diff --git a/docs/relay/README.md b/docs/relay/README.md index f1dbffe81..07476babb 100644 --- a/docs/relay/README.md +++ b/docs/relay/README.md @@ -43,7 +43,8 @@ chainlink nodes solana create --name= --chain-id= --url=1: Uses multiple blocks, where n is number of blocks. DISCLAIMER: 1:1 ratio between n and RPC calls. + BlockHistorySize: ptr(uint64(15)), // 1: uses latest block; >1: Uses multiple blocks, where n is number of blocks. DISCLAIMER: 1:1 ratio between n and RPC calls. ComputeUnitLimitDefault: ptr(uint32(200_000)), // set to 0 to disable adding compute unit limit EstimateComputeUnitLimit: ptr(false), // set to false to disable compute unit limit estimation } @@ -43,6 +44,7 @@ type Config interface { TxTimeout() time.Duration TxRetryTimeout() time.Duration TxConfirmTimeout() time.Duration + TxExpirationRebroadcast() bool TxRetentionTimeout() time.Duration SkipPreflight() bool Commitment() rpc.CommitmentType @@ -68,6 +70,7 @@ type Chain struct { TxTimeout *config.Duration TxRetryTimeout *config.Duration TxConfirmTimeout *config.Duration + TxExpirationRebroadcast *bool TxRetentionTimeout *config.Duration SkipPreflight *bool Commitment *string @@ -105,6 +108,9 @@ func (c *Chain) SetDefaults() { if c.TxConfirmTimeout == nil { c.TxConfirmTimeout = defaultConfigSet.TxConfirmTimeout } + if c.TxExpirationRebroadcast == nil { + c.TxExpirationRebroadcast = defaultConfigSet.TxExpirationRebroadcast + } if c.TxRetentionTimeout == nil { c.TxRetentionTimeout = defaultConfigSet.TxRetentionTimeout } diff --git a/pkg/solana/config/mocks/config.go b/pkg/solana/config/mocks/config.go index 6f9ab913d..0ea855b0f 100644 --- a/pkg/solana/config/mocks/config.go +++ b/pkg/solana/config/mocks/config.go @@ -789,6 +789,51 @@ func (_c *Config_TxConfirmTimeout_Call) RunAndReturn(run func() time.Duration) * return _c } +// TxExpirationRebroadcast provides a mock function with given fields: +func (_m *Config) TxExpirationRebroadcast() bool { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for TxExpirationRebroadcast") + } + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// Config_TxExpirationRebroadcast_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'TxExpirationRebroadcast' +type Config_TxExpirationRebroadcast_Call struct { + *mock.Call +} + +// TxExpirationRebroadcast is a helper method to define mock.On call +func (_e *Config_Expecter) TxExpirationRebroadcast() *Config_TxExpirationRebroadcast_Call { + return &Config_TxExpirationRebroadcast_Call{Call: _e.mock.On("TxExpirationRebroadcast")} +} + +func (_c *Config_TxExpirationRebroadcast_Call) Run(run func()) *Config_TxExpirationRebroadcast_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *Config_TxExpirationRebroadcast_Call) Return(_a0 bool) *Config_TxExpirationRebroadcast_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Config_TxExpirationRebroadcast_Call) RunAndReturn(run func() bool) *Config_TxExpirationRebroadcast_Call { + _c.Call.Return(run) + return _c +} + // TxRetentionTimeout provides a mock function with given fields: func (_m *Config) TxRetentionTimeout() time.Duration { ret := _m.Called() diff --git a/pkg/solana/config/toml.go b/pkg/solana/config/toml.go index 6e9eadc5d..9d46c0b65 100644 --- a/pkg/solana/config/toml.go +++ b/pkg/solana/config/toml.go @@ -155,6 +155,9 @@ func setFromChain(c, f *Chain) { if f.TxConfirmTimeout != nil { c.TxConfirmTimeout = f.TxConfirmTimeout } + if f.TxExpirationRebroadcast != nil { + c.TxExpirationRebroadcast = f.TxExpirationRebroadcast + } if f.TxRetentionTimeout != nil { c.TxRetentionTimeout = f.TxRetentionTimeout } @@ -241,6 +244,10 @@ func (c *TOMLConfig) TxConfirmTimeout() time.Duration { return c.Chain.TxConfirmTimeout.Duration() } +func (c *TOMLConfig) TxExpirationRebroadcast() bool { + return *c.Chain.TxExpirationRebroadcast +} + func (c *TOMLConfig) TxRetentionTimeout() time.Duration { return c.Chain.TxRetentionTimeout.Duration() } diff --git a/pkg/solana/relay.go b/pkg/solana/relay.go index 1f2fbdffd..fca61ba9f 100644 --- a/pkg/solana/relay.go +++ b/pkg/solana/relay.go @@ -24,7 +24,15 @@ import ( var _ TxManager = (*txm.Txm)(nil) type TxManager interface { - Enqueue(ctx context.Context, accountID string, tx *solana.Transaction, txID *string, txCfgs ...txm.SetTxConfig) error + // Enqueue adds a tx to the txm queue for processing and submitting to the Solana network. + // An error is returned if the txm is not ready, if the tx is invalid, or if the queue is full. + // + // Important Notes: + // - The tx must contain at least one account key. The first account will be used to sign the tx (fee payer's public key). + // - txCfgs can be used to set custom tx configurations. + // - If a txID is provided, it will be used to identify the tx. Otherwise, a random UUID will be generated. + // - The caller needs to set the tx.Message.RecentBlockhash and provide the corresponding lastValidBlockHeight. These values are obtained from the GetLatestBlockhash RPC call. + Enqueue(ctx context.Context, accountID string, tx *solana.Transaction, txID *string, lastValidBlockHeight uint64, txCfgs ...txm.SetTxConfig) error } var _ relaytypes.Relayer = &Relayer{} //nolint:staticcheck diff --git a/pkg/solana/transmitter.go b/pkg/solana/transmitter.go index 951e9633e..537c72699 100644 --- a/pkg/solana/transmitter.go +++ b/pkg/solana/transmitter.go @@ -87,7 +87,7 @@ func (c *Transmitter) Transmit( // pass transmit payload to tx manager queue c.lggr.Debugf("Queuing transmit tx: state (%s) + transmissions (%s)", c.stateID.String(), c.transmissionsID.String()) - if err = c.txManager.Enqueue(ctx, c.stateID.String(), tx, nil); err != nil { + if err = c.txManager.Enqueue(ctx, c.stateID.String(), tx, nil, blockhash.Value.LastValidBlockHeight); err != nil { return fmt.Errorf("error on Transmit.txManager.Enqueue: %w", err) } return nil diff --git a/pkg/solana/transmitter_test.go b/pkg/solana/transmitter_test.go index 1d058d36a..f6db01d6c 100644 --- a/pkg/solana/transmitter_test.go +++ b/pkg/solana/transmitter_test.go @@ -27,7 +27,7 @@ type verifyTxSize struct { s *solana.PrivateKey } -func (txm verifyTxSize) Enqueue(_ context.Context, _ string, tx *solana.Transaction, txID *string, _ ...txm.SetTxConfig) error { +func (txm verifyTxSize) Enqueue(_ context.Context, _ string, tx *solana.Transaction, txID *string, _ uint64, _ ...txm.SetTxConfig) error { // additional components that transaction manager adds to the transaction require.NoError(txm.t, fees.SetComputeUnitPrice(tx, 0)) require.NoError(txm.t, fees.SetComputeUnitLimit(tx, 0)) diff --git a/pkg/solana/txm/pendingtx.go b/pkg/solana/txm/pendingtx.go index ecae7243b..b08039ab7 100644 --- a/pkg/solana/txm/pendingtx.go +++ b/pkg/solana/txm/pendingtx.go @@ -24,10 +24,13 @@ type PendingTxContext interface { New(msg pendingTx, sig solana.Signature, cancel context.CancelFunc) error // AddSignature adds a new signature for an existing transaction ID AddSignature(id string, sig solana.Signature) error - // Remove removes transaction and related signatures from storage if not in finalized or errored state - Remove(sig solana.Signature) (string, error) - // ListAll returns all of the signatures being tracked for all transactions not yet finalized or errored - ListAll() []solana.Signature + // Remove removes transaction, context and related signatures from storage associated to given tx id if not in finalized or errored state + Remove(id string) (string, error) + // ListAllSigs returns all of the signatures being tracked for all transactions not yet finalized or errored + ListAllSigs() []solana.Signature + // ListAllExpiredBroadcastedTxs returns all the txes that are in broadcasted state and have expired for given block number compared against lastValidBlockHeight (last valid block number) + // Passing maxUint64 as currBlockNumber will return all broadcasted txes. + ListAllExpiredBroadcastedTxs(currBlockNumber uint64) []pendingTx // Expired returns whether or not confirmation timeout amount of time has passed since creation Expired(sig solana.Signature, confirmationTimeout time.Duration) bool // OnProcessed marks transactions as Processed @@ -48,12 +51,13 @@ type PendingTxContext interface { // finishedTx is used to store info required to track transactions to finality or error type pendingTx struct { - tx solana.Transaction - cfg TxConfig - signatures []solana.Signature - id string - createTs time.Time - state TxState + tx solana.Transaction + cfg TxConfig + signatures []solana.Signature + id string + createTs time.Time + state TxState + lastValidBlockHeight uint64 // to track expiration, equivalent to last valid block number. } // finishedTx is used to store minimal info specifically for finalized or errored transactions for external status checks @@ -68,9 +72,9 @@ type pendingTxContext struct { cancelBy map[string]context.CancelFunc sigToID map[solana.Signature]string - broadcastedTxs map[string]pendingTx // transactions that require retry and bumping i.e broadcasted, processed - confirmedTxs map[string]pendingTx // transactions that require monitoring for re-org - finalizedErroredTxs map[string]finishedTx // finalized and errored transactions held onto for status + broadcastedProcessedTxs map[string]pendingTx // broadcasted and processed transactions that may require retry and bumping + confirmedTxs map[string]pendingTx // transactions that require monitoring for re-org + finalizedErroredTxs map[string]finishedTx // finalized and errored transactions held onto for status lock sync.RWMutex } @@ -80,9 +84,9 @@ func newPendingTxContext() *pendingTxContext { cancelBy: map[string]context.CancelFunc{}, sigToID: map[solana.Signature]string{}, - broadcastedTxs: map[string]pendingTx{}, - confirmedTxs: map[string]pendingTx{}, - finalizedErroredTxs: map[string]finishedTx{}, + broadcastedProcessedTxs: map[string]pendingTx{}, + confirmedTxs: map[string]pendingTx{}, + finalizedErroredTxs: map[string]finishedTx{}, } } @@ -92,8 +96,14 @@ func (c *pendingTxContext) New(tx pendingTx, sig solana.Signature, cancel contex if _, exists := c.sigToID[sig]; exists { return ErrSigAlreadyExists } - // validate id does not exist - if _, exists := c.broadcastedTxs[tx.id]; exists { + // Check if ID already exists in any of the maps + if _, exists := c.broadcastedProcessedTxs[tx.id]; exists { + return ErrIDAlreadyExists + } + if _, exists := c.confirmedTxs[tx.id]; exists { + return ErrIDAlreadyExists + } + if _, exists := c.finalizedErroredTxs[tx.id]; exists { return ErrIDAlreadyExists } return nil @@ -107,7 +117,14 @@ func (c *pendingTxContext) New(tx pendingTx, sig solana.Signature, cancel contex if _, exists := c.sigToID[sig]; exists { return "", ErrSigAlreadyExists } - if _, exists := c.broadcastedTxs[tx.id]; exists { + // Check if ID already exists in any of the maps + if _, exists := c.broadcastedProcessedTxs[tx.id]; exists { + return "", ErrIDAlreadyExists + } + if _, exists := c.confirmedTxs[tx.id]; exists { + return "", ErrIDAlreadyExists + } + if _, exists := c.finalizedErroredTxs[tx.id]; exists { return "", ErrIDAlreadyExists } // save cancel func @@ -118,7 +135,7 @@ func (c *pendingTxContext) New(tx pendingTx, sig solana.Signature, cancel contex tx.createTs = time.Now() tx.state = Broadcasted // save to the broadcasted map since transaction was just broadcasted - c.broadcastedTxs[tx.id] = tx + c.broadcastedProcessedTxs[tx.id] = tx return "", nil }) return err @@ -132,7 +149,7 @@ func (c *pendingTxContext) AddSignature(id string, sig solana.Signature) error { } // new signatures should only be added for broadcasted transactions // otherwise, the transaction has transitioned states and no longer needs new signatures to track - if _, exists := c.broadcastedTxs[id]; !exists { + if _, exists := c.broadcastedProcessedTxs[id]; !exists { return ErrTransactionNotFound } return nil @@ -146,15 +163,15 @@ func (c *pendingTxContext) AddSignature(id string, sig solana.Signature) error { if _, exists := c.sigToID[sig]; exists { return "", ErrSigAlreadyExists } - if _, exists := c.broadcastedTxs[id]; !exists { + if _, exists := c.broadcastedProcessedTxs[id]; !exists { return "", ErrTransactionNotFound } c.sigToID[sig] = id - tx := c.broadcastedTxs[id] + tx := c.broadcastedProcessedTxs[id] // save new signature tx.signatures = append(tx.signatures, sig) // save updated tx to broadcasted map - c.broadcastedTxs[id] = tx + c.broadcastedProcessedTxs[id] = tx return "", nil }) return err @@ -162,14 +179,9 @@ func (c *pendingTxContext) AddSignature(id string, sig solana.Signature) error { // returns the id if removed (otherwise returns empty string) // removes transactions from any state except finalized and errored -func (c *pendingTxContext) Remove(sig solana.Signature) (id string, err error) { - err = c.withReadLock(func() error { - // check if already removed - id, sigExists := c.sigToID[sig] - if !sigExists { - return ErrSigDoesNotExist - } - _, broadcastedIDExists := c.broadcastedTxs[id] +func (c *pendingTxContext) Remove(id string) (string, error) { + err := c.withReadLock(func() error { + _, broadcastedIDExists := c.broadcastedProcessedTxs[id] _, confirmedIDExists := c.confirmedTxs[id] // transcation does not exist in tx maps if !broadcastedIDExists && !confirmedIDExists { @@ -183,14 +195,10 @@ func (c *pendingTxContext) Remove(sig solana.Signature) (id string, err error) { // upgrade to write lock if sig does not exist return c.withWriteLock(func() (string, error) { - id, sigExists := c.sigToID[sig] - if !sigExists { - return id, ErrSigDoesNotExist - } var tx pendingTx - if tempTx, exists := c.broadcastedTxs[id]; exists { + if tempTx, exists := c.broadcastedProcessedTxs[id]; exists { tx = tempTx - delete(c.broadcastedTxs, id) + delete(c.broadcastedProcessedTxs, id) } if tempTx, exists := c.confirmedTxs[id]; exists { tx = tempTx @@ -211,12 +219,26 @@ func (c *pendingTxContext) Remove(sig solana.Signature) (id string, err error) { }) } -func (c *pendingTxContext) ListAll() []solana.Signature { +func (c *pendingTxContext) ListAllSigs() []solana.Signature { c.lock.RLock() defer c.lock.RUnlock() return maps.Keys(c.sigToID) } +// ListAllExpiredBroadcastedTxs returns all the txes that are in broadcasted state and have expired for given block number compared against lastValidBlockHeight (last valid block number) +// Passing maxUint64 as currBlockNumber will return all broadcasted txes. +func (c *pendingTxContext) ListAllExpiredBroadcastedTxs(currBlockNumber uint64) []pendingTx { + c.lock.RLock() + defer c.lock.RUnlock() + expiredBroadcastedTxs := make([]pendingTx, 0, len(c.broadcastedProcessedTxs)) // worst case, all of them + for _, tx := range c.broadcastedProcessedTxs { + if tx.state == Broadcasted && tx.lastValidBlockHeight < currBlockNumber { + expiredBroadcastedTxs = append(expiredBroadcastedTxs, tx) + } + } + return expiredBroadcastedTxs +} + // Expired returns if the timeout for trying to confirm a signature has been reached func (c *pendingTxContext) Expired(sig solana.Signature, confirmationTimeout time.Duration) bool { c.lock.RLock() @@ -229,7 +251,7 @@ func (c *pendingTxContext) Expired(sig solana.Signature, confirmationTimeout tim if !exists { return false // return expired = false if timestamp does not exist (likely cleaned up by something else previously) } - if tx, exists := c.broadcastedTxs[id]; exists { + if tx, exists := c.broadcastedProcessedTxs[id]; exists { return time.Since(tx.createTs) > confirmationTimeout } if tx, exists := c.confirmedTxs[id]; exists { @@ -246,7 +268,7 @@ func (c *pendingTxContext) OnProcessed(sig solana.Signature) (string, error) { return ErrSigDoesNotExist } // Transactions should only move to processed from broadcasted - tx, exists := c.broadcastedTxs[id] + tx, exists := c.broadcastedProcessedTxs[id] if !exists { return ErrTransactionNotFound } @@ -266,14 +288,14 @@ func (c *pendingTxContext) OnProcessed(sig solana.Signature) (string, error) { if !sigExists { return id, ErrSigDoesNotExist } - tx, exists := c.broadcastedTxs[id] + tx, exists := c.broadcastedProcessedTxs[id] if !exists { return id, ErrTransactionNotFound } // update tx state to Processed tx.state = Processed // save updated tx back to the broadcasted map - c.broadcastedTxs[id] = tx + c.broadcastedProcessedTxs[id] = tx return id, nil }) } @@ -290,7 +312,7 @@ func (c *pendingTxContext) OnConfirmed(sig solana.Signature) (string, error) { return ErrAlreadyInExpectedState } // Transactions should only move to confirmed from broadcasted/processed - if _, exists := c.broadcastedTxs[id]; !exists { + if _, exists := c.broadcastedProcessedTxs[id]; !exists { return ErrTransactionNotFound } return nil @@ -305,7 +327,7 @@ func (c *pendingTxContext) OnConfirmed(sig solana.Signature) (string, error) { if !sigExists { return id, ErrSigDoesNotExist } - tx, exists := c.broadcastedTxs[id] + tx, exists := c.broadcastedProcessedTxs[id] if !exists { return id, ErrTransactionNotFound } @@ -319,7 +341,7 @@ func (c *pendingTxContext) OnConfirmed(sig solana.Signature) (string, error) { // move tx to confirmed map c.confirmedTxs[id] = tx // remove tx from broadcasted map - delete(c.broadcastedTxs, id) + delete(c.broadcastedProcessedTxs, id) return id, nil }) } @@ -331,7 +353,7 @@ func (c *pendingTxContext) OnFinalized(sig solana.Signature, retentionTimeout ti return ErrSigDoesNotExist } // Allow transactions to transition from broadcasted, processed, or confirmed state in case there are delays between status checks - _, broadcastedExists := c.broadcastedTxs[id] + _, broadcastedExists := c.broadcastedProcessedTxs[id] _, confirmedExists := c.confirmedTxs[id] if !broadcastedExists && !confirmedExists { return ErrTransactionNotFound @@ -350,7 +372,7 @@ func (c *pendingTxContext) OnFinalized(sig solana.Signature, retentionTimeout ti } var tx, tempTx pendingTx var broadcastedExists, confirmedExists bool - if tempTx, broadcastedExists = c.broadcastedTxs[id]; broadcastedExists { + if tempTx, broadcastedExists = c.broadcastedProcessedTxs[id]; broadcastedExists { tx = tempTx } if tempTx, confirmedExists = c.confirmedTxs[id]; confirmedExists { @@ -366,7 +388,7 @@ func (c *pendingTxContext) OnFinalized(sig solana.Signature, retentionTimeout ti delete(c.cancelBy, id) } // delete from broadcasted map, if exists - delete(c.broadcastedTxs, id) + delete(c.broadcastedProcessedTxs, id) // delete from confirmed map, if exists delete(c.confirmedTxs, id) // remove all related signatures from the sigToID map to skip picking up this tx in the confirmation logic @@ -397,7 +419,7 @@ func (c *pendingTxContext) OnPrebroadcastError(id string, retentionTimeout time. if tx, exists := c.finalizedErroredTxs[id]; exists && tx.state == txState { return ErrAlreadyInExpectedState } - _, broadcastedExists := c.broadcastedTxs[id] + _, broadcastedExists := c.broadcastedProcessedTxs[id] _, confirmedExists := c.confirmedTxs[id] if broadcastedExists || confirmedExists { return ErrIDAlreadyExists @@ -410,10 +432,11 @@ func (c *pendingTxContext) OnPrebroadcastError(id string, retentionTimeout time. // upgrade to write lock if id does not exist in other maps and is not in expected state already _, err = c.withWriteLock(func() (string, error) { - if tx, exists := c.finalizedErroredTxs[id]; exists && tx.state == txState { + tx, exists := c.finalizedErroredTxs[id] + if exists && tx.state == txState { return "", ErrAlreadyInExpectedState } - _, broadcastedExists := c.broadcastedTxs[id] + _, broadcastedExists := c.broadcastedProcessedTxs[id] _, confirmedExists := c.confirmedTxs[id] if broadcastedExists || confirmedExists { return "", ErrIDAlreadyExists @@ -437,7 +460,7 @@ func (c *pendingTxContext) OnError(sig solana.Signature, retentionTimeout time.D } // transaction can transition from any non-finalized state var broadcastedExists, confirmedExists bool - _, broadcastedExists = c.broadcastedTxs[id] + _, broadcastedExists = c.broadcastedProcessedTxs[id] _, confirmedExists = c.confirmedTxs[id] // transcation does not exist in any tx maps if !broadcastedExists && !confirmedExists { @@ -457,7 +480,7 @@ func (c *pendingTxContext) OnError(sig solana.Signature, retentionTimeout time.D } var tx, tempTx pendingTx var broadcastedExists, confirmedExists bool - if tempTx, broadcastedExists = c.broadcastedTxs[id]; broadcastedExists { + if tempTx, broadcastedExists = c.broadcastedProcessedTxs[id]; broadcastedExists { tx = tempTx } if tempTx, confirmedExists = c.confirmedTxs[id]; confirmedExists { @@ -473,7 +496,7 @@ func (c *pendingTxContext) OnError(sig solana.Signature, retentionTimeout time.D delete(c.cancelBy, id) } // delete from broadcasted map, if exists - delete(c.broadcastedTxs, id) + delete(c.broadcastedProcessedTxs, id) // delete from confirmed map, if exists delete(c.confirmedTxs, id) // remove all related signatures from the sigToID map to skip picking up this tx in the confirmation logic @@ -497,7 +520,7 @@ func (c *pendingTxContext) OnError(sig solana.Signature, retentionTimeout time.D func (c *pendingTxContext) GetTxState(id string) (TxState, error) { c.lock.RLock() defer c.lock.RUnlock() - if tx, exists := c.broadcastedTxs[id]; exists { + if tx, exists := c.broadcastedProcessedTxs[id]; exists { return tx.state, nil } if tx, exists := c.confirmedTxs[id]; exists { @@ -594,16 +617,20 @@ func (c *pendingTxContextWithProm) OnConfirmed(sig solana.Signature) (string, er return id, err } -func (c *pendingTxContextWithProm) Remove(sig solana.Signature) (string, error) { - return c.pendingTx.Remove(sig) +func (c *pendingTxContextWithProm) Remove(id string) (string, error) { + return c.pendingTx.Remove(id) } -func (c *pendingTxContextWithProm) ListAll() []solana.Signature { - sigs := c.pendingTx.ListAll() +func (c *pendingTxContextWithProm) ListAllSigs() []solana.Signature { + sigs := c.pendingTx.ListAllSigs() promSolTxmPendingTxs.WithLabelValues(c.chainID).Set(float64(len(sigs))) return sigs } +func (c *pendingTxContextWithProm) ListAllExpiredBroadcastedTxs(currBlockNumber uint64) []pendingTx { + return c.pendingTx.ListAllExpiredBroadcastedTxs(currBlockNumber) +} + func (c *pendingTxContextWithProm) Expired(sig solana.Signature, lifespan time.Duration) bool { return c.pendingTx.Expired(sig, lifespan) } diff --git a/pkg/solana/txm/pendingtx_test.go b/pkg/solana/txm/pendingtx_test.go index e7b7fc51e..a79f9f7aa 100644 --- a/pkg/solana/txm/pendingtx_test.go +++ b/pkg/solana/txm/pendingtx_test.go @@ -48,19 +48,21 @@ func TestPendingTxContext_add_remove_multiple(t *testing.T) { // cannot add signature for non existent ID require.Error(t, txs.AddSignature(uuid.New().String(), solana.Signature{})) - // return list of signatures - list := txs.ListAll() + list := make([]string, 0, n) + for _, id := range txs.sigToID { + list = append(list, id) + } assert.Equal(t, n, len(list)) // stop all sub processes for i := 0; i < len(list); i++ { - id, err := txs.Remove(list[i]) + txID := list[i] + _, err := txs.Remove(txID) assert.NoError(t, err) - assert.Equal(t, n-i-1, len(txs.ListAll())) - assert.Equal(t, ids[list[i]], id) + assert.Equal(t, n-i-1, len(txs.ListAllSigs())) // second remove should not return valid id - already removed - id, err = txs.Remove(list[i]) + id, err := txs.Remove(txID) require.Error(t, err) assert.Equal(t, "", id) } @@ -76,29 +78,55 @@ func TestPendingTxContext_new(t *testing.T) { // Create new transaction msg := pendingTx{id: uuid.NewString()} err := txs.New(msg, sig, cancel) - require.NoError(t, err) + require.NoError(t, err, "expected no error when adding a new transaction") - // Check it exists in signature map + // Check it exists in signature map and mapped to the correct txID id, exists := txs.sigToID[sig] - require.True(t, exists) - require.Equal(t, msg.id, id) + require.True(t, exists, "signature should exist in sigToID map") + require.Equal(t, msg.id, id, "signature should map to correct transaction ID") - // Check it exists in broadcasted map - tx, exists := txs.broadcastedTxs[msg.id] - require.True(t, exists) - require.Len(t, tx.signatures, 1) - require.Equal(t, sig, tx.signatures[0]) + // Check it exists in broadcasted map and that sigs match + tx, exists := txs.broadcastedProcessedTxs[msg.id] + require.True(t, exists, "transaction should exist in broadcastedProcessedTxs map") + require.Len(t, tx.signatures, 1, "transaction should have one signature") + require.Equal(t, sig, tx.signatures[0], "signature should match") // Check status is Broadcasted - require.Equal(t, Broadcasted, tx.state) + require.Equal(t, Broadcasted, tx.state, "transaction state should be Broadcasted") - // Check it does not exist in confirmed map + // Check it does not exist in confirmed nor finalized maps _, exists = txs.confirmedTxs[msg.id] - require.False(t, exists) - - // Check it does not exist in finalized map + require.False(t, exists, "transaction should not exist in confirmedTxs map") _, exists = txs.finalizedErroredTxs[msg.id] - require.False(t, exists) + require.False(t, exists, "transaction should not exist in finalizedErroredTxs map") + + // Attempt to add the same transaction again with the same signature + err = txs.New(msg, sig, cancel) + require.ErrorIs(t, err, ErrSigAlreadyExists, "expected ErrSigAlreadyExists when adding duplicate signature") + + // Attempt to add a new transaction with the same transaction ID but different signature + err = txs.New(pendingTx{id: msg.id}, randomSignature(t), cancel) + require.ErrorIs(t, err, ErrIDAlreadyExists, "expected ErrIDAlreadyExists when adding duplicate transaction ID") + + // Attempt to add a new transaction with a different transaction ID but same signature + err = txs.New(pendingTx{id: uuid.NewString()}, sig, cancel) + require.ErrorIs(t, err, ErrSigAlreadyExists, "expected ErrSigAlreadyExists when adding duplicate signature") + + // Simulate moving the transaction to confirmedTxs map + _, err = txs.OnConfirmed(sig) + require.NoError(t, err, "expected no error when confirming transaction") + + // Attempt to add a new transaction with the same ID (now in confirmedTxs) and new signature + err = txs.New(pendingTx{id: msg.id}, randomSignature(t), cancel) + require.ErrorIs(t, err, ErrIDAlreadyExists, "expected ErrIDAlreadyExists when adding transaction ID that exists in confirmedTxs") + + // Simulate moving the transaction to finalizedErroredTxs map + _, err = txs.OnFinalized(sig, 10*time.Second) + require.NoError(t, err, "expected no error when finalizing transaction") + + // Attempt to add a new transaction with the same ID (now in finalizedErroredTxs) and new signature + err = txs.New(pendingTx{id: msg.id}, randomSignature(t), cancel) + require.ErrorIs(t, err, ErrIDAlreadyExists, "expected ErrIDAlreadyExists when adding transaction ID that exists in finalizedErroredTxs") } func TestPendingTxContext_add_signature(t *testing.T) { @@ -127,7 +155,7 @@ func TestPendingTxContext_add_signature(t *testing.T) { require.Equal(t, msg.id, id) // Check broadcasted map - tx, exists := txs.broadcastedTxs[msg.id] + tx, exists := txs.broadcastedProcessedTxs[msg.id] require.True(t, exists) require.Len(t, tx.signatures, 2) require.Equal(t, sig1, tx.signatures[0]) @@ -216,7 +244,7 @@ func TestPendingTxContext_on_broadcasted_processed(t *testing.T) { require.Equal(t, msg.id, id) // Check it exists in broadcasted map - tx, exists := txs.broadcastedTxs[msg.id] + tx, exists := txs.broadcastedProcessedTxs[msg.id] require.True(t, exists) require.Len(t, tx.signatures, 1) require.Equal(t, sig, tx.signatures[0]) @@ -351,7 +379,7 @@ func TestPendingTxContext_on_confirmed(t *testing.T) { require.Equal(t, msg.id, id) // Check it does not exist in broadcasted map - _, exists = txs.broadcastedTxs[msg.id] + _, exists = txs.broadcastedProcessedTxs[msg.id] require.False(t, exists) // Check it exists in confirmed map @@ -463,7 +491,7 @@ func TestPendingTxContext_on_finalized(t *testing.T) { require.Equal(t, msg.id, id) // Check it does not exist in broadcasted map - _, exists := txs.broadcastedTxs[msg.id] + _, exists := txs.broadcastedProcessedTxs[msg.id] require.False(t, exists) // Check it does not exist in confirmed map @@ -513,7 +541,7 @@ func TestPendingTxContext_on_finalized(t *testing.T) { require.Equal(t, msg.id, id) // Check it does not exist in broadcasted map - _, exists := txs.broadcastedTxs[msg.id] + _, exists := txs.broadcastedProcessedTxs[msg.id] require.False(t, exists) // Check it does not exist in confirmed map @@ -558,7 +586,7 @@ func TestPendingTxContext_on_finalized(t *testing.T) { require.Equal(t, msg.id, id) // Check it does not exist in broadcasted map - _, exists := txs.broadcastedTxs[msg.id] + _, exists := txs.broadcastedProcessedTxs[msg.id] require.False(t, exists) // Check it does not exist in confirmed map @@ -613,7 +641,7 @@ func TestPendingTxContext_on_error(t *testing.T) { require.Equal(t, msg.id, id) // Check it does not exist in broadcasted map - _, exists := txs.broadcastedTxs[msg.id] + _, exists := txs.broadcastedProcessedTxs[msg.id] require.False(t, exists) // Check it does not exist in confirmed map @@ -651,7 +679,7 @@ func TestPendingTxContext_on_error(t *testing.T) { require.Equal(t, msg.id, id) // Check it does not exist in broadcasted map - _, exists := txs.broadcastedTxs[msg.id] + _, exists := txs.broadcastedProcessedTxs[msg.id] require.False(t, exists) // Check it does not exist in confirmed map @@ -684,7 +712,7 @@ func TestPendingTxContext_on_error(t *testing.T) { require.Equal(t, msg.id, id) // Check it does not exist in broadcasted map - _, exists := txs.broadcastedTxs[msg.id] + _, exists := txs.broadcastedProcessedTxs[msg.id] require.False(t, exists) // Check it exists in errored map @@ -718,7 +746,7 @@ func TestPendingTxContext_on_error(t *testing.T) { require.Equal(t, msg.id, id) // Check it does not exist in broadcasted map - _, exists := txs.broadcastedTxs[msg.id] + _, exists := txs.broadcastedProcessedTxs[msg.id] require.False(t, exists) // Check it does not exist in confirmed map @@ -825,22 +853,27 @@ func TestPendingTxContext_remove(t *testing.T) { txs := newPendingTxContext() retentionTimeout := 5 * time.Second + broadcastedID := uuid.NewString() broadcastedSig1 := randomSignature(t) broadcastedSig2 := randomSignature(t) + processedID := uuid.NewString() processedSig := randomSignature(t) + confirmedID := uuid.NewString() confirmedSig := randomSignature(t) + finalizedID := uuid.NewString() finalizedSig := randomSignature(t) + erroredID := uuid.NewString() erroredSig := randomSignature(t) // Create new broadcasted transaction with extra sig - broadcastedMsg := pendingTx{id: uuid.NewString()} + broadcastedMsg := pendingTx{id: broadcastedID} err := txs.New(broadcastedMsg, broadcastedSig1, cancel) require.NoError(t, err) err = txs.AddSignature(broadcastedMsg.id, broadcastedSig2) require.NoError(t, err) // Create new processed transaction - processedMsg := pendingTx{id: uuid.NewString()} + processedMsg := pendingTx{id: processedID} err = txs.New(processedMsg, processedSig, cancel) require.NoError(t, err) id, err := txs.OnProcessed(processedSig) @@ -848,7 +881,7 @@ func TestPendingTxContext_remove(t *testing.T) { require.Equal(t, processedMsg.id, id) // Create new confirmed transaction - confirmedMsg := pendingTx{id: uuid.NewString()} + confirmedMsg := pendingTx{id: confirmedID} err = txs.New(confirmedMsg, confirmedSig, cancel) require.NoError(t, err) id, err = txs.OnConfirmed(confirmedSig) @@ -856,7 +889,7 @@ func TestPendingTxContext_remove(t *testing.T) { require.Equal(t, confirmedMsg.id, id) // Create new finalized transaction - finalizedMsg := pendingTx{id: uuid.NewString()} + finalizedMsg := pendingTx{id: finalizedID} err = txs.New(finalizedMsg, finalizedSig, cancel) require.NoError(t, err) id, err = txs.OnFinalized(finalizedSig, retentionTimeout) @@ -864,7 +897,7 @@ func TestPendingTxContext_remove(t *testing.T) { require.Equal(t, finalizedMsg.id, id) // Create new errored transaction - erroredMsg := pendingTx{id: uuid.NewString()} + erroredMsg := pendingTx{id: erroredID} err = txs.New(erroredMsg, erroredSig, cancel) require.NoError(t, err) id, err = txs.OnError(erroredSig, retentionTimeout, Errored, 0) @@ -872,11 +905,11 @@ func TestPendingTxContext_remove(t *testing.T) { require.Equal(t, erroredMsg.id, id) // Remove broadcasted transaction - id, err = txs.Remove(broadcastedSig1) + id, err = txs.Remove(broadcastedID) require.NoError(t, err) require.Equal(t, broadcastedMsg.id, id) // Check removed from broadcasted map - _, exists := txs.broadcastedTxs[broadcastedMsg.id] + _, exists := txs.broadcastedProcessedTxs[broadcastedMsg.id] require.False(t, exists) // Check all signatures removed from sig map _, exists = txs.sigToID[broadcastedSig1] @@ -885,18 +918,18 @@ func TestPendingTxContext_remove(t *testing.T) { require.False(t, exists) // Remove processed transaction - id, err = txs.Remove(processedSig) + id, err = txs.Remove(processedID) require.NoError(t, err) require.Equal(t, processedMsg.id, id) // Check removed from broadcasted map - _, exists = txs.broadcastedTxs[processedMsg.id] + _, exists = txs.broadcastedProcessedTxs[processedMsg.id] require.False(t, exists) // Check all signatures removed from sig map _, exists = txs.sigToID[processedSig] require.False(t, exists) // Remove confirmed transaction - id, err = txs.Remove(confirmedSig) + id, err = txs.Remove(confirmedID) require.NoError(t, err) require.Equal(t, confirmedMsg.id, id) // Check removed from confirmed map @@ -907,17 +940,17 @@ func TestPendingTxContext_remove(t *testing.T) { require.False(t, exists) // Check remove cannot be called on finalized transaction - id, err = txs.Remove(finalizedSig) + id, err = txs.Remove(finalizedID) require.Error(t, err) require.Equal(t, "", id) // Check remove cannot be called on errored transaction - id, err = txs.Remove(erroredSig) + id, err = txs.Remove(erroredID) require.Error(t, err) require.Equal(t, "", id) // Check sig list is empty after all removals - require.Empty(t, txs.ListAll()) + require.Empty(t, txs.ListAllSigs()) } func TestPendingTxContext_trim_finalized_errored_txs(t *testing.T) { t.Parallel() @@ -959,23 +992,24 @@ func TestPendingTxContext_expired(t *testing.T) { _, cancel := context.WithCancel(tests.Context(t)) sig := solana.Signature{} txs := newPendingTxContext() + txID := uuid.NewString() - msg := pendingTx{id: uuid.NewString()} + msg := pendingTx{id: txID} err := txs.New(msg, sig, cancel) assert.NoError(t, err) - msg, exists := txs.broadcastedTxs[msg.id] + msg, exists := txs.broadcastedProcessedTxs[msg.id] require.True(t, exists) // Set createTs to 10 seconds ago msg.createTs = time.Now().Add(-10 * time.Second) - txs.broadcastedTxs[msg.id] = msg + txs.broadcastedProcessedTxs[msg.id] = msg assert.False(t, txs.Expired(sig, 0*time.Second)) // false if timeout 0 assert.True(t, txs.Expired(sig, 5*time.Second)) // expired for 5s lifetime assert.False(t, txs.Expired(sig, 60*time.Second)) // not expired for 60s lifetime - id, err := txs.Remove(sig) + id, err := txs.Remove(txID) assert.NoError(t, err) assert.Equal(t, msg.id, id) assert.False(t, txs.Expired(sig, 60*time.Second)) // no longer exists, should return false @@ -1025,18 +1059,19 @@ func TestPendingTxContext_race(t *testing.T) { t.Run("remove", func(t *testing.T) { txCtx := newPendingTxContext() - msg := pendingTx{id: uuid.NewString()} + txID := uuid.NewString() + msg := pendingTx{id: txID} err := txCtx.New(msg, solana.Signature{}, func() {}) require.NoError(t, err) var wg sync.WaitGroup wg.Add(2) go func() { - assert.NotPanics(t, func() { txCtx.Remove(solana.Signature{}) }) //nolint // no need to check error + assert.NotPanics(t, func() { txCtx.Remove(txID) }) //nolint // no need to check error wg.Done() }() go func() { - assert.NotPanics(t, func() { txCtx.Remove(solana.Signature{}) }) //nolint // no need to check error + assert.NotPanics(t, func() { txCtx.Remove(txID) }) //nolint // no need to check error wg.Done() }() @@ -1137,3 +1172,157 @@ func randomSignature(t *testing.T) solana.Signature { return solana.SignatureFromBytes(sig) } + +func TestPendingTxContext_ListAllExpiredBroadcastedTxs(t *testing.T) { + tests := []struct { + name string + setup func(t *testing.T, ctx *pendingTxContext) + currBlockHeight uint64 + expectedTxIDs []string + }{ + { + name: "No broadcasted transactions", + setup: func(t *testing.T, ctx *pendingTxContext) { + // No setup needed; broadcastedProcessedTxs remains empty + }, + currBlockHeight: 1000, + expectedTxIDs: []string{}, + }, + { + name: "No expired broadcasted transactions", + setup: func(t *testing.T, ctx *pendingTxContext) { + tx1 := pendingTx{ + id: "tx1", + state: Broadcasted, + lastValidBlockHeight: 1500, + } + tx2 := pendingTx{ + id: "tx2", + state: Broadcasted, + lastValidBlockHeight: 1600, + } + ctx.broadcastedProcessedTxs["tx1"] = tx1 + ctx.broadcastedProcessedTxs["tx2"] = tx2 + }, + currBlockHeight: 1400, + expectedTxIDs: []string{}, + }, + { + name: "Some expired broadcasted transactions", + setup: func(t *testing.T, ctx *pendingTxContext) { + tx1 := pendingTx{ + id: "tx1", + state: Broadcasted, + lastValidBlockHeight: 1000, + } + tx2 := pendingTx{ + id: "tx2", + state: Broadcasted, + lastValidBlockHeight: 1500, + } + tx3 := pendingTx{ + id: "tx3", + state: Broadcasted, + lastValidBlockHeight: 900, + } + ctx.broadcastedProcessedTxs["tx1"] = tx1 + ctx.broadcastedProcessedTxs["tx2"] = tx2 + ctx.broadcastedProcessedTxs["tx3"] = tx3 + }, + currBlockHeight: 1200, + expectedTxIDs: []string{"tx1", "tx3"}, + }, + { + name: "All broadcasted transactions expired with maxUint64", + setup: func(t *testing.T, ctx *pendingTxContext) { + tx1 := pendingTx{ + id: "tx1", + state: Broadcasted, + lastValidBlockHeight: 1000, + } + tx2 := pendingTx{ + id: "tx2", + state: Broadcasted, + lastValidBlockHeight: 1500, + } + ctx.broadcastedProcessedTxs["tx1"] = tx1 + ctx.broadcastedProcessedTxs["tx2"] = tx2 + }, + currBlockHeight: ^uint64(0), // maxUint64 + expectedTxIDs: []string{"tx1", "tx2"}, + }, + { + name: "Only broadcasted transactions are considered", + setup: func(t *testing.T, ctx *pendingTxContext) { + tx1 := pendingTx{ + id: "tx1", + state: Broadcasted, + lastValidBlockHeight: 800, + } + tx2 := pendingTx{ + id: "tx2", + state: Processed, // Not Broadcasted + lastValidBlockHeight: 700, + } + tx3 := pendingTx{ + id: "tx3", + state: Processed, // Not Broadcasted + lastValidBlockHeight: 600, + } + ctx.broadcastedProcessedTxs["tx1"] = tx1 + ctx.broadcastedProcessedTxs["tx2"] = tx2 + ctx.broadcastedProcessedTxs["tx3"] = tx3 + }, + currBlockHeight: 900, + expectedTxIDs: []string{"tx1"}, + }, + { + name: "Broadcasted transactions with edge block heights", + setup: func(t *testing.T, ctx *pendingTxContext) { + tx1 := pendingTx{ + id: "tx1", + state: Broadcasted, + lastValidBlockHeight: 1000, + } + tx2 := pendingTx{ + id: "tx2", + state: Broadcasted, + lastValidBlockHeight: 999, + } + tx3 := pendingTx{ + id: "tx3", + state: Broadcasted, + lastValidBlockHeight: 1, + } + ctx.broadcastedProcessedTxs["tx1"] = tx1 + ctx.broadcastedProcessedTxs["tx2"] = tx2 + ctx.broadcastedProcessedTxs["tx3"] = tx3 + }, + currBlockHeight: 1000, + expectedTxIDs: []string{"tx2", "tx3"}, + }, + } + + for _, tt := range tests { + tt := tt // capture range variable + t.Run(tt.name, func(t *testing.T) { + // Initialize a new PendingTxContext + ctx := newPendingTxContext() + + // Setup the test case + tt.setup(t, ctx) + + // Execute the function under test + result := ctx.ListAllExpiredBroadcastedTxs(tt.currBlockHeight) + + // Extract the IDs from the result + var resultIDs []string + for _, tx := range result { + resultIDs = append(resultIDs, tx.id) + } + + // Assert that the expected IDs match the result IDs (order does not matter) + assert.ElementsMatch(t, tt.expectedTxIDs, resultIDs) + }) + } +} diff --git a/pkg/solana/txm/txm.go b/pkg/solana/txm/txm.go index 10cc1acd2..3e169d88a 100644 --- a/pkg/solana/txm/txm.go +++ b/pkg/solana/txm/txm.go @@ -142,6 +142,10 @@ func (txm *Txm) Start(ctx context.Context) error { }) } +// run is a goroutine that continuously processes transactions from the chSend channel. +// It attempts to send each transaction with retry logic and, upon success, enqueues the transaction for simulation. +// If a transaction fails to send, it logs the error and resets the client to handle potential bad RPCs. +// The function runs until the chStop channel signals to stop. func (txm *Txm) run() { defer txm.done.Done() ctx, cancel := txm.chStop.NewCtx() @@ -175,197 +179,198 @@ func (txm *Txm) run() { } } +// sendWithRetry attempts to send a transaction with exponential backoff retry logic. +// It builds, signs, sends the initial tx, and starts a retry routine with fee bumping if needed. +// The function returns the signed transaction, its ID, and the initial signature for use in simulation. func (txm *Txm) sendWithRetry(ctx context.Context, msg pendingTx) (solanaGo.Transaction, string, solanaGo.Signature, error) { - // get key - // fee payer account is index 0 account - // https://github.com/gagliardetto/solana-go/blob/main/transaction.go#L252 - key := msg.tx.Message.AccountKeys[0].String() - - // base compute unit price should only be calculated once - // prevent underlying base changing when bumping (could occur with RPC based estimation) - getFee := func(count int) fees.ComputeUnitPrice { - fee := fees.CalculateFee( - msg.cfg.BaseComputeUnitPrice, - msg.cfg.ComputeUnitPriceMax, - msg.cfg.ComputeUnitPriceMin, - uint(count), //nolint:gosec // reasonable number of bumps should never cause overflow - ) - return fees.ComputeUnitPrice(fee) - } - - baseTx := msg.tx - - // add compute unit limit instruction - static for the transaction - // skip if compute unit limit = 0 (otherwise would always fail) - if msg.cfg.ComputeUnitLimit != 0 { - if computeUnitLimitErr := fees.SetComputeUnitLimit(&baseTx, fees.ComputeUnitLimit(msg.cfg.ComputeUnitLimit)); computeUnitLimitErr != nil { - return solanaGo.Transaction{}, "", solanaGo.Signature{}, fmt.Errorf("failed to add compute unit limit instruction: %w", computeUnitLimitErr) - } - } - - buildTx := func(ctx context.Context, base solanaGo.Transaction, retryCount int) (solanaGo.Transaction, error) { - newTx := base // make copy - - // set fee - // fee bumping can be enabled by moving the setting & signing logic to the broadcaster - if computeUnitErr := fees.SetComputeUnitPrice(&newTx, getFee(retryCount)); computeUnitErr != nil { - return solanaGo.Transaction{}, computeUnitErr - } - - // sign tx - txMsg, marshalErr := newTx.Message.MarshalBinary() - if marshalErr != nil { - return solanaGo.Transaction{}, fmt.Errorf("error in soltxm.SendWithRetry.MarshalBinary: %w", marshalErr) - } - sigBytes, signErr := txm.ks.Sign(ctx, key, txMsg) - if signErr != nil { - return solanaGo.Transaction{}, fmt.Errorf("error in soltxm.SendWithRetry.Sign: %w", signErr) - } - var finalSig [64]byte - copy(finalSig[:], sigBytes) - newTx.Signatures = append(newTx.Signatures, finalSig) - - return newTx, nil - } - - initTx, initBuildErr := buildTx(ctx, baseTx, 0) - if initBuildErr != nil { - return solanaGo.Transaction{}, "", solanaGo.Signature{}, initBuildErr + // Build and sign initial transaction setting compute unit price and limit + initTx, err := txm.buildTx(ctx, msg, 0) + if err != nil { + return solanaGo.Transaction{}, "", solanaGo.Signature{}, err } - // create timeout context + // Send initial transaction ctx, cancel := context.WithTimeout(ctx, msg.cfg.Timeout) - - // send initial tx (do not retry and exit early if fails) sig, initSendErr := txm.sendTx(ctx, &initTx) if initSendErr != nil { - cancel() // cancel context when exiting early + // Do not retry and exit early if fails + cancel() stateTransitionErr := txm.txs.OnPrebroadcastError(msg.id, txm.cfg.TxRetentionTimeout(), Errored, TxFailReject) return solanaGo.Transaction{}, "", solanaGo.Signature{}, fmt.Errorf("tx failed initial transmit: %w", errors.Join(initSendErr, stateTransitionErr)) } - // store tx signature + cancel function - initStoreErr := txm.txs.New(msg, sig, cancel) - if initStoreErr != nil { - cancel() // cancel context when exiting early - return solanaGo.Transaction{}, "", solanaGo.Signature{}, fmt.Errorf("failed to save tx signature (%s) to inflight txs: %w", sig, initStoreErr) + // Store tx signature and cancel function + if err := txm.txs.New(msg, sig, cancel); err != nil { + cancel() // Cancel context when exiting early + return solanaGo.Transaction{}, "", solanaGo.Signature{}, fmt.Errorf("failed to save tx signature (%s) to inflight txs: %w", sig, err) } - // used for tracking rebroadcasting only in SendWithRetry - var sigs signatureList + txm.lggr.Debugw("tx initial broadcast", "id", msg.id, "fee", msg.cfg.BaseComputeUnitPrice, "signature", sig, "lastValidBlockHeight", msg.lastValidBlockHeight) + + // Initialize signature list with initialTx signature. This list will be used to add new signatures and track retry attempts. + sigs := &signatureList{} sigs.Allocate() if initSetErr := sigs.Set(0, sig); initSetErr != nil { return solanaGo.Transaction{}, "", solanaGo.Signature{}, fmt.Errorf("failed to save initial signature in signature list: %w", initSetErr) } - txm.lggr.Debugw("tx initial broadcast", "id", msg.id, "fee", getFee(0), "signature", sig) - + // pass in copy of msg (to build new tx with bumped fee) and broadcasted tx == initTx (to retry tx without bumping) txm.done.Add(1) - // retry with exponential backoff - // until context cancelled by timeout or called externally - // pass in copy of baseTx (used to build new tx with bumped fee) and broadcasted tx == initTx (used to retry tx without bumping) - go func(ctx context.Context, baseTx, currentTx solanaGo.Transaction) { + go func() { defer txm.done.Done() - deltaT := 1 // ms - tick := time.After(0) - bumpCount := 0 - bumpTime := time.Now() - var wg sync.WaitGroup + txm.retryTx(ctx, msg, initTx, sigs) + }() - for { - select { - case <-ctx.Done(): - // stop sending tx after retry tx ctx times out (does not stop confirmation polling for tx) - wg.Wait() - txm.lggr.Debugw("stopped tx retry", "id", msg.id, "signatures", sigs.List(), "err", context.Cause(ctx)) - return - case <-tick: - var shouldBump bool - // bump if period > 0 and past time - if msg.cfg.FeeBumpPeriod != 0 && time.Since(bumpTime) > msg.cfg.FeeBumpPeriod { - bumpCount++ - bumpTime = time.Now() - shouldBump = true - } + // Return signed tx, id, signature for use in simulation + return initTx, msg.id, sig, nil +} - // if fee should be bumped, build new tx and replace currentTx - if shouldBump { - var retryBuildErr error - currentTx, retryBuildErr = buildTx(ctx, baseTx, bumpCount) - if retryBuildErr != nil { - txm.lggr.Errorw("failed to build bumped retry tx", "error", retryBuildErr, "id", msg.id) - return // exit func if cannot build tx for retrying - } - ind := sigs.Allocate() - if ind != bumpCount { - txm.lggr.Errorw("INVARIANT VIOLATION: index (%d) != bumpCount (%d)", ind, bumpCount) - return - } - } +// buildTx builds and signs the transaction with the appropriate compute unit price. +func (txm *Txm) buildTx(ctx context.Context, msg pendingTx, retryCount int) (solanaGo.Transaction, error) { + // work with a copy + newTx := msg.tx - // take currentTx and broadcast, if bumped fee -> save signature to list - wg.Add(1) - go func(bump bool, count int, retryTx solanaGo.Transaction) { - defer wg.Done() - - retrySig, retrySendErr := txm.sendTx(ctx, &retryTx) - // this could occur if endpoint goes down or if ctx cancelled - if retrySendErr != nil { - if strings.Contains(retrySendErr.Error(), "context canceled") || strings.Contains(retrySendErr.Error(), "context deadline exceeded") { - txm.lggr.Debugw("ctx error on send retry transaction", "error", retrySendErr, "signatures", sigs.List(), "id", msg.id) - } else { - txm.lggr.Warnw("failed to send retry transaction", "error", retrySendErr, "signatures", sigs.List(), "id", msg.id) - } - return - } - - // save new signature if fee bumped - if bump { - if retryStoreErr := txm.txs.AddSignature(msg.id, retrySig); retryStoreErr != nil { - txm.lggr.Warnw("error in adding retry transaction", "error", retryStoreErr, "id", msg.id) - return - } - if setErr := sigs.Set(count, retrySig); setErr != nil { - // this should never happen - txm.lggr.Errorw("INVARIANT VIOLATION", "error", setErr) - } - txm.lggr.Debugw("tx rebroadcast with bumped fee", "id", msg.id, "fee", getFee(count), "signatures", sigs.List()) - } - - // prevent locking on waitgroup when ctx is closed - wait := make(chan struct{}) - go func() { - defer close(wait) - sigs.Wait(count) // wait until bump tx has set the tx signature to compare rebroadcast signatures - }() - select { - case <-ctx.Done(): - return - case <-wait: - } - - // this should never happen (should match the signature saved to sigs) - if fetchedSig, fetchErr := sigs.Get(count); fetchErr != nil || retrySig != fetchedSig { - txm.lggr.Errorw("original signature does not match retry signature", "expectedSignatures", sigs.List(), "receivedSignature", retrySig, "error", fetchErr) - } - }(shouldBump, bumpCount, currentTx) - } + // Set compute unit limit if specified + if msg.cfg.ComputeUnitLimit != 0 { + if err := fees.SetComputeUnitLimit(&newTx, fees.ComputeUnitLimit(msg.cfg.ComputeUnitLimit)); err != nil { + return solanaGo.Transaction{}, fmt.Errorf("failed to add compute unit limit instruction: %w", err) + } + } + + // Set compute unit price (fee) + fee := fees.ComputeUnitPrice( + fees.CalculateFee( + msg.cfg.BaseComputeUnitPrice, + msg.cfg.ComputeUnitPriceMax, + msg.cfg.ComputeUnitPriceMin, + uint(retryCount), //nolint:gosec // reasonable number of bumps should never cause overflow + )) + if err := fees.SetComputeUnitPrice(&newTx, fee); err != nil { + return solanaGo.Transaction{}, err + } + + // Sign transaction + // NOTE: fee payer account is index 0 account. https://github.com/gagliardetto/solana-go/blob/main/transaction.go#L252 + txMsg, err := newTx.Message.MarshalBinary() + if err != nil { + return solanaGo.Transaction{}, fmt.Errorf("error in MarshalBinary: %w", err) + } + sigBytes, err := txm.ks.Sign(ctx, msg.tx.Message.AccountKeys[0].String(), txMsg) + if err != nil { + return solanaGo.Transaction{}, fmt.Errorf("error in Sign: %w", err) + } + var finalSig [64]byte + copy(finalSig[:], sigBytes) + newTx.Signatures = append(newTx.Signatures, finalSig) + + return newTx, nil +} - // exponential increase in wait time, capped at 250ms - deltaT *= 2 - if deltaT > MaxRetryTimeMs { - deltaT = MaxRetryTimeMs +// retryTx contains the logic for retrying the transaction, including exponential backoff and fee bumping. +// Retries until context cancelled by timeout or called externally. +// It uses handleRetry helper function to handle each retry attempt. +func (txm *Txm) retryTx(ctx context.Context, msg pendingTx, currentTx solanaGo.Transaction, sigs *signatureList) { + deltaT := 1 // initial delay in ms + tick := time.After(0) + bumpCount := 0 + bumpTime := time.Now() + var wg sync.WaitGroup + + for { + select { + case <-ctx.Done(): + // stop sending tx after retry tx ctx times out (does not stop confirmation polling for tx) + wg.Wait() + txm.lggr.Debugw("stopped tx retry", "id", msg.id, "signatures", sigs.List(), "err", context.Cause(ctx)) + return + case <-tick: + // determines whether the fee should be bumped based on the fee bump period. + shouldBump := msg.cfg.FeeBumpPeriod != 0 && time.Since(bumpTime) > msg.cfg.FeeBumpPeriod + if shouldBump { + bumpCount++ + bumpTime = time.Now() + // Build new transaction with bumped fee and replace current tx + var err error + currentTx, err = txm.buildTx(ctx, msg, bumpCount) + if err != nil { + // Exit if unable to build transaction for retrying + txm.lggr.Errorw("failed to build bumped retry tx", "error", err, "id", msg.id) + return + } + // allocates space for new signature that will be introduced in handleRetry if needs bumping. + index := sigs.Allocate() + if index != bumpCount { + txm.lggr.Errorw("invariant violation: index does not match bumpCount", "index", index, "bumpCount", bumpCount) + return + } } - tick = time.After(time.Duration(deltaT) * time.Millisecond) + + // Start a goroutine to handle the retry attempt + // takes currentTx and rebroadcast. If needs bumping it will new signature to already allocated space in signatureList. + wg.Add(1) + go func(bump bool, count int, retryTx solanaGo.Transaction) { + defer wg.Done() + txm.handleRetry(ctx, msg, bump, count, retryTx, sigs) + }(shouldBump, bumpCount, currentTx) } - }(ctx, baseTx, initTx) - // return signed tx, id, signature for use in simulation - return initTx, msg.id, sig, nil + // updates the exponential backoff delay up to a maximum limit. + deltaT = deltaT * 2 + if deltaT > MaxRetryTimeMs { + deltaT = MaxRetryTimeMs + } + tick = time.After(time.Duration(deltaT) * time.Millisecond) + } +} + +// handleRetry handles the logic for each retry attempt, including sending the transaction, updating signatures, and logging. +func (txm *Txm) handleRetry(ctx context.Context, msg pendingTx, bump bool, count int, retryTx solanaGo.Transaction, sigs *signatureList) { + // send retry transaction + retrySig, err := txm.sendTx(ctx, &retryTx) + if err != nil { + // this could occur if endpoint goes down or if ctx cancelled + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + txm.lggr.Debugw("ctx error on send retry transaction", "error", err, "signatures", sigs.List(), "id", msg.id) + } else { + txm.lggr.Warnw("failed to send retry transaction", "error", err, "signatures", sigs.List(), "id", msg.id) + } + return + } + + // if bump is true, update signature list and set new signature in space already allocated. + if bump { + if err := txm.txs.AddSignature(msg.id, retrySig); err != nil { + txm.lggr.Warnw("error in adding retry transaction", "error", err, "id", msg.id) + return + } + if err := sigs.Set(count, retrySig); err != nil { + // this should never happen + txm.lggr.Errorw("INVARIANT VIOLATION: failed to set signature", "error", err, "id", msg.id) + return + } + txm.lggr.Debugw("tx rebroadcast with bumped fee", "id", msg.id, "retryCount", count, "fee", msg.cfg.BaseComputeUnitPrice, "signatures", sigs.List()) + } + + // prevent locking on waitgroup when ctx is closed + wait := make(chan struct{}) + go func() { + defer close(wait) + sigs.Wait(count) // wait until bump tx has set the tx signature to compare rebroadcast signatures + }() + select { + case <-ctx.Done(): + return + case <-wait: + } + + // this should never happen (should match the signature saved to sigs) + if fetchedSig, err := sigs.Get(count); err != nil || retrySig != fetchedSig { + txm.lggr.Errorw("original signature does not match retry signature", "expectedSignatures", sigs.List(), "receivedSignature", retrySig, "error", err) + } } -// goroutine that polls to confirm implementation -// cancels the exponential retry once confirmed +// confirm is a goroutine that continuously polls for transaction confirmations and handles rebroadcasts expired transactions if enabled. +// The function runs until the chStop channel signals to stop. func (txm *Txm) confirm() { defer txm.done.Done() ctx, cancel := txm.chStop.NewCtx() @@ -377,139 +382,227 @@ func (txm *Txm) confirm() { case <-ctx.Done(): return case <-tick: - // get list of tx signatures to confirm - sigs := txm.txs.ListAll() - - // exit switch if not txs to confirm - if len(sigs) == 0 { + // If no signatures to confirm and rebroadcast, we can break loop as there's nothing to process. + if txm.InflightTxs() == 0 { break } - // get client client, err := txm.client.Get() if err != nil { - txm.lggr.Errorw("failed to get client in soltxm.confirm", "error", err) - break // exit switch + txm.lggr.Errorw("failed to get client in txm.confirm", "error", err) + break + } + txm.processConfirmations(ctx, client) + if txm.cfg.TxExpirationRebroadcast() { + txm.rebroadcastExpiredTxs(ctx, client) } + } + tick = time.After(utils.WithJitter(txm.cfg.ConfirmPollPeriod())) + } +} - // batch sigs no more than MaxSigsToConfirm each - sigsBatch, err := utils.BatchSplit(sigs, MaxSigsToConfirm) - if err != nil { // this should never happen - txm.lggr.Fatalw("failed to batch signatures", "error", err) - break // exit switch +// processConfirmations checks the status of transaction signatures on-chain and updates our in-memory state accordingly. +// It splits the signatures into batches, retrieves their statuses with an RPC call, and processes each status accordingly. +// The function handles transitions, managing expiration, errors, and transitions between different states like broadcasted, processed, confirmed, and finalized. +// It also determines when to end polling based on the status of each signature cancelling the exponential retry. +func (txm *Txm) processConfirmations(ctx context.Context, client client.ReaderWriter) { + sigsBatch, err := utils.BatchSplit(txm.txs.ListAllSigs(), MaxSigsToConfirm) + if err != nil { // this should never happen + txm.lggr.Fatalw("failed to batch signatures", "error", err) + return + } + + var wg sync.WaitGroup + for i := 0; i < len(sigsBatch); i++ { + statuses, err := client.SignatureStatuses(ctx, sigsBatch[i]) + if err != nil { + txm.lggr.Errorw("failed to get signature statuses in txm.confirm", "error", err) + break + } + + wg.Add(1) + // nonblocking: process batches as soon as they come in + go func(index int) { + defer wg.Done() + + // to process successful first + sortedSigs, sortedRes, err := SortSignaturesAndResults(sigsBatch[i], statuses) + if err != nil { + txm.lggr.Errorw("sorting error", "error", err) + return } - // process signatures - processSigs := func(s []solanaGo.Signature, res []*rpc.SignatureStatusesResult) { - // sort signatures and results process successful first - s, res, err := SortSignaturesAndResults(s, res) - if err != nil { - txm.lggr.Errorw("sorting error", "error", err) - return + for j := 0; j < len(sortedRes); j++ { + sig, status := sortedSigs[j], sortedRes[j] + // sig not found could mean invalid tx or not picked up yet, keep polling + if status == nil { + txm.handleNotFoundSignatureStatus(sig) + continue } - for i := 0; i < len(res); i++ { - // if status is nil (sig not found), continue polling - // sig not found could mean invalid tx or not picked up yet - if res[i] == nil { - txm.lggr.Debugw("tx state: not found", - "signature", s[i], - ) - - // check confirm timeout exceeded - if txm.cfg.TxConfirmTimeout() != 0*time.Second && txm.txs.Expired(s[i], txm.cfg.TxConfirmTimeout()) { - id, err := txm.txs.OnError(s[i], txm.cfg.TxRetentionTimeout(), Errored, TxFailDrop) - if err != nil { - txm.lggr.Infow("failed to mark transaction as errored", "id", id, "signature", s[i], "timeoutSeconds", txm.cfg.TxConfirmTimeout(), "error", err) - } else { - txm.lggr.Debugw("failed to find transaction within confirm timeout", "id", id, "signature", s[i], "timeoutSeconds", txm.cfg.TxConfirmTimeout()) - } - } - continue - } - - // if signature has an error, end polling - if res[i].Err != nil { - // Process error to determine the corresponding state and type. - // Skip marking as errored if error considered to not be a failure. - if txState, errType := txm.ProcessError(s[i], res[i].Err, false); errType != NoFailure { - id, err := txm.txs.OnError(s[i], txm.cfg.TxRetentionTimeout(), txState, errType) - if err != nil { - txm.lggr.Infow(fmt.Sprintf("failed to mark transaction as %s", txState.String()), "id", id, "signature", s[i], "error", err) - } else { - txm.lggr.Debugw(fmt.Sprintf("marking transaction as %s", txState.String()), "id", id, "signature", s[i], "error", res[i].Err, "status", res[i].ConfirmationStatus) - } - } - continue - } + // if signature has an error, end polling unless blockhash not found and expiration rebroadcast is enabled + if status.Err != nil { + txm.handleErrorSignatureStatus(sig, status) + continue + } + switch status.ConfirmationStatus { + case rpc.ConfirmationStatusProcessed: // if signature is processed, keep polling for confirmed or finalized status - if res[i].ConfirmationStatus == rpc.ConfirmationStatusProcessed { - // update transaction state in local memory - id, err := txm.txs.OnProcessed(s[i]) - if err != nil && !errors.Is(err, ErrAlreadyInExpectedState) { - txm.lggr.Errorw("failed to mark transaction as processed", "signature", s[i], "error", err) - } else if err == nil { - txm.lggr.Debugw("marking transaction as processed", "id", id, "signature", s[i]) - } - // check confirm timeout exceeded if TxConfirmTimeout set - if txm.cfg.TxConfirmTimeout() != 0*time.Second && txm.txs.Expired(s[i], txm.cfg.TxConfirmTimeout()) { - id, err := txm.txs.OnError(s[i], txm.cfg.TxRetentionTimeout(), Errored, TxFailDrop) - if err != nil { - txm.lggr.Infow("failed to mark transaction as errored", "id", id, "signature", s[i], "timeoutSeconds", txm.cfg.TxConfirmTimeout(), "error", err) - } else { - txm.lggr.Debugw("tx failed to move beyond 'processed' within confirm timeout", "id", id, "signature", s[i], "timeoutSeconds", txm.cfg.TxConfirmTimeout()) - } - } - continue - } - + txm.handleProcessedSignatureStatus(sig) + continue + case rpc.ConfirmationStatusConfirmed: // if signature is confirmed, keep polling for finalized status - if res[i].ConfirmationStatus == rpc.ConfirmationStatusConfirmed { - id, err := txm.txs.OnConfirmed(s[i]) - if err != nil && !errors.Is(err, ErrAlreadyInExpectedState) { - txm.lggr.Errorw("failed to mark transaction as confirmed", "id", id, "signature", s[i], "error", err) - } else if err == nil { - txm.lggr.Debugw("marking transaction as confirmed", "id", id, "signature", s[i]) - } - continue - } - + txm.handleConfirmedSignatureStatus(sig) + continue + case rpc.ConfirmationStatusFinalized: // if signature is finalized, end polling - if res[i].ConfirmationStatus == rpc.ConfirmationStatusFinalized { - id, err := txm.txs.OnFinalized(s[i], txm.cfg.TxRetentionTimeout()) - if err != nil { - txm.lggr.Errorw("failed to mark transaction as finalized", "id", id, "signature", s[i], "error", err) - } else { - txm.lggr.Debugw("marking transaction as finalized", "id", id, "signature", s[i]) - } - continue - } + txm.handleFinalizedSignatureStatus(sig) + continue + default: + txm.lggr.Warnw("unknown confirmation status", "signature", sig, "status", status.ConfirmationStatus) + continue } } + }(i) + } + wg.Wait() // wait for processing to finish +} - // waitgroup for processing - var wg sync.WaitGroup +// handleNotFoundSignatureStatus handles the case where a transaction signature is not found on-chain. +// If the confirmation timeout has been exceeded it marks the transaction as errored. +func (txm *Txm) handleNotFoundSignatureStatus(sig solanaGo.Signature) { + txm.lggr.Debugw("tx state: not found", "signature", sig) + if txm.cfg.TxConfirmTimeout() != 0*time.Second && txm.txs.Expired(sig, txm.cfg.TxConfirmTimeout()) { + id, err := txm.txs.OnError(sig, txm.cfg.TxRetentionTimeout(), Errored, TxFailDrop) + if err != nil { + txm.lggr.Infow("failed to mark transaction as errored", "id", id, "signature", sig, "timeoutSeconds", txm.cfg.TxConfirmTimeout(), "error", err) + } else { + txm.lggr.Debugw("failed to find transaction within confirm timeout", "id", id, "signature", sig, "timeoutSeconds", txm.cfg.TxConfirmTimeout()) + } + } +} - // loop through batch - for i := 0; i < len(sigsBatch); i++ { - // fetch signature statuses - statuses, err := client.SignatureStatuses(ctx, sigsBatch[i]) - if err != nil { - txm.lggr.Errorw("failed to get signature statuses in soltxm.confirm", "error", err) - break // exit for loop - } +// handleErrorSignatureStatus handles the case where a transaction signature has an error on-chain. +// If the error is BlockhashNotFound and expiration rebroadcast is enabled, it skips error handling to allow rebroadcasting. +// Otherwise, it marks the transaction as errored. +func (txm *Txm) handleErrorSignatureStatus(sig solanaGo.Signature, status *rpc.SignatureStatusesResult) { + // We want to rebroadcast rather than drop tx if expiration rebroadcast is enabled when blockhash was not found. + // converting error to string so we are able to check if it contains the error message. + if status.Err != nil && strings.Contains(fmt.Sprintf("%v", status.Err), "BlockhashNotFound") && txm.cfg.TxExpirationRebroadcast() { + return + } - wg.Add(1) - // nonblocking: process batches as soon as they come in - go func(index int) { - defer wg.Done() - processSigs(sigsBatch[index], statuses) - }(i) - } - wg.Wait() // wait for processing to finish + // Process error to determine the corresponding state and type. + // Skip marking as errored if error considered to not be a failure. + if txState, errType := txm.ProcessError(sig, status.Err, false); errType != NoFailure { + id, err := txm.txs.OnError(sig, txm.cfg.TxRetentionTimeout(), txState, errType) + if err != nil { + txm.lggr.Infow(fmt.Sprintf("failed to mark transaction as %s", txState.String()), "id", id, "signature", sig, "error", err) + } else { + txm.lggr.Debugw(fmt.Sprintf("marking transaction as %s", txState.String()), "id", id, "signature", sig, "error", status.Err, "status", status.ConfirmationStatus) } - tick = time.After(utils.WithJitter(txm.cfg.ConfirmPollPeriod())) + } +} + +// handleProcessedSignatureStatus handles the case where a transaction signature is in the "processed" state on-chain. +// It updates the transaction state in the local memory and checks if the confirmation timeout has been exceeded. +// If the timeout is exceeded, it marks the transaction as errored. +func (txm *Txm) handleProcessedSignatureStatus(sig solanaGo.Signature) { + // update transaction state in local memory + id, err := txm.txs.OnProcessed(sig) + if err != nil && !errors.Is(err, ErrAlreadyInExpectedState) { + txm.lggr.Errorw("failed to mark transaction as processed", "signature", sig, "error", err) + } else if err == nil { + txm.lggr.Debugw("marking transaction as processed", "id", id, "signature", sig) + } + // check confirm timeout exceeded if TxConfirmTimeout set + if txm.cfg.TxConfirmTimeout() != 0*time.Second && txm.txs.Expired(sig, txm.cfg.TxConfirmTimeout()) { + id, err := txm.txs.OnError(sig, txm.cfg.TxRetentionTimeout(), Errored, TxFailDrop) + if err != nil { + txm.lggr.Infow("failed to mark transaction as errored", "id", id, "signature", sig, "timeoutSeconds", txm.cfg.TxConfirmTimeout(), "error", err) + } else { + txm.lggr.Debugw("tx failed to move beyond 'processed' within confirm timeout", "id", id, "signature", sig, "timeoutSeconds", txm.cfg.TxConfirmTimeout()) + } + } +} + +// handleConfirmedSignatureStatus handles the case where a transaction signature is in the "confirmed" state on-chain. +// It updates the transaction state in the local memory. +func (txm *Txm) handleConfirmedSignatureStatus(sig solanaGo.Signature) { + id, err := txm.txs.OnConfirmed(sig) + if err != nil && !errors.Is(err, ErrAlreadyInExpectedState) { + txm.lggr.Errorw("failed to mark transaction as confirmed", "id", id, "signature", sig, "error", err) + } else if err == nil { + txm.lggr.Debugw("marking transaction as confirmed", "id", id, "signature", sig) + } +} + +// handleFinalizedSignatureStatus handles the case where a transaction signature is in the "finalized" state on-chain. +// It updates the transaction state in the local memory. +func (txm *Txm) handleFinalizedSignatureStatus(sig solanaGo.Signature) { + id, err := txm.txs.OnFinalized(sig, txm.cfg.TxRetentionTimeout()) + if err != nil { + txm.lggr.Errorw("failed to mark transaction as finalized", "id", id, "signature", sig, "error", err) + } else { + txm.lggr.Debugw("marking transaction as finalized", "id", id, "signature", sig) + } +} + +// rebroadcastExpiredTxs attempts to rebroadcast all transactions that are in broadcasted state and have expired. +// An expired tx is one where it's blockhash lastValidBlockHeight (last valid block number) is smaller than the current block height (block number). +// The function loops through all expired txes, rebroadcasts them with a new blockhash, and updates the lastValidBlockHeight. +// If any error occurs during rebroadcast attempt, they are discarded, and the function continues with the next transaction. +func (txm *Txm) rebroadcastExpiredTxs(ctx context.Context, client client.ReaderWriter) { + currBlock, err := client.GetLatestBlock(ctx) + if err != nil || currBlock == nil || currBlock.BlockHeight == nil { + txm.lggr.Errorw("failed to get current block height", "error", err) + return + } + + // Get all expired broadcasted transactions at current block number. Safe to quit if no txes are found. + expiredBroadcastedTxes := txm.txs.ListAllExpiredBroadcastedTxs(*currBlock.BlockHeight) + if len(expiredBroadcastedTxes) == 0 { + return + } + + blockhash, err := client.LatestBlockhash(ctx) + if err != nil { + txm.lggr.Errorw("failed to getLatestBlockhash for rebroadcast", "error", err) + return + } + if blockhash == nil || blockhash.Value == nil { + txm.lggr.Errorw("nil pointer returned from getLatestBlockhash for rebroadcast") + return + } + + // rebroadcast each expired tx after updating blockhash, lastValidBlockHeight and compute unit price (priority fee) + for _, tx := range expiredBroadcastedTxes { + txm.lggr.Debugw("transaction expired, rebroadcasting", "id", tx.id, "signature", tx.signatures, "lastValidBlockHeight", tx.lastValidBlockHeight, "currentBlockHeight", *currBlock.BlockHeight) + // Removes all signatures associated to prior tx and cancels context. + _, err := txm.txs.Remove(tx.id) + if err != nil { + txm.lggr.Errorw("failed to remove expired transaction", "id", tx.id, "error", err) + continue + } + + tx.tx.Message.RecentBlockhash = blockhash.Value.Blockhash + tx.cfg.BaseComputeUnitPrice = txm.fee.BaseComputeUnitPrice() + rebroadcastTx := pendingTx{ + tx: tx.tx, + cfg: tx.cfg, + id: tx.id, // using same id in case it was set by caller and we need to maintain it. + lastValidBlockHeight: blockhash.Value.LastValidBlockHeight, + } + // call sendWithRetry directly to avoid enqueuing + _, _, _, sendErr := txm.sendWithRetry(ctx, rebroadcastTx) + if sendErr != nil { + stateTransitionErr := txm.txs.OnPrebroadcastError(tx.id, txm.cfg.TxRetentionTimeout(), Errored, TxFailReject) + txm.lggr.Errorw("failed to rebroadcast transaction", "id", tx.id, "error", errors.Join(sendErr, stateTransitionErr)) + continue + } + + txm.lggr.Debugw("rebroadcast transaction sent", "id", tx.id) } } @@ -580,7 +673,7 @@ func (txm *Txm) reap() { } // Enqueue enqueues a msg destined for the solana chain. -func (txm *Txm) Enqueue(ctx context.Context, accountID string, tx *solanaGo.Transaction, txID *string, txCfgs ...SetTxConfig) error { +func (txm *Txm) Enqueue(ctx context.Context, accountID string, tx *solanaGo.Transaction, txID *string, txLastValidBlockHeight uint64, txCfgs ...SetTxConfig) error { if err := txm.Ready(); err != nil { return fmt.Errorf("error in soltxm.Enqueue: %w", err) } @@ -628,9 +721,10 @@ func (txm *Txm) Enqueue(ctx context.Context, accountID string, tx *solanaGo.Tran } msg := pendingTx{ - tx: *tx, - cfg: cfg, - id: id, + id: id, + tx: *tx, + cfg: cfg, + lastValidBlockHeight: txLastValidBlockHeight, } select { @@ -745,7 +839,7 @@ func (txm *Txm) simulateTx(ctx context.Context, tx *solanaGo.Transaction) (res * return } -// processError parses and handles relevant errors found in simulation results +// ProcessError parses and handles relevant errors found in simulation results func (txm *Txm) ProcessError(sig solanaGo.Signature, resErr interface{}, simulation bool) (txState TxState, errType TxErrType) { if resErr != nil { // handle various errors @@ -827,8 +921,9 @@ func (txm *Txm) ProcessError(sig solanaGo.Signature, resErr interface{}, simulat return } +// InflightTxs returns the number of signatures being tracked for all transactions not yet finalized or errored func (txm *Txm) InflightTxs() int { - return len(txm.txs.ListAll()) + return len(txm.txs.ListAllSigs()) } // Close close service diff --git a/pkg/solana/txm/txm_integration_test.go b/pkg/solana/txm/txm_integration_test.go new file mode 100644 index 000000000..154a42f6a --- /dev/null +++ b/pkg/solana/txm/txm_integration_test.go @@ -0,0 +1,187 @@ +//go:build integration + +package txm_test + +import ( + "context" + "testing" + "time" + + "github.com/gagliardetto/solana-go" + "github.com/gagliardetto/solana-go/programs/system" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "go.uber.org/zap/zapcore" + "go.uber.org/zap/zaptest/observer" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/services/servicetest" + "github.com/smartcontractkit/chainlink-common/pkg/types" + "github.com/smartcontractkit/chainlink-common/pkg/utils" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" + + relayconfig "github.com/smartcontractkit/chainlink-common/pkg/config" + + solanaClient "github.com/smartcontractkit/chainlink-solana/pkg/solana/client" + "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" + "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm" + keyMocks "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm/mocks" +) + +func TestTxm_Integration_ExpirationRebroadcast(t *testing.T) { + t.Parallel() + url := solanaClient.SetupLocalSolNode(t) // live validator + + type TestCase struct { + name string + txExpirationRebroadcast bool + useValidBlockHash bool + expectRebroadcast bool + expectTransactionStatus types.TransactionStatus + } + + testCases := []TestCase{ + { + name: "WithRebroadcast", + txExpirationRebroadcast: true, + useValidBlockHash: false, + expectRebroadcast: true, + expectTransactionStatus: types.Finalized, + }, + { + name: "WithoutRebroadcast", + txExpirationRebroadcast: false, + useValidBlockHash: false, + expectRebroadcast: false, + expectTransactionStatus: types.Failed, + }, + { + name: "ConfirmedBeforeRebroadcast", + txExpirationRebroadcast: true, + useValidBlockHash: true, + expectRebroadcast: false, + expectTransactionStatus: types.Finalized, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + ctx, client, txmInstance, senderPubKey, receiverPubKey, observer := setup(t, url, tc.txExpirationRebroadcast) + + // Record initial balance + initSenderBalance, err := client.Balance(ctx, senderPubKey) + require.NoError(t, err) + const amount = 1 * solana.LAMPORTS_PER_SOL + + // Create and enqueue tx + txID := tc.name + tx, lastValidBlockHeight := createTransaction(ctx, t, client, senderPubKey, receiverPubKey, amount, tc.useValidBlockHash) + require.NoError(t, txmInstance.Enqueue(ctx, "", tx, &txID, lastValidBlockHeight)) + + // Wait for the transaction to reach the expected status + require.Eventually(t, func() bool { + status, statusErr := txmInstance.GetTransactionStatus(ctx, txID) + if statusErr != nil { + return false + } + return status == tc.expectTransactionStatus + }, 60*time.Second, 1*time.Second, "Transaction should eventually reach expected status") + + // Verify balances + finalSenderBalance, err := client.Balance(ctx, senderPubKey) + require.NoError(t, err) + finalReceiverBalance, err := client.Balance(ctx, receiverPubKey) + require.NoError(t, err) + + if tc.expectTransactionStatus == types.Finalized { + require.Less(t, finalSenderBalance, initSenderBalance, "Sender balance should decrease") + require.Equal(t, amount, finalReceiverBalance, "Receiver should receive the transferred amount") + } else { + require.Equal(t, initSenderBalance, finalSenderBalance, "Sender balance should remain the same") + require.Equal(t, uint64(0), finalReceiverBalance, "Receiver should not receive any funds") + } + + // Verify rebroadcast logs + rebroadcastLogs := observer.FilterMessageSnippet("rebroadcast transaction sent").Len() + rebroadcastLogs2 := observer.FilterMessageSnippet("transaction expired, rebroadcasting").Len() + if tc.expectRebroadcast { + require.Equal(t, 1, rebroadcastLogs, "Expected rebroadcast log message not found") + require.Equal(t, 1, rebroadcastLogs2, "Expected rebroadcast log message not found") + } else { + require.Equal(t, 0, rebroadcastLogs, "Rebroadcast should not occur") + require.Equal(t, 0, rebroadcastLogs2, "Rebroadcast should not occur") + } + }) + } +} + +func setup(t *testing.T, url string, txExpirationRebroadcast bool) (context.Context, *solanaClient.Client, *txm.Txm, solana.PublicKey, solana.PublicKey, *observer.ObservedLogs) { + ctx := tests.Context(t) + + // Generate sender and receiver keys and fund sender account + senderKey, err := solana.NewRandomPrivateKey() + require.NoError(t, err) + senderPubKey := senderKey.PublicKey() + receiverKey, err := solana.NewRandomPrivateKey() + require.NoError(t, err) + receiverPubKey := receiverKey.PublicKey() + solanaClient.FundTestAccounts(t, []solana.PublicKey{senderPubKey}, url) + + // Set up mock keystore with sender key + mkey := keyMocks.NewSimpleKeystore(t) + mkey.On("Sign", mock.Anything, senderPubKey.String(), mock.Anything).Return(func(_ context.Context, _ string, data []byte) []byte { + sig, _ := senderKey.Sign(data) + return sig[:] + }, nil) + + // Set configs + cfg := config.NewDefault() + cfg.Chain.TxExpirationRebroadcast = &txExpirationRebroadcast + cfg.Chain.TxRetentionTimeout = relayconfig.MustNewDuration(10 * time.Second) // to get the finalized tx status + + // Initialize the Solana client and TXM + lggr, obs := logger.TestObserved(t, zapcore.DebugLevel) + client, err := solanaClient.NewClient(url, cfg, 2*time.Second, lggr) + require.NoError(t, err) + loader := utils.NewLazyLoad(func() (solanaClient.ReaderWriter, error) { return client, nil }) + txmInstance := txm.NewTxm("localnet", loader, nil, cfg, mkey, lggr) + servicetest.Run(t, txmInstance) + + return ctx, client, txmInstance, senderPubKey, receiverPubKey, obs +} + +// createTransaction is a helper function to create a transaction based on the test case. +func createTransaction(ctx context.Context, t *testing.T, client *solanaClient.Client, senderPubKey, receiverPubKey solana.PublicKey, amount uint64, useValidBlockHash bool) (*solana.Transaction, uint64) { + var blockhash solana.Hash + var lastValidBlockHeight uint64 + + if useValidBlockHash { + // Get a valid recent blockhash + recentBlockHashResult, err := client.LatestBlockhash(ctx) + require.NoError(t, err) + blockhash = recentBlockHashResult.Value.Blockhash + lastValidBlockHeight = recentBlockHashResult.Value.LastValidBlockHeight + } else { + // Use empty blockhash to simulate expiration + blockhash = solana.Hash{} + lastValidBlockHeight = 0 + } + + // Create the transaction + tx, err := solana.NewTransaction( + []solana.Instruction{ + system.NewTransferInstruction( + amount, + senderPubKey, + receiverPubKey, + ).Build(), + }, + blockhash, + solana.TransactionPayer(senderPubKey), + ) + require.NoError(t, err) + + return tx, lastValidBlockHeight +} diff --git a/pkg/solana/txm/txm_internal_test.go b/pkg/solana/txm/txm_internal_test.go index 0054e0a2b..13c861362 100644 --- a/pkg/solana/txm/txm_internal_test.go +++ b/pkg/solana/txm/txm_internal_test.go @@ -161,7 +161,6 @@ func TestTxm(t *testing.T) { return out }, nil, ) - // happy path (send => simulate success => tx: nil => tx: processed => tx: confirmed => finalized => done) t.Run("happyPath", func(t *testing.T) { sig := randomSignature(t) @@ -204,7 +203,8 @@ func TestTxm(t *testing.T) { // send tx testTxID := uuid.New().String() - assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID)) + lastValidBlockHeight := uint64(100) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID, lastValidBlockHeight)) wg.Wait() // no transactions stored inflight txs list @@ -240,7 +240,8 @@ func TestTxm(t *testing.T) { // tx should be able to queue testTxID := uuid.New().String() - assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID)) + lastValidBlockHeight := uint64(100) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID, lastValidBlockHeight)) wg.Wait() // wait to be picked up and processed // no transactions stored inflight txs list @@ -272,7 +273,8 @@ func TestTxm(t *testing.T) { // tx should be able to queue testTxID := uuid.New().String() - assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID)) + lastValidBlockHeight := uint64(100) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID, lastValidBlockHeight)) wg.Wait() // wait to be picked up and processed waitFor(t, waitDuration, txm, prom, empty) // txs cleared quickly @@ -308,7 +310,8 @@ func TestTxm(t *testing.T) { // tx should be able to queue testTxID := uuid.New().String() - assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID)) + lastValidBlockHeight := uint64(100) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID, lastValidBlockHeight)) wg.Wait() // wait to be picked up and processed waitFor(t, waitDuration, txm, prom, empty) // txs cleared after timeout @@ -348,7 +351,8 @@ func TestTxm(t *testing.T) { // tx should be able to queue testTxID := uuid.New().String() - assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID)) + lastValidBlockHeight := uint64(100) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID, lastValidBlockHeight)) wg.Wait() // wait to be picked up and processed waitFor(t, waitDuration, txm, prom, empty) // txs cleared after timeout @@ -399,7 +403,8 @@ func TestTxm(t *testing.T) { // tx should be able to queue testTxID := uuid.New().String() - assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID)) + lastValidBlockHeight := uint64(100) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID, lastValidBlockHeight)) wg.Wait() // wait to be picked up and processed waitFor(t, waitDuration, txm, prom, empty) // txs cleared after timeout @@ -441,7 +446,8 @@ func TestTxm(t *testing.T) { // tx should be able to queue testTxID := uuid.New().String() - assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID)) + lastValidBlockHeight := uint64(100) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID, lastValidBlockHeight)) wg.Wait() // wait to be picked up and processed waitFor(t, waitDuration, txm, prom, empty) // txs cleared after timeout @@ -486,7 +492,8 @@ func TestTxm(t *testing.T) { // tx should be able to queue testTxID := uuid.New().String() - assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID)) + lastValidBlockHeight := uint64(100) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID, lastValidBlockHeight)) wg.Wait() // wait to be picked up and processed waitFor(t, waitDuration, txm, prom, empty) // inflight txs cleared after timeout @@ -538,7 +545,8 @@ func TestTxm(t *testing.T) { // tx should be able to queue testTxID := uuid.New().String() - assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID)) + lastValidBlockHeight := uint64(100) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID, lastValidBlockHeight)) wg.Wait() // wait to be picked up and processed waitFor(t, waitDuration, txm, prom, empty) // inflight txs cleared after timeout @@ -576,7 +584,8 @@ func TestTxm(t *testing.T) { // tx should be able to queue testTxID := uuid.New().String() - assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID)) + lastValidBlockHeight := uint64(100) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID, lastValidBlockHeight)) wg.Wait() // wait to be picked up and processed waitFor(t, waitDuration, txm, prom, empty) // inflight txs cleared after timeout @@ -622,7 +631,8 @@ func TestTxm(t *testing.T) { // send tx testTxID := uuid.New().String() - assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID)) + lastValidBlockHeight := uint64(100) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID, lastValidBlockHeight)) wg.Wait() // no transactions stored inflight txs list @@ -676,7 +686,8 @@ func TestTxm(t *testing.T) { // send tx - with disabled fee bumping testTxID := uuid.New().String() - assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID, SetFeeBumpPeriod(0))) + lastValidBlockHeight := uint64(100) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID, lastValidBlockHeight, SetFeeBumpPeriod(0))) wg.Wait() // no transactions stored inflight txs list @@ -728,7 +739,8 @@ func TestTxm(t *testing.T) { // send tx - with disabled fee bumping and disabled compute unit limit testTxID := uuid.New().String() - assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID, SetFeeBumpPeriod(0), SetComputeUnitLimit(0))) + lastValidBlockHeight := uint64(100) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID, lastValidBlockHeight, SetFeeBumpPeriod(0), SetComputeUnitLimit(0))) wg.Wait() // no transactions stored inflight txs list @@ -836,7 +848,8 @@ func TestTxm_disabled_confirm_timeout_with_retention(t *testing.T) { // tx should be able to queue testTxID := uuid.New().String() - assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID)) + lastValidBlockHeight := uint64(100) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID, lastValidBlockHeight)) wg.Wait() // wait to be picked up and processed waitFor(t, 5*time.Second, txm, prom, empty) // inflight txs cleared after timeout @@ -875,7 +888,8 @@ func TestTxm_disabled_confirm_timeout_with_retention(t *testing.T) { // tx should be able to queue testTxID := uuid.NewString() - assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID)) + lastValidBlockHeight := uint64(100) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID, lastValidBlockHeight)) wg.Wait() waitFor(t, 5*time.Second, txm, prom, empty) // inflight txs cleared after timeout @@ -920,7 +934,8 @@ func TestTxm_disabled_confirm_timeout_with_retention(t *testing.T) { // tx should be able to queue testTxID := uuid.NewString() - assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID)) + lastValidBlockHeight := uint64(100) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID, lastValidBlockHeight)) wg.Wait() // wait till send tx waitFor(t, 5*time.Second, txm, prom, empty) // inflight txs cleared after timeout @@ -1040,7 +1055,8 @@ func TestTxm_compute_unit_limit_estimation(t *testing.T) { // send tx testTxID := uuid.New().String() - assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID)) + lastValidBlockHeight := uint64(100) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID, lastValidBlockHeight)) wg.Wait() // no transactions stored inflight txs list @@ -1069,7 +1085,8 @@ func TestTxm_compute_unit_limit_estimation(t *testing.T) { mc.On("SimulateTx", mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("simulation failed")).Once() // tx should NOT be able to queue - assert.Error(t, txm.Enqueue(ctx, t.Name(), tx, nil)) + lastValidBlockHeight := uint64(0) + assert.Error(t, txm.Enqueue(ctx, t.Name(), tx, nil, lastValidBlockHeight)) }) t.Run("simulation_returns_error", func(t *testing.T) { @@ -1084,8 +1101,9 @@ func TestTxm_compute_unit_limit_estimation(t *testing.T) { mc.On("SimulateTx", mock.Anything, simulateTx, mock.Anything).Return(&rpc.SimulateTransactionResult{Err: errors.New("InstructionError")}, nil).Once() txID := uuid.NewString() + lastValidBlockHeight := uint64(100) // tx should NOT be able to queue - assert.Error(t, txm.Enqueue(ctx, t.Name(), tx, &txID)) + assert.Error(t, txm.Enqueue(ctx, t.Name(), tx, &txID, lastValidBlockHeight)) // tx should be stored in-memory and moved to errored state status, err := txm.GetTransactionStatus(ctx, txID) require.NoError(t, err) @@ -1131,6 +1149,7 @@ func TestTxm_Enqueue(t *testing.T) { ) require.NoError(t, err) + lastValidBlockHeight := uint64(0) invalidTx, err := solana.NewTransaction( []solana.Instruction{ system.NewTransferInstruction( @@ -1147,28 +1166,29 @@ func TestTxm_Enqueue(t *testing.T) { loader := utils.NewLazyLoad(func() (client.ReaderWriter, error) { return mc, nil }) txm := NewTxm("enqueue_test", loader, nil, cfg, mkey, lggr) - require.ErrorContains(t, txm.Enqueue(ctx, "txmUnstarted", &solana.Transaction{}, nil), "not started") + require.ErrorContains(t, txm.Enqueue(ctx, "txmUnstarted", &solana.Transaction{}, nil, lastValidBlockHeight), "not started") require.NoError(t, txm.Start(ctx)) t.Cleanup(func() { require.NoError(t, txm.Close()) }) txs := []struct { - name string - tx *solana.Transaction - fail bool + name string + tx *solana.Transaction + lastValidBlockHeight uint64 + fail bool }{ - {"success", tx, false}, - {"invalid_key", invalidTx, true}, - {"nil_pointer", nil, true}, - {"empty_tx", &solana.Transaction{}, true}, + {"success", tx, 100, false}, + {"invalid_key", invalidTx, 0, true}, + {"nil_pointer", nil, 0, true}, + {"empty_tx", &solana.Transaction{}, 0, true}, } for _, run := range txs { t.Run(run.name, func(t *testing.T) { if !run.fail { - assert.NoError(t, txm.Enqueue(ctx, run.name, run.tx, nil)) + assert.NoError(t, txm.Enqueue(ctx, run.name, run.tx, nil, run.lastValidBlockHeight)) return } - assert.Error(t, txm.Enqueue(ctx, run.name, run.tx, nil)) + assert.Error(t, txm.Enqueue(ctx, run.name, run.tx, nil, run.lastValidBlockHeight)) }) } } @@ -1186,3 +1206,406 @@ func addSigAndLimitToTx(t *testing.T, keystore SimpleKeystore, pubkey solana.Pub require.NoError(t, fees.SetComputeUnitLimit(&txCopy, limit)) return &txCopy } + +func TestTxm_ExpirationRebroadcast(t *testing.T) { + t.Parallel() + estimator := "fixed" + id := "mocknet-" + estimator + "-" + uuid.NewString() + cfg := config.NewDefault() + cfg.Chain.FeeEstimatorMode = &estimator + cfg.Chain.TxConfirmTimeout = relayconfig.MustNewDuration(5 * time.Second) + cfg.Chain.TxRetentionTimeout = relayconfig.MustNewDuration(10 * time.Second) // Enable retention to keep transactions after finality and be able to check their statuses. + lggr := logger.Test(t) + ctx := tests.Context(t) + + // Helper function to set up common test environment + setupTxmTest := func( + txExpirationRebroadcast bool, + latestBlockhashFunc func() (*rpc.GetLatestBlockhashResult, error), + getLatestBlockFunc func() (*rpc.GetBlockResult, error), + sendTxFunc func() (solana.Signature, error), + statuses map[solana.Signature]func() *rpc.SignatureStatusesResult, + ) (*Txm, *mocks.ReaderWriter, *keyMocks.SimpleKeystore) { + cfg.Chain.TxExpirationRebroadcast = &txExpirationRebroadcast + + mc := mocks.NewReaderWriter(t) + if latestBlockhashFunc != nil { + mc.On("LatestBlockhash", mock.Anything).Return( + func(_ context.Context) (*rpc.GetLatestBlockhashResult, error) { + return latestBlockhashFunc() + }, + ).Maybe() + } + if getLatestBlockFunc != nil { + mc.On("GetLatestBlock", mock.Anything).Return( + func(_ context.Context) (*rpc.GetBlockResult, error) { + return getLatestBlockFunc() + }, + ).Maybe() + } + if sendTxFunc != nil { + mc.On("SendTx", mock.Anything, mock.Anything).Return( + func(_ context.Context, _ *solana.Transaction) (solana.Signature, error) { + return sendTxFunc() + }, + ).Maybe() + } + + mc.On("SimulateTx", mock.Anything, mock.Anything, mock.Anything).Return(&rpc.SimulateTransactionResult{}, nil).Maybe() + if statuses != nil { + mc.On("SignatureStatuses", mock.Anything, mock.AnythingOfType("[]solana.Signature")).Return( + func(_ context.Context, sigs []solana.Signature) ([]*rpc.SignatureStatusesResult, error) { + var out []*rpc.SignatureStatusesResult + for _, sig := range sigs { + getStatus, exists := statuses[sig] + if !exists { + out = append(out, nil) + } else { + out = append(out, getStatus()) + } + } + return out, nil + }, + ).Maybe() + } + + mkey := keyMocks.NewSimpleKeystore(t) + mkey.On("Sign", mock.Anything, mock.Anything, mock.Anything).Return([]byte{}, nil) + + loader := utils.NewLazyLoad(func() (client.ReaderWriter, error) { return mc, nil }) + txm := NewTxm(id, loader, nil, cfg, mkey, lggr) + require.NoError(t, txm.Start(ctx)) + t.Cleanup(func() { require.NoError(t, txm.Close()) }) + + return txm, mc, mkey + } + + // tracking prom metrics + prom := soltxmProm{id: id} + + t.Run("WithRebroadcast", func(t *testing.T) { + txExpirationRebroadcast := true + statuses := map[solana.Signature]func() *rpc.SignatureStatusesResult{} + + // Mock getLatestBlock to return a value greater than 0 for blockHeight + getLatestBlockFunc := func() (*rpc.GetBlockResult, error) { + val := uint64(1500) + return &rpc.GetBlockResult{ + BlockHeight: &val, + }, nil + } + + rebroadcastCount := 0 + latestBlockhashFunc := func() (*rpc.GetLatestBlockhashResult, error) { + defer func() { rebroadcastCount++ }() + // rebroadcast call will go through because lastValidBlockHeight is bigger than blockHeight + return &rpc.GetLatestBlockhashResult{ + Value: &rpc.LatestBlockhashResult{ + LastValidBlockHeight: uint64(2000), + }, + }, nil + } + + sig1 := randomSignature(t) + sendTxFunc := func() (solana.Signature, error) { + return sig1, nil + } + + nowTs := time.Now() + sigStatusCallCount := 0 + var wg sync.WaitGroup + wg.Add(1) + statuses[sig1] = func() *rpc.SignatureStatusesResult { + // First transaction should be rebroadcasted. + if time.Since(nowTs) < cfg.TxConfirmTimeout()-2*time.Second { + return nil + } + // Second transaction should reach finalization. + sigStatusCallCount++ + if sigStatusCallCount == 1 { + return &rpc.SignatureStatusesResult{ + ConfirmationStatus: rpc.ConfirmationStatusProcessed, + } + } + if sigStatusCallCount == 2 { + return &rpc.SignatureStatusesResult{ + ConfirmationStatus: rpc.ConfirmationStatusConfirmed, + } + } + wg.Done() + return &rpc.SignatureStatusesResult{ + ConfirmationStatus: rpc.ConfirmationStatusFinalized, + } + } + + txm, _, mkey := setupTxmTest(txExpirationRebroadcast, latestBlockhashFunc, getLatestBlockFunc, sendTxFunc, statuses) + + tx, _ := getTx(t, 0, mkey) + txID := "test-rebroadcast" + lastValidBlockHeight := uint64(100) // lastValidBlockHeight is smaller than blockHeight + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &txID, lastValidBlockHeight)) + wg.Wait() + waitFor(t, txm.cfg.TxConfirmTimeout(), txm, prom, empty) + + // check prom metric + prom.confirmed++ + prom.finalized++ + prom.assertEqual(t) + + // Check that transaction for txID has been finalized and rebroadcasted 1 time. + status, err := txm.GetTransactionStatus(ctx, txID) + require.NoError(t, err) + require.Equal(t, types.Finalized, status) + require.Equal(t, 1, rebroadcastCount) + }) + + t.Run("WithoutRebroadcast", func(t *testing.T) { + txExpirationRebroadcast := false + statuses := map[solana.Signature]func() *rpc.SignatureStatusesResult{} + rebroadcastCount := 0 + + sig1 := randomSignature(t) + sendTxFunc := func() (solana.Signature, error) { + return sig1, nil + } + + nowTs := time.Now() + var wg sync.WaitGroup + wg.Add(1) + statuses[sig1] = func() *rpc.SignatureStatusesResult { + // Transaction remains unconfirmed and should not be rebroadcasted. + if time.Since(nowTs) < cfg.TxConfirmTimeout() { + return nil + } + wg.Done() + return nil + } + + txm, _, mkey := setupTxmTest(txExpirationRebroadcast, nil, nil, sendTxFunc, statuses) + + tx, _ := getTx(t, 5, mkey) + txID := "test-no-rebroadcast" + lastValidBlockHeight := uint64(0) // original lastValidBlockHeight is invalid + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &txID, lastValidBlockHeight)) + wg.Wait() + waitFor(t, txm.cfg.TxConfirmTimeout(), txm, prom, empty) + + // check prom metric + prom.drop++ + prom.error++ + prom.assertEqual(t) + + // Check that transaction for txID has not been finalized and has not been rebroadcasted + status, err := txm.GetTransactionStatus(ctx, txID) + require.NoError(t, err) + require.Equal(t, types.Failed, status) + require.Equal(t, 0, rebroadcastCount) + }) + + t.Run("WithMultipleRebroadcast", func(t *testing.T) { + txExpirationRebroadcast := true + statuses := map[solana.Signature]func() *rpc.SignatureStatusesResult{} + + // Mock getLatestBlock to return a value greater than 0 + getLatestBlockFunc := func() (*rpc.GetBlockResult, error) { + val := uint64(1500) + return &rpc.GetBlockResult{ + BlockHeight: &val, + }, nil + } + + // Mock LatestBlockhash to return an invalid blockhash in the first 2 attempts to rebroadcast. + // the last one is valid because it is greater than the blockHeight + rebroadcastCount := 0 + latestBlockhashFunc := func() (*rpc.GetLatestBlockhashResult, error) { + defer func() { rebroadcastCount++ }() + if rebroadcastCount < 2 { + return &rpc.GetLatestBlockhashResult{ + Value: &rpc.LatestBlockhashResult{ + LastValidBlockHeight: uint64(1000), + }, + }, nil + } + return &rpc.GetLatestBlockhashResult{ + Value: &rpc.LatestBlockhashResult{ + LastValidBlockHeight: uint64(2000), + }, + }, nil + } + + sig1 := randomSignature(t) + sendTxFunc := func() (solana.Signature, error) { + return sig1, nil + } + nowTs := time.Now() + sigStatusCallCount := 0 + var wg sync.WaitGroup + wg.Add(1) + statuses[sig1] = func() *rpc.SignatureStatusesResult { + // transaction should be rebroadcasted multiple times. + if time.Since(nowTs) < cfg.TxConfirmTimeout()-2*time.Second { + return nil + } + // Second transaction should reach finalization. + sigStatusCallCount++ + if sigStatusCallCount == 1 { + return &rpc.SignatureStatusesResult{ + ConfirmationStatus: rpc.ConfirmationStatusProcessed, + } + } else if sigStatusCallCount == 2 { + return &rpc.SignatureStatusesResult{ + ConfirmationStatus: rpc.ConfirmationStatusConfirmed, + } + } + wg.Done() + return &rpc.SignatureStatusesResult{ + ConfirmationStatus: rpc.ConfirmationStatusFinalized, + } + } + + txm, _, mkey := setupTxmTest(txExpirationRebroadcast, latestBlockhashFunc, getLatestBlockFunc, sendTxFunc, statuses) + tx, _ := getTx(t, 0, mkey) + txID := "test-rebroadcast" + lastValidBlockHeight := uint64(100) // lastValidBlockHeight is smaller than blockHeight + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &txID, lastValidBlockHeight)) + wg.Wait() + waitFor(t, txm.cfg.TxConfirmTimeout(), txm, prom, empty) + + // check prom metric + prom.confirmed++ + prom.finalized++ + prom.assertEqual(t) + + // Check that transaction for txID has been finalized and rebroadcasted multiple times. + status, err := txm.GetTransactionStatus(ctx, txID) + require.NoError(t, err) + require.Equal(t, types.Finalized, status) + require.Equal(t, 3, rebroadcastCount) + }) + + t.Run("ConfirmedBeforeRebroadcast", func(t *testing.T) { + txExpirationRebroadcast := true + statuses := map[solana.Signature]func() *rpc.SignatureStatusesResult{} + sig1 := randomSignature(t) + sendTxFunc := func() (solana.Signature, error) { + return sig1, nil + } + + // Mock getLatestBlock to return a value greater than 0 + getLatestBlockFunc := func() (*rpc.GetBlockResult, error) { + val := uint64(1500) + return &rpc.GetBlockResult{ + BlockHeight: &val, + }, nil + } + + rebroadcastCount := 0 + latestBlockhashFunc := func() (*rpc.GetLatestBlockhashResult, error) { + defer func() { rebroadcastCount++ }() + return &rpc.GetLatestBlockhashResult{ + Value: &rpc.LatestBlockhashResult{ + LastValidBlockHeight: uint64(1000), + }, + }, nil + } + + var wg sync.WaitGroup + wg.Add(1) + count := 0 + statuses[sig1] = func() *rpc.SignatureStatusesResult { + defer func() { count++ }() + + out := &rpc.SignatureStatusesResult{} + if count == 1 { + out.ConfirmationStatus = rpc.ConfirmationStatusConfirmed + return out + } + if count == 2 { + out.ConfirmationStatus = rpc.ConfirmationStatusFinalized + wg.Done() + return out + } + out.ConfirmationStatus = rpc.ConfirmationStatusProcessed + return out + } + + txm, _, mkey := setupTxmTest(txExpirationRebroadcast, latestBlockhashFunc, getLatestBlockFunc, sendTxFunc, statuses) + tx, _ := getTx(t, 0, mkey) + txID := "test-confirmed-before-rebroadcast" + lastValidBlockHeight := uint64(1500) // original lastValidBlockHeight is valid + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &txID, lastValidBlockHeight)) + wg.Wait() + waitFor(t, txm.cfg.TxConfirmTimeout(), txm, prom, empty) + + // check prom metric + prom.confirmed++ + prom.finalized++ + prom.assertEqual(t) + + // Check that transaction has been finalized without rebroadcast + status, err := txm.GetTransactionStatus(ctx, txID) + require.NoError(t, err) + require.Equal(t, types.Finalized, status) + require.Equal(t, 0, rebroadcastCount) + }) + + t.Run("RebroadcastWithError", func(t *testing.T) { + txExpirationRebroadcast := true + statuses := map[solana.Signature]func() *rpc.SignatureStatusesResult{} + + // To force rebroadcast, first call needs to be smaller than blockHeight + // following rebroadcast call will go through because lastValidBlockHeight will be bigger than blockHeight + getLatestBlockFunc := func() (*rpc.GetBlockResult, error) { + val := uint64(1500) + return &rpc.GetBlockResult{ + BlockHeight: &val, + }, nil + } + + rebroadcastCount := 0 + latestBlockhashFunc := func() (*rpc.GetLatestBlockhashResult, error) { + defer func() { rebroadcastCount++ }() + return &rpc.GetLatestBlockhashResult{ + Value: &rpc.LatestBlockhashResult{ + LastValidBlockHeight: uint64(2000), + }, + }, nil + } + + sig1 := randomSignature(t) + sendTxFunc := func() (solana.Signature, error) { + return sig1, nil + } + + var wg sync.WaitGroup + wg.Add(1) + count := 0 + statuses[sig1] = func() *rpc.SignatureStatusesResult { + defer func() { count++ }() + // Transaction remains unconfirmed + if count == 1 { + wg.Done() + } + return nil + } + + txm, _, mkey := setupTxmTest(txExpirationRebroadcast, latestBlockhashFunc, getLatestBlockFunc, sendTxFunc, statuses) + tx, _ := getTx(t, 0, mkey) + txID := "test-rebroadcast-error" + lastValidBlockHeight := uint64(100) // lastValidBlockHeight is smaller than blockHeight + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &txID, lastValidBlockHeight)) + wg.Wait() + waitFor(t, cfg.TxConfirmTimeout(), txm, prom, empty) + + // check prom metric + prom.drop++ + prom.error++ + prom.assertEqual(t) + + // Transaction should be moved to failed after trying to rebroadcast 1 time. + status, err := txm.GetTransactionStatus(ctx, txID) + require.NoError(t, err) + require.Equal(t, types.Failed, status) + require.Equal(t, 1, rebroadcastCount) + }) +} diff --git a/pkg/solana/txm/txm_load_test.go b/pkg/solana/txm/txm_load_test.go index 5d5a8061b..333c95e23 100644 --- a/pkg/solana/txm/txm_load_test.go +++ b/pkg/solana/txm/txm_load_test.go @@ -14,7 +14,6 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" - "github.com/smartcontractkit/chainlink-common/pkg/services/servicetest" solanaClient "github.com/smartcontractkit/chainlink-solana/pkg/solana/client" "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm" @@ -22,6 +21,7 @@ import ( relayconfig "github.com/smartcontractkit/chainlink-common/pkg/config" "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/services/servicetest" "github.com/smartcontractkit/chainlink-common/pkg/utils" "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" ) @@ -83,12 +83,11 @@ func TestTxm_Integration(t *testing.T) { // already started assert.Error(t, txm.Start(ctx)) - - createTx := func(signer solana.PublicKey, sender solana.PublicKey, receiver solana.PublicKey, amt uint64) *solana.Transaction { + createTx := func(signer solana.PublicKey, sender solana.PublicKey, receiver solana.PublicKey, amt uint64) (*solana.Transaction, uint64) { // create transfer tx - hash, err := client.LatestBlockhash(ctx) - assert.NoError(t, err) - tx, err := solana.NewTransaction( + hash, blockhashErr := client.LatestBlockhash(ctx) + assert.NoError(t, blockhashErr) + tx, txErr := solana.NewTransaction( []solana.Instruction{ system.NewTransferInstruction( amt, @@ -99,22 +98,27 @@ func TestTxm_Integration(t *testing.T) { hash.Value.Blockhash, solana.TransactionPayer(signer), ) - require.NoError(t, err) - return tx + require.NoError(t, txErr) + return tx, hash.Value.LastValidBlockHeight } - // enqueue txs (must pass to move on to load test) - require.NoError(t, txm.Enqueue(ctx, "test_success_0", createTx(pubKey, pubKey, pubKeyReceiver, solana.LAMPORTS_PER_SOL), nil)) - require.Error(t, txm.Enqueue(ctx, "test_invalidSigner", createTx(pubKeyReceiver, pubKey, pubKeyReceiver, solana.LAMPORTS_PER_SOL), nil)) // cannot sign tx before enqueuing - require.NoError(t, txm.Enqueue(ctx, "test_invalidReceiver", createTx(pubKey, pubKey, solana.PublicKey{}, solana.LAMPORTS_PER_SOL), nil)) + tx, lastValidBlockHeight := createTx(pubKey, pubKey, pubKeyReceiver, solana.LAMPORTS_PER_SOL) + require.NoError(t, txm.Enqueue(ctx, "test_success_0", tx, nil, lastValidBlockHeight)) + tx2, lastValidBlockHeight2 := createTx(pubKeyReceiver, pubKey, pubKeyReceiver, solana.LAMPORTS_PER_SOL) + require.Error(t, txm.Enqueue(ctx, "test_invalidSigner", tx2, nil, lastValidBlockHeight2)) // cannot sign tx before enqueuing + tx3, lastValidBlockHeight3 := createTx(pubKey, pubKey, solana.PublicKey{}, solana.LAMPORTS_PER_SOL) + require.NoError(t, txm.Enqueue(ctx, "test_invalidReceiver", tx3, nil, lastValidBlockHeight3)) time.Sleep(500 * time.Millisecond) // pause 0.5s for new blockhash - require.NoError(t, txm.Enqueue(ctx, "test_success_1", createTx(pubKey, pubKey, pubKeyReceiver, solana.LAMPORTS_PER_SOL), nil)) - require.NoError(t, txm.Enqueue(ctx, "test_txFail", createTx(pubKey, pubKey, pubKeyReceiver, 1000*solana.LAMPORTS_PER_SOL), nil)) + tx4, lastValidBlockHeight4 := createTx(pubKey, pubKey, pubKeyReceiver, solana.LAMPORTS_PER_SOL) + require.NoError(t, txm.Enqueue(ctx, "test_success_1", tx4, nil, lastValidBlockHeight4)) + tx5, lastValidBlockHeight5 := createTx(pubKey, pubKey, pubKeyReceiver, 1000*solana.LAMPORTS_PER_SOL) + require.NoError(t, txm.Enqueue(ctx, "test_txFail", tx5, nil, lastValidBlockHeight5)) // load test: try to overload txs, confirm, or simulation for i := 0; i < 1000; i++ { - assert.NoError(t, txm.Enqueue(ctx, fmt.Sprintf("load_%d", i), createTx(loadTestKey.PublicKey(), loadTestKey.PublicKey(), loadTestKey.PublicKey(), uint64(i)), nil)) - time.Sleep(10 * time.Millisecond) // ~100 txs per second (note: have run 5ms delays for ~200tx/s succesfully) + tx6, lastValidBlockHeight6 := createTx(loadTestKey.PublicKey(), loadTestKey.PublicKey(), loadTestKey.PublicKey(), uint64(i)) + assert.NoError(t, txm.Enqueue(ctx, fmt.Sprintf("load_%d", i), tx6, nil, lastValidBlockHeight6)) + time.Sleep(10 * time.Millisecond) // ~100 txs per second (note: have run 5ms delays for ~200tx/s successfully) } // check to make sure all txs are closed out from inflight list (longest should last MaxConfirmTimeout) diff --git a/pkg/solana/txm/txm_race_test.go b/pkg/solana/txm/txm_race_test.go index 42062718f..33ec0f7bf 100644 --- a/pkg/solana/txm/txm_race_test.go +++ b/pkg/solana/txm/txm_race_test.go @@ -62,7 +62,6 @@ func TestTxm_SendWithRetry_Race(t *testing.T) { // assemble minimal tx for testing retry msg := NewTestMsg() - testRunner := func(t *testing.T, client solanaClient.ReaderWriter) { // build minimal txm loader := utils.NewLazyLoad(func() (solanaClient.ReaderWriter, error) { @@ -81,10 +80,8 @@ func TestTxm_SendWithRetry_Race(t *testing.T) { lastLog := observer.All()[len(observer.All())-1] assert.Contains(t, lastLog.Message, "stopped tx retry") // assert that all retry goroutines exit successfully } - + client := clientmocks.NewReaderWriter(t) t.Run("delay in rebroadcasting tx", func(t *testing.T) { - client := clientmocks.NewReaderWriter(t) - // client mock txs := map[string]solanaGo.Signature{} var lock sync.RWMutex client.On("SendTx", mock.Anything, mock.Anything).Return( @@ -121,8 +118,6 @@ func TestTxm_SendWithRetry_Race(t *testing.T) { }) t.Run("delay in broadcasting new tx", func(t *testing.T) { - client := clientmocks.NewReaderWriter(t) - // client mock txs := map[string]solanaGo.Signature{} var lock sync.RWMutex client.On("SendTx", mock.Anything, mock.Anything).Return( @@ -157,8 +152,6 @@ func TestTxm_SendWithRetry_Race(t *testing.T) { }) t.Run("overlapping bumping tx", func(t *testing.T) { - client := clientmocks.NewReaderWriter(t) - // client mock txs := map[string]solanaGo.Signature{} var lock sync.RWMutex client.On("SendTx", mock.Anything, mock.Anything).Return( @@ -204,8 +197,7 @@ func TestTxm_SendWithRetry_Race(t *testing.T) { }) t.Run("bumping tx errors and ctx cleans up waitgroup blocks", func(t *testing.T) { - client := clientmocks.NewReaderWriter(t) - // client mock - first tx is always successful + // first tx is always successful msg0 := NewTestMsg() require.NoError(t, fees.SetComputeUnitPrice(&msg0.tx, 0)) require.NoError(t, fees.SetComputeUnitLimit(&msg0.tx, 200_000)) @@ -217,7 +209,7 @@ func TestTxm_SendWithRetry_Race(t *testing.T) { require.NoError(t, fees.SetComputeUnitPrice(&msg1.tx, 1)) require.NoError(t, fees.SetComputeUnitLimit(&msg1.tx, 200_000)) msg1.tx.Signatures = make([]solanaGo.Signature, 1) - client.On("SendTx", mock.Anything, &msg1.tx).Return(solanaGo.Signature{}, fmt.Errorf("BUMP FAILED")).Once() + client.On("SendTx", mock.Anything, &msg1.tx).Return(solanaGo.Signature{}, fmt.Errorf("BUMP FAILED")) client.On("SendTx", mock.Anything, &msg1.tx).Return(solanaGo.Signature{2}, nil) // init bump tx success, rebroadcast fails @@ -225,7 +217,7 @@ func TestTxm_SendWithRetry_Race(t *testing.T) { require.NoError(t, fees.SetComputeUnitPrice(&msg2.tx, 2)) require.NoError(t, fees.SetComputeUnitLimit(&msg2.tx, 200_000)) msg2.tx.Signatures = make([]solanaGo.Signature, 1) - client.On("SendTx", mock.Anything, &msg2.tx).Return(solanaGo.Signature{3}, nil).Once() + client.On("SendTx", mock.Anything, &msg2.tx).Return(solanaGo.Signature{3}, nil) client.On("SendTx", mock.Anything, &msg2.tx).Return(solanaGo.Signature{}, fmt.Errorf("REBROADCAST FAILED")) // always successful @@ -234,7 +226,6 @@ func TestTxm_SendWithRetry_Race(t *testing.T) { require.NoError(t, fees.SetComputeUnitLimit(&msg3.tx, 200_000)) msg3.tx.Signatures = make([]solanaGo.Signature, 1) client.On("SendTx", mock.Anything, &msg3.tx).Return(solanaGo.Signature{4}, nil) - testRunner(t, client) }) }