Skip to content

Commit

Permalink
support gather for cuda gradient
Browse files Browse the repository at this point in the history
fix
  • Loading branch information
yuehhua committed Jun 9, 2021
1 parent e224c4d commit 6749baf
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 16 deletions.
24 changes: 12 additions & 12 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ version = "1.0.1"

[[Adapt]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "f1b523983a58802c4695851926203b36e28f09db"
git-tree-sha1 = "84918055d15b3114ede17ac6a7182f68870c16f7"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "3.3.0"
version = "3.3.1"

[[ArgTools]]
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
Expand Down Expand Up @@ -40,9 +40,9 @@ version = "3.2.1"

[[ChainRulesCore]]
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
git-tree-sha1 = "b391f22252b8754f4440de1f37ece49d8a7314bb"
git-tree-sha1 = "04dd5ce9f9d7b9b14559b00a7eb5be7528f56b82"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.9.44"
version = "0.10.2"

[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
Expand Down Expand Up @@ -179,18 +179,18 @@ uuid = "14a3606d-f60d-562e-9121-12d972cd8159"

[[NNlib]]
deps = ["Adapt", "ChainRulesCore", "Compat", "LinearAlgebra", "Pkg", "Requires", "Statistics"]
git-tree-sha1 = "80b8360670f445d88b3475e88b33bbcc92f7866e"
git-tree-sha1 = "0bf1fbb9dc557f2af9fb7e1337366d69de0dc78c"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.7.19"
version = "0.7.21"

[[NetworkOptions]]
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"

[[OpenSpecFun_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "b9b8b8ed236998f91143938a760c2112dceeb2b4"
git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1"
uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
version = "0.5.4+0"
version = "0.5.5+0"

[[OrderedCollections]]
git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c"
Expand Down Expand Up @@ -232,9 +232,9 @@ uuid = "e6cf234a-135c-5ec9-84dd-332b85af5143"
version = "1.4.0"

[[Reexport]]
git-tree-sha1 = "57d8440b0c7d98fc4f889e478e80f268d534c9d5"
git-tree-sha1 = "5f6c21241f0f655da3952fd60aa18477cf96c220"
uuid = "189a3867-3050-52da-a836-e630ba90ab69"
version = "1.0.0"
version = "1.1.0"

[[Requires]]
deps = ["UUIDs"]
Expand Down Expand Up @@ -267,9 +267,9 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[[SpecialFunctions]]
deps = ["ChainRulesCore", "LogExpFunctions", "OpenSpecFun_jll"]
git-tree-sha1 = "c467f25b6ec4167ea3a9a4351c66c2e1cba5da33"
git-tree-sha1 = "a50550fa3164a8c46747e62063b4d774ac1bcf49"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "1.4.1"
version = "1.5.1"

[[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
Expand Down
2 changes: 1 addition & 1 deletion src/scatter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ function scatter_kernel!(op, dst, src, idx, max_idx, max_dims_idx, dims_size)
end

function NNlib.scatter!(op, dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArray)
dims = NNlib._check_dims(dst, src, idx)
dims = NNlib.scatter_dims(dst, src, idx)
args = if dims == 0
max_idx = length(idx)
op, dst, src, idx
Expand Down
6 changes: 3 additions & 3 deletions test/gather.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
y = NNlib.gather(src, index)
@test y isa CuArray{Float32,2}
@test size(y) == size(index)
gputest(src -> NNlib.gather(src, index), src, checkgrad=false)
gputest(src -> NNlib.gather(src, index), src, checkgrad=true)
@test NNlib.gather!(CUDA.zeros(T, size(index)...), src, index) == output
@test_throws ArgumentError NNlib.gather!(zeros(T, 3, 5), src, index)

Expand All @@ -30,7 +30,7 @@
y = NNlib.gather(src, index)
@test y isa CuArray{Float32,3}
@test size(y) == size(index)
gputest(src -> NNlib.gather(src, index), src, checkgrad=false)
gputest(src -> NNlib.gather(src, index), src, checkgrad=true)


## 2d src, 2d index of ints -> 3d output
Expand All @@ -56,5 +56,5 @@
Nsrc = ndims(src)
@test y isa CuArray{Float32,3}
@test size(y) == (size(src)[1:Nsrc-M]..., size(index)...)
gputest(src -> NNlib.gather(src, index), src, checkgrad=false)
gputest(src -> NNlib.gather(src, index), src, checkgrad=true)
end

0 comments on commit 6749baf

Please sign in to comment.