Skip to content

Commit

Permalink
[app] use PeerId as the only key for CASESessionManager
Browse files Browse the repository at this point in the history
This prevents accidental misuse of a wrong node in a different fabric.
  • Loading branch information
gjc13 committed Dec 16, 2021
1 parent 5988375 commit e073bb3
Show file tree
Hide file tree
Showing 16 changed files with 49 additions and 46 deletions.
2 changes: 1 addition & 1 deletion examples/chip-tool/commands/clusters/ModelCommand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ void ModelCommand::OnDeviceConnectedFn(void * context, ChipDevice * device)
command->SendCommand(device, command->mEndPointId);
}

void ModelCommand::OnDeviceConnectionFailureFn(void * context, NodeId deviceId, CHIP_ERROR err)
void ModelCommand::OnDeviceConnectionFailureFn(void * context, PeerId peerId, CHIP_ERROR err)
{
LogErrorOnFailure(err);

Expand Down
2 changes: 1 addition & 1 deletion examples/chip-tool/commands/clusters/ModelCommand.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class ModelCommand : public CHIPCommand
uint8_t mEndPointId;

static void OnDeviceConnectedFn(void * context, ChipDevice * device);
static void OnDeviceConnectionFailureFn(void * context, NodeId deviceId, CHIP_ERROR error);
static void OnDeviceConnectionFailureFn(void * context, PeerId peerId, CHIP_ERROR error);

chip::Callback::Callback<chip::OnDeviceConnected> mOnDeviceConnectedCallback;
chip::Callback::Callback<chip::OnDeviceConnectionFailure> mOnDeviceConnectionFailureCallback;
Expand Down
1 change: 1 addition & 0 deletions examples/chip-tool/commands/common/CHIPCommand.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class CHIPCommand : public Command
using ChipDeviceController = ::chip::Controller::DeviceController;
using IPAddress = ::chip::Inet::IPAddress;
using NodeId = ::chip::NodeId;
using PeerId = ::chip::PeerId;
using PeerAddress = ::chip::Transport::PeerAddress;

CHIPCommand(const char * commandName) : Command(commandName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ void OpenCommissioningWindowCommand::OnDeviceConnectedFn(void * context, chip::O
VerifyOrReturn(command != nullptr, ChipLogError(chipTool, "OnDeviceConnectedFn: context is null"));
command->OpenCommissioningWindow();
}
void OpenCommissioningWindowCommand::OnDeviceConnectionFailureFn(void * context, NodeId remoteId, CHIP_ERROR err)
void OpenCommissioningWindowCommand::OnDeviceConnectionFailureFn(void * context, PeerId peerId, CHIP_ERROR err)
{
LogErrorOnFailure(err);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class OpenCommissioningWindowCommand : public CHIPCommand

CHIP_ERROR OpenCommissioningWindow();
static void OnDeviceConnectedFn(void * context, chip::OperationalDeviceProxy * device);
static void OnDeviceConnectionFailureFn(void * context, NodeId deviceId, CHIP_ERROR error);
static void OnDeviceConnectionFailureFn(void * context, PeerId peerId, CHIP_ERROR error);
static void OnOpenCommissioningWindowResponse(void * context, NodeId deviceId, CHIP_ERROR status, chip::SetupPayload payload);

chip::Callback::Callback<chip::OnDeviceConnected> mOnDeviceConnectedCallback;
Expand Down
4 changes: 2 additions & 2 deletions examples/chip-tool/commands/tests/TestCommand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ void TestCommand::OnDeviceConnectedFn(void * context, chip::OperationalDevicePro
command->NextTest();
}

void TestCommand::OnDeviceConnectionFailureFn(void * context, NodeId deviceId, CHIP_ERROR error)
void TestCommand::OnDeviceConnectionFailureFn(void * context, PeerId peerId, CHIP_ERROR error)
{
ChipLogProgress(chipTool, " **** Test Setup: Device Connection Failure [deviceId=%" PRIu64 ". Error %" CHIP_ERROR_FORMAT "\n]",
deviceId, error.Format());
peerId.GetNodeId(), error.Format());
auto * command = static_cast<TestCommand *>(context);
VerifyOrReturn(command != nullptr, ChipLogError(chipTool, "Test command context is null"));
command->SetCommandExitStatus(error);
Expand Down
2 changes: 1 addition & 1 deletion examples/chip-tool/commands/tests/TestCommand.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class TestCommand : public CHIPCommand
chip::NodeId mNodeId;

static void OnDeviceConnectedFn(void * context, chip::OperationalDeviceProxy * device);
static void OnDeviceConnectionFailureFn(void * context, NodeId deviceId, CHIP_ERROR error);
static void OnDeviceConnectionFailureFn(void * context, PeerId peerId, CHIP_ERROR error);
static void OnWaitForMsFn(chip::System::Layer * systemLayer, void * context);

CHIP_ERROR ContinueOnChipMainThread() { return WaitForMs(0); };
Expand Down
25 changes: 12 additions & 13 deletions src/app/CASESessionManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,16 @@

namespace chip {

CHIP_ERROR CASESessionManager::FindOrEstablishSession(FabricInfo * fabric, NodeId nodeId,
Callback::Callback<OnDeviceConnected> * onConnection,
CHIP_ERROR CASESessionManager::FindOrEstablishSession(PeerId peerId, Callback::Callback<OnDeviceConnected> * onConnection,
Callback::Callback<OnDeviceConnectionFailure> * onFailure)
{
Dnssd::ResolvedNodeData resolutionData;

PeerId peerId = fabric->GetPeerIdForNode(nodeId);
// PeerId peerId = fabric->GetPeerIdForNode(nodeId);

bool nodeIDWasResolved = (mConfig.dnsCache != nullptr && mConfig.dnsCache->Lookup(peerId, resolutionData) == CHIP_NO_ERROR);

OperationalDeviceProxy * session = FindExistingSession(nodeId);
OperationalDeviceProxy * session = FindExistingSession(peerId);
if (session == nullptr)
{
// TODO - Implement LRU to evict least recently used session to handle mActiveSessions pool exhaustion
Expand All @@ -46,7 +45,7 @@ CHIP_ERROR CASESessionManager::FindOrEstablishSession(FabricInfo * fabric, NodeI

if (session == nullptr)
{
onFailure->mCall(onFailure->mContext, nodeId, CHIP_ERROR_NO_MEMORY);
onFailure->mCall(onFailure->mContext, peerId, CHIP_ERROR_NO_MEMORY);
return CHIP_ERROR_NO_MEMORY;
}
}
Expand All @@ -64,9 +63,9 @@ CHIP_ERROR CASESessionManager::FindOrEstablishSession(FabricInfo * fabric, NodeI
return err;
}

void CASESessionManager::ReleaseSession(NodeId nodeId)
void CASESessionManager::ReleaseSession(PeerId peerId)
{
ReleaseSession(FindExistingSession(nodeId));
ReleaseSession(FindExistingSession(peerId));
}

CHIP_ERROR CASESessionManager::ResolveDeviceAddress(FabricInfo * fabric, NodeId nodeId)
Expand All @@ -84,7 +83,7 @@ void CASESessionManager::OnNodeIdResolved(const Dnssd::ResolvedNodeData & nodeDa
LogErrorOnFailure(mConfig.dnsCache->Insert(nodeData));
}

OperationalDeviceProxy * session = FindExistingSession(nodeData.mPeerId.GetNodeId());
OperationalDeviceProxy * session = FindExistingSession(nodeData.mPeerId);
VerifyOrReturn(session != nullptr,
ChipLogDetail(Controller, "OnNodeIdResolved was called for a device with no active sessions, ignoring it."));

Expand All @@ -96,17 +95,17 @@ void CASESessionManager::OnNodeIdResolutionFailed(const PeerId & peer, CHIP_ERRO
ChipLogError(Controller, "Error resolving node id: %s", ErrorStr(error));
}

CHIP_ERROR CASESessionManager::GetPeerAddress(FabricInfo * fabric, NodeId nodeId, Transport::PeerAddress & addr)
CHIP_ERROR CASESessionManager::GetPeerAddress(PeerId peerId, Transport::PeerAddress & addr)
{
if (mConfig.dnsCache != nullptr)
{
Dnssd::ResolvedNodeData resolutionData;
ReturnErrorOnFailure(mConfig.dnsCache->Lookup(fabric->GetPeerIdForNode(nodeId), resolutionData));
ReturnErrorOnFailure(mConfig.dnsCache->Lookup(peerId, resolutionData));
addr = OperationalDeviceProxy::ToPeerAddress(resolutionData);
return CHIP_NO_ERROR;
}

OperationalDeviceProxy * session = FindExistingSession(nodeId);
OperationalDeviceProxy * session = FindExistingSession(peerId);
VerifyOrReturnError(session != nullptr, CHIP_ERROR_NOT_CONNECTED);
addr = session->GetPeerAddress();
return CHIP_NO_ERROR;
Expand All @@ -125,9 +124,9 @@ OperationalDeviceProxy * CASESessionManager::FindSession(SessionHandle session)
return mConfig.devicePool->FindDevice(session);
}

OperationalDeviceProxy * CASESessionManager::FindExistingSession(NodeId id)
OperationalDeviceProxy * CASESessionManager::FindExistingSession(PeerId peerId)
{
return mConfig.devicePool->FindDevice(id);
return mConfig.devicePool->FindDevice(peerId);
}

void CASESessionManager::ReleaseSession(OperationalDeviceProxy * session)
Expand Down
8 changes: 4 additions & 4 deletions src/app/CASESessionManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,12 @@ class CASESessionManager : public SessionReleaseDelegate, public Dnssd::Resolver
* these will be used to inform the caller about successful or failed connection establishment.
* If the connection is already established, the `onConnection` callback will be immediately called.
*/
CHIP_ERROR FindOrEstablishSession(FabricInfo * fabric, NodeId nodeId, Callback::Callback<OnDeviceConnected> * onConnection,
CHIP_ERROR FindOrEstablishSession(PeerId peerId, Callback::Callback<OnDeviceConnected> * onConnection,
Callback::Callback<OnDeviceConnectionFailure> * onFailure);

OperationalDeviceProxy * FindExistingSession(NodeId nodeId);
OperationalDeviceProxy * FindExistingSession(PeerId peerId);

void ReleaseSession(NodeId nodeId);
void ReleaseSession(PeerId peerId);

/**
* This API triggers the DNS-SD resolution for the given node ID. The node ID will be looked up
Expand All @@ -103,7 +103,7 @@ class CASESessionManager : public SessionReleaseDelegate, public Dnssd::Resolver
* an ongoing session with the peer node. If the session doesn't exist, the API will return
* `CHIP_ERROR_NOT_CONNECTED` error.
*/
CHIP_ERROR GetPeerAddress(FabricInfo * fabric, NodeId nodeId, Transport::PeerAddress & addr);
CHIP_ERROR GetPeerAddress(PeerId peerId, Transport::PeerAddress & addr);

//////////// SessionReleaseDelegate Implementation ///////////////
void OnSessionReleased(SessionHandle session) override;
Expand Down
4 changes: 2 additions & 2 deletions src/app/OperationalDeviceProxy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ CHIP_ERROR OperationalDeviceProxy::Connect(Callback::Callback<OnDeviceConnected>

if (err != CHIP_NO_ERROR && onFailure != nullptr)
{
onFailure->mCall(onFailure->mContext, mPeerId.GetNodeId(), err);
onFailure->mCall(onFailure->mContext, mPeerId, err);
}

return err;
Expand Down Expand Up @@ -205,7 +205,7 @@ void OperationalDeviceProxy::DequeueConnectionFailureCallbacks(CHIP_ERROR error,
cb->Cancel();
if (executeCallback)
{
cb->mCall(cb->mContext, mPeerId.GetNodeId(), error);
cb->mCall(cb->mContext, mPeerId, error);
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/app/OperationalDeviceProxy.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ struct DeviceProxyInitParams
class OperationalDeviceProxy;

typedef void (*OnDeviceConnected)(void * context, OperationalDeviceProxy * device);
typedef void (*OnDeviceConnectionFailure)(void * context, NodeId deviceId, CHIP_ERROR error);
typedef void (*OnDeviceConnectionFailure)(void * context, PeerId peerId, CHIP_ERROR error);

class DLL_EXPORT OperationalDeviceProxy : public DeviceProxy, SessionReleaseDelegate, public SessionEstablishmentDelegate
{
Expand Down
6 changes: 3 additions & 3 deletions src/app/OperationalDeviceProxyPool.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class OperationalDeviceProxyPoolDelegate

virtual OperationalDeviceProxy * FindDevice(SessionHandle session) = 0;

virtual OperationalDeviceProxy * FindDevice(NodeId id) = 0;
virtual OperationalDeviceProxy * FindDevice(PeerId peerId) = 0;

virtual ~OperationalDeviceProxyPoolDelegate() {}
};
Expand Down Expand Up @@ -74,11 +74,11 @@ class OperationalDeviceProxyPool : public OperationalDeviceProxyPoolDelegate
return foundDevice;
}

OperationalDeviceProxy * FindDevice(NodeId id) override
OperationalDeviceProxy * FindDevice(PeerId peerId) override
{
OperationalDeviceProxy * foundDevice = nullptr;
mDevicePool.ForEachActiveObject([&](auto * activeDevice) {
if (activeDevice->GetDeviceId() == id)
if (activeDevice->GetPeerId() == peerId)
{
foundDevice = activeDevice;
return Loop::Break;
Expand Down
9 changes: 5 additions & 4 deletions src/app/clusters/ota-requestor/OTARequestor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,8 @@ void OTARequestor::ConnectToProvider(OnConnectedAction onConnectedAction)

ChipLogDetail(SoftwareUpdate, "Establishing session to provider node ID 0x" ChipLogFormatX64 " on fabric index %d",
ChipLogValueX64(mProviderNodeId), mProviderFabricIndex);
CHIP_ERROR err = mCASESessionManager->FindOrEstablishSession(fabricInfo, mProviderNodeId, &mOnConnectedCallback,
&mOnConnectionFailureCallback);
CHIP_ERROR err = mCASESessionManager->FindOrEstablishSession(fabricInfo->GetPeerIdForNode(mProviderNodeId),
&mOnConnectedCallback, &mOnConnectionFailureCallback);
VerifyOrReturn(err == CHIP_NO_ERROR,
ChipLogError(SoftwareUpdate, "Cannot establish connection to provider: %" CHIP_ERROR_FORMAT, err.Format()));
}
Expand Down Expand Up @@ -350,9 +350,10 @@ OTARequestor::OTATriggerResult OTARequestor::TriggerImmediateQuery()
}

// Called whenever FindOrEstablishSession fails
void OTARequestor::OnConnectionFailure(void * context, NodeId deviceId, CHIP_ERROR error)
void OTARequestor::OnConnectionFailure(void * context, PeerId peerId, CHIP_ERROR error)
{
ChipLogError(SoftwareUpdate, "Failed to connect to node 0x%" PRIX64 ": %" CHIP_ERROR_FORMAT, deviceId, error.Format());
ChipLogError(SoftwareUpdate, "Failed to connect to node 0x%" PRIX64 ": %" CHIP_ERROR_FORMAT, peerId.GetNodeId(),
error.Format());
}

void OTARequestor::ApplyUpdate()
Expand Down
2 changes: 1 addition & 1 deletion src/app/clusters/ota-requestor/OTARequestor.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ class OTARequestor : public OTARequestorInterface
* Session connection callbacks
*/
static void OnConnected(void * context, OperationalDeviceProxy * deviceProxy);
static void OnConnectionFailure(void * context, NodeId deviceId, CHIP_ERROR error);
static void OnConnectionFailure(void * context, PeerId peerId, CHIP_ERROR error);
Callback::Callback<OnDeviceConnected> mOnConnectedCallback;
Callback::Callback<OnDeviceConnectionFailure> mOnConnectionFailureCallback;

Expand Down
18 changes: 9 additions & 9 deletions src/controller/CHIPDeviceController.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ void DeviceController::ReleaseOperationalDevice(NodeId remoteDeviceId)
{
VerifyOrReturn(mState == State::Initialized,
ChipLogError(Controller, "ReleaseOperationalDevice was called in incorrect state"));
mCASESessionManager->ReleaseSession(remoteDeviceId);
mCASESessionManager->ReleaseSession(mFabricInfo->GetPeerIdForNode(remoteDeviceId));
}

void DeviceController::OnSessionReleased(SessionHandle session)
Expand Down Expand Up @@ -352,7 +352,7 @@ CHIP_ERROR DeviceController::GetPeerAddressAndPort(PeerId peerId, Inet::IPAddres
{
VerifyOrReturnError(mState == State::Initialized, CHIP_ERROR_INCORRECT_STATE);
Transport::PeerAddress peerAddr;
ReturnErrorOnFailure(mCASESessionManager->GetPeerAddress(mFabricInfo, peerId.GetNodeId(), peerAddr));
ReturnErrorOnFailure(mCASESessionManager->GetPeerAddress(mFabricInfo->GetPeerIdForNode(peerId.GetNodeId()), peerAddr));
addr = peerAddr.GetIPAddress();
port = peerAddr.GetPort();
return CHIP_NO_ERROR;
Expand All @@ -379,7 +379,7 @@ void DeviceController::OnVIDReadResponse(void * context, uint16_t value)
controller->mSetupPayload.vendorID = value;

OperationalDeviceProxy * device =
controller->mCASESessionManager->FindExistingSession(controller->mDeviceWithCommissioningWindowOpen);
controller->mCASESessionManager->FindExistingSession(controller->GetPeerIdWithCommissioningWindowOpen());
if (device == nullptr)
{
ChipLogError(Controller, "Could not find device for opening commissioning window");
Expand Down Expand Up @@ -476,7 +476,7 @@ CHIP_ERROR DeviceController::OpenCommissioningWindowWithCallback(NodeId deviceId

if (callback != nullptr && mCommissioningWindowOption != CommissioningWindowOption::kOriginalSetupCode && readVIDPIDAttributes)
{
OperationalDeviceProxy * device = mCASESessionManager->FindExistingSession(mDeviceWithCommissioningWindowOpen);
OperationalDeviceProxy * device = mCASESessionManager->FindExistingSession(GetPeerIdWithCommissioningWindowOpen());
VerifyOrReturnError(device != nullptr, CHIP_ERROR_INVALID_ARGUMENT);

constexpr EndpointId kBasicClusterEndpoint = 0;
Expand All @@ -495,7 +495,7 @@ CHIP_ERROR DeviceController::OpenCommissioningWindowInternal()
ChipLogProgress(Controller, "OpenCommissioningWindow for device ID %" PRIu64, mDeviceWithCommissioningWindowOpen);
VerifyOrReturnError(mState == State::Initialized, CHIP_ERROR_INCORRECT_STATE);

OperationalDeviceProxy * device = mCASESessionManager->FindExistingSession(mDeviceWithCommissioningWindowOpen);
OperationalDeviceProxy * device = mCASESessionManager->FindExistingSession(GetPeerIdWithCommissioningWindowOpen());
VerifyOrReturnError(device != nullptr, CHIP_ERROR_INVALID_ARGUMENT);

constexpr EndpointId kAdministratorCommissioningClusterEndpoint = 0;
Expand Down Expand Up @@ -1645,8 +1645,8 @@ void DeviceCommissioner::OnNodeIdResolved(const chip::Dnssd::ResolvedNodeData &

mDNSCache.Insert(nodeData);

mCASESessionManager->FindOrEstablishSession(mFabricInfo, nodeData.mPeerId.GetNodeId(), &mOnDeviceConnectedCallback,
&mOnDeviceConnectionFailureCallback);
mCASESessionManager->FindOrEstablishSession(mFabricInfo->GetPeerIdForNode(nodeData.mPeerId.GetNodeId()),
&mOnDeviceConnectedCallback, &mOnDeviceConnectionFailureCallback);
DeviceController::OnNodeIdResolved(nodeData);
}

Expand Down Expand Up @@ -1688,15 +1688,15 @@ void DeviceCommissioner::OnDeviceConnectedFn(void * context, OperationalDevicePr
}
}

void DeviceCommissioner::OnDeviceConnectionFailureFn(void * context, NodeId deviceId, CHIP_ERROR error)
void DeviceCommissioner::OnDeviceConnectionFailureFn(void * context, PeerId peerId, CHIP_ERROR error)
{
DeviceCommissioner * commissioner = static_cast<DeviceCommissioner *>(context);
ChipLogProgress(Controller, "Device connection failed. Error %s", ErrorStr(error));
VerifyOrReturn(commissioner != nullptr,
ChipLogProgress(Controller, "Device connection failure callback with null context. Ignoring"));
VerifyOrReturn(commissioner->mPairingDelegate != nullptr,
ChipLogProgress(Controller, "Device connection failure callback with null pairing delegate. Ignoring"));
commissioner->mPairingDelegate->OnCommissioningComplete(deviceId, error);
commissioner->mPairingDelegate->OnCommissioningComplete(peerId.GetNodeId(), error);
}

void DeviceCommissioner::PerformCommissioningStep(DeviceProxy * proxy, CommissioningStage step, CommissioningParameters & params,
Expand Down
6 changes: 4 additions & 2 deletions src/controller/CHIPDeviceController.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ class DLL_EXPORT DeviceController : public SessionReleaseDelegate,
Callback::Callback<OnDeviceConnectionFailure> * onFailure)
{
VerifyOrReturnError(mState == State::Initialized, CHIP_ERROR_INCORRECT_STATE);
return mCASESessionManager->FindOrEstablishSession(mFabricInfo, deviceId, onConnection, onFailure);
return mCASESessionManager->FindOrEstablishSession(mFabricInfo->GetPeerIdForNode(deviceId), onConnection, onFailure);
}

/**
Expand Down Expand Up @@ -399,6 +399,8 @@ class DLL_EXPORT DeviceController : public SessionReleaseDelegate,

CHIP_ERROR OpenCommissioningWindowInternal();

PeerId GetPeerIdWithCommissioningWindowOpen() { return mFabricInfo->GetPeerIdForNode(mDeviceWithCommissioningWindowOpen); }

// TODO - Support opening commissioning window simultaneously on multiple devices
Callback::Callback<OnOpenCommissioningWindow> * mCommissioningWindowCallback = nullptr;
SetupPayload mSetupPayload;
Expand Down Expand Up @@ -745,7 +747,7 @@ class DLL_EXPORT DeviceCommissioner : public DeviceController,
static void OnRootCertFailureResponse(void * context, uint8_t status);

static void OnDeviceConnectedFn(void * context, OperationalDeviceProxy * device);
static void OnDeviceConnectionFailureFn(void * context, NodeId deviceId, CHIP_ERROR error);
static void OnDeviceConnectionFailureFn(void * context, PeerId peerId, CHIP_ERROR error);

static void OnDeviceNOCChainGeneration(void * context, CHIP_ERROR status, const ByteSpan & noc, const ByteSpan & icac,
const ByteSpan & rcac);
Expand Down

0 comments on commit e073bb3

Please sign in to comment.