From 32b43865c5990011771c40dded0ffae8ecac5149 Mon Sep 17 00:00:00 2001 From: Jens Alfke Date: Fri, 10 May 2024 13:56:57 -0700 Subject: [PATCH] Allow a Hook to reload an individual client session In a clustered environment, client connections are distributed among multiple Server instances on different machines. After a client disconnects, leaving behind a persistent session state, its next login is likely to be on a different node. Because of this, in such a setup an individual Server instance should only keep Client instances corresponding to online client connections, and it should be able to reload an individual client's state (presumably from persistent storage) when that client connects. This commit adds support for such an environment by adding a new hook `StoredClientByID`. An implementation finds and returns any persistent client data for a given session ID. In practice the only necessary information turned out to be the saved subscriptions and in-flight ack messages. The hook also returns the prior 'Remote' property since the server logs that. The Server method `inheritClientSession` is extended to call this hook if there is no matching in-memory Client session. If the hook returns session data, it installs it into the Client object in the same way as the existing code. At the end of the Server method `attachClient`, after disconnection, the existence of a `StoredClientByID` hook is checked; if present, the method expires the Client instance so it won't hang around in memory and so the next connection will go to the hook to reload state. --- hooks.go | 26 +++++++++++++++++ server.go | 84 +++++++++++++++++++++++++++++++++++++++++++------------ 2 files changed, 92 insertions(+), 18 deletions(-) diff --git a/hooks.go b/hooks.go index 4da709f7..4ff75777 100644 --- a/hooks.go +++ b/hooks.go @@ -55,6 +55,7 @@ const ( StoredInflightMessages StoredRetainedMessages StoredSysInfo + StoredClientByID ) var ( @@ -114,6 +115,7 @@ type Hook interface { StoredInflightMessages() ([]storage.Message, error) StoredRetainedMessages() ([]storage.Message, error) StoredSysInfo() (storage.SystemInfo, error) + StoredClientByID(id string, username []byte) (string, []storage.Subscription, []storage.Message, error) } // HookOptions contains values which are inherited from the server on initialisation. @@ -679,6 +681,25 @@ func (h *Hooks) OnACLCheck(cl *Client, topic string, write bool) bool { return false } +// StoredClientByID returns the state of the stored client with the given session ID, if any. +func (h *Hooks) StoredClientByID(id string, username []byte) (oldRemote string, subs []storage.Subscription, msgs []storage.Message, err error) { + for _, hook := range h.GetAll() { + if hook.Provides(StoredClientByID) { + oldRemote, subs, msgs, err = hook.StoredClientByID(id, username) + if err != nil { + h.Log.Error("failed to load client by ID", "error", err, "hook", hook.ID()) + return + } + + if oldRemote != "" && err == nil { + return + } + } + } + + return +} + // HookBase provides a set of default methods for each hook. It should be embedded in // all hooks. type HookBase struct { @@ -859,3 +880,8 @@ func (h *HookBase) StoredRetainedMessages() (v []storage.Message, err error) { func (h *HookBase) StoredSysInfo() (v storage.SystemInfo, err error) { return } + +// StoredClientByID returns the state of the stored client with the given session ID, if any. +func (h *HookBase) StoredClientByID(id string, username []byte) (oldRemote string, subs []storage.Subscription, msgs []storage.Message, err error) { + return +} diff --git a/server.go b/server.go index 4ad91822..0430b700 100644 --- a/server.go +++ b/server.go @@ -485,6 +485,11 @@ func (s *Server) attachClient(cl *Client, listener string) error { expire := (cl.Properties.ProtocolVersion == 5 && cl.Properties.Props.SessionExpiryInterval == 0) || (cl.Properties.ProtocolVersion < 5 && cl.Properties.Clean) s.hooks.OnDisconnect(cl, err, expire) + if s.hooks.Provides(StoredClientByID) { + // Hooks are capable of reloading a persistent client session, so I can forget it + expire = true + } + if expire && atomic.LoadUint32(&cl.State.isTakenOver) == 0 { cl.ClearInflights() s.UnsubscribeClient(cl) @@ -596,6 +601,42 @@ func (s *Server) inheritClientSession(pk packets.Packet, cl *Client) bool { return true // [MQTT-3.2.2-3] } + // Look up a stored client that's not in memory yet: + if s.hooks.Provides(StoredClientByID) { + oldRemote, subs, msgs, err := s.hooks.StoredClientByID(cl.ID, cl.Properties.Username) + if err == nil && oldRemote != "" { + // Instantiate in-flight messages to deliver: + if len(msgs) > 0 { + inf := NewInflights() + for _, msg := range msgs { + inf.Set(msg.ToPacket()) + } + cl.State.Inflight = inf + } + + // Instantiate stored subscriptions: + for _, sub := range subs { + sb := packets.Subscription{ + Filter: sub.Filter, + RetainHandling: sub.RetainHandling, + Qos: sub.Qos, + RetainAsPublished: sub.RetainAsPublished, + NoLocal: sub.NoLocal, + Identifier: sub.Identifier, + } + existed := !s.Topics.Subscribe(cl.ID, sb) // [MQTT-3.8.4-3] + if !existed { + atomic.AddInt64(&s.Info.Subscriptions, 1) + } + cl.State.Subscriptions.Add(sb.Filter, sb) + } + + s.Log.Debug("session taken over (persistent)", "client", cl.ID, "old_remote", oldRemote, "new_remote", cl.Net.Remote) + + return true + } + } + if atomic.LoadInt64(&s.Info.ClientsConnected) > atomic.LoadInt64(&s.Info.ClientsMaximum) { atomic.AddInt64(&s.Info.ClientsMaximum, 1) } @@ -1014,6 +1055,7 @@ func (s *Server) publishToSubscribers(pk packets.Packet) { } } +// publishToClient delivers a published message to a single subscriber client. func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packets.Packet) (packets.Packet, error) { if sub.NoLocal && pk.Origin == cl.ID { return pk, nil // [MQTT-3.8.3-3] @@ -1636,24 +1678,7 @@ func (s *Server) loadSubscriptions(v []storage.Subscription) { // loadClients restores clients from the datastore. func (s *Server) loadClients(v []storage.Client) { for _, c := range v { - cl := s.NewClient(nil, c.Listener, c.ID, false) - cl.Properties.Username = c.Username - cl.Properties.Clean = c.Clean - cl.Properties.ProtocolVersion = c.ProtocolVersion - cl.Properties.Props = packets.Properties{ - SessionExpiryInterval: c.Properties.SessionExpiryInterval, - SessionExpiryIntervalFlag: c.Properties.SessionExpiryIntervalFlag, - AuthenticationMethod: c.Properties.AuthenticationMethod, - AuthenticationData: c.Properties.AuthenticationData, - RequestProblemInfoFlag: c.Properties.RequestProblemInfoFlag, - RequestProblemInfo: c.Properties.RequestProblemInfo, - RequestResponseInfo: c.Properties.RequestResponseInfo, - ReceiveMaximum: c.Properties.ReceiveMaximum, - TopicAliasMaximum: c.Properties.TopicAliasMaximum, - User: c.Properties.User, - MaximumPacketSize: c.Properties.MaximumPacketSize, - } - cl.Properties.Will = Will(c.Will) + cl := s.newClientFromStorage(&c) // cancel the context, update cl.State such as disconnected time and stopCause. cl.Stop(packets.ErrServerShuttingDown) @@ -1669,6 +1694,29 @@ func (s *Server) loadClients(v []storage.Client) { } } +// newClientFromStorage creates a Client from a storage.Client. +func (s *Server) newClientFromStorage(c *storage.Client) *Client { + cl := s.NewClient(nil, c.Listener, c.ID, false) + cl.Properties.Username = c.Username + cl.Properties.Clean = c.Clean + cl.Properties.ProtocolVersion = c.ProtocolVersion + cl.Properties.Props = packets.Properties{ + SessionExpiryInterval: c.Properties.SessionExpiryInterval, + SessionExpiryIntervalFlag: c.Properties.SessionExpiryIntervalFlag, + AuthenticationMethod: c.Properties.AuthenticationMethod, + AuthenticationData: c.Properties.AuthenticationData, + RequestProblemInfoFlag: c.Properties.RequestProblemInfoFlag, + RequestProblemInfo: c.Properties.RequestProblemInfo, + RequestResponseInfo: c.Properties.RequestResponseInfo, + ReceiveMaximum: c.Properties.ReceiveMaximum, + TopicAliasMaximum: c.Properties.TopicAliasMaximum, + User: c.Properties.User, + MaximumPacketSize: c.Properties.MaximumPacketSize, + } + cl.Properties.Will = Will(c.Will) + return cl +} + // loadInflight restores inflight messages from the datastore. func (s *Server) loadInflight(v []storage.Message) { for _, msg := range v {