diff --git a/.github/.cSpellWords.txt b/.github/.cSpellWords.txt index 3159b02..b00d6ae 100644 --- a/.github/.cSpellWords.txt +++ b/.github/.cSpellWords.txt @@ -40,6 +40,7 @@ compatibil coremqtt COSE CLRF +Cqqpk CSRS ctest Customisation @@ -80,6 +81,7 @@ evlog EVTOLUX evtolux feip +Fgdszfl FIFOS fftr fftwrap @@ -110,6 +112,7 @@ iridix IRQN ISRAM istty +Ixyh JITP JITR Jytl @@ -150,6 +153,7 @@ Mqtt MQTT MQTT's mqttexample +Mqud MVEI mytestthing myTestThing @@ -233,16 +237,23 @@ unusued utilises USART UYVY +Vbex VCLK VECTACTIVE venv vmean vsocket vsync +Vwij YPJLH WGHT wght WLATENCY +Xcycu xtea +Xwzpc +Xxrcgxi zeroize ZEROIZE +ZIUVJ +Zwjr diff --git a/applications/keyword_detection/ml_interface.cc b/applications/keyword_detection/ml_interface.cc index 15663c1..c775edf 100644 --- a/applications/keyword_detection/ml_interface.cc +++ b/applications/keyword_detection/ml_interface.cc @@ -85,68 +85,67 @@ extern QueueHandle_t xMlMqttQueue; #ifdef AUDIO_VSI -#include "Driver_SAI.h" + #include "Driver_SAI.h" -#define AUDIO_BLOCK_NUM (4) -#define AUDIO_BLOCK_SIZE (3200) -#define AUDIO_BUFFER_SIZE (AUDIO_BLOCK_NUM * AUDIO_BLOCK_SIZE) + #define AUDIO_BLOCK_NUM ( 4 ) + #define AUDIO_BLOCK_SIZE ( 3200 ) + #define AUDIO_BUFFER_SIZE ( AUDIO_BLOCK_NUM * AUDIO_BLOCK_SIZE ) -// audio constants -__attribute__((section(".bss.NoInit.vsi_audio_buffer"))) __attribute__((aligned(4))) -int16_t shared_audio_buffer[AUDIO_BUFFER_SIZE / 2]; -const int kAudioSampleFrequency = 16000; +/* audio constants */ + __attribute__( ( section( ".bss.NoInit.vsi_audio_buffer" ) ) ) __attribute__( ( aligned( 4 ) ) ) + int16_t shared_audio_buffer[ AUDIO_BUFFER_SIZE / 2 ]; + const int kAudioSampleFrequency = 16000; -extern ARM_DRIVER_SAI Driver_SAI0; -extern TaskHandle_t xVsiTaskHandle; + extern ARM_DRIVER_SAI Driver_SAI0; + extern TaskHandle_t xVsiTaskHandle; -uint32_t ulVsiEvent; + uint32_t ulVsiEvent; #else /* !defined(AUDIO_VSI) */ -#include "InputFiles.hpp" + #include "InputFiles.hpp" #endif /* AUDIO_VSI */ -// Define tensor arena and declare functions required to access the model +/* Define tensor arena and declare functions required to access the model */ namespace arm { namespace app { -uint8_t tensorArena[ACTIVATION_BUF_SZ] ACTIVATION_BUF_ATTRIBUTE; +uint8_t tensorArena[ ACTIVATION_BUF_SZ ] ACTIVATION_BUF_ATTRIBUTE; namespace kws { -extern uint8_t *GetModelPointer(); +extern uint8_t * GetModelPointer(); extern size_t GetModelLen(); } /* namespace kws */ } /* namespace app */ } /* namespace arm */ namespace { - -typedef struct { +typedef struct +{ ml_processing_state_t state; } ml_mqtt_msg_t; -// Import +/* Import */ using namespace arm::app; ml_processing_change_handler_t ml_processing_change_handler = NULL; -void *ml_processing_change_ptr = NULL; +void * ml_processing_change_ptr = NULL; const std::array, 12> label_to_state{ - std::pair{"_silence_", ML_SILENCE}, - std::pair{"_unknown_", ML_UNKNOWN}, - std::pair{"yes", ML_HEARD_YES}, - std::pair{"no", ML_HEARD_NO}, - std::pair{"up", ML_HEARD_UP}, - std::pair{"down", ML_HEARD_DOWN}, - std::pair{"left", ML_HEARD_LEFT}, - std::pair{"right", ML_HEARD_RIGHT}, - std::pair{"on", ML_HEARD_ON}, - std::pair{"off", ML_HEARD_OFF}, - std::pair{"go", ML_HEARD_GO}, - std::pair{"stop", ML_HEARD_STOP}, + std::pair{ "_silence_", ML_SILENCE }, + std::pair{ "_unknown_", ML_UNKNOWN }, + std::pair{ "yes", ML_HEARD_YES }, + std::pair{ "no", ML_HEARD_NO }, + std::pair{ "up", ML_HEARD_UP }, + std::pair{ "down", ML_HEARD_DOWN }, + std::pair{ "left", ML_HEARD_LEFT }, + std::pair{ "right", ML_HEARD_RIGHT }, + std::pair{ "on", ML_HEARD_ON }, + std::pair{ "off", ML_HEARD_OFF }, + std::pair{ "go", ML_HEARD_GO }, + std::pair{ "stop", ML_HEARD_STOP }, }; extern "C" { - static void prvAppPublishCommandCallback( MQTTAgentCommandContext_t * pxCommandContext, MQTTAgentReturnInfo_t * pxReturnInfo ) { @@ -160,9 +159,9 @@ static void prvAppPublishCommandCallback( MQTTAgentCommandContext_t * pxCommandC static void prvMqttSendMessage( const char * message ) { - static MQTTPublishInfo_t publishInfo = { (MQTTQoS_t)0 }; + static MQTTPublishInfo_t publishInfo = { ( MQTTQoS_t ) 0 }; static MQTTAgentCommandInfo_t xCommandParams = { 0 }; - static MQTTAgentCommandContext_t xCommandContext = { (MQTTStatus_t)0 }; + static MQTTAgentCommandContext_t xCommandContext = { ( MQTTStatus_t ) 0 }; MQTTStatus_t mqttStatus = MQTTBadParameter; publishInfo.pTopicName = mqttexampleTOPIC; @@ -211,14 +210,14 @@ static void prvMqttSendMessage( const char * message ) } } -static const char *prvGetInferenceResultString(ml_processing_state_t ref_state) +static const char * prvGetInferenceResultString( ml_processing_state_t ref_state ) { - return (label_to_state[ref_state].first); + return( label_to_state[ ref_state ].first ); } void vMlTaskInferenceStart() { - if(xSystemEvents == NULL) + if( xSystemEvents == NULL ) { LogError( ( "xSystemEvents is not initialised\r\n" ) ); return; @@ -226,14 +225,14 @@ void vMlTaskInferenceStart() LogInfo( ( "Signal task inference start\r\n" ) ); - ( void )xEventGroupClearBits( xSystemEvents, (EventBits_t)EVENT_MASK_ML_STOP ); + ( void ) xEventGroupClearBits( xSystemEvents, ( EventBits_t ) EVENT_MASK_ML_STOP ); - ( void )xEventGroupSetBits( xSystemEvents, (EventBits_t)EVENT_MASK_ML_START ); + ( void ) xEventGroupSetBits( xSystemEvents, ( EventBits_t ) EVENT_MASK_ML_START ); } void vMlTaskInferenceStop() { - if(xSystemEvents == NULL) + if( xSystemEvents == NULL ) { LogError( ( "xSystemEvents is not initialised\r\n" ) ); return; @@ -241,83 +240,89 @@ void vMlTaskInferenceStop() LogInfo( ( "Signal task inference stop\r\n" ) ); - ( void )xEventGroupClearBits( xSystemEvents, (EventBits_t)EVENT_MASK_ML_START ); + ( void ) xEventGroupClearBits( xSystemEvents, ( EventBits_t ) EVENT_MASK_ML_START ); - ( void )xEventGroupSetBits( xSystemEvents, (EventBits_t)EVENT_MASK_ML_STOP ); + ( void ) xEventGroupSetBits( xSystemEvents, ( EventBits_t ) EVENT_MASK_ML_STOP ); } void vStartMlTask( void ) { - if (xTaskCreate( vMlTask, - "ML_TASK", - appCONFIG_ML_TASK_STACK_SIZE, - NULL, - appCONFIG_ML_TASK_PRIORITY, - NULL ) != pdPASS) { + if( xTaskCreate( vMlTask, + "ML_TASK", + appCONFIG_ML_TASK_STACK_SIZE, + NULL, + appCONFIG_ML_TASK_PRIORITY, + NULL ) != pdPASS ) + { LogError( ( "Failed to create Ml Task\r\n" ) ); } } void vStartMlMqttTask( void ) { - if (xTaskCreate( vMlMqttTask, - "ML_MQTT", - appCONFIG_ML_MQTT_TASK_STACK_SIZE, - NULL, - appCONFIG_ML_MQTT_TASK_PRIORITY, - NULL ) != pdPASS) { + if( xTaskCreate( vMlMqttTask, + "ML_MQTT", + appCONFIG_ML_MQTT_TASK_STACK_SIZE, + NULL, + appCONFIG_ML_MQTT_TASK_PRIORITY, + NULL ) != pdPASS ) + { LogError( ( "Failed to create Ml Mqtt Task\r\n" ) ); } } +} /* extern "C" */ -} // extern "C" - -static void prvSetMlProcessingState(ml_processing_state_t new_state) +static void prvSetMlProcessingState( ml_processing_state_t new_state ) { - // In this use case, only changes in state are relevant. Additionally, - // this avoids reporting the same keyword detected twice in adjacent, - // overlapping inference windows. - static ml_processing_state_t ml_processing_state{ML_SILENCE}; - if (new_state != ml_processing_state) { - if(xMlMqttQueue == NULL) + /* In this use case, only changes in state are relevant. Additionally, */ + /* this avoids reporting the same keyword detected twice in adjacent, */ + /* overlapping inference windows. */ + static ml_processing_state_t ml_processing_state{ ML_SILENCE }; + + if( new_state != ml_processing_state ) + { + if( xMlMqttQueue == NULL ) { LogError( ( "xMlMqttQueue is not initialised\r\n" ) ); return; } - const ml_mqtt_msg_t msg = {new_state}; - if (xQueueSendToBack(xMlMqttQueue, (void *)&msg, (TickType_t)0) != pdPASS) { + const ml_mqtt_msg_t msg = { new_state }; + + if( xQueueSendToBack( xMlMqttQueue, ( void * ) &msg, ( TickType_t ) 0 ) != pdPASS ) + { LogError( ( "Failed to send message to xMlMqttQueue\r\n" ) ); } ml_processing_state = new_state; - if (ml_processing_change_handler) { + + if( ml_processing_change_handler ) + { ml_processing_change_handler_t handler = ml_processing_change_handler; - void *handler_instance = ml_processing_change_ptr; + void * handler_instance = ml_processing_change_ptr; - handler(handler_instance, new_state); + handler( handler_instance, new_state ); } } } -// Model +/* Model */ arm::app::ApplicationContext caseContext; #ifdef AUDIO_VSI - extern "C" { -// Audio driver data -void (*pxOnVsiEvent)(void *); -void *pvVsiContext = nullptr; +/* Audio driver data */ +void (* pxOnVsiEvent)( void * ); +void * pvVsiContext = nullptr; } -// Audio driver callback function for event management -// Note: This function cannot contain any logging function -// because it would be called in an ISR and it is not permitted -// to use logging calls inside the ISR. -static void prvArmSaiSignalEvent(uint32_t event) +/* Audio driver callback function for event management */ +/* Note: This function cannot contain any logging function */ +/* because it would be called in an ISR and it is not permitted */ +/* to use logging calls inside the ISR. */ +static void prvArmSaiSignalEvent( uint32_t event ) { - if(xVsiTaskHandle == NULL) + if( xVsiTaskHandle == NULL ) { LogError( ( "VSI Task is not created\r\n" ) ); return; @@ -331,32 +336,38 @@ static void prvArmSaiSignalEvent(uint32_t event) portYIELD_FROM_ISR( xHigherPriorityTaskWoken ); } -static int prvAudioDrvSetup(void (*event_handler)(void *), void *event_handler_ptr) +static int prvAudioDrvSetup( void ( * event_handler )( void * ), + void * event_handler_ptr ) { - if (Driver_SAI0.Initialize(prvArmSaiSignalEvent) != ARM_DRIVER_OK) { + if( Driver_SAI0.Initialize( prvArmSaiSignalEvent ) != ARM_DRIVER_OK ) + { LogError( ( "Failed to set up FVP VSI!\n" ) ); return -1; } - if (Driver_SAI0.PowerControl(ARM_POWER_FULL) != ARM_DRIVER_OK) { + if( Driver_SAI0.PowerControl( ARM_POWER_FULL ) != ARM_DRIVER_OK ) + { LogError( ( "Failed to set the driver to operate with full power!\n" ) ); return -1; } - if (Driver_SAI0.Control(ARM_SAI_CONTROL_RX, 1, 0) != ARM_DRIVER_OK) { + if( Driver_SAI0.Control( ARM_SAI_CONTROL_RX, 1, 0 ) != ARM_DRIVER_OK ) + { LogError( ( "Failed to enable the VSI receiver!\n" ) ); return -1; } - if (Driver_SAI0.Control(ARM_SAI_CONFIGURE_RX | ARM_SAI_PROTOCOL_USER | ARM_SAI_DATA_SIZE(16), - AUDIO_BLOCK_SIZE, - static_cast(kAudioSampleFrequency)) - != ARM_DRIVER_OK) { + if( Driver_SAI0.Control( ARM_SAI_CONFIGURE_RX | ARM_SAI_PROTOCOL_USER | ARM_SAI_DATA_SIZE( 16 ), + AUDIO_BLOCK_SIZE, + static_cast( kAudioSampleFrequency ) ) + != ARM_DRIVER_OK ) + { LogError( ( "Failed to configure the receiver!\n" ) ); return -1; } - if (Driver_SAI0.Receive(reinterpret_cast(shared_audio_buffer), AUDIO_BLOCK_NUM) != ARM_DRIVER_OK) { + if( Driver_SAI0.Receive( reinterpret_cast( shared_audio_buffer ), AUDIO_BLOCK_NUM ) != ARM_DRIVER_OK ) + { LogError( ( "Failed to start receiving the data!\n" ) ); return -1; } @@ -373,103 +384,120 @@ static int prvAudioDrvSetup(void (*event_handler)(void *), void *event_handler_p * If data is not available, the audio processing thread goes to sleep until it * is woken up by the audio driver. */ -template struct CircularSlidingWindow { - CircularSlidingWindow( - const T *buffer, size_t block_size, size_t block_count, size_t window_size, size_t stride_size) - : buffer{buffer}, block_size{block_size}, block_count{block_count}, window_size{window_size}, stride_size{ - stride_size} +template struct CircularSlidingWindow +{ + CircularSlidingWindow( const T * buffer, + size_t block_size, + size_t block_count, + size_t window_size, + size_t stride_size ) + : buffer{ buffer }, block_size{ block_size }, block_count{ block_count }, window_size{ window_size }, stride_size{ + stride_size } { - // These are the requirements for the algorithm. - assert(stride_size < block_size); - assert(window_size > stride_size); - assert(block_size > window_size); - assert(block_size % stride_size == 0); - assert(window_size % stride_size == 0); - prvCreateBinarySemaphore(&xSlidingWindowSemaphore); + /* These are the requirements for the algorithm. */ + assert( stride_size < block_size ); + assert( window_size > stride_size ); + assert( block_size > window_size ); + assert( block_size % stride_size == 0 ); + assert( window_size % stride_size == 0 ); + prvCreateBinarySemaphore( &xSlidingWindowSemaphore ); } ~CircularSlidingWindow() { - if(xSlidingWindowSemaphore != NULL) + if( xSlidingWindowSemaphore != NULL ) { - vSemaphoreDelete(xSlidingWindowSemaphore); + vSemaphoreDelete( xSlidingWindowSemaphore ); } } - void next(T *dest) + void next( T * dest ) { - // Compute the block that contains the stride + /* Compute the block that contains the stride */ size_t first_block = current_stride / prvStridesPerBlock(); - auto last_block = ((current_stride * stride_size + window_size - 1) / block_size) % block_count; + auto last_block = ( ( current_stride * stride_size + window_size - 1 ) / block_size ) % block_count; - // Go to sleep if one of the block that contains the next stride is being written. - // If the stride is already loaded, copy it into the destination buffer. - while (first_block == prvGetBlockUnderWrite() || last_block == prvGetBlockUnderWrite()) { - if (xSlidingWindowSemaphore != NULL) { + /* Go to sleep if one of the block that contains the next stride is being written. */ + /* If the stride is already loaded, copy it into the destination buffer. */ + while( first_block == prvGetBlockUnderWrite() || last_block == prvGetBlockUnderWrite() ) + { + if( xSlidingWindowSemaphore != NULL ) + { BaseType_t ret = xSemaphoreTake( xSlidingWindowSemaphore, portMAX_DELAY ); - if (ret != pdTRUE) { + + if( ret != pdTRUE ) + { LogError( ( "xSemaphoreTake xSlidingWindowSemaphore failed %ld\r\n", ret ) ); } } } - // Copy the data into the destination buffer - auto begin = buffer + (current_stride * stride_size); + /* Copy the data into the destination buffer */ + auto begin = buffer + ( current_stride * stride_size ); - // Memory to copy may not be seqquential if a window span on two blocks. - if (last_block < first_block) { - // Copy end of the buffer - auto buffer_end = buffer + (block_size * block_count); - std::copy(begin, buffer_end, dest); - // Copy remaining from the begining + /* Memory to copy may not be seqquential if a window span on two blocks. */ + if( last_block < first_block ) + { + /* Copy end of the buffer */ + auto buffer_end = buffer + ( block_size * block_count ); + std::copy( begin, buffer_end, dest ); + /* Copy remaining from the begining */ auto offset = buffer_end - begin; - std::copy(buffer, buffer + (window_size - offset), dest + offset); - } else { - std::copy(begin, begin + window_size, dest); + std::copy( buffer, buffer + ( window_size - offset ), dest + offset ); + } + else + { + std::copy( begin, begin + window_size, dest ); } - // Compute the next stride + /* Compute the next stride */ ++current_stride; current_stride %= prvStrideCount(); } - // This is called from ISR - static void prvSignalBlockWritten(void *ptr) + /* This is called from ISR */ + static void prvSignalBlockWritten( void * ptr ) { - auto *self = reinterpret_cast *>(ptr); - // Update block ID - self->block_under_write = ((self->block_under_write + 1) % self->block_count); + auto * self = reinterpret_cast *>( ptr ); - if(self->xSlidingWindowSemaphore != NULL) + /* Update block ID */ + self->block_under_write = ( ( self->block_under_write + 1 ) % self->block_count ); + + if( self->xSlidingWindowSemaphore != NULL ) { BaseType_t yield = pdFALSE; - // Wakeup task waiting - if(xSemaphoreGiveFromISR(self->xSlidingWindowSemaphore, &yield) == pdTRUE) + + /* Wakeup task waiting */ + if( xSemaphoreGiveFromISR( self->xSlidingWindowSemaphore, &yield ) == pdTRUE ) { - portYIELD_FROM_ISR (yield); + portYIELD_FROM_ISR( yield ); } } - // safe to return as this can signal multiple times before the reader acquires the semaphore. + + /* safe to return as this can signal multiple times before the reader acquires the semaphore. */ } - static void prvCreateBinarySemaphore(SemaphoreHandle_t* xSemaphore) + static void prvCreateBinarySemaphore( SemaphoreHandle_t * xSemaphore ) { - *xSemaphore = xSemaphoreCreateBinary(); - if (*xSemaphore == NULL) { + * xSemaphore = xSemaphoreCreateBinary(); + + if( *xSemaphore == NULL ) + { LogError( ( "xSemaphoreCreateBinary failed \r\n" ) ); } - if (xSemaphoreGive (*xSemaphore) != pdPASS) { + if( xSemaphoreGive( *xSemaphore ) != pdPASS ) + { LogError( ( "xSemaphoreGive xSemaphore failed \r\n" ) ); - vSemaphoreDelete (*xSemaphore); - *xSemaphore = NULL; + vSemaphoreDelete( *xSemaphore ); + * xSemaphore = NULL; } } private: size_t prvStrideCount() const { - return ((block_size * block_count) / stride_size); + return( ( block_size * block_count ) / stride_size ); } size_t prvStridesPerBlock() const @@ -485,7 +513,7 @@ template struct CircularSlidingWindow { return result; } - const T *buffer; + const T * buffer; size_t block_size; /* write size */ size_t block_count; size_t window_size; @@ -494,7 +522,6 @@ template struct CircularSlidingWindow { size_t current_stride = 0; SemaphoreHandle_t xSlidingWindowSemaphore; }; - #endif /* AUDIO_VSI */ /** @@ -504,7 +531,7 @@ template struct CircularSlidingWindow { * @param[in] results Vector of classification results to be displayed. * @return true if successful, false otherwise. **/ -static bool prvPresentInferenceResult(const arm::app::kws::KwsResult &result); +static bool prvPresentInferenceResult( const arm::app::kws::KwsResult &result ); /** * @brief Returns a function to perform feature calculation and populates input tensor data with @@ -520,65 +547,73 @@ static bool prvPresentInferenceResult(const arm::app::kws::KwsResult &result); * @param[in] cacheSize Size of the feature vectors cache (number of feature vectors). * @return Function to be called providing audio sample and sliding window index. */ -static std::function &, int, bool, size_t)> -prvGetFeatureCalculator(audio::MicroNetKwsMFCC &mfcc, TfLiteTensor *inputTensor, size_t cacheSize); +static std::function &, int, bool, size_t )> prvGetFeatureCalculator( audio::MicroNetKwsMFCC &mfcc, + TfLiteTensor * inputTensor, + size_t cacheSize ); -// Convert labels into ml_processing_state_t -static ml_processing_state_t prvConvertInferenceResult(const std::string &label) +/* Convert labels into ml_processing_state_t */ +static ml_processing_state_t prvConvertInferenceResult( const std::string &label ) { - for (const auto &label_to_state_pair : label_to_state) { - if (label == label_to_state_pair.first) { + for( const auto &label_to_state_pair : label_to_state ) + { + if( label == label_to_state_pair.first ) + { return label_to_state_pair.second; } } + return ML_UNKNOWN; } -static void prvProcessAudio(ApplicationContext &ctx) +static void prvProcessAudio( ApplicationContext &ctx ) { - // Constants + /* Constants */ constexpr int minTensorDims = - static_cast((arm::app::MicroNetKwsModel::ms_inputRowsIdx > arm::app::MicroNetKwsModel::ms_inputColsIdx) - ? arm::app::MicroNetKwsModel::ms_inputRowsIdx - : arm::app::MicroNetKwsModel::ms_inputColsIdx); + static_cast( ( arm::app::MicroNetKwsModel::ms_inputRowsIdx > arm::app::MicroNetKwsModel::ms_inputColsIdx ) + ? arm::app::MicroNetKwsModel::ms_inputRowsIdx + : arm::app::MicroNetKwsModel::ms_inputColsIdx ); - // Get the global model - auto &model = ctx.Get("model"); + /* Get the global model */ + auto &model = ctx.Get( "model" ); - if (!model.IsInited()) { + if( !model.IsInited() ) + { LogError( ( "Model is not initialised! Terminating processing.\n" ) ); return; } - const auto frameLength = ctx.Get("frameLength"); // 640 - const auto frameStride = ctx.Get("frameStride"); // 320 - const auto scoreThreshold = ctx.Get("scoreThreshold"); // 0.8 + const auto frameLength = ctx.Get( "frameLength" ); /* 640 */ + const auto frameStride = ctx.Get( "frameStride" ); /* 320 */ + const auto scoreThreshold = ctx.Get( "scoreThreshold" ); /* 0.8 */ - // Input and output tensors - TfLiteTensor *outputTensor = model.GetOutputTensor(0); - TfLiteTensor *inputTensor = model.GetInputTensor(0); + /* Input and output tensors */ + TfLiteTensor * outputTensor = model.GetOutputTensor( 0 ); + TfLiteTensor * inputTensor = model.GetInputTensor( 0 ); - if (!inputTensor->dims) { + if( !inputTensor->dims ) + { LogError( ( "Invalid input tensor dims\n" ) ); return; - } else if (inputTensor->dims->size < minTensorDims) { + } + else if( inputTensor->dims->size < minTensorDims ) + { LogError( ( "Input tensor dimension should be >= %d\n", minTensorDims ) ); return; } - TfLiteIntArray *inputShape = model.GetInputShape(0); - const uint32_t kNumCols = inputShape->data[arm::app::MicroNetKwsModel::ms_inputColsIdx]; - const uint32_t kNumRows = inputShape->data[arm::app::MicroNetKwsModel::ms_inputRowsIdx]; + TfLiteIntArray * inputShape = model.GetInputShape( 0 ); + const uint32_t kNumCols = inputShape->data[ arm::app::MicroNetKwsModel::ms_inputColsIdx ]; + const uint32_t kNumRows = inputShape->data[ arm::app::MicroNetKwsModel::ms_inputRowsIdx ]; - audio::MicroNetKwsMFCC mfcc = audio::MicroNetKwsMFCC(kNumCols, frameLength); + audio::MicroNetKwsMFCC mfcc = audio::MicroNetKwsMFCC( kNumCols, frameLength ); mfcc.Init(); /* Deduce the data length required for 1 inference from the network parameters. */ - auto audioDataWindowSize = kNumRows * frameStride + (frameLength - frameStride); // 16000 -#ifdef AUDIO_VSI - auto mfccWindowSize = frameLength; // 640 -#endif /* AUDIO_VSI */ - auto mfccWindowStride = frameStride; // 320 + auto audioDataWindowSize = kNumRows * frameStride + ( frameLength - frameStride ); /* 16000 */ + #ifdef AUDIO_VSI + auto mfccWindowSize = frameLength; /* 640 */ + #endif /* AUDIO_VSI */ + auto mfccWindowStride = frameStride; /* 320 */ /* We choose to move by half the window size => for a 1 second window size * there is an overlap of 0.5 seconds. */ @@ -586,12 +621,13 @@ static void prvProcessAudio(ApplicationContext &ctx) /* To have the previously calculated features re-usable, stride must be multiple * of MFCC features window stride. */ - if (0 != audioDataStride % mfccWindowStride) { + if( 0 != audioDataStride % mfccWindowStride ) + { /* Reduce the stride. */ - audioDataStride -= audioDataStride % mfccWindowStride; // 8000 + audioDataStride -= audioDataStride % mfccWindowStride; /* 8000 */ } - auto nMfccVectorsInAudioStride = audioDataStride / mfccWindowStride; // 25 + auto nMfccVectorsInAudioStride = audioDataStride / mfccWindowStride; /* 25 */ /* We expect to be sampling 1 second worth of data at a time. * NOTE: This is only used for time stamp calculation. */ @@ -602,177 +638,196 @@ static void prvProcessAudio(ApplicationContext &ctx) auto numberOfReusedFeatureVectors = nMfccVectorsInAudioStride; /* Construct feature calculation function. */ - auto mfccFeatureCalc = prvGetFeatureCalculator(mfcc, inputTensor, numberOfReusedFeatureVectors); + auto mfccFeatureCalc = prvGetFeatureCalculator( mfcc, inputTensor, numberOfReusedFeatureVectors ); - if (!mfccFeatureCalc) { + if( !mfccFeatureCalc ) + { LogError( ( "No feature calculator available" ) ); return; } -#ifdef AUDIO_VSI - - // Initialize the sliding window - auto circularSlider = CircularSlidingWindow( - shared_audio_buffer, AUDIO_BLOCK_SIZE / sizeof(int16_t), AUDIO_BLOCK_NUM, mfccWindowSize, mfccWindowStride); - - // Initialize the audio driver. It is delayed until that point to avoid drop - // of starting frames. - prvAudioDrvSetup(&decltype(circularSlider)::prvSignalBlockWritten, &circularSlider); - - bool first_iteration = true; - auto mfccAudioData = std::vector(mfccWindowSize, 0); - size_t audio_index = 0; + #ifdef AUDIO_VSI + /* Initialize the sliding window */ + auto circularSlider = CircularSlidingWindow( + shared_audio_buffer, AUDIO_BLOCK_SIZE / sizeof( int16_t ), AUDIO_BLOCK_NUM, mfccWindowSize, mfccWindowStride ); -#endif /* AUDIO_VSI */ + /* Initialize the audio driver. It is delayed until that point to avoid drop */ + /* of starting frames. */ + prvAudioDrvSetup( &decltype( circularSlider )::prvSignalBlockWritten, &circularSlider ); - while (true) { + bool first_iteration = true; + auto mfccAudioData = std::vector( mfccWindowSize, 0 ); + size_t audio_index = 0; + #endif /* AUDIO_VSI */ -#ifdef AUDIO_VSI + while( true ) + { + #ifdef AUDIO_VSI + LogInfo( ( "Running inference as audio input is received from the Virtual Streaming Interface\r\n" ) ); - LogInfo( ( "Running inference as audio input is received from the Virtual Streaming Interface\r\n" ) ); + while( true ) + { + EventBits_t flags = xEventGroupWaitBits( xSystemEvents, ( EventBits_t ) EVENT_MASK_ML_STOP, pdTRUE, pdFAIL, 10 ); - while (true) { - EventBits_t flags = xEventGroupWaitBits(xSystemEvents, (EventBits_t)EVENT_MASK_ML_STOP, pdTRUE, pdFAIL, 10); + if( flags & EVENT_MASK_ML_STOP ) + { + /* jump out to outer loop */ + LogInfo( ( "Stopping audio processing\r\n" ) ); + break; + } - if (flags & EVENT_MASK_ML_STOP) { - /* jump out to outer loop */ - LogInfo( ( "Stopping audio processing\r\n" ) ); - break; - } + /* The first window does not have cache ready. */ + bool useCache = first_iteration == false && numberOfReusedFeatureVectors > 0; + size_t stride_index = 0; - /* The first window does not have cache ready. */ - bool useCache = first_iteration == false && numberOfReusedFeatureVectors > 0; - size_t stride_index = 0; + while( stride_index < ( audioDataWindowSize / mfccWindowStride ) ) + { + if( !useCache || ( stride_index >= numberOfReusedFeatureVectors ) ) + { + circularSlider.next( mfccAudioData.data() ); + } - while (stride_index < (audioDataWindowSize / mfccWindowStride)) { - if (!useCache || stride_index >= numberOfReusedFeatureVectors) { - circularSlider.next(mfccAudioData.data()); + /* Compute features for this window and write them to input tensor. */ + mfccFeatureCalc( mfccAudioData, stride_index, useCache, nMfccVectorsInAudioStride ); + ++stride_index; } - /* Compute features for this window and write them to input tensor. */ - mfccFeatureCalc(mfccAudioData, stride_index, useCache, nMfccVectorsInAudioStride); - ++stride_index; - } - - /* Run inference over this audio clip sliding window. */ - if (!model.RunInference()) { - LogError( ( "Failed to run inference" ) ); - return; - } - - std::vector classificationResult; - auto &classifier = ctx.Get("classifier"); - classifier.GetClassificationResults( - outputTensor, classificationResult, ctx.Get &>("labels"), 1, true); + /* Run inference over this audio clip sliding window. */ + if( !model.RunInference() ) + { + LogError( ( "Failed to run inference" ) ); + return; + } - auto result = kws::KwsResult( - classificationResult, audio_index * secondsPerSample * audioDataStride, audio_index, scoreThreshold); + std::vector classificationResult; + auto &classifier = ctx.Get( "classifier" ); + classifier.GetClassificationResults( + outputTensor, classificationResult, ctx.Get &>( "labels" ), 1, true ); - if (result.m_resultVec.empty()) { - prvSetMlProcessingState(ML_UNKNOWN); - } else { - prvSetMlProcessingState(prvConvertInferenceResult(result.m_resultVec[0].m_label)); - } + auto result = kws::KwsResult( + classificationResult, audio_index * secondsPerSample * audioDataStride, audio_index, scoreThreshold ); - if (prvPresentInferenceResult(result) != true) { - LogError( ( "Failed to present inference result" ) ); - return; - } - first_iteration = false; - ++audio_index; - } /* while (true) */ + if( result.m_resultVec.empty() ) + { + prvSetMlProcessingState( ML_UNKNOWN ); + } + else + { + prvSetMlProcessingState( prvConvertInferenceResult( result.m_resultVec[ 0 ].m_label ) ); + } -#else /* !defined(AUDIO_VSI) */ + if( prvPresentInferenceResult( result ) != true ) + { + LogError( ( "Failed to present inference result" ) ); + return; + } - LogInfo( ( "Running inference on an audio clip in local memory\r\n" ) ); + first_iteration = false; + ++audio_index; + } /* while (true) */ + #else /* !defined(AUDIO_VSI) */ + LogInfo( ( "Running inference on an audio clip in local memory\r\n" ) ); - const uint32_t numMfccFeatures = inputShape->data[MicroNetKwsModel::ms_inputColsIdx]; - const uint32_t numMfccFrames = inputShape->data[arm::app::MicroNetKwsModel::ms_inputRowsIdx]; + const uint32_t numMfccFeatures = inputShape->data[ MicroNetKwsModel::ms_inputColsIdx ]; + const uint32_t numMfccFrames = inputShape->data[ arm::app::MicroNetKwsModel::ms_inputRowsIdx ]; - KwsPreProcess preProcess = KwsPreProcess( - inputTensor, numMfccFeatures, numMfccFrames, ctx.Get("frameLength"), ctx.Get("frameStride")); + KwsPreProcess preProcess = KwsPreProcess( + inputTensor, numMfccFeatures, numMfccFrames, ctx.Get( "frameLength" ), ctx.Get( "frameStride" ) ); - std::vector singleInfResult; - KwsPostProcess postProcess = KwsPostProcess(outputTensor, - ctx.Get("classifier"), - ctx.Get &>("labels"), - singleInfResult); + std::vector singleInfResult; + KwsPostProcess postProcess = KwsPostProcess( outputTensor, + ctx.Get( "classifier" ), + ctx.Get &>( "labels" ), + singleInfResult ); - /* Creating a sliding window through the whole audio clip. */ - auto audioDataSlider = audio::SlidingWindow( - GetAudioArray(0), GetAudioArraySize(0), preProcess.m_audioDataWindowSize, preProcess.m_audioDataStride); + /* Creating a sliding window through the whole audio clip. */ + auto audioDataSlider = audio::SlidingWindow( + GetAudioArray( 0 ), GetAudioArraySize( 0 ), preProcess.m_audioDataWindowSize, preProcess.m_audioDataStride ); - /* Start sliding through audio clip. */ - while (audioDataSlider.HasNext()) { - EventBits_t flags = xEventGroupWaitBits(xSystemEvents, (EventBits_t)EVENT_MASK_ML_STOP, pdTRUE, pdFAIL, 10); + /* Start sliding through audio clip. */ + while( audioDataSlider.HasNext() ) + { + EventBits_t flags = xEventGroupWaitBits( xSystemEvents, ( EventBits_t ) EVENT_MASK_ML_STOP, pdTRUE, pdFAIL, 10 ); - if (flags & EVENT_MASK_ML_STOP) { - /* Jump out to the outer loop, which may restart inference on an EVENT_MASK_ML_START signal */ - LogInfo( ( "Inference stopped by a signal.\r\n" ) ); - break; - } + if( flags & EVENT_MASK_ML_STOP ) + { + /* Jump out to the outer loop, which may restart inference on an EVENT_MASK_ML_START signal */ + LogInfo( ( "Inference stopped by a signal.\r\n" ) ); + break; + } - const int16_t *inferenceWindow = audioDataSlider.Next(); - if (!preProcess.DoPreProcess(inferenceWindow, audioDataSlider.Index())) { - LogError( ( "Pre-processing failed." ) ); - return; - } + const int16_t * inferenceWindow = audioDataSlider.Next(); - if (!model.RunInference()) { - LogError( ( "Inference failed." ) ); - return; - } + if( !preProcess.DoPreProcess( inferenceWindow, audioDataSlider.Index() ) ) + { + LogError( ( "Pre-processing failed." ) ); + return; + } - if (!postProcess.DoPostProcess()) { - LogError( ( "Post-processing failed." ) ); - return; - } + if( !model.RunInference() ) + { + LogError( ( "Inference failed." ) ); + return; + } - auto result = kws::KwsResult(singleInfResult, - audioDataSlider.Index() * secondsPerSample * preProcess.m_audioDataStride, - audioDataSlider.Index(), - scoreThreshold); + if( !postProcess.DoPostProcess() ) + { + LogError( ( "Post-processing failed." ) ); + return; + } - if (result.m_resultVec.empty()) { - prvSetMlProcessingState(ML_UNKNOWN); - } else { - prvSetMlProcessingState(prvConvertInferenceResult(result.m_resultVec[0].m_label)); - } + auto result = kws::KwsResult( singleInfResult, + audioDataSlider.Index() * secondsPerSample * preProcess.m_audioDataStride, + audioDataSlider.Index(), + scoreThreshold ); - if (prvPresentInferenceResult(result) != true) { - LogError( ( "Failed to present inference result" ) ); - return; - } - } /* while (audioDataSlider.HasNext()) */ + if( result.m_resultVec.empty() ) + { + prvSetMlProcessingState( ML_UNKNOWN ); + } + else + { + prvSetMlProcessingState( prvConvertInferenceResult( result.m_resultVec[ 0 ].m_label ) ); + } -#endif /* AUDIO_VSI */ + if( prvPresentInferenceResult( result ) != true ) + { + LogError( ( "Failed to present inference result" ) ); + return; + } + } /* while (audioDataSlider.HasNext()) */ + #endif /* AUDIO_VSI */ - EventBits_t flags = xEventGroupWaitBits(xSystemEvents, (EventBits_t)EVENT_MASK_ML_START, pdTRUE, pdFAIL, portMAX_DELAY); + EventBits_t flags = xEventGroupWaitBits( xSystemEvents, ( EventBits_t ) EVENT_MASK_ML_START, pdTRUE, pdFAIL, portMAX_DELAY ); - if (flags & EVENT_MASK_ML_START) { + if( flags & EVENT_MASK_ML_START ) + { LogInfo( ( "Restarting audio processing %u\r\n", flags ) ); } } /* while (true) */ } -static bool prvPresentInferenceResult(const arm::app::kws::KwsResult &result) +static bool prvPresentInferenceResult( const arm::app::kws::KwsResult &result ) { - /* Display each result */ - if (result.m_resultVec.empty()) { + if( result.m_resultVec.empty() ) + { LogInfo( ( "For timestamp: %f (inference #: %" PRIu32 "); label: %s; threshold: %f\n", - (double)result.m_timeStamp, - result.m_inferenceNumber, - "", - 0. ) ); - } else { - for (uint32_t i = 0; i < result.m_resultVec.size(); ++i) { + ( double ) result.m_timeStamp, + result.m_inferenceNumber, + "", + 0. ) ); + } + else + { + for( uint32_t i = 0; i < result.m_resultVec.size(); ++i ) + { LogInfo( ( "For timestamp: %f (inference #: %" PRIu32 "); label: %s, score: %f; threshold: %f\n", - (double)result.m_timeStamp, - result.m_inferenceNumber, - result.m_resultVec[i].m_label.c_str(), - result.m_resultVec[i].m_normalisedVal, - (double)result.m_threshold ) ); + ( double ) result.m_timeStamp, + result.m_inferenceNumber, + result.m_resultVec[ i ].m_label.c_str(), + result.m_resultVec[ i ].m_normalisedVal, + ( double ) result.m_threshold ) ); } } @@ -792,153 +847,170 @@ static bool prvPresentInferenceResult(const arm::app::kws::KwsResult &result) * @param[in] compute Features calculator function. * @return Lambda function to compute features. */ -template -std::function &, size_t, bool, size_t)> -FeatureCalc(TfLiteTensor *inputTensor, size_t cacheSize, std::function(std::vector &)> compute) +template +std::function &, size_t, bool, size_t )> FeatureCalc( TfLiteTensor * inputTensor, + size_t cacheSize, + std::function( std::vector & )> compute ) { /* Feature cache to be captured by lambda function. */ - static std::vector> featureCache = std::vector>(cacheSize); - - return [=](std::vector &audioDataWindow, size_t index, bool useCache, size_t featuresOverlapIndex) { - T *tensorData = tflite::GetTensorData(inputTensor); - std::vector features; - - /* Reuse features from cache if cache is ready and sliding windows overlap. - * Overlap is in the beginning of sliding window with a size of a feature cache. */ - if (useCache && index < featureCache.size()) { - features = std::move(featureCache[index]); - } else { - features = std::move(compute(audioDataWindow)); - } - auto size = features.size(); - auto sizeBytes = sizeof(T) * size; - std::memcpy(tensorData + (index * size), features.data(), sizeBytes); - - /* Start renewing cache as soon iteration goes out of the windows overlap. */ - if (index >= featuresOverlapIndex) { - featureCache[index - featuresOverlapIndex] = std::move(features); - } + static std::vector > featureCache = std::vector >( cacheSize ); + + return [ = ]( std::vector &audioDataWindow, size_t index, bool useCache, size_t featuresOverlapIndex ) { + T * tensorData = tflite::GetTensorData( inputTensor ); + std::vector features; + + /* Reuse features from cache if cache is ready and sliding windows overlap. + * Overlap is in the beginning of sliding window with a size of a feature cache. */ + if( useCache && ( index < featureCache.size() ) ) + { + features = std::move( featureCache[ index ] ); + } + else + { + features = std::move( compute( audioDataWindow ) ); + } + + auto size = features.size(); + auto sizeBytes = sizeof( T ) * size; + std::memcpy( tensorData + ( index * size ), features.data(), sizeBytes ); + + /* Start renewing cache as soon iteration goes out of the windows overlap. */ + if( index >= featuresOverlapIndex ) + { + featureCache[ index - featuresOverlapIndex ] = std::move( features ); + } }; } -template std::function &, size_t, bool, size_t)> FeatureCalc( - TfLiteTensor *inputTensor, size_t cacheSize, std::function(std::vector &)> compute); +template std::function &, size_t, bool, size_t )> FeatureCalc( TfLiteTensor * inputTensor, + size_t cacheSize, + std::function( std::vector & )> compute ); -template std::function &, size_t, bool, size_t)> FeatureCalc( - TfLiteTensor *inputTensor, size_t cacheSize, std::function(std::vector &)> compute); +template std::function &, size_t, bool, size_t )> FeatureCalc( TfLiteTensor * inputTensor, + size_t cacheSize, + std::function( std::vector & )> compute ); -template std::function &, size_t, bool, size_t)> FeatureCalc( - TfLiteTensor *inputTensor, size_t cacheSize, std::function(std::vector &)> compute); +template std::function &, size_t, bool, size_t )> FeatureCalc( TfLiteTensor * inputTensor, + size_t cacheSize, + std::function( std::vector & )> compute ); -template std::function &, size_t, bool, size_t)> FeatureCalc( - TfLiteTensor *inputTensor, size_t cacheSize, std::function(std::vector &)> compute); +template std::function &, size_t, bool, size_t )> FeatureCalc( TfLiteTensor * inputTensor, + size_t cacheSize, + std::function( std::vector & )> compute ); -static std::function &, int, bool, size_t)> -prvGetFeatureCalculator(audio::MicroNetKwsMFCC &mfcc, TfLiteTensor *inputTensor, size_t cacheSize) +static std::function &, int, bool, size_t )> prvGetFeatureCalculator( audio::MicroNetKwsMFCC &mfcc, + TfLiteTensor * inputTensor, + size_t cacheSize ) { - std::function &, size_t, bool, size_t)> mfccFeatureCalc; + std::function &, size_t, bool, size_t )> mfccFeatureCalc; TfLiteQuantization quant = inputTensor->quantization; - if (kTfLiteAffineQuantization == quant.type) { - auto *quantParams = static_cast(quant.params); - const float quantScale = quantParams->scale->data[0]; - const int quantOffset = quantParams->zero_point->data[0]; + if( kTfLiteAffineQuantization == quant.type ) + { + auto * quantParams = static_cast( quant.params ); + const float quantScale = quantParams->scale->data[ 0 ]; + const int quantOffset = quantParams->zero_point->data[ 0 ]; - switch (inputTensor->type) { - case kTfLiteInt8: { + switch( inputTensor->type ) + { + case kTfLiteInt8: mfccFeatureCalc = - FeatureCalc(inputTensor, cacheSize, [=, &mfcc](std::vector &audioDataWindow) { - return mfcc.MfccComputeQuant(audioDataWindow, quantScale, quantOffset); - }); + FeatureCalc( inputTensor, cacheSize, [ =, &mfcc ]( std::vector &audioDataWindow ) { + return mfcc.MfccComputeQuant( audioDataWindow, quantScale, quantOffset ); + } ); break; - } - case kTfLiteUInt8: { + + case kTfLiteUInt8: mfccFeatureCalc = - FeatureCalc(inputTensor, cacheSize, [=, &mfcc](std::vector &audioDataWindow) { - return mfcc.MfccComputeQuant(audioDataWindow, quantScale, quantOffset); - }); + FeatureCalc( inputTensor, cacheSize, [ =, &mfcc ]( std::vector &audioDataWindow ) { + return mfcc.MfccComputeQuant( audioDataWindow, quantScale, quantOffset ); + } ); break; - } - case kTfLiteInt16: { + + case kTfLiteInt16: mfccFeatureCalc = - FeatureCalc(inputTensor, cacheSize, [=, &mfcc](std::vector &audioDataWindow) { - return mfcc.MfccComputeQuant(audioDataWindow, quantScale, quantOffset); - }); + FeatureCalc( inputTensor, cacheSize, [ =, &mfcc ]( std::vector &audioDataWindow ) { + return mfcc.MfccComputeQuant( audioDataWindow, quantScale, quantOffset ); + } ); break; - } + default: - LogError( ( "Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type) ) ); + LogError( ( "Tensor type %s not supported\n", TfLiteTypeGetName( inputTensor->type ) ) ); } - - } else { - mfccFeatureCalc = FeatureCalc(inputTensor, cacheSize, [&mfcc](std::vector &audioDataWindow) { - return mfcc.MfccCompute(audioDataWindow); - }); } + else + { + mfccFeatureCalc = FeatureCalc( inputTensor, cacheSize, [ &mfcc ]( std::vector &audioDataWindow ) { + return mfcc.MfccCompute( audioDataWindow ); + } ); + } + return mfccFeatureCalc; } - -} // anonymous namespace +} /* anonymous namespace */ #ifdef USE_ETHOS -extern struct ethosu_driver ethosu_drv; /* Default Ethos-U55 device driver */ + extern struct ethosu_driver ethosu_drv; /* Default Ethos-U55 device driver */ /** * @brief Initialises the Arm Ethos-U55 NPU * @return 0 if successful, error code otherwise **/ -static int prvArmNpuInit(void); + static int prvArmNpuInit( void ); -static int prvArmNpuInit(void) -{ - int err = 0; + static int prvArmNpuInit( void ) + { + int err = 0; - SCB_EnableICache(); - SCB_EnableDCache(); + SCB_EnableICache(); + SCB_EnableDCache(); -#if defined(ETHOS_U_NPU_TIMING_ADAPTER_ENABLED) - /* If the platform has timing adapter blocks along with Ethos-U core - * block, initialise them here. */ - if (0 != (err = arm_ethosu_timing_adapter_init())) { - LogError( ("Failed to init timing adapter\n") ); - return err; - } -#endif /* ETHOS_U_NPU_TIMING_ADAPTER_ENABLED */ + #if defined( ETHOS_U_NPU_TIMING_ADAPTER_ENABLED ) - // Initialize the ethos NPU - if (0 != (err = arm_ethosu_npu_init())) { - LogError( ("Failed to init arm npu\n") ); - return err; - } + /* If the platform has timing adapter blocks along with Ethos-U core + * block, initialise them here. */ + if( 0 != ( err = arm_ethosu_timing_adapter_init() ) ) + { + LogError( ( "Failed to init timing adapter\n" ) ); + return err; + } + #endif /* ETHOS_U_NPU_TIMING_ADAPTER_ENABLED */ + + /* Initialize the ethos NPU */ + if( 0 != ( err = arm_ethosu_npu_init() ) ) + { + LogError( ( "Failed to init arm npu\n" ) ); + return err; + } - LogInfo( ( "Ethos-U55 device initialised\n" ) ); + LogInfo( ( "Ethos-U55 device initialised\n" ) ); - /* Get Ethos-U55 version */ - struct ethosu_driver_version driver_version; - struct ethosu_hw_info hw_info; + /* Get Ethos-U55 version */ + struct ethosu_driver_version driver_version; + struct ethosu_hw_info hw_info; - ethosu_get_driver_version(&driver_version); - ethosu_get_hw_info(ðosu_drv, &hw_info); + ethosu_get_driver_version( &driver_version ); + ethosu_get_hw_info( ðosu_drv, &hw_info ); - LogInfo( ( "Ethos-U version info:\n" ) ); - LogInfo( ( "\tArch: v%" PRIu32 ".%" PRIu32 ".%" PRIu32 "\n", - hw_info.version.arch_major_rev, - hw_info.version.arch_minor_rev, - hw_info.version.arch_patch_rev ) ); - LogInfo( ( "\tDriver: v%" PRIu8 ".%" PRIu8 ".%" PRIu8 "\n", - driver_version.major, - driver_version.minor, - driver_version.patch ) ); - LogInfo( ( "\tMACs/cc: %" PRIu32 "\n", (uint32_t)(1 << hw_info.cfg.macs_per_cc) ) ); - LogInfo( ( "\tCmd stream: v%" PRIu32 "\n", hw_info.cfg.cmd_stream_version ) ); + LogInfo( ( "Ethos-U version info:\n" ) ); + LogInfo( ( "\tArch: v%" PRIu32 ".%" PRIu32 ".%" PRIu32 "\n", + hw_info.version.arch_major_rev, + hw_info.version.arch_minor_rev, + hw_info.version.arch_patch_rev ) ); + LogInfo( ( "\tDriver: v%" PRIu8 ".%" PRIu8 ".%" PRIu8 "\n", + driver_version.major, + driver_version.minor, + driver_version.patch ) ); + LogInfo( ( "\tMACs/cc: %" PRIu32 "\n", ( uint32_t ) ( 1 << hw_info.cfg.macs_per_cc ) ) ); + LogInfo( ( "\tCmd stream: v%" PRIu32 "\n", hw_info.cfg.cmd_stream_version ) ); - return 0; -} + return 0; + } #endif /* USE_ETHOS */ extern "C" { - -void vRegisterMlProcessingChangeCb(ml_processing_change_handler_t handler, void *ctx) +void vRegisterMlProcessingChangeCb( ml_processing_change_handler_t handler, + void * ctx ) { ml_processing_change_handler = handler; ml_processing_change_ptr = ctx; @@ -948,74 +1020,81 @@ static int prvMlInterfaceInit() { static arm::app::MicroNetKwsModel model; /* Model wrapper object. */ -#ifdef USE_ETHOS - // Initialize the ethos U55 - if (prvArmNpuInit() != 0) { - LogError( ( "Failed to arm npu\n" ) ); - return -1; - } -#endif /* USE_ETHOS */ + #ifdef USE_ETHOS + /* Initialize the ethos U55 */ + if( prvArmNpuInit() != 0 ) + { + LogError( ( "Failed to arm npu\n" ) ); + return -1; + } + #endif /* USE_ETHOS */ /* Load the model. */ - if (!model.Init(::arm::app::tensorArena, - sizeof(::arm::app::tensorArena), - ::arm::app::kws::GetModelPointer(), - ::arm::app::kws::GetModelLen())) { + if( !model.Init( ::arm::app::tensorArena, + sizeof( ::arm::app::tensorArena ), + ::arm::app::kws::GetModelPointer(), + ::arm::app::kws::GetModelLen() ) ) + { LogError( ( "Failed to initialise model\n" ) ); return -1; } /* Instantiate application context. */ - caseContext.Set("model", model); - caseContext.Set("frameLength", arm::app::kws::g_FrameLength); - caseContext.Set("frameStride", arm::app::kws::g_FrameStride); - caseContext.Set("scoreThreshold", arm::app::kws::g_ScoreThreshold); /* Normalised score threshold. */ + caseContext.Set( "model", model ); + caseContext.Set( "frameLength", arm::app::kws::g_FrameLength ); + caseContext.Set( "frameStride", arm::app::kws::g_FrameStride ); + caseContext.Set( "scoreThreshold", arm::app::kws::g_ScoreThreshold ); /* Normalised score threshold. */ - static KwsClassifier classifier; /* classifier wrapper object. */ - caseContext.Set("classifier", classifier); + static KwsClassifier classifier; /* classifier wrapper object. */ + caseContext.Set( "classifier", classifier ); static std::vector labels; - GetLabelsVector(labels); + GetLabelsVector( labels ); - caseContext.Set &>("labels", labels); + caseContext.Set &>( "labels", labels ); PrintTensorFlowVersion(); LogInfo( ( "*** ML interface initialised\r\n" ) ); return 0; } -void vMlTask(void *arg) +void vMlTask( void * arg ) { - (void)arg; + ( void ) arg; - EventBits_t flags = xEventGroupWaitBits(xSystemEvents, (EventBits_t)EVENT_MASK_ML_START, pdTRUE, pdFAIL, portMAX_DELAY); + EventBits_t flags = xEventGroupWaitBits( xSystemEvents, ( EventBits_t ) EVENT_MASK_ML_START, pdTRUE, pdFAIL, portMAX_DELAY ); - if (flags & EVENT_MASK_ML_START) + if( flags & EVENT_MASK_ML_START ) { LogInfo( ( "Initial start of audio processing\r\n" ) ); } - if (prvMlInterfaceInit() < 0) { + if( prvMlInterfaceInit() < 0 ) + { LogError( ( "prvMlInterfaceInit failed\r\n" ) ); return; } - prvProcessAudio(caseContext); + prvProcessAudio( caseContext ); } -void vMlMqttTask(void *arg) +void vMlMqttTask( void * arg ) { - (void)arg; + ( void ) arg; - while (1) { + while( 1 ) + { ml_mqtt_msg_t msg; - if (xQueueReceive (xMlMqttQueue, &msg, portMAX_DELAY) == pdPASS) { - prvMqttSendMessage(prvGetInferenceResultString(msg.state)); - } else { + + if( xQueueReceive( xMlMqttQueue, &msg, portMAX_DELAY ) == pdPASS ) + { + prvMqttSendMessage( prvGetInferenceResultString( msg.state ) ); + } + else + { LogError( ( "xQueueReceive Ml Mqtt Queue failed\r\n" ) ); return; } } } - -} // extern "C" +} /* extern "C" */ diff --git a/applications/object_detection/ml_interface.cc b/applications/object_detection/ml_interface.cc old mode 100755 new mode 100644 index c162300..7ba4698 --- a/applications/object_detection/ml_interface.cc +++ b/applications/object_detection/ml_interface.cc @@ -82,29 +82,30 @@ extern MQTTAgentContext_t xGlobalMqttAgentContext; extern EventGroupHandle_t xSystemEvents; extern QueueHandle_t xMlMqttQueue; -// Define tensor arena and declare functions required to access the model +/* Define tensor arena and declare functions required to access the model */ namespace arm { - namespace app { - uint8_t ucTensorArena[ACTIVATION_BUF_SZ] ACTIVATION_BUF_ATTRIBUTE; - namespace object_detection { - extern uint8_t *GetModelPointer(); - extern size_t GetModelLen(); - } /* namespace object_detection */ - } /* namespace app */ +namespace app { +uint8_t ucTensorArena[ ACTIVATION_BUF_SZ ] ACTIVATION_BUF_ATTRIBUTE; +namespace object_detection { +extern uint8_t * GetModelPointer(); +extern size_t GetModelLen(); +} /* namespace object_detection */ +} /* namespace app */ } /* namespace arm */ namespace { - -// Import +/* Import */ using namespace arm::app; -// Model +/* Model */ arm::app::ApplicationContext xCaseContext; -static int prvProcessImage(ApplicationContext &xApplicationContext, const uint8_t *pucImage, struct DetectRegion_t *pxCResults, uint32_t *pulResultsNum); +static int prvProcessImage( ApplicationContext &xApplicationContext, + const uint8_t * pucImage, + struct DetectRegion_t * pxCResults, + uint32_t * pulResultsNum ); extern "C" { - static void prvAppPublishCommandCallback( MQTTAgentCommandContext_t * pxCommandContext, MQTTAgentReturnInfo_t * pxReturnInfo ) { @@ -118,9 +119,9 @@ static void prvAppPublishCommandCallback( MQTTAgentCommandContext_t * pxCommandC static void prvMqttSendMessage( const char * pcMessage ) { - static MQTTPublishInfo_t xPublishInfo = { (MQTTQoS_t)0 }; + static MQTTPublishInfo_t xPublishInfo = { ( MQTTQoS_t ) 0 }; static MQTTAgentCommandInfo_t xCommandParams = { 0 }; - static MQTTAgentCommandContext_t xCommandContext = { (MQTTStatus_t)0 }; + static MQTTAgentCommandContext_t xCommandContext = { ( MQTTStatus_t ) 0 }; MQTTStatus_t xMqttStatus = MQTTBadParameter; xPublishInfo.pTopicName = mqttexampleTOPIC; @@ -169,36 +170,35 @@ static void prvMqttSendMessage( const char * pcMessage ) } } -void vMlTaskInferenceStart(void) +void vMlTaskInferenceStart( void ) { - if(xSystemEvents == NULL) + if( xSystemEvents == NULL ) { LogError( ( "xSystemEvents is not initialised\r\n" ) ); return; } LogInfo( ( "Signal task inference start\r\n" ) ); - ( void )xEventGroupClearBits( xSystemEvents, (EventBits_t)EVENT_MASK_ML_STOP ); - ( void )xEventGroupSetBits( xSystemEvents, (EventBits_t)EVENT_MASK_ML_START ); - + ( void ) xEventGroupClearBits( xSystemEvents, ( EventBits_t ) EVENT_MASK_ML_STOP ); + ( void ) xEventGroupSetBits( xSystemEvents, ( EventBits_t ) EVENT_MASK_ML_START ); } -void vMlTaskInferenceStop(void) +void vMlTaskInferenceStop( void ) { - if(xSystemEvents == NULL) + if( xSystemEvents == NULL ) { LogError( ( "xSystemEvents is not initialised\r\n" ) ); return; } LogInfo( ( "Signal task inference stop\r\n" ) ); - ( void )xEventGroupClearBits( xSystemEvents, (EventBits_t)EVENT_MASK_ML_START ); - ( void )xEventGroupSetBits( xSystemEvents, (EventBits_t)EVENT_MASK_ML_STOP ); + ( void ) xEventGroupClearBits( xSystemEvents, ( EventBits_t ) EVENT_MASK_ML_START ); + ( void ) xEventGroupSetBits( xSystemEvents, ( EventBits_t ) EVENT_MASK_ML_STOP ); } -void vStartMlTask( void *pvParameters ) +void vStartMlTask( void * pvParameters ) { - if ( + if( xTaskCreate( vMlTask, "ML_TASK", @@ -206,15 +206,16 @@ void vStartMlTask( void *pvParameters ) pvParameters, appCONFIG_ML_TASK_PRIORITY, NULL - ) != pdPASS - ) { + ) != pdPASS + ) + { LogError( ( "Failed to create ML Task\r\n" ) ); } } void vStartMlMqttTask( void ) { - if ( + if( xTaskCreate( vMlMqttTask, "ML_MQTT", @@ -222,38 +223,47 @@ void vStartMlMqttTask( void ) NULL, appCONFIG_ML_MQTT_TASK_PRIORITY, NULL - ) != pdPASS - ) { + ) != pdPASS + ) + { LogError( ( "Failed to create ML Mqtt Task\r\n" ) ); } } -int32_t lMLRunInference(const uint8_t *pucImg, struct DetectRegion_t *pxResults, uint32_t *pulResultsNum) +int32_t lMLRunInference( const uint8_t * pucImg, + struct DetectRegion_t * pxResults, + uint32_t * pulResultsNum ) { - prvProcessImage(xCaseContext, pucImg, pxResults, pulResultsNum); + prvProcessImage( xCaseContext, pucImg, pxResults, pulResultsNum ); return 0; } -} // extern "C" { +} /* extern "C" { */ -static void prvSetMlProcessingstate(const char *pcInferenceResult) +static void prvSetMlProcessingstate( const char * pcInferenceResult ) { - size_t xMsgLen = strlen(pcInferenceResult) + 1; - char *pcMsgResult = reinterpret_cast(malloc(xMsgLen)); - if (pcMsgResult) { - if(xMlMqttQueue == NULL) + size_t xMsgLen = strlen( pcInferenceResult ) + 1; + char * pcMsgResult = reinterpret_cast( malloc( xMsgLen ) ); + + if( pcMsgResult ) + { + if( xMlMqttQueue == NULL ) { LogError( ( "xMlMqttQueue is not initialised\r\n" ) ); - free(reinterpret_cast(pcMsgResult)); + free( reinterpret_cast( pcMsgResult ) ); return; } - memcpy(pcMsgResult, pcInferenceResult, xMsgLen); - const MLMqttMsg_t msg = {pcMsgResult}; - if (xQueueSendToBack(xMlMqttQueue, (void *)&msg, (TickType_t)0) != pdTRUE) { + memcpy( pcMsgResult, pcInferenceResult, xMsgLen ); + const MLMqttMsg_t msg = { pcMsgResult }; + + if( xQueueSendToBack( xMlMqttQueue, ( void * ) &msg, ( TickType_t ) 0 ) != pdTRUE ) + { LogError( ( "Failed to send message to xMlMqttQueue\r\n" ) ); - free(reinterpret_cast(pcMsgResult)); + free( reinterpret_cast( pcMsgResult ) ); } - } else { + } + else + { LogWarn( ( "Failed to send ml processing inference_result (alloc failure)" ) ); } } @@ -265,100 +275,116 @@ static void prvSetMlProcessingstate(const char *pcInferenceResult) * @param[in] results Vector of classification results to be displayed. * @return true if successful, false otherwise. **/ -static bool prvPresentInferenceResult(const std::vector &xResults); +static bool prvPresentInferenceResult( const std::vector &xResults ); -static int prvProcessImage(ApplicationContext &xApplicationContext, const uint8_t *pucImage, struct DetectRegion_t *pxCResults, uint32_t *pulResultsNum) +static int prvProcessImage( ApplicationContext &xApplicationContext, + const uint8_t * pucImage, + struct DetectRegion_t * pxCResults, + uint32_t * pulResultsNum ) { - // Get the global model - auto &xModel = xApplicationContext.Get("model"); + /* Get the global model */ + auto &xModel = xApplicationContext.Get( "model" ); - if (!xModel.IsInited()) { - LogError( ("Model is not initialised! Terminating processing.\n") ); + if( !xModel.IsInited() ) + { + LogError( ( "Model is not initialised! Terminating processing.\n" ) ); return -1; } - TfLiteTensor *xInputTensor = xModel.GetInputTensor(0); - TfLiteTensor *xOutputTensor0 = xModel.GetOutputTensor(0); - TfLiteTensor *xOutputTensor1 = xModel.GetOutputTensor(1); + TfLiteTensor * xInputTensor = xModel.GetInputTensor( 0 ); + TfLiteTensor * xOutputTensor0 = xModel.GetOutputTensor( 0 ); + TfLiteTensor * xOutputTensor1 = xModel.GetOutputTensor( 1 ); - if (!xInputTensor->dims) { - LogError( ("Invalid input tensor dims\n") ); + if( !xInputTensor->dims ) + { + LogError( ( "Invalid input tensor dims\n" ) ); return -1; - } else if (xInputTensor->dims->size < 3) { - LogError( ("Input tensor dimension should be >= 3\n") ); + } + else if( xInputTensor->dims->size < 3 ) + { + LogError( ( "Input tensor dimension should be >= 3\n" ) ); return -1; } - TfLiteIntArray *pxInputShape = xModel.GetInputShape(0); + TfLiteIntArray * pxInputShape = xModel.GetInputShape( 0 ); - const int lInputImgCols = pxInputShape->data[YoloFastestModel::ms_inputColsIdx]; - const int lInputImgRows = pxInputShape->data[YoloFastestModel::ms_inputRowsIdx]; + const int lInputImgCols = pxInputShape->data[ YoloFastestModel::ms_inputColsIdx ]; + const int lInputImgRows = pxInputShape->data[ YoloFastestModel::ms_inputRowsIdx ]; /* Set up pre and post-processing. */ /* RGB to grayscale skipped, already done outside */ - DetectorPreProcess xPreProcess = DetectorPreProcess(xInputTensor, false, xModel.IsDataSigned()); + DetectorPreProcess xPreProcess = DetectorPreProcess( xInputTensor, false, xModel.IsDataSigned() ); std::vector xResults; - const object_detection::PostProcessParams xPostProcessParams{lInputImgRows, - lInputImgCols, - object_detection::originalImageSize, - object_detection::anchor1, - object_detection::anchor2}; - DetectorPostProcess xPostProcess = DetectorPostProcess(xOutputTensor0, xOutputTensor1, xResults, xPostProcessParams); + const object_detection::PostProcessParams xPostProcessParams{ lInputImgRows, + lInputImgCols, + object_detection::originalImageSize, + object_detection::anchor1, + object_detection::anchor2 }; + DetectorPostProcess xPostProcess = DetectorPostProcess( xOutputTensor0, xOutputTensor1, xResults, xPostProcessParams ); /* Ensure there are no results leftover from previous inference when running all. */ xResults.clear(); /* Run the pre-processing, inference and post-processing. */ - if (!xPreProcess.DoPreProcess(pucImage, xInputTensor->bytes)) { - LogError( ("Pre-processing failed.") ); + if( !xPreProcess.DoPreProcess( pucImage, xInputTensor->bytes ) ) + { + LogError( ( "Pre-processing failed." ) ); return -1; } /* Run inference over this image. */ - info("Running inference on image at addr 0x%x\n", (uint32_t)pucImage); + info( "Running inference on image at addr 0x%x\n", ( uint32_t ) pucImage ); - if (!xModel.RunInference()) { - LogError( ("Inference failed.") ); + if( !xModel.RunInference() ) + { + LogError( ( "Inference failed." ) ); return -1; } - if (!xPostProcess.DoPostProcess()) { - LogError( ("Post-processing failed.") ); + if( !xPostProcess.DoPostProcess() ) + { + LogError( ( "Post-processing failed." ) ); return -1; } - for (uint32_t i = 0; i < xResults.size() && i < *pulResultsNum; ++i) { - pxCResults[i].ulX = xResults[i].m_x0; - pxCResults[i].ulY = xResults[i].m_y0; - pxCResults[i].ulW = xResults[i].m_w; - pxCResults[i].ulH = xResults[i].m_h; + for( uint32_t i = 0; i < xResults.size() && i < *pulResultsNum; ++i ) + { + pxCResults[ i ].ulX = xResults[ i ].m_x0; + pxCResults[ i ].ulY = xResults[ i ].m_y0; + pxCResults[ i ].ulW = xResults[ i ].m_w; + pxCResults[ i ].ulH = xResults[ i ].m_h; } - if (!prvPresentInferenceResult(xResults)) { + + if( !prvPresentInferenceResult( xResults ) ) + { return -1; } - if (*pulResultsNum > xResults.size()) { + + if( *pulResultsNum > xResults.size() ) + { *pulResultsNum = xResults.size(); } return 0; } -static bool prvPresentInferenceResult(const std::vector &xResults) +static bool prvPresentInferenceResult( const std::vector &xResults ) { /* If profiling is enabled, and the time is valid. */ - LogInfo( ("Final results:\n") ); - LogInfo( ("Total number of inferences: 1\n") ); - LogInfo( ("Detected faces: %d\n", xResults.size()) ); - - for (uint32_t i = 0; i < xResults.size(); ++i) { - LogInfo( ("%" PRIu32 ") (%f) -> %s {x=%d,y=%d,w=%d,h=%d}\n", - i, - xResults[i].m_normalisedVal, - "Detection box:", - xResults[i].m_x0, - xResults[i].m_y0, - xResults[i].m_w, - xResults[i].m_h) ); + LogInfo( ( "Final results:\n" ) ); + LogInfo( ( "Total number of inferences: 1\n" ) ); + LogInfo( ( "Detected faces: %d\n", xResults.size() ) ); + + for( uint32_t i = 0; i < xResults.size(); ++i ) + { + LogInfo( ( "%" PRIu32 ") (%f) -> %s {x=%d,y=%d,w=%d,h=%d}\n", + i, + xResults[ i ].m_normalisedVal, + "Detection box:", + xResults[ i ].m_x0, + xResults[ i ].m_y0, + xResults[ i ].m_w, + xResults[ i ].m_h ) ); } std::string xFinalResultStr = "Detected faces: "; @@ -366,89 +392,92 @@ static bool prvPresentInferenceResult(const std::vector("model", xModel); + xCaseContext.Set( "model", xModel ); PrintTensorFlowVersion(); LogInfo( ( "*** ML interface initialised\r\n" ) ); return 0; } -void vMlTask(void *pvParameters) +void vMlTask( void * pvParameters ) { LogInfo( ( "ML Task start\r\n" ) ); EventBits_t xFlags = xEventGroupWaitBits( - xSystemEvents, (EventBits_t)EVENT_MASK_ML_START, pdTRUE, pdFAIL, portMAX_DELAY - ); + xSystemEvents, ( EventBits_t ) EVENT_MASK_ML_START, pdTRUE, pdFAIL, portMAX_DELAY + ); - if (xFlags & EVENT_MASK_ML_START) { + if( xFlags & EVENT_MASK_ML_START ) + { LogInfo( ( "Initial start of image processing\r\n" ) ); } - if (prvMlInterfaceInit() < 0) { + if( prvMlInterfaceInit() < 0 ) + { LogError( ( "prvMlInterfaceInit failed\r\n" ) ); return; } vStartISPDemo(); - while (1) { + while( 1 ) + { xFlags = xEventGroupWaitBits( - xSystemEvents, (EventBits_t)EVENT_MASK_ML_STOP, pdTRUE, pdFAIL, 300 - ); + xSystemEvents, ( EventBits_t ) EVENT_MASK_ML_STOP, pdTRUE, pdFAIL, 300 + ); - if (xFlags & EVENT_MASK_ML_STOP) { + if( xFlags & EVENT_MASK_ML_STOP ) + { LogInfo( ( "Stopping image processing\r\n" ) ); break; } } } -void vMlMqttTask(void *pvParameters) +void vMlMqttTask( void * pvParameters ) { - (void)pvParameters; + ( void ) pvParameters; - while (1) { + while( 1 ) + { MLMqttMsg_t xMsg; - if (xQueueReceive (xMlMqttQueue, &xMsg, portMAX_DELAY) == pdTRUE) { - prvMqttSendMessage(xMsg.pcResult); - free(reinterpret_cast(xMsg.pcResult)); - } else { + + if( xQueueReceive( xMlMqttQueue, &xMsg, portMAX_DELAY ) == pdTRUE ) + { + prvMqttSendMessage( xMsg.pcResult ); + free( reinterpret_cast( xMsg.pcResult ) ); + } + else + { LogError( ( "xQueueReceive ML MQTT msg queue failed\r\n" ) ); } } } - -} // extern "C" +} /* extern "C" */ diff --git a/applications/speech_recognition/dsp/src/dsp_interfaces.cpp b/applications/speech_recognition/dsp/src/dsp_interfaces.cpp index 2b91198..c9c743f 100644 --- a/applications/speech_recognition/dsp/src/dsp_interfaces.cpp +++ b/applications/speech_recognition/dsp/src/dsp_interfaces.cpp @@ -1,4 +1,4 @@ -/* Copyright 2022-2023 Arm Limited and/or its affiliates +/* Copyright 2022-2024 Arm Limited and/or its affiliates * * SPDX-License-Identifier: MIT */ @@ -31,71 +31,76 @@ extern "C" { static float audio_timestamp = 0.0; -void vSetAudioTimestamp(float timestamp) { +void vSetAudioTimestamp( float timestamp ) +{ taskENTER_CRITICAL(); audio_timestamp = timestamp; taskEXIT_CRITICAL(); } -float xGetAudioTimestamp() { +float xGetAudioTimestamp() +{ taskENTER_CRITICAL(); float timestamp = audio_timestamp; taskEXIT_CRITICAL(); return timestamp; } -DspAudioSource::DspAudioSource(const int16_t* audiobuffer, size_t block_count ): - block_count{block_count}, - audiobuffer{audiobuffer} +DspAudioSource::DspAudioSource( const int16_t * audiobuffer, + size_t block_count ) : + block_count{ block_count }, + audiobuffer{ audiobuffer } { - } -const int16_t *DspAudioSource::pxGetCurrentBuffer() +const int16_t * DspAudioSource::pxGetCurrentBuffer() { -#ifndef AUDIO_VSI - // Update block ID - current_block = (current_block + 1) % block_count; -#endif + #ifndef AUDIO_VSI + /* Update block ID */ + current_block = ( current_block + 1 ) % block_count; + #endif - return(audiobuffer + current_block*(AUDIO_BLOCK_SIZE/2)); + return( audiobuffer + current_block * ( AUDIO_BLOCK_SIZE / 2 ) ); } #ifdef AUDIO_VSI -void DspAudioSource::vWaitForNewBuffer() -{ - xSemaphoreTake( this->semaphore, portMAX_DELAY ); -} + void DspAudioSource::vWaitForNewBuffer() + { + xSemaphoreTake( this->semaphore, portMAX_DELAY ); + } -void DspAudioSource::prvNewAudioBlockReceived(void* ptr) -{ - auto* self = reinterpret_cast(ptr); + void DspAudioSource::prvNewAudioBlockReceived( void * ptr ) + { + auto * self = reinterpret_cast( ptr ); - // Update block ID - self->current_block = self->block_under_write; - self->block_under_write = ((self->block_under_write + 1) % self->block_count); + /* Update block ID */ + self->current_block = self->block_under_write; + self->block_under_write = ( ( self->block_under_write + 1 ) % self->block_count ); - if ( self->semaphore != NULL ) - { - BaseType_t yield = pdFALSE; - // Wakeup task waiting - if(xSemaphoreGiveFromISR(self->semaphore, &yield) == pdTRUE) + if( self->semaphore != NULL ) { - portYIELD_FROM_ISR (yield); + BaseType_t yield = pdFALSE; + + /* Wakeup task waiting */ + if( xSemaphoreGiveFromISR( self->semaphore, &yield ) == pdTRUE ) + { + portYIELD_FROM_ISR( yield ); + } } } -}; -#endif +#endif /* ifdef AUDIO_VSI */ -static bool prvDspMlLock(SemaphoreHandle_t ml_fifo_mutex) +static bool prvDspMlLock( SemaphoreHandle_t ml_fifo_mutex ) { - if ( ml_fifo_mutex == NULL ) { + if( ml_fifo_mutex == NULL ) + { return false; } - if ( xSemaphoreTake( ml_fifo_mutex, portMAX_DELAY ) != pdTRUE ) { + if( xSemaphoreTake( ml_fifo_mutex, portMAX_DELAY ) != pdTRUE ) + { LogError( ( "Failed to acquire ml_fifo_mutex" ) ); return false; } @@ -103,13 +108,15 @@ static bool prvDspMlLock(SemaphoreHandle_t ml_fifo_mutex) return true; } -static bool prvDspMlUnlock(SemaphoreHandle_t ml_fifo_mutex) +static bool prvDspMlUnlock( SemaphoreHandle_t ml_fifo_mutex ) { - if ( ml_fifo_mutex == NULL ) { + if( ml_fifo_mutex == NULL ) + { return false; } - if ( xSemaphoreGive( ml_fifo_mutex ) != pdTRUE ) { + if( xSemaphoreGive( ml_fifo_mutex ) != pdTRUE ) + { LogError( ( "Failed to release ml_fifo_mutex" ) ); return false; } @@ -118,10 +125,10 @@ static bool prvDspMlUnlock(SemaphoreHandle_t ml_fifo_mutex) } -DSPML::DSPML(size_t bufferLengthInSamples ):nbSamples(bufferLengthInSamples) +DSPML::DSPML( size_t bufferLengthInSamples ) : nbSamples( bufferLengthInSamples ) { - bufferA=static_cast(malloc(bufferLengthInSamples*sizeof(int16_t))); - bufferB=static_cast(malloc(bufferLengthInSamples*sizeof(int16_t))); + bufferA = static_cast( malloc( bufferLengthInSamples * sizeof( int16_t ) ) ); + bufferB = static_cast( malloc( bufferLengthInSamples * sizeof( int16_t ) ) ); dspBuffer = bufferA; mlBuffer = bufferB; @@ -129,39 +136,39 @@ DSPML::DSPML(size_t bufferLengthInSamples ):nbSamples(bufferLengthInSamples) DSPML::~DSPML() { - free(bufferA); - free(bufferB); + free( bufferA ); + free( bufferB ); } -void DSPML::vCopyToDSPBufferFrom(int16_t * buf) +void DSPML::vCopyToDSPBufferFrom( int16_t * buf ) { - prvDspMlLock(mutex); - memcpy(dspBuffer,buf,sizeof(int16_t)*nbSamples); - prvDspMlUnlock(mutex); - + prvDspMlLock( mutex ); + memcpy( dspBuffer, buf, sizeof( int16_t ) * nbSamples ); + prvDspMlUnlock( mutex ); } -void DSPML::vCopyFromMLBufferInto(int16_t * buf) +void DSPML::vCopyFromMLBufferInto( int16_t * buf ) { - prvDspMlLock(mutex); - memcpy(buf,mlBuffer,sizeof(int16_t)*nbSamples); - prvDspMlUnlock(mutex); + prvDspMlLock( mutex ); + memcpy( buf, mlBuffer, sizeof( int16_t ) * nbSamples ); + prvDspMlUnlock( mutex ); } void DSPML::vSwapBuffersAndWakeUpMLThread() { - int16_t* tmp; + int16_t * tmp; - prvDspMlLock(mutex); - tmp=dspBuffer; - dspBuffer=mlBuffer; - mlBuffer=tmp; - prvDspMlUnlock(mutex); + prvDspMlLock( mutex ); + tmp = dspBuffer; + dspBuffer = mlBuffer; + mlBuffer = tmp; + prvDspMlUnlock( mutex ); BaseType_t yield = pdFALSE; - if (xSemaphoreGiveFromISR(semaphore, &yield) == pdTRUE) + + if( xSemaphoreGiveFromISR( semaphore, &yield ) == pdTRUE ) { - portYIELD_FROM_ISR (yield); + portYIELD_FROM_ISR( yield ); } } diff --git a/applications/speech_recognition/dsp/src/dsp_task.cc b/applications/speech_recognition/dsp/src/dsp_task.cc index a56f302..570480d 100644 --- a/applications/speech_recognition/dsp/src/dsp_task.cc +++ b/applications/speech_recognition/dsp/src/dsp_task.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 Arm Limited and/or its affiliates +/* Copyright 2023-2024 Arm Limited and/or its affiliates * * SPDX-License-Identifier: MIT */ @@ -33,96 +33,101 @@ extern EventGroupHandle_t xSystemEvents; #ifdef AUDIO_VSI -#include "Driver_SAI.h" + #include "Driver_SAI.h" -// audio constants -__attribute__((section(".bss.NoInit.vsi_audio_buffer"))) __attribute__((aligned(4))) -int16_t shared_audio_buffer[AUDIO_BUFFER_SIZE / 2]; +/* audio constants */ + __attribute__( ( section( ".bss.NoInit.vsi_audio_buffer" ) ) ) __attribute__( ( aligned( 4 ) ) ) + int16_t shared_audio_buffer[ AUDIO_BUFFER_SIZE / 2 ]; -extern ARM_DRIVER_SAI Driver_SAI0; -extern TaskHandle_t xVsiTaskHandle; + extern ARM_DRIVER_SAI Driver_SAI0; + extern TaskHandle_t xVsiTaskHandle; -uint32_t ulVsiEvent; + uint32_t ulVsiEvent; -extern "C" { -// Audio driver data -void (*pxOnVsiEvent)(void *); -void *pvVsiContext = nullptr; -} + extern "C" { +/* Audio driver data */ + void (* pxOnVsiEvent)( void * ); + void * pvVsiContext = nullptr; + } -// Audio driver callback function for event management -static void prvArmSaiSignalEvent(uint32_t event) -{ - if(xVsiTaskHandle == NULL) +/* Audio driver callback function for event management */ + static void prvArmSaiSignalEvent( uint32_t event ) { - LogError( ( "VSI Task is not created\r\n" ) ); - return; - } + if( xVsiTaskHandle == NULL ) + { + LogError( ( "VSI Task is not created\r\n" ) ); + return; + } - BaseType_t xHigherPriorityTaskWoken = pdFALSE; - ulVsiEvent = event; + BaseType_t xHigherPriorityTaskWoken = pdFALSE; + ulVsiEvent = event; - vTaskNotifyGiveFromISR( xVsiTaskHandle, &xHigherPriorityTaskWoken ); + vTaskNotifyGiveFromISR( xVsiTaskHandle, &xHigherPriorityTaskWoken ); - portYIELD_FROM_ISR( xHigherPriorityTaskWoken ); -} + portYIELD_FROM_ISR( xHigherPriorityTaskWoken ); + } -static int prvAudioDrvSetup(void (*event_handler)(void *), void *event_handler_ptr) -{ - if (Driver_SAI0.Initialize(prvArmSaiSignalEvent) != ARM_DRIVER_OK) { - LogError( ( "Failed to set up FVP VSI!\n" ) ); - return -1; - } + static int prvAudioDrvSetup( void ( * event_handler )( void * ), + void * event_handler_ptr ) + { + if( Driver_SAI0.Initialize( prvArmSaiSignalEvent ) != ARM_DRIVER_OK ) + { + LogError( ( "Failed to set up FVP VSI!\n" ) ); + return -1; + } - if (Driver_SAI0.PowerControl(ARM_POWER_FULL) != ARM_DRIVER_OK) { - LogError( ( "Failed to set the driver to operate with full power!\n" ) ); - return -1; - } + if( Driver_SAI0.PowerControl( ARM_POWER_FULL ) != ARM_DRIVER_OK ) + { + LogError( ( "Failed to set the driver to operate with full power!\n" ) ); + return -1; + } - if (Driver_SAI0.Control(ARM_SAI_CONTROL_RX, 1, 0) != ARM_DRIVER_OK) { - LogError( ( "Failed to enable the VSI receiver!\n" ) ); - return -1; - } + if( Driver_SAI0.Control( ARM_SAI_CONTROL_RX, 1, 0 ) != ARM_DRIVER_OK ) + { + LogError( ( "Failed to enable the VSI receiver!\n" ) ); + return -1; + } - if ( - Driver_SAI0.Control( - ARM_SAI_CONFIGURE_RX | ARM_SAI_PROTOCOL_USER | ARM_SAI_DATA_SIZE(16), - AUDIO_BLOCK_SIZE, - static_cast(SAMPLE_RATE)) != ARM_DRIVER_OK - ) { - LogError( ( "Failed to configure the receiver!\n" ) ); - return -1; - } + if( + Driver_SAI0.Control( + ARM_SAI_CONFIGURE_RX | ARM_SAI_PROTOCOL_USER | ARM_SAI_DATA_SIZE( 16 ), + AUDIO_BLOCK_SIZE, + static_cast( SAMPLE_RATE ) ) != ARM_DRIVER_OK + ) + { + LogError( ( "Failed to configure the receiver!\n" ) ); + return -1; + } - if ( - Driver_SAI0.Receive( - reinterpret_cast(shared_audio_buffer), AUDIO_BLOCK_NUM - ) != ARM_DRIVER_OK - ) { - LogError( ( "Failed to start receiving the data!\n" ) ); - return -1; - } + if( + Driver_SAI0.Receive( + reinterpret_cast( shared_audio_buffer ), AUDIO_BLOCK_NUM + ) != ARM_DRIVER_OK + ) + { + LogError( ( "Failed to start receiving the data!\n" ) ); + return -1; + } - pxOnVsiEvent = event_handler; - pvVsiContext = event_handler_ptr; + pxOnVsiEvent = event_handler; + pvVsiContext = event_handler_ptr; - return 0; -} + return 0; + } #else /* !defined(AUDIO_VSI) */ -#include "InputFiles.hpp" + #include "InputFiles.hpp" #endif // AUDIO_VSI extern "C" { - -void vDspStart(void) +void vDspStart( void ) { - if(xSystemEvents == NULL) + if( xSystemEvents == NULL ) { LogError( ( "xSystemEvents is not initialised\r\n" ) ); return; @@ -130,12 +135,12 @@ void vDspStart(void) LogInfo( ( "DSP task start\r\n" ) ); - ( void )xEventGroupSetBits( xSystemEvents, (EventBits_t)EVENT_MASK_DSP_START ); + ( void ) xEventGroupSetBits( xSystemEvents, ( EventBits_t ) EVENT_MASK_DSP_START ); } -void vDspStop(void) +void vDspStop( void ) { - if(xSystemEvents == NULL) + if( xSystemEvents == NULL ) { LogError( ( "xSystemEvents is not initialised\r\n" ) ); return; @@ -143,70 +148,73 @@ void vDspStop(void) LogInfo( ( "DSP task stop\r\n" ) ); - ( void )xEventGroupClearBits( xSystemEvents, (EventBits_t)EVENT_MASK_DSP_START ); + ( void ) xEventGroupClearBits( xSystemEvents, ( EventBits_t ) EVENT_MASK_DSP_START ); } -} // extern "C" +} /* extern "C" */ -void *pvDspGetMlConnection(void) +void * pvDspGetMlConnection( void ) { - auto dspMLConnection = new DSPML(AUDIOFEATURELENGTH); - return static_cast(dspMLConnection); + auto dspMLConnection = new DSPML( AUDIOFEATURELENGTH ); + + return static_cast( dspMLConnection ); } -void vDspTask(void *pvParameters) +void vDspTask( void * pvParameters ) { LogInfo( ( "DSP Task start\r\n" ) ); -#ifdef AUDIO_VSI - bool first_launch = true; - const int16_t *audioBuf = shared_audio_buffer; - auto audioSource = DspAudioSource(audioBuf, AUDIO_BLOCK_NUM); -#else - const int16_t *audioBuf = GetAudioArray(0); - // This integer division for calculating the number of blocks means that, - // any remainder data at the end of the audio clip that's smaller than a - // block will not be accounted for. This will not have a major impact on - // the inference result as a block is only a small fraction of a second. - const size_t audioBlockNum = (size_t)GetAudioArraySize(0) / (AUDIO_BLOCK_SIZE / sizeof(uint16_t)); - auto audioSource = DspAudioSource(audioBuf, audioBlockNum); -#endif - - DSPML *dspMLConnection = static_cast(pvParameters); - - while (1) { - // Wait for the start message - EventBits_t flags = xEventGroupWaitBits(xSystemEvents, (EventBits_t)EVENT_MASK_DSP_START, pdFAIL, pdFAIL, portMAX_DELAY); + #ifdef AUDIO_VSI + bool first_launch = true; + const int16_t * audioBuf = shared_audio_buffer; + auto audioSource = DspAudioSource( audioBuf, AUDIO_BLOCK_NUM ); + #else + const int16_t * audioBuf = GetAudioArray( 0 ); + /* This integer division for calculating the number of blocks means that, */ + /* any remainder data at the end of the audio clip that's smaller than a */ + /* block will not be accounted for. This will not have a major impact on */ + /* the inference result as a block is only a small fraction of a second. */ + const size_t audioBlockNum = ( size_t ) GetAudioArraySize( 0 ) / ( AUDIO_BLOCK_SIZE / sizeof( uint16_t ) ); + auto audioSource = DspAudioSource( audioBuf, audioBlockNum ); + #endif /* ifdef AUDIO_VSI */ + + DSPML * dspMLConnection = static_cast( pvParameters ); + + while( 1 ) + { + /* Wait for the start message */ + EventBits_t flags = xEventGroupWaitBits( xSystemEvents, ( EventBits_t ) EVENT_MASK_DSP_START, pdFAIL, pdFAIL, portMAX_DELAY ); - if (flags & EVENT_MASK_DSP_START) + if( flags & EVENT_MASK_DSP_START ) { LogInfo( ( "Initial start of audio processing\r\n" ) ); } -#ifdef AUDIO_VSI - if (first_launch) { - prvAudioDrvSetup(&DspAudioSource::prvNewAudioBlockReceived, &audioSource); - first_launch = false; - } -#endif - - // Launch the CMSIS-DSP synchronous data flow. - // This compute graph is defined in graph.py - // It can be regenerated with - // pip install cmsisdsp - // python graph.py + #ifdef AUDIO_VSI + if( first_launch ) + { + prvAudioDrvSetup( &DspAudioSource::prvNewAudioBlockReceived, &audioSource ); + first_launch = false; + } + #endif + + /* Launch the CMSIS-DSP synchronous data flow. */ + /* This compute graph is defined in graph.py */ + /* It can be regenerated with */ + /* pip install cmsisdsp */ + /* python graph.py */ int error; - uint32_t nbSched=ulScheduler(&error,&audioSource, dspMLConnection); + uint32_t nbSched = ulScheduler( &error, &audioSource, dspMLConnection ); LogInfo( ( "Synchronous Dataflow Scheduler ended with error %d after %i schedule loops\r\n", - error, - nbSched - ) ); + error, + nbSched + ) ); } } -void vStartDSPTask( void *pvParameters ) +void vStartDSPTask( void * pvParameters ) { - if ( + if( xTaskCreate( vDspTask, "DSP_TASK", @@ -214,8 +222,9 @@ void vStartDSPTask( void *pvParameters ) pvParameters, appCONFIG_DSP_TASK_PRIORITY, NULL - ) != pdPASS - ) { + ) != pdPASS + ) + { LogError( ( "Failed to create DSP Task\r\n" ) ); } } diff --git a/applications/speech_recognition/dsp/src/scheduler.cpp b/applications/speech_recognition/dsp/src/scheduler.cpp index 0770432..afa19fc 100644 --- a/applications/speech_recognition/dsp/src/scheduler.cpp +++ b/applications/speech_recognition/dsp/src/scheduler.cpp @@ -1,11 +1,11 @@ -/* Copyright 2022-2023 Arm Limited and/or its affiliates +/* Copyright 2022-2024 Arm Limited and/or its affiliates * * SPDX-License-Identifier: MIT */ /* -Generated with CMSIS-DSP SDF Scripts. -*/ + * Generated with CMSIS-DSP SDF Scripts. + */ #include "FreeRTOS.h" #include "arm_math.h" @@ -17,191 +17,192 @@ Generated with CMSIS-DSP SDF Scripts. #include "task.h" /*********** -FIFO buffers -************/ -#define FIFOSIZE0 1600 -#define FIFOSIZE1 16000 -#define FIFOSIZE2 47360 + * FIFO buffers + ************/ +#define FIFOSIZE0 1600 +#define FIFOSIZE1 16000 +#define FIFOSIZE2 47360 -#define BUFFERSIZE0 1600 -int16_t buf0[BUFFERSIZE0]={0}; +#define BUFFERSIZE0 1600 +int16_t buf0[ BUFFERSIZE0 ] = { 0 }; -#define BUFFERSIZE1 16000 -int16_t buf1[BUFFERSIZE1]={0}; +#define BUFFERSIZE1 16000 +int16_t buf1[ BUFFERSIZE1 ] = { 0 }; -#define BUFFERSIZE2 47360 -int16_t buf2[BUFFERSIZE2]={0}; +#define BUFFERSIZE2 47360 +int16_t buf2[ BUFFERSIZE2 ] = { 0 }; extern EventGroupHandle_t xSystemEvents; -uint32_t ulScheduler( - int *error, - DspAudioSource *dspAudio, - DSPML *dspMLConnection -) { -// Define CHECKERROR_OR_PAUSE -// This updated version of CHECKERROR verify if the task must be stopped or not -#define CHECKERROR_OR_PAUSE \ - if (sdfError < 0) {\ - break; \ - } else { \ - if ((xEventGroupGetBits (xSystemEvents) & EVENT_MASK_DSP_START) == 0U) { \ - break; \ - } \ +uint32_t ulScheduler( int * error, + DspAudioSource * dspAudio, + DSPML * dspMLConnection ) +{ +/* Define CHECKERROR_OR_PAUSE */ +/* This updated version of CHECKERROR verify if the task must be stopped or not */ +#define CHECKERROR_OR_PAUSE \ + if( sdfError < 0 ) { \ + break; \ + } \ + else { \ + if( ( xEventGroupGetBits( xSystemEvents ) & EVENT_MASK_DSP_START ) == 0U ) { \ + break; \ + } \ } - int sdfError=0; - uint32_t nbSchedule=0; + int sdfError = 0; + uint32_t nbSchedule = 0; /* - Create FIFOs objects - */ - FIFO fifo0(buf0); - FIFO fifo1(buf1); - FIFO fifo2(buf2); + * Create FIFOs objects + */ + FIFO fifo0( buf0 ); + FIFO fifo1( buf1 ); + FIFO fifo2( buf2 ); /* - Create node objects - */ - SlidingBuffer audioWin(fifo1,fifo2); - DSP dsp(fifo0,fifo1); - MicrophoneSource mic(fifo0,dspAudio); - ML ml(fifo2,dspMLConnection); + * Create node objects + */ + SlidingBuffer audioWin( fifo1, fifo2 ); + DSP dsp( fifo0, fifo1 ); + MicrophoneSource mic( fifo0, dspAudio ); + ML ml( fifo2, dspMLConnection ); /* Run several schedule iterations */ - while(sdfError==0) + while( sdfError == 0 ) { - /* Run a schedule iteration */ - sdfError = mic.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = mic.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = mic.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = mic.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = mic.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = mic.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = mic.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = mic.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = mic.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = mic.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = dsp.run(); - CHECKERROR_OR_PAUSE; - sdfError = audioWin.run(); - CHECKERROR_OR_PAUSE; - sdfError = ml.run(); - CHECKERROR_OR_PAUSE; + /* Run a schedule iteration */ + sdfError = mic.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = mic.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = mic.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = mic.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = mic.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = mic.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = mic.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = mic.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = mic.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = mic.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = dsp.run(); + CHECKERROR_OR_PAUSE; + sdfError = audioWin.run(); + CHECKERROR_OR_PAUSE; + sdfError = ml.run(); + CHECKERROR_OR_PAUSE; - nbSchedule++; + nbSchedule++; - // Add delay to allow some time for the connectivity task - // to send and receive messages to and from the cloud. - vTaskDelay(140); + /* Add delay to allow some time for the connectivity task */ + /* to send and receive messages to and from the cloud. */ + vTaskDelay( 140 ); } - *error=sdfError; - return(nbSchedule); + + *error = sdfError; + return( nbSchedule ); } diff --git a/applications/speech_recognition/ml_interface.cc b/applications/speech_recognition/ml_interface.cc index ee4dba4..4cacd63 100644 --- a/applications/speech_recognition/ml_interface.cc +++ b/applications/speech_recognition/ml_interface.cc @@ -90,28 +90,27 @@ extern MQTTAgentContext_t xGlobalMqttAgentContext; extern EventGroupHandle_t xSystemEvents; extern QueueHandle_t xMlMqttQueue; -// Define tensor arena and declare functions required to access the model +/* Define tensor arena and declare functions required to access the model */ namespace arm { namespace app { -uint8_t tensorArena[ACTIVATION_BUF_SZ] ACTIVATION_BUF_ATTRIBUTE; +uint8_t tensorArena[ ACTIVATION_BUF_SZ ] ACTIVATION_BUF_ATTRIBUTE; namespace asr { -extern uint8_t *GetModelPointer(); +extern uint8_t * GetModelPointer(); extern size_t GetModelLen(); } /* namespace asr */ } /* namespace app */ } /* namespace arm */ namespace { - -typedef struct { - char *result; +typedef struct +{ + char * result; } ml_mqtt_msg_t; -// Import +/* Import */ using namespace arm::app; extern "C" { - static void prvAppPublishCommandCallback( MQTTAgentCommandContext_t * pxCommandContext, MQTTAgentReturnInfo_t * pxReturnInfo ) { @@ -125,9 +124,9 @@ static void prvAppPublishCommandCallback( MQTTAgentCommandContext_t * pxCommandC static void prvMqttSendMessage( const char * message ) { - static MQTTPublishInfo_t publishInfo = { (MQTTQoS_t)0 }; + static MQTTPublishInfo_t publishInfo = { ( MQTTQoS_t ) 0 }; static MQTTAgentCommandInfo_t xCommandParams = { 0 }; - static MQTTAgentCommandContext_t xCommandContext = { (MQTTStatus_t)0 }; + static MQTTAgentCommandContext_t xCommandContext = { ( MQTTStatus_t ) 0 }; MQTTStatus_t mqttStatus = MQTTBadParameter; publishInfo.pTopicName = mqttexampleTOPIC; @@ -176,36 +175,35 @@ static void prvMqttSendMessage( const char * message ) } } -void vMlTaskInferenceStart(void) +void vMlTaskInferenceStart( void ) { - if(xSystemEvents == NULL) + if( xSystemEvents == NULL ) { LogError( ( "xSystemEvents is not initialised\r\n" ) ); return; } LogInfo( ( "Signal task inference start\r\n" ) ); - ( void )xEventGroupClearBits( xSystemEvents, (EventBits_t)EVENT_MASK_ML_STOP ); - ( void )xEventGroupSetBits( xSystemEvents, (EventBits_t)EVENT_MASK_ML_START ); - + ( void ) xEventGroupClearBits( xSystemEvents, ( EventBits_t ) EVENT_MASK_ML_STOP ); + ( void ) xEventGroupSetBits( xSystemEvents, ( EventBits_t ) EVENT_MASK_ML_START ); } -void vMlTaskInferenceStop(void) +void vMlTaskInferenceStop( void ) { - if(xSystemEvents == NULL) + if( xSystemEvents == NULL ) { LogError( ( "xSystemEvents is not initialised\r\n" ) ); return; } LogInfo( ( "Signal task inference stop\r\n" ) ); - ( void )xEventGroupClearBits( xSystemEvents, (EventBits_t)EVENT_MASK_ML_START ); - ( void )xEventGroupSetBits( xSystemEvents, (EventBits_t)EVENT_MASK_ML_STOP ); + ( void ) xEventGroupClearBits( xSystemEvents, ( EventBits_t ) EVENT_MASK_ML_START ); + ( void ) xEventGroupSetBits( xSystemEvents, ( EventBits_t ) EVENT_MASK_ML_STOP ); } -void vStartMlTask( void *pvParameters ) +void vStartMlTask( void * pvParameters ) { - if ( + if( xTaskCreate( vMlTask, "ML_TASK", @@ -213,15 +211,16 @@ void vStartMlTask( void *pvParameters ) pvParameters, appCONFIG_ML_TASK_PRIORITY, NULL - ) != pdPASS - ) { + ) != pdPASS + ) + { LogError( ( "Failed to create ML Task\r\n" ) ); } } void vStartMlMqttTask( void ) { - if ( + if( xTaskCreate( vMlMqttTask, "ML_MQTT", @@ -229,37 +228,44 @@ void vStartMlMqttTask( void ) NULL, appCONFIG_ML_MQTT_TASK_PRIORITY, NULL - ) != pdPASS - ) { + ) != pdPASS + ) + { LogError( ( "Failed to create ML Mqtt Task\r\n" ) ); } } -} // extern "C" { +} /* extern "C" { */ -static void prvSetMlProcessingstate(const char *inference_result) +static void prvSetMlProcessingstate( const char * inference_result ) { - size_t msg_len = strlen(inference_result) + 1; - char *msg_result = reinterpret_cast(malloc(msg_len)); - if (msg_result) { - if(xMlMqttQueue == NULL) + size_t msg_len = strlen( inference_result ) + 1; + char * msg_result = reinterpret_cast( malloc( msg_len ) ); + + if( msg_result ) + { + if( xMlMqttQueue == NULL ) { LogError( ( "xMlMqttQueue is not initialised\r\n" ) ); - free(reinterpret_cast(msg_result)); + free( reinterpret_cast( msg_result ) ); return; } - memcpy(msg_result, inference_result, msg_len); - const ml_mqtt_msg_t msg = {msg_result}; - if (xQueueSendToBack(xMlMqttQueue, (void *)&msg, (TickType_t)0) != pdTRUE) { + memcpy( msg_result, inference_result, msg_len ); + const ml_mqtt_msg_t msg = { msg_result }; + + if( xQueueSendToBack( xMlMqttQueue, ( void * ) &msg, ( TickType_t ) 0 ) != pdTRUE ) + { LogError( ( "Failed to send message to xMlMqttQueue\r\n" ) ); - free(reinterpret_cast(msg_result)); + free( reinterpret_cast( msg_result ) ); } - } else { + } + else + { LogWarn( ( "Failed to send ml processing inference_result (alloc failure)" ) ); } } -// Model +/* Model */ arm::app::ApplicationContext caseContext; /** @@ -269,104 +275,112 @@ arm::app::ApplicationContext caseContext; * @param[in] results Vector of classification results to be displayed. * @return true if successful, false otherwise. **/ -static bool prvPresentInferenceResult(const std::vector &results); +static bool prvPresentInferenceResult( const std::vector &results ); -static void prvProcessAudio(ApplicationContext &ctx, DSPML *dspMLConnection) +static void prvProcessAudio( ApplicationContext &ctx, + DSPML * dspMLConnection ) { /* Get model reference. */ - auto &model = ctx.Get("model"); - if (!model.IsInited()) { + auto &model = ctx.Get( "model" ); + + if( !model.IsInited() ) + { LogError( ( "Model is not initialised! Terminating processing.\n" ) ); return; } /* Get score threshold to be applied for the classifier (post-inference). */ - auto scoreThreshold = ctx.Get("scoreThreshold"); + auto scoreThreshold = ctx.Get( "scoreThreshold" ); /* Get tensors. Dimensions of the tensor should have been verified by * the callee. */ - TfLiteTensor *inputTensor = model.GetInputTensor(0); - TfLiteTensor *outputTensor = model.GetOutputTensor(0); - TfLiteIntArray *inputShape = model.GetInputShape(0); + TfLiteTensor * inputTensor = model.GetInputTensor( 0 ); + TfLiteTensor * outputTensor = model.GetOutputTensor( 0 ); + TfLiteIntArray * inputShape = model.GetInputShape( 0 ); /* Populate MFCC related parameters. */ - auto mfccFrameLen = ctx.Get("frameLength"); - auto mfccFrameStride = ctx.Get("frameStride"); + auto mfccFrameLen = ctx.Get( "frameLength" ); + auto mfccFrameStride = ctx.Get( "frameStride" ); /* Populate ASR inference context and inner lengths for input. */ - auto inputCtxLen = ctx.Get("ctxLen"); + auto inputCtxLen = ctx.Get( "ctxLen" ); /* Get pre/post-processing objects. */ - AsrPreProcess preProcess = AsrPreProcess(inputTensor, - Wav2LetterModel::ms_numMfccFeatures, - inputShape->data[Wav2LetterModel::ms_inputRowsIdx], - mfccFrameLen, - mfccFrameStride); + AsrPreProcess preProcess = AsrPreProcess( inputTensor, + Wav2LetterModel::ms_numMfccFeatures, + inputShape->data[ Wav2LetterModel::ms_inputRowsIdx ], + mfccFrameLen, + mfccFrameStride ); std::vector singleInfResult; - const uint32_t outputCtxLen = AsrPostProcess::GetOutputContextLen(model, inputCtxLen); - AsrPostProcess postProcess = AsrPostProcess(outputTensor, - ctx.Get("classifier"), - ctx.Get &>("labels"), - singleInfResult, - outputCtxLen, - Wav2LetterModel::ms_blankTokenIdx, - Wav2LetterModel::ms_outputRowsIdx); - - const uint32_t inputRows = inputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputRowsIdx]; + const uint32_t outputCtxLen = AsrPostProcess::GetOutputContextLen( model, inputCtxLen ); + AsrPostProcess postProcess = AsrPostProcess( outputTensor, + ctx.Get( "classifier" ), + ctx.Get &>( "labels" ), + singleInfResult, + outputCtxLen, + Wav2LetterModel::ms_blankTokenIdx, + Wav2LetterModel::ms_outputRowsIdx ); + + const uint32_t inputRows = inputTensor->dims->data[ arm::app::Wav2LetterModel::ms_inputRowsIdx ]; /* Audio data stride corresponds to inputInnerLen feature vectors. */ const uint32_t audioParamsWinLen = inputRows * mfccFrameStride; - auto inferenceWindow = std::vector(audioParamsWinLen, 0); + auto inferenceWindow = std::vector( audioParamsWinLen, 0 ); size_t inferenceWindowLen = audioParamsWinLen; - // Start processing audio data as it arrive + /* Start processing audio data as it arrive */ uint32_t inferenceIndex = 0; - // We do inference on 2 audio segments before reporting a result - // We do not have the concept of audio clip in a streaming application - // so we need to decide when a sentenced is finished to start a recognition. - // It was arbitrarily chosen to be 2 inferences. - // In a real app, a voice activity detector would probably be used - // to detect a long silence between 2 sentences. + /* We do inference on 2 audio segments before reporting a result */ + /* We do not have the concept of audio clip in a streaming application */ + /* so we need to decide when a sentenced is finished to start a recognition. */ + /* It was arbitrarily chosen to be 2 inferences. */ + /* In a real app, a voice activity detector would probably be used */ + /* to detect a long silence between 2 sentences. */ const uint32_t maxNbInference = 2; std::vector results; - while (true) { - while (true) { + while( true ) + { + while( true ) + { EventBits_t flags = xEventGroupWaitBits( - xSystemEvents, (EventBits_t)EVENT_MASK_ML_STOP, pdTRUE, pdFAIL, 300 - ); + xSystemEvents, ( EventBits_t ) EVENT_MASK_ML_STOP, pdTRUE, pdFAIL, 300 + ); - if (flags & EVENT_MASK_ML_STOP) { + if( flags & EVENT_MASK_ML_STOP ) + { LogInfo( ( "Stopping audio processing\r\n" ) ); break; } - // Wait for the DSP task signal to start the recognition + /* Wait for the DSP task signal to start the recognition */ dspMLConnection->vWaitForDSPData(); - int16_t *p = inferenceWindow.data(); - dspMLConnection->vCopyFromMLBufferInto(p); + int16_t * p = inferenceWindow.data(); + dspMLConnection->vCopyFromMLBufferInto( p ); - // This timestamp is corresponding to the time when - // inference is starting and not to the time of the - // beginning of the audio segment used for this inference. + /* This timestamp is corresponding to the time when */ + /* inference is starting and not to the time of the */ + /* beginning of the audio segment used for this inference. */ float currentTimeStamp = xGetAudioTimestamp(); LogInfo( ( "Inference %i/%i\n", inferenceIndex + 1, maxNbInference ) ); /* Run the pre-processing, inference and post-processing. */ - if (!preProcess.DoPreProcess(inferenceWindow.data(), inferenceWindowLen)) { + if( !preProcess.DoPreProcess( inferenceWindow.data(), inferenceWindowLen ) ) + { LogError( ( "Pre-processing failed." ) ); } -#ifdef AUDIO_VSI - LogInfo( ( "Start running inference on audio input from the Virtual Streaming Interface\r\n" ) ); -#else - LogInfo( ( "Start running inference on an audio clip in local memory\r\n" ) ); -#endif + #ifdef AUDIO_VSI + LogInfo( ( "Start running inference on audio input from the Virtual Streaming Interface\r\n" ) ); + #else + LogInfo( ( "Start running inference on an audio clip in local memory\r\n" ) ); + #endif /* Run inference over this audio clip sliding window. */ - if (!model.RunInference()) { + if( !model.RunInference() ) + { LogError( ( "Failed to run inference" ) ); return; } @@ -374,226 +388,242 @@ static void prvProcessAudio(ApplicationContext &ctx, DSPML *dspMLConnection) LogDebug( ( "Doing post processing\n" ) ); /* Post processing needs to know if we are on the last audio window. */ - // postProcess.m_lastIteration = !audioDataSlider.HasNext(); - if (!postProcess.DoPostProcess()) { + /* postProcess.m_lastIteration = !audioDataSlider.HasNext(); */ + if( !postProcess.DoPostProcess() ) + { LogError( ( "Post-processing failed." ) ); } LogInfo( ( "Inference done\n" ) ); std::vector classificationResult; - auto &classifier = ctx.Get("classifier"); + auto &classifier = ctx.Get( "classifier" ); classifier.GetClassificationResults( outputTensor, classificationResult, - ctx.Get &>("labels"), + ctx.Get &>( "labels" ), 1, true - ); + ); auto result = asr::AsrResult( classificationResult, currentTimeStamp, inferenceIndex, scoreThreshold - ); + ); - results.emplace_back(result); + results.emplace_back( result ); inferenceIndex = inferenceIndex + 1; - if (inferenceIndex == maxNbInference) { + if( inferenceIndex == maxNbInference ) + { inferenceIndex = 0; - ctx.Set>("results", results); + ctx.Set >( "results", results ); - if (!prvPresentInferenceResult(results)) { + if( !prvPresentInferenceResult( results ) ) + { return; } results.clear(); } - // Inference loop + /* Inference loop */ } /* while (true) */ EventBits_t flags = xEventGroupWaitBits( - xSystemEvents, (EventBits_t)EVENT_MASK_ML_START, pdTRUE, pdFAIL, portMAX_DELAY - ); + xSystemEvents, ( EventBits_t ) EVENT_MASK_ML_START, pdTRUE, pdFAIL, portMAX_DELAY + ); - if (flags & EVENT_MASK_ML_START) { + if( flags & EVENT_MASK_ML_START ) + { LogInfo( ( "Restarting audio processing %u\r\n", flags ) ); } } /* while (true) */ } -static bool prvPresentInferenceResult(const std::vector &results) +static bool prvPresentInferenceResult( const std::vector &results ) { LogInfo( ( "Final results:\n" ) ); LogInfo( ( "Total number of inferences: %zu\n", results.size() ) ); /* Results from multiple inferences should be combined before processing. */ std::vector combinedResults; - for (auto &result : results) { - combinedResults.insert(combinedResults.end(), result.m_resultVec.begin(), result.m_resultVec.end()); + + for( auto &result : results ) + { + combinedResults.insert( combinedResults.end(), result.m_resultVec.begin(), result.m_resultVec.end() ); } /* Get each inference result string using the decoder. */ - for (const auto &result : results) { - std::string infResultStr = audio::asr::DecodeOutput(result.m_resultVec); + for( const auto &result : results ) + { + std::string infResultStr = audio::asr::DecodeOutput( result.m_resultVec ); LogInfo( ( "For timestamp: %f (inference #: %" PRIu32 "); label: %s\r\n", - (double)result.m_timeStamp, - result.m_inferenceNumber, - infResultStr.c_str() ) ); + ( double ) result.m_timeStamp, + result.m_inferenceNumber, + infResultStr.c_str() ) ); } /* Get the decoded result for the combined result. */ - std::string finalResultStr = audio::asr::DecodeOutput(combinedResults); + std::string finalResultStr = audio::asr::DecodeOutput( combinedResults ); LogInfo( ( "Complete recognition: %s\n", finalResultStr.c_str() ) ); - // Send the inference result - prvSetMlProcessingstate(finalResultStr.c_str()); + /* Send the inference result */ + prvSetMlProcessingstate( finalResultStr.c_str() ); return true; } - -} // anonymous namespace +} /* anonymous namespace */ #ifdef USE_ETHOS -extern struct ethosu_driver ethosu_drv; /* Default Ethos-U55 device driver */ + extern struct ethosu_driver ethosu_drv; /* Default Ethos-U55 device driver */ /** * @brief Initialises the Arm Ethos-U55 NPU * @return 0 if successful, error code otherwise **/ -static int prvArmNpuInit(void); + static int prvArmNpuInit( void ); -static int prvArmNpuInit(void) -{ - int err = 0; + static int prvArmNpuInit( void ) + { + int err = 0; - SCB_EnableICache(); - SCB_EnableDCache(); + SCB_EnableICache(); + SCB_EnableDCache(); -#if defined(ETHOS_U_NPU_TIMING_ADAPTER_ENABLED) - /* If the platform has timing adapter blocks along with Ethos-U core - * block, initialise them here. */ - if (0 != (err = arm_ethosu_timing_adapter_init())) { - LogError( ("Failed to init timing adapter\n") ); - return err; - } -#endif /* ETHOS_U_NPU_TIMING_ADAPTER_ENABLED */ + #if defined( ETHOS_U_NPU_TIMING_ADAPTER_ENABLED ) - // Initialize the ethos NPU - if (0 != (err = arm_ethosu_npu_init())) { - LogError( ("Failed to init arm npu\n") ); - return err; - } + /* If the platform has timing adapter blocks along with Ethos-U core + * block, initialise them here. */ + if( 0 != ( err = arm_ethosu_timing_adapter_init() ) ) + { + LogError( ( "Failed to init timing adapter\n" ) ); + return err; + } + #endif /* ETHOS_U_NPU_TIMING_ADAPTER_ENABLED */ - LogInfo( ( "Ethos-U55 device initialised\n" ) ); + /* Initialize the ethos NPU */ + if( 0 != ( err = arm_ethosu_npu_init() ) ) + { + LogError( ( "Failed to init arm npu\n" ) ); + return err; + } - /* Get Ethos-U55 version */ - struct ethosu_driver_version driver_version; - struct ethosu_hw_info hw_info; + LogInfo( ( "Ethos-U55 device initialised\n" ) ); - ethosu_get_driver_version(&driver_version); - ethosu_get_hw_info(ðosu_drv, &hw_info); + /* Get Ethos-U55 version */ + struct ethosu_driver_version driver_version; + struct ethosu_hw_info hw_info; - LogInfo( ( "Ethos-U version info:\n" ) ); - LogInfo( ( "\tArch: v%" PRIu32 ".%" PRIu32 ".%" PRIu32 "\n", - hw_info.version.arch_major_rev, - hw_info.version.arch_minor_rev, - hw_info.version.arch_patch_rev ) ); - LogInfo( ( "\tDriver: v%" PRIu8 ".%" PRIu8 ".%" PRIu8 "\n", - driver_version.major, - driver_version.minor, - driver_version.patch ) ); - LogInfo( ( "\tMACs/cc: %" PRIu32 "\n", (uint32_t)(1 << hw_info.cfg.macs_per_cc) ) ); - LogInfo( ( "\tCmd stream: v%" PRIu32 "\n", hw_info.cfg.cmd_stream_version ) ); + ethosu_get_driver_version( &driver_version ); + ethosu_get_hw_info( ðosu_drv, &hw_info ); - return 0; -} + LogInfo( ( "Ethos-U version info:\n" ) ); + LogInfo( ( "\tArch: v%" PRIu32 ".%" PRIu32 ".%" PRIu32 "\n", + hw_info.version.arch_major_rev, + hw_info.version.arch_minor_rev, + hw_info.version.arch_patch_rev ) ); + LogInfo( ( "\tDriver: v%" PRIu8 ".%" PRIu8 ".%" PRIu8 "\n", + driver_version.major, + driver_version.minor, + driver_version.patch ) ); + LogInfo( ( "\tMACs/cc: %" PRIu32 "\n", ( uint32_t ) ( 1 << hw_info.cfg.macs_per_cc ) ) ); + LogInfo( ( "\tCmd stream: v%" PRIu32 "\n", hw_info.cfg.cmd_stream_version ) ); + + return 0; + } #endif /* USE_ETHOS */ extern "C" { - -static int prvMlInterfaceInit(void) +static int prvMlInterfaceInit( void ) { static arm::app::Wav2LetterModel model; /* Model wrapper object. */ static arm::app::AsrClassifier classifier; /* Classifier wrapper object. */ static std::vector labels; -#ifdef USE_ETHOS - // Initialize the ethos U55 - if (prvArmNpuInit() != 0) { - LogError( ( "Failed to arm npu\n" ) ); - return -1; - } -#endif /* USE_ETHOS */ + #ifdef USE_ETHOS + /* Initialize the ethos U55 */ + if( prvArmNpuInit() != 0 ) + { + LogError( ( "Failed to arm npu\n" ) ); + return -1; + } + #endif /* USE_ETHOS */ /* Load the model. */ - if (!model.Init(::arm::app::tensorArena, - sizeof(::arm::app::tensorArena), - ::arm::app::asr::GetModelPointer(), - ::arm::app::asr::GetModelLen())) { + if( !model.Init( ::arm::app::tensorArena, + sizeof( ::arm::app::tensorArena ), + ::arm::app::asr::GetModelPointer(), + ::arm::app::asr::GetModelLen() ) ) + { LogError( ( "Failed to initialise model\n" ) ); return -1; } /* Initialise post-processing. */ - GetLabelsVector(labels); + GetLabelsVector( labels ); /* Instantiate application context. */ - caseContext.Set("model", model); - caseContext.Set("frameLength", g_FrameLength); - caseContext.Set("frameStride", g_FrameStride); - caseContext.Set("ctxLen", g_ctxLen); + caseContext.Set( "model", model ); + caseContext.Set( "frameLength", g_FrameLength ); + caseContext.Set( "frameStride", g_FrameStride ); + caseContext.Set( "ctxLen", g_ctxLen ); - caseContext.Set("scoreThreshold", g_ScoreThreshold); /* Normalised score threshold. */ + caseContext.Set( "scoreThreshold", g_ScoreThreshold ); /* Normalised score threshold. */ - caseContext.Set &>("labels", labels); - caseContext.Set("classifier", classifier); + caseContext.Set &>( "labels", labels ); + caseContext.Set( "classifier", classifier ); PrintTensorFlowVersion(); LogInfo( ( "*** ML interface initialised\r\n" ) ); return 0; } -void vMlTask(void *pvParameters) +void vMlTask( void * pvParameters ) { LogInfo( ( "ML Task start\r\n" ) ); - DSPML *dspMLConnection = static_cast(pvParameters); + DSPML * dspMLConnection = static_cast( pvParameters ); EventBits_t flags = xEventGroupWaitBits( - xSystemEvents, (EventBits_t)EVENT_MASK_ML_START, pdTRUE, pdFAIL, portMAX_DELAY - ); + xSystemEvents, ( EventBits_t ) EVENT_MASK_ML_START, pdTRUE, pdFAIL, portMAX_DELAY + ); - if (flags & EVENT_MASK_ML_START) { + if( flags & EVENT_MASK_ML_START ) + { LogInfo( ( "Initial start of audio processing\r\n" ) ); } - if (prvMlInterfaceInit() < 0) { + if( prvMlInterfaceInit() < 0 ) + { LogError( ( "prvMlInterfaceInit failed\r\n" ) ); return; } - prvProcessAudio(caseContext, dspMLConnection); + prvProcessAudio( caseContext, dspMLConnection ); } -void vMlMqttTask(void *pvParameters) +void vMlMqttTask( void * pvParameters ) { - (void)pvParameters; + ( void ) pvParameters; - while (1) { + while( 1 ) + { ml_mqtt_msg_t msg; - if (xQueueReceive (xMlMqttQueue, &msg, portMAX_DELAY) == pdTRUE) { - prvMqttSendMessage(msg.result); - free(reinterpret_cast(msg.result)); - } else { + + if( xQueueReceive( xMlMqttQueue, &msg, portMAX_DELAY ) == pdTRUE ) + { + prvMqttSendMessage( msg.result ); + free( reinterpret_cast( msg.result ) ); + } + else + { LogError( ( "xQueueReceive ML MQTT msg queue failed\r\n" ) ); } } } - -} // extern "C" +} /* extern "C" */ diff --git a/components/aws_iot/coremqtt_agent/integration/tests/test_freertos_agent_message.cpp b/components/aws_iot/coremqtt_agent/integration/tests/test_freertos_agent_message.cpp index 0b1ffec..d0aa523 100644 --- a/components/aws_iot/coremqtt_agent/integration/tests/test_freertos_agent_message.cpp +++ b/components/aws_iot/coremqtt_agent/integration/tests/test_freertos_agent_message.cpp @@ -8,8 +8,8 @@ #include "gtest/gtest.h" extern "C" { - #include "freertos_agent_message.h" - #include "core_mqtt_agent_message_interface.h" +#include "freertos_agent_message.h" +#include "core_mqtt_agent_message_interface.h" } DEFINE_FFF_GLOBALS @@ -18,67 +18,71 @@ class TestFreertosAgentMessage : public ::testing::Test { public: TestFreertosAgentMessage() { - RESET_FAKE(xQueueSendToBack); - RESET_FAKE(xQueueReceive); + RESET_FAKE( xQueueSendToBack ); + RESET_FAKE( xQueueReceive ); } }; -TEST_F(TestFreertosAgentMessage, sending_nullptr_message_returns_false) +TEST_F( TestFreertosAgentMessage, sending_nullptr_message_returns_false ) { - MQTTAgentCommand_t *command; - EXPECT_FALSE(Agent_MessageSend(nullptr, &command, 1)); + MQTTAgentCommand_t * command; + + EXPECT_FALSE( Agent_MessageSend( nullptr, &command, 1 ) ); } -TEST_F(TestFreertosAgentMessage, sending_nullptr_command_returns_false) +TEST_F( TestFreertosAgentMessage, sending_nullptr_command_returns_false ) { MQTTAgentMessageContext_t message; - EXPECT_FALSE(Agent_MessageSend(&message, nullptr, 1)); + + EXPECT_FALSE( Agent_MessageSend( &message, nullptr, 1 ) ); } -TEST_F(TestFreertosAgentMessage, failing_to_send_a_message_returns_false) +TEST_F( TestFreertosAgentMessage, failing_to_send_a_message_returns_false ) { xQueueSendToBack_fake.return_val = pdFAIL; MQTTAgentMessageContext_t message; - MQTTAgentCommand_t *command; - EXPECT_FALSE(Agent_MessageSend(&message, &command, 1)); + MQTTAgentCommand_t * command; + EXPECT_FALSE( Agent_MessageSend( &message, &command, 1 ) ); } -TEST_F(TestFreertosAgentMessage, successfully_sending_a_message_returns_true) +TEST_F( TestFreertosAgentMessage, successfully_sending_a_message_returns_true ) { xQueueSendToBack_fake.return_val = pdPASS; MQTTAgentMessageContext_t message; - MQTTAgentCommand_t *command; - EXPECT_TRUE(Agent_MessageSend(&message, &command, 1)); + MQTTAgentCommand_t * command; + EXPECT_TRUE( Agent_MessageSend( &message, &command, 1 ) ); } -TEST_F(TestFreertosAgentMessage, request_to_receive_with_nullptr_message_returns_false) +TEST_F( TestFreertosAgentMessage, request_to_receive_with_nullptr_message_returns_false ) { - MQTTAgentCommand_t *command; - EXPECT_FALSE(Agent_MessageReceive(nullptr, &command, 1)); + MQTTAgentCommand_t * command; + + EXPECT_FALSE( Agent_MessageReceive( nullptr, &command, 1 ) ); } -TEST_F(TestFreertosAgentMessage, request_to_receive_with_nullptr_command_returns_false) +TEST_F( TestFreertosAgentMessage, request_to_receive_with_nullptr_command_returns_false ) { MQTTAgentMessageContext_t message; - EXPECT_FALSE(Agent_MessageReceive(&message, nullptr, 1)); + + EXPECT_FALSE( Agent_MessageReceive( &message, nullptr, 1 ) ); } -TEST_F(TestFreertosAgentMessage, failing_to_receive_a_message_returns_false) +TEST_F( TestFreertosAgentMessage, failing_to_receive_a_message_returns_false ) { xQueueReceive_fake.return_val = pdFAIL; MQTTAgentMessageContext_t message; - MQTTAgentCommand_t *command; - EXPECT_FALSE(Agent_MessageReceive(&message, &command, 1)); + MQTTAgentCommand_t * command; + EXPECT_FALSE( Agent_MessageReceive( &message, &command, 1 ) ); } -TEST_F(TestFreertosAgentMessage, successfully_receiving_a_message_returns_true) +TEST_F( TestFreertosAgentMessage, successfully_receiving_a_message_returns_true ) { xQueueReceive_fake.return_val = pdPASS; MQTTAgentMessageContext_t message; - MQTTAgentCommand_t *command; - EXPECT_TRUE(Agent_MessageReceive(&message, &command, 1)); + MQTTAgentCommand_t * command; + EXPECT_TRUE( Agent_MessageReceive( &message, &command, 1 ) ); } diff --git a/components/aws_iot/coremqtt_agent/integration/tests/test_mqtt_agent_task.cpp b/components/aws_iot/coremqtt_agent/integration/tests/test_mqtt_agent_task.cpp index 00c94c4..2815f0c 100644 --- a/components/aws_iot/coremqtt_agent/integration/tests/test_mqtt_agent_task.cpp +++ b/components/aws_iot/coremqtt_agent/integration/tests/test_mqtt_agent_task.cpp @@ -12,391 +12,434 @@ using namespace std; extern "C" { - #include "FreeRTOSConfig.h" - #include "backoff_algorithm.h" - #include "core_mqtt_agent_message_interface.h" - #include "core_mqtt_serializer.h" - #include "event_groups.h" - #include "events.h" - #include "logging_stack.h" - #include "mqtt_agent_task.h" - #include "psa/crypto.h" - #include "psa/error.h" - #include "task.h" - #include "queue.h" +#include "FreeRTOSConfig.h" +#include "backoff_algorithm.h" +#include "core_mqtt_agent_message_interface.h" +#include "core_mqtt_serializer.h" +#include "event_groups.h" +#include "events.h" +#include "logging_stack.h" +#include "mqtt_agent_task.h" +#include "psa/crypto.h" +#include "psa/error.h" +#include "task.h" +#include "queue.h" + +/* + * Exposed static functions to test + * The below functions are found in `mqtt_agent_task.c`, which is + * found by the inclusion of `mqtt_agent_task.h`. + */ - /* - Exposed static functions to test - The below functions are found in `mqtt_agent_task.c`, which is - found by the inclusion of `mqtt_agent_task.h`. - */ - - extern void prvMQTTAgentTask( void * pParam ); - extern BaseType_t prvSocketConnect( NetworkContext_t * pNetworkContext ); - extern void prvDisconnectFromMQTTBroker( void ); - extern MQTTStatus_t prvMQTTInit( void ); - extern MQTTStatus_t prvMQTTConnect( void ); - extern uint32_t prvGetTimeMs( void ); - extern UBaseType_t prvGetRandomNumber( void ); - extern BaseType_t prvSocketConnect( NetworkContext_t * pxNetworkContext ); - extern BaseType_t prvSocketDisconnect( NetworkContext_t * pxNetworkContext ); - extern void prvIncomingPublishCallback( MQTTAgentContext_t * pMqttAgentContext, - uint16_t packetId, - MQTTPublishInfo_t * pxPublishInfo ); - extern void prvReSubscriptionCommandCallback( MQTTAgentCommandContext_t * pxCommandContext, - MQTTAgentReturnInfo_t * pxReturnInfo ); - extern MQTTStatus_t prvHandleResubscribe( void ); - extern MQTTStatus_t prvMQTTConnect( void ); - extern void prvDisconnectCommandCallback( MQTTAgentCommandContext_t * pxCommandContext, - MQTTAgentReturnInfo_t * pxReturnInfo ); - extern void prvDisconnectFromMQTTBroker( void ); - extern void prvMQTTAgentTask( void * pParam ); - - /* Directly copy-paste mock headers from the file under test's directory. - Otherwise, the non-mock files are detected. */ - - /* subscription_manager.h */ - #ifndef SUBSCRIPTION_MANAGER_H +extern void prvMQTTAgentTask( void * pParam ); +extern BaseType_t prvSocketConnect( NetworkContext_t * pNetworkContext ); +extern void prvDisconnectFromMQTTBroker( void ); +extern MQTTStatus_t prvMQTTInit( void ); +extern MQTTStatus_t prvMQTTConnect( void ); +extern uint32_t prvGetTimeMs( void ); +extern UBaseType_t prvGetRandomNumber( void ); +extern BaseType_t prvSocketConnect( NetworkContext_t * pxNetworkContext ); +extern BaseType_t prvSocketDisconnect( NetworkContext_t * pxNetworkContext ); +extern void prvIncomingPublishCallback( MQTTAgentContext_t * pMqttAgentContext, + uint16_t packetId, + MQTTPublishInfo_t * pxPublishInfo ); +extern void prvReSubscriptionCommandCallback( MQTTAgentCommandContext_t * pxCommandContext, + MQTTAgentReturnInfo_t * pxReturnInfo ); +extern MQTTStatus_t prvHandleResubscribe( void ); +extern MQTTStatus_t prvMQTTConnect( void ); +extern void prvDisconnectCommandCallback( MQTTAgentCommandContext_t * pxCommandContext, + MQTTAgentReturnInfo_t * pxReturnInfo ); +extern void prvDisconnectFromMQTTBroker( void ); +extern void prvMQTTAgentTask( void * pParam ); + +/* Directly copy-paste mock headers from the file under test's directory. + * Otherwise, the non-mock files are detected. */ + +/* subscription_manager.h */ +#ifndef SUBSCRIPTION_MANAGER_H #define SUBSCRIPTION_MANAGER_H - typedef struct SubscriptionElement { - int usFilterStringLength; - const char * pcSubscriptionFilterString; - } SubscriptionElement_t; + typedef struct SubscriptionElement + { + int usFilterStringLength; + const char * pcSubscriptionFilterString; + } SubscriptionElement_t; - #ifndef SUBSCRIPTION_MANAGER_MAX_SUBSCRIPTIONS - #define SUBSCRIPTION_MANAGER_MAX_SUBSCRIPTIONS 10U - #endif + #ifndef SUBSCRIPTION_MANAGER_MAX_SUBSCRIPTIONS + #define SUBSCRIPTION_MANAGER_MAX_SUBSCRIPTIONS 10U + #endif - DECLARE_FAKE_VOID_FUNC( removeSubscription, - const char *, - uint16_t ); + DECLARE_FAKE_VOID_FUNC( removeSubscription, + const char *, + uint16_t ); - DECLARE_FAKE_VALUE_FUNC( bool, - handleIncomingPublishes, - MQTTPublishInfo_t * ); + DECLARE_FAKE_VALUE_FUNC( bool, + handleIncomingPublishes, + MQTTPublishInfo_t * ); - DEFINE_FAKE_VOID_FUNC( removeSubscription, - const char *, - uint16_t ); + DEFINE_FAKE_VOID_FUNC( removeSubscription, + const char *, + uint16_t ); - DEFINE_FAKE_VALUE_FUNC( bool, - handleIncomingPublishes, - MQTTPublishInfo_t * ); + DEFINE_FAKE_VALUE_FUNC( bool, + handleIncomingPublishes, + MQTTPublishInfo_t * ); - SubscriptionElement_t xGlobalSubscriptionList[ SUBSCRIPTION_MANAGER_MAX_SUBSCRIPTIONS ]; - #endif /* SUBSCRIPTION_MANAGER_H */ + SubscriptionElement_t xGlobalSubscriptionList[ SUBSCRIPTION_MANAGER_MAX_SUBSCRIPTIONS ]; +#endif /* SUBSCRIPTION_MANAGER_H */ - /* freertos_command_pool.h */ - #ifndef FREERTOS_COMMAND_POOL_H +/* freertos_command_pool.h */ +#ifndef FREERTOS_COMMAND_POOL_H #define FREERTOS_COMMAND_POOL_H - DECLARE_FAKE_VOID_FUNC( Agent_InitializePool ); - DECLARE_FAKE_VALUE_FUNC( MQTTAgentCommand_t *, - Agent_GetCommand, - uint32_t ); - DECLARE_FAKE_VALUE_FUNC( bool, - Agent_ReleaseCommand, - MQTTAgentCommand_t * ); - DEFINE_FAKE_VOID_FUNC( Agent_InitializePool ); - DEFINE_FAKE_VALUE_FUNC( MQTTAgentCommand_t *, - Agent_GetCommand, - uint32_t ); - DEFINE_FAKE_VALUE_FUNC( bool, - Agent_ReleaseCommand, - MQTTAgentCommand_t * ); - #endif /* FREERTOS_COMMAND_POOL_H */ - - /* Functions usually defined by main.c */ - DEFINE_FAKE_VOID_FUNC( vAssertCalled, - const char *, - unsigned long ); - DEFINE_FAKE_VOID_FUNC_VARARG( SdkLogError, - const char *, - ... ); - DEFINE_FAKE_VOID_FUNC_VARARG( SdkLogWarn, - const char *, - ... ); - DEFINE_FAKE_VOID_FUNC_VARARG( SdkLogInfo, - const char *, - ... ); - DEFINE_FAKE_VOID_FUNC_VARARG( SdkLogDebug, - const char *, - ... ); - + DECLARE_FAKE_VOID_FUNC( Agent_InitializePool ); + DECLARE_FAKE_VALUE_FUNC( MQTTAgentCommand_t *, + Agent_GetCommand, + uint32_t ); + DECLARE_FAKE_VALUE_FUNC( bool, + Agent_ReleaseCommand, + MQTTAgentCommand_t * ); + DEFINE_FAKE_VOID_FUNC( Agent_InitializePool ); + DEFINE_FAKE_VALUE_FUNC( MQTTAgentCommand_t *, + Agent_GetCommand, + uint32_t ); + DEFINE_FAKE_VALUE_FUNC( bool, + Agent_ReleaseCommand, + MQTTAgentCommand_t * ); +#endif /* FREERTOS_COMMAND_POOL_H */ + +/* Functions usually defined by main.c */ +DEFINE_FAKE_VOID_FUNC( vAssertCalled, + const char *, + unsigned long ); +DEFINE_FAKE_VOID_FUNC_VARARG( SdkLogError, + const char *, + ... ); +DEFINE_FAKE_VOID_FUNC_VARARG( SdkLogWarn, + const char *, + ... ); +DEFINE_FAKE_VOID_FUNC_VARARG( SdkLogInfo, + const char *, + ... ); +DEFINE_FAKE_VOID_FUNC_VARARG( SdkLogDebug, + const char *, + ... ); } DEFINE_FFF_GLOBALS -#define ASSERTION_FAIL int +#define ASSERTION_FAIL int /* Mock for vAssertCalled */ -void throw_assertion_failure ( const char * pcFile, - unsigned long ulLine ) { - throw (1); +void throw_assertion_failure( const char * pcFile, + unsigned long ulLine ) +{ + throw ( 1 ); + /* - Behaviour wanted: - - Encounters assertion fail, stops running any more code. E.g. does not go to next line. - - But checks all assertions in the google test program hold. - */ + * Behaviour wanted: + * - Encounters assertion fail, stops running any more code. E.g. does not go to next line. + * - But checks all assertions in the google test program hold. + */ } /* Being under this test class denotes a valid test that needs a corresponding fix, but we do not want clogging the testsuite. */ class SkipTest : public ::testing::Test { - protected: - void SetUp() override { - GTEST_SKIP() << "Skipping all tests under this suite"; - } +protected: + void SetUp() override + { + GTEST_SKIP() << "Skipping all tests under this suite"; + } }; class TestMqttAgentTask : public ::testing::Test { public: TestMqttAgentTask() { - RESET_FAKE(Agent_InitializePool); - RESET_FAKE(BackoffAlgorithm_InitializeParams); - RESET_FAKE(BackoffAlgorithm_GetNextBackoff); - RESET_FAKE(handleIncomingPublishes); - RESET_FAKE(MQTTAgent_CancelAll); - RESET_FAKE(MQTTAgent_CommandLoop); - RESET_FAKE(MQTTAgent_Init); - RESET_FAKE(MQTTAgent_ResumeSession) - RESET_FAKE(MQTTAgent_Subscribe); - RESET_FAKE(MQTT_Connect); - RESET_FAKE(MQTT_Status_strerror); - RESET_FAKE(psa_generate_random); - RESET_FAKE(SdkLogError); - RESET_FAKE(SdkLogInfo); - RESET_FAKE(SdkLogWarn); - RESET_FAKE(Transport_Disconnect); - RESET_FAKE(Transport_Connect); - RESET_FAKE(vAssertCalled); - RESET_FAKE(vTaskDelay); - RESET_FAKE(vTaskDelete); - RESET_FAKE(vWaitUntilNetworkIsUp); - RESET_FAKE(xEventGroupClearBits); - RESET_FAKE(xEventGroupSetBits); - RESET_FAKE(xTaskCreate); - RESET_FAKE(xTaskGetCurrentTaskHandle); - RESET_FAKE(xTaskGetTickCount); - RESET_FAKE(xTaskNotifyWait); - RESET_FAKE(xQueueCreateStatic); - - // Wrap functions expected to fail an assertion in EXPECT_THROW from GoogleTest. + RESET_FAKE( Agent_InitializePool ); + RESET_FAKE( BackoffAlgorithm_InitializeParams ); + RESET_FAKE( BackoffAlgorithm_GetNextBackoff ); + RESET_FAKE( handleIncomingPublishes ); + RESET_FAKE( MQTTAgent_CancelAll ); + RESET_FAKE( MQTTAgent_CommandLoop ); + RESET_FAKE( MQTTAgent_Init ); + RESET_FAKE( MQTTAgent_ResumeSession ) + RESET_FAKE( MQTTAgent_Subscribe ); + RESET_FAKE( MQTT_Connect ); + RESET_FAKE( MQTT_Status_strerror ); + RESET_FAKE( psa_generate_random ); + RESET_FAKE( SdkLogError ); + RESET_FAKE( SdkLogInfo ); + RESET_FAKE( SdkLogWarn ); + RESET_FAKE( Transport_Disconnect ); + RESET_FAKE( Transport_Connect ); + RESET_FAKE( vAssertCalled ); + RESET_FAKE( vTaskDelay ); + RESET_FAKE( vTaskDelete ); + RESET_FAKE( vWaitUntilNetworkIsUp ); + RESET_FAKE( xEventGroupClearBits ); + RESET_FAKE( xEventGroupSetBits ); + RESET_FAKE( xTaskCreate ); + RESET_FAKE( xTaskGetCurrentTaskHandle ); + RESET_FAKE( xTaskGetTickCount ); + RESET_FAKE( xTaskNotifyWait ); + RESET_FAKE( xQueueCreateStatic ); + + /* Wrap functions expected to fail an assertion in EXPECT_THROW from GoogleTest. */ vAssertCalled_fake.custom_fake = throw_assertion_failure; } }; /* Test helper functions */ -void expect_no_errors ( void ) { - ASSERT_EQ(SdkLogError_fake.call_count, 0); - ASSERT_EQ(vAssertCalled_fake.call_count, 0); +void expect_no_errors( void ) +{ + ASSERT_EQ( SdkLogError_fake.call_count, 0 ); + ASSERT_EQ( vAssertCalled_fake.call_count, 0 ); } -void expect_errors ( void ) { - ASSERT_NE(SdkLogError_fake.call_count + vAssertCalled_fake.call_count, 0) << "Expected an error reported by LogError or an assertion failure."; +void expect_errors( void ) +{ + ASSERT_NE( SdkLogError_fake.call_count + vAssertCalled_fake.call_count, 0 ) << "Expected an error reported by LogError or an assertion failure."; } -// expect throws -void expect_errors_or_warnings ( void ) { - ASSERT_NE(SdkLogError_fake.call_count + SdkLogWarn_fake.call_count + vAssertCalled_fake.call_count, 0) << "Expected an error reported by LogError, LogWarn or an assertion failure."; +/* expect throws */ +void expect_errors_or_warnings( void ) +{ + ASSERT_NE( SdkLogError_fake.call_count + SdkLogWarn_fake.call_count + vAssertCalled_fake.call_count, 0 ) << "Expected an error reported by LogError, LogWarn or an assertion failure."; } -void expect_no_errors_or_warnings ( void ) { - ASSERT_EQ(SdkLogError_fake.call_count + SdkLogWarn_fake.call_count + vAssertCalled_fake.call_count, 0); +void expect_no_errors_or_warnings( void ) +{ + ASSERT_EQ( SdkLogError_fake.call_count + SdkLogWarn_fake.call_count + vAssertCalled_fake.call_count, 0 ); } /* Custom fake for xEventGroupClearBits */ int expect_clearing_MQTT_event_mask( void * unused, - const int mask ) { - EXPECT_EQ (mask, EVENT_MASK_MQTT_CONNECTED); + const int mask ) +{ + EXPECT_EQ( mask, EVENT_MASK_MQTT_CONNECTED ); return 1; } -int expect_mqtt_connected_event_mask (void * unused, - const int mask) { - EXPECT_EQ(mask, EVENT_MASK_MQTT_CONNECTED); +int expect_mqtt_connected_event_mask( void * unused, + const int mask ) +{ + EXPECT_EQ( mask, EVENT_MASK_MQTT_CONNECTED ); return 1; } /* Custom fake for xTaskGetTickCount() */ TickType_t sharedCounter = 0; -TickType_t increment_shared_counter_and_return( void ) { +TickType_t increment_shared_counter_and_return( void ) +{ sharedCounter = sharedCounter + 1; return sharedCounter; } /* Custom fake for psa_generate_random */ -psa_status_t set_random_variable_to_three_and_return_success (uint8_t * randomVar, size_t outputSize) { +psa_status_t set_random_variable_to_three_and_return_success( uint8_t * randomVar, + size_t outputSize ) +{ *randomVar = 3; return PSA_SUCCESS; } /* Custom fakes for xTaskNotify */ -BaseType_t return_pdpass_and_expect_task_handle_points_to_five (TaskHandle_t handle, uint32_t returnCode, eNotifyAction unused) { - EXPECT_EQ(*handle, 5); +BaseType_t return_pdpass_and_expect_task_handle_points_to_five( TaskHandle_t handle, + uint32_t returnCode, + eNotifyAction unused ) +{ + EXPECT_EQ( *handle, 5 ); return pdPASS; } -BaseType_t return_pdpass_and_expect_mqtt_success_return_code (TaskHandle_t handle, uint32_t returnCode, eNotifyAction unused) { - EXPECT_EQ(returnCode, MQTTSuccess); +BaseType_t return_pdpass_and_expect_mqtt_success_return_code( TaskHandle_t handle, + uint32_t returnCode, + eNotifyAction unused ) +{ + EXPECT_EQ( returnCode, MQTTSuccess ); return pdPASS; } -BaseType_t return_pdpass_and_expect_mqtt_bad_parameter_return_code (TaskHandle_t handle, uint32_t returnCode, eNotifyAction unused) { - EXPECT_EQ(returnCode, MQTTBadParameter); +BaseType_t return_pdpass_and_expect_mqtt_bad_parameter_return_code( TaskHandle_t handle, + uint32_t returnCode, + eNotifyAction unused ) +{ + EXPECT_EQ( returnCode, MQTTBadParameter ); return pdPASS; } /* Custom fake for Transport_Connect */ -TransportStatus_t check_if_timeout_less_than_ten_seconds ( - NetworkContext_t * pNetworkContext, - const ServerInfo_t * pServerInfo, - const TLSParams_t * pTLSParams, - uint32_t sendTimeoutMs, - uint32_t recvTimeoutMs ) { +TransportStatus_t check_if_timeout_less_than_ten_seconds( NetworkContext_t * pNetworkContext, + const ServerInfo_t * pServerInfo, + const TLSParams_t * pTLSParams, + uint32_t sendTimeoutMs, + uint32_t recvTimeoutMs ) +{ uint32_t MAX_TIMEOUT_IN_MS = 10000; - EXPECT_LE(sendTimeoutMs, MAX_TIMEOUT_IN_MS); - EXPECT_LE(recvTimeoutMs, MAX_TIMEOUT_IN_MS); + + EXPECT_LE( sendTimeoutMs, MAX_TIMEOUT_IN_MS ); + EXPECT_LE( recvTimeoutMs, MAX_TIMEOUT_IN_MS ); return TRANSPORT_STATUS_SUCCESS; } /* The file under test contains static functions which the tests in this file assume are made visible -by conditional compiling macros. This test verifies these macros are defined. */ -TEST_F(TestMqttAgentTask, Can_test_static_functions) { + * by conditional compiling macros. This test verifies these macros are defined. */ +TEST_F( TestMqttAgentTask, Can_test_static_functions ) +{ #ifndef UNIT_TESTING - FAIL() << "The macro UNIT_TESTING is not defined, please add this to your CMake compile definitions."; + FAIL() << "The macro UNIT_TESTING is not defined, please add this to your CMake compile definitions."; #endif /* UNIT_TESTING */ } /* Testing vStartMqttAgentTask */ -TEST_F(TestMqttAgentTask, Starting_the_agent_creates_a_task) +TEST_F( TestMqttAgentTask, Starting_the_agent_creates_a_task ) { - ASSERT_EQ(xTaskCreate_fake.call_count, 0); - vStartMqttAgentTask (); - ASSERT_NE(xTaskCreate_fake.call_count, 0); + ASSERT_EQ( xTaskCreate_fake.call_count, 0 ); + vStartMqttAgentTask(); + ASSERT_NE( xTaskCreate_fake.call_count, 0 ); expect_no_errors(); } /* Testing prvSocketConnect */ -TEST_F(TestMqttAgentTask, Socket_connect_tries_to_make_a_connection) { - EXPECT_EQ(Transport_Connect_fake.call_count, 0); +TEST_F( TestMqttAgentTask, Socket_connect_tries_to_make_a_connection ) +{ + EXPECT_EQ( Transport_Connect_fake.call_count, 0 ); prvSocketConnect( 0 ); - EXPECT_NE(Transport_Connect_fake.call_count, 0); + EXPECT_NE( Transport_Connect_fake.call_count, 0 ); } -TEST_F(TestMqttAgentTask, Socket_connect_returns_success_on_successful_connection) { +TEST_F( TestMqttAgentTask, Socket_connect_returns_success_on_successful_connection ) +{ Transport_Connect_fake.return_val = TRANSPORT_STATUS_SUCCESS; - EXPECT_EQ(prvSocketConnect ( 0 ), pdPASS); + EXPECT_EQ( prvSocketConnect( 0 ), pdPASS ); expect_no_errors(); } -TEST_F(TestMqttAgentTask, Socket_connect_does_not_error_on_unsuccessful_connection) { +TEST_F( TestMqttAgentTask, Socket_connect_does_not_error_on_unsuccessful_connection ) +{ Transport_Connect_fake.return_val = TRANSPORT_STATUS_CONNECT_FAILURE; - prvSocketConnect ( 0 ); + prvSocketConnect( 0 ); expect_no_errors(); } -TEST_F(TestMqttAgentTask, Socket_connection_reattempt_does_not_continue_past_reasonable_time) { +TEST_F( TestMqttAgentTask, Socket_connection_reattempt_does_not_continue_past_reasonable_time ) +{ Transport_Connect_fake.custom_fake = check_if_timeout_less_than_ten_seconds; - prvSocketConnect ( 0 ); + prvSocketConnect( 0 ); expect_no_errors(); } -TEST_F(TestMqttAgentTask, Socket_connect_returns_failure_on_unsuccessful_connection) { +TEST_F( TestMqttAgentTask, Socket_connect_returns_failure_on_unsuccessful_connection ) +{ Transport_Connect_fake.return_val = TRANSPORT_STATUS_CONNECT_FAILURE; - EXPECT_EQ(prvSocketConnect ( 0 ), pdFALSE); + EXPECT_EQ( prvSocketConnect( 0 ), pdFALSE ); expect_no_errors(); } /* Testing prvSocketDisconnect */ -TEST_F(TestMqttAgentTask, Socket_disconnect_tries_to_close_a_connection) { +TEST_F( TestMqttAgentTask, Socket_disconnect_tries_to_close_a_connection ) +{ Transport_Disconnect_fake.return_val = TRANSPORT_STATUS_SUCCESS; xEventGroupClearBits_fake.return_val = 1; int dummy = 1; - EXPECT_EQ(Transport_Disconnect_fake.call_count, 0); + EXPECT_EQ( Transport_Disconnect_fake.call_count, 0 ); prvSocketDisconnect( &dummy ); - EXPECT_NE(Transport_Disconnect_fake.call_count, 0); + EXPECT_NE( Transport_Disconnect_fake.call_count, 0 ); } -TEST_F(TestMqttAgentTask, Socket_disconnect_returns_success_when_disconnecting_succeeds) { +TEST_F( TestMqttAgentTask, Socket_disconnect_returns_success_when_disconnecting_succeeds ) +{ Transport_Disconnect_fake.return_val = TRANSPORT_STATUS_SUCCESS; xEventGroupClearBits_fake.return_val = 1; int dummy = 1; - EXPECT_EQ(prvSocketDisconnect( &dummy ), pdPASS); + EXPECT_EQ( prvSocketDisconnect( &dummy ), pdPASS ); } -TEST_F(TestMqttAgentTask, Socket_disconnect_returns_failure_when_disconnecting_fails) { +TEST_F( TestMqttAgentTask, Socket_disconnect_returns_failure_when_disconnecting_fails ) +{ Transport_Disconnect_fake.return_val = TRANSPORT_STATUS_INVALID_PARAMETER; xEventGroupClearBits_fake.return_val = 1; int dummy = 1; - EXPECT_EQ(prvSocketDisconnect( &dummy ), pdFAIL); + EXPECT_EQ( prvSocketDisconnect( &dummy ), pdFAIL ); } -TEST_F(TestMqttAgentTask, Socket_disconnect_informs_system_there_is_no_MQTT_connection_on_connection_closure) { +TEST_F( TestMqttAgentTask, Socket_disconnect_informs_system_there_is_no_MQTT_connection_on_connection_closure ) +{ Transport_Disconnect_fake.return_val = TRANSPORT_STATUS_SUCCESS; xEventGroupClearBits_fake.custom_fake = expect_clearing_MQTT_event_mask; - EXPECT_EQ(xEventGroupClearBits_fake.call_count, 0); + EXPECT_EQ( xEventGroupClearBits_fake.call_count, 0 ); int dummy = 1; prvSocketDisconnect( &dummy ); - EXPECT_NE(xEventGroupClearBits_fake.call_count, 0); + EXPECT_NE( xEventGroupClearBits_fake.call_count, 0 ); } -TEST_F(TestMqttAgentTask, Socket_disconnect_does_not_inform_system_there_is_no_MQTT_connection_if_connection_not_closed) { +TEST_F( TestMqttAgentTask, Socket_disconnect_does_not_inform_system_there_is_no_MQTT_connection_if_connection_not_closed ) +{ Transport_Disconnect_fake.return_val = TRANSPORT_STATUS_INVALID_PARAMETER; xEventGroupClearBits_fake.custom_fake = expect_clearing_MQTT_event_mask; - EXPECT_EQ(xEventGroupClearBits_fake.call_count, 0); + EXPECT_EQ( xEventGroupClearBits_fake.call_count, 0 ); int dummy = 1; prvSocketDisconnect( &dummy ); - EXPECT_EQ(xEventGroupClearBits_fake.call_count, 0); + EXPECT_EQ( xEventGroupClearBits_fake.call_count, 0 ); } /* Testing prvDisconnectFromMQTTBroker */ -TEST_F(TestMqttAgentTask, Disconnect_from_broker_tries_to_disconnect_from_MQTT_broker) { +TEST_F( TestMqttAgentTask, Disconnect_from_broker_tries_to_disconnect_from_MQTT_broker ) +{ int dummy = 1; + xTaskGetCurrentTaskHandle_fake.return_val = &dummy; xTaskNotifyWait_fake.return_val = 1; Transport_Disconnect_fake.return_val = TRANSPORT_STATUS_SUCCESS; xEventGroupClearBits_fake.return_val = 1; - EXPECT_EQ(MQTTAgent_Disconnect_fake.call_count, 0); + EXPECT_EQ( MQTTAgent_Disconnect_fake.call_count, 0 ); prvDisconnectFromMQTTBroker(); - EXPECT_NE(MQTTAgent_Disconnect_fake.call_count, 0); + EXPECT_NE( MQTTAgent_Disconnect_fake.call_count, 0 ); } -TEST_F(TestMqttAgentTask, Disconnect_from_broker_tries_to_close_TCP_connection_if_successful) { +TEST_F( TestMqttAgentTask, Disconnect_from_broker_tries_to_close_TCP_connection_if_successful ) +{ int dummy = 1; + xTaskGetCurrentTaskHandle_fake.return_val = &dummy; xTaskNotifyWait_fake.return_val = 1; Transport_Disconnect_fake.return_val = TRANSPORT_STATUS_SUCCESS; xEventGroupClearBits_fake.return_val = 1; - // Successful mqtt closure. + /* Successful mqtt closure. */ MQTTAgent_Disconnect_fake.return_val = MQTTSuccess; - EXPECT_EQ(Transport_Disconnect_fake.call_count, 0); + EXPECT_EQ( Transport_Disconnect_fake.call_count, 0 ); prvDisconnectFromMQTTBroker(); - EXPECT_NE(Transport_Disconnect_fake.call_count, 0); + EXPECT_NE( Transport_Disconnect_fake.call_count, 0 ); } -TEST_F(TestMqttAgentTask, Disconnect_from_broker_errors_on_failure_to_close_TCP_connection) { +TEST_F( TestMqttAgentTask, Disconnect_from_broker_errors_on_failure_to_close_TCP_connection ) +{ int dummy = 1; + xTaskGetCurrentTaskHandle_fake.return_val = &dummy; xTaskNotifyWait_fake.return_val = 1; xEventGroupClearBits_fake.return_val = 1; - // Fails to close mqtt connection. + /* Fails to close mqtt connection. */ MQTTAgent_Disconnect_fake.return_val = MQTTSuccess; Transport_Disconnect_fake.return_val = TRANSPORT_STATUS_INVALID_PARAMETER; prvDisconnectFromMQTTBroker(); expect_errors(); } -TEST_F(TestMqttAgentTask, Waits_for_MQTT_connection_to_close_before_trying_to_close_TCP_connection) { +TEST_F( TestMqttAgentTask, Waits_for_MQTT_connection_to_close_before_trying_to_close_TCP_connection ) +{ int dummy = 1; + xTaskGetCurrentTaskHandle_fake.return_val = &dummy; xTaskNotifyWait_fake.return_val = 1; xEventGroupClearBits_fake.return_val = 1; - // Closes mqtt connection but fails for TCP. + /* Closes mqtt connection but fails for TCP. */ MQTTAgent_Disconnect_fake.return_val = MQTTSuccess; Transport_Disconnect_fake.return_val = TRANSPORT_STATUS_SUCCESS; - EXPECT_EQ(xTaskNotifyWait_fake.call_count, 0); + EXPECT_EQ( xTaskNotifyWait_fake.call_count, 0 ); prvDisconnectFromMQTTBroker(); - EXPECT_NE(xTaskNotifyWait_fake.call_count, 0); + EXPECT_NE( xTaskNotifyWait_fake.call_count, 0 ); - // Check this is not the case if MQTT closure fails. - RESET_FAKE(xTaskNotifyWait); + /* Check this is not the case if MQTT closure fails. */ + RESET_FAKE( xTaskNotifyWait ); MQTTAgent_Disconnect_fake.return_val = MQTTBadParameter; - EXPECT_EQ(xTaskNotifyWait_fake.call_count, 0) << "Failed to reset fake."; - EXPECT_THROW(prvDisconnectFromMQTTBroker(); - EXPECT_EQ(xTaskNotifyWait_fake.call_count, 0), ASSERTION_FAIL); + EXPECT_EQ( xTaskNotifyWait_fake.call_count, 0 ) << "Failed to reset fake."; + EXPECT_THROW( prvDisconnectFromMQTTBroker(); + EXPECT_EQ( xTaskNotifyWait_fake.call_count, 0 ), ASSERTION_FAIL ); } -TEST_F(TestMqttAgentTask, Disconnect_from_broker_errors_on_failure_to_close_MQTT_connection) { +TEST_F( TestMqttAgentTask, Disconnect_from_broker_errors_on_failure_to_close_MQTT_connection ) +{ int dummy = 1; + xTaskGetCurrentTaskHandle_fake.return_val = &dummy; xTaskNotifyWait_fake.return_val = 1; Transport_Disconnect_fake.return_val = TRANSPORT_STATUS_SUCCESS; xEventGroupClearBits_fake.return_val = 1; - // Fails to close mqtt connection. + /* Fails to close mqtt connection. */ MQTTAgent_Disconnect_fake.return_val = MQTTBadParameter; Transport_Disconnect_fake.return_val = TRANSPORT_STATUS_SUCCESS; - EXPECT_THROW(prvDisconnectFromMQTTBroker(), ASSERTION_FAIL); + EXPECT_THROW( prvDisconnectFromMQTTBroker(), ASSERTION_FAIL ); expect_errors(); } @@ -420,26 +463,30 @@ class TestMqttAgentTaskConnect : public TestMqttAgentTask { MQTTAgent_Disconnect_fake.return_val = MQTTSuccess; } }; -TEST_F(TestMqttAgentTaskConnect, MQTT_connect_tries_to_create_a_connection) { - EXPECT_EQ(MQTT_Connect_fake.call_count, 0); +TEST_F( TestMqttAgentTaskConnect, MQTT_connect_tries_to_create_a_connection ) +{ + EXPECT_EQ( MQTT_Connect_fake.call_count, 0 ); prvMQTTConnect(); - EXPECT_NE(MQTT_Connect_fake.call_count, 0); + EXPECT_NE( MQTT_Connect_fake.call_count, 0 ); } -TEST_F(TestMqttAgentTaskConnect, MQTT_connect_returns_failure_if_connection_fails) { +TEST_F( TestMqttAgentTaskConnect, MQTT_connect_returns_failure_if_connection_fails ) +{ MQTT_Connect_fake.return_val = MQTTBadParameter; - EXPECT_NE(prvMQTTConnect(), MQTTSuccess); + EXPECT_NE( prvMQTTConnect(), MQTTSuccess ); } -TEST_F(TestMqttAgentTaskConnect, MQTT_connect_returns_success_if_connection_succeeds) { +TEST_F( TestMqttAgentTaskConnect, MQTT_connect_returns_success_if_connection_succeeds ) +{ MQTT_Connect_fake.return_val = MQTTSuccess; - EXPECT_EQ(prvMQTTConnect(), MQTTSuccess); -} -// This test would be ideal, but is not possible to write neatly for the file. -// TEST_F(TestMqttAgentTask, MQTT_connect_updates_system_flags_when_creating_a_new_connection) { -// EXPECT_EQ(xEventGroupSetBits_fake.call_count, 0); -// prvMQTTConnect(); -// EXPECT_NE(xEventGroupSetBits_fake.call_count, 0); -// } -TEST_F(TestMqttAgentTaskConnect, MQTT_connect_does_not_set_incorrect_system_flags_when_creating_a_new_connection) { + EXPECT_EQ( prvMQTTConnect(), MQTTSuccess ); +} +/* This test would be ideal, but is not possible to write neatly for the file. */ +/* TEST_F(TestMqttAgentTask, MQTT_connect_updates_system_flags_when_creating_a_new_connection) { */ +/* EXPECT_EQ(xEventGroupSetBits_fake.call_count, 0); */ +/* prvMQTTConnect(); */ +/* EXPECT_NE(xEventGroupSetBits_fake.call_count, 0); */ +/* } */ +TEST_F( TestMqttAgentTaskConnect, MQTT_connect_does_not_set_incorrect_system_flags_when_creating_a_new_connection ) +{ xEventGroupSetBits_fake.custom_fake = expect_mqtt_connected_event_mask; prvMQTTConnect(); } @@ -448,42 +495,50 @@ TEST_F(TestMqttAgentTaskConnect, MQTT_connect_does_not_set_incorrect_system_flag /* Testing prvGetRandomNumber */ -TEST_F(TestMqttAgentTask, Random_generation_calls_random_library_function) { +TEST_F( TestMqttAgentTask, Random_generation_calls_random_library_function ) +{ psa_generate_random_fake.return_val = PSA_SUCCESS; xTaskGetTickCount_fake.return_val = 3; - EXPECT_EQ(psa_generate_random_fake.call_count, 0); + EXPECT_EQ( psa_generate_random_fake.call_count, 0 ); prvGetRandomNumber(); - EXPECT_NE(psa_generate_random_fake.call_count, 0); + EXPECT_NE( psa_generate_random_fake.call_count, 0 ); } -TEST_F(TestMqttAgentTask, Random_generation_does_not_error_if_random_library_function_succeeds) { +TEST_F( TestMqttAgentTask, Random_generation_does_not_error_if_random_library_function_succeeds ) +{ psa_generate_random_fake.return_val = PSA_SUCCESS; xTaskGetTickCount_fake.return_val = 3; prvGetRandomNumber(); expect_no_errors(); } -TEST_F(TestMqttAgentTask, Random_generation_reports_library_failure) { +TEST_F( TestMqttAgentTask, Random_generation_reports_library_failure ) +{ psa_generate_random_fake.return_val = PSA_ERROR_PROGRAMMER_ERROR; xTaskGetTickCount_fake.return_val = 3; prvGetRandomNumber(); expect_errors_or_warnings(); } -TEST_F(TestMqttAgentTask, Random_generation_gives_same_output_for_same_value_given_by_random_library) { - // Randomizing outputs should be handled by correct library functions only. +TEST_F( TestMqttAgentTask, Random_generation_gives_same_output_for_same_value_given_by_random_library ) +{ + /* Randomizing outputs should be handled by correct library functions only. */ xTaskGetTickCount_fake.custom_fake = increment_shared_counter_and_return; psa_generate_random_fake.custom_fake = set_random_variable_to_three_and_return_success; UBaseType_t expected = prvGetRandomNumber(); - for (int count=0; count < 5; count++) { - EXPECT_EQ(expected, prvGetRandomNumber()); + + for( int count = 0; count < 5; count++ ) + { + EXPECT_EQ( expected, prvGetRandomNumber() ); } } /* Testing prvIncomingPublishCallback */ -TEST_F(TestMqttAgentTask, Publish_callback_calls_publish_handler) { +TEST_F( TestMqttAgentTask, Publish_callback_calls_publish_handler ) +{ bool handled = true; + handleIncomingPublishes_fake.return_val = handled; - EXPECT_EQ(handleIncomingPublishes_fake.call_count, 0); + EXPECT_EQ( handleIncomingPublishes_fake.call_count, 0 ); int dummy = 5; MQTTAgentContext_t mqttAgentContext = { &dummy }; uint16_t dummyId = 10; @@ -495,11 +550,13 @@ TEST_F(TestMqttAgentTask, Publish_callback_calls_publish_handler) { xPublishInfo.topicNameLength = 6; xPublishInfo.pPayload = &dummy; xPublishInfo.payloadLength = 1; - prvIncomingPublishCallback(&mqttAgentContext, dummyId, &xPublishInfo); - EXPECT_NE(handleIncomingPublishes_fake.call_count, 0); + prvIncomingPublishCallback( &mqttAgentContext, dummyId, &xPublishInfo ); + EXPECT_NE( handleIncomingPublishes_fake.call_count, 0 ); } -TEST_F(TestMqttAgentTask, Publish_callback_does_not_error_if_publish_handled) { +TEST_F( TestMqttAgentTask, Publish_callback_does_not_error_if_publish_handled ) +{ bool handled = true; + handleIncomingPublishes_fake.return_val = handled; int dummy = 5; @@ -513,11 +570,13 @@ TEST_F(TestMqttAgentTask, Publish_callback_does_not_error_if_publish_handled) { xPublishInfo.topicNameLength = 6; xPublishInfo.pPayload = &dummy; xPublishInfo.payloadLength = 1; - prvIncomingPublishCallback(&mqttAgentContext, dummyId, &xPublishInfo); + prvIncomingPublishCallback( &mqttAgentContext, dummyId, &xPublishInfo ); expect_no_errors_or_warnings(); } -TEST_F(TestMqttAgentTask, Publish_callback_generates_log_if_handler_fails) { +TEST_F( TestMqttAgentTask, Publish_callback_generates_log_if_handler_fails ) +{ bool handled = false; + handleIncomingPublishes_fake.return_val = handled; int dummy = 5; @@ -531,55 +590,64 @@ TEST_F(TestMqttAgentTask, Publish_callback_generates_log_if_handler_fails) { xPublishInfo.topicNameLength = 6; xPublishInfo.pPayload = &dummy; xPublishInfo.payloadLength = 1; - prvIncomingPublishCallback(&mqttAgentContext, dummyId, &xPublishInfo); + prvIncomingPublishCallback( &mqttAgentContext, dummyId, &xPublishInfo ); expect_errors_or_warnings(); } /* Testing prvReSubscriptionCommandCallback */ -TEST_F(TestMqttAgentTask, Command_callback_does_not_error_if_given_MQTT_success) { - // All topic filters are already in subscription list. +TEST_F( TestMqttAgentTask, Command_callback_does_not_error_if_given_MQTT_success ) +{ + /* All topic filters are already in subscription list. */ MQTTAgentCommandContext_t pxCommandContext = {}; - uint8_t subackCodes[] ={1,2,3}; - MQTTAgentReturnInfo_t returnInfo = {MQTTSuccess, subackCodes}; - prvReSubscriptionCommandCallback(&pxCommandContext, &returnInfo); + uint8_t subackCodes[] = { 1, 2, 3 }; + MQTTAgentReturnInfo_t returnInfo = { MQTTSuccess, subackCodes }; + + prvReSubscriptionCommandCallback( &pxCommandContext, &returnInfo ); expect_no_errors(); } /* Testing prvMQTTInit */ -TEST_F(TestMqttAgentTask, MQTT_init_tries_to_create_a_command_queue) { - QueueDefinition queue = {10}; +TEST_F( TestMqttAgentTask, MQTT_init_tries_to_create_a_command_queue ) +{ + QueueDefinition queue = { 10 }; + xQueueCreateStatic_fake.return_val = &queue; MQTTAgent_Init_fake.return_val = MQTTSuccess; xEventGroupSetBits_fake.return_val = 1; - EXPECT_EQ(xQueueCreateStatic_fake.call_count, 0); + EXPECT_EQ( xQueueCreateStatic_fake.call_count, 0 ); prvMQTTInit(); - EXPECT_NE(xQueueCreateStatic_fake.call_count, 0); + EXPECT_NE( xQueueCreateStatic_fake.call_count, 0 ); } -TEST_F(TestMqttAgentTask, MQTT_init_errors_if_command_queue_creation_returns_nullptr) { - // E.g. happens if out of memory. +TEST_F( TestMqttAgentTask, MQTT_init_errors_if_command_queue_creation_returns_nullptr ) +{ + /* E.g. happens if out of memory. */ xQueueCreateStatic_fake.return_val = nullptr; MQTTAgent_Init_fake.return_val = MQTTSuccess; xEventGroupSetBits_fake.return_val = 1; - EXPECT_EQ(xQueueCreateStatic_fake.call_count, 0); - EXPECT_THROW(prvMQTTInit(), ASSERTION_FAIL); - EXPECT_NE(xQueueCreateStatic_fake.call_count, 0); + EXPECT_EQ( xQueueCreateStatic_fake.call_count, 0 ); + EXPECT_THROW( prvMQTTInit(), ASSERTION_FAIL ); + EXPECT_NE( xQueueCreateStatic_fake.call_count, 0 ); } -TEST_F(TestMqttAgentTask, MQTT_init_tries_to_initialise_MQTT_library) { - QueueDefinition queue = {10}; +TEST_F( TestMqttAgentTask, MQTT_init_tries_to_initialise_MQTT_library ) +{ + QueueDefinition queue = { 10 }; + xQueueCreateStatic_fake.return_val = &queue; MQTTAgent_Init_fake.return_val = MQTTSuccess; xEventGroupSetBits_fake.return_val = 1; - EXPECT_EQ(MQTTAgent_Init_fake.call_count, 0); + EXPECT_EQ( MQTTAgent_Init_fake.call_count, 0 ); prvMQTTInit(); - EXPECT_NE(MQTTAgent_Init_fake.call_count, 0); + EXPECT_NE( MQTTAgent_Init_fake.call_count, 0 ); } -TEST_F(TestMqttAgentTask, MQTT_init_errors_if_cannot_initialise_MQTT_library) { - QueueDefinition queue = {10}; +TEST_F( TestMqttAgentTask, MQTT_init_errors_if_cannot_initialise_MQTT_library ) +{ + QueueDefinition queue = { 10 }; + xQueueCreateStatic_fake.return_val = &queue; MQTTAgent_Init_fake.return_val = MQTTBadParameter; xEventGroupSetBits_fake.return_val = 1; @@ -587,35 +655,41 @@ TEST_F(TestMqttAgentTask, MQTT_init_errors_if_cannot_initialise_MQTT_library) { prvMQTTInit(); expect_errors(); } -TEST_F(TestMqttAgentTask, MQTT_init_returns_success_if_successful) { - QueueDefinition queue = {10}; +TEST_F( TestMqttAgentTask, MQTT_init_returns_success_if_successful ) +{ + QueueDefinition queue = { 10 }; + xQueueCreateStatic_fake.return_val = &queue; MQTTAgent_Init_fake.return_val = MQTTSuccess; xEventGroupSetBits_fake.return_val = 1; - EXPECT_EQ(prvMQTTInit(), MQTTSuccess); + EXPECT_EQ( prvMQTTInit(), MQTTSuccess ); } -TEST_F(TestMqttAgentTask, MQTT_init_sets_system_MQTT_init_event_flag_if_successful) { - QueueDefinition queue = {10}; +TEST_F( TestMqttAgentTask, MQTT_init_sets_system_MQTT_init_event_flag_if_successful ) +{ + QueueDefinition queue = { 10 }; + xQueueCreateStatic_fake.return_val = &queue; MQTTAgent_Init_fake.return_val = MQTTSuccess; xEventGroupSetBits_fake.return_val = 1; - EXPECT_EQ(xEventGroupSetBits_fake.call_count, 0); + EXPECT_EQ( xEventGroupSetBits_fake.call_count, 0 ); prvMQTTInit(); - EXPECT_NE(xEventGroupSetBits_fake.call_count, 0); + EXPECT_NE( xEventGroupSetBits_fake.call_count, 0 ); } /* Testing prvHandleResubscribe */ -TEST_F(TestMqttAgentTask, Resubscribe_returns_success_if_subscription_succeeds) { +TEST_F( TestMqttAgentTask, Resubscribe_returns_success_if_subscription_succeeds ) +{ MQTTAgent_Subscribe_fake.return_val = MQTTSuccess; MQTT_Status_strerror_fake.return_val = "dummy"; - EXPECT_EQ(prvHandleResubscribe(), MQTTSuccess); + EXPECT_EQ( prvHandleResubscribe(), MQTTSuccess ); } /* Testing prvDisconnectCommandCallback */ -TEST_F(TestMqttAgentTask, Callback_for_MQTT_disconnect_tries_to_notify_task_waiting_on_MQTT_disconnection) { +TEST_F( TestMqttAgentTask, Callback_for_MQTT_disconnect_tries_to_notify_task_waiting_on_MQTT_disconnection ) +{ xTaskNotify_fake.return_val = pdPASS; int dummyTask = 5; @@ -624,11 +698,12 @@ TEST_F(TestMqttAgentTask, Callback_for_MQTT_disconnect_tries_to_notify_task_wait MQTTAgentReturnInfo_t xReturnInfo; xReturnInfo.returnCode = MQTTSuccess; - EXPECT_EQ(xTaskNotify_fake.call_count, 0); - prvDisconnectCommandCallback(&xCommandContext, &xReturnInfo); - EXPECT_NE(xTaskNotify_fake.call_count, 0); + EXPECT_EQ( xTaskNotify_fake.call_count, 0 ); + prvDisconnectCommandCallback( &xCommandContext, &xReturnInfo ); + EXPECT_NE( xTaskNotify_fake.call_count, 0 ); } -TEST_F(TestMqttAgentTask, Callback_for_MQTT_disconnect_does_not_notify_if_no_tasks_are_waiting) { +TEST_F( TestMqttAgentTask, Callback_for_MQTT_disconnect_does_not_notify_if_no_tasks_are_waiting ) +{ xTaskNotify_fake.return_val = pdPASS; MQTTAgentCommandContext_t xCommandContext; @@ -636,11 +711,12 @@ TEST_F(TestMqttAgentTask, Callback_for_MQTT_disconnect_does_not_notify_if_no_tas MQTTAgentReturnInfo_t xReturnInfo; xReturnInfo.returnCode = MQTTSuccess; - EXPECT_EQ(xTaskNotify_fake.call_count, 0); - prvDisconnectCommandCallback(&xCommandContext, &xReturnInfo); - EXPECT_EQ(xTaskNotify_fake.call_count, 0); + EXPECT_EQ( xTaskNotify_fake.call_count, 0 ); + prvDisconnectCommandCallback( &xCommandContext, &xReturnInfo ); + EXPECT_EQ( xTaskNotify_fake.call_count, 0 ); } -TEST_F(TestMqttAgentTask, Callback_for_MQTT_disconnect_notifies_tasks_with_correct_MQTT_status_return_code) { +TEST_F( TestMqttAgentTask, Callback_for_MQTT_disconnect_notifies_tasks_with_correct_MQTT_status_return_code ) +{ xTaskNotify_fake.custom_fake = return_pdpass_and_expect_mqtt_success_return_code; int dummyTask = 5; @@ -650,14 +726,15 @@ TEST_F(TestMqttAgentTask, Callback_for_MQTT_disconnect_notifies_tasks_with_corre xReturnInfo.returnCode = MQTTBadParameter; xTaskNotify_fake.custom_fake = return_pdpass_and_expect_mqtt_bad_parameter_return_code; - EXPECT_EQ(xTaskNotify_fake.call_count, 0); - prvDisconnectCommandCallback(&xCommandContext, &xReturnInfo); - EXPECT_NE(xTaskNotify_fake.call_count, 0); + EXPECT_EQ( xTaskNotify_fake.call_count, 0 ); + prvDisconnectCommandCallback( &xCommandContext, &xReturnInfo ); + EXPECT_NE( xTaskNotify_fake.call_count, 0 ); xReturnInfo.returnCode = MQTTBadParameter; - prvDisconnectCommandCallback(&xCommandContext, &xReturnInfo); + prvDisconnectCommandCallback( &xCommandContext, &xReturnInfo ); } -TEST_F(TestMqttAgentTask, Callback_for_MQTT_disconnect_notifies_the_correct_task) { +TEST_F( TestMqttAgentTask, Callback_for_MQTT_disconnect_notifies_the_correct_task ) +{ xTaskNotify_fake.custom_fake = return_pdpass_and_expect_task_handle_points_to_five; int dummyTask = 5; @@ -666,9 +743,9 @@ TEST_F(TestMqttAgentTask, Callback_for_MQTT_disconnect_notifies_the_correct_task MQTTAgentReturnInfo_t xReturnInfo; xReturnInfo.returnCode = MQTTSuccess; - EXPECT_EQ(xTaskNotify_fake.call_count, 0); - prvDisconnectCommandCallback(&xCommandContext, &xReturnInfo); - EXPECT_NE(xTaskNotify_fake.call_count, 0); + EXPECT_EQ( xTaskNotify_fake.call_count, 0 ); + prvDisconnectCommandCallback( &xCommandContext, &xReturnInfo ); + EXPECT_NE( xTaskNotify_fake.call_count, 0 ); } /* Testing prvMQTTAgentTask */ @@ -678,7 +755,8 @@ class TestMqttAgentTaskMainFunction : public TestMqttAgentTask { TestMqttAgentTaskMainFunction() { /* So that MQTT can initialise */ - QueueDefinition queue = {10}; + QueueDefinition queue = { 10 }; + xQueueCreateStatic_fake.return_val = &queue; MQTTAgent_Init_fake.return_val = MQTTSuccess; xEventGroupSetBits_fake.return_val = 1; @@ -703,29 +781,34 @@ class TestMqttAgentTaskMainFunction : public TestMqttAgentTask { } }; -TEST_F(TestMqttAgentTaskMainFunction, Agent_task_waits_for_network_to_start) { - EXPECT_EQ(vWaitUntilNetworkIsUp_fake.call_count, 0); - prvMQTTAgentTask ( nullptr ); - EXPECT_NE(vWaitUntilNetworkIsUp_fake.call_count, 0); +TEST_F( TestMqttAgentTaskMainFunction, Agent_task_waits_for_network_to_start ) +{ + EXPECT_EQ( vWaitUntilNetworkIsUp_fake.call_count, 0 ); + prvMQTTAgentTask( nullptr ); + EXPECT_NE( vWaitUntilNetworkIsUp_fake.call_count, 0 ); } -TEST_F(TestMqttAgentTaskMainFunction, Agent_task_initialises_MQTT_library) { - EXPECT_EQ(MQTTAgent_Init_fake.call_count, 0); - prvMQTTAgentTask ( nullptr ); - EXPECT_NE(MQTTAgent_Init_fake.call_count, 0); +TEST_F( TestMqttAgentTaskMainFunction, Agent_task_initialises_MQTT_library ) +{ + EXPECT_EQ( MQTTAgent_Init_fake.call_count, 0 ); + prvMQTTAgentTask( nullptr ); + EXPECT_NE( MQTTAgent_Init_fake.call_count, 0 ); } -TEST_F(TestMqttAgentTaskMainFunction, Agent_task_cannot_continue_if_mqtt_initialisation_fails) { +TEST_F( TestMqttAgentTaskMainFunction, Agent_task_cannot_continue_if_mqtt_initialisation_fails ) +{ /* MQTT needs a command queue and the agent to be intialised */ - QueueDefinition queue = {10}; + QueueDefinition queue = { 10 }; + xQueueCreateStatic_fake.return_val = &queue; MQTTAgent_Init_fake.return_val = MQTTBadParameter; xEventGroupSetBits_fake.return_val = 1; MQTTAgent_Init_fake.return_val = MQTTBadParameter; - EXPECT_THROW(prvMQTTAgentTask( nullptr ), ASSERTION_FAIL); + EXPECT_THROW( prvMQTTAgentTask( nullptr ), ASSERTION_FAIL ); } -TEST_F(TestMqttAgentTaskMainFunction, Agent_task_initialises_MQTT_pool) { - EXPECT_EQ(Agent_InitializePool_fake.call_count, 0); - prvMQTTAgentTask ( nullptr ); - EXPECT_NE(Agent_InitializePool_fake.call_count, 0); +TEST_F( TestMqttAgentTaskMainFunction, Agent_task_initialises_MQTT_pool ) +{ + EXPECT_EQ( Agent_InitializePool_fake.call_count, 0 ); + prvMQTTAgentTask( nullptr ); + EXPECT_NE( Agent_InitializePool_fake.call_count, 0 ); } diff --git a/release_changes/202409181115.change b/release_changes/202409181115.change new file mode 100644 index 0000000..dbb2af8 --- /dev/null +++ b/release_changes/202409181115.change @@ -0,0 +1,2 @@ +spell-checker: Add missing excluded words. +formatting: Add CPP files to uncrustify check. diff --git a/tools/scripts/run_uncrustify.sh b/tools/scripts/run_uncrustify.sh index 28f5624..d445b4d 100755 --- a/tools/scripts/run_uncrustify.sh +++ b/tools/scripts/run_uncrustify.sh @@ -1,6 +1,6 @@ #!/bin/bash -# Copyright 2023 Arm Limited and/or its affiliates +# Copyright 2023-2024 Arm Limited and/or its affiliates # # SPDX-License-Identifier: MIT @@ -21,4 +21,4 @@ do done exclude_pattern+="./build" -fdfind -E $exclude_pattern -e c -e h --exec uncrustify --no-backup --replace --if-changed -c tools/uncrustify.cfg +fdfind -E $exclude_pattern -e c -e h -e cc -e cpp --exec uncrustify --no-backup --replace --if-changed -c tools/uncrustify.cfg