Source code for nerfstudio.fields.sdf_field

# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Field for SDF based model, rather then estimating density to generate a surface,
a signed distance function (SDF) for surface representation is used to help with extracting high fidelity surfaces
"""

from dataclasses import dataclass, field
from typing import Dict, Literal, Optional, Type

import numpy as np
import torch
import torch.nn.functional as F
from jaxtyping import Float
from torch import Tensor, nn
from torch.nn.parameter import Parameter

from nerfstudio.cameras.rays import RaySamples
from nerfstudio.field_components.embedding import Embedding
from nerfstudio.field_components.encodings import NeRFEncoding
from nerfstudio.field_components.field_heads import FieldHeadNames
from nerfstudio.field_components.spatial_distortions import SpatialDistortion
from nerfstudio.fields.base_field import Field, FieldConfig
from nerfstudio.utils.external import tcnn


[docs]class LearnedVariance(nn.Module): """Variance network in NeuS Args: init_val: initial value in NeuS variance network """ variance: Tensor def __init__(self, init_val): super().__init__() self.register_parameter("variance", nn.Parameter(init_val * torch.ones(1), requires_grad=True))
[docs] def forward(self, x: Float[Tensor, "1"]) -> Float[Tensor, "1"]: """Returns current variance value""" return torch.ones([len(x), 1], device=x.device) * torch.exp(self.variance * 10.0)
[docs] def get_variance(self) -> Float[Tensor, "1"]: """return current variance value""" return torch.exp(self.variance * 10.0).clip(1e-6, 1e6)
[docs]@dataclass class SDFFieldConfig(FieldConfig): """SDF Field Config""" _target: Type = field(default_factory=lambda: SDFField) num_layers: int = 8 """Number of layers for geometric network""" hidden_dim: int = 256 """Number of hidden dimension of geometric network""" geo_feat_dim: int = 256 """Dimension of geometric feature""" num_layers_color: int = 4 """Number of layers for color network""" hidden_dim_color: int = 256 """Number of hidden dimension of color network""" appearance_embedding_dim: int = 32 """Dimension of appearance embedding""" use_appearance_embedding: bool = False """Whether to use appearance embedding""" bias: float = 0.8 """Sphere size of geometric initialization""" geometric_init: bool = True """Whether to use geometric initialization""" inside_outside: bool = True """Whether to revert signed distance value, set to True for indoor scene""" weight_norm: bool = True """Whether to use weight norm for linear layer""" use_grid_feature: bool = False """Whether to use multi-resolution feature grids""" divide_factor: float = 2.0 """Normalization factor for multi-resolution grids""" beta_init: float = 0.1 """Init learnable beta value for transformation of sdf to density""" encoding_type: Literal["hash", "periodic", "tensorf_vm"] = "hash" num_levels: int = 16 """Number of encoding levels""" max_res: int = 2048 """Maximum resolution of the encoding""" base_res: int = 16 """Base resolution of the encoding""" log2_hashmap_size: int = 19 """Size of the hash map""" features_per_level: int = 2 """Number of features per encoding level""" use_hash: bool = True """Whether to use hash encoding""" smoothstep: bool = True """Whether to use the smoothstep function"""
[docs]class SDFField(Field): """ A field for Signed Distance Functions (SDF). Args: config: The configuration for the SDF field. aabb: An axis-aligned bounding box for the SDF field. num_images: The number of images for embedding appearance. use_average_appearance_embedding: Whether to use average appearance embedding. Defaults to False. spatial_distortion: The spatial distortion. Defaults to None. """ config: SDFFieldConfig def __init__( self, config: SDFFieldConfig, aabb: Float[Tensor, "2 3"], num_images: int, use_average_appearance_embedding: bool = False, spatial_distortion: Optional[SpatialDistortion] = None, ) -> None: super().__init__() self.config = config self.aabb = Parameter(aabb, requires_grad=False) self.spatial_distortion = spatial_distortion self.num_images = num_images self.embedding_appearance = Embedding(self.num_images, self.config.appearance_embedding_dim) self.use_average_appearance_embedding = use_average_appearance_embedding self.use_grid_feature = self.config.use_grid_feature self.divide_factor = self.config.divide_factor growth_factor = np.exp((np.log(config.max_res) - np.log(config.base_res)) / (config.num_levels - 1)) if self.config.encoding_type == "hash": # feature encoding self.encoding = tcnn.Encoding( n_input_dims=3, encoding_config={ "otype": "HashGrid" if config.use_hash else "DenseGrid", "n_levels": config.num_levels, "n_features_per_level": config.features_per_level, "log2_hashmap_size": config.log2_hashmap_size, "base_resolution": config.base_res, "per_level_scale": growth_factor, "interpolation": "Smoothstep" if config.smoothstep else "Linear", }, ) # we concat inputs position ourselves self.position_encoding = NeRFEncoding( in_dim=3, num_frequencies=6, min_freq_exp=0.0, max_freq_exp=5.0, include_input=False ) self.direction_encoding = NeRFEncoding( in_dim=3, num_frequencies=4, min_freq_exp=0.0, max_freq_exp=3.0, include_input=True ) # initialize geometric network self.initialize_geo_layers() # deviation_network to compute alpha from sdf from NeuS self.deviation_network = LearnedVariance(init_val=self.config.beta_init) # color network dims = [self.config.hidden_dim_color for _ in range(self.config.num_layers_color)] # point, view_direction, normal, feature, embedding in_dim = ( 3 + self.direction_encoding.get_out_dim() + 3 + self.config.geo_feat_dim + self.embedding_appearance.get_out_dim() ) dims = [in_dim] + dims + [3] self.num_layers_color = len(dims) for layer in range(0, self.num_layers_color - 1): out_dim = dims[layer + 1] lin = nn.Linear(dims[layer], out_dim) if self.config.weight_norm: lin = nn.utils.weight_norm(lin) setattr(self, "clin" + str(layer), lin) self.softplus = nn.Softplus(beta=100) self.relu = nn.ReLU() self.sigmoid = torch.nn.Sigmoid() self._cos_anneal_ratio = 1.0 if self.use_grid_feature: assert self.spatial_distortion is not None, "spatial distortion must be provided when using grid feature"
[docs] def initialize_geo_layers(self) -> None: """ Initialize layers for geometric network (sdf) """ # MLP with geometric initialization dims = [self.config.hidden_dim for _ in range(self.config.num_layers)] in_dim = 3 + self.position_encoding.get_out_dim() + self.encoding.n_output_dims dims = [in_dim] + dims + [1 + self.config.geo_feat_dim] self.num_layers = len(dims) self.skip_in = [4] for layer in range(0, self.num_layers - 1): if layer + 1 in self.skip_in: out_dim = dims[layer + 1] - dims[0] else: out_dim = dims[layer + 1] lin = nn.Linear(dims[layer], out_dim) if self.config.geometric_init: if layer == self.num_layers - 2: if not self.config.inside_outside: torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[layer]), std=0.0001) torch.nn.init.constant_(lin.bias, -self.config.bias) else: torch.nn.init.normal_(lin.weight, mean=-np.sqrt(np.pi) / np.sqrt(dims[layer]), std=0.0001) torch.nn.init.constant_(lin.bias, self.config.bias) elif layer == 0: torch.nn.init.constant_(lin.bias, 0.0) torch.nn.init.constant_(lin.weight[:, 3:], 0.0) torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim)) elif layer in self.skip_in: torch.nn.init.constant_(lin.bias, 0.0) torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3) :], 0.0) else: torch.nn.init.constant_(lin.bias, 0.0) torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) if self.config.weight_norm: lin = nn.utils.weight_norm(lin) setattr(self, "glin" + str(layer), lin)
[docs] def set_cos_anneal_ratio(self, anneal: float) -> None: """Set the anneal value for the proposal network.""" self._cos_anneal_ratio = anneal
[docs] def forward_geonetwork(self, inputs: Float[Tensor, "*batch 3"]) -> Float[Tensor, "*batch geo_features+1"]: """forward the geonetwork""" if self.use_grid_feature: assert self.spatial_distortion is not None, "spatial distortion must be provided when using grid feature" positions = self.spatial_distortion(inputs) # map range [-2, 2] to [0, 1] positions = (positions + 2.0) / 4.0 feature = self.encoding(positions) else: feature = torch.zeros_like(inputs[:, :1].repeat(1, self.encoding.n_output_dims)) pe = self.position_encoding(inputs) inputs = torch.cat((inputs, pe, feature), dim=-1) # Pass through layers outputs = inputs for layer in range(0, self.num_layers - 1): lin = getattr(self, "glin" + str(layer)) if layer in self.skip_in: outputs = torch.cat([outputs, inputs], 1) / np.sqrt(2) outputs = lin(outputs) if layer < self.num_layers - 2: outputs = self.softplus(outputs) return outputs
# TODO: fix ... in shape annotations.
[docs] def get_sdf(self, ray_samples: RaySamples) -> Float[Tensor, "num_samples ... 1"]: """predict the sdf value for ray samples""" positions = ray_samples.frustums.get_start_positions() positions_flat = positions.view(-1, 3) hidden_output = self.forward_geonetwork(positions_flat).view(*ray_samples.frustums.shape, -1) sdf, _ = torch.split(hidden_output, [1, self.config.geo_feat_dim], dim=-1) return sdf
[docs] def get_alpha( self, ray_samples: RaySamples, sdf: Optional[Float[Tensor, "num_samples ... 1"]] = None, gradients: Optional[Float[Tensor, "num_samples ... 1"]] = None, ) -> Float[Tensor, "num_samples ... 1"]: """compute alpha from sdf as in NeuS""" if sdf is None or gradients is None: inputs = ray_samples.frustums.get_start_positions() inputs.requires_grad_(True) with torch.enable_grad(): hidden_output = self.forward_geonetwork(inputs) sdf, _ = torch.split(hidden_output, [1, self.config.geo_feat_dim], dim=-1) d_output = torch.ones_like(sdf, requires_grad=False, device=sdf.device) gradients = torch.autograd.grad( outputs=sdf, inputs=inputs, grad_outputs=d_output, create_graph=True, retain_graph=True, only_inputs=True, )[0] inv_s = self.deviation_network.get_variance() # Single parameter true_cos = (ray_samples.frustums.directions * gradients).sum(-1, keepdim=True) # anneal as NeuS cos_anneal_ratio = self._cos_anneal_ratio # "cos_anneal_ratio" grows from 0 to 1 in the beginning training iterations. The anneal strategy below makes # the cos value "not dead" at the beginning training iterations, for better convergence. iter_cos = -( F.relu(-true_cos * 0.5 + 0.5) * (1.0 - cos_anneal_ratio) + F.relu(-true_cos) * cos_anneal_ratio ) # always non-positive # Estimate signed distances at section points estimated_next_sdf = sdf + iter_cos * ray_samples.deltas * 0.5 estimated_prev_sdf = sdf - iter_cos * ray_samples.deltas * 0.5 prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s) next_cdf = torch.sigmoid(estimated_next_sdf * inv_s) p = prev_cdf - next_cdf c = prev_cdf alpha = ((p + 1e-5) / (c + 1e-5)).clip(0.0, 1.0) return alpha
[docs] def get_density(self, ray_samples: RaySamples): raise NotImplementedError
[docs] def get_colors( self, points: Float[Tensor, "*batch 3"], directions: Float[Tensor, "*batch 3"], normals: Float[Tensor, "*batch 3"], geo_features: Float[Tensor, "*batch geo_feat_dim"], camera_indices: Tensor, ) -> Float[Tensor, "*batch 3"]: """compute colors""" d = self.direction_encoding(directions) # appearance if self.training: embedded_appearance = self.embedding_appearance(camera_indices) # set it to zero if don't use it if not self.config.use_appearance_embedding: embedded_appearance = torch.zeros_like(embedded_appearance) else: if self.use_average_appearance_embedding: embedded_appearance = torch.ones( (*directions.shape[:-1], self.config.appearance_embedding_dim), device=directions.device ) * self.embedding_appearance.mean(dim=0) else: embedded_appearance = torch.zeros( (*directions.shape[:-1], self.config.appearance_embedding_dim), device=directions.device ) hidden_input = torch.cat( [ points, d, normals, geo_features.view(-1, self.config.geo_feat_dim), embedded_appearance.view(-1, self.config.appearance_embedding_dim), ], dim=-1, ) for layer in range(0, self.num_layers_color - 1): lin = getattr(self, "clin" + str(layer)) hidden_input = lin(hidden_input) if layer < self.num_layers_color - 2: hidden_input = self.relu(hidden_input) rgb = self.sigmoid(hidden_input) return rgb
[docs] def get_outputs( self, ray_samples: RaySamples, density_embedding: Optional[Tensor] = None, return_alphas: bool = False, ) -> Dict[FieldHeadNames, Tensor]: """compute output of ray samples""" if ray_samples.camera_indices is None: raise AttributeError("Camera indices are not provided.") outputs = {} camera_indices = ray_samples.camera_indices.squeeze() inputs = ray_samples.frustums.get_start_positions() inputs = inputs.view(-1, 3) directions = ray_samples.frustums.directions directions_flat = directions.reshape(-1, 3) inputs.requires_grad_(True) with torch.enable_grad(): hidden_output = self.forward_geonetwork(inputs) sdf, geo_feature = torch.split(hidden_output, [1, self.config.geo_feat_dim], dim=-1) d_output = torch.ones_like(sdf, requires_grad=False, device=sdf.device) gradients = torch.autograd.grad( outputs=sdf, inputs=inputs, grad_outputs=d_output, create_graph=True, retain_graph=True, only_inputs=True )[0] rgb = self.get_colors(inputs, directions_flat, gradients, geo_feature, camera_indices) rgb = rgb.view(*ray_samples.frustums.directions.shape[:-1], -1) sdf = sdf.view(*ray_samples.frustums.directions.shape[:-1], -1) gradients = gradients.view(*ray_samples.frustums.directions.shape[:-1], -1) normals = torch.nn.functional.normalize(gradients, p=2, dim=-1) outputs.update( { FieldHeadNames.RGB: rgb, FieldHeadNames.SDF: sdf, FieldHeadNames.NORMALS: normals, FieldHeadNames.GRADIENT: gradients, } ) if return_alphas: alphas = self.get_alpha(ray_samples, sdf, gradients) outputs.update({FieldHeadNames.ALPHA: alphas}) return outputs
[docs] def forward( self, ray_samples: RaySamples, compute_normals: bool = False, return_alphas: bool = False ) -> Dict[FieldHeadNames, Tensor]: """Evaluates the field at points along the ray. Args: ray_samples: Samples to evaluate field on. compute normals: not currently used in this implementation. return_alphas: Whether to return alpha values """ field_outputs = self.get_outputs(ray_samples, return_alphas=return_alphas) return field_outputs