From c052e2830357373b41b566d60c198388e161938d Mon Sep 17 00:00:00 2001 From: Zang MingJie Date: Wed, 2 Jun 2021 12:31:33 +0800 Subject: [PATCH] Fix CRMP resend null out retained buffer --- src/messaging/ApplicationExchangeDispatch.cpp | 14 +-- src/messaging/ApplicationExchangeDispatch.h | 8 +- src/messaging/ExchangeMessageDispatch.cpp | 15 ++-- src/messaging/ExchangeMessageDispatch.h | 19 +--- src/messaging/ReliableMessageMgr.cpp | 3 +- .../tests/TestReliableMessageProtocol.cpp | 28 ++---- .../SessionEstablishmentExchangeDispatch.cpp | 33 ++----- .../SessionEstablishmentExchangeDispatch.h | 10 +-- src/transport/SecureSessionMgr.cpp | 89 ++++++++----------- src/transport/SecureSessionMgr.h | 28 +++--- src/transport/tests/TestSecureSessionMgr.cpp | 51 +++++++---- 11 files changed, 126 insertions(+), 172 deletions(-) diff --git a/src/messaging/ApplicationExchangeDispatch.cpp b/src/messaging/ApplicationExchangeDispatch.cpp index 80a1a438647187..ce6f42f9fd8752 100644 --- a/src/messaging/ApplicationExchangeDispatch.cpp +++ b/src/messaging/ApplicationExchangeDispatch.cpp @@ -26,17 +26,17 @@ namespace chip { namespace Messaging { -CHIP_ERROR ApplicationExchangeDispatch::SendMessageImpl(SecureSessionHandle session, PayloadHeader & payloadHeader, - System::PacketBufferHandle && message, - EncryptedPacketBufferHandle * retainedMessage) +CHIP_ERROR ApplicationExchangeDispatch::PrepareMessage(SecureSessionHandle session, PayloadHeader & payloadHeader, + System::PacketBufferHandle && message, + EncryptedPacketBufferHandle & preparedMessage) { - return mSessionMgr->SendMessage(session, payloadHeader, std::move(message), retainedMessage); + return mSessionMgr->PrepareMessage(session, payloadHeader, std::move(message), preparedMessage); } -CHIP_ERROR ApplicationExchangeDispatch::ResendMessage(SecureSessionHandle session, EncryptedPacketBufferHandle && message, - EncryptedPacketBufferHandle * retainedMessage) const +CHIP_ERROR ApplicationExchangeDispatch::SendPreparedMessage(SecureSessionHandle session, + const EncryptedPacketBufferHandle & message) const { - return mSessionMgr->SendEncryptedMessage(session, std::move(message), retainedMessage); + return mSessionMgr->SendPreparedMessage(session, message); } bool ApplicationExchangeDispatch::MessagePermitted(uint16_t protocol, uint8_t type) diff --git a/src/messaging/ApplicationExchangeDispatch.h b/src/messaging/ApplicationExchangeDispatch.h index 4958cf11735ba8..6c72f178a96e86 100644 --- a/src/messaging/ApplicationExchangeDispatch.h +++ b/src/messaging/ApplicationExchangeDispatch.h @@ -45,15 +45,13 @@ class ApplicationExchangeDispatch : public ExchangeMessageDispatch return ExchangeMessageDispatch::Init(reliableMessageMgr); } - CHIP_ERROR ResendMessage(SecureSessionHandle session, EncryptedPacketBufferHandle && message, - EncryptedPacketBufferHandle * retainedMessage) const override; + CHIP_ERROR PrepareMessage(SecureSessionHandle session, PayloadHeader & payloadHeader, System::PacketBufferHandle && message, + EncryptedPacketBufferHandle & preparedMessage) override; + CHIP_ERROR SendPreparedMessage(SecureSessionHandle session, const EncryptedPacketBufferHandle & message) const override; SecureSessionMgr * GetSessionMgr() const { return mSessionMgr; } protected: - CHIP_ERROR SendMessageImpl(SecureSessionHandle session, PayloadHeader & payloadHeader, System::PacketBufferHandle && message, - EncryptedPacketBufferHandle * retainedMessage) override; - bool MessagePermitted(uint16_t protocol, uint8_t type) override; private: diff --git a/src/messaging/ExchangeMessageDispatch.cpp b/src/messaging/ExchangeMessageDispatch.cpp index c998ef6968864e..c0aa1f22db02d3 100644 --- a/src/messaging/ExchangeMessageDispatch.cpp +++ b/src/messaging/ExchangeMessageDispatch.cpp @@ -75,24 +75,25 @@ CHIP_ERROR ExchangeMessageDispatch::SendMessage(SecureSessionHandle session, uin // Add to Table for subsequent sending ReturnErrorOnFailure(mReliableMessageMgr->AddToRetransTable(reliableMessageContext, &entry)); - CHIP_ERROR err = SendMessageImpl(session, payloadHeader, std::move(message), &entry->retainedBuf); + CHIP_ERROR err = PrepareMessage(session, payloadHeader, std::move(message), entry->retainedBuf); if (err != CHIP_NO_ERROR) { // Remove from table ChipLogError(ExchangeManager, "Failed to send message with err %s", ::chip::ErrorStr(err)); mReliableMessageMgr->ClearRetransTable(*entry); - ReturnErrorOnFailure(err); - } - else - { - mReliableMessageMgr->StartRetransmision(entry); + return err; } + + ReturnErrorOnFailure(SendPreparedMessage(session, entry->retainedBuf)); + mReliableMessageMgr->StartRetransmision(entry); } else { // If the channel itself is providing reliability, let's not request CRMP acks payloadHeader.SetNeedsAck(false); - ReturnErrorOnFailure(SendMessageImpl(session, payloadHeader, std::move(message), nullptr)); + EncryptedPacketBufferHandle preparedMessage; + ReturnErrorOnFailure(PrepareMessage(session, payloadHeader, std::move(message), preparedMessage)); + ReturnErrorOnFailure(SendPreparedMessage(session, preparedMessage)); } return CHIP_NO_ERROR; diff --git a/src/messaging/ExchangeMessageDispatch.h b/src/messaging/ExchangeMessageDispatch.h index d5cc4111d1d338..40a499092bb519 100644 --- a/src/messaging/ExchangeMessageDispatch.h +++ b/src/messaging/ExchangeMessageDispatch.h @@ -48,16 +48,9 @@ class ExchangeMessageDispatch : public ReferenceCounted ReliableMessageContext * reliableMessageContext, bool isReliableTransmission, Protocols::Id protocol, uint8_t type, System::PacketBufferHandle && message); - /** - * The 'message' and 'retainedMessage' arguments may point to the same - * handle. Therefore, callees _must_ ensure that any moving out of - * 'message' happens before writing to *retainedMessage. - */ - virtual CHIP_ERROR ResendMessage(SecureSessionHandle session, EncryptedPacketBufferHandle && message, - EncryptedPacketBufferHandle * retainedMessage) const - { - return CHIP_ERROR_NOT_IMPLEMENTED; - } + virtual CHIP_ERROR PrepareMessage(SecureSessionHandle session, PayloadHeader & payloadHeader, + System::PacketBufferHandle && message, EncryptedPacketBufferHandle & preparedMessage) = 0; + virtual CHIP_ERROR SendPreparedMessage(SecureSessionHandle session, const EncryptedPacketBufferHandle & message) const = 0; virtual CHIP_ERROR OnMessageReceived(const PayloadHeader & payloadHeader, uint32_t messageId, const Transport::PeerAddress & peerAddress, @@ -65,11 +58,7 @@ class ExchangeMessageDispatch : public ReferenceCounted protected: virtual bool MessagePermitted(uint16_t protocol, uint8_t type) = 0; - - virtual CHIP_ERROR SendMessageImpl(SecureSessionHandle session, PayloadHeader & payloadHeader, - System::PacketBufferHandle && message, EncryptedPacketBufferHandle * retainedMessage) = 0; - - virtual bool IsReliableTransmissionAllowed() { return true; } + virtual bool IsReliableTransmissionAllowed() const { return true; } private: ReliableMessageMgr * mReliableMessageMgr = nullptr; diff --git a/src/messaging/ReliableMessageMgr.cpp b/src/messaging/ReliableMessageMgr.cpp index ab84a91cb1031a..05846907382a78 100644 --- a/src/messaging/ReliableMessageMgr.cpp +++ b/src/messaging/ReliableMessageMgr.cpp @@ -366,8 +366,7 @@ CHIP_ERROR ReliableMessageMgr::SendFromRetransTable(RetransTableEntry * entry) const ExchangeMessageDispatch * dispatcher = rc->GetExchangeContext()->GetMessageDispatch(); VerifyOrExit(dispatcher != nullptr, err = CHIP_ERROR_INCORRECT_STATE); - err = - dispatcher->ResendMessage(rc->GetExchangeContext()->GetSecureSession(), std::move(entry->retainedBuf), &entry->retainedBuf); + err = dispatcher->SendPreparedMessage(rc->GetExchangeContext()->GetSecureSession(), entry->retainedBuf); SuccessOrExit(err); // Update the counters diff --git a/src/messaging/tests/TestReliableMessageProtocol.cpp b/src/messaging/tests/TestReliableMessageProtocol.cpp index 5d0179743cdff9..72d8d37240bda3 100644 --- a/src/messaging/tests/TestReliableMessageProtocol.cpp +++ b/src/messaging/tests/TestReliableMessageProtocol.cpp @@ -113,37 +113,25 @@ class MockAppDelegate : public ExchangeDelegate class MockSessionEstablishmentExchangeDispatch : public Messaging::ExchangeMessageDispatch { public: - CHIP_ERROR SendMessageImpl(SecureSessionHandle session, PayloadHeader & payloadHeader, System::PacketBufferHandle && message, - EncryptedPacketBufferHandle * retainedMessage) override + CHIP_ERROR PrepareMessage(SecureSessionHandle session, PayloadHeader & payloadHeader, System::PacketBufferHandle && message, + EncryptedPacketBufferHandle & preparedMessage) override { PacketHeader packetHeader; ReturnErrorOnFailure(payloadHeader.EncodeBeforeData(message)); ReturnErrorOnFailure(packetHeader.EncodeBeforeData(message)); - if (retainedMessage != nullptr && mRetainMessageOnSend) - { - *retainedMessage = EncryptedPacketBufferHandle::MarkEncrypted(message.Retain()); - } - return gTransportMgr.SendMessage(Transport::PeerAddress(), std::move(message)); + preparedMessage = EncryptedPacketBufferHandle::MarkEncrypted(std::move(message)); + return CHIP_NO_ERROR; } - CHIP_ERROR ResendMessage(SecureSessionHandle session, EncryptedPacketBufferHandle && message, - EncryptedPacketBufferHandle * retainedMessage) const override + CHIP_ERROR SendPreparedMessage(SecureSessionHandle session, const EncryptedPacketBufferHandle & message) const override { - // Our send path needs a (writable) PacketBuffer, so get that from the - // EncryptedPacketBufferHandle. Note that we have to do this before we - // set *retainedMessage, because 'message' and '*retainedMessage' might - // be the same memory location and we have to guarantee that we move out - // of 'message' before we write to *retainedMessage. - System::PacketBufferHandle writableBuf(std::move(message).CastToWritable()); - if (retainedMessage != nullptr && mRetainMessageOnSend) - { - *retainedMessage = EncryptedPacketBufferHandle::MarkEncrypted(writableBuf.Retain()); - } - return gTransportMgr.SendMessage(Transport::PeerAddress(), std::move(writableBuf)); + return gTransportMgr.SendMessage(Transport::PeerAddress(), message.Retain()); } + bool IsReliableTransmissionAllowed() const override { return mRetainMessageOnSend; } + bool MessagePermitted(uint16_t protocol, uint8_t type) override { return true; } bool mRetainMessageOnSend = true; diff --git a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp b/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp index a84defaeadee7f..da0f56c7b1b911 100644 --- a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp +++ b/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp @@ -28,40 +28,23 @@ namespace chip { using namespace Messaging; -CHIP_ERROR SessionEstablishmentExchangeDispatch::SendMessageImpl(SecureSessionHandle session, PayloadHeader & payloadHeader, - System::PacketBufferHandle && message, - EncryptedPacketBufferHandle * retainedMessage) +CHIP_ERROR SessionEstablishmentExchangeDispatch::PrepareMessage(SecureSessionHandle session, PayloadHeader & payloadHeader, + System::PacketBufferHandle && message, + EncryptedPacketBufferHandle & preparedMessage) { - ReturnErrorCodeIf(mTransportMgr == nullptr, CHIP_ERROR_INCORRECT_STATE); PacketHeader packetHeader; - ReturnErrorOnFailure(payloadHeader.EncodeBeforeData(message)); ReturnErrorOnFailure(packetHeader.EncodeBeforeData(message)); - if (retainedMessage != nullptr) - { - *retainedMessage = EncryptedPacketBufferHandle::MarkEncrypted(message.Retain()); - ChipLogError(Inet, "RETAINED IN SESS: %p %d", retainedMessage, (*retainedMessage).IsNull()); - } - return mTransportMgr->SendMessage(mPeerAddress, std::move(message)); + preparedMessage = EncryptedPacketBufferHandle::MarkEncrypted(std::move(message)); + return CHIP_NO_ERROR; } -CHIP_ERROR SessionEstablishmentExchangeDispatch::ResendMessage(SecureSessionHandle session, EncryptedPacketBufferHandle && message, - EncryptedPacketBufferHandle * retainedMessage) const +CHIP_ERROR SessionEstablishmentExchangeDispatch::SendPreparedMessage(SecureSessionHandle session, + const EncryptedPacketBufferHandle & message) const { ReturnErrorCodeIf(mTransportMgr == nullptr, CHIP_ERROR_INCORRECT_STATE); - - // Our send path needs a (writable) PacketBuffer, so get that from the - // EncryptedPacketBufferHandle. Note that we have to do this before we set - // *retainedMessage, because 'message' and '*retainedMessage' might be the - // same memory location and we have to guarantee that we move out of - // 'message' before we write to *retainedMessage. - System::PacketBufferHandle writableBuf(std::move(message).CastToWritable()); - if (retainedMessage != nullptr) - { - *retainedMessage = EncryptedPacketBufferHandle::MarkEncrypted(writableBuf.Retain()); - } - return mTransportMgr->SendMessage(mPeerAddress, std::move(writableBuf)); + return mTransportMgr->SendMessage(mPeerAddress, message.Retain()); } CHIP_ERROR SessionEstablishmentExchangeDispatch::OnMessageReceived(const PayloadHeader & payloadHeader, uint32_t messageId, diff --git a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.h b/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.h index a6e9a669727fb1..359b4b06f23f15 100644 --- a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.h +++ b/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.h @@ -43,8 +43,9 @@ class SessionEstablishmentExchangeDispatch : public Messaging::ExchangeMessageDi return ExchangeMessageDispatch::Init(reliableMessageMgr); } - CHIP_ERROR ResendMessage(SecureSessionHandle session, EncryptedPacketBufferHandle && message, - EncryptedPacketBufferHandle * retainedMessage) const override; + CHIP_ERROR PrepareMessage(SecureSessionHandle session, PayloadHeader & payloadHeader, System::PacketBufferHandle && message, + EncryptedPacketBufferHandle & out) override; + CHIP_ERROR SendPreparedMessage(SecureSessionHandle session, const EncryptedPacketBufferHandle & message) const override; CHIP_ERROR OnMessageReceived(const PayloadHeader & payloadHeader, uint32_t messageId, const Transport::PeerAddress & peerAddress, @@ -55,12 +56,9 @@ class SessionEstablishmentExchangeDispatch : public Messaging::ExchangeMessageDi void SetPeerAddress(const Transport::PeerAddress & address) { mPeerAddress = address; } protected: - CHIP_ERROR SendMessageImpl(SecureSessionHandle session, PayloadHeader & payloadHeader, System::PacketBufferHandle && message, - EncryptedPacketBufferHandle * retainedMessage) override; - bool MessagePermitted(uint16_t protocol, uint8_t type) override; - bool IsReliableTransmissionAllowed() override + bool IsReliableTransmissionAllowed() const override { // If the underlying transport is UDP. return (mPeerAddress.GetTransportType() == Transport::Type::kUdp); diff --git a/src/transport/SecureSessionMgr.cpp b/src/transport/SecureSessionMgr.cpp index d15a1b40fcc4d2..ef999425a0db89 100644 --- a/src/transport/SecureSessionMgr.cpp +++ b/src/transport/SecureSessionMgr.cpp @@ -122,44 +122,49 @@ Transport::Type SecureSessionMgr::GetTransportType(NodeId peerNodeId) return Transport::Type::kUndefined; } -CHIP_ERROR SecureSessionMgr::SendMessage(SecureSessionHandle session, PayloadHeader & payloadHeader, - System::PacketBufferHandle && msgBuf, EncryptedPacketBufferHandle * bufferRetainSlot) +CHIP_ERROR SecureSessionMgr::PrepareMessage(SecureSessionHandle session, PayloadHeader & payloadHeader, + System::PacketBufferHandle && msgBuf, EncryptedPacketBufferHandle & preparedMessage) { - PacketHeader unusedPacketHeader; - return SendMessage(session, payloadHeader, unusedPacketHeader, std::move(msgBuf), bufferRetainSlot, - EncryptionState::kPayloadIsUnencrypted); -} + PacketHeader packetHeader; + if (IsControlMessage(payloadHeader)) + { + packetHeader.SetSecureSessionControlMsg(true); + } -CHIP_ERROR SecureSessionMgr::SendEncryptedMessage(SecureSessionHandle session, EncryptedPacketBufferHandle && msgBuf, - EncryptedPacketBufferHandle * bufferRetainSlot) -{ - VerifyOrReturnError(!msgBuf.IsNull(), CHIP_ERROR_INVALID_ARGUMENT); - VerifyOrReturnError(!msgBuf.HasChainedBuffer(), CHIP_ERROR_INVALID_MESSAGE_LENGTH); + PeerConnectionState * state = GetPeerConnectionState(session); + if (state == nullptr) + { + return CHIP_ERROR_NOT_CONNECTED; + }; - // Our send path needs a (writable) PacketBuffer (e.g. so it can encode a - // PacketHeader into it), so get that from the EncryptedPacketBufferHandle. - System::PacketBufferHandle mutableBuf(std::move(msgBuf).CastToWritable()); + Transport::AdminPairingInfo * admin = mAdmins->FindAdminWithId(state->GetAdminId()); + if (admin == nullptr) + { + return CHIP_ERROR_INCORRECT_STATE; + } - // Advancing the start to encrypted header, since SendMessage will attach the packet header on top of it. - PacketHeader packetHeader; - ReturnErrorOnFailure(packetHeader.DecodeAndConsume(mutableBuf)); + NodeId localNodeId = admin->GetNodeId(); + MessageCounter & counter = GetSendCounterForPacket(payloadHeader, *state); + ReturnErrorOnFailure(SecureMessageCodec::Encode(localNodeId, state, payloadHeader, packetHeader, msgBuf, counter)); - PayloadHeader payloadHeader; - return SendMessage(session, payloadHeader, packetHeader, std::move(mutableBuf), bufferRetainSlot, - EncryptionState::kPayloadIsEncrypted); + ReturnErrorOnFailure(packetHeader.EncodeBeforeData(msgBuf)); + + preparedMessage = EncryptedPacketBufferHandle::MarkEncrypted(std::move(msgBuf)); + ChipLogProgress(Inet, "Prepared msg %p from 0x" ChipLogFormatX64 " to 0x" ChipLogFormatX64 ".", &preparedMessage, + ChipLogValueX64(localNodeId), ChipLogValueX64(state->GetPeerNodeId())); + + return CHIP_NO_ERROR; } -CHIP_ERROR SecureSessionMgr::SendMessage(SecureSessionHandle session, PayloadHeader & payloadHeader, PacketHeader & packetHeader, - System::PacketBufferHandle && msgBuf, EncryptedPacketBufferHandle * bufferRetainSlot, - EncryptionState encryptionState) +CHIP_ERROR SecureSessionMgr::SendPreparedMessage(SecureSessionHandle session, const EncryptedPacketBufferHandle & preparedMessage) { - CHIP_ERROR err = CHIP_NO_ERROR; - PeerConnectionState * state = nullptr; - NodeId localNodeId = mLocalNodeId; - - Transport::AdminPairingInfo * admin = nullptr; + CHIP_ERROR err = CHIP_NO_ERROR; + PeerConnectionState * state = nullptr; + PacketBufferHandle msgBuf; VerifyOrExit(mState == State::kInitialized, err = CHIP_ERROR_INCORRECT_STATE); + VerifyOrExit(!preparedMessage.IsNull(), err = CHIP_ERROR_INVALID_ARGUMENT); + msgBuf = preparedMessage.Retain(); VerifyOrExit(!msgBuf.IsNull(), err = CHIP_ERROR_INVALID_ARGUMENT); VerifyOrExit(!msgBuf->HasChainedBuffer(), err = CHIP_ERROR_INVALID_MESSAGE_LENGTH); @@ -169,33 +174,9 @@ CHIP_ERROR SecureSessionMgr::SendMessage(SecureSessionHandle session, PayloadHea // This marks any connection where we send data to as 'active' mPeerConnections.MarkConnectionActive(state); - admin = mAdmins->FindAdminWithId(state->GetAdminId()); - VerifyOrExit(admin != nullptr, err = CHIP_ERROR_INCORRECT_STATE); - localNodeId = admin->GetNodeId(); - - if (IsControlMessage(payloadHeader)) - { - packetHeader.SetSecureSessionControlMsg(true); - } - - if (encryptionState == EncryptionState::kPayloadIsUnencrypted) - { - MessageCounter & counter = GetSendCounterForPacket(payloadHeader, *state); - err = SecureMessageCodec::Encode(localNodeId, state, payloadHeader, packetHeader, msgBuf, counter); - SuccessOrExit(err); - } - - err = packetHeader.EncodeBeforeData(msgBuf); - SuccessOrExit(err); - - // Retain the packet buffer in case it's needed for retransmissions. - if (bufferRetainSlot != nullptr) - { - *bufferRetainSlot = EncryptedPacketBufferHandle::MarkEncrypted(msgBuf.Retain()); - } - ChipLogProgress(Inet, "Sending msg from 0x" ChipLogFormatX64 " to 0x" ChipLogFormatX64 " at utc time: %" PRId64 " msec", - ChipLogValueX64(localNodeId), ChipLogValueX64(state->GetPeerNodeId()), System::Layer::GetClock_MonotonicMS()); + ChipLogProgress(Inet, "Sending msg %p to 0x" ChipLogFormatX64 " at utc time: %" PRId64 " msec", &preparedMessage, + ChipLogValueX64(state->GetPeerNodeId()), System::Layer::GetClock_MonotonicMS()); if (state->GetTransport() != nullptr) { diff --git a/src/transport/SecureSessionMgr.h b/src/transport/SecureSessionMgr.h index 576834fdfa93f9..ec102c2f5c3c28 100644 --- a/src/transport/SecureSessionMgr.h +++ b/src/transport/SecureSessionMgr.h @@ -68,6 +68,8 @@ class EncryptedPacketBufferHandle final : private System::PacketBufferHandle uint32_t GetMsgId() const; + PacketBufferHandle Retain() const { return PacketBufferHandle::Retain(); } + /** * Creates a copy of the data in this packet. * @@ -182,17 +184,23 @@ class DLL_EXPORT SecureSessionMgr : public TransportMgrDelegate /** * @brief - * Send a message to a currently connected peer. + * This function takes the payload and returns the final message which can be send multiple times. * * @details - * msgBuf contains the data to be transmitted. If bufferRetainSlot is not null and this function - * returns success, the encrypted data that was sent, as well as various other information needed - * to retransmit it, will be stored in *bufferRetainSlot. + * It contains following preparation: + * 1. Encrypt the msgBuf + * 2. construct the packet header + * 3. Encode the packet header and prepend it to message. + * Returns a prepared message in preparedMessage. + */ + CHIP_ERROR PrepareMessage(SecureSessionHandle session, PayloadHeader & payloadHeader, System::PacketBufferHandle && msgBuf, + EncryptedPacketBufferHandle & preparedMessage); + + /** + * @brief + * Send a prepared message to a currently connected peer. */ - CHIP_ERROR SendMessage(SecureSessionHandle session, PayloadHeader & payloadHeader, System::PacketBufferHandle && msgBuf, - EncryptedPacketBufferHandle * bufferRetainSlot = nullptr); - CHIP_ERROR SendEncryptedMessage(SecureSessionHandle session, EncryptedPacketBufferHandle && msgBuf, - EncryptedPacketBufferHandle * bufferRetainSlot); + CHIP_ERROR SendPreparedMessage(SecureSessionHandle session, const EncryptedPacketBufferHandle & preparedMessage); Transport::PeerConnectionState * GetPeerConnectionState(SecureSessionHandle session); @@ -301,10 +309,6 @@ class DLL_EXPORT SecureSessionMgr : public TransportMgrDelegate GlobalUnencryptedMessageCounter mGlobalUnencryptedMessageCounter; GlobalEncryptedMessageCounter mGlobalEncryptedMessageCounter; - CHIP_ERROR SendMessage(SecureSessionHandle session, PayloadHeader & payloadHeader, PacketHeader & packetHeader, - System::PacketBufferHandle && msgBuf, EncryptedPacketBufferHandle * bufferRetainSlot, - EncryptionState encryptionState); - /** Schedules a new oneshot timer for checking connection expiry. */ void ScheduleExpiryTimer(); diff --git a/src/transport/tests/TestSecureSessionMgr.cpp b/src/transport/tests/TestSecureSessionMgr.cpp index 69a6bb456a1f70..c149baaa19952b 100644 --- a/src/transport/tests/TestSecureSessionMgr.cpp +++ b/src/transport/tests/TestSecureSessionMgr.cpp @@ -220,7 +220,11 @@ void CheckMessageTest(nlTestSuite * inSuite, void * inContext) // Set the protocol ID and message type for this header. payloadHeader.SetMessageType(chip::Protocols::Echo::MsgType::EchoRequest); - err = secureSessionMgr.SendMessage(localToRemoteSession, payloadHeader, std::move(buffer)); + EncryptedPacketBufferHandle preparedMessage; + err = secureSessionMgr.PrepareMessage(localToRemoteSession, payloadHeader, std::move(buffer), preparedMessage); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + + err = secureSessionMgr.SendPreparedMessage(localToRemoteSession, preparedMessage); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 1); @@ -231,7 +235,10 @@ void CheckMessageTest(nlTestSuite * inSuite, void * inContext) callback.LargeMessageSent = true; - err = secureSessionMgr.SendMessage(localToRemoteSession, payloadHeader, std::move(large_buffer)); + err = secureSessionMgr.PrepareMessage(localToRemoteSession, payloadHeader, std::move(large_buffer), preparedMessage); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + + err = secureSessionMgr.SendPreparedMessage(localToRemoteSession, preparedMessage); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 2); @@ -244,7 +251,7 @@ void CheckMessageTest(nlTestSuite * inSuite, void * inContext) callback.LargeMessageSent = true; - err = secureSessionMgr.SendMessage(localToRemoteSession, payloadHeader, std::move(extra_large_buffer)); + err = secureSessionMgr.PrepareMessage(localToRemoteSession, payloadHeader, std::move(extra_large_buffer), preparedMessage); NL_TEST_ASSERT(inSuite, err == CHIP_ERROR_MESSAGE_TOO_LONG); } @@ -302,7 +309,7 @@ void SendEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) callback.ReceiveHandlerCallCount = 0; PayloadHeader payloadHeader; - EncryptedPacketBufferHandle msgBuf; + EncryptedPacketBufferHandle preparedMessage; // Set the exchange ID for this header. payloadHeader.SetExchangeID(0); @@ -312,7 +319,10 @@ void SendEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) payloadHeader.SetInitiator(true); - err = secureSessionMgr.SendMessage(localToRemoteSession, payloadHeader, std::move(buffer), &msgBuf); + err = secureSessionMgr.PrepareMessage(localToRemoteSession, payloadHeader, std::move(buffer), preparedMessage); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + + err = secureSessionMgr.SendPreparedMessage(localToRemoteSession, preparedMessage); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); // Reset receive side message counter, or duplicated message will be denied. @@ -321,7 +331,7 @@ void SendEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 1); - err = secureSessionMgr.SendEncryptedMessage(localToRemoteSession, std::move(msgBuf), nullptr); + err = secureSessionMgr.SendPreparedMessage(localToRemoteSession, preparedMessage); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 2); @@ -381,7 +391,7 @@ void SendBadEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) callback.ReceiveHandlerCallCount = 0; PayloadHeader payloadHeader; - EncryptedPacketBufferHandle msgBuf; + EncryptedPacketBufferHandle preparedMessage; // Set the exchange ID for this header. payloadHeader.SetExchangeID(0); @@ -391,7 +401,10 @@ void SendBadEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) payloadHeader.SetInitiator(true); - err = secureSessionMgr.SendMessage(localToRemoteSession, payloadHeader, std::move(buffer), &msgBuf); + err = secureSessionMgr.PrepareMessage(localToRemoteSession, payloadHeader, std::move(buffer), preparedMessage); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + + err = secureSessionMgr.SendPreparedMessage(localToRemoteSession, preparedMessage); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 1); @@ -404,14 +417,14 @@ void SendBadEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) PacketHeader packetHeader; // Change Destination Node ID - EncryptedPacketBufferHandle badDestNodeIdMsg = msgBuf.CloneData(); + EncryptedPacketBufferHandle badDestNodeIdMsg = preparedMessage.CloneData(); NL_TEST_ASSERT(inSuite, badDestNodeIdMsg.ExtractPacketHeader(packetHeader) == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, packetHeader.GetDestinationNodeId().Value() == kDestinationNodeId); packetHeader.SetDestinationNodeId(kSourceNodeId); NL_TEST_ASSERT(inSuite, badDestNodeIdMsg.InsertPacketHeader(packetHeader) == CHIP_NO_ERROR); - err = secureSessionMgr.SendEncryptedMessage(localToRemoteSession, std::move(badDestNodeIdMsg), nullptr); + err = secureSessionMgr.SendPreparedMessage(localToRemoteSession, badDestNodeIdMsg); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 1); @@ -420,13 +433,13 @@ void SendBadEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) state->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(1); // Change Source Node ID - EncryptedPacketBufferHandle badSrcNodeIdMsg = msgBuf.CloneData(); + EncryptedPacketBufferHandle badSrcNodeIdMsg = preparedMessage.CloneData(); NL_TEST_ASSERT(inSuite, badSrcNodeIdMsg.ExtractPacketHeader(packetHeader) == CHIP_NO_ERROR); packetHeader.SetSourceNodeId(kDestinationNodeId); NL_TEST_ASSERT(inSuite, badSrcNodeIdMsg.InsertPacketHeader(packetHeader) == CHIP_NO_ERROR); - err = secureSessionMgr.SendEncryptedMessage(localToRemoteSession, std::move(badSrcNodeIdMsg), nullptr); + err = secureSessionMgr.SendPreparedMessage(localToRemoteSession, badSrcNodeIdMsg); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 1); @@ -435,13 +448,13 @@ void SendBadEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) state->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(1); // Change Source Node ID - EncryptedPacketBufferHandle noDstNodeIdMsg = msgBuf.CloneData(); + EncryptedPacketBufferHandle noDstNodeIdMsg = preparedMessage.CloneData(); NL_TEST_ASSERT(inSuite, noDstNodeIdMsg.ExtractPacketHeader(packetHeader) == CHIP_NO_ERROR); packetHeader.ClearDestinationNodeId(); NL_TEST_ASSERT(inSuite, noDstNodeIdMsg.InsertPacketHeader(packetHeader) == CHIP_NO_ERROR); - err = secureSessionMgr.SendEncryptedMessage(localToRemoteSession, std::move(noDstNodeIdMsg), nullptr); + err = secureSessionMgr.SendPreparedMessage(localToRemoteSession, noDstNodeIdMsg); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 1); @@ -450,14 +463,14 @@ void SendBadEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) state->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(1); // Change Message ID - EncryptedPacketBufferHandle badMessageIdMsg = msgBuf.CloneData(); + EncryptedPacketBufferHandle badMessageIdMsg = preparedMessage.CloneData(); NL_TEST_ASSERT(inSuite, badMessageIdMsg.ExtractPacketHeader(packetHeader) == CHIP_NO_ERROR); uint32_t msgID = packetHeader.GetMessageId(); packetHeader.SetMessageId(msgID + 1); NL_TEST_ASSERT(inSuite, badMessageIdMsg.InsertPacketHeader(packetHeader) == CHIP_NO_ERROR); - err = secureSessionMgr.SendEncryptedMessage(localToRemoteSession, std::move(badMessageIdMsg), nullptr); + err = secureSessionMgr.SendPreparedMessage(localToRemoteSession, badMessageIdMsg); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 1); @@ -466,14 +479,14 @@ void SendBadEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) state->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(1); // Change Key ID - EncryptedPacketBufferHandle badKeyIdMsg = msgBuf.CloneData(); + EncryptedPacketBufferHandle badKeyIdMsg = preparedMessage.CloneData(); NL_TEST_ASSERT(inSuite, badKeyIdMsg.ExtractPacketHeader(packetHeader) == CHIP_NO_ERROR); // the secure channel is setup to use key ID 1, and 2. So let's use 3 here. packetHeader.SetEncryptionKeyID(3); NL_TEST_ASSERT(inSuite, badKeyIdMsg.InsertPacketHeader(packetHeader) == CHIP_NO_ERROR); - err = secureSessionMgr.SendEncryptedMessage(localToRemoteSession, std::move(badKeyIdMsg), nullptr); + err = secureSessionMgr.SendPreparedMessage(localToRemoteSession, badKeyIdMsg); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); /* -------------------------------------------------------------------------------------------*/ @@ -482,7 +495,7 @@ void SendBadEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 1); // Send the correct encrypted msg - err = secureSessionMgr.SendEncryptedMessage(localToRemoteSession, std::move(msgBuf), nullptr); + err = secureSessionMgr.SendPreparedMessage(localToRemoteSession, preparedMessage); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 2);