# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Collection of renderers
Example:
.. code-block:: python
field_outputs = field(ray_sampler)
weights = ray_sampler.get_weights(field_outputs[FieldHeadNames.DENSITY])
rgb_renderer = RGBRenderer()
rgb = rgb_renderer(rgb=field_outputs[FieldHeadNames.RGB], weights=weights)
"""
import contextlib
import math
from typing import Generator, Literal, Optional, Tuple, Union
import nerfacc
import torch
from jaxtyping import Float, Int
from torch import Tensor, nn
from nerfstudio.cameras.rays import RaySamples
from nerfstudio.utils import colors
from nerfstudio.utils.math import safe_normalize
from nerfstudio.utils.spherical_harmonics import components_from_spherical_harmonics
BackgroundColor = Union[Literal["random", "last_sample", "black", "white"], Float[Tensor, "3"], Float[Tensor, "*bs 3"]]
BACKGROUND_COLOR_OVERRIDE: Optional[Float[Tensor, "3"]] = None
[docs]@contextlib.contextmanager
def background_color_override_context(mode: Float[Tensor, "3"]) -> Generator[None, None, None]:
"""Context manager for setting background mode."""
global BACKGROUND_COLOR_OVERRIDE
old_background_color = BACKGROUND_COLOR_OVERRIDE
try:
BACKGROUND_COLOR_OVERRIDE = mode
yield
finally:
BACKGROUND_COLOR_OVERRIDE = old_background_color
[docs]class RGBRenderer(nn.Module):
"""Standard volumetric rendering.
Args:
background_color: Background color as RGB. Uses random colors if None.
"""
def __init__(self, background_color: BackgroundColor = "random") -> None:
super().__init__()
self.background_color: BackgroundColor = background_color
[docs] @classmethod
def combine_rgb(
cls,
rgb: Float[Tensor, "*bs num_samples 3"],
weights: Float[Tensor, "*bs num_samples 1"],
background_color: BackgroundColor = "random",
ray_indices: Optional[Int[Tensor, "num_samples"]] = None,
num_rays: Optional[int] = None,
) -> Float[Tensor, "*bs 3"]:
"""Composite samples along ray and render color image.
If background color is random, no BG color is added - as if the background was black!
Args:
rgb: RGB for each sample
weights: Weights for each sample
background_color: Background color as RGB.
ray_indices: Ray index for each sample, used when samples are packed.
num_rays: Number of rays, used when samples are packed.
Returns:
Outputs rgb values.
"""
if ray_indices is not None and num_rays is not None:
# Necessary for packed samples from volumetric ray sampler
if background_color == "last_sample":
raise NotImplementedError("Background color 'last_sample' not implemented for packed samples.")
comp_rgb = nerfacc.accumulate_along_rays(
weights[..., 0], values=rgb, ray_indices=ray_indices, n_rays=num_rays
)
accumulated_weight = nerfacc.accumulate_along_rays(
weights[..., 0], values=None, ray_indices=ray_indices, n_rays=num_rays
)
else:
comp_rgb = torch.sum(weights * rgb, dim=-2)
accumulated_weight = torch.sum(weights, dim=-2)
if BACKGROUND_COLOR_OVERRIDE is not None:
background_color = BACKGROUND_COLOR_OVERRIDE
if background_color == "random":
# If background color is random, the predicted color is returned without blending,
# as if the background color was black.
return comp_rgb
elif background_color == "last_sample":
# Note, this is only supported for non-packed samples.
background_color = rgb[..., -1, :]
background_color = cls.get_background_color(background_color, shape=comp_rgb.shape, device=comp_rgb.device)
assert isinstance(background_color, torch.Tensor)
comp_rgb = comp_rgb + background_color * (1.0 - accumulated_weight)
return comp_rgb
[docs] @classmethod
def get_background_color(
cls, background_color: BackgroundColor, shape: Tuple[int, ...], device: torch.device
) -> Union[Float[Tensor, "3"], Float[Tensor, "*bs 3"]]:
"""Returns the RGB background color for a specified background color.
Note:
This function CANNOT be called for background_color being either "last_sample" or "random".
Args:
background_color: The background color specification. If a string is provided, it must be a valid color name.
shape: Shape of the output tensor.
device: Device on which to create the tensor.
Returns:
Background color as RGB.
"""
assert background_color not in {"last_sample", "random"}
assert shape[-1] == 3, "Background color must be RGB."
if BACKGROUND_COLOR_OVERRIDE is not None:
background_color = BACKGROUND_COLOR_OVERRIDE
if isinstance(background_color, str) and background_color in colors.COLORS_DICT:
background_color = colors.COLORS_DICT[background_color]
assert isinstance(background_color, Tensor)
# Ensure correct shape
return background_color.expand(shape).to(device)
[docs] def blend_background(
self,
image: Tensor,
background_color: Optional[BackgroundColor] = None,
) -> Float[Tensor, "*bs 3"]:
"""Blends the background color into the image if image is RGBA.
Otherwise no blending is performed (we assume opacity of 1).
Args:
image: RGB/RGBA per pixel.
opacity: Alpha opacity per pixel.
background_color: Background color.
Returns:
Blended RGB.
"""
if image.size(-1) < 4:
return image
rgb, opacity = image[..., :3], image[..., 3:]
if background_color is None:
background_color = self.background_color
if background_color in {"last_sample", "random"}:
background_color = "black"
background_color = self.get_background_color(background_color, shape=rgb.shape, device=rgb.device)
assert isinstance(background_color, torch.Tensor)
return rgb * opacity + background_color.to(rgb.device) * (1 - opacity)
[docs] def blend_background_for_loss_computation(
self,
pred_image: Tensor,
pred_accumulation: Tensor,
gt_image: Tensor,
) -> Tuple[Tensor, Tensor]:
"""Blends a background color into the ground truth and predicted image for
loss computation.
Args:
gt_image: The ground truth image.
pred_image: The predicted RGB values (without background blending).
pred_accumulation: The predicted opacity/ accumulation.
Returns:
A tuple of the predicted and ground truth RGB values.
"""
background_color = self.background_color
if background_color == "last_sample":
background_color = "black" # No background blending for GT
elif background_color == "random":
background_color = torch.rand_like(pred_image)
pred_image = pred_image + background_color * (1.0 - pred_accumulation)
gt_image = self.blend_background(gt_image, background_color=background_color)
return pred_image, gt_image
[docs] def forward(
self,
rgb: Float[Tensor, "*bs num_samples 3"],
weights: Float[Tensor, "*bs num_samples 1"],
ray_indices: Optional[Int[Tensor, "num_samples"]] = None,
num_rays: Optional[int] = None,
background_color: Optional[BackgroundColor] = None,
) -> Float[Tensor, "*bs 3"]:
"""Composite samples along ray and render color image
Args:
rgb: RGB for each sample
weights: Weights for each sample
ray_indices: Ray index for each sample, used when samples are packed.
num_rays: Number of rays, used when samples are packed.
background_color: The background color to use for rendering.
Returns:
Outputs of rgb values.
"""
if background_color is None:
background_color = self.background_color
if not self.training:
rgb = torch.nan_to_num(rgb)
rgb = self.combine_rgb(
rgb, weights, background_color=background_color, ray_indices=ray_indices, num_rays=num_rays
)
if not self.training:
torch.clamp_(rgb, min=0.0, max=1.0)
return rgb
[docs]class SHRenderer(nn.Module):
"""Render RGB value from spherical harmonics.
Args:
background_color: Background color as RGB. Uses random colors if None
activation: Output activation.
"""
def __init__(
self,
background_color: BackgroundColor = "random",
activation: Optional[nn.Module] = nn.Sigmoid(),
) -> None:
super().__init__()
self.background_color: BackgroundColor = background_color
self.activation = activation
[docs] def forward(
self,
sh: Float[Tensor, "*batch num_samples coeffs"],
directions: Float[Tensor, "*batch num_samples 3"],
weights: Float[Tensor, "*batch num_samples 1"],
) -> Float[Tensor, "*batch 3"]:
"""Composite samples along ray and render color image
Args:
sh: Spherical harmonics coefficients for each sample
directions: Sample direction
weights: Weights for each sample
Returns:
Outputs of rgb values.
"""
sh = sh.view(*sh.shape[:-1], 3, sh.shape[-1] // 3)
levels = int(math.sqrt(sh.shape[-1]))
components = components_from_spherical_harmonics(degree=levels - 1, directions=directions)
rgb = sh * components[..., None, :] # [..., num_samples, 3, sh_components]
rgb = torch.sum(rgb, dim=-1) # [..., num_samples, 3]
if self.activation is not None:
rgb = self.activation(rgb)
if not self.training:
rgb = torch.nan_to_num(rgb)
rgb = RGBRenderer.combine_rgb(rgb, weights, background_color=self.background_color)
if not self.training:
torch.clamp_(rgb, min=0.0, max=1.0)
return rgb
[docs]class AccumulationRenderer(nn.Module):
"""Accumulated value along a ray."""
[docs] @classmethod
def forward(
cls,
weights: Float[Tensor, "*bs num_samples 1"],
ray_indices: Optional[Int[Tensor, "num_samples"]] = None,
num_rays: Optional[int] = None,
) -> Float[Tensor, "*bs 1"]:
"""Composite samples along ray and calculate accumulation.
Args:
weights: Weights for each sample
ray_indices: Ray index for each sample, used when samples are packed.
num_rays: Number of rays, used when samples are packed.
Returns:
Outputs of accumulated values.
"""
if ray_indices is not None and num_rays is not None:
# Necessary for packed samples from volumetric ray sampler
accumulation = nerfacc.accumulate_along_rays(
weights[..., 0], values=None, ray_indices=ray_indices, n_rays=num_rays
)
else:
accumulation = torch.sum(weights, dim=-2)
return accumulation
[docs]class DepthRenderer(nn.Module):
"""Calculate depth along ray.
Depth Method:
- median: Depth is set to the distance where the accumulated weight reaches 0.5.
- expected: Expected depth along ray. Same procedure as rendering rgb, but with depth.
Args:
method: Depth calculation method.
"""
def __init__(self, method: Literal["median", "expected"] = "median") -> None:
super().__init__()
self.method = method
[docs] def forward(
self,
weights: Float[Tensor, "*batch num_samples 1"],
ray_samples: RaySamples,
ray_indices: Optional[Int[Tensor, "num_samples"]] = None,
num_rays: Optional[int] = None,
) -> Float[Tensor, "*batch 1"]:
"""Composite samples along ray and calculate depths.
Args:
weights: Weights for each sample.
ray_samples: Set of ray samples.
ray_indices: Ray index for each sample, used when samples are packed.
num_rays: Number of rays, used when samples are packed.
Returns:
Outputs of depth values.
"""
if self.method == "median":
steps = (ray_samples.frustums.starts + ray_samples.frustums.ends) / 2
if ray_indices is not None and num_rays is not None:
raise NotImplementedError("Median depth calculation is not implemented for packed samples.")
cumulative_weights = torch.cumsum(weights[..., 0], dim=-1) # [..., num_samples]
split = torch.ones((*weights.shape[:-2], 1), device=weights.device) * 0.5 # [..., 1]
median_index = torch.searchsorted(cumulative_weights, split, side="left") # [..., 1]
median_index = torch.clamp(median_index, 0, steps.shape[-2] - 1) # [..., 1]
median_depth = torch.gather(steps[..., 0], dim=-1, index=median_index) # [..., 1]
return median_depth
if self.method == "expected":
eps = 1e-10
steps = (ray_samples.frustums.starts + ray_samples.frustums.ends) / 2
if ray_indices is not None and num_rays is not None:
# Necessary for packed samples from volumetric ray sampler
depth = nerfacc.accumulate_along_rays(
weights[..., 0], values=steps, ray_indices=ray_indices, n_rays=num_rays
)
accumulation = nerfacc.accumulate_along_rays(
weights[..., 0], values=None, ray_indices=ray_indices, n_rays=num_rays
)
depth = depth / (accumulation + eps)
else:
depth = torch.sum(weights * steps, dim=-2) / (torch.sum(weights, -2) + eps)
depth = torch.clip(depth, steps.min(), steps.max())
return depth
raise NotImplementedError(f"Method {self.method} not implemented")
[docs]class UncertaintyRenderer(nn.Module):
"""Calculate uncertainty along the ray."""
[docs] @classmethod
def forward(
cls, betas: Float[Tensor, "*bs num_samples 1"], weights: Float[Tensor, "*bs num_samples 1"]
) -> Float[Tensor, "*bs 1"]:
"""Calculate uncertainty along the ray.
Args:
betas: Uncertainty betas for each sample.
weights: Weights of each sample.
Returns:
Rendering of uncertainty.
"""
uncertainty = torch.sum(weights * betas, dim=-2)
return uncertainty
[docs]class SemanticRenderer(nn.Module):
"""Calculate semantics along the ray."""
[docs] @classmethod
def forward(
cls,
semantics: Float[Tensor, "*bs num_samples num_classes"],
weights: Float[Tensor, "*bs num_samples 1"],
ray_indices: Optional[Int[Tensor, "num_samples"]] = None,
num_rays: Optional[int] = None,
) -> Float[Tensor, "*bs num_classes"]:
"""Calculate semantics along the ray."""
if ray_indices is not None and num_rays is not None:
# Necessary for packed samples from volumetric ray sampler
return nerfacc.accumulate_along_rays(
weights[..., 0], values=semantics, ray_indices=ray_indices, n_rays=num_rays
)
else:
return torch.sum(weights * semantics, dim=-2)
[docs]class NormalsRenderer(nn.Module):
"""Calculate normals along the ray."""
[docs] @classmethod
def forward(
cls,
normals: Float[Tensor, "*bs num_samples 3"],
weights: Float[Tensor, "*bs num_samples 1"],
normalize: bool = True,
) -> Float[Tensor, "*bs 3"]:
"""Calculate normals along the ray.
Args:
normals: Normals for each sample.
weights: Weights of each sample.
normalize: Normalize normals.
"""
n = torch.sum(weights * normals, dim=-2)
if normalize:
n = safe_normalize(n)
return n