Skip to content

Commit

Permalink
Fix remaining method ambiguities
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierHnt committed Dec 9, 2024
1 parent f74bbd4 commit 9bb9fc7
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 50 deletions.
59 changes: 32 additions & 27 deletions src/intervals/arithmetic/power.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,16 @@ julia> pow(interval(-1, 1), interval(-3))
Interval{Float64}(1.0, Inf, trv)
```
"""
function pow(x::BareInterval{T}, y::BareInterval{T}) where {T<:NumTypes}
isempty_interval(y) && return y
domain = _unsafe_bareinterval(T, zero(T), typemax(T))
x = intersect_interval(x, domain)
isempty_interval(x) && return x
isthin(y) && return _thin_pow(x, sup(y))
return hull(_thin_pow(x, inf(y)), _thin_pow(x, sup(y)))
pow(x, y)
for U (:AbstractFloat, :Rational) # needed to resolve ambiguity
@eval function pow(x::BareInterval{T}, y::BareInterval{T}) where {T<:$U}
isempty_interval(y) && return y
domain = _unsafe_bareinterval(T, zero(T), typemax(T))
x = intersect_interval(x, domain)
isempty_interval(x) && return x
isthin(y) && return _thin_pow(x, sup(y))
return hull(_thin_pow(x, inf(y)), _thin_pow(x, sup(y)))
end
end
pow(x::BareInterval, y::BareInterval) = pow(promote(x, y)...)
# specialize on rational to improve exactness
Expand Down Expand Up @@ -169,28 +172,30 @@ pow(x::Interval, n::Integer) = pow(x, n//one(n))

# helper function for `pow`

function _thin_pow(x::BareInterval{T}, y::T) where {T<:NumTypes}
# assume `inf(x) ≥ 0` and `!isempty_interval(x)`
if sup(x) == 0 # isthinzero(x)
y > 0 && return x
return emptyinterval(BareInterval{T})
else
isinteger(y) && return pown(x, Integer(y))
y == 0.5 && return sqrt(x)
lo = @round(T, inf(x)^y, inf(x)^y)
hi = @round(T, sup(x)^y, sup(x)^y)
return hull(lo, hi)
for U (:AbstractFloat, :Rational) # needed to resolve ambiguity
@eval function _thin_pow(x::BareInterval{T}, y::T) where {T<:$U}
# assume `inf(x) ≥ 0` and `!isempty_interval(x)`
if sup(x) == 0 # isthinzero(x)
y > 0 && return x
return emptyinterval(BareInterval{T})
else
isinteger(y) && return pown(x, Integer(y))
y == 0.5 && return sqrt(x)
lo = @round(T, inf(x)^y, inf(x)^y)
hi = @round(T, sup(x)^y, sup(x)^y)
return hull(lo, hi)
end
end
end

function _thin_pow(x::BareInterval{T}, y::Rational{S}) where {T<:NumTypes,S<:Integer}
# assume `inf(x) ≥ 0` and `!isempty_interval(x)`
if sup(x) == 0 # isthinzero(x)
y > 0 && return x
return emptyinterval(BareInterval{T})
else
isinteger(y) && return pown(x, S(y))
return pown(rootn(x, denominator(y)), numerator(y))
@eval function _thin_pow(x::BareInterval{T}, y::Rational{S}) where {T<:$U,S<:Integer}
# assume `inf(x) ≥ 0` and `!isempty_interval(x)`
if sup(x) == 0 # isthinzero(x)
y > 0 && return x
return emptyinterval(BareInterval{T})
else
isinteger(y) && return pown(x, S(y))
return pown(rootn(x, denominator(y)), numerator(y))
end
end
end

Expand Down
5 changes: 5 additions & 0 deletions src/intervals/exact_literals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ Base.promote_rule(::Type{ExactReal{T}}, ::Type{Interval{S}}) where {T<:Real,S<:N

# to Real

Bool(x::ExactReal) = convert(Bool, x) # needed to resolve ambiguity
(::Type{T})(x::ExactReal) where {T<:Real} = convert(T, x)
Interval{T}(x::ExactReal) where {T<:NumTypes} = convert(Interval{T}, x) # needed to resolve ambiguity
Interval(x::ExactReal) = Interval{promote_numtype(numtype(x.value), numtype(x.value))}(x) # needed to resolve ambiguity
Expand All @@ -129,6 +130,10 @@ Base.promote_rule(::Type{ExactReal{T}}, ::Type{BigFloat}) where {T<:Real} =
promote_type(T, BigFloat)
Base.promote_rule(::Type{BigFloat}, ::Type{ExactReal{T}}) where {T<:Real} =
promote_type(BigFloat, T)
Base.promote_rule(::Type{ExactReal{T}}, ::Type{S}) where {T<:Real,S<:AbstractIrrational} =
promote_type(T, S)
Base.promote_rule(::Type{T}, ::Type{ExactReal{S}}) where {T<:AbstractIrrational,S<:Real} =
promote_type(T, S)

# to complex -- by-pass default from Base which lead to "NG" flag in the (zero) imaginary part

Expand Down
8 changes: 8 additions & 0 deletions src/intervals/real_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ for T ∈ (:BareInterval, :Interval)
Base.intersect(::$T) =
throw(ArgumentError("`intersect` is purposely not supported for intervals. See instead `intersect_interval`"))

Base.union!(::BitSet, ::$T) = # needed to resolve ambiguity
throw(ArgumentError("`union!` is purposely not supported for intervals. See instead `hull`"))
Base.union!(::AbstractSet, ::$T) = # also returned when calling `intersect`, `symdiff` with intervals
throw(ArgumentError("`union!` is purposely not supported for intervals. See instead `hull`"))
Base.union!(::AbstractVector{S}, ::$T) where {S} =
Expand Down Expand Up @@ -154,6 +156,12 @@ result is true if and only if the interval contains only that number.
"""
Base.:(==)(x::Union{BareInterval,Interval}, y::Number) = isthin(x, y)
Base.:(==)(x::Number, y::Union{BareInterval,Interval}) = y == x
# needed to resolve ambiguity from irrationals.jl
Base.:(==)(x::Interval, y::AbstractIrrational) = isthin(x, y)
Base.:(==)(x::AbstractIrrational, y::Interval) = y == x
# needed to resolve ambiguity from complex.jl
Base.:(==)(x::Interval, y::Complex) = isreal(y) & (real(y) == x)
Base.:(==)(x::Complex, y::Interval) = y == x

# follows docstring of `Base.iszero`
Base.iszero(x::Union{BareInterval,Interval}) = isthinzero(x)
Expand Down
103 changes: 88 additions & 15 deletions src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,73 @@ end



# matrix eigenvalues

function LinearAlgebra.eigvals!(A::AbstractMatrix{<:Interval}; permute::Bool=true, scale::Bool=true, sortby::Union{Function,Nothing}=LinearAlgebra.eigsortby)
# note: this function does not overwrite `A`
v = _eigvals(A, permute, scale, sortby)
isreal(v) && return v
_fold_conjugate!(v)
isreal(v) && return real(v)
return v
end

LinearAlgebra.eigvals!(A::AbstractMatrix{<:Complex{<:Interval}}; permute::Bool=true, scale::Bool=true, sortby::Union{Function,Nothing}=LinearAlgebra.eigsortby) =
# note: this function does not overwrite `A`
_eigvals(A, permute, scale, sortby)

function _eigvals(A, permute, scale, sortby)
# Gershgorin circle theorem
B = _similarity_transform(A, permute, scale, sortby)
v = LinearAlgebra.diag(B)
T = eltype(v)
for j axes(B, 1)
r = zero(T)
for i axes(B, 2)
if i j
r += abs(B[i,j])
end
end
v[j] = interval(v[j], r; format = :midpoint)
end
return v
end

function _similarity_transform(A, permute, scale, sortby)
mA = mid.(A)
mλ, mV = LinearAlgebra.eigen(mA; permute = permute, scale = scale, sortby = sortby)
.+= LinearAlgebra.diag(mV \ (mA * mV - mV * LinearAlgebra.Diagonal(mλ)))
Λ = LinearAlgebra.Diagonal(interval(mλ))
V = interval(mV)
V .= Λ .+ inv(V) * (A * V - V * Λ)
return V
end

function _fold_conjugate!(v)
for i eachindex(v)
vᵢ = v[i]
idxs = findall(j -> (j i) & !isdisjoint_interval(conj(vᵢ), v[j]), eachindex(v))
if isempty(idxs)
v[i] = real(vᵢ)
else
w = view(v, idxs)
z = conj(intersect_interval(conj(vᵢ), reduce(intersect_interval, w)))
z = complex(setdecoration(real(z), min(decoration(real(vᵢ)), minimum(decoration real, w))), setdecoration(imag(z), min(decoration(imag(vᵢ)), minimum(decoration imag, w))))
v[i] = z
end
end
return v
end



# matrix determinant

LinearAlgebra.det(A::AbstractMatrix{<:Interval}) = real(reduce(*, LinearAlgebra.eigvals(A)))
LinearAlgebra.det(A::AbstractMatrix{<:Complex{<:Interval}}) = reduce(*, LinearAlgebra.eigvals(A))



# matrix multiplication

"""
Expand Down Expand Up @@ -94,19 +161,23 @@ function LinearAlgebra.mul!(C::AbstractVecOrMat{<:RealOrComplexI}, A::AbstractMa
return LinearAlgebra.mul!(C, A, B, interval(true), interval(false))
end

function LinearAlgebra.mul!(C::AbstractVecOrMat{<:RealOrComplexI}, A::AbstractMatrix{<:RealOrComplexI}, B::AbstractVecOrMat{<:RealOrComplexI}, α::Number, β::Number)
size(A, 2) == size(B, 1) || return throw(DimensionMismatch("The number of columns of A must match the number of rows of B."))
return _mul!(matmul_mode(), C, A, B, α, β)
end
for T (:AbstractVector, :AbstractMatrix) # needed to resolve method ambiguities
@eval begin
function LinearAlgebra.mul!(C::AbstractVecOrMat{<:RealOrComplexI}, A::AbstractMatrix{<:RealOrComplexI}, B::$T{<:RealOrComplexI}, α::Number, β::Number)
size(A, 2) == size(B, 1) || return throw(DimensionMismatch("The number of columns of A must match the number of rows of B."))
return _mul!(matmul_mode(), C, A, B, α, β)
end

function LinearAlgebra.mul!(C::AbstractVecOrMat{<:RealOrComplexI}, A::AbstractMatrix, B::AbstractVecOrMat{<:RealOrComplexI}, α::Number, β::Number)
size(A, 2) == size(B, 1) || return throw(DimensionMismatch("The number of columns of A must match the number of rows of B."))
return _mul!(matmul_mode(), C, A, B, α, β)
end
function LinearAlgebra.mul!(C::AbstractVecOrMat{<:RealOrComplexI}, A::AbstractMatrix, B::$T{<:RealOrComplexI}, α::Number, β::Number)
size(A, 2) == size(B, 1) || return throw(DimensionMismatch("The number of columns of A must match the number of rows of B."))
return _mul!(matmul_mode(), C, A, B, α, β)
end

function LinearAlgebra.mul!(C::AbstractVecOrMat{<:RealOrComplexI}, A::AbstractMatrix{<:RealOrComplexI}, B::AbstractVecOrMat, α::Number, β::Number)
size(A, 2) == size(B, 1) || return throw(DimensionMismatch("The number of columns of A must match the number of rows of B."))
return _mul!(matmul_mode(), C, A, B, α, β)
function LinearAlgebra.mul!(C::AbstractVecOrMat{<:RealOrComplexI}, A::AbstractMatrix{<:RealOrComplexI}, B::$T, α::Number, β::Number)
size(A, 2) == size(B, 1) || return throw(DimensionMismatch("The number of columns of A must match the number of rows of B."))
return _mul!(matmul_mode(), C, A, B, α, β)
end
end
end

function _mul!(::MatMulMode{:slow}, C, A::AbstractMatrix, B::AbstractVecOrMat, α, β)
Expand Down Expand Up @@ -139,8 +210,10 @@ _mul!(::MatMulMode{:fast}, C, A::AbstractMatrix{<:Complex{<:Interval{<:Rational}
_mul!(::MatMulMode{:fast}, C, A::AbstractMatrix{<:Interval{<:Rational}}, B::AbstractVecOrMat{<:Complex{<:Interval{<:Rational}}}, α, β) =
LinearAlgebra._mul!(C, A, B, α, β)

_mul!(::MatMulMode{:fast}, C, A, B, α, β) = _fastmul!(C, A, B, α, β)

for (T, S) ((:Interval, :Interval), (:Interval, :Any), (:Any, :Interval))
@eval function _mul!(::MatMulMode{:fast}, C, A::AbstractMatrix{<:$T}, B::AbstractVecOrMat{<:$S}, α, β)
@eval function _fastmul!(C, A::AbstractMatrix{<:$T}, B::AbstractVecOrMat{<:$S}, α, β)
CoefType = eltype(C)
if iszero(α)
if iszero(β)
Expand Down Expand Up @@ -177,7 +250,7 @@ end

for (T, S) ((:(Complex{<:Interval}), :(Complex{<:Interval})),
(:(Complex{<:Interval}), :Complex), (:Complex, :(Complex{<:Interval})))
@eval function _mul!(::MatMulMode{:fast}, C, A::AbstractMatrix{<:$T}, B::AbstractVecOrMat{<:$S}, α, β)
@eval function _fastmul!(C, A::AbstractMatrix{<:$T}, B::AbstractVecOrMat{<:$S}, α, β)
CoefType = eltype(C)
if iszero(α)
if iszero(β)
Expand Down Expand Up @@ -225,7 +298,7 @@ end

for (T, S) ((:(Complex{<:Interval}), :Interval), (:(Complex{<:Interval}), :Any), (:Complex, :Interval))
@eval begin
function _mul!(::MatMulMode{:fast}, C, A::AbstractMatrix{<:$T}, B::AbstractVecOrMat{<:$S}, α, β)
function _fastmul!(C, A::AbstractMatrix{<:$T}, B::AbstractVecOrMat{<:$S}, α, β)
CoefType = eltype(C)
if iszero(α)
if iszero(β)
Expand Down Expand Up @@ -261,7 +334,7 @@ for (T, S) ∈ ((:(Complex{<:Interval}), :Interval), (:(Complex{<:Interval}), :A
return C
end

function _mul!(::MatMulMode{:fast}, C, A::AbstractMatrix{<:$S}, B::AbstractVecOrMat{<:$T}, α, β)
function _fastmul!(C, A::AbstractMatrix{<:$S}, B::AbstractVecOrMat{<:$T}, α, β)
CoefType = eltype(C)
if iszero(α)
if iszero(β)
Expand Down
10 changes: 2 additions & 8 deletions test/aqua.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,9 @@ using Aqua
pkg_match(pkgname, pkdir::Nothing) = false
pkg_match(pkgname, pkdir::AbstractString) = occursin(pkgname, pkdir)
filter!(x -> pkg_match("IntervalArithmetic", pkgdir(last(x).module)), ambs)
@test_broken length(ambs) == 0
@test length(ambs) == 0
end

@testset "Aqua tests (additional)" begin
Aqua.test_undefined_exports(IntervalArithmetic)
# Aqua.test_deps_compat(IntervalArithmetic)
Aqua.test_stale_deps(IntervalArithmetic)
Aqua.test_piracies(IntervalArithmetic)
Aqua.test_unbound_args(IntervalArithmetic)
Aqua.test_project_extras(IntervalArithmetic)
Aqua.test_persistent_tasks(IntervalArithmetic)
@test Aqua.test_all(IntervalArithmetic)
end

0 comments on commit 9bb9fc7

Please sign in to comment.