Source code for nerfstudio.model_components.losses

# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Collection of Losses.
"""
from enum import Enum
from typing import Dict, Literal, Optional, Tuple, cast

import torch
from jaxtyping import Bool, Float
from torch import Tensor, nn

from nerfstudio.cameras.rays import RaySamples
from nerfstudio.field_components.field_heads import FieldHeadNames
from nerfstudio.utils.math import masked_reduction, normalized_depth_scale_and_shift

L1Loss = nn.L1Loss
MSELoss = nn.MSELoss

LOSSES = {"L1": L1Loss, "MSE": MSELoss}

EPS = 1.0e-7

# Sigma scale factor from Urban Radiance Fields (Rematas et al., 2022)
URF_SIGMA_SCALE_FACTOR = 3.0


[docs]class DepthLossType(Enum): """Types of depth losses for depth supervision.""" DS_NERF = 1 URF = 2 SPARSENERF_RANKING = 3
FORCE_PSEUDODEPTH_LOSS = False PSEUDODEPTH_COMPATIBLE_LOSSES = (DepthLossType.SPARSENERF_RANKING,)
[docs]def outer( t0_starts: Float[Tensor, "*batch num_samples_0"], t0_ends: Float[Tensor, "*batch num_samples_0"], t1_starts: Float[Tensor, "*batch num_samples_1"], t1_ends: Float[Tensor, "*batch num_samples_1"], y1: Float[Tensor, "*batch num_samples_1"], ) -> Float[Tensor, "*batch num_samples_0"]: """Faster version of https://github.com/kakaobrain/NeRF-Factory/blob/f61bb8744a5cb4820a4d968fb3bfbed777550f4a/src/model/mipnerf360/helper.py#L117 https://github.com/google-research/multinerf/blob/b02228160d3179300c7d499dca28cb9ca3677f32/internal/stepfun.py#L64 Args: t0_starts: start of the interval edges t0_ends: end of the interval edges t1_starts: start of the interval edges t1_ends: end of the interval edges y1: weights """ cy1 = torch.cat([torch.zeros_like(y1[..., :1]), torch.cumsum(y1, dim=-1)], dim=-1) idx_lo = torch.searchsorted(t1_starts.contiguous(), t0_starts.contiguous(), side="right") - 1 idx_lo = torch.clamp(idx_lo, min=0, max=y1.shape[-1] - 1) idx_hi = torch.searchsorted(t1_ends.contiguous(), t0_ends.contiguous(), side="right") idx_hi = torch.clamp(idx_hi, min=0, max=y1.shape[-1] - 1) cy1_lo = torch.take_along_dim(cy1[..., :-1], idx_lo, dim=-1) cy1_hi = torch.take_along_dim(cy1[..., 1:], idx_hi, dim=-1) y0_outer = cy1_hi - cy1_lo return y0_outer
[docs]def lossfun_outer( t: Float[Tensor, "*batch num_samples_1"], w: Float[Tensor, "*batch num_samples"], t_env: Float[Tensor, "*batch num_samples_1"], w_env: Float[Tensor, "*batch num_samples"], ): """ https://github.com/kakaobrain/NeRF-Factory/blob/f61bb8744a5cb4820a4d968fb3bfbed777550f4a/src/model/mipnerf360/helper.py#L136 https://github.com/google-research/multinerf/blob/b02228160d3179300c7d499dca28cb9ca3677f32/internal/stepfun.py#L80 Args: t: interval edges w: weights t_env: interval edges of the upper bound enveloping histogram w_env: weights that should upper bound the inner (t,w) histogram """ w_outer = outer(t[..., :-1], t[..., 1:], t_env[..., :-1], t_env[..., 1:], w_env) return torch.clip(w - w_outer, min=0) ** 2 / (w + EPS)
[docs]def ray_samples_to_sdist(ray_samples): """Convert ray samples to s space""" starts = ray_samples.spacing_starts ends = ray_samples.spacing_ends sdist = torch.cat([starts[..., 0], ends[..., -1:, 0]], dim=-1) # (num_rays, num_samples + 1) return sdist
[docs]def interlevel_loss(weights_list, ray_samples_list) -> torch.Tensor: """Calculates the proposal loss in the MipNeRF-360 paper. https://github.com/kakaobrain/NeRF-Factory/blob/f61bb8744a5cb4820a4d968fb3bfbed777550f4a/src/model/mipnerf360/model.py#L515 https://github.com/google-research/multinerf/blob/b02228160d3179300c7d499dca28cb9ca3677f32/internal/train_utils.py#L133 """ c = ray_samples_to_sdist(ray_samples_list[-1]).detach() w = weights_list[-1][..., 0].detach() assert len(ray_samples_list) > 0 loss_interlevel = 0.0 for ray_samples, weights in zip(ray_samples_list[:-1], weights_list[:-1]): sdist = ray_samples_to_sdist(ray_samples) cp = sdist # (num_rays, num_samples + 1) wp = weights[..., 0] # (num_rays, num_samples) loss_interlevel += torch.mean(lossfun_outer(c, w, cp, wp)) assert isinstance(loss_interlevel, Tensor) return loss_interlevel
# Verified
[docs]def lossfun_distortion(t, w): """ https://github.com/kakaobrain/NeRF-Factory/blob/f61bb8744a5cb4820a4d968fb3bfbed777550f4a/src/model/mipnerf360/helper.py#L142 https://github.com/google-research/multinerf/blob/b02228160d3179300c7d499dca28cb9ca3677f32/internal/stepfun.py#L266 """ ut = (t[..., 1:] + t[..., :-1]) / 2 dut = torch.abs(ut[..., :, None] - ut[..., None, :]) loss_inter = torch.sum(w * torch.sum(w[..., None, :] * dut, dim=-1), dim=-1) loss_intra = torch.sum(w**2 * (t[..., 1:] - t[..., :-1]), dim=-1) / 3 return loss_inter + loss_intra
[docs]def distortion_loss(weights_list, ray_samples_list): """From mipnerf360""" c = ray_samples_to_sdist(ray_samples_list[-1]) w = weights_list[-1][..., 0] loss = torch.mean(lossfun_distortion(c, w)) return loss
[docs]def nerfstudio_distortion_loss( ray_samples: RaySamples, densities: Optional[Float[Tensor, "*bs num_samples 1"]] = None, weights: Optional[Float[Tensor, "*bs num_samples 1"]] = None, ) -> Float[Tensor, "*bs 1"]: """Ray based distortion loss proposed in MipNeRF-360. Returns distortion Loss. .. math:: \\mathcal{L}(\\mathbf{s}, \\mathbf{w}) =\\iint\\limits_{-\\infty}^{\\,\\,\\,\\infty} \\mathbf{w}_\\mathbf{s}(u)\\mathbf{w}_\\mathbf{s}(v)|u - v|\\,d_{u}\\,d_{v} where :math:`\\mathbf{w}_\\mathbf{s}(u)=\\sum_i w_i \\mathbb{1}_{[\\mathbf{s}_i, \\mathbf{s}_{i+1})}(u)` is the weight at location :math:`u` between bin locations :math:`s_i` and :math:`s_{i+1}`. Args: ray_samples: Ray samples to compute loss over densities: Predicted sample densities weights: Predicted weights from densities and sample locations """ if torch.is_tensor(densities): assert not torch.is_tensor(weights), "Cannot use both densities and weights" assert densities is not None # Compute the weight at each sample location weights = ray_samples.get_weights(densities) if torch.is_tensor(weights): assert not torch.is_tensor(densities), "Cannot use both densities and weights" assert weights is not None starts = ray_samples.spacing_starts ends = ray_samples.spacing_ends assert starts is not None and ends is not None, "Ray samples must have spacing starts and ends" midpoints = (starts + ends) / 2.0 # (..., num_samples, 1) loss = ( weights * weights[..., None, :, 0] * torch.abs(midpoints - midpoints[..., None, :, 0]) ) # (..., num_samples, num_samples) loss = torch.sum(loss, dim=(-1, -2))[..., None] # (..., num_samples) loss = loss + 1 / 3.0 * torch.sum(weights**2 * (ends - starts), dim=-2) return loss
[docs]def orientation_loss( weights: Float[Tensor, "*bs num_samples 1"], normals: Float[Tensor, "*bs num_samples 3"], viewdirs: Float[Tensor, "*bs 3"], ): """Orientation loss proposed in Ref-NeRF. Loss that encourages that all visible normals are facing towards the camera. """ w = weights n = normals v = viewdirs * -1 n_dot_v = (n * v[..., None, :]).sum(dim=-1) return (w[..., 0] * torch.fmin(torch.zeros_like(n_dot_v), n_dot_v) ** 2).sum(dim=-1)
[docs]def pred_normal_loss( weights: Float[Tensor, "*bs num_samples 1"], normals: Float[Tensor, "*bs num_samples 3"], pred_normals: Float[Tensor, "*bs num_samples 3"], ): """Loss between normals calculated from density and normals from prediction network.""" return (weights[..., 0] * (1.0 - torch.sum(normals * pred_normals, dim=-1))).sum(dim=-1)
[docs]def ds_nerf_depth_loss( weights: Float[Tensor, "*batch num_samples 1"], termination_depth: Float[Tensor, "*batch 1"], steps: Float[Tensor, "*batch num_samples 1"], lengths: Float[Tensor, "*batch num_samples 1"], sigma: Float[Tensor, "0"], ) -> Float[Tensor, "*batch 1"]: """Depth loss from Depth-supervised NeRF (Deng et al., 2022). Args: weights: Weights predicted for each sample. termination_depth: Ground truth depth of rays. steps: Sampling distances along rays. lengths: Distances between steps. sigma: Uncertainty around depth values. Returns: Depth loss scalar. """ depth_mask = termination_depth > 0 loss = -torch.log(weights + EPS) * torch.exp(-((steps - termination_depth[:, None]) ** 2) / (2 * sigma)) * lengths loss = loss.sum(-2) * depth_mask return torch.mean(loss)
[docs]def urban_radiance_field_depth_loss( weights: Float[Tensor, "*batch num_samples 1"], termination_depth: Float[Tensor, "*batch 1"], predicted_depth: Float[Tensor, "*batch 1"], steps: Float[Tensor, "*batch num_samples 1"], sigma: Float[Tensor, "0"], ) -> Float[Tensor, "*batch 1"]: """Lidar losses from Urban Radiance Fields (Rematas et al., 2022). Args: weights: Weights predicted for each sample. termination_depth: Ground truth depth of rays. predicted_depth: Depth prediction from the network. steps: Sampling distances along rays. sigma: Uncertainty around depth values. Returns: Depth loss scalar. """ depth_mask = termination_depth > 0 # Expected depth loss expected_depth_loss = (termination_depth - predicted_depth) ** 2 # Line of sight losses target_distribution = torch.distributions.normal.Normal(0.0, sigma / URF_SIGMA_SCALE_FACTOR) termination_depth = termination_depth[:, None] line_of_sight_loss_near_mask = torch.logical_and( steps <= termination_depth + sigma, steps >= termination_depth - sigma ) line_of_sight_loss_near = (weights - torch.exp(target_distribution.log_prob(steps - termination_depth))) ** 2 line_of_sight_loss_near = (line_of_sight_loss_near_mask * line_of_sight_loss_near).sum(-2) line_of_sight_loss_empty_mask = steps < termination_depth - sigma line_of_sight_loss_empty = (line_of_sight_loss_empty_mask * weights**2).sum(-2) line_of_sight_loss = line_of_sight_loss_near + line_of_sight_loss_empty loss = (expected_depth_loss + line_of_sight_loss) * depth_mask return torch.mean(loss)
[docs]def depth_loss( weights: Float[Tensor, "*batch num_samples 1"], ray_samples: RaySamples, termination_depth: Float[Tensor, "*batch 1"], predicted_depth: Float[Tensor, "*batch 1"], sigma: Float[Tensor, "0"], directions_norm: Float[Tensor, "*batch 1"], is_euclidean: bool, depth_loss_type: DepthLossType, ) -> Float[Tensor, "0"]: """Implementation of depth losses. Args: weights: Weights predicted for each sample. ray_samples: Samples along rays corresponding to weights. termination_depth: Ground truth depth of rays. predicted_depth: Depth prediction from the network. sigma: Uncertainty around depth value. directions_norm: Norms of ray direction vectors in the camera frame. is_euclidean: Whether ground truth depths corresponds to normalized direction vectors. depth_loss_type: Type of depth loss to apply. Returns: Depth loss scalar. """ if not is_euclidean: termination_depth = termination_depth * directions_norm steps = (ray_samples.frustums.starts + ray_samples.frustums.ends) / 2 if depth_loss_type == DepthLossType.DS_NERF: lengths = ray_samples.frustums.ends - ray_samples.frustums.starts return ds_nerf_depth_loss(weights, termination_depth, steps, lengths, sigma) if depth_loss_type == DepthLossType.URF: return urban_radiance_field_depth_loss(weights, termination_depth, predicted_depth, steps, sigma) raise NotImplementedError("Provided depth loss type not implemented.")
[docs]def monosdf_normal_loss( normal_pred: Float[Tensor, "num_samples 3"], normal_gt: Float[Tensor, "num_samples 3"] ) -> Float[Tensor, "0"]: """ Normal consistency loss proposed in monosdf - https://niujinshuchong.github.io/monosdf/ Enforces consistency between the volume rendered normal and the predicted monocular normal. With both angluar and L1 loss. Eq 14 https://arxiv.org/pdf/2206.00665.pdf Args: normal_pred: volume rendered normal normal_gt: monocular normal """ normal_gt = torch.nn.functional.normalize(normal_gt, p=2, dim=-1) normal_pred = torch.nn.functional.normalize(normal_pred, p=2, dim=-1) l1 = torch.abs(normal_pred - normal_gt).sum(dim=-1).mean() cos = (1.0 - torch.sum(normal_pred * normal_gt, dim=-1)).mean() return l1 + cos
[docs]class MiDaSMSELoss(nn.Module): """ data term from MiDaS paper """ def __init__(self, reduction_type: Literal["image", "batch"] = "batch"): super().__init__() self.reduction_type: Literal["image", "batch"] = reduction_type # reduction here is different from the image/batch-based reduction. This is either "mean" or "sum" self.mse_loss = MSELoss(reduction="none")
[docs] def forward( self, prediction: Float[Tensor, "1 32 mult"], target: Float[Tensor, "1 32 mult"], mask: Bool[Tensor, "1 32 mult"], ) -> Float[Tensor, "0"]: """ Args: prediction: predicted depth map target: ground truth depth map mask: mask of valid pixels Returns: mse loss based on reduction function """ summed_mask = torch.sum(mask, (1, 2)) image_loss = torch.sum(self.mse_loss(prediction, target) * mask, (1, 2)) # multiply by 2 magic number? image_loss = masked_reduction(image_loss, 2 * summed_mask, self.reduction_type) return image_loss
# losses based on https://github.com/autonomousvision/monosdf/blob/main/code/model/loss.py
[docs]class GradientLoss(nn.Module): """ multiscale, scale-invariant gradient matching term to the disparity space. This term biases discontinuities to be sharp and to coincide with discontinuities in the ground truth More info here https://arxiv.org/pdf/1907.01341.pdf Equation 11 """ def __init__(self, scales: int = 4, reduction_type: Literal["image", "batch"] = "batch"): """ Args: scales: number of scales to use reduction_type: either "batch" or "image" """ super().__init__() self.reduction_type: Literal["image", "batch"] = reduction_type self.__scales = scales
[docs] def forward( self, prediction: Float[Tensor, "1 32 mult"], target: Float[Tensor, "1 32 mult"], mask: Bool[Tensor, "1 32 mult"], ) -> Float[Tensor, "0"]: """ Args: prediction: predicted depth map target: ground truth depth map mask: mask of valid pixels Returns: gradient loss based on reduction function """ assert self.__scales >= 1 total = 0.0 for scale in range(self.__scales): step = pow(2, scale) grad_loss = self.gradient_loss( prediction[:, ::step, ::step], target[:, ::step, ::step], mask[:, ::step, ::step], ) total += grad_loss assert isinstance(total, Tensor) return total
[docs] def gradient_loss( self, prediction: Float[Tensor, "1 32 mult"], target: Float[Tensor, "1 32 mult"], mask: Bool[Tensor, "1 32 mult"], ) -> Float[Tensor, "0"]: """ multiscale, scale-invariant gradient matching term to the disparity space. This term biases discontinuities to be sharp and to coincide with discontinuities in the ground truth More info here https://arxiv.org/pdf/1907.01341.pdf Equation 11 Args: prediction: predicted depth map target: ground truth depth map reduction: reduction function, either reduction_batch_based or reduction_image_based Returns: gradient loss based on reduction function """ summed_mask = torch.sum(mask, (1, 2)) diff = prediction - target diff = torch.mul(mask, diff) grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1]) mask_x = torch.mul(mask[:, :, 1:], mask[:, :, :-1]) grad_x = torch.mul(mask_x, grad_x) grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :]) mask_y = torch.mul(mask[:, 1:, :], mask[:, :-1, :]) grad_y = torch.mul(mask_y, grad_y) image_loss = torch.sum(grad_x, (1, 2)) + torch.sum(grad_y, (1, 2)) image_loss = masked_reduction(image_loss, summed_mask, self.reduction_type) return image_loss
[docs]class ScaleAndShiftInvariantLoss(nn.Module): """ Scale and shift invariant loss as described in "Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer" https://arxiv.org/pdf/1907.01341.pdf """ def __init__(self, alpha: float = 0.5, scales: int = 4, reduction_type: Literal["image", "batch"] = "batch"): """ Args: alpha: weight of the regularization term scales: number of scales to use reduction_type: either "batch" or "image" """ super().__init__() self.__data_loss = MiDaSMSELoss(reduction_type=reduction_type) self.__regularization_loss = GradientLoss(scales=scales, reduction_type=reduction_type) self.__alpha = alpha self.__prediction_ssi = None
[docs] def forward( self, prediction: Float[Tensor, "1 32 mult"], target: Float[Tensor, "1 32 mult"], mask: Bool[Tensor, "1 32 mult"], ) -> Float[Tensor, "0"]: """ Args: prediction: predicted depth map (unnormalized) target: ground truth depth map (normalized) mask: mask of valid pixels Returns: scale and shift invariant loss """ scale, shift = normalized_depth_scale_and_shift(prediction, target, mask) self.__prediction_ssi = scale.view(-1, 1, 1) * prediction + shift.view(-1, 1, 1) total = self.__data_loss(self.__prediction_ssi, target, mask) if self.__alpha > 0: total += self.__alpha * self.__regularization_loss(self.__prediction_ssi, target, mask) return total
def __get_prediction_ssi(self): """ scale and shift invariant prediction from https://arxiv.org/pdf/1907.01341.pdf equation 1 """ return self.__prediction_ssi prediction_ssi = property(__get_prediction_ssi)
[docs]def tv_loss(grids: Float[Tensor, "grids feature_dim row column"]) -> Float[Tensor, ""]: """ https://github.com/apchenstu/TensoRF/blob/4ec894dc1341a2201fe13ae428631b58458f105d/utils.py#L139 Args: grids: stacks of explicit feature grids (stacked at dim 0) Returns: average total variation loss for neighbor rows and columns. """ number_of_grids = grids.shape[0] h_tv_count = grids[:, :, 1:, :].shape[1] * grids[:, :, 1:, :].shape[2] * grids[:, :, 1:, :].shape[3] w_tv_count = grids[:, :, :, 1:].shape[1] * grids[:, :, :, 1:].shape[2] * grids[:, :, :, 1:].shape[3] h_tv = torch.pow((grids[:, :, 1:, :] - grids[:, :, :-1, :]), 2).sum() w_tv = torch.pow((grids[:, :, :, 1:] - grids[:, :, :, :-1]), 2).sum() return 2 * (h_tv / h_tv_count + w_tv / w_tv_count) / number_of_grids
class _GradientScaler(torch.autograd.Function): # typing: ignore """ Scale gradients by a constant factor. """ @staticmethod def forward(ctx, value, scaling): ctx.save_for_backward(scaling) return value, scaling @staticmethod def backward(ctx, output_grad, grad_scaling): (scaling,) = ctx.saved_tensors return output_grad * scaling, grad_scaling
[docs]def scale_gradients_by_distance_squared( field_outputs: Dict[FieldHeadNames, torch.Tensor], ray_samples: RaySamples, ) -> Dict[FieldHeadNames, torch.Tensor]: """ Scale gradients by the ray distance to the pixel as suggested in `Radiance Field Gradient Scaling for Unbiased Near-Camera Training` paper Note: The scaling is applied on the interval of [0, 1] along the ray! Example: GradientLoss should be called right after obtaining the densities and colors from the field. :: >>> field_outputs = scale_gradient_by_distance_squared(field_outputs, ray_samples) """ out = {} ray_dist = (ray_samples.frustums.starts + ray_samples.frustums.ends) / 2 scaling = torch.square(ray_dist).clamp(0, 1) for key, value in field_outputs.items(): out[key], _ = cast(Tuple[Tensor, Tensor], _GradientScaler.apply(value, scaling)) return out
[docs]def depth_ranking_loss(rendered_depth, gt_depth): """ Depth ranking loss as described in the SparseNeRF paper Assumes that the layout of the batch comes from a PairPixelSampler, so that adjacent samples in the gt_depth and rendered_depth are from pixels with a radius of each other """ m = 1e-4 if rendered_depth.shape[0] % 2 != 0: # chop off one index rendered_depth = rendered_depth[:-1, :] gt_depth = gt_depth[:-1, :] dpt_diff = gt_depth[::2, :] - gt_depth[1::2, :] out_diff = rendered_depth[::2, :] - rendered_depth[1::2, :] + m differing_signs = torch.sign(dpt_diff) != torch.sign(out_diff) return torch.nanmean((out_diff[differing_signs] * torch.sign(out_diff[differing_signs])))