Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(rfq-api): add v2 contracts to rfq api endpoint [SLT-429] #3387

Merged
merged 9 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion contrib/opbot/botmd/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,8 @@
return nil, fmt.Errorf("error getting chain client for chain ID %d: %w", chainID, err)
}

contractAddress, ok := contracts.Contracts[chainID]
// TODO: handle v2 contract if specified
contractAddress, ok := contracts.ContractsV1[chainID]

Check warning on line 375 in contrib/opbot/botmd/commands.go

View check run for this annotation

Codecov / codecov/patch

contrib/opbot/botmd/commands.go#L375

Added line #L375 was not covered by tests
Comment on lines +374 to +375
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add test coverage for contract version handling

The new contract version logic lacks test coverage. This is critical as it affects core functionality like refunds.

Consider adding these test cases:

  1. Test successful v1 contract retrieval
  2. Test successful v2 contract retrieval (after implementation)
  3. Test fallback behavior
  4. Test error cases for missing contracts

Would you like me to help generate the test cases?

🧰 Tools
🪛 GitHub Check: codecov/patch

[warning] 375-375: contrib/opbot/botmd/commands.go#L375
Added line #L375 was not covered by tests

if !ok {
return nil, fmt.Errorf("no contract address for chain ID")
}
Expand Down
6 changes: 5 additions & 1 deletion services/rfq/api/client/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,11 @@ func (c *ClientSuite) SetupTest() {
DSN: filet.TmpFile(c.T(), "", "").Name(),
},
OmniRPCURL: testOmnirpc,
Bridges: map[uint32]string{
FastBridgeContractsV1: map[uint32]string{
1: ethFastBridgeAddress.Hex(),
42161: arbFastBridgeAddress.Hex(),
},
FastBridgeContractsV2: map[uint32]string{
1: ethFastBridgeAddress.Hex(),
42161: arbFastBridgeAddress.Hex(),
},
Expand Down
14 changes: 7 additions & 7 deletions services/rfq/api/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ type DatabaseConfig struct {

// Config is the configuration for the RFQ Quoter.
type Config struct {
Database DatabaseConfig `yaml:"database"`
OmniRPCURL string `yaml:"omnirpc_url"`
// bridges is a map of chainid->address
Bridges map[uint32]string `yaml:"bridges"`
Port string `yaml:"port"`
RelayAckTimeout time.Duration `yaml:"relay_ack_timeout"`
MaxQuoteAge time.Duration `yaml:"max_quote_age"`
Database DatabaseConfig `yaml:"database"`
OmniRPCURL string `yaml:"omnirpc_url"`
FastBridgeContractsV1 map[uint32]string `yaml:"fast_bridge_contracts_v1"`
FastBridgeContractsV2 map[uint32]string `yaml:"fast_bridge_contracts_v2"`
Port string `yaml:"port"`
RelayAckTimeout time.Duration `yaml:"relay_ack_timeout"`
MaxQuoteAge time.Duration `yaml:"max_quote_age"`
ChiTimesChi marked this conversation as resolved.
Show resolved Hide resolved
}

const defaultRelayAckTimeout = 30 * time.Second
Expand Down
8 changes: 5 additions & 3 deletions services/rfq/api/model/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,12 @@ type PutRelayAckResponse struct {
RelayerAddress string `json:"relayer_address"`
}

// GetContractsResponse contains the schema for a GET /contract response.
// GetContractsResponse contains the schema for a GET /contracts response.
type GetContractsResponse struct {
// Contracts is a map of chain id to contract address
Contracts map[uint32]string `json:"contracts"`
// ContractsV1 is a map of chain id to contract address for v1 fast bridge contracts
ContractsV1 map[uint32]string `json:"contracts_v1"`
// ContractsV2 is a map of chain id to contract address for v2 fast bridge contracts
ContractsV2 map[uint32]string `json:"contracts_v2"`
}

// ActiveRFQMessage represents the general structure of WebSocket messages for Active RFQ.
Expand Down
41 changes: 32 additions & 9 deletions services/rfq/api/rest/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"strconv"
"time"

"github.com/ethereum/go-ethereum/common"
"github.com/synapsecns/sanguine/services/rfq/api/config"

"github.com/gin-gonic/gin"
Expand Down Expand Up @@ -69,7 +70,7 @@ func (h *Handler) ModifyQuote(c *gin.Context) {
return
}

dbQuote, err := parseDBQuote(*putRequest, relayerAddr)
dbQuote, err := parseDBQuote(h.cfg, *putRequest, relayerAddr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
Expand Down Expand Up @@ -117,7 +118,7 @@ func (h *Handler) ModifyBulkQuotes(c *gin.Context) {

dbQuotes := []*db.Quote{}
for _, quoteReq := range putRequest.Quotes {
dbQuote, err := parseDBQuote(quoteReq, relayerAddr)
dbQuote, err := parseDBQuote(h.cfg, quoteReq, relayerAddr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid quote request"})
return
Expand All @@ -134,7 +135,7 @@ func (h *Handler) ModifyBulkQuotes(c *gin.Context) {
}

//nolint:gosec
func parseDBQuote(putRequest model.PutRelayerQuoteRequest, relayerAddr interface{}) (*db.Quote, error) {
func parseDBQuote(cfg config.Config, putRequest model.PutRelayerQuoteRequest, relayerAddr interface{}) (*db.Quote, error) {
destAmount, err := decimal.NewFromString(putRequest.DestAmount)
if err != nil {
return nil, fmt.Errorf("invalid DestAmount")
Expand All @@ -147,6 +148,12 @@ func parseDBQuote(putRequest model.PutRelayerQuoteRequest, relayerAddr interface
if err != nil {
return nil, fmt.Errorf("invalid FixedFee")
}

err = validateFastBridgeAddresses(cfg, putRequest)
if err != nil {
return nil, fmt.Errorf("invalid fast bridge addresses: %w", err)
}

// nolint: forcetypeassert
return &db.Quote{
OriginChainID: uint64(putRequest.OriginChainID),
Expand All @@ -163,6 +170,24 @@ func parseDBQuote(putRequest model.PutRelayerQuoteRequest, relayerAddr interface
}, nil
}

//nolint:gosec
func validateFastBridgeAddresses(cfg config.Config, putRequest model.PutRelayerQuoteRequest) error {
// Check V1 contracts
isV1Origin := common.HexToAddress(cfg.FastBridgeContractsV1[uint32(putRequest.OriginChainID)]) == common.HexToAddress(putRequest.OriginFastBridgeAddress)
isV1Dest := common.HexToAddress(cfg.FastBridgeContractsV1[uint32(putRequest.DestChainID)]) == common.HexToAddress(putRequest.DestFastBridgeAddress)

// Check V2 contracts
isV2Origin := common.HexToAddress(cfg.FastBridgeContractsV2[uint32(putRequest.OriginChainID)]) == common.HexToAddress(putRequest.OriginFastBridgeAddress)
isV2Dest := common.HexToAddress(cfg.FastBridgeContractsV2[uint32(putRequest.DestChainID)]) == common.HexToAddress(putRequest.DestFastBridgeAddress)

// Valid if both addresses match either V1 or V2
if (isV1Origin && isV1Dest) || (isV2Origin && isV2Dest) {
return nil
}

return fmt.Errorf("origin and destination fast bridge addresses must match either V1 or V2")
}
Comment on lines +174 to +189
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Address potential integer overflow and optimize validation logic

  1. There's a risk of integer overflow when converting chainIDs from int to uint32.
  2. The validation logic could be more efficient.

Fix the integer overflow risk:

-       isV1Origin := common.HexToAddress(cfg.FastBridgeContractsV1[uint32(putRequest.OriginChainID)]) == common.HexToAddress(putRequest.OriginFastBridgeAddress)
+       if putRequest.OriginChainID < 0 || putRequest.DestChainID < 0 {
+           return fmt.Errorf("chain IDs must be non-negative")
+       }
+       originChainID := uint32(putRequest.OriginChainID)
+       destChainID := uint32(putRequest.DestChainID)
+       isV1Origin := common.HexToAddress(cfg.FastBridgeContractsV1[originChainID]) == common.HexToAddress(putRequest.OriginFastBridgeAddress)

Optimize the validation logic:

-       isV1Origin := common.HexToAddress(cfg.FastBridgeContractsV1[uint32(putRequest.OriginChainID)]) == common.HexToAddress(putRequest.OriginFastBridgeAddress)
-       isV1Dest := common.HexToAddress(cfg.FastBridgeContractsV1[uint32(putRequest.DestChainID)]) == common.HexToAddress(putRequest.DestFastBridgeAddress)
-       isV2Origin := common.HexToAddress(cfg.FastBridgeContractsV2[uint32(putRequest.OriginChainID)]) == common.HexToAddress(putRequest.OriginFastBridgeAddress)
-       isV2Dest := common.HexToAddress(cfg.FastBridgeContractsV2[uint32(putRequest.DestChainID)]) == common.HexToAddress(putRequest.DestFastBridgeAddress)
+       originAddr := common.HexToAddress(putRequest.OriginFastBridgeAddress)
+       destAddr := common.HexToAddress(putRequest.DestFastBridgeAddress)
+       
+       v1OriginAddr := common.HexToAddress(cfg.FastBridgeContractsV1[originChainID])
+       v1DestAddr := common.HexToAddress(cfg.FastBridgeContractsV1[destChainID])
+       if originAddr == v1OriginAddr && destAddr == v1DestAddr {
+           return nil
+       }
+       
+       v2OriginAddr := common.HexToAddress(cfg.FastBridgeContractsV2[originChainID])
+       v2DestAddr := common.HexToAddress(cfg.FastBridgeContractsV2[destChainID])
+       if originAddr == v2OriginAddr && destAddr == v2DestAddr {
+           return nil
+       }
+       
+       return fmt.Errorf("origin and destination fast bridge addresses must match either V1 or V2")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
func validateFastBridgeAddresses(cfg config.Config, putRequest model.PutRelayerQuoteRequest) error {
// Check V1 contracts
isV1Origin := common.HexToAddress(cfg.FastBridgeContractsV1[uint32(putRequest.OriginChainID)]) == common.HexToAddress(putRequest.OriginFastBridgeAddress)
isV1Dest := common.HexToAddress(cfg.FastBridgeContractsV1[uint32(putRequest.DestChainID)]) == common.HexToAddress(putRequest.DestFastBridgeAddress)
// Check V2 contracts
isV2Origin := common.HexToAddress(cfg.FastBridgeContractsV2[uint32(putRequest.OriginChainID)]) == common.HexToAddress(putRequest.OriginFastBridgeAddress)
isV2Dest := common.HexToAddress(cfg.FastBridgeContractsV2[uint32(putRequest.DestChainID)]) == common.HexToAddress(putRequest.DestFastBridgeAddress)
// Valid if both addresses match either V1 or V2
if (isV1Origin && isV1Dest) || (isV2Origin && isV2Dest) {
return nil
}
return fmt.Errorf("origin and destination fast bridge addresses must match either V1 or V2")
}
func validateFastBridgeAddresses(cfg config.Config, putRequest model.PutRelayerQuoteRequest) error {
if putRequest.OriginChainID < 0 || putRequest.DestChainID < 0 {
return fmt.Errorf("chain IDs must be non-negative")
}
originChainID := uint32(putRequest.OriginChainID)
destChainID := uint32(putRequest.DestChainID)
originAddr := common.HexToAddress(putRequest.OriginFastBridgeAddress)
destAddr := common.HexToAddress(putRequest.DestFastBridgeAddress)
v1OriginAddr := common.HexToAddress(cfg.FastBridgeContractsV1[originChainID])
v1DestAddr := common.HexToAddress(cfg.FastBridgeContractsV1[destChainID])
if originAddr == v1OriginAddr && destAddr == v1DestAddr {
return nil
}
v2OriginAddr := common.HexToAddress(cfg.FastBridgeContractsV2[originChainID])
v2DestAddr := common.HexToAddress(cfg.FastBridgeContractsV2[destChainID])
if originAddr == v2OriginAddr && destAddr == v2DestAddr {
return nil
}
return fmt.Errorf("origin and destination fast bridge addresses must match either V1 or V2")
}
🧰 Tools
🪛 GitHub Check: Lint (services/rfq)

[failure] 175-175:
G115: integer overflow conversion int -> uint32 (gosec)


[failure] 176-176:
G115: integer overflow conversion int -> uint32 (gosec)


[failure] 179-179:
G115: integer overflow conversion int -> uint32 (gosec)


//nolint:gosec
func quoteResponseFromDBQuote(dbQuote *db.Quote) *model.GetQuoteResponse {
return &model.GetQuoteResponse{
Expand Down Expand Up @@ -301,12 +326,10 @@ func dbActiveQuoteRequestToModel(dbQuote *db.ActiveQuoteRequest) *model.GetOpenQ
// @Header 200 {string} X-Api-Version "API Version Number - See docs for more info"
// @Router /contracts [get].
func (h *Handler) GetContracts(c *gin.Context) {
// Convert quotes from db model to api model
contracts := make(map[uint32]string)
for chainID, address := range h.cfg.Bridges {
contracts[chainID] = address
}
c.JSON(http.StatusOK, model.GetContractsResponse{Contracts: contracts})
c.JSON(http.StatusOK, model.GetContractsResponse{
ContractsV1: h.cfg.FastBridgeContractsV1,
ContractsV2: h.cfg.FastBridgeContractsV2,
})
}

func filterQuoteAge(cfg config.Config, dbQuotes []*db.Quote) []*db.Quote {
Expand Down
144 changes: 110 additions & 34 deletions services/rfq/api/rest/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
"go.opentelemetry.io/otel/trace"
"golang.org/x/sync/errgroup"

"github.com/ethereum/go-ethereum/accounts/abi/bind"
"github.com/ethereum/go-ethereum/common"
Expand All @@ -35,6 +36,7 @@ import (
"github.com/synapsecns/sanguine/services/rfq/api/docs"
"github.com/synapsecns/sanguine/services/rfq/api/model"
"github.com/synapsecns/sanguine/services/rfq/contracts/fastbridge"
"github.com/synapsecns/sanguine/services/rfq/contracts/fastbridgev2"
"github.com/synapsecns/sanguine/services/rfq/relayer/relapi"
)

Expand All @@ -51,15 +53,17 @@ func getCurrentVersion() (string, error) {
// QuoterAPIServer is a struct that holds the configuration, database connection, gin engine, RPC client, metrics handler, and fast bridge contracts.
// It is used to initialize and run the API server.
type QuoterAPIServer struct {
cfg config.Config
db db.APIDB
engine *gin.Engine
upgrader websocket.Upgrader
omnirpcClient omniClient.RPCClient
handler metrics.Handler
meter metric.Meter
fastBridgeContracts map[uint32]*fastbridge.FastBridge
roleCache map[uint32]*ttlcache.Cache[string, bool]
cfg config.Config
db db.APIDB
engine *gin.Engine
upgrader websocket.Upgrader
omnirpcClient omniClient.RPCClient
handler metrics.Handler
meter metric.Meter
fastBridgeContractsV1 map[uint32]*fastbridge.FastBridge
fastBridgeContractsV2 map[uint32]*fastbridgev2.FastBridgeV2
roleCacheV1 map[uint32]*ttlcache.Cache[string, bool]
roleCacheV2 map[uint32]*ttlcache.Cache[string, bool]
Comment on lines +63 to +66
Copy link
Contributor

@coderabbitai coderabbitai bot Nov 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Refactor to reduce duplication between V1 and V2 contract fields

The QuoterAPIServer struct introduces separate fields for V1 and V2 contracts and role caches:

fastBridgeContractsV1 map[uint32]*fastbridge.FastBridge
fastBridgeContractsV2 map[uint32]*fastbridgev2.FastBridgeV2
roleCacheV1           map[uint32]*ttlcache.Cache[string, bool]
roleCacheV2           map[uint32]*ttlcache.Cache[string, bool]

Consider refactoring these fields to reduce duplication. One approach is to use a unified structure or map that can handle multiple versions, possibly by introducing a version key or encapsulating the contract and cache data together.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dwasse thoughts on this? Not sure if worth it, but I see that the new roleContract type could be potentially used here, as we only check roles anyway

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the humans are having a chat. I'll hop back into my burrow for now. If you need me again, just tag @coderabbitai in a new comment, and I'll come hopping out!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My plan here was to keep these as separate fields temporarily while we support backwards compatibility, then remove the old fields. We could opt for a more generic approach if we want to support multiple versions down the line- do you think that will be the case?

// relayAckCache contains a set of transactionID values that reflect
// transactions that have been acked for relay
relayAckCache *ttlcache.Cache[string, string]
Expand Down Expand Up @@ -96,23 +100,47 @@ func NewAPI(

docs.SwaggerInfo.Title = "RFQ Quoter API"

bridges := make(map[uint32]*fastbridge.FastBridge)
roles := make(map[uint32]*ttlcache.Cache[string, bool])
for chainID, bridge := range cfg.Bridges {
fastBridgeContractsV1 := make(map[uint32]*fastbridge.FastBridge)
rolesV1 := make(map[uint32]*ttlcache.Cache[string, bool])
for chainID, contract := range cfg.FastBridgeContractsV1 {
chainClient, err := omniRPCClient.GetChainClient(ctx, int(chainID))
if err != nil {
return nil, fmt.Errorf("could not create omnirpc client: %w", err)
}
bridges[chainID], err = fastbridge.NewFastBridge(common.HexToAddress(bridge), chainClient)
fastBridgeContractsV1[chainID], err = fastbridge.NewFastBridge(common.HexToAddress(contract), chainClient)
if err != nil {
return nil, fmt.Errorf("could not create bridge contract: %w", err)
}

// create the roles cache
roles[chainID] = ttlcache.New[string, bool](
rolesV1[chainID] = ttlcache.New[string, bool](
ttlcache.WithTTL[string, bool](cacheInterval),
)
roleCache := roles[chainID]
roleCache := rolesV1[chainID]
go roleCache.Start()
go func() {
<-ctx.Done()
roleCache.Stop()
}()
}
Comment on lines +103 to +125
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Reduce code duplication in contract initialization

The initialization code for v1 and v2 contracts is nearly identical, which could lead to maintenance challenges.

Consider consolidating the initialization logic into a helper function:

+func initializeContracts[T roleContract](
+    ctx context.Context,
+    cfg map[uint32]string,
+    newContract func(common.Address, bind.ContractBackend) (T, error),
+    omniRPCClient omniClient.RPCClient,
+) (map[uint32]T, map[uint32]*ttlcache.Cache[string, bool], error) {
+    contracts := make(map[uint32]T)
+    roles := make(map[uint32]*ttlcache.Cache[string, bool])
+    
+    for chainID, contract := range cfg {
+        chainClient, err := omniRPCClient.GetChainClient(ctx, int(chainID))
+        if err != nil {
+            return nil, nil, fmt.Errorf("could not create omnirpc client: %w", err)
+        }
+        
+        contractInstance, err := newContract(common.HexToAddress(contract), chainClient)
+        if err != nil {
+            return nil, nil, fmt.Errorf("could not create bridge contract: %w", err)
+        }
+        contracts[chainID] = contractInstance
+        
+        roles[chainID] = ttlcache.New[string, bool](
+            ttlcache.WithTTL[string, bool](cacheInterval),
+        )
+        roleCache := roles[chainID]
+        go roleCache.Start()
+        go func() {
+            <-ctx.Done()
+            roleCache.Stop()
+        }()
+    }
+    return contracts, roles, nil
+}

Usage:

-fastBridgeContractsV1 := make(map[uint32]*fastbridge.FastBridge)
-rolesV1 := make(map[uint32]*ttlcache.Cache[string, bool])
-for chainID, contract := range cfg.FastBridgeContractsV1 {
-    // ... initialization code
-}
+fastBridgeContractsV1, rolesV1, err := initializeContracts(
+    ctx,
+    cfg.FastBridgeContractsV1,
+    fastbridge.NewFastBridge,
+    omniRPCClient,
+)
+if err != nil {
+    return nil, err
+}

Also applies to: 127-141


fastBridgeContractsV2 := make(map[uint32]*fastbridgev2.FastBridgeV2)
rolesV2 := make(map[uint32]*ttlcache.Cache[string, bool])
for chainID, contract := range cfg.FastBridgeContractsV2 {
chainClient, err := omniRPCClient.GetChainClient(ctx, int(chainID))
if err != nil {
return nil, fmt.Errorf("could not create omnirpc client: %w", err)
}
fastBridgeContractsV2[chainID], err = fastbridgev2.NewFastBridgeV2(common.HexToAddress(contract), chainClient)
if err != nil {
return nil, fmt.Errorf("could not create bridge contract: %w", err)
}

// create the roles cache
rolesV2[chainID] = ttlcache.New[string, bool](
ttlcache.WithTTL[string, bool](cacheInterval),
)
roleCache := rolesV2[chainID]
go roleCache.Start()
go func() {
<-ctx.Done()
Expand All @@ -132,17 +160,19 @@ func NewAPI(
}()

q := &QuoterAPIServer{
cfg: cfg,
db: store,
omnirpcClient: omniRPCClient,
handler: handler,
meter: handler.Meter(meterName),
fastBridgeContracts: bridges,
roleCache: roles,
relayAckCache: relayAckCache,
ackMux: sync.Mutex{},
wsClients: xsync.NewMapOf[WsClient](),
pubSubManager: NewPubSubManager(),
cfg: cfg,
db: store,
omnirpcClient: omniRPCClient,
handler: handler,
meter: handler.Meter(meterName),
fastBridgeContractsV1: fastBridgeContractsV1,
fastBridgeContractsV2: fastBridgeContractsV2,
roleCacheV1: rolesV1,
roleCacheV2: rolesV2,
relayAckCache: relayAckCache,
ackMux: sync.Mutex{},
wsClients: xsync.NewMapOf[WsClient](),
pubSubManager: NewPubSubManager(),
}

// Prometheus metrics setup
Expand Down Expand Up @@ -298,7 +328,7 @@ func (r *QuoterAPIServer) AuthMiddleware() gin.HandlerFunc {
// Authenticate and fetch the address from the request
var addressRecovered *common.Address
for _, destChainID := range destChainIDs {
addr, err := r.checkRole(c, destChainID)
addr, err := r.checkRoleParallel(c, destChainID)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"msg": err.Error()})
c.Abort()
Expand All @@ -321,11 +351,57 @@ func (r *QuoterAPIServer) AuthMiddleware() gin.HandlerFunc {
}
}

func (r *QuoterAPIServer) checkRole(c *gin.Context, destChainID uint32) (addressRecovered common.Address, err error) {
bridge, ok := r.fastBridgeContracts[destChainID]
if !ok {
err = fmt.Errorf("dest chain id not supported: %d", destChainID)
return addressRecovered, err
type roleContract interface {
HasRole(opts *bind.CallOpts, role [32]byte, account common.Address) (bool, error)
}

func (r *QuoterAPIServer) checkRoleParallel(c *gin.Context, destChainID uint32) (addressRecovered common.Address, err error) {
g := new(errgroup.Group)
var v1Addr, v2Addr common.Address
var v1Err, v2Err error

g.Go(func() error {
v1Addr, v1Err = r.checkRole(c, destChainID, true)
return v1Err
})

g.Go(func() error {
v2Addr, v2Err = r.checkRole(c, destChainID, false)
return v2Err
})

err = g.Wait()
if v1Addr != (common.Address{}) {
return v1Addr, nil
}
if v2Addr != (common.Address{}) {
return v2Addr, nil
}
if err != nil {
return common.Address{}, fmt.Errorf("role check failed: %w", err)
}

return common.Address{}, fmt.Errorf("role check failed for both v1 and v2")
}
Comment on lines +358 to +385
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Improve error handling in parallel role checking

The current implementation might mask specific errors when both v1 and v2 checks fail.

Consider enhancing error reporting:

func (r *QuoterAPIServer) checkRoleParallel(c *gin.Context, destChainID uint32) (addressRecovered common.Address, err error) {
    g := new(errgroup.Group)
    var v1Addr, v2Addr common.Address
    var v1Err, v2Err error

    g.Go(func() error {
        v1Addr, v1Err = r.checkRole(c, destChainID, true)
        return v1Err
    })

    g.Go(func() error {
        v2Addr, v2Err = r.checkRole(c, destChainID, false)
        return v2Err
    })

-   err = g.Wait()
+   _ = g.Wait() // We want to collect both errors regardless of g.Wait()
    if v1Addr != (common.Address{}) {
        return v1Addr, nil
    }
    if v2Addr != (common.Address{}) {
        return v2Addr, nil
    }
-   if err != nil {
-       return common.Address{}, fmt.Errorf("role check failed: %w", err)
-   }
-   return common.Address{}, fmt.Errorf("role check failed for both v1 and v2")
+   return common.Address{}, fmt.Errorf("role check failed: v1 error: %v, v2 error: %v", v1Err, v2Err)
}
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
func (r *QuoterAPIServer) checkRoleParallel(c *gin.Context, destChainID uint32) (addressRecovered common.Address, err error) {
g := new(errgroup.Group)
var v1Addr, v2Addr common.Address
var v1Err, v2Err error
g.Go(func() error {
v1Addr, v1Err = r.checkRole(c, destChainID, true)
return v1Err
})
g.Go(func() error {
v2Addr, v2Err = r.checkRole(c, destChainID, false)
return v2Err
})
err = g.Wait()
if v1Addr != (common.Address{}) {
return v1Addr, nil
}
if v2Addr != (common.Address{}) {
return v2Addr, nil
}
if err != nil {
return common.Address{}, fmt.Errorf("role check failed: %w", err)
}
return common.Address{}, fmt.Errorf("role check failed for both v1 and v2")
}
func (r *QuoterAPIServer) checkRoleParallel(c *gin.Context, destChainID uint32) (addressRecovered common.Address, err error) {
g := new(errgroup.Group)
var v1Addr, v2Addr common.Address
var v1Err, v2Err error
g.Go(func() error {
v1Addr, v1Err = r.checkRole(c, destChainID, true)
return v1Err
})
g.Go(func() error {
v2Addr, v2Err = r.checkRole(c, destChainID, false)
return v2Err
})
_ = g.Wait() // We want to collect both errors regardless of g.Wait()
if v1Addr != (common.Address{}) {
return v1Addr, nil
}
if v2Addr != (common.Address{}) {
return v2Addr, nil
}
return common.Address{}, fmt.Errorf("role check failed: v1 error: %v, v2 error: %v", v1Err, v2Err)
}


func (r *QuoterAPIServer) checkRole(c *gin.Context, destChainID uint32, useV1 bool) (addressRecovered common.Address, err error) {
var bridge roleContract
var roleCache *ttlcache.Cache[string, bool]
var ok bool
if useV1 {
bridge, ok = r.fastBridgeContractsV1[destChainID]
if !ok {
err = fmt.Errorf("dest chain id not supported: %d", destChainID)
return addressRecovered, err
}
roleCache = r.roleCacheV1[destChainID]
} else {
bridge, ok = r.fastBridgeContractsV2[destChainID]
if !ok {
err = fmt.Errorf("dest chain id not supported: %d", destChainID)
return addressRecovered, err
}
roleCache = r.roleCacheV2[destChainID]
}

ops := &bind.CallOpts{Context: c}
Expand All @@ -340,7 +416,7 @@ func (r *QuoterAPIServer) checkRole(c *gin.Context, destChainID uint32) (address
}

// Check and update cache
cachedRoleItem := r.roleCache[destChainID].Get(addressRecovered.Hex())
cachedRoleItem := roleCache.Get(addressRecovered.Hex())
var hasRole bool

if cachedRoleItem == nil || cachedRoleItem.IsExpired() {
Expand All @@ -350,7 +426,7 @@ func (r *QuoterAPIServer) checkRole(c *gin.Context, destChainID uint32) (address
return addressRecovered, fmt.Errorf("unable to check relayer role on-chain: %w", err)
}
// Update cache
r.roleCache[destChainID].Set(addressRecovered.Hex(), hasRole, cacheInterval)
roleCache.Set(addressRecovered.Hex(), hasRole, cacheInterval)
} else {
// Use cached value
hasRole = cachedRoleItem.Value()
Expand Down
3 changes: 2 additions & 1 deletion services/rfq/api/rest/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -607,5 +607,6 @@ func (c *ServerSuite) TestContracts() {
contracts, err := client.GetRFQContracts(c.GetTestContext())
c.Require().NoError(err)

c.Require().Len(contracts.Contracts, 2)
c.Require().Len(contracts.ContractsV1, 2)
c.Require().Len(contracts.ContractsV2, 2)
}
6 changes: 5 additions & 1 deletion services/rfq/api/rest/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,11 @@ func (c *ServerSuite) SetupTest() {
DSN: filet.TmpFile(c.T(), "", "").Name(),
},
OmniRPCURL: testOmnirpc,
Bridges: map[uint32]string{
FastBridgeContractsV1: map[uint32]string{
1: ethFastBridgeAddress.Hex(),
42161: arbFastBridgeAddress.Hex(),
},
Comment on lines +85 to +88
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Address type inconsistency in chain ID mappings

The FastBridge contract maps in the test config use uint32 for chain IDs, but the implementation uses uint64 (see fastBridgeAddressMap field). This type mismatch could cause issues in production code.

Apply this diff to fix the type inconsistency:

-       FastBridgeContractsV1: map[uint32]string{
+       FastBridgeContractsV1: map[uint64]string{
        },
-       FastBridgeContractsV2: map[uint32]string{
+       FastBridgeContractsV2: map[uint64]string{
        },

Also applies to: 89-92

FastBridgeContractsV2: map[uint32]string{
1: ethFastBridgeAddress.Hex(),
42161: arbFastBridgeAddress.Hex(),
},
Expand Down
2 changes: 1 addition & 1 deletion services/rfq/e2e/setup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func (i *IntegrationSuite) setupQuoterAPI() {
DSN: dbPath,
},
OmniRPCURL: i.omniServer,
Bridges: map[uint32]string{
FastBridgeContractsV1: map[uint32]string{
originBackendChainID: i.manager.Get(i.GetTestContext(), i.originBackend, testutil.FastBridgeType).Address().String(),
destBackendChainID: i.manager.Get(i.GetTestContext(), i.destBackend, testutil.FastBridgeType).Address().String(),
},
Expand Down
Loading