Skip to content

Commit

Permalink
Improve complex power
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierHnt committed Nov 29, 2024
1 parent ec34419 commit 6141c47
Showing 1 changed file with 92 additions and 117 deletions.
209 changes: 92 additions & 117 deletions src/intervals/arithmetic/power.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@ power_mode() = PowerMode{:fast}()
^(x::BareInterval, y::BareInterval)
^(x::Interval, y::Interval)
Compute the power of the positive real part of `x` by `y`. This function is not
in the IEEE Standard 1788-2015. Its behaviour depend on the current
[`PowerMode`](@ref).
Compute the power of `x` by `y`. Unless `y` is an integer, the positive real
part of `x^y` is returned. This function is not in the IEEE Standard 1788-2015.
Its behaviour depends on the current [`PowerMode`](@ref).
See also: [`pow`](@ref) and [`pown`](@ref).
See also: [`pow`](@ref), [`pown`](@ref), [`fastpow`](@ref) and
[`fastpown`](@ref).
# Examples
Expand Down Expand Up @@ -60,40 +61,43 @@ function Base.:^(x::Interval, y::Interval)
return _unsafe_interval(bareinterval(r), d, t)
end

Base.:^(n::Integer, y::Interval) = ^(n//one(n), y)
Base.:^(x::Interval, n::Integer) = ^(x, n//one(n))
Base.:^(x::Rational, y::Interval) = ^(convert(Interval{typeof(x)}, x), y)
Base.:^(x::Interval, y::Rational) = ^(x, convert(Interval{typeof(y)}, y))

function Base.:^(x::Complex{Interval{T}}, y::Complex{Interval{T}}) where {T<:NumTypes}
!isthinzero(x) && return exp(y * log(x))
d = min(decoration(x), decoration(y))
t = isguaranteed(x) & isguaranteed(y)
isthinzero(y) && return complex(_unsafe_interval(one(BareInterval{T}), d, t), _unsafe_interval(zero(BareInterval{T}), d, t))
(inf(real(y)) > 0) & !isempty_interval(bareinterval(real(y))) && return complex(_unsafe_interval(zero(BareInterval{T}), d, t), _unsafe_interval(zero(BareInterval{T}), d, t))
d = min(d, trv)
return complex(_unsafe_interval(emptyinterval(BareInterval{T}), d, t), _unsafe_interval(emptyinterval(BareInterval{T}), d, t))
function Base.:^(x::Complex{<:Interval}, y::Complex{<:Interval})
if isreal(x) && isthininteger(y)
r = real(x) ^ real(y)
d = min(decoration(x), decoration(y), decoration(r))
t = isguaranteed(x) & isguaranteed(y)
return complex(_unsafe_interval(bareinterval(real(r)), d, t), _unsafe_interval(bareinterval(imag(r)), d, t))
else
isthininteger(y) && return exp(y * _log_no_branch_cut(x))
return exp(y * log(x))
end
end
Base.:^(x::Complex{<:Interval}, y::Complex{<:Interval}) = ^(promote(x, y)...)
Base.:^(x::Complex{<:Interval}, y::Real) = ^(promote(x, y)...)
Base.:^(x::Real, y::Complex{<:Interval}) = ^(promote(x, y)...)
# needed to avoid method ambiguities
Base.:^(x::Complex{<:Interval}, n::Bool) = ^(promote(x, n)...)
Base.:^(x::Complex{<:Interval}, n::Integer) = ^(promote(x, n)...)
Base.:^(x::Complex{<:Interval}, n::Rational) = ^(promote(x, n)...)

function _log_no_branch_cut(z::Complex{<:Interval})
x, y = reim(z)
by = bareinterval(y)
bx = bareinterval(x)
r = atan(by, bx)
d = min(decoration(y), decoration(x), decoration(r))
d = min(d,
ifelse(in_interval(0, by),
ifelse(in_interval(0, bx), trv, d),
d))
t = isguaranteed(y) & isguaranteed(x)
angle = _unsafe_interval(r, d, t)
return complex(log(abs(z)), angle)
end

# needed to avoid method errors
Base.:^(x::Complex{<:Interval}, y::Interval) = ^(promote(x, y)...)
Base.:^(x::Interval, y::Complex{<:Interval}) = ^(promote(x, y)...)

# overwrite behaviour for small integer powers from https://github.com/JuliaLang/julia/pull/24240
# Base.literal_pow(::typeof(^), x::Interval, ::Val{n}) where {n} = x^n
Base.literal_pow(::typeof(^), x::Interval, ::Val{n}) where {n} = _select_pown(power_mode(), x, n)
function Base.literal_pow(::typeof(^), x::Complex{Interval{T}}, ::Val{n}) where {T<:NumTypes,n}
!isthinzero(x) && return exp(interval(T, n) * log(x))
d = decoration(x)
t = isguaranteed(x)
n == 0 && return complex(_unsafe_interval(one(BareInterval{T}), d, t), _unsafe_interval(zero(BareInterval{T}), d, t))
n > 0 && return complex(_unsafe_interval(zero(BareInterval{T}), d, t), _unsafe_interval(zero(BareInterval{T}), d, t))
d = min(d, trv)
return complex(_unsafe_interval(emptyinterval(BareInterval{T}), d, t), _unsafe_interval(emptyinterval(BareInterval{T}), d, t))
end
Base.literal_pow(::typeof(^), x::Complex{<:Interval}, ::Val{n}) where {n} = ^(x, interval(n))

# helper functions for power

Expand All @@ -110,7 +114,7 @@ if `y` is a thin integer, this is not equivalent to `pown(x, sup(y))`.
Implement the `pow` function of the IEEE Standard 1788-2015 (Table 9.1).
See also: [`pown`](@ref).
See also: [`fastpow`](@ref), [`pown`](@ref) and [`fastpown`](@ref).
# Examples
Expand All @@ -127,40 +131,28 @@ julia> pow(interval(-1, 1), interval(-3))
Interval{Float64}(1.0, Inf, trv)
```
"""
function pow(x::BareInterval{T}, y::BareInterval{T}) where {T<:AbstractFloat}
function pow(x::BareInterval{T}, y::BareInterval{T}) where {T<:NumTypes}
isempty_interval(y) && return y
domain = _unsafe_bareinterval(T, zero(T), typemax(T))
x = intersect_interval(x, domain)
isempty_interval(x) && return x
isthin(y) && return _pow(x, sup(y))
return hull(_pow(x, inf(y)), _pow(x, sup(y)))
isthin(y) && return _thin_pow(x, sup(y))
return hull(_thin_pow(x, inf(y)), _thin_pow(x, sup(y)))
end
function pow(x::BareInterval{T}, y::BareInterval{T}) where {T<:Rational}
isempty_interval(y) && return y
domain = _unsafe_bareinterval(T, zero(T), typemax(T))
x = intersect_interval(x, domain)
isempty_interval(x) && return x
isthin(y) && return _pow(x, sup(y))
return hull(_pow(x, inf(y)), _pow(x, sup(y)))
end
pow(x::BareInterval{<:AbstractFloat}, y::BareInterval{<:AbstractFloat}) = pow(promote(x, y)...)
pow(x::BareInterval{<:Rational}, y::BareInterval{<:Rational}) = pow(promote(x, y)...)
pow(x::BareInterval{<:Rational}, y::BareInterval{<:AbstractFloat}) = pow(promote(x, y)...)
pow(x::BareInterval, y::BareInterval) = pow(promote(x, y)...)
# specialize on rational to improve exactness
function pow(x::BareInterval{T}, y::BareInterval{S}) where {T<:NumTypes,S<:Rational}
R = promote_numtype(T, S)
isempty_interval(y) && return emptyinterval(BareInterval{R})
domain = _unsafe_bareinterval(T, zero(T), typemax(T))
x = intersect_interval(x, domain)
isempty_interval(x) && return emptyinterval(BareInterval{R})
isthin(y) && return BareInterval{R}(_pow(x, sup(y)))
return BareInterval{R}(hull(_pow(x, inf(y)), _pow(x, sup(y))))
isthin(y) && return BareInterval{R}(_thin_pow(x, sup(y)))
return BareInterval{R}(hull(_thin_pow(x, inf(y)), _thin_pow(x, sup(y))))
end

pow(n::Integer, y::BareInterval) = pow(n//one(n), y)
pow(x::BareInterval, n::Integer) = pow(x, n//one(n))
pow(x::Real, y::BareInterval) = pow(bareinterval(x), y)
pow(x::BareInterval, y::Real) = pow(x, bareinterval(y))
pow(x::BareInterval, n::Integer) = pow(x, n//one(n))

function pow(x::Interval, y::Interval)
bx = bareinterval(x)
Expand All @@ -172,17 +164,15 @@ function pow(x::Interval, y::Interval)
return _unsafe_interval(r, d, t)
end

pow(n::Integer, y::Interval) = pow(n//one(n), y)
pow(x::Interval, n::Integer) = pow(x, n//one(n))
pow(x::Real, y::Interval) = pow(interval(x), y)
pow(x::Interval, y::Real) = pow(x, interval(y))
pow(x::Interval, n::Integer) = pow(x, n//one(n))

# helper functions for power
# helper function for `pow`

function _pow(x::BareInterval{T}, y::T) where {T<:NumTypes}
function _thin_pow(x::BareInterval{T}, y::T) where {T<:NumTypes}
# assume `inf(x) ≥ 0` and `!isempty_interval(x)`
if sup(x) == 0
y > 0 && return x # zero(x)
if sup(x) == 0 # isthinzero(x)
y > 0 && return x
return emptyinterval(BareInterval{T})
else
isinteger(y) && return pown(x, Integer(y))
Expand All @@ -193,15 +183,14 @@ function _pow(x::BareInterval{T}, y::T) where {T<:NumTypes}
end
end

function _pow(x::BareInterval{T}, y::Rational{S}) where {T<:NumTypes,S<:Integer}
function _thin_pow(x::BareInterval{T}, y::Rational{S}) where {T<:NumTypes,S<:Integer}
# assume `inf(x) ≥ 0` and `!isempty_interval(x)`
if sup(x) == 0
y > 0 && return x # zero(x)
if sup(x) == 0 # isthinzero(x)
y > 0 && return x
return emptyinterval(BareInterval{T})
else
isinteger(y) && return pown(x, S(y))
y == (1//2) && return sqrt(x)
return pown(rootn(x, y.den), y.num)
return pown(rootn(x, denominator(y)), numerator(y))
end
end

Expand All @@ -210,6 +199,8 @@ end
Implement the `pown` function of the IEEE Standard 1788-2015 (Table 9.1).
See also: [`fastpown`](@ref), [`pow`](@ref) and [`fastpow`](@ref).
# Examples
```jldoctest
Expand Down Expand Up @@ -321,37 +312,29 @@ Base.hypot(x::Interval, y::Interval) = sqrt(_select_pown(power_mode(), x, 2) + _
"""
fastpow(x, y)
A faster implementation of `pow(x, y)`, at the cost of maybe returning a
slightly larger interval.
A faster implementation of `pow(x, y)`, at the cost of maybe returning a larger
interval.
See also: [`pow`](@ref), [`pown`](@ref) and [`fastpown`](@ref).
"""
function fastpow(x::BareInterval{T}, y::BareInterval{T}) where {T<:NumTypes}
isempty_interval(y) && return y
isthininteger(y) && return fastpow(x, Integer(sup(y)))
domain = _unsafe_bareinterval(T, zero(T), typemax(T))
x = intersect_interval(x, domain)
isempty_interval(x) && return x
if sup(x) == 0
sup(y) > 0 && return x # zero(x)
if sup(x) == 0 # isthinzero(x)
sup(y) > 0 && return x
return emptyinterval(BareInterval{T})
elseif isthininteger(y)
n = Integer(sup(y))
n < 0 && return inv(_positive_power_by_squaring(x, -n))
return _positive_power_by_squaring(x, n)
else
return exp(y * log(x))
end
end

function fastpow(x::BareInterval{T}, n::Integer) where {T<:NumTypes}
n < 0 && return inv(fastpow(x, -n))
domain = _unsafe_bareinterval(T, zero(T), typemax(T))
x = intersect_interval(x, domain)
isempty_interval(x) && return x
if sup(x) == 0
n > 0 && return x # zero(x)
return emptyinterval(BareInterval{T}) # n == 0
else
return _positive_power_by_squaring(x, n)
end
end

fastpow(x::BareInterval, y::BareInterval) = fastpow(promote(x, y)...)

fastpow(x::BareInterval, y::Real) = fastpow(x, bareinterval(y))

function fastpow(x::Interval, y::Interval)
Expand All @@ -366,17 +349,35 @@ end

fastpow(x::Interval, y::Real) = fastpow(x, interval(y))

# helper function for fast power
"""
fastpown(x, n)
A faster implementation of `pown(x, n)`, at the cost of maybe returning a larger
interval.
See also: [`pown`](@ref), [`pow`](@ref) and [`fastpow`](@ref).
"""
function fastpown(x::BareInterval{T}, n::Integer) where {T<:NumTypes}
isempty_interval(x) && return x
n < 0 && return inv(fastpown(x, -n))
range = _unsafe_bareinterval(T, ifelse(iseven(n), zero(T), typemin(T)), typemax(T))
return intersect_interval(_positive_power_by_squaring(x, n), range)
end

function fastpown(x::Interval, n::Integer)
r = fastpown(bareinterval(x), n)
d = min(decoration(x), decoration(r))
d = min(d, ifelse((n < 0) & in_interval(0, x), trv, d))
return _unsafe_interval(r, d, isguaranteed(x))
end

# helper function for `fastpow` and `fastpown`

# code inspired by `power_by_squaring(::Any, ::Integer)` in base/intfuncs.jl
Base.@assume_effects :terminates_locally function _positive_power_by_squaring(x::BareInterval, n::Integer)
if n == 1
return x
elseif n == 0
return one(x)
elseif n == 2
return x*x
end
Base.@assume_effects :terminates_locally function _positive_power_by_squaring(x, n)
n == 0 && return one(x)
n == 1 && return x
n == 2 && return x*x
t = trailing_zeros(n) + 1
n >>= t
while (t -= 1) > 0
Expand All @@ -394,32 +395,6 @@ Base.@assume_effects :terminates_locally function _positive_power_by_squaring(x:
return y
end

"""
fastpown(x, n)
A faster implementation of `pown(x, y)`, at the cost of maybe returning a
slightly larger interval.
"""
function fastpown(x::BareInterval{T}, n::Integer) where {T<:NumTypes}
isempty_interval(x) && return x
n == 0 && return one(BareInterval{T})
n == 1 && return x
if n < 0
isthinzero(x) && return emptyinterval(BareInterval{T})
return inv(fastpown(x, -n))
else
range = _unsafe_bareinterval(T, ifelse(iseven(n), zero(T), typemin(T)), typemax(T))
return intersect_interval(_positive_power_by_squaring(x, n), range)
end
end

function fastpown(x::Interval, n::Integer)
r = fastpown(bareinterval(x), n)
d = min(decoration(x), decoration(r))
d = min(d, ifelse((n < 0) & in_interval(0, x), trv, d))
return _unsafe_interval(r, d, isguaranteed(x))
end

#

for f (:cbrt, :exp, :exp2, :exp10, :expm1)
Expand Down

0 comments on commit 6141c47

Please sign in to comment.