torchoptics#

TorchOptics: Differentiable wave optics simulation library built on PyTorch.

Submodules#

Classes#

Field

Optical field class.

OpticsModule

Base class for all optics modules.

PlanarGrid

Base class for TorchOptics classes with 2D planar grid geometries.

SpatialCoherence

Spatial Coherence class.

System

Optical system of elements similar to torch.nn.Sequential.

Functions#

animate_tensor(tensor[, title, xlabel, ylabel, ...])

Animate a 3D tensor over time using matplotlib.

get_default_dtype()

Get the current default dtype value.

get_default_spacing()

Get the current default spacing value.

get_default_wavelength()

Get the current default wavelength value.

set_default_dtype(value)

Set the default dtype value.

set_default_spacing(value)

Set the default spacing value.

set_default_wavelength(value)

Set the default wavelength value.

visualize_tensor(tensor[, title, xlabel, ylabel, ...])

Visualize a 2D real or complex-valued tensor using matplotlib.

Package Contents#

class torchoptics.Field(data, wavelength=None, z=0, spacing=None, offset=None)#

Bases: torchoptics.planar_grid.PlanarGrid

Optical field class.

Parameters:
  • data (Tensor) – The complex-valued field data.

  • wavelength (Scalar | None) – The wavelength of the field. Default: if None, uses a global default (see torchoptics.set_default_wavelength()).

  • z (Scalar) – Position along the z-axis. Default: 0.

  • spacing (Vector2 | None) – Distance between grid points along planar dimensions. Default: if None, uses a global default (see torchoptics.set_default_spacing()).

  • offset (Vector2 | None) – Center coordinates of the plane. Default: (0, 0).

centroid()#

Return the centroid of the intensity.

Return type:

torch.Tensor

copy(**kwargs)#

Create a copy of the field with optionally updated properties.

Parameters:

**kwargs – Properties to update in the copy.

Returns:

A new field with updated properties.

Return type:

Field

inner(other)#

Return the inner product of the field (last two data dimensions) with another field.

Parameters:

other (Field) – The other field.

Returns:

The inner product.

Return type:

Tensor

intensity()#

Return the intensity of the field.

Return type:

torch.Tensor

modulate(modulation_profile)#

Modulate the field by a modulation profile.

Parameters:

modulation_profile (Tensor) – The modulation profile.

Returns:

Modulated field.

Return type:

Field

normalize(normalized_power=1.0)#

Normalize the field to a specified power.

Parameters:

normalized_power (Scalar) – The normalized power. Default: 1.0.

Returns:

Normalized field.

Return type:

Field

outer(other)#

Return the outer product of the field (last two data dimensions) with another field.

Parameters:

other (Field) – The other field.

Returns:

The outer product.

Return type:

Tensor

polarized_modulate(polarized_modulation_profile)#

Modulate the field by a polarized modulation profile.

Parameters:

polarized_modulation_profile (Tensor) – The polarized modulation profile.

Returns:

Modulated field.

Return type:

Field

polarized_split()#

Split the field into three polarized fields.

Returns:

The split fields.

Return type:

tuple[Field, Field, Field]

power()#

Return the total power of the field calculated by integrating the intensity over the plane.

Return type:

torch.Tensor

propagate(shape, z, spacing=None, offset=None, *, propagation_method='AUTO', asm_pad=None, interpolation_mode='nearest')#

Propagate the field through free-space to a plane defined by the input parameters.

Parameters:
  • shape (Vector2) – Number of grid points along the planar dimensions.

  • z (Scalar) – Position along the z-axis.

  • spacing (Vector2 | None) – Distance between grid points along planar dimensions. Default: if None, uses a global default (see torchoptics.set_default_spacing()).

  • offset (Vector2 | None) – Center coordinates of the plane. Default: (0, 0).

  • propagation_method (str) – The propagation method to use. Default: “AUTO”.

  • asm_pad (Vector2 | None) – The padding size along both planar dimensions for ASM propagation. Default: if None, pads by 2x the input field size in each dimension.

  • interpolation_mode (str) – The interpolation mode to use. Default: “nearest”.

Returns:

Output field after propagating to the plane.

Return type:

Field

propagate_to_plane(plane, *, propagation_method='AUTO', asm_pad=None, interpolation_mode='nearest')#

Propagate the field through free-space to a plane defined by a PlanarGrid object.

Parameters:
  • plane (PlanarGrid) – Plane grid.

  • propagation_method (str) – The propagation method to use. Default: “AUTO”.

  • asm_pad (Vector2 | None) – The padding size along both planar dimensions for ASM propagation. Default: if None, pads by 2x the input field size in each dimension.

  • interpolation_mode (str) – The interpolation mode to use. Default: “nearest”.

Returns:

Output field after propagating to the plane.

Return type:

Field

propagate_to_z(z, *, propagation_method='AUTO', asm_pad=None, interpolation_mode='nearest')#

Propagate the field through free-space to a plane at a specific z position.

The plane has the same shape, spacing, and offset as the input field.

Parameters:
  • z (Scalar) – Position along the z-axis.

  • propagation_method (str) – The propagation method to use. Default: “AUTO”.

  • asm_pad (Vector2 | None) – The padding size along both planar dimensions for ASM propagation. Default: if None, pads by 2x the input field size in each dimension.

  • interpolation_mode (str) – The interpolation mode to use. Default: “nearest”.

Returns:

Output field after propagating to the plane.

Return type:

Field

std()#

Return the standard deviation of the intensity.

Return type:

torch.Tensor

visualize(*index, **kwargs)#

Visualize the field.

Parameters:
  • *index (int) – Index of the data tensor to visualize.

  • intensity (bool) – Whether to visualize only the intensity. Default: False.

  • **kwargs – Additional keyword arguments for visualization.

Return type:

matplotlib.figure.Figure | None

class torchoptics.OpticsModule#

Bases: torch.nn.Module

Base class for all optics modules.

This class facilitates the registration of tensors, representing optics-related properties, as either PyTorch parameters or buffers. These properties are validated and registered using register_optics_property():

from torchoptics import OpticsModule
from torch.nn import Parameter

class MyOpticsModule(OpticsModule):
    def __init__(self, trainable_property, non_trainable_property):
        super().__init__()
        self.register_optics_property("trainable_property", Parameter(trainable_property), shape=())
        self.register_optics_property("non_trainable_property", non_trainable_property, shape=())

Once the properties are registered, they can be updated using set_optics_property().

Note

__setattr__() is overridden to call set_optics_property() when setting the value of an optics property.

register_optics_property(name, value, **kwargs)#

Register an optics property as a PyTorch parameter or buffer.

Parameters:
  • name (str) – Name of the optics property.

  • value (Any) – Initial value of the property.

  • **kwargs – Additional keyword arguments for tensor initialization (e.g., is_scalar, is_vector2, is_complex, is_positive, is_non_negative).

Return type:

None

set_optics_property(name, value)#

Set the value of an existing optics property.

Parameters:
  • name (str) – Name of the optics property.

  • value (Any) – New value of the property.

Raises:
  • AttributeError – If the property is not registered.

  • ValueError – If the value does not match the property’s conditions.

Return type:

None

class torchoptics.PlanarGrid(shape, z=0, spacing=None, offset=None)#

Bases: torchoptics.optics_module.OpticsModule

Base class for TorchOptics classes with 2D planar grid geometries.

This class defines objects with planar geometries perpendicular to the z-axis. It includes methods for calculating various properties of the planar grid.

Parameters:
  • shape (Vector2) – Number of grid points along the planar dimensions.

  • z (Scalar) – Position along the z-axis. Default: 0.

  • spacing (Vector2 | None) – Distance between grid points along planar dimensions. Default: if None, uses a global default (see torchoptics.set_default_spacing()).

  • offset (Vector2 | None) – Center coordinates of the plane. Default: (0, 0).

bounds(use_grid_points=False)#

Return the position of the plane boundaries along the planar dimensions.

Parameters:

use_grid_points (bool) – If True, returns the position of the first and last grid points along the planar dimensions. Otherwise, returns the position of the edges of the first and last grid cells. Default: False.

Return type:

torch.Tensor

cell_area()#

Return the area between adjacent grid points.

Return type:

torch.Tensor

extra_repr()#

Return the extra representation of the class.

Return type:

str

property geometry: dict[str, Any]#

Return a dictionary containing shape, z, spacing, and offset.

Return type:

dict[str, Any]

geometry_str()#

Return a string representation of the geometry properties.

Return type:

str

is_same_geometry(other)#

Check if the geometry is the same as another PlanarGrid instance.

Parameters:

other (PlanarGrid) – Another instance of PlanarGrid to compare with.

Return type:

bool

length(use_grid_points=False)#

Return the length of the plane along the planar dimensions.

Parameters:

use_grid_points (bool) – If True, returns the length between the first and last grid points along the planar dimensions. Otherwise, returns the length between the edges of the first and last grid cells. Default: False.

Return type:

torch.Tensor

meshgrid()#

Return a 2D meshgrid of the grid points along the plane.

Return type:

tuple[torch.Tensor, torch.Tensor]

property shape: tuple[int, int]#

Return the shape of the plane.

Return type:

tuple[int, int]

class torchoptics.SpatialCoherence(data, wavelength=None, z=0, spacing=None, offset=None)#

Bases: Field

Spatial Coherence class.

Parameters:
  • data (Tensor) – The complex-valued spatial coherence data.

  • wavelength (Scalar | None) – The wavelength of the field. Default: if None, uses a global default (see torchoptics.set_default_wavelength()).

  • z (Scalar) – Position along the z-axis. Default: 0.

  • spacing (Vector2 | None) – Distance between grid points along planar dimensions. Default: if None, uses a global default (see torchoptics.set_default_spacing()).

  • offset (Vector2 | None) – Center coordinates of the plane. Default: (0, 0).

inner(other)#

SpatialCoherence does not support the inner product.

Parameters:

other (Field)

Return type:

torch.Tensor

intensity()#

Return the intensity of the spatial coherence.

Return type:

torch.Tensor

modulate#

Modulate the field by a modulation profile.

Parameters:

modulation_profile (Tensor) – The modulation profile.

Returns:

Modulated field.

Return type:

Field

normalize(normalized_power=1.0)#

Normalize the spatial coherence to a given power.

Parameters:

normalized_power (torchoptics.types.Scalar)

Return type:

Field

outer(other)#

SpatialCoherence does not support the outer product.

Parameters:

other (Field)

Return type:

torch.Tensor

propagate#

Propagate the field through free-space to a plane defined by the input parameters.

Parameters:
  • shape (Vector2) – Number of grid points along the planar dimensions.

  • z (Scalar) – Position along the z-axis.

  • spacing (Vector2 | None) – Distance between grid points along planar dimensions. Default: if None, uses a global default (see torchoptics.set_default_spacing()).

  • offset (Vector2 | None) – Center coordinates of the plane. Default: (0, 0).

  • propagation_method (str) – The propagation method to use. Default: “AUTO”.

  • asm_pad (Vector2 | None) – The padding size along both planar dimensions for ASM propagation. Default: if None, pads by 2x the input field size in each dimension.

  • interpolation_mode (str) – The interpolation mode to use. Default: “nearest”.

Returns:

Output field after propagating to the plane.

Return type:

Field

visualize(*index, **kwargs)#

Visualize the the time-averaged intensity (diagonal of the spatial coherence matrix).

Parameters:
  • *index (int) – Index of the data tensor to visualize.

  • intensity (bool) – Whether to visualize only the intensity. Default: False.

  • **kwargs – Additional keyword arguments for visualization.

Return type:

matplotlib.figure.Figure | None

class torchoptics.System(*elements)#

Bases: torch.nn.Module

Optical system of elements similar to torch.nn.Sequential.

The system consists of a sequence of optical elements, ordered by their z positions. When a Field is passed to forward(), it is propagated through the system: each element applies its own transformation via forward(). The output from the final element is returned.

Field measurements at arbitrary planes can be performed using measure(), measure_at_z(), or measure_at_plane().

Indexing with system[i] returns the i-th optical element. Slicing, e.g. system[i:j], returns a new System containing the selected elements.

Example

Create a 4f system consisting of two lenses:

system = System(
    Lens(shape, focal_length, z=1 * focal_length),
    Lens(shape, focal_length, z=3 * focal_length),
).to(device)

# Measure the field at the 4f plane
output_field = system.measure_at_z(input_field, z=4 * focal_length)
Parameters:

*elements (Element) – Optical elements in the system.

property elements: tuple[torchoptics.elements.Element, Ellipsis]#

Return the elements in the system.

Return type:

tuple[torchoptics.elements.Element, Ellipsis]

elements_in_field_path(field, last_element)#

Return the elements along the field path.

Parameters:
  • field (Field) – Input field.

  • last_element (Element | None) – Last element of the system.

Returns:

Elements along the field path.

Return type:

tuple[Element]

forward(field, *, propagation_method='AUTO', asm_pad=None, interpolation_mode='nearest')#

Propagate the field through the system.

Parameters:
  • field (Field) – Input field.

  • propagation_method (str) – The propagation method to use. Default: “AUTO”.

  • asm_pad (Vector2 | None) – The padding size along both planar dimensions for ASM propagation. Default: if None, pads by 2x the input field size in each dimension.

  • interpolation_mode (str) – The interpolation mode to use. Default: “nearest”.

Returns:

Output field after propagating through the system.

Return type:

Field

measure(field, shape, z, spacing=None, offset=None, *, propagation_method='AUTO', asm_pad=None, interpolation_mode='nearest')#

Propagate the field through the system to a plane defined by the input parameters.

Parameters:
  • field (Field) – Input field.

  • shape (Vector2) – Number of grid points along the planar dimensions.

  • z (Scalar) – Position along the z-axis.

  • spacing (Vector2 | None) – Distance between grid points along planar dimensions. Default: if None, uses a global default (see torchoptics.set_default_spacing()).

  • offset (Vector2 | None) – Center coordinates of the plane. Default: (0, 0).

  • propagation_method (str) – The propagation method to use. Default: “AUTO”.

  • asm_pad (Vector2 | None) – The padding size along both planar dimensions for ASM propagation. Default: if None, pads by 2x the input field size in each dimension.

  • interpolation_mode (str) – The interpolation mode to use. Default: “nearest”.

Returns:

Output field after propagating to the plane.

Return type:

Field

measure_at_plane(field, plane, *, propagation_method='AUTO', asm_pad=None, interpolation_mode='nearest')#

Propagate the field through the system to a plane defined by a PlanarGrid object.

Parameters:
  • field (Field) – Input field.

  • plane (PlanarGrid) – Plane grid.

  • propagation_method (str) – The propagation method to use. Default: “AUTO”.

  • asm_pad (Vector2 | None) – The padding size along both planar dimensions for ASM propagation. Default: if None, pads by 2x the input field size in each dimension.

  • interpolation_mode (str) – The interpolation mode to use. Default: “nearest”.

Returns:

Output field after propagating to the plane.

Return type:

Field

measure_at_z(field, z, *, propagation_method='AUTO', asm_pad=None, interpolation_mode='nearest')#

Propagate the field through the system to a plane at a specific z position.

The plane has the same shape, spacing, and offset as the input field.

Parameters:
  • field (Field) – Input field.

  • z (Scalar) – Position along the z-axis.

  • propagation_method (str) – The propagation method to use. Default: “AUTO”.

  • asm_pad (Vector2 | None) – The padding size along both planar dimensions for ASM propagation. Default: if None, pads by 2x the input field size in each dimension.

  • interpolation_mode (str) – The interpolation mode to use. Default: “nearest”.

Returns:

Output field after propagating to the plane.

Return type:

Field

sorted_elements()#

Return the elements sorted by their z position.

Return type:

tuple[torchoptics.elements.Element, Ellipsis]

torchoptics.animate_tensor(tensor, title=None, xlabel=None, ylabel=None, symbol=None, show=True, func_anim_kwargs=None, **imshow_kwargs)#

Animate a 3D tensor over time using matplotlib.

The first dimension of the tensor is treated as time or frame index. If the tensor is complex, each frame is visualized as both magnitude squared and phase.

Parameters:
  • tensor (Tensor) – A 3D tensor of shape (T, H, W).

  • title (str | Sequence[str] | None) – Title for each frame, or a static title.

  • xlabel (str | None) – Label for the x-axis.

  • ylabel (str | None) – Label for the y-axis.

  • symbol (str | None) – Symbol used in subplot titles for LaTeX rendering.

  • show (bool) – Whether to call plt.show(). Defaults to True.

  • func_anim_kwargs (dict | None) – Additional keyword arguments for FuncAnimation.

  • **imshow_kwargs – Additional keyword arguments passed directly to matplotlib.pyplot.imshow(), such as cmap, vmin, vmax, interpolation, etc.

Returns:

The matplotlib animation object.

Return type:

FuncAnimation

torchoptics.get_default_dtype()#

Get the current default dtype value.

Return type:

torch.dtype

torchoptics.get_default_spacing()#

Get the current default spacing value.

Return type:

torch.Tensor

torchoptics.get_default_wavelength()#

Get the current default wavelength value.

Return type:

torch.Tensor

torchoptics.set_default_dtype(value)#

Set the default dtype value.

Parameters:

value (torch.dtype) – The default dtype.

Return type:

None

Example

>>> torchoptics.set_default_dtype(torch.float32)
torchoptics.set_default_spacing(value)#

Set the default spacing value.

Parameters:

value (Vector2) – The default spacing.

Return type:

None

Example

>>> torchoptics.set_default_spacing((10e-6, 10e-6))
torchoptics.set_default_wavelength(value)#

Set the default wavelength value.

Parameters:

value (Scalar) – The default wavelength.

Return type:

None

Example

>>> torchoptics.set_default_wavelength(700e-6)
torchoptics.visualize_tensor(tensor, title=None, xlabel=None, ylabel=None, symbol=None, show=True, return_fig=False, **imshow_kwargs)#

Visualize a 2D real or complex-valued tensor using matplotlib.

If the tensor is complex, two subplots are shown: one for the magnitude squared and one for the phase.

Parameters:
  • tensor (Tensor) – A 2D tensor of shape (H, W).

  • title (str | None) – Title for the figure.

  • xlabel (str | None) – Label for the x-axis.

  • ylabel (str | None) – Label for the y-axis.

  • symbol (str | None) – Symbol used in subplot titles for LaTeX rendering.

  • show (bool) – Whether to call plt.show(). Defaults to True.

  • return_fig (bool) – If True, returns the matplotlib Figure.

  • **imshow_kwargs – Additional keyword arguments passed directly to matplotlib.pyplot.imshow(), such as cmap, vmin, vmax, interpolation, etc.

Returns:

The matplotlib Figure if return_fig is True, else None.

Return type:

plt.Figure | None