Skip to content

Commit

Permalink
Thread-safe subscriptions and listener cleanup (#166)
Browse files Browse the repository at this point in the history
This PR adds:
- Thread-safety for mutation of listener maps
- Safe cleanup flow for listener channels

As a general principle, we use Golang's `sync.map`. This is an
optimistic concurrency pattern that restricts contention to a per-key
level, and separates reads and writes via a read-only map and dirty map.
So the dispatch loop is not affected - any mutations to the sync maps
are performed outside of the dispatching goroutine.

I've also added a `RWMutex` to synchronize between adding and removing
listeners. The main thing we are protecting against is that when
removing a listener, we may want to delete the `listenerset` if it is
empty - but we can't perform the emptiness check and the deletion step
atomically without a mutex. I think this should be okay, because in the
current server we are getting something like 10 new subscriptions a
second.

Would love to do some benchmarking on this later to make sure we've made
the right tradeoffs, I can see us changing out the implementation
underneath depending on where the bottleneck is.

Closes #125
  • Loading branch information
richardhuaaa authored Sep 20, 2024
1 parent 4caf7ad commit 2427b56
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 58 deletions.
4 changes: 2 additions & 2 deletions pkg/api/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func (s *Service) BatchSubscribeEnvelopes(
return status.Errorf(codes.InvalidArgument, "missing requests")
}

ch, err := s.subscribeWorker.listen(requests)
ch, err := s.subscribeWorker.listen(stream.Context(), requests)
if err != nil {
return status.Errorf(codes.InvalidArgument, "invalid subscription request: %v", err)
}
Expand All @@ -96,7 +96,7 @@ func (s *Service) BatchSubscribeEnvelopes(
}
} else {
// TODO(rich) Recover from backpressure
log.Info("stream closed due to backpressure")
log.Debug("channel closed by worker")
return nil
}
case <-stream.Context().Done():
Expand Down
207 changes: 156 additions & 51 deletions pkg/api/subscribeWorker.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@ import (
"database/sql"
"encoding/hex"
"fmt"
"sync"
"time"

"github.com/xmtp/xmtpd/pkg/db"
"github.com/xmtp/xmtpd/pkg/db/queries"
"github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api"
"go.uber.org/zap"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
)

Expand All @@ -24,7 +23,75 @@ const (
maxTopicLength = 128
)

type listener = chan<- []*message_api.OriginatorEnvelope
type listener struct {
ctx context.Context
ch chan<- []*message_api.OriginatorEnvelope
closed bool
topics map[string]struct{}
originators map[uint32]struct{}
isGlobal bool
}

type listenerSet struct {
sync.Map // map[*listener]struct{}
}

func (ls *listenerSet) addListener(l *listener) {
ls.Store(l, struct{}{})
}

func (ls *listenerSet) removeListener(l *listener) {
ls.Delete(l)
}

func (ls *listenerSet) isEmpty() bool {
empty := true
ls.Range(func(_, _ interface{}) bool {
empty = false
return false // stop iteration
})
return empty
}

// Maps from a key to a set of listeners
type listenersMap[K comparable] struct {
data sync.Map // map[K]*listenerSet
mu sync.RWMutex // ensures mutations are consistent
}

func (lm *listenersMap[K]) addListener(keys map[K]struct{}, l *listener) {
lm.mu.RLock()
defer lm.mu.RUnlock()
for key := range keys {
value, _ := lm.data.LoadOrStore(key, &listenerSet{})
set := value.(*listenerSet)
set.addListener(l)
}
}

func (lm *listenersMap[K]) removeListener(keys map[K]struct{}, l *listener) {
lm.mu.Lock()
defer lm.mu.Unlock()
for key := range keys {
value, ok := lm.data.Load(key)
if !ok || value == nil {
return // Key doesn't exist, nothing to do
}
set := value.(*listenerSet)
set.removeListener(l)
if set.isEmpty() {
lm.data.Delete(key)
}
}
}

func (lm *listenersMap[K]) getListeners(key K) *listenerSet {
// No lock needed, because we are not mutating lm.data
if value, ok := lm.data.Load(key); ok {
return value.(*listenerSet)
}
return nil
}

// A worker that listens for new envelopes in the DB and sends them to subscribers
// Assumes that there are many listeners - non-blocking updates are sent on buffered channels
Expand All @@ -35,9 +102,9 @@ type subscribeWorker struct {

dbSubscription <-chan []queries.GatewayEnvelope
// Assumption: listeners cannot be in multiple slices
globalListeners []listener
originatorListeners map[uint32][]listener
topicListeners map[string][]listener
globalListeners listenerSet
originatorListeners listenersMap[uint32]
topicListeners listenersMap[string]
}

func startSubscribeWorker(
Expand Down Expand Up @@ -86,9 +153,9 @@ func startSubscribeWorker(
ctx: ctx,
log: log,
dbSubscription: dbChan,
globalListeners: make([]listener, 0),
originatorListeners: make(map[uint32][]listener),
topicListeners: make(map[string][]listener),
globalListeners: listenerSet{},
originatorListeners: listenersMap[uint32]{},
topicListeners: listenersMap[string]{},
}

go worker.start()
Expand All @@ -109,42 +176,78 @@ func (s *subscribeWorker) start() {
}
}

func (s *subscribeWorker) dispatch(
row *queries.GatewayEnvelope,
) {
bytes := row.OriginatorEnvelope
func (s *subscribeWorker) dispatch(row *queries.GatewayEnvelope) {
env := &message_api.OriginatorEnvelope{}
err := proto.Unmarshal(bytes, env)
err := proto.Unmarshal(row.OriginatorEnvelope, env)
if err != nil {
s.log.Error("Failed to unmarshal envelope", zap.Error(err))
return
}
for _, listener := range s.originatorListeners[uint32(row.OriginatorNodeID)] {
select {
case listener <- []*message_api.OriginatorEnvelope{env}:
default: // TODO(rich) Close and clean up channel
}

originatorListeners := s.originatorListeners.getListeners(uint32(row.OriginatorNodeID))
topicListeners := s.topicListeners.getListeners(hex.EncodeToString(row.Topic))
s.dispatchToListeners(originatorListeners, env)
s.dispatchToListeners(topicListeners, env)
s.dispatchToListeners(&s.globalListeners, env)
}

func (s *subscribeWorker) dispatchToListeners(
listeners *listenerSet,
env *message_api.OriginatorEnvelope,
) {
if listeners == nil {
return
}
for _, listener := range s.topicListeners[hex.EncodeToString(row.Topic)] {
select {
case listener <- []*message_api.OriginatorEnvelope{env}:
default:
listeners.Range(func(key, _ any) bool {
l := key.(*listener)
if l.closed {
return true
}
}
for _, listener := range s.globalListeners {
// Assumption: listener channel is never closed by a different goroutine
select {
case listener <- []*message_api.OriginatorEnvelope{env}:
case <-l.ctx.Done():
s.log.Debug("Stream closed, removing listener", zap.Any("listener", l.ch))
s.closeListener(l)
default:
select {
case l.ch <- []*message_api.OriginatorEnvelope{env}:
default:
s.log.Info("Channel full, removing listener", zap.Any("listener", l.ch))
s.closeListener(l)
}
}
}
return true
})
}

func (s *subscribeWorker) closeListener(l *listener) {
// Assumption: this method may not be called across multiple goroutines
l.closed = true
close(l.ch)

go func() {
if l.isGlobal {
s.globalListeners.Delete(l)
} else if len(l.topics) > 0 {
s.topicListeners.removeListener(l.topics, l)
} else if len(l.originators) > 0 {
s.originatorListeners.removeListener(l.originators, l)
}
}()
}

func (s *subscribeWorker) listen(
ctx context.Context,
requests []*message_api.BatchSubscribeEnvelopesRequest_SubscribeEnvelopesRequest,
) (<-chan []*message_api.OriginatorEnvelope, error) {
subscribeAll := false
topics := make(map[string]bool, len(requests))
originators := make(map[uint32]bool, len(requests))
ch := make(chan []*message_api.OriginatorEnvelope, subscriptionBufferSize)
l := &listener{
ctx: ctx,
ch: ch,
topics: make(map[string]struct{}),
originators: make(map[uint32]struct{}),
isGlobal: false,
}

if len(requests) > maxSubscriptionsPerClient {
return nil, fmt.Errorf(
Expand All @@ -155,40 +258,42 @@ func (s *subscribeWorker) listen(
for _, req := range requests {
enum := req.GetQuery().GetFilter()
if enum == nil {
subscribeAll = true
l.isGlobal = true
}
switch filter := enum.(type) {
case *message_api.EnvelopesQuery_Topic:
topic := hex.EncodeToString(filter.Topic)
if len(filter.Topic) == 0 || len(filter.Topic) > maxTopicLength {
return nil, status.Errorf(codes.InvalidArgument, "invalid topic")
return nil, fmt.Errorf("invalid topic: %s", topic)
}
if _, exists := l.topics[topic]; exists {
return nil, fmt.Errorf("multiple requests for same topic: %s", topic)
}
topics[hex.EncodeToString(filter.Topic)] = true
l.topics[topic] = struct{}{}
case *message_api.EnvelopesQuery_OriginatorNodeId:
originators[filter.OriginatorNodeId] = true
if _, exists := l.originators[filter.OriginatorNodeId]; exists {
return nil, fmt.Errorf("multiple requests for same originator: %d", filter.OriginatorNodeId)
}
l.originators[filter.OriginatorNodeId] = struct{}{}
default:
subscribeAll = true
l.isGlobal = true
}
}

ch := make(chan []*message_api.OriginatorEnvelope, subscriptionBufferSize)

if subscribeAll {
if len(topics) > 0 || len(originators) > 0 {
return nil, fmt.Errorf("cannot filter by topic or originator when subscribing to all")
if l.isGlobal {
if len(l.topics) > 0 || len(l.originators) > 0 {
return nil, fmt.Errorf(
"cannot filter by topic or originator when subscribing to all",
)
}
// TODO(rich) thread safety
s.globalListeners = append(s.globalListeners, ch)
} else if len(topics) > 0 {
if len(originators) > 0 {
s.globalListeners.Store(l, struct{}{})
} else if len(l.topics) > 0 {
if len(l.originators) > 0 {
return nil, fmt.Errorf("cannot filter by both topic and originator in same subscription request")
}
for topic := range topics {
s.topicListeners[topic] = append(s.topicListeners[topic], ch)
}
} else if len(originators) > 0 {
for originator := range originators {
s.originatorListeners[originator] = append(s.originatorListeners[originator], ch)
}
s.topicListeners.addListener(l.topics, l)
} else if len(l.originators) > 0 {
s.originatorListeners.addListener(l.originators, l)
}

return ch, nil
Expand Down
5 changes: 0 additions & 5 deletions pkg/api/subscribe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ func validateUpdates(
}

func TestSubscribeEnvelopesAll(t *testing.T) {
t.Skip("TODO(rich) thread safety for race tests")
client, db, cleanup := setupTest(t)
defer cleanup()
insertInitialRows(t, db)
Expand All @@ -131,7 +130,6 @@ func TestSubscribeEnvelopesAll(t *testing.T) {
}

func TestSubscribeEnvelopesByTopic(t *testing.T) {
t.Skip("TODO(rich) thread safety for race tests")
client, db, cleanup := setupTest(t)
defer cleanup()
insertInitialRows(t, db)
Expand Down Expand Up @@ -164,7 +162,6 @@ func TestSubscribeEnvelopesByTopic(t *testing.T) {
}

func TestSubscribeEnvelopesByOriginator(t *testing.T) {
t.Skip("TODO(rich) thread safety for race tests")
client, db, cleanup := setupTest(t)
defer cleanup()
insertInitialRows(t, db)
Expand Down Expand Up @@ -197,7 +194,6 @@ func TestSubscribeEnvelopesByOriginator(t *testing.T) {
}

func TestSimultaneousSubscriptions(t *testing.T) {
t.Skip("TODO(rich) thread safety for race tests")
client, db, cleanup := setupTest(t)
defer cleanup()
insertInitialRows(t, db)
Expand Down Expand Up @@ -256,7 +252,6 @@ func TestSimultaneousSubscriptions(t *testing.T) {
}

func TestSubscribeEnvelopesInvalidRequest(t *testing.T) {
t.Skip("TODO(rich) thread safety for race tests")
client, _, cleanup := setupTest(t)
defer cleanup()

Expand Down

0 comments on commit 2427b56

Please sign in to comment.