diff --git a/src/messaging/ExchangeContext.cpp b/src/messaging/ExchangeContext.cpp index 2c1dcc3c5c223c..cb7cda6f748697 100644 --- a/src/messaging/ExchangeContext.cpp +++ b/src/messaging/ExchangeContext.cpp @@ -433,8 +433,8 @@ CHIP_ERROR ExchangeContext::HandleMessage(const PacketHeader & packetHeader, con MessageHandled(); }); - ReturnErrorOnFailure(mDispatch->OnMessageReceived(packetHeader.GetFlags(), payloadHeader, packetHeader.GetMessageId(), - peerAddress, msgFlags, GetReliableMessageContext())); + ReturnErrorOnFailure(mDispatch->OnMessageReceived(payloadHeader, packetHeader.GetMessageId(), peerAddress, msgFlags, + GetReliableMessageContext())); if (IsAckPending() && !mDelegate) { diff --git a/src/messaging/ExchangeMessageDispatch.cpp b/src/messaging/ExchangeMessageDispatch.cpp index 2da637bea71311..055b987577223b 100644 --- a/src/messaging/ExchangeMessageDispatch.cpp +++ b/src/messaging/ExchangeMessageDispatch.cpp @@ -94,15 +94,10 @@ CHIP_ERROR ExchangeMessageDispatch::SendMessage(SecureSessionHandle session, uin return CHIP_NO_ERROR; } -CHIP_ERROR ExchangeMessageDispatch::OnMessageReceived(const Header::Flags & headerFlags, const PayloadHeader & payloadHeader, - uint32_t messageId, const Transport::PeerAddress & peerAddress, - MessageFlags msgFlags, ReliableMessageContext * reliableMessageContext) +CHIP_ERROR ExchangeMessageDispatch::OnMessageReceived(const PayloadHeader & payloadHeader, uint32_t messageId, + const Transport::PeerAddress & peerAddress, MessageFlags msgFlags, + ReliableMessageContext * reliableMessageContext) { - if (IsEncryptionRequired()) - { - VerifyOrReturnError(headerFlags.Has(Header::FlagValues::kEncryptedMessage), CHIP_ERROR_INVALID_ARGUMENT); - } - ReturnErrorCodeIf(!MessagePermitted(payloadHeader.GetProtocolID().GetProtocolId(), payloadHeader.GetMessageType()), CHIP_ERROR_INVALID_ARGUMENT); diff --git a/src/messaging/ExchangeMessageDispatch.h b/src/messaging/ExchangeMessageDispatch.h index 65bb350bf44a13..4bf43fa7b293ac 100644 --- a/src/messaging/ExchangeMessageDispatch.h +++ b/src/messaging/ExchangeMessageDispatch.h @@ -59,14 +59,13 @@ class ExchangeMessageDispatch : public ReferenceCounted virtual CHIP_ERROR SendPreparedMessage(SecureSessionHandle session, const EncryptedPacketBufferHandle & preparedMessage) const = 0; - virtual CHIP_ERROR OnMessageReceived(const Header::Flags & headerFlags, const PayloadHeader & payloadHeader, uint32_t messageId, + virtual CHIP_ERROR OnMessageReceived(const PayloadHeader & payloadHeader, uint32_t messageId, const Transport::PeerAddress & peerAddress, MessageFlags msgFlags, ReliableMessageContext * reliableMessageContext); protected: virtual bool MessagePermitted(uint16_t protocol, uint8_t type) = 0; virtual bool IsReliableTransmissionAllowed() const { return true; } - virtual bool IsEncryptionRequired() const { return true; } }; } // namespace Messaging diff --git a/src/messaging/tests/TestReliableMessageProtocol.cpp b/src/messaging/tests/TestReliableMessageProtocol.cpp index 1e94e5a1a49078..17b6c364d1e5ca 100644 --- a/src/messaging/tests/TestReliableMessageProtocol.cpp +++ b/src/messaging/tests/TestReliableMessageProtocol.cpp @@ -147,11 +147,7 @@ class MockSessionEstablishmentExchangeDispatch : public Messaging::ExchangeMessa bool MessagePermitted(uint16_t protocol, uint8_t type) override { return true; } - bool IsEncryptionRequired() const override { return mRequireEncryption; } - bool mRetainMessageOnSend = true; - - bool mRequireEncryption = false; }; class MockSessionEstablishmentDelegate : public ExchangeDelegate @@ -420,52 +416,6 @@ void CheckFailedMessageRetainOnSend(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, rm->TestGetCountRetransTable() == 0); } -void CheckUnencryptedMessageReceiveFailure(nlTestSuite * inSuite, void * inContext) -{ - TestContext & ctx = *reinterpret_cast(inContext); - - ctx.GetInetLayer().SystemLayer()->Init(); - - chip::System::PacketBufferHandle buffer = chip::MessagePacketBuffer::NewWithData(PAYLOAD, sizeof(PAYLOAD)); - NL_TEST_ASSERT(inSuite, !buffer.IsNull()); - - MockSessionEstablishmentDelegate mockReceiver; - CHIP_ERROR err = ctx.GetExchangeManager().RegisterUnsolicitedMessageHandlerForType(Echo::MsgType::EchoRequest, &mockReceiver); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - - // Expect the received messages to be encrypted - mockReceiver.mMessageDispatch.mRequireEncryption = true; - - MockSessionEstablishmentDelegate mockSender; - ExchangeContext * exchange = ctx.NewExchangeToPeer(&mockSender); - NL_TEST_ASSERT(inSuite, exchange != nullptr); - - ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr(); - ReliableMessageContext * rc = exchange->GetReliableMessageContext(); - NL_TEST_ASSERT(inSuite, rm != nullptr); - NL_TEST_ASSERT(inSuite, rc != nullptr); - - err = mockSender.mMessageDispatch.Init(); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - - gLoopback.mSentMessageCount = 0; - gLoopback.mNumMessagesToDrop = 0; - gLoopback.mDroppedMessageCount = 0; - - err = exchange->SendMessage(Echo::MsgType::EchoRequest, std::move(buffer)); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - // Test that the message was actually sent (and not dropped) - NL_TEST_ASSERT(inSuite, gLoopback.mSentMessageCount == 1); - NL_TEST_ASSERT(inSuite, gLoopback.mDroppedMessageCount == 0); - // Test that the message was dropped by the receiver - NL_TEST_ASSERT(inSuite, !mockReceiver.IsOnMessageReceivedCalled); - - // Since peer dropped the message, we might have pending acks. Let's clear the table - rm->ClearRetransTable(rc); - - exchange->Close(); -} - void CheckResendApplicationMessageWithPeerExchange(nlTestSuite * inSuite, void * inContext) { TestContext & ctx = *reinterpret_cast(inContext); @@ -1187,7 +1137,6 @@ const nlTest sTests[] = NL_TEST_DEF("Test sending an unsolicited ack-soliciting 'standalone ack' message", CheckSendUnsolicitedStandaloneAckMessage), NL_TEST_DEF("Test ReliableMessageMgr::CheckSendStandaloneAckMessage", CheckSendStandaloneAckMessage), NL_TEST_DEF("Test command, response, default response, with receiver closing exchange after sending response", CheckMessageAfterClosed), - NL_TEST_DEF("Test that unencrypted message is dropped if exchange requires encryption", CheckUnencryptedMessageReceiveFailure), NL_TEST_SENTINEL() }; diff --git a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp b/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp index 8cb8ee03cc2183..59914b415231d9 100644 --- a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp +++ b/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp @@ -47,15 +47,13 @@ CHIP_ERROR SessionEstablishmentExchangeDispatch::SendPreparedMessage(SecureSessi return mTransportMgr->SendMessage(mPeerAddress, preparedMessage.CastToWritable()); } -CHIP_ERROR SessionEstablishmentExchangeDispatch::OnMessageReceived(const Header::Flags & headerFlags, - const PayloadHeader & payloadHeader, uint32_t messageId, +CHIP_ERROR SessionEstablishmentExchangeDispatch::OnMessageReceived(const PayloadHeader & payloadHeader, uint32_t messageId, const Transport::PeerAddress & peerAddress, Messaging::MessageFlags msgFlags, ReliableMessageContext * reliableMessageContext) { mPeerAddress = peerAddress; - return ExchangeMessageDispatch::OnMessageReceived(headerFlags, payloadHeader, messageId, peerAddress, msgFlags, - reliableMessageContext); + return ExchangeMessageDispatch::OnMessageReceived(payloadHeader, messageId, peerAddress, msgFlags, reliableMessageContext); } bool SessionEstablishmentExchangeDispatch::MessagePermitted(uint16_t protocol, uint8_t type) diff --git a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.h b/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.h index 225a3410de5350..1a58d00f98e859 100644 --- a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.h +++ b/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.h @@ -47,7 +47,7 @@ class SessionEstablishmentExchangeDispatch : public Messaging::ExchangeMessageDi EncryptedPacketBufferHandle & out) override; CHIP_ERROR SendPreparedMessage(SecureSessionHandle session, const EncryptedPacketBufferHandle & preparedMessage) const override; - CHIP_ERROR OnMessageReceived(const Header::Flags & headerFlags, const PayloadHeader & payloadHeader, uint32_t messageId, + CHIP_ERROR OnMessageReceived(const PayloadHeader & payloadHeader, uint32_t messageId, const Transport::PeerAddress & peerAddress, Messaging::MessageFlags msgFlags, Messaging::ReliableMessageContext * reliableMessageContext) override; @@ -64,8 +64,6 @@ class SessionEstablishmentExchangeDispatch : public Messaging::ExchangeMessageDi return (mPeerAddress.GetTransportType() == Transport::Type::kUdp); } - bool IsEncryptionRequired() const override { return false; } - private: TransportMgrBase * mTransportMgr = nullptr; Transport::PeerAddress mPeerAddress;