Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
c2s: propagate listener port into stream context (#223)
Browse files Browse the repository at this point in the history
  • Loading branch information
ortuman authored May 1, 2022
1 parent 7edd710 commit 1c3b383
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 40 deletions.
44 changes: 26 additions & 18 deletions pkg/c2s/in.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ func (a *authState) reset() {
}

type inC2S struct {
ctx context.Context
cancelFn context.CancelFunc
id stream.C2SID
cfg inCfg
tr transport.Transport
Expand Down Expand Up @@ -128,6 +130,7 @@ type inC2S struct {
}

func newInC2S(
ctx context.Context,
cfg inCfg,
tr transport.Transport,
authenticators []auth.Authenticator,
Expand Down Expand Up @@ -159,25 +162,29 @@ func newInC2S(
},
sLogger,
)
ctx, cancelFn := context.WithCancel(ctx)

// init stream
stm := &inC2S{
id: id,
cfg: cfg,
tr: tr,
inf: c2smodel.NewInfoMap(),
session: session,
authSt: authState{authenticators: authenticators},
hosts: hosts,
router: router,
comps: comps,
mods: mods,
resMng: resMng,
shapers: shapers,
rq: runqueue.New(id.String()),
doneCh: make(chan struct{}),
state: inConnecting,
hk: hk,
logger: sLogger,
ctx: ctx,
cancelFn: cancelFn,
id: id,
cfg: cfg,
tr: tr,
inf: c2smodel.NewInfoMap(),
session: session,
authSt: authState{authenticators: authenticators},
hosts: hosts,
router: router,
comps: comps,
mods: mods,
resMng: resMng,
shapers: shapers,
rq: runqueue.New(id.String()),
doneCh: make(chan struct{}),
state: inConnecting,
hk: hk,
logger: sLogger,
}
if cfg.useTLS {
stm.flags.setSecured() // stream already secured
Expand Down Expand Up @@ -1137,6 +1144,7 @@ func (s *inC2S) terminate(ctx context.Context) error {
if err != nil {
return err
}
s.cancelFn()
close(s.doneCh) // signal termination

s.setState(inTerminated)
Expand Down Expand Up @@ -1216,7 +1224,7 @@ func (s *inC2S) runHook(ctx context.Context, hookName string, inf *hook.C2SStrea
}

func (s *inC2S) requestContext() (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), s.cfg.reqTimeout)
return context.WithTimeout(s.ctx, s.cfg.reqTimeout)
}

var currentID uint64
Expand Down
30 changes: 19 additions & 11 deletions pkg/c2s/in_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,11 @@ func TestInC2S_SendElement(t *testing.T) {
return nil
}
s := &inC2S{
session: sessMock,
rq: runqueue.New("in_c2s:test"),
hk: hook.NewHooks(),
ctx: context.Background(),
cancelFn: func() {},
session: sessMock,
rq: runqueue.New("in_c2s:test"),
hk: hook.NewHooks(),
}
// when
stanza := stravaganza.NewBuilder("auth").
Expand Down Expand Up @@ -111,14 +113,16 @@ func TestInC2S_Disconnect(t *testing.T) {
return c2sRouterMock
}
s := &inC2S{
state: inBinded,
session: sessMock,
tr: trMock,
router: routerMock,
resMng: rmMock,
rq: runqueue.New("in_c2s:test"),
doneCh: make(chan struct{}),
hk: hook.NewHooks(),
ctx: context.Background(),
cancelFn: func() {},
state: inBinded,
session: sessMock,
tr: trMock,
router: routerMock,
resMng: rmMock,
rq: runqueue.New("in_c2s:test"),
doneCh: make(chan struct{}),
hk: hook.NewHooks(),
}
// when
s.Disconnect(streamerror.E(streamerror.SystemShutdown))
Expand Down Expand Up @@ -721,6 +725,8 @@ func TestInC2S_HandleSessionElement(t *testing.T) {

userJID, _ := jid.NewWithString("ortuman@localhost", true)
stm := &inC2S{
ctx: context.Background(),
cancelFn: func() {},
cfg: inCfg{
reqTimeout: time.Minute,
maxStanzaSize: 8192,
Expand Down Expand Up @@ -819,6 +825,8 @@ func TestInC2S_HandleSessionError(t *testing.T) {
}

stm := &inC2S{
ctx: context.Background(),
cancelFn: func() {},
cfg: inCfg{
reqTimeout: time.Minute,
maxStanzaSize: 8192,
Expand Down
5 changes: 5 additions & 0 deletions pkg/c2s/socket_listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import (
"sync/atomic"
"time"

contextutil "github.com/ortuman/jackal/pkg/util/context"

kitlog "github.com/go-kit/log"
"github.com/go-kit/log/level"
"github.com/ortuman/jackal/pkg/auth"
Expand Down Expand Up @@ -61,6 +63,7 @@ var resConflictMap = map[string]resourceConflict{

// SocketListener represents a C2S socket listener type.
type SocketListener struct {
ctx context.Context
cfg ListenerConfig
extAuth *auth.External
hosts *host.Hosts
Expand Down Expand Up @@ -136,6 +139,7 @@ func newSocketListener(
)
}
ln := &SocketListener{
ctx: contextutil.InjectListenerPort(context.Background(), cfg.Port),
cfg: cfg,
extAuth: extAuth,
hosts: hosts,
Expand Down Expand Up @@ -221,6 +225,7 @@ func (l *SocketListener) Stop(ctx context.Context) error {
func (l *SocketListener) handleConn(conn net.Conn) {
tr := transport.NewSocketTransport(conn, l.cfg.ConnectTimeout, l.cfg.KeepAliveTimeout)
stm, err := newInC2S(
l.ctx,
l.getInConfig(),
tr,
l.getAuthenticators(tr),
Expand Down
14 changes: 7 additions & 7 deletions pkg/cluster/instance/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ import (
)

const (
envInstanceID = "JACKAL_INSTANCE_ID"
envInstanceFQDN = "JACKAL_INSTANCE_FQDN"
envInstanceID = "JACKAL_INSTANCE_ID"
envHostName = "JACKAL_HOSTNAME"
)

var (
instID, hostName string
instID, hostIP string
)

var (
Expand All @@ -38,7 +38,7 @@ var (

func init() {
instID = getID()
hostName = getHostname()
hostIP = getHostname()
}

// ID returns local instance identifier.
Expand All @@ -52,7 +52,7 @@ func ID() string {
// Hostname returns local instance host name.
func Hostname() string {
if readCachedResults {
return hostName
return hostIP
}
return getHostname()
}
Expand All @@ -66,15 +66,15 @@ func getID() string {
}

func getHostname() string {
fqdn := os.Getenv(envInstanceFQDN)
fqdn := os.Getenv(envHostName)
if len(fqdn) > 0 {
return fqdn
}
hn, err := getLocalHostname()
if err == nil && len(hn) > 0 {
return hn
}
return "localhost" // fallback to 'localhost' name
return "localhost" // fallback to 'localhost' ip
}

func getLocalHostname() (string, error) {
Expand Down
6 changes: 3 additions & 3 deletions pkg/cluster/instance/instance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func TestRandomIdentifier(t *testing.T) {

func TestFQDNHostname(t *testing.T) {
// given
_ = os.Setenv(envInstanceFQDN, "xmpp1.jackal.im")
_ = os.Setenv(envHostName, "xmpp1.jackal.im")
readCachedResults = false

// when
Expand All @@ -72,7 +72,7 @@ func TestFQDNHostname(t *testing.T) {

func TestIPHostname(t *testing.T) {
// given
_ = os.Setenv(envInstanceFQDN, "")
_ = os.Setenv(envHostName, "")

interfaceAddresses = func() ([]net.Addr, error) {
return []net.Addr{&net.IPNet{
Expand All @@ -91,7 +91,7 @@ func TestIPHostname(t *testing.T) {

func TestFallbackHostname(t *testing.T) {
// given
_ = os.Setenv(envInstanceFQDN, "")
_ = os.Setenv(envHostName, "")

interfaceAddresses = func() ([]net.Addr, error) {
return nil, errors.New("foo error")
Expand Down
7 changes: 6 additions & 1 deletion pkg/module/xep0198/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import (
xmppparser "github.com/ortuman/jackal/pkg/parser"
"github.com/ortuman/jackal/pkg/router"
"github.com/ortuman/jackal/pkg/router/stream"
contextutil "github.com/ortuman/jackal/pkg/util/context"
)

const (
Expand Down Expand Up @@ -328,7 +329,7 @@ func (m *Stream) handleEnable(ctx context.Context, stm stream.C2S) error {
stm.SendElement(stravaganza.NewBuilder("enabled").
WithAttribute(stravaganza.Namespace, streamNamespace).
WithAttribute("id", smID).
WithAttribute("location", instance.Hostname()).
WithAttribute("location", getLocation(ctx)).
WithAttribute("resume", "true").
Build(),
)
Expand Down Expand Up @@ -484,3 +485,7 @@ func decodeSMID(smID string) (jd *jid.JID, nonce []byte, err error) {
func queueKey(jd *jid.JID) string {
return jd.String()
}

func getLocation(ctx context.Context) string {
return instance.Hostname() + ":" + strconv.Itoa(contextutil.ExtractListenerPort(ctx))
}
37 changes: 37 additions & 0 deletions pkg/util/context/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright 2022 The jackal Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package context

import "context"

type contextKey string

const (
listenerPortKey contextKey = "ln_port"
)

// InjectListenerPort returns a ctx derived context injecting a listener port.
func InjectListenerPort(ctx context.Context, port int) context.Context {
return context.WithValue(ctx, listenerPortKey, port)
}

// ExtractListenerPort extracts a listener port from ctx context.
func ExtractListenerPort(ctx context.Context) int {
port, ok := ctx.Value(listenerPortKey).(int)
if !ok {
return 0
}
return port
}

0 comments on commit 1c3b383

Please sign in to comment.