Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

hook: add context execution field #249

Merged
merged 2 commits into from
Sep 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## jackal - main / unreleased

* [ENHANCEMENT] Re-enable TLS 1.3 channel binding during auth using [RFC 9266](https://www.rfc-editor.org/rfc/rfc9266).
* [ENHANCEMENT] hook: include propagated context into execution parameter. [249](https://github.com/ortuman/jackal/pull/249)

## 0.61.0 (2022/06/06)

Expand Down
6 changes: 4 additions & 2 deletions pkg/admin/server/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,11 @@ func (s *usersService) CreateUser(ctx context.Context, req *userspb.CreateUserRe
return nil, err
}
// run user created hook
_, err := s.hk.Run(ctx, hook.UserCreated, &hook.ExecutionContext{
_, err := s.hk.Run(hook.UserCreated, &hook.ExecutionContext{
Info: &hook.UserInfo{
Username: username,
},
Context: ctx,
})
if err != nil {
return nil, err
Expand Down Expand Up @@ -102,10 +103,11 @@ func (s *usersService) DeleteUser(ctx context.Context, req *userspb.DeleteUserRe
return nil, status.Error(codes.Internal, err.Error())
}
// run user deleted hook
_, err := s.hk.Run(ctx, hook.UserDeleted, &hook.ExecutionContext{
_, err := s.hk.Run(hook.UserDeleted, &hook.ExecutionContext{
Info: &hook.UserInfo{
Username: username,
},
Context: ctx,
})
if err != nil {
return nil, err
Expand Down
7 changes: 4 additions & 3 deletions pkg/c2s/in.go
Original file line number Diff line number Diff line change
Expand Up @@ -1212,9 +1212,10 @@ func (s *inC2S) getState() state {
}

func (s *inC2S) runHook(ctx context.Context, hookName string, inf *hook.C2SStreamInfo) (halt bool, err error) {
return s.hk.Run(ctx, hookName, &hook.ExecutionContext{
Info: inf,
Sender: s,
return s.hk.Run(hookName, &hook.ExecutionContext{
Info: inf,
Sender: s,
Context: ctx,
})
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/cluster/connmanager/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func (m *Manager) Stop(_ context.Context) error {
return nil
}

func (m *Manager) onMemberListUpdated(ctx context.Context, execCtx *hook.ExecutionContext) error {
func (m *Manager) onMemberListUpdated(execCtx *hook.ExecutionContext) error {
m.mu.Lock()
defer m.mu.Unlock()

Expand All @@ -125,7 +125,7 @@ func (m *Manager) onMemberListUpdated(ctx context.Context, execCtx *hook.Executi
// dial connections to new registered members...
for _, member := range inf.Registered {
cl := newConn(member.Host, member.Port, member.APIVer)
if err := cl.dialContext(ctx); err != nil {
if err := cl.dialContext(execCtx.Context); err != nil {
level.Warn(m.logger).Log("msg", "failed to dial cluster conn", "err", err)
continue
}
Expand Down
9 changes: 6 additions & 3 deletions pkg/cluster/connmanager/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,21 +46,23 @@ func TestConnections_UpdateMembers(t *testing.T) {
_ = connMng.Start(context.Background())

// register cluster member
_, _ = hk.Run(context.Background(), hook.MemberListUpdated, &hook.ExecutionContext{
_, _ = hk.Run(hook.MemberListUpdated, &hook.ExecutionContext{
Info: &hook.MemberListInfo{
Registered: []clustermodel.Member{
{InstanceID: "a1234", Host: "192.168.2.1", Port: 1234, APIVer: version.ClusterAPIVersion},
},
},
Context: context.Background(),
})

conn1, err1 := connMng.GetConnection("a1234")

// register cluster member
_, _ = hk.Run(context.Background(), hook.MemberListUpdated, &hook.ExecutionContext{
_, _ = hk.Run(hook.MemberListUpdated, &hook.ExecutionContext{
Info: &hook.MemberListInfo{
UnregisteredKeys: []string{"a1234"},
},
Context: context.Background(),
})

conn2, err2 := connMng.GetConnection("a1234")
Expand Down Expand Up @@ -94,12 +96,13 @@ func TestConnections_IncompatibleClusterAPI(t *testing.T) {
_ = connMng.Start(context.Background())

incompVer := version.NewVersion(version.ClusterAPIVersion.Major()+1, 0, 0)
_, _ = hk.Run(context.Background(), hook.MemberListUpdated, &hook.ExecutionContext{
_, _ = hk.Run(hook.MemberListUpdated, &hook.ExecutionContext{
Info: &hook.MemberListInfo{
Registered: []clustermodel.Member{
{InstanceID: "a1234", Host: "192.168.2.1", Port: 1234, APIVer: incompVer},
},
},
Context: context.Background(),
})

// then
Expand Down
7 changes: 4 additions & 3 deletions pkg/cluster/memberlist/kv_memberlist.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,10 @@ func (ml *KVMemberList) processKVEvents(ctx context.Context, kvEvents []kvtypes.
}

func (ml *KVMemberList) runHook(ctx context.Context, inf *hook.MemberListInfo) error {
_, err := ml.hk.Run(ctx, hook.MemberListUpdated, &hook.ExecutionContext{
Info: inf,
Sender: ml,
_, err := ml.hk.Run(hook.MemberListUpdated, &hook.ExecutionContext{
Info: inf,
Sender: ml,
Context: ctx,
})
return err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/cluster/resourcemanager/kvmanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ func (m *kvManager) Stop(_ context.Context) error {
return nil
}

func (m *kvManager) onMemberListUpdated(_ context.Context, execCtx *hook.ExecutionContext) error {
func (m *kvManager) onMemberListUpdated(execCtx *hook.ExecutionContext) error {
inf := execCtx.Info.(*hook.MemberListInfo)
if len(inf.UnregisteredKeys) == 0 {
return nil
Expand Down
10 changes: 6 additions & 4 deletions pkg/component/component.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,12 @@ func (c *Components) Start(ctx context.Context) error {
}
level.Info(c.logger).Log("msg", "started components", "total_components", len(c.comps))

_, err := c.hk.Run(ctx, hook.ComponentsStarted, &hook.ExecutionContext{
_, err := c.hk.Run(hook.ComponentsStarted, &hook.ExecutionContext{
Info: &hook.ComponentsInfo{
Hosts: hosts,
},
Sender: c,
Sender: c,
Context: ctx,
})
return err
}
Expand All @@ -181,11 +182,12 @@ func (c *Components) Stop(ctx context.Context) error {
}
level.Info(c.logger).Log("msg", "stopped components", "total_components", len(c.comps))

_, err := c.hk.Run(ctx, hook.ComponentsStopped, &hook.ExecutionContext{
_, err := c.hk.Run(hook.ComponentsStopped, &hook.ExecutionContext{
Info: &hook.ComponentsInfo{
Hosts: hosts,
},
Sender: c,
Sender: c,
Context: ctx,
})
return err
}
7 changes: 4 additions & 3 deletions pkg/component/xep0114/in.go
Original file line number Diff line number Diff line change
Expand Up @@ -453,9 +453,10 @@ func (s *inComponent) getState() inComponentState {
}

func (s *inComponent) runHook(ctx context.Context, hookName string, inf *hook.ExternalComponentInfo) (halt bool, err error) {
return s.hk.Run(ctx, hookName, &hook.ExecutionContext{
Info: inf,
Sender: s,
return s.hk.Run(hookName, &hook.ExecutionContext{
Info: inf,
Sender: s,
Context: ctx,
})
}

Expand Down
11 changes: 6 additions & 5 deletions pkg/hook/hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,16 @@ const (
)

// Handler defines a generic hook handler function.
type Handler func(ctx context.Context, execCtx *ExecutionContext) error
type Handler func(execCtx *ExecutionContext) error

// ErrStopped error is returned by a handler to halt hook execution.
var ErrStopped = errors.New("hook: execution stopped")

// ExecutionContext defines a hook execution info context.
type ExecutionContext struct {
Info interface{}
Sender interface{}
Info interface{}
Sender interface{}
Context context.Context
}

type handler struct {
Expand Down Expand Up @@ -101,13 +102,13 @@ func (h *Hooks) RemoveHook(hook string, hnd Handler) {

// Run invokes all hook handlers in order.
// If halted return value is true no more handlers are invoked.
func (h *Hooks) Run(ctx context.Context, hook string, execCtx *ExecutionContext) (halted bool, err error) {
func (h *Hooks) Run(hook string, execCtx *ExecutionContext) (halted bool, err error) {
h.mu.RLock()
defer h.mu.RUnlock()

handlers := h.handlers[hook]
for _, handler := range handlers {
err := handler.h(ctx, execCtx)
err := handler.h(execCtx)
switch {
case err == nil:
break
Expand Down
23 changes: 11 additions & 12 deletions pkg/hook/hooks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
package hook

import (
"context"
"testing"

"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -43,9 +42,9 @@ func TestHooks_Remove(t *testing.T) {
h := NewHooks()

// when
var hnd1 Handler = func(ctx context.Context, execCtx *ExecutionContext) error { return nil }
var hnd2 Handler = func(ctx context.Context, execCtx *ExecutionContext) error { return nil }
var hnd3 Handler = func(ctx context.Context, execCtx *ExecutionContext) error { return nil }
var hnd1 Handler = func(execCtx *ExecutionContext) error { return nil }
var hnd2 Handler = func(execCtx *ExecutionContext) error { return nil }
var hnd3 Handler = func(execCtx *ExecutionContext) error { return nil }

h.AddHook("h1", hnd1, 0)
h.AddHook("h1", hnd2, 0)
Expand All @@ -65,15 +64,15 @@ func TestHooks_Run(t *testing.T) {

// when
var i int
var hnd1 Handler = func(ctx context.Context, execCtx *ExecutionContext) error { i++; return nil }
var hnd2 Handler = func(ctx context.Context, execCtx *ExecutionContext) error { i++; return nil }
var hnd3 Handler = func(ctx context.Context, execCtx *ExecutionContext) error { i++; return nil }
var hnd1 Handler = func(execCtx *ExecutionContext) error { i++; return nil }
var hnd2 Handler = func(execCtx *ExecutionContext) error { i++; return nil }
var hnd3 Handler = func(execCtx *ExecutionContext) error { i++; return nil }

h.AddHook("h1", hnd1, 0)
h.AddHook("h1", hnd2, 0)
h.AddHook("h1", hnd3, 0)

halted, err := h.Run(context.Background(), "h1", nil)
halted, err := h.Run("h1", nil)

// then
require.Nil(t, err)
Expand All @@ -88,15 +87,15 @@ func TestHooks_HaltedRun(t *testing.T) {

// when
var i int
var hnd1 Handler = func(ctx context.Context, execCtx *ExecutionContext) error { i++; return nil }
var hnd2 Handler = func(ctx context.Context, execCtx *ExecutionContext) error { i++; return ErrStopped }
var hnd3 Handler = func(ctx context.Context, execCtx *ExecutionContext) error { i++; return nil }
var hnd1 Handler = func(execCtx *ExecutionContext) error { i++; return nil }
var hnd2 Handler = func(execCtx *ExecutionContext) error { i++; return ErrStopped }
var hnd3 Handler = func(execCtx *ExecutionContext) error { i++; return nil }

h.AddHook("h1", hnd1, 10)
h.AddHook("h1", hnd2, 5)
h.AddHook("h1", hnd3, 0)

halted, err := h.Run(context.Background(), "h1", nil)
halted, err := h.Run("h1", nil)

// then
require.Nil(t, err)
Expand Down
10 changes: 6 additions & 4 deletions pkg/module/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,12 @@ func (m *Modules) Start(ctx context.Context) error {
"iq_processors_count", len(m.iqProcessors),
"mods_count", len(m.mods),
)
_, err := m.hk.Run(ctx, hook.ModulesStarted, &hook.ExecutionContext{
_, err := m.hk.Run(hook.ModulesStarted, &hook.ExecutionContext{
Info: &hook.ModulesInfo{
ModuleNames: modNames,
},
Sender: m,
Sender: m,
Context: ctx,
})
return err
}
Expand All @@ -127,11 +128,12 @@ func (m *Modules) Stop(ctx context.Context) error {
"iq_processors_count", len(m.iqProcessors),
"mods_count", len(m.mods),
)
_, err := m.hk.Run(ctx, hook.ModulesStopped, &hook.ExecutionContext{
_, err := m.hk.Run(hook.ModulesStopped, &hook.ExecutionContext{
Info: &hook.ModulesInfo{
ModuleNames: modNames,
},
Sender: m,
Sender: m,
Context: ctx,
})
return err
}
Expand Down
18 changes: 10 additions & 8 deletions pkg/module/offline/offline.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func (m *Offline) Stop(_ context.Context) error {
return nil
}

func (m *Offline) onWillRouteElement(ctx context.Context, execCtx *hook.ExecutionContext) error {
func (m *Offline) onWillRouteElement(execCtx *hook.ExecutionContext) error {
var elem stravaganza.Element

switch inf := execCtx.Info.(type) {
Expand All @@ -135,17 +135,17 @@ func (m *Offline) onWillRouteElement(ctx context.Context, execCtx *hook.Executio
if !m.hosts.IsLocalHost(toJID.Domain()) {
return nil
}
rss, err := m.resMng.GetResources(ctx, toJID.Node())
rss, err := m.resMng.GetResources(execCtx.Context, toJID.Node())
if err != nil {
return err
}
if len(rss) > 0 {
return nil
}
return m.archiveMessage(ctx, msg)
return m.archiveMessage(execCtx.Context, msg)
}

func (m *Offline) onC2SPresenceRecv(ctx context.Context, execCtx *hook.ExecutionContext) error {
func (m *Offline) onC2SPresenceRecv(execCtx *hook.ExecutionContext) error {
inf := execCtx.Info.(*hook.C2SStreamInfo)

pr := inf.Element.(*stravaganza.Presence)
Expand All @@ -156,11 +156,12 @@ func (m *Offline) onC2SPresenceRecv(ctx context.Context, execCtx *hook.Execution
if !pr.IsAvailable() || pr.Priority() < 0 {
return nil
}
return m.deliverOfflineMessages(ctx, toJID.Node())
return m.deliverOfflineMessages(execCtx.Context, toJID.Node())
}

func (m *Offline) onUserDeleted(ctx context.Context, execCtx *hook.ExecutionContext) error {
func (m *Offline) onUserDeleted(execCtx *hook.ExecutionContext) error {
inf := execCtx.Info.(*hook.UserInfo)
ctx := execCtx.Context

lockID := offlineQueueLockID(inf.Username)

Expand Down Expand Up @@ -226,12 +227,13 @@ func (m *Offline) archiveMessage(ctx context.Context, msg *stravaganza.Message)
if err := m.rep.InsertOfflineMessage(ctx, dMsg, username); err != nil {
return err
}
_, err = m.hk.Run(ctx, hook.OfflineMessageArchived, &hook.ExecutionContext{
_, err = m.hk.Run(hook.OfflineMessageArchived, &hook.ExecutionContext{
Info: &hook.OfflineInfo{
Username: username,
Message: dMsg,
},
Sender: m,
Sender: m,
Context: ctx,
})
if err != nil {
return err
Expand Down
Loading