Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

noise: make it possible for the server to send early data #1750

Merged
merged 6 commits into from
Sep 19, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion p2p/security/noise/crypto_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func TestCryptoFailsIfHandshakeIncomplete(t *testing.T) {
init, resp := net.Pipe()
_ = resp.Close()

session, _ := newSecureSession(initTransport, context.TODO(), init, "remote-peer", nil, nil, true)
session, _ := newSecureSession(initTransport, context.TODO(), init, "remote-peer", nil, nil, nil, true)
_, err := session.encrypt(nil, []byte("hi"))
if err == nil {
t.Error("expected encryption error when handshake incomplete")
Expand Down
77 changes: 44 additions & 33 deletions p2p/security/noise/handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ import (
"runtime/debug"
"time"

"golang.org/x/crypto/chacha20poly1305"

"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/p2p/security/noise/pb"
Expand Down Expand Up @@ -64,11 +62,6 @@ func (s *secureSession) runHandshake(ctx context.Context) (err error) {
return fmt.Errorf("error initializing handshake state: %w", err)
}

payload, err := s.generateHandshakePayload(kp)
if err != nil {
return err
}

// set a deadline to complete the handshake, if one has been supplied.
// clear it after we're done.
if deadline, ok := ctx.Deadline(); ok {
Expand All @@ -78,20 +71,16 @@ func (s *secureSession) runHandshake(ctx context.Context) (err error) {
}
}

// We can re-use this buffer for all handshake messages as its size
// will be the size of the maximum handshake message for the Noise XX pattern.
// Also, since we prefix every noise handshake message with its length, we need to account for
// it when we fetch the buffer from the pool
maxMsgSize := 2*noise.DH25519.DHLen() + len(payload) + 2*chacha20poly1305.Overhead
hbuf := pool.Get(maxMsgSize + LengthPrefixLength)
// We can re-use this buffer for all handshake messages.
hbuf := pool.Get(2 << 10)
defer pool.Put(hbuf)

if s.initiator {
// stage 0 //
// Handshake Msg Len = len(DH ephemeral key)
var ed []byte
if s.earlyDataHandler != nil {
ed = s.earlyDataHandler.Send(ctx, s.insecureConn, s.remoteID)
if s.initiatorEarlyDataHandler != nil {
ed = s.initiatorEarlyDataHandler.Send(ctx, s.insecureConn, s.remoteID)
}
if err := s.sendHandshakeMessage(hs, ed, hbuf); err != nil {
return fmt.Errorf("error sending handshake message: %w", err)
Expand All @@ -102,12 +91,22 @@ func (s *secureSession) runHandshake(ctx context.Context) (err error) {
if err != nil {
return fmt.Errorf("error reading handshake message: %w", err)
}
if err := s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic()); err != nil {
rcvdEd, err := s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic())
if err != nil {
return err
}
if s.initiatorEarlyDataHandler != nil {
if err := s.initiatorEarlyDataHandler.Received(ctx, s.insecureConn, rcvdEd); err != nil {
return err
}
}

// stage 2 //
// Handshake Msg Len = len(DHT static key) + MAC(static key is encrypted) + len(Payload) + MAC(payload is encrypted)
payload, err := s.generateHandshakePayload(kp, nil)
if err != nil {
return err
}
if err := s.sendHandshakeMessage(hs, payload, hbuf); err != nil {
return fmt.Errorf("error sending handshake message: %w", err)
}
Expand All @@ -118,15 +117,23 @@ func (s *secureSession) runHandshake(ctx context.Context) (err error) {
if err != nil {
return fmt.Errorf("error reading handshake message: %w", err)
}
if s.earlyDataHandler != nil {
if err := s.earlyDataHandler.Received(ctx, s.insecureConn, initialPayload); err != nil {
if s.responderEarlyDataHandler != nil {
if err := s.responderEarlyDataHandler.Received(ctx, s.insecureConn, initialPayload); err != nil {
return err
}
}

// stage 1 //
// Handshake Msg Len = len(DH ephemeral key) + len(DHT static key) + MAC(static key is encrypted) + len(Payload) +
// MAC(payload is encrypted)
var ed []byte
if s.responderEarlyDataHandler != nil {
ed = s.responderEarlyDataHandler.Send(ctx, s.insecureConn, s.remoteID)
}
payload, err := s.generateHandshakePayload(kp, ed)
if err != nil {
return err
}
if err := s.sendHandshakeMessage(hs, payload, hbuf); err != nil {
return fmt.Errorf("error sending handshake message: %w", err)
}
Expand All @@ -136,7 +143,9 @@ func (s *secureSession) runHandshake(ctx context.Context) (err error) {
if err != nil {
return fmt.Errorf("error reading handshake message: %w", err)
}
return s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic())
// we don't expect any early data on this message
_, err = s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic())
return err
}
}

Expand Down Expand Up @@ -214,8 +223,8 @@ func (s *secureSession) readHandshakeMessage(hs *noise.HandshakeState) ([]byte,

// generateHandshakePayload creates a libp2p handshake payload with a
// signature of our static noise key.
func (s *secureSession) generateHandshakePayload(localStatic noise.DHKey) ([]byte, error) {
// obtain the public key from the handshake session so we can sign it with
func (s *secureSession) generateHandshakePayload(localStatic noise.DHKey, data []byte) ([]byte, error) {
// obtain the public key from the handshake session, so we can sign it with
// our libp2p secret key.
localKeyRaw, err := crypto.MarshalPublicKey(s.LocalPublicKey())
if err != nil {
Expand All @@ -230,10 +239,11 @@ func (s *secureSession) generateHandshakePayload(localStatic noise.DHKey) ([]byt
}

// create payload
payload := new(pb.NoiseHandshakePayload)
payload.IdentityKey = localKeyRaw
payload.IdentitySig = signedPayload
payloadEnc, err := proto.Marshal(payload)
payloadEnc, err := proto.Marshal(&pb.NoiseHandshakePayload{
IdentityKey: localKeyRaw,
IdentitySig: signedPayload,
Data: data,
})
if err != nil {
return nil, fmt.Errorf("error marshaling handshake payload: %w", err)
}
Expand All @@ -242,44 +252,45 @@ func (s *secureSession) generateHandshakePayload(localStatic noise.DHKey) ([]byt

// handleRemoteHandshakePayload unmarshals the handshake payload object sent
// by the remote peer and validates the signature against the peer's static Noise key.
func (s *secureSession) handleRemoteHandshakePayload(payload []byte, remoteStatic []byte) error {
// It returns the data attached to the payload.
func (s *secureSession) handleRemoteHandshakePayload(payload []byte, remoteStatic []byte) ([]byte, error) {
// unmarshal payload
nhp := new(pb.NoiseHandshakePayload)
err := proto.Unmarshal(payload, nhp)
if err != nil {
return fmt.Errorf("error unmarshaling remote handshake payload: %w", err)
return nil, fmt.Errorf("error unmarshaling remote handshake payload: %w", err)
}

// unpack remote peer's public libp2p key
remotePubKey, err := crypto.UnmarshalPublicKey(nhp.GetIdentityKey())
if err != nil {
return err
return nil, err
}
id, err := peer.IDFromPublicKey(remotePubKey)
if err != nil {
return err
return nil, err
}

// check the peer ID for:
// * all outbound connection
// * inbound connections, if we know which peer we want to connect to (SecureInbound called with a peer ID)
if (s.initiator && s.remoteID != id) || (!s.initiator && s.remoteID != "" && s.remoteID != id) {
// use Pretty() as it produces the full b58-encoded string, rather than abbreviated forms.
return fmt.Errorf("peer id mismatch: expected %s, but remote key matches %s", s.remoteID.Pretty(), id.Pretty())
return nil, fmt.Errorf("peer id mismatch: expected %s, but remote key matches %s", s.remoteID.Pretty(), id.Pretty())
}

// verify payload is signed by asserted remote libp2p key.
sig := nhp.GetIdentitySig()
msg := append([]byte(payloadSigPrefix), remoteStatic...)
ok, err := remotePubKey.Verify(msg, sig)
if err != nil {
return fmt.Errorf("error verifying signature: %w", err)
return nil, fmt.Errorf("error verifying signature: %w", err)
} else if !ok {
return fmt.Errorf("handshake signature invalid")
return nil, fmt.Errorf("handshake signature invalid")
}

// set remote peer key and id
s.remoteID = id
s.remoteKey = remotePubKey
return nil
return nhp.Data, nil
}
24 changes: 13 additions & 11 deletions p2p/security/noise/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,22 +36,24 @@ type secureSession struct {
dec *noise.CipherState

// noise prologue
prologue []byte
earlyDataHandler EarlyDataHandler
prologue []byte

initiatorEarlyDataHandler, responderEarlyDataHandler EarlyDataHandler
}
marten-seemann marked this conversation as resolved.
Show resolved Hide resolved

// newSecureSession creates a Noise session over the given insecureConn Conn, using
// the libp2p identity keypair from the given Transport.
func newSecureSession(tpt *Transport, ctx context.Context, insecure net.Conn, remote peer.ID, prologue []byte, edh EarlyDataHandler, initiator bool) (*secureSession, error) {
func newSecureSession(tpt *Transport, ctx context.Context, insecure net.Conn, remote peer.ID, prologue []byte, initiatorEDH, responderEDH EarlyDataHandler, initiator bool) (*secureSession, error) {
s := &secureSession{
insecureConn: insecure,
insecureReader: bufio.NewReader(insecure),
initiator: initiator,
localID: tpt.localID,
localKey: tpt.privateKey,
remoteID: remote,
prologue: prologue,
earlyDataHandler: edh,
insecureConn: insecure,
insecureReader: bufio.NewReader(insecure),
initiator: initiator,
localID: tpt.localID,
localKey: tpt.privateKey,
remoteID: remote,
prologue: prologue,
initiatorEarlyDataHandler: initiatorEDH,
responderEarlyDataHandler: responderEDH,
}

// the go-routine we create to run the handshake will
Expand Down
14 changes: 8 additions & 6 deletions p2p/security/noise/session_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ type EarlyDataHandler interface {
Received(context.Context, net.Conn, []byte) error
}

func EarlyData(h EarlyDataHandler) SessionOption {
func EarlyData(initiator, responder EarlyDataHandler) SessionOption {
MarcoPolo marked this conversation as resolved.
Show resolved Hide resolved
return func(s *SessionTransport) error {
s.earlyDataHandler = h
s.initiatorEarlyDataHandler = initiator
s.responderEarlyDataHandler = responder
return nil
}
}
Expand All @@ -45,14 +46,15 @@ var _ sec.SecureTransport = &SessionTransport{}
type SessionTransport struct {
t *Transport
// options
prologue []byte
earlyDataHandler EarlyDataHandler
prologue []byte

initiatorEarlyDataHandler, responderEarlyDataHandler EarlyDataHandler
}

// SecureInbound runs the Noise handshake as the responder.
// If p is empty, connections from any peer are accepted.
func (i *SessionTransport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) {
c, err := newSecureSession(i.t, ctx, insecure, p, i.prologue, i.earlyDataHandler, false)
c, err := newSecureSession(i.t, ctx, insecure, p, i.prologue, i.initiatorEarlyDataHandler, i.responderEarlyDataHandler, false)
if err != nil {
addr, maErr := manet.FromNetAddr(insecure.RemoteAddr())
if maErr == nil {
Expand All @@ -64,5 +66,5 @@ func (i *SessionTransport) SecureInbound(ctx context.Context, insecure net.Conn,

// SecureOutbound runs the Noise handshake as the initiator.
func (i *SessionTransport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) {
return newSecureSession(i.t, ctx, insecure, p, i.prologue, i.earlyDataHandler, true)
return newSecureSession(i.t, ctx, insecure, p, i.prologue, i.initiatorEarlyDataHandler, i.responderEarlyDataHandler, true)
}
4 changes: 2 additions & 2 deletions p2p/security/noise/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func New(privkey crypto.PrivKey) (*Transport, error) {
// SecureInbound runs the Noise handshake as the responder.
// If p is empty, connections from any peer are accepted.
func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) {
c, err := newSecureSession(t, ctx, insecure, p, nil, nil, false)
c, err := newSecureSession(t, ctx, insecure, p, nil, nil, nil, false)
if err != nil {
addr, maErr := manet.FromNetAddr(insecure.RemoteAddr())
if maErr == nil {
Expand All @@ -53,7 +53,7 @@ func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer

// SecureOutbound runs the Noise handshake as the initiator.
func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) {
return newSecureSession(t, ctx, insecure, p, nil, nil, true)
return newSecureSession(t, ctx, insecure, p, nil, nil, nil, true)
}

func (t *Transport) WithSessionOptions(opts ...SessionOption) (sec.SecureTransport, error) {
Expand Down
Loading