diff --git a/src/messaging/ExchangeMessageDispatch.h b/src/messaging/ExchangeMessageDispatch.h index 57179710d84017..d5cc4111d1d338 100644 --- a/src/messaging/ExchangeMessageDispatch.h +++ b/src/messaging/ExchangeMessageDispatch.h @@ -48,6 +48,11 @@ class ExchangeMessageDispatch : public ReferenceCounted ReliableMessageContext * reliableMessageContext, bool isReliableTransmission, Protocols::Id protocol, uint8_t type, System::PacketBufferHandle && message); + /** + * The 'message' and 'retainedMessage' arguments may point to the same + * handle. Therefore, callees _must_ ensure that any moving out of + * 'message' happens before writing to *retainedMessage. + */ virtual CHIP_ERROR ResendMessage(SecureSessionHandle session, EncryptedPacketBufferHandle && message, EncryptedPacketBufferHandle * retainedMessage) const { diff --git a/src/messaging/tests/TestReliableMessageProtocol.cpp b/src/messaging/tests/TestReliableMessageProtocol.cpp index e26d7dcfcd4080..5d0179743cdff9 100644 --- a/src/messaging/tests/TestReliableMessageProtocol.cpp +++ b/src/messaging/tests/TestReliableMessageProtocol.cpp @@ -131,11 +131,17 @@ class MockSessionEstablishmentExchangeDispatch : public Messaging::ExchangeMessa CHIP_ERROR ResendMessage(SecureSessionHandle session, EncryptedPacketBufferHandle && message, EncryptedPacketBufferHandle * retainedMessage) const override { + // Our send path needs a (writable) PacketBuffer, so get that from the + // EncryptedPacketBufferHandle. Note that we have to do this before we + // set *retainedMessage, because 'message' and '*retainedMessage' might + // be the same memory location and we have to guarantee that we move out + // of 'message' before we write to *retainedMessage. + System::PacketBufferHandle writableBuf(std::move(message).CastToWritable()); if (retainedMessage != nullptr && mRetainMessageOnSend) { - *retainedMessage = EncryptedPacketBufferHandle::MarkEncrypted(message.Retain()); + *retainedMessage = EncryptedPacketBufferHandle::MarkEncrypted(writableBuf.Retain()); } - return gTransportMgr.SendMessage(Transport::PeerAddress(), std::move(message)); + return gTransportMgr.SendMessage(Transport::PeerAddress(), std::move(writableBuf)); } bool MessagePermitted(uint16_t protocol, uint8_t type) override { return true; } diff --git a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp b/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp index 78fac185ec66f5..a84defaeadee7f 100644 --- a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp +++ b/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp @@ -41,6 +41,7 @@ CHIP_ERROR SessionEstablishmentExchangeDispatch::SendMessageImpl(SecureSessionHa if (retainedMessage != nullptr) { *retainedMessage = EncryptedPacketBufferHandle::MarkEncrypted(message.Retain()); + ChipLogError(Inet, "RETAINED IN SESS: %p %d", retainedMessage, (*retainedMessage).IsNull()); } return mTransportMgr->SendMessage(mPeerAddress, std::move(message)); } @@ -50,11 +51,17 @@ CHIP_ERROR SessionEstablishmentExchangeDispatch::ResendMessage(SecureSessionHand { ReturnErrorCodeIf(mTransportMgr == nullptr, CHIP_ERROR_INCORRECT_STATE); + // Our send path needs a (writable) PacketBuffer, so get that from the + // EncryptedPacketBufferHandle. Note that we have to do this before we set + // *retainedMessage, because 'message' and '*retainedMessage' might be the + // same memory location and we have to guarantee that we move out of + // 'message' before we write to *retainedMessage. + System::PacketBufferHandle writableBuf(std::move(message).CastToWritable()); if (retainedMessage != nullptr) { - *retainedMessage = EncryptedPacketBufferHandle::MarkEncrypted(message.Retain()); + *retainedMessage = EncryptedPacketBufferHandle::MarkEncrypted(writableBuf.Retain()); } - return mTransportMgr->SendMessage(mPeerAddress, std::move(message)); + return mTransportMgr->SendMessage(mPeerAddress, std::move(writableBuf)); } CHIP_ERROR SessionEstablishmentExchangeDispatch::OnMessageReceived(const PayloadHeader & payloadHeader, uint32_t messageId, diff --git a/src/transport/SecureSessionMgr.cpp b/src/transport/SecureSessionMgr.cpp index cda9a7bc0e6562..d15a1b40fcc4d2 100644 --- a/src/transport/SecureSessionMgr.cpp +++ b/src/transport/SecureSessionMgr.cpp @@ -134,14 +134,18 @@ CHIP_ERROR SecureSessionMgr::SendEncryptedMessage(SecureSessionHandle session, E EncryptedPacketBufferHandle * bufferRetainSlot) { VerifyOrReturnError(!msgBuf.IsNull(), CHIP_ERROR_INVALID_ARGUMENT); - VerifyOrReturnError(!msgBuf->HasChainedBuffer(), CHIP_ERROR_INVALID_MESSAGE_LENGTH); + VerifyOrReturnError(!msgBuf.HasChainedBuffer(), CHIP_ERROR_INVALID_MESSAGE_LENGTH); + + // Our send path needs a (writable) PacketBuffer (e.g. so it can encode a + // PacketHeader into it), so get that from the EncryptedPacketBufferHandle. + System::PacketBufferHandle mutableBuf(std::move(msgBuf).CastToWritable()); // Advancing the start to encrypted header, since SendMessage will attach the packet header on top of it. PacketHeader packetHeader; - ReturnErrorOnFailure(packetHeader.DecodeAndConsume(msgBuf)); + ReturnErrorOnFailure(packetHeader.DecodeAndConsume(mutableBuf)); PayloadHeader payloadHeader; - return SendMessage(session, payloadHeader, packetHeader, std::move(msgBuf), bufferRetainSlot, + return SendMessage(session, payloadHeader, packetHeader, std::move(mutableBuf), bufferRetainSlot, EncryptionState::kPayloadIsEncrypted); } diff --git a/src/transport/SecureSessionMgr.h b/src/transport/SecureSessionMgr.h index f1a4ae78abaffd..576834fdfa93f9 100644 --- a/src/transport/SecureSessionMgr.h +++ b/src/transport/SecureSessionMgr.h @@ -53,7 +53,7 @@ namespace chip { * EncryptedPacketBufferHandle is a kind of PacketBufferHandle class and used to hold a packet buffer * object whose payload has already been encrypted. */ -class EncryptedPacketBufferHandle final : public System::PacketBufferHandle +class EncryptedPacketBufferHandle final : private System::PacketBufferHandle { public: EncryptedPacketBufferHandle() {} @@ -61,6 +61,11 @@ class EncryptedPacketBufferHandle final : public System::PacketBufferHandle void operator=(EncryptedPacketBufferHandle && aBuffer) { PacketBufferHandle::operator=(std::move(aBuffer)); } + using System::PacketBufferHandle::IsNull; + // Pass-through to HasChainedBuffer on our underlying buffer without + // exposing operator-> + bool HasChainedBuffer() const { return (*this)->HasChainedBuffer(); } + uint32_t GetMsgId() const; /** @@ -97,6 +102,16 @@ class EncryptedPacketBufferHandle final : public System::PacketBufferHandle return EncryptedPacketBufferHandle(std::move(aBuffer)); } + /** + * Get a handle to the data that allows mutating the bytes. This should + * only be used if absolutely necessary, because EncryptedPacketBufferHandle + * represents a buffer that we want to resend as-is. + * + * We only allow doing this with an rvalue reference, so the fact that we + * are moving out of the EncryptedPacketBufferHandle is clear. + */ + PacketBufferHandle CastToWritable() && { return PacketBufferHandle(std::move(*this)); } + private: EncryptedPacketBufferHandle(PacketBufferHandle && aBuffer) : PacketBufferHandle(std::move(aBuffer)) {} };