Source code for spherical_inr.diffops

r"""Differential operators (gradient, divergence, Laplacian) in Cartesian and spherical coordinates."""

import torch
import warnings

__all__ = [
    "cartesian_gradient",
    "spherical_gradient",
    "s2_gradient",
    "cartesian_divergence",
    "spherical_divergence",
    "s2_divergence",
    "cartesian_laplacian",
    "spherical_laplacian",
    "s2_laplacian",
]


def _gradient(
    outputs: torch.Tensor,
    inputs: torch.Tensor,
    create_graph: bool = False,
    retain_graph: bool = False,
) -> torch.Tensor:

    grad = torch.autograd.grad(
        outputs,
        inputs,
        grad_outputs=torch.ones_like(outputs),
        create_graph=create_graph,
        retain_graph=retain_graph,
        allow_unused=True,
    )[0]

    if grad is None:
        warnings.warn(
            "Computed _gradient is None; replacing with zeros", RuntimeWarning
        )
        grad = (
            torch.zeros_like(inputs)
            if not create_graph
            else torch.zeros_like(inputs).requires_grad_(True)
        )

    return grad


[docs] def cartesian_gradient( outputs: torch.Tensor, inputs: torch.Tensor, track: bool = False ) -> torch.Tensor: r"""Compute the gradient of a scalar function in Cartesian coordinates. For a scalar function :math:`f: \mathbb{R}^n \rightarrow \mathbb{R}`, the gradient is given by .. math:: \nabla f(x) = \left[ \frac{\partial f}{\partial x_1},\, \frac{\partial f}{\partial x_2},\, \dots,\, \frac{\partial f}{\partial x_n} \right]. This function computes :math:`\nabla f(x)` with respect to the Cartesian coordinates provided in ``inputs``. Enabling the ``track`` parameter allows for the construction of higher-order derivative graphs. Parameters ---------- outputs: torch.Tensor Tensor representing the function values :math:`f(x)`. inputs: torch.Tensor Tensor representing the Cartesian coordinates :math:`x`. track: bool, optional If True, enables gradient tracking for higher-order derivatives. Default = ``False`` Returns ------- torch.Tensor: The gradient :math:`\nabla f(x)` in Cartesian coordinates. """ return _gradient(outputs, inputs, create_graph=track, retain_graph=track)
[docs] def spherical_gradient( outputs: torch.Tensor, inputs: torch.Tensor, track: bool = False ) -> torch.Tensor: r"""Compute the gradient of a function defined in spherical coordinates :math:`(r, \theta, \phi)`. For a function :math:`f(r, \theta, \phi)`, the spherical gradient is defined as .. math:: \nabla f = \hat{r}\,\frac{\partial f}{\partial r} \;+\; \hat{\theta}\,\frac{1}{r}\frac{\partial f}{\partial \theta} \;+\; \hat{\phi}\,\frac{1}{r\,\sin\theta}\frac{\partial f}{\partial \phi}. This function computes the gradient of :math:`f(r,\theta,\phi)` with respect to the spherical coordinates. The input tensor ``inputs`` must have three components representing :math:`[r, \theta, \phi]`. Parameters ---------- outputs: torch.Tensor Tensor representing the function values :math:`f(r, \theta, \phi)`. inputs: torch.Tensor Tensor of shape (..., 3) representing spherical coordinates :math:`[r, \theta, \phi]`. track: bool, optional If True, enables gradient tracking for higher-order derivatives. Default = ``False`` Returns ------- torch.Tensor: The spherical gradient with components scaled as above. Raises ------ ValueError: If ``inputs`` does not have three components. """ if inputs.size(-1) != 3: raise ValueError( "Spherical gradient is only defined for 3D spherical (r, θ, φ) coordinates" ) grad = _gradient(outputs, inputs, create_graph=track, retain_graph=track) r = inputs[..., 0] theta = inputs[..., 1] with torch.set_grad_enabled(track): grad = torch.stack( [ grad[..., 0], grad[..., 1] / r, grad[..., 2] / (r * torch.sin(theta)), ], dim=-1, ) return grad
[docs] def s2_gradient( outputs: torch.Tensor, inputs: torch.Tensor, track: bool = False ) -> torch.Tensor: r"""Compute the gradient of a function defined on the 2-sphere (S²) with respect to the angles :math:`(\theta, \phi)`. For a function :math:`f(\theta, \phi)` defined on S², the gradient is expressed as .. math:: \nabla_{\mathbb{S}^2} f = \hat{\theta}\,\frac{\partial f}{\partial \theta} \;+\; \hat{\phi}\,\frac{1}{\sin\theta}\frac{\partial f}{\partial \phi}. This function computes the gradient of :math:`f(\theta, \phi)` on S². The input tensor ``inputs`` must have two components corresponding to :math:`(\theta, \phi)`. Parameters ---------- outputs: torch.Tensor Tensor representing the function values :math:`f(\theta, \phi)`. inputs: torch.Tensor Tensor of shape (..., 2) representing spherical coordinates :math:`(\theta, \phi)`. track: bool, optional If True, enables gradient tracking for higher-order derivatives. Default = ``False`` Returns ------- torch.Tensor: The gradient on S², :math:`\nabla_{\mathbb{S}^2} f`. Raises ------ ValueError: If ``inputs`` does not have two components. """ if inputs.size(-1) != 2: raise ValueError( "S2 gradient is only defined for 2D spherical (θ, φ) coordinates" ) grad = _gradient(outputs, inputs, create_graph=track, retain_graph=track) theta = inputs[..., 0] with torch.set_grad_enabled(track): grad = torch.stack( [ grad[..., 0], grad[..., 1] / (torch.sin(theta)), ], dim=-1, ) return grad
[docs] def cartesian_divergence( outputs: torch.Tensor, inputs: torch.Tensor, track: bool = False ) -> torch.Tensor: r"""Compute the divergence of a vector field in Cartesian coordinates. For a vector field :math:`\mathbf{F}: \mathbb{R}^n \rightarrow \mathbb{R}^n` with components :math:`\mathbf{F}(x) = \left[F_1(x),\, F_2(x),\, \dots,\, F_n(x)\right]`, the divergence is defined as .. math:: \nabla \cdot \mathbf{F} = \sum_{i=1}^{n} \frac{\partial F_i}{\partial x_i}. This function computes the divergence by summing the partial derivatives of each component of the vector field with respect to its corresponding Cartesian coordinate. Parameters ---------- outputs: torch.Tensor Tensor representing the vector field, where the last dimension contains the components of :math:`\mathbf{F}`. inputs: torch.Tensor Tensor representing the Cartesian coordinates :math:`x`. track: bool, optional If True, enables gradient tracking for higher-order derivatives. Default = ``False`` Returns ------- torch.Tensor: The divergence :math:`\nabla \cdot \mathbf{F}`. """ outputs_to_grad = [outputs[..., i] for i in range(outputs.size(-1))] div = torch.zeros_like(outputs[..., 0]) for i, out in enumerate(outputs_to_grad): div += _gradient( out, inputs, create_graph=track, retain_graph=True if i < outputs.size(-1) - 1 else track, )[..., i] return div
[docs] def spherical_divergence( outputs: torch.Tensor, inputs: torch.Tensor, track: bool = False ) -> torch.Tensor: r"""Compute the divergence of a vector field in spherical coordinates. For a vector field :math:`\mathbf{F}(r,\theta,\phi) = \left[F_r,\, F_\theta,\, F_\phi\right]`, the divergence is given by .. math:: \nabla \cdot \mathbf{F} = \frac{1}{r^2}\frac{\partial}{\partial r}\Bigl(r^2 F_r\Bigr) \;+\; \frac{1}{r \sin\theta}\frac{\partial}{\partial \theta}\Bigl(\sin\theta F_\theta\Bigr) \;+\; \frac{1}{r \sin\theta}\frac{\partial F_\phi}{\partial \phi}. This function computes the divergence of the vector field in spherical coordinates. The tensor ``outputs`` must have three components. Parameters ---------- outputs: torch.Tensor Tensor of shape (..., 3) representing the vector field in spherical coordinates. inputs: torch.Tensor Tensor of shape (..., 3) representing spherical coordinates :math:`[r, \theta, \phi]`. track: bool, optional If True, enables gradient tracking for higher-order derivatives. Default = ``False`` Returns ------- torch.Tensor: The divergence :math:`\nabla \cdot \mathbf{F}` in spherical coordinates. Raises ------ ValueError: If ``outputs`` does not have three components. """ if outputs.size(-1) != 3: raise ValueError( "Spherical divergence is only defined for (r_hat, θ_hat, φ_hat) vector fields." ) r = inputs[..., 0] theta = inputs[..., 1] sin_theta = torch.sin(theta) r_sin_theta = r * sin_theta r2 = r**2 # Combine gradient computations outputs_to_grad = [ r2 * outputs[..., 0], sin_theta * outputs[..., 1], outputs[..., 2], ] scaling_factors = [1 / r2, 1 / r_sin_theta, 1 / r_sin_theta] div = torch.zeros_like(outputs[..., 0]) for i, (out, scaling_factors) in enumerate(zip(outputs_to_grad, scaling_factors)): grad = _gradient( out, inputs, create_graph=track, retain_graph=True if i < outputs.size(-1) - 1 else track, )[..., i] with torch.set_grad_enabled(track): div += grad * scaling_factors return div
[docs] def s2_divergence( outputs: torch.Tensor, inputs: torch.Tensor, track: bool = False ) -> torch.Tensor: r"""Compute the divergence of a vector field defined on the 2-sphere (S²). For a vector field on S², :math:`\mathbf{F}(\theta,\phi) = \left[F_\theta,\, F_\phi\right]`, the divergence is defined as .. math:: \nabla_{\mathbb{S}^2} \cdot \mathbf{F} = \frac{1}{\sin\theta}\frac{\partial}{\partial \theta}\Bigl(\sin\theta F_\theta\Bigr) \;+\; \frac{1}{\sin\theta}\frac{\partial F_\phi}{\partial \phi}. This function computes the divergence of the vector field on S². The tensor ``outputs`` must have two components. Parameters ---------- outputs: torch.Tensor Tensor representing the vector field on S² with two components. inputs: torch.Tensor Tensor of shape (..., 2) representing spherical coordinates :math:`(\theta, \phi)`. track: bool, optional If True, enables gradient tracking for higher-order derivatives. Default = ``False`` Returns ------- torch.Tensor: The divergence :math:`\nabla_{\mathbb{S}^2} \cdot \mathbf{F}` on the 2-sphere. Raises ------ ValueError: If ``outputs`` does not have two components. """ if outputs.size(-1) != 2: raise ValueError( "Spherical divergence is only defined for s2 (θ_hat, φ_hat) vector fields." ) theta = inputs[..., 0] sin_theta = torch.sin(theta) # Combine gradient computations outputs_to_grad = [ sin_theta * outputs[..., 0], outputs[..., 1], ] scaling_factors = [1 / sin_theta, 1 / sin_theta] div = torch.zeros_like(outputs[..., 0]) for i, (out, scaling_factors) in enumerate(zip(outputs_to_grad, scaling_factors)): grad = _gradient( out, inputs, create_graph=track, retain_graph=True if i < outputs.size(-1) - 1 else track, )[..., i] with torch.set_grad_enabled(track): div += grad * scaling_factors return div
[docs] def cartesian_laplacian( outputs: torch.Tensor, inputs: torch.Tensor, track: bool = False, ) -> torch.Tensor: r"""Compute the Laplacian of a scalar function in Cartesian coordinates. The Laplacian is defined as the divergence of the gradient: .. math:: \Delta f = \nabla \cdot \left(\nabla f\right) \;=\; \sum_{i=1}^{n} \frac{\partial^2 f}{\partial x_i^2}. This function computes the gradient of :math:`f(x)` with respect to the Cartesian coordinates ``x`` and then evaluates its divergence to obtain the Laplacian. The ``track`` parameter enables higher-order derivative tracking. Parameters ---------- outputs: torch.Tensor Tensor representing the function values :math:`f(x)`. inputs: torch.Tensor Tensor representing Cartesian coordinates :math:`x`. track: bool, optional If True, enables gradient tracking for higher-order derivatives. Default = ``False`` Returns ------- torch.Tensor: The Laplacian :math:`\Delta f` of the function. """ grad = cartesian_gradient(outputs, inputs, track=True) laplacian = cartesian_divergence(grad, inputs, track=track) return laplacian
[docs] def spherical_laplacian( outputs: torch.Tensor, inputs: torch.Tensor, track: bool = False, ) -> torch.Tensor: r"""Compute the Laplacian of a function defined in spherical coordinates :math:`(r, \theta, \phi)`. The Laplacian is computed as the divergence of the gradient: .. math:: \Delta f = \nabla \cdot \left(\nabla f\right). In spherical coordinates, the Laplacian of :math:`f(r,\theta,\phi)` is expressed as .. math:: \Delta f = \frac{1}{r^2}\frac{\partial}{\partial r}\Bigl(r^2 \frac{\partial f}{\partial r}\Bigr) \;+\; \frac{1}{r^2 \sin\theta}\frac{\partial}{\partial \theta}\Bigl(\sin\theta \frac{\partial f}{\partial \theta}\Bigr) \;+\; \frac{1}{r^2 \sin^2\theta}\frac{\partial^2 f}{\partial \phi^2}. This function computes the gradient of :math:`f(r,\theta,\phi)` in spherical coordinates and then its divergence to obtain the Laplacian. Parameters ---------- outputs: torch.Tensor Tensor representing the function values :math:`f(r,\theta,\phi)`. inputs: torch.Tensor Tensor representing spherical coordinates :math:`[r, \theta, \phi]`. track: bool, optional If True, enables gradient tracking for higher-order derivatives. Default = ``False`` Returns ------- torch.Tensor: The Laplacian :math:`\Delta f` in spherical coordinates. """ grad = spherical_gradient(outputs, inputs, track=True) laplacian = spherical_divergence(grad, inputs, track=track) return laplacian
[docs] def s2_laplacian( outputs: torch.Tensor, inputs: torch.Tensor, track: bool = False, ) -> torch.Tensor: r"""Compute the Laplacian of a function defined on the 2-sphere (S²). The Laplacian on S² is defined as the divergence of the gradient on the sphere: .. math:: \Delta_{\mathbb{S}^2} f = \nabla_{\mathbb{S}^2} \cdot \left(\nabla_{\mathbb{S}^2} f\right). For a function :math:`f(\theta,\phi)` on S², the Laplacian expands to .. math:: \Delta_{\mathbb{S}^2} f = \frac{1}{\sin\theta}\frac{\partial}{\partial \theta}\Bigl(\sin\theta \frac{\partial f}{\partial \theta}\Bigr) \;+\; \frac{1}{\sin^2\theta}\frac{\partial^2 f}{\partial \phi^2}. This function computes the gradient of :math:`f(\theta,\phi)` on S² and then evaluates its divergence to obtain the Laplacian. Parameters ---------- outputs: torch.Tensor Tensor representing the function values :math:`f(\theta,\phi)`. inputs: torch.Tensor Tensor representing spherical coordinates :math:`(\theta, \phi)` on S². track: bool, optional If True, enables gradient tracking for higher-order derivatives. Default = ``False`` Returns ------- torch.Tensor: The Laplacian :math:`\Delta_{\mathbb{S}^2} f` on the 2-sphere. """ grad = s2_gradient(outputs, inputs, track=True) laplacian = s2_divergence(grad, inputs, track=track) return laplacian