diff --git a/examples/shell/shell_common/include/Globals.h b/examples/shell/shell_common/include/Globals.h index d91620ca0ad171..69162be92f5204 100644 --- a/examples/shell/shell_common/include/Globals.h +++ b/examples/shell/shell_common/include/Globals.h @@ -24,7 +24,9 @@ #include #include #include +#if INET_CONFIG_ENABLE_TCP_ENDPOINT #include +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT #include #if INET_CONFIG_ENABLE_TCP_ENDPOINT diff --git a/src/app/CASESessionManager.cpp b/src/app/CASESessionManager.cpp index c15ddcabe01c04..162ae7021a9f9f 100644 --- a/src/app/CASESessionManager.cpp +++ b/src/app/CASESessionManager.cpp @@ -30,62 +30,55 @@ CHIP_ERROR CASESessionManager::Init(chip::System::Layer * systemLayer, const CAS } void CASESessionManager::FindOrEstablishSession(const ScopedNodeId & peerId, Callback::Callback * onConnection, - Callback::Callback * onFailure + Callback::Callback * onFailure, #if CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES - , - uint8_t attemptCount, Callback::Callback * onRetry + uint8_t attemptCount, Callback::Callback * onRetry, #endif // CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES -) + TransportPayloadCapability transportPayloadCapability) { - FindOrEstablishSessionHelper(peerId, onConnection, onFailure, nullptr + FindOrEstablishSessionHelper(peerId, onConnection, onFailure, nullptr, #if CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES - , - attemptCount, onRetry + attemptCount, onRetry, #endif - ); + transportPayloadCapability); } void CASESessionManager::FindOrEstablishSession(const ScopedNodeId & peerId, Callback::Callback * onConnection, - Callback::Callback * onSetupFailure + Callback::Callback * onSetupFailure, #if CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES - , - uint8_t attemptCount, Callback::Callback * onRetry + uint8_t attemptCount, Callback::Callback * onRetry, #endif -) + TransportPayloadCapability transportPayloadCapability) { - FindOrEstablishSessionHelper(peerId, onConnection, nullptr, onSetupFailure + FindOrEstablishSessionHelper(peerId, onConnection, nullptr, onSetupFailure, #if CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES - , - attemptCount, onRetry + attemptCount, onRetry, #endif - ); + transportPayloadCapability); } void CASESessionManager::FindOrEstablishSession(const ScopedNodeId & peerId, Callback::Callback * onConnection, - std::nullptr_t + std::nullptr_t, #if CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES - , - uint8_t attemptCount, Callback::Callback * onRetry + uint8_t attemptCount, Callback::Callback * onRetry, #endif -) + TransportPayloadCapability transportPayloadCapability) { - FindOrEstablishSessionHelper(peerId, onConnection, nullptr, nullptr + FindOrEstablishSessionHelper(peerId, onConnection, nullptr, nullptr, #if CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES - , - attemptCount, onRetry + attemptCount, onRetry, #endif - ); + transportPayloadCapability); } void CASESessionManager::FindOrEstablishSessionHelper(const ScopedNodeId & peerId, Callback::Callback * onConnection, Callback::Callback * onFailure, - Callback::Callback * onSetupFailure + Callback::Callback * onSetupFailure, #if CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES - , - uint8_t attemptCount, Callback::Callback * onRetry + uint8_t attemptCount, Callback::Callback * onRetry, #endif -) + TransportPayloadCapability transportPayloadCapability) { ChipLogDetail(CASESessionManager, "FindOrEstablishSession: PeerId = [%d:" ChipLogFormatX64 "]", peerId.GetFabricIndex(), ChipLogValueX64(peerId.GetNodeId())); @@ -124,12 +117,12 @@ void CASESessionManager::FindOrEstablishSessionHelper(const ScopedNodeId & peerI if (onFailure != nullptr) { - session->Connect(onConnection, onFailure); + session->Connect(onConnection, onFailure, transportPayloadCapability); } if (onSetupFailure != nullptr) { - session->Connect(onConnection, onSetupFailure); + session->Connect(onConnection, onSetupFailure, transportPayloadCapability); } } @@ -143,10 +136,11 @@ void CASESessionManager::ReleaseAllSessions() mConfig.sessionSetupPool->ReleaseAllSessionSetup(); } -CHIP_ERROR CASESessionManager::GetPeerAddress(const ScopedNodeId & peerId, Transport::PeerAddress & addr) +CHIP_ERROR CASESessionManager::GetPeerAddress(const ScopedNodeId & peerId, Transport::PeerAddress & addr, + TransportPayloadCapability transportPayloadCapability) { ReturnErrorOnFailure(mConfig.sessionInitParams.Validate()); - auto optionalSessionHandle = FindExistingSession(peerId); + auto optionalSessionHandle = FindExistingSession(peerId, transportPayloadCapability); ReturnErrorCodeIf(!optionalSessionHandle.HasValue(), CHIP_ERROR_NOT_CONNECTED); addr = optionalSessionHandle.Value()->AsSecureSession()->GetPeerAddress(); return CHIP_NO_ERROR; @@ -182,10 +176,11 @@ OperationalSessionSetup * CASESessionManager::FindExistingSessionSetup(const Sco return mConfig.sessionSetupPool->FindSessionSetup(peerId, forAddressUpdate); } -Optional CASESessionManager::FindExistingSession(const ScopedNodeId & peerId) const +Optional CASESessionManager::FindExistingSession(const ScopedNodeId & peerId, + const TransportPayloadCapability transportPayloadCapability) const { - return mConfig.sessionInitParams.sessionManager->FindSecureSessionForNode(peerId, - MakeOptional(Transport::SecureSession::Type::kCASE)); + return mConfig.sessionInitParams.sessionManager->FindSecureSessionForNode( + peerId, MakeOptional(Transport::SecureSession::Type::kCASE), transportPayloadCapability); } void CASESessionManager::ReleaseSession(OperationalSessionSetup * session) diff --git a/src/app/CASESessionManager.h b/src/app/CASESessionManager.h index e78478852b640e..38b39108b43b7e 100644 --- a/src/app/CASESessionManager.h +++ b/src/app/CASESessionManager.h @@ -26,6 +26,7 @@ #include #include #include +#include #include namespace chip { @@ -78,12 +79,11 @@ class CASESessionManager : public OperationalSessionReleaseDelegate, public Sess * setup is not successful. */ void FindOrEstablishSession(const ScopedNodeId & peerId, Callback::Callback * onConnection, - Callback::Callback * onFailure + Callback::Callback * onFailure, #if CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES - , - uint8_t attemptCount = 1, Callback::Callback * onRetry = nullptr + uint8_t attemptCount = 1, Callback::Callback * onRetry = nullptr, #endif // CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES - ); + TransportPayloadCapability transportPayloadCapability = TransportPayloadCapability::kMRPPayload); /** * Find an existing session for the given node ID or trigger a new session request. @@ -106,14 +106,14 @@ class CASESessionManager : public OperationalSessionReleaseDelegate, public Sess * @param onSetupFailure A callback to be called upon an extended device connection failure. * @param attemptCount The number of retry attempts if session setup fails (default is 1). * @param onRetry A callback to be called on a retry attempt (enabled by a config flag). + * @param transportPayloadCapability An indicator of what payload types the session needs to be able to transport. */ void FindOrEstablishSession(const ScopedNodeId & peerId, Callback::Callback * onConnection, - Callback::Callback * onSetupFailure + Callback::Callback * onSetupFailure, #if CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES - , - uint8_t attemptCount = 1, Callback::Callback * onRetry = nullptr + uint8_t attemptCount = 1, Callback::Callback * onRetry = nullptr, #endif // CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES - ); + TransportPayloadCapability transportPayloadCapability = TransportPayloadCapability::kMRPPayload); /** * Find an existing session for the given node ID or trigger a new session request. @@ -134,13 +134,13 @@ class CASESessionManager : public OperationalSessionReleaseDelegate, public Sess * @param onConnection A callback to be called upon successful connection establishment. * @param attemptCount The number of retry attempts if session setup fails (default is 1). * @param onRetry A callback to be called on a retry attempt (enabled by a config flag). + * @param transportPayloadCapability An indicator of what payload types the session needs to be able to transport. */ - void FindOrEstablishSession(const ScopedNodeId & peerId, Callback::Callback * onConnection, std::nullptr_t + void FindOrEstablishSession(const ScopedNodeId & peerId, Callback::Callback * onConnection, std::nullptr_t, #if CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES - , - uint8_t attemptCount = 1, Callback::Callback * onRetry = nullptr + uint8_t attemptCount = 1, Callback::Callback * onRetry = nullptr, #endif // CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES - ); + TransportPayloadCapability transportPayloadCapability = TransportPayloadCapability::kMRPPayload); void ReleaseSessionsForFabric(FabricIndex fabricIndex); @@ -154,7 +154,8 @@ class CASESessionManager : public OperationalSessionReleaseDelegate, public Sess * an ongoing session with the peer node. If the session doesn't exist, the API will return * `CHIP_ERROR_NOT_CONNECTED` error. */ - CHIP_ERROR GetPeerAddress(const ScopedNodeId & peerId, Transport::PeerAddress & addr); + CHIP_ERROR GetPeerAddress(const ScopedNodeId & peerId, Transport::PeerAddress & addr, + TransportPayloadCapability transportPayloadCapability = TransportPayloadCapability::kMRPPayload); //////////// OperationalSessionReleaseDelegate Implementation /////////////// void ReleaseSession(OperationalSessionSetup * device) override; @@ -165,15 +166,17 @@ class CASESessionManager : public OperationalSessionReleaseDelegate, public Sess private: OperationalSessionSetup * FindExistingSessionSetup(const ScopedNodeId & peerId, bool forAddressUpdate = false) const; - Optional FindExistingSession(const ScopedNodeId & peerId) const; + Optional FindExistingSession( + const ScopedNodeId & peerId, + const TransportPayloadCapability transportPayloadCapability = TransportPayloadCapability::kMRPPayload) const; void FindOrEstablishSessionHelper(const ScopedNodeId & peerId, Callback::Callback * onConnection, Callback::Callback * onFailure, Callback::Callback * onSetupFailure, #if CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES - uint8_t attemptCount, Callback::Callback * onRetry + uint8_t attemptCount, Callback::Callback * onRetry, #endif - ); + TransportPayloadCapability transportPayloadCapability); CASESessionManagerConfig mConfig; }; diff --git a/src/app/OperationalSessionSetup.cpp b/src/app/OperationalSessionSetup.cpp index 157de953116ef5..9d91b7b573a59f 100644 --- a/src/app/OperationalSessionSetup.cpp +++ b/src/app/OperationalSessionSetup.cpp @@ -76,8 +76,8 @@ bool OperationalSessionSetup::AttachToExistingSecureSession() mState == State::WaitingForRetry, false); - auto sessionHandle = - mInitParams.sessionManager->FindSecureSessionForNode(mPeerId, MakeOptional(Transport::SecureSession::Type::kCASE)); + auto sessionHandle = mInitParams.sessionManager->FindSecureSessionForNode( + mPeerId, MakeOptional(Transport::SecureSession::Type::kCASE), mTransportPayloadCapability); if (!sessionHandle.HasValue()) return false; @@ -93,11 +93,13 @@ bool OperationalSessionSetup::AttachToExistingSecureSession() void OperationalSessionSetup::Connect(Callback::Callback * onConnection, Callback::Callback * onFailure, - Callback::Callback * onSetupFailure) + Callback::Callback * onSetupFailure, + TransportPayloadCapability transportPayloadCapability) { CHIP_ERROR err = CHIP_NO_ERROR; bool isConnected = false; + mTransportPayloadCapability = transportPayloadCapability; // // Always enqueue our user provided callbacks into our callback list. // If anything goes wrong below, we'll trigger failures (including any queued from @@ -180,15 +182,17 @@ void OperationalSessionSetup::Connect(Callback::Callback * on } void OperationalSessionSetup::Connect(Callback::Callback * onConnection, - Callback::Callback * onFailure) + Callback::Callback * onFailure, + TransportPayloadCapability transportPayloadCapability) { - Connect(onConnection, onFailure, nullptr); + Connect(onConnection, onFailure, nullptr, transportPayloadCapability); } void OperationalSessionSetup::Connect(Callback::Callback * onConnection, - Callback::Callback * onSetupFailure) + Callback::Callback * onSetupFailure, + TransportPayloadCapability transportPayloadCapability) { - Connect(onConnection, nullptr, onSetupFailure); + Connect(onConnection, nullptr, onSetupFailure, transportPayloadCapability); } void OperationalSessionSetup::UpdateDeviceData(const Transport::PeerAddress & addr, const ReliableMessageProtocolConfig & config) @@ -288,6 +292,16 @@ void OperationalSessionSetup::UpdateDeviceData(const Transport::PeerAddress & ad CHIP_ERROR OperationalSessionSetup::EstablishConnection(const ReliableMessageProtocolConfig & config) { +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + // TODO: Combine LargePayload flag with DNS-SD advertisements from peer. + // Issue #32348. + if (mTransportPayloadCapability == TransportPayloadCapability::kLargePayload) + { + // Set the transport type for carrying large payloads + mDeviceAddress.SetTransportType(chip::Transport::Type::kTcp); + } +#endif + mCASEClient = mClientPool->Allocate(); ReturnErrorCodeIf(mCASEClient == nullptr, CHIP_ERROR_NO_MEMORY); diff --git a/src/app/OperationalSessionSetup.h b/src/app/OperationalSessionSetup.h index 1de8c305353314..5955dbab0713bd 100644 --- a/src/app/OperationalSessionSetup.h +++ b/src/app/OperationalSessionSetup.h @@ -210,8 +210,12 @@ class DLL_EXPORT OperationalSessionSetup : public SessionEstablishmentDelegate, * `onFailure` may be called before the Connect call returns, for error * cases that are detected synchronously (e.g. inability to start an address * lookup). + * + * `transportPayloadCapability` is set to kLargePayload when the session needs to be established + * over a transport that allows large payloads to be transferred, e.g., TCP. */ - void Connect(Callback::Callback * onConnection, Callback::Callback * onFailure); + void Connect(Callback::Callback * onConnection, Callback::Callback * onFailure, + TransportPayloadCapability transportPayloadCapability = TransportPayloadCapability::kMRPPayload); /* * This function can be called to establish a secure session with the device. @@ -227,8 +231,12 @@ class DLL_EXPORT OperationalSessionSetup : public SessionEstablishmentDelegate, * * `onSetupFailure` may be called before the Connect call returns, for error cases that are detected synchronously * (e.g. inability to start an address lookup). + * + * `transportPayloadCapability` is set to kLargePayload when the session needs to be established + * over a transport that allows large payloads to be transferred, e.g., TCP. */ - void Connect(Callback::Callback * onConnection, Callback::Callback * onSetupFailure); + void Connect(Callback::Callback * onConnection, Callback::Callback * onSetupFailure, + TransportPayloadCapability transportPayloadCapability = TransportPayloadCapability::kMRPPayload); bool IsForAddressUpdate() const { return mPerformingAddressUpdate; } @@ -318,6 +326,8 @@ class DLL_EXPORT OperationalSessionSetup : public SessionEstablishmentDelegate, System::Clock::Milliseconds16 mRequestedBusyDelay = System::Clock::kZero; #endif // CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES || CHIP_CONFIG_ENABLE_BUSY_HANDLING_FOR_OPERATIONAL_SESSION_SETUP + TransportPayloadCapability mTransportPayloadCapability = TransportPayloadCapability::kMRPPayload; + #if CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES // When we TryNextResult on the resolver, it will synchronously call back // into our OnNodeAddressResolved when it succeeds. We need to track @@ -351,7 +361,8 @@ class DLL_EXPORT OperationalSessionSetup : public SessionEstablishmentDelegate, void CleanupCASEClient(); void Connect(Callback::Callback * onConnection, Callback::Callback * onFailure, - Callback::Callback * onSetupFailure); + Callback::Callback * onSetupFailure, + TransportPayloadCapability transportPayloadCapability = TransportPayloadCapability::kMRPPayload); void EnqueueConnectionCallbacks(Callback::Callback * onConnection, Callback::Callback * onFailure, diff --git a/src/app/server/Server.cpp b/src/app/server/Server.cpp index cd9c56f9f021f7..1ad98d271262da 100644 --- a/src/app/server/Server.cpp +++ b/src/app/server/Server.cpp @@ -71,6 +71,9 @@ using chip::Transport::BleListenParameters; #endif using chip::Transport::PeerAddress; using chip::Transport::UdpListenParameters; +#if INET_CONFIG_ENABLE_TCP_ENDPOINT +using chip::Transport::TcpListenParameters; +#endif namespace { @@ -201,6 +204,12 @@ CHIP_ERROR Server::Init(const ServerInitParams & initParams) #if CONFIG_NETWORK_LAYER_BLE , BleListenParameters(DeviceLayer::ConnectivityMgr().GetBleLayer()) +#endif +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + , + TcpListenParameters(DeviceLayer::TCPEndPointManager()) + .SetAddressType(IPAddressType::kIPv6) + .SetListenPort(mOperationalServicePort) #endif ); diff --git a/src/app/server/Server.h b/src/app/server/Server.h index 8f6fcd5abecc66..d649e0fc923896 100644 --- a/src/app/server/Server.h +++ b/src/app/server/Server.h @@ -77,6 +77,12 @@ namespace chip { inline constexpr size_t kMaxBlePendingPackets = 1; +#if INET_CONFIG_ENABLE_TCP_ENDPOINT +inline constexpr size_t kMaxTcpActiveConnectionCount = CHIP_CONFIG_MAX_ACTIVE_TCP_CONNECTIONS; + +inline constexpr size_t kMaxTcpPendingPackets = CHIP_CONFIG_MAX_TCP_PENDING_PACKETS; +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + // // NOTE: Please do not alter the order of template specialization here as the logic // in the Server impl depends on this. @@ -89,6 +95,10 @@ using ServerTransportMgr = chip::TransportMgr +#endif +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + , + chip::Transport::TCP #endif >; diff --git a/src/app/tests/TestCommissionManager.cpp b/src/app/tests/TestCommissionManager.cpp index 0076fd6a55718f..188f51aee6a609 100644 --- a/src/app/tests/TestCommissionManager.cpp +++ b/src/app/tests/TestCommissionManager.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -104,6 +105,9 @@ void InitializeChip(nlTestSuite * suite) static chip::SimpleTestEventTriggerDelegate sSimpleTestEventTriggerDelegate; initParams.testEventTriggerDelegate = &sSimpleTestEventTriggerDelegate; (void) initParams.InitializeStaticResourcesBeforeServerInit(); + // Set a randomized server port(slightly shifted from CHIP_PORT) for testing + initParams.operationalServicePort = static_cast(initParams.operationalServicePort + chip::Crypto::GetRandU16() % 20); + err = chip::Server::GetInstance().Init(initParams); NL_TEST_ASSERT(suite, err == CHIP_NO_ERROR); diff --git a/src/controller/CHIPDeviceControllerFactory.cpp b/src/controller/CHIPDeviceControllerFactory.cpp index 198d0c47c9f614..eb52e389a6b9d6 100644 --- a/src/controller/CHIPDeviceControllerFactory.cpp +++ b/src/controller/CHIPDeviceControllerFactory.cpp @@ -161,6 +161,12 @@ CHIP_ERROR DeviceControllerFactory::InitSystemState(FactoryInitParams params) #if CONFIG_NETWORK_LAYER_BLE , Transport::BleListenParameters(stateParams.bleLayer) +#endif +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + , + Transport::TcpListenParameters(stateParams.tcpEndPointManager) + .SetAddressType(IPAddressType::kIPv6) + .SetListenPort(params.listenPort) #endif )); diff --git a/src/controller/CHIPDeviceControllerSystemState.h b/src/controller/CHIPDeviceControllerSystemState.h index 389bb557f6c0cc..1ea0593d594993 100644 --- a/src/controller/CHIPDeviceControllerSystemState.h +++ b/src/controller/CHIPDeviceControllerSystemState.h @@ -57,16 +57,27 @@ namespace chip { inline constexpr size_t kMaxDeviceTransportBlePendingPackets = 1; -using DeviceTransportMgr = TransportMgr /* BLE */ + , + Transport::BLE /* BLE */ +#endif +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + , + Transport::TCP #endif - >; + >; namespace Controller { diff --git a/src/include/platform/internal/GenericPlatformManagerImpl.ipp b/src/include/platform/internal/GenericPlatformManagerImpl.ipp index d6f29ed515e034..64878f5928cd9b 100644 --- a/src/include/platform/internal/GenericPlatformManagerImpl.ipp +++ b/src/include/platform/internal/GenericPlatformManagerImpl.ipp @@ -89,6 +89,16 @@ CHIP_ERROR GenericPlatformManagerImpl::_InitChipStack() } SuccessOrExit(err); +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + // Initialize the CHIP TCP layer. + err = TCPEndPointManager()->Init(SystemLayer()); + if (err != CHIP_NO_ERROR) + { + ChipLogError(DeviceLayer, "TCP initialization failed: %" CHIP_ERROR_FORMAT, err.Format()); + } + SuccessOrExit(err); +#endif + // TODO Perform dynamic configuration of the core CHIP objects based on stored settings. // Initialize the CHIP BLE manager. @@ -132,6 +142,10 @@ void GenericPlatformManagerImpl::_Shutdown() ChipLogError(DeviceLayer, "Inet Layer shutdown"); UDPEndPointManager()->Shutdown(); +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + TCPEndPointManager()->Shutdown(); +#endif + #if CHIP_DEVICE_CONFIG_ENABLE_CHIPOBLE ChipLogError(DeviceLayer, "BLE shutdown"); BLEMgr().Shutdown(); diff --git a/src/messaging/ExchangeContext.cpp b/src/messaging/ExchangeContext.cpp index f36a274be1da2a..bca044d91f230d 100644 --- a/src/messaging/ExchangeContext.cpp +++ b/src/messaging/ExchangeContext.cpp @@ -670,5 +670,12 @@ void ExchangeContext::ExchangeSessionHolder::GrabExpiredSession(const SessionHan GrabUnchecked(session); } +#if INET_CONFIG_ENABLE_TCP_ENDPOINT +void ExchangeContext::OnSessionConnectionClosed(CHIP_ERROR conErr) +{ + // TODO: Handle connection closure at the ExchangeContext level. +} +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + } // namespace Messaging } // namespace chip diff --git a/src/messaging/ExchangeContext.h b/src/messaging/ExchangeContext.h index fc20a5aace5273..47cf0ddbef2783 100644 --- a/src/messaging/ExchangeContext.h +++ b/src/messaging/ExchangeContext.h @@ -86,6 +86,9 @@ class DLL_EXPORT ExchangeContext : public ReliableMessageContext, NewSessionHandlingPolicy GetNewSessionHandlingPolicy() override { return NewSessionHandlingPolicy::kStayAtOldSession; } void OnSessionReleased() override; +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + void OnSessionConnectionClosed(CHIP_ERROR conErr) override; +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT /** * Send a CHIP message on this exchange. * diff --git a/src/messaging/ExchangeMgr.cpp b/src/messaging/ExchangeMgr.cpp index 3971864bb7e06e..a184723726781e 100644 --- a/src/messaging/ExchangeMgr.cpp +++ b/src/messaging/ExchangeMgr.cpp @@ -77,6 +77,9 @@ CHIP_ERROR ExchangeManager::Init(SessionManager * sessionManager) sessionManager->SetMessageDelegate(this); +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + sessionManager->SetConnectionDelegate(this); +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT mReliableMessageMgr.Init(sessionManager->SystemLayer()); mState = State::kState_Initialized; @@ -413,5 +416,18 @@ void ExchangeManager::CloseAllContextsForDelegate(const ExchangeDelegate * deleg }); } +#if INET_CONFIG_ENABLE_TCP_ENDPOINT +void ExchangeManager::OnTCPConnectionClosed(const SessionHandle & session, CHIP_ERROR conErr) +{ + mContextPool.ForEachActiveObject([&](auto * ec) { + if (ec->HasSessionHandle() && ec->GetSessionHandle() == session) + { + ec->OnSessionConnectionClosed(conErr); + } + return Loop::Continue; + }); +} +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + } // namespace Messaging } // namespace chip diff --git a/src/messaging/ExchangeMgr.h b/src/messaging/ExchangeMgr.h index 48c6f8df673fb6..b6e416cdbf74e7 100644 --- a/src/messaging/ExchangeMgr.h +++ b/src/messaging/ExchangeMgr.h @@ -49,6 +49,10 @@ static constexpr int16_t kAnyMessageType = -1; * handling the registration/unregistration of unsolicited message handlers. */ class DLL_EXPORT ExchangeManager : public SessionMessageDelegate +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + , + public SessionConnectionDelegate +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT { friend class ExchangeContext; @@ -242,6 +246,9 @@ class DLL_EXPORT ExchangeManager : public SessionMessageDelegate DuplicateMessage isDuplicate, System::PacketBufferHandle && msgBuf) override; void SendStandaloneAckIfNeeded(const PacketHeader & packetHeader, const PayloadHeader & payloadHeader, const SessionHandle & session, MessageFlags msgFlags, System::PacketBufferHandle && msgBuf); +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + void OnTCPConnectionClosed(const SessionHandle & session, CHIP_ERROR conErr) override; +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT }; } // namespace Messaging diff --git a/src/messaging/tests/echo/common.cpp b/src/messaging/tests/echo/common.cpp index 80befc92d27ae6..10eeeb67738024 100644 --- a/src/messaging/tests/echo/common.cpp +++ b/src/messaging/tests/echo/common.cpp @@ -51,10 +51,6 @@ void InitializeChip() err = chip::DeviceLayer::PlatformMgr().InitChipStack(); SuccessOrExit(err); - // Initialize TCP. - err = chip::DeviceLayer::TCPEndPointManager()->Init(chip::DeviceLayer::SystemLayer()); - SuccessOrExit(err); - exit: if (err != CHIP_NO_ERROR) { @@ -68,6 +64,5 @@ void ShutdownChip() gMessageCounterManager.Shutdown(); gExchangeManager.Shutdown(); gSessionManager.Shutdown(); - (void) chip::DeviceLayer::TCPEndPointManager()->Shutdown(); chip::DeviceLayer::PlatformMgr().Shutdown(); } diff --git a/src/messaging/tests/echo/echo_requester.cpp b/src/messaging/tests/echo/echo_requester.cpp index 1bff8e70f55d8d..b06edb359b989f 100644 --- a/src/messaging/tests/echo/echo_requester.cpp +++ b/src/messaging/tests/echo/echo_requester.cpp @@ -35,7 +35,9 @@ #include #include #include +#if INET_CONFIG_ENABLE_TCP_ENDPOINT #include +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT #include #include @@ -49,6 +51,9 @@ namespace { // Max value for the number of EchoRequests sent. constexpr size_t kMaxEchoCount = 3; +// Max value for the number of tcp connect attempts. +constexpr size_t kMaxTCPConnectAttempts = 3; + // The CHIP Echo interval time. constexpr chip::System::Clock::Timeout gEchoInterval = chip::System::Clock::Seconds16(1); @@ -62,9 +67,21 @@ chip::TransportMgrAsSecureSession()->SetTCPConnection(conn); + } + + printf("Successfully established secure session with peer at %s\n", peerAddrBuf); } return err; } +void CloseConnection() +{ + char peerAddrBuf[chip::Transport::PeerAddress::kMaxToStringSize]; + chip::Transport::PeerAddress peerAddr = chip::Transport::PeerAddress::TCP(gDestAddr, CHIP_PORT); + + gSessionManager.TCPDisconnect(peerAddr); + + peerAddr.ToString(peerAddrBuf); + printf("Connection closed to peer at %s\n", peerAddrBuf); + + gClientConEstablished = false; + gClientConInProgress = false; +} + +void HandleConnectionAttemptComplete(chip::Transport::ActiveTCPConnectionState * conn, CHIP_ERROR err) +{ + chip::DeviceLayer::PlatformMgr().StopEventLoopTask(); + + if (err != CHIP_NO_ERROR) + { + printf("Connection FAILED with err: %s\n", chip::ErrorStr(err)); + + gLastEchoTime = chip::System::SystemClock().GetMonotonicTimestamp(); + CloseConnection(); + gTCPConnAttemptCount++; + return; + } + + err = EstablishSecureSession(conn); + if (err != CHIP_NO_ERROR) + { + printf("Secure session FAILED with err: %s\n", chip::ErrorStr(err)); + + gLastEchoTime = chip::System::SystemClock().GetMonotonicTimestamp(); + CloseConnection(); + return; + } + + gClientConEstablished = true; + gClientConInProgress = false; +} + +void HandleConnectionClosed(chip::Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr) +{ + CloseConnection(); +} + +void EstablishTCPConnection() +{ + CHIP_ERROR err = CHIP_NO_ERROR; + // Previous connection attempt underway. + if (gClientConInProgress) + { + return; + } + + gClientConEstablished = false; + + chip::Transport::PeerAddress peerAddr = chip::Transport::PeerAddress::TCP(gDestAddr, CHIP_PORT); + + // Connect to the peer + err = gSessionManager.TCPConnect(peerAddr, &gAppTCPConnCbCtxt, &gActiveTCPConnState); + if (err != CHIP_NO_ERROR) + { + printf("Connection FAILED with err: %s\n", chip::ErrorStr(err)); + + gLastEchoTime = chip::System::SystemClock().GetMonotonicTimestamp(); + CloseConnection(); + gTCPConnAttemptCount++; + return; + } + + gClientConInProgress = true; +} + void HandleEchoResponseReceived(chip::Messaging::ExchangeContext * ec, chip::System::PacketBufferHandle && payload) { chip::System::Clock::Timestamp respTime = chip::System::SystemClock().GetMonotonicTimestamp(); @@ -236,6 +342,10 @@ int main(int argc, char * argv[]) err = gSessionManager.Init(&chip::DeviceLayer::SystemLayer(), &gTCPManager, &gMessageCounterManager, &gStorage, &gFabricTable, gSessionKeystore); SuccessOrExit(err); + + gAppTCPConnCbCtxt.appContext = nullptr; + gAppTCPConnCbCtxt.connCompleteCb = HandleConnectionAttemptComplete; + gAppTCPConnCbCtxt.connClosedCb = HandleConnectionClosed; } else { @@ -255,9 +365,29 @@ int main(int argc, char * argv[]) err = gMessageCounterManager.Init(&gExchangeManager); SuccessOrExit(err); - // Start the CHIP connection to the CHIP echo responder. - err = EstablishSecureSession(); - SuccessOrExit(err); + if (gUseTCP) + { + + while (!gClientConEstablished) + { + // For TCP transport, attempt to establish the connection to the CHIP echo responder. + // On Connection completion, call EstablishSecureSession(conn); + EstablishTCPConnection(); + + chip::DeviceLayer::PlatformMgr().RunEventLoop(); + + if (gTCPConnAttemptCount > kMaxTCPConnectAttempts) + { + ExitNow(); + } + } + } + else + { + // Start the CHIP session to the CHIP echo responder. + err = EstablishSecureSession(nullptr); + SuccessOrExit(err); + } err = gEchoClient.Init(&gExchangeManager, gSession.Get().Value()); SuccessOrExit(err); @@ -274,14 +404,14 @@ int main(int argc, char * argv[]) if (gUseTCP) { - gTCPManager.Disconnect(chip::Transport::PeerAddress::TCP(gDestAddr)); + gTCPManager.TCPDisconnect(chip::Transport::PeerAddress::TCP(gDestAddr)); } gTCPManager.Close(); Shutdown(); exit: - if ((err != CHIP_NO_ERROR) || (gEchoRespCount != kMaxEchoCount)) + if ((err != CHIP_NO_ERROR) || (gEchoRespCount != kMaxEchoCount) || (gTCPConnAttemptCount > kMaxTCPConnectAttempts)) { printf("ChipEchoClient failed: %s\n", chip::ErrorStr(err)); exit(EXIT_FAILURE); diff --git a/src/messaging/tests/echo/echo_responder.cpp b/src/messaging/tests/echo/echo_responder.cpp index 140aab00b0ff23..0cd2366a1f6c2b 100644 --- a/src/messaging/tests/echo/echo_responder.cpp +++ b/src/messaging/tests/echo/echo_responder.cpp @@ -35,7 +35,9 @@ #include #include #include +#if INET_CONFIG_ENABLE_TCP_ENDPOINT #include +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT #include namespace { diff --git a/src/protocols/secure_channel/CASESession.cpp b/src/protocols/secure_channel/CASESession.cpp index 745bd3898fa865..ccc2bc2188d176 100644 --- a/src/protocols/secure_channel/CASESession.cpp +++ b/src/protocols/secure_channel/CASESession.cpp @@ -419,6 +419,21 @@ void CASESession::Clear() mPeerNodeId = kUndefinedNodeId; mFabricsTable = nullptr; mFabricIndex = kUndefinedFabricIndex; +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + // Clear the context object. + mTCPConnCbCtxt.appContext = nullptr; + mTCPConnCbCtxt.connCompleteCb = nullptr; + mTCPConnCbCtxt.connClosedCb = nullptr; + mTCPConnCbCtxt.connReceivedCb = nullptr; + + if (mPeerConnState && mPeerConnState->mConnectionState != Transport::TCPState::kConnected) + { + // Abort the connection if the CASESession is being destroyed and the + // connection is in the middle of being set up. + mSessionManager->TCPDisconnect(mPeerConnState, /* shouldAbort = */ true); + mPeerConnState = nullptr; + } +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT } void CASESession::InvalidateIfPendingEstablishmentOnFabric(FabricIndex fabricIndex) @@ -446,7 +461,9 @@ CHIP_ERROR CASESession::Init(SessionManager & sessionManager, Credentials::Certi ReturnErrorOnFailure(mCommissioningHash.Begin()); - mDelegate = delegate; + mDelegate = delegate; + mSessionManager = &sessionManager; + ReturnErrorOnFailure(AllocateSecureSession(sessionManager, sessionEvictionHint)); mValidContext.Reset(); @@ -454,6 +471,11 @@ CHIP_ERROR CASESession::Init(SessionManager & sessionManager, Credentials::Certi mValidContext.mRequiredKeyPurposes.Set(KeyPurposeFlags::kServerAuth); mValidContext.mValidityPolicy = policy; +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + mTCPConnCbCtxt.appContext = this; + mTCPConnCbCtxt.connCompleteCb = HandleConnectionAttemptComplete; + mTCPConnCbCtxt.connClosedCb = HandleConnectionClosed; +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT return CHIP_NO_ERROR; } @@ -516,12 +538,18 @@ CHIP_ERROR CASESession::EstablishSession(SessionManager & sessionManager, Fabric // This is to make sure the exchange will get closed if Init() returned an error. mExchangeCtxt.Emplace(*exchangeCtxt); + Transport::PeerAddress peerAddress = mExchangeCtxt.Value()->GetSessionHandle()->AsUnauthenticatedSession()->GetPeerAddress(); + // From here onwards, let's go to exit on error, as some state might have already // been initialized SuccessOrExit(err); SuccessOrExit(err = fabricTable->AddFabricDelegate(this)); + // Set the PeerAddress in the secure session up front to indicate the + // Transport Type of the session that is being set up. + mSecureSessionHolder->AsSecureSession()->SetPeerAddress(peerAddress); + mFabricsTable = fabricTable; mFabricIndex = fabricInfo->GetFabricIndex(); mSessionResumptionStorage = sessionResumptionStorage; @@ -534,8 +562,18 @@ CHIP_ERROR CASESession::EstablishSession(SessionManager & sessionManager, Fabric ChipLogProgress(SecureChannel, "Initiating session on local FabricIndex %u from 0x" ChipLogFormatX64 " -> 0x" ChipLogFormatX64, static_cast(mFabricIndex), ChipLogValueX64(mLocalNodeId), ChipLogValueX64(mPeerNodeId)); - err = SendSigma1(); - SuccessOrExit(err); + if (peerAddress.GetTransportType() == Transport::Type::kTcp) + { +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + err = sessionManager.TCPConnect(peerAddress, &mTCPConnCbCtxt, &mPeerConnState); + SuccessOrExit(err); +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + } + else + { + err = SendSigma1(); + SuccessOrExit(err); + } exit: if (err != CHIP_NO_ERROR) @@ -648,6 +686,78 @@ CHIP_ERROR CASESession::RecoverInitiatorIpk() return CHIP_NO_ERROR; } +#if INET_CONFIG_ENABLE_TCP_ENDPOINT +void CASESession::HandleConnectionAttemptComplete(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR err) +{ + VerifyOrReturn(conn != nullptr); + // conn->mAppState should not be NULL. SessionManager has already checked + // before calling this callback. + VerifyOrDie(conn->mAppState != nullptr); + + char peerAddrBuf[chip::Transport::PeerAddress::kMaxToStringSize]; + conn->mPeerAddr.ToString(peerAddrBuf); + + CASESession * caseSession = reinterpret_cast(conn->mAppState->appContext); + VerifyOrReturn(caseSession != nullptr); + + // Exit and disconnect if connection setup encountered an error. + SuccessOrExit(err); + + ChipLogDetail(SecureChannel, "TCP Connection established with %s before session establishment", peerAddrBuf); + + // Associate the connection with the current unauthenticated session for the + // CASE exchange. + caseSession->mExchangeCtxt.Value()->GetSessionHandle()->AsUnauthenticatedSession()->SetTCPConnection(conn); + + // Associate the connection with the current secure session that is being + // set up. + caseSession->mSecureSessionHolder.Get().Value()->AsSecureSession()->SetTCPConnection(conn); + + // Send Sigma1 after connection is established for sessions over TCP + err = caseSession->SendSigma1(); + SuccessOrExit(err); + +exit: + if (err != CHIP_NO_ERROR) + { + ChipLogError(SecureChannel, "Connection establishment failed with peer at %s: %" CHIP_ERROR_FORMAT, peerAddrBuf, + err.Format()); + + // Close the underlying connection and ensure that the CASESession is + // not holding on to a stale ActiveTCPConnectionState. We call + // TCPDisconnect() here explicitly in order to abort the connection + // even after it establishes successfully, but SendSigma1() fails for + // some reason. + caseSession->mSessionManager->TCPDisconnect(conn, /* shouldAbort = */ true); + caseSession->mPeerConnState = nullptr; + + caseSession->Clear(); + } +} + +void CASESession::HandleConnectionClosed(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr) +{ + VerifyOrReturn(conn != nullptr); + // conn->mAppState should not be NULL. SessionManager has already checked + // before calling this callback. + VerifyOrDie(conn->mAppState != nullptr); + + CASESession * caseSession = reinterpret_cast(conn->mAppState->appContext); + VerifyOrReturn(caseSession != nullptr); + + // Drop our pointer to the now-invalid connection state. + // + // Since the connection is closed, message sends over the ExchangeContext + // will just fail and be handled like normal send errors. + // + // Additionally, SessionManager notifies (via ExchangeMgr) all ExchangeContexts on the + // connection closures for the attached sessions and the ExchangeContexts + // can close proactively if that's appropriate. + caseSession->mPeerConnState = nullptr; + ChipLogDetail(SecureChannel, "TCP Connection for this session has closed"); +} +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + CHIP_ERROR CASESession::SendSigma1() { MATTER_TRACE_SCOPE("SendSigma1", "CASESession"); @@ -2143,10 +2253,16 @@ CHIP_ERROR CASESession::OnMessageReceived(ExchangeContext * ec, const PayloadHea #endif // CONFIG_BUILD_FOR_HOST_UNIT_TEST #if CHIP_CONFIG_SLOW_CRYPTO - if (msgType == Protocols::SecureChannel::MsgType::CASE_Sigma1 || msgType == Protocols::SecureChannel::MsgType::CASE_Sigma2 || - msgType == Protocols::SecureChannel::MsgType::CASE_Sigma2Resume || - msgType == Protocols::SecureChannel::MsgType::CASE_Sigma3) - { + if ((msgType == Protocols::SecureChannel::MsgType::CASE_Sigma1 || msgType == Protocols::SecureChannel::MsgType::CASE_Sigma2 || + msgType == Protocols::SecureChannel::MsgType::CASE_Sigma2Resume || + msgType == Protocols::SecureChannel::MsgType::CASE_Sigma3) && + mExchangeCtxt.Value()->GetSessionHandle()->AsUnauthenticatedSession()->GetPeerAddress().GetTransportType() != + Transport::Type::kTcp) + { + // TODO: Rename FlushAcks() to something more semantically correct and + // call unconditionally for TCP or MRP from here. Inside, the + // PeerAddress type could be consulted to selectively flush MRP Acks + // when transport is not TCP. Issue #33183 SuccessOrExit(err = mExchangeCtxt.Value()->FlushAcks()); } #endif // CHIP_CONFIG_SLOW_CRYPTO diff --git a/src/protocols/secure_channel/CASESession.h b/src/protocols/secure_channel/CASESession.h index b7c6b429b950ad..045d1982dd723c 100644 --- a/src/protocols/secure_channel/CASESession.h +++ b/src/protocols/secure_channel/CASESession.h @@ -286,6 +286,22 @@ class DLL_EXPORT CASESession : public Messaging::UnsolicitedMessageHandler, void InvalidateIfPendingEstablishmentOnFabric(FabricIndex fabricIndex); +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + static void HandleConnectionAttemptComplete(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr); + static void HandleConnectionClosed(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr); + + // Context to pass down when connecting to peer + Transport::AppTCPConnectionCallbackCtxt mTCPConnCbCtxt; + // Pointer to the underlying TCP connection state. Returned by the + // TCPConnect() method (on the connection Initiator side) when an + // ActiveTCPConnectionState object is allocated. This connection + // context is used on the CASE Initiator side to facilitate the + // invocation of the callbacks when the connection is established/closed. + // + // This pointer must be nulled out when the connection is closed. + Transport::ActiveTCPConnectionState * mPeerConnState = nullptr; +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + #if CONFIG_BUILD_FOR_HOST_UNIT_TEST void SetStopSigmaHandshakeAt(Optional state) { mStopHandshakeAtState = state; } #endif // CONFIG_BUILD_FOR_HOST_UNIT_TEST @@ -301,6 +317,7 @@ class DLL_EXPORT CASESession : public Messaging::UnsolicitedMessageHandler, uint8_t mIPK[kIPKSize]; SessionResumptionStorage * mSessionResumptionStorage = nullptr; + SessionManager * mSessionManager = nullptr; FabricTable * mFabricsTable = nullptr; FabricIndex mFabricIndex = kUndefinedFabricIndex; diff --git a/src/protocols/secure_channel/PairingSession.cpp b/src/protocols/secure_channel/PairingSession.cpp index 1f7874bdf115dc..ae4ca272858a78 100644 --- a/src/protocols/secure_channel/PairingSession.cpp +++ b/src/protocols/secure_channel/PairingSession.cpp @@ -22,6 +22,7 @@ #include #include #include +#include namespace chip { @@ -58,6 +59,18 @@ void PairingSession::Finish() { Transport::PeerAddress address = mExchangeCtxt.Value()->GetSessionHandle()->AsUnauthenticatedSession()->GetPeerAddress(); +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + if (address.GetTransportType() == Transport::Type::kTcp) + { + // Fetch the connection for the unauthenticated session used to set up + // the secure session. + Transport::ActiveTCPConnectionState * conn = + mExchangeCtxt.Value()->GetSessionHandle()->AsUnauthenticatedSession()->GetTCPConnection(); + + // Associate the connection with the secure session being activated. + mSecureSessionHolder->AsSecureSession()->SetTCPConnection(conn); + } +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT // Discard the exchange so that Clear() doesn't try closing it. The exchange will handle that. DiscardExchange(); diff --git a/src/protocols/user_directed_commissioning/UserDirectedCommissioning.h b/src/protocols/user_directed_commissioning/UserDirectedCommissioning.h index 86fd4d51ed4e48..7a496e5a2f0a33 100644 --- a/src/protocols/user_directed_commissioning/UserDirectedCommissioning.h +++ b/src/protocols/user_directed_commissioning/UserDirectedCommissioning.h @@ -539,7 +539,8 @@ class DLL_EXPORT UserDirectedCommissioningClient : public TransportMgrDelegate } private: - void OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msgBuf) override; + void OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msgBuf, + Transport::MessageTransportContext * ctxt = nullptr) override; CommissionerDeclarationHandler * mCommissionerDeclarationHandler = nullptr; }; @@ -652,7 +653,8 @@ class DLL_EXPORT UserDirectedCommissioningServer : public TransportMgrDelegate void HandleNewUDC(const Transport::PeerAddress & source, IdentificationDeclaration & id); void HandleUDCCancel(IdentificationDeclaration & id); void HandleUDCCommissionerPasscodeReady(IdentificationDeclaration & id); - void OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msgBuf) override; + void OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msgBuf, + Transport::MessageTransportContext * ctxt = nullptr) override; UDCClients mUdcClients; // < Active UDC clients diff --git a/src/protocols/user_directed_commissioning/UserDirectedCommissioningClient.cpp b/src/protocols/user_directed_commissioning/UserDirectedCommissioningClient.cpp index 9fc43634d7ec7d..6d7c315ffb5308 100644 --- a/src/protocols/user_directed_commissioning/UserDirectedCommissioningClient.cpp +++ b/src/protocols/user_directed_commissioning/UserDirectedCommissioningClient.cpp @@ -24,6 +24,7 @@ */ #include "UserDirectedCommissioning.h" +#include #ifdef __ZEPHYR__ #include @@ -235,7 +236,8 @@ CHIP_ERROR CommissionerDeclaration::ReadPayload(uint8_t * udcPayload, size_t pay return CHIP_NO_ERROR; } -void UserDirectedCommissioningClient::OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msg) +void UserDirectedCommissioningClient::OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msg, + Transport::MessageTransportContext * ctxt) { char addrBuffer[chip::Transport::PeerAddress::kMaxToStringSize]; source.ToString(addrBuffer); diff --git a/src/protocols/user_directed_commissioning/UserDirectedCommissioningServer.cpp b/src/protocols/user_directed_commissioning/UserDirectedCommissioningServer.cpp index 2efd8d0a33de28..a06bffcec62b3a 100644 --- a/src/protocols/user_directed_commissioning/UserDirectedCommissioningServer.cpp +++ b/src/protocols/user_directed_commissioning/UserDirectedCommissioningServer.cpp @@ -26,6 +26,7 @@ #include "UserDirectedCommissioning.h" #include #include +#include #include @@ -33,7 +34,8 @@ namespace chip { namespace Protocols { namespace UserDirectedCommissioning { -void UserDirectedCommissioningServer::OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msg) +void UserDirectedCommissioningServer::OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msg, + Transport::MessageTransportContext * ctxt) { char addrBuffer[chip::Transport::PeerAddress::kMaxToStringSize]; source.ToString(addrBuffer); diff --git a/src/transport/BUILD.gn b/src/transport/BUILD.gn index fa556d40afd5b0..c4649c9b92d27c 100644 --- a/src/transport/BUILD.gn +++ b/src/transport/BUILD.gn @@ -37,6 +37,7 @@ static_library("transport") { "SecureSessionTable.h", "Session.cpp", "Session.h", + "SessionConnectionDelegate.h", "SessionDelegate.h", "SessionHolder.cpp", "SessionHolder.h", diff --git a/src/transport/Session.h b/src/transport/Session.h index 0b6048d5c077c0..d9840ec3ee33f1 100644 --- a/src/transport/Session.h +++ b/src/transport/Session.h @@ -28,6 +28,9 @@ #include #include #include +#if INET_CONFIG_ENABLE_TCP_ENDPOINT +#include +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT namespace chip { namespace Transport { @@ -225,6 +228,15 @@ class Session bool IsUnauthenticatedSession() const { return GetSessionType() == SessionType::kUnauthenticated; } +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + // This API is used to associate the connection with the session when the + // latter is about to be marked active. It is also used to reset the + // connection to a nullptr when the connection is lost and the session + // is marked as Defunct. + ActiveTCPConnectionState * GetTCPConnection() const { return mTCPConnection; } + void SetTCPConnection(ActiveTCPConnectionState * conn) { mTCPConnection = conn; } +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + void DispatchSessionEvent(SessionDelegate::Event event) { // Holders might remove themselves when notified. @@ -264,6 +276,15 @@ class Session private: FabricIndex mFabricIndex = kUndefinedFabricIndex; +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + // The underlying TCP connection object over which the session is + // established. + // The lifetime of this member connection pointer is, essentially, the same + // as that of the underlying connection with the peer. + // It would remain as a nullptr for all sessions that are not set up over + // a TCP connection. + ActiveTCPConnectionState * mTCPConnection = nullptr; +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT }; // diff --git a/src/transport/SessionConnectionDelegate.h b/src/transport/SessionConnectionDelegate.h new file mode 100644 index 00000000000000..4557d0ae107e68 --- /dev/null +++ b/src/transport/SessionConnectionDelegate.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2023 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace chip { + +/** + * @brief + * This class defines a delegate that will be called by the SessionManager on + * specific connection-related (e.g. for TCP) events. If the user of SessionManager + * is interested in receiving these callbacks, they can specialize this class and + * handle each trigger in their implementation of this class. + */ +class DLL_EXPORT SessionConnectionDelegate +{ +public: + virtual ~SessionConnectionDelegate() {} + + /** + * @brief + * Called when the underlying connection for the session is closed. + * + * @param session The handle to the secure session + * @param conErr The connection error code + */ + virtual void OnTCPConnectionClosed(const SessionHandle & session, CHIP_ERROR conErr) = 0; +}; + +} // namespace chip diff --git a/src/transport/SessionDelegate.h b/src/transport/SessionDelegate.h index b9e0a7b8b38b0c..503aaa2b0c4f5c 100644 --- a/src/transport/SessionDelegate.h +++ b/src/transport/SessionDelegate.h @@ -66,6 +66,10 @@ class DLL_EXPORT SessionDelegate * SessionManager to allocate a new session. If they desire to do so, it MUST be done asynchronously. */ virtual void OnSessionHang() {} + +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + virtual void OnSessionConnectionClosed(CHIP_ERROR conErr) {} +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT }; } // namespace chip diff --git a/src/transport/SessionManager.cpp b/src/transport/SessionManager.cpp index 5adfde18b47ff4..d49e66fb6f5116 100644 --- a/src/transport/SessionManager.cpp +++ b/src/transport/SessionManager.cpp @@ -147,6 +147,11 @@ CHIP_ERROR SessionManager::Init(System::Layer * systemLayer, TransportMgrBase * mTransportMgr->SetSessionManager(this); +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + mConnCompleteCb = nullptr; + mConnClosedCb = nullptr; +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + return CHIP_NO_ERROR; } @@ -602,7 +607,8 @@ CHIP_ERROR SessionManager::InjectCaseSessionWithTestKey(SessionHolder & sessionH return CHIP_NO_ERROR; } -void SessionManager::OnMessageReceived(const PeerAddress & peerAddress, System::PacketBufferHandle && msg) +void SessionManager::OnMessageReceived(const PeerAddress & peerAddress, System::PacketBufferHandle && msg, + Transport::MessageTransportContext * ctxt) { PacketHeader partialPacketHeader; @@ -621,20 +627,151 @@ void SessionManager::OnMessageReceived(const PeerAddress & peerAddress, System:: } else { - SecureUnicastMessageDispatch(partialPacketHeader, peerAddress, std::move(msg)); + SecureUnicastMessageDispatch(partialPacketHeader, peerAddress, std::move(msg), ctxt); } } else { - UnauthenticatedMessageDispatch(partialPacketHeader, peerAddress, std::move(msg)); + UnauthenticatedMessageDispatch(partialPacketHeader, peerAddress, std::move(msg), ctxt); + } +} + +#if INET_CONFIG_ENABLE_TCP_ENDPOINT +void SessionManager::HandleConnectionReceived(Transport::ActiveTCPConnectionState * conn) +{ + char peerAddrBuf[chip::Transport::PeerAddress::kMaxToStringSize]; + + VerifyOrReturn(conn != nullptr); + conn->mPeerAddr.ToString(peerAddrBuf); + ChipLogProgress(Inet, "Received TCP connection request from %s.", peerAddrBuf); + + Transport::AppTCPConnectionCallbackCtxt * appTCPConnCbCtxt = conn->mAppState; + if (appTCPConnCbCtxt != nullptr && appTCPConnCbCtxt->connReceivedCb != nullptr) + { + appTCPConnCbCtxt->connReceivedCb(conn); + } +} + +void SessionManager::HandleConnectionAttemptComplete(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr) +{ + VerifyOrReturn(conn != nullptr); + + Transport::AppTCPConnectionCallbackCtxt * appTCPConnCbCtxt = conn->mAppState; + if (appTCPConnCbCtxt == nullptr) + { + TCPDisconnect(conn, /* shouldAbort = */ true); + return; + } + + if (appTCPConnCbCtxt->connCompleteCb != nullptr) + { + appTCPConnCbCtxt->connCompleteCb(conn, conErr); + } + else + { + char peerAddrBuf[chip::Transport::PeerAddress::kMaxToStringSize]; + conn->mPeerAddr.ToString(peerAddrBuf); + + ChipLogProgress(Inet, "TCP Connection established with peer %s, but no registered handler. Disconnecting.", peerAddrBuf); + + // Close the connection + TCPDisconnect(conn, /* shouldAbort = */ true); + } +} + +void SessionManager::HandleConnectionClosed(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr) +{ + VerifyOrReturn(conn != nullptr); + + // Mark the corresponding secure sessions as defunct + mSecureSessions.ForEachSession([&](auto session) { + if (session->IsActiveSession() && session->GetTCPConnection() == conn) + { + SessionHandle handle(*session); + // Notify the SessionConnection delegate of the connection + // closure. + if (mConnDelegate != nullptr) + { + mConnDelegate->OnTCPConnectionClosed(handle, conErr); + } + + // Dis-associate the connection from session by setting it to a + // nullptr. + session->SetTCPConnection(nullptr); + // Mark session as defunct + session->MarkAsDefunct(); + } + + return Loop::Continue; + }); + + // TODO: A mechanism to mark an unauthenticated session as unusable when + // the underlying connection is broken. Issue #32323 + + Transport::AppTCPConnectionCallbackCtxt * appTCPConnCbCtxt = conn->mAppState; + VerifyOrReturn(appTCPConnCbCtxt != nullptr); + + if (appTCPConnCbCtxt->connClosedCb != nullptr) + { + appTCPConnCbCtxt->connClosedCb(conn, conErr); + } +} + +CHIP_ERROR SessionManager::TCPConnect(const PeerAddress & peerAddress, Transport::AppTCPConnectionCallbackCtxt * appState, + Transport::ActiveTCPConnectionState ** peerConnState) +{ + char peerAddrBuf[chip::Transport::PeerAddress::kMaxToStringSize]; + peerAddress.ToString(peerAddrBuf); + if (mTransportMgr != nullptr) + { + ChipLogProgress(Inet, "Connecting over TCP with peer at %s.", peerAddrBuf); + return mTransportMgr->TCPConnect(peerAddress, appState, peerConnState); + } + + ChipLogError(Inet, "The transport manager is not initialized. Unable to connect to peer at %s.", peerAddrBuf); + + return CHIP_ERROR_INCORRECT_STATE; +} + +CHIP_ERROR SessionManager::TCPDisconnect(const PeerAddress & peerAddress) +{ + if (mTransportMgr != nullptr) + { + char peerAddrBuf[chip::Transport::PeerAddress::kMaxToStringSize]; + peerAddress.ToString(peerAddrBuf); + ChipLogProgress(Inet, "Disconnecting TCP connection from peer at %s.", peerAddrBuf); + mTransportMgr->TCPDisconnect(peerAddress); } + + return CHIP_NO_ERROR; } +void SessionManager::TCPDisconnect(Transport::ActiveTCPConnectionState * conn, bool shouldAbort) +{ + if (mTransportMgr != nullptr && conn != nullptr) + { + char peerAddrBuf[chip::Transport::PeerAddress::kMaxToStringSize]; + conn->mPeerAddr.ToString(peerAddrBuf); + ChipLogProgress(Inet, "Disconnecting TCP connection from peer at %s.", peerAddrBuf); + mTransportMgr->TCPDisconnect(conn, shouldAbort); + } +} +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + void SessionManager::UnauthenticatedMessageDispatch(const PacketHeader & partialPacketHeader, - const Transport::PeerAddress & peerAddress, System::PacketBufferHandle && msg) + const Transport::PeerAddress & peerAddress, System::PacketBufferHandle && msg, + Transport::MessageTransportContext * ctxt) { MATTER_TRACE_SCOPE("Unauthenticated Message Dispatch", "SessionManager"); +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + if (peerAddress.GetTransportType() == Transport::Type::kTcp && ctxt->conn == nullptr) + { + ChipLogError(Inet, "Connection object is missing for received message."); + return; + } +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + // Drop unsecured messages with privacy enabled. if (partialPacketHeader.HasPrivacyFlag()) { @@ -660,7 +797,7 @@ void SessionManager::UnauthenticatedMessageDispatch(const PacketHeader & partial if (source.HasValue()) { // Assume peer is the initiator, we are the responder. - optionalSession = mUnauthenticatedSessions.FindOrAllocateResponder(source.Value(), GetDefaultMRPConfig()); + optionalSession = mUnauthenticatedSessions.FindOrAllocateResponder(source.Value(), GetDefaultMRPConfig(), peerAddress); if (!optionalSession.HasValue()) { ChipLogError(Inet, "UnauthenticatedSession exhausted"); @@ -670,7 +807,7 @@ void SessionManager::UnauthenticatedMessageDispatch(const PacketHeader & partial else { // Assume peer is the responder, we are the initiator. - optionalSession = mUnauthenticatedSessions.FindInitiator(destination.Value()); + optionalSession = mUnauthenticatedSessions.FindInitiator(destination.Value(), peerAddress); if (!optionalSession.HasValue()) { ChipLogProgress(Inet, "Received unknown unsecure packet for initiator 0x" ChipLogFormatX64, @@ -685,6 +822,25 @@ void SessionManager::UnauthenticatedMessageDispatch(const PacketHeader & partial CorrectPeerAddressInterfaceID(mutablePeerAddress); unsecuredSession->SetPeerAddress(mutablePeerAddress); SessionMessageDelegate::DuplicateMessage isDuplicate = SessionMessageDelegate::DuplicateMessage::No; +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + // Associate the unauthenticated session with the connection, if not done already. + if (peerAddress.GetTransportType() == Transport::Type::kTcp) + { + Transport::ActiveTCPConnectionState * sessionConn = unsecuredSession->GetTCPConnection(); + if (sessionConn == nullptr) + { + unsecuredSession->SetTCPConnection(ctxt->conn); + } + else + { + if (sessionConn != ctxt->conn) + { + ChipLogError(Inet, "Data received over wrong connection %p. Dropping it!", ctxt->conn); + return; + } + } + } +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT unsecuredSession->MarkActiveRx(); @@ -723,13 +879,55 @@ void SessionManager::UnauthenticatedMessageDispatch(const PacketHeader & partial } void SessionManager::SecureUnicastMessageDispatch(const PacketHeader & partialPacketHeader, - const Transport::PeerAddress & peerAddress, System::PacketBufferHandle && msg) + const Transport::PeerAddress & peerAddress, System::PacketBufferHandle && msg, + Transport::MessageTransportContext * ctxt) { MATTER_TRACE_SCOPE("Secure Unicast Message Dispatch", "SessionManager"); CHIP_ERROR err = CHIP_NO_ERROR; +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + if (peerAddress.GetTransportType() == Transport::Type::kTcp && ctxt->conn == nullptr) + { + ChipLogError(Inet, "Connection object is missing for received message."); + return; + } +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + Optional session = mSecureSessions.FindSecureSessionByLocalKey(partialPacketHeader.GetSessionId()); + if (!session.HasValue()) + { + ChipLogError(Inet, "Data received on an unknown session (LSID=%d). Dropping it!", partialPacketHeader.GetSessionId()); + return; + } + + Transport::SecureSession * secureSession = session.Value()->AsSecureSession(); + Transport::PeerAddress mutablePeerAddress = peerAddress; + CorrectPeerAddressInterfaceID(mutablePeerAddress); + if (secureSession->GetPeerAddress() != mutablePeerAddress) + { + secureSession->SetPeerAddress(mutablePeerAddress); + } + +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + // Associate the secure session with the connection, if not done already. + if (peerAddress.GetTransportType() == Transport::Type::kTcp) + { + Transport::ActiveTCPConnectionState * sessionConn = secureSession->GetTCPConnection(); + if (sessionConn == nullptr) + { + secureSession->SetTCPConnection(ctxt->conn); + } + else + { + if (sessionConn != ctxt->conn) + { + ChipLogError(Inet, "Data received over wrong connection %p. Dropping it!", ctxt->conn); + return; + } + } + } +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT PayloadHeader payloadHeader; @@ -751,14 +949,6 @@ void SessionManager::SecureUnicastMessageDispatch(const PacketHeader & partialPa return; } - if (!session.HasValue()) - { - ChipLogError(Inet, "Data received on an unknown session (LSID=%d). Dropping it!", packetHeader.GetSessionId()); - return; - } - - Transport::SecureSession * secureSession = session.Value()->AsSecureSession(); - // We need to allow through messages even on sessions that are pending // evictions, because for some cases (UpdateNOC, RemoveFabric, etc) there // can be a single exchange alive on the session waiting for a MRP ack, and @@ -816,13 +1006,6 @@ void SessionManager::SecureUnicastMessageDispatch(const PacketHeader & partialPa secureSession->GetSessionMessageCounter().GetPeerMessageCounter().CommitEncryptedUnicast(packetHeader.GetMessageCounter()); } - Transport::PeerAddress mutablePeerAddress = peerAddress; - CorrectPeerAddressInterfaceID(mutablePeerAddress); - if (secureSession->GetPeerAddress() != mutablePeerAddress) - { - secureSession->SetPeerAddress(mutablePeerAddress); - } - if (mCB != nullptr) { MATTER_LOG_MESSAGE_RECEIVED(chip::Tracing::IncomingMessageType::kSecureUnicast, &payloadHeader, &packetHeader, @@ -1057,27 +1240,69 @@ void SessionManager::SecureGroupMessageDispatch(const PacketHeader & partialPack } Optional SessionManager::FindSecureSessionForNode(ScopedNodeId peerNodeId, - const Optional & type) + const Optional & type, + TransportPayloadCapability transportPayloadCapability) { - SecureSession * found = nullptr; - - mSecureSessions.ForEachSession([&peerNodeId, &type, &found](auto session) { + SecureSession * mrpSession = nullptr; +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + SecureSession * tcpSession = nullptr; +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + + mSecureSessions.ForEachSession([&peerNodeId, &type, &mrpSession, +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + &tcpSession, +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + &transportPayloadCapability](auto session) { if (session->IsActiveSession() && session->GetPeer() == peerNodeId && (!type.HasValue() || type.Value() == session->GetSecureSessionType())) { - // - // Select the active session with the most recent activity to return back to the caller. - // - if ((found == nullptr) || (found->GetLastActivityTime() < session->GetLastActivityTime())) +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + if ((transportPayloadCapability == TransportPayloadCapability::kMRPOrTCPCompatiblePayload || + transportPayloadCapability == TransportPayloadCapability::kLargePayload) && + session->GetTCPConnection() != nullptr) { - found = session; + // Set up a TCP transport based session as standby + if ((tcpSession == nullptr) || (tcpSession->GetLastActivityTime() < session->GetLastActivityTime())) + { + tcpSession = session; + } + } +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + + if ((mrpSession == nullptr) || (mrpSession->GetLastActivityTime() < session->GetLastActivityTime())) + { + mrpSession = session; } } return Loop::Continue; }); - return found != nullptr ? MakeOptional(*found) : Optional::Missing(); +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + if (transportPayloadCapability == TransportPayloadCapability::kLargePayload) + { + return tcpSession != nullptr ? MakeOptional(*tcpSession) : Optional::Missing(); + } + + if (transportPayloadCapability == TransportPayloadCapability::kMRPOrTCPCompatiblePayload) + { + // If MRP-based session is available, use it. + if (mrpSession != nullptr) + { + return MakeOptional(*mrpSession); + } + + // Otherwise, look for a tcp-based session + if (tcpSession != nullptr) + { + return MakeOptional(*tcpSession); + } + + return Optional::Missing(); + } +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + + return mrpSession != nullptr ? MakeOptional(*mrpSession) : Optional::Missing(); } /** diff --git a/src/transport/SessionManager.h b/src/transport/SessionManager.h index 5f2e6f7603cad0..b7a1630b3d2851 100644 --- a/src/transport/SessionManager.h +++ b/src/transport/SessionManager.h @@ -52,8 +52,30 @@ #include #include +#if INET_CONFIG_ENABLE_TCP_ENDPOINT +#include +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + namespace chip { +/* + * This enum indicates whether a session needs to be established over a + * suitable transport that meets certain payload size requirements for + * transmitted messages. + * + */ +enum class TransportPayloadCapability : uint8_t +{ + kMRPPayload, // Transport requires the maximum payload size to fit within a single + // IPv6 packet(1280 bytes). + kLargePayload, // Transport needs to handle payloads larger than the single IPv6 + // packet, as supported by MRP. The transport of choice, in this + // case, is TCP. + kMRPOrTCPCompatiblePayload // This option provides the ability to use MRP + // as the preferred transport, but use a large + // payload transport if that is already + // available. +}; /** * @brief * Tracks ownership of a encrypted packet buffer. @@ -151,6 +173,10 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate, public FabricTabl /// ExchangeManager) void SetMessageDelegate(SessionMessageDelegate * cb) { mCB = cb; } +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + void SetConnectionDelegate(SessionConnectionDelegate * cb) { mConnDelegate = cb; } +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + // Test-only: create a session on the fly. CHIP_ERROR InjectPaseSessionWithTestKey(SessionHolder & sessionHolder, uint16_t localSessionId, NodeId peerNodeId, uint16_t peerSessionId, FabricIndex fabricIndex, @@ -413,8 +439,34 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate, public FabricTabl * * @param source the source address of the package * @param msgBuf the buffer containing a full CHIP message (except for the optional length field). + * @param ctxt pointer to additional context on the underlying transport. For TCP, it is a pointer + * to the underlying connection object. */ - void OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msgBuf) override; + void OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msgBuf, + Transport::MessageTransportContext * ctxt = nullptr) override; + +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + CHIP_ERROR TCPConnect(const Transport::PeerAddress & peerAddress, Transport::AppTCPConnectionCallbackCtxt * appState, + Transport::ActiveTCPConnectionState ** peerConnState); + + CHIP_ERROR TCPDisconnect(const Transport::PeerAddress & peerAddress); + + void TCPDisconnect(Transport::ActiveTCPConnectionState * conn, bool shouldAbort = 0); + + void HandleConnectionReceived(Transport::ActiveTCPConnectionState * conn) override; + + void HandleConnectionAttemptComplete(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr) override; + + void HandleConnectionClosed(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr) override; + + // Functors for callbacks into higher layers + using OnTCPConnectionReceivedCallback = void (*)(Transport::ActiveTCPConnectionState * conn); + + using OnTCPConnectionCompleteCallback = void (*)(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr); + + using OnTCPConnectionClosedCallback = void (*)(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr); + +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT Optional CreateUnauthenticatedSession(const Transport::PeerAddress & peerAddress, const ReliableMessageProtocolConfig & config) @@ -436,8 +488,9 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate, public FabricTabl // is returned. Otherwise, an Optional with no value set is returned. // // - Optional FindSecureSessionForNode(ScopedNodeId peerNodeId, - const Optional & type = NullOptional); + Optional + FindSecureSessionForNode(ScopedNodeId peerNodeId, const Optional & type = NullOptional, + TransportPayloadCapability transportPayloadCapability = TransportPayloadCapability::kMRPPayload); using SessionHandleCallback = bool (*)(void * context, SessionHandle & sessionHandle); CHIP_ERROR ForEachSessionHandle(void * context, SessionHandleCallback callback); @@ -477,8 +530,22 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate, public FabricTabl State mState; // < Initialization state of the object chip::Transport::GroupOutgoingCounters mGroupClientCounter; +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + OnTCPConnectionReceivedCallback mConnReceivedCb = nullptr; + OnTCPConnectionCompleteCallback mConnCompleteCb = nullptr; + OnTCPConnectionClosedCallback mConnClosedCb = nullptr; + + // Hold the TCPConnection callback context for the receiver application in the SessionManager. + // On receipt of a connection from a peer, the SessionManager + Transport::AppTCPConnectionCallbackCtxt * mServerTCPConnCbCtxt = nullptr; +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + SessionMessageDelegate * mCB = nullptr; +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + SessionConnectionDelegate * mConnDelegate = nullptr; +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + TransportMgrBase * mTransportMgr = nullptr; Transport::MessageCounterManagerInterface * mMessageCounterManager = nullptr; @@ -491,9 +558,11 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate, public FabricTabl * If the message decrypts successfully, this will be filled with a fully decoded PacketHeader. * @param[in] peerAddress The PeerAddress of the message as provided by the receiving Transport Endpoint. * @param msg The full message buffer, including header fields. + * @param ctxt The pointer to additional context on the underlying transport. For TCP, it is a pointer + * to the underlying connection object. */ void SecureUnicastMessageDispatch(const PacketHeader & partialPacketHeader, const Transport::PeerAddress & peerAddress, - System::PacketBufferHandle && msg); + System::PacketBufferHandle && msg, Transport::MessageTransportContext * ctxt = nullptr); /** * @brief Parse, decrypt, validate, and dispatch a secure group message. @@ -511,9 +580,11 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate, public FabricTabl * @param partialPacketHeader The partial PacketHeader of the message after processing with DecodeFixed. * @param peerAddress The PeerAddress of the message as provided by the receiving Transport Endpoint. * @param msg The full message buffer, including header fields. + * @param ctxt The pointer to additional context on the underlying transport. For TCP, it is a pointer + * to the underlying connection object. */ void UnauthenticatedMessageDispatch(const PacketHeader & partialPacketHeader, const Transport::PeerAddress & peerAddress, - System::PacketBufferHandle && msg); + System::PacketBufferHandle && msg, Transport::MessageTransportContext * ctxt = nullptr); void OnReceiveError(CHIP_ERROR error, const Transport::PeerAddress & source); diff --git a/src/transport/TransportMgr.h b/src/transport/TransportMgr.h index 494db39de964fd..6e5f03ab4a44a6 100644 --- a/src/transport/TransportMgr.h +++ b/src/transport/TransportMgr.h @@ -34,6 +34,9 @@ #include #include #include +#if INET_CONFIG_ENABLE_TCP_ENDPOINT +#include +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT namespace chip { @@ -49,8 +52,26 @@ class TransportMgrDelegate * * @param source the source address of the package * @param msgBuf the buffer containing a full CHIP message (except for the optional length field). + * @param ctxt the pointer to additional context on the underlying transport. For TCP, it is a pointer + * to the underlying connection object. */ - virtual void OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msgBuf) = 0; + virtual void OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msgBuf, + Transport::MessageTransportContext * ctxt = nullptr) = 0; + +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + /** + * @brief + * Handle connection attempt completion. + * + * @param conn the connection object + * @param conErr the connection error on the attempt, or CHIP_NO_ERROR. + */ + virtual void HandleConnectionAttemptComplete(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr){}; + + virtual void HandleConnectionClosed(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr){}; + + virtual void HandleConnectionReceived(Transport::ActiveTCPConnectionState * conn){}; +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT }; template diff --git a/src/transport/TransportMgrBase.cpp b/src/transport/TransportMgrBase.cpp index bde54a96c195b6..98997102eb92ca 100644 --- a/src/transport/TransportMgrBase.cpp +++ b/src/transport/TransportMgrBase.cpp @@ -28,11 +28,24 @@ CHIP_ERROR TransportMgrBase::SendMessage(const Transport::PeerAddress & address, return mTransport->SendMessage(address, std::move(msgBuf)); } -void TransportMgrBase::Disconnect(const Transport::PeerAddress & address) +#if INET_CONFIG_ENABLE_TCP_ENDPOINT +CHIP_ERROR TransportMgrBase::TCPConnect(const Transport::PeerAddress & address, Transport::AppTCPConnectionCallbackCtxt * appState, + Transport::ActiveTCPConnectionState ** peerConnState) { - mTransport->Disconnect(address); + return mTransport->TCPConnect(address, appState, peerConnState); } +void TransportMgrBase::TCPDisconnect(const Transport::PeerAddress & address) +{ + mTransport->TCPDisconnect(address); +} + +void TransportMgrBase::TCPDisconnect(Transport::ActiveTCPConnectionState * conn, bool shouldAbort) +{ + mTransport->TCPDisconnect(conn, shouldAbort); +} +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + CHIP_ERROR TransportMgrBase::Init(Transport::Base * transport) { if (mTransport != nullptr) @@ -41,6 +54,7 @@ CHIP_ERROR TransportMgrBase::Init(Transport::Base * transport) } mTransport = transport; mTransport->SetDelegate(this); + ChipLogDetail(Inet, "TransportMgr initialized"); return CHIP_NO_ERROR; } @@ -56,7 +70,8 @@ CHIP_ERROR TransportMgrBase::MulticastGroupJoinLeave(const Transport::PeerAddres return mTransport->MulticastGroupJoinLeave(address, join); } -void TransportMgrBase::HandleMessageReceived(const Transport::PeerAddress & peerAddress, System::PacketBufferHandle && msg) +void TransportMgrBase::HandleMessageReceived(const Transport::PeerAddress & peerAddress, System::PacketBufferHandle && msg, + Transport::MessageTransportContext * ctxt) { // This is the first point all incoming messages funnel through. Ensure // that our message receipts are all synchronized correctly. @@ -73,7 +88,7 @@ void TransportMgrBase::HandleMessageReceived(const Transport::PeerAddress & peer if (mSessionManager != nullptr) { - mSessionManager->OnMessageReceived(peerAddress, std::move(msg)); + mSessionManager->OnMessageReceived(peerAddress, std::move(msg), ctxt); } else { @@ -83,4 +98,60 @@ void TransportMgrBase::HandleMessageReceived(const Transport::PeerAddress & peer } } +#if INET_CONFIG_ENABLE_TCP_ENDPOINT +void TransportMgrBase::HandleConnectionReceived(Transport::ActiveTCPConnectionState * conn) +{ + if (mSessionManager != nullptr) + { + mSessionManager->HandleConnectionReceived(conn); + } + else + { + Transport::TCPBase * tcp = reinterpret_cast(conn->mEndPoint->mAppState); + + // Close connection here since no upper layer is interested in the + // connection. + if (tcp) + { + tcp->TCPDisconnect(conn, /* shouldAbort = */ true); + } + } +} + +void TransportMgrBase::HandleConnectionAttemptComplete(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr) +{ + if (mSessionManager != nullptr) + { + mSessionManager->HandleConnectionAttemptComplete(conn, conErr); + } + else + { + Transport::TCPBase * tcp = reinterpret_cast(conn->mEndPoint->mAppState); + + // Close connection here since no upper layer is interested in the + // connection. + if (tcp) + { + tcp->TCPDisconnect(conn, /* shouldAbort = */ true); + } + } +} + +void TransportMgrBase::HandleConnectionClosed(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr) +{ + if (mSessionManager != nullptr) + { + mSessionManager->HandleConnectionClosed(conn, conErr); + } + else + { + Transport::TCPBase * tcp = reinterpret_cast(conn->mEndPoint->mAppState); + if (tcp) + { + tcp->TCPDisconnect(conn, /* shouldAbort = */ true); + } + } +} +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + } // namespace chip diff --git a/src/transport/TransportMgrBase.h b/src/transport/TransportMgrBase.h index e4942ca6ecd90b..2b0f33bfb9e3a7 100644 --- a/src/transport/TransportMgrBase.h +++ b/src/transport/TransportMgrBase.h @@ -21,6 +21,9 @@ #include #include #include +#if INET_CONFIG_ENABLE_TCP_ENDPOINT +#include +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT namespace chip { @@ -39,13 +42,29 @@ class TransportMgrBase : public Transport::RawTransportDelegate void Close(); - void Disconnect(const Transport::PeerAddress & address); +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + CHIP_ERROR TCPConnect(const Transport::PeerAddress & address, Transport::AppTCPConnectionCallbackCtxt * appState, + Transport::ActiveTCPConnectionState ** peerConnState); + + void TCPDisconnect(const Transport::PeerAddress & address); + + void TCPDisconnect(Transport::ActiveTCPConnectionState * conn, bool shouldAbort = 0); + + void HandleConnectionReceived(Transport::ActiveTCPConnectionState * conn) override; + + void HandleConnectionAttemptComplete(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr) override; + + void HandleConnectionClosed(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr) override; +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT void SetSessionManager(TransportMgrDelegate * sessionManager) { mSessionManager = sessionManager; } + TransportMgrDelegate * GetSessionManager() { return mSessionManager; }; + CHIP_ERROR MulticastGroupJoinLeave(const Transport::PeerAddress & address, bool join); - void HandleMessageReceived(const Transport::PeerAddress & peerAddress, System::PacketBufferHandle && msg) override; + void HandleMessageReceived(const Transport::PeerAddress & peerAddress, System::PacketBufferHandle && msg, + Transport::MessageTransportContext * ctxt = nullptr) override; private: TransportMgrDelegate * mSessionManager = nullptr; diff --git a/src/transport/UnauthenticatedSessionTable.h b/src/transport/UnauthenticatedSessionTable.h index 1c3ff7ed55586a..913058229e1a6b 100644 --- a/src/transport/UnauthenticatedSessionTable.h +++ b/src/transport/UnauthenticatedSessionTable.h @@ -45,9 +45,10 @@ class UnauthenticatedSession : public Session, public ReferenceCounted & sessionTable) : - UnauthenticatedSession(sessionRole, ephemeralInitiatorNodeID, config) + UnauthenticatedSession(sessionRole, ephemeralInitiatorNodeID, peerAddress, config) #if CHIP_SYSTEM_CONFIG_POOL_USE_HEAP , mSessionTable(sessionTable) @@ -224,13 +225,16 @@ class UnauthenticatedSessionTable * @return the session found or allocated, or Optional::Missing if not found and allocation failed. */ CHECK_RETURN_VALUE - Optional FindOrAllocateResponder(NodeId ephemeralInitiatorNodeID, const ReliableMessageProtocolConfig & config) + Optional FindOrAllocateResponder(NodeId ephemeralInitiatorNodeID, const ReliableMessageProtocolConfig & config, + const Transport::PeerAddress & peerAddress) { - UnauthenticatedSession * result = FindEntry(UnauthenticatedSession::SessionRole::kResponder, ephemeralInitiatorNodeID); + UnauthenticatedSession * result = + FindEntry(UnauthenticatedSession::SessionRole::kResponder, ephemeralInitiatorNodeID, peerAddress); if (result != nullptr) return MakeOptional(*result); - CHIP_ERROR err = AllocEntry(UnauthenticatedSession::SessionRole::kResponder, ephemeralInitiatorNodeID, config, result); + CHIP_ERROR err = + AllocEntry(UnauthenticatedSession::SessionRole::kResponder, ephemeralInitiatorNodeID, peerAddress, config, result); if (err == CHIP_NO_ERROR) { return MakeOptional(*result); @@ -239,9 +243,11 @@ class UnauthenticatedSessionTable return Optional::Missing(); } - CHECK_RETURN_VALUE Optional FindInitiator(NodeId ephemeralInitiatorNodeID) + CHECK_RETURN_VALUE Optional FindInitiator(NodeId ephemeralInitiatorNodeID, + const Transport::PeerAddress & peerAddress) { - UnauthenticatedSession * result = FindEntry(UnauthenticatedSession::SessionRole::kInitiator, ephemeralInitiatorNodeID); + UnauthenticatedSession * result = + FindEntry(UnauthenticatedSession::SessionRole::kInitiator, ephemeralInitiatorNodeID, peerAddress); if (result != nullptr) { return MakeOptional(*result); @@ -254,7 +260,8 @@ class UnauthenticatedSessionTable const ReliableMessageProtocolConfig & config) { UnauthenticatedSession * result = nullptr; - CHIP_ERROR err = AllocEntry(UnauthenticatedSession::SessionRole::kInitiator, ephemeralInitiatorNodeID, config, result); + CHIP_ERROR err = + AllocEntry(UnauthenticatedSession::SessionRole::kInitiator, ephemeralInitiatorNodeID, peerAddress, config, result); if (err == CHIP_NO_ERROR) { result->SetPeerAddress(peerAddress); @@ -276,9 +283,10 @@ class UnauthenticatedSessionTable */ CHECK_RETURN_VALUE CHIP_ERROR AllocEntry(UnauthenticatedSession::SessionRole sessionRole, NodeId ephemeralInitiatorNodeID, - const ReliableMessageProtocolConfig & config, UnauthenticatedSession *& entry) + const PeerAddress & peerAddress, const ReliableMessageProtocolConfig & config, + UnauthenticatedSession *& entry) { - auto entryToUse = mEntries.CreateObject(sessionRole, ephemeralInitiatorNodeID, config, *this); + auto entryToUse = mEntries.CreateObject(sessionRole, ephemeralInitiatorNodeID, peerAddress, config, *this); if (entryToUse != nullptr) { entry = entryToUse; @@ -294,7 +302,7 @@ class UnauthenticatedSessionTable // Drop the least recent entry to allow for a new alloc. mEntries.ReleaseObject(entryToUse); - entryToUse = mEntries.CreateObject(sessionRole, ephemeralInitiatorNodeID, config, *this); + entryToUse = mEntries.CreateObject(sessionRole, ephemeralInitiatorNodeID, peerAddress, config, *this); if (entryToUse == nullptr) { @@ -308,11 +316,13 @@ class UnauthenticatedSessionTable } CHECK_RETURN_VALUE UnauthenticatedSession * FindEntry(UnauthenticatedSession::SessionRole sessionRole, - NodeId ephemeralInitiatorNodeID) + NodeId ephemeralInitiatorNodeID, + const Transport::PeerAddress & peerAddress) { UnauthenticatedSession * result = nullptr; mEntries.ForEachActiveObject([&](UnauthenticatedSession * entry) { - if (entry->GetSessionRole() == sessionRole && entry->GetEphemeralInitiatorNodeID() == ephemeralInitiatorNodeID) + if (entry->GetSessionRole() == sessionRole && entry->GetEphemeralInitiatorNodeID() == ephemeralInitiatorNodeID && + entry->GetPeerAddress().GetTransportType() == peerAddress.GetTransportType()) { result = entry; return Loop::Break; diff --git a/src/transport/raw/ActiveTCPConnectionState.h b/src/transport/raw/ActiveTCPConnectionState.h new file mode 100644 index 00000000000000..0f53d4479e9300 --- /dev/null +++ b/src/transport/raw/ActiveTCPConnectionState.h @@ -0,0 +1,125 @@ +/* + * + * Copyright (c) 2023 Project CHIP Authors + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @file + * This file defines the CHIP Active Connection object that maintains TCP connections. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace chip { +namespace Transport { + +/** + * The State of the TCP connection + */ +enum class TCPState +{ + kNotReady = 0, /**< State before initialization. */ + kInitialized = 1, /**< State after class is listening and ready. */ + kConnecting = 3, /**< Connection with peer has been initiated. */ + kConnected = 4, /**< Connected with peer and ready for Send/Receive. */ + kClosed = 5, /**< Connection is closed. */ +}; + +struct AppTCPConnectionCallbackCtxt; +/** + * State for each active TCP connection + */ +struct ActiveTCPConnectionState +{ + + void Init(Inet::TCPEndPoint * endPoint, const PeerAddress & peerAddr) + { + mEndPoint = endPoint; + mPeerAddr = peerAddr; + mReceived = nullptr; + mAppState = nullptr; + } + + void Free() + { + mEndPoint->Free(); + mPeerAddr = PeerAddress::Uninitialized(); + mEndPoint = nullptr; + mReceived = nullptr; + mAppState = nullptr; + } + + bool InUse() const { return mEndPoint != nullptr; } + + bool IsConnected() const { return (mEndPoint != nullptr && mConnectionState == TCPState::kConnected); } + + bool IsConnecting() const { return (mEndPoint != nullptr && mConnectionState == TCPState::kConnecting); } + + // Associated endpoint. + Inet::TCPEndPoint * mEndPoint; + + // Peer Node Address + PeerAddress mPeerAddr; + + // Buffers received but not yet consumed. + System::PacketBufferHandle mReceived; + + // Current state of the connection + TCPState mConnectionState; + + // A pointer to an application-specific state object. It should + // represent an object that is at a layer above the SessionManager. The + // SessionManager would accept this object at the time of connecting to + // the peer, and percolate it down to the TransportManager that then, + // should store this state in the corresponding connection object that + // is created. + // At various connection events, this state is passed back to the + // corresponding application. + AppTCPConnectionCallbackCtxt * mAppState = nullptr; + + // KeepAlive interval in seconds + uint16_t mTCPKeepAliveIntervalSecs = CHIP_CONFIG_TCP_KEEPALIVE_INTERVAL_SECS; + uint16_t mTCPMaxNumKeepAliveProbes = CHIP_CONFIG_MAX_TCP_KEEPALIVE_PROBES; +}; + +// Functors for callbacks into higher layers +using OnTCPConnectionReceivedCallback = void (*)(ActiveTCPConnectionState * conn); + +using OnTCPConnectionCompleteCallback = void (*)(ActiveTCPConnectionState * conn, CHIP_ERROR conErr); + +using OnTCPConnectionClosedCallback = void (*)(ActiveTCPConnectionState * conn, CHIP_ERROR conErr); + +/* + * Application callback state that is passed down at connection establishment + * stage. + * */ +struct AppTCPConnectionCallbackCtxt +{ + void * appContext = nullptr; // A pointer to an application context object. + OnTCPConnectionReceivedCallback connReceivedCb = nullptr; + OnTCPConnectionCompleteCallback connCompleteCb = nullptr; + OnTCPConnectionClosedCallback connClosedCb = nullptr; +}; + +} // namespace Transport +} // namespace chip diff --git a/src/transport/raw/BUILD.gn b/src/transport/raw/BUILD.gn index 736b16cdb08477..3e8d3d4761dca3 100644 --- a/src/transport/raw/BUILD.gn +++ b/src/transport/raw/BUILD.gn @@ -14,6 +14,7 @@ import("//build_overrides/chip.gni") import("${chip_root}/src/ble/ble.gni") +import("${chip_root}/src/inet/inet.gni") static_library("raw") { output_name = "libRawTransport" @@ -23,13 +24,20 @@ static_library("raw") { "MessageHeader.cpp", "MessageHeader.h", "PeerAddress.h", - "TCP.cpp", - "TCP.h", "Tuple.h", "UDP.cpp", "UDP.h", ] + if (chip_inet_config_enable_tcp_endpoint) { + sources += [ + "ActiveTCPConnectionState.h", + "TCP.cpp", + "TCP.h", + "TCPConfig.h", + ] + } + if (chip_config_network_layer_ble) { sources += [ "BLE.cpp", diff --git a/src/transport/raw/Base.h b/src/transport/raw/Base.h index 66c01a5fcc3d5d..920f932b82be41 100644 --- a/src/transport/raw/Base.h +++ b/src/transport/raw/Base.h @@ -24,20 +24,38 @@ #pragma once #include +#include #include #include #include #include #include +#if INET_CONFIG_ENABLE_TCP_ENDPOINT +#include +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT namespace chip { namespace Transport { +struct MessageTransportContext +{ +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + ActiveTCPConnectionState * conn = nullptr; +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT +}; + class RawTransportDelegate { public: virtual ~RawTransportDelegate() {} - virtual void HandleMessageReceived(const Transport::PeerAddress & peerAddress, System::PacketBufferHandle && msg) = 0; + virtual void HandleMessageReceived(const Transport::PeerAddress & peerAddress, System::PacketBufferHandle && msg, + MessageTransportContext * ctxt = nullptr) = 0; + +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + virtual void HandleConnectionReceived(ActiveTCPConnectionState * conn){}; + virtual void HandleConnectionAttemptComplete(ActiveTCPConnectionState * conn, CHIP_ERROR conErr){}; + virtual void HandleConnectionClosed(ActiveTCPConnectionState * conn, CHIP_ERROR conErr){}; +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT }; /** @@ -77,10 +95,26 @@ class Base */ virtual bool CanListenMulticast() { return false; } +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + /** + * Connect to the specified peer. + */ + virtual CHIP_ERROR TCPConnect(const PeerAddress & address, Transport::AppTCPConnectionCallbackCtxt * appState, + Transport::ActiveTCPConnectionState ** peerConnState) + { + return CHIP_NO_ERROR; + } + /** * Handle disconnection from the specified peer if currently connected to it. */ - virtual void Disconnect(const PeerAddress & address) {} + virtual void TCPDisconnect(const PeerAddress & address) {} + + /** + * Disconnect on the active connection that is passed in. + */ + virtual void TCPDisconnect(Transport::ActiveTCPConnectionState * conn, bool shouldAbort = 0) {} +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT /** * Enable Listening for multicast messages ( IPV6 UDP only) @@ -97,12 +131,31 @@ class Base * Method used by subclasses to notify that a packet has been received after * any associated headers have been decoded. */ - void HandleMessageReceived(const PeerAddress & source, System::PacketBufferHandle && buffer) + void HandleMessageReceived(const PeerAddress & source, System::PacketBufferHandle && buffer, + MessageTransportContext * ctxt = nullptr) + { + mDelegate->HandleMessageReceived(source, std::move(buffer), ctxt); + } + +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + // Handle an incoming connection request from a peer. + void HandleConnectionReceived(ActiveTCPConnectionState * conn) { mDelegate->HandleConnectionReceived(conn); } + + // Callback during connection establishment to notify of success or any + // error. + void HandleConnectionAttemptComplete(ActiveTCPConnectionState * conn, CHIP_ERROR conErr) + { + mDelegate->HandleConnectionAttemptComplete(conn, conErr); + } + + // Callback to notify the higher layer of an unexpected connection closure. + void HandleConnectionClosed(ActiveTCPConnectionState * conn, CHIP_ERROR conErr) { - mDelegate->HandleMessageReceived(source, std::move(buffer)); + mDelegate->HandleConnectionClosed(conn, conErr); } +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT - RawTransportDelegate * mDelegate; + RawTransportDelegate * mDelegate = nullptr; }; } // namespace Transport diff --git a/src/transport/raw/TCP.cpp b/src/transport/raw/TCP.cpp index a1b1df1fe45a2d..e33590a63a6fdf 100644 --- a/src/transport/raw/TCP.cpp +++ b/src/transport/raw/TCP.cpp @@ -67,8 +67,7 @@ void TCPBase::CloseActiveConnections() { if (mActiveConnections[i].InUse()) { - mActiveConnections[i].Free(); - mUsedEndPointCount--; + CloseConnectionInternal(&mActiveConnections[i], CHIP_NO_ERROR, SuppressCallback::Yes); } } } @@ -77,7 +76,7 @@ CHIP_ERROR TCPBase::Init(TcpListenParameters & params) { CHIP_ERROR err = CHIP_NO_ERROR; - VerifyOrExit(mState == State::kNotReady, err = CHIP_ERROR_INCORRECT_STATE); + VerifyOrExit(mState == TCPState::kNotReady, err = CHIP_ERROR_INCORRECT_STATE); #if INET_CONFIG_ENABLE_TCP_ENDPOINT err = params.GetEndPointManager()->NewEndPoint(&mListenSocket); @@ -90,23 +89,21 @@ CHIP_ERROR TCPBase::Init(TcpListenParameters & params) params.GetInterfaceId().IsPresent()); SuccessOrExit(err); + mListenSocket->mAppState = reinterpret_cast(this); + mListenSocket->OnConnectionReceived = HandleIncomingConnection; + mListenSocket->OnAcceptError = HandleAcceptError; + + mEndpointType = params.GetAddressType(); + err = mListenSocket->Listen(kListenBacklogSize); SuccessOrExit(err); - mListenSocket->mAppState = reinterpret_cast(this); - mListenSocket->OnDataReceived = OnTcpReceive; - mListenSocket->OnConnectComplete = OnConnectionComplete; - mListenSocket->OnConnectionClosed = OnConnectionClosed; - mListenSocket->OnConnectionReceived = OnConnectionReceived; - mListenSocket->OnAcceptError = OnAcceptError; - mEndpointType = params.GetAddressType(); - - mState = State::kInitialized; + mState = TCPState::kInitialized; exit: if (err != CHIP_NO_ERROR) { - ChipLogError(Inet, "Failed to initialize TCP transport: %s", ErrorStr(err)); + ChipLogError(Inet, "Failed to initialize TCP transport: %" CHIP_ERROR_FORMAT, err.Format()); if (mListenSocket) { mListenSocket->Free(); @@ -124,10 +121,24 @@ void TCPBase::Close() mListenSocket->Free(); mListenSocket = nullptr; } - mState = State::kNotReady; + mState = TCPState::kNotReady; +} + +ActiveTCPConnectionState * TCPBase::AllocateConnection() +{ + for (size_t i = 0; i < mActiveConnectionsSize; i++) + { + if (!mActiveConnections[i].InUse()) + { + return &mActiveConnections[i]; + } + } + + return nullptr; } -TCPBase::ActiveConnectionState * TCPBase::FindActiveConnection(const PeerAddress & address) +// Find an ActiveTCPConnectionState corresponding to a peer address +ActiveTCPConnectionState * TCPBase::FindActiveConnection(const PeerAddress & address) { if (address.GetTransportType() != Type::kTcp) { @@ -136,7 +147,7 @@ TCPBase::ActiveConnectionState * TCPBase::FindActiveConnection(const PeerAddress for (size_t i = 0; i < mActiveConnectionsSize; i++) { - if (!mActiveConnections[i].InUse()) + if (!mActiveConnections[i].IsConnected()) { continue; } @@ -153,8 +164,26 @@ TCPBase::ActiveConnectionState * TCPBase::FindActiveConnection(const PeerAddress return nullptr; } -TCPBase::ActiveConnectionState * TCPBase::FindActiveConnection(const Inet::TCPEndPoint * endPoint) +// Find the ActiveTCPConnectionState for a given TCPEndPoint +ActiveTCPConnectionState * TCPBase::FindActiveConnection(const Inet::TCPEndPoint * endPoint) { + for (size_t i = 0; i < mActiveConnectionsSize; i++) + { + if (mActiveConnections[i].mEndPoint == endPoint && mActiveConnections[i].IsConnected()) + { + return &mActiveConnections[i]; + } + } + return nullptr; +} + +ActiveTCPConnectionState * TCPBase::FindInUseConnection(const Inet::TCPEndPoint * endPoint) +{ + if (endPoint == nullptr) + { + return nullptr; + } + for (size_t i = 0; i < mActiveConnectionsSize; i++) { if (mActiveConnections[i].mEndPoint == endPoint) @@ -172,7 +201,7 @@ CHIP_ERROR TCPBase::SendMessage(const Transport::PeerAddress & address, System:: // - actual data VerifyOrReturnError(address.GetTransportType() == Type::kTcp, CHIP_ERROR_INVALID_ARGUMENT); - VerifyOrReturnError(mState == State::kInitialized, CHIP_ERROR_INCORRECT_STATE); + VerifyOrReturnError(mState == TCPState::kInitialized, CHIP_ERROR_INCORRECT_STATE); VerifyOrReturnError(kPacketSizeBytes + msgBuf->DataLength() <= std::numeric_limits::max(), CHIP_ERROR_INVALID_ARGUMENT); @@ -186,7 +215,7 @@ CHIP_ERROR TCPBase::SendMessage(const Transport::PeerAddress & address, System:: // Reuse existing connection if one exists, otherwise a new one // will be established - ActiveConnectionState * connection = FindActiveConnection(address); + ActiveTCPConnectionState * connection = FindActiveConnection(address); if (connection != nullptr) { @@ -196,8 +225,46 @@ CHIP_ERROR TCPBase::SendMessage(const Transport::PeerAddress & address, System:: return SendAfterConnect(address, std::move(msgBuf)); } +CHIP_ERROR TCPBase::StartConnect(const PeerAddress & addr, Transport::AppTCPConnectionCallbackCtxt * appState, + Transport::ActiveTCPConnectionState ** outPeerConnState) +{ +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + ActiveTCPConnectionState * activeConnection = nullptr; + Inet::TCPEndPoint * endPoint = nullptr; + *outPeerConnState = nullptr; + ReturnErrorOnFailure(mListenSocket->GetEndPointManager().NewEndPoint(&endPoint)); + + auto EndPointDeletor = [](Inet::TCPEndPoint * e) { e->Free(); }; + std::unique_ptr endPointHolder(endPoint, EndPointDeletor); + + endPoint->mAppState = reinterpret_cast(this); + endPoint->OnConnectComplete = HandleTCPEndPointConnectComplete; + endPoint->SetConnectTimeout(mConnectTimeout); + + activeConnection = AllocateConnection(); + VerifyOrReturnError(activeConnection != nullptr, CHIP_ERROR_NO_MEMORY); + activeConnection->Init(endPoint, addr); + activeConnection->mAppState = appState; + activeConnection->mConnectionState = TCPState::kConnecting; + // Set the return value of the peer connection state to the allocated + // connection. + *outPeerConnState = activeConnection; + + ReturnErrorOnFailure(endPoint->Connect(addr.GetIPAddress(), addr.GetPort(), addr.GetInterface())); + + mUsedEndPointCount++; + + endPointHolder.release(); + + return CHIP_NO_ERROR; +#else + return CHIP_ERROR_UNSUPPORTED_CHIP_FEATURE; +#endif +} + CHIP_ERROR TCPBase::SendAfterConnect(const PeerAddress & addr, System::PacketBufferHandle && msg) { +#if INET_CONFIG_ENABLE_TCP_ENDPOINT // This will initiate a connection to the specified peer bool alreadyConnecting = false; @@ -224,28 +291,13 @@ CHIP_ERROR TCPBase::SendAfterConnect(const PeerAddress & addr, System::PacketBuf // Ensures sufficient active connections size exist VerifyOrReturnError(mUsedEndPointCount < mActiveConnectionsSize, CHIP_ERROR_NO_MEMORY); -#if INET_CONFIG_ENABLE_TCP_ENDPOINT - Inet::TCPEndPoint * endPoint = nullptr; - ReturnErrorOnFailure(mListenSocket->GetEndPointManager().NewEndPoint(&endPoint)); - auto EndPointDeletor = [](Inet::TCPEndPoint * e) { e->Free(); }; - std::unique_ptr endPointHolder(endPoint, EndPointDeletor); - - endPoint->mAppState = reinterpret_cast(this); - endPoint->OnDataReceived = OnTcpReceive; - endPoint->OnConnectComplete = OnConnectionComplete; - endPoint->OnConnectionClosed = OnConnectionClosed; - endPoint->OnConnectionReceived = OnConnectionReceived; - endPoint->OnAcceptError = OnAcceptError; - endPoint->OnPeerClose = OnPeerClosed; - - ReturnErrorOnFailure(endPoint->Connect(addr.GetIPAddress(), addr.GetPort(), addr.GetInterface())); + Transport::ActiveTCPConnectionState * peerConnState = nullptr; + ReturnErrorOnFailure(StartConnect(addr, nullptr, &peerConnState)); // enqueue the packet once the connection succeeds VerifyOrReturnError(mPendingPackets.CreateObject(addr, std::move(msg)) != nullptr, CHIP_ERROR_NO_MEMORY); mUsedEndPointCount++; - endPointHolder.release(); - return CHIP_NO_ERROR; #else return CHIP_ERROR_UNSUPPORTED_CHIP_FEATURE; @@ -255,7 +307,7 @@ CHIP_ERROR TCPBase::SendAfterConnect(const PeerAddress & addr, System::PacketBuf CHIP_ERROR TCPBase::ProcessReceivedBuffer(Inet::TCPEndPoint * endPoint, const PeerAddress & peerAddress, System::PacketBufferHandle && buffer) { - ActiveConnectionState * state = FindActiveConnection(endPoint); + ActiveTCPConnectionState * state = FindActiveConnection(endPoint); VerifyOrReturnError(state != nullptr, CHIP_ERROR_INTERNAL); state->mReceived.AddToEnd(std::move(buffer)); @@ -275,6 +327,7 @@ CHIP_ERROR TCPBase::ProcessReceivedBuffer(Inet::TCPEndPoint * endPoint, const Pe uint16_t messageSize = LittleEndian::Get16(messageSizeBuf); if (messageSize >= kMaxMessageSize) { + // This message is too long for upper layers. return CHIP_ERROR_MESSAGE_TOO_LONG; } @@ -291,12 +344,15 @@ CHIP_ERROR TCPBase::ProcessReceivedBuffer(Inet::TCPEndPoint * endPoint, const Pe return CHIP_NO_ERROR; } -CHIP_ERROR TCPBase::ProcessSingleMessage(const PeerAddress & peerAddress, ActiveConnectionState * state, uint16_t messageSize) +CHIP_ERROR TCPBase::ProcessSingleMessage(const PeerAddress & peerAddress, ActiveTCPConnectionState * state, uint16_t messageSize) { // We enter with `state->mReceived` containing at least one full message, perhaps in a chain. // `state->mReceived->Start()` currently points to the message data. // On exit, `state->mReceived` will have had `messageSize` bytes consumed, no matter what. System::PacketBufferHandle message; + MessageTransportContext msgContext; + msgContext.conn = state; + if (state->mReceived->DataLength() == messageSize) { // In this case, the head packet buffer contains exactly the message. @@ -321,23 +377,53 @@ CHIP_ERROR TCPBase::ProcessSingleMessage(const PeerAddress & peerAddress, Active message->SetDataLength(messageSize); } - HandleMessageReceived(peerAddress, std::move(message)); + HandleMessageReceived(peerAddress, std::move(message), &msgContext); return CHIP_NO_ERROR; } -void TCPBase::ReleaseActiveConnection(Inet::TCPEndPoint * endPoint) +void TCPBase::CloseConnectionInternal(ActiveTCPConnectionState * connection, CHIP_ERROR err, SuppressCallback suppressCallback) { - for (size_t i = 0; i < mActiveConnectionsSize; i++) + TCPState prevState; + + if (connection == nullptr) { - if (mActiveConnections[i].mEndPoint == endPoint) + return; + } + + if (connection->mConnectionState != TCPState::kClosed && connection->mEndPoint) + { + if (err == CHIP_NO_ERROR) { - mActiveConnections[i].Free(); - mUsedEndPointCount--; + connection->mEndPoint->Close(); } + else + { + connection->mEndPoint->Abort(); + } + + prevState = connection->mConnectionState; + connection->mConnectionState = TCPState::kClosed; + + if (suppressCallback == SuppressCallback::No) + { + if (prevState == TCPState::kConnecting) + { + // Call upper layer connection attempt complete handler + HandleConnectionAttemptComplete(connection, err); + } + else + { + // Call upper layer connection closed handler + HandleConnectionClosed(connection, err); + } + } + + connection->Free(); + mUsedEndPointCount--; } } -CHIP_ERROR TCPBase::OnTcpReceive(Inet::TCPEndPoint * endPoint, System::PacketBufferHandle && buffer) +CHIP_ERROR TCPBase::HandleTCPEndPointDataReceived(Inet::TCPEndPoint * endPoint, System::PacketBufferHandle && buffer) { Inet::IPAddress ipAddress; uint16_t port; @@ -353,13 +439,13 @@ CHIP_ERROR TCPBase::OnTcpReceive(Inet::TCPEndPoint * endPoint, System::PacketBuf if (err != CHIP_NO_ERROR) { // Connection could need to be closed at this point - ChipLogError(Inet, "Failed to accept received TCP message: %s", ErrorStr(err)); + ChipLogError(Inet, "Failed to accept received TCP message: %" CHIP_ERROR_FORMAT, err.Format()); return CHIP_ERROR_UNEXPECTED_EVENT; } return CHIP_NO_ERROR; } -void TCPBase::OnConnectionComplete(Inet::TCPEndPoint * endPoint, CHIP_ERROR inetErr) +void TCPBase::HandleTCPEndPointConnectComplete(Inet::TCPEndPoint * endPoint, CHIP_ERROR conErr) { CHIP_ERROR err = CHIP_NO_ERROR; bool foundPendingPacket = false; @@ -367,157 +453,229 @@ void TCPBase::OnConnectionComplete(Inet::TCPEndPoint * endPoint, CHIP_ERROR inet Inet::IPAddress ipAddress; uint16_t port; Inet::InterfaceId interfaceId; + ActiveTCPConnectionState * activeConnection = nullptr; endPoint->GetPeerInfo(&ipAddress, &port); endPoint->GetInterfaceId(&interfaceId); + char addrStr[Transport::PeerAddress::kMaxToStringSize]; PeerAddress addr = PeerAddress::TCP(ipAddress, port, interfaceId); + addr.ToString(addrStr); - // Send any pending packets - tcp->mPendingPackets.ForEachActiveObject([&](PendingPacket * pending) { - if (pending->mPeerAddress == addr) + if (conErr == CHIP_NO_ERROR) + { + // Set the Data received handler when connection completes + endPoint->OnDataReceived = HandleTCPEndPointDataReceived; + endPoint->OnDataSent = nullptr; + endPoint->OnConnectionClosed = HandleTCPEndPointConnectionClosed; + + activeConnection = tcp->FindInUseConnection(endPoint); + VerifyOrDie(activeConnection != nullptr); + + // Set to Connected state + activeConnection->mConnectionState = TCPState::kConnected; + + // Disable TCP Nagle buffering by setting TCP_NODELAY socket option to true. + // This is to expedite transmission of payload data and not rely on the + // network stack's configuration of collating enough data in the TCP + // window to begin transmission. + err = endPoint->EnableNoDelay(); + if (err != CHIP_NO_ERROR) { - foundPendingPacket = true; - System::PacketBufferHandle buffer = std::move(pending->mPacketBuffer); - tcp->mPendingPackets.ReleaseObject(pending); + tcp->CloseConnectionInternal(activeConnection, err, SuppressCallback::No); + return; + } - if ((inetErr == CHIP_NO_ERROR) && (err == CHIP_NO_ERROR)) + // Send any pending packets that are queued for this connection + tcp->mPendingPackets.ForEachActiveObject([&](PendingPacket * pending) { + if (pending->mPeerAddress == addr) { - err = endPoint->Send(std::move(buffer)); + foundPendingPacket = true; + System::PacketBufferHandle buffer = std::move(pending->mPacketBuffer); + tcp->mPendingPackets.ReleaseObject(pending); + + if ((conErr == CHIP_NO_ERROR) && (err == CHIP_NO_ERROR)) + { + err = endPoint->Send(std::move(buffer)); + } } - } - return Loop::Continue; - }); + return Loop::Continue; + }); - if (err == CHIP_NO_ERROR) - { - err = inetErr; - } + // Set the TCPKeepalive configurations on the established connection + endPoint->EnableKeepAlive(activeConnection->mTCPKeepAliveIntervalSecs, activeConnection->mTCPMaxNumKeepAliveProbes); - if (!foundPendingPacket && (err == CHIP_NO_ERROR)) - { - // Force a close: new connections are only expected when a - // new buffer is being sent. - ChipLogError(Inet, "Connection accepted without pending buffers"); - err = CHIP_ERROR_CONNECTION_CLOSED_UNEXPECTEDLY; - } + ChipLogProgress(Inet, "Connection established successfully with %s.", addrStr); - // cleanup packets or mark as free - if (err != CHIP_NO_ERROR) - { - ChipLogError(Inet, "Connection complete encountered an error: %s", ErrorStr(err)); - endPoint->Free(); - tcp->mUsedEndPointCount--; + // Let higher layer/delegate know that connection is successfully + // established + tcp->HandleConnectionAttemptComplete(activeConnection, CHIP_NO_ERROR); } else { - bool connectionStored = false; - for (size_t i = 0; i < tcp->mActiveConnectionsSize; i++) - { - if (!tcp->mActiveConnections[i].InUse()) - { - tcp->mActiveConnections[i].Init(endPoint); - connectionStored = true; - break; - } - } - - // since we track end points counts, we always expect to store the - // connection. - if (!connectionStored) - { - endPoint->Free(); - ChipLogError(Inet, "Internal logic error: insufficient space to store active connection"); - } + ChipLogError(Inet, "Connection establishment with %s encountered an error: %" CHIP_ERROR_FORMAT, addrStr, err.Format()); + endPoint->Free(); + tcp->mUsedEndPointCount--; } } -void TCPBase::OnConnectionClosed(Inet::TCPEndPoint * endPoint, CHIP_ERROR err) +void TCPBase::HandleTCPEndPointConnectionClosed(Inet::TCPEndPoint * endPoint, CHIP_ERROR err) { - TCPBase * tcp = reinterpret_cast(endPoint->mAppState); + TCPBase * tcp = reinterpret_cast(endPoint->mAppState); + ActiveTCPConnectionState * activeConnection = tcp->FindInUseConnection(endPoint); - ChipLogProgress(Inet, "Connection closed."); + if (activeConnection == nullptr) + { + endPoint->Free(); + return; + } - ChipLogProgress(Inet, "Freeing closed connection."); - tcp->ReleaseActiveConnection(endPoint); + if (err == CHIP_NO_ERROR && activeConnection->IsConnected()) + { + err = CHIP_ERROR_CONNECTION_CLOSED_UNEXPECTEDLY; + } + + tcp->CloseConnectionInternal(activeConnection, err, SuppressCallback::No); } -void TCPBase::OnConnectionReceived(Inet::TCPEndPoint * listenEndPoint, Inet::TCPEndPoint * endPoint, - const Inet::IPAddress & peerAddress, uint16_t peerPort) +// Handler for incoming connection requests from peer nodes +void TCPBase::HandleIncomingConnection(Inet::TCPEndPoint * listenEndPoint, Inet::TCPEndPoint * endPoint, + const Inet::IPAddress & peerAddress, uint16_t peerPort) { - TCPBase * tcp = reinterpret_cast(listenEndPoint->mAppState); + TCPBase * tcp = reinterpret_cast(listenEndPoint->mAppState); + ActiveTCPConnectionState * activeConnection = nullptr; + Inet::InterfaceId interfaceId; + Inet::IPAddress ipAddress; + uint16_t port; + + endPoint->GetPeerInfo(&ipAddress, &port); + endPoint->GetInterfaceId(&interfaceId); + PeerAddress addr = PeerAddress::TCP(ipAddress, port, interfaceId); if (tcp->mUsedEndPointCount < tcp->mActiveConnectionsSize) { - // have space to use one more (even if considering pending connections) - for (size_t i = 0; i < tcp->mActiveConnectionsSize; i++) - { - if (!tcp->mActiveConnections[i].InUse()) - { - tcp->mActiveConnections[i].Init(endPoint); - tcp->mUsedEndPointCount++; - break; - } - } + activeConnection = tcp->AllocateConnection(); + + endPoint->mAppState = listenEndPoint->mAppState; + endPoint->OnDataReceived = HandleTCPEndPointDataReceived; + endPoint->OnDataSent = nullptr; + endPoint->OnConnectionClosed = HandleTCPEndPointConnectionClosed; + + // By default, disable TCP Nagle buffering by setting TCP_NODELAY socket option to true + endPoint->EnableNoDelay(); - endPoint->mAppState = listenEndPoint->mAppState; - endPoint->OnDataReceived = OnTcpReceive; - endPoint->OnConnectComplete = OnConnectionComplete; - endPoint->OnConnectionClosed = OnConnectionClosed; - endPoint->OnConnectionReceived = OnConnectionReceived; - endPoint->OnAcceptError = OnAcceptError; - endPoint->OnPeerClose = OnPeerClosed; + // Update state for the active connection + activeConnection->Init(endPoint, addr); + tcp->mUsedEndPointCount++; + activeConnection->mConnectionState = TCPState::kConnected; + + char addrStr[Transport::PeerAddress::kMaxToStringSize]; + peerAddress.ToString(addrStr); + ChipLogProgress(Inet, "Incoming connection established with peer at %s.", addrStr); + + // Call the upper layer handler for incoming connection received. + tcp->HandleConnectionReceived(activeConnection); } else { - ChipLogError(Inet, "Insufficient connection space to accept new connections"); + ChipLogError(Inet, "Insufficient connection space to accept new connections."); endPoint->Free(); + listenEndPoint->OnAcceptError(endPoint, CHIP_ERROR_TOO_MANY_CONNECTIONS); } } -void TCPBase::OnAcceptError(Inet::TCPEndPoint * endPoint, CHIP_ERROR err) +void TCPBase::HandleAcceptError(Inet::TCPEndPoint * endPoint, CHIP_ERROR err) +{ + endPoint->Free(); + ChipLogError(Inet, "Accept error: %" CHIP_ERROR_FORMAT, err.Format()); +} + +CHIP_ERROR TCPBase::TCPConnect(const PeerAddress & address, Transport::AppTCPConnectionCallbackCtxt * appState, + Transport::ActiveTCPConnectionState ** outPeerConnState) { - ChipLogError(Inet, "Accept error: %s", ErrorStr(err)); + VerifyOrReturnError(mState == TCPState::kInitialized, CHIP_ERROR_INCORRECT_STATE); + + // Verify that PeerAddress AddressType is TCP + VerifyOrReturnError(address.GetTransportType() == Transport::Type::kTcp, CHIP_ERROR_INVALID_ARGUMENT); + + VerifyOrReturnError(mUsedEndPointCount < mActiveConnectionsSize, CHIP_ERROR_NO_MEMORY); + + char addrStr[Transport::PeerAddress::kMaxToStringSize]; + address.ToString(addrStr); + ChipLogProgress(Inet, "Connecting to peer %s.", addrStr); + + ReturnErrorOnFailure(StartConnect(address, appState, outPeerConnState)); + + return CHIP_NO_ERROR; } -void TCPBase::Disconnect(const PeerAddress & address) +void TCPBase::TCPDisconnect(const PeerAddress & address) { + CHIP_ERROR err = CHIP_NO_ERROR; // Closes an existing connection for (size_t i = 0; i < mActiveConnectionsSize; i++) { - if (mActiveConnections[i].InUse()) + if (mActiveConnections[i].IsConnected()) { Inet::IPAddress ipAddress; uint16_t port; Inet::InterfaceId interfaceId; - mActiveConnections[i].mEndPoint->GetPeerInfo(&ipAddress, &port); - mActiveConnections[i].mEndPoint->GetInterfaceId(&interfaceId); - if (address == PeerAddress::TCP(ipAddress, port, interfaceId)) + err = mActiveConnections[i].mEndPoint->GetPeerInfo(&ipAddress, &port); + if (err != CHIP_NO_ERROR) + { + ChipLogError(Inet, "TCPDisconnect: GetPeerInfo error: %" CHIP_ERROR_FORMAT, err.Format()); + return; + } + + err = mActiveConnections[i].mEndPoint->GetInterfaceId(&interfaceId); + if (err != CHIP_NO_ERROR) + { + ChipLogError(Inet, "TCPDisconnect: GetInterfaceId error: %" CHIP_ERROR_FORMAT, err.Format()); + return; + } + // if (address == PeerAddress::TCP(ipAddress, port, interfaceId)) + if (ipAddress == address.GetIPAddress() && port == address.GetPort()) { + char addrStr[Transport::PeerAddress::kMaxToStringSize]; + address.ToString(addrStr); + ChipLogProgress(Inet, "Disconnecting with peer %s.", addrStr); + // NOTE: this leaves the socket in TIME_WAIT. // Calling Abort() would clean it since SO_LINGER would be set to 0, // however this seems not to be useful. - mActiveConnections[i].Free(); - mUsedEndPointCount--; + CloseConnectionInternal(&mActiveConnections[i], CHIP_NO_ERROR, SuppressCallback::Yes); } } } } -void TCPBase::OnPeerClosed(Inet::TCPEndPoint * endPoint) +void TCPBase::TCPDisconnect(Transport::ActiveTCPConnectionState * conn, bool shouldAbort) { - TCPBase * tcp = reinterpret_cast(endPoint->mAppState); - ChipLogProgress(Inet, "Freeing connection: connection closed by peer"); + if (conn == nullptr) + { + ChipLogError(Inet, "Failed to Disconnect. Passed in Connection is null."); + return; + } - tcp->ReleaseActiveConnection(endPoint); + // This call should be able to disconnect the connection either when it is + // already established, or when it is being set up. + if ((conn->IsConnected() && shouldAbort) || conn->IsConnecting()) + { + CloseConnectionInternal(conn, CHIP_ERROR_CONNECTION_ABORTED, SuppressCallback::Yes); + } + + if (conn->IsConnected() && !shouldAbort) + { + CloseConnectionInternal(conn, CHIP_NO_ERROR, SuppressCallback::Yes); + } } bool TCPBase::HasActiveConnections() const { for (size_t i = 0; i < mActiveConnectionsSize; i++) { - if (mActiveConnections[i].InUse()) + if (mActiveConnections[i].IsConnected()) { return true; } diff --git a/src/transport/raw/TCP.h b/src/transport/raw/TCP.h index d9f78be1771b0f..bb4671215b96c8 100644 --- a/src/transport/raw/TCP.h +++ b/src/transport/raw/TCP.h @@ -34,7 +34,9 @@ #include #include #include +#include #include +#include namespace chip { namespace Transport { @@ -96,45 +98,23 @@ struct PendingPacket /** Implements a transport using TCP. */ class DLL_EXPORT TCPBase : public Base { - /** - * The State of the TCP connection - */ - enum class State - { - kNotReady = 0, /**< State before initialization. */ - kInitialized = 1, /**< State after class is listening and ready. */ - }; protected: - /** - * State for each active connection - */ - struct ActiveConnectionState + enum class ShouldAbort : uint8_t { - void Init(Inet::TCPEndPoint * endPoint) - { - mEndPoint = endPoint; - mReceived = nullptr; - } - - void Free() - { - mEndPoint->Free(); - mEndPoint = nullptr; - mReceived = nullptr; - } - bool InUse() const { return mEndPoint != nullptr; } - - // Associated endpoint. - Inet::TCPEndPoint * mEndPoint; + Yes, + No + }; - // Buffers received but not yet consumed. - System::PacketBufferHandle mReceived; + enum class SuppressCallback : uint8_t + { + Yes, + No }; public: using PendingPacketPoolType = PoolInterface; - TCPBase(ActiveConnectionState * activeConnectionsBuffer, size_t bufferSize, PendingPacketPoolType & packetBuffers) : + TCPBase(ActiveTCPConnectionState * activeConnectionsBuffer, size_t bufferSize, PendingPacketPoolType & packetBuffers) : mActiveConnections(activeConnectionsBuffer), mActiveConnectionsSize(bufferSize), mPendingPackets(packetBuffers) { // activeConnectionsBuffer must be initialized by the caller. @@ -153,6 +133,13 @@ class DLL_EXPORT TCPBase : public Base */ CHIP_ERROR Init(TcpListenParameters & params); + /** + * Set the timeout (in milliseconds) for the node to wait for the TCP + * connection attempt to complete. + * + */ + void SetConnectTimeout(const uint32_t connTimeoutMsecs) { mConnectTimeout = connTimeoutMsecs; } + /** * Close the open endpoint without destroying the object */ @@ -160,14 +147,46 @@ class DLL_EXPORT TCPBase : public Base CHIP_ERROR SendMessage(const PeerAddress & address, System::PacketBufferHandle && msgBuf) override; - void Disconnect(const PeerAddress & address) override; + /* + * Connect to the given peerAddress over TCP. + * + * @param address The address of the peer. + * + * @param appState Context passed in by the application to be sent back + * via the connection attempt complete callback when + * connection attempt with peer completes. + * + * @param outPeerConnState Pointer to pointer to the active TCP connection state. This is + * an output parameter that is allocated by the + * transport layer and held by the caller object. + * This allows the caller object to abort the + * connection attempt if the caller object dies + * before the attempt completes. + * + */ + CHIP_ERROR TCPConnect(const PeerAddress & address, Transport::AppTCPConnectionCallbackCtxt * appState, + Transport::ActiveTCPConnectionState ** outPeerConnState) override; + + void TCPDisconnect(const PeerAddress & address) override; + + // Close an active connection (corresponding to the passed + // ActiveTCPConnectionState object) + // and release from the pool. + void TCPDisconnect(Transport::ActiveTCPConnectionState * conn, bool shouldAbort = false) override; bool CanSendToPeer(const PeerAddress & address) override { - return (mState == State::kInitialized) && (address.GetTransportType() == Type::kTcp) && + return (mState == TCPState::kInitialized) && (address.GetTransportType() == Type::kTcp) && (address.GetIPAddress().Type() == mEndpointType); } + const Optional GetConnectionPeerAddress(const Inet::TCPEndPoint * con) + { + ActiveTCPConnectionState * activeConState = FindActiveConnection(con); + + return activeConState != nullptr ? MakeOptional(activeConState->mPeerAddr) : Optional::Missing(); + } + /** * Helper method to determine if IO processing is still required for a TCP transport * before everything is cleaned up (socket closing is async, so after calling 'Close' on @@ -183,12 +202,22 @@ class DLL_EXPORT TCPBase : public Base private: friend class TCPTest; + /** + * Allocate an unused connection from the pool + * + */ + ActiveTCPConnectionState * AllocateConnection(); /** * Find an active connection to the given peer or return nullptr if * no active connection exists. */ - ActiveConnectionState * FindActiveConnection(const PeerAddress & addr); - ActiveConnectionState * FindActiveConnection(const Inet::TCPEndPoint * endPoint); + ActiveTCPConnectionState * FindActiveConnection(const PeerAddress & addr); + ActiveTCPConnectionState * FindActiveConnection(const Inet::TCPEndPoint * endPoint); + + /** + * Find an allocated connection that matches the corresponding TCPEndPoint. + */ + ActiveTCPConnectionState * FindInUseConnection(const Inet::TCPEndPoint * endPoint); /** * Sends the specified message once a connection has been established. @@ -223,46 +252,63 @@ class DLL_EXPORT TCPBase : public Base * is no other data). * @param[in] messageSize Size of the single message. */ - CHIP_ERROR ProcessSingleMessage(const PeerAddress & peerAddress, ActiveConnectionState * state, uint16_t messageSize); + CHIP_ERROR ProcessSingleMessage(const PeerAddress & peerAddress, ActiveTCPConnectionState * state, uint16_t messageSize); - // Release an active connection (corresponding to the passed TCPEndPoint) - // from the pool. - void ReleaseActiveConnection(Inet::TCPEndPoint * endPoint); + /** + * Initiate a connection to the given peer. On connection completion, + * HandleTCPConnectComplete callback would be called. + * + */ + CHIP_ERROR StartConnect(const PeerAddress & addr, AppTCPConnectionCallbackCtxt * appState, + Transport::ActiveTCPConnectionState ** outPeerConnState); + + /** + * Gracefully Close or Abort a given connection. + * + */ + void CloseConnectionInternal(ActiveTCPConnectionState * connection, CHIP_ERROR err, SuppressCallback suppressCallback); + + // Close the listening socket endpoint + void CloseListeningSocket(); // Callback handler for TCPEndPoint. TCP message receive handler. // @see TCPEndpoint::OnDataReceivedFunct - static CHIP_ERROR OnTcpReceive(Inet::TCPEndPoint * endPoint, System::PacketBufferHandle && buffer); + static CHIP_ERROR HandleTCPEndPointDataReceived(Inet::TCPEndPoint * endPoint, System::PacketBufferHandle && buffer); // Callback handler for TCPEndPoint. Called when a connection has been completed. // @see TCPEndpoint::OnConnectCompleteFunct - static void OnConnectionComplete(Inet::TCPEndPoint * endPoint, CHIP_ERROR err); + static void HandleTCPEndPointConnectComplete(Inet::TCPEndPoint * endPoint, CHIP_ERROR err); // Callback handler for TCPEndPoint. Called when a connection has been closed. // @see TCPEndpoint::OnConnectionClosedFunct - static void OnConnectionClosed(Inet::TCPEndPoint * endPoint, CHIP_ERROR err); - - // Callback handler for TCPEndPoint. Callend when a peer closes the connection. - // @see TCPEndpoint::OnPeerCloseFunct - static void OnPeerClosed(Inet::TCPEndPoint * endPoint); + static void HandleTCPEndPointConnectionClosed(Inet::TCPEndPoint * endPoint, CHIP_ERROR err); // Callback handler for TCPEndPoint. Called when a connection is received on the listening port. // @see TCPEndpoint::OnConnectionReceivedFunct - static void OnConnectionReceived(Inet::TCPEndPoint * listenEndPoint, Inet::TCPEndPoint * endPoint, - const Inet::IPAddress & peerAddress, uint16_t peerPort); + static void HandleIncomingConnection(Inet::TCPEndPoint * listenEndPoint, Inet::TCPEndPoint * endPoint, + const Inet::IPAddress & peerAddress, uint16_t peerPort); - // Called on accept error + // Callback handler for handling accept error // @see TCPEndpoint::OnAcceptErrorFunct - static void OnAcceptError(Inet::TCPEndPoint * endPoint, CHIP_ERROR err); + static void HandleAcceptError(Inet::TCPEndPoint * endPoint, CHIP_ERROR err); Inet::TCPEndPoint * mListenSocket = nullptr; ///< TCP socket used by the transport Inet::IPAddressType mEndpointType = Inet::IPAddressType::kUnknown; ///< Socket listening type - State mState = State::kNotReady; ///< State of the TCP transport + TCPState mState = TCPState::kNotReady; ///< State of the TCP transport + + // The configured timeout for the connection attempt to the peer, before + // giving up. + uint32_t mConnectTimeout = CHIP_CONFIG_TCP_CONNECT_TIMEOUT_MSECS; + + // The max payload size of data over a TCP connection that is transmissible + // at a time. + uint32_t mMaxTCPPayloadSize = CHIP_CONFIG_MAX_TCP_PAYLOAD_SIZE_BYTES; // Number of active and 'pending connection' endpoints size_t mUsedEndPointCount = 0; // Currently active connections - ActiveConnectionState * mActiveConnections; + ActiveTCPConnectionState * mActiveConnections; const size_t mActiveConnectionsSize; // Data to be sent when connections succeed @@ -277,14 +323,15 @@ class TCP : public TCPBase { for (size_t i = 0; i < kActiveConnectionsSize; ++i) { - mConnectionsBuffer[i].Init(nullptr); + mConnectionsBuffer[i].Init(nullptr, PeerAddress::Uninitialized()); } } + ~TCP() override { mPendingPackets.ReleaseAll(); } private: friend class TCPTest; - TCPBase::ActiveConnectionState mConnectionsBuffer[kActiveConnectionsSize]; + ActiveTCPConnectionState mConnectionsBuffer[kActiveConnectionsSize]; PoolImpl mPendingPackets; }; diff --git a/src/transport/raw/TCPConfig.h b/src/transport/raw/TCPConfig.h new file mode 100644 index 00000000000000..d54a9466b4d294 --- /dev/null +++ b/src/transport/raw/TCPConfig.h @@ -0,0 +1,127 @@ +/* + * + * Copyright (c) 2023 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @file + * This file defines default compile-time configuration constants + * for CHIP. + * + * Package integrators that wish to override these values should + * either use preprocessor definitions or create a project- + * specific chipProjectConfig.h header and then assert + * HAVE_CHIPPROJECTCONFIG_H via the package configuration tool + * via --with-chip-project-includes=DIR where DIR is the + * directory that contains the header. + * + * + */ + +#pragma once + +#include + +namespace chip { + +/** + * @def CHIP_CONFIG_MAX_ACTIVE_TCP_CONNECTIONS + * + * @brief Maximum Number of TCP connections a device can simultaneously have + */ +#ifndef CHIP_CONFIG_MAX_ACTIVE_TCP_CONNECTIONS +#define CHIP_CONFIG_MAX_ACTIVE_TCP_CONNECTIONS 4 +#endif + +#if INET_CONFIG_ENABLE_TCP_ENDPOINT && CHIP_CONFIG_MAX_ACTIVE_TCP_CONNECTIONS < 1 +#error "If TCP is enabled, the device needs to support at least 1 TCP connection" +#endif + +#if INET_CONFIG_ENABLE_TCP_ENDPOINT && CHIP_CONFIG_MAX_ACTIVE_TCP_CONNECTIONS > INET_CONFIG_NUM_TCP_ENDPOINTS +#error "If TCP is enabled, the maximum number of connections cannot exceed the number of tcp endpoints" +#endif + +/** + * @def CHIP_CONFIG_MAX_TCP_PENDING_PACKETS + * + * @brief Maximum Number of outstanding pending packets in the queue before a TCP connection + * needs to be established + */ +#ifndef CHIP_CONFIG_MAX_TCP_PENDING_PACKETS +#define CHIP_CONFIG_MAX_TCP_PENDING_PACKETS 4 +#endif + +/** + * @def CHIP_CONFIG_MAX_TCP_PAYLOAD_SIZE_BYTES + * + * @brief Maximum payload size of a message over a TCP connection + */ +#ifndef CHIP_CONFIG_MAX_TCP_PAYLOAD_SIZE_BYTES +#define CHIP_CONFIG_MAX_TCP_PAYLOAD_SIZE_BYTES 1000000 +#endif + +/** + * @def CHIP_CONFIG_TCP_CONNECT_TIMEOUT_MSECS + * + * @brief + * This defines the default timeout for the TCP connect + * attempt to either succeed or notify the caller of an + * error. + * + */ +#ifndef CHIP_CONFIG_TCP_CONNECT_TIMEOUT_MSECS +#define CHIP_CONFIG_TCP_CONNECT_TIMEOUT_MSECS (10000) +#endif // CHIP_CONFIG_TCP_CONNECT_TIMEOUT_MSECS + +/** + * @def CHIP_CONFIG_KEEPALIVE_INTERVAL_SECS + * + * @brief + * This defines the default interval (in seconds) between + * keepalive probes for a TCP connection. + * This value also controls the time between last data + * packet sent and the transmission of the first keepalive + * probe. + * + */ +#ifndef CHIP_CONFIG_TCP_KEEPALIVE_INTERVAL_SECS +#define CHIP_CONFIG_TCP_KEEPALIVE_INTERVAL_SECS (25) +#endif // CHIP_CONFIG_TCP_KEEPALIVE_INTERVAL_SECS + +/** + * @def CHIP_CONFIG_MAX_TCP_KEEPALIVE_PROBES + * + * @brief + * This defines the default value for the maximum number of + * keepalive probes for a TCP connection. + * + */ +#ifndef CHIP_CONFIG_MAX_TCP_KEEPALIVE_PROBES +#define CHIP_CONFIG_MAX_TCP_KEEPALIVE_PROBES (5) +#endif // CHIP_CONFIG_MAX_TCP_KEEPALIVE_PROBES + +/** + * @def CHIP_CONFIG_MAX_UNACKED_DATA_TIMEOUT_SECS + * + * @brief + * This defines the default value for the maximum timeout + * of unacknowledged data for a TCP connection. + * + */ +#ifndef CHIP_CONFIG_MAX_UNACKED_DATA_TIMEOUT_SECS +#define CHIP_CONFIG_MAX_UNACKED_DATA_TIMEOUT_SECS (30) +#endif // CHIP_CONFIG_MAX_UNACKED_DATA_TIMEOUT_SECS + +} // namespace chip diff --git a/src/transport/raw/Tuple.h b/src/transport/raw/Tuple.h index 743b9b9b0e09aa..e3e52a171abcda 100644 --- a/src/transport/raw/Tuple.h +++ b/src/transport/raw/Tuple.h @@ -93,7 +93,20 @@ class Tuple : public Base bool CanSendToPeer(const PeerAddress & address) override { return CanSendToPeerImpl<0>(address); } - void Disconnect(const PeerAddress & address) override { return DisconnectImpl<0>(address); } +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + CHIP_ERROR TCPConnect(const PeerAddress & address, Transport::AppTCPConnectionCallbackCtxt * appState, + Transport::ActiveTCPConnectionState ** peerConnState) override + { + return TCPConnectImpl<0>(address, appState, peerConnState); + } + + void TCPDisconnect(const PeerAddress & address) override { return TCPDisconnectImpl<0>(address); } + + void TCPDisconnect(Transport::ActiveTCPConnectionState * conn, bool shouldAbort = 0) override + { + return TCPDisconnectImpl<0>(conn, shouldAbort); + } +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT void Close() override { return CloseImpl<0>(); } @@ -138,26 +151,78 @@ class Tuple : public Base return false; } +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + /** + * Recursive TCPConnect implementation iterating through transport members. + * + * @tparam N the index of the underlying transport to send disconnect to + * + * @param address what address to connect to. + */ + template ::type * = nullptr> + CHIP_ERROR TCPConnectImpl(const PeerAddress & address, Transport::AppTCPConnectionCallbackCtxt * appState, + Transport::ActiveTCPConnectionState ** peerConnState) + { + Base * base = &std::get(mTransports); + if (base->CanSendToPeer(address)) + { + return base->TCPConnect(address, appState, peerConnState); + } + return TCPConnectImpl(address, appState, peerConnState); + } + + /** + * TCPConnectImpl template for out of range N. + */ + template = sizeof...(TransportTypes))>::type * = nullptr> + CHIP_ERROR TCPConnectImpl(const PeerAddress & address, Transport::AppTCPConnectionCallbackCtxt * appState, + Transport::ActiveTCPConnectionState ** peerConnState) + { + return CHIP_ERROR_NO_MESSAGE_HANDLER; + } + /** * Recursive disconnect implementation iterating through transport members. * * @tparam N the index of the underlying transport to send disconnect to * - * @param address what address to check. + * @param address what address to disconnect from. + */ + template ::type * = nullptr> + void TCPDisconnectImpl(const PeerAddress & address) + { + std::get(mTransports).TCPDisconnect(address); + TCPDisconnectImpl(address); + } + + /** + * TCPDisconnectImpl template for out of range N. + */ + template = sizeof...(TransportTypes))>::type * = nullptr> + void TCPDisconnectImpl(const PeerAddress & address) + {} + + /** + * Recursive disconnect implementation iterating through transport members. + * + * @tparam N the index of the underlying transport to send disconnect to + * + * @param conn pointer to the connection to the peer. */ template ::type * = nullptr> - void DisconnectImpl(const PeerAddress & address) + void TCPDisconnectImpl(Transport::ActiveTCPConnectionState * conn, bool shouldAbort = 0) { - std::get(mTransports).Disconnect(address); - DisconnectImpl(address); + std::get(mTransports).TCPDisconnect(conn, shouldAbort); + TCPDisconnectImpl(conn, shouldAbort); } /** - * DisconnectImpl template for out of range N. + * TCPDisconnectImpl template for out of range N. */ template = sizeof...(TransportTypes))>::type * = nullptr> - void DisconnectImpl(const PeerAddress & address) + void TCPDisconnectImpl(Transport::ActiveTCPConnectionState * conn, bool shouldAbort = 0) {} +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT /** * Recursive disconnect implementation iterating through transport members. diff --git a/src/transport/raw/tests/BUILD.gn b/src/transport/raw/tests/BUILD.gn index c655586c5a0e35..973626af7f340b 100644 --- a/src/transport/raw/tests/BUILD.gn +++ b/src/transport/raw/tests/BUILD.gn @@ -14,7 +14,9 @@ import("//build_overrides/build.gni") import("//build_overrides/chip.gni") +import("//build_overrides/nlunit_test.gni") import("//build_overrides/pigweed.gni") +import("${chip_root}/src/inet/inet.gni") import("${chip_root}/build/chip/chip_test_suite.gni") static_library("helpers") { @@ -40,10 +42,13 @@ chip_test_suite("tests") { test_sources = [ "TestMessageHeader.cpp", "TestPeerAddress.cpp", - "TestTCP.cpp", "TestUDP.cpp", ] + if (chip_inet_config_enable_tcp_endpoint) { + test_sources += [ "TestTCP.cpp" ] + } + public_deps = [ ":helpers", "${chip_root}/src/inet/tests:helpers", diff --git a/src/transport/raw/tests/TestTCP.cpp b/src/transport/raw/tests/TestTCP.cpp index 93414e01e3d653..6a60fb330c70f7 100644 --- a/src/transport/raw/tests/TestTCP.cpp +++ b/src/transport/raw/tests/TestTCP.cpp @@ -23,6 +23,7 @@ #include "NetworkTestHelpers.h" +#include #include #include #include @@ -30,7 +31,9 @@ #include #include #include +#if INET_CONFIG_ENABLE_TCP_ENDPOINT #include +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT #include @@ -47,6 +50,9 @@ namespace { constexpr size_t kMaxTcpActiveConnectionCount = 4; constexpr size_t kMaxTcpPendingPackets = 4; constexpr uint16_t kPacketSizeBytes = static_cast(sizeof(uint16_t)); +uint16_t gChipTCPPort = static_cast(CHIP_PORT + chip::Crypto::GetRandU16() % 100); +chip::Transport::AppTCPConnectionCallbackCtxt gAppTCPConnCbCtxt; +chip::Transport::ActiveTCPConnectionState * gActiveTCPConnState = nullptr; using TCPImpl = Transport::TCP; @@ -71,7 +77,8 @@ class MockTransportMgrDelegate : public chip::TransportMgrDelegate mCallback = callback; mCallbackData = callback_data; } - void OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msgBuf) override + void OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msgBuf, + Transport::MessageTransportContext * transCtxt = nullptr) override { PacketHeader packetHeader; @@ -82,12 +89,54 @@ class MockTransportMgrDelegate : public chip::TransportMgrDelegate EXPECT_EQ(mCallback(msgBuf->Start(), msgBuf->DataLength(), mReceiveHandlerCallCount, mCallbackData), 0); } + ChipLogProgress(Inet, "Message Receive Handler called"); + mReceiveHandlerCallCount++; } + void HandleConnectionAttemptComplete(chip::Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr) override + { + chip::Transport::AppTCPConnectionCallbackCtxt * appConnCbCtxt = nullptr; + VerifyOrReturn(conn != nullptr); + + mHandleConnectionCompleteCalled = true; + appConnCbCtxt = conn->mAppState; + VerifyOrReturn(appConnCbCtxt != nullptr); + + if (appConnCbCtxt->connCompleteCb != nullptr) + { + appConnCbCtxt->connCompleteCb(conn, conErr); + } + else + { + ChipLogProgress(Inet, "Connection established. App callback missing."); + } + } + + void HandleConnectionClosed(chip::Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr) override + { + chip::Transport::AppTCPConnectionCallbackCtxt * appConnCbCtxt = nullptr; + VerifyOrReturn(conn != nullptr); + + mHandleConnectionCloseCalled = true; + appConnCbCtxt = conn->mAppState; + VerifyOrReturn(appConnCbCtxt != nullptr); + + if (appConnCbCtxt->connClosedCb != nullptr) + { + appConnCbCtxt->connClosedCb(conn, conErr); + } + else + { + ChipLogProgress(Inet, "Connection Closed. App callback missing."); + } + } + void InitializeMessageTest(TCPImpl & tcp, const IPAddress & addr) { - CHIP_ERROR err = tcp.Init(Transport::TcpListenParameters(mContext->GetTCPEndPointManager()).SetAddressType(addr.Type())); + CHIP_ERROR err = tcp.Init(Transport::TcpListenParameters(mContext->GetTCPEndPointManager()) + .SetAddressType(addr.Type()) + .SetListenPort(gChipTCPPort)); // retry a few times in case the port is somehow in use. // this is a WORKAROUND for flaky testing if we run tests very fast after each other. @@ -106,7 +155,9 @@ class MockTransportMgrDelegate : public chip::TransportMgrDelegate { ChipLogProgress(NotSpecified, "RETRYING tcp initialization"); chip::test_utils::SleepMillis(100); - err = tcp.Init(Transport::TcpListenParameters(mContext->GetTCPEndPointManager()).SetAddressType(addr.Type())); + err = tcp.Init(Transport::TcpListenParameters(mContext->GetTCPEndPointManager()) + .SetAddressType(addr.Type()) + .SetListenPort(gChipTCPPort)); } EXPECT_EQ(err, CHIP_NO_ERROR); @@ -114,7 +165,14 @@ class MockTransportMgrDelegate : public chip::TransportMgrDelegate mTransportMgrBase.SetSessionManager(this); mTransportMgrBase.Init(&tcp); - mReceiveHandlerCallCount = 0; + mReceiveHandlerCallCount = 0; + mHandleConnectionCompleteCalled = false; + mHandleConnectionCloseCalled = false; + + gAppTCPConnCbCtxt.appContext = nullptr; + gAppTCPConnCbCtxt.connReceivedCb = nullptr; + gAppTCPConnCbCtxt.connCompleteCb = nullptr; + gAppTCPConnCbCtxt.connClosedCb = nullptr; } void SingleMessageTest(TCPImpl & tcp, const IPAddress & addr) @@ -132,7 +190,7 @@ class MockTransportMgrDelegate : public chip::TransportMgrDelegate EXPECT_EQ(err, CHIP_NO_ERROR); // Should be able to send a message to itself by just calling send. - err = tcp.SendMessage(Transport::PeerAddress::TCP(addr), std::move(buffer)); + err = tcp.SendMessage(Transport::PeerAddress::TCP(addr, gChipTCPPort), std::move(buffer)); EXPECT_EQ(err, CHIP_NO_ERROR); mContext->DriveIOUntil(chip::System::Clock::Seconds16(5), [this]() { return mReceiveHandlerCallCount != 0; }); @@ -141,39 +199,114 @@ class MockTransportMgrDelegate : public chip::TransportMgrDelegate SetCallback(nullptr); } - void FinalizeMessageTest(TCPImpl & tcp, const IPAddress & addr) + void ConnectTest(TCPImpl & tcp, const IPAddress & addr) + { + // Connect and wait for seeing active connection + CHIP_ERROR err = tcp.TCPConnect(Transport::PeerAddress::TCP(addr, gChipTCPPort), &gAppTCPConnCbCtxt, &gActiveTCPConnState); + EXPECT_EQ(err, CHIP_NO_ERROR); + + mContext->DriveIOUntil(chip::System::Clock::Seconds16(5), [&tcp]() { return tcp.HasActiveConnections(); }); + EXPECT_EQ(tcp.HasActiveConnections(), true); + } + + void HandleConnectCompleteCbCalledTest(TCPImpl & tcp, const IPAddress & addr) + { + // Connect and wait for seeing active connection and connection complete + // handler being called. + CHIP_ERROR err = tcp.TCPConnect(Transport::PeerAddress::TCP(addr, gChipTCPPort), &gAppTCPConnCbCtxt, &gActiveTCPConnState); + EXPECT_EQ(err, CHIP_NO_ERROR); + + mContext->DriveIOUntil(chip::System::Clock::Seconds16(5), [this]() { return mHandleConnectionCompleteCalled; }); + EXPECT_EQ(mHandleConnectionCompleteCalled, true); + } + + void HandleConnectCloseCbCalledTest(TCPImpl & tcp, const IPAddress & addr) + { + // Connect and wait for seeing active connection and connection complete + // handler being called. + CHIP_ERROR err = tcp.TCPConnect(Transport::PeerAddress::TCP(addr, gChipTCPPort), &gAppTCPConnCbCtxt, &gActiveTCPConnState); + EXPECT_EQ(err, CHIP_NO_ERROR); + + mContext->DriveIOUntil(chip::System::Clock::Seconds16(5), [this]() { return mHandleConnectionCompleteCalled; }); + EXPECT_EQ(mHandleConnectionCompleteCalled, true); + + tcp.TCPDisconnect(Transport::PeerAddress::TCP(addr, gChipTCPPort)); + mContext->DriveIOUntil(chip::System::Clock::Seconds16(5), [&tcp]() { return !tcp.HasActiveConnections(); }); + EXPECT_EQ(mHandleConnectionCloseCalled, true); + } + + void DisconnectTest(TCPImpl & tcp, chip::Transport::ActiveTCPConnectionState * conn) { // Disconnect and wait for seeing peer close - tcp.Disconnect(Transport::PeerAddress::TCP(addr)); + tcp.TCPDisconnect(conn, true); mContext->DriveIOUntil(chip::System::Clock::Seconds16(5), [&tcp]() { return !tcp.HasActiveConnections(); }); + EXPECT_EQ(tcp.HasActiveConnections(), false); + } + + void DisconnectTest(TCPImpl & tcp, const IPAddress & addr) + { + // Disconnect and wait for seeing peer close + tcp.TCPDisconnect(Transport::PeerAddress::TCP(addr, gChipTCPPort)); + mContext->DriveIOUntil(chip::System::Clock::Seconds16(5), [&tcp]() { return !tcp.HasActiveConnections(); }); + EXPECT_EQ(tcp.HasActiveConnections(), false); + } + + CHIP_ERROR TCPConnect(const Transport::PeerAddress & peerAddress, Transport::AppTCPConnectionCallbackCtxt * appState, + Transport::ActiveTCPConnectionState ** peerConnState) + { + return mTransportMgrBase.TCPConnect(peerAddress, appState, peerConnState); + } + + using OnTCPConnectionReceivedCallback = void (*)(void * context, chip::Transport::ActiveTCPConnectionState * conn); + + using OnTCPConnectionCompleteCallback = void (*)(void * context, chip::Transport::ActiveTCPConnectionState * conn, + CHIP_ERROR conErr); + + using OnTCPConnectionClosedCallback = void (*)(void * context, chip::Transport::ActiveTCPConnectionState * conn, + CHIP_ERROR conErr); + + void SetConnectionCallbacks(OnTCPConnectionCompleteCallback connCompleteCb, OnTCPConnectionClosedCallback connClosedCb, + OnTCPConnectionReceivedCallback connReceivedCb) + { + mConnCompleteCb = connCompleteCb; + mConnClosedCb = connClosedCb; + mConnReceivedCb = connReceivedCb; } int mReceiveHandlerCallCount = 0; + bool mHandleConnectionCompleteCalled = false; + + bool mHandleConnectionCloseCalled = false; + private: TestContext * mContext; MessageReceivedCallback mCallback; void * mCallbackData; TransportMgrBase mTransportMgrBase; + OnTCPConnectionCompleteCallback mConnCompleteCb = nullptr; + OnTCPConnectionClosedCallback mConnClosedCb = nullptr; + OnTCPConnectionReceivedCallback mConnReceivedCb = nullptr; }; -/////////////////////////// Init test - class TestTCP : public ::testing::Test, public chip::Test::IOContext { protected: void SetUp() { ASSERT_EQ(Init(), CHIP_NO_ERROR); } void TearDown() { Shutdown(); } + /////////////////////////// Init test void CheckSimpleInitTest(Inet::IPAddressType type) { TCPImpl tcp; - CHIP_ERROR err = tcp.Init(Transport::TcpListenParameters(GetTCPEndPointManager()).SetAddressType(type)); + CHIP_ERROR err = + tcp.Init(Transport::TcpListenParameters(GetTCPEndPointManager()).SetAddressType(type).SetListenPort(gChipTCPPort)); EXPECT_EQ(err, CHIP_NO_ERROR); } + /////////////////////////// Messaging test void CheckMessageTest(const IPAddress & addr) { TCPImpl tcp; @@ -181,7 +314,48 @@ class TestTCP : public ::testing::Test, public chip::Test::IOContext MockTransportMgrDelegate gMockTransportMgrDelegate(this); gMockTransportMgrDelegate.InitializeMessageTest(tcp, addr); gMockTransportMgrDelegate.SingleMessageTest(tcp, addr); - gMockTransportMgrDelegate.FinalizeMessageTest(tcp, addr); + gMockTransportMgrDelegate.DisconnectTest(tcp, addr); + } + + void ConnectToSelfTest(const IPAddress & addr) + { + TCPImpl tcp; + + MockTransportMgrDelegate gMockTransportMgrDelegate(this); + gMockTransportMgrDelegate.InitializeMessageTest(tcp, addr); + gMockTransportMgrDelegate.ConnectTest(tcp, addr); + gMockTransportMgrDelegate.DisconnectTest(tcp, addr); + } + + void ConnectSendMessageThenCloseTest(const IPAddress & addr) + { + TCPImpl tcp; + + MockTransportMgrDelegate gMockTransportMgrDelegate(this); + gMockTransportMgrDelegate.InitializeMessageTest(tcp, addr); + gMockTransportMgrDelegate.ConnectTest(tcp, addr); + gMockTransportMgrDelegate.SingleMessageTest(tcp, addr); + gMockTransportMgrDelegate.DisconnectTest(tcp, addr); + } + + void HandleConnCompleteTest(const IPAddress & addr) + { + TCPImpl tcp; + + MockTransportMgrDelegate gMockTransportMgrDelegate(this); + gMockTransportMgrDelegate.InitializeMessageTest(tcp, addr); + gMockTransportMgrDelegate.HandleConnectCompleteCbCalledTest(tcp, addr); + gMockTransportMgrDelegate.DisconnectTest(tcp, addr); + } + + void HandleConnCloseTest(const IPAddress & addr) + { + TCPImpl tcp; + + MockTransportMgrDelegate gMockTransportMgrDelegate(this); + gMockTransportMgrDelegate.InitializeMessageTest(tcp, addr); + gMockTransportMgrDelegate.HandleConnectCloseCbCalledTest(tcp, addr); + gMockTransportMgrDelegate.DisconnectTest(tcp, addr); } }; @@ -211,6 +385,57 @@ TEST_F(TestTCP, CheckMessageTest6) CheckMessageTest(addr); } +#if INET_CONFIG_ENABLE_IPV4 +TEST_F(TestTCP, ConnectToSelfTest4) +{ + IPAddress addr; + IPAddress::FromString("127.0.0.1", addr); + ConnectToSelfTest(addr); +} + +TEST_F(TestTCP, ConnectSendMessageThenCloseTest4) +{ + IPAddress addr; + IPAddress::FromString("127.0.0.1", addr); + ConnectSendMessageThenCloseTest(addr); +} + +TEST_F(TestTCP, HandleConnCompleteCalledTest4) +{ + IPAddress addr; + IPAddress::FromString("127.0.0.1", addr); + HandleConnCompleteTest(addr); +} +#endif // INET_CONFIG_ENABLE_IPV4 + +TEST_F(TestTCP, ConnectToSelfTest6) +{ + IPAddress addr; + IPAddress::FromString("::1", addr); + ConnectToSelfTest(addr); +} + +TEST_F(TestTCP, ConnectSendMessageThenCloseTest6) +{ + IPAddress addr; + IPAddress::FromString("::1", addr); + ConnectSendMessageThenCloseTest(addr); +} + +TEST_F(TestTCP, HandleConnCompleteCalledTest6) +{ + IPAddress addr; + IPAddress::FromString("::1", addr); + HandleConnCompleteTest(addr); +} + +TEST_F(TestTCP, HandleConnCloseCalledTest6) +{ + IPAddress addr; + IPAddress::FromString("::1", addr); + HandleConnCloseTest(addr); +} + // Generates a packet buffer or a chain of packet buffers for a single message. struct TestData { @@ -381,8 +606,8 @@ class TCPTest // (The current TCPEndPoint implementation is not effectively mockable.) gMockTransportMgrDelegate.SingleMessageTest(tcp, addr); - Transport::PeerAddress lPeerAddress = Transport::PeerAddress::TCP(addr); - TCPBase::ActiveConnectionState * state = tcp.FindActiveConnection(lPeerAddress); + Transport::PeerAddress lPeerAddress = Transport::PeerAddress::TCP(addr, gChipTCPPort); + chip::Transport::ActiveTCPConnectionState * state = tcp.FindActiveConnection(lPeerAddress); ASSERT_NE(state, nullptr); Inet::TCPEndPoint * lEndPoint = state->mEndPoint; ASSERT_NE(lEndPoint, nullptr); @@ -433,7 +658,7 @@ class TCPTest EXPECT_EQ(err, CHIP_ERROR_MESSAGE_TOO_LONG); EXPECT_EQ(gMockTransportMgrDelegate.mReceiveHandlerCallCount, 0); - gMockTransportMgrDelegate.FinalizeMessageTest(tcp, addr); + gMockTransportMgrDelegate.DisconnectTest(tcp, addr); } }; } // namespace Transport diff --git a/src/transport/raw/tests/TestUDP.cpp b/src/transport/raw/tests/TestUDP.cpp index 70077a6fa297e9..6a22a821fdfc94 100644 --- a/src/transport/raw/tests/TestUDP.cpp +++ b/src/transport/raw/tests/TestUDP.cpp @@ -52,7 +52,8 @@ class MockTransportMgrDelegate : public TransportMgrDelegate MockTransportMgrDelegate() {} ~MockTransportMgrDelegate() override {} - void OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msgBuf) override + void OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msgBuf, + Transport::MessageTransportContext * transCtxt = nullptr) override { PacketHeader packetHeader;