Source code for spherical_inr.coords

r"""Coordinate transforms between Cartesian and spherical/polar parameterizations.

The 3D functions use spherical coordinates :math:`(r, \theta, \phi)` with
:math:`\theta` the polar angle and :math:`\phi` the azimuth. The 2D functions
use polar coordinates :math:`(r, \theta)` or just the angle :math:`\theta` on
the unit circle.
"""

import torch


[docs] def rtp_to_r3(rtp_coords: torch.Tensor) -> torch.Tensor: r"""Map spherical coordinates :math:`(r, \theta, \phi)` to Cartesian coordinates. The conversion follows .. math:: x = r \sin\theta \cos\phi,\quad y = r \sin\theta \sin\phi,\quad z = r \cos\theta. Parameters ---------- rtp_coords: torch.Tensor Tensor with shape (..., 3) representing :math:`[r, \theta, \phi]`. Returns ------- torch.Tensor Tensor with shape (..., 3) containing :math:`[x, y, z]`. Raises ------ ValueError If the last dimension of ``rtp_coords`` is not 3. """ if rtp_coords.shape[-1] != 3: raise ValueError("The last dimension of rtp_coords must be 3.") r, theta, phi = rtp_coords.unbind(dim=-1) sin_theta = torch.sin(theta) x = r * sin_theta * torch.cos(phi) y = r * sin_theta * torch.sin(phi) z = r * torch.cos(theta) return torch.stack([x, y, z], dim=-1)
[docs] def tp_to_r3(tp_coords: torch.Tensor) -> torch.Tensor: r"""Map unit-sphere angles :math:`(\theta, \phi)` to Cartesian coordinates. Parameters ---------- tp_coords: torch.Tensor Tensor with shape (..., 2) representing :math:`[\theta, \phi]`. Returns ------- torch.Tensor Tensor with shape (..., 3) containing :math:`[x, y, z]` on the unit sphere. Raises ------ ValueError If the last dimension of ``tp_coords`` is not 2. """ if tp_coords.shape[-1] != 2: raise ValueError("The last dimension of tp_coords must be 2.") theta, phi = tp_coords.unbind(dim=-1) sin_theta = torch.sin(theta) x = sin_theta * torch.cos(phi) y = sin_theta * torch.sin(phi) z = torch.cos(theta) return torch.stack([x, y, z], dim=-1)
[docs] def r3_to_rtp(r3_coords: torch.Tensor) -> torch.Tensor: r"""Convert Cartesian coordinates to spherical coordinates :math:`(r, \theta, \phi)`. Parameters ---------- r3_coords: torch.Tensor Tensor with shape (..., 3) representing :math:`[x, y, z]`. Returns ------- torch.Tensor Tensor with shape (..., 3) containing :math:`[r, \theta, \phi]`. Raises ------ ValueError If the last dimension of ``r3_coords`` is not 3. """ if r3_coords.shape[-1] != 3: raise ValueError("The last dimension of r3_coords must be 3.") x, y, z = r3_coords.unbind(dim=-1) r = torch.sqrt(x**2 + y**2 + z**2) theta = torch.acos(torch.clamp(z / (r + 1e-8), -1.0, 1.0)) phi = torch.atan2(y, x) return torch.stack([r, theta, phi], dim=-1)
[docs] def r3_to_tp(r3_coords: torch.Tensor) -> torch.Tensor: r"""Project Cartesian coordinates on the unit sphere to angles :math:`(\theta, \phi)`. Parameters ---------- r3_coords: torch.Tensor Tensor with shape (..., 3) representing :math:`[x, y, z]`. Returns ------- torch.Tensor Tensor with shape (..., 2) containing :math:`[\theta, \phi]`. Raises ------ ValueError If the last dimension of ``r3_coords`` is not 3. """ if r3_coords.shape[-1] != 3: raise ValueError("The last dimension of r3_coords must be 3.") norm = torch.norm(r3_coords, dim=-1, keepdim=True) unit_coords = r3_coords / (norm + 1e-8) x, y, z = unit_coords.unbind(dim=-1) theta = torch.acos(torch.clamp(z, -1.0, 1.0)) phi = torch.atan2(y, x) return torch.stack([theta, phi], dim=-1)
# === 2D Conversion Functions ===
[docs] def rt_to_r2(rt_coords: torch.Tensor) -> torch.Tensor: r"""Map polar coordinates :math:`(r, \theta)` to Cartesian coordinates in :math:`\mathbb{R}^2`. Parameters ---------- rt_coords: torch.Tensor Tensor with shape (..., 2) representing :math:`[r, \theta]`. Returns ------- torch.Tensor Tensor with shape (..., 2) containing :math:`[x, y]`. Raises ------ ValueError If the last dimension of ``rt_coords`` is not 2. """ if rt_coords.shape[-1] != 2: raise ValueError("The last dimension of rt_coords must be 2.") r, theta = rt_coords.unbind(dim=-1) x = r * torch.cos(theta) y = r * torch.sin(theta) return torch.stack([x, y], dim=-1)
[docs] def t_to_r2(t_coords: torch.Tensor) -> torch.Tensor: r"""Convert an angle on the unit circle to Cartesian coordinates. Parameters ---------- t_coords: torch.Tensor Tensor with shape (..., 1) containing the angle :math:`\theta`. Returns ------- torch.Tensor Tensor with shape (..., 2) containing :math:`[x, y]` on the unit circle. Raises ------ ValueError If the last dimension of ``t_coords`` is not 1. """ if t_coords.shape[-1] != 1: raise ValueError("The last dimension of t_coords must be 1.") theta = t_coords.squeeze(dim=-1) x = torch.cos(theta) y = torch.sin(theta) return torch.stack([x, y], dim=-1)
[docs] def r2_to_rt(r2_coords: torch.Tensor) -> torch.Tensor: r"""Convert Cartesian coordinates in :math:`\mathbb{R}^2` to polar form :math:`(r, \theta)`. Parameters ---------- r2_coords: torch.Tensor Tensor with shape (..., 2) representing :math:`[x, y]`. Returns ------- torch.Tensor Tensor with shape (..., 2) containing :math:`[r, \theta]`. Raises ------ ValueError If the last dimension of ``r2_coords`` is not 2. """ if r2_coords.shape[-1] != 2: raise ValueError("The last dimension of r2_coords must be 2.") x, y = r2_coords.unbind(dim=-1) r = torch.sqrt(x**2 + y**2) theta = torch.atan2(y, x) return torch.stack([r, theta], dim=-1)
[docs] def r2_to_t(r2_coords: torch.Tensor) -> torch.Tensor: r"""Project Cartesian coordinates on the unit circle to their angle :math:`\theta`. Parameters ---------- r2_coords: torch.Tensor Tensor with shape (..., 2) representing :math:`[x, y]`. Returns ------- torch.Tensor Tensor with shape (..., 1) containing :math:`[\theta]`. Raises ------ ValueError If the last dimension of ``r2_coords`` is not 2. """ if r2_coords.shape[-1] != 2: raise ValueError("The last dimension of r2_coords must be 2.") norm = torch.norm(r2_coords, dim=-1, keepdim=True) unit_coords = r2_coords / (norm + 1e-8) x, y = unit_coords.unbind(dim=-1) theta = torch.atan2(y, x) return theta.unsqueeze(dim=-1)