diff --git a/p2p/security/noise/crypto_test.go b/p2p/security/noise/crypto_test.go index 9b7d390829..1ca11476f4 100644 --- a/p2p/security/noise/crypto_test.go +++ b/p2p/security/noise/crypto_test.go @@ -93,7 +93,7 @@ func TestCryptoFailsIfHandshakeIncomplete(t *testing.T) { init, resp := net.Pipe() _ = resp.Close() - session, _ := newSecureSession(initTransport, context.TODO(), init, "remote-peer", nil, nil, true) + session, _ := newSecureSession(initTransport, context.TODO(), init, "remote-peer", nil, nil, nil, true) _, err := session.encrypt(nil, []byte("hi")) if err == nil { t.Error("expected encryption error when handshake incomplete") diff --git a/p2p/security/noise/handshake.go b/p2p/security/noise/handshake.go index c5aa968cbd..488ee7894e 100644 --- a/p2p/security/noise/handshake.go +++ b/p2p/security/noise/handshake.go @@ -10,8 +10,6 @@ import ( "runtime/debug" "time" - "golang.org/x/crypto/chacha20poly1305" - "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/p2p/security/noise/pb" @@ -64,11 +62,6 @@ func (s *secureSession) runHandshake(ctx context.Context) (err error) { return fmt.Errorf("error initializing handshake state: %w", err) } - payload, err := s.generateHandshakePayload(kp) - if err != nil { - return err - } - // set a deadline to complete the handshake, if one has been supplied. // clear it after we're done. if deadline, ok := ctx.Deadline(); ok { @@ -78,22 +71,14 @@ func (s *secureSession) runHandshake(ctx context.Context) (err error) { } } - // We can re-use this buffer for all handshake messages as its size - // will be the size of the maximum handshake message for the Noise XX pattern. - // Also, since we prefix every noise handshake message with its length, we need to account for - // it when we fetch the buffer from the pool - maxMsgSize := 2*noise.DH25519.DHLen() + len(payload) + 2*chacha20poly1305.Overhead - hbuf := pool.Get(maxMsgSize + LengthPrefixLength) + // We can re-use this buffer for all handshake messages. + hbuf := pool.Get(2 << 10) defer pool.Put(hbuf) if s.initiator { // stage 0 // // Handshake Msg Len = len(DH ephemeral key) - var ed []byte - if s.earlyDataHandler != nil { - ed = s.earlyDataHandler.Send(ctx, s.insecureConn, s.remoteID) - } - if err := s.sendHandshakeMessage(hs, ed, hbuf); err != nil { + if err := s.sendHandshakeMessage(hs, nil, hbuf); err != nil { return fmt.Errorf("error sending handshake message: %w", err) } @@ -102,31 +87,47 @@ func (s *secureSession) runHandshake(ctx context.Context) (err error) { if err != nil { return fmt.Errorf("error reading handshake message: %w", err) } - if err := s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic()); err != nil { + rcvdEd, err := s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic()) + if err != nil { return err } + if s.initiatorEarlyDataHandler != nil { + if err := s.initiatorEarlyDataHandler.Received(ctx, s.insecureConn, rcvdEd); err != nil { + return err + } + } // stage 2 // // Handshake Msg Len = len(DHT static key) + MAC(static key is encrypted) + len(Payload) + MAC(payload is encrypted) + var ed []byte + if s.initiatorEarlyDataHandler != nil { + ed = s.initiatorEarlyDataHandler.Send(ctx, s.insecureConn, s.remoteID) + } + payload, err := s.generateHandshakePayload(kp, ed) + if err != nil { + return err + } if err := s.sendHandshakeMessage(hs, payload, hbuf); err != nil { return fmt.Errorf("error sending handshake message: %w", err) } return nil } else { // stage 0 // - initialPayload, err := s.readHandshakeMessage(hs) - if err != nil { + if _, err := s.readHandshakeMessage(hs); err != nil { return fmt.Errorf("error reading handshake message: %w", err) } - if s.earlyDataHandler != nil { - if err := s.earlyDataHandler.Received(ctx, s.insecureConn, initialPayload); err != nil { - return err - } - } // stage 1 // // Handshake Msg Len = len(DH ephemeral key) + len(DHT static key) + MAC(static key is encrypted) + len(Payload) + // MAC(payload is encrypted) + var ed []byte + if s.responderEarlyDataHandler != nil { + ed = s.responderEarlyDataHandler.Send(ctx, s.insecureConn, s.remoteID) + } + payload, err := s.generateHandshakePayload(kp, ed) + if err != nil { + return err + } if err := s.sendHandshakeMessage(hs, payload, hbuf); err != nil { return fmt.Errorf("error sending handshake message: %w", err) } @@ -136,7 +137,16 @@ func (s *secureSession) runHandshake(ctx context.Context) (err error) { if err != nil { return fmt.Errorf("error reading handshake message: %w", err) } - return s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic()) + rcvdEd, err := s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic()) + if err != nil { + return err + } + if s.responderEarlyDataHandler != nil { + if err := s.responderEarlyDataHandler.Received(ctx, s.insecureConn, rcvdEd); err != nil { + return err + } + } + return nil } } @@ -214,8 +224,8 @@ func (s *secureSession) readHandshakeMessage(hs *noise.HandshakeState) ([]byte, // generateHandshakePayload creates a libp2p handshake payload with a // signature of our static noise key. -func (s *secureSession) generateHandshakePayload(localStatic noise.DHKey) ([]byte, error) { - // obtain the public key from the handshake session so we can sign it with +func (s *secureSession) generateHandshakePayload(localStatic noise.DHKey, data []byte) ([]byte, error) { + // obtain the public key from the handshake session, so we can sign it with // our libp2p secret key. localKeyRaw, err := crypto.MarshalPublicKey(s.LocalPublicKey()) if err != nil { @@ -230,10 +240,11 @@ func (s *secureSession) generateHandshakePayload(localStatic noise.DHKey) ([]byt } // create payload - payload := new(pb.NoiseHandshakePayload) - payload.IdentityKey = localKeyRaw - payload.IdentitySig = signedPayload - payloadEnc, err := proto.Marshal(payload) + payloadEnc, err := proto.Marshal(&pb.NoiseHandshakePayload{ + IdentityKey: localKeyRaw, + IdentitySig: signedPayload, + Data: data, + }) if err != nil { return nil, fmt.Errorf("error marshaling handshake payload: %w", err) } @@ -242,22 +253,23 @@ func (s *secureSession) generateHandshakePayload(localStatic noise.DHKey) ([]byt // handleRemoteHandshakePayload unmarshals the handshake payload object sent // by the remote peer and validates the signature against the peer's static Noise key. -func (s *secureSession) handleRemoteHandshakePayload(payload []byte, remoteStatic []byte) error { +// It returns the data attached to the payload. +func (s *secureSession) handleRemoteHandshakePayload(payload []byte, remoteStatic []byte) ([]byte, error) { // unmarshal payload nhp := new(pb.NoiseHandshakePayload) err := proto.Unmarshal(payload, nhp) if err != nil { - return fmt.Errorf("error unmarshaling remote handshake payload: %w", err) + return nil, fmt.Errorf("error unmarshaling remote handshake payload: %w", err) } // unpack remote peer's public libp2p key remotePubKey, err := crypto.UnmarshalPublicKey(nhp.GetIdentityKey()) if err != nil { - return err + return nil, err } id, err := peer.IDFromPublicKey(remotePubKey) if err != nil { - return err + return nil, err } // check the peer ID for: @@ -265,7 +277,7 @@ func (s *secureSession) handleRemoteHandshakePayload(payload []byte, remoteStati // * inbound connections, if we know which peer we want to connect to (SecureInbound called with a peer ID) if (s.initiator && s.remoteID != id) || (!s.initiator && s.remoteID != "" && s.remoteID != id) { // use Pretty() as it produces the full b58-encoded string, rather than abbreviated forms. - return fmt.Errorf("peer id mismatch: expected %s, but remote key matches %s", s.remoteID.Pretty(), id.Pretty()) + return nil, fmt.Errorf("peer id mismatch: expected %s, but remote key matches %s", s.remoteID.Pretty(), id.Pretty()) } // verify payload is signed by asserted remote libp2p key. @@ -273,13 +285,13 @@ func (s *secureSession) handleRemoteHandshakePayload(payload []byte, remoteStati msg := append([]byte(payloadSigPrefix), remoteStatic...) ok, err := remotePubKey.Verify(msg, sig) if err != nil { - return fmt.Errorf("error verifying signature: %w", err) + return nil, fmt.Errorf("error verifying signature: %w", err) } else if !ok { - return fmt.Errorf("handshake signature invalid") + return nil, fmt.Errorf("handshake signature invalid") } // set remote peer key and id s.remoteID = id s.remoteKey = remotePubKey - return nil + return nhp.Data, nil } diff --git a/p2p/security/noise/session.go b/p2p/security/noise/session.go index 5e3d0956cf..692c9174ad 100644 --- a/p2p/security/noise/session.go +++ b/p2p/security/noise/session.go @@ -36,22 +36,24 @@ type secureSession struct { dec *noise.CipherState // noise prologue - prologue []byte - earlyDataHandler EarlyDataHandler + prologue []byte + + initiatorEarlyDataHandler, responderEarlyDataHandler EarlyDataHandler } // newSecureSession creates a Noise session over the given insecureConn Conn, using // the libp2p identity keypair from the given Transport. -func newSecureSession(tpt *Transport, ctx context.Context, insecure net.Conn, remote peer.ID, prologue []byte, edh EarlyDataHandler, initiator bool) (*secureSession, error) { +func newSecureSession(tpt *Transport, ctx context.Context, insecure net.Conn, remote peer.ID, prologue []byte, initiatorEDH, responderEDH EarlyDataHandler, initiator bool) (*secureSession, error) { s := &secureSession{ - insecureConn: insecure, - insecureReader: bufio.NewReader(insecure), - initiator: initiator, - localID: tpt.localID, - localKey: tpt.privateKey, - remoteID: remote, - prologue: prologue, - earlyDataHandler: edh, + insecureConn: insecure, + insecureReader: bufio.NewReader(insecure), + initiator: initiator, + localID: tpt.localID, + localKey: tpt.privateKey, + remoteID: remote, + prologue: prologue, + initiatorEarlyDataHandler: initiatorEDH, + responderEarlyDataHandler: responderEDH, } // the go-routine we create to run the handshake will diff --git a/p2p/security/noise/session_transport.go b/p2p/security/noise/session_transport.go index 973f6facf2..291cc8266a 100644 --- a/p2p/security/noise/session_transport.go +++ b/p2p/security/noise/session_transport.go @@ -22,18 +22,28 @@ func Prologue(prologue []byte) SessionOption { } } -// EarlyDataHandler allows attaching an (unencrypted) application payload to the first handshake message. -// While unencrypted, the integrity of this early data is retroactively authenticated on completion of the handshake. +// EarlyDataHandler defines what the application payload is for either the second +// (if responder) or third (if initiator) handshake message, and defines the +// logic for handling the other side's early data. Note the early data in the +// second handshake message is encrypted, but the peer is not authenticated at that point. type EarlyDataHandler interface { - // Send is called for the client before sending the first handshake message. + // Send for the initiator is called for the client before sending the third + // handshake message. Defines the application payload for the third message. + // Send for the responder is called before sending the second handshake message. Send(context.Context, net.Conn, peer.ID) []byte - // Received is called for the server when the first handshake message from the client is received. + // Received for the initiator is called when the second handshake message + // from the responder is received. + // Received for the responder is called when the third handshake message + // from the initiator is received. Received(context.Context, net.Conn, []byte) error } -func EarlyData(h EarlyDataHandler) SessionOption { +// EarlyData sets the `EarlyDataHandler` for the initiator and responder roles. +// See `EarlyDataHandler` for more details. +func EarlyData(initiator, responder EarlyDataHandler) SessionOption { return func(s *SessionTransport) error { - s.earlyDataHandler = h + s.initiatorEarlyDataHandler = initiator + s.responderEarlyDataHandler = responder return nil } } @@ -45,14 +55,15 @@ var _ sec.SecureTransport = &SessionTransport{} type SessionTransport struct { t *Transport // options - prologue []byte - earlyDataHandler EarlyDataHandler + prologue []byte + + initiatorEarlyDataHandler, responderEarlyDataHandler EarlyDataHandler } // SecureInbound runs the Noise handshake as the responder. // If p is empty, connections from any peer are accepted. func (i *SessionTransport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { - c, err := newSecureSession(i.t, ctx, insecure, p, i.prologue, i.earlyDataHandler, false) + c, err := newSecureSession(i.t, ctx, insecure, p, i.prologue, i.initiatorEarlyDataHandler, i.responderEarlyDataHandler, false) if err != nil { addr, maErr := manet.FromNetAddr(insecure.RemoteAddr()) if maErr == nil { @@ -64,5 +75,5 @@ func (i *SessionTransport) SecureInbound(ctx context.Context, insecure net.Conn, // SecureOutbound runs the Noise handshake as the initiator. func (i *SessionTransport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { - return newSecureSession(i.t, ctx, insecure, p, i.prologue, i.earlyDataHandler, true) + return newSecureSession(i.t, ctx, insecure, p, i.prologue, i.initiatorEarlyDataHandler, i.responderEarlyDataHandler, true) } diff --git a/p2p/security/noise/transport.go b/p2p/security/noise/transport.go index bd66d0fdd1..c6923698cc 100644 --- a/p2p/security/noise/transport.go +++ b/p2p/security/noise/transport.go @@ -41,7 +41,7 @@ func New(privkey crypto.PrivKey) (*Transport, error) { // SecureInbound runs the Noise handshake as the responder. // If p is empty, connections from any peer are accepted. func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { - c, err := newSecureSession(t, ctx, insecure, p, nil, nil, false) + c, err := newSecureSession(t, ctx, insecure, p, nil, nil, nil, false) if err != nil { addr, maErr := manet.FromNetAddr(insecure.RemoteAddr()) if maErr == nil { @@ -53,7 +53,7 @@ func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer // SecureOutbound runs the Noise handshake as the initiator. func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { - return newSecureSession(t, ctx, insecure, p, nil, nil, true) + return newSecureSession(t, ctx, insecure, p, nil, nil, nil, true) } func (t *Transport) WithSessionOptions(opts ...SessionOption) (sec.SecureTransport, error) { diff --git a/p2p/security/noise/transport_test.go b/p2p/security/noise/transport_test.go index f86705f462..ed29418ea0 100644 --- a/p2p/security/noise/transport_test.go +++ b/p2p/security/noise/transport_test.go @@ -447,74 +447,116 @@ func (e *earlyDataHandler) Received(ctx context.Context, conn net.Conn, data []b } func TestEarlyDataAccepted(t *testing.T) { + handshake := func(t *testing.T, client, server EarlyDataHandler) { + t.Helper() + initTransport, err := newTestTransport(t, crypto.Ed25519, 2048).WithSessionOptions(EarlyData(client, nil)) + require.NoError(t, err) + tpt := newTestTransport(t, crypto.Ed25519, 2048) + respTransport, err := tpt.WithSessionOptions(EarlyData(nil, server)) + require.NoError(t, err) + + initConn, respConn := newConnPair(t) + + errChan := make(chan error) + go func() { + _, err := respTransport.SecureInbound(context.Background(), initConn, "") + errChan <- err + }() + + conn, err := initTransport.SecureOutbound(context.Background(), respConn, tpt.localID) + require.NoError(t, err) + select { + case <-time.After(500 * time.Millisecond): + t.Fatal("timeout") + case err := <-errChan: + require.NoError(t, err) + } + defer conn.Close() + } + var receivedEarlyData []byte - serverEDH := &earlyDataHandler{ + receivingEDH := &earlyDataHandler{ received: func(_ context.Context, _ net.Conn, data []byte) error { receivedEarlyData = data return nil }, } - clientEDH := &earlyDataHandler{ + sendingEDH := &earlyDataHandler{ send: func(ctx context.Context, conn net.Conn, id peer.ID) []byte { return []byte("foobar") }, } - initTransport, err := newTestTransport(t, crypto.Ed25519, 2048).WithSessionOptions(EarlyData(clientEDH)) - require.NoError(t, err) - tpt := newTestTransport(t, crypto.Ed25519, 2048) - respTransport, err := tpt.WithSessionOptions(EarlyData(serverEDH)) - require.NoError(t, err) - initConn, respConn := newConnPair(t) + t.Run("client sending", func(t *testing.T) { + handshake(t, sendingEDH, receivingEDH) + require.Equal(t, []byte("foobar"), receivedEarlyData) + receivedEarlyData = nil + }) - errChan := make(chan error) - go func() { - _, err := respTransport.SecureInbound(context.Background(), initConn, "") - errChan <- err - }() - - conn, err := initTransport.SecureOutbound(context.Background(), respConn, tpt.localID) - require.NoError(t, err) - defer conn.Close() - - require.Equal(t, []byte("foobar"), receivedEarlyData) + t.Run("server sending", func(t *testing.T) { + handshake(t, receivingEDH, sendingEDH) + require.Equal(t, []byte("foobar"), receivedEarlyData) + receivedEarlyData = nil + }) } func TestEarlyDataRejected(t *testing.T) { - serverEDH := &earlyDataHandler{ + handshake := func(t *testing.T, client, server EarlyDataHandler) (clientErr, serverErr error) { + initTransport, err := newTestTransport(t, crypto.Ed25519, 2048).WithSessionOptions(EarlyData(client, nil)) + require.NoError(t, err) + tpt := newTestTransport(t, crypto.Ed25519, 2048) + respTransport, err := tpt.WithSessionOptions(EarlyData(nil, server)) + require.NoError(t, err) + + initConn, respConn := newConnPair(t) + + errChan := make(chan error) + go func() { + _, err := respTransport.SecureInbound(context.Background(), initConn, "") + errChan <- err + }() + + // As early data is sent with the last handshake message, the handshake will appear + // to succeed for the client. + var conn sec.SecureConn + conn, clientErr = initTransport.SecureOutbound(context.Background(), respConn, tpt.localID) + if clientErr == nil { + _, clientErr = conn.Read([]byte{0}) + } + + select { + case <-time.After(500 * time.Millisecond): + t.Fatal("timeout") + case err := <-errChan: + serverErr = err + } + return + } + + receivingEDH := &earlyDataHandler{ received: func(_ context.Context, _ net.Conn, data []byte) error { return errors.New("nope") }, } - clientEDH := &earlyDataHandler{ + sendingEDH := &earlyDataHandler{ send: func(ctx context.Context, conn net.Conn, id peer.ID) []byte { return []byte("foobar") }, } - initTransport, err := newTestTransport(t, crypto.Ed25519, 2048).WithSessionOptions(EarlyData(clientEDH)) - require.NoError(t, err) - tpt := newTestTransport(t, crypto.Ed25519, 2048) - respTransport, err := tpt.WithSessionOptions(EarlyData(serverEDH)) - require.NoError(t, err) - initConn, respConn := newConnPair(t) + t.Run("client sending", func(t *testing.T) { + clientErr, serverErr := handshake(t, sendingEDH, receivingEDH) + require.Error(t, clientErr) + require.EqualError(t, serverErr, "nope") - errChan := make(chan error) - go func() { - _, err := respTransport.SecureInbound(context.Background(), initConn, "") - errChan <- err - }() - - _, err = initTransport.SecureOutbound(context.Background(), respConn, tpt.localID) - require.Error(t, err) + }) - select { - case <-time.After(500 * time.Millisecond): - t.Fatal("timeout") - case err := <-errChan: - require.EqualError(t, err, "nope") - } + t.Run("server sending", func(t *testing.T) { + clientErr, serverErr := handshake(t, receivingEDH, sendingEDH) + require.Error(t, serverErr) + require.EqualError(t, clientErr, "nope") + }) } -func TestEarlyDataAcceptedWithNoHandler(t *testing.T) { +func TestEarlyfffDataAcceptedWithNoHandler(t *testing.T) { clientEDH := &earlyDataHandler{ send: func(ctx context.Context, conn net.Conn, id peer.ID) []byte { return []byte("foobar") }, } - initTransport, err := newTestTransport(t, crypto.Ed25519, 2048).WithSessionOptions(EarlyData(clientEDH)) + initTransport, err := newTestTransport(t, crypto.Ed25519, 2048).WithSessionOptions(EarlyData(clientEDH, nil)) require.NoError(t, err) respTransport := newTestTransport(t, crypto.Ed25519, 2048) diff --git a/p2p/transport/webtransport/listener.go b/p2p/transport/webtransport/listener.go index ca8fe1cf35..5c51b89e2e 100644 --- a/p2p/transport/webtransport/listener.go +++ b/p2p/transport/webtransport/listener.go @@ -190,7 +190,7 @@ func (l *listener) handshake(ctx context.Context, sess *webtransport.Session) (* if err != nil { return nil, err } - n, err := l.noise.WithSessionOptions(noise.EarlyData(newEarlyDataReceiver(l.checkEarlyData))) + n, err := l.noise.WithSessionOptions(noise.EarlyData(nil, newEarlyDataReceiver(l.checkEarlyData))) if err != nil { return nil, fmt.Errorf("failed to initialize Noise session: %w", err) } diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go index 03cacb133c..67c2c33840 100644 --- a/p2p/transport/webtransport/transport.go +++ b/p2p/transport/webtransport/transport.go @@ -208,7 +208,7 @@ func (t *transport) upgrade(ctx context.Context, sess *webtransport.Session, p p if err != nil { return nil, fmt.Errorf("failed to marshal WebTransport protobuf: %w", err) } - n, err := t.noise.WithSessionOptions(noise.EarlyData(newEarlyDataSender(msgBytes))) + n, err := t.noise.WithSessionOptions(noise.EarlyData(newEarlyDataSender(msgBytes), nil)) if err != nil { return nil, fmt.Errorf("failed to create Noise transport: %w", err) } diff --git a/p2p/transport/webtransport/transport_test.go b/p2p/transport/webtransport/transport_test.go index 92680fafbe..3001438d7b 100644 --- a/p2p/transport/webtransport/transport_test.go +++ b/p2p/transport/webtransport/transport_test.go @@ -162,8 +162,11 @@ func TestHashVerification(t *testing.T) { }) t.Run("fails when adding a wrong hash", func(t *testing.T) { - _, err := tr2.Dial(context.Background(), ln.Multiaddr().Encapsulate(foobarHash), serverID) - require.Error(t, err) + conn, err := tr2.Dial(context.Background(), ln.Multiaddr().Encapsulate(foobarHash), serverID) + if err != nil { + _, err = conn.AcceptStream() + require.Error(t, err) + } }) require.NoError(t, ln.Close())