diff --git a/lib/NNlibCUDA/src/NNlibCUDA.jl b/lib/NNlibCUDA/src/NNlibCUDA.jl index 48c11dfa3..13e7928bd 100644 --- a/lib/NNlibCUDA/src/NNlibCUDA.jl +++ b/lib/NNlibCUDA/src/NNlibCUDA.jl @@ -4,8 +4,12 @@ using NNlib using CUDA using Random, Statistics +const IntOrIntTuple = Union{Integer, NTuple{N,<:Integer} where N} +const MAX_THREADS = 1024 + include("upsample.jl") include("batchedmul.jl") +include("scatter.jl") include("cudnn/cudnn.jl") include("cudnn/conv.jl") include("cudnn/pooling.jl") diff --git a/lib/NNlibCUDA/src/scatter.jl b/lib/NNlibCUDA/src/scatter.jl new file mode 100644 index 000000000..e939c46e6 --- /dev/null +++ b/lib/NNlibCUDA/src/scatter.jl @@ -0,0 +1,44 @@ +ATM_OPS = Dict((+) => CUDA.atomic_add!, (-) => CUDA.atomic_sub!, (max) => CUDA.atomic_max!, (min) => CUDA.atomic_min!, + (*) => CUDA.atomic_mul!, (/) => CUDA.atomic_div!, (&) => CUDA.atomic_and!, (|) => CUDA.atomic_or!) + +for (op, atm_op) in ATM_OPS + @eval function scatter_kernel!(op::typeof($(op)), dst, src, idx) + index = threadIdx().x + (blockIdx().x - 1) * blockDim().x + + @inbounds if index <= length(idx) + i = Base._to_linear_index(dst, idx[index]...) + $(atm_op)(pointer(dst, i), src[index]) + end + return nothing + end + + @eval function scatter_kernel!(op::typeof($(op)), dst, src, idx, dims::Val{N}, max_idx, max_dims_idx, dims_size) where {N} + index = threadIdx().x + (blockIdx().x - 1) * blockDim().x + + @inbounds if index <= max_idx + j, k = divrem(index-1, max_dims_idx) + dims_i = CartesianIndices(dims_size)[k+1] + i = Base._to_linear_index(dst, Tuple(dims_i)..., idx[j+1]...) + $(atm_op)(pointer(dst, i), src[index]) + end + return nothing + end + + @eval function NNlib.scatter!(op::typeof($(op)), dst::CuArray{Tdst}, src::CuArray{Tsrc}, idx::CuArray{<:IntOrIntTuple}, dims::Val{N}) where {Tdst,Tsrc,N} + if N == 0 + max_idx = length(idx) + threads = min(MAX_THREADS, max_idx) + blocks = ceil(Int, max_idx / threads) + @cuda blocks=blocks threads=threads scatter_kernel!(op, dst, src, idx) + return dst + else + dims_size = size(dst)[1:N] + max_dims_idx = prod(dims_size) + max_idx = max_dims_idx * length(idx) + threads = min(MAX_THREADS, max_idx) + blocks = ceil(Int, max_idx / threads) + @cuda blocks=blocks threads=threads scatter_kernel!(op, dst, src, idx, dims, max_idx, max_dims_idx, dims_size) + return dst + end + end +end diff --git a/lib/NNlibCUDA/test/runtests.jl b/lib/NNlibCUDA/test/runtests.jl index 4624151ee..778b8753e 100644 --- a/lib/NNlibCUDA/test/runtests.jl +++ b/lib/NNlibCUDA/test/runtests.jl @@ -3,6 +3,7 @@ using NNlib using Zygote using NNlibCUDA using ForwardDiff: Dual +using Statistics: mean using CUDA CUDA.allowscalar(false) @@ -16,4 +17,5 @@ if CUDA.has_cuda() include("pooling.jl") include("softmax.jl") include("batchnorm.jl") + include("scatter.jl") end diff --git a/lib/NNlibCUDA/test/scatter.jl b/lib/NNlibCUDA/test/scatter.jl new file mode 100644 index 000000000..154a524bf --- /dev/null +++ b/lib/NNlibCUDA/test/scatter.jl @@ -0,0 +1,149 @@ +dsts = Dict( + 0 => cu([3, 4, 5, 6, 7]), + 1 => cu([3 3 4 4 5; + 5 5 6 6 7]), +) +srcs = Dict( + (0, true) => cu(ones(Int, 3, 4)), + (0, false) => cu(ones(Int, 3) * collect(1:4)'), + (1, true) => cu(ones(Int, 2, 3, 4)), + (1, false) => cu([1, 2] .* reshape(ones(Int, 3) * collect(1:4)', 1,3,4)), +) +idxs = [ + cu([1 2 3 4; + 4 2 1 3; + 3 5 5 3]), # integer index + cu([(1,) (2,) (3,) (4,); + (4,) (2,) (1,) (3,); + (3,) (5,) (5,) (3,)]), # tuple index +] +res = Dict( + (+, 0, true) => cu([5, 6, 9, 8, 9]), + (+, 1, true) => cu([5 5 8 6 7; + 7 7 10 8 9]), + (+, 0, false) => cu([4, 4, 12, 5, 5]), + (+, 1, false) => cu([4 4 12 5 5; + 8 8 24 10 10]), + (-, 0, true) => cu([1, 2, 1, 4, 5]), + (-, 1, true) => cu([1 1 0 2 3; + 3 3 2 4 5]), + (-, 0, false) => cu([-4, -4, -12, -5, -5]), + (-, 1, false) => cu([-4 -4 -12 -5 -5; + -8 -8 -24 -10 -10]), + (max, 0, true) => cu([3, 4, 5, 6, 7]), + (max, 1, true) => cu([3 3 4 4 5; + 5 5 6 6 7]), + (max, 0, false) => cu([3, 2, 4, 4, 3]), + (max, 1, false) => cu([3 2 4 4 3; + 6 4 8 8 6]), + (min, 0, true) => cu([1, 1, 1, 1, 1]), + (min, 1, true) => cu([1 1 1 1 1; + 1 1 1 1 1]), + (min, 0, false) => cu([1, 2, 1, 1, 2]), + (min, 1, false) => cu([1 2 1 1 2; + 2 4 2 2 4]), + (*, 0, true) => cu([3, 4, 5, 6, 7]), + (*, 1, true) => cu([3 3 4 4 5; + 5 5 6 6 7]), + (*, 0, false) => cu([3, 4, 48, 4, 6]), + (*, 1, false) => cu([3 4 48 4 6; + 12 16 768 16 24]), + (/, 0, true) => cu([0.75, 1., 0.3125, 1.5, 1.75]), + (/, 1, true) => cu([0.75 0.75 0.25 1. 1.25; + 1.25 1.25 0.375 1.5 1.75]), + (/, 0, false) => cu([1//3, 1//4, 1//48, 1//4, 1//6]), + (/, 1, false) => cu([1//3 1//4 1//48 1//4 1//6; + 1//12 1//16 1//768 1//16 1//24]), + (mean, 0, true) => cu([4., 5., 6., 7., 8.]), + (mean, 1, true) => cu([4. 4. 5. 5. 6.; + 6. 6. 7. 7. 8.]), + (mean, 0, false) => cu([2, 2, 3, 2.5, 2.5]), + (mean, 1, false) => cu([2. 2. 3. 2.5 2.5; + 4. 4. 6. 5. 5.]), +) + +types = [CuArray{UInt32}, CuArray{UInt64}, + CuArray{Int32}, CuArray{Int64}, + CuArray{Float32}, CuArray{Float64}] + + +@testset "scatter" begin + for T = types + @testset "$(T)" begin + @testset "+" begin + for idx = idxs, dims = [0, 1] + mutated = true + @test scatter!(+, T(dsts[dims]), T(srcs[(dims, mutated)]), idx) == T(res[(+, dims, mutated)]) + + mutated = false + # @test scatter(+, srcs[(dims, mutated)], idx) == T(res[(+, dims, mutated)]) + end + end + + @testset "-" begin + for idx = idxs, dims = [0, 1] + mutated = true + @test scatter!(-, T(dsts[dims]), T(srcs[(dims, mutated)]), idx) == T(res[(-, dims, mutated)]) + + mutated = false + # @test scatter(-, srcs[(dims, mutated)], idx) == T(res[(-, dims, mutated)]) + end + end + + @testset "max" begin + for idx = idxs, dims = [0, 1] + mutated = true + @test scatter!(max, T(dsts[dims]), T(srcs[(dims, mutated)]), idx) == T(res[(max, dims, mutated)]) + + mutated = false + # @test scatter(max, srcs[(dims, mutated)], idx) == T(res[(max, dims, mutated)]) + end + end + + @testset "min" begin + for idx = idxs, dims = [0, 1] + mutated = true + @test scatter!(min, T(dsts[dims]), T(srcs[(dims, mutated)]), idx) == T(res[(min, dims, mutated)]) + + mutated = false + # @test scatter(min, srcs[(dims, mutated)], idx) == T(res[(min, dims, mutated)]) + end + end + end + end + + + for T = [CuArray{Float32}, CuArray{Float64}] + @testset "$(T)" begin + @testset "*" begin + for idx = idxs, dims = [0, 1] + mutated = true + @test scatter!(*, T(dsts[dims]), T(srcs[(dims, mutated)]), idx) == T(res[(*, dims, mutated)]) + + mutated = false + # @test scatter(*, srcs[(dims, mutated)], idx) == T(res[(*, dims, mutated)]) + end + end + + @testset "/" begin + for idx = idxs, dims = [0, 1] + mutated = true + @test scatter!(/, T(dsts[dims]), T(srcs[(dims, mutated)].*2), idx) == T(res[(/, dims, mutated)]) + + mutated = false + # @test scatter(/, srcs[(dims, mutated)], idx) == T(res[(/, dims, mutated)]) + end + end + + @testset "mean" begin + for idx = idxs, dims = [0, 1] + mutated = true + @test scatter!(mean, T(dsts[dims]), T(srcs[(dims, mutated)]), idx) == T(res[(mean, dims, mutated)]) + + mutated = false + # @test scatter(mean, srcs[(dims, mutated)], idx) == T(res[(mean, dims, mutated)]) + end + end + end + end +end