Skip to content

Commit

Permalink
Fix memory leak in MqttClient (SmingHub#1742)
Browse files Browse the repository at this point in the history
* Handle connect message outside of queue, always give it priority
* In `copyString`, free destination buffer first to avoid memory leaks
* Use helper functions to create, delete and clear messages
* Use `MQTT_MALLOC` instead of `malloc` for messages and other items

* Remove COPY_STRING macro and handle allocation failures
  • Loading branch information
mikee47 authored and slav-at-attachix committed Jul 7, 2019
1 parent e4b63db commit 6f93d73
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 49 deletions.
131 changes: 82 additions & 49 deletions Sming/SmingCore/Network/MqttClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,38 @@

#include "../Clock.h"

// Content length set to this value to indicate data refers to a stream, not a buffer
#define MQTT_PUBLISH_STREAM 0

mqtt_serialiser_t MqttClient::serialiser;
mqtt_parser_callbacks_t MqttClient::callbacks;

static mqtt_message_t* createMessage(mqtt_type_t messageType)
{
auto message = new mqtt_message_t;
if(message != nullptr) {
mqtt_message_init(message);
message->common.type = messageType;
}
return message;
}

static void deleteMessage(mqtt_message_t* message)
{
mqtt_message_clear(message, 0);
delete message;
}

static void clearMessage(mqtt_message_t& message)
{
mqtt_message_clear(&message, 0);
}

static bool copyString(mqtt_buffer_t& destBuffer, const String& sourceString)
{
destBuffer.length = sourceString.length();
destBuffer.data = (uint8_t*)malloc(sourceString.length());
MQTT_FREE(destBuffer.data); // Avoid memory leaks
destBuffer.data = (uint8_t*)MQTT_MALLOC(sourceString.length());
if(destBuffer.data == nullptr) {
debug_e("Not enough memory");
return false;
Expand All @@ -32,11 +55,6 @@ static bool copyString(mqtt_buffer_t& destBuffer, const String& sourceString)
return true;
}

#define COPY_STRING(TO, FROM) \
if(!copyString(TO, FROM)) { \
return false; \
}

MqttClient::MqttClient(bool withDefaultPayloadParser, bool autoDestruct) : TcpClient(autoDestruct)
{
// TODO:...
Expand Down Expand Up @@ -67,16 +85,15 @@ MqttClient::MqttClient(bool withDefaultPayloadParser, bool autoDestruct) : TcpCl
MqttClient::~MqttClient()
{
while(requestQueue.count() != 0) {
mqtt_message_clear(requestQueue.dequeue(), 1);
deleteMessage(requestQueue.dequeue());
}

mqtt_message_clear(&connectMessage, 0);
if(outgoingMessage != nullptr) {
mqtt_message_clear(outgoingMessage, 1);
outgoingMessage = nullptr;
clearMessage(connectMessage);
if(outgoingMessage != &connectMessage) {
deleteMessage(outgoingMessage);
}

mqtt_message_clear(&incomingMessage, 0);
outgoingMessage = nullptr;
clearMessage(incomingMessage);
}

bool MqttClient::onTcpReceive(TcpClient& client, char* data, int size)
Expand Down Expand Up @@ -188,10 +205,8 @@ bool MqttClient::setWill(const String& topic, const String& message, uint8_t fla
connectMessage.connect.flags.will_qos = (flags >> 1) & 0x03;
connectMessage.connect.flags.will = 1;

COPY_STRING(connectMessage.connect.will_topic, topic);
COPY_STRING(connectMessage.connect.will_message, message);

return true;
return copyString(connectMessage.connect.will_topic, topic) &&
copyString(connectMessage.connect.will_message, message);
}

bool MqttClient::connect(const Url& url, const String& clientName, uint32_t sslOptions)
Expand All @@ -210,24 +225,31 @@ bool MqttClient::connect(const Url& url, const String& clientName, uint32_t sslO

debug_d("MQTT start connection");

COPY_STRING(connectMessage.connect.protocol_name, F("MQTT"));
bool res = copyString(connectMessage.connect.protocol_name, F("MQTT"));

connectMessage.connect.keep_alive = keepAlive;

COPY_STRING(connectMessage.connect.client_id, clientName);
res &= copyString(connectMessage.connect.client_id, clientName);

if(url.User.length() > 0) {
connectMessage.connect.flags.username_follows = 1;
COPY_STRING(connectMessage.connect.username, url.User);
res &= copyString(connectMessage.connect.username, url.User);
if(url.Password.length() > 0) {
connectMessage.connect.flags.password_follows = 1;
COPY_STRING(connectMessage.connect.password, url.Password);
res &= copyString(connectMessage.connect.password, url.Password);
}
}

mqtt_message_t* message = (mqtt_message_t*)malloc(sizeof(mqtt_message_t));
memcpy(message, &connectMessage, sizeof(mqtt_message_t));
requestQueue.enqueue(message);
if(!res) {
debug_e("MQTT out of memory");
return false;
}

// We'll pick up connectMessage before sending any other queued messages
if(connectQueued) {
debug_i("MQTT replacing connect message");
}
connectQueued = true;

return TcpClient::connect(url.Host, url.getPort(), useSsl, sslOptions);
}
Expand All @@ -238,16 +260,16 @@ bool MqttClient::publish(const String& topic, const String& content, uint8_t fla
return false;
}

mqtt_message_t* message = (mqtt_message_t*)malloc(sizeof(mqtt_message_t));
mqtt_message_init(message);
message->common.type = MQTT_TYPE_PUBLISH;
auto message = createMessage(MQTT_TYPE_PUBLISH);

message->common.retain = static_cast<mqtt_retain_t>((flags >> 0) & 0x01);
message->common.qos = static_cast<mqtt_qos_t>((flags >> 1) & 0x03);
message->common.dup = static_cast<mqtt_dup_t>((flags >> 3) & 0x01);

COPY_STRING(message->publish.topic_name, topic);
COPY_STRING(message->publish.content, content);
if(!copyString(message->publish.topic_name, topic) || !copyString(message->publish.content, content)) {
delete message;
return false;
}

return requestQueue.enqueue(message);
}
Expand All @@ -263,15 +285,18 @@ bool MqttClient::publish(const String& topic, IDataSourceStream* stream, uint8_t
return false;
}

mqtt_message_t* message = (mqtt_message_t*)malloc(sizeof(mqtt_message_t));
mqtt_message_init(message);
message->common.type = MQTT_TYPE_PUBLISH;
auto message = createMessage(MQTT_TYPE_PUBLISH);

message->common.retain = static_cast<mqtt_retain_t>((flags >> 0) & 0x01);
message->common.qos = static_cast<mqtt_qos_t>((flags >> 1) & 0x03);
message->common.dup = static_cast<mqtt_dup_t>((flags >> 3) & 0x01);

COPY_STRING(message->publish.topic_name, topic);
if(!copyString(message->publish.topic_name, topic)) {
delete message;
delete stream;
return false;
}

message->publish.content.length = MQTT_PUBLISH_STREAM;
message->publish.content.data = (uint8_t*)stream;

Expand All @@ -286,13 +311,14 @@ bool MqttClient::subscribe(const String& topic)
return false;
}

mqtt_message_t* message = (mqtt_message_t*)malloc(sizeof(mqtt_message_t));
mqtt_message_init(message);
message->common.type = MQTT_TYPE_SUBSCRIBE;
message->subscribe.topics = (mqtt_topicpair_t*)malloc(sizeof(mqtt_topicpair_t));
memset(message->subscribe.topics, 0, sizeof(mqtt_topicpair_t));
auto message = createMessage(MQTT_TYPE_SUBSCRIBE);

COPY_STRING(message->subscribe.topics->name, topic);
message->subscribe.topics = (mqtt_topicpair_t*)MQTT_MALLOC(sizeof(mqtt_topicpair_t));
memset(message->subscribe.topics, 0, sizeof(mqtt_topicpair_t));
if(!copyString(message->subscribe.topics->name, topic)) {
delete message;
return false;
}

return requestQueue.enqueue(message);
}
Expand All @@ -305,12 +331,14 @@ bool MqttClient::unsubscribe(const String& topic)
return false;
}

mqtt_message_t* message = (mqtt_message_t*)malloc(sizeof(mqtt_message_t));
mqtt_message_init(message);
message->common.type = MQTT_TYPE_SUBSCRIBE;
message->unsubscribe.topics = (mqtt_topic_t*)malloc(sizeof(mqtt_topic_t));
auto message = createMessage(MQTT_TYPE_SUBSCRIBE);

message->unsubscribe.topics = (mqtt_topic_t*)MQTT_MALLOC(sizeof(mqtt_topic_t));
memset(message->unsubscribe.topics, 0, sizeof(mqtt_topic_t));
COPY_STRING(message->unsubscribe.topics->name, topic);
if(!copyString(message->unsubscribe.topics->name, topic)) {
delete message;
return false;
}

return requestQueue.enqueue(message);
}
Expand All @@ -320,18 +348,23 @@ void MqttClient::onReadyToSendData(TcpConnectionEvent sourceEvent)
switch(state) {
REENTER:
case eMCS_Ready: {
mqtt_message_clear(outgoingMessage, 1);
outgoingMessage = requestQueue.dequeue();
if(outgoingMessage != &connectMessage) {
deleteMessage(outgoingMessage);
}
if(connectQueued) {
outgoingMessage = &connectMessage;
connectQueued = false;
} else {
outgoingMessage = requestQueue.dequeue();
}
if(!outgoingMessage) {
// Send PINGREQ every PingRepeatTime time, if there is no outgoing traffic
// PingRepeatTime should be <= keepAlive
if(!(lastMessage && (millis() - lastMessage >= pingRepeatTime * 1000))) {
break;
}

outgoingMessage = (mqtt_message_t*)malloc(sizeof(mqtt_message_t));
memset(outgoingMessage, 0, sizeof(mqtt_message_t));
outgoingMessage->common.type = MQTT_TYPE_PINGREQ;
outgoingMessage = createMessage(MQTT_TYPE_PINGREQ);
}

IDataSourceStream* payloadStream = nullptr;
Expand Down
1 change: 1 addition & 0 deletions Sming/SmingCore/Network/MqttClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ class MqttClient : protected TcpClient
// messages
MqttRequestQueue requestQueue;
mqtt_message_t connectMessage;
bool connectQueued = false; ///< True if our connect message needs to be sent
mqtt_message_t* outgoingMessage = nullptr;
mqtt_message_t incomingMessage;

Expand Down

0 comments on commit 6f93d73

Please sign in to comment.