Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The end of concrete_solve #610

Closed
ChrisRackauckas opened this issue May 28, 2020 · 8 comments · Fixed by SciML/DiffEqBase.jl#520
Closed

The end of concrete_solve #610

ChrisRackauckas opened this issue May 28, 2020 · 8 comments · Fixed by SciML/DiffEqBase.jl#520

Comments

@ChrisRackauckas
Copy link
Member

Let's end concrete_solve. The question of "why do we have concrete_solve" and "how to get rid of concrete_solve" are essentially the same question, so let me layout what we need to fix about solve in order to deprecate concrete_solve:

Adjoints are not safe with interpolation.

This is not something that can or ever will be fixed, because it's just not a good idea. You could add it be adding discrete sensitivity analysis as a way to calculate the derivatives of sol.k w.r.t. parameters, but if you look at what's going on... no. Discrete sensitivity analysis is not only equivalent to AD through the solver, but it would also require an implementation per method. So... you might as well AD through the solver if that's what you're looking for, and that would be much more efficient since adjoints + continuous outputs is essentially calculating two different versions of reverse-mode AD when it could be done with one, so it's just never a good way computationally or memory-wise to compute the derivative there.

If dense=false, then a linear interpolation is used and that would be safe, and currently concrete_solve does not allow this safe option. However, this is not something we could blindly due to the user, since if outside of an AD context they have a 9th order algorithm but when doing AD it's a 1st order algorithm, that would introduce so many numerical issues it's not even funny. So we can't just set dense=false to the user. But if we don't and they do use interpolation, they will get a zero gradient from the values generated by the zero gradient, which then Zygote brings all the way back as zero gradients instead of erroring, and so training is now essentially turned off on loss functions which are dependent on the interpolation and only when dense=true. That is also a major trap, and the big reason why concrete_solve was added in the first place.

Solution

Allow for passing dense = NullInterpolation() and when the solvers see this, they create a post-solution interpolation that errors if you try to use it, saying that this interpolation is not compatible with usage inside of AD, suggesting that you use saveat or resort to AD on the solver itself for this functionality. Downstream packages need to get updated for this. I think most algorithms just pass through dense=dense to build_solution, so it can be handled in DiffEqBase+OrdinaryDiffEq and that should handle everything.

AD pass through

So okay, this brings up the second issue: if we want to make this work out, then we need to have an option so that AD can still work on the solver, otherwise we recursively keep capturing it to send it to an adjoint method.

Solution

This means we need a sensealg choice SensitivityPassThrough where when this is seen, the adjoint continues the original AD call on the adjoint. I think this can be done in ChainRules by having !(SensitivityPassThrough <: AbstractSensitivityAlgorithm) and then only defining the rrule to dispatch on Union{Nothing,AbstractSensitivityAlgorithm}, since then the AD should just ignore the adjoint when it's SensitivityPassThrough (@oxinabox can you confirm?)

Differentiation w.r.t. input fields

Lastly, we need to figure out how to differentiate w.r.t. input fields. concrete_solve specifically does:

concrete_solve(prob,alg,u0=prob.u0,p=prob.p;...)

so that (a) it's easier for us to write dispatches to differentiate w.r.t. u0 and p but also (b) it's easier for users to change u0 and p.

Solution

One more recent change is that we now have a system for allowing overrides:

solve(prob,alg,u0=...,p=...)

I think we just need to have a lowering process that's like solve -> _solve_up (which then gets the adjoint definitions) -> internal stuff for handling distributions and all of that.

@oxinabox is there something special that could help here?

@ChrisRackauckas
Copy link
Member Author

ChrisRackauckas commented May 29, 2020

@oxinabox
Copy link

I think this can be done in ChainRules by having !(SensitivityPassThrough <: AbstractSensitivityAlgorithm) and then only defining the rrule to dispatch on Union{Nothing,AbstractSensitivityAlgorithm}, since then the AD should just ignore the adjoint when it's SensitivityPassThrough (@oxinabox can you confirm?)

Sounds right to me.

@oxinabox is there something special that could help here?

Not sure, are there particular problems you have still? Seems like you have a solution.
JuliaDiff/ChainRulesCore.jl#68 would open up some more options, but not nesc useful ones.

@ChrisRackauckas
Copy link
Member Author

@ChrisRackauckas
Copy link
Member Author

SciML/DiffEqFlux.jl#273

@ChrisRackauckas
Copy link
Member Author

Not sure, are there particular problems you have still? Seems like you have a solution.

Yeah, I got something working. It's a bit complex but it solves all of these problems so we're good.

@ChrisRackauckas
Copy link
Member Author

Done. concrete_solve is no more: solve does it all and is safe.

@ChrisRackauckas
Copy link
Member Author

Thanks fixed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants