From 5ba4c99f2abf686949529e050e9bab6e8b2078a0 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Tue, 23 Mar 2021 12:51:46 +0000 Subject: [PATCH 1/4] add a warning when trying to accum two NametTuples with different keys --- src/lib/lib.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lib/lib.jl b/src/lib/lib.jl index 0d52c876d..b46341d39 100644 --- a/src/lib/lib.jl +++ b/src/lib/lib.jl @@ -16,6 +16,8 @@ accum(x::Tuple, y::Tuple) = accum.(x, y) accum(x::AbstractArray, y::AbstractArray) = accum.(x, y) @generated function accum(x::NamedTuple, y::NamedTuple) + # Zygote assumes that the NamedTuples will have the same keys + fieldnames(x) === fieldnames(y) || throw(ArgumentError("$x and $y keys must be the same")) grad(x) = x in fieldnames(y) ? :(y.$x) : :nothing Expr(:tuple, [:($f=accum(x.$f, $(grad(f)))) for f in fieldnames(x)]...) end From 4c7a68cbef106ad7be29a3293e1970490028444a Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Tue, 23 Mar 2021 12:53:51 +0000 Subject: [PATCH 2/4] bump patch --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 432cee9c1..34af23245 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Zygote" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.4" +version = "0.6.5" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" From 0e63d38a96721b0d683f3701725f014016f85ef4 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Tue, 23 Mar 2021 13:48:09 +0000 Subject: [PATCH 3/4] update implementation --- src/lib/lib.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/lib/lib.jl b/src/lib/lib.jl index b46341d39..dc389e316 100644 --- a/src/lib/lib.jl +++ b/src/lib/lib.jl @@ -16,9 +16,10 @@ accum(x::Tuple, y::Tuple) = accum.(x, y) accum(x::AbstractArray, y::AbstractArray) = accum.(x, y) @generated function accum(x::NamedTuple, y::NamedTuple) - # Zygote assumes that the NamedTuples will have the same keys - fieldnames(x) === fieldnames(y) || throw(ArgumentError("$x and $y keys must be the same")) - grad(x) = x in fieldnames(y) ? :(y.$x) : :nothing + # assumes that y has no keys apart from those also in x + fieldnames(y) ⊆ fieldnames(x) || throw(ArgumentError("$y keys must be a subset of $x keys")) + + grad(field) = field in fieldnames(y) ? :(y.$field) : :nothing Expr(:tuple, [:($f=accum(x.$f, $(grad(f)))) for f in fieldnames(x)]...) end From ada3d02fbcc2c85b37cc51a7530835451f557997 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Tue, 23 Mar 2021 13:56:53 +0000 Subject: [PATCH 4/4] add test --- test/lib/lib.jl | 8 ++++++++ test/runtests.jl | 3 ++- 2 files changed, 10 insertions(+), 1 deletion(-) create mode 100644 test/lib/lib.jl diff --git a/test/lib/lib.jl b/test/lib/lib.jl new file mode 100644 index 000000000..0886b9969 --- /dev/null +++ b/test/lib/lib.jl @@ -0,0 +1,8 @@ +@testset "lib.jl" begin + @testset "accum" begin + t1 = (a=1, b=2, c=3) + t2 = (a=1, b=2) + @test Zygote.accum(t1, t2) == (a = 2, b = 4, c = 3) + @test_throws ArgumentError Zygote.accum(t2, t1) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 6fbc75341..b6b7aab0b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -22,8 +22,9 @@ end include("utils.jl") end -@testset "lib/number" begin +@testset "lib" begin include("lib/number.jl") + include("lib/lib.jl") end @testset "Features" begin