-
-
Notifications
You must be signed in to change notification settings - Fork 42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adjoint overload #29
Comments
Currently, julia> using ModelingToolkit, Zygote
julia> Zygote.gradient((u, p)->ModelingToolkit.StructuralTransformations.numerical_nlsolve((u,p)->hypot(u, p)-cos(u), u, p), 0.1, 0.2)
ERROR: MethodError: no method matching ndims(::Tuple{Float64})
Closest candidates are:
ndims(::AbstractAlgebra.MatrixElem{T} where T) at /Users/scheme/.julia/packages/AbstractAlgebra/Boo1X/src/generic/Matrix.jl:441
ndims(::Base.Iterators.ProductIterator) at iterators.jl:967
ndims(::Base.Generator) at generator.jl:53
...
Stacktrace:
[1] build_solution(prob::NonlinearProblem{Float64, false, Float64, NonlinearFunction{false, var"#44#46", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, alg::NonlinearSolve.NewtonRaphson{12, true, DataType, NonlinearSolve.DefaultLinSolve}, u::Tuple{Float64}, resid::Float64; calculate_error::Bool, retcode::Symbol, original::Nothing, left::Nothing, right::Nothing, kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ SciMLBase ~/src/julia/SciMLBase/src/solutions/nonlinear_solutions.jl:26
[2] build_solution(prob::NonlinearProblem{Float64, false, Float64, NonlinearFunction{false, var"#44#46", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, alg::NonlinearSolve.NewtonRaphson{12, true, DataType, NonlinearSolve.DefaultLinSolve}, u::Tuple{Float64}, resid::Float64)
@ SciMLBase ~/src/julia/SciMLBase/src/solutions/nonlinear_solutions.jl:25
[3] (::DiffEqBase.var"#solu_adjoint#171"{SciMLBase.NonlinearSolution{Float64, 0, Float64, Float64, NonlinearProblem{Float64, false, Float64, NonlinearFunction{false, var"#44#46", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, NonlinearSolve.NewtonRaphson{12, true, DataType, NonlinearSolve.DefaultLinSolve}, Nothing, Nothing}})(Δ::Float64)
@ DiffEqBase ~/src/julia/DiffEqBase/src/zygote.jl:33
[4] (::DiffEqBase.var"#150#back#172"{DiffEqBase.var"#solu_adjoint#171"{SciMLBase.NonlinearSolution{Float64, 0, Float64, Float64, NonlinearProblem{Float64, false, Float64, NonlinearFunction{false, var"#44#46", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, NonlinearSolve.NewtonRaphson{12, true, DataType, NonlinearSolve.DefaultLinSolve}, Nothing, Nothing}}})(Δ::Float64)
@ DiffEqBase ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
[5] Pullback
@ ~/src/julia/ModelingToolkit/src/structural_transformation/utils.jl:308 [inlined]
[6] (::typeof(∂(numerical_nlsolve)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
[7] Pullback
@ ./REPL[60]:1 [inlined]
[8] (::typeof(∂(#43)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
[9] (::Zygote.var"#41#42"{typeof(∂(#43))})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:41
[10] gradient(::Function, ::Float64, ::Vararg{Float64, N} where N)
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:59
[11] top-level scope
@ REPL[60]:1
julia> @DhairyaLGandhi could you take a look? |
|
How is |
It should use the SteadyStateProblem adjoint IIRC. |
I derived frule and rrule here: https://gist.github.com/YingboMa/4e4496f828c6a3179004f6d0ca224d2a Someone just need to write a performant implementation of it. |
What's wrong with the current implementation of the vjp? |
For performance reasons, the adjoint for |
I see so just the small problem issue so it'll need a specialized form? |
Yeah, exactly. |
It's just the implicit function theorem
https://github.com/mitmath/18335/blob/spring20/notes/adjoint/adjoint.pdf
While there is a generic one for SteadyStateProblem in DiffEqSensitivity (that could be extended to NonlinearProblem). SciML/SciMLSensitivity.jl#389 will work on anything with __solve
The text was updated successfully, but these errors were encountered: