diff --git a/Manifest.toml b/Manifest.toml index 4db5124..22b48b9 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -8,9 +8,9 @@ version = "1.0.1" [[Adapt]] deps = ["LinearAlgebra"] -git-tree-sha1 = "f1b523983a58802c4695851926203b36e28f09db" +git-tree-sha1 = "84918055d15b3114ede17ac6a7182f68870c16f7" uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "3.3.0" +version = "3.3.1" [[ArgTools]] uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" @@ -40,9 +40,9 @@ version = "3.2.1" [[ChainRulesCore]] deps = ["Compat", "LinearAlgebra", "SparseArrays"] -git-tree-sha1 = "b391f22252b8754f4440de1f37ece49d8a7314bb" +git-tree-sha1 = "04dd5ce9f9d7b9b14559b00a7eb5be7528f56b82" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.9.44" +version = "0.10.2" [[Compat]] deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] @@ -179,18 +179,18 @@ uuid = "14a3606d-f60d-562e-9121-12d972cd8159" [[NNlib]] deps = ["Adapt", "ChainRulesCore", "Compat", "LinearAlgebra", "Pkg", "Requires", "Statistics"] -git-tree-sha1 = "80b8360670f445d88b3475e88b33bbcc92f7866e" +git-tree-sha1 = "0bf1fbb9dc557f2af9fb7e1337366d69de0dc78c" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.7.19" +version = "0.7.21" [[NetworkOptions]] uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" [[OpenSpecFun_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "b9b8b8ed236998f91143938a760c2112dceeb2b4" +git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" -version = "0.5.4+0" +version = "0.5.5+0" [[OrderedCollections]] git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c" @@ -232,9 +232,9 @@ uuid = "e6cf234a-135c-5ec9-84dd-332b85af5143" version = "1.4.0" [[Reexport]] -git-tree-sha1 = "57d8440b0c7d98fc4f889e478e80f268d534c9d5" +git-tree-sha1 = "5f6c21241f0f655da3952fd60aa18477cf96c220" uuid = "189a3867-3050-52da-a836-e630ba90ab69" -version = "1.0.0" +version = "1.1.0" [[Requires]] deps = ["UUIDs"] @@ -267,9 +267,9 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [[SpecialFunctions]] deps = ["ChainRulesCore", "LogExpFunctions", "OpenSpecFun_jll"] -git-tree-sha1 = "c467f25b6ec4167ea3a9a4351c66c2e1cba5da33" +git-tree-sha1 = "a50550fa3164a8c46747e62063b4d774ac1bcf49" uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "1.4.1" +version = "1.5.1" [[Statistics]] deps = ["LinearAlgebra", "SparseArrays"] diff --git a/src/gather.jl b/src/gather.jl index dcfd29b..da06a19 100644 --- a/src/gather.jl +++ b/src/gather.jl @@ -51,3 +51,12 @@ function NNlib.gather!(dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArray) kernel(args...; threads=threads, blocks=blocks) return dst end + +# Gradient + +function NNlib.rrule(::typeof(NNlib.gather!), dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArray) + y = NNlib.gather!(dst, src, idx) + src_size = size(src) + gather!_pullback(Δ) = (NO_FIELDS, NoTangent(), NNlib.∇gather_src(Δ, src_size, idx), NoTangent()) + y, gather!_pullback +end diff --git a/src/scatter.jl b/src/scatter.jl index 6f81211..e1d4ed1 100644 --- a/src/scatter.jl +++ b/src/scatter.jl @@ -21,7 +21,7 @@ function scatter_kernel!(op, dst, src, idx, max_idx, max_dims_idx, dims_size) end function NNlib.scatter!(op, dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArray) - dims = NNlib._check_dims(dst, src, idx) + dims = NNlib.scatter_dims(dst, src, idx) args = if dims == 0 max_idx = length(idx) op, dst, src, idx diff --git a/test/gather.jl b/test/gather.jl index 8dd2f20..423f8c6 100644 --- a/test/gather.jl +++ b/test/gather.jl @@ -14,7 +14,7 @@ y = NNlib.gather(src, index) @test y isa CuArray{Float32,2} @test size(y) == size(index) - gputest(src -> NNlib.gather(src, index), src, checkgrad=false) + gputest(src -> NNlib.gather(src, index), src, checkgrad=true) @test NNlib.gather!(CUDA.zeros(T, size(index)...), src, index) == output @test_throws ArgumentError NNlib.gather!(zeros(T, 3, 5), src, index) @@ -30,7 +30,7 @@ y = NNlib.gather(src, index) @test y isa CuArray{Float32,3} @test size(y) == size(index) - gputest(src -> NNlib.gather(src, index), src, checkgrad=false) + gputest(src -> NNlib.gather(src, index), src, checkgrad=true) ## 2d src, 2d index of ints -> 3d output @@ -56,5 +56,5 @@ Nsrc = ndims(src) @test y isa CuArray{Float32,3} @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) - gputest(src -> NNlib.gather(src, index), src, checkgrad=false) + gputest(src -> NNlib.gather(src, index), src, checkgrad=true) end