From a9bb19f79d38f4d74cbd81c0e979a506050377d5 Mon Sep 17 00:00:00 2001 From: Rafael Fourquet Date: Thu, 9 Jun 2016 16:18:41 +0530 Subject: [PATCH] ndigits: check for invalid bases (fix #16766) --- base/gmp.jl | 14 +++++++++++-- base/intfuncs.jl | 54 ++++++++++++++++++++++++++---------------------- test/bigint.jl | 6 +++--- test/intfuncs.jl | 6 +++++- 4 files changed, 49 insertions(+), 31 deletions(-) diff --git a/base/gmp.jl b/base/gmp.jl index 3bcb2ade3e8675..c0e28a92df62ec 100644 --- a/base/gmp.jl +++ b/base/gmp.jl @@ -8,7 +8,8 @@ import Base: *, +, -, /, <, <<, >>, >>>, <=, ==, >, >=, ^, (~), (&), (|), ($), binomial, cmp, convert, div, divrem, factorial, fld, gcd, gcdx, lcm, mod, ndigits, promote_rule, rem, show, isqrt, string, powermod, sum, trailing_zeros, trailing_ones, count_ones, base, tryparse_internal, - bin, oct, dec, hex, isequal, invmod, prevpow2, nextpow2, ndigits0z, widen, signed, unsafe_trunc, trunc + bin, oct, dec, hex, isequal, invmod, prevpow2, nextpow2, ndigits0z, + ndigits0znb, widen, signed, unsafe_trunc, trunc if Clong == Int32 typealias ClongMax Union{Int8, Int16, Int32} @@ -519,6 +520,7 @@ end function ndigits0z(x::BigInt, b::Integer=10) b < 2 && throw(DomainError()) + x.size == 0 && return 0 # for consistency with other ndigits0z methods if ispow2(b) && 2 <= b <= 62 # GMP assumes b is in this range Int(ccall((:__gmpz_sizeinbase,:libgmp), Csize_t, (Ptr{BigInt}, Cint), &x, b)) else @@ -538,7 +540,15 @@ function ndigits0z(x::BigInt, b::Integer=10) end end end -ndigits(x::BigInt, b::Integer=10) = x.size == 0 ? 1 : ndigits0z(x,b) + +ndigits(x::BigInt, b::Integer=10) = + if b < -1 + x.size == 0 ? 1 : ndigits0znb(x, b) + elseif b > 1 + x.size == 0 ? 1 : ndigits0z(x, b) + else + throw(DomainError()) + end prevpow2(x::BigInt) = x.size < 0 ? -prevpow2(-x) : (x <= 2 ? x : one(BigInt) << (ndigits(x, 2)-1)) nextpow2(x::BigInt) = x.size < 0 ? -nextpow2(-x) : (x <= 2 ? x : one(BigInt) << ndigits(x-1, 2)) diff --git a/base/intfuncs.jl b/base/intfuncs.jl index d85b98cb4a8fce..3ebb8fce650e2b 100644 --- a/base/intfuncs.jl +++ b/base/intfuncs.jl @@ -183,10 +183,13 @@ function ndigits0z(x::UInt128) return n + ndigits0z(UInt64(x)) end ndigits0z(x::Integer) = ndigits0z(unsigned(abs(x))) +ndigits(x::Integer) = x==0 ? 1 : ndigits0z(x) const ndigits_max_mul = Core.sizeof(Int32) == 4 ? 69000000 : 290000000000000000 -function ndigits0znb(n::Int, b::Int) +# The suffix "nb" stands for "negative base" +function ndigits0znb(n::Integer, b::Integer) + # precondition: b < -1 && !(typeof(n) <: Unsigned) d = 0 while n != 0 n = cld(n,b) @@ -195,36 +198,38 @@ function ndigits0znb(n::Int, b::Int) return d end +ndigits0znb(n::Unsigned, b::Integer) = ndigits0znb(signed(n), b) + function ndigits0z(n::Unsigned, b::Int) + # precondition: b > 1 d = 0 - if b < 0 - d = ndigits0znb(signed(n), b) - else - b == 2 && return (sizeof(n)<<3-leading_zeros(n)) - b == 8 && return div((sizeof(n)<<3)-leading_zeros(n)+2,3) - b == 16 && return (sizeof(n)<<1)-(leading_zeros(n)>>2) - b == 10 && return ndigits0z(n) - while ndigits_max_mul < n - n = div(n,b) - d += 1 - end - m = 1 - while m <= n - m *= b - d += 1 - end + b == 2 && return (sizeof(n)<<3-leading_zeros(n)) + b == 8 && return div((sizeof(n)<<3)-leading_zeros(n)+2,3) + b == 16 && return (sizeof(n)<<1)-(leading_zeros(n)>>2) + b == 10 && return ndigits0z(n) + while ndigits_max_mul < n + n = div(n,b) + d += 1 + end + m = 1 + while m <= n + m *= b + d += 1 end return d end -ndigits0z(x::Integer, b::Integer) = ndigits0z(unsigned(abs(x)),Int(b)) -ndigitsnb(x::Integer, b::Integer) = x==0 ? 1 : ndigits0znb(x, b) +ndigits0z(x::Integer, b::Integer) = ndigits0z(unsigned(abs(x)), Int(b)) -ndigits(x::Unsigned, b::Integer) = x==0 ? 1 : ndigits0z(x,Int(b)) -ndigits(x::Unsigned) = x==0 ? 1 : ndigits0z(x) +ndigits(x::Integer, b::Integer) = + if b < -1 + x == 0 ? 1 : ndigits0znb(x, b) + elseif b > 1 + x == 0 ? 1 : ndigits0z(x, b) + else + throw(DomainError()) + end -ndigits(x::Integer, b::Integer) = b >= 0 ? ndigits(unsigned(abs(x)),Int(b)) : ndigitsnb(x, b) -ndigits(x::Integer) = ndigits(unsigned(abs(x))) ## integer to string functions ## @@ -319,8 +324,7 @@ bits(x::Union{Int128,UInt128}) = bin(reinterpret(UInt128,x),128) digits{T<:Integer}(n::Integer, base::T=10, pad::Integer=1) = digits(T, n, base, pad) function digits{T<:Integer}(::Type{T}, n::Integer, base::Integer=10, pad::Integer=1) - 2 <= base || throw(ArgumentError("base must be ≥ 2, got $base")) - digits!(zeros(T, max(pad, ndigits0z(n,base))), n, base) + digits!(zeros(T, max(pad, ndigits(n,base))), n, base) end function digits!{T<:Integer}(a::AbstractArray{T,1}, n::Integer, base::Integer=10) diff --git a/test/bigint.jl b/test/bigint.jl index 4e4554e736788c..29b292b7b379d6 100644 --- a/test/bigint.jl +++ b/test/bigint.jl @@ -275,10 +275,10 @@ ndigits_mismatch(n) = ndigits(n) != ndigits(BigInt(n)) @test !any(ndigits_mismatch, 8192:9999) # The following should not crash (#16579) -ndigits(rand(big(-999:999)), rand(63:typemax(Int))) -ndigits(rand(big(-999:999)), big(2)^rand(2:999)) +ndigits(big(rand(Int)), rand(63:typemax(Int))) +ndigits(big(rand(Int)), big(2)^rand(2:999)) -@test_throws DomainError ndigits(rand(big(-999:999)), rand(typemin(Int):1)) +@test_throws DomainError ndigits(big(rand(Int)), rand(-1:1)) # conversion from float @test BigInt(2.0) == BigInt(2.0f0) == BigInt(big(2.0)) == 2 diff --git a/test/intfuncs.jl b/test/intfuncs.jl index 50a1344ba7daad..253359b03a0173 100644 --- a/test/intfuncs.jl +++ b/test/intfuncs.jl @@ -79,10 +79,14 @@ end @test ndigits(146, -3) == 5 -let n = rand(Int) +let (n, b) = rand(Int, 2) + -1 <= b <= 1 && (b = 2) # invalid bases @test ndigits(n) == ndigits(big(n)) == ndigits(n, 10) + @test ndigits(n, b) == ndigits(big(n), b) end +@test_throws DomainError ndigits(rand(Int), rand(-1:1)) + @test bin(3) == "11" @test bin(3, 2) == "11" @test bin(3, 3) == "011"