Skip to content

Commit

Permalink
Update HPKE code to use ecdh stdlib package.
Browse files Browse the repository at this point in the history
  • Loading branch information
armfazh committed Jan 22, 2025
1 parent 4987803 commit 342ad81
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 92 deletions.
9 changes: 4 additions & 5 deletions hpke/algs.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"crypto"
"crypto/aes"
"crypto/cipher"
"crypto/elliptic"
"crypto/ecdh"
_ "crypto/sha256" // Linking sha256.
_ "crypto/sha512" // Linking sha512.
"fmt"
Expand All @@ -13,7 +13,6 @@ import (

"github.com/cloudflare/circl/dh/x25519"
"github.com/cloudflare/circl/dh/x448"
"github.com/cloudflare/circl/ecc/p384"
"github.com/cloudflare/circl/kem"
"github.com/cloudflare/circl/kem/kyber/kyber768"
"github.com/cloudflare/circl/kem/xwing"
Expand Down Expand Up @@ -247,19 +246,19 @@ var (
)

func init() {
dhkemp256hkdfsha256.Curve = elliptic.P256()
dhkemp256hkdfsha256.Curve = ecdh.P256()
dhkemp256hkdfsha256.dhKemBase.id = KEM_P256_HKDF_SHA256
dhkemp256hkdfsha256.dhKemBase.name = "HPKE_KEM_P256_HKDF_SHA256"
dhkemp256hkdfsha256.dhKemBase.Hash = crypto.SHA256
dhkemp256hkdfsha256.dhKemBase.dhKEM = dhkemp256hkdfsha256

dhkemp384hkdfsha384.Curve = p384.P384()
dhkemp384hkdfsha384.Curve = ecdh.P384()
dhkemp384hkdfsha384.dhKemBase.id = KEM_P384_HKDF_SHA384
dhkemp384hkdfsha384.dhKemBase.name = "HPKE_KEM_P384_HKDF_SHA384"
dhkemp384hkdfsha384.dhKemBase.Hash = crypto.SHA384
dhkemp384hkdfsha384.dhKemBase.dhKEM = dhkemp384hkdfsha384

dhkemp521hkdfsha512.Curve = elliptic.P521()
dhkemp521hkdfsha512.Curve = ecdh.P521()
dhkemp521hkdfsha512.dhKemBase.id = KEM_P521_HKDF_SHA512
dhkemp521hkdfsha512.dhKemBase.name = "HPKE_KEM_P521_HKDF_SHA512"
dhkemp521hkdfsha512.dhKemBase.Hash = crypto.SHA512
Expand Down
1 change: 1 addition & 0 deletions hpke/hpke.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,5 +278,6 @@ var (
ErrInvalidKEMPublicKey = errors.New("hpke: invalid KEM public key")
ErrInvalidKEMPrivateKey = errors.New("hpke: invalid KEM private key")
ErrInvalidKEMSharedSecret = errors.New("hpke: invalid KEM shared secret")
ErrInvalidKEMDeriveKey = errors.New("hpke: too many tries to derive KEM key")
ErrAEADSeqOverflows = errors.New("hpke: AEAD sequence number overflows")
)
3 changes: 3 additions & 0 deletions hpke/hpke_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@ func BenchmarkHpkeRoundTrip(b *testing.B) {
kdf hpke.KDF
aead hpke.AEAD
}{
{hpke.KEM_P256_HKDF_SHA256, hpke.KDF_HKDF_SHA256, hpke.AEAD_AES128GCM},
{hpke.KEM_P384_HKDF_SHA384, hpke.KDF_HKDF_SHA384, hpke.AEAD_AES256GCM},
{hpke.KEM_P521_HKDF_SHA512, hpke.KDF_HKDF_SHA512, hpke.AEAD_AES256GCM},
{hpke.KEM_X25519_HKDF_SHA256, hpke.KDF_HKDF_SHA256, hpke.AEAD_AES128GCM},
{hpke.KEM_X25519_KYBER768_DRAFT00, hpke.KDF_HKDF_SHA256, hpke.AEAD_AES128GCM},
{hpke.KEM_XWING, hpke.KDF_HKDF_SHA256, hpke.AEAD_AES128GCM},
Expand Down
155 changes: 68 additions & 87 deletions hpke/shortkem.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
package hpke

import (
"crypto/elliptic"
"crypto/ecdh"
"crypto/rand"
"crypto/subtle"
"fmt"
"math/big"

"github.com/cloudflare/circl/kem"
)

type shortKEM struct {
dhKemBase
elliptic.Curve
ecdh.Curve
}

func (s shortKEM) PrivateKeySize() int { return s.byteSize() }
Expand All @@ -21,19 +19,40 @@ func (s shortKEM) CiphertextSize() int { return 1 + 2*s.byteSize() }
func (s shortKEM) PublicKeySize() int { return 1 + 2*s.byteSize() }
func (s shortKEM) EncapsulationSeedSize() int { return s.byteSize() }

func (s shortKEM) byteSize() int { return (s.Params().BitSize + 7) / 8 }
func (s shortKEM) byteSize() int {
var bits int
switch s.Curve {
case ecdh.P256():
bits = 256
case ecdh.P384():
bits = 384
case ecdh.P521():
bits = 521
default:
panic(ErrInvalidKEM)
}

return (bits + 7) / 8
}

func (s shortKEM) sizeDH() int { return s.byteSize() }
func (s shortKEM) calcDH(dh []byte, sk kem.PrivateKey, pk kem.PublicKey) error {
PK := pk.(*shortKEMPubKey)
SK := sk.(*shortKEMPrivKey)
l := len(dh)
x, _ := s.ScalarMult(PK.x, PK.y, SK.priv) // only x-coordinate is used.
if x.Sign() == 0 {
return ErrInvalidKEMSharedSecret
PK, ok := pk.(*shortKEMPubKey)
if !ok {
return ErrInvalidKEMPublicKey
}

SK, ok := sk.(*shortKEMPrivKey)
if !ok {
return ErrInvalidKEMPrivateKey
}

x, err := SK.priv.ECDH(&PK.pub)
if err != nil {
return err
}
b := x.Bytes()
copy(dh[l-len(b):l], b)

copy(dh, x)
return nil
}

Expand All @@ -49,122 +68,84 @@ func (s shortKEM) DeriveKeyPair(seed []byte) (kem.PublicKey, kem.PrivateKey) {
}

bitmask := byte(0xFF)
if s.Params().BitSize == 521 {
if s.Curve == ecdh.P521() {
bitmask = 0x01
}

dkpPrk := s.labeledExtract([]byte(""), []byte("dkp_prk"), seed)
var bytes []byte
ctr := 0
for skBig := new(big.Int); skBig.Sign() == 0 || skBig.Cmp(s.Params().N) >= 0; ctr++ {
if ctr > 255 {
panic("derive key error")
}
bytes = s.labeledExpand(
for ctr := 0; ctr <= 255; ctr++ {
bytes := s.labeledExpand(
dkpPrk,
[]byte("candidate"),
[]byte{byte(ctr)},
uint16(s.byteSize()),
)
bytes[0] &= bitmask
skBig.SetBytes(bytes)
sk, err := s.UnmarshalBinaryPrivateKey(bytes)
if err == nil {
return sk.Public(), sk
}
}
l := s.PrivateKeySize()
sk := &shortKEMPrivKey{s, make([]byte, l), nil}
copy(sk.priv[l-len(bytes):], bytes)
return sk.Public(), sk

panic(ErrInvalidKEMDeriveKey)
}

func (s shortKEM) GenerateKeyPair() (kem.PublicKey, kem.PrivateKey, error) {
sk, x, y, err := elliptic.GenerateKey(s, rand.Reader)
pub := &shortKEMPubKey{s, x, y}
return pub, &shortKEMPrivKey{s, sk, pub}, err
key, err := s.Curve.GenerateKey(rand.Reader)
if err != nil {
return nil, nil, err
}

sk := &shortKEMPrivKey{s, key}
return sk.Public(), sk, err
}

func (s shortKEM) UnmarshalBinaryPrivateKey(data []byte) (kem.PrivateKey, error) {
l := s.PrivateKeySize()
if len(data) < l {
return nil, ErrInvalidKEMPrivateKey
}
sk := &shortKEMPrivKey{s, make([]byte, l), nil}
copy(sk.priv[l-len(data):l], data[:l])
if !sk.validate() {
return nil, ErrInvalidKEMPrivateKey
key, err := s.Curve.NewPrivateKey(data)
if err != nil {
return nil, err
}

return sk, nil
return &shortKEMPrivKey{s, key}, nil
}

func (s shortKEM) UnmarshalBinaryPublicKey(data []byte) (kem.PublicKey, error) {
x, y := elliptic.Unmarshal(s, data)
if x == nil {
return nil, ErrInvalidKEMPublicKey
}
key := &shortKEMPubKey{s, x, y}
if !key.validate() {
return nil, ErrInvalidKEMPublicKey
key, err := s.Curve.NewPublicKey(data)
if err != nil {
return nil, err
}
return key, nil

return &shortKEMPubKey{s, *key}, nil
}

type shortKEMPubKey struct {
scheme shortKEM
x, y *big.Int
pub ecdh.PublicKey
}

func (k *shortKEMPubKey) String() string {
return fmt.Sprintf("x: %v\ny: %v", k.x.Text(16), k.y.Text(16))
}
func (k *shortKEMPubKey) Scheme() kem.Scheme { return k.scheme }
func (k *shortKEMPubKey) MarshalBinary() ([]byte, error) {
return elliptic.Marshal(k.scheme, k.x, k.y), nil
}
func (k *shortKEMPubKey) String() string { return fmt.Sprintf("%x", k.pub.Bytes()) }
func (k *shortKEMPubKey) Scheme() kem.Scheme { return k.scheme }
func (k *shortKEMPubKey) MarshalBinary() ([]byte, error) { return k.pub.Bytes(), nil }

func (k *shortKEMPubKey) Equal(pk kem.PublicKey) bool {
k1, ok := pk.(*shortKEMPubKey)
return ok &&
k.scheme.Params().Name == k1.scheme.Params().Name &&
k.x.Cmp(k1.x) == 0 &&
k.y.Cmp(k1.y) == 0
}

func (k *shortKEMPubKey) validate() bool {
p := k.scheme.Params().P
notAtInfinity := k.x.Sign() > 0 && k.y.Sign() > 0
lessThanP := k.x.Cmp(p) < 0 && k.y.Cmp(p) < 0
onCurve := k.scheme.IsOnCurve(k.x, k.y)
return notAtInfinity && lessThanP && onCurve
return ok && k.scheme == k1.scheme && k.pub.Equal(&k1.pub)
}

type shortKEMPrivKey struct {
scheme shortKEM
priv []byte
pub *shortKEMPubKey
priv *ecdh.PrivateKey
}

func (k *shortKEMPrivKey) String() string { return fmt.Sprintf("%x", k.priv) }
func (k *shortKEMPrivKey) Scheme() kem.Scheme { return k.scheme }
func (k *shortKEMPrivKey) MarshalBinary() ([]byte, error) {
return append(make([]byte, 0, k.scheme.PrivateKeySize()), k.priv...), nil
}
func (k *shortKEMPrivKey) String() string { return fmt.Sprintf("%x", k.priv.Bytes()) }
func (k *shortKEMPrivKey) Scheme() kem.Scheme { return k.scheme }
func (k *shortKEMPrivKey) MarshalBinary() ([]byte, error) { return k.priv.Bytes(), nil }

func (k *shortKEMPrivKey) Equal(pk kem.PrivateKey) bool {
k1, ok := pk.(*shortKEMPrivKey)
return ok &&
k.scheme.Params().Name == k1.scheme.Params().Name &&
subtle.ConstantTimeCompare(k.priv, k1.priv) == 1
return ok && k.scheme == k1.scheme && k.priv.Equal(k1.priv)
}

func (k *shortKEMPrivKey) Public() kem.PublicKey {
if k.pub == nil {
x, y := k.scheme.ScalarBaseMult(k.priv)
k.pub = &shortKEMPubKey{k.scheme, x, y}
}
return k.pub
}

func (k *shortKEMPrivKey) validate() bool {
n := new(big.Int).SetBytes(k.priv)
order := k.scheme.Curve.Params().N
return len(k.priv) == k.scheme.PrivateKeySize() && n.Cmp(order) < 0
return &shortKEMPubKey{k.scheme, *k.priv.PublicKey()}
}

0 comments on commit 342ad81

Please sign in to comment.