Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DataInterpolations support #178

Merged
merged 35 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
8aaaabd
Add draft of DataInterpolations extension
SouthEndMusic Aug 20, 2024
3155a07
Next attempt
SouthEndMusic Aug 21, 2024
598fe2e
The next attempt
SouthEndMusic Aug 21, 2024
1c282d9
Move tests to `test/ext` folder
adrhill Aug 22, 2024
d3eff69
Add package extrension testset
adrhill Aug 22, 2024
e348a86
Add overloads
adrhill Aug 22, 2024
3d4c73d
Fix for Julia 1.6
adrhill Aug 22, 2024
4c659d0
Specify DataInterp ext. isn't an ideal template
adrhill Aug 22, 2024
78052bd
Test on more types of interpolation
adrhill Aug 22, 2024
1e73da1
Fix path
adrhill Aug 22, 2024
c0076fa
Test local tracers
adrhill Aug 22, 2024
4228dbc
Fix docstrings and testset names
adrhill Aug 22, 2024
7dc48da
Reorganize testset code
adrhill Aug 23, 2024
146fbb9
Add matrix test case
adrhill Aug 23, 2024
d3136fc
Prepare N-dim testing
adrhill Aug 23, 2024
9d444ab
Test output values and shapes
adrhill Aug 23, 2024
ae9749d
Minor fix
adrhill Aug 23, 2024
a322ae5
Test more types of interpolations
adrhill Aug 23, 2024
67e39ad
Subtype `AbstractVector`
adrhill Aug 23, 2024
67532af
Support N-dim interpolations
adrhill Aug 23, 2024
b90e043
Add disclaimer
adrhill Aug 23, 2024
d30086c
Fix for 1.6
adrhill Aug 23, 2024
c5a2291
Fix for 1.6 v2
adrhill Aug 23, 2024
8c8e194
Fix for 1.6 v3
adrhill Aug 23, 2024
7524a02
Reorganize code and add TODOs
adrhill Aug 28, 2024
457a1cc
Add test cases on other input types
adrhill Aug 28, 2024
abe0189
Fix Jacobian tests
adrhill Aug 29, 2024
f603ff4
Fix Hessian tests
adrhill Aug 29, 2024
c1bde9e
Remove TODOs
adrhill Aug 29, 2024
a98b0d6
Remove TODOs v2
adrhill Aug 29, 2024
894da06
Remove overloads on `Dual` when possible
adrhill Aug 29, 2024
99ac119
Remove support for Julia <1.10
adrhill Sep 2, 2024
814a042
Fix ambiguity
adrhill Sep 2, 2024
c993787
Test generation of tests
adrhill Sep 2, 2024
23ee8ab
Remove tests on interpolants containing tracers
adrhill Sep 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,28 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[weakdeps]
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"

[extensions]
SparseConnectivityTracerNNlibExt = "NNlib"
SparseConnectivityTracerSpecialFunctionsExt = "SpecialFunctions"
SparseConnectivityTracerDataInterpolationsExt = "DataInterpolations"

[extras]
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[compat]
ADTypes = "1"
DataInterpolations = "6"
DocStringExtensions = "0.9"
FillArrays = "1"
LinearAlgebra = "<0.0.1, 1"
Expand Down
2 changes: 1 addition & 1 deletion docs/src/dev/adding_overloads.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ to improve the performance of your functions or to work around some of SCT's [li
This page of the documentation shows you how.

!!! tip "Copy one of our package extensions"
The easiest way to add overloads is to copy one of our [package extensions](https://github.com/adrhill/SparseConnectivityTracer.jl/tree/main/ext) and to modify it.
The easiest way to add overloads is to copy one of our package extensions, [e.g. our NNlib extension](https://github.com/adrhill/SparseConnectivityTracer.jl/blob/main/ext/SparseConnectivityTracerNNlibExt.jl), and to modify it.
Please upstream your additions by opening a pull request! We will help you out to get your feature merged.

## Operator classification
Expand Down
173 changes: 173 additions & 0 deletions ext/SparseConnectivityTracerDataInterpolationsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# WARNING: If you are following the "Adding Overloads" guide's advice to copy an existing package extension,
# copy another, less complicated one!
module SparseConnectivityTracerDataInterpolationsExt

if isdefined(Base, :get_extension)
using SparseConnectivityTracer: AbstractTracer, Dual, primal, tracer
using SparseConnectivityTracer: GradientTracer, gradient_tracer_1_to_1
using SparseConnectivityTracer: HessianTracer, hessian_tracer_1_to_1
using SparseConnectivityTracer: Fill # from FillArrays.jl
import DataInterpolations:
AbstractInterpolation,
LinearInterpolation,
QuadraticInterpolation,
LagrangeInterpolation,
AkimaInterpolation,
ConstantInterpolation,
QuadraticSpline,
CubicSpline,
BSplineInterpolation,
BSplineApprox
# TODO: support when Julia 1.6 is dropped
# CubicHermiteSpline,
# PCHIPInterpolation,
# QuinticHermiteSpline
else
using ..SparseConnectivityTracer: AbstractTracer, Dual, primal, tracer
using ..SparseConnectivityTracer: GradientTracer, gradient_tracer_1_to_1
using ..SparseConnectivityTracer: HessianTracer, hessian_tracer_1_to_1
using ..SparseConnectivityTracer: Fill # from FillArrays.jl
import ..DataInterpolations:
AbstractInterpolation,
LinearInterpolation,
QuadraticInterpolation,
LagrangeInterpolation,
AkimaInterpolation,
ConstantInterpolation,
QuadraticSpline,
CubicSpline,
BSplineInterpolation,
BSplineApprox
# TODO: support when Julia 1.6 is dropped
# CubicHermiteSpline,
# PCHIPInterpolation,
# QuinticHermiteSpline
end

#========================#
# General interpolations #
#========================#

# We assume that with the exception of ConstantInterpolation and LinearInterpolation,
# all interpolations have a non-zero second derivative at some point in the input domain.

for I in (
:QuadraticInterpolation,
:LagrangeInterpolation,
:AkimaInterpolation,
:QuadraticSpline,
:CubicSpline,
:BSplineInterpolation,
:BSplineApprox,
# TODO: support when Julia 1.6 is dropped
# :CubicHermiteSpline,
# :QuinticHermiteSpline,
)
# 1D Interpolations (uType<:AbstractVector)
@eval function (interp::$(I){uType})(t::GradientTracer) where {uType<:AbstractVector}
return gradient_tracer_1_to_1(t, false)
end
@eval function (interp::$(I){uType})(t::HessianTracer) where {uType<:AbstractVector}
return hessian_tracer_1_to_1(t, false, false)
end

# ND Interpolations (uType<:AbstractMatrix)
@eval function (interp::$(I){uType})(t::GradientTracer) where {uType<:AbstractMatrix}
t = gradient_tracer_1_to_1(t, false)
nstates = size(interp.u, 1)
return Fill(t, nstates)
end
@eval function (interp::$(I){uType})(t::HessianTracer) where {uType<:AbstractMatrix}
t = hessian_tracer_1_to_1(t, false, false)
nstates = size(interp.u, 1)
return Fill(t, nstates)
end
end

# Some Interpolations require custom overloads on `Dual` due to mutation of caches.
for I in (
:LagrangeInterpolation,
:BSplineInterpolation,
:BSplineApprox,
# TODO: support when Julia 1.6 is dropped
# :CubicHermiteSpline,
# :QuinticHermiteSpline,
)
@eval function (interp::$(I){uType})(d::Dual) where {uType<:AbstractVector}
p = interp(primal(d))
t = interp(tracer(d))
return Dual(p, t)
end

@eval function (interp::$(I){uType})(d::Dual) where {uType<:AbstractMatrix}
p = interp(primal(d))
t = interp(tracer(d))
return Dual.(p, t)
end
end

#=======================#
# ConstantInterpolation #
#=======================#

# 1D Interpolations (uType<:AbstractVector)
function (interp::ConstantInterpolation{uType})(
t::GradientTracer
) where {uType<:AbstractVector}
return gradient_tracer_1_to_1(t, true)
end
function (interp::ConstantInterpolation{uType})(
t::HessianTracer
) where {uType<:AbstractVector}
return hessian_tracer_1_to_1(t, true, true)
end

# ND Interpolations (uType<:AbstractMatrix)
function (interp::ConstantInterpolation{uType})(
t::GradientTracer
) where {uType<:AbstractMatrix}
t = gradient_tracer_1_to_1(t, true)
nstates = size(interp.u, 1)
return Fill(t, nstates)
end
function (interp::ConstantInterpolation{uType})(
t::HessianTracer
) where {uType<:AbstractMatrix}
t = hessian_tracer_1_to_1(t, true, true)
nstates = size(interp.u, 1)
return Fill(t, nstates)
end

#=====================#
# LinearInterpolation #
#=====================#

# 1D Interpolations (uType<:AbstractVector)
function (interp::LinearInterpolation{uType})(
t::GradientTracer
) where {uType<:AbstractVector}
return gradient_tracer_1_to_1(t, false)
end
function (interp::LinearInterpolation{uType})(
t::HessianTracer
) where {uType<:AbstractVector}
return hessian_tracer_1_to_1(t, false, true)
end

# ND Interpolations (uType<:AbstractMatrix)
function (interp::LinearInterpolation{uType})(
t::GradientTracer
) where {uType<:AbstractMatrix}
t = gradient_tracer_1_to_1(t, false)
nstates = size(interp.u, 1)
return Fill(t, nstates)
end
function (interp::LinearInterpolation{uType})(
t::HessianTracer
) where {uType<:AbstractMatrix}
t = hessian_tracer_1_to_1(t, false, true)
nstates = size(interp.u, 1)
return Fill(t, nstates)
end

end # module SparseConnectivityTracerDataInterpolationsExt
3 changes: 3 additions & 0 deletions src/SparseConnectivityTracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ function __init__()
@require NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" include(
"../ext/SparseConnectivityTracerNNlibExt.jl"
)
@require DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0" include(
"../ext/SparseConnectivityTracerDataInterpolationsExt.jl"
)
end
end

Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Expand Down
Loading
Loading