Source code for nerfstudio.field_components.encodings

# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Encoding functions
"""

import itertools
from abc import abstractmethod
from typing import Literal, Optional, Sequence

import numpy as np
import torch
import torch.nn.functional as F
from jaxtyping import Float, Int, Shaped
from torch import Tensor, nn

from nerfstudio.field_components.base_field_component import FieldComponent
from nerfstudio.utils.external import TCNN_EXISTS, tcnn
from nerfstudio.utils.math import components_from_spherical_harmonics, expected_sin, generate_polyhedron_basis
from nerfstudio.utils.printing import print_tcnn_speed_warning


[docs]class Encoding(FieldComponent): """Encode an input tensor. Intended to be subclassed Args: in_dim: Input dimension of tensor """ def __init__(self, in_dim: int) -> None: if in_dim <= 0: raise ValueError("Input dimension should be greater than zero") super().__init__(in_dim=in_dim)
[docs] @classmethod def get_tcnn_encoding_config(cls) -> dict: """Get the encoding configuration for tcnn if implemented""" raise NotImplementedError("Encoding does not have a TCNN implementation")
[docs] @abstractmethod def forward(self, in_tensor: Shaped[Tensor, "*bs input_dim"]) -> Shaped[Tensor, "*bs output_dim"]: """Call forward and returns and processed tensor Args: in_tensor: the input tensor to process """ raise NotImplementedError
[docs]class Identity(Encoding): """Identity encoding (Does not modify input)"""
[docs] def get_out_dim(self) -> int: if self.in_dim is None: raise ValueError("Input dimension has not been set") return self.in_dim
[docs] def forward(self, in_tensor: Shaped[Tensor, "*bs input_dim"]) -> Shaped[Tensor, "*bs output_dim"]: return in_tensor
[docs]class ScalingAndOffset(Encoding): """Simple scaling and offset to input Args: in_dim: Input dimension of tensor scaling: Scaling applied to tensor. offset: Offset applied to tensor. """ def __init__(self, in_dim: int, scaling: float = 1.0, offset: float = 0.0) -> None: super().__init__(in_dim) self.scaling = scaling self.offset = offset
[docs] def get_out_dim(self) -> int: if self.in_dim is None: raise ValueError("Input dimension has not been set") return self.in_dim
[docs] def forward(self, in_tensor: Float[Tensor, "*bs input_dim"]) -> Float[Tensor, "*bs output_dim"]: return self.scaling * in_tensor + self.offset
[docs]class NeRFEncoding(Encoding): """Multi-scale sinusoidal encodings. Support ``integrated positional encodings`` if covariances are provided. Each axis is encoded with frequencies ranging from 2^min_freq_exp to 2^max_freq_exp. Args: in_dim: Input dimension of tensor num_frequencies: Number of encoded frequencies per axis min_freq_exp: Minimum frequency exponent max_freq_exp: Maximum frequency exponent include_input: Append the input coordinate to the encoding """ def __init__( self, in_dim: int, num_frequencies: int, min_freq_exp: float, max_freq_exp: float, include_input: bool = False, implementation: Literal["tcnn", "torch"] = "torch", ) -> None: super().__init__(in_dim) self.num_frequencies = num_frequencies self.min_freq = min_freq_exp self.max_freq = max_freq_exp self.include_input = include_input self.tcnn_encoding = None if implementation == "tcnn" and not TCNN_EXISTS: print_tcnn_speed_warning("NeRFEncoding") elif implementation == "tcnn": assert min_freq_exp == 0, "tcnn only supports min_freq_exp = 0" assert max_freq_exp == num_frequencies - 1, "tcnn only supports max_freq_exp = num_frequencies - 1" encoding_config = self.get_tcnn_encoding_config(num_frequencies=self.num_frequencies) self.tcnn_encoding = tcnn.Encoding( n_input_dims=in_dim, encoding_config=encoding_config, )
[docs] @classmethod def get_tcnn_encoding_config(cls, num_frequencies) -> dict: """Get the encoding configuration for tcnn if implemented""" encoding_config = {"otype": "Frequency", "n_frequencies": num_frequencies} return encoding_config
[docs] def get_out_dim(self) -> int: if self.in_dim is None: raise ValueError("Input dimension has not been set") out_dim = self.in_dim * self.num_frequencies * 2 if self.include_input: out_dim += self.in_dim return out_dim
[docs] def pytorch_fwd( self, in_tensor: Float[Tensor, "*bs input_dim"], covs: Optional[Float[Tensor, "*bs input_dim input_dim"]] = None, ) -> Float[Tensor, "*bs output_dim"]: """Calculates NeRF encoding. If covariances are provided the encodings will be integrated as proposed in mip-NeRF. Args: in_tensor: For best performance, the input tensor should be between 0 and 1. covs: Covariances of input points. Returns: Output values will be between -1 and 1 """ scaled_in_tensor = 2 * torch.pi * in_tensor # scale to [0, 2pi] freqs = 2 ** torch.linspace(self.min_freq, self.max_freq, self.num_frequencies, device=in_tensor.device) scaled_inputs = scaled_in_tensor[..., None] * freqs # [..., "input_dim", "num_scales"] scaled_inputs = scaled_inputs.view(*scaled_inputs.shape[:-2], -1) # [..., "input_dim" * "num_scales"] if covs is None: encoded_inputs = torch.sin(torch.cat([scaled_inputs, scaled_inputs + torch.pi / 2.0], dim=-1)) else: input_var = torch.diagonal(covs, dim1=-2, dim2=-1)[..., :, None] * freqs[None, :] ** 2 input_var = input_var.reshape((*input_var.shape[:-2], -1)) encoded_inputs = expected_sin( torch.cat([scaled_inputs, scaled_inputs + torch.pi / 2.0], dim=-1), torch.cat(2 * [input_var], dim=-1) ) return encoded_inputs
[docs] def forward( self, in_tensor: Float[Tensor, "*bs input_dim"], covs: Optional[Float[Tensor, "*bs input_dim input_dim"]] = None ) -> Float[Tensor, "*bs output_dim"]: if self.tcnn_encoding is not None: encoded_inputs = self.tcnn_encoding(in_tensor) else: encoded_inputs = self.pytorch_fwd(in_tensor, covs) if self.include_input: encoded_inputs = torch.cat([encoded_inputs, in_tensor], dim=-1) return encoded_inputs
[docs]class FFEncoding(Encoding): """Fourier Feature encoding. Supports integrated encodings. Args: in_dim: Input dimension of tensor basis: Basis matrix from which to construct the Fourier features. num_frequencies: Number of encoded frequencies per axis min_freq_exp: Minimum frequency exponent max_freq_exp: Maximum frequency exponent include_input: Append the input coordinate to the encoding """ def __init__( self, in_dim: int, basis: Float[Tensor, "M N"], num_frequencies: int, min_freq_exp: float, max_freq_exp: float, include_input: bool = False, ) -> None: super().__init__(in_dim) self.num_frequencies = num_frequencies self.min_freq = min_freq_exp self.max_freq = max_freq_exp self.register_buffer(name="b_matrix", tensor=basis) self.include_input = include_input
[docs] def get_out_dim(self) -> int: if self.in_dim is None: raise ValueError("Input dimension has not been set") assert isinstance(self.b_matrix, Tensor) out_dim = self.b_matrix.shape[1] * self.num_frequencies * 2 if self.include_input: out_dim += self.in_dim return out_dim
[docs] def forward( self, in_tensor: Float[Tensor, "*bs input_dim"], covs: Optional[Float[Tensor, "*bs input_dim input_dim"]] = None, ) -> Float[Tensor, "*bs output_dim"]: """Calculates FF encoding. If covariances are provided the encodings will be integrated as proposed in mip-NeRF. Args: in_tensor: For best performance, the input tensor should be between 0 and 1. covs: Covariances of input points. Returns: Output values will be between -1 and 1 """ scaled_in_tensor = 2 * torch.pi * in_tensor # scale to [0, 2pi] scaled_inputs = scaled_in_tensor @ self.b_matrix # [..., "num_frequencies"] freqs = 2 ** torch.linspace(self.min_freq, self.max_freq, self.num_frequencies, device=in_tensor.device) scaled_inputs = scaled_inputs[..., None] * freqs # [..., "input_dim", "num_scales"] scaled_inputs = scaled_inputs.view(*scaled_inputs.shape[:-2], -1) # [..., "input_dim" * "num_scales"] if covs is None: encoded_inputs = torch.sin(torch.cat([scaled_inputs, scaled_inputs + torch.pi / 2.0], dim=-1)) else: input_var = torch.sum((covs @ self.b_matrix) * self.b_matrix, -2) input_var = input_var[..., :, None] * freqs[None, :] ** 2 input_var = input_var.reshape((*input_var.shape[:-2], -1)) encoded_inputs = expected_sin( torch.cat([scaled_inputs, scaled_inputs + torch.pi / 2.0], dim=-1), torch.cat(2 * [input_var], dim=-1) ) if self.include_input: encoded_inputs = torch.cat([encoded_inputs, in_tensor], dim=-1) return encoded_inputs
[docs]class RFFEncoding(FFEncoding): """Random Fourier Feature encoding. Supports integrated encodings. Args: in_dim: Input dimension of tensor num_frequencies: Number of encoding frequencies scale: Std of Gaussian to sample frequencies. Must be greater than zero include_input: Append the input coordinate to the encoding """ def __init__(self, in_dim: int, num_frequencies: int, scale: float, include_input: bool = False) -> None: if not scale > 0: raise ValueError("RFF encoding scale should be greater than zero") b_matrix = torch.normal(mean=0, std=scale, size=(in_dim, num_frequencies)) super().__init__(in_dim, b_matrix, 1, 0.0, 0.0, include_input)
[docs]class PolyhedronFFEncoding(FFEncoding): """Fourier Feature encoding using polyhedron basis as proposed by mip-NeRF360. Supports integrated encodings. Args: num_frequencies: Number of encoded frequencies per axis min_freq_exp: Minimum frequency exponent max_freq_exp: Maximum frequency exponent basis_shape: Shape of polyhedron basis. Either "octahedron" or "icosahedron" basis_subdivisions: Number of times to tesselate the polyhedron. include_input: Append the input coordinate to the encoding """ def __init__( self, num_frequencies: int, min_freq_exp: float, max_freq_exp: float, basis_shape: Literal["octahedron", "icosahedron"] = "octahedron", basis_subdivisions: int = 1, include_input: bool = False, ) -> None: basis_t = generate_polyhedron_basis(basis_shape, basis_subdivisions).T super().__init__(3, basis_t, num_frequencies, min_freq_exp, max_freq_exp, include_input)
[docs]class HashEncoding(Encoding): """Hash encoding Args: num_levels: Number of feature grids. min_res: Resolution of smallest feature grid. max_res: Resolution of largest feature grid. log2_hashmap_size: Size of hash map is 2^log2_hashmap_size. features_per_level: Number of features per level. hash_init_scale: Value to initialize hash grid. implementation: Implementation of hash encoding. Fallback to torch if tcnn not available. interpolation: Interpolation override for tcnn hashgrid. Not supported for torch unless linear. """ def __init__( self, num_levels: int = 16, min_res: int = 16, max_res: int = 1024, log2_hashmap_size: int = 19, features_per_level: int = 2, hash_init_scale: float = 0.001, implementation: Literal["tcnn", "torch"] = "tcnn", interpolation: Optional[Literal["Nearest", "Linear", "Smoothstep"]] = None, ) -> None: super().__init__(in_dim=3) self.num_levels = num_levels self.min_res = min_res self.features_per_level = features_per_level self.hash_init_scale = hash_init_scale self.log2_hashmap_size = log2_hashmap_size self.hash_table_size = 2**log2_hashmap_size levels = torch.arange(num_levels) self.growth_factor = np.exp((np.log(max_res) - np.log(min_res)) / (num_levels - 1)) if num_levels > 1 else 1 self.scalings = torch.floor(min_res * self.growth_factor**levels) self.hash_offset = levels * self.hash_table_size self.tcnn_encoding = None self.hash_table = torch.empty(0) if implementation == "torch": self.build_nn_modules() elif implementation == "tcnn" and not TCNN_EXISTS: print_tcnn_speed_warning("HashEncoding") self.build_nn_modules() elif implementation == "tcnn": encoding_config = self.get_tcnn_encoding_config( num_levels=self.num_levels, features_per_level=self.features_per_level, log2_hashmap_size=self.log2_hashmap_size, min_res=self.min_res, growth_factor=self.growth_factor, interpolation=interpolation, ) self.tcnn_encoding = tcnn.Encoding( n_input_dims=3, encoding_config=encoding_config, ) if self.tcnn_encoding is None: assert ( interpolation is None or interpolation == "Linear" ), f"interpolation '{interpolation}' is not supported for torch encoding backend"
[docs] def build_nn_modules(self) -> None: """Initialize the torch version of the hash encoding.""" self.hash_table = torch.rand(size=(self.hash_table_size * self.num_levels, self.features_per_level)) * 2 - 1 self.hash_table *= self.hash_init_scale self.hash_table = nn.Parameter(self.hash_table)
[docs] @classmethod def get_tcnn_encoding_config( cls, num_levels, features_per_level, log2_hashmap_size, min_res, growth_factor, interpolation=None ) -> dict: """Get the encoding configuration for tcnn if implemented""" encoding_config = { "otype": "HashGrid", "n_levels": num_levels, "n_features_per_level": features_per_level, "log2_hashmap_size": log2_hashmap_size, "base_resolution": min_res, "per_level_scale": growth_factor, } if interpolation is not None: encoding_config["interpolation"] = interpolation return encoding_config
[docs] def get_out_dim(self) -> int: return self.num_levels * self.features_per_level
[docs] def hash_fn(self, in_tensor: Int[Tensor, "*bs num_levels 3"]) -> Shaped[Tensor, "*bs num_levels"]: """Returns hash tensor using method described in Instant-NGP Args: in_tensor: Tensor to be hashed """ # min_val = torch.min(in_tensor) # max_val = torch.max(in_tensor) # assert min_val >= 0.0 # assert max_val <= 1.0 in_tensor = in_tensor * torch.tensor([1, 2654435761, 805459861]).to(in_tensor.device) x = torch.bitwise_xor(in_tensor[..., 0], in_tensor[..., 1]) x = torch.bitwise_xor(x, in_tensor[..., 2]) x %= self.hash_table_size x += self.hash_offset.to(x.device) return x
[docs] def pytorch_fwd(self, in_tensor: Float[Tensor, "*bs input_dim"]) -> Float[Tensor, "*bs output_dim"]: """Forward pass using pytorch. Significantly slower than TCNN implementation.""" assert in_tensor.shape[-1] == 3 in_tensor = in_tensor[..., None, :] # [..., 1, 3] scaled = in_tensor * self.scalings.view(-1, 1).to(in_tensor.device) # [..., L, 3] scaled_c = torch.ceil(scaled).type(torch.int32) scaled_f = torch.floor(scaled).type(torch.int32) offset = scaled - scaled_f hashed_0 = self.hash_fn(scaled_c) # [..., num_levels] hashed_1 = self.hash_fn(torch.cat([scaled_c[..., 0:1], scaled_f[..., 1:2], scaled_c[..., 2:3]], dim=-1)) hashed_2 = self.hash_fn(torch.cat([scaled_f[..., 0:1], scaled_f[..., 1:2], scaled_c[..., 2:3]], dim=-1)) hashed_3 = self.hash_fn(torch.cat([scaled_f[..., 0:1], scaled_c[..., 1:2], scaled_c[..., 2:3]], dim=-1)) hashed_4 = self.hash_fn(torch.cat([scaled_c[..., 0:1], scaled_c[..., 1:2], scaled_f[..., 2:3]], dim=-1)) hashed_5 = self.hash_fn(torch.cat([scaled_c[..., 0:1], scaled_f[..., 1:2], scaled_f[..., 2:3]], dim=-1)) hashed_6 = self.hash_fn(scaled_f) hashed_7 = self.hash_fn(torch.cat([scaled_f[..., 0:1], scaled_c[..., 1:2], scaled_f[..., 2:3]], dim=-1)) f_0 = self.hash_table[hashed_0] # [..., num_levels, features_per_level] f_1 = self.hash_table[hashed_1] f_2 = self.hash_table[hashed_2] f_3 = self.hash_table[hashed_3] f_4 = self.hash_table[hashed_4] f_5 = self.hash_table[hashed_5] f_6 = self.hash_table[hashed_6] f_7 = self.hash_table[hashed_7] f_03 = f_0 * offset[..., 0:1] + f_3 * (1 - offset[..., 0:1]) f_12 = f_1 * offset[..., 0:1] + f_2 * (1 - offset[..., 0:1]) f_56 = f_5 * offset[..., 0:1] + f_6 * (1 - offset[..., 0:1]) f_47 = f_4 * offset[..., 0:1] + f_7 * (1 - offset[..., 0:1]) f0312 = f_03 * offset[..., 1:2] + f_12 * (1 - offset[..., 1:2]) f4756 = f_47 * offset[..., 1:2] + f_56 * (1 - offset[..., 1:2]) encoded_value = f0312 * offset[..., 2:3] + f4756 * ( 1 - offset[..., 2:3] ) # [..., num_levels, features_per_level] return torch.flatten(encoded_value, start_dim=-2, end_dim=-1) # [..., num_levels * features_per_level]
[docs] def forward(self, in_tensor: Float[Tensor, "*bs input_dim"]) -> Float[Tensor, "*bs output_dim"]: if self.tcnn_encoding is not None: return self.tcnn_encoding(in_tensor) return self.pytorch_fwd(in_tensor)
[docs]class TensorCPEncoding(Encoding): """Learned CANDECOMP/PARFAC (CP) decomposition encoding used in TensoRF Args: resolution: Resolution of grid. num_components: Number of components per dimension. init_scale: Initialization scale. """ def __init__(self, resolution: int = 256, num_components: int = 24, init_scale: float = 0.1) -> None: super().__init__(in_dim=3) self.resolution = resolution self.num_components = num_components # TODO Learning rates should be different for these self.line_coef = nn.Parameter(init_scale * torch.randn((3, num_components, resolution, 1)))
[docs] def get_out_dim(self) -> int: return self.num_components
[docs] def forward(self, in_tensor: Float[Tensor, "*bs input_dim"]) -> Float[Tensor, "*bs output_dim"]: line_coord = torch.stack([in_tensor[..., 2], in_tensor[..., 1], in_tensor[..., 0]]) # [3, ...] line_coord = torch.stack([torch.zeros_like(line_coord), line_coord], dim=-1) # [3, ...., 2] # Stop gradients from going to sampler line_coord = line_coord.view(3, -1, 1, 2).detach() line_features = F.grid_sample(self.line_coef, line_coord, align_corners=True) # [3, Components, -1, 1] features = torch.prod(line_features, dim=0) features = torch.moveaxis(features.view(self.num_components, *in_tensor.shape[:-1]), 0, -1) return features # [..., Components]
[docs] @torch.no_grad() def upsample_grid(self, resolution: int) -> None: """Upsamples underyling feature grid Args: resolution: Target resolution. """ line_coef = F.interpolate(self.line_coef.data, size=(resolution, 1), mode="bilinear", align_corners=True) self.line_coef = torch.nn.Parameter(line_coef) self.resolution = resolution
[docs]class TensorVMEncoding(Encoding): """Learned vector-matrix encoding proposed by TensoRF Args: resolution: Resolution of grid. num_components: Number of components per dimension. init_scale: Initialization scale. """ plane_coef: Float[Tensor, "3 num_components resolution resolution"] line_coef: Float[Tensor, "3 num_components resolution 1"] def __init__( self, resolution: int = 128, num_components: int = 24, init_scale: float = 0.1, ) -> None: super().__init__(in_dim=3) self.resolution = resolution self.num_components = num_components self.plane_coef = nn.Parameter(init_scale * torch.randn((3, num_components, resolution, resolution))) self.line_coef = nn.Parameter(init_scale * torch.randn((3, num_components, resolution, 1)))
[docs] def get_out_dim(self) -> int: return self.num_components * 3
[docs] def forward(self, in_tensor: Float[Tensor, "*bs input_dim"]) -> Float[Tensor, "*bs output_dim"]: """Compute encoding for each position in in_positions Args: in_tensor: position inside bounds in range [-1,1], Returns: Encoded position """ plane_coord = torch.stack([in_tensor[..., [0, 1]], in_tensor[..., [0, 2]], in_tensor[..., [1, 2]]]) # [3,...,2] line_coord = torch.stack([in_tensor[..., 2], in_tensor[..., 1], in_tensor[..., 0]]) # [3, ...] line_coord = torch.stack([torch.zeros_like(line_coord), line_coord], dim=-1) # [3, ...., 2] # Stop gradients from going to sampler plane_coord = plane_coord.view(3, -1, 1, 2).detach() line_coord = line_coord.view(3, -1, 1, 2).detach() plane_features = F.grid_sample(self.plane_coef, plane_coord, align_corners=True) # [3, Components, -1, 1] line_features = F.grid_sample(self.line_coef, line_coord, align_corners=True) # [3, Components, -1, 1] features = plane_features * line_features # [3, Components, -1, 1] features = torch.moveaxis(features.view(3 * self.num_components, *in_tensor.shape[:-1]), 0, -1) return features # [..., 3 * Components]
[docs] @torch.no_grad() def upsample_grid(self, resolution: int) -> None: """Upsamples underlying feature grid Args: resolution: Target resolution. """ plane_coef = F.interpolate( self.plane_coef.data, size=(resolution, resolution), mode="bilinear", align_corners=True ) line_coef = F.interpolate(self.line_coef.data, size=(resolution, 1), mode="bilinear", align_corners=True) self.plane_coef, self.line_coef = torch.nn.Parameter(plane_coef), torch.nn.Parameter(line_coef) self.resolution = resolution
[docs]class TriplaneEncoding(Encoding): """Learned triplane encoding The encoding at [i,j,k] is an n dimensional vector corresponding to the element-wise product of the three n dimensional vectors at plane_coeff[i,j], plane_coeff[i,k], and plane_coeff[j,k]. This allows for marginally more expressivity than the TensorVMEncoding, and each component is self standing and symmetrical, unlike with VM decomposition where we needed one component with a vector along all the x, y, z directions for symmetry. This can be thought of as 3 planes of features perpendicular to the x, y, and z axes, respectively and intersecting at the origin, and the encoding being the element-wise product of the element at the projection of [i, j, k] on these planes. The use for this is in representing a tensor decomp of a 4D embedding tensor: (x, y, z, feature_size) This will return a tensor of shape (bs:..., num_components) Args: resolution: Resolution of grid. num_components: The number of scalar triplanes to use (ie: output feature size) init_scale: The scale of the initial values of the planes product: Whether to use the element-wise product of the planes or the sum """ plane_coef: Float[Tensor, "3 num_components resolution resolution"] def __init__( self, resolution: int = 32, num_components: int = 64, init_scale: float = 0.1, reduce: Literal["sum", "product"] = "sum", ) -> None: super().__init__(in_dim=3) self.resolution = resolution self.num_components = num_components self.init_scale = init_scale self.reduce = reduce self.plane_coef = nn.Parameter( self.init_scale * torch.randn((3, self.num_components, self.resolution, self.resolution)) )
[docs] def get_out_dim(self) -> int: return self.num_components
[docs] def forward(self, in_tensor: Float[Tensor, "*bs 3"]) -> Float[Tensor, "*bs num_components featuresize"]: """Sample features from this encoder. Expects in_tensor to be in range [0, resolution]""" original_shape = in_tensor.shape in_tensor = in_tensor.reshape(-1, 3) plane_coord = torch.stack([in_tensor[..., [0, 1]], in_tensor[..., [0, 2]], in_tensor[..., [1, 2]]], dim=0) # Stop gradients from going to sampler plane_coord = plane_coord.detach().view(3, -1, 1, 2) plane_features = F.grid_sample( self.plane_coef, plane_coord, align_corners=True ) # [3, num_components, flattened_bs, 1] if self.reduce == "product": plane_features = plane_features.prod(0).squeeze(-1).T # [flattened_bs, num_components] else: plane_features = plane_features.sum(0).squeeze(-1).T return plane_features.reshape(*original_shape[:-1], self.num_components)
[docs] @torch.no_grad() def upsample_grid(self, resolution: int) -> None: """Upsamples underlying feature grid Args: resolution: Target resolution. """ plane_coef = F.interpolate( self.plane_coef.data, size=(resolution, resolution), mode="bilinear", align_corners=True ) self.plane_coef = torch.nn.Parameter(plane_coef) self.resolution = resolution
[docs]class KPlanesEncoding(Encoding): """Learned K-Planes encoding A plane encoding supporting both 3D and 4D coordinates. With 3D coordinates this is similar to :class:`TriplaneEncoding`. With 4D coordinates, the encoding at point ``[i,j,k,q]`` is a n-dimensional vector computed as the elementwise product of 6 n-dimensional vectors at ``planes[i,j]``, ``planes[i,k]``, ``planes[i,q]``, ``planes[j,k]``, ``planes[j,q]``, ``planes[k,q]``. Unlike :class:`TriplaneEncoding` this class supports different resolution along each axis. This will return a tensor of shape (bs:..., num_components) Args: resolution: Resolution of the grid. Can be a sequence of 3 or 4 integers. num_components: The number of scalar planes to use (ie: output feature size) init_a: The lower-bound of the uniform distribution used to initialize the spatial planes init_b: The upper-bound of the uniform distribution used to initialize the spatial planes reduce: Whether to use the element-wise product of the planes or the sum """ def __init__( self, resolution: Sequence[int] = (128, 128, 128), num_components: int = 64, init_a: float = 0.1, init_b: float = 0.5, reduce: Literal["sum", "product"] = "product", ) -> None: super().__init__(in_dim=len(resolution)) self.resolution = resolution self.num_components = num_components self.reduce = reduce if self.in_dim not in {3, 4}: raise ValueError( f"The dimension of coordinates must be either 3 (static scenes) " f"or 4 (dynamic scenes). Found resolution with {self.in_dim} dimensions." ) has_time_planes = self.in_dim == 4 self.coo_combs = list(itertools.combinations(range(self.in_dim), 2)) # Unlike the Triplane encoding, we use a parameter list instead of batching all planes # together to support uneven resolutions (especially useful for time). # Dynamic models (in_dim == 4) will have 6 planes: # (y, x), (z, x), (t, x), (z, y), (t, y), (t, z) # static models (in_dim == 3) will only have the 1st, 2nd and 4th planes. self.plane_coefs = nn.ParameterList() for coo_comb in self.coo_combs: new_plane_coef = nn.Parameter( torch.empty([self.num_components] + [self.resolution[cc] for cc in coo_comb[::-1]]) ) if has_time_planes and 3 in coo_comb: # Time planes initialized to 1 nn.init.ones_(new_plane_coef) else: nn.init.uniform_(new_plane_coef, a=init_a, b=init_b) self.plane_coefs.append(new_plane_coef)
[docs] def get_out_dim(self) -> int: return self.num_components
[docs] def forward(self, in_tensor: Float[Tensor, "*bs input_dim"]) -> Float[Tensor, "*bs output_dim"]: """Sample features from this encoder. Expects ``in_tensor`` to be in range [-1, 1]""" original_shape = in_tensor.shape assert any(self.coo_combs) output = 1.0 if self.reduce == "product" else 0.0 # identity for corresponding op for ci, coo_comb in enumerate(self.coo_combs): grid = self.plane_coefs[ci].unsqueeze(0) # [1, feature_dim, reso1, reso2] coords = in_tensor[..., coo_comb].view(1, 1, -1, 2) # [1, 1, flattened_bs, 2] interp = F.grid_sample( grid, coords, align_corners=True, padding_mode="border" ) # [1, output_dim, 1, flattened_bs] interp = interp.view(self.num_components, -1).T # [flattened_bs, output_dim] if self.reduce == "product": output = output * interp else: output = output + interp # Typing: output gets converted to a tensor after the first iteration of the loop assert isinstance(output, Tensor) return output.reshape(*original_shape[:-1], self.num_components)
[docs]class SHEncoding(Encoding): """Spherical harmonic encoding Args: levels: Number of spherical harmonic levels to encode. """ def __init__(self, levels: int = 4, implementation: Literal["tcnn", "torch"] = "torch") -> None: super().__init__(in_dim=3) if levels <= 0 or levels > 4: raise ValueError(f"Spherical harmonic encoding only supports 1 to 4 levels, requested {levels}") self.levels = levels self.tcnn_encoding = None if implementation == "tcnn" and not TCNN_EXISTS: print_tcnn_speed_warning("SHEncoding") elif implementation == "tcnn": encoding_config = self.get_tcnn_encoding_config(levels=self.levels) self.tcnn_encoding = tcnn.Encoding( n_input_dims=3, encoding_config=encoding_config, )
[docs] @classmethod def get_tcnn_encoding_config(cls, levels) -> dict: """Get the encoding configuration for tcnn if implemented""" encoding_config = { "otype": "SphericalHarmonics", "degree": levels, } return encoding_config
[docs] def get_out_dim(self) -> int: return self.levels**2
[docs] @torch.no_grad() def pytorch_fwd(self, in_tensor: Float[Tensor, "*bs input_dim"]) -> Float[Tensor, "*bs output_dim"]: """Forward pass using pytorch. Significantly slower than TCNN implementation.""" return components_from_spherical_harmonics(levels=self.levels, directions=in_tensor)
[docs] def forward(self, in_tensor: Float[Tensor, "*bs input_dim"]) -> Float[Tensor, "*bs output_dim"]: if self.tcnn_encoding is not None: return self.tcnn_encoding(in_tensor) return self.pytorch_fwd(in_tensor)