Skip to content

Commit

Permalink
Faster Rationals by avoiding unnecessary divgcd
Browse files Browse the repository at this point in the history
  • Loading branch information
Liozou committed Apr 16, 2020
1 parent 07b7237 commit 64fdb02
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 32 deletions.
100 changes: 68 additions & 32 deletions base/rational.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,32 @@ struct Rational{T<:Integer} <: Real
end
return new(num2, den2)
end

function Rational{T}(num::Integer, den::Integer, ::Val{false}) where T<:Integer
# Used when num and den are only known to be coprime and not both equal to 0
if T<:Signed && signbit(den)
den = -den
signbit(den) && __throw_rational_argerror_typemin(T)
num = -num
end
return new(num, den)
end

function Rational{T}(num::Integer, den::Integer, ::Val{true}) where T<:Integer
# Used to skip all checks. This means we know num and den are coprime
# and cannot both be 0 and den >= 0
return new(num, den)
end

end
@noinline __throw_rational_argerror_zero(T) = throw(ArgumentError("invalid rational: zero($T)//zero($T)"))
@noinline __throw_rational_argerror_typemin(T) = throw(ArgumentError("invalid rational: denominator can't be typemin($T)"))

Rational(n::T, d::T) where {T<:Integer} = Rational{T}(n,d)
Rational(n::T, d::T, check::Val{b}) where {T<:Integer, b} = Rational{T}(n, d, check)
Rational(n::Integer, d::Integer) = Rational(promote(n,d)...)
Rational(n::Integer) = Rational(n,one(n))
Rational(n::Integer, d::Integer, check::Val{b}) where {b} = Rational(promote(n, d)..., check)
Rational(n::Integer) = Rational(n, one(n), Val(true))

function divgcd(x::Integer,y::Integer)
g = gcd(x,y)
Expand All @@ -51,19 +70,19 @@ julia> (3 // 5) // (2 // 1)

function //(x::Rational, y::Integer)
xn,yn = divgcd(x.num,y)
xn//checked_mul(x.den,yn)
Rational(xn, checked_mul(x.den, yn), Val(false))
end
function //(x::Integer, y::Rational)
xn,yn = divgcd(x,y.num)
checked_mul(xn,y.den)//yn
Rational(checked_mul(xn, y.den), yn, Val(false))
end
function //(x::Rational, y::Rational)
xn,yn = divgcd(x.num,y.num)
xd,yd = divgcd(x.den,y.den)
checked_mul(xn,yd)//checked_mul(xd,yn)
Rational(checked_mul(xn, yd), checked_mul(xd, yn), Val(false))
end

//(x::Complex, y::Real) = complex(real(x)//y,imag(x)//y)
//(x::Complex, y::Real) = complex(real(x)//y, imag(x)//y)
//(x::Number, y::Complex) = x*conj(y)//abs2(y)


Expand All @@ -84,8 +103,8 @@ function write(s::IO, z::Rational)
write(s,numerator(z),denominator(z))
end

Rational{T}(x::Rational) where {T<:Integer} = Rational{T}(convert(T,x.num), convert(T,x.den))
Rational{T}(x::Integer) where {T<:Integer} = Rational{T}(convert(T,x), convert(T,1))
Rational{T}(x::Rational) where {T<:Integer} = Rational{T}(convert(T,x.num), convert(T,x.den), Val(true))
Rational{T}(x::Integer) where {T<:Integer} = Rational{T}(convert(T,x), one(T), Val(true))

Rational(x::Rational) = x

Expand All @@ -108,7 +127,7 @@ end
Rational(x::Float64) = Rational{Int64}(x)
Rational(x::Float32) = Rational{Int}(x)

big(q::Rational) = big(numerator(q))//big(denominator(q))
big(q::Rational) = Rational(big(numerator(q)), big(denominator(q)), Val(true))

big(z::Complex{<:Rational{<:Integer}}) = Complex{Rational{BigInt}}(z)

Expand Down Expand Up @@ -140,8 +159,11 @@ function rationalize(::Type{T}, x::AbstractFloat, tol::Real) where T<:Integer
if tol < 0
throw(ArgumentError("negative tolerance $tol"))
end
if T<:Unsigned && x < 0
throw(OverflowError("cannot negate unsigned number"))
end
isnan(x) && return T(x)//one(T)
isinf(x) && return (x < 0 ? -one(T) : one(T))//zero(T)
isinf(x) && return Rational(x < 0 ? -one(T) : one(T), zero(T), Val(true))

p, q = (x < 0 ? -one(T) : one(T)), zero(T)
pp, qq = zero(T), one(T)
Expand Down Expand Up @@ -234,59 +256,73 @@ denominator(x::Rational) = x.den

sign(x::Rational) = oftype(x, sign(x.num))
signbit(x::Rational) = signbit(x.num)
copysign(x::Rational, y::Real) = copysign(x.num,y) // x.den
copysign(x::Rational, y::Rational) = copysign(x.num,y.num) // x.den
copysign(x::Rational, y::Real) = Rational(copysign(x.num, y), x.den, Val(true))
copysign(x::Rational, y::Rational) = Rational(copysign(x.num, y.num), x.den, Val(true))

abs(x::Rational) = Rational(abs(x.num), x.den)

typemin(::Type{Rational{T}}) where {T<:Integer} = -one(T)//zero(T)
typemax(::Type{Rational{T}}) where {T<:Integer} = one(T)//zero(T)
typemin(::Type{Rational{T}}) where {T<:Signed} = Rational(-one(T), zero(T), Val(true))
typemin(::Type{Rational{T}}) where {T<:Integer} = Rational(zero(T), one(T), Val(true))
typemax(::Type{Rational{T}}) where {T<:Integer} = Rational(one(T), zero(T), Val(true))

isinteger(x::Rational) = x.den == 1

+(x::Rational) = (+x.num) // x.den
-(x::Rational) = (-x.num) // x.den
+(x::Rational) = Rational(+x.num, x.den, Val(true))
-(x::Rational) = Rational(-x.num, x.den, Val(true))

function -(x::Rational{T}) where T<:BitSigned
x.num == typemin(T) && throw(OverflowError("rational numerator is typemin(T)"))
(-x.num) // x.den
Rational(-x.num, x.den, Val(true))
end
function -(x::Rational{T}) where T<:Unsigned
x.num != zero(T) && throw(OverflowError("cannot negate unsigned number"))
x
end

for (op,chop) in ((:+,:checked_add), (:-,:checked_sub),
(:rem,:rem), (:mod,:mod))
for (op,chop) in ((:+,:checked_add), (:-,:checked_sub), (:rem,:rem), (:mod,:mod))
@eval begin
function ($op)(x::Rational, y::Rational)
xd, yd = divgcd(x.den, y.den)
Rational(($chop)(checked_mul(x.num,yd), checked_mul(y.num,xd)), checked_mul(x.den,yd))
end

function ($op)(x::Rational, y::Integer)
Rational(($chop)(x.num, checked_mul(x.den, y)), x.den)
Rational(($chop)(x.num, checked_mul(x.den, y)), x.den, Val(true))
end

end
end
for (op,chop) in ((:+,:checked_add), (:-,:checked_sub))
@eval begin
function ($op)(y::Integer, x::Rational)
Rational(($chop)(checked_mul(x.den, y), x.num), x.den, Val(true))
end
end
end
for (op,chop) in ((:rem,:rem), (:mod,:mod))
@eval begin
function ($op)(y::Integer, x::Rational)
Rational(($chop)(checked_mul(x.den, y), x.num), x.den)
end
end
end

function *(x::Rational, y::Rational)
xn,yd = divgcd(x.num,y.den)
xd,yn = divgcd(x.den,y.num)
checked_mul(xn,yn) // checked_mul(xd,yd)
xn, yd = divgcd(x.num, y.den)
xd, yn = divgcd(x.den, y.num)
Rational(checked_mul(xn, yn), checked_mul(xd, yd), Val(true))
end
function *(x::Rational, y::Integer)
xd, yn = divgcd(x.den, y)
checked_mul(x.num, yn) // xd
Rational(checked_mul(x.num, yn), xd, Val(true))
end
function *(y::Integer, x::Rational)
yn, xd = divgcd(y, x.den)
Rational(checked_mul(yn, x.num), xd, Val(true))
end
*(x::Integer, y::Rational) = *(y, x)
/(x::Rational, y::Rational) = x//y
/(x::Rational, y::Union{Rational, Integer}) = x//y
/(x::Integer, y::Rational) = x//y
/(x::Rational, y::Complex{<:Union{Integer,Rational}}) = x//y
inv(x::Rational) = Rational(x.den, x.num)
inv(x::Rational) = Rational(x.den, x.num, Val(false))

fma(x::Rational, y::Rational, z::Rational) = x*y+z

Expand Down Expand Up @@ -405,7 +441,7 @@ function _round_rational(::Type{T}, x::Rational{Tr}, ::RoundingMode{:Nearest}) w
if denominator(x) == zero(Tr) && T <: Integer
throw(DivideError())
elseif denominator(x) == zero(Tr)
return convert(T, copysign(one(Tr)//zero(Tr), numerator(x)))
return convert(T, copysign(Rational(one(Tr), zero(Tr), Val(true)), numerator(x)))
end
q,r = divrem(numerator(x), denominator(x))
s = q
Expand All @@ -419,7 +455,7 @@ function _round_rational(::Type{T}, x::Rational{Tr}, ::RoundingMode{:NearestTies
if denominator(x) == zero(Tr) && T <: Integer
throw(DivideError())
elseif denominator(x) == zero(Tr)
return convert(T, copysign(one(Tr)//zero(Tr), numerator(x)))
return convert(T, copysign(Rational(one(Tr), zero(Tr), Val(true)), numerator(x)))
end
q,r = divrem(numerator(x), denominator(x))
s = q
Expand All @@ -433,7 +469,7 @@ function _round_rational(::Type{T}, x::Rational{Tr}, ::RoundingMode{:NearestTies
if denominator(x) == zero(Tr) && T <: Integer
throw(DivideError())
elseif denominator(x) == zero(Tr)
return convert(T, copysign(one(Tr)//zero(Tr), numerator(x)))
return convert(T, copysign(Rational(one(Tr), zero(Tr), Val(true)), numerator(x)))
end
q,r = divrem(numerator(x), denominator(x))
s = q
Expand Down Expand Up @@ -477,8 +513,8 @@ end

float(::Type{Rational{T}}) where {T<:Integer} = float(T)

gcd(x::Rational, y::Rational) = gcd(x.num, y.num) // lcm(x.den, y.den)
lcm(x::Rational, y::Rational) = lcm(x.num, y.num) // gcd(x.den, y.den)
gcd(x::Rational, y::Rational) = Rational(gcd(x.num, y.num), lcm(x.den, y.den), Val(true))
lcm(x::Rational, y::Rational) = Rational(lcm(x.num, y.num), gcd(x.den, y.den), Val(true))
function gcdx(x::Rational, y::Rational)
c = gcd(x, y)
if iszero(c.num)
Expand Down
8 changes: 8 additions & 0 deletions test/rational.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ using Test
@test 5//0 == 1//0
@test -1//0 == -1//0
@test -7//0 == -1//0
@test (-1//2) // (-2//5) == 5//4

@test_throws OverflowError -(0x01//0x0f)
@test_throws OverflowError -(typemin(Int)//1)
Expand All @@ -26,9 +27,12 @@ using Test
@test (typemax(Int)//1) / (typemax(Int)//1) == 1
@test (1//typemax(Int)) / (1//typemax(Int)) == 1
@test_throws OverflowError (1//2)^63
@test inv((1+typemin(Int))//typemax(Int)) == -1
@test_throws ArgumentError inv(typemin(Int)//typemax(Int))

@test @inferred(rationalize(Int, 3.0, 0.0)) === 3//1
@test @inferred(rationalize(Int, 3.0, 0)) === 3//1
@test_throws OverflowError rationalize(UInt, -2.0)
@test_throws ArgumentError rationalize(Int, big(3.0), -1.)
# issue 26823
@test_throws InexactError rationalize(Int, NaN)
Expand Down Expand Up @@ -120,6 +124,8 @@ end
@test widen(Rational{T}) == Rational{widen(T)}
end

@test iszero(typemin(Rational{UInt}))

@test Rational(Float32(rand_int)) == Rational(rand_int)

@test Rational(Rational(rand_int)) == Rational(rand_int)
Expand Down Expand Up @@ -416,6 +422,8 @@ end
end
@test 1//2 * 3 == 3//2
@test -3 * (1//2) == -3//2
@test (6//5) // -3 == -2//5
@test -4 // (-6//5) == 10//3

@test_throws OverflowError UInt(1)//2 - 1
@test_throws OverflowError 1 - UInt(5)//2
Expand Down

0 comments on commit 64fdb02

Please sign in to comment.