diff --git a/Project.toml b/Project.toml index 79902d9..e53ebb6 100644 --- a/Project.toml +++ b/Project.toml @@ -11,7 +11,7 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" [compat] -AbstractFFTs = "1.0" +AbstractFFTs = "1.5" FFTW_jll = "3.3.9" MKL_jll = "2019.0.117, 2020, 2021, 2022, 2023" Preferences = "1.2" diff --git a/src/dct.jl b/src/dct.jl index cd3ec60..a921731 100644 --- a/src/dct.jl +++ b/src/dct.jl @@ -171,3 +171,5 @@ end mul!(Array{T}(undef, p.plan.osz), p, copy(x)) # need copy to preserve input *(p::DCTPlan{T,K,true}, x::StridedArray{T}) where {T,K} = mul!(x, p, x) + +AbstractFFTs.AdjointStyle(::DCTPlan) = AbstractFFTs.UnitaryAdjointStyle() diff --git a/src/fft.jl b/src/fft.jl index daa5866..4831a18 100644 --- a/src/fft.jl +++ b/src/fft.jl @@ -1049,3 +1049,9 @@ function *(p::r2rFFTWPlan{T,K,true}, x::StridedArray{T}) where {T,K} unsafe_execute!(p, x, x) return x end + +####################################################################### + +AbstractFFTs.AdjointStyle(::cFFTWPlan) = AbstractFFTs.FFTAdjointStyle() +AbstractFFTs.AdjointStyle(::rFFTWPlan{T, FORWARD}) where {T} = AbstractFFTs.RFFTAdjointStyle() +AbstractFFTs.AdjointStyle(p::rFFTWPlan{T, BACKWARD}) where {T} = AbstractFFTs.IRFFTAdjointStyle(p.osz[first(p.region)]) diff --git a/test/runtests.jl b/test/runtests.jl index 301194d..981f3f1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -577,3 +577,27 @@ end end end end + +@testset "DCT adjoints" begin + # only test on FFTW because MKL is missing functionality + if FFTW.get_provider() == "fftw" + for x in (randn(3), randn(4), randn(3, 4), randn(3, 4, 5)) + y = randn(size(x)) + N = ndims(x) + for dims in unique((1, 1:N, N)) + for P in (plan_dct(x, dims), plan_idct(x, dims)) + AbstractFFTs.TestUtils.test_plan_adjoint(P, x) + end + end + end + end +end + +@testset "AbstractFFTs FFT backend tests" begin + # note this also tests adjoint functionality for FFT plans + # only test on FFTW because MKL is missing functionality + if FFTW.get_provider() == "fftw" + AbstractFFTs.TestUtils.test_complex_ffts(Array) + AbstractFFTs.TestUtils.test_real_ffts(Array; copy_input=true) + end +end