Skip to content

Commit

Permalink
Merge pull request nyaruka#382 from nyaruka/fix-WA-rate-limit
Browse files Browse the repository at this point in the history
Throttle courier queues when the channel has rate limit redis key not…
  • Loading branch information
rowanseymour authored Dec 8, 2021
2 parents ca40349 + 9560dc9 commit 46475f1
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 2 deletions.
22 changes: 20 additions & 2 deletions handlers/whatsapp/whatsapp.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"time"

"github.com/buger/jsonparser"
"github.com/gomodule/redigo/redis"
"github.com/nyaruka/courier"
"github.com/nyaruka/courier/handlers"
"github.com/nyaruka/courier/utils"
Expand Down Expand Up @@ -495,6 +496,9 @@ const maxMsgLength = 4096
// SendMsg sends the passed in message, returning any error
func (h *handler) SendMsg(ctx context.Context, msg courier.Msg) (courier.MsgStatus, error) {
start := time.Now()
conn := h.Backend().RedisPool().Get()
defer conn.Close()

// get our token
token := msg.Channel().StringConfigForKey(courier.ConfigAuthToken, "")
if token == "" {
Expand Down Expand Up @@ -525,7 +529,7 @@ func (h *handler) SendMsg(ctx context.Context, msg courier.Msg) (courier.MsgStat

for i, payload := range payloads {
externalID := ""
wppID, externalID, logs, err = sendWhatsAppMsg(msg, sendPath, payload)
wppID, externalID, logs, err = sendWhatsAppMsg(conn, msg, sendPath, payload)
// add logs to our status
for _, log := range logs {
status.AddLog(log)
Expand Down Expand Up @@ -827,7 +831,7 @@ func (h *handler) fetchMediaID(msg courier.Msg, mimeType, mediaURL string) (stri
return mediaID, logs, nil
}

func sendWhatsAppMsg(msg courier.Msg, sendPath *url.URL, payload interface{}) (string, string, []*courier.ChannelLog, error) {
func sendWhatsAppMsg(rc redis.Conn, msg courier.Msg, sendPath *url.URL, payload interface{}) (string, string, []*courier.ChannelLog, error) {
start := time.Now()
jsonBody, err := json.Marshal(payload)

Expand All @@ -839,6 +843,20 @@ func sendWhatsAppMsg(msg courier.Msg, sendPath *url.URL, payload interface{}) (s
req, _ := http.NewRequest(http.MethodPost, sendPath.String(), bytes.NewReader(jsonBody))
req.Header = buildWhatsAppHeaders(msg.Channel())
rr, err := utils.MakeHTTPRequest(req)

if rr.StatusCode == 429 || rr.StatusCode == 503 {
rateLimitKey := fmt.Sprintf("rate_limit:%s", msg.Channel().UUID().String())
rc.Do("set", rateLimitKey, "engaged")

// The rate limit is 50 requests per second
// We pause sending 2 seconds so the limit count is reset
// TODO: In the future we should the header value when available
rc.Do("expire", rateLimitKey, 2)

log := courier.NewChannelLogFromRR("rate limit engaged", msg.Channel(), msg.ID(), rr).WithError("Message Send Error", err)
return "", "", []*courier.ChannelLog{log}, err
}

log := courier.NewChannelLogFromRR("Message Sent", msg.Channel(), msg.ID(), rr).WithError("Message Send Error", err)
errPayload := &mtErrorPayload{}
err = json.Unmarshal(rr.Body, errPayload)
Expand Down
6 changes: 6 additions & 0 deletions handlers/whatsapp/whatsapp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,12 @@ var defaultSendTestCases = []ChannelSendTestCase{
ResponseBody: `{ "errors": [{ "title": "Error Sending" }] }`, ResponseStatus: 403,
RequestBody: `{"to":"250788123123","type":"text","text":{"body":"Error"}}`,
SendPrep: setSendURL},
{Label: "Rate Limit Engaged",
Text: "Error", URN: "whatsapp:250788123123",
Status: "E",
ResponseBody: `{ "errors": [{ "title": "Too many requests" }] }`, ResponseStatus: 429,
RequestBody: `{"to":"250788123123","type":"text","text":{"body":"Error"}}`,
SendPrep: setSendURL},
{Label: "No Message ID",
Text: "Error", URN: "whatsapp:250788123123",
Status: "E",
Expand Down
14 changes: 14 additions & 0 deletions queue/queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,24 @@ var luaPop = redis.NewScript(2, `-- KEYS: [EpochMS QueueType]
local delim = string.find(queue, "|")
local tps = 0
local tpsKey = ""
local queueName = ""
if delim then
queueName = string.sub(queue, string.len(KEYS[2])+2, delim-1)
tps = tonumber(string.sub(queue, delim+1))
end
if queueName then
local rateLimitKey = "rate_limit:" .. queueName
local rateLimitEngaged = redis.call("get", rateLimitKey)
if rateLimitEngaged then
redis.call("zincrby", KEYS[2] .. ":throttled", workers, queue)
redis.call("zrem", KEYS[2] .. ":active", queue)
return {"retry", ""}
end
end
-- if we have a tps, then check whether we exceed it
if tps > 0 then
tpsKey = queue .. ":tps:" .. math.floor(KEYS[1])
Expand Down
46 changes: 46 additions & 0 deletions queue/queue_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ func TestLua(t *testing.T) {
defer close(quitter)

rate := 10

for i := 0; i < 20; i++ {
err := PushOntoQueue(conn, "msgs", "chan1", rate, fmt.Sprintf(`[{"id":%d}]`, i), LowPriority)
assert.NoError(err)
Expand Down Expand Up @@ -166,6 +167,51 @@ func TestLua(t *testing.T) {
assert.NoError(err)
assert.Equal(EmptyQueue, queue)
assert.Empty(value)

err = PushOntoQueue(conn, "msgs", "chan1", rate, `[{"id":34}]`, HighPriority)
assert.NoError(err)

conn.Do("set", "rate_limit:chan1", "engaged")
conn.Do("EXPIRE", "rate_limit:chan1", 5)

// we have the rate limit set,
queue, value, err = PopFromQueue(conn, "msgs")
if value != "" && queue != EmptyQueue {
t.Fatal("Should be throttled")
}

time.Sleep(2 * time.Second)
queue, value, err = PopFromQueue(conn, "msgs")
if value != "" && queue != EmptyQueue {
t.Fatal("Should be throttled")
}

count, err = redis.Int(conn.Do("zcard", "msgs:throttled"))
assert.NoError(err)
assert.Equal(1, count, "Expected chan1 to be throttled")

count, err = redis.Int(conn.Do("zcard", "msgs:active"))
assert.NoError(err)
assert.Equal(0, count, "Expected chan1 to not be active")

// but if we wait for the rate limit to expire
time.Sleep(3 * time.Second)

// next should be 34
queue, value, err = PopFromQueue(conn, "msgs")
assert.NotEqual(queue, EmptyQueue)
assert.Equal(`{"id":34}`, value)
assert.NoError(err)

// nothing should be left
queue = Retry
for queue == Retry {
queue, value, err = PopFromQueue(conn, "msgs")
}
assert.NoError(err)
assert.Equal(EmptyQueue, queue)
assert.Empty(value)

}

func nTestThrottle(t *testing.T) {
Expand Down

0 comments on commit 46475f1

Please sign in to comment.