Skip to content

Commit

Permalink
Abstract is mixed (#1536)
Browse files Browse the repository at this point in the history
* Abstract is mixed

* fix unionall

* fix

* more fixups
  • Loading branch information
wsmoses authored Jun 15, 2024
1 parent a889bb6 commit 82cc451
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 48 deletions.
97 changes: 64 additions & 33 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -252,16 +252,30 @@ end
ActivityState(Int(a1) | Int(a2))
end

struct Merger{seen,worldT,justActive,UnionSret}
struct Merger{seen,worldT,justActive,UnionSret,AbstractIsMixed}
world::worldT
end

@inline element(::Val{T}) where T = T

@inline function (c::Merger{seen,worldT,justActive,UnionSret})(f::Int) where {seen,worldT,justActive,UnionSret}
# From https://github.com/JuliaLang/julia/blob/81813164963f38dcd779d65ecd222fad8d7ed437/src/cgutils.cpp#L570
@inline function isghostty(ty)
if ty === Union{}
return true
end
if Base.isconcretetype(ty) && !ismutabletype(ty)
if sizeof(ty) == 0
return true
end
# TODO consider struct_to_llvm ?
end
return false
end

@inline function (c::Merger{seen,worldT,justActive,UnionSret,AbstractIsMixed})(f::Int) where {seen,worldT,justActive,UnionSret,AbstractIsMixed}
T = element(first(seen))

reftype = ismutabletype(T) || T isa UnionAll
reftype = ismutabletype(T) || (T isa UnionAll && !AbstractIsMixed)

if justActive && reftype
return Val(AnyState)
Expand All @@ -273,7 +287,7 @@ end
return Val(AnyState)
end

sub = active_reg_inner(subT, seen, c.world, Val(justActive), Val(UnionSret))
sub = active_reg_inner(subT, seen, c.world, Val(justActive), Val(UnionSret), Val(AbstractIsMixed))

if sub == AnyState
Val(AnyState)
Expand Down Expand Up @@ -372,24 +386,31 @@ end
end)
end

@inline function active_reg_recur(::Type{ST}, seen::Seen, world, ::Val{justActive}, ::Val{UnionSret}) where {ST, Seen, justActive, UnionSret}
@inline function active_reg_recur(::Type{ST}, seen::Seen, world, ::Val{justActive}, ::Val{UnionSret}, ::Val{AbstractIsMixed}) where {ST, Seen, justActive, UnionSret, AbstractIsMixed}
if ST isa Union
return forcefold(Val(active_reg_recur(ST.a, seen, world, Val(justActive), Val(UnionSret))), Val(active_reg_recur(ST.b, seen, world, Val(justActive), Val(UnionSret))))
return forcefold(Val(active_reg_recur(ST.a, seen, world, Val(justActive), Val(UnionSret), Val(AbstractIsMixed))), Val(active_reg_recur(ST.b, seen, world, Val(justActive), Val(UnionSret), Val(AbstractIsMixed))))
end
return active_reg_inner(ST, seen, world, Val(justActive), Val(UnionSret))
return active_reg_inner(ST, seen, world, Val(justActive), Val(UnionSret), Val(AbstractIsMixed))
end

@inline function active_reg_inner(::Type{T}, seen::ST, world::Union{Nothing, UInt}, ::Val{justActive}=Val(false), ::Val{UnionSret}=Val(false))::ActivityState where {ST,T, justActive, UnionSret}
@inline is_vararg_tup(x) = false
@inline is_vararg_tup(::Type{Tuple{Vararg{T2}}}) where T2 = true

@inline function active_reg_inner(::Type{T}, seen::ST, world::Union{Nothing, UInt}, ::Val{justActive}=Val(false), ::Val{UnionSret}=Val(false), ::Val{AbstractIsMixed}=Val(false))::ActivityState where {ST,T, justActive, UnionSret, AbstractIsMixed}
if T === Any
return DupState
if AbstractIsMixed
return MixedState
else
return DupState
end
end

if T === Union{}
return AnyState
end

if T <: Complex && !(T isa UnionAll)
return active_reg_inner(ptreltype(T), seen, world, Val(justActive), Val(UnionSret))
return active_reg_inner(ptreltype(T), seen, world, Val(justActive), Val(UnionSret), Val(AbstractIsMixed))
end

if T <: AbstractFloat
Expand All @@ -401,10 +422,14 @@ end
return AnyState
end

if is_arrayorvararg_ty(T) && active_reg_inner(ptreltype(T), seen, world, Val(justActive), Val(UnionSret)) == AnyState
if is_arrayorvararg_ty(T) && active_reg_inner(ptreltype(T), seen, world, Val(justActive), Val(UnionSret), Val(AbstractIsMixed)) == AnyState
return AnyState
else
return DupState
if AbstractIsMixed && is_vararg_tup(T)
return MixedState
else
return DupState
end
end
end

Expand Down Expand Up @@ -434,35 +459,55 @@ end
if T isa UnionAll
aT = Base.argument_datatype(T)
if aT === nothing
return DupState
if AbstractIsMixed
return MixedState
else
return DupState
end
end
if datatype_fieldcount(aT) === nothing
return DupState
if AbstractIsMixed
return MixedState
else
return DupState
end
end
end

if T isa Union
# if sret union, the data is stored in a stack memory location and is therefore
# not unique'd preventing the boxing of the union in the default case
if UnionSret && is_sret_union(T)
return active_reg_recur(T, seen, world, Val(justActive), Val(UnionSret))
return active_reg_recur(T, seen, world, Val(justActive), Val(UnionSret), Val(AbstractIsMixed))
else
if justActive
return AnyState
end
if active_reg_inner(T.a, seen, world, Val(justActive), Val(UnionSret)) != AnyState
return DupState
if AbstractIsMixed
return MixedState
else
return DupState
end
end
if active_reg_inner(T.b, seen, world, Val(justActive), Val(UnionSret)) != AnyState
return DupState
if AbstractIsMixed
return MixedState
else
return DupState
end
end
end
return AnyState
end

# if abstract it must be by reference
if Base.isabstracttype(T)
return DupState
if AbstractIsMixed
return MixedState
else
return DupState
end
end

if ismutabletype(T)
Expand Down Expand Up @@ -504,7 +549,7 @@ end

seen2 = (Val(nT), seen...)

fty = Merger{seen2,typeof(world),justActive, UnionSret}(world)
fty = Merger{seen2,typeof(world),justActive, UnionSret, AbstractIsMixed}(world)

ty = forcefold(Val(AnyState), ntuple(fty, Val(fieldcount(nT)))...)

Expand Down Expand Up @@ -1158,20 +1203,6 @@ function permit_inlining!(f::LLVM.Function)
end
end

# From https://github.com/JuliaLang/julia/blob/81813164963f38dcd779d65ecd222fad8d7ed437/src/cgutils.cpp#L570
@inline function isghostty(ty)
if ty === Union{}
return true
end
if Base.isconcretetype(ty) && !ismutabletype(ty)
if sizeof(ty) == 0
return true
end
# TODO consider struct_to_llvm ?
end
return false
end

struct Tape{TapeTy,ShadowTy,ResT}
internal_tape::TapeTy
shadow_return::ShadowTy
Expand Down
57 changes: 55 additions & 2 deletions src/rules/jitrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,15 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing,
end
)
else
mixexpr = if Width == 1
quote
iterate_unwrap_augfwd_mix(Val($reverse), refs, $(primargs[i]), $(shadowargs[i]))
end
else
quote
iterate_unwrap_augfwd_batchmix(Val($reverse), refs, Val($Width), $(primargs[i]), $(shadowargs[i]))
end
end
dupexpr = if Width == 1
quote
iterate_unwrap_augfwd_dup(Val($reverse), refs, $(primargs[i]), $(shadowargs[i]))
Expand All @@ -110,8 +119,7 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing,
if $aref == ActiveState
iterate_unwrap_augfwd_act($(primargs[i])...)
elseif $aref == MixedState
T = $(primtypes[i])
throw(AssertionError("Mixed State of type $T is unsupported in apply iterate"))
$mixexpr
else
$dupexpr
end
Expand Down Expand Up @@ -586,6 +594,51 @@ end
end
end

@inline function iterate_unwrap_augfwd_mix(::Val{reverse}, vals, args, dargs0) where reverse
dargs = dargs0[]
ntuple(Val(length(args))) do i
Base.@_inline_meta
arg = args[i]
ty = Core.Typeof(arg)
actreg = active_reg_nothrow(ty, Val(nothing))
if actreg == AnyState
Const(arg)
elseif actreg == ActiveState
Active(arg)
elseif actreg == MixedState
darg = Base.inferencebarrier(dargs[i])
MixedDuplicated(arg, push_if_not_ref(Val(reverse), vals, darg, ty)::Base.RefValue{ty})
else
Duplicated(arg, dargs[i])
end
end
end

@inline function iterate_unwrap_augfwd_batchmix(::Val{reverse}, vals, ::Val{Width}, args, dargs) where {reverse, Width}
ntuple(Val(length(args))) do i
Base.@_inline_meta
arg = args[i]
ty = Core.Typeof(arg)
actreg = active_reg_nothrow(ty, Val(nothing))
if actreg == AnyState
Const(arg)
elseif actreg == ActiveState
Active(arg)
elseif actreg == MixedState
BatchMixedDuplicated(arg, ntuple(Val(Width)) do j
Base.@_inline_meta
darg = Base.inferencebarrier(dargs[j][][i])
push_if_not_ref(Val(reverse), vals, darg, ty)::Base.RefValue{ty}
end)
else
BatchDuplicated(arg, ntuple(Val(Width)) do j
Base.@_inline_meta
dargs[j][][i]
end)
end
end
end

@inline function allFirst(::Val{Width}, res) where Width
ntuple(Val(Width)) do i
Base.@_inline_meta
Expand Down
24 changes: 12 additions & 12 deletions src/rules/typeunstablerules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ function body_construct_augfwd(N, Width, primtypes, active_refs, primargs, batch
shadow_rets_i = Expr[]
aref = Symbol("active_ref_$i")
for w in 1:Width
sref = Symbol("shadow_"*string(i)*"_"*string(w))
sref = Symbol("sub_shadow_"*string(i)*"_"*string(w))
push!(shadow_rets_i, quote
$sref = if $aref == AnyState
$(primargs[i]);
else
if !ActivityTup[$i]
if $aref == DupState || $aref == MixedState
if ($aref == DupState || $aref == MixedState) && $(batchshadowargs[i][w]) === nothing
prim = $(primargs[i])
throw("Error cannot store inactive but differentiable variable $prim into active tuple")
end
Expand Down Expand Up @@ -98,7 +98,7 @@ function body_construct_rev(N, Width, primtypes, active_refs, primargs, batchsha
shad = batchshadowargs[i][w]
out = :(if $(Symbol("active_ref_$i")) == MixedState || $(Symbol("active_ref_$i")) == ActiveState
if $shad isa Base.RefValue
$shad[] = recursive_add($shad[], $expr)
$shad[] = recursive_add($shad[], $expr, identity, guaranteed_nonactive)
else
error("Enzyme Mutability Error: Cannot add one in place to immutable value "*string($shad))
end
Expand Down Expand Up @@ -248,10 +248,10 @@ function newstruct_common(fwd, run, offset, B, orig, gutils, normalR, shadowR)
# if any active [e.g. ActiveState / MixedState] data could exist
# err
if !fwd
if !found
if !found_partial
return false
end
act = active_reg_inner(typ, (), world)
act = active_reg_inner(typ_partial, (), world, #=justactive=#Val(false), #=unionsret=#Val(false), #=abstractismixed=#Val(true))
if act == MixedState || act == ActiveState
return false
end
Expand Down Expand Up @@ -306,7 +306,7 @@ function common_newstructv_fwd(offset, B, orig, gutils, normalR, shadowR)
return false
end

function common_newstructv_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR)
function common_newstructv_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR)::Bool
needsShadowP = Ref{UInt8}(0)
needsPrimalP = Ref{UInt8}(0)
activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils))
Expand Down Expand Up @@ -379,7 +379,7 @@ function common_f_tuple_fwd(offset, B, orig, gutils, normalR, shadowR)
common_newstructv_fwd(offset, B, orig, gutils, normalR, shadowR)
end

function common_f_tuple_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR)
function common_f_tuple_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR)::Bool
needsShadowP = Ref{UInt8}(0)
needsPrimalP = Ref{UInt8}(0)
activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils))
Expand Down Expand Up @@ -420,8 +420,8 @@ function common_f_tuple_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR)

unsafe_store!(tapeR, sret.ref)

return false
end
return false
end

function common_f_tuple_rev(offset, B, orig, gutils, tape)
Expand Down Expand Up @@ -474,7 +474,7 @@ function f_tuple_fwd(B, orig, gutils, normalR, shadowR)
common_f_tuple_fwd(1, B, orig, gutils, normalR, shadowR)
end

function f_tuple_augfwd(B, orig, gutils, normalR, shadowR, tapeR)
function f_tuple_augfwd(B, orig, gutils, normalR, shadowR, tapeR)::Bool
common_f_tuple_augfwd(1, B, orig, gutils, normalR, shadowR, tapeR)
end

Expand All @@ -487,7 +487,7 @@ function new_structv_fwd(B, orig, gutils, normalR, shadowR)
common_newstructv_fwd(1, B, orig, gutils, normalR, shadowR)
end

function new_structv_augfwd(B, orig, gutils, normalR, shadowR, tapeR)
function new_structv_augfwd(B, orig, gutils, normalR, shadowR, tapeR)::Bool
common_newstructv_augfwd(1, B, orig, gutils, normalR, shadowR, tapeR)
end

Expand Down Expand Up @@ -525,7 +525,7 @@ function new_structt_fwd(B, orig, gutils, normalR, shadowR)
unsafe_store!(shadowR, shadowres.ref)
return false
end
function new_structt_augfwd(B, orig, gutils, normalR, shadowR, tapeR)
function new_structt_augfwd(B, orig, gutils, normalR, shadowR, tapeR)::Bool
new_structt_fwd(B, orig, gutils, normalR, shadowR)
end

Expand Down Expand Up @@ -821,7 +821,7 @@ function idx_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}
return nothing
end

function common_jl_getfield_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR)
function common_jl_getfield_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR)::Bool
if is_constant_value(gutils, orig) || unsafe_load(shadowR) == C_NULL
return true
end
Expand Down
25 changes: 24 additions & 1 deletion test/mixedapplyiter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,4 +141,27 @@ end
@test out[] 5562.9996
@test tupapprox(dx, [[(4.0, [5.4]), (6.0, [6.28])], [(15.8, [94.0]), (22.4, [112.0])]])
@test tupapprox(dx2, [[(3*4.0, [3*5.4]), (3*6.0, [3*6.28])], [(3*15.8, [3*94.0]), (3*22.4, [3*112.0])]])
end
end

struct MyRectilinearGrid5{FT,FZ}
x :: FT
z :: FZ
end


@inline flatten_tuple(a::Tuple) = @inbounds a[2:end]
@inline flatten_tuple(a::Tuple{<:Any}) = tuple() #inner_flatten_tuple(a[1])...)

function myupdate_state!(model)
tupled = Base.inferencebarrier((model,model))
flatten_tuple(tupled)
return nothing
end

@testset "Abstract type allocation" begin
model = MyRectilinearGrid5{Float64, Vector{Float64}}(0.0, [0.0])
dmodel = MyRectilinearGrid5{Float64, Vector{Float64}}(0.0, [0.0])
autodiff(Enzyme.Reverse,
myupdate_state!,
MixedDuplicated(model, Ref(dmodel)))
end
Loading

0 comments on commit 82cc451

Please sign in to comment.