Source code for nerfstudio.data.pixel_samplers

# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Code for sampling pixels.
"""

import random
import warnings
from dataclasses import dataclass, field
from typing import Dict, Optional, Type, Union

import torch
from jaxtyping import Int
from torch import Tensor

from nerfstudio.configs.base_config import InstantiateConfig
from nerfstudio.data.utils.pixel_sampling_utils import divide_rays_per_image, erode_mask


[docs]@dataclass class PixelSamplerConfig(InstantiateConfig): """Configuration for pixel sampler instantiation.""" _target: Type = field(default_factory=lambda: PixelSampler) """Target class to instantiate.""" num_rays_per_batch: int = 4096 """Number of rays to sample per batch.""" keep_full_image: bool = False """Whether or not to include a reference to the full image in returned batch.""" is_equirectangular: bool = False """List of whether or not camera i is equirectangular.""" ignore_mask: bool = False """Whether to ignore the masks when sampling.""" fisheye_crop_radius: Optional[float] = None """Set to the radius (in pixels) for fisheye cameras.""" rejection_sample_mask: bool = True """Whether or not to use rejection sampling when sampling images with masks""" max_num_iterations: int = 100 """If rejection sampling masks, the maximum number of times to sample"""
[docs]class PixelSampler: """Samples 'pixel_batch's from 'image_batch's. Args: config: the DataManagerConfig used to instantiate class """ config: PixelSamplerConfig def __init__(self, config: PixelSamplerConfig, **kwargs) -> None: self.kwargs = kwargs self.config = config # Possibly override some values if they are present in the kwargs dictionary self.config.num_rays_per_batch = self.kwargs.get("num_rays_per_batch", self.config.num_rays_per_batch) self.config.keep_full_image = self.kwargs.get("keep_full_image", self.config.keep_full_image) self.config.is_equirectangular = self.kwargs.get("is_equirectangular", self.config.is_equirectangular) self.config.fisheye_crop_radius = self.kwargs.get("fisheye_crop_radius", self.config.fisheye_crop_radius) self.set_num_rays_per_batch(self.config.num_rays_per_batch)
[docs] def set_num_rays_per_batch(self, num_rays_per_batch: int): """Set the number of rays to sample per batch. Args: num_rays_per_batch: number of rays to sample per batch """ self.num_rays_per_batch = num_rays_per_batch
[docs] def sample_method( self, batch_size: int, num_images: int, image_height: int, image_width: int, mask: Optional[Tensor] = None, device: Union[torch.device, str] = "cpu", ) -> Int[Tensor, "batch_size 3"]: """ Naive pixel sampler, uniformly samples across all possible pixels of all possible images. Args: batch_size: number of samples in a batch num_images: number of images to sample over mask: mask of possible pixels in an image to sample from. """ indices = ( torch.rand((batch_size, 3), device=device) * torch.tensor([num_images, image_height, image_width], device=device) ).long() if isinstance(mask, torch.Tensor) and not self.config.ignore_mask: if self.config.rejection_sample_mask: num_valid = 0 for _ in range(self.config.max_num_iterations): c, y, x = (i.flatten() for i in torch.split(indices, 1, dim=-1)) chosen_indices_validity = mask[..., 0][c, y, x].bool() num_valid = int(torch.sum(chosen_indices_validity).item()) if num_valid == batch_size: break else: replacement_indices = ( torch.rand((batch_size - num_valid, 3), device=device) * torch.tensor([num_images, image_height, image_width], device=device) ).long() indices[~chosen_indices_validity] = replacement_indices if num_valid != batch_size: warnings.warn( """ Masked sampling failed, mask is either empty or mostly empty. Reverting behavior to non-rejection sampling. Consider setting pipeline.datamanager.pixel-sampler.rejection-sample-mask to False or increasing pipeline.datamanager.pixel-sampler.max-num-iterations """ ) self.config.rejection_sample_mask = False nonzero_indices = torch.nonzero(mask[..., 0], as_tuple=False) chosen_indices = random.sample(range(len(nonzero_indices)), k=batch_size) indices = nonzero_indices[chosen_indices] else: nonzero_indices = torch.nonzero(mask[..., 0], as_tuple=False) chosen_indices = random.sample(range(len(nonzero_indices)), k=batch_size) indices = nonzero_indices[chosen_indices] return indices
def sample_method_equirectangular( self, batch_size: int, num_images: int, image_height: int, image_width: int, mask: Optional[Tensor] = None, device: Union[torch.device, str] = "cpu", ) -> Int[Tensor, "batch_size 3"]: if isinstance(mask, torch.Tensor) and not self.config.ignore_mask: # Note: if there is a mask, sampling reduces back to uniform sampling, which gives more # sampling weight to the poles of the image than the equators. # TODO(kevinddchen): implement the correct mask-sampling method. indices = self.sample_method(batch_size, num_images, image_height, image_width, mask=mask, device=device) else: # We sample theta uniformly in [0, 2*pi] # We sample phi in [0, pi] according to the PDF f(phi) = sin(phi) / 2. # This is done by inverse transform sampling. # http://corysimon.github.io/articles/uniformdistn-on-sphere/ num_images_rand = torch.rand(batch_size, device=device) phi_rand = torch.acos(1 - 2 * torch.rand(batch_size, device=device)) / torch.pi theta_rand = torch.rand(batch_size, device=device) indices = torch.floor( torch.stack((num_images_rand, phi_rand, theta_rand), dim=-1) * torch.tensor([num_images, image_height, image_width], device=device) ).long() return indices def sample_method_fisheye( self, batch_size: int, num_images: int, image_height: int, image_width: int, mask: Optional[Tensor] = None, device: Union[torch.device, str] = "cpu", ) -> Int[Tensor, "batch_size 3"]: if isinstance(mask, torch.Tensor) and not self.config.ignore_mask: indices = self.sample_method(batch_size, num_images, image_height, image_width, mask=mask, device=device) else: # Rejection sampling. valid: Optional[torch.Tensor] = None indices = None while True: samples_needed = batch_size if valid is None else int(batch_size - torch.sum(valid).item()) # Check if done! if samples_needed == 0: break rand_samples = torch.rand((samples_needed, 2), device=device) # Convert random samples to radius and theta. radii = self.config.fisheye_crop_radius * torch.sqrt(rand_samples[:, 0]) theta = 2.0 * torch.pi * rand_samples[:, 1] # Convert radius and theta to x and y. x = (radii * torch.cos(theta) + image_width // 2).long() y = (radii * torch.sin(theta) + image_height // 2).long() sampled_indices = torch.stack( [torch.randint(0, num_images, size=(samples_needed,), device=device), y, x], dim=-1 ) # Update indices. if valid is None: indices = sampled_indices valid = ( (sampled_indices[:, 1] >= 0) & (sampled_indices[:, 1] < image_height) & (sampled_indices[:, 2] >= 0) & (sampled_indices[:, 2] < image_width) ) else: assert indices is not None not_valid = ~valid indices[not_valid, :] = sampled_indices valid[not_valid] = ( (sampled_indices[:, 1] >= 0) & (sampled_indices[:, 1] < image_height) & (sampled_indices[:, 2] >= 0) & (sampled_indices[:, 2] < image_width) ) assert indices is not None assert indices.shape == (batch_size, 3) return indices
[docs] def collate_image_dataset_batch(self, batch: Dict, num_rays_per_batch: int, keep_full_image: bool = False): """ Operates on a batch of images and samples pixels to use for generating rays. Returns a collated batch which is input to the Graph. It will sample only within the valid 'mask' if it's specified. Args: batch: batch of images to sample from num_rays_per_batch: number of rays to sample per batch keep_full_image: whether or not to include a reference to the full image in returned batch """ device = batch["image"].device num_images, image_height, image_width, _ = batch["image"].shape if "mask" in batch: if self.config.is_equirectangular: indices = self.sample_method_equirectangular( num_rays_per_batch, num_images, image_height, image_width, mask=batch["mask"], device=device ) elif self.config.fisheye_crop_radius is not None: indices = self.sample_method_fisheye( num_rays_per_batch, num_images, image_height, image_width, mask=batch["mask"], device=device ) else: indices = self.sample_method( num_rays_per_batch, num_images, image_height, image_width, mask=batch["mask"], device=device ) else: if self.config.is_equirectangular: indices = self.sample_method_equirectangular( num_rays_per_batch, num_images, image_height, image_width, device=device ) elif self.config.fisheye_crop_radius is not None: indices = self.sample_method_fisheye( num_rays_per_batch, num_images, image_height, image_width, device=device ) else: indices = self.sample_method(num_rays_per_batch, num_images, image_height, image_width, device=device) c, y, x = (i.flatten() for i in torch.split(indices, 1, dim=-1)) c, y, x = c.cpu(), y.cpu(), x.cpu() collated_batch = { key: value[c, y, x] for key, value in batch.items() if key != "image_idx" and value is not None } assert collated_batch["image"].shape[0] == num_rays_per_batch # Needed to correct the random indices to their actual camera idx locations. indices[:, 0] = batch["image_idx"][c] collated_batch["indices"] = indices # with the abs camera indices if keep_full_image: collated_batch["full_image"] = batch["image"] return collated_batch
[docs] def collate_image_dataset_batch_list(self, batch: Dict, num_rays_per_batch: int, keep_full_image: bool = False): """ Does the same as collate_image_dataset_batch, except it will operate over a list of images / masks inside a list. We will use this with the intent of DEPRECIATING it as soon as we find a viable alternative. The intention will be to replace this with a more efficient implementation that doesn't require a for loop, but since pytorch's ragged tensors are still in beta (this would allow for some vectorization), this will do. Args: batch: batch of images to sample from num_rays_per_batch: number of rays to sample per batch keep_full_image: whether or not to include a reference to the full image in returned batch """ device = batch["image"][0].device num_images = len(batch["image"]) # only sample within the mask, if the mask is in the batch all_indices = [] all_images = [] all_depth_images = [] assert num_rays_per_batch % 2 == 0, "num_rays_per_batch must be divisible by 2" num_rays_per_image = divide_rays_per_image(num_rays_per_batch, num_images) if "mask" in batch: for i, num_rays in enumerate(num_rays_per_image): image_height, image_width, _ = batch["image"][i].shape indices = self.sample_method( num_rays, 1, image_height, image_width, mask=batch["mask"][i].unsqueeze(0), device=device ) indices[:, 0] = i all_indices.append(indices) all_images.append(batch["image"][i][indices[:, 1], indices[:, 2]]) if "depth_image" in batch: all_depth_images.append(batch["depth_image"][i][indices[:, 1], indices[:, 2]]) else: for i, num_rays in enumerate(num_rays_per_image): image_height, image_width, _ = batch["image"][i].shape if self.config.is_equirectangular: indices = self.sample_method_equirectangular(num_rays, 1, image_height, image_width, device=device) else: indices = self.sample_method(num_rays, 1, image_height, image_width, device=device) indices[:, 0] = i all_indices.append(indices) all_images.append(batch["image"][i][indices[:, 1], indices[:, 2]]) if "depth_image" in batch: all_depth_images.append(batch["depth_image"][i][indices[:, 1], indices[:, 2]]) indices = torch.cat(all_indices, dim=0) c, y, x = (i.flatten() for i in torch.split(indices, 1, dim=-1)) collated_batch = { key: value[c, y, x] for key, value in batch.items() if key != "image_idx" and key != "image" and key != "mask" and key != "depth_image" and value is not None } collated_batch["image"] = torch.cat(all_images, dim=0) if "depth_image" in batch: collated_batch["depth_image"] = torch.cat(all_depth_images, dim=0) assert collated_batch["image"].shape[0] == num_rays_per_batch # Needed to correct the random indices to their actual camera idx locations. indices[:, 0] = batch["image_idx"][c] collated_batch["indices"] = indices # with the abs camera indices if keep_full_image: collated_batch["full_image"] = batch["image"] return collated_batch
[docs] def sample(self, image_batch: Dict): """Sample an image batch and return a pixel batch. Args: image_batch: batch of images to sample from """ if isinstance(image_batch["image"], list): image_batch = dict(image_batch.items()) # copy the dictionary so we don't modify the original pixel_batch = self.collate_image_dataset_batch_list( image_batch, self.num_rays_per_batch, keep_full_image=self.config.keep_full_image ) elif isinstance(image_batch["image"], torch.Tensor): pixel_batch = self.collate_image_dataset_batch( image_batch, self.num_rays_per_batch, keep_full_image=self.config.keep_full_image ) else: raise ValueError("image_batch['image'] must be a list or torch.Tensor") return pixel_batch
[docs]@dataclass class PatchPixelSamplerConfig(PixelSamplerConfig): """Config dataclass for PatchPixelSampler.""" _target: Type = field(default_factory=lambda: PatchPixelSampler) """Target class to instantiate.""" patch_size: int = 32 """Side length of patch. This must be consistent in the method config in order for samples to be reshaped into patches correctly."""
[docs]class PatchPixelSampler(PixelSampler): """Samples 'pixel_batch's from 'image_batch's. Samples square patches from the images randomly. Useful for patch-based losses. Args: config: the PatchPixelSamplerConfig used to instantiate class """ config: PatchPixelSamplerConfig def __init__(self, config: PatchPixelSamplerConfig, **kwargs) -> None: super().__init__(config, **kwargs) self.config.patch_size = self.kwargs.get("patch_size", self.config.patch_size)
[docs] def set_num_rays_per_batch(self, num_rays_per_batch: int): """Set the number of rays to sample per batch. Overridden to deal with patch-based sampling. Args: num_rays_per_batch: number of rays to sample per batch """ self.num_rays_per_batch = (num_rays_per_batch // (self.config.patch_size**2)) * (self.config.patch_size**2)
# overrides base method
[docs] def sample_method( self, batch_size: int, num_images: int, image_height: int, image_width: int, mask: Optional[Tensor] = None, device: Union[torch.device, str] = "cpu", ) -> Int[Tensor, "batch_size 3"]: if isinstance(mask, Tensor) and not self.config.ignore_mask: sub_bs = batch_size // (self.config.patch_size**2) half_patch_size = int(self.config.patch_size / 2) m = erode_mask(mask.permute(0, 3, 1, 2).float(), pixel_radius=half_patch_size) nonzero_indices = torch.nonzero(m[:, 0], as_tuple=False).to(device) chosen_indices = random.sample(range(len(nonzero_indices)), k=sub_bs) indices = nonzero_indices[chosen_indices] indices = ( indices.view(sub_bs, 1, 1, 3) .broadcast_to(sub_bs, self.config.patch_size, self.config.patch_size, 3) .clone() ) yys, xxs = torch.meshgrid( torch.arange(self.config.patch_size, device=device), torch.arange(self.config.patch_size, device=device) ) indices[:, ..., 1] += yys - half_patch_size indices[:, ..., 2] += xxs - half_patch_size indices = torch.floor(indices).long() indices = indices.flatten(0, 2) else: sub_bs = batch_size // (self.config.patch_size**2) indices = torch.rand((sub_bs, 3), device=device) * torch.tensor( [num_images, image_height - self.config.patch_size, image_width - self.config.patch_size], device=device, ) indices = ( indices.view(sub_bs, 1, 1, 3) .broadcast_to(sub_bs, self.config.patch_size, self.config.patch_size, 3) .clone() ) yys, xxs = torch.meshgrid( torch.arange(self.config.patch_size, device=device), torch.arange(self.config.patch_size, device=device) ) indices[:, ..., 1] += yys indices[:, ..., 2] += xxs indices = torch.floor(indices).long() indices = indices.flatten(0, 2) return indices
[docs]@dataclass class PairPixelSamplerConfig(PixelSamplerConfig): """Config dataclass for PairPixelSampler.""" _target: Type = field(default_factory=lambda: PairPixelSampler) """Target class to instantiate.""" radius: int = 2 """max distance between pairs of pixels."""
[docs]class PairPixelSampler(PixelSampler): # pylint: disable=too-few-public-methods """Samples pair of pixels from 'image_batch's. Samples pairs of pixels from from the images randomly within a 'radius' distance apart. Useful for pair-based losses. Args: config: the PairPixelSamplerConfig used to instantiate class """ def __init__(self, config: PairPixelSamplerConfig, **kwargs) -> None: self.config = config self.radius = self.config.radius super().__init__(self.config, **kwargs) self.rays_to_sample = self.config.num_rays_per_batch // 2 # overrides base method
[docs] def sample_method( # pylint: disable=no-self-use self, batch_size: Optional[int], num_images: int, image_height: int, image_width: int, mask: Optional[Tensor] = None, device: Union[torch.device, str] = "cpu", ) -> Int[Tensor, "batch_size 3"]: rays_to_sample = self.rays_to_sample if batch_size is not None: assert ( int(batch_size) % 2 == 0 ), f"PairPixelSampler can only return batch sizes in multiples of two (got {batch_size})" rays_to_sample = batch_size // 2 if isinstance(mask, Tensor) and not self.config.ignore_mask: m = erode_mask(mask.permute(0, 3, 1, 2).float(), pixel_radius=self.radius) nonzero_indices = torch.nonzero(m[:, 0], as_tuple=False).to(device) chosen_indices = random.sample(range(len(nonzero_indices)), k=rays_to_sample) indices = nonzero_indices[chosen_indices] else: s = (rays_to_sample, 1) ns = torch.randint(0, num_images, s, dtype=torch.long, device=device) hs = torch.randint(self.radius, image_height - self.radius, s, dtype=torch.long, device=device) ws = torch.randint(self.radius, image_width - self.radius, s, dtype=torch.long, device=device) indices = torch.concat((ns, hs, ws), dim=1) pair_indices = torch.hstack( ( torch.zeros(rays_to_sample, 1, device=device, dtype=torch.long), torch.randint(-self.radius, self.radius, (rays_to_sample, 2), device=device, dtype=torch.long), ) ) pair_indices += indices indices = torch.hstack((indices, pair_indices)).view(rays_to_sample * 2, 3) return indices