diff --git a/pkg/c2s/in.go b/pkg/c2s/in.go index 2a4eaf0f3..226bf5724 100644 --- a/pkg/c2s/in.go +++ b/pkg/c2s/in.go @@ -101,8 +101,6 @@ func (a *authState) reset() { } type inC2S struct { - ctx context.Context - cancelFn context.CancelFunc id stream.C2SID cfg inCfg tr transport.Transport @@ -130,7 +128,6 @@ type inC2S struct { } func newInC2S( - ctx context.Context, cfg inCfg, tr transport.Transport, authenticators []auth.Authenticator, @@ -162,29 +159,25 @@ func newInC2S( }, sLogger, ) - ctx, cancelFn := context.WithCancel(ctx) - // init stream stm := &inC2S{ - ctx: ctx, - cancelFn: cancelFn, - id: id, - cfg: cfg, - tr: tr, - inf: c2smodel.NewInfoMap(), - session: session, - authSt: authState{authenticators: authenticators}, - hosts: hosts, - router: router, - comps: comps, - mods: mods, - resMng: resMng, - shapers: shapers, - rq: runqueue.New(id.String()), - doneCh: make(chan struct{}), - state: inConnecting, - hk: hk, - logger: sLogger, + id: id, + cfg: cfg, + tr: tr, + inf: c2smodel.NewInfoMap(), + session: session, + authSt: authState{authenticators: authenticators}, + hosts: hosts, + router: router, + comps: comps, + mods: mods, + resMng: resMng, + shapers: shapers, + rq: runqueue.New(id.String()), + doneCh: make(chan struct{}), + state: inConnecting, + hk: hk, + logger: sLogger, } if cfg.useTLS { stm.flags.setSecured() // stream already secured @@ -1144,7 +1137,6 @@ func (s *inC2S) terminate(ctx context.Context) error { if err != nil { return err } - s.cancelFn() close(s.doneCh) // signal termination s.setState(inTerminated) @@ -1224,7 +1216,7 @@ func (s *inC2S) runHook(ctx context.Context, hookName string, inf *hook.C2SStrea } func (s *inC2S) requestContext() (context.Context, context.CancelFunc) { - return context.WithTimeout(s.ctx, s.cfg.reqTimeout) + return context.WithTimeout(context.Background(), s.cfg.reqTimeout) } var currentID uint64 diff --git a/pkg/c2s/in_test.go b/pkg/c2s/in_test.go index 63e8cdfc2..56de60d13 100644 --- a/pkg/c2s/in_test.go +++ b/pkg/c2s/in_test.go @@ -59,11 +59,9 @@ func TestInC2S_SendElement(t *testing.T) { return nil } s := &inC2S{ - ctx: context.Background(), - cancelFn: func() {}, - session: sessMock, - rq: runqueue.New("in_c2s:test"), - hk: hook.NewHooks(), + session: sessMock, + rq: runqueue.New("in_c2s:test"), + hk: hook.NewHooks(), } // when stanza := stravaganza.NewBuilder("auth"). @@ -113,16 +111,14 @@ func TestInC2S_Disconnect(t *testing.T) { return c2sRouterMock } s := &inC2S{ - ctx: context.Background(), - cancelFn: func() {}, - state: inBinded, - session: sessMock, - tr: trMock, - router: routerMock, - resMng: rmMock, - rq: runqueue.New("in_c2s:test"), - doneCh: make(chan struct{}), - hk: hook.NewHooks(), + state: inBinded, + session: sessMock, + tr: trMock, + router: routerMock, + resMng: rmMock, + rq: runqueue.New("in_c2s:test"), + doneCh: make(chan struct{}), + hk: hook.NewHooks(), } // when s.Disconnect(streamerror.E(streamerror.SystemShutdown)) @@ -725,8 +721,6 @@ func TestInC2S_HandleSessionElement(t *testing.T) { userJID, _ := jid.NewWithString("ortuman@localhost", true) stm := &inC2S{ - ctx: context.Background(), - cancelFn: func() {}, cfg: inCfg{ reqTimeout: time.Minute, maxStanzaSize: 8192, @@ -825,8 +819,6 @@ func TestInC2S_HandleSessionError(t *testing.T) { } stm := &inC2S{ - ctx: context.Background(), - cancelFn: func() {}, cfg: inCfg{ reqTimeout: time.Minute, maxStanzaSize: 8192, diff --git a/pkg/c2s/socket_listener.go b/pkg/c2s/socket_listener.go index 3c248b98a..487cfb4b0 100644 --- a/pkg/c2s/socket_listener.go +++ b/pkg/c2s/socket_listener.go @@ -22,8 +22,6 @@ import ( "sync/atomic" "time" - contextutil "github.com/ortuman/jackal/pkg/util/context" - kitlog "github.com/go-kit/log" "github.com/go-kit/log/level" "github.com/ortuman/jackal/pkg/auth" @@ -63,7 +61,6 @@ var resConflictMap = map[string]resourceConflict{ // SocketListener represents a C2S socket listener type. type SocketListener struct { - ctx context.Context cfg ListenerConfig extAuth *auth.External hosts *host.Hosts @@ -139,7 +136,6 @@ func newSocketListener( ) } ln := &SocketListener{ - ctx: contextutil.InjectListenerPort(context.Background(), cfg.Port), cfg: cfg, extAuth: extAuth, hosts: hosts, @@ -225,7 +221,6 @@ func (l *SocketListener) Stop(ctx context.Context) error { func (l *SocketListener) handleConn(conn net.Conn) { tr := transport.NewSocketTransport(conn, l.cfg.ConnectTimeout, l.cfg.KeepAliveTimeout) stm, err := newInC2S( - l.ctx, l.getInConfig(), tr, l.getAuthenticators(tr), diff --git a/pkg/cluster/instance/instance.go b/pkg/cluster/instance/instance.go index cb46b0a94..7a988c98a 100644 --- a/pkg/cluster/instance/instance.go +++ b/pkg/cluster/instance/instance.go @@ -15,8 +15,6 @@ package instance import ( - "errors" - "net" "os" "github.com/google/uuid" @@ -24,21 +22,18 @@ import ( const ( envInstanceID = "JACKAL_INSTANCE_ID" - envHostName = "JACKAL_HOSTNAME" ) var ( - instID, hostIP string + instID string ) var ( - readCachedResults = true - interfaceAddresses = net.InterfaceAddrs + readCachedResults = true ) func init() { instID = getID() - hostIP = getHostname() } // ID returns local instance identifier. @@ -49,14 +44,6 @@ func ID() string { return getID() } -// Hostname returns local instance host name. -func Hostname() string { - if readCachedResults { - return hostIP - } - return getHostname() -} - func getID() string { id := os.Getenv(envInstanceID) if len(id) == 0 { @@ -64,31 +51,3 @@ func getID() string { } return id } - -func getHostname() string { - fqdn := os.Getenv(envHostName) - if len(fqdn) > 0 { - return fqdn - } - hn, err := getLocalHostname() - if err == nil && len(hn) > 0 { - return hn - } - return "localhost" // fallback to 'localhost' ip -} - -func getLocalHostname() (string, error) { - addresses, err := interfaceAddresses() - if err != nil { - return "", err - } - - for _, addr := range addresses { - if ipNet, ok := addr.(*net.IPNet); ok && !ipNet.IP.IsLoopback() { - if ipNet.IP.To4() != nil { - return ipNet.IP.String(), nil - } - } - } - return "", errors.New("instance: failed to get local ip") -} diff --git a/pkg/cluster/instance/instance_test.go b/pkg/cluster/instance/instance_test.go index 3907b7d6f..607ed4796 100644 --- a/pkg/cluster/instance/instance_test.go +++ b/pkg/cluster/instance/instance_test.go @@ -15,23 +15,12 @@ package instance import ( - "errors" - "net" "os" "testing" "github.com/stretchr/testify/require" ) -func init() { - interfaceAddresses = func() ([]net.Addr, error) { - return []net.Addr{&net.IPNet{ - IP: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 192, 168, 0, 13}, - Mask: []byte{255, 255, 255, 0}, - }}, nil - } -} - func TestOsEnvironmentIdentifier(t *testing.T) { // given readCachedResults = false @@ -57,50 +46,3 @@ func TestRandomIdentifier(t *testing.T) { // then require.True(t, len(id) > 0) } - -func TestFQDNHostname(t *testing.T) { - // given - _ = os.Setenv(envHostName, "xmpp1.jackal.im") - readCachedResults = false - - // when - hn := Hostname() - - // then - require.Equal(t, "xmpp1.jackal.im", hn) -} - -func TestIPHostname(t *testing.T) { - // given - _ = os.Setenv(envHostName, "") - - interfaceAddresses = func() ([]net.Addr, error) { - return []net.Addr{&net.IPNet{ - IP: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 192, 168, 0, 13}, - Mask: []byte{255, 255, 255, 0}, - }}, nil - } - readCachedResults = false - - // when - hn := Hostname() - - // then - require.Equal(t, "192.168.0.13", hn) -} - -func TestFallbackHostname(t *testing.T) { - // given - _ = os.Setenv(envHostName, "") - - interfaceAddresses = func() ([]net.Addr, error) { - return nil, errors.New("foo error") - } - readCachedResults = false - - // when - hn := Hostname() - - // then - require.Equal(t, "localhost", hn) -} diff --git a/pkg/cluster/memberlist/kv_memberlist.go b/pkg/cluster/memberlist/kv_memberlist.go index 96f8e5128..9d83be338 100644 --- a/pkg/cluster/memberlist/kv_memberlist.go +++ b/pkg/cluster/memberlist/kv_memberlist.go @@ -16,6 +16,7 @@ package memberlist import ( "context" + "errors" "fmt" "net" "strconv" @@ -194,9 +195,13 @@ func (ml *KVMemberList) getMembers(ctx context.Context) ([]clustermodel.Member, } func (ml *KVMemberList) getLocalMember() (*clustermodel.Member, error) { + hostIP, err := getHostIP() + if err != nil { + return nil, err + } return &clustermodel.Member{ InstanceID: instance.ID(), - Host: instance.Hostname(), + Host: hostIP, Port: ml.localPort, APIVer: version.ClusterAPIVersion, }, nil @@ -276,3 +281,19 @@ func localMemberKey() string { func isLocalMemberKey(k string) bool { return k == localMemberKey() } + +func getHostIP() (string, error) { + addresses, err := net.InterfaceAddrs() + if err != nil { + return "", err + } + + for _, addr := range addresses { + if ipNet, ok := addr.(*net.IPNet); ok && !ipNet.IP.IsLoopback() { + if ipNet.IP.To4() != nil { + return ipNet.IP.String(), nil + } + } + } + return "", errors.New("instance: failed to get local ip") +} diff --git a/pkg/cluster/memberlist/kv_memberlist_test.go b/pkg/cluster/memberlist/kv_memberlist_test.go index 843036413..a2229e08f 100644 --- a/pkg/cluster/memberlist/kv_memberlist_test.go +++ b/pkg/cluster/memberlist/kv_memberlist_test.go @@ -39,7 +39,7 @@ func TestMemberList_Join(t *testing.T) { } kvMock.GetPrefixFunc = func(ctx context.Context, prefix string) (map[string][]byte, error) { return map[string][]byte{ - fmt.Sprintf("i://%s", instance.ID()): []byte(fmt.Sprintf("a=%s:4312 cv=v1.0.0", instance.Hostname())), + fmt.Sprintf("i://%s", instance.ID()): []byte(fmt.Sprintf("a=%s:4312 cv=v1.0.0", "10.106.0.5")), "i://b3fd": []byte("a=192.168.0.12:1456 cv=v1.0.0"), }, nil } @@ -113,7 +113,7 @@ func TestMemberList_WatchChanges(t *testing.T) { } kvMock.GetPrefixFunc = func(ctx context.Context, prefix string) (map[string][]byte, error) { return map[string][]byte{ - fmt.Sprintf("i://%s", instance.ID()): []byte(fmt.Sprintf("a=%s:4312 cv=v1.0.0", instance.Hostname())), + fmt.Sprintf("i://%s", instance.ID()): []byte(fmt.Sprintf("a=%s:4312 cv=v1.0.0", "10.106.0.5")), "i://b3fd": []byte("a=192.168.0.12:1456 cv=v1.0.0"), }, nil } diff --git a/pkg/module/xep0198/stream.go b/pkg/module/xep0198/stream.go index 3d9f2494c..a196e49c3 100644 --- a/pkg/module/xep0198/stream.go +++ b/pkg/module/xep0198/stream.go @@ -30,14 +30,12 @@ import ( "github.com/jackal-xmpp/stravaganza" streamerror "github.com/jackal-xmpp/stravaganza/errors/stream" "github.com/jackal-xmpp/stravaganza/jid" - "github.com/ortuman/jackal/pkg/cluster/instance" "github.com/ortuman/jackal/pkg/cluster/resourcemanager" "github.com/ortuman/jackal/pkg/hook" "github.com/ortuman/jackal/pkg/host" xmppparser "github.com/ortuman/jackal/pkg/parser" "github.com/ortuman/jackal/pkg/router" "github.com/ortuman/jackal/pkg/router/stream" - contextutil "github.com/ortuman/jackal/pkg/util/context" ) const ( @@ -329,7 +327,6 @@ func (m *Stream) handleEnable(ctx context.Context, stm stream.C2S) error { stm.SendElement(stravaganza.NewBuilder("enabled"). WithAttribute(stravaganza.Namespace, streamNamespace). WithAttribute("id", smID). - WithAttribute("location", getLocation(ctx)). WithAttribute("resume", "true"). Build(), ) @@ -485,7 +482,3 @@ func decodeSMID(smID string) (jd *jid.JID, nonce []byte, err error) { func queueKey(jd *jid.JID) string { return jd.String() } - -func getLocation(ctx context.Context) string { - return instance.Hostname() + ":" + strconv.Itoa(contextutil.ExtractListenerPort(ctx)) -} diff --git a/pkg/util/context/context.go b/pkg/util/context/context.go deleted file mode 100644 index c4415421e..000000000 --- a/pkg/util/context/context.go +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright 2022 The jackal Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package context - -import "context" - -type contextKey string - -const ( - listenerPortKey contextKey = "ln_port" -) - -// InjectListenerPort returns a ctx derived context injecting a listener port. -func InjectListenerPort(ctx context.Context, port int) context.Context { - return context.WithValue(ctx, listenerPortKey, port) -} - -// ExtractListenerPort extracts a listener port from ctx context. -func ExtractListenerPort(ctx context.Context) int { - port, ok := ctx.Value(listenerPortKey).(int) - if !ok { - return 0 - } - return port -}