From 10299465e2166f8f5c67f1f9504d2a30887c8ba6 Mon Sep 17 00:00:00 2001 From: Jerry Johns Date: Fri, 5 Nov 2021 12:01:02 -0700 Subject: [PATCH] Add async message dispatch to loopback (#11461) * Add async message dispatch to loopback This PR was triggered by some test failures in some of the end-to-end IM unit tests that utilized the loopback transport to send/receive payloads from client to server and back. Since the current loopback transport processes 'transmitted' messages synchronously without completing the execution of the original context, it results in call flows that are not typical of actual devices interacting with each other. This resulted in a use-after-free error where the upon calling SendMessage() within the CommandSender, the synchronous execution resulted in the eventual destruction of the original CommandSender object immediately after SendMessage() was called. This PR adds support for asynchronous dispatch and handling of transmitted messages that is more representative of real-world CHIP node interactions to the existing loopback interface. It utilizes SystemLayer::ScheduleWork to handle the processing of the sent message as a bottom half handler. It also adds a DrainAndServiceIO method on the AppContext that will automatically drain and service the IO till all messages have been handled. Tests: - Ensured the TestCommand failure doesn't happen again. * Apply suggestions from code review Co-authored-by: Boris Zbarsky Co-authored-by: Boris Zbarsky --- src/app/tests/AppTestContext.h | 54 +++++++++++++++++++ .../tests/data_model/TestCommands.cpp | 12 ++++- src/controller/tests/data_model/TestRead.cpp | 12 ++++- src/controller/tests/data_model/TestWrite.cpp | 6 ++- src/transport/TransportMgr.h | 3 ++ src/transport/raw/Tuple.h | 7 +++ src/transport/raw/tests/NetworkTestHelpers.h | 52 +++++++++++++++++- 7 files changed, 142 insertions(+), 4 deletions(-) diff --git a/src/app/tests/AppTestContext.h b/src/app/tests/AppTestContext.h index 469dbf23c5c09c..d12e0c98346a6d 100644 --- a/src/app/tests/AppTestContext.h +++ b/src/app/tests/AppTestContext.h @@ -15,6 +15,7 @@ */ #pragma once +#include "system/SystemClock.h" #include #include @@ -36,12 +37,65 @@ class AppContext : public MessagingContext // Shutdown all layers, finalize operations CHIP_ERROR Shutdown(); + /* + * For unit-tests that simulate end-to-end transmission and reception of messages in loopback mode, + * this mode better replicates a real-functioning stack that correctly handles the processing + * of a transmitted message as an asynchronous, bottom half handler dispatched after the current execution context has + completed. + * This is achieved using SystemLayer::ScheduleWork. + + * This should be used in conjunction with the DrainAndServiceIO function below to correctly service and drain the event queue. + * + */ + void EnableAsyncDispatch() + { + auto & impl = mTransportManager.GetTransport().GetImplAtIndex<0>(); + impl.EnableAsyncDispatch(&mIOContext.GetSystemLayer()); + } + + /* + * This drives the servicing of events using the embedded IOContext while there are pending + * messages in the loopback transport's pending message queue. This should run to completion + * in well-behaved logic (i.e there isn't an indefinite ping-pong of messages transmitted back + * and forth). + * + * Consequently, this is guarded with a user-provided timeout to ensure we don't have unit-tests that stall + * in CI due to bugs in the code that is being tested. + * + * This DOES NOT ensure that all pending events are serviced to completion (i.e timers, any ScheduleWork calls). + * + */ + void DrainAndServiceIO(System::Clock::Timeout maxWait = chip::System::Clock::Seconds16(5)) + { + auto & impl = mTransportManager.GetTransport().GetImplAtIndex<0>(); + System::Clock::Timestamp startTime = System::SystemClock().GetMonotonicTimestamp(); + + while (impl.HasPendingMessages()) + { + mIOContext.DriveIO(); + if ((System::SystemClock().GetMonotonicTimestamp() - startTime) >= maxWait) + { + break; + } + } + } + static int Initialize(void * context) { auto * ctx = static_cast(context); return ctx->Init() == CHIP_NO_ERROR ? SUCCESS : FAILURE; } + static int InitializeAsync(void * context) + { + auto * ctx = static_cast(context); + + VerifyOrReturnError(ctx->Init() == CHIP_NO_ERROR, FAILURE); + ctx->EnableAsyncDispatch(); + + return SUCCESS; + } + static int Finalize(void * context) { auto * ctx = static_cast(context); diff --git a/src/controller/tests/data_model/TestCommands.cpp b/src/controller/tests/data_model/TestCommands.cpp index 058b19a8089e3c..8421c0db738654 100644 --- a/src/controller/tests/data_model/TestCommands.cpp +++ b/src/controller/tests/data_model/TestCommands.cpp @@ -199,6 +199,8 @@ void TestCommandInteraction::TestDataResponse(nlTestSuite * apSuite, void * apCo chip::Controller::InvokeCommandRequest( &ctx.GetExchangeManager(), sessionHandle, kTestEndpointId, request, onSuccessCb, onFailureCb); + ctx.DrainAndServiceIO(); + NL_TEST_ASSERT(apSuite, onSuccessWasCalled && !onFailureWasCalled); NL_TEST_ASSERT(apSuite, ctx.GetExchangeManager().GetNumActiveExchanges() == 0); } @@ -231,6 +233,8 @@ void TestCommandInteraction::TestSuccessNoDataResponse(nlTestSuite * apSuite, vo chip::Controller::InvokeCommandRequest(&ctx.GetExchangeManager(), sessionHandle, kTestEndpointId, request, onSuccessCb, onFailureCb); + ctx.DrainAndServiceIO(); + NL_TEST_ASSERT(apSuite, onSuccessWasCalled && !onFailureWasCalled && statusCheck); NL_TEST_ASSERT(apSuite, ctx.GetExchangeManager().GetNumActiveExchanges() == 0); } @@ -263,6 +267,8 @@ void TestCommandInteraction::TestFailure(nlTestSuite * apSuite, void * apContext chip::Controller::InvokeCommandRequest(&ctx.GetExchangeManager(), sessionHandle, kTestEndpointId, request, onSuccessCb, onFailureCb); + ctx.DrainAndServiceIO(); + NL_TEST_ASSERT(apSuite, !onSuccessWasCalled && onFailureWasCalled && statusCheck); NL_TEST_ASSERT(apSuite, ctx.GetExchangeManager().GetNumActiveExchanges() == 0); } @@ -296,6 +302,8 @@ void TestCommandInteraction::TestSuccessNoDataResponseWithClusterStatus(nlTestSu chip::Controller::InvokeCommandRequest(&ctx.GetExchangeManager(), sessionHandle, kTestEndpointId, request, onSuccessCb, onFailureCb); + ctx.DrainAndServiceIO(); + NL_TEST_ASSERT(apSuite, onSuccessWasCalled && !onFailureWasCalled && statusCheck); NL_TEST_ASSERT(apSuite, ctx.GetExchangeManager().GetNumActiveExchanges() == 0); } @@ -329,6 +337,8 @@ void TestCommandInteraction::TestFailureWithClusterStatus(nlTestSuite * apSuite, chip::Controller::InvokeCommandRequest(&ctx.GetExchangeManager(), sessionHandle, kTestEndpointId, request, onSuccessCb, onFailureCb); + ctx.DrainAndServiceIO(); + NL_TEST_ASSERT(apSuite, !onSuccessWasCalled && onFailureWasCalled && statusCheck); NL_TEST_ASSERT(apSuite, ctx.GetExchangeManager().GetNumActiveExchanges() == 0); } @@ -350,7 +360,7 @@ nlTestSuite sSuite = { "TestCommands", &sTests[0], - TestContext::Initialize, + TestContext::InitializeAsync, TestContext::Finalize }; // clang-format on diff --git a/src/controller/tests/data_model/TestRead.cpp b/src/controller/tests/data_model/TestRead.cpp index 96730b9c30b7c2..e37509f6052be8 100644 --- a/src/controller/tests/data_model/TestRead.cpp +++ b/src/controller/tests/data_model/TestRead.cpp @@ -144,7 +144,9 @@ void TestReadInteraction::TestDataResponse(nlTestSuite * apSuite, void * apConte chip::Controller::ReadAttribute( &ctx.GetExchangeManager(), sessionHandle, kTestEndpointId, onSuccessCb, onFailureCb); + ctx.DrainAndServiceIO(); chip::app::InteractionModelEngine::GetInstance()->GetReportingEngine().Run(); + ctx.DrainAndServiceIO(); NL_TEST_ASSERT(apSuite, onSuccessCbInvoked && !onFailureCbInvoked); NL_TEST_ASSERT(apSuite, chip::app::InteractionModelEngine::GetInstance()->GetNumActiveReadClients() == 0); @@ -177,7 +179,9 @@ void TestReadInteraction::TestAttributeError(nlTestSuite * apSuite, void * apCon chip::Controller::ReadAttribute( &ctx.GetExchangeManager(), sessionHandle, kTestEndpointId, onSuccessCb, onFailureCb); + ctx.DrainAndServiceIO(); chip::app::InteractionModelEngine::GetInstance()->GetReportingEngine().Run(); + ctx.DrainAndServiceIO(); NL_TEST_ASSERT(apSuite, !onSuccessCbInvoked && onFailureCbInvoked); NL_TEST_ASSERT(apSuite, chip::app::InteractionModelEngine::GetInstance()->GetNumActiveReadClients() == 0); @@ -210,11 +214,15 @@ void TestReadInteraction::TestReadTimeout(nlTestSuite * apSuite, void * apContex chip::Controller::ReadAttribute( &ctx.GetExchangeManager(), sessionHandle, kTestEndpointId, onSuccessCb, onFailureCb); + ctx.DrainAndServiceIO(); + NL_TEST_ASSERT(apSuite, chip::app::InteractionModelEngine::GetInstance()->GetNumActiveReadClients() == 1); NL_TEST_ASSERT(apSuite, ctx.GetExchangeManager().GetNumActiveExchanges() == 2); ctx.GetExchangeManager().OnConnectionExpired(ctx.GetSessionBobToAlice()); + ctx.DrainAndServiceIO(); + NL_TEST_ASSERT(apSuite, !onSuccessCbInvoked && onFailureCbInvoked); NL_TEST_ASSERT(apSuite, chip::app::InteractionModelEngine::GetInstance()->GetNumActiveReadClients() == 0); @@ -223,7 +231,9 @@ void TestReadInteraction::TestReadTimeout(nlTestSuite * apSuite, void * apContex // // NL_TEST_ASSERT(apSuite, ctx.GetExchangeManager().GetNumActiveExchanges() == 1); + ctx.DrainAndServiceIO(); chip::app::InteractionModelEngine::GetInstance()->GetReportingEngine().Run(); + ctx.DrainAndServiceIO(); ctx.GetExchangeManager().OnConnectionExpired(ctx.GetSessionAliceToBob()); @@ -250,7 +260,7 @@ nlTestSuite sSuite = { "TestRead", &sTests[0], - TestContext::Initialize, + TestContext::InitializeAsync, TestContext::Finalize }; // clang-format on diff --git a/src/controller/tests/data_model/TestWrite.cpp b/src/controller/tests/data_model/TestWrite.cpp index ef8c9519f22ac8..fe9384d90de596 100644 --- a/src/controller/tests/data_model/TestWrite.cpp +++ b/src/controller/tests/data_model/TestWrite.cpp @@ -151,6 +151,8 @@ void TestWriteInteraction::TestDataResponse(nlTestSuite * apSuite, void * apCont chip::Controller::WriteAttribute( &ctx.GetExchangeManager(), sessionHandle, kTestEndpointId, value, onSuccessCb, onFailureCb); + ctx.DrainAndServiceIO(); + NL_TEST_ASSERT(apSuite, onSuccessCbInvoked && !onFailureCbInvoked); NL_TEST_ASSERT(apSuite, chip::app::InteractionModelEngine::GetInstance()->GetNumActiveWriteHandlers() == 0); NL_TEST_ASSERT(apSuite, ctx.GetExchangeManager().GetNumActiveExchanges() == 0); @@ -190,6 +192,8 @@ void TestWriteInteraction::TestAttributeError(nlTestSuite * apSuite, void * apCo chip::Controller::WriteAttribute( &ctx.GetExchangeManager(), sessionHandle, kTestEndpointId, value, onSuccessCb, onFailureCb); + ctx.DrainAndServiceIO(); + NL_TEST_ASSERT(apSuite, !onSuccessCbInvoked && onFailureCbInvoked); NL_TEST_ASSERT(apSuite, chip::app::InteractionModelEngine::GetInstance()->GetNumActiveWriteHandlers() == 0); NL_TEST_ASSERT(apSuite, ctx.GetExchangeManager().GetNumActiveExchanges() == 0); @@ -209,7 +213,7 @@ nlTestSuite sSuite = { "TestWrite", &sTests[0], - TestContext::Initialize, + TestContext::InitializeAsync, TestContext::Finalize }; // clang-format on diff --git a/src/transport/TransportMgr.h b/src/transport/TransportMgr.h index 6688a1de4ddc6a..494db39de964fd 100644 --- a/src/transport/TransportMgr.h +++ b/src/transport/TransportMgr.h @@ -78,6 +78,9 @@ class TransportMgr : public TransportMgrBase private: Transport::Tuple mTransport; + +public: + auto & GetTransport() { return mTransport; } }; } // namespace chip diff --git a/src/transport/raw/Tuple.h b/src/transport/raw/Tuple.h index 4127ef2995b75f..9cb7ab81c12382 100644 --- a/src/transport/raw/Tuple.h +++ b/src/transport/raw/Tuple.h @@ -239,6 +239,13 @@ class Tuple : public Base CHIP_ERROR InitImpl(RawTransportDelegate * delegate) { return CHIP_NO_ERROR; } std::tuple mTransports; + +public: + template + auto GetImplAtIndex() -> decltype(std::get(mTransports)) & + { + return std::get(mTransports); + } }; } // namespace Transport diff --git a/src/transport/raw/tests/NetworkTestHelpers.h b/src/transport/raw/tests/NetworkTestHelpers.h index 52dc3402094461..b7e6b881c0f4b3 100644 --- a/src/transport/raw/tests/NetworkTestHelpers.h +++ b/src/transport/raw/tests/NetworkTestHelpers.h @@ -27,6 +27,7 @@ #include #include +#include namespace chip { namespace Test { @@ -63,6 +64,32 @@ class LoopbackTransport : public Transport::Base /// Transports are required to have a constructor that takes exactly one argument CHIP_ERROR Init(const char *) { return CHIP_NO_ERROR; } + /* + * For unit-tests that simulate end-to-end transmission and reception of messages in loopback mode, + * this mode better replicates a real-functioning stack that correctly handles the processing + * of a transmitted message as an asynchronous, bottom half handler dispatched after the current execution context has + * completed. This is achieved using SystemLayer::ScheduleWork. + */ + void EnableAsyncDispatch(System::Layer * aSystemLayer) + { + mSystemLayer = aSystemLayer; + mAsyncMessageDispatch = true; + } + + bool HasPendingMessages() { return !mPendingMessageQueue.empty(); } + + static void OnMessageReceived(System::Layer * aSystemLayer, void * aAppState) + { + LoopbackTransport * _this = static_cast(aAppState); + + while (!_this->mPendingMessageQueue.empty()) + { + auto item = std::move(_this->mPendingMessageQueue.front()); + _this->mPendingMessageQueue.pop(); + _this->HandleMessageReceived(item.mDestinationAddress, std::move(item.mPendingMessage)); + } + } + CHIP_ERROR SendMessage(const Transport::PeerAddress & address, System::PacketBufferHandle && msgBuf) override { ReturnErrorOnFailure(mMessageSendError); @@ -71,7 +98,16 @@ class LoopbackTransport : public Transport::Base if (mNumMessagesToDrop == 0) { System::PacketBufferHandle receivedMessage = msgBuf.CloneData(); - HandleMessageReceived(address, std::move(receivedMessage)); + + if (mAsyncMessageDispatch) + { + mPendingMessageQueue.push(PendingMessageItem(address, std::move(receivedMessage))); + mSystemLayer->ScheduleWork(OnMessageReceived, this); + } + else + { + HandleMessageReceived(address, std::move(receivedMessage)); + } } else { @@ -93,9 +129,23 @@ class LoopbackTransport : public Transport::Base mMessageSendError = CHIP_NO_ERROR; } + struct PendingMessageItem + { + PendingMessageItem(const Transport::PeerAddress destinationAddress, System::PacketBufferHandle && pendingMessage) : + mDestinationAddress(destinationAddress), mPendingMessage(std::move(pendingMessage)) + {} + + const Transport::PeerAddress mDestinationAddress; + System::PacketBufferHandle mPendingMessage; + }; + // Hook for subclasses to perform custom logic on message drops. virtual void MessageDropped() {} + System::Layer * mSystemLayer = nullptr; + bool mAsyncMessageDispatch = false; + std::queue mPendingMessageQueue; + Transport::PeerAddress mTxAddress; uint32_t mNumMessagesToDrop = 0; uint32_t mDroppedMessageCount = 0; uint32_t mSentMessageCount = 0;