Source code for nerfstudio.field_components.field_heads

# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Collection of render heads
"""
from enum import Enum
from typing import Callable, Optional, Union

import torch
from jaxtyping import Float, Shaped
from torch import Tensor, nn

from nerfstudio.field_components.base_field_component import FieldComponent


[docs]class FieldHeadNames(Enum): """Possible field outputs""" RGB = "rgb" SH = "sh" DENSITY = "density" NORMALS = "normals" PRED_NORMALS = "pred_normals" UNCERTAINTY = "uncertainty" BACKGROUND_RGB = "background_rgb" TRANSIENT_RGB = "transient_rgb" TRANSIENT_DENSITY = "transient_density" SEMANTICS = "semantics" SDF = "sdf" ALPHA = "alpha" GRADIENT = "gradient"
[docs]class FieldHead(FieldComponent): """Base field output Args: out_dim: output dimension for renderer field_head_name: Field type in_dim: input dimension. If not defined in constructor, it must be set later. activation: output head activation """ def __init__( self, out_dim: int, field_head_name: FieldHeadNames, in_dim: Optional[int] = None, activation: Optional[Union[nn.Module, Callable]] = None, ) -> None: super().__init__() self.out_dim = out_dim self.activation = activation self.field_head_name = field_head_name self.net = None if in_dim is not None: self.in_dim = in_dim self._construct_net()
[docs] def set_in_dim(self, in_dim: int) -> None: """Set input dimension of Field Head""" self.in_dim = in_dim self._construct_net()
def _construct_net(self): self.net = nn.Linear(self.in_dim, self.out_dim)
[docs] def forward(self, in_tensor: Shaped[Tensor, "*bs in_dim"]) -> Shaped[Tensor, "*bs out_dim"]: """Process network output for renderer Args: in_tensor: Network input Returns: Render head output """ if not self.net: raise SystemError("in_dim not set. Must be provided to constructor, or set_in_dim() should be called.") out_tensor = self.net(in_tensor) if self.activation: out_tensor = self.activation(out_tensor) return out_tensor
[docs]class DensityFieldHead(FieldHead): """Density output Args: in_dim: input dimension. If not defined in constructor, it must be set later. activation: output head activation """ def __init__(self, in_dim: Optional[int] = None, activation: Optional[nn.Module] = nn.Softplus()) -> None: super().__init__(in_dim=in_dim, out_dim=1, field_head_name=FieldHeadNames.DENSITY, activation=activation)
[docs]class RGBFieldHead(FieldHead): """RGB output Args: in_dim: input dimension. If not defined in constructor, it must be set later. activation: output head activation """ def __init__(self, in_dim: Optional[int] = None, activation: Optional[nn.Module] = nn.Sigmoid()) -> None: super().__init__(in_dim=in_dim, out_dim=3, field_head_name=FieldHeadNames.RGB, activation=activation)
[docs]class SHFieldHead(FieldHead): """Spherical harmonics output Args: in_dim: input dimension. If not defined in constructor, it must be set later. levels: Number of spherical harmonics layers. channels: Number of channels. Defaults to 3 (ie RGB). activation: Output activation. """ def __init__( self, in_dim: Optional[int] = None, levels: int = 3, channels: int = 3, activation: Optional[nn.Module] = None ) -> None: out_dim = channels * levels**2 super().__init__(in_dim=in_dim, out_dim=out_dim, field_head_name=FieldHeadNames.SH, activation=activation)
[docs]class UncertaintyFieldHead(FieldHead): """Uncertainty output Args: in_dim: input dimension. If not defined in constructor, it must be set later. activation: output head activation """ def __init__(self, in_dim: Optional[int] = None, activation: Optional[nn.Module] = nn.Softplus()) -> None: super().__init__(in_dim=in_dim, out_dim=1, field_head_name=FieldHeadNames.UNCERTAINTY, activation=activation)
[docs]class TransientRGBFieldHead(FieldHead): """Transient RGB output Args: in_dim: input dimension. If not defined in constructor, it must be set later. activation: output head activation """ def __init__(self, in_dim: Optional[int] = None, activation: Optional[nn.Module] = nn.Sigmoid()) -> None: super().__init__(in_dim=in_dim, out_dim=3, field_head_name=FieldHeadNames.TRANSIENT_RGB, activation=activation)
[docs]class TransientDensityFieldHead(FieldHead): """Transient density output Args: in_dim: input dimension. If not defined in constructor, it must be set later. activation: output head activation """ def __init__(self, in_dim: Optional[int] = None, activation: Optional[nn.Module] = nn.Softplus()) -> None: super().__init__( in_dim=in_dim, out_dim=1, field_head_name=FieldHeadNames.TRANSIENT_DENSITY, activation=activation )
[docs]class SemanticFieldHead(FieldHead): """Semantic output Args: num_classes: Number of semantic classes in_dim: input dimension. If not defined in constructor, it must be set later. activation: output head activation """ def __init__(self, num_classes: int, in_dim: Optional[int] = None) -> None: super().__init__(in_dim=in_dim, out_dim=num_classes, field_head_name=FieldHeadNames.SEMANTICS, activation=None)
[docs]class PredNormalsFieldHead(FieldHead): """Predicted normals output. Args: in_dim: input dimension. If not defined in constructor, it must be set later. activation: output head activation """ def __init__(self, in_dim: Optional[int] = None, activation: Optional[nn.Module] = nn.Tanh()) -> None: super().__init__(in_dim=in_dim, out_dim=3, field_head_name=FieldHeadNames.PRED_NORMALS, activation=activation)
[docs] def forward(self, in_tensor: Float[Tensor, "*bs in_dim"]) -> Float[Tensor, "*bs out_dim"]: """Needed to normalize the output into valid normals.""" out_tensor = super().forward(in_tensor) out_tensor = torch.nn.functional.normalize(out_tensor, dim=-1) return out_tensor