Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix timeout calculation in sendVector function to account for overflow #250

Merged
merged 7 commits into from
May 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions source/core_mqtt.c
Original file line number Diff line number Diff line change
Expand Up @@ -767,7 +767,7 @@ static int32_t sendMessageVector( MQTTContext_t * pContext,
size_t ioVecCount )
{
int32_t sendResult;
uint32_t timeoutMs;
uint32_t startTime;
TransportOutVector_t * pIoVectIterator;
size_t vectorsToBeSent = ioVecCount;
size_t bytesToSend = 0U;
Expand All @@ -788,8 +788,8 @@ static int32_t sendMessageVector( MQTTContext_t * pContext,
/* Reset the iterator to point to the first entry in the array. */
pIoVectIterator = pIoVec;

/* Set the timeout. */
timeoutMs = pContext->getTime() + MQTT_SEND_TIMEOUT_MS;
/* Note the start time. */
startTime = pContext->getTime();

while( ( bytesSentOrError < ( int32_t ) bytesToSend ) && ( bytesSentOrError >= 0 ) )
{
Expand Down Expand Up @@ -832,7 +832,7 @@ static int32_t sendMessageVector( MQTTContext_t * pContext,
}

/* Check for timeout. */
if( pContext->getTime() >= timeoutMs )
if( calculateElapsedTime( pContext->getTime(), startTime ) > MQTT_SEND_TIMEOUT_MS )
AniruddhaKanhere marked this conversation as resolved.
Show resolved Hide resolved
{
LogError( ( "sendMessageVector: Unable to send packet: Timed out." ) );
break;
Expand Down Expand Up @@ -1701,7 +1701,7 @@ static MQTTStatus_t receiveSingleIteration( MQTTContext_t * pContext,
pContext->index += ( size_t ) recvBytes;

status = MQTT_ProcessIncomingPacketTypeAndLength( pContext->networkBuffer.pBuffer,
&pContext->index,
&( pContext->index ),
&incomingPacket );

totalMQTTPacketLength = incomingPacket.remainingLength + incomingPacket.headerLength;
Expand Down
2 changes: 2 additions & 0 deletions test/unit-test/core_mqtt_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,6 @@ struct NetworkContext

#define MQTT_SUB_UNSUB_MAX_VECTORS ( 6U )

#define MQTT_SEND_TIMEOUT_MS ( 20U )

#endif /* ifndef CORE_MQTT_CONFIG_H_ */
117 changes: 117 additions & 0 deletions test/unit-test/core_mqtt_utest.c
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,37 @@ static uint32_t getTime( void )
return globalEntryTime++;
}

static int32_t getTimeMockCallLimit = -1;

/**
* @brief A mocked timer query function that increments on every call. This
* guarantees that only a single iteration runs in the ProcessLoop for ease
* of testing. Additionally, this tests whether the number of calls to this
* function have exceeded than the set limit and asserts.
*/
static uint32_t getTimeMock( void )
{
TEST_ASSERT_GREATER_THAN_INT32( -1, getTimeMockCallLimit-- );
return globalEntryTime++;
}

static int32_t getTimeMockBigTimeStepCallLimit = -1;

/**
* @brief A mocked timer query function that increments by MQTT_SEND_TIMEOUT_MS
* to simulate the time consumed by a long running high priority task on every
* call. Additionally, this tests whether the number of calls to this function
* have exceeded than the set limit and asserts.
*/
static uint32_t getTimeMockBigTimeStep( void )
{
TEST_ASSERT_GREATER_THAN_INT32( -1, getTimeMockBigTimeStepCallLimit-- );

globalEntryTime += MQTT_SEND_TIMEOUT_MS;
return globalEntryTime;
}


/**
* @brief A mocked timer function that could be used on a device with no system
* time.
Expand Down Expand Up @@ -4755,6 +4786,92 @@ void test_MQTT_Subscribe_error_paths2( void )
TEST_ASSERT_EQUAL( MQTTSendFailed, mqttStatus );
}

/**
* @brief This test case verifies that MQTT_Subscribe returns MQTTSendFailed
* if transport interface send fails and timer overflows.
*/
void test_MQTT_Subscribe_error_paths_timerOverflowCheck( void )
{
MQTTStatus_t mqttStatus;
MQTTContext_t context = { 0 };
TransportInterface_t transport = { 0 };
MQTTFixedBuffer_t networkBuffer = { 0 };
MQTTSubscribeInfo_t subscribeInfo = { 0 };
size_t remainingLength = MQTT_SAMPLE_REMAINING_LENGTH;
size_t packetSize = MQTT_SAMPLE_REMAINING_LENGTH;

globalEntryTime = UINT32_MAX - 2U;

/* The timer function can be called a maximum of these many times
* (which is way less than UINT32_MAX). This ensures that if overflow
* check is not correct, then the timer mock call will fail and assert. */
getTimeMockCallLimit = MQTT_SEND_TIMEOUT_MS + 1;

/* Verify that an error is propagated when transport interface returns an error. */
setupNetworkBuffer( &networkBuffer );
setupSubscriptionInfo( &subscribeInfo );
subscribeInfo.qos = MQTTQoS0;
setupTransportInterface( &transport );
transport.writev = NULL;
/* Case when there is timeout in sending data through transport send. */
transport.send = transportSendNoBytes; /* Use the mock function that returns zero bytes sent. */

/* Initialize context. */
mqttStatus = MQTT_Init( &context, &transport, getTimeMock, eventCallback, &networkBuffer );
TEST_ASSERT_EQUAL( MQTTSuccess, mqttStatus );

MQTT_GetSubscribePacketSize_ExpectAnyArgsAndReturn( MQTTSuccess );
MQTT_GetSubscribePacketSize_ReturnThruPtr_pPacketSize( &packetSize );
MQTT_GetSubscribePacketSize_ReturnThruPtr_pRemainingLength( &remainingLength );
MQTT_SerializeSubscribeHeader_Stub( MQTT_SerializeSubscribedHeader_cb );
mqttStatus = MQTT_Subscribe( &context, &subscribeInfo, 1, MQTT_FIRST_VALID_PACKET_ID );
TEST_ASSERT_EQUAL( MQTTSendFailed, mqttStatus );
TEST_ASSERT_EQUAL( -1, getTimeMockCallLimit );
}

/**
* @brief This test case verifies that MQTT_Subscribe returns MQTTSendFailed
* if transport interface send fails and timer overflows.
*/
void test_MQTT_Subscribe_error_paths_timerOverflowCheck1( void )
{
MQTTStatus_t mqttStatus;
MQTTContext_t context = { 0 };
TransportInterface_t transport = { 0 };
MQTTFixedBuffer_t networkBuffer = { 0 };
MQTTSubscribeInfo_t subscribeInfo = { 0 };
size_t remainingLength = MQTT_SAMPLE_REMAINING_LENGTH;
size_t packetSize = MQTT_SAMPLE_REMAINING_LENGTH;

globalEntryTime = UINT32_MAX - MQTT_SEND_TIMEOUT_MS + 1;

/* The timer function can be called a exactly 2 times. First when setting
* the initial time, next time when checking for timeout.
*/
getTimeMockBigTimeStepCallLimit = 2;

/* Verify that an error is propagated when transport interface returns an error. */
setupNetworkBuffer( &networkBuffer );
setupSubscriptionInfo( &subscribeInfo );
subscribeInfo.qos = MQTTQoS0;
setupTransportInterface( &transport );
transport.writev = NULL;
/* Case when there is timeout in sending data through transport send. */
transport.send = transportSendNoBytes; /* Use the mock function that returns zero bytes sent. */

/* Initialize context. */
mqttStatus = MQTT_Init( &context, &transport, getTimeMockBigTimeStep, eventCallback, &networkBuffer );
TEST_ASSERT_EQUAL( MQTTSuccess, mqttStatus );

MQTT_GetSubscribePacketSize_ExpectAnyArgsAndReturn( MQTTSuccess );
MQTT_GetSubscribePacketSize_ReturnThruPtr_pPacketSize( &packetSize );
MQTT_GetSubscribePacketSize_ReturnThruPtr_pRemainingLength( &remainingLength );
MQTT_SerializeSubscribeHeader_Stub( MQTT_SerializeSubscribedHeader_cb );
mqttStatus = MQTT_Subscribe( &context, &subscribeInfo, 1, MQTT_FIRST_VALID_PACKET_ID );
TEST_ASSERT_EQUAL( MQTTSendFailed, mqttStatus );
TEST_ASSERT_EQUAL( -1, getTimeMockBigTimeStepCallLimit );
}

/* ========================================================================== */

/**
Expand Down