diff --git a/base/rational.jl b/base/rational.jl index 5d2674f14e89a9..832ca2ec1b7327 100644 --- a/base/rational.jl +++ b/base/rational.jl @@ -20,13 +20,32 @@ struct Rational{T<:Integer} <: Real end return new(num2, den2) end + + function Rational{T}(num::Integer, den::Integer, ::Val{false}) where T<:Integer + # Used when num and den are only known to be coprime and not both equal to 0 + if T<:Signed && signbit(den) + den = -den + signbit(den) && __throw_rational_argerror_typemin(T) + num = -num + end + return new(num, den) + end + + function Rational{T}(num::Integer, den::Integer, ::Val{true}) where T<:Integer + # Used to skip all checks. This means we know num and den are coprime + # and cannot both be 0 and den >= 0 + return new(num, den) + end + end @noinline __throw_rational_argerror_zero(T) = throw(ArgumentError("invalid rational: zero($T)//zero($T)")) @noinline __throw_rational_argerror_typemin(T) = throw(ArgumentError("invalid rational: denominator can't be typemin($T)")) -Rational(n::T, d::T) where {T<:Integer} = Rational{T}(n,d) +Rational(n::T, d::T) where {T<:Integer} = Rational{T}(n, d) +Rational(n::T, d::T, check::Val{b}) where {T<:Integer, b} = Rational{T}(n, d, check) Rational(n::Integer, d::Integer) = Rational(promote(n,d)...) -Rational(n::Integer) = Rational(n,one(n)) +Rational(n::Integer, d::Integer, check::Val{b}) where {b} = Rational(promote(n, d)..., check) +Rational(n::Integer) = Rational(n, one(n), Val(true)) function divgcd(x::Integer,y::Integer) g = gcd(x,y) @@ -51,19 +70,19 @@ julia> (3 // 5) // (2 // 1) function //(x::Rational, y::Integer) xn,yn = divgcd(x.num,y) - xn//checked_mul(x.den,yn) + Rational(xn, checked_mul(x.den, yn), Val(false)) end function //(x::Integer, y::Rational) xn,yn = divgcd(x,y.num) - checked_mul(xn,y.den)//yn + Rational(checked_mul(xn, y.den), yn, Val(false)) end function //(x::Rational, y::Rational) xn,yn = divgcd(x.num,y.num) xd,yd = divgcd(x.den,y.den) - checked_mul(xn,yd)//checked_mul(xd,yn) + Rational(checked_mul(xn, yd), checked_mul(xd, yn), Val(false)) end -//(x::Complex, y::Real) = complex(real(x)//y,imag(x)//y) +//(x::Complex, y::Real) = complex(real(x)//y, imag(x)//y) //(x::Number, y::Complex) = x*conj(y)//abs2(y) @@ -84,8 +103,8 @@ function write(s::IO, z::Rational) write(s,numerator(z),denominator(z)) end -Rational{T}(x::Rational) where {T<:Integer} = Rational{T}(convert(T,x.num), convert(T,x.den)) -Rational{T}(x::Integer) where {T<:Integer} = Rational{T}(convert(T,x), convert(T,1)) +Rational{T}(x::Rational) where {T<:Integer} = Rational{T}(convert(T,x.num), convert(T,x.den), Val(true)) +Rational{T}(x::Integer) where {T<:Integer} = Rational{T}(convert(T,x), one(T), Val(true)) Rational(x::Rational) = x @@ -108,7 +127,7 @@ end Rational(x::Float64) = Rational{Int64}(x) Rational(x::Float32) = Rational{Int}(x) -big(q::Rational) = big(numerator(q))//big(denominator(q)) +big(q::Rational) = Rational(big(numerator(q)), big(denominator(q)), Val(true)) big(z::Complex{<:Rational{<:Integer}}) = Complex{Rational{BigInt}}(z) @@ -140,8 +159,11 @@ function rationalize(::Type{T}, x::AbstractFloat, tol::Real) where T<:Integer if tol < 0 throw(ArgumentError("negative tolerance $tol")) end + if T<:Unsigned && x < 0 + throw(OverflowError("cannot negate unsigned number")) + end isnan(x) && return T(x)//one(T) - isinf(x) && return (x < 0 ? -one(T) : one(T))//zero(T) + isinf(x) && return Rational(x < 0 ? -one(T) : one(T), zero(T), Val(true)) p, q = (x < 0 ? -one(T) : one(T)), zero(T) pp, qq = zero(T), one(T) @@ -234,30 +256,30 @@ denominator(x::Rational) = x.den sign(x::Rational) = oftype(x, sign(x.num)) signbit(x::Rational) = signbit(x.num) -copysign(x::Rational, y::Real) = copysign(x.num,y) // x.den -copysign(x::Rational, y::Rational) = copysign(x.num,y.num) // x.den +copysign(x::Rational, y::Real) = Rational(copysign(x.num, y), x.den, Val(true)) +copysign(x::Rational, y::Rational) = Rational(copysign(x.num, y.num), x.den, Val(true)) abs(x::Rational) = Rational(abs(x.num), x.den) -typemin(::Type{Rational{T}}) where {T<:Integer} = -one(T)//zero(T) -typemax(::Type{Rational{T}}) where {T<:Integer} = one(T)//zero(T) +typemin(::Type{Rational{T}}) where {T<:Signed} = Rational(-one(T), zero(T), Val(true)) +typemin(::Type{Rational{T}}) where {T<:Integer} = Rational(zero(T), one(T), Val(true)) +typemax(::Type{Rational{T}}) where {T<:Integer} = Rational(one(T), zero(T), Val(true)) isinteger(x::Rational) = x.den == 1 -+(x::Rational) = (+x.num) // x.den --(x::Rational) = (-x.num) // x.den ++(x::Rational) = Rational(+x.num, x.den, Val(true)) +-(x::Rational) = Rational(-x.num, x.den, Val(true)) function -(x::Rational{T}) where T<:BitSigned x.num == typemin(T) && throw(OverflowError("rational numerator is typemin(T)")) - (-x.num) // x.den + Rational(-x.num, x.den, Val(true)) end function -(x::Rational{T}) where T<:Unsigned x.num != zero(T) && throw(OverflowError("cannot negate unsigned number")) x end -for (op,chop) in ((:+,:checked_add), (:-,:checked_sub), - (:rem,:rem), (:mod,:mod)) +for (op,chop) in ((:+,:checked_add), (:-,:checked_sub), (:rem,:rem), (:mod,:mod)) @eval begin function ($op)(x::Rational, y::Rational) xd, yd = divgcd(x.den, y.den) @@ -265,9 +287,19 @@ for (op,chop) in ((:+,:checked_add), (:-,:checked_sub), end function ($op)(x::Rational, y::Integer) - Rational(($chop)(x.num, checked_mul(x.den, y)), x.den) + Rational(($chop)(x.num, checked_mul(x.den, y)), x.den, Val(true)) end - + end +end +for (op,chop) in ((:+,:checked_add), (:-,:checked_sub)) + @eval begin + function ($op)(y::Integer, x::Rational) + Rational(($chop)(checked_mul(x.den, y), x.num), x.den, Val(true)) + end + end +end +for (op,chop) in ((:rem,:rem), (:mod,:mod)) + @eval begin function ($op)(y::Integer, x::Rational) Rational(($chop)(checked_mul(x.den, y), x.num), x.den) end @@ -275,18 +307,22 @@ for (op,chop) in ((:+,:checked_add), (:-,:checked_sub), end function *(x::Rational, y::Rational) - xn,yd = divgcd(x.num,y.den) - xd,yn = divgcd(x.den,y.num) - checked_mul(xn,yn) // checked_mul(xd,yd) + xn, yd = divgcd(x.num, y.den) + xd, yn = divgcd(x.den, y.num) + Rational(checked_mul(xn, yn), checked_mul(xd, yd), Val(true)) end function *(x::Rational, y::Integer) xd, yn = divgcd(x.den, y) - checked_mul(x.num, yn) // xd + Rational(checked_mul(x.num, yn), xd, Val(true)) +end +function *(y::Integer, x::Rational) + yn, xd = divgcd(y, x.den) + Rational(checked_mul(yn, x.num), xd, Val(true)) end -*(x::Integer, y::Rational) = *(y, x) -/(x::Rational, y::Rational) = x//y +/(x::Rational, y::Union{Rational, Integer}) = x//y +/(x::Integer, y::Rational) = x//y /(x::Rational, y::Complex{<:Union{Integer,Rational}}) = x//y -inv(x::Rational) = Rational(x.den, x.num) +inv(x::Rational) = Rational(x.den, x.num, Val(false)) fma(x::Rational, y::Rational, z::Rational) = x*y+z @@ -405,7 +441,7 @@ function _round_rational(::Type{T}, x::Rational{Tr}, ::RoundingMode{:Nearest}) w if denominator(x) == zero(Tr) && T <: Integer throw(DivideError()) elseif denominator(x) == zero(Tr) - return convert(T, copysign(one(Tr)//zero(Tr), numerator(x))) + return convert(T, copysign(Rational(one(Tr), zero(Tr), Val(true)), numerator(x))) end q,r = divrem(numerator(x), denominator(x)) s = q @@ -419,7 +455,7 @@ function _round_rational(::Type{T}, x::Rational{Tr}, ::RoundingMode{:NearestTies if denominator(x) == zero(Tr) && T <: Integer throw(DivideError()) elseif denominator(x) == zero(Tr) - return convert(T, copysign(one(Tr)//zero(Tr), numerator(x))) + return convert(T, copysign(Rational(one(Tr), zero(Tr), Val(true)), numerator(x))) end q,r = divrem(numerator(x), denominator(x)) s = q @@ -433,7 +469,7 @@ function _round_rational(::Type{T}, x::Rational{Tr}, ::RoundingMode{:NearestTies if denominator(x) == zero(Tr) && T <: Integer throw(DivideError()) elseif denominator(x) == zero(Tr) - return convert(T, copysign(one(Tr)//zero(Tr), numerator(x))) + return convert(T, copysign(Rational(one(Tr), zero(Tr), Val(true)), numerator(x))) end q,r = divrem(numerator(x), denominator(x)) s = q @@ -477,8 +513,8 @@ end float(::Type{Rational{T}}) where {T<:Integer} = float(T) -gcd(x::Rational, y::Rational) = gcd(x.num, y.num) // lcm(x.den, y.den) -lcm(x::Rational, y::Rational) = lcm(x.num, y.num) // gcd(x.den, y.den) +gcd(x::Rational, y::Rational) = Rational(gcd(x.num, y.num), lcm(x.den, y.den), Val(true)) +lcm(x::Rational, y::Rational) = Rational(lcm(x.num, y.num), gcd(x.den, y.den), Val(true)) function gcdx(x::Rational, y::Rational) c = gcd(x, y) if iszero(c.num) diff --git a/test/rational.jl b/test/rational.jl index 7a8f37e711320a..bf295852a1f440 100644 --- a/test/rational.jl +++ b/test/rational.jl @@ -17,6 +17,7 @@ using Test @test 5//0 == 1//0 @test -1//0 == -1//0 @test -7//0 == -1//0 + @test (-1//2) // (-2//5) == 5//4 @test_throws OverflowError -(0x01//0x0f) @test_throws OverflowError -(typemin(Int)//1) @@ -26,9 +27,12 @@ using Test @test (typemax(Int)//1) / (typemax(Int)//1) == 1 @test (1//typemax(Int)) / (1//typemax(Int)) == 1 @test_throws OverflowError (1//2)^63 + @test inv((1+typemin(Int))//typemax(Int)) == -1 + @test_throws ArgumentError inv(typemin(Int)//typemax(Int)) @test @inferred(rationalize(Int, 3.0, 0.0)) === 3//1 @test @inferred(rationalize(Int, 3.0, 0)) === 3//1 + @test_throws OverflowError rationalize(UInt, -2.0) @test_throws ArgumentError rationalize(Int, big(3.0), -1.) # issue 26823 @test_throws InexactError rationalize(Int, NaN) @@ -120,6 +124,8 @@ end @test widen(Rational{T}) == Rational{widen(T)} end + @test iszero(typemin(Rational{UInt})) + @test Rational(Float32(rand_int)) == Rational(rand_int) @test Rational(Rational(rand_int)) == Rational(rand_int) @@ -416,6 +422,8 @@ end end @test 1//2 * 3 == 3//2 @test -3 * (1//2) == -3//2 + @test (6//5) // -3 == -2//5 + @test -4 // (-6//5) == 10//3 @test_throws OverflowError UInt(1)//2 - 1 @test_throws OverflowError 1 - UInt(5)//2