From 598fe2e59494393e4ae30041f93a4ebabad8a1d1 Mon Sep 17 00:00:00 2001 From: Bart de Koning Date: Wed, 21 Aug 2024 14:27:18 +0200 Subject: [PATCH] The next attempt --- Project.toml | 2 +- ...ConnectivityTracerDataInterpolationsExt.jl | 32 +++++++------------ test/Project.toml | 1 + test/test_gradient.jl | 11 ++----- 4 files changed, 15 insertions(+), 31 deletions(-) diff --git a/Project.toml b/Project.toml index 825dce5..659eb51 100644 --- a/Project.toml +++ b/Project.toml @@ -22,7 +22,7 @@ DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0" [extensions] SparseConnectivityTracerNNlibExt = "NNlib" SparseConnectivityTracerSpecialFunctionsExt = "SpecialFunctions" -SparseConnectivityTracerDataInteprolationsExt = "DataInterpolations" +SparseConnectivityTracerDataInterpolationsExt = "DataInterpolations" [compat] ADTypes = "1" diff --git a/ext/SparseConnectivityTracerDataInterpolationsExt.jl b/ext/SparseConnectivityTracerDataInterpolationsExt.jl index 4e4c369..94241f9 100644 --- a/ext/SparseConnectivityTracerDataInterpolationsExt.jl +++ b/ext/SparseConnectivityTracerDataInterpolationsExt.jl @@ -3,33 +3,23 @@ import SparseConnectivityTracer as SCT if isdefined(Base, :get_extension) using DataInterpolations + using DataInterpolations: _interpolate, _derivative else using ..DataInterpolations + using ..DataInterpolations: _interpolate, _derivative end -interpolation_types = [] -for name in names(DataInterpolations) - if isdefined(DataInterpolations, name) - val = getfield(DataInterpolations, name) - if val isa Type && val <: DataInterpolations.AbstractInterpolation - push!(interpolation_types, val) - end - end -end +operations = [_interpolate, _derivative] -for interpolation_type in interpolation_types - if interpolation_type == ConstantInterpolation - @eval SCT.is_der1_zero_global(::Type{$interpolation_type}) = true - @eval SCT.is_der2_zero_global(::Type{$interpolation_type}) = true - elseif interpolation_type == LinearInterpolation - @eval SCT.is_der1_zero_global(::Type{$interpolation_type}) = false - @eval SCT.is_der2_zero_global(::Type{$interpolation_type}) = true - else - @eval SCT.is_der1_zero_global(::Type{$interpolation_type}) = false - @eval SCT.is_der2_zero_global(::Type{$interpolation_type}) = false - end +for operation in operations + @eval SCT.is_der1_arg1_zero_global(::typeof($operation)) = true + @eval SCT.is_der2_arg1_zero_global(::typeof($operation)) = true + @eval SCT.is_der1_arg2_zero_global(::typeof($operation)) = false + @eval SCT.is_der2_arg2_zero_global(::typeof($operation)) = false + @eval SCT.is_der_cross_zero_global(::typeof($operation)) = true end -SCT.overload_gradient_1_to_1(:DataInterpolations, interpolation_types) +eval(SCT.overload_gradient_2_to_1(:DataInterpolations, operations)) +eval(SCT.overload_hessian_2_to_1(:DataInterpolations, operations)) end # module SparseConnectivityTracerDataInterpolationsExt diff --git a/test/Project.toml b/test/Project.toml index fca60bc..ce3bd9f 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -5,6 +5,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" +DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" diff --git a/test/test_gradient.jl b/test/test_gradient.jl index 8d6f3c2..3f0acfd 100644 --- a/test/test_gradient.jl +++ b/test/test_gradient.jl @@ -9,15 +9,8 @@ using LinearAlgebra: det, dot, logdet # Load definitions of GRADIENT_TRACERS, GRADIENT_PATTERNS, HESSIAN_TRACERS and HESSIAN_PATTERNS include("tracers_definitions.jl") -interpolation_types = [] -for name in names(DataInterpolations) - if isdefined(DataInterpolations, name) - val = getfield(DataInterpolations, name) - if val isa Type && val <: DataInterpolations.AbstractInterpolation - push!(interpolation_types, val) - end - end -end +# Sample of interpolation types +interpolation_types = [ConstantInterpolation, LinearInterpolation, QuadraticInterpolation] REAL_TYPES = (Float64, Int, Bool, UInt8, Float16, Rational{Int})