diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index c24342f6a..fd83d20cf 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -15,6 +15,7 @@ jobs: test: runs-on: ubuntu-latest strategy: + fail-fast: false matrix: group: - Core diff --git a/ext/DiffEqBaseCUDAExt.jl b/ext/DiffEqBaseCUDAExt.jl index 0adfa02bb..3af91fb9c 100644 --- a/ext/DiffEqBaseCUDAExt.jl +++ b/ext/DiffEqBaseCUDAExt.jl @@ -2,7 +2,8 @@ module DiffEqBaseCUDAExt using DiffEqBase, CUDA -function DiffEqBase.ODE_DEFAULT_NORM(u::CuArray{T},t) where {T <: Union{AbstractFloat, Complex}} +function DiffEqBase.ODE_DEFAULT_NORM( + u::CuArray{T}, t) where {T <: Union{AbstractFloat, Complex}} sqrt(sum(DiffEqBase.sse, u; init = DiffEqBase.sse(zero(T))) / DiffEqBase.totallength(u)) end diff --git a/src/forwarddiff.jl b/src/forwarddiff.jl index 5b3e85d7c..52fe08d9f 100644 --- a/src/forwarddiff.jl +++ b/src/forwarddiff.jl @@ -112,6 +112,65 @@ https://discourse.julialang.org/t/typeerror-in-julia-turing-when-sampling-for-a- end end +const FORWARDDIFF_AUTODETECTION_FAILURE_MESSAGE = """ + Failed to automatically detect ForwardDiff compatability of + the parameter object. In order for ForwardDiff.jl automatic + differentiation to work on a solution object, the state of + the differential equation or nonlinear solve (`u0`) needs to + be converted to a Dual type which matches the values being + differentiated. For example, for a loss function loss(p) + where `p`` is a `Vector{Float64}`, this conversion is + equivalent to: + + ```julia + # Convert u0 to match the new Dual element type of `p` + _prob = remake(prob, u0 = eltype(p).(prob.u0)) + ``` + + In most cases, SciML tools are able to do this conversion + automatically. However, it seems you have provided a + parameter type for which this automatic conversion has failed. + + To fix this, you can do the conversion yourself. For example, + if you have a parameter vector being optimized `p` which is + then put into an odd struct, you can manually convert `u0` + to match `p`: + + ```julia + function loss(p) + _prob = remake(prob, u0 = eltype(p).(prob.u0), p = MyStruct(p)) + sol = solve(_prob, ...) + # do stuff on sol + end + ``` + + Or you can define a dispatch on `DiffEqBase.anyeltypedual` + which tells the system what fields to interpret as the + differentiable parts. For example, to support ODESolutions + as parameters we tell it the data is `sol.u` and `sol.t` via: + + ```julia + function DiffEqBase.anyeltypedual(sol::ODESolution, counter = 0) + DiffEqBase.anyeltypedual((sol.u, sol.t)) + end + ``` + + If you have defined this on a common type which should + be more generally supported, please open a pull request + adding this dispatch. If you need help defining this dispatch, + feel free to open an issue. + """ + +struct ForwardDiffAutomaticDetectionFailure <: Exception end + +function Base.showerror(io::IO, e::ForwardDiffAutomaticDetectionFailure) + print(io, FORWARDDIFF_AUTODETECTION_FAILURE_MESSAGE) +end + +function anyeltypedual(::Type{Union{}}) + throw(ForwardDiffAutomaticDetectionFailure()) +end + # Opt out since these are using for preallocation, not differentiation function anyeltypedual(x::Union{ForwardDiff.AbstractConfig, Module}, ::Type{Val{counter}} = Val{0}) where {counter} @@ -192,6 +251,17 @@ function anyeltypedual(sol::RecursiveArrayTools.AbstractDiffEqArray, counter = 0 diffeqmapreduce(anyeltypedual, promote_dual, (sol.u, sol.t)) end +function anyeltypedual(prob::Union{ODEProblem, SDEProblem, RODEProblem, DDEProblem}, + ::Type{Val{counter}} = Val{0}) where {counter} where {N, T <: NTuple{N, <:Number}} + anyeltypedual((prob.u0, prob.p, prob.tspan)) +end + +function anyeltypedual( + prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem, OptimizationProblem}, + ::Type{Val{counter}} = Val{0}) where {counter} where {N, T <: NTuple{N, <:Number}} + anyeltypedual((prob.u0, prob.p)) +end + function anyeltypedual(x::Number, ::Type{Val{counter}} = Val{0}) where {counter} anyeltypedual(typeof(x)) end diff --git a/src/solve.jl b/src/solve.jl index 28ac78270..68893d53a 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -1065,7 +1065,7 @@ end function solve_up(prob::Union{AbstractDEProblem, NonlinearProblem}, sensealg, u0, p, args...; kwargs...) - alg = extract_alg(args, kwargs, has_kwargs(prob) ? prob.kwargs : kwargs) + alg = extract_alg(args, kwargs, has_kwargs(prob) ? prob.kwargs : kwargs) if isnothing(alg) || !(alg isa AbstractDEAlgorithm) # Default algorithm handling _prob = get_concrete_problem(prob, !(prob isa DiscreteProblem); u0 = u0, p = p, kwargs...) diff --git a/test/forwarddiff_dual_detection.jl b/test/forwarddiff_dual_detection.jl index 2a244864f..132b24556 100644 --- a/test/forwarddiff_dual_detection.jl +++ b/test/forwarddiff_dual_detection.jl @@ -328,3 +328,10 @@ t = ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Flo @test DiffEqBase.promote_u0(u0, p, t) isa AbstractArray{<:ForwardDiff.Dual} u0 = [1.0 + 1im, 2.0, 3.0] @test DiffEqBase.promote_u0(u0, p, t) isa AbstractArray{<:Complex{<:ForwardDiff.Dual}} + +# Issue https://github.com/SciML/NonlinearSolve.jl/issues/440 +f(u, p, t) = [u[2], 1.5u[1]^2] +ode = ODEProblem(f, [0.0, 0.0], (0, 1)) +@inferred DiffEqBase.anyeltypedual(ode) +ode = NonlinearProblem(f, [0.0, 0.0], (0, 1)) +@inferred DiffEqBase.anyeltypedual(ode)