Source code for nerfstudio.data.utils.nerfstudio_collate

# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Custom collate function that includes cases for nerfstudio types.
"""

import collections
import collections.abc
import re
from typing import Any, Callable, Dict, Union

import torch
import torch.utils.data

from nerfstudio.cameras.cameras import Cameras

NERFSTUDIO_COLLATE_ERR_MSG_FORMAT = (
    "default_collate: batch must contain tensors, numpy arrays, numbers, " "dicts, lists or anything in {}; found {}"
)
np_str_obj_array_pattern = re.compile(r"[SaUO]")


[docs]def nerfstudio_collate(batch: Any, extra_mappings: Union[Dict[type, Callable], None] = None) -> Any: r""" This is the default pytorch collate function, but with support for nerfstudio types. All documentation below is copied straight over from pytorch's default_collate function, python version 3.8.13, pytorch version '1.12.1+cu113'. Custom nerfstudio types are accounted for at the end, and extra mappings can be passed in to handle custom types. These mappings are from types: callable (types being like int or float or the return value of type(3.), etc). The only code before we parse for custom types that was changed from default pytorch was the addition of the extra_mappings argument, a find and replace operation from default_collate to nerfstudio_collate, and the addition of the nerfstudio_collate_err_msg_format variable. Function that takes in a batch of data and puts the elements within the batch into a tensor with an additional outer dimension - batch size. The exact output type can be a :class:`torch.Tensor`, a `Sequence` of :class:`torch.Tensor`, a Collection of :class:`torch.Tensor`, or left unchanged, depending on the input type. This is used as the default function for collation when `batch_size` or `batch_sampler` is defined in :class:`~torch.utils.data.DataLoader`. Here is the general input type (based on the type of the element within the batch) to output type mapping: * :class:`torch.Tensor` -> :class:`torch.Tensor` (with an added outer dimension batch size) * NumPy Arrays -> :class:`torch.Tensor` * `float` -> :class:`torch.Tensor` * `int` -> :class:`torch.Tensor` * `str` -> `str` (unchanged) * `bytes` -> `bytes` (unchanged) * `Mapping[K, V_i]` -> `Mapping[K, nerfstudio_collate([V_1, V_2, ...])]` * `NamedTuple[V1_i, V2_i, ...]` -> `NamedTuple[nerfstudio_collate([V1_1, V1_2, ...]), nerfstudio_collate([V2_1, V2_2, ...]), ...]` * `Sequence[V1_i, V2_i, ...]` -> `Sequence[nerfstudio_collate([V1_1, V1_2, ...]), nerfstudio_collate([V2_1, V2_2, ...]), ...]` Args: batch: a single batch to be collated Examples: >>> # Example with a batch of `int`s: >>> nerfstudio_collate([0, 1, 2, 3]) tensor([0, 1, 2, 3]) >>> # Example with a batch of `str`s: >>> nerfstudio_collate(['a', 'b', 'c']) ['a', 'b', 'c'] >>> # Example with `Map` inside the batch: >>> nerfstudio_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}]) {'A': tensor([ 0, 100]), 'B': tensor([ 1, 100])} >>> # Example with `NamedTuple` inside the batch: >>> Point = namedtuple('Point', ['x', 'y']) >>> nerfstudio_collate([Point(0, 0), Point(1, 1)]) Point(x=tensor([0, 1]), y=tensor([0, 1])) >>> # Example with `Tuple` inside the batch: >>> nerfstudio_collate([(0, 1), (2, 3)]) [tensor([0, 2]), tensor([1, 3])] >>> # Example with `List` inside the batch: >>> nerfstudio_collate([[0, 1], [2, 3]]) [tensor([0, 2]), tensor([1, 3])] """ if extra_mappings is None: extra_mappings = {} elem = batch[0] elem_type = type(elem) if isinstance(elem, torch.Tensor): out = None if torch.utils.data.get_worker_info() is not None: # If we're in a background process, concatenate directly into a # shared memory tensor to avoid an extra copy numel = sum(x.numel() for x in batch) storage = elem.storage()._new_shared(numel, device=elem.device) out = elem.new(storage).resize_(len(batch), *list(elem.size())) return torch.stack(batch, 0, out=out) elif elem_type.__module__ == "numpy" and elem_type.__name__ != "str_" and elem_type.__name__ != "string_": if elem_type.__name__ in ("ndarray", "memmap"): # array of string classes and object if np_str_obj_array_pattern.search(elem.dtype.str) is not None: raise TypeError(NERFSTUDIO_COLLATE_ERR_MSG_FORMAT.format(elem.dtype)) return nerfstudio_collate([torch.as_tensor(b) for b in batch], extra_mappings=extra_mappings) elif elem.shape == (): # scalars return torch.as_tensor(batch) elif isinstance(elem, float): return torch.tensor(batch, dtype=torch.float64) elif isinstance(elem, int): return torch.tensor(batch) elif isinstance(elem, (str, bytes)): return batch elif isinstance(elem, collections.abc.Mapping): try: return elem_type( {key: nerfstudio_collate([d[key] for d in batch], extra_mappings=extra_mappings) for key in elem} ) except TypeError: # The mapping type may not support `__init__(iterable)`. return {key: nerfstudio_collate([d[key] for d in batch], extra_mappings=extra_mappings) for key in elem} elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple return elem_type(*(nerfstudio_collate(samples, extra_mappings=extra_mappings) for samples in zip(*batch))) elif isinstance(elem, collections.abc.Sequence): # check to make sure that the elements in batch have consistent size it = iter(batch) elem_size = len(next(it)) if not all(len(elem) == elem_size for elem in it): raise RuntimeError("each element in list of batch should be of equal size") transposed = list(zip(*batch)) # It may be accessed twice, so we use a list. if isinstance(elem, tuple): return [ nerfstudio_collate(samples, extra_mappings=extra_mappings) for samples in transposed ] # Backwards compatibility. else: try: return elem_type([nerfstudio_collate(samples, extra_mappings=extra_mappings) for samples in transposed]) except TypeError: # The sequence type may not support `__init__(iterable)` (e.g., `range`). return [nerfstudio_collate(samples, extra_mappings=extra_mappings) for samples in transposed] # NerfStudio types supported below elif isinstance(elem, Cameras): # If a camera, just concatenate along the batch dimension. In the future, this may change to stacking assert all((isinstance(cam, Cameras) for cam in batch)) assert all((cam.distortion_params is None for cam in batch)) or all( (cam.distortion_params is not None for cam in batch) ), "All cameras must have distortion parameters or none of them should have distortion parameters.\ Generalized batching will be supported in the future." if batch[0].metadata is not None: metadata_keys = batch[0].metadata.keys() assert all( (cam.metadata.keys() == metadata_keys for cam in batch) ), "All cameras must have the same metadata keys." else: assert all((cam.metadata is None for cam in batch)), "All cameras must have the same metadata keys." if batch[0].times is not None: assert all((cam.times is not None for cam in batch)), "All cameras must have times present or absent." else: assert all((cam.times is None for cam in batch)), "All cameras must have times present or absent." # If no batch dimension exists, then we need to stack everything and create a batch dimension on 0th dim if elem.shape == (): op = torch.stack # If batch dimension exists, then we need to concatenate along the 0th dimension else: op = torch.cat # Create metadata dictionary if batch[0].metadata is not None: metadata = {key: op([cam.metadata[key] for cam in batch], dim=0) for key in batch[0].metadata.keys()} else: metadata = None if batch[0].distortion_params is not None: distortion_params = op( [cameras.distortion_params for cameras in batch], dim=0, ) else: distortion_params = None if batch[0].times is not None: times = torch.stack([cameras.times for cameras in batch], dim=0) else: times = None return Cameras( op([cameras.camera_to_worlds for cameras in batch], dim=0), op([cameras.fx for cameras in batch], dim=0), op([cameras.fy for cameras in batch], dim=0), op([cameras.cx for cameras in batch], dim=0), op([cameras.cy for cameras in batch], dim=0), height=op([cameras.height for cameras in batch], dim=0), width=op([cameras.width for cameras in batch], dim=0), distortion_params=distortion_params, camera_type=op([cameras.camera_type for cameras in batch], dim=0), times=times, metadata=metadata, ) for type_key in extra_mappings: if isinstance(elem, type_key): return extra_mappings[type_key](batch) raise TypeError(NERFSTUDIO_COLLATE_ERR_MSG_FORMAT.format(elem_type))