diff --git a/src/protocols/secure_channel/CASESession.cpp b/src/protocols/secure_channel/CASESession.cpp index 393ae04a2699e3..3b8b5f8c08462f 100644 --- a/src/protocols/secure_channel/CASESession.cpp +++ b/src/protocols/secure_channel/CASESession.cpp @@ -199,6 +199,7 @@ CHIP_ERROR CASESession::EstablishSession(SessionManager & sessionManager, Fabric ReturnErrorCodeIf(fabric == nullptr, CHIP_ERROR_INVALID_ARGUMENT); err = Init(sessionManager, delegate); + GrabUnauthenticatedSession(exchangeCtxt->GetSessionHandle()); mRole = CryptoContext::SessionRole::kInitiator; @@ -1642,6 +1643,8 @@ CHIP_ERROR CASESession::OnMessageReceived(ExchangeContext * ec, const PayloadHea Protocols::SecureChannel::MsgType msgType = static_cast(payloadHeader.GetMessageType()); SuccessOrExit(err); + GrabUnauthenticatedSession(ec->GetSessionHandle()); + // By default, CHIP_ERROR_INVALID_MESSAGE_TYPE is returned if in the current state // a message handler is not defined for the received message type. err = CHIP_ERROR_INVALID_MESSAGE_TYPE; diff --git a/src/protocols/secure_channel/PASESession.cpp b/src/protocols/secure_channel/PASESession.cpp index c513f1d2808c99..c8feed1c74f8dd 100644 --- a/src/protocols/secure_channel/PASESession.cpp +++ b/src/protocols/secure_channel/PASESession.cpp @@ -208,6 +208,7 @@ CHIP_ERROR PASESession::Pair(SessionManager & sessionManager, uint32_t peerSetUp MATTER_TRACE_EVENT_SCOPE("Pair", "PASESession"); ReturnErrorCodeIf(exchangeCtxt == nullptr, CHIP_ERROR_INVALID_ARGUMENT); CHIP_ERROR err = Init(sessionManager, peerSetUpPINCode, delegate); + GrabUnauthenticatedSession(exchangeCtxt->GetSessionHandle()); SuccessOrExit(err); mRole = CryptoContext::SessionRole::kInitiator; @@ -803,6 +804,8 @@ CHIP_ERROR PASESession::OnMessageReceived(ExchangeContext * exchange, const Payl CHIP_ERROR err = ValidateReceivedMessage(exchange, payloadHeader, msg); SuccessOrExit(err); + GrabUnauthenticatedSession(exchange->GetSessionHandle()); + switch (static_cast(payloadHeader.GetMessageType())) { case MsgType::PBKDFParamRequest: diff --git a/src/protocols/secure_channel/PairingSession.cpp b/src/protocols/secure_channel/PairingSession.cpp index 55d992c1655d83..e818085578559e 100644 --- a/src/protocols/secure_channel/PairingSession.cpp +++ b/src/protocols/secure_channel/PairingSession.cpp @@ -25,10 +25,11 @@ namespace chip { CHIP_ERROR PairingSession::AllocateSecureSession(SessionManager & sessionManager) { - auto handle = sessionManager.AllocateSession(GetSecureSessionType()); + Optional handle = sessionManager.AllocateSession(GetSecureSessionType()); VerifyOrReturnError(handle.HasValue(), CHIP_ERROR_NO_MEMORY); VerifyOrReturnError(mSecureSessionHolder.GrabPairing(handle.Value()), CHIP_ERROR_INTERNAL); - mSessionManager = &sessionManager; + mSecureSessionRef = handle.Value().ToShared(); + mSessionManager = &sessionManager; return CHIP_NO_ERROR; } @@ -153,7 +154,9 @@ void PairingSession::Clear() mExchangeCtxt = nullptr; } + mUnauthenticatedSessionRef.Release(); mSecureSessionHolder.Release(); + mSecureSessionRef.Release(); mPeerSessionId.ClearValue(); mSessionManager = nullptr; } diff --git a/src/protocols/secure_channel/PairingSession.h b/src/protocols/secure_channel/PairingSession.h index 25e0c9f95bd06b..6d8947217fc1ac 100644 --- a/src/protocols/secure_channel/PairingSession.h +++ b/src/protocols/secure_channel/PairingSession.h @@ -33,6 +33,7 @@ #include #include #include +#include namespace chip { @@ -89,6 +90,8 @@ class DLL_EXPORT PairingSession : public SessionDelegate TLV::TLVWriter & tlvWriter); protected: + void GrabUnauthenticatedSession(const SessionHandle & session) { mUnauthenticatedSessionRef = session.ToShared(); } + /** * Allocate a secure session object from the passed session manager for the * pending session establishment operation. @@ -174,6 +177,8 @@ class DLL_EXPORT PairingSession : public SessionDelegate protected: CryptoContext::SessionRole mRole; + SessionSharedPtr mUnauthenticatedSessionRef; // Hold the unauthenticated session to prevent it from releasing + SessionSharedPtr mSecureSessionRef; // Hold the secure session to prevent it from releasing SessionHolderWithDelegate mSecureSessionHolder; // mSessionManager is set if we actually allocate a secure session, so we // can clean it up later as needed. diff --git a/src/transport/BUILD.gn b/src/transport/BUILD.gn index defde06d387fdc..ff8afecd0db5bc 100644 --- a/src/transport/BUILD.gn +++ b/src/transport/BUILD.gn @@ -47,6 +47,7 @@ static_library("transport") { "SessionManager.h", "SessionMessageCounter.h", "SessionMessageDelegate.h", + "SessionSharedPtr.h", "TransportMgr.h", "TransportMgrBase.cpp", "TransportMgrBase.h", diff --git a/src/transport/SessionHandle.h b/src/transport/SessionHandle.h index e5087133083968..e2a591d0eb707a 100644 --- a/src/transport/SessionHandle.h +++ b/src/transport/SessionHandle.h @@ -19,6 +19,7 @@ #include #include +#include namespace chip { @@ -46,6 +47,7 @@ class SessionHandle SessionHandle & operator=(SessionHandle &&) = delete; bool operator==(const SessionHandle & that) const { return &mSession.Get() == &that.mSession.Get(); } + SessionSharedPtr ToShared() const { return SessionSharedPtr(mSession.Get()); } Transport::Session * operator->() const { return mSession.operator->(); } diff --git a/src/transport/SessionHolder.cpp b/src/transport/SessionHolder.cpp index f9420a88e7a306..fe64621c9381d9 100644 --- a/src/transport/SessionHolder.cpp +++ b/src/transport/SessionHolder.cpp @@ -31,7 +31,7 @@ SessionHolder::SessionHolder(const SessionHolder & that) : IntrusiveListNodeBase mSession = that.mSession; if (mSession.HasValue()) { - mSession.Value()->AddHolder(*this); + mSession.Value().get().AddHolder(*this); } } @@ -40,7 +40,7 @@ SessionHolder::SessionHolder(SessionHolder && that) : IntrusiveListNodeBase() mSession = that.mSession; if (mSession.HasValue()) { - mSession.Value()->AddHolder(*this); + mSession.Value().get().AddHolder(*this); } that.Release(); @@ -53,7 +53,7 @@ SessionHolder & SessionHolder::operator=(const SessionHolder & that) mSession = that.mSession; if (mSession.HasValue()) { - mSession.Value()->AddHolder(*this); + mSession.Value().get().AddHolder(*this); } return *this; @@ -66,7 +66,7 @@ SessionHolder & SessionHolder::operator=(SessionHolder && that) mSession = that.mSession; if (mSession.HasValue()) { - mSession.Value()->AddHolder(*this); + mSession.Value().get().AddHolder(*this); } that.Release(); @@ -84,7 +84,7 @@ bool SessionHolder::GrabPairing(const SessionHandle & session) if (!session->AsSecureSession()->IsPairing()) return false; - mSession.Emplace(session.mSession); + mSession.Emplace(session.mSession.Get()); session->AddHolder(*this); return true; } @@ -96,7 +96,7 @@ bool SessionHolder::Grab(const SessionHandle & session) if (!session->IsActiveSession()) return false; - mSession.Emplace(session.mSession); + mSession.Emplace(session.mSession.Get()); session->AddHolder(*this); return true; } @@ -105,7 +105,7 @@ void SessionHolder::Release() { if (mSession.HasValue()) { - mSession.Value()->RemoveHolder(*this); + mSession.Value().get().RemoveHolder(*this); mSession.ClearValue(); } } diff --git a/src/transport/SessionHolder.h b/src/transport/SessionHolder.h index 0d9477b142be52..96a651e190ef39 100644 --- a/src/transport/SessionHolder.h +++ b/src/transport/SessionHolder.h @@ -16,6 +16,8 @@ #pragma once +#include + #include #include #include @@ -44,14 +46,14 @@ class SessionHolder : public SessionDelegate, public IntrusiveListNodeBase bool Contains(const SessionHandle & session) const { - return mSession.HasValue() && &mSession.Value().Get() == &session.mSession.Get(); + return mSession.HasValue() && &mSession.Value().get() == &session.mSession.Get(); } bool GrabPairing(const SessionHandle & session); // Should be only used inside CASE/PASE pairing. bool Grab(const SessionHandle & session); void Release(); - operator bool() const { return mSession.HasValue(); } + explicit operator bool() const { return mSession.HasValue(); } Optional Get() const { // @@ -60,14 +62,14 @@ class SessionHolder : public SessionDelegate, public IntrusiveListNodeBase // // So, construct a new Optional from the underlying Transport::Session reference. // - return mSession.HasValue() ? chip::MakeOptional(mSession.Value().Get()) + return mSession.HasValue() ? chip::MakeOptional(mSession.Value().get()) : chip::Optional::Missing(); } - Transport::Session * operator->() const { return &mSession.Value().Get(); } + Transport::Session * operator->() const { return &mSession.Value().get(); } private: - Optional> mSession; + Optional> mSession; }; // @brief Extends SessionHolder to allow propagate OnSessionReleased event to an extra given destination diff --git a/src/transport/SessionSharedPtr.h b/src/transport/SessionSharedPtr.h new file mode 100644 index 00000000000000..7e387090861771 --- /dev/null +++ b/src/transport/SessionSharedPtr.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace chip { + +namespace Transport { +class Session; +} // namespace Transport + +class SessionHolder; + +/** @brief + * A shared_ptr like smart pointer to manage the lifetime of a Session object. Like a shared_ptr, + * this object can start out not tracking any Session and be attached to a Session there-after. + * The underlying Session is guaranteed to remain active and resident until all references to it from SessionSharedPtr + * instances in the system have gone away, at which point it invokes its custom destructor. + * + * Just because a Session is refcounted does not mean it actually gets destroyed upon reaching a count of 0. + * UnauthenticatedSession and SecureSession have different logic that gets invoked when the count hits 0. + * + * This should really only be used during session setup by the entity setting up the session. + * Once setup, the session should transfer ownership to the SessionManager at which point, + * all clients in the system should only be holding SessionWeakPtrs (SessionWeakPtr doesn't exist yet, but once + * #18399 is complete, SessionHolder will become SessionWeakPtr). + * + * This is copy-constructible. + */ +class SessionSharedPtr +{ +public: + SessionSharedPtr() {} + SessionSharedPtr(Transport::Session & session) : mSession(InPlace, session) {} + + SessionSharedPtr(const SessionSharedPtr &) = default; + SessionSharedPtr & operator=(const SessionSharedPtr &) = default; + + /* + * If we're currently pointing to a valid session, remove ourselves + * as a shared owner of that session. If there are no more shared owners + * on that session, that session MAY be reclaimed. + */ + void Release() { mSession.ClearValue(); } + +private: + Optional> mSession; +}; + +} // namespace chip