From 5b53a863fe58dc85e1d1862e1128a25b4cde1c99 Mon Sep 17 00:00:00 2001 From: Yueh-Hua Tu Date: Wed, 9 Jun 2021 18:31:57 +0800 Subject: [PATCH] support scatter for cuda gradient MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit add count_indices for cuarray add CUDA kernel for divide_by_counts! add NNlib.∇scatter_src for cuda gradient support scatter mean AD for CUDA support scatter *,/ AD for CUDA --- Manifest.toml | 64 ++++++++++++++++-------------------------- src/NNlibCUDA.jl | 1 + src/scatter.jl | 72 ++++++++++++++++++++++++++++++++++++++++++++++++ src/utils.jl | 56 +++++++++++++++++++++++++++++++++++++ test/scatter.jl | 72 ++++++++++-------------------------------------- 5 files changed, 166 insertions(+), 99 deletions(-) create mode 100644 src/utils.jl diff --git a/Manifest.toml b/Manifest.toml index 22b48b9..d79d06a 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -33,22 +33,22 @@ uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" version = "0.4.1" [[CUDA]] -deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "DataStructures", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "MacroTools", "Memoize", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "TimerOutputs"] -git-tree-sha1 = "364179416eabc34c9ca32126a6bdb431680c3bad" +deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "DataStructures", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "TimerOutputs"] +git-tree-sha1 = "82b2811f5888465d96b38c7bb12d8fb9c25838e1" uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" -version = "3.2.1" +version = "3.3.1" [[ChainRulesCore]] deps = ["Compat", "LinearAlgebra", "SparseArrays"] -git-tree-sha1 = "04dd5ce9f9d7b9b14559b00a7eb5be7528f56b82" +git-tree-sha1 = "be770c08881f7bb928dfd86d1ba83798f76cf62a" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.10.2" +version = "0.10.9" [[Compat]] deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] -git-tree-sha1 = "e4e2b39db08f967cc1360951f01e8a75ec441cab" +git-tree-sha1 = "dc7dedc2c2aa9faf59a55c622760a25cbefbe941" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "3.30.0" +version = "3.31.0" [[CompilerSupportLibraries_jll]] deps = ["Artifacts", "Libdl"] @@ -73,10 +73,10 @@ deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" [[DocStringExtensions]] -deps = ["LibGit2", "Markdown", "Pkg", "Test"] -git-tree-sha1 = "9d4f64f79012636741cf01133158a54b24924c32" +deps = ["LibGit2"] +git-tree-sha1 = "a32185f5428d3986f47c2ab78b1f216d5e6cc96f" uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -version = "0.8.4" +version = "0.8.5" [[Downloads]] deps = ["ArgTools", "LibCURL", "NetworkOptions"] @@ -89,15 +89,15 @@ version = "0.1.3" [[GPUArrays]] deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization", "Statistics"] -git-tree-sha1 = "df5b8569904c5c10e84c640984cfff054b18c086" +git-tree-sha1 = "ececbf05f8904c92814bdbd0aafd5540b0bf2e9a" uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -version = "6.4.1" +version = "7.0.1" [[GPUCompiler]] -deps = ["DataStructures", "ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "Serialization", "TimerOutputs", "UUIDs"] -git-tree-sha1 = "42d635f6d87af125b86288df3819f805fb4d851a" +deps = ["DataStructures", "ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "TimerOutputs", "UUIDs"] +git-tree-sha1 = "222c6cdb888ec24795936d6829aa978691def60e" uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" -version = "0.11.5" +version = "0.12.3" [[InteractiveUtils]] deps = ["Markdown"] @@ -111,9 +111,9 @@ version = "1.3.0" [[LLVM]] deps = ["CEnum", "Libdl", "Printf", "Unicode"] -git-tree-sha1 = "b499c68a45249b0385585c62f4a9b62b5db8e691" +git-tree-sha1 = "f57ac3fd2045b50d3db081663837ac5b4096947e" uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "3.7.1" +version = "3.9.0" [[LazyArtifacts]] deps = ["Artifacts", "Pkg"] @@ -151,12 +151,6 @@ version = "0.2.4" [[Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" -[[MacroTools]] -deps = ["Markdown", "Random"] -git-tree-sha1 = "6a8a2a625ab0dea913aba95c11370589e0239ff0" -uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.5.6" - [[Markdown]] deps = ["Base64"] uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" @@ -165,12 +159,6 @@ uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" deps = ["Artifacts", "Libdl"] uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" -[[Memoize]] -deps = ["MacroTools"] -git-tree-sha1 = "2b1dfcba103de714d31c033b5dacc2e4a12c7caa" -uuid = "c03570c3-d221-55d1-a50c-7939bbd78826" -version = "0.4.4" - [[Mmap]] uuid = "a63ad114-7e13-5084-954f-fe012c677804" @@ -179,9 +167,9 @@ uuid = "14a3606d-f60d-562e-9121-12d972cd8159" [[NNlib]] deps = ["Adapt", "ChainRulesCore", "Compat", "LinearAlgebra", "Pkg", "Requires", "Statistics"] -git-tree-sha1 = "0bf1fbb9dc557f2af9fb7e1337366d69de0dc78c" +git-tree-sha1 = "7461639cef384a2ad058005b49e32b318d844343" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.7.21" +version = "0.7.22" [[NetworkOptions]] uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" @@ -221,9 +209,9 @@ uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [[Random123]] deps = ["Libdl", "Random", "RandomNumbers"] -git-tree-sha1 = "7c6710c8198fd4444b5eb6a3840b7d47bd3593c5" +git-tree-sha1 = "0e8b146557ad1c6deb1367655e052276690e71a3" uuid = "74087812-796a-5b5d-8853-05524746bad3" -version = "1.3.1" +version = "1.4.2" [[RandomNumbers]] deps = ["Random", "Requires"] @@ -245,12 +233,6 @@ version = "1.1.3" [[SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" -[[Scratch]] -deps = ["Dates"] -git-tree-sha1 = "ad4b278adb62d185bbcb6864dc24959ab0627bf6" -uuid = "6c6a2e73-6563-6170-7368-637461726353" -version = "1.0.3" - [[Serialization]] uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" @@ -289,9 +271,9 @@ uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [[TimerOutputs]] deps = ["ExprTools", "Printf"] -git-tree-sha1 = "bf8aacc899a1bd16522d0350e1e2310510d77236" +git-tree-sha1 = "9f494bc54b4c31404a9eff449235836615929de1" uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" -version = "0.5.9" +version = "0.5.10" [[UUIDs]] deps = ["Random", "SHA"] diff --git a/src/NNlibCUDA.jl b/src/NNlibCUDA.jl index f98c08f..0fc5626 100644 --- a/src/NNlibCUDA.jl +++ b/src/NNlibCUDA.jl @@ -11,6 +11,7 @@ include("activations.jl") include("batchedmul.jl") include("scatter.jl") include("gather.jl") +include("utils.jl") include("cudnn/cudnn.jl") include("cudnn/conv.jl") include("cudnn/pooling.jl") diff --git a/src/scatter.jl b/src/scatter.jl index d08fc32..b6569aa 100644 --- a/src/scatter.jl +++ b/src/scatter.jl @@ -46,3 +46,75 @@ function NNlib.scatter!(op::typeof(mean), dst::AnyCuArray, src::AnyCuArray, idx: dst .+= NNlib.safe_div.(dst_, Ns) return dst end + + +## Gradients + +function ∇scatter_src_kernel!(op, Δsrc, src, idx, rev_idx, max_idx, T) + index = threadIdx().x + (blockIdx().x - 1) * blockDim().x + + @inbounds if index <= max_idx + cart_j = CartesianIndices(idx)[index] + # get aggregating indeices, which is to be aggregated together, and itself index + inds = rev_idx[idx[cart_j]...] + # multiply all values to be aggregated but not itself + x = one(T) + for k in inds + x *= src[k] + end + x /= src[cart_j] + # apply `op` on `Δsrc[i, k]` and `x` + Δsrc[cart_j] = op(Δsrc[cart_j], x) + end + return nothing +end + +function ∇scatter_src_kernel!(op, Δsrc, src, idx, rev_idx, ax, max_dims_idx, max_idx, T) + index = threadIdx().x + (blockIdx().x - 1) * blockDim().x + + @inbounds if index <= max_idx + i, j = fldmod1(index, max_dims_idx) + cart_j = CartesianIndices(idx)[j] + cart_i = ax[i] + l = Tuple(cart_i)..., Tuple(cart_j)... + # get aggregating indeices, which is to be aggregated together, and itself index + inds = rev_idx[idx[cart_j]...] + # multiply all values to be aggregated but not itself + x = one(T) + for k in inds + x *= src[Tuple(cart_i)..., k] + end + x /= src[l] + # apply `op` on `Δsrc[i, k]` and `x` + Δsrc[l] = op(Δsrc[l], x) + end + return nothing +end + +function NNlib.∇scatter_src(op::Union{typeof(*),typeof(/)}, Δ, dst, + src::AnyCuArray{Tsrc,Nsrc}, + idx::AnyCuArray{Tidx,Nidx}) where {Tsrc,Tidx,Nsrc,Nidx} + dims = Nsrc - Nidx + Δsrc = NNlib.modify_src(op, NNlib.gather(Δ, idx), src) + rev_idx = NNlib.reverse_indices(idx) + rev_idx = CuArray(map(CUDA.cudaconvert, rev_idx)) + + if dims == 0 + max_idx = length(idx) + args = op, Δsrc, src, idx, rev_idx, max_idx, Tsrc + else + ax = CartesianIndices(axes(src)[1:dims]) + max_dims_idx = length(ax) + max_idx = max_dims_idx * length(idx) + args = op, Δsrc, src, idx, rev_idx, ax, max_dims_idx, max_idx, Tsrc + end + + kernel = @cuda launch=false ∇scatter_src_kernel!(args...) + config = launch_configuration(kernel.fun; max_threads=256) + threads = min(max_idx, config.threads) + blocks = cld(max_idx, threads) + kernel(args...; threads=threads, blocks=blocks) + + CUDA.unsafe_free!(rev_idx) + return Δsrc +end diff --git a/src/utils.jl b/src/utils.jl new file mode 100644 index 0000000..bb0c8b2 --- /dev/null +++ b/src/utils.jl @@ -0,0 +1,56 @@ +function NNlib.count_indices(idx::AnyCuArray) + dst_counts = length.(NNlib.reverse_indices(idx)) + src_counts = NNlib.gather(cu(dst_counts), idx) + return src_counts +end + +function divide_kernel!(xs, ys, max_idx) + index = threadIdx().x + (blockIdx().x - 1) * blockDim().x + + @inbounds if index <= max_idx + xs[index] = xs[index] / ys[index] + end + return nothing +end + +function divide_kernel!(xs, counts, max_idx, max_dims_idx, dims_size) + index = threadIdx().x + (blockIdx().x - 1) * blockDim().x + + @inbounds if index <= max_idx + j, k = divrem(index-1, max_dims_idx) + dims_i = Tuple(CartesianIndices(dims_size)[k+1]) + @atomic xs[dims_i..., j+1] = xs[dims_i..., j+1] / counts[j+1] + end + return nothing +end + +function NNlib.divide_by_counts!(xs::AnyCuArray{T}, idx::AnyCuArray, dims) where {T} + counts = CuArray{T}(NNlib.count_indices(idx)) + args = if dims == 0 + max_idx = length(idx) + xs, counts, max_idx + else + dims_size = size(xs)[1:dims] + max_dims_idx = prod(dims_size) + max_idx = prod(size(xs)) + xs, counts, max_idx, max_dims_idx, dims_size + end + + kernel = @cuda launch=false divide_kernel!(args...) + config = launch_configuration(kernel.fun; max_threads=256) + threads = min(max_idx, config.threads) + blocks = cld(max_idx, threads) + kernel(args...; threads=threads, blocks=blocks) + return xs +end + +function NNlib.reverse_indices(idx::AnyCuArray{<:Any,N}) where N + max_dims = maximum_dims(idx) + T = CartesianIndex{N} + rev = Array{Vector{T}}(undef, max_dims...) + for i in eachindex(rev) + rev[i] = T[] + end + NNlib.reverse_indices!(rev, idx) + return map(cu, rev) +end diff --git a/test/scatter.jl b/test/scatter.jl index 16b6e43..088284e 100644 --- a/test/scatter.jl +++ b/test/scatter.jl @@ -17,50 +17,6 @@ idxs = [ (4,) (2,) (1,) (3,); (3,) (5,) (5,) (3,)]), # tuple index ] -res = Dict( - (+, 0, true) => cu([5, 6, 9, 8, 9]), - (+, 1, true) => cu([5 5 8 6 7; - 7 7 10 8 9]), - (+, 0, false) => cu([4, 4, 12, 5, 5]), - (+, 1, false) => cu([4 4 12 5 5; - 8 8 24 10 10]), - (-, 0, true) => cu([1, 2, 1, 4, 5]), - (-, 1, true) => cu([1 1 0 2 3; - 3 3 2 4 5]), - (-, 0, false) => cu([-4, -4, -12, -5, -5]), - (-, 1, false) => cu([-4 -4 -12 -5 -5; - -8 -8 -24 -10 -10]), - (max, 0, true) => cu([3, 4, 5, 6, 7]), - (max, 1, true) => cu([3 3 4 4 5; - 5 5 6 6 7]), - (max, 0, false) => cu([3, 2, 4, 4, 3]), - (max, 1, false) => cu([3 2 4 4 3; - 6 4 8 8 6]), - (min, 0, true) => cu([1, 1, 1, 1, 1]), - (min, 1, true) => cu([1 1 1 1 1; - 1 1 1 1 1]), - (min, 0, false) => cu([1, 2, 1, 1, 2]), - (min, 1, false) => cu([1 2 1 1 2; - 2 4 2 2 4]), - (*, 0, true) => cu([3, 4, 5, 6, 7]), - (*, 1, true) => cu([3 3 4 4 5; - 5 5 6 6 7]), - (*, 0, false) => cu([3, 4, 48, 4, 6]), - (*, 1, false) => cu([3 4 48 4 6; - 12 16 768 16 24]), - (/, 0, true) => cu([0.75, 1., 0.3125, 1.5, 1.75]), - (/, 1, true) => cu([0.75 0.75 0.25 1. 1.25; - 1.25 1.25 0.375 1.5 1.75]), - (/, 0, false) => cu([1//3, 1//4, 1//48, 1//4, 1//6]), - (/, 1, false) => cu([1//3 1//4 1//48 1//4 1//6; - 1//12 1//16 1//768 1//16 1//24]), - (mean, 0, true) => cu([4., 5., 6., 7., 8.]), - (mean, 1, true) => cu([4. 4. 5. 5. 6.; - 6. 6. 7. 7. 8.]), - (mean, 0, false) => cu([2, 2, 3, 2.5, 2.5]), - (mean, 1, false) => cu([2. 2. 3. 2.5 2.5; - 4. 4. 6. 5. 5.]), -) types = [CuArray{Int32}, CuArray{Int64}, CuArray{Float32}, CuArray{Float64}] @@ -71,40 +27,40 @@ types = [CuArray{Int32}, CuArray{Int64}, CuArray{Float32}, CuArray{Float64}] @testset "+" begin for idx = idxs, dims = [0, 1] mutated = true - @test NNlib.scatter!(+, T(copy(dsts[dims])), T(srcs[(dims, mutated)]), idx) == T(res[(+, dims, mutated)]) + gputest((dst, src) -> NNlib.scatter!(+, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true) mutated = false - @test NNlib.scatter(+, T(srcs[(dims, mutated)]), idx) == T(res[(+, dims, mutated)]) + gputest(src -> NNlib.scatter(+, src, idx), T(srcs[(dims, mutated)]), checkgrad=true) end end @testset "-" begin for idx = idxs, dims = [0, 1] mutated = true - @test NNlib.scatter!(-, T(copy(dsts[dims])), T(srcs[(dims, mutated)]), idx) == T(res[(-, dims, mutated)]) + gputest((dst, src) -> NNlib.scatter!(-, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true) mutated = false - @test NNlib.scatter(-, T(srcs[(dims, mutated)]), idx) == T(res[(-, dims, mutated)]) + gputest(src -> NNlib.scatter(-, src, idx), T(srcs[(dims, mutated)]), checkgrad=true) end end @testset "max" begin for idx = idxs, dims = [0, 1] mutated = true - @test NNlib.scatter!(max, T(copy(dsts[dims])), T(srcs[(dims, mutated)]), idx) == T(res[(max, dims, mutated)]) + gputest((dst, src) -> NNlib.scatter!(max, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true) mutated = false - @test NNlib.scatter(max, T(srcs[(dims, mutated)]), idx) == T(res[(max, dims, mutated)]) + gputest(src -> NNlib.scatter(max, src, idx), T(srcs[(dims, mutated)]), checkgrad=true) end end @testset "min" begin for idx = idxs, dims = [0, 1] mutated = true - @test NNlib.scatter!(min, T(copy(dsts[dims])), T(srcs[(dims, mutated)]), idx) == T(res[(min, dims, mutated)]) + gputest((dst, src) -> NNlib.scatter!(min, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true) mutated = false - @test NNlib.scatter(min, T(srcs[(dims, mutated)]), idx) == T(res[(min, dims, mutated)]) + gputest(src -> NNlib.scatter(min, src, idx), T(srcs[(dims, mutated)]), checkgrad=true) end end end @@ -116,30 +72,30 @@ types = [CuArray{Int32}, CuArray{Int64}, CuArray{Float32}, CuArray{Float64}] @testset "*" begin for idx = idxs, dims = [0, 1] mutated = true - @test NNlib.scatter!(*, T(copy(dsts[dims])), T(srcs[(dims, mutated)]), idx) == T(res[(*, dims, mutated)]) + gputest((dst, src) -> NNlib.scatter!(*, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true) mutated = false - @test NNlib.scatter(*, T(srcs[(dims, mutated)]), idx) == T(res[(*, dims, mutated)]) + gputest(src -> NNlib.scatter(*, src, idx), T(srcs[(dims, mutated)]), checkgrad=true) end end @testset "/" begin for idx = idxs, dims = [0, 1] mutated = true - @test NNlib.scatter!(/, T(copy(dsts[dims])), T(srcs[(dims, mutated)].*2), idx) == T(res[(/, dims, mutated)]) + gputest((dst, src) -> NNlib.scatter!(/, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true) mutated = false - @test NNlib.scatter(/, T(srcs[(dims, mutated)]), idx) == T(res[(/, dims, mutated)]) + gputest(src -> NNlib.scatter(/, src, idx), T(srcs[(dims, mutated)]), checkgrad=true) end end @testset "mean" begin for idx = idxs, dims = [0, 1] mutated = true - @test NNlib.scatter!(mean, T(copy(dsts[dims])), T(srcs[(dims, mutated)]), idx) == T(res[(mean, dims, mutated)]) + gputest((dst, src) -> NNlib.scatter!(mean, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true) mutated = false - @test NNlib.scatter(mean, T(srcs[(dims, mutated)]), idx) == T(res[(mean, dims, mutated)]) + gputest(src -> NNlib.scatter(mean, src, idx), T(srcs[(dims, mutated)]), checkgrad=true) end end end