diff --git a/src/upsample.jl b/src/upsample.jl index 69063802f..02c0fca48 100644 --- a/src/upsample.jl +++ b/src/upsample.jl @@ -149,13 +149,14 @@ upsample_linear_kernel!(y::AbstractArray{<:Any,5}, x::AbstractArray{<:Any,5}) = function upsample_linear_wcn!(output::AbstractArray{T,3}, input::AbstractArray{T,3}) where T size(input)[2:3] == size(output)[2:3] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))") in_w, channels, batches = size(input) + RT = real(T) # treat batch and channel dimension as one for better parallelization granularity channels *= batches out_w, _, _ = size(output) output_slice_size = out_w - # T() and // so that we can handle rationals (super slow) - width_scale = T((in_w - 1) // (out_w - 1)) + #real(T)() and // so that we can handle rationals (super slow) + width_scale = RT((in_w - 1) // (out_w - 1)) @inline idx(c, w) = c * in_w + w + 1 @@ -175,14 +176,15 @@ end function upsample_bilinear_whcn!(output::AbstractArray{T,4}, input::AbstractArray{T,4}) where T size(input)[3:4] == size(output)[3:4] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))") in_w, in_h, channels, batches = size(input) + RT = real(T) # treat batch and channel dimension as one for better parallelization granularity channels *= batches out_w, out_h, _, _ = size(output) output_slice_size = out_h * out_w - # T() and // so that we can handle rationals (super slow) - width_scale = T((in_w - 1) // (out_w - 1)) - height_scale = T((in_h - 1) // (out_h - 1)) + #real(T)() and // so that we can handle rationals (super slow) + width_scale = RT((in_w - 1) // (out_w - 1)) + height_scale = RT((in_h - 1) // (out_h - 1)) @inline idx(c, h, w) = c * in_h * in_w + h * in_w + w + 1 @@ -208,15 +210,16 @@ end function upsample_trilinear_whdcn!(output::AbstractArray{T,5}, input::AbstractArray{T,5}) where T size(input)[4:5] == size(output)[4:5] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))") in_w, in_h, in_d, channels, batches = size(input) + RT = real(T) # treat batch and channel dimension as one for better parallelization granularity channels *= batches out_w, out_h, out_d, _, _ = size(output) output_slice_size = out_h * out_w * out_d - # T() and // so that we can handle rationals (super slow) - width_scale = T((in_w - 1) // (out_w - 1)) - height_scale = T((in_h - 1) // (out_h - 1)) - depth_scale = T((in_d - 1) // (out_d - 1)) + #real(T)() and // so that we can handle rationals (super slow) + width_scale = RT((in_w - 1) // (out_w - 1)) + height_scale = RT((in_h - 1) // (out_h - 1)) + depth_scale = RT((in_d - 1) // (out_d - 1)) @inline idx(c, d, h, w) = c * in_d * in_h * in_w + d * in_h * in_w + h * in_w + w + 1 @@ -268,13 +271,13 @@ end function ∇upsample_linear_wcn!(dx::AbstractArray{T,3}, Δ::AbstractArray{T,3}) where T size(dx)[2:3] == size(Δ)[2:3] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))") in_w, channels, batches = size(dx) - + RT = real(T) # treat batch and channel dimension as one for better parallelization granularity channels *= batches out_w, _, _ = size(Δ) output_slice_size = out_w - width_scale = T((in_w - 1) // (out_w - 1)) + width_scale = RT((in_w - 1) // (out_w - 1)) @inline idx(c, w) = c * in_w + w + 1 @@ -294,14 +297,14 @@ end function ∇upsample_bilinear_whcn!(dx::AbstractArray{T,4}, Δ::AbstractArray{T,4}) where T size(dx)[3:4] == size(Δ)[3:4] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))") in_w, in_h, channels, batches = size(dx) - + RT = real(T) # treat batch and channel dimension as one for better parallelization granularity channels *= batches out_w, out_h, _, _ = size(Δ) output_slice_size = out_h * out_w - width_scale = T((in_w - 1) // (out_w - 1)) - height_scale = T((in_h - 1) // (out_h - 1)) + width_scale = RT((in_w - 1) // (out_w - 1)) + height_scale = RT((in_h - 1) // (out_h - 1)) @inline idx(c, h, w) = c * in_h * in_w + h * in_w + w + 1 @@ -326,15 +329,16 @@ end function ∇upsample_trilinear_whdcn!(dx::AbstractArray{T,5}, Δ::AbstractArray{T,5}) where T size(dx)[4:5] == size(Δ)[4:5] || error("Number of input and output channels and batches must match. Got dx $(size(dx)) and Δ $(size(Δ))") in_w, in_h, in_d, channels, batches = size(dx) + RT = real(T) # treat batch and channel dimension as one for better parallelization granularity channels *= batches out_w, out_h, out_d, _, _ = size(Δ) output_slice_size = out_h * out_w * out_d - # T() and // so that we can handle rationals (super slow) - width_scale = T((in_w - 1) // (out_w - 1)) - height_scale = T((in_h - 1) // (out_h - 1)) - depth_scale = T((in_d - 1) // (out_d - 1)) + #real(T)() and // so that we can handle rationals (super slow) + width_scale = RT((in_w - 1) // (out_w - 1)) + height_scale = RT((in_h - 1) // (out_h - 1)) + depth_scale = RT((in_d - 1) // (out_d - 1)) @inline idx(c, d, h, w) = c * in_d * in_h * in_w + d * in_h * in_w + h * in_w + w + 1 diff --git a/test/upsample.jl b/test/upsample.jl index 24f71d0b1..ffd4852cb 100644 --- a/test/upsample.jl +++ b/test/upsample.jl @@ -172,3 +172,23 @@ end gradtest(x -> pixel_shuffle(x, r), x) end end + +@testset "Complex-valued upsample" begin + for (d, method) in zip([1, 2, 3], [upsample_linear, upsample_bilinear, upsample_trilinear]) + for (k, interp) in zip((2, ntuple(_ -> 2, d)), [method, upsample_nearest]) + x = randn(Complex{Float32}, (4,8,12)[1:d]..., 1, 1) + + upsize = (8, 16, 24)[1:d] + xup = interp(x, k) + @test size(xup)[1:d] == upsize + @test real(xup) == interp(real(x), k) + @test imag(xup) == interp(imag(x), k) + + upsize = (8,24,48)[1:d] + xup = interp(x; size=upsize) + @test size(xup)[1:d] == upsize + @test real(xup) == interp(real(x), size=upsize) + @test imag(xup) == interp(imag(x), size=upsize) + end + end +end