Source code for spherical_inr.loss

import torch
import torch.nn as nn

import spherical_inr.differentiation as D
from typing import Optional


[docs] class SphericalLaplacianLoss(nn.Module): r"""Spherical Laplacian Loss. Computes the loss based on the spherical Laplacian of the network output. For a scalar function :math:`f` defined in spherical coordinates :math:`(r,\theta,\phi)`, the spherical Laplacian is given by .. math:: \Delta_{sph} f = \frac{1}{r^2}\frac{\partial}{\partial r}\left( r^2\,\frac{\partial f}{\partial r} \right) + \frac{1}{r^2 \sin\theta}\frac{\partial}{\partial \theta}\left( \sin\theta\,\frac{\partial f}{\partial \theta} \right) + \frac{1}{r^2 \sin^2\theta}\frac{\partial^2 f}{\partial \phi^2}. The loss is defined as the mean squared value of the Laplacian: .. math:: \mathcal{L} = \operatorname{mean}\Bigl( (\Delta_{sph} f)^2 \Bigr). """
[docs] def forward(self, output: torch.Tensor, input: torch.Tensor) -> torch.Tensor: lap = D.spherical_laplacian(output, input, track=True) loss = (lap).pow(2).mean(dim=0) return loss
[docs] class CartesianLaplacianLoss(nn.Module): r"""Cartesian Laplacian Loss. Computes the loss based on the Cartesian Laplacian of the network output. For a scalar function :math:`f` defined on :math:`\mathbb{R}^n`, the Cartesian Laplacian is .. math:: \Delta f = \sum_{i=1}^{n} \frac{\partial^2 f}{\partial x_i^2}. The loss is defined as the mean squared value of the Laplacian: .. math:: \mathcal{L} = \operatorname{mean}\Bigl( (\Delta f)^2 \Bigr). """
[docs] def forward(self, output: torch.Tensor, input: torch.Tensor) -> torch.Tensor: lap = D.cartesian_laplacian(output, input, track=True) loss = (lap).pow(2).mean(dim=0) return loss
[docs] class S2LaplacianLoss(nn.Module): r"""S2 Laplacian Loss. Computes the loss based on the Laplacian of the network output on the 2-sphere. For a function :math:`f` defined on the 2-sphere with coordinates :math:`(\theta,\phi)`, the Laplacian is .. math:: \Delta_{S^2} f = \frac{1}{\sin\theta}\frac{\partial}{\partial \theta}\left( \sin\theta\,\frac{\partial f}{\partial \theta} \right) + \frac{1}{\sin^2\theta}\frac{\partial^2 f}{\partial \phi^2}. The loss is defined as the mean squared value of the Laplacian: .. math:: \mathcal{L} = \operatorname{mean}\Bigl( (\Delta_{S^2} f)^2 \Bigr). """
[docs] def forward(self, output: torch.Tensor, input: torch.Tensor) -> torch.Tensor: lap = D.s2_laplacian(output, input, track=True) loss = (lap).pow(2).mean(dim=0) return loss
[docs] class CartesianGradientMSELoss(nn.Module): r"""Cartesian Gradient MSE Loss. Computes the mean squared error (MSE) loss between the Cartesian gradient of the network output and a target gradient. For a function :math:`f` defined on :math:`\mathbb{R}^n`, the Cartesian gradient is .. math:: \nabla f = \left( \frac{\partial f}{\partial x_1},\, \frac{\partial f}{\partial x_2},\, \dots,\, \frac{\partial f}{\partial x_n} \right). The loss is defined as .. math:: \mathcal{L} = \operatorname{mean}\Bigl( \sum_{i=1}^{n}\Bigl( \frac{\partial f}{\partial x_i} - t_i \Bigr)^2 \Bigr), where :math:`t` denotes the target gradient. """
[docs] def forward( self, target: torch.Tensor, output: torch.Tensor, input: torch.Tensor ) -> torch.Tensor: grad = D.cartesian_gradient(output, input, track=True) loss = (grad - target).pow(2).sum(dim=-1).mean(dim=0) return loss
[docs] class SphericalGradientMSELoss(nn.Module): r"""Spherical Gradient MSE Loss. Computes the mean squared error (MSE) loss between the spherical gradient of the network output and a target gradient. For a function :math:`f` defined in spherical coordinates :math:`(r,\theta,\phi)`, the spherical gradient is given by .. math:: \nabla_{sph} f = \left( \frac{\partial f}{\partial r},\, \frac{1}{r}\frac{\partial f}{\partial \theta},\, \frac{1}{r\,\sin\theta}\frac{\partial f}{\partial \phi} \right). The loss is defined as .. math:: \mathcal{L} = \operatorname{mean}\Bigl( \sum_{i}\Bigl( (\nabla_{sph} f)_i - t_i \Bigr)^2 \Bigr), where :math:`t` represents the target gradient. """
[docs] def forward( self, target: torch.Tensor, output: torch.Tensor, input: torch.Tensor ) -> torch.Tensor: grad = D.spherical_gradient(output, input, track=True) loss = (grad - target).pow(2).sum(dim=-1).mean(dim=0) return loss
[docs] class S2GradientMSELoss(nn.Module): r"""S2 Gradient MSE Loss. Computes the mean squared error (MSE) loss between the gradient on the 2-sphere and a target gradient. For a function :math:`f` defined on the 2-sphere with coordinates :math:`(\theta,\phi)`, the gradient is .. math:: \nabla_{S^2} f = \left( \frac{\partial f}{\partial \theta},\, \frac{1}{\sin\theta}\frac{\partial f}{\partial \phi} \right). The loss is defined as .. math:: \mathcal{L} = \operatorname{mean}\Bigl( \sum_{i=1}^{2}\Bigl( (\nabla_{S^2} f)_i - t_i \Bigr)^2 \Bigr), where :math:`t` denotes the target gradient. """
[docs] def forward( self, target: torch.Tensor, output: torch.Tensor, input: torch.Tensor ) -> torch.Tensor: grad = D.s2_gradient(output, input, track=True) loss = (grad - target).pow(2).sum(dim=-1).mean(dim=0) return loss
[docs] class CartesianGradientLaplacianMSELoss(nn.Module): r"""Cartesian Gradient-Laplacian MSE Loss. Computes a composite loss that combines the MSE between the Cartesian gradient of the network output and a target gradient with a regularization term based on the squared Cartesian Laplacian. For a function :math:`f` defined on :math:`\mathbb{R}^n`, let .. math:: \nabla f = \left( \frac{\partial f}{\partial x_1},\, \dots,\, \frac{\partial f}{\partial x_n} \right) \quad \text{and} \quad \Delta f = \sum_{i=1}^{n} \frac{\partial^2 f}{\partial x_i^2}. The loss is defined as .. math:: \mathcal{L} = \operatorname{mean}\Bigl( \sum_{i=1}^{n}\Bigl( \frac{\partial f}{\partial x_i} - t_i \Bigr)^2 \Bigr) \;+\; \alpha\,\operatorname{mean}\Bigl( (\Delta f)^2 \Bigr), where :math:`t` is the target gradient and :math:`\alpha` is a regularization parameter. """ def __init__(self, alpha_reg: float = 1.0): super().__init__() self.register_buffer("alpha_reg", torch.tensor(alpha_reg))
[docs] def forward( self, target: torch.Tensor, output: torch.Tensor, input: torch.Tensor ) -> torch.Tensor: grad = D.cartesian_gradient(output, input, track=True) lap = D.cartesian_divergence(grad, input, track=True) loss = (grad - target).pow(2).sum(dim=-1).mean( dim=0 ) + self.alpha_reg * lap.pow(2).mean(dim=0) return loss
[docs] class SphericalGradientLaplacianMSELoss(nn.Module): r"""Spherical Gradient-Laplacian MSE Loss. Computes a composite loss that combines the MSE between the spherical gradient of the network output and a target gradient with a regularization term based on the squared spherical Laplacian. For a function :math:`f` defined in spherical coordinates :math:`(r,\theta,\phi)`, let .. math:: \nabla_{sph} f = \left( \frac{\partial f}{\partial r},\, \frac{1}{r}\frac{\partial f}{\partial \theta},\, \frac{1}{r\,\sin\theta}\frac{\partial f}{\partial \phi} \right) and the spherical Laplacian is .. math:: \Delta_{sph} f = \frac{1}{r^2}\frac{\partial}{\partial r}\left( r^2\,\frac{\partial f}{\partial r} \right) + \frac{1}{r^2 \sin\theta}\frac{\partial}{\partial \theta}\left( \sin\theta\,\frac{\partial f}{\partial \theta} \right) + \frac{1}{r^2 \sin^2\theta}\frac{\partial^2 f}{\partial \phi^2}. The loss is defined as .. math:: \mathcal{L} = \operatorname{mean}\Bigl( \sum_{i}\Bigl( (\nabla_{sph} f)_i - t_i \Bigr)^2 \Bigr) \;+\; \alpha\,\operatorname{mean}\Bigl( (\Delta_{sph} f)^2 \Bigr), where :math:`t` is the target gradient and :math:`\alpha` is a regularization coefficient. """ def __init__(self, alpha_reg: float = 1.0): super().__init__() self.register_buffer("alpha_reg", torch.tensor(alpha_reg))
[docs] def forward( self, target: torch.Tensor, output: torch.Tensor, input: torch.Tensor ) -> torch.Tensor: grad = D.spherical_gradient(output, input, track=True) lap = D.spherical_divergence(grad, input, track=True) loss = (grad - target).pow(2).sum(dim=-1).mean( dim=0 ) + self.alpha_reg * lap.pow(2).mean(dim=0) return loss
[docs] class S2GradientLaplacianMSELoss(nn.Module): r"""S2 Gradient-Laplacian MSE Loss. Computes a composite loss for functions defined on the 2-sphere that combines the MSE between the gradient on the 2-sphere and a target gradient with a regularization term based on the squared Laplacian on the 2-sphere. For a function :math:`f` defined on the 2-sphere with coordinates :math:`(\theta,\phi)`, let .. math:: \nabla_{S^2} f = \left( \frac{\partial f}{\partial \theta},\, \frac{1}{\sin\theta}\frac{\partial f}{\partial \phi} \right) and .. math:: \Delta_{S^2} f = \frac{1}{\sin\theta}\frac{\partial}{\partial \theta}\left( \sin\theta\,\frac{\partial f}{\partial \theta} \right) + \frac{1}{\sin^2\theta}\frac{\partial^2 f}{\partial \phi^2}. The loss is defined as .. math:: \mathcal{L} = \operatorname{mean}\Bigl( \sum_{i=1}^{2}\Bigl( (\nabla_{S^2} f)_i - t_i \Bigr)^2 \Bigr) \;+\; \alpha\,\operatorname{mean}\Bigl( (\Delta_{S^2} f)^2 \Bigr), where :math:`t` denotes the target gradient and :math:`\alpha` is a regularization parameter. """ def __init__(self, alpha_reg: float = 1.0): super().__init__() self.register_buffer("alpha_reg", torch.tensor(alpha_reg))
[docs] def forward( self, target: torch.Tensor, output: torch.Tensor, input: torch.Tensor ) -> torch.Tensor: grad = D.s2_gradient(output, input, track=True) lap = D.s2_divergence(grad, input, track=True) loss = (grad - target).pow(2).sum(dim=-1).mean( dim=0 ) + self.alpha_reg * lap.pow(2).mean(dim=0) return loss