Source code for spherical_inr.inr

import torch
import torch.nn as nn

from .transforms import (
    tp_to_r3,
    rtp_to_r3,
)
from .positional_encoding import (
    get_positional_encoding,
    HerglotzPE,
    FourierPE,
    SphericalHarmonicsPE,
    RegularSolidHarmonicsPE,
    IrregularSolidHarmonicsPE,
    RegularHerglotzPE,
    IrregularHerglotzPE,
    )

from .mlp import (
    MLP,
    SineMLP,
)

from typing import Optional, List


[docs] class INR(nn.Module): r"""Implicit Neural Representation (INR). Maps inputs in ℝᵈ through a positional encoding ψ onto a multilayer perceptron. For each input **x** of shape `(..., input_dim)`, you get an output of shape `(..., output_dim)` by 1. Computing **ψ(x)** via a chosen PE: shape `(..., num_atoms)`. 2. Passing that through an MLP with hidden sizes `mlp_sizes`. Parameters: num_atoms (int): Number of channels (atoms) output by the positional encoding ψ. mlp_sizes (List[int]): Hidden‐layer sizes for the MLP. E.g. `[64, 64]` for two hidden layers of width 64. output_dim (int): Number of output features per input point. input_dim (int, keyword-only): Dimensionality of each input x. Must match the PE’s requirement: - For `"herglotz"` or `"fourier"`: any positive int (commonly 2 or 3). - For `"spherical_harmonics"`: **must** be 2 (θ,φ). pe (str, optional): Which PE to use. One of: - `"herglotz"`: Herglotz map in ℝᵈ. - `"spherical_harmonics"`: real SH on S² (needs `input_dim=2`). - `"fourier"`: Fourier‐feature map. activation (str, optional): Activation for MLP layers, e.g. `"relu"`, `"gelu"`, etc. pe_kwargs (dict, optional): Passed directly into the chosen PE’s constructor—see that class’s docstring. mlp_kwargs (dict, optional): Extra `MLP(…)` keyword args (e.g. `bias=True`). activation_kwargs (dict, optional): Extra kwargs for the activation function (e.g. `{"inplace":True}`). Input: - **x**: Tensor of shape `(..., input_dim)`. * For Fourier or Herglotz: any real d-vector. * For SH: last two components are (θ,φ) in radians. Output: - Tensor of shape `(..., output_dim)`. """ def __init__( self, num_atoms: int, mlp_sizes: List[int], output_dim: int, *, input_dim: int, pe: str = "herglotz", activation: str = "relu", pe_kwargs: Optional[dict] = None, mlp_kwargs: Optional[dict] = None, activation_kwargs : Optional[dict] = None, ) -> None: super(INR, self).__init__() if pe not in ["herglotz", "spherical_harmonics", "fourier"]: raise ValueError( "Invalid positional encoding type. Choose from 'herglotz', 'spherical_harmonics', or 'fourier'." ) if pe == "spherical_harmonics" and input_dim != 2: raise ValueError( "Spherical harmonics positional encoding requires input_dim to be 2." ) self.pe = get_positional_encoding( pe, **{ "num_atoms": num_atoms, "input_dim": input_dim, **(pe_kwargs or {}), }, ) if activation == "sin": self.mlp = SineMLP( input_features=self.pe.num_atoms, output_features=output_dim, hidden_sizes=mlp_sizes, **(mlp_kwargs or {}), ) else: self.mlp = MLP( input_features=self.pe.num_atoms, output_features=output_dim, hidden_sizes=mlp_sizes, activation=activation, activation_kwargs= activation_kwargs or {}, **(mlp_kwargs or {}), )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.pe(x) x = self.mlp(x) return x
[docs] class HerglotzNet(nn.Module): r"""HerglotzNet on the 2-sphere. Expects inputs in spherical coords (θ,φ), converts them to Cartesian (x,y,z), then applies a 3-D Herglotz PE and a sine-activated MLP (SineMLP). Workflow: x_sph ∈ S² ──tp_to_r3──▶ x_cart ∈ ℝ³ └─HerglotzPE─▶ ψ(x) ∈ ℝⁿ └─SineMLP──▶ output ∈ ℝᵒ Parameters: L (int): Harmonic order. The PE creates `num_atoms = (L+1)**2` channels. mlp_sizes (List[int]): Hidden layer sizes for the SineMLP. output_dim (int): Number of output features. seed (int, optional): RNG seed for reproducible atom initialization in HerglotzPE. pe_kwargs (dict, optional): Extra args for `HerglotzPE(…)`—see its docstring. mlp_kwargs (dict, optional): Extra args for `SineMLP(…)` (e.g. `omega0`). Input: - **x**: Tensor `(..., 2)` of spherical coords (θ ∈ [0,π], φ ∈ [0,2π)). Output: - Tensor `(..., output_dim)`. """ def __init__( self, L : int, mlp_sizes: List[int], output_dim: int, *, seed: Optional[int] = None, pe_kwargs: Optional[dict] = None, mlp_kwargs: Optional[dict] = None, ) -> None: super(HerglotzNet, self).__init__() self.pe = HerglotzPE( L=L, input_dim=3, seed=seed, **(pe_kwargs or {}), ) self.mlp = SineMLP( input_features=self.pe.num_atoms, output_features=output_dim, hidden_sizes=mlp_sizes, **(mlp_kwargs or {}), )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: x = tp_to_r3(x) x = self.pe(x) x = self.mlp(x) return x
[docs] class RegularHerglotzNet(nn.Module): r"""Regular Solid HerglotzNet. Like HerglotzNet but uses **regular** solid harmonics → features grow like rˡ. Inputs are full spherical coords (r,θ,φ) so that features encode radial and angular info. Workflow: x_sph ∈ ℝ³ ──rtp_to_r3──▶ x_cart ∈ ℝ³ └─RegularSolidHerglotzPE─▶ ψ(x) └─SineMLP──▶ output Parameters: L (int): Harmonic order. `num_atoms=(L+1)**2`. mlp_sizes (List[int]): Hidden layer widths for the SineMLP. output_dim (int): Dimensionality of the network’s final output. seed (int, optional): RNG seed for PE. pe_kwargs (dict, optional): Extra args for `RegularHerglotzPE(…)`. mlp_kwargs (dict, optional): Extra args for `SineMLP(…)`. Input: - **x**: Tensor `(..., 3)` as (r,θ,φ). Output: - Tensor `(..., output_dim)`. """ def __init__( self, L :int, mlp_sizes: List[int], output_dim: int, *, seed: Optional[int] = None, pe_kwargs: Optional[dict] = None, mlp_kwargs: Optional[dict] = None, ) -> None: super(RegularHerglotzNet, self).__init__() self.pe = RegularHerglotzPE( L=L, seed=seed, ** (pe_kwargs or {}), ) self.mlp = SineMLP( input_features=self.pe.num_atoms, output_features=output_dim, hidden_sizes=mlp_sizes, **(mlp_kwargs or {}), )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: x = rtp_to_r3(x) x = self.pe(x) x = self.mlp(x) return x
[docs] class IrregularHerglotzNet(nn.Module): r"""Irregular Solid HerglotzNet. Identical to RegularHerglotzNet but uses **irregular** solid harmonics → features decay like 1/rˡ⁺¹. Use this when you want the encoding to vanish at infinity. Parameters: L (int): Harmonic order → `num_atoms=(L+1)**2`. mlp_sizes (List[int]): Hidden widths for SineMLP. output_dim (int): Output feature count. seed (int, optional): RNG seed. pe_kwargs (dict, optional): Extra for `IrregularHerglotzPE`. mlp_kwargs (dict, optional): Extra for `SineMLP`. Input / Output: same shapes as RegularHerglotzNet. """ def __init__( self, L :int, mlp_sizes: List[int], output_dim: int, *, seed: Optional[int] = None, pe_kwargs: Optional[dict] = None, mlp_kwargs: Optional[dict] = None, ) -> None: super(IrregularHerglotzNet, self).__init__() self.pe = IrregularHerglotzPE( L=L, seed=seed, ** (pe_kwargs or {}), ) self.mlp = SineMLP( input_features=self.pe.num_atoms, output_features=output_dim, hidden_sizes=mlp_sizes, **(mlp_kwargs or {}), )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: x = rtp_to_r3(x) x = self.pe(x) x = self.mlp(x) return x
[docs] class SirenNet(nn.Module): r"""Standard SIREN network with learnable Fourier PE. Applies a FourierPE followed by a sine-activated MLP (SineMLP). Workflow: x ∈ ℝᵈ ──FourierPE(num_atoms, ω₀)─▶ ψ(x) ∈ ℝⁿ └─SineMLP(ω₀)──▶ output ∈ ℝᵒ Parameters: num_atoms (int): Channels for the FourierPE. mlp_sizes (List[int]): Hidden layer sizes for the SineMLP. output_dim (int): Final output dimensionality. input_dim (int, keyword-only): Dimensionality d of x. pe_kwargs (dict, optional): Extra args for `FourierPE(…)` (e.g. `omega0`). mlp_kwargs (dict, optional): Extra args for `SineMLP(…)` (e.g. `omega0`). Input: - **x**: Tensor `(..., input_dim)`. Output: - Tensor `(..., output_dim)`. """ def __init__( self, num_atoms: int, mlp_sizes: List[int], output_dim: int, *, input_dim: int, pe_kwargs: Optional[dict] = None, mlp_kwargs: Optional[dict] = None, ) -> None: super(SirenNet, self).__init__() self.pe = FourierPE( num_atoms=num_atoms, input_dim=input_dim, **(pe_kwargs or {}) ) self.mlp = SineMLP( input_features=self.pe.num_atoms, output_features=output_dim, hidden_sizes=mlp_sizes, **(mlp_kwargs or {}), )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.pe(x) x = self.mlp(x) return x
[docs] class HerglotzSirenNet(nn.Module): r"""HerglotzSirenNet. Cartesian‐coordinate SIREN that uses a learnable Herglotz positional encoding in place of the usual Fourier features. Workflow: x ∈ ℝᵈ # user‐supplied Cartesian input └──HerglotzPE(num_atoms,d)──▶ ψ(x) ∈ ℝⁿ └─SineMLP──▶ output ∈ ℝᵒ Parameters: num_atoms (int): Number of atoms (channels) in the HerglotzPE. The PE buffer “A” will have shape `(num_atoms, input_dim)`. mlp_sizes (List[int]): Hidden‐layer widths for the sine‐activated MLP. e.g. `[64,64]` for two hidden layers of 64 units each. output_dim (int): Dimensionality of the final output (o). input_dim (int, keyword-only): Dimensionality d of each input vector x. Must match the HerglotzPE’s `input_dim` requirement (commonly 3). pe_kwargs (Optional[dict]): Extra keyword args forwarded to `HerglotzPE(…)`. See `HerglotzPE` docstring for full parameter list. mlp_kwargs (Optional[dict]): Extra keyword args forwarded to `SineMLP(…)` (e.g. `omega0`). Input: - **x**: Tensor of shape `(..., input_dim)`, in Cartesian coords. Output: - Tensor of shape `(..., output_dim)`. """ def __init__( self, num_atoms: int, mlp_sizes: List[int], output_dim: int, *, input_dim: int, pe_kwargs: Optional[dict] = None, mlp_kwargs: Optional[dict] = None, ) -> None: super(HerglotzSirenNet, self).__init__() self.pe = HerglotzPE( num_atoms=num_atoms, input_dim=input_dim, **(pe_kwargs or {}) ) self.mlp = SineMLP( input_features=self.pe.num_atoms, output_features=output_dim, hidden_sizes=mlp_sizes, **(mlp_kwargs or {}), )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.pe(x) x = self.mlp(x) return x
[docs] class SphericalSirenNet(nn.Module): r"""SphericalSirenNet. Angular SIREN on the 2‐sphere: encodes (θ,φ) via real spherical harmonics, then processes with a sine‐activated MLP. Workflow: x_sph ∈ S² # (θ,φ) in radians └─SphericalHarmonicsPE──▶ ψ(x) ∈ ℝⁿ └─SineMLP──▶ output ∈ ℝᵒ Parameters: L (int): Maximum spherical‐harmonic degree. PE will output `(L+1)**2` channels. mlp_sizes (List[int]): Hidden‐layer widths for the SineMLP. output_dim (int): Dimensionality of the network’s final output. seed (int, optional): RNG seed for reproducible behavior in `SphericalHarmonicsPE`. pe_kwargs (Optional[dict]): Extra keyword args for `SphericalHarmonicsPE(…)`. mlp_kwargs (Optional[dict]): Extra keyword args for `SineMLP(…)`. Input: - **x**: Tensor of shape `(..., 2)`, representing (θ,φ). Output: - Tensor of shape `(..., output_dim)`. """ def __init__( self, L : int, mlp_sizes : List[int], output_dim : int, *, seed : Optional[int] = None, pe_kwargs : Optional[dict] = None, mlp_kwargs : Optional[dict] = None, ) -> None: super(SphericalSirenNet, self).__init__() self.pe = SphericalHarmonicsPE( L=L, seed = seed, **(pe_kwargs or {}), ) self.mlp = SineMLP( input_features=self.pe.num_atoms, output_features=output_dim, hidden_sizes=mlp_sizes, **(mlp_kwargs or {}), )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: if x.shape[-1] != 2: raise ValueError( f"Expected input shape (..., 2) for spherical coordinates (θ, φ), but got {x.shape}." ) x = self.pe(x) x = self.mlp(x) return x
[docs] class IrregularSolidSirenNet(nn.Module): r"""IrregularSolidSirenNet. Solid‐harmonic SIREN on ℝ³ with **irregular** (decaying) basis functions. Workflow: x_sph ∈ ℝ³ # (r,θ,φ) └──rtp_to_r3──▶ x_cart ∈ ℝ³ └─IrregularSolidHarmonicsPE──▶ ψ(x) ∈ ℝⁿ └─SineMLP──▶ output ∈ ℝᵒ Parameters: L (int): Maximum harmonic degree; `num_atoms=(L+1)**2`. mlp_sizes (List[int]): Hidden‐layer widths for the SineMLP. output_dim (int): Dimensionality of the final output. seed (int, optional): RNG seed for `IrregularSolidHarmonicsPE`. pe_kwargs (Optional[dict]): Extra keyword args forwarded to `IrregularSolidHarmonicsPE(…)`. mlp_kwargs (Optional[dict]): Extra keyword args forwarded to `SineMLP(…)`. Input: - **x**: Tensor of shape `(..., 3)`, representing (r,θ,φ). Output: - Tensor of shape `(..., output_dim)`. """ def __init__( self, L : int, mlp_sizes : List[int], output_dim : int, *, seed : Optional[int] = None, pe_kwargs : Optional[dict] = None, mlp_kwargs : Optional[dict] = None, ) -> None: super(IrregularSolidSirenNet, self).__init__() self.pe = IrregularSolidHarmonicsPE( L=L, seed=seed, ** (pe_kwargs or {}), ) self.mlp = SineMLP( input_features=self.pe.num_atoms, output_features=output_dim, hidden_sizes=mlp_sizes, **(mlp_kwargs or {}), )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: x = rtp_to_r3(x) x = self.pe(x) x = self.mlp(x) return x
[docs] class RegularSolidSirenNet(nn.Module): r"""RegularSolidSirenNet. Solid‐harmonic SIREN on ℝ³ using **regular** solid harmonics (features grow like rˡ). Workflow: x_sph ∈ ℝ³ # input spherical coords (r, θ, φ) └──rtp_to_r3──▶ x_cart ∈ ℝ³ └─RegularSolidHarmonicsPE(L)──▶ ψ(x) ∈ ℝⁿ └─SineMLP──▶ output ∈ ℝᵒ Parameters: L (int): Maximum spherical‐harmonic degree. The PE will produce `num_atoms = (L+1)**2` channels. mlp_sizes (List[int]): Sizes of hidden layers for the sine‐activated MLP (e.g. `[64, 64]`). output_dim (int): Dimensionality o of the network’s final output. seed (int, optional): Random‐seed for initializing the solid‐harmonic basis in the PE. pe_kwargs (dict, optional): Additional keyword arguments forwarded to `RegularSolidHarmonicsPE`. See `RegularSolidHarmonicsPE` docstring for the full API. mlp_kwargs (dict, optional): Additional keyword arguments forwarded to `SineMLP`. Input: - **x**: Tensor of shape `(..., 3)`, representing spherical coordinates `(r ≥ 0, θ ∈ [0,π], φ ∈ [0,2π))`. Output: - Tensor of shape `(..., output_dim)`, the MLP’s prediction per input point. """ def __init__( self, L :int, mlp_sizes : List[int], output_dim : int, *, seed : Optional[int] = None, pe_kwargs : Optional[dict] = None, mlp_kwargs : Optional[dict] = None, ) -> None: super(RegularSolidSirenNet, self).__init__() self.pe = RegularSolidHarmonicsPE( L=L, seed=seed, ** (pe_kwargs or {}), ) self.mlp = SineMLP( input_features=self.pe.num_atoms, output_features=output_dim, hidden_sizes=mlp_sizes, **(mlp_kwargs or {}), )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: x = rtp_to_r3(x) x = self.pe(x) x = self.mlp(x) return x