From 4f67a9b97e9bf8caaf57e293c054a7f64201dd62 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Thu, 10 Jun 2021 09:41:29 +0200 Subject: [PATCH 1/6] use ArrayInterface.restructure in update! --- Manifest.toml | 150 +++++++++++++++++++++++++-------------- Project.toml | 5 +- src/Flux.jl | 1 - src/optimise/Optimise.jl | 1 + src/optimise/train.jl | 7 +- test/optimise.jl | 37 +++++++--- 6 files changed, 133 insertions(+), 68 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index 455ea5bddf..422daca7d6 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -13,13 +13,19 @@ version = "0.3.4" [[Adapt]] deps = ["LinearAlgebra"] -git-tree-sha1 = "f1b523983a58802c4695851926203b36e28f09db" +git-tree-sha1 = "84918055d15b3114ede17ac6a7182f68870c16f7" uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "3.3.0" +version = "3.3.1" [[ArgTools]] uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" +[[ArrayInterface]] +deps = ["IfElse", "LinearAlgebra", "Requires", "SparseArrays", "Static"] +git-tree-sha1 = "045ff5e1bc8c6fb1ecb28694abba0a0d55b5f4f5" +uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +version = "3.1.17" + [[Artifacts]] uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" @@ -38,22 +44,22 @@ uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" version = "0.4.1" [[CUDA]] -deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "DataStructures", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "MacroTools", "Memoize", "Printf", "Random", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "TimerOutputs"] -git-tree-sha1 = "a6ce96dcf22fc4f1bfdfac02d54f0b77ecf2a4cc" +deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "DataStructures", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "MacroTools", "Memoize", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "TimerOutputs"] +git-tree-sha1 = "364179416eabc34c9ca32126a6bdb431680c3bad" uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" -version = "3.0.3" +version = "3.2.1" [[ChainRules]] -deps = ["ChainRulesCore", "Compat", "LinearAlgebra", "Random", "Reexport", "Requires", "Statistics"] -git-tree-sha1 = "1f410fba5c04d03ab712f348f1542e6059376547" +deps = ["ChainRulesCore", "Compat", "LinearAlgebra", "Random", "Statistics"] +git-tree-sha1 = "720fa9a9ce61ff18842a40f501d6a1f8ba771c64" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.7.61" +version = "0.8.6" [[ChainRulesCore]] deps = ["Compat", "LinearAlgebra", "SparseArrays"] -git-tree-sha1 = "42e3c181483fbd2c416087a0a93838803e358358" +git-tree-sha1 = "8b31cc69cbc38c5c826aaa1c890c694be3622d99" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.9.38" +version = "0.10.3" [[CodecZlib]] deps = ["TranscodingStreams", "Zlib_jll"] @@ -63,15 +69,15 @@ version = "0.7.0" [[ColorTypes]] deps = ["FixedPointNumbers", "Random"] -git-tree-sha1 = "32a2b8af383f11cbb65803883837a149d10dfe8a" +git-tree-sha1 = "024fe24d83e4a5bf5fc80501a314ce0d1aa35597" uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" -version = "0.10.12" +version = "0.11.0" [[Colors]] deps = ["ColorTypes", "FixedPointNumbers", "Reexport"] -git-tree-sha1 = "82f4e6ff9f847eca3e5ebc666ea2cd7b48e8b47e" +git-tree-sha1 = "417b0ed7b8b838aa6ca0a87aadf1bb9eb111ce40" uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" -version = "0.12.7" +version = "0.12.8" [[CommonSubexpressions]] deps = ["MacroTools", "Test"] @@ -81,9 +87,9 @@ version = "0.3.0" [[Compat]] deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] -git-tree-sha1 = "ac4132ad78082518ec2037ae5770b6e796f7f956" +git-tree-sha1 = "e4e2b39db08f967cc1360951f01e8a75ec441cab" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "3.27.0" +version = "3.30.0" [[CompilerSupportLibraries_jll]] deps = ["Artifacts", "Libdl"] @@ -124,6 +130,12 @@ version = "1.0.2" deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" +[[DocStringExtensions]] +deps = ["LibGit2"] +git-tree-sha1 = "a32185f5428d3986f47c2ab78b1f216d5e6cc96f" +uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +version = "0.8.5" + [[Downloads]] deps = ["ArgTools", "LibCURL", "NetworkOptions"] uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" @@ -158,16 +170,16 @@ uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" version = "0.2.1" [[GPUArrays]] -deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization"] -git-tree-sha1 = "9c95b2fd5c16bc7f97371e9f92f0fef77e0f5957" +deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization", "Statistics"] +git-tree-sha1 = "df5b8569904c5c10e84c640984cfff054b18c086" uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -version = "6.2.2" +version = "6.4.1" [[GPUCompiler]] deps = ["DataStructures", "ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "Serialization", "TimerOutputs", "UUIDs"] -git-tree-sha1 = "6eadd2321dc3ac0fc9d530ab01c2caa7fe5d74c6" +git-tree-sha1 = "42d635f6d87af125b86288df3819f805fb4d851a" uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" -version = "0.11.4" +version = "0.11.5" [[IRTools]] deps = ["InteractiveUtils", "MacroTools", "Test"] @@ -175,6 +187,11 @@ git-tree-sha1 = "c67e7515a11f726f44083e74f218d134396d6510" uuid = "7869d1d1-7146-5819-86e3-90919afe41df" version = "0.4.2" +[[IfElse]] +git-tree-sha1 = "28e837ff3e7a6c3cdb252ce49fb412c8eb3caeef" +uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" +version = "0.1.0" + [[InteractiveUtils]] deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" @@ -193,9 +210,9 @@ version = "0.8.4" [[LLVM]] deps = ["CEnum", "Libdl", "Printf", "Unicode"] -git-tree-sha1 = "b616937c31337576360cb9fb872ec7633af7b194" +git-tree-sha1 = "b499c68a45249b0385585c62f4a9b62b5db8e691" uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "3.6.0" +version = "3.7.1" [[LazyArtifacts]] deps = ["Artifacts", "Pkg"] @@ -224,6 +241,12 @@ uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" deps = ["Libdl"] uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +[[LogExpFunctions]] +deps = ["DocStringExtensions", "LinearAlgebra"] +git-tree-sha1 = "1ba664552f1ef15325e68dc4c05c3ef8c2d5d885" +uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +version = "0.2.4" + [[Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" @@ -255,9 +278,9 @@ version = "0.4.4" [[Missings]] deps = ["DataAPI"] -git-tree-sha1 = "f8c673ccc215eb50fcadb285f522420e29e69e1c" +git-tree-sha1 = "4ea90bd5d3985ae1f9a908bd4500ae88921c5ce7" uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" -version = "0.4.5" +version = "1.0.0" [[Mmap]] uuid = "a63ad114-7e13-5084-954f-fe012c677804" @@ -267,15 +290,15 @@ uuid = "14a3606d-f60d-562e-9121-12d972cd8159" [[NNlib]] deps = ["Adapt", "ChainRulesCore", "Compat", "LinearAlgebra", "Pkg", "Requires", "Statistics"] -git-tree-sha1 = "80b8360670f445d88b3475e88b33bbcc92f7866e" +git-tree-sha1 = "0bf1fbb9dc557f2af9fb7e1337366d69de0dc78c" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.7.19" +version = "0.7.21" [[NNlibCUDA]] deps = ["CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics"] -git-tree-sha1 = "4b368b466bcdd25d448a5b20de4b7e481d68b88e" +git-tree-sha1 = "86d4d75e1091fe89d56422044be0dcc83766d2b4" uuid = "a00861dc-f156-4864-bf3c-e6376f28a68d" -version = "0.1.0" +version = "0.1.2" [[NaNMath]] git-tree-sha1 = "bfe47e760d60b82b66b61d2d44128b62e3a369fb" @@ -287,14 +310,14 @@ uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" [[OpenSpecFun_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "9db77584158d0ab52307f8c04f8e7c08ca76b5b3" +git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" -version = "0.5.3+4" +version = "0.5.5+0" [[OrderedCollections]] -git-tree-sha1 = "4fa2ba51070ec13fcc7517db714445b4ab986bdf" +git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c" uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.4.0" +version = "1.4.1" [[Pkg]] deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] @@ -302,9 +325,9 @@ uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" [[Preferences]] deps = ["TOML"] -git-tree-sha1 = "ea79e4c9077208cd3bc5d29631a26bc0cff78902" +git-tree-sha1 = "00cfd92944ca9c760982747e9a1d0d5d86ab1e5a" uuid = "21216c6a-2e73-6563-6e65-726566657250" -version = "1.2.1" +version = "1.2.2" [[Printf]] deps = ["Unicode"] @@ -322,6 +345,12 @@ uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" deps = ["Serialization"] uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +[[Random123]] +deps = ["Libdl", "Random", "RandomNumbers"] +git-tree-sha1 = "7c6710c8198fd4444b5eb6a3840b7d47bd3593c5" +uuid = "74087812-796a-5b5d-8853-05524746bad3" +version = "1.3.1" + [[RandomNumbers]] deps = ["Random", "Requires"] git-tree-sha1 = "441e6fc35597524ada7f85e13df1f4e10137d16f" @@ -329,9 +358,9 @@ uuid = "e6cf234a-135c-5ec9-84dd-332b85af5143" version = "1.4.0" [[Reexport]] -git-tree-sha1 = "57d8440b0c7d98fc4f889e478e80f268d534c9d5" +git-tree-sha1 = "5f6c21241f0f655da3952fd60aa18477cf96c220" uuid = "189a3867-3050-52da-a836-e630ba90ab69" -version = "1.0.0" +version = "1.1.0" [[Requires]] deps = ["UUIDs"] @@ -344,9 +373,9 @@ uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" [[Scratch]] deps = ["Dates"] -git-tree-sha1 = "ad4b278adb62d185bbcb6864dc24959ab0627bf6" +git-tree-sha1 = "0b4b7f1393cff97c33891da2a0bf69c6ed241fda" uuid = "6c6a2e73-6563-6170-7368-637461726353" -version = "1.0.3" +version = "1.1.0" [[Serialization]] uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" @@ -359,36 +388,47 @@ uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" uuid = "6462fe0b-24de-5631-8697-dd941f90decc" [[SortingAlgorithms]] -deps = ["DataStructures", "Random", "Test"] -git-tree-sha1 = "03f5898c9959f8115e30bc7226ada7d0df554ddd" +deps = ["DataStructures"] +git-tree-sha1 = "2ec1962eba973f383239da22e75218565c390a96" uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" -version = "0.3.1" +version = "1.0.0" [[SparseArrays]] deps = ["LinearAlgebra", "Random"] uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [[SpecialFunctions]] -deps = ["ChainRulesCore", "OpenSpecFun_jll"] -git-tree-sha1 = "5919936c0e92cff40e57d0ddf0ceb667d42e5902" +deps = ["ChainRulesCore", "LogExpFunctions", "OpenSpecFun_jll"] +git-tree-sha1 = "a50550fa3164a8c46747e62063b4d774ac1bcf49" uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "1.3.0" +version = "1.5.1" + +[[Static]] +deps = ["IfElse"] +git-tree-sha1 = "2740ea27b66a41f9d213561a04573da5d3823d4b" +uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" +version = "0.2.5" [[StaticArrays]] deps = ["LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "e8cd1b100d37f5b4cfd2c83f45becf61c762eaf7" +git-tree-sha1 = "42378d3bab8b4f57aa1ca443821b752850592668" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.1.1" +version = "1.2.2" [[Statistics]] deps = ["LinearAlgebra", "SparseArrays"] uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +[[StatsAPI]] +git-tree-sha1 = "1958272568dc176a1d881acb797beb909c785510" +uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" +version = "1.0.0" + [[StatsBase]] -deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"] -git-tree-sha1 = "4bc58880426274277a066de306ef19ecc22a6863" +deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] +git-tree-sha1 = "2f6792d523d7448bbe2fec99eca9218f06cc746d" uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.33.5" +version = "0.33.8" [[TOML]] deps = ["Dates"] @@ -403,10 +443,10 @@ deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [[TimerOutputs]] -deps = ["Printf"] -git-tree-sha1 = "32cdbe6cd2d214c25a0b88f985c9e0092877c236" +deps = ["ExprTools", "Printf"] +git-tree-sha1 = "bf8aacc899a1bd16522d0350e1e2310510d77236" uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" -version = "0.5.8" +version = "0.5.9" [[TranscodingStreams]] deps = ["Random", "Test"] @@ -433,9 +473,9 @@ uuid = "83775a58-1f1d-513f-b197-d71354ab007a" [[Zygote]] deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "927209c83efa62256788a9880c191774c07c5b51" +git-tree-sha1 = "b1d95edd4e693066c38c13a10aab0a8f6a6e2f65" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.10" +version = "0.6.12" [[ZygoteRules]] deps = ["MacroTools"] diff --git a/Project.toml b/Project.toml index b45ba0dedd..8e8e17aef5 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.12.4" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193" Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" @@ -29,6 +30,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] AbstractTrees = "0.3" Adapt = "3.0" +ArrayInterface = "3.1" CUDA = "3" CodecZlib = "0.7" Colors = "0.12" @@ -45,9 +47,10 @@ julia = "1.6" [extras] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "Documenter", "IterTools", "LinearAlgebra"] +test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays"] diff --git a/src/Flux.jl b/src/Flux.jl index 5e6776d601..0689beb278 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -8,7 +8,6 @@ using Zygote, MacroTools, Juno, Reexport using MacroTools: @forward @reexport using NNlib using Zygote: Params, @adjoint, gradient, pullback, @nograd - export gradient export Chain, Dense, Maxout, SkipConnection, Parallel, flatten, diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index e2485a05d0..010cbfc9bb 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -1,6 +1,7 @@ module Optimise using LinearAlgebra +import ArrayInterface export train!, update!, Descent, ADAM, Momentum, Nesterov, RMSProp, diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 79acae1778..a224bf3389 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -18,14 +18,17 @@ Perform an update step of the parameters `ps` (or the single parameter `p`) according to optimizer `opt` and the gradients `gs` (the gradient `g`). As a result, the parameters are mutated and the optimizer's internal state may change. +The gradient could be mutated as well. """ function update!(opt, x, x̄) - x .-= apply!(opt, x, x̄) + x̄r = ArrayInterface.restructure(x, x̄) # address some cases where Zygote's + # output are not mutable, see #1510 + x .-= apply!(opt, x, x̄r) end function update!(opt, xs::Params, gs) for x in xs - gs[x] == nothing && continue + isnothing(gs[x]) && continue update!(opt, x, gs[x]) end end diff --git a/test/optimise.jl b/test/optimise.jl index 04cbf6f6c0..fa593df25d 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -1,6 +1,7 @@ using Flux.Optimise using Flux.Optimise: runall using Flux: Params, gradient +import FillArrays using Test using Random @@ -131,13 +132,31 @@ end end @testset "Clipping" begin - w = randn(10, 10) - loss(x) = sum(w * x) - θ = Params([w]) - x = 1000 * randn(10) - w̄ = gradient(() -> loss(x), θ)[w] - w̄_value = Optimise.apply!(ClipValue(1.0), w, copy(w̄)) - @test all(w̄_value .<= 1) - w̄_norm = Optimise.apply!(ClipNorm(1.0), w, copy(w̄)) - @test norm(w̄_norm) <= 1 + w = randn(10, 10) + loss(x) = sum(w * x) + θ = Params([w]) + x = 1000 * randn(10) + w̄ = gradient(() -> loss(x), θ)[w] + w̄_value = Optimise.apply!(ClipValue(1.0), w, copy(w̄)) + @test all(w̄_value .<= 1) + w̄_norm = Optimise.apply!(ClipNorm(1.0), w, copy(w̄)) + @test norm(w̄_norm) <= 1 +end + +@testset "handle Fills from Zygote" begin + w = randn(10,10) + wold = copy(w) + g = FillArrays.Ones(size(w)) + opt = Descent(0.1) + Flux.update!(opt, w, g) + @test w ≈ wold .- 0.1 + + ## Issue #1550 + w = randn(10,10) + wold = copy(w) + θ = Flux.params([w]) + gs = gradient(() -> sum(w), θ) + opt = Descent(0.1) + Flux.update!(opt, w, g) + @test w ≈ wold .- 0.1 end From f96fe9377a422234b8b3cb0dd47dde6f561937ab Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Thu, 10 Jun 2021 10:03:46 +0200 Subject: [PATCH 2/6] more tests --- test/optimise.jl | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/test/optimise.jl b/test/optimise.jl index fa593df25d..76ebc4c014 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -151,12 +151,21 @@ end Flux.update!(opt, w, g) @test w ≈ wold .- 0.1 + w = randn(3) + wold = copy(w) + θ = Flux.params([w]) + gs = gradient(() -> w[1], θ) + opt = Descent(0.1) + Flux.update!(opt, θ, gs) + @test w[1] ≈ wold[1] .- 0.1 + @test w[2:3] ≈ wold[2:3] + ## Issue #1550 w = randn(10,10) wold = copy(w) θ = Flux.params([w]) gs = gradient(() -> sum(w), θ) opt = Descent(0.1) - Flux.update!(opt, w, g) + Flux.update!(opt, θ, gs) @test w ≈ wold .- 0.1 end From 28bd53e20fa3f77ec6f314a1460220c547d45631 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Thu, 10 Jun 2021 10:06:42 +0200 Subject: [PATCH 3/6] fix typo --- test/optimise.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/optimise.jl b/test/optimise.jl index 76ebc4c014..cec267f3d1 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -160,7 +160,7 @@ end @test w[1] ≈ wold[1] .- 0.1 @test w[2:3] ≈ wold[2:3] - ## Issue #1550 + ## Issue #1510 w = randn(10,10) wold = copy(w) θ = Flux.params([w]) From 6bae1e7f226323a71407e89a6065dac5a858c844 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Thu, 10 Jun 2021 19:17:37 +0200 Subject: [PATCH 4/6] add test for ComponentArrays --- Manifest.toml | 6 ++++++ Project.toml | 3 ++- test/optimise.jl | 17 +++++++++++++++-- 3 files changed, 23 insertions(+), 3 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index 422daca7d6..d9856e3dbf 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -95,6 +95,12 @@ version = "3.30.0" deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" +[[ComponentArrays]] +deps = ["ArrayInterface", "LinearAlgebra", "Requires"] +git-tree-sha1 = "76495e7a7e47abc3771d70c782d5f6e66f114d36" +uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" +version = "0.10.5" + [[DataAPI]] git-tree-sha1 = "dfb3b7e89e395be1e25c2ad6d7690dc29cc53b1d" uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" diff --git a/Project.toml b/Project.toml index 8e8e17aef5..d4bcbca67a 100644 --- a/Project.toml +++ b/Project.toml @@ -46,6 +46,7 @@ Zygote = "0.6" julia = "1.6" [extras] +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e" @@ -53,4 +54,4 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays"] +test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays"] diff --git a/test/optimise.jl b/test/optimise.jl index cec267f3d1..4de8b4ac2f 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -1,7 +1,7 @@ using Flux.Optimise using Flux.Optimise: runall using Flux: Params, gradient -import FillArrays +import FillArrays, ComponentArrays using Test using Random @@ -143,7 +143,7 @@ end @test norm(w̄_norm) <= 1 end -@testset "handle Fills from Zygote" begin +@testset "update!: handle Fills from Zygote" begin w = randn(10,10) wold = copy(w) g = FillArrays.Ones(size(w)) @@ -169,3 +169,16 @@ end Flux.update!(opt, θ, gs) @test w ≈ wold .- 0.1 end + +@testset "update!: handle ComponentArrays" begin + w = ComponentArrays.ComponentArray(a=1.0, b=[2, 1, 4], c=(a=2, b=[1, 2])) + wold = deepcopy(w) + θ = Flux.params([w]) + gs = gradient(() -> sum(w.a) + sum(w.c.b), θ) + opt = Descent(0.1) + Flux.update!(opt, θ, gs) + @test w.a ≈ wold.a .- 0.1 + @test w.b ≈ wold.b + @test w.c.b ≈ wold.c.b .- 0.1 + @test w.c.a ≈ wold.c.a +end From 1fe64dcef2b1993f7fd9c64aeaa910b73c5008f8 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Thu, 10 Jun 2021 19:24:39 +0200 Subject: [PATCH 5/6] add another ComponentArray test --- Project.toml | 1 + test/optimise.jl | 8 ++++++++ 2 files changed, 9 insertions(+) diff --git a/Project.toml b/Project.toml index d4bcbca67a..9c87d426e3 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193" Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d" diff --git a/test/optimise.jl b/test/optimise.jl index 4de8b4ac2f..a6c82a99a8 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -181,4 +181,12 @@ end @test w.b ≈ wold.b @test w.c.b ≈ wold.c.b .- 0.1 @test w.c.a ≈ wold.c.a + + w = ComponentArrays.ComponentArray(a=1.0, b=[2, 1, 4], c=(a=2, b=[1, 2])) + wold = deepcopy(w) + θ = Flux.params([w]) + gs = gradient(() -> sum(w), θ) + opt = Descent(0.1) + Flux.update!(opt, θ, gs) + @test w ≈ wold .- 0.1 end From a77b32f8dabf84a6fb1f290832840f61a1317098 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Thu, 10 Jun 2021 19:27:11 +0200 Subject: [PATCH 6/6] cleanup project --- Manifest.toml | 10 ++-------- Project.toml | 1 - 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index d9856e3dbf..d163161570 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -95,12 +95,6 @@ version = "3.30.0" deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -[[ComponentArrays]] -deps = ["ArrayInterface", "LinearAlgebra", "Requires"] -git-tree-sha1 = "76495e7a7e47abc3771d70c782d5f6e66f114d36" -uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" -version = "0.10.5" - [[DataAPI]] git-tree-sha1 = "dfb3b7e89e395be1e25c2ad6d7690dc29cc53b1d" uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" @@ -302,9 +296,9 @@ version = "0.7.21" [[NNlibCUDA]] deps = ["CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics"] -git-tree-sha1 = "86d4d75e1091fe89d56422044be0dcc83766d2b4" +git-tree-sha1 = "bd8b29bf75be7a6c2b288b4b9a4e8903d0376ac1" uuid = "a00861dc-f156-4864-bf3c-e6376f28a68d" -version = "0.1.2" +version = "0.1.3" [[NaNMath]] git-tree-sha1 = "bfe47e760d60b82b66b61d2d44128b62e3a369fb" diff --git a/Project.toml b/Project.toml index 9c87d426e3..d4bcbca67a 100644 --- a/Project.toml +++ b/Project.toml @@ -9,7 +9,6 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193" Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" -ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d"