Skip to content

Commit

Permalink
noise: use separate early data handlers for initiator and receiver
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Sep 14, 2022
1 parent 65270f1 commit 0b36832
Show file tree
Hide file tree
Showing 8 changed files with 39 additions and 35 deletions.
2 changes: 1 addition & 1 deletion p2p/security/noise/crypto_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func TestCryptoFailsIfHandshakeIncomplete(t *testing.T) {
init, resp := net.Pipe()
_ = resp.Close()

session, _ := newSecureSession(initTransport, context.TODO(), init, "remote-peer", nil, nil, true)
session, _ := newSecureSession(initTransport, context.TODO(), init, "remote-peer", nil, nil, nil, true)
_, err := session.encrypt(nil, []byte("hi"))
if err == nil {
t.Error("expected encryption error when handshake incomplete")
Expand Down
16 changes: 8 additions & 8 deletions p2p/security/noise/handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ func (s *secureSession) runHandshake(ctx context.Context) (err error) {
// stage 0 //
// Handshake Msg Len = len(DH ephemeral key)
var ed []byte
if s.earlyDataHandler != nil {
ed = s.earlyDataHandler.Send(ctx, s.insecureConn, s.remoteID)
if s.initiatorEarlyDataHandler != nil {
ed = s.initiatorEarlyDataHandler.Send(ctx, s.insecureConn, s.remoteID)
}
if err := s.sendHandshakeMessage(hs, ed, hbuf); err != nil {
return fmt.Errorf("error sending handshake message: %w", err)
Expand All @@ -101,8 +101,8 @@ func (s *secureSession) runHandshake(ctx context.Context) (err error) {
if err != nil {
return err
}
if s.earlyDataHandler != nil {
if err := s.earlyDataHandler.Received(ctx, s.insecureConn, rcvdEd); err != nil {
if s.initiatorEarlyDataHandler != nil {
if err := s.initiatorEarlyDataHandler.Received(ctx, s.insecureConn, rcvdEd); err != nil {
return err
}
}
Expand All @@ -123,8 +123,8 @@ func (s *secureSession) runHandshake(ctx context.Context) (err error) {
if err != nil {
return fmt.Errorf("error reading handshake message: %w", err)
}
if s.earlyDataHandler != nil {
if err := s.earlyDataHandler.Received(ctx, s.insecureConn, initialPayload); err != nil {
if s.receiverEarlyDataHandler != nil {
if err := s.receiverEarlyDataHandler.Received(ctx, s.insecureConn, initialPayload); err != nil {
return err
}
}
Expand All @@ -133,8 +133,8 @@ func (s *secureSession) runHandshake(ctx context.Context) (err error) {
// Handshake Msg Len = len(DH ephemeral key) + len(DHT static key) + MAC(static key is encrypted) + len(Payload) +
// MAC(payload is encrypted)
var ed []byte
if s.earlyDataHandler != nil {
ed = s.earlyDataHandler.Send(ctx, s.insecureConn, s.remoteID)
if s.receiverEarlyDataHandler != nil {
ed = s.receiverEarlyDataHandler.Send(ctx, s.insecureConn, s.remoteID)
}
payload, err := s.generateHandshakePayload(kp, ed)
if err != nil {
Expand Down
24 changes: 13 additions & 11 deletions p2p/security/noise/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,22 +36,24 @@ type secureSession struct {
dec *noise.CipherState

// noise prologue
prologue []byte
earlyDataHandler EarlyDataHandler
prologue []byte

initiatorEarlyDataHandler, receiverEarlyDataHandler EarlyDataHandler
}

// newSecureSession creates a Noise session over the given insecureConn Conn, using
// the libp2p identity keypair from the given Transport.
func newSecureSession(tpt *Transport, ctx context.Context, insecure net.Conn, remote peer.ID, prologue []byte, edh EarlyDataHandler, initiator bool) (*secureSession, error) {
func newSecureSession(tpt *Transport, ctx context.Context, insecure net.Conn, remote peer.ID, prologue []byte, initiatorEDH, receiverEDH EarlyDataHandler, initiator bool) (*secureSession, error) {
s := &secureSession{
insecureConn: insecure,
insecureReader: bufio.NewReader(insecure),
initiator: initiator,
localID: tpt.localID,
localKey: tpt.privateKey,
remoteID: remote,
prologue: prologue,
earlyDataHandler: edh,
insecureConn: insecure,
insecureReader: bufio.NewReader(insecure),
initiator: initiator,
localID: tpt.localID,
localKey: tpt.privateKey,
remoteID: remote,
prologue: prologue,
initiatorEarlyDataHandler: initiatorEDH,
receiverEarlyDataHandler: receiverEDH,
}

// the go-routine we create to run the handshake will
Expand Down
14 changes: 8 additions & 6 deletions p2p/security/noise/session_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ type EarlyDataHandler interface {
Received(context.Context, net.Conn, []byte) error
}

func EarlyData(h EarlyDataHandler) SessionOption {
func EarlyData(initiator, receiver EarlyDataHandler) SessionOption {
return func(s *SessionTransport) error {
s.earlyDataHandler = h
s.initiatorEarlyDataHandler = initiator
s.receiverEarlyDataHandler = receiver
return nil
}
}
Expand All @@ -45,14 +46,15 @@ var _ sec.SecureTransport = &SessionTransport{}
type SessionTransport struct {
t *Transport
// options
prologue []byte
earlyDataHandler EarlyDataHandler
prologue []byte

initiatorEarlyDataHandler, receiverEarlyDataHandler EarlyDataHandler
}

// SecureInbound runs the Noise handshake as the responder.
// If p is empty, connections from any peer are accepted.
func (i *SessionTransport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) {
c, err := newSecureSession(i.t, ctx, insecure, p, i.prologue, i.earlyDataHandler, false)
c, err := newSecureSession(i.t, ctx, insecure, p, i.prologue, i.initiatorEarlyDataHandler, i.receiverEarlyDataHandler, false)
if err != nil {
addr, maErr := manet.FromNetAddr(insecure.RemoteAddr())
if maErr == nil {
Expand All @@ -64,5 +66,5 @@ func (i *SessionTransport) SecureInbound(ctx context.Context, insecure net.Conn,

// SecureOutbound runs the Noise handshake as the initiator.
func (i *SessionTransport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) {
return newSecureSession(i.t, ctx, insecure, p, i.prologue, i.earlyDataHandler, true)
return newSecureSession(i.t, ctx, insecure, p, i.prologue, i.initiatorEarlyDataHandler, i.receiverEarlyDataHandler, true)
}
4 changes: 2 additions & 2 deletions p2p/security/noise/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func New(privkey crypto.PrivKey) (*Transport, error) {
// SecureInbound runs the Noise handshake as the responder.
// If p is empty, connections from any peer are accepted.
func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) {
c, err := newSecureSession(t, ctx, insecure, p, nil, nil, false)
c, err := newSecureSession(t, ctx, insecure, p, nil, nil, nil, false)
if err != nil {
addr, maErr := manet.FromNetAddr(insecure.RemoteAddr())
if maErr == nil {
Expand All @@ -53,7 +53,7 @@ func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer

// SecureOutbound runs the Noise handshake as the initiator.
func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) {
return newSecureSession(t, ctx, insecure, p, nil, nil, true)
return newSecureSession(t, ctx, insecure, p, nil, nil, nil, true)
}

func (t *Transport) WithSessionOptions(opts ...SessionOption) (sec.SecureTransport, error) {
Expand Down
10 changes: 5 additions & 5 deletions p2p/security/noise/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -449,10 +449,10 @@ func (e *earlyDataHandler) Received(ctx context.Context, conn net.Conn, data []b
func TestEarlyDataAccepted(t *testing.T) {
handshake := func(t *testing.T, client, server EarlyDataHandler) {
t.Helper()
initTransport, err := newTestTransport(t, crypto.Ed25519, 2048).WithSessionOptions(EarlyData(client))
initTransport, err := newTestTransport(t, crypto.Ed25519, 2048).WithSessionOptions(EarlyData(client, nil))
require.NoError(t, err)
tpt := newTestTransport(t, crypto.Ed25519, 2048)
respTransport, err := tpt.WithSessionOptions(EarlyData(server))
respTransport, err := tpt.WithSessionOptions(EarlyData(nil, server))
require.NoError(t, err)

initConn, respConn := newConnPair(t)
Expand Down Expand Up @@ -495,10 +495,10 @@ func TestEarlyDataAccepted(t *testing.T) {
func TestEarlyDataRejected(t *testing.T) {
handshake := func(t *testing.T, client, server EarlyDataHandler) (clientErr, serverErr error) {
t.Helper()
initTransport, err := newTestTransport(t, crypto.Ed25519, 2048).WithSessionOptions(EarlyData(client))
initTransport, err := newTestTransport(t, crypto.Ed25519, 2048).WithSessionOptions(EarlyData(client, nil))
require.NoError(t, err)
tpt := newTestTransport(t, crypto.Ed25519, 2048)
respTransport, err := tpt.WithSessionOptions(EarlyData(server))
respTransport, err := tpt.WithSessionOptions(EarlyData(nil, server))
require.NoError(t, err)

initConn, respConn := newConnPair(t)
Expand Down Expand Up @@ -545,7 +545,7 @@ func TestEarlyDataAcceptedWithNoHandler(t *testing.T) {
clientEDH := &earlyDataHandler{
send: func(ctx context.Context, conn net.Conn, id peer.ID) []byte { return []byte("foobar") },
}
initTransport, err := newTestTransport(t, crypto.Ed25519, 2048).WithSessionOptions(EarlyData(clientEDH))
initTransport, err := newTestTransport(t, crypto.Ed25519, 2048).WithSessionOptions(EarlyData(clientEDH, nil))
require.NoError(t, err)
respTransport := newTestTransport(t, crypto.Ed25519, 2048)

Expand Down
2 changes: 1 addition & 1 deletion p2p/transport/webtransport/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ func (l *listener) handshake(ctx context.Context, sess *webtransport.Session) (*
if err != nil {
return nil, err
}
n, err := l.noise.WithSessionOptions(noise.EarlyData(newEarlyDataReceiver(l.checkEarlyData)))
n, err := l.noise.WithSessionOptions(noise.EarlyData(nil, newEarlyDataReceiver(l.checkEarlyData)))
if err != nil {
return nil, fmt.Errorf("failed to initialize Noise session: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion p2p/transport/webtransport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ func (t *transport) upgrade(ctx context.Context, sess *webtransport.Session, p p
if err != nil {
return nil, fmt.Errorf("failed to marshal WebTransport protobuf: %w", err)
}
n, err := t.noise.WithSessionOptions(noise.EarlyData(newEarlyDataSender(msgBytes)))
n, err := t.noise.WithSessionOptions(noise.EarlyData(newEarlyDataSender(msgBytes), nil))
if err != nil {
return nil, fmt.Errorf("failed to create Noise transport: %w", err)
}
Expand Down

0 comments on commit 0b36832

Please sign in to comment.