diff --git a/examples/chip-tool/commands/tests/TestCommand.cpp b/examples/chip-tool/commands/tests/TestCommand.cpp index 24d8558ede1a34..aed75cd215b068 100644 --- a/examples/chip-tool/commands/tests/TestCommand.cpp +++ b/examples/chip-tool/commands/tests/TestCommand.cpp @@ -32,7 +32,17 @@ CHIP_ERROR TestCommand::RunCommand() CHIP_ERROR TestCommand::WaitForCommissionee(chip::NodeId nodeId) { - CurrentCommissioner().ReleaseOperationalDevice(nodeId); + chip::FabricIndex fabricIndex; + + ReturnErrorOnFailure(CurrentCommissioner().GetFabricIndex(&fabricIndex)); + + // + // There's a chance the commissionee may have rebooted before this call here as part of a test flow + // or is just starting out fresh outright. Let's make sure we're not re-using any cached CASE sessions + // that will now be stale and mismatched with the peer, causing subsequent interactions to fail. + // + CurrentCommissioner().SessionMgr()->ExpireAllPairings(nodeId, fabricIndex); + return CurrentCommissioner().GetConnectedDevice(nodeId, &mOnDeviceConnectedCallback, &mOnDeviceConnectionFailureCallback); } diff --git a/examples/tv-casting-app/linux/main.cpp b/examples/tv-casting-app/linux/main.cpp index c305746333e07f..2c5815069fec16 100644 --- a/examples/tv-casting-app/linux/main.cpp +++ b/examples/tv-casting-app/linux/main.cpp @@ -332,6 +332,10 @@ class TargetEndpointInfo class TargetVideoPlayerInfo { public: + TargetVideoPlayerInfo() : + mOnConnectedCallback(HandleDeviceConnected, this), mOnConnectionFailureCallback(HandleDeviceConnectionFailure, this) + {} + bool IsInitialized() { return mInitialized; } CHIP_ERROR Initialize(NodeId nodeId, FabricIndex fabricIndex) @@ -359,21 +363,30 @@ class TargetVideoPlayerInfo .clientPool = &gCASEClientPool, }; - PeerId peerID = fabric->GetPeerIdForNode(nodeId); - mOperationalDeviceProxy = chip::Platform::New(initParams, peerID); + PeerId peerID = fabric->GetPeerIdForNode(nodeId); + + // + // TODO: The code here is assuming that we can create an OperationalDeviceProxy instance and attach it immediately + // to a CASE session that just got established to us by the tv-app. While this will work most of the time, + // this is a dangerous assumption to make since it is entirely possible for that secure session to have been + // evicted in the time since that session was established to the point here when we desire to interact back + // with that peer. If that is the case, our `OnConnected` callback will not get invoked syncronously and + // mOperationalDeviceProxy will still have a value of null, triggering the check below to fail. + // + mOperationalDeviceProxy = nullptr; + CHIP_ERROR err = + server->GetCASESessionManager()->FindOrEstablishSession(peerID, &mOnConnectedCallback, &mOnConnectionFailureCallback); + if (err != CHIP_NO_ERROR) + { + ChipLogError(AppServer, "Could not establish a session to the peer"); + return err; + } - // TODO: figure out why this doesn't work so that we can remove OperationalDeviceProxy creation above, - // and remove the FindSecureSessionForNode and SetConnectedSession calls below - // mOperationalDeviceProxy = server->GetCASESessionManager()->FindExistingSession(nodeId); if (mOperationalDeviceProxy == nullptr) { - ChipLogError(AppServer, "Failed in creating an instance of OperationalDeviceProxy"); + ChipLogError(AppServer, "Failed to find an existing instance of OperationalDeviceProxy to the peer"); return CHIP_ERROR_INVALID_ARGUMENT; } - ChipLogError(AppServer, "Created an instance of OperationalDeviceProxy"); - - SessionHandle handle = server->GetSecureSessionManager().FindSecureSessionForNode(nodeId); - mOperationalDeviceProxy->SetConnectedSession(handle); mInitialized = true; return CHIP_NO_ERROR; @@ -451,12 +464,27 @@ class TargetVideoPlayerInfo } private: + static void HandleDeviceConnected(void * context, OperationalDeviceProxy * device) + { + TargetVideoPlayerInfo * _this = static_cast(context); + _this->mOperationalDeviceProxy = device; + } + + static void HandleDeviceConnectionFailure(void * context, PeerId peerId, CHIP_ERROR error) + { + TargetVideoPlayerInfo * _this = static_cast(context); + _this->mOperationalDeviceProxy = nullptr; + } + static constexpr size_t kMaxNumberOfEndpoints = 5; TargetEndpointInfo mEndpoints[kMaxNumberOfEndpoints]; NodeId mNodeId; FabricIndex mFabricIndex; OperationalDeviceProxy * mOperationalDeviceProxy; + Callback::Callback mOnConnectedCallback; + Callback::Callback mOnConnectionFailureCallback; + bool mInitialized = false; }; TargetVideoPlayerInfo gTargetVideoPlayerInfo; diff --git a/src/app/CASESessionManager.cpp b/src/app/CASESessionManager.cpp index 02f6ea1ced98b8..597be920fe984a 100644 --- a/src/app/CASESessionManager.cpp +++ b/src/app/CASESessionManager.cpp @@ -39,7 +39,7 @@ CHIP_ERROR CASESessionManager::FindOrEstablishSession(PeerId peerId, Callback::C OperationalDeviceProxy * session = FindExistingSession(peerId); if (session == nullptr) { - ChipLogDetail(CASESessionManager, "FindOrEstablishSession: No existing session found"); + ChipLogDetail(CASESessionManager, "FindOrEstablishSession: No existing OperationalDeviceProxy instance found"); session = mConfig.devicePool->Allocate(mConfig.sessionInitParams, peerId); diff --git a/src/app/OperationalDeviceProxy.cpp b/src/app/OperationalDeviceProxy.cpp index ef7ea1ce0738b6..404c6181ca46ee 100644 --- a/src/app/OperationalDeviceProxy.cpp +++ b/src/app/OperationalDeviceProxy.cpp @@ -29,6 +29,7 @@ #include "CASEClient.h" #include "CommandSender.h" #include "ReadPrepareParams.h" +#include "transport/SecureSession.h" #include #include @@ -57,10 +58,35 @@ void OperationalDeviceProxy::MoveToState(State aTargetState) } } +bool OperationalDeviceProxy::AttachToExistingSecureSession() +{ + VerifyOrReturnError(mState == State::NeedsAddress || mState == State::Initialized, false); + + ScopedNodeId peerNodeId(mPeerId.GetNodeId(), mFabricInfo->GetFabricIndex()); + auto sessionHandle = mInitParams.sessionManager->FindSecureSessionForNode(peerNodeId, Transport::SecureSession::Type::kCASE); + if (sessionHandle.HasValue()) + { + ChipLogProgress(Controller, "Found an existing secure session to [" ChipLogFormatX64 "-" ChipLogFormatX64 "]!", + ChipLogValueX64(mPeerId.GetCompressedFabricId()), ChipLogValueX64(mPeerId.GetNodeId())); + mSecureSession.Grab(sessionHandle.Value()); + return true; + } + + return false; +} + CHIP_ERROR OperationalDeviceProxy::Connect(Callback::Callback * onConnection, Callback::Callback * onFailure) { - CHIP_ERROR err = CHIP_NO_ERROR; + CHIP_ERROR err = CHIP_NO_ERROR; + bool isConnected = false; + + // + // Always enqueue our user provided callbacks into our callback list. + // If anything goes wrong below, we'll trigger failures (including any queued from + // a previous iteration which in theory shouldn't happen, but this is written to be more defensive) + // + EnqueueConnectionCallbacks(onConnection, onFailure); switch (mState) { @@ -69,35 +95,47 @@ CHIP_ERROR OperationalDeviceProxy::Connect(Callback::Callback break; case State::NeedsAddress: - err = LookupPeerAddress(); - EnqueueConnectionCallbacks(onConnection, onFailure); + isConnected = AttachToExistingSecureSession(); + if (!isConnected) + { + err = LookupPeerAddress(); + } + break; case State::Initialized: - err = EstablishConnection(); - if (err == CHIP_NO_ERROR) + isConnected = AttachToExistingSecureSession(); + if (!isConnected) { - EnqueueConnectionCallbacks(onConnection, onFailure); + err = EstablishConnection(); } + break; + case State::Connecting: - EnqueueConnectionCallbacks(onConnection, onFailure); break; case State::SecureConnected: - if (onConnection != nullptr) - { - onConnection->mCall(onConnection->mContext, this); - } + isConnected = true; break; default: err = CHIP_ERROR_INCORRECT_STATE; } - if (err != CHIP_NO_ERROR && onFailure != nullptr) + if (isConnected) + { + MoveToState(State::SecureConnected); + } + + // + // Dequeue all our callbacks on either encountering an error + // or if we successfully connected. Both should not be set + // simultaneously. + // + if (err != CHIP_NO_ERROR || isConnected) { - onFailure->mCall(onFailure->mContext, mPeerId, err); + DequeueConnectionCallbacks(err); } return err; @@ -133,7 +171,7 @@ CHIP_ERROR OperationalDeviceProxy::UpdateDeviceData(const Transport::PeerAddress err = EstablishConnection(); if (err != CHIP_NO_ERROR) { - OnSessionEstablishmentError(err); + DequeueConnectionCallbacks(err); } } else @@ -194,35 +232,43 @@ void OperationalDeviceProxy::EnqueueConnectionCallbacks(Callback::Callback * cb = Callback::Callback::FromCancelable(ready.mNext); + Callback::Callback * cb = + Callback::Callback::FromCancelable(failureReady.mNext); cb->Cancel(); - if (executeCallback) + + if (error != CHIP_NO_ERROR) { - cb->mCall(cb->mContext, this); + cb->mCall(cb->mContext, mPeerId, error); } } -} -void OperationalDeviceProxy::DequeueConnectionFailureCallbacks(CHIP_ERROR error, bool executeCallback) -{ - Cancelable ready; - mConnectionFailure.DequeueAll(ready); - while (ready.mNext != &ready) + while (successReady.mNext != &successReady) { - Callback::Callback * cb = - Callback::Callback::FromCancelable(ready.mNext); + Callback::Callback * cb = Callback::Callback::FromCancelable(successReady.mNext); cb->Cancel(); - if (executeCallback) + if (error == CHIP_NO_ERROR) { - cb->mCall(cb->mContext, mPeerId, error); + cb->mCall(cb->mContext, this); } } } @@ -234,13 +280,20 @@ void OperationalDeviceProxy::HandleCASEConnectionFailure(void * context, CASECli ChipLogError(Controller, "HandleCASEConnectionFailure was called while the device was not initialized")); VerifyOrReturn(client == device->mCASEClient, ChipLogError(Controller, "HandleCASEConnectionFailure for unknown CASEClient")); + // + // We don't need to reset the state all the way back to NeedsAddress since all that transpired + // was just CASE connection failure. So let's re-use the cached address to re-do CASE again + // if need-be. + // device->MoveToState(State::Initialized); device->CloseCASESession(); - device->DequeueConnectionSuccessCallbacks(/* executeCallback */ false); - device->DequeueConnectionFailureCallbacks(error, /* executeCallback */ true); - // Do not touch device anymore; it might have been destroyed by a failure + device->DequeueConnectionCallbacks(error); + + // + // Do not touch device instance anymore; it might have been destroyed by a failure // callback. + // } void OperationalDeviceProxy::HandleCASEConnected(void * context, CASEClient * client) @@ -254,19 +307,18 @@ void OperationalDeviceProxy::HandleCASEConnected(void * context, CASEClient * cl if (err != CHIP_NO_ERROR) { device->HandleCASEConnectionFailure(context, client, err); - // Do not touch device anymore; it might have been destroyed by a - // HandleCASEConnectionFailure. } else { device->MoveToState(State::SecureConnected); - device->CloseCASESession(); - device->DequeueConnectionFailureCallbacks(CHIP_NO_ERROR, /* executeCallback */ false); - device->DequeueConnectionSuccessCallbacks(/* executeCallback */ true); - // Do not touch device anymore; it might have been destroyed by a - // success callback. + device->DequeueConnectionCallbacks(CHIP_NO_ERROR); } + + // + // Do not touch this instance anymore; it might have been destroyed by a + // callback. + // } CHIP_ERROR OperationalDeviceProxy::Disconnect() @@ -285,12 +337,6 @@ CHIP_ERROR OperationalDeviceProxy::Disconnect() return CHIP_NO_ERROR; } -void OperationalDeviceProxy::SetConnectedSession(const SessionHandle & handle) -{ - mSecureSession.Grab(handle); - MoveToState(State::SecureConnected); -} - void OperationalDeviceProxy::Clear() { if (mCASEClient) @@ -367,8 +413,7 @@ void OperationalDeviceProxy::OnNodeAddressResolutionFailed(const PeerId & peerId ChipLogError(Discovery, "Operational discovery failed for 0x" ChipLogFormatX64 ": %" CHIP_ERROR_FORMAT, ChipLogValueX64(peerId.GetNodeId()), reason.Format()); - DequeueConnectionSuccessCallbacks(/* executeCallback */ false); - DequeueConnectionFailureCallbacks(reason, /* executeCallback */ true); + DequeueConnectionCallbacks(reason); } } // namespace chip diff --git a/src/app/OperationalDeviceProxy.h b/src/app/OperationalDeviceProxy.h index 7f1b9d715b23f6..ea5e60892369b2 100644 --- a/src/app/OperationalDeviceProxy.h +++ b/src/app/OperationalDeviceProxy.h @@ -91,6 +91,10 @@ class DLL_EXPORT OperationalDeviceProxy : public DeviceProxy, { public: ~OperationalDeviceProxy() override; + + // + // TODO: Should not be PeerId, but rather, ScopedNodeId + // OperationalDeviceProxy(DeviceProxyInitParams & params, PeerId peerId) : mSecureSession(*this) { mInitParams = params; @@ -159,15 +163,6 @@ class DLL_EXPORT OperationalDeviceProxy : public DeviceProxy, */ CHIP_ERROR Disconnect() override; - /** - * Use SetConnectedSession if 'this' object is a newly allocated device proxy. - * It will take an existing session, such as the one established - * during commissioning, and use it for this device proxy. - * - * Note: Avoid using this function generally as it is Deprecated - */ - void SetConnectedSession(const SessionHandle & handle); - NodeId GetDeviceId() const override { return mPeerId.GetNodeId(); } /** @@ -268,6 +263,15 @@ class DLL_EXPORT OperationalDeviceProxy : public DeviceProxy, CHIP_ERROR EstablishConnection(); + /* + * This checks to see if an existing CASE session exists to the peer within the SessionManager + * and if one exists, to load that into mSecureSession. + * + * Returns true if a valid session was found, false otherwise. + * + */ + bool AttachToExistingSecureSession(); + bool IsSecureConnected() const override { return mState == State::SecureConnected; } static void HandleCASEConnected(void * context, CASEClient * client); @@ -280,8 +284,15 @@ class DLL_EXPORT OperationalDeviceProxy : public DeviceProxy, void EnqueueConnectionCallbacks(Callback::Callback * onConnection, Callback::Callback * onFailure); - void DequeueConnectionSuccessCallbacks(bool executeCallback); - void DequeueConnectionFailureCallbacks(CHIP_ERROR error, bool executeCallback); + /* + * This dequeues all failure and success callbacks and appropriately + * invokes either set depending on the value of error. + * + * If error == CHIP_NO_ERROR, only success callbacks are invoked. + * Otherwise, only failure callbacks are invoked. + * + */ + void DequeueConnectionCallbacks(CHIP_ERROR error); }; } // namespace chip diff --git a/src/controller/CHIPDeviceController.h b/src/controller/CHIPDeviceController.h index e09490ff5c4413..d2aa1172e05c4f 100644 --- a/src/controller/CHIPDeviceController.h +++ b/src/controller/CHIPDeviceController.h @@ -144,6 +144,16 @@ class DLL_EXPORT DeviceController : public SessionRecoveryDelegate, public Abstr */ virtual CHIP_ERROR Shutdown(); + SessionManager * SessionMgr() + { + if (mSystemState) + { + return mSystemState->SessionMgr(); + } + + return nullptr; + } + CHIP_ERROR GetPeerAddressAndPort(PeerId peerId, Inet::IPAddress & addr, uint16_t & port); /** diff --git a/src/transport/SessionManager.cpp b/src/transport/SessionManager.cpp index 730a64e7abf80c..f0b2047683a45a 100644 --- a/src/transport/SessionManager.cpp +++ b/src/transport/SessionManager.cpp @@ -812,11 +812,12 @@ void SessionManager::ExpiryTimerCallback(System::Layer * layer, void * param) mgr->ScheduleExpiryTimer(); // re-schedule the oneshot timer } -SessionHandle SessionManager::FindSecureSessionForNode(NodeId peerNodeId) +Optional SessionManager::FindSecureSessionForNode(ScopedNodeId peerNodeId, Transport::SecureSession::Type type) { SecureSession * found = nullptr; - mSecureSessions.ForEachSession([&](auto session) { - if (session->GetPeerNodeId() == peerNodeId) + mSecureSessions.ForEachSession([&peerNodeId, type, &found](auto session) { + if (session->GetPeer() == peerNodeId && + (type == SecureSession::Type::kUndefined || type == session->GetSecureSessionType())) { found = session; return Loop::Break; @@ -824,8 +825,7 @@ SessionHandle SessionManager::FindSecureSessionForNode(NodeId peerNodeId) return Loop::Continue; }); - VerifyOrDie(found != nullptr); - return SessionHandle(*found); + return found != nullptr ? MakeOptional(*found) : Optional::Missing(); } /** diff --git a/src/transport/SessionManager.h b/src/transport/SessionManager.h index 87c430646ddd87..244155482bcebd 100644 --- a/src/transport/SessionManager.h +++ b/src/transport/SessionManager.h @@ -249,9 +249,17 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate return mUnauthenticatedSessions.AllocInitiator(ephemeralInitiatorNodeID, peerAddress, config); } - // TODO: this is a temporary solution for legacy tests which use nodeId to send packets - // and tv-casting-app that uses the TV's node ID to find the associated secure session - SessionHandle FindSecureSessionForNode(NodeId peerNodeId); + // + // Find an existing secure session given a peer's scoped NodeId and a type of session to match against. + // If matching against all types of sessions is desired, kUndefined should be passed into type. + // + // If a valid session is found, an Optional with the value set to the SessionHandle of the session + // is returned. Otherwise, an Optional with no value set is returned. + // + // + Optional + FindSecureSessionForNode(ScopedNodeId peerNodeId, + Transport::SecureSession::Type type = Transport::SecureSession::Type::kUndefined); using SessionHandleCallback = bool (*)(void * context, SessionHandle & sessionHandle); CHIP_ERROR ForEachSessionHandle(void * context, SessionHandleCallback callback);