From 604ae5e973fa49b2d0cfaab00c7ab413f5dc4d22 Mon Sep 17 00:00:00 2001 From: hannahhoward Date: Wed, 28 Nov 2018 14:26:25 -0800 Subject: [PATCH] refactor(sessions): extract sessions to package - moved sessions out of main bitswap package - modified session manager to manage all sessions - moved get functions to their own package so sessions can directly BREAKING CHANGE: SessionsForBlock, while not used outside of Bitswap, has been removed, and was an exported function --- bitswap.go | 33 +- ...n_test.go => bitswap_with_sessions_test.go | 5 +- dup_blocks_test.go | 5 +- get.go => getter/getter.go | 22 +- session.go => session/session.go | 372 ++++++++++-------- sessionmanager/sessionmanager.go | 65 ++- 6 files changed, 303 insertions(+), 199 deletions(-) rename session_test.go => bitswap_with_sessions_test.go (97%) rename get.go => getter/getter.go (68%) rename session.go => session/session.go (54%) diff --git a/bitswap.go b/bitswap.go index b3e472d2..aa7de181 100644 --- a/bitswap.go +++ b/bitswap.go @@ -10,6 +10,7 @@ import ( "time" decision "github.com/ipfs/go-bitswap/decision" + bsgetter "github.com/ipfs/go-bitswap/getter" bsmsg "github.com/ipfs/go-bitswap/message" bsmq "github.com/ipfs/go-bitswap/messagequeue" bsnet "github.com/ipfs/go-bitswap/network" @@ -100,6 +101,7 @@ func New(parent context.Context, network bsnet.BitSwapNetwork, return bsmq.New(p, network) } + wm := bswm.New(ctx) bs := &Bitswap{ blockstore: bstore, notifications: notif, @@ -109,9 +111,9 @@ func New(parent context.Context, network bsnet.BitSwapNetwork, process: px, newBlocks: make(chan cid.Cid, HasBlockBufferSize), provideKeys: make(chan cid.Cid, provideKeysBufferSize), - wm: bswm.New(ctx), + wm: wm, pm: bspm.New(ctx, peerQueueFactory), - sm: bssm.New(), + sm: bssm.New(ctx, wm, network), counters: new(counters), dupMetric: dupHist, allMetric: allHist, @@ -202,7 +204,7 @@ type blockRequest struct { // GetBlock attempts to retrieve a particular block from peers within the // deadline enforced by the context. func (bs *Bitswap) GetBlock(parent context.Context, k cid.Cid) (blocks.Block, error) { - return getBlock(parent, k, bs.GetBlocks) + return bsgetter.SyncGetBlock(parent, k, bs.GetBlocks) } func (bs *Bitswap) WantlistForPeer(p peer.ID) []cid.Cid { @@ -307,7 +309,7 @@ func (bs *Bitswap) GetBlocks(ctx context.Context, keys []cid.Cid) (<-chan blocks return out, nil } -// CancelWant removes a given key from the wantlist +// CancelWants removes a given key from the wantlist func (bs *Bitswap) CancelWants(cids []cid.Cid, ses uint64) { if len(cids) == 0 { return @@ -345,12 +347,7 @@ func (bs *Bitswap) receiveBlockFrom(blk blocks.Block, from peer.ID) error { // it now as it requires more thought and isnt causing immediate problems. bs.notifications.Publish(blk) - k := blk.Cid() - ks := []cid.Cid{k} - for _, s := range bs.SessionsForBlock(k) { - s.receiveBlockFrom(from, blk) - bs.CancelWants(ks, s.id) - } + bs.sm.ReceiveBlockFrom(from, blk) bs.engine.AddBlock(blk) @@ -363,18 +360,6 @@ func (bs *Bitswap) receiveBlockFrom(blk blocks.Block, from peer.ID) error { return nil } -// SessionsForBlock returns a slice of all sessions that may be interested in the given cid -func (bs *Bitswap) SessionsForBlock(c cid.Cid) []*Session { - var out []*Session - bs.sm.IterateSessions(func(session exchange.Fetcher) { - s := session.(*Session) - if s.interestedIn(c) { - out = append(out, s) - } - }) - return out -} - func (bs *Bitswap) ReceiveMessage(ctx context.Context, p peer.ID, incoming bsmsg.BitSwapMessage) { atomic.AddUint64(&bs.counters.messagesRecvd, 1) @@ -477,3 +462,7 @@ func (bs *Bitswap) GetWantlist() []cid.Cid { func (bs *Bitswap) IsOnline() bool { return true } + +func (bs *Bitswap) NewSession(ctx context.Context) exchange.Fetcher { + return bs.sm.NewSession(ctx) +} diff --git a/session_test.go b/bitswap_with_sessions_test.go similarity index 97% rename from session_test.go rename to bitswap_with_sessions_test.go index c5a00a90..5034aaee 100644 --- a/session_test.go +++ b/bitswap_with_sessions_test.go @@ -6,6 +6,7 @@ import ( "testing" "time" + bssession "github.com/ipfs/go-bitswap/session" blocks "github.com/ipfs/go-block-format" cid "github.com/ipfs/go-cid" blocksutil "github.com/ipfs/go-ipfs-blocksutil" @@ -132,8 +133,8 @@ func TestSessionSplitFetch(t *testing.T) { cids = append(cids, blk.Cid()) } - ses := inst[10].Exchange.NewSession(ctx).(*Session) - ses.baseTickDelay = time.Millisecond * 10 + ses := inst[10].Exchange.NewSession(ctx).(*bssession.Session) + ses.SetBaseTickDelay(time.Millisecond * 10) for i := 0; i < 10; i++ { ch, err := ses.GetBlocks(ctx, cids[i*10:(i+1)*10]) diff --git a/dup_blocks_test.go b/dup_blocks_test.go index a48889a3..58fc9614 100644 --- a/dup_blocks_test.go +++ b/dup_blocks_test.go @@ -11,6 +11,7 @@ import ( tn "github.com/ipfs/go-bitswap/testnet" + bssession "github.com/ipfs/go-bitswap/session" "github.com/ipfs/go-block-format" cid "github.com/ipfs/go-cid" blocksutil "github.com/ipfs/go-ipfs-blocksutil" @@ -248,14 +249,14 @@ func onePeerPerBlock(b *testing.B, provs []Instance, blks []blocks.Block) { } func oneAtATime(b *testing.B, bs *Bitswap, ks []cid.Cid) { - ses := bs.NewSession(context.Background()).(*Session) + ses := bs.NewSession(context.Background()).(*bssession.Session) for _, c := range ks { _, err := ses.GetBlock(context.Background(), c) if err != nil { b.Fatal(err) } } - b.Logf("Session fetch latency: %s", ses.latTotal/time.Duration(ses.fetchcnt)) + b.Logf("Session fetch latency: %s", ses.GetAverageLatency()) } // fetch data in batches, 10 at a time diff --git a/get.go b/getter/getter.go similarity index 68% rename from get.go rename to getter/getter.go index 8578277e..2ed97f2d 100644 --- a/get.go +++ b/getter/getter.go @@ -1,19 +1,27 @@ -package bitswap +package getter import ( "context" "errors" notifications "github.com/ipfs/go-bitswap/notifications" + logging "github.com/ipfs/go-log" blocks "github.com/ipfs/go-block-format" cid "github.com/ipfs/go-cid" blockstore "github.com/ipfs/go-ipfs-blockstore" ) -type getBlocksFunc func(context.Context, []cid.Cid) (<-chan blocks.Block, error) +var log = logging.Logger("bitswap") -func getBlock(p context.Context, k cid.Cid, gb getBlocksFunc) (blocks.Block, error) { +// GetBlocksFunc is any function that can take an array of CIDs and return a +// channel of incoming blocks +type GetBlocksFunc func(context.Context, []cid.Cid) (<-chan blocks.Block, error) + +// SyncGetBlock takes a block cid and an async function for getting several +// blocks that returns a channel, and uses that function to return the +// block syncronously +func SyncGetBlock(p context.Context, k cid.Cid, gb GetBlocksFunc) (blocks.Block, error) { if !k.Defined() { log.Error("undefined cid in GetBlock") return nil, blockstore.ErrNotFound @@ -49,9 +57,13 @@ func getBlock(p context.Context, k cid.Cid, gb getBlocksFunc) (blocks.Block, err } } -type wantFunc func(context.Context, []cid.Cid) +// WantFunc is any function that can express a want for set of blocks +type WantFunc func(context.Context, []cid.Cid) -func getBlocksImpl(ctx context.Context, keys []cid.Cid, notif notifications.PubSub, want wantFunc, cwants func([]cid.Cid)) (<-chan blocks.Block, error) { +// AsyncGetBlocks take a set of block cids, a pubsub channel for incoming +// blocks, a want function, and a close function, +// and returns a channel of incoming blocks +func AsyncGetBlocks(ctx context.Context, keys []cid.Cid, notif notifications.PubSub, want WantFunc, cwants func([]cid.Cid)) (<-chan blocks.Block, error) { if len(keys) == 0 { out := make(chan blocks.Block) close(out) diff --git a/session.go b/session/session.go similarity index 54% rename from session.go rename to session/session.go index cd5f645a..470aeafd 100644 --- a/session.go +++ b/session/session.go @@ -1,16 +1,16 @@ -package bitswap +package session import ( "context" "fmt" "time" - notifications "github.com/ipfs/go-bitswap/notifications" - lru "github.com/hashicorp/golang-lru" + bsgetter "github.com/ipfs/go-bitswap/getter" + bsnet "github.com/ipfs/go-bitswap/network" + notifications "github.com/ipfs/go-bitswap/notifications" blocks "github.com/ipfs/go-block-format" cid "github.com/ipfs/go-cid" - exchange "github.com/ipfs/go-ipfs-exchange-interface" logging "github.com/ipfs/go-log" loggables "github.com/libp2p/go-libp2p-loggables" peer "github.com/libp2p/go-libp2p-peer" @@ -18,41 +18,61 @@ import ( const activeWantsLimit = 16 -// Session holds state for an individual bitswap transfer operation. +// SessionWantmanager is an interface that can be used to request blocks +// from given peers +type SessionWantManager interface { + WantBlocks(ctx context.Context, ks []cid.Cid, peers []peer.ID, ses uint64) + CancelWants(ctx context.Context, ks []cid.Cid, peers []peer.ID, ses uint64) +} + +type interestReq struct { + c cid.Cid + resp chan bool +} + +type blkRecv struct { + from peer.ID + blk blocks.Block +} + +// session holds state for an individual bitswap transfer operation. // This allows bitswap to make smarter decisions about who to send wantlist // info to, and who to request blocks from type Session struct { - ctx context.Context + // dependencies + ctx context.Context + wm SessionWantManager + network bsnet.BitSwapNetwork + + // channels + incoming chan blkRecv + newReqs chan []cid.Cid + cancelKeys chan []cid.Cid + interestReqs chan interestReq + latencyReqs chan chan time.Duration + tickDelayReqs chan time.Duration + + // do not touch outside run loop tofetch *cidQueue activePeers map[peer.ID]struct{} activePeersArr []peer.ID - - bs *Bitswap - incoming chan blkRecv - newReqs chan []cid.Cid - cancelKeys chan []cid.Cid - interestReqs chan interestReq - - interest *lru.Cache - liveWants map[cid.Cid]time.Time - - tick *time.Timer - baseTickDelay time.Duration - - latTotal time.Duration - fetchcnt int - + interest *lru.Cache + liveWants map[cid.Cid]time.Time + tick *time.Timer + baseTickDelay time.Duration + latTotal time.Duration + fetchcnt int + + // identifiers notif notifications.PubSub - - uuid logging.Loggable - - id uint64 - tag string + uuid logging.Loggable + id uint64 + tag string } -// NewSession creates a new bitswap session whose lifetime is bounded by the +// New creates a new bitswap session whose lifetime is bounded by the // given context -func (bs *Bitswap) NewSession(ctx context.Context) exchange.Fetcher { +func New(ctx context.Context, id uint64, wm SessionWantManager, network bsnet.BitSwapNetwork) *Session { s := &Session{ activePeers: make(map[peer.ID]struct{}), liveWants: make(map[cid.Cid]time.Time), @@ -60,13 +80,16 @@ func (bs *Bitswap) NewSession(ctx context.Context) exchange.Fetcher { cancelKeys: make(chan []cid.Cid), tofetch: newCidQueue(), interestReqs: make(chan interestReq), + latencyReqs: make(chan chan time.Duration), + tickDelayReqs: make(chan time.Duration), ctx: ctx, - bs: bs, + wm: wm, + network: network, incoming: make(chan blkRecv), notif: notifications.New(), uuid: loggables.Uuid("GetBlockRequest"), baseTickDelay: time.Millisecond * 500, - id: bs.sm.GetNextSessionID(), + id: id, } s.tag = fmt.Sprint("bs-ses-", s.id) @@ -74,39 +97,63 @@ func (bs *Bitswap) NewSession(ctx context.Context) exchange.Fetcher { cache, _ := lru.New(2048) s.interest = cache - bs.sm.AddSession(s) go s.run(ctx) return s } -func (bs *Bitswap) removeSession(s *Session) { - s.notif.Shutdown() - - live := make([]cid.Cid, 0, len(s.liveWants)) - for c := range s.liveWants { - live = append(live, c) +// ReceiveBlockFrom receives an incoming block from the given peer +func (s *Session) ReceiveBlockFrom(from peer.ID, blk blocks.Block) { + select { + case s.incoming <- blkRecv{from: from, blk: blk}: + case <-s.ctx.Done(): } - bs.CancelWants(live, s.id) +} - bs.sm.RemoveSession(s) +// InterestedIn returns true if this session is interested in the given Cid +func (s *Session) InterestedIn(c cid.Cid) bool { + return s.interest.Contains(c) || s.isLiveWant(c) } -type blkRecv struct { - from peer.ID - blk blocks.Block +// GetBlock fetches a single block +func (s *Session) GetBlock(parent context.Context, k cid.Cid) (blocks.Block, error) { + return bsgetter.SyncGetBlock(parent, k, s.GetBlocks) +} + +// GetBlocks fetches a set of blocks within the context of this session and +// returns a channel that found blocks will be returned on. No order is +// guaranteed on the returned blocks. +func (s *Session) GetBlocks(ctx context.Context, keys []cid.Cid) (<-chan blocks.Block, error) { + ctx = logging.ContextWithLoggable(ctx, s.uuid) + return bsgetter.AsyncGetBlocks(ctx, keys, s.notif, s.fetch, s.cancel) +} + +// ID returns the sessions identifier +func (s *Session) ID() uint64 { + return s.id } -func (s *Session) receiveBlockFrom(from peer.ID, blk blocks.Block) { +func (s *Session) GetAverageLatency() time.Duration { + resp := make(chan time.Duration) select { - case s.incoming <- blkRecv{from: from, blk: blk}: + case s.latencyReqs <- resp: + case <-s.ctx.Done(): + return -1 * time.Millisecond + } + + select { + case latency := <-resp: + return latency case <-s.ctx.Done(): + return -1 * time.Millisecond } } -type interestReq struct { - c cid.Cid - resp chan bool +func (s *Session) SetBaseTickDelay(baseTickDelay time.Duration) { + select { + case s.tickDelayReqs <- baseTickDelay: + case <-s.ctx.Done(): + } } // TODO: PERF: this is using a channel to guard a map access against race @@ -135,114 +182,147 @@ func (s *Session) isLiveWant(c cid.Cid) bool { } } -func (s *Session) interestedIn(c cid.Cid) bool { - return s.interest.Contains(c) || s.isLiveWant(c) -} - -const provSearchDelay = time.Second * 10 - -func (s *Session) addActivePeer(p peer.ID) { - if _, ok := s.activePeers[p]; !ok { - s.activePeers[p] = struct{}{} - s.activePeersArr = append(s.activePeersArr, p) - - cmgr := s.bs.network.ConnectionManager() - cmgr.TagPeer(p, s.tag, 10) +func (s *Session) fetch(ctx context.Context, keys []cid.Cid) { + select { + case s.newReqs <- keys: + case <-ctx.Done(): + case <-s.ctx.Done(): } } -func (s *Session) resetTick() { - if s.latTotal == 0 { - s.tick.Reset(provSearchDelay) - } else { - avLat := s.latTotal / time.Duration(s.fetchcnt) - s.tick.Reset(s.baseTickDelay + (3 * avLat)) +func (s *Session) cancel(keys []cid.Cid) { + select { + case s.cancelKeys <- keys: + case <-s.ctx.Done(): } } +const provSearchDelay = time.Second * 10 + +// Session run loop -- everything function below here should not be called +// of this loop func (s *Session) run(ctx context.Context) { s.tick = time.NewTimer(provSearchDelay) newpeers := make(chan peer.ID, 16) for { select { case blk := <-s.incoming: - s.tick.Stop() - - if blk.from != "" { - s.addActivePeer(blk.from) - } - - s.receiveBlock(ctx, blk.blk) - - s.resetTick() + s.handleIncomingBlock(ctx, blk) case keys := <-s.newReqs: - for _, k := range keys { - s.interest.Add(k, nil) - } - if len(s.liveWants) < activeWantsLimit { - toadd := activeWantsLimit - len(s.liveWants) - if toadd > len(keys) { - toadd = len(keys) - } - - now := keys[:toadd] - keys = keys[toadd:] - - s.wantBlocks(ctx, now) - } - for _, k := range keys { - s.tofetch.Push(k) - } + s.handleNewRequest(ctx, keys) case keys := <-s.cancelKeys: - s.cancel(keys) - + s.handleCancel(keys) case <-s.tick.C: - live := make([]cid.Cid, 0, len(s.liveWants)) - now := time.Now() - for c := range s.liveWants { - live = append(live, c) - s.liveWants[c] = now - } - - // Broadcast these keys to everyone we're connected to - s.bs.wm.WantBlocks(ctx, live, nil, s.id) - - if len(live) > 0 { - go func(k cid.Cid) { - // TODO: have a task queue setup for this to: - // - rate limit - // - manage timeouts - // - ensure two 'findprovs' calls for the same block don't run concurrently - // - share peers between sessions based on interest set - for p := range s.bs.network.FindProvidersAsync(ctx, k, 10) { - newpeers <- p - } - }(live[0]) - } - s.resetTick() + s.handleTick(ctx, newpeers) case p := <-newpeers: s.addActivePeer(p) case lwchk := <-s.interestReqs: lwchk.resp <- s.cidIsWanted(lwchk.c) + case resp := <-s.latencyReqs: + resp <- s.averageLatency() + case baseTickDelay := <-s.tickDelayReqs: + s.baseTickDelay = baseTickDelay case <-ctx.Done(): - s.tick.Stop() - s.bs.removeSession(s) - - cmgr := s.bs.network.ConnectionManager() - for _, p := range s.activePeersArr { - cmgr.UntagPeer(p, s.tag) - } + s.handleShutdown() return } } } +func (s *Session) handleIncomingBlock(ctx context.Context, blk blkRecv) { + s.tick.Stop() + + if blk.from != "" { + s.addActivePeer(blk.from) + } + + s.receiveBlock(ctx, blk.blk) + + s.resetTick() +} + +func (s *Session) handleNewRequest(ctx context.Context, keys []cid.Cid) { + for _, k := range keys { + s.interest.Add(k, nil) + } + if len(s.liveWants) < activeWantsLimit { + toadd := activeWantsLimit - len(s.liveWants) + if toadd > len(keys) { + toadd = len(keys) + } + + now := keys[:toadd] + keys = keys[toadd:] + + s.wantBlocks(ctx, now) + } + for _, k := range keys { + s.tofetch.Push(k) + } +} + +func (s *Session) handleCancel(keys []cid.Cid) { + for _, c := range keys { + s.tofetch.Remove(c) + } +} + +func (s *Session) handleTick(ctx context.Context, newpeers chan<- peer.ID) { + live := make([]cid.Cid, 0, len(s.liveWants)) + now := time.Now() + for c := range s.liveWants { + live = append(live, c) + s.liveWants[c] = now + } + + // Broadcast these keys to everyone we're connected to + s.wm.WantBlocks(ctx, live, nil, s.id) + + if len(live) > 0 { + go func(k cid.Cid) { + // TODO: have a task queue setup for this to: + // - rate limit + // - manage timeouts + // - ensure two 'findprovs' calls for the same block don't run concurrently + // - share peers between sessions based on interest set + for p := range s.network.FindProvidersAsync(ctx, k, 10) { + newpeers <- p + } + }(live[0]) + } + s.resetTick() +} + +func (s *Session) addActivePeer(p peer.ID) { + if _, ok := s.activePeers[p]; !ok { + s.activePeers[p] = struct{}{} + s.activePeersArr = append(s.activePeersArr, p) + + cmgr := s.network.ConnectionManager() + cmgr.TagPeer(p, s.tag, 10) + } +} + +func (s *Session) handleShutdown() { + s.tick.Stop() + s.notif.Shutdown() + + live := make([]cid.Cid, 0, len(s.liveWants)) + for c := range s.liveWants { + live = append(live, c) + } + s.wm.CancelWants(s.ctx, live, nil, s.id) + cmgr := s.network.ConnectionManager() + for _, p := range s.activePeersArr { + cmgr.UntagPeer(p, s.tag) + } +} + func (s *Session) cidIsWanted(c cid.Cid) bool { _, ok := s.liveWants[c] if !ok { ok = s.tofetch.Has(c) } - return ok } @@ -270,43 +350,21 @@ func (s *Session) wantBlocks(ctx context.Context, ks []cid.Cid) { for _, c := range ks { s.liveWants[c] = now } - s.bs.wm.WantBlocks(ctx, ks, s.activePeersArr, s.id) + s.wm.WantBlocks(ctx, ks, s.activePeersArr, s.id) } -func (s *Session) cancel(keys []cid.Cid) { - for _, c := range keys { - s.tofetch.Remove(c) - } -} - -func (s *Session) cancelWants(keys []cid.Cid) { - select { - case s.cancelKeys <- keys: - case <-s.ctx.Done(): - } +func (s *Session) averageLatency() time.Duration { + return s.latTotal / time.Duration(s.fetchcnt) } - -func (s *Session) fetch(ctx context.Context, keys []cid.Cid) { - select { - case s.newReqs <- keys: - case <-ctx.Done(): - case <-s.ctx.Done(): +func (s *Session) resetTick() { + if s.latTotal == 0 { + s.tick.Reset(provSearchDelay) + } else { + avLat := s.averageLatency() + s.tick.Reset(s.baseTickDelay + (3 * avLat)) } } -// GetBlocks fetches a set of blocks within the context of this session and -// returns a channel that found blocks will be returned on. No order is -// guaranteed on the returned blocks. -func (s *Session) GetBlocks(ctx context.Context, keys []cid.Cid) (<-chan blocks.Block, error) { - ctx = logging.ContextWithLoggable(ctx, s.uuid) - return getBlocksImpl(ctx, keys, s.notif, s.fetch, s.cancelWants) -} - -// GetBlock fetches a single block -func (s *Session) GetBlock(parent context.Context, k cid.Cid) (blocks.Block, error) { - return getBlock(parent, k, s.GetBlocks) -} - type cidQueue struct { elems []cid.Cid eset *cid.Set diff --git a/sessionmanager/sessionmanager.go b/sessionmanager/sessionmanager.go index 1ebee2fd..aed86af5 100644 --- a/sessionmanager/sessionmanager.go +++ b/sessionmanager/sessionmanager.go @@ -1,32 +1,71 @@ package sessionmanager import ( + "context" "sync" + blocks "github.com/ipfs/go-block-format" + cid "github.com/ipfs/go-cid" + + bsnet "github.com/ipfs/go-bitswap/network" + bssession "github.com/ipfs/go-bitswap/session" + bswm "github.com/ipfs/go-bitswap/wantmanager" exchange "github.com/ipfs/go-ipfs-exchange-interface" + peer "github.com/libp2p/go-libp2p-peer" ) +// SessionManager is responsible for creating, managing, and dispatching to +// sessions type SessionManager struct { + wm *bswm.WantManager + network bsnet.BitSwapNetwork + ctx context.Context // Sessions sessLk sync.Mutex - sessions []exchange.Fetcher + sessions []*bssession.Session // Session Index sessIDLk sync.Mutex sessID uint64 } -func New() *SessionManager { - return &SessionManager{} +// New creates a new SessionManager +func New(ctx context.Context, wm *bswm.WantManager, network bsnet.BitSwapNetwork) *SessionManager { + return &SessionManager{ + ctx: ctx, + wm: wm, + network: network, + } } -func (sm *SessionManager) AddSession(session exchange.Fetcher) { +// NewSession initializes a session with the given context, and adds to the +// session manager +func (sm *SessionManager) NewSession(ctx context.Context) exchange.Fetcher { + id := sm.GetNextSessionID() + sessionctx, cancel := context.WithCancel(ctx) + + session := bssession.New(sessionctx, id, sm.wm, sm.network) sm.sessLk.Lock() sm.sessions = append(sm.sessions, session) sm.sessLk.Unlock() + go func() { + for { + defer cancel() + select { + case <-sm.ctx.Done(): + sm.removeSession(session) + return + case <-ctx.Done(): + sm.removeSession(session) + return + } + } + }() + + return session } -func (sm *SessionManager) RemoveSession(session exchange.Fetcher) { +func (sm *SessionManager) removeSession(session exchange.Fetcher) { sm.sessLk.Lock() defer sm.sessLk.Unlock() for i := 0; i < len(sm.sessions); i++ { @@ -38,6 +77,7 @@ func (sm *SessionManager) RemoveSession(session exchange.Fetcher) { } } +// GetNextSessionID returns the next sequentional identifier for a session func (sm *SessionManager) GetNextSessionID() uint64 { sm.sessIDLk.Lock() defer sm.sessIDLk.Unlock() @@ -45,15 +85,18 @@ func (sm *SessionManager) GetNextSessionID() uint64 { return sm.sessID } -type IterateSessionFunc func(session exchange.Fetcher) - -// IterateSessions loops through all managed sessions and applies the given -// IterateSessionFunc -func (sm *SessionManager) IterateSessions(iterate IterateSessionFunc) { +// ReceiveBlockFrom receives a block from a peer and dispatches to interested +// sessions +func (sm *SessionManager) ReceiveBlockFrom(from peer.ID, blk blocks.Block) { sm.sessLk.Lock() defer sm.sessLk.Unlock() + k := blk.Cid() + ks := []cid.Cid{k} for _, s := range sm.sessions { - iterate(s) + if s.InterestedIn(k) { + s.ReceiveBlockFrom(from, blk) + sm.wm.CancelWants(sm.ctx, ks, nil, s.ID()) + } } }