Skip to content

Commit

Permalink
fix gradient of filters
Browse files Browse the repository at this point in the history
  • Loading branch information
smartalecH committed Mar 24, 2022
1 parent d104c99 commit dfdbba3
Showing 1 changed file with 28 additions and 7 deletions.
35 changes: 28 additions & 7 deletions python/adjoint/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
import meep as mp
from scipy import special
from scipy import signal
from autograd.extend import primitive, defvjp
from functools import partial

def _proper_pad(x,n):
'''
Expand Down Expand Up @@ -36,6 +34,28 @@ def _centered(arr, newshape):
myslice = [slice(startind[k], endind[k]) for k in range(len(endind))]
return arr[tuple(myslice)]

def _edge_pad(arr, pad):

# fill sides
left = npa.tile(arr[0, :], (pad[0][0], 1)) # left side
right = npa.tile(arr[-1, :], (pad[0][1], 1)) # right side
top = npa.tile(arr[:, 0], (pad[1][0], 1)).transpose() # top side
bottom = npa.tile(arr[:, -1], (pad[1][1], 1)).transpose() # bottom side)

# fill corners
top_left = npa.tile(arr[0, 0], (pad[0][0], pad[1][0])) # top left
top_right = npa.tile(arr[-1, 0], (pad[0][1], pad[1][0])) # top right
bottom_left = npa.tile(arr[0, -1], (pad[0][0], pad[1][1])) # bottom left
bottom_right = npa.tile(arr[-1, -1],
(pad[0][1], pad[1][1])) # bottom right

out = npa.concatenate((npa.concatenate(
(top_left, top, top_right)), npa.concatenate((left, arr, right)),
npa.concatenate(
(bottom_left, bottom, bottom_right))),
axis=1)

return out

def simple_2d_filter(x, h):
"""A simple 2d filter algorithm that is differentiable with autograd.
Expand All @@ -56,11 +76,9 @@ def simple_2d_filter(x, h):
array_like (2D)
The output of the 2d convolution.
"""
x_shape = x.shape
x_pad = [[k,k] for k in x_shape]
x = np.pad(x,x_pad,'edge')
return _centered(np.real(npa.fft.ifft2(npa.fft.fft2(x)*npa.fft.fft2(h))),x_shape)

(kx, ky) = x.shape
x = _edge_pad(x,((kx, kx), (ky, ky)))
return _centered(npa.real(npa.fft.ifft2(npa.fft.fft2(x)*npa.fft.fft2(h))),(kx, ky))

def cylindrical_filter(x, radius, Lx, Ly, resolution):
'''A uniform cylindrical filter [1]. Typically allows for sharper transitions.
Expand Down Expand Up @@ -90,6 +108,7 @@ def cylindrical_filter(x, radius, Lx, Ly, resolution):
'''
Nx = int(Lx*resolution)
Ny = int(Ly*resolution)
x = x.reshape(Nx, Ny) # Ensure the input is 2D

xv = np.arange(0,Lx/2,1/resolution)
yv = np.arange(0,Ly/2,1/resolution)
Expand Down Expand Up @@ -135,6 +154,7 @@ def conic_filter(x, radius, Lx, Ly, resolution):
'''
Nx = int(Lx*resolution)
Ny = int(Ly*resolution)
x = x.reshape(Nx, Ny) # Ensure the input is 2D

xv = np.arange(0,Lx/2,1/resolution)
yv = np.arange(0,Ly/2,1/resolution)
Expand Down Expand Up @@ -180,6 +200,7 @@ def gaussian_filter(x, sigma, Lx, Ly, resolution):
'''
Nx = int(Lx*resolution)
Ny = int(Ly*resolution)
x = x.reshape(Nx, Ny) # Ensure the input is 2D

xv = np.arange(0,Lx/2,1/resolution)
yv = np.arange(0,Ly/2,1/resolution)
Expand Down

0 comments on commit dfdbba3

Please sign in to comment.