diff --git a/src/app/tests/TestReadInteraction.cpp b/src/app/tests/TestReadInteraction.cpp index 2b78cb8f4c9388..660cd8ef8020c5 100644 --- a/src/app/tests/TestReadInteraction.cpp +++ b/src/app/tests/TestReadInteraction.cpp @@ -2696,6 +2696,20 @@ void TestReadInteraction::TestPostSubscribeRoundtripChunkReportTimeout(nlTestSui ctx.CreateSessionBobToAlice(); } +namespace { + +void CheckForInvalidAction(nlTestSuite * apSuite, Test::MessageCapturer & messageLog) +{ + NL_TEST_ASSERT(apSuite, messageLog.MessageCount() == 1); + NL_TEST_ASSERT(apSuite, messageLog.IsMessageType(0, Protocols::InteractionModel::MsgType::StatusResponse)); + CHIP_ERROR status; + NL_TEST_ASSERT(apSuite, + StatusResponse::ProcessStatusResponse(std::move(messageLog.MessagePayload(0)), status) == CHIP_NO_ERROR); + NL_TEST_ASSERT(apSuite, status == CHIP_IM_GLOBAL_STATUS(InvalidAction)); +} + +} // anonymous namespace + // Read Client sends the read request, Read Handler drops the response, then test injects unknown status reponse message for Read // Client. void TestReadInteraction::TestReadClientReceiveInvalidMessage(nlTestSuite * apSuite, void * apContext) @@ -2750,6 +2764,9 @@ void TestReadInteraction::TestReadClientReceiveInvalidMessage(nlTestSuite * apSu payloadHeader.SetExchangeID(0); payloadHeader.SetMessageType(chip::Protocols::InteractionModel::MsgType::StatusResponse); + Test::MessageCapturer messageLog(ctx); + messageLog.mCaptureStandaloneAcks = false; + rm->ClearRetransTable(readClient.mExchange.Get()); NL_TEST_ASSERT(apSuite, ctx.GetLoopback().mSentMessageCount == 2); NL_TEST_ASSERT(apSuite, ctx.GetLoopback().mDroppedMessageCount == 1); @@ -2760,12 +2777,12 @@ void TestReadInteraction::TestReadClientReceiveInvalidMessage(nlTestSuite * apSu readClient.OnMessageReceived(readClient.mExchange.Get(), payloadHeader, std::move(msgBuf)); ctx.DrainAndServiceIO(); - // TODO: Need to validate what status is being sent to the ReadHandler // The ReadHandler closed its exchange when it sent the Report Data (which we dropped). // Since we synthesized the StatusResponse to the ReadClient, instead of sending it from the ReadHandler, // the only messages here are the ReadClient's StatusResponse to the unexpected message and an MRP ack. - NL_TEST_ASSERT(apSuite, ctx.GetLoopback().mSentMessageCount == 2); NL_TEST_ASSERT(apSuite, delegate.mError == CHIP_IM_GLOBAL_STATUS(Busy)); + + CheckForInvalidAction(apSuite, messageLog); } engine->Shutdown(); diff --git a/src/messaging/tests/MessagingContext.cpp b/src/messaging/tests/MessagingContext.cpp index 127b1a152d7bf2..c31e66b7c0e47f 100644 --- a/src/messaging/tests/MessagingContext.cpp +++ b/src/messaging/tests/MessagingContext.cpp @@ -20,6 +20,7 @@ #include #include #include +#include namespace chip { namespace Test { @@ -204,5 +205,16 @@ Messaging::ExchangeContext * MessagingContext::NewExchangeToBob(Messaging::Excha return mExchangeManager.NewContext(GetSessionAliceToBob(), delegate); } +void MessageCapturer::OnMessageReceived(const PacketHeader & packetHeader, const PayloadHeader & payloadHeader, + const SessionHandle & session, DuplicateMessage isDuplicate, + System::PacketBufferHandle && msgBuf) +{ + if (mCaptureStandaloneAcks || !payloadHeader.HasMessageType(Protocols::SecureChannel::MsgType::StandaloneAck)) + { + mCapturedMessages.emplace_back(Message{ packetHeader, payloadHeader, isDuplicate, msgBuf.CloneData() }); + } + mOriginalDelegate.OnMessageReceived(packetHeader, payloadHeader, session, isDuplicate, std::move(msgBuf)); +} + } // namespace Test } // namespace chip diff --git a/src/messaging/tests/MessagingContext.h b/src/messaging/tests/MessagingContext.h index 8c13fdd4f62a85..559a0466b7d930 100644 --- a/src/messaging/tests/MessagingContext.h +++ b/src/messaging/tests/MessagingContext.h @@ -30,6 +30,8 @@ #include +#include + namespace chip { namespace Test { @@ -65,7 +67,7 @@ class PlatformMemoryUser }; /** - * @brief The context of test cases for messaging layer. It wil initialize network layer and system layer, and create + * @brief The context of test cases for messaging layer. It will initialize network layer and system layer, and create * two secure sessions, connected with each other. Exchanges can be created for each secure session. */ class MessagingContext : public PlatformMemoryUser @@ -213,5 +215,53 @@ class LoopbackMessagingContext : public LoopbackTransportManager, public Messagi using LoopbackTransportManager::GetSystemLayer; }; +// Class that can be used to capture decrypted message traffic in tests using +// MessagingContext. +class MessageCapturer : public SessionMessageDelegate +{ +public: + MessageCapturer(MessagingContext & aContext) : + mSessionManager(aContext.GetSecureSessionManager()), mOriginalDelegate(aContext.GetExchangeManager()) + { + // Interpose ourselves into the message flow. + mSessionManager.SetMessageDelegate(this); + } + + ~MessageCapturer() + { + // Restore the normal message flow. + mSessionManager.SetMessageDelegate(&mOriginalDelegate); + } + + struct Message + { + PacketHeader mPacketHeader; + PayloadHeader mPayloadHeader; + DuplicateMessage mIsDuplicate; + System::PacketBufferHandle mPayload; + }; + + size_t MessageCount() const { return mCapturedMessages.size(); } + + template ::value>> + bool IsMessageType(size_t index, MessageType type) + { + return mCapturedMessages[index].mPayloadHeader.HasMessageType(type); + } + + System::PacketBufferHandle & MessagePayload(size_t index) { return mCapturedMessages[index].mPayload; } + + bool mCaptureStandaloneAcks = true; + +private: + // SessionMessageDelegate implementation. + void OnMessageReceived(const PacketHeader & packetHeader, const PayloadHeader & payloadHeader, const SessionHandle & session, + DuplicateMessage isDuplicate, System::PacketBufferHandle && msgBuf) override; + + SessionManager & mSessionManager; + SessionMessageDelegate & mOriginalDelegate; + std::vector mCapturedMessages; +}; + } // namespace Test } // namespace chip diff --git a/src/transport/raw/MessageHeader.h b/src/transport/raw/MessageHeader.h index f253aa4f928751..c1ccf540b2f55a 100644 --- a/src/transport/raw/MessageHeader.h +++ b/src/transport/raw/MessageHeader.h @@ -430,6 +430,7 @@ class PayloadHeader { public: constexpr PayloadHeader() { SetProtocol(Protocols::NotSpecified); } + constexpr PayloadHeader(const PayloadHeader &) = default; PayloadHeader & operator=(const PayloadHeader &) = default; /** Get the Session ID from this header. */