From 3c727ab53b105a4d1397ba8ed97fbb433ad1935c Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Sat, 29 Jul 2023 21:17:07 -0400 Subject: [PATCH] Support new AdjointStyle trait --- Project.toml | 2 +- src/dct.jl | 2 ++ src/fft.jl | 15 +++++++++++++++ test/Project.toml | 8 ++++---- test/runtests.jl | 18 ++++++++++++++++++ 5 files changed, 40 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index 79902d9..b3cd54e 100644 --- a/Project.toml +++ b/Project.toml @@ -11,7 +11,7 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" [compat] -AbstractFFTs = "1.0" +AbstractFFTs = "1.4" FFTW_jll = "3.3.9" MKL_jll = "2019.0.117, 2020, 2021, 2022, 2023" Preferences = "1.2" diff --git a/src/dct.jl b/src/dct.jl index cd3ec60..a921731 100644 --- a/src/dct.jl +++ b/src/dct.jl @@ -171,3 +171,5 @@ end mul!(Array{T}(undef, p.plan.osz), p, copy(x)) # need copy to preserve input *(p::DCTPlan{T,K,true}, x::StridedArray{T}) where {T,K} = mul!(x, p, x) + +AbstractFFTs.AdjointStyle(::DCTPlan) = AbstractFFTs.UnitaryAdjointStyle() diff --git a/src/fft.jl b/src/fft.jl index daa5866..db88044 100644 --- a/src/fft.jl +++ b/src/fft.jl @@ -1049,3 +1049,18 @@ function *(p::r2rFFTWPlan{T,K,true}, x::StridedArray{T}) where {T,K} unsafe_execute!(p, x, x) return x end + +####################################################################### + +""" + R2RAdjointStyle(kinds) + +Projection style for real to real transforms +""" +struct R2RAdjointStyle{K} <: AbstractFFTs.AdjointStyle + kinds::K +end + +AbstractFFTs.AdjointStyle(::cFFTWPlan) = AbstractFFTs.FFTAdjointStyle() +AbstractFFTs.AdjointStyle(::rFFTWPlan{T, FORWARD}) where {T} = AbstractFFTs.RFFTAdjointStyle() +AbstractFFTs.AdjointStyle(P::rFFTWPlan{T, BACKWARD}) where {T} = AbstractFFTs.IRFFTAdjointStyle(P.osz[first(P.region)]) diff --git a/test/Project.toml b/test/Project.toml index c46e7ba..3ac006b 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,8 +1,8 @@ -# A bug in Julia 1.6.0's Pkg causes Preferences to be dropped during `Pkg.test()`, so we work around -# it by explicitly creating a `test/Project.toml` which will correctly communicate any preferences -# through to the child Julia process. X-ref: https://github.com/JuliaLang/Pkg.jl/issues/2500 - [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" diff --git a/test/runtests.jl b/test/runtests.jl index 301194d..9f0e6de 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -577,3 +577,21 @@ end end end end + +@testset "DCT adjoints" begin + for x in (randn(3), randn(4), randn(3, 4), randn(3, 4, 5)) + y = randn(size(x)) + N = ndims(x) + for dims in unique((1, 1:N, N)) + for P in (plan_dct(x, dims), plan_idct(x, dims)) + AbstractFFTs.TestUtils.test_plan_adjoint(P, x) + end + end + end +end + +@testset "AbstractFFTs FFT backend tests" begin + # note this also tests adjoint functionality for FFT plans + AbstractFFTs.TestUtils.test_complex_ffts(Array) + AbstractFFTs.TestUtils.test_real_ffts(Array; copy_input=true) +end