Skip to content

Commit

Permalink
Merge pull request #272 from jbenet/context-dial-peer
Browse files Browse the repository at this point in the history
add context to DialPeer interface
  • Loading branch information
Brian Tiger Chow committed Nov 5, 2014
2 parents 9d7e0bb + 390f4d7 commit d742984
Show file tree
Hide file tree
Showing 12 changed files with 118 additions and 37 deletions.
46 changes: 29 additions & 17 deletions exchange/bitswap/bitswap.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,28 @@ import (
var log = u.Logger("bitswap")

// NetMessageSession initializes a BitSwap session that communicates over the
// provided NetMessage service
func NetMessageSession(parent context.Context, p peer.Peer,
// provided NetMessage service.
// Runs until context is cancelled
func NetMessageSession(ctx context.Context, p peer.Peer,
net inet.Network, srv inet.Service, directory bsnet.Routing,
d ds.ThreadSafeDatastore, nice bool) exchange.Interface {

networkAdapter := bsnet.NetMessageAdapter(srv, net, nil)

notif := notifications.New()

go func() {
for {
select {
case <-ctx.Done():
notif.Shutdown()
}
}
}()

bs := &bitswap{
blockstore: blockstore.NewBlockstore(d),
notifications: notifications.New(), // TODO Shutdown()
notifications: notif,
strategy: strategy.New(nice),
routing: directory,
sender: networkAdapter,
Expand Down Expand Up @@ -75,6 +88,8 @@ func (bs *bitswap) Block(parent context.Context, k u.Key) (*blocks.Block, error)
}()

ctx, cancelFunc := context.WithCancel(parent)
defer cancelFunc()

bs.wantlist.Add(k)
promise := bs.notifications.Subscribe(ctx, k)

Expand All @@ -91,7 +106,7 @@ func (bs *bitswap) Block(parent context.Context, k u.Key) (*blocks.Block, error)
go func(p peer.Peer) {

log.Debugf("bitswap dialing peer: %s", p)
err := bs.sender.DialPeer(p)
err := bs.sender.DialPeer(ctx, p)
if err != nil {
log.Errorf("Error sender.DialPeer(%s)", p)
return
Expand All @@ -117,17 +132,15 @@ func (bs *bitswap) Block(parent context.Context, k u.Key) (*blocks.Block, error)

select {
case block := <-promise:
cancelFunc()
bs.wantlist.Remove(k)
// TODO remove from wantlist
return &block, nil
case <-parent.Done():
return nil, parent.Err()
}
}

// HasBlock announces the existance of a block to bitswap, potentially sending
// it to peers (Partners) whose WantLists include it.
// HasBlock announces the existance of a block to this bitswap service. The
// service will potentially notify its peers.
func (bs *bitswap) HasBlock(ctx context.Context, blk blocks.Block) error {
log.Debugf("Has Block %v", blk.Key())
bs.wantlist.Remove(blk.Key())
Expand Down Expand Up @@ -162,13 +175,11 @@ func (bs *bitswap) ReceiveMessage(ctx context.Context, p peer.Peer, incoming bsm
if err := bs.blockstore.Put(&block); err != nil {
continue // FIXME(brian): err ignored
}
go bs.notifications.Publish(block)
go func(block blocks.Block) {
err := bs.HasBlock(ctx, block) // FIXME err ignored
if err != nil {
log.Warningf("HasBlock errored: %s", err)
}
}(block)
bs.notifications.Publish(block)
err := bs.HasBlock(ctx, block)
if err != nil {
log.Warningf("HasBlock errored: %s", err)
}
}

message := bsmsg.New()
Expand Down Expand Up @@ -202,11 +213,12 @@ func (bs *bitswap) ReceiveError(err error) {
// sent
func (bs *bitswap) send(ctx context.Context, p peer.Peer, m bsmsg.BitSwapMessage) {
bs.sender.SendMessage(ctx, p, m)
go bs.strategy.MessageSent(p, m)
bs.strategy.MessageSent(p, m)
}

func (bs *bitswap) sendToPeersThatWant(ctx context.Context, block blocks.Block) {
log.Debugf("Sending %v to peers that want it", block.Key())

for _, p := range bs.strategy.Peers() {
if bs.strategy.BlockIsWantedByPeer(block.Key(), p) {
log.Debugf("%v wants %v", p, block.Key())
Expand All @@ -216,7 +228,7 @@ func (bs *bitswap) sendToPeersThatWant(ctx context.Context, block blocks.Block)
for _, wanted := range bs.wantlist.Keys() {
message.AddWanted(wanted)
}
go bs.send(ctx, p, message)
bs.send(ctx, p, message)
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion exchange/bitswap/network/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
type Adapter interface {

// DialPeer ensures there is a connection to peer.
DialPeer(peer.Peer) error
DialPeer(context.Context, peer.Peer) error

// SendMessage sends a BitSwap message to a peer.
SendMessage(
Expand Down
4 changes: 2 additions & 2 deletions exchange/bitswap/network/net_message_adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ func (adapter *impl) HandleMessage(
return outgoing
}

func (adapter *impl) DialPeer(p peer.Peer) error {
return adapter.net.DialPeer(p)
func (adapter *impl) DialPeer(ctx context.Context, p peer.Peer) error {
return adapter.net.DialPeer(ctx, p)
}

func (adapter *impl) SendMessage(
Expand Down
2 changes: 1 addition & 1 deletion exchange/bitswap/testnet/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ func (nc *networkClient) SendRequest(
return nc.network.SendRequest(ctx, nc.local, to, message)
}

func (nc *networkClient) DialPeer(p peer.Peer) error {
func (nc *networkClient) DialPeer(ctx context.Context, p peer.Peer) error {
// no need to do anything because dialing isn't a thing in this test net.
if !nc.network.HasPeer(p) {
return fmt.Errorf("Peer not in network: %s", p)
Expand Down
5 changes: 3 additions & 2 deletions net/interface.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package net

import (
"github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context"
msg "github.com/jbenet/go-ipfs/net/message"
mux "github.com/jbenet/go-ipfs/net/mux"
srv "github.com/jbenet/go-ipfs/net/service"
Expand All @@ -19,7 +20,7 @@ type Network interface {
// TODO: for now, only listen on addrs in local peer when initializing.

// DialPeer attempts to establish a connection to a given peer
DialPeer(peer.Peer) error
DialPeer(context.Context, peer.Peer) error

// ClosePeer connection to peer
ClosePeer(peer.Peer) error
Expand Down Expand Up @@ -64,5 +65,5 @@ type Service srv.Service
type Dialer interface {

// DialPeer attempts to establish a connection to a given peer
DialPeer(peer.Peer) error
DialPeer(context.Context, peer.Peer) error
}
11 changes: 8 additions & 3 deletions net/net.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
mux "github.com/jbenet/go-ipfs/net/mux"
swarm "github.com/jbenet/go-ipfs/net/swarm"
peer "github.com/jbenet/go-ipfs/peer"
util "github.com/jbenet/go-ipfs/util"
ctxc "github.com/jbenet/go-ipfs/util/ctxcloser"

context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context"
Expand Down Expand Up @@ -56,9 +57,13 @@ func NewIpfsNetwork(ctx context.Context, listen []ma.Multiaddr, local peer.Peer,
// Listen handles incoming connections on given Multiaddr.
// func (n *IpfsNetwork) Listen(*ma.Muliaddr) error {}

// DialPeer attempts to establish a connection to a given peer
func (n *IpfsNetwork) DialPeer(p peer.Peer) error {
_, err := n.swarm.Dial(p)
// DialPeer attempts to establish a connection to a given peer.
// Respects the context.
func (n *IpfsNetwork) DialPeer(ctx context.Context, p peer.Peer) error {
err := util.ContextDo(ctx, func() error {
_, err := n.swarm.Dial(p)
return err
})
return err
}

Expand Down
10 changes: 5 additions & 5 deletions routing/dht/dht.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func (dht *IpfsDHT) Connect(ctx context.Context, npeer peer.Peer) (peer.Peer, er
//
// /ip4/10.20.30.40/tcp/1234/ipfs/Qxhxxchxzcncxnzcnxzcxzm
//
err := dht.dialer.DialPeer(npeer)
err := dht.dialer.DialPeer(ctx, npeer)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -311,7 +311,7 @@ func (dht *IpfsDHT) getFromPeerList(ctx context.Context, key u.Key,
peerlist []*pb.Message_Peer, level int) ([]byte, error) {

for _, pinfo := range peerlist {
p, err := dht.ensureConnectedToPeer(pinfo)
p, err := dht.ensureConnectedToPeer(ctx, pinfo)
if err != nil {
log.Errorf("getFromPeers error: %s", err)
continue
Expand Down Expand Up @@ -496,14 +496,14 @@ func (dht *IpfsDHT) peerFromInfo(pbp *pb.Message_Peer) (peer.Peer, error) {
return p, nil
}

func (dht *IpfsDHT) ensureConnectedToPeer(pbp *pb.Message_Peer) (peer.Peer, error) {
func (dht *IpfsDHT) ensureConnectedToPeer(ctx context.Context, pbp *pb.Message_Peer) (peer.Peer, error) {
p, err := dht.peerFromInfo(pbp)
if err != nil {
return nil, err
}

// dial connection
err = dht.dialer.DialPeer(p)
err = dht.dialer.DialPeer(ctx, p)
return p, err
}

Expand Down Expand Up @@ -556,7 +556,7 @@ func (dht *IpfsDHT) Bootstrap(ctx context.Context) {
if err != nil {
log.Error("Bootstrap peer error: %s", err)
}
err = dht.dialer.DialPeer(p)
err = dht.dialer.DialPeer(ctx, p)
if err != nil {
log.Errorf("Bootstrap peer error: %s", err)
}
Expand Down
3 changes: 1 addition & 2 deletions routing/dht/ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"testing"

crand "crypto/rand"

context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context"
"github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto"

Expand Down Expand Up @@ -82,7 +81,7 @@ type fauxNet struct {
}

// DialPeer attempts to establish a connection to a given peer
func (f *fauxNet) DialPeer(peer.Peer) error {
func (f *fauxNet) DialPeer(context.Context, peer.Peer) error {
return nil
}

Expand Down
2 changes: 1 addition & 1 deletion routing/dht/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ func (r *dhtQueryRunner) queryPeer(p peer.Peer) {

// make sure we're connected to the peer.
// (Incidentally, this will add it to the peerstore too)
err := r.query.dialer.DialPeer(p)
err := r.query.dialer.DialPeer(r.ctx, p)
if err != nil {
log.Debugf("ERROR worker for: %v -- err connecting: %v", p, err)
r.Lock()
Expand Down
6 changes: 3 additions & 3 deletions routing/dht/routing.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ func (dht *IpfsDHT) FindProvidersAsync(ctx context.Context, key u.Key, count int
log.Error(err)
return
}
dht.addPeerListAsync(key, pmes.GetProviderPeers(), ps, count, peerOut)
dht.addPeerListAsync(ctx, key, pmes.GetProviderPeers(), ps, count, peerOut)
}(pp)
}
wg.Wait()
Expand All @@ -154,13 +154,13 @@ func (dht *IpfsDHT) FindProvidersAsync(ctx context.Context, key u.Key, count int
return peerOut
}

func (dht *IpfsDHT) addPeerListAsync(k u.Key, peers []*pb.Message_Peer, ps *peerSet, count int, out chan peer.Peer) {
func (dht *IpfsDHT) addPeerListAsync(ctx context.Context, k u.Key, peers []*pb.Message_Peer, ps *peerSet, count int, out chan peer.Peer) {
done := make(chan struct{})
for _, pbp := range peers {
go func(mp *pb.Message_Peer) {
defer func() { done <- struct{}{} }()
// construct new peer
p, err := dht.ensureConnectedToPeer(mp)
p, err := dht.ensureConnectedToPeer(ctx, mp)
if err != nil {
log.Error("%s", err)
return
Expand Down
22 changes: 22 additions & 0 deletions util/do.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package util

import "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context"

func ContextDo(ctx context.Context, f func() error) error {

ch := make(chan error)

go func() {
select {
case <-ctx.Done():
case ch <- f():
}
}()
select {
case <-ctx.Done():
return ctx.Err()
case val := <-ch:
return val
}
return nil
}
42 changes: 42 additions & 0 deletions util/do_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package util

import (
"errors"
"testing"

"github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context"
)

func TestDoReturnsContextErr(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
ch := make(chan struct{})
err := ContextDo(ctx, func() error {
cancel()
ch <- struct{}{} // won't return
return nil
})
if err != ctx.Err() {
t.Fail()
}
}

func TestDoReturnsFuncError(t *testing.T) {
ctx := context.Background()
expected := errors.New("expected to be returned by ContextDo")
err := ContextDo(ctx, func() error {
return expected
})
if err != expected {
t.Fail()
}
}

func TestDoReturnsNil(t *testing.T) {
ctx := context.Background()
err := ContextDo(ctx, func() error {
return nil
})
if err != nil {
t.Fail()
}
}

0 comments on commit d742984

Please sign in to comment.