Skip to content

Commit

Permalink
Feat: add checkRoleParallel()
Browse files Browse the repository at this point in the history
  • Loading branch information
dwasse committed Nov 18, 2024
1 parent cf4ece0 commit 981722c
Showing 1 changed file with 31 additions and 5 deletions.
36 changes: 31 additions & 5 deletions services/rfq/api/rest/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
"go.opentelemetry.io/otel/trace"
"golang.org/x/sync/errgroup"

"github.com/ethereum/go-ethereum/accounts/abi/bind"
"github.com/ethereum/go-ethereum/common"
Expand Down Expand Up @@ -275,7 +276,6 @@ func (r *QuoterAPIServer) Run(ctx context.Context) error {
func (r *QuoterAPIServer) AuthMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var loggedRequest interface{}
var useV1 bool
var err error
destChainIDs := []uint32{}

Expand All @@ -288,7 +288,6 @@ func (r *QuoterAPIServer) AuthMiddleware() gin.HandlerFunc {
destChainIDs = append(destChainIDs, uint32(req.DestChainID))
loggedRequest = &req
}
useV1 = true
case BulkQuotesRoute:
var req model.PutBulkQuotesRequest
err = c.BindJSON(&req)
Expand All @@ -298,15 +297,13 @@ func (r *QuoterAPIServer) AuthMiddleware() gin.HandlerFunc {
}
loggedRequest = &req
}
useV1 = true
case AckRoute:
var req model.PutAckRequest
err = c.BindJSON(&req)
if err == nil {
destChainIDs = append(destChainIDs, uint32(req.DestChainID))
loggedRequest = &req
}
useV1 = true
case RFQRoute, RFQStreamRoute:
chainsHeader := c.GetHeader(ChainsHeader)
if chainsHeader != "" {
Expand All @@ -330,7 +327,7 @@ func (r *QuoterAPIServer) AuthMiddleware() gin.HandlerFunc {
// Authenticate and fetch the address from the request
var addressRecovered *common.Address
for _, destChainID := range destChainIDs {
addr, err := r.checkRole(c, destChainID, useV1)
addr, err := r.checkRoleParallel(c, destChainID)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"msg": err.Error()})
c.Abort()
Expand All @@ -357,6 +354,35 @@ type roleContract interface {
HasRole(opts *bind.CallOpts, role [32]byte, account common.Address) (bool, error)
}

func (r *QuoterAPIServer) checkRoleParallel(c *gin.Context, destChainID uint32) (addressRecovered common.Address, err error) {
g := new(errgroup.Group)
var v1Addr, v2Addr common.Address
var v1Err, v2Err error

g.Go(func() error {
v1Addr, v1Err = r.checkRole(c, destChainID, true)
return v1Err
})

g.Go(func() error {
v2Addr, v2Err = r.checkRole(c, destChainID, false)
return v2Err
})

err = g.Wait()
if v1Addr != (common.Address{}) {
return v1Addr, nil
}
if v2Addr != (common.Address{}) {
return v2Addr, nil
}
if err != nil {
return common.Address{}, fmt.Errorf("role check failed: %w", err)
}

return common.Address{}, fmt.Errorf("role check failed for both v1 and v2")
}

func (r *QuoterAPIServer) checkRole(c *gin.Context, destChainID uint32, useV1 bool) (addressRecovered common.Address, err error) {
var bridge roleContract
var roleCache *ttlcache.Cache[string, bool]
Expand Down

0 comments on commit 981722c

Please sign in to comment.