r"""Positional encodings for spherical data via harmonics, Fourier features, and Herglotz kernels."""
from __future__ import annotations
import torch
import torch.nn as nn
import math
from . import _kernels as PE
from typing import Tuple
__all__ = [
"HerglotzPE",
"FourierPE",
"SphericalHarmonicsPE",
]
[docs]
class SphericalHarmonicsPE(nn.Module):
r"""
Real spherical harmonics positional encoding.
This module maps spherical angles
:math:`x = (\theta, \phi) \in [0,\pi] \times [-\pi,\pi]`
to a vector of real spherical harmonics
.. math::
\psi^{\mathrm{SH}}(x) =
\bigl(
Y_{\ell_1}^{m_1}(\theta,\phi), \dots,
Y_{\ell_N}^{m_N}(\theta,\phi)
\bigr),
where the index pairs :math:`(\ell_k, m_k)` follow the standard ordering
.. math::
(0,0), (1,-1),(1,0),(1,1),(2,-2),\dots
and only the first ``num_atoms = N`` basis functions are retained.
The real spherical harmonics are defined as
.. math::
Y_\ell^m(\theta,\phi)
= N_{\ell m}\,P_\ell^{|m|}(\cos\theta)
\begin{cases}
\cos(m\phi), & m \ge 0, \\
\sin(|m|\phi), & m < 0,
\end{cases}
where :math:`P_\ell^m` are the associated Legendre polynomials and
:math:`N_{\ell m}` is a normalization constant.
Parameters
----------
num_atoms: int
Number of spherical harmonic basis functions returned.
"""
def __init__(
self,
num_atoms: int,
) -> None:
super().__init__()
self.num_atoms = num_atoms
L_upper = math.ceil(math.sqrt(num_atoms)) - 1
ms = [m for l in range(L_upper + 1) for m in range(-l, l + 1)][: self.num_atoms]
ls = [l for l in range(L_upper + 1) for _ in range(-l, l + 1)][: self.num_atoms]
# store as buffers for device moves
self.register_buffer(
"l_list", torch.tensor(ls, dtype=torch.int64), persistent=False
)
self.register_buffer(
"m_list", torch.tensor(ms, dtype=torch.int64), persistent=False
)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the spherical harmonics encoding.
This method evaluates the preselected real spherical harmonic basis
functions at the input angular coordinates and returns the first
``num_atoms`` coefficients in standard ordering.
Parameters
----------
x : torch.Tensor
Tensor of shape ``(..., 2)`` containing spherical angles
``(theta, phi)`` in radians.
Returns
-------
torch.Tensor
Real spherical harmonics evaluated at the input angles,
with shape ``(..., num_atoms)``.
Raises
------
ValueError
If ``x.shape[-1] != 2``.
"""
assert x.size(-1) == 2, "Input dim must be (theta, phi)"
return PE.sph_harm(x, self.l_list, self.m_list)
[docs]
class HerglotzPE(nn.Module):
r"""
Herglotz positional encoding with learnable phase and magnitude.
This module implements a real-valued Herglotz-type feature map defined on
Cartesian coordinates :math:`x \in \mathbb{R}^3`.
Each atom :math:`k` is defined by two orthonormal vectors
:math:`a_{k, \Re}, a_{k, \Im} \in \mathbb{R}^3`, forming an
implicit complex direction
:math:`a_k = a_{k, \Re} + i\,a_{k, \Im}`.
For an input point :math:`x`, we compute the projections
.. math::
u_k = \langle x, a_{k, \Re} \rangle, \qquad
v_k = \langle x, a_{k, \Im} \rangle.
Each atom is parameterized by learnable parameter :math:`\sigma_k` with :math:`\sigma_k \sim \mathcal{U}(0, L_{\text{init}})`.
The Herglotz feature associated with atom :math:`k` is defined in closed form as
.. math::
\psi^{\mathrm{H}}_k(x)
= \frac{1}{1 + 2L_{\text{init}}} e^{\rho_k (u_k - 1)}
\Bigl[
(1 + 2\rho_k u_k)\cos(\rho_k v_k)
- (2\rho_k v_k)\sin(\rho_k v_k)
\Bigr].
Optionally, a learnable quaternion rotation may be applied to all atoms
before evaluation, allowing the encoding to learn a global orientation.
Parameters
----------
num_atoms: int
Number of Herglotz atoms (output features).
L_init: int
Upper bound used to initialize the magnitude parameters
:math:`\sigma_k \sim \mathcal{U}(0, L_{\mathrm{init}})`.
rot: bool, optional
If ``True``, applies a learnable quaternion rotation to all atoms.
Default = ``False``
Notes
-----
This module is **Cartesian-only**. If your data is given in spherical coordinates :math:`(\theta,\phi)`,
use a wrapper to convert inputs to Cartesian coordinates before applying this encoding.
"""
def __init__(self, num_atoms: int, L_init: int, rot: bool = False) -> None:
super().__init__()
self.num_atoms = num_atoms
self.L_init = L_init
self.rot = rot
self.sigmas = nn.Parameter(torch.empty(self.num_atoms))
self.register_buffer("A_real0", torch.empty(self.num_atoms, 3))
self.register_buffer("A_imag0", torch.empty(self.num_atoms, 3))
if rot:
self.qrot = nn.Parameter(torch.empty(num_atoms, 4))
else:
self.register_parameter("qrot", None)
inv_const = 1.0 / (1.0 + 2 * self.L_init)
self.register_buffer(
"inv_const",
torch.tensor(inv_const),
persistent=False,
)
self.reset_parameters()
def reset_parameters(
self,
) -> None:
with torch.no_grad():
aR, aI = self._generate_atoms(
self.num_atoms, device=self.A_real0.device, dtype=self.A_real0.dtype
)
self.A_real0.copy_(aR)
self.A_imag0.copy_(aI)
nn.init.uniform_(self.sigmas, 0, self.L_init)
if self.qrot is not None:
self.qrot.zero_()
self.qrot[:, 0] = 1.0
@staticmethod
def _generate_atoms(
num_atoms: int, device=None, dtype=None
) -> Tuple[torch.Tensor, torch.Tensor]:
aI = torch.randn(num_atoms, 3, device=device, dtype=dtype)
aR = torch.randn(num_atoms, 3, device=device, dtype=dtype)
aR /= torch.norm(aR, dim=1, keepdim=True).clamp(1e-12)
aI -= torch.sum(aI * aR, dim=1, keepdim=True) * aR
aI /= torch.norm(aI, dim=1, keepdim=True).clamp(1e-12)
return aR, aI
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the Herglotz positional encoding.
This method computes the closed-form Herglotz feature map by projecting
Cartesian input points onto the learned atom directions, applying the
learned phase and magnitude parameters.
Parameters
----------
x : torch.Tensor
Tensor of shape ``(..., 3)`` containing Cartesian coordinates.
Returns
-------
torch.Tensor
Herglotz features evaluated at the input points,
with shape ``(..., num_atoms)``.
Raises
------
ValueError
If ``x.shape[-1] != 3``.
"""
if x.shape[-1] != 3:
raise ValueError(
f"HerglotzPE(coord='cartesian') expects x[...,3]=(x,y,z), got {x.shape}"
)
return PE.herglotz(
x,
self.A_real0,
self.A_imag0,
self.sigmas,
self.inv_const,
self.qrot,
)
[docs]
class FourierPE(nn.Module):
r"""
Learned Fourier positional encoding.
This module implements a learnable sinusoidal feature map of the form
.. math::
\psi^{\mathrm{F}}(x) = \sin\bigl(\omega_0 (x \Omega^\top + b)\bigr),
where:
- :math:`W \in \mathbb{R}^{N \times d}` is a learnable weight matrix,
- :math:`b \in \mathbb{R}^N` is an optional learnable bias,
- :math:`\omega_0 > 0` is a fixed frequency scaling factor.
This corresponds to a standard Fourier-feature embedding with learned
frequencies.
Parameters
----------
num_atoms: int
Number of output features.
input_dim: int
Dimension :math:`d` of the input space.
bias: bool, optional
Whether to include a learnable bias term :math:`b`.
Default = ``True``
omega0: float, optional
Frequency scaling factor :math:`\omega_0`.
Default = ``1.0``
"""
def __init__(
self,
num_atoms: int,
input_dim: int = 3,
bias: bool = True,
omega0: float = 1.0,
) -> None:
super().__init__()
self.num_atoms = num_atoms
self.input_dim = input_dim
self.omega0 = omega0
self.Omega = nn.Parameter(torch.empty(num_atoms, input_dim))
self.bias = nn.Parameter(torch.empty(self.num_atoms)) if bias else None
self.reset_parameters()
def reset_parameters(
self,
):
with torch.no_grad():
nn.init.uniform_(self.Omega, -1 / self.input_dim, 1 / self.input_dim)
if self.bias is not None:
bound = 1 / math.sqrt(self.input_dim)
nn.init.uniform_(self.bias, -bound, bound)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the learned Fourier positional encoding.
This method applies a learned linear projection of the input coordinates,
followed by sinusoidal activation with fixed frequency scaling.
Parameters
----------
x : torch.Tensor
Input tensor of shape ``(..., input_dim)``.
Returns
-------
torch.Tensor
Learned Fourier features with shape ``(..., num_atoms)``.
"""
return PE.fourier(x, self.Omega, self.omega0, self.bias)