diff --git a/examples/all-clusters-app/esp32/main/EchoServer.cpp b/examples/all-clusters-app/esp32/main/EchoServer.cpp index 9ddebc256ad30a..c469c5940141bf 100644 --- a/examples/all-clusters-app/esp32/main/EchoServer.cpp +++ b/examples/all-clusters-app/esp32/main/EchoServer.cpp @@ -42,6 +42,7 @@ #include #include #include +#include #include #include @@ -123,7 +124,7 @@ class EchoServerCallback : public SecureSessionMgrDelegate public: void OnMessageReceived(const PacketHeader & header, const PayloadHeader & payloadHeader, const Transport::PeerConnectionState * state, System::PacketBuffer * buffer, - SecureSessionMgrBase * mgr) override + SecureSessionMgr * mgr) override { CHIP_ERROR err; const size_t data_len = buffer->DataLength(); @@ -180,13 +181,13 @@ class EchoServerCallback : public SecureSessionMgrDelegate } } - void OnReceiveError(CHIP_ERROR error, const Transport::PeerAddress & source, SecureSessionMgrBase * mgr) override + void OnReceiveError(CHIP_ERROR error, const Transport::PeerAddress & source, SecureSessionMgr * mgr) override { ESP_LOGE(TAG, "ERROR: %s\n Got UDP error", ErrorStr(error)); statusLED1.BlinkOnError(); } - void OnNewConnection(const Transport::PeerConnectionState * state, SecureSessionMgrBase * mgr) override + void OnNewConnection(const Transport::PeerConnectionState * state, SecureSessionMgr * mgr) override { ESP_LOGI(TAG, "Received a new connection."); } @@ -219,15 +220,16 @@ class EchoServerCallback : public SecureSessionMgrDelegate EchoServerCallback gCallbacks; -SecureSessionMgr - sessions; +TransportMgr + gTransports; +SecureSessionMgr sessions; } // namespace namespace chip { -SecureSessionMgrBase & SessionManager() +SecureSessionMgr & SessionManager() { return sessions; } @@ -237,9 +239,10 @@ SecureSessionMgrBase & SessionManager() void startServer(NodeId localNodeId) { CHIP_ERROR err = CHIP_NO_ERROR; - err = sessions.Init(localNodeId, &DeviceLayer::SystemLayer, - UdpListenParameters(&DeviceLayer::InetLayer).SetAddressType(kIPAddressType_IPv6).SetInterfaceId(NULL), - UdpListenParameters(&DeviceLayer::InetLayer).SetAddressType(kIPAddressType_IPv4)); + err = gTransports.Init(UdpListenParameters(&DeviceLayer::InetLayer).SetAddressType(kIPAddressType_IPv6).SetInterfaceId(nullptr), + UdpListenParameters(&DeviceLayer::InetLayer).SetAddressType(kIPAddressType_IPv4)); + SuccessOrExit(err); + err = sessions.Init(localNodeId, &DeviceLayer::SystemLayer, &gTransports); SuccessOrExit(err); sessions.SetDelegate(&gCallbacks); diff --git a/examples/all-clusters-app/esp32/main/RendezvousDeviceDelegate.cpp b/examples/all-clusters-app/esp32/main/RendezvousDeviceDelegate.cpp index c735fbea4dc6ef..de539548964c2d 100644 --- a/examples/all-clusters-app/esp32/main/RendezvousDeviceDelegate.cpp +++ b/examples/all-clusters-app/esp32/main/RendezvousDeviceDelegate.cpp @@ -46,7 +46,7 @@ RendezvousDeviceDelegate::RendezvousDeviceDelegate() params.SetSetupPINCode(setupPINCode).SetBleLayer(DeviceLayer::ConnectivityMgr().GetBleLayer()); mRendezvousSession = chip::Platform::New(this); - err = mRendezvousSession->Init(params); + err = mRendezvousSession->Init(params, nullptr); exit: if (err != CHIP_NO_ERROR) diff --git a/examples/common/chip-app-server/DataModelHandler.cpp b/examples/common/chip-app-server/DataModelHandler.cpp index 77e9575b1e241b..18d0d5a5458ada 100644 --- a/examples/common/chip-app-server/DataModelHandler.cpp +++ b/examples/common/chip-app-server/DataModelHandler.cpp @@ -40,7 +40,7 @@ using namespace ::chip; * @param [in] buffer The buffer holding the message. This function guarantees * that it will free the buffer before returning. */ -void HandleDataModelMessage(const PacketHeader & header, System::PacketBuffer * buffer, SecureSessionMgrBase * mgr) +void HandleDataModelMessage(const PacketHeader & header, System::PacketBuffer * buffer, SecureSessionMgr * mgr) { EmberApsFrame frame; bool ok = extractApsFrame(buffer->Start(), buffer->DataLength(), &frame) > 0; diff --git a/examples/common/chip-app-server/RendezvousServer.cpp b/examples/common/chip-app-server/RendezvousServer.cpp index 13f1f877c47a08..5d8d8c6d7dfd39 100644 --- a/examples/common/chip-app-server/RendezvousServer.cpp +++ b/examples/common/chip-app-server/RendezvousServer.cpp @@ -36,9 +36,9 @@ namespace chip { RendezvousServer::RendezvousServer() : mRendezvousSession(this) {} -CHIP_ERROR RendezvousServer::Init(const RendezvousParameters & params) +CHIP_ERROR RendezvousServer::Init(const RendezvousParameters & params, TransportMgrBase * transportMgr) { - return mRendezvousSession.Init(params); + return mRendezvousSession.Init(params, transportMgr); } void RendezvousServer::OnRendezvousError(CHIP_ERROR err) @@ -56,7 +56,8 @@ void RendezvousServer::OnRendezvousConnectionClosed() ChipLogProgress(AppServer, "OnRendezvousConnectionClosed"); } -void RendezvousServer::OnRendezvousMessageReceived(PacketBuffer * buffer) +void RendezvousServer::OnRendezvousMessageReceived(const PacketHeader & packetHeader, const PeerAddress & peerAddress, + PacketBuffer * buffer) { chip::System::PacketBuffer::Free(buffer); } diff --git a/examples/common/chip-app-server/Server.cpp b/examples/common/chip-app-server/Server.cpp index 03100852d010b8..0410326ea457fe 100644 --- a/examples/common/chip-app-server/Server.cpp +++ b/examples/common/chip-app-server/Server.cpp @@ -47,7 +47,7 @@ class ServerCallback : public SecureSessionMgrDelegate public: void OnMessageReceived(const PacketHeader & header, const PayloadHeader & payloadHeader, const Transport::PeerConnectionState * state, System::PacketBuffer * buffer, - SecureSessionMgrBase * mgr) override + SecureSessionMgr * mgr) override { const size_t data_len = buffer->DataLength(); char src_addr[PeerAddress::kMaxToStringSize]; @@ -74,20 +74,21 @@ class ServerCallback : public SecureSessionMgrDelegate } } - void OnNewConnection(const Transport::PeerConnectionState * state, SecureSessionMgrBase * mgr) override + void OnNewConnection(const Transport::PeerConnectionState * state, SecureSessionMgr * mgr) override { ChipLogProgress(AppServer, "Received a new connection."); } }; -DemoSessionManager gSessions; +DemoTransportMgr gTransports; +SecureSessionMgr gSessions; ServerCallback gCallbacks; SecurePairingUsingTestSecret gTestPairing; RendezvousServer gRendezvousServer; } // namespace -SecureSessionMgrBase & chip::SessionManager() +SecureSessionMgr & chip::SessionManager() { return gSessions; } @@ -101,8 +102,11 @@ void InitServer() InitDataModelHandler(); - err = gSessions.Init(chip::kTestDeviceNodeId, &DeviceLayer::SystemLayer, - UdpListenParameters(&DeviceLayer::InetLayer).SetAddressType(kIPAddressType_IPv6)); + // Init transport before operations with secure session mgr. + err = gTransports.Init(UdpListenParameters(&DeviceLayer::InetLayer).SetAddressType(kIPAddressType_IPv6)); + SuccessOrExit(err); + + err = gSessions.Init(chip::kTestDeviceNodeId, &DeviceLayer::SystemLayer, &gTransports); SuccessOrExit(err); // This flag is used to bypass BLE in the cirque test @@ -116,7 +120,7 @@ void InitServer() params.SetSetupPINCode(pinCode) .SetLocalNodeId(chip::kTestDeviceNodeId) .SetBleLayer(DeviceLayer::ConnectivityMgr().GetBleLayer()); - SuccessOrExit(err = gRendezvousServer.Init(params)); + SuccessOrExit(err = gRendezvousServer.Init(params, &gTransports)); } #endif @@ -125,7 +129,6 @@ void InitServer() gSessions.SetDelegate(&gCallbacks); chip::Mdns::DiscoveryManager::GetInstance().StartPublishDevice(chip::Inet::kIPAddressType_IPv6); - exit: if (err != CHIP_NO_ERROR) { diff --git a/examples/common/chip-app-server/include/DataModelHandler.h b/examples/common/chip-app-server/include/DataModelHandler.h index f4ff4a9b30a5b4..35a4524c9458df 100644 --- a/examples/common/chip-app-server/include/DataModelHandler.h +++ b/examples/common/chip-app-server/include/DataModelHandler.h @@ -33,6 +33,5 @@ * @param [in] buffer The buffer holding the message. This function guarantees * that it will free the buffer before returning. */ -void HandleDataModelMessage(const chip::PacketHeader & header, chip::System::PacketBuffer * buffer, - chip::SecureSessionMgrBase * mgr); +void HandleDataModelMessage(const chip::PacketHeader & header, chip::System::PacketBuffer * buffer, chip::SecureSessionMgr * mgr); void InitDataModelHandler(); diff --git a/examples/common/chip-app-server/include/RendezvousServer.h b/examples/common/chip-app-server/include/RendezvousServer.h index a2fb114698d4a4..5af5e9b92d5f6d 100644 --- a/examples/common/chip-app-server/include/RendezvousServer.h +++ b/examples/common/chip-app-server/include/RendezvousServer.h @@ -27,16 +27,18 @@ class RendezvousServer : public RendezvousSessionDelegate public: RendezvousServer(); - CHIP_ERROR Init(const RendezvousParameters & params); + CHIP_ERROR Init(const RendezvousParameters & params, TransportMgrBase * transportMgr); //////////////// RendezvousSessionDelegate Implementation /////////////////// void OnRendezvousConnectionOpened() override; void OnRendezvousConnectionClosed() override; void OnRendezvousError(CHIP_ERROR err) override; - void OnRendezvousMessageReceived(System::PacketBuffer * buffer) override; + void OnRendezvousMessageReceived(const PacketHeader & packetHeader, const Transport::PeerAddress & peerAddress, + System::PacketBuffer * buffer) override; void OnRendezvousComplete() override; void OnRendezvousStatusUpdate(Status status, CHIP_ERROR err) override; + RendezvousSession * GetRendezvousSession() { return &mRendezvousSession; }; private: RendezvousSession mRendezvousSession; diff --git a/examples/common/chip-app-server/include/Server.h b/examples/common/chip-app-server/include/Server.h index b4be5aba1f53ac..6bd260b9f23b8e 100644 --- a/examples/common/chip-app-server/include/Server.h +++ b/examples/common/chip-app-server/include/Server.h @@ -17,10 +17,10 @@ #pragma once -#include +#include #include -using DemoSessionManager = chip::SecureSessionMgr; +using DemoTransportMgr = chip::TransportMgr; /** * Initialize DataModelHandler and start CHIP datamodel server, the server diff --git a/examples/common/chip-app-server/include/SessionManager.h b/examples/common/chip-app-server/include/SessionManager.h index ebfc2d8c78f9f5..17d78a987de5ab 100644 --- a/examples/common/chip-app-server/include/SessionManager.h +++ b/examples/common/chip-app-server/include/SessionManager.h @@ -20,5 +20,5 @@ #include namespace chip { -SecureSessionMgrBase & SessionManager(); +SecureSessionMgr & SessionManager(); } // namespace chip diff --git a/examples/temperature-measurement-app/esp32/main/RendezvousDeviceDelegate.cpp b/examples/temperature-measurement-app/esp32/main/RendezvousDeviceDelegate.cpp index 66fb309773b97c..f95b23c6f480ee 100644 --- a/examples/temperature-measurement-app/esp32/main/RendezvousDeviceDelegate.cpp +++ b/examples/temperature-measurement-app/esp32/main/RendezvousDeviceDelegate.cpp @@ -44,7 +44,7 @@ RendezvousDeviceDelegate::RendezvousDeviceDelegate() params.SetSetupPINCode(setupPINCode).SetLocalNodeId(kLocalNodeId).SetBleLayer(DeviceLayer::ConnectivityMgr().GetBleLayer()); mRendezvousSession = new RendezvousSession(this); - err = mRendezvousSession->Init(params); + err = mRendezvousSession->Init(params, nullptr); exit: if (err != CHIP_NO_ERROR) diff --git a/examples/temperature-measurement-app/esp32/main/ResponseServer.cpp b/examples/temperature-measurement-app/esp32/main/ResponseServer.cpp index 58ea2b9f31d808..bbc857634d7d76 100644 --- a/examples/temperature-measurement-app/esp32/main/ResponseServer.cpp +++ b/examples/temperature-measurement-app/esp32/main/ResponseServer.cpp @@ -42,6 +42,7 @@ #include #include #include +#include #include #include @@ -62,7 +63,7 @@ class ResponseServerCallback : public SecureSessionMgrDelegate public: void OnMessageReceived(const PacketHeader & header, const PayloadHeader & payloadHeader, const Transport::PeerConnectionState * state, System::PacketBuffer * buffer, - SecureSessionMgrBase * mgr) override + SecureSessionMgr * mgr) override { CHIP_ERROR err; const size_t data_len = buffer->DataLength(); @@ -91,12 +92,12 @@ class ResponseServerCallback : public SecureSessionMgrDelegate } } - void OnReceiveError(CHIP_ERROR error, const Transport::PeerAddress & source, SecureSessionMgrBase * mgr) override + void OnReceiveError(CHIP_ERROR error, const Transport::PeerAddress & source, SecureSessionMgr * mgr) override { ESP_LOGE(TAG, "ERROR: %s\n Got UDP error", ErrorStr(error)); } - void OnNewConnection(const Transport::PeerConnectionState * state, SecureSessionMgrBase * mgr) override + void OnNewConnection(const Transport::PeerConnectionState * state, SecureSessionMgr * mgr) override { ESP_LOGI(TAG, "Received a new connection."); } @@ -128,16 +129,16 @@ class ResponseServerCallback : public SecureSessionMgrDelegate }; ResponseServerCallback gCallbacks; - -SecureSessionMgr - sessions; +TransportMgr + gTransports; +SecureSessionMgr sessions; } // namespace namespace chip { -SecureSessionMgrBase & SessionManager() +SecureSessionMgr & SessionManager() { return sessions; } @@ -147,9 +148,10 @@ SecureSessionMgrBase & SessionManager() void startServer(NodeId localNodeId) { CHIP_ERROR err = CHIP_NO_ERROR; - err = sessions.Init(localNodeId, &DeviceLayer::SystemLayer, - UdpListenParameters(&DeviceLayer::InetLayer).SetAddressType(kIPAddressType_IPv6).SetInterfaceId(nullptr), - UdpListenParameters(&DeviceLayer::InetLayer).SetAddressType(kIPAddressType_IPv4)); + err = gTransports.Init(UdpListenParameters(&DeviceLayer::InetLayer).SetAddressType(kIPAddressType_IPv6).SetInterfaceId(nullptr), + UdpListenParameters(&DeviceLayer::InetLayer).SetAddressType(kIPAddressType_IPv4)); + SuccessOrExit(err); + err = sessions.Init(localNodeId, &DeviceLayer::SystemLayer, &gTransports); SuccessOrExit(err); sessions.SetDelegate(&gCallbacks); diff --git a/src/app/util/chip-message-send.cpp b/src/app/util/chip-message-send.cpp index 23ce2b1b714ca0..c0e3e5d15d0fb8 100644 --- a/src/app/util/chip-message-send.cpp +++ b/src/app/util/chip-message-send.cpp @@ -25,7 +25,7 @@ #include #include // PacketBuffer and the like #include -#include // For SecureSessionMgrBase +#include // For SecureSessionMgr using namespace chip; @@ -35,7 +35,7 @@ using namespace chip; // // https://github.com/project-chip/connectedhomeip/issues/2566 tracks that API. namespace chip { -extern SecureSessionMgrBase & SessionManager(); +extern SecureSessionMgr & SessionManager(); } EmberStatus chipSendUnicast(NodeId destination, EmberApsFrame * apsFrame, uint16_t messageLength, uint8_t * message) diff --git a/src/controller/CHIPDevice.cpp b/src/controller/CHIPDevice.cpp index 3218a40fbc369e..ca067cf701eb6f 100644 --- a/src/controller/CHIPDevice.cpp +++ b/src/controller/CHIPDevice.cpp @@ -165,8 +165,7 @@ CHIP_ERROR Device::Deserialize(const SerializedDevice & input) } void Device::OnMessageReceived(const PacketHeader & header, const PayloadHeader & payloadHeader, - const Transport::PeerConnectionState * state, System::PacketBuffer * msgBuf, - SecureSessionMgrBase * mgr) + const Transport::PeerConnectionState * state, System::PacketBuffer * msgBuf, SecureSessionMgr * mgr) { if (mState == ConnectionState::SecureConnected && mStatusDelegate != nullptr) { @@ -187,7 +186,7 @@ CHIP_ERROR Device::LoadSecureSessionParameters() err = pairingSession.FromSerializable(mPairing); SuccessOrExit(err); - err = mSessionManager->ResetTransport(Transport::UdpListenParameters(mInetLayer).SetAddressType(mDeviceAddr.Type())); + err = mTransportMgr->ResetTransport(Transport::UdpListenParameters(mInetLayer).SetAddressType(mDeviceAddr.Type())); SuccessOrExit(err); err = mSessionManager->NewPairing( diff --git a/src/controller/CHIPDevice.h b/src/controller/CHIPDevice.h index 380685aa3ba6d2..7b45f9dfff2acc 100644 --- a/src/controller/CHIPDevice.h +++ b/src/controller/CHIPDevice.h @@ -31,6 +31,7 @@ #include #include #include +#include #include #include @@ -41,6 +42,8 @@ class DeviceController; class DeviceStatusDelegate; struct SerializedDevice; +using DeviceTransportMgr = TransportMgr; + class DLL_EXPORT Device { public: @@ -93,11 +96,13 @@ class DLL_EXPORT Device * that of this device object. If these objects are freed, while the device object is * still using them, it can lead to unknown behavior and crashes. * + * @param[in] transportMgr Transport manager object pointer * @param[in] sessionMgr Secure session manager object pointer * @param[in] inetLayer InetLayer object pointer */ - void Init(SecureSessionMgr * sessionMgr, Inet::InetLayer * inetLayer) + void Init(DeviceTransportMgr * transportMgr, SecureSessionMgr * sessionMgr, Inet::InetLayer * inetLayer) { + mTransportMgr = transportMgr; mSessionManager = sessionMgr; mInetLayer = inetLayer; } @@ -113,16 +118,17 @@ class DLL_EXPORT Device * uninitialzed/unpaired device objects. The object is initialized only when the device * is actually paired. * + * @param[in] transportMgr Transport manager object pointer * @param[in] sessionMgr Secure session manager object pointer * @param[in] inetLayer InetLayer object pointer * @param[in] deviceId Node ID of the device * @param[in] devicePort Port on which device is listening (typically CHIP_PORT) * @param[in] interfaceId Local Interface ID that should be used to talk to the device */ - void Init(SecureSessionMgr * sessionMgr, Inet::InetLayer * inetLayer, NodeId deviceId, uint16_t devicePort, - Inet::InterfaceId interfaceId) + void Init(DeviceTransportMgr * transportMgr, SecureSessionMgr * sessionMgr, Inet::InetLayer * inetLayer, NodeId deviceId, + uint16_t devicePort, Inet::InterfaceId interfaceId) { - Init(sessionMgr, inetLayer); + Init(transportMgr, sessionMgr, inetLayer); mDeviceId = deviceId; mDevicePort = devicePort; mInterface = interfaceId; @@ -156,7 +162,7 @@ class DLL_EXPORT Device * @param[in] mgr Pointer to secure session manager which received the message */ void OnMessageReceived(const PacketHeader & header, const PayloadHeader & payloadHeader, - const Transport::PeerConnectionState * state, System::PacketBuffer * msgBuf, SecureSessionMgrBase * mgr); + const Transport::PeerConnectionState * state, System::PacketBuffer * msgBuf, SecureSessionMgr * mgr); /** * @brief @@ -202,7 +208,9 @@ class DLL_EXPORT Device DeviceStatusDelegate * mStatusDelegate; - SecureSessionMgr * mSessionManager; + SecureSessionMgr * mSessionManager; + + DeviceTransportMgr * mTransportMgr; /** * @brief diff --git a/src/controller/CHIPDeviceController.cpp b/src/controller/CHIPDeviceController.cpp index 54bbc86fb0dfdd..0c1461c7dcca8a 100644 --- a/src/controller/CHIPDeviceController.cpp +++ b/src/controller/CHIPDeviceController.cpp @@ -123,10 +123,13 @@ CHIP_ERROR DeviceController::Init(NodeId localDeviceId, PersistentStorageDelegat mStorageDelegate->SetDelegate(this); } - mSessionManager = chip::Platform::New>(); + mTransportMgr = chip::Platform::New(); + mSessionManager = chip::Platform::New(); - err = mSessionManager->Init(localDeviceId, mSystemLayer, - Transport::UdpListenParameters(mInetLayer).SetAddressType(Inet::kIPAddressType_IPv6)); + err = mTransportMgr->Init(Transport::UdpListenParameters(mInetLayer).SetAddressType(Inet::kIPAddressType_IPv6)); + SuccessOrExit(err); + + err = mSessionManager->Init(localDeviceId, mSystemLayer, mTransportMgr); SuccessOrExit(err); mSessionManager->SetDelegate(this); @@ -202,7 +205,7 @@ CHIP_ERROR DeviceController::GetDevice(NodeId deviceId, const SerializedDevice & err = device->Deserialize(deviceInfo); VerifyOrExit(err == CHIP_NO_ERROR, ReleaseDevice(device)); - device->Init(mSessionManager, mInetLayer); + device->Init(mTransportMgr, mSessionManager, mInetLayer); } *out_device = device; @@ -264,7 +267,7 @@ CHIP_ERROR DeviceController::GetDevice(NodeId deviceId, Device ** out_device) err = device->Deserialize(deviceInfo); VerifyOrExit(err == CHIP_NO_ERROR, ReleaseDevice(device)); - device->Init(mSessionManager, mInetLayer); + device->Init(mTransportMgr, mSessionManager, mInetLayer); } } @@ -309,11 +312,11 @@ CHIP_ERROR DeviceController::ServiceEventSignal() return err; } -void DeviceController::OnNewConnection(const Transport::PeerConnectionState * peerConnection, SecureSessionMgrBase * mgr) {} +void DeviceController::OnNewConnection(const Transport::PeerConnectionState * peerConnection, SecureSessionMgr * mgr) {} void DeviceController::OnMessageReceived(const PacketHeader & header, const PayloadHeader & payloadHeader, const Transport::PeerConnectionState * state, System::PacketBuffer * msgBuf, - SecureSessionMgrBase * mgr) + SecureSessionMgr * mgr) { CHIP_ERROR err = CHIP_NO_ERROR; uint16_t index = 0; @@ -467,10 +470,10 @@ CHIP_ERROR DeviceCommissioner::PairDevice(NodeId remoteDeviceId, RendezvousParam mRendezvousSession = chip::Platform::New(this); VerifyOrExit(mRendezvousSession != nullptr, err = CHIP_ERROR_NO_MEMORY); - err = mRendezvousSession->Init(params.SetLocalNodeId(mLocalDeviceId).SetRemoteNodeId(remoteDeviceId)); + err = mRendezvousSession->Init(params.SetLocalNodeId(mLocalDeviceId).SetRemoteNodeId(remoteDeviceId), mTransportMgr); SuccessOrExit(err); - device->Init(mSessionManager, mInetLayer, remoteDeviceId, remotePort, interfaceId); + device->Init(mTransportMgr, mSessionManager, mInetLayer, remoteDeviceId, remotePort, interfaceId); exit: if (err != CHIP_NO_ERROR) @@ -513,7 +516,7 @@ CHIP_ERROR DeviceCommissioner::PairTestDeviceWithoutSecurity(NodeId remoteDevice testSecurePairingSecret->ToSerializable(device->GetPairing()); - device->Init(mSessionManager, mInetLayer, remoteDeviceId, remotePort, interfaceId); + device->Init(mTransportMgr, mSessionManager, mInetLayer, remoteDeviceId, remotePort, interfaceId); device->SetAddress(deviceAddr); diff --git a/src/controller/CHIPDeviceController.h b/src/controller/CHIPDeviceController.h index e7ec8593b8354b..ad979f4a46b46c 100644 --- a/src/controller/CHIPDeviceController.h +++ b/src/controller/CHIPDeviceController.h @@ -37,6 +37,7 @@ #include #include #include +#include #include namespace chip { @@ -182,7 +183,8 @@ class DLL_EXPORT DeviceController : public SecureSessionMgrDelegate, public Pers bool mPairedDevicesInitialized; NodeId mLocalDeviceId; - SecureSessionMgr * mSessionManager; + DeviceTransportMgr * mTransportMgr; + SecureSessionMgr * mSessionManager; PersistentStorageDelegate * mStorageDelegate; Inet::InetLayer * mInetLayer; @@ -195,9 +197,9 @@ class DLL_EXPORT DeviceController : public SecureSessionMgrDelegate, public Pers //////////// SecureSessionMgrDelegate Implementation /////////////// void OnMessageReceived(const PacketHeader & header, const PayloadHeader & payloadHeader, const Transport::PeerConnectionState * state, System::PacketBuffer * msgBuf, - SecureSessionMgrBase * mgr) override; + SecureSessionMgr * mgr) override; - void OnNewConnection(const Transport::PeerConnectionState * state, SecureSessionMgrBase * mgr) override; + void OnNewConnection(const Transport::PeerConnectionState * state, SecureSessionMgr * mgr) override; //////////// PersistentStorageResultDelegate Implementation /////////////// void OnValue(const char * key, const char * value) override; diff --git a/src/include/platform/CHIPDeviceLayer.h b/src/include/platform/CHIPDeviceLayer.h index 434998bb924af0..bbab2d48a701b1 100644 --- a/src/include/platform/CHIPDeviceLayer.h +++ b/src/include/platform/CHIPDeviceLayer.h @@ -29,6 +29,7 @@ #include #include #include +#include #if CHIP_DEVICE_CONFIG_ENABLE_SOFTWARE_UPDATE_MANAGER #include #endif // CHIP_DEVICE_CONFIG_ENABLE_SOFTWARE_UPDATE_MANAGER diff --git a/src/messaging/ExchangeMgr.cpp b/src/messaging/ExchangeMgr.cpp index 12ac90a3a8c1c3..61d7c6ab762fcb 100644 --- a/src/messaging/ExchangeMgr.cpp +++ b/src/messaging/ExchangeMgr.cpp @@ -62,7 +62,7 @@ ExchangeManager::ExchangeManager() mState = State::kState_NotInitialized; } -CHIP_ERROR ExchangeManager::Init(SecureSessionMgrBase * sessionMgr) +CHIP_ERROR ExchangeManager::Init(SecureSessionMgr * sessionMgr) { if (mState != State::kState_NotInitialized) return CHIP_ERROR_INCORRECT_STATE; @@ -135,7 +135,7 @@ CHIP_ERROR ExchangeManager::UnregisterUnsolicitedMessageHandler(uint32_t protoco return UnregisterUMH(protocolId, static_cast(msgType)); } -void ExchangeManager::OnReceiveError(CHIP_ERROR error, const Transport::PeerAddress & source, SecureSessionMgrBase * msgLayer) +void ExchangeManager::OnReceiveError(CHIP_ERROR error, const Transport::PeerAddress & source, SecureSessionMgr * msgLayer) { ChipLogError(ExchangeManager, "Accept FAILED, err = %s", ErrorStr(error)); } @@ -285,12 +285,12 @@ CHIP_ERROR ExchangeManager::UnregisterUMH(uint32_t protocolId, int16_t msgType) void ExchangeManager::OnMessageReceived(const PacketHeader & packetHeader, const PayloadHeader & payloadHeader, const Transport::PeerConnectionState * state, System::PacketBuffer * msgBuf, - SecureSessionMgrBase * msgLayer) + SecureSessionMgr * msgLayer) { DispatchMessage(packetHeader, payloadHeader, msgBuf); } -void ExchangeManager::OnConnectionExpired(const Transport::PeerConnectionState * state, SecureSessionMgrBase * mgr) +void ExchangeManager::OnConnectionExpired(const Transport::PeerConnectionState * state, SecureSessionMgr * mgr) { for (auto & ec : ContextPool) { diff --git a/src/messaging/ExchangeMgr.h b/src/messaging/ExchangeMgr.h index a9e056666fec9b..49772efc9fd5ea 100644 --- a/src/messaging/ExchangeMgr.h +++ b/src/messaging/ExchangeMgr.h @@ -56,14 +56,14 @@ class DLL_EXPORT ExchangeManager : public SecureSessionMgrDelegate * construction until a call to Shutdown is made to terminate the * instance. * - * @param[in] sessionMgr A pointer to the SecureSessionMgrBase object. + * @param[in] sessionMgr A pointer to the SecureSessionMgr object. * * @retval #CHIP_ERROR_INCORRECT_STATE If the state is not equal to * kState_NotInitialized. * @retval #CHIP_NO_ERROR On success. * */ - CHIP_ERROR Init(SecureSessionMgrBase * sessionMgr); + CHIP_ERROR Init(SecureSessionMgr * sessionMgr); /** * Shutdown the ExchangeManager. This terminates this instance @@ -162,7 +162,7 @@ class DLL_EXPORT ExchangeManager : public SecureSessionMgrDelegate void IncrementContextsInUse(); void DecrementContextsInUse(); - SecureSessionMgrBase * GetSessionMgr() const { return mSessionMgr; } + SecureSessionMgr * GetSessionMgr() const { return mSessionMgr; } size_t GetContextsInUse() const { return mContextsInUse; } @@ -182,7 +182,7 @@ class DLL_EXPORT ExchangeManager : public SecureSessionMgrDelegate uint16_t mNextExchangeId; State mState; - SecureSessionMgrBase * mSessionMgr; + SecureSessionMgr * mSessionMgr; std::array ContextPool; size_t mContextsInUse; @@ -197,13 +197,13 @@ class DLL_EXPORT ExchangeManager : public SecureSessionMgrDelegate CHIP_ERROR RegisterUMH(uint32_t protocolId, int16_t msgType, ExchangeDelegate * delegate); CHIP_ERROR UnregisterUMH(uint32_t protocolId, int16_t msgType); - void OnReceiveError(CHIP_ERROR error, const Transport::PeerAddress & source, SecureSessionMgrBase * msgLayer) override; + void OnReceiveError(CHIP_ERROR error, const Transport::PeerAddress & source, SecureSessionMgr * msgLayer) override; void OnMessageReceived(const PacketHeader & packetHeader, const PayloadHeader & payloadHeader, const Transport::PeerConnectionState * state, System::PacketBuffer * msgBuf, - SecureSessionMgrBase * msgLayer) override; + SecureSessionMgr * msgLayer) override; - void OnConnectionExpired(const Transport::PeerConnectionState * state, SecureSessionMgrBase * mgr) override; + void OnConnectionExpired(const Transport::PeerConnectionState * state, SecureSessionMgr * mgr) override; }; } // namespace chip diff --git a/src/messaging/tests/TestExchangeMgr.cpp b/src/messaging/tests/TestExchangeMgr.cpp index d7cdbfe2d49fbf..c6733a3ace3a36 100644 --- a/src/messaging/tests/TestExchangeMgr.cpp +++ b/src/messaging/tests/TestExchangeMgr.cpp @@ -29,6 +29,7 @@ #include #include #include +#include #include #include @@ -83,16 +84,19 @@ void CheckSimpleInitTest(nlTestSuite * inSuite, void * inContext) { TestContext & ctx = *reinterpret_cast(inContext); - SecureSessionMgr conn; + TransportMgr transportMgr; + SecureSessionMgr secureSessionMgr; CHIP_ERROR err; ctx.GetInetLayer().SystemLayer()->Init(nullptr); - err = conn.Init(kSourceNodeId, ctx.GetInetLayer().SystemLayer(), "LOOPBACK"); + err = transportMgr.Init("LOOPBACK"); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + err = secureSessionMgr.Init(kSourceNodeId, ctx.GetInetLayer().SystemLayer(), &transportMgr); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); ExchangeManager exchangeMgr; - err = exchangeMgr.Init(&conn); + err = exchangeMgr.Init(&secureSessionMgr); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); } @@ -100,16 +104,19 @@ void CheckNewContextTest(nlTestSuite * inSuite, void * inContext) { TestContext & ctx = *reinterpret_cast(inContext); - SecureSessionMgr conn; + TransportMgr transportMgr; + SecureSessionMgr secureSessionMgr; CHIP_ERROR err; ctx.GetInetLayer().SystemLayer()->Init(nullptr); - err = conn.Init(kSourceNodeId, ctx.GetInetLayer().SystemLayer(), "LOOPBACK"); + err = transportMgr.Init("LOOPBACK"); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + err = secureSessionMgr.Init(kSourceNodeId, ctx.GetInetLayer().SystemLayer(), &transportMgr); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); ExchangeManager exchangeMgr; - err = exchangeMgr.Init(&conn); + err = exchangeMgr.Init(&secureSessionMgr); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); MockAppDelegate mockAppDelegate; @@ -131,16 +138,19 @@ void CheckFindContextTest(nlTestSuite * inSuite, void * inContext) { TestContext & ctx = *reinterpret_cast(inContext); - SecureSessionMgr conn; + TransportMgr transportMgr; + SecureSessionMgr secureSessionMgr; CHIP_ERROR err; ctx.GetInetLayer().SystemLayer()->Init(nullptr); - err = conn.Init(kSourceNodeId, ctx.GetInetLayer().SystemLayer(), "LOOPBACK"); + err = transportMgr.Init("LOOPBACK"); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + err = secureSessionMgr.Init(kSourceNodeId, ctx.GetInetLayer().SystemLayer(), &transportMgr); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); ExchangeManager exchangeMgr; - err = exchangeMgr.Init(&conn); + err = exchangeMgr.Init(&secureSessionMgr); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); MockAppDelegate mockAppDelegate; @@ -159,16 +169,19 @@ void CheckUmhRegistrationTest(nlTestSuite * inSuite, void * inContext) { TestContext & ctx = *reinterpret_cast(inContext); - SecureSessionMgr conn; + TransportMgr transportMgr; + SecureSessionMgr secureSessionMgr; CHIP_ERROR err; ctx.GetInetLayer().SystemLayer()->Init(nullptr); - err = conn.Init(kSourceNodeId, ctx.GetInetLayer().SystemLayer(), "LOOPBACK"); + err = transportMgr.Init("LOOPBACK"); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + err = secureSessionMgr.Init(kSourceNodeId, ctx.GetInetLayer().SystemLayer(), &transportMgr); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); ExchangeManager exchangeMgr; - err = exchangeMgr.Init(&conn); + err = exchangeMgr.Init(&secureSessionMgr); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); MockAppDelegate mockAppDelegate; @@ -197,25 +210,29 @@ void CheckExchangeMessages(nlTestSuite * inSuite, void * inContext) TestContext & ctx = *reinterpret_cast(inContext); CHIP_ERROR err; - SecureSessionMgr conn; + TransportMgr transportMgr; + SecureSessionMgr secureSessionMgr; IPAddress addr; IPAddress::FromString("127.0.0.1", addr); - SecurePairingUsingTestSecret pairing1(Optional::Value(kSourceNodeId), 1, 2); - Optional peer1(Transport::PeerAddress::UDP(addr, 1)); - err = conn.NewPairing(peer1, kSourceNodeId, &pairing1); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - SecurePairingUsingTestSecret pairing2(Optional::Value(kDestinationNodeId), 2, 1); - Optional peer2(Transport::PeerAddress::UDP(addr, 2)); - err = conn.NewPairing(peer2, kDestinationNodeId, &pairing2); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); ctx.GetInetLayer().SystemLayer()->Init(nullptr); - err = conn.Init(kSourceNodeId, ctx.GetInetLayer().SystemLayer(), "LOOPBACK"); + err = transportMgr.Init("LOOPBACK"); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + err = secureSessionMgr.Init(kSourceNodeId, ctx.GetInetLayer().SystemLayer(), &transportMgr); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); ExchangeManager exchangeMgr; - err = exchangeMgr.Init(&conn); + err = exchangeMgr.Init(&secureSessionMgr); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + + SecurePairingUsingTestSecret pairing1(Optional::Value(kSourceNodeId), 1, 2); + Optional peer1(Transport::PeerAddress::UDP(addr, 1)); + err = secureSessionMgr.NewPairing(peer1, kSourceNodeId, &pairing1); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + SecurePairingUsingTestSecret pairing2(Optional::Value(kDestinationNodeId), 2, 1); + Optional peer2(Transport::PeerAddress::UDP(addr, 2)); + err = secureSessionMgr.NewPairing(peer2, kDestinationNodeId, &pairing2); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); // create solicited exchange diff --git a/src/transport/BLE.cpp b/src/transport/BLE.cpp index e52af6bc6820f8..45eecc4c15a589 100644 --- a/src/transport/BLE.cpp +++ b/src/transport/BLE.cpp @@ -189,11 +189,24 @@ void BLE::OnBleConnectionError(void * appState, BLE_ERROR err) void BLE::OnBleEndPointReceive(BLEEndPoint * endPoint, PacketBuffer * buffer) { - BLE * ble = reinterpret_cast(endPoint->mAppState); + BLE * ble = reinterpret_cast(endPoint->mAppState); + CHIP_ERROR err = CHIP_NO_ERROR; if (ble->mDelegate) { - ble->mDelegate->OnRendezvousMessageReceived(buffer); + uint16_t headerSize = 0; + + PacketHeader header; + err = header.Decode(buffer->Start(), buffer->DataLength(), &headerSize); + SuccessOrExit(err); + + buffer->ConsumeHead(headerSize); + ble->mDelegate->OnRendezvousMessageReceived(header, Transport::PeerAddress(Transport::Type::kBle), buffer); + } +exit: + if (err != CHIP_NO_ERROR) + { + ChipLogError(Inet, "Failed to receive BLE message: %s", ErrorStr(err)); } } diff --git a/src/transport/BUILD.gn b/src/transport/BUILD.gn index f6e05ba87c764c..2439d50afe987e 100644 --- a/src/transport/BUILD.gn +++ b/src/transport/BUILD.gn @@ -33,6 +33,8 @@ static_library("transport") { "SecureSession.h", "SecureSessionMgr.cpp", "SecureSessionMgr.h", + "TransportMgr.cpp", + "TransportMgr.h", ] if (chip_config_network_layer_ble) { diff --git a/src/transport/RendezvousParameters.h b/src/transport/RendezvousParameters.h index 34c61977a98ce7..fbf6ef9355989f 100644 --- a/src/transport/RendezvousParameters.h +++ b/src/transport/RendezvousParameters.h @@ -18,7 +18,7 @@ #pragma once #include - +#include #if CONFIG_NETWORK_LAYER_BLE #include #endif // CONFIG_NETWORK_LAYER_BLE @@ -45,6 +45,14 @@ class RendezvousParameters return *this; } + bool HasPeerAddress() const { return mPeerAddress.IsInitialized(); } + Transport::PeerAddress GetPeerAddress() const { return mPeerAddress; } + RendezvousParameters & SetPeerAddress(const Transport::PeerAddress & peerAddress) + { + mPeerAddress = peerAddress; + return *this; + } + bool HasDiscriminator() const { return mDiscriminator <= kMaxRendezvousDiscriminatorValue; } uint16_t GetDiscriminator() const { return mDiscriminator; } RendezvousParameters & SetDiscriminator(uint16_t discriminator) @@ -91,6 +99,7 @@ class RendezvousParameters private: Optional mLocalNodeId; ///< the local node id + Transport::PeerAddress mPeerAddress; ///< the peer node address Optional mRemoteNodeId; ///< the remote node id uint32_t mSetupPINCode = 0; ///< the target peripheral setup PIN Code uint16_t mDiscriminator = UINT16_MAX; ///< the target peripheral discriminator diff --git a/src/transport/RendezvousSession.cpp b/src/transport/RendezvousSession.cpp index 34727bce49e5a5..220c16aea5a3f0 100644 --- a/src/transport/RendezvousSession.cpp +++ b/src/transport/RendezvousSession.cpp @@ -24,6 +24,10 @@ #include #include #include +#include +#include +#include +#include #if CONFIG_NETWORK_LAYER_BLE #include @@ -35,15 +39,19 @@ static const char * kSpake2pKeyExchangeSalt = "SPAKE2P Key Exchange Salt" using namespace chip::Inet; using namespace chip::System; +using namespace chip::Transport; namespace chip { -CHIP_ERROR RendezvousSession::Init(const RendezvousParameters & params) +CHIP_ERROR RendezvousSession::Init(const RendezvousParameters & params, TransportMgrBase * transportMgr) { - mParams = params; + mParams = params; + mTransportMgr = transportMgr; VerifyOrReturnError(mDelegate != nullptr, CHIP_ERROR_INCORRECT_STATE); VerifyOrReturnError(mParams.HasSetupPINCode(), CHIP_ERROR_INVALID_ARGUMENT); + // TODO: BLE Should be a transport, in that case, RendezvousSession and BLE should decouple + if (params.GetPeerAddress().GetTransportType() == Transport::Type::kBle) #if CONFIG_NETWORK_LAYER_BLE { Transport::BLE * transport = chip::Platform::New(); @@ -51,6 +59,11 @@ CHIP_ERROR RendezvousSession::Init(const RendezvousParameters & params) ReturnErrorOnFailure(transport->Init(this, mParams)); } +#else + { + return CHIP_ERROR_UNSUPPORTED_CHIP_FEATURE; + } +#endif // CONFIG_NETWORK_LAYER_BLE if (!mParams.IsController()) { @@ -58,11 +71,13 @@ CHIP_ERROR RendezvousSession::Init(const RendezvousParameters & params) } mNetworkProvision.Init(this); + // TODO: We should assmue mTransportMgr not null for IP rendezvous. + if (mTransportMgr != nullptr) + { + mTransportMgr->SetRendezvousSession(this); + } return CHIP_NO_ERROR; -#else - return CHIP_ERROR_UNSUPPORTED_CHIP_FEATURE; -#endif // CONFIG_NETWORK_LAYER_BLE } RendezvousSession::~RendezvousSession() @@ -77,11 +92,28 @@ RendezvousSession::~RendezvousSession() } CHIP_ERROR RendezvousSession::SendPairingMessage(const PacketHeader & header, Header::Flags payloadFlags, - System::PacketBuffer * msgIn) + const Transport::PeerAddress & peerAddress, System::PacketBuffer * msgIn) { - ReturnErrorCodeIf(mCurrentState != State::kSecurePairing, CHIP_ERROR_INCORRECT_STATE); + if (mCurrentState != State::kSecurePairing) + { + PacketBuffer::Free(msgIn); + return CHIP_ERROR_INCORRECT_STATE; + } - return mTransport->SendMessage(header, payloadFlags, Transport::PeerAddress::BLE(), msgIn); + if (peerAddress.GetTransportType() == Transport::Type::kBle) + { + return mTransport->SendMessage(header, payloadFlags, peerAddress, msgIn); + } + else if (mTransportMgr != nullptr) + { + return mTransportMgr->SendMessage(header, payloadFlags, peerAddress, msgIn); + } + else + { + PacketBuffer::Free(msgIn); + ChipLogError(Ble, "SendPairingMessage dropped since no transport mgr for IP rendezvous"); + return CHIP_ERROR_INVALID_ADDRESS; + } } CHIP_ERROR RendezvousSession::SendSecureMessage(Protocols::CHIPProtocolId protocol, uint8_t msgType, System::PacketBuffer * msgIn) @@ -255,18 +287,20 @@ void RendezvousSession::UpdateState(RendezvousSession::State newState, CHIP_ERRO } } -void RendezvousSession::OnRendezvousMessageReceived(PacketBuffer * msgBuf) +void RendezvousSession::OnRendezvousMessageReceived(const PacketHeader & packetHeader, const PeerAddress & peerAddress, + PacketBuffer * msgBuf) { CHIP_ERROR err = CHIP_NO_ERROR; + // TODO: RendezvousSession should handle SecurePairing messages only switch (mCurrentState) { case State::kSecurePairing: - err = HandlePairingMessage(msgBuf); + err = HandlePairingMessage(packetHeader, peerAddress, msgBuf); break; case State::kNetworkProvisioning: - err = HandleSecureMessage(msgBuf); + err = HandleSecureMessage(packetHeader, peerAddress, msgBuf); break; default: @@ -280,19 +314,21 @@ void RendezvousSession::OnRendezvousMessageReceived(PacketBuffer * msgBuf) } } -CHIP_ERROR RendezvousSession::HandlePairingMessage(PacketBuffer * msgBuf) +void RendezvousSession::OnMessageReceived(const PacketHeader & header, const Transport::PeerAddress & source, + System::PacketBuffer * msgBuf) { - PacketHeader packetHeader; - uint16_t headerSize = 0; - - ReturnErrorOnFailure(packetHeader.Decode(msgBuf->Start(), msgBuf->DataLength(), &headerSize)); - - msgBuf->ConsumeHead(headerSize); + // TODO: OnRendezvousMessageReceived can be renamed to OnMessageReceived after BLE becomes a transport. + this->OnRendezvousMessageReceived(header, source, msgBuf); +} - return mPairingSession.HandlePeerMessage(packetHeader, msgBuf); +CHIP_ERROR RendezvousSession::HandlePairingMessage(const PacketHeader & packetHeader, const PeerAddress & peerAddress, + PacketBuffer * msgBuf) +{ + return mPairingSession.HandlePeerMessage(packetHeader, peerAddress, msgBuf); } -CHIP_ERROR RendezvousSession::HandleSecureMessage(PacketBuffer * msgIn) +CHIP_ERROR RendezvousSession::HandleSecureMessage(const PacketHeader & packetHeader, const PeerAddress & peerAddress, + PacketBuffer * msgIn) { System::PacketBufferHandle msgBuf; msgBuf.Adopt(msgIn); @@ -304,11 +340,6 @@ CHIP_ERROR RendezvousSession::HandleSecureMessage(PacketBuffer * msgIn) ReturnErrorCodeIf(msgBuf.IsNull(), CHIP_ERROR_INVALID_ARGUMENT); - PacketHeader packetHeader; - ReturnErrorOnFailure(packetHeader.Decode(msgBuf->Start(), msgBuf->DataLength(), &headerSize)); - - msgBuf->ConsumeHead(headerSize); - // Check if the source and destination node IDs match with what we already know if (packetHeader.GetDestinationNodeId().HasValue() && mParams.HasLocalNodeId()) { @@ -392,7 +423,7 @@ CHIP_ERROR RendezvousSession::WaitForPairing(Optional nodeId, uint32_t s CHIP_ERROR RendezvousSession::Pair(Optional nodeId, uint32_t setupPINCode) { UpdateState(State::kSecurePairing); - return mPairingSession.Pair(setupPINCode, kSpake2p_Iteration_Count, + return mPairingSession.Pair(mParams.GetPeerAddress(), setupPINCode, kSpake2p_Iteration_Count, reinterpret_cast(kSpake2pKeyExchangeSalt), strlen(kSpake2pKeyExchangeSalt), nodeId, mNextKeyId++, this); } diff --git a/src/transport/RendezvousSession.h b/src/transport/RendezvousSession.h index 33f21fc7127234..9f79908e95e427 100644 --- a/src/transport/RendezvousSession.h +++ b/src/transport/RendezvousSession.h @@ -29,13 +29,17 @@ #include #include #include - +#include +#include +#include namespace chip { namespace DeviceLayer { class CHIPDeviceEvent; } +class SecureSessionMgr; + /** * RendezvousSession establishes and maintains the first connection between * a commissioner and a device. This connection is used in order to @@ -61,7 +65,8 @@ class CHIPDeviceEvent; class RendezvousSession : public SecurePairingSessionDelegate, public RendezvousSessionDelegate, public RendezvousDeviceCredentialsDelegate, - public NetworkProvisioningDelegate + public NetworkProvisioningDelegate, + public TransportMgrDelegate { public: enum State : uint8_t @@ -79,9 +84,11 @@ class RendezvousSession : public SecurePairingSessionDelegate, * @brief * Initialize the underlying transport using the RendezvousParameters passed in the constructor. * + * @param params The RendezvousParameters + * @param transportMgr The transport to use * @ return CHIP_ERROR The result of the initialization */ - CHIP_ERROR Init(const RendezvousParameters & params); + CHIP_ERROR Init(const RendezvousParameters & params, TransportMgrBase * transportMgr); /** * @brief @@ -95,7 +102,8 @@ class RendezvousSession : public SecurePairingSessionDelegate, Optional GetRemoteNodeId() const { return mParams.GetRemoteNodeId(); } //////////// SecurePairingSessionDelegate Implementation /////////////// - CHIP_ERROR SendPairingMessage(const PacketHeader & header, Header::Flags payloadFlags, System::PacketBuffer * msgBuf) override; + CHIP_ERROR SendPairingMessage(const PacketHeader & header, Header::Flags payloadFlags, + const Transport::PeerAddress & peerAddress, System::PacketBuffer * msgBuf) override; void OnPairingError(CHIP_ERROR err) override; void OnPairingComplete() override; @@ -103,7 +111,8 @@ class RendezvousSession : public SecurePairingSessionDelegate, void OnRendezvousConnectionOpened() override; void OnRendezvousConnectionClosed() override; void OnRendezvousError(CHIP_ERROR err) override; - void OnRendezvousMessageReceived(System::/* */ PacketBuffer * buffer) override; + void OnRendezvousMessageReceived(const PacketHeader & packetHeader, const Transport::PeerAddress & peerAddress, + System::PacketBuffer * buffer) override; //////////// RendezvousDeviceCredentialsDelegate Implementation /////////////// void SendNetworkCredentials(const char * ssid, const char * passwd) override; @@ -115,6 +124,10 @@ class RendezvousSession : public SecurePairingSessionDelegate, void OnNetworkProvisioningError(CHIP_ERROR error) override; void OnNetworkProvisioningComplete() override; + //////////// TransportMgrDelegate Implementation /////////////// + void OnMessageReceived(const PacketHeader & header, const Transport::PeerAddress & source, + System::PacketBuffer * msgBuf) override; + /** * @brief * Get the IP address assigned to the device during network provisioning @@ -125,11 +138,13 @@ class RendezvousSession : public SecurePairingSessionDelegate, const Inet::IPAddress & GetIPAddress() const { return mNetworkProvision.GetIPAddress(); } private: - CHIP_ERROR HandlePairingMessage(System::PacketBuffer * msgBug); + CHIP_ERROR HandlePairingMessage(const PacketHeader & packetHeader, const Transport::PeerAddress & peerAddress, + System::PacketBuffer * msgBug); CHIP_ERROR Pair(Optional nodeId, uint32_t setupPINCode); CHIP_ERROR WaitForPairing(Optional nodeId, uint32_t setupPINCode); - CHIP_ERROR HandleSecureMessage(System::PacketBuffer * msgBuf); + CHIP_ERROR HandleSecureMessage(const PacketHeader & packetHeader, const Transport::PeerAddress & peerAddress, + System::PacketBuffer * msgBuf); Transport::Base * mTransport = nullptr; ///< Underlying transport RendezvousSessionDelegate * mDelegate = nullptr; ///< Underlying transport events RendezvousParameters mParams; ///< Rendezvous configuration @@ -137,6 +152,7 @@ class RendezvousSession : public SecurePairingSessionDelegate, SecurePairingSession mPairingSession; NetworkProvisioning mNetworkProvision; SecureSession mSecureSession; + TransportMgrBase * mTransportMgr; uint32_t mSecureMessageIndex = 0; uint16_t mNextKeyId = 0; diff --git a/src/transport/RendezvousSessionDelegate.h b/src/transport/RendezvousSessionDelegate.h index b1c5773c856b2a..3cd6366341c933 100644 --- a/src/transport/RendezvousSessionDelegate.h +++ b/src/transport/RendezvousSessionDelegate.h @@ -19,6 +19,8 @@ #include #include +#include +#include namespace chip { @@ -39,8 +41,8 @@ class RendezvousSessionDelegate virtual void OnRendezvousConnectionClosed() {} virtual void OnRendezvousError(CHIP_ERROR err) {} virtual void OnRendezvousComplete() {} - virtual void OnRendezvousMessageReceived(System::PacketBuffer * buffer){}; - + virtual void OnRendezvousMessageReceived(const PacketHeader & packetHeader, const Transport::PeerAddress & peerAddress, + System::PacketBuffer * buffer){}; virtual void OnRendezvousStatusUpdate(Status status, CHIP_ERROR err) {} }; diff --git a/src/transport/SecurePairingSession.cpp b/src/transport/SecurePairingSession.cpp index 338a7df28fcb37..29c5fa6ff285c6 100644 --- a/src/transport/SecurePairingSession.cpp +++ b/src/transport/SecurePairingSession.cpp @@ -211,15 +211,16 @@ CHIP_ERROR SecurePairingSession::AttachHeaderAndSend(uint8_t msgType, System::Pa VerifyOrExit(headerSize == actualEncodedHeaderSize, err = CHIP_ERROR_INTERNAL); err = mDelegate->SendPairingMessage(PacketHeader().SetSourceNodeId(mLocalNodeId).SetEncryptionKeyID(mLocalKeyId), - payloadHeader.GetEncodePacketFlags(), msgBuf.Release_ForNow()); + payloadHeader.GetEncodePacketFlags(), mPeerAddress, msgBuf.Release_ForNow()); SuccessOrExit(err); exit: return err; } -CHIP_ERROR SecurePairingSession::Pair(uint32_t peerSetUpPINCode, uint32_t pbkdf2IterCount, const uint8_t * salt, size_t saltLen, - Optional myNodeId, uint16_t myKeyId, SecurePairingSessionDelegate * delegate) +CHIP_ERROR SecurePairingSession::Pair(const Transport::PeerAddress peerAddress, uint32_t peerSetUpPINCode, uint32_t pbkdf2IterCount, + const uint8_t * salt, size_t saltLen, Optional myNodeId, uint16_t myKeyId, + SecurePairingSessionDelegate * delegate) { uint8_t X[kMAX_Point_Length]; size_t X_len = sizeof(X); @@ -230,6 +231,8 @@ CHIP_ERROR SecurePairingSession::Pair(uint32_t peerSetUpPINCode, uint32_t pbkdf2 CHIP_ERROR err = Init(peerSetUpPINCode, pbkdf2IterCount, salt, saltLen, myNodeId, myKeyId, delegate); SuccessOrExit(err); + mPeerAddress = peerAddress; + err = mSpake2p.BeginProver(reinterpret_cast(""), 0, reinterpret_cast(""), 0, &mWS[0][0], kSpake2p_WS_Length, &mWS[1][0], kSpake2p_WS_Length); SuccessOrExit(err); @@ -425,7 +428,8 @@ CHIP_ERROR SecurePairingSession::HandleCompute_cA(const PacketHeader & header, S return err; } -CHIP_ERROR SecurePairingSession::HandlePeerMessage(const PacketHeader & packetHeader, System::PacketBuffer * msgIn) +CHIP_ERROR SecurePairingSession::HandlePeerMessage(const PacketHeader & packetHeader, const Transport::PeerAddress & peerAddress, + System::PacketBuffer * msgIn) { CHIP_ERROR err = CHIP_NO_ERROR; uint16_t headerSize = 0; @@ -443,6 +447,8 @@ CHIP_ERROR SecurePairingSession::HandlePeerMessage(const PacketHeader & packetHe VerifyOrExit(payloadHeader.GetProtocolID() == Protocols::kProtocol_SecurityChannel, err = CHIP_ERROR_INVALID_MESSAGE_TYPE); VerifyOrExit(payloadHeader.GetMessageType() == (uint8_t) mNextExpectedMsg, err = CHIP_ERROR_INVALID_MESSAGE_TYPE); + mPeerAddress = peerAddress; + switch (static_cast(payloadHeader.GetMessageType())) { case Spake2pMsgType::kSpake2pCompute_pA: diff --git a/src/transport/SecurePairingSession.h b/src/transport/SecurePairingSession.h index c2052e03ae6737..0205544add81b0 100644 --- a/src/transport/SecurePairingSession.h +++ b/src/transport/SecurePairingSession.h @@ -30,6 +30,8 @@ #include #include #include +#include +#include namespace chip { @@ -47,10 +49,12 @@ class DLL_EXPORT SecurePairingSessionDelegate * * @param header the message header for the sent message * @param payloadFlags payload encoding flags + * @param peerAddress the destination of the message * @param msgBuf the raw data for the message being sent * @return CHIP_ERROR Error thrown when sending the message */ - virtual CHIP_ERROR SendPairingMessage(const PacketHeader & header, Header::Flags payloadFlags, System::PacketBuffer * msgBuf) + virtual CHIP_ERROR SendPairingMessage(const PacketHeader & header, Header::Flags payloadFlags, + const Transport::PeerAddress & peerAddress, System::PacketBuffer * msgBuf) { return CHIP_ERROR_NOT_IMPLEMENTED; } @@ -117,6 +121,7 @@ class DLL_EXPORT SecurePairingSession * @brief * Create a pairing request using peer's setup PIN code. * + * @param peerAddress Address of peer to pair * @param peerSetUpPINCode Setup PIN code of the peer device * @param pbkdf2IterCount Iteration count for PBKDF2 function * @param salt Salt to be used for SPAKE2P opertation @@ -127,8 +132,9 @@ class DLL_EXPORT SecurePairingSession * * @return CHIP_ERROR The result of initialization */ - CHIP_ERROR Pair(uint32_t peerSetUpPINCode, uint32_t pbkdf2IterCount, const uint8_t * salt, size_t saltLen, - Optional myNodeId, uint16_t myKeyId, SecurePairingSessionDelegate * delegate); + CHIP_ERROR Pair(const Transport::PeerAddress peerAddress, uint32_t peerSetUpPINCode, uint32_t pbkdf2IterCount, + const uint8_t * salt, size_t saltLen, Optional myNodeId, uint16_t myKeyId, + SecurePairingSessionDelegate * delegate); /** * @brief @@ -148,10 +154,12 @@ class DLL_EXPORT SecurePairingSession * Handler for peer's messages, exchanged during pairing handshake. * * @param packetHeader Message header for the received message + * @param peerAddress Source of the message * @param msg Message sent by the peer * @return CHIP_ERROR The result of message processing */ - virtual CHIP_ERROR HandlePeerMessage(const PacketHeader & packetHeader, System::PacketBuffer * msg); + virtual CHIP_ERROR HandlePeerMessage(const PacketHeader & packetHeader, const Transport::PeerAddress & peerAddress, + System::PacketBuffer * msg); /** * @brief @@ -241,6 +249,8 @@ class DLL_EXPORT SecurePairingSession uint16_t mPeerKeyId; + Transport::PeerAddress mPeerAddress; + uint8_t mKe[kMAX_Hash_Length]; size_t mKeLen = sizeof(mKe); @@ -300,7 +310,11 @@ class SecurePairingUsingTestSecret : public SecurePairingSession return CHIP_NO_ERROR; } - CHIP_ERROR HandlePeerMessage(const PacketHeader & packetHeader, System::PacketBuffer * msg) override { return CHIP_NO_ERROR; } + CHIP_ERROR HandlePeerMessage(const PacketHeader & packetHeader, const Transport::PeerAddress & peerAddress, + System::PacketBuffer * msg) override + { + return CHIP_NO_ERROR; + } }; typedef struct SecurePairingSessionSerialized diff --git a/src/transport/SecureSessionMgr.cpp b/src/transport/SecureSessionMgr.cpp index 72c1174fe05b81..ba4d4fd300c99e 100644 --- a/src/transport/SecureSessionMgr.cpp +++ b/src/transport/SecureSessionMgr.cpp @@ -33,7 +33,12 @@ #include #include #include +#include #include +#include +#include + +#include namespace chip { @@ -49,44 +54,45 @@ using Transport::PeerConnectionState; // TODO: this should be checked within the transport message sending instead of the session management layer. static const size_t kMax_SecureSDU_Length = 1024; -SecureSessionMgrBase::SecureSessionMgrBase() : mState(State::kNotReady) {} +SecureSessionMgr::SecureSessionMgr() : mState(State::kNotReady) {} -SecureSessionMgrBase::~SecureSessionMgrBase() +SecureSessionMgr::~SecureSessionMgr() { CancelExpiryTimer(); } -CHIP_ERROR SecureSessionMgrBase::InitInternal(NodeId localNodeId, System::Layer * systemLayer, Transport::Base * transport) +CHIP_ERROR SecureSessionMgr::Init(NodeId localNodeId, System::Layer * systemLayer, TransportMgrBase * transportMgr) { CHIP_ERROR err = CHIP_NO_ERROR; VerifyOrExit(mState == State::kNotReady, err = CHIP_ERROR_INCORRECT_STATE); + VerifyOrExit(transportMgr != nullptr, err = CHIP_ERROR_INVALID_ARGUMENT); - mState = State::kInitialized; - mLocalNodeId = localNodeId; - mSystemLayer = systemLayer; - mTransport = transport; + mState = State::kInitialized; + mLocalNodeId = localNodeId; + mSystemLayer = systemLayer; + mTransportMgr = transportMgr; ChipLogProgress(Inet, "local node id is %llu\n", mLocalNodeId); - mTransport->SetMessageReceiveHandler(HandleDataReceived, this); - Mdns::DiscoveryManager::GetInstance().Init(); Mdns::DiscoveryManager::GetInstance().RegisterResolveDelegate(this); ScheduleExpiryTimer(); + mTransportMgr->SetSecureSessionMgr(this); + exit: return err; } -CHIP_ERROR SecureSessionMgrBase::SendMessage(NodeId peerNodeId, System::PacketBuffer * msgBuf) +CHIP_ERROR SecureSessionMgr::SendMessage(NodeId peerNodeId, System::PacketBuffer * msgBuf) { PayloadHeader payloadHeader; return SendMessage(payloadHeader, peerNodeId, msgBuf); } -CHIP_ERROR SecureSessionMgrBase::SendMessage(PayloadHeader & payloadHeader, NodeId peerNodeId, System::PacketBuffer * msgIn) +CHIP_ERROR SecureSessionMgr::SendMessage(PayloadHeader & payloadHeader, NodeId peerNodeId, System::PacketBuffer * msgIn) { System::PacketBufferHandle msgBuf; CHIP_ERROR err = CHIP_NO_ERROR; @@ -127,6 +133,7 @@ CHIP_ERROR SecureSessionMgrBase::SendMessage(PayloadHeader & payloadHeader, Node .SetMessageId(state->GetSendMessageIndex()) // .SetEncryptionKeyID(state->GetLocalKeyID()) // .SetPayloadLength(static_cast(payloadLength)); + packetHeader.GetFlags().Set(Header::FlagValues::kSecure); ChipLogProgress(Inet, "Sending msg from %llu to %llu\n", mLocalNodeId, peerNodeId); @@ -150,8 +157,8 @@ CHIP_ERROR SecureSessionMgrBase::SendMessage(PayloadHeader & payloadHeader, Node ChipLogDetail(Inet, "Secure transport transmitting msg %u after encryption", state->GetSendMessageIndex()); - err = mTransport->SendMessage(packetHeader, payloadHeader.GetEncodePacketFlags(), state->GetPeerAddress(), - msgBuf.Release_ForNow()); + err = mTransportMgr->SendMessage(packetHeader, payloadHeader.GetEncodePacketFlags(), state->GetPeerAddress(), + msgBuf.Release_ForNow()); } SuccessOrExit(err); state->IncrementSendMessageIndex(); @@ -173,8 +180,8 @@ CHIP_ERROR SecureSessionMgrBase::SendMessage(PayloadHeader & payloadHeader, Node return err; } -CHIP_ERROR SecureSessionMgrBase::NewPairing(const Optional & peerAddr, NodeId peerNodeId, - SecurePairingSession * pairing) +CHIP_ERROR SecureSessionMgr::NewPairing(const Optional & peerAddr, NodeId peerNodeId, + SecurePairingSession * pairing) { CHIP_ERROR err = CHIP_NO_ERROR; @@ -221,7 +228,7 @@ CHIP_ERROR SecureSessionMgrBase::NewPairing(const OptionalStartTimer(CHIP_PEER_CONNECTION_TIMEOUT_CHECK_FREQUENCY_MS, SecureSessionMgrBase::ExpiryTimerCallback, this); + mSystemLayer->StartTimer(CHIP_PEER_CONNECTION_TIMEOUT_CHECK_FREQUENCY_MS, SecureSessionMgr::ExpiryTimerCallback, this); VerifyOrDie(err == CHIP_NO_ERROR); } -void SecureSessionMgrBase::CancelExpiryTimer() +void SecureSessionMgr::CancelExpiryTimer() { if (mSystemLayer != nullptr) { - mSystemLayer->CancelTimer(SecureSessionMgrBase::ExpiryTimerCallback, this); + mSystemLayer->CancelTimer(SecureSessionMgr::ExpiryTimerCallback, this); } } -void SecureSessionMgrBase::HandleDataReceived(const PacketHeader & packetHeader, const PeerAddress & peerAddress, - System::PacketBuffer * msgIn, SecureSessionMgrBase * connection) +void SecureSessionMgr::OnMessageReceived(const PacketHeader & packetHeader, const PeerAddress & peerAddress, + System::PacketBuffer * msgIn) { - CHIP_ERROR err = CHIP_NO_ERROR; - PeerConnectionState * state = connection->mPeerConnections.FindPeerConnectionState(packetHeader.GetSourceNodeId(), - packetHeader.GetEncryptionKeyID(), nullptr); + CHIP_ERROR err = CHIP_NO_ERROR; + PeerConnectionState * state = + mPeerConnections.FindPeerConnectionState(packetHeader.GetSourceNodeId(), packetHeader.GetEncryptionKeyID(), nullptr); PacketBufferHandle msg; PacketBufferHandle origMsg; @@ -295,7 +302,7 @@ void SecureSessionMgrBase::HandleDataReceived(const PacketHeader & packetHeader, state->SetPeerAddress(peerAddress); } - connection->mPeerConnections.MarkConnectionActive(state); + mPeerConnections.MarkConnectionActive(state); // TODO this is where messages should be decoded { @@ -343,20 +350,20 @@ void SecureSessionMgrBase::HandleDataReceived(const PacketHeader & packetHeader, state->SetPeerNodeId(packetHeader.GetSourceNodeId().Value()); } - if (connection->mCB != nullptr) + if (mCB != nullptr) { - connection->mCB->OnMessageReceived(packetHeader, payloadHeader, state, msg.Release_ForNow(), connection); + mCB->OnMessageReceived(packetHeader, payloadHeader, state, msg.Release_ForNow(), this); } } exit: - if (err != CHIP_NO_ERROR && connection->mCB != nullptr) + if (err != CHIP_NO_ERROR && mCB != nullptr) { - connection->mCB->OnReceiveError(err, peerAddress, connection); + mCB->OnReceiveError(err, peerAddress, this); } } -void SecureSessionMgrBase::HandleConnectionExpired(const Transport::PeerConnectionState & state) +void SecureSessionMgr::HandleConnectionExpired(const Transport::PeerConnectionState & state) { char addr[Transport::PeerAddress::kMaxToStringSize]; state.GetPeerAddress().ToString(addr, sizeof(addr)); @@ -368,12 +375,12 @@ void SecureSessionMgrBase::HandleConnectionExpired(const Transport::PeerConnecti mCB->OnConnectionExpired(&state, this); } - mTransport->Disconnect(state.GetPeerAddress()); + mTransportMgr->Disconnect(state.GetPeerAddress()); } -void SecureSessionMgrBase::ExpiryTimerCallback(System::Layer * layer, void * param, System::Error error) +void SecureSessionMgr::ExpiryTimerCallback(System::Layer * layer, void * param, System::Error error) { - SecureSessionMgrBase * mgr = reinterpret_cast(param); + SecureSessionMgr * mgr = reinterpret_cast(param); #if CHIP_CONFIG_SESSION_REKEYING // TODO(#2279): session expiration is currently disabled until rekeying is supported // the #ifdef should be removed after that. diff --git a/src/transport/SecureSessionMgr.h b/src/transport/SecureSessionMgr.h index f13274b55ee4e6..dcb73ee0fb4c1f 100644 --- a/src/transport/SecureSessionMgr.h +++ b/src/transport/SecureSessionMgr.h @@ -36,17 +36,18 @@ #include #include #include +#include #include #include namespace chip { -class SecureSessionMgrBase; +class SecureSessionMgr; /** * @brief * This class provides a skeleton for the callback functions. The functions will be - * called by SecureSssionMgrBase object on specific events. If the user of SecureSessionMgrBase + * called by SecureSssionMgrBase object on specific events. If the user of SecureSessionMgr * is interested in receiving these callbacks, they can specialize this class and handle * each trigger in their implementation of this class. */ @@ -66,7 +67,7 @@ class DLL_EXPORT SecureSessionMgrDelegate */ virtual void OnMessageReceived(const PacketHeader & packetHeader, const PayloadHeader & payloadHeader, const Transport::PeerConnectionState * state, System::PacketBuffer * msgBuf, - SecureSessionMgrBase * mgr) + SecureSessionMgr * mgr) {} /** @@ -77,7 +78,7 @@ class DLL_EXPORT SecureSessionMgrDelegate * @param source network entity that sent the message * @param mgr A pointer to the SecureSessionMgr */ - virtual void OnReceiveError(CHIP_ERROR error, const Transport::PeerAddress & source, SecureSessionMgrBase * mgr) {} + virtual void OnReceiveError(CHIP_ERROR error, const Transport::PeerAddress & source, SecureSessionMgr * mgr) {} /** * @brief @@ -86,7 +87,7 @@ class DLL_EXPORT SecureSessionMgrDelegate * @param state connection state * @param mgr A pointer to the SecureSessionMgr */ - virtual void OnNewConnection(const Transport::PeerConnectionState * state, SecureSessionMgrBase * mgr) {} + virtual void OnNewConnection(const Transport::PeerConnectionState * state, SecureSessionMgr * mgr) {} /** * @brief @@ -95,7 +96,7 @@ class DLL_EXPORT SecureSessionMgrDelegate * @param state connection state * @param mgr A pointer to the SecureSessionMgr */ - virtual void OnConnectionExpired(const Transport::PeerConnectionState * state, SecureSessionMgrBase * mgr) {} + virtual void OnConnectionExpired(const Transport::PeerConnectionState * state, SecureSessionMgr * mgr) {} /** * @brief @@ -105,12 +106,12 @@ class DLL_EXPORT SecureSessionMgrDelegate * @param nodeId The node ID resolved, 0 on error * @param mgr A pointer to the SecureSessionMgr */ - virtual void OnAddressResolved(CHIP_ERROR error, NodeId nodeId, SecureSessionMgrBase * mgr) {} + virtual void OnAddressResolved(CHIP_ERROR error, NodeId nodeId, SecureSessionMgr * mgr) {} virtual ~SecureSessionMgrDelegate() {} }; -class DLL_EXPORT SecureSessionMgrBase : public Mdns::ResolveDelegate +class DLL_EXPORT SecureSessionMgr : public Mdns::ResolveDelegate, public TransportMgrDelegate { public: /** @@ -123,8 +124,8 @@ class DLL_EXPORT SecureSessionMgrBase : public Mdns::ResolveDelegate */ CHIP_ERROR SendMessage(NodeId peerNodeId, System::PacketBuffer * msgBuf); CHIP_ERROR SendMessage(PayloadHeader & payloadHeader, NodeId peerNodeId, System::PacketBuffer * msgBuf); - SecureSessionMgrBase(); - ~SecureSessionMgrBase() override; + SecureSessionMgr(); + ~SecureSessionMgr() override; /** * @brief @@ -152,16 +153,27 @@ class DLL_EXPORT SecureSessionMgrBase : public Mdns::ResolveDelegate */ System::Layer * SystemLayer() { return mSystemLayer; } -protected: /** * @brief * Initialize a Secure Session Manager * * @param localNodeId Node id for the current node * @param systemLayer System, layer to use - * @param transport Underlying Transport to use + * @param transportMgr Transport to use + */ + CHIP_ERROR Init(NodeId localNodeId, System::Layer * systemLayer, TransportMgrBase * transportMgr); + +protected: + /** + * @brief + * Handle received secure message. Implements TransportMgrDelegate + * + * @param header the received message header + * @param source the source address of the package + * @param msgBuf the buffer of (encrypted) payload */ - CHIP_ERROR InitInternal(NodeId localNodeId, System::Layer * systemLayer, Transport::Base * transport); + void OnMessageReceived(const PacketHeader & header, const Transport::PeerAddress & source, + System::PacketBuffer * msgBuf) override; private: /** @@ -173,13 +185,13 @@ class DLL_EXPORT SecureSessionMgrBase : public Mdns::ResolveDelegate kInitialized, /**< State when the object is ready connect to other peers. */ }; - Transport::Base * mTransport = nullptr; System::Layer * mSystemLayer = nullptr; NodeId mLocalNodeId; // < Id of the current node Transport::PeerConnections mPeerConnections; // < Active connections to other peers State mState; // < Initialization state of the object - SecureSessionMgrDelegate * mCB = nullptr; + SecureSessionMgrDelegate * mCB = nullptr; + TransportMgrBase * mTransportMgr = nullptr; /** Schedules a new oneshot timer for checking connection expiry. */ void ScheduleExpiryTimer(); @@ -187,9 +199,6 @@ class DLL_EXPORT SecureSessionMgrBase : public Mdns::ResolveDelegate /** Cancels any active timers for connection expiry checks. */ void CancelExpiryTimer(); - static void HandleDataReceived(const PacketHeader & header, const Transport::PeerAddress & source, - System::PacketBuffer * msgBuf, SecureSessionMgrBase * transport); - /** * Called when a specific connection expires. */ @@ -203,44 +212,4 @@ class DLL_EXPORT SecureSessionMgrBase : public Mdns::ResolveDelegate void HandleNodeIdResolve(CHIP_ERROR error, NodeId nodeId, const Mdns::MdnsService & service) override; }; -/** - * A secure session manager that includes required underlying transports. - */ -template -class SecureSessionMgr : public SecureSessionMgrBase -{ -public: - /** - * @brief - * Initialize a Secure Session Manager - * - * @param localNodeId Node id for the current node - * @param systemLayer System, layer to use - * @param transportInitArgs Arguments to initialize the underlying transport - */ - template - CHIP_ERROR Init(NodeId localNodeId, System::Layer * systemLayer, Args &&... transportInitArgs) - { - CHIP_ERROR err = CHIP_NO_ERROR; - - err = mTransport.Init(std::forward(transportInitArgs)...); - SuccessOrExit(err); - - err = InitInternal(localNodeId, systemLayer, &mTransport); - SuccessOrExit(err); - - exit: - return err; - } - - template - CHIP_ERROR ResetTransport(Args &&... transportInitArgs) - { - return mTransport.Init(std::forward(transportInitArgs)...); - } - -private: - Transport::Tuple mTransport; -}; - } // namespace chip diff --git a/src/transport/TransportMgr.cpp b/src/transport/TransportMgr.cpp new file mode 100644 index 00000000000000..6425866ce79444 --- /dev/null +++ b/src/transport/TransportMgr.cpp @@ -0,0 +1,67 @@ +/* + * + * Copyright (c) 2020 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + */ + +/** + * @file + * This file implements a stateless TransportMgr, it will took a raw message + * buffer from transports, and then extract the message header without decode it. + * For secure messages, it will pass it to the SecureSessionMgr, and for unsecure + * messages (rendezvous messages), it will pass it to RendezvousSession. + * When sending messages, it will encode the packet header, and pass it to the + * transports. + * The whole process is fully stateless. + */ + +#include + +#include +#include +#include +#include +#include + +namespace chip { + +CHIP_ERROR TransportMgrBase::Init(Transport::Base * transport) +{ + if (mTransport != nullptr) + { + return CHIP_ERROR_INCORRECT_STATE; + } + mTransport = transport; + mTransport->SetMessageReceiveHandler(HandleMessageReceived, this); + ChipLogDetail(Inet, "TransportMgr initialized"); + return CHIP_NO_ERROR; +} + +void TransportMgrBase::HandleMessageReceived(const PacketHeader & packetHeader, const Transport::PeerAddress & peerAddress, + System::PacketBuffer * msg, TransportMgrBase * dispatcher) +{ + TransportMgrDelegate * handler = + packetHeader.GetFlags().Has(Header::FlagValues::kSecure) ? dispatcher->mSecureSessionMgr : dispatcher->mRendezvous; + if (handler != nullptr) + { + handler->OnMessageReceived(packetHeader, peerAddress, msg); + } + else + { + char addrBuffer[Transport::PeerAddress::kMaxToStringSize]; + peerAddress.ToString(addrBuffer, sizeof(addrBuffer)); + ChipLogError(Inet, "%s message from %s is dropped since no corresponding handler is set in TransportMgr.", + packetHeader.GetFlags().Has(Header::FlagValues::kSecure) ? "Encrypted" : "Unencrypted", addrBuffer); + } +} +} // namespace chip diff --git a/src/transport/TransportMgr.h b/src/transport/TransportMgr.h new file mode 100644 index 00000000000000..af0e3c75ec56e2 --- /dev/null +++ b/src/transport/TransportMgr.h @@ -0,0 +1,109 @@ +/* + * + * Copyright (c) 2020 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + */ + +/** + * @file + * This file implements a stateless TransportMgr, it will took a raw message + * buffer from transports, and then extract the message header without decode it. + * For secure messages, it will pass it to the SecureSessionMgr, and for unsecure + * messages (rendezvous messages), it will pass it to RendezvousSession. + * When sending messages, it will encode the packet header, and pass it to the + * transports. + * The whole process is fully stateless. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace chip { + +class TransportMgrBase; + +class TransportMgrDelegate +{ +public: + virtual ~TransportMgrDelegate() = default; + /** + * @brief + * Handle received secure message. + * + * @param header the received message header + * @param source the source address of the package + * @param msgBuf the buffer of (encrypted) payload + */ + virtual void OnMessageReceived(const PacketHeader & header, const Transport::PeerAddress & source, + System::PacketBuffer * msgBuf) = 0; +}; + +class TransportMgrBase +{ +public: + CHIP_ERROR Init(Transport::Base * transport); + + CHIP_ERROR SendMessage(const PacketHeader & header, Header::Flags payloadFlags, const Transport::PeerAddress & address, + System::PacketBuffer * msgBuf) + { + return mTransport->SendMessage(header, payloadFlags, address, msgBuf); + } + + void Disconnect(const Transport::PeerAddress & address) { mTransport->Disconnect(address); } + + void SetSecureSessionMgr(TransportMgrDelegate * secureSessionMgr) { mSecureSessionMgr = secureSessionMgr; } + + void SetRendezvousSession(TransportMgrDelegate * rendezvousSessionMgr) { mRendezvous = rendezvousSessionMgr; } + +private: + static void HandleMessageReceived(const PacketHeader & packetHeader, const Transport::PeerAddress & peerAddress, + System::PacketBuffer * msg, TransportMgrBase * dispatcher); + + TransportMgrDelegate * mSecureSessionMgr = nullptr; + TransportMgrDelegate * mRendezvous = nullptr; + Transport::Base * mTransport = nullptr; +}; + +template +class TransportMgr : public TransportMgrBase +{ +public: + template + CHIP_ERROR Init(Args &&... transportInitArgs) + { + CHIP_ERROR err = CHIP_NO_ERROR; + + err = mTransport.Init(std::forward(transportInitArgs)...); + SuccessOrExit(err); + err = TransportMgrBase::Init(&mTransport); + exit: + return err; + } + + template + CHIP_ERROR ResetTransport(Args &&... transportInitArgs) + { + return mTransport.Init(std::forward(transportInitArgs)...); + } + +private: + Transport::Tuple mTransport; +}; + +} // namespace chip diff --git a/src/transport/raw/MessageHeader.h b/src/transport/raw/MessageHeader.h index 13d1987317d65c..f32d52fd4ac186 100644 --- a/src/transport/raw/MessageHeader.h +++ b/src/transport/raw/MessageHeader.h @@ -74,16 +74,22 @@ enum class FlagValues : uint16_t /// Header flag specifying that it is a control message for secure session. kSecureSessionControlMessage = 0x0800, + /// Header flag specifying that it is a encrypted message. + kSecure = 0x0001, + }; using Flags = BitFlags; using ExFlags = BitFlags; // Header is a 16-bit value of the form -// | 4 bit | 4 bit | 4 bit | 4 bit | -// +---------+-------+---------+----------| -// | version | Flags | encType | reserved | -static constexpr uint16_t kFlagsMask = 0x0F00; +// | 4 bit | 4 bit |8 bit Security Flags| +// +---------+-------+--------------------| +// | version | Flags | P | C |Reserved| E | +// | | +---Encrypted +// | +----------------Control message (TODO: Implement this) +// +--------------------Privacy enhancements (TODO: Implement this) +static constexpr uint16_t kFlagsMask = 0x0F01; } // namespace Header diff --git a/src/transport/tests/TestSecurePairingSession.cpp b/src/transport/tests/TestSecurePairingSession.cpp index d50df175f8e164..dd8a78ef383d23 100644 --- a/src/transport/tests/TestSecurePairingSession.cpp +++ b/src/transport/tests/TestSecurePairingSession.cpp @@ -39,10 +39,11 @@ using namespace chip; class TestSecurePairingDelegate : public SecurePairingSessionDelegate { public: - CHIP_ERROR SendPairingMessage(const PacketHeader & header, Header::Flags payloadFlags, System::PacketBuffer * msgBuf) override + CHIP_ERROR SendPairingMessage(const PacketHeader & header, Header::Flags payloadFlags, + const Transport::PeerAddress & peerAddress, System::PacketBuffer * msgBuf) override { mNumMessageSend++; - return (peer != nullptr) ? peer->HandlePeerMessage(header, msgBuf) : mMessageSendError; + return (peer != nullptr) ? peer->HandlePeerMessage(header, peerAddress, msgBuf) : mMessageSendError; } void OnPairingError(CHIP_ERROR error) override { mNumPairingErrors++; } @@ -81,12 +82,14 @@ void SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) SecurePairingSession pairing; NL_TEST_ASSERT(inSuite, - pairing.Pair(1234, 500, nullptr, 0, Optional::Value(1), 0, &delegate) == CHIP_ERROR_INVALID_ARGUMENT); + pairing.Pair(Transport::PeerAddress(Transport::Type::kBle), 1234, 500, nullptr, 0, Optional::Value(1), 0, + &delegate) == CHIP_ERROR_INVALID_ARGUMENT); NL_TEST_ASSERT(inSuite, - pairing.Pair(1234, 500, (const uint8_t *) "salt", 4, Optional::Value(1), 0, nullptr) == - CHIP_ERROR_INVALID_ARGUMENT); + pairing.Pair(Transport::PeerAddress(Transport::Type::kBle), 1234, 500, (const uint8_t *) "salt", 4, + Optional::Value(1), 0, nullptr) == CHIP_ERROR_INVALID_ARGUMENT); NL_TEST_ASSERT(inSuite, - pairing.Pair(1234, 500, (const uint8_t *) "salt", 4, Optional::Value(1), 0, &delegate) == CHIP_NO_ERROR); + pairing.Pair(Transport::PeerAddress(Transport::Type::kBle), 1234, 500, (const uint8_t *) "salt", 4, + Optional::Value(1), 0, &delegate) == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, delegate.mNumMessageSend == 1); @@ -95,8 +98,8 @@ void SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) SecurePairingSession pairing1; NL_TEST_ASSERT(inSuite, - pairing1.Pair(1234, 500, (const uint8_t *) "salt", 4, Optional::Value(1), 0, &delegate) == - CHIP_ERROR_BAD_REQUEST); + pairing1.Pair(Transport::PeerAddress(Transport::Type::kBle), 1234, 500, (const uint8_t *) "salt", 4, + Optional::Value(1), 0, &delegate) == CHIP_ERROR_BAD_REQUEST); } void SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inContext, SecurePairingSession & pairingCommissioner, @@ -113,8 +116,8 @@ void SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inContext, S pairingAccessory.WaitForPairing(1234, 500, (const uint8_t *) "salt", 4, Optional::Value(1), 0, &delegateAccessory) == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, - pairingCommissioner.Pair(1234, 500, (const uint8_t *) "salt", 4, Optional::Value(2), 0, - &delegateCommissioner) == CHIP_NO_ERROR); + pairingCommissioner.Pair(Transport::PeerAddress(Transport::Type::kBle), 1234, 500, (const uint8_t *) "salt", 4, + Optional::Value(2), 0, &delegateCommissioner) == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, delegateAccessory.mNumMessageSend == 1); NL_TEST_ASSERT(inSuite, delegateAccessory.mNumPairingComplete == 1); diff --git a/src/transport/tests/TestSecureSessionMgr.cpp b/src/transport/tests/TestSecureSessionMgr.cpp index d8c4f5642b8df0..9795f437690dd9 100644 --- a/src/transport/tests/TestSecureSessionMgr.cpp +++ b/src/transport/tests/TestSecureSessionMgr.cpp @@ -26,6 +26,7 @@ #include #include #include +#include #include #include @@ -67,7 +68,7 @@ class TestSessMgrCallback : public SecureSessionMgrDelegate { public: void OnMessageReceived(const PacketHeader & header, const PayloadHeader & payloadHeader, const PeerConnectionState * state, - System::PacketBuffer * msgBuf, SecureSessionMgrBase * mgr) override + System::PacketBuffer * msgBuf, SecureSessionMgr * mgr) override { NL_TEST_ASSERT(mSuite, header.GetSourceNodeId() == Optional::Value(kSourceNodeId)); NL_TEST_ASSERT(mSuite, header.GetDestinationNodeId() == Optional::Value(kDestinationNodeId)); @@ -81,10 +82,7 @@ class TestSessMgrCallback : public SecureSessionMgrDelegate ReceiveHandlerCallCount++; } - void OnNewConnection(const PeerConnectionState * state, SecureSessionMgrBase * mgr) override - { - NewConnectionHandlerCallCount++; - } + void OnNewConnection(const PeerConnectionState * state, SecureSessionMgr * mgr) override { NewConnectionHandlerCallCount++; } nlTestSuite * mSuite = nullptr; int ReceiveHandlerCallCount = 0; @@ -97,12 +95,15 @@ void CheckSimpleInitTest(nlTestSuite * inSuite, void * inContext) { TestContext & ctx = *reinterpret_cast(inContext); - SecureSessionMgr conn; + TransportMgr transportMgr; + SecureSessionMgr secureSessionMgr; CHIP_ERROR err; ctx.GetInetLayer().SystemLayer()->Init(nullptr); - err = conn.Init(kSourceNodeId, ctx.GetInetLayer().SystemLayer(), "LOOPBACK"); + err = transportMgr.Init("LOOPBACK"); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + err = secureSessionMgr.Init(kSourceNodeId, ctx.GetInetLayer().SystemLayer(), &transportMgr); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); } @@ -124,29 +125,32 @@ void CheckMessageTest(nlTestSuite * inSuite, void * inContext) IPAddress::FromString("127.0.0.1", addr); CHIP_ERROR err = CHIP_NO_ERROR; - SecureSessionMgr conn; + TransportMgr transportMgr; + SecureSessionMgr secureSessionMgr; - err = conn.Init(kSourceNodeId, ctx.GetInetLayer().SystemLayer(), "LOOPBACK"); + err = transportMgr.Init("LOOPBACK"); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + err = secureSessionMgr.Init(kSourceNodeId, ctx.GetInetLayer().SystemLayer(), &transportMgr); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); callback.mSuite = inSuite; - conn.SetDelegate(&callback); + secureSessionMgr.SetDelegate(&callback); SecurePairingUsingTestSecret pairing1(Optional::Value(kSourceNodeId), 1, 2); Optional peer(Transport::PeerAddress::UDP(addr, CHIP_PORT)); - err = conn.NewPairing(peer, kDestinationNodeId, &pairing1); + err = secureSessionMgr.NewPairing(peer, kDestinationNodeId, &pairing1); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); SecurePairingUsingTestSecret pairing2(Optional::Value(kDestinationNodeId), 2, 1); - err = conn.NewPairing(peer, kSourceNodeId, &pairing2); + err = secureSessionMgr.NewPairing(peer, kSourceNodeId, &pairing2); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); // Should be able to send a message to itself by just calling send. callback.ReceiveHandlerCallCount = 0; - err = conn.SendMessage(kDestinationNodeId, buffer.Release_ForNow()); + err = secureSessionMgr.SendMessage(kDestinationNodeId, buffer.Release_ForNow()); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); ctx.DriveIOUntil(1000 /* ms */, []() { return callback.ReceiveHandlerCallCount != 0; });