Skip to content

Commit

Permalink
Merge pull request #1058 from SciML/dual_detect
Browse files Browse the repository at this point in the history
Improve dual detection and error message
  • Loading branch information
ChrisRackauckas authored Jun 14, 2024
2 parents b9f045c + c6afee5 commit 1ab515f
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 2 deletions.
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ jobs:
test:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
group:
- Core
Expand Down
3 changes: 2 additions & 1 deletion ext/DiffEqBaseCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ module DiffEqBaseCUDAExt

using DiffEqBase, CUDA

function DiffEqBase.ODE_DEFAULT_NORM(u::CuArray{T},t) where {T <: Union{AbstractFloat, Complex}}
function DiffEqBase.ODE_DEFAULT_NORM(
u::CuArray{T}, t) where {T <: Union{AbstractFloat, Complex}}
sqrt(sum(DiffEqBase.sse, u; init = DiffEqBase.sse(zero(T))) / DiffEqBase.totallength(u))
end

Expand Down
70 changes: 70 additions & 0 deletions src/forwarddiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,65 @@ https://discourse.julialang.org/t/typeerror-in-julia-turing-when-sampling-for-a-
end
end

const FORWARDDIFF_AUTODETECTION_FAILURE_MESSAGE = """
Failed to automatically detect ForwardDiff compatability of
the parameter object. In order for ForwardDiff.jl automatic
differentiation to work on a solution object, the state of
the differential equation or nonlinear solve (`u0`) needs to
be converted to a Dual type which matches the values being
differentiated. For example, for a loss function loss(p)
where `p`` is a `Vector{Float64}`, this conversion is
equivalent to:
```julia
# Convert u0 to match the new Dual element type of `p`
_prob = remake(prob, u0 = eltype(p).(prob.u0))
```
In most cases, SciML tools are able to do this conversion
automatically. However, it seems you have provided a
parameter type for which this automatic conversion has failed.
To fix this, you can do the conversion yourself. For example,
if you have a parameter vector being optimized `p` which is
then put into an odd struct, you can manually convert `u0`
to match `p`:
```julia
function loss(p)
_prob = remake(prob, u0 = eltype(p).(prob.u0), p = MyStruct(p))
sol = solve(_prob, ...)
# do stuff on sol
end
```
Or you can define a dispatch on `DiffEqBase.anyeltypedual`
which tells the system what fields to interpret as the
differentiable parts. For example, to support ODESolutions
as parameters we tell it the data is `sol.u` and `sol.t` via:
```julia
function DiffEqBase.anyeltypedual(sol::ODESolution, counter = 0)
DiffEqBase.anyeltypedual((sol.u, sol.t))
end
```
If you have defined this on a common type which should
be more generally supported, please open a pull request
adding this dispatch. If you need help defining this dispatch,
feel free to open an issue.
"""

struct ForwardDiffAutomaticDetectionFailure <: Exception end

function Base.showerror(io::IO, e::ForwardDiffAutomaticDetectionFailure)
print(io, FORWARDDIFF_AUTODETECTION_FAILURE_MESSAGE)
end

function anyeltypedual(::Type{Union{}})
throw(ForwardDiffAutomaticDetectionFailure())
end

# Opt out since these are using for preallocation, not differentiation
function anyeltypedual(x::Union{ForwardDiff.AbstractConfig, Module},
::Type{Val{counter}} = Val{0}) where {counter}
Expand Down Expand Up @@ -192,6 +251,17 @@ function anyeltypedual(sol::RecursiveArrayTools.AbstractDiffEqArray, counter = 0
diffeqmapreduce(anyeltypedual, promote_dual, (sol.u, sol.t))
end

function anyeltypedual(prob::Union{ODEProblem, SDEProblem, RODEProblem, DDEProblem},
::Type{Val{counter}} = Val{0}) where {counter} where {N, T <: NTuple{N, <:Number}}
anyeltypedual((prob.u0, prob.p, prob.tspan))
end

function anyeltypedual(
prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem, OptimizationProblem},
::Type{Val{counter}} = Val{0}) where {counter} where {N, T <: NTuple{N, <:Number}}
anyeltypedual((prob.u0, prob.p))
end

function anyeltypedual(x::Number, ::Type{Val{counter}} = Val{0}) where {counter}
anyeltypedual(typeof(x))
end
Expand Down
2 changes: 1 addition & 1 deletion src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1065,7 +1065,7 @@ end

function solve_up(prob::Union{AbstractDEProblem, NonlinearProblem}, sensealg, u0, p,
args...; kwargs...)
alg = extract_alg(args, kwargs, has_kwargs(prob) ? prob.kwargs : kwargs)
alg = extract_alg(args, kwargs, has_kwargs(prob) ? prob.kwargs : kwargs)
if isnothing(alg) || !(alg isa AbstractDEAlgorithm) # Default algorithm handling
_prob = get_concrete_problem(prob, !(prob isa DiscreteProblem); u0 = u0,
p = p, kwargs...)
Expand Down
7 changes: 7 additions & 0 deletions test/forwarddiff_dual_detection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -328,3 +328,10 @@ t = ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Flo
@test DiffEqBase.promote_u0(u0, p, t) isa AbstractArray{<:ForwardDiff.Dual}
u0 = [1.0 + 1im, 2.0, 3.0]
@test DiffEqBase.promote_u0(u0, p, t) isa AbstractArray{<:Complex{<:ForwardDiff.Dual}}

# Issue https://github.com/SciML/NonlinearSolve.jl/issues/440
f(u, p, t) = [u[2], 1.5u[1]^2]
ode = ODEProblem(f, [0.0, 0.0], (0, 1))
@inferred DiffEqBase.anyeltypedual(ode)
ode = NonlinearProblem(f, [0.0, 0.0], (0, 1))
@inferred DiffEqBase.anyeltypedual(ode)

0 comments on commit 1ab515f

Please sign in to comment.