diff --git a/h2/testing/certs.go b/h2/testing/certs.go index fe73004e..2300f265 100644 --- a/h2/testing/certs.go +++ b/h2/testing/certs.go @@ -93,7 +93,7 @@ func initLocalhostCert(ca *x509.Certificate, caPriv crypto.PrivateKey) (*tls.Cer hasher.Write(pkixpub) keyID := hasher.Sum(nil) - serial, err := rand.Int(rand.Reader, mitm.MaxSerialNumber) + serial, err := mitm.GenerateSerial() if err != nil { return nil, fmt.Errorf("generating serial number: %w", err) } diff --git a/mitm/mitm.go b/mitm/mitm.go index 83ee60f5..c90b636a 100644 --- a/mitm/mitm.go +++ b/mitm/mitm.go @@ -36,9 +36,30 @@ import ( "github.com/google/martian/v3/log" ) +// GenerateSerial generates a RFC 5280 conformant serial number for use in +// X.509 certificates. +func GenerateSerial() (serial *big.Int, err error) { + for { + serial, err = rand.Int(rand.Reader, MaxSerialNumber) + if err != nil { + return nil, err + } + // If the generated serial is 20 bytes, check that the MSB is not set. If + // it is set, generate a new serial, as encoding/asn1 will prefix the serial + // with 0x00 so that it isn't interpreted as a negative integer, resulting + // in a 21 byte serial number. + if serialBytes := serial.Bytes(); len(serialBytes) < 20 || len(serialBytes) > 0 && serialBytes[0]&0x80 == 0 { + return serial, nil + } + } +} + // MaxSerialNumber is the upper boundary that is used to create unique serial // numbers for the certificate. This can be any unsigned integer up to 20 // bytes (2^(8*20)-1). +// +// NOTE: Rather than using this as an input to crypto/rand.Int, GenerateSerial +// should be used. var MaxSerialNumber = big.NewInt(0).SetBytes(bytes.Repeat([]byte{255}, 20)) // Config is a set of configuration values that are used to build TLS configs @@ -81,7 +102,7 @@ func NewAuthority(name, organization string, validity time.Duration) (*x509.Cert // TODO: keep a map of used serial numbers to avoid potentially reusing a // serial multiple times. - serial, err := rand.Int(rand.Reader, MaxSerialNumber) + serial, err := GenerateSerial() if err != nil { return nil, nil, err } @@ -261,7 +282,7 @@ func (c *Config) cert(hostname string) (*tls.Certificate, error) { log.Debugf("mitm: cache miss for %s", hostname) - serial, err := rand.Int(rand.Reader, MaxSerialNumber) + serial, err := GenerateSerial() if err != nil { return nil, err } diff --git a/mitm/mitm_test.go b/mitm/mitm_test.go index d9f584ea..e264efb7 100644 --- a/mitm/mitm_test.go +++ b/mitm/mitm_test.go @@ -132,7 +132,7 @@ func TestCert(t *testing.T) { t.Fatal("x509c: got nil, want *x509.Certificate") } - if got := x509c.SerialNumber; got.Cmp(MaxSerialNumber) >= 0 { + if got := x509c.SerialNumber; got.Cmp(MaxSerialNumber) > 0 { t.Errorf("x509c.SerialNumber: got %v, want <= MaxSerialNumber", got) } if got, want := x509c.Subject.CommonName, "example.com"; got != want { @@ -203,3 +203,17 @@ func TestCert(t *testing.T) { t.Fatalf("x509c.IPAddresses: got %v, want %v", got, want) } } + +func TestGenerateSerial(t *testing.T) { + for i := 0; i < 100; i++ { + serial, err := GenerateSerial() + if err != nil { + t.Fatalf("GenerateSerial(): got %v, want no error", err) + } + if serial.Cmp(MaxSerialNumber) > 0 { + t.Errorf("serial: got %v, want <= MaxSerialNumber", serial) + } else if serialBytes := serial.Bytes(); len(serialBytes) == 20 && serialBytes[0]&0x80 != 0 { + t.Errorf("serial: got len=20 with MSB set, want MSB unset") + } + } +}