# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Datamanager.
"""
from __future__ import annotations
from abc import abstractmethod
from collections import defaultdict
from dataclasses import dataclass, field
from functools import cached_property
from pathlib import Path
from typing import (
Any,
Callable,
Dict,
ForwardRef,
Generic,
List,
Literal,
Optional,
Tuple,
Type,
Union,
cast,
get_args,
get_origin,
)
import torch
import tyro
from torch import nn
from torch.nn import Parameter
from torch.utils.data.distributed import DistributedSampler
from typing_extensions import TypeVar
from nerfstudio.cameras.camera_optimizers import CameraOptimizerConfig
from nerfstudio.cameras.cameras import Cameras, CameraType
from nerfstudio.cameras.rays import RayBundle
from nerfstudio.configs.base_config import InstantiateConfig
from nerfstudio.configs.dataparser_configs import AnnotatedDataParserUnion
from nerfstudio.data.dataparsers.base_dataparser import DataparserOutputs
from nerfstudio.data.dataparsers.blender_dataparser import BlenderDataParserConfig
from nerfstudio.data.datasets.base_dataset import InputDataset
from nerfstudio.data.pixel_samplers import PatchPixelSamplerConfig, PixelSampler, PixelSamplerConfig
from nerfstudio.data.utils.dataloaders import CacheDataloader, FixedIndicesEvalDataloader, RandIndicesEvalDataloader
from nerfstudio.data.utils.nerfstudio_collate import nerfstudio_collate
from nerfstudio.engine.callbacks import TrainingCallback, TrainingCallbackAttributes
from nerfstudio.model_components.ray_generators import RayGenerator
from nerfstudio.utils.misc import IterableWrapper, get_orig_class
from nerfstudio.utils.rich_utils import CONSOLE
[docs]def variable_res_collate(batch: List[Dict]) -> Dict:
"""Default collate function for the cached dataloader.
Args:
batch: Batch of samples from the dataset.
Returns:
Collated batch.
"""
images = []
imgdata_lists = defaultdict(list)
for data in batch:
image = data.pop("image")
images.append(image)
topop = []
for key, val in data.items():
if isinstance(val, torch.Tensor):
# if the value has same height and width as the image, assume that it should be collated accordingly.
if len(val.shape) >= 2 and val.shape[:2] == image.shape[:2]:
imgdata_lists[key].append(val)
topop.append(key)
# now that iteration is complete, the image data items can be removed from the batch
for key in topop:
del data[key]
new_batch = nerfstudio_collate(batch)
new_batch["image"] = images
new_batch.update(imgdata_lists)
return new_batch
[docs]@dataclass
class DataManagerConfig(InstantiateConfig):
"""Configuration for data manager instantiation; DataManager is in charge of keeping the train/eval dataparsers;
After instantiation, data manager holds both train/eval datasets and is in charge of returning unpacked
train/eval data at each iteration
"""
_target: Type = field(default_factory=lambda: DataManager)
"""Target class to instantiate."""
data: Optional[Path] = None
"""Source of data, may not be used by all models."""
masks_on_gpu: bool = False
"""Process masks on GPU for speed at the expense of memory, if True."""
images_on_gpu: bool = False
"""Process images on GPU for speed at the expense of memory, if True."""
[docs]class DataManager(nn.Module):
"""Generic data manager's abstract class
This version of the data manager is designed be a monolithic way to load data and latents,
especially since this may contain learnable parameters which need to be shared across the train
and test data managers. The idea is that we have setup methods for train and eval separately and
this can be a combined train/eval if you want.
Usage:
To get data, use the next_train and next_eval functions.
This data manager's next_train and next_eval methods will return 2 things:
1. 'rays': This will contain the rays or camera we are sampling, with latents and
conditionals attached (everything needed at inference)
2. A "batch" of auxiliary information: This will contain the mask, the ground truth
pixels, etc needed to actually train, score, etc the model
Rationale:
Because of this abstraction we've added, we can support more NeRF paradigms beyond the
vanilla nerf paradigm of single-scene, fixed-images, no-learnt-latents.
We can now support variable scenes, variable number of images, and arbitrary latents.
Train Methods:
setup_train: sets up for being used as train
iter_train: will be called on __iter__() for the train iterator
next_train: will be called on __next__() for the training iterator
get_train_iterable: utility that gets a clean pythonic iterator for your training data
Eval Methods:
setup_eval: sets up for being used as eval
iter_eval: will be called on __iter__() for the eval iterator
next_eval: will be called on __next__() for the eval iterator
get_eval_iterable: utility that gets a clean pythonic iterator for your eval data
Attributes:
train_count (int): the step number of our train iteration, needs to be incremented manually
eval_count (int): the step number of our eval iteration, needs to be incremented manually
train_dataset (Dataset): the dataset for the train dataset
eval_dataset (Dataset): the dataset for the eval dataset
includes_time (bool): whether the dataset includes time information
Additional attributes specific to each subclass are defined in the setup_train and setup_eval
functions.
"""
train_dataset: Optional[InputDataset] = None
eval_dataset: Optional[InputDataset] = None
train_sampler: Optional[DistributedSampler] = None
eval_sampler: Optional[DistributedSampler] = None
includes_time: bool = False
def __init__(self):
"""Constructor for the DataManager class.
Subclassed DataManagers will likely need to override this constructor.
If you aren't manually calling the setup_train and setup_eval functions from an overriden
constructor, that you call super().__init__() BEFORE you initialize any
nn.Modules or nn.Parameters, but AFTER you've already set all the attributes you need
for the setup functions."""
super().__init__()
self.train_count = 0
self.eval_count = 0
if self.train_dataset and self.test_mode != "inference":
self.setup_train()
if self.eval_dataset and self.test_mode != "inference":
self.setup_eval()
[docs] def forward(self):
"""Blank forward method
This is an nn.Module, and so requires a forward() method normally, although in our case
we do not need a forward() method"""
raise NotImplementedError
[docs] def iter_train(self):
"""The __iter__ function for the train iterator.
This only exists to assist the get_train_iterable function, since we need to pass
in an __iter__ function for our trivial iterable that we are making."""
self.train_count = 0
[docs] def iter_eval(self):
"""The __iter__ function for the eval iterator.
This only exists to assist the get_eval_iterable function, since we need to pass
in an __iter__ function for our trivial iterable that we are making."""
self.eval_count = 0
[docs] def get_train_iterable(self, length=-1) -> IterableWrapper:
"""Gets a trivial pythonic iterator that will use the iter_train and next_train functions
as __iter__ and __next__ methods respectively.
This basically is just a little utility if you want to do something like:
| for ray_bundle, batch in datamanager.get_train_iterable():
| <eval code here>
since the returned IterableWrapper is just an iterator with the __iter__ and __next__
methods (methods bound to our DataManager instance in this case) specified in the constructor.
"""
return IterableWrapper(self.iter_train, self.next_train, length)
[docs] def get_eval_iterable(self, length=-1) -> IterableWrapper:
"""Gets a trivial pythonic iterator that will use the iter_eval and next_eval functions
as __iter__ and __next__ methods respectively.
This basically is just a little utility if you want to do something like:
| for ray_bundle, batch in datamanager.get_eval_iterable():
| <eval code here>
since the returned IterableWrapper is just an iterator with the __iter__ and __next__
methods (methods bound to our DataManager instance in this case) specified in the constructor.
"""
return IterableWrapper(self.iter_eval, self.next_eval, length)
[docs] @abstractmethod
def setup_train(self):
"""Sets up the data manager for training.
Here you will define any subclass specific object attributes from the attribute"""
[docs] @abstractmethod
def setup_eval(self):
"""Sets up the data manager for evaluation"""
[docs] @abstractmethod
def next_train(self, step: int) -> Tuple[Union[RayBundle, Cameras], Dict]:
"""Returns the next batch of data from the train data manager.
Args:
step: the step number of the eval image to retrieve
Returns:
A tuple of the ray bundle for the image, and a dictionary of additional batch information
such as the groundtruth image.
"""
raise NotImplementedError
[docs] @abstractmethod
def next_eval(self, step: int) -> Tuple[Union[RayBundle, Cameras], Dict]:
"""Returns the next batch of data from the eval data manager.
Args:
step: the step number of the eval image to retrieve
Returns:
A tuple of the ray/camera for the image, and a dictionary of additional batch information
such as the groundtruth image.
"""
raise NotImplementedError
[docs] @abstractmethod
def next_eval_image(self, step: int) -> Tuple[Cameras, Dict]:
"""Retrieve the next eval image.
Args:
step: the step number of the eval image to retrieve
Returns:
A tuple of the step number, the ray/camera for the image, and a dictionary of
additional batch information such as the groundtruth image.
"""
raise NotImplementedError
[docs] @abstractmethod
def get_train_rays_per_batch(self) -> int:
"""Returns the number of rays per batch for training."""
raise NotImplementedError
[docs] @abstractmethod
def get_eval_rays_per_batch(self) -> int:
"""Returns the number of rays per batch for evaluation."""
raise NotImplementedError
[docs] @abstractmethod
def get_datapath(self) -> Path:
"""Returns the path to the data. This is used to determine where to save camera paths."""
[docs] def get_training_callbacks(
self, training_callback_attributes: TrainingCallbackAttributes
) -> List[TrainingCallback]:
"""Returns a list of callbacks to be used during training."""
return []
[docs] @abstractmethod
def get_param_groups(self) -> Dict[str, List[Parameter]]:
"""Get the param groups for the data manager.
Returns:
A list of dictionaries containing the data manager's param groups.
"""
return {}
[docs]@dataclass
class VanillaDataManagerConfig(DataManagerConfig):
"""A basic data manager for a ray-based model"""
_target: Type = field(default_factory=lambda: VanillaDataManager)
"""Target class to instantiate."""
dataparser: AnnotatedDataParserUnion = field(default_factory=BlenderDataParserConfig)
"""Specifies the dataparser used to unpack the data."""
train_num_rays_per_batch: int = 1024
"""Number of rays per batch to use per training iteration."""
train_num_images_to_sample_from: int = -1
"""Number of images to sample during training iteration."""
train_num_times_to_repeat_images: int = -1
"""When not training on all images, number of iterations before picking new
images. If -1, never pick new images."""
eval_num_rays_per_batch: int = 1024
"""Number of rays per batch to use per eval iteration."""
eval_num_images_to_sample_from: int = -1
"""Number of images to sample during eval iteration."""
eval_num_times_to_repeat_images: int = -1
"""When not evaluating on all images, number of iterations before picking
new images. If -1, never pick new images."""
eval_image_indices: Optional[Tuple[int, ...]] = (0,)
"""Specifies the image indices to use during eval; if None, uses all."""
collate_fn: Callable[[Any], Any] = cast(Any, staticmethod(nerfstudio_collate))
"""Specifies the collate function to use for the train and eval dataloaders."""
camera_res_scale_factor: float = 1.0
"""The scale factor for scaling spatial data such as images, mask, semantics
along with relevant information about camera intrinsics
"""
patch_size: int = 1
"""Size of patch to sample from. If > 1, patch-based sampling will be used."""
# tyro.conf.Suppress prevents us from creating CLI arguments for this field.
camera_optimizer: tyro.conf.Suppress[Optional[CameraOptimizerConfig]] = field(default=None)
"""Deprecated, has been moved to the model config."""
pixel_sampler: PixelSamplerConfig = field(default_factory=PixelSamplerConfig)
"""Specifies the pixel sampler used to sample pixels from images."""
[docs] def __post_init__(self):
"""Warn user of camera optimizer change."""
if self.camera_optimizer is not None:
import warnings
CONSOLE.print(
"\nCameraOptimizerConfig has been moved from the DataManager to the Model.\n", style="bold yellow"
)
warnings.warn("above message coming from", FutureWarning, stacklevel=3)
TDataset = TypeVar("TDataset", bound=InputDataset, default=InputDataset)
[docs]class VanillaDataManager(DataManager, Generic[TDataset]):
"""Basic stored data manager implementation.
This is pretty much a port over from our old dataloading utilities, and is a little jank
under the hood. We may clean this up a little bit under the hood with more standard dataloading
components that can be strung together, but it can be just used as a black box for now since
only the constructor is likely to change in the future, or maybe passing in step number to the
next_train and next_eval functions.
Args:
config: the DataManagerConfig used to instantiate class
"""
config: VanillaDataManagerConfig
train_dataset: TDataset
eval_dataset: TDataset
train_dataparser_outputs: DataparserOutputs
train_pixel_sampler: Optional[PixelSampler] = None
eval_pixel_sampler: Optional[PixelSampler] = None
def __init__(
self,
config: VanillaDataManagerConfig,
device: Union[torch.device, str] = "cpu",
test_mode: Literal["test", "val", "inference"] = "val",
world_size: int = 1,
local_rank: int = 0,
**kwargs,
):
self.config = config
self.device = device
self.world_size = world_size
self.local_rank = local_rank
self.sampler = None
self.test_mode = test_mode
self.test_split = "test" if test_mode in ["test", "inference"] else "val"
self.dataparser_config = self.config.dataparser
if self.config.data is not None:
self.config.dataparser.data = Path(self.config.data)
else:
self.config.data = self.config.dataparser.data
self.dataparser = self.dataparser_config.setup()
if test_mode == "inference":
self.dataparser.downscale_factor = 1 # Avoid opening images
self.includes_time = self.dataparser.includes_time
self.train_dataparser_outputs: DataparserOutputs = self.dataparser.get_dataparser_outputs(split="train")
self.train_dataset = self.create_train_dataset()
self.eval_dataset = self.create_eval_dataset()
self.exclude_batch_keys_from_device = self.train_dataset.exclude_batch_keys_from_device
if self.config.masks_on_gpu is True and "mask" in self.exclude_batch_keys_from_device:
self.exclude_batch_keys_from_device.remove("mask")
if self.config.images_on_gpu is True and "image" in self.exclude_batch_keys_from_device:
self.exclude_batch_keys_from_device.remove("image")
if self.train_dataparser_outputs is not None:
cameras = self.train_dataparser_outputs.cameras
if len(cameras) > 1:
for i in range(1, len(cameras)):
if cameras[0].width != cameras[i].width or cameras[0].height != cameras[i].height:
CONSOLE.print("Variable resolution, using variable_res_collate")
self.config.collate_fn = variable_res_collate
break
super().__init__()
@cached_property
def dataset_type(self) -> Type[TDataset]:
"""Returns the dataset type passed as the generic argument"""
default: Type[TDataset] = cast(TDataset, TDataset.__default__) # type: ignore
orig_class: Type[VanillaDataManager] = get_orig_class(self, default=None) # type: ignore
if type(self) is VanillaDataManager and orig_class is None:
return default
if orig_class is not None and get_origin(orig_class) is VanillaDataManager:
return get_args(orig_class)[0]
# For inherited classes, we need to find the correct type to instantiate
for base in getattr(self, "__orig_bases__", []):
if get_origin(base) is VanillaDataManager:
for value in get_args(base):
if isinstance(value, ForwardRef):
if value.__forward_evaluated__:
value = value.__forward_value__
elif value.__forward_module__ is None:
value.__forward_module__ = type(self).__module__
value = getattr(value, "_evaluate")(None, None, set())
assert isinstance(value, type)
if issubclass(value, InputDataset):
return cast(Type[TDataset], value)
return default
[docs] def create_train_dataset(self) -> TDataset:
"""Sets up the data loaders for training"""
return self.dataset_type(
dataparser_outputs=self.train_dataparser_outputs,
scale_factor=self.config.camera_res_scale_factor,
)
[docs] def create_eval_dataset(self) -> TDataset:
"""Sets up the data loaders for evaluation"""
return self.dataset_type(
dataparser_outputs=self.dataparser.get_dataparser_outputs(split=self.test_split),
scale_factor=self.config.camera_res_scale_factor,
)
def _get_pixel_sampler(self, dataset: TDataset, num_rays_per_batch: int) -> PixelSampler:
"""Infer pixel sampler to use."""
if self.config.patch_size > 1 and type(self.config.pixel_sampler) is PixelSamplerConfig:
return PatchPixelSamplerConfig().setup(
patch_size=self.config.patch_size, num_rays_per_batch=num_rays_per_batch
)
is_equirectangular = (dataset.cameras.camera_type == CameraType.EQUIRECTANGULAR.value).all()
if is_equirectangular.any():
CONSOLE.print("[bold yellow]Warning: Some cameras are equirectangular, but using default pixel sampler.")
fisheye_crop_radius = None
if dataset.cameras.metadata is not None:
fisheye_crop_radius = dataset.cameras.metadata.get("fisheye_crop_radius")
return self.config.pixel_sampler.setup(
is_equirectangular=is_equirectangular,
num_rays_per_batch=num_rays_per_batch,
fisheye_crop_radius=fisheye_crop_radius,
)
[docs] def setup_train(self):
"""Sets up the data loaders for training"""
assert self.train_dataset is not None
CONSOLE.print("Setting up training dataset...")
self.train_image_dataloader = CacheDataloader(
self.train_dataset,
num_images_to_sample_from=self.config.train_num_images_to_sample_from,
num_times_to_repeat_images=self.config.train_num_times_to_repeat_images,
device=self.device,
num_workers=self.world_size * 4,
pin_memory=True,
collate_fn=self.config.collate_fn,
exclude_batch_keys_from_device=self.exclude_batch_keys_from_device,
)
self.iter_train_image_dataloader = iter(self.train_image_dataloader)
self.train_pixel_sampler = self._get_pixel_sampler(self.train_dataset, self.config.train_num_rays_per_batch)
self.train_ray_generator = RayGenerator(self.train_dataset.cameras.to(self.device))
[docs] def setup_eval(self):
"""Sets up the data loader for evaluation"""
assert self.eval_dataset is not None
CONSOLE.print("Setting up evaluation dataset...")
self.eval_image_dataloader = CacheDataloader(
self.eval_dataset,
num_images_to_sample_from=self.config.eval_num_images_to_sample_from,
num_times_to_repeat_images=self.config.eval_num_times_to_repeat_images,
device=self.device,
num_workers=self.world_size * 4,
pin_memory=True,
collate_fn=self.config.collate_fn,
exclude_batch_keys_from_device=self.exclude_batch_keys_from_device,
)
self.iter_eval_image_dataloader = iter(self.eval_image_dataloader)
self.eval_pixel_sampler = self._get_pixel_sampler(self.eval_dataset, self.config.eval_num_rays_per_batch)
self.eval_ray_generator = RayGenerator(self.eval_dataset.cameras.to(self.device))
# for loading full images
self.fixed_indices_eval_dataloader = FixedIndicesEvalDataloader(
input_dataset=self.eval_dataset,
device=self.device,
num_workers=self.world_size * 4,
)
self.eval_dataloader = RandIndicesEvalDataloader(
input_dataset=self.eval_dataset,
device=self.device,
num_workers=self.world_size * 4,
)
[docs] def next_train(self, step: int) -> Tuple[RayBundle, Dict]:
"""Returns the next batch of data from the train dataloader."""
self.train_count += 1
image_batch = next(self.iter_train_image_dataloader)
assert self.train_pixel_sampler is not None
assert isinstance(image_batch, dict)
batch = self.train_pixel_sampler.sample(image_batch)
ray_indices = batch["indices"]
ray_bundle = self.train_ray_generator(ray_indices)
return ray_bundle, batch
[docs] def next_eval(self, step: int) -> Tuple[RayBundle, Dict]:
"""Returns the next batch of data from the eval dataloader."""
self.eval_count += 1
image_batch = next(self.iter_eval_image_dataloader)
assert self.eval_pixel_sampler is not None
assert isinstance(image_batch, dict)
batch = self.eval_pixel_sampler.sample(image_batch)
ray_indices = batch["indices"]
ray_bundle = self.eval_ray_generator(ray_indices)
return ray_bundle, batch
[docs] def next_eval_image(self, step: int) -> Tuple[Cameras, Dict]:
for camera, batch in self.eval_dataloader:
assert camera.shape[0] == 1
return camera, batch
raise ValueError("No more eval images")
[docs] def get_train_rays_per_batch(self) -> int:
if self.train_pixel_sampler is not None:
return self.train_pixel_sampler.num_rays_per_batch
return self.config.train_num_rays_per_batch
[docs] def get_eval_rays_per_batch(self) -> int:
if self.eval_pixel_sampler is not None:
return self.eval_pixel_sampler.num_rays_per_batch
return self.config.eval_num_rays_per_batch
[docs] def get_datapath(self) -> Path:
return self.config.dataparser.data
[docs] def get_param_groups(self) -> Dict[str, List[Parameter]]:
"""Get the param groups for the data manager.
Returns:
A list of dictionaries containing the data manager's param groups.
"""
return {}