diff --git a/pkg/api/publish_test.go b/pkg/api/publish_test.go new file mode 100644 index 00000000..3397a4f7 --- /dev/null +++ b/pkg/api/publish_test.go @@ -0,0 +1,99 @@ +package api_test + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/xmtp/xmtpd/pkg/db/queries" + "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api" + "github.com/xmtp/xmtpd/pkg/testutils" + "google.golang.org/protobuf/proto" +) + +func TestPublishEnvelope(t *testing.T) { + svc, db, cleanup := testutils.NewTestService(t) + defer cleanup() + + resp, err := svc.PublishEnvelope( + context.Background(), + &message_api.PublishEnvelopeRequest{ + PayerEnvelope: testutils.CreatePayerEnvelope(t), + }, + ) + require.NoError(t, err) + require.NotNil(t, resp) + + unsignedEnv := &message_api.UnsignedOriginatorEnvelope{} + require.NoError( + t, + proto.Unmarshal(resp.GetOriginatorEnvelope().GetUnsignedOriginatorEnvelope(), unsignedEnv), + ) + clientEnv := &message_api.ClientEnvelope{} + require.NoError( + t, + proto.Unmarshal(unsignedEnv.GetPayerEnvelope().GetUnsignedClientEnvelope(), clientEnv), + ) + require.Equal(t, uint8(0x5), clientEnv.Aad.GetTargetTopic()[0]) + + // Check that the envelope was published to the database after a delay + require.Eventually(t, func() bool { + envs, err := queries.New(db). + SelectGatewayEnvelopes(context.Background(), queries.SelectGatewayEnvelopesParams{}) + require.NoError(t, err) + + if len(envs) != 1 { + return false + } + + originatorEnv := &message_api.OriginatorEnvelope{} + require.NoError(t, proto.Unmarshal(envs[0].OriginatorEnvelope, originatorEnv)) + return proto.Equal(originatorEnv, resp.GetOriginatorEnvelope()) + }, 500*time.Millisecond, 50*time.Millisecond) +} + +func TestUnmarshalErrorOnPublish(t *testing.T) { + svc, _, cleanup := testutils.NewTestService(t) + defer cleanup() + + envelope := testutils.CreatePayerEnvelope(t) + envelope.UnsignedClientEnvelope = []byte("invalidbytes") + _, err := svc.PublishEnvelope( + context.Background(), + &message_api.PublishEnvelopeRequest{ + PayerEnvelope: envelope, + }, + ) + require.ErrorContains(t, err, "unmarshal") +} + +func TestMismatchingOriginatorOnPublish(t *testing.T) { + svc, _, cleanup := testutils.NewTestService(t) + defer cleanup() + + clientEnv := testutils.CreateClientEnvelope() + clientEnv.Aad.TargetOriginator = 2 + _, err := svc.PublishEnvelope( + context.Background(), + &message_api.PublishEnvelopeRequest{ + PayerEnvelope: testutils.CreatePayerEnvelope(t, clientEnv), + }, + ) + require.ErrorContains(t, err, "originator") +} + +func TestMissingTopicOnPublish(t *testing.T) { + svc, _, cleanup := testutils.NewTestService(t) + defer cleanup() + + clientEnv := testutils.CreateClientEnvelope() + clientEnv.Aad.TargetTopic = nil + _, err := svc.PublishEnvelope( + context.Background(), + &message_api.PublishEnvelopeRequest{ + PayerEnvelope: testutils.CreatePayerEnvelope(t, clientEnv), + }, + ) + require.ErrorContains(t, err, "topic") +} diff --git a/pkg/api/service_test.go b/pkg/api/query_test.go similarity index 54% rename from pkg/api/service_test.go rename to pkg/api/query_test.go index 639190fe..0833c94b 100644 --- a/pkg/api/service_test.go +++ b/pkg/api/query_test.go @@ -1,131 +1,16 @@ -package api +package api_test import ( "context" "database/sql" "testing" - "time" - "github.com/ethereum/go-ethereum/crypto" "github.com/stretchr/testify/require" "github.com/xmtp/xmtpd/pkg/db/queries" - mocks "github.com/xmtp/xmtpd/pkg/mocks/registry" "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api" - "github.com/xmtp/xmtpd/pkg/registrant" - "github.com/xmtp/xmtpd/pkg/registry" "github.com/xmtp/xmtpd/pkg/testutils" - "google.golang.org/protobuf/proto" ) -func newTestService(t *testing.T) (*Service, *sql.DB, func()) { - ctx := context.Background() - log := testutils.NewLog(t) - db, _, dbCleanup := testutils.NewDB(t, ctx) - privKey, err := crypto.GenerateKey() - require.NoError(t, err) - privKeyStr := "0x" + testutils.HexEncode(crypto.FromECDSA(privKey)) - mockRegistry := mocks.NewMockNodeRegistry(t) - mockRegistry.EXPECT().GetNodes().Return([]registry.Node{ - {NodeID: 1, SigningKey: &privKey.PublicKey}, - }, nil) - registrant, err := registrant.NewRegistrant(ctx, queries.New(db), mockRegistry, privKeyStr) - require.NoError(t, err) - - svc, err := NewReplicationApiService(ctx, log, registrant, db) - require.NoError(t, err) - - return svc, db, func() { - svc.Close() - dbCleanup() - } -} - -func TestPublishEnvelope(t *testing.T) { - svc, db, cleanup := newTestService(t) - defer cleanup() - - resp, err := svc.PublishEnvelope( - context.Background(), - &message_api.PublishEnvelopeRequest{ - PayerEnvelope: testutils.CreatePayerEnvelope(t), - }, - ) - require.NoError(t, err) - require.NotNil(t, resp) - - unsignedEnv := &message_api.UnsignedOriginatorEnvelope{} - require.NoError( - t, - proto.Unmarshal(resp.GetOriginatorEnvelope().GetUnsignedOriginatorEnvelope(), unsignedEnv), - ) - clientEnv := &message_api.ClientEnvelope{} - require.NoError( - t, - proto.Unmarshal(unsignedEnv.GetPayerEnvelope().GetUnsignedClientEnvelope(), clientEnv), - ) - require.Equal(t, uint8(0x5), clientEnv.Aad.GetTargetTopic()[0]) - - // Check that the envelope was published to the database after a delay - require.Eventually(t, func() bool { - envs, err := queries.New(db). - SelectGatewayEnvelopes(context.Background(), queries.SelectGatewayEnvelopesParams{}) - require.NoError(t, err) - - if len(envs) != 1 { - return false - } - - originatorEnv := &message_api.OriginatorEnvelope{} - require.NoError(t, proto.Unmarshal(envs[0].OriginatorEnvelope, originatorEnv)) - return proto.Equal(originatorEnv, resp.GetOriginatorEnvelope()) - }, 500*time.Millisecond, 50*time.Millisecond) -} - -func TestUnmarshalErrorOnPublish(t *testing.T) { - svc, _, cleanup := newTestService(t) - defer cleanup() - - envelope := testutils.CreatePayerEnvelope(t) - envelope.UnsignedClientEnvelope = []byte("invalidbytes") - _, err := svc.PublishEnvelope( - context.Background(), - &message_api.PublishEnvelopeRequest{ - PayerEnvelope: envelope, - }, - ) - require.ErrorContains(t, err, "unmarshal") -} - -func TestMismatchingOriginatorOnPublish(t *testing.T) { - svc, _, cleanup := newTestService(t) - defer cleanup() - - clientEnv := testutils.CreateClientEnvelope() - clientEnv.Aad.TargetOriginator = 2 - _, err := svc.PublishEnvelope( - context.Background(), - &message_api.PublishEnvelopeRequest{ - PayerEnvelope: testutils.CreatePayerEnvelope(t, clientEnv), - }, - ) - require.ErrorContains(t, err, "originator") -} - -func TestMissingTopicOnPublish(t *testing.T) { - svc, _, cleanup := newTestService(t) - defer cleanup() - - clientEnv := testutils.CreateClientEnvelope() - clientEnv.Aad.TargetTopic = nil - _, err := svc.PublishEnvelope( - context.Background(), - &message_api.PublishEnvelopeRequest{ - PayerEnvelope: testutils.CreatePayerEnvelope(t, clientEnv), - }, - ) - require.ErrorContains(t, err, "topic") -} - func setupQueryTest(t *testing.T, db *sql.DB) []queries.InsertGatewayEnvelopeParams { db_rows := []queries.InsertGatewayEnvelopeParams{ { @@ -179,7 +64,7 @@ func setupQueryTest(t *testing.T, db *sql.DB) []queries.InsertGatewayEnvelopePar } func TestQueryAllEnvelopes(t *testing.T) { - svc, db, cleanup := newTestService(t) + svc, db, cleanup := testutils.NewTestService(t) defer cleanup() db_rows := setupQueryTest(t, db) @@ -195,7 +80,7 @@ func TestQueryAllEnvelopes(t *testing.T) { } func TestQueryPagedEnvelopes(t *testing.T) { - svc, db, cleanup := newTestService(t) + svc, db, cleanup := testutils.NewTestService(t) defer cleanup() db_rows := setupQueryTest(t, db) @@ -211,7 +96,7 @@ func TestQueryPagedEnvelopes(t *testing.T) { } func TestQueryEnvelopesByOriginator(t *testing.T) { - svc, db, cleanup := newTestService(t) + svc, db, cleanup := testutils.NewTestService(t) defer cleanup() db_rows := setupQueryTest(t, db) @@ -232,7 +117,7 @@ func TestQueryEnvelopesByOriginator(t *testing.T) { } func TestQueryEnvelopesByTopic(t *testing.T) { - svc, db, cleanup := newTestService(t) + svc, db, cleanup := testutils.NewTestService(t) defer cleanup() db_rows := setupQueryTest(t, db) @@ -251,7 +136,7 @@ func TestQueryEnvelopesByTopic(t *testing.T) { } func TestQueryEnvelopesFromLastSeen(t *testing.T) { - svc, db, cleanup := newTestService(t) + svc, db, cleanup := testutils.NewTestService(t) defer cleanup() db_rows := setupQueryTest(t, db) @@ -270,7 +155,7 @@ func TestQueryEnvelopesFromLastSeen(t *testing.T) { } func TestQueryEnvelopesWithEmptyResult(t *testing.T) { - svc, db, cleanup := newTestService(t) + svc, db, cleanup := testutils.NewTestService(t) defer cleanup() db_rows := setupQueryTest(t, db) diff --git a/pkg/db/subscription_test.go b/pkg/db/subscription_test.go index 213ad614..42e52057 100644 --- a/pkg/db/subscription_test.go +++ b/pkg/db/subscription_test.go @@ -1,4 +1,4 @@ -package db +package db_test import ( "context" @@ -8,6 +8,7 @@ import ( "time" "github.com/stretchr/testify/require" + "github.com/xmtp/xmtpd/pkg/db" "github.com/xmtp/xmtpd/pkg/db/queries" "github.com/xmtp/xmtpd/pkg/testutils" "go.uber.org/zap" @@ -15,15 +16,15 @@ import ( func setup(t *testing.T) (*sql.DB, *zap.Logger, func()) { ctx := context.Background() - db, _, dbCleanup := testutils.NewDB(t, ctx) + store, _, storeCleanup := testutils.NewDB(t, ctx) log, err := zap.NewDevelopment() require.NoError(t, err) - return db, log, dbCleanup + return store, log, storeCleanup } -func insertInitialRows(t *testing.T, db *sql.DB) { - testutils.InsertGatewayEnvelopes(t, db, []queries.InsertGatewayEnvelopeParams{ +func insertInitialRows(t *testing.T, store *sql.DB) { + testutils.InsertGatewayEnvelopes(t, store, []queries.InsertGatewayEnvelopeParams{ { OriginatorNodeID: 1, OriginatorSequenceID: 1, @@ -39,12 +40,12 @@ func insertInitialRows(t *testing.T, db *sql.DB) { }) } -func envelopesQuery(db *sql.DB) PollableDBQuery[queries.GatewayEnvelope, VectorClock] { - return func(ctx context.Context, lastSeen VectorClock, numRows int32) ([]queries.GatewayEnvelope, VectorClock, error) { - envs, err := queries.New(db). - SelectGatewayEnvelopes(ctx, *SetVectorClock(&queries.SelectGatewayEnvelopesParams{ - OriginatorNodeID: NullInt32(1), - RowLimit: NullInt32(numRows), +func envelopesQuery(store *sql.DB) db.PollableDBQuery[queries.GatewayEnvelope, db.VectorClock] { + return func(ctx context.Context, lastSeen db.VectorClock, numRows int32) ([]queries.GatewayEnvelope, db.VectorClock, error) { + envs, err := queries.New(store). + SelectGatewayEnvelopes(ctx, *db.SetVectorClock(&queries.SelectGatewayEnvelopesParams{ + OriginatorNodeID: db.NullInt32(1), + RowLimit: db.NullInt32(numRows), }, lastSeen)) if err != nil { return nil, lastSeen, err @@ -56,8 +57,8 @@ func envelopesQuery(db *sql.DB) PollableDBQuery[queries.GatewayEnvelope, VectorC } } -func insertAdditionalRows(t *testing.T, db *sql.DB, notifyChan ...chan bool) { - testutils.InsertGatewayEnvelopes(t, db, []queries.InsertGatewayEnvelopeParams{ +func insertAdditionalRows(t *testing.T, store *sql.DB, notifyChan ...chan bool) { + testutils.InsertGatewayEnvelopes(t, store, []queries.InsertGatewayEnvelopeParams{ { OriginatorNodeID: 1, OriginatorSequenceID: 2, @@ -99,10 +100,12 @@ func validateUpdates(t *testing.T, updates <-chan []queries.GatewayEnvelope, ctx // flakyEnvelopesQuery returns a query that fails every other time // to simulate a transient database error -func flakyEnvelopesQuery(db *sql.DB) PollableDBQuery[queries.GatewayEnvelope, VectorClock] { +func flakyEnvelopesQuery( + store *sql.DB, +) db.PollableDBQuery[queries.GatewayEnvelope, db.VectorClock] { numQueries := 0 - query := envelopesQuery(db) - return func(ctx context.Context, lastSeen VectorClock, numRows int32) ([]queries.GatewayEnvelope, VectorClock, error) { + query := envelopesQuery(store) + return func(ctx context.Context, lastSeen db.VectorClock, numRows int32) ([]queries.GatewayEnvelope, db.VectorClock, error) { numQueries++ if numQueries%2 == 1 { return nil, lastSeen, fmt.Errorf("flaky query") @@ -113,19 +116,19 @@ func flakyEnvelopesQuery(db *sql.DB) PollableDBQuery[queries.GatewayEnvelope, Ve } func TestIntervalSubscription(t *testing.T) { - db, log, cleanup := setup(t) + store, log, cleanup := setup(t) defer cleanup() - insertInitialRows(t, db) + insertInitialRows(t, store) // Create a subscription that polls every 100ms ctx, ctxCancel := context.WithCancel(context.Background()) - subscription := NewDBSubscription( + subscription := db.NewDBSubscription( ctx, log, - envelopesQuery(db), - VectorClock{1: 1}, - PollingOptions{ + envelopesQuery(store), + db.VectorClock{1: 1}, + db.PollingOptions{ Interval: 100 * time.Millisecond, NumRows: 1, }, @@ -133,25 +136,25 @@ func TestIntervalSubscription(t *testing.T) { updates, err := subscription.Start() require.NoError(t, err) - insertAdditionalRows(t, db) + insertAdditionalRows(t, store) validateUpdates(t, updates, ctxCancel) } func TestNotifiedSubscription(t *testing.T) { - db, log, cleanup := setup(t) + store, log, cleanup := setup(t) defer cleanup() - insertInitialRows(t, db) + insertInitialRows(t, store) // Create a subscription that polls every 100ms ctx, ctxCancel := context.WithCancel(context.Background()) notifyChan := make(chan bool) - subscription := NewDBSubscription( + subscription := db.NewDBSubscription( ctx, log, - envelopesQuery(db), - VectorClock{1: 1}, - PollingOptions{ + envelopesQuery(store), + db.VectorClock{1: 1}, + db.PollingOptions{ Notifier: notifyChan, Interval: 30 * time.Second, NumRows: 1, @@ -160,24 +163,24 @@ func TestNotifiedSubscription(t *testing.T) { updates, err := subscription.Start() require.NoError(t, err) - insertAdditionalRows(t, db, notifyChan) + insertAdditionalRows(t, store, notifyChan) validateUpdates(t, updates, ctxCancel) } func TestTemporaryDBError(t *testing.T) { - db, log, cleanup := setup(t) + store, log, cleanup := setup(t) defer cleanup() - insertInitialRows(t, db) + insertInitialRows(t, store) // Create a subscription that polls every 100ms ctx, ctxCancel := context.WithCancel(context.Background()) - subscription := NewDBSubscription( + subscription := db.NewDBSubscription( ctx, log, - flakyEnvelopesQuery(db), - VectorClock{1: 1}, - PollingOptions{ + flakyEnvelopesQuery(store), + db.VectorClock{1: 1}, + db.PollingOptions{ Interval: 100 * time.Millisecond, NumRows: 1, }, @@ -185,6 +188,6 @@ func TestTemporaryDBError(t *testing.T) { updates, err := subscription.Start() require.NoError(t, err) - insertAdditionalRows(t, db) + insertAdditionalRows(t, store) validateUpdates(t, updates, ctxCancel) } diff --git a/pkg/testutils/service.go b/pkg/testutils/service.go new file mode 100644 index 00000000..c83324f8 --- /dev/null +++ b/pkg/testutils/service.go @@ -0,0 +1,38 @@ +package testutils + +import ( + "context" + "database/sql" + "testing" + + "github.com/ethereum/go-ethereum/crypto" + "github.com/stretchr/testify/require" + "github.com/xmtp/xmtpd/pkg/api" + "github.com/xmtp/xmtpd/pkg/db/queries" + mocks "github.com/xmtp/xmtpd/pkg/mocks/registry" + "github.com/xmtp/xmtpd/pkg/registrant" + "github.com/xmtp/xmtpd/pkg/registry" +) + +func NewTestService(t *testing.T) (*api.Service, *sql.DB, func()) { + ctx := context.Background() + log := NewLog(t) + db, _, dbCleanup := NewDB(t, ctx) + privKey, err := crypto.GenerateKey() + require.NoError(t, err) + privKeyStr := "0x" + HexEncode(crypto.FromECDSA(privKey)) + mockRegistry := mocks.NewMockNodeRegistry(t) + mockRegistry.EXPECT().GetNodes().Return([]registry.Node{ + {NodeID: 1, SigningKey: &privKey.PublicKey}, + }, nil) + registrant, err := registrant.NewRegistrant(ctx, queries.New(db), mockRegistry, privKeyStr) + require.NoError(t, err) + + svc, err := api.NewReplicationApiService(ctx, log, registrant, db) + require.NoError(t, err) + + return svc, db, func() { + svc.Close() + dbCleanup() + } +}