Source code for nerfstudio.data.dataparsers.minimal_dataparser

# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Data parser for pre-prepared datasets for all cameras, with no additional processing needed
Optional fields - semantics, mask_filenames, cameras.distortion_params, cameras.times
"""

from dataclasses import dataclass, field
from pathlib import Path
from typing import Type

import numpy as np
import torch

from nerfstudio.cameras.cameras import Cameras
from nerfstudio.data.dataparsers.base_dataparser import DataParser, DataParserConfig, DataparserOutputs, Semantics
from nerfstudio.data.scene_box import SceneBox


[docs]@dataclass class MinimalDataParserConfig(DataParserConfig): """Minimal dataset config""" _target: Type = field(default_factory=lambda: MinimalDataParser) """target class to instantiate""" data: Path = Path("/home/nikhil/nerfstudio-main/tests/data/lego_test/minimal_parser")
[docs]@dataclass class MinimalDataParser(DataParser): """Minimal DatasetParser""" config: MinimalDataParserConfig def _generate_dataparser_outputs(self, split="train"): filepath = self.config.data / f"{split}.npz" data = np.load(filepath, allow_pickle=True) image_filenames = [filepath.parent / path for path in data["image_filenames"].tolist()] mask_filenames = None if "mask_filenames" in data.keys(): mask_filenames = [filepath.parent / path for path in data["mask_filenames"].tolist()] if "semantics" in data.keys(): semantics = data["semantics"].item() metadata = { "semantics": Semantics( filenames=[filepath.parent / path for path in semantics["filenames"].tolist()], classes=semantics["classes"].tolist(), colors=torch.from_numpy(semantics["colors"]), mask_classes=semantics["mask_classes"].tolist(), ) } else: metadata = {} scene_box_aabb = torch.from_numpy(data["scene_box"]) scene_box = SceneBox(aabb=scene_box_aabb) camera_np = data["cameras"].item() distortion_params = None if "distortion_params" in camera_np.keys(): distortion_params = torch.from_numpy(camera_np["distortion_params"]) cameras = Cameras( fx=torch.from_numpy(camera_np["fx"]), fy=torch.from_numpy(camera_np["fy"]), cx=torch.from_numpy(camera_np["cx"]), cy=torch.from_numpy(camera_np["cy"]), distortion_params=distortion_params, height=torch.from_numpy(camera_np["height"]), width=torch.from_numpy(camera_np["width"]), camera_to_worlds=torch.from_numpy(camera_np["camera_to_worlds"])[:, :3, :4], camera_type=torch.from_numpy(camera_np["camera_type"]), times=torch.from_numpy(camera_np["times"]) if "times" in camera_np.keys() else None, ) applied_scale = 1.0 applied_transform = torch.eye(4, dtype=torch.float32)[:3, :] if "applied_scale" in data.keys(): applied_scale = float(data["applied_scale"]) if "applied_transform" in data.keys(): applied_transform = data["applied_transform"].astype(np.float32) assert applied_transform.shape == (3, 4) dataparser_outputs = DataparserOutputs( image_filenames=image_filenames, cameras=cameras, scene_box=scene_box, mask_filenames=mask_filenames, dataparser_transform=applied_transform, dataparser_scale=applied_scale, metadata=metadata, ) return dataparser_outputs