Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LogExpFunctions package extension #184

Merged
merged 2 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,33 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[weakdeps]
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"

[extensions]
SparseConnectivityTracerDataInterpolationsExt = "DataInterpolations"
SparseConnectivityTracerLogExpFunctionsExt = "LogExpFunctions"
SparseConnectivityTracerNNlibExt = "NNlib"
SparseConnectivityTracerSpecialFunctionsExt = "SpecialFunctions"

[extras]
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[compat]
ADTypes = "1"
DataInterpolations = "6.2"
DocStringExtensions = "0.9"
FillArrays = "1"
LinearAlgebra = "<0.0.1, 1"
LogExpFunctions = "0.3"
NNlib = "0.8, 0.9"
Random = "<0.0.1, 1"
Requires = "1.3"
SparseArrays = "<0.0.1, 1"
SpecialFunctions = "2.4"
julia = "1.6"

[extras]
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
65 changes: 65 additions & 0 deletions ext/SparseConnectivityTracerLogExpFunctionsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
module SparseConnectivityTracerLogExpFunctionsExt

if isdefined(Base, :get_extension)
import SparseConnectivityTracer as SCT
using LogExpFunctions
else
import ..SparseConnectivityTracer as SCT
using ..LogExpFunctions
end

# 1-to-1 functions

ops_1_to_1 = (
xlogx,
xexpx,
logistic,
logit,
logcosh,
logabssinh,
log1psq,
log1pexp,
log1mexp,
log2mexp,
logexpm1,
softplus,
invsoftplus,
log1pmx,
logmxp1,
cloglog,
cexpexp,
loglogistic,
logitexp,
log1mlogistic,
logit1mexp,
)

for op in ops_1_to_1
T = typeof(op)
@eval SCT.is_der1_zero_global(::$T) = false
@eval SCT.is_der2_zero_global(::$T) = false
end

# 2-to-1 functions

ops_2_to_1 = (xlogy, xlog1py, xexpy, logaddexp, logsubexp)

for op in ops_2_to_1
T = typeof(op)
@eval SCT.is_der1_arg1_zero_global(::$T) = false
@eval SCT.is_der2_arg1_zero_global(::$T) = false
@eval SCT.is_der1_arg2_zero_global(::$T) = false
@eval SCT.is_der2_arg2_zero_global(::$T) = false
@eval SCT.is_der_cross_zero_global(::$T) = false
end

# Generate overloads
eval(SCT.generate_code_1_to_1(:LogExpFunctions, ops_1_to_1))
eval(SCT.generate_code_2_to_1(:LogExpFunctions, ops_2_to_1))

# List operators for later testing
SCT.test_operators_1_to_1(::Val{:LogExpFunctions}) = ops_1_to_1
SCT.test_operators_2_to_1(::Val{:LogExpFunctions}) = ops_2_to_1
SCT.test_operators_1_to_2(::Val{:LogExpFunctions}) = ()

end
3 changes: 3 additions & 0 deletions src/SparseConnectivityTracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ function __init__()
@require NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" include(
"../ext/SparseConnectivityTracerNNlibExt.jl"
)
@require LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" include(
"../ext/SparseConnectivityTracerLogExpFunctionsExt.jl"
)
# NOTE: SparseConnectivityTracerDataInterpolationsExt is not loaded on Julia <1.10
end
end
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
NLPModels = "a4795742-8479-5a88-8948-cc11e1c8c1a6"
NLPModelsJuMP = "792afdf1-32c1-5681-94e0-d7bf7a5df49e"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Expand Down
15 changes: 11 additions & 4 deletions test/classification.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using SparseConnectivityTracer: # 1-to-1
using SparseConnectivityTracer: # 1-to-1
is_der1_zero_global,
is_der2_zero_global,
is_der1_zero_local,
Expand Down Expand Up @@ -29,6 +29,7 @@ using SparseConnectivityTracer: # testing
test_operators_1_to_2
using SpecialFunctions: SpecialFunctions
using NNlib: NNlib
using LogExpFunctions: LogExpFunctions
using Test
using ForwardDiff: derivative, gradient, hessian

Expand All @@ -43,6 +44,12 @@ random_input(op) = rand()
random_input(::Union{typeof(acosh),typeof(acoth),typeof(acsc),typeof(asec)}) = 1 + rand()
random_input(::typeof(sincosd)) = 180 * rand()

# LogExpFunctions.jl
random_input(::typeof(LogExpFunctions.log1mexp)) = -rand() # log1mexp(x) is defined for x < 0
random_input(::typeof(LogExpFunctions.log2mexp)) = -rand() # log2mexp(x) is defined for x < 0
random_input(::typeof(LogExpFunctions.logitexp)) = -rand() # logitexp(x) is defined for x < 0
random_input(::typeof(LogExpFunctions.logit1mexp)) = -rand() # logit1mexp(x) is defined for x < 0

random_first_input(op) = random_input(op)
random_second_input(op) = random_input(op)

Expand Down Expand Up @@ -90,7 +97,7 @@ function correct_classification_1_to_1(op, x; atol)
end

@testset verbose = true "1-to-1" begin
@testset "$m" for m in (Base, SpecialFunctions, NNlib)
@testset "$m" for m in (Base, SpecialFunctions, NNlib, LogExpFunctions)
@testset "$op" for op in test_operators_1_to_1(Val(Symbol(m)))
@test all(
correct_classification_1_to_1(op, random_input(op); atol=DEFAULT_ATOL) for
Expand Down Expand Up @@ -133,7 +140,7 @@ function correct_classification_2_to_1(op, x, y; atol)
end

@testset verbose = true "2-to-1" begin
@testset "$m" for m in (Base, SpecialFunctions, NNlib)
@testset "$m" for m in (Base, SpecialFunctions, NNlib, LogExpFunctions)
@testset "$op" for op in test_operators_2_to_1(Val(Symbol(m)))
@test all(
correct_classification_2_to_1(
Expand Down Expand Up @@ -173,7 +180,7 @@ function correct_classification_1_to_2(op, x; atol)
end

@testset verbose = true "1-to-2" begin
@testset "$m" for m in (Base, SpecialFunctions, NNlib)
@testset "$m" for m in (Base, SpecialFunctions, NNlib, LogExpFunctions)
@testset "$op" for op in test_operators_1_to_2(Val(Symbol(m)))
@test all(
correct_classification_1_to_2(op, random_input(op); atol=DEFAULT_ATOL) for
Expand Down
98 changes: 98 additions & 0 deletions test/ext/test_LogExpFunctions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
using SparseConnectivityTracer
using LogExpFunctions
using Test

# Load definitions of GRADIENT_TRACERS, GRADIENT_PATTERNS, HESSIAN_TRACERS and HESSIAN_PATTERNS
include("../tracers_definitions.jl")

lef_1_to_1_pos_input = (
xlogx,
logistic,
logit,
log1psq,
log1pexp,
logexpm1,
softplus,
invsoftplus,
log1pmx,
logmxp1,
logcosh,
logabssinh,
cloglog,
cexpexp,
loglogistic,
log1mlogistic,
)
lef_1_to_1_neg_input = (log1mexp, log2mexp, logitexp, logit1mexp)
lef_1_to_1 = union(lef_1_to_1_pos_input, lef_1_to_1_neg_input)
lef_2_to_1 = (xlogy, xlog1py, xexpy, logaddexp, logsubexp)

@testset "Jacobian Global" begin
method = TracerSparsityDetector()
J(f, x) = jacobian_sparsity(f, x, method)

@testset "1-to-1 functions" begin
@testset "$f" for f in lef_1_to_1
@test J(x -> f(x[1]), rand(2)) == [1 0]
end
end
@testset "2-to-1 functions" begin
@testset "$f" for f in lef_2_to_1
@test J(x -> f(x[1], x[2]), rand(3)) == [1 1 0]
end
end
end

@testset "Jacobian Local" begin
method = TracerLocalSparsityDetector()
J(f, x) = jacobian_sparsity(f, x, method)

@testset "1-to-1 functions" begin
@testset "$f" for f in lef_1_to_1_pos_input
@test J(x -> f(x[1]), [0.5, 1.0]) == [1 0]
end
@testset "$f" for f in lef_1_to_1_neg_input
@test J(x -> f(x[1]), [-0.5, 1.0]) == [1 0]
end
end
@testset "2-to-1 functions" begin
@testset "$f" for f in lef_2_to_1
@test J(x -> f(x[1], x[2]), [0.5, 1.0, 2.0]) == [1 1 0]
end
end
end

@testset "Hessian Global" begin
method = TracerSparsityDetector()
H(f, x) = hessian_sparsity(f, x, method)

@testset "1-to-1 functions" begin
@testset "$f" for f in lef_1_to_1
@test H(x -> f(x[1]), rand(2)) == [1 0; 0 0]
end
end
@testset "2-to-1 functions" begin
@testset "$f" for f in lef_2_to_1
@test H(x -> f(x[1], x[2]), rand(3)) == [1 1 0; 1 1 0; 0 0 0]
end
end
end

@testset "Hessian Local" begin
method = TracerLocalSparsityDetector()
H(f, x) = hessian_sparsity(f, x, method)

@testset "1-to-1 functions" begin
@testset "$f" for f in lef_1_to_1_pos_input
@test H(x -> f(x[1]), [0.5, 1.0]) == [1 0; 0 0]
end
@testset "$f" for f in lef_1_to_1_neg_input
@test H(x -> f(x[1]), [-0.5, 1.0]) == [1 0; 0 0]
end
end
@testset "2-to-1 functions" begin
@testset "$f" for f in lef_2_to_1
@test H(x -> f(x[1], x[2]), [0.5, 1.0, 2.0]) == [1 1 0; 1 1 0; 0 0 0]
end
end
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ GROUP = get(ENV, "JULIA_SCT_TEST_GROUP", "Core")
if GROUP in ("Core", "All")
@info "Testing package extensions..."
@testset verbose = true "Package extensions" begin
for ext in (:NNlib, :SpecialFunctions)
for ext in (:NNlib, :SpecialFunctions, :LogExpFunctions)
@testset "$ext" begin
@info "...$ext"
include("ext/test_$ext.jl")
Expand Down
Loading