Source code for nerfstudio.cameras.lie_groups

# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Helper for Lie group operations. Currently only used for pose optimization.
"""

import torch
from jaxtyping import Float
from torch import Tensor


# We make an exception on snake case conventions because SO3 != so3.
[docs]def exp_map_SO3xR3(tangent_vector: Float[Tensor, "b 6"]) -> Float[Tensor, "b 3 4"]: """Compute the exponential map of the direct product group `SO(3) x R^3`. This can be used for learning pose deltas on SE(3), and is generally faster than `exp_map_SE3`. Args: tangent_vector: Tangent vector; length-3 translations, followed by an `so(3)` tangent vector. Returns: [R|t] transformation matrices. """ # code for SO3 map grabbed from pytorch3d and stripped down to bare-bones log_rot = tangent_vector[:, 3:] nrms = (log_rot * log_rot).sum(1) rot_angles = torch.clamp(nrms, 1e-4).sqrt() rot_angles_inv = 1.0 / rot_angles fac1 = rot_angles_inv * rot_angles.sin() fac2 = rot_angles_inv * rot_angles_inv * (1.0 - rot_angles.cos()) skews = torch.zeros((log_rot.shape[0], 3, 3), dtype=log_rot.dtype, device=log_rot.device) skews[:, 0, 1] = -log_rot[:, 2] skews[:, 0, 2] = log_rot[:, 1] skews[:, 1, 0] = log_rot[:, 2] skews[:, 1, 2] = -log_rot[:, 0] skews[:, 2, 0] = -log_rot[:, 1] skews[:, 2, 1] = log_rot[:, 0] skews_square = torch.bmm(skews, skews) ret = torch.zeros(tangent_vector.shape[0], 3, 4, dtype=tangent_vector.dtype, device=tangent_vector.device) ret[:, :3, :3] = ( fac1[:, None, None] * skews + fac2[:, None, None] * skews_square + torch.eye(3, dtype=log_rot.dtype, device=log_rot.device)[None] ) # Compute the translation ret[:, :3, 3] = tangent_vector[:, :3] return ret
[docs]def exp_map_SE3(tangent_vector: Float[Tensor, "b 6"]) -> Float[Tensor, "b 3 4"]: """Compute the exponential map `se(3) -> SE(3)`. This can be used for learning pose deltas on `SE(3)`. Args: tangent_vector: A tangent vector from `se(3)`. Returns: [R|t] transformation matrices. """ tangent_vector_lin = tangent_vector[:, :3].view(-1, 3, 1) tangent_vector_ang = tangent_vector[:, 3:].view(-1, 3, 1) theta = torch.linalg.norm(tangent_vector_ang, dim=1).unsqueeze(1) theta2 = theta**2 theta3 = theta**3 near_zero = theta < 1e-2 non_zero = torch.ones(1, dtype=tangent_vector.dtype, device=tangent_vector.device) theta_nz = torch.where(near_zero, non_zero, theta) theta2_nz = torch.where(near_zero, non_zero, theta2) theta3_nz = torch.where(near_zero, non_zero, theta3) # Compute the rotation sine = theta.sin() cosine = torch.where(near_zero, 8 / (4 + theta2) - 1, theta.cos()) sine_by_theta = torch.where(near_zero, 0.5 * cosine + 0.5, sine / theta_nz) one_minus_cosine_by_theta2 = torch.where(near_zero, 0.5 * sine_by_theta, (1 - cosine) / theta2_nz) ret = torch.zeros(tangent_vector.shape[0], 3, 4).to(dtype=tangent_vector.dtype, device=tangent_vector.device) ret[:, :3, :3] = one_minus_cosine_by_theta2 * tangent_vector_ang @ tangent_vector_ang.transpose(1, 2) ret[:, 0, 0] += cosine.view(-1) ret[:, 1, 1] += cosine.view(-1) ret[:, 2, 2] += cosine.view(-1) temp = sine_by_theta.view(-1, 1) * tangent_vector_ang.view(-1, 3) ret[:, 0, 1] -= temp[:, 2] ret[:, 1, 0] += temp[:, 2] ret[:, 0, 2] += temp[:, 1] ret[:, 2, 0] -= temp[:, 1] ret[:, 1, 2] -= temp[:, 0] ret[:, 2, 1] += temp[:, 0] # Compute the translation sine_by_theta = torch.where(near_zero, 1 - theta2 / 6, sine_by_theta) one_minus_cosine_by_theta2 = torch.where(near_zero, 0.5 - theta2 / 24, one_minus_cosine_by_theta2) theta_minus_sine_by_theta3_t = torch.where(near_zero, 1.0 / 6 - theta2 / 120, (theta - sine) / theta3_nz) ret[:, :, 3:] = sine_by_theta * tangent_vector_lin ret[:, :, 3:] += one_minus_cosine_by_theta2 * torch.cross(tangent_vector_ang, tangent_vector_lin, dim=1) ret[:, :, 3:] += theta_minus_sine_by_theta3_t * ( tangent_vector_ang @ (tangent_vector_ang.transpose(1, 2) @ tangent_vector_lin) ) return ret