Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return only primal when applying non-differentiable methods to Dual #169

Merged
merged 20 commits into from
Aug 16, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Get rid of ambiguity errors on isless
adrhill committed Aug 16, 2024
commit cd21200b8517fdfa460b393f56ba555a17b28ccd
7 changes: 5 additions & 2 deletions src/overloads/gradient_tracer.jl
Original file line number Diff line number Diff line change
@@ -262,9 +262,8 @@ function overload_gradient_1_to_2(M::Symbol, f)
return Expr(:block, expr_gradienttracer, expr_dual)
end

## Special cases to avoid ambiguity errors
## Special overloads to avoid ambiguity errors

## Exponent (requires extra types)
for S in (Integer, Rational, Irrational{:ℯ})
Base.:^(t::T, ::S) where {T<:GradientTracer} = t
Base.:^(::S, t::T) where {T<:GradientTracer} = t
@@ -280,6 +279,10 @@ for S in (Integer, Rational, Irrational{:ℯ})
end
end

function Base.isless(dx::D, y::AbstractFloat) where {P,T<:GradientTracer,D<:Dual{P,T}}
return isless(primal(dx), y)
end

## Rounding
Base.round(::T, ::RoundingMode; kwargs...) where {T<:GradientTracer} = myempty(T)
function Base.round(
7 changes: 5 additions & 2 deletions src/overloads/hessian_tracer.jl
Original file line number Diff line number Diff line change
@@ -304,9 +304,8 @@ function overload_hessian_1_to_2(M::Symbol, f)
end
end

## Special cases
## Special overloads to avoid ambiguity errors

## Exponent (requires extra types)
for S in (Integer, Rational, Irrational{:ℯ})
Base.:^(t::T, ::S) where {T<:HessianTracer} = hessian_tracer_1_to_1(t, false, false)
Base.:^(::S, t::T) where {T<:HessianTracer} = hessian_tracer_1_to_1(t, false, false)
@@ -323,6 +322,10 @@ for S in (Integer, Rational, Irrational{:ℯ})
end
end

function Base.isless(dx::D, y::AbstractFloat) where {P,T<:GradientTracer,D<:Dual{P,T}}
return isless(primal(dx), y)
end

## Rounding
Base.round(t::T, ::RoundingMode; kwargs...) where {T<:HessianTracer} = myempty(T)
function Base.round(