Source code for nerfstudio.models.base_model

# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Base Model implementation which takes in RayBundles or Cameras
"""

from __future__ import annotations

from abc import abstractmethod
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple, Type, Union

import torch
from torch import nn
from torch.nn import Parameter

from nerfstudio.cameras.cameras import Cameras
from nerfstudio.cameras.rays import RayBundle
from nerfstudio.configs.base_config import InstantiateConfig
from nerfstudio.configs.config_utils import to_immutable_dict
from nerfstudio.data.scene_box import OrientedBox, SceneBox
from nerfstudio.engine.callbacks import TrainingCallback, TrainingCallbackAttributes
from nerfstudio.model_components.scene_colliders import NearFarCollider


# Model related configs
[docs]@dataclass class ModelConfig(InstantiateConfig): """Configuration for model instantiation""" _target: Type = field(default_factory=lambda: Model) """target class to instantiate""" enable_collider: bool = True """Whether to create a scene collider to filter rays.""" collider_params: Optional[Dict[str, float]] = to_immutable_dict({"near_plane": 2.0, "far_plane": 6.0}) """parameters to instantiate scene collider with""" loss_coefficients: Dict[str, float] = to_immutable_dict({"rgb_loss_coarse": 1.0, "rgb_loss_fine": 1.0}) """parameters to instantiate density field with""" eval_num_rays_per_chunk: int = 4096 """specifies number of rays per chunk during eval""" prompt: Optional[str] = None """A prompt to be used in text to NeRF models"""
[docs]class Model(nn.Module): """Model class Where everything (Fields, Optimizers, Samplers, Visualization, etc) is linked together. This should be subclassed for custom NeRF model. Args: config: configuration for instantiating model scene_box: dataset scene box """ config: ModelConfig def __init__( self, config: ModelConfig, scene_box: SceneBox, num_train_data: int, **kwargs, ) -> None: super().__init__() self.config = config self.scene_box = scene_box self.render_aabb: Optional[SceneBox] = None # the box that we want to render - should be a subset of scene_box self.num_train_data = num_train_data self.kwargs = kwargs self.collider = None self.populate_modules() # populate the modules self.callbacks = None # to keep track of which device the nn.Module is on self.device_indicator_param = nn.Parameter(torch.empty(0)) @property def device(self): """Returns the device that the model is on.""" return self.device_indicator_param.device
[docs] def get_training_callbacks( self, training_callback_attributes: TrainingCallbackAttributes ) -> List[TrainingCallback]: """Returns a list of callbacks that run functions at the specified training iterations.""" return []
[docs] def populate_modules(self): """Set the necessary modules to get the network working.""" # default instantiates optional modules that are common among many networks # NOTE: call `super().populate_modules()` in subclasses if self.config.enable_collider: assert self.config.collider_params is not None self.collider = NearFarCollider( near_plane=self.config.collider_params["near_plane"], far_plane=self.config.collider_params["far_plane"] )
[docs] @abstractmethod def get_param_groups(self) -> Dict[str, List[Parameter]]: """Obtain the parameter groups for the optimizers Returns: Mapping of different parameter groups """
[docs] @abstractmethod def get_outputs(self, ray_bundle: Union[RayBundle, Cameras]) -> Dict[str, Union[torch.Tensor, List]]: """Takes in a Ray Bundle and returns a dictionary of outputs. Args: ray_bundle: Input bundle of rays. This raybundle should have all the needed information to compute the outputs. Returns: Outputs of model. (ie. rendered colors) """
[docs] def forward(self, ray_bundle: Union[RayBundle, Cameras]) -> Dict[str, Union[torch.Tensor, List]]: """Run forward starting with a ray bundle. This outputs different things depending on the configuration of the model and whether or not the batch is provided (whether or not we are training basically) Args: ray_bundle: containing all the information needed to render that ray latents included """ if self.collider is not None: ray_bundle = self.collider(ray_bundle) return self.get_outputs(ray_bundle)
[docs] def get_metrics_dict(self, outputs, batch) -> Dict[str, torch.Tensor]: """Compute and returns metrics. Args: outputs: the output to compute loss dict to batch: ground truth batch corresponding to outputs """ return {}
[docs] @abstractmethod def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict[str, torch.Tensor]: """Computes and returns the losses dict. Args: outputs: the output to compute loss dict to batch: ground truth batch corresponding to outputs metrics_dict: dictionary of metrics, some of which we can use for loss """
[docs] @torch.no_grad() def get_outputs_for_camera(self, camera: Cameras, obb_box: Optional[OrientedBox] = None) -> Dict[str, torch.Tensor]: """Takes in a camera, generates the raybundle, and computes the output of the model. Assumes a ray-based model. Args: camera: generates raybundle """ return self.get_outputs_for_camera_ray_bundle( camera.generate_rays(camera_indices=0, keep_shape=True, obb_box=obb_box) )
[docs] @torch.no_grad() def get_outputs_for_camera_ray_bundle(self, camera_ray_bundle: RayBundle) -> Dict[str, torch.Tensor]: """Takes in camera parameters and computes the output of the model. Args: camera_ray_bundle: ray bundle to calculate outputs over """ input_device = camera_ray_bundle.directions.device num_rays_per_chunk = self.config.eval_num_rays_per_chunk image_height, image_width = camera_ray_bundle.origins.shape[:2] num_rays = len(camera_ray_bundle) outputs_lists = defaultdict(list) for i in range(0, num_rays, num_rays_per_chunk): start_idx = i end_idx = i + num_rays_per_chunk ray_bundle = camera_ray_bundle.get_row_major_sliced_ray_bundle(start_idx, end_idx) # move the chunk inputs to the model device ray_bundle = ray_bundle.to(self.device) outputs = self.forward(ray_bundle=ray_bundle) for output_name, output in outputs.items(): # type: ignore if not isinstance(output, torch.Tensor): # TODO: handle lists of tensors as well continue # move the chunk outputs from the model device back to the device of the inputs. outputs_lists[output_name].append(output.to(input_device)) outputs = {} for output_name, outputs_list in outputs_lists.items(): outputs[output_name] = torch.cat(outputs_list).view(image_height, image_width, -1) # type: ignore return outputs
[docs] def get_rgba_image(self, outputs: Dict[str, torch.Tensor], output_name: str = "rgb") -> torch.Tensor: """Returns the RGBA image from the outputs of the model. Args: outputs: Outputs of the model. Returns: RGBA image. """ accumulation_name = output_name.replace("rgb", "accumulation") if ( not hasattr(self, "renderer_rgb") or not hasattr(self.renderer_rgb, "background_color") or accumulation_name not in outputs ): raise NotImplementedError(f"get_rgba_image is not implemented for model {self.__class__.__name__}") rgb = outputs[output_name] if self.renderer_rgb.background_color == "random": # type: ignore acc = outputs[accumulation_name] if acc.dim() < rgb.dim(): acc = acc.unsqueeze(-1) return torch.cat((rgb / acc.clamp(min=1e-10), acc), dim=-1) return torch.cat((rgb, torch.ones_like(rgb[..., :1])), dim=-1)
[docs] @abstractmethod def get_image_metrics_and_images( self, outputs: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor] ) -> Tuple[Dict[str, float], Dict[str, torch.Tensor]]: """Writes the test image outputs. TODO: This shouldn't return a loss Args: image_idx: Index of the image. step: Current step. batch: Batch of data. outputs: Outputs of the model. Returns: A dictionary of metrics. """
[docs] def load_model(self, loaded_state: Dict[str, Any]) -> None: """Load the checkpoint from the given path Args: loaded_state: dictionary of pre-trained model states """ state = {key.replace("module.", ""): value for key, value in loaded_state["model"].items()} self.load_state_dict(state) # type: ignore
[docs] def update_to_step(self, step: int) -> None: """Called when loading a model from a checkpoint. Sets any model parameters that change over training to the correct value, based on the training step of the checkpoint. Args: step: training step of the loaded checkpoint """