Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix isless and update tests #161

Merged
merged 3 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/overloads/dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ for fn in (
:isfinite,
:isinf,
:isinteger,
:isless,
:ismissing,
:isnan,
:isnothing,
Expand All @@ -25,3 +24,6 @@ for fn in (:isequal, :isapprox, :isless, :(==), :(<), :(>), :(<=), :(>=))
@eval Base.$fn(dx::D, y::Real) where {D<:Dual} = $fn(primal(dx), y)
@eval Base.$fn(x::Real, dy::D) where {D<:Dual} = $fn(x, primal(dy))
end

# In some cases, more specialized methods are needed
Base.isless(dx::D, y::AbstractFloat) where {D<:Dual} = isless(primal(dx), y)
250 changes: 114 additions & 136 deletions test/test_gradient.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using SparseConnectivityTracer
using SparseConnectivityTracer: GradientTracer, Dual, MissingPrimalError, trace_input
using ADTypes: jacobian_sparsity
using LinearAlgebra: det, dot, logdet
using SpecialFunctions: erf, beta
using NNlib: NNlib
Expand Down Expand Up @@ -38,52 +37,54 @@ NNLIB_ACTIVATIONS_F = (
)
NNLIB_ACTIVATIONS = union(NNLIB_ACTIVATIONS_S, NNLIB_ACTIVATIONS_F)

REAL_TYPES = (Float64, Int, Bool, UInt8, Float16, Rational{Int})

@testset "Jacobian Global" begin
@testset "$P" for P in GRADIENT_PATTERNS
T = GradientTracer{P}
method = TracerSparsityDetector(; gradient_tracer_type=T)
J(f, x) = jacobian_sparsity(f, x, method)

f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sin(x[3])]
@test jacobian_sparsity(f, rand(3), method) == [1 0 0; 1 1 0; 0 0 1]
@test jacobian_sparsity(identity, rand(), method) ≈ [1;;]
@test jacobian_sparsity(Returns(1), 1, method) ≈ [0;;]
@test J(f, rand(3)) == [1 0 0; 1 1 0; 0 0 1]
@test J(identity, rand()) ≈ [1;;]
@test J(Returns(1), 1) ≈ [0;;]

# Test GradientTracer on functions with zero derivatives
x = rand(2)
g(x) = [x[1] * x[2], ceil(x[1] * x[2]), x[1] * round(x[2])]
@test jacobian_sparsity(g, x, method) == [1 1; 0 0; 1 0]
@test jacobian_sparsity(!, true, method) ≈ [0;;]
@test J(g, x) == [1 1; 0 0; 1 0]
@test J(!, true) ≈ [0;;]

# Code coverage
@test jacobian_sparsity(x -> [sincos(x)...], 1, method) ≈ [1; 1]
@test jacobian_sparsity(typemax, 1, method) ≈ [0;;]
@test jacobian_sparsity(x -> x^(2//3), 1, method) ≈ [1;;]
@test jacobian_sparsity(x -> (2//3)^x, 1, method) ≈ [1;;]
@test jacobian_sparsity(x -> x^ℯ, 1, method) ≈ [1;;]
@test jacobian_sparsity(x -> ℯ^x, 1, method) ≈ [1;;]
@test jacobian_sparsity(x -> round(x, RoundNearestTiesUp), 1, method) ≈ [0;;]
@test jacobian_sparsity(x -> 0, 1, method) ≈ [0;;]
@test J(x -> [sincos(x)...], 1) ≈ [1; 1]
@test J(typemax, 1) ≈ [0;;]
@test J(x -> x^(2//3), 1) ≈ [1;;]
@test J(x -> (2//3)^x, 1) ≈ [1;;]
@test J(x -> x^ℯ, 1) ≈ [1;;]
@test J(x -> ℯ^x, 1) ≈ [1;;]
@test J(x -> round(x, RoundNearestTiesUp), 1) ≈ [0;;]
@test J(x -> 0, 1) ≈ [0;;]

# Test special cases on empty tracer
@test jacobian_sparsity(x -> zero(x)^(2//3), 1, method) ≈ [0;;]
@test jacobian_sparsity(x -> (2//3)^zero(x), 1, method) ≈ [0;;]
@test jacobian_sparsity(x -> zero(x)^ℯ, 1, method) ≈ [0;;]
@test jacobian_sparsity(x -> ℯ^zero(x), 1, method) ≈ [0;;]
@test J(x -> zero(x)^(2//3), 1) ≈ [0;;]
@test J(x -> (2//3)^zero(x), 1) ≈ [0;;]
@test J(x -> zero(x)^ℯ, 1) ≈ [0;;]
@test J(x -> ℯ^zero(x), 1) ≈ [0;;]

# Linear Algebra
@test jacobian_sparsity(x -> dot(x[1:2], x[4:5]), rand(5), method) == [1 1 0 1 1]
@test J(x -> dot(x[1:2], x[4:5]), rand(5)) == [1 1 0 1 1]

# SpecialFunctions extension
@test jacobian_sparsity(x -> erf(x[1]), rand(2), method) == [1 0]
@test jacobian_sparsity(x -> beta(x[1], x[2]), rand(3), method) == [1 1 0]
@test J(x -> erf(x[1]), rand(2)) == [1 0]
@test J(x -> beta(x[1], x[2]), rand(3)) == [1 1 0]

# Missing primal errors
@testset "MissingPrimalError on $f" for f in (
iseven,
isfinite,
isinf,
isinteger,
isless,
ismissing,
isnan,
isnothing,
Expand All @@ -92,27 +93,20 @@ NNLIB_ACTIVATIONS = union(NNLIB_ACTIVATIONS_S, NNLIB_ACTIVATIONS_F)
isreal,
iszero,
)
@test_throws MissingPrimalError jacobian_sparsity(f, rand(), method)
@test_throws MissingPrimalError J(f, rand())
end

# NNlib extension
for f in NNLIB_ACTIVATIONS
@test jacobian_sparsity(f, 1, method) ≈ [1;;]
@test J(f, 1) ≈ [1;;]
end

# ifelse and comparisons
if VERSION >= v"1.8"
@test jacobian_sparsity(
x -> ifelse(x[2] < x[3], x[1] + x[2], x[3] * x[4]), [1 2 3 4], method
) == [1 1 1 1]

@test jacobian_sparsity(
x -> ifelse(x[2] < x[3], x[1] + x[2], 1.0), [1 2 3 4], method
) == [1 1 0 0]

@test jacobian_sparsity(
x -> ifelse(x[2] < x[3], 1.0, x[3] * x[4]), [1 2 3 4], method
) == [0 0 1 1]
@test J(x -> ifelse(x[2] < x[3], x[1] + x[2], x[3] * x[4]), [1 2 3 4]) ==
[1 1 1 1]
@test J(x -> ifelse(x[2] < x[3], x[1] + x[2], 1.0), [1 2 3 4]) == [1 1 0 0]
@test J(x -> ifelse(x[2] < x[3], 1.0, x[3] * x[4]), [1 2 3 4]) == [0 0 1 1]
end

function f_ampgo07(x)
Expand All @@ -121,13 +115,12 @@ NNLIB_ACTIVATIONS = union(NNLIB_ACTIVATIONS_S, NNLIB_ACTIVATIONS_F)
sin(10//3 * x[1]) +
log(abs(x[1])) - 84//100 * x[1] + 3
end
@test jacobian_sparsity(f_ampgo07, [1.0], method) ≈ [1;;]
@test J(f_ampgo07, [1.0]) ≈ [1;;]

## Error handling when applying non-dual tracers to "local" functions with control flow
# TypeError: non-boolean (SparseConnectivityTracer.GradientTracer{BitSet}) used in boolean context
@test_throws TypeError jacobian_sparsity(
x -> x[1] > x[2] ? x[3] : x[4], [1.0, 2.0, 3.0, 4.0], method
) == [0 0 1 1;]
@test_throws TypeError J(x -> x[1] > x[2] ? x[3] : x[4], [1.0, 2.0, 3.0, 4.0]) ==
[0 0 1 1;]
yield()
end
end
Expand All @@ -136,124 +129,109 @@ end
@testset "$P" for P in GRADIENT_PATTERNS
T = GradientTracer{P}
method = TracerLocalSparsityDetector(; gradient_tracer_type=T)
J(f, x) = jacobian_sparsity(f, x, method)

# Multiplication
@test jacobian_sparsity(x -> x[1] * x[2], [1.0, 1.0], method) == [1 1;]
@test jacobian_sparsity(x -> x[1] * x[2], [1.0, 0.0], method) == [0 1;]
@test jacobian_sparsity(x -> x[1] * x[2], [0.0, 1.0], method) == [1 0;]
@test jacobian_sparsity(x -> x[1] * x[2], [0.0, 0.0], method) == [0 0;]
@test J(x -> x[1] * x[2], [1.0, 1.0]) == [1 1;]
@test J(x -> x[1] * x[2], [1.0, 0.0]) == [0 1;]
@test J(x -> x[1] * x[2], [0.0, 1.0]) == [1 0;]
@test J(x -> x[1] * x[2], [0.0, 0.0]) == [0 0;]

# Division
@test jacobian_sparsity(x -> x[1] / x[2], [1.0, 1.0], method) == [1 1;]
@test jacobian_sparsity(x -> x[1] / x[2], [0.0, 0.0], method) == [1 0;]
@test J(x -> x[1] / x[2], [1.0, 1.0]) == [1 1;]
@test J(x -> x[1] / x[2], [0.0, 0.0]) == [1 0;]

# Maximum
@test jacobian_sparsity(x -> max(x[1], x[2]), [1.0, 2.0], method) == [0 1;]
@test jacobian_sparsity(x -> max(x[1], x[2]), [2.0, 1.0], method) == [1 0;]
@test jacobian_sparsity(x -> max(x[1], x[2]), [1.0, 1.0], method) == [1 1;]
@test J(x -> max(x[1], x[2]), [1.0, 2.0]) == [0 1;]
@test J(x -> max(x[1], x[2]), [2.0, 1.0]) == [1 0;]
@test J(x -> max(x[1], x[2]), [1.0, 1.0]) == [1 1;]

# Minimum
@test jacobian_sparsity(x -> min(x[1], x[2]), [1.0, 2.0], method) == [1 0;]
@test jacobian_sparsity(x -> min(x[1], x[2]), [2.0, 1.0], method) == [0 1;]
@test jacobian_sparsity(x -> min(x[1], x[2]), [1.0, 1.0], method) == [1 1;]
@test J(x -> min(x[1], x[2]), [1.0, 2.0]) == [1 0;]
@test J(x -> min(x[1], x[2]), [2.0, 1.0]) == [0 1;]
@test J(x -> min(x[1], x[2]), [1.0, 1.0]) == [1 1;]

# Comparisons
@test jacobian_sparsity(
x -> x[1] > x[2] ? x[3] : x[4], [1.0, 2.0, 3.0, 4.0], method
) == [0 0 0 1;]
@test jacobian_sparsity(
x -> x[1] > x[2] ? x[3] : x[4], [2.0, 1.0, 3.0, 4.0], method
) == [0 0 1 0;]
@test jacobian_sparsity(
x -> x[1] < x[2] ? x[3] : x[4], [1.0, 2.0, 3.0, 4.0], method
) == [0 0 1 0;]
@test jacobian_sparsity(
x -> x[1] < x[2] ? x[3] : x[4], [2.0, 1.0, 3.0, 4.0], method
) == [0 0 0 1;]

@test jacobian_sparsity(x -> x[1] >= x[2] ? x[1] : x[2], [1.0, 2.0], method) ==
[0 1;]
@test jacobian_sparsity(x -> x[1] >= x[2] ? x[1] : x[2], [2.0, 1.0], method) ==
[1 0;]
@test jacobian_sparsity(x -> x[1] >= x[2] ? x[1] : x[2], [1.0, 1.0], method) ==
[1 0;]

@test jacobian_sparsity(x -> x[1] >= x[2] ? x[1] : x[2], [1.0, 2.0], method) ==
[0 1;]
@test jacobian_sparsity(x -> x[1] >= x[2] ? x[1] : x[2], [2.0, 1.0], method) ==
[1 0;]
@test jacobian_sparsity(x -> x[1] >= x[2] ? x[1] : x[2], [1.0, 1.0], method) ==
[1 0;]

@test jacobian_sparsity(x -> x[1] <= x[2] ? x[1] : x[2], [1.0, 2.0], method) ==
[1 0;]
@test jacobian_sparsity(x -> x[1] <= x[2] ? x[1] : x[2], [2.0, 1.0], method) ==
[0 1;]
@test jacobian_sparsity(x -> x[1] <= x[2] ? x[1] : x[2], [1.0, 1.0], method) ==
[1 0;]

@test jacobian_sparsity(x -> x[1] == x[2] ? x[1] : x[2], [1.0, 2.0], method) ==
[0 1;]
@test jacobian_sparsity(x -> x[1] == x[2] ? x[1] : x[2], [2.0, 1.0], method) ==
[0 1;]
@test jacobian_sparsity(x -> x[1] == x[2] ? x[1] : x[2], [1.0, 1.0], method) ==
[1 0;]

@test jacobian_sparsity(x -> x[1] > 1 ? x[1] : x[2], [0.0, 2.0], method) == [0 1;]
@test jacobian_sparsity(x -> x[1] > 1 ? x[1] : x[2], [2.0, 0.0], method) == [1 0;]
@test jacobian_sparsity(x -> x[1] >= 1 ? x[1] : x[2], [0.0, 2.0], method) == [0 1;]
@test jacobian_sparsity(x -> x[1] >= 1 ? x[1] : x[2], [2.0, 0.0], method) == [1 0;]
@test jacobian_sparsity(x -> x[1] < 1 ? x[1] : x[2], [0.0, 2.0], method) == [1 0;]
@test jacobian_sparsity(x -> x[1] < 1 ? x[1] : x[2], [2.0, 0.0], method) == [0 1;]
@test jacobian_sparsity(x -> x[1] <= 1 ? x[1] : x[2], [0.0, 2.0], method) == [1 0;]
@test jacobian_sparsity(x -> x[1] <= 1 ? x[1] : x[2], [2.0, 0.0], method) == [0 1;]
@test jacobian_sparsity(x -> 1 > x[2] ? x[1] : x[2], [0.0, 2.0], method) == [0 1;]
@test jacobian_sparsity(x -> 1 > x[2] ? x[1] : x[2], [2.0, 0.0], method) == [1 0;]
@test jacobian_sparsity(x -> 1 >= x[2] ? x[1] : x[2], [0.0, 2.0], method) == [0 1;]
@test jacobian_sparsity(x -> 1 >= x[2] ? x[1] : x[2], [2.0, 0.0], method) == [1 0;]
@test jacobian_sparsity(x -> 1 < x[2] ? x[1] : x[2], [0.0, 2.0], method) == [1 0;]
@test jacobian_sparsity(x -> 1 < x[2] ? x[1] : x[2], [2.0, 0.0], method) == [0 1;]
@test jacobian_sparsity(x -> 1 <= x[2] ? x[1] : x[2], [0.0, 2.0], method) == [1 0;]
@test jacobian_sparsity(x -> 1 <= x[2] ? x[1] : x[2], [2.0, 0.0], method) == [0 1;]
@test J(x -> x[1] > x[2] ? x[3] : x[4], [1.0, 2.0, 3.0, 4.0]) == [0 0 0 1;]
@test J(x -> x[1] > x[2] ? x[3] : x[4], [2.0, 1.0, 3.0, 4.0]) == [0 0 1 0;]
@test J(x -> x[1] < x[2] ? x[3] : x[4], [1.0, 2.0, 3.0, 4.0]) == [0 0 1 0;]
@test J(x -> x[1] < x[2] ? x[3] : x[4], [2.0, 1.0, 3.0, 4.0]) == [0 0 0 1;]

@test J(x -> x[1] >= x[2] ? x[1] : x[2], [1.0, 2.0]) == [0 1;]
@test J(x -> x[1] >= x[2] ? x[1] : x[2], [2.0, 1.0]) == [1 0;]
@test J(x -> x[1] >= x[2] ? x[1] : x[2], [1.0, 1.0]) == [1 0;]

@test J(x -> x[1] >= x[2] ? x[1] : x[2], [1.0, 2.0]) == [0 1;]
@test J(x -> x[1] >= x[2] ? x[1] : x[2], [2.0, 1.0]) == [1 0;]
@test J(x -> x[1] >= x[2] ? x[1] : x[2], [1.0, 1.0]) == [1 0;]

@test J(x -> x[1] <= x[2] ? x[1] : x[2], [1.0, 2.0]) == [1 0;]
@test J(x -> x[1] <= x[2] ? x[1] : x[2], [2.0, 1.0]) == [0 1;]
@test J(x -> x[1] <= x[2] ? x[1] : x[2], [1.0, 1.0]) == [1 0;]

@test J(x -> x[1] == x[2] ? x[1] : x[2], [1.0, 2.0]) == [0 1;]
@test J(x -> x[1] == x[2] ? x[1] : x[2], [2.0, 1.0]) == [0 1;]
@test J(x -> x[1] == x[2] ? x[1] : x[2], [1.0, 1.0]) == [1 0;]

@testset "Comparison with $T" for T in REAL_TYPES
_1 = oneunit(T)
@test J(x -> x[1] > _1 ? x[1] : x[2], [0.0, 2.0]) == [0 1;]
@test J(x -> x[1] > _1 ? x[1] : x[2], [2.0, 0.0]) == [1 0;]
@test J(x -> x[1] >= _1 ? x[1] : x[2], [0.0, 2.0]) == [0 1;]
@test J(x -> x[1] >= _1 ? x[1] : x[2], [2.0, 0.0]) == [1 0;]
@test J(x -> x[1] < _1 ? x[1] : x[2], [0.0, 2.0]) == [1 0;]
@test J(x -> x[1] < _1 ? x[1] : x[2], [2.0, 0.0]) == [0 1;]
@test J(x -> isless(x[1], _1) ? x[1] : x[2], [0.0, 2.0]) == [1 0;]
@test J(x -> isless(x[1], _1) ? x[1] : x[2], [2.0, 0.0]) == [0 1;]
@test J(x -> x[1] <= _1 ? x[1] : x[2], [0.0, 2.0]) == [1 0;]
@test J(x -> x[1] <= _1 ? x[1] : x[2], [2.0, 0.0]) == [0 1;]
@test J(x -> _1 > x[2] ? x[1] : x[2], [0.0, 2.0]) == [0 1;]
@test J(x -> _1 > x[2] ? x[1] : x[2], [2.0, 0.0]) == [1 0;]
@test J(x -> _1 >= x[2] ? x[1] : x[2], [0.0, 2.0]) == [0 1;]
@test J(x -> _1 >= x[2] ? x[1] : x[2], [2.0, 0.0]) == [1 0;]
@test J(x -> _1 < x[2] ? x[1] : x[2], [0.0, 2.0]) == [1 0;]
@test J(x -> _1 < x[2] ? x[1] : x[2], [2.0, 0.0]) == [0 1;]
@test J(x -> _1 <= x[2] ? x[1] : x[2], [0.0, 2.0]) == [1 0;]
@test J(x -> _1 <= x[2] ? x[1] : x[2], [2.0, 0.0]) == [0 1;]
end

# Code coverage
@test jacobian_sparsity(x -> [sincos(x)...], 1, method) ≈ [1; 1]
@test jacobian_sparsity(typemax, 1, method) ≈ [0;;]
@test jacobian_sparsity(x -> x^(2//3), 1, method) ≈ [1;;]
@test jacobian_sparsity(x -> (2//3)^x, 1, method) ≈ [1;;]
@test jacobian_sparsity(x -> x^ℯ, 1, method) ≈ [1;;]
@test jacobian_sparsity(x -> ℯ^x, 1, method) ≈ [1;;]
@test jacobian_sparsity(x -> round(x, RoundNearestTiesUp), 1, method) ≈ [0;;]
@test jacobian_sparsity(x -> 0, 1, method) ≈ [0;;]
@test J(x -> [sincos(x)...], 1) ≈ [1; 1]
@test J(typemax, 1) ≈ [0;;]
@test J(x -> x^(2//3), 1) ≈ [1;;]
@test J(x -> (2//3)^x, 1) ≈ [1;;]
@test J(x -> x^ℯ, 1) ≈ [1;;]
@test J(x -> ℯ^x, 1) ≈ [1;;]
@test J(x -> round(x, RoundNearestTiesUp), 1) ≈ [0;;]
@test J(x -> 0, 1) ≈ [0;;]

# Linear algebra
@test jacobian_sparsity(logdet, [1.0 -1.0; 2.0 2.0], method) == [1 1 1 1] # (#68)
@test jacobian_sparsity(x -> log(det(x)), [1.0 -1.0; 2.0 2.0], method) == [1 1 1 1]
@test jacobian_sparsity(x -> dot(x[1:2], x[4:5]), [0, 1, 0, 1, 0], method) ==
[1 0 0 0 1]
@test J(logdet, [1.0 -1.0; 2.0 2.0]) == [1 1 1 1] # (#68)
@test J(x -> log(det(x)), [1.0 -1.0; 2.0 2.0]) == [1 1 1 1]
@test J(x -> dot(x[1:2], x[4:5]), [0, 1, 0, 1, 0]) == [1 0 0 0 1]

# NNlib extension
@test jacobian_sparsity(NNlib.relu, -1, method) ≈ [0;;]
@test jacobian_sparsity(NNlib.relu, 1, method) ≈ [1;;]
@test J(NNlib.relu, -1) ≈ [0;;]
@test J(NNlib.relu, 1) ≈ [1;;]

@test jacobian_sparsity(NNlib.relu6, -1, method) ≈ [0;;]
@test jacobian_sparsity(NNlib.relu6, 1, method) ≈ [1;;]
@test jacobian_sparsity(NNlib.relu6, 7, method) ≈ [0;;]
@test J(NNlib.relu6, -1) ≈ [0;;]
@test J(NNlib.relu6, 1) ≈ [1;;]
@test J(NNlib.relu6, 7) ≈ [0;;]

@test jacobian_sparsity(NNlib.trelu, 0.9, method) ≈ [0;;]
@test jacobian_sparsity(NNlib.trelu, 1.1, method) ≈ [1;;]
@test J(NNlib.trelu, 0.9) ≈ [0;;]
@test J(NNlib.trelu, 1.1) ≈ [1;;]

@test jacobian_sparsity(NNlib.hardσ, -4, method) ≈ [0;;]
@test jacobian_sparsity(NNlib.hardσ, 0, method) ≈ [1;;]
@test jacobian_sparsity(NNlib.hardσ, 4, method) ≈ [0;;]
@test J(NNlib.hardσ, -4) ≈ [0;;]
@test J(NNlib.hardσ, 0) ≈ [1;;]
@test J(NNlib.hardσ, 4) ≈ [0;;]

@test jacobian_sparsity(NNlib.hardtanh, -2, method) ≈ [0;;]
@test jacobian_sparsity(NNlib.hardtanh, 0, method) ≈ [1;;]
@test jacobian_sparsity(NNlib.hardtanh, 2, method) ≈ [0;;]
@test J(NNlib.hardtanh, -2) ≈ [0;;]
@test J(NNlib.hardtanh, 0) ≈ [1;;]
@test J(NNlib.hardtanh, 2) ≈ [0;;]

@test jacobian_sparsity(NNlib.softshrink, -1, method) ≈ [1;;]
@test jacobian_sparsity(NNlib.softshrink, 0, method) ≈ [0;;]
@test jacobian_sparsity(NNlib.softshrink, 1, method) ≈ [1;;]
@test J(NNlib.softshrink, -1) ≈ [1;;]
@test J(NNlib.softshrink, 0) ≈ [0;;]
@test J(NNlib.softshrink, 1) ≈ [1;;]
yield()
end
end
Loading
Loading