diff --git a/src/messaging/ReliableMessageManager.cpp b/src/messaging/ReliableMessageManager.cpp index 519a8efa364992..13945bb5eb2f21 100644 --- a/src/messaging/ReliableMessageManager.cpp +++ b/src/messaging/ReliableMessageManager.cpp @@ -409,7 +409,7 @@ CHIP_ERROR ReliableMessageManager::SendFromRetransTable(RetransTableEntry * entr if (rc) { - err = mSessionMgr->SendMessage(std::move(entry->retainedBuf), &entry->retainedBuf); + err = mSessionMgr->SendEncryptedMessage(std::move(entry->retainedBuf), &entry->retainedBuf); if (err == CHIP_NO_ERROR) { diff --git a/src/transport/SecureSessionMgr.cpp b/src/transport/SecureSessionMgr.cpp index 3c47db4c262803..9773d9c2980b89 100644 --- a/src/transport/SecureSessionMgr.cpp +++ b/src/transport/SecureSessionMgr.cpp @@ -31,6 +31,7 @@ #include #include +#include #include #include #include @@ -63,9 +64,8 @@ SecureSessionMgr::~SecureSessionMgr() CHIP_ERROR SecureSessionMgr::Init(NodeId localNodeId, System::Layer * systemLayer, TransportMgrBase * transportMgr) { - CHIP_ERROR err = CHIP_NO_ERROR; - VerifyOrExit(mState == State::kNotReady, err = CHIP_ERROR_INCORRECT_STATE); - VerifyOrExit(transportMgr != nullptr, err = CHIP_ERROR_INVALID_ARGUMENT); + VerifyOrReturnError(mState == State::kNotReady, CHIP_ERROR_INCORRECT_STATE); + VerifyOrReturnError(transportMgr != nullptr, CHIP_ERROR_INVALID_ARGUMENT); mState = State::kInitialized; mLocalNodeId = localNodeId; @@ -78,48 +78,121 @@ CHIP_ERROR SecureSessionMgr::Init(NodeId localNodeId, System::Layer * systemLaye mTransportMgr->SetSecureSessionMgr(this); -exit: - return err; + return CHIP_NO_ERROR; } CHIP_ERROR SecureSessionMgr::SendMessage(NodeId peerNodeId, System::PacketBufferHandle msgBuf) { - PayloadHeader payloadHeader; - - return SendMessage(payloadHeader, peerNodeId, std::move(msgBuf)); + PayloadHeader unusedPayloadHeader; + return SendMessage(unusedPayloadHeader, peerNodeId, std::move(msgBuf)); } CHIP_ERROR SecureSessionMgr::SendMessage(PayloadHeader & payloadHeader, NodeId peerNodeId, System::PacketBufferHandle msgBuf, EncryptedPacketBufferHandle * bufferRetainSlot) { - return SendMessage(payloadHeader, peerNodeId, std::move(msgBuf), bufferRetainSlot, false); + PacketHeader ununsedPacketHeader; + return SendMessage(payloadHeader, ununsedPacketHeader, peerNodeId, std::move(msgBuf), bufferRetainSlot, + EncryptionState::kPayloadIsUnencrypted); } -CHIP_ERROR SecureSessionMgr::SendMessage(EncryptedPacketBufferHandle msgBuf, EncryptedPacketBufferHandle * bufferRetainSlot) +CHIP_ERROR SecureSessionMgr::SendEncryptedMessage(EncryptedPacketBufferHandle msgBuf, + EncryptedPacketBufferHandle * bufferRetainSlot) { + VerifyOrReturnError(!msgBuf.IsNull(), CHIP_ERROR_INVALID_ARGUMENT); + + uint16_t headerSize = 0; + PacketHeader packetHeader; + ReturnErrorOnFailure(packetHeader.Decode(msgBuf->Start(), msgBuf->DataLength(), &headerSize)); + + VerifyOrReturnError(packetHeader.GetDestinationNodeId().HasValue(), CHIP_ERROR_INVALID_DESTINATION_NODE_ID); + NodeId peerNodeId = packetHeader.GetDestinationNodeId().Value(); + + // Advancing the start to encrypted header, since the transport will attach the packet header on top of it + msgBuf->SetStart(msgBuf->Start() + headerSize); + PayloadHeader payloadHeader; + ReturnErrorOnFailure(SendMessage(payloadHeader, packetHeader, peerNodeId, std::move(msgBuf), bufferRetainSlot, + EncryptionState::kPayloadIsEncrypted)); - return SendMessage(payloadHeader, 0, std::move(msgBuf), bufferRetainSlot, true); + return CHIP_NO_ERROR; } -CHIP_ERROR SecureSessionMgr::SendMessage(PayloadHeader & payloadHeader, NodeId peerNodeId, System::PacketBufferHandle msgBuf, - EncryptedPacketBufferHandle * bufferRetainSlot, bool isEncrypted) +CHIP_ERROR SecureSessionMgr::EncryptPayload(Transport::PeerConnectionState * state, PayloadHeader & payloadHeader, + PacketHeader & packetHeader, NodeId peerNodeId, System::PacketBufferHandle & msgBuf) { - CHIP_ERROR err = CHIP_NO_ERROR; - PeerConnectionState * state = nullptr; - PacketHeader packetHeader; - uint16_t headerSize = 0; + CHIP_ERROR err = CHIP_NO_ERROR; + uint8_t * data = nullptr; + uint32_t payloadLength = 0; // Make sure it's big enough to add two 16-bit ints without overflowing. + uint16_t totalLen = 0; + uint16_t taglen = 0; + uint16_t actualEncodedHeaderSize; + MessageAuthenticationCode mac; + + uint32_t msgId = state->GetSendMessageIndex(); + + static_assert(std::is_sameTotalLength()), uint16_t>::value, + "Addition to generate payloadLength might overflow"); + + uint16_t headerSize = payloadHeader.EncodeSizeBytes(); + payloadLength = static_cast(headerSize + msgBuf->TotalLength()); + VerifyOrExit(CanCastTo(payloadLength), err = CHIP_ERROR_NO_MEMORY); + + packetHeader + .SetSourceNodeId(mLocalNodeId) // + .SetDestinationNodeId(peerNodeId) // + .SetMessageId(msgId) // + .SetEncryptionKeyID(state->GetLocalKeyID()) // + .SetPayloadLength(static_cast(payloadLength)); + packetHeader.GetFlags().Set(Header::FlagValues::kSecure); + + VerifyOrExit(msgBuf->EnsureReservedSize(headerSize), err = CHIP_ERROR_NO_MEMORY); + + msgBuf->SetStart(msgBuf->Start() - headerSize); + data = msgBuf->Start(); + totalLen = msgBuf->TotalLength(); + + err = payloadHeader.Encode(data, totalLen, &actualEncodedHeaderSize); + SuccessOrExit(err); - if (isEncrypted) - { - err = packetHeader.Decode(msgBuf->Start(), msgBuf->DataLength(), &headerSize); - SuccessOrExit(err); + err = state->GetSecureSession().Encrypt(data, totalLen, data, packetHeader, mac); + SuccessOrExit(err); + err = mac.Encode(packetHeader, &data[totalLen], kMaxTagLen, &taglen); + SuccessOrExit(err); + + VerifyOrExit(CanCastTo(totalLen + taglen), err = CHIP_ERROR_INTERNAL); + msgBuf->SetDataLength(static_cast(totalLen + taglen)); - VerifyOrExit(packetHeader.GetDestinationNodeId().HasValue(), err = CHIP_ERROR_INVALID_DESTINATION_NODE_ID); - peerNodeId = packetHeader.GetDestinationNodeId().Value(); + ChipLogDetail(Inet, "Secure transport encrypted msg %u", msgId); + +exit: + if (err != CHIP_NO_ERROR) + { + const char * errStr = ErrorStr(err); + ChipLogError(Inet, "Secure transport failed to encrypt msg %u: %s", state->GetSendMessageIndex(), errStr); + } + else + { + state->IncrementSendMessageIndex(); } - state = mPeerConnections.FindPeerConnectionState(peerNodeId, nullptr); + return err; +} + +CHIP_ERROR SecureSessionMgr::SendMessage(PayloadHeader & payloadHeader, PacketHeader & packetHeader, NodeId peerNodeId, + System::PacketBufferHandle msgBuf, EncryptedPacketBufferHandle * bufferRetainSlot, + EncryptionState encryptionState) +{ + CHIP_ERROR err = CHIP_NO_ERROR; + PeerConnectionState * state = nullptr; + uint8_t * msgStart = nullptr; + uint16_t msgLen = 0; + uint16_t headerSize = 0; + + // Hold the reference to encrypted message in stack variable. + // In case of any failures, the reference is not returned, and this stack variable + // will automatically free the reference on returning from the function. + EncryptedPacketBufferHandle encryptedMsg; + VerifyOrExit(mState == State::kInitialized, err = CHIP_ERROR_INCORRECT_STATE); VerifyOrExit(!msgBuf.IsNull(), err = CHIP_ERROR_INVALID_ARGUMENT); @@ -127,101 +200,43 @@ CHIP_ERROR SecureSessionMgr::SendMessage(PayloadHeader & payloadHeader, NodeId p VerifyOrExit(msgBuf->TotalLength() < kMax_SecureSDU_Length, err = CHIP_ERROR_INVALID_MESSAGE_LENGTH); // Find an active connection to the specified peer node + state = mPeerConnections.FindPeerConnectionState(peerNodeId, nullptr); VerifyOrExit(state != nullptr, err = CHIP_ERROR_INVALID_DESTINATION_NODE_ID); // This marks any connection where we send data to as 'active' mPeerConnections.MarkConnectionActive(state); + if (encryptionState == EncryptionState::kPayloadIsUnencrypted) { - uint8_t * data = nullptr; - uint8_t * p = nullptr; - uint32_t msgId = 0; - uint32_t payloadLength = 0; // Make sure it's big enough to add two 16-bit ints without overflowing. - uint16_t len = 0; - MessageAuthenticationCode mac; - - if (!isEncrypted) - { - msgId = state->GetSendMessageIndex(); - - static_assert(std::is_sameTotalLength()), uint16_t>::value, - "Addition to generate payloadLength might overflow"); - - headerSize = payloadHeader.EncodeSizeBytes(); - payloadLength = static_cast(headerSize + msgBuf->TotalLength()); - VerifyOrExit(CanCastTo(payloadLength), err = CHIP_ERROR_NO_MEMORY); - - packetHeader - .SetSourceNodeId(mLocalNodeId) // - .SetDestinationNodeId(peerNodeId) // - .SetMessageId(msgId) // - .SetEncryptionKeyID(state->GetLocalKeyID()) // - .SetPayloadLength(static_cast(payloadLength)); - packetHeader.GetFlags().Set(Header::FlagValues::kSecure); - } - else - { - // Advancing the start to encrypted header, since the transport will attach the packet header on top of it - msgBuf->SetStart(msgBuf->Start() + headerSize); - } - - ChipLogProgress(Inet, "Sending msg from %llu to %llu", mLocalNodeId, peerNodeId); - - // Encrypt the packet if it's not already encrypted - if (!isEncrypted) - { - uint16_t totalLen = 0; - uint16_t taglen = 0; - uint16_t actualEncodedHeaderSize; - - VerifyOrExit(msgBuf->EnsureReservedSize(headerSize), err = CHIP_ERROR_NO_MEMORY); - - msgBuf->SetStart(msgBuf->Start() - headerSize); - data = msgBuf->Start(); - totalLen = msgBuf->TotalLength(); - - err = payloadHeader.Encode(data, totalLen, &actualEncodedHeaderSize); - SuccessOrExit(err); + err = EncryptPayload(state, payloadHeader, packetHeader, peerNodeId, msgBuf); + SuccessOrExit(err); + } - err = state->GetSecureSession().Encrypt(data, totalLen, data, packetHeader, mac); - SuccessOrExit(err); - err = mac.Encode(packetHeader, &data[totalLen], kMaxTagLen, &taglen); - SuccessOrExit(err); + // The start of buffer points to the beginning of the encrypted header, and the length of buffer + // contains both the encrypted header and encrypted data. + // Locally store the start and length of the retained buffer after accounting for the size of packet header. + headerSize = packetHeader.EncodeSizeBytes(); - VerifyOrExit(CanCastTo(totalLen + taglen), err = CHIP_ERROR_INTERNAL); - msgBuf->SetDataLength(static_cast(totalLen + taglen)); + msgStart = static_cast(msgBuf->Start() - headerSize); + msgLen = static_cast(msgBuf->DataLength() + headerSize); - ChipLogDetail(Inet, "Secure transport transmitting msg %u after encryption", msgId); - } - - if (bufferRetainSlot) - { - // The start of buffer points to the beginning of the encrypted header, and the length of buffer - // contains both the encrypted header and encrypted data. - // Locally store the start and length of the retained buffer after accounting for the size of packet header. - headerSize = packetHeader.EncodeSizeBytes(); + // Retain the PacketBuffer in case it's needed for retransmissions. + encryptedMsg = msgBuf.Retain(); + encryptedMsg.mMsgId = packetHeader.GetMessageId(); - p = static_cast(msgBuf->Start() - headerSize); - len = static_cast(msgBuf->DataLength() + headerSize); + ChipLogProgress(Inet, "Sending msg from %llu to %llu", mLocalNodeId, peerNodeId); - // Retain the PacketBuffer for following retransmit. - *bufferRetainSlot = msgBuf.Retain(); - bufferRetainSlot->mMsgId = msgId; - } + err = mTransportMgr->SendMessage(packetHeader, state->GetPeerAddress(), std::move(msgBuf)); + SuccessOrExit(err); - err = mTransportMgr->SendMessage(packetHeader, state->GetPeerAddress(), std::move(msgBuf)); + if (bufferRetainSlot != nullptr) + { + // Rewind the start and len of the buffer back to pre-send state for following possible retransmition. + encryptedMsg->SetStart(msgStart); + encryptedMsg->SetDataLength(msgLen); - if (bufferRetainSlot) - { - // Rewind the start and len of the buffer back to pre-send state for following possible retransmition. - (*bufferRetainSlot)->SetStart(p); - (*bufferRetainSlot)->SetDataLength(len); - } + (*bufferRetainSlot) = std::move(encryptedMsg); } - SuccessOrExit(err); - - if (!isEncrypted) - state->IncrementSendMessageIndex(); exit: if (!msgBuf.IsNull()) @@ -231,10 +246,6 @@ CHIP_ERROR SecureSessionMgr::SendMessage(PayloadHeader & payloadHeader, NodeId p { ChipLogError(Inet, "Secure transport could not find a valid PeerConnection: %s", errStr); } - else - { - ChipLogError(Inet, "Secure transport failed to encrypt msg %u: %s", state->GetSendMessageIndex(), errStr); - } } return err; @@ -259,8 +270,8 @@ CHIP_ERROR SecureSessionMgr::NewPairing(const Optional & ChipLogDetail(Inet, "New pairing for device %llu, key %d!!", peerNodeId, peerKeyId); state = nullptr; - err = mPeerConnections.CreateNewPeerConnectionState(Optional::Value(peerNodeId), peerKeyId, localKeyId, &state); - SuccessOrExit(err); + ReturnErrorOnFailure( + mPeerConnections.CreateNewPeerConnectionState(Optional::Value(peerNodeId), peerKeyId, localKeyId, &state)); if (peerAddr.HasValue() && peerAddr.Value().GetIPAddress() != Inet::IPAddress::Any) { @@ -279,7 +290,6 @@ CHIP_ERROR SecureSessionMgr::NewPairing(const Optional & strlen(kSpake2pI2RSessionInfo), state->GetSecureSession()); } -exit: return err; } diff --git a/src/transport/SecureSessionMgr.h b/src/transport/SecureSessionMgr.h index bb2e819d89e951..a2cbfa124ac938 100644 --- a/src/transport/SecureSessionMgr.h +++ b/src/transport/SecureSessionMgr.h @@ -157,7 +157,7 @@ class DLL_EXPORT SecureSessionMgr : public TransportMgrDelegate CHIP_ERROR SendMessage(NodeId peerNodeId, System::PacketBufferHandle msgBuf); CHIP_ERROR SendMessage(PayloadHeader & payloadHeader, NodeId peerNodeId, System::PacketBufferHandle msgBuf, EncryptedPacketBufferHandle * bufferRetainSlot = nullptr); - CHIP_ERROR SendMessage(EncryptedPacketBufferHandle msgBuf, EncryptedPacketBufferHandle * bufferRetainSlot); + CHIP_ERROR SendEncryptedMessage(EncryptedPacketBufferHandle msgBuf, EncryptedPacketBufferHandle * bufferRetainSlot); /** * @brief @@ -217,6 +217,12 @@ class DLL_EXPORT SecureSessionMgr : public TransportMgrDelegate kInitialized, /**< State when the object is ready connect to other peers. */ }; + enum class EncryptionState + { + kPayloadIsEncrypted, + kPayloadIsUnencrypted, + }; + System::Layer * mSystemLayer = nullptr; NodeId mLocalNodeId; // < Id of the current node Transport::PeerConnections mPeerConnections; // < Active connections to other peers @@ -225,8 +231,12 @@ class DLL_EXPORT SecureSessionMgr : public TransportMgrDelegate SecureSessionMgrDelegate * mCB = nullptr; TransportMgrBase * mTransportMgr = nullptr; - CHIP_ERROR SendMessage(PayloadHeader & payloadHeader, NodeId peerNodeId, System::PacketBufferHandle msgBuf, - EncryptedPacketBufferHandle * bufferRetainSlot, bool isEncrypted); + CHIP_ERROR EncryptPayload(Transport::PeerConnectionState * state, PayloadHeader & payloadHeader, PacketHeader & packetHeader, + NodeId peerNodeId, System::PacketBufferHandle & msgBuf); + + CHIP_ERROR SendMessage(PayloadHeader & payloadHeader, PacketHeader & packetHeader, NodeId peerNodeId, + System::PacketBufferHandle msgBuf, EncryptedPacketBufferHandle * bufferRetainSlot, + EncryptionState encryptionState); /** Schedules a new oneshot timer for checking connection expiry. */ void ScheduleExpiryTimer(); diff --git a/src/transport/tests/TestSecureSessionMgr.cpp b/src/transport/tests/TestSecureSessionMgr.cpp index 971360896e352d..8febdc4f1bf00f 100644 --- a/src/transport/tests/TestSecureSessionMgr.cpp +++ b/src/transport/tests/TestSecureSessionMgr.cpp @@ -242,7 +242,7 @@ void SendEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) err = secureSessionMgr.SendMessage(payloadHeader, kDestinationNodeId, std::move(buffer), &msgBuf); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - err = secureSessionMgr.SendMessage(std::move(msgBuf), nullptr); + err = secureSessionMgr.SendEncryptedMessage(std::move(msgBuf), nullptr); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); }