From 23d3a6d9d7918ba49fd7fdd7e6adf1f6e6022e21 Mon Sep 17 00:00:00 2001 From: Zang MingJie Date: Thu, 17 Jun 2021 00:03:50 +0800 Subject: [PATCH] Fix CRMP resend null out retained buffer (#7312) * Fix CRMP resend null out retained buffer * Resolve comments * Apply suggestions from code review Co-authored-by: Boris Zbarsky Co-authored-by: Boris Zbarsky --- src/messaging/ApplicationExchangeDispatch.cpp | 14 +-- src/messaging/ApplicationExchangeDispatch.h | 8 +- src/messaging/ExchangeMessageDispatch.cpp | 26 +++--- src/messaging/ExchangeMessageDispatch.h | 25 +++-- src/messaging/ReliableMessageMgr.cpp | 3 +- .../tests/TestReliableMessageProtocol.cpp | 28 ++---- .../SessionEstablishmentExchangeDispatch.cpp | 33 ++----- .../SessionEstablishmentExchangeDispatch.h | 10 +- src/transport/SecureSessionMgr.cpp | 93 ++++++++----------- src/transport/SecureSessionMgr.h | 31 +++---- src/transport/tests/TestSecureSessionMgr.cpp | 53 +++++++---- 11 files changed, 142 insertions(+), 182 deletions(-) diff --git a/src/messaging/ApplicationExchangeDispatch.cpp b/src/messaging/ApplicationExchangeDispatch.cpp index 80a1a438647187..d76f4b779fc2e6 100644 --- a/src/messaging/ApplicationExchangeDispatch.cpp +++ b/src/messaging/ApplicationExchangeDispatch.cpp @@ -26,17 +26,17 @@ namespace chip { namespace Messaging { -CHIP_ERROR ApplicationExchangeDispatch::SendMessageImpl(SecureSessionHandle session, PayloadHeader & payloadHeader, - System::PacketBufferHandle && message, - EncryptedPacketBufferHandle * retainedMessage) +CHIP_ERROR ApplicationExchangeDispatch::PrepareMessage(SecureSessionHandle session, PayloadHeader & payloadHeader, + System::PacketBufferHandle && message, + EncryptedPacketBufferHandle & preparedMessage) { - return mSessionMgr->SendMessage(session, payloadHeader, std::move(message), retainedMessage); + return mSessionMgr->BuildEncryptedMessagePayload(session, payloadHeader, std::move(message), preparedMessage); } -CHIP_ERROR ApplicationExchangeDispatch::ResendMessage(SecureSessionHandle session, EncryptedPacketBufferHandle && message, - EncryptedPacketBufferHandle * retainedMessage) const +CHIP_ERROR ApplicationExchangeDispatch::SendPreparedMessage(SecureSessionHandle session, + const EncryptedPacketBufferHandle & preparedMessage) const { - return mSessionMgr->SendEncryptedMessage(session, std::move(message), retainedMessage); + return mSessionMgr->SendPreparedMessage(session, preparedMessage); } bool ApplicationExchangeDispatch::MessagePermitted(uint16_t protocol, uint8_t type) diff --git a/src/messaging/ApplicationExchangeDispatch.h b/src/messaging/ApplicationExchangeDispatch.h index ed330ac7e623ab..985ad73e56ec86 100644 --- a/src/messaging/ApplicationExchangeDispatch.h +++ b/src/messaging/ApplicationExchangeDispatch.h @@ -45,15 +45,13 @@ class ApplicationExchangeDispatch : public ExchangeMessageDispatch return ExchangeMessageDispatch::Init(); } - CHIP_ERROR ResendMessage(SecureSessionHandle session, EncryptedPacketBufferHandle && message, - EncryptedPacketBufferHandle * retainedMessage) const override; + CHIP_ERROR PrepareMessage(SecureSessionHandle session, PayloadHeader & payloadHeader, System::PacketBufferHandle && message, + EncryptedPacketBufferHandle & preparedMessage) override; + CHIP_ERROR SendPreparedMessage(SecureSessionHandle session, const EncryptedPacketBufferHandle & message) const override; SecureSessionMgr * GetSessionMgr() const { return mSessionMgr; } protected: - CHIP_ERROR SendMessageImpl(SecureSessionHandle session, PayloadHeader & payloadHeader, System::PacketBufferHandle && message, - EncryptedPacketBufferHandle * retainedMessage) override; - bool MessagePermitted(uint16_t protocol, uint8_t type) override; private: diff --git a/src/messaging/ExchangeMessageDispatch.cpp b/src/messaging/ExchangeMessageDispatch.cpp index 22118e6ab12e58..fdd924c217f763 100644 --- a/src/messaging/ExchangeMessageDispatch.cpp +++ b/src/messaging/ExchangeMessageDispatch.cpp @@ -29,6 +29,7 @@ #endif #include +#include #include #include @@ -76,25 +77,22 @@ CHIP_ERROR ExchangeMessageDispatch::SendMessage(SecureSessionHandle session, uin // Add to Table for subsequent sending ReturnErrorOnFailure(reliableMessageMgr->AddToRetransTable(reliableMessageContext, &entry)); - - CHIP_ERROR err = SendMessageImpl(session, payloadHeader, std::move(message), &entry->retainedBuf); - if (err != CHIP_NO_ERROR) - { - // Remove from table - ChipLogError(ExchangeManager, "Failed to send message with err %s", ::chip::ErrorStr(err)); - reliableMessageMgr->ClearRetransTable(*entry); - ReturnErrorOnFailure(err); - } - else - { - reliableMessageMgr->StartRetransmision(entry); - } + auto deleter = [reliableMessageMgr](ReliableMessageMgr::RetransTableEntry * e) { + reliableMessageMgr->ClearRetransTable(*e); + }; + std::unique_ptr entryOwner(entry, deleter); + + ReturnErrorOnFailure(PrepareMessage(session, payloadHeader, std::move(message), entryOwner->retainedBuf)); + ReturnErrorOnFailure(SendPreparedMessage(session, entryOwner->retainedBuf)); + reliableMessageMgr->StartRetransmision(entryOwner.release()); } else { // If the channel itself is providing reliability, let's not request MRP acks payloadHeader.SetNeedsAck(false); - ReturnErrorOnFailure(SendMessageImpl(session, payloadHeader, std::move(message), nullptr)); + EncryptedPacketBufferHandle preparedMessage; + ReturnErrorOnFailure(PrepareMessage(session, payloadHeader, std::move(message), preparedMessage)); + ReturnErrorOnFailure(SendPreparedMessage(session, preparedMessage)); } return CHIP_NO_ERROR; diff --git a/src/messaging/ExchangeMessageDispatch.h b/src/messaging/ExchangeMessageDispatch.h index 0b9719ee85dc42..e0d9dd60322f7a 100644 --- a/src/messaging/ExchangeMessageDispatch.h +++ b/src/messaging/ExchangeMessageDispatch.h @@ -45,15 +45,18 @@ class ExchangeMessageDispatch : public ReferenceCounted uint8_t type, System::PacketBufferHandle && message); /** - * The 'message' and 'retainedMessage' arguments may point to the same - * handle. Therefore, callees _must_ ensure that any moving out of - * 'message' happens before writing to *retainedMessage. + * @brief + * This interface takes the payload and returns the prepared message which can be send multiple times. + * + * @param session Peer node to which the payload to be sent + * @param payloadHeader The payloadHeader to be encoded into the packet + * @param message The payload to be sent + * @param preparedMessage The handle to hold the prepared message */ - virtual CHIP_ERROR ResendMessage(SecureSessionHandle session, EncryptedPacketBufferHandle && message, - EncryptedPacketBufferHandle * retainedMessage) const - { - return CHIP_ERROR_NOT_IMPLEMENTED; - } + virtual CHIP_ERROR PrepareMessage(SecureSessionHandle session, PayloadHeader & payloadHeader, + System::PacketBufferHandle && message, EncryptedPacketBufferHandle & preparedMessage) = 0; + virtual CHIP_ERROR SendPreparedMessage(SecureSessionHandle session, + const EncryptedPacketBufferHandle & preparedMessage) const = 0; virtual CHIP_ERROR OnMessageReceived(const PayloadHeader & payloadHeader, uint32_t messageId, const Transport::PeerAddress & peerAddress, @@ -61,11 +64,7 @@ class ExchangeMessageDispatch : public ReferenceCounted protected: virtual bool MessagePermitted(uint16_t protocol, uint8_t type) = 0; - - virtual CHIP_ERROR SendMessageImpl(SecureSessionHandle session, PayloadHeader & payloadHeader, - System::PacketBufferHandle && message, EncryptedPacketBufferHandle * retainedMessage) = 0; - - virtual bool IsReliableTransmissionAllowed() { return true; } + virtual bool IsReliableTransmissionAllowed() const { return true; } }; } // namespace Messaging diff --git a/src/messaging/ReliableMessageMgr.cpp b/src/messaging/ReliableMessageMgr.cpp index 0e7b7e9f832e1e..29fc998b95ae64 100644 --- a/src/messaging/ReliableMessageMgr.cpp +++ b/src/messaging/ReliableMessageMgr.cpp @@ -366,8 +366,7 @@ CHIP_ERROR ReliableMessageMgr::SendFromRetransTable(RetransTableEntry * entry) const ExchangeMessageDispatch * dispatcher = rc->GetExchangeContext()->GetMessageDispatch(); VerifyOrExit(dispatcher != nullptr, err = CHIP_ERROR_INCORRECT_STATE); - err = - dispatcher->ResendMessage(rc->GetExchangeContext()->GetSecureSession(), std::move(entry->retainedBuf), &entry->retainedBuf); + err = dispatcher->SendPreparedMessage(rc->GetExchangeContext()->GetSecureSession(), entry->retainedBuf); SuccessOrExit(err); // Update the counters diff --git a/src/messaging/tests/TestReliableMessageProtocol.cpp b/src/messaging/tests/TestReliableMessageProtocol.cpp index 3e9ac2ff27819c..580f21578b6674 100644 --- a/src/messaging/tests/TestReliableMessageProtocol.cpp +++ b/src/messaging/tests/TestReliableMessageProtocol.cpp @@ -117,37 +117,25 @@ class MockAppDelegate : public ExchangeDelegate class MockSessionEstablishmentExchangeDispatch : public Messaging::ExchangeMessageDispatch { public: - CHIP_ERROR SendMessageImpl(SecureSessionHandle session, PayloadHeader & payloadHeader, System::PacketBufferHandle && message, - EncryptedPacketBufferHandle * retainedMessage) override + CHIP_ERROR PrepareMessage(SecureSessionHandle session, PayloadHeader & payloadHeader, System::PacketBufferHandle && message, + EncryptedPacketBufferHandle & preparedMessage) override { PacketHeader packetHeader; ReturnErrorOnFailure(payloadHeader.EncodeBeforeData(message)); ReturnErrorOnFailure(packetHeader.EncodeBeforeData(message)); - if (retainedMessage != nullptr && mRetainMessageOnSend) - { - *retainedMessage = EncryptedPacketBufferHandle::MarkEncrypted(message.Retain()); - } - return gTransportMgr.SendMessage(Transport::PeerAddress(), std::move(message)); + preparedMessage = EncryptedPacketBufferHandle::MarkEncrypted(std::move(message)); + return CHIP_NO_ERROR; } - CHIP_ERROR ResendMessage(SecureSessionHandle session, EncryptedPacketBufferHandle && message, - EncryptedPacketBufferHandle * retainedMessage) const override + CHIP_ERROR SendPreparedMessage(SecureSessionHandle session, const EncryptedPacketBufferHandle & preparedMessage) const override { - // Our send path needs a (writable) PacketBuffer, so get that from the - // EncryptedPacketBufferHandle. Note that we have to do this before we - // set *retainedMessage, because 'message' and '*retainedMessage' might - // be the same memory location and we have to guarantee that we move out - // of 'message' before we write to *retainedMessage. - System::PacketBufferHandle writableBuf(std::move(message).CastToWritable()); - if (retainedMessage != nullptr && mRetainMessageOnSend) - { - *retainedMessage = EncryptedPacketBufferHandle::MarkEncrypted(writableBuf.Retain()); - } - return gTransportMgr.SendMessage(Transport::PeerAddress(), std::move(writableBuf)); + return gTransportMgr.SendMessage(Transport::PeerAddress(), preparedMessage.CastToWritable()); } + bool IsReliableTransmissionAllowed() const override { return mRetainMessageOnSend; } + bool MessagePermitted(uint16_t protocol, uint8_t type) override { return true; } bool mRetainMessageOnSend = true; diff --git a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp b/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp index a84defaeadee7f..3593427ec08acf 100644 --- a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp +++ b/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp @@ -28,40 +28,23 @@ namespace chip { using namespace Messaging; -CHIP_ERROR SessionEstablishmentExchangeDispatch::SendMessageImpl(SecureSessionHandle session, PayloadHeader & payloadHeader, - System::PacketBufferHandle && message, - EncryptedPacketBufferHandle * retainedMessage) +CHIP_ERROR SessionEstablishmentExchangeDispatch::PrepareMessage(SecureSessionHandle session, PayloadHeader & payloadHeader, + System::PacketBufferHandle && message, + EncryptedPacketBufferHandle & preparedMessage) { - ReturnErrorCodeIf(mTransportMgr == nullptr, CHIP_ERROR_INCORRECT_STATE); PacketHeader packetHeader; - ReturnErrorOnFailure(payloadHeader.EncodeBeforeData(message)); ReturnErrorOnFailure(packetHeader.EncodeBeforeData(message)); - if (retainedMessage != nullptr) - { - *retainedMessage = EncryptedPacketBufferHandle::MarkEncrypted(message.Retain()); - ChipLogError(Inet, "RETAINED IN SESS: %p %d", retainedMessage, (*retainedMessage).IsNull()); - } - return mTransportMgr->SendMessage(mPeerAddress, std::move(message)); + preparedMessage = EncryptedPacketBufferHandle::MarkEncrypted(std::move(message)); + return CHIP_NO_ERROR; } -CHIP_ERROR SessionEstablishmentExchangeDispatch::ResendMessage(SecureSessionHandle session, EncryptedPacketBufferHandle && message, - EncryptedPacketBufferHandle * retainedMessage) const +CHIP_ERROR SessionEstablishmentExchangeDispatch::SendPreparedMessage(SecureSessionHandle session, + const EncryptedPacketBufferHandle & preparedMessage) const { ReturnErrorCodeIf(mTransportMgr == nullptr, CHIP_ERROR_INCORRECT_STATE); - - // Our send path needs a (writable) PacketBuffer, so get that from the - // EncryptedPacketBufferHandle. Note that we have to do this before we set - // *retainedMessage, because 'message' and '*retainedMessage' might be the - // same memory location and we have to guarantee that we move out of - // 'message' before we write to *retainedMessage. - System::PacketBufferHandle writableBuf(std::move(message).CastToWritable()); - if (retainedMessage != nullptr) - { - *retainedMessage = EncryptedPacketBufferHandle::MarkEncrypted(writableBuf.Retain()); - } - return mTransportMgr->SendMessage(mPeerAddress, std::move(writableBuf)); + return mTransportMgr->SendMessage(mPeerAddress, preparedMessage.CastToWritable()); } CHIP_ERROR SessionEstablishmentExchangeDispatch::OnMessageReceived(const PayloadHeader & payloadHeader, uint32_t messageId, diff --git a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.h b/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.h index 8ac9c0a2df3705..e59bfb92630fce 100644 --- a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.h +++ b/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.h @@ -43,8 +43,9 @@ class SessionEstablishmentExchangeDispatch : public Messaging::ExchangeMessageDi return ExchangeMessageDispatch::Init(); } - CHIP_ERROR ResendMessage(SecureSessionHandle session, EncryptedPacketBufferHandle && message, - EncryptedPacketBufferHandle * retainedMessage) const override; + CHIP_ERROR PrepareMessage(SecureSessionHandle session, PayloadHeader & payloadHeader, System::PacketBufferHandle && message, + EncryptedPacketBufferHandle & out) override; + CHIP_ERROR SendPreparedMessage(SecureSessionHandle session, const EncryptedPacketBufferHandle & preparedMessage) const override; CHIP_ERROR OnMessageReceived(const PayloadHeader & payloadHeader, uint32_t messageId, const Transport::PeerAddress & peerAddress, @@ -55,12 +56,9 @@ class SessionEstablishmentExchangeDispatch : public Messaging::ExchangeMessageDi void SetPeerAddress(const Transport::PeerAddress & address) { mPeerAddress = address; } protected: - CHIP_ERROR SendMessageImpl(SecureSessionHandle session, PayloadHeader & payloadHeader, System::PacketBufferHandle && message, - EncryptedPacketBufferHandle * retainedMessage) override; - bool MessagePermitted(uint16_t protocol, uint8_t type) override; - bool IsReliableTransmissionAllowed() override + bool IsReliableTransmissionAllowed() const override { // If the underlying transport is UDP. return (mPeerAddress.GetTransportType() == Transport::Type::kUdp); diff --git a/src/transport/SecureSessionMgr.cpp b/src/transport/SecureSessionMgr.cpp index 89458b96ae67fb..162efd3cd5974c 100644 --- a/src/transport/SecureSessionMgr.cpp +++ b/src/transport/SecureSessionMgr.cpp @@ -122,44 +122,54 @@ Transport::Type SecureSessionMgr::GetTransportType(NodeId peerNodeId) return Transport::Type::kUndefined; } -CHIP_ERROR SecureSessionMgr::SendMessage(SecureSessionHandle session, PayloadHeader & payloadHeader, - System::PacketBufferHandle && msgBuf, EncryptedPacketBufferHandle * bufferRetainSlot) +CHIP_ERROR SecureSessionMgr::BuildEncryptedMessagePayload(SecureSessionHandle session, PayloadHeader & payloadHeader, + System::PacketBufferHandle && msgBuf, + EncryptedPacketBufferHandle & encryptedMessage) { - PacketHeader unusedPacketHeader; - return SendMessage(session, payloadHeader, unusedPacketHeader, std::move(msgBuf), bufferRetainSlot, - EncryptionState::kPayloadIsUnencrypted); -} + PacketHeader packetHeader; + if (IsControlMessage(payloadHeader)) + { + packetHeader.SetSecureSessionControlMsg(true); + } -CHIP_ERROR SecureSessionMgr::SendEncryptedMessage(SecureSessionHandle session, EncryptedPacketBufferHandle && msgBuf, - EncryptedPacketBufferHandle * bufferRetainSlot) -{ - VerifyOrReturnError(!msgBuf.IsNull(), CHIP_ERROR_INVALID_ARGUMENT); - VerifyOrReturnError(!msgBuf.HasChainedBuffer(), CHIP_ERROR_INVALID_MESSAGE_LENGTH); + PeerConnectionState * state = GetPeerConnectionState(session); + if (state == nullptr) + { + return CHIP_ERROR_NOT_CONNECTED; + } - // Our send path needs a (writable) PacketBuffer (e.g. so it can encode a - // PacketHeader into it), so get that from the EncryptedPacketBufferHandle. - System::PacketBufferHandle mutableBuf(std::move(msgBuf).CastToWritable()); + Transport::AdminPairingInfo * admin = mAdmins->FindAdminWithId(state->GetAdminId()); + if (admin == nullptr) + { + return CHIP_ERROR_INCORRECT_STATE; + } - // Advancing the start to encrypted header, since SendMessage will attach the packet header on top of it. - PacketHeader packetHeader; - ReturnErrorOnFailure(packetHeader.DecodeAndConsume(mutableBuf)); + NodeId localNodeId = admin->GetNodeId(); + MessageCounter & counter = GetSendCounterForPacket(payloadHeader, *state); + ReturnErrorOnFailure(SecureMessageCodec::Encode(localNodeId, state, payloadHeader, packetHeader, msgBuf, counter)); - PayloadHeader payloadHeader; - return SendMessage(session, payloadHeader, packetHeader, std::move(mutableBuf), bufferRetainSlot, - EncryptionState::kPayloadIsEncrypted); + ReturnErrorOnFailure(packetHeader.EncodeBeforeData(msgBuf)); + + encryptedMessage = EncryptedPacketBufferHandle::MarkEncrypted(std::move(msgBuf)); + ChipLogProgress(Inet, + "Encrypted message %p from 0x" ChipLogFormatX64 " to 0x" ChipLogFormatX64 " of type %d and protocolId %" PRIu32 + " on exchange %d.", + &encryptedMessage, ChipLogValueX64(localNodeId), ChipLogValueX64(state->GetPeerNodeId()), + payloadHeader.GetMessageType(), payloadHeader.GetProtocolID().ToFullyQualifiedSpecForm(), + payloadHeader.GetExchangeID()); + + return CHIP_NO_ERROR; } -CHIP_ERROR SecureSessionMgr::SendMessage(SecureSessionHandle session, PayloadHeader & payloadHeader, PacketHeader & packetHeader, - System::PacketBufferHandle && msgBuf, EncryptedPacketBufferHandle * bufferRetainSlot, - EncryptionState encryptionState) +CHIP_ERROR SecureSessionMgr::SendPreparedMessage(SecureSessionHandle session, const EncryptedPacketBufferHandle & preparedMessage) { CHIP_ERROR err = CHIP_NO_ERROR; PeerConnectionState * state = nullptr; - NodeId localNodeId = mLocalNodeId; - - Transport::AdminPairingInfo * admin = nullptr; + PacketBufferHandle msgBuf; VerifyOrExit(mState == State::kInitialized, err = CHIP_ERROR_INCORRECT_STATE); + VerifyOrExit(!preparedMessage.IsNull(), err = CHIP_ERROR_INVALID_ARGUMENT); + msgBuf = preparedMessage.CastToWritable(); VerifyOrExit(!msgBuf.IsNull(), err = CHIP_ERROR_INVALID_ARGUMENT); VerifyOrExit(!msgBuf->HasChainedBuffer(), err = CHIP_ERROR_INVALID_MESSAGE_LENGTH); @@ -169,36 +179,9 @@ CHIP_ERROR SecureSessionMgr::SendMessage(SecureSessionHandle session, PayloadHea // This marks any connection where we send data to as 'active' mPeerConnections.MarkConnectionActive(state); - admin = mAdmins->FindAdminWithId(state->GetAdminId()); - VerifyOrExit(admin != nullptr, err = CHIP_ERROR_INCORRECT_STATE); - localNodeId = admin->GetNodeId(); - - if (IsControlMessage(payloadHeader)) - { - packetHeader.SetSecureSessionControlMsg(true); - } - - if (encryptionState == EncryptionState::kPayloadIsUnencrypted) - { - MessageCounter & counter = GetSendCounterForPacket(payloadHeader, *state); - err = SecureMessageCodec::Encode(localNodeId, state, payloadHeader, packetHeader, msgBuf, counter); - SuccessOrExit(err); - } - - err = packetHeader.EncodeBeforeData(msgBuf); - SuccessOrExit(err); - - // Retain the packet buffer in case it's needed for retransmissions. - if (bufferRetainSlot != nullptr) - { - *bufferRetainSlot = EncryptedPacketBufferHandle::MarkEncrypted(msgBuf.Retain()); - } - - ChipLogProgress(Inet, "Send message of type %d and protocolId %" PRIu32 " on exchange %d", payloadHeader.GetMessageType(), - payloadHeader.GetProtocolID().ToFullyQualifiedSpecForm(), payloadHeader.GetExchangeID()); - ChipLogProgress(Inet, "Sending msg from 0x" ChipLogFormatX64 " to 0x" ChipLogFormatX64 " at utc time: %" PRId64 " msec", - ChipLogValueX64(localNodeId), ChipLogValueX64(state->GetPeerNodeId()), System::Layer::GetClock_MonotonicMS()); + ChipLogProgress(Inet, "Sending msg %p to 0x" ChipLogFormatX64 " at utc time: %" PRId64 " msec", &preparedMessage, + ChipLogValueX64(state->GetPeerNodeId()), System::Layer::GetClock_MonotonicMS()); if (state->GetTransport() != nullptr) { diff --git a/src/transport/SecureSessionMgr.h b/src/transport/SecureSessionMgr.h index 576834fdfa93f9..acdb01c1c7e2fa 100644 --- a/src/transport/SecureSessionMgr.h +++ b/src/transport/SecureSessionMgr.h @@ -106,11 +106,8 @@ class EncryptedPacketBufferHandle final : private System::PacketBufferHandle * Get a handle to the data that allows mutating the bytes. This should * only be used if absolutely necessary, because EncryptedPacketBufferHandle * represents a buffer that we want to resend as-is. - * - * We only allow doing this with an rvalue reference, so the fact that we - * are moving out of the EncryptedPacketBufferHandle is clear. */ - PacketBufferHandle CastToWritable() && { return PacketBufferHandle(std::move(*this)); } + PacketBufferHandle CastToWritable() const { return PacketBufferHandle::Retain(); } private: EncryptedPacketBufferHandle(PacketBufferHandle && aBuffer) : PacketBufferHandle(std::move(aBuffer)) {} @@ -182,17 +179,23 @@ class DLL_EXPORT SecureSessionMgr : public TransportMgrDelegate /** * @brief - * Send a message to a currently connected peer. + * This function takes the payload and returns an encrypted message which can be sent multiple times. * * @details - * msgBuf contains the data to be transmitted. If bufferRetainSlot is not null and this function - * returns success, the encrypted data that was sent, as well as various other information needed - * to retransmit it, will be stored in *bufferRetainSlot. + * It does the following: + * 1. Encrypt the msgBuf + * 2. construct the packet header + * 3. Encode the packet header and prepend it to message. + * Returns a encrypted message in encryptedMessage. + */ + CHIP_ERROR BuildEncryptedMessagePayload(SecureSessionHandle session, PayloadHeader & payloadHeader, + System::PacketBufferHandle && msgBuf, EncryptedPacketBufferHandle & encryptedMessage); + + /** + * @brief + * Send a prepared message to a currently connected peer. */ - CHIP_ERROR SendMessage(SecureSessionHandle session, PayloadHeader & payloadHeader, System::PacketBufferHandle && msgBuf, - EncryptedPacketBufferHandle * bufferRetainSlot = nullptr); - CHIP_ERROR SendEncryptedMessage(SecureSessionHandle session, EncryptedPacketBufferHandle && msgBuf, - EncryptedPacketBufferHandle * bufferRetainSlot); + CHIP_ERROR SendPreparedMessage(SecureSessionHandle session, const EncryptedPacketBufferHandle & preparedMessage); Transport::PeerConnectionState * GetPeerConnectionState(SecureSessionHandle session); @@ -301,10 +304,6 @@ class DLL_EXPORT SecureSessionMgr : public TransportMgrDelegate GlobalUnencryptedMessageCounter mGlobalUnencryptedMessageCounter; GlobalEncryptedMessageCounter mGlobalEncryptedMessageCounter; - CHIP_ERROR SendMessage(SecureSessionHandle session, PayloadHeader & payloadHeader, PacketHeader & packetHeader, - System::PacketBufferHandle && msgBuf, EncryptedPacketBufferHandle * bufferRetainSlot, - EncryptionState encryptionState); - /** Schedules a new oneshot timer for checking connection expiry. */ void ScheduleExpiryTimer(); diff --git a/src/transport/tests/TestSecureSessionMgr.cpp b/src/transport/tests/TestSecureSessionMgr.cpp index 69a6bb456a1f70..b9d0d42c76a0c7 100644 --- a/src/transport/tests/TestSecureSessionMgr.cpp +++ b/src/transport/tests/TestSecureSessionMgr.cpp @@ -220,7 +220,11 @@ void CheckMessageTest(nlTestSuite * inSuite, void * inContext) // Set the protocol ID and message type for this header. payloadHeader.SetMessageType(chip::Protocols::Echo::MsgType::EchoRequest); - err = secureSessionMgr.SendMessage(localToRemoteSession, payloadHeader, std::move(buffer)); + EncryptedPacketBufferHandle preparedMessage; + err = secureSessionMgr.BuildEncryptedMessagePayload(localToRemoteSession, payloadHeader, std::move(buffer), preparedMessage); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + + err = secureSessionMgr.SendPreparedMessage(localToRemoteSession, preparedMessage); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 1); @@ -231,7 +235,11 @@ void CheckMessageTest(nlTestSuite * inSuite, void * inContext) callback.LargeMessageSent = true; - err = secureSessionMgr.SendMessage(localToRemoteSession, payloadHeader, std::move(large_buffer)); + err = secureSessionMgr.BuildEncryptedMessagePayload(localToRemoteSession, payloadHeader, std::move(large_buffer), + preparedMessage); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + + err = secureSessionMgr.SendPreparedMessage(localToRemoteSession, preparedMessage); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 2); @@ -244,7 +252,8 @@ void CheckMessageTest(nlTestSuite * inSuite, void * inContext) callback.LargeMessageSent = true; - err = secureSessionMgr.SendMessage(localToRemoteSession, payloadHeader, std::move(extra_large_buffer)); + err = secureSessionMgr.BuildEncryptedMessagePayload(localToRemoteSession, payloadHeader, std::move(extra_large_buffer), + preparedMessage); NL_TEST_ASSERT(inSuite, err == CHIP_ERROR_MESSAGE_TOO_LONG); } @@ -302,7 +311,7 @@ void SendEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) callback.ReceiveHandlerCallCount = 0; PayloadHeader payloadHeader; - EncryptedPacketBufferHandle msgBuf; + EncryptedPacketBufferHandle preparedMessage; // Set the exchange ID for this header. payloadHeader.SetExchangeID(0); @@ -312,7 +321,10 @@ void SendEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) payloadHeader.SetInitiator(true); - err = secureSessionMgr.SendMessage(localToRemoteSession, payloadHeader, std::move(buffer), &msgBuf); + err = secureSessionMgr.BuildEncryptedMessagePayload(localToRemoteSession, payloadHeader, std::move(buffer), preparedMessage); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + + err = secureSessionMgr.SendPreparedMessage(localToRemoteSession, preparedMessage); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); // Reset receive side message counter, or duplicated message will be denied. @@ -321,7 +333,7 @@ void SendEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 1); - err = secureSessionMgr.SendEncryptedMessage(localToRemoteSession, std::move(msgBuf), nullptr); + err = secureSessionMgr.SendPreparedMessage(localToRemoteSession, preparedMessage); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 2); @@ -381,7 +393,7 @@ void SendBadEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) callback.ReceiveHandlerCallCount = 0; PayloadHeader payloadHeader; - EncryptedPacketBufferHandle msgBuf; + EncryptedPacketBufferHandle preparedMessage; // Set the exchange ID for this header. payloadHeader.SetExchangeID(0); @@ -391,7 +403,10 @@ void SendBadEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) payloadHeader.SetInitiator(true); - err = secureSessionMgr.SendMessage(localToRemoteSession, payloadHeader, std::move(buffer), &msgBuf); + err = secureSessionMgr.BuildEncryptedMessagePayload(localToRemoteSession, payloadHeader, std::move(buffer), preparedMessage); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + + err = secureSessionMgr.SendPreparedMessage(localToRemoteSession, preparedMessage); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 1); @@ -404,14 +419,14 @@ void SendBadEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) PacketHeader packetHeader; // Change Destination Node ID - EncryptedPacketBufferHandle badDestNodeIdMsg = msgBuf.CloneData(); + EncryptedPacketBufferHandle badDestNodeIdMsg = preparedMessage.CloneData(); NL_TEST_ASSERT(inSuite, badDestNodeIdMsg.ExtractPacketHeader(packetHeader) == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, packetHeader.GetDestinationNodeId().Value() == kDestinationNodeId); packetHeader.SetDestinationNodeId(kSourceNodeId); NL_TEST_ASSERT(inSuite, badDestNodeIdMsg.InsertPacketHeader(packetHeader) == CHIP_NO_ERROR); - err = secureSessionMgr.SendEncryptedMessage(localToRemoteSession, std::move(badDestNodeIdMsg), nullptr); + err = secureSessionMgr.SendPreparedMessage(localToRemoteSession, badDestNodeIdMsg); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 1); @@ -420,13 +435,13 @@ void SendBadEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) state->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(1); // Change Source Node ID - EncryptedPacketBufferHandle badSrcNodeIdMsg = msgBuf.CloneData(); + EncryptedPacketBufferHandle badSrcNodeIdMsg = preparedMessage.CloneData(); NL_TEST_ASSERT(inSuite, badSrcNodeIdMsg.ExtractPacketHeader(packetHeader) == CHIP_NO_ERROR); packetHeader.SetSourceNodeId(kDestinationNodeId); NL_TEST_ASSERT(inSuite, badSrcNodeIdMsg.InsertPacketHeader(packetHeader) == CHIP_NO_ERROR); - err = secureSessionMgr.SendEncryptedMessage(localToRemoteSession, std::move(badSrcNodeIdMsg), nullptr); + err = secureSessionMgr.SendPreparedMessage(localToRemoteSession, badSrcNodeIdMsg); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 1); @@ -435,13 +450,13 @@ void SendBadEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) state->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(1); // Change Source Node ID - EncryptedPacketBufferHandle noDstNodeIdMsg = msgBuf.CloneData(); + EncryptedPacketBufferHandle noDstNodeIdMsg = preparedMessage.CloneData(); NL_TEST_ASSERT(inSuite, noDstNodeIdMsg.ExtractPacketHeader(packetHeader) == CHIP_NO_ERROR); packetHeader.ClearDestinationNodeId(); NL_TEST_ASSERT(inSuite, noDstNodeIdMsg.InsertPacketHeader(packetHeader) == CHIP_NO_ERROR); - err = secureSessionMgr.SendEncryptedMessage(localToRemoteSession, std::move(noDstNodeIdMsg), nullptr); + err = secureSessionMgr.SendPreparedMessage(localToRemoteSession, noDstNodeIdMsg); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 1); @@ -450,14 +465,14 @@ void SendBadEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) state->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(1); // Change Message ID - EncryptedPacketBufferHandle badMessageIdMsg = msgBuf.CloneData(); + EncryptedPacketBufferHandle badMessageIdMsg = preparedMessage.CloneData(); NL_TEST_ASSERT(inSuite, badMessageIdMsg.ExtractPacketHeader(packetHeader) == CHIP_NO_ERROR); uint32_t msgID = packetHeader.GetMessageId(); packetHeader.SetMessageId(msgID + 1); NL_TEST_ASSERT(inSuite, badMessageIdMsg.InsertPacketHeader(packetHeader) == CHIP_NO_ERROR); - err = secureSessionMgr.SendEncryptedMessage(localToRemoteSession, std::move(badMessageIdMsg), nullptr); + err = secureSessionMgr.SendPreparedMessage(localToRemoteSession, badMessageIdMsg); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 1); @@ -466,14 +481,14 @@ void SendBadEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) state->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(1); // Change Key ID - EncryptedPacketBufferHandle badKeyIdMsg = msgBuf.CloneData(); + EncryptedPacketBufferHandle badKeyIdMsg = preparedMessage.CloneData(); NL_TEST_ASSERT(inSuite, badKeyIdMsg.ExtractPacketHeader(packetHeader) == CHIP_NO_ERROR); // the secure channel is setup to use key ID 1, and 2. So let's use 3 here. packetHeader.SetEncryptionKeyID(3); NL_TEST_ASSERT(inSuite, badKeyIdMsg.InsertPacketHeader(packetHeader) == CHIP_NO_ERROR); - err = secureSessionMgr.SendEncryptedMessage(localToRemoteSession, std::move(badKeyIdMsg), nullptr); + err = secureSessionMgr.SendPreparedMessage(localToRemoteSession, badKeyIdMsg); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); /* -------------------------------------------------------------------------------------------*/ @@ -482,7 +497,7 @@ void SendBadEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 1); // Send the correct encrypted msg - err = secureSessionMgr.SendEncryptedMessage(localToRemoteSession, std::move(msgBuf), nullptr); + err = secureSessionMgr.SendPreparedMessage(localToRemoteSession, preparedMessage); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 2);