Skip to content

Commit

Permalink
fix: minor fixes to support adjoints
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 7, 2024
1 parent 0cbc2fc commit 2c1cdaf
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 4 deletions.
3 changes: 3 additions & 0 deletions lib/NonlinearSolveBase/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"

[weakdeps]
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[extensions]
NonlinearSolveBaseDiffEqBaseExt = "DiffEqBase"
NonlinearSolveBaseForwardDiffExt = "ForwardDiff"
NonlinearSolveBaseSparseArraysExt = "SparseArrays"

Expand All @@ -33,6 +35,7 @@ ArrayInterface = "7.9"
CommonSolve = "0.2.4"
Compat = "4.15"
ConcreteStructs = "0.2.3"
DiffEqBase = "6.149"
DifferentiationInterface = "0.6.1"
EnzymeCore = "0.8"
FastClosures = "0.3"
Expand Down
16 changes: 16 additions & 0 deletions lib/NonlinearSolveBase/ext/NonlinearSolveBaseDiffEqBaseExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
module NonlinearSolveBaseDiffEqBaseExt

using DiffEqBase: DiffEqBase
using SciMLBase: remake

using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem

function DiffEqBase.get_concrete_problem(
prob::ImmutableNonlinearProblem, isadapt; kwargs...)
u0 = DiffEqBase.get_concrete_u0(prob, isadapt, nothing, kwargs)
u0 = DiffEqBase.promote_u0(u0, prob.p, nothing)
p = DiffEqBase.get_concrete_p(prob, kwargs)
return remake(prob; u0 = u0, p = p)
end

end
5 changes: 3 additions & 2 deletions lib/SimpleNonlinearSolve/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ CUDA = "5.3"
ChainRulesCore = "1.24"
CommonSolve = "0.2.4"
ConcreteStructs = "0.2.3"
DiffEqBase = "6.155"
DiffEqBase = "6.149"
DifferentiationInterface = "0.6.1"
Enzyme = "0.13"
ExplicitImports = "1.9"
Expand Down Expand Up @@ -79,6 +79,7 @@ julia = "1.10"
AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Expand All @@ -95,4 +96,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["AllocCheck", "Aqua", "CUDA", "Enzyme", "ExplicitImports", "InteractiveUtils", "NonlinearProblemLibrary", "Pkg", "PolyesterForwardDiff", "Random", "ReverseDiff", "SciMLSensitivity", "StaticArrays", "Test", "TestItemRunner", "Tracker", "Zygote"]
test = ["AllocCheck", "Aqua", "CUDA", "DiffEqBase", "Enzyme", "ExplicitImports", "InteractiveUtils", "NonlinearProblemLibrary", "Pkg", "PolyesterForwardDiff", "Random", "ReverseDiff", "SciMLSensitivity", "StaticArrays", "Test", "TestItemRunner", "Tracker", "Zygote"]
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ using DiffEqBase: DiffEqBase

using SimpleNonlinearSolve: SimpleNonlinearSolve

SimpleNonlinearSolve.is_extension_loaded(::Val{:DiffEqBase}) = true

function SimpleNonlinearSolve.solve_adjoint_internal(args...; kwargs...)
return DiffEqBase._solve_adjoint(args...; kwargs...)
end
Expand Down
3 changes: 2 additions & 1 deletion lib/SimpleNonlinearSolve/test/core/adjoint_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
@testitem "Simple Adjoint Test" tags=[:adjoint] begin
using ForwardDiff, ReverseDiff, SciMLSensitivity, Tracker, Zygote
using ForwardDiff, ReverseDiff, SciMLSensitivity, Tracker, Zygote, DiffEqBase,
SimpleNonlinearSolve

ff(u, p) = u .^ 2 .- p

Expand Down
2 changes: 1 addition & 1 deletion lib/SimpleNonlinearSolve/test/core/allocation_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
@test true
catch e
@error e
@test false broken = (alg isa SimpleHalley)
@test false broken=(alg isa SimpleHalley)
end
end
end
Empty file.

0 comments on commit 2c1cdaf

Please sign in to comment.