Skip to content

Commit

Permalink
Add PolyesterForwardDiff Support
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 17, 2024
1 parent 4ad919a commit 7c1264f
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 14 deletions.
14 changes: 9 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SparseDiffTools"
uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
authors = ["Pankaj Mishra <[email protected]>", "Chris Rackauckas <[email protected]>"]
version = "2.15.1"
version = "2.16.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -27,17 +27,19 @@ VertexSafeGraphs = "19fa3120-7c27-5ec5-8db8-b0b0aa330d6f"

[weakdeps]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
SparseDiffToolsEnzymeExt = "Enzyme"
SparseDiffToolsPolyesterForwardDiffExt = "PolyesterForwardDiff"
SparseDiffToolsSymbolicsExt = "Symbolics"
SparseDiffToolsZygoteExt = "Zygote"

[compat]
ADTypes = "0.2.1"
Adapt = "3.0, 4"
ADTypes = "0.2.6"
Adapt = "3, 4"
ArrayInterface = "7.4.2"
Compat = "4"
DataStructures = "0.18"
Expand All @@ -47,7 +49,8 @@ ForwardDiff = "0.10"
Graphs = "1"
LinearAlgebra = "<0.0.1, 1"
PackageExtensionCompat = "1"
Random = "<0.0.1, 1"
PolyesterForwardDiff = "0.1.1"
Random = "1.6"
Reexport = "1"
SciMLOperators = "0.3.7"
Setfield = "1"
Expand All @@ -67,6 +70,7 @@ BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Expand All @@ -75,4 +79,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "BandedMatrices", "BlockBandedMatrices", "Enzyme", "IterativeSolvers", "Pkg", "Random", "SafeTestsets", "Symbolics", "Zygote", "StaticArrays"]
test = ["Test", "BandedMatrices", "BlockBandedMatrices", "Enzyme", "IterativeSolvers", "Pkg", "Random", "SafeTestsets", "Symbolics", "Zygote", "StaticArrays", "PolyesterForwardDiff"]
77 changes: 77 additions & 0 deletions ext/SparseDiffToolsPolyesterForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
module SparseDiffToolsPolyesterForwardDiffExt

using ADTypes, SparseDiffTools, PolyesterForwardDiff
import ForwardDiff
import SparseDiffTools: AbstractMaybeSparseJacobianCache, AbstractMaybeSparsityDetection,
ForwardColorJacCache, NoMatrixColoring, sparse_jacobian_cache, sparse_jacobian!,
sparse_jacobian_static_array, __standard_tag, __chunksize

struct PolyesterForwardDiffJacobianCache{CO, CA, J, FX, X} <:
AbstractMaybeSparseJacobianCache
coloring::CO
cache::CA
jac_prototype::J
fx::FX
x::X
end

function sparse_jacobian_cache(ad::Union{AutoSparsePolyesterForwardDiff,
AutoPolyesterForwardDiff}, sd::AbstractMaybeSparsityDetection, f::F, x;
fx = nothing) where {F}
coloring_result = sd(ad, f, x)
fx = fx === nothing ? similar(f(x)) : fx
if coloring_result isa NoMatrixColoring
cache = __chunksize(ad, x)
jac_prototype = nothing
else
@warn """Currently PolyesterForwardDiff does not support sparsity detection
natively. Falling back to using ForwardDiff.jl""" maxlog=1
tag = __standard_tag(nothing, x)
# Colored ForwardDiff passes `tag` directly into Dual so we need the `typeof`
cache = ForwardColorJacCache(f, x, __chunksize(ad); coloring_result.colorvec,
dx = fx, sparsity = coloring_result.jacobian_sparsity, tag = typeof(tag))
jac_prototype = coloring_result.jacobian_sparsity
end
return PolyesterForwardDiffJacobianCache(coloring_result, cache, jac_prototype, fx, x)
end

function sparse_jacobian_cache(ad::Union{AutoSparsePolyesterForwardDiff,

Check warning on line 38 in ext/SparseDiffToolsPolyesterForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SparseDiffToolsPolyesterForwardDiffExt.jl#L38

Added line #L38 was not covered by tests
AutoPolyesterForwardDiff}, sd::AbstractMaybeSparsityDetection, f!::F, fx,
x) where {F}
coloring_result = sd(ad, f!, fx, x)
if coloring_result isa NoMatrixColoring
cache = __chunksize(ad, x)
jac_prototype = nothing

Check warning on line 44 in ext/SparseDiffToolsPolyesterForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SparseDiffToolsPolyesterForwardDiffExt.jl#L41-L44

Added lines #L41 - L44 were not covered by tests
else
@warn """Currently PolyesterForwardDiff does not support sparsity detection

Check warning on line 46 in ext/SparseDiffToolsPolyesterForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SparseDiffToolsPolyesterForwardDiffExt.jl#L46

Added line #L46 was not covered by tests
natively. Falling back to using ForwardDiff.jl""" maxlog=1
tag = __standard_tag(nothing, x)

Check warning on line 48 in ext/SparseDiffToolsPolyesterForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SparseDiffToolsPolyesterForwardDiffExt.jl#L48

Added line #L48 was not covered by tests
# Colored ForwardDiff passes `tag` directly into Dual so we need the `typeof`
cache = ForwardColorJacCache(f!, x, __chunksize(ad); coloring_result.colorvec,

Check warning on line 50 in ext/SparseDiffToolsPolyesterForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SparseDiffToolsPolyesterForwardDiffExt.jl#L50

Added line #L50 was not covered by tests
dx = fx, sparsity = coloring_result.jacobian_sparsity, tag = typeof(tag))
jac_prototype = coloring_result.jacobian_sparsity

Check warning on line 52 in ext/SparseDiffToolsPolyesterForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SparseDiffToolsPolyesterForwardDiffExt.jl#L52

Added line #L52 was not covered by tests
end
return PolyesterForwardDiffJacobianCache(coloring_result, cache, jac_prototype, fx, x)

Check warning on line 54 in ext/SparseDiffToolsPolyesterForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SparseDiffToolsPolyesterForwardDiffExt.jl#L54

Added line #L54 was not covered by tests
end

function sparse_jacobian!(J::AbstractMatrix, _, cache::PolyesterForwardDiffJacobianCache,
f::F, x) where {F}
if cache.cache isa ForwardColorJacCache
forwarddiff_color_jacobian(J, f, x, cache.cache) # Use Sparse ForwardDiff
else
PolyesterForwardDiff.threaded_jacobian!(f, J, x, cache.cache) # Don't try to exploit sparsity
end
return J
end

function sparse_jacobian!(J::AbstractMatrix, _, cache::PolyesterForwardDiffJacobianCache,

Check warning on line 67 in ext/SparseDiffToolsPolyesterForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SparseDiffToolsPolyesterForwardDiffExt.jl#L67

Added line #L67 was not covered by tests
f!::F, fx, x) where {F}
if cache.cache isa ForwardColorJacCache
forwarddiff_color_jacobian!(J, f!, x, cache.cache) # Use Sparse ForwardDiff

Check warning on line 70 in ext/SparseDiffToolsPolyesterForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SparseDiffToolsPolyesterForwardDiffExt.jl#L69-L70

Added lines #L69 - L70 were not covered by tests
else
PolyesterForwardDiff.threaded_jacobian!(f!, fx, J, x, cache.cache) # Don't try to exploit sparsity

Check warning on line 72 in ext/SparseDiffToolsPolyesterForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SparseDiffToolsPolyesterForwardDiffExt.jl#L72

Added line #L72 was not covered by tests
end
return J

Check warning on line 74 in ext/SparseDiffToolsPolyesterForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SparseDiffToolsPolyesterForwardDiffExt.jl#L74

Added line #L74 was not covered by tests
end

end
7 changes: 5 additions & 2 deletions src/highlevel/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,8 @@ function init_jacobian end
const __init_𝒥 = init_jacobian

# Misc Functions
function __chunksize(::Union{AutoSparseForwardDiff{C}, AutoForwardDiff{C}}, x) where {C}
function __chunksize(::Union{AutoSparseForwardDiff{C}, AutoForwardDiff{C},
AutoSparsePolyesterForwardDiff{C}, AutoPolyesterForwardDiff{C}}, x) where {C}
C isa ForwardDiff.Chunk && return C
return __chunksize(Val(C), x)
end
Expand All @@ -285,7 +286,8 @@ end
__chunksize(x) = ForwardDiff.Chunk(x)
__chunksize(x::StaticArray) = ForwardDiff.Chunk{ForwardDiff.pickchunksize(prod(Size(x)))}()

function __chunksize(::Union{AutoSparseForwardDiff{C}, AutoForwardDiff{C}}) where {C}
function __chunksize(::Union{AutoSparseForwardDiff{C}, AutoForwardDiff{C},
AutoSparsePolyesterForwardDiff{C}, AutoPolyesterForwardDiff{C}}) where {C}
C === nothing && return nothing
C isa Integer && !(C isa Bool) && return C 0 ? nothing : Val(C)
return nothing
Expand Down Expand Up @@ -347,4 +349,5 @@ end
@inline __backend(::Union{AutoEnzyme, AutoSparseEnzyme}) = :Enzyme
@inline __backend(::Union{AutoZygote, AutoSparseZygote}) = :Zygote
@inline __backend(::Union{AutoForwardDiff, AutoSparseForwardDiff}) = :ForwardDiff
@inline __backend(::Union{AutoPolyesterForwardDiff, AutoSparsePolyesterForwardDiff}) = :PolyesterForwardDiff

Check warning on line 352 in src/highlevel/common.jl

View check run for this annotation

Codecov / codecov/patch

src/highlevel/common.jl#L352

Added line #L352 was not covered by tests
@inline __backend(::Union{AutoFiniteDiff, AutoSparseFiniteDiff}) = :FiniteDiff
27 changes: 20 additions & 7 deletions test/test_sparse_jacobian.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
## Sparse Jacobian tests
using SparseDiffTools,
Symbolics, ForwardDiff, LinearAlgebra, SparseArrays, Zygote, Enzyme, Test, StaticArrays
using SparseDiffTools, PolyesterForwardDiff, Symbolics, ForwardDiff, LinearAlgebra,
SparseArrays, Zygote, Enzyme, Test, StaticArrays

@views function fdiff(y, x) # in-place
L = length(x)
Expand Down Expand Up @@ -42,7 +42,12 @@ SPARSITY_DETECTION_ALGS = [JacPrototypeSparsityDetection(; jac_prototype = J_spa
AutoZygote(), AutoSparseForwardDiff(), AutoForwardDiff(),
AutoSparseForwardDiff(; chunksize = 0), AutoForwardDiff(; chunksize = 0),
AutoSparseForwardDiff(; chunksize = 4), AutoForwardDiff(; chunksize = 4),
AutoSparseFiniteDiff(), AutoFiniteDiff(), AutoEnzyme(), AutoSparseEnzyme())
AutoSparsePolyesterForwardDiff(), AutoPolyesterForwardDiff(),
AutoSparsePolyesterForwardDiff(; chunksize = 0),
AutoPolyesterForwardDiff(; chunksize = 0),
AutoSparsePolyesterForwardDiff(; chunksize = 4),
AutoPolyesterForwardDiff(; chunksize = 4), AutoSparseFiniteDiff(),
AutoFiniteDiff(), AutoEnzyme(), AutoSparseEnzyme())
@testset "Cache & Reuse" begin
cache = sparse_jacobian_cache(difftype, sd, fdiff, x)
J = init_jacobian(cache)
Expand All @@ -59,7 +64,9 @@ SPARSITY_DETECTION_ALGS = [JacPrototypeSparsityDetection(; jac_prototype = J_spa

@test J J_true

if !(difftype isa AutoSparseForwardDiff || difftype isa AutoForwardDiff)
if !(difftype isa AutoSparseForwardDiff || difftype isa AutoForwardDiff ||
difftype isa AutoSparsePolyesterForwardDiff ||
difftype isa AutoPolyesterForwardDiff)
@inferred sparse_jacobian(difftype, cache, fdiff, x)
end

Expand All @@ -71,7 +78,9 @@ SPARSITY_DETECTION_ALGS = [JacPrototypeSparsityDetection(; jac_prototype = J_spa
J = sparse_jacobian(difftype, sd, fdiff, x)

@test J J_true
if !(difftype isa AutoSparseForwardDiff || difftype isa AutoForwardDiff)
if !(difftype isa AutoSparseForwardDiff || difftype isa AutoForwardDiff ||
difftype isa AutoSparsePolyesterForwardDiff ||
difftype isa AutoPolyesterForwardDiff)
@inferred sparse_jacobian(difftype, sd, fdiff, x)
end

Expand Down Expand Up @@ -114,7 +123,9 @@ SPARSITY_DETECTION_ALGS = [JacPrototypeSparsityDetection(; jac_prototype = J_spa
J = sparse_jacobian(difftype, cache, fdiff, y, x)

@test J J_true
if !(difftype isa AutoSparseForwardDiff || difftype isa AutoForwardDiff)
if !(difftype isa AutoSparseForwardDiff || difftype isa AutoForwardDiff ||
difftype isa AutoSparsePolyesterForwardDiff ||
difftype isa AutoPolyesterForwardDiff)
@inferred sparse_jacobian(difftype, cache, fdiff, y, x)
end

Expand All @@ -126,7 +137,9 @@ SPARSITY_DETECTION_ALGS = [JacPrototypeSparsityDetection(; jac_prototype = J_spa
J = sparse_jacobian(difftype, sd, fdiff, y, x)

@test J J_true
if !(difftype isa AutoSparseForwardDiff || difftype isa AutoForwardDiff)
if !(difftype isa AutoSparseForwardDiff || difftype isa AutoForwardDiff ||
difftype isa AutoSparsePolyesterForwardDiff ||
difftype isa AutoPolyesterForwardDiff)
@inferred sparse_jacobian(difftype, sd, fdiff, y, x)
end

Expand Down

0 comments on commit 7c1264f

Please sign in to comment.