diff --git a/contrib/opbot/botmd/commands.go b/contrib/opbot/botmd/commands.go index bf49c6a98f..02258337af 100644 --- a/contrib/opbot/botmd/commands.go +++ b/contrib/opbot/botmd/commands.go @@ -371,7 +371,8 @@ func (b *Bot) makeFastBridge(ctx context.Context, chainID uint32) (*fastbridge.F return nil, fmt.Errorf("error getting chain client for chain ID %d: %w", chainID, err) } - contractAddress, ok := contracts.Contracts[chainID] + // TODO: handle v2 contract if specified + contractAddress, ok := contracts.ContractsV1[chainID] if !ok { return nil, fmt.Errorf("no contract address for chain ID") } diff --git a/services/rfq/api/client/suite_test.go b/services/rfq/api/client/suite_test.go index e87436fcca..183cd0b13b 100644 --- a/services/rfq/api/client/suite_test.go +++ b/services/rfq/api/client/suite_test.go @@ -85,7 +85,11 @@ func (c *ClientSuite) SetupTest() { DSN: filet.TmpFile(c.T(), "", "").Name(), }, OmniRPCURL: testOmnirpc, - Bridges: map[uint32]string{ + FastBridgeContractsV1: map[uint32]string{ + 1: ethFastBridgeAddress.Hex(), + 42161: arbFastBridgeAddress.Hex(), + }, + FastBridgeContractsV2: map[uint32]string{ 1: ethFastBridgeAddress.Hex(), 42161: arbFastBridgeAddress.Hex(), }, diff --git a/services/rfq/api/config/config.go b/services/rfq/api/config/config.go index 4c4456f246..13225a1c95 100644 --- a/services/rfq/api/config/config.go +++ b/services/rfq/api/config/config.go @@ -19,13 +19,13 @@ type DatabaseConfig struct { // Config is the configuration for the RFQ Quoter. type Config struct { - Database DatabaseConfig `yaml:"database"` - OmniRPCURL string `yaml:"omnirpc_url"` - // bridges is a map of chainid->address - Bridges map[uint32]string `yaml:"bridges"` - Port string `yaml:"port"` - RelayAckTimeout time.Duration `yaml:"relay_ack_timeout"` - MaxQuoteAge time.Duration `yaml:"max_quote_age"` + Database DatabaseConfig `yaml:"database"` + OmniRPCURL string `yaml:"omnirpc_url"` + FastBridgeContractsV1 map[uint32]string `yaml:"fast_bridge_contracts_v1"` + FastBridgeContractsV2 map[uint32]string `yaml:"fast_bridge_contracts_v2"` + Port string `yaml:"port"` + RelayAckTimeout time.Duration `yaml:"relay_ack_timeout"` + MaxQuoteAge time.Duration `yaml:"max_quote_age"` } const defaultRelayAckTimeout = 30 * time.Second diff --git a/services/rfq/api/model/response.go b/services/rfq/api/model/response.go index b6624ff6b9..bcdb58d462 100644 --- a/services/rfq/api/model/response.go +++ b/services/rfq/api/model/response.go @@ -41,10 +41,12 @@ type PutRelayAckResponse struct { RelayerAddress string `json:"relayer_address"` } -// GetContractsResponse contains the schema for a GET /contract response. +// GetContractsResponse contains the schema for a GET /contracts response. type GetContractsResponse struct { - // Contracts is a map of chain id to contract address - Contracts map[uint32]string `json:"contracts"` + // ContractsV1 is a map of chain id to contract address for v1 fast bridge contracts + ContractsV1 map[uint32]string `json:"contracts_v1"` + // ContractsV2 is a map of chain id to contract address for v2 fast bridge contracts + ContractsV2 map[uint32]string `json:"contracts_v2"` } // ActiveRFQMessage represents the general structure of WebSocket messages for Active RFQ. diff --git a/services/rfq/api/rest/handler.go b/services/rfq/api/rest/handler.go index dd91873362..6fd40bec0b 100644 --- a/services/rfq/api/rest/handler.go +++ b/services/rfq/api/rest/handler.go @@ -6,6 +6,7 @@ import ( "strconv" "time" + "github.com/ethereum/go-ethereum/common" "github.com/synapsecns/sanguine/services/rfq/api/config" "github.com/gin-gonic/gin" @@ -69,7 +70,7 @@ func (h *Handler) ModifyQuote(c *gin.Context) { return } - dbQuote, err := parseDBQuote(*putRequest, relayerAddr) + dbQuote, err := parseDBQuote(h.cfg, *putRequest, relayerAddr) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return @@ -117,7 +118,7 @@ func (h *Handler) ModifyBulkQuotes(c *gin.Context) { dbQuotes := []*db.Quote{} for _, quoteReq := range putRequest.Quotes { - dbQuote, err := parseDBQuote(quoteReq, relayerAddr) + dbQuote, err := parseDBQuote(h.cfg, quoteReq, relayerAddr) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid quote request"}) return @@ -134,7 +135,7 @@ func (h *Handler) ModifyBulkQuotes(c *gin.Context) { } //nolint:gosec -func parseDBQuote(putRequest model.PutRelayerQuoteRequest, relayerAddr interface{}) (*db.Quote, error) { +func parseDBQuote(cfg config.Config, putRequest model.PutRelayerQuoteRequest, relayerAddr interface{}) (*db.Quote, error) { destAmount, err := decimal.NewFromString(putRequest.DestAmount) if err != nil { return nil, fmt.Errorf("invalid DestAmount") @@ -147,6 +148,12 @@ func parseDBQuote(putRequest model.PutRelayerQuoteRequest, relayerAddr interface if err != nil { return nil, fmt.Errorf("invalid FixedFee") } + + err = validateFastBridgeAddresses(cfg, putRequest) + if err != nil { + return nil, fmt.Errorf("invalid fast bridge addresses: %w", err) + } + // nolint: forcetypeassert return &db.Quote{ OriginChainID: uint64(putRequest.OriginChainID), @@ -163,6 +170,24 @@ func parseDBQuote(putRequest model.PutRelayerQuoteRequest, relayerAddr interface }, nil } +//nolint:gosec +func validateFastBridgeAddresses(cfg config.Config, putRequest model.PutRelayerQuoteRequest) error { + // Check V1 contracts + isV1Origin := common.HexToAddress(cfg.FastBridgeContractsV1[uint32(putRequest.OriginChainID)]) == common.HexToAddress(putRequest.OriginFastBridgeAddress) + isV1Dest := common.HexToAddress(cfg.FastBridgeContractsV1[uint32(putRequest.DestChainID)]) == common.HexToAddress(putRequest.DestFastBridgeAddress) + + // Check V2 contracts + isV2Origin := common.HexToAddress(cfg.FastBridgeContractsV2[uint32(putRequest.OriginChainID)]) == common.HexToAddress(putRequest.OriginFastBridgeAddress) + isV2Dest := common.HexToAddress(cfg.FastBridgeContractsV2[uint32(putRequest.DestChainID)]) == common.HexToAddress(putRequest.DestFastBridgeAddress) + + // Valid if both addresses match either V1 or V2 + if (isV1Origin && isV1Dest) || (isV2Origin && isV2Dest) { + return nil + } + + return fmt.Errorf("origin and destination fast bridge addresses must match either V1 or V2") +} + //nolint:gosec func quoteResponseFromDBQuote(dbQuote *db.Quote) *model.GetQuoteResponse { return &model.GetQuoteResponse{ @@ -301,12 +326,10 @@ func dbActiveQuoteRequestToModel(dbQuote *db.ActiveQuoteRequest) *model.GetOpenQ // @Header 200 {string} X-Api-Version "API Version Number - See docs for more info" // @Router /contracts [get]. func (h *Handler) GetContracts(c *gin.Context) { - // Convert quotes from db model to api model - contracts := make(map[uint32]string) - for chainID, address := range h.cfg.Bridges { - contracts[chainID] = address - } - c.JSON(http.StatusOK, model.GetContractsResponse{Contracts: contracts}) + c.JSON(http.StatusOK, model.GetContractsResponse{ + ContractsV1: h.cfg.FastBridgeContractsV1, + ContractsV2: h.cfg.FastBridgeContractsV2, + }) } func filterQuoteAge(cfg config.Config, dbQuotes []*db.Quote) []*db.Quote { diff --git a/services/rfq/api/rest/server.go b/services/rfq/api/rest/server.go index d748e55295..cdf3d59529 100644 --- a/services/rfq/api/rest/server.go +++ b/services/rfq/api/rest/server.go @@ -20,6 +20,7 @@ import ( "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" "go.opentelemetry.io/otel/trace" + "golang.org/x/sync/errgroup" "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common" @@ -35,6 +36,7 @@ import ( "github.com/synapsecns/sanguine/services/rfq/api/docs" "github.com/synapsecns/sanguine/services/rfq/api/model" "github.com/synapsecns/sanguine/services/rfq/contracts/fastbridge" + "github.com/synapsecns/sanguine/services/rfq/contracts/fastbridgev2" "github.com/synapsecns/sanguine/services/rfq/relayer/relapi" ) @@ -51,15 +53,17 @@ func getCurrentVersion() (string, error) { // QuoterAPIServer is a struct that holds the configuration, database connection, gin engine, RPC client, metrics handler, and fast bridge contracts. // It is used to initialize and run the API server. type QuoterAPIServer struct { - cfg config.Config - db db.APIDB - engine *gin.Engine - upgrader websocket.Upgrader - omnirpcClient omniClient.RPCClient - handler metrics.Handler - meter metric.Meter - fastBridgeContracts map[uint32]*fastbridge.FastBridge - roleCache map[uint32]*ttlcache.Cache[string, bool] + cfg config.Config + db db.APIDB + engine *gin.Engine + upgrader websocket.Upgrader + omnirpcClient omniClient.RPCClient + handler metrics.Handler + meter metric.Meter + fastBridgeContractsV1 map[uint32]*fastbridge.FastBridge + fastBridgeContractsV2 map[uint32]*fastbridgev2.FastBridgeV2 + roleCacheV1 map[uint32]*ttlcache.Cache[string, bool] + roleCacheV2 map[uint32]*ttlcache.Cache[string, bool] // relayAckCache contains a set of transactionID values that reflect // transactions that have been acked for relay relayAckCache *ttlcache.Cache[string, string] @@ -96,23 +100,47 @@ func NewAPI( docs.SwaggerInfo.Title = "RFQ Quoter API" - bridges := make(map[uint32]*fastbridge.FastBridge) - roles := make(map[uint32]*ttlcache.Cache[string, bool]) - for chainID, bridge := range cfg.Bridges { + fastBridgeContractsV1 := make(map[uint32]*fastbridge.FastBridge) + rolesV1 := make(map[uint32]*ttlcache.Cache[string, bool]) + for chainID, contract := range cfg.FastBridgeContractsV1 { chainClient, err := omniRPCClient.GetChainClient(ctx, int(chainID)) if err != nil { return nil, fmt.Errorf("could not create omnirpc client: %w", err) } - bridges[chainID], err = fastbridge.NewFastBridge(common.HexToAddress(bridge), chainClient) + fastBridgeContractsV1[chainID], err = fastbridge.NewFastBridge(common.HexToAddress(contract), chainClient) if err != nil { return nil, fmt.Errorf("could not create bridge contract: %w", err) } // create the roles cache - roles[chainID] = ttlcache.New[string, bool]( + rolesV1[chainID] = ttlcache.New[string, bool]( ttlcache.WithTTL[string, bool](cacheInterval), ) - roleCache := roles[chainID] + roleCache := rolesV1[chainID] + go roleCache.Start() + go func() { + <-ctx.Done() + roleCache.Stop() + }() + } + + fastBridgeContractsV2 := make(map[uint32]*fastbridgev2.FastBridgeV2) + rolesV2 := make(map[uint32]*ttlcache.Cache[string, bool]) + for chainID, contract := range cfg.FastBridgeContractsV2 { + chainClient, err := omniRPCClient.GetChainClient(ctx, int(chainID)) + if err != nil { + return nil, fmt.Errorf("could not create omnirpc client: %w", err) + } + fastBridgeContractsV2[chainID], err = fastbridgev2.NewFastBridgeV2(common.HexToAddress(contract), chainClient) + if err != nil { + return nil, fmt.Errorf("could not create bridge contract: %w", err) + } + + // create the roles cache + rolesV2[chainID] = ttlcache.New[string, bool]( + ttlcache.WithTTL[string, bool](cacheInterval), + ) + roleCache := rolesV2[chainID] go roleCache.Start() go func() { <-ctx.Done() @@ -132,17 +160,19 @@ func NewAPI( }() q := &QuoterAPIServer{ - cfg: cfg, - db: store, - omnirpcClient: omniRPCClient, - handler: handler, - meter: handler.Meter(meterName), - fastBridgeContracts: bridges, - roleCache: roles, - relayAckCache: relayAckCache, - ackMux: sync.Mutex{}, - wsClients: xsync.NewMapOf[WsClient](), - pubSubManager: NewPubSubManager(), + cfg: cfg, + db: store, + omnirpcClient: omniRPCClient, + handler: handler, + meter: handler.Meter(meterName), + fastBridgeContractsV1: fastBridgeContractsV1, + fastBridgeContractsV2: fastBridgeContractsV2, + roleCacheV1: rolesV1, + roleCacheV2: rolesV2, + relayAckCache: relayAckCache, + ackMux: sync.Mutex{}, + wsClients: xsync.NewMapOf[WsClient](), + pubSubManager: NewPubSubManager(), } // Prometheus metrics setup @@ -298,7 +328,7 @@ func (r *QuoterAPIServer) AuthMiddleware() gin.HandlerFunc { // Authenticate and fetch the address from the request var addressRecovered *common.Address for _, destChainID := range destChainIDs { - addr, err := r.checkRole(c, destChainID) + addr, err := r.checkRoleParallel(c, destChainID) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"msg": err.Error()}) c.Abort() @@ -321,11 +351,57 @@ func (r *QuoterAPIServer) AuthMiddleware() gin.HandlerFunc { } } -func (r *QuoterAPIServer) checkRole(c *gin.Context, destChainID uint32) (addressRecovered common.Address, err error) { - bridge, ok := r.fastBridgeContracts[destChainID] - if !ok { - err = fmt.Errorf("dest chain id not supported: %d", destChainID) - return addressRecovered, err +type roleContract interface { + HasRole(opts *bind.CallOpts, role [32]byte, account common.Address) (bool, error) +} + +func (r *QuoterAPIServer) checkRoleParallel(c *gin.Context, destChainID uint32) (addressRecovered common.Address, err error) { + g := new(errgroup.Group) + var v1Addr, v2Addr common.Address + var v1Err, v2Err error + + g.Go(func() error { + v1Addr, v1Err = r.checkRole(c, destChainID, true) + return v1Err + }) + + g.Go(func() error { + v2Addr, v2Err = r.checkRole(c, destChainID, false) + return v2Err + }) + + err = g.Wait() + if v1Addr != (common.Address{}) { + return v1Addr, nil + } + if v2Addr != (common.Address{}) { + return v2Addr, nil + } + if err != nil { + return common.Address{}, fmt.Errorf("role check failed: %w", err) + } + + return common.Address{}, fmt.Errorf("role check failed for both v1 and v2") +} + +func (r *QuoterAPIServer) checkRole(c *gin.Context, destChainID uint32, useV1 bool) (addressRecovered common.Address, err error) { + var bridge roleContract + var roleCache *ttlcache.Cache[string, bool] + var ok bool + if useV1 { + bridge, ok = r.fastBridgeContractsV1[destChainID] + if !ok { + err = fmt.Errorf("dest chain id not supported: %d", destChainID) + return addressRecovered, err + } + roleCache = r.roleCacheV1[destChainID] + } else { + bridge, ok = r.fastBridgeContractsV2[destChainID] + if !ok { + err = fmt.Errorf("dest chain id not supported: %d", destChainID) + return addressRecovered, err + } + roleCache = r.roleCacheV2[destChainID] } ops := &bind.CallOpts{Context: c} @@ -340,7 +416,7 @@ func (r *QuoterAPIServer) checkRole(c *gin.Context, destChainID uint32) (address } // Check and update cache - cachedRoleItem := r.roleCache[destChainID].Get(addressRecovered.Hex()) + cachedRoleItem := roleCache.Get(addressRecovered.Hex()) var hasRole bool if cachedRoleItem == nil || cachedRoleItem.IsExpired() { @@ -350,7 +426,7 @@ func (r *QuoterAPIServer) checkRole(c *gin.Context, destChainID uint32) (address return addressRecovered, fmt.Errorf("unable to check relayer role on-chain: %w", err) } // Update cache - r.roleCache[destChainID].Set(addressRecovered.Hex(), hasRole, cacheInterval) + roleCache.Set(addressRecovered.Hex(), hasRole, cacheInterval) } else { // Use cached value hasRole = cachedRoleItem.Value() diff --git a/services/rfq/api/rest/server_test.go b/services/rfq/api/rest/server_test.go index 8f9431ee75..e29143dfaa 100644 --- a/services/rfq/api/rest/server_test.go +++ b/services/rfq/api/rest/server_test.go @@ -607,5 +607,6 @@ func (c *ServerSuite) TestContracts() { contracts, err := client.GetRFQContracts(c.GetTestContext()) c.Require().NoError(err) - c.Require().Len(contracts.Contracts, 2) + c.Require().Len(contracts.ContractsV1, 2) + c.Require().Len(contracts.ContractsV2, 2) } diff --git a/services/rfq/api/rest/suite_test.go b/services/rfq/api/rest/suite_test.go index 755b4882ca..64454dcc02 100644 --- a/services/rfq/api/rest/suite_test.go +++ b/services/rfq/api/rest/suite_test.go @@ -82,7 +82,11 @@ func (c *ServerSuite) SetupTest() { DSN: filet.TmpFile(c.T(), "", "").Name(), }, OmniRPCURL: testOmnirpc, - Bridges: map[uint32]string{ + FastBridgeContractsV1: map[uint32]string{ + 1: ethFastBridgeAddress.Hex(), + 42161: arbFastBridgeAddress.Hex(), + }, + FastBridgeContractsV2: map[uint32]string{ 1: ethFastBridgeAddress.Hex(), 42161: arbFastBridgeAddress.Hex(), }, diff --git a/services/rfq/e2e/setup_test.go b/services/rfq/e2e/setup_test.go index c3a2cf9e7b..d8ce6d8fab 100644 --- a/services/rfq/e2e/setup_test.go +++ b/services/rfq/e2e/setup_test.go @@ -59,7 +59,7 @@ func (i *IntegrationSuite) setupQuoterAPI() { DSN: dbPath, }, OmniRPCURL: i.omniServer, - Bridges: map[uint32]string{ + FastBridgeContractsV1: map[uint32]string{ originBackendChainID: i.manager.Get(i.GetTestContext(), i.originBackend, testutil.FastBridgeType).Address().String(), destBackendChainID: i.manager.Get(i.GetTestContext(), i.destBackend, testutil.FastBridgeType).Address().String(), },