Skip to content

Commit

Permalink
Add Base.checked_pow(x,y) to Base.Checked library (#52849) (#134)
Browse files Browse the repository at this point in the history
Fixes #52262.

Performs `^(x, y)` but throws OverflowError on overflow.

Example:
```julia
julia> 2^62
4611686018427387904

julia> 2^63
-9223372036854775808

julia> checked_pow(2, 63)
ERROR: OverflowError: 2147483648 * 4294967296 overflowed for type Int64
```

Co-authored-by: Nathan Daly <[email protected]>
Co-authored-by: Jameson Nash <[email protected]>
Co-authored-by: Shuhei Kadowaki <[email protected]>
Co-authored-by: Tomáš Drvoštěp <[email protected]>
  • Loading branch information
5 people authored Mar 19, 2024
1 parent 9a0475d commit 6cd7ce2
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 30 deletions.
15 changes: 14 additions & 1 deletion base/checked.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ return both the unchecked results and a boolean value denoting the presence of a
module Checked

export checked_neg, checked_abs, checked_add, checked_sub, checked_mul,
checked_div, checked_rem, checked_fld, checked_mod, checked_cld,
checked_div, checked_rem, checked_fld, checked_mod, checked_cld, checked_pow,
checked_length, add_with_overflow, sub_with_overflow, mul_with_overflow

import Core.Intrinsics:
Expand Down Expand Up @@ -358,6 +358,19 @@ The overflow protection may impose a perceptible performance penalty.
"""
checked_cld(x::T, y::T) where {T<:Integer} = cld(x, y) # Base.cld already checks

"""
Base.checked_pow(x, y)
Calculates `^(x,y)`, checking for overflow errors where applicable.
The overflow protection may impose a perceptible performance penalty.
"""
checked_pow(x::Integer, y::Integer) = checked_power_by_squaring(x, y)

checked_power_by_squaring(x_, p::Integer) = Base.power_by_squaring(x_, p; mul = checked_mul)
# For Booleans, the default implementation covers all cases.
checked_power_by_squaring(x::Bool, p::Integer) = Base.power_by_squaring(x, p)

"""
Base.checked_length(r)
Expand Down
11 changes: 6 additions & 5 deletions base/intfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -272,14 +272,15 @@ to_power_type(x) = convert(Base._return_type(*, Tuple{typeof(x), typeof(x)}), x)
"\nMake x a float matrix by adding a zero decimal ",
"(e.g., [2.0 1.0;1.0 0.0]^", p, " instead of [2 1;1 0]^", p, ")",
"or write float(x)^", p, " or Rational.(x)^", p, ".")))
@assume_effects :terminates_locally function power_by_squaring(x_, p::Integer)
# The * keyword supports `*=checked_mul` for `checked_pow`
@assume_effects :terminates_locally function power_by_squaring(x_, p::Integer; mul=*)
x = to_power_type(x_)
if p == 1
return copy(x)
elseif p == 0
return one(x)
elseif p == 2
return x*x
return mul(x, x)
elseif p < 0
isone(x) && return copy(x)
isone(-x) && return iseven(p) ? one(x) : copy(x)
Expand All @@ -288,16 +289,16 @@ to_power_type(x) = convert(Base._return_type(*, Tuple{typeof(x), typeof(x)}), x)
t = trailing_zeros(p) + 1
p >>= t
while (t -= 1) > 0
x *= x
x = mul(x, x)
end
y = x
while p > 0
t = trailing_zeros(p) + 1
p >>= t
while (t -= 1) >= 0
x *= x
x = mul(x, x)
end
y *= x
y = mul(y, x)
end
return y
end
Expand Down
1 change: 1 addition & 0 deletions doc/src/base/math.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ Base.Checked.checked_rem
Base.Checked.checked_fld
Base.Checked.checked_mod
Base.Checked.checked_cld
Base.Checked.checked_pow
Base.Checked.add_with_overflow
Base.Checked.sub_with_overflow
Base.Checked.mul_with_overflow
Expand Down
22 changes: 21 additions & 1 deletion test/checked.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Checked integer arithmetic

import Base: checked_abs, checked_neg, checked_add, checked_sub, checked_mul,
checked_div, checked_rem, checked_fld, checked_mod, checked_cld,
checked_div, checked_rem, checked_fld, checked_mod, checked_cld, checked_pow,
add_with_overflow, sub_with_overflow, mul_with_overflow

# checked operations
Expand Down Expand Up @@ -166,6 +166,19 @@ import Base: checked_abs, checked_neg, checked_add, checked_sub, checked_mul,
@test checked_cld(typemin(T), T(1)) === typemin(T)
@test_throws DivideError checked_cld(typemin(T), T(0))
@test_throws DivideError checked_cld(typemin(T), T(-1))

@test checked_pow(T(1), T(0)) === T(1)
@test checked_pow(typemax(T), T(0)) === T(1)
@test checked_pow(typemin(T), T(0)) === T(1)
@test checked_pow(T(1), T(1)) === T(1)
@test checked_pow(T(1), typemax(T)) === T(1)
@test checked_pow(T(2), T(2)) === T(4)
@test_throws OverflowError checked_pow(T(2), typemax(T))
@test_throws OverflowError checked_pow(T(-2), typemax(T))
@test_throws OverflowError checked_pow(typemax(T), T(2))
@test_throws OverflowError checked_pow(typemin(T), T(2))
@test_throws DomainError checked_pow(T(2), -T(1))
@test_throws DomainError checked_pow(-T(2), -T(1))
end

@testset for T in (UInt8, UInt16, UInt32, UInt64, UInt128)
Expand Down Expand Up @@ -296,6 +309,10 @@ end
@test checked_cld(true, true) === true
@test checked_cld(false, true) === false
@test_throws DivideError checked_cld(true, false)

@test checked_pow(true, 1) === true
@test checked_pow(true, 1000000) === true
@test checked_pow(false, 1000000) === false
end
@testset "BigInt" begin
@test checked_abs(BigInt(-1)) == BigInt(1)
Expand All @@ -310,6 +327,9 @@ end
@test checked_fld(BigInt(10), BigInt(3)) == BigInt(3)
@test checked_mod(BigInt(9), BigInt(4)) == BigInt(1)
@test checked_cld(BigInt(10), BigInt(3)) == BigInt(4)

@test checked_pow(BigInt(2), 2) == BigInt(4)
@test checked_pow(BigInt(2), 100) == BigInt(1267650600228229401496703205376)
end

@testset "Additional tests" begin
Expand Down
45 changes: 22 additions & 23 deletions test/compiler/ssair.jl
Original file line number Diff line number Diff line change
Expand Up @@ -549,21 +549,25 @@ import Core.Compiler: NewInstruction, insert_node!
let ir = Base.code_ircode((Int,Int); optimize_until="inlining") do a, b
a^b
end |> only |> first
@test length(ir.stmts) == 2
@test Meta.isexpr(ir.stmts[1][:inst], :invoke)
nstmts = length(ir.stmts)
invoke_idx = findfirst(@nospecialize(stmt)->Meta.isexpr(stmt, :invoke), ir.stmts.inst)
@test invoke !== nothing

newssa = insert_node!(ir, SSAValue(1), NewInstruction(Expr(:call, println, SSAValue(1)), Nothing), #=attach_after=#true)
invoke_ssa = SSAValue(invoke_idx)
newssa = insert_node!(ir, invoke_ssa, NewInstruction(Expr(:call, println, invoke_ssa), Nothing), #=attach_after=#true)
newssa = insert_node!(ir, newssa, NewInstruction(Expr(:call, println, newssa), Nothing), #=attach_after=#true)

ir = Core.Compiler.compact!(ir)
@test length(ir.stmts) == 4
@test Meta.isexpr(ir.stmts[1][:inst], :invoke)
call1 = ir.stmts[2][:inst]

@test length(ir.stmts) == nstmts + 2
@test Meta.isexpr(ir.stmts.inst[invoke_idx], :invoke)
call1 = ir.stmts.inst[invoke_idx+1]
@test iscall((ir,println), call1)
@test call1.args[2] === SSAValue(1)
call2 = ir.stmts[3][:inst]
@test call1.args[2] === invoke_ssa
call2 = ir.stmts.inst[invoke_idx+2]

@test iscall((ir,println), call2)
@test call2.args[2] === SSAValue(2)
@test call2.args[2] === SSAValue(invoke_idx+1)
end

# Issue #50379 - insert_node!(::IncrementalCompact, ...) at end of basic block
Expand Down Expand Up @@ -607,47 +611,42 @@ end
let ir = Base.code_ircode((Int,Int); optimize_until="inlining") do a, b
a^b
end |> only |> first
invoke_idx = findfirst(ir.stmts.inst) do @nospecialize(x)
Meta.isexpr(x, :invoke)
end
invoke_idx = findfirst(@nospecialize(stmt)->Meta.isexpr(stmt, :invoke), ir.stmts.inst)
@test invoke_idx !== nothing
invoke_expr = ir.stmts.inst[invoke_idx]
invoke_ssa = SSAValue(invoke_idx)

# effect-ful node
let compact = Core.Compiler.IncrementalCompact(Core.Compiler.copy(ir))
insert_node!(compact, SSAValue(1), NewInstruction(Expr(:call, println, SSAValue(1)), Nothing), #=attach_after=#true)
insert_node!(compact, invoke_ssa, NewInstruction(Expr(:call, println, invoke_ssa), Nothing), #=attach_after=#true)
state = Core.Compiler.iterate(compact)
while state !== nothing
state = Core.Compiler.iterate(compact, state[2])
end
ir = Core.Compiler.finish(compact)
new_invoke_idx = findfirst(ir.stmts.inst) do @nospecialize(x)
x == invoke_expr
end
new_invoke_idx = findfirst(@nospecialize(stmt)->stmt==invoke_expr, ir.stmts.inst)
@test new_invoke_idx !== nothing
new_call_idx = findfirst(ir.stmts.inst) do @nospecialize(x)
iscall((ir,println), x) && x.args[2] === SSAValue(invoke_idx)
new_call_idx = findfirst(ir.stmts.inst) do @nospecialize(stmt)
iscall((ir,println), stmt) && stmt.args[2] === SSAValue(new_invoke_idx)
end
@test new_call_idx !== nothing
@test new_call_idx == new_invoke_idx+1
end

# effect-free node
let compact = Core.Compiler.IncrementalCompact(Core.Compiler.copy(ir))
insert_node!(compact, SSAValue(1), NewInstruction(Expr(:call, GlobalRef(Base, :add_int), SSAValue(1), SSAValue(1)), Int), #=attach_after=#true)
insert_node!(compact, invoke_ssa, NewInstruction(Expr(:call, GlobalRef(Base, :add_int), invoke_ssa, invoke_ssa), Int), #=attach_after=#true)
state = Core.Compiler.iterate(compact)
while state !== nothing
state = Core.Compiler.iterate(compact, state[2])
end
ir = Core.Compiler.finish(compact)

ir = Core.Compiler.finish(compact)
new_invoke_idx = findfirst(ir.stmts.inst) do @nospecialize(x)
x == invoke_expr
end
new_invoke_idx = findfirst(@nospecialize(stmt)->stmt==invoke_expr, ir.stmts.inst)
@test new_invoke_idx !== nothing
new_call_idx = findfirst(ir.stmts.inst) do @nospecialize(x)
iscall((ir,Base.add_int), x) && x.args[2] === SSAValue(invoke_idx)
iscall((ir,Base.add_int), x) && x.args[2] === SSAValue(new_invoke_idx)
end
@test new_call_idx === nothing # should be deleted during the compaction
end
Expand Down

0 comments on commit 6cd7ce2

Please sign in to comment.