Skip to content

Commit

Permalink
Fix unit tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
bzbarsky-apple committed Jul 14, 2022
1 parent 78c92ec commit 3df266b
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 27 deletions.
48 changes: 42 additions & 6 deletions src/protocols/secure_channel/tests/TestCASESession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,39 @@ using TestContext = Test::LoopbackMessagingContext;
namespace chip {
namespace {

class TemporarySessionManager
{
public:
TemporarySessionManager(nlTestSuite * suite, TestContext & ctx) : mCtx(ctx)
{
NL_TEST_ASSERT(suite,
CHIP_NO_ERROR ==
mSessionManager.Init(&ctx.GetSystemLayer(), &ctx.GetTransportMgr(), &ctx.GetMessageCounterManager(),
&mStorage, &ctx.GetFabricTable()));
// The setup here is really weird: we are using one session manager for
// the actual messages we send (the PASE handshake, so the
// unauthenticated sessions) and a different one for allocating the PASE
// sessions. Since our Init() set us up as the thing to handle messages
// on the transport manager, undo that.
mCtx.GetTransportMgr().SetSessionManager(&mCtx.GetSecureSessionManager());
}

~TemporarySessionManager()
{
mSessionManager.Shutdown();
// Reset the session manager on the transport again, just in case
// shutdown messed with it.
mCtx.GetTransportMgr().SetSessionManager(&mCtx.GetSecureSessionManager());
}

operator SessionManager &() { return mSessionManager; }

private:
TestContext & mCtx;
TestPersistentStorageDelegate mStorage;
SessionManager mSessionManager;
};

CHIP_ERROR InitFabricTable(chip::FabricTable & fabricTable, chip::TestPersistentStorageDelegate * testStorage,
chip::Crypto::OperationalKeystore * opKeyStore,
chip::Credentials::PersistentStorageOpCertStore * opCertStore)
Expand Down Expand Up @@ -281,7 +314,8 @@ class TestCASESession

void TestCASESession::SecurePairingWaitTest(nlTestSuite * inSuite, void * inContext)
{
SessionManager sessionManager;
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);
TemporarySessionManager sessionManager(inSuite, ctx);

// Test all combinations of invalid parameters
TestCASESecurePairingDelegate delegate;
Expand Down Expand Up @@ -312,7 +346,7 @@ void TestCASESession::SecurePairingWaitTest(nlTestSuite * inSuite, void * inCont
void TestCASESession::SecurePairingStartTest(nlTestSuite * inSuite, void * inContext)
{
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);
SessionManager sessionManager;
TemporarySessionManager sessionManager(inSuite, ctx);

// Test all combinations of invalid parameters
TestCASESecurePairingDelegate delegate;
Expand Down Expand Up @@ -425,7 +459,9 @@ void SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inContext, S

void TestCASESession::SecurePairingHandshakeTest(nlTestSuite * inSuite, void * inContext)
{
SessionManager sessionManager;
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);
TemporarySessionManager sessionManager(inSuite, ctx);

TestCASESecurePairingDelegate delegateCommissioner;
CASESession pairingCommissioner;
pairingCommissioner.SetGroupDataProvider(&gCommissionerGroupDataProvider);
Expand Down Expand Up @@ -902,13 +938,13 @@ void TestCASESession::SessionResumptionStorage(nlTestSuite * inSuite, void * inC
#if CONFIG_BUILD_FOR_HOST_UNIT_TEST
void TestCASESession::SimulateUpdateNOCInvalidatePendingEstablishment(nlTestSuite * inSuite, void * inContext)
{
SessionManager sessionManager;
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);
TemporarySessionManager sessionManager(inSuite, ctx);

TestCASESecurePairingDelegate delegateCommissioner;
CASESession pairingCommissioner;
pairingCommissioner.SetGroupDataProvider(&gCommissionerGroupDataProvider);

TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);

TestCASESecurePairingDelegate delegateAccessory;
CASESession pairingAccessory;

Expand Down
74 changes: 56 additions & 18 deletions src/protocols/secure_channel/tests/TestPASESession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,45 @@ class MockAppDelegate : public ExchangeDelegate
void OnResponseTimeout(ExchangeContext * ec) override {}
};

class TemporarySessionManager
{
public:
TemporarySessionManager(nlTestSuite * suite, TestContext & ctx) : mCtx(ctx)
{
NL_TEST_ASSERT(suite,
CHIP_NO_ERROR ==
mSessionManager.Init(&ctx.GetSystemLayer(), &ctx.GetTransportMgr(), &ctx.GetMessageCounterManager(),
&mStorage, &ctx.GetFabricTable()));
// The setup here is really weird: we are using one session manager for
// the actual messages we send (the PASE handshake, so the
// unauthenticated sessions) and a different one for allocating the PASE
// sessions. Since our Init() set us up as the thing to handle messages
// on the transport manager, undo that.
mCtx.GetTransportMgr().SetSessionManager(&mCtx.GetSecureSessionManager());
}

~TemporarySessionManager()
{
mSessionManager.Shutdown();
// Reset the session manager on the transport again, just in case
// shutdown messed with it.
mCtx.GetTransportMgr().SetSessionManager(&mCtx.GetSecureSessionManager());
}

operator SessionManager &() { return mSessionManager; }

private:
TestContext & mCtx;
TestPersistentStorageDelegate mStorage;
SessionManager mSessionManager;
};

using namespace System::Clock::Literals;

void SecurePairingWaitTest(nlTestSuite * inSuite, void * inContext)
{
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);
SessionManager sessionManager;
TemporarySessionManager sessionManager(inSuite, ctx);

// Test all combinations of invalid parameters
TestSecurePairingDelegate delegate;
Expand Down Expand Up @@ -157,7 +190,7 @@ void SecurePairingWaitTest(nlTestSuite * inSuite, void * inContext)
void SecurePairingStartTest(nlTestSuite * inSuite, void * inContext)
{
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);
SessionManager sessionManager;
TemporarySessionManager sessionManager(inSuite, ctx);

// Test all combinations of invalid parameters
TestSecurePairingDelegate delegate;
Expand Down Expand Up @@ -285,11 +318,12 @@ void SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inContext, S

void SecurePairingHandshakeTest(nlTestSuite * inSuite, void * inContext)
{
SessionManager sessionManager;
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);
TemporarySessionManager sessionManager(inSuite, ctx);

TestSecurePairingDelegate delegateCommissioner;
PASESession pairingCommissioner;
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);
auto & loopback = ctx.GetLoopback();
auto & loopback = ctx.GetLoopback();
loopback.Reset();
SecurePairingHandshakeTestCommon(inSuite, inContext, sessionManager, pairingCommissioner,
Optional<ReliableMessageProtocolConfig>::Missing(),
Expand All @@ -298,11 +332,12 @@ void SecurePairingHandshakeTest(nlTestSuite * inSuite, void * inContext)

void SecurePairingHandshakeWithCommissionerMRPTest(nlTestSuite * inSuite, void * inContext)
{
SessionManager sessionManager;
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);
TemporarySessionManager sessionManager(inSuite, ctx);

TestSecurePairingDelegate delegateCommissioner;
PASESession pairingCommissioner;
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);
auto & loopback = ctx.GetLoopback();
auto & loopback = ctx.GetLoopback();
loopback.Reset();
ReliableMessageProtocolConfig config(1000_ms32, 10000_ms32);
SecurePairingHandshakeTestCommon(inSuite, inContext, sessionManager, pairingCommissioner,
Expand All @@ -312,11 +347,12 @@ void SecurePairingHandshakeWithCommissionerMRPTest(nlTestSuite * inSuite, void *

void SecurePairingHandshakeWithDeviceMRPTest(nlTestSuite * inSuite, void * inContext)
{
SessionManager sessionManager;
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);
TemporarySessionManager sessionManager(inSuite, ctx);

TestSecurePairingDelegate delegateCommissioner;
PASESession pairingCommissioner;
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);
auto & loopback = ctx.GetLoopback();
auto & loopback = ctx.GetLoopback();
loopback.Reset();
ReliableMessageProtocolConfig config(1000_ms32, 10000_ms32);
SecurePairingHandshakeTestCommon(inSuite, inContext, sessionManager, pairingCommissioner,
Expand All @@ -326,11 +362,12 @@ void SecurePairingHandshakeWithDeviceMRPTest(nlTestSuite * inSuite, void * inCon

void SecurePairingHandshakeWithAllMRPTest(nlTestSuite * inSuite, void * inContext)
{
SessionManager sessionManager;
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);
TemporarySessionManager sessionManager(inSuite, ctx);

TestSecurePairingDelegate delegateCommissioner;
PASESession pairingCommissioner;
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);
auto & loopback = ctx.GetLoopback();
auto & loopback = ctx.GetLoopback();
loopback.Reset();
ReliableMessageProtocolConfig commissionerConfig(1000_ms32, 10000_ms32);
ReliableMessageProtocolConfig deviceConfig(2000_ms32, 7000_ms32);
Expand All @@ -341,11 +378,12 @@ void SecurePairingHandshakeWithAllMRPTest(nlTestSuite * inSuite, void * inContex

void SecurePairingHandshakeWithPacketLossTest(nlTestSuite * inSuite, void * inContext)
{
SessionManager sessionManager;
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);
TemporarySessionManager sessionManager(inSuite, ctx);

TestSecurePairingDelegate delegateCommissioner;
PASESession pairingCommissioner;
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);
auto & loopback = ctx.GetLoopback();
auto & loopback = ctx.GetLoopback();
loopback.Reset();
loopback.mNumMessagesToDrop = 2;
SecurePairingHandshakeTestCommon(inSuite, inContext, sessionManager, pairingCommissioner,
Expand All @@ -358,7 +396,7 @@ void SecurePairingHandshakeWithPacketLossTest(nlTestSuite * inSuite, void * inCo
void SecurePairingFailedHandshake(nlTestSuite * inSuite, void * inContext)
{
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);
SessionManager sessionManager;
TemporarySessionManager sessionManager(inSuite, ctx);

TestSecurePairingDelegate delegateCommissioner;
PASESession pairingCommissioner;
Expand Down
42 changes: 39 additions & 3 deletions src/transport/tests/TestSessionManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,19 @@ static void RandomSessionIdAllocatorOffset(nlTestSuite * inSuite, SessionManager

void SessionAllocationTest(nlTestSuite * inSuite, void * inContext)
{
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);

FabricTableHolder fabricTableHolder;
NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == fabricTableHolder.Init());

secure_channel::MessageCounterManager messageCounterManager;
TestPersistentStorageDelegate deviceStorage1, deviceStorage2;

SessionManager sessionManager;
NL_TEST_ASSERT(inSuite,
CHIP_NO_ERROR ==
sessionManager.Init(&ctx.GetSystemLayer(), &ctx.GetTransportMgr(), &messageCounterManager, &deviceStorage1,
&fabricTableHolder.GetFabricTable()));

// Allocate a session.
uint16_t sessionId1;
Expand Down Expand Up @@ -735,10 +747,24 @@ void SessionAllocationTest(nlTestSuite * inSuite, void * inContext)
}

// Reconstruct the Session Manager to reset state.
sessionManager.Shutdown();
sessionManager.~SessionManager();
new (&sessionManager) SessionManager();
NL_TEST_ASSERT(inSuite,
CHIP_NO_ERROR ==
sessionManager.Init(&ctx.GetSystemLayer(), &ctx.GetTransportMgr(), &messageCounterManager, &deviceStorage2,
&fabricTableHolder.GetFabricTable()));

// Allocate a single session so we know what random id we are starting at.
{
auto handle = sessionManager.AllocateSession(
Transport::SecureSession::Type::kPASE,
ScopedNodeId(NodeIdFromPAKEKeyId(kDefaultCommissioningPasscodeId), kUndefinedFabricIndex));
NL_TEST_ASSERT(inSuite, handle.HasValue());
prevSessionId = handle.Value()->AsSecureSession()->GetLocalSessionId();
handle.Value()->AsSecureSession()->MarkForEviction();
}

prevSessionId = 0;
// Verify that we increment session ID by 1 for each allocation (except for
// the wraparound case where we skip session ID 0), even when allocated
// sessions are immediately freed.
Expand Down Expand Up @@ -886,6 +912,8 @@ void SessionCounterExhaustedTest(nlTestSuite * inSuite, void * inContext)

static void SessionShiftingTest(nlTestSuite * inSuite, void * inContext)
{
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);

IPAddress addr;
IPAddress::FromString("::1", addr);

Expand All @@ -894,9 +922,17 @@ static void SessionShiftingTest(nlTestSuite * inSuite, void * inContext)
FabricIndex aliceFabricIndex = 1;
FabricIndex bobFabricIndex = 1;

FabricTableHolder fabricTableHolder;
secure_channel::MessageCounterManager messageCounterManager;
TestPersistentStorageDelegate deviceStorage;

SessionManager sessionManager;
secure_channel::MessageCounterManager gMessageCounterManager;
chip::TestPersistentStorageDelegate deviceStorage;

NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == fabricTableHolder.Init());
NL_TEST_ASSERT(inSuite,
CHIP_NO_ERROR ==
sessionManager.Init(&ctx.GetSystemLayer(), &ctx.GetTransportMgr(), &messageCounterManager, &deviceStorage,
&fabricTableHolder.GetFabricTable()));

Transport::PeerAddress peer(Transport::PeerAddress::UDP(addr, CHIP_PORT));

Expand Down

0 comments on commit 3df266b

Please sign in to comment.