From f82ea4196cfb09834242f44d328ad41895d72085 Mon Sep 17 00:00:00 2001 From: Artur Troian Date: Thu, 31 Aug 2017 17:55:46 +0300 Subject: [PATCH] Ref #47 Omtimize send/recv routines Signed-off-by: Artur Troian --- clients/sessions.go | 11 +- connection/connection.go | 231 +++++++++-------- connection/net.go | 519 ------------------------------------- connection/netCallbacks.go | 127 ++++----- connection/publisher.go | 68 ----- connection/receiver.go | 138 ++++++++++ connection/transmitter.go | 214 +++++++++++++++ transport/connTCP.go | 10 + transport/tcp.go | 16 +- 9 files changed, 548 insertions(+), 786 deletions(-) delete mode 100644 connection/net.go delete mode 100644 connection/publisher.go create mode 100644 connection/receiver.go create mode 100644 connection/transmitter.go diff --git a/clients/sessions.go b/clients/sessions.go index 45ff2a0..f067bc5 100644 --- a/clients/sessions.go +++ b/clients/sessions.go @@ -89,21 +89,22 @@ func NewManager(c *Config) (*Manager, error) { allowReplace: c.AllowReplace, offlineQoS0: c.OfflineQoS0, quit: make(chan struct{}), - //sessions: make(map[string]*session), - //subscribers: make(map[string]subscriber.ConnectionProvider), - log: configuration.GetProdLogger().Named("sessions"), + log: configuration.GetProdLogger().Named("sessions"), } m.persistence, _ = c.Persist.Sessions() + var err error + m.log.Info("Loading sessions. Might take a while") + // load sessions for fill systree // those sessions having either will delay or expire are created with and timer started - if err := m.loadSessions(); err != nil { + if err = m.loadSessions(); err != nil { return nil, err } - if err := m.loadSubscribers(); err != nil { + if err = m.loadSubscribers(); err != nil { return nil, err } diff --git a/connection/connection.go b/connection/connection.go index 36828c1..fc28b5b 100644 --- a/connection/connection.go +++ b/connection/connection.go @@ -81,37 +81,35 @@ type Config struct { // Type session type Type struct { - conn *netConn - pubIn *ackQueue - pubOut *ackQueue - publisher *publisher - flowControl *packetsFlowControl - onDisconnect onDisconnect - subscriber subscriber.ConnectionProvider - messenger types.TopicMessenger - auth auth.SessionPermissions - id string - username string - quit chan struct{} - onStart types.Once - onStop types.Once + conn net.Conn + pubIn *ackQueue + pubOut *ackQueue + flowControl *packetsFlowControl + tx *transmitter + rx *receiver + onDisconnect onDisconnect + subscriber subscriber.ConnectionProvider + messenger types.TopicMessenger + auth auth.SessionPermissions + id string + username string + quit chan struct{} + onStart types.Once + //onStop types.Once onConnDisconnect types.Once - started sync.WaitGroup + //started sync.WaitGroup retained struct { lock sync.Mutex list []*packet.Publish } - log struct { - prod *zap.Logger - dev *zap.Logger - } - + log *zap.Logger sendQuota int32 offlineQoS0 bool clean bool version packet.ProtocolVersion + will bool } type unacknowledgedPublish struct { @@ -128,19 +126,44 @@ func New(c *Config) (s *Type, err error) { messenger: c.Messenger, version: c.Version, clean: c.Clean, + conn: c.Conn, onDisconnect: c.OnDisconnect, sendQuota: int32(c.SendQuota), quit: make(chan struct{}), + will: true, } - s.started.Add(1) - s.publisher = newPublisher(s.quit) + s.pubIn = newAckQueue(s.onReleaseIn) + s.pubOut = newAckQueue(s.onReleaseOut) s.flowControl = newFlowControl(s.quit, c.PreserveOrder) + s.log = configuration.GetProdLogger().Named("connection." + s.id) + + s.tx = newTransmitter(&transmitterConfig{ + quit: s.quit, + log: s.log, + id: s.id, + pubIn: s.pubIn, + pubOut: s.pubOut, + flowControl: s.flowControl, + conn: s.conn, + onDisconnect: s.onConnectionClose, + }) - s.log.prod = configuration.GetProdLogger().Named("connection." + s.id) - s.log.dev = configuration.GetDevLogger().Named("connection." + s.id) + s.rx = newReceiver(&receiverConfig{ + quit: s.quit, + conn: s.conn, + onDisconnect: s.onConnectionClose, + onPacket: s.processIncoming, + version: s.version, + will: &s.will, + }) + + if c.KeepAlive > 0 { + s.rx.keepAlive = time.Second * time.Duration(c.KeepAlive) + s.rx.keepAlive = s.rx.keepAlive + (s.rx.keepAlive / 2) + } - // publisher queue is ready, assign to subscriber new online callback + // transmitter queue is ready, assign to subscriber new online callback // and signal it to forward messages to online callback by creating channel s.subscriber.Online(s.onSubscribedPublish) if !s.clean && c.State != nil { @@ -148,41 +171,14 @@ func New(c *Config) (s *Type, err error) { s.loadPersistence(c.State) // nolint: errcheck } - s.pubIn = newAckQueue(s.onReleaseIn) - s.pubOut = newAckQueue(s.onReleaseOut) - - // Run connection - // error check check at deferred call - s.conn, err = newNet( - &netConfig{ - id: s.id, - conn: c.Conn, - keepAlive: c.KeepAlive, - protoVersion: c.Version, - on: onProcess{ - publish: s.onPublish, - ack: s.onAck, - subscribe: s.onSubscribe, - unSubscribe: s.onUnSubscribe, - disconnect: s.onConnectionClose, - }, - packetsMetric: c.Metric.Packets(), - }) return } -// Start signal session to start processing messages. -// Effective only first invoke +// Start run connection func (s *Type) Start() { s.onStart.Do(func() { - s.conn.start() - - s.publisher.stopped.Add(1) - s.publisher.started.Add(1) - go s.publishWorker() - s.publisher.started.Wait() - - s.started.Done() + s.tx.run() + s.rx.run() }) } @@ -190,19 +186,73 @@ func (s *Type) Start() { // or session is being replaced // Effective only first invoke func (s *Type) Stop(reason packet.ReasonCode) { - s.onStop.Do(func() { - s.conn.stop(&reason) - }) + s.onConnectionClose(true) } -func (s *Type) loadPersistence(state *persistenceTypes.SessionMessages) (err error) { - //defer s.publisher.cond.L.Unlock() - //s.publisher.cond.L.Lock() +func (s *Type) processIncoming(p packet.Provider) error { + var err error + var resp packet.Provider + switch pkt := p.(type) { + case *packet.Publish: + resp = s.onPublish(pkt) + case *packet.Ack: + resp = s.onAck(pkt) + case *packet.Subscribe: + resp = s.onSubscribe(pkt) + case *packet.UnSubscribe: + resp = s.onUnSubscribe(pkt) + case *packet.PingReq: + // For PINGREQ message, we should send back PINGRESP + mR, _ := packet.NewMessage(s.version, packet.PINGRESP) + resp, _ = mR.(*packet.PingResp) + case *packet.Disconnect: + // For DISCONNECT message, we should quit without sending Will + s.will = false + + //if s.version == packet.ProtocolV50 { + // // FIXME: CodeRefusedBadUsernameOrPassword has same id as CodeDisconnectWithWill + // if m.ReasonCode() == packet.CodeRefusedBadUsernameOrPassword { + // s.will = true + // } + // + // expireIn := time.Duration(0) + // if val, e := m.PropertyGet(packet.PropertySessionExpiryInterval); e == nil { + // expireIn = time.Duration(val.(uint32)) + // } + // + // // If the Session Expiry Interval in the CONNECT packet was zero, then it is a Protocol Error to set a non- + // // zero Session Expiry Interval in the DISCONNECT packet sent by the Client. If such a non-zero Session + // // Expiry Interval is received by the Server, it does not treat it as a valid DISCONNECT packet. The Server + // // uses DISCONNECT with Reason Code 0x82 (Protocol Error) as described in section 4.13. + // if s.expireIn != nil && *s.expireIn == 0 && expireIn != 0 { + // m, _ := packet.NewMessage(packet.ProtocolV50, packet.DISCONNECT) + // msg, _ := m.(*packet.Disconnect) + // msg.SetReasonCode(packet.CodeProtocolError) + // s.WriteMessage(msg, true) // nolint: errcheck + // } + //} + //return + err = errors.New("disconnect") + default: + s.log.Error("Unsupported incoming message type", + zap.String("ClientID", s.id), + zap.String("type", p.Type().Name())) + return nil + } + + if resp != nil { + s.tx.sendPacket(resp) + } + + return err +} + +func (s *Type) loadPersistence(state *persistenceTypes.SessionMessages) (err error) { for _, d := range state.OutMessages { var msg packet.Provider if msg, _, err = packet.Decode(s.version, d); err != nil { - s.log.prod.Error("Couldn't decode persisted message", zap.Error(err)) + s.log.Error("Couldn't decode persisted message", zap.Error(err)) err = ErrPersistence return } @@ -210,22 +260,22 @@ func (s *Type) loadPersistence(state *persistenceTypes.SessionMessages) (err err switch m := msg.(type) { case *packet.Publish: if m.QoS() == packet.QoS2 && m.Dup() { - s.log.prod.Error("3: QoS2 DUP") + s.log.Error("3: QoS2 DUP") } } - s.publisher.pushFront(msg) + s.tx.loadFront(msg) } for _, d := range state.UnAckMessages { var msg packet.Provider if msg, _, err = packet.Decode(s.version, d); err != nil { - s.log.prod.Error("Couldn't decode persisted message", zap.Error(err)) + s.log.Error("Couldn't decode persisted message", zap.Error(err)) err = ErrPersistence return } - s.publisher.pushFront(&unacknowledgedPublish{msg: msg}) + s.tx.loadFront(&unacknowledgedPublish{msg: msg}) } return @@ -238,15 +288,15 @@ func (s *Type) loadPersistence(state *persistenceTypes.SessionMessages) (err err // For the server, when this method is called, it means there's a message that // should be published to the client on the other end of this connection. So we // will call publish() to send the message. -func (s *Type) onSubscribedPublish(msg *packet.Publish) error { - _m, _ := packet.NewMessage(s.version, packet.PUBLISH) - m := _m.(*packet.Publish) +func (s *Type) onSubscribedPublish(p *packet.Publish) error { + _pkt, _ := packet.NewMessage(s.version, packet.PUBLISH) + pkt := _pkt.(*packet.Publish) // [MQTT-3.3.1-9] // [MQTT-3.3.1-3] - m.Set(msg.Topic(), msg.Payload(), msg.QoS(), false, false) // nolint: errcheck + pkt.Set(p.Topic(), p.Payload(), p.QoS(), false, false) // nolint: errcheck - s.publisher.pushBack(m) + s.tx.queuePacket(pkt) return nil } @@ -256,7 +306,7 @@ func (s *Type) publishToTopic(msg *packet.Publish) error { // [MQTT-3.3.1.3] if msg.Retain() { if err := s.messenger.Retain(msg); err != nil { - s.log.prod.Error("Error retaining message", zap.String("ClientID", s.id), zap.Error(err)) + s.log.Error("Error retaining message", zap.String("ClientID", s.id), zap.Error(err)) } // [MQTT-3.3.1-7] @@ -272,7 +322,7 @@ func (s *Type) publishToTopic(msg *packet.Publish) error { } if err := s.messenger.Publish(msg); err != nil { - s.log.prod.Error("Couldn't publish", zap.String("ClientID", s.id), zap.Error(err)) + s.log.Error("Couldn't publish", zap.String("ClientID", s.id), zap.Error(err)) } return nil @@ -298,38 +348,3 @@ func (s *Type) onReleaseOut(msg packet.Provider) { s.flowControl.release(id) } } - -// publishWorker publish messages coming from subscribed topics -func (s *Type) publishWorker() { - //defer s.Stop(message.CodeSuccess) - defer s.publisher.stopped.Done() - s.publisher.started.Done() - - value := s.publisher.waitForMessage() - for ; value != nil; value = s.publisher.waitForMessage() { - var msg packet.Provider - switch m := value.(type) { - case *packet.Publish: - if m.QoS() != packet.QoS0 { - if id, err := s.flowControl.acquire(); err == nil { - m.SetPacketID(id) - s.pubOut.store(m) - } else { - // if acquire id returned error session is about to exit. Queue message back and get away - s.publisher.pushBack(m) - return - } - } - msg = m - case *unacknowledgedPublish: - msg = m.msg - s.pubOut.store(msg) - } - - if _, err := s.conn.WriteMessage(msg, false); err != nil { - // Couldn't deliver message to client thus requeue it back - // Error during write means connection has been closed/timed-out etc - return - } - } -} diff --git a/connection/net.go b/connection/net.go deleted file mode 100644 index b772098..0000000 --- a/connection/net.go +++ /dev/null @@ -1,519 +0,0 @@ -package connection - -import ( - "encoding/binary" - "io" - "net" - "time" - - "errors" - "sync" - - "github.com/troian/goring" - "github.com/troian/surgemq/configuration" - "github.com/troian/surgemq/packet" - "github.com/troian/surgemq/systree" - "github.com/troian/surgemq/types" - "go.uber.org/zap" -) - -// onProcess callbacks to parent session -type onProcess struct { - // Publish call when PUBLISH message received - publish func(msg *packet.Publish) error - - // Ack call when PUBACK/PUBREC/PUBREL/PUBCOMP received - ack func(msg packet.Provider) error - - // Subscribe call when SUBSCRIBE message received - subscribe func(msg *packet.Subscribe) error - - // UnSubscribe call when UNSUBSCRIBE message received - unSubscribe func(msg *packet.UnSubscribe) (*packet.UnSubAck, error) - - // Disconnect call when connection falls into error or received DISCONNECT message - disconnect func(will bool) -} - -// netConfig of connection -type netConfig struct { - // On parent session callbacks - on onProcess - - // Conn is network connection - conn net.Conn - - // PacketsMetric interface to metric packets - packetsMetric systree.PacketsMetric - - // ID ClientID - id string - - // KeepAlive - keepAlive uint16 - - // ProtoVersion MQTT protocol version - protoVersion packet.ProtocolVersion -} - -type keepAlive struct { - period time.Duration - conn net.Conn - timer *time.Timer -} - -func (k *keepAlive) Read(b []byte) (int, error) { - if k.period > 0 { - if !k.timer.Stop() { - <-k.timer.C - } - k.timer.Reset(k.period) - } - return k.conn.Read(b) -} - -// netConn implementation of the connection -type netConn struct { - // Incoming data buffer. Bytes are read from the connection and put in here - in *goring.Buffer - config *netConfig - - sendTicker *time.Timer - currLock sync.Mutex - currOutBuffer net.Buffers - outBuffers chan net.Buffers - keepAlive keepAlive - - // Wait for the various goroutines to finish starting and stopping - wg struct { - routines struct { - started sync.WaitGroup - stopped sync.WaitGroup - } - conn struct { - started sync.WaitGroup - stopped sync.WaitGroup - } - } - - log struct { - prod *zap.Logger - dev *zap.Logger - } - - // Quit signal for determining when this service should end. If channel is closed, then exit - expireIn *time.Duration - done chan struct{} - onStop types.Once - will bool -} - -// newNet connection -func newNet(config *netConfig) (f *netConn, err error) { - defer func() { - if err != nil { - close(f.done) - } - }() - - f = &netConn{ - config: config, - done: make(chan struct{}), - will: true, - outBuffers: make(chan net.Buffers, 5), - sendTicker: time.NewTimer(5 * time.Millisecond), - } - - f.sendTicker.Stop() - - f.log.prod = configuration.GetProdLogger().Named("session.conn." + config.id) - f.log.dev = configuration.GetDevLogger().Named("session.conn." + config.id) - - f.wg.conn.started.Add(1) - f.wg.conn.stopped.Add(1) - - // Create the incoming ring buffer - f.in, err = goring.New(goring.DefaultBufferSize) - if err != nil { - return nil, err - } - - f.keepAlive.conn = f.config.conn - - if f.config.keepAlive > 0 { - f.keepAlive.period = time.Second * time.Duration(f.config.keepAlive) - f.keepAlive.period = f.keepAlive.period + (f.keepAlive.period / 2) - } - - return f, nil -} - -// Start serving messages over this connection -func (s *netConn) start() { - defer s.wg.conn.started.Done() - - if s.keepAlive.period > 0 { - s.wg.routines.stopped.Add(4) - } else { - s.wg.routines.stopped.Add(3) - } - - // these routines must start in specified order - // and next proceed next one only when previous finished - s.wg.routines.started.Add(1) - go s.sender() - s.wg.routines.started.Wait() - - s.wg.routines.started.Add(1) - go s.processIncoming() - s.wg.routines.started.Wait() - - if s.keepAlive.period > 0 { - s.wg.routines.started.Add(1) - go s.readTimeOutWorker() - s.wg.routines.started.Wait() - } - - s.wg.routines.started.Add(1) - go s.receiver() - s.wg.routines.started.Wait() -} - -// Stop connection. Effective is only first invoke -func (s *netConn) stop(reason *packet.ReasonCode) { - s.onStop.Do(func() { - // wait if stop invoked before start finished - s.wg.conn.started.Wait() - - // Close quit channel, effectively telling all the goroutines it's time to quit - if !s.isDone() { - close(s.done) - } - - if s.config.protoVersion == packet.ProtocolV50 && reason != nil { - m, _ := packet.NewMessage(packet.ProtocolV50, packet.DISCONNECT) - msg, _ := m.(*packet.Disconnect) - msg.SetReasonCode(*reason) - - // TODO: send it over - } - - if err := s.config.conn.Close(); err != nil { - s.log.prod.Error("close connection", zap.String("ClientID", s.config.id), zap.Error(err)) - } - - if err := s.in.Close(); err != nil { - s.log.prod.Error("close input buffer error", zap.String("ClientID", s.config.id), zap.Error(err)) - } - - s.sendTicker.Stop() - - // Wait for all the connection goroutines are finished - s.wg.routines.stopped.Wait() - - s.wg.conn.stopped.Done() - - s.config.on.disconnect(s.will) - }) -} - -// isDone -func (s *netConn) isDone() bool { - select { - case <-s.done: - return true - default: - return false - } -} - -// onRoutineReturn -func (s *netConn) onRoutineReturn() { - s.wg.routines.stopped.Done() - s.stop(nil) -} - -// reads message income messages -func (s *netConn) processIncoming() { - defer s.onRoutineReturn() - - s.wg.routines.started.Done() - - for !s.isDone() || s.in.Len() > 0 { - // 1. firstly lets peak message type and total length - mType, total, err := s.peekMessageSize() - if err != nil { - if err != io.EOF { - s.log.prod.Error("Error peeking next message size", zap.String("ClientID", s.config.id), zap.Error(err)) - } - return - } - - if ok, e := mType.Valid(s.config.protoVersion); !ok { - s.log.prod.Error("Invalid message type received", zap.String("ClientID", s.config.id), zap.Error(e)) - return - } - - var msg packet.Provider - - // 2. Now read message including fixed header - msg, _, err = s.readMessage(total) - if err != nil { - if err != io.EOF { - s.log.prod.Error("Couldn't read message", - zap.String("ClientID", s.config.id), - zap.Error(err), - zap.Int("total len", total)) - } - return - } - - s.config.packetsMetric.Received(msg.Type()) - - // 3. Put message for further processing - var resp packet.Provider - switch m := msg.(type) { - case *packet.Publish: - err = s.config.on.publish(m) - case *packet.Ack: - err = s.config.on.ack(msg) - case *packet.Subscribe: - err = s.config.on.subscribe(m) - case *packet.UnSubscribe: - resp, _ = s.config.on.unSubscribe(m) - _, err = s.WriteMessage(resp, false) - case *packet.PingReq: - // For PINGREQ message, we should send back PINGRESP - mR, _ := packet.NewMessage(s.config.protoVersion, packet.PINGRESP) - resp, _ := mR.(*packet.PingResp) - _, err = s.WriteMessage(resp, false) - case *packet.Disconnect: - // For DISCONNECT message, we should quit without sending Will - s.will = false - - if s.config.protoVersion == packet.ProtocolV50 { - // FIXME: CodeRefusedBadUsernameOrPassword has same id as CodeDisconnectWithWill - if m.ReasonCode() == packet.CodeRefusedBadUsernameOrPassword { - s.will = true - } - - expireIn := time.Duration(0) - if val, e := m.PropertyGet(packet.PropertySessionExpiryInterval); e == nil { - expireIn = time.Duration(val.(uint32)) - } - - // If the Session Expiry Interval in the CONNECT packet was zero, then it is a Protocol Error to set a non- - // zero Session Expiry Interval in the DISCONNECT packet sent by the Client. If such a non-zero Session - // Expiry Interval is received by the Server, it does not treat it as a valid DISCONNECT packet. The Server - // uses DISCONNECT with Reason Code 0x82 (Protocol Error) as described in section 4.13. - if s.expireIn != nil && *s.expireIn == 0 && expireIn != 0 { - m, _ := packet.NewMessage(packet.ProtocolV50, packet.DISCONNECT) - msg, _ := m.(*packet.Disconnect) - msg.SetReasonCode(packet.CodeProtocolError) - s.WriteMessage(msg, true) // nolint: errcheck - } - } - return - default: - s.log.prod.Error("Unsupported incoming message type", - zap.String("ClientID", s.config.id), - zap.String("type", msg.Type().Name())) - return - } - - if err != nil { - return - } - } -} - -func (s *netConn) readTimeOutWorker() { - defer s.onRoutineReturn() - - s.keepAlive.timer = time.NewTimer(s.keepAlive.period) - s.wg.routines.started.Done() - - select { - case <-s.keepAlive.timer.C: - s.log.prod.Error("Keep alive timed out") - return - case <-s.done: - s.keepAlive.timer.Stop() - return - } -} - -// receiver reads data from the network, and writes the data into the incoming buffer -func (s *netConn) receiver() { - defer s.onRoutineReturn() - - s.wg.routines.started.Done() - - s.in.ReadFrom(&s.keepAlive) // nolint: errcheck -} - -// sender writes data from the outgoing buffer to the network -func (s *netConn) sender() { - defer s.onRoutineReturn() - s.wg.routines.started.Done() - - for { - bufs := net.Buffers{} - select { - case <-s.sendTicker.C: - s.currLock.Lock() - s.outBuffers <- s.currOutBuffer - s.currOutBuffer = net.Buffers{} - s.currLock.Unlock() - case buf, ok := <-s.outBuffers: - s.sendTicker.Stop() - if !ok { - return - } - bufs = buf - case <-s.done: - s.sendTicker.Stop() - close(s.outBuffers) - return - } - - if len(bufs) > 0 { - if _, err := bufs.WriteTo(s.config.conn); err != nil { - return - } - } - } -} - -// peekMessageSize reads, but not commits, enough bytes to determine the size of -// the next message and returns the type and size. -func (s *netConn) peekMessageSize() (packet.Type, int, error) { - var b []byte - var err error - cnt := 2 - - if s.in == nil { - err = goring.ErrNotReady - return 0, 0, err - } - - // Let's read enough bytes to get the message header (msg type, remaining length) - for { - // If we have read 5 bytes and still not done, then there's a problem. - if cnt > 5 { - return 0, 0, errors.New("sendrecv/peekMessageSize: 4th byte of remaining length has continuation bit set") - } - - // Peek cnt bytes from the input buffer. - b, err = s.in.ReadWait(cnt) - if err != nil { - return 0, 0, err - } - - // If not enough bytes are returned, then continue until there's enough. - if len(b) < cnt { - continue - } - - // If we got enough bytes, then check the last byte to see if the continuation - // bit is set. If so, increment cnt and continue peeking - if b[cnt-1] >= 0x80 { - cnt++ - } else { - break - } - } - - // Get the remaining length of the message - remLen, m := binary.Uvarint(b[1:]) - - // Total message length is remlen + 1 (msg type) + m (remlen bytes) - total := int(remLen) + 1 + m - - mType := packet.Type(b[0] >> 4) - - return mType, total, err -} - -// readMessage reads and copies a message from the buffer. The buffer bytes are -// committed as a result of the read. -func (s *netConn) readMessage(total int) (packet.Provider, int, error) { - defer func() { - if int64(len(s.in.ExternalBuf)) > s.in.Size() { - s.in.ExternalBuf = make([]byte, s.in.Size()) - } - }() - - var err error - var n int - var msg packet.Provider - - if s.in == nil { - err = goring.ErrNotReady - return nil, 0, err - } - - if len(s.in.ExternalBuf) < total { - s.in.ExternalBuf = make([]byte, total) - } - - // Read until we get total bytes - l := 0 - toRead := total - for l < total { - n, err = s.in.Read(s.in.ExternalBuf[l : l+toRead]) - l += n - toRead -= n - if err != nil { - return nil, 0, err - } - } - - var dTotal int - if msg, dTotal, err = packet.Decode(s.config.protoVersion, s.in.ExternalBuf[:total]); err == nil && total != dTotal { - s.log.prod.Error("Incoming and outgoing length does not match", - zap.Int("in", total), - zap.Int("out", dTotal)) - return nil, 0, goring.ErrNotReady - } - - return msg, n, err -} - -// WriteMessage writes a message to the outgoing buffer -func (s *netConn) WriteMessage(msg packet.Provider, lastMessage bool) (int, error) { - if s.isDone() { - return 0, goring.ErrNotReady - } - - if lastMessage { - close(s.done) - } - - var total int - - expectedSize, err := msg.Size() - if err != nil { - return 0, err - } - - buf := make([]byte, expectedSize) - total, err = msg.Encode(buf) - if err != nil { - return 0, err - } - - s.currLock.Lock() - s.currOutBuffer = append(s.currOutBuffer, buf) - if len(s.currOutBuffer) == 1 { - s.sendTicker.Reset(1 * time.Millisecond) - } - - if len(s.currOutBuffer) == 10 { - s.outBuffers <- s.currOutBuffer - s.currOutBuffer = net.Buffers{} - } - s.currLock.Unlock() - - return total, err -} diff --git a/connection/netCallbacks.go b/connection/netCallbacks.go index 5d4774d..bd4ae1b 100644 --- a/connection/netCallbacks.go +++ b/connection/netCallbacks.go @@ -30,23 +30,23 @@ func (s *Type) getState() *persistenceTypes.SessionMessages { unAckMessages := [][]byte{} var next *list.Element - for elem := s.publisher.messages.Front(); elem != nil; elem = next { + for elem := s.tx.messages.Front(); elem != nil; elem = next { next = elem.Next() - switch m := s.publisher.messages.Remove(elem).(type) { + switch m := s.tx.messages.Remove(elem).(type) { case *packet.Publish: qos := m.QoS() if qos != packet.QoS0 || (s.offlineQoS0 && qos == packet.QoS0) { // make sure message has some IDType to prevent encode error m.SetPacketID(0) if buf, err := encodeMessage(m); err != nil { - s.log.prod.Error("Couldn't encode message for persistence", zap.Error(err)) + s.log.Error("Couldn't encode message for persistence", zap.Error(err)) } else { outMessages = append(outMessages, buf) } } case *unacknowledgedPublish: if buf, err := encodeMessage(m.msg); err != nil { - s.log.prod.Error("Couldn't encode message for persistence", zap.Error(err)) + s.log.Error("Couldn't encode message for persistence", zap.Error(err)) } else { unAckMessages = append(unAckMessages, buf) } @@ -62,7 +62,7 @@ func (s *Type) getState() *persistenceTypes.SessionMessages { } if buf, err := encodeMessage(m); err != nil { - s.log.prod.Error("Couldn't encode message for persistence", zap.Error(err)) + s.log.Error("Couldn't encode message for persistence", zap.Error(err)) } else { unAckMessages = append(unAckMessages, buf) } @@ -81,17 +81,18 @@ func (s *Type) onConnectionClose(will bool) { ExpireAt: nil, } - s.started.Wait() + //s.started.Wait() - s.subscriber.Offline(s.clean) + if err := s.conn.Close(); err != nil { + s.log.Error("close connection", zap.String("ClientID", s.id), zap.Error(err)) + } close(s.quit) - s.publisher.cond.L.Lock() - s.publisher.cond.Broadcast() - s.publisher.cond.L.Unlock() - // Wait writer to finish it's job - s.publisher.stopped.Wait() + s.subscriber.Offline(s.clean) + + s.tx.shutdown() + s.rx.shutdown() if !s.clean { params.State = s.getState() @@ -115,12 +116,12 @@ func (s *Type) onConnectionClose(will bool) { // On QoS == 0, we should just take the next step, no ack required // On QoS == 1, send back PUBACK, then take the next step // On QoS == 2, we need to put it in the ack queue, send back PUBREC -func (s *Type) onPublish(msg *packet.Publish) error { +func (s *Type) onPublish(msg *packet.Publish) packet.Provider { // check for topic access - var err error reason := packet.CodeSuccess + var resp packet.Provider // This case is for V5.0 actually as ack messages may return status. // To deal with V3.1.1 two ways left: // - ignore the message but send acks @@ -131,35 +132,34 @@ func (s *Type) onPublish(msg *packet.Publish) error { switch msg.QoS() { case packet.QoS2: - m, _ := packet.NewMessage(s.version, packet.PUBREC) - resp, _ := m.(*packet.Ack) + resp, _ = packet.NewMessage(s.version, packet.PUBREC) + r, _ := resp.(*packet.Ack) id, _ := msg.ID() - resp.SetPacketID(id) - resp.SetReason(reason) + r.SetPacketID(id) + r.SetReason(reason) - _, err = s.conn.WriteMessage(resp, false) // [MQTT-4.3.3-9] // store incoming QoS 2 message before sending PUBREC as theoretically PUBREL // might come before store in case message store done after write PUBREC - if err == nil && reason < packet.CodeUnspecifiedError { + if reason < packet.CodeUnspecifiedError { s.pubIn.store(msg) } case packet.QoS1: - m, _ := packet.NewMessage(s.version, packet.PUBACK) - resp, _ := m.(*packet.Ack) + resp, _ = packet.NewMessage(s.version, packet.PUBACK) + r, _ := resp.(*packet.Ack) id, _ := msg.ID() reason := packet.CodeSuccess - resp.SetPacketID(id) - resp.SetReason(reason) - _, err = s.conn.WriteMessage(resp, false) + r.SetPacketID(id) + r.SetReason(reason) + //_, err = s.conn.WriteMessage(resp, false) // [MQTT-4.3.2-4] - if err == nil && reason < packet.CodeUnspecifiedError { - if err = s.publishToTopic(msg); err != nil { - s.log.prod.Error("Couldn't publish message", + if reason < packet.CodeUnspecifiedError { + if err := s.publishToTopic(msg); err != nil { + s.log.Error("Couldn't publish message", zap.String("ClientID", s.id), zap.Uint8("QoS", uint8(msg.QoS())), zap.Error(err)) @@ -167,16 +167,20 @@ func (s *Type) onPublish(msg *packet.Publish) error { } case packet.QoS0: // QoS 0 // [MQTT-4.3.1] - err = s.publishToTopic(msg) + if err := s.publishToTopic(msg); err != nil { + s.log.Error("Couldn't publish message", + zap.String("ClientID", s.id), + zap.Uint8("QoS", uint8(msg.QoS())), + zap.Error(err)) + } } - return err + return resp } // onAck handle ack acknowledgment received from remote -func (s *Type) onAck(msg packet.Provider) error { - var err error - +func (s *Type) onAck(msg packet.Provider) packet.Provider { + var resp packet.Provider switch mIn := msg.(type) { case *packet.Ack: switch msg.Type() { @@ -197,36 +201,31 @@ func (s *Type) onAck(msg packet.Provider) error { } if !discard { - m, _ := packet.NewMessage(s.version, packet.PUBREL) - resp, _ := m.(*packet.Ack) + resp, _ = packet.NewMessage(s.version, packet.PUBREL) + r, _ := resp.(*packet.Ack) id, _ := msg.ID() - resp.SetPacketID(id) + r.SetPacketID(id) // 2. Put PUBREL into ack queue // Do it before writing into network as theoretically response may come // faster than put into queue s.pubOut.store(resp) - - // 2. Try send PUBREL reply - _, err = s.conn.WriteMessage(resp, false) } case packet.PUBREL: // Remote has released PUBLISH - m, _ := packet.NewMessage(s.version, packet.PUBCOMP) - resp, _ := m.(*packet.Ack) + resp, _ = packet.NewMessage(s.version, packet.PUBCOMP) + r, _ := resp.(*packet.Ack) id, _ := msg.ID() - resp.SetPacketID(id) + r.SetPacketID(id) s.pubIn.release(msg) - - _, err = s.conn.WriteMessage(resp, false) case packet.PUBCOMP: // PUBREL message has been acknowledged, release from queue s.pubOut.release(msg) default: - s.log.prod.Error("Unsupported ack message type", + s.log.Error("Unsupported ack message type", zap.String("ClientID", s.id), zap.String("type", msg.Type().Name())) } @@ -234,10 +233,10 @@ func (s *Type) onAck(msg packet.Provider) error { //return err } - return err + return resp } -func (s *Type) onSubscribe(msg *packet.Subscribe) error { +func (s *Type) onSubscribe(msg *packet.Subscribe) packet.Provider { m, _ := packet.NewMessage(s.version, packet.SUBACK) resp, _ := m.(*packet.SubAck) @@ -249,9 +248,6 @@ func (s *Type) onSubscribe(msg *packet.Subscribe) error { iter := msg.Topics().Iterator() for kv, ok := iter(); ok; kv, ok = iter() { - //for _, t := range topics { - // Let topic manager know we want to listen to given topic - //qos := msg.TopicQos(t) t := kv.Key.(string) ops := kv.Value.(packet.SubscriptionOptions) @@ -286,31 +282,14 @@ func (s *Type) onSubscribe(msg *packet.Subscribe) error { retCodes = append(retCodes, reason) - s.log.dev.Debug("Subscribing", + s.log.Debug("Subscribing", zap.String("ClientID", s.id), zap.String("topic", t), zap.Uint8("result_code", uint8(reason))) } if err := resp.AddReturnCodes(retCodes); err != nil { - return err - } - - if _, err := s.conn.WriteMessage(resp, false); err != nil { - s.log.prod.Error("Couldn't send SUBACK. Proceed to unsubscribe", - zap.String("ClientID", s.id), - zap.Error(err)) - - iter = msg.Topics().Iterator() - for kv, ok := iter(); ok; kv, ok = iter() { - t := kv.Key.(string) - - if err = s.subscriber.UnSubscribe(t); err != nil { - s.log.prod.Error("Couldn't unsubscribe from topic", zap.Error(err)) - } - } - - return err + return nil } // Now put retained messages into publish queue @@ -321,13 +300,13 @@ func (s *Type) onSubscribe(msg *packet.Subscribe) error { // [MQTT-3.3.1-8] msg.Set(rm.Topic(), rm.Payload(), rm.QoS(), true, false) // nolint: errcheck - s.publisher.pushBack(m) + s.tx.sendPacket(msg) } - return nil + return resp } -func (s *Type) onUnSubscribe(msg *packet.UnSubscribe) (*packet.UnSubAck, error) { +func (s *Type) onUnSubscribe(msg *packet.UnSubscribe) packet.Provider { iter := msg.Topics().Iterator() var retCodes []packet.ReasonCode @@ -340,7 +319,7 @@ func (s *Type) onUnSubscribe(msg *packet.UnSubscribe) (*packet.UnSubAck, error) if authorized { if err := s.subscriber.UnSubscribe(t); err != nil { - s.log.prod.Error("Couldn't unsubscribe from topic", zap.Error(err)) + s.log.Error("Couldn't unsubscribe from topic", zap.Error(err)) } else { reason = packet.CodeNoSubscriptionExisted } @@ -358,5 +337,5 @@ func (s *Type) onUnSubscribe(msg *packet.UnSubscribe) (*packet.UnSubAck, error) resp.SetPacketID(id) resp.AddReturnCodes(retCodes) // nolint: errcheck - return resp, nil + return resp } diff --git a/connection/publisher.go b/connection/publisher.go deleted file mode 100644 index 6e3c51d..0000000 --- a/connection/publisher.go +++ /dev/null @@ -1,68 +0,0 @@ -package connection - -import ( - "sync" - - "container/list" -) - -type publisher struct { - // signal publisher goroutine to exit on channel close - quit chan struct{} - // make sure writer started before exiting from Start() - started sync.WaitGroup - // make sure writer has finished before any finalization - stopped sync.WaitGroup - messages *list.List - cond *sync.Cond -} - -func (p *publisher) isDone() bool { - select { - case <-p.quit: - return true - default: - } - - return false -} - -func newPublisher(quit chan struct{}) *publisher { - return &publisher{ - messages: list.New(), - cond: sync.NewCond(new(sync.Mutex)), - quit: quit, - } -} - -func (p *publisher) waitForMessage() interface{} { - if p.isDone() { - return nil - } - - defer p.cond.L.Unlock() - p.cond.L.Lock() - - for p.messages.Len() == 0 { - p.cond.Wait() - if p.isDone() { - return nil - } - } - - return p.messages.Remove(p.messages.Front()) -} - -func (p *publisher) pushFront(value interface{}) { - p.cond.L.Lock() - p.messages.PushFront(value) - p.cond.L.Unlock() - p.cond.Signal() -} - -func (p *publisher) pushBack(value interface{}) { - p.cond.L.Lock() - p.messages.PushBack(value) - p.cond.L.Unlock() - p.cond.Signal() -} diff --git a/connection/receiver.go b/connection/receiver.go new file mode 100644 index 0000000..38f2f3d --- /dev/null +++ b/connection/receiver.go @@ -0,0 +1,138 @@ +package connection + +import ( + "bufio" + "encoding/binary" + "errors" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/troian/surgemq/packet" +) + +type receiverConfig struct { + conn net.Conn + quit chan struct{} + keepAlive time.Duration // nolint: structcheck + onPacket func(packet.Provider) error + onDisconnect func(bool) + will *bool + version packet.ProtocolVersion +} + +type receiver struct { + receiverConfig + wg sync.WaitGroup + running uint32 + recv []byte + remainingRecv int +} + +func newReceiver(config *receiverConfig) *receiver { + r := &receiver{ + receiverConfig: *config, + } + return r +} + +func (r *receiver) shutdown() { + r.wg.Wait() +} + +func (r *receiver) run() { + if atomic.CompareAndSwapUint32(&r.running, 0, 1) { + r.wg.Wait() + r.wg.Add(1) + if r.keepAlive > 0 { + r.conn.SetReadDeadline(time.Now().Add(r.keepAlive)) // nolint: errcheck + } else { + r.conn.SetReadDeadline(time.Time{}) // nolint: errcheck + } + + go r.routine() + } +} + +func (r *receiver) routine() { + var err error + defer func() { + r.wg.Done() + r.onDisconnect(*r.will) + }() + + buf := bufio.NewReader(r.conn) + + for atomic.LoadUint32(&r.running) == 1 { + var pkt packet.Provider + if pkt, err = r.readPacket(buf); err != nil || pkt == nil { + atomic.StoreUint32(&r.running, 0) + } else { + if r.keepAlive > 0 { + r.conn.SetReadDeadline(time.Now().Add(r.keepAlive)) // nolint: errcheck + } + if err = r.onPacket(pkt); err != nil { + atomic.StoreUint32(&r.running, 0) + } + } + } +} + +func (r *receiver) readPacket(buf *bufio.Reader) (packet.Provider, error) { + var err error + + if len(r.recv) == 0 { + var header []byte + peekCount := 2 + // Let's read enough bytes to get the message header (msg type, remaining length) + for { + // If we have read 5 bytes and still not done, then there's a problem. + if peekCount > 5 { + return nil, errors.New("sendrecv/peekMessageSize: 4th byte of remaining length has continuation bit set") + } + + header, err = buf.Peek(peekCount) + if err != nil { + return nil, err + } + + // If not enough bytes are returned, then continue until there's enough. + if len(header) < peekCount { + continue + } + + // If we got enough bytes, then check the last byte to see if the continuation + // bit is set. If so, increment cnt and continue peeking + if header[peekCount-1] >= 0x80 { + peekCount++ + } else { + break + } + } + + // Get the remaining length of the message + remLen, m := binary.Uvarint(header[1:]) + // Total message length is remlen + 1 (msg type) + m (remlen bytes) + r.remainingRecv = int(remLen) + 1 + m + r.recv = make([]byte, r.remainingRecv) + } + + offset := len(r.recv) - r.remainingRecv + + for offset != r.remainingRecv { + var n int + n, err = buf.Read(r.recv[offset:]) + offset += n + if err != nil { + return nil, err + } + } + + var pkt packet.Provider + pkt, _, err = packet.Decode(r.version, r.recv) + r.recv = []byte{} + r.remainingRecv = 0 + + return pkt, err +} diff --git a/connection/transmitter.go b/connection/transmitter.go new file mode 100644 index 0000000..bf52094 --- /dev/null +++ b/connection/transmitter.go @@ -0,0 +1,214 @@ +package connection + +import ( + "container/list" + "errors" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/troian/surgemq/packet" + "go.uber.org/zap" +) + +type transmitterConfig struct { + id string + quit chan struct{} + flowControl *packetsFlowControl + pubIn *ackQueue + pubOut *ackQueue + log *zap.Logger + conn net.Conn + onDisconnect func(bool) +} + +type transmitter struct { + transmitterConfig + lock sync.Mutex + messages *list.List + available chan int + running uint32 + timer *time.Timer + wg sync.WaitGroup +} + +func newTransmitter(config *transmitterConfig) *transmitter { + p := &transmitter{ + transmitterConfig: *config, + messages: list.New(), + available: make(chan int, 1), + timer: time.NewTimer(1 * time.Second), + } + + p.timer.Stop() + + return p +} + +func (p *transmitter) shutdown() { + p.timer.Stop() + p.wg.Wait() + select { + case <-p.available: + default: + close(p.available) + } +} + +func (p *transmitter) loadFront(value interface{}) { + p.lock.Lock() + p.messages.PushFront(value) + p.lock.Unlock() + p.signalAvailable() +} + +//func (p *transmitter) loadBack(value interface{}) { +// p.lock.Lock() +// p.messages.PushBack(value) +// p.lock.Unlock() +// p.signalAvailable() +//} + +func (p *transmitter) sendPacket(pkt packet.Provider) { + p.lock.Lock() + p.messages.PushFront(pkt) + p.lock.Unlock() + p.signalAvailable() + p.run() +} + +func (p *transmitter) queuePacket(pkt packet.Provider) { + p.lock.Lock() + p.messages.PushBack(pkt) + p.lock.Unlock() + p.signalAvailable() + p.run() +} + +func (p *transmitter) signalAvailable() { + select { + case p.available <- 1: + default: + } +} + +func (p *transmitter) run() { + if atomic.CompareAndSwapUint32(&p.running, 0, 1) { + p.wg.Wait() + p.wg.Add(1) + go p.routine() + } +} + +func (p *transmitter) flushBuffers(buf net.Buffers) error { + _, e := buf.WriteTo(p.conn) + buf = net.Buffers{} + // todo metrics + return e +} + +func (p *transmitter) routine() { + var err error + + defer func() { + p.wg.Done() + + if err != nil { + p.onDisconnect(true) + } + }() + + sendBuffers := net.Buffers{} + for atomic.LoadUint32(&p.running) == 1 { + select { + case <-p.timer.C: + if err = p.flushBuffers(sendBuffers); err != nil { + atomic.StoreUint32(&p.running, 0) + return + } + + sendBuffers = net.Buffers{} + p.lock.Lock() + l := p.messages.Len() + p.lock.Unlock() + + if l != 0 { + p.signalAvailable() + } else { + atomic.StoreUint32(&p.running, 0) + return + } + + case <-p.available: + p.lock.Lock() + + var elem *list.Element + + if elem = p.messages.Front(); elem == nil { + p.lock.Unlock() + atomic.StoreUint32(&p.running, 0) + break + } + + value := p.messages.Remove(p.messages.Front()) + + if p.messages.Len() != 0 { + p.signalAvailable() + } + p.lock.Unlock() + + var msg packet.Provider + switch m := value.(type) { + case *packet.Publish: + if m.QoS() != packet.QoS0 { + var id packet.IDType + if id, err = p.flowControl.acquire(); err == nil { + m.SetPacketID(id) + p.pubOut.store(m) + } else { + // if acquire id returned error session is about to exit. Queue message back and get away + p.lock.Lock() + p.messages.PushBack(m) + p.lock.Unlock() + err = errors.New("exit") + atomic.StoreUint32(&p.running, 0) + return + } + } + msg = m + case *unacknowledgedPublish: + msg = m.msg + p.pubOut.store(msg) + default: + msg = m.(packet.Provider) + } + + var sz int + if sz, err = msg.Size(); err != nil { + p.log.Error("Couldn't calculate message size", zap.String("ClientID", p.id), zap.Error(err)) + return + } + + buf := make([]byte, sz) + _, err = msg.Encode(buf) + sendBuffers = append(sendBuffers, buf) + + l := len(sendBuffers) + if l == 1 { + p.timer.Reset(1 * time.Millisecond) + } else if l == 5 { + p.timer.Stop() + if err = p.flushBuffers(sendBuffers); err != nil { + atomic.StoreUint32(&p.running, 0) + } + + sendBuffers = net.Buffers{} + } + case <-p.quit: + err = errors.New("exit") + atomic.StoreUint32(&p.running, 0) + return + } + } +} diff --git a/transport/connTCP.go b/transport/connTCP.go index 2ccfd50..8c55d05 100644 --- a/transport/connTCP.go +++ b/transport/connTCP.go @@ -1,7 +1,9 @@ package transport import ( + "errors" "net" + "os" "time" "github.com/troian/surgemq/systree" @@ -62,3 +64,11 @@ func (c *connTCP) SetReadDeadline(t time.Time) error { func (c *connTCP) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) } + +func (c *connTCP) File() (*os.File, error) { + if conn, ok := c.conn.(*net.TCPConn); ok { + return conn.File() + } + + return nil, errors.New("not implemented") +} diff --git a/transport/tcp.go b/transport/tcp.go index feb90ec..c88c24c 100644 --- a/transport/tcp.go +++ b/transport/tcp.go @@ -12,18 +12,10 @@ import ( // ConfigTCP configuration of tcp transport type ConfigTCP struct { - // Scheme - Scheme string - - // Host - Host string - - // CertFile - CertFile string - - // KeyFile - KeyFile string - + Scheme string + Host string + CertFile string + KeyFile string transport *Config }