Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Non-thread-safe subscriptions #160

Merged
merged 11 commits into from
Sep 20, 2024
16 changes: 8 additions & 8 deletions pkg/api/publishWorker.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
"google.golang.org/protobuf/proto"
)

type PublishWorker struct {
type publishWorker struct {
ctx context.Context
log *zap.Logger
listener <-chan []queries.StagedOriginatorEnvelope
Expand All @@ -22,13 +22,13 @@ type PublishWorker struct {
subscription db.DBSubscription[queries.StagedOriginatorEnvelope, int64]
}

func StartPublishWorker(
func startPublishWorker(
ctx context.Context,
log *zap.Logger,
reg *registrant.Registrant,
store *sql.DB,
) (*PublishWorker, error) {
log = log.Named("publishWorker")
) (*publishWorker, error) {
log = log.With(zap.String("method", "publishWorker"))
q := queries.New(store)
query := func(ctx context.Context, lastSeenID int64, numRows int32) ([]queries.StagedOriginatorEnvelope, int64, error) {
results, err := q.SelectStagedOriginatorEnvelopes(
Expand Down Expand Up @@ -59,7 +59,7 @@ func StartPublishWorker(
return nil, err
}

worker := &PublishWorker{
worker := &publishWorker{
ctx: ctx,
log: log,
notifier: notifier,
Expand All @@ -73,14 +73,14 @@ func StartPublishWorker(
return worker, nil
}

func (p *PublishWorker) NotifyStagedPublish() {
func (p *publishWorker) notifyStagedPublish() {
select {
case p.notifier <- true:
default:
}
}

func (p *PublishWorker) start() {
func (p *publishWorker) start() {
for {
select {
case <-p.ctx.Done():
Expand All @@ -97,7 +97,7 @@ func (p *PublishWorker) start() {
}
}

func (p *PublishWorker) publishStagedEnvelope(stagedEnv queries.StagedOriginatorEnvelope) bool {
func (p *publishWorker) publishStagedEnvelope(stagedEnv queries.StagedOriginatorEnvelope) bool {
logger := p.log.With(zap.Int64("sequenceID", stagedEnv.ID))
originatorEnv, err := p.registrant.SignStagedEnvelope(stagedEnv)
if err != nil {
Expand Down
85 changes: 69 additions & 16 deletions pkg/api/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api"
"github.com/xmtp/xmtpd/pkg/registrant"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"

Expand All @@ -23,11 +24,12 @@ const (
type Service struct {
message_api.UnimplementedReplicationApiServer

ctx context.Context
log *zap.Logger
registrant *registrant.Registrant
store *sql.DB
worker *PublishWorker
ctx context.Context
log *zap.Logger
registrant *registrant.Registrant
store *sql.DB
publishWorker *publishWorker
subscribeWorker *subscribeWorker
}

func NewReplicationApiService(
Expand All @@ -36,16 +38,22 @@ func NewReplicationApiService(
registrant *registrant.Registrant,
store *sql.DB,
) (*Service, error) {
worker, err := StartPublishWorker(ctx, log, registrant, store)
publishWorker, err := startPublishWorker(ctx, log, registrant, store)
if err != nil {
return nil, err
}
subscribeWorker, err := startSubscribeWorker(ctx, log, store)
if err != nil {
return nil, err
}

return &Service{
ctx: ctx,
log: log,
registrant: registrant,
store: store,
worker: worker,
ctx: ctx,
log: log,
registrant: registrant,
store: store,
publishWorker: publishWorker,
subscribeWorker: subscribeWorker,
}, nil
}

Expand All @@ -55,15 +63,57 @@ func (s *Service) Close() {

func (s *Service) BatchSubscribeEnvelopes(
req *message_api.BatchSubscribeEnvelopesRequest,
server message_api.ReplicationApi_BatchSubscribeEnvelopesServer,
stream message_api.ReplicationApi_BatchSubscribeEnvelopesServer,
) error {
return status.Errorf(codes.Unimplemented, "method BatchSubscribeEnvelopes not implemented")
log := s.log.With(zap.String("method", "batchSubscribe"))

// Send a header (any header) to fix an issue with Tonic based GRPC clients.
// See: https://github.com/xmtp/libxmtp/pull/58
err := stream.SendHeader(metadata.Pairs("subscribed", "true"))
if err != nil {
return status.Errorf(codes.Internal, "could not send header: %v", err)
}

requests := req.GetRequests()
if len(requests) == 0 {
return status.Errorf(codes.InvalidArgument, "missing requests")
}

ch, err := s.subscribeWorker.listen(requests)
if err != nil {
return status.Errorf(codes.InvalidArgument, "invalid subscription request: %v", err)
}

for {
select {
case envs, open := <-ch:
if open {
err := stream.Send(&message_api.BatchSubscribeEnvelopesResponse{
Envelopes: envs,
})
if err != nil {
return status.Errorf(codes.Internal, "error sending envelope: %v", err)
}
} else {
// TODO(rich) Recover from backpressure
log.Info("stream closed due to backpressure")
richardhuaaa marked this conversation as resolved.
Show resolved Hide resolved
return nil
}
case <-stream.Context().Done():
log.Debug("stream closed")
return nil
case <-s.ctx.Done():
log.Info("service closed")
return nil
}
}
}

func (s *Service) QueryEnvelopes(
ctx context.Context,
req *message_api.QueryEnvelopesRequest,
) (*message_api.QueryEnvelopesResponse, error) {
log := s.log.With(zap.String("method", "query"))
params, err := s.queryReqToDBParams(req)
if err != nil {
return nil, err
Expand All @@ -80,7 +130,7 @@ func (s *Service) QueryEnvelopes(
err := proto.Unmarshal(row.OriginatorEnvelope, originatorEnv)
if err != nil {
// We expect to have already validated the envelope when it was inserted
s.log.Error("could not unmarshal originator envelope", zap.Error(err))
log.Error("could not unmarshal originator envelope", zap.Error(err))
continue
}
envs = append(envs, originatorEnv)
Expand Down Expand Up @@ -109,6 +159,9 @@ func (s *Service) queryReqToDBParams(

switch filter := query.GetFilter().(type) {
case *message_api.EnvelopesQuery_Topic:
if len(filter.Topic) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "missing topic")
}
params.Topic = filter.Topic
case *message_api.EnvelopesQuery_OriginatorNodeId:
params.OriginatorNodeID = db.NullInt32(int32(filter.OriginatorNodeId))
Expand Down Expand Up @@ -162,7 +215,7 @@ func (s *Service) PublishEnvelope(
if err != nil {
return nil, status.Errorf(codes.Internal, "could not insert staged envelope: %v", err)
}
s.worker.NotifyStagedPublish()
s.publishWorker.notifyStagedPublish()

originatorEnv, err := s.registrant.SignStagedEnvelope(stagedEnv)
if err != nil {
Expand Down Expand Up @@ -196,7 +249,7 @@ func (s *Service) validatePayerInfo(
}

func (s *Service) validateClientInfo(clientEnv *message_api.ClientEnvelope) ([]byte, error) {
if clientEnv.GetAad().GetTargetOriginator() != uint32(s.registrant.NodeID()) {
if clientEnv.GetAad().GetTargetOriginator() != s.registrant.NodeID() {
return nil, status.Errorf(codes.InvalidArgument, "invalid target originator")
}

Expand Down
Loading
Loading