Skip to content

Commit

Permalink
Merge pull request FluxML#421 from mloubout/master
Browse files Browse the repository at this point in the history
support complex input for upsample
  • Loading branch information
ToucheSir authored and mloubout committed Jun 14, 2022
2 parents b02b8c2 + 22ef274 commit 745db36
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 18 deletions.
40 changes: 22 additions & 18 deletions src/upsample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,14 @@ upsample_linear_kernel!(y::AbstractArray{<:Any,5}, x::AbstractArray{<:Any,5}) =
function upsample_linear_wcn!(output::AbstractArray{T,3}, input::AbstractArray{T,3}) where T
size(input)[2:3] == size(output)[2:3] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))")
in_w, channels, batches = size(input)
RT = real(T)
# treat batch and channel dimension as one for better parallelization granularity
channels *= batches
out_w, _, _ = size(output)
output_slice_size = out_w

# T() and // so that we can handle rationals (super slow)
width_scale = T((in_w - 1) // (out_w - 1))
#real(T)() and // so that we can handle rationals (super slow)
width_scale = RT((in_w - 1) // (out_w - 1))

@inline idx(c, w) = c * in_w + w + 1

Expand All @@ -175,14 +176,15 @@ end
function upsample_bilinear_whcn!(output::AbstractArray{T,4}, input::AbstractArray{T,4}) where T
size(input)[3:4] == size(output)[3:4] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))")
in_w, in_h, channels, batches = size(input)
RT = real(T)
# treat batch and channel dimension as one for better parallelization granularity
channels *= batches
out_w, out_h, _, _ = size(output)
output_slice_size = out_h * out_w

# T() and // so that we can handle rationals (super slow)
width_scale = T((in_w - 1) // (out_w - 1))
height_scale = T((in_h - 1) // (out_h - 1))
#real(T)() and // so that we can handle rationals (super slow)
width_scale = RT((in_w - 1) // (out_w - 1))
height_scale = RT((in_h - 1) // (out_h - 1))

@inline idx(c, h, w) = c * in_h * in_w + h * in_w + w + 1

Expand All @@ -208,15 +210,16 @@ end
function upsample_trilinear_whdcn!(output::AbstractArray{T,5}, input::AbstractArray{T,5}) where T
size(input)[4:5] == size(output)[4:5] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))")
in_w, in_h, in_d, channels, batches = size(input)
RT = real(T)
# treat batch and channel dimension as one for better parallelization granularity
channels *= batches
out_w, out_h, out_d, _, _ = size(output)
output_slice_size = out_h * out_w * out_d

# T() and // so that we can handle rationals (super slow)
width_scale = T((in_w - 1) // (out_w - 1))
height_scale = T((in_h - 1) // (out_h - 1))
depth_scale = T((in_d - 1) // (out_d - 1))
#real(T)() and // so that we can handle rationals (super slow)
width_scale = RT((in_w - 1) // (out_w - 1))
height_scale = RT((in_h - 1) // (out_h - 1))
depth_scale = RT((in_d - 1) // (out_d - 1))

@inline idx(c, d, h, w) = c * in_d * in_h * in_w + d * in_h * in_w + h * in_w + w + 1

Expand Down Expand Up @@ -268,13 +271,13 @@ end
function ∇upsample_linear_wcn!(dx::AbstractArray{T,3}, Δ::AbstractArray{T,3}) where T
size(dx)[2:3] == size(Δ)[2:3] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))")
in_w, channels, batches = size(dx)

RT = real(T)
# treat batch and channel dimension as one for better parallelization granularity
channels *= batches
out_w, _, _ = size(Δ)
output_slice_size = out_w

width_scale = T((in_w - 1) // (out_w - 1))
width_scale = RT((in_w - 1) // (out_w - 1))

@inline idx(c, w) = c * in_w + w + 1

Expand All @@ -294,14 +297,14 @@ end
function ∇upsample_bilinear_whcn!(dx::AbstractArray{T,4}, Δ::AbstractArray{T,4}) where T
size(dx)[3:4] == size(Δ)[3:4] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))")
in_w, in_h, channels, batches = size(dx)

RT = real(T)
# treat batch and channel dimension as one for better parallelization granularity
channels *= batches
out_w, out_h, _, _ = size(Δ)
output_slice_size = out_h * out_w

width_scale = T((in_w - 1) // (out_w - 1))
height_scale = T((in_h - 1) // (out_h - 1))
width_scale = RT((in_w - 1) // (out_w - 1))
height_scale = RT((in_h - 1) // (out_h - 1))

@inline idx(c, h, w) = c * in_h * in_w + h * in_w + w + 1

Expand All @@ -326,15 +329,16 @@ end
function ∇upsample_trilinear_whdcn!(dx::AbstractArray{T,5}, Δ::AbstractArray{T,5}) where T
size(dx)[4:5] == size(Δ)[4:5] || error("Number of input and output channels and batches must match. Got dx $(size(dx)) and Δ $(size(Δ))")
in_w, in_h, in_d, channels, batches = size(dx)
RT = real(T)
# treat batch and channel dimension as one for better parallelization granularity
channels *= batches
out_w, out_h, out_d, _, _ = size(Δ)
output_slice_size = out_h * out_w * out_d

# T() and // so that we can handle rationals (super slow)
width_scale = T((in_w - 1) // (out_w - 1))
height_scale = T((in_h - 1) // (out_h - 1))
depth_scale = T((in_d - 1) // (out_d - 1))
#real(T)() and // so that we can handle rationals (super slow)
width_scale = RT((in_w - 1) // (out_w - 1))
height_scale = RT((in_h - 1) // (out_h - 1))
depth_scale = RT((in_d - 1) // (out_d - 1))

@inline idx(c, d, h, w) = c * in_d * in_h * in_w + d * in_h * in_w + h * in_w + w + 1

Expand Down
20 changes: 20 additions & 0 deletions test/upsample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,23 @@ end
gradtest(x -> pixel_shuffle(x, r), x)
end
end

@testset "Complex-valued upsample" begin
for (d, method) in zip([1, 2, 3], [upsample_linear, upsample_bilinear, upsample_trilinear])
for (k, interp) in zip((2, ntuple(_ -> 2, d)), [method, upsample_nearest])
x = randn(Complex{Float32}, (4,8,12)[1:d]..., 1, 1)

upsize = (8, 16, 24)[1:d]
xup = interp(x, k)
@test size(xup)[1:d] == upsize
@test real(xup) == interp(real(x), k)
@test imag(xup) == interp(imag(x), k)

upsize = (8,24,48)[1:d]
xup = interp(x; size=upsize)
@test size(xup)[1:d] == upsize
@test real(xup) == interp(real(x), size=upsize)
@test imag(xup) == interp(imag(x), size=upsize)
end
end
end

0 comments on commit 745db36

Please sign in to comment.