Source code for nerfstudio.fields.density_fields

# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Proposal network field.
"""


from typing import Literal, Optional, Tuple

import torch
from torch import Tensor, nn

from nerfstudio.cameras.rays import RaySamples
from nerfstudio.data.scene_box import SceneBox
from nerfstudio.field_components.activations import trunc_exp
from nerfstudio.field_components.encodings import HashEncoding
from nerfstudio.field_components.mlp import MLP
from nerfstudio.field_components.spatial_distortions import SpatialDistortion
from nerfstudio.fields.base_field import Field


[docs]class HashMLPDensityField(Field): """A lightweight density field module. Args: aabb: parameters of scene aabb bounds num_layers: number of hidden layers hidden_dim: dimension of hidden layers spatial_distortion: spatial distortion module use_linear: whether to skip the MLP and use a single linear layer instead """ aabb: Tensor def __init__( self, aabb: Tensor, num_layers: int = 2, hidden_dim: int = 64, spatial_distortion: Optional[SpatialDistortion] = None, use_linear: bool = False, num_levels: int = 8, max_res: int = 1024, base_res: int = 16, log2_hashmap_size: int = 18, features_per_level: int = 2, average_init_density: float = 1.0, implementation: Literal["tcnn", "torch"] = "tcnn", ) -> None: super().__init__() self.register_buffer("aabb", aabb) self.spatial_distortion = spatial_distortion self.use_linear = use_linear self.average_init_density = average_init_density self.register_buffer("max_res", torch.tensor(max_res)) self.register_buffer("num_levels", torch.tensor(num_levels)) self.register_buffer("log2_hashmap_size", torch.tensor(log2_hashmap_size)) self.encoding = HashEncoding( num_levels=num_levels, min_res=base_res, max_res=max_res, log2_hashmap_size=log2_hashmap_size, features_per_level=features_per_level, implementation=implementation, ) if not self.use_linear: network = MLP( in_dim=self.encoding.get_out_dim(), num_layers=num_layers, layer_width=hidden_dim, out_dim=1, activation=nn.ReLU(), out_activation=None, implementation=implementation, ) self.mlp_base = torch.nn.Sequential(self.encoding, network) else: self.linear = torch.nn.Linear(self.encoding.get_out_dim(), 1)
[docs] def get_density(self, ray_samples: RaySamples) -> Tuple[Tensor, None]: if self.spatial_distortion is not None: positions = self.spatial_distortion(ray_samples.frustums.get_positions()) positions = (positions + 2.0) / 4.0 else: positions = SceneBox.get_normalized_positions(ray_samples.frustums.get_positions(), self.aabb) # Make sure the tcnn gets inputs between 0 and 1. selector = ((positions > 0.0) & (positions < 1.0)).all(dim=-1) positions = positions * selector[..., None] positions_flat = positions.view(-1, 3) if not self.use_linear: density_before_activation = ( self.mlp_base(positions_flat).view(*ray_samples.frustums.shape, -1).to(positions) ) else: x = self.encoding(positions_flat).to(positions) density_before_activation = self.linear(x).view(*ray_samples.frustums.shape, -1) # Rectifying the density with an exponential is much more stable than a ReLU or # softplus, because it enables high post-activation (float32) density outputs # from smaller internal (float16) parameters. density = self.average_init_density * trunc_exp(density_before_activation) density = density * selector[..., None] return density, None
[docs] def get_outputs(self, ray_samples: RaySamples, density_embedding: Optional[Tensor] = None) -> dict: return {}