From d0b2e115e1a3323a7099cd3972065d58c6d42c74 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Thu, 6 Apr 2023 00:30:05 +0300 Subject: [PATCH 01/11] Use KernelAbstractions.jl for upsample kernels - Add `align_corners` option. - Add unified test suite which accepts backend as an argument and runs tests for it. --- Project.toml | 2 + src/NNlib.jl | 2 + src/upsample.jl | 571 ++++++++++++++++++++------------------------- test/Project.toml | 2 + test/runtests.jl | 69 +++++- test/test_utils.jl | 41 +++- test/upsample.jl | 386 +++++++++++++++--------------- 7 files changed, 557 insertions(+), 516 deletions(-) diff --git a/Project.toml b/Project.toml index 1892a1b43..e8c55345e 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.8.19" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -21,6 +22,7 @@ NNlibAMDGPUExt = "AMDGPU" AMDGPU = "0.4.8" Adapt = "2, 3.2" ChainRulesCore = "1.13" +KernelAbstractions = "0.9" Requires = "0.5, 1.0" julia = "1.6" diff --git a/src/NNlib.jl b/src/NNlib.jl index 38dab7f62..61e6bf52b 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -12,6 +12,8 @@ using Statistics: mean using LinearAlgebra using LinearAlgebra: BlasFloat, Transpose, Adjoint, AdjOrTransAbsMat using LinearAlgebra.BLAS: BlasInt, @blasfunc +using KernelAbstractions +using KernelAbstractions: @atomic const libblas = Base.libblas_name diff --git a/src/upsample.jl b/src/upsample.jl index 027ed4ebb..b4ac2d50c 100644 --- a/src/upsample.jl +++ b/src/upsample.jl @@ -1,3 +1,76 @@ +""" + pixel_shuffle(x, r::Integer) + +Pixel shuffling operation, upscaling by a factor `r`. + +For 4-arrays representing `N` images, the operation converts input `size(x) == (W, H, r^2*C, N)` +to output of size `(r*W, r*H, C, N)`. For `D`-dimensional data, it expects `ndims(x) == D+2` +with channel and batch dimensions, and divides the number of channels by `r^D`. + +Used in super-resolution networks to upsample towards high resolution features. +Reference: Shi et. al., "Real-Time Single Image and Video Super-Resolution ...", CVPR 2016, https://arxiv.org/abs/1609.05158 + +# Examples + +```jldoctest +julia> x = [10i + j + channel/10 for i in 1:2, j in 1:3, channel in 1:4, batch in 1:1] +2×3×4×1 Array{Float64, 4}: +[:, :, 1, 1] = + 11.1 12.1 13.1 + 21.1 22.1 23.1 + +[:, :, 2, 1] = + 11.2 12.2 13.2 + 21.2 22.2 23.2 + +[:, :, 3, 1] = + 11.3 12.3 13.3 + 21.3 22.3 23.3 + +[:, :, 4, 1] = + 11.4 12.4 13.4 + 21.4 22.4 23.4 + +julia> pixel_shuffle(x, 2) # 4 channels used up as 2x upscaling of image dimensions +4×6×1×1 Array{Float64, 4}: +[:, :, 1, 1] = + 11.1 11.3 12.1 12.3 13.1 13.3 + 11.2 11.4 12.2 12.4 13.2 13.4 + 21.1 21.3 22.1 22.3 23.1 23.3 + 21.2 21.4 22.2 22.4 23.2 23.4 + +julia> y = [i + channel/10 for i in 1:3, channel in 1:6, batch in 1:1] +3×6×1 Array{Float64, 3}: +[:, :, 1] = + 1.1 1.2 1.3 1.4 1.5 1.6 + 2.1 2.2 2.3 2.4 2.5 2.6 + 3.1 3.2 3.3 3.4 3.5 3.6 + +julia> pixel_shuffle(y, 2) # 1D image, with 6 channels reduced to 3 +6×3×1 Array{Float64, 3}: +[:, :, 1] = + 1.1 1.3 1.5 + 1.2 1.4 1.6 + 2.1 2.3 2.5 + 2.2 2.4 2.6 + 3.1 3.3 3.5 + 3.2 3.4 3.6 +``` +""" +function pixel_shuffle(x::AbstractArray, r::Integer) + ndims(x) > 2 || throw(ArgumentError("expected x with at least 3 dimensions")) + d = ndims(x) - 2 + sizein = size(x)[1:d] + cin, n = size(x, d+1), size(x, d+2) + cin % r^d == 0 || throw(ArgumentError("expected channel dimension to be divisible by r^d = $( + r^d), where d=$d is the number of spatial dimensions. Given r=$r, input size(x) = $(size(x))")) + cout = cin ÷ r^d + x = reshape(x, sizein..., ntuple(i->r, d)..., cout, n) + perm = hcat(d+1:2d, 1:d) |> transpose |> vec # = [d+1, 1, d+2, 2, ..., 2d, d] + x = permutedims(x, (perm..., 2d+1, 2d+2)) + return reshape(x, map(s -> s*r, sizein)..., cout, n) +end + """ upsample_nearest(x, scale::NTuple{S,Int}) upsample_nearest(x; size::NTuple{S,Int}) @@ -75,28 +148,9 @@ function rrule(::typeof(upsample_nearest), x::AbstractArray, s::Tuple) return Ω, upsample_nearest_pullback end -# utility function -@inline function compute_source_index_and_lambda( - ratio, # 0 < ratio < 1 - output_index, - input_size, - output_size -) - real_input_index = ratio*output_index - input_index0 = floor(Int, real_input_index) # typecast to int was here in C++ - offset = (input_index0 < input_size - 1) ? 1 : 0 - input_index1 = input_index0 + offset - lambda1 = real_input_index - input_index0 - lambda0 = 1 - lambda1 - return input_index0, input_index1, lambda0, lambda1 -end - -########### -# linear -########### """ - upsample_linear(x::AbstractArray{T,3}, scale::Real) - upsample_linear(x::AbstractArray{T,3}; size::Integer) + upsample_linear(x::AbstractArray{T,3}, scale::Real; align_corners::Bool = true) + upsample_linear(x::AbstractArray{T,3}; size::Integer, align_corners::Bool = true) Upsamples the first dimension of the array `x` by the upsample provided `scale`, using linear interpolation. As an alternative to using `scale`, the resulting array `size` @@ -105,17 +159,18 @@ can be directly specified with a keyword argument. The size of the output is equal to `(scale*S1, S2, S3)`, where `S1, S2, S3 = size(x)`. """ # the user facing function -function upsample_linear(x::AbstractArray{<:Any,N}, scale::NTuple{M,Real}) where {N,M} +function upsample_linear(x::AbstractArray{<:Any,N}, scale::NTuple{M,Real}; align_corners::Bool = true) where {N,M} M == N-2 || error("The scale argument should be an NTuple with length $(N-2), but it has length $M.") outsize = ntuple(i -> floor(Int, scale[i] * Base.size(x, i)), N-2) - return upsample_linear(x; size=outsize) + return upsample_linear(x; size=outsize, align_corners) end # convenience for single-number scale -upsample_linear(x::AbstractArray{<:Any,N}, scale::Real) where N = upsample_linear(x, ntuple(_ -> scale, N-2)) +upsample_linear(x::AbstractArray{<:Any,N}, scale::Real; align_corners::Bool = true) where N = + upsample_linear(x, ntuple(_ -> scale, N-2); align_corners) # this actually calls the upsamling kernel -function upsample_linear(x::AbstractArray{T,N}; size::Union{Integer, NTuple{<:Any,Integer}}) where {T,N} +function upsample_linear(x::AbstractArray{T,N}; size::Union{Integer, NTuple{<:Any,Integer}}, align_corners::Bool = true) where {T,N} length(size) == N-2 || error("The scale argument should be an NTuple with length $(N-2), but it has length $(length(size)).") if Base.size(x)[1:N-2] == size @@ -123,133 +178,18 @@ function upsample_linear(x::AbstractArray{T,N}; size::Union{Integer, NTuple{<:An end y = similar(x, T, size..., Base.size(x)[end-1:end]...) - return upsample_linear_kernel!(y, x) + return upsample_linear_kernel!(y, x; align_corners) end # Convenience definition for integers. The algo internally works with floats and then rounds. -function upsample_linear(x::AbstractArray{T,<:Any}; size) where T<:Integer +function upsample_linear(x::AbstractArray{T,<:Any}; size, align_corners::Bool = true) where T<:Integer y = float.(x) - res = upsample_linear(y; size=size) + res = upsample_linear(y; size=size, align_corners) return round.(T, res) end -# compatibility layer for old versions of NNlibCUDA -# old versions overload upsample_linear_wcn, new versions overload upsample_linear_kernel -# can be removed from NNlib 0.9, i.e. revert https://github.com/FluxML/NNlib.jl/pull/414 -# IF https://github.com/FluxML/NNlibCUDA.jl/pull/49 has been merged -upsample_linear_kernel!(y::AbstractArray{<:Any,3}, x::AbstractArray{<:Any,3}) = upsample_linear_wcn!(y,x) -upsample_linear_kernel!(y::AbstractArray{<:Any,4}, x::AbstractArray{<:Any,4}) = upsample_bilinear_whcn!(y,x) -upsample_linear_kernel!(y::AbstractArray{<:Any,5}, x::AbstractArray{<:Any,5}) = upsample_trilinear_whdcn!(y,x) -∇upsample_linear_kernel!(y::AbstractArray{<:Any,3}, x::AbstractArray{<:Any,3}) = ∇upsample_linear_wcn!(y,x) -∇upsample_linear_kernel!(y::AbstractArray{<:Any,4}, x::AbstractArray{<:Any,4}) = ∇upsample_bilinear_whcn!(y,x) -∇upsample_linear_kernel!(y::AbstractArray{<:Any,5}, x::AbstractArray{<:Any,5}) = ∇upsample_trilinear_whdcn!(y,x) - -# linearly upsamples first dim of 3D array -function upsample_linear_wcn!(output::AbstractArray{T,3}, input::AbstractArray{T,3}) where T - size(input)[2:3] == size(output)[2:3] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))") - in_w, channels, batches = size(input) - RT = real(T) - # treat batch and channel dimension as one for better parallelization granularity - channels *= batches - out_w, _, _ = size(output) - output_slice_size = out_w - - #real(T)() and // so that we can handle rationals (super slow) - width_scale = RT((in_w - 1) // (out_w - 1)) - - @inline idx(c, w) = c * in_w + w + 1 - - Threads.@threads for c in 0:channels-1 - @inbounds for ow in 0:out_w-1 - iw0, iw1, w0lambda, w1lambda = compute_source_index_and_lambda(width_scale, ow, in_w, out_w) - output_offset = c * output_slice_size + ow + 1 - output[output_offset] = (w0lambda * input[idx(c, iw0)] + # w0 * i00 - w1lambda * input[idx(c, iw1)]) # w1 * i01 - end - end - return output -end - -# bilinear -# linearly upsamples first two dims of 4D array -function upsample_bilinear_whcn!(output::AbstractArray{T,4}, input::AbstractArray{T,4}) where T - size(input)[3:4] == size(output)[3:4] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))") - in_w, in_h, channels, batches = size(input) - RT = real(T) - # treat batch and channel dimension as one for better parallelization granularity - channels *= batches - out_w, out_h, _, _ = size(output) - output_slice_size = out_h * out_w - - #real(T)() and // so that we can handle rationals (super slow) - width_scale = RT((in_w - 1) // (out_w - 1)) - height_scale = RT((in_h - 1) // (out_h - 1)) - - @inline idx(c, h, w) = c * in_h * in_w + h * in_w + w + 1 - - Threads.@threads for c in 0:channels-1 - @inbounds for oh in 0:out_h-1 - ih0, ih1, h0lambda, h1lambda = compute_source_index_and_lambda(height_scale, oh, in_h, out_h) - for ow in 0:out_w-1 - iw0, iw1, w0lambda, w1lambda = compute_source_index_and_lambda(width_scale, ow, in_w, out_w) - output_offset = c * output_slice_size + oh * out_w + ow + 1 - output[output_offset] = - (h0lambda * w0lambda * input[idx(c, ih0, iw0)] + # h0 * w0 * i00 - h0lambda * w1lambda * input[idx(c, ih0, iw1)] + # h0 * w1 * i01 - h1lambda * w0lambda * input[idx(c, ih1, iw0)] + # h1 * w0 * i10 - h1lambda * w1lambda * input[idx(c, ih1, iw1)]) # h1 * w1 * i11 - end - end - end - return output -end - -# trilinear -# linearly upsamples first three dims of 5D array -function upsample_trilinear_whdcn!(output::AbstractArray{T,5}, input::AbstractArray{T,5}) where T - size(input)[4:5] == size(output)[4:5] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))") - in_w, in_h, in_d, channels, batches = size(input) - RT = real(T) - # treat batch and channel dimension as one for better parallelization granularity - channels *= batches - out_w, out_h, out_d, _, _ = size(output) - output_slice_size = out_h * out_w * out_d - - #real(T)() and // so that we can handle rationals (super slow) - width_scale = RT((in_w - 1) // (out_w - 1)) - height_scale = RT((in_h - 1) // (out_h - 1)) - depth_scale = RT((in_d - 1) // (out_d - 1)) - - @inline idx(c, d, h, w) = c * in_d * in_h * in_w + d * in_h * in_w + h * in_w + w + 1 - - Threads.@threads for c in 0:channels-1 - @inbounds for od in 0:out_d-1 - id0, id1, d0lambda, d1lambda = compute_source_index_and_lambda(depth_scale, od, in_d, out_d) - for oh in 0:out_h-1 - ih0, ih1, h0lambda, h1lambda = compute_source_index_and_lambda(height_scale, oh, in_h, out_h) - for ow in 0:out_w-1 - iw0, iw1, w0lambda, w1lambda = compute_source_index_and_lambda(width_scale, ow, in_w, out_w) - output_offset = c * output_slice_size + od * out_w * out_h + oh * out_w + ow + 1 - output[output_offset] = - d0lambda * h0lambda * w0lambda * input[idx(c, id0, ih0, iw0)] + # d0 * h0 * w0 * i000 - d0lambda * h0lambda * w1lambda * input[idx(c, id0, ih0, iw1)] + # d0 * h0 * w1 * i001 - d0lambda * h1lambda * w0lambda * input[idx(c, id0, ih1, iw0)] + # d0 * h1 * w0 * i010 - d0lambda * h1lambda * w1lambda * input[idx(c, id0, ih1, iw1)] + # d0 * h1 * w1 * i011 - d1lambda * h0lambda * w0lambda * input[idx(c, id1, ih0, iw0)] + # d1 * h0 * w0 * i100 - d1lambda * h0lambda * w1lambda * input[idx(c, id1, ih0, iw1)] + # d1 * h0 * w1 * i101 - d1lambda * h1lambda * w0lambda * input[idx(c, id1, ih1, iw0)] + # d1 * h1 * w0 * i110 - d1lambda * h1lambda * w1lambda * input[idx(c, id1, ih1, iw1)] # d1 * h1 * w1 * i111 - end - end - end - end - return output -end - - - """ - ∇upsample_linear(Δ::AbstractArray{T,3}; size::Integer) where T + ∇upsample_linear(Δ::AbstractArray{T,3}; size::Integer, align_corners::Bool = true) where T # Arguments - `Δ`: Incoming gradient array, backpropagated from downstream layers @@ -258,127 +198,26 @@ end # Outputs - `dx`: Downsampled version of `Δ` """ -function ∇upsample_linear(Δ::AbstractArray{T,N}; size::NTuple{<:Any,Integer}) where {T,N} +function ∇upsample_linear(Δ::AbstractArray{T,N}; size::NTuple{<:Any,Integer}, align_corners::Bool = true) where {T,N} if Base.size(Δ)[1:N-2] == size return Δ end dx = zero(similar(Δ, T, size..., Base.size(Δ)[end-1:end]...)) - return ∇upsample_linear_kernel!(dx, Δ) + return ∇upsample_linear_kernel!(dx, Δ; align_corners) end -# linear -function ∇upsample_linear_wcn!(dx::AbstractArray{T,3}, Δ::AbstractArray{T,3}) where T - size(dx)[2:3] == size(Δ)[2:3] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))") - in_w, channels, batches = size(dx) - RT = real(T) - # treat batch and channel dimension as one for better parallelization granularity - channels *= batches - out_w, _, _ = size(Δ) - output_slice_size = out_w - - width_scale = RT((in_w - 1) // (out_w - 1)) - - @inline idx(c, w) = c * in_w + w + 1 - - Threads.@threads for c in 0:channels-1 - @inbounds for ow in 0:out_w-1 - iw0, iw1, w0lambda, w1lambda = compute_source_index_and_lambda(width_scale, ow, in_w, out_w) - output_offset = c * output_slice_size + ow + 1 - Δ_value = Δ[output_offset] - dx[idx(c, iw0)] += w0lambda * Δ_value # i00 - dx[idx(c, iw1)] += w1lambda * Δ_value # i01 - end - end - return dx -end -# bilinear -function ∇upsample_bilinear_whcn!(dx::AbstractArray{T,4}, Δ::AbstractArray{T,4}) where T - size(dx)[3:4] == size(Δ)[3:4] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))") - in_w, in_h, channels, batches = size(dx) - RT = real(T) - # treat batch and channel dimension as one for better parallelization granularity - channels *= batches - out_w, out_h, _, _ = size(Δ) - output_slice_size = out_h * out_w - - width_scale = RT((in_w - 1) // (out_w - 1)) - height_scale = RT((in_h - 1) // (out_h - 1)) - - @inline idx(c, h, w) = c * in_h * in_w + h * in_w + w + 1 - - Threads.@threads for c in 0:channels-1 - @inbounds for oh in 0:out_h-1 - ih0, ih1, h0lambda, h1lambda = compute_source_index_and_lambda(height_scale, oh, in_h, out_h) - for ow in 0:out_w-1 - iw0, iw1, w0lambda, w1lambda = compute_source_index_and_lambda(width_scale, ow, in_w, out_w) - output_offset = c * output_slice_size + oh * out_w + ow + 1 - Δ_value = Δ[output_offset] - dx[idx(c, ih0, iw0)] += h0lambda * w0lambda * Δ_value # i00 - dx[idx(c, ih0, iw1)] += h0lambda * w1lambda * Δ_value # i01 - dx[idx(c, ih1, iw0)] += h1lambda * w0lambda * Δ_value # i10 - dx[idx(c, ih1, iw1)] += h1lambda * w1lambda * Δ_value # i11 - end - end - end - return dx -end - -# trilinear -function ∇upsample_trilinear_whdcn!(dx::AbstractArray{T,5}, Δ::AbstractArray{T,5}) where T - size(dx)[4:5] == size(Δ)[4:5] || error("Number of input and output channels and batches must match. Got dx $(size(dx)) and Δ $(size(Δ))") - in_w, in_h, in_d, channels, batches = size(dx) - RT = real(T) - # treat batch and channel dimension as one for better parallelization granularity - channels *= batches - out_w, out_h, out_d, _, _ = size(Δ) - output_slice_size = out_h * out_w * out_d - - #real(T)() and // so that we can handle rationals (super slow) - width_scale = RT((in_w - 1) // (out_w - 1)) - height_scale = RT((in_h - 1) // (out_h - 1)) - depth_scale = RT((in_d - 1) // (out_d - 1)) - - @inline idx(c, d, h, w) = c * in_d * in_h * in_w + d * in_h * in_w + h * in_w + w + 1 - - Threads.@threads for c in 0:channels-1 - @inbounds for od in 0:out_d-1 - id0, id1, d0lambda, d1lambda = compute_source_index_and_lambda(depth_scale, od, in_d, out_d) - for oh in 0:out_h-1 - ih0, ih1, h0lambda, h1lambda = compute_source_index_and_lambda(height_scale, oh, in_h, out_h) - for ow in 0:out_w-1 - iw0, iw1, w0lambda, w1lambda = compute_source_index_and_lambda(width_scale, ow, in_w, out_w) - output_offset = c * output_slice_size + od * out_w * out_h + oh * out_w + ow + 1 - Δ_value = Δ[output_offset] - dx[idx(c, id0, ih0, iw0)] += d0lambda * h0lambda * w0lambda * Δ_value # /* i000 */ - dx[idx(c, id0, ih0, iw1)] += d0lambda * h0lambda * w1lambda * Δ_value # /* i001 */ - dx[idx(c, id0, ih1, iw0)] += d0lambda * h1lambda * w0lambda * Δ_value # /* i010 */ - dx[idx(c, id0, ih1, iw1)] += d0lambda * h1lambda * w1lambda * Δ_value # /* i011 */ - dx[idx(c, id1, ih0, iw0)] += d1lambda * h0lambda * w0lambda * Δ_value # /* i100 */ - dx[idx(c, id1, ih0, iw1)] += d1lambda * h0lambda * w1lambda * Δ_value # /* i101 */ - dx[idx(c, id1, ih1, iw0)] += d1lambda * h1lambda * w0lambda * Δ_value # /* i110 */ - dx[idx(c, id1, ih1, iw1)] += d1lambda * h1lambda * w1lambda * Δ_value # /* i111 */ - end - end - end - end - return dx -end - -function rrule(::typeof(upsample_linear), x::AbstractArray{<:Any,N}; size) where N - Ω = upsample_linear(x; size=size) +function rrule(::typeof(upsample_linear), x::AbstractArray{<:Any,N}; size, align_corners::Bool = true) where N + Ω = upsample_linear(x; size, align_corners) function upsample_linear_pullback(Δ) - (NoTangent(), ∇upsample_linear(unthunk(Δ); size=Base.size(x)[1:N-2])) + (NoTangent(), ∇upsample_linear(unthunk(Δ); size=Base.size(x)[1:N-2], align_corners)) end return Ω, upsample_linear_pullback end -########### -# bilinear -########### """ - upsample_bilinear(x::AbstractArray{T,4}, scale::NTuple{2,Real}) - upsample_bilinear(x::AbstractArray{T,4}; size::NTuple{2,Integer}) + upsample_bilinear(x::AbstractArray{T,4}, scale::NTuple{2,Real}; align_corners::Bool = true) + upsample_bilinear(x::AbstractArray{T,4}; size::NTuple{2,Integer}, align_corners::Bool = true) Upsamples the first 2 dimensions of the array `x` by the upsample factors stored in `scale`, using bilinear interpolation. As an alternative to using `scale`, the resulting image `size` @@ -417,12 +256,12 @@ julia> upsample_bilinear(x, (2.5, 3.5)) # non-integer scaling factors are allow 4.0 4.22222 4.44444 4.66667 4.88889 5.33333 5.55556 5.77778 6.0 ``` """ -upsample_bilinear(x, scale) = upsample_linear(x, scale) -upsample_bilinear(x; size) = upsample_linear(x; size) +upsample_bilinear(x, scale; align_corners::Bool = true) = upsample_linear(x, scale; align_corners) +upsample_bilinear(x; size, align_corners::Bool = true) = upsample_linear(x; size, align_corners) """ - ∇upsample_bilinear(Δ::AbstractArray{T,4}; size::NTuple{2,Integer}) where T + ∇upsample_bilinear(Δ::AbstractArray{T,4}; size::NTuple{2,Integer}, align_corners::Bool = true) where T # Arguments - `Δ`: Incoming gradient array, backpropagated from downstream layers @@ -431,11 +270,11 @@ upsample_bilinear(x; size) = upsample_linear(x; size) # Outputs - `dx`: Downsampled version of `Δ` """ -∇upsample_bilinear(Δ; size) = ∇upsample_linear(Δ; size) +∇upsample_bilinear(Δ; size, align_corners::Bool = true) = ∇upsample_linear(Δ; size, align_corners) """ - upsample_trilinear(x::AbstractArray{T,5}, scale::NTuple{3,Real}) - upsample_trilinear(x::AbstractArray{T,5}; size::NTuple{3,Integer}) + upsample_trilinear(x::AbstractArray{T,5}, scale::NTuple{3,Real}; align_corners::Bool = true) + upsample_trilinear(x::AbstractArray{T,5}; size::NTuple{3,Integer}, align_corners::Bool = true) Upsamples the first 3 dimensions of the array `x` by the upsample factors stored in `scale`, using trilinear interpolation. As an alternative to using `scale`, the resulting image `size` @@ -452,11 +291,11 @@ upsample_trilinear(x; size=(4, 9, 11)) # specify ouput size instead upsample_trilinear(x, (2.5, 3.5, pi)) # non-integer scaling factors are allowed ``` """ -upsample_trilinear(x, scale) = upsample_linear(x, scale) -upsample_trilinear(x; size) = upsample_linear(x; size) +upsample_trilinear(x, scale; align_corners::Bool = true) = upsample_linear(x, scale; align_corners) +upsample_trilinear(x; size, align_corners::Bool = true) = upsample_linear(x; size, align_corners) """ - ∇upsample_trilinear(Δ::AbstractArray{T,5}; size::NTuple{3,Integer}) where T + ∇upsample_trilinear(Δ::AbstractArray{T,5}; size::NTuple{3,Integer}, align_corners::Bool = true) where T # Arguments - `Δ`: Incoming gradient array, backpropagated from downstream layers @@ -465,79 +304,167 @@ upsample_trilinear(x; size) = upsample_linear(x; size) # Outputs - `dx`: Downsampled version of `Δ` """ -∇upsample_trilinear(Δ; size)= ∇upsample_linear(Δ; size) +∇upsample_trilinear(Δ; size, align_corners::Bool = true) = ∇upsample_linear(Δ; size, align_corners) + +function upsample_linear_kernel!( + y::AbstractArray{T, N}, x::AbstractArray{T, N}; align_corners::Bool = true, +) where {T, N} + ndrange = size(y)[1:N - 2] + ratios = align_corners ? + ntuple(i -> real(T)((size(x, i) - 1) / (size(y, i) - 1)), N - 2) : + ntuple(i -> real(T)(size(x, i) / size(y, i)), N - 2) + + backend = KernelAbstractions.get_backend(x) + _upsample_linear_kernel!(backend)(y, x, ratios..., Val(align_corners); ndrange) + return y +end +function ∇upsample_linear_kernel!( + dx::AbstractArray{T, N}, Δ::AbstractArray{T, N}; align_corners::Bool = true, +) where {T, N} + ndrange = size(Δ)[1:N - 2] + ratios = align_corners ? + ntuple(i -> real(T)((size(dx, i) - 1) / (size(Δ, i) - 1)), N - 2) : + ntuple(i -> real(T)(size(dx, i) / size(Δ, i)), N - 2) + backend = KernelAbstractions.get_backend(dx) + _∇upsample_linear_kernel!(backend)(dx, Δ, ratios..., Val(align_corners); ndrange) + return dx +end -""" - pixel_shuffle(x, r::Integer) +# Linear. -Pixel shuffling operation, upscaling by a factor `r`. +@kernel function _upsample_linear_kernel!(y::T, x::T, rwidth, align::Val{A}) where { + T <: AbstractArray{<: Any, 3}, A, +} + @uniform in_width::UInt32, channels::UInt32, batch::UInt32 = size(x) -For 4-arrays representing `N` images, the operation converts input `size(x) == (W, H, r^2*C, N)` -to output of size `(r*W, r*H, C, N)`. For `D`-dimensional data, it expects `ndims(x) == D+2` -with channel and batch dimensions, and divides the number of channels by `r^D`. + i::UInt32 = @index(Global) + iw0, iw1, w0lambda, w1lambda = source_index_and_lambda( rwidth, i - 0x1, align, in_width) + @inbounds for n in 1:batch, c in 1:channels + y[i, c, n] = w0lambda * x[iw0, c, n] + w1lambda * x[iw1, c, n] + end +end -Used in super-resolution networks to upsample towards high resolution features. -Reference: Shi et. al., "Real-Time Single Image and Video Super-Resolution ...", CVPR 2016, https://arxiv.org/abs/1609.05158 +@kernel function _∇upsample_linear_kernel!(dx::T, Δ::T, rwidth, align::Val{A}) where { + T <: AbstractArray{<: Any, 3}, A, +} + @uniform in_width::UInt32, channels::UInt32, batch::UInt32 = size(Δ) + @uniform out_width::UInt32 = size(dx, 1) + + i::UInt32 = @index(Global) + ow0, ow1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, out_width) + @inbounds for n in 1:batch, c in 1:channels + val = Δ[i, c, n] + @atomic dx[ow0, c, n] += w0lambda * val + @atomic dx[ow1, c, n] += w1lambda * val + end +end -# Examples +# Bilinear. -```jldoctest -julia> x = [10i + j + channel/10 for i in 1:2, j in 1:3, channel in 1:4, batch in 1:1] -2×3×4×1 Array{Float64, 4}: -[:, :, 1, 1] = - 11.1 12.1 13.1 - 21.1 22.1 23.1 +@kernel function _upsample_linear_kernel!(y::T, x::T, rwidth, rheight, align::Val{A}) where { + T <: AbstractArray{<: Any, 4}, A, +} + @uniform in_width::UInt32, in_height::UInt32, channels::UInt32, batch::UInt32 = size(x) -[:, :, 2, 1] = - 11.2 12.2 13.2 - 21.2 22.2 23.2 + i::UInt32, j::UInt32 = @index(Global, NTuple) -[:, :, 3, 1] = - 11.3 12.3 13.3 - 21.3 22.3 23.3 + iw0, iw1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, in_width) + ih0, ih1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - 0x1, align, in_height) -[:, :, 4, 1] = - 11.4 12.4 13.4 - 21.4 22.4 23.4 + @inbounds for n in 1:batch, c in 1:channels + y[i, j, c, n] = + h0lambda * (w0lambda * x[iw0, ih0, c, n] + w1lambda * x[iw1, ih0, c, n]) + + h1lambda * (w0lambda * x[iw0, ih1, c, n] + w1lambda * x[iw1, ih1, c, n]) + end +end -julia> pixel_shuffle(x, 2) # 4 channels used up as 2x upscaling of image dimensions -4×6×1×1 Array{Float64, 4}: -[:, :, 1, 1] = - 11.1 11.3 12.1 12.3 13.1 13.3 - 11.2 11.4 12.2 12.4 13.2 13.4 - 21.1 21.3 22.1 22.3 23.1 23.3 - 21.2 21.4 22.2 22.4 23.2 23.4 +@kernel function _∇upsample_linear_kernel!(dx::T, Δ::T, rwidth, rheight, align::Val{A}) where { + T <: AbstractArray{<: Any, 4}, A, +} + @uniform in_width::UInt32, in_height::UInt32, channels::UInt32, batch::UInt32 = size(Δ) + @uniform out_width::UInt32, out_height::UInt32 = size(dx)[1:2] -julia> y = [i + channel/10 for i in 1:3, channel in 1:6, batch in 1:1] -3×6×1 Array{Float64, 3}: -[:, :, 1] = - 1.1 1.2 1.3 1.4 1.5 1.6 - 2.1 2.2 2.3 2.4 2.5 2.6 - 3.1 3.2 3.3 3.4 3.5 3.6 + i::UInt32, j::UInt32 = @index(Global, NTuple) -julia> pixel_shuffle(y, 2) # 1D image, with 6 channels reduced to 3 -6×3×1 Array{Float64, 3}: -[:, :, 1] = - 1.1 1.3 1.5 - 1.2 1.4 1.6 - 2.1 2.3 2.5 - 2.2 2.4 2.6 - 3.1 3.3 3.5 - 3.2 3.4 3.6 -``` -""" -function pixel_shuffle(x::AbstractArray, r::Integer) - ndims(x) > 2 || throw(ArgumentError("expected x with at least 3 dimensions")) - d = ndims(x) - 2 - sizein = size(x)[1:d] - cin, n = size(x, d+1), size(x, d+2) - cin % r^d == 0 || throw(ArgumentError("expected channel dimension to be divisible by r^d = $( - r^d), where d=$d is the number of spatial dimensions. Given r=$r, input size(x) = $(size(x))")) - cout = cin ÷ r^d - x = reshape(x, sizein..., ntuple(i->r, d)..., cout, n) - perm = hcat(d+1:2d, 1:d) |> transpose |> vec # = [d+1, 1, d+2, 2, ..., 2d, d] - x = permutedims(x, (perm..., 2d+1, 2d+2)) - return reshape(x, map(s -> s*r, sizein)..., cout, n) + ow0, ow1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, out_width) + oh0, oh1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - 0x1, align, out_height) + + @inbounds for n in 1:batch, c in 1:channels + val = Δ[i, j, c, n] + @atomic dx[ow0, oh0, c, n] += w0lambda * h0lambda * val + @atomic dx[ow1, oh0, c, n] += w1lambda * h0lambda * val + @atomic dx[ow0, oh1, c, n] += w0lambda * h1lambda * val + @atomic dx[ow1, oh1, c, n] += w1lambda * h1lambda * val + end +end + +# Trilinear. + +@kernel function _upsample_linear_kernel!(y::T, x::T, rwidth, rheight, rdepth, align::Val{A}) where { + T <: AbstractArray{<: Any, 5}, A, +} + @uniform in_width::UInt32, in_height::UInt32, in_depth::UInt32 = size(x)[1:3] + @uniform channels::UInt32, batch::UInt32 = size(x)[4:5] + + i::UInt32, j::UInt32, k::UInt32 = @index(Global, NTuple) + + iw0, iw1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, in_width) + ih0, ih1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - 0x1, align, in_height) + id0, id1, d0lambda, d1lambda = source_index_and_lambda(rdepth, k - 0x1, align, in_depth) + + @inbounds for n in 1:batch, c in 1:channels + y[i, j, k, c, n] = + d0lambda * ( + h0lambda * (w0lambda * x[iw0, ih0, id0, c, n] + w1lambda * x[iw1, ih0, id0, c, n]) + + h1lambda * (w0lambda * x[iw0, ih1, id0, c, n] + w1lambda * x[iw1, ih1, id0, c, n])) + + d1lambda * ( + h0lambda * (w0lambda * x[iw0, ih0, id1, c, n] + w1lambda * x[iw1, ih0, id1, c, n]) + + h1lambda * (w0lambda * x[iw0, ih1, id1, c, n] + w1lambda * x[iw1, ih1, id1, c, n])) + end +end + +@kernel function _∇upsample_linear_kernel!(dx::T, Δ::T, rwidth, rheight, rdepth, align::Val{A}) where { + T <: AbstractArray{<: Any, 5}, A, +} + @uniform in_width::UInt32, in_height::UInt32, in_depth::UInt32 = size(Δ)[1:3] + @uniform channels::UInt32, batch::UInt32 = size(Δ)[4:5] + @uniform out_width::UInt32, out_height::UInt32, out_depth::UInt32 = size(dx)[1:3] + + i::UInt32, j::UInt32, k::UInt32 = @index(Global, NTuple) + + ow0, ow1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, out_width) + oh0, oh1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - 0x1, align, out_height) + od0, od1, d0lambda, d1lambda = source_index_and_lambda(rdepth, k - 0x1, align, out_depth) + + @inbounds for n in 1:batch, c in 1:channels + val = Δ[i, j, k, c, n] + @atomic dx[ow0, oh0, od0, c, n] += w0lambda * h0lambda * d0lambda * val + @atomic dx[ow1, oh0, od0, c, n] += w1lambda * h0lambda * d0lambda * val + @atomic dx[ow0, oh1, od0, c, n] += w0lambda * h1lambda * d0lambda * val + @atomic dx[ow1, oh1, od0, c, n] += w1lambda * h1lambda * d0lambda * val + + @atomic dx[ow0, oh0, od1, c, n] += w0lambda * h0lambda * d1lambda * val + @atomic dx[ow1, oh0, od1, c, n] += w1lambda * h0lambda * d1lambda * val + @atomic dx[ow0, oh1, od1, c, n] += w0lambda * h1lambda * d1lambda * val + @atomic dx[ow1, oh1, od1, c, n] += w1lambda * h1lambda * d1lambda * val + end +end + +@inline function source_index_and_lambda( + ratio::T, out_idx::UInt32, ::Val{align}, in_width::UInt32, +) where {T, align} + real_index = align ? + ratio * out_idx : + max(zero(T), ratio * (out_idx + T(0.5)) - T(0.5)) + + iw0 = floor(UInt32, real_index) + offset::UInt32 = ifelse(iw0 < in_width - 0x1, 0x1, 0x0) + iw1 = iw0 + offset + 0x1 + + w1lambda = real_index - iw0 + w0lambda = T(1) - w1lambda + + return iw0 + 0x1, iw1, w0lambda, w1lambda end diff --git a/test/Project.toml b/test/Project.toml index d158aeae4..82062f811 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,6 +1,8 @@ [deps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" +KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" diff --git a/test/runtests.jl b/test/runtests.jl index 38daea5e3..0c038bbb6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,11 +7,76 @@ import Zygote using Zygote: gradient using StableRNGs using Documenter +using Adapt +using KernelAbstractions DocMeta.setdocmeta!(NNlib, :DocTestSetup, :(using NNlib, UnicodePlots); recursive=true) const rng = StableRNG(123) include("test_utils.jl") +macro conditional_testset(name, skip_tests, expr) + esc(quote + @testset $name begin + if $name ∉ $skip_tests + $expr + else + @test_skip false + end + end + end) +end + +include("upsample.jl") + +function nnlib_testsuite(Backend; skip_tests = Set{String}()) + @conditional_testset "Upsample" skip_tests begin + upsample_testsuite(Backend) + end +end + +@testset verbose=true "NNlib.jl - Test Suite" begin + @testset "CPU" begin + nnlib_testsuite(CPU) + end + + if get(ENV, "NNLIB_TEST_CUDA", "false") == "true" + using CUDA + if CUDA.functional() + @testset "CUDABackend" begin + nnlib_testsuite(CUDABackend) + end + else + @info "CUDA.jl is not functional. Skipping test suite for CUDABackend." + end + else + @info "Skipping CUDA tests, set NNLIB_TEST_CUDA=true to run them." + end + + if get(ENV, "NNLIB_TEST_AMDGPU", "false") == "true" + import Pkg + test_info = Pkg.project() + # Add MIOpen_jll to AMDGPU. + Pkg.develop("AMDGPU") + Pkg.activate(joinpath(Pkg.devdir(), "AMDGPU")) + Pkg.add("MIOpen_jll") + Pkg.update() + # Update test project. + Pkg.activate(test_info.path) + Pkg.update() + + using AMDGPU + if AMDGPU.functional() + @testset "ROCBackend" begin + nnlib_testsuite(ROCBackend) + end + else + @info "AMDGPU.jl is not functional. Skipping test suite for ROCBackend." + end + else + @info "Skipping AMDGPU tests, set NNLIB_TEST_AMDGPU=true to run them." + end +end + @testset verbose=true "NNlib.jl" begin if get(ENV, "NNLIB_TEST_CUDA", "false") == "true" using CUDA @@ -104,10 +169,6 @@ include("test_utils.jl") include("softmax.jl") end - @testset "Upsampling" begin - include("upsample.jl") - end - @testset "Gather" begin include("gather.jl") end diff --git a/test/test_utils.jl b/test/test_utils.jl index 57f75f801..da3991156 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -1,8 +1,8 @@ const IntOrTuple = Union{Int, NTuple{N,Int} where N} -gradtest(f, dims::IntOrTuple...; kw...) = +gradtest(f, dims::IntOrTuple...; kw...) = gradtest(f, randn.(Ref(rng), Float64, dims)...; kw...) # julia v1.3 compat - # gradtest(f, randn.(rng, Float64, dims)...; kw...) + # gradtest(f, randn.(rng, Float64, dims)...; kw...) """ Compare numerical gradient and automatic gradient @@ -10,11 +10,11 @@ given by Zygote. `f` has to be a scalar valued function. Applies also `ChainRulesTestUtils.test_rrule` if the rrule for `f` is explicitly defined. """ -function gradtest(f, xs...; atol = 1e-6, rtol = 1e-6, fkwargs = NamedTuple(), - check_rrule = false, - fdm = :central, - check_broadcast = false, - skip = false, broken = false) +function gradtest( + f, xs...; atol = 1e-6, rtol = 1e-6, fkwargs = NamedTuple(), + check_rrule = false, fdm = :central, check_broadcast = false, + skip = false, broken = false, +) # TODO: revamp when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/pull/166 # is merged if check_rrule @@ -55,3 +55,30 @@ function gradtest(f, xs...; atol = 1e-6, rtol = 1e-6, fkwargs = NamedTuple(), end return true end + +""" + gputest(f, xs...; checkgrad=true, atol=1e-6, kws...) + +Compare gradients computed on the device vs CPU. +`xs...` should already be on the device. +""" +function gputest(f, xs...; checkgrad=true, atol=1e-6, kws...) + cpu_xs = map(x -> adapt(CPU(), x), xs) + + cpu_y = f(cpu_xs...; kws...) + y = f(xs...; kws...) + @test collect(cpu_y) ≈ collect(y) + + if checkgrad + cpu_grad = gradient((x...) -> sum(sin.(f(x...; kws...))), cpu_xs...) + gpu_grad = gradient((x...) -> sum(sin.(f(x...; kws...))), xs...) + + for (cpu_g, gpu_g) in zip(cpu_grad, adapt(CPU(), gpu_grad)) + if cpu_g === nothing + @test gpu_g === nothing + else + @test collect(cpu_g) ≈ collect(gpu_g) atol=atol + end + end + end +end diff --git a/test/upsample.jl b/test/upsample.jl index ffd4852cb..6f77ccf04 100644 --- a/test/upsample.jl +++ b/test/upsample.jl @@ -1,194 +1,214 @@ -@testset "upsample_nearest, integer scale via reshape" begin - x = reshape(Float32[1. 2.; 3. 4.], (2,2,1,1)) - @test upsample_nearest(x, (3,3))[1,:] == [1,1,1, 2,2,2] - - y = upsample_nearest(x, (2,3)) - @test size(y) == (4,6,1,1) - ∇upsample_nearest(y, (2,3)) == [6 12; 18 24] - - gradtest(x -> upsample_nearest(x, (2,3)), rand(2,2,1,1)) - - y2 = upsample_nearest(x, size=(4,6)) - @test y ≈ y2 - gradtest(x -> upsample_nearest(x, size=(4,6)), rand(2,2,1,1)) - - @test_throws ArgumentError ∇upsample_nearest(y, (2,4)) - @test_throws ArgumentError upsample_nearest(x, (1,2,3,4,5)) - @test_throws ArgumentError upsample_nearest(x, size=(3,4)) -end +function upsample_testsuite(Backend) + cpu, backend = CPU(), Backend() + T = Float32 # TODO test against all supported eltypes for each backend. + atol = T == Float32 ? 1e-3 : 1e-6 + gradtest_fn = backend == CPU() ? gradtest : gputest + + @testset "upsample_nearest, integer scale via reshape" begin + x = adapt(backend, reshape(T[1 2; 3 4], (2,2,1,1))) + @test adapt(cpu, upsample_nearest(x, (3,3)))[1,:] == [1,1,1, 2,2,2] + + y = upsample_nearest(x, (2,3)) + @test size(y) == (4,6,1,1) + y2 = upsample_nearest(x, size=(4,6)) + @test adapt(cpu, y) ≈ adapt(cpu, y2) + + @test adapt(cpu, ∇upsample_nearest(y, (2,3)))[:, :, 1, 1] == [6 12; 18 24] + gradtest_fn( + x -> upsample_nearest(x, (2,3)), + adapt(backend, rand(T, 2,2,1,1)); atol) + gradtest_fn( + x -> upsample_nearest(x, size=(4,6)), + adapt(backend, rand(T, 2,2,1,1)); atol) + + @test_throws ArgumentError ∇upsample_nearest(y, (2,4)) + @test_throws ArgumentError upsample_nearest(x, (1,2,3,4,5)) + @test_throws ArgumentError upsample_nearest(x, size=(3,4)) + end -@testset "Linear upsampling (1D)" begin - x = Float64[1,2,3,4] - x = hcat(x,x,x)[:,:,:] + @testset "Linear upsampling (1D)" begin + x = Float64[1,2,3,4] + x = hcat(x,x,x)[:,:,:] - y = collect(1:1//3:4) - y = hcat(y,y,y)[:,:,:] - yF64 = Float64.(y) + y = collect(1:1//3:4) + y = hcat(y,y,y)[:,:,:] - @test y ≈ upsample_linear(x, 2.5) - @test y ≈ upsample_linear(x; size=10) - gradtest(x->upsample_linear(x, 2.5), x) -end + xd = adapt(backend, x) + @test y ≈ adapt(cpu, upsample_linear(xd, 2.5)) + @test y ≈ adapt(cpu, upsample_linear(xd; size=10)) + gradtest_fn(x -> upsample_linear(x, 2.5), xd; atol) + end -@testset "Bilinear upsampling (2D)" begin - x = Float32[1 2; 3 4][:,:,:,:] - x = cat(x,x; dims=3) - x = cat(x,x; dims=4) - - # this output matches the one of pytorch v1.5.0 - # nn.UpsamplingBilinear2d(scale_factor=(3,2), align_corners=True) - # for above x - y_true = Float32[ 1//1 4//3 5//3 2//1; - 7//5 26//15 31//15 12//5; - 9//5 32//15 37//15 14//5; - 11//5 38//15 43//15 16//5; - 13//5 44//15 49//15 18//5; - 3//1 10//3 11//3 4//1][:,:,:,:] - y_true = cat(y_true,y_true; dims=3) - y_true = cat(y_true,y_true; dims=4) - - y = upsample_bilinear(x, (3, 2)) - @test size(y) == size(y_true) - @test eltype(y) == Float32 - @test y ≈ y_true - - gradtest(x->upsample_bilinear(x, (3, 2)), x, atol=1e-3) # works to higher precision for Float64 - - # additional grad check, also compliant with pytorch - o = ones(Float32,6,4,2,1) - grad_true = 6*ones(Float32,2,2,2,1) - @test ∇upsample_bilinear(o; size = (2,2)) ≈ grad_true - - y_true_2 = Rational{Int}[1//1 5//4 6//4 7//4 2//1; - 3//2 7//4 8//4 9//4 5//2; - 4//2 9//4 10//4 11//4 6//2; - 5//2 11//4 12//4 13//4 7//2; - 3//1 13//4 14//4 15//4 4//1][:,:,:,:] - - # check for real-valued single-number argument and type stability for rationals - upsample_bilinear(x, 2.5) == y_true_2 - - # check Integer support for forward pass - # grads are always assumed to be floats, so no extension there - x = UInt8[1 3; 3 5][:,:,:,:] - y_true_int = UInt8[1 2 3; 2 3 4; 3 4 5][:,:,:,:] - y = upsample_bilinear(x, 1.5) - - @test eltype(y) == UInt8 - @test y == y_true_int -end + @testset "Bilinear upsampling (2D)" begin + x = Float32[1 2; 3 4][:,:,:,:] + x = cat(x,x; dims=3) + x = cat(x,x; dims=4) + + # this output matches the one of pytorch v1.5.0 + # nn.UpsamplingBilinear2d(scale_factor=(3,2), align_corners=True) + # for above x + y_true = Float32[ 1//1 4//3 5//3 2//1; + 7//5 26//15 31//15 12//5; + 9//5 32//15 37//15 14//5; + 11//5 38//15 43//15 16//5; + 13//5 44//15 49//15 18//5; + 3//1 10//3 11//3 4//1][:,:,:,:] + y_true = cat(y_true, y_true; dims=3) + y_true = cat(y_true, y_true; dims=4) + + xd = adapt(backend, x) + y = upsample_bilinear(xd, (3, 2)) + @test size(y) == size(y_true) + @test eltype(y) == Float32 + @test adapt(cpu, y) ≈ y_true + + gradtest_fn(x -> upsample_bilinear(x, (3, 2)), xd; atol) + + # additional grad check, also compliant with pytorch + o = ones(Float32,6,4,2,1) + grad_true = 6*ones(Float32,2,2,2,1) + @test adapt(cpu, ∇upsample_bilinear(adapt(backend, o); size = (2,2))) ≈ grad_true + + # CPU only tests. + + y_true_2 = Rational{Int}[1//1 5//4 6//4 7//4 2//1; + 3//2 7//4 8//4 9//4 5//2; + 4//2 9//4 10//4 11//4 6//2; + 5//2 11//4 12//4 13//4 7//2; + 3//1 13//4 14//4 15//4 4//1][:,:,:,:] + y_true_2 = cat(y_true_2, y_true_2; dims=3) + y_true_2 = cat(y_true_2, y_true_2; dims=4) + + # check for real-valued single-number argument and type stability for rationals + y_rational = upsample_bilinear(Rational{Int}.(x), 2.5) + @test eltype(y_rational) == Rational{Int} + @test y_rational == y_true_2 + + # check Integer support for forward pass + # grads are always assumed to be floats, so no extension there + x = UInt8[1 3; 3 5][:,:,:,:] + y_true_int = UInt8[1 2 3; 2 3 4; 3 4 5][:,:,:,:] + y = upsample_bilinear(x, 1.5) + + @test eltype(y) == UInt8 + @test y == y_true_int + end -@testset "Trilinear upsampling (3D)" begin - # Layout: WHDCN, where D is depth - # we generate data which is constant along W & H and differs in D - # then we upsample along all dimensions - x = ones(Float32, 3,3,3,1,1) - x[:,:,1,:,:] .= 1. - x[:,:,2,:,:] .= 2. - x[:,:,3,:,:] .= 3. - - y_true = ones(Float32, 5,5,5,1,1) - y_true[:,:,1,:,:] .= 1. - y_true[:,:,2,:,:] .= 1.5 - y_true[:,:,3,:,:] .= 2. - y_true[:,:,4,:,:] .= 2.5 - y_true[:,:,5,:,:] .= 3. - - y = upsample_trilinear(x; size=(5,5,5)) - - @test size(y) == size(y_true) - @test eltype(y) == Float32 - @test collect(y) ≈ collect(y_true) - - # this test only works when align_corners=false (not present for CPU yet) - # o = ones(Float32,8,8,8,1,1) - # grad_true = 8*ones(Float32,4,4,4,1,1) - # @test ∇upsample_trilinear(o; size=(4,4,4)) ≈ grad_true - - x = Float64.(x) - gradtest(x -> upsample_trilinear(x, (2,2,2)), x) -end + @testset "Trilinear upsampling (3D)" begin + # Layout: WHDCN, where D is depth + # we generate data which is constant along W & H and differs in D + # then we upsample along all dimensions + x = ones(T, 3,3,3,1,1) + x[:,:,1,:,:] .= 1. + x[:,:,2,:,:] .= 2. + x[:,:,3,:,:] .= 3. + + y_true = ones(T, 5,5,5,1,1) + y_true[:,:,1,:,:] .= 1. + y_true[:,:,2,:,:] .= 1.5 + y_true[:,:,3,:,:] .= 2. + y_true[:,:,4,:,:] .= 2.5 + y_true[:,:,5,:,:] .= 3. + + xd = adapt(backend, x) + y = upsample_trilinear(xd; size=(5,5,5)) + + @test size(y) == size(y_true) + @test eltype(y) == T + @test collect(y) ≈ collect(y_true) + + gradtest_fn( + x -> upsample_trilinear(x, (2,2,2)), xd; + atol=(T == Float32) ? 1e-2 : 1e-5) + + # This test only works when `align_corners=false`. + o = adapt(backend, ones(Float32,8,8,8,1,1)) + grad_true = 8 * ones(Float32,4,4,4,1,1) + @test adapt(cpu, ∇upsample_trilinear(o; size=(4,4,4), align_corners=false)) ≈ grad_true + end -@testset "pixel_shuffle" begin - x = reshape(1:16, (2, 2, 4, 1)) - # [:, :, 1, 1] = - # 1 3 - # 2 4 - # [:, :, 2, 1] = - # 5 7 - # 6 8 - # [:, :, 3, 1] = - # 9 11 - # 10 12 - # [:, :, 4, 1] = - # 13 15 - # 14 16 - - y_true = [1 9 3 11 - 5 13 7 15 - 2 10 4 12 - 6 14 8 16][:,:,:,:] - - y = pixel_shuffle(x, 2) - @test size(y) == size(y_true) - @test y_true == y - - x = reshape(1:32, (2, 2, 8, 1)) - y_true = zeros(Int, 4, 4, 2, 1) - y_true[:,:,1,1] .= [ 1 9 3 11 - 5 13 7 15 - 2 10 4 12 - 6 14 8 16 ] - - y_true[:,:,2,1] .= [ 17 25 19 27 - 21 29 23 31 - 18 26 20 28 - 22 30 24 32] - - y = pixel_shuffle(x, 2) - @test size(y) == size(y_true) - @test y_true == y - - x = reshape(1:4*3*27*2, (4,3,27,2)) - y = pixel_shuffle(x, 3) - @test size(y) == (12, 9, 3, 2) - # batch dimension is preserved - x1 = x[:,:,:,[1]] - x2 = x[:,:,:,[2]] - y1 = pixel_shuffle(x1, 3) - y2 = pixel_shuffle(x2, 3) - @test cat(y1, y2, dims=4) == y - - for d in [1, 2, 3] - r = rand(1:5) - n = rand(1:5) - c = rand(1:5) - insize = rand(1:5, d) - x = rand(insize..., r^d*c, n) - - y = pixel_shuffle(x, r) - @test size(y) == ((r .* insize)..., c, n) - - gradtest(x -> pixel_shuffle(x, r), x) + @testset "pixel_shuffle" begin + x = reshape(1:16, (2, 2, 4, 1)) + # [:, :, 1, 1] = + # 1 3 + # 2 4 + # [:, :, 2, 1] = + # 5 7 + # 6 8 + # [:, :, 3, 1] = + # 9 11 + # 10 12 + # [:, :, 4, 1] = + # 13 15 + # 14 16 + + y_true = [1 9 3 11 + 5 13 7 15 + 2 10 4 12 + 6 14 8 16][:,:,:,:] + + y = pixel_shuffle(adapt(backend, x), 2) + @test size(y) == size(y_true) + @test y_true == adapt(cpu, y) + + x = reshape(1:32, (2, 2, 8, 1)) + y_true = zeros(Int, 4, 4, 2, 1) + y_true[:,:,1,1] .= [ 1 9 3 11 + 5 13 7 15 + 2 10 4 12 + 6 14 8 16 ] + + y_true[:,:,2,1] .= [ 17 25 19 27 + 21 29 23 31 + 18 26 20 28 + 22 30 24 32] + + y = pixel_shuffle(adapt(backend, x), 2) + @test size(y) == size(y_true) + @test y_true == adapt(cpu, y) + + x = reshape(1:4*3*27*2, (4,3,27,2)) + y = pixel_shuffle(adapt(backend, x), 3) + @test size(y) == (12, 9, 3, 2) + + # batch dimension is preserved + x1 = x[:,:,:,[1]] + x2 = x[:,:,:,[2]] + y1 = pixel_shuffle(adapt(backend, x1), 3) + y2 = pixel_shuffle(adapt(backend, x2), 3) + @test adapt(cpu, cat(y1, y2, dims=4)) == adapt(cpu, y) + + for d in [1, 2, 3] + r = rand(1:5) + n = rand(1:5) + c = rand(1:5) + insize = rand(1:5, d) + x = rand(insize..., r^d*c, n) + xd = adapt(backend, x) + + y = pixel_shuffle(xd, r) + @test size(y) == ((r .* insize)..., c, n) + gradtest_fn(x -> pixel_shuffle(x, r), xd) + end end -end -@testset "Complex-valued upsample" begin - for (d, method) in zip([1, 2, 3], [upsample_linear, upsample_bilinear, upsample_trilinear]) - for (k, interp) in zip((2, ntuple(_ -> 2, d)), [method, upsample_nearest]) - x = randn(Complex{Float32}, (4,8,12)[1:d]..., 1, 1) - - upsize = (8, 16, 24)[1:d] - xup = interp(x, k) - @test size(xup)[1:d] == upsize - @test real(xup) == interp(real(x), k) - @test imag(xup) == interp(imag(x), k) - - upsize = (8,24,48)[1:d] - xup = interp(x; size=upsize) - @test size(xup)[1:d] == upsize - @test real(xup) == interp(real(x), size=upsize) - @test imag(xup) == interp(imag(x), size=upsize) + @testset "Complex-valued upsample" begin + for (d, method) in zip([1, 2, 3], [upsample_linear, upsample_bilinear, upsample_trilinear]) + for (k, interp) in zip((2, ntuple(_ -> 2, d)), [method, upsample_nearest]) + x = adapt(backend, randn(Complex{Float32}, (4,8,12)[1:d]..., 1, 1)) + + upsize = (8, 16, 24)[1:d] + xup = interp(x, k) + @test size(xup)[1:d] == upsize + @test adapt(cpu, real(xup)) == adapt(cpu, interp(real(x), k)) + @test adapt(cpu, imag(xup)) == adapt(cpu, interp(imag(x), k)) + + upsize = (8,24,48)[1:d] + xup = interp(x; size=upsize) + @test size(xup)[1:d] == upsize + @test adapt(cpu, real(xup)) == adapt(cpu, interp(real(x), size=upsize)) + @test adapt(cpu, imag(xup)) == adapt(cpu, interp(imag(x), size=upsize)) + end end end end From 5e6cd973bf1cb0fc36041bf68920cceee83dae0a Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Thu, 6 Apr 2023 01:02:58 +0300 Subject: [PATCH 02/11] Remove upsample from NNlibCUDA.jl --- ext/NNlibCUDA/src/NNlibCUDA.jl | 1 - ext/NNlibCUDA/src/upsample.jl | 361 --------------------------------- ext/NNlibCUDA/test/runtests.jl | 1 - ext/NNlibCUDA/test/upsample.jl | 77 ------- src/upsample.jl | 65 +++++- test/runtests.jl | 240 +++++++++++----------- test/test_utils.jl | 8 +- 7 files changed, 182 insertions(+), 571 deletions(-) delete mode 100644 ext/NNlibCUDA/src/upsample.jl delete mode 100644 ext/NNlibCUDA/test/upsample.jl diff --git a/ext/NNlibCUDA/src/NNlibCUDA.jl b/ext/NNlibCUDA/src/NNlibCUDA.jl index b3c3bbb72..b0d5ea9e7 100644 --- a/ext/NNlibCUDA/src/NNlibCUDA.jl +++ b/ext/NNlibCUDA/src/NNlibCUDA.jl @@ -6,7 +6,6 @@ using Random, Statistics const IntOrIntTuple = Union{Integer, NTuple{N,<:Integer} where N} -include("upsample.jl") include("sampling.jl") include("activations.jl") include("batchedadjtrans.jl") diff --git a/ext/NNlibCUDA/src/upsample.jl b/ext/NNlibCUDA/src/upsample.jl deleted file mode 100644 index a6a529fb1..000000000 --- a/ext/NNlibCUDA/src/upsample.jl +++ /dev/null @@ -1,361 +0,0 @@ -# -# Upsampling -# - -# GPU based bilinear upsampling including its gradient -# -# Based on the Caffe2 implementation at: -# The code is a translation from the following files: -# - https://github.com/pytorch/pytorch/blob/v1.8.0-rc1/caffe2/operators/upsample_op.cu -# - https://github.com/pytorch/pytorch/blob/v1.8.0-rc1/caffe2/core/common_gpu.h -# -# Copyright (c) 2016-2021 Facebook Inc. -# Copyright (c) 2015 Google Inc. -# Copyright (c) 2015 Yangqing Jia -# Copyright 2019-2020 Kakao Brain -# -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without modification, are -# permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this list of -# conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, this list of -# conditions and the following disclaimer in the documentation and/or other materials -# provided with the distribution. -# -# 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America and -# IDIAP Research Institute nor the names of its contributors may be used to endorse or -# promote products derived from this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY -# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF -# MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE -# COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) -# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR -# TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -# Forward and backward pass have been tested to produce the same output -# as pytorch with align_corners=True - it works modulo bit noise. -# pytorch's default is align_corners=False, because otherwise the gradients depend on the -# image size, which should be avoided -> this should be considered here as well - -@inline function compute_source_index(ratio::T, dst_index, align_corners) where T - if align_corners - return ratio*dst_index - else - src_idx = ratio * (dst_index + T(0.5)) - T(0.5) - return max(zero(T), src_idx) - end -end - -function NNlib.upsample_linear_kernel!(y::CuArray{T,N}, x::CuArray{T,N}; align_corners=true) where {T,N} - out_size = prod(size(y)[1:N-2]) - - if align_corners - ratios = ntuple(i -> T((size(x,i)-1) / (size(y,i)-1)), N-2) - else - ratios = ntuple(i -> T(size(x,i) / size(y,i)), N-2) - end - - kernel = @cuda launch=false upsample_linear_cuda_kernel!(out_size, ratios..., x, y, align_corners) - config = launch_configuration(kernel.fun; max_threads=256) - threads = Base.min(out_size, config.threads) - blocks = cld(out_size, threads) - kernel(out_size, ratios..., x, y, align_corners; threads=threads, blocks=blocks) - return y -end - -function NNlib.∇upsample_linear_kernel!(dx::CuArray{T,N}, Δ::CuArray{T,N}; align_corners=true) where {T,N} - in_size = prod(size(Δ)[1:N-2]) - - if align_corners - ratios = ntuple(i -> T((size(dx,i)-1) / (size(Δ,i)-1)), N-2) # reversed compared to forward pass - else - ratios = ntuple(i -> T(size(dx,i) / size(Δ,i)), N-2) - end - - kernel = @cuda launch=false ∇upsample_linear_cuda_kernel!(in_size, ratios..., Δ, dx, align_corners) - config = launch_configuration(kernel.fun; max_threads=256) - threads = Base.min(in_size, config.threads) - blocks = cld(in_size, threads) - kernel(in_size, ratios..., Δ, dx, align_corners; threads=threads, blocks=blocks) - return dx -end - - -########### -# linear -########### -function upsample_linear_cuda_kernel!(n_elem, rwidth, x::CuDeviceArray{<:Any, 3}, y::CuDeviceArray{<:Any, 3}, align_corners) - index = (threadIdx().x - 1) + (blockIdx().x - 1) * blockDim().x - - if index < n_elem - in_w, channels, batchsize = size(x) - out_w, _, _ = size(y) - - ow = index % out_w - - # real_index = rwidth*ow - real_index = compute_source_index(rwidth, ow, align_corners) - iw0 = Base.floor(Int, real_index) - offset = (iw0 < in_w-1) ? 1 : 0 - iw1 = iw0 + offset + 1 - w1lambda = real_index - iw0 - w0lambda = 1 - w1lambda - iw0 += 1 - - @inbounds for n in 1:batchsize - for c in 1:channels - val = (w0lambda * x[iw0, c, n] + # w0 * i00 - w1lambda * x[iw1, c, n]) # w1 * i01 - y[ow+1, c, n] = val - end - end - end - return nothing -end - -# Δ is the gradient backpropagated from downstream layers -function ∇upsample_linear_cuda_kernel!(n_elem, rwidth, Δ::CuDeviceArray{<:Any, 3}, dx::CuDeviceArray{<:Any, 3}, align_corners) - index = (threadIdx().x - 1) + (blockIdx().x - 1) * blockDim().x - - if index < n_elem - in_width, channels, batchsize = size(Δ) - out_width, _, _ = size(dx) - - iw = index % in_width - - # real_index_w = rwidth * iw - real_index_w = compute_source_index(rwidth, iw, align_corners) - ow0 = Base.floor(Int, real_index_w) - offset = (ow0 < out_width - 1) ? 1 : 0 - ow1 = ow0 + offset + 1 - w1lambda = real_index_w - ow0 - w0lambda = 1 - w1lambda - ow0 += 1 - - @inbounds for n in 1:batchsize - for c in 1:channels - val = Δ[iw+1, c, n] - CUDA.@atomic dx[ow0, c, n] += w0lambda * val - CUDA.@atomic dx[ow1, c, n] += w1lambda * val - end - end - end # if - return nothing -end - - -########### -# bilinear -########### -function upsample_linear_cuda_kernel!(n_elem, rwidth, rheight, x::CuDeviceArray{<:Any, 4}, y::CuDeviceArray{<:Any, 4}, align_corners) - index = (threadIdx().x - 1) + (blockIdx().x - 1) * blockDim().x - - if index < n_elem - in_w, in_h, channels, batchsize = size(x) - out_w, out_h, _, _ = size(y) - - ow = index % out_w - oh = index ÷ out_w - - # real_index = rheight*oh - real_index = compute_source_index(rheight, oh, align_corners) - ih0 = Base.floor(Int, real_index) - offset = (ih0 < in_h-1) ? 1 : 0 - ih1 = ih0 + offset + 1 - h1lambda = real_index - ih0 - h0lambda = 1 - h1lambda - ih0 += 1 - - # real_index = rwidth*ow - real_index = compute_source_index(rwidth, ow, align_corners) - iw0 = Base.floor(Int, real_index) - offset = (iw0 < in_w-1) ? 1 : 0 - iw1 = iw0 + offset + 1 - w1lambda = real_index - iw0 - w0lambda = 1 - w1lambda - iw0 += 1 - - @inbounds for n in 1:batchsize - for c in 1:channels - val = h0lambda * (w0lambda * x[iw0, ih0, c, n] + # h0 * w0 * i00 - w1lambda * x[iw1, ih0, c, n]) + # h0 * w1 * i01 - h1lambda * (w0lambda * x[iw0, ih1, c, n] + # h1 * w0 * i10 - w1lambda * x[iw1, ih1, c, n]) # h1 * w1 * i11 - y[ow+1, oh+1, c, n] = val - end - end - end - return nothing -end - -# Δ is the gradient backpropagated from downstream layers -function ∇upsample_linear_cuda_kernel!(n_elem, rwidth, rheight, Δ::CuDeviceArray{<:Any, 4}, dx::CuDeviceArray{<:Any, 4}, align_corners) - index = (threadIdx().x - 1) + (blockIdx().x - 1) * blockDim().x - - if index < n_elem - in_width, in_height, channels, batchsize = size(Δ) - out_width, out_height, _, _ = size(dx) - - iw = index % in_width - ih = index ÷ in_width - - # Compute Y axis lambdas - # real_index_h = rheight*ih - real_index_h = compute_source_index(rheight, ih, align_corners) - oh0 = Base.floor(Int, real_index_h) - offset = (oh0 < out_height-1) ? 1 : 0 - oh1 = oh0 + offset + 1 - h1lambda = real_index_h - oh0 - h0lambda = 1 - h1lambda - oh0 += 1 - - # # Compute X axis lambdas - # real_index_w = rwidth * iw - real_index_w = compute_source_index(rwidth, iw, align_corners) - ow0 = Base.floor(Int, real_index_w) - offset = (ow0 < out_width - 1) ? 1 : 0 - ow1 = ow0 + offset + 1 - w1lambda = real_index_w - ow0 - w0lambda = 1 - w1lambda - ow0 += 1 - - @inbounds for n in 1:batchsize - for c in 1:channels - val = Δ[iw+1, ih+1, c, n] - CUDA.@atomic dx[ow0, oh0, c, n] += h0lambda * w0lambda * val - CUDA.@atomic dx[ow1, oh0, c, n] += h0lambda * w1lambda * val - CUDA.@atomic dx[ow0, oh1, c, n] += h1lambda * w0lambda * val - CUDA.@atomic dx[ow1, oh1, c, n] += h1lambda * w1lambda * val - end - end - end # if - return nothing -end - - -########### -# trilinear -########### -function upsample_linear_cuda_kernel!(n_elem, rwidth, rheight, rdepth, x::CuDeviceArray{<:Any, 5}, y::CuDeviceArray{<:Any, 5}, align_corners) - index = (threadIdx().x - 1) + (blockIdx().x - 1) * blockDim().x - - if index < n_elem - in_w, in_h, in_d, channels, batchsize = size(x) - out_w, out_h, out_d, _, _ = size(y) - - ow = (index % (out_w * out_h)) % out_w - oh = (index % (out_w * out_h)) ÷ out_w - od = index ÷ (out_w * out_h) - - # real_index = rwidth*ow - real_index = compute_source_index(rwidth, ow, align_corners) - iw0 = Base.floor(Int, real_index) - offset = (iw0 < in_w-1) ? 1 : 0 - iw1 = iw0 + offset + 1 - w1lambda = real_index - iw0 - w0lambda = 1 - w1lambda - iw0 += 1 - - # real_index = rheight*oh - real_index = compute_source_index(rheight, oh, align_corners) - ih0 = Base.floor(Int, real_index) - offset = (ih0 < in_h-1) ? 1 : 0 - ih1 = ih0 + offset + 1 - h1lambda = real_index - ih0 - h0lambda = 1 - h1lambda - ih0 += 1 - - # real_index = rdepth*od - real_index = compute_source_index(rdepth, od, align_corners) - id0 = Base.floor(Int, real_index) - offset = (id0 < in_d-1) ? 1 : 0 - id1 = id0 + offset + 1 - d1lambda = real_index - id0 - d0lambda = 1 - d1lambda - id0 += 1 - - @inbounds for n in 1:batchsize - for c in 1:channels - val = d0lambda * - (h0lambda * - (w0lambda * x[iw0, ih0, id0, c, n] + - w1lambda * x[iw1, ih0, id0, c, n]) + - h1lambda * - (w0lambda * x[iw0, ih1, id0, c, n] + - w1lambda * x[iw1, ih1, id0, c, n])) + - d1lambda * - (h0lambda * - (w0lambda * x[iw0, ih0, id1, c, n] + - w1lambda * x[iw1, ih0, id1, c, n]) + - h1lambda * - (w0lambda * x[iw0, ih1, id1, c, n] + - w1lambda * x[iw1, ih1, id1, c, n])) - - y[ow+1, oh+1, od+1, c, n] = val - end - end - end - return nothing -end - -# Δ is the gradient backpropagated from downstream layers -function ∇upsample_linear_cuda_kernel!(n_elem, rwidth, rheight, rdepth, Δ::CuDeviceArray{<:Any, 5}, dx::CuDeviceArray{<:Any, 5}, align_corners) - index = (threadIdx().x - 1) + (blockIdx().x - 1) * blockDim().x - - if index < n_elem - in_width, in_height, in_depth, channels, batchsize = size(Δ) - out_width, out_height, out_depth, _, _ = size(dx) - - iw = (index % (in_height * in_width)) % in_width - ih = (index % (in_height * in_width)) ÷ in_width - id = index ÷ (in_height * in_width) - - real_index_w = compute_source_index(rwidth, iw, align_corners) - ow0 = Base.floor(Int, real_index_w) - offset = (ow0 < out_width - 1) ? 1 : 0 - ow1 = ow0 + offset + 1 - w1lambda = real_index_w - ow0 - w0lambda = 1 - w1lambda - ow0 += 1 - - real_index_h = compute_source_index(rheight, ih, align_corners) - oh0 = Base.floor(Int, real_index_h) - offset = (oh0 < out_height-1) ? 1 : 0 - oh1 = oh0 + offset + 1 - h1lambda = real_index_h - oh0 - h0lambda = 1 - h1lambda - oh0 += 1 - - real_index_d = compute_source_index(rdepth, id, align_corners) - od0 = Base.floor(Int, real_index_d) - offset = (od0 < out_depth-1) ? 1 : 0 - od1 = od0 + offset + 1 - d1lambda = real_index_d - od0 - d0lambda = 1 - d1lambda - od0 += 1 - - @inbounds for n in 1:batchsize - for c in 1:channels - val = Δ[iw+1, ih+1, id+1, c, n] - CUDA.@atomic dx[ow0, oh0, od0, c, n] += w0lambda * h0lambda * d0lambda * val - CUDA.@atomic dx[ow1, oh0, od0, c, n] += w1lambda * h0lambda * d0lambda * val - CUDA.@atomic dx[ow0, oh1, od0, c, n] += w0lambda * h1lambda * d0lambda * val - CUDA.@atomic dx[ow1, oh1, od0, c, n] += w1lambda * h1lambda * d0lambda * val - - CUDA.@atomic dx[ow0, oh0, od1, c, n] += w0lambda * h0lambda * d1lambda * val - CUDA.@atomic dx[ow1, oh0, od1, c, n] += w1lambda * h0lambda * d1lambda * val - CUDA.@atomic dx[ow0, oh1, od1, c, n] += w0lambda * h1lambda * d1lambda * val - CUDA.@atomic dx[ow1, oh1, od1, c, n] += w1lambda * h1lambda * d1lambda * val - end - end - end # if - return nothing -end diff --git a/ext/NNlibCUDA/test/runtests.jl b/ext/NNlibCUDA/test/runtests.jl index 5b76d48ef..8af877bba 100644 --- a/ext/NNlibCUDA/test/runtests.jl +++ b/ext/NNlibCUDA/test/runtests.jl @@ -13,7 +13,6 @@ include("activations.jl") include("dropout.jl") include("batchedadjtrans.jl") include("batchedmul.jl") -include("upsample.jl") include("conv.jl") include("ctc.jl") include("fold.jl") diff --git a/ext/NNlibCUDA/test/upsample.jl b/ext/NNlibCUDA/test/upsample.jl deleted file mode 100644 index ee881ea66..000000000 --- a/ext/NNlibCUDA/test/upsample.jl +++ /dev/null @@ -1,77 +0,0 @@ -@testset "Linear upsampling" begin - x = Float32.(1:10)[:,:,:] - x = cat(x, x; dims=2) - x = cat(x, x; dims=3) - xgpu = cu(x) - - y_true = Float32.(1:1//2:10) - y_true = cat(y_true, y_true; dims=2) - y_true = cat(y_true, y_true; dims=3) - y_true_gpu = cu(y_true) - - y = upsample_linear(xgpu; size=19) - - @test size(y) == size(y_true_gpu) - @test eltype(y) == Float32 - @test collect(y) ≈ collect(y_true_gpu) - - gputest(x -> upsample_linear(x, 2), x, atol=1e-5) -end - -@testset "Bilinear upsampling" begin - x = Float32[1 2; 3 4][:,:,:,:] - x = cat(x,x; dims=3) - x = cat(x,x; dims=4) - xgpu = cu(x) - - y_true = Float32[ 1//1 4//3 5//3 2//1; - 7//5 26//15 31//15 12//5; - 9//5 32//15 37//15 14//5; - 11//5 38//15 43//15 16//5; - 13//5 44//15 49//15 18//5; - 3//1 10//3 11//3 4//1] - y_true = cat(y_true,y_true; dims=3) - y_true = cat(y_true,y_true; dims=4) - y_true_gpu = cu(y_true) - - y = upsample_bilinear(xgpu, (3,2)) - @test size(y) == size(y_true_gpu) - @test eltype(y) == Float32 - @test collect(y) ≈ collect(y_true_gpu) - - o = CUDA.ones(Float32,6,4,2,1) - grad_true = 6*CUDA.ones(Float32,2,2,2,1) - @test ∇upsample_bilinear(o; size=(2,2)) ≈ grad_true - - gputest(x -> upsample_bilinear(x, (3, 2)), x, atol=1e-5) -end - -@testset "Trilinear upsampling" begin - # Layout: WHDCN, where D is depth - # we generate data which is constant along W & H and differs in D - # then we upsample along all dimensions - x = CUDA.ones(Float32, 3,3,3,1,1) - x[:,:,1,:,:] .= 1. - x[:,:,2,:,:] .= 2. - x[:,:,3,:,:] .= 3. - - y_true = CUDA.ones(Float32, 5,5,5,1,1) - y_true[:,:,1,:,:] .= 1. - y_true[:,:,2,:,:] .= 1.5 - y_true[:,:,3,:,:] .= 2. - y_true[:,:,4,:,:] .= 2.5 - y_true[:,:,5,:,:] .= 3. - - y = upsample_trilinear(x; size=(5,5,5)) - - @test size(y) == size(y_true) - @test eltype(y) == Float32 - @test collect(y) ≈ collect(y_true) - - # this test only works when align_corners=false - # o = CUDA.ones(Float32,8,8,8,1,1) - # grad_true = 8*CUDA.ones(Float32,4,4,4,1,1) - # @test ∇upsample_trilinear(o; size=(4,4,4)) ≈ grad_true - - gputest(x -> upsample_trilinear(x, (2,2,2)), x, atol=1e-5) -end diff --git a/src/upsample.jl b/src/upsample.jl index b4ac2d50c..5c2250454 100644 --- a/src/upsample.jl +++ b/src/upsample.jl @@ -71,6 +71,52 @@ function pixel_shuffle(x::AbstractArray, r::Integer) return reshape(x, map(s -> s*r, sizein)..., cout, n) end +# +# Upsampling +# +# GPU based bilinear upsampling including its gradient +# +# Based on the Caffe2 implementation at: +# The code is a translation from the following files: +# - https://github.com/pytorch/pytorch/blob/v1.8.0-rc1/caffe2/operators/upsample_op.cu +# - https://github.com/pytorch/pytorch/blob/v1.8.0-rc1/caffe2/core/common_gpu.h +# +# Copyright (c) 2016-2021 Facebook Inc. +# Copyright (c) 2015 Google Inc. +# Copyright (c) 2015 Yangqing Jia +# Copyright 2019-2020 Kakao Brain +# +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, are +# permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this list of +# conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, this list of +# conditions and the following disclaimer in the documentation and/or other materials +# provided with the distribution. +# +# 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America and +# IDIAP Research Institute nor the names of its contributors may be used to endorse or +# promote products derived from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE +# COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) +# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR +# TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# Forward and backward pass have been tested to produce the same output +# as pytorch with align_corners=True - it works modulo bit noise. +# pytorch's default is align_corners=False, because otherwise the gradients depend on the +# image size, which should be avoided -> this should be considered here as well + """ upsample_nearest(x, scale::NTuple{S,Int}) upsample_nearest(x; size::NTuple{S,Int}) @@ -346,8 +392,9 @@ end end end -@kernel function _∇upsample_linear_kernel!(dx::T, Δ::T, rwidth, align::Val{A}) where { - T <: AbstractArray{<: Any, 3}, A, +@kernel function _∇upsample_linear_kernel!(dx::T1, Δ::T2, rwidth, align::Val{A}) where { + T1 <: AbstractArray{<: Any, 3}, + T2 <: AbstractArray{<: Any, 3}, A, } @uniform in_width::UInt32, channels::UInt32, batch::UInt32 = size(Δ) @uniform out_width::UInt32 = size(dx, 1) @@ -380,8 +427,9 @@ end end end -@kernel function _∇upsample_linear_kernel!(dx::T, Δ::T, rwidth, rheight, align::Val{A}) where { - T <: AbstractArray{<: Any, 4}, A, +@kernel function _∇upsample_linear_kernel!(dx::T1, Δ::T2, rwidth, rheight, align::Val{A}) where { + T1 <: AbstractArray{<: Any, 4}, + T2 <: AbstractArray{<: Any, 4}, A, } @uniform in_width::UInt32, in_height::UInt32, channels::UInt32, batch::UInt32 = size(Δ) @uniform out_width::UInt32, out_height::UInt32 = size(dx)[1:2] @@ -406,7 +454,7 @@ end T <: AbstractArray{<: Any, 5}, A, } @uniform in_width::UInt32, in_height::UInt32, in_depth::UInt32 = size(x)[1:3] - @uniform channels::UInt32, batch::UInt32 = size(x)[4:5] + @uniform channels::UInt32, batch::UInt32 = size(x, 4), size(x, 5) i::UInt32, j::UInt32, k::UInt32 = @index(Global, NTuple) @@ -425,11 +473,12 @@ end end end -@kernel function _∇upsample_linear_kernel!(dx::T, Δ::T, rwidth, rheight, rdepth, align::Val{A}) where { - T <: AbstractArray{<: Any, 5}, A, +@kernel function _∇upsample_linear_kernel!(dx::T1, Δ::T2, rwidth, rheight, rdepth, align::Val{A}) where { + T1 <: AbstractArray{<: Any, 5}, + T2 <: AbstractArray{<: Any, 5}, A, } @uniform in_width::UInt32, in_height::UInt32, in_depth::UInt32 = size(Δ)[1:3] - @uniform channels::UInt32, batch::UInt32 = size(Δ)[4:5] + @uniform channels::UInt32, batch::UInt32 = size(Δ, 4), size(Δ, 5) @uniform out_width::UInt32, out_height::UInt32, out_depth::UInt32 = size(dx)[1:3] i::UInt32, j::UInt32, k::UInt32 = @index(Global, NTuple) diff --git a/test/runtests.jl b/test/runtests.jl index 0c038bbb6..6414d9bf8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -34,158 +34,160 @@ function nnlib_testsuite(Backend; skip_tests = Set{String}()) end end -@testset verbose=true "NNlib.jl - Test Suite" begin - @testset "CPU" begin - nnlib_testsuite(CPU) - end +@testset "NNlib.jl" verbose=true begin + @testset "Test Suite" begin + @testset "CPU" begin + nnlib_testsuite(CPU) + end - if get(ENV, "NNLIB_TEST_CUDA", "false") == "true" - using CUDA - if CUDA.functional() - @testset "CUDABackend" begin - nnlib_testsuite(CUDABackend) + if get(ENV, "NNLIB_TEST_CUDA", "false") == "true" + using CUDA + if CUDA.functional() + @testset "CUDABackend" begin + nnlib_testsuite(CUDABackend) + end + else + @info "CUDA.jl is not functional. Skipping test suite for CUDABackend." end else - @info "CUDA.jl is not functional. Skipping test suite for CUDABackend." + @info "Skipping CUDA tests, set NNLIB_TEST_CUDA=true to run them." end - else - @info "Skipping CUDA tests, set NNLIB_TEST_CUDA=true to run them." - end - if get(ENV, "NNLIB_TEST_AMDGPU", "false") == "true" - import Pkg - test_info = Pkg.project() - # Add MIOpen_jll to AMDGPU. - Pkg.develop("AMDGPU") - Pkg.activate(joinpath(Pkg.devdir(), "AMDGPU")) - Pkg.add("MIOpen_jll") - Pkg.update() - # Update test project. - Pkg.activate(test_info.path) - Pkg.update() - - using AMDGPU - if AMDGPU.functional() - @testset "ROCBackend" begin - nnlib_testsuite(ROCBackend) + if get(ENV, "NNLIB_TEST_AMDGPU", "false") == "true" + import Pkg + test_info = Pkg.project() + # Add MIOpen_jll to AMDGPU. + Pkg.develop("AMDGPU") + Pkg.activate(joinpath(Pkg.devdir(), "AMDGPU")) + Pkg.add("MIOpen_jll") + Pkg.update() + # Update test project. + Pkg.activate(test_info.path) + Pkg.update() + + using AMDGPU + if AMDGPU.functional() + @testset "ROCBackend" begin + nnlib_testsuite(ROCBackend) + end + else + @info "AMDGPU.jl is not functional. Skipping test suite for ROCBackend." end else - @info "AMDGPU.jl is not functional. Skipping test suite for ROCBackend." + @info "Skipping AMDGPU tests, set NNLIB_TEST_AMDGPU=true to run them." end - else - @info "Skipping AMDGPU tests, set NNLIB_TEST_AMDGPU=true to run them." end -end -@testset verbose=true "NNlib.jl" begin - if get(ENV, "NNLIB_TEST_CUDA", "false") == "true" - using CUDA - if CUDA.functional() - import Pkg - using NNlibCUDA - @testset "CUDA" begin - Pkg.test("NNlibCUDA") + @testset "NNlib.jl" begin + if get(ENV, "NNLIB_TEST_CUDA", "false") == "true" + using CUDA + if CUDA.functional() + import Pkg + using NNlibCUDA + @testset "CUDA" begin + Pkg.test("NNlibCUDA") + end + else + @info "Insufficient version or CUDA not found; Skipping CUDA tests" end else - @info "Insufficient version or CUDA not found; Skipping CUDA tests" + @info "Skipping CUDA tests, set NNLIB_TEST_CUDA=true to run them" end - else - @info "Skipping CUDA tests, set NNLIB_TEST_CUDA=true to run them" - end - if get(ENV, "NNLIB_TEST_AMDGPU", "false") == "true" - import Pkg - test_info = Pkg.project() - - # Add MIOpen_jll to AMDGPU. - Pkg.develop("AMDGPU") - Pkg.activate(joinpath(Pkg.devdir(), "AMDGPU")) - Pkg.add("MIOpen_jll") - Pkg.update() - # Update test project. - Pkg.activate(test_info.path) - Pkg.update() - - using AMDGPU - AMDGPU.versioninfo() - if AMDGPU.functional() && AMDGPU.functional(:MIOpen) - @show AMDGPU.MIOpen.version() - @testset "AMDGPU" begin - include("ext_amdgpu/runtests.jl") + if get(ENV, "NNLIB_TEST_AMDGPU", "false") == "true" + import Pkg + test_info = Pkg.project() + + # Add MIOpen_jll to AMDGPU. + Pkg.develop("AMDGPU") + Pkg.activate(joinpath(Pkg.devdir(), "AMDGPU")) + Pkg.add("MIOpen_jll") + Pkg.update() + # Update test project. + Pkg.activate(test_info.path) + Pkg.update() + + using AMDGPU + AMDGPU.versioninfo() + if AMDGPU.functional() && AMDGPU.functional(:MIOpen) + @show AMDGPU.MIOpen.version() + @testset "AMDGPU" begin + include("ext_amdgpu/runtests.jl") + end + else + @info "AMDGPU.jl package is not functional. Skipping AMDGPU tests." end else - @info "AMDGPU.jl package is not functional. Skipping AMDGPU tests." + @info "Skipping AMDGPU tests, set NNLIB_TEST_CUDA=true to run them." end - else - @info "Skipping AMDGPU tests, set NNLIB_TEST_CUDA=true to run them." - end - @testset "Doctests" begin - doctest(NNlib, manual=false) - end + @testset "Doctests" begin + doctest(NNlib, manual=false) + end - @testset "Activation Functions" begin - include("activations.jl") - end + @testset "Activation Functions" begin + include("activations.jl") + end - @testset "Attention" begin - include("attention.jl") - end + @testset "Attention" begin + include("attention.jl") + end - @testset "Batched Multiplication" begin - include("batchedmul.jl") - end + @testset "Batched Multiplication" begin + include("batchedmul.jl") + end - @testset "Convolution" begin - include("conv.jl") - include("conv_bias_act.jl") - end + @testset "Convolution" begin + include("conv.jl") + include("conv_bias_act.jl") + end - @testset "CTC Loss" begin - include("ctc.jl") - end + @testset "CTC Loss" begin + include("ctc.jl") + end - @testset "Dropout" begin - include("dropout.jl") - end + @testset "Dropout" begin + include("dropout.jl") + end - @testset "Fold/Unfold" begin - include("fold.jl") - end + @testset "Fold/Unfold" begin + include("fold.jl") + end - @testset "Inference" begin - include("inference.jl") - end + @testset "Inference" begin + include("inference.jl") + end - @testset "Pooling" begin - include("pooling.jl") - end + @testset "Pooling" begin + include("pooling.jl") + end - @testset "Padding" begin - include("padding.jl") - end + @testset "Padding" begin + include("padding.jl") + end - @testset "Softmax" begin - include("softmax.jl") - end + @testset "Softmax" begin + include("softmax.jl") + end - @testset "Gather" begin - include("gather.jl") - end + @testset "Gather" begin + include("gather.jl") + end - @testset "Scatter" begin - include("scatter.jl") - end + @testset "Scatter" begin + include("scatter.jl") + end - @testset "Utilities" begin - include("utils.jl") - end + @testset "Utilities" begin + include("utils.jl") + end - @testset "Grid Sampling" begin - include("sampling.jl") - end + @testset "Grid Sampling" begin + include("sampling.jl") + end - @testset "Functions" begin - include("functions.jl") + @testset "Functions" begin + include("functions.jl") + end end end diff --git a/test/test_utils.jl b/test/test_utils.jl index da3991156..16b3998dc 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -23,9 +23,9 @@ function gradtest( if check_broadcast length(fkwargs) > 0 && @warn("CHECK_BROADCAST: dropping keywords args") - h = (xs...) -> sum(sin.(f.(xs...))) + h = (xs...) -> sum(f.(xs...)) else - h = (xs...) -> sum(sin.(f(xs...; fkwargs...))) + h = (xs...) -> sum(f(xs...; fkwargs...)) end y_true = h(xs...) @@ -70,8 +70,8 @@ function gputest(f, xs...; checkgrad=true, atol=1e-6, kws...) @test collect(cpu_y) ≈ collect(y) if checkgrad - cpu_grad = gradient((x...) -> sum(sin.(f(x...; kws...))), cpu_xs...) - gpu_grad = gradient((x...) -> sum(sin.(f(x...; kws...))), xs...) + cpu_grad = gradient((x...) -> sum(f(x...; kws...)), cpu_xs...) + gpu_grad = gradient((x...) -> sum(f(x...; kws...)), xs...) for (cpu_g, gpu_g) in zip(cpu_grad, adapt(CPU(), gpu_grad)) if cpu_g === nothing From 4c5f827d7a62f605e3c02d8666021226fb200acd Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Thu, 6 Apr 2023 12:19:13 +0300 Subject: [PATCH 03/11] =?UTF-8?q?Remove=20allocation=20in=20'=E2=88=87upsa?= =?UTF-8?q?mple=5Flinear'?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/upsample.jl | 2 +- test/runtests.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/upsample.jl b/src/upsample.jl index 5c2250454..d507582ab 100644 --- a/src/upsample.jl +++ b/src/upsample.jl @@ -248,7 +248,7 @@ function ∇upsample_linear(Δ::AbstractArray{T,N}; size::NTuple{<:Any,Integer}, if Base.size(Δ)[1:N-2] == size return Δ end - dx = zero(similar(Δ, T, size..., Base.size(Δ)[end-1:end]...)) + dx = fill!(similar(Δ, T, size..., Base.size(Δ)[end-1:end]...), zero(T)) return ∇upsample_linear_kernel!(dx, Δ; align_corners) end diff --git a/test/runtests.jl b/test/runtests.jl index 6414d9bf8..d03d4bf1d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -78,7 +78,7 @@ end end end - @testset "NNlib.jl" begin + @testset "Tests" begin if get(ENV, "NNLIB_TEST_CUDA", "false") == "true" using CUDA if CUDA.functional() From efedbc12e3e7f42263301ca4dae2f85d150dceb7 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Thu, 6 Apr 2023 13:15:44 +0300 Subject: [PATCH 04/11] Change indexing order --- src/upsample.jl | 49 ++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 40 insertions(+), 9 deletions(-) diff --git a/src/upsample.jl b/src/upsample.jl index d507582ab..6fa6f9a53 100644 --- a/src/upsample.jl +++ b/src/upsample.jl @@ -355,7 +355,8 @@ upsample_trilinear(x; size, align_corners::Bool = true) = upsample_linear(x; si function upsample_linear_kernel!( y::AbstractArray{T, N}, x::AbstractArray{T, N}; align_corners::Bool = true, ) where {T, N} - ndrange = size(y)[1:N - 2] + # ndrange = size(y)[1:N - 2] + ndrange = size(y)[N - 1:end] ratios = align_corners ? ntuple(i -> real(T)((size(x, i) - 1) / (size(y, i) - 1)), N - 2) : ntuple(i -> real(T)(size(x, i) / size(y, i)), N - 2) @@ -414,19 +415,49 @@ end T <: AbstractArray{<: Any, 4}, A, } @uniform in_width::UInt32, in_height::UInt32, channels::UInt32, batch::UInt32 = size(x) + @uniform out_width::UInt32, out_height::UInt32 = size(y)[1:2] + + c::UInt32, n::UInt32 = @index(Global, NTuple) + + for j in UnitRange{UInt32}(1, out_height) + ih0, ih1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - 0x1, align, in_height) + for i in UnitRange{UInt32}(1, out_width) + iw0, iw1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, in_width) + @inbounds y[i, j, c, n] = + h0lambda * (w0lambda * x[iw0, ih0, c, n] + w1lambda * x[iw1, ih0, c, n]) + + h1lambda * (w0lambda * x[iw0, ih1, c, n] + w1lambda * x[iw1, ih1, c, n]) + end + end - i::UInt32, j::UInt32 = @index(Global, NTuple) + # i::UInt32, j::UInt32 = @index(Global, NTuple) - iw0, iw1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, in_width) - ih0, ih1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - 0x1, align, in_height) + # iw0, iw1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, in_width) + # ih0, ih1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - 0x1, align, in_height) - @inbounds for n in 1:batch, c in 1:channels - y[i, j, c, n] = - h0lambda * (w0lambda * x[iw0, ih0, c, n] + w1lambda * x[iw1, ih0, c, n]) + - h1lambda * (w0lambda * x[iw0, ih1, c, n] + w1lambda * x[iw1, ih1, c, n]) - end + # @inbounds for n in 1:batch, c in 1:channels + # y[i, j, c, n] = + # h0lambda * (w0lambda * x[iw0, ih0, c, n] + w1lambda * x[iw1, ih0, c, n]) + + # h1lambda * (w0lambda * x[iw0, ih1, c, n] + w1lambda * x[iw1, ih1, c, n]) + # end end +# @kernel function _upsample_linear_kernel!(y::T, x::T, rwidth, rheight, align::Val{A}) where { +# T <: AbstractArray{<: Any, 4}, A, +# } +# @uniform in_width::UInt32, in_height::UInt32, channels::UInt32, batch::UInt32 = size(x) + +# i::UInt32, j::UInt32 = @index(Global, NTuple) + +# iw0, iw1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, in_width) +# ih0, ih1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - 0x1, align, in_height) + +# @inbounds for n in 1:batch, c in 1:channels +# y[i, j, c, n] = +# h0lambda * (w0lambda * x[iw0, ih0, c, n] + w1lambda * x[iw1, ih0, c, n]) + +# h1lambda * (w0lambda * x[iw0, ih1, c, n] + w1lambda * x[iw1, ih1, c, n]) +# end +# end + @kernel function _∇upsample_linear_kernel!(dx::T1, Δ::T2, rwidth, rheight, align::Val{A}) where { T1 <: AbstractArray{<: Any, 4}, T2 <: AbstractArray{<: Any, 4}, A, From e4966168d8855fdcdaa7c8270cee7f2c41c18051 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Thu, 6 Apr 2023 16:20:59 +0300 Subject: [PATCH 05/11] Specialize kernels for CPU & GPU --- src/NNlib.jl | 17 ++-- src/upsample.jl | 205 ++++++++++++++++++++++++++++++++--------------- test/runtests.jl | 4 - test/upsample.jl | 2 +- 4 files changed, 150 insertions(+), 78 deletions(-) diff --git a/src/NNlib.jl b/src/NNlib.jl index 61e6bf52b..184fbfc74 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -1,19 +1,20 @@ module NNlib -using Pkg -using Requires -using ChainRulesCore import ChainRulesCore: rrule + using Base.Broadcast: broadcasted using Base.Threads +using ChainRulesCore +using KernelAbstractions +using KernelAbstractions: @atomic +using LinearAlgebra +using LinearAlgebra.BLAS: @blasfunc, BlasInt +using LinearAlgebra: AdjOrTransAbsMat, Adjoint, BlasFloat, Transpose +using Pkg using Random +using Requires using Statistics using Statistics: mean -using LinearAlgebra -using LinearAlgebra: BlasFloat, Transpose, Adjoint, AdjOrTransAbsMat -using LinearAlgebra.BLAS: BlasInt, @blasfunc -using KernelAbstractions -using KernelAbstractions: @atomic const libblas = Base.libblas_name diff --git a/src/upsample.jl b/src/upsample.jl index 6fa6f9a53..cdfd5452c 100644 --- a/src/upsample.jl +++ b/src/upsample.jl @@ -355,51 +355,78 @@ upsample_trilinear(x; size, align_corners::Bool = true) = upsample_linear(x; si function upsample_linear_kernel!( y::AbstractArray{T, N}, x::AbstractArray{T, N}; align_corners::Bool = true, ) where {T, N} - # ndrange = size(y)[1:N - 2] - ndrange = size(y)[N - 1:end] + backend = KernelAbstractions.get_backend(x) + ndrange = backend isa CPU ? + size(y)[N - 1:end] : # Parallelization along channel x batch. + size(y)[1:N - 2] # Parallelization along WHD. ratios = align_corners ? ntuple(i -> real(T)((size(x, i) - 1) / (size(y, i) - 1)), N - 2) : ntuple(i -> real(T)(size(x, i) / size(y, i)), N - 2) - - backend = KernelAbstractions.get_backend(x) - _upsample_linear_kernel!(backend)(y, x, ratios..., Val(align_corners); ndrange) + _upsample_linear_kernel!(backend)(backend, y, x, ratios..., Val(align_corners); ndrange) return y end function ∇upsample_linear_kernel!( dx::AbstractArray{T, N}, Δ::AbstractArray{T, N}; align_corners::Bool = true, ) where {T, N} - ndrange = size(Δ)[1:N - 2] + backend = KernelAbstractions.get_backend(dx) + ndrange = backend isa CPU ? + size(Δ)[N - 1:end] : # Parallelization along channel x batch. + size(Δ)[1:N - 2] # Parallelization along WHD. ratios = align_corners ? ntuple(i -> real(T)((size(dx, i) - 1) / (size(Δ, i) - 1)), N - 2) : ntuple(i -> real(T)(size(dx, i) / size(Δ, i)), N - 2) - - backend = KernelAbstractions.get_backend(dx) - _∇upsample_linear_kernel!(backend)(dx, Δ, ratios..., Val(align_corners); ndrange) + _∇upsample_linear_kernel!(backend)(backend, dx, Δ, ratios..., Val(align_corners); ndrange) return dx end -# Linear. +# Linear (CPU): parallelization along channel x batch dimensions. -@kernel function _upsample_linear_kernel!(y::T, x::T, rwidth, align::Val{A}) where { - T <: AbstractArray{<: Any, 3}, A, +@kernel function _upsample_linear_kernel!(::CPU, y::T, x::T, rwidth, align::Val{A}) where { + T <: AbstractArray{<:Any, 3}, A, } @uniform in_width::UInt32, channels::UInt32, batch::UInt32 = size(x) + @uniform out_width::UInt32 = size(y, 1) + c::UInt32, n::UInt32 = @index(Global, NTuple) + @inbounds for i in UnitRange{UInt32}(1, out_width) + iw0, iw1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, in_width) + y[i, c, n] = w0lambda * x[iw0, c, n] + w1lambda * x[iw1, c, n] + end +end + +@kernel function _∇upsample_linear_kernel!(::CPU, dx::T1, Δ::T2, rwidth, align::Val{A}) where { + T1 <: AbstractArray{<:Any, 3}, T2 <: AbstractArray{<:Any, 3}, A, +} + @uniform in_width::UInt32, channels::UInt32, batch::UInt32 = size(Δ) + @uniform out_width::UInt32 = size(dx, 1) + c::UInt32, n::UInt32 = @index(Global, NTuple) + @inbounds for i in UnitRange{UInt32}(1, in_width) + ow0, ow1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, out_width) + val = Δ[i, c, n] + @atomic dx[ow0, c, n] += w0lambda * val + @atomic dx[ow1, c, n] += w1lambda * val + end +end + +# Linear (GPU): parallelization along width dimension. +# TODO replace AbstractArray -> AbstractGPUArray once device arrays subtype it. +@kernel function _upsample_linear_kernel!(::B, y::T, x::T, rwidth, align::Val{A}) where { + B <: GPU, T <: AbstractArray{<:Any, 3}, A, +} + @uniform in_width::UInt32, channels::UInt32, batch::UInt32 = size(x) i::UInt32 = @index(Global) - iw0, iw1, w0lambda, w1lambda = source_index_and_lambda( rwidth, i - 0x1, align, in_width) + iw0, iw1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, in_width) @inbounds for n in 1:batch, c in 1:channels y[i, c, n] = w0lambda * x[iw0, c, n] + w1lambda * x[iw1, c, n] end end -@kernel function _∇upsample_linear_kernel!(dx::T1, Δ::T2, rwidth, align::Val{A}) where { - T1 <: AbstractArray{<: Any, 3}, - T2 <: AbstractArray{<: Any, 3}, A, +@kernel function _∇upsample_linear_kernel!(::B, dx::T, Δ::T, rwidth, align::Val{A}) where { + B <: GPU, T <: AbstractArray{<:Any, 3}, A, } @uniform in_width::UInt32, channels::UInt32, batch::UInt32 = size(Δ) @uniform out_width::UInt32 = size(dx, 1) - i::UInt32 = @index(Global) ow0, ow1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, out_width) @inbounds for n in 1:batch, c in 1:channels @@ -409,16 +436,14 @@ end end end -# Bilinear. +# Bilinear (CPU): parallelization along channel x batch dimensions. -@kernel function _upsample_linear_kernel!(y::T, x::T, rwidth, rheight, align::Val{A}) where { - T <: AbstractArray{<: Any, 4}, A, +@kernel function _upsample_linear_kernel!(::CPU, y::T, x::T, rwidth, rheight, align::Val{A}) where { + T <: AbstractArray{<:Any, 4}, A, } @uniform in_width::UInt32, in_height::UInt32, channels::UInt32, batch::UInt32 = size(x) @uniform out_width::UInt32, out_height::UInt32 = size(y)[1:2] - c::UInt32, n::UInt32 = @index(Global, NTuple) - for j in UnitRange{UInt32}(1, out_height) ih0, ih1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - 0x1, align, in_height) for i in UnitRange{UInt32}(1, out_width) @@ -428,48 +453,51 @@ end h1lambda * (w0lambda * x[iw0, ih1, c, n] + w1lambda * x[iw1, ih1, c, n]) end end - - # i::UInt32, j::UInt32 = @index(Global, NTuple) - - # iw0, iw1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, in_width) - # ih0, ih1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - 0x1, align, in_height) - - # @inbounds for n in 1:batch, c in 1:channels - # y[i, j, c, n] = - # h0lambda * (w0lambda * x[iw0, ih0, c, n] + w1lambda * x[iw1, ih0, c, n]) + - # h1lambda * (w0lambda * x[iw0, ih1, c, n] + w1lambda * x[iw1, ih1, c, n]) - # end end -# @kernel function _upsample_linear_kernel!(y::T, x::T, rwidth, rheight, align::Val{A}) where { -# T <: AbstractArray{<: Any, 4}, A, -# } -# @uniform in_width::UInt32, in_height::UInt32, channels::UInt32, batch::UInt32 = size(x) - -# i::UInt32, j::UInt32 = @index(Global, NTuple) +@kernel function _∇upsample_linear_kernel!(::CPU, dx::T1, Δ::T2, rwidth, rheight, align::Val{A}) where { + T1 <: AbstractArray{<:Any, 4}, T2 <: AbstractArray{<:Any, 4}, A, +} + @uniform in_width::UInt32, in_height::UInt32, channels::UInt32, batch::UInt32 = size(Δ) + @uniform out_width::UInt32, out_height::UInt32 = size(dx)[1:2] + c::UInt32, n::UInt32 = @index(Global, NTuple) + for j in UnitRange{UInt32}(1, in_height) + oh0, oh1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - 0x1, align, out_height) + for i in UnitRange{UInt32}(1, in_width) + ow0, ow1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, out_width) + val = Δ[i, j, c, n] + @atomic dx[ow0, oh0, c, n] += w0lambda * h0lambda * val + @atomic dx[ow1, oh0, c, n] += w1lambda * h0lambda * val + @atomic dx[ow0, oh1, c, n] += w0lambda * h1lambda * val + @atomic dx[ow1, oh1, c, n] += w1lambda * h1lambda * val + end + end +end -# iw0, iw1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, in_width) -# ih0, ih1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - 0x1, align, in_height) +# Bilinear (GPU): parallelization along width, height dimensions. -# @inbounds for n in 1:batch, c in 1:channels -# y[i, j, c, n] = -# h0lambda * (w0lambda * x[iw0, ih0, c, n] + w1lambda * x[iw1, ih0, c, n]) + -# h1lambda * (w0lambda * x[iw0, ih1, c, n] + w1lambda * x[iw1, ih1, c, n]) -# end -# end +@kernel function _upsample_linear_kernel!(::B, y::T, x::T, rwidth, rheight, align::Val{A}) where { + B <: GPU, T <: AbstractArray{<:Any, 4}, A, +} + @uniform in_width::UInt32, in_height::UInt32, channels::UInt32, batch::UInt32 = size(x) + i::UInt32, j::UInt32 = @index(Global, NTuple) + iw0, iw1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, in_width) + ih0, ih1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - 0x1, align, in_height) + @inbounds for n in 1:batch, c in 1:channels + y[i, j, c, n] = + h0lambda * (w0lambda * x[iw0, ih0, c, n] + w1lambda * x[iw1, ih0, c, n]) + + h1lambda * (w0lambda * x[iw0, ih1, c, n] + w1lambda * x[iw1, ih1, c, n]) + end +end -@kernel function _∇upsample_linear_kernel!(dx::T1, Δ::T2, rwidth, rheight, align::Val{A}) where { - T1 <: AbstractArray{<: Any, 4}, - T2 <: AbstractArray{<: Any, 4}, A, +@kernel function _∇upsample_linear_kernel!(::B, dx::T, Δ::T, rwidth, rheight, align::Val{A}) where { + B <: GPU, T <: AbstractArray{<:Any, 4}, A, } @uniform in_width::UInt32, in_height::UInt32, channels::UInt32, batch::UInt32 = size(Δ) @uniform out_width::UInt32, out_height::UInt32 = size(dx)[1:2] - i::UInt32, j::UInt32 = @index(Global, NTuple) - ow0, ow1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, out_width) oh0, oh1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - 0x1, align, out_height) - @inbounds for n in 1:batch, c in 1:channels val = Δ[i, j, c, n] @atomic dx[ow0, oh0, c, n] += w0lambda * h0lambda * val @@ -479,20 +507,72 @@ end end end -# Trilinear. +# Trilinear (CPU): parallelization along channel x batch dimensions. -@kernel function _upsample_linear_kernel!(y::T, x::T, rwidth, rheight, rdepth, align::Val{A}) where { - T <: AbstractArray{<: Any, 5}, A, +@kernel function _upsample_linear_kernel!(::CPU, y::T, x::T, rwidth, rheight, rdepth, align::Val{A}) where { + T <: AbstractArray{<:Any, 5}, A, } @uniform in_width::UInt32, in_height::UInt32, in_depth::UInt32 = size(x)[1:3] @uniform channels::UInt32, batch::UInt32 = size(x, 4), size(x, 5) + @uniform out_width::UInt32, out_height::UInt32, out_depth::UInt32 = size(y)[1:3] + c::UInt32, n::UInt32 = @index(Global, NTuple) + for k in UnitRange{UInt32}(1, out_depth) + id0, id1, d0lambda, d1lambda = source_index_and_lambda(rdepth, k - 0x1, align, in_depth) + for j in UnitRange{UInt32}(1, out_height) + ih0, ih1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - 0x1, align, in_height) + for i in UnitRange{UInt32}(1, out_width) + iw0, iw1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, in_width) + @inbounds y[i, j, k, c, n] = + d0lambda * ( + h0lambda * (w0lambda * x[iw0, ih0, id0, c, n] + w1lambda * x[iw1, ih0, id0, c, n]) + + h1lambda * (w0lambda * x[iw0, ih1, id0, c, n] + w1lambda * x[iw1, ih1, id0, c, n])) + + d1lambda * ( + h0lambda * (w0lambda * x[iw0, ih0, id1, c, n] + w1lambda * x[iw1, ih0, id1, c, n]) + + h1lambda * (w0lambda * x[iw0, ih1, id1, c, n] + w1lambda * x[iw1, ih1, id1, c, n])) + end + end + end +end - i::UInt32, j::UInt32, k::UInt32 = @index(Global, NTuple) +@kernel function _∇upsample_linear_kernel!(::CPU, dx::T1, Δ::T2, rwidth, rheight, rdepth, align::Val{A}) where { + T1 <: AbstractArray{<:Any, 5}, T2 <: AbstractArray{<:Any, 5}, A, +} + @uniform in_width::UInt32, in_height::UInt32, in_depth::UInt32 = size(Δ)[1:3] + @uniform channels::UInt32, batch::UInt32 = size(Δ, 4), size(Δ, 5) + @uniform out_width::UInt32, out_height::UInt32, out_depth::UInt32 = size(dx)[1:3] + c::UInt32, n::UInt32 = @index(Global, NTuple) + for k in UnitRange{UInt32}(1, in_depth) + od0, od1, d0lambda, d1lambda = source_index_and_lambda(rdepth, k - 0x1, align, out_depth) + for j in UnitRange{UInt32}(1, in_height) + oh0, oh1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - 0x1, align, out_height) + @inbounds for i in UnitRange{UInt32}(1, in_width) + ow0, ow1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, out_width) + val = Δ[i, j, k, c, n] + @atomic dx[ow0, oh0, od0, c, n] += w0lambda * h0lambda * d0lambda * val + @atomic dx[ow1, oh0, od0, c, n] += w1lambda * h0lambda * d0lambda * val + @atomic dx[ow0, oh1, od0, c, n] += w0lambda * h1lambda * d0lambda * val + @atomic dx[ow1, oh1, od0, c, n] += w1lambda * h1lambda * d0lambda * val + + @atomic dx[ow0, oh0, od1, c, n] += w0lambda * h0lambda * d1lambda * val + @atomic dx[ow1, oh0, od1, c, n] += w1lambda * h0lambda * d1lambda * val + @atomic dx[ow0, oh1, od1, c, n] += w0lambda * h1lambda * d1lambda * val + @atomic dx[ow1, oh1, od1, c, n] += w1lambda * h1lambda * d1lambda * val + end + end + end +end + +# Trilinear (GPU): parallelization along width x height x depth dimensions. +@kernel function _upsample_linear_kernel!(::B, y::T, x::T, rwidth, rheight, rdepth, align::Val{A}) where { + B <: GPU, T <: AbstractArray{<:Any, 5}, A, +} + @uniform in_width::UInt32, in_height::UInt32, in_depth::UInt32 = size(x)[1:3] + @uniform channels::UInt32, batch::UInt32 = size(x, 4), size(x, 5) + i::UInt32, j::UInt32, k::UInt32 = @index(Global, NTuple) iw0, iw1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, in_width) ih0, ih1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - 0x1, align, in_height) id0, id1, d0lambda, d1lambda = source_index_and_lambda(rdepth, k - 0x1, align, in_depth) - @inbounds for n in 1:batch, c in 1:channels y[i, j, k, c, n] = d0lambda * ( @@ -504,20 +584,16 @@ end end end -@kernel function _∇upsample_linear_kernel!(dx::T1, Δ::T2, rwidth, rheight, rdepth, align::Val{A}) where { - T1 <: AbstractArray{<: Any, 5}, - T2 <: AbstractArray{<: Any, 5}, A, +@kernel function _∇upsample_linear_kernel!(::B, dx::T, Δ::T, rwidth, rheight, rdepth, align::Val{A}) where { + B <: GPU, T <: AbstractArray{<:Any, 5}, A, } @uniform in_width::UInt32, in_height::UInt32, in_depth::UInt32 = size(Δ)[1:3] @uniform channels::UInt32, batch::UInt32 = size(Δ, 4), size(Δ, 5) @uniform out_width::UInt32, out_height::UInt32, out_depth::UInt32 = size(dx)[1:3] - i::UInt32, j::UInt32, k::UInt32 = @index(Global, NTuple) - ow0, ow1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, out_width) oh0, oh1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - 0x1, align, out_height) od0, od1, d0lambda, d1lambda = source_index_and_lambda(rdepth, k - 0x1, align, out_depth) - @inbounds for n in 1:batch, c in 1:channels val = Δ[i, j, k, c, n] @atomic dx[ow0, oh0, od0, c, n] += w0lambda * h0lambda * d0lambda * val @@ -545,6 +621,5 @@ end w1lambda = real_index - iw0 w0lambda = T(1) - w1lambda - return iw0 + 0x1, iw1, w0lambda, w1lambda end diff --git a/test/runtests.jl b/test/runtests.jl index d03d4bf1d..3b6df78cb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -56,11 +56,7 @@ end if get(ENV, "NNLIB_TEST_AMDGPU", "false") == "true" import Pkg test_info = Pkg.project() - # Add MIOpen_jll to AMDGPU. Pkg.develop("AMDGPU") - Pkg.activate(joinpath(Pkg.devdir(), "AMDGPU")) - Pkg.add("MIOpen_jll") - Pkg.update() # Update test project. Pkg.activate(test_info.path) Pkg.update() diff --git a/test/upsample.jl b/test/upsample.jl index 6f77ccf04..eac86f4a6 100644 --- a/test/upsample.jl +++ b/test/upsample.jl @@ -27,7 +27,7 @@ function upsample_testsuite(Backend) end @testset "Linear upsampling (1D)" begin - x = Float64[1,2,3,4] + x = T[1,2,3,4] x = hcat(x,x,x)[:,:,:] y = collect(1:1//3:4) From 97fb365fd1ce4a4d70805c3cd31f81907fb86041 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Thu, 6 Apr 2023 16:37:19 +0300 Subject: [PATCH 06/11] Remove @atomic for CPU --- src/upsample.jl | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/upsample.jl b/src/upsample.jl index cdfd5452c..4dffcf251 100644 --- a/src/upsample.jl +++ b/src/upsample.jl @@ -403,8 +403,8 @@ end @inbounds for i in UnitRange{UInt32}(1, in_width) ow0, ow1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, out_width) val = Δ[i, c, n] - @atomic dx[ow0, c, n] += w0lambda * val - @atomic dx[ow1, c, n] += w1lambda * val + dx[ow0, c, n] += w0lambda * val + dx[ow1, c, n] += w1lambda * val end end @@ -466,10 +466,10 @@ end for i in UnitRange{UInt32}(1, in_width) ow0, ow1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, out_width) val = Δ[i, j, c, n] - @atomic dx[ow0, oh0, c, n] += w0lambda * h0lambda * val - @atomic dx[ow1, oh0, c, n] += w1lambda * h0lambda * val - @atomic dx[ow0, oh1, c, n] += w0lambda * h1lambda * val - @atomic dx[ow1, oh1, c, n] += w1lambda * h1lambda * val + dx[ow0, oh0, c, n] += w0lambda * h0lambda * val + dx[ow1, oh0, c, n] += w1lambda * h0lambda * val + dx[ow0, oh1, c, n] += w0lambda * h1lambda * val + dx[ow1, oh1, c, n] += w1lambda * h1lambda * val end end end @@ -548,15 +548,15 @@ end @inbounds for i in UnitRange{UInt32}(1, in_width) ow0, ow1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, out_width) val = Δ[i, j, k, c, n] - @atomic dx[ow0, oh0, od0, c, n] += w0lambda * h0lambda * d0lambda * val - @atomic dx[ow1, oh0, od0, c, n] += w1lambda * h0lambda * d0lambda * val - @atomic dx[ow0, oh1, od0, c, n] += w0lambda * h1lambda * d0lambda * val - @atomic dx[ow1, oh1, od0, c, n] += w1lambda * h1lambda * d0lambda * val - - @atomic dx[ow0, oh0, od1, c, n] += w0lambda * h0lambda * d1lambda * val - @atomic dx[ow1, oh0, od1, c, n] += w1lambda * h0lambda * d1lambda * val - @atomic dx[ow0, oh1, od1, c, n] += w0lambda * h1lambda * d1lambda * val - @atomic dx[ow1, oh1, od1, c, n] += w1lambda * h1lambda * d1lambda * val + dx[ow0, oh0, od0, c, n] += w0lambda * h0lambda * d0lambda * val + dx[ow1, oh0, od0, c, n] += w1lambda * h0lambda * d0lambda * val + dx[ow0, oh1, od0, c, n] += w0lambda * h1lambda * d0lambda * val + dx[ow1, oh1, od0, c, n] += w1lambda * h1lambda * d0lambda * val + + dx[ow0, oh0, od1, c, n] += w0lambda * h0lambda * d1lambda * val + dx[ow1, oh0, od1, c, n] += w1lambda * h0lambda * d1lambda * val + dx[ow0, oh1, od1, c, n] += w0lambda * h1lambda * d1lambda * val + dx[ow1, oh1, od1, c, n] += w1lambda * h1lambda * d1lambda * val end end end From 7cc9524881094350619b338ea558a666ad290a87 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Thu, 6 Apr 2023 16:57:18 +0300 Subject: [PATCH 07/11] Fixes --- src/upsample.jl | 90 ++++++++++++++++++++++++------------------------- 1 file changed, 45 insertions(+), 45 deletions(-) diff --git a/src/upsample.jl b/src/upsample.jl index 4dffcf251..6e0ebb4a1 100644 --- a/src/upsample.jl +++ b/src/upsample.jl @@ -388,8 +388,8 @@ end @uniform in_width::UInt32, channels::UInt32, batch::UInt32 = size(x) @uniform out_width::UInt32 = size(y, 1) c::UInt32, n::UInt32 = @index(Global, NTuple) - @inbounds for i in UnitRange{UInt32}(1, out_width) - iw0, iw1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, in_width) + @inbounds for i in UnitRange{UInt32}(one(UInt32), out_width) + iw0, iw1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - one(UInt32), align, in_width) y[i, c, n] = w0lambda * x[iw0, c, n] + w1lambda * x[iw1, c, n] end end @@ -400,8 +400,8 @@ end @uniform in_width::UInt32, channels::UInt32, batch::UInt32 = size(Δ) @uniform out_width::UInt32 = size(dx, 1) c::UInt32, n::UInt32 = @index(Global, NTuple) - @inbounds for i in UnitRange{UInt32}(1, in_width) - ow0, ow1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, out_width) + @inbounds for i in UnitRange{UInt32}(one(UInt32), in_width) + ow0, ow1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - one(UInt32), align, out_width) val = Δ[i, c, n] dx[ow0, c, n] += w0lambda * val dx[ow1, c, n] += w1lambda * val @@ -416,8 +416,8 @@ end } @uniform in_width::UInt32, channels::UInt32, batch::UInt32 = size(x) i::UInt32 = @index(Global) - iw0, iw1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, in_width) - @inbounds for n in 1:batch, c in 1:channels + iw0, iw1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - one(UInt32), align, in_width) + @inbounds for n in UnitRange{UInt32}(one(UInt32), batch), c in UnitRange{UInt32}(one(UInt32), channels) y[i, c, n] = w0lambda * x[iw0, c, n] + w1lambda * x[iw1, c, n] end end @@ -428,8 +428,8 @@ end @uniform in_width::UInt32, channels::UInt32, batch::UInt32 = size(Δ) @uniform out_width::UInt32 = size(dx, 1) i::UInt32 = @index(Global) - ow0, ow1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, out_width) - @inbounds for n in 1:batch, c in 1:channels + ow0, ow1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - one(UInt32), align, out_width) + @inbounds for n in UnitRange{UInt32}(one(UInt32), batch), c in UnitRange{UInt32}(one(UInt32), channels) val = Δ[i, c, n] @atomic dx[ow0, c, n] += w0lambda * val @atomic dx[ow1, c, n] += w1lambda * val @@ -444,10 +444,10 @@ end @uniform in_width::UInt32, in_height::UInt32, channels::UInt32, batch::UInt32 = size(x) @uniform out_width::UInt32, out_height::UInt32 = size(y)[1:2] c::UInt32, n::UInt32 = @index(Global, NTuple) - for j in UnitRange{UInt32}(1, out_height) - ih0, ih1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - 0x1, align, in_height) - for i in UnitRange{UInt32}(1, out_width) - iw0, iw1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, in_width) + for j in UnitRange{UInt32}(one(UInt32), out_height) + ih0, ih1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - one(UInt32), align, in_height) + for i in UnitRange{UInt32}(one(UInt32), out_width) + iw0, iw1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - one(UInt32), align, in_width) @inbounds y[i, j, c, n] = h0lambda * (w0lambda * x[iw0, ih0, c, n] + w1lambda * x[iw1, ih0, c, n]) + h1lambda * (w0lambda * x[iw0, ih1, c, n] + w1lambda * x[iw1, ih1, c, n]) @@ -461,10 +461,10 @@ end @uniform in_width::UInt32, in_height::UInt32, channels::UInt32, batch::UInt32 = size(Δ) @uniform out_width::UInt32, out_height::UInt32 = size(dx)[1:2] c::UInt32, n::UInt32 = @index(Global, NTuple) - for j in UnitRange{UInt32}(1, in_height) - oh0, oh1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - 0x1, align, out_height) - for i in UnitRange{UInt32}(1, in_width) - ow0, ow1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, out_width) + for j in UnitRange{UInt32}(one(UInt32), in_height) + oh0, oh1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - one(UInt32), align, out_height) + for i in UnitRange{UInt32}(one(UInt32), in_width) + ow0, ow1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - one(UInt32), align, out_width) val = Δ[i, j, c, n] dx[ow0, oh0, c, n] += w0lambda * h0lambda * val dx[ow1, oh0, c, n] += w1lambda * h0lambda * val @@ -481,9 +481,9 @@ end } @uniform in_width::UInt32, in_height::UInt32, channels::UInt32, batch::UInt32 = size(x) i::UInt32, j::UInt32 = @index(Global, NTuple) - iw0, iw1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, in_width) - ih0, ih1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - 0x1, align, in_height) - @inbounds for n in 1:batch, c in 1:channels + iw0, iw1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - one(UInt32), align, in_width) + ih0, ih1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - one(UInt32), align, in_height) + @inbounds for n in UnitRange{UInt32}(one(UInt32), batch), c in UnitRange{UInt32}(one(UInt32), channels) y[i, j, c, n] = h0lambda * (w0lambda * x[iw0, ih0, c, n] + w1lambda * x[iw1, ih0, c, n]) + h1lambda * (w0lambda * x[iw0, ih1, c, n] + w1lambda * x[iw1, ih1, c, n]) @@ -496,9 +496,9 @@ end @uniform in_width::UInt32, in_height::UInt32, channels::UInt32, batch::UInt32 = size(Δ) @uniform out_width::UInt32, out_height::UInt32 = size(dx)[1:2] i::UInt32, j::UInt32 = @index(Global, NTuple) - ow0, ow1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, out_width) - oh0, oh1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - 0x1, align, out_height) - @inbounds for n in 1:batch, c in 1:channels + ow0, ow1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - one(UInt32), align, out_width) + oh0, oh1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - one(UInt32), align, out_height) + @inbounds for n in UnitRange{UInt32}(one(UInt32), batch), c in UnitRange{UInt32}(one(UInt32), channels) val = Δ[i, j, c, n] @atomic dx[ow0, oh0, c, n] += w0lambda * h0lambda * val @atomic dx[ow1, oh0, c, n] += w1lambda * h0lambda * val @@ -516,12 +516,12 @@ end @uniform channels::UInt32, batch::UInt32 = size(x, 4), size(x, 5) @uniform out_width::UInt32, out_height::UInt32, out_depth::UInt32 = size(y)[1:3] c::UInt32, n::UInt32 = @index(Global, NTuple) - for k in UnitRange{UInt32}(1, out_depth) - id0, id1, d0lambda, d1lambda = source_index_and_lambda(rdepth, k - 0x1, align, in_depth) - for j in UnitRange{UInt32}(1, out_height) - ih0, ih1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - 0x1, align, in_height) - for i in UnitRange{UInt32}(1, out_width) - iw0, iw1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, in_width) + for k in UnitRange{UInt32}(one(UInt32), out_depth) + id0, id1, d0lambda, d1lambda = source_index_and_lambda(rdepth, k - one(UInt32), align, in_depth) + for j in UnitRange{UInt32}(one(UInt32), out_height) + ih0, ih1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - one(UInt32), align, in_height) + for i in UnitRange{UInt32}(one(UInt32), out_width) + iw0, iw1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - one(UInt32), align, in_width) @inbounds y[i, j, k, c, n] = d0lambda * ( h0lambda * (w0lambda * x[iw0, ih0, id0, c, n] + w1lambda * x[iw1, ih0, id0, c, n]) + @@ -541,12 +541,12 @@ end @uniform channels::UInt32, batch::UInt32 = size(Δ, 4), size(Δ, 5) @uniform out_width::UInt32, out_height::UInt32, out_depth::UInt32 = size(dx)[1:3] c::UInt32, n::UInt32 = @index(Global, NTuple) - for k in UnitRange{UInt32}(1, in_depth) - od0, od1, d0lambda, d1lambda = source_index_and_lambda(rdepth, k - 0x1, align, out_depth) - for j in UnitRange{UInt32}(1, in_height) - oh0, oh1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - 0x1, align, out_height) - @inbounds for i in UnitRange{UInt32}(1, in_width) - ow0, ow1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, out_width) + for k in UnitRange{UInt32}(one(UInt32), in_depth) + od0, od1, d0lambda, d1lambda = source_index_and_lambda(rdepth, k - one(UInt32), align, out_depth) + for j in UnitRange{UInt32}(one(UInt32), in_height) + oh0, oh1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - one(UInt32), align, out_height) + @inbounds for i in UnitRange{UInt32}(one(UInt32), in_width) + ow0, ow1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - one(UInt32), align, out_width) val = Δ[i, j, k, c, n] dx[ow0, oh0, od0, c, n] += w0lambda * h0lambda * d0lambda * val dx[ow1, oh0, od0, c, n] += w1lambda * h0lambda * d0lambda * val @@ -570,10 +570,10 @@ end @uniform in_width::UInt32, in_height::UInt32, in_depth::UInt32 = size(x)[1:3] @uniform channels::UInt32, batch::UInt32 = size(x, 4), size(x, 5) i::UInt32, j::UInt32, k::UInt32 = @index(Global, NTuple) - iw0, iw1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, in_width) - ih0, ih1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - 0x1, align, in_height) - id0, id1, d0lambda, d1lambda = source_index_and_lambda(rdepth, k - 0x1, align, in_depth) - @inbounds for n in 1:batch, c in 1:channels + iw0, iw1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - one(UInt32), align, in_width) + ih0, ih1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - one(UInt32), align, in_height) + id0, id1, d0lambda, d1lambda = source_index_and_lambda(rdepth, k - one(UInt32), align, in_depth) + @inbounds for n in UnitRange{UInt32}(one(UInt32), batch), c in UnitRange{UInt32}(one(UInt32), channels) y[i, j, k, c, n] = d0lambda * ( h0lambda * (w0lambda * x[iw0, ih0, id0, c, n] + w1lambda * x[iw1, ih0, id0, c, n]) + @@ -591,10 +591,10 @@ end @uniform channels::UInt32, batch::UInt32 = size(Δ, 4), size(Δ, 5) @uniform out_width::UInt32, out_height::UInt32, out_depth::UInt32 = size(dx)[1:3] i::UInt32, j::UInt32, k::UInt32 = @index(Global, NTuple) - ow0, ow1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, out_width) - oh0, oh1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - 0x1, align, out_height) - od0, od1, d0lambda, d1lambda = source_index_and_lambda(rdepth, k - 0x1, align, out_depth) - @inbounds for n in 1:batch, c in 1:channels + ow0, ow1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - one(UInt32), align, out_width) + oh0, oh1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - one(UInt32), align, out_height) + od0, od1, d0lambda, d1lambda = source_index_and_lambda(rdepth, k - one(UInt32), align, out_depth) + @inbounds for n in UnitRange{UInt32}(one(UInt32), batch), c in UnitRange{UInt32}(one(UInt32), channels) val = Δ[i, j, k, c, n] @atomic dx[ow0, oh0, od0, c, n] += w0lambda * h0lambda * d0lambda * val @atomic dx[ow1, oh0, od0, c, n] += w1lambda * h0lambda * d0lambda * val @@ -616,10 +616,10 @@ end max(zero(T), ratio * (out_idx + T(0.5)) - T(0.5)) iw0 = floor(UInt32, real_index) - offset::UInt32 = ifelse(iw0 < in_width - 0x1, 0x1, 0x0) - iw1 = iw0 + offset + 0x1 + offset::UInt32 = ifelse(iw0 < in_width - one(UInt32), one(UInt32), zero(UInt32)) + iw1 = iw0 + offset + one(UInt32) w1lambda = real_index - iw0 w0lambda = T(1) - w1lambda - return iw0 + 0x1, iw1, w0lambda, w1lambda + return iw0 + one(UInt32), iw1, w0lambda, w1lambda end From f30fb3742824d610d321fefc2dda73f550edeba4 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Thu, 6 Apr 2023 17:16:07 +0300 Subject: [PATCH 08/11] Add compatibility layers back --- src/upsample.jl | 9 +++++++++ test/runtests.jl | 5 ++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/upsample.jl b/src/upsample.jl index 6e0ebb4a1..bfb1cecfe 100644 --- a/src/upsample.jl +++ b/src/upsample.jl @@ -380,6 +380,15 @@ function ∇upsample_linear_kernel!( return dx end +# Compatibility layer for old versions of NNlibCUDA. +# TODO Can be removed from NNlib 0.9. +upsample_linear_wcn!(y, x) = upsample_linear_kernel!(y, x) +upsample_bilinear_whcn!(y, x) = upsample_linear_kernel!(y, x) +upsample_trilinear_whdcn!(y, x) = upsample_linear_kernel!(y, x) +∇upsample_linear_wcn!(y, x) = ∇upsample_linear_kernel!(y, x) +∇upsample_bilinear_whcn!(y, x) = ∇upsample_linear_kernel!(y, x) +∇upsample_trilinear_whdcn!(y, x) = ∇upsample_linear_kernel!(y, x) + # Linear (CPU): parallelization along channel x batch dimensions. @kernel function _upsample_linear_kernel!(::CPU, y::T, x::T, rwidth, align::Val{A}) where { diff --git a/test/runtests.jl b/test/runtests.jl index 3b6df78cb..ac1361be7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -56,7 +56,11 @@ end if get(ENV, "NNLIB_TEST_AMDGPU", "false") == "true" import Pkg test_info = Pkg.project() + # Add MIOpen_jll to AMDGPU. Pkg.develop("AMDGPU") + Pkg.activate(joinpath(Pkg.devdir(), "AMDGPU")) + Pkg.add("MIOpen_jll") + Pkg.update() # Update test project. Pkg.activate(test_info.path) Pkg.update() @@ -93,7 +97,6 @@ end if get(ENV, "NNLIB_TEST_AMDGPU", "false") == "true" import Pkg test_info = Pkg.project() - # Add MIOpen_jll to AMDGPU. Pkg.develop("AMDGPU") Pkg.activate(joinpath(Pkg.devdir(), "AMDGPU")) From a39d91b96bab653901b76096389b227f953b43f3 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Mon, 10 Apr 2023 14:32:10 +0300 Subject: [PATCH 09/11] Refactor to use 'cpu' & 'device' functions --- test/runtests.jl | 2 +- test/upsample.jl | 63 ++++++++++++++++++++++++------------------------ 2 files changed, 33 insertions(+), 32 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index ac1361be7..8fc0c15df 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -35,7 +35,7 @@ function nnlib_testsuite(Backend; skip_tests = Set{String}()) end @testset "NNlib.jl" verbose=true begin - @testset "Test Suite" begin + @testset verbose=true "Test Suite" begin @testset "CPU" begin nnlib_testsuite(CPU) end diff --git a/test/upsample.jl b/test/upsample.jl index eac86f4a6..d91626ccb 100644 --- a/test/upsample.jl +++ b/test/upsample.jl @@ -1,25 +1,26 @@ function upsample_testsuite(Backend) - cpu, backend = CPU(), Backend() + cpu(x) = adapt(CPU(), x) + device(x) = adapt(Backend(), x) + gradtest_fn = KernelAbstractions.isgpu(Backend()) ? gputest : gradtest T = Float32 # TODO test against all supported eltypes for each backend. atol = T == Float32 ? 1e-3 : 1e-6 - gradtest_fn = backend == CPU() ? gradtest : gputest @testset "upsample_nearest, integer scale via reshape" begin - x = adapt(backend, reshape(T[1 2; 3 4], (2,2,1,1))) - @test adapt(cpu, upsample_nearest(x, (3,3)))[1,:] == [1,1,1, 2,2,2] + x = device(reshape(T[1 2; 3 4], (2,2,1,1))) + @test cpu(upsample_nearest(x, (3,3)))[1,:] == [1,1,1, 2,2,2] y = upsample_nearest(x, (2,3)) @test size(y) == (4,6,1,1) y2 = upsample_nearest(x, size=(4,6)) - @test adapt(cpu, y) ≈ adapt(cpu, y2) + @test cpu(y) ≈ cpu(y2) - @test adapt(cpu, ∇upsample_nearest(y, (2,3)))[:, :, 1, 1] == [6 12; 18 24] + @test cpu(∇upsample_nearest(y, (2,3)))[:, :, 1, 1] == [6 12; 18 24] gradtest_fn( x -> upsample_nearest(x, (2,3)), - adapt(backend, rand(T, 2,2,1,1)); atol) + device(rand(T, 2,2,1,1)); atol) gradtest_fn( x -> upsample_nearest(x, size=(4,6)), - adapt(backend, rand(T, 2,2,1,1)); atol) + device(rand(T, 2,2,1,1)); atol) @test_throws ArgumentError ∇upsample_nearest(y, (2,4)) @test_throws ArgumentError upsample_nearest(x, (1,2,3,4,5)) @@ -33,9 +34,9 @@ function upsample_testsuite(Backend) y = collect(1:1//3:4) y = hcat(y,y,y)[:,:,:] - xd = adapt(backend, x) - @test y ≈ adapt(cpu, upsample_linear(xd, 2.5)) - @test y ≈ adapt(cpu, upsample_linear(xd; size=10)) + xd = device(x) + @test y ≈ cpu(upsample_linear(xd, 2.5)) + @test y ≈ cpu(upsample_linear(xd; size=10)) gradtest_fn(x -> upsample_linear(x, 2.5), xd; atol) end @@ -56,18 +57,18 @@ function upsample_testsuite(Backend) y_true = cat(y_true, y_true; dims=3) y_true = cat(y_true, y_true; dims=4) - xd = adapt(backend, x) + xd = device(x) y = upsample_bilinear(xd, (3, 2)) @test size(y) == size(y_true) @test eltype(y) == Float32 - @test adapt(cpu, y) ≈ y_true + @test cpu(y) ≈ y_true gradtest_fn(x -> upsample_bilinear(x, (3, 2)), xd; atol) # additional grad check, also compliant with pytorch o = ones(Float32,6,4,2,1) grad_true = 6*ones(Float32,2,2,2,1) - @test adapt(cpu, ∇upsample_bilinear(adapt(backend, o); size = (2,2))) ≈ grad_true + @test cpu(∇upsample_bilinear(device(o); size = (2,2))) ≈ grad_true # CPU only tests. @@ -110,7 +111,7 @@ function upsample_testsuite(Backend) y_true[:,:,4,:,:] .= 2.5 y_true[:,:,5,:,:] .= 3. - xd = adapt(backend, x) + xd = device(x) y = upsample_trilinear(xd; size=(5,5,5)) @test size(y) == size(y_true) @@ -122,9 +123,9 @@ function upsample_testsuite(Backend) atol=(T == Float32) ? 1e-2 : 1e-5) # This test only works when `align_corners=false`. - o = adapt(backend, ones(Float32,8,8,8,1,1)) + o = device(ones(Float32,8,8,8,1,1)) grad_true = 8 * ones(Float32,4,4,4,1,1) - @test adapt(cpu, ∇upsample_trilinear(o; size=(4,4,4), align_corners=false)) ≈ grad_true + @test cpu(∇upsample_trilinear(o; size=(4,4,4), align_corners=false)) ≈ grad_true end @testset "pixel_shuffle" begin @@ -147,9 +148,9 @@ function upsample_testsuite(Backend) 2 10 4 12 6 14 8 16][:,:,:,:] - y = pixel_shuffle(adapt(backend, x), 2) + y = pixel_shuffle(device(x), 2) @test size(y) == size(y_true) - @test y_true == adapt(cpu, y) + @test y_true == cpu(y) x = reshape(1:32, (2, 2, 8, 1)) y_true = zeros(Int, 4, 4, 2, 1) @@ -163,20 +164,20 @@ function upsample_testsuite(Backend) 18 26 20 28 22 30 24 32] - y = pixel_shuffle(adapt(backend, x), 2) + y = pixel_shuffle(device(x), 2) @test size(y) == size(y_true) - @test y_true == adapt(cpu, y) + @test y_true == cpu(y) x = reshape(1:4*3*27*2, (4,3,27,2)) - y = pixel_shuffle(adapt(backend, x), 3) + y = pixel_shuffle(device(x), 3) @test size(y) == (12, 9, 3, 2) # batch dimension is preserved x1 = x[:,:,:,[1]] x2 = x[:,:,:,[2]] - y1 = pixel_shuffle(adapt(backend, x1), 3) - y2 = pixel_shuffle(adapt(backend, x2), 3) - @test adapt(cpu, cat(y1, y2, dims=4)) == adapt(cpu, y) + y1 = pixel_shuffle(device(x1), 3) + y2 = pixel_shuffle(device(x2), 3) + @test cpu(cat(y1, y2, dims=4)) == cpu(y) for d in [1, 2, 3] r = rand(1:5) @@ -184,7 +185,7 @@ function upsample_testsuite(Backend) c = rand(1:5) insize = rand(1:5, d) x = rand(insize..., r^d*c, n) - xd = adapt(backend, x) + xd = device(x) y = pixel_shuffle(xd, r) @test size(y) == ((r .* insize)..., c, n) @@ -195,19 +196,19 @@ function upsample_testsuite(Backend) @testset "Complex-valued upsample" begin for (d, method) in zip([1, 2, 3], [upsample_linear, upsample_bilinear, upsample_trilinear]) for (k, interp) in zip((2, ntuple(_ -> 2, d)), [method, upsample_nearest]) - x = adapt(backend, randn(Complex{Float32}, (4,8,12)[1:d]..., 1, 1)) + x = device(randn(Complex{Float32}, (4,8,12)[1:d]..., 1, 1)) upsize = (8, 16, 24)[1:d] xup = interp(x, k) @test size(xup)[1:d] == upsize - @test adapt(cpu, real(xup)) == adapt(cpu, interp(real(x), k)) - @test adapt(cpu, imag(xup)) == adapt(cpu, interp(imag(x), k)) + @test cpu(real(xup)) == cpu(interp(real(x), k)) + @test cpu(imag(xup)) == cpu(interp(imag(x), k)) upsize = (8,24,48)[1:d] xup = interp(x; size=upsize) @test size(xup)[1:d] == upsize - @test adapt(cpu, real(xup)) == adapt(cpu, interp(real(x), size=upsize)) - @test adapt(cpu, imag(xup)) == adapt(cpu, interp(imag(x), size=upsize)) + @test cpu(real(xup)) == cpu(interp(real(x), size=upsize)) + @test cpu(imag(xup)) == cpu(interp(imag(x), size=upsize)) end end end From 5fbf12f39cad45e330d0fee7e8b2835fc51732dc Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Mon, 10 Apr 2023 14:44:41 +0300 Subject: [PATCH 10/11] Refactor --- test/runtests.jl | 2 ++ test/upsample.jl | 1 - 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 8fc0c15df..8b9061b77 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -26,6 +26,8 @@ macro conditional_testset(name, skip_tests, expr) end) end +cpu(x) = adapt(CPU(), x) + include("upsample.jl") function nnlib_testsuite(Backend; skip_tests = Set{String}()) diff --git a/test/upsample.jl b/test/upsample.jl index d91626ccb..28109d4b2 100644 --- a/test/upsample.jl +++ b/test/upsample.jl @@ -1,5 +1,4 @@ function upsample_testsuite(Backend) - cpu(x) = adapt(CPU(), x) device(x) = adapt(Backend(), x) gradtest_fn = KernelAbstractions.isgpu(Backend()) ? gputest : gradtest T = Float32 # TODO test against all supported eltypes for each backend. From bf4d6cf8dc869dd73cc72269896f960585b350db Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Mon, 10 Apr 2023 23:08:17 +0300 Subject: [PATCH 11/11] Optimize CPU kernels --- src/upsample.jl | 166 +++++++++++++++++++++++++----------------------- 1 file changed, 86 insertions(+), 80 deletions(-) diff --git a/src/upsample.jl b/src/upsample.jl index bfb1cecfe..2f58666b9 100644 --- a/src/upsample.jl +++ b/src/upsample.jl @@ -397,9 +397,10 @@ upsample_trilinear_whdcn!(y, x) = upsample_linear_kernel!(y, x) @uniform in_width::UInt32, channels::UInt32, batch::UInt32 = size(x) @uniform out_width::UInt32 = size(y, 1) c::UInt32, n::UInt32 = @index(Global, NTuple) + yv, xv = @view(y[:, c, n]), @view(x[:, c, n]) @inbounds for i in UnitRange{UInt32}(one(UInt32), out_width) - iw0, iw1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - one(UInt32), align, in_width) - y[i, c, n] = w0lambda * x[iw0, c, n] + w1lambda * x[iw1, c, n] + iw0, iw1, w0λ, w1λ = source_idx_and_λ(rwidth, i - one(UInt32), align, in_width) + yv[i] = w0λ * xv[iw0] + w1λ * xv[iw1] end end @@ -409,11 +410,12 @@ end @uniform in_width::UInt32, channels::UInt32, batch::UInt32 = size(Δ) @uniform out_width::UInt32 = size(dx, 1) c::UInt32, n::UInt32 = @index(Global, NTuple) + Δv, dxv = @view(Δ[:, c, n]), @view(dx[:, c, n]) @inbounds for i in UnitRange{UInt32}(one(UInt32), in_width) - ow0, ow1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - one(UInt32), align, out_width) - val = Δ[i, c, n] - dx[ow0, c, n] += w0lambda * val - dx[ow1, c, n] += w1lambda * val + ow0, ow1, w0λ, w1λ = source_idx_and_λ(rwidth, i - one(UInt32), align, out_width) + val = Δv[i] + dxv[ow0] += w0λ * val + dxv[ow1] += w1λ * val end end @@ -425,9 +427,9 @@ end } @uniform in_width::UInt32, channels::UInt32, batch::UInt32 = size(x) i::UInt32 = @index(Global) - iw0, iw1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - one(UInt32), align, in_width) + iw0, iw1, w0λ, w1λ = source_idx_and_λ(rwidth, i - one(UInt32), align, in_width) @inbounds for n in UnitRange{UInt32}(one(UInt32), batch), c in UnitRange{UInt32}(one(UInt32), channels) - y[i, c, n] = w0lambda * x[iw0, c, n] + w1lambda * x[iw1, c, n] + y[i, c, n] = w0λ * x[iw0, c, n] + w1λ * x[iw1, c, n] end end @@ -437,11 +439,11 @@ end @uniform in_width::UInt32, channels::UInt32, batch::UInt32 = size(Δ) @uniform out_width::UInt32 = size(dx, 1) i::UInt32 = @index(Global) - ow0, ow1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - one(UInt32), align, out_width) + ow0, ow1, w0λ, w1λ = source_idx_and_λ(rwidth, i - one(UInt32), align, out_width) @inbounds for n in UnitRange{UInt32}(one(UInt32), batch), c in UnitRange{UInt32}(one(UInt32), channels) val = Δ[i, c, n] - @atomic dx[ow0, c, n] += w0lambda * val - @atomic dx[ow1, c, n] += w1lambda * val + @atomic dx[ow0, c, n] += w0λ * val + @atomic dx[ow1, c, n] += w1λ * val end end @@ -453,13 +455,14 @@ end @uniform in_width::UInt32, in_height::UInt32, channels::UInt32, batch::UInt32 = size(x) @uniform out_width::UInt32, out_height::UInt32 = size(y)[1:2] c::UInt32, n::UInt32 = @index(Global, NTuple) + yv, xv = @view(y[:, :, c, n]), @view(x[:, :, c, n]) for j in UnitRange{UInt32}(one(UInt32), out_height) - ih0, ih1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - one(UInt32), align, in_height) + ih0, ih1, h0λ, h1λ = source_idx_and_λ(rheight, j - one(UInt32), align, in_height) for i in UnitRange{UInt32}(one(UInt32), out_width) - iw0, iw1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - one(UInt32), align, in_width) - @inbounds y[i, j, c, n] = - h0lambda * (w0lambda * x[iw0, ih0, c, n] + w1lambda * x[iw1, ih0, c, n]) + - h1lambda * (w0lambda * x[iw0, ih1, c, n] + w1lambda * x[iw1, ih1, c, n]) + iw0, iw1, w0λ, w1λ = source_idx_and_λ(rwidth, i - one(UInt32), align, in_width) + @inbounds yv[i, j] = + h0λ * (w0λ * xv[iw0, ih0] + w1λ * xv[iw1, ih0]) + + h1λ * (w0λ * xv[iw0, ih1] + w1λ * xv[iw1, ih1]) end end end @@ -470,15 +473,16 @@ end @uniform in_width::UInt32, in_height::UInt32, channels::UInt32, batch::UInt32 = size(Δ) @uniform out_width::UInt32, out_height::UInt32 = size(dx)[1:2] c::UInt32, n::UInt32 = @index(Global, NTuple) + Δv, dxv = @view(Δ[:, :, c, n]), @view(dx[:, :, c, n]) for j in UnitRange{UInt32}(one(UInt32), in_height) - oh0, oh1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - one(UInt32), align, out_height) - for i in UnitRange{UInt32}(one(UInt32), in_width) - ow0, ow1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - one(UInt32), align, out_width) - val = Δ[i, j, c, n] - dx[ow0, oh0, c, n] += w0lambda * h0lambda * val - dx[ow1, oh0, c, n] += w1lambda * h0lambda * val - dx[ow0, oh1, c, n] += w0lambda * h1lambda * val - dx[ow1, oh1, c, n] += w1lambda * h1lambda * val + oh0, oh1, h0λ, h1λ = source_idx_and_λ(rheight, j - one(UInt32), align, out_height) + @inbounds for i in UnitRange{UInt32}(one(UInt32), in_width) + ow0, ow1, w0λ, w1λ = source_idx_and_λ(rwidth, i - one(UInt32), align, out_width) + val = Δv[i, j] + dxv[ow0, oh0] += w0λ * h0λ * val + dxv[ow1, oh0] += w1λ * h0λ * val + dxv[ow0, oh1] += w0λ * h1λ * val + dxv[ow1, oh1] += w1λ * h1λ * val end end end @@ -490,12 +494,12 @@ end } @uniform in_width::UInt32, in_height::UInt32, channels::UInt32, batch::UInt32 = size(x) i::UInt32, j::UInt32 = @index(Global, NTuple) - iw0, iw1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - one(UInt32), align, in_width) - ih0, ih1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - one(UInt32), align, in_height) + iw0, iw1, w0λ, w1λ = source_idx_and_λ(rwidth, i - one(UInt32), align, in_width) + ih0, ih1, h0λ, h1λ = source_idx_and_λ(rheight, j - one(UInt32), align, in_height) @inbounds for n in UnitRange{UInt32}(one(UInt32), batch), c in UnitRange{UInt32}(one(UInt32), channels) y[i, j, c, n] = - h0lambda * (w0lambda * x[iw0, ih0, c, n] + w1lambda * x[iw1, ih0, c, n]) + - h1lambda * (w0lambda * x[iw0, ih1, c, n] + w1lambda * x[iw1, ih1, c, n]) + h0λ * (w0λ * x[iw0, ih0, c, n] + w1λ * x[iw1, ih0, c, n]) + + h1λ * (w0λ * x[iw0, ih1, c, n] + w1λ * x[iw1, ih1, c, n]) end end @@ -505,14 +509,14 @@ end @uniform in_width::UInt32, in_height::UInt32, channels::UInt32, batch::UInt32 = size(Δ) @uniform out_width::UInt32, out_height::UInt32 = size(dx)[1:2] i::UInt32, j::UInt32 = @index(Global, NTuple) - ow0, ow1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - one(UInt32), align, out_width) - oh0, oh1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - one(UInt32), align, out_height) + ow0, ow1, w0λ, w1λ = source_idx_and_λ(rwidth, i - one(UInt32), align, out_width) + oh0, oh1, h0λ, h1λ = source_idx_and_λ(rheight, j - one(UInt32), align, out_height) @inbounds for n in UnitRange{UInt32}(one(UInt32), batch), c in UnitRange{UInt32}(one(UInt32), channels) val = Δ[i, j, c, n] - @atomic dx[ow0, oh0, c, n] += w0lambda * h0lambda * val - @atomic dx[ow1, oh0, c, n] += w1lambda * h0lambda * val - @atomic dx[ow0, oh1, c, n] += w0lambda * h1lambda * val - @atomic dx[ow1, oh1, c, n] += w1lambda * h1lambda * val + @atomic dx[ow0, oh0, c, n] += w0λ * h0λ * val + @atomic dx[ow1, oh0, c, n] += w1λ * h0λ * val + @atomic dx[ow0, oh1, c, n] += w0λ * h1λ * val + @atomic dx[ow1, oh1, c, n] += w1λ * h1λ * val end end @@ -525,19 +529,20 @@ end @uniform channels::UInt32, batch::UInt32 = size(x, 4), size(x, 5) @uniform out_width::UInt32, out_height::UInt32, out_depth::UInt32 = size(y)[1:3] c::UInt32, n::UInt32 = @index(Global, NTuple) + yv, xv = @view(y[:, :, :, c, n]), @view(x[:, :, :, c, n]) for k in UnitRange{UInt32}(one(UInt32), out_depth) - id0, id1, d0lambda, d1lambda = source_index_and_lambda(rdepth, k - one(UInt32), align, in_depth) + id0, id1, d0λ, d1λ = source_idx_and_λ(rdepth, k - one(UInt32), align, in_depth) for j in UnitRange{UInt32}(one(UInt32), out_height) - ih0, ih1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - one(UInt32), align, in_height) + ih0, ih1, h0λ, h1λ = source_idx_and_λ(rheight, j - one(UInt32), align, in_height) for i in UnitRange{UInt32}(one(UInt32), out_width) - iw0, iw1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - one(UInt32), align, in_width) - @inbounds y[i, j, k, c, n] = - d0lambda * ( - h0lambda * (w0lambda * x[iw0, ih0, id0, c, n] + w1lambda * x[iw1, ih0, id0, c, n]) + - h1lambda * (w0lambda * x[iw0, ih1, id0, c, n] + w1lambda * x[iw1, ih1, id0, c, n])) + - d1lambda * ( - h0lambda * (w0lambda * x[iw0, ih0, id1, c, n] + w1lambda * x[iw1, ih0, id1, c, n]) + - h1lambda * (w0lambda * x[iw0, ih1, id1, c, n] + w1lambda * x[iw1, ih1, id1, c, n])) + iw0, iw1, w0λ, w1λ = source_idx_and_λ(rwidth, i - one(UInt32), align, in_width) + @inbounds yv[i, j, k] = + d0λ * ( + h0λ * (w0λ * xv[iw0, ih0, id0] + w1λ * xv[iw1, ih0, id0]) + + h1λ * (w0λ * xv[iw0, ih1, id0] + w1λ * xv[iw1, ih1, id0])) + + d1λ * ( + h0λ * (w0λ * xv[iw0, ih0, id1] + w1λ * xv[iw1, ih0, id1]) + + h1λ * (w0λ * xv[iw0, ih1, id1] + w1λ * xv[iw1, ih1, id1])) end end end @@ -550,22 +555,23 @@ end @uniform channels::UInt32, batch::UInt32 = size(Δ, 4), size(Δ, 5) @uniform out_width::UInt32, out_height::UInt32, out_depth::UInt32 = size(dx)[1:3] c::UInt32, n::UInt32 = @index(Global, NTuple) + Δv, dxv = @view(Δ[:, :, :, c, n]), @view(dx[:, :, :, c, n]) for k in UnitRange{UInt32}(one(UInt32), in_depth) - od0, od1, d0lambda, d1lambda = source_index_and_lambda(rdepth, k - one(UInt32), align, out_depth) + od0, od1, d0λ, d1λ = source_idx_and_λ(rdepth, k - one(UInt32), align, out_depth) for j in UnitRange{UInt32}(one(UInt32), in_height) - oh0, oh1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - one(UInt32), align, out_height) + oh0, oh1, h0λ, h1λ = source_idx_and_λ(rheight, j - one(UInt32), align, out_height) @inbounds for i in UnitRange{UInt32}(one(UInt32), in_width) - ow0, ow1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - one(UInt32), align, out_width) - val = Δ[i, j, k, c, n] - dx[ow0, oh0, od0, c, n] += w0lambda * h0lambda * d0lambda * val - dx[ow1, oh0, od0, c, n] += w1lambda * h0lambda * d0lambda * val - dx[ow0, oh1, od0, c, n] += w0lambda * h1lambda * d0lambda * val - dx[ow1, oh1, od0, c, n] += w1lambda * h1lambda * d0lambda * val - - dx[ow0, oh0, od1, c, n] += w0lambda * h0lambda * d1lambda * val - dx[ow1, oh0, od1, c, n] += w1lambda * h0lambda * d1lambda * val - dx[ow0, oh1, od1, c, n] += w0lambda * h1lambda * d1lambda * val - dx[ow1, oh1, od1, c, n] += w1lambda * h1lambda * d1lambda * val + ow0, ow1, w0λ, w1λ = source_idx_and_λ(rwidth, i - one(UInt32), align, out_width) + val = Δv[i, j, k] + dxv[ow0, oh0, od0] += w0λ * h0λ * d0λ * val + dxv[ow1, oh0, od0] += w1λ * h0λ * d0λ * val + dxv[ow0, oh1, od0] += w0λ * h1λ * d0λ * val + dxv[ow1, oh1, od0] += w1λ * h1λ * d0λ * val + + dxv[ow0, oh0, od1] += w0λ * h0λ * d1λ * val + dxv[ow1, oh0, od1] += w1λ * h0λ * d1λ * val + dxv[ow0, oh1, od1] += w0λ * h1λ * d1λ * val + dxv[ow1, oh1, od1] += w1λ * h1λ * d1λ * val end end end @@ -579,17 +585,17 @@ end @uniform in_width::UInt32, in_height::UInt32, in_depth::UInt32 = size(x)[1:3] @uniform channels::UInt32, batch::UInt32 = size(x, 4), size(x, 5) i::UInt32, j::UInt32, k::UInt32 = @index(Global, NTuple) - iw0, iw1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - one(UInt32), align, in_width) - ih0, ih1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - one(UInt32), align, in_height) - id0, id1, d0lambda, d1lambda = source_index_and_lambda(rdepth, k - one(UInt32), align, in_depth) + iw0, iw1, w0λ, w1λ = source_idx_and_λ(rwidth, i - one(UInt32), align, in_width) + ih0, ih1, h0λ, h1λ = source_idx_and_λ(rheight, j - one(UInt32), align, in_height) + id0, id1, d0λ, d1λ = source_idx_and_λ(rdepth, k - one(UInt32), align, in_depth) @inbounds for n in UnitRange{UInt32}(one(UInt32), batch), c in UnitRange{UInt32}(one(UInt32), channels) y[i, j, k, c, n] = - d0lambda * ( - h0lambda * (w0lambda * x[iw0, ih0, id0, c, n] + w1lambda * x[iw1, ih0, id0, c, n]) + - h1lambda * (w0lambda * x[iw0, ih1, id0, c, n] + w1lambda * x[iw1, ih1, id0, c, n])) + - d1lambda * ( - h0lambda * (w0lambda * x[iw0, ih0, id1, c, n] + w1lambda * x[iw1, ih0, id1, c, n]) + - h1lambda * (w0lambda * x[iw0, ih1, id1, c, n] + w1lambda * x[iw1, ih1, id1, c, n])) + d0λ * ( + h0λ * (w0λ * x[iw0, ih0, id0, c, n] + w1λ * x[iw1, ih0, id0, c, n]) + + h1λ * (w0λ * x[iw0, ih1, id0, c, n] + w1λ * x[iw1, ih1, id0, c, n])) + + d1λ * ( + h0λ * (w0λ * x[iw0, ih0, id1, c, n] + w1λ * x[iw1, ih0, id1, c, n]) + + h1λ * (w0λ * x[iw0, ih1, id1, c, n] + w1λ * x[iw1, ih1, id1, c, n])) end end @@ -600,24 +606,24 @@ end @uniform channels::UInt32, batch::UInt32 = size(Δ, 4), size(Δ, 5) @uniform out_width::UInt32, out_height::UInt32, out_depth::UInt32 = size(dx)[1:3] i::UInt32, j::UInt32, k::UInt32 = @index(Global, NTuple) - ow0, ow1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - one(UInt32), align, out_width) - oh0, oh1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - one(UInt32), align, out_height) - od0, od1, d0lambda, d1lambda = source_index_and_lambda(rdepth, k - one(UInt32), align, out_depth) + ow0, ow1, w0λ, w1λ = source_idx_and_λ(rwidth, i - one(UInt32), align, out_width) + oh0, oh1, h0λ, h1λ = source_idx_and_λ(rheight, j - one(UInt32), align, out_height) + od0, od1, d0λ, d1λ = source_idx_and_λ(rdepth, k - one(UInt32), align, out_depth) @inbounds for n in UnitRange{UInt32}(one(UInt32), batch), c in UnitRange{UInt32}(one(UInt32), channels) val = Δ[i, j, k, c, n] - @atomic dx[ow0, oh0, od0, c, n] += w0lambda * h0lambda * d0lambda * val - @atomic dx[ow1, oh0, od0, c, n] += w1lambda * h0lambda * d0lambda * val - @atomic dx[ow0, oh1, od0, c, n] += w0lambda * h1lambda * d0lambda * val - @atomic dx[ow1, oh1, od0, c, n] += w1lambda * h1lambda * d0lambda * val - - @atomic dx[ow0, oh0, od1, c, n] += w0lambda * h0lambda * d1lambda * val - @atomic dx[ow1, oh0, od1, c, n] += w1lambda * h0lambda * d1lambda * val - @atomic dx[ow0, oh1, od1, c, n] += w0lambda * h1lambda * d1lambda * val - @atomic dx[ow1, oh1, od1, c, n] += w1lambda * h1lambda * d1lambda * val + @atomic dx[ow0, oh0, od0, c, n] += w0λ * h0λ * d0λ * val + @atomic dx[ow1, oh0, od0, c, n] += w1λ * h0λ * d0λ * val + @atomic dx[ow0, oh1, od0, c, n] += w0λ * h1λ * d0λ * val + @atomic dx[ow1, oh1, od0, c, n] += w1λ * h1λ * d0λ * val + + @atomic dx[ow0, oh0, od1, c, n] += w0λ * h0λ * d1λ * val + @atomic dx[ow1, oh0, od1, c, n] += w1λ * h0λ * d1λ * val + @atomic dx[ow0, oh1, od1, c, n] += w0λ * h1λ * d1λ * val + @atomic dx[ow1, oh1, od1, c, n] += w1λ * h1λ * d1λ * val end end -@inline function source_index_and_lambda( +@inline function source_idx_and_λ( ratio::T, out_idx::UInt32, ::Val{align}, in_width::UInt32, ) where {T, align} real_index = align ? @@ -629,6 +635,6 @@ end iw1 = iw0 + offset + one(UInt32) w1lambda = real_index - iw0 - w0lambda = T(1) - w1lambda + w0lambda = one(T) - w1lambda return iw0 + one(UInt32), iw1, w0lambda, w1lambda end