From 93e1f20435d612d14be9cbd67eddddb2b9bc19b1 Mon Sep 17 00:00:00 2001
From: mikee47 <mike@sillyhouse.net>
Date: Fri, 5 Jul 2019 22:39:52 +0100
Subject: [PATCH 1/2] Queue at most one connect message and give it priority
 over others

* Handle connect message outside of queue, always give it priority
* In `copyString`, free destination buffer first to avoid memory leaks
* Use helper functions to create, delete and clear messages
* Use `MQTT_MALLOC` instead of `malloc` for messages and other items
---
 Sming/Core/Network/MqttClient.cpp | 84 +++++++++++++++++++------------
 Sming/Core/Network/MqttClient.h   |  1 +
 2 files changed, 54 insertions(+), 31 deletions(-)

diff --git a/Sming/Core/Network/MqttClient.cpp b/Sming/Core/Network/MqttClient.cpp
index 28379d0daf..a1778b49bd 100644
--- a/Sming/Core/Network/MqttClient.cpp
+++ b/Sming/Core/Network/MqttClient.cpp
@@ -15,15 +15,38 @@
 
 #include "Clock.h"
 
+// Content length set to this value to indicate data refers to a stream, not a buffer
 #define MQTT_PUBLISH_STREAM 0
 
 mqtt_serialiser_t MqttClient::serialiser;
 mqtt_parser_callbacks_t MqttClient::callbacks;
 
+static mqtt_message_t* createMessage(mqtt_type_t messageType)
+{
+	auto message = new mqtt_message_t;
+	if(message != nullptr) {
+		mqtt_message_init(message);
+		message->common.type = messageType;
+	}
+	return message;
+}
+
+static void deleteMessage(mqtt_message_t* message)
+{
+	mqtt_message_clear(message, 0);
+	delete message;
+}
+
+static void clearMessage(mqtt_message_t& message)
+{
+	mqtt_message_clear(&message, 0);
+}
+
 static bool copyString(mqtt_buffer_t& destBuffer, const String& sourceString)
 {
 	destBuffer.length = sourceString.length();
-	destBuffer.data = (uint8_t*)malloc(sourceString.length());
+	MQTT_FREE(destBuffer.data); // Avoid memory leaks
+	destBuffer.data = (uint8_t*)MQTT_MALLOC(sourceString.length());
 	if(destBuffer.data == nullptr) {
 		debug_e("Not enough memory");
 		return false;
@@ -67,16 +90,15 @@ MqttClient::MqttClient(bool withDefaultPayloadParser, bool autoDestruct) : TcpCl
 MqttClient::~MqttClient()
 {
 	while(requestQueue.count() != 0) {
-		mqtt_message_clear(requestQueue.dequeue(), 1);
+		deleteMessage(requestQueue.dequeue());
 	}
 
-	mqtt_message_clear(&connectMessage, 0);
-	if(outgoingMessage != nullptr) {
-		mqtt_message_clear(outgoingMessage, 1);
-		outgoingMessage = nullptr;
+	clearMessage(connectMessage);
+	if(outgoingMessage != &connectMessage) {
+		deleteMessage(outgoingMessage);
 	}
-
-	mqtt_message_clear(&incomingMessage, 0);
+	outgoingMessage = nullptr;
+	clearMessage(incomingMessage);
 }
 
 bool MqttClient::onTcpReceive(TcpClient& client, char* data, int size)
@@ -225,9 +247,11 @@ bool MqttClient::connect(const Url& url, const String& clientName, uint32_t sslO
 		}
 	}
 
-	mqtt_message_t* message = (mqtt_message_t*)malloc(sizeof(mqtt_message_t));
-	memcpy(message, &connectMessage, sizeof(mqtt_message_t));
-	requestQueue.enqueue(message);
+	// We'll pick up connectMessage before sending any other queued messages
+	if(connectQueued) {
+		debug_i("MQTT replacing connect message");
+	}
+	connectQueued = true;
 
 	return TcpClient::connect(url.Host, url.getPort(), useSsl, sslOptions);
 }
@@ -238,9 +262,7 @@ bool MqttClient::publish(const String& topic, const String& content, uint8_t fla
 		return false;
 	}
 
-	mqtt_message_t* message = (mqtt_message_t*)malloc(sizeof(mqtt_message_t));
-	mqtt_message_init(message);
-	message->common.type = MQTT_TYPE_PUBLISH;
+	auto message = createMessage(MQTT_TYPE_PUBLISH);
 
 	message->common.retain = static_cast<mqtt_retain_t>((flags >> 0) & 0x01);
 	message->common.qos = static_cast<mqtt_qos_t>((flags >> 1) & 0x03);
@@ -263,9 +285,7 @@ bool MqttClient::publish(const String& topic, IDataSourceStream* stream, uint8_t
 		return false;
 	}
 
-	mqtt_message_t* message = (mqtt_message_t*)malloc(sizeof(mqtt_message_t));
-	mqtt_message_init(message);
-	message->common.type = MQTT_TYPE_PUBLISH;
+	auto message = createMessage(MQTT_TYPE_PUBLISH);
 
 	message->common.retain = static_cast<mqtt_retain_t>((flags >> 0) & 0x01);
 	message->common.qos = static_cast<mqtt_qos_t>((flags >> 1) & 0x03);
@@ -286,12 +306,10 @@ bool MqttClient::subscribe(const String& topic)
 		return false;
 	}
 
-	mqtt_message_t* message = (mqtt_message_t*)malloc(sizeof(mqtt_message_t));
-	mqtt_message_init(message);
-	message->common.type = MQTT_TYPE_SUBSCRIBE;
-	message->subscribe.topics = (mqtt_topicpair_t*)malloc(sizeof(mqtt_topicpair_t));
-	memset(message->subscribe.topics, 0, sizeof(mqtt_topicpair_t));
+	auto message = createMessage(MQTT_TYPE_SUBSCRIBE);
 
+	message->subscribe.topics = (mqtt_topicpair_t*)MQTT_MALLOC(sizeof(mqtt_topicpair_t));
+	memset(message->subscribe.topics, 0, sizeof(mqtt_topicpair_t));
 	COPY_STRING(message->subscribe.topics->name, topic);
 
 	return requestQueue.enqueue(message);
@@ -305,10 +323,9 @@ bool MqttClient::unsubscribe(const String& topic)
 		return false;
 	}
 
-	mqtt_message_t* message = (mqtt_message_t*)malloc(sizeof(mqtt_message_t));
-	mqtt_message_init(message);
-	message->common.type = MQTT_TYPE_SUBSCRIBE;
-	message->unsubscribe.topics = (mqtt_topic_t*)malloc(sizeof(mqtt_topic_t));
+	auto message = createMessage(MQTT_TYPE_SUBSCRIBE);
+
+	message->unsubscribe.topics = (mqtt_topic_t*)MQTT_MALLOC(sizeof(mqtt_topic_t));
 	memset(message->unsubscribe.topics, 0, sizeof(mqtt_topic_t));
 	COPY_STRING(message->unsubscribe.topics->name, topic);
 
@@ -320,8 +337,15 @@ void MqttClient::onReadyToSendData(TcpConnectionEvent sourceEvent)
 	switch(state) {
 	REENTER:
 	case eMCS_Ready: {
-		mqtt_message_clear(outgoingMessage, 1);
-		outgoingMessage = requestQueue.dequeue();
+		if(outgoingMessage != &connectMessage) {
+			deleteMessage(outgoingMessage);
+		}
+		if(connectQueued) {
+			outgoingMessage = &connectMessage;
+			connectQueued = false;
+		} else {
+			outgoingMessage = requestQueue.dequeue();
+		}
 		if(!outgoingMessage) {
 			// Send PINGREQ every PingRepeatTime time, if there is no outgoing traffic
 			// PingRepeatTime should be <= keepAlive
@@ -329,9 +353,7 @@ void MqttClient::onReadyToSendData(TcpConnectionEvent sourceEvent)
 				break;
 			}
 
-			outgoingMessage = (mqtt_message_t*)malloc(sizeof(mqtt_message_t));
-			memset(outgoingMessage, 0, sizeof(mqtt_message_t));
-			outgoingMessage->common.type = MQTT_TYPE_PINGREQ;
+			outgoingMessage = createMessage(MQTT_TYPE_PINGREQ);
 		}
 
 		IDataSourceStream* payloadStream = nullptr;
diff --git a/Sming/Core/Network/MqttClient.h b/Sming/Core/Network/MqttClient.h
index f6d5bb5ea1..302032e689 100644
--- a/Sming/Core/Network/MqttClient.h
+++ b/Sming/Core/Network/MqttClient.h
@@ -302,6 +302,7 @@ class MqttClient : protected TcpClient
 	// messages
 	MqttRequestQueue requestQueue;
 	mqtt_message_t connectMessage;
+	bool connectQueued = false; ///< True if our connect message needs to be sent
 	mqtt_message_t* outgoingMessage = nullptr;
 	mqtt_message_t incomingMessage;
 

From 457a6554b920d279a933dbc9468675d24d7b0607 Mon Sep 17 00:00:00 2001
From: mikee47 <mike@sillyhouse.net>
Date: Fri, 5 Jul 2019 23:03:20 +0100
Subject: [PATCH 2/2] Remove COPY_STRING macro and handle allocation failures

---
 Sming/Core/Network/MqttClient.cpp | 47 +++++++++++++++++++------------
 1 file changed, 29 insertions(+), 18 deletions(-)

diff --git a/Sming/Core/Network/MqttClient.cpp b/Sming/Core/Network/MqttClient.cpp
index a1778b49bd..3fc02b1275 100644
--- a/Sming/Core/Network/MqttClient.cpp
+++ b/Sming/Core/Network/MqttClient.cpp
@@ -55,11 +55,6 @@ static bool copyString(mqtt_buffer_t& destBuffer, const String& sourceString)
 	return true;
 }
 
-#define COPY_STRING(TO, FROM)                                                                                          \
-	if(!copyString(TO, FROM)) {                                                                                        \
-		return false;                                                                                                  \
-	}
-
 MqttClient::MqttClient(bool withDefaultPayloadParser, bool autoDestruct) : TcpClient(autoDestruct)
 {
 	// TODO:...
@@ -210,10 +205,8 @@ bool MqttClient::setWill(const String& topic, const String& message, uint8_t fla
 	connectMessage.connect.flags.will_qos = (flags >> 1) & 0x03;
 	connectMessage.connect.flags.will = 1;
 
-	COPY_STRING(connectMessage.connect.will_topic, topic);
-	COPY_STRING(connectMessage.connect.will_message, message);
-
-	return true;
+	return copyString(connectMessage.connect.will_topic, topic) &&
+		   copyString(connectMessage.connect.will_message, message);
 }
 
 bool MqttClient::connect(const Url& url, const String& clientName, uint32_t sslOptions)
@@ -232,21 +225,26 @@ bool MqttClient::connect(const Url& url, const String& clientName, uint32_t sslO
 
 	debug_d("MQTT start connection");
 
-	COPY_STRING(connectMessage.connect.protocol_name, F("MQTT"));
+	bool res = copyString(connectMessage.connect.protocol_name, F("MQTT"));
 
 	connectMessage.connect.keep_alive = keepAlive;
 
-	COPY_STRING(connectMessage.connect.client_id, clientName);
+	res &= copyString(connectMessage.connect.client_id, clientName);
 
 	if(url.User.length() > 0) {
 		connectMessage.connect.flags.username_follows = 1;
-		COPY_STRING(connectMessage.connect.username, url.User);
+		res &= copyString(connectMessage.connect.username, url.User);
 		if(url.Password.length() > 0) {
 			connectMessage.connect.flags.password_follows = 1;
-			COPY_STRING(connectMessage.connect.password, url.Password);
+			res &= copyString(connectMessage.connect.password, url.Password);
 		}
 	}
 
+	if(!res) {
+		debug_e("MQTT out of memory");
+		return false;
+	}
+
 	// We'll pick up connectMessage before sending any other queued messages
 	if(connectQueued) {
 		debug_i("MQTT replacing connect message");
@@ -268,8 +266,10 @@ bool MqttClient::publish(const String& topic, const String& content, uint8_t fla
 	message->common.qos = static_cast<mqtt_qos_t>((flags >> 1) & 0x03);
 	message->common.dup = static_cast<mqtt_dup_t>((flags >> 3) & 0x01);
 
-	COPY_STRING(message->publish.topic_name, topic);
-	COPY_STRING(message->publish.content, content);
+	if(!copyString(message->publish.topic_name, topic) || !copyString(message->publish.content, content)) {
+		delete message;
+		return false;
+	}
 
 	return requestQueue.enqueue(message);
 }
@@ -291,7 +291,12 @@ bool MqttClient::publish(const String& topic, IDataSourceStream* stream, uint8_t
 	message->common.qos = static_cast<mqtt_qos_t>((flags >> 1) & 0x03);
 	message->common.dup = static_cast<mqtt_dup_t>((flags >> 3) & 0x01);
 
-	COPY_STRING(message->publish.topic_name, topic);
+	if(!copyString(message->publish.topic_name, topic)) {
+		delete message;
+		delete stream;
+		return false;
+	}
+
 	message->publish.content.length = MQTT_PUBLISH_STREAM;
 	message->publish.content.data = (uint8_t*)stream;
 
@@ -310,7 +315,10 @@ bool MqttClient::subscribe(const String& topic)
 
 	message->subscribe.topics = (mqtt_topicpair_t*)MQTT_MALLOC(sizeof(mqtt_topicpair_t));
 	memset(message->subscribe.topics, 0, sizeof(mqtt_topicpair_t));
-	COPY_STRING(message->subscribe.topics->name, topic);
+	if(!copyString(message->subscribe.topics->name, topic)) {
+		delete message;
+		return false;
+	}
 
 	return requestQueue.enqueue(message);
 }
@@ -327,7 +335,10 @@ bool MqttClient::unsubscribe(const String& topic)
 
 	message->unsubscribe.topics = (mqtt_topic_t*)MQTT_MALLOC(sizeof(mqtt_topic_t));
 	memset(message->unsubscribe.topics, 0, sizeof(mqtt_topic_t));
-	COPY_STRING(message->unsubscribe.topics->name, topic);
+	if(!copyString(message->unsubscribe.topics->name, topic)) {
+		delete message;
+		return false;
+	}
 
 	return requestQueue.enqueue(message);
 }