From 7dcb206fcf46d75477a9abd8bc73ab45605fcaa1 Mon Sep 17 00:00:00 2001 From: Alec Hammond Date: Tue, 19 Mar 2024 13:35:56 -0700 Subject: [PATCH 1/2] add subpixel smoothing example --- Metagrating3D/metagrating_fmmax_smoothing.py | 371 +++++++++++++++++++ 1 file changed, 371 insertions(+) create mode 100644 Metagrating3D/metagrating_fmmax_smoothing.py diff --git a/Metagrating3D/metagrating_fmmax_smoothing.py b/Metagrating3D/metagrating_fmmax_smoothing.py new file mode 100644 index 0000000..dbee962 --- /dev/null +++ b/Metagrating3D/metagrating_fmmax_smoothing.py @@ -0,0 +1,371 @@ +"""metagrating_fmmax_smoothing.py - simulate the metagrating problem using +fmmax, and optimize using the smoothed projection operator. + +The simulation itself is generated from the invrs.io challenge gym. + +The subpixel smoothing routine is a first-order approximate, and ported over +from meep. +""" +import dataclasses +import functools +from dataclasses import dataclass +from typing import Callable, List, Tuple + +import agjax +import jax +import nlopt +from invrs_gym import challenges +from invrs_gym.challenges.base import Challenge +from invrs_gym.challenges.diffract.metagrating_challenge import METAGRATING_SPEC +from jax import numpy as jnp +from matplotlib import pyplot as plt +from meep import adjoint as mpa +from totypes import types + +# -------------------------------------------- # +# Challenge problem constants and types +# -------------------------------------------- # + +RESOLUTION = 1 / METAGRATING_SPEC.grid_spacing +DEFAULT_ETA = 0.5 +DEFAULT_ETA_E = 0.75 + +# Degrees of freedom in x- and y-directions, pulled from challenge problem. +N_x, N_y = 118, 46 + + +@dataclass +class OptimizationParams: + """""" + + beta: float + eta: float + filter_radius: float + + +# The expected results composite type +Results = Tuple[jnp.ndarray, List[jnp.ndarray], List[float]] +# -------------------------------------------- # +# Main routines +# -------------------------------------------- # + + +def run_shape_optimization( + starting_design: jnp.ndarray, num_iters: int, min_lengthscale: float +) -> Results: + """ + Runs shape optimization (β=∞) with the given parameters. + + Args: + starting_design: The optimization initial condition. + num_iters: The number of iterations to run the optimization for. + min_lengthscale: The minimum length scale for the optimization. + + Returns: + Results: The final design, design history, and FOM history of the optimization process. + """ + + return _run_optimization( + starting_design=starting_design, + beta=jnp.inf, + num_iters=num_iters, + min_lengthscale=min_lengthscale, + ) + + +def run_topology_optimization( + starting_design: jnp.ndarray, betas: List[float], num_iters, min_lengthscale: float +) -> Results: + """ + Runs multi-epoch topology optimization (β=∞) with the given parameters. + + Args: + starting_design: The optimization initial condition. + betas: Projection function parameter list. + num_iters: The number of iterations to run the optimization for. + min_lengthscale: The minimum length scale for the optimization. + + Returns: + Results: The final design, design history, and FOM history of the optimization process. + """ + data = [] + results = [] + + # iterate through each bet aepoch + for current_beta in betas: + final_design, current_data, current_results = _run_optimization( + starting_design=starting_design, + beta=current_beta, + num_iters=num_iters, + min_lengthscale=min_lengthscale, + ) + + # refresh the starting design with our latest optimized result + starting_design = final_design + + # Concatenate results + data += current_data + results += current_results + + return final_design, data, results + + +# -------------------------------------------- # +# Define jax wrappers for autograd utils +# -------------------------------------------- # + + +@agjax.wrap_for_jax +def jax_conic_filter(input_array: jnp.ndarray, radius: float) -> jnp.ndarray: + """Jax wrapper for meep's conic filter function.""" + return mpa.conic_filter( + x=input_array, + radius=radius, + Lx=METAGRATING_SPEC.period_x, + Ly=METAGRATING_SPEC.period_y, + resolution=RESOLUTION, + periodic_axes=[True, True], # periodic in both directions + ) + + +@agjax.wrap_for_jax +def jax_smoothed_projection(x_smoothed: jnp.ndarray, beta: float, eta: float): + """Jax wrapper for meep's smoothed projection operator.""" + return mpa.smoothed_projection( + x_smoothed=x_smoothed, + beta=beta, + eta=eta, + resolution=RESOLUTION, + ) + + +# -------------------------------------------- # +# Optimization helper routines +# -------------------------------------------- # + + +def _loss_function( + design_vector: jnp.ndarray, + design_params: types.Density2DArray, + optimization_params: OptimizationParams, + challenge_problem: Challenge, +): + """ + Computes a weighted loss function for the diffraction problem. + + The exact loss function is pulled directly from the invrs.io gym. Filtering, + projection, and symmetry operations are performed to ensure proper setup. + + Args: + design_vector: The design vector to compute the loss for. + design_params: The design parameters for the optimization. + optimization_params: The optimization parameters. + challenge_problem: The challenge problem for the optimization. + + Returns: + Tuple: The loss and a tuple containing the smoothed array, response, and efficiency. + """ + design_array = design_vector.reshape(N_x, N_y) + + # enforce symmetry + design_array = (design_array + jnp.fliplr(design_array)) / 2 + + # Filter the design parameters + filtered_array = jax_conic_filter(design_array, optimization_params.filter_radius) + + # Smoothly project the design parameters + smoothed_array = jax_smoothed_projection( + filtered_array, optimization_params.beta, optimization_params.eta + ) + + design_params = dataclasses.replace(design_params, array=smoothed_array) + + # Simulate the challenge problem + response, aux = challenge_problem.component.response(design_params) + + # Use the same loss quantities as the paper + loss = challenge_problem.loss(response) + metrics = challenge_problem.metrics(response, params=design_params, aux=aux) + efficiency = metrics["average_efficiency"] + + return loss, (smoothed_array, response, efficiency) + + +def nlopt_fom( + x: jnp.ndarray, gradient: jnp.ndarray, loss_fn: Callable, data: List, results: List +): + """Wrapper for NLopt FOM. + Args: + x: Degrees of freedom array. + gradient: Gradient of FOM. + loss_fn: Problem specific loss function. + data: Structure to store the simulated design each iteration. + results: Structure to store the simulated FOM each iteration. + Returns: + The FOM value at the current iteration. + """ + + loss_val_aux, current_grad = loss_fn(x) + + # Decompose everything + loss_val, (smoothed_array, response, efficiency) = loss_val_aux + + if gradient.size > 0: + gradient[:] = current_grad + + # Data logging + data.append(smoothed_array.copy()) + results.append(float(efficiency)) + + print("FOM: {:.2f}, Efficiency: {:.2f}%".format(loss_val, efficiency * 100)) + + return float(loss_val) # explicit cast for nlopt + + +def _run_optimization( + starting_design: jnp.ndarray, beta: float, num_iters: int, min_lengthscale: float +) -> Results: + """ + Runs a single optimization epoch with the given parameters. + + Args: + starting_design: The optimization initial condition. + beta: The projection parameter [0,∞]. + num_iters: The number of iterations to run the optimization for. + min_lengthscale: The minimum length scale for the optimization. + + Returns: + Results: The final design, design history, and FOM history of the optimization process. + """ + # Set up logging data structures + data = [] + results = [] + + # Set up the challenge problem + challenge_problem = challenges.metagrating() + design_params = challenge_problem.component.init(jax.random.PRNGKey(0)) + + filter_radius = mpa.get_conic_radius_from_eta_e(min_lengthscale, DEFAULT_ETA_E) + + optimization_params = OptimizationParams( + beta=beta, + eta=DEFAULT_ETA, + filter_radius=filter_radius, + ) + + loss_fn = jax.value_and_grad( + functools.partial( + _loss_function, + design_params=design_params, + optimization_params=optimization_params, + challenge_problem=challenge_problem, + ), + has_aux=True, + ) + + nlopt_wrapper = functools.partial( + nlopt_fom, + loss_fn=loss_fn, + data=data, + results=results, + ) + + # Set up nlopt's CCSA algorithm + algorithm = nlopt.LD_CCSAQ + solver = nlopt.opt(algorithm, N_x * N_y) + solver.set_lower_bounds(0) + solver.set_upper_bounds(1) + solver.set_maxeval(num_iters) + solver.set_min_objective(nlopt_wrapper) + + # Run the optimization + final_design = solver.optimize(starting_design.flatten()) + + return final_design, data, results + + +# -------------------------------------------- # +# Visualization routines +# -------------------------------------------- # + + +def visualize_evolution( + data: List, results: List, design_samples: List, output_filename: str +) -> None: + """ + Visualizes the evolution of the design optimization process. + + Saves the resulting figure. + + Args: + data: The list of design data at each iteration. + results: The list of results at each iteration. + design_samples: The list of design samples to visualize. + output_filename: The filename to save the plot as. + + Returns: + Nothing. + """ + num_samples = len(design_samples) + plt.figure(figsize=(2 * num_samples, 4), constrained_layout=True) + + for k in range(num_samples): + plt.subplot(2, num_samples, k + 1) + plt.imshow(data[design_samples[k]], cmap="binary", vmin=0, vmax=1) + plt.axis("off") + plt.title(f"Iter {design_samples[k]}") + + plt.subplot(2, 1, 2) + plt.plot(range(1, len(results) + 1), jnp.asarray(results) * 100, "-o") + plt.xlabel("Optimization Iteration") + plt.ylabel("Efficiency (%)") + + plt.savefig(output_filename) + + +# -------------------------------------------- # +# +# -------------------------------------------- # + +if __name__ == "__main__": + # Hyperparameters + num_iters = 60 + min_lengthscale = 0.1 + + # generate a random initial design + key = jax.random.PRNGKey(314159) + starting_design = jax.random.uniform(key, (N_x, N_y)) + starting_design = (starting_design + jnp.fliplr(starting_design)) / 2 + + # Run a round of shape optimization + if True: + final_design, data, results = run_shape_optimization( + starting_design=starting_design, + num_iters=num_iters, + min_lengthscale=min_lengthscale, + ) + + visualize_evolution( + data=data, + results=results, + design_samples=[0, 10, 30, 45, -1], + output_filename="shape_optimization.png", + ) + + # Run a round of topology optimization + if True: + betas = [16.0, 64.0, jnp.inf] + + final_design, data, results = run_topology_optimization( + betas=betas, + starting_design=starting_design, + num_iters=num_iters, + min_lengthscale=min_lengthscale, + ) + + visualize_evolution( + data=data, + results=results, + design_samples=[0, 20, 50, 100, -1], + output_filename="topology_optimization.png", + ) From 87fb9ab4b6851e83e4fdf34d46d9b3b35d807d2f Mon Sep 17 00:00:00 2001 From: Alec Hammond Date: Thu, 11 Apr 2024 11:05:46 -0700 Subject: [PATCH 2/2] use updated projectino function --- Metagrating3D/metagrating_fmmax_smoothing.py | 156 +++++++++++++++---- 1 file changed, 126 insertions(+), 30 deletions(-) diff --git a/Metagrating3D/metagrating_fmmax_smoothing.py b/Metagrating3D/metagrating_fmmax_smoothing.py index dbee962..78d9396 100644 --- a/Metagrating3D/metagrating_fmmax_smoothing.py +++ b/Metagrating3D/metagrating_fmmax_smoothing.py @@ -6,6 +6,7 @@ The subpixel smoothing routine is a first-order approximate, and ported over from meep. """ + import dataclasses import functools from dataclasses import dataclass @@ -16,7 +17,10 @@ import nlopt from invrs_gym import challenges from invrs_gym.challenges.base import Challenge -from invrs_gym.challenges.diffract.metagrating_challenge import METAGRATING_SPEC +from invrs_gym.challenges.diffract.metagrating_challenge import ( + METAGRATING_SIM_PARAMS, + METAGRATING_SPEC, +) from jax import numpy as jnp from matplotlib import pyplot as plt from meep import adjoint as mpa @@ -44,7 +48,7 @@ class OptimizationParams: # The expected results composite type -Results = Tuple[jnp.ndarray, List[jnp.ndarray], List[float]] +Results = Tuple[jnp.ndarray, jnp.ndarray, List[jnp.ndarray], List[float]] # -------------------------------------------- # # Main routines # -------------------------------------------- # @@ -93,11 +97,13 @@ def run_topology_optimization( # iterate through each bet aepoch for current_beta in betas: - final_design, current_data, current_results = _run_optimization( - starting_design=starting_design, - beta=current_beta, - num_iters=num_iters, - min_lengthscale=min_lengthscale, + final_design, projected_design, current_data, current_results = ( + _run_optimization( + starting_design=starting_design, + beta=current_beta, + num_iters=num_iters, + min_lengthscale=min_lengthscale, + ) ) # refresh the starting design with our latest optimized result @@ -107,7 +113,7 @@ def run_topology_optimization( data += current_data results += current_results - return final_design, data, results + return final_design, projected_design, data, results # -------------------------------------------- # @@ -115,6 +121,8 @@ def run_topology_optimization( # -------------------------------------------- # +# TODO [smartalecH] Use experimental jit compatibility, where we +# define all the input/output shapes along with the types. @agjax.wrap_for_jax def jax_conic_filter(input_array: jnp.ndarray, radius: float) -> jnp.ndarray: """Jax wrapper for meep's conic filter function.""" @@ -128,6 +136,8 @@ def jax_conic_filter(input_array: jnp.ndarray, radius: float) -> jnp.ndarray: ) +# TODO [smartalecH] Use experimental jit compatibility, where we define all the +# input/output shapes along with the types. @agjax.wrap_for_jax def jax_smoothed_projection(x_smoothed: jnp.ndarray, beta: float, eta: float): """Jax wrapper for meep's smoothed projection operator.""" @@ -144,6 +154,34 @@ def jax_smoothed_projection(x_smoothed: jnp.ndarray, beta: float, eta: float): # -------------------------------------------- # +def _latents_to_params( + design_vector: jnp.ndarray, + optimization_params: OptimizationParams, +) -> jnp.ndarray: + """Transform the latent design weights to the projected weights. + + Args: + design_vector: The design vector to compute the loss for. + optimization_params: The optimization parameters. + Returns: + the transformed parameters + """ + design_array = design_vector.reshape(N_x, N_y) + + # enforce symmetry + design_array = (design_array + jnp.fliplr(design_array)) / 2 + + # Filter the design parameters + filtered_array = jax_conic_filter(design_array, optimization_params.filter_radius) + + # Smoothly project the design parameters + smoothed_array = jax_smoothed_projection( + filtered_array, optimization_params.beta, optimization_params.eta + ) + + return smoothed_array + + def _loss_function( design_vector: jnp.ndarray, design_params: types.Density2DArray, @@ -165,17 +203,8 @@ def _loss_function( Returns: Tuple: The loss and a tuple containing the smoothed array, response, and efficiency. """ - design_array = design_vector.reshape(N_x, N_y) - - # enforce symmetry - design_array = (design_array + jnp.fliplr(design_array)) / 2 - - # Filter the design parameters - filtered_array = jax_conic_filter(design_array, optimization_params.filter_radius) - - # Smoothly project the design parameters - smoothed_array = jax_smoothed_projection( - filtered_array, optimization_params.beta, optimization_params.eta + smoothed_array = _latents_to_params( + design_vector=design_vector, optimization_params=optimization_params ) design_params = dataclasses.replace(design_params, array=smoothed_array) @@ -184,11 +213,14 @@ def _loss_function( response, aux = challenge_problem.component.response(design_params) # Use the same loss quantities as the paper - loss = challenge_problem.loss(response) metrics = challenge_problem.metrics(response, params=design_params, aux=aux) + + # Rather than optimizing the gym problem's loss function, we'll simply + # maximize the raw efficiency, to be consistent with the other + # implementations in the testbed. efficiency = metrics["average_efficiency"] - return loss, (smoothed_array, response, efficiency) + return efficiency, (smoothed_array, response, efficiency) def nlopt_fom( @@ -217,7 +249,7 @@ def nlopt_fom( data.append(smoothed_array.copy()) results.append(float(efficiency)) - print("FOM: {:.2f}, Efficiency: {:.2f}%".format(loss_val, efficiency * 100)) + print("Efficiency: {:.1f}%".format(efficiency * 100)) return float(loss_val) # explicit cast for nlopt @@ -276,12 +308,67 @@ def _run_optimization( solver.set_lower_bounds(0) solver.set_upper_bounds(1) solver.set_maxeval(num_iters) - solver.set_min_objective(nlopt_wrapper) + solver.set_max_objective(nlopt_wrapper) # Run the optimization final_design = solver.optimize(starting_design.flatten()) - return final_design, data, results + # map the final design to it's projected equivalent + projected_design = _latents_to_params( + design_vector=final_design, optimization_params=optimization_params + ) + + return final_design, projected_design, data, results + + +# -------------------------------------------- # +# Validation +# -------------------------------------------- # + + +def convergence_check( + design: jnp.ndarray, fourier_terms: jnp.ndarray, min_lengthscale: float +) -> jnp.ndarray: + """Run a convergence check by sweeping the number of Fourier terms. + + Args: + design: Binary, projected design. + fourier_terms: Array of fourier terms over which to sweep. + min_lengthscale: Minimum lengthscale (in um). + + Returns: + Vector of loss values for each fourier term. + """ + + filter_radius = mpa.get_conic_radius_from_eta_e(min_lengthscale, DEFAULT_ETA_E) + + FOM_values = [] + for num_terms in fourier_terms: + optimization_params = OptimizationParams( + beta=jnp.inf, + eta=DEFAULT_ETA, + filter_radius=filter_radius, + ) + + modified_params = METAGRATING_SIM_PARAMS + modified_params.approximate_num_terms = num_terms + + # Set up the challenge problem + challenge_problem = challenges.metagrating(sim_params=modified_params) + design_params = challenge_problem.component.init(jax.random.PRNGKey(0)) + + loss_fn = functools.partial( + _loss_function, + design_params=design_params, + optimization_params=optimization_params, + challenge_problem=challenge_problem, + ) + + loss_val, (smoothed_array, response, efficiency) = loss_fn(design) + + FOM_values.append(efficiency) + + return jnp.ndarray(FOM_values) # -------------------------------------------- # @@ -322,6 +409,8 @@ def visualize_evolution( plt.savefig(output_filename) + plt.close("all") + # -------------------------------------------- # # @@ -329,7 +418,7 @@ def visualize_evolution( if __name__ == "__main__": # Hyperparameters - num_iters = 60 + num_iters = 20 min_lengthscale = 0.1 # generate a random initial design @@ -339,33 +428,40 @@ def visualize_evolution( # Run a round of shape optimization if True: - final_design, data, results = run_shape_optimization( + print("RUNNING SHAPE OPTIMIZATION EXAMPLE...") + final_design, projected_design, data, results = run_shape_optimization( starting_design=starting_design, num_iters=num_iters, min_lengthscale=min_lengthscale, ) + # cache the projected design weights + jnp.savez("shape_optimization_design.npz", design_weights=projected_design) + visualize_evolution( data=data, results=results, - design_samples=[0, 10, 30, 45, -1], + design_samples=[0, 4, 9, 14, num_iters-1], output_filename="shape_optimization.png", ) # Run a round of topology optimization if True: - betas = [16.0, 64.0, jnp.inf] + print("RUNNING TOPOLOGY OPTIMIZATION EXAMPLE...") + betas = [8.0, 16.0, jnp.inf] - final_design, data, results = run_topology_optimization( + final_design, projected_design, data, results = run_topology_optimization( betas=betas, starting_design=starting_design, num_iters=num_iters, min_lengthscale=min_lengthscale, ) + jnp.savez("topology_optimization_design.npz", design_weights=projected_design) + visualize_evolution( data=data, results=results, - design_samples=[0, 20, 50, 100, -1], + design_samples=[0, 15, 30, 45, 59], output_filename="topology_optimization.png", )