diff --git a/Project.toml b/Project.toml index e53ebb6..54e1b2e 100644 --- a/Project.toml +++ b/Project.toml @@ -9,11 +9,22 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + +[extensions] +FFTWChainRulesCoreExt = "ChainRulesCore" [compat] +ChainRulesCore = "1" AbstractFFTs = "1.5" FFTW_jll = "3.3.9" MKL_jll = "2019.0.117, 2020, 2021, 2022, 2023" Preferences = "1.2" Reexport = "0.2, 1.0" julia = "1.6" + +[extras] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + +[weakdeps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/ext/FFTWChainRulesCoreExt.jl b/ext/FFTWChainRulesCoreExt.jl new file mode 100644 index 0000000..789e4e9 --- /dev/null +++ b/ext/FFTWChainRulesCoreExt.jl @@ -0,0 +1,42 @@ +module FFTWChainRulesCoreExt + +using FFTW +using FFTW: r2r +using ChainRulesCore + +# DCT/IDCT + +for (fwd, bwd) in ( + (dct, idct), + (idct, dct), +) + function ChainRulesCore.frule(Δ, ::typeof(fwd), x::AbstractArray, region = 1:ndims(x)) + Δx = Δ[2] + y = fwd(x, region) + Δy = fwd(Δx, region) + return y, Δy + end + + function ChainRulesCore.rrule(::typeof(fwd), x::AbstractArray) + project_x = ProjectTo(x) + dct_pb(Δ) = NoTangent(), project_x(bwd(unthunk(Δ))) + return fwd(x), dct_pb + end + + function ChainRulesCore.rrule(::typeof(fwd), x::AbstractArray, region) + project_x = ProjectTo(x) + dct_pb(Δ) = NoTangent(), project_x(bwd(unthunk(Δ), region)), NoTangent() + return fwd(x, region), dct_pb + end +end + +# R2R + +function ChainRulesCore.frule(Δ, ::typeof(r2r), x::AbstractArray, kind, region = 1:ndims(x)) + Δx = Δ[2] + y = r2r(x, kind, region) + Δy = r2r(Δx, kind, region) + return y, Δy +end + +end # module diff --git a/src/FFTW.jl b/src/FFTW.jl index 4366ee7..82d6bde 100644 --- a/src/FFTW.jl +++ b/src/FFTW.jl @@ -72,4 +72,8 @@ include("dct.jl") include("precompile.jl") _precompile_() +@static if !isdefined(Base, :get_extension) + include("../ext/FFTWChainRulesCoreExt.jl") +end + end # module diff --git a/test/Project.toml b/test/Project.toml index c46e7ba..a477908 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,8 +1,6 @@ -# A bug in Julia 1.6.0's Pkg causes Preferences to be dropped during `Pkg.test()`, so we work around -# it by explicitly creating a `test/Project.toml` which will correctly communicate any preferences -# through to the child Julia process. X-ref: https://github.com/JuliaLang/Pkg.jl/issues/2500 - [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/runtests.jl b/test/runtests.jl index 981f3f1..351054d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,6 @@ # This file was formerly a part of Julia. License is MIT: https://julialang.org/license using FFTW -using FFTW: fftw_provider +using FFTW: fftw_provider, r2r using AbstractFFTs: Plan, plan_inv using Test using LinearAlgebra