Skip to content

Commit

Permalink
Add SessionSharedPtr, Change SessionHolder to a weak ref
Browse files Browse the repository at this point in the history
  • Loading branch information
kghost committed May 24, 2022
1 parent 352333a commit 8fcb4de
Show file tree
Hide file tree
Showing 9 changed files with 100 additions and 15 deletions.
3 changes: 3 additions & 0 deletions src/protocols/secure_channel/CASESession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ CHIP_ERROR CASESession::EstablishSession(SessionManager & sessionManager, Fabric
ReturnErrorCodeIf(fabric == nullptr, CHIP_ERROR_INVALID_ARGUMENT);

err = Init(sessionManager, delegate);
GrabUnauthenticatedSession(exchangeCtxt->GetSessionHandle());

mRole = CryptoContext::SessionRole::kInitiator;

Expand Down Expand Up @@ -1642,6 +1643,8 @@ CHIP_ERROR CASESession::OnMessageReceived(ExchangeContext * ec, const PayloadHea
Protocols::SecureChannel::MsgType msgType = static_cast<Protocols::SecureChannel::MsgType>(payloadHeader.GetMessageType());
SuccessOrExit(err);

GrabUnauthenticatedSession(ec->GetSessionHandle());

// By default, CHIP_ERROR_INVALID_MESSAGE_TYPE is returned if in the current state
// a message handler is not defined for the received message type.
err = CHIP_ERROR_INVALID_MESSAGE_TYPE;
Expand Down
3 changes: 3 additions & 0 deletions src/protocols/secure_channel/PASESession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ CHIP_ERROR PASESession::Pair(SessionManager & sessionManager, uint32_t peerSetUp
MATTER_TRACE_EVENT_SCOPE("Pair", "PASESession");
ReturnErrorCodeIf(exchangeCtxt == nullptr, CHIP_ERROR_INVALID_ARGUMENT);
CHIP_ERROR err = Init(sessionManager, peerSetUpPINCode, delegate);
GrabUnauthenticatedSession(exchangeCtxt->GetSessionHandle());
SuccessOrExit(err);

mRole = CryptoContext::SessionRole::kInitiator;
Expand Down Expand Up @@ -803,6 +804,8 @@ CHIP_ERROR PASESession::OnMessageReceived(ExchangeContext * exchange, const Payl
CHIP_ERROR err = ValidateReceivedMessage(exchange, payloadHeader, msg);
SuccessOrExit(err);

GrabUnauthenticatedSession(exchange->GetSessionHandle());

switch (static_cast<MsgType>(payloadHeader.GetMessageType()))
{
case MsgType::PBKDFParamRequest:
Expand Down
7 changes: 5 additions & 2 deletions src/protocols/secure_channel/PairingSession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@ namespace chip {

CHIP_ERROR PairingSession::AllocateSecureSession(SessionManager & sessionManager)
{
auto handle = sessionManager.AllocateSession(GetSecureSessionType());
Optional<SessionHandle> handle = sessionManager.AllocateSession(GetSecureSessionType());
VerifyOrReturnError(handle.HasValue(), CHIP_ERROR_NO_MEMORY);
VerifyOrReturnError(mSecureSessionHolder.GrabPairing(handle.Value()), CHIP_ERROR_INTERNAL);
mSessionManager = &sessionManager;
mSecureSessionRef = handle.Value().ToShared();
mSessionManager = &sessionManager;
return CHIP_NO_ERROR;
}

Expand Down Expand Up @@ -153,7 +154,9 @@ void PairingSession::Clear()
mExchangeCtxt = nullptr;
}

mUnauthenticatedSessionRef.Release();
mSecureSessionHolder.Release();
mSecureSessionRef.Release();
mPeerSessionId.ClearValue();
mSessionManager = nullptr;
}
Expand Down
5 changes: 5 additions & 0 deletions src/protocols/secure_channel/PairingSession.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <protocols/secure_channel/StatusReport.h>
#include <transport/CryptoContext.h>
#include <transport/SecureSession.h>
#include <transport/SessionSharedPtr.h>

namespace chip {

Expand Down Expand Up @@ -89,6 +90,8 @@ class DLL_EXPORT PairingSession : public SessionDelegate
TLV::TLVWriter & tlvWriter);

protected:
void GrabUnauthenticatedSession(const SessionHandle & session) { mUnauthenticatedSessionRef = session.ToShared(); }

/**
* Allocate a secure session object from the passed session manager for the
* pending session establishment operation.
Expand Down Expand Up @@ -174,6 +177,8 @@ class DLL_EXPORT PairingSession : public SessionDelegate

protected:
CryptoContext::SessionRole mRole;
SessionSharedPtr mUnauthenticatedSessionRef; // Hold the unauthenticated session to prevent it from releasing
SessionSharedPtr mSecureSessionRef; // Hold the secure session to prevent it from releasing
SessionHolderWithDelegate mSecureSessionHolder;
// mSessionManager is set if we actually allocate a secure session, so we
// can clean it up later as needed.
Expand Down
1 change: 1 addition & 0 deletions src/transport/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ static_library("transport") {
"SessionManager.h",
"SessionMessageCounter.h",
"SessionMessageDelegate.h",
"SessionSharedPtr.h",
"TransportMgr.h",
"TransportMgrBase.cpp",
"TransportMgrBase.h",
Expand Down
4 changes: 3 additions & 1 deletion src/transport/SessionHandle.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include <access/SubjectDescriptor.h>
#include <lib/support/ReferenceCountedHandle.h>
#include <transport/SessionSharedPtr.h>

namespace chip {

Expand All @@ -40,12 +41,13 @@ class SessionHandle
SessionHandle(Transport::Session & session) : mSession(session) {}
~SessionHandle() {}

SessionHandle(const SessionHandle &) = delete;
SessionHandle(const SessionHandle &) = default;
SessionHandle operator=(const SessionHandle &) = delete;
SessionHandle(SessionHandle &&) = default;
SessionHandle & operator=(SessionHandle &&) = delete;

bool operator==(const SessionHandle & that) const { return &mSession.Get() == &that.mSession.Get(); }
SessionSharedPtr ToShared() const { return SessionSharedPtr(mSession.Get()); }

Transport::Session * operator->() const { return mSession.operator->(); }

Expand Down
14 changes: 7 additions & 7 deletions src/transport/SessionHolder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ SessionHolder::SessionHolder(const SessionHolder & that) : IntrusiveListNodeBase
mSession = that.mSession;
if (mSession.HasValue())
{
mSession.Value()->AddHolder(*this);
mSession.Value().get().AddHolder(*this);
}
}

Expand All @@ -40,7 +40,7 @@ SessionHolder::SessionHolder(SessionHolder && that) : IntrusiveListNodeBase()
mSession = that.mSession;
if (mSession.HasValue())
{
mSession.Value()->AddHolder(*this);
mSession.Value().get().AddHolder(*this);
}

that.Release();
Expand All @@ -53,7 +53,7 @@ SessionHolder & SessionHolder::operator=(const SessionHolder & that)
mSession = that.mSession;
if (mSession.HasValue())
{
mSession.Value()->AddHolder(*this);
mSession.Value().get().AddHolder(*this);
}

return *this;
Expand All @@ -66,7 +66,7 @@ SessionHolder & SessionHolder::operator=(SessionHolder && that)
mSession = that.mSession;
if (mSession.HasValue())
{
mSession.Value()->AddHolder(*this);
mSession.Value().get().AddHolder(*this);
}

that.Release();
Expand All @@ -84,7 +84,7 @@ bool SessionHolder::GrabPairing(const SessionHandle & session)
if (!session->AsSecureSession()->IsPairing())
return false;

mSession.Emplace(session.mSession);
mSession.Emplace(session.mSession.Get());
session->AddHolder(*this);
return true;
}
Expand All @@ -96,7 +96,7 @@ bool SessionHolder::Grab(const SessionHandle & session)
if (!session->IsActiveSession())
return false;

mSession.Emplace(session.mSession);
mSession.Emplace(session.mSession.Get());
session->AddHolder(*this);
return true;
}
Expand All @@ -105,7 +105,7 @@ void SessionHolder::Release()
{
if (mSession.HasValue())
{
mSession.Value()->RemoveHolder(*this);
mSession.Value().get().RemoveHolder(*this);
mSession.ClearValue();
}
}
Expand Down
12 changes: 7 additions & 5 deletions src/transport/SessionHolder.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

#pragma once

#include <functional>

#include <lib/core/Optional.h>
#include <lib/support/IntrusiveList.h>
#include <transport/SessionDelegate.h>
Expand Down Expand Up @@ -44,14 +46,14 @@ class SessionHolder : public SessionDelegate, public IntrusiveListNodeBase

bool Contains(const SessionHandle & session) const
{
return mSession.HasValue() && &mSession.Value().Get() == &session.mSession.Get();
return mSession.HasValue() && &mSession.Value().get() == &session.mSession.Get();
}

bool GrabPairing(const SessionHandle & session); // Should be only used inside CASE/PASE pairing.
bool Grab(const SessionHandle & session);
void Release();

operator bool() const { return mSession.HasValue(); }
explicit operator bool() const { return mSession.HasValue(); }
Optional<SessionHandle> Get() const
{
//
Expand All @@ -60,14 +62,14 @@ class SessionHolder : public SessionDelegate, public IntrusiveListNodeBase
//
// So, construct a new Optional<SessionHandle> from the underlying Transport::Session reference.
//
return mSession.HasValue() ? chip::MakeOptional<SessionHandle>(mSession.Value().Get())
return mSession.HasValue() ? chip::MakeOptional<SessionHandle>(mSession.Value().get())
: chip::Optional<SessionHandle>::Missing();
}

Transport::Session * operator->() const { return &mSession.Value().Get(); }
Transport::Session * operator->() const { return &mSession.Value().get(); }

private:
Optional<ReferenceCountedHandle<Transport::Session>> mSession;
Optional<std::reference_wrapper<Transport::Session>> mSession;
};

// @brief Extends SessionHolder to allow propagate OnSessionReleased event to an extra given destination
Expand Down
66 changes: 66 additions & 0 deletions src/transport/SessionSharedPtr.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Copyright (c) 2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <lib/core/Optional.h>
#include <lib/support/ReferenceCountedHandle.h>

namespace chip {

namespace Transport {
class Session;
} // namespace Transport

class SessionHolder;

/** @brief
* A shared_ptr like smart pointer to manage the lifetime of a Session object. Like a shared_ptr,
* this object can start out not tracking any Session and be attached to a Session there-after.
* The underlying Session is guaranteed to remain active and resident until all references to it from SessionSharedPtr
* instances in the system have gone away, at which point it invokes its custom destructor.
*
* Just because a Session is refcounted does not mean it actually gets destroyed upon reaching a count of 0.
* UnauthenticatedSession and SecureSession have different logic that gets invoked when the count hits 0.
*
* This should really only be used during session setup by the entity setting up the session.
* Once setup, the session should transfer ownership to the SessionManager at which point,
* all clients in the system should only be holding SessionWeakPtrs (SessionWeakPtr doesn't exist yet, but once
* #18399 is complete, SessionHolder will become SessionWeakPtr).
*
* This is copy-constructible.
*/
class SessionSharedPtr
{
public:
SessionSharedPtr() {}
SessionSharedPtr(Transport::Session & session) : mSession(InPlace, session) {}

SessionSharedPtr(const SessionSharedPtr &) = default;
SessionSharedPtr & operator=(const SessionSharedPtr &) = default;

/*
* If we're currently pointing to a valid session, remove ourselves
* as a shared owner of that session. If there are no more shared owners
* on that session, that session MAY be reclaimed.
*/
void Release() { mSession.ClearValue(); }

private:
Optional<ReferenceCountedHandle<Transport::Session>> mSession;
};

} // namespace chip

0 comments on commit 8fcb4de

Please sign in to comment.