From fa1878438093c0f80e672bc0edf00a02934df2b8 Mon Sep 17 00:00:00 2001 From: Boris Zbarsky Date: Thu, 9 Jun 2022 13:09:11 -0400 Subject: [PATCH] Ensure that we send MRP acks to incoming messages as needed. There were two ways we could fail to send an ack to an incoming reliable message: 1) If we found no matching handler, and hence created an ephemeral exchange to handle the message, but the message was unencrypted. In this case our ephemeral exchange would return true for IsEncryptionRequired(), because it would default to an ApplicationExchangeDispatch, and we would never call into ExchangeContext::HandleMessage. 2) If ExchangeMessageDispatch::MessagePermitted returned false for the message. In particular, for an ApplicationExchangeDispatch, this would happen for all the handshake messages except StatusReport. The fix for issue 1 is to ensure we always call into HandleMEssage if we manage to allocate an exchange. If there is an encryption mismatch, which only matters when the exchange is non-ephemeral, we close the exchange first to prevent event delivery to the delegate. The fix for issue 2 is to move the MRP processing out of ExchangeMessageDispatch and into ExchangeContext, and to move the MessagePermitted check so the only thing it prevents is delivery of the message to the delegate, not any other processing by the exchange. Fixes https://github.com/project-chip/connectedhomeip/issues/10515 --- src/messaging/ApplicationExchangeDispatch.cpp | 10 +++---- src/messaging/ApplicationExchangeDispatch.h | 2 +- src/messaging/ExchangeContext.cpp | 20 ++++++++++++-- src/messaging/ExchangeMessageDispatch.cpp | 27 +------------------ src/messaging/ExchangeMessageDispatch.h | 6 ++--- src/messaging/ExchangeMgr.cpp | 18 ++++++++++--- .../tests/TestReliableMessageProtocol.cpp | 2 +- .../SessionEstablishmentExchangeDispatch.cpp | 10 +++---- .../SessionEstablishmentExchangeDispatch.h | 2 +- .../secure_channel/tests/TestCASESession.cpp | 9 +++---- .../secure_channel/tests/TestPASESession.cpp | 9 +++---- 11 files changed, 53 insertions(+), 62 deletions(-) diff --git a/src/messaging/ApplicationExchangeDispatch.cpp b/src/messaging/ApplicationExchangeDispatch.cpp index 3e2ad293b7d34b..170e6bb702864d 100644 --- a/src/messaging/ApplicationExchangeDispatch.cpp +++ b/src/messaging/ApplicationExchangeDispatch.cpp @@ -26,12 +26,11 @@ namespace chip { namespace Messaging { -bool ApplicationExchangeDispatch::MessagePermitted(uint16_t protocol, uint8_t type) +bool ApplicationExchangeDispatch::MessagePermitted(Protocols::Id protocol, uint8_t type) { // TODO: Change this check to only include the protocol and message types that are allowed - switch (protocol) + if (protocol == Protocols::SecureChannel::Id) { - case Protocols::SecureChannel::Id.GetProtocolId(): switch (type) { case static_cast(Protocols::SecureChannel::MsgType::PBKDFParamRequest): @@ -49,11 +48,8 @@ bool ApplicationExchangeDispatch::MessagePermitted(uint16_t protocol, uint8_t ty default: break; } - break; - - default: - break; } + return true; } diff --git a/src/messaging/ApplicationExchangeDispatch.h b/src/messaging/ApplicationExchangeDispatch.h index 6316df05f1b44f..04d7a0d1b4c7d1 100644 --- a/src/messaging/ApplicationExchangeDispatch.h +++ b/src/messaging/ApplicationExchangeDispatch.h @@ -43,7 +43,7 @@ class ApplicationExchangeDispatch : public ExchangeMessageDispatch ~ApplicationExchangeDispatch() override {} protected: - bool MessagePermitted(uint16_t protocol, uint8_t type) override; + bool MessagePermitted(Protocols::Id protocol, uint8_t type) override; }; } // namespace Messaging diff --git a/src/messaging/ExchangeContext.cpp b/src/messaging/ExchangeContext.cpp index b0aeac785b9beb..cb67cdbbe03a9b 100644 --- a/src/messaging/ExchangeContext.cpp +++ b/src/messaging/ExchangeContext.cpp @@ -458,7 +458,21 @@ CHIP_ERROR ExchangeContext::HandleMessage(uint32_t messageCounter, const Payload MessageHandled(); }); - ReturnErrorOnFailure(mDispatch.OnMessageReceived(messageCounter, payloadHeader, msgFlags, GetReliableMessageContext())); + if (mDispatch.IsReliableTransmissionAllowed() && !IsGroupExchangeContext()) + { + if (!msgFlags.Has(MessageFlagValues::kDuplicateMessage) && payloadHeader.IsAckMsg() && + payloadHeader.GetAckMessageCounter().HasValue()) + { + HandleRcvdAck(payloadHeader.GetAckMessageCounter().Value()); + } + + if (payloadHeader.NeedsAck()) + { + // An acknowledgment needs to be sent back to the peer for this message on this exchange, + + HandleNeedsAck(messageCounter, msgFlags); + } + } if (IsAckPending() && !mDelegate) { @@ -487,7 +501,9 @@ CHIP_ERROR ExchangeContext::HandleMessage(uint32_t messageCounter, const Payload // is implicitly that response. SetResponseExpected(false); - if (mDelegate != nullptr) + // Don't send messages on to our delegate if our dispatch does not allow + // those messages. + if (mDelegate != nullptr && mDispatch.MessagePermitted(payloadHeader.GetProtocolID(), payloadHeader.GetMessageType())) { return mDelegate->OnMessageReceived(this, payloadHeader, std::move(msgBuf)); } diff --git a/src/messaging/ExchangeMessageDispatch.cpp b/src/messaging/ExchangeMessageDispatch.cpp index 3e318a72d247af..4bd14e89da600f 100644 --- a/src/messaging/ExchangeMessageDispatch.cpp +++ b/src/messaging/ExchangeMessageDispatch.cpp @@ -46,7 +46,7 @@ CHIP_ERROR ExchangeMessageDispatch::SendMessage(SessionManager * sessionManager, bool isReliableTransmission, Protocols::Id protocol, uint8_t type, System::PacketBufferHandle && message) { - ReturnErrorCodeIf(!MessagePermitted(protocol.GetProtocolId(), type), CHIP_ERROR_INVALID_ARGUMENT); + ReturnErrorCodeIf(!MessagePermitted(protocol, type), CHIP_ERROR_INVALID_ARGUMENT); PayloadHeader payloadHeader; payloadHeader.SetExchangeID(exchangeId).SetMessageType(protocol, type).SetInitiator(isInitiator); @@ -113,30 +113,5 @@ CHIP_ERROR ExchangeMessageDispatch::SendMessage(SessionManager * sessionManager, return CHIP_NO_ERROR; } -CHIP_ERROR ExchangeMessageDispatch::OnMessageReceived(uint32_t messageCounter, const PayloadHeader & payloadHeader, - MessageFlags msgFlags, ReliableMessageContext * reliableMessageContext) -{ - ReturnErrorCodeIf(!MessagePermitted(payloadHeader.GetProtocolID().GetProtocolId(), payloadHeader.GetMessageType()), - CHIP_ERROR_INVALID_ARGUMENT); - - if (IsReliableTransmissionAllowed() && !reliableMessageContext->GetExchangeContext()->IsGroupExchangeContext()) - { - if (!msgFlags.Has(MessageFlagValues::kDuplicateMessage) && payloadHeader.IsAckMsg() && - payloadHeader.GetAckMessageCounter().HasValue()) - { - reliableMessageContext->HandleRcvdAck(payloadHeader.GetAckMessageCounter().Value()); - } - - if (payloadHeader.NeedsAck()) - { - // An acknowledgment needs to be sent back to the peer for this message on this exchange, - - ReturnErrorOnFailure(reliableMessageContext->HandleNeedsAck(messageCounter, msgFlags)); - } - } - - return CHIP_NO_ERROR; -} - } // namespace Messaging } // namespace chip diff --git a/src/messaging/ExchangeMessageDispatch.h b/src/messaging/ExchangeMessageDispatch.h index 6c6fefb24fcdeb..35f12e06e80d36 100644 --- a/src/messaging/ExchangeMessageDispatch.h +++ b/src/messaging/ExchangeMessageDispatch.h @@ -24,6 +24,7 @@ #pragma once #include +#include #include namespace chip { @@ -42,11 +43,8 @@ class ExchangeMessageDispatch CHIP_ERROR SendMessage(SessionManager * sessionManager, const SessionHandle & session, uint16_t exchangeId, bool isInitiator, ReliableMessageContext * reliableMessageContext, bool isReliableTransmission, Protocols::Id protocol, uint8_t type, System::PacketBufferHandle && message); - CHIP_ERROR OnMessageReceived(uint32_t messageCounter, const PayloadHeader & payloadHeader, MessageFlags msgFlags, - ReliableMessageContext * reliableMessageContext); -protected: - virtual bool MessagePermitted(uint16_t protocol, uint8_t type) = 0; + virtual bool MessagePermitted(Protocols::Id protocol, uint8_t type) = 0; // TODO: remove IsReliableTransmissionAllowed, this function should be provided over session. virtual bool IsReliableTransmissionAllowed() const { return true; } diff --git a/src/messaging/ExchangeMgr.cpp b/src/messaging/ExchangeMgr.cpp index fc5bea10b83e22..1a7a3ca4cc1fb1 100644 --- a/src/messaging/ExchangeMgr.cpp +++ b/src/messaging/ExchangeMgr.cpp @@ -310,11 +310,23 @@ void ExchangeManager::OnMessageReceived(const PacketHeader & packetHeader, const ChipLogDetail(ExchangeManager, "Handling via exchange: " ChipLogFormatExchange ", Delegate: %p", ChipLogValueExchange(ec), ec->GetDelegate()); - if (ec->IsEncryptionRequired() != packetHeader.IsEncrypted()) + // Make sure the exchange stays alive through the code below even if we + // close it before calling HandleMessage. + ExchangeHandle ref(*ec); + + // Ignore encryption-required mismatches for emphemeral exchanges, + // because those never have delegates anyway. + if (matchingUMH != nullptr && ec->IsEncryptionRequired() != packetHeader.IsEncrypted()) { - ChipLogError(ExchangeManager, "OnMessageReceived failed, err = %s", ErrorStr(CHIP_ERROR_INVALID_MESSAGE_TYPE)); + // We want to still to do MRP processing for this message, but we do + // not want to deliver it to the application. Just close the + // exchange (which will notify the delegate, null it out, etc), then + // go ahead and call HandleMessage() on it to do the MRP + // processing.null out the delegate on the exchange, pretend to + // matchingUMH that exchange creation failed, so it cleans up the + // delegate, then tell the exchagne to handle the message. + ChipLogProgress(ExchangeManager, "OnMessageReceived encryption mismatch"); ec->Close(); - return; } CHIP_ERROR err = ec->HandleMessage(packetHeader.GetMessageCounter(), payloadHeader, msgFlags, std::move(msgBuf)); diff --git a/src/messaging/tests/TestReliableMessageProtocol.cpp b/src/messaging/tests/TestReliableMessageProtocol.cpp index 8cd06bb79015b2..4ddf2c63919759 100644 --- a/src/messaging/tests/TestReliableMessageProtocol.cpp +++ b/src/messaging/tests/TestReliableMessageProtocol.cpp @@ -132,7 +132,7 @@ class MockSessionEstablishmentExchangeDispatch : public Messaging::ApplicationEx public: bool IsReliableTransmissionAllowed() const override { return mRetainMessageOnSend; } - bool MessagePermitted(uint16_t protocol, uint8_t type) override { return true; } + bool MessagePermitted(Protocols::Id protocol, uint8_t type) override { return true; } bool IsEncryptionRequired() const override { return mRequireEncryption; } diff --git a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp b/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp index 240c3ada9dc087..3a26d0df1ee06d 100644 --- a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp +++ b/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp @@ -28,11 +28,10 @@ namespace chip { using namespace Messaging; -bool SessionEstablishmentExchangeDispatch::MessagePermitted(uint16_t protocol, uint8_t type) +bool SessionEstablishmentExchangeDispatch::MessagePermitted(Protocols::Id protocol, uint8_t type) { - switch (protocol) + if (protocol == Protocols::SecureChannel::Id) { - case Protocols::SecureChannel::Id.GetProtocolId(): switch (type) { case static_cast(Protocols::SecureChannel::MsgType::StandaloneAck): @@ -52,11 +51,8 @@ bool SessionEstablishmentExchangeDispatch::MessagePermitted(uint16_t protocol, u default: break; } - break; - - default: - break; } + return false; } diff --git a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.h b/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.h index 9cb75a596ff6a4..72083b45dd909c 100644 --- a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.h +++ b/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.h @@ -41,7 +41,7 @@ class SessionEstablishmentExchangeDispatch : public Messaging::ExchangeMessageDi ~SessionEstablishmentExchangeDispatch() override {} protected: - bool MessagePermitted(uint16_t protocol, uint8_t type) override; + bool MessagePermitted(Protocols::Id, uint8_t type) override; bool IsEncryptionRequired() const override { return false; } }; diff --git a/src/protocols/secure_channel/tests/TestCASESession.cpp b/src/protocols/secure_channel/tests/TestCASESession.cpp index 816682767df078..c66c350b5b4a7b 100644 --- a/src/protocols/secure_channel/tests/TestCASESession.cpp +++ b/src/protocols/secure_channel/tests/TestCASESession.cpp @@ -235,12 +235,11 @@ void CASE_SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) ctx.DrainAndServiceIO(); auto & loopback = ctx.GetLoopback(); - NL_TEST_ASSERT(inSuite, loopback.mSentMessageCount == 1); + // There should have been two message sent: Sigma1 and an ack. + NL_TEST_ASSERT(inSuite, loopback.mSentMessageCount == 2); - // Clear pending packet in CRMP - ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr(); - ReliableMessageContext * rc = context->GetReliableMessageContext(); - rm->ClearRetransTable(rc); + ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr(); + NL_TEST_ASSERT(inSuite, rm->TestGetCountRetransTable() == 0); loopback.mMessageSendError = CHIP_ERROR_BAD_REQUEST; diff --git a/src/protocols/secure_channel/tests/TestPASESession.cpp b/src/protocols/secure_channel/tests/TestPASESession.cpp index 5c4b5243f91368..6b60ff3bd413ab 100644 --- a/src/protocols/secure_channel/tests/TestPASESession.cpp +++ b/src/protocols/secure_channel/tests/TestPASESession.cpp @@ -177,12 +177,11 @@ void SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) &delegate) == CHIP_NO_ERROR); ctx.DrainAndServiceIO(); - NL_TEST_ASSERT(inSuite, loopback.mSentMessageCount == 1); + // There should have been two messages sent: PBKDFParamRequest and an ack. + NL_TEST_ASSERT(inSuite, loopback.mSentMessageCount == 2); - // Clear pending packet in CRMP - ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr(); - ReliableMessageContext * rc = context->GetReliableMessageContext(); - rm->ClearRetransTable(rc); + ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr(); + NL_TEST_ASSERT(inSuite, rm->TestGetCountRetransTable() == 0); loopback.Reset(); loopback.mSentMessageCount = 0;