From d060cb99c97b7fbf7f2f85745e23784f4497f816 Mon Sep 17 00:00:00 2001 From: Milos Zivkovic Date: Wed, 4 Oct 2023 16:55:32 +0200 Subject: [PATCH] Add unit test for restore execution --- backup/backup.go | 41 ++++++++++++++++----------- backup/backup_test.go | 12 +++++++- restore/config.go | 30 -------------------- restore/config_test.go | 28 ------------------ restore/mock_test.go | 49 ++++++++++++++++++++++++++++++++ restore/restore.go | 42 ++++++++++++++------------- restore/restore_test.go | 63 +++++++++++++++++++++++++++++++++++++++++ 7 files changed, 170 insertions(+), 95 deletions(-) delete mode 100644 restore/config.go delete mode 100644 restore/config_test.go create mode 100644 restore/mock_test.go create mode 100644 restore/restore_test.go diff --git a/backup/backup.go b/backup/backup.go index 4f6e50f..4238bb9 100644 --- a/backup/backup.go +++ b/backup/backup.go @@ -1,6 +1,7 @@ package backup import ( + "context" "encoding/json" "fmt" "io" @@ -12,6 +13,7 @@ import ( // ExecuteBackup executes the node backup process func ExecuteBackup( + ctx context.Context, client client.Client, writer io.Writer, logger log.Logger, @@ -30,26 +32,33 @@ func ExecuteBackup( // Gather the chain data from the node for block := cfg.FromBlock; block <= toBlock; block++ { - txs, txErr := client.GetBlockTransactions(block) - if txErr != nil { - return fmt.Errorf("unable to fetch block transactions, %w", txErr) - } - - // Save the block transaction data, if any - for _, tx := range txs { - data := &types.TxData{ - Tx: tx, - BlockNum: block, + select { + case <-ctx.Done(): + logger.Info("export procedure stopped") + + return nil + default: + txs, txErr := client.GetBlockTransactions(block) + if txErr != nil { + return fmt.Errorf("unable to fetch block transactions, %w", txErr) } - // Write the tx data to the file - if writeErr := writeTxData(writer, data); writeErr != nil { - return fmt.Errorf("unable to write tx data, %w", writeErr) + // Save the block transaction data, if any + for _, tx := range txs { + data := &types.TxData{ + Tx: tx, + BlockNum: block, + } + + // Write the tx data to the file + if writeErr := writeTxData(writer, data); writeErr != nil { + return fmt.Errorf("unable to write tx data, %w", writeErr) + } } - } - // Log the progress - logProgress(logger, cfg.FromBlock, toBlock, block) + // Log the progress + logProgress(logger, cfg.FromBlock, toBlock, block) + } } return nil diff --git a/backup/backup_test.go b/backup/backup_test.go index 07266bf..035fafa 100644 --- a/backup/backup_test.go +++ b/backup/backup_test.go @@ -2,6 +2,7 @@ package backup import ( "bufio" + "context" "encoding/json" "errors" "os" @@ -119,7 +120,16 @@ func TestBackup_ExecuteBackup(t *testing.T) { cfg.ToBlock = &toBlock // Run the backup procedure - require.NoError(t, ExecuteBackup(mockClient, tempFile, noop.New(), cfg)) + require.NoError( + t, + ExecuteBackup( + context.Background(), + mockClient, + tempFile, + noop.New(), + cfg, + ), + ) // Read the output file fileRaw, err := os.Open(tempFile.Name()) diff --git a/restore/config.go b/restore/config.go deleted file mode 100644 index f070a92..0000000 --- a/restore/config.go +++ /dev/null @@ -1,30 +0,0 @@ -package restore - -import "errors" - -const ( - DefaultRemote = "http://127.0.0.1:26657" -) - -var errInvalidRemote = errors.New("invalid remote address") - -// Config is the base chain restore config -type Config struct { - Remote string // the remote JSON-RPC URL of the chain -} - -// DefaultConfig returns the default restore configuration -func DefaultConfig() Config { - return Config{ - Remote: DefaultRemote, - } -} - -// ValidateConfig validates the base restore configuration -func ValidateConfig(cfg Config) error { - if cfg.Remote == "" { - return errInvalidRemote - } - - return nil -} diff --git a/restore/config_test.go b/restore/config_test.go deleted file mode 100644 index a585d00..0000000 --- a/restore/config_test.go +++ /dev/null @@ -1,28 +0,0 @@ -package restore - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestConfig_ValidateConfig(t *testing.T) { - t.Parallel() - - t.Run("invalid remote address", func(t *testing.T) { - t.Parallel() - - cfg := DefaultConfig() - cfg.Remote = "" - - assert.ErrorIs(t, ValidateConfig(cfg), errInvalidRemote) - }) - - t.Run("valid configuration", func(t *testing.T) { - t.Parallel() - - cfg := DefaultConfig() - - assert.NoError(t, ValidateConfig(cfg)) - }) -} diff --git a/restore/mock_test.go b/restore/mock_test.go new file mode 100644 index 0000000..06f8913 --- /dev/null +++ b/restore/mock_test.go @@ -0,0 +1,49 @@ +package restore + +import ( + "context" + + "github.com/gnolang/gno/tm2/pkg/std" +) + +type ( + sendTransactionDelegate func(*std.Tx) error +) + +type mockClient struct { + sendTransactionFn sendTransactionDelegate +} + +func (m *mockClient) SendTransaction(tx *std.Tx) error { + if m.sendTransactionFn != nil { + return m.sendTransactionFn(tx) + } + + return nil +} + +type ( + nextDelegate func(context.Context) (*std.Tx, error) + closeDelegate func() error +) + +type mockSource struct { + nextFn nextDelegate + closeFn closeDelegate +} + +func (m *mockSource) Next(ctx context.Context) (*std.Tx, error) { + if m.nextFn != nil { + return m.nextFn(ctx) + } + + return nil, nil +} + +func (m *mockSource) Close() error { + if m.closeFn != nil { + return m.closeFn() + } + + return nil +} diff --git a/restore/restore.go b/restore/restore.go index cca6b40..3e51b89 100644 --- a/restore/restore.go +++ b/restore/restore.go @@ -14,17 +14,13 @@ import ( // ExecuteRestore executes the node restore process func ExecuteRestore( + ctx context.Context, client client.Client, source source.Source, logger log.Logger, - cfg Config, ) error { - // Verify the config - if cfgErr := ValidateConfig(cfg); cfgErr != nil { - return fmt.Errorf("invalid config, %w", cfgErr) - } - - defer func() { + // Set up the teardown + teardown := func() { if closeErr := source.Close(); closeErr != nil { logger.Error( "unable to gracefully close source", @@ -32,7 +28,9 @@ func ExecuteRestore( closeErr.Error(), ) } - }() + } + + defer teardown() var ( tx *std.Tx @@ -42,8 +40,12 @@ func ExecuteRestore( ) // Fetch next transactions - // TODO add ctx - for tx, nextErr = source.Next(context.Background()); nextErr == nil; { + for nextErr == nil { + tx, nextErr = source.Next(ctx) + if nextErr != nil { + break + } + // Send the transaction if sendErr := client.SendTransaction(tx); sendErr != nil { // Invalid transaction sends are only logged, @@ -61,16 +63,16 @@ func ExecuteRestore( } // Check if this is the end of the road - if errors.Is(nextErr, io.EOF) { - // No more transactions to apply - logger.Info( - "restore process finished", - "total", - totalTxs, - ) - - return nil + if !errors.Is(nextErr, io.EOF) { + return fmt.Errorf("unable to get next transaction, %w", nextErr) } - return fmt.Errorf("unable to get next transaction, %w", nextErr) + // No more transactions to apply + logger.Info( + "restore process finished", + "total", + totalTxs, + ) + + return nil } diff --git a/restore/restore_test.go b/restore/restore_test.go new file mode 100644 index 0000000..7940e13 --- /dev/null +++ b/restore/restore_test.go @@ -0,0 +1,63 @@ +package restore + +import ( + "context" + "io" + "testing" + + "github.com/gnolang/gno/tm2/pkg/std" + "github.com/gnolang/tx-archive/log/noop" + "github.com/stretchr/testify/assert" +) + +func TestRestore_ExecuteRestore(t *testing.T) { + t.Parallel() + + var ( + exampleTxCount = 10 + exampleTxGiven = 0 + + exampleTx = &std.Tx{ + Memo: "example tx", + } + + sentTxs = make([]*std.Tx, 0) + + mockClient = &mockClient{ + sendTransactionFn: func(tx *std.Tx) error { + sentTxs = append(sentTxs, tx) + + return nil + }, + } + mockSource = &mockSource{ + nextFn: func(ctx context.Context) (*std.Tx, error) { + if exampleTxGiven == exampleTxCount { + return nil, io.EOF + } + + exampleTxGiven++ + + return exampleTx, nil + }, + } + ) + + // Execute the restore + assert.NoError( + t, + ExecuteRestore( + context.Background(), + mockClient, + mockSource, + noop.New(), + ), + ) + + // Verify the restore was correct + assert.Len(t, sentTxs, exampleTxCount) + + for _, tx := range sentTxs { + assert.Equal(t, exampleTx, tx) + } +}