RoMa: 3D Rotations in PyTorch

1.5.6 · active · verified Thu Apr 16

RoMa (Rotation Manipulation) is a lightweight Python library designed to simplify the handling of 3D rotations within PyTorch. It provides differentiable mappings between various 3D rotation representations (e.g., rotation vectors, quaternions, rotation matrices, Euler angles), mappings from Euclidean to rotation space, and a suite of utilities for rotation-related operations. It aims to be an easy-to-use and efficient toolbox for machine learning and gradient-based optimization applications. The current version is 1.5.6, with development actively maintained by NAVER Corp..

Common errors

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to initialize `roma` and perform fundamental conversions between different 3D rotation representations: rotation vectors, unit quaternions, and rotation matrices. It also shows the use of `special_procrustes` to orthonormalize an arbitrary matrix into a valid rotation matrix.

import torch
import roma

# Example: Convert a batch of rotation vectors to unit quaternions and then to rotation matrices
batch_shape = (2, 3) # Example: 2 batches of 3 rotations each

# Generate random rotation vectors (3D tensor)
rotvec = torch.randn(batch_shape + (3,))
print(f"Rotation vector shape: {rotvec.shape}")

# Convert rotation vectors to unit quaternions (XYZW convention)
q = roma.rotvec_to_unitquat(rotvec)
print(f"Unit quaternion shape: {q.shape}")

# Convert unit quaternions to rotation matrices (3x3 tensor)
R = roma.unitquat_to_rotmat(q)
print(f"Rotation matrix shape: {R.shape}")

# Direct conversion from rotation vector to rotation matrix
R_direct = roma.rotvec_to_rotmat(rotvec)
print(f"Direct Rotation matrix shape: {R_direct.shape}")
assert torch.allclose(R, R_direct, atol=1e-6)

# Example: Special Procrustes orthonormalization
# Projects an arbitrary 3x3 matrix onto the closest rotation matrix
random_matrix = torch.randn(batch_shape + (3, 3))
R_procrustes = roma.special_procrustes(random_matrix)
print(f"Procrustes orthonormalized matrix shape: {R_procrustes.shape}")

view raw JSON →