Skip to content

Commit

Permalink
Add publish worker
Browse files Browse the repository at this point in the history
  • Loading branch information
richardhuaaa committed Aug 26, 2024
1 parent 9771fb2 commit a8cef43
Show file tree
Hide file tree
Showing 8 changed files with 346 additions and 41 deletions.
147 changes: 147 additions & 0 deletions pkg/api/publishWorker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package api

import (
"context"
"database/sql"
"time"

"github.com/xmtp/xmtpd/pkg/db"
"github.com/xmtp/xmtpd/pkg/db/queries"
"github.com/xmtp/xmtpd/pkg/registrant"
"go.uber.org/zap"
"google.golang.org/protobuf/proto"
)

type PublishWorker struct {
ctx context.Context
log *zap.Logger
listener <-chan []queries.StagedOriginatorEnvelope
notifier chan<- bool
registrant *registrant.Registrant
store *sql.DB
subscription db.DBSubscription[queries.StagedOriginatorEnvelope]
}

func StartPublishWorker(
ctx context.Context,
log *zap.Logger,
reg *registrant.Registrant,
store *sql.DB,
) (*PublishWorker, error) {
q := queries.New(store)
query := func(ctx context.Context, lastSeenID int64, numRows int32) ([]queries.StagedOriginatorEnvelope, int64, error) {
results, err := q.SelectStagedOriginatorEnvelopes(
ctx,
queries.SelectStagedOriginatorEnvelopesParams{
LastSeenID: lastSeenID,
NumRows: numRows,
},
)
if err != nil {
return nil, 0, err
}
if len(results) > 0 {
lastSeenID = results[len(results)-1].ID
}
return results, lastSeenID, nil
}
notifier := make(chan bool, 1)
subscription := db.NewDBSubscription(
ctx,
log,
query,
0, // lastSeenID
db.PollingOptions{Interval: time.Second, Notifier: notifier, NumRows: 100},
)
listener, err := subscription.Start()
if err != nil {
return nil, err
}

worker := &PublishWorker{
ctx: ctx,
log: log,
notifier: notifier,
subscription: *subscription,
listener: listener,
registrant: reg,
store: store,
}
go worker.start()

return worker, nil
}

func (p *PublishWorker) NotifyStagedPublish() {
select {
case p.notifier <- true:
default:
}
}

func (p *PublishWorker) start() {
for {
select {
case <-p.ctx.Done():
return
case new_batch := <-p.listener:
for _, stagedEnv := range new_batch {
for !p.publishStagedEnvelope(stagedEnv) {
// Infinite retry on failure to publish; we cannot
// continue to the next envelope until this one is processed
time.Sleep(time.Second)
}
}
}
}
}

func (p *PublishWorker) publishStagedEnvelope(stagedEnv queries.StagedOriginatorEnvelope) bool {
logger := p.log.With(zap.Int64("sequenceID", stagedEnv.ID))
originatorEnv, err := p.registrant.SignStagedEnvelope(stagedEnv)
if err != nil {
logger.Error(
"Failed to sign staged envelope",
zap.Error(err),
)
return false
}
originatorBytes, err := proto.Marshal(originatorEnv)
if err != nil {
logger.Error("Failed to marshal originator envelope", zap.Error(err))
return false
}

q := queries.New(p.store)

// On unique constraint conflicts, no error is thrown, but numRows is 0
inserted, err := q.InsertGatewayEnvelope(
p.ctx,
queries.InsertGatewayEnvelopeParams{
OriginatorID: int32(p.registrant.NodeID()),
OriginatorSequenceID: stagedEnv.ID,
Topic: stagedEnv.Topic,
OriginatorEnvelope: originatorBytes,
},
)
if err != nil {
logger.Error("Failed to insert gateway envelope", zap.Error(err))
return false
} else if inserted == 0 {
// Envelope was already inserted by another worker
logger.Debug("Envelope already inserted")
}

// Try to delete the row regardless of if the gateway envelope was inserted elsewhere
deleted, err := q.DeleteStagedOriginatorEnvelope(context.Background(), stagedEnv.ID)
if err != nil {
logger.Error("Failed to delete staged envelope", zap.Error(err))
// Envelope is already inserted, so it is safe to continue
return true
} else if deleted == 0 {
// Envelope was already deleted by another worker
logger.Debug("Envelope already deleted")
}

return true
}
84 changes: 70 additions & 14 deletions pkg/api/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,27 @@ type Service struct {
ctx context.Context
log *zap.Logger
registrant *registrant.Registrant
queries *queries.Queries
store *sql.DB
worker *PublishWorker
}

func NewReplicationApiService(
ctx context.Context,
log *zap.Logger,
registrant *registrant.Registrant,
writerDB *sql.DB,
store *sql.DB,
) (*Service, error) {
return &Service{ctx: ctx, log: log, registrant: registrant, queries: queries.New(writerDB)}, nil
worker, err := StartPublishWorker(ctx, log, registrant, store)
if err != nil {
return nil, err
}
return &Service{
ctx: ctx,
log: log,
registrant: registrant,
store: store,
worker: worker,
}, nil
}

func (s *Service) Close() {
Expand All @@ -54,27 +65,32 @@ func (s *Service) PublishEnvelope(
ctx context.Context,
req *message_api.PublishEnvelopeRequest,
) (*message_api.PublishEnvelopeResponse, error) {
payerEnv := req.GetPayerEnvelope()
clientBytes := payerEnv.GetUnsignedClientEnvelope()
sig := payerEnv.GetPayerSignature()
if (clientBytes == nil) || (sig == nil) {
return nil, status.Errorf(codes.InvalidArgument, "missing envelope or signature")
clientEnv, err := s.validatePayerInfo(req.GetPayerEnvelope())
if err != nil {
return nil, err
}
// TODO(rich): Verify payer signature
// TODO(rich): Verify all originators have synced past `last_originator_sids`
// TODO(rich): Check that the blockchain sequence ID is equal to the latest on the group
// TODO(rich): Perform any payload-specific validation (e.g. identity updates)

topic, err := s.validateClientInfo(clientEnv)
if err != nil {
return nil, err
}

// TODO(rich): If it is a commit, publish it to blockchain instead

payerBytes, err := proto.Marshal(payerEnv)
payerBytes, err := proto.Marshal(req.GetPayerEnvelope())
if err != nil {
return nil, status.Errorf(codes.Internal, "could not marshal envelope: %v", err)
}

stagedEnv, err := s.queries.InsertStagedOriginatorEnvelope(ctx, payerBytes)
stagedEnv, err := queries.New(s.store).
InsertStagedOriginatorEnvelope(ctx, queries.InsertStagedOriginatorEnvelopeParams{
Topic: topic,
PayerEnvelope: payerBytes,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "could not insert staged envelope: %v", err)
}
s.worker.NotifyStagedPublish()

originatorEnv, err := s.registrant.SignStagedEnvelope(stagedEnv)
if err != nil {
Expand All @@ -83,3 +99,43 @@ func (s *Service) PublishEnvelope(

return &message_api.PublishEnvelopeResponse{OriginatorEnvelope: originatorEnv}, nil
}

func (s *Service) validatePayerInfo(
payerEnv *message_api.PayerEnvelope,
) (*message_api.ClientEnvelope, error) {
clientBytes := payerEnv.GetUnsignedClientEnvelope()
sig := payerEnv.GetPayerSignature()
if (clientBytes == nil) || (sig == nil) {
return nil, status.Errorf(codes.InvalidArgument, "missing envelope or signature")
}
// TODO(rich): Verify payer signature

clientEnv := &message_api.ClientEnvelope{}
err := proto.Unmarshal(clientBytes, clientEnv)
if err != nil {
return nil, status.Errorf(
codes.InvalidArgument,
"could not unmarshal client envelope: %v",
err,
)
}

return clientEnv, nil
}

func (s *Service) validateClientInfo(clientEnv *message_api.ClientEnvelope) ([]byte, error) {
if clientEnv.GetAad().GetTargetOriginator() != uint32(s.registrant.NodeID()) {
return nil, status.Errorf(codes.InvalidArgument, "invalid target originator")
}

topic := clientEnv.GetAad().GetTargetTopic()
if len(topic) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "missing target topic")
}

// TODO(rich): Verify all originators have synced past `last_originator_sids`
// TODO(rich): Check that the blockchain sequence ID is equal to the latest on the group
// TODO(rich): Perform any payload-specific validation (e.g. identity updates)

return topic, nil
}
102 changes: 95 additions & 7 deletions pkg/api/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"testing"
"time"

"github.com/ethereum/go-ethereum/crypto"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -40,17 +41,41 @@ func newTestService(t *testing.T) (*Service, *sql.DB, func()) {
}
}

func createClientEnvelope() *message_api.ClientEnvelope {
return &message_api.ClientEnvelope{
Payload: nil,
Aad: &message_api.AuthenticatedData{
TargetOriginator: 1,
TargetTopic: []byte{0x5},
LastOriginatorSids: []uint64{},
},
}
}

func createPayerEnvelope(
t *testing.T,
clientEnv ...*message_api.ClientEnvelope,
) *message_api.PayerEnvelope {
if len(clientEnv) == 0 {
clientEnv = append(clientEnv, createClientEnvelope())
}
clientEnvBytes, err := proto.Marshal(clientEnv[0])
require.NoError(t, err)

return &message_api.PayerEnvelope{
UnsignedClientEnvelope: clientEnvBytes,
PayerSignature: &associations.RecoverableEcdsaSignature{},
}
}

func TestSimplePublish(t *testing.T) {
svc, _, cleanup := newTestService(t)
svc, db, cleanup := newTestService(t)
defer cleanup()

resp, err := svc.PublishEnvelope(
context.Background(),
&message_api.PublishEnvelopeRequest{
PayerEnvelope: &message_api.PayerEnvelope{
UnsignedClientEnvelope: []byte{0x5},
PayerSignature: &associations.RecoverableEcdsaSignature{},
},
PayerEnvelope: createPayerEnvelope(t),
},
)
require.NoError(t, err)
Expand All @@ -61,7 +86,70 @@ func TestSimplePublish(t *testing.T) {
t,
proto.Unmarshal(resp.GetOriginatorEnvelope().GetUnsignedOriginatorEnvelope(), unsignedEnv),
)
require.Equal(t, uint8(0x5), unsignedEnv.GetPayerEnvelope().GetUnsignedClientEnvelope()[0])
clientEnv := &message_api.ClientEnvelope{}
require.NoError(
t,
proto.Unmarshal(unsignedEnv.GetPayerEnvelope().GetUnsignedClientEnvelope(), clientEnv),
)
require.Equal(t, uint8(0x5), clientEnv.Aad.GetTargetTopic()[0])

// TODO(rich) Test that the published envelope is retrievable via the query API
// Check that the envelope was published to the database after a delay
require.Eventually(t, func() bool {
envs, err := queries.New(db).
SelectGatewayEnvelopes(context.Background(), queries.SelectGatewayEnvelopesParams{})
require.NoError(t, err)

if len(envs) != 1 {
return false
}

originatorEnv := &message_api.OriginatorEnvelope{}
require.NoError(t, proto.Unmarshal(envs[0].OriginatorEnvelope, originatorEnv))
return proto.Equal(originatorEnv, resp.GetOriginatorEnvelope())
}, 500*time.Millisecond, 50*time.Millisecond)
}

func TestUnmarshalError(t *testing.T) {
svc, _, cleanup := newTestService(t)
defer cleanup()

envelope := createPayerEnvelope(t)
envelope.UnsignedClientEnvelope = []byte("invalidbytes")
_, err := svc.PublishEnvelope(
context.Background(),
&message_api.PublishEnvelopeRequest{
PayerEnvelope: envelope,
},
)
require.ErrorContains(t, err, "unmarshal")
}

func TestMismatchingOriginator(t *testing.T) {
svc, _, cleanup := newTestService(t)
defer cleanup()

clientEnv := createClientEnvelope()
clientEnv.Aad.TargetOriginator = 2
_, err := svc.PublishEnvelope(
context.Background(),
&message_api.PublishEnvelopeRequest{
PayerEnvelope: createPayerEnvelope(t, clientEnv),
},
)
require.ErrorContains(t, err, "originator")
}

func TestMissingTopic(t *testing.T) {
svc, _, cleanup := newTestService(t)
defer cleanup()

clientEnv := createClientEnvelope()
clientEnv.Aad.TargetTopic = nil
_, err := svc.PublishEnvelope(
context.Background(),
&message_api.PublishEnvelopeRequest{
PayerEnvelope: createPayerEnvelope(t, clientEnv),
},
)
require.ErrorContains(t, err, "topic")
}
Loading

0 comments on commit a8cef43

Please sign in to comment.