Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2nd order AD fails #298

Closed
avik-pal opened this issue Nov 20, 2024 · 6 comments · Fixed by #433
Closed

2nd order AD fails #298

avik-pal opened this issue Nov 20, 2024 · 6 comments · Fixed by #433

Comments

@avik-pal
Copy link
Collaborator

using Reactant, Enzyme, Lux, Random, LinearAlgebra

const xdev = reactant_device()
const cdev = cpu_device()

model = Dense(5 => 5, gelu);
ps, st = Lux.setup(Random.default_rng(), model) |> xdev;
potential = StatefulLuxLayer{true}(model, ps, st)

# Currently EnzymeMLIR doesn't support batching so we force chunksize to 1
function ∇potential(potential, x)
    J = reshape(only(Enzyme.jacobian(Forward, potential, x; chunk=Val(1))), :, length(x))
    J_diag = @allowscalar diag(J)
    return reshape(J_diag, size(x))
end

function ∇²potential(potential, x)
    J = reshape(only(
        Enzyme.jacobian(Forward, Base.Fix1(∇potential, potential), x; chunk=Val(1))
    ), :, length(x))
end

x_ra = randn(Float32, 5, 3) |> xdev

@code_hlo ∇²potential(potential, x_ra)

A non-minimal example taken from LuxDL/Lux.jl#614

@avik-pal
Copy link
Collaborator Author

Error Msg

ERROR: AssertionError: Base.isconcretetype(typ)
Stacktrace:
  [1] abs_typeof(arg::LLVM.LoadInst, partial::Bool, seenphis::Set{LLVM.PHIInst})
    @ Enzyme.Compiler /mnt/.julia/packages/Enzyme/RTS5U/src/absint.jl:557
  [2] abs_typeof(arg::LLVM.LoadInst)
    @ Enzyme.Compiler /mnt/.julia/packages/Enzyme/RTS5U/src/absint.jl:283
  [3] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler /mnt/.julia/packages/Enzyme/RTS5U/src/compiler.jl:7066
  [4] codegen
    @ /mnt/.julia/packages/Enzyme/RTS5U/src/compiler.jl:6146 [inlined]
  [5] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
    @ Enzyme.Compiler /mnt/.julia/packages/Enzyme/RTS5U/src/compiler.jl:8468
  [6] _thunk
    @ /mnt/.julia/packages/Enzyme/RTS5U/src/compiler.jl:8468 [inlined]
  [7] cached_compilation
    @ /mnt/.julia/packages/Enzyme/RTS5U/src/compiler.jl:8509 [inlined]
  [8] thunkbase(ctx::LLVM.Context, mi::Core.MethodInstance, ::Val{…}, ::Type{…}, ::Type{…}, tt::Type{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Type{…}, ::Val{…}, ::Val{…})
    @ Enzyme.Compiler /mnt/.julia/packages/Enzyme/RTS5U/src/compiler.jl:8641
  [9] #s2105#19135
    @ /mnt/.julia/packages/Enzyme/RTS5U/src/compiler.jl:8778 [inlined]
 [10] 
    @ Enzyme.Compiler ./none:0
 [11] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any})
    @ Core ./boot.jl:707
 [12] autodiff
    @ /mnt/.julia/packages/Enzyme/RTS5U/src/Enzyme.jl:633 [inlined]
 [13] autodiff
    @ /mnt/.julia/packages/Enzyme/RTS5U/src/Enzyme.jl:512 [inlined]
 [14] macro expansion
    @ /mnt/.julia/packages/Enzyme/RTS5U/src/Enzyme.jl:2090 [inlined]
 [15] gradient(::ForwardMode{…}, ::StatefulLuxLayer{…}, ::Reactant.TracedRArray{…}; chunk::Val{…}, shadows::Tuple{…})
    @ Enzyme /mnt/.julia/packages/Enzyme/RTS5U/src/Enzyme.jl:1970
 [16] #jacobian#133
    @ /mnt/.julia/packages/Enzyme/RTS5U/src/Enzyme.jl:2177 [inlined]
 [17] jacobian
    @ /mnt/.julia/packages/Enzyme/RTS5U/src/Enzyme.jl:2176 [inlined]
 [18] ∇potential(potential::StatefulLuxLayer{…}, x::Reactant.TracedRArray{…})
    @ Main /mnt/software/lux/Lux.jl/docs/src/manual/nested_autodiff_reactant.md:17
 [19] Fix1
    @ ./operators.jl:1127 [inlined]
 [20] #apply#24
    @ /mnt/software/lux/Reactant.jl/src/utils.jl:37 [inlined]
 [21] apply
    @ /mnt/software/lux/Reactant.jl/src/utils.jl:36 [inlined]
 [22] (::Tuple{})(none::Base.Fix1{typeof(∇potential), StatefulLuxLayer{…}}, none::Tuple{Reactant.TracedRArray{…}})
    @ Base.Experimental ./<missing>:0
 [23] (::Reactant.var"#32#42"{Bool, Bool, typeof(Reactant.apply), Tuple{}, Vector{}, Tuple{}})()
    @ Reactant /mnt/software/lux/Reactant.jl/src/utils.jl:148
 [24] block!(f::Reactant.var"#32#42"{}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Block.jl:201
 [25] make_mlir_fn(f::Function, args::Tuple{…}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, no_args_in_result::Bool, construct_function_without_args::Bool, do_transpose::Bool)
    @ Reactant /mnt/software/lux/Reactant.jl/src/utils.jl:120
 [26] make_mlir_fn
    @ /mnt/software/lux/Reactant.jl/src/utils.jl:40 [inlined]
 [27] #make_mlir_fn#25
    @ /mnt/software/lux/Reactant.jl/src/utils.jl:53 [inlined]
 [28] make_mlir_fn
    @ /mnt/software/lux/Reactant.jl/src/utils.jl:40 [inlined]
 [29] overload_autodiff(::ForwardMode{…}, f::Const{…}, ::Type{…}, args::Duplicated{…})
    @ Reactant /mnt/software/lux/Reactant.jl/src/Interpreter.jl:373
 [30] autodiff
    @ /mnt/software/lux/Reactant.jl/src/Interpreter.jl:660 [inlined]
 [31] autodiff
    @ /mnt/.julia/packages/Enzyme/RTS5U/src/Enzyme.jl:512 [inlined]
 [32] macro expansion
    @ /mnt/.julia/packages/Enzyme/RTS5U/src/Enzyme.jl:2090 [inlined]
 [33] gradient(::ForwardMode{…}, ::Base.Fix1{…}, ::Reactant.TracedRArray{…}; chunk::Val{…}, shadows::Tuple{…})
    @ Enzyme /mnt/.julia/packages/Enzyme/RTS5U/src/Enzyme.jl:1970
 [34] #jacobian#133
    @ /mnt/.julia/packages/Enzyme/RTS5U/src/Enzyme.jl:2177 [inlined]
 [35] jacobian
    @ /mnt/.julia/packages/Enzyme/RTS5U/src/Enzyme.jl:2176 [inlined]
 [36] ∇²potential
    @ /mnt/software/lux/Lux.jl/docs/src/manual/nested_autodiff_reactant.md:23 [inlined]
 [37] (::Tuple{})(none::StatefulLuxLayer{…}, none::Reactant.TracedRArray{…})
    @ Base.Experimental ./<missing>:0
 [38] (::Reactant.var"#32#42"{Bool, Bool, typeof(∇²potential), Tuple{}, Vector{}, Tuple{}})()
    @ Reactant /mnt/software/lux/Reactant.jl/src/utils.jl:157
 [39] block!(f::Reactant.var"#32#42"{}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Block.jl:201
 [40] make_mlir_fn(f::Function, args::Tuple{…}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, no_args_in_result::Bool, construct_function_without_args::Bool, do_transpose::Bool)
    @ Reactant /mnt/software/lux/Reactant.jl/src/utils.jl:120
 [41] make_mlir_fn
    @ /mnt/software/lux/Reactant.jl/src/utils.jl:40 [inlined]
 [42] #10
    @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:286 [inlined]
 [43] block!(f::Reactant.Compiler.var"#10#15"{typeof(∇²potential), Tuple{}}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Block.jl:201
 [44] #9
    @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:285 [inlined]
 [45] mmodule!(f::Reactant.Compiler.var"#9#14"{}, blk::Reactant.MLIR.IR.Module)
    @ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Module.jl:92
 [46] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{…}; optimize::Bool)
    @ Reactant.Compiler /mnt/software/lux/Reactant.jl/src/Compiler.jl:282
 [47] compile_mlir!
    @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:281 [inlined]
 [48] #6
    @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:276 [inlined]
 [49] context!(f::Reactant.Compiler.var"#6#7"{@Kwargs{}, typeof(∇²potential), Tuple{}}, ctx::Reactant.MLIR.IR.Context)
    @ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Context.jl:76
 [50] compile_mlir(f::Function, args::Tuple{StatefulLuxLayer{…}, ConcreteRArray{…}}; kwargs::@Kwargs{optimize::Bool})
    @ Reactant.Compiler /mnt/software/lux/Reactant.jl/src/Compiler.jl:274
Some type information was truncated. Use `show(err)` to see complete types.

@wsmoses
Copy link
Member

wsmoses commented Nov 21, 2024

just for fun what if you do set_abi(Forward, ReactantABI)

@avik-pal
Copy link
Collaborator Author

That did work!

@wsmoses
Copy link
Member

wsmoses commented Nov 22, 2024

yeah so this is again stemming from "any abstract interpreter based shenanigans fails to go through type unstable code".

Here the actual resolution we did earlier is to make Forward be replaced by set_abi(Forward, ReactantABI) in our absint. This makes things way nicer (including doing the replacement at the callsite of autodiff/jacobian/etc), so any intermediates that are type unstable don't have any issues. Similarly, it means we can natively call it like above. Unfortunately this only applies at the top level absint.

Probably the solution here is to have the absint replace type unstable calls with my_call(...) which itself runs things again in an absint.

@wsmoses
Copy link
Member

wsmoses commented Dec 30, 2024

Incidentally now that @jumerckx initial batch work is in it would be interesting to see if batched AD works here

Also now that the absint stuff is in presumably now the set abi stuff shouldn’t be needed

@avik-pal
Copy link
Collaborator Author

I checked set_abi is not needed, but we need stack to work (#433) for jacobian

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants