diff --git a/src/lib/core/Optional.h b/src/lib/core/Optional.h index 6e14e00f506bc5..c94edfb0d7fa30 100644 --- a/src/lib/core/Optional.h +++ b/src/lib/core/Optional.h @@ -29,6 +29,13 @@ namespace chip { +/// An empty class type used to indicate optional type with uninitialized state. +struct NullOptionalType +{ + explicit NullOptionalType() = default; +}; +constexpr NullOptionalType NullOptional{}; + /** * Pairs an object with a boolean value to determine if the object value * is actually valid or not. @@ -38,6 +45,8 @@ class Optional { public: constexpr Optional() : mHasValue(false) {} + constexpr Optional(NullOptionalType) : mHasValue(false) {} + ~Optional() { if (mHasValue) diff --git a/src/transport/SecureSessionMgr.cpp b/src/transport/SecureSessionMgr.cpp index 41103b38be6ddb..89d91a455c3823 100644 --- a/src/transport/SecureSessionMgr.cpp +++ b/src/transport/SecureSessionMgr.cpp @@ -181,7 +181,7 @@ CHIP_ERROR SecureSessionMgr::SendPreparedMessage(SessionHandle session, const En else { auto unauthenticated = session.GetUnauthenticatedSession(); - mUnauthenticatedSessions.MarkSessionActive(unauthenticated.Get()); + mUnauthenticatedSessions.MarkSessionActive(unauthenticated); destination = &unauthenticated->GetPeerAddress(); ChipLogProgress(Inet, "Sending %s msg %p to 0x" ChipLogFormatX64 " at utc time: %" PRId64 " msec", "plaintext", @@ -329,13 +329,14 @@ void SecureSessionMgr::OnMessageReceived(const PeerAddress & peerAddress, System void SecureSessionMgr::MessageDispatch(const PacketHeader & packetHeader, const Transport::PeerAddress & peerAddress, System::PacketBufferHandle && msg) { - Transport::UnauthenticatedSession * session = mUnauthenticatedSessions.FindOrAllocateEntry(peerAddress); - if (session == nullptr) + Optional optionalSession = mUnauthenticatedSessions.FindOrAllocateEntry(peerAddress); + if (!optionalSession.HasValue()) { ChipLogError(Inet, "UnauthenticatedSession exhausted"); return; } + Transport::UnauthenticatedSessionHandle session = optionalSession.Value(); SecureSessionMgrDelegate::DuplicateMessage isDuplicate = SecureSessionMgrDelegate::DuplicateMessage::No; // Verify message counter @@ -348,7 +349,7 @@ void SecureSessionMgr::MessageDispatch(const PacketHeader & packetHeader, const } VerifyOrDie(err == CHIP_NO_ERROR); - mUnauthenticatedSessions.MarkSessionActive(*session); + mUnauthenticatedSessions.MarkSessionActive(session); PayloadHeader payloadHeader; ReturnOnFailure(payloadHeader.DecodeAndConsume(msg)); @@ -357,8 +358,7 @@ void SecureSessionMgr::MessageDispatch(const PacketHeader & packetHeader, const if (mCB != nullptr) { - mCB->OnMessageReceived(packetHeader, payloadHeader, SessionHandle(Transport::UnauthenticatedSessionHandle(*session)), - peerAddress, isDuplicate, std::move(msg)); + mCB->OnMessageReceived(packetHeader, payloadHeader, SessionHandle(session), peerAddress, isDuplicate, std::move(msg)); } } diff --git a/src/transport/SecureSessionMgr.h b/src/transport/SecureSessionMgr.h index 942f05f374ac6d..d1b04fa474a816 100644 --- a/src/transport/SecureSessionMgr.h +++ b/src/transport/SecureSessionMgr.h @@ -266,11 +266,8 @@ class DLL_EXPORT SecureSessionMgr : public TransportMgrDelegate Optional CreateUnauthenticatedSession(const Transport::PeerAddress & peerAddress) { - Transport::UnauthenticatedSession * session = mUnauthenticatedSessions.FindOrAllocateEntry(peerAddress); - if (session == nullptr) - return Optional::Missing(); - - return Optional::Value(SessionHandle(Transport::UnauthenticatedSessionHandle(*session))); + Optional session = mUnauthenticatedSessions.FindOrAllocateEntry(peerAddress); + return session.HasValue() ? MakeOptional(session.Value()) : NullOptional; } private: diff --git a/src/transport/UnauthenticatedSessionTable.h b/src/transport/UnauthenticatedSessionTable.h index 9f03662dcbe413..8219b8202088e8 100644 --- a/src/transport/UnauthenticatedSessionTable.h +++ b/src/transport/UnauthenticatedSessionTable.h @@ -44,7 +44,7 @@ class UnauthenticatedSessionDeleter * @brief * An UnauthenticatedSession stores the binding of TransportAddress, and message counters. */ -class UnauthenticatedSession : public ReferenceCounted +class UnauthenticatedSession : public ReferenceCounted { public: UnauthenticatedSession(const PeerAddress & address) : mPeerAddress(address) {} @@ -82,6 +82,39 @@ template FindOrAllocateEntry(const PeerAddress & address) + { + UnauthenticatedSession * result = FindEntry(address); + if (result != nullptr) + return MakeOptional(*result); + + CHIP_ERROR err = AllocEntry(address, result); + if (err == CHIP_NO_ERROR) + { + return MakeOptional(*result); + } + else + { + return Optional::Missing(); + } + } + + /// Mark a session as active + void MarkSessionActive(UnauthenticatedSessionHandle session) + { + session->SetLastActivityTimeMs(mTimeSource.GetCurrentMonotonicTimeMs()); + } + + /// Allows access to the underlying time source used for keeping track of connection active time + Time::TimeSource & GetTimeSource() { return mTimeSource; } + +private: /** * Allocates a new session out of the internal resource pool. * @@ -125,36 +158,6 @@ class UnauthenticatedSessionTable return result; } - /** - * Get a peer given the peer id. If the peer doesn't exist in the cache, allocate a new entry for it. - * - * @return the peer found or allocated, nullptr if not found and allocate failed. - */ - CHECK_RETURN_VALUE - UnauthenticatedSession * FindOrAllocateEntry(const PeerAddress & address) - { - UnauthenticatedSession * result = FindEntry(address); - if (result != nullptr) - return result; - - CHIP_ERROR err = AllocEntry(address, result); - if (err == CHIP_NO_ERROR) - { - return result; - } - else - { - return nullptr; - } - } - - /// Mark a session as active - void MarkSessionActive(UnauthenticatedSession & entry) { entry.SetLastActivityTimeMs(mTimeSource.GetCurrentMonotonicTimeMs()); } - - /// Allows access to the underlying time source used for keeping track of connection active time - Time::TimeSource & GetTimeSource() { return mTimeSource; } - -private: UnauthenticatedSession * FindLeastRecentUsedEntry() { UnauthenticatedSession * result = nullptr;