Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix memory leak in MqttClient #1742

Merged
merged 2 commits into from
Jul 7, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 82 additions & 49 deletions Sming/Core/Network/MqttClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,38 @@

#include "Clock.h"

// Content length set to this value to indicate data refers to a stream, not a buffer
#define MQTT_PUBLISH_STREAM 0

mqtt_serialiser_t MqttClient::serialiser;
mqtt_parser_callbacks_t MqttClient::callbacks;

static mqtt_message_t* createMessage(mqtt_type_t messageType)
{
auto message = new mqtt_message_t;
if(message != nullptr) {
mqtt_message_init(message);
message->common.type = messageType;
}
return message;
}

static void deleteMessage(mqtt_message_t* message)
{
mqtt_message_clear(message, 0);
delete message;
}

static void clearMessage(mqtt_message_t& message)
{
mqtt_message_clear(&message, 0);
}

static bool copyString(mqtt_buffer_t& destBuffer, const String& sourceString)
{
destBuffer.length = sourceString.length();
destBuffer.data = (uint8_t*)malloc(sourceString.length());
MQTT_FREE(destBuffer.data); // Avoid memory leaks
destBuffer.data = (uint8_t*)MQTT_MALLOC(sourceString.length());
if(destBuffer.data == nullptr) {
debug_e("Not enough memory");
return false;
Expand All @@ -32,11 +55,6 @@ static bool copyString(mqtt_buffer_t& destBuffer, const String& sourceString)
return true;
}

#define COPY_STRING(TO, FROM) \
if(!copyString(TO, FROM)) { \
return false; \
}

MqttClient::MqttClient(bool withDefaultPayloadParser, bool autoDestruct) : TcpClient(autoDestruct)
{
// TODO:...
Expand Down Expand Up @@ -67,16 +85,15 @@ MqttClient::MqttClient(bool withDefaultPayloadParser, bool autoDestruct) : TcpCl
MqttClient::~MqttClient()
{
while(requestQueue.count() != 0) {
mqtt_message_clear(requestQueue.dequeue(), 1);
deleteMessage(requestQueue.dequeue());
}

mqtt_message_clear(&connectMessage, 0);
if(outgoingMessage != nullptr) {
mqtt_message_clear(outgoingMessage, 1);
outgoingMessage = nullptr;
clearMessage(connectMessage);
if(outgoingMessage != &connectMessage) {
deleteMessage(outgoingMessage);
}

mqtt_message_clear(&incomingMessage, 0);
outgoingMessage = nullptr;
clearMessage(incomingMessage);
}

bool MqttClient::onTcpReceive(TcpClient& client, char* data, int size)
Expand Down Expand Up @@ -188,10 +205,8 @@ bool MqttClient::setWill(const String& topic, const String& message, uint8_t fla
connectMessage.connect.flags.will_qos = (flags >> 1) & 0x03;
connectMessage.connect.flags.will = 1;

COPY_STRING(connectMessage.connect.will_topic, topic);
COPY_STRING(connectMessage.connect.will_message, message);

return true;
return copyString(connectMessage.connect.will_topic, topic) &&
copyString(connectMessage.connect.will_message, message);
}

bool MqttClient::connect(const Url& url, const String& clientName, uint32_t sslOptions)
Expand All @@ -210,24 +225,31 @@ bool MqttClient::connect(const Url& url, const String& clientName, uint32_t sslO

debug_d("MQTT start connection");

COPY_STRING(connectMessage.connect.protocol_name, F("MQTT"));
bool res = copyString(connectMessage.connect.protocol_name, F("MQTT"));

connectMessage.connect.keep_alive = keepAlive;

COPY_STRING(connectMessage.connect.client_id, clientName);
res &= copyString(connectMessage.connect.client_id, clientName);

if(url.User.length() > 0) {
connectMessage.connect.flags.username_follows = 1;
COPY_STRING(connectMessage.connect.username, url.User);
res &= copyString(connectMessage.connect.username, url.User);
if(url.Password.length() > 0) {
connectMessage.connect.flags.password_follows = 1;
COPY_STRING(connectMessage.connect.password, url.Password);
res &= copyString(connectMessage.connect.password, url.Password);
}
}

mqtt_message_t* message = (mqtt_message_t*)malloc(sizeof(mqtt_message_t));
memcpy(message, &connectMessage, sizeof(mqtt_message_t));
requestQueue.enqueue(message);
if(!res) {
debug_e("MQTT out of memory");
return false;
}

// We'll pick up connectMessage before sending any other queued messages
if(connectQueued) {
debug_i("MQTT replacing connect message");
}
connectQueued = true;

return TcpClient::connect(url.Host, url.getPort(), useSsl, sslOptions);
}
Expand All @@ -238,16 +260,16 @@ bool MqttClient::publish(const String& topic, const String& content, uint8_t fla
return false;
}

mqtt_message_t* message = (mqtt_message_t*)malloc(sizeof(mqtt_message_t));
mqtt_message_init(message);
message->common.type = MQTT_TYPE_PUBLISH;
auto message = createMessage(MQTT_TYPE_PUBLISH);

message->common.retain = static_cast<mqtt_retain_t>((flags >> 0) & 0x01);
message->common.qos = static_cast<mqtt_qos_t>((flags >> 1) & 0x03);
message->common.dup = static_cast<mqtt_dup_t>((flags >> 3) & 0x01);

COPY_STRING(message->publish.topic_name, topic);
COPY_STRING(message->publish.content, content);
if(!copyString(message->publish.topic_name, topic) || !copyString(message->publish.content, content)) {
delete message;
return false;
}

return requestQueue.enqueue(message);
}
Expand All @@ -263,15 +285,18 @@ bool MqttClient::publish(const String& topic, IDataSourceStream* stream, uint8_t
return false;
}

mqtt_message_t* message = (mqtt_message_t*)malloc(sizeof(mqtt_message_t));
mqtt_message_init(message);
message->common.type = MQTT_TYPE_PUBLISH;
auto message = createMessage(MQTT_TYPE_PUBLISH);

message->common.retain = static_cast<mqtt_retain_t>((flags >> 0) & 0x01);
message->common.qos = static_cast<mqtt_qos_t>((flags >> 1) & 0x03);
message->common.dup = static_cast<mqtt_dup_t>((flags >> 3) & 0x01);

COPY_STRING(message->publish.topic_name, topic);
if(!copyString(message->publish.topic_name, topic)) {
delete message;
delete stream;
return false;
}

message->publish.content.length = MQTT_PUBLISH_STREAM;
message->publish.content.data = (uint8_t*)stream;

Expand All @@ -286,13 +311,14 @@ bool MqttClient::subscribe(const String& topic)
return false;
}

mqtt_message_t* message = (mqtt_message_t*)malloc(sizeof(mqtt_message_t));
mqtt_message_init(message);
message->common.type = MQTT_TYPE_SUBSCRIBE;
message->subscribe.topics = (mqtt_topicpair_t*)malloc(sizeof(mqtt_topicpair_t));
memset(message->subscribe.topics, 0, sizeof(mqtt_topicpair_t));
auto message = createMessage(MQTT_TYPE_SUBSCRIBE);

COPY_STRING(message->subscribe.topics->name, topic);
message->subscribe.topics = (mqtt_topicpair_t*)MQTT_MALLOC(sizeof(mqtt_topicpair_t));
memset(message->subscribe.topics, 0, sizeof(mqtt_topicpair_t));
if(!copyString(message->subscribe.topics->name, topic)) {
delete message;
return false;
}

return requestQueue.enqueue(message);
}
Expand All @@ -305,12 +331,14 @@ bool MqttClient::unsubscribe(const String& topic)
return false;
}

mqtt_message_t* message = (mqtt_message_t*)malloc(sizeof(mqtt_message_t));
mqtt_message_init(message);
message->common.type = MQTT_TYPE_SUBSCRIBE;
message->unsubscribe.topics = (mqtt_topic_t*)malloc(sizeof(mqtt_topic_t));
auto message = createMessage(MQTT_TYPE_SUBSCRIBE);

message->unsubscribe.topics = (mqtt_topic_t*)MQTT_MALLOC(sizeof(mqtt_topic_t));
memset(message->unsubscribe.topics, 0, sizeof(mqtt_topic_t));
COPY_STRING(message->unsubscribe.topics->name, topic);
if(!copyString(message->unsubscribe.topics->name, topic)) {
delete message;
return false;
}

return requestQueue.enqueue(message);
}
Expand All @@ -320,18 +348,23 @@ void MqttClient::onReadyToSendData(TcpConnectionEvent sourceEvent)
switch(state) {
REENTER:
case eMCS_Ready: {
mqtt_message_clear(outgoingMessage, 1);
outgoingMessage = requestQueue.dequeue();
if(outgoingMessage != &connectMessage) {
deleteMessage(outgoingMessage);
}
if(connectQueued) {
outgoingMessage = &connectMessage;
connectQueued = false;
} else {
outgoingMessage = requestQueue.dequeue();
}
if(!outgoingMessage) {
// Send PINGREQ every PingRepeatTime time, if there is no outgoing traffic
// PingRepeatTime should be <= keepAlive
if(!(lastMessage && (millis() - lastMessage >= pingRepeatTime * 1000))) {
break;
}

outgoingMessage = (mqtt_message_t*)malloc(sizeof(mqtt_message_t));
memset(outgoingMessage, 0, sizeof(mqtt_message_t));
outgoingMessage->common.type = MQTT_TYPE_PINGREQ;
outgoingMessage = createMessage(MQTT_TYPE_PINGREQ);
}

IDataSourceStream* payloadStream = nullptr;
Expand Down
1 change: 1 addition & 0 deletions Sming/Core/Network/MqttClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ class MqttClient : protected TcpClient
// messages
MqttRequestQueue requestQueue;
mqtt_message_t connectMessage;
bool connectQueued = false; ///< True if our connect message needs to be sent
mqtt_message_t* outgoingMessage = nullptr;
mqtt_message_t incomingMessage;

Expand Down