From 93e1f20435d612d14be9cbd67eddddb2b9bc19b1 Mon Sep 17 00:00:00 2001 From: mikee47 <mike@sillyhouse.net> Date: Fri, 5 Jul 2019 22:39:52 +0100 Subject: [PATCH 1/2] Queue at most one connect message and give it priority over others * Handle connect message outside of queue, always give it priority * In `copyString`, free destination buffer first to avoid memory leaks * Use helper functions to create, delete and clear messages * Use `MQTT_MALLOC` instead of `malloc` for messages and other items --- Sming/Core/Network/MqttClient.cpp | 84 +++++++++++++++++++------------ Sming/Core/Network/MqttClient.h | 1 + 2 files changed, 54 insertions(+), 31 deletions(-) diff --git a/Sming/Core/Network/MqttClient.cpp b/Sming/Core/Network/MqttClient.cpp index 28379d0daf..a1778b49bd 100644 --- a/Sming/Core/Network/MqttClient.cpp +++ b/Sming/Core/Network/MqttClient.cpp @@ -15,15 +15,38 @@ #include "Clock.h" +// Content length set to this value to indicate data refers to a stream, not a buffer #define MQTT_PUBLISH_STREAM 0 mqtt_serialiser_t MqttClient::serialiser; mqtt_parser_callbacks_t MqttClient::callbacks; +static mqtt_message_t* createMessage(mqtt_type_t messageType) +{ + auto message = new mqtt_message_t; + if(message != nullptr) { + mqtt_message_init(message); + message->common.type = messageType; + } + return message; +} + +static void deleteMessage(mqtt_message_t* message) +{ + mqtt_message_clear(message, 0); + delete message; +} + +static void clearMessage(mqtt_message_t& message) +{ + mqtt_message_clear(&message, 0); +} + static bool copyString(mqtt_buffer_t& destBuffer, const String& sourceString) { destBuffer.length = sourceString.length(); - destBuffer.data = (uint8_t*)malloc(sourceString.length()); + MQTT_FREE(destBuffer.data); // Avoid memory leaks + destBuffer.data = (uint8_t*)MQTT_MALLOC(sourceString.length()); if(destBuffer.data == nullptr) { debug_e("Not enough memory"); return false; @@ -67,16 +90,15 @@ MqttClient::MqttClient(bool withDefaultPayloadParser, bool autoDestruct) : TcpCl MqttClient::~MqttClient() { while(requestQueue.count() != 0) { - mqtt_message_clear(requestQueue.dequeue(), 1); + deleteMessage(requestQueue.dequeue()); } - mqtt_message_clear(&connectMessage, 0); - if(outgoingMessage != nullptr) { - mqtt_message_clear(outgoingMessage, 1); - outgoingMessage = nullptr; + clearMessage(connectMessage); + if(outgoingMessage != &connectMessage) { + deleteMessage(outgoingMessage); } - - mqtt_message_clear(&incomingMessage, 0); + outgoingMessage = nullptr; + clearMessage(incomingMessage); } bool MqttClient::onTcpReceive(TcpClient& client, char* data, int size) @@ -225,9 +247,11 @@ bool MqttClient::connect(const Url& url, const String& clientName, uint32_t sslO } } - mqtt_message_t* message = (mqtt_message_t*)malloc(sizeof(mqtt_message_t)); - memcpy(message, &connectMessage, sizeof(mqtt_message_t)); - requestQueue.enqueue(message); + // We'll pick up connectMessage before sending any other queued messages + if(connectQueued) { + debug_i("MQTT replacing connect message"); + } + connectQueued = true; return TcpClient::connect(url.Host, url.getPort(), useSsl, sslOptions); } @@ -238,9 +262,7 @@ bool MqttClient::publish(const String& topic, const String& content, uint8_t fla return false; } - mqtt_message_t* message = (mqtt_message_t*)malloc(sizeof(mqtt_message_t)); - mqtt_message_init(message); - message->common.type = MQTT_TYPE_PUBLISH; + auto message = createMessage(MQTT_TYPE_PUBLISH); message->common.retain = static_cast<mqtt_retain_t>((flags >> 0) & 0x01); message->common.qos = static_cast<mqtt_qos_t>((flags >> 1) & 0x03); @@ -263,9 +285,7 @@ bool MqttClient::publish(const String& topic, IDataSourceStream* stream, uint8_t return false; } - mqtt_message_t* message = (mqtt_message_t*)malloc(sizeof(mqtt_message_t)); - mqtt_message_init(message); - message->common.type = MQTT_TYPE_PUBLISH; + auto message = createMessage(MQTT_TYPE_PUBLISH); message->common.retain = static_cast<mqtt_retain_t>((flags >> 0) & 0x01); message->common.qos = static_cast<mqtt_qos_t>((flags >> 1) & 0x03); @@ -286,12 +306,10 @@ bool MqttClient::subscribe(const String& topic) return false; } - mqtt_message_t* message = (mqtt_message_t*)malloc(sizeof(mqtt_message_t)); - mqtt_message_init(message); - message->common.type = MQTT_TYPE_SUBSCRIBE; - message->subscribe.topics = (mqtt_topicpair_t*)malloc(sizeof(mqtt_topicpair_t)); - memset(message->subscribe.topics, 0, sizeof(mqtt_topicpair_t)); + auto message = createMessage(MQTT_TYPE_SUBSCRIBE); + message->subscribe.topics = (mqtt_topicpair_t*)MQTT_MALLOC(sizeof(mqtt_topicpair_t)); + memset(message->subscribe.topics, 0, sizeof(mqtt_topicpair_t)); COPY_STRING(message->subscribe.topics->name, topic); return requestQueue.enqueue(message); @@ -305,10 +323,9 @@ bool MqttClient::unsubscribe(const String& topic) return false; } - mqtt_message_t* message = (mqtt_message_t*)malloc(sizeof(mqtt_message_t)); - mqtt_message_init(message); - message->common.type = MQTT_TYPE_SUBSCRIBE; - message->unsubscribe.topics = (mqtt_topic_t*)malloc(sizeof(mqtt_topic_t)); + auto message = createMessage(MQTT_TYPE_SUBSCRIBE); + + message->unsubscribe.topics = (mqtt_topic_t*)MQTT_MALLOC(sizeof(mqtt_topic_t)); memset(message->unsubscribe.topics, 0, sizeof(mqtt_topic_t)); COPY_STRING(message->unsubscribe.topics->name, topic); @@ -320,8 +337,15 @@ void MqttClient::onReadyToSendData(TcpConnectionEvent sourceEvent) switch(state) { REENTER: case eMCS_Ready: { - mqtt_message_clear(outgoingMessage, 1); - outgoingMessage = requestQueue.dequeue(); + if(outgoingMessage != &connectMessage) { + deleteMessage(outgoingMessage); + } + if(connectQueued) { + outgoingMessage = &connectMessage; + connectQueued = false; + } else { + outgoingMessage = requestQueue.dequeue(); + } if(!outgoingMessage) { // Send PINGREQ every PingRepeatTime time, if there is no outgoing traffic // PingRepeatTime should be <= keepAlive @@ -329,9 +353,7 @@ void MqttClient::onReadyToSendData(TcpConnectionEvent sourceEvent) break; } - outgoingMessage = (mqtt_message_t*)malloc(sizeof(mqtt_message_t)); - memset(outgoingMessage, 0, sizeof(mqtt_message_t)); - outgoingMessage->common.type = MQTT_TYPE_PINGREQ; + outgoingMessage = createMessage(MQTT_TYPE_PINGREQ); } IDataSourceStream* payloadStream = nullptr; diff --git a/Sming/Core/Network/MqttClient.h b/Sming/Core/Network/MqttClient.h index f6d5bb5ea1..302032e689 100644 --- a/Sming/Core/Network/MqttClient.h +++ b/Sming/Core/Network/MqttClient.h @@ -302,6 +302,7 @@ class MqttClient : protected TcpClient // messages MqttRequestQueue requestQueue; mqtt_message_t connectMessage; + bool connectQueued = false; ///< True if our connect message needs to be sent mqtt_message_t* outgoingMessage = nullptr; mqtt_message_t incomingMessage; From 457a6554b920d279a933dbc9468675d24d7b0607 Mon Sep 17 00:00:00 2001 From: mikee47 <mike@sillyhouse.net> Date: Fri, 5 Jul 2019 23:03:20 +0100 Subject: [PATCH 2/2] Remove COPY_STRING macro and handle allocation failures --- Sming/Core/Network/MqttClient.cpp | 47 +++++++++++++++++++------------ 1 file changed, 29 insertions(+), 18 deletions(-) diff --git a/Sming/Core/Network/MqttClient.cpp b/Sming/Core/Network/MqttClient.cpp index a1778b49bd..3fc02b1275 100644 --- a/Sming/Core/Network/MqttClient.cpp +++ b/Sming/Core/Network/MqttClient.cpp @@ -55,11 +55,6 @@ static bool copyString(mqtt_buffer_t& destBuffer, const String& sourceString) return true; } -#define COPY_STRING(TO, FROM) \ - if(!copyString(TO, FROM)) { \ - return false; \ - } - MqttClient::MqttClient(bool withDefaultPayloadParser, bool autoDestruct) : TcpClient(autoDestruct) { // TODO:... @@ -210,10 +205,8 @@ bool MqttClient::setWill(const String& topic, const String& message, uint8_t fla connectMessage.connect.flags.will_qos = (flags >> 1) & 0x03; connectMessage.connect.flags.will = 1; - COPY_STRING(connectMessage.connect.will_topic, topic); - COPY_STRING(connectMessage.connect.will_message, message); - - return true; + return copyString(connectMessage.connect.will_topic, topic) && + copyString(connectMessage.connect.will_message, message); } bool MqttClient::connect(const Url& url, const String& clientName, uint32_t sslOptions) @@ -232,21 +225,26 @@ bool MqttClient::connect(const Url& url, const String& clientName, uint32_t sslO debug_d("MQTT start connection"); - COPY_STRING(connectMessage.connect.protocol_name, F("MQTT")); + bool res = copyString(connectMessage.connect.protocol_name, F("MQTT")); connectMessage.connect.keep_alive = keepAlive; - COPY_STRING(connectMessage.connect.client_id, clientName); + res &= copyString(connectMessage.connect.client_id, clientName); if(url.User.length() > 0) { connectMessage.connect.flags.username_follows = 1; - COPY_STRING(connectMessage.connect.username, url.User); + res &= copyString(connectMessage.connect.username, url.User); if(url.Password.length() > 0) { connectMessage.connect.flags.password_follows = 1; - COPY_STRING(connectMessage.connect.password, url.Password); + res &= copyString(connectMessage.connect.password, url.Password); } } + if(!res) { + debug_e("MQTT out of memory"); + return false; + } + // We'll pick up connectMessage before sending any other queued messages if(connectQueued) { debug_i("MQTT replacing connect message"); @@ -268,8 +266,10 @@ bool MqttClient::publish(const String& topic, const String& content, uint8_t fla message->common.qos = static_cast<mqtt_qos_t>((flags >> 1) & 0x03); message->common.dup = static_cast<mqtt_dup_t>((flags >> 3) & 0x01); - COPY_STRING(message->publish.topic_name, topic); - COPY_STRING(message->publish.content, content); + if(!copyString(message->publish.topic_name, topic) || !copyString(message->publish.content, content)) { + delete message; + return false; + } return requestQueue.enqueue(message); } @@ -291,7 +291,12 @@ bool MqttClient::publish(const String& topic, IDataSourceStream* stream, uint8_t message->common.qos = static_cast<mqtt_qos_t>((flags >> 1) & 0x03); message->common.dup = static_cast<mqtt_dup_t>((flags >> 3) & 0x01); - COPY_STRING(message->publish.topic_name, topic); + if(!copyString(message->publish.topic_name, topic)) { + delete message; + delete stream; + return false; + } + message->publish.content.length = MQTT_PUBLISH_STREAM; message->publish.content.data = (uint8_t*)stream; @@ -310,7 +315,10 @@ bool MqttClient::subscribe(const String& topic) message->subscribe.topics = (mqtt_topicpair_t*)MQTT_MALLOC(sizeof(mqtt_topicpair_t)); memset(message->subscribe.topics, 0, sizeof(mqtt_topicpair_t)); - COPY_STRING(message->subscribe.topics->name, topic); + if(!copyString(message->subscribe.topics->name, topic)) { + delete message; + return false; + } return requestQueue.enqueue(message); } @@ -327,7 +335,10 @@ bool MqttClient::unsubscribe(const String& topic) message->unsubscribe.topics = (mqtt_topic_t*)MQTT_MALLOC(sizeof(mqtt_topic_t)); memset(message->unsubscribe.topics, 0, sizeof(mqtt_topic_t)); - COPY_STRING(message->unsubscribe.topics->name, topic); + if(!copyString(message->unsubscribe.topics->name, topic)) { + delete message; + return false; + } return requestQueue.enqueue(message); }