From abf86104dbcd9ab64544c03ce11daffd2c8e3fbe Mon Sep 17 00:00:00 2001 From: adrhill Date: Mon, 2 Sep 2024 17:51:32 +0200 Subject: [PATCH 1/2] Add LogExpFunctions extension --- Project.toml | 16 +++-- ...rseConnectivityTracerLogExpFunctionsExt.jl | 65 +++++++++++++++++++ src/SparseConnectivityTracer.jl | 3 + test/Project.toml | 1 + test/classification.jl | 15 +++-- 5 files changed, 90 insertions(+), 10 deletions(-) create mode 100644 ext/SparseConnectivityTracerLogExpFunctionsExt.jl diff --git a/Project.toml b/Project.toml index dad1279d..292287c9 100644 --- a/Project.toml +++ b/Project.toml @@ -13,29 +13,33 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [weakdeps] +DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0" +LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" -DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0" [extensions] SparseConnectivityTracerDataInterpolationsExt = "DataInterpolations" +SparseConnectivityTracerLogExpFunctionsExt = "LogExpFunctions" SparseConnectivityTracerNNlibExt = "NNlib" SparseConnectivityTracerSpecialFunctionsExt = "SpecialFunctions" -[extras] -DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0" -NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" - [compat] ADTypes = "1" DataInterpolations = "6.2" DocStringExtensions = "0.9" FillArrays = "1" LinearAlgebra = "<0.0.1, 1" +LogExpFunctions = "0.3" NNlib = "0.8, 0.9" Random = "<0.0.1, 1" Requires = "1.3" SparseArrays = "<0.0.1, 1" SpecialFunctions = "2.4" julia = "1.6" + +[extras] +DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0" +LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" diff --git a/ext/SparseConnectivityTracerLogExpFunctionsExt.jl b/ext/SparseConnectivityTracerLogExpFunctionsExt.jl new file mode 100644 index 00000000..c4074fa9 --- /dev/null +++ b/ext/SparseConnectivityTracerLogExpFunctionsExt.jl @@ -0,0 +1,65 @@ +module SparseConnectivityTracerLogExpFunctionsExt + +if isdefined(Base, :get_extension) + import SparseConnectivityTracer as SCT + using LogExpFunctions +else + import ..SparseConnectivityTracer as SCT + using ..LogExpFunctions +end + +# 1-to-1 functions + +ops_1_to_1 = ( + xlogx, + xexpx, + logistic, + logit, + logcosh, + logabssinh, + log1psq, + log1pexp, + log1mexp, + log2mexp, + logexpm1, + softplus, + invsoftplus, + log1pmx, + logmxp1, + cloglog, + cexpexp, + loglogistic, + logitexp, + log1mlogistic, + logit1mexp, +) + +for op in ops_1_to_1 + T = typeof(op) + @eval SCT.is_der1_zero_global(::$T) = false + @eval SCT.is_der2_zero_global(::$T) = false +end + +# 2-to-1 functions + +ops_2_to_1 = (xlogy, xlog1py, xexpy, logaddexp, logsubexp) + +for op in ops_2_to_1 + T = typeof(op) + @eval SCT.is_der1_arg1_zero_global(::$T) = false + @eval SCT.is_der2_arg1_zero_global(::$T) = false + @eval SCT.is_der1_arg2_zero_global(::$T) = false + @eval SCT.is_der2_arg2_zero_global(::$T) = false + @eval SCT.is_der_cross_zero_global(::$T) = false +end + +# Generate overloads +eval(SCT.generate_code_1_to_1(:LogExpFunctions, ops_1_to_1)) +eval(SCT.generate_code_2_to_1(:LogExpFunctions, ops_2_to_1)) + +# List operators for later testing +SCT.test_operators_1_to_1(::Val{:LogExpFunctions}) = ops_1_to_1 +SCT.test_operators_2_to_1(::Val{:LogExpFunctions}) = ops_2_to_1 +SCT.test_operators_1_to_2(::Val{:LogExpFunctions}) = () + +end diff --git a/src/SparseConnectivityTracer.jl b/src/SparseConnectivityTracer.jl index df716472..a17511b7 100644 --- a/src/SparseConnectivityTracer.jl +++ b/src/SparseConnectivityTracer.jl @@ -50,6 +50,9 @@ function __init__() @require NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" include( "../ext/SparseConnectivityTracerNNlibExt.jl" ) + @require LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" include( + "../ext/SparseConnectivityTracerLogExpFunctionsExt.jl" + ) # NOTE: SparseConnectivityTracerDataInterpolationsExt is not loaded on Julia <1.10 end end diff --git a/test/Project.toml b/test/Project.toml index ce3bd9ff..d2d23b83 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -13,6 +13,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" NLPModels = "a4795742-8479-5a88-8948-cc11e1c8c1a6" NLPModelsJuMP = "792afdf1-32c1-5681-94e0-d7bf7a5df49e" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" diff --git a/test/classification.jl b/test/classification.jl index c1c932e1..b53759aa 100644 --- a/test/classification.jl +++ b/test/classification.jl @@ -1,4 +1,4 @@ -using SparseConnectivityTracer: # 1-to-1 +using SparseConnectivityTracer: # 1-to-1 is_der1_zero_global, is_der2_zero_global, is_der1_zero_local, @@ -29,6 +29,7 @@ using SparseConnectivityTracer: # testing test_operators_1_to_2 using SpecialFunctions: SpecialFunctions using NNlib: NNlib +using LogExpFunctions: LogExpFunctions using Test using ForwardDiff: derivative, gradient, hessian @@ -43,6 +44,12 @@ random_input(op) = rand() random_input(::Union{typeof(acosh),typeof(acoth),typeof(acsc),typeof(asec)}) = 1 + rand() random_input(::typeof(sincosd)) = 180 * rand() +# LogExpFunctions.jl +random_input(::typeof(LogExpFunctions.log1mexp)) = -rand() # log1mexp(x) is defined for x < 0 +random_input(::typeof(LogExpFunctions.log2mexp)) = -rand() # log2mexp(x) is defined for x < 0 +random_input(::typeof(LogExpFunctions.logitexp)) = -rand() # logitexp(x) is defined for x < 0 +random_input(::typeof(LogExpFunctions.logit1mexp)) = -rand() # logit1mexp(x) is defined for x < 0 + random_first_input(op) = random_input(op) random_second_input(op) = random_input(op) @@ -90,7 +97,7 @@ function correct_classification_1_to_1(op, x; atol) end @testset verbose = true "1-to-1" begin - @testset "$m" for m in (Base, SpecialFunctions, NNlib) + @testset "$m" for m in (Base, SpecialFunctions, NNlib, LogExpFunctions) @testset "$op" for op in test_operators_1_to_1(Val(Symbol(m))) @test all( correct_classification_1_to_1(op, random_input(op); atol=DEFAULT_ATOL) for @@ -133,7 +140,7 @@ function correct_classification_2_to_1(op, x, y; atol) end @testset verbose = true "2-to-1" begin - @testset "$m" for m in (Base, SpecialFunctions, NNlib) + @testset "$m" for m in (Base, SpecialFunctions, NNlib, LogExpFunctions) @testset "$op" for op in test_operators_2_to_1(Val(Symbol(m))) @test all( correct_classification_2_to_1( @@ -173,7 +180,7 @@ function correct_classification_1_to_2(op, x; atol) end @testset verbose = true "1-to-2" begin - @testset "$m" for m in (Base, SpecialFunctions, NNlib) + @testset "$m" for m in (Base, SpecialFunctions, NNlib, LogExpFunctions) @testset "$op" for op in test_operators_1_to_2(Val(Symbol(m))) @test all( correct_classification_1_to_2(op, random_input(op); atol=DEFAULT_ATOL) for From 4b3318089e9fd956c40883cf78b29c0291ac15b9 Mon Sep 17 00:00:00 2001 From: adrhill Date: Mon, 2 Sep 2024 17:51:38 +0200 Subject: [PATCH 2/2] Add tests --- test/ext/test_LogExpFunctions.jl | 98 ++++++++++++++++++++++++++++++++ test/runtests.jl | 2 +- 2 files changed, 99 insertions(+), 1 deletion(-) create mode 100644 test/ext/test_LogExpFunctions.jl diff --git a/test/ext/test_LogExpFunctions.jl b/test/ext/test_LogExpFunctions.jl new file mode 100644 index 00000000..7bc8ee9f --- /dev/null +++ b/test/ext/test_LogExpFunctions.jl @@ -0,0 +1,98 @@ +using SparseConnectivityTracer +using LogExpFunctions +using Test + +# Load definitions of GRADIENT_TRACERS, GRADIENT_PATTERNS, HESSIAN_TRACERS and HESSIAN_PATTERNS +include("../tracers_definitions.jl") + +lef_1_to_1_pos_input = ( + xlogx, + logistic, + logit, + log1psq, + log1pexp, + logexpm1, + softplus, + invsoftplus, + log1pmx, + logmxp1, + logcosh, + logabssinh, + cloglog, + cexpexp, + loglogistic, + log1mlogistic, +) +lef_1_to_1_neg_input = (log1mexp, log2mexp, logitexp, logit1mexp) +lef_1_to_1 = union(lef_1_to_1_pos_input, lef_1_to_1_neg_input) +lef_2_to_1 = (xlogy, xlog1py, xexpy, logaddexp, logsubexp) + +@testset "Jacobian Global" begin + method = TracerSparsityDetector() + J(f, x) = jacobian_sparsity(f, x, method) + + @testset "1-to-1 functions" begin + @testset "$f" for f in lef_1_to_1 + @test J(x -> f(x[1]), rand(2)) == [1 0] + end + end + @testset "2-to-1 functions" begin + @testset "$f" for f in lef_2_to_1 + @test J(x -> f(x[1], x[2]), rand(3)) == [1 1 0] + end + end +end + +@testset "Jacobian Local" begin + method = TracerLocalSparsityDetector() + J(f, x) = jacobian_sparsity(f, x, method) + + @testset "1-to-1 functions" begin + @testset "$f" for f in lef_1_to_1_pos_input + @test J(x -> f(x[1]), [0.5, 1.0]) == [1 0] + end + @testset "$f" for f in lef_1_to_1_neg_input + @test J(x -> f(x[1]), [-0.5, 1.0]) == [1 0] + end + end + @testset "2-to-1 functions" begin + @testset "$f" for f in lef_2_to_1 + @test J(x -> f(x[1], x[2]), [0.5, 1.0, 2.0]) == [1 1 0] + end + end +end + +@testset "Hessian Global" begin + method = TracerSparsityDetector() + H(f, x) = hessian_sparsity(f, x, method) + + @testset "1-to-1 functions" begin + @testset "$f" for f in lef_1_to_1 + @test H(x -> f(x[1]), rand(2)) == [1 0; 0 0] + end + end + @testset "2-to-1 functions" begin + @testset "$f" for f in lef_2_to_1 + @test H(x -> f(x[1], x[2]), rand(3)) == [1 1 0; 1 1 0; 0 0 0] + end + end +end + +@testset "Hessian Local" begin + method = TracerLocalSparsityDetector() + H(f, x) = hessian_sparsity(f, x, method) + + @testset "1-to-1 functions" begin + @testset "$f" for f in lef_1_to_1_pos_input + @test H(x -> f(x[1]), [0.5, 1.0]) == [1 0; 0 0] + end + @testset "$f" for f in lef_1_to_1_neg_input + @test H(x -> f(x[1]), [-0.5, 1.0]) == [1 0; 0 0] + end + end + @testset "2-to-1 functions" begin + @testset "$f" for f in lef_2_to_1 + @test H(x -> f(x[1], x[2]), [0.5, 1.0, 2.0]) == [1 1 0; 1 1 0; 0 0 0] + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 85a1196d..89f29629 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -94,7 +94,7 @@ GROUP = get(ENV, "JULIA_SCT_TEST_GROUP", "Core") if GROUP in ("Core", "All") @info "Testing package extensions..." @testset verbose = true "Package extensions" begin - for ext in (:NNlib, :SpecialFunctions) + for ext in (:NNlib, :SpecialFunctions, :LogExpFunctions) @testset "$ext" begin @info "...$ext" include("ext/test_$ext.jl")