diff --git a/src/transport/MessageCounter.h b/src/transport/MessageCounter.h index 70ce525e851b80..4ca7dc7206b967 100644 --- a/src/transport/MessageCounter.h +++ b/src/transport/MessageCounter.h @@ -109,6 +109,9 @@ class LocalSessionMessageCounter : public MessageCounter return CHIP_NO_ERROR; } + // Test-only function to set the counter value + void TestSetCounter(uint32_t value) { mValue = value; } + private: uint32_t mValue; }; diff --git a/src/transport/tests/TestSessionManager.cpp b/src/transport/tests/TestSessionManager.cpp index cc9179a92b2330..e77f9222f1ef3a 100644 --- a/src/transport/tests/TestSessionManager.cpp +++ b/src/transport/tests/TestSessionManager.cpp @@ -759,6 +759,83 @@ void SessionAllocationTest(nlTestSuite * inSuite, void * inContext) sessionManager.Shutdown(); } +void SessionCounterExhaustedTest(nlTestSuite * inSuite, void * inContext) +{ + TestContext & ctx = *reinterpret_cast(inContext); + + IPAddress addr; + IPAddress::FromString("::1", addr); + CHIP_ERROR err = CHIP_NO_ERROR; + + FabricTable fabricTable; + SessionManager sessionManager; + secure_channel::MessageCounterManager gMessageCounterManager; + chip::TestPersistentStorageDelegate deviceStorage; + + NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == fabricTable.Init(&deviceStorage)); + NL_TEST_ASSERT(inSuite, + CHIP_NO_ERROR == + sessionManager.Init(&ctx.GetSystemLayer(), &ctx.GetTransportMgr(), &gMessageCounterManager, &deviceStorage, + &fabricTable)); + + Transport::PeerAddress peer(Transport::PeerAddress::UDP(addr, CHIP_PORT)); + + FabricIndex aliceFabricIndex; + FabricInfo aliceFabric; + aliceFabric.TestOnlyBuildFabric(GetRootACertAsset().mCert, GetIAA1CertAsset().mCert, GetNodeA1CertAsset().mCert, + GetNodeA1CertAsset().mKey); + NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == fabricTable.AddNewFabricForTest(aliceFabric, &aliceFabricIndex)); + + FabricIndex bobFabricIndex; + FabricInfo bobFabric; + bobFabric.TestOnlyBuildFabric(GetRootACertAsset().mCert, GetIAA1CertAsset().mCert, GetNodeA2CertAsset().mCert, + GetNodeA2CertAsset().mKey); + NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == fabricTable.AddNewFabricForTest(bobFabric, &bobFabricIndex)); + + SessionHolder aliceToBobSession; + err = sessionManager.InjectPaseSessionWithTestKey(aliceToBobSession, 2, + fabricTable.FindFabricWithIndex(bobFabricIndex)->GetNodeId(), 1, + aliceFabricIndex, peer, CryptoContext::SessionRole::kInitiator); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + + SessionHolder bobToAliceSession; + err = sessionManager.InjectPaseSessionWithTestKey(bobToAliceSession, 1, + fabricTable.FindFabricWithIndex(aliceFabricIndex)->GetNodeId(), 2, + bobFabricIndex, peer, CryptoContext::SessionRole::kResponder); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + + // ==== Set counter value to max ==== + LocalSessionMessageCounter & counter = static_cast( + aliceToBobSession.Get().Value()->AsSecureSession()->GetSessionMessageCounter().GetLocalMessageCounter()); + counter.TestSetCounter(LocalSessionMessageCounter::kMaxMessageCounter); + + // ==== Build a valid message with max counter value ==== + chip::System::PacketBufferHandle buffer = chip::MessagePacketBuffer::NewWithData(PAYLOAD, sizeof(PAYLOAD)); + NL_TEST_ASSERT(inSuite, !buffer.IsNull()); + + PayloadHeader payloadHeader; + + // Set the exchange ID for this header. + payloadHeader.SetExchangeID(0); + + // Set the protocol ID and message type for this header. + payloadHeader.SetMessageType(chip::Protocols::Echo::MsgType::EchoRequest); + + EncryptedPacketBufferHandle preparedMessage; + err = sessionManager.PrepareMessage(aliceToBobSession.Get().Value(), payloadHeader, std::move(buffer), preparedMessage); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + + // ==== Build another message which will fail becuase message counter is exhausted ==== + chip::System::PacketBufferHandle buffer2 = chip::MessagePacketBuffer::NewWithData(PAYLOAD, sizeof(PAYLOAD)); + NL_TEST_ASSERT(inSuite, !buffer2.IsNull()); + + EncryptedPacketBufferHandle preparedMessage2; + err = sessionManager.PrepareMessage(aliceToBobSession.Get().Value(), payloadHeader, std::move(buffer2), preparedMessage2); + NL_TEST_ASSERT(inSuite, err == CHIP_ERROR_MESSAGE_COUNTER_EXHAUSTED); + + sessionManager.Shutdown(); +} + // Test Suite /** @@ -774,6 +851,7 @@ const nlTest sTests[] = NL_TEST_DEF("Old counter Test", SendPacketWithOldCounterTest), NL_TEST_DEF("Too-old counter Test", SendPacketWithTooOldCounterTest), NL_TEST_DEF("Session Allocation Test", SessionAllocationTest), + NL_TEST_DEF("Session Counter Exhausted Test", SessionCounterExhaustedTest), NL_TEST_SENTINEL() };