Skip to content

Commit

Permalink
Fix UnauthenticatedSession leak (#10754) (#11325)
Browse files Browse the repository at this point in the history
* Fix UnauthenticatedSession leak

* Resovle comments
  • Loading branch information
kghost authored Nov 3, 2021
1 parent 18bbe0a commit d0c7d97
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 42 deletions.
9 changes: 9 additions & 0 deletions src/lib/core/Optional.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@

namespace chip {

/// An empty class type used to indicate optional type with uninitialized state.
struct NullOptionalType
{
explicit NullOptionalType() = default;
};
constexpr NullOptionalType NullOptional{};

/**
* Pairs an object with a boolean value to determine if the object value
* is actually valid or not.
Expand All @@ -38,6 +45,8 @@ class Optional
{
public:
constexpr Optional() : mHasValue(false) {}
constexpr Optional(NullOptionalType) : mHasValue(false) {}

~Optional()
{
if (mHasValue)
Expand Down
12 changes: 6 additions & 6 deletions src/transport/SecureSessionMgr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ CHIP_ERROR SecureSessionMgr::SendPreparedMessage(SessionHandle session, const En
else
{
auto unauthenticated = session.GetUnauthenticatedSession();
mUnauthenticatedSessions.MarkSessionActive(unauthenticated.Get());
mUnauthenticatedSessions.MarkSessionActive(unauthenticated);
destination = &unauthenticated->GetPeerAddress();

ChipLogProgress(Inet, "Sending %s msg %p to 0x" ChipLogFormatX64 " at utc time: %" PRId64 " msec", "plaintext",
Expand Down Expand Up @@ -329,13 +329,14 @@ void SecureSessionMgr::OnMessageReceived(const PeerAddress & peerAddress, System
void SecureSessionMgr::MessageDispatch(const PacketHeader & packetHeader, const Transport::PeerAddress & peerAddress,
System::PacketBufferHandle && msg)
{
Transport::UnauthenticatedSession * session = mUnauthenticatedSessions.FindOrAllocateEntry(peerAddress);
if (session == nullptr)
Optional<Transport::UnauthenticatedSessionHandle> optionalSession = mUnauthenticatedSessions.FindOrAllocateEntry(peerAddress);
if (!optionalSession.HasValue())
{
ChipLogError(Inet, "UnauthenticatedSession exhausted");
return;
}

Transport::UnauthenticatedSessionHandle session = optionalSession.Value();
SecureSessionMgrDelegate::DuplicateMessage isDuplicate = SecureSessionMgrDelegate::DuplicateMessage::No;

// Verify message counter
Expand All @@ -348,7 +349,7 @@ void SecureSessionMgr::MessageDispatch(const PacketHeader & packetHeader, const
}
VerifyOrDie(err == CHIP_NO_ERROR);

mUnauthenticatedSessions.MarkSessionActive(*session);
mUnauthenticatedSessions.MarkSessionActive(session);

PayloadHeader payloadHeader;
ReturnOnFailure(payloadHeader.DecodeAndConsume(msg));
Expand All @@ -357,8 +358,7 @@ void SecureSessionMgr::MessageDispatch(const PacketHeader & packetHeader, const

if (mCB != nullptr)
{
mCB->OnMessageReceived(packetHeader, payloadHeader, SessionHandle(Transport::UnauthenticatedSessionHandle(*session)),
peerAddress, isDuplicate, std::move(msg));
mCB->OnMessageReceived(packetHeader, payloadHeader, SessionHandle(session), peerAddress, isDuplicate, std::move(msg));
}
}

Expand Down
7 changes: 2 additions & 5 deletions src/transport/SecureSessionMgr.h
Original file line number Diff line number Diff line change
Expand Up @@ -266,11 +266,8 @@ class DLL_EXPORT SecureSessionMgr : public TransportMgrDelegate

Optional<SessionHandle> CreateUnauthenticatedSession(const Transport::PeerAddress & peerAddress)
{
Transport::UnauthenticatedSession * session = mUnauthenticatedSessions.FindOrAllocateEntry(peerAddress);
if (session == nullptr)
return Optional<SessionHandle>::Missing();

return Optional<SessionHandle>::Value(SessionHandle(Transport::UnauthenticatedSessionHandle(*session)));
Optional<Transport::UnauthenticatedSessionHandle> session = mUnauthenticatedSessions.FindOrAllocateEntry(peerAddress);
return session.HasValue() ? MakeOptional<SessionHandle>(session.Value()) : NullOptional;
}

private:
Expand Down
65 changes: 34 additions & 31 deletions src/transport/UnauthenticatedSessionTable.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class UnauthenticatedSessionDeleter
* @brief
* An UnauthenticatedSession stores the binding of TransportAddress, and message counters.
*/
class UnauthenticatedSession : public ReferenceCounted<UnauthenticatedSession, UnauthenticatedSessionDeleter>
class UnauthenticatedSession : public ReferenceCounted<UnauthenticatedSession, UnauthenticatedSessionDeleter, 0>
{
public:
UnauthenticatedSession(const PeerAddress & address) : mPeerAddress(address) {}
Expand Down Expand Up @@ -82,6 +82,39 @@ template <size_t kMaxConnectionCount, Time::Source kTimeSource = Time::Source::k
class UnauthenticatedSessionTable
{
public:
/**
* Get a session given the peer address. If the session doesn't exist in the cache, allocate a new entry for it.
*
* @return the session found or allocated, nullptr if not found and allocation failed.
*/
CHECK_RETURN_VALUE
Optional<UnauthenticatedSessionHandle> FindOrAllocateEntry(const PeerAddress & address)
{
UnauthenticatedSession * result = FindEntry(address);
if (result != nullptr)
return MakeOptional<UnauthenticatedSessionHandle>(*result);

CHIP_ERROR err = AllocEntry(address, result);
if (err == CHIP_NO_ERROR)
{
return MakeOptional<UnauthenticatedSessionHandle>(*result);
}
else
{
return Optional<UnauthenticatedSessionHandle>::Missing();
}
}

/// Mark a session as active
void MarkSessionActive(UnauthenticatedSessionHandle session)
{
session->SetLastActivityTimeMs(mTimeSource.GetCurrentMonotonicTimeMs());
}

/// Allows access to the underlying time source used for keeping track of connection active time
Time::TimeSource<kTimeSource> & GetTimeSource() { return mTimeSource; }

private:
/**
* Allocates a new session out of the internal resource pool.
*
Expand Down Expand Up @@ -125,36 +158,6 @@ class UnauthenticatedSessionTable
return result;
}

/**
* Get a peer given the peer id. If the peer doesn't exist in the cache, allocate a new entry for it.
*
* @return the peer found or allocated, nullptr if not found and allocate failed.
*/
CHECK_RETURN_VALUE
UnauthenticatedSession * FindOrAllocateEntry(const PeerAddress & address)
{
UnauthenticatedSession * result = FindEntry(address);
if (result != nullptr)
return result;

CHIP_ERROR err = AllocEntry(address, result);
if (err == CHIP_NO_ERROR)
{
return result;
}
else
{
return nullptr;
}
}

/// Mark a session as active
void MarkSessionActive(UnauthenticatedSession & entry) { entry.SetLastActivityTimeMs(mTimeSource.GetCurrentMonotonicTimeMs()); }

/// Allows access to the underlying time source used for keeping track of connection active time
Time::TimeSource<kTimeSource> & GetTimeSource() { return mTimeSource; }

private:
UnauthenticatedSession * FindLeastRecentUsedEntry()
{
UnauthenticatedSession * result = nullptr;
Expand Down

0 comments on commit d0c7d97

Please sign in to comment.