Skip to content

Commit

Permalink
The next attempt
Browse files Browse the repository at this point in the history
  • Loading branch information
SouthEndMusic authored and adrhill committed Aug 22, 2024
1 parent 3155a07 commit 598fe2e
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 31 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
[extensions]
SparseConnectivityTracerNNlibExt = "NNlib"
SparseConnectivityTracerSpecialFunctionsExt = "SpecialFunctions"
SparseConnectivityTracerDataInteprolationsExt = "DataInterpolations"
SparseConnectivityTracerDataInterpolationsExt = "DataInterpolations"

[compat]
ADTypes = "1"
Expand Down
32 changes: 11 additions & 21 deletions ext/SparseConnectivityTracerDataInterpolationsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,23 @@ import SparseConnectivityTracer as SCT

if isdefined(Base, :get_extension)
using DataInterpolations
using DataInterpolations: _interpolate, _derivative
else
using ..DataInterpolations
using ..DataInterpolations: _interpolate, _derivative
end

interpolation_types = []
for name in names(DataInterpolations)
if isdefined(DataInterpolations, name)
val = getfield(DataInterpolations, name)
if val isa Type && val <: DataInterpolations.AbstractInterpolation
push!(interpolation_types, val)
end
end
end
operations = [_interpolate, _derivative]

for interpolation_type in interpolation_types
if interpolation_type == ConstantInterpolation
@eval SCT.is_der1_zero_global(::Type{$interpolation_type}) = true
@eval SCT.is_der2_zero_global(::Type{$interpolation_type}) = true
elseif interpolation_type == LinearInterpolation
@eval SCT.is_der1_zero_global(::Type{$interpolation_type}) = false
@eval SCT.is_der2_zero_global(::Type{$interpolation_type}) = true
else
@eval SCT.is_der1_zero_global(::Type{$interpolation_type}) = false
@eval SCT.is_der2_zero_global(::Type{$interpolation_type}) = false
end
for operation in operations
@eval SCT.is_der1_arg1_zero_global(::typeof($operation)) = true
@eval SCT.is_der2_arg1_zero_global(::typeof($operation)) = true
@eval SCT.is_der1_arg2_zero_global(::typeof($operation)) = false
@eval SCT.is_der2_arg2_zero_global(::typeof($operation)) = false
@eval SCT.is_der_cross_zero_global(::typeof($operation)) = true
end

SCT.overload_gradient_1_to_1(:DataInterpolations, interpolation_types)
eval(SCT.overload_gradient_2_to_1(:DataInterpolations, operations))
eval(SCT.overload_hessian_2_to_1(:DataInterpolations, operations))

end # module SparseConnectivityTracerDataInterpolationsExt
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Expand Down
11 changes: 2 additions & 9 deletions test/test_gradient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,8 @@ using LinearAlgebra: det, dot, logdet
# Load definitions of GRADIENT_TRACERS, GRADIENT_PATTERNS, HESSIAN_TRACERS and HESSIAN_PATTERNS
include("tracers_definitions.jl")

interpolation_types = []
for name in names(DataInterpolations)
if isdefined(DataInterpolations, name)
val = getfield(DataInterpolations, name)
if val isa Type && val <: DataInterpolations.AbstractInterpolation
push!(interpolation_types, val)
end
end
end
# Sample of interpolation types
interpolation_types = [ConstantInterpolation, LinearInterpolation, QuadraticInterpolation]

REAL_TYPES = (Float64, Int, Bool, UInt8, Float16, Rational{Int})

Expand Down

0 comments on commit 598fe2e

Please sign in to comment.