diff --git a/src/app/CASEClient.cpp b/src/app/CASEClient.cpp index 32b842a2ba2dee..d3f8fb2d32aa1c 100644 --- a/src/app/CASEClient.cpp +++ b/src/app/CASEClient.cpp @@ -35,8 +35,8 @@ CHIP_ERROR CASEClient::EstablishSession(PeerId peer, const Transport::PeerAddres Optional session = mInitParams.sessionManager->CreateUnauthenticatedSession(peerAddress, mrpConfig); VerifyOrReturnError(session.HasValue(), CHIP_ERROR_NO_MEMORY); - uint16_t keyID = 0; - ReturnErrorOnFailure(mInitParams.idAllocator->Allocate(keyID)); + SessionHolder secureSessionHolder = mInitParams.sessionManager->AllocateSession(); + VerifyOrReturnError(secureSessionHolder, CHIP_ERROR_NO_MEMORY); // Allocate the exchange immediately before calling CASESession::EstablishSession. // @@ -48,8 +48,8 @@ CHIP_ERROR CASEClient::EstablishSession(PeerId peer, const Transport::PeerAddres VerifyOrReturnError(exchange != nullptr, CHIP_ERROR_INTERNAL); mCASESession.SetGroupDataProvider(mInitParams.groupDataProvider); - ReturnErrorOnFailure(mCASESession.EstablishSession(peerAddress, mInitParams.fabricInfo, peer.GetNodeId(), keyID, exchange, this, - mInitParams.mrpLocalConfig)); + ReturnErrorOnFailure(mCASESession.EstablishSession(peerAddress, mInitParams.fabricInfo, peer.GetNodeId(), secureSessionHolder, + exchange, this, mInitParams.mrpLocalConfig)); mConnectionSuccessCallback = onConnection; mConnectionFailureCallback = onFailure; mConectionContext = context; @@ -61,8 +61,6 @@ CHIP_ERROR CASEClient::EstablishSession(PeerId peer, const Transport::PeerAddres void CASEClient::OnSessionEstablishmentError(CHIP_ERROR error) { - mInitParams.idAllocator->Free(mCASESession.GetLocalSessionId()); - if (mConnectionFailureCallback) { mConnectionFailureCallback(mConectionContext, this, error); diff --git a/src/app/CASEClient.h b/src/app/CASEClient.h index 6a1c708fc3a7de..fba48029a0ecca 100644 --- a/src/app/CASEClient.h +++ b/src/app/CASEClient.h @@ -21,7 +21,6 @@ #include #include #include -#include namespace chip { @@ -34,7 +33,6 @@ struct CASEClientInitParams { SessionManager * sessionManager = nullptr; Messaging::ExchangeManager * exchangeMgr = nullptr; - SessionIDAllocator * idAllocator = nullptr; FabricInfo * fabricInfo = nullptr; Credentials::GroupDataProvider * groupDataProvider = nullptr; diff --git a/src/app/OperationalDeviceProxy.cpp b/src/app/OperationalDeviceProxy.cpp index efc18d11c3a9b5..0ce22499730342 100644 --- a/src/app/OperationalDeviceProxy.cpp +++ b/src/app/OperationalDeviceProxy.cpp @@ -167,9 +167,9 @@ bool OperationalDeviceProxy::GetAddress(Inet::IPAddress & addr, uint16_t & port) CHIP_ERROR OperationalDeviceProxy::EstablishConnection() { - mCASEClient = mInitParams.clientPool->Allocate( - CASEClientInitParams{ mInitParams.sessionManager, mInitParams.exchangeMgr, mInitParams.idAllocator, mFabricInfo, - mInitParams.groupDataProvider, mInitParams.mrpLocalConfig }); + mCASEClient = + mInitParams.clientPool->Allocate(CASEClientInitParams{ mInitParams.sessionManager, mInitParams.exchangeMgr, mFabricInfo, + mInitParams.groupDataProvider, mInitParams.mrpLocalConfig }); ReturnErrorCodeIf(mCASEClient == nullptr, CHIP_ERROR_NO_MEMORY); CHIP_ERROR err = mCASEClient->EstablishSession(mPeerId, mDeviceAddress, mMRPConfig, HandleCASEConnected, HandleCASEConnectionFailure, this); diff --git a/src/app/OperationalDeviceProxy.h b/src/app/OperationalDeviceProxy.h index 463f05061716ad..16e4f6e63b6594 100644 --- a/src/app/OperationalDeviceProxy.h +++ b/src/app/OperationalDeviceProxy.h @@ -38,7 +38,6 @@ #include #include #include -#include #include #include #include @@ -51,7 +50,6 @@ struct DeviceProxyInitParams { SessionManager * sessionManager = nullptr; Messaging::ExchangeManager * exchangeMgr = nullptr; - SessionIDAllocator * idAllocator = nullptr; FabricTable * fabricTable = nullptr; CASEClientPoolDelegate * clientPool = nullptr; Credentials::GroupDataProvider * groupDataProvider = nullptr; @@ -62,7 +60,6 @@ struct DeviceProxyInitParams { ReturnErrorCodeIf(sessionManager == nullptr, CHIP_ERROR_INCORRECT_STATE); ReturnErrorCodeIf(exchangeMgr == nullptr, CHIP_ERROR_INCORRECT_STATE); - ReturnErrorCodeIf(idAllocator == nullptr, CHIP_ERROR_INCORRECT_STATE); ReturnErrorCodeIf(fabricTable == nullptr, CHIP_ERROR_INCORRECT_STATE); ReturnErrorCodeIf(groupDataProvider == nullptr, CHIP_ERROR_INCORRECT_STATE); ReturnErrorCodeIf(clientPool == nullptr, CHIP_ERROR_INCORRECT_STATE); diff --git a/src/app/server/CommissioningWindowManager.cpp b/src/app/server/CommissioningWindowManager.cpp index a6fc8ae7c9f5a3..f21016812a69c8 100644 --- a/src/app/server/CommissioningWindowManager.cpp +++ b/src/app/server/CommissioningWindowManager.cpp @@ -176,8 +176,8 @@ CHIP_ERROR CommissioningWindowManager::AdvertiseAndListenForPASE() { VerifyOrReturnError(mCommissioningTimeoutTimerArmed, CHIP_ERROR_INCORRECT_STATE); - uint16_t keyID = 0; - ReturnErrorOnFailure(mIDAllocator->Allocate(keyID)); + SessionHolder secureSessionHolder = mServer->GetSecureSessionManager().AllocateSession(); + VerifyOrReturnError(secureSessionHolder, CHIP_ERROR_NO_MEMORY); mPairingSession.Clear(); @@ -188,9 +188,9 @@ CHIP_ERROR CommissioningWindowManager::AdvertiseAndListenForPASE() if (mUseECM) { ReturnErrorOnFailure(SetTemporaryDiscriminator(mECMDiscriminator)); - ReturnErrorOnFailure( - mPairingSession.WaitForPairing(mECMPASEVerifier, mECMIterations, ByteSpan(mECMSalt, mECMSaltLength), keyID, - Optional::Value(GetLocalMRPConfig()), this)); + ReturnErrorOnFailure(mPairingSession.WaitForPairing( + mECMPASEVerifier, mECMIterations, ByteSpan(mECMSalt, mECMSaltLength), secureSessionHolder, + Optional::Value(GetLocalMRPConfig()), this)); } else { @@ -211,8 +211,9 @@ CHIP_ERROR CommissioningWindowManager::AdvertiseAndListenForPASE() ReturnErrorOnFailure(verifier.Deserialize(ByteSpan(serializedVerifier))); - ReturnErrorOnFailure(mPairingSession.WaitForPairing( - verifier, iterationCount, saltSpan, keyID, Optional::Value(GetLocalMRPConfig()), this)); + ReturnErrorOnFailure(mPairingSession.WaitForPairing(verifier, iterationCount, saltSpan, secureSessionHolder, + Optional::Value(GetLocalMRPConfig()), + this)); } ReturnErrorOnFailure(StartAdvertisement()); diff --git a/src/app/server/CommissioningWindowManager.h b/src/app/server/CommissioningWindowManager.h index fffe78b4ade5cc..6cc557eaf9cd75 100644 --- a/src/app/server/CommissioningWindowManager.h +++ b/src/app/server/CommissioningWindowManager.h @@ -23,7 +23,6 @@ #include #include #include -#include #include namespace chip { @@ -65,8 +64,6 @@ class CommissioningWindowManager : public SessionEstablishmentDelegate, public a void SetAppDelegate(AppDelegate * delegate) { mAppDelegate = delegate; } - void SetSessionIDAllocator(SessionIDAllocator * idAllocator) { mIDAllocator = idAllocator; } - /** * Open the pairing window using default configured parameters. */ @@ -146,7 +143,6 @@ class CommissioningWindowManager : public SessionEstablishmentDelegate, public a bool mIsBLE = true; - SessionIDAllocator * mIDAllocator = nullptr; PASESession mPairingSession; uint8_t mFailedCommissioningAttempts = 0; diff --git a/src/app/server/Server.cpp b/src/app/server/Server.cpp index 498169e2bde733..a9c8fd56483d2c 100644 --- a/src/app/server/Server.cpp +++ b/src/app/server/Server.cpp @@ -111,7 +111,6 @@ CHIP_ERROR Server::Init(AppDelegate * delegate, uint16_t secureServicePort, uint SuccessOrExit(err = mCommissioningWindowManager.Init(this)); mCommissioningWindowManager.SetAppDelegate(delegate); - mCommissioningWindowManager.SetSessionIDAllocator(&mSessionIDAllocator); // Set up attribute persistence before we try to bring up the data model // handler. @@ -241,7 +240,6 @@ CHIP_ERROR Server::Init(AppDelegate * delegate, uint16_t secureServicePort, uint .sessionInitParams = { .sessionManager = &mSessions, .exchangeMgr = &mExchangeMgr, - .idAllocator = &mSessionIDAllocator, .fabricTable = &mFabrics, .clientPool = &mCASEClientPool, .groupDataProvider = &mGroupsProvider, diff --git a/src/app/server/Server.h b/src/app/server/Server.h index 78adc35bfb292e..afba9c4b17d3a0 100644 --- a/src/app/server/Server.h +++ b/src/app/server/Server.h @@ -80,8 +80,6 @@ class Server Messaging::ExchangeManager & GetExchangeManager() { return mExchangeMgr; } - SessionIDAllocator & GetSessionIDAllocator() { return mSessionIDAllocator; } - SessionManager & GetSecureSessionManager() { return mSessions; } TransportMgrBase & GetTransportManager() { return mTransports; } @@ -248,12 +246,10 @@ class Server Messaging::ExchangeManager mExchangeMgr; FabricTable mFabrics; - SessionIDAllocator mSessionIDAllocator; secure_channel::MessageCounterManager mMessageCounterManager; #if CHIP_DEVICE_CONFIG_ENABLE_COMMISSIONER_DISCOVERY_CLIENT chip::Protocols::UserDirectedCommissioning::UserDirectedCommissioningClient gUDCClient; #endif // CHIP_DEVICE_CONFIG_ENABLE_COMMISSIONER_DISCOVERY_CLIENT - SecurePairingUsingTestSecret mTestPairing; CommissioningWindowManager mCommissioningWindowManager; // Both PersistentStorageDelegate, and GroupDataProvider should be injected by the applications diff --git a/src/app/tests/TestOperationalDeviceProxy.cpp b/src/app/tests/TestOperationalDeviceProxy.cpp index cf31ff169520ee..0a32373f57c443 100644 --- a/src/app/tests/TestOperationalDeviceProxy.cpp +++ b/src/app/tests/TestOperationalDeviceProxy.cpp @@ -23,7 +23,6 @@ #include #include #include -#include #include #include #include @@ -56,7 +55,6 @@ void TestOperationalDeviceProxy_EstablishSessionDirectly(nlTestSuite * inSuite, VerifyOrDie(fabric != nullptr); secure_channel::MessageCounterManager messageCounterManager; chip::TestPersistentStorageDelegate deviceStorage; - SessionIDAllocator idAllocator; GroupDataProviderImpl groupDataProvider; systemLayer.Init(); @@ -72,7 +70,6 @@ void TestOperationalDeviceProxy_EstablishSessionDirectly(nlTestSuite * inSuite, DeviceProxyInitParams params = { .sessionManager = &sessionManager, .exchangeMgr = &exchangeMgr, - .idAllocator = &idAllocator, .fabricInfo = fabric, .groupDataProvider = &groupDataProvider, }; diff --git a/src/app/tests/integration/chip_im_initiator.cpp b/src/app/tests/integration/chip_im_initiator.cpp index a016509340f082..90bb9caf41027e 100644 --- a/src/app/tests/integration/chip_im_initiator.cpp +++ b/src/app/tests/integration/chip_im_initiator.cpp @@ -433,6 +433,7 @@ CHIP_ERROR EstablishSecureSession() chip::SecurePairingUsingTestSecret * testSecurePairingSecret = chip::Platform::New(); VerifyOrExit(testSecurePairingSecret != nullptr, err = CHIP_ERROR_NO_MEMORY); + testSecurePairingSecret->Init(gSessionManager); // Attempt to connect to the peer. err = gSessionManager.NewPairing(gSession, diff --git a/src/app/tests/integration/chip_im_responder.cpp b/src/app/tests/integration/chip_im_responder.cpp index fe793f407c8a0f..b05fb611e4a73a 100644 --- a/src/app/tests/integration/chip_im_responder.cpp +++ b/src/app/tests/integration/chip_im_responder.cpp @@ -197,6 +197,7 @@ int main(int argc, char * argv[]) InitializeEventLogging(&gExchangeManager); + gTestPairing.Init(gSessionManager); err = gSessionManager.NewPairing(gSession, peer, chip::kTestControllerNodeId, &gTestPairing, chip::CryptoContext::SessionRole::kResponder, gFabricIndex); SuccessOrExit(err); diff --git a/src/controller/CHIPDeviceController.cpp b/src/controller/CHIPDeviceController.cpp index c1c182f18851b4..5f888d3d837edd 100644 --- a/src/controller/CHIPDeviceController.cpp +++ b/src/controller/CHIPDeviceController.cpp @@ -413,7 +413,6 @@ ControllerDeviceInitParams DeviceController::GetControllerDeviceInitParams() .exchangeMgr = mSystemState->ExchangeMgr(), .udpEndPointManager = mSystemState->UDPEndPointManager(), .storageDelegate = mStorageDelegate, - .idAllocator = mSystemState->SessionIDAlloc(), .fabricsTable = mSystemState->Fabrics(), }; } @@ -610,8 +609,7 @@ CHIP_ERROR DeviceCommissioner::EstablishPASEConnection(NodeId remoteDeviceId, Re Messaging::ExchangeContext * exchangeCtxt = nullptr; Optional session; - - uint16_t keyID = 0; + SessionHolder secureSessionHolder; VerifyOrExit(mState == State::Initialized, err = CHIP_ERROR_INCORRECT_STATE); VerifyOrExit(mDeviceInPASEEstablishment == nullptr, err = CHIP_ERROR_INCORRECT_STATE); @@ -677,8 +675,8 @@ CHIP_ERROR DeviceCommissioner::EstablishPASEConnection(NodeId remoteDeviceId, Re session = mSystemState->SessionMgr()->CreateUnauthenticatedSession(params.GetPeerAddress(), device->GetMRPConfig()); VerifyOrExit(session.HasValue(), err = CHIP_ERROR_NO_MEMORY); - err = mSystemState->SessionIDAlloc()->Allocate(keyID); - SuccessOrExit(err); + secureSessionHolder = mSystemState->SessionMgr()->AllocateSession(); + VerifyOrExit(secureSessionHolder, CHIP_ERROR_NO_MEMORY); // TODO - Remove use of SetActive/IsActive from CommissioneeDeviceProxy device->SetActive(true); @@ -692,7 +690,7 @@ CHIP_ERROR DeviceCommissioner::EstablishPASEConnection(NodeId remoteDeviceId, Re exchangeCtxt = mSystemState->ExchangeMgr()->NewContext(session.Value(), &device->GetPairing()); VerifyOrExit(exchangeCtxt != nullptr, err = CHIP_ERROR_INTERNAL); - err = device->GetPairing().Pair(params.GetPeerAddress(), params.GetSetupPINCode(), keyID, + err = device->GetPairing().Pair(params.GetPeerAddress(), params.GetSetupPINCode(), secureSessionHolder, Optional::Value(GetLocalMRPConfig()), exchangeCtxt, this); SuccessOrExit(err); diff --git a/src/controller/CHIPDeviceControllerFactory.cpp b/src/controller/CHIPDeviceControllerFactory.cpp index 3f0c03cf1f7717..f5e5aab66bb278 100644 --- a/src/controller/CHIPDeviceControllerFactory.cpp +++ b/src/controller/CHIPDeviceControllerFactory.cpp @@ -214,14 +214,12 @@ CHIP_ERROR DeviceControllerFactory::InitSystemState(FactoryInitParams params) chip::app::DnssdServer::Instance().StartServer(); } - stateParams.sessionIDAllocator = Platform::New(); stateParams.operationalDevicePool = Platform::New(); stateParams.caseClientPool = Platform::New(); DeviceProxyInitParams deviceInitParams = { .sessionManager = stateParams.sessionMgr, .exchangeMgr = stateParams.exchangeMgr, - .idAllocator = stateParams.sessionIDAllocator, .fabricTable = stateParams.fabricTable, .clientPool = stateParams.caseClientPool, .groupDataProvider = stateParams.groupDataProvider, @@ -336,13 +334,8 @@ CHIP_ERROR DeviceControllerSystemState::Shutdown() mCASESessionManager = nullptr; } - // mSessionIDAllocator, mCASEClientPool, and mDevicePool must be deallocated + // mCASEClientPool and mDevicePool must be deallocated // after mCASESessionManager, which uses them. - if (mSessionIDAllocator != nullptr) - { - Platform::Delete(mSessionIDAllocator); - mSessionIDAllocator = nullptr; - } if (mOperationalDevicePool != nullptr) { diff --git a/src/controller/CHIPDeviceControllerSystemState.h b/src/controller/CHIPDeviceControllerSystemState.h index a3818527e99463..d6cba1c7eceee0 100644 --- a/src/controller/CHIPDeviceControllerSystemState.h +++ b/src/controller/CHIPDeviceControllerSystemState.h @@ -36,7 +36,6 @@ #include #include #include -#include #include #include @@ -88,7 +87,6 @@ struct DeviceControllerSystemStateParams FabricTable * fabricTable = nullptr; CASEServer * caseServer = nullptr; CASESessionManager * caseSessionManager = nullptr; - SessionIDAllocator * sessionIDAllocator = nullptr; OperationalDevicePool * operationalDevicePool = nullptr; CASEClientPool * caseClientPool = nullptr; Credentials::GroupDataProvider * groupDataProvider = nullptr; @@ -109,8 +107,8 @@ class DeviceControllerSystemState mUDPEndPointManager(params.udpEndPointManager), mTransportMgr(params.transportMgr), mSessionMgr(params.sessionMgr), mExchangeMgr(params.exchangeMgr), mMessageCounterManager(params.messageCounterManager), mFabrics(params.fabricTable), mCASEServer(params.caseServer), mCASESessionManager(params.caseSessionManager), - mSessionIDAllocator(params.sessionIDAllocator), mOperationalDevicePool(params.operationalDevicePool), - mCASEClientPool(params.caseClientPool), mGroupDataProvider(params.groupDataProvider) + mOperationalDevicePool(params.operationalDevicePool), mCASEClientPool(params.caseClientPool), + mGroupDataProvider(params.groupDataProvider) { #if CONFIG_NETWORK_LAYER_BLE mBleLayer = params.bleLayer; @@ -143,8 +141,7 @@ class DeviceControllerSystemState { return mSystemLayer != nullptr && mUDPEndPointManager != nullptr && mTransportMgr != nullptr && mSessionMgr != nullptr && mExchangeMgr != nullptr && mMessageCounterManager != nullptr && mFabrics != nullptr && mCASESessionManager != nullptr && - mSessionIDAllocator != nullptr && mOperationalDevicePool != nullptr && mCASEClientPool != nullptr && - mGroupDataProvider != nullptr; + mOperationalDevicePool != nullptr && mCASEClientPool != nullptr && mGroupDataProvider != nullptr; }; System::Layer * SystemLayer() { return mSystemLayer; }; @@ -159,7 +156,6 @@ class DeviceControllerSystemState Ble::BleLayer * BleLayer() { return mBleLayer; }; #endif CASESessionManager * CASESessionMgr() const { return mCASESessionManager; } - SessionIDAllocator * SessionIDAlloc() const { return mSessionIDAllocator; } Credentials::GroupDataProvider * GetGroupDataProvider() const { return mGroupDataProvider; } private: @@ -178,7 +174,6 @@ class DeviceControllerSystemState FabricTable * mFabrics = nullptr; CASEServer * mCASEServer = nullptr; CASESessionManager * mCASESessionManager = nullptr; - SessionIDAllocator * mSessionIDAllocator = nullptr; OperationalDevicePool * mOperationalDevicePool = nullptr; CASEClientPool * mCASEClientPool = nullptr; Credentials::GroupDataProvider * mGroupDataProvider = nullptr; diff --git a/src/controller/CommissioneeDeviceProxy.h b/src/controller/CommissioneeDeviceProxy.h index 6d053ffec9b9d2..fa9278c593713f 100644 --- a/src/controller/CommissioneeDeviceProxy.h +++ b/src/controller/CommissioneeDeviceProxy.h @@ -39,7 +39,6 @@ #include #include #include -#include #include #include #include @@ -69,7 +68,6 @@ struct ControllerDeviceInitParams Messaging::ExchangeManager * exchangeMgr = nullptr; Inet::EndPointManager * udpEndPointManager = nullptr; PersistentStorageDelegate * storageDelegate = nullptr; - SessionIDAllocator * idAllocator = nullptr; #if CONFIG_NETWORK_LAYER_BLE Ble::BleLayer * bleLayer = nullptr; #endif @@ -120,7 +118,6 @@ class CommissioneeDeviceProxy : public DeviceProxy, public SessionReleaseDelegat mExchangeMgr = params.exchangeMgr; mUDPEndPointManager = params.udpEndPointManager; mFabricIndex = fabric; - mIDAllocator = params.idAllocator; #if CONFIG_NETWORK_LAYER_BLE mBleLayer = params.bleLayer; #endif @@ -287,8 +284,6 @@ class CommissioneeDeviceProxy : public DeviceProxy, public SessionReleaseDelegat CHIP_ERROR LoadSecureSessionParametersIfNeeded(bool & didLoad); FabricIndex mFabricIndex = kUndefinedFabricIndex; - - SessionIDAllocator * mIDAllocator = nullptr; }; } // namespace chip diff --git a/src/messaging/tests/MessagingContext.cpp b/src/messaging/tests/MessagingContext.cpp index 9fe19f46ca78ba..81231681ff0216 100644 --- a/src/messaging/tests/MessagingContext.cpp +++ b/src/messaging/tests/MessagingContext.cpp @@ -95,6 +95,10 @@ CHIP_ERROR MessagingContext::ShutdownAndRestoreExisting(MessagingContext & exist CHIP_ERROR MessagingContext::CreateSessionBobToAlice() { + if (!mPairingBobToAlice.GetSecureSessionHolder()) + { + mPairingBobToAlice.Init(mSessionManager); + } return mSessionManager.NewPairing(mSessionBobToAlice, Optional::Value(mAliceAddress), GetAliceFabric()->GetNodeId(), &mPairingBobToAlice, CryptoContext::SessionRole::kInitiator, mBobFabricIndex); @@ -102,6 +106,10 @@ CHIP_ERROR MessagingContext::CreateSessionBobToAlice() CHIP_ERROR MessagingContext::CreateSessionAliceToBob() { + if (!mPairingAliceToBob.GetSecureSessionHolder()) + { + mPairingAliceToBob.Init(mSessionManager); + } return mSessionManager.NewPairing(mSessionAliceToBob, Optional::Value(mBobAddress), GetBobFabric()->GetNodeId(), &mPairingAliceToBob, CryptoContext::SessionRole::kResponder, mAliceFabricIndex); diff --git a/src/messaging/tests/MessagingContext.h b/src/messaging/tests/MessagingContext.h index bff1a6e9606f72..902749df580196 100644 --- a/src/messaging/tests/MessagingContext.h +++ b/src/messaging/tests/MessagingContext.h @@ -71,8 +71,9 @@ class MessagingContext : public PlatformMemoryUser public: MessagingContext() : mInitialized(false), mAliceAddress(Transport::PeerAddress::UDP(GetAddress(), CHIP_PORT + 1)), - mBobAddress(Transport::PeerAddress::UDP(GetAddress(), CHIP_PORT)), mPairingAliceToBob(kBobKeyId, kAliceKeyId), - mPairingBobToAlice(kAliceKeyId, kBobKeyId) + mBobAddress(Transport::PeerAddress::UDP(GetAddress(), CHIP_PORT)), + mPairingAliceToBob(kBobKeyId, kAliceKeyId, GetSecureSessionManager()), + mPairingBobToAlice(kAliceKeyId, kBobKeyId, GetSecureSessionManager()) {} ~MessagingContext() { VerifyOrDie(mInitialized == false); } diff --git a/src/messaging/tests/echo/echo_requester.cpp b/src/messaging/tests/echo/echo_requester.cpp index e8c35e798d8937..f090811a0a440c 100644 --- a/src/messaging/tests/echo/echo_requester.cpp +++ b/src/messaging/tests/echo/echo_requester.cpp @@ -155,6 +155,7 @@ CHIP_ERROR EstablishSecureSession() chip::Optional peerAddr; chip::SecurePairingUsingTestSecret * testSecurePairingSecret = chip::Platform::New(); VerifyOrExit(testSecurePairingSecret != nullptr, err = CHIP_ERROR_NO_MEMORY); + testSecurePairingSecret->Init(gSessionManager); if (gUseTCP) { diff --git a/src/messaging/tests/echo/echo_responder.cpp b/src/messaging/tests/echo/echo_responder.cpp index c8255b1ef2703c..1d1618bfcbbe2a 100644 --- a/src/messaging/tests/echo/echo_responder.cpp +++ b/src/messaging/tests/echo/echo_responder.cpp @@ -123,6 +123,7 @@ int main(int argc, char * argv[]) SuccessOrExit(err); } + gTestPairing.Init(gSessionManager); err = gSessionManager.NewPairing(gSession, peer, chip::kTestControllerNodeId, &gTestPairing, chip::CryptoContext::SessionRole::kResponder, gFabricIndex); SuccessOrExit(err); diff --git a/src/protocols/secure_channel/BUILD.gn b/src/protocols/secure_channel/BUILD.gn index 5037378f195e9f..4cc84703343fcd 100644 --- a/src/protocols/secure_channel/BUILD.gn +++ b/src/protocols/secure_channel/BUILD.gn @@ -18,8 +18,6 @@ static_library("secure_channel") { "SessionEstablishmentDelegate.h", "SessionEstablishmentExchangeDispatch.cpp", "SessionEstablishmentExchangeDispatch.h", - "SessionIDAllocator.cpp", - "SessionIDAllocator.h", "StatusReport.cpp", "StatusReport.h", ] diff --git a/src/protocols/secure_channel/CASEServer.cpp b/src/protocols/secure_channel/CASEServer.cpp index 59dcffe9943f2c..822d1d9586d9de 100644 --- a/src/protocols/secure_channel/CASEServer.cpp +++ b/src/protocols/secure_channel/CASEServer.cpp @@ -74,12 +74,13 @@ CHIP_ERROR CASEServer::InitCASEHandshake(Messaging::ExchangeContext * ec) } #endif - ReturnErrorOnFailure(mSessionIDAllocator.Allocate(mSessionKeyId)); + SessionHolder secureSessionHolder = mSessionManager->AllocateSession(); + VerifyOrReturnError(secureSessionHolder, CHIP_ERROR_NO_MEMORY); // Setup CASE state machine using the credentials for the current fabric. GetSession().SetGroupDataProvider(mGroupDataProvider); ReturnErrorOnFailure(GetSession().ListenForSessionEstablishment( - mSessionKeyId, mFabrics, this, Optional::Value(GetLocalMRPConfig()))); + secureSessionHolder, mFabrics, this, Optional::Value(GetLocalMRPConfig()))); // Hand over the exchange context to the CASE session. ec->SetDelegate(&GetSession()); @@ -123,7 +124,6 @@ void CASEServer::Cleanup() void CASEServer::OnSessionEstablishmentError(CHIP_ERROR err) { ChipLogError(Inet, "CASE Session establishment failed: %s", ErrorStr(err)); - mSessionIDAllocator.Free(mSessionKeyId); Cleanup(); } diff --git a/src/protocols/secure_channel/CASEServer.h b/src/protocols/secure_channel/CASEServer.h index f57ef94baaf87e..6e93f558a88ee3 100644 --- a/src/protocols/secure_channel/CASEServer.h +++ b/src/protocols/secure_channel/CASEServer.h @@ -24,7 +24,6 @@ #include #include #include -#include namespace chip { @@ -63,7 +62,6 @@ class CASEServer : public SessionEstablishmentDelegate, public Messaging::Exchan Messaging::ExchangeManager * mExchangeManager = nullptr; CASESession mPairingSession; - uint16_t mSessionKeyId = 0; SessionManager * mSessionManager = nullptr; #if CONFIG_NETWORK_LAYER_BLE Ble::BleLayer * mBleLayer = nullptr; @@ -71,7 +69,6 @@ class CASEServer : public SessionEstablishmentDelegate, public Messaging::Exchan FabricTable * mFabrics = nullptr; Credentials::GroupDataProvider * mGroupDataProvider = nullptr; - SessionIDAllocator mSessionIDAllocator; CHIP_ERROR InitCASEHandshake(Messaging::ExchangeContext * ec); diff --git a/src/protocols/secure_channel/CASESession.cpp b/src/protocols/secure_channel/CASESession.cpp index 0ef0f76c0f6efc..b42ced79221b2e 100644 --- a/src/protocols/secure_channel/CASESession.cpp +++ b/src/protocols/secure_channel/CASESession.cpp @@ -187,18 +187,18 @@ CHIP_ERROR CASESession::FromCachable(const CASESessionCachable & cachableSession return CHIP_NO_ERROR; } -CHIP_ERROR CASESession::Init(uint16_t localSessionId, SessionEstablishmentDelegate * delegate) +CHIP_ERROR CASESession::Init(SessionHolder secureSessionHolder, SessionEstablishmentDelegate * delegate) { VerifyOrReturnError(delegate != nullptr, CHIP_ERROR_INVALID_ARGUMENT); - VerifyOrReturnError(mGroupDataProvider != nullptr, CHIP_ERROR_INVALID_ARGUMENT); + VerifyOrReturnError(secureSessionHolder && secureSessionHolder->IsSecureSession(), CHIP_ERROR_INVALID_ARGUMENT); Clear(); ReturnErrorOnFailure(mCommissioningHash.Begin()); + SetSecureSessionHolder(secureSessionHolder); mDelegate = delegate; - SetLocalSessionId(localSessionId); mValidContext.Reset(); mValidContext.mRequiredKeyUsages.Set(KeyUsageFlags::kDigitalSignature); @@ -208,11 +208,12 @@ CHIP_ERROR CASESession::Init(uint16_t localSessionId, SessionEstablishmentDelega } CHIP_ERROR -CASESession::ListenForSessionEstablishment(uint16_t localSessionId, FabricTable * fabrics, SessionEstablishmentDelegate * delegate, +CASESession::ListenForSessionEstablishment(SessionHolder secureSessionHolder, FabricTable * fabrics, + SessionEstablishmentDelegate * delegate, Optional mrpConfig) { VerifyOrReturnError(fabrics != nullptr, CHIP_ERROR_INVALID_ARGUMENT); - ReturnErrorOnFailure(Init(localSessionId, delegate)); + ReturnErrorOnFailure(Init(secureSessionHolder, delegate)); mFabricsTable = fabrics; mLocalMRPConfig = mrpConfig; @@ -225,7 +226,7 @@ CASESession::ListenForSessionEstablishment(uint16_t localSessionId, FabricTable } CHIP_ERROR CASESession::EstablishSession(const Transport::PeerAddress peerAddress, FabricInfo * fabric, NodeId peerNodeId, - uint16_t localSessionId, ExchangeContext * exchangeCtxt, + SessionHolder secureSessionHolder, ExchangeContext * exchangeCtxt, SessionEstablishmentDelegate * delegate, Optional mrpConfig) { MATTER_TRACE_EVENT_SCOPE("EstablishSession", "CASESession"); @@ -241,7 +242,7 @@ CHIP_ERROR CASESession::EstablishSession(const Transport::PeerAddress peerAddres ReturnErrorCodeIf(exchangeCtxt == nullptr, CHIP_ERROR_INVALID_ARGUMENT); ReturnErrorCodeIf(fabric == nullptr, CHIP_ERROR_INVALID_ARGUMENT); - err = Init(localSessionId, delegate); + err = Init(secureSessionHolder, delegate); // We are setting the exchange context specifically before checking for error. // This is to make sure the exchange will get closed if Init() returned an error. @@ -358,6 +359,9 @@ CHIP_ERROR CASESession::SendSigma1() TLV::TLVType outerContainerType = TLV::kTLVType_NotSpecified; uint8_t destinationIdentifier[kSHA256_Hash_Length] = { 0 }; + // Validate that we have a session ID allocated. + VerifyOrReturnError(GetLocalSessionId().HasValue(), CHIP_ERROR_INCORRECT_STATE); + // Generate an ephemeral keypair ReturnErrorOnFailure(mEphemeralKey.Initialize()); @@ -372,7 +376,7 @@ CHIP_ERROR CASESession::SendSigma1() ReturnErrorOnFailure(tlvWriter.StartContainer(TLV::AnonymousTag(), TLV::kTLVType_Structure, outerContainerType)); ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(1), ByteSpan(mInitiatorRandom))); // Retrieve Session Identifier - ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(2), GetLocalSessionId())); + ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(2), GetLocalSessionId().Value())); // Generate a Destination Identifier based on the node we are attempting to reach { ReturnErrorCodeIf(mFabricInfo == nullptr, CHIP_ERROR_INCORRECT_STATE); @@ -582,6 +586,9 @@ CHIP_ERROR CASESession::SendSigma2Resume(const ByteSpan & initiatorRandom) System::PacketBufferHandle msg_R2_resume; TLV::TLVType outerContainerType = TLV::kTLVType_NotSpecified; + // Validate that we have a session ID allocated. + VerifyOrReturnError(GetLocalSessionId().HasValue(), CHIP_ERROR_INCORRECT_STATE); + msg_R2_resume = System::PacketBufferHandle::New(max_sigma2_resume_data_len); VerifyOrReturnError(!msg_R2_resume.IsNull(), CHIP_ERROR_NO_MEMORY); @@ -600,7 +607,7 @@ CHIP_ERROR CASESession::SendSigma2Resume(const ByteSpan & initiatorRandom) ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(2), resumeMICSpan)); - ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(3), GetLocalSessionId())); + ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(3), GetLocalSessionId().Value())); if (mLocalMRPConfig.HasValue()) { @@ -625,6 +632,9 @@ CHIP_ERROR CASESession::SendSigma2Resume(const ByteSpan & initiatorRandom) CHIP_ERROR CASESession::SendSigma2() { MATTER_TRACE_EVENT_SCOPE("SendSigma2", "CASESession"); + + VerifyOrReturnError(GetLocalSessionId().HasValue(), CHIP_ERROR_INCORRECT_STATE); + VerifyOrReturnError(mFabricInfo != nullptr, CHIP_ERROR_INCORRECT_STATE); ByteSpan icaCert; @@ -724,7 +734,7 @@ CHIP_ERROR CASESession::SendSigma2() tlvWriterMsg2.Init(std::move(msg_R2)); ReturnErrorOnFailure(tlvWriterMsg2.StartContainer(TLV::AnonymousTag(), TLV::kTLVType_Structure, outerContainerType)); ReturnErrorOnFailure(tlvWriterMsg2.PutBytes(TLV::ContextTag(1), &msg_rand[0], sizeof(msg_rand))); - ReturnErrorOnFailure(tlvWriterMsg2.Put(TLV::ContextTag(2), GetLocalSessionId())); + ReturnErrorOnFailure(tlvWriterMsg2.Put(TLV::ContextTag(2), GetLocalSessionId().Value())); ReturnErrorOnFailure( tlvWriterMsg2.PutBytes(TLV::ContextTag(3), mEphemeralKey.Pubkey(), static_cast(mEphemeralKey.Pubkey().Length()))); ReturnErrorOnFailure(tlvWriterMsg2.PutBytes(TLV::ContextTag(4), msg_R2_Encrypted.Get(), diff --git a/src/protocols/secure_channel/CASESession.h b/src/protocols/secure_channel/CASESession.h index 63e51073b64606..1aa2a1620d790b 100644 --- a/src/protocols/secure_channel/CASESession.h +++ b/src/protocols/secure_channel/CASESession.h @@ -79,14 +79,14 @@ class DLL_EXPORT CASESession : public Messaging::ExchangeDelegate, public Pairin * @brief * Initialize using configured fabrics and wait for session establishment requests. * - * @param mySessionId Session ID to be assigned to the secure session on the peer node + * @param secureSessionHolder Pre-allocated SecureSession holder from SessionManager * @param fabrics Table of fabrics that are currently configured on the device * @param delegate Callback object * * @return CHIP_ERROR The result of initialization */ CHIP_ERROR ListenForSessionEstablishment( - uint16_t mySessionId, FabricTable * fabrics, SessionEstablishmentDelegate * delegate, + SessionHolder secureSessionHolder, FabricTable * fabrics, SessionEstablishmentDelegate * delegate, Optional mrpConfig = Optional::Missing()); /** @@ -96,15 +96,16 @@ class DLL_EXPORT CASESession : public Messaging::ExchangeDelegate, public Pairin * @param peerAddress Address of peer with which to establish a session. * @param fabric The fabric that should be used for connecting with the peer * @param peerNodeId Node id of the peer node - * @param mySessionId Session ID to be assigned to the secure session on the peer node + * @param secureSessionHolder Pre-allocated SecureSession holder from SessionManager * @param exchangeCtxt The exchange context to send and receive messages with the peer * @param delegate Callback object * * @return CHIP_ERROR The result of initialization */ CHIP_ERROR - EstablishSession(const Transport::PeerAddress peerAddress, FabricInfo * fabric, NodeId peerNodeId, uint16_t mySessionId, - Messaging::ExchangeContext * exchangeCtxt, SessionEstablishmentDelegate * delegate, + EstablishSession(const Transport::PeerAddress peerAddress, FabricInfo * fabric, NodeId peerNodeId, + SessionHolder secureSessionHolder, Messaging::ExchangeContext * exchangeCtxt, + SessionEstablishmentDelegate * delegate, Optional mrpConfig = Optional::Missing()); /** @@ -190,7 +191,7 @@ class DLL_EXPORT CASESession : public Messaging::ExchangeDelegate, public Pairin kSentSigma2Resume = 4, }; - CHIP_ERROR Init(uint16_t mySessionId, SessionEstablishmentDelegate * delegate); + CHIP_ERROR Init(SessionHolder secureSessionHolder, SessionEstablishmentDelegate * delegate); // On success, sets mIpk to the correct value for outgoing Sigma1 based on internal state CHIP_ERROR RecoverInitiatorIpk(); diff --git a/src/protocols/secure_channel/PASESession.cpp b/src/protocols/secure_channel/PASESession.cpp index c714f860225ae4..8bca46e21161ed 100644 --- a/src/protocols/secure_channel/PASESession.cpp +++ b/src/protocols/secure_channel/PASESession.cpp @@ -117,75 +117,10 @@ void PASESession::DiscardExchange() } } -CHIP_ERROR PASESession::Serialize(PASESessionSerialized & output) -{ - PASESessionSerializable serializable; - VerifyOrReturnError(BASE64_ENCODED_LEN(sizeof(serializable)) <= sizeof(output.inner), CHIP_ERROR_INVALID_ARGUMENT); - - ReturnErrorOnFailure(ToSerializable(serializable)); - - uint16_t serializedLen = chip::Base64Encode(Uint8::to_const_uchar(reinterpret_cast(&serializable)), - static_cast(sizeof(serializable)), Uint8::to_char(output.inner)); - VerifyOrReturnError(serializedLen > 0, CHIP_ERROR_INVALID_ARGUMENT); - VerifyOrReturnError(serializedLen < sizeof(output.inner), CHIP_ERROR_INVALID_ARGUMENT); - output.inner[serializedLen] = '\0'; - - return CHIP_NO_ERROR; -} - -CHIP_ERROR PASESession::Deserialize(PASESessionSerialized & input) -{ - PASESessionSerializable serializable; - size_t maxlen = BASE64_ENCODED_LEN(sizeof(serializable)); - size_t len = strnlen(Uint8::to_char(input.inner), maxlen); - uint16_t deserializedLen = 0; - - VerifyOrReturnError(len < sizeof(PASESessionSerialized), CHIP_ERROR_INVALID_ARGUMENT); - VerifyOrReturnError(CanCastTo(len), CHIP_ERROR_INVALID_ARGUMENT); - - memset(&serializable, 0, sizeof(serializable)); - deserializedLen = - Base64Decode(Uint8::to_const_char(input.inner), static_cast(len), Uint8::to_uchar((uint8_t *) &serializable)); - - VerifyOrReturnError(deserializedLen > 0, CHIP_ERROR_INVALID_ARGUMENT); - VerifyOrReturnError(deserializedLen <= sizeof(serializable), CHIP_ERROR_INVALID_ARGUMENT); - - return FromSerializable(serializable); -} - -CHIP_ERROR PASESession::ToSerializable(PASESessionSerializable & serializable) -{ - VerifyOrReturnError(CanCastTo(mKeLen), CHIP_ERROR_INTERNAL); - - memset(&serializable, 0, sizeof(serializable)); - serializable.mKeLen = static_cast(mKeLen); - serializable.mPairingComplete = (mPairingComplete) ? 1 : 0; - serializable.mLocalSessionId = GetLocalSessionId(); - serializable.mPeerSessionId = GetPeerSessionId(); - - memcpy(serializable.mKe, mKe, mKeLen); - - return CHIP_NO_ERROR; -} - -CHIP_ERROR PASESession::FromSerializable(const PASESessionSerializable & serializable) -{ - mPairingComplete = (serializable.mPairingComplete == 1); - mKeLen = static_cast(serializable.mKeLen); - - VerifyOrReturnError(mKeLen <= sizeof(mKe), CHIP_ERROR_INVALID_ARGUMENT); - memset(mKe, 0, sizeof(mKe)); - memcpy(mKe, serializable.mKe, mKeLen); - - SetLocalSessionId(serializable.mLocalSessionId); - SetPeerSessionId(serializable.mPeerSessionId); - - return CHIP_NO_ERROR; -} - -CHIP_ERROR PASESession::Init(uint16_t mySessionId, uint32_t setupCode, SessionEstablishmentDelegate * delegate) +CHIP_ERROR PASESession::Init(SessionHolder secureSessionHolder, uint32_t setupCode, SessionEstablishmentDelegate * delegate) { VerifyOrReturnError(delegate != nullptr, CHIP_ERROR_INVALID_ARGUMENT); + VerifyOrReturnError(secureSessionHolder && secureSessionHolder->IsSecureSession(), CHIP_ERROR_INVALID_ARGUMENT); // Reset any state maintained by PASESession object (in case it's being reused for pairing) Clear(); @@ -194,9 +129,9 @@ CHIP_ERROR PASESession::Init(uint16_t mySessionId, uint32_t setupCode, SessionEs ReturnErrorOnFailure(mCommissioningHash.AddData(ByteSpan{ Uint8::from_const_char(kSpake2pContext), strlen(kSpake2pContext) })); mDelegate = delegate; - - ChipLogDetail(SecureChannel, "Assigned local session key ID %d", mySessionId); - SetLocalSessionId(mySessionId); + SetSecureSessionHolder(secureSessionHolder); + VerifyOrReturnError(GetLocalSessionId().HasValue(), CHIP_ERROR_INCORRECT_STATE); + ChipLogDetail(SecureChannel, "Assigned local session key ID %d", GetLocalSessionId().Value()); ReturnErrorCodeIf(setupCode >= (1 << kSetupPINCodeFieldLengthInBits), CHIP_ERROR_INVALID_ARGUMENT); mSetupPINCode = setupCode; @@ -233,7 +168,7 @@ CHIP_ERROR PASESession::SetupSpake2p() } CHIP_ERROR PASESession::WaitForPairing(const Spake2pVerifier & verifier, uint32_t pbkdf2IterCount, const ByteSpan & salt, - uint16_t mySessionId, Optional mrpConfig, + SessionHolder secureSessionHolder, Optional mrpConfig, SessionEstablishmentDelegate * delegate) { // Return early on error here, as we have not initialized any state yet @@ -242,7 +177,7 @@ CHIP_ERROR PASESession::WaitForPairing(const Spake2pVerifier & verifier, uint32_ ReturnErrorCodeIf(salt.size() < kSpake2p_Min_PBKDF_Salt_Length || salt.size() > kSpake2p_Max_PBKDF_Salt_Length, CHIP_ERROR_INVALID_ARGUMENT); - CHIP_ERROR err = Init(mySessionId, kSetupPINCodeUndefinedValue, delegate); + CHIP_ERROR err = Init(secureSessionHolder, kSetupPINCodeUndefinedValue, delegate); // From here onwards, let's go to exit on error, as some state might have already // been initialized SuccessOrExit(err); @@ -279,13 +214,13 @@ CHIP_ERROR PASESession::WaitForPairing(const Spake2pVerifier & verifier, uint32_ return err; } -CHIP_ERROR PASESession::Pair(const Transport::PeerAddress peerAddress, uint32_t peerSetUpPINCode, uint16_t mySessionId, +CHIP_ERROR PASESession::Pair(const Transport::PeerAddress peerAddress, uint32_t peerSetUpPINCode, SessionHolder secureSessionHolder, Optional mrpConfig, Messaging::ExchangeContext * exchangeCtxt, SessionEstablishmentDelegate * delegate) { MATTER_TRACE_EVENT_SCOPE("Pair", "PASESession"); ReturnErrorCodeIf(exchangeCtxt == nullptr, CHIP_ERROR_INVALID_ARGUMENT); - CHIP_ERROR err = Init(mySessionId, peerSetUpPINCode, delegate); + CHIP_ERROR err = Init(secureSessionHolder, peerSetUpPINCode, delegate); SuccessOrExit(err); mExchangeCtxt = exchangeCtxt; @@ -334,6 +269,9 @@ CHIP_ERROR PASESession::DeriveSecureSession(CryptoContext & session, CryptoConte CHIP_ERROR PASESession::SendPBKDFParamRequest() { MATTER_TRACE_EVENT_SCOPE("SendPBKDFParamRequest", "PASESession"); + + VerifyOrReturnError(GetLocalSessionId().HasValue(), CHIP_ERROR_INCORRECT_STATE); + ReturnErrorOnFailure(DRBG_get_bytes(mPBKDFLocalRandomData, sizeof(mPBKDFLocalRandomData))); const size_t mrpParamsSize = mLocalMRPConfig.HasValue() ? TLV::EstimateStructOverhead(sizeof(uint16_t), sizeof(uint16_t)) : 0; @@ -353,7 +291,7 @@ CHIP_ERROR PASESession::SendPBKDFParamRequest() TLV::TLVType outerContainerType = TLV::kTLVType_NotSpecified; ReturnErrorOnFailure(tlvWriter.StartContainer(TLV::AnonymousTag(), TLV::kTLVType_Structure, outerContainerType)); ReturnErrorOnFailure(tlvWriter.PutBytes(TLV::ContextTag(1), mPBKDFLocalRandomData, sizeof(mPBKDFLocalRandomData))); - ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(2), GetLocalSessionId())); + ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(2), GetLocalSessionId().Value())); ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(3), kDefaultCommissioningPasscodeId)); ReturnErrorOnFailure(tlvWriter.PutBoolean(TLV::ContextTag(4), mHavePBKDFParameters)); if (mLocalMRPConfig.HasValue()) @@ -442,6 +380,9 @@ CHIP_ERROR PASESession::HandlePBKDFParamRequest(System::PacketBufferHandle && ms CHIP_ERROR PASESession::SendPBKDFParamResponse(ByteSpan initiatorRandom, bool initiatorHasPBKDFParams) { MATTER_TRACE_EVENT_SCOPE("SendPBKDFParamResponse", "PASESession"); + + VerifyOrReturnError(GetLocalSessionId().HasValue(), CHIP_ERROR_INCORRECT_STATE); + ReturnErrorOnFailure(DRBG_get_bytes(mPBKDFLocalRandomData, sizeof(mPBKDFLocalRandomData))); const size_t mrpParamsSize = mLocalMRPConfig.HasValue() ? TLV::EstimateStructOverhead(sizeof(uint16_t), sizeof(uint16_t)) : 0; @@ -464,7 +405,7 @@ CHIP_ERROR PASESession::SendPBKDFParamResponse(ByteSpan initiatorRandom, bool in // The initiator random value is being sent back in the response as required by the specifications ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(1), initiatorRandom)); ReturnErrorOnFailure(tlvWriter.PutBytes(TLV::ContextTag(2), mPBKDFLocalRandomData, sizeof(mPBKDFLocalRandomData))); - ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(3), GetLocalSessionId())); + ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(3), GetLocalSessionId().Value())); if (!initiatorHasPBKDFParams) { diff --git a/src/protocols/secure_channel/PASESession.h b/src/protocols/secure_channel/PASESession.h index a6f9b687c6566a..d1093c6de332b4 100644 --- a/src/protocols/secure_channel/PASESession.h +++ b/src/protocols/secure_channel/PASESession.h @@ -83,34 +83,34 @@ class DLL_EXPORT PASESession : public Messaging::ExchangeDelegate, public Pairin * @brief * Initialize using PASE verifier and wait for pairing requests. * - * @param verifier PASE verifier to be used for SPAKE2P pairing - * @param pbkdf2IterCount Iteration count for PBKDF2 function - * @param salt Salt to be used for SPAKE2P operation - * @param mySessionId Session ID to be assigned to the secure session on the peer node - * @param delegate Callback object + * @param verifier PASE verifier to be used for SPAKE2P pairing + * @param pbkdf2IterCount Iteration count for PBKDF2 function + * @param salt Salt to be used for SPAKE2P operation + * @param secureSessionHolder Pre-allocated SecureSession holder from SessionManager + * @param delegate Callback object * * @return CHIP_ERROR The result of initialization */ CHIP_ERROR WaitForPairing(const Spake2pVerifier & verifier, uint32_t pbkdf2IterCount, const ByteSpan & salt, - uint16_t mySessionId, Optional mrpConfig, + SessionHolder secureSessionHolder, Optional mrpConfig, SessionEstablishmentDelegate * delegate); /** * @brief * Create a pairing request using peer's setup PIN code. * - * @param peerAddress Address of peer to pair - * @param peerSetUpPINCode Setup PIN code of the peer device - * @param mySessionId Session ID to be assigned to the secure session on the peer node - * @param exchangeCtxt The exchange context to send and receive messages with the peer - * Note: It's expected that the caller of this API hands over the - * ownership of the exchangeCtxt to PASESession object. PASESession - * will close the exchange on (successful/failed) handshake completion. - * @param delegate Callback object + * @param peerAddress Address of peer to pair + * @param peerSetUpPINCode Setup PIN code of the peer device + * @param secureSessionHolder Pre-allocated SecureSession holder from SessionManager + * @param exchangeCtxt The exchange context to send and receive messages with the peer + * Note: It's expected that the caller of this API hands over the + * ownership of the exchangeCtxt to PASESession object. PASESession + * will close the exchange on (successful/failed) handshake completion. + * @param delegate Callback object * * @return CHIP_ERROR The result of initialization */ - CHIP_ERROR Pair(const Transport::PeerAddress peerAddress, uint32_t peerSetUpPINCode, uint16_t mySessionId, + CHIP_ERROR Pair(const Transport::PeerAddress peerAddress, uint32_t peerSetUpPINCode, SessionHolder secureSessionHolder, Optional mrpConfig, Messaging::ExchangeContext * exchangeCtxt, SessionEstablishmentDelegate * delegate); @@ -141,30 +141,6 @@ class DLL_EXPORT PASESession : public Messaging::ExchangeDelegate, public Pairin */ CHIP_ERROR DeriveSecureSession(CryptoContext & session, CryptoContext::SessionRole role) override; - /** @brief Serialize the Pairing Session to a string. - * - * @return Returns a CHIP_ERROR on error, CHIP_NO_ERROR otherwise - **/ - CHIP_ERROR Serialize(PASESessionSerialized & output); - - /** @brief Deserialize the Pairing Session from the string. - * - * @return Returns a CHIP_ERROR on error, CHIP_NO_ERROR otherwise - **/ - CHIP_ERROR Deserialize(PASESessionSerialized & input); - - /** @brief Serialize the PASESession to the given serializable data structure for secure pairing - * - * @return Returns a CHIP_ERROR on error, CHIP_NO_ERROR otherwise - **/ - CHIP_ERROR ToSerializable(PASESessionSerializable & output); - - /** @brief Reconstruct secure pairing class from the serializable data structure. - * - * @return Returns a CHIP_ERROR on error, CHIP_NO_ERROR otherwise - **/ - CHIP_ERROR FromSerializable(const PASESessionSerializable & output); - // TODO: remove Clear, we should create a new instance instead reset the old instance. /** @brief This function zeroes out and resets the memory used by the object. **/ @@ -205,7 +181,7 @@ class DLL_EXPORT PASESession : public Messaging::ExchangeDelegate, public Pairin kUnexpected = 0xff, }; - CHIP_ERROR Init(uint16_t mySessionId, uint32_t setupCode, SessionEstablishmentDelegate * delegate); + CHIP_ERROR Init(SessionHolder secureSessionHolder, uint32_t setupCode, SessionEstablishmentDelegate * delegate); CHIP_ERROR ValidateReceivedMessage(Messaging::ExchangeContext * exchange, const PayloadHeader & payloadHeader, const System::PacketBufferHandle & msg); @@ -296,14 +272,20 @@ class SecurePairingUsingTestSecret : public PairingSession { // Do not set to 0 to prevent unwanted unsecured session // since the session type is unknown. - SetLocalSessionId(1); SetPeerSessionId(1); } - SecurePairingUsingTestSecret(uint16_t peerSessionId, uint16_t localSessionId) : - PairingSession(Transport::SecureSession::Type::kPASE) + void Init(SessionManager & sessionManager) { - SetLocalSessionId(localSessionId); + // Do not set to 0 to prevent unwanted unsecured session + // since the session type is unknown. + SetSecureSessionHolder(sessionManager.AllocateSession(mLocalSessionId)); + } + + SecurePairingUsingTestSecret(uint16_t peerSessionId, uint16_t localSessionId, SessionManager & sessionManager) : + PairingSession(Transport::SecureSession::Type::kPASE), mLocalSessionId(localSessionId) + { + SetSecureSessionHolder(sessionManager.AllocateSession(localSessionId)); SetPeerSessionId(peerSessionId); } @@ -314,28 +296,11 @@ class SecurePairingUsingTestSecret : public PairingSession CryptoContext::SessionInfoType::kSessionEstablishment, role); } - CHIP_ERROR ToSerializable(PASESessionSerializable & serializable) - { - size_t secretLen = strlen(kTestSecret); - - memset(&serializable, 0, sizeof(serializable)); - serializable.mKeLen = static_cast(secretLen); - serializable.mPairingComplete = 1; - serializable.mLocalSessionId = GetLocalSessionId(); - serializable.mPeerSessionId = GetPeerSessionId(); - - memcpy(serializable.mKe, kTestSecret, secretLen); - return CHIP_NO_ERROR; - } - private: + // Do not set to 0 to prevent unwanted unsecured session + // since the session type is unknown. + uint16_t mLocalSessionId = 1; const char * kTestSecret = CHIP_CONFIG_TEST_SHARED_SECRET_VALUE; }; -typedef struct PASESessionSerialized -{ - // Extra uint64_t to account for padding bytes (NULL termination, and some decoding overheads) - uint8_t inner[BASE64_ENCODED_LEN(sizeof(PASESessionSerializable) + sizeof(uint64_t))]; -} PASESessionSerialized; - } // namespace chip diff --git a/src/protocols/secure_channel/SessionIDAllocator.cpp b/src/protocols/secure_channel/SessionIDAllocator.cpp deleted file mode 100644 index 9df635d98d84fd..00000000000000 --- a/src/protocols/secure_channel/SessionIDAllocator.cpp +++ /dev/null @@ -1,83 +0,0 @@ -/* - * - * Copyright (c) 2021 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include - -namespace chip { - -uint16_t SessionIDAllocator::sNextAvailable = 1; - -CHIP_ERROR SessionIDAllocator::Allocate(uint16_t & id) -{ - VerifyOrReturnError(sNextAvailable < kMaxSessionID, CHIP_ERROR_NO_MEMORY); - VerifyOrReturnError(sNextAvailable > kUnsecuredSessionId, CHIP_ERROR_INTERNAL); - id = sNextAvailable; - - // TODO - Update SessionID allocator to use freed session IDs - sNextAvailable++; - - return CHIP_NO_ERROR; -} - -void SessionIDAllocator::Free(uint16_t id) -{ - // As per spec 4.4.1.3 Session ID of 0 is reserved for Unsecure communication - if (sNextAvailable > (kUnsecuredSessionId + 1) && (sNextAvailable - 1) == id) - { - sNextAvailable--; - } -} - -CHIP_ERROR SessionIDAllocator::Reserve(uint16_t id) -{ - VerifyOrReturnError(id < kMaxSessionID, CHIP_ERROR_NO_MEMORY); - if (id >= sNextAvailable) - { - sNextAvailable = id; - sNextAvailable++; - } - - // TODO - Check if ID is already allocated in SessionIDAllocator::Reserve() - - return CHIP_NO_ERROR; -} - -CHIP_ERROR SessionIDAllocator::ReserveUpTo(uint16_t id) -{ - VerifyOrReturnError(id < kMaxSessionID, CHIP_ERROR_NO_MEMORY); - if (id >= sNextAvailable) - { - sNextAvailable = id; - sNextAvailable++; - } - - // TODO - Update ReserveUpTo to mark all IDs in use - // Current SessionIDAllocator only tracks the smallest unused session ID. - // If/when we change it to track all in use IDs, we should also update ReserveUpTo - // to reserve all individual session IDs, instead of just setting the sNextAvailable. - - return CHIP_NO_ERROR; -} - -uint16_t SessionIDAllocator::Peek() -{ - return sNextAvailable; -} - -} // namespace chip diff --git a/src/protocols/secure_channel/SessionIDAllocator.h b/src/protocols/secure_channel/SessionIDAllocator.h deleted file mode 100644 index c56b292c1b3ed8..00000000000000 --- a/src/protocols/secure_channel/SessionIDAllocator.h +++ /dev/null @@ -1,55 +0,0 @@ -/* - * - * Copyright (c) 2021 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include - -// Spec 4.4.1.3 -// ===== Session ID (16 bits) -// An unsigned integer value identifying the session associated with this message. -// The session identifies the particular key used to encrypt a message out of the set of -// available keys (either session or group), and the particular encryption/message -// integrity algorithm to use for the message.The Session ID field is always present. -// A Session ID of 0 SHALL indicate an unsecured session with no encryption or message integrity checking. -// -// The Session ID is allocated from a global numerical space shared across all fabrics and nodes on the resident process instance. -// - -namespace chip { - -class SessionIDAllocator -{ -public: - SessionIDAllocator() {} - ~SessionIDAllocator() {} - - CHIP_ERROR Allocate(uint16_t & id); - void Free(uint16_t id); - CHIP_ERROR Reserve(uint16_t id); - CHIP_ERROR ReserveUpTo(uint16_t id); - uint16_t Peek(); - -private: - static constexpr uint16_t kMaxSessionID = UINT16_MAX; - static constexpr uint16_t kUnsecuredSessionId = 0; - - static uint16_t sNextAvailable; -}; - -} // namespace chip diff --git a/src/protocols/secure_channel/tests/BUILD.gn b/src/protocols/secure_channel/tests/BUILD.gn index c2ce3df2fe54b7..070a757fbdcf97 100644 --- a/src/protocols/secure_channel/tests/BUILD.gn +++ b/src/protocols/secure_channel/tests/BUILD.gn @@ -15,7 +15,6 @@ chip_test_suite("tests") { # TODO - Fix Message Counter Sync to use group key # "TestMessageCounterManager.cpp", "TestPASESession.cpp", - "TestSessionIDAllocator.cpp", "TestStatusReport.cpp", ] diff --git a/src/protocols/secure_channel/tests/TestCASESession.cpp b/src/protocols/secure_channel/tests/TestCASESession.cpp index c89f68b929cd5f..1a0ac71ee1668b 100644 --- a/src/protocols/secure_channel/tests/TestCASESession.cpp +++ b/src/protocols/secure_channel/tests/TestCASESession.cpp @@ -179,16 +179,25 @@ void CASE_SecurePairingWaitTest(nlTestSuite * inSuite, void * inContext) TestCASESecurePairingDelegate delegate; CASESession pairing; FabricTable fabrics; + SessionManager sessionManager; + SessionHolder secureSessionHolder; + SessionHolder emptySessionHolder; NL_TEST_ASSERT(inSuite, pairing.GetSecureSessionType() == SecureSession::Type::kCASE); CATValues peerCATs; peerCATs = pairing.GetPeerCATs(); NL_TEST_ASSERT(inSuite, memcmp(&peerCATs, &kUndefinedCATs, sizeof(CATValues)) == 0); + secureSessionHolder = sessionManager.AllocateSession(); + NL_TEST_ASSERT(inSuite, secureSessionHolder); pairing.SetGroupDataProvider(&gDeviceGroupDataProvider); - NL_TEST_ASSERT(inSuite, pairing.ListenForSessionEstablishment(0, nullptr, nullptr) == CHIP_ERROR_INVALID_ARGUMENT); - NL_TEST_ASSERT(inSuite, pairing.ListenForSessionEstablishment(0, nullptr, &delegate) == CHIP_ERROR_INVALID_ARGUMENT); - NL_TEST_ASSERT(inSuite, pairing.ListenForSessionEstablishment(0, &fabrics, &delegate) == CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, + pairing.ListenForSessionEstablishment(emptySessionHolder, nullptr, nullptr) == CHIP_ERROR_INVALID_ARGUMENT); + NL_TEST_ASSERT(inSuite, + pairing.ListenForSessionEstablishment(emptySessionHolder, nullptr, &delegate) == CHIP_ERROR_INVALID_ARGUMENT); + NL_TEST_ASSERT(inSuite, + pairing.ListenForSessionEstablishment(emptySessionHolder, &fabrics, &delegate) == CHIP_ERROR_INVALID_ARGUMENT); + NL_TEST_ASSERT(inSuite, pairing.ListenForSessionEstablishment(secureSessionHolder, &fabrics, &delegate) == CHIP_NO_ERROR); } void CASE_SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) @@ -202,22 +211,25 @@ void CASE_SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) FabricInfo * fabric = gCommissionerFabrics.FindFabricWithIndex(gCommissionerFabricIndex); NL_TEST_ASSERT(inSuite, fabric != nullptr); + SessionManager sessionManager; + SessionHolder secureSessionHolder = sessionManager.AllocateSession(); + NL_TEST_ASSERT(inSuite, secureSessionHolder); ExchangeContext * context = ctx.NewUnauthenticatedExchangeToBob(&pairing); NL_TEST_ASSERT(inSuite, - pairing.EstablishSession(Transport::PeerAddress(Transport::Type::kBle), nullptr, Node01_01, 0, nullptr, - nullptr) != CHIP_NO_ERROR); + pairing.EstablishSession(Transport::PeerAddress(Transport::Type::kBle), nullptr, Node01_01, secureSessionHolder, + nullptr, nullptr) != CHIP_NO_ERROR); ctx.DrainAndServiceIO(); NL_TEST_ASSERT(inSuite, - pairing.EstablishSession(Transport::PeerAddress(Transport::Type::kBle), fabric, Node01_01, 0, nullptr, - nullptr) != CHIP_NO_ERROR); + pairing.EstablishSession(Transport::PeerAddress(Transport::Type::kBle), fabric, Node01_01, secureSessionHolder, + nullptr, nullptr) != CHIP_NO_ERROR); ctx.DrainAndServiceIO(); NL_TEST_ASSERT(inSuite, - pairing.EstablishSession(Transport::PeerAddress(Transport::Type::kBle), fabric, Node01_01, 0, context, - &delegate) == CHIP_NO_ERROR); + pairing.EstablishSession(Transport::PeerAddress(Transport::Type::kBle), fabric, Node01_01, secureSessionHolder, + context, &delegate) == CHIP_NO_ERROR); ctx.DrainAndServiceIO(); NL_TEST_ASSERT(inSuite, gLoopback.mSentMessageCount == 1); @@ -237,8 +249,8 @@ void CASE_SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) ExchangeContext * context1 = ctx.NewUnauthenticatedExchangeToBob(&pairing1); NL_TEST_ASSERT(inSuite, - pairing1.EstablishSession(Transport::PeerAddress(Transport::Type::kBle), fabric, Node01_01, 0, context1, - &delegate) == CHIP_ERROR_BAD_REQUEST); + pairing1.EstablishSession(Transport::PeerAddress(Transport::Type::kBle), fabric, Node01_01, secureSessionHolder, + context1, &delegate) == CHIP_ERROR_BAD_REQUEST); ctx.DrainAndServiceIO(); gLoopback.mMessageSendError = CHIP_NO_ERROR; @@ -254,6 +266,9 @@ void CASE_SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inConte CASESession pairingAccessory; CASESessionCachable serializableCommissioner; CASESessionCachable serializableAccessory; + SessionManager sessionManager; + SessionHolder secureSessionHolder = sessionManager.AllocateSession(); + NL_TEST_ASSERT(inSuite, secureSessionHolder); gLoopback.mSentMessageCount = 0; @@ -268,10 +283,12 @@ void CASE_SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inConte pairingAccessory.SetGroupDataProvider(&gDeviceGroupDataProvider); NL_TEST_ASSERT(inSuite, - pairingAccessory.ListenForSessionEstablishment(0, &gDeviceFabrics, &delegateAccessory) == CHIP_NO_ERROR); + pairingAccessory.ListenForSessionEstablishment(secureSessionHolder, &gDeviceFabrics, &delegateAccessory) == + CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, - pairingCommissioner.EstablishSession(Transport::PeerAddress(Transport::Type::kBle), fabric, Node01_01, 0, - contextCommissioner, &delegateCommissioner) == CHIP_NO_ERROR); + pairingCommissioner.EstablishSession(Transport::PeerAddress(Transport::Type::kBle), fabric, Node01_01, + secureSessionHolder, contextCommissioner, + &delegateCommissioner) == CHIP_NO_ERROR); ctx.DrainAndServiceIO(); NL_TEST_ASSERT(inSuite, gLoopback.mSentMessageCount == 5); @@ -304,6 +321,11 @@ void CASE_SecurePairingHandshakeServerTest(nlTestSuite * inSuite, void * inConte auto * pairingCommissioner = chip::Platform::New(); pairingCommissioner->SetGroupDataProvider(&gCommissionerGroupDataProvider); + SessionManager sessionManager; + + SessionHolder secureSessionHolder = sessionManager.AllocateSession(); + NL_TEST_ASSERT(inSuite, secureSessionHolder); + TestContext & ctx = *reinterpret_cast(inContext); gLoopback.mSentMessageCount = 0; @@ -322,8 +344,9 @@ void CASE_SecurePairingHandshakeServerTest(nlTestSuite * inSuite, void * inConte NL_TEST_ASSERT(inSuite, fabric != nullptr); NL_TEST_ASSERT(inSuite, - pairingCommissioner->EstablishSession(Transport::PeerAddress(Transport::Type::kBle), fabric, Node01_01, 0, - contextCommissioner, &delegateCommissioner) == CHIP_NO_ERROR); + pairingCommissioner->EstablishSession(Transport::PeerAddress(Transport::Type::kBle), fabric, Node01_01, + secureSessionHolder, contextCommissioner, + &delegateCommissioner) == CHIP_NO_ERROR); ctx.DrainAndServiceIO(); NL_TEST_ASSERT(inSuite, gLoopback.mSentMessageCount == 5); @@ -334,8 +357,9 @@ void CASE_SecurePairingHandshakeServerTest(nlTestSuite * inSuite, void * inConte ExchangeContext * contextCommissioner1 = ctx.NewUnauthenticatedExchangeToBob(pairingCommissioner1); NL_TEST_ASSERT(inSuite, - pairingCommissioner1->EstablishSession(Transport::PeerAddress(Transport::Type::kBle), fabric, Node01_01, 0, - contextCommissioner1, &delegateCommissioner) == CHIP_NO_ERROR); + pairingCommissioner1->EstablishSession(Transport::PeerAddress(Transport::Type::kBle), fabric, Node01_01, + secureSessionHolder, contextCommissioner1, + &delegateCommissioner) == CHIP_NO_ERROR); ctx.DrainAndServiceIO(); chip::Platform::Delete(pairingCommissioner); diff --git a/src/protocols/secure_channel/tests/TestPASESession.cpp b/src/protocols/secure_channel/tests/TestPASESession.cpp index 27fb6a9ac875ee..af65b3680e2fb4 100644 --- a/src/protocols/secure_channel/tests/TestPASESession.cpp +++ b/src/protocols/secure_channel/tests/TestPASESession.cpp @@ -114,37 +114,49 @@ void SecurePairingWaitTest(nlTestSuite * inSuite, void * inContext) // Test all combinations of invalid parameters TestSecurePairingDelegate delegate; PASESession pairing; + SessionManager sessionManager; + SessionHolder secureSessionHolder; + SessionHolder emptySessionHolder; NL_TEST_ASSERT(inSuite, pairing.GetSecureSessionType() == SecureSession::Type::kPASE); CATValues peerCATs; peerCATs = pairing.GetPeerCATs(); NL_TEST_ASSERT(inSuite, memcmp(&peerCATs, &kUndefinedCATs, sizeof(CATValues)) == 0); + secureSessionHolder = sessionManager.AllocateSession(); + NL_TEST_ASSERT(inSuite, secureSessionHolder); gLoopback.Reset(); NL_TEST_ASSERT(inSuite, - pairing.WaitForPairing(sTestSpake2p01_PASEVerifier, sTestSpake2p01_IterationCount, ByteSpan(nullptr, 0), 0, - Optional::Missing(), + pairing.WaitForPairing(sTestSpake2p01_PASEVerifier, sTestSpake2p01_IterationCount, ByteSpan(nullptr, 0), + emptySessionHolder, Optional::Missing(), &delegate) == CHIP_ERROR_INVALID_ARGUMENT); ctx.DrainAndServiceIO(); NL_TEST_ASSERT(inSuite, pairing.WaitForPairing(sTestSpake2p01_PASEVerifier, sTestSpake2p01_IterationCount, - ByteSpan(reinterpret_cast("saltSalt"), 8), 0, + ByteSpan(reinterpret_cast("saltSalt"), 8), emptySessionHolder, Optional::Missing(), nullptr) == CHIP_ERROR_INVALID_ARGUMENT); ctx.DrainAndServiceIO(); NL_TEST_ASSERT(inSuite, pairing.WaitForPairing(sTestSpake2p01_PASEVerifier, sTestSpake2p01_IterationCount, - ByteSpan(reinterpret_cast("saltSalt"), 8), 0, + ByteSpan(reinterpret_cast("saltSalt"), 8), emptySessionHolder, Optional::Missing(), &delegate) == CHIP_ERROR_INVALID_ARGUMENT); ctx.DrainAndServiceIO(); NL_TEST_ASSERT(inSuite, pairing.WaitForPairing(sTestSpake2p01_PASEVerifier, sTestSpake2p01_IterationCount, ByteSpan(sTestSpake2p01_Salt), - 0, Optional::Missing(), &delegate) == CHIP_NO_ERROR); + emptySessionHolder, Optional::Missing(), + &delegate) == CHIP_ERROR_INVALID_ARGUMENT); + ctx.DrainAndServiceIO(); + + NL_TEST_ASSERT(inSuite, + pairing.WaitForPairing(sTestSpake2p01_PASEVerifier, sTestSpake2p01_IterationCount, ByteSpan(sTestSpake2p01_Salt), + secureSessionHolder, Optional::Missing(), + &delegate) == CHIP_NO_ERROR); ctx.DrainAndServiceIO(); } @@ -154,20 +166,24 @@ void SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) // Test all combinations of invalid parameters TestSecurePairingDelegate delegate; - PASESession pairing; + SessionManager sessionManager; + SessionHolder secureSessionHolder; + + secureSessionHolder = sessionManager.AllocateSession(); + NL_TEST_ASSERT(inSuite, secureSessionHolder); gLoopback.Reset(); ExchangeContext * context = ctx.NewUnauthenticatedExchangeToBob(&pairing); NL_TEST_ASSERT(inSuite, - pairing.Pair(Transport::PeerAddress(Transport::Type::kBle), sTestSpake2p01_PinCode, 0, + pairing.Pair(Transport::PeerAddress(Transport::Type::kBle), sTestSpake2p01_PinCode, secureSessionHolder, Optional::Missing(), nullptr, nullptr) != CHIP_NO_ERROR); gLoopback.Reset(); NL_TEST_ASSERT(inSuite, - pairing.Pair(Transport::PeerAddress(Transport::Type::kBle), sTestSpake2p01_PinCode, 0, + pairing.Pair(Transport::PeerAddress(Transport::Type::kBle), sTestSpake2p01_PinCode, secureSessionHolder, Optional::Missing(), context, &delegate) == CHIP_NO_ERROR); ctx.DrainAndServiceIO(); @@ -185,7 +201,7 @@ void SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) PASESession pairing1; ExchangeContext * context1 = ctx.NewUnauthenticatedExchangeToBob(&pairing1); NL_TEST_ASSERT(inSuite, - pairing1.Pair(Transport::PeerAddress(Transport::Type::kBle), sTestSpake2p01_PinCode, 0, + pairing1.Pair(Transport::PeerAddress(Transport::Type::kBle), sTestSpake2p01_PinCode, secureSessionHolder, Optional::Missing(), context1, &delegate) == CHIP_ERROR_BAD_REQUEST); ctx.DrainAndServiceIO(); @@ -202,6 +218,11 @@ void SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inContext, P TestSecurePairingDelegate delegateAccessory; PASESession pairingAccessory; + SessionManager sessionManager; + SessionHolder secureSessionHolder; + + secureSessionHolder = sessionManager.AllocateSession(); + NL_TEST_ASSERT(inSuite, secureSessionHolder); gLoopback.mSentMessageCount = 0; @@ -226,13 +247,14 @@ void SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inContext, P NL_TEST_ASSERT(inSuite, pairingAccessory.WaitForPairing(sTestSpake2p01_PASEVerifier, sTestSpake2p01_IterationCount, - ByteSpan(sTestSpake2p01_Salt), 0, mrpAccessoryConfig, + ByteSpan(sTestSpake2p01_Salt), secureSessionHolder, mrpAccessoryConfig, &delegateAccessory) == CHIP_NO_ERROR); ctx.DrainAndServiceIO(); NL_TEST_ASSERT(inSuite, - pairingCommissioner.Pair(Transport::PeerAddress(Transport::Type::kBle), sTestSpake2p01_PinCode, 0, - mrpCommissionerConfig, contextCommissioner, &delegateCommissioner) == CHIP_NO_ERROR); + pairingCommissioner.Pair(Transport::PeerAddress(Transport::Type::kBle), sTestSpake2p01_PinCode, + secureSessionHolder, mrpCommissionerConfig, contextCommissioner, + &delegateCommissioner) == CHIP_NO_ERROR); ctx.DrainAndServiceIO(); while (gLoopback.mMessageDropped) @@ -332,6 +354,9 @@ void SecurePairingFailedHandshake(nlTestSuite * inSuite, void * inContext) TestSecurePairingDelegate delegateAccessory; PASESession pairingAccessory; + SessionManager sessionManager; + SessionHolder secureSessionHolder; + gLoopback.Reset(); gLoopback.mSentMessageCount = 0; @@ -339,8 +364,10 @@ void SecurePairingFailedHandshake(nlTestSuite * inSuite, void * inContext) ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr(); ReliableMessageContext * rc = contextCommissioner->GetReliableMessageContext(); + secureSessionHolder = sessionManager.AllocateSession(); NL_TEST_ASSERT(inSuite, rm != nullptr); NL_TEST_ASSERT(inSuite, rc != nullptr); + NL_TEST_ASSERT(inSuite, secureSessionHolder); contextCommissioner->GetSessionHandle()->AsUnauthenticatedSession()->SetMRPConfig({ 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_IDLE_RETRY_INTERVAL @@ -352,13 +379,14 @@ void SecurePairingFailedHandshake(nlTestSuite * inSuite, void * inContext) Protocols::SecureChannel::MsgType::PBKDFParamRequest, &pairingAccessory) == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, - pairingAccessory.WaitForPairing( - sTestSpake2p01_PASEVerifier, sTestSpake2p01_IterationCount, ByteSpan(sTestSpake2p01_Salt), 0, - Optional::Missing(), &delegateAccessory) == CHIP_NO_ERROR); + pairingAccessory.WaitForPairing(sTestSpake2p01_PASEVerifier, sTestSpake2p01_IterationCount, + ByteSpan(sTestSpake2p01_Salt), secureSessionHolder, + Optional::Missing(), + &delegateAccessory) == CHIP_NO_ERROR); ctx.DrainAndServiceIO(); NL_TEST_ASSERT(inSuite, - pairingCommissioner.Pair(Transport::PeerAddress(Transport::Type::kBle), 4321, 0, + pairingCommissioner.Pair(Transport::PeerAddress(Transport::Type::kBle), 4321, secureSessionHolder, Optional::Missing(), contextCommissioner, &delegateCommissioner) == CHIP_NO_ERROR); ctx.DrainAndServiceIO(); @@ -369,75 +397,6 @@ void SecurePairingFailedHandshake(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, delegateCommissioner.mNumPairingErrors == 1); } -void SecurePairingDeserialize(nlTestSuite * inSuite, void * inContext, PASESession & pairingCommissioner, - PASESession & deserialized) -{ - PASESessionSerialized serialized; - gLoopback.Reset(); - NL_TEST_ASSERT(inSuite, pairingCommissioner.Serialize(serialized) == CHIP_NO_ERROR); - - NL_TEST_ASSERT(inSuite, deserialized.Deserialize(serialized) == CHIP_NO_ERROR); - - // Serialize from the deserialized session, and check we get the same string back - PASESessionSerialized serialized2; - NL_TEST_ASSERT(inSuite, deserialized.Serialize(serialized2) == CHIP_NO_ERROR); - - NL_TEST_ASSERT(inSuite, strncmp(Uint8::to_char(serialized.inner), Uint8::to_char(serialized2.inner), sizeof(serialized)) == 0); -} - -void SecurePairingSerializeTest(nlTestSuite * inSuite, void * inContext) -{ - TestSecurePairingDelegate delegateCommissioner; - - // Allocate on the heap to avoid stack overflow in some restricted test scenarios (e.g. QEMU) - auto * testPairingSession1 = chip::Platform::New(); - auto * testPairingSession2 = chip::Platform::New(); - - gLoopback.Reset(); - - SecurePairingHandshakeTestCommon(inSuite, inContext, *testPairingSession1, Optional::Missing(), - Optional::Missing(), delegateCommissioner); - SecurePairingDeserialize(inSuite, inContext, *testPairingSession1, *testPairingSession2); - - const uint8_t plain_text[] = { 0x86, 0x74, 0x64, 0xe5, 0x0b, 0xd4, 0x0d, 0x90, 0xe1, 0x17, 0xa3, 0x2d, 0x4b, 0xd4, 0xe1, 0xe6 }; - uint8_t encrypted[64]; - PacketHeader header; - MessageAuthenticationCode mac; - - header.SetSessionId(1); - NL_TEST_ASSERT(inSuite, header.IsEncrypted() == true); - NL_TEST_ASSERT(inSuite, header.MICTagLength() == 16); - - // Let's try encrypting using original session, and decrypting using deserialized - { - CryptoContext session1; - - CHIP_ERROR err = testPairingSession1->DeriveSecureSession(session1, CryptoContext::SessionRole::kInitiator); - - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - - CryptoContext::NonceStorage nonce; - CryptoContext::BuildNonce(nonce, header.GetSecurityFlags(), header.GetMessageCounter(), kUndefinedNodeId); - err = session1.Encrypt(plain_text, sizeof(plain_text), encrypted, nonce, header, mac); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - } - - { - CryptoContext session2; - NL_TEST_ASSERT(inSuite, - testPairingSession2->DeriveSecureSession(session2, CryptoContext::SessionRole::kResponder) == CHIP_NO_ERROR); - - uint8_t decrypted[64]; - CryptoContext::NonceStorage nonce; - CryptoContext::BuildNonce(nonce, header.GetSecurityFlags(), header.GetMessageCounter(), kUndefinedNodeId); - NL_TEST_ASSERT(inSuite, session2.Decrypt(encrypted, sizeof(plain_text), decrypted, nonce, header, mac) == CHIP_NO_ERROR); - NL_TEST_ASSERT(inSuite, memcmp(plain_text, decrypted, sizeof(plain_text)) == 0); - } - - chip::Platform::Delete(testPairingSession1); - chip::Platform::Delete(testPairingSession2); -} - void PASEVerifierSerializeTest(nlTestSuite * inSuite, void * inContext) { Spake2pVerifier verifier; @@ -474,7 +433,6 @@ static const nlTest sTests[] = NL_TEST_DEF("Handshake with Both MRP Parameters", SecurePairingHandshakeWithAllMRPTest), NL_TEST_DEF("Handshake with packet loss", SecurePairingHandshakeWithPacketLossTest), NL_TEST_DEF("Failed Handshake", SecurePairingFailedHandshake), - NL_TEST_DEF("Serialize", SecurePairingSerializeTest), NL_TEST_DEF("PASE Verifier Serialize", PASEVerifierSerializeTest), NL_TEST_SENTINEL() diff --git a/src/protocols/secure_channel/tests/TestSessionIDAllocator.cpp b/src/protocols/secure_channel/tests/TestSessionIDAllocator.cpp deleted file mode 100644 index 55ce38f24b30d8..00000000000000 --- a/src/protocols/secure_channel/tests/TestSessionIDAllocator.cpp +++ /dev/null @@ -1,141 +0,0 @@ -/* - * - * Copyright (c) 2021 Project CHIP Authors - * All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -#include - -using namespace chip; - -void TestSessionIDAllocator_Free(nlTestSuite * inSuite, void * inContext) -{ - SessionIDAllocator allocator; - uint16_t i = allocator.Peek(); - - uint16_t id; - - for (uint16_t j = 0; j < 17; j++) - { - CHIP_ERROR err = allocator.Allocate(id); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - NL_TEST_ASSERT(inSuite, id == static_cast(i + j)); - NL_TEST_ASSERT(inSuite, allocator.Peek() == static_cast(i + j + 1)); - } - - // Free an intermediate ID - allocator.Free(10); - NL_TEST_ASSERT(inSuite, allocator.Peek() == static_cast(i + 17)); - - // Free the last allocated ID - allocator.Free(static_cast(i + 16)); - NL_TEST_ASSERT(inSuite, allocator.Peek() == static_cast(i + 16)); - - // Free some random unallocated ID - allocator.Free(100); - NL_TEST_ASSERT(inSuite, allocator.Peek() == static_cast(i + 16)); -} - -void TestSessionIDAllocator_Reserve(nlTestSuite * inSuite, void * inContext) -{ - SessionIDAllocator allocator; - uint16_t i = allocator.Peek(); - uint16_t id; - - for (uint16_t j = 0; j < 17; j++) - { - CHIP_ERROR err = allocator.Allocate(id); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - NL_TEST_ASSERT(inSuite, id == static_cast(i + j)); - NL_TEST_ASSERT(inSuite, allocator.Peek() == static_cast(i + j + 1)); - } - - i = allocator.Peek(); - allocator.Reserve(static_cast(i + 100)); - NL_TEST_ASSERT(inSuite, allocator.Peek() == static_cast(i + 101)); -} - -void TestSessionIDAllocator_ReserveUpTo(nlTestSuite * inSuite, void * inContext) -{ - SessionIDAllocator allocator; - uint16_t i = allocator.Peek(); - - i = allocator.Peek(); - allocator.Reserve(static_cast(i + 100)); - NL_TEST_ASSERT(inSuite, allocator.Peek() == static_cast(i + 101)); -} - -// Test Suite - -/** - * Test Suite that lists all the test functions. - */ -// clang-format off -static const nlTest sTests[] = -{ - NL_TEST_DEF("SessionIDAllocator_Free", TestSessionIDAllocator_Free), - NL_TEST_DEF("SessionIDAllocator_Reserve", TestSessionIDAllocator_Reserve), - NL_TEST_DEF("SessionIDAllocator_ReserveUpTo", TestSessionIDAllocator_ReserveUpTo), - - NL_TEST_SENTINEL() -}; -// clang-format on - -/** - * Set up the test suite. - */ -static int TestSetup(void * inContext) -{ - CHIP_ERROR error = chip::Platform::MemoryInit(); - if (error != CHIP_NO_ERROR) - return FAILURE; - return SUCCESS; -} - -/** - * Tear down the test suite. - */ -static int TestTeardown(void * inContext) -{ - chip::Platform::MemoryShutdown(); - return SUCCESS; -} - -// clang-format off -static nlTestSuite sSuite = -{ - "Test-CHIP-SessionIDAllocator", - &sTests[0], - TestSetup, - TestTeardown, -}; -// clang-format on - -/** - * Main - */ -int TestSessionIDAllocator() -{ - // Run test suit against one context - nlTestRunner(&sSuite, nullptr); - - return (nlTestRunnerStats(&sSuite)); -} - -CHIP_REGISTER_TEST_SUITE(TestSessionIDAllocator) diff --git a/src/transport/PairingSession.h b/src/transport/PairingSession.h index 1a8682f8578f4b..46128a4db1faf9 100644 --- a/src/transport/PairingSession.h +++ b/src/transport/PairingSession.h @@ -50,11 +50,18 @@ class DLL_EXPORT PairingSession CATValues GetPeerCATs() const { return mPeerCATs; } - // TODO: the local key id should be allocateed at start - // mLocalSessionId should be const and assigned at the construction, such that GetLocalSessionId will always return a valid key - // id , and SetLocalSessionId is not necessary. - uint16_t GetLocalSessionId() const { return mLocalSessionId; } - bool IsValidLocalSessionId() const { return mLocalSessionId != kInvalidKeyId; } + Optional GetLocalSessionId() const + { + Optional localSessionId; + VerifyOrExit(mSecureSessionHolder, localSessionId = Optional::Missing()); + VerifyOrExit(mSecureSessionHolder->GetSessionType() == Transport::Session::SessionType::kSecure, + localSessionId = Optional::Missing()); + localSessionId.SetValue(mSecureSessionHolder->AsSecureSession()->GetLocalSessionId()); + exit: + return localSessionId; + } + + SessionHolder & GetSecureSessionHolder() { return mSecureSessionHolder; } uint16_t GetPeerSessionId() const { @@ -97,7 +104,7 @@ class DLL_EXPORT PairingSession void SetPeerNodeId(NodeId peerNodeId) { mPeerNodeId = peerNodeId; } void SetPeerCATs(CATValues peerCATs) { mPeerCATs = peerCATs; } void SetPeerSessionId(uint16_t id) { mPeerSessionId.SetValue(id); } - void SetLocalSessionId(uint16_t id) { mLocalSessionId = id; } + void SetSecureSessionHolder(SessionHolder holder) { mSecureSessionHolder = holder; } void SetPeerAddress(const Transport::PeerAddress & address) { mPeerAddress = address; } virtual void OnSuccessStatusReport() {} virtual CHIP_ERROR OnFailureStatusReport(Protocols::SecureChannel::GeneralStatusCode generalCode, uint16_t protocolCode) @@ -170,7 +177,7 @@ class DLL_EXPORT PairingSession mPeerCATs = kUndefinedCATs; mPeerAddress = Transport::PeerAddress::Uninitialized(); mPeerSessionId.ClearValue(); - mLocalSessionId = kInvalidKeyId; + mSecureSessionHolder.Release(); } private: @@ -178,10 +185,7 @@ class DLL_EXPORT PairingSession NodeId mPeerNodeId = kUndefinedNodeId; CATValues mPeerCATs; - // TODO: the local key id should be allocateed at start - // then we can remove kInvalidKeyId - static constexpr uint16_t kInvalidKeyId = UINT16_MAX; - uint16_t mLocalSessionId = kInvalidKeyId; + SessionHolder mSecureSessionHolder; // TODO: decouple peer address into transport, such that pairing session do not need to handle peer address Transport::PeerAddress mPeerAddress = Transport::PeerAddress::Uninitialized(); diff --git a/src/transport/SecureSession.h b/src/transport/SecureSession.h index 8146884cf15373..e4c1d4783d2b21 100644 --- a/src/transport/SecureSession.h +++ b/src/transport/SecureSession.h @@ -61,6 +61,14 @@ class SecureSession : public Session kUndefined = 0, kPASE = 1, kCASE = 2, + // kPending denotes a secure session object that is internally + // reserved by the stack before and during session establishment. + // + // Although the stack can tolerate eviction of these (releasing one + // out from under the holder would exhibit as CHIP_ERROR_INCORRECT_STATE + // during CASE or PASE), intent is that we should not and would leave + // these untouched until CASE or PASE complete. + kPending = 3, }; SecureSession(Type secureSessionType, uint16_t localSessionId, NodeId peerNodeId, CATValues peerCATs, uint16_t peerSessionId, @@ -71,6 +79,32 @@ class SecureSession : public Session { SetFabricIndex(fabric); } + + /** + * @brief + * Construct a secure session to associate with a pending secure + * session establishment attempt. A pending secure session object + * receives a session ID, but no other state. + */ + SecureSession(uint16_t localSessionId) : + SecureSession(Type::kPending, localSessionId, kUndefinedNodeId, CATValues{}, 0, kUndefinedFabricIndex, GetLocalMRPConfig()) + {} + + /** + * Activate a pending Secure Session that had been reserved during CASE or + * PASE, setting internal state according to the parameters used and + * discovered during session establishment. + */ + void Activate(Type secureSessionType, NodeId peerNodeId, CATValues peerCATs, uint16_t peerSessionId, FabricIndex fabric, + const ReliableMessageProtocolConfig & config) + { + mSecureSessionType = secureSessionType; + mPeerNodeId = peerNodeId; + mPeerCATs = peerCATs; + mPeerSessionId = peerSessionId; + mMRPConfig = config; + SetFabricIndex(fabric); + } ~SecureSession() override { NotifySessionReleased(); } SecureSession(SecureSession &&) = delete; @@ -141,11 +175,11 @@ class SecureSession : public Session SessionMessageCounter & GetSessionMessageCounter() { return mSessionMessageCounter; } private: - const Type mSecureSessionType; - const NodeId mPeerNodeId; - const CATValues mPeerCATs; + Type mSecureSessionType; + NodeId mPeerNodeId; + CATValues mPeerCATs; const uint16_t mLocalSessionId; - const uint16_t mPeerSessionId; + uint16_t mPeerSessionId; PeerAddress mPeerAddress; System::Clock::Timestamp mLastActivityTime; diff --git a/src/transport/SecureSessionTable.h b/src/transport/SecureSessionTable.h index 9dce2198c3bbc4..197db3b1b8c615 100644 --- a/src/transport/SecureSessionTable.h +++ b/src/transport/SecureSessionTable.h @@ -25,10 +25,8 @@ namespace chip { namespace Transport { -// TODO; use 0xffff to match any key id, this is a temporary solution for -// InteractionModel, where key id is not obtainable. This will be removed when -// InteractionModel is migrated to messaging layer -constexpr const uint16_t kAnyKeyId = 0xffff; +constexpr uint16_t kMaxSessionID = UINT16_MAX; +constexpr uint16_t kUnsecuredSessionId = 0; /** * Handles a set of sessions. @@ -69,6 +67,49 @@ class SecureSessionTable return result != nullptr ? MakeOptional(*result) : Optional::Missing(); } + /** + * Allocates a new secure session out of the internal resource pool with the + * specified session ID. The returned secure session will not become active + * until the call to SecureSession::Activate. + * + * @returns allocated session on success, else failure + */ + CHECK_RETURN_VALUE + Optional CreateNewSecureSession(uint16_t localSessionId) + { + Optional rv = Optional::Missing(); + SecureSession * allocated = nullptr; + VerifyOrExit(!FindSecureSessionByLocalKey(localSessionId).HasValue(), rv = Optional::Missing()); + allocated = mEntries.CreateObject(localSessionId); + VerifyOrExit(allocated != nullptr, rv = Optional::Missing()); + rv = MakeOptional(*allocated); + exit: + return rv; + } + + /** + * Allocates a new secure session out of the internal resource pool with a + * non-colliding session ID and increments mNextSessionId to give a clue to + * the allocator for the next allocation. The secure session session will + * become active unitl the call to SecureSession::Activate. + * + * @returns allocated session on success, else failure + */ + CHECK_RETURN_VALUE + Optional CreateNewSecureSession() + { + Optional rv = Optional::Missing(); + auto sessionId = FindUnusedSessionId(); + SecureSession * allocated = nullptr; + VerifyOrExit(sessionId.HasValue(), rv = Optional::Missing()); + allocated = mEntries.CreateObject(sessionId.Value()); + VerifyOrExit(allocated != nullptr, rv = Optional::Missing()); + rv = MakeOptional(*allocated); + mNextSessionId = sessionId.Value() == kMaxSessionID ? kUnsecuredSessionId + 1 : sessionId.Value() + 1; + exit: + return rv; + } + void ReleaseSession(SecureSession * session) { mEntries.ReleaseObject(session); } template @@ -78,7 +119,7 @@ class SecureSessionTable } /** - * Get a secure session given a Node Id and Peer's Encryption Key Id. + * Get a secure session given a Encryption key ID. * * @param localSessionId Encryption key ID used by the local node. * @@ -109,7 +150,8 @@ class SecureSessionTable void ExpireInactiveSessions(System::Clock::Timestamp maxIdleTime, Callback callback) { mEntries.ForEachActiveObject([&](auto session) { - if (session->GetLastActivityTime() + maxIdleTime < System::SystemClock().GetMonotonicTimestamp()) + if (session->GetSecureSessionType() != SecureSession::Type::kPending && + session->GetLastActivityTime() + maxIdleTime < System::SystemClock().GetMonotonicTimestamp()) { callback(*session); ReleaseSession(session); @@ -119,7 +161,76 @@ class SecureSessionTable } private: + /** + * Find an available session ID that is unused in the secure sesion table. + * + * The search algorithm iterates over the session ID space in the outer loop + * and the session table in the inner loop to locate an available session ID + * from the starting mNextSessionId clue. + * + * Outer-loop iteration considers the session ID space in 64-entry buckets + * to give us runtime of O(kMaxSessionCount^2 / 64). This is the fastest + * we can be without a sorted session table or additional storage. + * + * @return an unused session ID if any is found, else nothing + */ + CHECK_RETURN_VALUE + Optional FindUnusedSessionId() + { + uint16_t candidate_base = 0; + uint64_t candidate_mask = 0; + for (uint32_t i = 0; i <= kMaxSessionID; i += 64) + { + // Candidate_base is the base Session ID we are searching from. + // We have a 64-bit mask anchored at this ID at iterate over the + // whole session table, marking bits in the mask for in-use IDs. + // If we can iterate through the entire session table and have + // any bits free in the mask, we have available session IDs. + candidate_base = static_cast(i) + mNextSessionId; + candidate_mask = 0; + { + uint16_t shift = kUnsecuredSessionId - candidate_base; + if (shift <= 63) + { + candidate_mask |= (1ULL << shift); // kUnsecuredSessionId is never available + } + } + mEntries.ForEachActiveObject([&](auto session) { + uint16_t shift = session->GetLocalSessionId() - candidate_base; + if (shift <= 63) + { + candidate_mask |= (1ULL << shift); + } + if (candidate_mask == UINT64_MAX) + { + return Loop::Break; // No bits clear means this bucket is full. + } + return Loop::Continue; + }); + if (candidate_mask != UINT64_MAX) + { + break; // Any bit clear means we have an available ID in this bucket. + } + } + if (candidate_mask != UINT64_MAX) + { + uint16_t offset = 0; + while (candidate_mask & 1) + { + candidate_mask >>= 1; + ++offset; + } + uint16_t available = candidate_base + offset; + return MakeOptional(available); + } + else + { + return Optional::Missing(); + } + } + BitMapObjectPool mEntries; + uint16_t mNextSessionId = 0; }; } // namespace Transport diff --git a/src/transport/Session.h b/src/transport/Session.h index 8c1ea3291e9bc9..f10954b1f07801 100644 --- a/src/transport/Session.h +++ b/src/transport/Session.h @@ -85,6 +85,8 @@ class Session return GetSessionType() == SessionType::kGroupIncoming || GetSessionType() == SessionType::kGroupOutgoing; } + bool IsSecureSession() const { return GetSessionType() == SessionType::kSecure; } + protected: // This should be called by sub-classes at the very beginning of the destructor, before any data field is disposed, such that // the session is still functional during the callback. diff --git a/src/transport/SessionHolder.h b/src/transport/SessionHolder.h index ded017e210da9f..c750be24a6b01f 100644 --- a/src/transport/SessionHolder.h +++ b/src/transport/SessionHolder.h @@ -68,6 +68,8 @@ class SessionHolderWithDelegate : public SessionHolder { public: SessionHolderWithDelegate(SessionReleaseDelegate & delegate) : mDelegate(delegate) {} + SessionHolderWithDelegate(SessionHolder holder, SessionReleaseDelegate & delegate) : SessionHolder(holder), mDelegate(delegate) + {} operator bool() const { return SessionHolder::operator bool(); } void OnSessionReleased() override diff --git a/src/transport/SessionManager.cpp b/src/transport/SessionManager.cpp index 652193a7d93fc1..6ea98374b529ad 100644 --- a/src/transport/SessionManager.cpp +++ b/src/transport/SessionManager.cpp @@ -372,27 +372,51 @@ void SessionManager::ExpireAllPairingsForFabric(FabricIndex fabric) }); } +SessionHolder SessionManager::AllocateSession() +{ + SessionHolder holder; + Optional session = mSecureSessions.CreateNewSecureSession(); + VerifyOrExit(session.HasValue(), holder = holder); + holder.Grab(session.Value()); +exit: + return holder; +} + +SessionHolder SessionManager::AllocateSession(uint16_t sessionId) +{ + // If we forego SessionManager session ID allocation, we can have a + // collission. In case of such a collission, we must evict first. + Optional oldSession = mSecureSessions.FindSecureSessionByLocalKey(sessionId); + if (oldSession.HasValue()) + { + mSecureSessions.ReleaseSession(oldSession.Value()->AsSecureSession()); + } + SessionHolder holder; + Optional session = mSecureSessions.CreateNewSecureSession(sessionId); + VerifyOrExit(session.HasValue(), holder = holder); + holder.Grab(session.Value()); +exit: + return holder; +} + CHIP_ERROR SessionManager::NewPairing(SessionHolder & sessionHolder, const Optional & peerAddr, NodeId peerNodeId, PairingSession * pairing, CryptoContext::SessionRole direction, FabricIndex fabric) { - uint16_t peerSessionId = pairing->GetPeerSessionId(); - uint16_t localSessionId = pairing->GetLocalSessionId(); - Optional session = mSecureSessions.FindSecureSessionByLocalKey(localSessionId); - - // Find any existing connection with the same local key ID - if (session.HasValue()) - { - mSecureSessions.ReleaseSession(session.Value()->AsSecureSession()); - } + uint16_t peerSessionId = pairing->GetPeerSessionId(); + SecureSession * secureSession; + uint16_t localSessionId; + sessionHolder = pairing->GetSecureSessionHolder(); + VerifyOrReturnError(sessionHolder, CHIP_ERROR_INCORRECT_STATE); + VerifyOrReturnError(sessionHolder->IsSecureSession(), CHIP_ERROR_INCORRECT_STATE); + secureSession = sessionHolder->AsSecureSession(); + localSessionId = secureSession->GetLocalSessionId(); ChipLogDetail(Inet, "New secure session created for device 0x" ChipLogFormatX64 ", LSID:%d PSID:%d!", ChipLogValueX64(peerNodeId), localSessionId, peerSessionId); - session = mSecureSessions.CreateNewSecureSession(pairing->GetSecureSessionType(), localSessionId, peerNodeId, - pairing->GetPeerCATs(), peerSessionId, fabric, pairing->GetMRPConfig()); - ReturnErrorCodeIf(!session.HasValue(), CHIP_ERROR_NO_MEMORY); + secureSession->Activate(pairing->GetSecureSessionType(), peerNodeId, pairing->GetPeerCATs(), peerSessionId, fabric, + pairing->GetMRPConfig()); - Transport::SecureSession * secureSession = session.Value()->AsSecureSession(); if (peerAddr.HasValue() && peerAddr.Value().GetIPAddress() != Inet::IPAddress::Any) { secureSession->SetPeerAddress(peerAddr.Value()); @@ -411,7 +435,7 @@ CHIP_ERROR SessionManager::NewPairing(SessionHolder & sessionHolder, const Optio ReturnErrorOnFailure(pairing->DeriveSecureSession(secureSession->GetCryptoContext(), direction)); secureSession->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(LocalSessionMessageCounter::kInitialSyncValue); - sessionHolder.Grab(session.Value()); + return CHIP_NO_ERROR; } diff --git a/src/transport/SessionManager.h b/src/transport/SessionManager.h index 87aef5649a4959..065621c346df0f 100644 --- a/src/transport/SessionManager.h +++ b/src/transport/SessionManager.h @@ -168,6 +168,21 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate CHIP_ERROR NewPairing(SessionHolder & sessionHolder, const Optional & peerAddr, NodeId peerNodeId, PairingSession * pairing, CryptoContext::SessionRole direction, FabricIndex fabric); + /** + * @brief + * Allocate a secure session and non-colliding session ID in the secure + * session table. + */ + SessionHolder AllocateSession(); + + /** + * @brief + * Allocate a secure session in the secure session table at the specified + * session ID. If the session ID collides with an existing session, evict + * it. + */ + SessionHolder AllocateSession(uint16_t localSessionId); + void ExpirePairing(const SessionHandle & session); void ExpireAllPairings(NodeId peerNodeId, FabricIndex fabric); void ExpireAllPairingsForFabric(FabricIndex fabric); diff --git a/src/transport/tests/TestSessionManager.cpp b/src/transport/tests/TestSessionManager.cpp index 4cda4fcaa699f3..e4b96d5288e5c7 100644 --- a/src/transport/tests/TestSessionManager.cpp +++ b/src/transport/tests/TestSessionManager.cpp @@ -168,12 +168,12 @@ void CheckMessageTest(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == fabricTable.AddNewFabric(bobFabric, &bobFabricIndex)); SessionHolder aliceToBobSession; - SecurePairingUsingTestSecret aliceToBobPairing(1, 2); + SecurePairingUsingTestSecret aliceToBobPairing(1, 2, sessionManager); err = sessionManager.NewPairing(aliceToBobSession, peer, fabricTable.FindFabricWithIndex(bobFabricIndex)->GetNodeId(), &aliceToBobPairing, CryptoContext::SessionRole::kInitiator, aliceFabricIndex); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - SecurePairingUsingTestSecret bobToAlicePairing(2, 1); + SecurePairingUsingTestSecret bobToAlicePairing(2, 1, sessionManager); SessionHolder bobToAliceSession; err = sessionManager.NewPairing(bobToAliceSession, peer, fabricTable.FindFabricWithIndex(aliceFabricIndex)->GetNodeId(), &bobToAlicePairing, CryptoContext::SessionRole::kResponder, bobFabricIndex); @@ -283,12 +283,12 @@ void SendEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == fabricTable.AddNewFabric(bobFabric, &bobFabricIndex)); SessionHolder aliceToBobSession; - SecurePairingUsingTestSecret aliceToBobPairing(1, 2); + SecurePairingUsingTestSecret aliceToBobPairing(1, 2, sessionManager); err = sessionManager.NewPairing(aliceToBobSession, peer, fabricTable.FindFabricWithIndex(bobFabricIndex)->GetNodeId(), &aliceToBobPairing, CryptoContext::SessionRole::kInitiator, aliceFabricIndex); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - SecurePairingUsingTestSecret bobToAlicePairing(2, 1); + SecurePairingUsingTestSecret bobToAlicePairing(2, 1, sessionManager); SessionHolder bobToAliceSession; err = sessionManager.NewPairing(bobToAliceSession, peer, fabricTable.FindFabricWithIndex(aliceFabricIndex)->GetNodeId(), &bobToAlicePairing, CryptoContext::SessionRole::kResponder, bobFabricIndex); @@ -384,12 +384,12 @@ void SendBadEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == fabricTable.AddNewFabric(bobFabric, &bobFabricIndex)); SessionHolder aliceToBobSession; - SecurePairingUsingTestSecret aliceToBobPairing(1, 2); + SecurePairingUsingTestSecret aliceToBobPairing(1, 2, sessionManager); err = sessionManager.NewPairing(aliceToBobSession, peer, fabricTable.FindFabricWithIndex(bobFabricIndex)->GetNodeId(), &aliceToBobPairing, CryptoContext::SessionRole::kInitiator, aliceFabricIndex); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - SecurePairingUsingTestSecret bobToAlicePairing(2, 1); + SecurePairingUsingTestSecret bobToAlicePairing(2, 1, sessionManager); SessionHolder bobToAliceSession; err = sessionManager.NewPairing(bobToAliceSession, peer, fabricTable.FindFabricWithIndex(aliceFabricIndex)->GetNodeId(), &bobToAlicePairing, CryptoContext::SessionRole::kResponder, bobFabricIndex); @@ -497,36 +497,36 @@ void StaleConnectionDropTest(nlTestSuite * inSuite, void * inContext) SessionHolderWithDelegate session5(callback); // First pairing - SecurePairingUsingTestSecret pairing1(1, 1); callback.mOldConnectionDropped = false; + SecurePairingUsingTestSecret pairing1(1, 1, sessionManager); err = sessionManager.NewPairing(session1, peer, kSourceNodeId, &pairing1, CryptoContext::SessionRole::kInitiator, 1); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, !callback.mOldConnectionDropped); // New pairing with different peer node ID and different local key ID (same peer key ID) - SecurePairingUsingTestSecret pairing2(1, 2); callback.mOldConnectionDropped = false; + SecurePairingUsingTestSecret pairing2(1, 2, sessionManager); err = sessionManager.NewPairing(session2, peer, kSourceNodeId, &pairing2, CryptoContext::SessionRole::kResponder, 0); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, !callback.mOldConnectionDropped); // New pairing with undefined node ID and different local key ID (same peer key ID) - SecurePairingUsingTestSecret pairing3(1, 3); callback.mOldConnectionDropped = false; + SecurePairingUsingTestSecret pairing3(1, 3, sessionManager); err = sessionManager.NewPairing(session3, peer, kUndefinedNodeId, &pairing3, CryptoContext::SessionRole::kResponder, 0); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, !callback.mOldConnectionDropped); // New pairing with same local key ID, and a given node ID - SecurePairingUsingTestSecret pairing4(1, 2); callback.mOldConnectionDropped = false; + SecurePairingUsingTestSecret pairing4(1, 2, sessionManager); err = sessionManager.NewPairing(session4, peer, kSourceNodeId, &pairing4, CryptoContext::SessionRole::kResponder, 0); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, callback.mOldConnectionDropped); // New pairing with same local key ID, and undefined node ID - SecurePairingUsingTestSecret pairing5(1, 1); callback.mOldConnectionDropped = false; + SecurePairingUsingTestSecret pairing5(1, 1, sessionManager); err = sessionManager.NewPairing(session5, peer, kUndefinedNodeId, &pairing5, CryptoContext::SessionRole::kResponder, 0); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, callback.mOldConnectionDropped); @@ -590,12 +590,12 @@ void SendPacketWithOldCounterTest(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == fabricTable.AddNewFabric(bobFabric, &bobFabricIndex)); SessionHolder aliceToBobSession; - SecurePairingUsingTestSecret aliceToBobPairing(1, 2); + SecurePairingUsingTestSecret aliceToBobPairing(1, 2, sessionManager); err = sessionManager.NewPairing(aliceToBobSession, peer, fabricTable.FindFabricWithIndex(bobFabricIndex)->GetNodeId(), &aliceToBobPairing, CryptoContext::SessionRole::kInitiator, aliceFabricIndex); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - SecurePairingUsingTestSecret bobToAlicePairing(2, 1); + SecurePairingUsingTestSecret bobToAlicePairing(2, 1, sessionManager); SessionHolder bobToAliceSession; err = sessionManager.NewPairing(bobToAliceSession, peer, fabricTable.FindFabricWithIndex(aliceFabricIndex)->GetNodeId(), &bobToAlicePairing, CryptoContext::SessionRole::kResponder, bobFabricIndex); @@ -704,12 +704,12 @@ void SendPacketWithTooOldCounterTest(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == fabricTable.AddNewFabric(bobFabric, &bobFabricIndex)); SessionHolder aliceToBobSession; - SecurePairingUsingTestSecret aliceToBobPairing(1, 2); + SecurePairingUsingTestSecret aliceToBobPairing(1, 2, sessionManager); err = sessionManager.NewPairing(aliceToBobSession, peer, fabricTable.FindFabricWithIndex(bobFabricIndex)->GetNodeId(), &aliceToBobPairing, CryptoContext::SessionRole::kInitiator, aliceFabricIndex); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - SecurePairingUsingTestSecret bobToAlicePairing(2, 1); + SecurePairingUsingTestSecret bobToAlicePairing(2, 1, sessionManager); SessionHolder bobToAliceSession; err = sessionManager.NewPairing(bobToAliceSession, peer, fabricTable.FindFabricWithIndex(aliceFabricIndex)->GetNodeId(), &bobToAlicePairing, CryptoContext::SessionRole::kResponder, bobFabricIndex); @@ -764,6 +764,56 @@ void SendPacketWithTooOldCounterTest(nlTestSuite * inSuite, void * inContext) sessionManager.Shutdown(); } +void SessionAllocationTest(nlTestSuite * inSuite, void * inContext) +{ + SessionManager sessionManager; + TestSessionReleaseCallback callback; + + // Allocate a session. + SessionHolderWithDelegate session1(sessionManager.AllocateSession(), callback); + NL_TEST_ASSERT(inSuite, session1); + auto sessionId1 = session1->AsSecureSession()->GetLocalSessionId(); + + // Allocate a session at a colliding ID, verify eviction. + callback.mOldConnectionDropped = false; + SessionHolderWithDelegate session2(sessionManager.AllocateSession(sessionId1), callback); + NL_TEST_ASSERT(inSuite, session2); + + auto prevSessionId = sessionId1; + // Verify that session IDs monotonically increase, except for the + // wraparound case where we skip session ID 0. + for (uint32_t i = 0; i < 10; ++i) + { + auto session = sessionManager.AllocateSession(); + if (!session) + { + break; + } + auto sessionId = session->AsSecureSession()->GetLocalSessionId(); + NL_TEST_ASSERT(inSuite, sessionId - prevSessionId == 1 || (sessionId == 1 && prevSessionId == 65535)); + prevSessionId = sessionId; + } + + sessionManager.~SessionManager(); + new (&sessionManager) SessionManager(); + + prevSessionId = 0; + // Verify that session IDs monotonically increase, even when releasing + // sessions, except for the wraparound case where we skip session ID 0. + for (uint32_t i = 0; i < UINT16_MAX + 10; ++i) + { + auto session = sessionManager.AllocateSession(); + NL_TEST_ASSERT(inSuite, session); + auto sessionId = session->AsSecureSession()->GetLocalSessionId(); + NL_TEST_ASSERT(inSuite, sessionId - prevSessionId == 1 || (sessionId == 1 && prevSessionId == 65535)); + prevSessionId = sessionId; + sessionManager.ExpirePairing(session.Get()); + } + + // Verify that session IDs monotonically increase, even when we free sessions. + sessionManager.Shutdown(); +} + // Test Suite /** @@ -779,6 +829,7 @@ const nlTest sTests[] = NL_TEST_DEF("Drop stale connection Test", StaleConnectionDropTest), NL_TEST_DEF("Old counter Test", SendPacketWithOldCounterTest), NL_TEST_DEF("Too-old counter Test", SendPacketWithTooOldCounterTest), + NL_TEST_DEF("Session Allocation Test", SessionAllocationTest), NL_TEST_SENTINEL() }; diff --git a/third_party/efr32_sdk/repo b/third_party/efr32_sdk/repo index d90c8379f259e1..9b527ad5d35e72 160000 --- a/third_party/efr32_sdk/repo +++ b/third_party/efr32_sdk/repo @@ -1 +1 @@ -Subproject commit d90c8379f259e1a6b33c1ec147289f5e55d2b6a8 +Subproject commit 9b527ad5d35e72f0266b54ed64d74ebe0170aa45