# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Code to train model.
"""
from __future__ import annotations
import dataclasses
import functools
import os
import time
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from threading import Lock
from typing import DefaultDict, Dict, List, Literal, Optional, Tuple, Type, cast
import torch
import viser
from rich import box, style
from rich.panel import Panel
from rich.table import Table
from torch.cuda.amp.grad_scaler import GradScaler
from nerfstudio.configs.experiment_config import ExperimentConfig
from nerfstudio.engine.callbacks import TrainingCallback, TrainingCallbackAttributes, TrainingCallbackLocation
from nerfstudio.engine.optimizers import Optimizers
from nerfstudio.pipelines.base_pipeline import VanillaPipeline
from nerfstudio.utils import profiler, writer
from nerfstudio.utils.decorators import check_eval_enabled, check_main_thread, check_viewer_enabled
from nerfstudio.utils.misc import step_check
from nerfstudio.utils.rich_utils import CONSOLE
from nerfstudio.utils.writer import EventName, TimeWriter
from nerfstudio.viewer.viewer import Viewer as ViewerState
from nerfstudio.viewer_legacy.server.viewer_state import ViewerLegacyState
TRAIN_INTERATION_OUTPUT = Tuple[torch.Tensor, Dict[str, torch.Tensor], Dict[str, torch.Tensor]]
TORCH_DEVICE = str
[docs]@dataclass
class TrainerConfig(ExperimentConfig):
"""Configuration for training regimen"""
_target: Type = field(default_factory=lambda: Trainer)
"""target class to instantiate"""
steps_per_save: int = 1000
"""Number of steps between saves."""
steps_per_eval_batch: int = 500
"""Number of steps between randomly sampled batches of rays."""
steps_per_eval_image: int = 500
"""Number of steps between single eval images."""
steps_per_eval_all_images: int = 25000
"""Number of steps between eval all images."""
max_num_iterations: int = 1000000
"""Maximum number of iterations to run."""
mixed_precision: bool = False
"""Whether or not to use mixed precision for training."""
use_grad_scaler: bool = False
"""Use gradient scaler even if the automatic mixed precision is disabled."""
save_only_latest_checkpoint: bool = True
"""Whether to only save the latest checkpoint or all checkpoints."""
# optional parameters if we want to resume training
load_dir: Optional[Path] = None
"""Optionally specify a pre-trained model directory to load from."""
load_step: Optional[int] = None
"""Optionally specify model step to load from; if none, will find most recent model in load_dir."""
load_config: Optional[Path] = None
"""Path to config YAML file."""
load_checkpoint: Optional[Path] = None
"""Path to checkpoint file."""
log_gradients: bool = False
"""Optionally log gradients during training"""
gradient_accumulation_steps: Dict[str, int] = field(default_factory=lambda: {})
"""Number of steps to accumulate gradients over. Contains a mapping of {param_group:num}"""
start_paused: bool = False
"""Whether to start the training in a paused state."""
[docs]class Trainer:
"""Trainer class
Args:
config: The configuration object.
local_rank: Local rank of the process.
world_size: World size of the process.
Attributes:
config: The configuration object.
local_rank: Local rank of the process.
world_size: World size of the process.
device: The device to run the training on.
pipeline: The pipeline object.
optimizers: The optimizers object.
callbacks: The callbacks object.
training_state: Current model training state.
"""
pipeline: VanillaPipeline
optimizers: Optimizers
callbacks: List[TrainingCallback]
def __init__(self, config: TrainerConfig, local_rank: int = 0, world_size: int = 1) -> None:
self.train_lock = Lock()
self.config = config
self.local_rank = local_rank
self.world_size = world_size
self.device: TORCH_DEVICE = config.machine.device_type
if self.device == "cuda":
self.device += f":{local_rank}"
self.mixed_precision: bool = self.config.mixed_precision
self.use_grad_scaler: bool = self.mixed_precision or self.config.use_grad_scaler
self.training_state: Literal["training", "paused", "completed"] = (
"paused" if self.config.start_paused else "training"
)
self.gradient_accumulation_steps: DefaultDict = defaultdict(lambda: 1)
self.gradient_accumulation_steps.update(self.config.gradient_accumulation_steps)
if self.device == "cpu":
self.mixed_precision = False
CONSOLE.print("Mixed precision is disabled for CPU training.")
self._start_step: int = 0
# optimizers
self.grad_scaler = GradScaler(enabled=self.use_grad_scaler)
self.base_dir: Path = config.get_base_dir()
# directory to save checkpoints
self.checkpoint_dir: Path = config.get_checkpoint_dir()
CONSOLE.log(f"Saving checkpoints to: {self.checkpoint_dir}")
self.viewer_state = None
# used to keep track of the current step
self.step = 0
[docs] def setup(self, test_mode: Literal["test", "val", "inference"] = "val") -> None:
"""Setup the Trainer by calling other setup functions.
Args:
test_mode:
'val': loads train/val datasets into memory
'test': loads train/test datasets into memory
'inference': does not load any dataset into memory
"""
self.pipeline = self.config.pipeline.setup(
device=self.device,
test_mode=test_mode,
world_size=self.world_size,
local_rank=self.local_rank,
grad_scaler=self.grad_scaler,
)
self.optimizers = self.setup_optimizers()
# set up viewer if enabled
viewer_log_path = self.base_dir / self.config.viewer.relative_log_filename
self.viewer_state, banner_messages = None, None
if self.config.is_viewer_legacy_enabled() and self.local_rank == 0:
datapath = self.config.data
if datapath is None:
datapath = self.base_dir
self.viewer_state = ViewerLegacyState(
self.config.viewer,
log_filename=viewer_log_path,
datapath=datapath,
pipeline=self.pipeline,
trainer=self,
train_lock=self.train_lock,
)
banner_messages = [f"Legacy viewer at: {self.viewer_state.viewer_url}"]
if self.config.is_viewer_enabled() and self.local_rank == 0:
datapath = self.config.data
if datapath is None:
datapath = self.base_dir
self.viewer_state = ViewerState(
self.config.viewer,
log_filename=viewer_log_path,
datapath=datapath,
pipeline=self.pipeline,
trainer=self,
train_lock=self.train_lock,
share=self.config.viewer.make_share_url,
)
banner_messages = self.viewer_state.viewer_info
self._check_viewer_warnings()
self._load_checkpoint()
self.callbacks = self.pipeline.get_training_callbacks(
TrainingCallbackAttributes(
optimizers=self.optimizers, grad_scaler=self.grad_scaler, pipeline=self.pipeline, trainer=self
)
)
# set up writers/profilers if enabled
writer_log_path = self.base_dir / self.config.logging.relative_log_dir
writer.setup_event_writer(
self.config.is_wandb_enabled(),
self.config.is_tensorboard_enabled(),
self.config.is_comet_enabled(),
log_dir=writer_log_path,
experiment_name=self.config.experiment_name,
project_name=self.config.project_name,
)
writer.setup_local_writer(
self.config.logging, max_iter=self.config.max_num_iterations, banner_messages=banner_messages
)
writer.put_config(name="config", config_dict=dataclasses.asdict(self.config), step=0)
profiler.setup_profiler(self.config.logging, writer_log_path)
[docs] def setup_optimizers(self) -> Optimizers:
"""Helper to set up the optimizers
Returns:
The optimizers object given the trainer config.
"""
optimizer_config = self.config.optimizers.copy()
param_groups = self.pipeline.get_param_groups()
return Optimizers(optimizer_config, param_groups)
[docs] def train(self) -> None:
"""Train the model."""
assert self.pipeline.datamanager.train_dataset is not None, "Missing DatsetInputs"
if hasattr(self.pipeline.datamanager, "train_dataparser_outputs"):
self.pipeline.datamanager.train_dataparser_outputs.save_dataparser_transform(
self.base_dir / "dataparser_transforms.json"
)
self._init_viewer_state()
with TimeWriter(writer, EventName.TOTAL_TRAIN_TIME):
num_iterations = self.config.max_num_iterations
step = 0
self.stop_training = False
for step in range(self._start_step, self._start_step + num_iterations):
self.step = step
if self.stop_training:
break
while self.training_state == "paused":
if self.stop_training:
self._after_train()
return
time.sleep(0.01)
with self.train_lock:
with TimeWriter(writer, EventName.ITER_TRAIN_TIME, step=step) as train_t:
self.pipeline.train()
# training callbacks before the training iteration
for callback in self.callbacks:
callback.run_callback_at_location(
step, location=TrainingCallbackLocation.BEFORE_TRAIN_ITERATION
)
# time the forward pass
loss, loss_dict, metrics_dict = self.train_iteration(step)
# training callbacks after the training iteration
for callback in self.callbacks:
callback.run_callback_at_location(
step, location=TrainingCallbackLocation.AFTER_TRAIN_ITERATION
)
# Skip the first two steps to avoid skewed timings that break the viewer rendering speed estimate.
if step > 1:
writer.put_time(
name=EventName.TRAIN_RAYS_PER_SEC,
duration=self.world_size
* self.pipeline.datamanager.get_train_rays_per_batch()
/ max(0.001, train_t.duration),
step=step,
avg_over_steps=True,
)
self._update_viewer_state(step)
# a batch of train rays
if step_check(step, self.config.logging.steps_per_log, run_at_zero=True):
writer.put_scalar(name="Train Loss", scalar=loss, step=step)
writer.put_dict(name="Train Loss Dict", scalar_dict=loss_dict, step=step)
writer.put_dict(name="Train Metrics Dict", scalar_dict=metrics_dict, step=step)
# The actual memory allocated by Pytorch. This is likely less than the amount
# shown in nvidia-smi since some unused memory can be held by the caching
# allocator and some context needs to be created on GPU. See Memory management
# (https://pytorch.org/docs/stable/notes/cuda.html#cuda-memory-management)
# for more details about GPU memory management.
writer.put_scalar(
name="GPU Memory (MB)", scalar=torch.cuda.max_memory_allocated() / (1024**2), step=step
)
# Do not perform evaluation if there are no validation images
if self.pipeline.datamanager.eval_dataset:
with self.train_lock:
self.eval_iteration(step)
if step_check(step, self.config.steps_per_save):
self.save_checkpoint(step)
writer.write_out_storage()
# save checkpoint at the end of training, and write out any remaining events
self._after_train()
[docs] def shutdown(self) -> None:
"""Stop the trainer and stop all associated threads/processes (such as the viewer)."""
self.stop_training = True # tell the training loop to stop
if self.viewer_state is not None:
# stop the viewer
# this condition excludes the case where `viser_server` is either `None` or an
# instance of `viewer_legacy`'s `ViserServer` instead of the upstream one.
if isinstance(self.viewer_state.viser_server, viser.ViserServer):
self.viewer_state.viser_server.stop()
def _after_train(self) -> None:
"""Function to run after training is complete"""
self.training_state = "completed" # used to update the webui state
# save checkpoint at the end of training
self.save_checkpoint(self.step)
# write out any remaining events (e.g., total train time)
writer.write_out_storage()
table = Table(
title=None,
show_header=False,
box=box.MINIMAL,
title_style=style.Style(bold=True),
)
table.add_row("Config File", str(self.config.get_base_dir() / "config.yml"))
table.add_row("Checkpoint Directory", str(self.checkpoint_dir))
CONSOLE.print(Panel(table, title="[bold][green]:tada: Training Finished :tada:[/bold]", expand=False))
# after train end callbacks
for callback in self.callbacks:
callback.run_callback_at_location(step=self.step, location=TrainingCallbackLocation.AFTER_TRAIN)
if not self.config.viewer.quit_on_train_completion:
self._train_complete_viewer()
@check_main_thread
def _check_viewer_warnings(self) -> None:
"""Helper to print out any warnings regarding the way the viewer/loggers are enabled"""
if (
(self.config.is_viewer_legacy_enabled() or self.config.is_viewer_enabled())
and not self.config.is_tensorboard_enabled()
and not self.config.is_wandb_enabled()
and not self.config.is_comet_enabled()
):
string: str = (
"[NOTE] Not running eval iterations since only viewer is enabled.\n"
"Use [yellow]--vis {wandb, tensorboard, viewer+wandb, viewer+tensorboard}[/yellow] to run with eval."
)
CONSOLE.print(f"{string}")
@check_viewer_enabled
def _init_viewer_state(self) -> None:
"""Initializes viewer scene with given train dataset"""
assert self.viewer_state and self.pipeline.datamanager.train_dataset
self.viewer_state.init_scene(
train_dataset=self.pipeline.datamanager.train_dataset,
train_state=self.training_state,
eval_dataset=self.pipeline.datamanager.eval_dataset,
)
@check_viewer_enabled
def _update_viewer_state(self, step: int) -> None:
"""Updates the viewer state by rendering out scene with current pipeline
Returns the time taken to render scene.
Args:
step: current train step
"""
assert self.viewer_state is not None
num_rays_per_batch: int = self.pipeline.datamanager.get_train_rays_per_batch()
try:
self.viewer_state.update_scene(step, num_rays_per_batch)
except RuntimeError:
time.sleep(0.03) # sleep to allow buffer to reset
CONSOLE.log("Viewer failed. Continuing training.")
@check_viewer_enabled
def _train_complete_viewer(self) -> None:
"""Let the viewer know that the training is complete"""
assert self.viewer_state is not None
self.training_state = "completed"
try:
self.viewer_state.training_complete()
except RuntimeError:
time.sleep(0.03) # sleep to allow buffer to reset
CONSOLE.log("Viewer failed. Continuing training.")
CONSOLE.print("Use ctrl+c to quit", justify="center")
while True:
time.sleep(0.01)
@check_viewer_enabled
def _update_viewer_rays_per_sec(self, train_t: TimeWriter, vis_t: TimeWriter, step: int) -> None:
"""Performs update on rays/sec calculation for training
Args:
train_t: timer object carrying time to execute total training iteration
vis_t: timer object carrying time to execute visualization step
step: current step
"""
train_num_rays_per_batch: int = self.pipeline.datamanager.get_train_rays_per_batch()
writer.put_time(
name=EventName.TRAIN_RAYS_PER_SEC,
duration=self.world_size * train_num_rays_per_batch / (train_t.duration - vis_t.duration),
step=step,
avg_over_steps=True,
)
def _load_checkpoint(self) -> None:
"""Helper function to load pipeline and optimizer from prespecified checkpoint"""
load_dir = self.config.load_dir
load_checkpoint = self.config.load_checkpoint
if load_dir is not None:
load_step = self.config.load_step
if load_step is None:
print("Loading latest Nerfstudio checkpoint from load_dir...")
# NOTE: this is specific to the checkpoint name format
load_step = sorted(int(x[x.find("-") + 1 : x.find(".")]) for x in os.listdir(load_dir))[-1]
load_path: Path = load_dir / f"step-{load_step:09d}.ckpt"
assert load_path.exists(), f"Checkpoint {load_path} does not exist"
loaded_state = torch.load(load_path, map_location="cpu")
self._start_step = loaded_state["step"] + 1
# load the checkpoints for pipeline, optimizers, and gradient scalar
self.pipeline.load_pipeline(loaded_state["pipeline"], loaded_state["step"])
self.optimizers.load_optimizers(loaded_state["optimizers"])
if "schedulers" in loaded_state and self.config.load_scheduler:
self.optimizers.load_schedulers(loaded_state["schedulers"])
self.grad_scaler.load_state_dict(loaded_state["scalers"])
CONSOLE.print(f"Done loading Nerfstudio checkpoint from {load_path}")
elif load_checkpoint is not None:
assert load_checkpoint.exists(), f"Checkpoint {load_checkpoint} does not exist"
loaded_state = torch.load(load_checkpoint, map_location="cpu")
self._start_step = loaded_state["step"] + 1
# load the checkpoints for pipeline, optimizers, and gradient scalar
self.pipeline.load_pipeline(loaded_state["pipeline"], loaded_state["step"])
self.optimizers.load_optimizers(loaded_state["optimizers"])
if "schedulers" in loaded_state and self.config.load_scheduler:
self.optimizers.load_schedulers(loaded_state["schedulers"])
self.grad_scaler.load_state_dict(loaded_state["scalers"])
CONSOLE.print(f"Done loading Nerfstudio checkpoint from {load_checkpoint}")
else:
CONSOLE.print("No Nerfstudio checkpoint to load, so training from scratch.")
@check_main_thread
def save_checkpoint(self, step: int) -> None:
"""Save the model and optimizers
Args:
step: number of steps in training for given checkpoint
"""
# possibly make the checkpoint directory
if not self.checkpoint_dir.exists():
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
# save the checkpoint
ckpt_path: Path = self.checkpoint_dir / f"step-{step:09d}.ckpt"
torch.save(
{
"step": step,
"pipeline": self.pipeline.module.state_dict() # type: ignore
if hasattr(self.pipeline, "module")
else self.pipeline.state_dict(),
"optimizers": {k: v.state_dict() for (k, v) in self.optimizers.optimizers.items()},
"schedulers": {k: v.state_dict() for (k, v) in self.optimizers.schedulers.items()},
"scalers": self.grad_scaler.state_dict(),
},
ckpt_path,
)
# possibly delete old checkpoints
if self.config.save_only_latest_checkpoint:
# delete every other checkpoint in the checkpoint folder
for f in self.checkpoint_dir.glob("*.ckpt"):
if f != ckpt_path:
f.unlink()
[docs] @profiler.time_function
def train_iteration(self, step: int) -> TRAIN_INTERATION_OUTPUT:
"""Run one iteration with a batch of inputs. Returns dictionary of model losses.
Args:
step: Current training step.
"""
needs_zero = [
group for group in self.optimizers.parameters.keys() if step % self.gradient_accumulation_steps[group] == 0
]
self.optimizers.zero_grad_some(needs_zero)
cpu_or_cuda_str: str = self.device.split(":")[0]
cpu_or_cuda_str = "cpu" if cpu_or_cuda_str == "mps" else cpu_or_cuda_str
with torch.autocast(device_type=cpu_or_cuda_str, enabled=self.mixed_precision):
_, loss_dict, metrics_dict = self.pipeline.get_train_loss_dict(step=step)
loss = functools.reduce(torch.add, loss_dict.values())
self.grad_scaler.scale(loss).backward() # type: ignore
needs_step = [
group
for group in self.optimizers.parameters.keys()
if step % self.gradient_accumulation_steps[group] == self.gradient_accumulation_steps[group] - 1
]
self.optimizers.optimizer_scaler_step_some(self.grad_scaler, needs_step)
if self.config.log_gradients:
total_grad = 0
for tag, value in self.pipeline.model.named_parameters():
assert tag != "Total"
if value.grad is not None:
grad = value.grad.norm()
metrics_dict[f"Gradients/{tag}"] = grad # type: ignore
total_grad += grad
metrics_dict["Gradients/Total"] = cast(torch.Tensor, total_grad) # type: ignore
scale = self.grad_scaler.get_scale()
self.grad_scaler.update()
# If the gradient scaler is decreased, no optimization step is performed so we should not step the scheduler.
if scale <= self.grad_scaler.get_scale():
self.optimizers.scheduler_step_all(step)
# Merging loss and metrics dict into a single output.
return loss, loss_dict, metrics_dict # type: ignore
@check_eval_enabled
@profiler.time_function
def eval_iteration(self, step: int) -> None:
"""Run one iteration with different batch/image/all image evaluations depending on step size.
Args:
step: Current training step.
"""
# a batch of eval rays
if step_check(step, self.config.steps_per_eval_batch):
_, eval_loss_dict, eval_metrics_dict = self.pipeline.get_eval_loss_dict(step=step)
eval_loss = functools.reduce(torch.add, eval_loss_dict.values())
writer.put_scalar(name="Eval Loss", scalar=eval_loss, step=step)
writer.put_dict(name="Eval Loss Dict", scalar_dict=eval_loss_dict, step=step)
writer.put_dict(name="Eval Metrics Dict", scalar_dict=eval_metrics_dict, step=step)
# one eval image
if step_check(step, self.config.steps_per_eval_image):
with TimeWriter(writer, EventName.TEST_RAYS_PER_SEC, write=False) as test_t:
metrics_dict, images_dict = self.pipeline.get_eval_image_metrics_and_images(step=step)
writer.put_time(
name=EventName.TEST_RAYS_PER_SEC,
duration=metrics_dict["num_rays"] / test_t.duration,
step=step,
avg_over_steps=True,
)
writer.put_dict(name="Eval Images Metrics", scalar_dict=metrics_dict, step=step)
group = "Eval Images"
for image_name, image in images_dict.items():
writer.put_image(name=group + "/" + image_name, image=image, step=step)
# all eval images
if step_check(step, self.config.steps_per_eval_all_images):
metrics_dict = self.pipeline.get_average_eval_image_metrics(step=step)
writer.put_dict(name="Eval Images Metrics Dict (all images)", scalar_dict=metrics_dict, step=step)