Skip to content

Commit

Permalink
Support clamp and clamp! (#208)
Browse files Browse the repository at this point in the history
  • Loading branch information
adrhill authored Oct 21, 2024
1 parent 6e404ca commit ececbfd
Show file tree
Hide file tree
Showing 9 changed files with 164 additions and 47 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# SparseConnectivityTracer.jl

## Version `v0.6.8`

* ![Feature][badge-feature] Support `clamp` and `clamp!` ([#208])

## Version `v0.6.7`

* ![Enhancement][badge-enhancement] Drop compatibility with Julia <1.10 to improve tracer performance ([#204], [#205])
Expand Down Expand Up @@ -80,6 +84,7 @@
[badge-maintenance]: https://img.shields.io/badge/maintenance-gray.svg
[badge-docs]: https://img.shields.io/badge/docs-orange.svg

[#208]: https://github.com/adrhill/SparseConnectivityTracer.jl/pull/208
[#205]: https://github.com/adrhill/SparseConnectivityTracer.jl/pull/205
[#204]: https://github.com/adrhill/SparseConnectivityTracer.jl/pull/204
[#202]: https://github.com/adrhill/SparseConnectivityTracer.jl/pull/202
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SparseConnectivityTracer"
uuid = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
authors = ["Adrian Hill <[email protected]>"]
version = "0.6.7"
version = "0.6.8-DEV"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
3 changes: 2 additions & 1 deletion src/SparseConnectivityTracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@ include("operators.jl")
include("overloads/conversion.jl")
include("overloads/gradient_tracer.jl")
include("overloads/hessian_tracer.jl")
include("overloads/utils.jl")
include("overloads/special_cases.jl")
include("overloads/three_arg.jl")
include("overloads/ifelse_global.jl")
include("overloads/dual.jl")
include("overloads/arrays.jl")
include("overloads/utils.jl")
include("overloads/ambiguities.jl")

include("trace_functions.jl")
Expand Down
57 changes: 12 additions & 45 deletions src/overloads/arrays.jl
Original file line number Diff line number Diff line change
@@ -1,48 +1,3 @@
"""
second_order_or(tracers)
Compute the most conservative elementwise OR of tracer sparsity patterns,
including second-order interactions to update the `hessian` field of `HessianTracer`.
This is functionally equivalent to:
```julia
reduce(^, tracers)
```
"""
function second_order_or(ts::AbstractArray{T}) where {T<:AbstractTracer}
# TODO: improve performance
return reduce(second_order_or, ts; init=myempty(T))
end

function second_order_or(a::T, b::T) where {T<:GradientTracer}
return gradient_tracer_2_to_1(a, b, false, false)
end
function second_order_or(a::T, b::T) where {T<:HessianTracer}
return hessian_tracer_2_to_1(a, b, false, false, false, false, false)
end

"""
first_order_or(tracers)
Compute the most conservative elementwise OR of tracer sparsity patterns,
excluding second-order interactions of `HessianTracer`.
This is functionally equivalent to:
```julia
reduce(+, tracers)
```
"""
function first_order_or(ts::AbstractArray{T}) where {T<:AbstractTracer}
# TODO: improve performance
return reduce(first_order_or, ts; init=myempty(T))
end
function first_order_or(a::T, b::T) where {T<:GradientTracer}
return gradient_tracer_2_to_1(a, b, false, false)
end
function first_order_or(a::T, b::T) where {T<:HessianTracer}
return hessian_tracer_2_to_1(a, b, false, true, false, true, true)
end

#===========#
# Utilities #
#===========#
Expand Down Expand Up @@ -168,6 +123,18 @@ function Base.literal_pow(::typeof(^), D::Diagonal{T}, ::Val{0}) where {T<:Abstr
return Diagonal(ts)
end

## clamp!
Base.clamp!(A::AbstractArray{T}, lo, hi) where {T<:AbstractTracer} = A
function Base.clamp!(A::AbstractArray{T}, lo::T, hi) where {T<:AbstractTracer}
return first_order_or.(A, lo)
end
function Base.clamp!(A::AbstractArray{T}, lo, hi::T) where {T<:AbstractTracer}
return first_order_or.(A, hi)
end
function Base.clamp!(A::AbstractArray{T}, lo::T, hi::T) where {T<:AbstractTracer}
return first_order_or.(A, first_order_or(lo, hi))
end

#==========================#
# LinearAlgebra.jl on Dual #
#==========================#
Expand Down
12 changes: 12 additions & 0 deletions src/overloads/three_arg.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#= For now, three-argument functions are overloaded individually.
If this file grows too large:
1. 3-arg operators should be classified in src/operators.jl
2. the classification should be tested in test/classification.jl
3. code generation utilities should be added to the src/overloads/*_tracer.jl files
=#
Base.clamp(t::T, lo, hi) where {T<:AbstractTracer} = t
Base.clamp(t::T, lo::T, hi) where {T<:AbstractTracer} = first_order_or(t, lo)
Base.clamp(t::T, lo, hi::T) where {T<:AbstractTracer} = first_order_or(t, hi)
function Base.clamp(t::T, lo::T, hi::T) where {T<:AbstractTracer}
return first_order_or(t, first_order_or(lo, hi))
end
53 changes: 53 additions & 0 deletions src/overloads/utils.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,56 @@
#===============#
# Tracer unions #
#===============#

"""
first_order_or(tracers)
Compute the most conservative elementwise OR of tracer sparsity patterns,
excluding second-order interactions of `HessianTracer`.
This is functionally equivalent to:
```julia
reduce(+, tracers)
```
"""
function first_order_or(ts::AbstractArray{T}) where {T<:AbstractTracer}
# TODO: improve performance
return reduce(first_order_or, ts; init=myempty(T))
end
function first_order_or(a::T, b::T) where {T<:GradientTracer}
return gradient_tracer_2_to_1(a, b, false, false)
end
function first_order_or(a::T, b::T) where {T<:HessianTracer}
return hessian_tracer_2_to_1(a, b, false, true, false, true, true)
end

"""
second_order_or(tracers)
Compute the most conservative elementwise OR of tracer sparsity patterns,
including second-order interactions to update the `hessian` field of `HessianTracer`.
This is functionally equivalent to:
```julia
reduce(^, tracers)
```
"""
function second_order_or(ts::AbstractArray{T}) where {T<:AbstractTracer}
# TODO: improve performance
return reduce(second_order_or, ts; init=myempty(T))
end

function second_order_or(a::T, b::T) where {T<:GradientTracer}
return gradient_tracer_2_to_1(a, b, false, false)
end
function second_order_or(a::T, b::T) where {T<:HessianTracer}
return hessian_tracer_2_to_1(a, b, false, false, false, false, false)
end

#=================#
# Code generation #
#=================#

dims = (Symbol("1_to_1"), Symbol("2_to_1"), Symbol("1_to_2"))

# Generate both Gradient and Hessian code with one call to `generate_code_X_to_Y`
Expand Down
35 changes: 35 additions & 0 deletions test/test_arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,41 @@ S = BitSet
P = IndexSetGradientPattern{Int,S}
TG = GradientTracer{P}

@testset "clamp!" begin
t1 = TG(P(S(1)))
t2 = TG(P(S(2)))
t3 = TG(P(S(3)))
t4 = TG(P(S(4)))
A = [t1 t2; t3 t4]

t_lo = TG(P(S(5)))
t_hi = TG(P(S(6)))

out = clamp!(A, 0.0, 1.0)
@test SCT.gradient(out[1, 1]) == S(1)
@test SCT.gradient(out[1, 2]) == S(2)
@test SCT.gradient(out[2, 1]) == S(3)
@test SCT.gradient(out[2, 2]) == S(4)

out = clamp!(A, t_lo, 1.0)
@test SCT.gradient(out[1, 1]) == S([1, 5])
@test SCT.gradient(out[1, 2]) == S([2, 5])
@test SCT.gradient(out[2, 1]) == S([3, 5])
@test SCT.gradient(out[2, 2]) == S([4, 5])

out = clamp!(A, 0.0, t_hi)
@test SCT.gradient(out[1, 1]) == S([1, 6])
@test SCT.gradient(out[1, 2]) == S([2, 6])
@test SCT.gradient(out[2, 1]) == S([3, 6])
@test SCT.gradient(out[2, 2]) == S([4, 6])

out = clamp!(A, t_lo, t_hi)
@test SCT.gradient(out[1, 1]) == S([1, 5, 6])
@test SCT.gradient(out[1, 2]) == S([2, 5, 6])
@test SCT.gradient(out[2, 1]) == S([3, 5, 6])
@test SCT.gradient(out[2, 2]) == S([4, 5, 6])
end

@testset "Matrix division" begin
t1 = TG(P(S([1, 3, 4])))
t2 = TG(P(S([2, 4])))
Expand Down
19 changes: 19 additions & 0 deletions test/test_gradient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@ J(f, x) = jacobian_sparsity(f, x, detector)
@test J(x -> round(x; digits=3, base=2), 1.1) [0;;]
end

@testset "Three-argument operators" begin
@test J(x -> clamp(x, 0.0, 1.0), rand()) == [1;;]
@test J(x -> clamp(x[1], x[2], 1.0), rand(2)) == [1 1]
@test J(x -> clamp(x[1], 0.0, x[2]), rand(2)) == [1 1]
@test J(x -> clamp(x[1], x[2], x[3]), rand(3)) == [1 1 1]
end

@testset "Random" begin
@test J(x -> rand(typeof(x)), 1) [0;;]
@test J(x -> rand(GLOBAL_RNG, typeof(x)), 1) [0;;]
Expand Down Expand Up @@ -219,6 +226,18 @@ end
@test J(x -> round(x; digits=3, base=2), 1.1) [0;;]
end

@testset "Three-argument operators" begin
@test J(x -> clamp(x, 0.0, 1.0), 0.5) == [1;;]
@test J(x -> clamp(x, 0.0, 1.0), -0.5) == [0;;]
@test J(x -> clamp(x[1], x[2], 1.0), [0.5, 0.0]) == [1 0]
@test J(x -> clamp(x[1], x[2], 1.0), [0.5, 0.6]) == [0 1]
@test J(x -> clamp(x[1], 0.0, x[2]), [0.5, 1.0]) == [1 0]
@test J(x -> clamp(x[1], 0.0, x[2]), [0.5, 0.4]) == [0 1]
@test J(x -> clamp(x[1], x[2], x[3]), [0.5, 0.0, 1.0]) == [1 0 0]
@test J(x -> clamp(x[1], x[2], x[3]), [0.5, 0.6, 1.0]) == [0 1 0]
@test J(x -> clamp(x[1], x[2], x[3]), [0.5, 0.0, 0.4]) == [0 0 1]
end

@testset "Random" begin
@test J(x -> rand(typeof(x)), 1) [0;;]
@test J(x -> rand(GLOBAL_RNG, typeof(x)), 1) [0;;]
Expand Down
25 changes: 25 additions & 0 deletions test/test_hessian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,16 @@ D = Dual{Int,T}
@test H(x -> round(x; digits=3, base=2), 1.1) [0;;]
end

@testset "Three-argument operators" begin
@test H(x -> clamp(x, 0.1, 0.9), rand()) == [0;;]
@test H(x -> clamp(x[1], x[2], 0.9), rand(2)) == [0 0; 0 0]
@test H(x -> clamp(x[1], 0.1, x[2]), rand(2)) == [0 0; 0 0]
@test H(x -> clamp(x[1], x[2], x[3]), rand(3)) == [0 0 0; 0 0 0; 0 0 0]
@test H(x -> x[1] * clamp(x[1], x[2], x[3]), rand(3)) == [1 1 1; 1 0 0; 1 0 0]
@test H(x -> x[2] * clamp(x[1], x[2], x[3]), rand(3)) == [0 1 0; 1 1 1; 0 1 0]
@test H(x -> x[3] * clamp(x[1], x[2], x[3]), rand(3)) == [0 0 1; 0 0 1; 1 1 1]
end

@testset "Random" begin
@test H(x -> rand(typeof(x)), 1) [0;;]
@test H(x -> rand(GLOBAL_RNG, typeof(x)), 1) [0;;]
Expand Down Expand Up @@ -377,6 +387,21 @@ end
@test H(x -> round(x; digits=3, base=2), 1.1) [0;;]
end

@testset "Three-argument operators" begin
@test H(x -> x * clamp(x, 0.0, 1.0), 0.5) == [1;;]
@test H(x -> x * clamp(x, 0.0, 1.0), -0.5) == [0;;]
@test H(x -> sum(x) * clamp(x[1], x[2], 1.0), [0.5, 0.0]) == [1 1; 1 0]
@test H(x -> sum(x) * clamp(x[1], x[2], 1.0), [0.5, 0.6]) == [0 1; 1 1]
@test H(x -> sum(x) * clamp(x[1], 0.0, x[2]), [0.5, 1.0]) == [1 1; 1 0]
@test H(x -> sum(x) * clamp(x[1], 0.0, x[2]), [0.5, 0.4]) == [0 1; 1 1]
@test H(x -> sum(x) * clamp(x[1], x[2], x[3]), [0.5, 0.0, 1.0]) ==
[1 1 1; 1 0 0; 1 0 0]
@test H(x -> sum(x) * clamp(x[1], x[2], x[3]), [0.5, 0.6, 1.0]) ==
[0 1 0; 1 1 1; 0 1 0]
@test H(x -> sum(x) * clamp(x[1], x[2], x[3]), [0.5, 0.0, 0.4]) ==
[0 0 1; 0 0 1; 1 1 1]
end

@testset "Random" begin
@test H(x -> rand(typeof(x)), 1) [0;;]
@test H(x -> rand(GLOBAL_RNG, typeof(x)), 1) [0;;]
Expand Down

0 comments on commit ececbfd

Please sign in to comment.