Skip to content

Commit

Permalink
Faster Rationals by avoiding unnecessary divgcd (#35492)
Browse files Browse the repository at this point in the history
  • Loading branch information
Liozou authored May 4, 2020
1 parent 5ec0608 commit 78af72f
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 47 deletions.
127 changes: 80 additions & 47 deletions base/rational.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,34 @@ struct Rational{T<:Integer} <: Real
num::T
den::T

function Rational{T}(num::Integer, den::Integer) where T<:Integer
num == den == zero(T) && __throw_rational_argerror_zero(T)
num2, den2 = divgcd(num, den)
if T<:Signed && signbit(den2)
den2 = -den2
signbit(den2) && __throw_rational_argerror_typemin(T)
num2 = -num2
end
return new(num2, den2)
# Unexported inner constructor of Rational that bypasses all checks
global unsafe_rational(::Type{T}, num, den) where {T} = new{T}(num, den)
end

unsafe_rational(num::T, den::T) where {T<:Integer} = unsafe_rational(T, num, den)
unsafe_rational(num::Integer, den::Integer) = unsafe_rational(promote(num, den)...)

@noinline __throw_rational_argerror_typemin(T) = throw(ArgumentError("invalid rational: denominator can't be typemin($T)"))
function checked_den(num::T, den::T) where T<:Integer
if signbit(den)
den = -den
signbit(den) && __throw_rational_argerror_typemin(T)
num = -num
end
return unsafe_rational(T, num, den)
end
checked_den(num::Integer, den::Integer) = checked_den(promote(num, den)...)

@noinline __throw_rational_argerror_zero(T) = throw(ArgumentError("invalid rational: zero($T)//zero($T)"))
@noinline __throw_rational_argerror_typemin(T) = throw(ArgumentError("invalid rational: denominator can't be typemin($T)"))
function Rational{T}(num::Integer, den::Integer) where T<:Integer
iszero(den) && iszero(num) && __throw_rational_argerror_zero(T)
num, den = divgcd(num, den)
return checked_den(T(num), T(den))
end

Rational(n::T, d::T) where {T<:Integer} = Rational{T}(n,d)
Rational(n::Integer, d::Integer) = Rational(promote(n,d)...)
Rational(n::Integer) = Rational(n,one(n))
Rational(n::T, d::T) where {T<:Integer} = Rational{T}(n, d)
Rational(n::Integer, d::Integer) = Rational(promote(n, d)...)
Rational(n::Integer) = unsafe_rational(n, one(n))

function divgcd(x::Integer,y::Integer)
g = gcd(x,y)
Expand All @@ -50,20 +61,20 @@ julia> (3 // 5) // (2 // 1)
//(n::Integer, d::Integer) = Rational(n,d)

function //(x::Rational, y::Integer)
xn,yn = divgcd(x.num,y)
xn//checked_mul(x.den,yn)
xn, yn = divgcd(x.num,y)
checked_den(xn, checked_mul(x.den, yn))
end
function //(x::Integer, y::Rational)
xn,yn = divgcd(x,y.num)
checked_mul(xn,y.den)//yn
xn, yn = divgcd(x,y.num)
checked_den(checked_mul(xn, y.den), yn)
end
function //(x::Rational, y::Rational)
xn,yn = divgcd(x.num,y.num)
xd,yd = divgcd(x.den,y.den)
checked_mul(xn,yd)//checked_mul(xd,yn)
checked_den(checked_mul(xn, yd), checked_mul(xd, yn))
end

//(x::Complex, y::Real) = complex(real(x)//y,imag(x)//y)
//(x::Complex, y::Real) = complex(real(x)//y, imag(x)//y)
//(x::Number, y::Complex) = x*conj(y)//abs2(y)


Expand All @@ -84,8 +95,12 @@ function write(s::IO, z::Rational)
write(s,numerator(z),denominator(z))
end

Rational{T}(x::Rational) where {T<:Integer} = Rational{T}(convert(T,x.num), convert(T,x.den))
Rational{T}(x::Integer) where {T<:Integer} = Rational{T}(convert(T,x), convert(T,1))
function Rational{T}(x::Rational) where T<:Integer
unsafe_rational(T, convert(T, x.num), convert(T, x.den))
end
function Rational{T}(x::Integer) where T<:Integer
unsafe_rational(T, convert(T, x), one(T))
end

Rational(x::Rational) = x

Expand All @@ -108,7 +123,7 @@ end
Rational(x::Float64) = Rational{Int64}(x)
Rational(x::Float32) = Rational{Int}(x)

big(q::Rational) = big(numerator(q))//big(denominator(q))
big(q::Rational) = unsafe_rational(big(numerator(q)), big(denominator(q)))

big(z::Complex{<:Rational{<:Integer}}) = Complex{Rational{BigInt}}(z)

Expand All @@ -118,6 +133,8 @@ promote_rule(::Type{Rational{T}}, ::Type{S}) where {T<:Integer,S<:AbstractFloat}

widen(::Type{Rational{T}}) where {T} = Rational{widen(T)}

@noinline __throw_negate_unsigned() = throw(OverflowError("cannot negate unsigned number"))

"""
rationalize([T<:Integer=Int,] x; tol::Real=eps(x))
Expand All @@ -140,8 +157,9 @@ function rationalize(::Type{T}, x::AbstractFloat, tol::Real) where T<:Integer
if tol < 0
throw(ArgumentError("negative tolerance $tol"))
end
T<:Unsigned && x < 0 && __throw_negate_unsigned()
isnan(x) && return T(x)//one(T)
isinf(x) && return (x < 0 ? -one(T) : one(T))//zero(T)
isinf(x) && return unsafe_rational(x < 0 ? -one(T) : one(T), zero(T))

p, q = (x < 0 ? -one(T) : one(T)), zero(T)
pp, qq = zero(T), one(T)
Expand Down Expand Up @@ -234,59 +252,74 @@ denominator(x::Rational) = x.den

sign(x::Rational) = oftype(x, sign(x.num))
signbit(x::Rational) = signbit(x.num)
copysign(x::Rational, y::Real) = copysign(x.num,y) // x.den
copysign(x::Rational, y::Rational) = copysign(x.num,y.num) // x.den
copysign(x::Rational, y::Real) = unsafe_rational(copysign(x.num, y), x.den)
copysign(x::Rational, y::Rational) = unsafe_rational(copysign(x.num, y.num), x.den)

abs(x::Rational) = Rational(abs(x.num), x.den)

typemin(::Type{Rational{T}}) where {T<:Integer} = -one(T)//zero(T)
typemax(::Type{Rational{T}}) where {T<:Integer} = one(T)//zero(T)
typemin(::Type{Rational{T}}) where {T<:Signed} = unsafe_rational(T, -one(T), zero(T))
typemin(::Type{Rational{T}}) where {T<:Integer} = unsafe_rational(T, zero(T), one(T))
typemax(::Type{Rational{T}}) where {T<:Integer} = unsafe_rational(T, one(T), zero(T))

isinteger(x::Rational) = x.den == 1

+(x::Rational) = (+x.num) // x.den
-(x::Rational) = (-x.num) // x.den
+(x::Rational) = unsafe_rational(+x.num, x.den)
-(x::Rational) = unsafe_rational(-x.num, x.den)

function -(x::Rational{T}) where T<:BitSigned
x.num == typemin(T) && throw(OverflowError("rational numerator is typemin(T)"))
(-x.num) // x.den
x.num == typemin(T) && __throw_rational_numerator_typemin(T)
unsafe_rational(-x.num, x.den)
end
@noinline __throw_rational_numerator_typemin(T) = throw(OverflowError("rational numerator is typemin($T)"))

function -(x::Rational{T}) where T<:Unsigned
x.num != zero(T) && throw(OverflowError("cannot negate unsigned number"))
x.num != zero(T) && __throw_negate_unsigned()
x
end

for (op,chop) in ((:+,:checked_add), (:-,:checked_sub),
(:rem,:rem), (:mod,:mod))
for (op,chop) in ((:+,:checked_add), (:-,:checked_sub), (:rem,:rem), (:mod,:mod))
@eval begin
function ($op)(x::Rational, y::Rational)
xd, yd = divgcd(x.den, y.den)
Rational(($chop)(checked_mul(x.num,yd), checked_mul(y.num,xd)), checked_mul(x.den,yd))
end

function ($op)(x::Rational, y::Integer)
Rational(($chop)(x.num, checked_mul(x.den, y)), x.den)
unsafe_rational(($chop)(x.num, checked_mul(x.den, y)), x.den)
end

end
end
for (op,chop) in ((:+,:checked_add), (:-,:checked_sub))
@eval begin
function ($op)(y::Integer, x::Rational)
unsafe_rational(($chop)(checked_mul(x.den, y), x.num), x.den)
end
end
end
for (op,chop) in ((:rem,:rem), (:mod,:mod))
@eval begin
function ($op)(y::Integer, x::Rational)
Rational(($chop)(checked_mul(x.den, y), x.num), x.den)
end
end
end

function *(x::Rational, y::Rational)
xn,yd = divgcd(x.num,y.den)
xd,yn = divgcd(x.den,y.num)
checked_mul(xn,yn) // checked_mul(xd,yd)
xn, yd = divgcd(x.num, y.den)
xd, yn = divgcd(x.den, y.num)
unsafe_rational(checked_mul(xn, yn), checked_mul(xd, yd))
end
function *(x::Rational, y::Integer)
xd, yn = divgcd(x.den, y)
checked_mul(x.num, yn) // xd
unsafe_rational(checked_mul(x.num, yn), xd)
end
function *(y::Integer, x::Rational)
yn, xd = divgcd(y, x.den)
unsafe_rational(checked_mul(yn, x.num), xd)
end
*(x::Integer, y::Rational) = *(y, x)
/(x::Rational, y::Rational) = x//y
/(x::Rational, y::Complex{<:Union{Integer,Rational}}) = x//y
inv(x::Rational) = Rational(x.den, x.num)
/(x::Rational, y::Union{Rational, Integer, Complex{<:Union{Integer,Rational}}}) = x//y
/(x::Union{Integer, Complex{<:Union{Integer,Rational}}}, y::Rational) = x//y
inv(x::Rational{T}) where {T} = checked_den(x.den, x.num)

fma(x::Rational, y::Rational, z::Rational) = x*y+z

Expand Down Expand Up @@ -403,7 +436,7 @@ round(x::Rational, r::RoundingMode=RoundNearest) = round(typeof(x), x, r)

function round(::Type{T}, x::Rational{Tr}, r::RoundingMode=RoundNearest) where {T,Tr}
if iszero(denominator(x)) && !(T <: Integer)
return convert(T, copysign(one(Tr)//zero(Tr), numerator(x)))
return convert(T, copysign(unsafe_rational(one(Tr), zero(Tr)), numerator(x)))
end
convert(T, div(numerator(x), denominator(x), r))
end
Expand Down Expand Up @@ -437,8 +470,8 @@ end

float(::Type{Rational{T}}) where {T<:Integer} = float(T)

gcd(x::Rational, y::Rational) = gcd(x.num, y.num) // lcm(x.den, y.den)
lcm(x::Rational, y::Rational) = lcm(x.num, y.num) // gcd(x.den, y.den)
gcd(x::Rational, y::Rational) = unsafe_rational(gcd(x.num, y.num), lcm(x.den, y.den))
lcm(x::Rational, y::Rational) = unsafe_rational(lcm(x.num, y.num), gcd(x.den, y.den))
function gcdx(x::Rational, y::Rational)
c = gcd(x, y)
if iszero(c.num)
Expand Down
14 changes: 14 additions & 0 deletions test/rational.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ using Test
@test 5//0 == 1//0
@test -1//0 == -1//0
@test -7//0 == -1//0
@test (-1//2) // (-2//5) == 5//4

@test_throws OverflowError -(0x01//0x0f)
@test_throws OverflowError -(typemin(Int)//1)
Expand All @@ -26,9 +27,13 @@ using Test
@test (typemax(Int)//1) / (typemax(Int)//1) == 1
@test (1//typemax(Int)) / (1//typemax(Int)) == 1
@test_throws OverflowError (1//2)^63
@test inv((1+typemin(Int))//typemax(Int)) == -1
@test_throws ArgumentError inv(typemin(Int)//typemax(Int))
@test_throws ArgumentError Rational(0x1, typemin(Int32))

@test @inferred(rationalize(Int, 3.0, 0.0)) === 3//1
@test @inferred(rationalize(Int, 3.0, 0)) === 3//1
@test_throws OverflowError rationalize(UInt, -2.0)
@test_throws ArgumentError rationalize(Int, big(3.0), -1.)
# issue 26823
@test_throws InexactError rationalize(Int, NaN)
Expand All @@ -38,6 +43,11 @@ using Test
@test -2 // typemin(Int) == -1 // (typemin(Int) >> 1)
@test 2 // typemin(Int) == 1 // (typemin(Int) >> 1)

@test_throws InexactError Rational(UInt(1), typemin(Int32))
@test iszero(Rational{Int}(UInt(0), 1))
@test Rational{BigInt}(UInt(1), Int(-1)) == -1
@test_broken Rational{Int64}(UInt(1), typemin(Int32)) == Int64(1) // Int64(typemin(Int32))

for a = -5:5, b = -5:5
if a == b == 0; continue; end
if ispow2(b)
Expand Down Expand Up @@ -120,6 +130,8 @@ end
@test widen(Rational{T}) == Rational{widen(T)}
end

@test iszero(typemin(Rational{UInt}))

@test Rational(Float32(rand_int)) == Rational(rand_int)

@test Rational(Rational(rand_int)) == Rational(rand_int)
Expand Down Expand Up @@ -548,6 +560,8 @@ end
end
@test 1//2 * 3 == 3//2
@test -3 * (1//2) == -3//2
@test (6//5) // -3 == -2//5
@test -4 // (-6//5) == 10//3

@test_throws OverflowError UInt(1)//2 - 1
@test_throws OverflowError 1 - UInt(5)//2
Expand Down

0 comments on commit 78af72f

Please sign in to comment.