import torch
[docs]
def rtp_to_r3(rtp_coords: torch.Tensor) -> torch.Tensor:
"""
Converts full spherical coordinates (rtp) [r, θ, ϕ] to R^3 Cartesian coordinates [x, y, z].
Parameters:
rtp_coords (torch.Tensor): Tensor with shape [..., 3] representing [r, θ, ϕ],
where r is the radius, θ is the polar angle, and ϕ is the azimuthal angle.
Returns:
torch.Tensor: Tensor with shape [..., 3] representing [x, y, z].
Raises:
ValueError: If the last dimension of rtp_coords is not 3.
"""
if rtp_coords.shape[-1] != 3:
raise ValueError("The last dimension of rtp_coords must be 3.")
r, theta, phi = rtp_coords.unbind(dim=-1)
sin_theta = torch.sin(theta)
x = r * sin_theta * torch.cos(phi)
y = r * sin_theta * torch.sin(phi)
z = r * torch.cos(theta)
return torch.stack([x, y, z], dim=-1)
[docs]
def tp_to_r3(tp_coords: torch.Tensor) -> torch.Tensor:
"""
Converts spherical coordinates on the unit sphere (tp) [θ, ϕ] to R^3 Cartesian coordinates [x, y, z].
Parameters:
tp_coords (torch.Tensor): Tensor with shape [..., 2] representing [θ, ϕ].
Returns:
torch.Tensor: Tensor with shape [..., 3] representing [x, y, z] on the unit sphere.
Raises:
ValueError: If the last dimension of tp_coords is not 2.
"""
if tp_coords.shape[-1] != 2:
raise ValueError("The last dimension of tp_coords must be 2.")
theta, phi = tp_coords.unbind(dim=-1)
sin_theta = torch.sin(theta)
x = sin_theta * torch.cos(phi)
y = sin_theta * torch.sin(phi)
z = torch.cos(theta)
return torch.stack([x, y, z], dim=-1)
[docs]
def r3_to_rtp(r3_coords: torch.Tensor) -> torch.Tensor:
"""
Converts R^3 Cartesian coordinates [x, y, z] to full spherical coordinates (rtp) [r, θ, ϕ].
Parameters:
r3_coords (torch.Tensor): Tensor with shape [..., 3] representing [x, y, z].
Returns:
torch.Tensor: Tensor with shape [..., 3] representing [r, θ, ϕ],
where θ is the polar angle and ϕ is the azimuthal angle.
Raises:
ValueError: If the last dimension of r3_coords is not 3.
"""
if r3_coords.shape[-1] != 3:
raise ValueError("The last dimension of r3_coords must be 3.")
x, y, z = r3_coords.unbind(dim=-1)
r = torch.sqrt(x**2 + y**2 + z**2)
theta = torch.acos(torch.clamp(z / (r + 1e-8), -1.0, 1.0))
phi = torch.atan2(y, x)
return torch.stack([r, theta, phi], dim=-1)
[docs]
def r3_to_tp(r3_coords: torch.Tensor) -> torch.Tensor:
"""
Converts R^3 Cartesian coordinates [x, y, z] (assumed to lie on the unit sphere)
to spherical coordinates on the unit sphere (tp) [θ, ϕ].
Parameters:
r3_coords (torch.Tensor): Tensor with shape [..., 3] representing [x, y, z].
Returns:
torch.Tensor: Tensor with shape [..., 2] representing [θ, ϕ].
Raises:
ValueError: If the last dimension of r3_coords is not 3.
"""
if r3_coords.shape[-1] != 3:
raise ValueError("The last dimension of r3_coords must be 3.")
norm = torch.norm(r3_coords, dim=-1, keepdim=True)
unit_coords = r3_coords / (norm + 1e-8)
x, y, z = unit_coords.unbind(dim=-1)
theta = torch.acos(torch.clamp(z, -1.0, 1.0))
phi = torch.atan2(y, x)
return torch.stack([theta, phi], dim=-1)
# === 2D Conversion Functions ===
[docs]
def rt_to_r2(rt_coords: torch.Tensor) -> torch.Tensor:
"""
Converts full polar coordinates (rt) [r, θ] to R^2 Cartesian coordinates [x, y].
Parameters:
rt_coords (torch.Tensor): Tensor with shape [..., 2] representing [r, θ].
Returns:
torch.Tensor: Tensor with shape [..., 2] representing [x, y].
Raises:
ValueError: If the last dimension of rt_coords is not 2.
"""
if rt_coords.shape[-1] != 2:
raise ValueError("The last dimension of rt_coords must be 2.")
r, theta = rt_coords.unbind(dim=-1)
x = r * torch.cos(theta)
y = r * torch.sin(theta)
return torch.stack([x, y], dim=-1)
[docs]
def t_to_r2(t_coords: torch.Tensor) -> torch.Tensor:
"""
Converts angle-only representation (t) [θ] on the unit circle to R^2 Cartesian coordinates [x, y].
Parameters:
t_coords (torch.Tensor): Tensor with shape [..., 1] representing [θ].
Returns:
torch.Tensor: Tensor with shape [..., 2] representing [x, y] on the unit circle.
Raises:
ValueError: If the last dimension of t_coords is not 1.
"""
if t_coords.shape[-1] != 1:
raise ValueError("The last dimension of t_coords must be 1.")
theta = t_coords.squeeze(dim=-1)
x = torch.cos(theta)
y = torch.sin(theta)
return torch.stack([x, y], dim=-1)
[docs]
def r2_to_rt(r2_coords: torch.Tensor) -> torch.Tensor:
"""
Converts R^2 Cartesian coordinates [x, y] to full polar coordinates (rt) [r, θ].
Parameters:
r2_coords (torch.Tensor): Tensor with shape [..., 2] representing [x, y].
Returns:
torch.Tensor: Tensor with shape [..., 2] representing [r, θ].
Raises:
ValueError: If the last dimension of r2_coords is not 2.
"""
if r2_coords.shape[-1] != 2:
raise ValueError("The last dimension of r2_coords must be 2.")
x, y = r2_coords.unbind(dim=-1)
r = torch.sqrt(x**2 + y**2)
theta = torch.atan2(y, x)
return torch.stack([r, theta], dim=-1)
[docs]
def r2_to_t(r2_coords: torch.Tensor) -> torch.Tensor:
"""
Converts R^2 Cartesian coordinates [x, y] (assumed to lie on the unit circle)
to angle-only representation (t) [θ].
Parameters:
r2_coords (torch.Tensor): Tensor with shape [..., 2] representing [x, y].
Returns:
torch.Tensor: Tensor with shape [..., 1] representing [θ].
Raises:
ValueError: If the last dimension of r2_coords is not 2.
"""
if r2_coords.shape[-1] != 2:
raise ValueError("The last dimension of r2_coords must be 2.")
norm = torch.norm(r2_coords, dim=-1, keepdim=True)
unit_coords = r2_coords / (norm + 1e-8)
x, y = unit_coords.unbind(dim=-1)
theta = torch.atan2(y, x)
return theta.unsqueeze(dim=-1)