diff --git a/src/messaging/ExchangeContext.cpp b/src/messaging/ExchangeContext.cpp index b8f387258e93ea..ffaaf5d1ceae3a 100644 --- a/src/messaging/ExchangeContext.cpp +++ b/src/messaging/ExchangeContext.cpp @@ -425,8 +425,8 @@ CHIP_ERROR ExchangeContext::HandleMessage(const PacketHeader & packetHeader, con MessageHandled(); }); - ReturnErrorOnFailure(mDispatch->OnMessageReceived(payloadHeader, packetHeader.GetMessageId(), peerAddress, msgFlags, - GetReliableMessageContext())); + ReturnErrorOnFailure(mDispatch->OnMessageReceived(packetHeader.GetFlags(), payloadHeader, packetHeader.GetMessageId(), + peerAddress, msgFlags, GetReliableMessageContext())); if (IsAckPending() && !mDelegate) { diff --git a/src/messaging/ExchangeMessageDispatch.cpp b/src/messaging/ExchangeMessageDispatch.cpp index 055b987577223b..2da637bea71311 100644 --- a/src/messaging/ExchangeMessageDispatch.cpp +++ b/src/messaging/ExchangeMessageDispatch.cpp @@ -94,10 +94,15 @@ CHIP_ERROR ExchangeMessageDispatch::SendMessage(SecureSessionHandle session, uin return CHIP_NO_ERROR; } -CHIP_ERROR ExchangeMessageDispatch::OnMessageReceived(const PayloadHeader & payloadHeader, uint32_t messageId, - const Transport::PeerAddress & peerAddress, MessageFlags msgFlags, - ReliableMessageContext * reliableMessageContext) +CHIP_ERROR ExchangeMessageDispatch::OnMessageReceived(const Header::Flags & headerFlags, const PayloadHeader & payloadHeader, + uint32_t messageId, const Transport::PeerAddress & peerAddress, + MessageFlags msgFlags, ReliableMessageContext * reliableMessageContext) { + if (IsEncryptionRequired()) + { + VerifyOrReturnError(headerFlags.Has(Header::FlagValues::kEncryptedMessage), CHIP_ERROR_INVALID_ARGUMENT); + } + ReturnErrorCodeIf(!MessagePermitted(payloadHeader.GetProtocolID().GetProtocolId(), payloadHeader.GetMessageType()), CHIP_ERROR_INVALID_ARGUMENT); diff --git a/src/messaging/ExchangeMessageDispatch.h b/src/messaging/ExchangeMessageDispatch.h index 4bf43fa7b293ac..65bb350bf44a13 100644 --- a/src/messaging/ExchangeMessageDispatch.h +++ b/src/messaging/ExchangeMessageDispatch.h @@ -59,13 +59,14 @@ class ExchangeMessageDispatch : public ReferenceCounted virtual CHIP_ERROR SendPreparedMessage(SecureSessionHandle session, const EncryptedPacketBufferHandle & preparedMessage) const = 0; - virtual CHIP_ERROR OnMessageReceived(const PayloadHeader & payloadHeader, uint32_t messageId, + virtual CHIP_ERROR OnMessageReceived(const Header::Flags & headerFlags, const PayloadHeader & payloadHeader, uint32_t messageId, const Transport::PeerAddress & peerAddress, MessageFlags msgFlags, ReliableMessageContext * reliableMessageContext); protected: virtual bool MessagePermitted(uint16_t protocol, uint8_t type) = 0; virtual bool IsReliableTransmissionAllowed() const { return true; } + virtual bool IsEncryptionRequired() const { return true; } }; } // namespace Messaging diff --git a/src/messaging/tests/TestReliableMessageProtocol.cpp b/src/messaging/tests/TestReliableMessageProtocol.cpp index e62658f8e11069..a34fb6cfe5e866 100644 --- a/src/messaging/tests/TestReliableMessageProtocol.cpp +++ b/src/messaging/tests/TestReliableMessageProtocol.cpp @@ -147,7 +147,11 @@ class MockSessionEstablishmentExchangeDispatch : public Messaging::ExchangeMessa bool MessagePermitted(uint16_t protocol, uint8_t type) override { return true; } + bool IsEncryptionRequired() const override { return mRequireEncryption; } + bool mRetainMessageOnSend = true; + + bool mRequireEncryption = false; }; class MockSessionEstablishmentDelegate : public ExchangeDelegate @@ -421,6 +425,44 @@ void CheckFailedMessageRetainOnSend(nlTestSuite * inSuite, void * inContext) rm->ClearRetransTable(rc); } +void CheckUnencryptedMessageReceiveFailure(nlTestSuite * inSuite, void * inContext) +{ + TestContext & ctx = *reinterpret_cast(inContext); + + ctx.GetInetLayer().SystemLayer()->Init(); + + chip::System::PacketBufferHandle buffer = chip::MessagePacketBuffer::NewWithData(PAYLOAD, sizeof(PAYLOAD)); + NL_TEST_ASSERT(inSuite, !buffer.IsNull()); + + MockSessionEstablishmentDelegate mockReceiver; + CHIP_ERROR err = ctx.GetExchangeManager().RegisterUnsolicitedMessageHandlerForType(Echo::MsgType::EchoRequest, &mockReceiver); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + + // Expect the received messages to be encrypted + mockReceiver.mMessageDispatch.mRequireEncryption = true; + + MockSessionEstablishmentDelegate mockSender; + ExchangeContext * exchange = ctx.NewExchangeToPeer(&mockSender); + NL_TEST_ASSERT(inSuite, exchange != nullptr); + + err = mockSender.mMessageDispatch.Init(); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + + gLoopback.mSentMessageCount = 0; + gLoopback.mNumMessagesToDrop = 0; + gLoopback.mDroppedMessageCount = 0; + + err = exchange->SendMessage(Echo::MsgType::EchoRequest, std::move(buffer)); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + // Test that the message was actually sent (and not dropped) + NL_TEST_ASSERT(inSuite, gLoopback.mSentMessageCount == 1); + NL_TEST_ASSERT(inSuite, gLoopback.mDroppedMessageCount == 0); + // Test that the message was dropped by the receiver + NL_TEST_ASSERT(inSuite, !mockReceiver.IsOnMessageReceivedCalled); + + exchange->Close(); +} + void CheckResendApplicationMessageWithPeerExchange(nlTestSuite * inSuite, void * inContext) { TestContext & ctx = *reinterpret_cast(inContext); @@ -1161,6 +1203,7 @@ const nlTest sTests[] = NL_TEST_DEF("Test sending an unsolicited ack-soliciting 'standalone ack' message", CheckSendUnsolicitedStandaloneAckMessage), NL_TEST_DEF("Test ReliableMessageMgr::CheckSendStandaloneAckMessage", CheckSendStandaloneAckMessage), NL_TEST_DEF("Test command, response, default response, with receiver closing exchange after sending response", CheckMessageAfterClosed), + NL_TEST_DEF("Test that unencrypted message is dropped if exchange requires encryption", CheckUnencryptedMessageReceiveFailure), NL_TEST_SENTINEL() }; diff --git a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp b/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp index 59914b415231d9..8cb8ee03cc2183 100644 --- a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp +++ b/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp @@ -47,13 +47,15 @@ CHIP_ERROR SessionEstablishmentExchangeDispatch::SendPreparedMessage(SecureSessi return mTransportMgr->SendMessage(mPeerAddress, preparedMessage.CastToWritable()); } -CHIP_ERROR SessionEstablishmentExchangeDispatch::OnMessageReceived(const PayloadHeader & payloadHeader, uint32_t messageId, +CHIP_ERROR SessionEstablishmentExchangeDispatch::OnMessageReceived(const Header::Flags & headerFlags, + const PayloadHeader & payloadHeader, uint32_t messageId, const Transport::PeerAddress & peerAddress, Messaging::MessageFlags msgFlags, ReliableMessageContext * reliableMessageContext) { mPeerAddress = peerAddress; - return ExchangeMessageDispatch::OnMessageReceived(payloadHeader, messageId, peerAddress, msgFlags, reliableMessageContext); + return ExchangeMessageDispatch::OnMessageReceived(headerFlags, payloadHeader, messageId, peerAddress, msgFlags, + reliableMessageContext); } bool SessionEstablishmentExchangeDispatch::MessagePermitted(uint16_t protocol, uint8_t type) diff --git a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.h b/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.h index 1a58d00f98e859..225a3410de5350 100644 --- a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.h +++ b/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.h @@ -47,7 +47,7 @@ class SessionEstablishmentExchangeDispatch : public Messaging::ExchangeMessageDi EncryptedPacketBufferHandle & out) override; CHIP_ERROR SendPreparedMessage(SecureSessionHandle session, const EncryptedPacketBufferHandle & preparedMessage) const override; - CHIP_ERROR OnMessageReceived(const PayloadHeader & payloadHeader, uint32_t messageId, + CHIP_ERROR OnMessageReceived(const Header::Flags & headerFlags, const PayloadHeader & payloadHeader, uint32_t messageId, const Transport::PeerAddress & peerAddress, Messaging::MessageFlags msgFlags, Messaging::ReliableMessageContext * reliableMessageContext) override; @@ -64,6 +64,8 @@ class SessionEstablishmentExchangeDispatch : public Messaging::ExchangeMessageDi return (mPeerAddress.GetTransportType() == Transport::Type::kUdp); } + bool IsEncryptionRequired() const override { return false; } + private: TransportMgrBase * mTransportMgr = nullptr; Transport::PeerAddress mPeerAddress;