Skip to content

Commit

Permalink
Use BitMapPool for SecureSessionTable
Browse files Browse the repository at this point in the history
  • Loading branch information
kghost committed Oct 28, 2021
1 parent 36329e1 commit 16c747b
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 322 deletions.
43 changes: 16 additions & 27 deletions src/transport/SecureSession.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,43 +49,31 @@ static constexpr uint32_t kUndefinedMessageIndex = UINT32_MAX;
class SecureSession
{
public:
SecureSession() : mPeerAddress(PeerAddress::Uninitialized()) {}
SecureSession(const PeerAddress & addr) : mPeerAddress(addr) {}
SecureSession(PeerAddress && addr) : mPeerAddress(addr) {}
SecureSession(uint16_t localSessionId, NodeId peerNodeId, uint16_t peerSessionId, FabricIndex fabric, uint64_t time) :
mPeerNodeId(peerNodeId), mLocalSessionId(localSessionId), mPeerSessionId(peerSessionId), mFabric(fabric)
{
SetLastActivityTimeMs(time);
}

SecureSession(SecureSession &&) = default;
SecureSession(const SecureSession &) = default;
SecureSession & operator=(const SecureSession &) = default;
SecureSession & operator=(SecureSession &&) = default;
SecureSession(SecureSession &&) = delete;
SecureSession(const SecureSession &) = delete;
SecureSession & operator=(const SecureSession &) = delete;
SecureSession & operator=(SecureSession &&) = delete;

const PeerAddress & GetPeerAddress() const { return mPeerAddress; }
PeerAddress & GetPeerAddress() { return mPeerAddress; }
void SetPeerAddress(const PeerAddress & address) { mPeerAddress = address; }

NodeId GetPeerNodeId() const { return mPeerNodeId; }
void SetPeerNodeId(NodeId peerNodeId) { mPeerNodeId = peerNodeId; }

uint16_t GetPeerSessionId() const { return mPeerSessionId; }
void SetPeerSessionId(uint16_t id) { mPeerSessionId = id; }

// TODO: Rename KeyID to SessionID
uint16_t GetLocalSessionId() const { return mLocalSessionId; }
void SetLocalSessionId(uint16_t id) { mLocalSessionId = id; }
uint16_t GetPeerSessionId() const { return mPeerSessionId; }
FabricIndex GetFabricIndex() const { return mFabric; }

uint64_t GetLastActivityTimeMs() const { return mLastActivityTimeMs; }
void SetLastActivityTimeMs(uint64_t value) { mLastActivityTimeMs = value; }

CryptoContext & GetCryptoContext() { return mCryptoContext; }

FabricIndex GetFabricIndex() const { return mFabric; }
void SetFabricIndex(FabricIndex fabricIndex) { mFabric = fabricIndex; }

bool IsInitialized()
{
return (mPeerAddress.IsInitialized() || mPeerNodeId != kUndefinedNodeId || mPeerSessionId != UINT16_MAX ||
mLocalSessionId != UINT16_MAX);
}

CHIP_ERROR EncryptBeforeSend(const uint8_t * input, size_t input_length, uint8_t * output, PacketHeader & header,
MessageAuthenticationCode & mac) const
{
Expand All @@ -101,14 +89,15 @@ class SecureSession
SessionMessageCounter & GetSessionMessageCounter() { return mSessionMessageCounter; }

private:
const NodeId mPeerNodeId;
const uint16_t mLocalSessionId;
const uint16_t mPeerSessionId;
const FabricIndex mFabric;

PeerAddress mPeerAddress;
NodeId mPeerNodeId = kUndefinedNodeId;
uint16_t mPeerSessionId = UINT16_MAX;
uint16_t mLocalSessionId = UINT16_MAX;
uint64_t mLastActivityTimeMs = 0;
CryptoContext mCryptoContext;
SessionMessageCounter mSessionMessageCounter;
FabricIndex mFabric = kUndefinedFabricIndex;
};

} // namespace Transport
Expand Down
156 changes: 27 additions & 129 deletions src/transport/SecureSessionTable.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include <lib/core/CHIPError.h>
#include <lib/support/CodeUtils.h>
#include <lib/support/Pool.h>
#include <system/TimeSource.h>
#include <transport/SecureSession.h>

Expand All @@ -43,9 +44,10 @@ class SecureSessionTable
/**
* Allocates a new secure session out of the internal resource pool.
*
* @param peerNode represents peer Node's ID
* @param peerSessionId represents the encryption key ID assigned by peer node
* @param localSessionId represents the encryption key ID assigned by local node
* @param peerNodeId represents peer Node's ID
* @param peerSessionId represents the encryption key ID assigned by peer node
* @param fabric represents fabric ID for the session
* @param state [out] will contain the session if one was available. May be null if no return value is desired.
*
* @note the newly created state will have an 'active' time set based on the current time source.
Expand All @@ -54,70 +56,17 @@ class SecureSessionTable
* has been reached (with CHIP_ERROR_NO_MEMORY).
*/
CHECK_RETURN_VALUE
CHIP_ERROR CreateNewSecureSession(NodeId peerNode, uint16_t peerSessionId, uint16_t localSessionId, SecureSession ** state)
SecureSession * CreateNewSecureSession(uint16_t localSessionId, NodeId peerNodeId, uint16_t peerSessionId, FabricIndex fabric)
{
CHIP_ERROR err = CHIP_ERROR_NO_MEMORY;

if (state)
{
*state = nullptr;
}

for (size_t i = 0; i < kMaxSessionCount; i++)
{
if (!mStates[i].IsInitialized())
{
mStates[i] = SecureSession();
mStates[i].SetPeerNodeId(peerNode);
mStates[i].SetPeerSessionId(peerSessionId);
mStates[i].SetLocalSessionId(localSessionId);
mStates[i].SetLastActivityTimeMs(mTimeSource.GetCurrentMonotonicTimeMs());

if (state)
{
*state = &mStates[i];
}

err = CHIP_NO_ERROR;
break;
}
}

return err;
return mEntries.CreateObject(localSessionId, peerNodeId, peerSessionId, fabric, mTimeSource.GetCurrentMonotonicTimeMs());
}

/**
* Get a secure session given a Node Id.
*
* @param nodeId is the session to find (based on nodeId).
* @param begin If a member of the pool, will start search from the next item. Can be nullptr to search from start.
*
* @return the state found, nullptr if not found
*/
CHECK_RETURN_VALUE
SecureSession * FindSecureSession(NodeId nodeId, SecureSession * begin)
{
SecureSession * state = nullptr;
SecureSession * iter = &mStates[0];

if (begin >= iter && begin < &mStates[kMaxSessionCount])
{
iter = begin + 1;
}
void ReleaseSession(SecureSession * session) { mEntries.ReleaseObject(session); }

for (; iter < &mStates[kMaxSessionCount]; iter++)
{
if (!iter->IsInitialized())
{
continue;
}
if (iter->GetPeerNodeId() == nodeId)
{
state = iter;
break;
}
}
return state;
template <typename Function>
bool ForEachSession(Function && function)
{
return mEntries.ForEachActiveObject(std::forward<Function>(function));
}

/**
Expand All @@ -129,66 +78,23 @@ class SecureSessionTable
* @return the state found, nullptr if not found
*/
CHECK_RETURN_VALUE
SecureSession * FindSecureSessionByLocalKey(uint16_t localSessionId, SecureSession * begin)
{
SecureSession * state = nullptr;
SecureSession * iter = &mStates[0];

if (begin >= iter && begin < &mStates[kMaxSessionCount])
{
iter = begin + 1;
}

for (; iter < &mStates[kMaxSessionCount]; iter++)
{
if (!iter->IsInitialized())
{
continue;
}
if (iter->GetLocalSessionId() == localSessionId)
{
state = iter;
break;
}
}
return state;
}

/**
* Get the first session that matches the given fabric index.
*
* @param fabric The fabric index to match
*
* @return the session found, nullptr if not found
*/
CHECK_RETURN_VALUE
SecureSession * FindSecureSessionByFabric(FabricIndex fabric)
SecureSession * FindSecureSessionByLocalKey(uint16_t localSessionId)
{
for (auto & state : mStates)
{
if (!state.IsInitialized())
{
continue;
}
if (state.GetFabricIndex() == fabric)
SecureSession * result = nullptr;
mEntries.ForEachActiveObject([&](auto session) {
if (session->GetLocalSessionId() == localSessionId)
{
return &state;
result = session;
return false;
}
}
return nullptr;
return true;
});
return result;
}

/// Convenience method to mark a session as active
void MarkSessionActive(SecureSession * state) { state->SetLastActivityTimeMs(mTimeSource.GetCurrentMonotonicTimeMs()); }

/// Convenience method to expired a session and fired the related callback
template <typename Callback>
void MarkSessionExpired(SecureSession * state, Callback callback)
{
callback(*state);
*state = SecureSession(PeerAddress::Uninitialized());
}

/**
* Iterates through all active sessions and expires any sessions with an idle time
* larger than the given amount.
Expand All @@ -199,30 +105,22 @@ class SecureSessionTable
void ExpireInactiveSessions(uint64_t maxIdleTimeMs, Callback callback)
{
const uint64_t currentTime = mTimeSource.GetCurrentMonotonicTimeMs();

for (size_t i = 0; i < kMaxSessionCount; i++)
{
if (!mStates[i].IsInitialized())
{
continue; // not an active session
}

uint64_t sessionActiveTime = mStates[i].GetLastActivityTimeMs();
if (sessionActiveTime + maxIdleTimeMs >= currentTime)
mEntries.ForEachActiveObject([&](auto session) {
if (session->GetLastActivityTimeMs() + maxIdleTimeMs < currentTime)
{
continue; // not expired
callback(*session);
ReleaseSession(session);
}

MarkSessionExpired(&mStates[i], callback);
}
return true;
});
}

/// Allows access to the underlying time source used for keeping track of session active time
Time::TimeSource<kTimeSource> & GetTimeSource() { return mTimeSource; }

private:
Time::TimeSource<kTimeSource> mTimeSource;
SecureSession mStates[kMaxSessionCount];
BitMapObjectPool<SecureSession, kMaxSessionCount> mEntries;
};

} // namespace Transport
Expand Down
Loading

0 comments on commit 16c747b

Please sign in to comment.