-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathplanned_fft.jl
43 lines (36 loc) · 960 Bytes
/
planned_fft.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
@memoize function make_plan(size::Tuple)
plan_fft(zeros(ComplexF64, size), flags=FFTW.MEASURE)
end
@memoize function make_plan!(size::Tuple)
plan_fft!(zeros(ComplexF64, size), flags=FFTW.MEASURE)
end
function planned_fft(x)
plan = make_plan(size(x))
plan * x
end
function planned_ifft(x)
plan = make_plan(size(x))
plan \ x
end
function planned_fft!(x)
plan = make_plan!(size(x))
plan * x
end
function planned_ifft!(x)
plan = make_plan!(size(x))
plan \ x
end
function rrule(::typeof(planned_fft), x)
N = prod(size(x))
function pullback(∂x)
NoTangent(), N * planned_ifft(∂x)
end
return planned_fft(x), pullback
end
function rrule(::typeof(planned_ifft), x)
N = prod(size(x))
function pullback(∂x)
NoTangent(), planned_fft(∂x) / N
end
return planned_ifft(x), pullback
end