Skip to content

Commit

Permalink
Merge pull request #100 from libp2p/upgrader-interface
Browse files Browse the repository at this point in the history
use the new transport.Upgrader interface
  • Loading branch information
marten-seemann authored Jan 4, 2022
2 parents c622cb0 + e8056e8 commit 17c6e5e
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 64 deletions.
6 changes: 3 additions & 3 deletions p2p/net/upgrader/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type listener struct {
manet.Listener

transport transport.Transport
upgrader *Upgrader
upgrader *upgrader

incoming chan transport.CapableConn
err error
Expand Down Expand Up @@ -83,7 +83,7 @@ func (l *listener) handleIncoming() {
catcher.Reset()

// gate the connection if applicable
if l.upgrader.ConnGater != nil && !l.upgrader.ConnGater.InterceptAccept(maconn) {
if l.upgrader.connGater != nil && !l.upgrader.connGater.InterceptAccept(maconn) {
log.Debugf("gater blocked incoming connection on local addr %s from %s",
maconn.LocalMultiaddr(), maconn.RemoteMultiaddr())

Expand All @@ -106,7 +106,7 @@ func (l *listener) handleIncoming() {
go func() {
defer wg.Done()

ctx, cancel := context.WithTimeout(l.ctx, l.upgrader.acceptTimeout())
ctx, cancel := context.WithTimeout(l.ctx, l.upgrader.acceptTimeout)
defer cancel()

conn, err := l.upgrader.Upgrade(ctx, l.transport, maconn, network.DirInbound, "")
Expand Down
21 changes: 7 additions & 14 deletions p2p/net/upgrader/listener_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func (mux *MuxAdapter) SecureOutbound(ctx context.Context, insecure net.Conn, p
return sconn, false, err
}

func createListener(t *testing.T, upgrader *st.Upgrader) transport.Listener {
func createListener(t *testing.T, upgrader transport.Upgrader) transport.Listener {
t.Helper()
addr, err := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/0")
require.NoError(t, err)
Expand Down Expand Up @@ -97,8 +97,7 @@ func TestConnectionsClosedIfNotAccepted(t *testing.T) {
timeout = 500 * time.Millisecond
}

id, upgrader := createUpgrader(t)
upgrader.AcceptTimeout = timeout
id, upgrader := createUpgrader(t, st.WithAcceptTimeout(timeout))
ln := createListener(t, upgrader)
defer ln.Close()

Expand Down Expand Up @@ -132,10 +131,8 @@ func TestConnectionsClosedIfNotAccepted(t *testing.T) {
func TestFailedUpgradeOnListen(t *testing.T) {
require := require.New(t)

id, upgrader := createUpgrader(t)
upgrader.Muxer = &errorMuxer{}
id, upgrader := createUpgraderWithMuxer(t, &errorMuxer{})
ln := createListener(t, upgrader)
defer ln.Close()

errCh := make(chan error)
go func() {
Expand Down Expand Up @@ -171,9 +168,8 @@ func TestListenerClose(t *testing.T) {
}

// unblocks Accept when it is closed.
err := ln.Close()
require.NoError(err)
err = <-errCh
require.NoError(ln.Close())
err := <-errCh
require.Error(err)
require.Contains(err.Error(), "use of closed network connection")

Expand Down Expand Up @@ -225,10 +221,8 @@ func TestListenerCloseClosesQueued(t *testing.T) {
func TestConcurrentAccept(t *testing.T) {
var num = 3 * st.AcceptQueueLength

id, upgrader := createUpgrader(t)
blockingMuxer := newBlockingMuxer()
upgrader.Muxer = blockingMuxer

id, upgrader := createUpgraderWithMuxer(t, blockingMuxer)
ln := createListener(t, upgrader)
defer ln.Close()

Expand Down Expand Up @@ -312,8 +306,7 @@ func TestListenerConnectionGater(t *testing.T) {
require := require.New(t)

testGater := &testGater{}
id, upgrader := createUpgrader(t)
upgrader.ConnGater = testGater
id, upgrader := createUpgrader(t, st.WithConnectionGater(testGater))

ln := createListener(t, upgrader)
defer ln.Close()
Expand Down
95 changes: 56 additions & 39 deletions p2p/net/upgrader/upgrader.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,24 +28,63 @@ var AcceptQueueLength = 16

const defaultAcceptTimeout = 15 * time.Second

type Option func(*upgrader) error

func WithPSK(psk ipnet.PSK) Option {
return func(u *upgrader) error {
u.psk = psk
return nil
}
}

func WithAcceptTimeout(t time.Duration) Option {
return func(u *upgrader) error {
u.acceptTimeout = t
return nil
}
}

func WithConnectionGater(g connmgr.ConnectionGater) Option {
return func(u *upgrader) error {
u.connGater = g
return nil
}
}

// Upgrader is a multistream upgrader that can upgrade an underlying connection
// to a full transport connection (secure and multiplexed).
type Upgrader struct {
PSK ipnet.PSK
Secure sec.SecureMuxer
Muxer mux.Multiplexer
ConnGater connmgr.ConnectionGater
type upgrader struct {
secure sec.SecureMuxer
muxer mux.Multiplexer

psk ipnet.PSK
connGater connmgr.ConnectionGater

// AcceptTimeout is the maximum duration an Accept is allowed to take.
// This includes the time between accepting the raw network connection,
// protocol selection as well as the handshake, if applicable.
//
// If unset, the default value (15s) is used.
AcceptTimeout time.Duration
acceptTimeout time.Duration
}

var _ transport.Upgrader = &upgrader{}

func New(secureMuxer sec.SecureMuxer, muxer mux.Multiplexer, opts ...Option) (transport.Upgrader, error) {
u := &upgrader{
secure: secureMuxer,
muxer: muxer,
acceptTimeout: defaultAcceptTimeout,
}
for _, opt := range opts {
if err := opt(u); err != nil {
return nil, err
}
}
return u, nil
}

// UpgradeListener upgrades the passed multiaddr-net listener into a full libp2p-transport listener.
func (u *Upgrader) UpgradeListener(t transport.Transport, list manet.Listener) transport.Listener {
func (u *upgrader) UpgradeListener(t transport.Transport, list manet.Listener) transport.Listener {
ctx, cancel := context.WithCancel(context.Background())
l := &listener{
Listener: list,
Expand All @@ -60,22 +99,7 @@ func (u *Upgrader) UpgradeListener(t transport.Transport, list manet.Listener) t
return l
}

// UpgradeOutbound upgrades the given outbound multiaddr-net connection into a
// full libp2p-transport connection.
// Deprecated: use Upgrade instead.
func (u *Upgrader) UpgradeOutbound(ctx context.Context, t transport.Transport, maconn manet.Conn, p peer.ID) (transport.CapableConn, error) {
return u.Upgrade(ctx, t, maconn, network.DirOutbound, p)
}

// UpgradeInbound upgrades the given inbound multiaddr-net connection into a
// full libp2p-transport connection.
// Deprecated: use Upgrade instead.
func (u *Upgrader) UpgradeInbound(ctx context.Context, t transport.Transport, maconn manet.Conn) (transport.CapableConn, error) {
return u.Upgrade(ctx, t, maconn, network.DirInbound, "")
}

// Upgrade upgrades the multiaddr/net connection into a full libp2p-transport connection.
func (u *Upgrader) Upgrade(ctx context.Context, t transport.Transport, maconn manet.Conn, dir network.Direction, p peer.ID) (transport.CapableConn, error) {
func (u *upgrader) Upgrade(ctx context.Context, t transport.Transport, maconn manet.Conn, dir network.Direction, p peer.ID) (transport.CapableConn, error) {
if dir == network.DirOutbound && p == "" {
return nil, ErrNilPeer
}
Expand All @@ -85,8 +109,8 @@ func (u *Upgrader) Upgrade(ctx context.Context, t transport.Transport, maconn ma
}

var conn net.Conn = maconn
if u.PSK != nil {
pconn, err := pnet.NewProtectedConn(u.PSK, conn)
if u.psk != nil {
pconn, err := pnet.NewProtectedConn(u.psk, conn)
if err != nil {
conn.Close()
return nil, fmt.Errorf("failed to setup private network protector: %s", err)
Expand All @@ -104,7 +128,7 @@ func (u *Upgrader) Upgrade(ctx context.Context, t transport.Transport, maconn ma
}

// call the connection gater, if one is registered.
if u.ConnGater != nil && !u.ConnGater.InterceptSecured(dir, sconn.RemotePeer(), maconn) {
if u.connGater != nil && !u.connGater.InterceptSecured(dir, sconn.RemotePeer(), maconn) {
if err := maconn.Close(); err != nil {
log.Errorf("failed to close connection with peer %s and addr %s; err: %s",
p.Pretty(), maconn.RemoteMultiaddr(), err)
Expand All @@ -129,22 +153,22 @@ func (u *Upgrader) Upgrade(ctx context.Context, t transport.Transport, maconn ma
return tc, nil
}

func (u *Upgrader) setupSecurity(ctx context.Context, conn net.Conn, p peer.ID, dir network.Direction) (sec.SecureConn, bool, error) {
func (u *upgrader) setupSecurity(ctx context.Context, conn net.Conn, p peer.ID, dir network.Direction) (sec.SecureConn, bool, error) {
if dir == network.DirInbound {
return u.Secure.SecureInbound(ctx, conn, p)
return u.secure.SecureInbound(ctx, conn, p)
}
return u.Secure.SecureOutbound(ctx, conn, p)
return u.secure.SecureOutbound(ctx, conn, p)
}

func (u *Upgrader) setupMuxer(ctx context.Context, conn net.Conn, server bool) (mux.MuxedConn, error) {
func (u *upgrader) setupMuxer(ctx context.Context, conn net.Conn, server bool) (mux.MuxedConn, error) {
// TODO: The muxer should take a context.
done := make(chan struct{})

var smconn mux.MuxedConn
var err error
go func() {
defer close(done)
smconn, err = u.Muxer.NewConn(conn, server)
smconn, err = u.muxer.NewConn(conn, server)
}()

select {
Expand All @@ -158,10 +182,3 @@ func (u *Upgrader) setupMuxer(ctx context.Context, conn net.Conn, server bool) (
return nil, ctx.Err()
}
}

func (u *Upgrader) acceptTimeout() time.Duration {
if u.AcceptTimeout == 0 {
return defaultAcceptTimeout
}
return u.AcceptTimeout
}
19 changes: 11 additions & 8 deletions p2p/net/upgrader/upgrader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/libp2p/go-libp2p-core/sec/insecure"
"github.com/libp2p/go-libp2p-core/test"
"github.com/libp2p/go-libp2p-core/transport"

mplex "github.com/libp2p/go-libp2p-mplex"
st "github.com/libp2p/go-libp2p-transport-upgrader"

Expand All @@ -22,15 +23,18 @@ import (
"github.com/stretchr/testify/require"
)

func createUpgrader(t *testing.T) (peer.ID, *st.Upgrader) {
func createUpgrader(t *testing.T, opts ...st.Option) (peer.ID, transport.Upgrader) {
return createUpgraderWithMuxer(t, &negotiatingMuxer{}, opts...)
}

func createUpgraderWithMuxer(t *testing.T, muxer mux.Multiplexer, opts ...st.Option) (peer.ID, transport.Upgrader) {
priv, _, err := test.RandTestKeyPair(crypto.Ed25519, 256)
require.NoError(t, err)
id, err := peer.IDFromPrivateKey(priv)
require.NoError(t, err)
return id, &st.Upgrader{
Secure: &MuxAdapter{tpt: insecure.NewWithIdentity(id, priv)},
Muxer: &negotiatingMuxer{},
}
u, err := st.New(&MuxAdapter{tpt: insecure.NewWithIdentity(id, priv)}, muxer, opts...)
require.NoError(t, err)
return id, u
}

// negotiatingMuxer sets up a new mplex connection
Expand Down Expand Up @@ -99,7 +103,7 @@ func testConn(t *testing.T, clientConn, serverConn transport.CapableConn) {
require.Equal([]byte("foobar"), b)
}

func dial(t *testing.T, upgrader *st.Upgrader, raddr ma.Multiaddr, p peer.ID) (transport.CapableConn, error) {
func dial(t *testing.T, upgrader transport.Upgrader, raddr ma.Multiaddr, p peer.ID) (transport.CapableConn, error) {
t.Helper()

macon, err := manet.Dial(raddr)
Expand All @@ -117,8 +121,7 @@ func TestOutboundConnectionGating(t *testing.T) {
defer ln.Close()

testGater := &testGater{}
_, dialUpgrader := createUpgrader(t)
dialUpgrader.ConnGater = testGater
_, dialUpgrader := createUpgrader(t, st.WithConnectionGater(testGater))
conn, err := dial(t, dialUpgrader, ln.Multiaddr(), id)
require.NoError(err)
require.NotNil(conn)
Expand Down

0 comments on commit 17c6e5e

Please sign in to comment.