Skip to content

Commit

Permalink
Merge pull request #452 from akokhanovskyi/fix-356
Browse files Browse the repository at this point in the history
Propagate MQTT connect error (fixes #356)
  • Loading branch information
Al S-M authored Sep 18, 2020
2 parents 5feda7b + 7facab9 commit ba85050
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 66 deletions.
2 changes: 1 addition & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ func (c *client) attemptConnection() (net.Conn, byte, bool, error) {
DEBUG.Println(CLI, "socket connected to broker")

// Now we send the perform the MQTT connection handshake
rc, sessionPresent = ConnectMQTT(conn, cm, protocolVersion)
rc, sessionPresent, err = connectMQTT(conn, cm, protocolVersion)
if rc == packets.Accepted {
break // successfully connected
}
Expand Down
146 changes: 81 additions & 65 deletions net.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
package mqtt

import (
"errors"
"io"
"net"
"reflect"
"strings"
Expand All @@ -26,11 +28,18 @@ import (

const closedNetConnErrorText = "use of closed network connection" // error string for closed conn (https://golang.org/src/net/error_test.go)

// ConnectMQTT takes a connected net.Conn and performs the initial MQTT handshake. Paramaters are:
// ConnectMQTT takes a connected net.Conn and performs the initial MQTT handshake. Parameters are:
// conn - Connected net.Conn
// cm - Connect Packet with everything other than the protocolname/version populated (historical reasons)
// cm - Connect Packet with everything other than the protocol name/version populated (historical reasons)
// protocolVersion - The protocol version to attempt to connect with
//
// Note that, for backward compatibility, ConnectMQTT() suppresses the actual connection error (compare to connectMQTT()).
func ConnectMQTT(conn net.Conn, cm *packets.ConnectPacket, protocolVersion uint) (byte, bool) {
rc, sessionPresent, _ := connectMQTT(conn, cm, protocolVersion)
return rc, sessionPresent
}

func connectMQTT(conn io.ReadWriter, cm *packets.ConnectPacket, protocolVersion uint) (byte, bool, error) {
switch protocolVersion {
case 3:
DEBUG.Println(CLI, "Using MQTT 3.1 protocol")
Expand All @@ -49,43 +58,46 @@ func ConnectMQTT(conn net.Conn, cm *packets.ConnectPacket, protocolVersion uint)
cm.ProtocolName = "MQTT"
cm.ProtocolVersion = 4
}

if err := cm.Write(conn); err != nil {
ERROR.Println(CLI, err)
return packets.ErrNetworkError, false, err
}

rc, sessionPresent := verifyCONNACK(conn)
return rc, sessionPresent
rc, sessionPresent, err := verifyCONNACK(conn)
return rc, sessionPresent, err
}

// This function is only used for receiving a connack
// when the connection is first started.
// This prevents receiving incoming data while resume
// is in progress if clean session is false.
func verifyCONNACK(conn net.Conn) (byte, bool) {
func verifyCONNACK(conn io.Reader) (byte, bool, error) {
DEBUG.Println(NET, "connect started")

ca, err := packets.ReadPacket(conn)
if err != nil {
ERROR.Println(NET, "connect got error", err)
return packets.ErrNetworkError, false
return packets.ErrNetworkError, false, err
}

if ca == nil {
ERROR.Println(NET, "received nil packet")
return packets.ErrNetworkError, false
return packets.ErrNetworkError, false, errors.New("nil CONNACK packet")
}

msg, ok := ca.(*packets.ConnackPacket)
if !ok {
ERROR.Println(NET, "received msg that was not CONNACK")
return packets.ErrNetworkError, false
return packets.ErrNetworkError, false, errors.New("non-CONNACK first packet received")
}

DEBUG.Println(NET, "received connack")
return msg.ReturnCode, msg.SessionPresent
return msg.ReturnCode, msg.SessionPresent, nil
}

// inbound encapuslates the output from startIncoming.
// err - If != nil then an error has occured
// err - If != nil then an error has occurred
// cp - A control packet received over the network link
type inbound struct {
err error
Expand All @@ -95,12 +107,13 @@ type inbound struct {
// startIncoming initiates a goroutine that reads incoming messages off the wire and sends them to the channel (returned).
// If there are any issues with the network connection then the returned cahnnel will be closed and the goroutine will exit
// (so closing the connection will terminate the goroutine)
func startIncoming(conn net.Conn) <-chan inbound {
func startIncoming(conn io.Reader) <-chan inbound {
var err error
var cp packets.ControlPacket
ibound := make(chan inbound)

DEBUG.Println(NET, "incoming started")

go func() {
for {
if cp, err = packets.ReadPacket(conn); err != nil {
Expand All @@ -118,35 +131,36 @@ func startIncoming(conn net.Conn) <-chan inbound {
ibound <- inbound{cp: cp}
}
}()

return ibound
}

// incommingComms encapuslates the possible output of the incommingComms routine. If err != nil then an error has occured and
// incomingComms encapuslates the possible output of the incomingComms routine. If err != nil then an error has occurred and
// the routine will have terminated; otherwise one of the other members should be non-nil
type incommingComms struct {
err error // If non-nil then there has been an error (ignore everything else)
outbound *PacketAndToken // Packet (with token) than needs to be sent out (e.g. an acknowledgement)
incommingPub *packets.PublishPacket // A new publish has been received; this will need to be passed on to our user
type incomingComms struct {
err error // If non-nil then there has been an error (ignore everything else)
outbound *PacketAndToken // Packet (with token) than needs to be sent out (e.g. an acknowledgement)
incomingPub *packets.PublishPacket // A new publish has been received; this will need to be passed on to our user
}

// startIncommingComms initiates incomming communications; this includes starting a goroutine to process incomming
// startIncomingComms initiates incoming communications; this includes starting a goroutine to process incoming
// messages.
// Accepts a channel of inbound messages from the store (persistanced messages); note this must be closed as soon as the
// everything in the store has been sent.
// Returns a channel that will be passed any received packets; this will be closed on a network error (and inboundFromStore closed)
func startIncommingComms(conn net.Conn,
func startIncomingComms(conn io.Reader,
c commsFns,
inboundFromStore <-chan packets.ControlPacket,
) <-chan incommingComms {
) <-chan incomingComms {
ibound := startIncoming(conn) // Start goroutine that reads from network connection
output := make(chan incommingComms)
output := make(chan incomingComms)

DEBUG.Println(NET, "startIncommingComms started")
DEBUG.Println(NET, "startIncomingComms started")
go func() {
for {
if inboundFromStore == nil && ibound == nil {
close(output)
DEBUG.Println(NET, "startIncommingComms goroutine complete")
DEBUG.Println(NET, "startIncomingComms goroutine complete")
return // As soon as ibound is closed we can exit (should have already processed an error)
}
DEBUG.Println(NET, "logic waiting for msg on ibound")
Expand All @@ -156,22 +170,22 @@ func startIncommingComms(conn net.Conn,
select {
case msg, ok = <-inboundFromStore:
if !ok {
DEBUG.Println(NET, "startIncommingComms: inboundFromStore complete")
DEBUG.Println(NET, "startIncomingComms: inboundFromStore complete")
inboundFromStore = nil // should happen quickly as this is only for persisted messages
continue
}
DEBUG.Println(NET, "startIncommingComms: got msg from store")
DEBUG.Println(NET, "startIncomingComms: got msg from store")
case ibMsg, ok := <-ibound:
if !ok {
DEBUG.Println(NET, "startIncommingComms: ibound complete")
DEBUG.Println(NET, "startIncomingComms: ibound complete")
ibound = nil
continue
}
DEBUG.Println(NET, "startIncommingComms: got msg on ibound")
DEBUG.Println(NET, "startIncomingComms: got msg on ibound")
// If the inbound comms routine encounters any issues it will send us an error.
if ibMsg.err != nil {
output <- incommingComms{err: ibMsg.err}
continue // Usually the channel will be closed immediatly after sending an error but safer that we do not assume this
output <- incomingComms{err: ibMsg.err}
continue // Usually the channel will be closed immediately after sending an error but safer that we do not assume this
}
msg = ibMsg.cp

Expand All @@ -181,44 +195,45 @@ func startIncommingComms(conn net.Conn,

switch m := msg.(type) {
case *packets.PingrespPacket:
DEBUG.Println(NET, "startIncommingComms: received pingresp")
DEBUG.Println(NET, "startIncomingComms: received pingresp")
c.pingRespReceived()
case *packets.SubackPacket:
DEBUG.Println(NET, "startIncommingComms: received suback, id:", m.MessageID)
DEBUG.Println(NET, "startIncomingComms: received suback, id:", m.MessageID)
token := c.getToken(m.MessageID)
switch t := token.(type) {
case *SubscribeToken:
DEBUG.Println(NET, "startIncommingComms: granted qoss", m.ReturnCodes)

if t, ok := token.(*SubscribeToken); ok {
DEBUG.Println(NET, "startIncomingComms: granted qoss", m.ReturnCodes)
for i, qos := range m.ReturnCodes {
t.subResult[t.subs[i]] = qos
}
}

token.flowComplete()
c.freeID(m.MessageID)
case *packets.UnsubackPacket:
DEBUG.Println(NET, "startIncommingComms: received unsuback, id:", m.MessageID)
DEBUG.Println(NET, "startIncomingComms: received unsuback, id:", m.MessageID)
c.getToken(m.MessageID).flowComplete()
c.freeID(m.MessageID)
case *packets.PublishPacket:
DEBUG.Println(NET, "startIncommingComms: received publish, msgId:", m.MessageID)
output <- incommingComms{incommingPub: m}
DEBUG.Println(NET, "startIncomingComms: received publish, msgId:", m.MessageID)
output <- incomingComms{incomingPub: m}
case *packets.PubackPacket:
DEBUG.Println(NET, "startIncommingComms: received puback, id:", m.MessageID)
DEBUG.Println(NET, "startIncomingComms: received puback, id:", m.MessageID)
c.getToken(m.MessageID).flowComplete()
c.freeID(m.MessageID)
case *packets.PubrecPacket:
DEBUG.Println(NET, "startIncommingComms: received pubrec, id:", m.MessageID)
DEBUG.Println(NET, "startIncomingComms: received pubrec, id:", m.MessageID)
prel := packets.NewControlPacket(packets.Pubrel).(*packets.PubrelPacket)
prel.MessageID = m.MessageID
output <- incommingComms{outbound: &PacketAndToken{p: prel, t: nil}}
output <- incomingComms{outbound: &PacketAndToken{p: prel, t: nil}}
case *packets.PubrelPacket:
DEBUG.Println(NET, "startIncommingComms: received pubrel, id:", m.MessageID)
DEBUG.Println(NET, "startIncomingComms: received pubrel, id:", m.MessageID)
pc := packets.NewControlPacket(packets.Pubcomp).(*packets.PubcompPacket)
pc.MessageID = m.MessageID
c.persistOutbound(pc)
output <- incommingComms{outbound: &PacketAndToken{p: pc, t: nil}}
output <- incomingComms{outbound: &PacketAndToken{p: pc, t: nil}}
case *packets.PubcompPacket:
DEBUG.Println(NET, "startIncommingComms: received pubcomp, id:", m.MessageID)
DEBUG.Println(NET, "startIncomingComms: received pubcomp, id:", m.MessageID)
c.getToken(m.MessageID).flowComplete()
c.freeID(m.MessageID)
}
Expand All @@ -229,14 +244,14 @@ func startIncommingComms(conn net.Conn,

// startOutgoingComms initiates a go routint to transmit outgoing packets.
// Pass in an open network connection and channels for outbound messages (including those triggered
// directly from incomming comms).
// directly from incoming comms).
// Returns a channel that will receive details of any errors (closed when the goroutine exits)
// This function wil only terminate when all input channels are closed
func startOutgoingComms(conn net.Conn,
c commsFns,
oboundp <-chan *PacketAndToken,
obound <-chan *PacketAndToken,
oboundFromIncomming <-chan *PacketAndToken,
oboundFromIncoming <-chan *PacketAndToken,
) <-chan error {
errChan := make(chan error)
DEBUG.Println(NET, "outgoing started")
Expand All @@ -248,7 +263,7 @@ func startOutgoingComms(conn net.Conn,
// This goroutine will only exits when all of the input channels we receive on have been closed. This approach is taken to avoid any
// deadlocks (if the connection goes down there are limited options as to what we can do with anything waiting on us and
// throwing away the packets seems the best option)
if oboundp == nil && obound == nil && oboundFromIncomming == nil {
if oboundp == nil && obound == nil && oboundFromIncoming == nil {
DEBUG.Println(NET, "outgoing comms stopping")
close(errChan)
return
Expand Down Expand Up @@ -306,22 +321,22 @@ func startOutgoingComms(conn net.Conn,
errChan <- err
continue
}
switch msg.p.(type) {
case *packets.DisconnectPacket:

if _, ok := msg.p.(*packets.DisconnectPacket); ok {
msg.t.(*DisconnectToken).flowComplete()
DEBUG.Println(NET, "outbound wrote disconnect, closing connection")
// As per the MQTT spec "After sending a DISCONNECT Packet the Client MUST close the Network Connection"
// Closing the connection will cause the goroutines to end in sequence (starting with incomming comms)
// Closing the connection will cause the goroutines to end in sequence (starting with incoming comms)
conn.Close()
}
case msg, ok := <-oboundFromIncomming: // message triggered by an inbound message (PubrecPacket or PubrelPacket)
case msg, ok := <-oboundFromIncoming: // message triggered by an inbound message (PubrecPacket or PubrelPacket)
if !ok {
oboundFromIncomming = nil
oboundFromIncoming = nil
continue
}
DEBUG.Println(NET, "obound from incomming msg to write, type", reflect.TypeOf(msg.p), " ID ", msg.p.Details().MessageID)
DEBUG.Println(NET, "obound from incoming msg to write, type", reflect.TypeOf(msg.p), " ID ", msg.p.Details().MessageID)
if err := msg.p.Write(conn); err != nil {
ERROR.Println(NET, "outgoing oboundFromIncomming reporting error", err)
ERROR.Println(NET, "outgoing oboundFromIncoming reporting error", err)
if msg.t != nil {
msg.t.setError(err)
}
Expand All @@ -338,7 +353,7 @@ func startOutgoingComms(conn net.Conn,
// commsFns provide access to the client state (messageids, requesting disconnection and updating timing)
type commsFns interface {
getToken(id uint16) tokenCompletor // Retrieve the token for the specified messageid (if none then a dummy token must be returned)
freeID(id uint16) // Release the specified messageid (clearing out of any persistant store)
freeID(id uint16) // Release the specified messageid (clearing out of any persistent store)
UpdateLastReceived() // Must be called whenever a packet is received
UpdateLastSent() // Must be called whenever a packet is successfully sent
getWriteTimeOut() time.Duration // Return the writetimeout (or 0 if none)
Expand All @@ -348,28 +363,29 @@ type commsFns interface {
}

// startComms initiates goroutines that handles communications over the network connection
// Messages will be stored (via commsFns) and deleted from the store as neccessary
// Messages will be stored (via commsFns) and deleted from the store as necessary
// It returns two channels:
// packets.PublishPacket - Will receive publish packets received over the network. Closed when incomming comms routines exit (on shutdown or if network link closed)
// packets.PublishPacket - Will receive publish packets received over the network.
// Closed when incoming comms routines exit (on shutdown or if network link closed)
// error - Any errors will be sent on this channel. The channel is closed when all comms routines have shut down
//
// Note: The comms routines monitoring oboundp and obound will not shutdown until those channels are both closed. Any messages received between the
// connection being closed and those channels being closed will generate errors (and nothing will be sent). That way the chance of a deadlock is
// minimised.
func startComms(conn net.Conn, // Network connection (must be active)
c commsFns, // getters and setters to enable us to cleanly interact with client
inboundFromStore <-chan packets.ControlPacket, // Inbound packets from the persistance store (should be closed relatively soon after startup)
inboundFromStore <-chan packets.ControlPacket, // Inbound packets from the persistence store (should be closed relatively soon after startup)
oboundp <-chan *PacketAndToken,
obound <-chan *PacketAndToken) (
<-chan *packets.PublishPacket, // Publishpackages received over the network
<-chan error, // Any errors (should generally trigger a disconnect)
) {
// Start inbound comms handler; this needs to be able to transmit messages so we start a go routine to add these to the priority outbound channel
ibound := startIncommingComms(conn, c, inboundFromStore)
outboundFromIncomming := make(chan *PacketAndToken) // Will accept outgoing messages triggered by startIncommingComms (e.g. acknowledgements)
ibound := startIncomingComms(conn, c, inboundFromStore)
outboundFromIncoming := make(chan *PacketAndToken) // Will accept outgoing messages triggered by startIncomingComms (e.g. acknowledgements)

// Start the outgoing handler. It is important to note that output from startIncommingComms is fed into startOutgoingComms (for ACK's)
oboundErr := startOutgoingComms(conn, c, oboundp, obound, outboundFromIncomming)
// Start the outgoing handler. It is important to note that output from startIncomingComms is fed into startOutgoingComms (for ACK's)
oboundErr := startOutgoingComms(conn, c, oboundp, obound, outboundFromIncoming)
DEBUG.Println(NET, "startComms started")

// Run up go routines to handle the output from the above comms functions - these are handled in seperate
Expand All @@ -388,17 +404,17 @@ func startComms(conn net.Conn, // Network connection (must be active)
continue
}
if ic.outbound != nil {
outboundFromIncomming <- ic.outbound
outboundFromIncoming <- ic.outbound
continue
}
if ic.incommingPub != nil {
outPublish <- ic.incommingPub
if ic.incomingPub != nil {
outPublish <- ic.incomingPub
continue
}
ERROR.Println(STR, "startComms received empty incommingComms msg")
ERROR.Println(STR, "startComms received empty incomingComms msg")
}
// Close channels that will not be written to again (allowing other routines to exit)
close(outboundFromIncomming)
close(outboundFromIncoming)
close(outPublish)
wg.Done()
}()
Expand Down

0 comments on commit ba85050

Please sign in to comment.