Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support plan_inv for ScaledPlan's #77

Merged
merged 3 commits into from
Aug 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/definitions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,8 @@ plan_ifft(x::AbstractArray, region; kws...) =
plan_ifft!(x::AbstractArray, region; kws...) =
ScaledPlan(plan_bfft!(x, region; kws...), normalization(x, region))

plan_inv(p::ScaledPlan) = ScaledPlan(plan_inv(p.p), inv(p.scale))
# Don't cache inverse of scaled plan (only inverse of inner plan)
inv(p::ScaledPlan) = ScaledPlan(inv(p.p), inv(p.scale))

LinearAlgebra.mul!(y::AbstractArray, p::ScaledPlan, x::AbstractArray) =
Expand Down
52 changes: 34 additions & 18 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,15 @@ end
dims = ndims(x)
y = AbstractFFTs.fft(x, dims)
@test y ≈ fftw_fft
P = plan_fft(x, dims)
@test eltype(P) === ComplexF64
@test P * x ≈ fftw_fft
@test P \ (P * x) ≈ x
@test fftdims(P) == dims
# test plan_fft and also inv and plan_inv of plan_ifft, which should all give
# functionally identical plans
for P in [plan_fft(x, dims), inv(plan_ifft(x, dims)),
AbstractFFTs.plan_inv(plan_ifft(x, dims))]
@test eltype(P) === ComplexF64
@test P * x ≈ fftw_fft
@test P \ (P * x) ≈ x
@test fftdims(P) == dims
end

fftw_bfft = complex.(size(x, dims) .* x)
@test AbstractFFTs.bfft(y, dims) ≈ fftw_bfft
Expand All @@ -71,10 +75,14 @@ end

fftw_ifft = complex.(x)
@test AbstractFFTs.ifft(y, dims) ≈ fftw_ifft
P = plan_ifft(x, dims)
@test P * y ≈ fftw_ifft
@test P \ (P * y) ≈ y
@test fftdims(P) == dims
# test plan_ifft and also inv and plan_inv of plan_fft, which should all give
# functionally identical plans
for P in [plan_ifft(x, dims), inv(plan_fft(x, dims)),
AbstractFFTs.plan_inv(plan_fft(x, dims))]
@test P * y ≈ fftw_ifft
@test P \ (P * y) ≈ y
@test fftdims(P) == dims
end

# real FFT
fftw_rfft = fftw_fft[
Expand All @@ -83,11 +91,15 @@ end
]
ry = AbstractFFTs.rfft(x, dims)
@test ry ≈ fftw_rfft
P = plan_rfft(x, dims)
@test eltype(P) === Int
@test P * x ≈ fftw_rfft
@test P \ (P * x) ≈ x
@test fftdims(P) == dims
# test plan_rfft and also inv and plan_inv of plan_irfft, which should all give
# functionally identical plans
for P in [plan_rfft(x, dims), inv(plan_irfft(ry, size(x, dims), dims)),
AbstractFFTs.plan_inv(plan_irfft(ry, size(x, dims), dims))]
@test eltype(P) <: Real
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other than add the loop over P, the only other change I made was to this assertion. Previously the assertion was that the eltype should be Int here. That doesn't seem right (for FFTW the eltype would be Float64); rather than thinking too hard about what the test plan implementation should be doing (it's probably underspecified), I just made the check a little looser.

@test P * x ≈ fftw_rfft
@test P \ (P * x) ≈ x
@test fftdims(P) == dims
end

fftw_brfft = complex.(size(x, dims) .* x)
@test AbstractFFTs.brfft(ry, size(x, dims), dims) ≈ fftw_brfft
Expand All @@ -98,10 +110,14 @@ end

fftw_irfft = complex.(x)
@test AbstractFFTs.irfft(ry, size(x, dims), dims) ≈ fftw_irfft
P = plan_irfft(ry, size(x, dims), dims)
@test P * ry ≈ fftw_irfft
@test P \ (P * ry) ≈ ry
@test fftdims(P) == dims
# test plan_rfft and also inv and plan_inv of plan_irfft, which should all give
# functionally identical plans
for P in [plan_irfft(ry, size(x, dims), dims), inv(plan_rfft(x, dims)),
AbstractFFTs.plan_inv(plan_rfft(x, dims))]
@test P * ry ≈ fftw_irfft
@test P \ (P * ry) ≈ ry
@test fftdims(P) == dims
end
end
end

Expand Down
8 changes: 4 additions & 4 deletions test/testplans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,11 @@ Base.:*(p::InverseTestPlan, x::AbstractArray) = mul!(similar(x, complex(float(el
mutable struct TestRPlan{T,N} <: Plan{T}
region
sz::NTuple{N,Int}
pinv::Plan{T}
pinv::Plan{Complex{T}}
TestRPlan{T}(region, sz::NTuple{N,Int}) where {T,N} = new{T,N}(region, sz)
end

mutable struct InverseTestRPlan{T,N} <: Plan{T}
mutable struct InverseTestRPlan{T,N} <: Plan{Complex{T}}
d::Int
region
sz::NTuple{N,Int}
Expand All @@ -107,10 +107,10 @@ mutable struct InverseTestRPlan{T,N} <: Plan{T}
end
end

function AbstractFFTs.plan_rfft(x::AbstractArray{T}, region; kwargs...) where {T}
function AbstractFFTs.plan_rfft(x::AbstractArray{T}, region; kwargs...) where {T<:Real}
return TestRPlan{T}(region, size(x))
end
function AbstractFFTs.plan_brfft(x::AbstractArray{T}, d, region; kwargs...) where {T}
function AbstractFFTs.plan_brfft(x::AbstractArray{Complex{T}}, d, region; kwargs...) where {T}
return InverseTestRPlan{T}(d, region, size(x))
end
function AbstractFFTs.plan_inv(p::TestRPlan{T,N}) where {T,N}
Expand Down