Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support scatter for CUDA gradient #13

Merged
merged 1 commit into from
Jul 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 23 additions & 41 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,22 @@ uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
version = "0.4.1"

[[CUDA]]
deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "DataStructures", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "MacroTools", "Memoize", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "TimerOutputs"]
git-tree-sha1 = "364179416eabc34c9ca32126a6bdb431680c3bad"
deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "DataStructures", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "TimerOutputs"]
git-tree-sha1 = "82b2811f5888465d96b38c7bb12d8fb9c25838e1"
uuid = "052768ef-5323-5732-b1bb-66c8b64840ba"
version = "3.2.1"
version = "3.3.1"

[[ChainRulesCore]]
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
git-tree-sha1 = "04dd5ce9f9d7b9b14559b00a7eb5be7528f56b82"
git-tree-sha1 = "be770c08881f7bb928dfd86d1ba83798f76cf62a"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.10.2"
version = "0.10.9"

[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "e4e2b39db08f967cc1360951f01e8a75ec441cab"
git-tree-sha1 = "dc7dedc2c2aa9faf59a55c622760a25cbefbe941"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "3.30.0"
version = "3.31.0"

[[CompilerSupportLibraries_jll]]
deps = ["Artifacts", "Libdl"]
Expand All @@ -73,10 +73,10 @@ deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"

[[DocStringExtensions]]
deps = ["LibGit2", "Markdown", "Pkg", "Test"]
git-tree-sha1 = "9d4f64f79012636741cf01133158a54b24924c32"
deps = ["LibGit2"]
git-tree-sha1 = "a32185f5428d3986f47c2ab78b1f216d5e6cc96f"
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
version = "0.8.4"
version = "0.8.5"

[[Downloads]]
deps = ["ArgTools", "LibCURL", "NetworkOptions"]
Expand All @@ -89,15 +89,15 @@ version = "0.1.3"

[[GPUArrays]]
deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization", "Statistics"]
git-tree-sha1 = "df5b8569904c5c10e84c640984cfff054b18c086"
git-tree-sha1 = "ececbf05f8904c92814bdbd0aafd5540b0bf2e9a"
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
version = "6.4.1"
version = "7.0.1"

[[GPUCompiler]]
deps = ["DataStructures", "ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "Serialization", "TimerOutputs", "UUIDs"]
git-tree-sha1 = "42d635f6d87af125b86288df3819f805fb4d851a"
deps = ["DataStructures", "ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "TimerOutputs", "UUIDs"]
git-tree-sha1 = "222c6cdb888ec24795936d6829aa978691def60e"
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
version = "0.11.5"
version = "0.12.3"

[[InteractiveUtils]]
deps = ["Markdown"]
Expand All @@ -111,9 +111,9 @@ version = "1.3.0"

[[LLVM]]
deps = ["CEnum", "Libdl", "Printf", "Unicode"]
git-tree-sha1 = "b499c68a45249b0385585c62f4a9b62b5db8e691"
git-tree-sha1 = "f57ac3fd2045b50d3db081663837ac5b4096947e"
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
version = "3.7.1"
version = "3.9.0"

[[LazyArtifacts]]
deps = ["Artifacts", "Pkg"]
Expand Down Expand Up @@ -151,12 +151,6 @@ version = "0.2.4"
[[Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"

[[MacroTools]]
deps = ["Markdown", "Random"]
git-tree-sha1 = "6a8a2a625ab0dea913aba95c11370589e0239ff0"
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
version = "0.5.6"

[[Markdown]]
deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
Expand All @@ -165,12 +159,6 @@ uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
deps = ["Artifacts", "Libdl"]
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"

[[Memoize]]
deps = ["MacroTools"]
git-tree-sha1 = "2b1dfcba103de714d31c033b5dacc2e4a12c7caa"
uuid = "c03570c3-d221-55d1-a50c-7939bbd78826"
version = "0.4.4"

[[Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"

Expand All @@ -179,9 +167,9 @@ uuid = "14a3606d-f60d-562e-9121-12d972cd8159"

[[NNlib]]
deps = ["Adapt", "ChainRulesCore", "Compat", "LinearAlgebra", "Pkg", "Requires", "Statistics"]
git-tree-sha1 = "0bf1fbb9dc557f2af9fb7e1337366d69de0dc78c"
git-tree-sha1 = "7e6f31cfa39b1ff1c541cc8580b14b0ff4ba22d0"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.7.21"
version = "0.7.23"

[[NetworkOptions]]
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
Expand Down Expand Up @@ -221,9 +209,9 @@ uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[[Random123]]
deps = ["Libdl", "Random", "RandomNumbers"]
git-tree-sha1 = "7c6710c8198fd4444b5eb6a3840b7d47bd3593c5"
git-tree-sha1 = "0e8b146557ad1c6deb1367655e052276690e71a3"
uuid = "74087812-796a-5b5d-8853-05524746bad3"
version = "1.3.1"
version = "1.4.2"

[[RandomNumbers]]
deps = ["Random", "Requires"]
Expand All @@ -245,12 +233,6 @@ version = "1.1.3"
[[SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"

[[Scratch]]
deps = ["Dates"]
git-tree-sha1 = "ad4b278adb62d185bbcb6864dc24959ab0627bf6"
uuid = "6c6a2e73-6563-6170-7368-637461726353"
version = "1.0.3"

[[Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"

Expand Down Expand Up @@ -289,9 +271,9 @@ uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[[TimerOutputs]]
deps = ["ExprTools", "Printf"]
git-tree-sha1 = "bf8aacc899a1bd16522d0350e1e2310510d77236"
git-tree-sha1 = "9f494bc54b4c31404a9eff449235836615929de1"
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
version = "0.5.9"
version = "0.5.10"

[[UUIDs]]
deps = ["Random", "SHA"]
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
CUDA = "3.3.1"
NNlib = "0.7.21"
NNlib = "0.7.23"
julia = "1.6"

[extras]
Expand Down
1 change: 1 addition & 0 deletions src/NNlibCUDA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ include("activations.jl")
include("batchedmul.jl")
include("scatter.jl")
include("gather.jl")
include("utils.jl")
include("cudnn/cudnn.jl")
include("cudnn/conv.jl")
include("cudnn/pooling.jl")
Expand Down
72 changes: 72 additions & 0 deletions src/scatter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,75 @@ function NNlib.scatter!(op::typeof(mean), dst::AnyCuArray, src::AnyCuArray, idx:
dst .+= NNlib.safe_div.(dst_, Ns)
return dst
end


## Gradients

function ∇scatter_src_kernel!(op, Δsrc, src, idx, rev_idx, max_idx, T)
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x

@inbounds if index <= max_idx
cart_j = CartesianIndices(idx)[index]
# get aggregating indeices, which is to be aggregated together, and itself index
inds = rev_idx[idx[cart_j]...]
# multiply all values to be aggregated but not itself
x = one(T)
for k in inds
x *= src[k]
end
x /= src[cart_j]
# apply `op` on `Δsrc[i, k]` and `x`
Δsrc[cart_j] = op(Δsrc[cart_j], x)
end
return nothing
end

function ∇scatter_src_kernel!(op, Δsrc, src, idx, rev_idx, pre_cart_idx, max_dims_idx, max_idx, T)
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x

@inbounds if index <= max_idx
i, j = fldmod1(index, max_dims_idx)
cart_i = CartesianIndices(idx)[i]
cart_j = pre_cart_idx[j]
# get aggregating indeices, which is to be aggregated together, and itself index
inds = rev_idx[idx[cart_i]...]
# multiply all values to be aggregated but not itself
x = one(T)
for k in inds
jk = Base._to_linear_index(src, Tuple(cart_j)..., Tuple(k)...)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Base._to_linear_index is introduced here to transform index of any form into integer index. Integer index is required to index a cuarray.

x *= src[jk]
end
x /= src[index]
# apply `op` on `Δsrc[i, k]` and `x`
Δsrc[index] = op(Δsrc[index], x)
end
return nothing
end

function NNlib.∇scatter_src(op::Union{typeof(*),typeof(/)}, Δ, dst,
src::AnyCuArray{Tsrc,Nsrc},
idx::AnyCuArray{Tidx,Nidx}) where {Tsrc,Tidx,Nsrc,Nidx}
dims = Nsrc - Nidx
Δsrc = NNlib.modify_src(op, NNlib.gather(Δ, idx), src)
rev_idx = NNlib.reverse_indices(idx)
rev_idx = CuArray(map(CUDA.cudaconvert, rev_idx))

if dims == 0
max_idx = length(idx)
args = op, Δsrc, src, idx, rev_idx, max_idx, Tsrc
else
pre_cart_idx = CartesianIndices(axes(src)[1:dims])
max_dims_idx = length(pre_cart_idx)
max_idx = max_dims_idx * length(idx)
args = op, Δsrc, src, idx, rev_idx, pre_cart_idx, max_dims_idx, max_idx, Tsrc
end

kernel = @cuda launch=false ∇scatter_src_kernel!(args...)
config = launch_configuration(kernel.fun; max_threads=256)
threads = min(max_idx, config.threads)
blocks = cld(max_idx, threads)
kernel(args...; threads=threads, blocks=blocks)

CUDA.unsafe_free!(rev_idx)
return Δsrc
end
56 changes: 56 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
function NNlib.count_indices(idx::AnyCuArray)
dst_counts = length.(NNlib.reverse_indices(idx))
src_counts = NNlib.gather(cu(dst_counts), idx)
return src_counts
end

function divide_kernel!(xs, ys, max_idx)
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x

@inbounds if index <= max_idx
xs[index] = xs[index] / ys[index]
end
return nothing
end

function divide_kernel!(xs, counts, max_idx, max_dims_idx, dims_size)
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x

@inbounds if index <= max_idx
j, k = divrem(index-1, max_dims_idx)
dims_i = Tuple(CartesianIndices(dims_size)[k+1])
@atomic xs[dims_i..., j+1] = xs[dims_i..., j+1] / counts[j+1]
end
return nothing
end

function NNlib.divide_by_counts!(xs::AnyCuArray{T}, idx::AnyCuArray, dims) where {T}
counts = CuArray{T}(NNlib.count_indices(idx))
args = if dims == 0
max_idx = length(idx)
xs, counts, max_idx
else
dims_size = size(xs)[1:dims]
max_dims_idx = prod(dims_size)
max_idx = prod(size(xs))
xs, counts, max_idx, max_dims_idx, dims_size
end

kernel = @cuda launch=false divide_kernel!(args...)
config = launch_configuration(kernel.fun; max_threads=256)
threads = min(max_idx, config.threads)
blocks = cld(max_idx, threads)
kernel(args...; threads=threads, blocks=blocks)
return xs
end

function NNlib.reverse_indices(idx::AnyCuArray{<:Any,N}) where N
max_dims = NNlib.maximum_dims(idx)
T = CartesianIndex{N}
rev = Array{Vector{T}}(undef, max_dims...)
for i in eachindex(rev)
rev[i] = T[]
end
NNlib.reverse_indices!(rev, idx)
return map(cu, rev)
end
Loading