Source code for spherical_inr.mlp
r"""MLP backbones (ReLU or sine) for parameterizing implicit neural representations."""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import List
__all__ = ["ReLUMLP", "SineMLP"]
[docs]
class ReLUMLP(nn.Module):
r"""
ReLU-activated multi-layer perceptron.
Hidden layers apply:
.. math::
h_k = \mathrm{ReLU}(W_k h_{k-1} + b_k), \quad k=1,\dots,L-1,
and the output layer is linear:
.. math::
f_\theta(x) = W_L h_{L-1} + b_L.
Parameters
----------
input_features:
Input dimension.
output_features:
Output dimension.
hidden_sizes:
List of hidden layer widths.
bias:
Whether to include biases in each linear layer.
"""
def __init__(
self,
input_features: int,
output_features: int,
hidden_sizes: List[int],
bias: bool = True,
):
super().__init__()
self.input_features = int(input_features)
self.output_features = int(output_features)
sizes = [self.input_features] + list(hidden_sizes) + [self.output_features]
self.layers = nn.ModuleList(
nn.Linear(sizes[i], sizes[i + 1], bias=bias) for i in range(len(sizes) - 1)
)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the ReLU-activated MLP.
Applies a sequence of linear layers with ReLU activation on all hidden
layers, followed by a final linear output layer without activation.
Parameters
----------
x : torch.Tensor
Input tensor of shape ``(..., input_features)``.
Returns
-------
torch.Tensor
Output tensor of shape ``(..., output_features)``.
"""
for layer in self.layers[:-1]:
x = F.relu(layer(x))
return self.layers[-1](x)
[docs]
class SineMLP(nn.Module):
r"""
Sine-activated multi-layer perceptron.
This module is identical to :class:`ReluMLP` except the hidden activation is
a sine nonlinearity with frequency factor :math:`\omega_0`.
Given an input :math:`x`, the hidden activations are
.. math::
h_0 = x, \qquad
h_k = \sin\!\bigl(\omega_0 (W_k h_{k-1} + b_k)\bigr),
\quad k=1,\dots,L-1,
and the output layer is linear:
.. math::
f_\theta(x) = W_L h_{L-1} + b_L.
The weights are initialized uniformly (per layer) as
.. math::
W_k \sim \mathcal{U}\!\left[-\frac{\sqrt{6/n_k}}{\omega_0},
\frac{\sqrt{6/n_k}}{\omega_0}\right],
where :math:`n_k` is the fan-in (number of input features) of layer :math:`k`.
Biases are initialized to zero when present.
Parameters
----------
input_features:
Input dimension.
output_features:
Output dimension.
hidden_sizes:
List of hidden layer widths.
bias:
Whether to include biases in each linear layer.
omega0:
Frequency factor :math:`\omega_0` used in the sine activation and in the
weight initialization bound.
"""
def __init__(
self,
input_features: int,
output_features: int,
hidden_sizes: List[int],
bias: bool = True,
omega0: float = 1.0,
) -> None:
super().__init__()
self.input_features = input_features
self.output_features = output_features
self.bias = bias
sizes = [input_features] + hidden_sizes + [output_features]
self.hidden_layers = nn.ModuleList(
nn.Linear(sizes[i], sizes[i + 1], bias=bias) for i in range(len(sizes) - 1)
)
self.omega0 = omega0
self.reset_parameters()
def reset_parameters(
self,
) -> None:
with torch.no_grad():
for layer in self.hidden_layers:
fan_in = layer.weight.size(1)
bound = math.sqrt(6 / fan_in) / self.omega0
layer.weight.uniform_(-bound, bound)
if layer.bias is not None:
nn.init.constant_(layer.bias, 0)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the sine-activated MLP.
Applies sine nonlinearities with frequency scaling ``omega0`` after each
hidden linear layer, followed by a final linear output layer.
Parameters
----------
x : torch.Tensor
Input tensor of shape ``(..., input_features)``.
Returns
-------
torch.Tensor
Output tensor of shape ``(..., output_features)``.
"""
for layer in self.hidden_layers[:-1]:
x = torch.sin(self.omega0 * layer(x))
return self.hidden_layers[-1](x)