From 86b8e83449820736071af4543769ae2f7830e853 Mon Sep 17 00:00:00 2001 From: adrhill Date: Mon, 21 Oct 2024 14:41:36 +0200 Subject: [PATCH] Support and test `clamp!` on arrays --- src/overloads/arrays.jl | 11 ++++++++++- test/test_arrays.jl | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/src/overloads/arrays.jl b/src/overloads/arrays.jl index 73aedc7..5901f91 100644 --- a/src/overloads/arrays.jl +++ b/src/overloads/arrays.jl @@ -124,7 +124,16 @@ function Base.literal_pow(::typeof(^), D::Diagonal{T}, ::Val{0}) where {T<:Abstr end ## clamp! -Base.clamp!(A::AbstractArray{<:AbstractTracer}, lo, hi) = A +Base.clamp!(A::AbstractArray{T}, lo, hi) where {T<:AbstractTracer} = A +function Base.clamp!(A::AbstractArray{T}, lo::T, hi) where {T<:AbstractTracer} + return first_order_or.(A, lo) +end +function Base.clamp!(A::AbstractArray{T}, lo, hi::T) where {T<:AbstractTracer} + return first_order_or.(A, hi) +end +function Base.clamp!(A::AbstractArray{T}, lo::T, hi::T) where {T<:AbstractTracer} + return first_order_or.(A, first_order_or(lo, hi)) +end #==========================# # LinearAlgebra.jl on Dual # diff --git a/test/test_arrays.jl b/test/test_arrays.jl index 0525a16..1e3ff74 100644 --- a/test/test_arrays.jl +++ b/test/test_arrays.jl @@ -330,6 +330,41 @@ S = BitSet P = IndexSetGradientPattern{Int,S} TG = GradientTracer{P} +@testset "clamp!" begin + t1 = TG(P(S(1))) + t2 = TG(P(S(2))) + t3 = TG(P(S(3))) + t4 = TG(P(S(4))) + A = [t1 t2; t3 t4] + + t_lo = TG(P(S(5))) + t_hi = TG(P(S(6))) + + out = clamp!(A, 0.0, 1.0) + @test SCT.gradient(out[1, 1]) == S(1) + @test SCT.gradient(out[1, 2]) == S(2) + @test SCT.gradient(out[2, 1]) == S(3) + @test SCT.gradient(out[2, 2]) == S(4) + + out = clamp!(A, t_lo, 1.0) + @test SCT.gradient(out[1, 1]) == S([1, 5]) + @test SCT.gradient(out[1, 2]) == S([2, 5]) + @test SCT.gradient(out[2, 1]) == S([3, 5]) + @test SCT.gradient(out[2, 2]) == S([4, 5]) + + out = clamp!(A, 0.0, t_hi) + @test SCT.gradient(out[1, 1]) == S([1, 6]) + @test SCT.gradient(out[1, 2]) == S([2, 6]) + @test SCT.gradient(out[2, 1]) == S([3, 6]) + @test SCT.gradient(out[2, 2]) == S([4, 6]) + + out = clamp!(A, t_lo, t_hi) + @test SCT.gradient(out[1, 1]) == S([1, 5, 6]) + @test SCT.gradient(out[1, 2]) == S([2, 5, 6]) + @test SCT.gradient(out[2, 1]) == S([3, 5, 6]) + @test SCT.gradient(out[2, 2]) == S([4, 5, 6]) +end + @testset "Matrix division" begin t1 = TG(P(S([1, 3, 4]))) t2 = TG(P(S([2, 4])))