Skip to content

Commit

Permalink
Add task to handle throttled expirations
Browse files Browse the repository at this point in the history
  • Loading branch information
rowanseymour committed Nov 19, 2024
1 parent 0eafabc commit ee6a392
Show file tree
Hide file tree
Showing 6 changed files with 200 additions and 81 deletions.
53 changes: 53 additions & 0 deletions core/tasks/expirations/bulk_expire.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package expirations

import (
"context"
"fmt"
"time"

"github.com/nyaruka/mailroom/core/models"
"github.com/nyaruka/mailroom/core/tasks"
"github.com/nyaruka/mailroom/core/tasks/handler"
"github.com/nyaruka/mailroom/core/tasks/handler/ctasks"
"github.com/nyaruka/mailroom/runtime"
)

// TypeBulkExpire is the type of the task
const TypeBulkExpire = "bulk_expire"

func init() {
tasks.RegisterType(TypeBulkExpire, func() tasks.Task { return &BulkExpireTask{} })
}

// BulkExpireTask is the payload of the task
type BulkExpireTask struct {
Expirations []*ExpiredWait `json:"expirations"`
}

func (t *BulkExpireTask) Type() string {
return TypeBulkExpire
}

// Timeout is the maximum amount of time the task can run for
func (t *BulkExpireTask) Timeout() time.Duration {
return time.Hour
}

func (t *BulkExpireTask) WithAssets() models.Refresh {
return models.RefreshNone
}

// Perform creates the actual task
func (t *BulkExpireTask) Perform(ctx context.Context, rt *runtime.Runtime, oa *models.OrgAssets) error {
rc := rt.RP.Get()
defer rc.Close()

for _, exp := range t.Expirations {
err := handler.QueueTask(rc, oa.OrgID(), exp.ContactID, ctasks.NewWaitExpiration(exp.SessionID, exp.WaitExpiresOn))
if err != nil {
return fmt.Errorf("error queuing handle task for expiration on session #%d: %w", exp.SessionID, err)
}
}

return nil
}
36 changes: 36 additions & 0 deletions core/tasks/expirations/bulk_expire_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package expirations_test

import (
"testing"
"time"

"github.com/nyaruka/gocommon/dates"
"github.com/nyaruka/mailroom/core/tasks/expirations"
"github.com/nyaruka/mailroom/testsuite"
"github.com/nyaruka/mailroom/testsuite/testdata"
"github.com/stretchr/testify/assert"
)

func TestBulkExpire(t *testing.T) {
_, rt := testsuite.Runtime()
defer testsuite.Reset(testsuite.ResetRedis)

defer dates.SetNowFunc(time.Now)
dates.SetNowFunc(dates.NewFixedNow(time.Date(2024, 11, 15, 13, 59, 0, 0, time.UTC)))

testsuite.QueueBatchTask(t, rt, testdata.Org1, &expirations.BulkExpireTask{
Expirations: []*expirations.ExpiredWait{
{SessionID: 123456, ContactID: testdata.Cathy.ID, WaitExpiresOn: time.Date(2024, 11, 15, 13, 57, 0, 0, time.UTC)},
{SessionID: 234567, ContactID: testdata.Bob.ID, WaitExpiresOn: time.Date(2024, 11, 15, 13, 58, 0, 0, time.UTC)},
},
})

assert.Equal(t, map[string]int{"bulk_expire": 1}, testsuite.FlushTasks(t, rt, "batch", "throttled"))

testsuite.AssertContactTasks(t, testdata.Org1, testdata.Cathy, []string{
`{"type":"expiration_event","task":{"session_id":123456,"time":"2024-11-15T13:57:00Z"},"queued_on":"2024-11-15T13:59:00Z"}`,
})
testsuite.AssertContactTasks(t, testdata.Org1, testdata.Bob, []string{
`{"type":"expiration_event","task":{"session_id":234567,"time":"2024-11-15T13:58:00Z"},"queued_on":"2024-11-15T13:59:00Z"}`,
})
}
126 changes: 68 additions & 58 deletions core/tasks/expirations/cron.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"log/slog"
"slices"
"time"

"github.com/nyaruka/mailroom/core/ivr"
Expand All @@ -15,22 +16,22 @@ import (
"github.com/nyaruka/redisx"
)

const (
expireBatchSize = 500
)

func init() {
tasks.RegisterCron("run_expirations", NewExpirationsCron())
tasks.RegisterCron("run_expirations", NewExpirationsCron(10, 100))
tasks.RegisterCron("expire_ivr_calls", &VoiceExpirationsCron{})
}

type ExpirationsCron struct {
marker *redisx.IntervalSet
marker *redisx.IntervalSet
bulkThreshold int // use bulk task for any org with this or more expirations
bulkBatchSize int // number of expirations to queue in a single bulk task
}

func NewExpirationsCron() *ExpirationsCron {
func NewExpirationsCron(bulkThreshold, bulkBatchSize int) *ExpirationsCron {
return &ExpirationsCron{
marker: redisx.NewIntervalSet("run_expirations", time.Hour*24, 2),
marker: redisx.NewIntervalSet("run_expirations", time.Hour*24, 2),
bulkThreshold: bulkThreshold,
bulkBatchSize: bulkBatchSize,
}
}

Expand All @@ -45,50 +46,42 @@ func (c *ExpirationsCron) AllInstances() bool {
// handles waiting messaging sessions whose waits have expired, resuming those that can be resumed,
// and expiring those that can't
func (c *ExpirationsCron) Run(ctx context.Context, rt *runtime.Runtime) (map[string]any, error) {
rc := rt.RP.Get()
defer rc.Close()

// we expire sessions that can't be resumed in batches
expiredSessions := make([]models.SessionID, 0, expireBatchSize)

// select messaging sessions with expired waits
rows, err := rt.DB.QueryxContext(ctx, sqlSelectExpiredWaits)
if err != nil {
return nil, fmt.Errorf("error querying for expired waits: %w", err)
return nil, fmt.Errorf("error querying sessions with expired waits: %w", err)
}
defer rows.Close()

numExpired, numDupes, numQueued := 0, 0, 0
taskID := func(w *ExpiredWait) string {
return fmt.Sprintf("%d:%s", w.SessionID, w.WaitExpiresOn.Format(time.RFC3339))
}

// scan and organize by org
byOrg := make(map[models.OrgID][]*ExpiredWait, 50)

// the sessions that can't be resumed and will be exited
toExit := make([]models.SessionID, 0, 100)

rc := rt.RP.Get()
defer rc.Close()

numDupes, numQueuedHandler, numQueuedBulk, numExited := 0, 0, 0, 0

for rows.Next() {
expiredWait := &ExpiredWait{}
err := rows.StructScan(expiredWait)
if err != nil {
if err := rows.StructScan(expiredWait); err != nil {
return nil, fmt.Errorf("error scanning expired wait: %w", err)
}

// if it can't be resumed, add to batch to be expired
if !expiredWait.WaitResumes {
expiredSessions = append(expiredSessions, expiredWait.SessionID)

// batch is full? commit it
if len(expiredSessions) == expireBatchSize {
err = models.ExitSessions(ctx, rt.DB, expiredSessions, models.SessionStatusExpired)
if err != nil {
return nil, fmt.Errorf("error expiring batch of sessions: %w", err)
}
expiredSessions = expiredSessions[:0]
}

numExpired++
toExit = append(toExit, expiredWait.SessionID)
continue
}

// create a contact task to resume this session
taskID := fmt.Sprintf("%d:%s", expiredWait.SessionID, expiredWait.WaitExpiresOn.Format(time.RFC3339))
queued, err := c.marker.IsMember(rc, taskID)
// check whether we've already queued this
queued, err := c.marker.IsMember(rc, taskID(expiredWait))
if err != nil {
return nil, fmt.Errorf("error checking whether expiration is queued: %w", err)
return nil, fmt.Errorf("error checking whether expiration is already queued: %w", err)
}

// already queued? move on
Expand All @@ -97,30 +90,47 @@ func (c *ExpirationsCron) Run(ctx context.Context, rt *runtime.Runtime) (map[str
continue
}

// ok, queue this task
err = handler.QueueTask(rc, expiredWait.OrgID, expiredWait.ContactID, ctasks.NewWaitExpiration(expiredWait.SessionID, expiredWait.WaitExpiresOn))
if err != nil {
return nil, fmt.Errorf("error adding new expiration task: %w", err)
}
byOrg[expiredWait.OrgID] = append(byOrg[expiredWait.OrgID], expiredWait)
}

// and mark it as queued
err = c.marker.Add(rc, taskID)
if err != nil {
return nil, fmt.Errorf("error marking expiration task as queued: %w", err)
}
for orgID, expirations := range byOrg {
throttle := len(expirations) >= c.bulkThreshold

for batch := range slices.Chunk(expirations, c.bulkBatchSize) {
if throttle {
if err := tasks.Queue(rc, tasks.ThrottledQueue, orgID, &BulkExpireTask{Expirations: batch}, true); err != nil {
return nil, fmt.Errorf("error queuing bulk expiration task to throttle queue: %w", err)
}
numQueuedBulk += len(batch)
}

for _, exp := range batch {
if !throttle {
err := handler.QueueTask(rc, orgID, exp.ContactID, ctasks.NewWaitExpiration(exp.SessionID, exp.WaitExpiresOn))
if err != nil {
return nil, fmt.Errorf("error queuing expiration task to handler queue: %w", err)
}
numQueuedHandler++
}

numQueued++
// mark as queued
if err = c.marker.Add(rc, taskID(exp)); err != nil {
return nil, fmt.Errorf("error marking expiration task as queued: %w", err)
}
}
}
}

// commit any stragglers
if len(expiredSessions) > 0 {
err = models.ExitSessions(ctx, rt.DB, expiredSessions, models.SessionStatusExpired)
// exit the sessions that can't be resumed
for batch := range slices.Chunk(toExit, 500) {
err = models.ExitSessions(ctx, rt.DB, batch, models.SessionStatusExpired)
if err != nil {
return nil, fmt.Errorf("error expiring runs and sessions: %w", err)
return nil, fmt.Errorf("error exiting expired sessions: %w", err)
}
numExited += len(batch)
}

return map[string]any{"expired": numExpired, "dupes": numDupes, "queued": numQueued}, nil
return map[string]any{"exited": numExited, "dupes": numDupes, "queued_handler": numQueuedHandler, "queued_bulk": numQueuedBulk}, nil
}

const sqlSelectExpiredWaits = `
Expand All @@ -131,11 +141,11 @@ const sqlSelectExpiredWaits = `
LIMIT 25000`

type ExpiredWait struct {
SessionID models.SessionID `db:"session_id"`
OrgID models.OrgID `db:"org_id"`
WaitExpiresOn time.Time `db:"wait_expires_on"`
WaitResumes bool `db:"wait_resume_on_expire"`
ContactID models.ContactID `db:"contact_id"`
SessionID models.SessionID `db:"session_id" json:"session_id"`
OrgID models.OrgID `db:"org_id" json:"-"`
WaitExpiresOn time.Time `db:"wait_expires_on" json:"wait_expires_on"`
WaitResumes bool `db:"wait_resume_on_expire" json:"-"`
ContactID models.ContactID `db:"contact_id" json:"contact_id"`
}

type VoiceExpirationsCron struct{}
Expand All @@ -158,7 +168,7 @@ func (c *VoiceExpirationsCron) Run(ctx context.Context, rt *runtime.Runtime) (ma
// select voice sessions with expired waits
rows, err := rt.DB.QueryxContext(ctx, sqlSelectExpiredVoiceWaits)
if err != nil {
return nil, fmt.Errorf("error querying for expired waits: %w", err)
return nil, fmt.Errorf("error querying voice sessions with expired waits: %w", err)
}
defer rows.Close()

Expand Down
55 changes: 38 additions & 17 deletions core/tasks/expirations/cron_test.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
package expirations_test

import (
"fmt"
"testing"
"time"

"github.com/nyaruka/gocommon/dbutil/assertdb"
"github.com/nyaruka/gocommon/i18n"
"github.com/nyaruka/gocommon/jsonx"
"github.com/nyaruka/gocommon/uuids"
"github.com/nyaruka/goflow/flows"
_ "github.com/nyaruka/mailroom/core/handlers"
"github.com/nyaruka/mailroom/core/models"
"github.com/nyaruka/mailroom/core/tasks"
Expand Down Expand Up @@ -50,13 +53,17 @@ func TestExpirations(t *testing.T) {
r6ID := testdata.InsertFlowRun(rt, testdata.Org1, s5ID, blake, testdata.Favorites, models.RunStatusActive, "")
r7ID := testdata.InsertFlowRun(rt, testdata.Org1, s5ID, blake, testdata.Favorites, models.RunStatusWaiting, "")

time.Sleep(5 * time.Millisecond)
// for other org create 6 waiting sessions that will expire
for i := range 6 {
c := testdata.InsertContact(rt, testdata.Org2, flows.ContactUUID(uuids.NewV4()), fmt.Sprint(i), i18n.NilLanguage, models.ContactStatusActive)
testdata.InsertWaitingSession(rt, testdata.Org2, c, models.FlowTypeMessaging, testdata.Favorites, models.NilCallID, time.Now(), time.Now(), true, nil)
}

// expire our sessions...
cron := expirations.NewExpirationsCron()
cron := expirations.NewExpirationsCron(3, 5)
res, err := cron.Run(ctx, rt)
assert.NoError(t, err)
assert.Equal(t, map[string]any{"dupes": 0, "expired": 1, "queued": 2}, res)
assert.Equal(t, map[string]any{"exited": 1, "dupes": 0, "queued_bulk": 6, "queued_handler": 2}, res)

// Cathy's session should be expired along with its runs
assertdb.Query(t, rt.DB, `SELECT status FROM flows_flowsession WHERE id = $1;`, s1ID).Columns(map[string]any{"status": "X"})
Expand All @@ -80,29 +87,43 @@ func TestExpirations(t *testing.T) {
assertdb.Query(t, rt.DB, `SELECT status FROM flows_flowrun WHERE id = $1;`, r6ID).Columns(map[string]any{"status": "A"})
assertdb.Query(t, rt.DB, `SELECT status FROM flows_flowrun WHERE id = $1;`, r7ID).Columns(map[string]any{"status": "W"})

// should have created two expiration tasks
task, err := tasks.HandlerQueue.Pop(rc)
// should have created two handler tasks for org 1
task1, err := tasks.HandlerQueue.Pop(rc)
assert.NoError(t, err)
assert.Equal(t, int(testdata.Org1.ID), task1.OwnerID)
assert.Equal(t, "handle_contact_event", task1.Type)
task2, err := tasks.HandlerQueue.Pop(rc)
assert.NoError(t, err)
assert.NotNil(t, task)
assert.Equal(t, int(testdata.Org1.ID), task2.OwnerID)
assert.Equal(t, "handle_contact_event", task2.Type)

// check the first task
// decode the tasks to check contacts
eventTask := &handler.HandleContactEventTask{}
jsonx.MustUnmarshal(task.Task, eventTask)
jsonx.MustUnmarshal(task1.Task, eventTask)
assert.Equal(t, testdata.George.ID, eventTask.ContactID)

task, err = tasks.HandlerQueue.Pop(rc)
assert.NoError(t, err)
assert.NotNil(t, task)

// check the second task
eventTask = &handler.HandleContactEventTask{}
jsonx.MustUnmarshal(task.Task, eventTask)
jsonx.MustUnmarshal(task2.Task, eventTask)
assert.Equal(t, blake.ID, eventTask.ContactID)

// no other tasks
task, err = tasks.HandlerQueue.Pop(rc)
// no other
task, err := tasks.HandlerQueue.Pop(rc)
assert.NoError(t, err)
assert.Nil(t, task)

// should have created two throttled bulk tasks for org 2
task3, err := tasks.ThrottledQueue.Pop(rc)
assert.NoError(t, err)
assert.Equal(t, int(testdata.Org2.ID), task3.OwnerID)
assert.Equal(t, "bulk_expire", task3.Type)
task4, err := tasks.ThrottledQueue.Pop(rc)
assert.NoError(t, err)
assert.Equal(t, int(testdata.Org2.ID), task4.OwnerID)
assert.Equal(t, "bulk_expire", task4.Type)

// if task runs again, these tasks won't be re-queued
res, err = cron.Run(ctx, rt)
assert.NoError(t, err)
assert.Equal(t, map[string]any{"exited": 0, "dupes": 8, "queued_handler": 0, "queued_bulk": 0}, res)
}

func TestExpireVoiceSessions(t *testing.T) {
Expand Down
Loading

0 comments on commit ee6a392

Please sign in to comment.