Skip to content

Commit

Permalink
More methods for round and rand, fix NNlib (#162)
Browse files Browse the repository at this point in the history
* More methods for `round` and `rand`

* Fix NNlib bug

* Remove some dead code from connectivity tracer

* Increase NNlib code coverage

* Increase code coverage on `round`

* Increase code coverage on `rand`
  • Loading branch information
adrhill authored Aug 13, 2024
1 parent c3d2cfb commit 199899d
Show file tree
Hide file tree
Showing 10 changed files with 171 additions and 23 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# SparseConnectivityTracer.jl

## Version `v0.6.2`

* ![Feature][badge-feature] More methods for `round` and `rand` ([#162])
* ![Bugfix][badge-bugfix] Fix Hessian on NNlib activation functions ([#162])
* ![Bugfix][badge-bugfix] Fix `isless` ([#161])

## Version `v0.6.1`

* ![Enhancement][badge-enhancement] Improve the performance of Hessian pattern tracing by an order of magnitude:
Expand Down Expand Up @@ -33,6 +39,8 @@
[badge-maintenance]: https://img.shields.io/badge/maintenance-gray.svg
[badge-docs]: https://img.shields.io/badge/docs-orange.svg

[#162]: https://github.com/adrhill/SparseConnectivityTracer.jl/pull/162
[#161]: https://github.com/adrhill/SparseConnectivityTracer.jl/pull/161
[#158]: https://github.com/adrhill/SparseConnectivityTracer.jl/pull/158
[#155]: https://github.com/adrhill/SparseConnectivityTracer.jl/pull/155
[#151]: https://github.com/adrhill/SparseConnectivityTracer.jl/pull/151
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SparseConnectivityTracer"
uuid = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
authors = ["Adrian Hill <[email protected]>"]
version = "0.6.1"
version = "0.6.2-DEV"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
10 changes: 5 additions & 5 deletions ext/SparseConnectivityTracerNNlibExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@ for op in ops_1_to_1_s
@eval SCT.is_der2_zero_global(::$T) = false
end

SCT.is_der2_zero_local(::typeof(celu)) = x > 0
SCT.is_der2_zero_local(::typeof(elu)) = x > 0
SCT.is_der2_zero_local(::typeof(selu)) = x > 0
SCT.is_der2_zero_local(::typeof(celu), x) = x > 0
SCT.is_der2_zero_local(::typeof(elu), x) = x > 0
SCT.is_der2_zero_local(::typeof(selu), x) = x > 0

SCT.is_der1_zero_local(::typeof(hardswish)) = x < -3
SCT.is_der2_zero_local(::typeof(hardswish)) = x < -3 || x > 3
SCT.is_der1_zero_local(::typeof(hardswish), x) = x < -3
SCT.is_der2_zero_local(::typeof(hardswish), x) = x < -3 || x > 3

# ops_1_to_1_f:
# x -> f != 0
Expand Down
2 changes: 0 additions & 2 deletions ext/SparseConnectivityTracerSpecialFunctionsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,6 @@ ops_2_to_1_ssc = (

for op in ops_2_to_1_ssc
T = typeof(op)
@eval SCT.is_infl_arg1_zero_global(::$T) = false
@eval SCT.is_infl_arg2_zero_global(::$T) = false
@eval SCT.is_der1_arg1_zero_global(::$T) = false
@eval SCT.is_der2_arg1_zero_global(::$T) = false
@eval SCT.is_der1_arg2_zero_global(::$T) = false
Expand Down
3 changes: 1 addition & 2 deletions src/operators.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
## Operator definitions

# We use a system of letters to categorize operators:
# i: independence - no influence at all
# z: influence but first- and second-order derivatives (FOD, SOD) are zero
# z: first- and second-order derivatives (FOD, SOD) are zero
# f: FOD ∂f/∂x is non-zero, SOD ∂²f/∂x² is zero
# s: FOD ∂f/∂x is non-zero, SOD ∂²f/∂x² is non-zero
# c: Cross-derivative ∂²f/∂x∂y is non-zero
Expand Down
28 changes: 23 additions & 5 deletions src/overloads/gradient_tracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -211,23 +211,41 @@ for S in (Integer, Rational, Irrational{:ℯ})
Base.:^(::S, t::T) where {T<:GradientTracer} = t
function Base.:^(d::D, y::S) where {P,T<:GradientTracer,D<:Dual{P,T}}
x = primal(d)
t = tracer(d)
t = gradient_tracer_1_to_1(tracer(d), false)
return Dual(x^y, t)
end
function Base.:^(x::S, d::D) where {P,T<:GradientTracer,D<:Dual{P,T}}
y = primal(d)
t = tracer(d)
t = gradient_tracer_1_to_1(tracer(d), false)
return Dual(x^y, t)
end
end

## Rounding
Base.round(t::T, ::RoundingMode; kwargs...) where {T<:GradientTracer} = myempty(T)
Base.round(::T, ::RoundingMode; kwargs...) where {T<:GradientTracer} = myempty(T)
function Base.round(
d::D, mode::RoundingMode; kwargs...
) where {P,T<:GradientTracer,D<:Dual{P,T}}
return Dual(round(primal(d), mode; kwargs...), myempty(T))
p = round(primal(d), mode; kwargs...)
t = myempty(T)
return Dual(p, t)
end

for RR in (Real, Integer, Bool)
Base.round(::Type{R}, ::T) where {R<:RR,T<:GradientTracer} = myempty(T)
function Base.round(::Type{R}, d::D) where {R<:RR,P,T<:GradientTracer,D<:Dual{P,T}}
p = round(R, primal(d))
t = myempty(T)
return Dual(p, t)
end
end

## Random numbers
Base.rand(::AbstractRNG, ::SamplerType{T}) where {T<:GradientTracer} = myempty(T) # TODO: was missing Base, add tests
Base.rand(::AbstractRNG, ::SamplerType{T}) where {T<:GradientTracer} = myempty(T)
function Base.rand(
rng::AbstractRNG, ::SamplerType{D}
) where {P,T<:GradientTracer,D<:Dual{P,T}}
p = rand(rng, P)
t = myempty(T)
return Dual(p, t)
end
25 changes: 24 additions & 1 deletion src/overloads/hessian_tracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,29 @@ end

## Rounding
Base.round(t::T, ::RoundingMode; kwargs...) where {T<:HessianTracer} = myempty(T)
function Base.round(
d::D, mode::RoundingMode; kwargs...
) where {P,T<:HessianTracer,D<:Dual{P,T}}
p = round(primal(d), mode; kwargs...)
t = myempty(T)
return Dual(p, t)
end

for RR in (Real, Integer, Bool)
Base.round(::Type{R}, ::T) where {R<:RR,T<:HessianTracer} = myempty(T)
function Base.round(::Type{R}, d::D) where {R<:RR,P,T<:HessianTracer,D<:Dual{P,T}}
p = round(R, primal(d))
t = myempty(T)
return Dual(p, t)
end
end

## Random numbers
Base.rand(::AbstractRNG, ::SamplerType{T}) where {T<:HessianTracer} = myempty(T) # TODO: was missing Base, add tests
Base.rand(::AbstractRNG, ::SamplerType{T}) where {T<:HessianTracer} = myempty(T)
function Base.rand(
rng::AbstractRNG, ::SamplerType{D}
) where {P,T<:HessianTracer,D<:Dual{P,T}}
p = rand(rng, P)
t = myempty(T)
return Dual(p, t)
end
2 changes: 0 additions & 2 deletions src/overloads/ifelse_global.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ end
# Overload only on AbstractTracer, not Dual
for op in (isequal, isapprox, isless, ==, <, >, <=, >=)
T = typeof(op)
@eval is_infl_arg1_zero_global(::$T) = false
@eval is_infl_arg2_zero_global(::$T) = false
@eval is_der1_arg1_zero_global(::$T) = true
@eval is_der2_arg1_zero_global(::$T) = true
@eval is_der1_arg2_zero_global(::$T) = true
Expand Down
43 changes: 40 additions & 3 deletions test/test_gradient.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
using SparseConnectivityTracer
using SparseConnectivityTracer: GradientTracer, Dual, MissingPrimalError, trace_input
using Test

using Random: rand, GLOBAL_RNG
using LinearAlgebra: det, dot, logdet
using SpecialFunctions: erf, beta
using NNlib: NNlib
using Test

# Load definitions of GRADIENT_TRACERS, GRADIENT_PATTERNS, HESSIAN_TRACERS and HESSIAN_PATTERNS
include("tracers_definitions.jl")
Expand Down Expand Up @@ -63,7 +65,6 @@ REAL_TYPES = (Float64, Int, Bool, UInt8, Float16, Rational{Int})
@test J(x -> (2//3)^x, 1) [1;;]
@test J(x -> x^ℯ, 1) [1;;]
@test J(x ->^x, 1) [1;;]
@test J(x -> round(x, RoundNearestTiesUp), 1) [0;;]
@test J(x -> 0, 1) [0;;]

# Test special cases on empty tracer
Expand All @@ -72,6 +73,18 @@ REAL_TYPES = (Float64, Int, Bool, UInt8, Float16, Rational{Int})
@test J(x -> zero(x)^ℯ, 1) [0;;]
@test J(x ->^zero(x), 1) [0;;]

# Round
@test J(round, 1.1) [0;;]
@test J(x -> round(Int, x), 1.1) [0;;]
@test J(x -> round(Bool, x), 1.1) [0;;]
@test J(x -> round(Float16, x), 1.1) [0;;]
@test J(x -> round(x, RoundNearestTiesAway), 1.1) [0;;]
@test J(x -> round(x; digits=3, base=2), 1.1) [0;;]

# Random
@test J(x -> rand(typeof(x)), 1) [0;;]
@test J(x -> rand(GLOBAL_RNG, typeof(x)), 1) [0;;]

# Linear Algebra
@test J(x -> dot(x[1:2], x[4:5]), rand(5)) == [1 1 0 1 1]

Expand Down Expand Up @@ -202,9 +215,19 @@ end
@test J(x -> (2//3)^x, 1) [1;;]
@test J(x -> x^ℯ, 1) [1;;]
@test J(x ->^x, 1) [1;;]
@test J(x -> round(x, RoundNearestTiesUp), 1) [0;;]
@test J(x -> 0, 1) [0;;]

# Round
@test J(round, 1.1) [0;;]
@test J(x -> round(Int, x), 1.1) [0;;]
@test J(x -> round(Bool, x), 1.1) [0;;]
@test J(x -> round(x, RoundNearestTiesAway), 1.1) [0;;]
@test J(x -> round(x; digits=3, base=2), 1.1) [0;;]

# Random
@test J(x -> rand(typeof(x)), 1) [0;;]
@test J(x -> rand(GLOBAL_RNG, typeof(x)), 1) [0;;]

# Linear algebra
@test J(logdet, [1.0 -1.0; 2.0 2.0]) == [1 1 1 1] # (#68)
@test J(x -> log(det(x)), [1.0 -1.0; 2.0 2.0]) == [1 1 1 1]
Expand All @@ -213,6 +236,12 @@ end
# NNlib extension
@test J(NNlib.relu, -1) [0;;]
@test J(NNlib.relu, 1) [1;;]
@test J(NNlib.elu, -1) [1;;]
@test J(NNlib.elu, 1) [1;;]
@test J(NNlib.celu, -1) [1;;]
@test J(NNlib.celu, 1) [1;;]
@test J(NNlib.selu, -1) [1;;]
@test J(NNlib.selu, 1) [1;;]

@test J(NNlib.relu6, -1) [0;;]
@test J(NNlib.relu6, 1) [1;;]
Expand All @@ -221,6 +250,14 @@ end
@test J(NNlib.trelu, 0.9) [0;;]
@test J(NNlib.trelu, 1.1) [1;;]

@test J(NNlib.swish, -5) [1;;]
@test J(NNlib.swish, 0) [1;;]
@test J(NNlib.swish, 5) [1;;]

@test J(NNlib.hardswish, -5) [0;;]
@test J(NNlib.hardswish, 0) [1;;]
@test J(NNlib.hardswish, 5) [1;;]

@test J(NNlib.hardσ, -4) [0;;]
@test J(NNlib.hardσ, 0) [1;;]
@test J(NNlib.hardσ, 4) [0;;]
Expand Down
71 changes: 69 additions & 2 deletions test/test_hessian.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
using SparseConnectivityTracer
using SparseConnectivityTracer: Dual, HessianTracer, MissingPrimalError
using SparseConnectivityTracer: trace_input, create_tracers, pattern, shared
using SpecialFunctions: erf, beta
using Test

using Random: rand, GLOBAL_RNG
using SpecialFunctions: erf, beta
using NNlib: NNlib

# Load definitions of GRADIENT_TRACERS, GRADIENT_PATTERNS, HESSIAN_TRACERS and HESSIAN_PATTERNS
include("tracers_definitions.jl")

# This exists to be able to quickly run tests in the REPL.
# NOTE: H gets overwritten inside the testsets.
method = TracerSparsityDetector()
H(f, x) = hessian_sparsity(f, x, method)

@testset "Global Hessian" begin
@testset "$P" for P in HESSIAN_PATTERNS
T = HessianTracer{P}
Expand All @@ -26,9 +34,20 @@ include("tracers_definitions.jl")
@test H(x -> (2//3)^x, 1) [1;;]
@test H(x -> x^ℯ, 1) [1;;]
@test H(x ->^x, 1) [1;;]
@test H(x -> round(x, RoundNearestTiesUp), 1) [0;;]
@test H(x -> 0, 1) [0;;]

# Round
@test H(round, 1.1) [0;;]
@test H(x -> round(Int, x), 1.1) [0;;]
@test H(x -> round(Bool, x), 1.1) [0;;]
@test H(x -> round(Float16, x), 1.1) [0;;]
@test H(x -> round(x, RoundNearestTiesAway), 1.1) [0;;]
@test H(x -> round(x; digits=3, base=2), 1.1) [0;;]

# Random
@test H(x -> rand(typeof(x)), 1) [0;;]
@test H(x -> rand(GLOBAL_RNG, typeof(x)), 1) [0;;]

@test H(x -> x[1] / x[2] + x[3] / 1 + 1 / x[4], rand(4)) == [
0 1 0 0
1 1 0 0
Expand Down Expand Up @@ -318,11 +337,59 @@ end
@test H(x ->^x, 1) [1;;]
@test H(x -> 0, 1) [0;;]

# Round
@test H(round, 1.1) [0;;]
@test H(x -> round(Int, x), 1.1) [0;;]
@test H(x -> round(Bool, x), 1.1) [0;;]
@test H(x -> round(x, RoundNearestTiesAway), 1.1) [0;;]
@test H(x -> round(x; digits=3, base=2), 1.1) [0;;]

# Random
@test H(x -> rand(typeof(x)), 1) [0;;]
@test H(x -> rand(GLOBAL_RNG, typeof(x)), 1) [0;;]

# Test special cases on empty tracer
@test H(x -> zero(x)^(2//3), 1) [0;;]
@test H(x -> (2//3)^zero(x), 1) [0;;]
@test H(x -> zero(x)^ℯ, 1) [0;;]
@test H(x ->^zero(x), 1) [0;;]

# NNlib extension
@test H(NNlib.relu, -1) [0;;]
@test H(NNlib.relu, 1) [0;;]
@test H(NNlib.elu, -1) [1;;]
@test H(NNlib.elu, 1) [0;;]
@test H(NNlib.celu, -1) [1;;]
@test H(NNlib.celu, 1) [0;;]
@test H(NNlib.selu, -1) [1;;]
@test H(NNlib.selu, 1) [0;;]

@test H(NNlib.relu6, -1) [0;;]
@test H(NNlib.relu6, 1) [0;;]
@test H(NNlib.relu6, 7) [0;;]

@test H(NNlib.trelu, 0.9) [0;;]
@test H(NNlib.trelu, 1.1) [0;;]

@test H(NNlib.swish, -5) [1;;]
@test H(NNlib.swish, 0) [1;;]
@test H(NNlib.swish, 5) [1;;]

@test H(NNlib.hardswish, -5) [0;;]
@test H(NNlib.hardswish, 0) [1;;]
@test H(NNlib.hardswish, 5) [0;;]

@test H(NNlib.hardσ, -4) [0;;]
@test H(NNlib.hardσ, 0) [0;;]
@test H(NNlib.hardσ, 4) [0;;]

@test H(NNlib.hardtanh, -2) [0;;]
@test H(NNlib.hardtanh, 0) [0;;]
@test H(NNlib.hardtanh, 2) [0;;]

@test H(NNlib.softshrink, -1) [0;;]
@test H(NNlib.softshrink, 0) [0;;]
@test H(NNlib.softshrink, 1) [0;;]
yield()
end
end
Expand Down

0 comments on commit 199899d

Please sign in to comment.