diff --git a/contrib/screener-api/screener/screener.go b/contrib/screener-api/screener/screener.go index e71395767d..e7465d3665 100644 --- a/contrib/screener-api/screener/screener.go +++ b/contrib/screener-api/screener/screener.go @@ -12,6 +12,7 @@ import ( "time" "github.com/gin-gonic/gin" + "github.com/gin-gonic/gin/binding" "github.com/ipfs/go-log" "github.com/synapsecns/sanguine/contrib/screener-api/client" "github.com/synapsecns/sanguine/contrib/screener-api/config" @@ -85,7 +86,7 @@ func NewScreener(ctx context.Context, cfg config.Config, metricHandler metrics.H screener.router.Handle(http.MethodGet, "/:ruleset/address/:address", screener.screenAddress) // idk the middleware is faking up - screener.router.Handle(http.MethodPost, "/api/data/sync", screener.blacklistAddress) + screener.router.Handle(http.MethodPost, "/api/data/sync", screener.authMiddleware(), screener.blacklistAddress) return &screener, nil } @@ -124,33 +125,26 @@ func (s *screenerImpl) fetchBlacklist(ctx context.Context) { func (s *screenerImpl) blacklistAddress(c *gin.Context) { var blacklistBody client.BlackListBody - // grab the body - if err := c.ShouldBindJSON(&blacklistBody); err != nil { + // Grab the body of the JSON request and unmarshal it into the blacklistBody struct. + if err := c.ShouldBindBodyWith(&blacklistBody, binding.JSON); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - type_req := blacklistBody.TypeReq - id := blacklistBody.Id - data := blacklistBody.Data - address := blacklistBody.Address - network := blacklistBody.Network - tag := blacklistBody.Tag - remark := blacklistBody.Remark - address = strings.ToLower(address) + blacklistedAddress := db.BlacklistedAddress{ + TypeReq: blacklistBody.TypeReq, + Id: blacklistBody.Id, + Data: blacklistBody.Data, + Network: blacklistBody.Network, + Tag: blacklistBody.Tag, + Remark: blacklistBody.Remark, + Address: strings.ToLower(blacklistBody.Address), + } - switch type_req { + switch blacklistBody.TypeReq { case "create": - if err := s.db.PutBlacklistedAddress(c, db.BlacklistedAddress{ - Id: id, - TypeReq: type_req, - Data: data, - Address: address, - Network: network, - Tag: tag, - Remark: remark, - }); err != nil { + if err := s.db.PutBlacklistedAddress(c, blacklistedAddress); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } @@ -159,15 +153,7 @@ func (s *screenerImpl) blacklistAddress(c *gin.Context) { return case "update": - if err := s.db.UpdateBlacklistedAddress(c, id, db.BlacklistedAddress{ - Id: id, - TypeReq: type_req, - Data: data, - Address: address, - Network: network, - Tag: tag, - Remark: remark, - }); err != nil { + if err := s.db.UpdateBlacklistedAddress(c, blacklistedAddress.Id, blacklistedAddress); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } @@ -176,10 +162,11 @@ func (s *screenerImpl) blacklistAddress(c *gin.Context) { return case "delete": - if err := s.db.DeleteBlacklistedAddress(c, address); err != nil { + if err := s.db.DeleteBlacklistedAddress(c, blacklistedAddress.Address); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } + c.JSON(http.StatusOK, gin.H{"status": "success"}) return @@ -190,30 +177,37 @@ func (s *screenerImpl) blacklistAddress(c *gin.Context) { } -func (s *screenerImpl) authMiddleware(c *gin.Context) { - var blacklistBody client.BlackListBody +// This function takes the HTTP headers and the body of the request and reconstructs the signature to +// compare it with the signature provided. If they match, the request is allowed to pass through. +func (s *screenerImpl) authMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + var blacklistBody client.BlackListBody - if err := c.ShouldBindJSON(&blacklistBody); err != nil { - // c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - c.JSON(http.StatusBadRequest, gin.H{"error": "Auth middleware fucked up"}) - return - } + if err := c.ShouldBindBodyWith(&blacklistBody, binding.JSON); err != nil { + c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()}) + c.Abort() + return + } - nonce := c.GetHeader("nonce") - timestamp := c.GetHeader("timestamp") - appid := c.GetHeader("appid") - queryString := c.GetHeader("queryString") - if nonce == "" || timestamp == "" || appid == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "missing headers"}) - c.Abort() - } + nonce := c.GetHeader("nonce") + timestamp := c.GetHeader("timestamp") + appid := c.GetHeader("appid") + queryString := c.GetHeader("queryString") + if nonce == "" || timestamp == "" || appid == "" { + c.JSON(http.StatusConflict, gin.H{"error": "missing headers"}) + c.Abort() + return + } - // reconstruct signature - expected := client.GenerateSignature("appsecret", appid, timestamp, nonce, queryString, blacklistBody) + // reconstruct signature + expected := client.GenerateSignature("appsecret", appid, timestamp, nonce, queryString, blacklistBody) - if c.GetHeader("Signature") != expected { - c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"}) - c.Abort() + if c.GetHeader("Signature") != expected { + c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized your mom"}) + c.Abort() + return + } + c.Next() } }