Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use BitMapPool for SecureSessionTable #11110

Merged
merged 1 commit into from
Oct 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 16 additions & 27 deletions src/transport/SecureSession.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,43 +49,31 @@ static constexpr uint32_t kUndefinedMessageIndex = UINT32_MAX;
class SecureSession
{
public:
SecureSession() : mPeerAddress(PeerAddress::Uninitialized()) {}
SecureSession(const PeerAddress & addr) : mPeerAddress(addr) {}
SecureSession(PeerAddress && addr) : mPeerAddress(addr) {}
SecureSession(uint16_t localSessionId, NodeId peerNodeId, uint16_t peerSessionId, FabricIndex fabric, uint64_t currentTime) :
mPeerNodeId(peerNodeId), mLocalSessionId(localSessionId), mPeerSessionId(peerSessionId), mFabric(fabric)
{
SetLastActivityTimeMs(currentTime);
}

SecureSession(SecureSession &&) = default;
SecureSession(const SecureSession &) = default;
SecureSession & operator=(const SecureSession &) = default;
SecureSession & operator=(SecureSession &&) = default;
SecureSession(SecureSession &&) = delete;
SecureSession(const SecureSession &) = delete;
SecureSession & operator=(const SecureSession &) = delete;
SecureSession & operator=(SecureSession &&) = delete;

const PeerAddress & GetPeerAddress() const { return mPeerAddress; }
PeerAddress & GetPeerAddress() { return mPeerAddress; }
void SetPeerAddress(const PeerAddress & address) { mPeerAddress = address; }

NodeId GetPeerNodeId() const { return mPeerNodeId; }
void SetPeerNodeId(NodeId peerNodeId) { mPeerNodeId = peerNodeId; }

uint16_t GetPeerSessionId() const { return mPeerSessionId; }
void SetPeerSessionId(uint16_t id) { mPeerSessionId = id; }

// TODO: Rename KeyID to SessionID
uint16_t GetLocalSessionId() const { return mLocalSessionId; }
void SetLocalSessionId(uint16_t id) { mLocalSessionId = id; }
uint16_t GetPeerSessionId() const { return mPeerSessionId; }
FabricIndex GetFabricIndex() const { return mFabric; }

uint64_t GetLastActivityTimeMs() const { return mLastActivityTimeMs; }
void SetLastActivityTimeMs(uint64_t value) { mLastActivityTimeMs = value; }

CryptoContext & GetCryptoContext() { return mCryptoContext; }

FabricIndex GetFabricIndex() const { return mFabric; }
void SetFabricIndex(FabricIndex fabricIndex) { mFabric = fabricIndex; }

bool IsInitialized()
{
return (mPeerAddress.IsInitialized() || mPeerNodeId != kUndefinedNodeId || mPeerSessionId != UINT16_MAX ||
mLocalSessionId != UINT16_MAX);
}

CHIP_ERROR EncryptBeforeSend(const uint8_t * input, size_t input_length, uint8_t * output, PacketHeader & header,
MessageAuthenticationCode & mac) const
{
Expand All @@ -101,14 +89,15 @@ class SecureSession
SessionMessageCounter & GetSessionMessageCounter() { return mSessionMessageCounter; }

private:
const NodeId mPeerNodeId;
const uint16_t mLocalSessionId;
const uint16_t mPeerSessionId;
const FabricIndex mFabric;

PeerAddress mPeerAddress;
NodeId mPeerNodeId = kUndefinedNodeId;
uint16_t mPeerSessionId = UINT16_MAX;
uint16_t mLocalSessionId = UINT16_MAX;
uint64_t mLastActivityTimeMs = 0;
CryptoContext mCryptoContext;
SessionMessageCounter mSessionMessageCounter;
FabricIndex mFabric = kUndefinedFabricIndex;
};

} // namespace Transport
Expand Down
158 changes: 27 additions & 131 deletions src/transport/SecureSessionTable.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include <lib/core/CHIPError.h>
#include <lib/support/CodeUtils.h>
#include <lib/support/Pool.h>
#include <system/TimeSource.h>
#include <transport/SecureSession.h>

Expand All @@ -43,152 +44,55 @@ class SecureSessionTable
/**
* Allocates a new secure session out of the internal resource pool.
*
* @param peerNode represents peer Node's ID
* @param peerSessionId represents the encryption key ID assigned by peer node
* @param localSessionId represents the encryption key ID assigned by local node
* @param state [out] will contain the session if one was available. May be null if no return value is desired.
* @param peerNodeId represents peer Node's ID
* @param peerSessionId represents the encryption key ID assigned by peer node
* @param fabric represents fabric ID for the session
*
* @note the newly created state will have an 'active' time set based on the current time source.
*
* @returns CHIP_NO_ERROR if state could be initialized. May fail if maximum session count
* has been reached (with CHIP_ERROR_NO_MEMORY).
*/
CHECK_RETURN_VALUE
CHIP_ERROR CreateNewSecureSession(NodeId peerNode, uint16_t peerSessionId, uint16_t localSessionId, SecureSession ** state)
SecureSession * CreateNewSecureSession(uint16_t localSessionId, NodeId peerNodeId, uint16_t peerSessionId, FabricIndex fabric)
{
CHIP_ERROR err = CHIP_ERROR_NO_MEMORY;

if (state)
{
*state = nullptr;
}

for (size_t i = 0; i < kMaxSessionCount; i++)
{
if (!mStates[i].IsInitialized())
{
mStates[i] = SecureSession();
mStates[i].SetPeerNodeId(peerNode);
mStates[i].SetPeerSessionId(peerSessionId);
mStates[i].SetLocalSessionId(localSessionId);
mStates[i].SetLastActivityTimeMs(mTimeSource.GetCurrentMonotonicTimeMs());

if (state)
{
*state = &mStates[i];
}

err = CHIP_NO_ERROR;
break;
}
}

return err;
return mEntries.CreateObject(localSessionId, peerNodeId, peerSessionId, fabric, mTimeSource.GetCurrentMonotonicTimeMs());
}

/**
* Get a secure session given a Node Id.
*
* @param nodeId is the session to find (based on nodeId).
* @param begin If a member of the pool, will start search from the next item. Can be nullptr to search from start.
*
* @return the state found, nullptr if not found
*/
CHECK_RETURN_VALUE
SecureSession * FindSecureSession(NodeId nodeId, SecureSession * begin)
{
SecureSession * state = nullptr;
SecureSession * iter = &mStates[0];

if (begin >= iter && begin < &mStates[kMaxSessionCount])
{
iter = begin + 1;
}
void ReleaseSession(SecureSession * session) { mEntries.ReleaseObject(session); }

for (; iter < &mStates[kMaxSessionCount]; iter++)
{
if (!iter->IsInitialized())
{
continue;
}
if (iter->GetPeerNodeId() == nodeId)
{
state = iter;
break;
}
}
return state;
template <typename Function>
bool ForEachSession(Function && function)
{
return mEntries.ForEachActiveObject(std::forward<Function>(function));
}

/**
* Get a secure session given a Node Id and Peer's Encryption Key Id.
*
* @param localSessionId Encryption key ID used by the local node.
* @param begin If a member of the pool, will start search from the next item. Can be nullptr to search from start.
*
* @return the state found, nullptr if not found
*/
CHECK_RETURN_VALUE
SecureSession * FindSecureSessionByLocalKey(uint16_t localSessionId, SecureSession * begin)
SecureSession * FindSecureSessionByLocalKey(uint16_t localSessionId)
{
SecureSession * state = nullptr;
SecureSession * iter = &mStates[0];

if (begin >= iter && begin < &mStates[kMaxSessionCount])
{
iter = begin + 1;
}

for (; iter < &mStates[kMaxSessionCount]; iter++)
{
if (!iter->IsInitialized())
SecureSession * result = nullptr;
mEntries.ForEachActiveObject([&](auto session) {
if (session->GetLocalSessionId() == localSessionId)
{
continue;
result = session;
return false;
}
if (iter->GetLocalSessionId() == localSessionId)
{
state = iter;
break;
}
}
return state;
}

/**
* Get the first session that matches the given fabric index.
*
* @param fabric The fabric index to match
*
* @return the session found, nullptr if not found
*/
CHECK_RETURN_VALUE
SecureSession * FindSecureSessionByFabric(FabricIndex fabric)
{
for (auto & state : mStates)
{
if (!state.IsInitialized())
{
continue;
}
if (state.GetFabricIndex() == fabric)
{
return &state;
}
}
return nullptr;
return true;
});
return result;
}

/// Convenience method to mark a session as active
void MarkSessionActive(SecureSession * state) { state->SetLastActivityTimeMs(mTimeSource.GetCurrentMonotonicTimeMs()); }

/// Convenience method to expired a session and fired the related callback
template <typename Callback>
void MarkSessionExpired(SecureSession * state, Callback callback)
{
callback(*state);
*state = SecureSession(PeerAddress::Uninitialized());
}

/**
* Iterates through all active sessions and expires any sessions with an idle time
* larger than the given amount.
Expand All @@ -199,30 +103,22 @@ class SecureSessionTable
void ExpireInactiveSessions(uint64_t maxIdleTimeMs, Callback callback)
{
const uint64_t currentTime = mTimeSource.GetCurrentMonotonicTimeMs();

for (size_t i = 0; i < kMaxSessionCount; i++)
{
if (!mStates[i].IsInitialized())
mEntries.ForEachActiveObject([&](auto session) {
if (session->GetLastActivityTimeMs() + maxIdleTimeMs < currentTime)
{
continue; // not an active session
callback(*session);
ReleaseSession(session);
}

uint64_t sessionActiveTime = mStates[i].GetLastActivityTimeMs();
if (sessionActiveTime + maxIdleTimeMs >= currentTime)
{
continue; // not expired
}

MarkSessionExpired(&mStates[i], callback);
}
return true;
});
}

/// Allows access to the underlying time source used for keeping track of session active time
Time::TimeSource<kTimeSource> & GetTimeSource() { return mTimeSource; }

private:
Time::TimeSource<kTimeSource> mTimeSource;
SecureSession mStates[kMaxSessionCount];
BitMapObjectPool<SecureSession, kMaxSessionCount> mEntries;
};

} // namespace Transport
Expand Down
Loading