diff --git a/src/protocols/secure_channel/tests/TestCASESession.cpp b/src/protocols/secure_channel/tests/TestCASESession.cpp index 792f60af032149..9c62a48e046908 100644 --- a/src/protocols/secure_channel/tests/TestCASESession.cpp +++ b/src/protocols/secure_channel/tests/TestCASESession.cpp @@ -57,6 +57,39 @@ using TestContext = Test::LoopbackMessagingContext; namespace chip { namespace { +class TemporarySessionManager +{ +public: + TemporarySessionManager(nlTestSuite * suite, TestContext & ctx) : mCtx(ctx) + { + NL_TEST_ASSERT(suite, + CHIP_NO_ERROR == + mSessionManager.Init(&ctx.GetSystemLayer(), &ctx.GetTransportMgr(), &ctx.GetMessageCounterManager(), + &mStorage, &ctx.GetFabricTable())); + // The setup here is really weird: we are using one session manager for + // the actual messages we send (the PASE handshake, so the + // unauthenticated sessions) and a different one for allocating the PASE + // sessions. Since our Init() set us up as the thing to handle messages + // on the transport manager, undo that. + mCtx.GetTransportMgr().SetSessionManager(&mCtx.GetSecureSessionManager()); + } + + ~TemporarySessionManager() + { + mSessionManager.Shutdown(); + // Reset the session manager on the transport again, just in case + // shutdown messed with it. + mCtx.GetTransportMgr().SetSessionManager(&mCtx.GetSecureSessionManager()); + } + + operator SessionManager &() { return mSessionManager; } + +private: + TestContext & mCtx; + TestPersistentStorageDelegate mStorage; + SessionManager mSessionManager; +}; + CHIP_ERROR InitFabricTable(chip::FabricTable & fabricTable, chip::TestPersistentStorageDelegate * testStorage, chip::Crypto::OperationalKeystore * opKeyStore, chip::Credentials::PersistentStorageOpCertStore * opCertStore) @@ -281,7 +314,8 @@ class TestCASESession void TestCASESession::SecurePairingWaitTest(nlTestSuite * inSuite, void * inContext) { - SessionManager sessionManager; + TestContext & ctx = *reinterpret_cast(inContext); + TemporarySessionManager sessionManager(inSuite, ctx); // Test all combinations of invalid parameters TestCASESecurePairingDelegate delegate; @@ -312,7 +346,7 @@ void TestCASESession::SecurePairingWaitTest(nlTestSuite * inSuite, void * inCont void TestCASESession::SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) { TestContext & ctx = *reinterpret_cast(inContext); - SessionManager sessionManager; + TemporarySessionManager sessionManager(inSuite, ctx); // Test all combinations of invalid parameters TestCASESecurePairingDelegate delegate; @@ -425,7 +459,9 @@ void SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inContext, S void TestCASESession::SecurePairingHandshakeTest(nlTestSuite * inSuite, void * inContext) { - SessionManager sessionManager; + TestContext & ctx = *reinterpret_cast(inContext); + TemporarySessionManager sessionManager(inSuite, ctx); + TestCASESecurePairingDelegate delegateCommissioner; CASESession pairingCommissioner; pairingCommissioner.SetGroupDataProvider(&gCommissionerGroupDataProvider); @@ -902,13 +938,13 @@ void TestCASESession::SessionResumptionStorage(nlTestSuite * inSuite, void * inC #if CONFIG_BUILD_FOR_HOST_UNIT_TEST void TestCASESession::SimulateUpdateNOCInvalidatePendingEstablishment(nlTestSuite * inSuite, void * inContext) { - SessionManager sessionManager; + TestContext & ctx = *reinterpret_cast(inContext); + TemporarySessionManager sessionManager(inSuite, ctx); + TestCASESecurePairingDelegate delegateCommissioner; CASESession pairingCommissioner; pairingCommissioner.SetGroupDataProvider(&gCommissionerGroupDataProvider); - TestContext & ctx = *reinterpret_cast(inContext); - TestCASESecurePairingDelegate delegateAccessory; CASESession pairingAccessory; diff --git a/src/protocols/secure_channel/tests/TestPASESession.cpp b/src/protocols/secure_channel/tests/TestPASESession.cpp index cbf3bbe1a39f5b..932d665a22b2c8 100644 --- a/src/protocols/secure_channel/tests/TestPASESession.cpp +++ b/src/protocols/secure_channel/tests/TestPASESession.cpp @@ -111,12 +111,45 @@ class MockAppDelegate : public ExchangeDelegate void OnResponseTimeout(ExchangeContext * ec) override {} }; +class TemporarySessionManager +{ +public: + TemporarySessionManager(nlTestSuite * suite, TestContext & ctx) : mCtx(ctx) + { + NL_TEST_ASSERT(suite, + CHIP_NO_ERROR == + mSessionManager.Init(&ctx.GetSystemLayer(), &ctx.GetTransportMgr(), &ctx.GetMessageCounterManager(), + &mStorage, &ctx.GetFabricTable())); + // The setup here is really weird: we are using one session manager for + // the actual messages we send (the PASE handshake, so the + // unauthenticated sessions) and a different one for allocating the PASE + // sessions. Since our Init() set us up as the thing to handle messages + // on the transport manager, undo that. + mCtx.GetTransportMgr().SetSessionManager(&mCtx.GetSecureSessionManager()); + } + + ~TemporarySessionManager() + { + mSessionManager.Shutdown(); + // Reset the session manager on the transport again, just in case + // shutdown messed with it. + mCtx.GetTransportMgr().SetSessionManager(&mCtx.GetSecureSessionManager()); + } + + operator SessionManager &() { return mSessionManager; } + +private: + TestContext & mCtx; + TestPersistentStorageDelegate mStorage; + SessionManager mSessionManager; +}; + using namespace System::Clock::Literals; void SecurePairingWaitTest(nlTestSuite * inSuite, void * inContext) { TestContext & ctx = *reinterpret_cast(inContext); - SessionManager sessionManager; + TemporarySessionManager sessionManager(inSuite, ctx); // Test all combinations of invalid parameters TestSecurePairingDelegate delegate; @@ -157,7 +190,7 @@ void SecurePairingWaitTest(nlTestSuite * inSuite, void * inContext) void SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) { TestContext & ctx = *reinterpret_cast(inContext); - SessionManager sessionManager; + TemporarySessionManager sessionManager(inSuite, ctx); // Test all combinations of invalid parameters TestSecurePairingDelegate delegate; @@ -285,11 +318,12 @@ void SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inContext, S void SecurePairingHandshakeTest(nlTestSuite * inSuite, void * inContext) { - SessionManager sessionManager; + TestContext & ctx = *reinterpret_cast(inContext); + TemporarySessionManager sessionManager(inSuite, ctx); + TestSecurePairingDelegate delegateCommissioner; PASESession pairingCommissioner; - TestContext & ctx = *reinterpret_cast(inContext); - auto & loopback = ctx.GetLoopback(); + auto & loopback = ctx.GetLoopback(); loopback.Reset(); SecurePairingHandshakeTestCommon(inSuite, inContext, sessionManager, pairingCommissioner, Optional::Missing(), @@ -298,11 +332,12 @@ void SecurePairingHandshakeTest(nlTestSuite * inSuite, void * inContext) void SecurePairingHandshakeWithCommissionerMRPTest(nlTestSuite * inSuite, void * inContext) { - SessionManager sessionManager; + TestContext & ctx = *reinterpret_cast(inContext); + TemporarySessionManager sessionManager(inSuite, ctx); + TestSecurePairingDelegate delegateCommissioner; PASESession pairingCommissioner; - TestContext & ctx = *reinterpret_cast(inContext); - auto & loopback = ctx.GetLoopback(); + auto & loopback = ctx.GetLoopback(); loopback.Reset(); ReliableMessageProtocolConfig config(1000_ms32, 10000_ms32); SecurePairingHandshakeTestCommon(inSuite, inContext, sessionManager, pairingCommissioner, @@ -312,11 +347,12 @@ void SecurePairingHandshakeWithCommissionerMRPTest(nlTestSuite * inSuite, void * void SecurePairingHandshakeWithDeviceMRPTest(nlTestSuite * inSuite, void * inContext) { - SessionManager sessionManager; + TestContext & ctx = *reinterpret_cast(inContext); + TemporarySessionManager sessionManager(inSuite, ctx); + TestSecurePairingDelegate delegateCommissioner; PASESession pairingCommissioner; - TestContext & ctx = *reinterpret_cast(inContext); - auto & loopback = ctx.GetLoopback(); + auto & loopback = ctx.GetLoopback(); loopback.Reset(); ReliableMessageProtocolConfig config(1000_ms32, 10000_ms32); SecurePairingHandshakeTestCommon(inSuite, inContext, sessionManager, pairingCommissioner, @@ -326,11 +362,12 @@ void SecurePairingHandshakeWithDeviceMRPTest(nlTestSuite * inSuite, void * inCon void SecurePairingHandshakeWithAllMRPTest(nlTestSuite * inSuite, void * inContext) { - SessionManager sessionManager; + TestContext & ctx = *reinterpret_cast(inContext); + TemporarySessionManager sessionManager(inSuite, ctx); + TestSecurePairingDelegate delegateCommissioner; PASESession pairingCommissioner; - TestContext & ctx = *reinterpret_cast(inContext); - auto & loopback = ctx.GetLoopback(); + auto & loopback = ctx.GetLoopback(); loopback.Reset(); ReliableMessageProtocolConfig commissionerConfig(1000_ms32, 10000_ms32); ReliableMessageProtocolConfig deviceConfig(2000_ms32, 7000_ms32); @@ -341,11 +378,12 @@ void SecurePairingHandshakeWithAllMRPTest(nlTestSuite * inSuite, void * inContex void SecurePairingHandshakeWithPacketLossTest(nlTestSuite * inSuite, void * inContext) { - SessionManager sessionManager; + TestContext & ctx = *reinterpret_cast(inContext); + TemporarySessionManager sessionManager(inSuite, ctx); + TestSecurePairingDelegate delegateCommissioner; PASESession pairingCommissioner; - TestContext & ctx = *reinterpret_cast(inContext); - auto & loopback = ctx.GetLoopback(); + auto & loopback = ctx.GetLoopback(); loopback.Reset(); loopback.mNumMessagesToDrop = 2; SecurePairingHandshakeTestCommon(inSuite, inContext, sessionManager, pairingCommissioner, @@ -358,7 +396,7 @@ void SecurePairingHandshakeWithPacketLossTest(nlTestSuite * inSuite, void * inCo void SecurePairingFailedHandshake(nlTestSuite * inSuite, void * inContext) { TestContext & ctx = *reinterpret_cast(inContext); - SessionManager sessionManager; + TemporarySessionManager sessionManager(inSuite, ctx); TestSecurePairingDelegate delegateCommissioner; PASESession pairingCommissioner; diff --git a/src/transport/tests/TestSessionManager.cpp b/src/transport/tests/TestSessionManager.cpp index 45c5468ca61a46..da302ea9bb571c 100644 --- a/src/transport/tests/TestSessionManager.cpp +++ b/src/transport/tests/TestSessionManager.cpp @@ -702,7 +702,19 @@ static void RandomSessionIdAllocatorOffset(nlTestSuite * inSuite, SessionManager void SessionAllocationTest(nlTestSuite * inSuite, void * inContext) { + TestContext & ctx = *reinterpret_cast(inContext); + + FabricTableHolder fabricTableHolder; + NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == fabricTableHolder.Init()); + + secure_channel::MessageCounterManager messageCounterManager; + TestPersistentStorageDelegate deviceStorage1, deviceStorage2; + SessionManager sessionManager; + NL_TEST_ASSERT(inSuite, + CHIP_NO_ERROR == + sessionManager.Init(&ctx.GetSystemLayer(), &ctx.GetTransportMgr(), &messageCounterManager, &deviceStorage1, + &fabricTableHolder.GetFabricTable())); // Allocate a session. uint16_t sessionId1; @@ -735,10 +747,24 @@ void SessionAllocationTest(nlTestSuite * inSuite, void * inContext) } // Reconstruct the Session Manager to reset state. + sessionManager.Shutdown(); sessionManager.~SessionManager(); new (&sessionManager) SessionManager(); + NL_TEST_ASSERT(inSuite, + CHIP_NO_ERROR == + sessionManager.Init(&ctx.GetSystemLayer(), &ctx.GetTransportMgr(), &messageCounterManager, &deviceStorage2, + &fabricTableHolder.GetFabricTable())); + + // Allocate a single session so we know what random id we are starting at. + { + auto handle = sessionManager.AllocateSession( + Transport::SecureSession::Type::kPASE, + ScopedNodeId(NodeIdFromPAKEKeyId(kDefaultCommissioningPasscodeId), kUndefinedFabricIndex)); + NL_TEST_ASSERT(inSuite, handle.HasValue()); + prevSessionId = handle.Value()->AsSecureSession()->GetLocalSessionId(); + handle.Value()->AsSecureSession()->MarkForEviction(); + } - prevSessionId = 0; // Verify that we increment session ID by 1 for each allocation (except for // the wraparound case where we skip session ID 0), even when allocated // sessions are immediately freed. @@ -886,6 +912,8 @@ void SessionCounterExhaustedTest(nlTestSuite * inSuite, void * inContext) static void SessionShiftingTest(nlTestSuite * inSuite, void * inContext) { + TestContext & ctx = *reinterpret_cast(inContext); + IPAddress addr; IPAddress::FromString("::1", addr); @@ -894,9 +922,17 @@ static void SessionShiftingTest(nlTestSuite * inSuite, void * inContext) FabricIndex aliceFabricIndex = 1; FabricIndex bobFabricIndex = 1; + FabricTableHolder fabricTableHolder; + secure_channel::MessageCounterManager messageCounterManager; + TestPersistentStorageDelegate deviceStorage; + SessionManager sessionManager; - secure_channel::MessageCounterManager gMessageCounterManager; - chip::TestPersistentStorageDelegate deviceStorage; + + NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == fabricTableHolder.Init()); + NL_TEST_ASSERT(inSuite, + CHIP_NO_ERROR == + sessionManager.Init(&ctx.GetSystemLayer(), &ctx.GetTransportMgr(), &messageCounterManager, &deviceStorage, + &fabricTableHolder.GetFabricTable())); Transport::PeerAddress peer(Transport::PeerAddress::UDP(addr, CHIP_PORT));