Skip to content

Commit

Permalink
Some cleanup in secure session manager code (#4296)
Browse files Browse the repository at this point in the history
* Some cleanup in secure session manager code

* Address review comments

* address review comments

* review comments

* update enum variants
  • Loading branch information
pan-apple authored Jan 11, 2021
1 parent 8b42113 commit 832a4e1
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 117 deletions.
2 changes: 1 addition & 1 deletion src/messaging/ReliableMessageManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ CHIP_ERROR ReliableMessageManager::SendFromRetransTable(RetransTableEntry * entr

if (rc)
{
err = mSessionMgr->SendMessage(std::move(entry->retainedBuf), &entry->retainedBuf);
err = mSessionMgr->SendEncryptedMessage(std::move(entry->retainedBuf), &entry->retainedBuf);

if (err == CHIP_NO_ERROR)
{
Expand Down
234 changes: 122 additions & 112 deletions src/transport/SecureSessionMgr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

#include <platform/CHIPDeviceLayer.h>
#include <support/CodeUtils.h>
#include <support/ReturnMacros.h>
#include <support/SafeInt.h>
#include <support/logging/CHIPLogging.h>
#include <transport/RendezvousSession.h>
Expand Down Expand Up @@ -63,9 +64,8 @@ SecureSessionMgr::~SecureSessionMgr()

CHIP_ERROR SecureSessionMgr::Init(NodeId localNodeId, System::Layer * systemLayer, TransportMgrBase * transportMgr)
{
CHIP_ERROR err = CHIP_NO_ERROR;
VerifyOrExit(mState == State::kNotReady, err = CHIP_ERROR_INCORRECT_STATE);
VerifyOrExit(transportMgr != nullptr, err = CHIP_ERROR_INVALID_ARGUMENT);
VerifyOrReturnError(mState == State::kNotReady, CHIP_ERROR_INCORRECT_STATE);
VerifyOrReturnError(transportMgr != nullptr, CHIP_ERROR_INVALID_ARGUMENT);

mState = State::kInitialized;
mLocalNodeId = localNodeId;
Expand All @@ -78,8 +78,7 @@ CHIP_ERROR SecureSessionMgr::Init(NodeId localNodeId, System::Layer * systemLaye

mTransportMgr->SetSecureSessionMgr(this);

exit:
return err;
return CHIP_NO_ERROR;
}

Transport::Type SecureSessionMgr::GetTransportType(NodeId peerNodeId)
Expand All @@ -96,144 +95,160 @@ Transport::Type SecureSessionMgr::GetTransportType(NodeId peerNodeId)

CHIP_ERROR SecureSessionMgr::SendMessage(NodeId peerNodeId, System::PacketBufferHandle msgBuf)
{
PayloadHeader payloadHeader;

return SendMessage(payloadHeader, peerNodeId, std::move(msgBuf));
PayloadHeader unusedPayloadHeader;
return SendMessage(unusedPayloadHeader, peerNodeId, std::move(msgBuf));
}

CHIP_ERROR SecureSessionMgr::SendMessage(PayloadHeader & payloadHeader, NodeId peerNodeId, System::PacketBufferHandle msgBuf,
EncryptedPacketBufferHandle * bufferRetainSlot)
{
return SendMessage(payloadHeader, peerNodeId, std::move(msgBuf), bufferRetainSlot, false);
PacketHeader ununsedPacketHeader;
return SendMessage(payloadHeader, ununsedPacketHeader, peerNodeId, std::move(msgBuf), bufferRetainSlot,
EncryptionState::kPayloadIsUnencrypted);
}

CHIP_ERROR SecureSessionMgr::SendMessage(EncryptedPacketBufferHandle msgBuf, EncryptedPacketBufferHandle * bufferRetainSlot)
CHIP_ERROR SecureSessionMgr::SendEncryptedMessage(EncryptedPacketBufferHandle msgBuf,
EncryptedPacketBufferHandle * bufferRetainSlot)
{
VerifyOrReturnError(!msgBuf.IsNull(), CHIP_ERROR_INVALID_ARGUMENT);

uint16_t headerSize = 0;
PacketHeader packetHeader;
ReturnErrorOnFailure(packetHeader.Decode(msgBuf->Start(), msgBuf->DataLength(), &headerSize));

VerifyOrReturnError(packetHeader.GetDestinationNodeId().HasValue(), CHIP_ERROR_INVALID_DESTINATION_NODE_ID);
NodeId peerNodeId = packetHeader.GetDestinationNodeId().Value();

// Advancing the start to encrypted header, since the transport will attach the packet header on top of it
msgBuf->SetStart(msgBuf->Start() + headerSize);

PayloadHeader payloadHeader;
ReturnErrorOnFailure(SendMessage(payloadHeader, packetHeader, peerNodeId, std::move(msgBuf), bufferRetainSlot,
EncryptionState::kPayloadIsEncrypted));

return SendMessage(payloadHeader, 0, std::move(msgBuf), bufferRetainSlot, true);
return CHIP_NO_ERROR;
}

CHIP_ERROR SecureSessionMgr::SendMessage(PayloadHeader & payloadHeader, NodeId peerNodeId, System::PacketBufferHandle msgBuf,
EncryptedPacketBufferHandle * bufferRetainSlot, bool isEncrypted)
CHIP_ERROR SecureSessionMgr::EncryptPayload(Transport::PeerConnectionState * state, PayloadHeader & payloadHeader,
PacketHeader & packetHeader, NodeId peerNodeId, System::PacketBufferHandle & msgBuf)
{
CHIP_ERROR err = CHIP_NO_ERROR;
PeerConnectionState * state = nullptr;
PacketHeader packetHeader;
uint16_t headerSize = 0;
CHIP_ERROR err = CHIP_NO_ERROR;
uint8_t * data = nullptr;
uint32_t payloadLength = 0; // Make sure it's big enough to add two 16-bit ints without overflowing.
uint16_t totalLen = 0;
uint16_t taglen = 0;
uint16_t actualEncodedHeaderSize;
MessageAuthenticationCode mac;

uint32_t msgId = state->GetSendMessageIndex();

static_assert(std::is_same<decltype(msgBuf->TotalLength()), uint16_t>::value,
"Addition to generate payloadLength might overflow");

uint16_t headerSize = payloadHeader.EncodeSizeBytes();
payloadLength = static_cast<uint32_t>(headerSize + msgBuf->TotalLength());
VerifyOrExit(CanCastTo<uint16_t>(payloadLength), err = CHIP_ERROR_NO_MEMORY);

packetHeader
.SetSourceNodeId(mLocalNodeId) //
.SetDestinationNodeId(peerNodeId) //
.SetMessageId(msgId) //
.SetEncryptionKeyID(state->GetLocalKeyID()) //
.SetPayloadLength(static_cast<uint16_t>(payloadLength));
packetHeader.GetFlags().Set(Header::FlagValues::kSecure);

VerifyOrExit(msgBuf->EnsureReservedSize(headerSize), err = CHIP_ERROR_NO_MEMORY);

msgBuf->SetStart(msgBuf->Start() - headerSize);
data = msgBuf->Start();
totalLen = msgBuf->TotalLength();

err = payloadHeader.Encode(data, totalLen, &actualEncodedHeaderSize);
SuccessOrExit(err);

if (isEncrypted)
{
err = packetHeader.Decode(msgBuf->Start(), msgBuf->DataLength(), &headerSize);
SuccessOrExit(err);
err = state->GetSecureSession().Encrypt(data, totalLen, data, packetHeader, mac);
SuccessOrExit(err);
err = mac.Encode(packetHeader, &data[totalLen], kMaxTagLen, &taglen);
SuccessOrExit(err);

VerifyOrExit(CanCastTo<uint16_t>(totalLen + taglen), err = CHIP_ERROR_INTERNAL);
msgBuf->SetDataLength(static_cast<uint16_t>(totalLen + taglen));

VerifyOrExit(packetHeader.GetDestinationNodeId().HasValue(), err = CHIP_ERROR_INVALID_DESTINATION_NODE_ID);
peerNodeId = packetHeader.GetDestinationNodeId().Value();
ChipLogDetail(Inet, "Secure transport encrypted msg %u", msgId);

exit:
if (err != CHIP_NO_ERROR)
{
const char * errStr = ErrorStr(err);
ChipLogError(Inet, "Secure transport failed to encrypt msg %u: %s", state->GetSendMessageIndex(), errStr);
}
else
{
state->IncrementSendMessageIndex();
}

state = mPeerConnections.FindPeerConnectionState(peerNodeId, nullptr);
return err;
}

CHIP_ERROR SecureSessionMgr::SendMessage(PayloadHeader & payloadHeader, PacketHeader & packetHeader, NodeId peerNodeId,
System::PacketBufferHandle msgBuf, EncryptedPacketBufferHandle * bufferRetainSlot,
EncryptionState encryptionState)
{
CHIP_ERROR err = CHIP_NO_ERROR;
PeerConnectionState * state = nullptr;
uint8_t * msgStart = nullptr;
uint16_t msgLen = 0;
uint16_t headerSize = 0;

// Hold the reference to encrypted message in stack variable.
// In case of any failures, the reference is not returned, and this stack variable
// will automatically free the reference on returning from the function.
EncryptedPacketBufferHandle encryptedMsg;

VerifyOrExit(mState == State::kInitialized, err = CHIP_ERROR_INCORRECT_STATE);

VerifyOrExit(!msgBuf.IsNull(), err = CHIP_ERROR_INVALID_ARGUMENT);
VerifyOrExit(!msgBuf->HasChainedBuffer(), err = CHIP_ERROR_INVALID_MESSAGE_LENGTH);
VerifyOrExit(msgBuf->TotalLength() < kMax_SecureSDU_Length, err = CHIP_ERROR_INVALID_MESSAGE_LENGTH);

// Find an active connection to the specified peer node
state = mPeerConnections.FindPeerConnectionState(peerNodeId, nullptr);
VerifyOrExit(state != nullptr, err = CHIP_ERROR_INVALID_DESTINATION_NODE_ID);

// This marks any connection where we send data to as 'active'
mPeerConnections.MarkConnectionActive(state);

if (encryptionState == EncryptionState::kPayloadIsUnencrypted)
{
uint8_t * data = nullptr;
uint8_t * p = nullptr;
uint32_t msgId = 0;
uint32_t payloadLength = 0; // Make sure it's big enough to add two 16-bit ints without overflowing.
uint16_t len = 0;
MessageAuthenticationCode mac;

if (!isEncrypted)
{
msgId = state->GetSendMessageIndex();

static_assert(std::is_same<decltype(msgBuf->TotalLength()), uint16_t>::value,
"Addition to generate payloadLength might overflow");

headerSize = payloadHeader.EncodeSizeBytes();
payloadLength = static_cast<uint32_t>(headerSize + msgBuf->TotalLength());
VerifyOrExit(CanCastTo<uint16_t>(payloadLength), err = CHIP_ERROR_NO_MEMORY);

packetHeader
.SetSourceNodeId(mLocalNodeId) //
.SetDestinationNodeId(peerNodeId) //
.SetMessageId(msgId) //
.SetEncryptionKeyID(state->GetLocalKeyID()) //
.SetPayloadLength(static_cast<uint16_t>(payloadLength));
packetHeader.GetFlags().Set(Header::FlagValues::kSecure);
}
else
{
// Advancing the start to encrypted header, since the transport will attach the packet header on top of it
msgBuf->SetStart(msgBuf->Start() + headerSize);
}

ChipLogProgress(Inet, "Sending msg from %llu to %llu", mLocalNodeId, peerNodeId);

// Encrypt the packet if it's not already encrypted
if (!isEncrypted)
{
uint16_t totalLen = 0;
uint16_t taglen = 0;
uint16_t actualEncodedHeaderSize;

VerifyOrExit(msgBuf->EnsureReservedSize(headerSize), err = CHIP_ERROR_NO_MEMORY);

msgBuf->SetStart(msgBuf->Start() - headerSize);
data = msgBuf->Start();
totalLen = msgBuf->TotalLength();

err = payloadHeader.Encode(data, totalLen, &actualEncodedHeaderSize);
SuccessOrExit(err);
err = EncryptPayload(state, payloadHeader, packetHeader, peerNodeId, msgBuf);
SuccessOrExit(err);
}

err = state->GetSecureSession().Encrypt(data, totalLen, data, packetHeader, mac);
SuccessOrExit(err);
err = mac.Encode(packetHeader, &data[totalLen], kMaxTagLen, &taglen);
SuccessOrExit(err);
// The start of buffer points to the beginning of the encrypted header, and the length of buffer
// contains both the encrypted header and encrypted data.
// Locally store the start and length of the retained buffer after accounting for the size of packet header.
headerSize = packetHeader.EncodeSizeBytes();

VerifyOrExit(CanCastTo<uint16_t>(totalLen + taglen), err = CHIP_ERROR_INTERNAL);
msgBuf->SetDataLength(static_cast<uint16_t>(totalLen + taglen));
msgStart = static_cast<uint8_t *>(msgBuf->Start() - headerSize);
msgLen = static_cast<uint16_t>(msgBuf->DataLength() + headerSize);

ChipLogDetail(Inet, "Secure transport transmitting msg %u after encryption", msgId);
}

if (bufferRetainSlot)
{
// The start of buffer points to the beginning of the encrypted header, and the length of buffer
// contains both the encrypted header and encrypted data.
// Locally store the start and length of the retained buffer after accounting for the size of packet header.
headerSize = packetHeader.EncodeSizeBytes();
// Retain the PacketBuffer in case it's needed for retransmissions.
encryptedMsg = msgBuf.Retain();
encryptedMsg.mMsgId = packetHeader.GetMessageId();

p = static_cast<uint8_t *>(msgBuf->Start() - headerSize);
len = static_cast<uint16_t>(msgBuf->DataLength() + headerSize);
ChipLogProgress(Inet, "Sending msg from %llu to %llu", mLocalNodeId, peerNodeId);

// Retain the PacketBuffer for following retransmit.
*bufferRetainSlot = msgBuf.Retain();
bufferRetainSlot->mMsgId = msgId;
}
err = mTransportMgr->SendMessage(packetHeader, state->GetPeerAddress(), std::move(msgBuf));
SuccessOrExit(err);

err = mTransportMgr->SendMessage(packetHeader, state->GetPeerAddress(), std::move(msgBuf));
if (bufferRetainSlot != nullptr)
{
// Rewind the start and len of the buffer back to pre-send state for following possible retransmition.
encryptedMsg->SetStart(msgStart);
encryptedMsg->SetDataLength(msgLen);

if (bufferRetainSlot)
{
// Rewind the start and len of the buffer back to pre-send state for following possible retransmition.
(*bufferRetainSlot)->SetStart(p);
(*bufferRetainSlot)->SetDataLength(len);
}
(*bufferRetainSlot) = std::move(encryptedMsg);
}
SuccessOrExit(err);

if (!isEncrypted)
state->IncrementSendMessageIndex();

exit:
if (!msgBuf.IsNull())
Expand All @@ -243,10 +258,6 @@ CHIP_ERROR SecureSessionMgr::SendMessage(PayloadHeader & payloadHeader, NodeId p
{
ChipLogError(Inet, "Secure transport could not find a valid PeerConnection: %s", errStr);
}
else
{
ChipLogError(Inet, "Secure transport failed to encrypt msg %u: %s", state->GetSendMessageIndex(), errStr);
}
}

return err;
Expand All @@ -271,8 +282,8 @@ CHIP_ERROR SecureSessionMgr::NewPairing(const Optional<Transport::PeerAddress> &
ChipLogDetail(Inet, "New pairing for device %llu, key %d!!", peerNodeId, peerKeyId);

state = nullptr;
err = mPeerConnections.CreateNewPeerConnectionState(Optional<NodeId>::Value(peerNodeId), peerKeyId, localKeyId, &state);
SuccessOrExit(err);
ReturnErrorOnFailure(
mPeerConnections.CreateNewPeerConnectionState(Optional<NodeId>::Value(peerNodeId), peerKeyId, localKeyId, &state));

if (peerAddr.HasValue() && peerAddr.Value().GetIPAddress() != Inet::IPAddress::Any)
{
Expand All @@ -291,7 +302,6 @@ CHIP_ERROR SecureSessionMgr::NewPairing(const Optional<Transport::PeerAddress> &
strlen(kSpake2pI2RSessionInfo), state->GetSecureSession());
}

exit:
return err;
}

Expand Down
16 changes: 13 additions & 3 deletions src/transport/SecureSessionMgr.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ class DLL_EXPORT SecureSessionMgr : public TransportMgrDelegate
CHIP_ERROR SendMessage(NodeId peerNodeId, System::PacketBufferHandle msgBuf);
CHIP_ERROR SendMessage(PayloadHeader & payloadHeader, NodeId peerNodeId, System::PacketBufferHandle msgBuf,
EncryptedPacketBufferHandle * bufferRetainSlot = nullptr);
CHIP_ERROR SendMessage(EncryptedPacketBufferHandle msgBuf, EncryptedPacketBufferHandle * bufferRetainSlot);
CHIP_ERROR SendEncryptedMessage(EncryptedPacketBufferHandle msgBuf, EncryptedPacketBufferHandle * bufferRetainSlot);

/**
* @brief
Expand Down Expand Up @@ -226,6 +226,12 @@ class DLL_EXPORT SecureSessionMgr : public TransportMgrDelegate
kInitialized, /**< State when the object is ready connect to other peers. */
};

enum class EncryptionState
{
kPayloadIsEncrypted,
kPayloadIsUnencrypted,
};

System::Layer * mSystemLayer = nullptr;
NodeId mLocalNodeId; // < Id of the current node
Transport::PeerConnections<CHIP_CONFIG_PEER_CONNECTION_POOL_SIZE> mPeerConnections; // < Active connections to other peers
Expand All @@ -234,8 +240,12 @@ class DLL_EXPORT SecureSessionMgr : public TransportMgrDelegate
SecureSessionMgrDelegate * mCB = nullptr;
TransportMgrBase * mTransportMgr = nullptr;

CHIP_ERROR SendMessage(PayloadHeader & payloadHeader, NodeId peerNodeId, System::PacketBufferHandle msgBuf,
EncryptedPacketBufferHandle * bufferRetainSlot, bool isEncrypted);
CHIP_ERROR EncryptPayload(Transport::PeerConnectionState * state, PayloadHeader & payloadHeader, PacketHeader & packetHeader,
NodeId peerNodeId, System::PacketBufferHandle & msgBuf);

CHIP_ERROR SendMessage(PayloadHeader & payloadHeader, PacketHeader & packetHeader, NodeId peerNodeId,
System::PacketBufferHandle msgBuf, EncryptedPacketBufferHandle * bufferRetainSlot,
EncryptionState encryptionState);

/** Schedules a new oneshot timer for checking connection expiry. */
void ScheduleExpiryTimer();
Expand Down
2 changes: 1 addition & 1 deletion src/transport/tests/TestSecureSessionMgr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ void SendEncryptedPacketTest(nlTestSuite * inSuite, void * inContext)
err = secureSessionMgr.SendMessage(payloadHeader, kDestinationNodeId, std::move(buffer), &msgBuf);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);

err = secureSessionMgr.SendMessage(std::move(msgBuf), nullptr);
err = secureSessionMgr.SendEncryptedMessage(std::move(msgBuf), nullptr);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);
}

Expand Down

0 comments on commit 832a4e1

Please sign in to comment.