Skip to content

Commit

Permalink
webrtc: setup datachannel handlers before connecting to a peer (#2716)
Browse files Browse the repository at this point in the history
If done after connecting to a peer, there's a small window of time
when datachannels created by the peer may cause a memory leak.
  • Loading branch information
sukunrt authored Feb 23, 2024
1 parent 972278a commit 473a5e9
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 133 deletions.
50 changes: 2 additions & 48 deletions p2p/transport/webrtc/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,10 @@ import (
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
tpt "github.com/libp2p/go-libp2p/core/transport"
"github.com/libp2p/go-libp2p/p2p/transport/webrtc/pb"

"github.com/libp2p/go-msgio"

ma "github.com/multiformats/go-multiaddr"
"github.com/pion/datachannel"
"github.com/pion/webrtc/v3"
"google.golang.org/protobuf/proto"
)

var _ tpt.CapableConn = &connection{}
Expand Down Expand Up @@ -77,6 +73,7 @@ func newConnection(
remotePeer peer.ID,
remoteKey ic.PubKey,
remoteMultiaddr ma.Multiaddr,
incomingDataChannels chan dataChannel,
) (*connection, error) {
ctx, cancel := context.WithCancel(context.Background())
c := &connection{
Expand All @@ -94,7 +91,7 @@ func newConnection(
cancel: cancel,
streams: make(map[uint16]*stream),

acceptQueue: make(chan dataChannel, maxAcceptQueueLen),
acceptQueue: incomingDataChannels,
}
switch direction {
case network.DirInbound:
Expand All @@ -105,24 +102,6 @@ func newConnection(
}

pc.OnConnectionStateChange(c.onConnectionStateChange)
pc.OnDataChannel(func(dc *webrtc.DataChannel) {
dc.OnOpen(func() {
rwc, err := dc.Detach()
if err != nil {
log.Warnf("could not detach datachannel: id: %d", *dc.ID())
return
}
select {
case c.acceptQueue <- dataChannel{rwc, dc}:
default:
log.Warnf("connection busy, rejecting stream")
b, _ := proto.Marshal(&pb.Message{Flag: pb.Message_RESET.Enum()})
w := msgio.NewWriter(rwc)
w.WriteMsg(b)
rwc.Close()
}
})
})
return c, nil
}

Expand Down Expand Up @@ -274,28 +253,3 @@ func (c *connection) detachChannel(ctx context.Context, dc *webrtc.DataChannel)
return rwc, err
}
}

// A note on these setters and why they are needed:
//
// The connection object sets up receiving datachannels (streams) from the remote peer.
// Please consider the XX noise handshake pattern from a peer A to peer B as described at:
// https://noiseexplorer.com/patterns/XX/
//
// The initiator A completes the noise handshake before B.
// This would allow A to create new datachannels before B has set up the callbacks to process incoming datachannels.
// This would create a situation where A has successfully created a stream but B is not aware of it.
// Moving the construction of the connection object before the noise handshake eliminates this issue,
// as callbacks have been set up for both peers.
//
// This could lead to a case where streams are created during the noise handshake,
// and the handshake fails. In this case, we would close the underlying peerconnection.

// only used during connection setup
func (c *connection) setRemotePeer(id peer.ID) {
c.remotePeer = id
}

// only used during connection setup
func (c *connection) setRemotePublicKey(key ic.PubKey) {
c.remoteKey = key
}
75 changes: 29 additions & 46 deletions p2p/transport/webrtc/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,11 @@ func (l *listener) setupConnection(
ctx context.Context, scope network.ConnManagementScope,
remoteMultiaddr ma.Multiaddr, candidate udpmux.Candidate,
) (tConn tpt.CapableConn, err error) {
var pc *webrtc.PeerConnection
var w webRTCConnection
defer func() {
if err != nil {
if pc != nil {
_ = pc.Close()
if w.PeerConnection != nil {
_ = w.PeerConnection.Close()
}
if tConn != nil {
_ = tConn.Close()
Expand Down Expand Up @@ -224,34 +224,24 @@ func (l *listener) setupConnection(
)
settingEngine.DetachDataChannels()

api := webrtc.NewAPI(webrtc.WithSettingEngine(settingEngine))
pc, err = api.NewPeerConnection(l.config)
w, err = newWebRTCConnection(settingEngine, l.config)
if err != nil {
return nil, err
return nil, fmt.Errorf("instantiating peer connection failed: %w", err)
}

negotiated, id := handshakeChannelNegotiated, handshakeChannelID
rawDatachannel, err := pc.CreateDataChannel("", &webrtc.DataChannelInit{
Negotiated: &negotiated,
ID: &id,
})
if err != nil {
return nil, err
}

errC := addOnConnectionStateChangeCallback(pc)
errC := addOnConnectionStateChangeCallback(w.PeerConnection)
// Infer the client SDP from the incoming STUN message by setting the ice-ufrag.
if err := pc.SetRemoteDescription(webrtc.SessionDescription{
if err := w.PeerConnection.SetRemoteDescription(webrtc.SessionDescription{
SDP: createClientSDP(candidate.Addr, candidate.Ufrag),
Type: webrtc.SDPTypeOffer,
}); err != nil {
return nil, err
}
answer, err := pc.CreateAnswer(nil)
answer, err := w.PeerConnection.CreateAnswer(nil)
if err != nil {
return nil, err
}
if err := pc.SetLocalDescription(answer); err != nil {
if err := w.PeerConnection.SetLocalDescription(answer); err != nil {
return nil, err
}

Expand All @@ -264,49 +254,42 @@ func (l *listener) setupConnection(
}
}

rwc, err := detachHandshakeDataChannel(ctx, rawDatachannel)
if err != nil {
return nil, err
}

localMultiaddrWithoutCerthash, _ := ma.SplitFunc(l.localMultiaddr, func(c ma.Component) bool { return c.Protocol().Code == ma.P_CERTHASH })

handshakeChannel := newStream(rawDatachannel, rwc, func() {})
// The connection is instantiated before performing the Noise handshake. This is
// to handle the case where the remote is faster and attempts to initiate a stream
// before the ondatachannel callback can be set.
conn, err := newConnection(
network.DirInbound,
pc,
l.transport,
scope,
l.transport.localPeerId,
localMultiaddrWithoutCerthash,
"", // remotePeer
nil, // remoteKey
remoteMultiaddr,
)
// Run the noise handshake.
rwc, err := detachHandshakeDataChannel(ctx, w.HandshakeDataChannel)
if err != nil {
return nil, err
}

handshakeChannel := newStream(w.HandshakeDataChannel, rwc, func() {})
// we do not yet know A's peer ID so accept any inbound
remotePubKey, err := l.transport.noiseHandshake(ctx, pc, handshakeChannel, "", crypto.SHA256, true)
remotePubKey, err := l.transport.noiseHandshake(ctx, w.PeerConnection, handshakeChannel, "", crypto.SHA256, true)
if err != nil {
return nil, err
}
remotePeer, err := peer.IDFromPublicKey(remotePubKey)
if err != nil {
return nil, err
}

// earliest point where we know the remote's peerID
if err := scope.SetPeer(remotePeer); err != nil {
return nil, err
}

conn.setRemotePeer(remotePeer)
conn.setRemotePublicKey(remotePubKey)
localMultiaddrWithoutCerthash, _ := ma.SplitFunc(l.localMultiaddr, func(c ma.Component) bool { return c.Protocol().Code == ma.P_CERTHASH })
conn, err := newConnection(
network.DirInbound,
w.PeerConnection,
l.transport,
scope,
l.transport.localPeerId,
localMultiaddrWithoutCerthash,
remotePeer,
remotePubKey,
remoteMultiaddr,
w.IncomingDataChannels,
)
if err != nil {
return nil, err
}

return conn, err
}
Expand Down
Loading

0 comments on commit 473a5e9

Please sign in to comment.