-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Karatsuba multiplication #95
base: master
Are you sure you want to change the base?
Changes from all commits
bdabfe8
244a6a8
f2a48e2
9e4efd6
f3bcc9b
bc118cb
438a26e
5dd2512
dc69b3c
fd4d321
bb05ee5
1327726
d7bfc9f
290a1d1
334537a
463ece4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,7 +11,6 @@ type | |
limbs: seq[uint32] | ||
isNegative: bool | ||
|
||
|
||
func normalize(a: var BigInt) = | ||
for i in countdown(a.limbs.high, 0): | ||
if a.limbs[i] > 0'u32: | ||
|
@@ -69,6 +68,7 @@ func initBigInt*(val: BigInt): BigInt = | |
const | ||
zero = initBigInt(0) | ||
one = initBigInt(1) | ||
karatsubaThreshold = 2 | ||
|
||
func isZero(a: BigInt): bool {.inline.} = | ||
a.limbs.len == 0 or (a.limbs.len == 1 and a.limbs[0] == 0) | ||
|
@@ -390,7 +390,6 @@ template `-=`*(a: var BigInt, b: BigInt) = | |
assert a == 3.initBigInt | ||
a = a - b | ||
|
||
|
||
func unsignedMultiplication(a: var BigInt, b, c: BigInt) {.inline.} = | ||
# always called with bl >= cl | ||
let | ||
|
@@ -420,6 +419,30 @@ func unsignedMultiplication(a: var BigInt, b, c: BigInt) {.inline.} = | |
inc pos | ||
normalize(a) | ||
|
||
func scalarMultiplication(a: var BigInt, b: BigInt, c: uint32) {.inline.} = | ||
# always called with bl >= cl | ||
if c == 0: | ||
a = zero | ||
return | ||
let | ||
bl = b.limbs.len | ||
a.limbs.setLen(bl + 1) | ||
var tmp = 0'u64 | ||
let c = uint64(c) | ||
|
||
for i in 0 ..< bl: | ||
tmp += uint64(b.limbs[i]) * c | ||
a.limbs[i] = uint32(tmp and uint32.high) | ||
tmp = tmp shr 32 | ||
|
||
a.limbs[bl] = uint32(tmp) | ||
normalize(a) | ||
|
||
# forward declaration for use in `multiplication` | ||
func karatsubaMultiplication*(a: var BigInt, b, c: BigInt) | ||
func `shl`*(x: BigInt, y: Natural): BigInt | ||
func `shr`*(x: BigInt, y: Natural): BigInt | ||
|
||
func multiplication(a: var BigInt, b, c: BigInt) = | ||
# a = b * c | ||
if b.isZero or c.isZero: | ||
|
@@ -430,11 +453,73 @@ func multiplication(a: var BigInt, b, c: BigInt) = | |
cl = c.limbs.len | ||
|
||
if cl > bl: | ||
# if bl >= karatsubaThreshold: | ||
# karatsubaMultiplication(a, c, b) | ||
# else: | ||
unsignedMultiplication(a, c, b) | ||
else: | ||
# if cl >= karatsubaThreshold: | ||
# karatsubaMultiplication(a, b, c) | ||
# else: | ||
unsignedMultiplication(a, b, c) | ||
a.isNegative = b.isNegative xor c.isNegative | ||
|
||
func karatsubaMultiplication*(a: var BigInt, b, c: BigInt) = | ||
if b.isZero or c.isZero: | ||
a = zero | ||
return | ||
a.isNegative = b.isNegative xor c.isNegative | ||
let | ||
bl = b.limbs.len | ||
cl = c.limbs.len | ||
n = max(bl, cl) | ||
k = n shr 1 | ||
if bl == 1: | ||
# base case : multiply the only limb with each limb of second term | ||
scalarMultiplication(a, c, b.limbs[0]) | ||
a.isNegative = b.isNegative xor c.isNegative | ||
return | ||
if cl == 1: | ||
scalarMultiplication(a, b, c.limbs[0]) | ||
a.isNegative = b.isNegative xor c.isNegative | ||
return | ||
if bl < karatsubaThreshold: | ||
if cl <= bl: | ||
unsignedMultiplication(a, b, c) | ||
else: | ||
unsignedMultiplication(a, c, b) | ||
a.isNegative = b.isNegative xor c.isNegative | ||
return | ||
if cl < karatsubaThreshold: | ||
if bl <= cl: | ||
unsignedMultiplication(a, c, b) | ||
else: | ||
unsignedMultiplication(a, b, c) | ||
a.isNegative = b.isNegative xor c.isNegative | ||
return | ||
var | ||
low_b, high_b, low_c, high_c: BigInt | ||
# Decompose `b` and `c` in two parts of (almost) equal length | ||
low_b.limbs = b.limbs[0 .. k-1] | ||
high_b.limbs = b.limbs[k .. ^1] | ||
low_c.limbs = c.limbs[0 .. k-1] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can simply change n to the min of bl and cl in this case. |
||
high_c.limbs = c.limbs[k .. ^1] | ||
Comment on lines
+503
to
+506
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These all create new There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, I did not thougt about this. I used seq because I expect to join the results into a seq after. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @konsumlamm openArray are only used for procs arguments, when we want either a seq or an array as argument. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. An There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do not see how we can call multiplication then on those pointers, without making a multiplication directly on arrays. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The algorithm wants the value of the polynomial corresponding to each parts of the slice. type
BigInt* = object
limbs: seq[uint32]
isNegative: bool to type
BigInt* = object
limbs: ref seq[uint32]
isNegative: bool Otherwise, we will have to reimplement addition, subtraction and base case multiplication for another container. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. seqs are openarrays, if you implement for openarrays, it's implemented for seq and arrays. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have to change the whole code. Addition, subtraction and multiplication take bigints with a sequence parameter as input. For the Karatsuba algorithm as well as some others, we need to manipulate these sequences through a pointer and get a pointer for parts of the sequence. We also need to operate on those slices of the sequence, i.e. get the value associated with each part of the slice, and add, subtract, and multiply these slices. That's why we need a seq-like container with the possibility of getting a reference to each value of the seq. None of the openarray types enables this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As @konsumlamm said, I think openArray and proc sum(x: openArray[int]): int =
case x.len:
of 0:
0
of 1:
x[0]
of 2:
x[0] + x[1]
else:
let mid = x.len div 2
sum(toOpenArray(x, 0, mid - 1)) + sum(toOpenArray(x, mid, x.high))
echo sum([1, 2, 3, 4, 5]) This is a part of var
low_b, high_b, low_c, high_c: BigInt
# Decompose `b` and `c` in two parts of (almost) equal length
low_b.limbs = b.limbs[0 .. k-1]
high_b.limbs = b.limbs[k .. ^1]
low_c.limbs = c.limbs[0 .. k-1]
high_c.limbs = c.limbs[k .. ^1]
# subtractive version of Karatsuba's algorithm to limit carry handling
var lowProduct, highProduct, add3, add4, add5, middleTerm: BigInt = zero
multiplication(lowProduct, low_b, low_c)
multiplication(highProduct, high_b, high_c)
add3 = low_b - high_b
add4 = high_c - low_c Above code would be written using # Decompose `b` and `c` in two parts of (almost) equal length
template low_b = toOpenArray(b.limbs, 0, k - 1)
template high_b = toOpenArray(b.limbs, k, b.limbs.high)
template low_c = toOpenArray(c.limbs, 0, k - 1)
template high_c = toOpenArray(c.limbs, k, c.limbs.high)
# subtractive version of Karatsuba's algorithm to limit carry handling
var lowProduct, highProduct, add3, add4, add5, middleTerm: BigInt = zero
multiplication(lowProduct, low_b, low_c)
multiplication(highProduct, high_b, high_c)
# This code requires subtraction proc that takes 2 `openArray`
add3 = low_b - high_b
add4 = high_c - low_c There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you for the time taken and the detailed changes. # This code requires subtraction proc that takes 2 `openArray` I just want to point out that scalarMultiplication and unsignedMultiplication will also have to take openArrays: if bl == 1:
scalarMultiplication(a, c, b.limbs[0]) # b and c are openArrays
a.isNegative = b.isNegative xor c.isNegative
return
...
if bl < karatsubaThreshold:
if cl <= bl:
unsignedMultiplication(a, b, c) # b and c are openArrays here
else:
unsignedMultiplication(a, c, b) # same
a.isNegative = b.isNegative xor c.isNegative
return
... I can make a I read the unsignedMultiplication and scalarMultiplication proc algorithms again and they effectively don't need parameters to be BigInts. I will look into it. |
||
|
||
# subtractive version of Karatsuba's algorithm to limit carry handling | ||
var lowProduct, highProduct, add3, add4, add5, middleTerm: BigInt = zero | ||
|
||
multiplication(lowProduct, low_b, low_c) | ||
multiplication(highProduct, high_b, high_c) | ||
|
||
add3 = low_b - high_b | ||
add4 = high_c - low_c | ||
|
||
multiplication(add5, add4, add3) | ||
|
||
middleTerm = lowProduct + highProduct + add5 | ||
a = lowProduct + middleTerm shl (32*k) + highProduct shl (64*k) | ||
a.isNegative = b.isNegative xor c.isNegative | ||
|
||
func `*`*(a, b: BigInt): BigInt = | ||
## Multiplication for `BigInt`s. | ||
runnableExamples: | ||
|
@@ -1124,7 +1209,6 @@ iterator `..<`*(a, b: BigInt): BigInt = | |
yield res | ||
inc res | ||
|
||
|
||
func modulo(a, modulus: BigInt): BigInt = | ||
## Like `mod`, but the result is always in the range `[0, modulus-1]`. | ||
## `modulus` should be greater than zero. | ||
|
@@ -1140,7 +1224,6 @@ func fastLog2*(a: BigInt): int = | |
return -1 | ||
bitops.fastLog2(a.limbs[^1]) + 32*(a.limbs.high) | ||
|
||
|
||
func invmod*(a, modulus: BigInt): BigInt = | ||
## Compute the modular inverse of `a` modulo `modulus`. | ||
## The return value is always in the range `[1, modulus-1]` | ||
|
@@ -1198,3 +1281,106 @@ func powmod*(base, exponent, modulus: BigInt): BigInt = | |
result = (result * basePow) mod modulus | ||
basePow = (basePow * basePow) mod modulus | ||
exponent = exponent shr 1 | ||
|
||
when isMainModule: | ||
var a, b, c: BigInt | ||
let | ||
two = 2.initBigInt | ||
three = 3.initBigInt | ||
four = 4.initBigInt | ||
|
||
a.limbs = @[1'u32, 2'u32] | ||
b.limbs = @[3'u32, 4'u32] | ||
echo a.limbs | ||
echo b.limbs | ||
echo "factors: ", a, " ", b | ||
karatsubaMultiplication(c, a, b) | ||
echo "product Karatsuba: ", c | ||
echo "correct product: ", a * b | ||
|
||
a.limbs = @[1'u32, 0'u32] | ||
b.limbs = @[0'u32, 4'u32] | ||
echo a.limbs | ||
echo b.limbs | ||
echo "factors: ", a, " ", b | ||
karatsubaMultiplication(c, a, b) | ||
echo "product Karatsuba: ", c | ||
echo "correct product: ", a * b | ||
|
||
a.limbs = @[2'u32, 1'u32] | ||
b.limbs = @[3'u32, 4'u32] | ||
echo a.limbs | ||
echo b.limbs | ||
echo "factors: ", a, " ", b | ||
karatsubaMultiplication(c, a, b) | ||
echo "product Karatsuba: ", c | ||
echo "correct product: ", a * b | ||
|
||
a.limbs = @[2'u32, 1'u32] | ||
b.limbs = @[4'u32, 3'u32] | ||
echo a.limbs | ||
echo b.limbs | ||
echo "factors: ", a, " ", b | ||
karatsubaMultiplication(c, a, b) | ||
echo "product Karatsuba: ", c | ||
echo "correct product: ", a * b | ||
|
||
a.limbs = @[1'u32, 2'u32] | ||
b.limbs = @[3'u32, 4'u32] | ||
echo a.limbs | ||
echo b.limbs | ||
echo "factors: ", a, " ", b | ||
karatsubaMultiplication(c, a, b) | ||
echo "product Karatsuba: ", c | ||
echo "correct product: ", a * b | ||
|
||
a = two shl 32 - one | ||
b = four shl 32 - three | ||
echo a.limbs | ||
echo b.limbs | ||
echo "factors: ", a, " ", b | ||
karatsubaMultiplication(c, a, b) | ||
echo "product Karatsuba: ", c | ||
echo "correct product: ", a * b | ||
|
||
a = -(two shl 32 + one) | ||
b = four shl 32 - three | ||
echo a.limbs | ||
echo b.limbs | ||
echo "factors: ", a, " ", b | ||
karatsubaMultiplication(c, a, b) | ||
echo "product Karatsuba: ", c | ||
echo "correct product: ", a * b | ||
|
||
a.limbs = @[1'u32, 2'u32, 3'u32] | ||
b.limbs = @[4'u32, 5'u32, 6'u32] | ||
a.isNegative = false | ||
b.isNegative = false | ||
echo a.limbs | ||
echo b.limbs | ||
echo "factors: ", a, " ", b | ||
karatsubaMultiplication(c, a, b) | ||
echo "product Karatsuba: ", c | ||
echo "correct product: ", a * b | ||
|
||
a.limbs = @[1'u32, 2'u32, 3'u32, 4'u32, 5'u32] | ||
b.limbs = @[4'u32, 5'u32, 6'u32, 7'u32, 8'u32] | ||
a.isNegative = false | ||
b.isNegative = false | ||
echo a.limbs | ||
echo b.limbs | ||
echo "factors: ", a, " ", b | ||
karatsubaMultiplication(c, a, b) | ||
echo "product Karatsuba: ", c | ||
echo "correct product: ", a * b | ||
|
||
a.limbs = @[1'u32, 2'u32, 3'u32, 4'u32, 5'u32, 6'u32, 7'u32, 8'u32, 9'u32, 10'u32] | ||
b.limbs = @[10'u32, 9'u32, 8'u32, 7'u32, 6'u32, 5'u32, 4'u32, 3'u32, 2'u32, 1'u32] | ||
a.isNegative = false | ||
b.isNegative = false | ||
echo a.limbs | ||
echo b.limbs | ||
echo "factors: ", a, " ", b | ||
karatsubaMultiplication(c, a, b) | ||
echo "product Karatsuba: ", c | ||
echo "correct product: ", a * b |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import bigints | ||
import std/[math, random, sequtils, strutils] | ||
|
||
randomize() | ||
# Pick a number in 0..100. | ||
let limit = 10^9 | ||
let limbs = 10000 | ||
let randomBigInt = toSeq(1..limbs).mapIt(rand(limit)).join("").initBigInt | ||
let randomBigInt2 = toSeq(1..limbs).mapIt(rand(limit)).join("").initBigInt | ||
let randomBigInt3 = toSeq(1..limbs).mapIt(rand(limit)).join("").initBigInt | ||
|
||
# Compute subproducts | ||
let prod1 = randomBigInt * randomBigInt2 | ||
let prod1bis = randomBigInt2 * randomBigInt | ||
|
||
# Check commutativity of the product | ||
doAssert prod1 == prod1bis | ||
|
||
let prod2 = randomBigInt2 * randomBigInt3 | ||
let prod3 = randomBigInt * randomBigInt3 | ||
let product = prod2 * randomBigInt | ||
|
||
# Check associativity of the product | ||
doAssert prod1 * randomBigInt3 == product | ||
doAssert prod3 * randomBigInt2 == product |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment needs to be removed, since
cl == 1
in this case.