# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions to allow easy re-use of common operations across dataloaders"""
from pathlib import Path
from typing import IO, List, Tuple, Union
import cv2
import numpy as np
import torch
from PIL import Image
from PIL.Image import Image as PILImage
[docs]def pil_to_numpy(im: PILImage) -> np.ndarray:
"""Converts a PIL Image object to a NumPy array.
Args:
im (PIL.Image.Image): The input PIL Image object.
Returns:
numpy.ndarray representing the image data.
"""
# Load in image completely (PIL defaults to lazy loading)
im.load()
# Unpack data
e = Image._getencoder(im.mode, "raw", im.mode)
e.setimage(im.im)
# NumPy buffer for the result
shape, typestr = Image._conv_type_shape(im)
data = np.empty(shape, dtype=np.dtype(typestr))
mem = data.data.cast("B", (data.data.nbytes,))
bufsize, s, offset = 65536, 0, 0
while not s:
_, s, d = e.encode(bufsize)
mem[offset : offset + len(d)] = d
offset += len(d)
if s < 0:
raise RuntimeError("encoder error %d in tobytes" % s)
return data
[docs]def get_image_mask_tensor_from_path(filepath: Union[Path, IO[bytes]], scale_factor: float = 1.0) -> torch.Tensor:
"""
Utility function to read a mask image from the given path and return a boolean tensor
"""
pil_mask = Image.open(filepath)
if scale_factor != 1.0:
width, height = pil_mask.size
newsize = (int(width * scale_factor), int(height * scale_factor))
pil_mask = pil_mask.resize(newsize, resample=Image.Resampling.NEAREST)
mask_tensor = torch.from_numpy(pil_to_numpy(pil_mask)).unsqueeze(-1).bool()
if len(mask_tensor.shape) != 3:
raise ValueError("The mask image should have 1 channel")
return mask_tensor
[docs]def get_semantics_and_mask_tensors_from_path(
filepath: Path, mask_indices: Union[List, torch.Tensor], scale_factor: float = 1.0
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Utility function to read segmentation from the given filepath
If no mask is required - use mask_indices = []
"""
if isinstance(mask_indices, List):
mask_indices = torch.tensor(mask_indices, dtype=torch.int64).view(1, 1, -1)
pil_image = Image.open(filepath)
if scale_factor != 1.0:
width, height = pil_image.size
newsize = (int(width * scale_factor), int(height * scale_factor))
pil_image = pil_image.resize(newsize, resample=Image.Resampling.NEAREST)
semantics = torch.from_numpy(np.array(pil_image, dtype="int64"))[..., None]
mask = torch.sum(semantics == mask_indices, dim=-1, keepdim=True) == 0
return semantics, mask
[docs]def get_depth_image_from_path(
filepath: Path,
height: int,
width: int,
scale_factor: float,
interpolation: int = cv2.INTER_NEAREST,
) -> torch.Tensor:
"""Loads, rescales and resizes depth images.
Filepath points to a 16-bit or 32-bit depth image, or a numpy array `*.npy`.
Args:
filepath: Path to depth image.
height: Target depth image height.
width: Target depth image width.
scale_factor: Factor by which to scale depth image.
interpolation: Depth value interpolation for resizing.
Returns:
Depth image torch tensor with shape [height, width, 1].
"""
if filepath.suffix == ".npy":
image = np.load(filepath).astype(np.float32) * scale_factor
image = cv2.resize(image, (width, height), interpolation=interpolation)
else:
image = cv2.imread(str(filepath.absolute()), cv2.IMREAD_ANYDEPTH)
image = image.astype(np.float32) * scale_factor
image = cv2.resize(image, (width, height), interpolation=interpolation)
return torch.from_numpy(image[:, :, np.newaxis])
[docs]def identity_collate(x):
"""This function does nothing but serves to help our dataloaders have a pickleable function, as lambdas are not pickleable"""
return x