From 981722cb1bfbb8d8d7f9bbe8a4e630c972048d9e Mon Sep 17 00:00:00 2001 From: Daniel Wasserman Date: Mon, 18 Nov 2024 14:45:12 -0600 Subject: [PATCH] Feat: add checkRoleParallel() --- services/rfq/api/rest/server.go | 36 ++++++++++++++++++++++++++++----- 1 file changed, 31 insertions(+), 5 deletions(-) diff --git a/services/rfq/api/rest/server.go b/services/rfq/api/rest/server.go index af2ab4da5b..8bc3136b11 100644 --- a/services/rfq/api/rest/server.go +++ b/services/rfq/api/rest/server.go @@ -19,6 +19,7 @@ import ( "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" "go.opentelemetry.io/otel/trace" + "golang.org/x/sync/errgroup" "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common" @@ -275,7 +276,6 @@ func (r *QuoterAPIServer) Run(ctx context.Context) error { func (r *QuoterAPIServer) AuthMiddleware() gin.HandlerFunc { return func(c *gin.Context) { var loggedRequest interface{} - var useV1 bool var err error destChainIDs := []uint32{} @@ -288,7 +288,6 @@ func (r *QuoterAPIServer) AuthMiddleware() gin.HandlerFunc { destChainIDs = append(destChainIDs, uint32(req.DestChainID)) loggedRequest = &req } - useV1 = true case BulkQuotesRoute: var req model.PutBulkQuotesRequest err = c.BindJSON(&req) @@ -298,7 +297,6 @@ func (r *QuoterAPIServer) AuthMiddleware() gin.HandlerFunc { } loggedRequest = &req } - useV1 = true case AckRoute: var req model.PutAckRequest err = c.BindJSON(&req) @@ -306,7 +304,6 @@ func (r *QuoterAPIServer) AuthMiddleware() gin.HandlerFunc { destChainIDs = append(destChainIDs, uint32(req.DestChainID)) loggedRequest = &req } - useV1 = true case RFQRoute, RFQStreamRoute: chainsHeader := c.GetHeader(ChainsHeader) if chainsHeader != "" { @@ -330,7 +327,7 @@ func (r *QuoterAPIServer) AuthMiddleware() gin.HandlerFunc { // Authenticate and fetch the address from the request var addressRecovered *common.Address for _, destChainID := range destChainIDs { - addr, err := r.checkRole(c, destChainID, useV1) + addr, err := r.checkRoleParallel(c, destChainID) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"msg": err.Error()}) c.Abort() @@ -357,6 +354,35 @@ type roleContract interface { HasRole(opts *bind.CallOpts, role [32]byte, account common.Address) (bool, error) } +func (r *QuoterAPIServer) checkRoleParallel(c *gin.Context, destChainID uint32) (addressRecovered common.Address, err error) { + g := new(errgroup.Group) + var v1Addr, v2Addr common.Address + var v1Err, v2Err error + + g.Go(func() error { + v1Addr, v1Err = r.checkRole(c, destChainID, true) + return v1Err + }) + + g.Go(func() error { + v2Addr, v2Err = r.checkRole(c, destChainID, false) + return v2Err + }) + + err = g.Wait() + if v1Addr != (common.Address{}) { + return v1Addr, nil + } + if v2Addr != (common.Address{}) { + return v2Addr, nil + } + if err != nil { + return common.Address{}, fmt.Errorf("role check failed: %w", err) + } + + return common.Address{}, fmt.Errorf("role check failed for both v1 and v2") +} + func (r *QuoterAPIServer) checkRole(c *gin.Context, destChainID uint32, useV1 bool) (addressRecovered common.Address, err error) { var bridge roleContract var roleCache *ttlcache.Cache[string, bool]