Skip to content

Commit

Permalink
new DFT api
Browse files Browse the repository at this point in the history
  • Loading branch information
stevengj committed Jul 9, 2015
1 parent 33ff40f commit dd1d724
Show file tree
Hide file tree
Showing 15 changed files with 1,393 additions and 1,118 deletions.
3 changes: 3 additions & 0 deletions base/complex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ real(x::Real) = x
imag(x::Real) = zero(x)
reim(z) = (real(z), imag(z))

real{T<:Real}(::Type{T}) = T
real{T<:Real}(::Type{Complex{T}}) = T

isreal(x::Real) = true
isreal(z::Complex) = imag(z) == 0
isimag(z::Number) = real(z) == 0
Expand Down
6 changes: 3 additions & 3 deletions base/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -474,23 +474,23 @@ export float32_isvalid, float64_isvalid
@deprecate ($)(x::Char, y::Char) Char(UInt32(x) $ UInt32(y))

# 11241

@deprecate is_valid_char(ch::Char) isvalid(ch)
@deprecate is_valid_ascii(str::ASCIIString) isvalid(str)
@deprecate is_valid_utf8(str::UTF8String) isvalid(str)
@deprecate is_valid_utf16(str::UTF16String) isvalid(str)
@deprecate is_valid_utf32(str::UTF32String) isvalid(str)

@deprecate is_valid_char(ch) isvalid(Char, ch)
@deprecate is_valid_ascii(str) isvalid(ASCIIString, str)
@deprecate is_valid_utf8(str) isvalid(UTF8String, str)
@deprecate is_valid_utf16(str) isvalid(UTF16String, str)
@deprecate is_valid_utf32(str) isvalid(UTF32String, str)

# 11379

@deprecate utf32(c::Integer...) UTF32String(UInt32[c...,0])

# 6193
@deprecate call(P::Base.DFT.Plan, A) P * A

# 10862

function chol(A::AbstractMatrix, uplo::Symbol)
Expand Down
195 changes: 195 additions & 0 deletions base/dft.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
# This file is a part of Julia. License is MIT: http://julialang.org/license

module DFT

# DFT plan where the inputs are an array of eltype T
abstract Plan{T}

import Base: show, summary, size, ndims, length, eltype,
*, A_mul_B!, inv, \, A_ldiv_B!

eltype{T}(::Plan{T}) = T
eltype{P<:Plan}(T::Type{P}) = T.parameters[1]

# size(p) should return the size of the input array for p
size(p::Plan, d) = size(p)[d]
ndims(p::Plan) = length(size(p))
length(p::Plan) = prod(size(p))::Int

##############################################################################
export fft, ifft, bfft, fft!, ifft!, bfft!,
plan_fft, plan_ifft, plan_bfft, plan_fft!, plan_ifft!, plan_bfft!,
rfft, irfft, brfft, plan_rfft, plan_irfft, plan_brfft

complexfloat{T<:FloatingPoint}(x::AbstractArray{Complex{T}}) = x

# return an Array, rather than similar(x), to avoid an extra copy for FFTW
# (which only works on StridedArray types).
complexfloat{T<:Complex}(x::AbstractArray{T}) = copy!(Array(typeof(float(one(T))), size(x)), x)
complexfloat{T<:FloatingPoint}(x::AbstractArray{T}) = copy!(Array(typeof(complex(one(T))), size(x)), x)
complexfloat{T<:Real}(x::AbstractArray{T}) = copy!(Array(typeof(complex(float(one(T)))), size(x)), x)

# implementations only need to provide plan_X(x, region)
# for X in (:fft, :bfft, ...):
for f in (:fft, :bfft, :ifft, :fft!, :bfft!, :ifft!, :rfft)
pf = symbol(string("plan_", f))
@eval begin
$f(x::AbstractArray) = $pf(x) * x
$f(x::AbstractArray, region) = $pf(x, region) * x
$pf(x::AbstractArray; kws...) = $pf(x, 1:ndims(x); kws...)
end
end

# promote to a complex floating-point type (out-of-place only),
# so implementations only need Complex{Float} methods
for f in (:fft, :bfft, :ifft)
pf = symbol(string("plan_", f))
@eval begin
$f{T<:Real}(x::AbstractArray{T}, region=1:ndims(x)) = $f(complexfloat(x), region)
$pf{T<:Real}(x::AbstractArray{T}, region; kws...) = $pf(complexfloat(x), region; kws...)
$f{T<:Union(Integer,Rational)}(x::AbstractArray{Complex{T}}, region=1:ndims(x)) = $f(complexfloat(x), region)
$pf{T<:Union(Integer,Rational)}(x::AbstractArray{Complex{T}}, region; kws...) = $pf(complexfloat(x), region; kws...)
end
end
rfft{T<:Union(Integer,Rational)}(x::AbstractArray{T}, region=1:ndims(x)) = rfft(float(x), region)
plan_rfft{T<:Union(Integer,Rational)}(x::AbstractArray{T}, region; kws...) = plan_rfft(float(x), region; kws...)

# only require implementation to provide *(::Plan{T}, ::Array{T})
*{T}(p::Plan{T}, x::AbstractArray) = p * copy!(Array(T, size(x)), x)

# Implementations should also implement A_mul_B!(Y, plan, X) so as to support
# pre-allocated output arrays. We don't define * in terms of A_mul_B!
# generically here, however, because of subtleties for in-place and rfft plans.

##############################################################################
# To support inv, \, and A_ldiv_B!(y, p, x), we require Plan subtypes
# to have a pinv::Plan field, which caches the inverse plan, and which
# should be initially undefined. They should also implement
# plan_inv(p) to construct the inverse of a plan p.

# hack from @simonster (in #6193) to compute the return type of plan_inv
# without actually calling it or even constructing the empty arrays.
_pinv_type(p::Plan) = typeof([plan_inv(x) for x in typeof(p)[]])
pinv_type(p::Plan) = eltype(_pinv_type(p))

inv(p::Plan) =
isdefined(p, :pinv) ? p.pinv::pinv_type(p) : (p.pinv = plan_inv(p))
\(p::Plan, x::AbstractArray) = inv(p) * x
A_ldiv_B!(y::AbstractArray, p::Plan, x::AbstractArray) = A_mul_B!(y, inv(p), x)

##############################################################################
# implementations only need to provide the unnormalized backwards FFT,
# similar to FFTW, and we do the scaling generically to get the ifft:

type ScaledPlan{T,P,N} <: Plan{T}
p::P
scale::N # not T, to avoid unnecessary promotion to Complex
pinv::Plan
ScaledPlan(p, scale) = new(p, scale)
end
ScaledPlan{P<:Plan,N<:Number}(p::P, scale::N) = ScaledPlan{eltype(P),P,N}(p, scale)
ScaledPlan(p::ScaledPlan, α::Number) = ScaledPlan(p.p, p.scale * α)

size(p::ScaledPlan) = size(p.p)

show(io::IO, p::ScaledPlan) = print(io, p.scale, " * ", p.p)
summary(p::ScaledPlan) = string(p.scale, " * ", summary(p.p))

*(p::ScaledPlan, x::AbstractArray) = scale!(p.p * x, p.scale)

*::Number, p::Plan) = ScaledPlan(p, α)
*(p::Plan, α::Number) = ScaledPlan(p, α)
*(I::UniformScaling, p::ScaledPlan) = ScaledPlan(p, I.λ)
*(p::ScaledPlan, I::UniformScaling) = ScaledPlan(p, I.λ)
*(I::UniformScaling, p::Plan) = ScaledPlan(p, I.λ)
*(p::Plan, I::UniformScaling) = ScaledPlan(p, I.λ)

# Normalization for ifft, given unscaled bfft, is 1/prod(dimensions)
normalization(T, sz, region) = one(T) / prod([sz...][[region...]])
normalization(X, region) = normalization(real(eltype(X)), size(X), region)

plan_ifft(x::AbstractArray, region; kws...) =
ScaledPlan(plan_bfft(x, region; kws...), normalization(x, region))
plan_ifft!(x::AbstractArray, region; kws...) =
ScaledPlan(plan_bfft!(x, region; kws...), normalization(x, region))

plan_inv(p::ScaledPlan) = ScaledPlan(plan_inv(p.p), inv(p.scale))

A_mul_B!(y::AbstractArray, p::ScaledPlan, x::AbstractArray) =
scale!(p.scale, A_mul_B!(y, p.p, x))

##############################################################################
# Real-input DFTs are annoying because the output has a different size
# than the input if we want to gain the full factor-of-two(ish) savings
# For backward real-data transforms, we must specify the original length
# of the first dimension, since there is no reliable way to detect this
# from the data (we can't detect whether the dimension was originally even
# or odd).

for f in (:brfft, :irfft)
pf = symbol(string("plan_", f))
@eval begin
$f(x::AbstractArray, d::Integer) = $pf(x, d) * x
$f(x::AbstractArray, d::Integer, region) = $pf(x, d, region) * x
$pf(x::AbstractArray, d::Integer;kws...) = $pf(x, d, 1:ndims(x);kws...)
end
end

for f in (:brfft, :irfft)
@eval begin
$f{T<:Real}(x::AbstractArray{T}, d::Integer, region=1:ndims(x)) = $f(complexfloat(x), d, region)
$f{T<:Union(Integer,Rational)}(x::AbstractArray{Complex{T}}, d::Integer, region=1:ndims(x)) = $f(complexfloat(x), d, region)
end
end

function rfft_output_size(x::AbstractArray, region)
d1 = first(region)
osize = [size(x)...]
osize[d1] = osize[d1]>>1 + 1
return osize
end

function brfft_output_size(x::AbstractArray, d::Integer, region)
d1 = first(region)
osize = [size(x)...]
@assert osize[d1] == d>>1 + 1
osize[d1] = d
return osize
end

plan_irfft{T}(x::AbstractArray{Complex{T}}, d::Integer, region; kws...) =
ScaledPlan(plan_brfft(x, d, region; kws...),
normalization(T, brfft_output_size(x, d, region), region))

##############################################################################

export fftshift, ifftshift

fftshift(x) = circshift(x, div([size(x)...],2))

function fftshift(x,dim)
s = zeros(Int,ndims(x))
s[dim] = div(size(x,dim),2)
circshift(x, s)
end

ifftshift(x) = circshift(x, div([size(x)...],-2))

function ifftshift(x,dim)
s = zeros(Int,ndims(x))
s[dim] = -div(size(x,dim),2)
circshift(x, s)
end

##############################################################################

# FFTW module (may move to an external package at some point):
if Base.USE_GPL_LIBS
include("fft/FFTW.jl")
importall .FFTW
export FFTW, dct, idct, dct!, idct!, plan_dct, plan_idct, plan_dct!, plan_idct!
end

##############################################################################

end
135 changes: 4 additions & 131 deletions base/dsp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,9 @@

module DSP

importall Base.FFTW
import Base.FFTW.normalization
import Base.trailingsize

export FFTW, filt, filt!, deconv, conv, conv2, xcorr, fftshift, ifftshift,
dct, idct, dct!, idct!, plan_dct, plan_idct, plan_dct!, plan_idct!,
# the rest are defined imported from FFTW:
fft, bfft, ifft, rfft, brfft, irfft,
plan_fft, plan_bfft, plan_ifft, plan_rfft, plan_brfft, plan_irfft,
fft!, bfft!, ifft!, plan_fft!, plan_bfft!, plan_ifft!
export filt, filt!, deconv, conv, conv2, xcorr

_zerosi(b,a,T) = zeros(promote_type(eltype(b), eltype(a), T), max(length(a), length(b))-1)

Expand Down Expand Up @@ -117,10 +110,10 @@ function conv{T<:Base.LinAlg.BlasFloat}(u::StridedVector{T}, v::StridedVector{T}
vpad = [v; zeros(T, np2 - nv)]
if T <: Real
p = plan_rfft(upad)
y = irfft(p(upad).*p(vpad), np2)
y = irfft((p*upad).*(p*vpad), np2)
else
p = plan_fft!(upad)
y = ifft!(p(upad).*p(vpad))
y = ifft!((p*upad).*(p*vpad))
end
return y[1:n]
end
Expand Down Expand Up @@ -149,7 +142,7 @@ function conv2{T}(A::StridedMatrix{T}, B::StridedMatrix{T})
At[1:sa[1], 1:sa[2]] = A
Bt[1:sb[1], 1:sb[2]] = B
p = plan_fft(At)
C = ifft(p(At).*p(Bt))
C = ifft((p*At).*(p*Bt))
if T <: Real
return real(C)
end
Expand All @@ -168,124 +161,4 @@ function xcorr(u, v)
flipdim(conv(flipdim(u, 1), v), 1)
end

fftshift(x) = circshift(x, div([size(x)...],2))

function fftshift(x,dim)
s = zeros(Int,ndims(x))
s[dim] = div(size(x,dim),2)
circshift(x, s)
end

ifftshift(x) = circshift(x, div([size(x)...],-2))

function ifftshift(x,dim)
s = zeros(Int,ndims(x))
s[dim] = -div(size(x,dim),2)
circshift(x, s)
end

# Discrete cosine and sine transforms via FFTW's r2r transforms;
# we follow the Matlab convention and adopt a unitary normalization here.
# Unlike Matlab we compute the multidimensional transform by default,
# similar to the Julia fft functions.

fftwcopy{T<:fftwNumber}(X::StridedArray{T}) = copy(X)
fftwcopy{T<:Real}(X::StridedArray{T}) = float(X)
fftwcopy{T<:Complex}(X::StridedArray{T}) = map(Complex128,X)

for (f, fr2r, Y, Tx) in ((:dct, :r2r, :Y, :Number),
(:dct!, :r2r!, :X, :fftwNumber))
plan_f = symbol("plan_",f)
plan_fr2r = symbol("plan_",fr2r)
fi = symbol("i",f)
plan_fi = symbol("plan_",fi)
Ycopy = Y == :X ? 0 : :(Y = fftwcopy(X))
@eval begin
function $f{T<:$Tx}(X::StridedArray{T}, region)
$Y = $fr2r(X, REDFT10, region)
scale!($Y, sqrt(0.5^length(region) * normalization(X,region)))
sqrthalf = sqrt(0.5)
r = map(n -> 1:n, [size(X)...])
for d in region
r[d] = 1:1
$Y[r...] *= sqrthalf
r[d] = 1:size(X,d)
end
return $Y
end

function $plan_f{T<:$Tx}(X::StridedArray{T}, region,
flags::Unsigned, timelimit::Real)
p = $plan_fr2r(X, REDFT10, region, flags, timelimit)
sqrthalf = sqrt(0.5)
r = map(n -> 1:n, [size(X)...])
nrm = sqrt(0.5^length(region) * normalization(X,region))
return X::StridedArray{T} -> begin
$Y = p(X)
scale!($Y, nrm)
for d in region
r[d] = 1:1
$Y[r...] *= sqrthalf
r[d] = 1:size(X,d)
end
return $Y
end
end

function $fi{T<:$Tx}(X::StridedArray{T}, region)
$Ycopy
scale!($Y, sqrt(0.5^length(region) * normalization(X, region)))
sqrt2 = sqrt(2)
r = map(n -> 1:n, [size(X)...])
for d in region
r[d] = 1:1
$Y[r...] *= sqrt2
r[d] = 1:size(X,d)
end
return r2r!($Y, REDFT01, region)
end

function $plan_fi{T<:$Tx}(X::StridedArray{T}, region,
flags::Unsigned, timelimit::Real)
p = $plan_fr2r(X, REDFT01, region, flags, timelimit)
sqrt2 = sqrt(2)
r = map(n -> 1:n, [size(X)...])
nrm = sqrt(0.5^length(region) * normalization(X,region))
return X::StridedArray{T} -> begin
$Ycopy
scale!($Y, nrm)
for d in region
r[d] = 1:1
$Y[r...] *= sqrt2
r[d] = 1:size(X,d)
end
return p($Y)
end
end

end
for (g,plan_g) in ((f,plan_f), (fi, plan_fi))
@eval begin
$g{T<:$Tx}(X::StridedArray{T}) = $g(X, 1:ndims(X))

$plan_g(X, region, flags::Unsigned) =
$plan_g(X, region, flags, NO_TIMELIMIT)
$plan_g(X, region) =
$plan_g(X, region, ESTIMATE, NO_TIMELIMIT)
$plan_g{T<:$Tx}(X::StridedArray{T}) =
$plan_g(X, 1:ndims(X), ESTIMATE, NO_TIMELIMIT)
end
end
end

# DCT of scalar is just the identity:
dct(x::Number, dims) = length(dims) == 0 || dims[1] == 1 ? x : throw(BoundsError())
idct(x::Number, dims) = dct(x, dims)
dct(x::Number) = x
idct(x::Number) = x
plan_dct(x::Number, dims, flags, tlim) = length(dims) == 0 || dims[1] == 1 ? y::Number -> y : throw(BoundsError())
plan_idct(x::Number, dims, flags, tlim) = plan_dct(x, dims)
plan_dct(x::Number) = y::Number -> y
plan_idct(x::Number) = y::Number -> y

end # module
Loading

0 comments on commit dd1d724

Please sign in to comment.