Source code for nerfstudio.data.datamanagers.base_datamanager

# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Datamanager.
"""

from __future__ import annotations

from abc import abstractmethod
from collections import defaultdict
from dataclasses import dataclass, field
from functools import cached_property
from pathlib import Path
from typing import (
    Any,
    Callable,
    Dict,
    ForwardRef,
    Generic,
    List,
    Literal,
    Optional,
    Tuple,
    Type,
    Union,
    cast,
    get_args,
    get_origin,
)

import torch
import tyro
from torch import nn
from torch.nn import Parameter
from torch.utils.data.distributed import DistributedSampler
from typing_extensions import TypeVar

from nerfstudio.cameras.camera_optimizers import CameraOptimizerConfig
from nerfstudio.cameras.cameras import Cameras, CameraType
from nerfstudio.cameras.rays import RayBundle
from nerfstudio.configs.base_config import InstantiateConfig
from nerfstudio.configs.dataparser_configs import AnnotatedDataParserUnion
from nerfstudio.data.dataparsers.base_dataparser import DataparserOutputs
from nerfstudio.data.dataparsers.blender_dataparser import BlenderDataParserConfig
from nerfstudio.data.datasets.base_dataset import InputDataset
from nerfstudio.data.pixel_samplers import PatchPixelSamplerConfig, PixelSampler, PixelSamplerConfig
from nerfstudio.data.utils.dataloaders import CacheDataloader, FixedIndicesEvalDataloader, RandIndicesEvalDataloader
from nerfstudio.data.utils.nerfstudio_collate import nerfstudio_collate
from nerfstudio.engine.callbacks import TrainingCallback, TrainingCallbackAttributes
from nerfstudio.model_components.ray_generators import RayGenerator
from nerfstudio.utils.misc import IterableWrapper, get_orig_class
from nerfstudio.utils.rich_utils import CONSOLE


[docs]def variable_res_collate(batch: List[Dict]) -> Dict: """Default collate function for the cached dataloader. Args: batch: Batch of samples from the dataset. Returns: Collated batch. """ images = [] imgdata_lists = defaultdict(list) for data in batch: image = data.pop("image") images.append(image) topop = [] for key, val in data.items(): if isinstance(val, torch.Tensor): # if the value has same height and width as the image, assume that it should be collated accordingly. if len(val.shape) >= 2 and val.shape[:2] == image.shape[:2]: imgdata_lists[key].append(val) topop.append(key) # now that iteration is complete, the image data items can be removed from the batch for key in topop: del data[key] new_batch = nerfstudio_collate(batch) new_batch["image"] = images new_batch.update(imgdata_lists) return new_batch
[docs]@dataclass class DataManagerConfig(InstantiateConfig): """Configuration for data manager instantiation; DataManager is in charge of keeping the train/eval dataparsers; After instantiation, data manager holds both train/eval datasets and is in charge of returning unpacked train/eval data at each iteration """ _target: Type = field(default_factory=lambda: DataManager) """Target class to instantiate.""" data: Optional[Path] = None """Source of data, may not be used by all models.""" masks_on_gpu: bool = False """Process masks on GPU for speed at the expense of memory, if True.""" images_on_gpu: bool = False """Process images on GPU for speed at the expense of memory, if True."""
[docs]class DataManager(nn.Module): """Generic data manager's abstract class This version of the data manager is designed be a monolithic way to load data and latents, especially since this may contain learnable parameters which need to be shared across the train and test data managers. The idea is that we have setup methods for train and eval separately and this can be a combined train/eval if you want. Usage: To get data, use the next_train and next_eval functions. This data manager's next_train and next_eval methods will return 2 things: 1. 'rays': This will contain the rays or camera we are sampling, with latents and conditionals attached (everything needed at inference) 2. A "batch" of auxiliary information: This will contain the mask, the ground truth pixels, etc needed to actually train, score, etc the model Rationale: Because of this abstraction we've added, we can support more NeRF paradigms beyond the vanilla nerf paradigm of single-scene, fixed-images, no-learnt-latents. We can now support variable scenes, variable number of images, and arbitrary latents. Train Methods: setup_train: sets up for being used as train iter_train: will be called on __iter__() for the train iterator next_train: will be called on __next__() for the training iterator get_train_iterable: utility that gets a clean pythonic iterator for your training data Eval Methods: setup_eval: sets up for being used as eval iter_eval: will be called on __iter__() for the eval iterator next_eval: will be called on __next__() for the eval iterator get_eval_iterable: utility that gets a clean pythonic iterator for your eval data Attributes: train_count (int): the step number of our train iteration, needs to be incremented manually eval_count (int): the step number of our eval iteration, needs to be incremented manually train_dataset (Dataset): the dataset for the train dataset eval_dataset (Dataset): the dataset for the eval dataset includes_time (bool): whether the dataset includes time information Additional attributes specific to each subclass are defined in the setup_train and setup_eval functions. """ train_dataset: Optional[InputDataset] = None eval_dataset: Optional[InputDataset] = None train_sampler: Optional[DistributedSampler] = None eval_sampler: Optional[DistributedSampler] = None includes_time: bool = False def __init__(self): """Constructor for the DataManager class. Subclassed DataManagers will likely need to override this constructor. If you aren't manually calling the setup_train and setup_eval functions from an overriden constructor, that you call super().__init__() BEFORE you initialize any nn.Modules or nn.Parameters, but AFTER you've already set all the attributes you need for the setup functions.""" super().__init__() self.train_count = 0 self.eval_count = 0 if self.train_dataset and self.test_mode != "inference": self.setup_train() if self.eval_dataset and self.test_mode != "inference": self.setup_eval()
[docs] def forward(self): """Blank forward method This is an nn.Module, and so requires a forward() method normally, although in our case we do not need a forward() method""" raise NotImplementedError
[docs] def iter_train(self): """The __iter__ function for the train iterator. This only exists to assist the get_train_iterable function, since we need to pass in an __iter__ function for our trivial iterable that we are making.""" self.train_count = 0
[docs] def iter_eval(self): """The __iter__ function for the eval iterator. This only exists to assist the get_eval_iterable function, since we need to pass in an __iter__ function for our trivial iterable that we are making.""" self.eval_count = 0
[docs] def get_train_iterable(self, length=-1) -> IterableWrapper: """Gets a trivial pythonic iterator that will use the iter_train and next_train functions as __iter__ and __next__ methods respectively. This basically is just a little utility if you want to do something like: | for ray_bundle, batch in datamanager.get_train_iterable(): | <eval code here> since the returned IterableWrapper is just an iterator with the __iter__ and __next__ methods (methods bound to our DataManager instance in this case) specified in the constructor. """ return IterableWrapper(self.iter_train, self.next_train, length)
[docs] def get_eval_iterable(self, length=-1) -> IterableWrapper: """Gets a trivial pythonic iterator that will use the iter_eval and next_eval functions as __iter__ and __next__ methods respectively. This basically is just a little utility if you want to do something like: | for ray_bundle, batch in datamanager.get_eval_iterable(): | <eval code here> since the returned IterableWrapper is just an iterator with the __iter__ and __next__ methods (methods bound to our DataManager instance in this case) specified in the constructor. """ return IterableWrapper(self.iter_eval, self.next_eval, length)
[docs] @abstractmethod def setup_train(self): """Sets up the data manager for training. Here you will define any subclass specific object attributes from the attribute"""
[docs] @abstractmethod def setup_eval(self): """Sets up the data manager for evaluation"""
[docs] @abstractmethod def next_train(self, step: int) -> Tuple[Union[RayBundle, Cameras], Dict]: """Returns the next batch of data from the train data manager. Args: step: the step number of the eval image to retrieve Returns: A tuple of the ray bundle for the image, and a dictionary of additional batch information such as the groundtruth image. """ raise NotImplementedError
[docs] @abstractmethod def next_eval(self, step: int) -> Tuple[Union[RayBundle, Cameras], Dict]: """Returns the next batch of data from the eval data manager. Args: step: the step number of the eval image to retrieve Returns: A tuple of the ray/camera for the image, and a dictionary of additional batch information such as the groundtruth image. """ raise NotImplementedError
[docs] @abstractmethod def next_eval_image(self, step: int) -> Tuple[Cameras, Dict]: """Retrieve the next eval image. Args: step: the step number of the eval image to retrieve Returns: A tuple of the step number, the ray/camera for the image, and a dictionary of additional batch information such as the groundtruth image. """ raise NotImplementedError
[docs] @abstractmethod def get_train_rays_per_batch(self) -> int: """Returns the number of rays per batch for training.""" raise NotImplementedError
[docs] @abstractmethod def get_eval_rays_per_batch(self) -> int: """Returns the number of rays per batch for evaluation.""" raise NotImplementedError
[docs] @abstractmethod def get_datapath(self) -> Path: """Returns the path to the data. This is used to determine where to save camera paths."""
[docs] def get_training_callbacks( self, training_callback_attributes: TrainingCallbackAttributes ) -> List[TrainingCallback]: """Returns a list of callbacks to be used during training.""" return []
[docs] @abstractmethod def get_param_groups(self) -> Dict[str, List[Parameter]]: """Get the param groups for the data manager. Returns: A list of dictionaries containing the data manager's param groups. """ return {}
[docs]@dataclass class VanillaDataManagerConfig(DataManagerConfig): """A basic data manager for a ray-based model""" _target: Type = field(default_factory=lambda: VanillaDataManager) """Target class to instantiate.""" dataparser: AnnotatedDataParserUnion = field(default_factory=BlenderDataParserConfig) """Specifies the dataparser used to unpack the data.""" train_num_rays_per_batch: int = 1024 """Number of rays per batch to use per training iteration.""" train_num_images_to_sample_from: int = -1 """Number of images to sample during training iteration.""" train_num_times_to_repeat_images: int = -1 """When not training on all images, number of iterations before picking new images. If -1, never pick new images.""" eval_num_rays_per_batch: int = 1024 """Number of rays per batch to use per eval iteration.""" eval_num_images_to_sample_from: int = -1 """Number of images to sample during eval iteration.""" eval_num_times_to_repeat_images: int = -1 """When not evaluating on all images, number of iterations before picking new images. If -1, never pick new images.""" eval_image_indices: Optional[Tuple[int, ...]] = (0,) """Specifies the image indices to use during eval; if None, uses all.""" collate_fn: Callable[[Any], Any] = cast(Any, staticmethod(nerfstudio_collate)) """Specifies the collate function to use for the train and eval dataloaders.""" camera_res_scale_factor: float = 1.0 """The scale factor for scaling spatial data such as images, mask, semantics along with relevant information about camera intrinsics """ patch_size: int = 1 """Size of patch to sample from. If > 1, patch-based sampling will be used.""" # tyro.conf.Suppress prevents us from creating CLI arguments for this field. camera_optimizer: tyro.conf.Suppress[Optional[CameraOptimizerConfig]] = field(default=None) """Deprecated, has been moved to the model config.""" pixel_sampler: PixelSamplerConfig = field(default_factory=PixelSamplerConfig) """Specifies the pixel sampler used to sample pixels from images."""
[docs] def __post_init__(self): """Warn user of camera optimizer change.""" if self.camera_optimizer is not None: import warnings CONSOLE.print( "\nCameraOptimizerConfig has been moved from the DataManager to the Model.\n", style="bold yellow" ) warnings.warn("above message coming from", FutureWarning, stacklevel=3)
TDataset = TypeVar("TDataset", bound=InputDataset, default=InputDataset)
[docs]class VanillaDataManager(DataManager, Generic[TDataset]): """Basic stored data manager implementation. This is pretty much a port over from our old dataloading utilities, and is a little jank under the hood. We may clean this up a little bit under the hood with more standard dataloading components that can be strung together, but it can be just used as a black box for now since only the constructor is likely to change in the future, or maybe passing in step number to the next_train and next_eval functions. Args: config: the DataManagerConfig used to instantiate class """ config: VanillaDataManagerConfig train_dataset: TDataset eval_dataset: TDataset train_dataparser_outputs: DataparserOutputs train_pixel_sampler: Optional[PixelSampler] = None eval_pixel_sampler: Optional[PixelSampler] = None def __init__( self, config: VanillaDataManagerConfig, device: Union[torch.device, str] = "cpu", test_mode: Literal["test", "val", "inference"] = "val", world_size: int = 1, local_rank: int = 0, **kwargs, ): self.config = config self.device = device self.world_size = world_size self.local_rank = local_rank self.sampler = None self.test_mode = test_mode self.test_split = "test" if test_mode in ["test", "inference"] else "val" self.dataparser_config = self.config.dataparser if self.config.data is not None: self.config.dataparser.data = Path(self.config.data) else: self.config.data = self.config.dataparser.data self.dataparser = self.dataparser_config.setup() if test_mode == "inference": self.dataparser.downscale_factor = 1 # Avoid opening images self.includes_time = self.dataparser.includes_time self.train_dataparser_outputs: DataparserOutputs = self.dataparser.get_dataparser_outputs(split="train") self.train_dataset = self.create_train_dataset() self.eval_dataset = self.create_eval_dataset() self.exclude_batch_keys_from_device = self.train_dataset.exclude_batch_keys_from_device if self.config.masks_on_gpu is True and "mask" in self.exclude_batch_keys_from_device: self.exclude_batch_keys_from_device.remove("mask") if self.config.images_on_gpu is True and "image" in self.exclude_batch_keys_from_device: self.exclude_batch_keys_from_device.remove("image") if self.train_dataparser_outputs is not None: cameras = self.train_dataparser_outputs.cameras if len(cameras) > 1: for i in range(1, len(cameras)): if cameras[0].width != cameras[i].width or cameras[0].height != cameras[i].height: CONSOLE.print("Variable resolution, using variable_res_collate") self.config.collate_fn = variable_res_collate break super().__init__() @cached_property def dataset_type(self) -> Type[TDataset]: """Returns the dataset type passed as the generic argument""" default: Type[TDataset] = cast(TDataset, TDataset.__default__) # type: ignore orig_class: Type[VanillaDataManager] = get_orig_class(self, default=None) # type: ignore if type(self) is VanillaDataManager and orig_class is None: return default if orig_class is not None and get_origin(orig_class) is VanillaDataManager: return get_args(orig_class)[0] # For inherited classes, we need to find the correct type to instantiate for base in getattr(self, "__orig_bases__", []): if get_origin(base) is VanillaDataManager: for value in get_args(base): if isinstance(value, ForwardRef): if value.__forward_evaluated__: value = value.__forward_value__ elif value.__forward_module__ is None: value.__forward_module__ = type(self).__module__ value = getattr(value, "_evaluate")(None, None, set()) assert isinstance(value, type) if issubclass(value, InputDataset): return cast(Type[TDataset], value) return default
[docs] def create_train_dataset(self) -> TDataset: """Sets up the data loaders for training""" return self.dataset_type( dataparser_outputs=self.train_dataparser_outputs, scale_factor=self.config.camera_res_scale_factor, )
[docs] def create_eval_dataset(self) -> TDataset: """Sets up the data loaders for evaluation""" return self.dataset_type( dataparser_outputs=self.dataparser.get_dataparser_outputs(split=self.test_split), scale_factor=self.config.camera_res_scale_factor, )
def _get_pixel_sampler(self, dataset: TDataset, num_rays_per_batch: int) -> PixelSampler: """Infer pixel sampler to use.""" if self.config.patch_size > 1 and type(self.config.pixel_sampler) is PixelSamplerConfig: return PatchPixelSamplerConfig().setup( patch_size=self.config.patch_size, num_rays_per_batch=num_rays_per_batch ) is_equirectangular = (dataset.cameras.camera_type == CameraType.EQUIRECTANGULAR.value).all() if is_equirectangular.any(): CONSOLE.print("[bold yellow]Warning: Some cameras are equirectangular, but using default pixel sampler.") fisheye_crop_radius = None if dataset.cameras.metadata is not None: fisheye_crop_radius = dataset.cameras.metadata.get("fisheye_crop_radius") return self.config.pixel_sampler.setup( is_equirectangular=is_equirectangular, num_rays_per_batch=num_rays_per_batch, fisheye_crop_radius=fisheye_crop_radius, )
[docs] def setup_train(self): """Sets up the data loaders for training""" assert self.train_dataset is not None CONSOLE.print("Setting up training dataset...") self.train_image_dataloader = CacheDataloader( self.train_dataset, num_images_to_sample_from=self.config.train_num_images_to_sample_from, num_times_to_repeat_images=self.config.train_num_times_to_repeat_images, device=self.device, num_workers=self.world_size * 4, pin_memory=True, collate_fn=self.config.collate_fn, exclude_batch_keys_from_device=self.exclude_batch_keys_from_device, ) self.iter_train_image_dataloader = iter(self.train_image_dataloader) self.train_pixel_sampler = self._get_pixel_sampler(self.train_dataset, self.config.train_num_rays_per_batch) self.train_ray_generator = RayGenerator(self.train_dataset.cameras.to(self.device))
[docs] def setup_eval(self): """Sets up the data loader for evaluation""" assert self.eval_dataset is not None CONSOLE.print("Setting up evaluation dataset...") self.eval_image_dataloader = CacheDataloader( self.eval_dataset, num_images_to_sample_from=self.config.eval_num_images_to_sample_from, num_times_to_repeat_images=self.config.eval_num_times_to_repeat_images, device=self.device, num_workers=self.world_size * 4, pin_memory=True, collate_fn=self.config.collate_fn, exclude_batch_keys_from_device=self.exclude_batch_keys_from_device, ) self.iter_eval_image_dataloader = iter(self.eval_image_dataloader) self.eval_pixel_sampler = self._get_pixel_sampler(self.eval_dataset, self.config.eval_num_rays_per_batch) self.eval_ray_generator = RayGenerator(self.eval_dataset.cameras.to(self.device)) # for loading full images self.fixed_indices_eval_dataloader = FixedIndicesEvalDataloader( input_dataset=self.eval_dataset, device=self.device, num_workers=self.world_size * 4, ) self.eval_dataloader = RandIndicesEvalDataloader( input_dataset=self.eval_dataset, device=self.device, num_workers=self.world_size * 4, )
[docs] def next_train(self, step: int) -> Tuple[RayBundle, Dict]: """Returns the next batch of data from the train dataloader.""" self.train_count += 1 image_batch = next(self.iter_train_image_dataloader) assert self.train_pixel_sampler is not None assert isinstance(image_batch, dict) batch = self.train_pixel_sampler.sample(image_batch) ray_indices = batch["indices"] ray_bundle = self.train_ray_generator(ray_indices) return ray_bundle, batch
[docs] def next_eval(self, step: int) -> Tuple[RayBundle, Dict]: """Returns the next batch of data from the eval dataloader.""" self.eval_count += 1 image_batch = next(self.iter_eval_image_dataloader) assert self.eval_pixel_sampler is not None assert isinstance(image_batch, dict) batch = self.eval_pixel_sampler.sample(image_batch) ray_indices = batch["indices"] ray_bundle = self.eval_ray_generator(ray_indices) return ray_bundle, batch
[docs] def next_eval_image(self, step: int) -> Tuple[Cameras, Dict]: for camera, batch in self.eval_dataloader: assert camera.shape[0] == 1 return camera, batch raise ValueError("No more eval images")
[docs] def get_train_rays_per_batch(self) -> int: if self.train_pixel_sampler is not None: return self.train_pixel_sampler.num_rays_per_batch return self.config.train_num_rays_per_batch
[docs] def get_eval_rays_per_batch(self) -> int: if self.eval_pixel_sampler is not None: return self.eval_pixel_sampler.num_rays_per_batch return self.config.eval_num_rays_per_batch
[docs] def get_datapath(self) -> Path: return self.config.dataparser.data
[docs] def get_param_groups(self) -> Dict[str, List[Parameter]]: """Get the param groups for the data manager. Returns: A list of dictionaries containing the data manager's param groups. """ return {}