Particle optimisation#

Demonstration of particle optimisation via gradient descent. Farfield cross sections are optimiatised to fit a guassian curve centered at 600.0nm.

Core and shell refractive indices and radii are optimised, with the materials limited to dispersionless dielectrics.

We optimize a large number of initial guesses concurrently, which avoids that a single solution gets stuck in a local minimum.

author: O. Jackson, P. Wiecha 03/2025

imports#

import time
import matplotlib.pyplot as plt
import pymiediff as pmd
import torch
import numpy as np

setup optimiation target#

# - define the range of wavelengths to be incuded in optimisation.
N_wl = 21
wl0 = torch.linspace(400, 800, N_wl)
k0 = 2 * torch.pi / wl0


# - for this example we target a gaussian like spectra centered at 600.0nm
def gaussian(x, mu, sig):
    return (
        1.0 / (np.sqrt(2.0 * np.pi) * sig) * np.exp(-np.power((x - mu) / sig, 2.0) / 2)
    )


target = gaussian(wl0.numpy(), 600.0, 60.0) * 700 + 0.5

target_tensor = torch.tensor(target)

# - we can plot the target spectra
plt.figure()
plt.plot(wl0, target, label="Target spetra.")
plt.xlabel("$\lambda$ (nm)")
plt.legend()
# plt.savefig("ex_04a.svg", dpi=300)
plt.show()
ex 06 optimisation
/home/runner/work/MieDiff/MieDiff/examples/ex_06_optimisation.py:51: SyntaxWarning: invalid escape sequence '\l'
  plt.xlabel("$\lambda$ (nm)")

setup particle prameter limits#

# - constants
n_env = 1.0

# - set limits to particle's properties, in this example we limit to dielectric materials
lim_r = torch.as_tensor([10, 100], dtype=torch.double)
lim_n_re = torch.as_tensor([1, 4.5], dtype=torch.double)
lim_n_im = torch.as_tensor([0, 0.1], dtype=torch.double)

normalization helper#

we let the optimizer work on normalized parameters which we pass through a sigmoid. This is a straightforward way to implement box boundaries for the optimization variables.

def params_to_physical(r_opt, n_opt):
    """converts normalised parameters to physical

    Args:
        r_opt (torch.Tensor): normalised radii
        n_opt (torch.Tensor): normalised materials

    Returns:
        torch.Tensor: physical parameters
    """

    # constrain optimization internally to physical limits
    # sigmoid: convert to [0, 1], then renormalize to physical limits
    sigmoid = torch.nn.Sigmoid()
    r_c_n = sigmoid(r_opt[0]).clone()
    d_s_n = sigmoid(r_opt[1]).clone()
    n_c_re_n = sigmoid(n_opt[0]).clone()
    n_s_re_n = sigmoid(n_opt[1]).clone()
    n_c_im_n = sigmoid(n_opt[2]).clone()
    n_s_im_n = sigmoid(n_opt[3]).clone()

    # scale parameters to physical units
    # size parameters
    r_c = r_c_n * (lim_r.max() - lim_r.min()) + lim_r.min()
    d_s = d_s_n * (lim_r.max() - lim_r.min()) + lim_r.min()
    r_s = r_c + d_s

    # core and shell complex ref. index
    n_c = (n_c_re_n * (lim_n_re.max() - lim_n_re.min()) + lim_n_re.min()) + 1j * (
        n_c_im_n * (lim_n_im.max() - lim_n_im.min()) + lim_n_im.min()
    )
    n_s = (n_s_re_n * (lim_n_re.max() - lim_n_re.min()) + lim_n_re.min()) + 1j * (
        n_s_im_n * (lim_n_im.max() - lim_n_im.min()) + lim_n_im.min()
    )

    return r_c, n_c**2, r_s, n_s**2

random initialization#

we use PyMieDiff’s vectorization capabilities to run the optimization of many random initial guesses in parallel.

# number of random guesses to make.
num_guesses = 100

# 2 size parameters (radius of core and thickness of shell)
# 4 material parameters: real and imag parts of constant ref. indices
r_opt_arr = torch.normal(0, 1, (2, num_guesses))
n_opt_arr = torch.normal(0, 1, (4, num_guesses))
r_opt_arr.requires_grad = True
n_opt_arr.requires_grad = True

optimisation loop#

define losses, create and run optimization loop. In this example adam optimizer is used, but the example is written such that it is ready to be used with LBFGS instead (requiring a “closure”).

max_iter = 50


# - define optimiser and hyperparameters
optimizer = torch.optim.AdamW(
    [r_opt_arr, n_opt_arr],
    lr=0.2,
)
# - alternative optimizer: LBFGS
# optimizer = torch.optim.LBFGS(
#     [r_opt_arr, n_opt_arr], lr=0.2, max_iter=10, history_size=7
# )


# - helper for batched forward pass (many particles)
def eval_batch(r_opt_arr, n_opt_arr):
    r_c, eps_c, r_s, eps_s = params_to_physical(r_opt_arr, n_opt_arr)

    # stack layers in (N_part, L) so pena backend broadcasts over wavelengths
    r_layers = torch.stack((r_c, r_s), dim=1)
    eps_layers = torch.stack((eps_c, eps_s), dim=1)

    # evaluate Mie
    result_mie = pmd.multishell.cross_sections(
        k0=k0.unsqueeze(0),
        r_layers=r_layers,
        eps_layers=eps_layers,
    )["q_sca"]

    # get loss, MSE comparing target with current spectra
    losses = torch.mean(torch.abs(target_tensor.unsqueeze(0) - result_mie) ** 2, dim=1)

    return losses


# - main loop
start_time = time.time()
loss_hist = []  # Array to store loss data
for o in range(max_iter + 1):
    optimizer.zero_grad()
    all_losses = eval_batch(r_opt_arr, n_opt_arr)
    loss = torch.mean(all_losses)
    loss.backward()
    optimizer.step()

    all_losses = eval_batch(r_opt_arr, n_opt_arr)
    loss_hist.append(loss.item())  # Store loss value

    if o % 1 == 0:
        i_best = torch.argmin(all_losses)
        r_c, eps_c, r_s, eps_s = params_to_physical(r_opt_arr, n_opt_arr)
        print(
            " --- iter {}: loss={:.2f}, best={:.2f}".format(
                o, loss.item(), all_losses.min().item()
            )
        )
        print(
            "     r_core, r_shell  = {:.1f}nm,     {:.1f}nm".format(
                r_c[i_best], r_s[i_best]
            )
        )
        print(
            "     n_core, n_shell  = {:.2f}, {:.2f}".format(
                torch.sqrt(eps_c[i_best]), torch.sqrt(eps_s[i_best])
            )
        )

# - finished
print(50 * "-")
t_opt = time.time() - start_time
print(
    "Optimization finished in {:.1f}s ({:.1f}s per iteration)".format(
        t_opt, t_opt / max_iter
    )
)
 --- iter 0: loss=5.13, best=0.92
     r_core, r_shell  = 78.1nm,     118.5nm
     n_core, n_shell  = 3.92+0.08j, 1.38+0.05j
 --- iter 1: loss=3.91, best=0.62
     r_core, r_shell  = 75.1nm,     118.8nm
     n_core, n_shell  = 3.66+0.05j, 1.91+0.06j
 --- iter 2: loss=3.44, best=0.72
     r_core, r_shell  = 50.3nm,     125.3nm
     n_core, n_shell  = 4.15+0.07j, 2.14+0.07j
 --- iter 3: loss=3.19, best=0.56
     r_core, r_shell  = 76.6nm,     110.3nm
     n_core, n_shell  = 3.81+0.07j, 1.51+0.04j
 --- iter 4: loss=2.98, best=0.52
     r_core, r_shell  = 78.0nm,     111.4nm
     n_core, n_shell  = 3.84+0.07j, 1.59+0.04j
 --- iter 5: loss=2.78, best=0.54
     r_core, r_shell  = 57.7nm,     107.9nm
     n_core, n_shell  = 3.77+0.09j, 2.48+0.08j
 --- iter 6: loss=2.60, best=0.51
     r_core, r_shell  = 60.3nm,     110.9nm
     n_core, n_shell  = 3.83+0.09j, 2.26+0.06j
 --- iter 7: loss=2.45, best=0.42
     r_core, r_shell  = 61.7nm,     112.8nm
     n_core, n_shell  = 3.90+0.09j, 2.26+0.07j
 --- iter 8: loss=2.32, best=0.46
     r_core, r_shell  = 76.7nm,     117.9nm
     n_core, n_shell  = 3.75+0.07j, 1.67+0.07j
 --- iter 9: loss=2.20, best=0.40
     r_core, r_shell  = 58.4nm,     107.3nm
     n_core, n_shell  = 3.93+0.09j, 2.37+0.09j
 --- iter 10: loss=2.11, best=0.35
     r_core, r_shell  = 59.7nm,     109.3nm
     n_core, n_shell  = 3.98+0.09j, 2.39+0.09j
 --- iter 11: loss=2.06, best=0.28
     r_core, r_shell  = 62.6nm,     112.3nm
     n_core, n_shell  = 4.05+0.10j, 2.15+0.08j
 --- iter 12: loss=2.00, best=0.29
     r_core, r_shell  = 62.1nm,     111.0nm
     n_core, n_shell  = 4.07+0.10j, 2.11+0.08j
 --- iter 13: loss=1.91, best=0.31
     r_core, r_shell  = 66.4nm,     115.7nm
     n_core, n_shell  = 4.10+0.05j, 1.92+0.09j
 --- iter 14: loss=1.81, best=0.27
     r_core, r_shell  = 60.3nm,     109.0nm
     n_core, n_shell  = 4.09+0.10j, 2.31+0.09j
 --- iter 15: loss=1.73, best=0.25
     r_core, r_shell  = 66.2nm,     110.6nm
     n_core, n_shell  = 4.15+0.04j, 1.96+0.09j
 --- iter 16: loss=1.69, best=0.26
     r_core, r_shell  = 65.6nm,     107.7nm
     n_core, n_shell  = 4.17+0.04j, 1.97+0.09j
 --- iter 17: loss=1.65, best=0.27
     r_core, r_shell  = 63.7nm,     111.4nm
     n_core, n_shell  = 4.17+0.10j, 2.06+0.08j
 --- iter 18: loss=1.59, best=0.28
     r_core, r_shell  = 65.5nm,     105.0nm
     n_core, n_shell  = 4.20+0.03j, 2.01+0.09j
 --- iter 19: loss=1.54, best=0.25
     r_core, r_shell  = 60.3nm,     107.2nm
     n_core, n_shell  = 4.17+0.10j, 2.22+0.09j
 --- iter 20: loss=1.50, best=0.23
     r_core, r_shell  = 61.1nm,     108.2nm
     n_core, n_shell  = 4.19+0.10j, 2.23+0.10j
 --- iter 21: loss=1.48, best=0.25
     r_core, r_shell  = 61.4nm,     110.3nm
     n_core, n_shell  = 4.24+0.09j, 2.14+0.10j
 --- iter 22: loss=1.46, best=0.24
     r_core, r_shell  = 57.5nm,     109.9nm
     n_core, n_shell  = 4.37+0.08j, 2.16+0.09j
 --- iter 23: loss=1.42, best=0.21
     r_core, r_shell  = 58.3nm,     109.8nm
     n_core, n_shell  = 4.37+0.07j, 2.18+0.09j
 --- iter 24: loss=1.38, best=0.22
     r_core, r_shell  = 58.9nm,     109.5nm
     n_core, n_shell  = 4.38+0.07j, 2.20+0.09j
 --- iter 25: loss=1.33, best=0.22
     r_core, r_shell  = 61.3nm,     107.3nm
     n_core, n_shell  = 4.23+0.10j, 2.20+0.10j
 --- iter 26: loss=1.29, best=0.19
     r_core, r_shell  = 59.3nm,     107.5nm
     n_core, n_shell  = 4.38+0.07j, 2.22+0.09j
 --- iter 27: loss=1.26, best=0.19
     r_core, r_shell  = 59.1nm,     106.0nm
     n_core, n_shell  = 4.38+0.07j, 2.22+0.09j
 --- iter 28: loss=1.22, best=0.20
     r_core, r_shell  = 59.0nm,     104.7nm
     n_core, n_shell  = 4.39+0.07j, 2.22+0.09j
 --- iter 29: loss=1.19, best=0.21
     r_core, r_shell  = 59.2nm,     104.1nm
     n_core, n_shell  = 4.39+0.06j, 2.23+0.09j
 --- iter 30: loss=1.18, best=0.19
     r_core, r_shell  = 59.6nm,     104.0nm
     n_core, n_shell  = 4.39+0.06j, 2.24+0.09j
 --- iter 31: loss=1.16, best=0.19
     r_core, r_shell  = 60.1nm,     104.1nm
     n_core, n_shell  = 4.39+0.06j, 2.25+0.09j
 --- iter 32: loss=1.13, best=0.20
     r_core, r_shell  = 60.4nm,     104.2nm
     n_core, n_shell  = 4.40+0.06j, 2.25+0.09j
 --- iter 33: loss=1.09, best=0.20
     r_core, r_shell  = 60.8nm,     103.7nm
     n_core, n_shell  = 4.33+0.09j, 2.24+0.10j
 --- iter 34: loss=1.06, best=0.19
     r_core, r_shell  = 56.5nm,     98.6nm
     n_core, n_shell  = 4.39+0.10j, 2.57+0.10j
 --- iter 35: loss=1.02, best=0.19
     r_core, r_shell  = 56.9nm,     99.2nm
     n_core, n_shell  = 4.39+0.10j, 2.57+0.10j
 --- iter 36: loss=1.00, best=0.20
     r_core, r_shell  = 57.2nm,     99.7nm
     n_core, n_shell  = 4.40+0.10j, 2.56+0.10j
 --- iter 37: loss=0.98, best=0.19
     r_core, r_shell  = 56.5nm,     104.8nm
     n_core, n_shell  = 4.48+0.08j, 2.37+0.10j
 --- iter 38: loss=0.94, best=0.17
     r_core, r_shell  = 56.4nm,     103.3nm
     n_core, n_shell  = 4.48+0.08j, 2.37+0.10j
 --- iter 39: loss=0.90, best=0.18
     r_core, r_shell  = 57.1nm,     99.8nm
     n_core, n_shell  = 4.40+0.10j, 2.51+0.10j
 --- iter 40: loss=0.88, best=0.19
     r_core, r_shell  = 56.8nm,     101.3nm
     n_core, n_shell  = 4.48+0.08j, 2.38+0.10j
 --- iter 41: loss=0.86, best=0.17
     r_core, r_shell  = 57.3nm,     101.3nm
     n_core, n_shell  = 4.48+0.08j, 2.39+0.10j
 --- iter 42: loss=0.85, best=0.17
     r_core, r_shell  = 57.9nm,     101.5nm
     n_core, n_shell  = 4.48+0.08j, 2.40+0.10j
 --- iter 43: loss=0.84, best=0.18
     r_core, r_shell  = 57.3nm,     100.4nm
     n_core, n_shell  = 4.41+0.10j, 2.47+0.10j
 --- iter 44: loss=0.83, best=0.18
     r_core, r_shell  = 57.6nm,     100.9nm
     n_core, n_shell  = 4.41+0.10j, 2.47+0.10j
 --- iter 45: loss=0.81, best=0.17
     r_core, r_shell  = 55.1nm,     101.0nm
     n_core, n_shell  = 4.46+0.10j, 2.51+0.10j
 --- iter 46: loss=0.79, best=0.18
     r_core, r_shell  = 55.3nm,     100.5nm
     n_core, n_shell  = 4.46+0.10j, 2.50+0.10j
 --- iter 47: loss=0.78, best=0.17
     r_core, r_shell  = 55.7nm,     100.3nm
     n_core, n_shell  = 4.46+0.10j, 2.50+0.10j
 --- iter 48: loss=0.77, best=0.17
     r_core, r_shell  = 56.2nm,     100.4nm
     n_core, n_shell  = 4.46+0.10j, 2.49+0.10j
 --- iter 49: loss=0.77, best=0.17
     r_core, r_shell  = 56.7nm,     100.5nm
     n_core, n_shell  = 4.46+0.10j, 2.49+0.10j
 --- iter 50: loss=0.76, best=0.17
     r_core, r_shell  = 57.0nm,     100.4nm
     n_core, n_shell  = 4.46+0.10j, 2.49+0.10j
--------------------------------------------------
Optimization finished in 9.6s (0.2s per iteration)

optimisation results#

view optimised speactra and corresponding particle parameters.

# - plot optimised spectra against target spectra
wl0_eval = torch.linspace(400, 800, 200)
k0_eval = 2 * torch.pi / wl0_eval

i_best = torch.argmin(all_losses)
r_c, eps_c, r_s, eps_s = params_to_physical(r_opt_arr[:, i_best], n_opt_arr[:, i_best])

cs_opt = pmd.multishell.cross_sections(
    k0=k0_eval,
    r_layers=torch.stack((r_c, r_s)),
    eps_layers=torch.stack((eps_c, eps_s)),
)

plt.figure(figsize=(5, 3.5))
plt.plot(cs_opt["wavelength"], cs_opt["q_sca"][0].detach(), label="$Q_{sca}^{optim}$")
plt.plot(wl0, target_tensor, label="$Q_{sca}^{target}$", linestyle="--")
plt.xlabel("wavelength (nm)")
plt.ylabel("Scattering efficiency")
plt.legend()
plt.tight_layout()
plt.show()



# - print optimun parameters
print(50 * "-")
print("optimum:")
print(" r_core  = {:.1f}nm".format(r_c))
print(" r_shell = {:.1f}nm".format(r_s))
print(" n_core  = {:.2f}".format(torch.sqrt(eps_c)))
print(" n_shell = {:.2f}".format(torch.sqrt(eps_s)))
ex 06 optimisation
--------------------------------------------------
optimum:
 r_core  = 57.0nm
 r_shell = 100.4nm
 n_core  = 4.46+0.10j
 n_shell = 2.49+0.10j

Total running time of the script: (0 minutes 10.855 seconds)

Estimated memory usage: 756 MB

Gallery generated by Sphinx-Gallery