Skip to content

Commit

Permalink
Start adding support for keyword triggers with multiple keywords
Browse files Browse the repository at this point in the history
  • Loading branch information
rowanseymour committed Oct 11, 2023
1 parent 3cf5f98 commit ad59f75
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 86 deletions.
95 changes: 44 additions & 51 deletions core/models/triggers.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,16 @@ const NilTriggerID = TriggerID(0)
// Trigger represents a trigger in an organization
type Trigger struct {
t struct {
ID TriggerID `json:"id"`
FlowID FlowID `json:"flow_id"`
TriggerType TriggerType `json:"trigger_type"`
Keyword string `json:"keyword"`
MatchType MatchType `json:"match_type"`
ChannelID ChannelID `json:"channel_id"`
ReferrerID string `json:"referrer_id"`
IncludeGroupIDs []GroupID `json:"include_group_ids"`
ExcludeGroupIDs []GroupID `json:"exclude_group_ids"`
ContactIDs []ContactID `json:"contact_ids,omitempty"`
ID TriggerID `json:"id"`
FlowID FlowID `json:"flow_id"`
TriggerType TriggerType `json:"trigger_type"`
Keywords pq.StringArray `json:"keywords"`
MatchType MatchType `json:"match_type"`
ChannelID ChannelID `json:"channel_id"`
ReferrerID string `json:"referrer_id"`
IncludeGroupIDs []GroupID `json:"include_group_ids"`
ExcludeGroupIDs []GroupID `json:"exclude_group_ids"`
ContactIDs []ContactID `json:"contact_ids,omitempty"`
}
}

Expand All @@ -68,7 +68,7 @@ func (t *Trigger) ID() TriggerID { return t.t.ID }

func (t *Trigger) FlowID() FlowID { return t.t.FlowID }
func (t *Trigger) TriggerType() TriggerType { return t.t.TriggerType }
func (t *Trigger) Keyword() string { return t.t.Keyword }
func (t *Trigger) Keywords() []string { return []string(t.t.Keywords) }
func (t *Trigger) MatchType() MatchType { return t.t.MatchType }
func (t *Trigger) ChannelID() ChannelID { return t.t.ChannelID }
func (t *Trigger) ReferrerID() string { return t.t.ReferrerID }
Expand All @@ -82,20 +82,9 @@ func (t *Trigger) KeywordMatchType() triggers.KeywordMatchType {
return triggers.KeywordMatchTypeOnlyWord
}

// Match returns the match for this trigger, if any
func (t *Trigger) Match() *triggers.KeywordMatch {
if t.Keyword() != "" {
return &triggers.KeywordMatch{
Type: t.KeywordMatchType(),
Keyword: t.Keyword(),
}
}
return nil
}

// loadTriggers loads all non-schedule triggers for the passed in org
func loadTriggers(ctx context.Context, db *sql.DB, orgID OrgID) ([]*Trigger, error) {
rows, err := db.QueryContext(ctx, selectTriggersSQL, orgID)
rows, err := db.QueryContext(ctx, sqlSelectTriggersByOrg, orgID)
if err != nil {
return nil, errors.Wrapf(err, "error querying triggers for org: %d", orgID)
}
Expand All @@ -116,7 +105,7 @@ func loadTriggers(ctx context.Context, db *sql.DB, orgID OrgID) ([]*Trigger, err
}

// FindMatchingMsgTrigger finds the best match trigger for an incoming message from the given contact
func FindMatchingMsgTrigger(oa *OrgAssets, channel *Channel, contact *flows.Contact, text string) *Trigger {
func FindMatchingMsgTrigger(oa *OrgAssets, channel *Channel, contact *flows.Contact, text string) (*Trigger, string) {
// determine our message keyword
words := utils.TokenizeString(text)
keyword := ""
Expand All @@ -127,19 +116,29 @@ func FindMatchingMsgTrigger(oa *OrgAssets, channel *Channel, contact *flows.Cont
only = len(words) == 1
}

// for each candidate trigger, the keyword that matched
candidateKeywords := make(map[*Trigger]string, 10)

candidates := findTriggerCandidates(oa, KeywordTriggerType, func(t *Trigger) bool {
return envs.CollateEquals(oa.Env(), t.Keyword(), keyword) && (t.MatchType() == MatchFirst || (t.MatchType() == MatchOnly && only))
for _, k := range t.Keywords() {
m := envs.CollateEquals(oa.Env(), k, keyword) && (t.MatchType() == MatchFirst || (t.MatchType() == MatchOnly && only))
if m {
candidateKeywords[t] = k
return true
}
}
return false
})

// if we have a matching keyword trigger return that, otherwise we move on to catchall triggers..
byKeyword := findBestTriggerMatch(candidates, channel, contact)
if byKeyword != nil {
return byKeyword
return byKeyword, candidateKeywords[byKeyword]
}

candidates = findTriggerCandidates(oa, CatchallTriggerType, nil)

return findBestTriggerMatch(candidates, channel, contact)
return findBestTriggerMatch(candidates, channel, contact), ""
}

// FindMatchingIncomingCallTrigger finds the best match trigger for incoming calls
Expand Down Expand Up @@ -306,30 +305,24 @@ func triggerMatchQualifiers(t *Trigger, channel *Channel, contactGroups map[Grou
return true, score
}

const selectTriggersSQL = `
SELECT ROW_TO_JSON(r) FROM (SELECT
t.id as id,
t.flow_id as flow_id,
t.trigger_type as trigger_type,
t.keyword as keyword,
t.match_type as match_type,
t.channel_id as channel_id,
COALESCE(t.referrer_id, '') as referrer_id,
ARRAY_REMOVE(ARRAY_AGG(DISTINCT ig.contactgroup_id), NULL) as include_group_ids,
ARRAY_REMOVE(ARRAY_AGG(DISTINCT eg.contactgroup_id), NULL) as exclude_group_ids
FROM
triggers_trigger t
LEFT OUTER JOIN triggers_trigger_groups ig ON t.id = ig.trigger_id
LEFT OUTER JOIN triggers_trigger_exclude_groups eg ON t.id = eg.trigger_id
WHERE
t.org_id = $1 AND
t.is_active = TRUE AND
t.is_archived = FALSE AND
t.trigger_type != 'S'
GROUP BY
t.id
) r;
`
const sqlSelectTriggersByOrg = `
SELECT ROW_TO_JSON(r) FROM (
SELECT
t.id as id,
t.flow_id as flow_id,
t.trigger_type as trigger_type,
CASE WHEN t.keyword IS NOT NULL AND t.keyword != '' THEN ARRAY[t.keyword] ELSE NULL END as keywords,
t.match_type as match_type,
t.channel_id as channel_id,
COALESCE(t.referrer_id, '') as referrer_id,
ARRAY_REMOVE(ARRAY_AGG(DISTINCT ig.contactgroup_id), NULL) as include_group_ids,
ARRAY_REMOVE(ARRAY_AGG(DISTINCT eg.contactgroup_id), NULL) as exclude_group_ids
FROM triggers_trigger t
LEFT OUTER JOIN triggers_trigger_groups ig ON t.id = ig.trigger_id
LEFT OUTER JOIN triggers_trigger_exclude_groups eg ON t.id = eg.trigger_id
WHERE t.org_id = $1 AND t.is_active = TRUE AND t.is_archived = FALSE AND t.trigger_type != 'S'
GROUP BY t.id
) r;`

const selectTriggersByContactIDsSQL = `
SELECT
Expand Down
64 changes: 33 additions & 31 deletions core/models/triggers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"github.com/nyaruka/mailroom/core/models"
"github.com/nyaruka/mailroom/testsuite"
"github.com/nyaruka/mailroom/testsuite/testdata"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand All @@ -29,7 +28,7 @@ func TestLoadTriggers(t *testing.T) {
id models.TriggerID
type_ models.TriggerType
flowID models.FlowID
keyword string
keywords []string
keywordMatchType models.MatchType
referrerID string
includeGroups []models.GroupID
Expand All @@ -41,22 +40,22 @@ func TestLoadTriggers(t *testing.T) {
id: testdata.InsertKeywordTrigger(rt, testdata.Org1, testdata.Favorites, "join", models.MatchFirst, nil, nil, nil),
type_: models.KeywordTriggerType,
flowID: testdata.Favorites.ID,
keyword: "join",
keywords: []string{"join"},
keywordMatchType: models.MatchFirst,
},
{
id: testdata.InsertKeywordTrigger(rt, testdata.Org1, testdata.Favorites, "join", models.MatchFirst, nil, nil, testdata.TwilioChannel),
type_: models.KeywordTriggerType,
flowID: testdata.Favorites.ID,
keyword: "join",
keywords: []string{"join"},
keywordMatchType: models.MatchFirst,
channelID: testdata.TwilioChannel.ID,
},
{
id: testdata.InsertKeywordTrigger(rt, testdata.Org1, testdata.PickANumber, "start", models.MatchOnly, []*testdata.Group{testdata.DoctorsGroup, testdata.TestersGroup}, []*testdata.Group{farmersGroup}, nil),
type_: models.KeywordTriggerType,
flowID: testdata.PickANumber.ID,
keyword: "start",
keywords: []string{"start"},
keywordMatchType: models.MatchOnly,
includeGroups: []models.GroupID{testdata.DoctorsGroup.ID, testdata.TestersGroup.ID},
excludeGroups: []models.GroupID{farmersGroup.ID},
Expand All @@ -69,9 +68,10 @@ func TestLoadTriggers(t *testing.T) {
excludeGroups: []models.GroupID{farmersGroup.ID},
},
{
id: testdata.InsertIncomingCallTrigger(rt, testdata.Org1, testdata.Favorites, []*testdata.Group{testdata.DoctorsGroup, testdata.TestersGroup}, []*testdata.Group{farmersGroup}, testdata.TwilioChannel),
type_: models.IncomingCallTriggerType,
flowID: testdata.Favorites.ID,
id: testdata.InsertIncomingCallTrigger(rt, testdata.Org1, testdata.Favorites, []*testdata.Group{testdata.DoctorsGroup, testdata.TestersGroup}, []*testdata.Group{farmersGroup}, testdata.TwilioChannel),
type_: models.IncomingCallTriggerType,
flowID: testdata.Favorites.ID,

includeGroups: []models.GroupID{testdata.DoctorsGroup.ID, testdata.TestersGroup.ID},
excludeGroups: []models.GroupID{farmersGroup.ID},
channelID: testdata.TwilioChannel.ID,
Expand Down Expand Up @@ -123,7 +123,7 @@ func TestLoadTriggers(t *testing.T) {
assert.Equal(t, tc.id, actual.ID(), "id mismatch in trigger #%d", i)
assert.Equal(t, tc.type_, actual.TriggerType(), "type mismatch in trigger #%d", i)
assert.Equal(t, tc.flowID, actual.FlowID(), "flow id mismatch in trigger #%d", i)
assert.Equal(t, tc.keyword, actual.Keyword(), "keyword mismatch in trigger #%d", i)
assert.Equal(t, tc.keywords, actual.Keywords(), "keywords mismatch in trigger #%d", i)
assert.Equal(t, tc.keywordMatchType, actual.MatchType(), "match type mismatch in trigger #%d", i)
assert.Equal(t, tc.referrerID, actual.ReferrerID(), "referrer id mismatch in trigger #%d", i)
assert.ElementsMatch(t, tc.includeGroups, actual.IncludeGroupIDs(), "include groups mismatch in trigger #%d", i)
Expand Down Expand Up @@ -172,34 +172,36 @@ func TestFindMatchingMsgTrigger(t *testing.T) {
channel *models.Channel
contact *flows.Contact
expectedTriggerID models.TriggerID
expectedKeyword string
}{
{" join ", nil, cathy, joinID},
{"JOIN", nil, cathy, joinID},
{"JOIN", twilioChannels[0], cathy, joinTwilioOnlyID},
{"JOIN", facebookChannels[0], cathy, joinID},
{"join this", nil, cathy, joinID},
{"resist", nil, george, resistID},
{"resist", twilioChannels[0], george, resistTwilioOnlyID},
{"resist", nil, bob, doctorsID},
{"resist", twilioChannels[0], cathy, resistTwilioOnlyID},
{"resist", nil, cathy, doctorsAndNotTestersID},
{"resist this", nil, cathy, doctorsCatchallID},
{" 👍 ", nil, george, emojiID},
{"👍🏾", nil, george, emojiID}, // is 👍 + 🏾
{"😀👍", nil, george, othersAllID},
{"other", nil, cathy, doctorsCatchallID},
{"other", nil, george, othersAllID},
{"", nil, george, othersAllID},
{"start", twilioChannels[0], cathy, startTwilioOnlyID},
{"start", facebookChannels[0], cathy, doctorsCatchallID},
{"start", twilioChannels[0], george, startTwilioOnlyID},
{"start", facebookChannels[0], george, othersAllID},
{" join ", nil, cathy, joinID, "join"},
{"JOIN", nil, cathy, joinID, "join"},
{"JOIN", twilioChannels[0], cathy, joinTwilioOnlyID, "join"},
{"JOIN", facebookChannels[0], cathy, joinID, "join"},
{"join this", nil, cathy, joinID, "join"},
{"resist", nil, george, resistID, "resist"},
{"resist", twilioChannels[0], george, resistTwilioOnlyID, "resist"},
{"resist", nil, bob, doctorsID, "resist"},
{"resist", twilioChannels[0], cathy, resistTwilioOnlyID, "resist"},
{"resist", nil, cathy, doctorsAndNotTestersID, "resist"},
{"resist this", nil, cathy, doctorsCatchallID, ""},
{" 👍 ", nil, george, emojiID, "👍"},
{"👍🏾", nil, george, emojiID, "👍"}, // is 👍 + 🏾
{"😀👍", nil, george, othersAllID, ""},
{"other", nil, cathy, doctorsCatchallID, ""},
{"other", nil, george, othersAllID, ""},
{"", nil, george, othersAllID, ""},
{"start", twilioChannels[0], cathy, startTwilioOnlyID, "start"},
{"start", facebookChannels[0], cathy, doctorsCatchallID, ""},
{"start", twilioChannels[0], george, startTwilioOnlyID, "start"},
{"start", facebookChannels[0], george, othersAllID, ""},
}

for _, tc := range tcs {
trigger := models.FindMatchingMsgTrigger(oa, tc.channel, tc.contact, tc.text)
trigger, keyword := models.FindMatchingMsgTrigger(oa, tc.channel, tc.contact, tc.text)

assertTrigger(t, tc.expectedTriggerID, trigger, "trigger mismatch for %s sending '%s'", tc.contact.Name(), tc.text)
assert.Equal(t, tc.expectedKeyword, keyword, "keyword mismatch for %s sending '%s'", tc.contact.Name(), tc.text)
}
}

Expand Down
9 changes: 7 additions & 2 deletions core/tasks/handler/contact_tasks.go
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ func handleMsgEvent(ctx context.Context, rt *runtime.Runtime, event *MsgEvent) e
}

// find any matching triggers
trigger := models.FindMatchingMsgTrigger(oa, channel, contact, event.Text)
trigger, keyword := models.FindMatchingMsgTrigger(oa, channel, contact, event.Text)

// look for a waiting session for this contact
session, err := models.FindWaitingSessionForContact(ctx, rt.DB, rt.SessionStorage, oa, models.FlowTypeMessaging, contact)
Expand Down Expand Up @@ -480,8 +480,13 @@ func handleMsgEvent(ctx context.Context, rt *runtime.Runtime, event *MsgEvent) e
return nil
}

tb := triggers.NewBuilder(oa.Env(), flow.Reference(), contact).Msg(msgIn)
if keyword != "" {
tb = tb.WithMatch(&triggers.KeywordMatch{Type: trigger.KeywordMatchType(), Keyword: keyword})
}

// otherwise build the trigger and start the flow directly
trigger := triggers.NewBuilder(oa.Env(), flow.Reference(), contact).Msg(msgIn).WithMatch(trigger.Match()).Build()
trigger := tb.Build()
_, err = runner.StartFlowForContacts(ctx, rt, oa, flow, []*models.Contact{modelContact}, []flows.Trigger{trigger}, flowMsgHook, true)
if err != nil {
return errors.Wrapf(err, "error starting flow for contact")
Expand Down
8 changes: 6 additions & 2 deletions web/simulation/simulation.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ func handleResume(ctx context.Context, rt *runtime.Runtime, r *resumeRequest) (a
// if this is a msg resume we want to check whether it might be caught by a trigger
if resume.Type() == resumes.TypeMsg {
msgResume := resume.(*resumes.MsgResume)
trigger := models.FindMatchingMsgTrigger(oa, nil, msgResume.Contact(), msgResume.Msg().Text())
trigger, keyword := models.FindMatchingMsgTrigger(oa, nil, msgResume.Contact(), msgResume.Msg().Text())
if trigger != nil {
var flow *models.Flow
for _, r := range session.Runs() {
Expand Down Expand Up @@ -228,7 +228,11 @@ func handleResume(ctx context.Context, rt *runtime.Runtime, r *resumeRequest) (a
// non-simulation IVR triggers to use that so that this is consistent.
sessionTrigger = tb.Manual().WithCall(testChannel, testURN).Build()
} else {
sessionTrigger = tb.Msg(msgResume.Msg()).WithMatch(trigger.Match()).Build()
mtb := tb.Msg(msgResume.Msg())
if keyword != "" {
mtb = mtb.WithMatch(&triggers.KeywordMatch{Type: trigger.KeywordMatchType(), Keyword: keyword})
}
sessionTrigger = mtb.Build()
}

return triggerFlow(ctx, rt, oa, sessionTrigger)
Expand Down

0 comments on commit ad59f75

Please sign in to comment.