diff --git a/spec/std/big/big_int_spec.cr b/spec/std/big/big_int_spec.cr index 79b6525bd465..d653f75523ac 100644 --- a/spec/std/big/big_int_spec.cr +++ b/spec/std/big/big_int_spec.cr @@ -543,4 +543,8 @@ describe "BigInt Math" do it "sqrt" do Math.sqrt(BigInt.new("1" + "0"*48)).should eq(BigFloat.new("1" + "0"*24)) end + + it "isqrt" do + Math.isqrt(BigInt.new("1" + "0"*48)).should eq(BigInt.new("1" + "0"*24)) + end end diff --git a/spec/std/math_spec.cr b/spec/std/math_spec.cr index 68cfe85394d5..c000fc569f7c 100644 --- a/spec/std/math_spec.cr +++ b/spec/std/math_spec.cr @@ -41,6 +41,17 @@ describe "Math" do Math.sqrt(4_f32).should eq(2) Math.sqrt(4).should eq(2) end + + it "isqrt" do + Math.isqrt(9).should eq(3) + Math.isqrt(8).should eq(2) + Math.isqrt(4).should eq(2) + {% for type in [UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64] %} + %val = {{type}}.new 42 + %exp = {{type}}.new 6 + Math.isqrt(%val).should eq(%exp) + {% end %} + end end describe "Exponents" do diff --git a/src/big/big_int.cr b/src/big/big_int.cr index 9c50d55e7265..a06074f7d70d 100644 --- a/src/big/big_int.cr +++ b/src/big/big_int.cr @@ -723,6 +723,11 @@ module Math def sqrt(value : BigInt) sqrt(value.to_big_f) end + + # Calculates the integer square root of *value*. + def isqrt(value : BigInt) + BigInt.new { |mpz| LibGMP.sqrt(mpz, value) } + end end module Random diff --git a/src/big/lib_gmp.cr b/src/big/lib_gmp.cr index 782b79fe3739..7f46a383aafb 100644 --- a/src/big/lib_gmp.cr +++ b/src/big/lib_gmp.cr @@ -83,6 +83,8 @@ lib LibGMP fun pow_ui = __gmpz_pow_ui(rop : MPZ*, base : MPZ*, exp : ULong) fun fac_ui = __gmpz_fac_ui(rop : MPZ*, n : ULong) + fun sqrt = __gmpz_sqrt(rop : MPZ*, op : MPZ*) + # # Bitwise operations fun and = __gmpz_and(rop : MPZ*, op1 : MPZ*, op2 : MPZ*) diff --git a/src/math/math.cr b/src/math/math.cr index 49bcd91b353d..64c4d19dd698 100644 --- a/src/math/math.cr +++ b/src/math/math.cr @@ -332,6 +332,25 @@ module Math sqrt(value.to_f) end + # Calculates the integer square root of *value*. + def isqrt(value : Int::Primitive) + raise ArgumentError.new "Input must be non-negative integer" if value < 0 + return value if value < 2 + res = value.class.zero + bit = res.succ << (res.leading_zeros_count - 2) + bit >>= value.leading_zeros_count & ~0x3 + while (bit != 0) + if value >= res + bit + value -= res + bit + res = (res >> 1) + bit + else + res >>= 1 + end + bit >>= 2 + end + res + end + # Calculates the cubic root of *value*. def cbrt(value : Float32) LibM.cbrt_f32(value)