Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some cleanup in secure session manager code #4296

Merged
merged 5 commits into from
Jan 11, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/messaging/ReliableMessageManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ CHIP_ERROR ReliableMessageManager::SendFromRetransTable(RetransTableEntry * entr

if (rc)
{
err = mSessionMgr->SendMessage(std::move(entry->retainedBuf), &entry->retainedBuf);
err = mSessionMgr->SendEncryptedMessage(std::move(entry->retainedBuf), &entry->retainedBuf);

if (err == CHIP_NO_ERROR)
{
Expand Down
235 changes: 123 additions & 112 deletions src/transport/SecureSessionMgr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

#include <platform/CHIPDeviceLayer.h>
#include <support/CodeUtils.h>
#include <support/ReturnMacros.h>
#include <support/SafeInt.h>
#include <support/logging/CHIPLogging.h>
#include <transport/RendezvousSession.h>
Expand Down Expand Up @@ -63,9 +64,8 @@ SecureSessionMgr::~SecureSessionMgr()

CHIP_ERROR SecureSessionMgr::Init(NodeId localNodeId, System::Layer * systemLayer, TransportMgrBase * transportMgr)
{
CHIP_ERROR err = CHIP_NO_ERROR;
VerifyOrExit(mState == State::kNotReady, err = CHIP_ERROR_INCORRECT_STATE);
VerifyOrExit(transportMgr != nullptr, err = CHIP_ERROR_INVALID_ARGUMENT);
VerifyOrReturnError(mState == State::kNotReady, CHIP_ERROR_INCORRECT_STATE);
VerifyOrReturnError(transportMgr != nullptr, CHIP_ERROR_INVALID_ARGUMENT);

mState = State::kInitialized;
mLocalNodeId = localNodeId;
Expand All @@ -78,150 +78,166 @@ CHIP_ERROR SecureSessionMgr::Init(NodeId localNodeId, System::Layer * systemLaye

mTransportMgr->SetSecureSessionMgr(this);

exit:
return err;
return CHIP_NO_ERROR;
}

CHIP_ERROR SecureSessionMgr::SendMessage(NodeId peerNodeId, System::PacketBufferHandle msgBuf)
{
PayloadHeader payloadHeader;

return SendMessage(payloadHeader, peerNodeId, std::move(msgBuf));
PayloadHeader unusedPayloadHeader;
return SendMessage(unusedPayloadHeader, peerNodeId, std::move(msgBuf));
}

CHIP_ERROR SecureSessionMgr::SendMessage(PayloadHeader & payloadHeader, NodeId peerNodeId, System::PacketBufferHandle msgBuf,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we call these ResendMessage?

I think there should be a difference between stuff taking headers and things that decode packet/payload headers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about SendEncryptedMessage? From the API perspective, it does not know the CRMP is using it for resends.

EncryptedPacketBufferHandle * bufferRetainSlot)
{
return SendMessage(payloadHeader, peerNodeId, std::move(msgBuf), bufferRetainSlot, false);
PacketHeader ununsedPacketHeader;
return SendMessage(payloadHeader, ununsedPacketHeader, peerNodeId, std::move(msgBuf), bufferRetainSlot,
EncryptionState::kUnencrypted);
}

CHIP_ERROR SecureSessionMgr::SendMessage(EncryptedPacketBufferHandle msgBuf, EncryptedPacketBufferHandle * bufferRetainSlot)
CHIP_ERROR SecureSessionMgr::SendEncryptedMessage(EncryptedPacketBufferHandle msgBuf,
EncryptedPacketBufferHandle * bufferRetainSlot)
{
VerifyOrReturnError(!msgBuf.IsNull(), CHIP_ERROR_INVALID_ARGUMENT);

uint16_t headerSize = 0;
PacketHeader packetHeader;
ReturnErrorOnFailure(packetHeader.Decode(msgBuf->Start(), msgBuf->DataLength(), &headerSize));

VerifyOrReturnError(packetHeader.GetDestinationNodeId().HasValue(), CHIP_ERROR_INVALID_DESTINATION_NODE_ID);
NodeId peerNodeId = packetHeader.GetDestinationNodeId().Value();

// Advancing the start to encrypted header, since the transport will attach the packet header on top of it
msgBuf->SetStart(msgBuf->Start() + headerSize);

PayloadHeader payloadHeader;
ReturnErrorOnFailure(
SendMessage(payloadHeader, packetHeader, peerNodeId, std::move(msgBuf), bufferRetainSlot, EncryptionState::kEncrypted));

return SendMessage(payloadHeader, 0, std::move(msgBuf), bufferRetainSlot, true);
return CHIP_NO_ERROR;
}

CHIP_ERROR SecureSessionMgr::SendMessage(PayloadHeader & payloadHeader, NodeId peerNodeId, System::PacketBufferHandle msgBuf,
EncryptedPacketBufferHandle * bufferRetainSlot, bool isEncrypted)
CHIP_ERROR SecureSessionMgr::EncryptPayload(Transport::PeerConnectionState * state, PayloadHeader & payloadHeader,
PacketHeader & packetHeader, NodeId peerNodeId, System::PacketBufferHandle & msgBuf)
{
CHIP_ERROR err = CHIP_NO_ERROR;
PeerConnectionState * state = nullptr;
PacketHeader packetHeader;
uint16_t headerSize = 0;
CHIP_ERROR err = CHIP_NO_ERROR;
uint8_t * data = nullptr;
uint32_t payloadLength = 0; // Make sure it's big enough to add two 16-bit ints without overflowing.
uint16_t totalLen = 0;
uint16_t taglen = 0;
uint16_t actualEncodedHeaderSize;
MessageAuthenticationCode mac;

uint32_t msgId = state->GetSendMessageIndex();

static_assert(std::is_same<decltype(msgBuf->TotalLength()), uint16_t>::value,
"Addition to generate payloadLength might overflow");

uint16_t headerSize = payloadHeader.EncodeSizeBytes();
payloadLength = static_cast<uint32_t>(headerSize + msgBuf->TotalLength());
VerifyOrExit(CanCastTo<uint16_t>(payloadLength), err = CHIP_ERROR_NO_MEMORY);

packetHeader
.SetSourceNodeId(mLocalNodeId) //
.SetDestinationNodeId(peerNodeId) //
.SetMessageId(msgId) //
.SetEncryptionKeyID(state->GetLocalKeyID()) //
.SetPayloadLength(static_cast<uint16_t>(payloadLength));
packetHeader.GetFlags().Set(Header::FlagValues::kSecure);

VerifyOrExit(msgBuf->EnsureReservedSize(headerSize), err = CHIP_ERROR_NO_MEMORY);

msgBuf->SetStart(msgBuf->Start() - headerSize);
data = msgBuf->Start();
totalLen = msgBuf->TotalLength();

err = payloadHeader.Encode(data, totalLen, &actualEncodedHeaderSize);
SuccessOrExit(err);

if (isEncrypted)
{
err = packetHeader.Decode(msgBuf->Start(), msgBuf->DataLength(), &headerSize);
SuccessOrExit(err);
err = state->GetSecureSession().Encrypt(data, totalLen, data, packetHeader, mac);
SuccessOrExit(err);
err = mac.Encode(packetHeader, &data[totalLen], kMaxTagLen, &taglen);
SuccessOrExit(err);

VerifyOrExit(CanCastTo<uint16_t>(totalLen + taglen), err = CHIP_ERROR_INTERNAL);
msgBuf->SetDataLength(static_cast<uint16_t>(totalLen + taglen));

VerifyOrExit(packetHeader.GetDestinationNodeId().HasValue(), err = CHIP_ERROR_INVALID_DESTINATION_NODE_ID);
peerNodeId = packetHeader.GetDestinationNodeId().Value();
ChipLogDetail(Inet, "Secure transport encrypted msg %u", msgId);

exit:
if (err != CHIP_NO_ERROR)
{
const char * errStr = ErrorStr(err);
ChipLogError(Inet, "Secure transport failed to encrypt msg %u: %s", state->GetSendMessageIndex(), errStr);
pan-apple marked this conversation as resolved.
Show resolved Hide resolved
}
else
{
state->IncrementSendMessageIndex();
}

state = mPeerConnections.FindPeerConnectionState(peerNodeId, nullptr);
return err;
}

CHIP_ERROR SecureSessionMgr::SendMessage(PayloadHeader & payloadHeader, PacketHeader & packetHeader, NodeId peerNodeId,
System::PacketBufferHandle msgBuf, EncryptedPacketBufferHandle * bufferRetainSlot,
EncryptionState encryptionState)
{
CHIP_ERROR err = CHIP_NO_ERROR;
PeerConnectionState * state = nullptr;
uint8_t * msgStart = nullptr;
uint16_t msgLen = 0;
uint16_t headerSize = 0;

// Hold the reference to encrypted message in stack variable.
// In case of any failures, the reference is not returned, and this stack variable
// will automatically free the reference on returning from the function.
EncryptedPacketBufferHandle encryptedMsg;

VerifyOrExit(mState == State::kInitialized, err = CHIP_ERROR_INCORRECT_STATE);

VerifyOrExit(!msgBuf.IsNull(), err = CHIP_ERROR_INVALID_ARGUMENT);
VerifyOrExit(!msgBuf->HasChainedBuffer(), err = CHIP_ERROR_INVALID_MESSAGE_LENGTH);
VerifyOrExit(msgBuf->TotalLength() < kMax_SecureSDU_Length, err = CHIP_ERROR_INVALID_MESSAGE_LENGTH);

// Find an active connection to the specified peer node
state = mPeerConnections.FindPeerConnectionState(peerNodeId, nullptr);
VerifyOrExit(state != nullptr, err = CHIP_ERROR_INVALID_DESTINATION_NODE_ID);

// This marks any connection where we send data to as 'active'
mPeerConnections.MarkConnectionActive(state);

if (encryptionState == EncryptionState::kUnencrypted)
{
uint8_t * data = nullptr;
uint8_t * p = nullptr;
uint32_t msgId = 0;
uint32_t payloadLength = 0; // Make sure it's big enough to add two 16-bit ints without overflowing.
uint16_t len = 0;
MessageAuthenticationCode mac;

if (!isEncrypted)
{
msgId = state->GetSendMessageIndex();

static_assert(std::is_same<decltype(msgBuf->TotalLength()), uint16_t>::value,
"Addition to generate payloadLength might overflow");

headerSize = payloadHeader.EncodeSizeBytes();
payloadLength = static_cast<uint32_t>(headerSize + msgBuf->TotalLength());
VerifyOrExit(CanCastTo<uint16_t>(payloadLength), err = CHIP_ERROR_NO_MEMORY);

packetHeader
.SetSourceNodeId(mLocalNodeId) //
.SetDestinationNodeId(peerNodeId) //
.SetMessageId(msgId) //
.SetEncryptionKeyID(state->GetLocalKeyID()) //
.SetPayloadLength(static_cast<uint16_t>(payloadLength));
packetHeader.GetFlags().Set(Header::FlagValues::kSecure);
}
else
{
// Advancing the start to encrypted header, since the transport will attach the packet header on top of it
msgBuf->SetStart(msgBuf->Start() + headerSize);
}

ChipLogProgress(Inet, "Sending msg from %llu to %llu", mLocalNodeId, peerNodeId);

// Encrypt the packet if it's not already encrypted
if (!isEncrypted)
{
uint16_t totalLen = 0;
uint16_t taglen = 0;
uint16_t actualEncodedHeaderSize;

VerifyOrExit(msgBuf->EnsureReservedSize(headerSize), err = CHIP_ERROR_NO_MEMORY);

msgBuf->SetStart(msgBuf->Start() - headerSize);
data = msgBuf->Start();
totalLen = msgBuf->TotalLength();

err = payloadHeader.Encode(data, totalLen, &actualEncodedHeaderSize);
SuccessOrExit(err);
err = EncryptPayload(state, payloadHeader, packetHeader, peerNodeId, msgBuf);
SuccessOrExit(err);
}

err = state->GetSecureSession().Encrypt(data, totalLen, data, packetHeader, mac);
SuccessOrExit(err);
err = mac.Encode(packetHeader, &data[totalLen], kMaxTagLen, &taglen);
SuccessOrExit(err);
// The start of buffer points to the beginning of the encrypted header, and the length of buffer
// contains both the encrypted header and encrypted data.
// Locally store the start and length of the retained buffer after accounting for the size of packet header.
headerSize = packetHeader.EncodeSizeBytes();

VerifyOrExit(CanCastTo<uint16_t>(totalLen + taglen), err = CHIP_ERROR_INTERNAL);
msgBuf->SetDataLength(static_cast<uint16_t>(totalLen + taglen));
msgStart = static_cast<uint8_t *>(msgBuf->Start() - headerSize);
msgLen = static_cast<uint16_t>(msgBuf->DataLength() + headerSize);

ChipLogDetail(Inet, "Secure transport transmitting msg %u after encryption", msgId);
}

if (bufferRetainSlot)
{
// The start of buffer points to the beginning of the encrypted header, and the length of buffer
// contains both the encrypted header and encrypted data.
// Locally store the start and length of the retained buffer after accounting for the size of packet header.
headerSize = packetHeader.EncodeSizeBytes();
// Retain the PacketBuffer in case it's needed for retransmissions.
encryptedMsg = msgBuf.Retain();
encryptedMsg.mMsgId = packetHeader.GetMessageId();

p = static_cast<uint8_t *>(msgBuf->Start() - headerSize);
len = static_cast<uint16_t>(msgBuf->DataLength() + headerSize);
ChipLogProgress(Inet, "Sending msg from %llu to %llu", mLocalNodeId, peerNodeId);

// Retain the PacketBuffer for following retransmit.
*bufferRetainSlot = msgBuf.Retain();
bufferRetainSlot->mMsgId = msgId;
}
err = mTransportMgr->SendMessage(packetHeader, state->GetPeerAddress(), std::move(msgBuf));
SuccessOrExit(err);

err = mTransportMgr->SendMessage(packetHeader, state->GetPeerAddress(), std::move(msgBuf));
if (bufferRetainSlot != nullptr)
{
// Rewind the start and len of the buffer back to pre-send state for following possible retransmition.
encryptedMsg->SetStart(msgStart);
encryptedMsg->SetDataLength(msgLen);

if (bufferRetainSlot)
{
// Rewind the start and len of the buffer back to pre-send state for following possible retransmition.
(*bufferRetainSlot)->SetStart(p);
(*bufferRetainSlot)->SetDataLength(len);
}
(*bufferRetainSlot) = std::move(encryptedMsg);
bufferRetainSlot->mMsgId = encryptedMsg.mMsgId;
pan-apple marked this conversation as resolved.
Show resolved Hide resolved
}
SuccessOrExit(err);

if (!isEncrypted)
state->IncrementSendMessageIndex();

exit:
if (!msgBuf.IsNull())
Expand All @@ -231,10 +247,6 @@ CHIP_ERROR SecureSessionMgr::SendMessage(PayloadHeader & payloadHeader, NodeId p
{
ChipLogError(Inet, "Secure transport could not find a valid PeerConnection: %s", errStr);
}
else
{
ChipLogError(Inet, "Secure transport failed to encrypt msg %u: %s", state->GetSendMessageIndex(), errStr);
}
}

return err;
Expand All @@ -259,8 +271,8 @@ CHIP_ERROR SecureSessionMgr::NewPairing(const Optional<Transport::PeerAddress> &
ChipLogDetail(Inet, "New pairing for device %llu, key %d!!", peerNodeId, peerKeyId);

state = nullptr;
err = mPeerConnections.CreateNewPeerConnectionState(Optional<NodeId>::Value(peerNodeId), peerKeyId, localKeyId, &state);
SuccessOrExit(err);
ReturnErrorOnFailure(
mPeerConnections.CreateNewPeerConnectionState(Optional<NodeId>::Value(peerNodeId), peerKeyId, localKeyId, &state));

if (peerAddr.HasValue() && peerAddr.Value().GetIPAddress() != Inet::IPAddress::Any)
{
Expand All @@ -279,7 +291,6 @@ CHIP_ERROR SecureSessionMgr::NewPairing(const Optional<Transport::PeerAddress> &
strlen(kSpake2pI2RSessionInfo), state->GetSecureSession());
}

exit:
return err;
}

Expand Down
16 changes: 13 additions & 3 deletions src/transport/SecureSessionMgr.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ class DLL_EXPORT SecureSessionMgr : public TransportMgrDelegate
CHIP_ERROR SendMessage(NodeId peerNodeId, System::PacketBufferHandle msgBuf);
CHIP_ERROR SendMessage(PayloadHeader & payloadHeader, NodeId peerNodeId, System::PacketBufferHandle msgBuf,
EncryptedPacketBufferHandle * bufferRetainSlot = nullptr);
CHIP_ERROR SendMessage(EncryptedPacketBufferHandle msgBuf, EncryptedPacketBufferHandle * bufferRetainSlot);
CHIP_ERROR SendEncryptedMessage(EncryptedPacketBufferHandle msgBuf, EncryptedPacketBufferHandle * bufferRetainSlot);

/**
* @brief
Expand Down Expand Up @@ -217,6 +217,12 @@ class DLL_EXPORT SecureSessionMgr : public TransportMgrDelegate
kInitialized, /**< State when the object is ready connect to other peers. */
};

enum class EncryptionState
{
kEncrypted,
pan-apple marked this conversation as resolved.
Show resolved Hide resolved
kUnencrypted,
};

System::Layer * mSystemLayer = nullptr;
NodeId mLocalNodeId; // < Id of the current node
Transport::PeerConnections<CHIP_CONFIG_PEER_CONNECTION_POOL_SIZE> mPeerConnections; // < Active connections to other peers
Expand All @@ -225,8 +231,12 @@ class DLL_EXPORT SecureSessionMgr : public TransportMgrDelegate
SecureSessionMgrDelegate * mCB = nullptr;
TransportMgrBase * mTransportMgr = nullptr;

CHIP_ERROR SendMessage(PayloadHeader & payloadHeader, NodeId peerNodeId, System::PacketBufferHandle msgBuf,
EncryptedPacketBufferHandle * bufferRetainSlot, bool isEncrypted);
CHIP_ERROR EncryptPayload(Transport::PeerConnectionState * state, PayloadHeader & payloadHeader, PacketHeader & packetHeader,
NodeId peerNodeId, System::PacketBufferHandle & msgBuf);

CHIP_ERROR SendMessage(PayloadHeader & payloadHeader, PacketHeader & packetHeader, NodeId peerNodeId,
System::PacketBufferHandle msgBuf, EncryptedPacketBufferHandle * bufferRetainSlot,
EncryptionState encrypted);

/** Schedules a new oneshot timer for checking connection expiry. */
void ScheduleExpiryTimer();
Expand Down
2 changes: 1 addition & 1 deletion src/transport/tests/TestSecureSessionMgr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ void SendEncryptedPacketTest(nlTestSuite * inSuite, void * inContext)
err = secureSessionMgr.SendMessage(payloadHeader, kDestinationNodeId, std::move(buffer), &msgBuf);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);

err = secureSessionMgr.SendMessage(std::move(msgBuf), nullptr);
err = secureSessionMgr.SendEncryptedMessage(std::move(msgBuf), nullptr);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);
}

Expand Down