Skip to content

Commit

Permalink
Support and test clamp! on arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
adrhill committed Oct 21, 2024
1 parent dd12b48 commit 86b8e83
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 1 deletion.
11 changes: 10 additions & 1 deletion src/overloads/arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,16 @@ function Base.literal_pow(::typeof(^), D::Diagonal{T}, ::Val{0}) where {T<:Abstr
end

## clamp!
Base.clamp!(A::AbstractArray{<:AbstractTracer}, lo, hi) = A
Base.clamp!(A::AbstractArray{T}, lo, hi) where {T<:AbstractTracer} = A
function Base.clamp!(A::AbstractArray{T}, lo::T, hi) where {T<:AbstractTracer}
return first_order_or.(A, lo)
end
function Base.clamp!(A::AbstractArray{T}, lo, hi::T) where {T<:AbstractTracer}
return first_order_or.(A, hi)
end
function Base.clamp!(A::AbstractArray{T}, lo::T, hi::T) where {T<:AbstractTracer}
return first_order_or.(A, first_order_or(lo, hi))
end

#==========================#
# LinearAlgebra.jl on Dual #
Expand Down
35 changes: 35 additions & 0 deletions test/test_arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,41 @@ S = BitSet
P = IndexSetGradientPattern{Int,S}
TG = GradientTracer{P}

@testset "clamp!" begin
t1 = TG(P(S(1)))
t2 = TG(P(S(2)))
t3 = TG(P(S(3)))
t4 = TG(P(S(4)))
A = [t1 t2; t3 t4]

t_lo = TG(P(S(5)))
t_hi = TG(P(S(6)))

out = clamp!(A, 0.0, 1.0)
@test SCT.gradient(out[1, 1]) == S(1)
@test SCT.gradient(out[1, 2]) == S(2)
@test SCT.gradient(out[2, 1]) == S(3)
@test SCT.gradient(out[2, 2]) == S(4)

out = clamp!(A, t_lo, 1.0)
@test SCT.gradient(out[1, 1]) == S([1, 5])
@test SCT.gradient(out[1, 2]) == S([2, 5])
@test SCT.gradient(out[2, 1]) == S([3, 5])
@test SCT.gradient(out[2, 2]) == S([4, 5])

out = clamp!(A, 0.0, t_hi)
@test SCT.gradient(out[1, 1]) == S([1, 6])
@test SCT.gradient(out[1, 2]) == S([2, 6])
@test SCT.gradient(out[2, 1]) == S([3, 6])
@test SCT.gradient(out[2, 2]) == S([4, 6])

out = clamp!(A, t_lo, t_hi)
@test SCT.gradient(out[1, 1]) == S([1, 5, 6])
@test SCT.gradient(out[1, 2]) == S([2, 5, 6])
@test SCT.gradient(out[2, 1]) == S([3, 5, 6])
@test SCT.gradient(out[2, 2]) == S([4, 5, 6])
end

@testset "Matrix division" begin
t1 = TG(P(S([1, 3, 4])))
t2 = TG(P(S([2, 4])))
Expand Down

0 comments on commit 86b8e83

Please sign in to comment.