Inverse Design#

Every TorchOptics operation (propagation, modulation, detection) is fully differentiable through torch.autograd. This enables gradient-based optimization of optical systems using the same tools used to train neural networks.

Trainable Properties#

Wrap any property value in torch.nn.Parameter to make it learnable.

Modulation data (the most common case):

import torch
from torch.nn import Parameter
from torchoptics.elements import PhaseModulator

slm = PhaseModulator(Parameter(torch.zeros(300, 300)), z=0)

Scalar properties such as focal length, position, or angle:

from torchoptics.elements import Lens, LinearPolarizer

lens = Lens(300, focal_length=Parameter(torch.tensor(100e-3)), z=0)
pol = LinearPolarizer(300, theta=Parameter(torch.tensor(0.0)), z=0)

This works for every registered property (z, spacing, offset, focal_length, theta, phi, etc.). If the value is a Parameter, it is learnable; otherwise it is a fixed buffer.

Parameterization#

For unconstrained phase modulation, PhaseModulator(Parameter(torch.zeros(...))) works directly: gradients flow through the phase values and the optimizer is free to explore all real numbers.

When a parameter must stay within a physical range, use torch.nn.utils.parametrize.register_parametrization(). For example, to keep an amplitude modulator’s values in \([0, 1]\), register a sigmoid parametrization:

import torch.nn.utils.parametrize as parametrize
from torch.nn import Parameter
from torchoptics.elements import AmplitudeModulator

slm = AmplitudeModulator(Parameter(torch.zeros(300, 300)), z=0)
parametrize.register_parametrization(slm, "amplitude", torch.nn.Sigmoid())
# slm.amplitude is always sigmoid(raw) in (0, 1); the optimizer trains the raw logits

The same pattern works for any differentiable constraint: torch.nn.Softplus() for positive-only values, or a custom torch.nn.Module for arbitrary mappings.

Training Loop#

The standard PyTorch training loop applies directly:

import torch
from torch.nn import Parameter
import torchoptics
from torchoptics import Field, System
from torchoptics.elements import PhaseModulator
from torchoptics.profiles import gaussian

torchoptics.set_default_spacing(10e-6)
torchoptics.set_default_wavelength(700e-9)

shape = 250
input_field = Field(gaussian(shape, waist_radius=300e-6), z=0)
target_field = Field(gaussian(shape, waist_radius=100e-6), z=0.4).normalize()

system = System(
    PhaseModulator(Parameter(torch.zeros(shape, shape)), z=0.0),
    PhaseModulator(Parameter(torch.zeros(shape, shape)), z=0.2),
)

optimizer = torch.optim.Adam(system.parameters(), lr=0.1)

for iteration in range(200):
    optimizer.zero_grad()
    output = system.measure_at_z(input_field, z=0.4)
    loss = 1 - output.inner(target_field).abs().square()
    loss.backward()
    optimizer.step()

inner() computes the complex overlap integral between two fields. The loss \(1 - |\eta|^2\), where \(\eta\) is the inner product, is zero for perfect overlap and one for orthogonal fields. Any differentiable PyTorch loss can be used here.

See also

Optimization & Inverse Design — complete end-to-end optimization examples.