Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjoint overload #29

Closed
ChrisRackauckas opened this issue Feb 7, 2021 · 10 comments
Closed

Adjoint overload #29

ChrisRackauckas opened this issue Feb 7, 2021 · 10 comments

Comments

@ChrisRackauckas
Copy link
Member

It's just the implicit function theorem

https://github.com/mitmath/18335/blob/spring20/notes/adjoint/adjoint.pdf

While there is a generic one for SteadyStateProblem in DiffEqSensitivity (that could be extended to NonlinearProblem). SciML/SciMLSensitivity.jl#389 will work on anything with __solve

@YingboMa
Copy link
Member

YingboMa commented May 19, 2021

Currently, NonlinearSolve doesn't work with Zygote due to build_solution.

julia> using ModelingToolkit, Zygote

julia> Zygote.gradient((u, p)->ModelingToolkit.StructuralTransformations.numerical_nlsolve((u,p)->hypot(u, p)-cos(u), u, p), 0.1, 0.2)
ERROR: MethodError: no method matching ndims(::Tuple{Float64})
Closest candidates are:
  ndims(::AbstractAlgebra.MatrixElem{T} where T) at /Users/scheme/.julia/packages/AbstractAlgebra/Boo1X/src/generic/Matrix.jl:441
  ndims(::Base.Iterators.ProductIterator) at iterators.jl:967
  ndims(::Base.Generator) at generator.jl:53
  ...
Stacktrace:
  [1] build_solution(prob::NonlinearProblem{Float64, false, Float64, NonlinearFunction{false, var"#44#46", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, alg::NonlinearSolve.NewtonRaphson{12, true, DataType, NonlinearSolve.DefaultLinSolve}, u::Tuple{Float64}, resid::Float64; calculate_error::Bool, retcode::Symbol, original::Nothing, left::Nothing, right::Nothing, kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ SciMLBase ~/src/julia/SciMLBase/src/solutions/nonlinear_solutions.jl:26
  [2] build_solution(prob::NonlinearProblem{Float64, false, Float64, NonlinearFunction{false, var"#44#46", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, alg::NonlinearSolve.NewtonRaphson{12, true, DataType, NonlinearSolve.DefaultLinSolve}, u::Tuple{Float64}, resid::Float64)
    @ SciMLBase ~/src/julia/SciMLBase/src/solutions/nonlinear_solutions.jl:25
  [3] (::DiffEqBase.var"#solu_adjoint#171"{SciMLBase.NonlinearSolution{Float64, 0, Float64, Float64, NonlinearProblem{Float64, false, Float64, NonlinearFunction{false, var"#44#46", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, NonlinearSolve.NewtonRaphson{12, true, DataType, NonlinearSolve.DefaultLinSolve}, Nothing, Nothing}})(Δ::Float64)
    @ DiffEqBase ~/src/julia/DiffEqBase/src/zygote.jl:33
  [4] (::DiffEqBase.var"#150#back#172"{DiffEqBase.var"#solu_adjoint#171"{SciMLBase.NonlinearSolution{Float64, 0, Float64, Float64, NonlinearProblem{Float64, false, Float64, NonlinearFunction{false, var"#44#46", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, NonlinearSolve.NewtonRaphson{12, true, DataType, NonlinearSolve.DefaultLinSolve}, Nothing, Nothing}}})(Δ::Float64)
    @ DiffEqBase ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [5] Pullback
    @ ~/src/julia/ModelingToolkit/src/structural_transformation/utils.jl:308 [inlined]
  [6] (::typeof((numerical_nlsolve)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
  [7] Pullback
    @ ./REPL[60]:1 [inlined]
  [8] (::typeof((#43)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
  [9] (::Zygote.var"#41#42"{typeof((#43))})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:41
 [10] gradient(::Function, ::Float64, ::Vararg{Float64, N} where N)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:59
 [11] top-level scope
    @ REPL[60]:1

julia>

@DhairyaLGandhi could you take a look?

@ChrisRackauckas
Copy link
Member Author

using DiffEqSensitivity?

@YingboMa
Copy link
Member

How is DiffEqSensitivity related? This is nonlinear system.

@ChrisRackauckas
Copy link
Member Author

It should use the SteadyStateProblem adjoint IIRC.

@YingboMa
Copy link
Member

I derived frule and rrule here: https://gist.github.com/YingboMa/4e4496f828c6a3179004f6d0ca224d2a

Someone just need to write a performant implementation of it.

@ChrisRackauckas
Copy link
Member Author

What's wrong with the current implementation of the vjp?

@YingboMa
Copy link
Member

For performance reasons, the adjoint for numerical_nlsolve should just be a dozen lines of code that's non-allocating.

@ChrisRackauckas
Copy link
Member Author

I see so just the small problem issue so it'll need a specialized form?

@YingboMa
Copy link
Member

Yeah, exactly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants