Skip to content

Commit

Permalink
Add ephemeral node IDs for unsecured sessions
Browse files Browse the repository at this point in the history
  • Loading branch information
kghost committed Jan 27, 2022
1 parent ea5612a commit c64e34a
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 65 deletions.
41 changes: 37 additions & 4 deletions src/transport/SessionManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,16 @@ CHIP_ERROR SessionManager::PrepareMessage(const SessionHandle & sessionHandle, P
uint32_t messageCounter = counter.Value();
ReturnErrorOnFailure(counter.Advance());
packetHeader.SetMessageCounter(messageCounter);
Transport::UnauthenticatedSession * session = sessionHandle->AsUnauthenticatedSession();
switch (session->GetSessionRole())
{
case Transport::UnauthenticatedSession::SessionRole::kInitiator:
packetHeader.SetSourceNodeId(session->GetEphemeralInitiatorNodeID());
break;
case Transport::UnauthenticatedSession::SessionRole::kResponder:
packetHeader.SetDestinationNodeId(session->GetEphemeralInitiatorNodeID());
break;
}

// Trace after all headers are settled.
CHIP_TRACE_MESSAGE_SENT(payloadHeader, packetHeader, message->Start(), message->TotalLength());
Expand Down Expand Up @@ -430,11 +440,33 @@ void SessionManager::RefreshSessionOperationalData(const SessionHandle & session
void SessionManager::MessageDispatch(const PacketHeader & packetHeader, const Transport::PeerAddress & peerAddress,
System::PacketBufferHandle && msg)
{
Optional<SessionHandle> optionalSession = mUnauthenticatedSessions.FindOrAllocateEntry(peerAddress, gDefaultMRPConfig);
if (!optionalSession.HasValue())
Optional<NodeId> source = packetHeader.GetSourceNodeId();
Optional<NodeId> destination = packetHeader.GetDestinationNodeId();
if ((source.HasValue() && destination.HasValue()) || (!source.HasValue() && !destination.HasValue()))
{
ChipLogError(Inet, "UnauthenticatedSession exhausted");
return;
return; // ephemeral node id is only assigned to the initiator, there should be one and only one node id exists.
}

Optional<SessionHandle> optionalSession;
if (source.HasValue())
{
// Assume peer is the initiator, we are the responder.
optionalSession = mUnauthenticatedSessions.FindOrAllocateResponder(source.Value(), gDefaultMRPConfig);
if (!optionalSession.HasValue())
{
ChipLogError(Inet, "UnauthenticatedSession exhausted");
return;
}
}
else
{
// Assume peer is the responder, we are the initiator.
optionalSession = mUnauthenticatedSessions.FindInitiator(destination.Value());
if (!optionalSession.HasValue())
{
ChipLogProgress(Inet, "Received unknown unsecure packet for initiator 0x%" PRId64, destination.Value());
return;
}
}

const SessionHandle & session = optionalSession.Value();
Expand Down Expand Up @@ -464,6 +496,7 @@ void SessionManager::MessageDispatch(const PacketHeader & packetHeader, const Tr
}

unsecuredSession->GetPeerMessageCounter().Commit(packetHeader.GetMessageCounter());
unsecuredSession->SetPeerAddress(peerAddress);

if (mCB != nullptr)
{
Expand Down
4 changes: 3 additions & 1 deletion src/transport/SessionManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#include <utility>

#include <crypto/RandUtils.h>
#include <inet/IPAddress.h>
#include <lib/core/CHIPCore.h>
#include <lib/support/CodeUtils.h>
Expand Down Expand Up @@ -206,7 +207,8 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate
Optional<SessionHandle> CreateUnauthenticatedSession(const Transport::PeerAddress & peerAddress,
const ReliableMessageProtocolConfig & config)
{
return mUnauthenticatedSessions.FindOrAllocateEntry(peerAddress, config);
NodeId ephemeralInitiatorNodeID = static_cast<NodeId>(Crypto::GetRandU64());
return mUnauthenticatedSessions.AllocInitiator(ephemeralInitiatorNodeID, peerAddress, config);
}

// TODO: implements group sessions
Expand Down
129 changes: 69 additions & 60 deletions src/transport/UnauthenticatedSessionTable.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,15 @@ class UnauthenticatedSessionDeleter
class UnauthenticatedSession : public Session, public ReferenceCounted<UnauthenticatedSession, UnauthenticatedSessionDeleter, 0>
{
public:
UnauthenticatedSession(const PeerAddress & address, const ReliableMessageProtocolConfig & config) :
mPeerAddress(address), mLastActivityTime(System::SystemClock().GetMonotonicTimestamp()), mMRPConfig(config)
enum class SessionRole
{
kInitiator,
kResponder,
};

UnauthenticatedSession(SessionRole sessionRole, NodeId ephemeralInitiatorNodeID, const ReliableMessageProtocolConfig & config) :
mEphemeralInitiatorNodeId(ephemeralInitiatorNodeID), mSessionRole(sessionRole),
mLastActivityTime(System::SystemClock().GetMonotonicTimestamp()), mMRPConfig(config)
{}
~UnauthenticatedSession() { NotifySessionReleased(); }

Expand Down Expand Up @@ -88,8 +95,22 @@ class UnauthenticatedSession : public Session, public ReferenceCounted<Unauthent
return System::Clock::Timeout();
}

NodeId GetPeerNodeId() const { return kUndefinedNodeId; }
NodeId GetPeerNodeId() const
{
if (mSessionRole == SessionRole::kInitiator)
{
return kUndefinedNodeId;
}
else
{
return mEphemeralInitiatorNodeId;
}
}

SessionRole GetSessionRole() const { return mSessionRole; }
NodeId GetEphemeralInitiatorNodeID() const { return mEphemeralInitiatorNodeId; }
const PeerAddress & GetPeerAddress() const { return mPeerAddress; }
void SetPeerAddress(const PeerAddress & peerAddress) { mPeerAddress = peerAddress; }

void SetMRPConfig(const ReliableMessageProtocolConfig & config) { mMRPConfig = config; }

Expand All @@ -98,7 +119,9 @@ class UnauthenticatedSession : public Session, public ReferenceCounted<Unauthent
PeerMessageCounter & GetPeerMessageCounter() { return mPeerMessageCounter; }

private:
const PeerAddress mPeerAddress;
const NodeId mEphemeralInitiatorNodeId;
const SessionRole mSessionRole;
PeerAddress mPeerAddress;
System::Clock::Timestamp mLastActivityTime;
ReliableMessageProtocolConfig mMRPConfig;
PeerMessageCounter mPeerMessageCounter;
Expand All @@ -118,20 +141,50 @@ class UnauthenticatedSessionTable
~UnauthenticatedSessionTable() { mEntries.ReleaseAll(); }

/**
* Get a session given the peer address. If the session doesn't exist in the cache, allocate a new entry for it.
* Get a responder session with the given ephemeralInitiatorNodeID. If the session doesn't exist in the cache, allocate a new
* entry for it.
*
* @return the session found or allocated, nullptr if not found and allocation failed.
* @return the session found or allocated, or Optional::Missing if not found and allocation failed.
*/
CHECK_RETURN_VALUE
Optional<SessionHandle> FindOrAllocateEntry(const PeerAddress & address, const ReliableMessageProtocolConfig & config)
Optional<SessionHandle> FindOrAllocateResponder(NodeId ephemeralInitiatorNodeID, const ReliableMessageProtocolConfig & config)
{
UnauthenticatedSession * result = FindEntry(address);
UnauthenticatedSession * result = FindEntry(UnauthenticatedSession::SessionRole::kResponder, ephemeralInitiatorNodeID);
if (result != nullptr)
return MakeOptional<SessionHandle>(*result);

CHIP_ERROR err = AllocEntry(address, config, result);
CHIP_ERROR err = AllocEntry(UnauthenticatedSession::SessionRole::kResponder, ephemeralInitiatorNodeID, config, result);
if (err == CHIP_NO_ERROR)
{
return MakeOptional<SessionHandle>(*result);
}
else
{
return Optional<SessionHandle>::Missing();
}
}

CHECK_RETURN_VALUE Optional<SessionHandle> FindInitiator(NodeId ephemeralInitiatorNodeID)
{
UnauthenticatedSession * result = FindEntry(UnauthenticatedSession::SessionRole::kInitiator, ephemeralInitiatorNodeID);
if (result != nullptr)
{
return MakeOptional<SessionHandle>(*result);
}
else
{
return Optional<SessionHandle>::Missing();
}
}

CHECK_RETURN_VALUE Optional<SessionHandle> AllocInitiator(NodeId ephemeralInitiatorNodeID, const PeerAddress & peerAddress,
const ReliableMessageProtocolConfig & config)
{
UnauthenticatedSession * result = nullptr;
CHIP_ERROR err = AllocEntry(UnauthenticatedSession::SessionRole::kInitiator, ephemeralInitiatorNodeID, config, result);
if (err == CHIP_NO_ERROR)
{
result->SetPeerAddress(peerAddress);
return MakeOptional<SessionHandle>(*result);
}
else
Expand All @@ -148,10 +201,10 @@ class UnauthenticatedSessionTable
* CHIP_ERROR_NO_MEMORY).
*/
CHECK_RETURN_VALUE
CHIP_ERROR AllocEntry(const PeerAddress & address, const ReliableMessageProtocolConfig & config,
UnauthenticatedSession *& entry)
CHIP_ERROR AllocEntry(UnauthenticatedSession::SessionRole sessionRole, NodeId ephemeralInitiatorNodeID,
const ReliableMessageProtocolConfig & config, UnauthenticatedSession *& entry)
{
entry = mEntries.CreateObject(address, config);
entry = mEntries.CreateObject(sessionRole, ephemeralInitiatorNodeID, config);
if (entry != nullptr)
return CHIP_NO_ERROR;

Expand All @@ -161,21 +214,16 @@ class UnauthenticatedSessionTable
return CHIP_ERROR_NO_MEMORY;
}

mEntries.ResetObject(entry, address, config);
mEntries.ResetObject(entry, sessionRole, ephemeralInitiatorNodeID, config);
return CHIP_NO_ERROR;
}

/**
* Get a session using given address
*
* @return the peer found, nullptr if not found
*/
CHECK_RETURN_VALUE
UnauthenticatedSession * FindEntry(const PeerAddress & address)
CHECK_RETURN_VALUE UnauthenticatedSession * FindEntry(UnauthenticatedSession::SessionRole sessionRole,
NodeId ephemeralInitiatorNodeID)
{
UnauthenticatedSession * result = nullptr;
mEntries.ForEachActiveObject([&](UnauthenticatedSession * entry) {
if (MatchPeerAddress(entry->GetPeerAddress(), address))
if (entry->GetSessionRole() == sessionRole && entry->GetEphemeralInitiatorNodeID() == ephemeralInitiatorNodeID)
{
result = entry;
return Loop::Break;
Expand All @@ -202,45 +250,6 @@ class UnauthenticatedSessionTable
return result;
}

// A temporary solution for #11120
// Enforce interface match if not null
static bool MatchInterface(Inet::InterfaceId i1, Inet::InterfaceId i2)
{
if (i1.IsPresent() && i2.IsPresent())
{
return i1 == i2;
}
else
{
// One of the interfaces is null.
return true;
}
}

static bool MatchPeerAddress(const PeerAddress & a1, const PeerAddress & a2)
{
if (a1.GetTransportType() != a2.GetTransportType())
return false;

switch (a1.GetTransportType())
{
case Transport::Type::kUndefined:
return false;
case Transport::Type::kUdp:
case Transport::Type::kTcp:
return a1.GetIPAddress() == a2.GetIPAddress() && a1.GetPort() == a2.GetPort() &&
// Enforce interface equal-ness if the address is link-local, otherwise ignore interface
// Use MatchInterface for a temporary solution for #11120
(a1.GetIPAddress().IsIPv6LinkLocal() ? a1.GetInterface() == a2.GetInterface()
: MatchInterface(a1.GetInterface(), a2.GetInterface()));
case Transport::Type::kBle:
// TODO: complete BLE address comparation
return true;
}

return false;
}

BitMapObjectPool<UnauthenticatedSession, kMaxSessionCount> mEntries;
};

Expand Down

0 comments on commit c64e34a

Please sign in to comment.