Source code for nerfstudio.models.vanilla_nerf

# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Implementation of vanilla nerf.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Dict, List, Literal, Tuple, Type

import torch
from torch.nn import Parameter

from nerfstudio.cameras.rays import RayBundle
from nerfstudio.configs.config_utils import to_immutable_dict
from nerfstudio.field_components.encodings import NeRFEncoding
from nerfstudio.field_components.field_heads import FieldHeadNames
from nerfstudio.field_components.temporal_distortions import TemporalDistortionKind
from nerfstudio.fields.vanilla_nerf_field import NeRFField
from nerfstudio.model_components.losses import MSELoss, scale_gradients_by_distance_squared
from nerfstudio.model_components.ray_samplers import PDFSampler, UniformSampler
from nerfstudio.model_components.renderers import AccumulationRenderer, DepthRenderer, RGBRenderer
from nerfstudio.models.base_model import Model, ModelConfig
from nerfstudio.utils import colormaps, misc


[docs]@dataclass class VanillaModelConfig(ModelConfig): """Vanilla Model Config""" _target: Type = field(default_factory=lambda: NeRFModel) num_coarse_samples: int = 64 """Number of samples in coarse field evaluation""" num_importance_samples: int = 128 """Number of samples in fine field evaluation""" enable_temporal_distortion: bool = False """Specifies whether or not to include ray warping based on time.""" temporal_distortion_params: Dict[str, Any] = to_immutable_dict({"kind": TemporalDistortionKind.DNERF}) """Parameters to instantiate temporal distortion with""" use_gradient_scaling: bool = False """Use gradient scaler where the gradients are lower for points closer to the camera.""" background_color: Literal["random", "last_sample", "black", "white"] = "white" """Whether to randomize the background color."""
[docs]class NeRFModel(Model): """Vanilla NeRF model Args: config: Basic NeRF configuration to instantiate model """ config: VanillaModelConfig def __init__( self, config: VanillaModelConfig, **kwargs, ) -> None: self.field_coarse = None self.field_fine = None self.temporal_distortion = None super().__init__( config=config, **kwargs, )
[docs] def populate_modules(self): """Set the fields and modules""" super().populate_modules() # fields position_encoding = NeRFEncoding( in_dim=3, num_frequencies=10, min_freq_exp=0.0, max_freq_exp=8.0, include_input=True ) direction_encoding = NeRFEncoding( in_dim=3, num_frequencies=4, min_freq_exp=0.0, max_freq_exp=4.0, include_input=True ) self.field_coarse = NeRFField( position_encoding=position_encoding, direction_encoding=direction_encoding, ) self.field_fine = NeRFField( position_encoding=position_encoding, direction_encoding=direction_encoding, ) # samplers self.sampler_uniform = UniformSampler(num_samples=self.config.num_coarse_samples) self.sampler_pdf = PDFSampler(num_samples=self.config.num_importance_samples) # renderers self.renderer_rgb = RGBRenderer(background_color=self.config.background_color) self.renderer_accumulation = AccumulationRenderer() self.renderer_depth = DepthRenderer() # losses self.rgb_loss = MSELoss() # metrics from torchmetrics.functional import structural_similarity_index_measure from torchmetrics.image import PeakSignalNoiseRatio from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity self.psnr = PeakSignalNoiseRatio(data_range=1.0) self.ssim = structural_similarity_index_measure self.lpips = LearnedPerceptualImagePatchSimilarity(normalize=True) if getattr(self.config, "enable_temporal_distortion", False): params = self.config.temporal_distortion_params kind = params.pop("kind") self.temporal_distortion = kind.to_temporal_distortion(params)
[docs] def get_param_groups(self) -> Dict[str, List[Parameter]]: param_groups = {} if self.field_coarse is None or self.field_fine is None: raise ValueError("populate_fields() must be called before get_param_groups") param_groups["fields"] = list(self.field_coarse.parameters()) + list(self.field_fine.parameters()) if self.temporal_distortion is not None: param_groups["temporal_distortion"] = list(self.temporal_distortion.parameters()) return param_groups
[docs] def get_outputs(self, ray_bundle: RayBundle): if self.field_coarse is None or self.field_fine is None: raise ValueError("populate_fields() must be called before get_outputs") # uniform sampling ray_samples_uniform = self.sampler_uniform(ray_bundle) if self.temporal_distortion is not None: offsets = None if ray_samples_uniform.times is not None: offsets = self.temporal_distortion( ray_samples_uniform.frustums.get_positions(), ray_samples_uniform.times ) ray_samples_uniform.frustums.set_offsets(offsets) # coarse field: field_outputs_coarse = self.field_coarse.forward(ray_samples_uniform) if self.config.use_gradient_scaling: field_outputs_coarse = scale_gradients_by_distance_squared(field_outputs_coarse, ray_samples_uniform) weights_coarse = ray_samples_uniform.get_weights(field_outputs_coarse[FieldHeadNames.DENSITY]) rgb_coarse = self.renderer_rgb( rgb=field_outputs_coarse[FieldHeadNames.RGB], weights=weights_coarse, ) accumulation_coarse = self.renderer_accumulation(weights_coarse) depth_coarse = self.renderer_depth(weights_coarse, ray_samples_uniform) # pdf sampling ray_samples_pdf = self.sampler_pdf(ray_bundle, ray_samples_uniform, weights_coarse) if self.temporal_distortion is not None: offsets = None if ray_samples_pdf.times is not None: offsets = self.temporal_distortion(ray_samples_pdf.frustums.get_positions(), ray_samples_pdf.times) ray_samples_pdf.frustums.set_offsets(offsets) # fine field: field_outputs_fine = self.field_fine.forward(ray_samples_pdf) if self.config.use_gradient_scaling: field_outputs_fine = scale_gradients_by_distance_squared(field_outputs_fine, ray_samples_pdf) weights_fine = ray_samples_pdf.get_weights(field_outputs_fine[FieldHeadNames.DENSITY]) rgb_fine = self.renderer_rgb( rgb=field_outputs_fine[FieldHeadNames.RGB], weights=weights_fine, ) accumulation_fine = self.renderer_accumulation(weights_fine) depth_fine = self.renderer_depth(weights_fine, ray_samples_pdf) outputs = { "rgb_coarse": rgb_coarse, "rgb_fine": rgb_fine, "accumulation_coarse": accumulation_coarse, "accumulation_fine": accumulation_fine, "depth_coarse": depth_coarse, "depth_fine": depth_fine, } return outputs
[docs] def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict[str, torch.Tensor]: # Scaling metrics by coefficients to create the losses. device = outputs["rgb_coarse"].device image = batch["image"].to(device) coarse_pred, coarse_image = self.renderer_rgb.blend_background_for_loss_computation( pred_image=outputs["rgb_coarse"], pred_accumulation=outputs["accumulation_coarse"], gt_image=image, ) fine_pred, fine_image = self.renderer_rgb.blend_background_for_loss_computation( pred_image=outputs["rgb_fine"], pred_accumulation=outputs["accumulation_fine"], gt_image=image, ) rgb_loss_coarse = self.rgb_loss(coarse_image, coarse_pred) rgb_loss_fine = self.rgb_loss(fine_image, fine_pred) loss_dict = {"rgb_loss_coarse": rgb_loss_coarse, "rgb_loss_fine": rgb_loss_fine} loss_dict = misc.scale_dict(loss_dict, self.config.loss_coefficients) return loss_dict
[docs] def get_image_metrics_and_images( self, outputs: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor] ) -> Tuple[Dict[str, float], Dict[str, torch.Tensor]]: image = batch["image"].to(outputs["rgb_coarse"].device) image = self.renderer_rgb.blend_background(image) rgb_coarse = outputs["rgb_coarse"] rgb_fine = outputs["rgb_fine"] acc_coarse = colormaps.apply_colormap(outputs["accumulation_coarse"]) acc_fine = colormaps.apply_colormap(outputs["accumulation_fine"]) assert self.config.collider_params is not None depth_coarse = colormaps.apply_depth_colormap( outputs["depth_coarse"], accumulation=outputs["accumulation_coarse"], near_plane=self.config.collider_params["near_plane"], far_plane=self.config.collider_params["far_plane"], ) depth_fine = colormaps.apply_depth_colormap( outputs["depth_fine"], accumulation=outputs["accumulation_fine"], near_plane=self.config.collider_params["near_plane"], far_plane=self.config.collider_params["far_plane"], ) combined_rgb = torch.cat([image, rgb_coarse, rgb_fine], dim=1) combined_acc = torch.cat([acc_coarse, acc_fine], dim=1) combined_depth = torch.cat([depth_coarse, depth_fine], dim=1) # Switch images from [H, W, C] to [1, C, H, W] for metrics computations image = torch.moveaxis(image, -1, 0)[None, ...] rgb_coarse = torch.moveaxis(rgb_coarse, -1, 0)[None, ...] rgb_fine = torch.moveaxis(rgb_fine, -1, 0)[None, ...] coarse_psnr = self.psnr(image, rgb_coarse) fine_psnr = self.psnr(image, rgb_fine) fine_ssim = self.ssim(image, rgb_fine) fine_lpips = self.lpips(image, rgb_fine) assert isinstance(fine_ssim, torch.Tensor) metrics_dict = { "psnr": float(fine_psnr.item()), "coarse_psnr": float(coarse_psnr), "fine_psnr": float(fine_psnr), "fine_ssim": float(fine_ssim), "fine_lpips": float(fine_lpips), } images_dict = {"img": combined_rgb, "accumulation": combined_acc, "depth": combined_depth} return metrics_dict, images_dict