From 6b41190fc42404d1bd7ad629e7dffe115f7ccb30 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20=C3=81ngel=20Ortu=C3=B1o?= Date: Sun, 1 May 2022 17:23:22 +0200 Subject: [PATCH] c2s: propagate listener port into stream context --- pkg/c2s/in.go | 44 ++++++++++++++++----------- pkg/c2s/in_test.go | 30 +++++++++++------- pkg/c2s/socket_listener.go | 5 +++ pkg/cluster/instance/instance.go | 14 ++++----- pkg/cluster/instance/instance_test.go | 6 ++-- pkg/module/xep0198/stream.go | 7 ++++- pkg/util/context/context.go | 37 ++++++++++++++++++++++ 7 files changed, 103 insertions(+), 40 deletions(-) create mode 100644 pkg/util/context/context.go diff --git a/pkg/c2s/in.go b/pkg/c2s/in.go index 226bf5724..2a4eaf0f3 100644 --- a/pkg/c2s/in.go +++ b/pkg/c2s/in.go @@ -101,6 +101,8 @@ func (a *authState) reset() { } type inC2S struct { + ctx context.Context + cancelFn context.CancelFunc id stream.C2SID cfg inCfg tr transport.Transport @@ -128,6 +130,7 @@ type inC2S struct { } func newInC2S( + ctx context.Context, cfg inCfg, tr transport.Transport, authenticators []auth.Authenticator, @@ -159,25 +162,29 @@ func newInC2S( }, sLogger, ) + ctx, cancelFn := context.WithCancel(ctx) + // init stream stm := &inC2S{ - id: id, - cfg: cfg, - tr: tr, - inf: c2smodel.NewInfoMap(), - session: session, - authSt: authState{authenticators: authenticators}, - hosts: hosts, - router: router, - comps: comps, - mods: mods, - resMng: resMng, - shapers: shapers, - rq: runqueue.New(id.String()), - doneCh: make(chan struct{}), - state: inConnecting, - hk: hk, - logger: sLogger, + ctx: ctx, + cancelFn: cancelFn, + id: id, + cfg: cfg, + tr: tr, + inf: c2smodel.NewInfoMap(), + session: session, + authSt: authState{authenticators: authenticators}, + hosts: hosts, + router: router, + comps: comps, + mods: mods, + resMng: resMng, + shapers: shapers, + rq: runqueue.New(id.String()), + doneCh: make(chan struct{}), + state: inConnecting, + hk: hk, + logger: sLogger, } if cfg.useTLS { stm.flags.setSecured() // stream already secured @@ -1137,6 +1144,7 @@ func (s *inC2S) terminate(ctx context.Context) error { if err != nil { return err } + s.cancelFn() close(s.doneCh) // signal termination s.setState(inTerminated) @@ -1216,7 +1224,7 @@ func (s *inC2S) runHook(ctx context.Context, hookName string, inf *hook.C2SStrea } func (s *inC2S) requestContext() (context.Context, context.CancelFunc) { - return context.WithTimeout(context.Background(), s.cfg.reqTimeout) + return context.WithTimeout(s.ctx, s.cfg.reqTimeout) } var currentID uint64 diff --git a/pkg/c2s/in_test.go b/pkg/c2s/in_test.go index 56de60d13..63e8cdfc2 100644 --- a/pkg/c2s/in_test.go +++ b/pkg/c2s/in_test.go @@ -59,9 +59,11 @@ func TestInC2S_SendElement(t *testing.T) { return nil } s := &inC2S{ - session: sessMock, - rq: runqueue.New("in_c2s:test"), - hk: hook.NewHooks(), + ctx: context.Background(), + cancelFn: func() {}, + session: sessMock, + rq: runqueue.New("in_c2s:test"), + hk: hook.NewHooks(), } // when stanza := stravaganza.NewBuilder("auth"). @@ -111,14 +113,16 @@ func TestInC2S_Disconnect(t *testing.T) { return c2sRouterMock } s := &inC2S{ - state: inBinded, - session: sessMock, - tr: trMock, - router: routerMock, - resMng: rmMock, - rq: runqueue.New("in_c2s:test"), - doneCh: make(chan struct{}), - hk: hook.NewHooks(), + ctx: context.Background(), + cancelFn: func() {}, + state: inBinded, + session: sessMock, + tr: trMock, + router: routerMock, + resMng: rmMock, + rq: runqueue.New("in_c2s:test"), + doneCh: make(chan struct{}), + hk: hook.NewHooks(), } // when s.Disconnect(streamerror.E(streamerror.SystemShutdown)) @@ -721,6 +725,8 @@ func TestInC2S_HandleSessionElement(t *testing.T) { userJID, _ := jid.NewWithString("ortuman@localhost", true) stm := &inC2S{ + ctx: context.Background(), + cancelFn: func() {}, cfg: inCfg{ reqTimeout: time.Minute, maxStanzaSize: 8192, @@ -819,6 +825,8 @@ func TestInC2S_HandleSessionError(t *testing.T) { } stm := &inC2S{ + ctx: context.Background(), + cancelFn: func() {}, cfg: inCfg{ reqTimeout: time.Minute, maxStanzaSize: 8192, diff --git a/pkg/c2s/socket_listener.go b/pkg/c2s/socket_listener.go index 487cfb4b0..3c248b98a 100644 --- a/pkg/c2s/socket_listener.go +++ b/pkg/c2s/socket_listener.go @@ -22,6 +22,8 @@ import ( "sync/atomic" "time" + contextutil "github.com/ortuman/jackal/pkg/util/context" + kitlog "github.com/go-kit/log" "github.com/go-kit/log/level" "github.com/ortuman/jackal/pkg/auth" @@ -61,6 +63,7 @@ var resConflictMap = map[string]resourceConflict{ // SocketListener represents a C2S socket listener type. type SocketListener struct { + ctx context.Context cfg ListenerConfig extAuth *auth.External hosts *host.Hosts @@ -136,6 +139,7 @@ func newSocketListener( ) } ln := &SocketListener{ + ctx: contextutil.InjectListenerPort(context.Background(), cfg.Port), cfg: cfg, extAuth: extAuth, hosts: hosts, @@ -221,6 +225,7 @@ func (l *SocketListener) Stop(ctx context.Context) error { func (l *SocketListener) handleConn(conn net.Conn) { tr := transport.NewSocketTransport(conn, l.cfg.ConnectTimeout, l.cfg.KeepAliveTimeout) stm, err := newInC2S( + l.ctx, l.getInConfig(), tr, l.getAuthenticators(tr), diff --git a/pkg/cluster/instance/instance.go b/pkg/cluster/instance/instance.go index 74788c521..cb46b0a94 100644 --- a/pkg/cluster/instance/instance.go +++ b/pkg/cluster/instance/instance.go @@ -23,12 +23,12 @@ import ( ) const ( - envInstanceID = "JACKAL_INSTANCE_ID" - envInstanceFQDN = "JACKAL_INSTANCE_FQDN" + envInstanceID = "JACKAL_INSTANCE_ID" + envHostName = "JACKAL_HOSTNAME" ) var ( - instID, hostName string + instID, hostIP string ) var ( @@ -38,7 +38,7 @@ var ( func init() { instID = getID() - hostName = getHostname() + hostIP = getHostname() } // ID returns local instance identifier. @@ -52,7 +52,7 @@ func ID() string { // Hostname returns local instance host name. func Hostname() string { if readCachedResults { - return hostName + return hostIP } return getHostname() } @@ -66,7 +66,7 @@ func getID() string { } func getHostname() string { - fqdn := os.Getenv(envInstanceFQDN) + fqdn := os.Getenv(envHostName) if len(fqdn) > 0 { return fqdn } @@ -74,7 +74,7 @@ func getHostname() string { if err == nil && len(hn) > 0 { return hn } - return "localhost" // fallback to 'localhost' name + return "localhost" // fallback to 'localhost' ip } func getLocalHostname() (string, error) { diff --git a/pkg/cluster/instance/instance_test.go b/pkg/cluster/instance/instance_test.go index 0525bc130..3907b7d6f 100644 --- a/pkg/cluster/instance/instance_test.go +++ b/pkg/cluster/instance/instance_test.go @@ -60,7 +60,7 @@ func TestRandomIdentifier(t *testing.T) { func TestFQDNHostname(t *testing.T) { // given - _ = os.Setenv(envInstanceFQDN, "xmpp1.jackal.im") + _ = os.Setenv(envHostName, "xmpp1.jackal.im") readCachedResults = false // when @@ -72,7 +72,7 @@ func TestFQDNHostname(t *testing.T) { func TestIPHostname(t *testing.T) { // given - _ = os.Setenv(envInstanceFQDN, "") + _ = os.Setenv(envHostName, "") interfaceAddresses = func() ([]net.Addr, error) { return []net.Addr{&net.IPNet{ @@ -91,7 +91,7 @@ func TestIPHostname(t *testing.T) { func TestFallbackHostname(t *testing.T) { // given - _ = os.Setenv(envInstanceFQDN, "") + _ = os.Setenv(envHostName, "") interfaceAddresses = func() ([]net.Addr, error) { return nil, errors.New("foo error") diff --git a/pkg/module/xep0198/stream.go b/pkg/module/xep0198/stream.go index 39a136912..3d9f2494c 100644 --- a/pkg/module/xep0198/stream.go +++ b/pkg/module/xep0198/stream.go @@ -37,6 +37,7 @@ import ( xmppparser "github.com/ortuman/jackal/pkg/parser" "github.com/ortuman/jackal/pkg/router" "github.com/ortuman/jackal/pkg/router/stream" + contextutil "github.com/ortuman/jackal/pkg/util/context" ) const ( @@ -328,7 +329,7 @@ func (m *Stream) handleEnable(ctx context.Context, stm stream.C2S) error { stm.SendElement(stravaganza.NewBuilder("enabled"). WithAttribute(stravaganza.Namespace, streamNamespace). WithAttribute("id", smID). - WithAttribute("location", instance.Hostname()). + WithAttribute("location", getLocation(ctx)). WithAttribute("resume", "true"). Build(), ) @@ -484,3 +485,7 @@ func decodeSMID(smID string) (jd *jid.JID, nonce []byte, err error) { func queueKey(jd *jid.JID) string { return jd.String() } + +func getLocation(ctx context.Context) string { + return instance.Hostname() + ":" + strconv.Itoa(contextutil.ExtractListenerPort(ctx)) +} diff --git a/pkg/util/context/context.go b/pkg/util/context/context.go new file mode 100644 index 000000000..c4415421e --- /dev/null +++ b/pkg/util/context/context.go @@ -0,0 +1,37 @@ +// Copyright 2022 The jackal Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import "context" + +type contextKey string + +const ( + listenerPortKey contextKey = "ln_port" +) + +// InjectListenerPort returns a ctx derived context injecting a listener port. +func InjectListenerPort(ctx context.Context, port int) context.Context { + return context.WithValue(ctx, listenerPortKey, port) +} + +// ExtractListenerPort extracts a listener port from ctx context. +func ExtractListenerPort(ctx context.Context) int { + port, ok := ctx.Value(listenerPortKey).(int) + if !ok { + return 0 + } + return port +}