From c64e34a654b98a793391304b0a9e573d0319cd56 Mon Sep 17 00:00:00 2001 From: Zang MingJie Date: Thu, 27 Jan 2022 15:50:01 +0800 Subject: [PATCH] Add ephemeral node IDs for unsecured sessions --- src/transport/SessionManager.cpp | 41 ++++++- src/transport/SessionManager.h | 4 +- src/transport/UnauthenticatedSessionTable.h | 129 +++++++++++--------- 3 files changed, 109 insertions(+), 65 deletions(-) diff --git a/src/transport/SessionManager.cpp b/src/transport/SessionManager.cpp index 8397aefe1b8c58..2fe6231386b269 100644 --- a/src/transport/SessionManager.cpp +++ b/src/transport/SessionManager.cpp @@ -180,6 +180,16 @@ CHIP_ERROR SessionManager::PrepareMessage(const SessionHandle & sessionHandle, P uint32_t messageCounter = counter.Value(); ReturnErrorOnFailure(counter.Advance()); packetHeader.SetMessageCounter(messageCounter); + Transport::UnauthenticatedSession * session = sessionHandle->AsUnauthenticatedSession(); + switch (session->GetSessionRole()) + { + case Transport::UnauthenticatedSession::SessionRole::kInitiator: + packetHeader.SetSourceNodeId(session->GetEphemeralInitiatorNodeID()); + break; + case Transport::UnauthenticatedSession::SessionRole::kResponder: + packetHeader.SetDestinationNodeId(session->GetEphemeralInitiatorNodeID()); + break; + } // Trace after all headers are settled. CHIP_TRACE_MESSAGE_SENT(payloadHeader, packetHeader, message->Start(), message->TotalLength()); @@ -430,11 +440,33 @@ void SessionManager::RefreshSessionOperationalData(const SessionHandle & session void SessionManager::MessageDispatch(const PacketHeader & packetHeader, const Transport::PeerAddress & peerAddress, System::PacketBufferHandle && msg) { - Optional optionalSession = mUnauthenticatedSessions.FindOrAllocateEntry(peerAddress, gDefaultMRPConfig); - if (!optionalSession.HasValue()) + Optional source = packetHeader.GetSourceNodeId(); + Optional destination = packetHeader.GetDestinationNodeId(); + if ((source.HasValue() && destination.HasValue()) || (!source.HasValue() && !destination.HasValue())) { - ChipLogError(Inet, "UnauthenticatedSession exhausted"); - return; + return; // ephemeral node id is only assigned to the initiator, there should be one and only one node id exists. + } + + Optional optionalSession; + if (source.HasValue()) + { + // Assume peer is the initiator, we are the responder. + optionalSession = mUnauthenticatedSessions.FindOrAllocateResponder(source.Value(), gDefaultMRPConfig); + if (!optionalSession.HasValue()) + { + ChipLogError(Inet, "UnauthenticatedSession exhausted"); + return; + } + } + else + { + // Assume peer is the responder, we are the initiator. + optionalSession = mUnauthenticatedSessions.FindInitiator(destination.Value()); + if (!optionalSession.HasValue()) + { + ChipLogProgress(Inet, "Received unknown unsecure packet for initiator 0x%" PRId64, destination.Value()); + return; + } } const SessionHandle & session = optionalSession.Value(); @@ -464,6 +496,7 @@ void SessionManager::MessageDispatch(const PacketHeader & packetHeader, const Tr } unsecuredSession->GetPeerMessageCounter().Commit(packetHeader.GetMessageCounter()); + unsecuredSession->SetPeerAddress(peerAddress); if (mCB != nullptr) { diff --git a/src/transport/SessionManager.h b/src/transport/SessionManager.h index 7d30ba0f2364cf..56dcb2372654fc 100644 --- a/src/transport/SessionManager.h +++ b/src/transport/SessionManager.h @@ -27,6 +27,7 @@ #include +#include #include #include #include @@ -206,7 +207,8 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate Optional CreateUnauthenticatedSession(const Transport::PeerAddress & peerAddress, const ReliableMessageProtocolConfig & config) { - return mUnauthenticatedSessions.FindOrAllocateEntry(peerAddress, config); + NodeId ephemeralInitiatorNodeID = static_cast(Crypto::GetRandU64()); + return mUnauthenticatedSessions.AllocInitiator(ephemeralInitiatorNodeID, peerAddress, config); } // TODO: implements group sessions diff --git a/src/transport/UnauthenticatedSessionTable.h b/src/transport/UnauthenticatedSessionTable.h index 0b645d011a79e3..656948471ac5b8 100644 --- a/src/transport/UnauthenticatedSessionTable.h +++ b/src/transport/UnauthenticatedSessionTable.h @@ -46,8 +46,15 @@ class UnauthenticatedSessionDeleter class UnauthenticatedSession : public Session, public ReferenceCounted { public: - UnauthenticatedSession(const PeerAddress & address, const ReliableMessageProtocolConfig & config) : - mPeerAddress(address), mLastActivityTime(System::SystemClock().GetMonotonicTimestamp()), mMRPConfig(config) + enum class SessionRole + { + kInitiator, + kResponder, + }; + + UnauthenticatedSession(SessionRole sessionRole, NodeId ephemeralInitiatorNodeID, const ReliableMessageProtocolConfig & config) : + mEphemeralInitiatorNodeId(ephemeralInitiatorNodeID), mSessionRole(sessionRole), + mLastActivityTime(System::SystemClock().GetMonotonicTimestamp()), mMRPConfig(config) {} ~UnauthenticatedSession() { NotifySessionReleased(); } @@ -88,8 +95,22 @@ class UnauthenticatedSession : public Session, public ReferenceCounted FindOrAllocateEntry(const PeerAddress & address, const ReliableMessageProtocolConfig & config) + Optional FindOrAllocateResponder(NodeId ephemeralInitiatorNodeID, const ReliableMessageProtocolConfig & config) { - UnauthenticatedSession * result = FindEntry(address); + UnauthenticatedSession * result = FindEntry(UnauthenticatedSession::SessionRole::kResponder, ephemeralInitiatorNodeID); if (result != nullptr) return MakeOptional(*result); - CHIP_ERROR err = AllocEntry(address, config, result); + CHIP_ERROR err = AllocEntry(UnauthenticatedSession::SessionRole::kResponder, ephemeralInitiatorNodeID, config, result); + if (err == CHIP_NO_ERROR) + { + return MakeOptional(*result); + } + else + { + return Optional::Missing(); + } + } + + CHECK_RETURN_VALUE Optional FindInitiator(NodeId ephemeralInitiatorNodeID) + { + UnauthenticatedSession * result = FindEntry(UnauthenticatedSession::SessionRole::kInitiator, ephemeralInitiatorNodeID); + if (result != nullptr) + { + return MakeOptional(*result); + } + else + { + return Optional::Missing(); + } + } + + CHECK_RETURN_VALUE Optional AllocInitiator(NodeId ephemeralInitiatorNodeID, const PeerAddress & peerAddress, + const ReliableMessageProtocolConfig & config) + { + UnauthenticatedSession * result = nullptr; + CHIP_ERROR err = AllocEntry(UnauthenticatedSession::SessionRole::kInitiator, ephemeralInitiatorNodeID, config, result); if (err == CHIP_NO_ERROR) { + result->SetPeerAddress(peerAddress); return MakeOptional(*result); } else @@ -148,10 +201,10 @@ class UnauthenticatedSessionTable * CHIP_ERROR_NO_MEMORY). */ CHECK_RETURN_VALUE - CHIP_ERROR AllocEntry(const PeerAddress & address, const ReliableMessageProtocolConfig & config, - UnauthenticatedSession *& entry) + CHIP_ERROR AllocEntry(UnauthenticatedSession::SessionRole sessionRole, NodeId ephemeralInitiatorNodeID, + const ReliableMessageProtocolConfig & config, UnauthenticatedSession *& entry) { - entry = mEntries.CreateObject(address, config); + entry = mEntries.CreateObject(sessionRole, ephemeralInitiatorNodeID, config); if (entry != nullptr) return CHIP_NO_ERROR; @@ -161,21 +214,16 @@ class UnauthenticatedSessionTable return CHIP_ERROR_NO_MEMORY; } - mEntries.ResetObject(entry, address, config); + mEntries.ResetObject(entry, sessionRole, ephemeralInitiatorNodeID, config); return CHIP_NO_ERROR; } - /** - * Get a session using given address - * - * @return the peer found, nullptr if not found - */ - CHECK_RETURN_VALUE - UnauthenticatedSession * FindEntry(const PeerAddress & address) + CHECK_RETURN_VALUE UnauthenticatedSession * FindEntry(UnauthenticatedSession::SessionRole sessionRole, + NodeId ephemeralInitiatorNodeID) { UnauthenticatedSession * result = nullptr; mEntries.ForEachActiveObject([&](UnauthenticatedSession * entry) { - if (MatchPeerAddress(entry->GetPeerAddress(), address)) + if (entry->GetSessionRole() == sessionRole && entry->GetEphemeralInitiatorNodeID() == ephemeralInitiatorNodeID) { result = entry; return Loop::Break; @@ -202,45 +250,6 @@ class UnauthenticatedSessionTable return result; } - // A temporary solution for #11120 - // Enforce interface match if not null - static bool MatchInterface(Inet::InterfaceId i1, Inet::InterfaceId i2) - { - if (i1.IsPresent() && i2.IsPresent()) - { - return i1 == i2; - } - else - { - // One of the interfaces is null. - return true; - } - } - - static bool MatchPeerAddress(const PeerAddress & a1, const PeerAddress & a2) - { - if (a1.GetTransportType() != a2.GetTransportType()) - return false; - - switch (a1.GetTransportType()) - { - case Transport::Type::kUndefined: - return false; - case Transport::Type::kUdp: - case Transport::Type::kTcp: - return a1.GetIPAddress() == a2.GetIPAddress() && a1.GetPort() == a2.GetPort() && - // Enforce interface equal-ness if the address is link-local, otherwise ignore interface - // Use MatchInterface for a temporary solution for #11120 - (a1.GetIPAddress().IsIPv6LinkLocal() ? a1.GetInterface() == a2.GetInterface() - : MatchInterface(a1.GetInterface(), a2.GetInterface())); - case Transport::Type::kBle: - // TODO: complete BLE address comparation - return true; - } - - return false; - } - BitMapObjectPool mEntries; };