From f4621e7e06a850838923adcff039a2302b634cd6 Mon Sep 17 00:00:00 2001 From: Frames White Date: Thu, 25 Jan 2024 14:20:15 +0800 Subject: [PATCH] WIP: premapping based cyclic type zeros --- src/tangent_types/abstract_zero.jl | 44 ++++++++++++++++++++++++++++- test/tangent_types/abstract_zero.jl | 32 +++++++++++++++++++++ 2 files changed, 75 insertions(+), 1 deletion(-) diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index f921db29d..d52f213d0 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -138,7 +138,19 @@ end ) Expr(:kw, fname, fval) end - return if has_mutable_tangent(primal) + + # easy case exit early, can't hold references, can't be a reference. + if isbitstype(primal) + return :($Tangent{$primal}($(Expr(:parameters, zfield_exprs...)))) + end + + # hard case need to be prepared for cycic references to this, or that are contained within this + quote + counts = $count_references!(primal) + end + +## TODO rewrite below + has_mutable_tangent(primal) any_mask = map(fieldnames(primal), fieldtypes(primal)) do fname, ftype # If it is is unassigned, or if it doesn't have a concrete type, let it take any value for its tangent fdef = :(!isdefined(primal, $(QuoteNode(fname))) || !isconcretetype($ftype)) @@ -171,6 +183,36 @@ function zero_tangent(x::Array{P,N}) where {P,N} return y end +############################################### +count_references!(x) = count_references(IdDict{Any, Int}(), x) +function count_references!(counts::IdDict{Any, Int}, x) + isbits(x) && return counts # can't be a refernece and can't hold a reference + counts[x] = get(counts, x, 0) + 1 # Increment *before* recursing + if counts[x] == 1 # Only recurse the first time + for ii in fieldcount(typeof(x)) + field = getfield(x, ii) + count_references!(counts, field) + end + end + return counts +end + +function count_references!(counts::IdDict{Any, Int}, x::Array) + counts[x] = get(counts, x, 0) + 1 # increment before recursing + isbitstype(eltype(x)) && return counts # no need to look inside, it can't hold references + if counts[x] == 1 # only recurse the first time + for ele in x + count_references!(counts, ele) + end + end + return counts +end + +count_references!(counts::IdDict{Any, Int}, ::DataType) = counts + +############################################### + + # Sad heauristic methods we need because of unassigned values guess_zero_tangent_type(::Type{T}) where {T<:Number} = T guess_zero_tangent_type(::Type{T}) where {T<:Integer} = typeof(float(zero(T))) diff --git a/test/tangent_types/abstract_zero.jl b/test/tangent_types/abstract_zero.jl index a4df83ebf..f198f5f77 100644 --- a/test/tangent_types/abstract_zero.jl +++ b/test/tangent_types/abstract_zero.jl @@ -275,4 +275,36 @@ end @test d.z == [2.0, 3.0] @test d.z isa SubArray end + + + @testset "cyclic references" begin + mutable struct Link + data::Float64 + next::Link + Link(data) = new(data) + end + + lk = Link(1.5) + lk.next = lk + + d = zero_tangent(lk) + @test d.data == 0.0 + @test d.next === d + + struct CarryingArray + x::Vector + end + ca = CarryingArray(Any[1.5]) + push!(ca.x, ca) + @test d_ca = zero_tangent(ca) + @test d_ca[1] == 0.0 + @test d_ca[2] === _ca + + # Idea: check if typeof(xs) <: eltype(xs), if so need to cache it before computing + xs = Any[1.5] + push!(xs, xs) + @test d_xs = zero_tangent(xs) + @test d_xs[1] == 0.0 + @test d_xs[2] == d_xs + end end