From bda970e336ad0ae87a59fb77c03d505705c2b1a1 Mon Sep 17 00:00:00 2001 From: Andrew Kokhanovskyi Date: Thu, 17 Sep 2020 22:24:40 +0300 Subject: [PATCH] Propagate MQTT connect error (fixes #356) Signed-off-by: Andrew Kokhanovskyi --- client.go | 2 +- net.go | 94 ++++++++++++++++++++++++++++++++----------------------- 2 files changed, 56 insertions(+), 40 deletions(-) diff --git a/client.go b/client.go index b7795205..b69bfdcf 100644 --- a/client.go +++ b/client.go @@ -365,7 +365,7 @@ func (c *client) attemptConnection() (net.Conn, byte, bool, error) { DEBUG.Println(CLI, "socket connected to broker") // Now we send the perform the MQTT connection handshake - rc, sessionPresent = ConnectMQTT(conn, cm, protocolVersion) + rc, sessionPresent, err = connectMQTT(conn, cm, protocolVersion) if rc == packets.Accepted { break // successfully connected } diff --git a/net.go b/net.go index 754b1d47..e87e67ad 100644 --- a/net.go +++ b/net.go @@ -15,6 +15,8 @@ package mqtt import ( + "errors" + "io" "net" "reflect" "strings" @@ -25,11 +27,18 @@ import ( const closedNetConnErrorText = "use of closed network connection" // error string for closed conn (https://golang.org/src/net/error_test.go) -// ConnectMQTT takes a connected net.Conn and performs the initial MQTT handshake. Paramaters are: +// ConnectMQTT takes a connected net.Conn and performs the initial MQTT handshake. Parameters are: // conn - Connected net.Conn -// cm - Connect Packet with everything other than the protocolname/version populated (historical reasons) +// cm - Connect Packet with everything other than the protocol name/version populated (historical reasons) // protocolVersion - The protocol version to attempt to connect with +// +// Note that, for backward compatibility, ConnectMQTT() suppresses the actual connection error (compare to connectMQTT()). func ConnectMQTT(conn net.Conn, cm *packets.ConnectPacket, protocolVersion uint) (byte, bool) { + rc, sessionPresent, _ := connectMQTT(conn, cm, protocolVersion) + return rc, sessionPresent +} + +func connectMQTT(conn io.ReadWriter, cm *packets.ConnectPacket, protocolVersion uint) (byte, bool, error) { switch protocolVersion { case 3: DEBUG.Println(CLI, "Using MQTT 3.1 protocol") @@ -48,43 +57,46 @@ func ConnectMQTT(conn net.Conn, cm *packets.ConnectPacket, protocolVersion uint) cm.ProtocolName = "MQTT" cm.ProtocolVersion = 4 } + if err := cm.Write(conn); err != nil { ERROR.Println(CLI, err) + return packets.ErrNetworkError, false, err } - rc, sessionPresent := verifyCONNACK(conn) - return rc, sessionPresent + rc, sessionPresent, err := verifyCONNACK(conn) + return rc, sessionPresent, err } // This function is only used for receiving a connack // when the connection is first started. // This prevents receiving incoming data while resume // is in progress if clean session is false. -func verifyCONNACK(conn net.Conn) (byte, bool) { +func verifyCONNACK(conn io.Reader) (byte, bool, error) { DEBUG.Println(NET, "connect started") ca, err := packets.ReadPacket(conn) if err != nil { ERROR.Println(NET, "connect got error", err) - return packets.ErrNetworkError, false + return packets.ErrNetworkError, false, err } + if ca == nil { ERROR.Println(NET, "received nil packet") - return packets.ErrNetworkError, false + return packets.ErrNetworkError, false, errors.New("nil CONNACK packet") } msg, ok := ca.(*packets.ConnackPacket) if !ok { ERROR.Println(NET, "received msg that was not CONNACK") - return packets.ErrNetworkError, false + return packets.ErrNetworkError, false, errors.New("non-CONNACK first packet received") } DEBUG.Println(NET, "received connack") - return msg.ReturnCode, msg.SessionPresent + return msg.ReturnCode, msg.SessionPresent, nil } // inbound encapuslates the output from startIncoming. -// err - If != nil then an error has occured +// err - If != nil then an error has occurred // cp - A control packet received over the network link type inbound struct { err error @@ -94,12 +106,13 @@ type inbound struct { // startIncoming initiates a goroutine that reads incoming messages off the wire and sends them to the channel (returned). // If there are any issues with the network connection then the returned cahnnel will be closed and the goroutine will exit // (so closing the connection will terminate the goroutine) -func startIncoming(conn net.Conn) <-chan inbound { +func startIncoming(conn io.Reader) <-chan inbound { var err error var cp packets.ControlPacket ibound := make(chan inbound) DEBUG.Println(NET, "incoming started") + go func() { for { if cp, err = packets.ReadPacket(conn); err != nil { @@ -117,35 +130,36 @@ func startIncoming(conn net.Conn) <-chan inbound { ibound <- inbound{cp: cp} } }() + return ibound } -// incommingComms encapuslates the possible output of the incommingComms routine. If err != nil then an error has occured and +// incomingComms encapuslates the possible output of the incomingComms routine. If err != nil then an error has occurred and // the routine will have terminated; otherwise one of the other members should be non-nil -type incommingComms struct { +type incomingComms struct { err error // If non-nil then there has been an error (ignore everything else) outbound *PacketAndToken // Packet (with token) than needs to be sent out (e.g. an acknowledgement) incommingPub *packets.PublishPacket // A new publish has been received; this will need to be passed on to our user } -// startIncommingComms initiates incomming communications; this includes starting a goroutine to process incomming +// startIncomingComms initiates incomming communications; this includes starting a goroutine to process incomming // messages. // Accepts a channel of inbound messages from the store (persistanced messages); note this must be closed as soon as the // everything in the store has been sent. // Returns a channel that will be passed any received packets; this will be closed on a network error (and inboundFromStore closed) -func startIncommingComms(conn net.Conn, +func startIncomingComms(conn io.Reader, c commsFns, inboundFromStore <-chan packets.ControlPacket, -) <-chan incommingComms { +) <-chan incomingComms { ibound := startIncoming(conn) // Start goroutine that reads from network connection - output := make(chan incommingComms) + output := make(chan incomingComms) - DEBUG.Println(NET, "startIncommingComms started") + DEBUG.Println(NET, "startIncomingComms started") go func() { for { if inboundFromStore == nil && ibound == nil { close(output) - DEBUG.Println(NET, "startIncommingComms goroutine complete") + DEBUG.Println(NET, "startIncomingComms goroutine complete") return // As soon as ibound is closed we can exit (should have already processed an error) } DEBUG.Println(NET, "logic waiting for msg on ibound") @@ -155,22 +169,22 @@ func startIncommingComms(conn net.Conn, select { case msg, ok = <-inboundFromStore: if !ok { - DEBUG.Println(NET, "startIncommingComms: inboundFromStore complete") + DEBUG.Println(NET, "startIncomingComms: inboundFromStore complete") inboundFromStore = nil // should happen quickly as this is only for persisted messages continue } - DEBUG.Println(NET, "startIncommingComms: got msg from store") + DEBUG.Println(NET, "startIncomingComms: got msg from store") case ibMsg, ok := <-ibound: if !ok { - DEBUG.Println(NET, "startIncommingComms: ibound complete") + DEBUG.Println(NET, "startIncomingComms: ibound complete") ibound = nil continue } - DEBUG.Println(NET, "startIncommingComms: got msg on ibound") + DEBUG.Println(NET, "startIncomingComms: got msg on ibound") // If the inbound comms routine encounters any issues it will send us an error. if ibMsg.err != nil { - output <- incommingComms{err: ibMsg.err} - continue // Usually the channel will be closed immediatly after sending an error but safer that we do not assume this + output <- incomingComms{err: ibMsg.err} + continue // Usually the channel will be closed immediately after sending an error but safer that we do not assume this } msg = ibMsg.cp @@ -185,13 +199,14 @@ func startIncommingComms(conn net.Conn, case *packets.SubackPacket: DEBUG.Println(NET, "received suback, id:", m.MessageID) token := c.getToken(m.MessageID) - switch t := token.(type) { - case *SubscribeToken: + + if t, ok := token.(*SubscribeToken); ok { DEBUG.Println(NET, "granted qoss", m.ReturnCodes) for i, qos := range m.ReturnCodes { t.subResult[t.subs[i]] = qos } } + token.flowComplete() c.freeID(m.MessageID) case *packets.UnsubackPacket: @@ -200,7 +215,7 @@ func startIncommingComms(conn net.Conn, c.freeID(m.MessageID) case *packets.PublishPacket: DEBUG.Println(NET, "received publish, msgId:", m.MessageID) - output <- incommingComms{incommingPub: m} + output <- incomingComms{incommingPub: m} case *packets.PubackPacket: DEBUG.Println(NET, "received puback, id:", m.MessageID) c.getToken(m.MessageID).flowComplete() @@ -209,13 +224,13 @@ func startIncommingComms(conn net.Conn, DEBUG.Println(NET, "received pubrec, id:", m.MessageID) prel := packets.NewControlPacket(packets.Pubrel).(*packets.PubrelPacket) prel.MessageID = m.MessageID - output <- incommingComms{outbound: &PacketAndToken{p: prel, t: nil}} + output <- incomingComms{outbound: &PacketAndToken{p: prel, t: nil}} case *packets.PubrelPacket: DEBUG.Println(NET, "received pubrel, id:", m.MessageID) pc := packets.NewControlPacket(packets.Pubcomp).(*packets.PubcompPacket) pc.MessageID = m.MessageID c.persistOutbound(pc) - output <- incommingComms{outbound: &PacketAndToken{p: pc, t: nil}} + output <- incomingComms{outbound: &PacketAndToken{p: pc, t: nil}} case *packets.PubcompPacket: DEBUG.Println(NET, "received pubcomp, id:", m.MessageID) c.getToken(m.MessageID).flowComplete() @@ -304,8 +319,8 @@ func startOutgoingComms(conn net.Conn, errChan <- err continue } - switch msg.p.(type) { - case *packets.DisconnectPacket: + + if _, ok := msg.p.(*packets.DisconnectPacket); ok { msg.t.(*DisconnectToken).flowComplete() DEBUG.Println(NET, "outbound wrote disconnect, closing connection") // As per the MQTT spec "After sending a DISCONNECT Packet the Client MUST close the Network Connection" @@ -336,7 +351,7 @@ func startOutgoingComms(conn net.Conn, // commsFns provide access to the client state (messageids, requesting disconnection and updating timing) type commsFns interface { getToken(id uint16) tokenCompletor // Retrieve the token for the specified messageid (if none then a dummy token must be returned) - freeID(id uint16) // Release the specified messageid (clearing out of any persistant store) + freeID(id uint16) // Release the specified messageid (clearing out of any persistent store) UpdateLastReceived() // Must be called whenever a packet is received UpdateLastSent() // Must be called whenever a packet is successfully sent getWriteTimeOut() time.Duration // Return the writetimeout (or 0 if none) @@ -346,9 +361,10 @@ type commsFns interface { } // startComms initiates goroutines that handles communications over the network connection -// Messages will be stored (via commsFns) and deleted from the store as neccessary +// Messages will be stored (via commsFns) and deleted from the store as necessary // It returns two channels: -// packets.PublishPacket - Will receive publish packets received over the network. Closed when incomming comms routines exit (on shutdown or if network link closed) +// packets.PublishPacket - Will receive publish packets received over the network. +// Closed when incoming comms routines exit (on shutdown or if network link closed) // error - Any errors will be sent on this channel. The channel is closed when all comms routines have shut down // // Note: The comms routines monitoring oboundp and obound will not shutdown until those channels are both closed. Any messages received between the @@ -356,15 +372,15 @@ type commsFns interface { // minimised. func startComms(conn net.Conn, // Network connection (must be active) c commsFns, // getters and setters to enable us to cleanly interact with client - inboundFromStore <-chan packets.ControlPacket, // Inbound packets from the persistance store (should be closed relatively soon after startup) + inboundFromStore <-chan packets.ControlPacket, // Inbound packets from the persistence store (should be closed relatively soon after startup) oboundp <-chan *PacketAndToken, obound <-chan *PacketAndToken) ( <-chan *packets.PublishPacket, // Publishpackages received over the network <-chan error, // Any errors (should generally trigger a disconnect) ) { // Start inbound comms handler; this needs to be able to transmit messages so we start a go routine to add these to the priority outbound channel - ibound := startIncommingComms(conn, c, inboundFromStore) - outboundFromIncomming := make(chan *PacketAndToken) // Will accept outgoing messages triggered by startIncommingComms (e.g. acknowledgements) + ibound := startIncomingComms(conn, c, inboundFromStore) + outboundFromIncomming := make(chan *PacketAndToken) // Will accept outgoing messages triggered by startIncomingComms (e.g. acknowledgements) oboundErr := startOutgoingComms(conn, c, oboundp, obound, outboundFromIncomming) DEBUG.Println(NET, "startComms started") @@ -400,7 +416,7 @@ func startComms(conn net.Conn, // Network connection (must be active) outPublish <- ic.incommingPub break } - ERROR.Println(STR, "startComms received empty incommingComms msg") + ERROR.Println(STR, "startComms received empty incomingComms msg") case err, ok := <-oboundErr: if !ok { oboundErr = nil