Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Enzyme integration following minimal example #479

Open
Larbino1 opened this issue Mar 4, 2024 · 2 comments
Open

Fix Enzyme integration following minimal example #479

Larbino1 opened this issue Mar 4, 2024 · 2 comments
Labels
bug Something isn't working

Comments

@Larbino1
Copy link

Larbino1 commented Mar 4, 2024

Enzyme autodiff fails to work with LinearSolve.

Minimal Reproducible Example 👇

using Enzyme
using LinearSolve

function testls(A, b, u)
    oa = OperatorAssumptions(true, condition = LinearSolve.OperatorCondition.WellConditioned)
    prob = LinearProblem{true}(A, b, u0=u)
    linsolve = init(prob, oa)
    solve(linsolve)
    return nothing
end


A = [1. 2.; 3. 4.]
b = [1., 2.]
u = zero(b)
dA = deepcopy(A)
db = deepcopy(b)
du = deepcopy(u)
@time testls(A, b, u)

Enzyme.autodiff(Reverse, testls, Duplicated(A, dA), Duplicated(b, db), Duplicated(u, du))

**Error & Stacktrace ⚠️**

```julia
ERROR: Enzyme execution failed.
Enzyme: Augmented forward pass custom rule Tuple{EnzymeCore.EnzymeRules.ConfigWidth{1, true, false, (false, true)}, Const{typeof(solve!)}, Type{Const{SciMLBase.LinearSolution{Float64, 1, Vector{Float64}, Nothing, LinearSolve.DefaultLinearSolver, LinearSolve.LinearCache{Matrix{Float64}, Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, LinearSolve.DefaultLinearSolver, LinearSolve.DefaultLinearSolverInit{LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, LinearAlgebra.QRCompactWY{Float64, Matrix{Float64}, Matrix{Float64}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Tuple{LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Vector{Int64}}, Nothing, Nothing, Nothing, LinearAlgebra.SVD{Float64, Float64, Matrix{Float64}, Vector{Float64}}, LinearAlgebra.Cholesky{Float64, Matrix{Float64}}, LinearAlgebra.Cholesky{Float64, Matrix{Float64}}, Tuple{LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int32}}, Base.RefValue{Int32}}, Tuple{LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Base.RefValue{Int64}}, LinearAlgebra.QRPivoted{Float64, Matrix{Float64}, Vector{Float64}, Vector{Int64}}, Krylov.CraigmrSolver{Float64, Float64, Vector{Float64}}, Krylov.LsmrSolver{Float64, Float64, Vector{Float64}}}, IdentityOperator, IdentityOperator, Float64, Bool}, Nothing}}}, Duplicated{LinearSolve.LinearCache{Matrix{Float64}, Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, LinearSolve.DefaultLinearSolver, LinearSolve.DefaultLinearSolverInit{LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, LinearAlgebra.QRCompactWY{Float64, Matrix{Float64}, Matrix{Float64}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Tuple{LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Vector{Int64}}, Nothing, Nothing, Nothing, LinearAlgebra.SVD{Float64, Float64, Matrix{Float64}, Vector{Float64}}, LinearAlgebra.Cholesky{Float64, Matrix{Float64}}, LinearAlgebra.Cholesky{Float64, Matrix{Float64}}, Tuple{LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int32}}, Base.RefValue{Int32}}, Tuple{LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Base.RefValue{Int64}}, LinearAlgebra.QRPivoted{Float64, Matrix{Float64}, Vector{Float64}, Vector{Int64}}, Krylov.CraigmrSolver{Float64, Float64, Vector{Float64}}, Krylov.LsmrSolver{Float64, Float64, Vector{Float64}}}, IdentityOperator, IdentityOperator, Float64, Bool}}} return type mismatch, expected EnzymeCore.EnzymeRules.AugmentedReturn{SciMLBase.LinearSolution{Float64, 1, Vector{Float64}, Nothing, LinearSolve.DefaultLinearSolver, LinearSolve.LinearCache{Matrix{Float64}, Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, LinearSolve.DefaultLinearSolver, LinearSolve.DefaultLinearSolverInit{LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, LinearAlgebra.QRCompactWY{Float64, Matrix{Float64}, Matrix{Float64}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Tuple{LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Vector{Int64}}, Nothing, Nothing, Nothing, LinearAlgebra.SVD{Float64, Float64, Matrix{Float64}, Vector{Float64}}, LinearAlgebra.Cholesky{Float64, Matrix{Float64}}, LinearAlgebra.Cholesky{Float64, Matrix{Float64}}, Tuple{LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int32}}, Base.RefValue{Int32}}, Tuple{LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Base.RefValue{Int64}}, LinearAlgebra.QRPivoted{Float64, Matrix{Float64}, Vector{Float64}, Vector{Int64}}, Krylov.CraigmrSolver{Float64, Float64, Vector{Float64}}, Krylov.LsmrSolver{Float64, Float64, Vector{Float64}}}, IdentityOperator, IdentityOperator, Float64, Bool}, Nothing}, Nothing, Any} found EnzymeCore.EnzymeRules.AugmentedReturn{SciMLBase.LinearSolution{Float64, 1, Vector{Float64}, Nothing, LinearSolve.DefaultLinearSolver, LinearSolve.LinearCache{Matrix{Float64}, Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, LinearSolve.DefaultLinearSolver, LinearSolve.DefaultLinearSolverInit{LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, LinearAlgebra.QRCompactWY{Float64, Matrix{Float64}, Matrix{Float64}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Tuple{LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Vector{Int64}}, Nothing, Nothing, Nothing, LinearAlgebra.SVD{Float64, Float64, Matrix{Float64}, Vector{Float64}}, LinearAlgebra.Cholesky{Float64, Matrix{Float64}}, LinearAlgebra.Cholesky{Float64, Matrix{Float64}}, Tuple{LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int32}}, Base.RefValue{Int32}}, Tuple{LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Base.RefValue{Int64}}, LinearAlgebra.QRPivoted{Float64, Matrix{Float64}, Vector{Float64}, Vector{Int64}}, Krylov.CraigmrSolver{Float64, Float64, Vector{Float64}}, Krylov.LsmrSolver{Float64, Float64, Vector{Float64}}}, IdentityOperator, IdentityOperator, Float64, Bool}, Nothing}, SciMLBase.LinearSolution{Float64, 1, Vector{Float64}, Nothing, LinearSolve.DefaultLinearSolver, LinearSolve.LinearCache{Matrix{Float64}, Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, LinearSolve.DefaultLinearSolver, LinearSolve.DefaultLinearSolverInit{LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, LinearAlgebra.QRCompactWY{Float64, Matrix{Float64}, Matrix{Float64}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Tuple{LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Vector{Int64}}, Nothing, Nothing, Nothing, LinearAlgebra.SVD{Float64, Float64, Matrix{Float64}, Vector{Float64}}, LinearAlgebra.Cholesky{Float64, Matrix{Float64}}, LinearAlgebra.Cholesky{Float64, Matrix{Float64}}, Tuple{LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int32}}, Base.RefValue{Int32}}, Tuple{LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Base.RefValue{Int64}}, LinearAlgebra.QRPivoted{Float64, Matrix{Float64}, Vector{Float64}, Vector{Int64}}, Krylov.CraigmrSolver{Float64, Float64, Vector{Float64}}, Krylov.LsmrSolver{Float64, Float64, Vector{Float64}}}, IdentityOperator, IdentityOperator, Float64, Bool}, Nothing}, Tuple{Vector{Float64}, Vector{Float64}, LinearSolve.LinearCache{Matrix{Float64}, Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, LinearSolve.DefaultLinearSolver, LinearSolve.DefaultLinearSolverInit{LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, LinearAlgebra.QRCompactWY{Float64, Matrix{Float64}, Matrix{Float64}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Tuple{LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Vector{Int64}}, Nothing, Nothing, Nothing, LinearAlgebra.SVD{Float64, Float64, Matrix{Float64}, Vector{Float64}}, LinearAlgebra.Cholesky{Float64, Matrix{Float64}}, LinearAlgebra.Cholesky{Float64, Matrix{Float64}}, Tuple{LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int32}}, Base.RefValue{Int32}}, Tuple{LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Base.RefValue{Int64}}, LinearAlgebra.QRPivoted{Float64, Matrix{Float64}, Vector{Float64}, Vector{Int64}}, Krylov.CraigmrSolver{Float64, Float64, Vector{Float64}}, Krylov.LsmrSolver{Float64, Float64, Vector{Float64}}}, IdentityOperator, IdentityOperator, Float64, Bool}, Tuple{Matrix{Float64}}, Tuple{Vector{Float64}}}}
Stacktrace:
 [1] #solve#96
   @ ./deprecated.jl:105

Stacktrace:
  [1] throwerr(cstr::Cstring)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/jOGYG/src/compiler.jl:1317
  [2] #solve#96
    @ ./deprecated.jl:105
  [3] testls
    @ ~/Dropbox (Cambridge University)/shared_Daniel/code/RobotDynamics.jl/examples/gradients/enzyme_experiments.jl:178 [inlined]
  [4] diffejulia_testls_4370wrap
    @ ~/Dropbox (Cambridge University)/shared_Daniel/code/RobotDynamics.jl/examples/gradients/enzyme_experiments.jl:0
  [5] macro expansion
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/jOGYG/src/compiler.jl:5306 [inlined]
  [6] enzyme_call(::Val{…}, ::Ptr{…}, ::Type{…}, ::Type{…}, ::Val{…}, ::Type{…}, ::Type{…}, ::Const{…}, ::Type{…}, ::Duplicated{…}, ::Duplicated{…}, ::Duplicated{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/jOGYG/src/compiler.jl:4984
  [7] (::Enzyme.Compiler.CombinedAdjointThunk{…})(::Const{…}, ::Duplicated{…}, ::Vararg{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/jOGYG/src/compiler.jl:4926
  [8] autodiff(::ReverseMode{…}, ::Const{…}, ::Type{…}, ::Duplicated{…}, ::Vararg{…})
    @ Enzyme ~/.julia/packages/Enzyme/jOGYG/src/Enzyme.jl:215
  [9] autodiff(::ReverseMode{…}, ::Const{…}, ::Duplicated{…}, ::Duplicated{…}, ::Vararg{…})
    @ Enzyme ~/.julia/packages/Enzyme/jOGYG/src/Enzyme.jl:238
 [10] autodiff(::ReverseMode{…}, ::typeof(testls), ::Duplicated{…}, ::Duplicated{…}, ::Vararg{…})
    @ Enzyme ~/.julia/packages/Enzyme/jOGYG/src/Enzyme.jl:224
 [11] top-level scope
    @ ~/Dropbox (Cambridge University)/shared_Daniel/code/RobotDynamics.jl/examples/gradients/enzyme_experiments.jl:191
Some type information was truncated. Use `show(err)` to see complete types.

Additional context

See discourse thread

@Larbino1 Larbino1 added the bug Something isn't working label Mar 4, 2024
@ChrisRackauckas
Copy link
Member

Using the return works:

using Enzyme
using LinearSolve

function testls(A, b, u)
    oa = OperatorAssumptions(true, condition = LinearSolve.OperatorCondition.WellConditioned)
    prob = LinearProblem(A, b)
    linsolve = init(prob, LUFactorization(), assumptions = oa)
    cache =solve!(linsolve)
    sum(cache.u)
end


A = [1. 2.; 3. 4.]
b = [1., 2.]
u = zero(b)
dA = deepcopy(A)
db = deepcopy(b)
du = deepcopy(u)
@time testls(A, b, u)
Enzyme.autodiff(Reverse, testls, Duplicated(A, dA), Duplicated(b, db), Duplicated(u, du))

it's just when I change it to use the mutation directly:

function testls(A, b, u)
    oa = OperatorAssumptions(true, condition = LinearSolve.OperatorCondition.WellConditioned)
    prob = LinearProblem(A, b)
    linsolve = init(prob, LUFactorization(), assumptions = oa)
    solve!(linsolve)
    sum(linsolve.u)
end

Enzyme.autodiff(Reverse, testls, Duplicated(A, dA), Duplicated(b, db), Duplicated(u, du))

that it then fails. Which is weird because the return is just the mutated linsolve object.

@wsmoses is this an Enzyme limitation that if the return is ignored it needs to assume nothing return and thus has a type mismatch? What we can do with LinearSolve.jl is just make solve! return nothing since it's mutating (this would be breaking, but might be a good change) and that would help make sure all codes support Enzyme, or if there's a way to handle this case that would be good to know.

@wsmoses
Copy link
Contributor

wsmoses commented Mar 12, 2024 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants