diff --git a/src/bigints.nim b/src/bigints.nim index 9ef5abb..c5ab457 100644 --- a/src/bigints.nim +++ b/src/bigints.nim @@ -11,7 +11,6 @@ type limbs: seq[uint32] isNegative: bool - func normalize(a: var BigInt) = for i in countdown(a.limbs.high, 0): if a.limbs[i] > 0'u32: @@ -69,6 +68,7 @@ func initBigInt*(val: BigInt): BigInt = const zero = initBigInt(0) one = initBigInt(1) + karatsubaThreshold = 2 func isZero(a: BigInt): bool {.inline.} = a.limbs.len == 0 or (a.limbs.len == 1 and a.limbs[0] == 0) @@ -390,7 +390,6 @@ template `-=`*(a: var BigInt, b: BigInt) = assert a == 3.initBigInt a = a - b - func unsignedMultiplication(a: var BigInt, b, c: BigInt) {.inline.} = # always called with bl >= cl let @@ -420,6 +419,30 @@ func unsignedMultiplication(a: var BigInt, b, c: BigInt) {.inline.} = inc pos normalize(a) +func scalarMultiplication(a: var BigInt, b: BigInt, c: uint32) {.inline.} = + # always called with bl >= cl + if c == 0: + a = zero + return + let + bl = b.limbs.len + a.limbs.setLen(bl + 1) + var tmp = 0'u64 + let c = uint64(c) + + for i in 0 ..< bl: + tmp += uint64(b.limbs[i]) * c + a.limbs[i] = uint32(tmp and uint32.high) + tmp = tmp shr 32 + + a.limbs[bl] = uint32(tmp) + normalize(a) + +# forward declaration for use in `multiplication` +func karatsubaMultiplication*(a: var BigInt, b, c: BigInt) +func `shl`*(x: BigInt, y: Natural): BigInt +func `shr`*(x: BigInt, y: Natural): BigInt + func multiplication(a: var BigInt, b, c: BigInt) = # a = b * c if b.isZero or c.isZero: @@ -430,11 +453,73 @@ func multiplication(a: var BigInt, b, c: BigInt) = cl = c.limbs.len if cl > bl: + # if bl >= karatsubaThreshold: + # karatsubaMultiplication(a, c, b) + # else: unsignedMultiplication(a, c, b) else: + # if cl >= karatsubaThreshold: + # karatsubaMultiplication(a, b, c) + # else: unsignedMultiplication(a, b, c) a.isNegative = b.isNegative xor c.isNegative +func karatsubaMultiplication*(a: var BigInt, b, c: BigInt) = + if b.isZero or c.isZero: + a = zero + return + a.isNegative = b.isNegative xor c.isNegative + let + bl = b.limbs.len + cl = c.limbs.len + n = max(bl, cl) + k = n shr 1 + if bl == 1: + # base case : multiply the only limb with each limb of second term + scalarMultiplication(a, c, b.limbs[0]) + a.isNegative = b.isNegative xor c.isNegative + return + if cl == 1: + scalarMultiplication(a, b, c.limbs[0]) + a.isNegative = b.isNegative xor c.isNegative + return + if bl < karatsubaThreshold: + if cl <= bl: + unsignedMultiplication(a, b, c) + else: + unsignedMultiplication(a, c, b) + a.isNegative = b.isNegative xor c.isNegative + return + if cl < karatsubaThreshold: + if bl <= cl: + unsignedMultiplication(a, c, b) + else: + unsignedMultiplication(a, b, c) + a.isNegative = b.isNegative xor c.isNegative + return + var + low_b, high_b, low_c, high_c: BigInt + # Decompose `b` and `c` in two parts of (almost) equal length + low_b.limbs = b.limbs[0 .. k-1] + high_b.limbs = b.limbs[k .. ^1] + low_c.limbs = c.limbs[0 .. k-1] + high_c.limbs = c.limbs[k .. ^1] + + # subtractive version of Karatsuba's algorithm to limit carry handling + var lowProduct, highProduct, add3, add4, add5, middleTerm: BigInt = zero + + multiplication(lowProduct, low_b, low_c) + multiplication(highProduct, high_b, high_c) + + add3 = low_b - high_b + add4 = high_c - low_c + + multiplication(add5, add4, add3) + + middleTerm = lowProduct + highProduct + add5 + a = lowProduct + middleTerm shl (32*k) + highProduct shl (64*k) + a.isNegative = b.isNegative xor c.isNegative + func `*`*(a, b: BigInt): BigInt = ## Multiplication for `BigInt`s. runnableExamples: @@ -1124,7 +1209,6 @@ iterator `..<`*(a, b: BigInt): BigInt = yield res inc res - func modulo(a, modulus: BigInt): BigInt = ## Like `mod`, but the result is always in the range `[0, modulus-1]`. ## `modulus` should be greater than zero. @@ -1140,7 +1224,6 @@ func fastLog2*(a: BigInt): int = return -1 bitops.fastLog2(a.limbs[^1]) + 32*(a.limbs.high) - func invmod*(a, modulus: BigInt): BigInt = ## Compute the modular inverse of `a` modulo `modulus`. ## The return value is always in the range `[1, modulus-1]` @@ -1198,3 +1281,106 @@ func powmod*(base, exponent, modulus: BigInt): BigInt = result = (result * basePow) mod modulus basePow = (basePow * basePow) mod modulus exponent = exponent shr 1 + +when isMainModule: + var a, b, c: BigInt + let + two = 2.initBigInt + three = 3.initBigInt + four = 4.initBigInt + + a.limbs = @[1'u32, 2'u32] + b.limbs = @[3'u32, 4'u32] + echo a.limbs + echo b.limbs + echo "factors: ", a, " ", b + karatsubaMultiplication(c, a, b) + echo "product Karatsuba: ", c + echo "correct product: ", a * b + + a.limbs = @[1'u32, 0'u32] + b.limbs = @[0'u32, 4'u32] + echo a.limbs + echo b.limbs + echo "factors: ", a, " ", b + karatsubaMultiplication(c, a, b) + echo "product Karatsuba: ", c + echo "correct product: ", a * b + + a.limbs = @[2'u32, 1'u32] + b.limbs = @[3'u32, 4'u32] + echo a.limbs + echo b.limbs + echo "factors: ", a, " ", b + karatsubaMultiplication(c, a, b) + echo "product Karatsuba: ", c + echo "correct product: ", a * b + + a.limbs = @[2'u32, 1'u32] + b.limbs = @[4'u32, 3'u32] + echo a.limbs + echo b.limbs + echo "factors: ", a, " ", b + karatsubaMultiplication(c, a, b) + echo "product Karatsuba: ", c + echo "correct product: ", a * b + + a.limbs = @[1'u32, 2'u32] + b.limbs = @[3'u32, 4'u32] + echo a.limbs + echo b.limbs + echo "factors: ", a, " ", b + karatsubaMultiplication(c, a, b) + echo "product Karatsuba: ", c + echo "correct product: ", a * b + + a = two shl 32 - one + b = four shl 32 - three + echo a.limbs + echo b.limbs + echo "factors: ", a, " ", b + karatsubaMultiplication(c, a, b) + echo "product Karatsuba: ", c + echo "correct product: ", a * b + + a = -(two shl 32 + one) + b = four shl 32 - three + echo a.limbs + echo b.limbs + echo "factors: ", a, " ", b + karatsubaMultiplication(c, a, b) + echo "product Karatsuba: ", c + echo "correct product: ", a * b + + a.limbs = @[1'u32, 2'u32, 3'u32] + b.limbs = @[4'u32, 5'u32, 6'u32] + a.isNegative = false + b.isNegative = false + echo a.limbs + echo b.limbs + echo "factors: ", a, " ", b + karatsubaMultiplication(c, a, b) + echo "product Karatsuba: ", c + echo "correct product: ", a * b + + a.limbs = @[1'u32, 2'u32, 3'u32, 4'u32, 5'u32] + b.limbs = @[4'u32, 5'u32, 6'u32, 7'u32, 8'u32] + a.isNegative = false + b.isNegative = false + echo a.limbs + echo b.limbs + echo "factors: ", a, " ", b + karatsubaMultiplication(c, a, b) + echo "product Karatsuba: ", c + echo "correct product: ", a * b + + a.limbs = @[1'u32, 2'u32, 3'u32, 4'u32, 5'u32, 6'u32, 7'u32, 8'u32, 9'u32, 10'u32] + b.limbs = @[10'u32, 9'u32, 8'u32, 7'u32, 6'u32, 5'u32, 4'u32, 3'u32, 2'u32, 1'u32] + a.isNegative = false + b.isNegative = false + echo a.limbs + echo b.limbs + echo "factors: ", a, " ", b + karatsubaMultiplication(c, a, b) + echo "product Karatsuba: ", c + echo "correct product: ", a * b diff --git a/tests/fastMultiplication.nim b/tests/fastMultiplication.nim new file mode 100644 index 0000000..cd9323a --- /dev/null +++ b/tests/fastMultiplication.nim @@ -0,0 +1,25 @@ +import bigints +import std/[math, random, sequtils, strutils] + +randomize() +# Pick a number in 0..100. +let limit = 10^9 +let limbs = 10000 +let randomBigInt = toSeq(1..limbs).mapIt(rand(limit)).join("").initBigInt +let randomBigInt2 = toSeq(1..limbs).mapIt(rand(limit)).join("").initBigInt +let randomBigInt3 = toSeq(1..limbs).mapIt(rand(limit)).join("").initBigInt + +# Compute subproducts +let prod1 = randomBigInt * randomBigInt2 +let prod1bis = randomBigInt2 * randomBigInt + +# Check commutativity of the product +doAssert prod1 == prod1bis + +let prod2 = randomBigInt2 * randomBigInt3 +let prod3 = randomBigInt * randomBigInt3 +let product = prod2 * randomBigInt + +# Check associativity of the product +doAssert prod1 * randomBigInt3 == product +doAssert prod3 * randomBigInt2 == product diff --git a/tests/tbigints.nim b/tests/tbigints.nim index 3b3a59d..8ac4668 100644 --- a/tests/tbigints.nim +++ b/tests/tbigints.nim @@ -350,6 +350,78 @@ proc main() = doAssert (d xor d) == f doAssert (d xor f) == d + block: # multiplication + # test sign + let one = 1.initBigInt + let negOne = -1.initBigInt + doAssert one * negOne == negOne + + # factors with 4 limbs + let a = "1780983279228119273110576463639172624".initBigInt + let b = "1843917749452418885995463656480858321".initBigInt + doAssert a * b == "3283986680046702618742503890385314117448805445290098330749803441805804304".initBigInt + + # factors with 17 limbs + let c = "15456863493948186026689401110531937466657435954521677287549013772194751214595085262021623597960658907994197330891108031896474775438991400654520526954653285".initBigInt + let d = "20867311096234429137120990056519061484140179793024844459539745043528236531589522382271230666075358518275274769618792229717222657110424037636116966396665200".initBigInt + var r: BigInt = 0.initBigInt + karatsubaMultiplication(r, c, d) + doAssert c * d == "322543179100245850295291096700090285623536165554432133161470913224665233565153206743023505404409261647296075477738317301701554637184306640109864382144645081119052516436652162825894456855767719709860985552674755702938369565636714472650667032224717209489767579823588160939485446085000195032327964706246225182000".initBigInt + doAssert c * d == r + + # factors with 65 limbs + let e = initBigInt("1f3b839241b0aacc183858dc7a75a773e7bad642a9f426ef499d91e09c9f99a88ec9a14d5ee51175faeaa10d2fb06f3ee37d2f50fe755c2c963aeb539cd55c0e14f5a23f04c64839c22bd4108034b7afc95e01a1c2fe605d8b1930926e886a8f3d7fc09acd54d388cb5d4b3a3fb4eaf6781173ab3a0cd8ad3119c37dd2cf05544235d7b85b2c96d2ed29e1a685820c4afdd824bd8878f1b6a3f52a57eb886efaa737af47161c89f298d908aa950979b8c2615d4e03b47ee87a5381ca39d9ec4788d7abd07b174913b962c02cdd5f8319722a3345eb38d3ebdd51dec66a58e89902151539298c41446758bac66923c910fd7a2d12d0d5c8bb688970b8a77e7d5fc", base = 16) + let f = initBigInt("116be4e445ea68066fca5652e472eb1c1a5fb850311126a8a91fd6f1199f92a9d6602a81bb5e500d163b01df7eec15e41109c62f6c83425027272823d9888a51c93422d47ba4e1cecb94d6fa02eb27df537038b2ac7c9c264634b8febc452c5c9043ddc7d2eddd04f1f743d85cfefd0864441bb9cbf46308138d037b2057980c8b6d215e6bd2c0a73d64c176b7b59452a2c7968a121e5d46c859d85678acd0d4927509418fd351331791fdb7ad041ee2e7d5975867f1812d1e17a41f5a7c0735ebe224294f5ca9d607e95a8722adf58f676b23da1563ac62a52352c10efd0bf5cc8b5eb9ddd7fd1c22bc307d5edf86474f66302fd7cd288a9bc251d3ba45a851b", base = 16) + doAssert e * f == initBigInt("2201d898f354480704067f8513e7b8b365db60c9132d96e54b8546c05c59a08f50b8d8488841d9e893b3c6de34e22b70f2f6c9bc682700c060b60b6e614c9cc39a2c9a9c13ccc412a512c8126f60e1572d20281855e63d019be43a34a929c20818d05527e75f9edf4c5c9a096c4412f879ca14dd3f8bc48aafea6ce17223adefc1e55ca8216ed3d6d351f08dc38a5e6fed7cb3abd1844bfaed632a1571d0017f1285b2e7762c8c3ad0adc781f47df619f462415b8b4e3496fea3d1b1f4443184a5f5fdc155af2f62d861c9ab321a5083c07b4fdfc384aaa6c4a09559e1d7383b5a3fd9f6c9dfe2079c13bfcf307f98de6e16e474b55e94dcf9dbaf3e90ba38f37ddc69c318ed680a0db8a5d5257a1214c765b267f3df8352ad5b4ee9e9c27ca7085ee5687e6afc8108b683c622613c003c068b60bed656d2dc6a32b1f7b079194108301da8d8049f5b64e88da091803bf1582a45fd24f242f2d9b9d8090c4d088ea31faa3e997d20688481b8f1847524f28153e0ba5c3017338cf470c906d8b27352082741dcfb81ec81a3268569424f791c9a82777d66f2b6a52e0653843057da444b55ddb5f517f1676daafa3413ee3dc6dc0a8edc7cfcbd4b6dd5653957baa93e35fad6908addc018706b0acf64cd5cfe3ee462b57e2cc58f641a883a693505fab131a8f51f22cc34ff694af1e4ca0b4d8469067a76b378783b190c2377d94", base = 16) + + block: # self-addition/self-subtraction + # self-addition + var a = zero + a += a + doAssert a == zero + a = 12.initBigInt + a += a + doAssert a == 24.initBigInt + a = 20736.initBigInt + a += a + doAssert a == 41472.initBigInt + a = "184884258895036416".initBigInt + a += a + doAssert a == "369768517790072832".initBigInt + + # self-subtraction + var b = zero + b -= b + doAssert b == zero + b = 12.initBigInt + b -= b + doAssert b == zero + b = 20736.initBigInt + b -= b + doAssert b == zero + b = "184884258895036416".initBigInt + b -= b + doAssert b == zero + + block: # self-multiplication + var a = 12.initBigInt + a *= a + doAssert a == 144.initBigInt + a *= a + doAssert a == 20736.initBigInt + a *= a + doAssert a == 429981696.initBigInt + a *= a + doAssert a == "184884258895036416".initBigInt + var b = zero + b *= b + doAssert b == zero + var c = one + c *= c + doAssert c == one + a *= b + doAssert a == zero + block: # inc/dec var x: BigInt @@ -787,5 +859,5 @@ proc main() = doAssert succ(a, 3) == initBigInt(10) -static: main() +# static: main() main()