Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ephemeral node IDs for unsecured sessions #14382

Merged
merged 3 commits into from
Feb 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 44 additions & 7 deletions src/transport/SessionManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,16 @@ CHIP_ERROR SessionManager::PrepareMessage(const SessionHandle & sessionHandle, P
uint32_t messageCounter = counter.Value();
ReturnErrorOnFailure(counter.Advance());
packetHeader.SetMessageCounter(messageCounter);
Transport::UnauthenticatedSession * session = sessionHandle->AsUnauthenticatedSession();
switch (session->GetSessionRole())
{
case Transport::UnauthenticatedSession::SessionRole::kInitiator:
packetHeader.SetSourceNodeId(session->GetEphemeralInitiatorNodeID());
break;
case Transport::UnauthenticatedSession::SessionRole::kResponder:
packetHeader.SetDestinationNodeId(session->GetEphemeralInitiatorNodeID());
break;
}

// Trace after all headers are settled.
CHIP_TRACE_MESSAGE_SENT(payloadHeader, packetHeader, message->Start(), message->TotalLength());
Expand Down Expand Up @@ -401,7 +411,7 @@ void SessionManager::OnMessageReceived(const PeerAddress & peerAddress, System::
}
else
{
MessageDispatch(packetHeader, peerAddress, std::move(msg));
UnauthenticatedMessageDispatch(packetHeader, peerAddress, std::move(msg));
}
}

Expand Down Expand Up @@ -437,18 +447,45 @@ void SessionManager::RefreshSessionOperationalData(const SessionHandle & session
});
}

void SessionManager::MessageDispatch(const PacketHeader & packetHeader, const Transport::PeerAddress & peerAddress,
System::PacketBufferHandle && msg)
void SessionManager::UnauthenticatedMessageDispatch(const PacketHeader & packetHeader, const Transport::PeerAddress & peerAddress,
System::PacketBufferHandle && msg)
{
Optional<SessionHandle> optionalSession = mUnauthenticatedSessions.FindOrAllocateEntry(peerAddress, GetLocalMRPConfig());
if (!optionalSession.HasValue())
Optional<NodeId> source = packetHeader.GetSourceNodeId();
Optional<NodeId> destination = packetHeader.GetDestinationNodeId();
if ((source.HasValue() && destination.HasValue()) || (!source.HasValue() && !destination.HasValue()))
{
ChipLogError(Inet, "UnauthenticatedSession exhausted");
return;
ChipLogProgress(Inet,
"Received malformed unsecure packet with source 0x" ChipLogFormatX64 " destination 0x" ChipLogFormatX64,
ChipLogValueX64(source.ValueOr(kUndefinedNodeId)), ChipLogValueX64(destination.ValueOr(kUndefinedNodeId)));
return; // ephemeral node id is only assigned to the initiator, there should be one and only one node id exists.
kghost marked this conversation as resolved.
Show resolved Hide resolved
}

Optional<SessionHandle> optionalSession;
if (source.HasValue())
{
// Assume peer is the initiator, we are the responder.
optionalSession = mUnauthenticatedSessions.FindOrAllocateResponder(source.Value(), GetLocalMRPConfig());
if (!optionalSession.HasValue())
{
ChipLogError(Inet, "UnauthenticatedSession exhausted");
return;
}
}
else
{
// Assume peer is the responder, we are the initiator.
optionalSession = mUnauthenticatedSessions.FindInitiator(destination.Value());
if (!optionalSession.HasValue())
{
ChipLogProgress(Inet, "Received unknown unsecure packet for initiator 0x" ChipLogFormatX64,
ChipLogValueX64(destination.Value()));
return;
}
}

const SessionHandle & session = optionalSession.Value();
Transport::UnauthenticatedSession * unsecuredSession = session->AsUnauthenticatedSession();
unsecuredSession->SetPeerAddress(peerAddress);
SessionMessageDelegate::DuplicateMessage isDuplicate = SessionMessageDelegate::DuplicateMessage::No;

// Verify message counter
Expand Down
13 changes: 10 additions & 3 deletions src/transport/SessionManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#include <utility>

#include <crypto/RandUtils.h>
#include <inet/IPAddress.h>
#include <lib/core/CHIPCore.h>
#include <lib/support/CodeUtils.h>
Expand Down Expand Up @@ -206,7 +207,13 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate
Optional<SessionHandle> CreateUnauthenticatedSession(const Transport::PeerAddress & peerAddress,
const ReliableMessageProtocolConfig & config)
{
return mUnauthenticatedSessions.FindOrAllocateEntry(peerAddress, config);
// Allocate ephemeralInitiatorNodeID in Operational Node ID range
NodeId ephemeralInitiatorNodeID;
do
{
ephemeralInitiatorNodeID = static_cast<NodeId>(Crypto::GetRandU64());
} while (!IsOperationalNodeId(ephemeralInitiatorNodeID));
return mUnauthenticatedSessions.AllocInitiator(ephemeralInitiatorNodeID, peerAddress, config);
}

// TODO: implements group sessions
Expand Down Expand Up @@ -279,8 +286,8 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate
void SecureGroupMessageDispatch(const PacketHeader & packetHeader, const Transport::PeerAddress & peerAddress,
System::PacketBufferHandle && msg);

void MessageDispatch(const PacketHeader & packetHeader, const Transport::PeerAddress & peerAddress,
System::PacketBufferHandle && msg);
void UnauthenticatedMessageDispatch(const PacketHeader & packetHeader, const Transport::PeerAddress & peerAddress,
System::PacketBufferHandle && msg);

void OnReceiveError(CHIP_ERROR error, const Transport::PeerAddress & source);

Expand Down
129 changes: 69 additions & 60 deletions src/transport/UnauthenticatedSessionTable.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,15 @@ class UnauthenticatedSessionDeleter
class UnauthenticatedSession : public Session, public ReferenceCounted<UnauthenticatedSession, UnauthenticatedSessionDeleter, 0>
{
public:
UnauthenticatedSession(const PeerAddress & address, const ReliableMessageProtocolConfig & config) :
mPeerAddress(address), mLastActivityTime(System::SystemClock().GetMonotonicTimestamp()), mMRPConfig(config)
enum class SessionRole
{
kInitiator,
kResponder,
};

UnauthenticatedSession(SessionRole sessionRole, NodeId ephemeralInitiatorNodeID, const ReliableMessageProtocolConfig & config) :
mEphemeralInitiatorNodeId(ephemeralInitiatorNodeID), mSessionRole(sessionRole),
mLastActivityTime(System::SystemClock().GetMonotonicTimestamp()), mMRPConfig(config)
{}
~UnauthenticatedSession() { NotifySessionReleased(); }

Expand Down Expand Up @@ -88,8 +95,22 @@ class UnauthenticatedSession : public Session, public ReferenceCounted<Unauthent
return System::Clock::Timeout();
}

NodeId GetPeerNodeId() const { return kUndefinedNodeId; }
NodeId GetPeerNodeId() const
{
if (mSessionRole == SessionRole::kInitiator)
{
return kUndefinedNodeId;
}
else
{
return mEphemeralInitiatorNodeId;
}
}

SessionRole GetSessionRole() const { return mSessionRole; }
NodeId GetEphemeralInitiatorNodeID() const { return mEphemeralInitiatorNodeId; }
const PeerAddress & GetPeerAddress() const { return mPeerAddress; }
void SetPeerAddress(const PeerAddress & peerAddress) { mPeerAddress = peerAddress; }

void SetMRPConfig(const ReliableMessageProtocolConfig & config) { mMRPConfig = config; }

Expand All @@ -98,7 +119,9 @@ class UnauthenticatedSession : public Session, public ReferenceCounted<Unauthent
PeerMessageCounter & GetPeerMessageCounter() { return mPeerMessageCounter; }

private:
const PeerAddress mPeerAddress;
const NodeId mEphemeralInitiatorNodeId;
const SessionRole mSessionRole;
PeerAddress mPeerAddress;
System::Clock::Timestamp mLastActivityTime;
ReliableMessageProtocolConfig mMRPConfig;
PeerMessageCounter mPeerMessageCounter;
Expand All @@ -118,20 +141,50 @@ class UnauthenticatedSessionTable
~UnauthenticatedSessionTable() { mEntries.ReleaseAll(); }

/**
* Get a session given the peer address. If the session doesn't exist in the cache, allocate a new entry for it.
* Get a responder session with the given ephemeralInitiatorNodeID. If the session doesn't exist in the cache, allocate a new
* entry for it.
*
* @return the session found or allocated, nullptr if not found and allocation failed.
* @return the session found or allocated, or Optional::Missing if not found and allocation failed.
*/
CHECK_RETURN_VALUE
Optional<SessionHandle> FindOrAllocateEntry(const PeerAddress & address, const ReliableMessageProtocolConfig & config)
Optional<SessionHandle> FindOrAllocateResponder(NodeId ephemeralInitiatorNodeID, const ReliableMessageProtocolConfig & config)
{
UnauthenticatedSession * result = FindEntry(address);
UnauthenticatedSession * result = FindEntry(UnauthenticatedSession::SessionRole::kResponder, ephemeralInitiatorNodeID);
if (result != nullptr)
return MakeOptional<SessionHandle>(*result);

CHIP_ERROR err = AllocEntry(address, config, result);
CHIP_ERROR err = AllocEntry(UnauthenticatedSession::SessionRole::kResponder, ephemeralInitiatorNodeID, config, result);
if (err == CHIP_NO_ERROR)
{
return MakeOptional<SessionHandle>(*result);
}
else
{
return Optional<SessionHandle>::Missing();
}
}

CHECK_RETURN_VALUE Optional<SessionHandle> FindInitiator(NodeId ephemeralInitiatorNodeID)
{
UnauthenticatedSession * result = FindEntry(UnauthenticatedSession::SessionRole::kInitiator, ephemeralInitiatorNodeID);
if (result != nullptr)
{
return MakeOptional<SessionHandle>(*result);
}
else
{
return Optional<SessionHandle>::Missing();
}
}
tcarmelveilleux marked this conversation as resolved.
Show resolved Hide resolved

CHECK_RETURN_VALUE Optional<SessionHandle> AllocInitiator(NodeId ephemeralInitiatorNodeID, const PeerAddress & peerAddress,
const ReliableMessageProtocolConfig & config)
{
UnauthenticatedSession * result = nullptr;
CHIP_ERROR err = AllocEntry(UnauthenticatedSession::SessionRole::kInitiator, ephemeralInitiatorNodeID, config, result);
if (err == CHIP_NO_ERROR)
{
result->SetPeerAddress(peerAddress);
return MakeOptional<SessionHandle>(*result);
}
else
Expand All @@ -148,10 +201,10 @@ class UnauthenticatedSessionTable
* CHIP_ERROR_NO_MEMORY).
*/
CHECK_RETURN_VALUE
CHIP_ERROR AllocEntry(const PeerAddress & address, const ReliableMessageProtocolConfig & config,
UnauthenticatedSession *& entry)
CHIP_ERROR AllocEntry(UnauthenticatedSession::SessionRole sessionRole, NodeId ephemeralInitiatorNodeID,
const ReliableMessageProtocolConfig & config, UnauthenticatedSession *& entry)
{
entry = mEntries.CreateObject(address, config);
entry = mEntries.CreateObject(sessionRole, ephemeralInitiatorNodeID, config);
if (entry != nullptr)
return CHIP_NO_ERROR;

Expand All @@ -161,21 +214,16 @@ class UnauthenticatedSessionTable
return CHIP_ERROR_NO_MEMORY;
}

mEntries.ResetObject(entry, address, config);
mEntries.ResetObject(entry, sessionRole, ephemeralInitiatorNodeID, config);
return CHIP_NO_ERROR;
}

/**
* Get a session using given address
*
* @return the peer found, nullptr if not found
*/
CHECK_RETURN_VALUE
UnauthenticatedSession * FindEntry(const PeerAddress & address)
CHECK_RETURN_VALUE UnauthenticatedSession * FindEntry(UnauthenticatedSession::SessionRole sessionRole,
NodeId ephemeralInitiatorNodeID)
{
UnauthenticatedSession * result = nullptr;
mEntries.ForEachActiveObject([&](UnauthenticatedSession * entry) {
if (MatchPeerAddress(entry->GetPeerAddress(), address))
if (entry->GetSessionRole() == sessionRole && entry->GetEphemeralInitiatorNodeID() == ephemeralInitiatorNodeID)
{
result = entry;
return Loop::Break;
Expand All @@ -202,45 +250,6 @@ class UnauthenticatedSessionTable
return result;
}

// A temporary solution for #11120
// Enforce interface match if not null
static bool MatchInterface(Inet::InterfaceId i1, Inet::InterfaceId i2)
{
if (i1.IsPresent() && i2.IsPresent())
{
return i1 == i2;
}
else
{
// One of the interfaces is null.
return true;
}
}

static bool MatchPeerAddress(const PeerAddress & a1, const PeerAddress & a2)
{
if (a1.GetTransportType() != a2.GetTransportType())
return false;

switch (a1.GetTransportType())
{
case Transport::Type::kUndefined:
return false;
case Transport::Type::kUdp:
case Transport::Type::kTcp:
return a1.GetIPAddress() == a2.GetIPAddress() && a1.GetPort() == a2.GetPort() &&
// Enforce interface equal-ness if the address is link-local, otherwise ignore interface
// Use MatchInterface for a temporary solution for #11120
(a1.GetIPAddress().IsIPv6LinkLocal() ? a1.GetInterface() == a2.GetInterface()
: MatchInterface(a1.GetInterface(), a2.GetInterface()));
case Transport::Type::kBle:
// TODO: complete BLE address comparation
return true;
}

return false;
}

ObjectPool<UnauthenticatedSession, kMaxSessionCount> mEntries;
};

Expand Down