Skip to content

Commit

Permalink
WIP: premapping based cyclic type zeros
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox committed Jan 25, 2024
1 parent fe63c33 commit f4621e7
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 1 deletion.
44 changes: 43 additions & 1 deletion src/tangent_types/abstract_zero.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,19 @@ end
)
Expr(:kw, fname, fval)
end
return if has_mutable_tangent(primal)

# easy case exit early, can't hold references, can't be a reference.
if isbitstype(primal)
return :($Tangent{$primal}($(Expr(:parameters, zfield_exprs...))))
end

# hard case need to be prepared for cycic references to this, or that are contained within this
quote
counts = $count_references!(primal)
end

## TODO rewrite below
has_mutable_tangent(primal)
any_mask = map(fieldnames(primal), fieldtypes(primal)) do fname, ftype
# If it is is unassigned, or if it doesn't have a concrete type, let it take any value for its tangent
fdef = :(!isdefined(primal, $(QuoteNode(fname))) || !isconcretetype($ftype))
Expand Down Expand Up @@ -171,6 +183,36 @@ function zero_tangent(x::Array{P,N}) where {P,N}
return y
end

###############################################
count_references!(x) = count_references(IdDict{Any, Int}(), x)
function count_references!(counts::IdDict{Any, Int}, x)
isbits(x) && return counts # can't be a refernece and can't hold a reference
counts[x] = get(counts, x, 0) + 1 # Increment *before* recursing
if counts[x] == 1 # Only recurse the first time
for ii in fieldcount(typeof(x))
field = getfield(x, ii)
count_references!(counts, field)
end
end
return counts
end

function count_references!(counts::IdDict{Any, Int}, x::Array)
counts[x] = get(counts, x, 0) + 1 # increment before recursing
isbitstype(eltype(x)) && return counts # no need to look inside, it can't hold references
if counts[x] == 1 # only recurse the first time
for ele in x
count_references!(counts, ele)
end
end
return counts
end

count_references!(counts::IdDict{Any, Int}, ::DataType) = counts

###############################################


# Sad heauristic methods we need because of unassigned values
guess_zero_tangent_type(::Type{T}) where {T<:Number} = T
guess_zero_tangent_type(::Type{T}) where {T<:Integer} = typeof(float(zero(T)))
Expand Down
32 changes: 32 additions & 0 deletions test/tangent_types/abstract_zero.jl
Original file line number Diff line number Diff line change
Expand Up @@ -275,4 +275,36 @@ end
@test d.z == [2.0, 3.0]
@test d.z isa SubArray
end


@testset "cyclic references" begin
mutable struct Link
data::Float64
next::Link
Link(data) = new(data)
end

lk = Link(1.5)
lk.next = lk

d = zero_tangent(lk)
@test d.data == 0.0
@test d.next === d

struct CarryingArray
x::Vector
end
ca = CarryingArray(Any[1.5])
push!(ca.x, ca)
@test d_ca = zero_tangent(ca)
@test d_ca[1] == 0.0
@test d_ca[2] === _ca

# Idea: check if typeof(xs) <: eltype(xs), if so need to cache it before computing
xs = Any[1.5]
push!(xs, xs)
@test d_xs = zero_tangent(xs)
@test d_xs[1] == 0.0
@test d_xs[2] == d_xs
end
end

0 comments on commit f4621e7

Please sign in to comment.