diff --git a/README.md b/README.md index 7a53547b..f3bcbb83 100644 --- a/README.md +++ b/README.md @@ -96,7 +96,7 @@ dev/generate Create a new migration by running: ```sh -dev/gen-migration +dev/gen-migration {migration-name} ``` Fill in the migrations in the generated files. If you are unfamiliar with migrations, you may follow [this guide](https://github.com/golang-migrate/migrate/blob/master/MIGRATIONS.md). The database is PostgreSQL and the driver is PGX. @@ -106,5 +106,5 @@ Fill in the migrations in the generated files. If you are unfamiliar with migrat We use [sqlc](https://docs.sqlc.dev/en/latest/index.html) to generate the code for our DB queries. Modify the `queries.sql` file, and then run: ```sh -sqlc generate +dev/generate ``` diff --git a/dev/generate b/dev/generate index 218f4b1c..909a0705 100755 --- a/dev/generate +++ b/dev/generate @@ -6,6 +6,7 @@ go generate ./... rm -f pkg/mocks/* ./dev/abigen mockery +sqlc generate rm -rf pkg/proto/**/*.pb.go pkg/proto/**/*.pb.gw.go pkg/proto/**/*.swagger.json if ! buf generate https://github.com/xmtp/proto.git#subdir=proto; then diff --git a/pkg/db/queries.sql b/pkg/db/queries.sql index 305fc6ba..072fd5b4 100644 --- a/pkg/db/queries.sql +++ b/pkg/db/queries.sql @@ -1,13 +1,39 @@ --- name: InsertStagedOriginatorEnvelope :one -INSERT INTO staged_originator_envelopes(payer_envelope) - VALUES (@payer_envelope) -RETURNING - *; - --- name: InsertNodeInfo :one +-- name: InsertNodeInfo :execrows INSERT INTO node_info(node_id, public_key) VALUES (@node_id, @public_key) - RETURNING *; +ON CONFLICT + DO NOTHING; -- name: SelectNodeInfo :one -SELECT * FROM node_info WHERE singleton_id = 1; +SELECT + * +FROM + node_info +WHERE + singleton_id = 1; + +-- name: InsertGatewayEnvelope :execrows +SELECT + insert_gateway_envelope(@originator_id, @sequence_id, @topic, @originator_envelope); + +-- name: InsertStagedOriginatorEnvelope :one +SELECT + * +FROM + insert_staged_originator_envelope(@payer_envelope); + +-- name: SelectStagedOriginatorEnvelopes :many +SELECT + * +FROM + staged_originator_envelopes +WHERE + id > @last_seen_id +ORDER BY + id ASC +LIMIT @num_rows; + +-- name: DeleteStagedOriginatorEnvelope :execrows +DELETE FROM staged_originator_envelopes +WHERE id = @id; + diff --git a/pkg/db/queries/models.go b/pkg/db/queries/models.go index 4b9e273f..e363190d 100644 --- a/pkg/db/queries/models.go +++ b/pkg/db/queries/models.go @@ -18,7 +18,8 @@ type AddressLog struct { type GatewayEnvelope struct { ID int64 - OriginatorSid int64 + OriginatorID int32 + SequenceID int64 Topic []byte OriginatorEnvelope []byte } diff --git a/pkg/db/queries/queries.sql.go b/pkg/db/queries/queries.sql.go index 16caf939..978fc077 100644 --- a/pkg/db/queries/queries.sql.go +++ b/pkg/db/queries/queries.sql.go @@ -9,10 +9,49 @@ import ( "context" ) -const insertNodeInfo = `-- name: InsertNodeInfo :one +const deleteStagedOriginatorEnvelope = `-- name: DeleteStagedOriginatorEnvelope :execrows +DELETE FROM staged_originator_envelopes +WHERE id = $1 +` + +func (q *Queries) DeleteStagedOriginatorEnvelope(ctx context.Context, id int64) (int64, error) { + result, err := q.db.ExecContext(ctx, deleteStagedOriginatorEnvelope, id) + if err != nil { + return 0, err + } + return result.RowsAffected() +} + +const insertGatewayEnvelope = `-- name: InsertGatewayEnvelope :execrows +SELECT + insert_gateway_envelope($1, $2, $3, $4) +` + +type InsertGatewayEnvelopeParams struct { + OriginatorID int32 + SequenceID int64 + Topic []byte + OriginatorEnvelope []byte +} + +func (q *Queries) InsertGatewayEnvelope(ctx context.Context, arg InsertGatewayEnvelopeParams) (int64, error) { + result, err := q.db.ExecContext(ctx, insertGatewayEnvelope, + arg.OriginatorID, + arg.SequenceID, + arg.Topic, + arg.OriginatorEnvelope, + ) + if err != nil { + return 0, err + } + return result.RowsAffected() +} + +const insertNodeInfo = `-- name: InsertNodeInfo :execrows INSERT INTO node_info(node_id, public_key) VALUES ($1, $2) - RETURNING node_id, public_key, singleton_id +ON CONFLICT + DO NOTHING ` type InsertNodeInfoParams struct { @@ -20,18 +59,19 @@ type InsertNodeInfoParams struct { PublicKey []byte } -func (q *Queries) InsertNodeInfo(ctx context.Context, arg InsertNodeInfoParams) (NodeInfo, error) { - row := q.db.QueryRowContext(ctx, insertNodeInfo, arg.NodeID, arg.PublicKey) - var i NodeInfo - err := row.Scan(&i.NodeID, &i.PublicKey, &i.SingletonID) - return i, err +func (q *Queries) InsertNodeInfo(ctx context.Context, arg InsertNodeInfoParams) (int64, error) { + result, err := q.db.ExecContext(ctx, insertNodeInfo, arg.NodeID, arg.PublicKey) + if err != nil { + return 0, err + } + return result.RowsAffected() } const insertStagedOriginatorEnvelope = `-- name: InsertStagedOriginatorEnvelope :one -INSERT INTO staged_originator_envelopes(payer_envelope) - VALUES ($1) -RETURNING +SELECT id, originator_time, payer_envelope +FROM + insert_staged_originator_envelope($1) ` func (q *Queries) InsertStagedOriginatorEnvelope(ctx context.Context, payerEnvelope []byte) (StagedOriginatorEnvelope, error) { @@ -42,7 +82,12 @@ func (q *Queries) InsertStagedOriginatorEnvelope(ctx context.Context, payerEnvel } const selectNodeInfo = `-- name: SelectNodeInfo :one -SELECT node_id, public_key, singleton_id FROM node_info WHERE singleton_id = 1 +SELECT + node_id, public_key, singleton_id +FROM + node_info +WHERE + singleton_id = 1 ` func (q *Queries) SelectNodeInfo(ctx context.Context) (NodeInfo, error) { @@ -51,3 +96,43 @@ func (q *Queries) SelectNodeInfo(ctx context.Context) (NodeInfo, error) { err := row.Scan(&i.NodeID, &i.PublicKey, &i.SingletonID) return i, err } + +const selectStagedOriginatorEnvelopes = `-- name: SelectStagedOriginatorEnvelopes :many +SELECT + id, originator_time, payer_envelope +FROM + staged_originator_envelopes +WHERE + id > $1 +ORDER BY + id ASC +LIMIT $2 +` + +type SelectStagedOriginatorEnvelopesParams struct { + LastSeenID int64 + NumRows int32 +} + +func (q *Queries) SelectStagedOriginatorEnvelopes(ctx context.Context, arg SelectStagedOriginatorEnvelopesParams) ([]StagedOriginatorEnvelope, error) { + rows, err := q.db.QueryContext(ctx, selectStagedOriginatorEnvelopes, arg.LastSeenID, arg.NumRows) + if err != nil { + return nil, err + } + defer rows.Close() + var items []StagedOriginatorEnvelope + for rows.Next() { + var i StagedOriginatorEnvelope + if err := rows.Scan(&i.ID, &i.OriginatorTime, &i.PayerEnvelope); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/pkg/migrations/00001_init-schema.up.sql b/pkg/migrations/00001_init-schema.up.sql index a9ff3532..191da914 100644 --- a/pkg/migrations/00001_init-schema.up.sql +++ b/pkg/migrations/00001_init-schema.up.sql @@ -3,7 +3,6 @@ CREATE TABLE node_info( node_id INTEGER NOT NULL, public_key BYTEA NOT NULL, - singleton_id SMALLINT PRIMARY KEY DEFAULT 1, CONSTRAINT is_singleton CHECK (singleton_id = 1) ); @@ -12,15 +11,35 @@ CREATE TABLE node_info( CREATE TABLE gateway_envelopes( -- used to construct gateway_sid id BIGSERIAL PRIMARY KEY, - originator_sid BIGINT NOT NULL, + originator_id INT NOT NULL, + sequence_id BIGINT NOT NULL, topic BYTEA NOT NULL, originator_envelope BYTEA NOT NULL ); + -- Client queries CREATE INDEX idx_gateway_envelopes_topic ON gateway_envelopes(topic); + -- Node queries -CREATE UNIQUE INDEX idx_gateway_envelopes_originator_sid ON gateway_envelopes(originator_sid); +CREATE UNIQUE INDEX idx_gateway_envelopes_originator_sid ON gateway_envelopes(originator_id, sequence_id); +CREATE FUNCTION insert_gateway_envelope(originator_id INT, sequence_id BIGINT, topic BYTEA, originator_envelope BYTEA) + RETURNS SETOF gateway_envelopes + AS $$ +BEGIN + -- Ensures that the generated sequence ID matches the insertion order + -- Only released at the end of the enclosing transaction - beware if called within a long transaction + PERFORM + pg_advisory_xact_lock(hashtext('gateway_envelopes_sequence')); + RETURN QUERY INSERT INTO gateway_envelopes(originator_id, sequence_id, topic, originator_envelope) + VALUES(originator_id, sequence_id, topic, originator_envelope) + ON CONFLICT + DO NOTHING + RETURNING + *; +END; +$$ +LANGUAGE plpgsql; -- Process for originating envelopes: -- 1. Perform any necessary validation @@ -38,6 +57,22 @@ CREATE TABLE staged_originator_envelopes( payer_envelope BYTEA NOT NULL ); +CREATE FUNCTION insert_staged_originator_envelope(payer_envelope BYTEA) + RETURNS SETOF staged_originator_envelopes + AS $$ +BEGIN + PERFORM + pg_advisory_xact_lock(hashtext('staged_originator_envelopes_sequence')); + RETURN QUERY INSERT INTO staged_originator_envelopes(payer_envelope) + VALUES(payer_envelope) + ON CONFLICT + DO NOTHING + RETURNING + *; +END; +$$ +LANGUAGE plpgsql; + -- A cached view for looking up the inbox_id that an address belongs to. -- Relies on a total ordering of updates across all inbox_ids, from which this -- view can be deterministically generated. @@ -46,6 +81,6 @@ CREATE TABLE address_log( inbox_id BYTEA NOT NULL, association_sequence_id BIGINT, revocation_sequence_id BIGINT, - PRIMARY KEY (address, inbox_id) ); + diff --git a/pkg/registrant/registrant.go b/pkg/registrant/registrant.go index f4629ba8..e642a7e1 100644 --- a/pkg/registrant/registrant.go +++ b/pkg/registrant/registrant.go @@ -49,7 +49,7 @@ func NewRegistrant( } func (r *Registrant) sid(localID int64) (uint64, error) { - if !utils.IsValidLocalID(localID) { + if !utils.IsValidSequenceID(localID) { return 0, fmt.Errorf("Invalid local ID %d, likely due to ID exhaustion", localID) } return utils.SID(r.record.NodeID, localID), nil @@ -121,7 +121,7 @@ func getRegistryRecord( // - Running multiple nodes with different private keys against the same DB // - Changing a server's configuration while pointing to data in an existing DB func ensureDatabaseMatches(ctx context.Context, db *queries.Queries, record *registry.Node) error { - _, err := db.InsertNodeInfo( + numRows, err := db.InsertNodeInfo( ctx, queries.InsertNodeInfoParams{ NodeID: int32(record.NodeID), @@ -129,6 +129,10 @@ func ensureDatabaseMatches(ctx context.Context, db *queries.Queries, record *reg }, ) if err != nil { + return fmt.Errorf("unable to insert node info into database: %v", err) + } + + if numRows == 0 { nodeInfo, err := db.SelectNodeInfo(ctx) if err != nil { return fmt.Errorf("unable to retrieve node info from database: %v", err) diff --git a/pkg/utils/sid.go b/pkg/utils/sid.go index b4b8acc6..8229b0cc 100644 --- a/pkg/utils/sid.go +++ b/pkg/utils/sid.go @@ -1,13 +1,13 @@ package utils // SIDS are 64-bit numbers consisting of 16 bits for the node ID -// followed by 48 bits for the sequence ID (local ID). This file +// followed by 48 bits for the sequence ID. This file // contains methods for reading and constructing sids. // // We also leverage type-checking throughout the repo to avoid confusion: // - SIDs are uint64 // - node IDs are uint16 -// - local IDs are int64 +// - sequence IDs are int64 const ( // Number of bits used for node ID @@ -20,7 +20,7 @@ const ( localIDMask uint64 = ^nodeIDMask ) -func IsValidLocalID(localID int64) bool { +func IsValidSequenceID(localID int64) bool { return localID > 0 && localID>>localIDBits == 0 } @@ -28,7 +28,7 @@ func NodeID(sid uint64) uint16 { return uint16(sid >> localIDBits) } -func LocalID(sid uint64) int64 { +func SequenceID(sid uint64) int64 { return int64(sid & localIDMask) } diff --git a/pkg/utils/sid_test.go b/pkg/utils/sid_test.go index 7880b2aa..bbe67b22 100644 --- a/pkg/utils/sid_test.go +++ b/pkg/utils/sid_test.go @@ -6,20 +6,20 @@ import ( "github.com/stretchr/testify/require" ) -func TestInvalidLocalID(t *testing.T) { - localID := int64(-1) - require.False(t, IsValidLocalID(localID)) - localID = int64(0) - require.False(t, IsValidLocalID(localID)) - localID = int64(0b0000000000000001000000000000000000000000000000000000000000000000) - require.False(t, IsValidLocalID(localID)) +func TestInvalidSequenceID(t *testing.T) { + sequenceID := int64(-1) + require.False(t, IsValidSequenceID(sequenceID)) + sequenceID = int64(0) + require.False(t, IsValidSequenceID(sequenceID)) + sequenceID = int64(0b0000000000000001000000000000000000000000000000000000000000000000) + require.False(t, IsValidSequenceID(sequenceID)) } -func TestValidLocalID(t *testing.T) { - localID := int64(1) - require.True(t, IsValidLocalID(localID)) - localID = int64(0b0000000000000000111111111111111111111111111111111111111111111111) - require.True(t, IsValidLocalID(localID)) +func TestValidSequenceID(t *testing.T) { + sequenceID := int64(1) + require.True(t, IsValidSequenceID(sequenceID)) + sequenceID = int64(0b0000000000000000111111111111111111111111111111111111111111111111) + require.True(t, IsValidSequenceID(sequenceID)) } func TestGetNodeID(t *testing.T) { @@ -29,23 +29,23 @@ func TestGetNodeID(t *testing.T) { require.Equal(t, uint16(1), NodeID(sid)) } -func TestGetLocalID(t *testing.T) { +func TestGetSequenceID(t *testing.T) { sid := uint64(0b0000000000000001111111111111111111111111111111111111111111111111) - require.Equal(t, int64(0b0000000000000000111111111111111111111111111111111111111111111111), LocalID(sid)) + require.Equal(t, int64(0b0000000000000000111111111111111111111111111111111111111111111111), SequenceID(sid)) sid = uint64(0b0000000000000001000000000000000000000000000000000000000000000000) - require.Equal(t, int64(0), LocalID(sid)) + require.Equal(t, int64(0), SequenceID(sid)) sid = uint64(0b0000000000000001000000000000000000000000000000000000000000000001) - require.Equal(t, int64(1), LocalID(sid)) + require.Equal(t, int64(1), SequenceID(sid)) } func TestGetSID(t *testing.T) { nodeID := uint16(1) - localID := int64(1) - require.Equal(t, uint64(0b0000000000000001000000000000000000000000000000000000000000000001), SID(nodeID, localID)) + sequenceID := int64(1) + require.Equal(t, uint64(0b0000000000000001000000000000000000000000000000000000000000000001), SID(nodeID, sequenceID)) nodeID = uint16(1) - localID = int64(0) - require.Equal(t, uint64(0b0000000000000001000000000000000000000000000000000000000000000000), SID(nodeID, localID)) + sequenceID = int64(0) + require.Equal(t, uint64(0b0000000000000001000000000000000000000000000000000000000000000000), SID(nodeID, sequenceID)) nodeID = uint16(0) - localID = int64(1) - require.Equal(t, uint64(0b0000000000000000000000000000000000000000000000000000000000000001), SID(nodeID, localID)) + sequenceID = int64(1) + require.Equal(t, uint64(0b0000000000000000000000000000000000000000000000000000000000000001), SID(nodeID, sequenceID)) }