From 75ef91e8a2f438e6ce2b6e620d236add8be1887d Mon Sep 17 00:00:00 2001 From: Bas Westerbaan Date: Sat, 30 Dec 2023 13:47:10 +0100 Subject: [PATCH] kyber: remove division by q in ciphertext compression On some platforms, division by q leaks some information on the ciphertext by its timing. If a keypair is reused, and an attacker has access to a decapsulation oracle, this reveals information on the private key. This is known as "kyberslash2". Note that this does not affect to the typical ephemeral usage in TLS. --- pke/kyber/internal/common/poly.go | 28 ++++--- pke/kyber/internal/common/poly_test.go | 105 +++++++++++++++++++++++++ 2 files changed, 123 insertions(+), 10 deletions(-) diff --git a/pke/kyber/internal/common/poly.go b/pke/kyber/internal/common/poly.go index d72f35d36..f580e9150 100644 --- a/pke/kyber/internal/common/poly.go +++ b/pke/kyber/internal/common/poly.go @@ -166,7 +166,7 @@ func (p *Poly) CompressMessageTo(m []byte) { // Set p to Decompress_q(m, 1). // -// Assumes d is in {3, 4, 5, 10, 11}. p will be normalized. +// Assumes d is in {4, 5, 10, 11}. p will be normalized. func (p *Poly) Decompress(m []byte, d int) { // Decompress_q(x, d) = ⌈(q/2ᵈ)x⌋ // = ⌊(q/2ᵈ)x+½⌋ @@ -244,20 +244,28 @@ func (p *Poly) Decompress(m []byte, d int) { // Writes Compress_q(p, d) to m. // -// Assumes p is normalized and d is in {3, 4, 5, 10, 11}. +// Assumes p is normalized and d is in {4, 5, 10, 11}. func (p *Poly) CompressTo(m []byte, d int) { // Compress_q(x, d) = ⌈(2ᵈ/q)x⌋ mod⁺ 2ᵈ // = ⌊(2ᵈ/q)x+½⌋ mod⁺ 2ᵈ // = ⌊((x << d) + q/2) / q⌋ mod⁺ 2ᵈ // = DIV((x << d) + q/2, q) & ((1<>e, where a/(2^e) ≈ 1/q. + // For d in {10,11} we use 20,642,679/2^36, which computes division by x/q + // correctly for 0 ≤ x < 41,522,616, which fits (q << 11) + q/2 comfortably. + // For d in {4,5} we use 315/2^20, which doesn't compute division by x/q + // correctly for all inputs, but it's close enough that the end result + // of the compression is correct. The advantage is that we do not need + // to use a 64-bit intermediate value. switch d { case 4: var t [8]uint16 idx := 0 for i := 0; i < N/8; i++ { for j := 0; j < 8; j++ { - t[j] = uint16(((uint32(p[8*i+j])<<4)+uint32(Q)/2)/ - uint32(Q)) & ((1 << 4) - 1) + t[j] = uint16((((uint32(p[8*i+j])<<4)+uint32(Q)/2)*315)>> + 20) & ((1 << 4) - 1) } m[idx] = byte(t[0]) | byte(t[1]<<4) m[idx+1] = byte(t[2]) | byte(t[3]<<4) @@ -271,8 +279,8 @@ func (p *Poly) CompressTo(m []byte, d int) { idx := 0 for i := 0; i < N/8; i++ { for j := 0; j < 8; j++ { - t[j] = uint16(((uint32(p[8*i+j])<<5)+uint32(Q)/2)/ - uint32(Q)) & ((1 << 5) - 1) + t[j] = uint16((((uint32(p[8*i+j])<<5)+uint32(Q)/2)*315)>> + 20) & ((1 << 5) - 1) } m[idx] = byte(t[0]) | byte(t[1]<<5) m[idx+1] = byte(t[1]>>3) | byte(t[2]<<2) | byte(t[3]<<7) @@ -287,8 +295,8 @@ func (p *Poly) CompressTo(m []byte, d int) { idx := 0 for i := 0; i < N/4; i++ { for j := 0; j < 4; j++ { - t[j] = uint16(((uint32(p[4*i+j])<<10)+uint32(Q)/2)/ - uint32(Q)) & ((1 << 10) - 1) + t[j] = uint16((uint64((uint32(p[4*i+j])<<10)+uint32(Q)/2)* + 20642679)>>36) & ((1 << 10) - 1) } m[idx] = byte(t[0]) m[idx+1] = byte(t[0]>>8) | byte(t[1]<<2) @@ -302,8 +310,8 @@ func (p *Poly) CompressTo(m []byte, d int) { idx := 0 for i := 0; i < N/8; i++ { for j := 0; j < 8; j++ { - t[j] = uint16(((uint32(p[8*i+j])<<11)+uint32(Q)/2)/ - uint32(Q)) & ((1 << 11) - 1) + t[j] = uint16((uint64((uint32(p[8*i+j])<<11)+uint32(Q)/2)* + 20642679)>>36) & ((1 << 11) - 1) } m[idx] = byte(t[0]) m[idx+1] = byte(t[0]>>8) | byte(t[1]<<3) diff --git a/pke/kyber/internal/common/poly_test.go b/pke/kyber/internal/common/poly_test.go index fcce5fa78..350bef961 100644 --- a/pke/kyber/internal/common/poly_test.go +++ b/pke/kyber/internal/common/poly_test.go @@ -1,6 +1,7 @@ package common import ( + "bytes" "crypto/rand" "fmt" "testing" @@ -273,3 +274,107 @@ func TestNormalizeAgainstGeneric(t *testing.T) { } } } + +func (p *Poly) OldCompressTo(m []byte, d int) { + switch d { + case 4: + var t [8]uint16 + idx := 0 + for i := 0; i < N/8; i++ { + for j := 0; j < 8; j++ { + t[j] = uint16(((uint32(p[8*i+j])<<4)+uint32(Q)/2)/ + uint32(Q)) & ((1 << 4) - 1) + } + m[idx] = byte(t[0]) | byte(t[1]<<4) + m[idx+1] = byte(t[2]) | byte(t[3]<<4) + m[idx+2] = byte(t[4]) | byte(t[5]<<4) + m[idx+3] = byte(t[6]) | byte(t[7]<<4) + idx += 4 + } + + case 5: + var t [8]uint16 + idx := 0 + for i := 0; i < N/8; i++ { + for j := 0; j < 8; j++ { + t[j] = uint16(((uint32(p[8*i+j])<<5)+uint32(Q)/2)/ + uint32(Q)) & ((1 << 5) - 1) + } + m[idx] = byte(t[0]) | byte(t[1]<<5) + m[idx+1] = byte(t[1]>>3) | byte(t[2]<<2) | byte(t[3]<<7) + m[idx+2] = byte(t[3]>>1) | byte(t[4]<<4) + m[idx+3] = byte(t[4]>>4) | byte(t[5]<<1) | byte(t[6]<<6) + m[idx+4] = byte(t[6]>>2) | byte(t[7]<<3) + idx += 5 + } + + case 10: + var t [4]uint16 + idx := 0 + for i := 0; i < N/4; i++ { + for j := 0; j < 4; j++ { + t[j] = uint16(((uint32(p[4*i+j])<<10)+uint32(Q)/2)/ + uint32(Q)) & ((1 << 10) - 1) + } + m[idx] = byte(t[0]) + m[idx+1] = byte(t[0]>>8) | byte(t[1]<<2) + m[idx+2] = byte(t[1]>>6) | byte(t[2]<<4) + m[idx+3] = byte(t[2]>>4) | byte(t[3]<<6) + m[idx+4] = byte(t[3] >> 2) + idx += 5 + } + case 11: + var t [8]uint16 + idx := 0 + for i := 0; i < N/8; i++ { + for j := 0; j < 8; j++ { + t[j] = uint16(((uint32(p[8*i+j])<<11)+uint32(Q)/2)/ + uint32(Q)) & ((1 << 11) - 1) + } + m[idx] = byte(t[0]) + m[idx+1] = byte(t[0]>>8) | byte(t[1]<<3) + m[idx+2] = byte(t[1]>>5) | byte(t[2]<<6) + m[idx+3] = byte(t[2] >> 2) + m[idx+4] = byte(t[2]>>10) | byte(t[3]<<1) + m[idx+5] = byte(t[3]>>7) | byte(t[4]<<4) + m[idx+6] = byte(t[4]>>4) | byte(t[5]<<7) + m[idx+7] = byte(t[5] >> 1) + m[idx+8] = byte(t[5]>>9) | byte(t[6]<<2) + m[idx+9] = byte(t[6]>>6) | byte(t[7]<<5) + m[idx+10] = byte(t[7] >> 3) + idx += 11 + } + default: + panic("unsupported d") + } +} + +func TestCompressFullInputFirstCoeff(t *testing.T) { + for _, d := range []int{4, 5, 10, 11} { + d := d + t.Run(fmt.Sprintf("d=%d", d), func(t *testing.T) { + var p, q Poly + bound := (Q + (1 << uint(d))) >> uint(d+1) + buf := make([]byte, (N*d-1)/8+1) + buf2 := make([]byte, len(buf)) + for i := int16(0); i < Q; i++ { + p[0] = i + p.CompressTo(buf, d) + p.OldCompressTo(buf2, d) + if !bytes.Equal(buf, buf2) { + t.Fatalf("%d", i) + } + q.Decompress(buf, d) + diff := sModQ(p[0] - q[0]) + if diff < 0 { + diff = -diff + } + if diff > bound { + t.Logf("%v\n", buf) + t.Fatalf("|%d - %d mod^± q| = %d > %d", + p[0], q[0], diff, bound) + } + } + }) + } +}