diff --git a/src/app/WriteClient.cpp b/src/app/WriteClient.cpp index 85af809681aedc..bbd4d289f58f05 100644 --- a/src/app/WriteClient.cpp +++ b/src/app/WriteClient.cpp @@ -423,22 +423,30 @@ CHIP_ERROR WriteClient::OnMessageReceived(Messaging::ExchangeContext * apExchang MoveToState(State::ResponseReceived); } - CHIP_ERROR err = CHIP_NO_ERROR; - + CHIP_ERROR err = CHIP_NO_ERROR; + bool sendStatusResponse = false; // Assert that the exchange context matches the client's current context. // This should never fail because even if SendWriteRequest is called // back-to-back, the second call will call Close() on the first exchange, // which clears the OnMessageReceived callback. VerifyOrExit(apExchangeContext == mExchangeCtx.Get(), err = CHIP_ERROR_INCORRECT_STATE); + sendStatusResponse = true; + if (mState == State::AwaitingTimedStatus) { - VerifyOrExit(aPayloadHeader.HasMessageType(MsgType::StatusResponse), err = CHIP_ERROR_INVALID_MESSAGE_TYPE); - CHIP_ERROR statusError = CHIP_NO_ERROR; - SuccessOrExit(err = StatusResponse::ProcessStatusResponse(std::move(aPayload), statusError)); - SuccessOrExit(err = statusError); - err = SendWriteRequest(); - + if (aPayloadHeader.HasMessageType(MsgType::StatusResponse)) + { + CHIP_ERROR statusError = CHIP_NO_ERROR; + SuccessOrExit(err = StatusResponse::ProcessStatusResponse(std::move(aPayload), statusError)); + sendStatusResponse = false; + SuccessOrExit(err = statusError); + err = SendWriteRequest(); + } + else + { + err = CHIP_ERROR_INVALID_MESSAGE_TYPE; + } // Skip all other processing here (which is for the response to the // write request), no matter whether err is success or not. goto exit; @@ -448,6 +456,7 @@ CHIP_ERROR WriteClient::OnMessageReceived(Messaging::ExchangeContext * apExchang { err = ProcessWriteResponseMessage(std::move(aPayload)); SuccessOrExit(err); + sendStatusResponse = false; if (!mChunks.IsNull()) { // Send the next chunk. @@ -475,6 +484,11 @@ CHIP_ERROR WriteClient::OnMessageReceived(Messaging::ExchangeContext * apExchang } } + if (sendStatusResponse) + { + StatusResponse::Send(Status::InvalidAction, apExchangeContext, false /*aExpectResponse*/); + } + if (mState != State::AwaitingResponse) { Close(); diff --git a/src/app/tests/TestWriteInteraction.cpp b/src/app/tests/TestWriteInteraction.cpp index e0149945b07d19..dcd13e596859ea 100644 --- a/src/app/tests/TestWriteInteraction.cpp +++ b/src/app/tests/TestWriteInteraction.cpp @@ -60,6 +60,10 @@ class TestWriteInteraction static void TestWriteClientGroup(nlTestSuite * apSuite, void * apContext); static void TestWriteHandler(nlTestSuite * apSuite, void * apContext); static void TestWriteRoundtrip(nlTestSuite * apSuite, void * apContext); + static void TestWriteInvalidMessage1(nlTestSuite * apSuite, void * apContext); + static void TestWriteInvalidMessage2(nlTestSuite * apSuite, void * apContext); + static void TestWriteInvalidMessage3(nlTestSuite * apSuite, void * apContext); + static void TestWriteInvalidMessage4(nlTestSuite * apSuite, void * apContext); static void TestWriteRoundtripWithClusterObjects(nlTestSuite * apSuite, void * apContext); static void TestWriteRoundtripWithClusterObjectsVersionMatch(nlTestSuite * apSuite, void * apContext); static void TestWriteRoundtripWithClusterObjectsVersionMismatch(nlTestSuite * apSuite, void * apContext); @@ -98,6 +102,7 @@ class TestWriteClientCallback : public chip::app::WriteClient::Callback { mOnErrorCalled++; mLastErrorReason = app::StatusIB(chipError); + mError = chipError; } void OnDone(WriteClient * apWriteClient) override { mOnDoneCalled++; } @@ -106,6 +111,7 @@ class TestWriteClientCallback : public chip::app::WriteClient::Callback int mOnDoneCalled = 0; StatusIB mStatus; StatusIB mLastErrorReason; + CHIP_ERROR mError = CHIP_NO_ERROR; }; void TestWriteInteraction::AddAttributeDataIB(nlTestSuite * apSuite, void * apContext, WriteClient & aWriteClient) @@ -621,6 +627,275 @@ void TestWriteInteraction::TestWriteHandlerReceiveInvalidMessage(nlTestSuite * a ctx.CreateSessionBobToAlice(); } #endif + +// Write Client sends a write request, receives an unexpected message type, sends a status response to that. +void TestWriteInteraction::TestWriteInvalidMessage1(nlTestSuite * apSuite, void * apContext) +{ + TestContext & ctx = *static_cast(apContext); + + CHIP_ERROR err = CHIP_NO_ERROR; + + Messaging::ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr(); + // Shouldn't have anything in the retransmit table when starting the test. + NL_TEST_ASSERT(apSuite, rm->TestGetCountRetransTable() == 0); + + TestWriteClientCallback callback; + auto * engine = chip::app::InteractionModelEngine::GetInstance(); + err = engine->Init(&ctx.GetExchangeManager(), &ctx.GetFabricTable()); + NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); + + app::WriteClient writeClient(engine->GetExchangeManager(), &callback, Optional::Missing()); + + System::PacketBufferHandle buf = System::PacketBufferHandle::New(System::PacketBuffer::kMaxSize); + AddAttributeDataIB(apSuite, apContext, writeClient); + + NL_TEST_ASSERT(apSuite, callback.mOnSuccessCalled == 0 && callback.mOnErrorCalled == 0 && callback.mOnDoneCalled == 0); + + ctx.GetLoopback().mSentMessageCount = 0; + ctx.GetLoopback().mNumMessagesToDrop = 1; + ctx.GetLoopback().mNumMessagesToAllowBeforeDropping = 1; + ctx.GetLoopback().mDroppedMessageCount = 0; + err = writeClient.SendWriteRequest(ctx.GetSessionBobToAlice()); + NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); + ctx.DrainAndServiceIO(); + + NL_TEST_ASSERT(apSuite, ctx.GetLoopback().mSentMessageCount == 2); + NL_TEST_ASSERT(apSuite, ctx.GetLoopback().mDroppedMessageCount == 1); + + System::PacketBufferHandle msgBuf = System::PacketBufferHandle::New(kMaxSecureSduLengthBytes); + NL_TEST_ASSERT(apSuite, !msgBuf.IsNull()); + System::PacketBufferTLVWriter writer; + writer.Init(std::move(msgBuf)); + ReportDataMessage::Builder response; + response.Init(&writer); + NL_TEST_ASSERT(apSuite, writer.Finalize(&msgBuf) == CHIP_NO_ERROR); + PayloadHeader payloadHeader; + payloadHeader.SetExchangeID(0); + payloadHeader.SetMessageType(chip::Protocols::InteractionModel::MsgType::ReportData); + + rm->ClearRetransTable(writeClient.mExchangeCtx.Get()); + ctx.GetLoopback().mSentMessageCount = 0; + ctx.GetLoopback().mNumMessagesToDrop = 0; + ctx.GetLoopback().mNumMessagesToAllowBeforeDropping = 0; + ctx.GetLoopback().mDroppedMessageCount = 0; + err = writeClient.OnMessageReceived(writeClient.mExchangeCtx.Get(), payloadHeader, std::move(msgBuf)); + NL_TEST_ASSERT(apSuite, err == CHIP_ERROR_INVALID_MESSAGE_TYPE); + ctx.DrainAndServiceIO(); + NL_TEST_ASSERT(apSuite, callback.mError == CHIP_ERROR_INVALID_MESSAGE_TYPE); + NL_TEST_ASSERT(apSuite, callback.mOnSuccessCalled == 0 && callback.mOnErrorCalled == 1 && callback.mOnDoneCalled == 1); + + // TODO: Check that the server gets the right status. + // Client sents status report with invalid action, server's exchange has been closed, so all it sends is an MRP Ack + NL_TEST_ASSERT(apSuite, ctx.GetLoopback().mSentMessageCount == 2); + + engine->Shutdown(); + ctx.ExpireSessionAliceToBob(); + ctx.ExpireSessionBobToAlice(); + ctx.CreateSessionAliceToBob(); + ctx.CreateSessionBobToAlice(); +} + +// Write Client sends a write request, receives a malformed write response message, sends a Status Report. +void TestWriteInteraction::TestWriteInvalidMessage2(nlTestSuite * apSuite, void * apContext) +{ + TestContext & ctx = *static_cast(apContext); + + CHIP_ERROR err = CHIP_NO_ERROR; + + Messaging::ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr(); + // Shouldn't have anything in the retransmit table when starting the test. + NL_TEST_ASSERT(apSuite, rm->TestGetCountRetransTable() == 0); + + TestWriteClientCallback callback; + auto * engine = chip::app::InteractionModelEngine::GetInstance(); + err = engine->Init(&ctx.GetExchangeManager(), &ctx.GetFabricTable()); + NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); + + app::WriteClient writeClient(engine->GetExchangeManager(), &callback, Optional::Missing()); + + System::PacketBufferHandle buf = System::PacketBufferHandle::New(System::PacketBuffer::kMaxSize); + AddAttributeDataIB(apSuite, apContext, writeClient); + + NL_TEST_ASSERT(apSuite, callback.mOnSuccessCalled == 0 && callback.mOnErrorCalled == 0 && callback.mOnDoneCalled == 0); + + ctx.GetLoopback().mSentMessageCount = 0; + ctx.GetLoopback().mNumMessagesToDrop = 1; + ctx.GetLoopback().mNumMessagesToAllowBeforeDropping = 1; + ctx.GetLoopback().mDroppedMessageCount = 0; + err = writeClient.SendWriteRequest(ctx.GetSessionBobToAlice()); + NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); + ctx.DrainAndServiceIO(); + + NL_TEST_ASSERT(apSuite, ctx.GetLoopback().mSentMessageCount == 2); + NL_TEST_ASSERT(apSuite, ctx.GetLoopback().mDroppedMessageCount == 1); + + System::PacketBufferHandle msgBuf = System::PacketBufferHandle::New(kMaxSecureSduLengthBytes); + NL_TEST_ASSERT(apSuite, !msgBuf.IsNull()); + System::PacketBufferTLVWriter writer; + writer.Init(std::move(msgBuf)); + WriteResponseMessage::Builder response; + response.Init(&writer); + NL_TEST_ASSERT(apSuite, writer.Finalize(&msgBuf) == CHIP_NO_ERROR); + PayloadHeader payloadHeader; + payloadHeader.SetExchangeID(0); + payloadHeader.SetMessageType(chip::Protocols::InteractionModel::MsgType::WriteResponse); + + rm->ClearRetransTable(writeClient.mExchangeCtx.Get()); + ctx.GetLoopback().mSentMessageCount = 0; + ctx.GetLoopback().mNumMessagesToDrop = 0; + ctx.GetLoopback().mNumMessagesToAllowBeforeDropping = 0; + ctx.GetLoopback().mDroppedMessageCount = 0; + err = writeClient.OnMessageReceived(writeClient.mExchangeCtx.Get(), payloadHeader, std::move(msgBuf)); + NL_TEST_ASSERT(apSuite, err == CHIP_ERROR_END_OF_TLV); + ctx.DrainAndServiceIO(); + NL_TEST_ASSERT(apSuite, callback.mError == CHIP_ERROR_END_OF_TLV); + NL_TEST_ASSERT(apSuite, callback.mOnSuccessCalled == 0 && callback.mOnErrorCalled == 1 && callback.mOnDoneCalled == 1); + + // Client sents status report with invalid action, server's exchange has been closed, so all it sends is an MRP Ack + NL_TEST_ASSERT(apSuite, ctx.GetLoopback().mSentMessageCount == 2); + + engine->Shutdown(); + ctx.ExpireSessionAliceToBob(); + ctx.ExpireSessionBobToAlice(); + ctx.CreateSessionAliceToBob(); + ctx.CreateSessionBobToAlice(); +} + +// Write Client sends a write request, receives a malformed status response message. +void TestWriteInteraction::TestWriteInvalidMessage3(nlTestSuite * apSuite, void * apContext) +{ + TestContext & ctx = *static_cast(apContext); + + CHIP_ERROR err = CHIP_NO_ERROR; + + Messaging::ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr(); + // Shouldn't have anything in the retransmit table when starting the test. + NL_TEST_ASSERT(apSuite, rm->TestGetCountRetransTable() == 0); + + TestWriteClientCallback callback; + auto * engine = chip::app::InteractionModelEngine::GetInstance(); + err = engine->Init(&ctx.GetExchangeManager(), &ctx.GetFabricTable()); + NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); + + app::WriteClient writeClient(engine->GetExchangeManager(), &callback, Optional::Missing()); + + System::PacketBufferHandle buf = System::PacketBufferHandle::New(System::PacketBuffer::kMaxSize); + AddAttributeDataIB(apSuite, apContext, writeClient); + + NL_TEST_ASSERT(apSuite, callback.mOnSuccessCalled == 0 && callback.mOnErrorCalled == 0 && callback.mOnDoneCalled == 0); + + ctx.GetLoopback().mSentMessageCount = 0; + ctx.GetLoopback().mNumMessagesToDrop = 1; + ctx.GetLoopback().mNumMessagesToAllowBeforeDropping = 1; + ctx.GetLoopback().mDroppedMessageCount = 0; + err = writeClient.SendWriteRequest(ctx.GetSessionBobToAlice()); + NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); + ctx.DrainAndServiceIO(); + + NL_TEST_ASSERT(apSuite, ctx.GetLoopback().mSentMessageCount == 2); + NL_TEST_ASSERT(apSuite, ctx.GetLoopback().mDroppedMessageCount == 1); + + System::PacketBufferHandle msgBuf = System::PacketBufferHandle::New(kMaxSecureSduLengthBytes); + NL_TEST_ASSERT(apSuite, !msgBuf.IsNull()); + System::PacketBufferTLVWriter writer; + writer.Init(std::move(msgBuf)); + StatusResponseMessage::Builder response; + response.Init(&writer); + NL_TEST_ASSERT(apSuite, writer.Finalize(&msgBuf) == CHIP_NO_ERROR); + PayloadHeader payloadHeader; + payloadHeader.SetExchangeID(0); + payloadHeader.SetMessageType(chip::Protocols::InteractionModel::MsgType::StatusResponse); + + rm->ClearRetransTable(writeClient.mExchangeCtx.Get()); + ctx.GetLoopback().mSentMessageCount = 0; + ctx.GetLoopback().mNumMessagesToDrop = 0; + ctx.GetLoopback().mNumMessagesToAllowBeforeDropping = 0; + ctx.GetLoopback().mDroppedMessageCount = 0; + err = writeClient.OnMessageReceived(writeClient.mExchangeCtx.Get(), payloadHeader, std::move(msgBuf)); + NL_TEST_ASSERT(apSuite, err == CHIP_ERROR_END_OF_TLV); + ctx.DrainAndServiceIO(); + NL_TEST_ASSERT(apSuite, callback.mError == CHIP_ERROR_END_OF_TLV); + NL_TEST_ASSERT(apSuite, callback.mOnSuccessCalled == 0 && callback.mOnErrorCalled == 1 && callback.mOnDoneCalled == 1); + + // TODO: Check that the server gets the right status + // Client sents status report with invalid action, server's exchange has been closed, so all it sends is an MRP ack. + NL_TEST_ASSERT(apSuite, ctx.GetLoopback().mSentMessageCount == 2); + + engine->Shutdown(); + ctx.ExpireSessionAliceToBob(); + ctx.ExpireSessionBobToAlice(); + ctx.CreateSessionAliceToBob(); + ctx.CreateSessionBobToAlice(); +} + +// Write Client sends a write request, receives a busy status response message. +void TestWriteInteraction::TestWriteInvalidMessage4(nlTestSuite * apSuite, void * apContext) +{ + TestContext & ctx = *static_cast(apContext); + + CHIP_ERROR err = CHIP_NO_ERROR; + + Messaging::ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr(); + // Shouldn't have anything in the retransmit table when starting the test. + NL_TEST_ASSERT(apSuite, rm->TestGetCountRetransTable() == 0); + + TestWriteClientCallback callback; + auto * engine = chip::app::InteractionModelEngine::GetInstance(); + err = engine->Init(&ctx.GetExchangeManager(), &ctx.GetFabricTable()); + NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); + + app::WriteClient writeClient(engine->GetExchangeManager(), &callback, Optional::Missing()); + + System::PacketBufferHandle buf = System::PacketBufferHandle::New(System::PacketBuffer::kMaxSize); + AddAttributeDataIB(apSuite, apContext, writeClient); + + NL_TEST_ASSERT(apSuite, callback.mOnSuccessCalled == 0 && callback.mOnErrorCalled == 0 && callback.mOnDoneCalled == 0); + + ctx.GetLoopback().mSentMessageCount = 0; + ctx.GetLoopback().mNumMessagesToDrop = 1; + ctx.GetLoopback().mNumMessagesToAllowBeforeDropping = 1; + ctx.GetLoopback().mDroppedMessageCount = 0; + err = writeClient.SendWriteRequest(ctx.GetSessionBobToAlice()); + NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); + ctx.DrainAndServiceIO(); + + NL_TEST_ASSERT(apSuite, ctx.GetLoopback().mSentMessageCount == 2); + NL_TEST_ASSERT(apSuite, ctx.GetLoopback().mDroppedMessageCount == 1); + + System::PacketBufferHandle msgBuf = System::PacketBufferHandle::New(kMaxSecureSduLengthBytes); + NL_TEST_ASSERT(apSuite, !msgBuf.IsNull()); + System::PacketBufferTLVWriter writer; + writer.Init(std::move(msgBuf)); + StatusResponseMessage::Builder response; + response.Init(&writer); + response.Status(Protocols::InteractionModel::Status::Busy); + NL_TEST_ASSERT(apSuite, writer.Finalize(&msgBuf) == CHIP_NO_ERROR); + PayloadHeader payloadHeader; + payloadHeader.SetExchangeID(0); + payloadHeader.SetMessageType(chip::Protocols::InteractionModel::MsgType::StatusResponse); + + rm->ClearRetransTable(writeClient.mExchangeCtx.Get()); + ctx.GetLoopback().mSentMessageCount = 0; + ctx.GetLoopback().mNumMessagesToDrop = 0; + ctx.GetLoopback().mNumMessagesToAllowBeforeDropping = 0; + ctx.GetLoopback().mDroppedMessageCount = 0; + err = writeClient.OnMessageReceived(writeClient.mExchangeCtx.Get(), payloadHeader, std::move(msgBuf)); + NL_TEST_ASSERT(apSuite, err == CHIP_IM_GLOBAL_STATUS(Busy)); + ctx.DrainAndServiceIO(); + NL_TEST_ASSERT(apSuite, callback.mError == CHIP_IM_GLOBAL_STATUS(Busy)); + NL_TEST_ASSERT(apSuite, callback.mOnSuccessCalled == 0 && callback.mOnErrorCalled == 1 && callback.mOnDoneCalled == 1); + + // TODO: Check that the server gets the right status.. + // Client sents status report with invalid action, server's exchange has been closed, so it just sends an MRP ack. + NL_TEST_ASSERT(apSuite, ctx.GetLoopback().mSentMessageCount == 2); + + engine->Shutdown(); + ctx.ExpireSessionAliceToBob(); + ctx.ExpireSessionBobToAlice(); + ctx.CreateSessionAliceToBob(); + ctx.CreateSessionBobToAlice(); +} + } // namespace app } // namespace chip @@ -643,6 +918,10 @@ const nlTest sTests[] = #if CONFIG_BUILD_FOR_HOST_UNIT_TEST NL_TEST_DEF("TestWriteHandlerReceiveInvalidMessage", chip::app::TestWriteInteraction::TestWriteHandlerReceiveInvalidMessage), #endif + NL_TEST_DEF("TestWriteInvalidMessage1", chip::app::TestWriteInteraction::TestWriteInvalidMessage1), + NL_TEST_DEF("TestWriteInvalidMessage2", chip::app::TestWriteInteraction::TestWriteInvalidMessage2), + NL_TEST_DEF("TestWriteInvalidMessage3", chip::app::TestWriteInteraction::TestWriteInvalidMessage3), + NL_TEST_DEF("TestWriteInvalidMessage4", chip::app::TestWriteInteraction::TestWriteInvalidMessage4), NL_TEST_SENTINEL() }; // clang-format on