diff --git a/src/transport/SecureSession.h b/src/transport/SecureSession.h index 16577c8dc61a1e..323c010f3fe57c 100644 --- a/src/transport/SecureSession.h +++ b/src/transport/SecureSession.h @@ -49,43 +49,31 @@ static constexpr uint32_t kUndefinedMessageIndex = UINT32_MAX; class SecureSession { public: - SecureSession() : mPeerAddress(PeerAddress::Uninitialized()) {} - SecureSession(const PeerAddress & addr) : mPeerAddress(addr) {} - SecureSession(PeerAddress && addr) : mPeerAddress(addr) {} + SecureSession(uint16_t localSessionId, NodeId peerNodeId, uint16_t peerSessionId, FabricIndex fabric, uint64_t currentTime) : + mPeerNodeId(peerNodeId), mLocalSessionId(localSessionId), mPeerSessionId(peerSessionId), mFabric(fabric) + { + SetLastActivityTimeMs(currentTime); + } - SecureSession(SecureSession &&) = default; - SecureSession(const SecureSession &) = default; - SecureSession & operator=(const SecureSession &) = default; - SecureSession & operator=(SecureSession &&) = default; + SecureSession(SecureSession &&) = delete; + SecureSession(const SecureSession &) = delete; + SecureSession & operator=(const SecureSession &) = delete; + SecureSession & operator=(SecureSession &&) = delete; const PeerAddress & GetPeerAddress() const { return mPeerAddress; } PeerAddress & GetPeerAddress() { return mPeerAddress; } void SetPeerAddress(const PeerAddress & address) { mPeerAddress = address; } NodeId GetPeerNodeId() const { return mPeerNodeId; } - void SetPeerNodeId(NodeId peerNodeId) { mPeerNodeId = peerNodeId; } - - uint16_t GetPeerSessionId() const { return mPeerSessionId; } - void SetPeerSessionId(uint16_t id) { mPeerSessionId = id; } - - // TODO: Rename KeyID to SessionID uint16_t GetLocalSessionId() const { return mLocalSessionId; } - void SetLocalSessionId(uint16_t id) { mLocalSessionId = id; } + uint16_t GetPeerSessionId() const { return mPeerSessionId; } + FabricIndex GetFabricIndex() const { return mFabric; } uint64_t GetLastActivityTimeMs() const { return mLastActivityTimeMs; } void SetLastActivityTimeMs(uint64_t value) { mLastActivityTimeMs = value; } CryptoContext & GetCryptoContext() { return mCryptoContext; } - FabricIndex GetFabricIndex() const { return mFabric; } - void SetFabricIndex(FabricIndex fabricIndex) { mFabric = fabricIndex; } - - bool IsInitialized() - { - return (mPeerAddress.IsInitialized() || mPeerNodeId != kUndefinedNodeId || mPeerSessionId != UINT16_MAX || - mLocalSessionId != UINT16_MAX); - } - CHIP_ERROR EncryptBeforeSend(const uint8_t * input, size_t input_length, uint8_t * output, PacketHeader & header, MessageAuthenticationCode & mac) const { @@ -101,14 +89,15 @@ class SecureSession SessionMessageCounter & GetSessionMessageCounter() { return mSessionMessageCounter; } private: + const NodeId mPeerNodeId; + const uint16_t mLocalSessionId; + const uint16_t mPeerSessionId; + const FabricIndex mFabric; + PeerAddress mPeerAddress; - NodeId mPeerNodeId = kUndefinedNodeId; - uint16_t mPeerSessionId = UINT16_MAX; - uint16_t mLocalSessionId = UINT16_MAX; uint64_t mLastActivityTimeMs = 0; CryptoContext mCryptoContext; SessionMessageCounter mSessionMessageCounter; - FabricIndex mFabric = kUndefinedFabricIndex; }; } // namespace Transport diff --git a/src/transport/SecureSessionTable.h b/src/transport/SecureSessionTable.h index 00b126b5fb5765..01fa082055892f 100644 --- a/src/transport/SecureSessionTable.h +++ b/src/transport/SecureSessionTable.h @@ -18,6 +18,7 @@ #include #include +#include #include #include @@ -43,10 +44,10 @@ class SecureSessionTable /** * Allocates a new secure session out of the internal resource pool. * - * @param peerNode represents peer Node's ID - * @param peerSessionId represents the encryption key ID assigned by peer node * @param localSessionId represents the encryption key ID assigned by local node - * @param state [out] will contain the session if one was available. May be null if no return value is desired. + * @param peerNodeId represents peer Node's ID + * @param peerSessionId represents the encryption key ID assigned by peer node + * @param fabric represents fabric ID for the session * * @note the newly created state will have an 'active' time set based on the current time source. * @@ -54,141 +55,44 @@ class SecureSessionTable * has been reached (with CHIP_ERROR_NO_MEMORY). */ CHECK_RETURN_VALUE - CHIP_ERROR CreateNewSecureSession(NodeId peerNode, uint16_t peerSessionId, uint16_t localSessionId, SecureSession ** state) + SecureSession * CreateNewSecureSession(uint16_t localSessionId, NodeId peerNodeId, uint16_t peerSessionId, FabricIndex fabric) { - CHIP_ERROR err = CHIP_ERROR_NO_MEMORY; - - if (state) - { - *state = nullptr; - } - - for (size_t i = 0; i < kMaxSessionCount; i++) - { - if (!mStates[i].IsInitialized()) - { - mStates[i] = SecureSession(); - mStates[i].SetPeerNodeId(peerNode); - mStates[i].SetPeerSessionId(peerSessionId); - mStates[i].SetLocalSessionId(localSessionId); - mStates[i].SetLastActivityTimeMs(mTimeSource.GetCurrentMonotonicTimeMs()); - - if (state) - { - *state = &mStates[i]; - } - - err = CHIP_NO_ERROR; - break; - } - } - - return err; + return mEntries.CreateObject(localSessionId, peerNodeId, peerSessionId, fabric, mTimeSource.GetCurrentMonotonicTimeMs()); } - /** - * Get a secure session given a Node Id. - * - * @param nodeId is the session to find (based on nodeId). - * @param begin If a member of the pool, will start search from the next item. Can be nullptr to search from start. - * - * @return the state found, nullptr if not found - */ - CHECK_RETURN_VALUE - SecureSession * FindSecureSession(NodeId nodeId, SecureSession * begin) - { - SecureSession * state = nullptr; - SecureSession * iter = &mStates[0]; - - if (begin >= iter && begin < &mStates[kMaxSessionCount]) - { - iter = begin + 1; - } + void ReleaseSession(SecureSession * session) { mEntries.ReleaseObject(session); } - for (; iter < &mStates[kMaxSessionCount]; iter++) - { - if (!iter->IsInitialized()) - { - continue; - } - if (iter->GetPeerNodeId() == nodeId) - { - state = iter; - break; - } - } - return state; + template + bool ForEachSession(Function && function) + { + return mEntries.ForEachActiveObject(std::forward(function)); } /** * Get a secure session given a Node Id and Peer's Encryption Key Id. * * @param localSessionId Encryption key ID used by the local node. - * @param begin If a member of the pool, will start search from the next item. Can be nullptr to search from start. * * @return the state found, nullptr if not found */ CHECK_RETURN_VALUE - SecureSession * FindSecureSessionByLocalKey(uint16_t localSessionId, SecureSession * begin) + SecureSession * FindSecureSessionByLocalKey(uint16_t localSessionId) { - SecureSession * state = nullptr; - SecureSession * iter = &mStates[0]; - - if (begin >= iter && begin < &mStates[kMaxSessionCount]) - { - iter = begin + 1; - } - - for (; iter < &mStates[kMaxSessionCount]; iter++) - { - if (!iter->IsInitialized()) + SecureSession * result = nullptr; + mEntries.ForEachActiveObject([&](auto session) { + if (session->GetLocalSessionId() == localSessionId) { - continue; + result = session; + return false; } - if (iter->GetLocalSessionId() == localSessionId) - { - state = iter; - break; - } - } - return state; - } - - /** - * Get the first session that matches the given fabric index. - * - * @param fabric The fabric index to match - * - * @return the session found, nullptr if not found - */ - CHECK_RETURN_VALUE - SecureSession * FindSecureSessionByFabric(FabricIndex fabric) - { - for (auto & state : mStates) - { - if (!state.IsInitialized()) - { - continue; - } - if (state.GetFabricIndex() == fabric) - { - return &state; - } - } - return nullptr; + return true; + }); + return result; } /// Convenience method to mark a session as active void MarkSessionActive(SecureSession * state) { state->SetLastActivityTimeMs(mTimeSource.GetCurrentMonotonicTimeMs()); } - /// Convenience method to expired a session and fired the related callback - template - void MarkSessionExpired(SecureSession * state, Callback callback) - { - callback(*state); - *state = SecureSession(PeerAddress::Uninitialized()); - } - /** * Iterates through all active sessions and expires any sessions with an idle time * larger than the given amount. @@ -199,22 +103,14 @@ class SecureSessionTable void ExpireInactiveSessions(uint64_t maxIdleTimeMs, Callback callback) { const uint64_t currentTime = mTimeSource.GetCurrentMonotonicTimeMs(); - - for (size_t i = 0; i < kMaxSessionCount; i++) - { - if (!mStates[i].IsInitialized()) + mEntries.ForEachActiveObject([&](auto session) { + if (session->GetLastActivityTimeMs() + maxIdleTimeMs < currentTime) { - continue; // not an active session + callback(*session); + ReleaseSession(session); } - - uint64_t sessionActiveTime = mStates[i].GetLastActivityTimeMs(); - if (sessionActiveTime + maxIdleTimeMs >= currentTime) - { - continue; // not expired - } - - MarkSessionExpired(&mStates[i], callback); - } + return true; + }); } /// Allows access to the underlying time source used for keeping track of session active time @@ -222,7 +118,7 @@ class SecureSessionTable private: Time::TimeSource mTimeSource; - SecureSession mStates[kMaxSessionCount]; + BitMapObjectPool mEntries; }; } // namespace Transport diff --git a/src/transport/SessionManager.cpp b/src/transport/SessionManager.cpp index f8aa215273eb33..7f8838ebd68b71 100644 --- a/src/transport/SessionManager.cpp +++ b/src/transport/SessionManager.cpp @@ -103,7 +103,7 @@ void SessionManager::Shutdown() mCB = nullptr; } -CHIP_ERROR SessionManager::PrepareMessage(SessionHandle session, PayloadHeader & payloadHeader, +CHIP_ERROR SessionManager::PrepareMessage(SessionHandle sessionHandle, PayloadHeader & payloadHeader, System::PacketBufferHandle && message, EncryptedPacketBufferHandle & preparedMessage) { PacketHeader packetHeader; @@ -115,26 +115,26 @@ CHIP_ERROR SessionManager::PrepareMessage(SessionHandle session, PayloadHeader & #if CHIP_PROGRESS_LOGGING NodeId destination; #endif // CHIP_PROGRESS_LOGGING - if (session.IsSecure()) + if (sessionHandle.IsSecure()) { - SecureSession * state = GetSecureSession(session); - if (state == nullptr) + SecureSession * session = GetSecureSession(sessionHandle); + if (session == nullptr) { return CHIP_ERROR_NOT_CONNECTED; } - MessageCounter & counter = GetSendCounterForPacket(payloadHeader, *state); - ReturnErrorOnFailure(SecureMessageCodec::Encrypt(state, payloadHeader, packetHeader, message, counter)); + MessageCounter & counter = GetSendCounterForPacket(payloadHeader, *session); + ReturnErrorOnFailure(SecureMessageCodec::Encrypt(session, payloadHeader, packetHeader, message, counter)); #if CHIP_PROGRESS_LOGGING - destination = state->GetPeerNodeId(); + destination = session->GetPeerNodeId(); #endif // CHIP_PROGRESS_LOGGING } else { ReturnErrorOnFailure(payloadHeader.EncodeBeforeData(message)); - MessageCounter & counter = session.GetUnauthenticatedSession()->GetLocalMessageCounter(); + MessageCounter & counter = sessionHandle.GetUnauthenticatedSession()->GetLocalMessageCounter(); uint32_t messageCounter = counter.Value(); ReturnErrorOnFailure(counter.Advance()); @@ -149,7 +149,7 @@ CHIP_ERROR SessionManager::PrepareMessage(SessionHandle session, PayloadHeader & "Prepared %s message %p to 0x" ChipLogFormatX64 " of type " ChipLogFormatMessageType " and protocolId " ChipLogFormatProtocolId " on exchange " ChipLogFormatExchangeId " with MessageCounter:" ChipLogFormatMessageCounter ".", - session.IsSecure() ? "encrypted" : "plaintext", &preparedMessage, ChipLogValueX64(destination), + sessionHandle.IsSecure() ? "encrypted" : "plaintext", &preparedMessage, ChipLogValueX64(destination), payloadHeader.GetMessageType(), ChipLogValueProtocolId(payloadHeader.GetProtocolID()), ChipLogValueExchangeIdFromSentHeader(payloadHeader), packetHeader.GetMessageCounter()); @@ -159,37 +159,37 @@ CHIP_ERROR SessionManager::PrepareMessage(SessionHandle session, PayloadHeader & return CHIP_NO_ERROR; } -CHIP_ERROR SessionManager::SendPreparedMessage(SessionHandle session, const EncryptedPacketBufferHandle & preparedMessage) +CHIP_ERROR SessionManager::SendPreparedMessage(SessionHandle sessionHandle, const EncryptedPacketBufferHandle & preparedMessage) { VerifyOrReturnError(mState == State::kInitialized, CHIP_ERROR_INCORRECT_STATE); VerifyOrReturnError(!preparedMessage.IsNull(), CHIP_ERROR_INVALID_ARGUMENT); const Transport::PeerAddress * destination; - if (session.IsSecure()) + if (sessionHandle.IsSecure()) { // Find an active connection to the specified peer node - SecureSession * state = GetSecureSession(session); - if (state == nullptr) + SecureSession * session = GetSecureSession(sessionHandle); + if (session == nullptr) { ChipLogError(Inet, "Secure transport could not find a valid PeerConnection"); return CHIP_ERROR_NOT_CONNECTED; } // This marks any connection where we send data to as 'active' - mPeerConnections.MarkSessionActive(state); + mSecureSessions.MarkSessionActive(session); - destination = &state->GetPeerAddress(); + destination = &session->GetPeerAddress(); ChipLogProgress(Inet, "Sending %s msg %p with MessageCounter:" ChipLogFormatMessageCounter " to 0x" ChipLogFormatX64 " at monotonic time: %" PRId64 " msec", - "encrypted", &preparedMessage, preparedMessage.GetMessageCounter(), ChipLogValueX64(state->GetPeerNodeId()), - System::SystemClock().GetMonotonicMilliseconds64().count()); + "encrypted", &preparedMessage, preparedMessage.GetMessageCounter(), + ChipLogValueX64(session->GetPeerNodeId()), System::SystemClock().GetMonotonicMilliseconds64().count()); } else { - auto unauthenticated = session.GetUnauthenticatedSession(); + auto unauthenticated = sessionHandle.GetUnauthenticatedSession(); mUnauthenticatedSessions.MarkSessionActive(unauthenticated); destination = &unauthenticated->GetPeerAddress(); @@ -215,44 +215,39 @@ CHIP_ERROR SessionManager::SendPreparedMessage(SessionHandle session, const Encr } } -void SessionManager::ExpirePairing(SessionHandle session) +void SessionManager::ExpirePairing(SessionHandle sessionHandle) { - SecureSession * state = GetSecureSession(session); - if (state != nullptr) + SecureSession * session = GetSecureSession(sessionHandle); + if (session != nullptr) { - mPeerConnections.MarkSessionExpired(state, - [this](const Transport::SecureSession & state1) { HandleConnectionExpired(state1); }); + HandleConnectionExpired(*session); + mSecureSessions.ReleaseSession(session); } } void SessionManager::ExpireAllPairings(NodeId peerNodeId, FabricIndex fabric) { - SecureSession * state = mPeerConnections.FindSecureSession(peerNodeId, nullptr); - while (state != nullptr) - { - if (fabric == state->GetFabricIndex()) + mSecureSessions.ForEachSession([&](auto session) { + if (session->GetPeerNodeId() == peerNodeId && session->GetFabricIndex() == fabric) { - mPeerConnections.MarkSessionExpired( - state, [this](const Transport::SecureSession & state1) { HandleConnectionExpired(state1); }); - state = mPeerConnections.FindSecureSession(peerNodeId, nullptr); + HandleConnectionExpired(*session); + mSecureSessions.ReleaseSession(session); } - else - { - state = mPeerConnections.FindSecureSession(peerNodeId, state); - } - } + return true; + }); } void SessionManager::ExpireAllPairingsForFabric(FabricIndex fabric) { ChipLogDetail(Inet, "Expiring all connections for fabric %d!!", fabric); - SecureSession * state = mPeerConnections.FindSecureSessionByFabric(fabric); - while (state != nullptr) - { - mPeerConnections.MarkSessionExpired(state, - [this](const Transport::SecureSession & state1) { HandleConnectionExpired(state1); }); - state = mPeerConnections.FindSecureSessionByFabric(fabric); - } + mSecureSessions.ForEachSession([&](auto session) { + if (session->GetFabricIndex() == fabric) + { + HandleConnectionExpired(*session); + mSecureSessions.ReleaseSession(session); + } + return true; + }); } CHIP_ERROR SessionManager::NewPairing(const Optional & peerAddr, NodeId peerNodeId, @@ -260,30 +255,27 @@ CHIP_ERROR SessionManager::NewPairing(const Optional & p { uint16_t peerSessionId = pairing->GetPeerSessionId(); uint16_t localSessionId = pairing->GetLocalSessionId(); - SecureSession * state = mPeerConnections.FindSecureSessionByLocalKey(localSessionId, nullptr); + SecureSession * session = mSecureSessions.FindSecureSessionByLocalKey(localSessionId); // Find any existing connection with the same local key ID - if (state) + if (session) { - mPeerConnections.MarkSessionExpired(state, - [this](const Transport::SecureSession & state1) { HandleConnectionExpired(state1); }); + HandleConnectionExpired(*session); + mSecureSessions.ReleaseSession(session); } ChipLogDetail(Inet, "New secure session created for device 0x" ChipLogFormatX64 ", key %d!!", ChipLogValueX64(peerNodeId), peerSessionId); - state = nullptr; - ReturnErrorOnFailure(mPeerConnections.CreateNewSecureSession(peerNodeId, peerSessionId, localSessionId, &state)); - ReturnErrorCodeIf(state == nullptr, CHIP_ERROR_NO_MEMORY); - - state->SetFabricIndex(fabric); + session = mSecureSessions.CreateNewSecureSession(localSessionId, peerNodeId, peerSessionId, fabric); + ReturnErrorCodeIf(session == nullptr, CHIP_ERROR_NO_MEMORY); if (peerAddr.HasValue() && peerAddr.Value().GetIPAddress() != Inet::IPAddress::Any) { - state->SetPeerAddress(peerAddr.Value()); + session->SetPeerAddress(peerAddr.Value()); } else if (peerAddr.HasValue() && peerAddr.Value().GetTransportType() == Transport::Type::kBle) { - state->SetPeerAddress(peerAddr.Value()); + session->SetPeerAddress(peerAddr.Value()); } else if (peerAddr.HasValue() && (peerAddr.Value().GetTransportType() == Transport::Type::kTcp || @@ -292,12 +284,13 @@ CHIP_ERROR SessionManager::NewPairing(const Optional & p return CHIP_ERROR_INVALID_ARGUMENT; } - ReturnErrorOnFailure(pairing->DeriveSecureSession(state->GetCryptoContext(), direction)); + ReturnErrorOnFailure(pairing->DeriveSecureSession(session->GetCryptoContext(), direction)); if (mCB != nullptr) { - state->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(pairing->GetPeerCounter()); - mCB->OnNewConnection(SessionHandle(state->GetPeerNodeId(), state->GetLocalSessionId(), state->GetPeerSessionId(), fabric)); + session->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(pairing->GetPeerCounter()); + mCB->OnNewConnection( + SessionHandle(session->GetPeerNodeId(), session->GetLocalSessionId(), session->GetPeerSessionId(), fabric)); } return CHIP_NO_ERROR; @@ -390,7 +383,7 @@ void SessionManager::SecureUnicastMessageDispatch(const PacketHeader & packetHea { CHIP_ERROR err = CHIP_NO_ERROR; - SecureSession * state = mPeerConnections.FindSecureSessionByLocalKey(packetHeader.GetSessionId(), nullptr); + SecureSession * session = mSecureSessions.FindSecureSessionByLocalKey(packetHeader.GetSessionId()); PayloadHeader payloadHeader; @@ -398,14 +391,14 @@ void SessionManager::SecureUnicastMessageDispatch(const PacketHeader & packetHea VerifyOrExit(!msg.IsNull(), ChipLogError(Inet, "Secure transport received NULL packet, discarding")); - if (state == nullptr) + if (session == nullptr) { ChipLogError(Inet, "Data received on an unknown connection (%d). Dropping it!!", packetHeader.GetSessionId()); ExitNow(err = CHIP_ERROR_KEY_NOT_FOUND_FROM_PEER); } // Decrypt and verify the message before message counter verification or any further processing. - VerifyOrExit(CHIP_NO_ERROR == SecureMessageCodec::Decrypt(state, payloadHeader, packetHeader, msg), + VerifyOrExit(CHIP_NO_ERROR == SecureMessageCodec::Decrypt(session, payloadHeader, packetHeader, msg), ChipLogError(Inet, "Secure transport received message, but failed to decode/authenticate it, discarding")); // Verify message counter @@ -415,14 +408,14 @@ void SessionManager::SecureUnicastMessageDispatch(const PacketHeader & packetHea } else { - if (!state->GetSessionMessageCounter().GetPeerMessageCounter().IsSynchronized()) + if (!session->GetSessionMessageCounter().GetPeerMessageCounter().IsSynchronized()) { // Queue and start message sync procedure err = mMessageCounterManager->QueueReceivedMessageAndStartSync( packetHeader, - SessionHandle(state->GetPeerNodeId(), state->GetLocalSessionId(), state->GetPeerSessionId(), - state->GetFabricIndex()), - state, peerAddress, std::move(msg)); + SessionHandle(session->GetPeerNodeId(), session->GetLocalSessionId(), session->GetPeerSessionId(), + session->GetFabricIndex()), + session, peerAddress, std::move(msg)); if (err != CHIP_NO_ERROR) { @@ -439,7 +432,7 @@ void SessionManager::SecureUnicastMessageDispatch(const PacketHeader & packetHea return; } - err = state->GetSessionMessageCounter().GetPeerMessageCounter().Verify(packetHeader.GetMessageCounter()); + err = session->GetSessionMessageCounter().GetPeerMessageCounter().Verify(packetHeader.GetMessageCounter()); if (err == CHIP_ERROR_DUPLICATE_MESSAGE_RECEIVED) { isDuplicate = SessionManagerDelegate::DuplicateMessage::Yes; @@ -452,7 +445,7 @@ void SessionManager::SecureUnicastMessageDispatch(const PacketHeader & packetHea SuccessOrExit(err); } - mPeerConnections.MarkSessionActive(state); + mSecureSessions.MarkSessionActive(session); if (isDuplicate == SessionManagerDelegate::DuplicateMessage::Yes && !payloadHeader.NeedsAck()) { @@ -474,22 +467,22 @@ void SessionManager::SecureUnicastMessageDispatch(const PacketHeader & packetHea } else { - state->GetSessionMessageCounter().GetPeerMessageCounter().Commit(packetHeader.GetMessageCounter()); + session->GetSessionMessageCounter().GetPeerMessageCounter().Commit(packetHeader.GetMessageCounter()); } // TODO: once mDNS address resolution is available reconsider if this is required // This updates the peer address once a packet is received from a new address // and serves as a way to auto-detect peer changing IPs. - if (state->GetPeerAddress() != peerAddress) + if (session->GetPeerAddress() != peerAddress) { - state->SetPeerAddress(peerAddress); + session->SetPeerAddress(peerAddress); } if (mCB != nullptr) { - SessionHandle session(state->GetPeerNodeId(), state->GetLocalSessionId(), state->GetPeerSessionId(), - state->GetFabricIndex()); - mCB->OnMessageReceived(packetHeader, payloadHeader, session, peerAddress, isDuplicate, std::move(msg)); + SessionHandle sessionHandle(session->GetPeerNodeId(), session->GetLocalSessionId(), session->GetPeerSessionId(), + session->GetFabricIndex()); + mCB->OnMessageReceived(packetHeader, payloadHeader, sessionHandle, peerAddress, isDuplicate, std::move(msg)); } exit: @@ -542,8 +535,8 @@ void SessionManager::SecureGroupMessageDispatch(const PacketHeader & packetHeade if (mCB != nullptr) { // TODO: Update Session Handle for Group messages. - // SessionHandle session(state->GetPeerNodeId(), state->GetLocalSessionId(), state->GetPeerSessionId(), - // state->GetFabricIndex()); + // SessionHandle session(session->GetPeerNodeId(), session->GetLocalSessionId(), session->GetPeerSessionId(), + // session->GetFabricIndex()); // mCB->OnMessageReceived(packetHeader, payloadHeader, nullptr, peerAddress, isDuplicate, std::move(msg)); } @@ -554,18 +547,18 @@ void SessionManager::SecureGroupMessageDispatch(const PacketHeader & packetHeade } } -void SessionManager::HandleConnectionExpired(const Transport::SecureSession & state) +void SessionManager::HandleConnectionExpired(const Transport::SecureSession & session) { ChipLogDetail(Inet, "Marking old secure session for device 0x" ChipLogFormatX64 " as expired", - ChipLogValueX64(state.GetPeerNodeId())); + ChipLogValueX64(session.GetPeerNodeId())); if (mCB != nullptr) { - mCB->OnConnectionExpired( - SessionHandle(state.GetPeerNodeId(), state.GetLocalSessionId(), state.GetPeerSessionId(), state.GetFabricIndex())); + mCB->OnConnectionExpired(SessionHandle(session.GetPeerNodeId(), session.GetLocalSessionId(), session.GetPeerSessionId(), + session.GetFabricIndex())); } - mTransportMgr->Disconnect(state.GetPeerAddress()); + mTransportMgr->Disconnect(session.GetPeerAddress()); } void SessionManager::ExpiryTimerCallback(System::Layer * layer, void * param) @@ -574,7 +567,7 @@ void SessionManager::ExpiryTimerCallback(System::Layer * layer, void * param) #if CHIP_CONFIG_SESSION_REKEYING // TODO(#2279): session expiration is currently disabled until rekeying is supported // the #ifdef should be removed after that. - mgr->mPeerConnections.ExpireInactiveSessions( + mgr->mSecureSessions.ExpireInactiveSessions( CHIP_PEER_CONNECTION_TIMEOUT_MS, [this](const Transport::SecureSession & state1) { HandleConnectionExpired(state1); }); #endif mgr->ScheduleExpiryTimer(); // re-schedule the oneshot timer @@ -584,7 +577,7 @@ SecureSession * SessionManager::GetSecureSession(SessionHandle session) { if (session.mLocalSessionId.HasValue()) { - return mPeerConnections.FindSecureSessionByLocalKey(session.mLocalSessionId.Value(), nullptr); + return mSecureSessions.FindSecureSessionByLocalKey(session.mLocalSessionId.Value()); } else { @@ -594,10 +587,18 @@ SecureSession * SessionManager::GetSecureSession(SessionHandle session) SessionHandle SessionManager::FindSecureSessionForNode(NodeId peerNodeId) { - SecureSession * session = mPeerConnections.FindSecureSession(peerNodeId, nullptr); - VerifyOrDie(session != nullptr); - return SessionHandle(session->GetPeerNodeId(), session->GetLocalSessionId(), session->GetPeerSessionId(), - session->GetFabricIndex()); + SecureSession * found = nullptr; + mSecureSessions.ForEachSession([&](auto session) { + if (session->GetPeerNodeId() == peerNodeId) + { + found = session; + return false; + } + return true; + }); + + VerifyOrDie(found != nullptr); + return SessionHandle(found->GetPeerNodeId(), found->GetLocalSessionId(), found->GetPeerSessionId(), found->GetFabricIndex()); } } // namespace chip diff --git a/src/transport/SessionManager.h b/src/transport/SessionManager.h index 5fe4cbd7797ccb..0248623ce58720 100644 --- a/src/transport/SessionManager.h +++ b/src/transport/SessionManager.h @@ -289,8 +289,8 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate System::Layer * mSystemLayer = nullptr; Transport::UnauthenticatedSessionTable mUnauthenticatedSessions; - Transport::SecureSessionTable mPeerConnections; // < Active connections to other peers - State mState; // < Initialization state of the object + Transport::SecureSessionTable mSecureSessions; // < Active connections to other peers + State mState; // < Initialization state of the object SessionManagerDelegate * mCB = nullptr; TransportMgrBase * mTransportMgr = nullptr; diff --git a/src/transport/tests/TestPeerConnections.cpp b/src/transport/tests/TestPeerConnections.cpp index 55ee2b2665b5d3..128f7fc4e9d7b7 100644 --- a/src/transport/tests/TestPeerConnections.cpp +++ b/src/transport/tests/TestPeerConnections.cpp @@ -53,87 +53,43 @@ const NodeId kPeer3NodeId = 81; void TestBasicFunctionality(nlTestSuite * inSuite, void * inContext) { - CHIP_ERROR err; SecureSession * statePtr; SecureSessionTable<2, Time::Source::kTest> connections; connections.GetTimeSource().SetCurrentMonotonicTimeMs(100); // Node ID 1, peer key 1, local key 2 - err = connections.CreateNewSecureSession(kPeer1NodeId, 1, 2, nullptr); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + statePtr = connections.CreateNewSecureSession(2, kPeer1NodeId, 1, 0 /* fabricIndex */); + NL_TEST_ASSERT(inSuite, statePtr != nullptr); // Node ID 2, peer key 3, local key 4 - err = connections.CreateNewSecureSession(kPeer2NodeId, 3, 4, &statePtr); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + statePtr = connections.CreateNewSecureSession(4, kPeer2NodeId, 3, 0 /* fabricIndex */); NL_TEST_ASSERT(inSuite, statePtr != nullptr); NL_TEST_ASSERT(inSuite, statePtr->GetPeerNodeId() == kPeer2NodeId); NL_TEST_ASSERT(inSuite, statePtr->GetLastActivityTimeMs() == 100); // Insufficient space for new connections. Object is max size 2 - err = connections.CreateNewSecureSession(kPeer3NodeId, 5, 6, &statePtr); - NL_TEST_ASSERT(inSuite, err != CHIP_NO_ERROR); -} - -void TestFindByNodeId(nlTestSuite * inSuite, void * inContext) -{ - CHIP_ERROR err; - SecureSession * statePtr; - SecureSessionTable<3, Time::Source::kTest> connections; - - // Node ID 1, peer key 1, local key 2 - err = connections.CreateNewSecureSession(kPeer1NodeId, 1, 2, &statePtr); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - statePtr->SetPeerAddress(kPeer1Addr); - - // Node ID 2, peer key 3, local key 4 - err = connections.CreateNewSecureSession(kPeer2NodeId, 3, 4, &statePtr); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - statePtr->SetPeerAddress(kPeer2Addr); - - // Same Node ID 1, peer key 5, local key 6 - err = connections.CreateNewSecureSession(kPeer1NodeId, 5, 6, &statePtr); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - statePtr->SetPeerAddress(kPeer3Addr); - - NL_TEST_ASSERT(inSuite, statePtr = connections.FindSecureSession(kPeer1NodeId, nullptr)); - char buf[100]; - statePtr->GetPeerAddress().ToString(buf); - NL_TEST_ASSERT(inSuite, statePtr->GetPeerAddress() == kPeer1Addr); - NL_TEST_ASSERT(inSuite, statePtr->GetPeerNodeId() == kPeer1NodeId); - - NL_TEST_ASSERT(inSuite, statePtr = connections.FindSecureSession(kPeer1NodeId, statePtr)); - statePtr->GetPeerAddress().ToString(buf); - NL_TEST_ASSERT(inSuite, statePtr->GetPeerAddress() == kPeer3Addr); - NL_TEST_ASSERT(inSuite, statePtr->GetPeerNodeId() == kPeer1NodeId); - - NL_TEST_ASSERT(inSuite, (statePtr = connections.FindSecureSession(kPeer1NodeId, statePtr)) == nullptr); - - NL_TEST_ASSERT(inSuite, statePtr = connections.FindSecureSession(kPeer2NodeId, nullptr)); - NL_TEST_ASSERT(inSuite, statePtr->GetPeerAddress() == kPeer2Addr); - NL_TEST_ASSERT(inSuite, statePtr->GetPeerNodeId() == kPeer2NodeId); - - NL_TEST_ASSERT(inSuite, !connections.FindSecureSession(kPeer3NodeId, nullptr)); + statePtr = connections.CreateNewSecureSession(6, kPeer3NodeId, 5, 0 /* fabricIndex */); + NL_TEST_ASSERT(inSuite, statePtr == nullptr); } void TestFindByKeyId(nlTestSuite * inSuite, void * inContext) { - CHIP_ERROR err; SecureSession * statePtr; SecureSessionTable<2, Time::Source::kTest> connections; // Node ID 1, peer key 1, local key 2 - err = connections.CreateNewSecureSession(kPeer1NodeId, 1, 2, &statePtr); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + statePtr = connections.CreateNewSecureSession(2, kPeer1NodeId, 1, 0 /* fabricIndex */); + NL_TEST_ASSERT(inSuite, statePtr != nullptr); - NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(1, nullptr)); - NL_TEST_ASSERT(inSuite, connections.FindSecureSessionByLocalKey(2, nullptr)); + NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(1)); + NL_TEST_ASSERT(inSuite, connections.FindSecureSessionByLocalKey(2)); // Node ID 2, peer key 3, local key 4 - err = connections.CreateNewSecureSession(kPeer2NodeId, 3, 4, &statePtr); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + statePtr = connections.CreateNewSecureSession(4, kPeer2NodeId, 3, 0 /* fabricIndex */); + NL_TEST_ASSERT(inSuite, statePtr != nullptr); - NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(3, nullptr)); - NL_TEST_ASSERT(inSuite, connections.FindSecureSessionByLocalKey(4, nullptr)); + NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(3)); + NL_TEST_ASSERT(inSuite, connections.FindSecureSessionByLocalKey(4)); } struct ExpiredCallInfo @@ -145,7 +101,6 @@ struct ExpiredCallInfo void TestExpireConnections(nlTestSuite * inSuite, void * inContext) { - CHIP_ERROR err; ExpiredCallInfo callInfo; SecureSession * statePtr; SecureSessionTable<2, Time::Source::kTest> connections; @@ -153,20 +108,20 @@ void TestExpireConnections(nlTestSuite * inSuite, void * inContext) connections.GetTimeSource().SetCurrentMonotonicTimeMs(100); // Node ID 1, peer key 1, local key 2 - err = connections.CreateNewSecureSession(kPeer1NodeId, 1, 2, &statePtr); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + statePtr = connections.CreateNewSecureSession(2, kPeer1NodeId, 1, 0 /* fabricIndex */); + NL_TEST_ASSERT(inSuite, statePtr != nullptr); statePtr->SetPeerAddress(kPeer1Addr); connections.GetTimeSource().SetCurrentMonotonicTimeMs(200); // Node ID 2, peer key 3, local key 4 - err = connections.CreateNewSecureSession(kPeer2NodeId, 3, 4, &statePtr); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + statePtr = connections.CreateNewSecureSession(4, kPeer2NodeId, 3, 0 /* fabricIndex */); + NL_TEST_ASSERT(inSuite, statePtr != nullptr); statePtr->SetPeerAddress(kPeer2Addr); // cannot add before expiry connections.GetTimeSource().SetCurrentMonotonicTimeMs(300); - err = connections.CreateNewSecureSession(kPeer3NodeId, 5, 6, &statePtr); - NL_TEST_ASSERT(inSuite, err != CHIP_NO_ERROR); + statePtr = connections.CreateNewSecureSession(6, kPeer3NodeId, 5, 0 /* fabricIndex */); + NL_TEST_ASSERT(inSuite, statePtr == nullptr); // at time 300, this expires ip addr 1 connections.ExpireInactiveSessions(150, [&callInfo](const SecureSession & state) { @@ -177,17 +132,17 @@ void TestExpireConnections(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, callInfo.callCount == 1); NL_TEST_ASSERT(inSuite, callInfo.lastCallNodeId == kPeer1NodeId); NL_TEST_ASSERT(inSuite, callInfo.lastCallPeerAddress == kPeer1Addr); - NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(2, nullptr)); + NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(2)); // now that the connections were expired, we can add peer3 connections.GetTimeSource().SetCurrentMonotonicTimeMs(300); // Node ID 3, peer key 5, local key 6 - err = connections.CreateNewSecureSession(kPeer3NodeId, 5, 6, &statePtr); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + statePtr = connections.CreateNewSecureSession(6, kPeer3NodeId, 5, 0 /* fabricIndex */); + NL_TEST_ASSERT(inSuite, statePtr != nullptr); statePtr->SetPeerAddress(kPeer3Addr); connections.GetTimeSource().SetCurrentMonotonicTimeMs(400); - NL_TEST_ASSERT(inSuite, statePtr = connections.FindSecureSessionByLocalKey(4, nullptr)); + NL_TEST_ASSERT(inSuite, statePtr = connections.FindSecureSessionByLocalKey(4)); connections.MarkSessionActive(statePtr); NL_TEST_ASSERT(inSuite, statePtr->GetLastActivityTimeMs() == connections.GetTimeSource().GetCurrentMonotonicTimeMs()); @@ -208,16 +163,16 @@ void TestExpireConnections(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, callInfo.callCount == 1); NL_TEST_ASSERT(inSuite, callInfo.lastCallNodeId == kPeer3NodeId); NL_TEST_ASSERT(inSuite, callInfo.lastCallPeerAddress == kPeer3Addr); - NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(2, nullptr)); - NL_TEST_ASSERT(inSuite, connections.FindSecureSessionByLocalKey(4, nullptr)); - NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(6, nullptr)); + NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(2)); + NL_TEST_ASSERT(inSuite, connections.FindSecureSessionByLocalKey(4)); + NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(6)); // Node ID 1, peer key 1, local key 2 - err = connections.CreateNewSecureSession(kPeer1NodeId, 1, 2, &statePtr); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - NL_TEST_ASSERT(inSuite, connections.FindSecureSessionByLocalKey(2, nullptr)); - NL_TEST_ASSERT(inSuite, connections.FindSecureSessionByLocalKey(4, nullptr)); - NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(6, nullptr)); + statePtr = connections.CreateNewSecureSession(2, kPeer1NodeId, 1, 0 /* fabricIndex */); + NL_TEST_ASSERT(inSuite, statePtr != nullptr); + NL_TEST_ASSERT(inSuite, connections.FindSecureSessionByLocalKey(2)); + NL_TEST_ASSERT(inSuite, connections.FindSecureSessionByLocalKey(4)); + NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(6)); // peer 1 and 2 are active connections.GetTimeSource().SetCurrentMonotonicTimeMs(1000); @@ -228,9 +183,9 @@ void TestExpireConnections(nlTestSuite * inSuite, void * inContext) callInfo.lastCallPeerAddress = state.GetPeerAddress(); }); NL_TEST_ASSERT(inSuite, callInfo.callCount == 2); // everything expired - NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(2, nullptr)); - NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(4, nullptr)); - NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(6, nullptr)); + NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(2)); + NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(4)); + NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(6)); } } // namespace @@ -239,7 +194,6 @@ void TestExpireConnections(nlTestSuite * inSuite, void * inContext) static const nlTest sTests[] = { NL_TEST_DEF("BasicFunctionality", TestBasicFunctionality), - NL_TEST_DEF("FindByNodeId", TestFindByNodeId), NL_TEST_DEF("FindByKeyId", TestFindByKeyId), NL_TEST_DEF("ExpireConnections", TestExpireConnections), NL_TEST_SENTINEL()