Skip to content

Commit

Permalink
Return only primal when applying non-differentiable methods to `Dua…
Browse files Browse the repository at this point in the history
…l` (#169)

* Return only `primal` on non-diff `Dual` methods

* Generate different overloads
based on type of differentiability of an operator

* Make comparisons regular operators

* Add more named testsets
  • Loading branch information
adrhill authored Aug 16, 2024
1 parent 8251cb3 commit 2cad635
Show file tree
Hide file tree
Showing 16 changed files with 1,000 additions and 862 deletions.
2 changes: 0 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ version = "0.6.2-DEV"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -25,7 +24,6 @@ SparseConnectivityTracerSpecialFunctionsExt = "SpecialFunctions"

[compat]
ADTypes = "1"
Compat = "3,4"
DocStringExtensions = "0.9"
FillArrays = "1"
LinearAlgebra = "<0.0.1, 1"
Expand Down
17 changes: 8 additions & 9 deletions ext/SparseConnectivityTracerNNlibExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,13 @@ SCT.is_der1_zero_local(::typeof(softshrink), x) = x > -0.5 && x < 0.5

ops_1_to_1 = union(ops_1_to_1_s, ops_1_to_1_f)

## Lists

SCT.list_operators_1_to_1(::Val{:NNlib}) = ops_1_to_1
SCT.list_operators_2_to_1(::Val{:NNlib}) = ()
SCT.list_operators_1_to_2(::Val{:NNlib}) = ()

## Overloads

eval(SCT.overload_all(:NNlib))
## Overload
eval(SCT.overload_gradient_1_to_1(:NNlib, ops_1_to_1))
eval(SCT.overload_hessian_1_to_1(:NNlib, ops_1_to_1))

## List operators for later testing
SCT.test_operators_1_to_1(::Val{:NNlib}) = ops_1_to_1
SCT.test_operators_2_to_1(::Val{:NNlib}) = ()
SCT.test_operators_1_to_2(::Val{:NNlib}) = ()

end
17 changes: 9 additions & 8 deletions ext/SparseConnectivityTracerSpecialFunctionsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,15 @@ end

ops_2_to_1 = ops_2_to_1_ssc

## Lists

SCT.list_operators_1_to_1(::Val{:SpecialFunctions}) = ops_1_to_1
SCT.list_operators_2_to_1(::Val{:SpecialFunctions}) = ops_2_to_1
SCT.list_operators_1_to_2(::Val{:SpecialFunctions}) = ()

## Overloads

eval(SCT.overload_all(:SpecialFunctions))
eval(SCT.overload_gradient_1_to_1(:SpecialFunctions, ops_1_to_1))
eval(SCT.overload_gradient_2_to_1(:SpecialFunctions, ops_2_to_1))
eval(SCT.overload_hessian_1_to_1(:SpecialFunctions, ops_1_to_1))
eval(SCT.overload_hessian_2_to_1(:SpecialFunctions, ops_2_to_1))

## List operators for later testing
SCT.test_operators_1_to_1(::Val{:SpecialFunctions}) = ops_1_to_1
SCT.test_operators_2_to_1(::Val{:SpecialFunctions}) = ops_2_to_1
SCT.test_operators_1_to_2(::Val{:SpecialFunctions}) = ()

end
1 change: 0 additions & 1 deletion src/SparseConnectivityTracer.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
module SparseConnectivityTracer

using ADTypes: ADTypes, jacobian_sparsity, hessian_sparsity
using Compat: Returns
using SparseArrays: SparseArrays
using SparseArrays: sparse
using Random: AbstractRNG, SamplerType
Expand Down
6 changes: 2 additions & 4 deletions src/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,8 @@ end
ops_2_to_1_zzz = (
# division
div, fld, fld1, cld,
# comparisons
isequal, isapprox, isless, ==, <, >, <=, >=,
)
for op in ops_2_to_1_zzz
T = typeof(op)
Expand Down Expand Up @@ -582,7 +584,3 @@ ops_1_to_2 = union(
ops_1_to_2_zz,
)
#! format: on

list_operators_1_to_1(::Val{:Base}) = ops_1_to_1
list_operators_2_to_1(::Val{:Base}) = ops_2_to_1
list_operators_1_to_2(::Val{:Base}) = ops_1_to_2
30 changes: 19 additions & 11 deletions src/overloads/conversion.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
#! format: off

##===============#
# AbstractTracer #
#================#

## Type conversions (non-dual)
Base.promote_rule(::Type{T}, ::Type{N}) where {T<:AbstractTracer,N<:Real} = T
Base.promote_rule(::Type{N}, ::Type{T}) where {T<:AbstractTracer,N<:Real} = T
Expand All @@ -24,7 +28,10 @@ Base.floatmin(::Type{T}) where {T<:AbstractTracer} = myempty(T)
Base.floatmax(::Type{T}) where {T<:AbstractTracer} = myempty(T)
Base.maxintfloat(::Type{T}) where {T<:AbstractTracer} = myempty(T)

## Duals
##======#
# Duals #
#=======#

function Base.promote_rule(::Type{Dual{P1, T}}, ::Type{Dual{P2, T}}) where {P1,P2,T}
PP = Base.promote_type(P1, P2) # TODO: possible method call error?
return Dual{PP,T}
Expand Down Expand Up @@ -59,15 +66,16 @@ for T in (:Int, :Integer, :Float64, :Float32)
end

## Constants
# These are methods defined on types. Methods on variables are in operators.jl
Base.zero(::Type{D}) where {P,T,D<:Dual{P,T}} = D(zero(P), myempty(T))
Base.one(::Type{D}) where {P,T,D<:Dual{P,T}} = D(one(P), myempty(T))
Base.oneunit(::Type{D}) where {P,T,D<:Dual{P,T}} = D(oneunit(P), myempty(T))
Base.typemin(::Type{D}) where {P,T,D<:Dual{P,T}} = D(typemin(P), myempty(T))
Base.typemax(::Type{D}) where {P,T,D<:Dual{P,T}} = D(typemax(P), myempty(T))
Base.eps(::Type{D}) where {P,T,D<:Dual{P,T}} = D(eps(P), myempty(T))
Base.floatmin(::Type{D}) where {P,T,D<:Dual{P,T}} = D(floatmin(P), myempty(T))
Base.floatmax(::Type{D}) where {P,T,D<:Dual{P,T}} = D(floatmax(P), myempty(T))
Base.maxintfloat(::Type{D}) where {P,T,D<:Dual{P,T}} = D(maxintfloat(P), myempty(T))
# These are methods defined on types. Methods on variables are in operators.jl
# TODO: only return primal on methods on variable
Base.zero(::Type{D}) where {P,T,D<:Dual{P,T}} = zero(P)
Base.one(::Type{D}) where {P,T,D<:Dual{P,T}} = one(P)
Base.oneunit(::Type{D}) where {P,T,D<:Dual{P,T}} = oneunit(P)
Base.typemin(::Type{D}) where {P,T,D<:Dual{P,T}} = typemin(P)
Base.typemax(::Type{D}) where {P,T,D<:Dual{P,T}} = typemax(P)
Base.eps(::Type{D}) where {P,T,D<:Dual{P,T}} = eps(P)
Base.floatmin(::Type{D}) where {P,T,D<:Dual{P,T}} = floatmin(P)
Base.floatmax(::Type{D}) where {P,T,D<:Dual{P,T}} = floatmax(P)
Base.maxintfloat(::Type{D}) where {P,T,D<:Dual{P,T}} = maxintfloat(P)

#! format: on
6 changes: 0 additions & 6 deletions src/overloads/dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,5 @@ for fn in (
end
end

for fn in (:isequal, :isapprox, :isless, :(==), :(<), :(>), :(<=), :(>=))
@eval Base.$fn(dx::D, dy::D) where {D<:Dual} = $fn(primal(dx), primal(dy))
@eval Base.$fn(dx::D, y::Real) where {D<:Dual} = $fn(primal(dx), y)
@eval Base.$fn(x::Real, dy::D) where {D<:Dual} = $fn(x, primal(dy))
end

# In some cases, more specialized methods are needed
Base.isless(dx::D, y::AbstractFloat) where {D<:Dual} = isless(primal(dx), y)
Loading

0 comments on commit 2cad635

Please sign in to comment.