# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Data parser for pre-prepared datasets for all cameras, with no additional processing needed
Optional fields - semantics, mask_filenames, cameras.distortion_params, cameras.times
"""
from dataclasses import dataclass, field
from pathlib import Path
from typing import Type
import numpy as np
import torch
from nerfstudio.cameras.cameras import Cameras
from nerfstudio.data.dataparsers.base_dataparser import DataParser, DataParserConfig, DataparserOutputs, Semantics
from nerfstudio.data.scene_box import SceneBox
[docs]@dataclass
class MinimalDataParserConfig(DataParserConfig):
"""Minimal dataset config"""
_target: Type = field(default_factory=lambda: MinimalDataParser)
"""target class to instantiate"""
data: Path = Path("/home/nikhil/nerfstudio-main/tests/data/lego_test/minimal_parser")
[docs]@dataclass
class MinimalDataParser(DataParser):
"""Minimal DatasetParser"""
config: MinimalDataParserConfig
def _generate_dataparser_outputs(self, split="train"):
filepath = self.config.data / f"{split}.npz"
data = np.load(filepath, allow_pickle=True)
image_filenames = [filepath.parent / path for path in data["image_filenames"].tolist()]
mask_filenames = None
if "mask_filenames" in data.keys():
mask_filenames = [filepath.parent / path for path in data["mask_filenames"].tolist()]
if "semantics" in data.keys():
semantics = data["semantics"].item()
metadata = {
"semantics": Semantics(
filenames=[filepath.parent / path for path in semantics["filenames"].tolist()],
classes=semantics["classes"].tolist(),
colors=torch.from_numpy(semantics["colors"]),
mask_classes=semantics["mask_classes"].tolist(),
)
}
else:
metadata = {}
scene_box_aabb = torch.from_numpy(data["scene_box"])
scene_box = SceneBox(aabb=scene_box_aabb)
camera_np = data["cameras"].item()
distortion_params = None
if "distortion_params" in camera_np.keys():
distortion_params = torch.from_numpy(camera_np["distortion_params"])
cameras = Cameras(
fx=torch.from_numpy(camera_np["fx"]),
fy=torch.from_numpy(camera_np["fy"]),
cx=torch.from_numpy(camera_np["cx"]),
cy=torch.from_numpy(camera_np["cy"]),
distortion_params=distortion_params,
height=torch.from_numpy(camera_np["height"]),
width=torch.from_numpy(camera_np["width"]),
camera_to_worlds=torch.from_numpy(camera_np["camera_to_worlds"])[:, :3, :4],
camera_type=torch.from_numpy(camera_np["camera_type"]),
times=torch.from_numpy(camera_np["times"]) if "times" in camera_np.keys() else None,
)
applied_scale = 1.0
applied_transform = torch.eye(4, dtype=torch.float32)[:3, :]
if "applied_scale" in data.keys():
applied_scale = float(data["applied_scale"])
if "applied_transform" in data.keys():
applied_transform = data["applied_transform"].astype(np.float32)
assert applied_transform.shape == (3, 4)
dataparser_outputs = DataparserOutputs(
image_filenames=image_filenames,
cameras=cameras,
scene_box=scene_box,
mask_filenames=mask_filenames,
dataparser_transform=applied_transform,
dataparser_scale=applied_scale,
metadata=metadata,
)
return dataparser_outputs