From d0dfeee80a544b1330643b1e1bb3010ba228b6f4 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 15 Nov 2018 09:15:45 +0700 Subject: [PATCH] only use a single certificate --- conn_test.go | 12 +++++++----- crypto.go | 47 +++++++++++------------------------------------ 2 files changed, 18 insertions(+), 41 deletions(-) diff --git a/conn_test.go b/conn_test.go index d474488..8db3d8c 100644 --- a/conn_test.go +++ b/conn_test.go @@ -54,7 +54,9 @@ var _ = Describe("Connection", func() { // modify the cert chain such that verificiation will fail invalidateCertChain := func(tlsConf *tls.Config) { - tlsConf.Certificates[0].Certificate = [][]byte{tlsConf.Certificates[0].Certificate[0]} + key, err := rsa.GenerateKey(rand.Reader, 1024) + Expect(err).ToNot(HaveOccurred()) + tlsConf.Certificates[0].PrivateKey = key } BeforeEach(func() { @@ -147,8 +149,8 @@ var _ = Describe("Connection", func() { serverAddr, serverConnChan := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic") clientTransport, err := NewTransport(clientKey) - invalidateCertChain(clientTransport.(*transport).tlsConf) Expect(err).ToNot(HaveOccurred()) + invalidateCertChain(clientTransport.(*transport).tlsConf) conn, err := clientTransport.Dial(context.Background(), serverAddr, serverID) Expect(err).ToNot(HaveOccurred()) Eventually(func() bool { return conn.IsClosed() }).Should(BeTrue()) @@ -157,15 +159,15 @@ var _ = Describe("Connection", func() { It("fails if the server presents an invalid cert chain", func() { serverTransport, err := NewTransport(serverKey) - invalidateCertChain(serverTransport.(*transport).tlsConf) Expect(err).ToNot(HaveOccurred()) + invalidateCertChain(serverTransport.(*transport).tlsConf) serverAddr, serverConnChan := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic") clientTransport, err := NewTransport(clientKey) Expect(err).ToNot(HaveOccurred()) _, err = clientTransport.Dial(context.Background(), serverAddr, serverID) Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("TLS handshake error: bad certificate")) + Expect(err.Error()).To(ContainSubstring("TLS handshake error")) Consistently(serverConnChan).ShouldNot(Receive()) }) @@ -176,8 +178,8 @@ var _ = Describe("Connection", func() { // first dial with an invalid cert chain clientTransport1, err := NewTransport(clientKey) - invalidateCertChain(clientTransport1.(*transport).tlsConf) Expect(err).ToNot(HaveOccurred()) + invalidateCertChain(clientTransport1.(*transport).tlsConf) _, err = clientTransport1.Dial(context.Background(), serverAddr, serverID) Expect(err).ToNot(HaveOccurred()) Consistently(serverConnChan).ShouldNot(Receive()) diff --git a/crypto.go b/crypto.go index 1ce18bc..de1522c 100644 --- a/crypto.go +++ b/crypto.go @@ -1,8 +1,6 @@ package libp2pquic import ( - "crypto/ecdsa" - "crypto/elliptic" "crypto/rand" "crypto/tls" "crypto/x509" @@ -21,29 +19,7 @@ const hostname = "quic.ipfs" const certValidityPeriod = 180 * 24 * time.Hour func generateConfig(privKey ic.PrivKey) (*tls.Config, error) { - key, hostCert, err := keyToCertificate(privKey) - if err != nil { - return nil, err - } - // The ephemeral key used just for a couple of connections (or a limited time). - ephemeralKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - return nil, err - } - // Sign the ephemeral key using the host key. - // This is the only time that the host's private key of the peer is needed. - // Note that this step could be done asynchronously, such that a running node doesn't need access its private key at all. - certTemplate := &x509.Certificate{ - DNSNames: []string{hostname}, - SerialNumber: big.NewInt(1), - NotBefore: time.Now().Add(-24 * time.Hour), - NotAfter: time.Now().Add(certValidityPeriod), - } - certDER, err := x509.CreateCertificate(rand.Reader, certTemplate, hostCert, ephemeralKey.Public(), key) - if err != nil { - return nil, err - } - cert, err := x509.ParseCertificate(certDER) + key, cert, err := keyToCertificate(privKey) if err != nil { return nil, err } @@ -52,22 +28,22 @@ func generateConfig(privKey ic.PrivKey) (*tls.Config, error) { InsecureSkipVerify: true, // This is not insecure here. We will verify the cert chain ourselves. ClientAuth: tls.RequireAnyClientCert, Certificates: []tls.Certificate{{ - Certificate: [][]byte{cert.Raw, hostCert.Raw}, - PrivateKey: ephemeralKey, + Certificate: [][]byte{cert.Raw}, + PrivateKey: key, }}, }, nil } func getRemotePubKey(chain []*x509.Certificate) (ic.PubKey, error) { - if len(chain) != 2 { - return nil, errors.New("expected 2 certificates in the chain") + if len(chain) != 1 { + return nil, errors.New("expected one certificates in the chain") } pool := x509.NewCertPool() - pool.AddCert(chain[1]) + pool.AddCert(chain[0]) if _, err := chain[0].Verify(x509.VerifyOptions{Roots: pool}); err != nil { return nil, err } - remotePubKey, err := x509.MarshalPKIXPublicKey(chain[1].PublicKey) + remotePubKey, err := x509.MarshalPKIXPublicKey(chain[0].PublicKey) if err != nil { return nil, err } @@ -80,11 +56,10 @@ func keyToCertificate(sk ic.PrivKey) (interface{}, *x509.Certificate, error) { return nil, nil, err } tmpl := &x509.Certificate{ - SerialNumber: sn, - NotBefore: time.Now().Add(-24 * time.Hour), - NotAfter: time.Now().Add(certValidityPeriod), - IsCA: true, - BasicConstraintsValid: true, + SerialNumber: sn, + NotBefore: time.Now().Add(-24 * time.Hour), + NotAfter: time.Now().Add(certValidityPeriod), + DNSNames: []string{hostname}, } var publicKey, privateKey interface{}