Skip to content

Commit

Permalink
Make imports explicit, test with ExplicitImports.jl (#188)
Browse files Browse the repository at this point in the history
* Fixes for ExplicitImports.jl

* Add package tests

* Add separate `test/linting.jl` file
  • Loading branch information
adrhill authored Sep 4, 2024
1 parent 116f027 commit 45be93b
Show file tree
Hide file tree
Showing 11 changed files with 293 additions and 52 deletions.
10 changes: 4 additions & 6 deletions ext/SparseConnectivityTracerDataInterpolationsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@ if isdefined(Base, :get_extension)
using SparseConnectivityTracer: AbstractTracer, Dual, primal, tracer
using SparseConnectivityTracer: GradientTracer, gradient_tracer_1_to_1
using SparseConnectivityTracer: HessianTracer, hessian_tracer_1_to_1
using SparseConnectivityTracer: Fill # from FillArrays.jl
using FillArrays: Fill # from FillArrays.jl
import DataInterpolations:
AbstractInterpolation,
LinearInterpolation,
QuadraticInterpolation,
LagrangeInterpolation,
Expand All @@ -19,15 +18,14 @@ if isdefined(Base, :get_extension)
BSplineInterpolation,
BSplineApprox,
CubicHermiteSpline,
PCHIPInterpolation,
# PCHIPInterpolation,
QuinticHermiteSpline
else
using ..SparseConnectivityTracer: AbstractTracer, Dual, primal, tracer
using ..SparseConnectivityTracer: GradientTracer, gradient_tracer_1_to_1
using ..SparseConnectivityTracer: HessianTracer, hessian_tracer_1_to_1
using ..SparseConnectivityTracer: Fill # from FillArrays.jl
using ..FillArrays: Fill # from FillArrays.jl
import ..DataInterpolations:
AbstractInterpolation,
LinearInterpolation,
QuadraticInterpolation,
LagrangeInterpolation,
Expand All @@ -38,7 +36,7 @@ else
BSplineInterpolation,
BSplineApprox,
CubicHermiteSpline,
PCHIPInterpolation,
# PCHIPInterpolation,
QuinticHermiteSpline
end

Expand Down
54 changes: 52 additions & 2 deletions ext/SparseConnectivityTracerLogExpFunctionsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,60 @@ module SparseConnectivityTracerLogExpFunctionsExt

if isdefined(Base, :get_extension)
import SparseConnectivityTracer as SCT
using LogExpFunctions
using LogExpFunctions:
LogExpFunctions,
cexpexp,
cloglog,
log1mexp,
log1mlogistic,
log1pexp,
log1pmx,
log1psq,
log2mexp,
logabssinh,
logaddexp,
logcosh,
logexpm1,
logistic,
logit,
logit1mexp,
logitexp,
loglogistic,
logmxp1,
logsubexp,
xexpx,
xexpy,
xlog1py,
xlogx,
xlogy
else
import ..SparseConnectivityTracer as SCT
using ..LogExpFunctions
using ..LogExpFunctions:
LogExpFunctions,
cexpexp,
cloglog,
log1mexp,
log1mlogistic,
log1pexp,
log1pmx,
log1psq,
log2mexp,
logabssinh,
logaddexp,
logcosh,
logexpm1,
logistic,
logit,
logit1mexp,
logitexp,
loglogistic,
logmxp1,
logsubexp,
xexpx,
xexpy,
xlog1py,
xlogx,
xlogy
end

## 1-to-1 functions
Expand Down
52 changes: 50 additions & 2 deletions ext/SparseConnectivityTracerNNlibExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,58 @@ module SparseConnectivityTracerNNlibExt

if isdefined(Base, :get_extension)
import SparseConnectivityTracer as SCT
using NNlib
using NNlib:
NNlib,
celu,
elu,
gelu,
hardswish,
hardtanh,
hardσ,
leakyrelu,
lisht,
logcosh,
logσ,
mish,
relu,
relu6,
selu,
sigmoid_fast,
softplus,
softshrink,
softsign,
swish,
tanh_fast,
tanhshrink,
trelu,
σ
else
import ..SparseConnectivityTracer as SCT
using ..NNlib
using ..NNlib:
NNlib,
celu,
elu,
gelu,
hardswish,
hardtanh,
hardσ,
leakyrelu,
lisht,
logcosh,
logσ,
mish,
relu,
relu6,
selu,
sigmoid_fast,
softplus,
softshrink,
softsign,
swish,
tanh_fast,
tanhshrink,
trelu,
σ
end

## 1-to-1
Expand Down
4 changes: 2 additions & 2 deletions ext/SparseConnectivityTracerNaNMathExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ module SparseConnectivityTracerNaNMathExt

if isdefined(Base, :get_extension)
import SparseConnectivityTracer as SCT
using NaNMath
using NaNMath: NaNMath
else
import ..SparseConnectivityTracer as SCT
using ..NaNMath
using ..NaNMath: NaNMath
end

## 1-to-1
Expand Down
92 changes: 90 additions & 2 deletions ext/SparseConnectivityTracerSpecialFunctionsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,98 @@ module SparseConnectivityTracerSpecialFunctionsExt

if isdefined(Base, :get_extension)
import SparseConnectivityTracer as SCT
using SpecialFunctions
using SpecialFunctions:
SpecialFunctions,
airyai,
airyaiprime,
airyaiprimex,
airyaix,
airybi,
airybiprime,
airybiprimex,
airybix,
besseli,
besselix,
besselj,
besselj0,
besselj1,
besseljx,
besselk,
besselkx,
bessely,
bessely0,
bessely1,
besselyx,
beta,
cosint,
digamma,
ellipe,
ellipk,
erf,
erfc,
erfcinv,
erfcx,
erfinv,
expint,
expinti,
expintx,
gamma,
invdigamma,
jinc,
logbeta,
logerfc,
loggamma,
sinint,
sphericalbesselj,
sphericalbessely,
trigamma
else
import ..SparseConnectivityTracer as SCT
using ..SpecialFunctions
using ..SpecialFunctions:
SpecialFunctions,
airyai,
airyaiprime,
airyaiprimex,
airyaix,
airybi,
airybiprime,
airybiprimex,
airybix,
besseli,
besselix,
besselj,
besselj0,
besselj1,
besseljx,
besselk,
besselkx,
bessely,
bessely0,
bessely1,
besselyx,
beta,
cosint,
digamma,
ellipe,
ellipk,
erf,
erfc,
erfcinv,
erfcx,
erfinv,
expint,
expinti,
expintx,
gamma,
invdigamma,
jinc,
logbeta,
logerfc,
loggamma,
sinint,
sphericalbesselj,
sphericalbessely,
trigamma
end

#=
Expand Down
2 changes: 1 addition & 1 deletion src/SparseConnectivityTracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using LinearAlgebra: LinearAlgebra, Symmetric
using LinearAlgebra: Diagonal, diag, diagind
using FillArrays: Fill

using DocStringExtensions
using DocStringExtensions: DocStringExtensions, TYPEDEF, TYPEDFIELDS

if !isdefined(Base, :get_extension)
using Requires
Expand Down
2 changes: 1 addition & 1 deletion src/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ for op in ops_2_to_1_fsc
end

# gradient of x/y: [1/y -x/y²]
SparseConnectivityTracer.is_der1_arg2_zero_local(::typeof(/), x, y) = iszero(x)
is_der1_arg2_zero_local(::typeof(/), x, y) = iszero(x)

# ops_2_to_1_fsz:
# ∂f/∂x != 0
Expand Down
8 changes: 4 additions & 4 deletions src/overloads/arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,12 @@ function LinearAlgebra.eigen(
end

## Inverse
function LinearAlgebra.inv(A::StridedMatrix{T}) where {T<:AbstractTracer}
function Base.inv(A::StridedMatrix{T}) where {T<:AbstractTracer}
LinearAlgebra.checksquare(A)
t = second_order_or(A)
return Fill(t, size(A)...)
end
function LinearAlgebra.inv(D::Diagonal{T}) where {T<:AbstractTracer}
function Base.inv(D::Diagonal{T}) where {T<:AbstractTracer}
ts_in = D.diag
ts_out = similar(ts_in)
for i in 1:length(ts_out)
Expand All @@ -132,7 +132,7 @@ function LinearAlgebra.pinv(
t = second_order_or(A)
return Fill(t, m, n)
end
LinearAlgebra.pinv(D::Diagonal{T}) where {T<:AbstractTracer} = LinearAlgebra.inv(D)
LinearAlgebra.pinv(D::Diagonal{T}) where {T<:AbstractTracer} = inv(D)

## Division
function LinearAlgebra.:\(
Expand All @@ -143,7 +143,7 @@ function LinearAlgebra.:\(
end

## Exponential
function LinearAlgebra.exp(A::AbstractMatrix{T}) where {T<:AbstractTracer}
function Base.exp(A::AbstractMatrix{T}) where {T<:AbstractTracer}
LinearAlgebra.checksquare(A)
n = size(A, 1)
t = second_order_or(A)
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
Expand Down
Loading

0 comments on commit 45be93b

Please sign in to comment.