Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve dual detection and error message #1058

Merged
merged 3 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ jobs:
test:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
group:
- Core
Expand Down
3 changes: 2 additions & 1 deletion ext/DiffEqBaseCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ module DiffEqBaseCUDAExt

using DiffEqBase, CUDA

function DiffEqBase.ODE_DEFAULT_NORM(u::CuArray{T},t) where {T <: Union{AbstractFloat, Complex}}
function DiffEqBase.ODE_DEFAULT_NORM(
u::CuArray{T}, t) where {T <: Union{AbstractFloat, Complex}}
sqrt(sum(DiffEqBase.sse, u; init = DiffEqBase.sse(zero(T))) / DiffEqBase.totallength(u))
end

Expand Down
70 changes: 70 additions & 0 deletions src/forwarddiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,65 @@ https://discourse.julialang.org/t/typeerror-in-julia-turing-when-sampling-for-a-
end
end

const FORWARDDIFF_AUTODETECTION_FAILURE_MESSAGE = """
Failed to automatically detect ForwardDiff compatability of
the parameter object. In order for ForwardDiff.jl automatic
differentiation to work on a solution object, the state of
the differential equation or nonlinear solve (`u0`) needs to
be converted to a Dual type which matches the values being
differentiated. For example, for a loss function loss(p)
where `p`` is a `Vector{Float64}`, this conversion is
equivalent to:

```julia
# Convert u0 to match the new Dual element type of `p`
_prob = remake(prob, u0 = eltype(p).(prob.u0))
```

In most cases, SciML tools are able to do this conversion
automatically. However, it seems you have provided a
parameter type for which this automatic conversion has failed.

To fix this, you can do the conversion yourself. For example,
if you have a parameter vector being optimized `p` which is
then put into an odd struct, you can manually convert `u0`
to match `p`:

```julia
function loss(p)
_prob = remake(prob, u0 = eltype(p).(prob.u0), p = MyStruct(p))
sol = solve(_prob, ...)
# do stuff on sol
end
```

Or you can define a dispatch on `DiffEqBase.anyeltypedual`
which tells the system what fields to interpret as the
differentiable parts. For example, to support ODESolutions
as parameters we tell it the data is `sol.u` and `sol.t` via:

```julia
function DiffEqBase.anyeltypedual(sol::ODESolution, counter = 0)
DiffEqBase.anyeltypedual((sol.u, sol.t))
end
```

If you have defined this on a common type which should
be more generally supported, please open a pull request
adding this dispatch. If you need help defining this dispatch,
feel free to open an issue.
"""

struct ForwardDiffAutomaticDetectionFailure <: Exception end

function Base.showerror(io::IO, e::ForwardDiffAutomaticDetectionFailure)
print(io, FORWARDDIFF_AUTODETECTION_FAILURE_MESSAGE)
end

function anyeltypedual(::Type{Union{}})
throw(ForwardDiffAutomaticDetectionFailure())
end

# Opt out since these are using for preallocation, not differentiation
function anyeltypedual(x::Union{ForwardDiff.AbstractConfig, Module},
::Type{Val{counter}} = Val{0}) where {counter}
Expand Down Expand Up @@ -192,6 +251,17 @@ function anyeltypedual(sol::RecursiveArrayTools.AbstractDiffEqArray, counter = 0
diffeqmapreduce(anyeltypedual, promote_dual, (sol.u, sol.t))
end

function anyeltypedual(prob::Union{ODEProblem, SDEProblem, RODEProblem, DDEProblem},
::Type{Val{counter}} = Val{0}) where {counter} where {N, T <: NTuple{N, <:Number}}
anyeltypedual((prob.u0, prob.p, prob.tspan))
end

function anyeltypedual(
prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem, OptimizationProblem},
::Type{Val{counter}} = Val{0}) where {counter} where {N, T <: NTuple{N, <:Number}}
anyeltypedual((prob.u0, prob.p))
end

function anyeltypedual(x::Number, ::Type{Val{counter}} = Val{0}) where {counter}
anyeltypedual(typeof(x))
end
Expand Down
2 changes: 1 addition & 1 deletion src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1065,7 +1065,7 @@ end

function solve_up(prob::Union{AbstractDEProblem, NonlinearProblem}, sensealg, u0, p,
args...; kwargs...)
alg = extract_alg(args, kwargs, has_kwargs(prob) ? prob.kwargs : kwargs)
alg = extract_alg(args, kwargs, has_kwargs(prob) ? prob.kwargs : kwargs)
if isnothing(alg) || !(alg isa AbstractDEAlgorithm) # Default algorithm handling
_prob = get_concrete_problem(prob, !(prob isa DiscreteProblem); u0 = u0,
p = p, kwargs...)
Expand Down
7 changes: 7 additions & 0 deletions test/forwarddiff_dual_detection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -328,3 +328,10 @@ t = ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Flo
@test DiffEqBase.promote_u0(u0, p, t) isa AbstractArray{<:ForwardDiff.Dual}
u0 = [1.0 + 1im, 2.0, 3.0]
@test DiffEqBase.promote_u0(u0, p, t) isa AbstractArray{<:Complex{<:ForwardDiff.Dual}}

# Issue https://github.com/SciML/NonlinearSolve.jl/issues/440
f(u, p, t) = [u[2], 1.5u[1]^2]
ode = ODEProblem(f, [0.0, 0.0], (0, 1))
@inferred DiffEqBase.anyeltypedual(ode)
ode = NonlinearProblem(f, [0.0, 0.0], (0, 1))
@inferred DiffEqBase.anyeltypedual(ode)
Loading