Skip to content

Commit

Permalink
Remove time source template argument of sessions (#12791)
Browse files Browse the repository at this point in the history
  • Loading branch information
kghost authored and pull[bot] committed Oct 14, 2023
1 parent beb2dfe commit 3384227
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 51 deletions.
10 changes: 4 additions & 6 deletions src/transport/SecureSession.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,11 @@ class SecureSession
};

SecureSession(Type secureSessionType, uint16_t localSessionId, NodeId peerNodeId, CATValues peerCATs, uint16_t peerSessionId,
FabricIndex fabric, const ReliableMessageProtocolConfig & config, System::Clock::Timestamp currentTime) :
FabricIndex fabric, const ReliableMessageProtocolConfig & config) :
mSecureSessionType(secureSessionType),
mPeerNodeId(peerNodeId), mPeerCATs(peerCATs), mLocalSessionId(localSessionId), mPeerSessionId(peerSessionId),
mFabric(fabric), mMRPConfig(config)
{
SetLastActivityTime(currentTime);
}
mFabric(fabric), mLastActivityTime(System::SystemClock().GetMonotonicTimestamp()), mMRPConfig(config)
{}

SecureSession(SecureSession &&) = delete;
SecureSession(const SecureSession &) = delete;
Expand All @@ -95,7 +93,7 @@ class SecureSession
FabricIndex GetFabricIndex() const { return mFabric; }

System::Clock::Timestamp GetLastActivityTime() const { return mLastActivityTime; }
void SetLastActivityTime(System::Clock::Timestamp value) { mLastActivityTime = value; }
void MarkActive() { mLastActivityTime = System::SystemClock().GetMonotonicTimestamp(); }

CryptoContext & GetCryptoContext() { return mCryptoContext; }

Expand Down
15 changes: 3 additions & 12 deletions src/transport/SecureSessionTable.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ constexpr const uint16_t kAnyKeyId = 0xffff;
* - handle session active time and expiration
* - allocate and free space for sessions.
*/
template <size_t kMaxSessionCount, Time::Source kTimeSource = Time::Source::kSystem>
template <size_t kMaxSessionCount>
class SecureSessionTable
{
public:
Expand All @@ -64,8 +64,7 @@ class SecureSessionTable
CATValues peerCATs, uint16_t peerSessionId, FabricIndex fabric,
const ReliableMessageProtocolConfig & config)
{
return mEntries.CreateObject(secureSessionType, localSessionId, peerNodeId, peerCATs, peerSessionId, fabric, config,
mTimeSource.GetMonotonicTimestamp());
return mEntries.CreateObject(secureSessionType, localSessionId, peerNodeId, peerCATs, peerSessionId, fabric, config);
}

void ReleaseSession(SecureSession * session) { mEntries.ReleaseObject(session); }
Expand Down Expand Up @@ -98,9 +97,6 @@ class SecureSessionTable
return result;
}

/// Convenience method to mark a session as active
void MarkSessionActive(SecureSession * state) { state->SetLastActivityTime(mTimeSource.GetMonotonicTimestamp()); }

/**
* Iterates through all active sessions and expires any sessions with an idle time
* larger than the given amount.
Expand All @@ -110,9 +106,8 @@ class SecureSessionTable
template <typename Callback>
void ExpireInactiveSessions(System::Clock::Timestamp maxIdleTime, Callback callback)
{
const System::Clock::Timestamp currentTime = mTimeSource.GetMonotonicTimestamp();
mEntries.ForEachActiveObject([&](auto session) {
if (session->GetLastActivityTime() + maxIdleTime < currentTime)
if (session->GetLastActivityTime() + maxIdleTime < System::SystemClock().GetMonotonicTimestamp())
{
callback(*session);
ReleaseSession(session);
Expand All @@ -121,11 +116,7 @@ class SecureSessionTable
});
}

/// Allows access to the underlying time source used for keeping track of session active time
Time::TimeSource<kTimeSource> & GetTimeSource() { return mTimeSource; }

private:
Time::TimeSource<kTimeSource> mTimeSource;
BitMapObjectPool<SecureSession, kMaxSessionCount> mEntries;
};

Expand Down
11 changes: 6 additions & 5 deletions src/transport/SessionManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ CHIP_ERROR SessionManager::SendPreparedMessage(const SessionHandle & sessionHand
}

// This marks any connection where we send data to as 'active'
mSecureSessions.MarkSessionActive(session);
session->MarkActive();

destination = &session->GetPeerAddress();

Expand All @@ -241,7 +241,7 @@ CHIP_ERROR SessionManager::SendPreparedMessage(const SessionHandle & sessionHand
else
{
auto unauthenticated = sessionHandle.GetUnauthenticatedSession();
mUnauthenticatedSessions.MarkSessionActive(unauthenticated);
unauthenticated->MarkActive();
destination = &unauthenticated->GetPeerAddress();

ChipLogProgress(Inet,
Expand Down Expand Up @@ -439,7 +439,7 @@ void SessionManager::MessageDispatch(const PacketHeader & packetHeader, const Tr
}
VerifyOrDie(err == CHIP_NO_ERROR);

mUnauthenticatedSessions.MarkSessionActive(session);
session->MarkActive();

PayloadHeader payloadHeader;
ReturnOnFailure(payloadHeader.DecodeAndConsume(msg));
Expand Down Expand Up @@ -502,7 +502,7 @@ void SessionManager::SecureUnicastMessageDispatch(const PacketHeader & packetHea
return;
}

mSecureSessions.MarkSessionActive(session);
session->MarkActive();

if (isDuplicate == SessionMessageDelegate::DuplicateMessage::Yes && !payloadHeader.NeedsAck())
{
Expand Down Expand Up @@ -628,7 +628,8 @@ void SessionManager::ExpiryTimerCallback(System::Layer * layer, void * param)
// TODO(#2279): session expiration is currently disabled until rekeying is supported
// the #ifdef should be removed after that.
mgr->mSecureSessions.ExpireInactiveSessions(
CHIP_PEER_CONNECTION_TIMEOUT_MS, [this](const Transport::SecureSession & state1) { HandleConnectionExpired(state1); });
System::SystemClock().GetMonotonicTimestamp(), System::Clock::Milliseconds32(CHIP_PEER_CONNECTION_TIMEOUT_MS),
[this](const Transport::SecureSession & state1) { HandleConnectionExpired(state1); });
#endif
mgr->ScheduleExpiryTimer(); // re-schedule the oneshot timer
}
Expand Down
19 changes: 4 additions & 15 deletions src/transport/UnauthenticatedSessionTable.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class UnauthenticatedSession : public ReferenceCounted<UnauthenticatedSession, U
{
public:
UnauthenticatedSession(const PeerAddress & address, const ReliableMessageProtocolConfig & config) :
mPeerAddress(address), mMRPConfig(config)
mPeerAddress(address), mLastActivityTime(System::SystemClock().GetMonotonicTimestamp()), mMRPConfig(config)
{}

UnauthenticatedSession(const UnauthenticatedSession &) = delete;
Expand All @@ -58,7 +58,7 @@ class UnauthenticatedSession : public ReferenceCounted<UnauthenticatedSession, U
UnauthenticatedSession & operator=(UnauthenticatedSession &&) = delete;

System::Clock::Timestamp GetLastActivityTime() const { return mLastActivityTime; }
void SetLastActivityTime(System::Clock::Timestamp value) { mLastActivityTime = value; }
void MarkActive() { mLastActivityTime = System::SystemClock().GetMonotonicTimestamp(); }

const PeerAddress & GetPeerAddress() const { return mPeerAddress; }

Expand All @@ -69,9 +69,8 @@ class UnauthenticatedSession : public ReferenceCounted<UnauthenticatedSession, U
PeerMessageCounter & GetPeerMessageCounter() { return mPeerMessageCounter; }

private:
System::Clock::Timestamp mLastActivityTime = System::Clock::kZero;

const PeerAddress mPeerAddress;
System::Clock::Timestamp mLastActivityTime;
ReliableMessageProtocolConfig mMRPConfig;
PeerMessageCounter mPeerMessageCounter;
};
Expand All @@ -84,7 +83,7 @@ class UnauthenticatedSession : public ReferenceCounted<UnauthenticatedSession, U
* hold by using UnauthenticatedSessionHandle, which increase the reference
* count by 1. If the reference count is not 0, the entry won't be pruned.
*/
template <size_t kMaxSessionCount, Time::Source kTimeSource = Time::Source::kSystem>
template <size_t kMaxSessionCount>
class UnauthenticatedSessionTable
{
public:
Expand Down Expand Up @@ -114,15 +113,6 @@ class UnauthenticatedSessionTable
}
}

/// Mark a session as active
void MarkSessionActive(UnauthenticatedSessionHandle session)
{
session->SetLastActivityTime(mTimeSource.GetMonotonicTimestamp());
}

/// Allows access to the underlying time source used for keeping track of session active time
Time::TimeSource<kTimeSource> & GetTimeSource() { return mTimeSource; }

private:
/**
* Allocates a new session out of the internal resource pool.
Expand Down Expand Up @@ -224,7 +214,6 @@ class UnauthenticatedSessionTable
return false;
}

Time::TimeSource<Time::Source::kSystem> mTimeSource;
BitMapObjectPool<UnauthenticatedSession, kMaxSessionCount> mEntries;
};

Expand Down
60 changes: 47 additions & 13 deletions src/transport/tests/TestPeerConnections.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,33 @@ const CATValues kPeer1CATs = { { 0xABCD0001, 0xABCE0100, 0xABCD0020 } };
const CATValues kPeer2CATs = { { 0xABCD0012, kUndefinedCAT, kUndefinedCAT } };
const CATValues kPeer3CATs;

class MockClock : public System::Clock::ClockBase
{
public:
System::Clock::Microseconds64 GetMonotonicMicroseconds64() override { return timeSource.GetMonotonicTimestamp(); }
System::Clock::Milliseconds64 GetMonotonicMilliseconds64() override { return timeSource.GetMonotonicTimestamp(); }
CHIP_ERROR GetClock_RealTime(System::Clock::Microseconds64 & aCurTime) override { return CHIP_ERROR_UNSUPPORTED_CHIP_FEATURE; }
CHIP_ERROR GetClock_RealTimeMS(System::Clock::Milliseconds64 & aCurTime) override
{
return CHIP_ERROR_UNSUPPORTED_CHIP_FEATURE;
}
CHIP_ERROR SetClock_RealTime(System::Clock::Microseconds64 aNewCurTime) override { return CHIP_ERROR_UNSUPPORTED_CHIP_FEATURE; }

System::Clock::Timestamp GetMonotonicTimestamp() { return timeSource.GetMonotonicTimestamp(); }
void SetMonotonicTimestamp(System::Clock::Timestamp value) { timeSource.SetMonotonicTimestamp(value); }

private:
Time::TimeSource<Time::Source::kTest> timeSource;
};

void TestBasicFunctionality(nlTestSuite * inSuite, void * inContext)
{
SecureSession * statePtr;
SecureSessionTable<2, Time::Source::kTest> connections;
connections.GetTimeSource().SetMonotonicTimestamp(100_ms64);
SecureSessionTable<2> connections;
MockClock clock;
System::Clock::ClockBase * realClock = &System::SystemClock();
System::Clock::Internal::SetSystemClockForTesting(&clock);
clock.SetMonotonicTimestamp(100_ms64);
CATValues peerCATs;

// Node ID 1, peer key 1, local key 2
Expand All @@ -90,12 +112,16 @@ void TestBasicFunctionality(nlTestSuite * inSuite, void * inContext)
statePtr = connections.CreateNewSecureSession(kPeer3SessionType, 6, kPeer3NodeId, kPeer3CATs, 5, 0 /* fabricIndex */,
gDefaultMRPConfig);
NL_TEST_ASSERT(inSuite, statePtr == nullptr);
System::Clock::Internal::SetSystemClockForTesting(realClock);
}

void TestFindByKeyId(nlTestSuite * inSuite, void * inContext)
{
SecureSession * statePtr;
SecureSessionTable<2, Time::Source::kTest> connections;
SecureSessionTable<2> connections;
MockClock clock;
System::Clock::ClockBase * realClock = &System::SystemClock();
System::Clock::Internal::SetSystemClockForTesting(&clock);

// Node ID 1, peer key 1, local key 2
statePtr = connections.CreateNewSecureSession(kPeer1SessionType, 2, kPeer1NodeId, kPeer1CATs, 1, 0 /* fabricIndex */,
Expand All @@ -112,6 +138,8 @@ void TestFindByKeyId(nlTestSuite * inSuite, void * inContext)

NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(3));
NL_TEST_ASSERT(inSuite, connections.FindSecureSessionByLocalKey(4));

System::Clock::Internal::SetSystemClockForTesting(realClock);
}

struct ExpiredCallInfo
Expand All @@ -125,25 +153,29 @@ void TestExpireConnections(nlTestSuite * inSuite, void * inContext)
{
ExpiredCallInfo callInfo;
SecureSession * statePtr;
SecureSessionTable<2, Time::Source::kTest> connections;
SecureSessionTable<2> connections;

MockClock clock;
System::Clock::ClockBase * realClock = &System::SystemClock();
System::Clock::Internal::SetSystemClockForTesting(&clock);

connections.GetTimeSource().SetMonotonicTimestamp(100_ms64);
clock.SetMonotonicTimestamp(100_ms64);

// Node ID 1, peer key 1, local key 2
statePtr = connections.CreateNewSecureSession(kPeer1SessionType, 2, kPeer1NodeId, kPeer1CATs, 1, 0 /* fabricIndex */,
gDefaultMRPConfig);
NL_TEST_ASSERT(inSuite, statePtr != nullptr);
statePtr->SetPeerAddress(kPeer1Addr);

connections.GetTimeSource().SetMonotonicTimestamp(200_ms64);
clock.SetMonotonicTimestamp(200_ms64);
// Node ID 2, peer key 3, local key 4
statePtr = connections.CreateNewSecureSession(kPeer2SessionType, 4, kPeer2NodeId, kPeer2CATs, 3, 0 /* fabricIndex */,
gDefaultMRPConfig);
NL_TEST_ASSERT(inSuite, statePtr != nullptr);
statePtr->SetPeerAddress(kPeer2Addr);

// cannot add before expiry
connections.GetTimeSource().SetMonotonicTimestamp(300_ms64);
clock.SetMonotonicTimestamp(300_ms64);
statePtr = connections.CreateNewSecureSession(kPeer3SessionType, 6, kPeer3NodeId, kPeer3CATs, 5, 0 /* fabricIndex */,
gDefaultMRPConfig);
NL_TEST_ASSERT(inSuite, statePtr == nullptr);
Expand All @@ -160,24 +192,24 @@ void TestExpireConnections(nlTestSuite * inSuite, void * inContext)
NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(2));

// now that the connections were expired, we can add peer3
connections.GetTimeSource().SetMonotonicTimestamp(300_ms64);
clock.SetMonotonicTimestamp(300_ms64);
// Node ID 3, peer key 5, local key 6
statePtr = connections.CreateNewSecureSession(kPeer3SessionType, 6, kPeer3NodeId, kPeer3CATs, 5, 0 /* fabricIndex */,
gDefaultMRPConfig);
NL_TEST_ASSERT(inSuite, statePtr != nullptr);
statePtr->SetPeerAddress(kPeer3Addr);

connections.GetTimeSource().SetMonotonicTimestamp(400_ms64);
clock.SetMonotonicTimestamp(400_ms64);
NL_TEST_ASSERT(inSuite, statePtr = connections.FindSecureSessionByLocalKey(4));

connections.MarkSessionActive(statePtr);
NL_TEST_ASSERT(inSuite, statePtr->GetLastActivityTime() == connections.GetTimeSource().GetMonotonicTimestamp());
statePtr->MarkActive();
NL_TEST_ASSERT(inSuite, statePtr->GetLastActivityTime() == clock.GetMonotonicTimestamp());

// At this time:
// Peer 3 active at time 300
// Peer 2 active at time 400

connections.GetTimeSource().SetMonotonicTimestamp(500_ms64);
clock.SetMonotonicTimestamp(500_ms64);
callInfo.callCount = 0;
connections.ExpireInactiveSessions(150_ms64, [&callInfo](const SecureSession & state) {
callInfo.callCount++;
Expand All @@ -202,7 +234,7 @@ void TestExpireConnections(nlTestSuite * inSuite, void * inContext)
NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(6));

// peer 1 and 2 are active
connections.GetTimeSource().SetMonotonicTimestamp(1000_ms64);
clock.SetMonotonicTimestamp(1000_ms64);
callInfo.callCount = 0;
connections.ExpireInactiveSessions(100_ms64, [&callInfo](const SecureSession & state) {
callInfo.callCount++;
Expand All @@ -213,6 +245,8 @@ void TestExpireConnections(nlTestSuite * inSuite, void * inContext)
NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(2));
NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(4));
NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(6));

System::Clock::Internal::SetSystemClockForTesting(realClock);
}

} // namespace
Expand Down

0 comments on commit 3384227

Please sign in to comment.