TensorDataclass#

Tensor dataclass

class nerfstudio.utils.tensor_dataclass.TensorDataclass[source]#

Bases: object

@dataclass of tensors with the same size batch. Allows indexing and standard tensor ops. Fields that are not Tensors will not be batched unless they are also a TensorDataclass. Any fields that are dictionaries will have their Tensors or TensorDataclasses batched, and dictionaries will have their tensors or TensorDataclasses considered in the initial broadcast. Tensor fields must have at least 1 dimension, meaning that you must convert a field like torch.Tensor(1) to torch.Tensor([1])

Example:

@dataclass
class TestTensorDataclass(TensorDataclass):
    a: torch.Tensor
    b: torch.Tensor
    c: torch.Tensor = None

# Create a new tensor dataclass with batch size of [2,3,4]
test = TestTensorDataclass(a=torch.ones((2, 3, 4, 2)), b=torch.ones((4, 3)))

test.shape  # [2, 3, 4]
test.a.shape  # [2, 3, 4, 2]
test.b.shape  # [2, 3, 4, 3]

test.reshape((6,4)).shape  # [6, 4]
test.flatten().shape  # [24,]

test[..., 0].shape  # [2, 3]
test[:, 0, :].shape  # [2, 4]
__post_init__() None[source]#

Finishes setting up the TensorDataclass

This will 1) find the broadcasted shape and 2) broadcast all fields to this shape 3) set _shape to be the broadcasted shape.

broadcast_to(shape: Union[Size, Tuple[int, ...]]) TensorDataclassT[source]#

Returns a new TensorDataclass broadcast to new shape.

Changes to the original tensor dataclass should effect the returned tensor dataclass, meaning it is NOT a deepcopy, and they are still linked.

Parameters:

shape – The new shape of the tensor dataclass.

Returns:

A new TensorDataclass with the same data but with a new shape.

flatten() TensorDataclassT[source]#

Returns a new TensorDataclass with flattened batch dimensions

Returns:

A new TensorDataclass with the same data but with a new shape.

Return type:

TensorDataclass

property ndim: int#

Returns the number of dimensions of the tensor dataclass.

pin_memory() TensorDataclassT[source]#

Pins the tensor dataclass memory

Returns:

A new TensorDataclass with the same data but pinned.

Return type:

TensorDataclass

reshape(shape: Tuple[int, ...]) TensorDataclassT[source]#

Returns a new TensorDataclass with the same data but with a new shape.

This should deepcopy as well.

Parameters:

shape – The new shape of the tensor dataclass.

Returns:

A new TensorDataclass with the same data but with a new shape.

property shape: Tuple[int, ...]#

Returns the batch shape of the tensor dataclass.

property size: int#

Returns the number of elements in the tensor dataclass batch dimension.

to(device) TensorDataclassT[source]#

Returns a new TensorDataclass with the same data but on the specified device.

Parameters:

device – The device to place the tensor dataclass.

Returns:

A new TensorDataclass with the same data but on the specified device.