Skip to content

Commit

Permalink
WIP: MQTT fixes (#2986)
Browse files Browse the repository at this point in the history
* mqtt: expose "connfail" callback via :on()

This makes it just like all the other callbacks in the module and is a
revision of behavior called out in
#2967

* mqtt: clarify when puback callback fires

* mqtt: Don't reference stack buffers from the heap

The confusingly-named "mqtt_connection_t" object is just a triple of
  - a serialized mqtt message pointer and length
  - a buffer pointer (to which the above can be written)
  - a message identifier

The last of these must be passed around the mqtt state machine, but the
first two are very local and the buffer is always sourced from the C
stack.  Unfortunately, because the entire structure is persisted in the
heap, some callers assume that they can always use the structure without
reinitialization (see mqtt_socket_close), which will trash the C stack.

Sever the pairing between message id and local state, punt the local
state entirely out of the heap, and rename things to be less confusing.
  • Loading branch information
nwf authored Mar 14, 2020
1 parent c116d9d commit 787ac7c
Show file tree
Hide file tree
Showing 4 changed files with 238 additions and 190 deletions.
94 changes: 62 additions & 32 deletions app/modules/mqtt.c
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ typedef struct mqtt_state_t
{
uint16_t port;
mqtt_connect_info_t* connect_info;
mqtt_connection_t mqtt_connection;
msg_queue_t* pending_msg_q;
uint16_t next_message_id;

uint8_t * recv_buffer; // heap buffer for multi-packet rx
uint8_t * recv_buffer_wp; // write pointer in multi-packet rx
Expand Down Expand Up @@ -108,6 +108,15 @@ static void mqtt_socket_reconnected(void *arg, sint8_t err);
static void mqtt_socket_connected(void *arg);
static void mqtt_connack_fail(lmqtt_userdata * mud, int reason_code);

static uint16_t mqtt_next_message_id(lmqtt_userdata * mud)
{
mud->mqtt_state.next_message_id++;
if (mud->mqtt_state.next_message_id == 0)
mud->mqtt_state.next_message_id++;

return mud->mqtt_state.next_message_id;
}

static void mqtt_socket_disconnected(void *arg) // tcp only
{
NODE_DBG("enter mqtt_socket_disconnected.\n");
Expand Down Expand Up @@ -399,7 +408,8 @@ static void mqtt_socket_received(void *arg, char *pdata, unsigned short len)

// temp buffer for control messages
uint8_t temp_buffer[MQTT_BUF_SIZE];
mqtt_msg_init(&mud->mqtt_state.mqtt_connection, temp_buffer, MQTT_BUF_SIZE);
mqtt_message_buffer_t msgb;
mqtt_msg_init(&msgb, temp_buffer, MQTT_BUF_SIZE);
mqtt_message_t *temp_msg = NULL;

lua_State *L = lua_getstate();
Expand Down Expand Up @@ -450,8 +460,6 @@ static void mqtt_socket_received(void *arg, char *pdata, unsigned short len)
mud->connState = MQTT_DATA;
NODE_DBG("MQTT: Connected\r\n");
mud->keepalive_sent = 0;
luaL_unref(L, LUA_REGISTRYINDEX, mud->cb_connect_fail_ref);
mud->cb_connect_fail_ref = LUA_NOREF;
if(mud->cb_connect_ref == LUA_NOREF)
break;
if(mud->self_ref == LUA_NOREF)
Expand Down Expand Up @@ -492,12 +500,12 @@ static void mqtt_socket_received(void *arg, char *pdata, unsigned short len)
// buffering and special code to handle this corner-case. Server will most likely have
// written all to OS socket anyway, and not be aware that we "should" not have received it all yet.
if(msg_qos == 1){
temp_msg = mqtt_msg_puback(&mud->mqtt_state.mqtt_connection, msg_id);
temp_msg = mqtt_msg_puback(&msgb, msg_id);
msg_enqueue(&(mud->mqtt_state.pending_msg_q), temp_msg,
msg_id, MQTT_MSG_TYPE_PUBACK, (int)mqtt_get_qos(temp_msg->data) );
}
else if(msg_qos == 2){
temp_msg = mqtt_msg_pubrec(&mud->mqtt_state.mqtt_connection, msg_id);
temp_msg = mqtt_msg_pubrec(&msgb, msg_id);
msg_enqueue(&(mud->mqtt_state.pending_msg_q), temp_msg,
msg_id, MQTT_MSG_TYPE_PUBREC, (int)mqtt_get_qos(temp_msg->data) );
}
Expand Down Expand Up @@ -596,12 +604,12 @@ static void mqtt_socket_received(void *arg, char *pdata, unsigned short len)
break;
case MQTT_MSG_TYPE_PUBLISH:
if(msg_qos == 1){
temp_msg = mqtt_msg_puback(&mud->mqtt_state.mqtt_connection, msg_id);
temp_msg = mqtt_msg_puback(&msgb, msg_id);
msg_enqueue(&(mud->mqtt_state.pending_msg_q), temp_msg,
msg_id, MQTT_MSG_TYPE_PUBACK, (int)mqtt_get_qos(temp_msg->data) );
}
else if(msg_qos == 2){
temp_msg = mqtt_msg_pubrec(&mud->mqtt_state.mqtt_connection, msg_id);
temp_msg = mqtt_msg_pubrec(&msgb, msg_id);
msg_enqueue(&(mud->mqtt_state.pending_msg_q), temp_msg,
msg_id, MQTT_MSG_TYPE_PUBREC, (int)mqtt_get_qos(temp_msg->data) );
}
Expand Down Expand Up @@ -629,7 +637,7 @@ static void mqtt_socket_received(void *arg, char *pdata, unsigned short len)
NODE_DBG("MQTT: Publish with QoS = 2 Received PUBREC\r\n");
// Note: actually, should not destroy the msg until PUBCOMP is received.
msg_destroy(msg_dequeue(&(mud->mqtt_state.pending_msg_q)));
temp_msg = mqtt_msg_pubrel(&mud->mqtt_state.mqtt_connection, msg_id);
temp_msg = mqtt_msg_pubrel(&msgb, msg_id);
msg_enqueue(&(mud->mqtt_state.pending_msg_q), temp_msg,
msg_id, MQTT_MSG_TYPE_PUBREL, (int)mqtt_get_qos(temp_msg->data) );
NODE_DBG("MQTT: Response PUBREL\r\n");
Expand All @@ -638,7 +646,7 @@ static void mqtt_socket_received(void *arg, char *pdata, unsigned short len)
case MQTT_MSG_TYPE_PUBREL:
if(pending_msg && pending_msg->msg_type == MQTT_MSG_TYPE_PUBREC && pending_msg->msg_id == msg_id){
msg_destroy(msg_dequeue(&(mud->mqtt_state.pending_msg_q)));
temp_msg = mqtt_msg_pubcomp(&mud->mqtt_state.mqtt_connection, msg_id);
temp_msg = mqtt_msg_pubcomp(&msgb, msg_id);
msg_enqueue(&(mud->mqtt_state.pending_msg_q), temp_msg,
msg_id, MQTT_MSG_TYPE_PUBCOMP, (int)mqtt_get_qos(temp_msg->data) );
NODE_DBG("MQTT: Response PUBCOMP\r\n");
Expand All @@ -658,7 +666,7 @@ static void mqtt_socket_received(void *arg, char *pdata, unsigned short len)
}
break;
case MQTT_MSG_TYPE_PINGREQ:
temp_msg = mqtt_msg_pingresp(&mud->mqtt_state.mqtt_connection);
temp_msg = mqtt_msg_pingresp(&msgb);
msg_enqueue(&(mud->mqtt_state.pending_msg_q), temp_msg,
msg_id, MQTT_MSG_TYPE_PINGRESP, (int)mqtt_get_qos(temp_msg->data) );
NODE_DBG("MQTT: Response PINGRESP\r\n");
Expand Down Expand Up @@ -770,10 +778,12 @@ static void mqtt_socket_connected(void *arg)
espconn_regist_disconcb(pesp_conn, mqtt_socket_disconnected);

uint8_t temp_buffer[MQTT_BUF_SIZE];
// call mqtt_connect() to start a mqtt connect stage.
mqtt_msg_init(&mud->mqtt_state.mqtt_connection, temp_buffer, MQTT_BUF_SIZE);
mqtt_message_t* temp_msg = mqtt_msg_connect(&mud->mqtt_state.mqtt_connection, mud->mqtt_state.connect_info);
mqtt_message_buffer_t msgb;
mqtt_msg_init(&msgb, temp_buffer, MQTT_BUF_SIZE);

mqtt_message_t* temp_msg = mqtt_msg_connect(&msgb, mud->mqtt_state.connect_info);
NODE_DBG("Send MQTT connection infomation, data len: %d, d[0]=%d \r\n", temp_msg->length, temp_msg->data[0]);

mud->event_timeout = MQTT_SEND_TIMEOUT;
// not queue this message. should send right now. or should enqueue this before head.
#ifdef CLIENT_SSL_ENABLE
Expand Down Expand Up @@ -879,9 +889,11 @@ void mqtt_socket_timer(void *arg)
mqtt_socket_reconnected(mud->pesp_conn, 0);
} else {
uint8_t temp_buffer[MQTT_BUF_SIZE];
mqtt_msg_init(&mud->mqtt_state.mqtt_connection, temp_buffer, MQTT_BUF_SIZE);
mqtt_message_buffer_t msgb;
mqtt_msg_init(&msgb, temp_buffer, MQTT_BUF_SIZE);

NODE_DBG("\r\nMQTT: Send keepalive packet\r\n");
mqtt_message_t* temp_msg = mqtt_msg_pingreq(&mud->mqtt_state.mqtt_connection);
mqtt_message_t* temp_msg = mqtt_msg_pingreq(&msgb);
msg_queue_t *node = msg_enqueue( &(mud->mqtt_state.pending_msg_q), temp_msg,
0, MQTT_MSG_TYPE_PINGREQ, (int)mqtt_get_qos(temp_msg->data) );
mud->keepalive_sent = 1;
Expand Down Expand Up @@ -1380,8 +1392,12 @@ static int mqtt_socket_close( lua_State* L )

sint8 espconn_status = ESPCONN_CONN;
if (mud->connected) {
uint8_t temp_buffer[MQTT_BUF_SIZE];
mqtt_message_buffer_t msgb;
mqtt_msg_init(&msgb, temp_buffer, MQTT_BUF_SIZE);

// Send disconnect message
mqtt_message_t* temp_msg = mqtt_msg_disconnect(&mud->mqtt_state.mqtt_connection);
mqtt_message_t* temp_msg = mqtt_msg_disconnect(&msgb);
NODE_DBG("Send MQTT disconnect infomation, data len: %d, d[0]=%d \r\n", temp_msg->length, temp_msg->data[0]);

#ifdef CLIENT_SSL_ENABLE
Expand Down Expand Up @@ -1437,6 +1453,9 @@ static int mqtt_socket_on( lua_State* L )
if( sl == 7 && strcmp(method, "connect") == 0){
luaL_unref(L, LUA_REGISTRYINDEX, mud->cb_connect_ref);
mud->cb_connect_ref = luaL_ref(L, LUA_REGISTRYINDEX);
}else if( sl == 7 && strcmp(method, "connfail") == 0){
luaL_unref(L, LUA_REGISTRYINDEX, mud->cb_connect_fail_ref);
mud->cb_connect_ref = luaL_ref(L, LUA_REGISTRYINDEX);
}else if( sl == 7 && strcmp(method, "offline") == 0){
luaL_unref(L, LUA_REGISTRYINDEX, mud->cb_disconnect_ref);
mud->cb_disconnect_ref = luaL_ref(L, LUA_REGISTRYINDEX);
Expand Down Expand Up @@ -1468,7 +1487,7 @@ static int mqtt_socket_unsubscribe( lua_State* L ) {
NODE_DBG("enter mqtt_socket_unsubscribe.\n");

uint8_t stack = 1;
uint16_t msg_id = 0;
uint16_t msg_id;
const char *topic;
size_t il;
lmqtt_userdata *mud;
Expand Down Expand Up @@ -1496,7 +1515,8 @@ static int mqtt_socket_unsubscribe( lua_State* L ) {
}

uint8_t temp_buffer[MQTT_BUF_SIZE];
mqtt_msg_init(&mud->mqtt_state.mqtt_connection, temp_buffer, MQTT_BUF_SIZE);
mqtt_message_buffer_t msgb;
mqtt_msg_init(&msgb, temp_buffer, MQTT_BUF_SIZE);
mqtt_message_t *temp_msg = NULL;

if( lua_istable( L, stack ) ) {
Expand All @@ -1510,9 +1530,10 @@ static int mqtt_socket_unsubscribe( lua_State* L ) {
topic = luaL_checkstring( L, -2 );

if (topic_count == 0) {
temp_msg = mqtt_msg_unsubscribe_init( &mud->mqtt_state.mqtt_connection, &msg_id );
msg_id = mqtt_next_message_id(mud);
temp_msg = mqtt_msg_unsubscribe_init( &msgb, msg_id );
}
temp_msg = mqtt_msg_unsubscribe_topic( &mud->mqtt_state.mqtt_connection, topic );
temp_msg = mqtt_msg_unsubscribe_topic( &msgb, topic );
topic_count++;

NODE_DBG("topic: %s - length: %d\n", topic, temp_msg->length);
Expand All @@ -1533,7 +1554,7 @@ static int mqtt_socket_unsubscribe( lua_State* L ) {
return luaL_error( L, "buffer overflow, can't enqueue all unsubscriptions" );
}

temp_msg = mqtt_msg_unsubscribe_fini( &mud->mqtt_state.mqtt_connection );
temp_msg = mqtt_msg_unsubscribe_fini( &msgb );
if (temp_msg->length == 0) {
return luaL_error( L, "buffer overflow, can't enqueue all unsubscriptions" );
}
Expand All @@ -1546,7 +1567,8 @@ static int mqtt_socket_unsubscribe( lua_State* L ) {
if( topic == NULL ){
return luaL_error( L, "need topic name" );
}
temp_msg = mqtt_msg_unsubscribe( &mud->mqtt_state.mqtt_connection, topic, &msg_id );
msg_id = mqtt_next_message_id(mud);
temp_msg = mqtt_msg_unsubscribe( &msgb, topic, msg_id );
}

if( lua_type( L, stack ) == LUA_TFUNCTION || lua_type( L, stack ) == LUA_TLIGHTFUNCTION ) { // TODO: this will overwrite the previous one.
Expand Down Expand Up @@ -1580,7 +1602,7 @@ static int mqtt_socket_subscribe( lua_State* L ) {
NODE_DBG("enter mqtt_socket_subscribe.\n");

uint8_t stack = 1, qos = 0;
uint16_t msg_id = 0;
uint16_t msg_id;
const char *topic;
size_t il;
lmqtt_userdata *mud;
Expand Down Expand Up @@ -1608,7 +1630,8 @@ static int mqtt_socket_subscribe( lua_State* L ) {
}

uint8_t temp_buffer[MQTT_BUF_SIZE];
mqtt_msg_init(&mud->mqtt_state.mqtt_connection, temp_buffer, MQTT_BUF_SIZE);
mqtt_message_buffer_t msgb;
mqtt_msg_init(&msgb, temp_buffer, MQTT_BUF_SIZE);
mqtt_message_t *temp_msg = NULL;

if( lua_istable( L, stack ) ) {
Expand All @@ -1623,9 +1646,10 @@ static int mqtt_socket_subscribe( lua_State* L ) {
qos = luaL_checkinteger( L, -1 );

if (topic_count == 0) {
temp_msg = mqtt_msg_subscribe_init( &mud->mqtt_state.mqtt_connection, &msg_id );
msg_id = mqtt_next_message_id(mud);
temp_msg = mqtt_msg_subscribe_init( &msgb, msg_id );
}
temp_msg = mqtt_msg_subscribe_topic( &mud->mqtt_state.mqtt_connection, topic, qos );
temp_msg = mqtt_msg_subscribe_topic( &msgb, topic, qos );
topic_count++;

NODE_DBG("topic: %s - qos: %d, length: %d\n", topic, qos, temp_msg->length);
Expand All @@ -1646,7 +1670,7 @@ static int mqtt_socket_subscribe( lua_State* L ) {
return luaL_error( L, "buffer overflow, can't enqueue all subscriptions" );
}

temp_msg = mqtt_msg_subscribe_fini( &mud->mqtt_state.mqtt_connection );
temp_msg = mqtt_msg_subscribe_fini( &msgb );
if (temp_msg->length == 0) {
return luaL_error( L, "buffer overflow, can't enqueue all subscriptions" );
}
Expand All @@ -1660,7 +1684,8 @@ static int mqtt_socket_subscribe( lua_State* L ) {
return luaL_error( L, "need topic name" );
}
qos = luaL_checkinteger( L, stack );
temp_msg = mqtt_msg_subscribe( &mud->mqtt_state.mqtt_connection, topic, qos, &msg_id );
msg_id = mqtt_next_message_id(mud);
temp_msg = mqtt_msg_subscribe( &msgb, topic, qos, msg_id );
stack++;
}

Expand Down Expand Up @@ -1732,12 +1757,17 @@ static int mqtt_socket_publish( lua_State* L )
uint8_t retain = luaL_checkinteger( L, stack);
stack ++;

if (qos != 0) {
msg_id = mqtt_next_message_id(mud);
}

uint8_t temp_buffer[MQTT_BUF_SIZE];
mqtt_msg_init(&mud->mqtt_state.mqtt_connection, temp_buffer, MQTT_BUF_SIZE);
mqtt_message_t *temp_msg = mqtt_msg_publish(&mud->mqtt_state.mqtt_connection,
mqtt_message_buffer_t msgb;
mqtt_msg_init(&msgb, temp_buffer, MQTT_BUF_SIZE);
mqtt_message_t *temp_msg = mqtt_msg_publish(&msgb,
topic, payload, l,
qos, retain,
&msg_id);
msg_id);

if (lua_type(L, stack) == LUA_TFUNCTION || lua_type(L, stack) == LUA_TLIGHTFUNCTION){
lua_pushvalue(L, stack); // copy argument (func) to the top of stack
Expand Down
Loading

0 comments on commit 787ac7c

Please sign in to comment.