From d22ab50b5bced55e1a46995ffd020e51c0bcb007 Mon Sep 17 00:00:00 2001 From: Jeff Bezanson Date: Fri, 2 Jun 2017 19:26:59 -0400 Subject: [PATCH] add `merge` and `structdiff` for named tuples --- base/namedtuple.jl | 41 +++++++++++++++++++++++++++++++++++++++++ test/namedtuple.jl | 10 ++++++++++ 2 files changed, 51 insertions(+) diff --git a/base/namedtuple.jl b/base/namedtuple.jl index b8a5d417c96f6..cd647845f5a1e 100644 --- a/base/namedtuple.jl +++ b/base/namedtuple.jl @@ -89,3 +89,44 @@ end namedtuple($NT, $(args...)) end end + +# a version of `in` for the older world these generated functions run in +function sym_in(x, itr) + for y in itr + y === x && return true + end + return false +end + +@generated function merge(a::NamedTuple{an}, b::NamedTuple{bn}) where {an, bn} + names = Symbol[an...] + for n in bn + if !sym_in(n, an) + push!(names, n) + end + end + vals = map(names) do n + if sym_in(n, bn) + :(getfield(b, $(Expr(:quote, n)))) + else + :(getfield(a, $(Expr(:quote, n)))) + end + end + names = (names...,) + :(namedtuple(NamedTuple{$names}, $(vals...))) +end + +@generated function structdiff(a::NamedTuple{an}, + b::Union{NamedTuple{bn},Type{NamedTuple{bn}}}) where {an,bn} + names = Symbol[] + for n in an + if !sym_in(n, bn) + push!(names, n) + end + end + vals = map(names) do n + :(getfield(a, $(Expr(:quote, n)))) + end + names = (names...,) + :(namedtuple(NamedTuple{$names}, $(vals...))) +end diff --git a/test/namedtuple.jl b/test/namedtuple.jl index 6738fa3f94f63..b689546367676 100644 --- a/test/namedtuple.jl +++ b/test/namedtuple.jl @@ -51,3 +51,13 @@ @test map(+, (x=1, y=2), (x=10, y=20)) == (x=11, y=22) @test map(string, (x=1, y=2)) == (x="1", y="2") @test map(round, (x=1//3, y=Int), (x=3, y=2//3)) == (x=0.333, y=1) + +@test merge((a=1, b=2), (a=10,)) == (a=10, b=2) +@test merge((a=1, b=2), (a=10, z=20)) == (a=10, b=2, z=20) +@test merge((a=1, b=2), (z=20,)) == (a=1, b=2, z=20) + +@test Base.structdiff((a=1, b=2), (b=3,)) == (a=1,) +@test Base.structdiff((a=1, b=2, z=20), (b=3,)) == (a=1, z=20) +@test Base.structdiff((a=1, b=2, z=20), (b=3, q=20, z=1)) == (a=1,) +@test Base.structdiff((a=1, b=2, z=20), (b=3, q=20, z=1, a=0)) == NamedTuple() +@test Base.structdiff((a=1, b=2, z=20), NamedTuple{(:b,)}) == (a=1, z=20)