From d36b8b2ae4610d44ca17ed717a7128fb9f96810e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20=C3=81ngel=20Ortu=C3=B1o?= Date: Thu, 8 Sep 2022 16:46:30 +0200 Subject: [PATCH 1/2] hook: add context execution field MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Miguel Ángel Ortuño --- pkg/admin/server/service.go | 6 ++-- pkg/c2s/in.go | 7 ++-- pkg/cluster/connmanager/manager.go | 4 +-- pkg/cluster/connmanager/manager_test.go | 9 +++-- pkg/cluster/memberlist/kv_memberlist.go | 7 ++-- pkg/cluster/resourcemanager/kvmanager.go | 2 +- pkg/component/component.go | 10 +++--- pkg/component/xep0114/in.go | 7 ++-- pkg/hook/hooks.go | 11 +++--- pkg/hook/hooks_test.go | 23 ++++++------ pkg/module/module.go | 10 +++--- pkg/module/offline/offline.go | 18 +++++----- pkg/module/offline/offline_test.go | 9 +++-- pkg/module/roster/roster.go | 15 ++++---- pkg/module/roster/roster_test.go | 30 +++++++++------- pkg/module/xep0012/last.go | 22 ++++++------ pkg/module/xep0012/last_test.go | 6 ++-- pkg/module/xep0030/disco.go | 7 ++-- pkg/module/xep0030/disco_test.go | 20 ++++++----- pkg/module/xep0049/private.go | 14 ++++---- pkg/module/xep0054/vcard.go | 14 ++++---- pkg/module/xep0115/caps.go | 18 +++++----- pkg/module/xep0115/caps_test.go | 6 ++-- pkg/module/xep0191/blocklist.go | 35 +++++++++++------- pkg/module/xep0191/blocklist_test.go | 9 +++-- pkg/module/xep0198/stream.go | 10 +++--- pkg/module/xep0198/stream_test.go | 46 ++++++++++++++---------- pkg/module/xep0199/ping.go | 6 ++-- pkg/module/xep0199/ping_test.go | 6 ++-- pkg/module/xep0280/carbons.go | 10 +++--- pkg/module/xep0280/carbons_test.go | 11 +++--- pkg/s2s/in.go | 7 ++-- pkg/s2s/out.go | 7 ++-- 33 files changed, 248 insertions(+), 174 deletions(-) diff --git a/pkg/admin/server/service.go b/pkg/admin/server/service.go index 04f91f1e0..67d3dde58 100644 --- a/pkg/admin/server/service.go +++ b/pkg/admin/server/service.go @@ -68,10 +68,11 @@ func (s *usersService) CreateUser(ctx context.Context, req *userspb.CreateUserRe return nil, err } // run user created hook - _, err := s.hk.Run(ctx, hook.UserCreated, &hook.ExecutionContext{ + _, err := s.hk.Run(hook.UserCreated, &hook.ExecutionContext{ Info: &hook.UserInfo{ Username: username, }, + Context: ctx, }) if err != nil { return nil, err @@ -102,10 +103,11 @@ func (s *usersService) DeleteUser(ctx context.Context, req *userspb.DeleteUserRe return nil, status.Error(codes.Internal, err.Error()) } // run user deleted hook - _, err := s.hk.Run(ctx, hook.UserDeleted, &hook.ExecutionContext{ + _, err := s.hk.Run(hook.UserDeleted, &hook.ExecutionContext{ Info: &hook.UserInfo{ Username: username, }, + Context: ctx, }) if err != nil { return nil, err diff --git a/pkg/c2s/in.go b/pkg/c2s/in.go index 32709514f..11f5d30a6 100644 --- a/pkg/c2s/in.go +++ b/pkg/c2s/in.go @@ -1212,9 +1212,10 @@ func (s *inC2S) getState() state { } func (s *inC2S) runHook(ctx context.Context, hookName string, inf *hook.C2SStreamInfo) (halt bool, err error) { - return s.hk.Run(ctx, hookName, &hook.ExecutionContext{ - Info: inf, - Sender: s, + return s.hk.Run(hookName, &hook.ExecutionContext{ + Info: inf, + Sender: s, + Context: ctx, }) } diff --git a/pkg/cluster/connmanager/manager.go b/pkg/cluster/connmanager/manager.go index 16555caba..a40be1f1c 100644 --- a/pkg/cluster/connmanager/manager.go +++ b/pkg/cluster/connmanager/manager.go @@ -108,7 +108,7 @@ func (m *Manager) Stop(_ context.Context) error { return nil } -func (m *Manager) onMemberListUpdated(ctx context.Context, execCtx *hook.ExecutionContext) error { +func (m *Manager) onMemberListUpdated(execCtx *hook.ExecutionContext) error { m.mu.Lock() defer m.mu.Unlock() @@ -125,7 +125,7 @@ func (m *Manager) onMemberListUpdated(ctx context.Context, execCtx *hook.Executi // dial connections to new registered members... for _, member := range inf.Registered { cl := newConn(member.Host, member.Port, member.APIVer) - if err := cl.dialContext(ctx); err != nil { + if err := cl.dialContext(execCtx.Context); err != nil { level.Warn(m.logger).Log("msg", "failed to dial cluster conn", "err", err) continue } diff --git a/pkg/cluster/connmanager/manager_test.go b/pkg/cluster/connmanager/manager_test.go index 358c09b78..60f4910ba 100644 --- a/pkg/cluster/connmanager/manager_test.go +++ b/pkg/cluster/connmanager/manager_test.go @@ -46,21 +46,23 @@ func TestConnections_UpdateMembers(t *testing.T) { _ = connMng.Start(context.Background()) // register cluster member - _, _ = hk.Run(context.Background(), hook.MemberListUpdated, &hook.ExecutionContext{ + _, _ = hk.Run(hook.MemberListUpdated, &hook.ExecutionContext{ Info: &hook.MemberListInfo{ Registered: []clustermodel.Member{ {InstanceID: "a1234", Host: "192.168.2.1", Port: 1234, APIVer: version.ClusterAPIVersion}, }, }, + Context: context.Background(), }) conn1, err1 := connMng.GetConnection("a1234") // register cluster member - _, _ = hk.Run(context.Background(), hook.MemberListUpdated, &hook.ExecutionContext{ + _, _ = hk.Run(hook.MemberListUpdated, &hook.ExecutionContext{ Info: &hook.MemberListInfo{ UnregisteredKeys: []string{"a1234"}, }, + Context: context.Background(), }) conn2, err2 := connMng.GetConnection("a1234") @@ -94,12 +96,13 @@ func TestConnections_IncompatibleClusterAPI(t *testing.T) { _ = connMng.Start(context.Background()) incompVer := version.NewVersion(version.ClusterAPIVersion.Major()+1, 0, 0) - _, _ = hk.Run(context.Background(), hook.MemberListUpdated, &hook.ExecutionContext{ + _, _ = hk.Run(hook.MemberListUpdated, &hook.ExecutionContext{ Info: &hook.MemberListInfo{ Registered: []clustermodel.Member{ {InstanceID: "a1234", Host: "192.168.2.1", Port: 1234, APIVer: incompVer}, }, }, + Context: context.Background(), }) // then diff --git a/pkg/cluster/memberlist/kv_memberlist.go b/pkg/cluster/memberlist/kv_memberlist.go index 9d83be338..80aa1bc9a 100644 --- a/pkg/cluster/memberlist/kv_memberlist.go +++ b/pkg/cluster/memberlist/kv_memberlist.go @@ -245,9 +245,10 @@ func (ml *KVMemberList) processKVEvents(ctx context.Context, kvEvents []kvtypes. } func (ml *KVMemberList) runHook(ctx context.Context, inf *hook.MemberListInfo) error { - _, err := ml.hk.Run(ctx, hook.MemberListUpdated, &hook.ExecutionContext{ - Info: inf, - Sender: ml, + _, err := ml.hk.Run(hook.MemberListUpdated, &hook.ExecutionContext{ + Info: inf, + Sender: ml, + Context: ctx, }) return err } diff --git a/pkg/cluster/resourcemanager/kvmanager.go b/pkg/cluster/resourcemanager/kvmanager.go index f45807ace..90c13e3a8 100644 --- a/pkg/cluster/resourcemanager/kvmanager.go +++ b/pkg/cluster/resourcemanager/kvmanager.go @@ -213,7 +213,7 @@ func (m *kvManager) Stop(_ context.Context) error { return nil } -func (m *kvManager) onMemberListUpdated(_ context.Context, execCtx *hook.ExecutionContext) error { +func (m *kvManager) onMemberListUpdated(execCtx *hook.ExecutionContext) error { inf := execCtx.Info.(*hook.MemberListInfo) if len(inf.UnregisteredKeys) == 0 { return nil diff --git a/pkg/component/component.go b/pkg/component/component.go index 6ad48b5b6..0b36da842 100644 --- a/pkg/component/component.go +++ b/pkg/component/component.go @@ -157,11 +157,12 @@ func (c *Components) Start(ctx context.Context) error { } level.Info(c.logger).Log("msg", "started components", "total_components", len(c.comps)) - _, err := c.hk.Run(ctx, hook.ComponentsStarted, &hook.ExecutionContext{ + _, err := c.hk.Run(hook.ComponentsStarted, &hook.ExecutionContext{ Info: &hook.ComponentsInfo{ Hosts: hosts, }, - Sender: c, + Sender: c, + Context: ctx, }) return err } @@ -181,11 +182,12 @@ func (c *Components) Stop(ctx context.Context) error { } level.Info(c.logger).Log("msg", "stopped components", "total_components", len(c.comps)) - _, err := c.hk.Run(ctx, hook.ComponentsStopped, &hook.ExecutionContext{ + _, err := c.hk.Run(hook.ComponentsStopped, &hook.ExecutionContext{ Info: &hook.ComponentsInfo{ Hosts: hosts, }, - Sender: c, + Sender: c, + Context: ctx, }) return err } diff --git a/pkg/component/xep0114/in.go b/pkg/component/xep0114/in.go index 563d4299b..b166a3629 100644 --- a/pkg/component/xep0114/in.go +++ b/pkg/component/xep0114/in.go @@ -453,9 +453,10 @@ func (s *inComponent) getState() inComponentState { } func (s *inComponent) runHook(ctx context.Context, hookName string, inf *hook.ExternalComponentInfo) (halt bool, err error) { - return s.hk.Run(ctx, hookName, &hook.ExecutionContext{ - Info: inf, - Sender: s, + return s.hk.Run(hookName, &hook.ExecutionContext{ + Info: inf, + Sender: s, + Context: ctx, }) } diff --git a/pkg/hook/hooks.go b/pkg/hook/hooks.go index 51b445bac..f61542dee 100644 --- a/pkg/hook/hooks.go +++ b/pkg/hook/hooks.go @@ -38,15 +38,16 @@ const ( ) // Handler defines a generic hook handler function. -type Handler func(ctx context.Context, execCtx *ExecutionContext) error +type Handler func(execCtx *ExecutionContext) error // ErrStopped error is returned by a handler to halt hook execution. var ErrStopped = errors.New("hook: execution stopped") // ExecutionContext defines a hook execution info context. type ExecutionContext struct { - Info interface{} - Sender interface{} + Info interface{} + Sender interface{} + Context context.Context } type handler struct { @@ -101,13 +102,13 @@ func (h *Hooks) RemoveHook(hook string, hnd Handler) { // Run invokes all hook handlers in order. // If halted return value is true no more handlers are invoked. -func (h *Hooks) Run(ctx context.Context, hook string, execCtx *ExecutionContext) (halted bool, err error) { +func (h *Hooks) Run(hook string, execCtx *ExecutionContext) (halted bool, err error) { h.mu.RLock() defer h.mu.RUnlock() handlers := h.handlers[hook] for _, handler := range handlers { - err := handler.h(ctx, execCtx) + err := handler.h(execCtx) switch { case err == nil: break diff --git a/pkg/hook/hooks_test.go b/pkg/hook/hooks_test.go index 466f7975d..2f29e07c9 100644 --- a/pkg/hook/hooks_test.go +++ b/pkg/hook/hooks_test.go @@ -15,7 +15,6 @@ package hook import ( - "context" "testing" "github.com/stretchr/testify/require" @@ -43,9 +42,9 @@ func TestHooks_Remove(t *testing.T) { h := NewHooks() // when - var hnd1 Handler = func(ctx context.Context, execCtx *ExecutionContext) error { return nil } - var hnd2 Handler = func(ctx context.Context, execCtx *ExecutionContext) error { return nil } - var hnd3 Handler = func(ctx context.Context, execCtx *ExecutionContext) error { return nil } + var hnd1 Handler = func(execCtx *ExecutionContext) error { return nil } + var hnd2 Handler = func(execCtx *ExecutionContext) error { return nil } + var hnd3 Handler = func(execCtx *ExecutionContext) error { return nil } h.AddHook("h1", hnd1, 0) h.AddHook("h1", hnd2, 0) @@ -65,15 +64,15 @@ func TestHooks_Run(t *testing.T) { // when var i int - var hnd1 Handler = func(ctx context.Context, execCtx *ExecutionContext) error { i++; return nil } - var hnd2 Handler = func(ctx context.Context, execCtx *ExecutionContext) error { i++; return nil } - var hnd3 Handler = func(ctx context.Context, execCtx *ExecutionContext) error { i++; return nil } + var hnd1 Handler = func(execCtx *ExecutionContext) error { i++; return nil } + var hnd2 Handler = func(execCtx *ExecutionContext) error { i++; return nil } + var hnd3 Handler = func(execCtx *ExecutionContext) error { i++; return nil } h.AddHook("h1", hnd1, 0) h.AddHook("h1", hnd2, 0) h.AddHook("h1", hnd3, 0) - halted, err := h.Run(context.Background(), "h1", nil) + halted, err := h.Run("h1", nil) // then require.Nil(t, err) @@ -88,15 +87,15 @@ func TestHooks_HaltedRun(t *testing.T) { // when var i int - var hnd1 Handler = func(ctx context.Context, execCtx *ExecutionContext) error { i++; return nil } - var hnd2 Handler = func(ctx context.Context, execCtx *ExecutionContext) error { i++; return ErrStopped } - var hnd3 Handler = func(ctx context.Context, execCtx *ExecutionContext) error { i++; return nil } + var hnd1 Handler = func(execCtx *ExecutionContext) error { i++; return nil } + var hnd2 Handler = func(execCtx *ExecutionContext) error { i++; return ErrStopped } + var hnd3 Handler = func(execCtx *ExecutionContext) error { i++; return nil } h.AddHook("h1", hnd1, 10) h.AddHook("h1", hnd2, 5) h.AddHook("h1", hnd3, 0) - halted, err := h.Run(context.Background(), "h1", nil) + halted, err := h.Run("h1", nil) // then require.Nil(t, err) diff --git a/pkg/module/module.go b/pkg/module/module.go index 939f2aaef..3019c348a 100644 --- a/pkg/module/module.go +++ b/pkg/module/module.go @@ -104,11 +104,12 @@ func (m *Modules) Start(ctx context.Context) error { "iq_processors_count", len(m.iqProcessors), "mods_count", len(m.mods), ) - _, err := m.hk.Run(ctx, hook.ModulesStarted, &hook.ExecutionContext{ + _, err := m.hk.Run(hook.ModulesStarted, &hook.ExecutionContext{ Info: &hook.ModulesInfo{ ModuleNames: modNames, }, - Sender: m, + Sender: m, + Context: ctx, }) return err } @@ -127,11 +128,12 @@ func (m *Modules) Stop(ctx context.Context) error { "iq_processors_count", len(m.iqProcessors), "mods_count", len(m.mods), ) - _, err := m.hk.Run(ctx, hook.ModulesStopped, &hook.ExecutionContext{ + _, err := m.hk.Run(hook.ModulesStopped, &hook.ExecutionContext{ Info: &hook.ModulesInfo{ ModuleNames: modNames, }, - Sender: m, + Sender: m, + Context: ctx, }) return err } diff --git a/pkg/module/offline/offline.go b/pkg/module/offline/offline.go index 5c3b9bc03..49f018468 100644 --- a/pkg/module/offline/offline.go +++ b/pkg/module/offline/offline.go @@ -118,7 +118,7 @@ func (m *Offline) Stop(_ context.Context) error { return nil } -func (m *Offline) onWillRouteElement(ctx context.Context, execCtx *hook.ExecutionContext) error { +func (m *Offline) onWillRouteElement(execCtx *hook.ExecutionContext) error { var elem stravaganza.Element switch inf := execCtx.Info.(type) { @@ -135,17 +135,17 @@ func (m *Offline) onWillRouteElement(ctx context.Context, execCtx *hook.Executio if !m.hosts.IsLocalHost(toJID.Domain()) { return nil } - rss, err := m.resMng.GetResources(ctx, toJID.Node()) + rss, err := m.resMng.GetResources(execCtx.Context, toJID.Node()) if err != nil { return err } if len(rss) > 0 { return nil } - return m.archiveMessage(ctx, msg) + return m.archiveMessage(execCtx.Context, msg) } -func (m *Offline) onC2SPresenceRecv(ctx context.Context, execCtx *hook.ExecutionContext) error { +func (m *Offline) onC2SPresenceRecv(execCtx *hook.ExecutionContext) error { inf := execCtx.Info.(*hook.C2SStreamInfo) pr := inf.Element.(*stravaganza.Presence) @@ -156,11 +156,12 @@ func (m *Offline) onC2SPresenceRecv(ctx context.Context, execCtx *hook.Execution if !pr.IsAvailable() || pr.Priority() < 0 { return nil } - return m.deliverOfflineMessages(ctx, toJID.Node()) + return m.deliverOfflineMessages(execCtx.Context, toJID.Node()) } -func (m *Offline) onUserDeleted(ctx context.Context, execCtx *hook.ExecutionContext) error { +func (m *Offline) onUserDeleted(execCtx *hook.ExecutionContext) error { inf := execCtx.Info.(*hook.UserInfo) + ctx := execCtx.Context lockID := offlineQueueLockID(inf.Username) @@ -226,12 +227,13 @@ func (m *Offline) archiveMessage(ctx context.Context, msg *stravaganza.Message) if err := m.rep.InsertOfflineMessage(ctx, dMsg, username); err != nil { return err } - _, err = m.hk.Run(ctx, hook.OfflineMessageArchived, &hook.ExecutionContext{ + _, err = m.hk.Run(hook.OfflineMessageArchived, &hook.ExecutionContext{ Info: &hook.OfflineInfo{ Username: username, Message: dMsg, }, - Sender: m, + Sender: m, + Context: ctx, }) if err != nil { return err diff --git a/pkg/module/offline/offline_test.go b/pkg/module/offline/offline_test.go index f79238ee6..b59985572 100644 --- a/pkg/module/offline/offline_test.go +++ b/pkg/module/offline/offline_test.go @@ -70,10 +70,11 @@ func TestOffline_ArchiveOfflineMessage(t *testing.T) { _ = m.Start(context.Background()) defer func() { _ = m.Stop(context.Background()) }() - _, _ = hk.Run(context.Background(), hook.C2SStreamWillRouteElement, &hook.ExecutionContext{ + _, _ = hk.Run(hook.C2SStreamWillRouteElement, &hook.ExecutionContext{ Info: &hook.C2SStreamInfo{ Element: msg, }, + Context: context.Background(), }) // then @@ -132,10 +133,11 @@ func TestOffline_ArchiveOfflineMessageQueueFull(t *testing.T) { _ = m.Start(context.Background()) defer func() { _ = m.Stop(context.Background()) }() - halted, err := hk.Run(context.Background(), hook.C2SStreamWillRouteElement, &hook.ExecutionContext{ + halted, err := hk.Run(hook.C2SStreamWillRouteElement, &hook.ExecutionContext{ Info: &hook.C2SStreamInfo{ Element: msg, }, + Context: context.Background(), }) // then @@ -202,10 +204,11 @@ func TestOffline_DeliverOfflineMessages(t *testing.T) { pr := xmpputil.MakePresence(fromJID, toJID, stravaganza.AvailableType, nil) - _, _ = hk.Run(context.Background(), hook.C2SStreamPresenceReceived, &hook.ExecutionContext{ + _, _ = hk.Run(hook.C2SStreamPresenceReceived, &hook.ExecutionContext{ Info: &hook.C2SStreamInfo{ Element: pr, }, + Context: context.Background(), }) // then diff --git a/pkg/module/roster/roster.go b/pkg/module/roster/roster.go index a26f83e67..71f09229a 100644 --- a/pkg/module/roster/roster.go +++ b/pkg/module/roster/roster.go @@ -132,7 +132,7 @@ func (r *Roster) Stop(_ context.Context) error { return nil } -func (r *Roster) onPresenceRecv(ctx context.Context, execCtx *hook.ExecutionContext) error { +func (r *Roster) onPresenceRecv(execCtx *hook.ExecutionContext) error { var pr *stravaganza.Presence switch inf := execCtx.Info.(type) { case *hook.C2SStreamInfo: @@ -145,14 +145,16 @@ func (r *Roster) onPresenceRecv(ctx context.Context, execCtx *hook.ExecutionCont if pr.ToJID().IsFull() { return nil } - if err := r.processPresence(ctx, pr); err != nil { + if err := r.processPresence(execCtx.Context, pr); err != nil { return fmt.Errorf("roster: failed to process C2S presence: %s", err) } return nil } -func (r *Roster) onUserDeleted(ctx context.Context, execCtx *hook.ExecutionContext) error { +func (r *Roster) onUserDeleted(execCtx *hook.ExecutionContext) error { inf := execCtx.Info.(*hook.UserInfo) + ctx := execCtx.Context + return r.rep.InTransaction(ctx, func(ctx context.Context, tx repository.Transaction) error { if err := tx.DeleteRosterNotifications(ctx, inf.Username); err != nil { return err @@ -838,9 +840,10 @@ func (r *Roster) getStream(username, resource string) (stream.C2S, error) { } func (r *Roster) runHook(ctx context.Context, hookName string, inf *hook.RosterInfo) error { - _, err := r.hk.Run(ctx, hookName, &hook.ExecutionContext{ - Info: inf, - Sender: r, + _, err := r.hk.Run(hookName, &hook.ExecutionContext{ + Info: inf, + Sender: r, + Context: ctx, }) return err } diff --git a/pkg/module/roster/roster_test.go b/pkg/module/roster/roster_test.go index 9efa06191..948797bcb 100644 --- a/pkg/module/roster/roster_test.go +++ b/pkg/module/roster/roster_test.go @@ -390,8 +390,9 @@ func TestRoster_Subscribe(t *testing.T) { pr := xmpputil.MakePresence(fromJID, toJID, stravaganza.SubscribeType, nil) _ = r.Start(context.Background()) - _, _ = hk.Run(context.Background(), hook.C2SStreamPresenceReceived, &hook.ExecutionContext{ - Info: &hook.C2SStreamInfo{Element: pr}, + _, _ = hk.Run(hook.C2SStreamPresenceReceived, &hook.ExecutionContext{ + Info: &hook.C2SStreamInfo{Element: pr}, + Context: context.Background(), }) // then @@ -492,8 +493,9 @@ func TestRoster_Subscribed(t *testing.T) { pr := xmpputil.MakePresence(fromJID, toJID, stravaganza.SubscribedType, nil) _ = r.Start(context.Background()) - _, _ = hk.Run(context.Background(), hook.C2SStreamPresenceReceived, &hook.ExecutionContext{ - Info: &hook.C2SStreamInfo{Element: pr}, + _, _ = hk.Run(hook.C2SStreamPresenceReceived, &hook.ExecutionContext{ + Info: &hook.C2SStreamInfo{Element: pr}, + Context: context.Background(), }) // then @@ -602,8 +604,9 @@ func TestRoster_Unsubscribe(t *testing.T) { pr := xmpputil.MakePresence(fromJID, toJID, stravaganza.UnsubscribeType, nil) _ = r.Start(context.Background()) - _, _ = hk.Run(context.Background(), hook.C2SStreamPresenceReceived, &hook.ExecutionContext{ - Info: &hook.C2SStreamInfo{Element: pr}, + _, _ = hk.Run(hook.C2SStreamPresenceReceived, &hook.ExecutionContext{ + Info: &hook.C2SStreamInfo{Element: pr}, + Context: context.Background(), }) // then @@ -715,8 +718,9 @@ func TestRoster_Unsubscribed(t *testing.T) { pr := xmpputil.MakePresence(fromJID, toJID, stravaganza.UnsubscribedType, nil) _ = r.Start(context.Background()) - _, _ = hk.Run(context.Background(), hook.C2SStreamPresenceReceived, &hook.ExecutionContext{ - Info: &hook.C2SStreamInfo{Element: pr}, + _, _ = hk.Run(hook.C2SStreamPresenceReceived, &hook.ExecutionContext{ + Info: &hook.C2SStreamInfo{Element: pr}, + Context: context.Background(), }) // then @@ -811,8 +815,9 @@ func TestRoster_Probe(t *testing.T) { pr := xmpputil.MakePresence(fromJID, toJID, stravaganza.ProbeType, nil) _ = r.Start(context.Background()) - _, _ = hk.Run(context.Background(), hook.C2SStreamPresenceReceived, &hook.ExecutionContext{ - Info: &hook.C2SStreamInfo{Element: pr}, + _, _ = hk.Run(hook.C2SStreamPresenceReceived, &hook.ExecutionContext{ + Info: &hook.C2SStreamInfo{Element: pr}, + Context: context.Background(), }) // then @@ -920,8 +925,9 @@ func TestRoster_Available(t *testing.T) { pr := xmpputil.MakePresence(fromJID, toJID, stravaganza.AvailableType, nil) _ = r.Start(context.Background()) - _, _ = hk.Run(context.Background(), hook.C2SStreamPresenceReceived, &hook.ExecutionContext{ - Info: &hook.C2SStreamInfo{Element: pr}, + _, _ = hk.Run(hook.C2SStreamPresenceReceived, &hook.ExecutionContext{ + Info: &hook.C2SStreamInfo{Element: pr}, + Context: context.Background(), }) // then diff --git a/pkg/module/xep0012/last.go b/pkg/module/xep0012/last.go index 103ad4caa..173cb357b 100644 --- a/pkg/module/xep0012/last.go +++ b/pkg/module/xep0012/last.go @@ -132,7 +132,7 @@ func (m *Last) Stop(_ context.Context) error { return nil } -func (m *Last) onElementRecv(ctx context.Context, execCtx *hook.ExecutionContext) error { +func (m *Last) onElementRecv(execCtx *hook.ExecutionContext) error { var iq *stravaganza.IQ var ok bool @@ -147,7 +147,7 @@ func (m *Last) onElementRecv(ctx context.Context, execCtx *hook.ExecutionContext if !ok { return nil } - return m.processIncomingIQ(ctx, iq) + return m.processIncomingIQ(execCtx.Context, iq) } func (m *Last) processIncomingIQ(ctx context.Context, iq *stravaganza.IQ) error { @@ -169,15 +169,15 @@ func (m *Last) processIncomingIQ(ctx context.Context, iq *stravaganza.IQ) error return nil } -func (m *Last) onUserDeleted(ctx context.Context, execCtx *hook.ExecutionContext) error { +func (m *Last) onUserDeleted(execCtx *hook.ExecutionContext) error { inf := execCtx.Info.(*hook.UserInfo) - return m.rep.DeleteLast(ctx, inf.Username) + return m.rep.DeleteLast(execCtx.Context, inf.Username) } -func (m *Last) onC2SPresenceRecv(ctx context.Context, execCtx *hook.ExecutionContext) error { +func (m *Last) onC2SPresenceRecv(execCtx *hook.ExecutionContext) error { inf := execCtx.Info.(*hook.C2SStreamInfo) pr := inf.Element.(*stravaganza.Presence) - return m.processC2SPresence(ctx, pr) + return m.processC2SPresence(execCtx.Context, pr) } func (m *Last) processC2SPresence(ctx context.Context, pr *stravaganza.Presence) error { @@ -212,12 +212,13 @@ func (m *Last) getServerLastActivity(ctx context.Context, iq *stravaganza.IQ) er level.Info(m.logger).Log("msg", "sent server uptime", "username", iq.FromJID().Node()) - _, err := m.hk.Run(ctx, hook.LastActivityFetched, &hook.ExecutionContext{ + _, err := m.hk.Run(hook.LastActivityFetched, &hook.ExecutionContext{ Info: &hook.LastActivityInfo{ Username: iq.FromJID().Node(), JID: iq.ToJID(), }, - Sender: m, + Sender: m, + Context: ctx, }) return err } @@ -255,12 +256,13 @@ func (m *Last) getAccountLastActivity(ctx context.Context, iq *stravaganza.IQ) e level.Info(m.logger).Log("msg", "sent last activity", "username", fromJID.Node(), "target", toJID.Node()) - _, err = m.hk.Run(ctx, hook.LastActivityFetched, &hook.ExecutionContext{ + _, err = m.hk.Run(hook.LastActivityFetched, &hook.ExecutionContext{ Info: &hook.LastActivityInfo{ Username: fromJID.Node(), JID: toJID, }, - Sender: m, + Sender: m, + Context: ctx, }) return err } diff --git a/pkg/module/xep0012/last_test.go b/pkg/module/xep0012/last_test.go index 20c127e0b..05b69366e 100644 --- a/pkg/module/xep0012/last_test.go +++ b/pkg/module/xep0012/last_test.go @@ -264,10 +264,11 @@ func TestLast_InterceptInboundElement(t *testing.T) { _ = m.Start(context.Background()) defer func() { _ = m.Stop(context.Background()) }() - halted, err := m.hk.Run(context.Background(), hook.C2SStreamElementReceived, &hook.ExecutionContext{ + halted, err := m.hk.Run(hook.C2SStreamElementReceived, &hook.ExecutionContext{ Info: &hook.C2SStreamInfo{ Element: iq, }, + Context: context.Background(), }) // then @@ -303,11 +304,12 @@ func TestLast_ProcessPresence(t *testing.T) { defer func() { _ = m.Stop(context.Background()) }() jd0, _ := jid.NewWithString("ortuman@jackal.im/yard", true) - _, _ = hk.Run(context.Background(), hook.C2SStreamPresenceReceived, &hook.ExecutionContext{ + _, _ = hk.Run(hook.C2SStreamPresenceReceived, &hook.ExecutionContext{ Info: &hook.C2SStreamInfo{ JID: jd0, Element: xmpputil.MakePresence(jd0, jd0.ToBareJID(), stravaganza.UnavailableType, nil), }, + Context: context.Background(), }) // then diff --git a/pkg/module/xep0030/disco.go b/pkg/module/xep0030/disco.go index b5a062d62..45a930a92 100644 --- a/pkg/module/xep0030/disco.go +++ b/pkg/module/xep0030/disco.go @@ -161,7 +161,7 @@ func (m *Disco) AccountProvider() InfoProvider { return m.accProv } -func (m *Disco) onModulesStarted(ctx context.Context, execCtx *hook.ExecutionContext) error { +func (m *Disco) onModulesStarted(execCtx *hook.ExecutionContext) error { mods := execCtx.Sender.(modules) m.mu.Lock() @@ -169,8 +169,9 @@ func (m *Disco) onModulesStarted(ctx context.Context, execCtx *hook.ExecutionCon m.accProv = newAccountProvider(mods.AllModules(), m.rosRep, m.resMng) m.mu.Unlock() - _, err := m.hk.Run(ctx, hook.DiscoProvidersStarted, &hook.ExecutionContext{ - Sender: m, + _, err := m.hk.Run(hook.DiscoProvidersStarted, &hook.ExecutionContext{ + Sender: m, + Context: execCtx.Context, }) return err } diff --git a/pkg/module/xep0030/disco_test.go b/pkg/module/xep0030/disco_test.go index 141dbb098..c949634a4 100644 --- a/pkg/module/xep0030/disco_test.go +++ b/pkg/module/xep0030/disco_test.go @@ -56,8 +56,9 @@ func TestDisco_GetServerInfo(t *testing.T) { modsMock.AllModulesFunc = func() []module.Module { return []module.Module{modMock, d} } - _, _ = hk.Run(context.Background(), hook.ModulesStarted, &hook.ExecutionContext{ - Sender: modsMock, + _, _ = hk.Run(hook.ModulesStarted, &hook.ExecutionContext{ + Sender: modsMock, + Context: context.Background(), }) // when @@ -126,8 +127,9 @@ func TestDisco_GetServerItems(t *testing.T) { modsMock.AllModulesFunc = func() []module.Module { return nil } - _, _ = hk.Run(context.Background(), hook.ModulesStarted, &hook.ExecutionContext{ - Sender: modsMock, + _, _ = hk.Run(hook.ModulesStarted, &hook.ExecutionContext{ + Sender: modsMock, + Context: context.Background(), }) // when @@ -197,8 +199,9 @@ func TestDisco_GetAccountInfo(t *testing.T) { modsMock.AllModulesFunc = func() []module.Module { return []module.Module{modMock, d} } - _, _ = hk.Run(context.Background(), hook.ModulesStarted, &hook.ExecutionContext{ - Sender: modsMock, + _, _ = hk.Run(hook.ModulesStarted, &hook.ExecutionContext{ + Sender: modsMock, + Context: context.Background(), }) // when @@ -271,8 +274,9 @@ func TestDisco_GetAccountItems(t *testing.T) { modsMock.AllModulesFunc = func() []module.Module { return nil } - _, _ = hk.Run(context.Background(), hook.ModulesStarted, &hook.ExecutionContext{ - Sender: modsMock, + _, _ = hk.Run(hook.ModulesStarted, &hook.ExecutionContext{ + Sender: modsMock, + Context: context.Background(), }) // when diff --git a/pkg/module/xep0049/private.go b/pkg/module/xep0049/private.go index c983a033b..a568a60ce 100644 --- a/pkg/module/xep0049/private.go +++ b/pkg/module/xep0049/private.go @@ -120,9 +120,9 @@ func (m *Private) Stop(_ context.Context) error { return nil } -func (m *Private) onUserDeleted(ctx context.Context, execCtx *hook.ExecutionContext) error { +func (m *Private) onUserDeleted(execCtx *hook.ExecutionContext) error { inf := execCtx.Info.(*hook.UserInfo) - return m.rep.DeletePrivates(ctx, inf.Username) + return m.rep.DeletePrivates(execCtx.Context, inf.Username) } func (m *Private) getPrivate(ctx context.Context, iq *stravaganza.IQ, q stravaganza.Element) error { @@ -160,12 +160,13 @@ func (m *Private) getPrivate(ctx context.Context, iq *stravaganza.IQ, q stravaga _, _ = m.router.Route(ctx, resIQ) // run private fetched hook - _, err = m.hk.Run(ctx, hook.PrivateFetched, &hook.ExecutionContext{ + _, err = m.hk.Run(hook.PrivateFetched, &hook.ExecutionContext{ Info: &hook.PrivateInfo{ Username: username, Private: prvElem, }, - Sender: m, + Sender: m, + Context: ctx, }) return err } @@ -189,12 +190,13 @@ func (m *Private) setPrivate(ctx context.Context, iq *stravaganza.IQ, q stravaga level.Info(m.logger).Log("msg", "saved private XML", "username", username, "namespace", ns) // run private updated hook - _, err := m.hk.Run(ctx, hook.PrivateUpdated, &hook.ExecutionContext{ + _, err := m.hk.Run(hook.PrivateUpdated, &hook.ExecutionContext{ Info: &hook.PrivateInfo{ Username: username, Private: prv, }, - Sender: m, + Sender: m, + Context: ctx, }) if err != nil { return err diff --git a/pkg/module/xep0054/vcard.go b/pkg/module/xep0054/vcard.go index a27c3220c..7aac3b10a 100644 --- a/pkg/module/xep0054/vcard.go +++ b/pkg/module/xep0054/vcard.go @@ -110,9 +110,9 @@ func (m *VCard) Stop(_ context.Context) error { return nil } -func (m *VCard) onUserDeleted(ctx context.Context, execCtx *hook.ExecutionContext) error { +func (m *VCard) onUserDeleted(execCtx *hook.ExecutionContext) error { inf := execCtx.Info.(*hook.UserInfo) - return m.rep.DeleteVCard(ctx, inf.Username) + return m.rep.DeleteVCard(execCtx.Context, inf.Username) } func (m *VCard) getVCard(ctx context.Context, iq *stravaganza.IQ) error { @@ -141,12 +141,13 @@ func (m *VCard) getVCard(ctx context.Context, iq *stravaganza.IQ) error { _, _ = m.router.Route(ctx, resIQ) // run vCard fetched hook - _, err = m.hk.Run(ctx, hook.VCardFetched, &hook.ExecutionContext{ + _, err = m.hk.Run(hook.VCardFetched, &hook.ExecutionContext{ Info: &hook.VCardInfo{ Username: toJID.Node(), VCard: vCard, }, - Sender: m, + Sender: m, + Context: ctx, }) return err } @@ -175,12 +176,13 @@ func (m *VCard) setVCard(ctx context.Context, iq *stravaganza.IQ) error { _, _ = m.router.Route(ctx, xmpputil.MakeResultIQ(iq, nil)) // run vCard updated hook - _, err = m.hk.Run(ctx, hook.VCardUpdated, &hook.ExecutionContext{ + _, err = m.hk.Run(hook.VCardUpdated, &hook.ExecutionContext{ Info: &hook.VCardInfo{ Username: toJID.Node(), VCard: vCard, }, - Sender: m, + Sender: m, + Context: ctx, }) return err } diff --git a/pkg/module/xep0115/caps.go b/pkg/module/xep0115/caps.go index 0f5eae810..433e14e89 100644 --- a/pkg/module/xep0115/caps.go +++ b/pkg/module/xep0115/caps.go @@ -172,31 +172,31 @@ func (m *Capabilities) Stop(_ context.Context) error { return nil } -func (m *Capabilities) onC2SPresenceRecv(ctx context.Context, execCtx *hook.ExecutionContext) error { +func (m *Capabilities) onC2SPresenceRecv(execCtx *hook.ExecutionContext) error { inf := execCtx.Info.(*hook.C2SStreamInfo) pr := inf.Element.(*stravaganza.Presence) - return m.processPresence(ctx, pr) + return m.processPresence(execCtx.Context, pr) } -func (m *Capabilities) onS2SPresenceRecv(ctx context.Context, execCtx *hook.ExecutionContext) error { +func (m *Capabilities) onS2SPresenceRecv(execCtx *hook.ExecutionContext) error { inf := execCtx.Info.(*hook.S2SStreamInfo) pr := inf.Element.(*stravaganza.Presence) - return m.processPresence(ctx, pr) + return m.processPresence(execCtx.Context, pr) } -func (m *Capabilities) onC2SIQRecv(ctx context.Context, execCtx *hook.ExecutionContext) error { +func (m *Capabilities) onC2SIQRecv(execCtx *hook.ExecutionContext) error { inf := execCtx.Info.(*hook.C2SStreamInfo) iq := inf.Element.(*stravaganza.IQ) - return m.processIQ(ctx, iq) + return m.processIQ(execCtx.Context, iq) } -func (m *Capabilities) onS2SIQRecv(ctx context.Context, execCtx *hook.ExecutionContext) error { +func (m *Capabilities) onS2SIQRecv(execCtx *hook.ExecutionContext) error { inf := execCtx.Info.(*hook.S2SStreamInfo) iq := inf.Element.(*stravaganza.IQ) - return m.processIQ(ctx, iq) + return m.processIQ(execCtx.Context, iq) } -func (m *Capabilities) onDiscoProvidersStarted(_ context.Context, execCtx *hook.ExecutionContext) error { +func (m *Capabilities) onDiscoProvidersStarted(execCtx *hook.ExecutionContext) error { disc := execCtx.Sender.(*xep0030.Disco) m.mu.Lock() m.srvProv = disc.ServerProvider() diff --git a/pkg/module/xep0115/caps_test.go b/pkg/module/xep0115/caps_test.go index 71eeec984..259c3a6ca 100644 --- a/pkg/module/xep0115/caps_test.go +++ b/pkg/module/xep0115/caps_test.go @@ -70,10 +70,11 @@ func TestCapabilities_RequestDiscoInfo(t *testing.T) { Build() pr := xmpputil.MakePresence(jd0, jd1, stravaganza.AvailableType, []stravaganza.Element{cElem}) - _, _ = hk.Run(context.Background(), hook.C2SStreamPresenceReceived, &hook.ExecutionContext{ + _, _ = hk.Run(hook.C2SStreamPresenceReceived, &hook.ExecutionContext{ Info: &hook.C2SStreamInfo{ Element: pr, }, + Context: context.Background(), }) // then @@ -139,10 +140,11 @@ func TestCapabilities_ProcessDiscoInfo(t *testing.T) { _ = c.Start(context.Background()) defer func() { _ = c.Stop(context.Background()) }() - _, _ = hk.Run(context.Background(), hook.C2SStreamIQReceived, &hook.ExecutionContext{ + _, _ = hk.Run(hook.C2SStreamIQReceived, &hook.ExecutionContext{ Info: &hook.C2SStreamInfo{ Element: discoIQ, }, + Context: context.Background(), }) // then diff --git a/pkg/module/xep0191/blocklist.go b/pkg/module/xep0191/blocklist.go index 4637e8b87..5ea173e6a 100644 --- a/pkg/module/xep0191/blocklist.go +++ b/pkg/module/xep0191/blocklist.go @@ -148,8 +148,10 @@ func (m *BlockList) Stop(_ context.Context) error { return nil } -func (m *BlockList) onC2SElementRecv(ctx context.Context, execCtx *hook.ExecutionContext) error { +func (m *BlockList) onC2SElementRecv(execCtx *hook.ExecutionContext) error { inf := execCtx.Info.(*hook.C2SStreamInfo) + ctx := execCtx.Context + stanza, ok := inf.Element.(stravaganza.Stanza) if !ok { return nil @@ -157,8 +159,10 @@ func (m *BlockList) onC2SElementRecv(ctx context.Context, execCtx *hook.Executio return m.processIncomingStanza(ctx, stanza) } -func (m *BlockList) onS2SElementRecv(ctx context.Context, execCtx *hook.ExecutionContext) error { +func (m *BlockList) onS2SElementRecv(execCtx *hook.ExecutionContext) error { inf := execCtx.Info.(*hook.S2SStreamInfo) + ctx := execCtx.Context + stanza, ok := inf.Element.(stravaganza.Stanza) if !ok { return nil @@ -166,8 +170,10 @@ func (m *BlockList) onS2SElementRecv(ctx context.Context, execCtx *hook.Executio return m.processIncomingStanza(ctx, stanza) } -func (m *BlockList) onC2SElementWillRoute(ctx context.Context, execCtx *hook.ExecutionContext) error { +func (m *BlockList) onC2SElementWillRoute(execCtx *hook.ExecutionContext) error { inf := execCtx.Info.(*hook.C2SStreamInfo) + ctx := execCtx.Context + stanza, ok := inf.Element.(stravaganza.Stanza) if !ok { return nil @@ -175,8 +181,10 @@ func (m *BlockList) onC2SElementWillRoute(ctx context.Context, execCtx *hook.Exe return m.processOutgoingStanza(ctx, stanza) } -func (m *BlockList) onS2SElementWillRoute(ctx context.Context, execCtx *hook.ExecutionContext) error { +func (m *BlockList) onS2SElementWillRoute(execCtx *hook.ExecutionContext) error { inf := execCtx.Info.(*hook.S2SStreamInfo) + ctx := execCtx.Context + stanza, ok := inf.Element.(stravaganza.Stanza) if !ok { return nil @@ -184,9 +192,9 @@ func (m *BlockList) onS2SElementWillRoute(ctx context.Context, execCtx *hook.Exe return m.processOutgoingStanza(ctx, stanza) } -func (m *BlockList) onUserDeleted(ctx context.Context, execCtx *hook.ExecutionContext) error { +func (m *BlockList) onUserDeleted(execCtx *hook.ExecutionContext) error { inf := execCtx.Info.(*hook.UserInfo) - return m.rep.DeleteBlockListItems(ctx, inf.Username) + return m.rep.DeleteBlockListItems(execCtx.Context, inf.Username) } func (m *BlockList) processIncomingStanza(ctx context.Context, stanza stravaganza.Stanza) error { @@ -297,12 +305,13 @@ func (m *BlockList) getBlockList(ctx context.Context, iq *stravaganza.IQ) error j, _ := jid.NewWithString(itm.Jid, false) allJIDs = append(allJIDs, *j) } - _, err = m.hk.Run(ctx, hook.BlockListFetched, &hook.ExecutionContext{ + _, err = m.hk.Run(hook.BlockListFetched, &hook.ExecutionContext{ Info: &hook.BlockListInfo{ Username: username, JIDs: allJIDs, }, - Sender: m, + Sender: m, + Context: ctx, }) return err } @@ -380,12 +389,13 @@ func (m *BlockList) blockJIDs(ctx context.Context, iq *stravaganza.IQ, block str m.sendPush(ctx, block, rss) // run hook - _, err = m.hk.Run(ctx, hook.BlockListItemsBlocked, &hook.ExecutionContext{ + _, err = m.hk.Run(hook.BlockListItemsBlocked, &hook.ExecutionContext{ Info: &hook.BlockListInfo{ Username: username, JIDs: blockJIDs, }, - Sender: m, + Sender: m, + Context: ctx, }) return err } @@ -452,12 +462,13 @@ func (m *BlockList) unblockJIDs(ctx context.Context, iq *stravaganza.IQ, unblock m.sendPush(ctx, unblock, rss) // run hook - _, err = m.hk.Run(ctx, hook.BlockListItemsUnblocked, &hook.ExecutionContext{ + _, err = m.hk.Run(hook.BlockListItemsUnblocked, &hook.ExecutionContext{ Info: &hook.BlockListInfo{ Username: username, JIDs: unblockJIDs, }, - Sender: m, + Sender: m, + Context: ctx, }) return err } diff --git a/pkg/module/xep0191/blocklist_test.go b/pkg/module/xep0191/blocklist_test.go index e218451f3..1a5f4ba75 100644 --- a/pkg/module/xep0191/blocklist_test.go +++ b/pkg/module/xep0191/blocklist_test.go @@ -345,10 +345,11 @@ func TestBlockList_UserDeleted(t *testing.T) { _ = bl.Start(context.Background()) defer func() { _ = bl.Stop(context.Background()) }() - _, _ = hk.Run(context.Background(), hook.UserDeleted, &hook.ExecutionContext{ + _, _ = hk.Run(hook.UserDeleted, &hook.ExecutionContext{ Info: &hook.UserInfo{ Username: "ortuman", }, + Context: context.Background(), }) // then @@ -397,10 +398,11 @@ func TestBlockList_InterceptIncomingStanza(t *testing.T) { _ = bl.Start(context.Background()) defer func() { _ = bl.Stop(context.Background()) }() - halted, err := hk.Run(context.Background(), hook.C2SStreamElementReceived, &hook.ExecutionContext{ + halted, err := hk.Run(hook.C2SStreamElementReceived, &hook.ExecutionContext{ Info: &hook.C2SStreamInfo{ Element: msg, }, + Context: context.Background(), }) // then @@ -459,10 +461,11 @@ func TestBlockList_InterceptOutgoingStanza(t *testing.T) { _ = bl.Start(context.Background()) defer func() { _ = bl.Stop(context.Background()) }() - halted, err := hk.Run(context.Background(), hook.C2SStreamWillRouteElement, &hook.ExecutionContext{ + halted, err := hk.Run(hook.C2SStreamWillRouteElement, &hook.ExecutionContext{ Info: &hook.C2SStreamInfo{ Element: msg, }, + Context: context.Background(), }) // then diff --git a/pkg/module/xep0198/stream.go b/pkg/module/xep0198/stream.go index fbea2b856..3dab7ffb4 100644 --- a/pkg/module/xep0198/stream.go +++ b/pkg/module/xep0198/stream.go @@ -170,8 +170,10 @@ func (m *Stream) Stop(_ context.Context) error { return nil } -func (m *Stream) onElementRecv(ctx context.Context, execCtx *hook.ExecutionContext) error { +func (m *Stream) onElementRecv(execCtx *hook.ExecutionContext) error { inf := execCtx.Info.(*hook.C2SStreamInfo) + ctx := execCtx.Context + stm := execCtx.Sender.(stream.C2S) if inf.Element.Attribute(stravaganza.Namespace) == streamNamespace { if err := m.processCmd(ctx, inf.Element, stm); err != nil { @@ -191,7 +193,7 @@ func (m *Stream) onElementRecv(ctx context.Context, execCtx *hook.ExecutionConte return nil } -func (m *Stream) onElementSent(_ context.Context, execCtx *hook.ExecutionContext) error { +func (m *Stream) onElementSent(execCtx *hook.ExecutionContext) error { inf := execCtx.Info.(*hook.C2SStreamInfo) stanza, ok := inf.Element.(stravaganza.Stanza) if !ok { @@ -220,7 +222,7 @@ func (m *Stream) onElementSent(_ context.Context, execCtx *hook.ExecutionContext return nil } -func (m *Stream) onDisconnect(_ context.Context, execCtx *hook.ExecutionContext) error { +func (m *Stream) onDisconnect(execCtx *hook.ExecutionContext) error { stm := execCtx.Sender.(stream.C2S) if !stm.Info().Bool(enabledInfoKey) { return nil @@ -255,7 +257,7 @@ func (m *Stream) onDisconnect(_ context.Context, execCtx *hook.ExecutionContext) return hook.ErrStopped } -func (m *Stream) onTerminate(_ context.Context, execCtx *hook.ExecutionContext) error { +func (m *Stream) onTerminate(execCtx *hook.ExecutionContext) error { inf := execCtx.Info.(*hook.C2SStreamInfo) stm := execCtx.Sender.(stream.C2S) if !stm.Info().Bool(enabledInfoKey) { diff --git a/pkg/module/xep0198/stream_test.go b/pkg/module/xep0198/stream_test.go index c350d84fc..799316991 100644 --- a/pkg/module/xep0198/stream_test.go +++ b/pkg/module/xep0198/stream_test.go @@ -111,13 +111,14 @@ func TestStream_Enable(t *testing.T) { _ = sm.Start(context.Background()) defer func() { _ = sm.Stop(context.Background()) }() - halted, err := hk.Run(context.Background(), hook.C2SStreamElementReceived, &hook.ExecutionContext{ + halted, err := hk.Run(hook.C2SStreamElementReceived, &hook.ExecutionContext{ Info: &hook.C2SStreamInfo{ Element: stravaganza.NewBuilder("enable"). WithAttribute(stravaganza.Namespace, streamNamespace). Build(), }, - Sender: stmMock, + Sender: stmMock, + Context: context.Background(), }) // then @@ -177,9 +178,10 @@ func TestStream_InStanza(t *testing.T) { _ = sm.Start(context.Background()) defer func() { _ = sm.Stop(context.Background()) }() - _, err := hk.Run(context.Background(), hook.C2SStreamElementReceived, &hook.ExecutionContext{ - Info: &hook.C2SStreamInfo{Element: testMsg}, - Sender: stmMock, + _, err := hk.Run(hook.C2SStreamElementReceived, &hook.ExecutionContext{ + Info: &hook.C2SStreamInfo{Element: testMsg}, + Sender: stmMock, + Context: context.Background(), }) // then @@ -232,9 +234,10 @@ func TestStream_OutStanza(t *testing.T) { _ = sm.Start(context.Background()) defer func() { _ = sm.Stop(context.Background()) }() - _, err := hk.Run(context.Background(), hook.C2SStreamElementSent, &hook.ExecutionContext{ - Info: &hook.C2SStreamInfo{Element: testMsg}, - Sender: stmMock, + _, err := hk.Run(hook.C2SStreamElementSent, &hook.ExecutionContext{ + Info: &hook.C2SStreamInfo{Element: testMsg}, + Sender: stmMock, + Context: context.Background(), }) // then @@ -303,9 +306,10 @@ func TestStream_OutStanzaMaxQueueSizeReached(t *testing.T) { _ = sm.Start(context.Background()) defer func() { _ = sm.Stop(context.Background()) }() - _, err := hk.Run(context.Background(), hook.C2SStreamElementSent, &hook.ExecutionContext{ - Info: &hook.C2SStreamInfo{Element: testMsg2}, - Sender: stmMock, + _, err := hk.Run(hook.C2SStreamElementSent, &hook.ExecutionContext{ + Info: &hook.C2SStreamInfo{Element: testMsg2}, + Sender: stmMock, + Context: context.Background(), }) // then @@ -405,13 +409,14 @@ func TestStream_HandleR(t *testing.T) { _ = sm.Start(context.Background()) defer func() { _ = sm.Stop(context.Background()) }() - halted, err := hk.Run(context.Background(), hook.C2SStreamElementReceived, &hook.ExecutionContext{ + halted, err := hk.Run(hook.C2SStreamElementReceived, &hook.ExecutionContext{ Info: &hook.C2SStreamInfo{ Element: stravaganza.NewBuilder("r"). WithAttribute(stravaganza.Namespace, streamNamespace). Build(), }, - Sender: stmMock, + Sender: stmMock, + Context: context.Background(), }) // then @@ -487,14 +492,15 @@ func TestStream_HandleA(t *testing.T) { _ = sm.Start(context.Background()) defer func() { _ = sm.Stop(context.Background()) }() - halted, err := hk.Run(context.Background(), hook.C2SStreamElementReceived, &hook.ExecutionContext{ + halted, err := hk.Run(hook.C2SStreamElementReceived, &hook.ExecutionContext{ Info: &hook.C2SStreamInfo{ Element: stravaganza.NewBuilder("a"). WithAttribute(stravaganza.Namespace, streamNamespace). WithAttribute("h", "21"). Build(), }, - Sender: stmMock, + Sender: stmMock, + Context: context.Background(), }) // then @@ -592,7 +598,7 @@ func TestStream_Resume(t *testing.T) { _ = sm.Start(context.Background()) defer func() { _ = sm.Stop(context.Background()) }() - halted, err := hk.Run(context.Background(), hook.C2SStreamElementReceived, &hook.ExecutionContext{ + halted, err := hk.Run(hook.C2SStreamElementReceived, &hook.ExecutionContext{ Info: &hook.C2SStreamInfo{ Element: stravaganza.NewBuilder("resume"). WithAttribute(stravaganza.Namespace, streamNamespace). @@ -600,7 +606,8 @@ func TestStream_Resume(t *testing.T) { WithAttribute("h", "21"). Build(), }, - Sender: stmMock, + Sender: stmMock, + Context: context.Background(), }) // then @@ -708,7 +715,7 @@ func TestStream_ResumeRemote(t *testing.T) { _ = sm.Start(context.Background()) defer func() { _ = sm.Stop(context.Background()) }() - halted, err := hk.Run(context.Background(), hook.C2SStreamElementReceived, &hook.ExecutionContext{ + halted, err := hk.Run(hook.C2SStreamElementReceived, &hook.ExecutionContext{ Info: &hook.C2SStreamInfo{ Element: stravaganza.NewBuilder("resume"). WithAttribute(stravaganza.Namespace, streamNamespace). @@ -716,7 +723,8 @@ func TestStream_ResumeRemote(t *testing.T) { WithAttribute("h", "21"). Build(), }, - Sender: stmMock, + Sender: stmMock, + Context: context.Background(), }) // then diff --git a/pkg/module/xep0199/ping.go b/pkg/module/xep0199/ping.go index 5bc6b4d13..1eb3dd2c3 100644 --- a/pkg/module/xep0199/ping.go +++ b/pkg/module/xep0199/ping.go @@ -148,13 +148,13 @@ func (p *Ping) sendPongReply(ctx context.Context, pingIQ *stravaganza.IQ) error return nil } -func (p *Ping) onBinded(_ context.Context, execCtx *hook.ExecutionContext) error { +func (p *Ping) onBinded(execCtx *hook.ExecutionContext) error { inf := execCtx.Info.(*hook.C2SStreamInfo) p.schedulePing(inf.JID) return nil } -func (p *Ping) onRecvElement(_ context.Context, execCtx *hook.ExecutionContext) error { +func (p *Ping) onRecvElement(execCtx *hook.ExecutionContext) error { stm := execCtx.Sender.(stream.C2S) if !stm.IsBinded() { return nil @@ -165,7 +165,7 @@ func (p *Ping) onRecvElement(_ context.Context, execCtx *hook.ExecutionContext) return nil } -func (p *Ping) onDisconnect(_ context.Context, execCtx *hook.ExecutionContext) error { +func (p *Ping) onDisconnect(execCtx *hook.ExecutionContext) error { inf := execCtx.Info.(*hook.C2SStreamInfo) if jd := inf.JID; jd != nil { p.cancelTimers(jd) diff --git a/pkg/module/xep0199/ping_test.go b/pkg/module/xep0199/ping_test.go index ee376f81f..d3643b390 100644 --- a/pkg/module/xep0199/ping_test.go +++ b/pkg/module/xep0199/ping_test.go @@ -82,11 +82,12 @@ func TestPing_SendPing(t *testing.T) { // when _ = p.Start(context.Background()) - _, _ = hk.Run(context.Background(), hook.C2SStreamBinded, &hook.ExecutionContext{ + _, _ = hk.Run(hook.C2SStreamBinded, &hook.ExecutionContext{ Info: &hook.C2SStreamInfo{ ID: "c2s1", JID: jd, }, + Context: context.Background(), }) time.Sleep(time.Second) // wait until ping is triggered @@ -128,11 +129,12 @@ func TestPing_Timeout(t *testing.T) { // when _ = p.Start(context.Background()) - _, _ = hk.Run(context.Background(), hook.C2SStreamBinded, &hook.ExecutionContext{ + _, _ = hk.Run(hook.C2SStreamBinded, &hook.ExecutionContext{ Info: &hook.C2SStreamInfo{ ID: "c2s1", JID: jd, }, + Context: context.Background(), }) time.Sleep(time.Second) // wait until ping is triggered diff --git a/pkg/module/xep0280/carbons.go b/pkg/module/xep0280/carbons.go index 2fbf3a45c..02916a62c 100644 --- a/pkg/module/xep0280/carbons.go +++ b/pkg/module/xep0280/carbons.go @@ -134,7 +134,7 @@ func (p *Carbons) ProcessIQ(ctx context.Context, iq *stravaganza.IQ) error { return nil } -func (p *Carbons) onC2SElementWillRoute(_ context.Context, execCtx *hook.ExecutionContext) error { +func (p *Carbons) onC2SElementWillRoute(execCtx *hook.ExecutionContext) error { inf := execCtx.Info.(*hook.C2SStreamInfo) msg, ok := inf.Element.(*stravaganza.Message) @@ -145,7 +145,7 @@ func (p *Carbons) onC2SElementWillRoute(_ context.Context, execCtx *hook.Executi return nil } -func (p *Carbons) onS2SElementWillRoute(_ context.Context, execCtx *hook.ExecutionContext) error { +func (p *Carbons) onS2SElementWillRoute(execCtx *hook.ExecutionContext) error { inf := execCtx.Info.(*hook.S2SStreamInfo) msg, ok := inf.Element.(*stravaganza.Message) @@ -156,8 +156,9 @@ func (p *Carbons) onS2SElementWillRoute(_ context.Context, execCtx *hook.Executi return nil } -func (p *Carbons) onC2SMessageRouted(ctx context.Context, execCtx *hook.ExecutionContext) error { +func (p *Carbons) onC2SMessageRouted(execCtx *hook.ExecutionContext) error { inf := execCtx.Info.(*hook.C2SStreamInfo) + ctx := execCtx.Context msg, ok := inf.Element.(*stravaganza.Message) if !ok { @@ -166,8 +167,9 @@ func (p *Carbons) onC2SMessageRouted(ctx context.Context, execCtx *hook.Executio return p.processMessage(ctx, msg, inf.Targets) } -func (p *Carbons) onS2SMessageRouted(ctx context.Context, execCtx *hook.ExecutionContext) error { +func (p *Carbons) onS2SMessageRouted(execCtx *hook.ExecutionContext) error { inf := execCtx.Info.(*hook.S2SStreamInfo) + ctx := execCtx.Context msg, ok := inf.Element.(*stravaganza.Message) if !ok { diff --git a/pkg/module/xep0280/carbons_test.go b/pkg/module/xep0280/carbons_test.go index e061cb35a..2f8994aee 100644 --- a/pkg/module/xep0280/carbons_test.go +++ b/pkg/module/xep0280/carbons_test.go @@ -214,12 +214,13 @@ func TestCarbons_SentCC(t *testing.T) { _ = c.Start(context.Background()) defer func() { _ = c.Stop(context.Background()) }() - _, _ = hk.Run(context.Background(), hook.S2SInStreamMessageRouted, &hook.ExecutionContext{ + _, _ = hk.Run(hook.S2SInStreamMessageRouted, &hook.ExecutionContext{ Info: &hook.S2SStreamInfo{ Sender: "jackal.im", Target: "jabber.org", Element: msg, }, + Context: context.Background(), }) // then @@ -286,11 +287,12 @@ func TestCarbons_ReceivedCC(t *testing.T) { _ = c.Start(context.Background()) defer func() { _ = c.Stop(context.Background()) }() - _, _ = hk.Run(context.Background(), hook.C2SStreamMessageRouted, &hook.ExecutionContext{ + _, _ = hk.Run(hook.C2SStreamMessageRouted, &hook.ExecutionContext{ Info: &hook.C2SStreamInfo{ Targets: []jid.JID{*jd2}, Element: msg, }, + Context: context.Background(), }) // then @@ -334,8 +336,9 @@ func TestCarbons_InterceptStanza(t *testing.T) { hInf := &hook.C2SStreamInfo{ Element: msg, } - _, err := hk.Run(context.Background(), hook.C2SStreamWillRouteElement, &hook.ExecutionContext{ - Info: hInf, + _, err := hk.Run(hook.C2SStreamWillRouteElement, &hook.ExecutionContext{ + Info: hInf, + Context: context.Background(), }) // then diff --git a/pkg/s2s/in.go b/pkg/s2s/in.go index eb38ca7d5..695c8d202 100644 --- a/pkg/s2s/in.go +++ b/pkg/s2s/in.go @@ -828,9 +828,10 @@ func (s *inS2S) getState() inState { } func (s *inS2S) runHook(ctx context.Context, hookName string, inf *hook.S2SStreamInfo) (halt bool, err error) { - return s.hk.Run(ctx, hookName, &hook.ExecutionContext{ - Info: inf, - Sender: s, + return s.hk.Run(hookName, &hook.ExecutionContext{ + Info: inf, + Sender: s, + Context: ctx, }) } diff --git a/pkg/s2s/out.go b/pkg/s2s/out.go index 4b328fb9e..e0fa2cb8c 100644 --- a/pkg/s2s/out.go +++ b/pkg/s2s/out.go @@ -597,9 +597,10 @@ func (s *outS2S) runHook(ctx context.Context, hookName string, inf *hook.S2SStre if s.typ == dialbackType { return nil } - _, err := s.hk.Run(ctx, hookName, &hook.ExecutionContext{ - Info: inf, - Sender: s, + _, err := s.hk.Run(hookName, &hook.ExecutionContext{ + Info: inf, + Sender: s, + Context: ctx, }) return err } From bd834ae10d1cb64e30d7629d7621f47236804e7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20=C3=81ngel=20Ortu=C3=B1o?= Date: Thu, 8 Sep 2022 16:50:15 +0200 Subject: [PATCH 2/2] updated CHANGELOG.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Miguel Ángel Ortuño --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ad9fb1940..f6f213d3c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ ## jackal - main / unreleased * [ENHANCEMENT] Re-enable TLS 1.3 channel binding during auth using [RFC 9266](https://www.rfc-editor.org/rfc/rfc9266). +* [ENHANCEMENT] hook: include propagated context into execution parameter. [249](https://github.com/ortuman/jackal/pull/249) ## 0.61.0 (2022/06/06)