diff --git a/CHANGELOG.md b/CHANGELOG.md index 0e307030f..9adc66b32 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,7 @@ The following emojis are used to highlight certain changes: - 🛠 `blockstore` and `blockservice`'s `WriteThrough()` option now takes an "enabled" parameter: `WriteThrough(enabled bool)`. - Replaced unmaintained mock time implementation uses in tests: [from](github.com/benbjohnson/clock) => [to](github.com/filecoin-project/go-clock) - updated to go-libp2p to [v0.38.0](https://github.com/libp2p/go-libp2p/releases/tag/v0.38.0) +- `bitswap/client`: if a libp2p connection has a context, use `context.AfterFunc` to cleanup the connection. ### Removed diff --git a/bitswap/network/ipfs_impl.go b/bitswap/network/ipfs_impl.go index 72f86d099..4a60aaf6b 100644 --- a/bitswap/network/ipfs_impl.go +++ b/bitswap/network/ipfs_impl.go @@ -82,18 +82,45 @@ type impl struct { receivers []Receiver } +// interfaceWrapper is concrete type that wraps an interface. Necessary because +// atomic.Value needs the same type and can not Store(nil). This indirection +// allows us to store nil. +type interfaceWrapper[T any] struct { + t T +} +type atomicInterface[T any] struct { + iface atomic.Value +} + +func (a *atomicInterface[T]) Load() T { + var v T + x := a.iface.Load() + if x != nil { + return x.(interfaceWrapper[T]).t + } + return v +} + +func (a *atomicInterface[T]) Store(v T) { + a.iface.Store(interfaceWrapper[T]{v}) +} + type streamMessageSender struct { - to peer.ID - stream network.Stream - connected bool - bsnet *impl - opts *MessageSenderOpts + to peer.ID + stream atomicInterface[network.Stream] + bsnet *impl + opts *MessageSenderOpts +} + +type HasContext interface { + Context() context.Context } // Open a stream to the remote peer func (s *streamMessageSender) Connect(ctx context.Context) (network.Stream, error) { - if s.connected { - return s.stream, nil + stream := s.stream.Load() + if stream != nil { + return stream, nil } tctx, cancel := context.WithTimeout(ctx, s.opts.SendTimeout) @@ -107,17 +134,22 @@ func (s *streamMessageSender) Connect(ctx context.Context) (network.Stream, erro if err != nil { return nil, err } + if withCtx, ok := stream.Conn().(HasContext); ok { + context.AfterFunc(withCtx.Context(), func() { + s.stream.Store(nil) + }) + } - s.stream = stream - s.connected = true - return s.stream, nil + s.stream.Store(stream) + return stream, nil } // Reset the stream func (s *streamMessageSender) Reset() error { - if s.stream != nil { - err := s.stream.Reset() - s.connected = false + stream := s.stream.Load() + if stream != nil { + err := stream.Reset() + s.stream.Store(nil) return err } return nil @@ -125,12 +157,22 @@ func (s *streamMessageSender) Reset() error { // Close the stream func (s *streamMessageSender) Close() error { - return s.stream.Close() + stream := s.stream.Load() + if stream != nil { + err := stream.Close() + s.stream.Store(nil) + return err + } + return nil } // Indicates whether the peer supports HAVE / DONT_HAVE messages func (s *streamMessageSender) SupportsHave() bool { - return s.bsnet.SupportsHave(s.stream.Protocol()) + stream := s.stream.Load() + if stream == nil { + return false + } + return s.bsnet.SupportsHave(stream.Protocol()) } // Send a message to the peer, attempting multiple times