Skip to content

Commit

Permalink
RFQ Relayer: restrict state transitions (#2787)
Browse files Browse the repository at this point in the history
- **New Features**
  - Introduced a new 'Keys()' method in map mutex interfaces for retrieving all keys.
  - Added logging for state transitions in quote requests. 

- **Bug Fixes**
  - Enhanced the `UpdateQuoteRequestStatus` function to check for valid state transitions and log any invalid transitions.

- **Tests**
  - Added new test cases for key retrieval in different types of map mutexes.

---------

Co-authored-by: Trajan0x <[email protected]>
  • Loading branch information
dwasse and trajan0x authored Jul 2, 2024
1 parent 0d573f2 commit 9cdc208
Show file tree
Hide file tree
Showing 9 changed files with 125 additions and 34 deletions.
13 changes: 13 additions & 0 deletions core/mapmutex/mapmutex.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
type untypedMapMutex interface {
Lock(key interface{}) Unlocker
TryLock(key interface{}) (Unlocker, bool)
Keys() []interface{}
}

type untypedMapMutexImpl struct {
Expand Down Expand Up @@ -81,6 +82,18 @@ func (m *untypedMapMutexImpl) TryLock(key interface{}) (Unlocker, bool) {
return nil, false
}

// Keys returns all keys in the map.
func (m *untypedMapMutexImpl) Keys() []interface{} {
m.ml.Lock()
defer m.ml.Unlock()

keys := make([]interface{}, 0, len(m.ma))
for k := range m.ma {
keys = append(keys, k)
}
return keys
}

// Unlock releases the lock for this entry.
func (me *mentry) Unlock() {
m := me.m
Expand Down
22 changes: 22 additions & 0 deletions core/mapmutex/mapmutex_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,28 @@ func (s MapMutexSuite) TestExampleMapMutex() {
NotPanics(s.T(), ExampleStringMapMutex)
}

func (s MapMutexSuite) TestKeys() {
s.T().Run("StringMapMutexKeys", func(t *testing.T) {
mapMutex := mapmutex.NewStringMapMutex()
mapMutex.Lock("lock1")
Equal(t, "lock1", mapMutex.Keys()[0])
Equal(t, 1, len(mapMutex.Keys()))
})
s.T().Run("StringerMapMutexKeys", func(t *testing.T) {
mapMutex := mapmutex.NewStringerMapMutex()
vitalik := common.HexToAddress("0xab5801a7d398351b8be11c439e05c5b3259aec9b")
mapMutex.Lock(vitalik)
Equal(t, vitalik.String(), mapMutex.Keys()[0])
Equal(t, 1, len(mapMutex.Keys()))
})
s.T().Run("IntMapMutexKeys", func(t *testing.T) {
mapMutex := mapmutex.NewIntMapMutex()
mapMutex.Lock(1)
Equal(t, 1, mapMutex.Keys()[0])
Equal(t, 1, len(mapMutex.Keys()))
})
}

func (s MapMutexSuite) TestMapMutex() {
//nolint:gosec
r := rand.New(rand.NewSource(42))
Expand Down
37 changes: 36 additions & 1 deletion core/mapmutex/stringer.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package mapmutex

import "fmt"
import (
"fmt"
)

// StringerMapMutex is an implementation of mapMutex for the fmt.Stringer conforming types.
type StringerMapMutex interface {
Lock(key fmt.Stringer) Unlocker
TryLock(key fmt.Stringer) (Unlocker, bool)
Keys() []string
}

// stringerLockerImpl is the implementation of StringerMapMutex.
Expand All @@ -22,6 +25,16 @@ func (s stringerLockerImpl) Lock(key fmt.Stringer) Unlocker {
return s.mapMux.Lock(key.String())
}

// Keys returns the keys of the map.
func (s stringerLockerImpl) Keys() []string {
var keys []string
for _, key := range s.mapMux.Keys() {
// nolint: forcetypeassert
keys = append(keys, key.(string))
}
return keys
}

// NewStringerMapMutex creates an initialized locker that locks on fmt.String.
func NewStringerMapMutex() StringerMapMutex {
return &stringerLockerImpl{
Expand All @@ -33,6 +46,7 @@ func NewStringerMapMutex() StringerMapMutex {
type StringMapMutex interface {
Lock(key string) Unlocker
TryLock(key string) (Unlocker, bool)
Keys() []string
}

// stringMutexImpl locks on a string type.
Expand All @@ -57,10 +71,21 @@ func (s stringMutexImpl) TryLock(key string) (Unlocker, bool) {
return s.mapMux.TryLock(key)
}

// Keys returns the keys of the map.
func (s stringMutexImpl) Keys() []string {
keys := []string{}
for _, key := range s.mapMux.Keys() {
// nolint: forcetypeassert
keys = append(keys, key.(string))
}
return keys
}

// IntMapMutex is a map mutex that allows locking on an int.
type IntMapMutex interface {
Lock(key int) Unlocker
TryLock(key int) (Unlocker, bool)
Keys() []int
}

// intMapMux locks on an int.
Expand All @@ -77,6 +102,16 @@ func (i intMapMux) Lock(key int) Unlocker {
return i.mapMux.Lock(key)
}

// Keys returns the keys of the map.
func (i intMapMux) Keys() []int {
var keys []int
for _, key := range i.mapMux.Keys() {
// nolint: forcetypeassert
keys = append(keys, key.(int))
}
return keys
}

// NewIntMapMutex creates a map mutex for locking on an integer.
func NewIntMapMutex() IntMapMutex {
return &intMapMux{
Expand Down
20 changes: 19 additions & 1 deletion services/rfq/relayer/reldb/base/quote.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,18 @@ func (s Store) GetQuoteResultsByStatus(ctx context.Context, matchStatuses ...rel
}

// UpdateQuoteRequestStatus todo: db test.
func (s Store) UpdateQuoteRequestStatus(ctx context.Context, id [32]byte, status reldb.QuoteRequestStatus) error {
func (s Store) UpdateQuoteRequestStatus(ctx context.Context, id [32]byte, status reldb.QuoteRequestStatus, prevStatus *reldb.QuoteRequestStatus) error {
if prevStatus == nil {
req, err := s.GetQuoteRequestByID(ctx, id)
if err != nil {
return fmt.Errorf("could not get quote: %w", err)
}
prevStatus = &req.Status
}
if !isValidStateTransition(*prevStatus, status) {
return nil
}

tx := s.DB().WithContext(ctx).Model(&RequestForQuote{}).
Where(fmt.Sprintf("%s = ?", transactionIDFieldName), hexutil.Encode(id[:])).
Update(statusFieldName, status)
Expand Down Expand Up @@ -120,3 +131,10 @@ func (s Store) UpdateRelayNonce(ctx context.Context, id [32]byte, nonce uint64)
}
return nil
}

func isValidStateTransition(prevStatus, status reldb.QuoteRequestStatus) bool {
if status == reldb.DeadlineExceeded || status == reldb.WillNotProcess {
return true
}
return status >= prevStatus
}
2 changes: 1 addition & 1 deletion services/rfq/relayer/reldb/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type Writer interface {
// StoreRebalance stores a rebalance.
StoreRebalance(ctx context.Context, rebalance Rebalance) error
// UpdateQuoteRequestStatus updates the status of a quote request
UpdateQuoteRequestStatus(ctx context.Context, id [32]byte, status QuoteRequestStatus) error
UpdateQuoteRequestStatus(ctx context.Context, id [32]byte, status QuoteRequestStatus, prevStatus *QuoteRequestStatus) error
// UpdateRebalance updates the status of a rebalance action.
// If the origin is supplied, it will be used to update the ID for the corresponding rebalance model.
UpdateRebalance(ctx context.Context, rebalance Rebalance, updateID bool) error
Expand Down
14 changes: 7 additions & 7 deletions services/rfq/relayer/service/chainindexer.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,41 +79,41 @@ func (r *Relayer) runChainIndexer(ctx context.Context, chainID int) (err error)
}
case *fastbridge.FastBridgeBridgeRelayed:
// blocking lock on the txid mutex to ensure state transitions are not overrwitten
unlocker := r.relayMtx.Lock(hexutil.Encode(event.TransactionId[:]))
unlocker := r.handlerMtx.Lock(hexutil.Encode(event.TransactionId[:]))
defer unlocker.Unlock()

// it wasn't me
if event.Relayer != r.signer.Address() {
//nolint: wrapcheck
return r.db.UpdateQuoteRequestStatus(ctx, event.TransactionId, reldb.RelayRaceLost)
return r.db.UpdateQuoteRequestStatus(ctx, event.TransactionId, reldb.RelayRaceLost, nil)
}

err = r.handleRelayLog(ctx, event)
if err != nil {
return fmt.Errorf("could not handle relay: %w", err)
}
case *fastbridge.FastBridgeBridgeProofProvided:
unlocker := r.relayMtx.Lock(hexutil.Encode(event.TransactionId[:]))
unlocker := r.handlerMtx.Lock(hexutil.Encode(event.TransactionId[:]))
defer unlocker.Unlock()

// it wasn't me
if event.Relayer != r.signer.Address() {
//nolint: wrapcheck
return r.db.UpdateQuoteRequestStatus(ctx, event.TransactionId, reldb.RelayRaceLost)
return r.db.UpdateQuoteRequestStatus(ctx, event.TransactionId, reldb.RelayRaceLost, nil)
}

err = r.handleProofProvided(ctx, event)
if err != nil {
return fmt.Errorf("could not handle proof provided: %w", err)
}
case *fastbridge.FastBridgeBridgeDepositClaimed:
unlocker := r.relayMtx.Lock(hexutil.Encode(event.TransactionId[:]))
unlocker := r.handlerMtx.Lock(hexutil.Encode(event.TransactionId[:]))
defer unlocker.Unlock()

// it wasn't me
if event.Relayer != r.signer.Address() {
//nolint: wrapcheck
return r.db.UpdateQuoteRequestStatus(ctx, event.TransactionId, reldb.RelayRaceLost)
return r.db.UpdateQuoteRequestStatus(ctx, event.TransactionId, reldb.RelayRaceLost, nil)
}

err = r.handleDepositClaimed(ctx, event)
Expand Down Expand Up @@ -206,7 +206,7 @@ func getDecimalsKey(addr common.Address, chainID uint32) string {
}

func (r *Relayer) handleDepositClaimed(ctx context.Context, event *fastbridge.FastBridgeBridgeDepositClaimed) error {
err := r.db.UpdateQuoteRequestStatus(ctx, event.TransactionId, reldb.ClaimCompleted)
err := r.db.UpdateQuoteRequestStatus(ctx, event.TransactionId, reldb.ClaimCompleted, nil)
if err != nil {
return fmt.Errorf("could not update request status: %w", err)
}
Expand Down
26 changes: 13 additions & 13 deletions services/rfq/relayer/service/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func (r *Relayer) handleBridgeRequestedLog(parentCtx context.Context, req *fastb
metrics.EndSpanWithErr(span, err)
}()

unlocker, ok := r.relayMtx.TryLock(hexutil.Encode(req.TransactionId[:]))
unlocker, ok := r.handlerMtx.TryLock(hexutil.Encode(req.TransactionId[:]))
if !ok {
span.SetAttributes(attribute.Bool("locked", true))
// already processing this request
Expand Down Expand Up @@ -142,7 +142,7 @@ func (q *QuoteRequestHandler) handleSeen(ctx context.Context, span trace.Span, r
return fmt.Errorf("could not determine if should process: %w", err)
}
if !shouldProcess {
err = q.db.UpdateQuoteRequestStatus(ctx, request.TransactionID, reldb.WillNotProcess)
err = q.db.UpdateQuoteRequestStatus(ctx, request.TransactionID, reldb.WillNotProcess, &request.Status)
if err != nil {
return fmt.Errorf("could not update request status: %w", err)
}
Expand All @@ -167,7 +167,7 @@ func (q *QuoteRequestHandler) handleSeen(ctx context.Context, span trace.Span, r
if errors.Is(err, inventory.ErrUnsupportedChain) {
// don't process request if chain is currently unsupported
span.AddEvent("dropping unsupported chain")
err = q.db.UpdateQuoteRequestStatus(ctx, request.TransactionID, reldb.WillNotProcess)
err = q.db.UpdateQuoteRequestStatus(ctx, request.TransactionID, reldb.WillNotProcess, &request.Status)
if err != nil {
return fmt.Errorf("could not update request status: %w", err)
}
Expand All @@ -179,7 +179,7 @@ func (q *QuoteRequestHandler) handleSeen(ctx context.Context, span trace.Span, r

// check if we have enough inventory to handle the request
if committableBalance.Cmp(request.Transaction.DestAmount) < 0 {
err = q.db.UpdateQuoteRequestStatus(ctx, request.TransactionID, reldb.NotEnoughInventory)
err = q.db.UpdateQuoteRequestStatus(ctx, request.TransactionID, reldb.NotEnoughInventory, &request.Status)
if err != nil {
return fmt.Errorf("could not update request status: %w", err)
}
Expand All @@ -206,7 +206,7 @@ func (q *QuoteRequestHandler) handleSeen(ctx context.Context, span trace.Span, r
}

request.Status = reldb.CommittedPending
err = q.db.UpdateQuoteRequestStatus(ctx, request.TransactionID, reldb.CommittedPending)
err = q.db.UpdateQuoteRequestStatus(ctx, request.TransactionID, reldb.CommittedPending, &request.Status)
if err != nil {
return fmt.Errorf("could not update request status: %w", err)
}
Expand Down Expand Up @@ -270,7 +270,7 @@ func (q *QuoteRequestHandler) handleCommitPending(ctx context.Context, span trac
}

request.Status = reldb.CommittedConfirmed
err = q.db.UpdateQuoteRequestStatus(ctx, request.TransactionID, reldb.CommittedConfirmed)
err = q.db.UpdateQuoteRequestStatus(ctx, request.TransactionID, reldb.CommittedConfirmed, &request.Status)
if err != nil {
return fmt.Errorf("could not update request status: %w", err)
}
Expand Down Expand Up @@ -300,7 +300,7 @@ func (q *QuoteRequestHandler) handleCommitConfirmed(ctx context.Context, span tr
span.AddEvent("relay successfully submitted")
span.SetAttributes(attribute.Int("relay_nonce", int(nonce)))

err = q.db.UpdateQuoteRequestStatus(ctx, request.TransactionID, reldb.RelayStarted)
err = q.db.UpdateQuoteRequestStatus(ctx, request.TransactionID, reldb.RelayStarted, &request.Status)
if err != nil {
return fmt.Errorf("could not update quote request status: %w", err)
}
Expand Down Expand Up @@ -332,7 +332,7 @@ func (r *Relayer) handleRelayLog(ctx context.Context, req *fastbridge.FastBridge
}

// TODO: this can still get re-orged
err = r.db.UpdateQuoteRequestStatus(ctx, req.TransactionId, reldb.RelayCompleted)
err = r.db.UpdateQuoteRequestStatus(ctx, req.TransactionId, reldb.RelayCompleted, nil)
if err != nil {
return fmt.Errorf("could not update request status: %w", err)
}
Expand Down Expand Up @@ -361,7 +361,7 @@ func (q *QuoteRequestHandler) handleRelayCompleted(ctx context.Context, _ trace.
return fmt.Errorf("could not submit transaction: %w", err)
}

err = q.db.UpdateQuoteRequestStatus(ctx, request.TransactionID, reldb.ProvePosting)
err = q.db.UpdateQuoteRequestStatus(ctx, request.TransactionID, reldb.ProvePosting, &request.Status)
if err != nil {
return fmt.Errorf("could not update request status: %w", err)
}
Expand All @@ -375,7 +375,7 @@ func (q *QuoteRequestHandler) handleRelayCompleted(ctx context.Context, _ trace.
func (r *Relayer) handleProofProvided(ctx context.Context, req *fastbridge.FastBridgeBridgeProofProvided) (err error) {
// TODO: this can still get re-orged
// ALso: we should make sure the previous status is ProvePosting
err = r.db.UpdateQuoteRequestStatus(ctx, req.TransactionId, reldb.ProvePosted)
err = r.db.UpdateQuoteRequestStatus(ctx, req.TransactionId, reldb.ProvePosted, nil)
if err != nil {
return fmt.Errorf("could not update request status: %w", err)
}
Expand Down Expand Up @@ -408,7 +408,7 @@ func (q *QuoteRequestHandler) handleProofPosted(ctx context.Context, _ trace.Spa
}

if bs == fastbridge.RelayerClaimed.Int() {
err = q.db.UpdateQuoteRequestStatus(ctx, request.TransactionID, reldb.ClaimCompleted)
err = q.db.UpdateQuoteRequestStatus(ctx, request.TransactionID, reldb.ClaimCompleted, &request.Status)
if err != nil {
return fmt.Errorf("could not update request status: %w", err)
}
Expand Down Expand Up @@ -443,7 +443,7 @@ func (q *QuoteRequestHandler) handleProofPosted(ctx context.Context, _ trace.Spa
return fmt.Errorf("could not submit transaction: %w", err)
}

err = q.db.UpdateQuoteRequestStatus(ctx, request.TransactionID, reldb.ClaimPending)
err = q.db.UpdateQuoteRequestStatus(ctx, request.TransactionID, reldb.ClaimPending, &request.Status)
if err != nil {
return fmt.Errorf("could not update request status: %w", err)
}
Expand All @@ -460,7 +460,7 @@ func (q *QuoteRequestHandler) handleNotEnoughInventory(ctx context.Context, _ tr
}
// if committableBalance > destAmount
if committableBalance.Cmp(request.Transaction.DestAmount) > 0 {
err = q.db.UpdateQuoteRequestStatus(ctx, request.TransactionID, reldb.CommittedPending)
err = q.db.UpdateQuoteRequestStatus(ctx, request.TransactionID, reldb.CommittedPending, &request.Status)
if err != nil {
return fmt.Errorf("could not update request status: %w", err)
}
Expand Down
8 changes: 4 additions & 4 deletions services/rfq/relayer/service/relayer.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ type Relayer struct {
decimalsCache *xsync.MapOf[string, *uint8]
// semaphore is used to limit the number of concurrent requests
semaphore *semaphore.Weighted
// relayMtx is used to synchronize handling of relay requests
relayMtx mapmutex.StringMapMutex
// handlerMtx is used to synchronize handling of relay requests
handlerMtx mapmutex.StringMapMutex
}

var logger = log.Logger("relayer")
Expand Down Expand Up @@ -155,7 +155,7 @@ func NewRelayer(ctx context.Context, metricHandler metrics.Handler, cfg relconfi
apiServer: apiServer,
apiClient: apiClient,
semaphore: semaphore.NewWeighted(maxConcurrentRequests),
relayMtx: mapmutex.NewStringMapMutex(),
handlerMtx: mapmutex.NewStringMapMutex(),
}
return &rel, nil
}
Expand Down Expand Up @@ -391,7 +391,7 @@ func (r *Relayer) processRequest(parentCtx context.Context, request reldb.QuoteR

// if deadline < now
if request.Transaction.Deadline.Cmp(big.NewInt(time.Now().Unix())) < 0 && request.Status.Int() < reldb.RelayCompleted.Int() {
err = r.db.UpdateQuoteRequestStatus(ctx, request.TransactionID, reldb.DeadlineExceeded)
err = r.db.UpdateQuoteRequestStatus(ctx, request.TransactionID, reldb.DeadlineExceeded, &request.Status)
if err != nil {
return fmt.Errorf("could not update request status: %w", err)
}
Expand Down
Loading

0 comments on commit 9cdc208

Please sign in to comment.