From 5cce9a1f0d0418e6ad9f6292eef25bbabb13a2d8 Mon Sep 17 00:00:00 2001 From: Daniel Wasserman Date: Fri, 5 Jul 2024 15:19:14 -0500 Subject: [PATCH] Fix: bulk quotes auth --- services/rfq/api/rest/server.go | 35 +++++++++++++++++++++++++-------- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/services/rfq/api/rest/server.go b/services/rfq/api/rest/server.go index d67a544999..03cfcdeed5 100644 --- a/services/rfq/api/rest/server.go +++ b/services/rfq/api/rest/server.go @@ -192,8 +192,8 @@ func (r *QuoterAPIServer) Run(ctx context.Context) error { func (r *QuoterAPIServer) AuthMiddleware() gin.HandlerFunc { return func(c *gin.Context) { var loggedRequest interface{} - var destChainID uint32 var err error + destChainIDs := []uint32{} // Parse the dest chain id from the request switch c.Request.URL.Path { @@ -201,14 +201,23 @@ func (r *QuoterAPIServer) AuthMiddleware() gin.HandlerFunc { var req model.PutQuoteRequest err = c.BindJSON(&req) if err == nil { - destChainID = uint32(req.DestChainID) + destChainIDs = append(destChainIDs, uint32(req.DestChainID)) + loggedRequest = &req + } + case BulkQuotesRoute: + var req model.PutBulkQuotesRequest + err = c.BindJSON(&req) + if err == nil { + for _, quote := range req.Quotes { + destChainIDs = append(destChainIDs, uint32(quote.DestChainID)) + } loggedRequest = &req } case AckRoute: var req model.PutAckRequest err = c.BindJSON(&req) if err == nil { - destChainID = uint32(req.DestChainID) + destChainIDs = append(destChainIDs, uint32(req.DestChainID)) loggedRequest = &req } default: @@ -221,11 +230,21 @@ func (r *QuoterAPIServer) AuthMiddleware() gin.HandlerFunc { } // Authenticate and fetch the address from the request - addressRecovered, err := r.checkRole(c, destChainID) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"msg": err.Error()}) - c.Abort() - return + var addressRecovered *common.Address + for _, destChainID := range destChainIDs { + addr, err := r.checkRole(c, destChainID) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"msg": err.Error()}) + c.Abort() + return + } + if addressRecovered == nil { + addressRecovered = &addr + } else if *addressRecovered != addr { + c.JSON(http.StatusBadRequest, gin.H{"msg": "relayer address mismatch"}) + c.Abort() + return + } } // Log and pass to the next middleware if authentication succeeds