Skip to content

Commit

Permalink
Support non-1 indices and fix type problems in DFT (fixes #17896)
Browse files Browse the repository at this point in the history
  • Loading branch information
timholy committed Aug 9, 2016
1 parent b378ece commit 7f55ca5
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 8 deletions.
29 changes: 22 additions & 7 deletions base/dft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,28 @@ export fft, ifft, bfft, fft!, ifft!, bfft!,
plan_fft, plan_ifft, plan_bfft, plan_fft!, plan_ifft!, plan_bfft!,
rfft, irfft, brfft, plan_rfft, plan_irfft, plan_brfft

complexfloat{T<:AbstractFloat}(x::AbstractArray{Complex{T}}) = x
typealias FFTWFloat Union{Float32,Float64}
fftwfloat(x) = _fftwfloat(float(x))
_fftwfloat{T<:FFTWFloat}(::Type{T}) = T
_fftwfloat(::Type{Float16}) = Float32
_fftwfloat{T}(::Type{T}) = error("type $T not supported")
_fftwfloat{T}(x::T) = _fftwfloat(T)(x)

complexfloat{T<:FFTWFloat}(x::StridedArray{Complex{T}}) = x
realfloat{T<:FFTWFloat}(x::StridedArray{T}) = x

# return an Array, rather than similar(x), to avoid an extra copy for FFTW
# (which only works on StridedArray types).
complexfloat{T<:Complex}(x::AbstractArray{T}) = copy!(Array{typeof(float(one(T)))}(size(x)), x)
complexfloat{T<:AbstractFloat}(x::AbstractArray{T}) = copy!(Array{typeof(complex(one(T)))}(size(x)), x)
complexfloat{T<:Real}(x::AbstractArray{T}) = copy!(Array{typeof(complex(float(one(T))))}(size(x)), x)
complexfloat{T<:Complex}(x::AbstractArray{T}) = copy1(typeof(fftwfloat(one(T))), x)
complexfloat{T<:Real}(x::AbstractArray{T}) = copy1(typeof(complex(fftwfloat(one(T)))), x)

realfloat{T<:Real}(x::AbstractArray{T}) = copy1(typeof(fftwfloat(one(T))), x)

# copy to a 1-based array, using circular permutation
function copy1{T}(::Type{T}, x)
y = Array{T}(map(length, indices(x)))
circcopy!(y, x)
end

# implementations only need to provide plan_X(x, region)
# for X in (:fft, :bfft, ...):
Expand Down Expand Up @@ -187,11 +202,11 @@ for f in (:fft, :bfft, :ifft)
$pf{T<:Union{Integer,Rational}}(x::AbstractArray{Complex{T}}, region; kws...) = $pf(complexfloat(x), region; kws...)
end
end
rfft{T<:Union{Integer,Rational}}(x::AbstractArray{T}, region=1:ndims(x)) = rfft(float(x), region)
plan_rfft{T<:Union{Integer,Rational}}(x::AbstractArray{T}, region; kws...) = plan_rfft(float(x), region; kws...)
rfft{T<:Union{Integer,Rational}}(x::AbstractArray{T}, region=1:ndims(x)) = rfft(realfloat(x), region)
plan_rfft(x::AbstractArray, region; kws...) = plan_rfft(realfloat(x), region; kws...)

# only require implementation to provide *(::Plan{T}, ::Array{T})
*{T}(p::Plan{T}, x::AbstractArray) = p * copy!(Array{T}(size(x)), x)
*{T}(p::Plan{T}, x::AbstractArray) = p * copy1(T, x)

# Implementations should also implement A_mul_B!(Y, plan, X) so as to support
# pre-allocated output arrays. We don't define * in terms of A_mul_B!
Expand Down
1 change: 1 addition & 0 deletions base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,7 @@ export
cat,
checkbounds,
checkindex,
circcopy!,
circshift,
circshift!,
clamp!,
Expand Down
60 changes: 60 additions & 0 deletions base/multidimensional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,66 @@ function _circshift!(dest, rdest, src, rsrc, inds, shiftamt)
copy!(dest, CartesianRange(rdest), src, CartesianRange(rsrc))
end

# circcopy!
"""
circcopy!(dest, src)
Copy `src` to `dest`, indexing each dimension modulo its length.
`src` and `dest` must have the same size, but can be offset in
their indices; any offset results in a (circular) wraparound. If the
arrays have overlapping indices, then on the domain of the overlap
`dest` agrees with `src`.
```jldoctest
julia> src = reshape(collect(1:16), (4,4))
4×4 Array{Int64,2}:
1 5 9 13
2 6 10 14
3 7 11 15
4 8 12 16
julia> dest = OffsetArray(Int, (-1,1));
julia> circcopy!(dest, src)
OffsetArrays.OffsetArray{Int64,2,Array{Int64,2}} with indices 0:3×2:5:
8 12 16 4
5 9 13 1
6 10 14 2
7 11 15 3
julia> dest[1:3,2:4] == src[1:3,2:4]
true
"""
function circcopy!(dest, src)
dest === src && throw(ArgumentError("dest and src must be separate arrays"))
indssrc, indsdest = indices(src), indices(dest)
if (szsrc = map(length, indssrc)) != (szdest = map(length, indsdest))
throw(DimensionMismatch("src and dest must have the same sizes (got $szsrc and $szdest)"))
end
shift = map((isrc, idest)->first(isrc)-first(idest), indssrc, indsdest)
all(x->x==0, shift) && return copy!(dest, src)
_circcopy!(dest, (), indsdest, src, (), indssrc)
end

@inline function _circcopy!(dest, rdest, indsdest::Tuple{AbstractUnitRange,Vararg{Any}},
src, rsrc, indssrc::Tuple{AbstractUnitRange,Vararg{Any}})
indd1, inds1 = indsdest[1], indssrc[1]
l = length(indd1)
s = mod(first(inds1)-first(indd1), l)
sdf = first(indd1)+s
rd1, rd2 = first(indd1):sdf-1, sdf:last(indd1)
ssf = last(inds1)-s
rs1, rs2 = first(inds1):ssf, ssf+1:last(inds1)
tindsd, tindss = tail(indsdest), tail(indssrc)
_circcopy!(dest, (rdest..., rd1), tindsd, src, (rsrc..., rs2), tindss)
_circcopy!(dest, (rdest..., rd2), tindsd, src, (rsrc..., rs1), tindss)
end

# At least one of indsdest, indssrc are empty (and both should be, since we've checked)
function _circcopy!(dest, rdest, indsdest, src, rsrc, indssrc)
copy!(dest, CartesianRange(rdest), src, CartesianRange(rsrc))
end

### BitArrays

## getindex
Expand Down
8 changes: 8 additions & 0 deletions test/fft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -326,3 +326,11 @@ for x in (randn(10),randn(10,12))
# note: inference doesn't work for plan_fft_ since the
# algorithm steps are included in the CTPlan type
end

# issue #17896
a = rand(5)
@test fft(a) == fft(view(a,:)) == fft(view(a, 1:5)) == fft(view(a, [1:5;]))
@test rfft(a) == rfft(view(a,:)) == rfft(view(a, 1:5)) == rfft(view(a, [1:5;]))
a16 = convert(Vector{Float16}, a)
@test fft(a16) == fft(view(a16,:)) == fft(view(a16, 1:5)) == fft(view(a16, [1:5;]))
@test rfft(a16) == rfft(view(a16,:)) == rfft(view(a16, 1:5)) == rfft(view(a16, [1:5;]))
29 changes: 28 additions & 1 deletion test/offsetarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -411,10 +411,37 @@ v = OffsetArray(rand(8), (-2,))
@test rotr90(A) == OffsetArray(rotr90(parent(A)), A.offsets[[2,1]])
@test flipdim(A, 1) == OffsetArray(flipdim(parent(A), 1), A.offsets)
@test flipdim(A, 2) == OffsetArray(flipdim(parent(A), 2), A.offsets)
@test circshift(A, (-1,2)) == OffsetArray(circshift(parent(A), (-1,2)), A.offsets)

@test A+1 == OffsetArray(parent(A)+1, A.offsets)
@test 2*A == OffsetArray(2*parent(A), A.offsets)
@test A+A == OffsetArray(parent(A)+parent(A), A.offsets)
@test A.*A == OffsetArray(parent(A).*parent(A), A.offsets)

@test circshift(A, (-1,2)) == OffsetArray(circshift(parent(A), (-1,2)), A.offsets)

src = reshape(collect(1:16), (4,4))
dest = OffsetArray(Array{Int}(4,4), (-1,1))
circcopy!(dest, src)
@test parent(dest) == [8 12 16 4; 5 9 13 1; 6 10 14 2; 7 11 15 3]
@test dest[1:3,2:4] == src[1:3,2:4]

e = eye(5)
a = [e[:,1], e[:,2], e[:,3], e[:,4], e[:,5]]
a1 = zeros(5)
c = [ones(Complex{Float64}, 5),
exp(-2*pi*im*(0:4)/5),
exp(-4*pi*im*(0:4)/5),
exp(-6*pi*im*(0:4)/5),
exp(-8*pi*im*(0:4)/5)]
for s = -5:5
for i = 1:5
thisa = OffsetArray(a[i], (s,))
thisc = c[mod1(i+s+5,5)]
@test_approx_eq fft(thisa) thisc
@test_approx_eq ifft(fft(thisa)) circcopy!(a1, thisa)
@test_approx_eq rfft(thisa) thisc[1:3]
@test_approx_eq irfft(rfft(thisa), 5) a1
end
end

end # let

0 comments on commit 7f55ca5

Please sign in to comment.