diff --git a/6_hybridlens_design.py b/6_hybridlens_design.py new file mode 100644 index 0000000..ada44db --- /dev/null +++ b/6_hybridlens_design.py @@ -0,0 +1,103 @@ +""" +Jointly optimize refractive-diffractive lens with a differentiable ray-wave model. This code can be easily extended to end-to-end refractive-diffractive lens and network design. + +Technical Paper: + Xinge Yang, Matheus Souza, Kunyi Wang, Praneeth Chakravarthula, Qiang Fu and Wolfgang Heidrich, "End-to-End Hybrid Refractive-Diffractive Lens Design with Differentiable Ray-Wave Model," Siggraph Asia 2024. + +This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell: + # The license is only for non-commercial use (commercial licenses can be obtained from authors). + # The material is provided as-is, with no warranties whatsoever. + # If you publish any code, data, or scientific work based on this, please cite our work. +""" + +import logging +import os +import random +import string +from datetime import datetime + +import torch +import yaml +from torchvision.utils import save_image +from tqdm import tqdm + +from deeplens.hybridlens import HybridLens +from deeplens.optics.loss import PSFLoss +from deeplens.utils import set_logger, set_seed + + +def config(): + # ==> Config + args = {"seed": 0, "DEBUG": True} + + # ==> Result folder + characters = string.ascii_letters + string.digits + random_string = "".join(random.choice(characters) for i in range(4)) + result_dir = ( + "./results/" + + datetime.now().strftime("%m%d-%H%M%S") + + "-HybridLens" + + "-" + + random_string + ) + args["result_dir"] = result_dir + os.makedirs(result_dir, exist_ok=True) + print(f"Result folder: {result_dir}") + + if args["seed"] is None: + seed = random.randint(0, 100) + args["seed"] = seed + set_seed(args["seed"]) + + # ==> Log + set_logger(result_dir) + if not args["DEBUG"]: + raise Exception("Add your wandb logging config here.") + + # ==> Device + num_gpus = torch.cuda.device_count() + args["num_gpus"] = num_gpus + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + args["device"] = device + logging.info(f"Using {num_gpus} {torch.cuda.get_device_name(0)} GPU(s)") + + # ==> Save config + with open(f"{result_dir}/config.yml", "w") as f: + yaml.dump(args, f) + + with open(f"{result_dir}/6_hybridlens_design.py", "w") as f: + with open("6_hybridlens_design.py", "r") as code: + f.write(code.read()) + + return args + + +def main(args): + # Create a hybrid refractive-diffractive lens + lens = HybridLens(filename="./lenses/hybridlens/a489_doe.json") + lens.double() + + # PSF optimization loop to focus blue light + optimizer = lens.get_optimizer(doe_lr=0.1, lens_lr=[1e-4, 1e-4, 1e-1, 1e-5]) + loss_fn = PSFLoss() + for i in tqdm(range(100 + 1)): + psf = lens.psf(point=[0.0, 0.0, -10000.0], ks=101, wvln=0.489) + + optimizer.zero_grad() + loss = loss_fn(psf) + loss.backward() + optimizer.step() + + if i % 25 == 0: + lens.write_lens_json(f"{args['result_dir']}/lens_iter{i}.json") + lens.analysis(save_name=f"{args['result_dir']}/lens_iter{i}.png") + save_image( + psf.detach().clone(), + f"{args['result_dir']}/psf_iter{i}.png", + normalize=True, + ) + + +if __name__ == "__main__": + args = config() + main(args) diff --git a/deeplens/geolens.py b/deeplens/geolens.py index a0daeb2..c9bedc7 100644 --- a/deeplens/geolens.py +++ b/deeplens/geolens.py @@ -29,24 +29,20 @@ from tqdm import tqdm from transformers import get_cosine_schedule_with_warmup +from .lens import Lens from .optics import DEPTH, EPSILON, GEO_SPP, SELLMEIER_TABLE, Ray from .optics.basics import ( - init_device, - BLUE_RESPONSE, COHERENT_SPP, DEFAULT_WAVE, GEO_GRID, - GREEN_RESPONSE, - RED_RESPONSE, - WAVE_BOARD_BAND, WAVE_RGB, - DeepObj, + init_device, ) from .optics.materials import Material from .optics.monte_carlo import forward_integral from .optics.render_psf import render_psf_map from .optics.surfaces import ( - DOE_GEO, + Diffractive_GEO, Aperture, Aspheric, Cubic, @@ -65,62 +61,40 @@ ) -class GeoLens(DeepObj): +class GeoLens(Lens): """Geolens class. A geometric lens consisting of refractive surfaces, simulate with ray tracing. May contain diffractive surfaces, but still use ray tracing to simulate.""" - def __init__(self, filename=None, sensor_res=[1024, 1024], use_roc=False): + + def __init__(self, filename=None): """Initialize Lensgroup. Args: filename (string): lens file. sensor_res: (H, W) """ - super(GeoLens, self).__init__() self.device = init_device() - # Load lens file. + # Load lens file if filename is not None: self.lens_name = filename - self.load_file(filename, use_roc, sensor_res) + self.load_file(filename) self.to(self.device) # Lens calculation self.find_aperture() - self.prepare_sensor(sensor_res) + self.prepare_sensor() self.diff_surf_range = self.find_diff_surf() self.post_computation() else: - self.sensor_res = sensor_res + self.sensor_res = [1024, 1024] self.surfaces = [] self.materials = [] self.to(self.device) - def load_file(self, filename, use_roc, sensor_res): - """Load lens from .txt file. - - Args: - filename (string): lens file. - use_roc (bool): use radius of curvature (roc) or not. In the old code, we store lens data in roc rather than curvature. - post_computation (bool): compute fnum, fov, foclen or not. - sensor_res (list): sensor resolution. - """ + def load_file(self, filename): + """Load lens file.""" if filename[-4:] == ".txt": - raise Exception("File format support will be removed in the future.") - self.surfaces, self.materials, self.r_sensor, d_last = self.read_lensfile( - filename, use_roc - ) - self.d_sensor = d_last + self.surfaces[-1].d.item() - self.sensor_size = [ - 2 - * self.r_sensor - * sensor_res[0] - / math.sqrt(sensor_res[0] ** 2 + sensor_res[1] ** 2), - 2 - * self.r_sensor - * sensor_res[1] - / math.sqrt(sensor_res[0] ** 2 + sensor_res[1] ** 2), - ] - self.focz = self.d_sensor + raise Exception("File format support has been removed.") elif filename[-5:] == ".json": self.read_lens_json(filename) @@ -133,7 +107,7 @@ def load_file(self, filename, use_roc, sensor_res): def load_external(self, surfaces, materials, r_sensor, d_sensor): """Load lens from extrenal surface/material list. - + Args: surfaces (list): list of surfaces. materials (list): list of materials. @@ -150,18 +124,20 @@ def load_external(self, surfaces, materials, r_sensor, d_sensor): self.surfaces[i].mat1 = self.materials[i] self.surfaces[i].mat2 = self.materials[i + 1] - def prepare_sensor(self, sensor_res=[512, 512], sensor_size=None): + def prepare_sensor(self, sensor_res=None, sensor_size=None): """Create sensor. Args: sensor_res (list): Resolution, pixel number. pixel_size (float): Pixel size in [mm]. """ - sensor_res = ( - [sensor_res, sensor_res] if isinstance(sensor_res, int) else sensor_res - ) - self.sensor_res = sensor_res - H, W = sensor_res + if sensor_res is not None: + sensor_res = ( + [sensor_res, sensor_res] if isinstance(sensor_res, int) else sensor_res + ) + self.sensor_res = sensor_res + + H, W = self.sensor_res if sensor_size is None: self.sensor_size = [ 2 * self.r_sensor * H / math.sqrt(H**2 + W**2), @@ -175,7 +151,7 @@ def prepare_sensor(self, sensor_res=[512, 512], sensor_size=None): self.r_sensor = math.sqrt(sensor_size[0] ** 2 + sensor_size[1] ** 2) / 2 self.sensor_size = [float(self.sensor_size[0]), float(self.sensor_size[1])] - self.pixel_size = self.sensor_size[0] / sensor_res[0] + self.pixel_size = self.sensor_size[0] / self.sensor_res[0] def post_computation(self): """After loading lens, compute foclen, fov and fnum.""" @@ -627,8 +603,6 @@ def sample_pupil( # => Sample more uniformly when spp is not large else: - num_r2 = spp // num_angle - # ==> For each pixel, sample different points on the pupil x, y = [], [] for i in range(num_angle): @@ -696,7 +670,7 @@ def trace(self, ray, lens_range=None, record=False): def trace2obj(self, ray, depth=DEPTH): """Trace rays through the lens and reach the sensor plane. - + Args: ray (Ray object): Ray object. depth (float): sensor distance. @@ -744,7 +718,7 @@ def trace2sensor(self, ray, record=False, ignore_invalid=False): def forward_tracing(self, ray, lens_range, record): """Trace rays from object space to sensor plane. - + Args: ray (Ray object): Ray object. lens_range (list): lens range. @@ -782,7 +756,7 @@ def forward_tracing(self, ray, lens_range, record): def backward_tracing(self, ray, lens_range, record): """Trace rays from sensor plane to object space. - + Args: ray (Ray object): Ray object. lens_range (list): lens range. @@ -908,8 +882,8 @@ def render_single_img( # ==> Unwarp to correct geometry distortion if unwarp: img_render = self.unwarp(img_render, depth) - # if save_name is not None: - # save_image(img_render, f'{save_name}_unwarped.png') + if save_name is not None: + save_image(img_render, f"{save_name}_unwarped.png") # ==> Add noise if noise > 0: @@ -1082,98 +1056,9 @@ def render_compute_image(self, img, depth, scale, ray): return image - def isp(self, img, psf, noise=0.01): - """Image signal processing.""" - raise NotImplementedError("This function has not been implemented yet.") - # Energy - - # Gamma - - # White balance - - # Noise - img += noise * torch.randn_like(img).to(self.device) - return img - # ==================================================================================== # PSF and spot diagram (incoherent ray tracing) # ==================================================================================== - def point_source_grid( - self, depth, grid=8, normalized=True, quater=False, center=True - ): - """Compute point grid [-1: 1] * [-1: 1] in the object space to compute PSF grid. - - Args: - depth (float): Depth of the point source plane. - grid (int): Grid size. Defaults to 9. - normalized (bool): Whether to use normalized x, y corrdinates [-1, 1]. Defaults to True. - quater (bool): Whether to use quater of the grid. Defaults to False. - center (bool): Whether to use center of each patch. Defaults to False. - - Returns: - point_source: Shape of [grid, grid, 3]. - """ - if grid == 1: - x, y = torch.tensor([[0.0]]), torch.tensor([[0.0]]) - assert not quater, "Quater should be False when grid is 1." - else: - # ==> Use center of each patch - if center: - half_bin_size = 1 / 2 / (grid - 1) - x, y = torch.meshgrid( - torch.linspace(-1 + half_bin_size, 1 - half_bin_size, grid), - torch.linspace(1 - half_bin_size, -1 + half_bin_size, grid), - indexing="xy", - ) - # ==> Use corner - else: - x, y = torch.meshgrid( - torch.linspace(-0.98, 0.98, grid), - torch.linspace(0.98, -0.98, grid), - indexing="xy", - ) - - z = torch.full((grid, grid), depth) - point_source = torch.stack([x, y, z], dim=-1) - - # ==> Use quater of the sensor plane to save memory - if quater: - z = torch.full((grid, grid), depth) - point_source = torch.stack([x, y, z], dim=-1) - bound_i = grid // 2 if grid % 2 == 0 else grid // 2 + 1 - bound_j = grid // 2 - point_source = point_source[0:bound_i, bound_j:, :] - - if not normalized: - scale = self.calc_scale_pinhole(depth) - point_source[..., 0] *= scale * self.sensor_size[0] / 2 - point_source[..., 1] *= scale * self.sensor_size[1] / 2 - - return point_source - - def point_source_radial(self, depth, grid=9, center=False): - """Compute point radial [0, 1] in the object space to compute PSF grid. - - Args: - grid (int, optional): Grid size. Defaults to 9. - - Returns: - point_source: Shape of [grid, 3]. - """ - if grid == 1: - x = torch.tensor([0.0]) - else: - # Select center of bin to calculate PSF - if center: - half_bin_size = 1 / 2 / (grid - 1) - x = torch.linspace(0, 1 - half_bin_size, grid) - else: - x = torch.linspace(0, 0.98, grid) - - z = torch.full_like(x, depth) - point_source = torch.stack([x, x, z], dim=-1) - return point_source - @torch.no_grad() def psf_center(self, point, method="chief_ray"): """Compute reference PSF center (flipped to match the original point, green light) for given point source. @@ -1191,7 +1076,9 @@ def psf_center(self, point, method="chief_ray"): assert (ray.ra == 1).any(), "No sampled rays is valid." psf_center = (ray.o * ray.ra.unsqueeze(-1)).sum(0) / ray.ra.unsqueeze( -1 - ).sum(0).add(EPSILON) # shape [N, 3] + ).sum(0).add( + EPSILON + ) # shape [N, 3] psf_center = -psf_center[..., :2] # shape [N, 2] elif method == "pinhole": @@ -1264,34 +1151,6 @@ def psf(self, points, ks=31, wvln=DEFAULT_WAVE, spp=GEO_SPP, center=True): return psf - def psf_board_band(self, points, ks=31, spp=GEO_SPP, recenter=True): - """Compute boardband PSF. Each color channel responses to all wavelenghts. - - 3 channels * 31 wvlns = 93 values - """ - # Calculate boardband RGB PSF - psf_r = [] - for i, wvln in enumerate(WAVE_BOARD_BAND): - psf = self.psf(points=points, ks=ks, wvln=wvln, spp=spp) - psf_r.append(psf * RED_RESPONSE[i]) - psf_r = torch.stack(psf_r, dim=0).sum(dim=0) / sum(RED_RESPONSE) - - psf_g = [] - for i, wvln in enumerate(WAVE_BOARD_BAND): - psf = self.psf(points=points, ks=ks, wvln=wvln, spp=spp) - psf_g.append(psf * GREEN_RESPONSE[i]) - psf_g = torch.stack(psf_g, dim=0).sum(dim=0) / sum(GREEN_RESPONSE) - - psf_b = [] - for i, wvln in enumerate(WAVE_BOARD_BAND): - psf = self.psf(points=points, ks=ks, wvln=wvln, spp=spp) - psf_b.append(psf * BLUE_RESPONSE[i]) - psf_b = torch.stack(psf_b, dim=0).sum(dim=0) / sum(BLUE_RESPONSE) - - psfs = torch.stack([psf_r, psf_g, psf_b], dim=0) # shape [3, ks, ks] - - return psfs - def psf_rgb(self, points, ks=31, spp=GEO_SPP, center=True): """Compute RGB point PSF. This function is differentiable. @@ -1332,7 +1191,9 @@ def psf_map( points = points.reshape(-1, 3) psfs = self.psf( points=points, ks=ks, center=center, spp=spp, wvln=wvln - ).unsqueeze(1) # shape [grid**2, 1, ks, ks] + ).unsqueeze( + 1 + ) # shape [grid**2, 1, ks, ks] psf_map = make_grid(psfs, nrow=grid, padding=0)[ 0, :, : @@ -1454,14 +1315,14 @@ def analysis_rms(self, depth=DEPTH): # Coherent ray tracing # ==================================================================================== def pupil_field(self, point, wvln=DEFAULT_WAVE, spp=COHERENT_SPP): - """Compute complex wavefront (flipped for further PSF calculation) at exit pupil plane by coherent ray tracing. + """Compute complex wavefront (flipped for further PSF calculation) at exit pupil plane by coherent ray tracing. The wavefront has the same size as image sensor. This function is differentiable. Args: - point (tensor, optional): Point source position. - wvln (float, optional): wvln. - spp (int, optional): Sample per pixel. + point (tensor): Point source position. Shape of [N, 3], x, y in range [-1, 1], z in range [-Inf, 0]. + wvln (float): Ray wavelength in [um]. + spp (int): Ray sample number per point. """ assert ( spp >= 1000000 @@ -1471,7 +1332,7 @@ def pupil_field(self, point, wvln=DEFAULT_WAVE, spp=COHERENT_SPP): ), "Please set the default dtype to float64 for accurate phase calculation." if len(point.shape) == 1: - point = point.unsqueeze(0) + point = point.unsqueeze(0) # shape of [1, 3] # Ray origin in the object space scale = self.calc_scale_ray(point[:, 2].item()) @@ -1900,6 +1761,18 @@ def calc_principal(self, wvln=DEFAULT_WAVE): return front_principal, back_principal + @torch.no_grad() + def calc_scale(self, depth, method="pinhole"): + """Calculate the scale factor.""" + if method == "pinhole": + scale = self.calc_scale_pinhole(depth) + elif method == "ray": + scale = self.calc_scale_ray(depth) + else: + raise ValueError("Invalid method.") + + return scale + @torch.no_grad() def calc_scale_pinhole(self, depth): """Assume the first principle point is at (0, 0, 0), use pinhole camera to calculate the scale factor.""" @@ -1954,7 +1827,7 @@ def chief_ray(self): def exit_pupil(self, shrink_pupil=False): """Sample **forward** rays to compute z coordinate and radius of exit pupil. Exit pupil: ray comes from sensor to object space. - Reference: https://en.wikipedia.org/wiki/Exit_pupil + Reference: https://en.wikipedia.org/wiki/Exit_pupil """ return self.entrance_pupil(entrance=False, shrink_pupil=shrink_pupil) @@ -1962,9 +1835,9 @@ def exit_pupil(self, shrink_pupil=False): def entrance_pupil(self, M=128, entrance=True, shrink_pupil=False): """Sample **backward** rays, return z coordinate and radius of entrance pupil. Entrance pupil: how many rays can come from object space to sensor. - Reference: https://en.wikipedia.org/wiki/Entrance_pupil "In an optical system, the entrance pupil is the optical image of the physical aperture stop, as 'seen' through the optical elements in front of the stop." + Reference: https://en.wikipedia.org/wiki/Entrance_pupil "In an optical system, the entrance pupil is the optical image of the physical aperture stop, as 'seen' through the optical elements in front of the stop." """ - if self.aper_idx is None: + if self.aper_idx is None or hasattr(self, "aper_idx") is False: if entrance: return self.surfaces[0].d.item(), self.surfaces[0].r else: @@ -2024,7 +1897,7 @@ def entrance_pupil(self, M=128, entrance=True, shrink_pupil=False): else: avg_pupilz = self.surfaces[-1].d.item() avg_pupilx = self.surfaces[-1].r - + if shrink_pupil: avg_pupilx *= 0.5 return avg_pupilz, avg_pupilx @@ -2186,7 +2059,7 @@ def prune_surf(self, expand_surf=None, surface_range=None): self.surfaces[i].r = max(height) * (1 + expand_surf) except: continue - + # ==> 4. Remove nan part, also the maximum height should not exceed sensor radius for i in surface_range: # max_height = min(self.surfaces[i].max_height(), self.r_sensor) @@ -2582,7 +2455,7 @@ def draw_aperture(ax, surface, color): # Draw lens surfaces for i, s in enumerate(self.surfaces): # DOE - if isinstance(s, DOE_GEO): + if isinstance(s, Diffractive_GEO): # DOE r = torch.linspace(-s.r, s.r, s.APERTURE_SAMPLING, device=self.device) max_offset = self.d_sensor.item() / 100 @@ -2670,68 +2543,6 @@ def draw_aperture(ax, surface, color): return ax, fig - @torch.no_grad() - def draw_psf_map( - self, - grid=7, - depth=DEPTH, - ks=101, - log_scale=False, - center=True, - save_name="./psf.png", - ): - """Draw RGB PSF map at a certain depth. Will draw M x M PSFs, each of size ks x ks.""" - # Calculate PSF map - psf_map = self.psf_map_rgb( - depth=depth, grid=grid, ks=ks, spp=GEO_SPP, center=center - ) - - if log_scale: - # Los scale the PSF for better visualization - psf_map = torch.log(psf_map + 1e-4) # 1e-4 is an empirical value - psf_map = (psf_map - psf_map.min()) / (psf_map.max() - psf_map.min()) - else: - # Normalize for each field - for i in range(0, psf_map.shape[-2], ks): - for j in range(0, psf_map.shape[-1], ks): - if psf_map[:, i : i + ks, j : j + ks].max() != 0: - psf_map[:, i : i + ks, j : j + ks] /= psf_map[ - :, i : i + ks, j : j + ks - ].max() - - # Save figure using matplotlib - plt.figure(figsize=(10, 10)) - psf_map = psf_map.permute(1, 2, 0).cpu().numpy() - plt.imshow(psf_map) - - H, W = psf_map.shape[:2] - ruler_len = 100 - arrow_end = ruler_len / (self.pixel_size * 1e3) # plot a scale ruler - plt.annotate( - "", - xy=(0, H - 10), - xytext=(arrow_end, H - 10), - arrowprops=dict(arrowstyle="<->", color="white"), - ) - plt.text( - arrow_end + 10, - H - 10, - f"{ruler_len} um", - color="white", - fontsize=12, - ha="left", - ) - - plt.axis("off") - plt.tight_layout(pad=0) # Removes padding - save_name = ( - f"./psf{-depth}mm.png" - if save_name is None - else f"{save_name}_psf{-depth}mm.png" - ) - plt.savefig(save_name, dpi=300) - plt.close() - @torch.no_grad() def draw_psf_radial( self, M=3, depth=DEPTH, ks=51, log_scale=False, save_name="./psf_radial.png" @@ -2978,7 +2789,9 @@ def unwarp(self, img, depth=DEPTH, grid=256, spp=256, crop=True): o2 = ray.project_to(self.d_sensor) o_dist = (o2 * ray.ra.unsqueeze(-1)).sum(0) / ray.ra.sum(0).add( EPSILON - ).unsqueeze(-1) # shape (H, W, 2) + ).unsqueeze( + -1 + ) # shape (H, W, 2) # Reshape to [N, C, H, W], normalize to [-1, 1], then resize to img resolution [N, C, H, W] x_dist = F.interpolate( @@ -3200,7 +3013,9 @@ def loss_reg(self, w_focus=None): loss_intersec = self.loss_self_intersec( dist_bound=0.1, thickness_bound=0.3, flange_bound=0.5 ) - loss_surf = self.loss_surface(sag_bound=0.8, grad_bound=1.0, grad2_bound=100.0) + loss_surf = self.loss_surface( + sag_bound=0.8, grad_bound=1.0, grad2_bound=100.0 + ) loss_angle = self.loss_ray_angle() w_focus = 2.0 if w_focus is None else w_focus @@ -3276,7 +3091,7 @@ def get_optimizer_params( lr=lr, decay=decay, optim_mat=optim_mat ) - elif isinstance(surf, DOE_GEO): + elif isinstance(surf, Diffractive_GEO): params += surf.get_optimizer_params(lr=lr[2]) elif isinstance(surf, Plane): @@ -3358,7 +3173,7 @@ def optimize( if i > 0: if shape_control: self.correct_shape() - + if optim_mat and match_mat: self.match_materials() @@ -3413,7 +3228,9 @@ def optimize( ra = ra[:, num_grid // 2 :, num_grid // 2 :] # Weight mask - weight_mask = (xy_norm.clone().detach() ** 2).sum([0, -1]) / (ra.sum([0]) + EPSILON) # Use L2 error as weight mask + weight_mask = (xy_norm.clone().detach() ** 2).sum([0, -1]) / ( + ra.sum([0]) + EPSILON + ) # Use L2 error as weight mask weight_mask /= weight_mask.mean() # shape of [M, M] # RMS loss @@ -3468,7 +3285,9 @@ def read_lens_json(self, filename="./test.json"): mat2=surf_dict["mat2"], ) else: - raise Exception("ROC not found. This case will be removed in the future.") + raise Exception( + "ROC not found. This case will be removed in the future." + ) s = Aspheric( c=surf_dict["c"], r=surf_dict["r"], @@ -3479,7 +3298,7 @@ def read_lens_json(self, filename="./test.json"): ) else: s = Aspheric( - c=1/surf_dict["roc"], + c=1 / surf_dict["roc"], r=surf_dict["r"], d=d, k=surf_dict["k"] if "k" in surf_dict else 0.001, @@ -3495,14 +3314,11 @@ def read_lens_json(self, filename="./test.json"): r=surf_dict["r"], d=d, b=surf_dict["b"], mat2=surf_dict["mat2"] ) - elif surf_dict["type"] == "DOE_GEO": - s = DOE_GEO(l=surf_dict["l"], d=d, glass=surf_dict["glass"]) + elif surf_dict["type"] == "Diffractive_GEO": + s = Diffractive_GEO(l=surf_dict["l"], d=d, glass=surf_dict["glass"]) elif surf_dict["type"] == "Plane": - if "l" in surf_dict: - s = Plane(l=surf_dict["l"], d=d, mat2=surf_dict["mat2"]) - else: - s = Plane(l=2 * surf_dict["r"], d=d, mat2=surf_dict["mat2"]) + s = Plane(r=surf_dict["r"], d=d, mat2=surf_dict["mat2"]) elif surf_dict["type"] == "Stop": s = Aperture(r=surf_dict["r"], d=d) @@ -3529,6 +3345,7 @@ def read_lens_json(self, filename="./test.json"): d += surf_dict["d_next"] # self.sensor_size = data['sensor_size'] + self.sensor_res = data["sensor_res"] if "sensor_res" in data else [1024, 1024] self.r_sensor = data["r_sensor"] self.d_sensor = torch.tensor(d) self.lens_info = data["info"] if "info" in data else "None" @@ -3537,22 +3354,24 @@ def write_lens_json(self, filename="./test.json"): """Write the lens into .json file.""" data = {} data["info"] = self.lens_info if hasattr(self, "lens_info") else "None" - data["foclen"] = self.foclen - data["fnum"] = self.fnum + data["foclen"] = round(self.foclen, 4) + data["fnum"] = round(self.fnum, 4) data["r_sensor"] = self.r_sensor - data["d_sensor"] = self.d_sensor.item() - data["(sensor_size)"] = self.sensor_size + data["d_sensor"] = round(self.d_sensor.item(), 4) + data["(sensor_size)"] = [round(i, 4) for i in self.sensor_size] data["surfaces"] = [] for i, s in enumerate(self.surfaces): surf_dict = {"idx": i + 1} surf_dict.update(s.surf_dict()) if i < len(self.surfaces) - 1: - surf_dict["d_next"] = ( - self.surfaces[i + 1].d.item() - self.surfaces[i].d.item() + surf_dict["d_next"] = round( + self.surfaces[i + 1].d.item() - self.surfaces[i].d.item(), 4 ) else: - surf_dict["d_next"] = self.d_sensor.item() - self.surfaces[i].d.item() + surf_dict["d_next"] = round( + self.d_sensor.item() - self.surfaces[i].d.item(), 4 + ) data["surfaces"].append(surf_dict) @@ -3705,13 +3524,14 @@ def write_lens_zmx(self, filename="./test.zmx"): # Other functions. # ==================================================================================== + def create_lens( foclen, fov, fnum, flange, thickness=None, - lens_type=['Spheric', 'Spheric', 'Aperture', 'Spheric', 'Aspheric'], + lens_type=["Spheric", "Spheric", "Aperture", "Spheric", "Aspheric"], save_dir="./", ): """Create a flat starting point for camera lens design. @@ -3726,7 +3546,7 @@ def create_lens( """ assert "Aperture" in lens_type, "Aperture should be included in lens_type." lens_num = len(lens_type) - + # Compute lens parameters aper_r = foclen / fnum / 2 imgh = 2 * foclen * np.tan(fov / 2 / 57.3) @@ -3735,7 +3555,7 @@ def create_lens( d_opt = thickness - flange d_air = np.random.rand(lens_num).astype(np.float32) + 0.5 d_lens = np.random.rand(lens_num).astype(np.float32) + 1.0 - d_lens[lens_type.index('Aperture')] = 0.0 + d_lens[lens_type.index("Aperture")] = 0.0 d_sum = np.sum(d_air) + np.sum(d_lens) d_air = d_air / d_sum * d_opt d_lens = d_lens / d_sum * d_opt @@ -3766,7 +3586,9 @@ def create_lens( ai1 = np.random.randn(7).astype(np.float32) * 1e-30 k1 = np.random.randn(1).astype(np.float32) * 0.001 surfaces.append( - Aspheric(r=max(imgh / 2, aper_r), d=d_total, c=c1, ai=ai1, k=k1, mat2=mat) + Aspheric( + r=max(imgh / 2, aper_r), d=d_total, c=c1, ai=ai1, k=k1, mat2=mat + ) ) # Back surface @@ -3775,9 +3597,11 @@ def create_lens( ai2 = np.random.randn(7).astype(np.float32) * 1e-30 k2 = np.random.randn(1).astype(np.float32) * 0.001 surfaces.append( - Aspheric(r=max(imgh / 2, aper_r), d=d_total, c=c2, ai=ai2, k=k2, mat2="air") + Aspheric( + r=max(imgh / 2, aper_r), d=d_total, c=c2, ai=ai2, k=k2, mat2="air" + ) ) - + elif lens_type[i] == "Spheric": # Front surface d_total += d_air[i] @@ -3788,7 +3612,9 @@ def create_lens( # Back surface d_total += d_lens[i] c2 = np.random.randn(1).astype(np.float32) * 0.001 - surfaces.append(Spheric(r=max(imgh / 2, aper_r), d=d_total, c=c2, mat2="air")) + surfaces.append( + Spheric(r=max(imgh / 2, aper_r), d=d_total, c=c2, mat2="air") + ) else: raise Exception("Surface type not supported yet.") @@ -3808,4 +3634,4 @@ def create_lens( filename = f"starting_point_f{foclen}mm_imgh{imgh}_fnum{fnum}.json" lens.write_lens_json(os.path.join(save_dir, filename)) - return lens \ No newline at end of file + return lens diff --git a/deeplens/hybridlens.py b/deeplens/hybridlens.py new file mode 100644 index 0000000..f3c854f --- /dev/null +++ b/deeplens/hybridlens.py @@ -0,0 +1,382 @@ +"""A hybrid refractive-diffractive lens consisting of a geolens and a DOE in the back. Hybrid ray-tracing-wave-propagation is used for differentiable simulation. + +This differentiable hybrid lens model can similate: + 1. Aberration of the refractive lens + 2. DOE phase modulation + +Technical Paper: + Xinge Yang, Matheus Souza, Kunyi Wang, Praneeth Chakravarthula, Qiang Fu, Wolfgang Heidrich, "End-to-End Hybrid Refractive-Diffractive Lens Design with Differentiable Ray-Wave Model," Siggraph Asia 2024. + +This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell: + # The license is only for non-commercial use (commercial licenses can be obtained from authors). + # The material is provided as-is, with no warranties whatsoever. + # If you publish any code, data, or scientific work based on this, please cite our work. +""" + +import json + +import matplotlib.patches as patches +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn.functional as F + +from .geolens import GeoLens +from .lens import Lens +from .optics.basics import ( + COHERENT_SPP, + DEFAULT_WAVE, + WAVE_RGB, +) +from .optics.monte_carlo import forward_integral +from .optics.surfaces import Diffractive_GEO +from .optics.surfaces_diffractive import DOE +from .optics.wave import AngularSpectrumMethod +from .optics.waveoptics_utils import diff_float + + +class HybridLens(Lens): + """A hybrid refractive-diffractive lens with a Geolens and a DOE at last. + + This differentiable hybrid lens model can similate: + 1. Aberration of the refractive lens + 2. DOE phase modulation + """ + + def double(self): + self.geolens.double() + self.doe.double() + + def read_lens_json(self, lens_path): + """Read the lens from .json file.""" + # Load geolens + geolens0 = GeoLens(filename=lens_path) + + with open(lens_path, "r") as f: + data = json.load(f) + + # Load DOE + doe_dict = data["DOE"] + doe0 = DOE( + l=doe_dict["l"], + d=doe_dict["d"], + res=doe_dict["res"], + fab_ps=doe_dict["fab_ps"], + param_model=doe_dict["param_model"], + ) + try: + doe0.load_doe(doe_dict) + except Exception: + print( + "When loading DOE, DOE parameter is not found, use random initialization." + ) + doe0.init_param_model(param_model=doe_dict["param_model"]) + + self.doe = doe0 + + # Add a DOE surface to GeoLens + geolens0.surfaces.append(Diffractive_GEO(l=doe0.l, d=doe0.d)) + self.geolens = geolens0 + + # + self.sensor_res = geolens0.sensor_res + self.pixel_size = geolens0.pixel_size + + def write_lens_json(self, lens_path): + """Write the lens into .json file.""" + geolens = self.geolens + data = {} + data["info"] = geolens.lens_info if hasattr(geolens, "lens_info") else "None" + data["foclen"] = round(geolens.foclen, 4) + data["fnum"] = round(geolens.fnum, 4) + data["r_sensor"] = round(geolens.r_sensor, 4) + data["d_sensor"] = round(geolens.d_sensor.item(), 4) + data["sensor_size"] = [round(i, 4) for i in geolens.sensor_size] + data["sensor_res"] = geolens.sensor_res + + # Geolens + data["surfaces"] = [] + for i, s in enumerate(geolens.surfaces[:-1]): + surf_dict = {"idx": i + 1} + surf_dict.update(s.surf_dict()) + + # To exclude the last surface (DOE) + if i < len(geolens.surfaces) - 2: + surf_dict["d_next"] = round( + geolens.surfaces[i + 1].d.item() - geolens.surfaces[i].d.item(), 3 + ) + else: + surf_dict["d_next"] = round( + geolens.d_sensor.item() - geolens.surfaces[i].d.item(), 3 + ) + + data["surfaces"].append(surf_dict) + + # DOE + data["DOE"] = self.doe.surf_dict() + + with open(lens_path, "w") as f: + json.dump(data, f, indent=4) + + # ===================================================================== + # Lens operation + # ===================================================================== + def analysis(self, save_name="./test.png"): + self.draw_layout(save_name=save_name) + + def prepare_sensor(self, sensor_res): + self.geolens.prepare_sensor(sensor_res) + + self.sensor_res = self.geolens.sensor_res + self.pixel_size = self.geolens.pixel_size + + def refocus(self, foc_dist): + """Refocus the DoeLens to a given depth. Donot move DOE because DOE is installed with geolens in the Siggraph Asia 2024 paper.""" + self.geolens.refocus(foc_dist) + + def draw_layout(self, save_name="./DOELens.png", depth=-10000, ax=None, fig=None): + """Draw DOELens layout with ray-tracing and wave-propagation.""" + geolens = self.geolens + + # Draw lens layout + if ax is None: + ax, fig = geolens.plot_setup2D() + save_fig = True + else: + save_fig = False + + # Draw light path + color_list = ["#CC0000", "#006600", "#0066CC"] + views = [0, np.rad2deg(geolens.hfov) * 0.707, np.rad2deg(geolens.hfov) * 0.99] + arc_radi_list = [0.1, 0.4, 0.7, 1.0, 1.4, 1.8] + for i, view in enumerate(views): + # Draw ray tracing + ray = geolens.sample_point_source_2D( + depth=depth, view=view, M=5, entrance_pupil=True, wvln=WAVE_RGB[2 - i] + ) + ray, _, oss = geolens.trace(ray=ray, record=True) + ax, fig = geolens.plot_raytraces( + oss, ax=ax, fig=fig, color=color_list[i], ra=ray.ra + ) + + # Draw wave propagation + ray.prop_to(geolens.d_sensor) + arc_center = (ray.o[:, 0] * ray.ra).sum() / ray.ra.sum() + arc_center = arc_center.item() + # arc_radi = geolens.d_sensor.item() - geolens.surfaces[-1].d.item() + arc_radi = geolens.d_sensor.item() - self.doe.d.item() + theta1 = ( + np.rad2deg( + np.arctan2( + ray.o[0, 0].item() - oss[-1][-1][0], + ray.o[0, 2].item() - oss[-1][-1][2], + ) + ) + - 4 + ) + theta2 = ( + np.rad2deg( + np.arctan2( + ray.o[0, 0].item() - oss[0][-1][0], + ray.o[0, 2].item() - oss[0][-1][2], + ) + ) + + 4 + ) + + for j in arc_radi_list: + arc_radi_j = arc_radi * j + arc = patches.Arc( + (geolens.d_sensor.item(), arc_center), + arc_radi_j, + arc_radi_j, + angle=180.0, + theta1=theta1, + theta2=theta2, + color=color_list[i], + ) + ax.add_patch(arc) + + if save_fig: + # Save figure + ax.axis("off") + ax.set_title("DOE Lens") + fig.savefig(save_name, bbox_inches="tight", format="png", dpi=600) + plt.close() + else: + return ax, fig + + def get_optimizer( + self, doe_lr=1e-4, lens_lr=[1e-4, 1e-4, 1e-2, 1e-5], lr_decay=0.01 + ): + params = [] + params += self.geolens.get_optimizer_params(lr=lens_lr, decay=lr_decay) + params += self.doe.get_optimizer_params(lr=doe_lr) + + optimizer = torch.optim.Adam(params, weight_decay=1e-4) + return optimizer + + # ===================================================================== + # PSF-related functions + # ===================================================================== + def doe_field(self, point, wvln=DEFAULT_WAVE, spp=COHERENT_SPP): + """Compute the complex wavefront at the DOE plane using coherent ray tracing. This function reimplements geolens.pupil_field() by changing the wavefront computation position to the last surface. + + Args: + point (torch.Tensor): Tensor of shape (3,) representing the point source position. Defaults to torch.tensor([0.0, 0.0, -10000.0]). + wvln (float): Wavelength. Defaults to DEFAULT_WAVE. + spp (int): Samples per pixel. Must be >= 1,000,000 for accurate simulation. Defaults to COHERENT_SPP. + + Returns: + + wavefront: Tensor of shape [H, W] representing the complex wavefront. + psf_center: List containing the PSF center coordinates [x, y]. + """ + assert spp >= 1_000_000, ( + "Coherent ray tracing spp is too small, " + "which may lead to inaccurate simulation." + ) + assert ( + torch.get_default_dtype() == torch.float64 + ), "Default dtype must be set to float64 for accurate phase tracing." + + geolens, doe = self.geolens, self.doe + + if point.dim() == 1: + point = point.unsqueeze(0) + + # Calculate ray origin in the object space + scale = geolens.calc_scale_ray(point[:, 2].item()) + point_obj = point.clone() + point_obj[:, 0] = point[:, 0] * scale * geolens.sensor_size[1] / 2 + point_obj[:, 1] = point[:, 1] * scale * geolens.sensor_size[0] / 2 + + # Determine ray center via chief ray + pointc_chief_ray = geolens.psf_center(point_obj)[0] # shape [2] + + # Ray tracing + ray = geolens.sample_from_points(o=point_obj, spp=spp, wvln=wvln) + ray.coherent = True + ray, _, _ = geolens.trace(ray) + ray = ray.prop_to(doe.d) + + # Calculate full-resolution complex field for exit-pupil diffraction + wavefront = forward_integral( + ray, + ps=doe.ps, + ks=doe.res[0], + pointc_ref=torch.zeros_like(point[:, :2]), + coherent=True, + ).squeeze( + 0 + ) # shape [H, W] + + # Compute PSF center based on chief ray + psf_center = [ + pointc_chief_ray[0] / geolens.sensor_size[0] * 2, + pointc_chief_ray[1] / geolens.sensor_size[1] * 2, + ] + + return wavefront, psf_center + + def psf( + self, + points=[0.0, 0.0, -10000.0], + ks=101, + wvln=DEFAULT_WAVE, + spp=COHERENT_SPP, + ): + """Single point monochromatic PSF using ray-wave model. + + Steps: + 1, calculate complex wavefield at DOE (pupil) plane by coherent ray tracing. + 2, propagate through DOE to sensor plane, calculate intensity PSF, crop the valid region and normalize. + + Args: + point (torch.Tensor, optional): [x, y, z] coordinates of the point source. Defaults to torch.Tensor([0,0,-10000]). + ks (int, optional): size of the PSF patch. Defaults to 101. + wvln (float, optional): wvln. Defaults to 0.589. + spp (int, optional): number of rays to sample. Defaults to 1000000. + + Returns: + psf_out (torch.Tensor): PSF patch. Normalized to sum to 1. Shape [ks, ks] + """ + # Check double precision + if not torch.get_default_dtype() == torch.float64: + raise ValueError( + "Please call HybridLens.double() to set the default dtype to float64 for accurate phase tracing." + ) + + # Check lens last surface + assert isinstance( + self.geolens.surfaces[-1], Diffractive_GEO + ), "The last lens surface should be a DOE." + geolens, doe = self.geolens, self.doe + + # Compute pupil field by coherent ray tracing + if isinstance(points, list): + point0 = torch.tensor(points) + elif isinstance(points, torch.Tensor): + point0 = points + else: + raise ValueError("point should be a list or a torch.Tensor.") + + wavefront, psfc = self.doe_field(point=point0, wvln=wvln, spp=spp) + wavefront = wavefront.squeeze(0) # shape of [H, W] + + # DOE phase modulation. We have to flip the phase map because the wavefront has been flipped + phase_map = torch.flip(doe.get_phase_map(wvln), [-1, -2]) + wavefront = wavefront * torch.exp(1j * phase_map) + + # Propagate wave field to sensor plane + h, w = wavefront.shape + wavefront = F.pad( + wavefront.unsqueeze(0).unsqueeze(0), + [h // 2, h // 2, w // 2, w // 2], + mode="constant", + value=0, + ) + sensor_field = AngularSpectrumMethod( + wavefront, z=geolens.d_sensor - doe.d, wvln=wvln, ps=doe.ps, padding=False + ) + + # Compute PSF (intensity distribution) + psf_inten = sensor_field.abs() ** 2 + psf_inten = ( + F.interpolate( + psf_inten, + scale_factor=geolens.sensor_res[0] / h, + mode="bilinear", + align_corners=False, + ) + .squeeze(0) + .squeeze(0) + ) + + # Calculate PSF center index and crop valid PSF region (Consider both interplation and padding) + if ks is not None: + h, w = psf_inten.shape[-2:] + psfc_idx_i = ((2 - psfc[1]) * h / 4).round().long() + psfc_idx_j = ((2 + psfc[0]) * w / 4).round().long() + + # Pad to avoid invalid edge region + psf_inten_pad = F.pad( + psf_inten, + [ks // 2, ks // 2, ks // 2, ks // 2], + mode="constant", + value=0, + ) + psf = psf_inten_pad[ + psfc_idx_i : psfc_idx_i + ks, psfc_idx_j : psfc_idx_j + ks + ] + else: + h, w = psf_inten.shape[-2:] + psf = psf_inten[ + int(h / 2 - h / 4) : int(h / 2 + h / 4), + int(w / 2 - w / 4) : int(w / 2 + w / 4), + ] + + # Normalize and convert to float precision + psf /= psf.sum() # shape of [ks, ks] or [h, w] + psf = diff_float(psf) + return psf diff --git a/deeplens/lens.py b/deeplens/lens.py index 59a60c2..fcf6acc 100644 --- a/deeplens/lens.py +++ b/deeplens/lens.py @@ -3,46 +3,55 @@ When creating a new lens class, it is recommended to inherit from the Lens class and re-write core functions. """ +import matplotlib.pyplot as plt import torch from torchvision.utils import make_grid, save_image + from .optics import ( - DeepObj, - init_device, - WAVE_RGB, + BLUE_RESPONSE, + DEPTH, + GREEN_RESPONSE, + RED_RESPONSE, WAVE_BLUE, + WAVE_BOARD_BAND, WAVE_GREEN, WAVE_RED, - WAVE_BOARD_BAND, - RED_RESPONSE, - GREEN_RESPONSE, - BLUE_RESPONSE, - DEPTH, + WAVE_RGB, + DeepObj, + init_device, ) +from .optics.render_psf import render_psf_map class Lens(DeepObj): - """Geolens class. A geometric lens consisting of refractive surfaces, simulate with ray tracing. May contain diffractive surfaces, but still use ray tracing to simulate.""" + """Base lens class.""" - def __init__(self, filename, sensor_res=[1024, 1024]): - """A lens class.""" - super(Lens, self).__init__() + def __init__(self, lens_path): + """Initialize a lens class.""" self.device = init_device() - - # Load lens file - self.lens_name = filename - self.load_file(filename) + self.read_lens(lens_path) self.to(self.device) - # # Lens calculation - # self.prepare_sensor(sensor_res) - # self.post_computation() + def read_lens(self, lens_path): + """Read lens from a file.""" + if lens_path.endswith(".json"): + self.read_lens_json(lens_path) + else: + raise Exception("Unknown lens file format.") - def load_file(self, filename): - """Load lens from a file.""" + def read_lens_json(self, lens_path): + """Read lens from a json file.""" raise NotImplementedError - def write_file(self, filename): + def write_lens(self, lens_path): """Write lens to a file.""" + if lens_path.endswith(".json"): + self.write_lens_json(lens_path) + else: + raise Exception("Unknown lens file format.") + + def write_lens_json(self, lens_path): + """Write lens to a json file.""" raise NotImplementedError def prepare_sensor(self, sensor_res=[1024, 1024]): @@ -57,21 +66,46 @@ def post_computation(self): # PSF-ralated functions # =========================================== def psf(self, points, ks=51, wvln=0.589, **kwargs): - """Compute monochrome point PSF. This function is differentiable.""" + """Compute monochrome point PSF. This function is differentiable. + + Args: + point (tensor): Shape of [N, 3] or [3]. + ks (int, optional): Kernel size. Defaults to 51. + wvln (float, optional): Wavelength. Defaults to 0.589. + + Returns: + psf: Shape of [ks, ks] or [N, ks, ks]. + """ raise NotImplementedError - def psf_rgb(self, point, ks=51, **kwargs): - """Compute RGB point PSF. This function is differentiable.""" + def psf_rgb(self, points, ks=51, **kwargs): + """Compute RGB point PSF. + + Args: + points (tensor): Shape of [N, 3] or [3]. + ks (int, optional): Kernel size. Defaults to 51. + + Returns: + psf_rgb: Shape of [3, ks, ks] or [N, 3, ks, ks]. + """ psfs = [] for wvln in WAVE_RGB: - psfs.append(self.psf(point=point, ks=ks, wvln=wvln, **kwargs)) + psfs.append(self.psf(points=points, ks=ks, wvln=wvln, **kwargs)) psf_rgb = torch.stack(psfs, dim=-3) # shape [3, ks, ks] or [N, 3, ks, ks] return psf_rgb def psf_narrow_band(self, points, ks=51, **kwargs): - """Should be migrated to psf_rgb. + """Compute RGB PSF considering three wavelengths for each color. Different wavelengths are simpliy averaged for the results, but the sensor response function will be more reasonable. - In this function we use an average for different wavelengths. Actually we should use the sensor response function. + Reference: + https://en.wikipedia.org/wiki/Spectral_sensitivity + + Args: + points (tensor): Shape of [N, 3] or [3]. + ks (int, optional): Kernel size. Defaults to 51. + + Returns: + psf: Shape of [3, ks, ks]. """ # Red psf_r = [] @@ -96,7 +130,18 @@ def psf_narrow_band(self, points, ks=51, **kwargs): return psf def psf_spectrum(self, points, ks=51, **kwargs): - """Should be migrated to psf_rgb.""" + """Compute RGB PSF considering full spectrum for each color. A placeholder RGB sensor response function is used to calculate the final PSF. But the actual sensor response function will be more reasonable. + + Reference: + https://en.wikipedia.org/wiki/Spectral_sensitivity + + Args: + points (tensor): Shape of [N, 3] or [3]. + ks (int, optional): Kernel size. Defaults to 51. + + Returns: + psf: Shape of [3, ks, ks]. + """ # Red psf_r = [] for i, wvln in enumerate(WAVE_BOARD_BAND): @@ -133,25 +178,35 @@ def draw_psf(self, depth=DEPTH, ks=101, save_name="./psf.png"): save_image(psfs.unsqueeze(0), save_name, normalize=True) def point_source_grid( - self, depth, grid=9, normalized=True, quater=False, center=False + self, depth, grid=9, normalized=True, quater=False, center=True ): + """Generate point source grid for PSF calculation. + + Args: + depth (float): Depth of the point source. + grid (int): Grid size. Defaults to 9, meaning 9x9 grid. + normalized (bool): Return normalized object source coordinates. Defaults to True, meaning object sources xy coordinates range from [-1, 1]. + quater (bool): Use quater of the sensor plane to save memory. Defaults to False. + center (bool): Use center of each patch. Defaults to False. + + Returns: + point_source: Shape of [grid, grid, 3]. """ - Generate point source grid for PSF calculation. - """ - # ==> Use center of each patch + # Compute point source grid if grid == 1: x, y = torch.tensor([[0.0]]), torch.tensor([[0.0]]) assert not quater, "Quater should be False when grid is 1." else: if center: + # Use center of each patch half_bin_size = 1 / 2 / (grid - 1) x, y = torch.meshgrid( torch.linspace(-1 + half_bin_size, 1 - half_bin_size, grid), torch.linspace(1 - half_bin_size, -1 + half_bin_size, grid), indexing="xy", ) - # ==> Use corner else: + # Use corner of image sensor x, y = torch.meshgrid( torch.linspace(-0.98, 0.98, grid), torch.linspace(0.98, -0.98, grid), @@ -161,7 +216,7 @@ def point_source_grid( z = torch.full((grid, grid), depth) point_source = torch.stack([x, y, z], dim=-1) - # ==> Use quater of the sensor plane to save memory + # Use quater of the sensor plane to save memory if quater: z = torch.full((grid, grid), depth) point_source = torch.stack([x, y, z], dim=-1) @@ -169,32 +224,67 @@ def point_source_grid( bound_j = grid // 2 point_source = point_source[0:bound_i, bound_j:, :] + # De-normalize object source coordinates to physical coordinates if not normalized: - raise Exception("Need to specify the scale.") - scale = self.calc_scale_pinhole(depth) + scale = self.calc_scale(depth) point_source[..., 0] *= scale * self.sensor_size[0] / 2 point_source[..., 1] *= scale * self.sensor_size[1] / 2 return point_source - def psf_map(self, grid=21, ks=51, depth=-20000.0, wvln=0.589, **kwargs): - """Compute PSF map.""" - # raise NotImplementedError + def point_source_radial(self, depth, grid=9, center=False): + """Compute point radial [0, 1] in the object space to compute PSF grid. + + Args: + grid (int, optional): Grid size. Defaults to 9. + + Returns: + point_source: Shape of [grid, 3]. + """ + if grid == 1: + x = torch.tensor([0.0]) + else: + # Select center of bin to calculate PSF + if center: + half_bin_size = 1 / 2 / (grid - 1) + x = torch.linspace(0, 1 - half_bin_size, grid) + else: + x = torch.linspace(0, 0.98, grid) + + z = torch.full_like(x, depth) + point_source = torch.stack([x, x, z], dim=-1) + return point_source + + def psf_map(self, grid=5, ks=51, depth=DEPTH, wvln=0.589, **kwargs): + """Compute monochrome PSF map. + + Args: + grid (int, optional): Grid size. Defaults to 5, meaning 5x5 grid. + ks (int, optional): Kernel size. Defaults to 51, meaning 51x51 kernel size. + depth (float, optional): Depth of the object. Defaults to DEPTH. + wvln (float, optional): Wavelength. Defaults to 0.589. + + Returns: + psf_map: Shape of [1, grid*ks, grid*ks]. + """ + # PSF map grid points = self.point_source_grid(depth=depth, grid=grid, center=True) points = points.reshape(-1, 3) + # Compute PSF map psfs = [] for i in range(points.shape[0]): point = points[i, ...] - psf = self.psf(point=point, ks=ks, wvln=wvln) + psf = self.psf(points=point, ks=ks, wvln=wvln) psfs.append(psf) - psf_map = torch.stack(psfs).unsqueeze(1) + + # Reshape PSF map from [grid*grid, 1, ks, ks] -> [1, grid*ks, grid*ks] psf_map = make_grid(psf_map, nrow=grid, padding=0)[0, :, :] return psf_map - def psf_map_rgb(self, grid=21, ks=51, depth=-20000.0, **kwargs): + def psf_map_rgb(self, grid=5, ks=51, depth=DEPTH, **kwargs): """Compute RGB PSF map.""" psfs = [] for wvln in WAVE_RGB: @@ -203,59 +293,132 @@ def psf_map_rgb(self, grid=21, ks=51, depth=-20000.0, **kwargs): psf_map = torch.stack(psfs, dim=0) # shape [3, grid*ks, grid*ks] return psf_map + @torch.no_grad() def draw_psf_map( - self, - grid=8, - depth=DEPTH, - ks=101, - log_scale=False, - save_name="./psf_map.png", + self, grid=7, ks=101, depth=DEPTH, log_scale=False, save_name="./psf_map.png" ): - """Draw RGB PSF map of the DOE thin lens.""" - # Calculate PSF map - psf_maps = [] - for wvln in WAVE_RGB: - psf_map = self.psf_map(grid=grid, depth=depth, ks=ks, wvln=wvln) - psf_maps.append(psf_map) - psf_map = torch.stack(psf_maps, dim=0) # shape [3, grid*ks, grid*ks] + """Draw RGB PSF map of the doelens.""" + # Calculate RGB PSF map + psf_map = self.psf_map_rgb(depth=depth, grid=grid, ks=ks) - # Data processing for visualization if log_scale: - psf_map = torch.log(psf_map + 1e-4) - psf_map = (psf_map - psf_map.min()) / (psf_map.max() - psf_map.min()) - - save_image(psf_map.unsqueeze(0), save_name, normalize=True) + # Log scale normalization for better visualization + psf_map = torch.log(psf_map + 1e-4) # 1e-4 is an empirical value + psf_map = (psf_map - psf_map.min()) / (psf_map.max() - psf_map.min()) + else: + # Linear normalization + for i in range(0, psf_map.shape[-2], ks): + for j in range(0, psf_map.shape[-1], ks): + local_max = psf_map[:, i : i + ks, j : j + ks].max() + if local_max > 0: + psf_map[:, i : i + ks, j : j + ks] /= local_max + + fig, ax = plt.subplots(figsize=(10, 10)) + + psf_map = psf_map.permute(1, 2, 0).cpu().numpy() + ax.imshow(psf_map) + + # Add scale bar near bottom-left + H, W, _ = psf_map.shape + scale_bar_length = 100 + arrow_length = scale_bar_length / (self.pixel_size * 1e3) + y_position = H - 20 # a little above the lower edge + x_start = 20 + x_end = x_start + arrow_length + + ax.annotate( + "", + xy=(x_start, y_position), + xytext=(x_end, y_position), + arrowprops=dict(arrowstyle="-", color="white"), + annotation_clip=False, + ) + ax.text( + x_end + 5, + y_position, + f"{scale_bar_length} μm", + color="white", + fontsize=12, + ha="left", + va="center", + clip_on=False, + ) + + # Clean up axes and save + ax.axis("off") + plt.tight_layout(pad=0) + plt.savefig(save_name, dpi=300, bbox_inches="tight", pad_inches=0) + plt.close(fig) # =========================================== - # Rendering-ralated functions + # Image simulation-ralated functions # =========================================== - def render(self, img, method="psf", noise_std=0.01): - """PSF based rendering or image based rendering.""" + def render(self, img_obj, depth=DEPTH, method="psf", **kwargs): + """Differentiable image simulation. + + Image simulation methods: + [1] PSF map block convolution. + [2] Ray tracing-based rendering. + [3] ... + + Args: + img_obj (tensor): Input image object in raw space. Shape of [N, C, H, W]. + depth (float, optional): Depth of the object. Defaults to DEPTH. + method (str, optional): Image simulation method. Defaults to "psf". + **kwargs: Additional arguments for different methods. + """ + # Check sensor resolution if not ( - self.sensor_res[0] == img.shape[-2] and self.sensor_res[1] == img.shape[-1] + self.sensor_res[0] == img_obj.shape[-2] + and self.sensor_res[1] == img_obj.shape[-1] ): - self.prepare_sensor(sensor_res=[img.shape[-2], img.shape[-1]]) + H, W = img_obj.shape[-2], img_obj.shape[-1] + self.prepare_sensor(sensor_res=[H, W]) + # Image simulation (in RAW space) if method == "psf": - # Note: larger psf_grid and psf_ks are better - psf_map = self.psf_map(grid=psf_grid, ks=psf_ks, depth=depth) - img_render = render_psf_map(img, psf_map, grid=psf_grid) + # Note: larger psf_grid and psf_ks are typically better + if "psf_grid" in kwargs and "psf_ks" in kwargs: + psf_grid, psf_ks = kwargs["psf_grid"], kwargs["psf_ks"] + else: + raise Exception("Please provide psf_grid and psf_ks.") + + psf_map = self.psf_map_rgb(grid=psf_grid, ks=psf_ks, depth=depth) + img_render = render_psf_map(img_obj, psf_map, grid=psf_grid) + elif method == "ray_tracing": + raise NotImplementedError else: raise Exception("Unknown method.") # Add sensor noise - if noise_std > 0: - img_render = img_render + torch.randn_like(img_render) * noise_std + img_render = self.add_noise( + img_render, read_noise_std=0.0, shot_noise_alpha=0.0 + ) return img_render def add_noise(self, img, read_noise_std=0.01, shot_noise_alpha=1.0): - """Add both read noise and shot noise. Note: for an accurate noise model, we need to convert back to RAW space.""" + """Add sensor read noise and shot noise. + + Note: use RAW space image for accurate noise simulation. + + Args: + img_raw (tensor): RAW space image. Shape of [N, C, H, W]. + read_noise_std (float): Read noise standard deviation. + shot_noise_alpha (float): Shot noise alpha. + + Returns: + img: Noisy image. Shape of [N, C, H, W]. + """ noise_std = torch.sqrt(img) * shot_noise_alpha + read_noise_std noise = torch.randn_like(img) * noise_std img = img + noise return img + def isp(self, img_raw): + """Image signal processing.""" + raise NotImplementedError + # =========================================== # Visualization-ralated functions # =========================================== @@ -263,22 +426,6 @@ def draw_layout(self): """Draw lens layout.""" raise NotImplementedError - def draw_psf_map( - self, - grid=7, - depth=DEPTH, - ks=101, - log_scale=False, - recenter=True, - save_name="./psf", - ): - """Draw lens RGB PSF map.""" - raise NotImplementedError - - def draw_render_image(self, image): - """Draw input and simulated images.""" - raise NotImplementedError - # =========================================== # Optimization-ralated functions # =========================================== diff --git a/deeplens/optics/basics.py b/deeplens/optics/basics.py index 46084c2..29bdd58 100644 --- a/deeplens/optics/basics.py +++ b/deeplens/optics/basics.py @@ -12,12 +12,6 @@ def init_device(): device = torch.device("cuda") device_name = torch.cuda.get_device_name(0) print(f"Using CUDA: {device_name}") - # elif torch.backends.mps.is_available(): - # raise NotImplementedError( - # "MPS is not supported yet due to incompatible with some functions." - # ) - # device = torch.device("mps") - # print("Using MPS") else: device = torch.device("cpu") device_name = "CPU" @@ -196,7 +190,7 @@ def init_device(): DEPTH = -20000.0 GEO_SPP = 10000 # spp for geometric optics calculation (psf) -COHERENT_SPP = 1000000 # spp for coherent optics calculation +COHERENT_SPP = 10000000 # spp for coherent optics calculation GEO_GRID = 21 # grid number for geometric optics calculation (PSF map) PSF_KS = 51 @@ -219,7 +213,7 @@ def wave_rgb(): # =========================================== # Classes # =========================================== -class DeepObj(nn.Module): +class DeepObj: def __str__(self): """Called when using print() and str()""" lines = [self.__class__.__name__ + ":"] diff --git a/deeplens/optics/loss.py b/deeplens/optics/loss.py new file mode 100644 index 0000000..e19c986 --- /dev/null +++ b/deeplens/optics/loss.py @@ -0,0 +1,59 @@ +import torch +import torch.nn as nn + + +class PSFLoss(nn.Module): + def __init__(self, w_achromatic=1.0, w_psf_size=1.0): + super(PSFLoss, self).__init__() + self.w_achromatic = w_achromatic + self.w_psf_size = w_psf_size + + def forward(self, psf): + # Ensure psf has shape [batch, channels, height, width] + if psf.dim() == 3: + psf = psf.unsqueeze(0) # Add batch dimension + elif psf.dim() == 2: + psf = ( + psf.unsqueeze(0).unsqueeze(0).repeat(1, 3, 1, 1) + ) # Add batch and channel dimensions + + batch, channels, height, width = psf.shape + + # Normalize PSF across spatial dimensions + psf_normalized = psf / psf.view(batch, channels, -1).sum( + dim=2, keepdim=True + ).view(batch, channels, 1, 1) + + # Concentration Loss: Minimize the spatial variance + # Compute coordinates + x = torch.linspace(-1, 1, steps=width, device=psf.device) + y = torch.linspace(-1, 1, steps=height, device=psf.device) + xv, yv = torch.meshgrid(x, y, indexing="ij") + xv = xv.unsqueeze(0).unsqueeze(0) # Shape [1, 1, H, W] + yv = yv.unsqueeze(0).unsqueeze(0) + + # Calculate mean positions + mean_x = (psf_normalized * xv).sum(dim=(2, 3)) + mean_y = (psf_normalized * yv).sum(dim=(2, 3)) + + # Calculate variance + var_x = ((xv - mean_x.view(batch, channels, 1, 1)) ** 2 * psf_normalized).sum( + dim=(2, 3) + ) + var_y = ((yv - mean_y.view(batch, channels, 1, 1)) ** 2 * psf_normalized).sum( + dim=(2, 3) + ) + concentration_loss = var_x + var_y + concentration_loss = concentration_loss.mean() + + # Achromatic Loss: Minimize differences between channels + channel_diff = 0 + for i in range(channels): + for j in range(i + 1, channels): + channel_diff += torch.mean((psf[:, i, :, :] - psf[:, j, :, :]) ** 2) + channel_diff = channel_diff / (channels * (channels - 1) / 2) + + total_loss = ( + self.w_psf_size * concentration_loss + self.w_achromatic * channel_diff + ) + return total_loss diff --git a/deeplens/optics/materials.py b/deeplens/optics/materials.py index a5ce104..e1eab8d 100644 --- a/deeplens/optics/materials.py +++ b/deeplens/optics/materials.py @@ -5,6 +5,7 @@ from .basics import DeepObj + class Material(DeepObj): def __init__(self, name=None, device="cpu"): self.name = "vacuum" if name is None else name.lower() @@ -118,7 +119,7 @@ def match_material(self, mat_table=None): mat_table = CDGM_GLASS else: raise NotImplementedError - + weight_n = 2 dist_min = 1e6 for name in mat_table: @@ -127,8 +128,7 @@ def match_material(self, mat_table=None): if dist < dist_min: self.name = name dist_min = dist - - breakpoint() + self.load_dispersion() def get_optimizer_params(self, lr=[1e-5, 1e-3]): diff --git a/deeplens/optics/surfaces.py b/deeplens/optics/surfaces.py index 8a3bce5..a44a07e 100644 --- a/deeplens/optics/surfaces.py +++ b/deeplens/optics/surfaces.py @@ -85,9 +85,9 @@ def newtons_method(self, ray): d_surf = self.d + self.d_perturb # 1. inital guess of t - t0 = ( - (d_surf - ray.o[..., 2]) / ray.d[..., 2] - ) # if the shape of aspheric surface is strange, will hit the back surface region instead + t0 = (d_surf - ray.o[..., 2]) / ray.d[ + ..., 2 + ] # if the shape of aspheric surface is strange, will hit the back surface region instead # 2. use Newton's method to update t to find the intersection points (non-differentiable) with torch.no_grad(): @@ -279,15 +279,15 @@ def d2fdxyz2(self, x, y, valid=None): d2g_dx2, d2g_dxdy, d2g_dy2 = self.d2gd(x, y) # Since f(x, y, z) = z - g(x, y), its second derivatives are: - d2fdx2 = -d2g_dx2 # ∂²f/∂x² = -∂²g/∂x² - d2fdxdy = -d2g_dxdy # ∂²f/∂x∂y = -∂²g/∂x∂y - d2fdy2 = -d2g_dy2 # ∂²f/∂y² = -∂²g/∂y² + d2fdx2 = -d2g_dx2 # ∂²f/∂x² = -∂²g/∂x² + d2fdxdy = -d2g_dxdy # ∂²f/∂x∂y = -∂²g/∂x∂y + d2fdy2 = -d2g_dy2 # ∂²f/∂y² = -∂²g/∂y² # Mixed partial derivatives involving z are zero zeros = torch.zeros_like(x) - d2fdxdz = zeros # ∂²f/∂x∂z = 0 - d2fdydz = zeros # ∂²f/∂y∂z = 0 - d2fdz2 = zeros # ∂²f/∂z² = 0 + d2fdxdz = zeros # ∂²f/∂x∂z = 0 + d2fdydz = zeros # ∂²f/∂y∂z = 0 + d2fdz2 = zeros # ∂²f/∂z² = 0 return d2fdx2, d2fdxdy, d2fdy2, d2fdxdz, d2fdydz, d2fdz2 @@ -320,7 +320,7 @@ def dgd(self, x, y): dgdy (tensor): dg / dy """ raise NotImplementedError() - + def d2gd(self, x, y): """Compute second-order derivatives of sag to x and y. (d2gdx2, d2gdy2) = (g''xx, g''yy). @@ -403,8 +403,8 @@ def get_optimizer(self, lr, optim_mat=False): def surf_dict(self): surf_dict = { "type": self.__class__.__name__, - "r": float(f"{self.r:.6f}"), - "(d)": float(f"{self.d.item():.3f}"), + "r": round(self.r, 4), + "(d)": round(self.d.item(), 4), "is_square": self.is_square, "mat2": self.mat2.get_name(), } @@ -473,8 +473,9 @@ def surf_dict(self): """Return a dict of surface.""" surf_dict = { "type": "Aperture", - "r": float(f"{self.r:.6f}"), - "(d)": float(f"{self.d.item():.3f}"), + "r": round(self.r, 4), + "(d)": round(self.d.item(), 4), + "mat2": "air", "is_square": self.is_square, "diffraction": self.diffraction, } @@ -688,7 +689,7 @@ def dgd(self, x, y): exec(f"dsdr2 += {i} * self.ai{2*i} * r2 ** {i-1}") return dsdr2 * 2 * x, dsdr2 * 2 * y - + def d2gd(self, x, y): """Compute second-order derivatives of surface height with respect to x and y.""" r2 = x**2 + y**2 @@ -701,10 +702,12 @@ def d2gd(self, x, y): # Compute derivative of dsdr2 with respect to r2 (ddsdr2_dr2) ddsdr2_dr2 = ( - ((1 + k) * c**2 / (2 * sf)) - + ((1 + k) ** 2 * r2 * c**4) / (4 * sf**3) - ) * c / (1 + sf) ** 2 \ - - 2 * dsdr2 * ((1 + sf + (1 + k) * r2 * c**2 / (2 * sf))) / (1 + sf) + ((1 + k) * c**2 / (2 * sf)) + ((1 + k) ** 2 * r2 * c**4) / (4 * sf**3) + ) * c / (1 + sf) ** 2 - 2 * dsdr2 * ( + (1 + sf + (1 + k) * r2 * c**2 / (2 * sf)) + ) / ( + 1 + sf + ) if self.ai_degree > 0: if self.ai_degree == 4: @@ -757,7 +760,7 @@ def d2gd(self, x, y): ) else: for i in range(1, self.ai_degree + 1): - ai_coeff = getattr(self, f'ai{2*i}') + ai_coeff = getattr(self, f"ai{2*i}") dsdr2 += i * ai_coeff * r2 ** (i - 1) if i > 1: ddsdr2_dr2 += i * (i - 1) * ai_coeff * r2 ** (i - 2) @@ -885,17 +888,20 @@ def surf_dict(self): """Return a dict of surface.""" surf_dict = { "type": "Aspheric", - "r": float(f"{self.r:.6f}"), - "(c)": float(f"{self.c.item():.3f}"), - "roc": float(f"{1/self.c.item():.3f}"), - "(d)": float(f"{self.d.item():.3f}"), - "k": float(f"{self.k.item():.6f}"), + "r": round(self.r, 4), + "(c)": round(self.c.item(), 4), + "roc": round(1 / self.c.item(), 4), + "d": round(self.d.item(), 4), + "k": round(self.k.item(), 4), "ai": [], "mat2": self.mat2.get_name(), } + # for i in range(1, self.ai_degree + 1): + # exec(f"surf_dict['(ai{2*i})'] = self.ai{2*i}.item()") + # surf_dict["ai"].append(eval(f"self.ai{2*i}.item()")) for i in range(1, self.ai_degree + 1): - exec(f"surf_dict['(ai{2*i})'] = self.ai{2*i}.item()") - surf_dict["ai"].append(eval(f"self.ai{2*i}.item()")) + exec(f"surf_dict['(ai{2*i})'] = float(format(self.ai{2*i}.item(), '.6e'))") + surf_dict["ai"].append(float(format(eval(f"self.ai{2*i}.item()"), ".6e"))) return surf_dict @@ -1081,14 +1087,15 @@ def surf_dict(self): "b5": self.b5.item(), "b7": self.b7.item(), "r": self.r, - "(d)": float(f"{self.d.item():.3f}"), + "(d)": round(self.d.item(), 4), } -class DOE_GEO(Surface): - """Kinoform and binary diffractive surfaces for ray tracing. +class Diffractive_GEO(Surface): + """Diffractive surfaces simulated with ray tracing. - https://support.zemax.com/hc/en-us/articles/1500005489061-How-diffractive-surfaces-are-modeled-in-OpticStudio + Reference: + https://support.zemax.com/hc/en-us/articles/1500005489061-How-diffractive-surfaces-are-modeled-in-OpticStudio """ def __init__( @@ -1108,7 +1115,7 @@ def __init__( # Use ray tracing to simulate diffraction, the same as Zemax self.diffraction = False self.diffraction_order = 1 - print("A DOE_GEO is created, but ray-tracing diffraction is not activated.") + # print("A Diffractive_GEO is created, but ray-tracing diffraction is not activated.") self.to(device) self.init_param_model(param_model) @@ -1153,7 +1160,7 @@ def activate_diffraction(self, diffraction_order=1): # ============================== # Computation (ray tracing) # ============================== - def ray_reaction(self, ray, **kwargs): + def ray_reaction(self, ray, n1=None, n2=None): """Ray reaction on DOE surface. Imagine the DOE as a wrapped positive convex lens for debugging. 1, The phase φ in radians adds to the optical path length of the ray @@ -1482,7 +1489,7 @@ def surf_dict(self): "glass": self.glass, "param_model": self.param_model, "f0": self.f0.item(), - "(d)": float(f"{self.d.item():.3f}"), + "(d)": round(self.d.item(), 4), "mat2": self.mat2.get_name(), } @@ -1496,7 +1503,7 @@ def surf_dict(self): "order4": self.order4.item(), "order6": self.order6.item(), "order8": self.order8.item(), - "(d)": f"{float(self.d.item()):.3f}", + "(d)": round(self.d.item(), 4), "mat2": self.mat2.get_name(), } @@ -1512,7 +1519,7 @@ def surf_dict(self): "order5": self.order5.item(), "order6": self.order6.item(), "order7": self.order7.item(), - "(d)": float(f"{self.d.item():.3f}"), + "(d)": round(self.d.item(), 4), "mat2": self.mat2.get_name(), } @@ -1524,7 +1531,7 @@ def surf_dict(self): "param_model": self.param_model, "theta": self.theta.item(), "alpha": self.alpha.item(), - "(d)": float(f"{self.d.item():.3f}"), + "(d)": round(self.d.item(), 4), "mat2": self.mat2.get_name(), } @@ -1575,12 +1582,10 @@ def ray_reaction(self, ray, **kwargs): class Plane(Surface): - def __init__(self, l, d, mat2, is_square=True, device="cpu"): + def __init__(self, r, d, mat2, is_square=False, device="cpu"): """Plane surface, typically rectangle. Working as IR filter, lens cover glass or DOE base.""" - Surface.__init__( - self, l / np.sqrt(2), d, mat2=mat2, is_square=is_square, device=device - ) - self.l = l + Surface.__init__(self, r, d, mat2=mat2, is_square=is_square, device=device) + self.l = r * np.sqrt(2) def intersect(self, ray, n=1.0): """Solve ray-surface intersection and update ray data.""" @@ -1627,8 +1632,9 @@ def dgd(self, x, y): def surf_dict(self): surf_dict = { "type": "Plane", - "l": self.l, - "(d)": float(f"{self.d.item():.3f}"), + "(l)": self.l, + "r": self.r, + "(d)": round(self.d.item(), 4), "is_square": True, "mat2": self.mat2.get_name(), } @@ -1663,35 +1669,35 @@ def dgd(self, x, y): sf = torch.sqrt(1 - r2 * c**2 + EPSILON) dgdr2 = c / (2 * sf) return dgdr2 * 2 * x, dgdr2 * 2 * y - - def d2gd(self, x, y): - """Compute second-order derivatives of the surface sag z = g(x, y). - Args: - x (tensor): x coordinate - y (tensor): y coordinate + def d2gd(self, x, y): + """Compute second-order derivatives of the surface sag z = g(x, y). - Returns: - d2g_dx2 (tensor): ∂²g / ∂x² - d2g_dxdy (tensor): ∂²g / ∂x∂y - d2g_dy2 (tensor): ∂²g / ∂y² - """ - c = self.c + self.c_perturb - r2 = x**2 + y**2 - sf = torch.sqrt(1 - r2 * c**2 + EPSILON) + Args: + x (tensor): x coordinate + y (tensor): y coordinate - # First derivative (dg/dr2) - dgdr2 = c / (2 * sf) + Returns: + d2g_dx2 (tensor): ∂²g / ∂x² + d2g_dxdy (tensor): ∂²g / ∂x∂y + d2g_dy2 (tensor): ∂²g / ∂y² + """ + c = self.c + self.c_perturb + r2 = x**2 + y**2 + sf = torch.sqrt(1 - r2 * c**2 + EPSILON) - # Second derivative (d²g/dr2²) - d2g_dr2_dr2 = (c**3) / (4 * sf**3) + # First derivative (dg/dr2) + dgdr2 = c / (2 * sf) - # Compute second-order partial derivatives using the chain rule - d2g_dx2 = 4 * x**2 * d2g_dr2_dr2 + 2 * dgdr2 - d2g_dxdy = 4 * x * y * d2g_dr2_dr2 - d2g_dy2 = 4 * y**2 * d2g_dr2_dr2 + 2 * dgdr2 + # Second derivative (d²g/dr2²) + d2g_dr2_dr2 = (c**3) / (4 * sf**3) - return d2g_dx2, d2g_dxdy, d2g_dy2 + # Compute second-order partial derivatives using the chain rule + d2g_dx2 = 4 * x**2 * d2g_dr2_dr2 + 2 * dgdr2 + d2g_dxdy = 4 * x * y * d2g_dr2_dr2 + d2g_dy2 = 4 * y**2 * d2g_dr2_dr2 + 2 * dgdr2 + + return d2g_dx2, d2g_dxdy, d2g_dy2 def valid(self, x, y): """Invalid when shape is non-defined.""" @@ -1736,10 +1742,10 @@ def surf_dict(self): roc = 1 / self.c.item() if self.c.item() != 0 else 0.0 surf_dict = { "type": "Spheric", - "r": float(f"{self.r:.3f}"), - "(c)": float(f"{self.c.item():.3f}"), - "roc": float(f"{roc:.3f}"), - "(d)": float(f"{self.d.item():.3f}"), + "r": round(self.r, 4), + "(c)": round(self.c.item(), 4), + "roc": round(roc, 4), + "(d)": round(self.d.item(), 4), "mat2": self.mat2.get_name(), } diff --git a/deeplens/optics/surfaces_diffractive.py b/deeplens/optics/surfaces_diffractive.py index fa28e28..91b84bf 100644 --- a/deeplens/optics/surfaces_diffractive.py +++ b/deeplens/optics/surfaces_diffractive.py @@ -1,6 +1,6 @@ -"""Diffractive optical surfaces. +"""Diffractive optical surfaces and related functions. -The input and output of each surface is a complex wave field. +Diffractive surfaces: the input and output of each surface is a complex wave field. """ import math @@ -11,13 +11,16 @@ import torch.nn.functional as F from torchvision.utils import save_image -from .basics import DeepObj +from .basics import DeepObj, EPSILON +# ======================================= +# Diffractive optical surfaces +# ======================================= class DOE(DeepObj): def __init__(self, l, d, res, fab_ps=0.001, param_model="pixel2d", device="cpu"): """DOE class.""" - super().__init__() + # super().__init__() # DOE material self.glass = "fused_silica" # DOE substrate material @@ -77,6 +80,8 @@ def init_param_model(self, param_model="none", **kwargs): # "Phase fresnel" or "Fresnel zone plate (FPZ)" f0 = kwargs.get("f0", 100.0) self.f0 = torch.tensor([f0]) + + # In the future we donot want to give another wvln fresnel_wvln = kwargs.get("fresnel_wvln", 0.55) self.fresnel_wvln = fresnel_wvln @@ -182,46 +187,47 @@ def save_ckpt(self, save_path="./doe.pth"): else: raise Exception("Unknown parameterization.") - def load_doe(self, load_path="./doe_fab.pth"): - """Load DOE phase map.""" - self.load_ckpt(load_path) - - def load_ckpt(self, load_path="./doe.pth"): - """Load DOE phase map.""" - ckpt = torch.load(load_path) - param_model = ckpt["param_model"] + def load_doe(self, doe_dict): + """Load DOE parameters from a dict.""" + # Init DOE parameter model + param_model = doe_dict["param_model"] self.init_param_model(param_model) + # Load DOE parameters if self.param_model == "fresnel": - self.f0 = ckpt["f0"].to(self.device) + self.f0 = doe_dict["f0"].to(self.device) elif self.param_model == "cubic": - self.a3 = ckpt["a3"].to(self.device) + self.a3 = doe_dict["a3"].to(self.device) elif self.param_model == "binary2": - self.init_param_model("binary2") - self.order2 = ckpt["order2"].to(self.device) - self.order4 = ckpt["order4"].to(self.device) - self.order6 = ckpt["order6"].to(self.device) - self.order8 = ckpt["order8"].to(self.device) + self.order2 = doe_dict["order2"].to(self.device) + self.order4 = doe_dict["order4"].to(self.device) + self.order6 = doe_dict["order6"].to(self.device) + self.order8 = doe_dict["order8"].to(self.device) elif self.param_model == "poly1d": - self.order2 = ckpt["order2"].to(self.device) - self.order3 = ckpt["order3"].to(self.device) - self.order4 = ckpt["order4"].to(self.device) - self.order5 = ckpt["order5"].to(self.device) - self.order6 = ckpt["order6"].to(self.device) - self.order7 = ckpt["order7"].to(self.device) + self.order2 = doe_dict["order2"].to(self.device) + self.order3 = doe_dict["order3"].to(self.device) + self.order4 = doe_dict["order4"].to(self.device) + self.order5 = doe_dict["order5"].to(self.device) + self.order6 = doe_dict["order6"].to(self.device) + self.order7 = doe_dict["order7"].to(self.device) elif self.param_model == "zernike": - self.z_coeff = ckpt["z_coeff"].to(self.device) + self.z_coeff = doe_dict["z_coeff"].to(self.device) elif self.param_model == "pixel2d": - self.pmap = ckpt["pmap"].to(self.device) + self.pmap = doe_dict["pmap"].to(self.device) else: raise Exception("Unknown parameterization.") + def load_ckpt(self, load_path="./doe.pth"): + """Load DOE phase map.""" + ckpt = torch.load(load_path) + self.load_doe(ckpt) + # ======================================= # Computation # ======================================= @@ -230,7 +236,7 @@ def get_phase_map(self, wvln=0.55): First we should calculate the phase map at 0.55um, then calculate the phase map for the given other wavelength. """ - phase_map0 = self.get_pmap() + phase_map0 = self.get_phase_map0() n = self.refractive_index(wvln) phase_map = phase_map0 * (self.wvln0 / wvln) * (n - 1) / (self.n0 - 1) @@ -243,7 +249,7 @@ def get_phase_map(self, wvln=0.55): ) return phase_map - def get_pmap(self): + def get_phase_map0(self): """Calculate phase map at wvln 0.55 um. Returns: @@ -389,14 +395,14 @@ def refractive_index(self, wvln=0.55): def pmap_quantize(self, bits=16): """Quantize phase map to bits levels.""" - pmap = self.get_pmap() + pmap = self.get_phase_map0() pmap_q = torch.round(pmap / (2 * np.pi / bits)) * (2 * np.pi / bits) return pmap_q def pmap_fab(self, bits=16, save_path=None): """Convert to fabricate phase map and save it. This function is used to output DOE_fab file, and it will not change the DOE object itself.""" # Fab resolution quantized pmap - pmap = self.get_pmap() + pmap = self.get_phase_map0() fab_res = int(self.ps / self.fab_ps * self.res[0]) pmap = ( F.interpolate( @@ -424,7 +430,7 @@ def loss_quantization(self, bits=16): Reference: Quantization-aware Deep Optics for Diffractive Snapshot Hyperspectral Imaging """ - pmap = self.get_pmap() + pmap = self.get_phase_map0() pmap_q = self.pmap_quantize(bits) loss = torch.mean(torch.abs(pmap - pmap_q)) return loss @@ -435,7 +441,7 @@ def loss_quantization(self, bits=16): def activate_grad(self, activate=True): """Activate gradient for phase map parameters.""" if self.param_model == "fresnel": - self.c.requires_grad = activate + self.f0.requires_grad = activate elif self.param_model == "cubic": self.a3.requires_grad = activate @@ -469,7 +475,7 @@ def get_optimizer_params(self, lr=None): if self.param_model == "fresnel": lr = 0.001 if lr is None else lr - params.append({"params": [self.c], "lr": lr}) + params.append({"params": [self.f0], "lr": lr}) elif self.param_model == "cubic": lr = 0.1 if lr is None else lr @@ -513,7 +519,7 @@ def get_optimizer(self, lr=None): lr (float, optional): Learning rate. Defaults to 1e-3. """ params = self.get_optimizer_params(lr) - optimizer = torch.optim.Adam(params) + optimizer = torch.optim.Adam(params, weight_decay=1e-4) return optimizer @@ -561,12 +567,12 @@ def save_pmap(self, save_path="./DOE_phase_map.png"): def draw_phase_map(self, save_name="./DOE_phase_map.png"): """Draw phase map. Range from [0, 2pi].""" - pmap = self.get_pmap() + pmap = self.get_phase_map0() save_image(pmap, save_name, normalize=True) def draw_phase_map_fab(self, save_name="./DOE_phase_map.png"): """Draw phase map. Range from [0, 2pi].""" - pmap = self.get_pmap() + pmap = self.get_phase_map0() pmap_q = self.pmap_quantize() fig, ax = plt.subplots(1, 2, figsize=(10, 5)) @@ -585,7 +591,7 @@ def draw_phase_map_fab(self, save_name="./DOE_phase_map.png"): def draw_phase_map3d(self, save_name="./DOE_phase_map3d.png"): """Draw 3D phase map.""" - pmap = self.get_pmap() / 20.0 + pmap = self.get_phase_map0() / 20.0 x = np.linspace(-self.w / 2, self.w / 2, self.res[0]) y = np.linspace(-self.h / 2, self.h / 2, self.res[1]) X, Y = np.meshgrid(x, y) @@ -608,7 +614,7 @@ def draw_phase_map3d(self, save_name="./DOE_phase_map3d.png"): def draw_cross_section(self, save_name="./DOE_corss_sec.png"): """Draw cross section of the phase map.""" - pmap = self.get_pmap() + pmap = self.get_phase_map0() pmap = torch.diag(pmap).cpu().numpy() r = np.linspace(-self.w / 2 * np.sqrt(2), self.w / 2 * np.sqrt(2), self.res[0]) @@ -618,6 +624,35 @@ def draw_cross_section(self, save_name="./DOE_corss_sec.png"): fig.savefig(save_name, dpi=600, bbox_inches="tight") plt.close(fig) + def surface(self, x, y, max_offset=0.2): + """When drawing the lens setup, this function is called to compute the surface height. + + Here we use a fake height ONLY for drawing. + """ + roc = self.l + r = torch.sqrt(x**2 + y**2 + EPSILON) + sag = roc * (1 - torch.sqrt(1 - r**2 / roc**2)) + sag = max_offset - torch.fmod(sag, max_offset) + return sag + + def draw_wedge(self, ax, color="black"): + # Create radius points + r = torch.linspace(-self.r, self.r, 256, device=self.device) + offset = 0.1 + + # Draw base at z = self.d + base_z = torch.tensor([self.d + offset, self.d, self.d, self.d + offset]) + base_x = torch.tensor([-self.r, -self.r, self.r, self.r]) + base_points = torch.stack((base_x, torch.zeros_like(base_x), base_z), dim=-1) + base_points = base_points.cpu().detach().numpy() + ax.plot(base_points[..., 2], base_points[..., 0], color=color, linewidth=0.8) + + # Calculate and draw surface + z = self.surface(r, torch.zeros_like(r), max_offset=offset) + self.d + offset + points = torch.stack((r, torch.zeros_like(r), z), dim=-1) + points = points.cpu().detach().numpy() + ax.plot(points[..., 2], points[..., 0], color=color, linewidth=0.8) + # ======================================= # Utils # ======================================= @@ -625,13 +660,39 @@ def surf_dict(self): """Return a dict of surface.""" surf_dict = { "type": "DOE", - "l": float(f"{self.l:.6f}"), + "l": round(self.l, 4), + "d": round(self.d[0].item(), 4), "res": self.res, - "fab_ps": float(f"{self.fab_ps:.6f}"), + "fab_ps": round(self.fab_ps, 6), "is_square": True, "param_model": self.param_model, "doe_path": None, } + + if self.param_model == "fresnel": + surf_dict["f0"] = round(self.f0.item(), 6) + elif self.param_model == "cubic": + surf_dict["a3"] = round(self.a3.item(), 6) + elif self.param_model == "binary2": + surf_dict["order2"] = round(self.order2.item(), 6) + surf_dict["order4"] = round(self.order4.item(), 6) + surf_dict["order6"] = round(self.order6.item(), 6) + surf_dict["order8"] = round(self.order8.item(), 6) + elif self.param_model == "poly1d": + surf_dict["order2"] = round(self.order2.item(), 6) + surf_dict["order3"] = round(self.order3.item(), 6) + surf_dict["order4"] = round(self.order4.item(), 6) + surf_dict["order5"] = round(self.order5.item(), 6) + surf_dict["order6"] = round(self.order6.item(), 6) + surf_dict["order7"] = round(self.order7.item(), 6) + elif self.param_model == "zernike": + surf_dict["z_coeff"] = self.z_coeff.tolist() + elif self.param_model == "pixel2d": + raise NotImplementedError + surf_dict["pmap"] = self.pmap.tolist() + else: + raise NotImplementedError + return surf_dict @@ -786,6 +847,9 @@ def surf_dict(self): return surf_dict +# ======================================= +# Functions +# ======================================= def Zernike(z_coeff, grid=256): """Calculate phase map produced by the first 37 Zernike polynomials. The output zernike phase map is in real value, to use it in the future we need to convert it to complex value.""" # Generate meshgrid diff --git a/deeplens/optics/wave.py b/deeplens/optics/wave.py index 1b5a3b6..f174765 100644 --- a/deeplens/optics/wave.py +++ b/deeplens/optics/wave.py @@ -23,9 +23,18 @@ def __init__( z=0.0, phy_size=[4.0, 4.0], valid_phy_size=None, - res=[1024, 1024], + res=[1000, 1000], ): - """Complex wave field class.""" + """Complex wave field class. + + Args: + u (tensor): complex wave field, shape [H, W] or [B, C, H, W]. + wvln (float): wavelength in [um]. + z (float): distance in [mm]. + phy_size (list): physical size in [mm]. + valid_phy_size (list): valid physical size in [mm]. + res (list): resolution. + """ super(ComplexWave, self).__init__() # Wave field has shape of [N, 1, H, W] for batch processing @@ -349,11 +358,13 @@ def AngularSpectrumMethod(u, z, wvln, ps, n=1.0, padding=True, TF=True): https://blog.csdn.net/zhenpixiaoyang/article/details/111569495 Args: - u: complex field, shape [H, W] or [B, C, H, W] - wvln: wvln - res: field resolution - ps (float): pixel size - z (float): propagation distance + u (tesor): complex field, shape [H, W] or [B, C, H, W] + z (float): propagation distance in [mm] + wvln (float): wavelength in [um] + ps (float): pixel size in [mm] + n (float): refractive index + padding (bool): padding or not + TF (bool): transfer function or impulse response """ if torch.is_tensor(z): z = z.item() @@ -376,7 +387,8 @@ def AngularSpectrumMethod(u, z, wvln, ps, n=1.0, padding=True, TF=True): # Propagation assert wvln > 0.1 and wvln < 1, "wvln unit should be [um]." - k = 2 * np.pi / (wvln * 1e-3) # we use k in vaccum, k in [mm]-1 + wvln_mm = wvln * 1e-3 # [um] to [mm] + k = 2 * np.pi / wvln_mm # we use k in vaccum, k in [mm]-1 x, y = torch.meshgrid( torch.linspace(-0.5 * Wimg * ps, 0.5 * Himg * ps, Wimg, device=u.device), torch.linspace(0.5 * Wimg * ps, -0.5 * Himg * ps, Himg, device=u.device), @@ -389,17 +401,17 @@ def AngularSpectrumMethod(u, z, wvln, ps, n=1.0, padding=True, TF=True): ) # Determine TF or IR - if ps > wvln * np.abs(z) / (Wimg * ps): + if ps > wvln_mm * np.abs(z) / (Wimg * ps): TF = True else: TF = False if TF: if n == 1: - square_root = torch.sqrt(1 - (wvln * 1e-3) ** 2 * (fx**2 + fy**2)) + square_root = torch.sqrt(1 - wvln_mm**2 * (fx**2 + fy**2)) H = torch.exp(1j * k * z * square_root) else: - square_root = torch.sqrt(n**2 - (wvln * 1e-3) ** 2 * (fx**2 + fy**2)) + square_root = torch.sqrt(n**2 - wvln_mm**2 * (fx**2 + fy**2)) H = n * torch.exp(1j * k * z * square_root) H = fftshift(H) @@ -409,9 +421,9 @@ def AngularSpectrumMethod(u, z, wvln, ps, n=1.0, padding=True, TF=True): r = torch.sqrt(r2) if n == 1: - h = z / (1j * wvln * r2) * torch.exp(1j * k * r) + h = z / (1j * wvln_mm * r2) * torch.exp(1j * k * r) else: - h = z * n / (1j * wvln * r2) * torch.exp(1j * n * k * r) + h = z * n / (1j * wvln_mm * r2) * torch.exp(1j * n * k * r) H = fft2(fftshift(h)) * ps**2 diff --git a/lenses/camera/sigma70mm_f2.8.json b/lenses/camera/sigma70mm_f2.8.json new file mode 100644 index 0000000..137e218 --- /dev/null +++ b/lenses/camera/sigma70mm_f2.8.json @@ -0,0 +1,212 @@ +{ + "info": "JP 2008-020656 Example 1 (Sigma Macro 70mm F2.8 EX DG)", + "foclen": 69.168, + "fnum": 2.912, + "r_sensor": 21.6, + "d_sensor": 124.856, + "(sensor_size)": [ + 23.476, + 23.476 + ], + "surfaces": [ + { + "idx": 1, + "type": "Spheric", + "r": 19.21, + "(c)": 0.003, + "roc": 369.55, + "(d)": 0.0, + "mat2": "1.54072/47.2", + "d_next": 1.8 + }, + { + "idx": 2, + "type": "Spheric", + "r": 18.0, + "(c)": 0.029, + "roc": 34.98, + "(d)": 1.8, + "mat2": "air", + "d_next": 13.45 + }, + { + "idx": 3, + "type": "Spheric", + "r": 16.5, + "(c)": 0.022, + "roc": 46.16, + "(d)": 15.25, + "mat2": "1.77250/49.6", + "d_next": 5.65 + }, + { + "idx": 4, + "type": "Spheric", + "r": 16.5, + "(c)": -0.007, + "roc": -135.9, + "(d)": 20.9, + "mat2": "air", + "d_next": 0.15 + }, + { + "idx": 5, + "type": "Spheric", + "r": 14.4, + "(c)": 0.032, + "roc": 31.73, + "(d)": 21.05, + "mat2": "1.69680/55.5", + "d_next": 3.5 + }, + { + "idx": 6, + "type": "Spheric", + "r": 14.4, + "(c)": 0.014, + "roc": 69.5, + "(d)": 24.55, + "mat2": "air", + "d_next": 5.79 + }, + { + "idx": 7, + "type": "Spheric", + "r": 15.0, + "(c)": 0.001, + "roc": 1000.0, + "(d)": 30.34, + "mat2": "1.60342/38.0", + "d_next": 1.0 + }, + { + "idx": 8, + "type": "Spheric", + "r": 13.5, + "(c)": 0.044, + "roc": 22.74, + "(d)": 31.34, + "mat2": "air", + "d_next": 12.1 + }, + { + "idx": 9, + "type": "Aperture", + "r": 10.5, + "(d)": 43.44, + "is_square": false, + "diffraction": false, + "d_next": 4.2 + }, + { + "idx": 10, + "type": "Spheric", + "r": 11.2, + "(c)": -0.037, + "roc": -27.01, + "(d)": 47.64, + "mat2": "1.58144/40.9", + "d_next": 1.0 + }, + { + "idx": 11, + "type": "Spheric", + "r": 12.5, + "(c)": 0.01, + "roc": 101.88, + "(d)": 48.64, + "mat2": "1.56045/71.6", + "d_next": 4.7 + }, + { + "idx": 12, + "type": "Spheric", + "r": 12.5, + "(c)": -0.028, + "roc": -35.28, + "(d)": 53.34, + "mat2": "air", + "d_next": 1.85 + }, + { + "idx": 13, + "type": "Spheric", + "r": 13.0, + "(c)": 0.001, + "roc": 1000.0, + "(d)": 55.19, + "mat2": "1.56045/71.6", + "d_next": 3.05 + }, + { + "idx": 14, + "type": "Spheric", + "r": 13.0, + "(c)": -0.019, + "roc": -53.2, + "(d)": 58.24, + "mat2": "air", + "d_next": 0.15 + }, + { + "idx": 15, + "type": "Spheric", + "r": 13.3, + "(c)": 0.008, + "roc": 125.46, + "(d)": 58.39, + "mat2": "1.49700/81.6", + "d_next": 3.0 + }, + { + "idx": 16, + "type": "Spheric", + "r": 13.3, + "(c)": -0.011, + "roc": -89.7, + "(d)": 61.39, + "mat2": "air", + "d_next": 1.5 + }, + { + "idx": 17, + "type": "Spheric", + "r": 14.4, + "(c)": 0.0, + "roc": 0.0, + "(d)": 62.89, + "mat2": "1.64000/60.2", + "d_next": 1.2 + }, + { + "idx": 18, + "type": "Spheric", + "r": 13.0, + "(c)": 0.019, + "roc": 52.82, + "(d)": 64.09, + "mat2": "air", + "d_next": 2.6 + }, + { + "idx": 19, + "type": "Spheric", + "r": 15.0, + "(c)": 0.001, + "roc": 931.0, + "(d)": 66.69, + "mat2": "1.83481/42.7", + "d_next": 2.35 + }, + { + "idx": 20, + "type": "Spheric", + "r": 15.0, + "(c)": -0.009, + "roc": -113.6, + "(d)": 69.04, + "mat2": "air", + "d_next": 55.816 + } + ] +} \ No newline at end of file diff --git a/lenses/camera/sigma70mm_f2.8.png b/lenses/camera/sigma70mm_f2.8.png new file mode 100644 index 0000000..f7d2b4d Binary files /dev/null and b/lenses/camera/sigma70mm_f2.8.png differ diff --git a/lenses/hybridlens/a489_doe.json b/lenses/hybridlens/a489_doe.json new file mode 100644 index 0000000..4ccddcc --- /dev/null +++ b/lenses/hybridlens/a489_doe.json @@ -0,0 +1,79 @@ +{ + "info": "A489 with a DOE", + "foclen": 7.7719, + "fnum": 3.8859, + "r_sensor": 2.1213, + "d_sensor": 11.257, + "(sensor_size)": [ + 3.0, + 3.0 + ], + "sensor_res": [ + 3000, + 3000 + ], + "surfaces": [ + { + "idx": 1, + "type": "Aperture", + "r": 1.0, + "(d)": 0.0, + "mat2": "air", + "is_square": false, + "diffraction": false, + "d_next": 0.1 + }, + { + "idx": 2, + "type": "Aspheric", + "r": 6.35, + "(c)": 0.2104, + "roc": 4.7531, + "d": 0.1, + "k": -1.2051, + "ai": [ + 0.0, + 0.000533, + 1.12e-05, + -3.75e-07, + -7.63e-09, + 1.36e-10 + ], + "mat2": "hk51", + "(ai2)": 0.0, + "(ai4)": 0.000533, + "(ai6)": 1.12e-05, + "(ai8)": -3.75e-07, + "(ai10)": -7.63e-09, + "(ai12)": 1.36e-10, + "d_next": 7.5 + }, + { + "idx": 3, + "type": "Spheric", + "r": 6.35, + "(c)": -0.0639, + "roc": -15.65, + "(d)": 7.6, + "mat2": "air", + "d_next": 3.657 + } + ], + "DOE": { + "type": "DOE", + "l": 3.0, + "d": 8.3098, + "res": [ + 3000, + 3000 + ], + "fab_ps": 0.001, + "is_square": true, + "param_model": "binary2", + "doe_path": null, + "order2": 0.0, + "order4": 0.0, + "order6": 0.0, + "order8": 0.0 + } +} \ No newline at end of file diff --git a/lenses/readme.md b/lenses/readme.md new file mode 100644 index 0000000..5f9f96a --- /dev/null +++ b/lenses/readme.md @@ -0,0 +1,6 @@ +# Lens file recources + +Lens data in this folder is collected from the following several resources: + +- https://www.photonstophotos.net/GeneralTopics/Lenses/OpticalBench/OpticalBench.htm +- https://www.lens-designs.com/ \ No newline at end of file