Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added warning if port mesh refinement is incompatible with the `GridSpec` in the `TerminalComponentModeler`.
- Various types, e.g. different `Simulation` or `SimulationData` sub-classes, can be loaded from file directly with `Tidy3dBaseModel.from_file()`.
- Added `interp_spec` in `EMEModeSpec` to enable faster multi-frequency EME simulations. Note that the default is now `ModeInterpSpec.cheb(num_points=3, reduce_data=True)`; previously the computation was repeated at all frequencies.
- Added `smoothed_projection` for topology optimization of completely binarized designs.

### Breaking Changes
- Edge singularity correction at PEC and lossy metal edges defaults to `True`.
Expand Down
1 change: 1 addition & 0 deletions docs/api/plugins/autograd.rst
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,5 @@ Inverse Design
tidy3d.plugins.autograd.invdes.make_filter_and_project
tidy3d.plugins.autograd.invdes.ramp_projection
tidy3d.plugins.autograd.invdes.tanh_projection
tidy3d.plugins.autograd.invdes.smoothed_projection

101 changes: 101 additions & 0 deletions tests/test_plugins/autograd/invdes/test_projections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from __future__ import annotations

import autograd
import numpy as np

from tidy3d.plugins.autograd.invdes.filters import ConicFilter
from tidy3d.plugins.autograd.invdes.projections import smoothed_projection, tanh_projection


def test_smoothed_projection_beta_inf():
nx, ny = 50, 50
arr = np.zeros((50, 50), dtype=float)

center_x, center_y = 25, 25
radius = 10
x = np.arange(nx)
y = np.arange(ny)
X, Y = np.meshgrid(x, y)
distance = np.sqrt((X - center_x) ** 2 + (Y - center_y) ** 2)

arr[distance <= radius] = 1

filter = ConicFilter(kernel_size=5)
arr_filtered = filter(arr)

result = smoothed_projection(
array=arr_filtered,
beta=np.inf,
eta=0.5,
)
assert not np.any(np.isinf(result) | np.isnan(result))
assert np.isclose(result[center_x, center_y], 1)
assert np.isclose(result[0, -1], 0)
assert np.isclose(result[0, 0], 0)
assert np.isclose(result[-1, 0], 0)
assert np.isclose(result[-1, -1], 0)

# fully discrete input should lead to fully discrete output
discrete_result = smoothed_projection(
array=arr,
beta=np.inf,
eta=0.5,
)
assert np.all(np.isclose(discrete_result, 0) | np.isclose(discrete_result, 1))


def test_smoothed_projection_beta_non_inf():
nx, ny = 50, 50
arr = np.zeros((50, 50), dtype=float)

center_x, center_y = 25, 25
radius = 10
x = np.arange(nx)
y = np.arange(ny)
X, Y = np.meshgrid(x, y)
distance = np.sqrt((X - center_x) ** 2 + (Y - center_y) ** 2)

arr[distance <= radius] = 1

# fully discrete input should still be fully discrete output
discrete_result = smoothed_projection(
array=arr,
beta=1.0,
eta=0.5,
)
assert np.all(np.isclose(discrete_result, 0) | np.isclose(discrete_result, 1))

filter = ConicFilter(kernel_size=11)
arr_filtered = filter(arr)

smooth_result = smoothed_projection(
array=arr_filtered,
beta=1.0,
eta=0.5,
)
# for sufficiently smooth input, the result should be the same as tanh projection
tanh_result = tanh_projection(
array=arr_filtered,
beta=1.0,
eta=0.5,
)
assert np.isclose(smooth_result, tanh_result, rtol=0, atol=1e-4).all()


def test_smoothed_projection_initialization():
# test that for initialization at eta=0.5, projection returns simply 0.5
arr = np.zeros((5, 5), dtype=float) + 0.5
result = smoothed_projection(array=arr, beta=1.0, eta=0.5)
assert np.all(np.isclose(result, 0.5))


def test_projection_gradient():
# test that gradient is finite
arr = np.zeros((5, 5), dtype=float) + 0.5

def _helper_fn(x):
return smoothed_projection(array=x, beta=1.0, eta=0.5).mean()

val, grad = autograd.value_and_grad(_helper_fn)(arr)
assert val == 0.5
assert np.all(~(np.isnan(grad) | np.isinf(grad)))
3 changes: 2 additions & 1 deletion tidy3d/plugins/autograd/invdes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
make_filter_and_project,
)
from .penalties import ErosionDilationPenalty, make_curvature_penalty, make_erosion_dilation_penalty
from .projections import ramp_projection, tanh_projection
from .projections import ramp_projection, smoothed_projection, tanh_projection

__all__ = [
"CircularFilter",
Expand All @@ -34,5 +34,6 @@
"make_filter_and_project",
"make_gaussian_filter",
"ramp_projection",
"smoothed_projection",
"tanh_projection",
]
111 changes: 111 additions & 0 deletions tidy3d/plugins/autograd/invdes/projections.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,114 @@ def tanh_projection(
num = np.tanh(beta * eta) + np.tanh(beta * (array - eta))
denom = np.tanh(beta * eta) + np.tanh(beta * (1 - eta))
return num / denom


def smoothed_projection(
array: NDArray,
beta: float = BETA_DEFAULT,
eta: float = ETA_DEFAULT,
scaling_factor=1.0,
) -> NDArray:
"""
Apply a subpixel-smoothed projection method.
The subpixel-smoothed projection method is discussed in [1]_ as follows:

This projection method eliminates discontinuities by applying first-order
smoothing at material boundaries through analytical fill factors. Unlike
traditional quadrature approaches, it works with maximum projection strength
(:math:`\\beta = \\infty`) and derives closed-form expressions for interfacial regions.

Prerequisites: input fields must be pre-filtered for continuity (for example
using a conic filter).

The algorithm detects whether boundaries intersect grid cells. When interfaces
are absent, standard projection is applied. For cells containing boundaries,
analytical fill ratios are computed to maintain gradient continuity as interfaces
move through cells and traverse pixel centers. This enables arbitrarily large
:math:`\\beta` values while preserving differentiability throughout the transition
process.

.. warning::
This function assumes that the device is placed on a uniform grid. When using
```GridSpec.auto``` in the simulation, make sure to place a ``MeshOverrideStructure`` at
the position of the optimized geometry.

.. warning::
When using :math:`\\beta = \\infty` the function will produce NaN values if
the input is exactly equal to ``eta``.

Parameters
----------
array : np.ndarray
The input array to be projected.
beta : float = BETA_DEFAULT
The steepness of the projection. Higher values result in a sharper transition.
eta : float = ETA_DEFAULT
The midpoint of the projection.
scaling_factor: float = 1.0
Optional scaling factor to adjust dx and dy to different resolutions.

Example
-------
>>> import autograd.numpy as np
>>> from tidy3d.plugins.autograd.invdes.filters import ConicFilter
>>> arr = np.random.uniform(size=(50, 50))
>>> filter = ConicFilter(kernel_size=5)
>>> arr_filtered = filter(arr)
>>> eta = 0.5 # center of projection
>>> smoothed = smoothed_projection(arr_filtered, beta=np.inf, eta=eta)

.. [1] A. M. Hammond, A. Oskooi, I. M. Hammond, M. Chen, S. E. Ralph, and
S. G. Johnson, "Unifying and accelerating level-set and density-based topology
optimization by subpixel-smoothed projection," arXiv:2503.20189v3 [physics.optics]
(2025).
"""
# sanity checks
if array.ndim != 2:
raise ValueError(f"Smoothed projection expects a 2d-array, but got shape {array.shape=}")

# smoothing kernel is circle (or ellipse for non-uniform grid)
# we choose smoothing kernel with unit area, which is r~=0.56, a bit larger than (arbitrary) default r=0.55 in paper
dx = dy = scaling_factor
smooth_radius = np.sqrt(1 / np.pi) * scaling_factor

original_projected = tanh_projection(array, beta=beta, eta=eta)

# finite-difference spatial gradients
rho_filtered_grad = np.gradient(array)
rho_filtered_grad_helper = (rho_filtered_grad[0] / dx) ** 2 + (rho_filtered_grad[1] / dy) ** 2

nonzero_norm = np.abs(rho_filtered_grad_helper) > 1e-10

filtered_grad_norm = np.sqrt(np.where(nonzero_norm, rho_filtered_grad_helper, 1))
filtered_grad_norm_eff = np.where(nonzero_norm, filtered_grad_norm, 1)

# distance of pixel center to nearest interface
distance = (eta - array) / filtered_grad_norm_eff

needs_smoothing = nonzero_norm & (np.abs(distance) < smooth_radius)

# double where trick
d_rel = distance / smooth_radius
polynom = np.where(
needs_smoothing, 0.5 - 15 / 16 * d_rel + 5 / 8 * d_rel**3 - 3 / 16 * d_rel**5, 1.0
)
# F(-d)
polynom_neg = np.where(
needs_smoothing, 0.5 + 15 / 16 * d_rel - 5 / 8 * d_rel**3 + 3 / 16 * d_rel**5, 1.0
)

# two projections, one for lower and one for upper bound
rho_filtered_minus = array - smooth_radius * filtered_grad_norm_eff * polynom
rho_filtered_plus = array + smooth_radius * filtered_grad_norm_eff * polynom_neg
rho_minus_eff_projected = tanh_projection(rho_filtered_minus, beta=beta, eta=eta)
rho_plus_eff_projected = tanh_projection(rho_filtered_plus, beta=beta, eta=eta)

# Smoothing is only applied to projections
projected_smoothed = (1 - polynom) * rho_minus_eff_projected + polynom * rho_plus_eff_projected
smoothed = np.where(
needs_smoothing,
projected_smoothed,
original_projected,
)
return smoothed