Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nested AD Errors Out #2147

Open
turiya4 opened this issue Nov 30, 2024 · 7 comments
Open

Nested AD Errors Out #2147

turiya4 opened this issue Nov 30, 2024 · 7 comments

Comments

@turiya4
Copy link

turiya4 commented Nov 30, 2024

The following code (using Julia 1.10.6 and Enzyme 13.17) does simple Nested AD. However this results in an error.

using Enzyme, Lux, Random, ComponentArrays, LinearAlgebra
n = 1
x_batch = randn(2, n)
y_batch = randn(2, n)
model = Chain(Parallel(vcat, Dense(2, 1, tanh), Dense(2, 1, tanh)), Dense(2, 1, tanh))
rng = Random.default_rng()
Random.seed!(rng, 0)
ps, st = Lux.setup(Xoshiro(0), model);
psaxes = getaxes(ComponentArray(ps))

nnfunc(x, y, psarray, st) = first(model((x, y), ComponentArray(psarray, psaxes), st))[1]
psarray = getdata(ComponentArray(ps))

function batch_error(xb, yb, psarray, st)
    val = zeros(n)
    for i = 1 : n
        dx = zeros(2)
        Enzyme.autodiff(Enzyme.Reverse, nnfunc, Active, Duplicated(xb[:, i], dx), Duplicated(yb[:, i], zeros(2)), Duplicated(psarray, zeros(Float32, size(psarray))), Const(st))
        val[i] = sum(dx.^2)
    end
    return sum(val)
end

dpsarray = zeros(Float32, size(psarray))
Enzyme.autodiff(Enzyme.Reverse, batch_error, Active, Duplicated(x_batch,zeros(size(x_batch))), Duplicated(y_batch,zeros(size(y_batch))), Duplicated(psarray, dpsarray), Const(st))
ERROR: LoadError: Enzyme compilation failed.
Current scope: 
; Function Attrs: mustprogress willreturn
define "enzyme_type"="{[-1]:Pointer}" "enzymejl_parmtype"="123305751457168" "enzymejl_parmtype_ref"="1" [3 x {} addrspace(10)*] @preprocess_julia_runtime_generic_augfwd_4730_inner.1({} addrspace(10)* nocapture nofree noundef nonnull readnone "enzyme_inactive" "enzyme_type"="{[-1]:Pointer}" "enzymejl_parmtype"="123305469593088" "enzymejl_parmtype_ref"="2" %0, {} addrspace(10)* noundef nonnull align 16 dereferenceable(40) "enzyme_type"="{[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer}" "enzymejl_parmtype"="123305771723920" "enzymejl_parmtype_ref"="2" %1, {} addrspace(10)* noundef nonnull align 16 dereferenceable(40) "enzyme_type"="{[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer}" "enzymejl_parmtype"="123305771723920" "enzymejl_parmtype_ref"="2" %2) local_unnamed_addr #5 !dbg !47 {
entry:
  %3 = call {}*** @julia.get_pgcstack() #9, !noalias !48
  %current_task1.i6 = getelementptr inbounds {}**, {}*** %3, i64 -14
  %current_task1.i = bitcast {}*** %current_task1.i6 to {}**
  %ptls_field.i7 = getelementptr inbounds {}**, {}*** %3, i64 2
  %4 = bitcast {}*** %ptls_field.i7 to i64***
  %ptls_load.i89 = load i64**, i64*** %4, align 8, !tbaa !11, !noalias !48
  %5 = getelementptr inbounds i64*, i64** %ptls_load.i89, i64 2
  %safepoint.i = load i64*, i64** %5, align 8, !tbaa !15, !noalias !48
  fence syncscope("singlethread") seq_cst
  call void @julia.safepoint(i64* %safepoint.i) #9, !dbg !51, !noalias !48
  fence syncscope("singlethread") seq_cst
  %6 = call { { {} addrspace(10)* }, { {} addrspace(10)* } } inttoptr (i64 123305430303696 to { { {} addrspace(10)* }, { {} addrspace(10)* } } ({} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*)*)({} addrspace(10)* addrspacecast ({}* inttoptr (i64 123305469593184 to {}*) to {} addrspace(10)*), {} addrspace(10)* nonnull %1, {} addrspace(10)* nonnull %2) #9, !dbg !53
  %7 = extractvalue { { {} addrspace(10)* }, { {} addrspace(10)* } } %6, 0, !dbg !57
  %8 = extractvalue { { {} addrspace(10)* }, { {} addrspace(10)* } } %6, 1, !dbg !57
  %box.i = call noalias nonnull dereferenceable(8) "enzyme_type"="{[-1]:Pointer, [-1,-1]:Pointer, [-1,-1,0]:Pointer, [-1,-1,0,-1]:Float@float, [-1,-1,8]:Integer, [-1,-1,9]:Integer, [-1,-1,10]:Integer, [-1,-1,11]:Integer, [-1,-1,12]:Integer, [-1,-1,13]:Integer, [-1,-1,14]:Integer, [-1,-1,15]:Integer, [-1,-1,16]:Integer, [-1,-1,17]:Integer, [-1,-1,18]:Integer, [-1,-1,19]:Integer, [-1,-1,20]:Integer, [-1,-1,21]:Integer, [-1,-1,22]:Integer, [-1,-1,23]:Integer, [-1,-1,24]:Integer, [-1,-1,25]:Integer, [-1,-1,26]:Integer, [-1,-1,27]:Integer, [-1,-1,28]:Integer, [-1,-1,29]:Integer, [-1,-1,30]:Integer, [-1,-1,31]:Integer, [-1,-1,32]:Integer, [-1,-1,33]:Integer, [-1,-1,34]:Integer, [-1,-1,35]:Integer, [-1,-1,36]:Integer, [-1,-1,37]:Integer, [-1,-1,38]:Integer, [-1,-1,39]:Integer}" {} addrspace(10)* @julia.gc_alloc_obj({}** nonnull %current_task1.i, i64 noundef 8, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 123305780515024 to {}*) to {} addrspace(10)*)) #10, !dbg !58
  %9 = bitcast {} addrspace(10)* %box.i to { {} addrspace(10)* } addrspace(10)*, !dbg !58
  %10 = extractvalue { {} addrspace(10)* } %7, 0, !dbg !58
  %11 = getelementptr { {} addrspace(10)* }, { {} addrspace(10)* } addrspace(10)* %9, i64 0, i32 0, !dbg !58
  store {} addrspace(10)* %10, {} addrspace(10)* addrspace(10)* %11, align 8, !dbg !58, !tbaa !32, !alias.scope !36, !noalias !60
  %box4.i = call noalias nonnull dereferenceable(8) "enzyme_type"="{[-1]:Pointer, [-1,-1]:Pointer, [-1,-1,0]:Pointer, [-1,-1,0,-1]:Float@float, [-1,-1,8]:Integer, [-1,-1,9]:Integer, [-1,-1,10]:Integer, [-1,-1,11]:Integer, [-1,-1,12]:Integer, [-1,-1,13]:Integer, [-1,-1,14]:Integer, [-1,-1,15]:Integer, [-1,-1,16]:Integer, [-1,-1,17]:Integer, [-1,-1,18]:Integer, [-1,-1,19]:Integer, [-1,-1,20]:Integer, [-1,-1,21]:Integer, [-1,-1,22]:Integer, [-1,-1,23]:Integer, [-1,-1,24]:Integer, [-1,-1,25]:Integer, [-1,-1,26]:Integer, [-1,-1,27]:Integer, [-1,-1,28]:Integer, [-1,-1,29]:Integer, [-1,-1,30]:Integer, [-1,-1,31]:Integer, [-1,-1,32]:Integer, [-1,-1,33]:Integer, [-1,-1,34]:Integer, [-1,-1,35]:Integer, [-1,-1,36]:Integer, [-1,-1,37]:Integer, [-1,-1,38]:Integer, [-1,-1,39]:Integer}" {} addrspace(10)* @julia.gc_alloc_obj({}** nonnull %current_task1.i, i64 noundef 8, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 123305780515024 to {}*) to {} addrspace(10)*)) #10, !dbg !58
  %12 = bitcast {} addrspace(10)* %box4.i to { {} addrspace(10)* } addrspace(10)*, !dbg !58
  %13 = extractvalue { {} addrspace(10)* } %8, 0, !dbg !58
  %14 = getelementptr { {} addrspace(10)* }, { {} addrspace(10)* } addrspace(10)* %12, i64 0, i32 0, !dbg !58
  store {} addrspace(10)* %13, {} addrspace(10)* addrspace(10)* %14, align 8, !dbg !58, !tbaa !32, !alias.scope !36, !noalias !60
  %.fca.0.insert = insertvalue [3 x {} addrspace(10)*] poison, {} addrspace(10)* %box.i, 0, !dbg !63
  %.fca.1.insert = insertvalue [3 x {} addrspace(10)*] %.fca.0.insert, {} addrspace(10)* %box4.i, 1, !dbg !63
  %.fca.2.insert = insertvalue [3 x {} addrspace(10)*] %.fca.1.insert, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 123304558811824 to {}*) to {} addrspace(10)*), 2, !dbg !63
  ret [3 x {} addrspace(10)*] %.fca.2.insert, !dbg !63
}

Did not have return index set when differentiating function
 call  %6 = call { { {} addrspace(10)* }, { {} addrspace(10)* } } inttoptr (i64 123305430303696 to { { {} addrspace(10)* }, { {} addrspace(10)* } } ({} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*)*)({} addrspace(10)* addrspacecast ({}* inttoptr (i64 123305469593184 to {}*) to {} addrspace(10)*), {} addrspace(10)* nonnull %1, {} addrspace(10)* nonnull %2) #9, !dbg !19
 augmentcall  %_augmented = call { i8*, { { {} addrspace(10)* }, { {} addrspace(10)* } } } %15({} addrspace(10)* addrspacecast ({}* inttoptr (i64 123305469593184 to {}*) to {} addrspace(10)*), {} addrspace(10)* addrspacecast ({}* inttoptr (i64 123305469593184 to {}*) to {} addrspace(10)*), {} addrspace(10)* %1, {} addrspace(10)* %"'", {} addrspace(10)* %2, {} addrspace(10)* %"'1"), !dbg !19


Stacktrace:
 [1] macro expansion
   @ ~/.julia/packages/Enzyme/fpA3W/src/compiler.jl:6229
 [2] enzyme_call
   @ ~/.julia/packages/Enzyme/fpA3W/src/compiler.jl:5775
 [3] AugmentedForwardThunk
   @ ~/.julia/packages/Enzyme/fpA3W/src/compiler.jl:5697
 [4] runtime_generic_augfwd
   @ ~/.julia/packages/Enzyme/fpA3W/src/rules/jitrules.jl:480
 [5] runtime_generic_augfwd
   @ ~/.julia/packages/Enzyme/fpA3W/src/rules/jitrules.jl:0

Stacktrace:
  [1] julia_error(msg::String, val::Ptr{LLVM.API.LLVMOpaqueValue}, errtype::Enzyme.API.ErrorType, data::Ptr{Nothing}, data2::Ptr{LLVM.API.LLVMOpaqueValue}, B::Ptr{LLVM.API.LLVMOpaqueBuilder})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/fpA3W/src/errors.jl:242
  [2] julia_error(cstr::Cstring, val::Ptr{LLVM.API.LLVMOpaqueValue}, errtype::Enzyme.API.ErrorType, data::Ptr{Nothing}, data2::Ptr{LLVM.API.LLVMOpaqueValue}, B::Ptr{LLVM.API.LLVMOpaqueBuilder})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/fpA3W/src/errors.jl:97
  [3] EnzymeCreateAugmentedPrimal(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnUsed::Bool, shadowReturnUsed::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…}, forceAnonymousTape::Bool, runtimeActivity::Bool, width::Int64, atomicAdd::Bool)
    @ Enzyme.API ~/.julia/packages/Enzyme/fpA3W/src/api.jl:389
  [4] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::Tuple{…} where N, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/fpA3W/src/compiler.jl:2145
  [5] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/fpA3W/src/compiler.jl:5426
  [6] codegen
    @ ~/.julia/packages/Enzyme/fpA3W/src/compiler.jl:4196 [inlined]
  [7] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/fpA3W/src/compiler.jl:6298
  [8] _thunk
    @ ~/.julia/packages/Enzyme/fpA3W/src/compiler.jl:6298 [inlined]
  [9] cached_compilation
    @ ~/.julia/packages/Enzyme/fpA3W/src/compiler.jl:6339 [inlined]
 [10] thunkbase(mi::Core.MethodInstance, World::UInt64, FA::Type{…}, A::Type{…}, TT::Type, Mode::Enzyme.API.CDerivativeMode, width::Int64, ModifiedBetween::Tuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/fpA3W/src/compiler.jl:6452
 [11] thunk_generator(world::UInt64, source::LineNumberNode, FA::Type, A::Type, TT::Type, Mode::Enzyme.API.CDerivativeMode, Width::Int64, ModifiedBetween::Tuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, self::Any, fakeworld::Any, fa::Type, a::Type, tt::Type, mode::Type, width::Type, modifiedbetween::Type, returnprimal::Type, shadowinit::Type, abi::Type, erriffuncwritten::Type, runtimeactivity::Type)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/fpA3W/src/compiler.jl:6604
 [12] runtime_generic_augfwd(activity::Type{…}, runtimeActivity::Val{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::typeof(Enzyme.Compiler.runtime_generic_augfwd), df::Nothing, primal_1::Type{…}, shadow_1_1::Nothing, primal_2::Val{…}, shadow_2_1::Nothing, primal_3::Val{…}, shadow_3_1::Nothing, primal_4::Val{…}, shadow_4_1::Nothing, primal_5::Val{…}, shadow_5_1::Nothing, primal_6::Type{…}, shadow_6_1::Nothing, primal_7::Nothing, shadow_7_1::Nothing, primal_8::Vector{…}, shadow_8_1::Vector{…}, primal_9::Vector{…}, shadow_9_1::Vector{…}, primal_10::Tuple{…}, shadow_10_1::Nothing, primal_11::Nothing, shadow_11_1::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/fpA3W/src/rules/jitrules.jl:465
 [13] nnfunc
    @ ~/temp/test.jl:11 [inlined]
 [14] augmented_julia_nnfunc_2777wrap
    @ ~/temp/test.jl:0
 [15] macro expansion
    @ ~/.julia/packages/Enzyme/fpA3W/src/compiler.jl:6229 [inlined]
 [16] enzyme_call
    @ ~/.julia/packages/Enzyme/fpA3W/src/compiler.jl:5775 [inlined]
 [17] AugmentedForwardThunk
    @ ~/.julia/packages/Enzyme/fpA3W/src/compiler.jl:5697 [inlined]
 [18] autodiff_deferred
    @ ~/.julia/packages/Enzyme/fpA3W/src/Enzyme.jl:729 [inlined]
 [19] autodiff
    @ ~/.julia/packages/Enzyme/fpA3W/src/Enzyme.jl:524 [inlined]
 [20] batch_error
    @ ~/temp/test.jl:18 [inlined]
 [21] augmented_julia_batch_error_2348wrap
    @ ~/temp/test.jl:0
 [22] macro expansion
    @ ~/.julia/packages/Enzyme/fpA3W/src/compiler.jl:6229 [inlined]
 [23] enzyme_call
    @ ~/.julia/packages/Enzyme/fpA3W/src/compiler.jl:5775 [inlined]
 [24] AugmentedForwardThunk
    @ ~/.julia/packages/Enzyme/fpA3W/src/compiler.jl:5697 [inlined]
 [25] autodiff
    @ ~/.julia/packages/Enzyme/fpA3W/src/Enzyme.jl:396 [inlined]
 [26] autodiff(::ReverseMode{false, false, FFIABI, false, false}, ::typeof(batch_error), ::Type{Active}, ::Duplicated{Matrix{…}}, ::Duplicated{Matrix{…}}, ::Duplicated{Vector{…}}, ::Const{@NamedTuple{…}})
    @ Enzyme ~/.julia/packages/Enzyme/fpA3W/src/Enzyme.jl:524
 [27] top-level scope
    @ ~/temp/test.jl:25
 [28] include(fname::String)
    @ Base.MainInclude ./client.jl:494
 [29] top-level scope
    @ REPL[1]:1
in expression starting at /home/work/temp/test.jl:25
Some type information was truncated. Use `show(err)` to see complete types.

@vchuravy
Copy link
Member

Please post the backtrace in full, you cut out important information

@turiya4
Copy link
Author

turiya4 commented Nov 30, 2024

Sorry. I have updated the original post with the full error.

@vchuravy
Copy link
Member

I am confused? You removed even more information?

@turiya4
Copy link
Author

turiya4 commented Nov 30, 2024

Could you please let me know now? I have updated the error.

@vchuravy
Copy link
Member

Thanks this information is always important:


Did not have return index set when differentiating function
 call  %6 = call { { {} addrspace(10)* }, { {} addrspace(10)* } } inttoptr (i64 123305430303696 to { { {} addrspace(10)* }, { {} addrspace(10)* } } ({} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*)*)({} addrspace(10)* addrspacecast ({}* inttoptr (i64 123305469593184 to {}*) to {} addrspace(10)*), {} addrspace(10)* nonnull %1, {} addrspace(10)* nonnull %2) #9, !dbg !19
 augmentcall  %_augmented = call { i8*, { { {} addrspace(10)* }, { {} addrspace(10)* } } } %15({} addrspace(10)* addrspacecast ({}* inttoptr (i64 123305469593184 to {}*) to {} addrspace(10)*), {} addrspace(10)* addrspacecast ({}* inttoptr (i64 123305469593184 to {}*) to {} addrspace(10)*), {} addrspace(10)* %1, {} addrspace(10)* %"'", {} addrspace(10)* %2, {} addrspace(10)* %"'1"), !dbg !19

@wsmoses
Copy link
Member

wsmoses commented Nov 30, 2024

Yeah so this issue is that there’s an inttoptr (presumably a runtime function) that we didn’t restore and thus didn’t properly handle

@wsmoses
Copy link
Member

wsmoses commented Nov 30, 2024

Oh no, actually this is an instance of the deferred codegen not triggering or something?

In particular I presume the inttoptr call is from the inner AD.

We need to fix this, but just for fun what happens if you do set_abi(Reverse, InlineABI) for the innermost call

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants