Source code for nerfstudio.utils.math

# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

"""Math Helper Functions"""

import itertools
import math
from dataclasses import dataclass
from typing import Literal, Tuple

import torch
from jaxtyping import Bool, Float, Int
from torch import Tensor

from import OrientedBox

[docs]@dataclass class Gaussians: """Stores Gaussians Args: mean: Mean of multivariate Gaussian cov: Covariance of multivariate Gaussian. """ mean: Float[Tensor, "*batch dim"] cov: Float[Tensor, "*batch dim dim"]
[docs]def compute_3d_gaussian( directions: Float[Tensor, "*batch 3"], means: Float[Tensor, "*batch 3"], dir_variance: Float[Tensor, "*batch 1"], radius_variance: Float[Tensor, "*batch 1"], ) -> Gaussians: """Compute gaussian along ray. Args: directions: Axis of Gaussian. means: Mean of Gaussian. dir_variance: Variance along direction axis. radius_variance: Variance tangent to direction axis. Returns: Gaussians: Oriented 3D gaussian. """ dir_outer_product = directions[..., :, None] * directions[..., None, :] eye = torch.eye(directions.shape[-1], device=directions.device) dir_mag_sq = torch.clamp(torch.sum(directions**2, dim=-1, keepdim=True), min=1e-10) null_outer_product = eye - directions[..., :, None] * (directions / dir_mag_sq)[..., None, :] dir_cov_diag = dir_variance[..., None] * dir_outer_product[..., :, :] radius_cov_diag = radius_variance[..., None] * null_outer_product[..., :, :] cov = dir_cov_diag + radius_cov_diag return Gaussians(mean=means, cov=cov)
[docs]def cylinder_to_gaussian( origins: Float[Tensor, "*batch 3"], directions: Float[Tensor, "*batch 3"], starts: Float[Tensor, "*batch 1"], ends: Float[Tensor, "*batch 1"], radius: Float[Tensor, "*batch 1"], ) -> Gaussians: """Approximates cylinders with a Gaussian distributions. Args: origins: Origins of cylinders. directions: Direction (axis) of cylinders. starts: Start of cylinders. ends: End of cylinders. radius: Radii of cylinders. Returns: Gaussians: Approximation of cylinders """ means = origins + directions * ((starts + ends) / 2.0) dir_variance = (ends - starts) ** 2 / 12 radius_variance = radius**2 / 4.0 return compute_3d_gaussian(directions, means, dir_variance, radius_variance)
[docs]def conical_frustum_to_gaussian( origins: Float[Tensor, "*batch 3"], directions: Float[Tensor, "*batch 3"], starts: Float[Tensor, "*batch 1"], ends: Float[Tensor, "*batch 1"], radius: Float[Tensor, "*batch 1"], ) -> Gaussians: """Approximates conical frustums with a Gaussian distributions. Uses stable parameterization described in mip-NeRF publication. Args: origins: Origins of cones. directions: Direction (axis) of frustums. starts: Start of conical frustums. ends: End of conical frustums. radius: Radii of cone a distance of 1 from the origin. Returns: Gaussians: Approximation of conical frustums """ mu = (starts + ends) / 2.0 hw = (ends - starts) / 2.0 means = origins + directions * (mu + (2.0 * mu * hw**2.0) / (3.0 * mu**2.0 + hw**2.0)) dir_variance = (hw**2) / 3 - (4 / 15) * ((hw**4 * (12 * mu**2 - hw**2)) / (3 * mu**2 + hw**2) ** 2) radius_variance = radius**2 * ((mu**2) / 4 + (5 / 12) * hw**2 - 4 / 15 * (hw**4) / (3 * mu**2 + hw**2)) return compute_3d_gaussian(directions, means, dir_variance, radius_variance)
[docs]def expected_sin(x_means: torch.Tensor, x_vars: torch.Tensor) -> torch.Tensor: """Computes the expected value of sin(y) where y ~ N(x_means, x_vars) Args: x_means: Mean values. x_vars: Variance of values. Returns: torch.Tensor: The expected value of sin. """ return torch.exp(-0.5 * x_vars) * torch.sin(x_means)
# @torch_compile(dynamic=True, mode="reduce-overhead", backend="eager")
[docs]def intersect_aabb( origins: torch.Tensor, directions: torch.Tensor, aabb: torch.Tensor, max_bound: float = 1e10, invalid_value: float = 1e10, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Implementation of ray intersection with AABB box Args: origins: [N,3] tensor of 3d positions directions: [N,3] tensor of normalized directions aabb: [6] array of aabb box in the form of [x_min, y_min, z_min, x_max, y_max, z_max] max_bound: Maximum value of t_max invalid_value: Value to return in case of no intersection Returns: t_min, t_max - two tensors of shapes N representing distance of intersection from the origin. """ tx_min = (aabb[:3] - origins) / directions tx_max = (aabb[3:] - origins) / directions t_min = torch.stack((tx_min, tx_max)).amin(dim=0) t_max = torch.stack((tx_min, tx_max)).amax(dim=0) t_min = t_min.amax(dim=-1) t_max = t_max.amin(dim=-1) t_min = torch.clamp(t_min, min=0, max=max_bound) t_max = torch.clamp(t_max, min=0, max=max_bound) cond = t_max <= t_min t_min = torch.where(cond, invalid_value, t_min) t_max = torch.where(cond, invalid_value, t_max) return t_min, t_max
[docs]def intersect_obb( origins: torch.Tensor, directions: torch.Tensor, obb: OrientedBox, max_bound: float = 1e10, invalid_value: float = 1e10, ): """ Ray intersection with an oriented bounding box (OBB) Args: origins: [N,3] tensor of 3d positions directions: [N,3] tensor of normalized directions R: [3,3] rotation matrix T: [3] translation vector S: [3] extents of the bounding box max_bound: Maximum value of t_max invalid_value: Value to return in case of no intersection """ # Transform ray to OBB space R, T, S = obb.R, obb.T, H = torch.eye(4, device=origins.device, dtype=origins.dtype) H[:3, :3] = R H[:3, 3] = T H_world2bbox = torch.inverse(H) origins =, torch.ones_like(origins[..., :1])), dim=-1) origins = torch.matmul(H_world2bbox, origins.T).T[..., :3] directions = torch.matmul(H_world2bbox[:3, :3], directions.T).T # Compute intersection with axis-aligned bounding box with min as -S and max as +S aabb = torch.concat((-S / 2, S / 2)) t_min, t_max = intersect_aabb(origins, directions, aabb, max_bound=max_bound, invalid_value=invalid_value) return t_min, t_max
[docs]def safe_normalize( vectors: Float[Tensor, "*batch_dim N"], eps: float = 1e-10, ) -> Float[Tensor, "*batch_dim N"]: """Normalizes vectors. Args: vectors: Vectors to normalize. eps: Epsilon value to avoid division by zero. Returns: Normalized vectors. """ return vectors / (torch.norm(vectors, dim=-1, keepdim=True) + eps)
[docs]def masked_reduction( input_tensor: Float[Tensor, "1 32 mult"], mask: Bool[Tensor, "1 32 mult"], reduction_type: Literal["image", "batch"], ) -> Tensor: """ Whether to consolidate the input_tensor across the batch or across the image Args: input_tensor: input tensor mask: mask tensor reduction_type: either "batch" or "image" Returns: input_tensor: reduced input_tensor """ if reduction_type == "batch": # avoid division by 0 (if sum(M) = sum(sum(mask)) = 0: sum(image_loss) = 0) divisor = torch.sum(mask) if divisor == 0: return torch.tensor(0, device=input_tensor.device) input_tensor = torch.sum(input_tensor) / divisor elif reduction_type == "image": # avoid division by 0 (if M = sum(mask) = 0: image_loss = 0) valid = mask.nonzero() input_tensor[valid] = input_tensor[valid] / mask[valid] input_tensor = torch.mean(input_tensor) return input_tensor
[docs]def normalized_depth_scale_and_shift( prediction: Float[Tensor, "1 32 mult"], target: Float[Tensor, "1 32 mult"], mask: Bool[Tensor, "1 32 mult"], ): """ More info here: supplementary section A2 Depth Consistency Loss This function computes scale/shift required to normalizes predicted depth map, to allow for using normalized depth maps as input from monocular depth estimation networks. These networks are trained such that they predict normalized depth maps. Solves for scale/shift using a least squares approach with a closed form solution: Based on: Args: prediction: predicted depth map target: ground truth depth map mask: mask of valid pixels Returns: scale and shift for depth prediction """ # system matrix: A = [[a_00, a_01], [a_10, a_11]] a_00 = torch.sum(mask * prediction * prediction, (1, 2)) a_01 = torch.sum(mask * prediction, (1, 2)) a_11 = torch.sum(mask, (1, 2)) # right hand side: b = [b_0, b_1] b_0 = torch.sum(mask * prediction * target, (1, 2)) b_1 = torch.sum(mask * target, (1, 2)) # solution: x = A^-1 . b = [[a_11, -a_01], [-a_10, a_00]] / (a_00 * a_11 - a_01 * a_10) . b scale = torch.zeros_like(b_0) shift = torch.zeros_like(b_1) det = a_00 * a_11 - a_01 * a_01 valid = det.nonzero() scale[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / det[valid] shift[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / det[valid] return scale, shift
[docs]def columnwise_squared_l2_distance( x: Float[Tensor, "*M N"], y: Float[Tensor, "*M N"], ) -> Float[Tensor, "N N"]: """Compute the squared Euclidean distance between all pairs of columns. Adapted from Args: x: tensor of floats, with shape [M, N]. y: tensor of floats, with shape [M, N]. Returns: sq_dist: tensor of floats, with shape [N, N]. """ # Use the fact that ||x - y||^2 == ||x||^2 + ||y||^2 - 2 x^T y. sq_norm_x = torch.sum(x**2, 0) sq_norm_y = torch.sum(y**2, 0) sq_dist = sq_norm_x[:, None] + sq_norm_y[None, :] - 2 * x.T @ y return sq_dist
def _compute_tesselation_weights(v: int) -> Tensor: """Tesselate the vertices of a triangle by a factor of `v`. Adapted from Args: v: int, the factor of the tesselation (v==1 is a no-op to the triangle). Returns: weights: tesselated weights. """ if v < 1: raise ValueError(f"v {v} must be >= 1") int_weights = [] for i in range(v + 1): for j in range(v + 1 - i): int_weights.append((i, j, v - (i + j))) int_weights = torch.FloatTensor(int_weights) weights = int_weights / v # Barycentric weights. return weights def _tesselate_geodesic( vertices: Float[Tensor, "N 3"], faces: Float[Tensor, "M 3"], v: int, eps: float = 1e-4, ) -> Tensor: """Tesselate the vertices of a geodesic polyhedron. Adapted from Args: vertices: tensor of floats, the vertex coordinates of the geodesic. faces: tensor of ints, the indices of the vertices of base_verts that constitute eachface of the polyhedra. v: int, the factor of the tesselation (v==1 is a no-op). eps: float, a small value used to determine if two vertices are the same. Returns: verts: a tensor of floats, the coordinates of the tesselated vertices. """ tri_weights = _compute_tesselation_weights(v) verts = [] for face in faces: new_verts = torch.matmul(tri_weights, vertices[face, :]) new_verts /= torch.sqrt(torch.sum(new_verts**2, 1, keepdim=True)) verts.append(new_verts) verts = torch.concatenate(verts, 0) sq_dist = columnwise_squared_l2_distance(verts.T, verts.T) assignment = torch.tensor([torch.min(torch.argwhere(d <= eps)) for d in sq_dist]) unique = torch.unique(assignment) verts = verts[unique, :] return verts
[docs]def generate_polyhedron_basis( basis_shape: Literal["icosahedron", "octahedron"], angular_tesselation: int, remove_symmetries: bool = True, eps: float = 1e-4, ) -> Tensor: """Generates a 3D basis by tesselating a geometric polyhedron. Basis is used to construct Fourier features for positional encoding. See Mip-Nerf360 paper: Adapted from Args: base_shape: string, the name of the starting polyhedron, must be either 'icosahedron' or 'octahedron'. angular_tesselation: int, the number of times to tesselate the polyhedron, must be >= 1 (a value of 1 is a no-op to the polyhedron). remove_symmetries: bool, if True then remove the symmetric basis columns, which is usually a good idea because otherwise projections onto the basis will have redundant negative copies of each other. eps: float, a small number used to determine symmetries. Returns: basis: a matrix with shape [3, n]. """ if basis_shape == "icosahedron": a = (math.sqrt(5) + 1) / 2 verts = torch.FloatTensor( [ (-1, 0, a), (1, 0, a), (-1, 0, -a), (1, 0, -a), (0, a, 1), (0, a, -1), (0, -a, 1), (0, -a, -1), (a, 1, 0), (-a, 1, 0), (a, -1, 0), (-a, -1, 0), ] ) / math.sqrt(a + 2) faces = torch.tensor( [ (0, 4, 1), (0, 9, 4), (9, 5, 4), (4, 5, 8), (4, 8, 1), (8, 10, 1), (8, 3, 10), (5, 3, 8), (5, 2, 3), (2, 7, 3), (7, 10, 3), (7, 6, 10), (7, 11, 6), (11, 0, 6), (0, 1, 6), (6, 1, 10), (9, 0, 11), (9, 11, 2), (9, 2, 5), (7, 2, 11), ] ) verts = _tesselate_geodesic(verts, faces, angular_tesselation) elif basis_shape == "octahedron": verts = torch.FloatTensor([(0, 0, -1), (0, 0, 1), (0, -1, 0), (0, 1, 0), (-1, 0, 0), (1, 0, 0)]) corners = torch.FloatTensor(list(itertools.product([-1, 1], repeat=3))) pairs = torch.argwhere(columnwise_squared_l2_distance(corners.T, verts.T) == 2) faces, _ = torch.sort(torch.reshape(pairs[:, 1], [3, -1]).T, 1) verts = _tesselate_geodesic(verts, faces, angular_tesselation) if remove_symmetries: # Remove elements of `verts` that are reflections of each other. match = columnwise_squared_l2_distance(verts.T, -verts.T) < eps verts = verts[torch.any(torch.triu(match), 1), :] basis = verts.flip(-1) return basis
[docs]def random_quat_tensor(N: int) -> Float[Tensor, "*batch 4"]: """ Defines a random quaternion tensor. Args: N: Number of quaternions to generate Returns: a random quaternion tensor of shape (N, 4) """ u = torch.rand(N) v = torch.rand(N) w = torch.rand(N) return torch.stack( [ torch.sqrt(1 - u) * torch.sin(2 * math.pi * v), torch.sqrt(1 - u) * torch.cos(2 * math.pi * v), torch.sqrt(u) * torch.sin(2 * math.pi * w), torch.sqrt(u) * torch.cos(2 * math.pi * w), ], dim=-1, )
[docs]def k_nearest_sklearn( x: torch.Tensor, k: int, metric: str = "euclidean" ) -> Tuple[Float[Tensor, "*batch k"], Int[Tensor, "*batch k"]]: """ Find k-nearest neighbors using sklearn's NearestNeighbors. Args: x: input tensor k: number of neighbors to find metric: metric to use for distance computation Returns: distances: distances to the k-nearest neighbors indices: indices of the k-nearest neighbors """ # Convert tensor to numpy array x_np = x.cpu().numpy() # Build the nearest neighbors model from sklearn.neighbors import NearestNeighbors nn_model = NearestNeighbors(n_neighbors=k + 1, algorithm="auto", metric=metric).fit(x_np) # Find the k-nearest neighbors distances, indices = nn_model.kneighbors(x_np) # Exclude the point itself from the result and return return torch.tensor(distances[:, 1:], dtype=torch.float32), torch.tensor(indices[:, 1:], dtype=torch.int64)