Skip to content

Commit

Permalink
Use refcounter for secure session (#17599)
Browse files Browse the repository at this point in the history
* Use refcounter for secure session

* Address comments: SessionHolder::Grab/GrabPairing return false on failure
  • Loading branch information
kghost authored and pull[bot] committed May 31, 2022
1 parent 81e0d6b commit 639801d
Show file tree
Hide file tree
Showing 18 changed files with 261 additions and 301 deletions.
22 changes: 12 additions & 10 deletions src/app/OperationalDeviceProxy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,16 @@ bool OperationalDeviceProxy::AttachToExistingSecureSession()
ScopedNodeId peerNodeId(mPeerId.GetNodeId(), mFabricInfo->GetFabricIndex());
auto sessionHandle =
mInitParams.sessionManager->FindSecureSessionForNode(peerNodeId, MakeOptional(Transport::SecureSession::Type::kCASE));
if (sessionHandle.HasValue())
{
ChipLogProgress(Controller, "Found an existing secure session to [" ChipLogFormatX64 "-" ChipLogFormatX64 "]!",
ChipLogValueX64(mPeerId.GetCompressedFabricId()), ChipLogValueX64(mPeerId.GetNodeId()));
mDeviceAddress = sessionHandle.Value()->AsSecureSession()->GetPeerAddress();
mSecureSession.Grab(sessionHandle.Value());
return true;
}
if (!sessionHandle.HasValue())
return false;

return false;
ChipLogProgress(Controller, "Found an existing secure session to [" ChipLogFormatX64 "-" ChipLogFormatX64 "]!",
ChipLogValueX64(mPeerId.GetCompressedFabricId()), ChipLogValueX64(mPeerId.GetNodeId()));
mDeviceAddress = sessionHandle.Value()->AsSecureSession()->GetPeerAddress();
if (!mSecureSession.Grab(sessionHandle.Value()))
return false;

return true;
}

void OperationalDeviceProxy::Connect(Callback::Callback<OnDeviceConnected> * onConnection,
Expand Down Expand Up @@ -305,7 +305,9 @@ void OperationalDeviceProxy::OnSessionEstablished(const SessionHandle & session)
VerifyOrReturn(mState != State::Uninitialized,
ChipLogError(Controller, "HandleCASEConnected was called while the device was not initialized"));

mSecureSession.Grab(session);
if (!mSecureSession.Grab(session))
return; // Got an invalid session, do not change any state

MoveToState(State::SecureConnected);
DequeueConnectionCallbacks(CHIP_NO_ERROR);

Expand Down
7 changes: 6 additions & 1 deletion src/controller/CommissioneeDeviceProxy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,13 @@ CHIP_ERROR CommissioneeDeviceProxy::UpdateDeviceData(const Transport::PeerAddres
CHIP_ERROR CommissioneeDeviceProxy::SetConnected(const SessionHandle & session)
{
VerifyOrReturnError(mState == ConnectionState::Connecting, CHIP_ERROR_INCORRECT_STATE);
if (!mSecureSession.Grab(session))
{
mState = ConnectionState::NotConnected;
return CHIP_ERROR_INTERNAL;
}

mState = ConnectionState::SecureConnected;
mSecureSession.Grab(session);
return CHIP_NO_ERROR;
}

Expand Down
17 changes: 11 additions & 6 deletions src/lib/core/ReferenceCounted.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,21 +39,26 @@ class DeleteDeletor
static void Release(T * obj) { chip::Platform::Delete(obj); }
};

template <class T>
class NoopDeletor
{
public:
static void Release(T * obj) {}
};

/**
* A reference counted object maintains a count of usages and when the usage
* count drops to 0, it deletes itself.
*/
template <class Subclass, class Deletor = DeleteDeletor<Subclass>, int kInitRefCount = 1>
template <class Subclass, class Deletor = DeleteDeletor<Subclass>, int kInitRefCount = 1, typename CounterType = uint32_t>
class ReferenceCounted
{
public:
using count_type = uint32_t;

/** Adds one to the usage count of this class */
Subclass * Retain()
{
VerifyOrDie(!kInitRefCount || mRefCount > 0);
VerifyOrDie(mRefCount < std::numeric_limits<count_type>::max());
VerifyOrDie(mRefCount < std::numeric_limits<CounterType>::max());
++mRefCount;

return static_cast<Subclass *>(this);
Expand All @@ -71,10 +76,10 @@ class ReferenceCounted
}

/** Get the current reference counter value */
count_type GetReferenceCount() const { return mRefCount; }
CounterType GetReferenceCount() const { return mRefCount; }

private:
count_type mRefCount = kInitRefCount;
CounterType mRefCount = kInitRefCount;
};

} // namespace chip
5 changes: 4 additions & 1 deletion src/messaging/ExchangeMgr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ CHIP_ERROR ExchangeManager::Shutdown()

ExchangeContext * ExchangeManager::NewContext(const SessionHandle & session, ExchangeDelegate * delegate)
{
// Disallow creating exchange on an inactive session
VerifyOrReturnError(session->IsActiveSession(), nullptr);
return mContextPool.CreateObject(this, mNextExchangeId++, session, true, delegate);
}

Expand Down Expand Up @@ -230,10 +232,11 @@ void ExchangeManager::OnMessageReceived(const PacketHeader & packetHeader, const
packetHeader.GetDestinationGroupId().Value());
}

// Do not handle unsolicited messages on a inactive session.
// If it's not a duplicate message, search for an unsolicited message handler if it is marked as being sent by an initiator.
// Since we didn't find an existing exchange that matches the message, it must be an unsolicited message. However all
// unsolicited messages must be marked as being from an initiator.
if (!msgFlags.Has(MessageFlagValues::kDuplicateMessage) && payloadHeader.IsInitiator())
if (session->IsActiveSession() && !msgFlags.Has(MessageFlagValues::kDuplicateMessage) && payloadHeader.IsInitiator())
{
// Search for an unsolicited message handler that can handle the message. Prefer handlers that can explicitly
// handle the message type over handlers that handle all messages for a profile.
Expand Down
21 changes: 4 additions & 17 deletions src/protocols/secure_channel/PairingSession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ namespace chip {

CHIP_ERROR PairingSession::AllocateSecureSession(SessionManager & sessionManager)
{
auto handle = sessionManager.AllocateSession();
auto handle = sessionManager.AllocateSession(GetSecureSessionType());
VerifyOrReturnError(handle.HasValue(), CHIP_ERROR_NO_MEMORY);
mSecureSessionHolder.Grab(handle.Value());
VerifyOrReturnError(mSecureSessionHolder.GrabPairing(handle.Value()), CHIP_ERROR_INTERNAL);
mSessionManager = &sessionManager;
return CHIP_NO_ERROR;
}
Expand All @@ -48,8 +48,7 @@ CHIP_ERROR PairingSession::ActivateSecureSession(const Transport::PeerAddress &

// Call Activate last, otherwise errors on anything after would lead to
// a partially valid session.
secureSession->Activate(GetSecureSessionType(), GetLocalScopedNodeId(), GetPeer(), GetPeerCATs(), peerSessionId,
mRemoteMRPConfig);
secureSession->Activate(GetLocalScopedNodeId(), GetPeer(), GetPeerCATs(), peerSessionId, mRemoteMRPConfig);

ChipLogDetail(Inet, "New secure session created for device " ChipLogFormatScopedNodeId ", LSID:%d PSID:%d!",
ChipLogValueScopedNodeId(GetPeer()), secureSession->GetLocalSessionId(), peerSessionId);
Expand Down Expand Up @@ -154,19 +153,7 @@ void PairingSession::Clear()
mExchangeCtxt = nullptr;
}

if (mSecureSessionHolder)
{
auto session = mSecureSessionHolder.Get();
// Call Release before ExpirePairing because we don't want to receive OnSessionReleased() event here
mSecureSessionHolder.Release();
if (!session.Value()->AsSecureSession()->IsActiveSession() && mSessionManager != nullptr)
{
// Make sure to clean up our pending session, since we're the only
// ones who have access to it do do so.
mSessionManager->ExpirePairing(session.Value());
}
}

mSecureSessionHolder.Release();
mPeerSessionId.ClearValue();
mSessionManager = nullptr;
}
Expand Down
27 changes: 23 additions & 4 deletions src/transport/GroupSession.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,30 @@

#include <app/util/basic-types.h>
#include <lib/core/GroupId.h>
#include <lib/core/ReferenceCounted.h>
#include <lib/support/Pool.h>
#include <transport/Session.h>

namespace chip {
namespace Transport {

class IncomingGroupSession : public Session
class IncomingGroupSession : public Session, public ReferenceCounted<IncomingGroupSession, NoopDeletor<IncomingGroupSession>, 0>
{
public:
IncomingGroupSession(GroupId group, FabricIndex fabricIndex, NodeId peerNodeId) : mGroupId(group), mPeerNodeId(peerNodeId)
{
SetFabricIndex(fabricIndex);
}
~IncomingGroupSession() override { NotifySessionReleased(); }
~IncomingGroupSession() override
{
NotifySessionReleased();
VerifyOrDie(GetReferenceCount() == 0);
}

void Retain() override { ReferenceCounted<IncomingGroupSession, NoopDeletor<IncomingGroupSession>, 0>::Retain(); }
void Release() override { ReferenceCounted<IncomingGroupSession, NoopDeletor<IncomingGroupSession>, 0>::Release(); }

bool IsActiveSession() const override { return true; }

Session::SessionType GetSessionType() const override { return Session::SessionType::kGroupIncoming; }
#if CHIP_PROGRESS_LOGGING
Expand Down Expand Up @@ -74,11 +84,20 @@ class IncomingGroupSession : public Session
const NodeId mPeerNodeId;
};

class OutgoingGroupSession : public Session
class OutgoingGroupSession : public Session, public ReferenceCounted<OutgoingGroupSession, NoopDeletor<OutgoingGroupSession>, 0>
{
public:
OutgoingGroupSession(GroupId group, FabricIndex fabricIndex) : mGroupId(group) { SetFabricIndex(fabricIndex); }
~OutgoingGroupSession() override { NotifySessionReleased(); }
~OutgoingGroupSession() override
{
NotifySessionReleased();
VerifyOrDie(GetReferenceCount() == 0);
}

void Retain() override { ReferenceCounted<OutgoingGroupSession, NoopDeletor<OutgoingGroupSession>, 0>::Retain(); }
void Release() override { ReferenceCounted<OutgoingGroupSession, NoopDeletor<OutgoingGroupSession>, 0>::Release(); }

bool IsActiveSession() const override { return true; }

Session::SessionType GetSessionType() const override { return Session::SessionType::kGroupOutgoing; }
#if CHIP_PROGRESS_LOGGING
Expand Down
29 changes: 29 additions & 0 deletions src/transport/SecureSession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,39 @@

#include <access/AuthMode.h>
#include <transport/SecureSession.h>
#include <transport/SecureSessionTable.h>

namespace chip {
namespace Transport {

void SecureSessionDeleter::Release(SecureSession * entry)
{
entry->mTable.ReleaseSession(entry);
}

void SecureSession::MarkForRemoval()
{
ChipLogDetail(Inet, "SecureSession MarkForRemoval %p Type:%d LSID:%d", this, to_underlying(mSecureSessionType),
mLocalSessionId);
ReferenceCountedHandle<Transport::Session> ref(*this);
switch (mState)
{
case State::kPairing:
mState = State::kPendingRemoval;
// Interrupt the pairing
NotifySessionReleased();
return;
case State::kActive:
Release(); // Decrease the ref which is retained at Activate
mState = State::kPendingRemoval;
NotifySessionReleased();
return;
case State::kPendingRemoval:
// Do nothing
return;
}
}

Access::SubjectDescriptor SecureSession::GetSubjectDescriptor() const
{
Access::SubjectDescriptor subjectDescriptor;
Expand Down
Loading

0 comments on commit 639801d

Please sign in to comment.