Source code for nerfstudio.field_components.mlp

# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Multi Layer Perceptron
"""

from typing import Literal, Optional, Set, Tuple, Union

import numpy as np
import torch
from jaxtyping import Float
from torch import Tensor, nn

from nerfstudio.field_components.base_field_component import FieldComponent
from nerfstudio.field_components.encodings import HashEncoding
from nerfstudio.utils.external import TCNN_EXISTS, tcnn
from nerfstudio.utils.printing import print_tcnn_speed_warning
from nerfstudio.utils.rich_utils import CONSOLE


[docs]def activation_to_tcnn_string(activation: Union[nn.Module, None]) -> str: """Converts a torch.nn activation function to a string that can be used to initialize a TCNN activation function. Args: activation: torch.nn activation function Returns: str: TCNN activation function string """ if isinstance(activation, nn.ReLU): return "ReLU" if isinstance(activation, nn.LeakyReLU): return "Leaky ReLU" if isinstance(activation, nn.Sigmoid): return "Sigmoid" if isinstance(activation, nn.Softplus): return "Softplus" if isinstance(activation, nn.Tanh): return "Tanh" if isinstance(activation, type(None)): return "None" tcnn_documentation_url = "https://github.com/NVlabs/tiny-cuda-nn/blob/master/DOCUMENTATION.md#activation-functions" raise ValueError( f"TCNN activation {activation} not supported for now.\nSee {tcnn_documentation_url} for TCNN documentation." )
[docs]class MLP(FieldComponent): """Multilayer perceptron Args: in_dim: Input layer dimension num_layers: Number of network layers layer_width: Width of each MLP layer out_dim: Output layer dimension. Uses layer_width if None. activation: intermediate layer activation function. out_activation: output activation function. implementation: Implementation of hash encoding. Fallback to torch if tcnn not available. """ def __init__( self, in_dim: int, num_layers: int, layer_width: int, out_dim: Optional[int] = None, skip_connections: Optional[Tuple[int]] = None, activation: Optional[nn.Module] = nn.ReLU(), out_activation: Optional[nn.Module] = None, implementation: Literal["tcnn", "torch"] = "torch", ) -> None: super().__init__() self.in_dim = in_dim assert self.in_dim > 0 self.out_dim = out_dim if out_dim is not None else layer_width self.num_layers = num_layers self.layer_width = layer_width self.skip_connections = skip_connections self._skip_connections: Set[int] = set(skip_connections) if skip_connections else set() self.activation = activation self.out_activation = out_activation self.net = None self.tcnn_encoding = None if implementation == "torch": self.build_nn_modules() elif implementation == "tcnn" and not TCNN_EXISTS: print_tcnn_speed_warning("MLP") self.build_nn_modules() elif implementation == "tcnn": network_config = self.get_tcnn_network_config( activation=self.activation, out_activation=self.out_activation, layer_width=self.layer_width, num_layers=self.num_layers, ) self.tcnn_encoding = tcnn.Network( n_input_dims=in_dim, n_output_dims=self.out_dim, network_config=network_config, )
[docs] @classmethod def get_tcnn_network_config(cls, activation, out_activation, layer_width, num_layers) -> dict: """Get the network configuration for tcnn if implemented""" activation_str = activation_to_tcnn_string(activation) output_activation_str = activation_to_tcnn_string(out_activation) if layer_width in [16, 32, 64, 128]: network_config = { "otype": "FullyFusedMLP", "activation": activation_str, "output_activation": output_activation_str, "n_neurons": layer_width, "n_hidden_layers": num_layers - 1, } else: CONSOLE.line() CONSOLE.print("[bold yellow]WARNING: Using slower TCNN CutlassMLP instead of TCNN FullyFusedMLP") CONSOLE.print("[bold yellow]Use layer width of 16, 32, 64, or 128 to use the faster TCNN FullyFusedMLP.") CONSOLE.line() network_config = { "otype": "CutlassMLP", "activation": activation_str, "output_activation": output_activation_str, "n_neurons": layer_width, "n_hidden_layers": num_layers - 1, } return network_config
[docs] def build_nn_modules(self) -> None: """Initialize the torch version of the multi-layer perceptron.""" layers = [] if self.num_layers == 1: layers.append(nn.Linear(self.in_dim, self.out_dim)) else: for i in range(self.num_layers - 1): if i == 0: assert i not in self._skip_connections, "Skip connection at layer 0 doesn't make sense." layers.append(nn.Linear(self.in_dim, self.layer_width)) elif i in self._skip_connections: layers.append(nn.Linear(self.layer_width + self.in_dim, self.layer_width)) else: layers.append(nn.Linear(self.layer_width, self.layer_width)) layers.append(nn.Linear(self.layer_width, self.out_dim)) self.layers = nn.ModuleList(layers)
[docs] def pytorch_fwd(self, in_tensor: Float[Tensor, "*bs in_dim"]) -> Float[Tensor, "*bs out_dim"]: """Process input with a multilayer perceptron. Args: in_tensor: Network input Returns: MLP network output """ x = in_tensor for i, layer in enumerate(self.layers): # as checked in `build_nn_modules`, 0 should not be in `_skip_connections` if i in self._skip_connections: x = torch.cat([in_tensor, x], -1) x = layer(x) if self.activation is not None and i < len(self.layers) - 1: x = self.activation(x) if self.out_activation is not None: x = self.out_activation(x) return x
[docs] def forward(self, in_tensor: Float[Tensor, "*bs in_dim"]) -> Float[Tensor, "*bs out_dim"]: if self.tcnn_encoding is not None: return self.tcnn_encoding(in_tensor) return self.pytorch_fwd(in_tensor)
[docs]class MLPWithHashEncoding(FieldComponent): """Multilayer perceptron with hash encoding Args: num_levels: Number of feature grids. min_res: Resolution of smallest feature grid. max_res: Resolution of largest feature grid. log2_hashmap_size: Size of hash map is 2^log2_hashmap_size. features_per_level: Number of features per level. hash_init_scale: Value to initialize hash grid. interpolation: Interpolation override for tcnn hashgrid. Not supported for torch unless linear. num_layers: Number of network layers layer_width: Width of each MLP layer out_dim: Output layer dimension. Uses layer_width if None. activation: intermediate layer activation function. out_activation: output activation function. implementation: Implementation of hash encoding. Fallback to torch if tcnn not available. """ def __init__( self, num_levels: int = 16, min_res: int = 16, max_res: int = 1024, log2_hashmap_size: int = 19, features_per_level: int = 2, hash_init_scale: float = 0.001, interpolation: Optional[Literal["Nearest", "Linear", "Smoothstep"]] = None, num_layers: int = 2, layer_width: int = 64, out_dim: Optional[int] = None, skip_connections: Optional[Tuple[int]] = None, activation: Optional[nn.Module] = nn.ReLU(), out_activation: Optional[nn.Module] = None, implementation: Literal["tcnn", "torch"] = "torch", ) -> None: super().__init__() self.in_dim = 3 self.num_levels = num_levels self.min_res = min_res self.max_res = max_res self.features_per_level = features_per_level self.hash_init_scale = hash_init_scale self.log2_hashmap_size = log2_hashmap_size self.hash_table_size = 2**log2_hashmap_size self.growth_factor = np.exp((np.log(max_res) - np.log(min_res)) / (num_levels - 1)) if num_levels > 1 else 1 self.out_dim = out_dim if out_dim is not None else layer_width self.num_layers = num_layers self.layer_width = layer_width self.skip_connections = skip_connections self._skip_connections: Set[int] = set(skip_connections) if skip_connections else set() self.activation = activation self.out_activation = out_activation self.net = None self.tcnn_encoding = None if implementation == "torch": self.build_nn_modules() elif implementation == "tcnn" and not TCNN_EXISTS: print_tcnn_speed_warning("MLPWithHashEncoding") self.build_nn_modules() elif implementation == "tcnn": self.model = tcnn.NetworkWithInputEncoding( n_input_dims=self.in_dim, n_output_dims=self.out_dim, encoding_config=HashEncoding.get_tcnn_encoding_config( num_levels=self.num_levels, features_per_level=self.features_per_level, log2_hashmap_size=self.log2_hashmap_size, min_res=self.min_res, growth_factor=self.growth_factor, interpolation=interpolation, ), network_config=MLP.get_tcnn_network_config( activation=self.activation, out_activation=self.out_activation, layer_width=self.layer_width, num_layers=self.num_layers, ), )
[docs] def build_nn_modules(self) -> None: """Initialize the torch version of the MLP with hash encoding.""" encoder = HashEncoding( num_levels=self.num_levels, min_res=self.min_res, max_res=self.max_res, log2_hashmap_size=self.log2_hashmap_size, features_per_level=self.features_per_level, hash_init_scale=self.hash_init_scale, implementation="torch", ) mlp = MLP( in_dim=encoder.get_out_dim(), num_layers=self.num_layers, layer_width=self.layer_width, out_dim=self.out_dim, skip_connections=self.skip_connections, activation=self.activation, out_activation=self.out_activation, implementation="torch", ) self.model = torch.nn.Sequential(encoder, mlp)
[docs] def forward(self, in_tensor: Float[Tensor, "*bs in_dim"]) -> Float[Tensor, "*bs out_dim"]: return self.model(in_tensor)