We have now released v0.3.0! Please use the latest version for the best experience.

Source code for omni.isaac.orbit.utils.math

# Copyright (c) 2022-2024, The ORBIT Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

"""Sub-module containing utilities for various math operations."""

# needed to import for allowing type-hinting: torch.Tensor | np.ndarray
from __future__ import annotations

import numpy as np
import torch
import torch.nn.functional
from typing import Literal

"""
General
"""


@torch.jit.script
def scale_transform(x: torch.Tensor, lower: torch.Tensor, upper: torch.Tensor) -> torch.Tensor:
    """Normalizes a given input tensor to a range of [-1, 1].

    .. note::
        It uses pytorch broadcasting functionality to deal with batched input.

    Args:
        x: Input tensor of shape (N, dims).
        lower: The minimum value of the tensor. Shape is (N, dims) or (dims,).
        upper: The maximum value of the tensor. Shape is (N, dims) or (dims,).

    Returns:
        Normalized transform of the tensor. Shape is (N, dims).
    """
    # default value of center
    offset = (lower + upper) * 0.5
    # return normalized tensor
    return 2 * (x - offset) / (upper - lower)


@torch.jit.script
def unscale_transform(x: torch.Tensor, lower: torch.Tensor, upper: torch.Tensor) -> torch.Tensor:
    """De-normalizes a given input tensor from range of [-1, 1] to (lower, upper).

    .. note::
        It uses pytorch broadcasting functionality to deal with batched input.

    Args:
        x: Input tensor of shape (N, dims).
        lower: The minimum value of the tensor. Shape is (N, dims) or (dims,).
        upper: The maximum value of the tensor. Shape is (N, dims) or (dims,).

    Returns:
        De-normalized transform of the tensor. Shape is (N, dims).
    """
    # default value of center
    offset = (lower + upper) * 0.5
    # return normalized tensor
    return x * (upper - lower) * 0.5 + offset


@torch.jit.script
def saturate(x: torch.Tensor, lower: torch.Tensor, upper: torch.Tensor) -> torch.Tensor:
    """Clamps a given input tensor to (lower, upper).

    It uses pytorch broadcasting functionality to deal with batched input.

    Args:
        x: Input tensor of shape (N, dims).
        lower: The minimum value of the tensor. Shape is (N, dims) or (dims,).
        upper: The maximum value of the tensor. Shape is (N, dims) or (dims,).

    Returns:
        Clamped transform of the tensor. Shape is (N, dims).
    """
    return torch.max(torch.min(x, upper), lower)


@torch.jit.script
def normalize(x: torch.Tensor, eps: float = 1e-9) -> torch.Tensor:
    """Normalizes a given input tensor to unit length.

    Args:
        x: Input tensor of shape (N, dims).
        eps: A small value to avoid division by zero. Defaults to 1e-9.

    Returns:
        Normalized tensor of shape (N, dims).
    """
    return x / x.norm(p=2, dim=-1).clamp(min=eps, max=None).unsqueeze(-1)


@torch.jit.script
def wrap_to_pi(angles: torch.Tensor) -> torch.Tensor:
    """Wraps input angles (in radians) to the range [-pi, pi].

    Args:
        angles: Input angles of any shape.

    Returns:
        Angles in the range [-pi, pi].
    """
    angles = angles.clone()
    angles %= 2 * torch.pi
    angles -= 2 * torch.pi * (angles > torch.pi)
    return angles


@torch.jit.script
def copysign(mag: float, other: torch.Tensor) -> torch.Tensor:
    """Create a new floating-point tensor with the magnitude of input and the sign of other, element-wise.

    Note:
        The implementation follows from `torch.copysign`. The function allows a scalar magnitude.

    Args:
        mag: The magnitude scalar.
        other: The tensor containing values whose signbits are applied to magnitude.

    Returns:
        The output tensor.
    """
    mag = torch.tensor(mag, device=other.device, dtype=torch.float).repeat(other.shape[0])
    return torch.abs(mag) * torch.sign(other)


"""
Rotation
"""


@torch.jit.script
def matrix_from_quat(quaternions: torch.Tensor) -> torch.Tensor:
    """Convert rotations given as quaternions to rotation matrices.

    Args:
        quaternions: The quaternion orientation in (w, x, y, z). Shape is (..., 4).

    Returns:
        Rotation matrices. The shape is (..., 3, 3).

    Reference:
        https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py#L41-L70
    """
    r, i, j, k = torch.unbind(quaternions, -1)
    # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
    two_s = 2.0 / (quaternions * quaternions).sum(-1)

    o = torch.stack(
        (
            1 - two_s * (j * j + k * k),
            two_s * (i * j - k * r),
            two_s * (i * k + j * r),
            two_s * (i * j + k * r),
            1 - two_s * (i * i + k * k),
            two_s * (j * k - i * r),
            two_s * (i * k - j * r),
            two_s * (j * k + i * r),
            1 - two_s * (i * i + j * j),
        ),
        -1,
    )
    return o.reshape(quaternions.shape[:-1] + (3, 3))


[docs]def convert_quat(quat: torch.Tensor | np.ndarray, to: Literal["xyzw", "wxyz"] = "xyzw") -> torch.Tensor | np.ndarray: """Converts quaternion from one convention to another. The convention to convert TO is specified as an optional argument. If to == 'xyzw', then the input is in 'wxyz' format, and vice-versa. Args: quat: The quaternion of shape (..., 4). to: Convention to convert quaternion to.. Defaults to "xyzw". Returns: The converted quaternion in specified convention. Raises: ValueError: Invalid input argument `to`, i.e. not "xyzw" or "wxyz". ValueError: Invalid shape of input `quat`, i.e. not (..., 4,). """ # check input is correct if quat.shape[-1] != 4: msg = f"Expected input quaternion shape mismatch: {quat.shape} != (..., 4)." raise ValueError(msg) if to not in ["xyzw", "wxyz"]: msg = f"Expected input argument `to` to be 'xyzw' or 'wxyz'. Received: {to}." raise ValueError(msg) # check if input is numpy array (we support this backend since some classes use numpy) if isinstance(quat, np.ndarray): # use numpy functions if to == "xyzw": # wxyz -> xyzw return np.roll(quat, -1, axis=-1) else: # xyzw -> wxyz return np.roll(quat, 1, axis=-1) else: # convert to torch (sanity check) if not isinstance(quat, torch.Tensor): quat = torch.tensor(quat, dtype=float) # convert to specified quaternion type if to == "xyzw": # wxyz -> xyzw return quat.roll(-1, dims=-1) else: # xyzw -> wxyz return quat.roll(1, dims=-1)
@torch.jit.script def quat_conjugate(q: torch.Tensor) -> torch.Tensor: """Computes the conjugate of a quaternion. Args: q: The quaternion orientation in (w, x, y, z). Shape is (..., 4). Returns: The conjugate quaternion in (w, x, y, z). Shape is (..., 4). """ shape = q.shape q = q.reshape(-1, 4) return torch.cat((q[:, 0:1], -q[:, 1:]), dim=-1).view(shape) @torch.jit.script def quat_inv(q: torch.Tensor) -> torch.Tensor: """Compute the inverse of a quaternion. Args: q: The quaternion orientation in (w, x, y, z). Shape is (N, 4). Returns: The inverse quaternion in (w, x, y, z). Shape is (N, 4). """ return normalize(quat_conjugate(q)) @torch.jit.script def quat_from_euler_xyz(roll: torch.Tensor, pitch: torch.Tensor, yaw: torch.Tensor) -> torch.Tensor: """Convert rotations given as Euler angles in radians to Quaternions. Note: The euler angles are assumed in XYZ convention. Args: roll: Rotation around x-axis (in radians). Shape is (N,). pitch: Rotation around y-axis (in radians). Shape is (N,). yaw: Rotation around z-axis (in radians). Shape is (N,). Returns: The quaternion in (w, x, y, z). Shape is (N, 4). """ cy = torch.cos(yaw * 0.5) sy = torch.sin(yaw * 0.5) cr = torch.cos(roll * 0.5) sr = torch.sin(roll * 0.5) cp = torch.cos(pitch * 0.5) sp = torch.sin(pitch * 0.5) # compute quaternion qw = cy * cr * cp + sy * sr * sp qx = cy * sr * cp - sy * cr * sp qy = cy * cr * sp + sy * sr * cp qz = sy * cr * cp - cy * sr * sp return torch.stack([qw, qx, qy, qz], dim=-1) @torch.jit.script def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: """Returns torch.sqrt(torch.max(0, x)) but with a zero sub-gradient where x is 0. Reference: https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py#L91-L99 """ ret = torch.zeros_like(x) positive_mask = x > 0 ret[positive_mask] = torch.sqrt(x[positive_mask]) return ret @torch.jit.script def quat_from_matrix(matrix: torch.Tensor) -> torch.Tensor: """Convert rotations given as rotation matrices to quaternions. Args: matrix: The rotation matrices. Shape is (..., 3, 3). Returns: The quaternion in (w, x, y, z). Shape is (..., 4). Reference: https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py#L102-L161 """ if matrix.size(-1) != 3 or matrix.size(-2) != 3: raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") batch_dim = matrix.shape[:-2] m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1) q_abs = _sqrt_positive_part( torch.stack( [ 1.0 + m00 + m11 + m22, 1.0 + m00 - m11 - m22, 1.0 - m00 + m11 - m22, 1.0 - m00 - m11 + m22, ], dim=-1, ) ) # we produce the desired quaternion multiplied by each of r, i, j, k quat_by_rijk = torch.stack( [ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`. torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`. torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`. torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`. torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), ], dim=-2, ) # We floor here at 0.1 but the exact level is not important; if q_abs is small, # the candidate won't be picked. flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) # if not for numerical problems, quat_candidates[i] should be same (up to a sign), # forall i; we pick the best-conditioned one (with the largest denominator) return quat_candidates[torch.nn.functional.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape( batch_dim + (4,) ) def _axis_angle_rotation(axis: Literal["X", "Y", "Z"], angle: torch.Tensor) -> torch.Tensor: """Return the rotation matrices for one of the rotations about an axis of which Euler angles describe, for each value of the angle given. Args: axis: Axis label "X" or "Y or "Z". angle: Euler angles in radians of any shape. Returns: Rotation matrices. Shape is (..., 3, 3). Reference: https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py#L164-L191 """ cos = torch.cos(angle) sin = torch.sin(angle) one = torch.ones_like(angle) zero = torch.zeros_like(angle) if axis == "X": R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) elif axis == "Y": R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) elif axis == "Z": R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) else: raise ValueError("letter must be either X, Y or Z.") return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
[docs]def matrix_from_euler(euler_angles: torch.Tensor, convention: str) -> torch.Tensor: """ Convert rotations given as Euler angles in radians to rotation matrices. Args: euler_angles: Euler angles in radians. Shape is (..., 3). convention: Convention string of three uppercase letters from {"X", "Y", and "Z"}. For example, "XYZ" means that the rotations should be applied first about x, then y, then z. Returns: Rotation matrices. Shape is (..., 3, 3). Reference: https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py#L194-L220 """ if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: raise ValueError("Invalid input euler angles.") if len(convention) != 3: raise ValueError("Convention must have 3 letters.") if convention[1] in (convention[0], convention[2]): raise ValueError(f"Invalid convention {convention}.") for letter in convention: if letter not in ("X", "Y", "Z"): raise ValueError(f"Invalid letter {letter} in convention string.") matrices = [_axis_angle_rotation(c, e) for c, e in zip(convention, torch.unbind(euler_angles, -1))] # return functools.reduce(torch.matmul, matrices) return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2])
@torch.jit.script def euler_xyz_from_quat(quat: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Convert rotations given as quaternions to Euler angles in radians. Note: The euler angles are assumed in XYZ convention. Args: quat: The quaternion orientation in (w, x, y, z). Shape is (N, 4). Returns: A tuple containing roll-pitch-yaw. Each element is a tensor of shape (N,). Reference: https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles """ q_w, q_x, q_y, q_z = quat[:, 0], quat[:, 1], quat[:, 2], quat[:, 3] # roll (x-axis rotation) sin_roll = 2.0 * (q_w * q_x + q_y * q_z) cos_roll = 1 - 2 * (q_x * q_x + q_y * q_y) roll = torch.atan2(sin_roll, cos_roll) # pitch (y-axis rotation) sin_pitch = 2.0 * (q_w * q_y - q_z * q_x) pitch = torch.where(torch.abs(sin_pitch) >= 1, copysign(torch.pi / 2.0, sin_pitch), torch.asin(sin_pitch)) # yaw (z-axis rotation) sin_yaw = 2.0 * (q_w * q_z + q_x * q_y) cos_yaw = 1 - 2 * (q_y * q_y + q_z * q_z) yaw = torch.atan2(sin_yaw, cos_yaw) return roll % (2 * torch.pi), pitch % (2 * torch.pi), yaw % (2 * torch.pi) # TODO: why not wrap_to_pi here ? @torch.jit.script def quat_unique(q: torch.Tensor) -> torch.Tensor: """Convert a unit quaternion to a standard form where the real part is non-negative. Quaternion representations have a singularity since ``q`` and ``-q`` represent the same rotation. This function ensures the real part of the quaternion is non-negative. Args: q: The quaternion orientation in (w, x, y, z). Shape is (..., 4). Returns: Standardized quaternions. Shape is (..., 4). """ return torch.where(q[..., 0:1] < 0, -q, q) @torch.jit.script def quat_mul(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor: """Multiply two quaternions together. Args: q1: The first quaternion in (w, x, y, z). Shape is (..., 4). q2: The second quaternion in (w, x, y, z). Shape is (..., 4). Returns: The product of the two quaternions in (w, x, y, z). Shape is (..., 4). Raises: ValueError: Input shapes of ``q1`` and ``q2`` are not matching. """ # check input is correct if q1.shape != q2.shape: msg = f"Expected input quaternion shape mismatch: {q1.shape} != {q2.shape}." raise ValueError(msg) # reshape to (N, 4) for multiplication shape = q1.shape q1 = q1.reshape(-1, 4) q2 = q2.reshape(-1, 4) # extract components from quaternions w1, x1, y1, z1 = q1[:, 0], q1[:, 1], q1[:, 2], q1[:, 3] w2, x2, y2, z2 = q2[:, 0], q2[:, 1], q2[:, 2], q2[:, 3] # perform multiplication ww = (z1 + x1) * (x2 + y2) yy = (w1 - y1) * (w2 + z2) zz = (w1 + y1) * (w2 - z2) xx = ww + yy + zz qq = 0.5 * (xx + (z1 - x1) * (x2 - y2)) w = qq - ww + (z1 - y1) * (y2 - z2) x = qq - xx + (x1 + w1) * (x2 + w2) y = qq - yy + (w1 - x1) * (y2 + z2) z = qq - zz + (z1 + y1) * (w2 - x2) return torch.stack([w, x, y, z], dim=-1).view(shape) @torch.jit.script def quat_box_minus(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor: """The box-minus operator (quaternion difference) between two quaternions. Args: q1: The first quaternion in (w, x, y, z). Shape is (N, 4). q2: The second quaternion in (w, x, y, z). Shape is (N, 4). Returns: The difference between the two quaternions. Shape is (N, 3). Reference: https://docs.leggedrobotics.com/kindr/cheatsheet_latest.pdf """ quat_diff = quat_mul(q1, quat_conjugate(q2)) # q1 * q2^-1 re = quat_diff[:, 0] # real part, q = [w, x, y, z] = [re, im] im = quat_diff[:, 1:] # imaginary part norm_im = torch.norm(im, dim=1) scale = 2.0 * torch.where(norm_im > 1.0e-7, torch.atan2(norm_im, re) / norm_im, torch.sign(re)) return scale.unsqueeze(-1) * im @torch.jit.script def yaw_quat(quat: torch.Tensor) -> torch.Tensor: """Extract the yaw component of a quaternion. Args: quat: The orientation in (w, x, y, z). Shape is (..., 4) Returns: A quaternion with only yaw component. """ shape = quat.shape quat_yaw = quat.clone().view(-1, 4) qw = quat_yaw[:, 0] qx = quat_yaw[:, 1] qy = quat_yaw[:, 2] qz = quat_yaw[:, 3] yaw = torch.atan2(2 * (qw * qz + qx * qy), 1 - 2 * (qy * qy + qz * qz)) quat_yaw[:] = 0.0 quat_yaw[:, 3] = torch.sin(yaw / 2) quat_yaw[:, 0] = torch.cos(yaw / 2) quat_yaw = normalize(quat_yaw) return quat_yaw.view(shape) @torch.jit.script def quat_apply(quat: torch.Tensor, vec: torch.Tensor) -> torch.Tensor: """Apply a quaternion rotation to a vector. Args: quat: The quaternion in (w, x, y, z). Shape is (..., 4). vec: The vector in (x, y, z). Shape is (..., 3). Returns: The rotated vector in (x, y, z). Shape is (..., 3). """ # store shape shape = vec.shape # reshape to (N, 3) for multiplication quat = quat.reshape(-1, 4) vec = vec.reshape(-1, 3) # extract components from quaternions xyz = quat[:, 1:] t = xyz.cross(vec, dim=-1) * 2 return (vec + quat[:, 0:1] * t + xyz.cross(t, dim=-1)).view(shape) @torch.jit.script def quat_apply_yaw(quat: torch.Tensor, vec: torch.Tensor) -> torch.Tensor: """Rotate a vector only around the yaw-direction. Args: quat: The orientation in (w, x, y, z). Shape is (N, 4). vec: The vector in (x, y, z). Shape is (N, 3). Returns: The rotated vector in (x, y, z). Shape is (N, 3). """ quat_yaw = yaw_quat(quat) return quat_apply(quat_yaw, vec) @torch.jit.script def quat_rotate(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor: """Rotate a vector by a quaternion. Args: q: The quaternion in (w, x, y, z). Shape is (N, 4). v: The vector in (x, y, z). Shape is (N, 3). Returns: The rotated vector in (x, y, z). Shape is (N, 3). """ shape = q.shape q_w = q[:, 0] q_vec = q[:, 1:] a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1) b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0 c = q_vec * torch.bmm(q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)).squeeze(-1) * 2.0 return a + b + c @torch.jit.script def quat_rotate_inverse(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor: """Rotate a vector by the inverse of a quaternion. Args: q: The quaternion in (w, x, y, z). Shape is (N, 4). v: The vector in (x, y, z). Shape is (N, 3). Returns: The rotated vector in (x, y, z). Shape is (N, 3). """ shape = q.shape q_w = q[:, 0] q_vec = q[:, 1:] a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1) b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0 c = q_vec * torch.bmm(q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)).squeeze(-1) * 2.0 return a - b + c @torch.jit.script def quat_from_angle_axis(angle: torch.Tensor, axis: torch.Tensor) -> torch.Tensor: """Convert rotations given as angle-axis to quaternions. Args: angle: The angle turned anti-clockwise in radians around the vector's direction. Shape is (N,). axis: The axis of rotation. Shape is (N, 3). Returns: The quaternion in (w, x, y, z). Shape is (N, 4). """ theta = (angle / 2).unsqueeze(-1) xyz = normalize(axis) * theta.sin() w = theta.cos() return normalize(torch.cat([w, xyz], dim=-1)) @torch.jit.script def axis_angle_from_quat(quat: torch.Tensor, eps: float = 1.0e-6) -> torch.Tensor: """Convert rotations given as quaternions to axis/angle. Args: quat: The quaternion orientation in (w, x, y, z). Shape is (..., 4). eps: The tolerance for Taylor approximation. Defaults to 1.0e-6. Returns: Rotations given as a vector in axis angle form. Shape is (..., 3). The vector's magnitude is the angle turned anti-clockwise in radians around the vector's direction. Reference: https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py#L526-L554 """ # Modified to take in quat as [q_w, q_x, q_y, q_z] # Quaternion is [q_w, q_x, q_y, q_z] = [cos(theta/2), n_x * sin(theta/2), n_y * sin(theta/2), n_z * sin(theta/2)] # Axis-angle is [a_x, a_y, a_z] = [theta * n_x, theta * n_y, theta * n_z] # Thus, axis-angle is [q_x, q_y, q_z] / (sin(theta/2) / theta) # When theta = 0, (sin(theta/2) / theta) is undefined # However, as theta --> 0, we can use the Taylor approximation 1/2 - theta^2 / 48 quat = quat * (1.0 - 2.0 * (quat[..., 0:1] < 0.0)) mag = torch.linalg.norm(quat[..., 1:], dim=-1) half_angle = torch.atan2(mag, quat[..., 0]) angle = 2.0 * half_angle # check whether to apply Taylor approximation sin_half_angles_over_angles = torch.where( angle.abs() > eps, torch.sin(half_angle) / angle, 0.5 - angle * angle / 48 ) return quat[..., 1:4] / sin_half_angles_over_angles.unsqueeze(-1) @torch.jit.script def quat_error_magnitude(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor: """Computes the rotation difference between two quaternions. Args: q1: The first quaternion in (w, x, y, z). Shape is (..., 4). q2: The second quaternion in (w, x, y, z). Shape is (..., 4). Returns: Angular error between input quaternions in radians. """ quat_diff = quat_mul(q1, quat_conjugate(q2)) return torch.norm(axis_angle_from_quat(quat_diff), dim=-1) @torch.jit.script def skew_symmetric_matrix(vec: torch.Tensor) -> torch.Tensor: """Computes the skew-symmetric matrix of a vector. Args: vec: The input vector. Shape is (3,) or (N, 3). Returns: The skew-symmetric matrix. Shape is (1, 3, 3) or (N, 3, 3). Raises: ValueError: If input tensor is not of shape (..., 3). """ # check input is correct if vec.shape[-1] != 3: raise ValueError(f"Expected input vector shape mismatch: {vec.shape} != (..., 3).") # unsqueeze the last dimension if vec.ndim == 1: vec = vec.unsqueeze(0) # create a skew-symmetric matrix skew_sym_mat = torch.zeros(vec.shape[0], 3, 3, device=vec.device, dtype=vec.dtype) skew_sym_mat[:, 0, 1] = -vec[:, 2] skew_sym_mat[:, 0, 2] = vec[:, 1] skew_sym_mat[:, 1, 2] = -vec[:, 0] skew_sym_mat[:, 1, 0] = vec[:, 2] skew_sym_mat[:, 2, 0] = -vec[:, 1] skew_sym_mat[:, 2, 1] = vec[:, 0] return skew_sym_mat """ Transformations """
[docs]def is_identity_pose(pos: torch.tensor, rot: torch.tensor) -> bool: """Checks if input poses are identity transforms. The function checks if the input position and orientation are close to zero and identity respectively using L2-norm. It does NOT check the error in the orientation. Args: pos: The cartesian position. Shape is (N, 3). rot: The quaternion in (w, x, y, z). Shape is (N, 4). Returns: True if all the input poses result in identity transform. Otherwise, False. """ # create identity transformations pos_identity = torch.zeros_like(pos) rot_identity = torch.zeros_like(rot) rot_identity[..., 0] = 1 # compare input to identity return torch.allclose(pos, pos_identity) and torch.allclose(rot, rot_identity)
# @torch.jit.script
[docs]def combine_frame_transforms( t01: torch.Tensor, q01: torch.Tensor, t12: torch.Tensor | None = None, q12: torch.Tensor | None = None ) -> tuple[torch.Tensor, torch.Tensor]: r"""Combine transformations between two reference frames into a stationary frame. It performs the following transformation operation: :math:`T_{02} = T_{01} \times T_{12}`, where :math:`T_{AB}` is the homogeneous transformation matrix from frame A to B. Args: t01: Position of frame 1 w.r.t. frame 0. Shape is (N, 3). q01: Quaternion orientation of frame 1 w.r.t. frame 0 in (w, x, y, z). Shape is (N, 4). t12: Position of frame 2 w.r.t. frame 1. Shape is (N, 3). Defaults to None, in which case the position is assumed to be zero. q12: Quaternion orientation of frame 2 w.r.t. frame 1 in (w, x, y, z). Shape is (N, 4). Defaults to None, in which case the orientation is assumed to be identity. Returns: A tuple containing the position and orientation of frame 2 w.r.t. frame 0. Shape of the tensors are (N, 3) and (N, 4) respectively. """ # compute orientation if q12 is not None: q02 = quat_mul(q01, q12) else: q02 = q01 # compute translation if t12 is not None: t02 = t01 + quat_apply(q01, t12) else: t02 = t01 return t02, q02
# @torch.jit.script
[docs]def subtract_frame_transforms( t01: torch.Tensor, q01: torch.Tensor, t02: torch.Tensor | None = None, q02: torch.Tensor | None = None ) -> tuple[torch.Tensor, torch.Tensor]: r"""Subtract transformations between two reference frames into a stationary frame. It performs the following transformation operation: :math:`T_{12} = T_{01}^{-1} \times T_{02}`, where :math:`T_{AB}` is the homogeneous transformation matrix from frame A to B. Args: t01: Position of frame 1 w.r.t. frame 0. Shape is (N, 3). q01: Quaternion orientation of frame 1 w.r.t. frame 0 in (w, x, y, z). Shape is (N, 4). t02: Position of frame 2 w.r.t. frame 0. Shape is (N, 3). Defaults to None, in which case the position is assumed to be zero. q02: Quaternion orientation of frame 2 w.r.t. frame 0 in (w, x, y, z). Shape is (N, 4). Defaults to None, in which case the orientation is assumed to be identity. Returns: A tuple containing the position and orientation of frame 2 w.r.t. frame 1. Shape of the tensors are (N, 3) and (N, 4) respectively. """ # compute orientation q10 = quat_inv(q01) if q02 is not None: q12 = quat_mul(q10, q02) else: q12 = q10 # compute translation if t02 is not None: t12 = quat_apply(q10, t02 - t01) else: t12 = quat_apply(q10, -t01) return t12, q12
# @torch.jit.script
[docs]def compute_pose_error( t01: torch.Tensor, q01: torch.Tensor, t02: torch.Tensor, q02: torch.Tensor, rot_error_type: Literal["quat", "axis_angle"] = "axis_angle", ) -> tuple[torch.Tensor, torch.Tensor]: """Compute the position and orientation error between source and target frames. Args: t01: Position of source frame. Shape is (N, 3). q01: Quaternion orientation of source frame in (w, x, y, z). Shape is (N, 4). t02: Position of target frame. Shape is (N, 3). q02: Quaternion orientation of target frame in (w, x, y, z). Shape is (N, 4). rot_error_type: The rotation error type to return: "quat", "axis_angle". Defaults to "axis_angle". Returns: A tuple containing position and orientation error. Shape of position error is (N, 3). Shape of orientation error depends on the value of :attr:`rot_error_type`: - If :attr:`rot_error_type` is "quat", the orientation error is returned as a quaternion. Shape is (N, 4). - If :attr:`rot_error_type` is "axis_angle", the orientation error is returned as an axis-angle vector. Shape is (N, 3). Raises: ValueError: Invalid rotation error type. """ # Compute quaternion error (i.e., difference quaternion) # Reference: https://personal.utdallas.edu/~sxb027100/dock/quaternion.html # q_current_norm = q_current * q_current_conj source_quat_norm = quat_mul(q01, quat_conjugate(q01))[:, 0] # q_current_inv = q_current_conj / q_current_norm source_quat_inv = quat_conjugate(q01) / source_quat_norm.unsqueeze(-1) # q_error = q_target * q_current_inv quat_error = quat_mul(q02, source_quat_inv) # Compute position error pos_error = t02 - t01 # return error based on specified type if rot_error_type == "quat": return pos_error, quat_error elif rot_error_type == "axis_angle": # Convert to axis-angle error axis_angle_error = axis_angle_from_quat(quat_error) return pos_error, axis_angle_error else: raise ValueError(f"Unsupported orientation error type: {rot_error_type}. Valid: 'quat', 'axis_angle'.")
@torch.jit.script def apply_delta_pose( source_pos: torch.Tensor, source_rot: torch.Tensor, delta_pose: torch.Tensor, eps: float = 1.0e-6 ) -> tuple[torch.Tensor, torch.Tensor]: """Applies delta pose transformation on source pose. The first three elements of `delta_pose` are interpreted as cartesian position displacement. The remaining three elements of `delta_pose` are interpreted as orientation displacement in the angle-axis format. Args: source_pos: Position of source frame. Shape is (N, 3). source_rot: Quaternion orientation of source frame in (w, x, y, z). Shape is (N, 4).. delta_pose: Position and orientation displacements. Shape is (N, 6). eps: The tolerance to consider orientation displacement as zero. Defaults to 1.0e-6. Returns: A tuple containing the displaced position and orientation frames. Shape of the tensors are (N, 3) and (N, 4) respectively. """ # number of poses given num_poses = source_pos.shape[0] device = source_pos.device # interpret delta_pose[:, 0:3] as target position displacements target_pos = source_pos + delta_pose[:, 0:3] # interpret delta_pose[:, 3:6] as target rotation displacements rot_actions = delta_pose[:, 3:6] angle = torch.linalg.vector_norm(rot_actions, dim=1) axis = rot_actions / angle.unsqueeze(-1) # change from axis-angle to quat convention identity_quat = torch.tensor([1.0, 0.0, 0.0, 0.0], device=device).repeat(num_poses, 1) rot_delta_quat = torch.where( angle.unsqueeze(-1).repeat(1, 4) > eps, quat_from_angle_axis(angle, axis), identity_quat ) # TODO: Check if this is the correct order for this multiplication. target_rot = quat_mul(rot_delta_quat, source_rot) return target_pos, target_rot # @torch.jit.script
[docs]def transform_points( points: torch.Tensor, pos: torch.Tensor | None = None, quat: torch.Tensor | None = None ) -> torch.Tensor: r"""Transform input points in a given frame to a target frame. This function transform points from a source frame to a target frame. The transformation is defined by the position :math:`t` and orientation :math:`R` of the target frame in the source frame. .. math:: p_{target} = R_{target} \times p_{source} + t_{target} If the input `points` is a batch of points, the inputs `pos` and `quat` must be either a batch of positions and quaternions or a single position and quaternion. If the inputs `pos` and `quat` are a single position and quaternion, the same transformation is applied to all points in the batch. If either the inputs :attr:`pos` and :attr:`quat` are None, the corresponding transformation is not applied. Args: points: Points to transform. Shape is (N, P, 3) or (P, 3). pos: Position of the target frame. Shape is (N, 3) or (3,). Defaults to None, in which case the position is assumed to be zero. quat: Quaternion orientation of the target frame in (w, x, y, z). Shape is (N, 4) or (4,). Defaults to None, in which case the orientation is assumed to be identity. Returns: Transformed points in the target frame. Shape is (N, P, 3) or (P, 3). Raises: ValueError: If the inputs `points` is not of shape (N, P, 3) or (P, 3). ValueError: If the inputs `pos` is not of shape (N, 3) or (3,). ValueError: If the inputs `quat` is not of shape (N, 4) or (4,). """ points_batch = points.clone() # check if inputs are batched is_batched = points_batch.dim() == 3 # -- check inputs if points_batch.dim() == 2: points_batch = points_batch[None] # (P, 3) -> (1, P, 3) if points_batch.dim() != 3: raise ValueError(f"Expected points to have dim = 2 or dim = 3: got shape {points.shape}") if not (pos is None or pos.dim() == 1 or pos.dim() == 2): raise ValueError(f"Expected pos to have dim = 1 or dim = 2: got shape {pos.shape}") if not (quat is None or quat.dim() == 1 or quat.dim() == 2): raise ValueError(f"Expected quat to have dim = 1 or dim = 2: got shape {quat.shape}") # -- rotation if quat is not None: # convert to batched rotation matrix rot_mat = matrix_from_quat(quat) if rot_mat.dim() == 2: rot_mat = rot_mat[None] # (3, 3) -> (1, 3, 3) # convert points to matching batch size (N, P, 3) -> (N, 3, P) # and apply rotation points_batch = torch.matmul(rot_mat, points_batch.transpose_(1, 2)) # (N, 3, P) -> (N, P, 3) points_batch = points_batch.transpose_(1, 2) # -- translation if pos is not None: # convert to batched translation vector if pos.dim() == 1: pos = pos[None, None, :] # (3,) -> (1, 1, 3) else: pos = pos[:, None, :] # (N, 3) -> (N, 1, 3) # apply translation points_batch += pos # -- return points in same shape as input if not is_batched: points_batch = points_batch.squeeze(0) # (1, P, 3) -> (P, 3) return points_batch
""" Projection operations. """ @torch.jit.script def unproject_depth(depth: torch.Tensor, intrinsics: torch.Tensor) -> torch.Tensor: r"""Unproject depth image into a pointcloud. This function converts depth images into points given the calibration matrix of the camera. .. math:: p_{3D} = K^{-1} \times [u, v, 1]^T \times d where :math:`p_{3D}` is the 3D point, :math:`d` is the depth value, :math:`u` and :math:`v` are the pixel coordinates and :math:`K` is the intrinsic matrix. If `depth` is a batch of depth images and `intrinsics` is a single intrinsic matrix, the same calibration matrix is applied to all depth images in the batch. The function assumes that the width and height are both greater than 1. This makes the function deal with many possible shapes of depth images and intrinsics matrices. Args: depth: The depth measurement. Shape is (H, W) or or (H, W, 1) or (N, H, W) or (N, H, W, 1). intrinsics: A tensor providing camera's calibration matrix. Shape is (3, 3) or (N, 3, 3). Returns: The 3D coordinates of points. Shape is (P, 3) or (N, P, 3). Raises: ValueError: When depth is not of shape (H, W) or (H, W, 1) or (N, H, W) or (N, H, W, 1). ValueError: When intrinsics is not of shape (3, 3) or (N, 3, 3). """ depth_batch = depth.clone() intrinsics_batch = intrinsics.clone() # check if inputs are batched is_batched = depth_batch.dim() == 4 or (depth_batch.dim() == 3 and depth_batch.shape[-1] != 1) # make sure inputs are batched if depth_batch.dim() == 3 and depth_batch.shape[-1] == 1: depth_batch = depth_batch.squeeze(dim=2) # (H, W, 1) -> (H, W) if depth_batch.dim() == 2: depth_batch = depth_batch[None] # (H, W) -> (1, H, W) if depth_batch.dim() == 4 and depth_batch.shape[-1] == 1: depth_batch = depth_batch.squeeze(dim=3) # (N, H, W, 1) -> (N, H, W) if intrinsics_batch.dim() == 2: intrinsics_batch = intrinsics_batch[None] # (3, 3) -> (1, 3, 3) # check shape of inputs if depth_batch.dim() != 3: raise ValueError(f"Expected depth images to have dim = 2 or 3 or 4: got shape {depth.shape}") if intrinsics_batch.dim() != 3: raise ValueError(f"Expected intrinsics to have shape (3, 3) or (N, 3, 3): got shape {intrinsics.shape}") # get image height and width im_height, im_width = depth_batch.shape[1:] # create image points in homogeneous coordinates (3, H x W) indices_u = torch.arange(im_width, device=depth.device, dtype=depth.dtype) indices_v = torch.arange(im_height, device=depth.device, dtype=depth.dtype) img_indices = torch.stack(torch.meshgrid([indices_u, indices_v], indexing="ij"), dim=0).reshape(2, -1) pixels = torch.nn.functional.pad(img_indices, (0, 0, 0, 1), mode="constant", value=1.0) pixels = pixels.unsqueeze(0) # (3, H x W) -> (1, 3, H x W) # unproject points into 3D space points = torch.matmul(torch.inverse(intrinsics_batch), pixels) # (N, 3, H x W) points = points / points[:, -1, :].unsqueeze(1) # normalize by last coordinate # flatten depth image (N, H, W) -> (N, H x W) depth_batch = depth_batch.transpose_(1, 2).reshape(depth_batch.shape[0], -1).unsqueeze(2) depth_batch = depth_batch.expand(-1, -1, 3) # scale points by depth points_xyz = points.transpose_(1, 2) * depth_batch # (N, H x W, 3) # return points in same shape as input if not is_batched: points_xyz = points_xyz.squeeze(0) return points_xyz @torch.jit.script def project_points(points: torch.Tensor, intrinsics: torch.Tensor) -> torch.Tensor: r"""Projects 3D points into 2D image plane. This project 3D points into a 2D image plane. The transformation is defined by the intrinsic matrix of the camera. .. math:: \begin{align} p &= K \times p_{3D} = \\ p_{2D} &= \begin{pmatrix} u \\ v \\ d \end{pmatrix} = \begin{pmatrix} p[0] / p[2] \\ p[1] / p[2] \\ Z \end{pmatrix} \end{align} where :math:`p_{2D} = (u, v, d)` is the projected 3D point, :math:`p_{3D} = (X, Y, Z)` is the 3D point and :math:`K \in \mathbb{R}^{3 \times 3}` is the intrinsic matrix. If `points` is a batch of 3D points and `intrinsics` is a single intrinsic matrix, the same calibration matrix is applied to all points in the batch. Args: points: The 3D coordinates of points. Shape is (P, 3) or (N, P, 3). intrinsics: Camera's calibration matrix. Shape is (3, 3) or (N, 3, 3). Returns: Projected 3D coordinates of points. Shape is (P, 3) or (N, P, 3). """ points_batch = points.clone() intrinsics_batch = intrinsics.clone() # check if inputs are batched is_batched = points_batch.dim() == 2 # make sure inputs are batched if points_batch.dim() == 2: points_batch = points_batch[None] # (P, 3) -> (1, P, 3) if intrinsics_batch.dim() == 2: intrinsics_batch = intrinsics_batch[None] # (3, 3) -> (1, 3, 3) # check shape of inputs if points_batch.dim() != 3: raise ValueError(f"Expected points to have dim = 3: got shape {points.shape}.") if intrinsics_batch.dim() != 3: raise ValueError(f"Expected intrinsics to have shape (3, 3) or (N, 3, 3): got shape {intrinsics.shape}.") # project points into 2D image plane points_2d = torch.matmul(intrinsics_batch, points_batch.transpose(1, 2)) points_2d = points_2d / points_2d[:, -1, :].unsqueeze(1) # normalize by last coordinate points_2d = points_2d.transpose_(1, 2) # (N, 3, P) -> (N, P, 3) # replace last coordinate with depth points_2d[:, :, -1] = points_batch[:, :, -1] # return points in same shape as input if not is_batched: points_2d = points_2d.squeeze(0) # (1, 3, P) -> (3, P) return points_2d """ Sampling """ @torch.jit.script def default_orientation(num: int, device: str) -> torch.Tensor: """Returns identity rotation transform. Args: num: The number of rotations to sample. device: Device to create tensor on. Returns: Identity quaternion in (w, x, y, z). Shape is (num, 4). """ quat = torch.zeros((num, 4), dtype=torch.float, device=device) quat[..., 0] = 1.0 return quat @torch.jit.script def random_orientation(num: int, device: str) -> torch.Tensor: """Returns sampled rotation in 3D as quaternion. Args: num: The number of rotations to sample. device: Device to create tensor on. Returns: Sampled quaternion in (w, x, y, z). Shape is (num, 4). Reference: https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.transform.Rotation.random.html """ # sample random orientation from normal distribution quat = torch.randn((num, 4), dtype=torch.float, device=device) # normalize the quaternion return torch.nn.functional.normalize(quat, p=2.0, dim=-1, eps=1e-12) @torch.jit.script def random_yaw_orientation(num: int, device: str) -> torch.Tensor: """Returns sampled rotation around z-axis. Args: num: The number of rotations to sample. device: Device to create tensor on. Returns: Sampled quaternion in (w, x, y, z). Shape is (num, 4). """ roll = torch.zeros(num, dtype=torch.float, device=device) pitch = torch.zeros(num, dtype=torch.float, device=device) yaw = 2 * torch.pi * torch.rand(num, dtype=torch.float, device=device) return quat_from_euler_xyz(roll, pitch, yaw)
[docs]def sample_triangle(lower: float, upper: float, size: int | tuple[int, ...], device: str) -> torch.Tensor: """Randomly samples tensor from a triangular distribution. Args: lower: The lower range of the sampled tensor. upper: The upper range of the sampled tensor. size: The shape of the tensor. device: Device to create tensor on. Returns: Sampled tensor. Shape is based on :attr:`size`. """ # convert to tuple if isinstance(size, int): size = (size,) # create random tensor in the range [-1, 1] r = 2 * torch.rand(*size, device=device) - 1 # convert to triangular distribution r = torch.where(r < 0.0, -torch.sqrt(-r), torch.sqrt(r)) # rescale back to [0, 1] r = (r + 1.0) / 2.0 # rescale to range [lower, upper] return (upper - lower) * r + lower
[docs]def sample_uniform( lower: torch.Tensor | float, upper: torch.Tensor | float, size: int | tuple[int, ...], device: str ) -> torch.Tensor: """Sample uniformly within a range. Args: lower: Lower bound of uniform range. upper: Upper bound of uniform range. size: The shape of the tensor. device: Device to create tensor on. Returns: Sampled tensor. Shape is based on :attr:`size`. """ # convert to tuple if isinstance(size, int): size = (size,) # return tensor return torch.rand(*size, device=device) * (upper - lower) + lower
[docs]def sample_log_uniform( lower: torch.Tensor | float, upper: torch.Tensor | float, size: int | tuple[int, ...], device: str ) -> torch.Tensor: r"""Sample using log-uniform distribution within a range. The log-uniform distribution is defined as a uniform distribution in the log-space. It is useful for sampling values that span several orders of magnitude. The sampled values are uniformly distributed in the log-space and then exponentiated to get the final values. .. math:: x = \exp(\text{uniform}(\log(\text{lower}), \log(\text{upper}))) Args: lower: Lower bound of uniform range. upper: Upper bound of uniform range. size: The shape of the tensor. device: Device to create tensor on. Returns: Sampled tensor. Shape is based on :attr:`size`. """ # cast to tensor if not already if not isinstance(lower, torch.Tensor): lower = torch.tensor(lower, dtype=torch.float, device=device) if not isinstance(upper, torch.Tensor): upper = torch.tensor(upper, dtype=torch.float, device=device) # sample in log-space and exponentiate return torch.exp(sample_uniform(torch.log(lower), torch.log(upper), size, device))
[docs]def sample_cylinder( radius: float, h_range: tuple[float, float], size: int | tuple[int, ...], device: str ) -> torch.Tensor: """Sample 3D points uniformly on a cylinder's surface. The cylinder is centered at the origin and aligned with the z-axis. The height of the cylinder is sampled uniformly from the range :obj:`h_range`, while the radius is fixed to :obj:`radius`. The sampled points are returned as a tensor of shape :obj:`(*size, 3)`, i.e. the last dimension contains the x, y, and z coordinates of the sampled points. Args: radius: The radius of the cylinder. h_range: The minimum and maximum height of the cylinder. size: The shape of the tensor. device: Device to create tensor on. Returns: Sampled tensor. Shape is :obj:`(*size, 3)`. """ # sample angles angles = (torch.rand(size, device=device) * 2 - 1) * torch.pi h_min, h_max = h_range # add shape if isinstance(size, int): size = (size, 3) else: size += (3,) # allocate a tensor xyz = torch.zeros(size, device=device) xyz[..., 0] = radius * torch.cos(angles) xyz[..., 1] = radius * torch.sin(angles) xyz[..., 2].uniform_(h_min, h_max) # return positions return xyz