Source code for nerfstudio.fields.tensorf_field

# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

"""TensoRF Field"""

from typing import Dict, Optional

import torch
from torch import Tensor, nn
from torch.nn.parameter import Parameter

from nerfstudio.cameras.rays import RaySamples
from import SceneBox
from nerfstudio.field_components.encodings import Encoding, Identity, SHEncoding
from nerfstudio.field_components.field_heads import FieldHeadNames, RGBFieldHead
from nerfstudio.field_components.mlp import MLP
from nerfstudio.fields.base_field import Field

[docs]class TensoRFField(Field): """TensoRF Field""" def __init__( self, aabb: Tensor, # the aabb bounding box of the dataset feature_encoding: Encoding = Identity(in_dim=3), # the encoding method used for appearance encoding outputs direction_encoding: Encoding = Identity(in_dim=3), # the encoding method used for ray direction density_encoding: Encoding = Identity(in_dim=3), # the tensor encoding method used for scene density color_encoding: Encoding = Identity(in_dim=3), # the tensor encoding method used for scene color appearance_dim: int = 27, # the number of dimensions for the appearance embedding head_mlp_num_layers: int = 2, # number of layers for the MLP head_mlp_layer_width: int = 128, # layer width for the MLP use_sh: bool = False, # whether to use spherical harmonics as the feature decoding function sh_levels: int = 2, # number of levels to use for spherical harmonics ) -> None: super().__init__() self.aabb = Parameter(aabb, requires_grad=False) self.feature_encoding = feature_encoding self.direction_encoding = direction_encoding self.density_encoding = density_encoding self.color_encoding = color_encoding self.mlp_head = MLP( in_dim=appearance_dim + 3 + self.direction_encoding.get_out_dim() + self.feature_encoding.get_out_dim(), num_layers=head_mlp_num_layers, layer_width=head_mlp_layer_width, activation=nn.ReLU(), out_activation=nn.ReLU(), ) self.use_sh = use_sh if self.use_sh: = SHEncoding(sh_levels) self.B = nn.Linear( in_features=self.color_encoding.get_out_dim(), out_features=3 *, bias=False ) else: self.B = nn.Linear(in_features=self.color_encoding.get_out_dim(), out_features=appearance_dim, bias=False) self.field_output_rgb = RGBFieldHead(in_dim=self.mlp_head.get_out_dim(), activation=nn.Sigmoid())
[docs] def get_density(self, ray_samples: RaySamples) -> Tensor: positions = SceneBox.get_normalized_positions(ray_samples.frustums.get_positions(), self.aabb) positions = positions * 2 - 1 density = self.density_encoding(positions) density_enc = torch.sum(density, dim=-1)[:, :, None] relu = torch.nn.ReLU() density_enc = relu(density_enc) return density_enc
[docs] def get_outputs(self, ray_samples: RaySamples, density_embedding: Optional[Tensor] = None) -> Tensor: d = ray_samples.frustums.directions positions = SceneBox.get_normalized_positions(ray_samples.frustums.get_positions(), self.aabb) positions = positions * 2 - 1 rgb_features = self.color_encoding(positions) rgb_features = self.B(rgb_features) if self.use_sh: sh_mult =[:, :, None] rgb_sh = rgb_features.view(sh_mult.shape[0], sh_mult.shape[1], 3, sh_mult.shape[-1]) rgb = torch.relu(torch.sum(sh_mult * rgb_sh, dim=-1) + 0.5) else: d_encoded = self.direction_encoding(d) rgb_features_encoded = self.feature_encoding(rgb_features) out = self.mlp_head([rgb_features, d, rgb_features_encoded, d_encoded], dim=-1)) # type: ignore rgb = self.field_output_rgb(out) return rgb
[docs] def forward( self, ray_samples: RaySamples, compute_normals: bool = False, mask: Optional[Tensor] = None, bg_color: Optional[Tensor] = None, ) -> Dict[FieldHeadNames, Tensor]: if compute_normals is True: raise ValueError("Surface normals are not currently supported with TensoRF") if mask is not None and bg_color is not None: base_density = torch.zeros(ray_samples.shape)[:, :, None].to(mask.device) base_rgb = bg_color.repeat(ray_samples[:, :, None].shape) if mask.any(): input_rays = ray_samples[mask, :] density = self.get_density(input_rays) rgb = self.get_outputs(input_rays, None) base_density[mask] = density base_rgb[mask] = rgb base_density.requires_grad_() base_rgb.requires_grad_() density = base_density rgb = base_rgb else: density = self.get_density(ray_samples) rgb = self.get_outputs(ray_samples, None) return {FieldHeadNames.DENSITY: density, FieldHeadNames.RGB: rgb}