Skip to content

Commit

Permalink
kyber: remove division by q in ciphertext compression
Browse files Browse the repository at this point in the history
On some platforms, division by q leaks some information on the
ciphertext by its timing. If a keypair is reused, and an attacker has access to
a decapsulation oracle, this reveals information on the private key.
This is known as "kyberslash2".

Note that this does not affect to the typical ephemeral usage in TLS.
  • Loading branch information
bwesterb committed Dec 30, 2023
1 parent 899732a commit e30c9a4
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 10 deletions.
28 changes: 18 additions & 10 deletions pke/kyber/internal/common/poly.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ func (p *Poly) CompressMessageTo(m []byte) {

// Set p to Decompress_q(m, 1).
//
// Assumes d is in {3, 4, 5, 10, 11}. p will be normalized.
// Assumes d is in {4, 5, 10, 11}. p will be normalized.
func (p *Poly) Decompress(m []byte, d int) {
// Decompress_q(x, d) = ⌈(q/2ᵈ)x⌋
// = ⌊(q/2ᵈ)x+½⌋
Expand Down Expand Up @@ -244,20 +244,28 @@ func (p *Poly) Decompress(m []byte, d int) {

// Writes Compress_q(p, d) to m.
//
// Assumes p is normalized and d is in {3, 4, 5, 10, 11}.
// Assumes p is normalized and d is in {4, 5, 10, 11}.
func (p *Poly) CompressTo(m []byte, d int) {
// Compress_q(x, d) = ⌈(2ᵈ/q)x⌋ mod⁺ 2ᵈ
// = ⌊(2ᵈ/q)x+½⌋ mod⁺ 2ᵈ
// = ⌊((x << d) + q/2) / q⌋ mod⁺ 2ᵈ
// = DIV((x << d) + q/2, q) & ((1<<d) - 1)
//
// We approximate DIV(x, q) by computing (x*a)>>e, where a/(2^e) ≈ 1/q.
// For d in {10,11} we use 20,642,678/2^36, which computes division by x/q
// correctly for 0 ≤ x < 41,522,616, which fits (q << 11) + q/2 comfortably.
// For d in {4,5} we use 315/2^20, which doesn't compute division by x/q
// correctly for all inputs, but it's close enough that the end result
// of the compression is correct. The advantage is that we do not need
// to use a 64-bit intermediate value.
switch d {
case 4:
var t [8]uint16
idx := 0
for i := 0; i < N/8; i++ {
for j := 0; j < 8; j++ {
t[j] = uint16(((uint32(p[8*i+j])<<4)+uint32(Q)/2)/
uint32(Q)) & ((1 << 4) - 1)
t[j] = uint16((((uint32(p[8*i+j])<<4)+uint32(Q)/2)*315)>>
20) & ((1 << 4) - 1)
}
m[idx] = byte(t[0]) | byte(t[1]<<4)
m[idx+1] = byte(t[2]) | byte(t[3]<<4)
Expand All @@ -271,8 +279,8 @@ func (p *Poly) CompressTo(m []byte, d int) {
idx := 0
for i := 0; i < N/8; i++ {
for j := 0; j < 8; j++ {
t[j] = uint16(((uint32(p[8*i+j])<<5)+uint32(Q)/2)/
uint32(Q)) & ((1 << 5) - 1)
t[j] = uint16((((uint32(p[8*i+j])<<5)+uint32(Q)/2)*315)>>
20) & ((1 << 5) - 1)
}
m[idx] = byte(t[0]) | byte(t[1]<<5)
m[idx+1] = byte(t[1]>>3) | byte(t[2]<<2) | byte(t[3]<<7)
Expand All @@ -287,8 +295,8 @@ func (p *Poly) CompressTo(m []byte, d int) {
idx := 0
for i := 0; i < N/4; i++ {
for j := 0; j < 4; j++ {
t[j] = uint16(((uint32(p[4*i+j])<<10)+uint32(Q)/2)/
uint32(Q)) & ((1 << 10) - 1)
t[j] = uint16((uint64((uint32(p[4*i+j])<<10)+uint32(Q)/2)*
20642678)>>36) & ((1 << 10) - 1)
}
m[idx] = byte(t[0])
m[idx+1] = byte(t[0]>>8) | byte(t[1]<<2)
Expand All @@ -302,8 +310,8 @@ func (p *Poly) CompressTo(m []byte, d int) {
idx := 0
for i := 0; i < N/8; i++ {
for j := 0; j < 8; j++ {
t[j] = uint16(((uint32(p[8*i+j])<<11)+uint32(Q)/2)/
uint32(Q)) & ((1 << 11) - 1)
t[j] = uint16((uint64((uint32(p[8*i+j])<<11)+uint32(Q)/2)*
20642678)>>36) & ((1 << 11) - 1)
}
m[idx] = byte(t[0])
m[idx+1] = byte(t[0]>>8) | byte(t[1]<<3)
Expand Down
25 changes: 25 additions & 0 deletions pke/kyber/internal/common/poly_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,31 @@ func TestDecompressMessage(t *testing.T) {
}
}

func TestCompressFullInputFirstCoeff(t *testing.T) {
for _, d := range []int{4, 5, 10, 11} {
d := d
t.Run(fmt.Sprintf("d=%d", d), func(t *testing.T) {
var p, q Poly
bound := (Q + (1 << uint(d))) >> uint(d+1)
buf := make([]byte, (N*d-1)/8+1)
for i := int16(0); i < Q; i++ {
p[0] = i
p.CompressTo(buf, d)
q.Decompress(buf, d)
diff := sModQ(p[0] - q[0])
if diff < 0 {
diff = -diff
}
if diff > bound {
t.Logf("%v\n", buf)
t.Fatalf("|%d - %d mod^± q| = %d > %d",
p[0], q[0], diff, bound)
}
}
})
}
}

func TestCompress(t *testing.T) {
for _, d := range []int{4, 5, 10, 11} {
d := d
Expand Down

0 comments on commit e30c9a4

Please sign in to comment.