From 157867464ca767f399144df4efa21190af5108bc Mon Sep 17 00:00:00 2001 From: Zang MingJie Date: Sat, 23 Oct 2021 05:16:54 +0800 Subject: [PATCH] Fix UnauthenticatedSession leak (#10754) * Fix UnauthenticatedSession leak * Resovle comments --- src/lib/core/Optional.h | 9 +++ src/transport/SessionManager.cpp | 12 ++-- src/transport/SessionManager.h | 7 +-- src/transport/UnauthenticatedSessionTable.h | 65 +++++++++++---------- 4 files changed, 51 insertions(+), 42 deletions(-) diff --git a/src/lib/core/Optional.h b/src/lib/core/Optional.h index 6e14e00f506bc5..c94edfb0d7fa30 100644 --- a/src/lib/core/Optional.h +++ b/src/lib/core/Optional.h @@ -29,6 +29,13 @@ namespace chip { +/// An empty class type used to indicate optional type with uninitialized state. +struct NullOptionalType +{ + explicit NullOptionalType() = default; +}; +constexpr NullOptionalType NullOptional{}; + /** * Pairs an object with a boolean value to determine if the object value * is actually valid or not. @@ -38,6 +45,8 @@ class Optional { public: constexpr Optional() : mHasValue(false) {} + constexpr Optional(NullOptionalType) : mHasValue(false) {} + ~Optional() { if (mHasValue) diff --git a/src/transport/SessionManager.cpp b/src/transport/SessionManager.cpp index 6bf6d5d3553bf4..0fccd619a16e87 100644 --- a/src/transport/SessionManager.cpp +++ b/src/transport/SessionManager.cpp @@ -189,7 +189,7 @@ CHIP_ERROR SessionManager::SendPreparedMessage(SessionHandle session, const Encr else { auto unauthenticated = session.GetUnauthenticatedSession(); - mUnauthenticatedSessions.MarkSessionActive(unauthenticated.Get()); + mUnauthenticatedSessions.MarkSessionActive(unauthenticated); destination = &unauthenticated->GetPeerAddress(); ChipLogProgress(Inet, @@ -339,13 +339,14 @@ void SessionManager::OnMessageReceived(const PeerAddress & peerAddress, System:: void SessionManager::MessageDispatch(const PacketHeader & packetHeader, const Transport::PeerAddress & peerAddress, System::PacketBufferHandle && msg) { - Transport::UnauthenticatedSession * session = mUnauthenticatedSessions.FindOrAllocateEntry(peerAddress); - if (session == nullptr) + Optional optionalSession = mUnauthenticatedSessions.FindOrAllocateEntry(peerAddress); + if (!optionalSession.HasValue()) { ChipLogError(Inet, "UnauthenticatedSession exhausted"); return; } + Transport::UnauthenticatedSessionHandle session = optionalSession.Value(); SessionManagerDelegate::DuplicateMessage isDuplicate = SessionManagerDelegate::DuplicateMessage::No; // Verify message counter @@ -357,7 +358,7 @@ void SessionManager::MessageDispatch(const PacketHeader & packetHeader, const Tr } VerifyOrDie(err == CHIP_NO_ERROR); - mUnauthenticatedSessions.MarkSessionActive(*session); + mUnauthenticatedSessions.MarkSessionActive(session); PayloadHeader payloadHeader; ReturnOnFailure(payloadHeader.DecodeAndConsume(msg)); @@ -374,8 +375,7 @@ void SessionManager::MessageDispatch(const PacketHeader & packetHeader, const Tr if (mCB != nullptr) { - mCB->OnMessageReceived(packetHeader, payloadHeader, SessionHandle(Transport::UnauthenticatedSessionHandle(*session)), - peerAddress, isDuplicate, std::move(msg)); + mCB->OnMessageReceived(packetHeader, payloadHeader, SessionHandle(session), peerAddress, isDuplicate, std::move(msg)); } } diff --git a/src/transport/SessionManager.h b/src/transport/SessionManager.h index 159251a8202421..43f6e0adbf1a12 100644 --- a/src/transport/SessionManager.h +++ b/src/transport/SessionManager.h @@ -265,11 +265,8 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate Optional CreateUnauthenticatedSession(const Transport::PeerAddress & peerAddress) { - Transport::UnauthenticatedSession * session = mUnauthenticatedSessions.FindOrAllocateEntry(peerAddress); - if (session == nullptr) - return Optional::Missing(); - - return Optional::Value(SessionHandle(Transport::UnauthenticatedSessionHandle(*session))); + Optional session = mUnauthenticatedSessions.FindOrAllocateEntry(peerAddress); + return session.HasValue() ? MakeOptional(session.Value()) : NullOptional; } private: diff --git a/src/transport/UnauthenticatedSessionTable.h b/src/transport/UnauthenticatedSessionTable.h index 701b27afe1bcca..70b46f26bf56b2 100644 --- a/src/transport/UnauthenticatedSessionTable.h +++ b/src/transport/UnauthenticatedSessionTable.h @@ -44,7 +44,7 @@ class UnauthenticatedSessionDeleter * @brief * An UnauthenticatedSession stores the binding of TransportAddress, and message counters. */ -class UnauthenticatedSession : public ReferenceCounted +class UnauthenticatedSession : public ReferenceCounted { public: UnauthenticatedSession(const PeerAddress & address) : mPeerAddress(address) { mLocalMessageCounter.Init(); } @@ -82,6 +82,39 @@ template FindOrAllocateEntry(const PeerAddress & address) + { + UnauthenticatedSession * result = FindEntry(address); + if (result != nullptr) + return MakeOptional(*result); + + CHIP_ERROR err = AllocEntry(address, result); + if (err == CHIP_NO_ERROR) + { + return MakeOptional(*result); + } + else + { + return Optional::Missing(); + } + } + + /// Mark a session as active + void MarkSessionActive(UnauthenticatedSessionHandle session) + { + session->SetLastActivityTimeMs(mTimeSource.GetCurrentMonotonicTimeMs()); + } + + /// Allows access to the underlying time source used for keeping track of connection active time + Time::TimeSource & GetTimeSource() { return mTimeSource; } + +private: /** * Allocates a new session out of the internal resource pool. * @@ -125,36 +158,6 @@ class UnauthenticatedSessionTable return result; } - /** - * Get a peer given the peer id. If the peer doesn't exist in the cache, allocate a new entry for it. - * - * @return the peer found or allocated, nullptr if not found and allocate failed. - */ - CHECK_RETURN_VALUE - UnauthenticatedSession * FindOrAllocateEntry(const PeerAddress & address) - { - UnauthenticatedSession * result = FindEntry(address); - if (result != nullptr) - return result; - - CHIP_ERROR err = AllocEntry(address, result); - if (err == CHIP_NO_ERROR) - { - return result; - } - else - { - return nullptr; - } - } - - /// Mark a session as active - void MarkSessionActive(UnauthenticatedSession & entry) { entry.SetLastActivityTimeMs(mTimeSource.GetCurrentMonotonicTimeMs()); } - - /// Allows access to the underlying time source used for keeping track of connection active time - Time::TimeSource & GetTimeSource() { return mTimeSource; } - -private: UnauthenticatedSession * FindLeastRecentUsedEntry() { UnauthenticatedSession * result = nullptr;