Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove time source template argument of sessions #12791

Merged
merged 1 commit into from
Dec 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions src/transport/SecureSession.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,11 @@ class SecureSession
};

SecureSession(Type secureSessionType, uint16_t localSessionId, NodeId peerNodeId, CATValues peerCATs, uint16_t peerSessionId,
FabricIndex fabric, const ReliableMessageProtocolConfig & config, System::Clock::Timestamp currentTime) :
FabricIndex fabric, const ReliableMessageProtocolConfig & config) :
mSecureSessionType(secureSessionType),
mPeerNodeId(peerNodeId), mPeerCATs(peerCATs), mLocalSessionId(localSessionId), mPeerSessionId(peerSessionId),
mFabric(fabric), mMRPConfig(config)
{
SetLastActivityTime(currentTime);
}
mFabric(fabric), mLastActivityTime(System::SystemClock().GetMonotonicTimestamp()), mMRPConfig(config)
{}

SecureSession(SecureSession &&) = delete;
SecureSession(const SecureSession &) = delete;
Expand All @@ -95,7 +93,7 @@ class SecureSession
FabricIndex GetFabricIndex() const { return mFabric; }

System::Clock::Timestamp GetLastActivityTime() const { return mLastActivityTime; }
void SetLastActivityTime(System::Clock::Timestamp value) { mLastActivityTime = value; }
void MarkActive() { mLastActivityTime = System::SystemClock().GetMonotonicTimestamp(); }

CryptoContext & GetCryptoContext() { return mCryptoContext; }

Expand Down
15 changes: 3 additions & 12 deletions src/transport/SecureSessionTable.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ constexpr const uint16_t kAnyKeyId = 0xffff;
* - handle session active time and expiration
* - allocate and free space for sessions.
*/
template <size_t kMaxSessionCount, Time::Source kTimeSource = Time::Source::kSystem>
template <size_t kMaxSessionCount>
class SecureSessionTable
{
public:
Expand All @@ -64,8 +64,7 @@ class SecureSessionTable
CATValues peerCATs, uint16_t peerSessionId, FabricIndex fabric,
const ReliableMessageProtocolConfig & config)
{
return mEntries.CreateObject(secureSessionType, localSessionId, peerNodeId, peerCATs, peerSessionId, fabric, config,
mTimeSource.GetMonotonicTimestamp());
return mEntries.CreateObject(secureSessionType, localSessionId, peerNodeId, peerCATs, peerSessionId, fabric, config);
}

void ReleaseSession(SecureSession * session) { mEntries.ReleaseObject(session); }
Expand Down Expand Up @@ -98,9 +97,6 @@ class SecureSessionTable
return result;
}

/// Convenience method to mark a session as active
void MarkSessionActive(SecureSession * state) { state->SetLastActivityTime(mTimeSource.GetMonotonicTimestamp()); }

/**
* Iterates through all active sessions and expires any sessions with an idle time
* larger than the given amount.
Expand All @@ -110,9 +106,8 @@ class SecureSessionTable
template <typename Callback>
void ExpireInactiveSessions(System::Clock::Timestamp maxIdleTime, Callback callback)
{
const System::Clock::Timestamp currentTime = mTimeSource.GetMonotonicTimestamp();
mEntries.ForEachActiveObject([&](auto session) {
if (session->GetLastActivityTime() + maxIdleTime < currentTime)
if (session->GetLastActivityTime() + maxIdleTime < System::SystemClock().GetMonotonicTimestamp())
{
callback(*session);
ReleaseSession(session);
Expand All @@ -121,11 +116,7 @@ class SecureSessionTable
});
}

/// Allows access to the underlying time source used for keeping track of session active time
Time::TimeSource<kTimeSource> & GetTimeSource() { return mTimeSource; }

private:
Time::TimeSource<kTimeSource> mTimeSource;
BitMapObjectPool<SecureSession, kMaxSessionCount> mEntries;
};

Expand Down
11 changes: 6 additions & 5 deletions src/transport/SessionManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ CHIP_ERROR SessionManager::SendPreparedMessage(const SessionHandle & sessionHand
}

// This marks any connection where we send data to as 'active'
mSecureSessions.MarkSessionActive(session);
session->MarkActive();

destination = &session->GetPeerAddress();

Expand All @@ -241,7 +241,7 @@ CHIP_ERROR SessionManager::SendPreparedMessage(const SessionHandle & sessionHand
else
{
auto unauthenticated = sessionHandle.GetUnauthenticatedSession();
mUnauthenticatedSessions.MarkSessionActive(unauthenticated);
unauthenticated->MarkActive();
destination = &unauthenticated->GetPeerAddress();

ChipLogProgress(Inet,
Expand Down Expand Up @@ -439,7 +439,7 @@ void SessionManager::MessageDispatch(const PacketHeader & packetHeader, const Tr
}
VerifyOrDie(err == CHIP_NO_ERROR);

mUnauthenticatedSessions.MarkSessionActive(session);
session->MarkActive();

PayloadHeader payloadHeader;
ReturnOnFailure(payloadHeader.DecodeAndConsume(msg));
Expand Down Expand Up @@ -502,7 +502,7 @@ void SessionManager::SecureUnicastMessageDispatch(const PacketHeader & packetHea
return;
}

mSecureSessions.MarkSessionActive(session);
session->MarkActive();

if (isDuplicate == SessionMessageDelegate::DuplicateMessage::Yes && !payloadHeader.NeedsAck())
{
Expand Down Expand Up @@ -628,7 +628,8 @@ void SessionManager::ExpiryTimerCallback(System::Layer * layer, void * param)
// TODO(#2279): session expiration is currently disabled until rekeying is supported
// the #ifdef should be removed after that.
mgr->mSecureSessions.ExpireInactiveSessions(
CHIP_PEER_CONNECTION_TIMEOUT_MS, [this](const Transport::SecureSession & state1) { HandleConnectionExpired(state1); });
System::SystemClock().GetMonotonicTimestamp(), System::Clock::Milliseconds32(CHIP_PEER_CONNECTION_TIMEOUT_MS),
[this](const Transport::SecureSession & state1) { HandleConnectionExpired(state1); });
#endif
mgr->ScheduleExpiryTimer(); // re-schedule the oneshot timer
}
Expand Down
19 changes: 4 additions & 15 deletions src/transport/UnauthenticatedSessionTable.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class UnauthenticatedSession : public ReferenceCounted<UnauthenticatedSession, U
{
public:
UnauthenticatedSession(const PeerAddress & address, const ReliableMessageProtocolConfig & config) :
mPeerAddress(address), mMRPConfig(config)
mPeerAddress(address), mLastActivityTime(System::SystemClock().GetMonotonicTimestamp()), mMRPConfig(config)
{}

UnauthenticatedSession(const UnauthenticatedSession &) = delete;
Expand All @@ -58,7 +58,7 @@ class UnauthenticatedSession : public ReferenceCounted<UnauthenticatedSession, U
UnauthenticatedSession & operator=(UnauthenticatedSession &&) = delete;

System::Clock::Timestamp GetLastActivityTime() const { return mLastActivityTime; }
void SetLastActivityTime(System::Clock::Timestamp value) { mLastActivityTime = value; }
void MarkActive() { mLastActivityTime = System::SystemClock().GetMonotonicTimestamp(); }

const PeerAddress & GetPeerAddress() const { return mPeerAddress; }

Expand All @@ -69,9 +69,8 @@ class UnauthenticatedSession : public ReferenceCounted<UnauthenticatedSession, U
PeerMessageCounter & GetPeerMessageCounter() { return mPeerMessageCounter; }

private:
System::Clock::Timestamp mLastActivityTime = System::Clock::kZero;

const PeerAddress mPeerAddress;
System::Clock::Timestamp mLastActivityTime;
ReliableMessageProtocolConfig mMRPConfig;
PeerMessageCounter mPeerMessageCounter;
};
Expand All @@ -84,7 +83,7 @@ class UnauthenticatedSession : public ReferenceCounted<UnauthenticatedSession, U
* hold by using UnauthenticatedSessionHandle, which increase the reference
* count by 1. If the reference count is not 0, the entry won't be pruned.
*/
template <size_t kMaxSessionCount, Time::Source kTimeSource = Time::Source::kSystem>
template <size_t kMaxSessionCount>
class UnauthenticatedSessionTable
{
public:
Expand Down Expand Up @@ -114,15 +113,6 @@ class UnauthenticatedSessionTable
}
}

/// Mark a session as active
void MarkSessionActive(UnauthenticatedSessionHandle session)
{
session->SetLastActivityTime(mTimeSource.GetMonotonicTimestamp());
}

/// Allows access to the underlying time source used for keeping track of session active time
Time::TimeSource<kTimeSource> & GetTimeSource() { return mTimeSource; }

private:
/**
* Allocates a new session out of the internal resource pool.
Expand Down Expand Up @@ -224,7 +214,6 @@ class UnauthenticatedSessionTable
return false;
}

Time::TimeSource<Time::Source::kSystem> mTimeSource;
BitMapObjectPool<UnauthenticatedSession, kMaxSessionCount> mEntries;
};

Expand Down
60 changes: 47 additions & 13 deletions src/transport/tests/TestPeerConnections.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,33 @@ const CATValues kPeer1CATs = { { 0xABCD0001, 0xABCE0100, 0xABCD0020 } };
const CATValues kPeer2CATs = { { 0xABCD0012, kUndefinedCAT, kUndefinedCAT } };
const CATValues kPeer3CATs;

class MockClock : public System::Clock::ClockBase
{
public:
System::Clock::Microseconds64 GetMonotonicMicroseconds64() override { return timeSource.GetMonotonicTimestamp(); }
System::Clock::Milliseconds64 GetMonotonicMilliseconds64() override { return timeSource.GetMonotonicTimestamp(); }
CHIP_ERROR GetClock_RealTime(System::Clock::Microseconds64 & aCurTime) override { return CHIP_ERROR_UNSUPPORTED_CHIP_FEATURE; }
CHIP_ERROR GetClock_RealTimeMS(System::Clock::Milliseconds64 & aCurTime) override
{
return CHIP_ERROR_UNSUPPORTED_CHIP_FEATURE;
}
CHIP_ERROR SetClock_RealTime(System::Clock::Microseconds64 aNewCurTime) override { return CHIP_ERROR_UNSUPPORTED_CHIP_FEATURE; }

System::Clock::Timestamp GetMonotonicTimestamp() { return timeSource.GetMonotonicTimestamp(); }
void SetMonotonicTimestamp(System::Clock::Timestamp value) { timeSource.SetMonotonicTimestamp(value); }

private:
Time::TimeSource<Time::Source::kTest> timeSource;
};

void TestBasicFunctionality(nlTestSuite * inSuite, void * inContext)
{
SecureSession * statePtr;
SecureSessionTable<2, Time::Source::kTest> connections;
connections.GetTimeSource().SetMonotonicTimestamp(100_ms64);
SecureSessionTable<2> connections;
MockClock clock;
System::Clock::ClockBase * realClock = &System::SystemClock();
System::Clock::Internal::SetSystemClockForTesting(&clock);
clock.SetMonotonicTimestamp(100_ms64);
CATValues peerCATs;

// Node ID 1, peer key 1, local key 2
Expand All @@ -90,12 +112,16 @@ void TestBasicFunctionality(nlTestSuite * inSuite, void * inContext)
statePtr = connections.CreateNewSecureSession(kPeer3SessionType, 6, kPeer3NodeId, kPeer3CATs, 5, 0 /* fabricIndex */,
gDefaultMRPConfig);
NL_TEST_ASSERT(inSuite, statePtr == nullptr);
System::Clock::Internal::SetSystemClockForTesting(realClock);
}

void TestFindByKeyId(nlTestSuite * inSuite, void * inContext)
{
SecureSession * statePtr;
SecureSessionTable<2, Time::Source::kTest> connections;
SecureSessionTable<2> connections;
MockClock clock;
System::Clock::ClockBase * realClock = &System::SystemClock();
System::Clock::Internal::SetSystemClockForTesting(&clock);

// Node ID 1, peer key 1, local key 2
statePtr = connections.CreateNewSecureSession(kPeer1SessionType, 2, kPeer1NodeId, kPeer1CATs, 1, 0 /* fabricIndex */,
Expand All @@ -112,6 +138,8 @@ void TestFindByKeyId(nlTestSuite * inSuite, void * inContext)

NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(3));
NL_TEST_ASSERT(inSuite, connections.FindSecureSessionByLocalKey(4));

System::Clock::Internal::SetSystemClockForTesting(realClock);
}

struct ExpiredCallInfo
Expand All @@ -125,25 +153,29 @@ void TestExpireConnections(nlTestSuite * inSuite, void * inContext)
{
ExpiredCallInfo callInfo;
SecureSession * statePtr;
SecureSessionTable<2, Time::Source::kTest> connections;
SecureSessionTable<2> connections;

MockClock clock;
System::Clock::ClockBase * realClock = &System::SystemClock();
System::Clock::Internal::SetSystemClockForTesting(&clock);

connections.GetTimeSource().SetMonotonicTimestamp(100_ms64);
clock.SetMonotonicTimestamp(100_ms64);

// Node ID 1, peer key 1, local key 2
statePtr = connections.CreateNewSecureSession(kPeer1SessionType, 2, kPeer1NodeId, kPeer1CATs, 1, 0 /* fabricIndex */,
gDefaultMRPConfig);
NL_TEST_ASSERT(inSuite, statePtr != nullptr);
statePtr->SetPeerAddress(kPeer1Addr);

connections.GetTimeSource().SetMonotonicTimestamp(200_ms64);
clock.SetMonotonicTimestamp(200_ms64);
// Node ID 2, peer key 3, local key 4
statePtr = connections.CreateNewSecureSession(kPeer2SessionType, 4, kPeer2NodeId, kPeer2CATs, 3, 0 /* fabricIndex */,
gDefaultMRPConfig);
NL_TEST_ASSERT(inSuite, statePtr != nullptr);
statePtr->SetPeerAddress(kPeer2Addr);

// cannot add before expiry
connections.GetTimeSource().SetMonotonicTimestamp(300_ms64);
clock.SetMonotonicTimestamp(300_ms64);
statePtr = connections.CreateNewSecureSession(kPeer3SessionType, 6, kPeer3NodeId, kPeer3CATs, 5, 0 /* fabricIndex */,
gDefaultMRPConfig);
NL_TEST_ASSERT(inSuite, statePtr == nullptr);
Expand All @@ -160,24 +192,24 @@ void TestExpireConnections(nlTestSuite * inSuite, void * inContext)
NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(2));

// now that the connections were expired, we can add peer3
connections.GetTimeSource().SetMonotonicTimestamp(300_ms64);
clock.SetMonotonicTimestamp(300_ms64);
// Node ID 3, peer key 5, local key 6
statePtr = connections.CreateNewSecureSession(kPeer3SessionType, 6, kPeer3NodeId, kPeer3CATs, 5, 0 /* fabricIndex */,
gDefaultMRPConfig);
NL_TEST_ASSERT(inSuite, statePtr != nullptr);
statePtr->SetPeerAddress(kPeer3Addr);

connections.GetTimeSource().SetMonotonicTimestamp(400_ms64);
clock.SetMonotonicTimestamp(400_ms64);
NL_TEST_ASSERT(inSuite, statePtr = connections.FindSecureSessionByLocalKey(4));

connections.MarkSessionActive(statePtr);
NL_TEST_ASSERT(inSuite, statePtr->GetLastActivityTime() == connections.GetTimeSource().GetMonotonicTimestamp());
statePtr->MarkActive();
NL_TEST_ASSERT(inSuite, statePtr->GetLastActivityTime() == clock.GetMonotonicTimestamp());

// At this time:
// Peer 3 active at time 300
// Peer 2 active at time 400

connections.GetTimeSource().SetMonotonicTimestamp(500_ms64);
clock.SetMonotonicTimestamp(500_ms64);
callInfo.callCount = 0;
connections.ExpireInactiveSessions(150_ms64, [&callInfo](const SecureSession & state) {
callInfo.callCount++;
Expand All @@ -202,7 +234,7 @@ void TestExpireConnections(nlTestSuite * inSuite, void * inContext)
NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(6));

// peer 1 and 2 are active
connections.GetTimeSource().SetMonotonicTimestamp(1000_ms64);
clock.SetMonotonicTimestamp(1000_ms64);
callInfo.callCount = 0;
connections.ExpireInactiveSessions(100_ms64, [&callInfo](const SecureSession & state) {
callInfo.callCount++;
Expand All @@ -213,6 +245,8 @@ void TestExpireConnections(nlTestSuite * inSuite, void * inContext)
NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(2));
NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(4));
NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(6));

System::Clock::Internal::SetSystemClockForTesting(realClock);
}

} // namespace
Expand Down