Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Karatsuba multiplication #95

Draft
wants to merge 16 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 190 additions & 4 deletions src/bigints.nim
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ type
limbs: seq[uint32]
isNegative: bool


func normalize(a: var BigInt) =
for i in countdown(a.limbs.high, 0):
if a.limbs[i] > 0'u32:
Expand Down Expand Up @@ -69,6 +68,7 @@ func initBigInt*(val: BigInt): BigInt =
const
zero = initBigInt(0)
one = initBigInt(1)
karatsubaThreshold = 2

func isZero(a: BigInt): bool {.inline.} =
a.limbs.len == 0 or (a.limbs.len == 1 and a.limbs[0] == 0)
Expand Down Expand Up @@ -390,7 +390,6 @@ template `-=`*(a: var BigInt, b: BigInt) =
assert a == 3.initBigInt
a = a - b


func unsignedMultiplication(a: var BigInt, b, c: BigInt) {.inline.} =
# always called with bl >= cl
let
Expand Down Expand Up @@ -420,6 +419,30 @@ func unsignedMultiplication(a: var BigInt, b, c: BigInt) {.inline.} =
inc pos
normalize(a)

func scalarMultiplication(a: var BigInt, b: BigInt, c: uint32) {.inline.} =
# always called with bl >= cl
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment needs to be removed, since cl == 1 in this case.

if c == 0:
a = zero
return
let
bl = b.limbs.len
a.limbs.setLen(bl + 1)
var tmp = 0'u64
let c = uint64(c)

for i in 0 ..< bl:
tmp += uint64(b.limbs[i]) * c
a.limbs[i] = uint32(tmp and uint32.high)
tmp = tmp shr 32

a.limbs[bl] = uint32(tmp)
normalize(a)

# forward declaration for use in `multiplication`
func karatsubaMultiplication*(a: var BigInt, b, c: BigInt)
func `shl`*(x: BigInt, y: Natural): BigInt
func `shr`*(x: BigInt, y: Natural): BigInt

func multiplication(a: var BigInt, b, c: BigInt) =
# a = b * c
if b.isZero or c.isZero:
Expand All @@ -430,11 +453,73 @@ func multiplication(a: var BigInt, b, c: BigInt) =
cl = c.limbs.len

if cl > bl:
# if bl >= karatsubaThreshold:
# karatsubaMultiplication(a, c, b)
# else:
unsignedMultiplication(a, c, b)
else:
# if cl >= karatsubaThreshold:
# karatsubaMultiplication(a, b, c)
# else:
unsignedMultiplication(a, b, c)
a.isNegative = b.isNegative xor c.isNegative

func karatsubaMultiplication*(a: var BigInt, b, c: BigInt) =
if b.isZero or c.isZero:
a = zero
return
a.isNegative = b.isNegative xor c.isNegative
let
bl = b.limbs.len
cl = c.limbs.len
n = max(bl, cl)
k = n shr 1
if bl == 1:
# base case : multiply the only limb with each limb of second term
scalarMultiplication(a, c, b.limbs[0])
a.isNegative = b.isNegative xor c.isNegative
return
if cl == 1:
scalarMultiplication(a, b, c.limbs[0])
a.isNegative = b.isNegative xor c.isNegative
return
if bl < karatsubaThreshold:
if cl <= bl:
unsignedMultiplication(a, b, c)
else:
unsignedMultiplication(a, c, b)
a.isNegative = b.isNegative xor c.isNegative
return
if cl < karatsubaThreshold:
if bl <= cl:
unsignedMultiplication(a, c, b)
else:
unsignedMultiplication(a, b, c)
a.isNegative = b.isNegative xor c.isNegative
return
var
low_b, high_b, low_c, high_c: BigInt
# Decompose `b` and `c` in two parts of (almost) equal length
low_b.limbs = b.limbs[0 .. k-1]
high_b.limbs = b.limbs[k .. ^1]
low_c.limbs = c.limbs[0 .. k-1]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When c.limbs.len is smaller than half of b.limbs.len, k is larger than c.limbs.len and low_c.limbs = c.limbs[0 .. k-1] cause IndexDefect.
It would be better to add tests that multiply a large value by a small value.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can simply change n to the min of bl and cl in this case.
I want to add these tests, but I am waiting for my other PRs to be merged. In these, I have added random generation and benchmarks. I am especially waiting for #112 . With initRandomBigInt, I will be able to generate bigints of a specific size for tests.

high_c.limbs = c.limbs[k .. ^1]
Comment on lines +503 to +506
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These all create new seqs, which isn't very efficient. Perhaps we should use openArray[uint32] instead (then this can use toOpenArray for O(1) slicing) or make our own "sliceable seq" to avoid copies.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I did not thougt about this. I used seq because I expect to join the results into a seq after.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@konsumlamm openArray are only used for procs arguments, when we want either a seq or an array as argument.
Can't we use some pointers here ? Do we have to create a new structure ?
Internal structure is a seq, if we use another structure, we would have to make a copy anyway or change the internal structure.
We can not use arrays neither, since we do not know the size at compile time. (It depends on the variable k, which value is known at runtime).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An openArray is basically a pointer and a length, it doesn't create a copy. Anything that doesn't create a copy but just modifies indices/pointers should be good.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not see how we can call multiplication then on those pointers, without making a multiplication directly on arrays.
We need to initialize BigInts

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These all create new seqs, which isn't very efficient. Perhaps we should use openArray[uint32] instead (then this can use toOpenArray for O(1) slicing) or make our own "sliceable seq" to avoid copies.

The algorithm wants the value of the polynomial corresponding to each parts of the slice.
The best way so far that I conceive is to modify the BigInt's limbs field from:

type
  BigInt* = object
    limbs: seq[uint32]
    isNegative: bool

to

type
  BigInt* = object
    limbs: ref seq[uint32]
    isNegative: bool

Otherwise, we will have to reimplement addition, subtraction and base case multiplication for another container.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seqs are openarrays, if you implement for openarrays, it's implemented for seq and arrays.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to change the whole code.

Addition, subtraction and multiplication take bigints with a sequence parameter as input.

For the Karatsuba algorithm as well as some others, we need to manipulate these sequences through a pointer and get a pointer for parts of the sequence.

We also need to operate on those slices of the sequence, i.e. get the value associated with each part of the slice, and add, subtract, and multiply these slices.

That's why we need a seq-like container with the possibility of getting a reference to each value of the seq.

None of the openarray types enables this.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As @konsumlamm said, I think openArray and toOpenArray is good enough to implement karatsuba multiplication.
You can write recursive procedure like this using openArray and toOpenArray:

proc sum(x: openArray[int]): int =
  case x.len:
  of 0:
    0
  of 1:
    x[0]
  of 2:
    x[0] + x[1]
  else:
    let mid = x.len div 2
    sum(toOpenArray(x, 0, mid - 1)) + sum(toOpenArray(x, mid, x.high))

echo sum([1, 2, 3, 4, 5])

This is a part of karatsubaMultiplication in your current PR:

var
  low_b, high_b, low_c, high_c: BigInt
# Decompose `b` and `c` in two parts of (almost) equal length
low_b.limbs = b.limbs[0 .. k-1]
high_b.limbs = b.limbs[k .. ^1]
low_c.limbs = c.limbs[0 .. k-1]
high_c.limbs = c.limbs[k .. ^1]

# subtractive version of Karatsuba's algorithm to limit carry handling
var lowProduct, highProduct, add3, add4, add5, middleTerm: BigInt = zero

multiplication(lowProduct, low_b, low_c)
multiplication(highProduct, high_b, high_c)

add3 = low_b - high_b
add4 = high_c - low_c

Above code would be written using toOpenArray like:

# Decompose `b` and `c` in two parts of (almost) equal length
template low_b = toOpenArray(b.limbs, 0, k - 1)
template high_b = toOpenArray(b.limbs, k, b.limbs.high)
template low_c = toOpenArray(c.limbs, 0, k - 1)
template high_c = toOpenArray(c.limbs, k, c.limbs.high)

# subtractive version of Karatsuba's algorithm to limit carry handling
var lowProduct, highProduct, add3, add4, add5, middleTerm: BigInt = zero

multiplication(lowProduct, low_b, low_c)
multiplication(highProduct, high_b, high_c)

# This code requires subtraction proc that takes 2 `openArray`
add3 = low_b - high_b
add4 = high_c - low_c

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the time taken and the detailed changes.

# This code requires subtraction proc that takes 2 `openArray`

I just want to point out that scalarMultiplication and unsignedMultiplication will also have to take openArrays:

  if bl == 1:
    scalarMultiplication(a, c, b.limbs[0]) # b and c are openArrays
    a.isNegative = b.isNegative xor c.isNegative
    return 
 ...
  if bl < karatsubaThreshold:
    if cl <= bl:
      unsignedMultiplication(a, b, c) # b and c are openArrays here
    else:
      unsignedMultiplication(a, c, b) # same
    a.isNegative = b.isNegative xor c.isNegative
    return
   ...

I can make a multiplication(a: var Bigint, b, c: openArray[uint32]) = proc, and avoid to rewrite addition and shr for openArrays.

I read the unsignedMultiplication and scalarMultiplication proc algorithms again and they effectively don't need parameters to be BigInts.
The substract proc will be quite delicate to convert for OpenArray parameters though.

I will look into it.


# subtractive version of Karatsuba's algorithm to limit carry handling
var lowProduct, highProduct, add3, add4, add5, middleTerm: BigInt = zero

multiplication(lowProduct, low_b, low_c)
multiplication(highProduct, high_b, high_c)

add3 = low_b - high_b
add4 = high_c - low_c

multiplication(add5, add4, add3)

middleTerm = lowProduct + highProduct + add5
a = lowProduct + middleTerm shl (32*k) + highProduct shl (64*k)
a.isNegative = b.isNegative xor c.isNegative

func `*`*(a, b: BigInt): BigInt =
## Multiplication for `BigInt`s.
runnableExamples:
Expand Down Expand Up @@ -1124,7 +1209,6 @@ iterator `..<`*(a, b: BigInt): BigInt =
yield res
inc res


func modulo(a, modulus: BigInt): BigInt =
## Like `mod`, but the result is always in the range `[0, modulus-1]`.
## `modulus` should be greater than zero.
Expand All @@ -1140,7 +1224,6 @@ func fastLog2*(a: BigInt): int =
return -1
bitops.fastLog2(a.limbs[^1]) + 32*(a.limbs.high)


func invmod*(a, modulus: BigInt): BigInt =
## Compute the modular inverse of `a` modulo `modulus`.
## The return value is always in the range `[1, modulus-1]`
Expand Down Expand Up @@ -1198,3 +1281,106 @@ func powmod*(base, exponent, modulus: BigInt): BigInt =
result = (result * basePow) mod modulus
basePow = (basePow * basePow) mod modulus
exponent = exponent shr 1

when isMainModule:
var a, b, c: BigInt
let
two = 2.initBigInt
three = 3.initBigInt
four = 4.initBigInt

a.limbs = @[1'u32, 2'u32]
b.limbs = @[3'u32, 4'u32]
echo a.limbs
echo b.limbs
echo "factors: ", a, " ", b
karatsubaMultiplication(c, a, b)
echo "product Karatsuba: ", c
echo "correct product: ", a * b

a.limbs = @[1'u32, 0'u32]
b.limbs = @[0'u32, 4'u32]
echo a.limbs
echo b.limbs
echo "factors: ", a, " ", b
karatsubaMultiplication(c, a, b)
echo "product Karatsuba: ", c
echo "correct product: ", a * b

a.limbs = @[2'u32, 1'u32]
b.limbs = @[3'u32, 4'u32]
echo a.limbs
echo b.limbs
echo "factors: ", a, " ", b
karatsubaMultiplication(c, a, b)
echo "product Karatsuba: ", c
echo "correct product: ", a * b

a.limbs = @[2'u32, 1'u32]
b.limbs = @[4'u32, 3'u32]
echo a.limbs
echo b.limbs
echo "factors: ", a, " ", b
karatsubaMultiplication(c, a, b)
echo "product Karatsuba: ", c
echo "correct product: ", a * b

a.limbs = @[1'u32, 2'u32]
b.limbs = @[3'u32, 4'u32]
echo a.limbs
echo b.limbs
echo "factors: ", a, " ", b
karatsubaMultiplication(c, a, b)
echo "product Karatsuba: ", c
echo "correct product: ", a * b

a = two shl 32 - one
b = four shl 32 - three
echo a.limbs
echo b.limbs
echo "factors: ", a, " ", b
karatsubaMultiplication(c, a, b)
echo "product Karatsuba: ", c
echo "correct product: ", a * b

a = -(two shl 32 + one)
b = four shl 32 - three
echo a.limbs
echo b.limbs
echo "factors: ", a, " ", b
karatsubaMultiplication(c, a, b)
echo "product Karatsuba: ", c
echo "correct product: ", a * b

a.limbs = @[1'u32, 2'u32, 3'u32]
b.limbs = @[4'u32, 5'u32, 6'u32]
a.isNegative = false
b.isNegative = false
echo a.limbs
echo b.limbs
echo "factors: ", a, " ", b
karatsubaMultiplication(c, a, b)
echo "product Karatsuba: ", c
echo "correct product: ", a * b

a.limbs = @[1'u32, 2'u32, 3'u32, 4'u32, 5'u32]
b.limbs = @[4'u32, 5'u32, 6'u32, 7'u32, 8'u32]
a.isNegative = false
b.isNegative = false
echo a.limbs
echo b.limbs
echo "factors: ", a, " ", b
karatsubaMultiplication(c, a, b)
echo "product Karatsuba: ", c
echo "correct product: ", a * b

a.limbs = @[1'u32, 2'u32, 3'u32, 4'u32, 5'u32, 6'u32, 7'u32, 8'u32, 9'u32, 10'u32]
b.limbs = @[10'u32, 9'u32, 8'u32, 7'u32, 6'u32, 5'u32, 4'u32, 3'u32, 2'u32, 1'u32]
a.isNegative = false
b.isNegative = false
echo a.limbs
echo b.limbs
echo "factors: ", a, " ", b
karatsubaMultiplication(c, a, b)
echo "product Karatsuba: ", c
echo "correct product: ", a * b
25 changes: 25 additions & 0 deletions tests/fastMultiplication.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import bigints
import std/[math, random, sequtils, strutils]

randomize()
# Pick a number in 0..100.
let limit = 10^9
let limbs = 10000
let randomBigInt = toSeq(1..limbs).mapIt(rand(limit)).join("").initBigInt
let randomBigInt2 = toSeq(1..limbs).mapIt(rand(limit)).join("").initBigInt
let randomBigInt3 = toSeq(1..limbs).mapIt(rand(limit)).join("").initBigInt

# Compute subproducts
let prod1 = randomBigInt * randomBigInt2
let prod1bis = randomBigInt2 * randomBigInt

# Check commutativity of the product
doAssert prod1 == prod1bis

let prod2 = randomBigInt2 * randomBigInt3
let prod3 = randomBigInt * randomBigInt3
let product = prod2 * randomBigInt

# Check associativity of the product
doAssert prod1 * randomBigInt3 == product
doAssert prod3 * randomBigInt2 == product
74 changes: 73 additions & 1 deletion tests/tbigints.nim
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,78 @@ proc main() =
doAssert (d xor d) == f
doAssert (d xor f) == d

block: # multiplication
# test sign
let one = 1.initBigInt
let negOne = -1.initBigInt
doAssert one * negOne == negOne

# factors with 4 limbs
let a = "1780983279228119273110576463639172624".initBigInt
let b = "1843917749452418885995463656480858321".initBigInt
doAssert a * b == "3283986680046702618742503890385314117448805445290098330749803441805804304".initBigInt

# factors with 17 limbs
let c = "15456863493948186026689401110531937466657435954521677287549013772194751214595085262021623597960658907994197330891108031896474775438991400654520526954653285".initBigInt
let d = "20867311096234429137120990056519061484140179793024844459539745043528236531589522382271230666075358518275274769618792229717222657110424037636116966396665200".initBigInt
var r: BigInt = 0.initBigInt
karatsubaMultiplication(r, c, d)
doAssert c * d == "322543179100245850295291096700090285623536165554432133161470913224665233565153206743023505404409261647296075477738317301701554637184306640109864382144645081119052516436652162825894456855767719709860985552674755702938369565636714472650667032224717209489767579823588160939485446085000195032327964706246225182000".initBigInt
doAssert c * d == r

# factors with 65 limbs
let e = initBigInt("1f3b839241b0aacc183858dc7a75a773e7bad642a9f426ef499d91e09c9f99a88ec9a14d5ee51175faeaa10d2fb06f3ee37d2f50fe755c2c963aeb539cd55c0e14f5a23f04c64839c22bd4108034b7afc95e01a1c2fe605d8b1930926e886a8f3d7fc09acd54d388cb5d4b3a3fb4eaf6781173ab3a0cd8ad3119c37dd2cf05544235d7b85b2c96d2ed29e1a685820c4afdd824bd8878f1b6a3f52a57eb886efaa737af47161c89f298d908aa950979b8c2615d4e03b47ee87a5381ca39d9ec4788d7abd07b174913b962c02cdd5f8319722a3345eb38d3ebdd51dec66a58e89902151539298c41446758bac66923c910fd7a2d12d0d5c8bb688970b8a77e7d5fc", base = 16)
let f = initBigInt("116be4e445ea68066fca5652e472eb1c1a5fb850311126a8a91fd6f1199f92a9d6602a81bb5e500d163b01df7eec15e41109c62f6c83425027272823d9888a51c93422d47ba4e1cecb94d6fa02eb27df537038b2ac7c9c264634b8febc452c5c9043ddc7d2eddd04f1f743d85cfefd0864441bb9cbf46308138d037b2057980c8b6d215e6bd2c0a73d64c176b7b59452a2c7968a121e5d46c859d85678acd0d4927509418fd351331791fdb7ad041ee2e7d5975867f1812d1e17a41f5a7c0735ebe224294f5ca9d607e95a8722adf58f676b23da1563ac62a52352c10efd0bf5cc8b5eb9ddd7fd1c22bc307d5edf86474f66302fd7cd288a9bc251d3ba45a851b", base = 16)
doAssert e * f == initBigInt("2201d898f354480704067f8513e7b8b365db60c9132d96e54b8546c05c59a08f50b8d8488841d9e893b3c6de34e22b70f2f6c9bc682700c060b60b6e614c9cc39a2c9a9c13ccc412a512c8126f60e1572d20281855e63d019be43a34a929c20818d05527e75f9edf4c5c9a096c4412f879ca14dd3f8bc48aafea6ce17223adefc1e55ca8216ed3d6d351f08dc38a5e6fed7cb3abd1844bfaed632a1571d0017f1285b2e7762c8c3ad0adc781f47df619f462415b8b4e3496fea3d1b1f4443184a5f5fdc155af2f62d861c9ab321a5083c07b4fdfc384aaa6c4a09559e1d7383b5a3fd9f6c9dfe2079c13bfcf307f98de6e16e474b55e94dcf9dbaf3e90ba38f37ddc69c318ed680a0db8a5d5257a1214c765b267f3df8352ad5b4ee9e9c27ca7085ee5687e6afc8108b683c622613c003c068b60bed656d2dc6a32b1f7b079194108301da8d8049f5b64e88da091803bf1582a45fd24f242f2d9b9d8090c4d088ea31faa3e997d20688481b8f1847524f28153e0ba5c3017338cf470c906d8b27352082741dcfb81ec81a3268569424f791c9a82777d66f2b6a52e0653843057da444b55ddb5f517f1676daafa3413ee3dc6dc0a8edc7cfcbd4b6dd5653957baa93e35fad6908addc018706b0acf64cd5cfe3ee462b57e2cc58f641a883a693505fab131a8f51f22cc34ff694af1e4ca0b4d8469067a76b378783b190c2377d94", base = 16)

block: # self-addition/self-subtraction
# self-addition
var a = zero
a += a
doAssert a == zero
a = 12.initBigInt
a += a
doAssert a == 24.initBigInt
a = 20736.initBigInt
a += a
doAssert a == 41472.initBigInt
a = "184884258895036416".initBigInt
a += a
doAssert a == "369768517790072832".initBigInt

# self-subtraction
var b = zero
b -= b
doAssert b == zero
b = 12.initBigInt
b -= b
doAssert b == zero
b = 20736.initBigInt
b -= b
doAssert b == zero
b = "184884258895036416".initBigInt
b -= b
doAssert b == zero

block: # self-multiplication
var a = 12.initBigInt
a *= a
doAssert a == 144.initBigInt
a *= a
doAssert a == 20736.initBigInt
a *= a
doAssert a == 429981696.initBigInt
a *= a
doAssert a == "184884258895036416".initBigInt
var b = zero
b *= b
doAssert b == zero
var c = one
c *= c
doAssert c == one
a *= b
doAssert a == zero

block: # inc/dec
var x: BigInt

Expand Down Expand Up @@ -787,5 +859,5 @@ proc main() =
doAssert succ(a, 3) == initBigInt(10)


static: main()
# static: main()
main()