diff --git a/Sming/Core/Network/MqttClient.cpp b/Sming/Core/Network/MqttClient.cpp index 28379d0daf..3fc02b1275 100644 --- a/Sming/Core/Network/MqttClient.cpp +++ b/Sming/Core/Network/MqttClient.cpp @@ -15,15 +15,38 @@ #include "Clock.h" +// Content length set to this value to indicate data refers to a stream, not a buffer #define MQTT_PUBLISH_STREAM 0 mqtt_serialiser_t MqttClient::serialiser; mqtt_parser_callbacks_t MqttClient::callbacks; +static mqtt_message_t* createMessage(mqtt_type_t messageType) +{ + auto message = new mqtt_message_t; + if(message != nullptr) { + mqtt_message_init(message); + message->common.type = messageType; + } + return message; +} + +static void deleteMessage(mqtt_message_t* message) +{ + mqtt_message_clear(message, 0); + delete message; +} + +static void clearMessage(mqtt_message_t& message) +{ + mqtt_message_clear(&message, 0); +} + static bool copyString(mqtt_buffer_t& destBuffer, const String& sourceString) { destBuffer.length = sourceString.length(); - destBuffer.data = (uint8_t*)malloc(sourceString.length()); + MQTT_FREE(destBuffer.data); // Avoid memory leaks + destBuffer.data = (uint8_t*)MQTT_MALLOC(sourceString.length()); if(destBuffer.data == nullptr) { debug_e("Not enough memory"); return false; @@ -32,11 +55,6 @@ static bool copyString(mqtt_buffer_t& destBuffer, const String& sourceString) return true; } -#define COPY_STRING(TO, FROM) \ - if(!copyString(TO, FROM)) { \ - return false; \ - } - MqttClient::MqttClient(bool withDefaultPayloadParser, bool autoDestruct) : TcpClient(autoDestruct) { // TODO:... @@ -67,16 +85,15 @@ MqttClient::MqttClient(bool withDefaultPayloadParser, bool autoDestruct) : TcpCl MqttClient::~MqttClient() { while(requestQueue.count() != 0) { - mqtt_message_clear(requestQueue.dequeue(), 1); + deleteMessage(requestQueue.dequeue()); } - mqtt_message_clear(&connectMessage, 0); - if(outgoingMessage != nullptr) { - mqtt_message_clear(outgoingMessage, 1); - outgoingMessage = nullptr; + clearMessage(connectMessage); + if(outgoingMessage != &connectMessage) { + deleteMessage(outgoingMessage); } - - mqtt_message_clear(&incomingMessage, 0); + outgoingMessage = nullptr; + clearMessage(incomingMessage); } bool MqttClient::onTcpReceive(TcpClient& client, char* data, int size) @@ -188,10 +205,8 @@ bool MqttClient::setWill(const String& topic, const String& message, uint8_t fla connectMessage.connect.flags.will_qos = (flags >> 1) & 0x03; connectMessage.connect.flags.will = 1; - COPY_STRING(connectMessage.connect.will_topic, topic); - COPY_STRING(connectMessage.connect.will_message, message); - - return true; + return copyString(connectMessage.connect.will_topic, topic) && + copyString(connectMessage.connect.will_message, message); } bool MqttClient::connect(const Url& url, const String& clientName, uint32_t sslOptions) @@ -210,24 +225,31 @@ bool MqttClient::connect(const Url& url, const String& clientName, uint32_t sslO debug_d("MQTT start connection"); - COPY_STRING(connectMessage.connect.protocol_name, F("MQTT")); + bool res = copyString(connectMessage.connect.protocol_name, F("MQTT")); connectMessage.connect.keep_alive = keepAlive; - COPY_STRING(connectMessage.connect.client_id, clientName); + res &= copyString(connectMessage.connect.client_id, clientName); if(url.User.length() > 0) { connectMessage.connect.flags.username_follows = 1; - COPY_STRING(connectMessage.connect.username, url.User); + res &= copyString(connectMessage.connect.username, url.User); if(url.Password.length() > 0) { connectMessage.connect.flags.password_follows = 1; - COPY_STRING(connectMessage.connect.password, url.Password); + res &= copyString(connectMessage.connect.password, url.Password); } } - mqtt_message_t* message = (mqtt_message_t*)malloc(sizeof(mqtt_message_t)); - memcpy(message, &connectMessage, sizeof(mqtt_message_t)); - requestQueue.enqueue(message); + if(!res) { + debug_e("MQTT out of memory"); + return false; + } + + // We'll pick up connectMessage before sending any other queued messages + if(connectQueued) { + debug_i("MQTT replacing connect message"); + } + connectQueued = true; return TcpClient::connect(url.Host, url.getPort(), useSsl, sslOptions); } @@ -238,16 +260,16 @@ bool MqttClient::publish(const String& topic, const String& content, uint8_t fla return false; } - mqtt_message_t* message = (mqtt_message_t*)malloc(sizeof(mqtt_message_t)); - mqtt_message_init(message); - message->common.type = MQTT_TYPE_PUBLISH; + auto message = createMessage(MQTT_TYPE_PUBLISH); message->common.retain = static_cast((flags >> 0) & 0x01); message->common.qos = static_cast((flags >> 1) & 0x03); message->common.dup = static_cast((flags >> 3) & 0x01); - COPY_STRING(message->publish.topic_name, topic); - COPY_STRING(message->publish.content, content); + if(!copyString(message->publish.topic_name, topic) || !copyString(message->publish.content, content)) { + delete message; + return false; + } return requestQueue.enqueue(message); } @@ -263,15 +285,18 @@ bool MqttClient::publish(const String& topic, IDataSourceStream* stream, uint8_t return false; } - mqtt_message_t* message = (mqtt_message_t*)malloc(sizeof(mqtt_message_t)); - mqtt_message_init(message); - message->common.type = MQTT_TYPE_PUBLISH; + auto message = createMessage(MQTT_TYPE_PUBLISH); message->common.retain = static_cast((flags >> 0) & 0x01); message->common.qos = static_cast((flags >> 1) & 0x03); message->common.dup = static_cast((flags >> 3) & 0x01); - COPY_STRING(message->publish.topic_name, topic); + if(!copyString(message->publish.topic_name, topic)) { + delete message; + delete stream; + return false; + } + message->publish.content.length = MQTT_PUBLISH_STREAM; message->publish.content.data = (uint8_t*)stream; @@ -286,13 +311,14 @@ bool MqttClient::subscribe(const String& topic) return false; } - mqtt_message_t* message = (mqtt_message_t*)malloc(sizeof(mqtt_message_t)); - mqtt_message_init(message); - message->common.type = MQTT_TYPE_SUBSCRIBE; - message->subscribe.topics = (mqtt_topicpair_t*)malloc(sizeof(mqtt_topicpair_t)); - memset(message->subscribe.topics, 0, sizeof(mqtt_topicpair_t)); + auto message = createMessage(MQTT_TYPE_SUBSCRIBE); - COPY_STRING(message->subscribe.topics->name, topic); + message->subscribe.topics = (mqtt_topicpair_t*)MQTT_MALLOC(sizeof(mqtt_topicpair_t)); + memset(message->subscribe.topics, 0, sizeof(mqtt_topicpair_t)); + if(!copyString(message->subscribe.topics->name, topic)) { + delete message; + return false; + } return requestQueue.enqueue(message); } @@ -305,12 +331,14 @@ bool MqttClient::unsubscribe(const String& topic) return false; } - mqtt_message_t* message = (mqtt_message_t*)malloc(sizeof(mqtt_message_t)); - mqtt_message_init(message); - message->common.type = MQTT_TYPE_SUBSCRIBE; - message->unsubscribe.topics = (mqtt_topic_t*)malloc(sizeof(mqtt_topic_t)); + auto message = createMessage(MQTT_TYPE_SUBSCRIBE); + + message->unsubscribe.topics = (mqtt_topic_t*)MQTT_MALLOC(sizeof(mqtt_topic_t)); memset(message->unsubscribe.topics, 0, sizeof(mqtt_topic_t)); - COPY_STRING(message->unsubscribe.topics->name, topic); + if(!copyString(message->unsubscribe.topics->name, topic)) { + delete message; + return false; + } return requestQueue.enqueue(message); } @@ -320,8 +348,15 @@ void MqttClient::onReadyToSendData(TcpConnectionEvent sourceEvent) switch(state) { REENTER: case eMCS_Ready: { - mqtt_message_clear(outgoingMessage, 1); - outgoingMessage = requestQueue.dequeue(); + if(outgoingMessage != &connectMessage) { + deleteMessage(outgoingMessage); + } + if(connectQueued) { + outgoingMessage = &connectMessage; + connectQueued = false; + } else { + outgoingMessage = requestQueue.dequeue(); + } if(!outgoingMessage) { // Send PINGREQ every PingRepeatTime time, if there is no outgoing traffic // PingRepeatTime should be <= keepAlive @@ -329,9 +364,7 @@ void MqttClient::onReadyToSendData(TcpConnectionEvent sourceEvent) break; } - outgoingMessage = (mqtt_message_t*)malloc(sizeof(mqtt_message_t)); - memset(outgoingMessage, 0, sizeof(mqtt_message_t)); - outgoingMessage->common.type = MQTT_TYPE_PINGREQ; + outgoingMessage = createMessage(MQTT_TYPE_PINGREQ); } IDataSourceStream* payloadStream = nullptr; diff --git a/Sming/Core/Network/MqttClient.h b/Sming/Core/Network/MqttClient.h index f6d5bb5ea1..302032e689 100644 --- a/Sming/Core/Network/MqttClient.h +++ b/Sming/Core/Network/MqttClient.h @@ -302,6 +302,7 @@ class MqttClient : protected TcpClient // messages MqttRequestQueue requestQueue; mqtt_message_t connectMessage; + bool connectQueued = false; ///< True if our connect message needs to be sent mqtt_message_t* outgoingMessage = nullptr; mqtt_message_t incomingMessage;