diff --git a/src/app/CASEClient.cpp b/src/app/CASEClient.cpp index f514b3aa4c089c..f6fa82c686e827 100644 --- a/src/app/CASEClient.cpp +++ b/src/app/CASEClient.cpp @@ -19,21 +19,20 @@ namespace chip { -CASEClient::CASEClient(const CASEClientInitParams & params) : mInitParams(params) {} - void CASEClient::SetRemoteMRPIntervals(const ReliableMessageProtocolConfig & remoteMRPConfig) { mCASESession.SetRemoteMRPConfig(remoteMRPConfig); } -CHIP_ERROR CASEClient::EstablishSession(const ScopedNodeId & peer, const Transport::PeerAddress & peerAddress, +CHIP_ERROR CASEClient::EstablishSession(const CASEClientInitParams & params, const ScopedNodeId & peer, + const Transport::PeerAddress & peerAddress, const ReliableMessageProtocolConfig & remoteMRPConfig, SessionEstablishmentDelegate * delegate) { - VerifyOrReturnError(mInitParams.fabricTable != nullptr, CHIP_ERROR_INVALID_ARGUMENT); + VerifyOrReturnError(params.fabricTable != nullptr, CHIP_ERROR_INVALID_ARGUMENT); // Create a UnauthenticatedSession for CASE pairing. - Optional session = mInitParams.sessionManager->CreateUnauthenticatedSession(peerAddress, remoteMRPConfig); + Optional session = params.sessionManager->CreateUnauthenticatedSession(peerAddress, remoteMRPConfig); VerifyOrReturnError(session.HasValue(), CHIP_ERROR_NO_MEMORY); // Allocate the exchange immediately before calling CASESession::EstablishSession. @@ -42,13 +41,13 @@ CHIP_ERROR CASEClient::EstablishSession(const ScopedNodeId & peer, const Transpo // free it on error, but can only do this if it is actually called. // Allocating the exchange context right before calling EstablishSession // ensures that if allocation succeeds, CASESession has taken ownership. - Messaging::ExchangeContext * exchange = mInitParams.exchangeMgr->NewContext(session.Value(), &mCASESession); + Messaging::ExchangeContext * exchange = params.exchangeMgr->NewContext(session.Value(), &mCASESession); VerifyOrReturnError(exchange != nullptr, CHIP_ERROR_INTERNAL); - mCASESession.SetGroupDataProvider(mInitParams.groupDataProvider); - ReturnErrorOnFailure(mCASESession.EstablishSession(*mInitParams.sessionManager, mInitParams.fabricTable, peer, exchange, - mInitParams.sessionResumptionStorage, mInitParams.certificateValidityPolicy, - delegate, mInitParams.mrpLocalConfig)); + mCASESession.SetGroupDataProvider(params.groupDataProvider); + ReturnErrorOnFailure(mCASESession.EstablishSession(*params.sessionManager, params.fabricTable, peer, exchange, + params.sessionResumptionStorage, params.certificateValidityPolicy, delegate, + params.mrpLocalConfig)); return CHIP_NO_ERROR; } diff --git a/src/app/CASEClient.h b/src/app/CASEClient.h index 33dfad16bab0cf..3a5aa8ded7ff08 100644 --- a/src/app/CASEClient.h +++ b/src/app/CASEClient.h @@ -34,23 +34,31 @@ struct CASEClientInitParams Messaging::ExchangeManager * exchangeMgr = nullptr; FabricTable * fabricTable = nullptr; Credentials::GroupDataProvider * groupDataProvider = nullptr; + Optional mrpLocalConfig = Optional::Missing(); - Optional mrpLocalConfig = Optional::Missing(); + CHIP_ERROR Validate() const + { + // sessionResumptionStorage can be nullptr when resumption is disabled. + // certificateValidityPolicy is optional, too. + ReturnErrorCodeIf(sessionManager == nullptr, CHIP_ERROR_INCORRECT_STATE); + ReturnErrorCodeIf(exchangeMgr == nullptr, CHIP_ERROR_INCORRECT_STATE); + ReturnErrorCodeIf(fabricTable == nullptr, CHIP_ERROR_INCORRECT_STATE); + ReturnErrorCodeIf(groupDataProvider == nullptr, CHIP_ERROR_INCORRECT_STATE); + + return CHIP_NO_ERROR; + } }; class DLL_EXPORT CASEClient { public: - CASEClient(const CASEClientInitParams & params); - void SetRemoteMRPIntervals(const ReliableMessageProtocolConfig & remoteMRPConfig); - CHIP_ERROR EstablishSession(const ScopedNodeId & peer, const Transport::PeerAddress & peerAddress, - const ReliableMessageProtocolConfig & remoteMRPConfig, SessionEstablishmentDelegate * delegate); + CHIP_ERROR EstablishSession(const CASEClientInitParams & params, const ScopedNodeId & peer, + const Transport::PeerAddress & peerAddress, const ReliableMessageProtocolConfig & remoteMRPConfig, + SessionEstablishmentDelegate * delegate); private: - CASEClientInitParams mInitParams; - CASESession mCASESession; }; diff --git a/src/app/CASEClientPool.h b/src/app/CASEClientPool.h index f44d487771c8fc..41b372cd030e5c 100644 --- a/src/app/CASEClientPool.h +++ b/src/app/CASEClientPool.h @@ -25,7 +25,7 @@ namespace chip { class CASEClientPoolDelegate { public: - virtual CASEClient * Allocate(CASEClientInitParams params) = 0; + virtual CASEClient * Allocate() = 0; virtual void Release(CASEClient * client) = 0; @@ -38,7 +38,7 @@ class CASEClientPool : public CASEClientPoolDelegate public: ~CASEClientPool() override { mClientPool.ReleaseAll(); } - CASEClient * Allocate(CASEClientInitParams params) override { return mClientPool.CreateObject(params); } + CASEClient * Allocate() override { return mClientPool.CreateObject(); } void Release(CASEClient * client) override { mClientPool.ReleaseObject(client); } diff --git a/src/app/CASESessionManager.cpp b/src/app/CASESessionManager.cpp index 447d7a663d388f..9d4f3814943e42 100644 --- a/src/app/CASESessionManager.cpp +++ b/src/app/CASESessionManager.cpp @@ -41,7 +41,7 @@ void CASESessionManager::FindOrEstablishSession(const ScopedNodeId & peerId, Cal { ChipLogDetail(CASESessionManager, "FindOrEstablishSession: No existing OperationalSessionSetup instance found"); - session = mConfig.sessionSetupPool->Allocate(mConfig.sessionInitParams, peerId, this); + session = mConfig.sessionSetupPool->Allocate(mConfig.sessionInitParams, mConfig.clientPool, peerId, this); if (session == nullptr) { @@ -83,7 +83,7 @@ void CASESessionManager::UpdatePeerAddress(ScopedNodeId peerId) { ChipLogDetail(CASESessionManager, "UpdatePeerAddress: No existing OperationalSessionSetup instance found"); - session = mConfig.sessionSetupPool->Allocate(mConfig.sessionInitParams, peerId, this); + session = mConfig.sessionSetupPool->Allocate(mConfig.sessionInitParams, mConfig.clientPool, peerId, this); if (session == nullptr) { ChipLogDetail(CASESessionManager, "UpdatePeerAddress: Failed to allocate OperationalSessionSetup instance"); diff --git a/src/app/CASESessionManager.h b/src/app/CASESessionManager.h index ddc971a1e7de86..1e901478aaf6b7 100644 --- a/src/app/CASESessionManager.h +++ b/src/app/CASESessionManager.h @@ -36,7 +36,8 @@ class OperationalSessionSetupPoolDelegate; struct CASESessionManagerConfig { - DeviceProxyInitParams sessionInitParams; + CASEClientInitParams sessionInitParams; + CASEClientPoolDelegate * clientPool = nullptr; OperationalSessionSetupPoolDelegate * sessionSetupPool = nullptr; }; diff --git a/src/app/OperationalSessionSetup.cpp b/src/app/OperationalSessionSetup.cpp index 38aa12566c59b8..2997c877c4beee 100644 --- a/src/app/OperationalSessionSetup.cpp +++ b/src/app/OperationalSessionSetup.cpp @@ -221,12 +221,10 @@ void OperationalSessionSetup::UpdateDeviceData(const Transport::PeerAddress & ad CHIP_ERROR OperationalSessionSetup::EstablishConnection(const ReliableMessageProtocolConfig & config) { - mCASEClient = mInitParams.clientPool->Allocate(CASEClientInitParams{ - mInitParams.sessionManager, mInitParams.sessionResumptionStorage, mInitParams.certificateValidityPolicy, - mInitParams.exchangeMgr, mFabricTable, mInitParams.groupDataProvider, mInitParams.mrpLocalConfig }); + mCASEClient = mClientPool->Allocate(); ReturnErrorCodeIf(mCASEClient == nullptr, CHIP_ERROR_NO_MEMORY); - CHIP_ERROR err = mCASEClient->EstablishSession(mPeerId, mDeviceAddress, config, this); + CHIP_ERROR err = mCASEClient->EstablishSession(mInitParams, mPeerId, mDeviceAddress, config, this); if (err != CHIP_NO_ERROR) { CleanupCASEClient(); @@ -330,7 +328,7 @@ void OperationalSessionSetup::CleanupCASEClient() { if (mCASEClient) { - mInitParams.clientPool->Release(mCASEClient); + mClientPool->Release(mCASEClient); mCASEClient = nullptr; } } @@ -364,7 +362,7 @@ OperationalSessionSetup::~OperationalSessionSetup() if (mCASEClient) { // Make sure we don't leak it. - mInitParams.clientPool->Release(mCASEClient); + mClientPool->Release(mCASEClient); } } @@ -382,7 +380,7 @@ CHIP_ERROR OperationalSessionSetup::LookupPeerAddress() return CHIP_NO_ERROR; } - auto const * fabricInfo = mFabricTable->FindFabricWithIndex(mPeerId.GetFabricIndex()); + auto const * fabricInfo = mInitParams.fabricTable->FindFabricWithIndex(mPeerId.GetFabricIndex()); VerifyOrReturnError(fabricInfo != nullptr, CHIP_ERROR_INVALID_FABRIC_INDEX); PeerId peerId(fabricInfo->GetCompressedFabricId(), mPeerId.GetNodeId()); diff --git a/src/app/OperationalSessionSetup.h b/src/app/OperationalSessionSetup.h index cde2fddc6dac0b..6ed951f46cf654 100644 --- a/src/app/OperationalSessionSetup.h +++ b/src/app/OperationalSessionSetup.h @@ -45,31 +45,6 @@ namespace chip { -struct DeviceProxyInitParams -{ - SessionManager * sessionManager = nullptr; - SessionResumptionStorage * sessionResumptionStorage = nullptr; - Credentials::CertificateValidityPolicy * certificateValidityPolicy = nullptr; - Messaging::ExchangeManager * exchangeMgr = nullptr; - FabricTable * fabricTable = nullptr; - CASEClientPoolDelegate * clientPool = nullptr; - Credentials::GroupDataProvider * groupDataProvider = nullptr; - - Optional mrpLocalConfig = Optional::Missing(); - - CHIP_ERROR Validate() const - { - ReturnErrorCodeIf(sessionManager == nullptr, CHIP_ERROR_INCORRECT_STATE); - // sessionResumptionStorage can be nullptr when resumption is disabled - ReturnErrorCodeIf(exchangeMgr == nullptr, CHIP_ERROR_INCORRECT_STATE); - ReturnErrorCodeIf(fabricTable == nullptr, CHIP_ERROR_INCORRECT_STATE); - ReturnErrorCodeIf(groupDataProvider == nullptr, CHIP_ERROR_INCORRECT_STATE); - ReturnErrorCodeIf(clientPool == nullptr, CHIP_ERROR_INCORRECT_STATE); - - return CHIP_NO_ERROR; - } -}; - class OperationalSessionSetup; /** @@ -171,20 +146,20 @@ class DLL_EXPORT OperationalSessionSetup : public SessionDelegate, public: ~OperationalSessionSetup() override; - OperationalSessionSetup(DeviceProxyInitParams & params, ScopedNodeId peerId, + OperationalSessionSetup(const CASEClientInitParams & params, CASEClientPoolDelegate * clientPool, ScopedNodeId peerId, OperationalSessionReleaseDelegate * releaseDelegate) : mSecureSession(*this) { mInitParams = params; - if (params.Validate() != CHIP_NO_ERROR || releaseDelegate == nullptr) + if (params.Validate() != CHIP_NO_ERROR || clientPool == nullptr || releaseDelegate == nullptr) { mState = State::Uninitialized; return; } + mClientPool = clientPool; mSystemLayer = params.exchangeMgr->GetSessionManager()->SystemLayer(); mPeerId = peerId; - mFabricTable = params.fabricTable; mReleaseDelegate = releaseDelegate; mState = State::NeedsAddress; mAddressLookupHandle.SetListener(this); @@ -260,8 +235,8 @@ class DLL_EXPORT OperationalSessionSetup : public SessionDelegate, SecureConnected, // CASE session established. }; - DeviceProxyInitParams mInitParams; - FabricTable * mFabricTable = nullptr; + CASEClientInitParams mInitParams; + CASEClientPoolDelegate * mClientPool = nullptr; System::Layer * mSystemLayer; // mCASEClient is only non-null if we are in State::Connecting or just diff --git a/src/app/OperationalSessionSetupPool.h b/src/app/OperationalSessionSetupPool.h index 50d1fbd1567bca..8f40b37ebf5c4b 100644 --- a/src/app/OperationalSessionSetupPool.h +++ b/src/app/OperationalSessionSetupPool.h @@ -27,8 +27,8 @@ namespace chip { class OperationalSessionSetupPoolDelegate { public: - virtual OperationalSessionSetup * Allocate(DeviceProxyInitParams & params, ScopedNodeId peerId, - OperationalSessionReleaseDelegate * releaseDelegate) = 0; + virtual OperationalSessionSetup * Allocate(const CASEClientInitParams & params, CASEClientPoolDelegate * clientPool, + ScopedNodeId peerId, OperationalSessionReleaseDelegate * releaseDelegate) = 0; virtual void Release(OperationalSessionSetup * device) = 0; @@ -47,10 +47,10 @@ class OperationalSessionSetupPool : public OperationalSessionSetupPoolDelegate public: ~OperationalSessionSetupPool() override { mSessionSetupPool.ReleaseAll(); } - OperationalSessionSetup * Allocate(DeviceProxyInitParams & params, ScopedNodeId peerId, - OperationalSessionReleaseDelegate * releaseDelegate) override + OperationalSessionSetup * Allocate(const CASEClientInitParams & params, CASEClientPoolDelegate * clientPool, + ScopedNodeId peerId, OperationalSessionReleaseDelegate * releaseDelegate) override { - return mSessionSetupPool.CreateObject(params, peerId, releaseDelegate); + return mSessionSetupPool.CreateObject(params, clientPool, peerId, releaseDelegate); } void Release(OperationalSessionSetup * device) override { mSessionSetupPool.ReleaseObject(device); } diff --git a/src/app/server/Server.cpp b/src/app/server/Server.cpp index 3705cb47d3212c..b6c0d571d9e727 100644 --- a/src/app/server/Server.cpp +++ b/src/app/server/Server.cpp @@ -290,11 +290,11 @@ CHIP_ERROR Server::Init(const ServerInitParams & initParams) .certificateValidityPolicy = mCertificateValidityPolicy, .exchangeMgr = &mExchangeMgr, .fabricTable = &mFabrics, - .clientPool = &mCASEClientPool, .groupDataProvider = mGroupsProvider, .mrpLocalConfig = GetLocalMRPConfig(), }, - .sessionSetupPool = &mSessionSetupPool, + .clientPool = &mCASEClientPool, + .sessionSetupPool = &mSessionSetupPool, }; err = mCASESessionManager.Init(&DeviceLayer::SystemLayer(), caseSessionManagerConfig); diff --git a/src/app/tests/TestOperationalDeviceProxy.cpp b/src/app/tests/TestOperationalDeviceProxy.cpp index 5d6f928fdc4d7d..3b8252323f920a 100644 --- a/src/app/tests/TestOperationalDeviceProxy.cpp +++ b/src/app/tests/TestOperationalDeviceProxy.cpp @@ -69,7 +69,7 @@ void TestOperationalDeviceProxy_EstablishSessionDirectly(nlTestSuite * inSuite, VerifyOrDie(groupDataProvider.Init() == CHIP_NO_ERROR); // TODO: Set IPK in groupDataProvider - DeviceProxyInitParams params = { + CASEClientInitParams params = { .sessionManager = &sessionManager, .sessionResumptionStorage = &sessionResumptionStorage, .exchangeMgr = &exchangeMgr, diff --git a/src/controller/CHIPDeviceControllerFactory.cpp b/src/controller/CHIPDeviceControllerFactory.cpp index 84a7c4388525bd..25457368d945a3 100644 --- a/src/controller/CHIPDeviceControllerFactory.cpp +++ b/src/controller/CHIPDeviceControllerFactory.cpp @@ -245,18 +245,18 @@ CHIP_ERROR DeviceControllerFactory::InitSystemState(FactoryInitParams params) stateParams.sessionSetupPool = Platform::New(); stateParams.caseClientPool = Platform::New(); - DeviceProxyInitParams deviceInitParams = { + CASEClientInitParams sessionInitParams = { .sessionManager = stateParams.sessionMgr, .sessionResumptionStorage = stateParams.sessionResumptionStorage.get(), .exchangeMgr = stateParams.exchangeMgr, .fabricTable = stateParams.fabricTable, - .clientPool = stateParams.caseClientPool, .groupDataProvider = stateParams.groupDataProvider, .mrpLocalConfig = GetLocalMRPConfig(), }; CASESessionManagerConfig sessionManagerConfig = { - .sessionInitParams = deviceInitParams, + .sessionInitParams = sessionInitParams, + .clientPool = stateParams.caseClientPool, .sessionSetupPool = stateParams.sessionSetupPool, }; diff --git a/src/protocols/secure_channel/CASEServer.h b/src/protocols/secure_channel/CASEServer.h index 8c8f79f547fc7f..894d496c93bfac 100644 --- a/src/protocols/secure_channel/CASEServer.h +++ b/src/protocols/secure_channel/CASEServer.h @@ -69,7 +69,7 @@ class CASEServer : public SessionEstablishmentDelegate, void OnResponseTimeout(Messaging::ExchangeContext * ec) override {} Messaging::ExchangeMessageDispatch & GetMessageDispatch() override { return GetSession().GetMessageDispatch(); } - virtual CASESession & GetSession() { return mPairingSession; } + CASESession & GetSession() { return mPairingSession; } private: Messaging::ExchangeManager * mExchangeManager = nullptr; diff --git a/src/protocols/secure_channel/tests/TestCASESession.cpp b/src/protocols/secure_channel/tests/TestCASESession.cpp index 8af555796e5bec..bcfe98f1d57af0 100644 --- a/src/protocols/secure_channel/tests/TestCASESession.cpp +++ b/src/protocols/secure_channel/tests/TestCASESession.cpp @@ -124,15 +124,6 @@ class TestCASESecurePairingDelegate : public SessionEstablishmentDelegate uint32_t mNumPairingComplete = 0; }; -class CASEServerForTest : public CASEServer -{ -public: - CASESession & GetSession() override { return mCaseSession; } - -private: - CASESession mCaseSession; -}; - class TestOperationalKeystore : public chip::Crypto::OperationalKeystore { public: @@ -469,7 +460,7 @@ void TestCASESession::SecurePairingHandshakeTest(nlTestSuite * inSuite, void * i SecurePairingHandshakeTestCommon(inSuite, inContext, sessionManager, pairingCommissioner, delegateCommissioner); } -CASEServerForTest gPairingServer; +CASEServer gPairingServer; void TestCASESession::SecurePairingHandshakeServerTest(nlTestSuite * inSuite, void * inContext) {