From 98b3c74b937f07658327ed27cc286a82b2768930 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Mon, 6 Mar 2023 10:34:37 -0500 Subject: [PATCH 1/3] Add chain rules for function calls without dims --- src/chainrules.jl | 29 ++++++++++++++++++++++++++ test/runtests.jl | 52 +++++++++++++++++++++++++++-------------------- 2 files changed, 59 insertions(+), 22 deletions(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index 97d4d22..db554c7 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -150,3 +150,32 @@ function ChainRulesCore.rrule(::typeof(ifftshift), x::AbstractArray, dims) end return y, ifftshift_pullback end + +# explicitly handle the default dims argument because e.g. fft(x) does not necessarily call fft(x, dims). (PR #83) +for f in (:fft, :rfft, :ifft, :bfft, :fftshift, :ifftshift) + @eval begin + function ChainRulesCore.frule((_, Δx), ::typeof($f), x::AbstractArray) + dims = 1:ndims(x) + return ChainRulesCore.frule((ChainRulesCore.NoTangent(), Δx, ChainRulesCore.NoTangent()), $f, x, dims) + end + function ChainRulesCore.rrule(::typeof($f), x::AbstractArray) + dims = 1:ndims(x) + y, pb = ChainRulesCore.rrule($f, x, dims) + y, (ȳ -> pb(ȳ)[1:end-1]) + end + end +end +for f in (:irfft, brfft) + @eval begin + function ChainRulesCore.frule((_, Δx, _), ::typeof($f), x::AbstractArray, d::Int) + dims = 1:ndims(x) + Δ = (ChainRulesCore.NoTangent(), Δx, ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent()) + return ChainRulesCore.frule(Δ, $f, x, d, dims) + end + function ChainRulesCore.rrule(::typeof($f), x::AbstractArray, d::Int) + dims = 1:ndims(x) + y, pb = ChainRulesCore.rrule($f, x, d, dims) + y, (ȳ -> pb(ȳ)[1:end-1]) + end + end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 4d402c5..a3a3ca7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -216,19 +216,23 @@ end @testset "ChainRules" begin @testset "shift functions" begin for x in (randn(3), randn(3, 4), randn(3, 4, 5)) - for dims in ((), 1, 2, (1,2), 1:2) - any(d > ndims(x) for d in dims) && continue + # type inference checks of `rrule` fail on old Julia versions + # for higher-dimensional arrays: + # https://github.com/JuliaMath/AbstractFFTs.jl/pull/58#issuecomment-916530016 + check_inferred = ndims(x) < 3 || VERSION >= v"1.6" - # type inference checks of `rrule` fail on old Julia versions - # for higher-dimensional arrays: - # https://github.com/JuliaMath/AbstractFFTs.jl/pull/58#issuecomment-916530016 - check_inferred = ndims(x) < 3 || VERSION >= v"1.6" + for dims in ((), 1, 2, (1,2), 1:2, nothing) + # if dims=nothing, test handling of default dims argument + args = (dims === nothing) ? () : (dims,) + real_dims = (dims === nothing) ? (1:ndims(x)) : dims - test_frule(AbstractFFTs.fftshift, x, dims) - test_rrule(AbstractFFTs.fftshift, x, dims; check_inferred=check_inferred) + any(d > ndims(x) for d in real_dims) && continue - test_frule(AbstractFFTs.ifftshift, x, dims) - test_rrule(AbstractFFTs.ifftshift, x, dims; check_inferred=check_inferred) + test_frule(AbstractFFTs.fftshift, x, args...) + test_rrule(AbstractFFTs.fftshift, x, args...; check_inferred=check_inferred) + + test_frule(AbstractFFTs.ifftshift, x, args...) + test_rrule(AbstractFFTs.ifftshift, x, args...; check_inferred=check_inferred) end end end @@ -237,23 +241,27 @@ end for x in (randn(3), randn(3, 4), randn(3, 4, 5)) N = ndims(x) complex_x = complex.(x) - for dims in unique((1, 1:N, N)) + for dims in unique((1, 1:N, N, nothing)) + # if dims=nothing, test handling of default dims argument + args = (dims === nothing) ? () : (dims,) + real_dims = (dims === nothing) ? (1:N) : dims + for f in (fft, ifft, bfft) - test_frule(f, x, dims) - test_rrule(f, x, dims) - test_frule(f, complex_x, dims) - test_rrule(f, complex_x, dims) + test_frule(f, x, args...) + test_rrule(f, x, args...) + test_frule(f, complex_x, args...) + test_rrule(f, complex_x, args...) end - test_frule(rfft, x, dims) - test_rrule(rfft, x, dims) + test_frule(rfft, x, args...) + test_rrule(rfft, x, args...) for f in (irfft, brfft) - for d in (2 * size(x, first(dims)) - 1, 2 * size(x, first(dims)) - 2) - test_frule(f, x, d, dims) - test_rrule(f, x, d, dims) - test_frule(f, complex_x, d, dims) - test_rrule(f, complex_x, d, dims) + for d in (2 * size(x, first(real_dims)) - 1, 2 * size(x, first(real_dims)) - 2) + test_frule(f, x, d, args...) + test_rrule(f, x, d, args...) + test_frule(f, complex_x, d, args...) + test_rrule(f, complex_x, d, args...) end end end From ce4e395ffbf4dbf24216a2661adc301899baffbe Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Mon, 6 Mar 2023 14:41:55 -0500 Subject: [PATCH 2/3] Write direct chain rules with and without dims argument --- src/chainrules.jl | 160 +++++++++++++++++++++++----------------------- test/runtests.jl | 20 +++--- 2 files changed, 89 insertions(+), 91 deletions(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index db554c7..7d1bb8d 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -1,121 +1,152 @@ # ffts -function ChainRulesCore.frule((_, Δx, _), ::typeof(fft), x::AbstractArray, dims) - y = fft(x, dims) - Δy = fft(Δx, dims) +# we explicitly handle both unprovided and provided dims arguments in all rules, which +# results in some additional complexity here but means no assumptions are made on what +# signatures downstream implementations support. +function ChainRulesCore.frule(Δargs, ::typeof(fft), x::AbstractArray, dims=nothing) + Δx = Δargs[2] + dims_args = (dims === nothing) ? () : (dims,) + y = fft(x, dims_args...) + Δy = fft(Δx, dims_args...) return y, Δy end -function ChainRulesCore.rrule(::typeof(fft), x::AbstractArray, dims) - y = fft(x, dims) +function ChainRulesCore.rrule(::typeof(fft), x::AbstractArray, dims=nothing) + dims_args = (dims === nothing) ? () : (dims,) + y = fft(x, dims_args...) project_x = ChainRulesCore.ProjectTo(x) function fft_pullback(ȳ) - x̄ = project_x(bfft(ChainRulesCore.unthunk(ȳ), dims)) - return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent() + x̄ = project_x(bfft(ChainRulesCore.unthunk(ȳ), dims_args...)) + dims_args_tangent = (dims === nothing) ? () : (ChainRulesCore.NoTangent(),) + return ChainRulesCore.NoTangent(), x̄, dims_args_tangent... end return y, fft_pullback end -function ChainRulesCore.frule((_, Δx, _), ::typeof(rfft), x::AbstractArray{<:Real}, dims) - y = rfft(x, dims) - Δy = rfft(Δx, dims) +function ChainRulesCore.frule(Δargs, ::typeof(rfft), x::AbstractArray{<:Real}, dims=nothing) + Δx = Δargs[2] + dims_args = (dims === nothing) ? () : (dims,) + y = rfft(x, dims_args...) + Δy = rfft(Δx, dims_args...) return y, Δy end -function ChainRulesCore.rrule(::typeof(rfft), x::AbstractArray{<:Real}, dims) - y = rfft(x, dims) +function ChainRulesCore.rrule(::typeof(rfft), x::AbstractArray{<:Real}, dims=nothing) + dims_args = (dims === nothing) ? () : (dims,) + true_dims = (dims === nothing) ? (1:ndims(x)) : dims + y = rfft(x, dims_args...) # compute scaling factors - halfdim = first(dims) + halfdim = first(true_dims) d = size(x, halfdim) n = size(y, halfdim) scale = reshape( [i == 1 || (i == n && 2 * (i - 1) == d) ? 1 : 2 for i in 1:n], - ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))), + ntuple(i -> i == first(true_dims) ? n : 1, Val(ndims(x))), ) project_x = ChainRulesCore.ProjectTo(x) function rfft_pullback(ȳ) - x̄ = project_x(brfft(ChainRulesCore.unthunk(ȳ) ./ scale, d, dims)) - return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent() + x̄ = project_x(brfft(ChainRulesCore.unthunk(ȳ) ./ scale, d, dims_args...)) + dims_args_tangent = (dims === nothing) ? () : (ChainRulesCore.NoTangent(),) + return ChainRulesCore.NoTangent(), x̄, dims_args_tangent... end return y, rfft_pullback end -function ChainRulesCore.frule((_, Δx, _), ::typeof(ifft), x::AbstractArray, dims) - y = ifft(x, dims) - Δy = ifft(Δx, dims) +function ChainRulesCore.frule(Δargs, ::typeof(ifft), x::AbstractArray, dims=nothing) + Δx = Δargs[2] + args = (dims === nothing) ? () : (dims,) + y = ifft(x, args...) + Δy = ifft(Δx, args...) return y, Δy end -function ChainRulesCore.rrule(::typeof(ifft), x::AbstractArray, dims) - y = ifft(x, dims) - invN = normalization(y, dims) +function ChainRulesCore.rrule(::typeof(ifft), x::AbstractArray, dims=nothing) + dims_args = (dims === nothing) ? () : (dims,) + true_dims = (dims === nothing) ? (1:ndims(x)) : dims + y = ifft(x, dims_args...) + invN = normalization(y, true_dims) project_x = ChainRulesCore.ProjectTo(x) function ifft_pullback(ȳ) - x̄ = project_x(invN .* fft(ChainRulesCore.unthunk(ȳ), dims)) - return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent() + x̄ = project_x(invN .* fft(ChainRulesCore.unthunk(ȳ), dims_args...)) + dims_args_tangent = (dims === nothing) ? () : (ChainRulesCore.NoTangent(),) + return ChainRulesCore.NoTangent(), x̄, dims_args_tangent... end return y, ifft_pullback end -function ChainRulesCore.frule((_, Δx, _, _), ::typeof(irfft), x::AbstractArray, d::Int, dims) - y = irfft(x, d, dims) - Δy = irfft(Δx, d, dims) +function ChainRulesCore.frule(Δargs, ::typeof(irfft), x::AbstractArray, d::Int, dims=nothing) + Δx = Δargs[2] + dims_args = (dims === nothing) ? () : (dims,) + y = irfft(x, d, dims_args...) + Δy = irfft(Δx, d, dims_args...) return y, Δy end -function ChainRulesCore.rrule(::typeof(irfft), x::AbstractArray, d::Int, dims) - y = irfft(x, d, dims) +function ChainRulesCore.rrule(::typeof(irfft), x::AbstractArray, d::Int, dims=nothing) + dims_args = (dims === nothing) ? () : (dims,) + true_dims = (dims === nothing) ? (1:ndims(x)) : dims + y = irfft(x, d, dims_args...) # compute scaling factors - halfdim = first(dims) + halfdim = first(true_dims) n = size(x, halfdim) - invN = normalization(y, dims) + invN = normalization(y, true_dims) twoinvN = 2 * invN scale = reshape( [i == 1 || (i == n && 2 * (i - 1) == d) ? invN : twoinvN for i in 1:n], - ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))), + ntuple(i -> i == first(true_dims) ? n : 1, Val(ndims(x))), ) project_x = ChainRulesCore.ProjectTo(x) function irfft_pullback(ȳ) - x̄ = project_x(scale .* rfft(real.(ChainRulesCore.unthunk(ȳ)), dims)) - return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent() + x̄ = project_x(scale .* rfft(real.(ChainRulesCore.unthunk(ȳ)), dims_args...)) + dims_args_tangent = (dims === nothing) ? () : (ChainRulesCore.NoTangent(),) + return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), dims_args_tangent... end return y, irfft_pullback end -function ChainRulesCore.frule((_, Δx, _), ::typeof(bfft), x::AbstractArray, dims) - y = bfft(x, dims) - Δy = bfft(Δx, dims) +function ChainRulesCore.frule(Δargs, ::typeof(bfft), x::AbstractArray, dims=nothing) + Δx = Δargs[2] + dims_args = (dims === nothing) ? () : (dims,) + y = bfft(x, dims_args...) + Δy = bfft(Δx, dims_args...) return y, Δy end -function ChainRulesCore.rrule(::typeof(bfft), x::AbstractArray, dims) - y = bfft(x, dims) +function ChainRulesCore.rrule(::typeof(bfft), x::AbstractArray, dims=nothing) + dims_args = (dims === nothing) ? () : (dims,) + y = bfft(x, dims_args...) project_x = ChainRulesCore.ProjectTo(x) function bfft_pullback(ȳ) - x̄ = project_x(fft(ChainRulesCore.unthunk(ȳ), dims)) - return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent() + x̄ = project_x(fft(ChainRulesCore.unthunk(ȳ), dims_args...)) + dims_args_tangent = (dims === nothing) ? () : (ChainRulesCore.NoTangent(),) + return ChainRulesCore.NoTangent(), x̄, dims_args_tangent... end return y, bfft_pullback end -function ChainRulesCore.frule((_, Δx, _, _), ::typeof(brfft), x::AbstractArray, d::Int, dims) - y = brfft(x, d, dims) - Δy = brfft(Δx, d, dims) +function ChainRulesCore.frule(Δargs, ::typeof(brfft), x::AbstractArray, d::Int, dims=nothing) + Δx = Δargs[2] + dims_args = (dims === nothing) ? () : (dims,) + y = brfft(x, d, dims_args...) + Δy = brfft(Δx, d, dims_args...) return y, Δy end -function ChainRulesCore.rrule(::typeof(brfft), x::AbstractArray, d::Int, dims) - y = brfft(x, d, dims) +function ChainRulesCore.rrule(::typeof(brfft), x::AbstractArray, d::Int, dims=nothing) + dims_args = (dims === nothing) ? () : (dims,) + true_dims = (dims === nothing) ? (1:ndims(x)) : dims + y = brfft(x, d, dims_args...) # compute scaling factors - halfdim = first(dims) + halfdim = first(true_dims) n = size(x, halfdim) scale = reshape( [i == 1 || (i == n && 2 * (i - 1) == d) ? 1 : 2 for i in 1:n], - ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))), + ntuple(i -> i == first(true_dims) ? n : 1, Val(ndims(x))), ) project_x = ChainRulesCore.ProjectTo(x) function brfft_pullback(ȳ) - x̄ = project_x(scale .* rfft(real.(ChainRulesCore.unthunk(ȳ)), dims)) - return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent() + x̄ = project_x(scale .* rfft(real.(ChainRulesCore.unthunk(ȳ)), dims_args...)) + dims_args_tangent = (dims === nothing) ? () : (ChainRulesCore.NoTangent(),) + return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), dims_args_tangent... end return y, brfft_pullback end @@ -150,32 +181,3 @@ function ChainRulesCore.rrule(::typeof(ifftshift), x::AbstractArray, dims) end return y, ifftshift_pullback end - -# explicitly handle the default dims argument because e.g. fft(x) does not necessarily call fft(x, dims). (PR #83) -for f in (:fft, :rfft, :ifft, :bfft, :fftshift, :ifftshift) - @eval begin - function ChainRulesCore.frule((_, Δx), ::typeof($f), x::AbstractArray) - dims = 1:ndims(x) - return ChainRulesCore.frule((ChainRulesCore.NoTangent(), Δx, ChainRulesCore.NoTangent()), $f, x, dims) - end - function ChainRulesCore.rrule(::typeof($f), x::AbstractArray) - dims = 1:ndims(x) - y, pb = ChainRulesCore.rrule($f, x, dims) - y, (ȳ -> pb(ȳ)[1:end-1]) - end - end -end -for f in (:irfft, brfft) - @eval begin - function ChainRulesCore.frule((_, Δx, _), ::typeof($f), x::AbstractArray, d::Int) - dims = 1:ndims(x) - Δ = (ChainRulesCore.NoTangent(), Δx, ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent()) - return ChainRulesCore.frule(Δ, $f, x, d, dims) - end - function ChainRulesCore.rrule(::typeof($f), x::AbstractArray, d::Int) - dims = 1:ndims(x) - y, pb = ChainRulesCore.rrule($f, x, d, dims) - y, (ȳ -> pb(ȳ)[1:end-1]) - end - end -end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index a3a3ca7..40ca3e4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -221,18 +221,14 @@ end # https://github.com/JuliaMath/AbstractFFTs.jl/pull/58#issuecomment-916530016 check_inferred = ndims(x) < 3 || VERSION >= v"1.6" - for dims in ((), 1, 2, (1,2), 1:2, nothing) - # if dims=nothing, test handling of default dims argument - args = (dims === nothing) ? () : (dims,) - real_dims = (dims === nothing) ? (1:ndims(x)) : dims - - any(d > ndims(x) for d in real_dims) && continue + for dims in ((), 1, 2, (1,2), 1:2) + any(d > ndims(x) for d in dims) && continue - test_frule(AbstractFFTs.fftshift, x, args...) - test_rrule(AbstractFFTs.fftshift, x, args...; check_inferred=check_inferred) + test_frule(AbstractFFTs.fftshift, x, dims) + test_rrule(AbstractFFTs.fftshift, x, dims; check_inferred=check_inferred) - test_frule(AbstractFFTs.ifftshift, x, args...) - test_rrule(AbstractFFTs.ifftshift, x, args...; check_inferred=check_inferred) + test_frule(AbstractFFTs.ifftshift, x, dims) + test_rrule(AbstractFFTs.ifftshift, x, dims; check_inferred=check_inferred) end end end @@ -244,7 +240,7 @@ end for dims in unique((1, 1:N, N, nothing)) # if dims=nothing, test handling of default dims argument args = (dims === nothing) ? () : (dims,) - real_dims = (dims === nothing) ? (1:N) : dims + true_dims = (dims === nothing) ? (1:N) : dims for f in (fft, ifft, bfft) test_frule(f, x, args...) @@ -257,7 +253,7 @@ end test_rrule(rfft, x, args...) for f in (irfft, brfft) - for d in (2 * size(x, first(real_dims)) - 1, 2 * size(x, first(real_dims)) - 2) + for d in (2 * size(x, first(true_dims)) - 1, 2 * size(x, first(true_dims)) - 2) test_frule(f, x, d, args...) test_rrule(f, x, d, args...) test_frule(f, complex_x, d, args...) From 13cf3af6591efa86ed1295ce4fbf272c120ddbde Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Mon, 6 Mar 2023 14:56:54 -0500 Subject: [PATCH 3/3] Use same naming in tests as in rules --- test/runtests.jl | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 40ca3e4..b8998bd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -239,25 +239,25 @@ end complex_x = complex.(x) for dims in unique((1, 1:N, N, nothing)) # if dims=nothing, test handling of default dims argument - args = (dims === nothing) ? () : (dims,) + dims_args = (dims === nothing) ? () : (dims,) true_dims = (dims === nothing) ? (1:N) : dims for f in (fft, ifft, bfft) - test_frule(f, x, args...) - test_rrule(f, x, args...) - test_frule(f, complex_x, args...) - test_rrule(f, complex_x, args...) + test_frule(f, x, dims_args...) + test_rrule(f, x, dims_args...) + test_frule(f, complex_x, dims_args...) + test_rrule(f, complex_x, dims_args...) end - test_frule(rfft, x, args...) - test_rrule(rfft, x, args...) + test_frule(rfft, x, dims_args...) + test_rrule(rfft, x, dims_args...) for f in (irfft, brfft) for d in (2 * size(x, first(true_dims)) - 1, 2 * size(x, first(true_dims)) - 2) - test_frule(f, x, d, args...) - test_rrule(f, x, d, args...) - test_frule(f, complex_x, d, args...) - test_rrule(f, complex_x, d, args...) + test_frule(f, x, d, dims_args...) + test_rrule(f, x, d, dims_args...) + test_frule(f, complex_x, d, dims_args...) + test_rrule(f, complex_x, d, dims_args...) end end end