Skip to content

Commit

Permalink
Merge pull request #147 from SciML/ap/explicit_imports
Browse files Browse the repository at this point in the history
Improve Code Standards
  • Loading branch information
avik-pal authored May 26, 2024
2 parents d80b885 + 248088c commit c985602
Show file tree
Hide file tree
Showing 34 changed files with 436 additions and 397 deletions.
3 changes: 2 additions & 1 deletion lib/SimpleNonlinearSolve/.JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
style = "sciml"
format_markdown = true
annotate_untyped_fields_with_any = false
format_docstrings = true
format_docstrings = true
join_lines_based_on_source = false
4 changes: 3 additions & 1 deletion lib/SimpleNonlinearSolve/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ ChainRulesCore = "1.22"
ConcreteStructs = "0.2.3"
DiffEqBase = "6.149"
DiffResults = "1.1"
ExplicitImports = "1.5.0"
FastClosures = "0.3.2"
FiniteDiff = "2.22"
ForwardDiff = "0.10.36"
Expand Down Expand Up @@ -73,6 +74,7 @@ AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -91,4 +93,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "AllocCheck", "DiffEqBase", "ForwardDiff", "LinearAlgebra", "LinearSolve", "NonlinearProblemLibrary", "Pkg", "Random", "ReTestItems", "SciMLSensitivity", "StaticArrays", "Zygote", "CUDA", "PolyesterForwardDiff", "Reexport", "Test", "FiniteDiff", "ReverseDiff", "Tracker"]
test = ["AllocCheck", "Aqua", "CUDA", "DiffEqBase", "ExplicitImports", "FiniteDiff", "ForwardDiff", "LinearAlgebra", "LinearSolve", "NonlinearProblemLibrary", "Pkg", "PolyesterForwardDiff", "Random", "ReTestItems", "Reexport", "ReverseDiff", "SciMLSensitivity", "StaticArrays", "Test", "Tracker", "Zygote"]
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
module SimpleNonlinearSolveChainRulesCoreExt

using ChainRulesCore, DiffEqBase, SciMLBase, SimpleNonlinearSolve
using ChainRulesCore: ChainRulesCore, NoTangent
using DiffEqBase: DiffEqBase
using SciMLBase: ChainRulesOriginator, NonlinearProblem, NonlinearLeastSquaresProblem
using SimpleNonlinearSolve: SimpleNonlinearSolve

# The expectation here is that no-one is using this directly inside a GPU kernel. We can
# eventually lift this requirement using a custom adjoint
function ChainRulesCore.rrule(::typeof(SimpleNonlinearSolve.__internal_solve_up),
prob::NonlinearProblem, sensealg, u0, u0_changed, p, p_changed, alg, args...;
kwargs...)
out, ∇internal = DiffEqBase._solve_adjoint(prob, sensealg, u0, p,
SciMLBase.ChainRulesOriginator(), alg, args...; kwargs...)
prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs...)
out, ∇internal = DiffEqBase._solve_adjoint(
prob, sensealg, u0, p, ChainRulesOriginator(), alg, args...; kwargs...)
function ∇__internal_solve_up(Δ)
∂f, ∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal(Δ)
return (∂f, ∂prob, ∂sensealg, ∂u0, NoTangent(), ∂p, NoTangent(), NoTangent(),
∂args...)
return (
f, ∂prob, ∂sensealg, ∂u0, NoTangent(), ∂p, NoTangent(), NoTangent(), ∂args...)
end
return out, ∇__internal_solve_up
end
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
module SimpleNonlinearSolvePolyesterForwardDiffExt

using SimpleNonlinearSolve, PolyesterForwardDiff
using PolyesterForwardDiff: PolyesterForwardDiff
using SimpleNonlinearSolve: SimpleNonlinearSolve

@inline SimpleNonlinearSolve.__is_extension_loaded(::Val{:PolyesterForwardDiff}) = true

@inline function SimpleNonlinearSolve.__polyester_forwarddiff_jacobian!(f!::F, y, J, x,
chunksize) where {F}
@inline function SimpleNonlinearSolve.__polyester_forwarddiff_jacobian!(
f!::F, y, J, x, chunksize) where {F}
PolyesterForwardDiff.threaded_jacobian!(f!, y, J, x, chunksize)
return J
end

@inline function SimpleNonlinearSolve.__polyester_forwarddiff_jacobian!(f::F, J, x,
chunksize) where {F}
@inline function SimpleNonlinearSolve.__polyester_forwarddiff_jacobian!(
f::F, J, x, chunksize) where {F}
PolyesterForwardDiff.threaded_jacobian!(f, J, x, chunksize)
return J
end
Expand Down
99 changes: 53 additions & 46 deletions lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveReverseDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,60 +1,67 @@
module SimpleNonlinearSolveReverseDiffExt

using ArrayInterface, DiffEqBase, ReverseDiff, SciMLBase, SimpleNonlinearSolve
import ReverseDiff: TrackedArray, TrackedReal
using ArrayInterface: ArrayInterface
using DiffEqBase: DiffEqBase
using ReverseDiff: ReverseDiff, TrackedArray, TrackedReal
using SciMLBase: ReverseDiffOriginator, NonlinearProblem, NonlinearLeastSquaresProblem
using SimpleNonlinearSolve: SimpleNonlinearSolve
import SimpleNonlinearSolve: __internal_solve_up

function __internal_solve_up(
prob::NonlinearProblem, sensealg, u0::TrackedArray, u0_changed,
p::TrackedArray, p_changed, alg, args...; kwargs...)
return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0,
u0_changed, p, p_changed, alg, args...; kwargs...)
end
for pType in (NonlinearProblem, NonlinearLeastSquaresProblem)
@eval begin
function __internal_solve_up(prob::$(pType), sensealg, u0::TrackedArray, u0_changed,
p::TrackedArray, p_changed, alg, args...; kwargs...)
return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0,
u0_changed, p, p_changed, alg, args...; kwargs...)
end

function __internal_solve_up(
prob::NonlinearProblem, sensealg, u0, u0_changed,
p::TrackedArray, p_changed, alg, args...; kwargs...)
return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0,
u0_changed, p, p_changed, alg, args...; kwargs...)
end
function __internal_solve_up(prob::$(pType), sensealg, u0, u0_changed,
p::TrackedArray, p_changed, alg, args...; kwargs...)
return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0,
u0_changed, p, p_changed, alg, args...; kwargs...)
end

function __internal_solve_up(
prob::NonlinearProblem, sensealg, u0::TrackedArray, u0_changed,
p, p_changed, alg, args...; kwargs...)
return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0,
u0_changed, p, p_changed, alg, args...; kwargs...)
end
function __internal_solve_up(prob::$(pType), sensealg, u0::TrackedArray,
u0_changed, p, p_changed, alg, args...; kwargs...)
return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0,
u0_changed, p, p_changed, alg, args...; kwargs...)
end

function __internal_solve_up(prob::NonlinearProblem, sensealg,
u0::AbstractArray{<:TrackedReal}, u0_changed, p::AbstractArray{<:TrackedReal},
p_changed, alg, args...; kwargs...)
return __internal_solve_up(
prob, sensealg, ArrayInterface.aos_to_soa(u0), true,
ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...)
end
function __internal_solve_up(
prob::$(pType), sensealg, u0::AbstractArray{<:TrackedReal}, u0_changed,
p::AbstractArray{<:TrackedReal}, p_changed, alg, args...; kwargs...)
return __internal_solve_up(prob, sensealg, ArrayInterface.aos_to_soa(u0), true,
ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...)
end

function __internal_solve_up(prob::NonlinearProblem, sensealg, u0, u0_changed,
p::AbstractArray{<:TrackedReal}, p_changed, alg, args...; kwargs...)
return __internal_solve_up(
prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...)
end
function __internal_solve_up(prob::$(pType), sensealg, u0, u0_changed,
p::AbstractArray{<:TrackedReal}, p_changed, alg, args...; kwargs...)
return __internal_solve_up(
prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p),
true, alg, args...; kwargs...)
end

function __internal_solve_up(prob::NonlinearProblem, sensealg,
u0::AbstractArray{<:TrackedReal}, u0_changed, p, p_changed, alg, args...; kwargs...)
return __internal_solve_up(
prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...)
end
function __internal_solve_up(
prob::$(pType), sensealg, u0::AbstractArray{<:TrackedReal},
u0_changed, p, p_changed, alg, args...; kwargs...)
return __internal_solve_up(
prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p),
true, alg, args...; kwargs...)
end

ReverseDiff.@grad function __internal_solve_up(
prob::NonlinearProblem, sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs...)
out, ∇internal = DiffEqBase._solve_adjoint(
prob, sensealg, ReverseDiff.value(u0), ReverseDiff.value(p),
SciMLBase.ReverseDiffOriginator(), alg, args...; kwargs...)
function ∇__internal_solve_up(_args...)
∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal(_args...)
return (∂prob, ∂sensealg, ∂u0, nothing, ∂p, nothing, nothing, ∂args...)
ReverseDiff.@grad function __internal_solve_up(
prob::$(pType), sensealg, u0, u0_changed,
p, p_changed, alg, args...; kwargs...)
out, ∇internal = DiffEqBase._solve_adjoint(
prob, sensealg, ReverseDiff.value(u0), ReverseDiff.value(p),
ReverseDiffOriginator(), alg, args...; kwargs...)
function ∇__internal_solve_up(_args...)
∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal(_args...)
return (∂prob, ∂sensealg, ∂u0, nothing, ∂p, nothing, nothing, ∂args...)
end
return Array(out), ∇__internal_solve_up
end
end
return Array(out), ∇__internal_solve_up
end

end
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module SimpleNonlinearSolveStaticArraysExt

using SimpleNonlinearSolve
using SimpleNonlinearSolve: SimpleNonlinearSolve

@inline SimpleNonlinearSolve.__is_extension_loaded(::Val{:StaticArrays}) = true

Expand Down
79 changes: 43 additions & 36 deletions lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTrackerExt.jl
Original file line number Diff line number Diff line change
@@ -1,42 +1,49 @@
module SimpleNonlinearSolveTrackerExt

using DiffEqBase, SciMLBase, SimpleNonlinearSolve, Tracker

function SimpleNonlinearSolve.__internal_solve_up(prob::NonlinearProblem,
sensealg, u0::TrackedArray, u0_changed, p, p_changed, alg, args...; kwargs...)
return Tracker.track(
SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, u0, u0_changed,
p, p_changed, alg, args...; kwargs...)
end

function SimpleNonlinearSolve.__internal_solve_up(
prob::NonlinearProblem, sensealg, u0::TrackedArray, u0_changed,
p::TrackedArray, p_changed, alg, args...; kwargs...)
return Tracker.track(
SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, u0, u0_changed,
p, p_changed, alg, args...; kwargs...)
end

function SimpleNonlinearSolve.__internal_solve_up(prob::NonlinearProblem,
sensealg, u0, u0_changed, p::TrackedArray, p_changed, alg, args...; kwargs...)
return Tracker.track(
SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, u0, u0_changed,
p, p_changed, alg, args...; kwargs...)
end

Tracker.@grad function SimpleNonlinearSolve.__internal_solve_up(_prob::NonlinearProblem,
sensealg, u0_, u0_changed, p_, p_changed, alg, args...; kwargs...)
u0, p = Tracker.data(u0_), Tracker.data(p_)
prob = remake(_prob; u0, p)
out, ∇internal = DiffEqBase._solve_adjoint(prob, sensealg, u0, p,
SciMLBase.TrackerOriginator(), alg, args...; kwargs...)

function ∇__internal_solve_up(Δ)
∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal(Δ)
return (∂prob, ∂sensealg, ∂u0, nothing, ∂p, nothing, nothing, ∂args...)
using DiffEqBase: DiffEqBase
using SciMLBase: TrackerOriginator, NonlinearProblem, NonlinearLeastSquaresProblem, remake
using SimpleNonlinearSolve: SimpleNonlinearSolve
using Tracker: Tracker, TrackedArray

for pType in (NonlinearProblem, NonlinearLeastSquaresProblem)
@eval begin
function SimpleNonlinearSolve.__internal_solve_up(
prob::$(pType), sensealg, u0::TrackedArray,
u0_changed, p, p_changed, alg, args...; kwargs...)
return Tracker.track(SimpleNonlinearSolve.__internal_solve_up, prob, sensealg,
u0, u0_changed, p, p_changed, alg, args...; kwargs...)
end

function SimpleNonlinearSolve.__internal_solve_up(
prob::$(pType), sensealg, u0::TrackedArray, u0_changed,
p::TrackedArray, p_changed, alg, args...; kwargs...)
return Tracker.track(SimpleNonlinearSolve.__internal_solve_up, prob, sensealg,
u0, u0_changed, p, p_changed, alg, args...; kwargs...)
end

function SimpleNonlinearSolve.__internal_solve_up(
prob::$(pType), sensealg, u0, u0_changed,
p::TrackedArray, p_changed, alg, args...; kwargs...)
return Tracker.track(SimpleNonlinearSolve.__internal_solve_up, prob, sensealg,
u0, u0_changed, p, p_changed, alg, args...; kwargs...)
end

Tracker.@grad function SimpleNonlinearSolve.__internal_solve_up(
_prob::$(pType), sensealg, u0_, u0_changed,
p_, p_changed, alg, args...; kwargs...)
u0, p = Tracker.data(u0_), Tracker.data(p_)
prob = remake(_prob; u0, p)
out, ∇internal = DiffEqBase._solve_adjoint(
prob, sensealg, u0, p, TrackerOriginator(), alg, args...; kwargs...)

function ∇__internal_solve_up(Δ)
∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal(Δ)
return (∂prob, ∂sensealg, ∂u0, nothing, ∂p, nothing, nothing, ∂args...)
end

return out, ∇__internal_solve_up
end
end

return out, ∇__internal_solve_up
end

end
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module SimpleNonlinearSolveZygoteExt

import SimpleNonlinearSolve, Zygote
using SimpleNonlinearSolve: SimpleNonlinearSolve
using Zygote: Zygote

SimpleNonlinearSolve.__is_extension_loaded(::Val{:Zygote}) = true

Expand Down
Loading

0 comments on commit c985602

Please sign in to comment.