Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change SessionHolder to a weak ref #18397

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/protocols/secure_channel/CASESession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ CHIP_ERROR CASESession::EstablishSession(SessionManager & sessionManager, Fabric
ReturnErrorCodeIf(fabric == nullptr, CHIP_ERROR_INVALID_ARGUMENT);

err = Init(sessionManager, delegate);
GrabUnauthenticatedSession(exchangeCtxt->GetSessionHandle());

mRole = CryptoContext::SessionRole::kInitiator;

Expand Down Expand Up @@ -1642,6 +1643,8 @@ CHIP_ERROR CASESession::OnMessageReceived(ExchangeContext * ec, const PayloadHea
Protocols::SecureChannel::MsgType msgType = static_cast<Protocols::SecureChannel::MsgType>(payloadHeader.GetMessageType());
SuccessOrExit(err);

GrabUnauthenticatedSession(ec->GetSessionHandle());

// By default, CHIP_ERROR_INVALID_MESSAGE_TYPE is returned if in the current state
// a message handler is not defined for the received message type.
err = CHIP_ERROR_INVALID_MESSAGE_TYPE;
Expand Down
3 changes: 3 additions & 0 deletions src/protocols/secure_channel/PASESession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ CHIP_ERROR PASESession::Pair(SessionManager & sessionManager, uint32_t peerSetUp
MATTER_TRACE_EVENT_SCOPE("Pair", "PASESession");
ReturnErrorCodeIf(exchangeCtxt == nullptr, CHIP_ERROR_INVALID_ARGUMENT);
CHIP_ERROR err = Init(sessionManager, peerSetUpPINCode, delegate);
GrabUnauthenticatedSession(exchangeCtxt->GetSessionHandle());
SuccessOrExit(err);

mRole = CryptoContext::SessionRole::kInitiator;
Expand Down Expand Up @@ -803,6 +804,8 @@ CHIP_ERROR PASESession::OnMessageReceived(ExchangeContext * exchange, const Payl
CHIP_ERROR err = ValidateReceivedMessage(exchange, payloadHeader, msg);
SuccessOrExit(err);

GrabUnauthenticatedSession(exchange->GetSessionHandle());

switch (static_cast<MsgType>(payloadHeader.GetMessageType()))
{
case MsgType::PBKDFParamRequest:
Expand Down
7 changes: 5 additions & 2 deletions src/protocols/secure_channel/PairingSession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@ namespace chip {

CHIP_ERROR PairingSession::AllocateSecureSession(SessionManager & sessionManager)
{
auto handle = sessionManager.AllocateSession(GetSecureSessionType());
Optional<SessionHandle> handle = sessionManager.AllocateSession(GetSecureSessionType());
VerifyOrReturnError(handle.HasValue(), CHIP_ERROR_NO_MEMORY);
VerifyOrReturnError(mSecureSessionHolder.GrabPairing(handle.Value()), CHIP_ERROR_INTERNAL);
mSessionManager = &sessionManager;
mSecureSessionRef = handle.Value().ToShared();
mSessionManager = &sessionManager;
return CHIP_NO_ERROR;
}

Expand Down Expand Up @@ -153,7 +154,9 @@ void PairingSession::Clear()
mExchangeCtxt = nullptr;
}

mUnauthenticatedSessionRef.Release();
mSecureSessionHolder.Release();
mSecureSessionRef.Release();
mPeerSessionId.ClearValue();
mSessionManager = nullptr;
}
Expand Down
5 changes: 5 additions & 0 deletions src/protocols/secure_channel/PairingSession.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <protocols/secure_channel/StatusReport.h>
#include <transport/CryptoContext.h>
#include <transport/SecureSession.h>
#include <transport/SessionSharedPtr.h>

namespace chip {

Expand Down Expand Up @@ -89,6 +90,8 @@ class DLL_EXPORT PairingSession : public SessionDelegate
TLV::TLVWriter & tlvWriter);

protected:
void GrabUnauthenticatedSession(const SessionHandle & session) { mUnauthenticatedSessionRef = session.ToShared(); }

/**
* Allocate a secure session object from the passed session manager for the
* pending session establishment operation.
Expand Down Expand Up @@ -174,6 +177,8 @@ class DLL_EXPORT PairingSession : public SessionDelegate

protected:
CryptoContext::SessionRole mRole;
SessionSharedPtr mUnauthenticatedSessionRef; // Hold the unauthenticated session to prevent it from releasing
SessionSharedPtr mSecureSessionRef; // Hold the secure session to prevent it from releasing
SessionHolderWithDelegate mSecureSessionHolder;
// mSessionManager is set if we actually allocate a secure session, so we
// can clean it up later as needed.
Expand Down
1 change: 1 addition & 0 deletions src/transport/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ static_library("transport") {
"SessionManager.h",
"SessionMessageCounter.h",
"SessionMessageDelegate.h",
"SessionSharedPtr.h",
"TransportMgr.h",
"TransportMgrBase.cpp",
"TransportMgrBase.h",
Expand Down
2 changes: 2 additions & 0 deletions src/transport/SessionHandle.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include <access/SubjectDescriptor.h>
#include <lib/support/ReferenceCountedHandle.h>
#include <transport/SessionSharedPtr.h>

namespace chip {

Expand Down Expand Up @@ -46,6 +47,7 @@ class SessionHandle
SessionHandle & operator=(SessionHandle &&) = delete;

bool operator==(const SessionHandle & that) const { return &mSession.Get() == &that.mSession.Get(); }
SessionSharedPtr ToShared() const { return SessionSharedPtr(mSession.Get()); }

Transport::Session * operator->() const { return mSession.operator->(); }

Expand Down
14 changes: 7 additions & 7 deletions src/transport/SessionHolder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ SessionHolder::SessionHolder(const SessionHolder & that) : IntrusiveListNodeBase
mSession = that.mSession;
if (mSession.HasValue())
{
mSession.Value()->AddHolder(*this);
mSession.Value().get().AddHolder(*this);
}
}

Expand All @@ -40,7 +40,7 @@ SessionHolder::SessionHolder(SessionHolder && that) : IntrusiveListNodeBase()
mSession = that.mSession;
if (mSession.HasValue())
{
mSession.Value()->AddHolder(*this);
mSession.Value().get().AddHolder(*this);
}

that.Release();
Expand All @@ -53,7 +53,7 @@ SessionHolder & SessionHolder::operator=(const SessionHolder & that)
mSession = that.mSession;
if (mSession.HasValue())
{
mSession.Value()->AddHolder(*this);
mSession.Value().get().AddHolder(*this);
}

return *this;
Expand All @@ -66,7 +66,7 @@ SessionHolder & SessionHolder::operator=(SessionHolder && that)
mSession = that.mSession;
if (mSession.HasValue())
{
mSession.Value()->AddHolder(*this);
mSession.Value().get().AddHolder(*this);
}

that.Release();
Expand All @@ -84,7 +84,7 @@ bool SessionHolder::GrabPairing(const SessionHandle & session)
if (!session->AsSecureSession()->IsPairing())
return false;

mSession.Emplace(session.mSession);
mSession.Emplace(session.mSession.Get());
session->AddHolder(*this);
return true;
}
Expand All @@ -96,7 +96,7 @@ bool SessionHolder::Grab(const SessionHandle & session)
if (!session->IsActiveSession())
return false;

mSession.Emplace(session.mSession);
mSession.Emplace(session.mSession.Get());
session->AddHolder(*this);
return true;
}
Expand All @@ -105,7 +105,7 @@ void SessionHolder::Release()
{
if (mSession.HasValue())
{
mSession.Value()->RemoveHolder(*this);
mSession.Value().get().RemoveHolder(*this);
mSession.ClearValue();
}
}
Expand Down
12 changes: 7 additions & 5 deletions src/transport/SessionHolder.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

#pragma once

#include <functional>

#include <lib/core/Optional.h>
#include <lib/support/IntrusiveList.h>
#include <transport/SessionDelegate.h>
Expand Down Expand Up @@ -44,14 +46,14 @@ class SessionHolder : public SessionDelegate, public IntrusiveListNodeBase

bool Contains(const SessionHandle & session) const
{
return mSession.HasValue() && &mSession.Value().Get() == &session.mSession.Get();
return mSession.HasValue() && &mSession.Value().get() == &session.mSession.Get();
}

bool GrabPairing(const SessionHandle & session); // Should be only used inside CASE/PASE pairing.
bool Grab(const SessionHandle & session);
void Release();

operator bool() const { return mSession.HasValue(); }
explicit operator bool() const { return mSession.HasValue(); }
Optional<SessionHandle> Get() const
{
//
Expand All @@ -60,14 +62,14 @@ class SessionHolder : public SessionDelegate, public IntrusiveListNodeBase
//
// So, construct a new Optional<SessionHandle> from the underlying Transport::Session reference.
//
return mSession.HasValue() ? chip::MakeOptional<SessionHandle>(mSession.Value().Get())
return mSession.HasValue() ? chip::MakeOptional<SessionHandle>(mSession.Value().get())
: chip::Optional<SessionHandle>::Missing();
}

Transport::Session * operator->() const { return &mSession.Value().Get(); }
Transport::Session * operator->() const { return &mSession.Value().get(); }

private:
Optional<ReferenceCountedHandle<Transport::Session>> mSession;
Optional<std::reference_wrapper<Transport::Session>> mSession;
Copy link
Contributor

@mrjerryjohns mrjerryjohns May 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't all of this just become Transport::Session *mSession? What's the value to saying it's an "optional reference"?

};

// @brief Extends SessionHolder to allow propagate OnSessionReleased event to an extra given destination
Expand Down
66 changes: 66 additions & 0 deletions src/transport/SessionSharedPtr.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Copyright (c) 2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <lib/core/Optional.h>
#include <lib/support/ReferenceCountedHandle.h>

namespace chip {

namespace Transport {
class Session;
} // namespace Transport

class SessionHolder;

/** @brief
* A shared_ptr like smart pointer to manage the lifetime of a Session object. Like a shared_ptr,
* this object can start out not tracking any Session and be attached to a Session there-after.
* The underlying Session is guaranteed to remain active and resident until all references to it from SessionSharedPtr
* instances in the system have gone away, at which point it invokes its custom destructor.
*
* Just because a Session is refcounted does not mean it actually gets destroyed upon reaching a count of 0.
* UnauthenticatedSession and SecureSession have different logic that gets invoked when the count hits 0.
*
* This should really only be used during session setup by the entity setting up the session.
* Once setup, the session should transfer ownership to the SessionManager at which point,
* all clients in the system should only be holding SessionWeakPtrs (SessionWeakPtr doesn't exist yet, but once
* #18399 is complete, SessionHolder will become SessionWeakPtr).
*
* This is copy-constructible.
*/
class SessionSharedPtr
{
public:
SessionSharedPtr() {}
SessionSharedPtr(Transport::Session & session) : mSession(InPlace, session) {}

SessionSharedPtr(const SessionSharedPtr &) = default;
SessionSharedPtr & operator=(const SessionSharedPtr &) = default;

/*
* If we're currently pointing to a valid session, remove ourselves
* as a shared owner of that session. If there are no more shared owners
* on that session, that session MAY be reclaimed.
*/
void Release() { mSession.ClearValue(); }

private:
Optional<ReferenceCountedHandle<Transport::Session>> mSession;
};

} // namespace chip