From 14b301d74bb36fdf4f18ec465ddcd56557560b03 Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Wed, 15 Sep 2021 21:21:22 -0700 Subject: [PATCH 1/9] Cata + Anamorphisms working --- old_src/Functors.jl | 7 ++ old_src/functor.jl | 202 +++++++++++++++++++++++++++++++ test/basics.jl | 288 ++++++++++++++++++++++++++------------------ 3 files changed, 377 insertions(+), 120 deletions(-) create mode 100644 old_src/Functors.jl create mode 100644 old_src/functor.jl diff --git a/old_src/Functors.jl b/old_src/Functors.jl new file mode 100644 index 0000000..54527a3 --- /dev/null +++ b/old_src/Functors.jl @@ -0,0 +1,7 @@ +module Functors + +export @functor, @flexiblefunctor, fmap, fmapstructure, fcollect + +include("functor.jl") + +end # module diff --git a/old_src/functor.jl b/old_src/functor.jl new file mode 100644 index 0000000..b502493 --- /dev/null +++ b/old_src/functor.jl @@ -0,0 +1,202 @@ +functor(T, x) = (), _ -> x +functor(x) = functor(typeof(x), x) + +functor(::Type{<:Tuple}, x) = x, y -> y +functor(::Type{<:NamedTuple}, x) = x, y -> y + +functor(::Type{<:AbstractArray}, x) = x, y -> y +functor(::Type{<:AbstractArray{<:Number}}, x) = (), _ -> x + +@static if VERSION >= v"1.6" + functor(::Type{<:Base.ComposedFunction}, x) = (outer = x.outer, inner = x.inner), y -> Base.ComposedFunction(y.outer, y.inner) +end + +function makefunctor(m::Module, T, fs = fieldnames(T)) + yᵢ = 0 + escargs = map(fieldnames(T)) do f + f in fs ? :(y[$(yᵢ += 1)]) : :(x.$f) + end + escfs = [:($f=x.$f) for f in fs] + + @eval m begin + $Functors.functor(::Type{<:$T}, x) = ($(escfs...),), y -> $T($(escargs...)) + end +end + +function functorm(T, fs = nothing) + fs === nothing || Meta.isexpr(fs, :tuple) || error("@functor T (a, b)") + fs = fs === nothing ? [] : [:($(map(QuoteNode, fs.args)...),)] + :(makefunctor(@__MODULE__, $(esc(T)), $(fs...))) +end + +macro functor(args...) + functorm(args...) +end + +function makeflexiblefunctor(m::Module, T, pfield) + pfield = QuoteNode(pfield) + @eval m begin + function $Functors.functor(::Type{<:$T}, x) + pfields = getproperty(x, $pfield) + function re(y) + all_args = map(fn -> getproperty(fn in pfields ? y : x, fn), fieldnames($T)) + return $T(all_args...) + end + func = NamedTuple{pfields}(map(p -> getproperty(x, p), pfields)) + return func, re + end + + end + +end + +function flexiblefunctorm(T, pfield = :params) + pfield isa Symbol || error("@flexiblefunctor T param_field") + pfield = QuoteNode(pfield) + :(makeflexiblefunctor(@__MODULE__, $(esc(T)), $(esc(pfield)))) +end + +macro flexiblefunctor(args...) + flexiblefunctorm(args...) +end + +""" + isleaf(x) + +Return true if `x` has no [`children`](@ref) according to [`functor`](@ref). +""" +isleaf(x) = children(x) === () + +""" + children(x) + +Return the children of `x` as defined by [`functor`](@ref). +Equivalent to `functor(x)[1]`. +""" +children(x) = functor(x)[1] + +function _default_walk(f, x) + func, re = functor(x) + re(map(f, func)) +end + +""" + fmap(f, x; exclude = isleaf, walk = Functors._default_walk) + +A structure and type preserving `map` that works for all [`functor`](@ref)s. + +By default, traverses `x` recursively using [`functor`](@ref) +and transforms every leaf node identified by `exclude` with `f`. + +For advanced customization of the traversal behaviour, pass a custom `walk` function of the form `(f', xs) -> ...`. +This function walks (maps) over `xs` calling the continuation `f'` to continue traversal. + +# Examples +```jldoctest +julia> struct Foo; x; y; end + +julia> @functor Foo + +julia> struct Bar; x; end + +julia> @functor Bar + +julia> m = Foo(Bar([1,2,3]), (4, 5)); + +julia> fmap(x -> 2x, m) +Foo(Bar([2, 4, 6]), (8, 10)) + +julia> fmap(string, m) +Foo(Bar("[1, 2, 3]"), ("4", "5")) + +julia> fmap(string, m, exclude = v -> v isa Bar) +Foo("Bar([1, 2, 3])", (4, 5)) + +julia> fmap(x -> 2x, m, walk=(f, x) -> x isa Bar ? x : Functors._default_walk(f, x)) +Foo(Bar([1, 2, 3]), (8, 10)) +``` +""" +function fmap(f, x; exclude = isleaf, walk = _default_walk, cache = IdDict()) + haskey(cache, x) && return cache[x] + y = exclude(x) ? f(x) : walk(x -> fmap(f, x, exclude = exclude, walk = walk, cache = cache), x) + cache[x] = y + + return y +end + +""" + fmapstructure(f, x; exclude = isleaf) + +Like [`fmap`](@ref), but doesn't preserve the type of custom structs. Instead, it returns a (potentially nested) `NamedTuple`. + +Useful for when the output must not contain custom structs. + +# Examples +```jldoctest +julia> struct Foo; x; y; end + +julia> @functor Foo + +julia> m = Foo([1,2,3], (4, 5)); + +julia> fmapstructure(x -> 2x, m) +(x = [2, 4, 6], y = (8, 10)) +``` +""" +fmapstructure(f, x; kwargs...) = fmap(f, x; walk = (f, x) -> map(f, children(x)), kwargs...) + +""" + fcollect(x; exclude = v -> false) + +Traverse `x` by recursing each child of `x` as defined by [`functor`](@ref) +and collecting the results into a flat array. + +Doesn't recurse inside branches rooted at nodes `v` +for which `exclude(v) == true`. +In such cases, the root `v` is also excluded from the result. +By default, `exclude` always yields `false`. + +See also [`children`](@ref). + +# Examples + +```jldoctest +julia> struct Foo; x; y; end + +julia> @functor Foo + +julia> struct Bar; x; end + +julia> @functor Bar + +julia> struct NoChildren; x; y; end + +julia> m = Foo(Bar([1,2,3]), NoChildren(:a, :b)) +Foo(Bar([1, 2, 3]), NoChildren(:a, :b)) + +julia> fcollect(m) +4-element Vector{Any}: + Foo(Bar([1, 2, 3]), NoChildren(:a, :b)) + Bar([1, 2, 3]) + [1, 2, 3] + NoChildren(:a, :b) + +julia> fcollect(m, exclude = v -> v isa Bar) +2-element Vector{Any}: + Foo(Bar([1, 2, 3]), NoChildren(:a, :b)) + NoChildren(:a, :b) + +julia> fcollect(m, exclude = v -> Functors.isleaf(v)) +2-element Vector{Any}: + Foo(Bar([1, 2, 3]), NoChildren(:a, :b)) + Bar([1, 2, 3]) +``` +""" +function fcollect(x; cache = [], exclude = v -> false) + x in cache && return cache + if !exclude(x) + push!(cache, x) + foreach(y -> fcollect(y; cache = cache, exclude = exclude), children(x)) + end + return cache +end diff --git a/test/basics.jl b/test/basics.jl index bc1617b..9720092 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -1,153 +1,201 @@ using Functors, Test -struct Foo - x - y -end -@functor Foo +abstract type Expr end -struct Bar - x +struct Index <: Expr + arg::Expr + idx::Expr end -@functor Bar +@functor Index -struct Baz - x - y - z +struct Call <: Expr + fn::Expr + args::Vector{Expr} end -@functor Baz (y,) +@functor Call -struct NoChildren - x - y +struct Unary <: Expr + op::String + arg::Expr end +@functor Unary -@static if VERSION >= v"1.6" - @testset "ComposedFunction" begin - f1 = Foo(1.1, 2.2) - f2 = Bar(3.3) - @test Functors.functor(f1 ∘ f2)[1] == (outer = f1, inner = f2) - @test Functors.functor(f1 ∘ f2)[2]((outer = f1, inner = f2)) == f1 ∘ f2 - @test fmap(x -> x + 10, f1 ∘ f2) == Foo(11.1, 12.2) ∘ Bar(13.3) - end +struct Binary <: Expr + lhs::Expr + op::String + rhs::Expr end +@functor Binary -@testset "Nested" begin - model = Bar(Foo(1, [1, 2, 3])) - - model′ = fmap(float, model) - - @test model.x.y == model′.x.y - @test model′.x.y isa Vector{Float64} +struct Paren <: Expr + inner::Expr end +@functor Paren -@testset "Exclude" begin - f(x::AbstractArray) = x - f(x::Char) = 'z' - - x = ['a', 'b', 'c'] - @test fmap(f, x) == ['z', 'z', 'z'] - @test fmap(f, x; exclude = x -> x isa AbstractArray) == x - - x = (['a', 'b', 'c'], ['d', 'e', 'f']) - @test fmap(f, x) == (['z', 'z', 'z'], ['z', 'z', 'z']) - @test fmap(f, x; exclude = x -> x isa AbstractArray) == x +struct Ident <: Expr + name::String end +@functor Ident -@testset "Walk" begin - model = Foo((0, Bar([1, 2, 3])), [4, 5]) - - model′ = fmapstructure(identity, model) - @test model′ == (; x=(0, (; x=[1, 2, 3])), y=[4, 5]) +struct Literal{T} <: Expr + val::T end +@functor Literal -@testset "Property list" begin - model = Baz(1, 2, 3) - model′ = fmap(x -> 2x, model) - - @test (model′.x, model′.y, model′.z) == (1, 4, 3) -end -@testset "fcollect" begin - m1 = [1, 2, 3] - m2 = 1 - m3 = Foo(m1, m2) - m4 = Bar(m3) - @test all(fcollect(m4) .=== [m4, m3, m1, m2]) - @test all(fcollect(m4, exclude = x -> x isa Array) .=== [m4, m3, m2]) - @test all(fcollect(m4, exclude = x -> x isa Foo) .=== [m4]) - - m1 = [1, 2, 3] - m2 = Bar(m1) - m0 = NoChildren(:a, :b) - m3 = Foo(m2, m0) - m4 = Bar(m3) - @test all(fcollect(m4) .=== [m4, m3, m2, m1, m0]) -end +@testset "cata" begin + countnodes(e::Functor{Unary}) = 1 + e.arg + countnodes(e::Functor{Binary}) = 1 + e.lhs + e.rhs + countnodes(e::Functor{Call}) = 1 + e.fn + sum(e.args) + countnodes(e::Functor{Index}) = 1 + e.arg + e.idx + countnodes(e::Functor{Paren}) = 1 + e.inner + countnodes(e::Functor{Literal}) = 1 + countnodes(e::Functor{Ident}) = 1 + countnodes(e::Union{String,Int}) = 0 + countnodes(e) = e -struct FFoo - x - y - p -end -@flexiblefunctor FFoo p + ten, add = Literal(10), Ident("add") + call = Call(add, [ten, ten]) -struct FBar - x - p + @test Functors.cata(countnodes, call) == 4 + + @show Functors.cata(Functors.backing, call) end -@flexiblefunctor FBar p -struct FBaz - x - y - z - p +@testset "ana" begin + function nested(n) + go(m) = m == 0 ? Literal(n) : Functor{Paren}((;inner=(m - 1))) + Functors.ana(go, n) + end + + @test nested(3) == 3 |> Literal |> Paren |> Paren |> Paren end -@flexiblefunctor FBaz p -@testset "Flexible Nested" begin - model = FBar(FFoo(1, [1, 2, 3], (:y, )), (:x,)) +# @static if VERSION >= v"1.6" +# @testset "ComposedFunction" begin +# f1 = Foo(1.1, 2.2) +# f2 = Bar(3.3) +# @test Functors.functor(f1 ∘ f2)[1] == (outer = f1, inner = f2) +# @test Functors.functor(f1 ∘ f2)[2]((outer = f1, inner = f2)) == f1 ∘ f2 +# @test fmap(x -> x + 10, f1 ∘ f2) == Foo(11.1, 12.2) ∘ Bar(13.3) +# end +# end - model′ = fmap(float, model) +# @testset "Nested" begin +# model = Bar(Foo(1, [1, 2, 3])) - @test model.x.y == model′.x.y - @test model′.x.y isa Vector{Float64} -end +# model′ = fmap(float, model) -@testset "Flexible Walk" begin - model = FFoo((0, FBar([1, 2, 3], (:x,))), [4, 5], (:x, :y)) +# @test model.x.y == model′.x.y +# @test model′.x.y isa Vector{Float64} +# end - model′ = fmapstructure(identity, model) - @test model′ == (; x=(0, (; x=[1, 2, 3])), y=[4, 5]) +# @testset "Exclude" begin +# f(x::AbstractArray) = x +# f(x::Char) = 'z' - model2 = FFoo((0, FBar([1, 2, 3], (:x,))), [4, 5], (:x,)) +# x = ['a', 'b', 'c'] +# @test fmap(f, x) == ['z', 'z', 'z'] +# @test fmap(f, x; exclude = x -> x isa AbstractArray) == x - model2′ = fmapstructure(identity, model2) - @test model2′ == (; x=(0, (; x=[1, 2, 3]))) -end +# x = (['a', 'b', 'c'], ['d', 'e', 'f']) +# @test fmap(f, x) == (['z', 'z', 'z'], ['z', 'z', 'z']) +# @test fmap(f, x; exclude = x -> x isa AbstractArray) == x +# end -@testset "Flexible Property list" begin - model = FBaz(1, 2, 3, (:x, :z)) - model′ = fmap(x -> 2x, model) +# @testset "Walk" begin +# model = Foo((0, Bar([1, 2, 3])), [4, 5]) - @test (model′.x, model′.y, model′.z) == (2, 2, 6) -end +# model′ = fmapstructure(identity, model) +# @test model′ == (; x=(0, (; x=[1, 2, 3])), y=[4, 5]) +# end -@testset "Flexible fcollect" begin - m1 = 1 - m2 = [1, 2, 3] - m3 = FFoo(m1, m2, (:y, )) - m4 = FBar(m3, (:x,)) - @test all(fcollect(m4) .=== [m4, m3, m2]) - @test all(fcollect(m4, exclude = x -> x isa Array) .=== [m4, m3]) - @test all(fcollect(m4, exclude = x -> x isa FFoo) .=== [m4]) - - m0 = NoChildren(:a, :b) - m1 = [1, 2, 3] - m2 = FBar(m1, ()) - m3 = FFoo(m2, m0, (:x, :y,)) - m4 = FBar(m3, (:x,)) - @test all(fcollect(m4) .=== [m4, m3, m2, m0]) -end +# @testset "Property list" begin +# model = Baz(1, 2, 3) +# model′ = fmap(x -> 2x, model) + +# @test (model′.x, model′.y, model′.z) == (1, 4, 3) +# end + +# @testset "fcollect" begin +# m1 = [1, 2, 3] +# m2 = 1 +# m3 = Foo(m1, m2) +# m4 = Bar(m3) +# @test all(fcollect(m4) .=== [m4, m3, m1, m2]) +# @test all(fcollect(m4, exclude = x -> x isa Array) .=== [m4, m3, m2]) +# @test all(fcollect(m4, exclude = x -> x isa Foo) .=== [m4]) + +# m1 = [1, 2, 3] +# m2 = Bar(m1) +# m0 = NoChildren(:a, :b) +# m3 = Foo(m2, m0) +# m4 = Bar(m3) +# @test all(fcollect(m4) .=== [m4, m3, m2, m1, m0]) +# end + +# struct FFoo +# x +# y +# p +# end +# @flexiblefunctor FFoo p + +# struct FBar +# x +# p +# end +# @flexiblefunctor FBar p + +# struct FBaz +# x +# y +# z +# p +# end +# @flexiblefunctor FBaz p + +# @testset "Flexible Nested" begin +# model = FBar(FFoo(1, [1, 2, 3], (:y, )), (:x,)) + +# model′ = fmap(float, model) + +# @test model.x.y == model′.x.y +# @test model′.x.y isa Vector{Float64} +# end + +# @testset "Flexible Walk" begin +# model = FFoo((0, FBar([1, 2, 3], (:x,))), [4, 5], (:x, :y)) + +# model′ = fmapstructure(identity, model) +# @test model′ == (; x=(0, (; x=[1, 2, 3])), y=[4, 5]) + +# model2 = FFoo((0, FBar([1, 2, 3], (:x,))), [4, 5], (:x,)) + +# model2′ = fmapstructure(identity, model2) +# @test model2′ == (; x=(0, (; x=[1, 2, 3]))) +# end + +# @testset "Flexible Property list" begin +# model = FBaz(1, 2, 3, (:x, :z)) +# model′ = fmap(x -> 2x, model) + +# @test (model′.x, model′.y, model′.z) == (2, 2, 6) +# end + +# @testset "Flexible fcollect" begin +# m1 = 1 +# m2 = [1, 2, 3] +# m3 = FFoo(m1, m2, (:y, )) +# m4 = FBar(m3, (:x,)) +# @test all(fcollect(m4) .=== [m4, m3, m2]) +# @test all(fcollect(m4, exclude = x -> x isa Array) .=== [m4, m3]) +# @test all(fcollect(m4, exclude = x -> x isa FFoo) .=== [m4]) + +# m0 = NoChildren(:a, :b) +# m1 = [1, 2, 3] +# m2 = FBar(m1, ()) +# m3 = FFoo(m2, m0, (:x, :y,)) +# m4 = FBar(m3, (:x,)) +# @test all(fcollect(m4) .=== [m4, m3, m2, m0]) +# end From c08f27c9a3ea861b5c6f1b194278fc09b88fdca3 Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Wed, 15 Sep 2021 22:13:35 -0700 Subject: [PATCH 2/9] More fmap tests and recursive fmap --- src/Functors.jl | 4 +- src/functor.jl | 219 ++++++++---------------------------------------- test/basics.jl | 57 ++++++++++--- 3 files changed, 81 insertions(+), 199 deletions(-) diff --git a/src/Functors.jl b/src/Functors.jl index 54527a3..50f2a8b 100644 --- a/src/Functors.jl +++ b/src/Functors.jl @@ -1,7 +1,7 @@ module Functors -export @functor, @flexiblefunctor, fmap, fmapstructure, fcollect - include("functor.jl") +export Functor, @functor, fmap + end # module diff --git a/src/functor.jl b/src/functor.jl index b502493..861253b 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -1,202 +1,53 @@ -functor(T, x) = (), _ -> x -functor(x) = functor(typeof(x), x) - -functor(::Type{<:Tuple}, x) = x, y -> y -functor(::Type{<:NamedTuple}, x) = x, y -> y - -functor(::Type{<:AbstractArray}, x) = x, y -> y -functor(::Type{<:AbstractArray{<:Number}}, x) = (), _ -> x - -@static if VERSION >= v"1.6" - functor(::Type{<:Base.ComposedFunction}, x) = (outer = x.outer, inner = x.inner), y -> Base.ComposedFunction(y.outer, y.inner) +struct Functor{T,FS} + inner::FS end -function makefunctor(m::Module, T, fs = fieldnames(T)) - yᵢ = 0 - escargs = map(fieldnames(T)) do f - f in fs ? :(y[$(yᵢ += 1)]) : :(x.$f) - end - escfs = [:($f=x.$f) for f in fs] +Functor{T}(inner) where T = Functor{T,typeof(inner)}(inner) +backing(x) = x +backing(func::Functor) = getfield(func, :inner) - @eval m begin - $Functors.functor(::Type{<:$T}, x) = ($(escfs...),), y -> $T($(escargs...)) - end -end +Base.getproperty(func::Functor, prop::Symbol) = getproperty(backing(func), prop) +Base.getindex(func::Functor, prop) = getindex(backing(func), prop) -function functorm(T, fs = nothing) - fs === nothing || Meta.isexpr(fs, :tuple) || error("@functor T (a, b)") - fs = fs === nothing ? [] : [:($(map(QuoteNode, fs.args)...),)] - :(makefunctor(@__MODULE__, $(esc(T)), $(fs...))) -end +fmap(_, x) = x +fmap(f, func::Functor{T}) where T = Functor{T}(map(f, backing(func))) -macro functor(args...) - functorm(args...) -end +project(x) = x +embed(func) = func -function makeflexiblefunctor(m::Module, T, pfield) - pfield = QuoteNode(pfield) - @eval m begin - function $Functors.functor(::Type{<:$T}, x) - pfields = getproperty(x, $pfield) - function re(y) - all_args = map(fn -> getproperty(fn in pfields ? y : x, fn), fieldnames($T)) - return $T(all_args...) - end - func = NamedTuple{pfields}(map(p -> getproperty(x, p), pfields)) - return func, re +function makefunctor(m::Module, T, fs=fieldnames(T)) + escfs = [:($f = x.$f) for f in fs] + @eval m begin + $Functors.project(x::$T) = $Functors.Functor{$T}(($(escfs...),)) + # TODO use ConstructionBase? + $Functors.embed(func::$Functors.Functor{$T}) = $T($Functors.backing(func)...) end - - end - end -function flexiblefunctorm(T, pfield = :params) - pfield isa Symbol || error("@flexiblefunctor T param_field") - pfield = QuoteNode(pfield) - :(makeflexiblefunctor(@__MODULE__, $(esc(T)), $(esc(pfield)))) +function functorm(T, fs=nothing) + fs === nothing || Meta.isexpr(fs, :tuple) || error("@functor T (a, b)") + fs = fs === nothing ? [] : [:($(map(QuoteNode, fs.args)...),)] + :(makefunctor(@__MODULE__, $(esc(T)), $(fs...))) end -macro flexiblefunctor(args...) - flexiblefunctorm(args...) -end - -""" - isleaf(x) - -Return true if `x` has no [`children`](@ref) according to [`functor`](@ref). -""" -isleaf(x) = children(x) === () - -""" - children(x) - -Return the children of `x` as defined by [`functor`](@ref). -Equivalent to `functor(x)[1]`. -""" -children(x) = functor(x)[1] - -function _default_walk(f, x) - func, re = functor(x) - re(map(f, func)) -end - -""" - fmap(f, x; exclude = isleaf, walk = Functors._default_walk) - -A structure and type preserving `map` that works for all [`functor`](@ref)s. - -By default, traverses `x` recursively using [`functor`](@ref) -and transforms every leaf node identified by `exclude` with `f`. - -For advanced customization of the traversal behaviour, pass a custom `walk` function of the form `(f', xs) -> ...`. -This function walks (maps) over `xs` calling the continuation `f'` to continue traversal. - -# Examples -```jldoctest -julia> struct Foo; x; y; end - -julia> @functor Foo - -julia> struct Bar; x; end - -julia> @functor Bar - -julia> m = Foo(Bar([1,2,3]), (4, 5)); - -julia> fmap(x -> 2x, m) -Foo(Bar([2, 4, 6]), (8, 10)) - -julia> fmap(string, m) -Foo(Bar("[1, 2, 3]"), ("4", "5")) - -julia> fmap(string, m, exclude = v -> v isa Bar) -Foo("Bar([1, 2, 3])", (4, 5)) - -julia> fmap(x -> 2x, m, walk=(f, x) -> x isa Bar ? x : Functors._default_walk(f, x)) -Foo(Bar([1, 2, 3]), (8, 10)) -``` -""" -function fmap(f, x; exclude = isleaf, walk = _default_walk, cache = IdDict()) - haskey(cache, x) && return cache[x] - y = exclude(x) ? f(x) : walk(x -> fmap(f, x, exclude = exclude, walk = walk, cache = cache), x) - cache[x] = y - - return y +macro functor(args...) + functorm(args...) end -""" - fmapstructure(f, x; exclude = isleaf) - -Like [`fmap`](@ref), but doesn't preserve the type of custom structs. Instead, it returns a (potentially nested) `NamedTuple`. - -Useful for when the output must not contain custom structs. - -# Examples -```jldoctest -julia> struct Foo; x; y; end - -julia> @functor Foo - -julia> m = Foo([1,2,3], (4, 5)); +# recursion schemes +cata(f, x) = f(fmap(y -> cata(f, y), project(x))) +ana(f, x) = embed(fmap(y -> ana(f, y), f(x))) -julia> fmapstructure(x -> 2x, m) -(x = [2, 4, 6], y = (8, 10)) -``` -""" -fmapstructure(f, x; kwargs...) = fmap(f, x; walk = (f, x) -> map(f, children(x)), kwargs...) +# aliases +const fold = cata +const unfold = ana -""" - fcollect(x; exclude = v -> false) +# convenience functions +rfmap(f, x) = cata(embed ∘ f, x) -Traverse `x` by recursing each child of `x` as defined by [`functor`](@ref) -and collecting the results into a flat array. +# built-ins +fmap(f, xs::T) where T <: Union{Tuple,NamedTuple,AbstractArray} = map(f, xs) -Doesn't recurse inside branches rooted at nodes `v` -for which `exclude(v) == true`. -In such cases, the root `v` is also excluded from the result. -By default, `exclude` always yields `false`. - -See also [`children`](@ref). - -# Examples - -```jldoctest -julia> struct Foo; x; y; end - -julia> @functor Foo - -julia> struct Bar; x; end - -julia> @functor Bar - -julia> struct NoChildren; x; y; end - -julia> m = Foo(Bar([1,2,3]), NoChildren(:a, :b)) -Foo(Bar([1, 2, 3]), NoChildren(:a, :b)) - -julia> fcollect(m) -4-element Vector{Any}: - Foo(Bar([1, 2, 3]), NoChildren(:a, :b)) - Bar([1, 2, 3]) - [1, 2, 3] - NoChildren(:a, :b) - -julia> fcollect(m, exclude = v -> v isa Bar) -2-element Vector{Any}: - Foo(Bar([1, 2, 3]), NoChildren(:a, :b)) - NoChildren(:a, :b) - -julia> fcollect(m, exclude = v -> Functors.isleaf(v)) -2-element Vector{Any}: - Foo(Bar([1, 2, 3]), NoChildren(:a, :b)) - Bar([1, 2, 3]) -``` -""" -function fcollect(x; cache = [], exclude = v -> false) - x in cache && return cache - if !exclude(x) - push!(cache, x) - foreach(y -> fcollect(y; cache = cache, exclude = exclude), children(x)) - end - return cache +@static if VERSION >= v"1.6" + @functor Base.ComposedFunction end diff --git a/test/basics.jl b/test/basics.jl index 9720092..e029835 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -12,6 +12,7 @@ struct Call <: Expr fn::Expr args::Vector{Expr} end +Base.:(==)(c1::Call, c2::Call) = c1.fn == c2.fn && c1.args == c2.args @functor Call struct Unary <: Expr @@ -44,22 +45,52 @@ end @testset "cata" begin - countnodes(e::Functor{Unary}) = 1 + e.arg - countnodes(e::Functor{Binary}) = 1 + e.lhs + e.rhs - countnodes(e::Functor{Call}) = 1 + e.fn + sum(e.args) - countnodes(e::Functor{Index}) = 1 + e.arg + e.idx - countnodes(e::Functor{Paren}) = 1 + e.inner - countnodes(e::Functor{Literal}) = 1 - countnodes(e::Functor{Ident}) = 1 - countnodes(e::Union{String,Int}) = 0 - countnodes(e) = e - ten, add = Literal(10), Ident("add") call = Call(add, [ten, ten]) - @test Functors.cata(countnodes, call) == 4 - - @show Functors.cata(Functors.backing, call) + @testset "converting to POJOs" begin + expected = (fn = (;name="add"), args = [(;val=10), (;val=10)]) + @test Functors.cata(Functors.backing, call) == expected + end + + @testset "roundtrip" begin + @test Functors.cata(Functors.embed, call) == call + end + + @testset "rfmap" begin + f64(x) = x + f64(x::Real) = Float64(x) + + expected = Call(Ident("add"), [Literal(10.), Literal(10.)]) + @test Functors.rfmap(f64, call) == expected + end + + @testset "counting nodes" begin + countnodes(e::Functor{Unary}) = 1 + e.arg + countnodes(e::Functor{Binary}) = 1 + e.lhs + e.rhs + countnodes(e::Functor{Call}) = 1 + e.fn + sum(e.args) + countnodes(e::Functor{Index}) = 1 + e.arg + e.idx + countnodes(e::Functor{Paren}) = 1 + e.inner + countnodes(e::Functor{Literal}) = 1 + countnodes(e::Functor{Ident}) = 1 + countnodes(e::Union{String,Int}) = 0 + countnodes(e) = e + + @test Functors.cata(countnodes, call) == 4 + end + + @testset "pretty-printing" begin + prettyprint(l::Functor{Literal}) = repr(l.val) + prettyprint(i::Functor{Ident}) = i.name + prettyprint(c::Functor{Call}) = "$(c.fn)($(join(c.args, ",")))" + prettyprint(i::Functor{Index}) = "$(i.arg)[$(i.idx)]" + prettyprint(u::Functor{Unary}) = u.op * u.expr + prettyprint(b::Functor{Binary}) = b.lhs * b.op * b.rhs + prettyprint(p::Functor{Paren}) = "[$(p.inner)]" + prettyprint(e) = e + + @test Functors.cata(prettyprint, call) == "add(10,10)" + end end @testset "ana" begin From 7c9dd4fb540abb85652b5736b23df955fd48a46b Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Thu, 16 Sep 2021 13:08:15 -0700 Subject: [PATCH 3/9] Param values and convenience constructors --- src/Functors.jl | 2 +- src/functor.jl | 22 +++++++++++++++++----- test/basics.jl | 7 +++---- 3 files changed, 21 insertions(+), 10 deletions(-) diff --git a/src/Functors.jl b/src/Functors.jl index 50f2a8b..0b43f00 100644 --- a/src/Functors.jl +++ b/src/Functors.jl @@ -2,6 +2,6 @@ module Functors include("functor.jl") -export Functor, @functor, fmap +export Functor, functor, @functor, fmap end # module diff --git a/src/functor.jl b/src/functor.jl index 861253b..c5fd75d 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -3,8 +3,13 @@ struct Functor{T,FS} end Functor{T}(inner) where T = Functor{T,typeof(inner)}(inner) +functor(::Type{T}, args...) where T = Functor{T}(NamedTuple{fieldnames(T)}(args)) +functor(::Type{T}; kwargs...) where T = Functor{T}(NamedTuple{fieldnames(T)}(values(kwargs))) + backing(x) = x backing(func::Functor) = getfield(func, :inner) +# TODO better name +paramvalues(x) = backing(x) Base.getproperty(func::Functor, prop::Symbol) = getproperty(backing(func), prop) Base.getindex(func::Functor, prop) = getindex(backing(func), prop) @@ -14,19 +19,26 @@ fmap(f, func::Functor{T}) where T = Functor{T}(map(f, backing(func))) project(x) = x embed(func) = func +# TODO use ConstructionBase? +embed(func::Functor{T}) where T = T(backing(func)...) function makefunctor(m::Module, T, fs=fieldnames(T)) - escfs = [:($f = x.$f) for f in fs] + escfields = [:($field = x.$field) for field in fieldnames(T)] + escfs = [:($field = x.$field) for field in fs] + escfmap = map(fieldnames(T)) do field + field in fs ? :($field = f(func.$field)) : :($field = func.$field) + end + @eval m begin - $Functors.project(x::$T) = $Functors.Functor{$T}(($(escfs...),)) - # TODO use ConstructionBase? - $Functors.embed(func::$Functors.Functor{$T}) = $T($Functors.backing(func)...) + $Functors.project(x::$T) = $Functors.Functor{$T}(($(escfields...),)) + $Functors.paramvalues(func::$Functors.Functor{$T}) = ($(escfs...),) + $Functors.fmap(f, func::$Functors.Functor{$T}) = $Functors.Functor{$T}(($(escfmap...),)) end end function functorm(T, fs=nothing) fs === nothing || Meta.isexpr(fs, :tuple) || error("@functor T (a, b)") - fs = fs === nothing ? [] : [:($(map(QuoteNode, fs.args)...),)] +fs = fs === nothing ? [] : [:($(map(QuoteNode, fs.args)...),)] :(makefunctor(@__MODULE__, $(esc(T)), $(fs...))) end diff --git a/test/basics.jl b/test/basics.jl index e029835..e757561 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -19,14 +19,14 @@ struct Unary <: Expr op::String arg::Expr end -@functor Unary +@functor Unary (arg,) struct Binary <: Expr lhs::Expr op::String rhs::Expr end -@functor Binary +@functor Binary (lhs, rhs) struct Paren <: Expr inner::Expr @@ -73,7 +73,6 @@ end countnodes(e::Functor{Paren}) = 1 + e.inner countnodes(e::Functor{Literal}) = 1 countnodes(e::Functor{Ident}) = 1 - countnodes(e::Union{String,Int}) = 0 countnodes(e) = e @test Functors.cata(countnodes, call) == 4 @@ -95,7 +94,7 @@ end @testset "ana" begin function nested(n) - go(m) = m == 0 ? Literal(n) : Functor{Paren}((;inner=(m - 1))) + go(m) = m == 0 ? Literal(n) : functor(Paren, m - 1) Functors.ana(go, n) end From 7dd805700afdb166e71d6ef485567f473731de41 Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Thu, 16 Sep 2021 13:23:52 -0700 Subject: [PATCH 4/9] Renaming and re-orging --- src/Functors.jl | 3 ++- src/functor.jl | 17 +++-------------- src/recursion_schemes.jl | 32 ++++++++++++++++++++++++++++++++ 3 files changed, 37 insertions(+), 15 deletions(-) create mode 100644 src/recursion_schemes.jl diff --git a/src/Functors.jl b/src/Functors.jl index 0b43f00..d378b2f 100644 --- a/src/Functors.jl +++ b/src/Functors.jl @@ -1,7 +1,8 @@ module Functors include("functor.jl") +include("recursion_schemes.jl") -export Functor, functor, @functor, fmap +export Functor, functor, fmap, @functor end # module diff --git a/src/functor.jl b/src/functor.jl index c5fd75d..a4713b7 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -9,7 +9,7 @@ functor(::Type{T}; kwargs...) where T = Functor{T}(NamedTuple{fieldnames(T)}(val backing(x) = x backing(func::Functor) = getfield(func, :inner) # TODO better name -paramvalues(x) = backing(x) +children(x) = backing(x) Base.getproperty(func::Functor, prop::Symbol) = getproperty(backing(func), prop) Base.getindex(func::Functor, prop) = getindex(backing(func), prop) @@ -31,13 +31,13 @@ function makefunctor(m::Module, T, fs=fieldnames(T)) @eval m begin $Functors.project(x::$T) = $Functors.Functor{$T}(($(escfields...),)) - $Functors.paramvalues(func::$Functors.Functor{$T}) = ($(escfs...),) + $Functors.children(func::$Functors.Functor{$T}) = ($(escfs...),) $Functors.fmap(f, func::$Functors.Functor{$T}) = $Functors.Functor{$T}(($(escfmap...),)) end end function functorm(T, fs=nothing) - fs === nothing || Meta.isexpr(fs, :tuple) || error("@functor T (a, b)") +fs === nothing || Meta.isexpr(fs, :tuple) || error("@functor T (a, b)") fs = fs === nothing ? [] : [:($(map(QuoteNode, fs.args)...),)] :(makefunctor(@__MODULE__, $(esc(T)), $(fs...))) end @@ -46,17 +46,6 @@ macro functor(args...) functorm(args...) end -# recursion schemes -cata(f, x) = f(fmap(y -> cata(f, y), project(x))) -ana(f, x) = embed(fmap(y -> ana(f, y), f(x))) - -# aliases -const fold = cata -const unfold = ana - -# convenience functions -rfmap(f, x) = cata(embed ∘ f, x) - # built-ins fmap(f, xs::T) where T <: Union{Tuple,NamedTuple,AbstractArray} = map(f, xs) diff --git a/src/recursion_schemes.jl b/src/recursion_schemes.jl new file mode 100644 index 0000000..9e6ae68 --- /dev/null +++ b/src/recursion_schemes.jl @@ -0,0 +1,32 @@ +# recursion schemes +struct Fold{F,C} + fn::F + cache::C +end +Fold(f) = Fold(f, IdDict()) + +function (f::Fold)(x) + haskey(f.cache, x) && return f.cache[x] + f.cache[x] = f.fn(fmap(f, project(x))) +end + +struct Unfold{F,C} + fn::F + cache::C +end +Unfold(f) = Unfold(f, IdDict()) + +function (f::Unfold)(x) + haskey(f.cache, x) && return f.cache[x] + f.cache[x] = embed(fmap(f, f.fn(x))) +end + +fold(f, x) = Fold(f)(x) +ana(f, x) = Unfold(f)(x) + +# aliases +const cata = fold +const ana = unfold + +# convenience functions +rfmap(f, x) = fold(embed ∘ f, x) \ No newline at end of file From 9dc969413800a8cf7ce377b378af0ddaa63c927f Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Thu, 16 Sep 2021 13:50:11 -0700 Subject: [PATCH 5/9] name tweaks and ComposedFunction test --- src/Functors.jl | 2 +- src/functor.jl | 3 ++- src/recursion_schemes.jl | 8 ++++---- test/basics.jl | 40 ++++++++++++++++++++++++---------------- 4 files changed, 31 insertions(+), 22 deletions(-) diff --git a/src/Functors.jl b/src/Functors.jl index d378b2f..372d089 100644 --- a/src/Functors.jl +++ b/src/Functors.jl @@ -3,6 +3,6 @@ module Functors include("functor.jl") include("recursion_schemes.jl") -export Functor, functor, fmap, @functor +export Functor, functor, @functor end # module diff --git a/src/functor.jl b/src/functor.jl index a4713b7..e81baab 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -24,7 +24,7 @@ embed(func::Functor{T}) where T = T(backing(func)...) function makefunctor(m::Module, T, fs=fieldnames(T)) escfields = [:($field = x.$field) for field in fieldnames(T)] - escfs = [:($field = x.$field) for field in fs] + escfs = [:($field = func.$field) for field in fs] escfmap = map(fieldnames(T)) do field field in fs ? :($field = f(func.$field)) : :($field = func.$field) end @@ -32,6 +32,7 @@ function makefunctor(m::Module, T, fs=fieldnames(T)) @eval m begin $Functors.project(x::$T) = $Functors.Functor{$T}(($(escfields...),)) $Functors.children(func::$Functors.Functor{$T}) = ($(escfs...),) + # Ref. https://gitlab.haskell.org/ghc/ghc/-/wikis/commentary/compiler/derive-functor $Functors.fmap(f, func::$Functors.Functor{$T}) = $Functors.Functor{$T}(($(escfmap...),)) end end diff --git a/src/recursion_schemes.jl b/src/recursion_schemes.jl index 9e6ae68..4393bfa 100644 --- a/src/recursion_schemes.jl +++ b/src/recursion_schemes.jl @@ -16,13 +16,13 @@ struct Unfold{F,C} end Unfold(f) = Unfold(f, IdDict()) -function (f::Unfold)(x) - haskey(f.cache, x) && return f.cache[x] - f.cache[x] = embed(fmap(f, f.fn(x))) +function (u::Unfold)(x) + haskey(u.cache, x) && return u.cache[x] + u.cache[x] = embed(fmap(u, u.fn(x))) end fold(f, x) = Fold(f)(x) -ana(f, x) = Unfold(f)(x) +unfold(f, x) = Unfold(f)(x) # aliases const cata = fold diff --git a/test/basics.jl b/test/basics.jl index e757561..bcb2c2d 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -44,17 +44,17 @@ end @functor Literal -@testset "cata" begin +@testset "folding" begin ten, add = Literal(10), Ident("add") call = Call(add, [ten, ten]) @testset "converting to POJOs" begin expected = (fn = (;name="add"), args = [(;val=10), (;val=10)]) - @test Functors.cata(Functors.backing, call) == expected + @test Functors.fold(Functors.backing, call) == expected end @testset "roundtrip" begin - @test Functors.cata(Functors.embed, call) == call + @test Functors.fold(Functors.embed, call) == call end @testset "rfmap" begin @@ -88,28 +88,36 @@ end prettyprint(p::Functor{Paren}) = "[$(p.inner)]" prettyprint(e) = e - @test Functors.cata(prettyprint, call) == "add(10,10)" + @test Functors.fold(prettyprint, call) == "add(10,10)" end end -@testset "ana" begin +@testset "unfolding" begin function nested(n) go(m) = m == 0 ? Literal(n) : functor(Paren, m - 1) - Functors.ana(go, n) + Functors.unfold(go, n) end @test nested(3) == 3 |> Literal |> Paren |> Paren |> Paren end -# @static if VERSION >= v"1.6" -# @testset "ComposedFunction" begin -# f1 = Foo(1.1, 2.2) -# f2 = Bar(3.3) -# @test Functors.functor(f1 ∘ f2)[1] == (outer = f1, inner = f2) -# @test Functors.functor(f1 ∘ f2)[2]((outer = f1, inner = f2)) == f1 ∘ f2 -# @test fmap(x -> x + 10, f1 ∘ f2) == Foo(11.1, 12.2) ∘ Bar(13.3) -# end -# end +@static if VERSION >= v"1.6" + @testset "ComposedFunction" begin + struct Foo a; b end + struct Bar c end + @functor Foo + @functor Bar + + plus10(x) = x + plus10(x::Real) = x + 10 + + f1 = Foo(1.1, 2.2) + f2 = Bar(3.3) + @test Functors.children(Functors.project(f1 ∘ f2)) == (outer = f1, inner = f2) + @test Functors.embed(Functors.project(f1 ∘ f2)) == f1 ∘ f2 + @test Functors.rfmap(plus10, f1 ∘ f2) == Foo(11.1, 12.2) ∘ Bar(13.3) + end +end # @testset "Nested" begin # model = Bar(Foo(1, [1, 2, 3])) @@ -117,7 +125,7 @@ end # model′ = fmap(float, model) # @test model.x.y == model′.x.y -# @test model′.x.y isa Vector{Float64} + # @test model′.x.y isa Vector{Float64} # end # @testset "Exclude" begin From 1b968cb0a8c7ae873d7d24e437ec31ebe8588fd9 Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Thu, 16 Sep 2021 23:07:28 -0700 Subject: [PATCH 6/9] Big Cata revamp and L2 example --- src/functor.jl | 34 +++++-- src/recursion_schemes.jl | 31 +++++-- test/basics.jl | 194 ++++++++++++++------------------------- 3 files changed, 120 insertions(+), 139 deletions(-) diff --git a/src/functor.jl b/src/functor.jl index e81baab..7039c1b 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -9,13 +9,17 @@ functor(::Type{T}; kwargs...) where T = Functor{T}(NamedTuple{fieldnames(T)}(val backing(x) = x backing(func::Functor) = getfield(func, :inner) # TODO better name -children(x) = backing(x) +children(_) = nothing +children(x::Functor) = backing(x) +isleaf(x) = children(x) === nothing Base.getproperty(func::Functor, prop::Symbol) = getproperty(backing(func), prop) Base.getindex(func::Functor, prop) = getindex(backing(func), prop) fmap(_, x) = x +fmap(_, xs...) = xs fmap(f, func::Functor{T}) where T = Functor{T}(map(f, backing(func))) +fmap(f, func::Functor, funcs...) where T = Functor{T}(map(f, backing(func), map(backing, funcs)...)) project(x) = x embed(func) = func @@ -23,23 +27,32 @@ embed(func) = func embed(func::Functor{T}) where T = T(backing(func)...) function makefunctor(m::Module, T, fs=fieldnames(T)) - escfields = [:($field = x.$field) for field in fieldnames(T)] + escfields = [:(x.$field) for field in fieldnames(T)] escfs = [:($field = func.$field) for field in fs] escfmap = map(fieldnames(T)) do field - field in fs ? :($field = f(func.$field)) : :($field = func.$field) + field in fs ? :(f(func.$field)) : :(func.$field) + end + escvfmap = map(fieldnames(T)) do field + field ∉ fs && :(func.$field) + :(f(func.$field, getproperty.(funcs, $(Meta.quot(field)))...)) end @eval m begin - $Functors.project(x::$T) = $Functors.Functor{$T}(($(escfields...),)) + $Functors.project(x::$T) = $Functors.functor($T, $(escfields...)) $Functors.children(func::$Functors.Functor{$T}) = ($(escfs...),) # Ref. https://gitlab.haskell.org/ghc/ghc/-/wikis/commentary/compiler/derive-functor - $Functors.fmap(f, func::$Functors.Functor{$T}) = $Functors.Functor{$T}(($(escfmap...),)) + $Functors.fmap(f, func::$Functors.Functor{$T}) = $Functors.functor($T, $(escfmap...)) + function $Functors.fmap(f, func::$Functors.Functor{$T}, funcs...) + # @show typeof(funcs) + # @show ($Functors.functor($T, $(escvfmap...)), funcs...) + $Functors.functor($T, $(escvfmap...)) + end end end function functorm(T, fs=nothing) -fs === nothing || Meta.isexpr(fs, :tuple) || error("@functor T (a, b)") -fs = fs === nothing ? [] : [:($(map(QuoteNode, fs.args)...),)] + fs === nothing || Meta.isexpr(fs, :tuple) || error("@functor T (a, b)") + fs = fs === nothing ? [] : [:($(map(QuoteNode, fs.args)...),)] :(makefunctor(@__MODULE__, $(esc(T)), $(fs...))) end @@ -48,7 +61,12 @@ macro functor(args...) end # built-ins -fmap(f, xs::T) where T <: Union{Tuple,NamedTuple,AbstractArray} = map(f, xs) +const JLCollection = Union{Tuple,NamedTuple,AbstractArray} +children(xs::T) where T <: JLCollection = xs +project(xs::T) where T <: JLCollection = map(project, xs) +embed(xs::T) where T <: JLCollection = map(embed, xs) +fmap(f, xs::T) where T <: JLCollection = map(f, xs) +fmap(f, xs::T, xss...) where T <: JLCollection = map(f, xs, map(backing, xss)...) @static if VERSION >= v"1.6" @functor Base.ComposedFunction diff --git a/src/recursion_schemes.jl b/src/recursion_schemes.jl index 4393bfa..47146df 100644 --- a/src/recursion_schemes.jl +++ b/src/recursion_schemes.jl @@ -1,32 +1,47 @@ # recursion schemes -struct Fold{F,C} +struct Fold{F,I,C} fn::F + isleaf::I cache::C end -Fold(f) = Fold(f, IdDict()) -function (f::Fold)(x) - haskey(f.cache, x) && return f.cache[x] - f.cache[x] = f.fn(fmap(f, project(x))) +(f::Fold)(x) = get!(f.cache, x) do + y = project(x) + f.fn(f.isleaf(y) ? y : fmap(f, y)) +end + +(f::Fold)(xs...) = get!(f.cache, xs) do + ys = map(project, xs) + ys = f.isleaf(ys[1]) ? ys : fmap((xs′...) -> f(xs′...), ys...) + ys isa Tuple ? f.fn(ys...) : f.fn(ys) end struct Unfold{F,C} fn::F cache::C end -Unfold(f) = Unfold(f, IdDict()) function (u::Unfold)(x) haskey(u.cache, x) && return u.cache[x] u.cache[x] = embed(fmap(u, u.fn(x))) end + +function (u::Unfold)(xs...) + haskey(u.cache, xs) && return u.cache[xs] + u.cache[xs] = map(embed, fmap((ys...) -> u(ys...), u.fn(xs...))) +end -fold(f, x) = Fold(f)(x) +fold(f, x; isleaf=isleaf, cache=IdDict()) = Fold(f, isleaf, cache)(x) +fold(f, xs...; isleaf=isleaf, cache=IdDict()) = Fold(f, isleaf, cache)(xs...) unfold(f, x) = Unfold(f)(x) +unfold(f, xs...) = Unfold(f)(xs...) # aliases const cata = fold const ana = unfold # convenience functions -rfmap(f, x) = fold(embed ∘ f, x) \ No newline at end of file +rfmap(f, x) = fold(embed ∘ f, x) +rfmap(f, xs...) = embed(fold(xs...) do (ys...) + embed(length(ys) > 1 ? f(ys...) : only(ys)) +end) \ No newline at end of file diff --git a/test/basics.jl b/test/basics.jl index bcb2c2d..caedd3a 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -78,7 +78,7 @@ end @test Functors.cata(countnodes, call) == 4 end - @testset "pretty-printing" begin + @testset "pretty printing" begin prettyprint(l::Functor{Literal}) = repr(l.val) prettyprint(i::Functor{Ident}) = i.name prettyprint(c::Functor{Call}) = "$(c.fn)($(join(c.args, ",")))" @@ -90,6 +90,73 @@ end @test Functors.fold(prettyprint, call) == "add(10,10)" end + + @testset "zipped fold" begin + calldata = (fn = (;name="add"), args = [(;val=5), (;val=5)]) + + # div(x) = x + div(x, _) = x + div(x::Number, y::Number) = x / y + expected = Call(Ident("add"), [Literal(2.), Literal(2.)]) + @test Functors.rfmap(div, call, calldata) == expected + + # pairnums(x) = x + pairnums(x, y) = x => y + expected = (fn = (;name="add" => "add"), args = [(;val=5 => 10), (;val=5 => 10)]) + @test Functors.rfmap(pairnums, calldata, call) == expected + end + + # based on https://github.com/FluxML/Flux.jl/issues/1284 + @testset "L₂ regularization" begin + struct Chain{T <: Tuple} + layers::T + end + + Chain(layers...) = Chain(layers) + Functors.project(c::Chain) = Functors.Functor{Chain}(c.layers) + + struct Dense{W,B} + weight::W + bias::B + end + @functor Dense + + struct Conv{W,B} + weight::W + bias::B + end + @functor Conv + + struct SkipConnection{L,C} + layer::L + connection::C + end + @functor SkipConnection + + _isbitsarray(::AbstractArray{<:Number}) = true + _isbitsarray(::AbstractArray{T}) where T = isbitstype(T) + _isbitsarray(x) = false + custom_isleaf(x) = Functors.isleaf(x) || _isbitsarray(x) + + L₂(x) = sum(something(Functors.children(x), 0)) + L₂(a::AbstractArray{<:Number}) = length(a) + L₂(d::Functor{Dense}) = d.weight + L₂(c::Functor{Conv}) = c.weight + + conv1 = Conv(ones(3, 3, 4, 4), ones(4)) + conv2 = Conv(ones(3, 3, 4, 4), ones(4)) + dense1 = Dense(ones(4, 2), ones(2)) + dense2 = Dense(ones(2, 1), ones(1)) + + model = Chain( + SkipConnection(conv1, conv2), + x -> dropdims(max(x, dims=2), dims=2), + dense1, + dense2 + ) + expected = sum(length, (conv1.weight, conv2.weight, dense1.weight, dense2.weight)) + @test Functors.fold(L₂, model; isleaf=custom_isleaf) == expected + end end @testset "unfolding" begin @@ -98,11 +165,11 @@ end Functors.unfold(go, n) end - @test nested(3) == 3 |> Literal |> Paren |> Paren |> Paren + # @test nested(3) == 3 |> Literal |> Paren |> Paren |> Paren end @static if VERSION >= v"1.6" - @testset "ComposedFunction" begin + @testset "ComposedFunction" begin struct Foo a; b end struct Bar c end @functor Foo @@ -117,123 +184,4 @@ end @test Functors.embed(Functors.project(f1 ∘ f2)) == f1 ∘ f2 @test Functors.rfmap(plus10, f1 ∘ f2) == Foo(11.1, 12.2) ∘ Bar(13.3) end -end - -# @testset "Nested" begin -# model = Bar(Foo(1, [1, 2, 3])) - -# model′ = fmap(float, model) - -# @test model.x.y == model′.x.y - # @test model′.x.y isa Vector{Float64} -# end - -# @testset "Exclude" begin -# f(x::AbstractArray) = x -# f(x::Char) = 'z' - -# x = ['a', 'b', 'c'] -# @test fmap(f, x) == ['z', 'z', 'z'] -# @test fmap(f, x; exclude = x -> x isa AbstractArray) == x - -# x = (['a', 'b', 'c'], ['d', 'e', 'f']) -# @test fmap(f, x) == (['z', 'z', 'z'], ['z', 'z', 'z']) -# @test fmap(f, x; exclude = x -> x isa AbstractArray) == x -# end - -# @testset "Walk" begin -# model = Foo((0, Bar([1, 2, 3])), [4, 5]) - -# model′ = fmapstructure(identity, model) -# @test model′ == (; x=(0, (; x=[1, 2, 3])), y=[4, 5]) -# end - -# @testset "Property list" begin -# model = Baz(1, 2, 3) -# model′ = fmap(x -> 2x, model) - -# @test (model′.x, model′.y, model′.z) == (1, 4, 3) -# end - -# @testset "fcollect" begin -# m1 = [1, 2, 3] -# m2 = 1 -# m3 = Foo(m1, m2) -# m4 = Bar(m3) -# @test all(fcollect(m4) .=== [m4, m3, m1, m2]) -# @test all(fcollect(m4, exclude = x -> x isa Array) .=== [m4, m3, m2]) -# @test all(fcollect(m4, exclude = x -> x isa Foo) .=== [m4]) - -# m1 = [1, 2, 3] -# m2 = Bar(m1) -# m0 = NoChildren(:a, :b) -# m3 = Foo(m2, m0) -# m4 = Bar(m3) -# @test all(fcollect(m4) .=== [m4, m3, m2, m1, m0]) -# end - -# struct FFoo -# x -# y -# p -# end -# @flexiblefunctor FFoo p - -# struct FBar -# x -# p -# end -# @flexiblefunctor FBar p - -# struct FBaz -# x -# y -# z -# p -# end -# @flexiblefunctor FBaz p - -# @testset "Flexible Nested" begin -# model = FBar(FFoo(1, [1, 2, 3], (:y, )), (:x,)) - -# model′ = fmap(float, model) - -# @test model.x.y == model′.x.y -# @test model′.x.y isa Vector{Float64} -# end - -# @testset "Flexible Walk" begin -# model = FFoo((0, FBar([1, 2, 3], (:x,))), [4, 5], (:x, :y)) - -# model′ = fmapstructure(identity, model) -# @test model′ == (; x=(0, (; x=[1, 2, 3])), y=[4, 5]) - -# model2 = FFoo((0, FBar([1, 2, 3], (:x,))), [4, 5], (:x,)) - -# model2′ = fmapstructure(identity, model2) -# @test model2′ == (; x=(0, (; x=[1, 2, 3]))) -# end - -# @testset "Flexible Property list" begin -# model = FBaz(1, 2, 3, (:x, :z)) -# model′ = fmap(x -> 2x, model) - -# @test (model′.x, model′.y, model′.z) == (2, 2, 6) -# end - -# @testset "Flexible fcollect" begin -# m1 = 1 -# m2 = [1, 2, 3] -# m3 = FFoo(m1, m2, (:y, )) -# m4 = FBar(m3, (:x,)) -# @test all(fcollect(m4) .=== [m4, m3, m2]) -# @test all(fcollect(m4, exclude = x -> x isa Array) .=== [m4, m3]) -# @test all(fcollect(m4, exclude = x -> x isa FFoo) .=== [m4]) - -# m0 = NoChildren(:a, :b) -# m1 = [1, 2, 3] -# m2 = FBar(m1, ()) -# m3 = FFoo(m2, m0, (:x, :y,)) -# m4 = FBar(m3, (:x,)) -# @test all(fcollect(m4) .=== [m4, m3, m2, m0]) -# end +end \ No newline at end of file From 4279da9a05c56cfed314574aef1038080dd72079 Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Mon, 11 Oct 2021 22:23:34 -0700 Subject: [PATCH 7/9] refactoring and better rfmap --- src/recursion_schemes.jl | 54 ++++++++++++++++++++++++---------------- test/basics.jl | 7 ++---- 2 files changed, 35 insertions(+), 26 deletions(-) diff --git a/src/recursion_schemes.jl b/src/recursion_schemes.jl index 47146df..18b310b 100644 --- a/src/recursion_schemes.jl +++ b/src/recursion_schemes.jl @@ -1,46 +1,58 @@ -# recursion schemes +## Generalized folds (Catamorphisms) struct Fold{F,I,C} fn::F isleaf::I cache::C end -(f::Fold)(x) = get!(f.cache, x) do +function _fold_helper(f, x) y = project(x) f.fn(f.isleaf(y) ? y : fmap(f, y)) end -(f::Fold)(xs...) = get!(f.cache, xs) do +function _fold_helper(f, xs...) ys = map(project, xs) - ys = f.isleaf(ys[1]) ? ys : fmap((xs′...) -> f(xs′...), ys...) - ys isa Tuple ? f.fn(ys...) : f.fn(ys) + any(f.isleaf, ys) ? f.fn(ys...) : f.fn(fmap((xs′...) -> f(xs′...), ys...)) end +(f::Fold)(x) = isbits(x) ? get!(() -> _fold_helper(f, x), f.cache, x) : _fold_helper(f, x) +function (f::Fold)(xs...) + # fully immutable types can't be reliably tracked by objectid, so bail + isbits(xs) && _fold_helper(f, xs...) + get!(() -> _fold_helper(f, xs...), f.cache, xs) +end + +""" + fold(f, x; isleaf, cache) + fold(f, x; isleaf, cache, accum, accum_cache) +""" +fold(f, x; isleaf=isleaf, cache=IdDict()) = Fold(f, isleaf, cache)(x) +fold(f, xs...; isleaf=isleaf, cache=IdDict()) = Fold(f, isleaf, cache)(xs...) + +""" +Alias for [fold](@ref) +""" +const cata = fold + + +## Generalized unfolds (Anamorphisms) struct Unfold{F,C} fn::F cache::C end -function (u::Unfold)(x) - haskey(u.cache, x) && return u.cache[x] - u.cache[x] = embed(fmap(u, u.fn(x))) -end - -function (u::Unfold)(xs...) - haskey(u.cache, xs) && return u.cache[xs] - u.cache[xs] = map(embed, fmap((ys...) -> u(ys...), u.fn(xs...))) -end +_unfold_helper(u, x) = embed(fmap(u, u.fn(x))) +_unfold_helper(u, xs...) = map(embed, fmap((ys...) -> u(ys...), u.fn(xs...))) -fold(f, x; isleaf=isleaf, cache=IdDict()) = Fold(f, isleaf, cache)(x) -fold(f, xs...; isleaf=isleaf, cache=IdDict()) = Fold(f, isleaf, cache)(xs...) -unfold(f, x) = Unfold(f)(x) -unfold(f, xs...) = Unfold(f)(xs...) +# (u::Unfold)(x) = ismutable(x) ? get!(() -> _unfold_helper(u, x), u.cache, x) : _unfold_helper(u, x) +(u::Unfold)(x, xs...) = ismutable(x) ? get!(() -> _unfold_helper(u, x, xs...), u.cache, x) : _unfold_helper(u, x, xs...) -# aliases -const cata = fold +unfold(f, x; cache=IdDict()) = Unfold(f, cache)(x) +unfold(f, xs...; cache=IdDict()) = Unfold(f, cache)(xs...) const ana = unfold -# convenience functions + +## Convenience functions rfmap(f, x) = fold(embed ∘ f, x) rfmap(f, xs...) = embed(fold(xs...) do (ys...) embed(length(ys) > 1 ? f(ys...) : only(ys)) diff --git a/test/basics.jl b/test/basics.jl index caedd3a..67d1017 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -94,16 +94,13 @@ end @testset "zipped fold" begin calldata = (fn = (;name="add"), args = [(;val=5), (;val=5)]) - # div(x) = x div(x, _) = x div(x::Number, y::Number) = x / y expected = Call(Ident("add"), [Literal(2.), Literal(2.)]) @test Functors.rfmap(div, call, calldata) == expected - # pairnums(x) = x - pairnums(x, y) = x => y expected = (fn = (;name="add" => "add"), args = [(;val=5 => 10), (;val=5 => 10)]) - @test Functors.rfmap(pairnums, calldata, call) == expected + @test Functors.rfmap(=>, calldata, call) == expected end # based on https://github.com/FluxML/Flux.jl/issues/1284 @@ -165,7 +162,7 @@ end Functors.unfold(go, n) end - # @test nested(3) == 3 |> Literal |> Paren |> Paren |> Paren + @test nested(3) == 3 |> Literal |> Paren |> Paren |> Paren end @static if VERSION >= v"1.6" From 235d7e04e624a4929803539f67eedf04269a922d Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Wed, 13 Oct 2021 16:07:16 -0700 Subject: [PATCH 8/9] Stricter caching requirements --- src/recursion_schemes.jl | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/recursion_schemes.jl b/src/recursion_schemes.jl index 18b310b..c689680 100644 --- a/src/recursion_schemes.jl +++ b/src/recursion_schemes.jl @@ -1,3 +1,9 @@ +# Ordinarily separate instances of mutable types would have distinct objectids +# (or at least fail a === test), but String is an exception to both. +# This function exists purely to handle that edge case. +iscacheable(x) = ismutable(x) +iscacheable(::String) = false + ## Generalized folds (Catamorphisms) struct Fold{F,I,C} fn::F @@ -12,13 +18,13 @@ end function _fold_helper(f, xs...) ys = map(project, xs) + # TODO are there times when we _should_ recurse despite a leaf? any(f.isleaf, ys) ? f.fn(ys...) : f.fn(fmap((xs′...) -> f(xs′...), ys...)) end -(f::Fold)(x) = isbits(x) ? get!(() -> _fold_helper(f, x), f.cache, x) : _fold_helper(f, x) +(f::Fold)(x) = !iscacheable(x) ? get!(() -> _fold_helper(f, x), f.cache, x) : _fold_helper(f, x) function (f::Fold)(xs...) - # fully immutable types can't be reliably tracked by objectid, so bail - isbits(xs) && _fold_helper(f, xs...) + all(iscacheable, xs) || _fold_helper(f, xs...) get!(() -> _fold_helper(f, xs...), f.cache, xs) end From be92c7edfc647640f6861597680034af7f1a5cfe Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Wed, 13 Oct 2021 17:42:28 -0700 Subject: [PATCH 9/9] Cleanup and docstrings --- src/functor.jl | 91 +++++++++++++++++++++++++++++++++++++++- src/recursion_schemes.jl | 79 ++++++++++++++++++++++++++++------ test/basics.jl | 2 +- 3 files changed, 158 insertions(+), 14 deletions(-) diff --git a/src/functor.jl b/src/functor.jl index 7039c1b..6c573e1 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -1,27 +1,105 @@ +""" + Functor{T, FS} + +Type-level representation of a [base functor](https://hackage.haskell.org/package/recursion-schemes-5.2.2.1/docs/Data-Functor-Foldable-TH.html). +A `Functor` wraps a struct of type `T` using a backing of type `FS`. +`Functors` also implement `getindex` and `getproperty` to act as rough proxies for the original type's structure. + +Note that not all types need to be wrapped in `Functor`. +Notable exceptions include primitives, tuples and arrays. +For an explanation of why we need base functors, see [here](https://blog.sumtypeofway.com/posts/recursion-schemes-part-4-point-5.html). +To implement a `Functor` wrapper for your own class, see [`@functor`](@ref). +""" struct Functor{T,FS} inner::FS end Functor{T}(inner) where T = Functor{T,typeof(inner)}(inner) + +""" + functor(::Type{T}, args...) + functor(::Type{T}; kwargs...) + +Convenience function for manually instantiating [`Functor`](@ref)s of type `T`. +""" functor(::Type{T}, args...) where T = Functor{T}(NamedTuple{fieldnames(T)}(args)) functor(::Type{T}; kwargs...) where T = Functor{T}(NamedTuple{fieldnames(T)}(values(kwargs))) -backing(x) = x +""" + backing(func) + +Returns the backing storage of a [`Functor`](@ref). +For unwrapped types, this is a no-op. +For wrapped struct types, this returns every field in the original struct. +If you only want the fields that are mapped over, see [`Functors.children`](@ref). +""" +backing(func) = func backing(func::Functor) = getfield(func, :inner) + # TODO better name +""" + children(func) + +Returns the child nodes of a [`Functor`](@ref). +For types that have not opted into the `Functor` interface, this returns `nothing`. +For many special-cased base types, this is a no-op. +For wrapped struct types, this returns the fields that can mapped over with [`Functors.fmap`](@ref). +""" children(_) = nothing children(x::Functor) = backing(x) + +""" + isleaf(func) + +Determines if a given value is a leaf node, i.e. it is not a functor or has no children. +By default, any type that has not implemented the functor interface is considered a leaf node. +""" isleaf(x) = children(x) === nothing Base.getproperty(func::Functor, prop::Symbol) = getproperty(backing(func), prop) Base.getindex(func::Functor, prop) = getindex(backing(func), prop) +""" + fmap(f, x) + fmap(f, xs...) + +A [structure and type preserving](https://hackage.haskell.org/package/base-4.15.0.0/docs/Prelude.html#v:fmap) `map` that works for all functors. +Similar to `map` on a `Tuple` or `NamedTuple`, except it works for any type that implements the functor interface. + +When multiple inputs `xs...` are provided, `fmap` will return a Functor with the structure of the first input. + +# Examples +```jldoctest +julia> struct Foo; x; y; z; end +julia> @functor Foo (x, z) +julia> foo = Foo(1, 3, 5); +julia> func = fmap(x -> 2x, project(foo)) +Functor{Foo}((x=2, y=3, z=10)) +julia> embed(func) +Foo(2, 3, 10) +``` +""" fmap(_, x) = x fmap(_, xs...) = xs fmap(f, func::Functor{T}) where T = Functor{T}(map(f, backing(func))) fmap(f, func::Functor, funcs...) where T = Functor{T}(map(f, backing(func), map(backing, funcs)...)) +""" + project(x) + +Transforms a plain Julia value into its functor representation. +For types that have not opted into the `Functor` interface and special-cased base types, this is a no-op. +For wrapped struct types, this returns a [`Functor`](@ref). +""" project(x) = x + +""" + embed(func) + +Transforms a functor back into its plain Julia representation. +For [`Functor`](@ref)s, this returns the original wrapped struct type. +For types that have not opted into the `Functor` interface and special-cased base types, this is a no-op. +""" embed(func) = func # TODO use ConstructionBase? embed(func::Functor{T}) where T = T(backing(func)...) @@ -56,6 +134,17 @@ function functorm(T, fs=nothing) :(makefunctor(@__MODULE__, $(esc(T)), $(fs...))) end +""" + @functor T + @functor T (field1, field2, ...) + +Fancy macro that automatically implements the functor interface for a given type `T`. +This includes [`Functors.project`](@ref), [`Functors.children`](@ref) and [`Functors.fmap`](@ref). + +By default, `@functor T` will include all fields of the original type as children in the [`Functor`](@ref) representation. +To include only certain fields, pass a tuple of field names to `@functor`. +Passing the empty tuple like `@functor T ()` will exclude all fields from being child nodes. +""" macro functor(args...) functorm(args...) end diff --git a/src/recursion_schemes.jl b/src/recursion_schemes.jl index c689680..23e11d2 100644 --- a/src/recursion_schemes.jl +++ b/src/recursion_schemes.jl @@ -22,6 +22,7 @@ function _fold_helper(f, xs...) any(f.isleaf, ys) ? f.fn(ys...) : f.fn(fmap((xs′...) -> f(xs′...), ys...)) end +# TODO add fast paths when f.cache === nothing? (f::Fold)(x) = !iscacheable(x) ? get!(() -> _fold_helper(f, x), f.cache, x) : _fold_helper(f, x) function (f::Fold)(xs...) all(iscacheable, xs) || _fold_helper(f, xs...) @@ -29,37 +30,91 @@ function (f::Fold)(xs...) end """ - fold(f, x; isleaf, cache) - fold(f, x; isleaf, cache, accum, accum_cache) + fold(f, x; isleaf=Functors.isleaf, cache=IdDict()) + fold(f, xs...; isleaf=Functors.isleaf, cache=IdDict()) + +Generalized fold over functors. +The fancy functional programming term for this is a "catamorphism" +(see [here](https://blog.sumtypeofway.com/posts/recursion-schemes-part-2.html) for a good intro). + +`f` is a function that takes one or more [`Functor`](@ref)s and returns any value. + +To control when `fold` stops recursing, pass a custom `isleaf` predicate. +This defaults to [`Functors.isleaf`](@ref). + +If `x` doesn't have any structural sharing, or can be safely unpacked into a tree without sharing, +Setting `cache=nothing` will provide a small speedup. +If you're not sure whether this applies, just leave it as-is. + +When multiple inputs `xs...` are passed, `fold` will return a functor based on the structure of the first input. +Just like `map` and `zip`, `fold` will stop recursing once any it hits a leaf node in any input. """ fold(f, x; isleaf=isleaf, cache=IdDict()) = Fold(f, isleaf, cache)(x) fold(f, xs...; isleaf=isleaf, cache=IdDict()) = Fold(f, isleaf, cache)(xs...) """ -Alias for [fold](@ref) +Alias for [`Functors.fold`](@ref) """ const cata = fold ## Generalized unfolds (Anamorphisms) -struct Unfold{F,C} +struct Unfold{F} fn::F - cache::C end _unfold_helper(u, x) = embed(fmap(u, u.fn(x))) _unfold_helper(u, xs...) = map(embed, fmap((ys...) -> u(ys...), u.fn(xs...))) -# (u::Unfold)(x) = ismutable(x) ? get!(() -> _unfold_helper(u, x), u.cache, x) : _unfold_helper(u, x) -(u::Unfold)(x, xs...) = ismutable(x) ? get!(() -> _unfold_helper(u, x, xs...), u.cache, x) : _unfold_helper(u, x, xs...) +(u::Unfold)(x, xs...) = _unfold_helper(u, x, xs...) + +""" + unfold(f, x) + unfold(f, xs...) -unfold(f, x; cache=IdDict()) = Unfold(f, cache)(x) -unfold(f, xs...; cache=IdDict()) = Unfold(f, cache)(xs...) +Generalized unfold producing functors. +The fancy functional programming term for this is an "anamorphism" +(see [here](https://blog.sumtypeofway.com/posts/recursion-schemes-part-2.html) for a good intro). + +`f` is a function that takes one or more inputs and returns a [`Functor`](@ref). +""" +unfold(f, x) = Unfold(f)(x) +unfold(f, xs...) = Unfold(f)(xs...) + +""" +Alias for [`Functors.unfold`](@ref) +""" const ana = unfold ## Convenience functions -rfmap(f, x) = fold(embed ∘ f, x) -rfmap(f, xs...) = embed(fold(xs...) do (ys...) +""" + rfmap(f, x; isleaf=Functors.isleaf, walk = Functors._default_walk) + rfmap(f, xs...; isleaf=Functors.isleaf, walk = Functors._default_walk) + +A recursive, structure and type preserving `map` that works on nested functors. +`rfmap` traverses and transforms every leaf node identified by `isleaf` with `f`. + +To control when `fold` stops recursing, pass a custom `isleaf` predicate. +This defaults to [`Functors.isleaf`](@ref). + +When multiple inputs `xs...` are passed, `rfmap` will return a functor based on the structure of the first input. +Just like `map` and `zip`, `rfmap` will stop recursing once any it hits a leaf node in any input. + +# Examples +```jldoctest +julia> struct Foo; x; y; end +julia> @functor Foo +julia> struct Bar; x; end +julia> @functor Bar +julia> m = Foo(Bar([1,2,3]), (4, 5)); +julia> rfmap(x -> 2x, m) +Foo(Bar([2, 4, 6]), (8, 10)) +julia> rfmap(string, m) +Foo(Bar("[1, 2, 3]"), ("4", "5")) +``` +""" +rfmap(f, x; isleaf=isleaf) = fold(embed ∘ f, x, isleaf=isleaf) +rfmap(f, xs...; isleaf=isleaf) = fold(xs...; isleaf=isleaf) do (ys...) embed(length(ys) > 1 ? f(ys...) : only(ys)) -end) \ No newline at end of file +end \ No newline at end of file diff --git a/test/basics.jl b/test/basics.jl index 67d1017..1190630 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -166,7 +166,7 @@ end end @static if VERSION >= v"1.6" - @testset "ComposedFunction" begin + @testset "ComposedFunction" begin struct Foo a; b end struct Bar c end @functor Foo