Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

kyber: remove division by q in ciphertext compression #468

Merged
merged 1 commit into from
Jan 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions pke/kyber/internal/common/poly.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ func (p *Poly) CompressMessageTo(m []byte) {

// Set p to Decompress_q(m, 1).
//
// Assumes d is in {3, 4, 5, 10, 11}. p will be normalized.
// Assumes d is in {4, 5, 10, 11}. p will be normalized.
func (p *Poly) Decompress(m []byte, d int) {
// Decompress_q(x, d) = ⌈(q/2ᵈ)x⌋
// = ⌊(q/2ᵈ)x+½⌋
Expand Down Expand Up @@ -244,20 +244,28 @@ func (p *Poly) Decompress(m []byte, d int) {

// Writes Compress_q(p, d) to m.
//
// Assumes p is normalized and d is in {3, 4, 5, 10, 11}.
// Assumes p is normalized and d is in {4, 5, 10, 11}.
func (p *Poly) CompressTo(m []byte, d int) {
// Compress_q(x, d) = ⌈(2ᵈ/q)x⌋ mod⁺ 2ᵈ
// = ⌊(2ᵈ/q)x+½⌋ mod⁺ 2ᵈ
// = ⌊((x << d) + q/2) / q⌋ mod⁺ 2ᵈ
// = DIV((x << d) + q/2, q) & ((1<<d) - 1)
//
// We approximate DIV(x, q) by computing (x*a)>>e, where a/(2^e) ≈ 1/q.
// For d in {10,11} we use 20,642,679/2^36, which computes division by x/q
// correctly for 0 ≤ x < 41,522,616, which fits (q << 11) + q/2 comfortably.
// For d in {4,5} we use 315/2^20, which doesn't compute division by x/q
// correctly for all inputs, but it's close enough that the end result
// of the compression is correct. The advantage is that we do not need
// to use a 64-bit intermediate value.
switch d {
case 4:
var t [8]uint16
idx := 0
for i := 0; i < N/8; i++ {
for j := 0; j < 8; j++ {
t[j] = uint16(((uint32(p[8*i+j])<<4)+uint32(Q)/2)/
uint32(Q)) & ((1 << 4) - 1)
t[j] = uint16((((uint32(p[8*i+j])<<4)+uint32(Q)/2)*315)>>
20) & ((1 << 4) - 1)
}
m[idx] = byte(t[0]) | byte(t[1]<<4)
m[idx+1] = byte(t[2]) | byte(t[3]<<4)
Expand All @@ -271,8 +279,8 @@ func (p *Poly) CompressTo(m []byte, d int) {
idx := 0
for i := 0; i < N/8; i++ {
for j := 0; j < 8; j++ {
t[j] = uint16(((uint32(p[8*i+j])<<5)+uint32(Q)/2)/
uint32(Q)) & ((1 << 5) - 1)
t[j] = uint16((((uint32(p[8*i+j])<<5)+uint32(Q)/2)*315)>>
20) & ((1 << 5) - 1)
}
m[idx] = byte(t[0]) | byte(t[1]<<5)
m[idx+1] = byte(t[1]>>3) | byte(t[2]<<2) | byte(t[3]<<7)
Expand All @@ -287,8 +295,8 @@ func (p *Poly) CompressTo(m []byte, d int) {
idx := 0
for i := 0; i < N/4; i++ {
for j := 0; j < 4; j++ {
t[j] = uint16(((uint32(p[4*i+j])<<10)+uint32(Q)/2)/
uint32(Q)) & ((1 << 10) - 1)
t[j] = uint16((uint64((uint32(p[4*i+j])<<10)+uint32(Q)/2)*
20642679)>>36) & ((1 << 10) - 1)
}
m[idx] = byte(t[0])
m[idx+1] = byte(t[0]>>8) | byte(t[1]<<2)
Expand All @@ -302,8 +310,8 @@ func (p *Poly) CompressTo(m []byte, d int) {
idx := 0
for i := 0; i < N/8; i++ {
for j := 0; j < 8; j++ {
t[j] = uint16(((uint32(p[8*i+j])<<11)+uint32(Q)/2)/
uint32(Q)) & ((1 << 11) - 1)
t[j] = uint16((uint64((uint32(p[8*i+j])<<11)+uint32(Q)/2)*
20642679)>>36) & ((1 << 11) - 1)
}
m[idx] = byte(t[0])
m[idx+1] = byte(t[0]>>8) | byte(t[1]<<3)
Expand Down
105 changes: 105 additions & 0 deletions pke/kyber/internal/common/poly_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package common

import (
"bytes"
"crypto/rand"
"fmt"
"testing"
Expand Down Expand Up @@ -273,3 +274,107 @@ func TestNormalizeAgainstGeneric(t *testing.T) {
}
}
}

func (p *Poly) OldCompressTo(m []byte, d int) {
switch d {
case 4:
var t [8]uint16
idx := 0
for i := 0; i < N/8; i++ {
for j := 0; j < 8; j++ {
t[j] = uint16(((uint32(p[8*i+j])<<4)+uint32(Q)/2)/
uint32(Q)) & ((1 << 4) - 1)
}
m[idx] = byte(t[0]) | byte(t[1]<<4)
m[idx+1] = byte(t[2]) | byte(t[3]<<4)
m[idx+2] = byte(t[4]) | byte(t[5]<<4)
m[idx+3] = byte(t[6]) | byte(t[7]<<4)
idx += 4
}

case 5:
var t [8]uint16
idx := 0
for i := 0; i < N/8; i++ {
for j := 0; j < 8; j++ {
t[j] = uint16(((uint32(p[8*i+j])<<5)+uint32(Q)/2)/
uint32(Q)) & ((1 << 5) - 1)
}
m[idx] = byte(t[0]) | byte(t[1]<<5)
m[idx+1] = byte(t[1]>>3) | byte(t[2]<<2) | byte(t[3]<<7)
m[idx+2] = byte(t[3]>>1) | byte(t[4]<<4)
m[idx+3] = byte(t[4]>>4) | byte(t[5]<<1) | byte(t[6]<<6)
m[idx+4] = byte(t[6]>>2) | byte(t[7]<<3)
idx += 5
}

case 10:
var t [4]uint16
idx := 0
for i := 0; i < N/4; i++ {
for j := 0; j < 4; j++ {
t[j] = uint16(((uint32(p[4*i+j])<<10)+uint32(Q)/2)/
uint32(Q)) & ((1 << 10) - 1)
}
m[idx] = byte(t[0])
m[idx+1] = byte(t[0]>>8) | byte(t[1]<<2)
m[idx+2] = byte(t[1]>>6) | byte(t[2]<<4)
m[idx+3] = byte(t[2]>>4) | byte(t[3]<<6)
m[idx+4] = byte(t[3] >> 2)
idx += 5
}
case 11:
var t [8]uint16
idx := 0
for i := 0; i < N/8; i++ {
for j := 0; j < 8; j++ {
t[j] = uint16(((uint32(p[8*i+j])<<11)+uint32(Q)/2)/
uint32(Q)) & ((1 << 11) - 1)
}
m[idx] = byte(t[0])
m[idx+1] = byte(t[0]>>8) | byte(t[1]<<3)
m[idx+2] = byte(t[1]>>5) | byte(t[2]<<6)
m[idx+3] = byte(t[2] >> 2)
m[idx+4] = byte(t[2]>>10) | byte(t[3]<<1)
m[idx+5] = byte(t[3]>>7) | byte(t[4]<<4)
m[idx+6] = byte(t[4]>>4) | byte(t[5]<<7)
m[idx+7] = byte(t[5] >> 1)
m[idx+8] = byte(t[5]>>9) | byte(t[6]<<2)
m[idx+9] = byte(t[6]>>6) | byte(t[7]<<5)
m[idx+10] = byte(t[7] >> 3)
idx += 11
}
default:
panic("unsupported d")
}
}

func TestCompressFullInputFirstCoeff(t *testing.T) {
for _, d := range []int{4, 5, 10, 11} {
d := d
t.Run(fmt.Sprintf("d=%d", d), func(t *testing.T) {
var p, q Poly
bound := (Q + (1 << uint(d))) >> uint(d+1)
buf := make([]byte, (N*d-1)/8+1)
buf2 := make([]byte, len(buf))
for i := int16(0); i < Q; i++ {
p[0] = i
p.CompressTo(buf, d)
p.OldCompressTo(buf2, d)
if !bytes.Equal(buf, buf2) {
t.Fatalf("%d", i)
}
q.Decompress(buf, d)
diff := sModQ(p[0] - q[0])
if diff < 0 {
diff = -diff
}
if diff > bound {
t.Logf("%v\n", buf)
t.Fatalf("|%d - %d mod^± q| = %d > %d",
p[0], q[0], diff, bound)
}
}
})
}
}