diff --git a/python/adjoint/filters.py b/python/adjoint/filters.py index bba940bb4..c7eddd688 100644 --- a/python/adjoint/filters.py +++ b/python/adjoint/filters.py @@ -7,8 +7,6 @@ import meep as mp from scipy import special from scipy import signal -from autograd.extend import primitive, defvjp -from functools import partial def _proper_pad(x,n): ''' @@ -36,6 +34,28 @@ def _centered(arr, newshape): myslice = [slice(startind[k], endind[k]) for k in range(len(endind))] return arr[tuple(myslice)] +def _edge_pad(arr, pad): + + # fill sides + left = npa.tile(arr[0, :], (pad[0][0], 1)) # left side + right = npa.tile(arr[-1, :], (pad[0][1], 1)) # right side + top = npa.tile(arr[:, 0], (pad[1][0], 1)).transpose() # top side + bottom = npa.tile(arr[:, -1], (pad[1][1], 1)).transpose() # bottom side) + + # fill corners + top_left = npa.tile(arr[0, 0], (pad[0][0], pad[1][0])) # top left + top_right = npa.tile(arr[-1, 0], (pad[0][1], pad[1][0])) # top right + bottom_left = npa.tile(arr[0, -1], (pad[0][0], pad[1][1])) # bottom left + bottom_right = npa.tile(arr[-1, -1], + (pad[0][1], pad[1][1])) # bottom right + + out = npa.concatenate((npa.concatenate( + (top_left, top, top_right)), npa.concatenate((left, arr, right)), + npa.concatenate( + (bottom_left, bottom, bottom_right))), + axis=1) + + return out def simple_2d_filter(x, h): """A simple 2d filter algorithm that is differentiable with autograd. @@ -56,11 +76,9 @@ def simple_2d_filter(x, h): array_like (2D) The output of the 2d convolution. """ - x_shape = x.shape - x_pad = [[k,k] for k in x_shape] - x = np.pad(x,x_pad,'edge') - return _centered(np.real(npa.fft.ifft2(npa.fft.fft2(x)*npa.fft.fft2(h))),x_shape) - + (kx, ky) = x.shape + x = _edge_pad(x,((kx, kx), (ky, ky))) + return _centered(npa.real(npa.fft.ifft2(npa.fft.fft2(x)*npa.fft.fft2(h))),(kx, ky)) def cylindrical_filter(x, radius, Lx, Ly, resolution): '''A uniform cylindrical filter [1]. Typically allows for sharper transitions. @@ -90,6 +108,7 @@ def cylindrical_filter(x, radius, Lx, Ly, resolution): ''' Nx = int(Lx*resolution) Ny = int(Ly*resolution) + x = x.reshape(Nx, Ny) # Ensure the input is 2D xv = np.arange(0,Lx/2,1/resolution) yv = np.arange(0,Ly/2,1/resolution) @@ -135,6 +154,7 @@ def conic_filter(x, radius, Lx, Ly, resolution): ''' Nx = int(Lx*resolution) Ny = int(Ly*resolution) + x = x.reshape(Nx, Ny) # Ensure the input is 2D xv = np.arange(0,Lx/2,1/resolution) yv = np.arange(0,Ly/2,1/resolution) @@ -180,6 +200,7 @@ def gaussian_filter(x, sigma, Lx, Ly, resolution): ''' Nx = int(Lx*resolution) Ny = int(Ly*resolution) + x = x.reshape(Nx, Ny) # Ensure the input is 2D xv = np.arange(0,Lx/2,1/resolution) yv = np.arange(0,Ly/2,1/resolution)