# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Field for compound nerf model, adds scene contraction and image embeddings to instant ngp
"""
from typing import Dict, Literal, Optional, Tuple
import torch
from torch import Tensor, nn
from nerfstudio.cameras.rays import RaySamples
from nerfstudio.data.scene_box import SceneBox
from nerfstudio.field_components.activations import trunc_exp
from nerfstudio.field_components.embedding import Embedding
from nerfstudio.field_components.encodings import NeRFEncoding, SHEncoding
from nerfstudio.field_components.field_heads import (
FieldHeadNames,
PredNormalsFieldHead,
SemanticFieldHead,
TransientDensityFieldHead,
TransientRGBFieldHead,
UncertaintyFieldHead,
)
from nerfstudio.field_components.mlp import MLP, MLPWithHashEncoding
from nerfstudio.field_components.spatial_distortions import SpatialDistortion
from nerfstudio.fields.base_field import Field, get_normalized_directions
[docs]class NerfactoField(Field):
"""Compound Field
Args:
aabb: parameters of scene aabb bounds
num_images: number of images in the dataset
num_layers: number of hidden layers
hidden_dim: dimension of hidden layers
geo_feat_dim: output geo feat dimensions
num_levels: number of levels of the hashmap for the base mlp
base_res: base resolution of the hashmap for the base mlp
max_res: maximum resolution of the hashmap for the base mlp
log2_hashmap_size: size of the hashmap for the base mlp
num_layers_color: number of hidden layers for color network
num_layers_transient: number of hidden layers for transient network
features_per_level: number of features per level for the hashgrid
hidden_dim_color: dimension of hidden layers for color network
hidden_dim_transient: dimension of hidden layers for transient network
appearance_embedding_dim: dimension of appearance embedding
transient_embedding_dim: dimension of transient embedding
use_transient_embedding: whether to use transient embedding
use_semantics: whether to use semantic segmentation
num_semantic_classes: number of semantic classes
use_pred_normals: whether to use predicted normals
use_average_appearance_embedding: whether to use average appearance embedding or zeros for inference
spatial_distortion: spatial distortion to apply to the scene
"""
aabb: Tensor
def __init__(
self,
aabb: Tensor,
num_images: int,
num_layers: int = 2,
hidden_dim: int = 64,
geo_feat_dim: int = 15,
num_levels: int = 16,
base_res: int = 16,
max_res: int = 2048,
log2_hashmap_size: int = 19,
num_layers_color: int = 3,
num_layers_transient: int = 2,
features_per_level: int = 2,
hidden_dim_color: int = 64,
hidden_dim_transient: int = 64,
appearance_embedding_dim: int = 32,
transient_embedding_dim: int = 16,
use_transient_embedding: bool = False,
use_semantics: bool = False,
num_semantic_classes: int = 100,
pass_semantic_gradients: bool = False,
use_pred_normals: bool = False,
use_average_appearance_embedding: bool = False,
spatial_distortion: Optional[SpatialDistortion] = None,
average_init_density: float = 1.0,
implementation: Literal["tcnn", "torch"] = "tcnn",
) -> None:
super().__init__()
self.register_buffer("aabb", aabb)
self.geo_feat_dim = geo_feat_dim
self.register_buffer("max_res", torch.tensor(max_res))
self.register_buffer("num_levels", torch.tensor(num_levels))
self.register_buffer("log2_hashmap_size", torch.tensor(log2_hashmap_size))
self.spatial_distortion = spatial_distortion
self.num_images = num_images
self.appearance_embedding_dim = appearance_embedding_dim
if self.appearance_embedding_dim > 0:
self.embedding_appearance = Embedding(self.num_images, self.appearance_embedding_dim)
else:
self.embedding_appearance = None
self.use_average_appearance_embedding = use_average_appearance_embedding
self.use_transient_embedding = use_transient_embedding
self.use_semantics = use_semantics
self.use_pred_normals = use_pred_normals
self.pass_semantic_gradients = pass_semantic_gradients
self.base_res = base_res
self.average_init_density = average_init_density
self.step = 0
self.direction_encoding = SHEncoding(
levels=4,
implementation=implementation,
)
self.position_encoding = NeRFEncoding(
in_dim=3, num_frequencies=2, min_freq_exp=0, max_freq_exp=2 - 1, implementation=implementation
)
self.mlp_base = MLPWithHashEncoding(
num_levels=num_levels,
min_res=base_res,
max_res=max_res,
log2_hashmap_size=log2_hashmap_size,
features_per_level=features_per_level,
num_layers=num_layers,
layer_width=hidden_dim,
out_dim=1 + self.geo_feat_dim,
activation=nn.ReLU(),
out_activation=None,
implementation=implementation,
)
# transients
if self.use_transient_embedding:
self.transient_embedding_dim = transient_embedding_dim
self.embedding_transient = Embedding(self.num_images, self.transient_embedding_dim)
self.mlp_transient = MLP(
in_dim=self.geo_feat_dim + self.transient_embedding_dim,
num_layers=num_layers_transient,
layer_width=hidden_dim_transient,
out_dim=hidden_dim_transient,
activation=nn.ReLU(),
out_activation=None,
implementation=implementation,
)
self.field_head_transient_uncertainty = UncertaintyFieldHead(in_dim=self.mlp_transient.get_out_dim())
self.field_head_transient_rgb = TransientRGBFieldHead(in_dim=self.mlp_transient.get_out_dim())
self.field_head_transient_density = TransientDensityFieldHead(in_dim=self.mlp_transient.get_out_dim())
# semantics
if self.use_semantics:
self.mlp_semantics = MLP(
in_dim=self.geo_feat_dim,
num_layers=2,
layer_width=64,
out_dim=hidden_dim_transient,
activation=nn.ReLU(),
out_activation=None,
implementation=implementation,
)
self.field_head_semantics = SemanticFieldHead(
in_dim=self.mlp_semantics.get_out_dim(), num_classes=num_semantic_classes
)
# predicted normals
if self.use_pred_normals:
self.mlp_pred_normals = MLP(
in_dim=self.geo_feat_dim + self.position_encoding.get_out_dim(),
num_layers=3,
layer_width=64,
out_dim=hidden_dim_transient,
activation=nn.ReLU(),
out_activation=None,
implementation=implementation,
)
self.field_head_pred_normals = PredNormalsFieldHead(in_dim=self.mlp_pred_normals.get_out_dim())
self.mlp_head = MLP(
in_dim=self.direction_encoding.get_out_dim() + self.geo_feat_dim + self.appearance_embedding_dim,
num_layers=num_layers_color,
layer_width=hidden_dim_color,
out_dim=3,
activation=nn.ReLU(),
out_activation=nn.Sigmoid(),
implementation=implementation,
)
[docs] def get_density(self, ray_samples: RaySamples) -> Tuple[Tensor, Tensor]:
"""Computes and returns the densities."""
if self.spatial_distortion is not None:
positions = ray_samples.frustums.get_positions()
positions = self.spatial_distortion(positions)
positions = (positions + 2.0) / 4.0
else:
positions = SceneBox.get_normalized_positions(ray_samples.frustums.get_positions(), self.aabb)
# Make sure the tcnn gets inputs between 0 and 1.
selector = ((positions > 0.0) & (positions < 1.0)).all(dim=-1)
positions = positions * selector[..., None]
assert positions.numel() > 0, "positions is empty."
self._sample_locations = positions
if not self._sample_locations.requires_grad:
self._sample_locations.requires_grad = True
positions_flat = positions.view(-1, 3)
assert positions_flat.numel() > 0, "positions_flat is empty."
h = self.mlp_base(positions_flat).view(*ray_samples.frustums.shape, -1)
density_before_activation, base_mlp_out = torch.split(h, [1, self.geo_feat_dim], dim=-1)
self._density_before_activation = density_before_activation
# Rectifying the density with an exponential is much more stable than a ReLU or
# softplus, because it enables high post-activation (float32) density outputs
# from smaller internal (float16) parameters.
density = self.average_init_density * trunc_exp(density_before_activation.to(positions))
density = density * selector[..., None]
return density, base_mlp_out
[docs] def get_outputs(
self, ray_samples: RaySamples, density_embedding: Optional[Tensor] = None
) -> Dict[FieldHeadNames, Tensor]:
assert density_embedding is not None
outputs = {}
if ray_samples.camera_indices is None:
raise AttributeError("Camera indices are not provided.")
camera_indices = ray_samples.camera_indices.squeeze()
directions = get_normalized_directions(ray_samples.frustums.directions)
directions_flat = directions.view(-1, 3)
d = self.direction_encoding(directions_flat)
outputs_shape = ray_samples.frustums.directions.shape[:-1]
# appearance
embedded_appearance = None
if self.embedding_appearance is not None:
if self.training:
embedded_appearance = self.embedding_appearance(camera_indices)
else:
if self.use_average_appearance_embedding:
embedded_appearance = torch.ones(
(*directions.shape[:-1], self.appearance_embedding_dim), device=directions.device
) * self.embedding_appearance.mean(dim=0)
else:
embedded_appearance = torch.zeros(
(*directions.shape[:-1], self.appearance_embedding_dim), device=directions.device
)
# transients
if self.use_transient_embedding and self.training:
embedded_transient = self.embedding_transient(camera_indices)
transient_input = torch.cat(
[
density_embedding.view(-1, self.geo_feat_dim),
embedded_transient.view(-1, self.transient_embedding_dim),
],
dim=-1,
)
x = self.mlp_transient(transient_input).view(*outputs_shape, -1).to(directions)
outputs[FieldHeadNames.UNCERTAINTY] = self.field_head_transient_uncertainty(x)
outputs[FieldHeadNames.TRANSIENT_RGB] = self.field_head_transient_rgb(x)
outputs[FieldHeadNames.TRANSIENT_DENSITY] = self.field_head_transient_density(x)
# semantics
if self.use_semantics:
semantics_input = density_embedding.view(-1, self.geo_feat_dim)
if not self.pass_semantic_gradients:
semantics_input = semantics_input.detach()
x = self.mlp_semantics(semantics_input).view(*outputs_shape, -1).to(directions)
outputs[FieldHeadNames.SEMANTICS] = self.field_head_semantics(x)
# predicted normals
if self.use_pred_normals:
positions = ray_samples.frustums.get_positions()
positions_flat = self.position_encoding(positions.view(-1, 3))
pred_normals_inp = torch.cat([positions_flat, density_embedding.view(-1, self.geo_feat_dim)], dim=-1)
x = self.mlp_pred_normals(pred_normals_inp).view(*outputs_shape, -1).to(directions)
outputs[FieldHeadNames.PRED_NORMALS] = self.field_head_pred_normals(x)
h = torch.cat(
[
d,
density_embedding.view(-1, self.geo_feat_dim),
]
+ (
[embedded_appearance.view(-1, self.appearance_embedding_dim)] if embedded_appearance is not None else []
),
dim=-1,
)
rgb = self.mlp_head(h).view(*outputs_shape, -1).to(directions)
outputs.update({FieldHeadNames.RGB: rgb})
return outputs