Source code for nerfstudio.fields.tensorf_field

# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""TensoRF Field"""


from typing import Dict, Optional

import torch
from torch import Tensor, nn
from torch.nn.parameter import Parameter

from nerfstudio.cameras.rays import RaySamples
from nerfstudio.data.scene_box import SceneBox
from nerfstudio.field_components.encodings import Encoding, Identity, SHEncoding
from nerfstudio.field_components.field_heads import FieldHeadNames, RGBFieldHead
from nerfstudio.field_components.mlp import MLP
from nerfstudio.fields.base_field import Field


[docs]class TensoRFField(Field): """TensoRF Field""" def __init__( self, aabb: Tensor, # the aabb bounding box of the dataset feature_encoding: Encoding = Identity(in_dim=3), # the encoding method used for appearance encoding outputs direction_encoding: Encoding = Identity(in_dim=3), # the encoding method used for ray direction density_encoding: Encoding = Identity(in_dim=3), # the tensor encoding method used for scene density color_encoding: Encoding = Identity(in_dim=3), # the tensor encoding method used for scene color appearance_dim: int = 27, # the number of dimensions for the appearance embedding head_mlp_num_layers: int = 2, # number of layers for the MLP head_mlp_layer_width: int = 128, # layer width for the MLP use_sh: bool = False, # whether to use spherical harmonics as the feature decoding function sh_levels: int = 2, # number of levels to use for spherical harmonics ) -> None: super().__init__() self.aabb = Parameter(aabb, requires_grad=False) self.feature_encoding = feature_encoding self.direction_encoding = direction_encoding self.density_encoding = density_encoding self.color_encoding = color_encoding self.mlp_head = MLP( in_dim=appearance_dim + 3 + self.direction_encoding.get_out_dim() + self.feature_encoding.get_out_dim(), num_layers=head_mlp_num_layers, layer_width=head_mlp_layer_width, activation=nn.ReLU(), out_activation=nn.ReLU(), ) self.use_sh = use_sh if self.use_sh: self.sh = SHEncoding(sh_levels) self.B = nn.Linear( in_features=self.color_encoding.get_out_dim(), out_features=3 * self.sh.get_out_dim(), bias=False ) else: self.B = nn.Linear(in_features=self.color_encoding.get_out_dim(), out_features=appearance_dim, bias=False) self.field_output_rgb = RGBFieldHead(in_dim=self.mlp_head.get_out_dim(), activation=nn.Sigmoid())
[docs] def get_density(self, ray_samples: RaySamples) -> Tensor: positions = SceneBox.get_normalized_positions(ray_samples.frustums.get_positions(), self.aabb) positions = positions * 2 - 1 density = self.density_encoding(positions) density_enc = torch.sum(density, dim=-1)[:, :, None] relu = torch.nn.ReLU() density_enc = relu(density_enc) return density_enc
[docs] def get_outputs(self, ray_samples: RaySamples, density_embedding: Optional[Tensor] = None) -> Tensor: d = ray_samples.frustums.directions positions = SceneBox.get_normalized_positions(ray_samples.frustums.get_positions(), self.aabb) positions = positions * 2 - 1 rgb_features = self.color_encoding(positions) rgb_features = self.B(rgb_features) if self.use_sh: sh_mult = self.sh(d)[:, :, None] rgb_sh = rgb_features.view(sh_mult.shape[0], sh_mult.shape[1], 3, sh_mult.shape[-1]) rgb = torch.relu(torch.sum(sh_mult * rgb_sh, dim=-1) + 0.5) else: d_encoded = self.direction_encoding(d) rgb_features_encoded = self.feature_encoding(rgb_features) out = self.mlp_head(torch.cat([rgb_features, d, rgb_features_encoded, d_encoded], dim=-1)) # type: ignore rgb = self.field_output_rgb(out) return rgb
[docs] def forward( self, ray_samples: RaySamples, compute_normals: bool = False, mask: Optional[Tensor] = None, bg_color: Optional[Tensor] = None, ) -> Dict[FieldHeadNames, Tensor]: if compute_normals is True: raise ValueError("Surface normals are not currently supported with TensoRF") if mask is not None and bg_color is not None: base_density = torch.zeros(ray_samples.shape)[:, :, None].to(mask.device) base_rgb = bg_color.repeat(ray_samples[:, :, None].shape) if mask.any(): input_rays = ray_samples[mask, :] density = self.get_density(input_rays) rgb = self.get_outputs(input_rays, None) base_density[mask] = density base_rgb[mask] = rgb base_density.requires_grad_() base_rgb.requires_grad_() density = base_density rgb = base_rgb else: density = self.get_density(ray_samples) rgb = self.get_outputs(ray_samples, None) return {FieldHeadNames.DENSITY: density, FieldHeadNames.RGB: rgb}