diff --git a/src/messaging/ApplicationExchangeDispatch.cpp b/src/messaging/ApplicationExchangeDispatch.cpp index 3e2ad293b7d34b..170e6bb702864d 100644 --- a/src/messaging/ApplicationExchangeDispatch.cpp +++ b/src/messaging/ApplicationExchangeDispatch.cpp @@ -26,12 +26,11 @@ namespace chip { namespace Messaging { -bool ApplicationExchangeDispatch::MessagePermitted(uint16_t protocol, uint8_t type) +bool ApplicationExchangeDispatch::MessagePermitted(Protocols::Id protocol, uint8_t type) { // TODO: Change this check to only include the protocol and message types that are allowed - switch (protocol) + if (protocol == Protocols::SecureChannel::Id) { - case Protocols::SecureChannel::Id.GetProtocolId(): switch (type) { case static_cast(Protocols::SecureChannel::MsgType::PBKDFParamRequest): @@ -49,11 +48,8 @@ bool ApplicationExchangeDispatch::MessagePermitted(uint16_t protocol, uint8_t ty default: break; } - break; - - default: - break; } + return true; } diff --git a/src/messaging/ApplicationExchangeDispatch.h b/src/messaging/ApplicationExchangeDispatch.h index 6316df05f1b44f..04d7a0d1b4c7d1 100644 --- a/src/messaging/ApplicationExchangeDispatch.h +++ b/src/messaging/ApplicationExchangeDispatch.h @@ -43,7 +43,7 @@ class ApplicationExchangeDispatch : public ExchangeMessageDispatch ~ApplicationExchangeDispatch() override {} protected: - bool MessagePermitted(uint16_t protocol, uint8_t type) override; + bool MessagePermitted(Protocols::Id protocol, uint8_t type) override; }; } // namespace Messaging diff --git a/src/messaging/ExchangeContext.cpp b/src/messaging/ExchangeContext.cpp index b0aeac785b9beb..cb67cdbbe03a9b 100644 --- a/src/messaging/ExchangeContext.cpp +++ b/src/messaging/ExchangeContext.cpp @@ -458,7 +458,21 @@ CHIP_ERROR ExchangeContext::HandleMessage(uint32_t messageCounter, const Payload MessageHandled(); }); - ReturnErrorOnFailure(mDispatch.OnMessageReceived(messageCounter, payloadHeader, msgFlags, GetReliableMessageContext())); + if (mDispatch.IsReliableTransmissionAllowed() && !IsGroupExchangeContext()) + { + if (!msgFlags.Has(MessageFlagValues::kDuplicateMessage) && payloadHeader.IsAckMsg() && + payloadHeader.GetAckMessageCounter().HasValue()) + { + HandleRcvdAck(payloadHeader.GetAckMessageCounter().Value()); + } + + if (payloadHeader.NeedsAck()) + { + // An acknowledgment needs to be sent back to the peer for this message on this exchange, + + HandleNeedsAck(messageCounter, msgFlags); + } + } if (IsAckPending() && !mDelegate) { @@ -487,7 +501,9 @@ CHIP_ERROR ExchangeContext::HandleMessage(uint32_t messageCounter, const Payload // is implicitly that response. SetResponseExpected(false); - if (mDelegate != nullptr) + // Don't send messages on to our delegate if our dispatch does not allow + // those messages. + if (mDelegate != nullptr && mDispatch.MessagePermitted(payloadHeader.GetProtocolID(), payloadHeader.GetMessageType())) { return mDelegate->OnMessageReceived(this, payloadHeader, std::move(msgBuf)); } diff --git a/src/messaging/ExchangeMessageDispatch.cpp b/src/messaging/ExchangeMessageDispatch.cpp index 3e318a72d247af..4bd14e89da600f 100644 --- a/src/messaging/ExchangeMessageDispatch.cpp +++ b/src/messaging/ExchangeMessageDispatch.cpp @@ -46,7 +46,7 @@ CHIP_ERROR ExchangeMessageDispatch::SendMessage(SessionManager * sessionManager, bool isReliableTransmission, Protocols::Id protocol, uint8_t type, System::PacketBufferHandle && message) { - ReturnErrorCodeIf(!MessagePermitted(protocol.GetProtocolId(), type), CHIP_ERROR_INVALID_ARGUMENT); + ReturnErrorCodeIf(!MessagePermitted(protocol, type), CHIP_ERROR_INVALID_ARGUMENT); PayloadHeader payloadHeader; payloadHeader.SetExchangeID(exchangeId).SetMessageType(protocol, type).SetInitiator(isInitiator); @@ -113,30 +113,5 @@ CHIP_ERROR ExchangeMessageDispatch::SendMessage(SessionManager * sessionManager, return CHIP_NO_ERROR; } -CHIP_ERROR ExchangeMessageDispatch::OnMessageReceived(uint32_t messageCounter, const PayloadHeader & payloadHeader, - MessageFlags msgFlags, ReliableMessageContext * reliableMessageContext) -{ - ReturnErrorCodeIf(!MessagePermitted(payloadHeader.GetProtocolID().GetProtocolId(), payloadHeader.GetMessageType()), - CHIP_ERROR_INVALID_ARGUMENT); - - if (IsReliableTransmissionAllowed() && !reliableMessageContext->GetExchangeContext()->IsGroupExchangeContext()) - { - if (!msgFlags.Has(MessageFlagValues::kDuplicateMessage) && payloadHeader.IsAckMsg() && - payloadHeader.GetAckMessageCounter().HasValue()) - { - reliableMessageContext->HandleRcvdAck(payloadHeader.GetAckMessageCounter().Value()); - } - - if (payloadHeader.NeedsAck()) - { - // An acknowledgment needs to be sent back to the peer for this message on this exchange, - - ReturnErrorOnFailure(reliableMessageContext->HandleNeedsAck(messageCounter, msgFlags)); - } - } - - return CHIP_NO_ERROR; -} - } // namespace Messaging } // namespace chip diff --git a/src/messaging/ExchangeMessageDispatch.h b/src/messaging/ExchangeMessageDispatch.h index 6c6fefb24fcdeb..35f12e06e80d36 100644 --- a/src/messaging/ExchangeMessageDispatch.h +++ b/src/messaging/ExchangeMessageDispatch.h @@ -24,6 +24,7 @@ #pragma once #include +#include #include namespace chip { @@ -42,11 +43,8 @@ class ExchangeMessageDispatch CHIP_ERROR SendMessage(SessionManager * sessionManager, const SessionHandle & session, uint16_t exchangeId, bool isInitiator, ReliableMessageContext * reliableMessageContext, bool isReliableTransmission, Protocols::Id protocol, uint8_t type, System::PacketBufferHandle && message); - CHIP_ERROR OnMessageReceived(uint32_t messageCounter, const PayloadHeader & payloadHeader, MessageFlags msgFlags, - ReliableMessageContext * reliableMessageContext); -protected: - virtual bool MessagePermitted(uint16_t protocol, uint8_t type) = 0; + virtual bool MessagePermitted(Protocols::Id protocol, uint8_t type) = 0; // TODO: remove IsReliableTransmissionAllowed, this function should be provided over session. virtual bool IsReliableTransmissionAllowed() const { return true; } diff --git a/src/messaging/ExchangeMgr.cpp b/src/messaging/ExchangeMgr.cpp index fc5bea10b83e22..1a7a3ca4cc1fb1 100644 --- a/src/messaging/ExchangeMgr.cpp +++ b/src/messaging/ExchangeMgr.cpp @@ -310,11 +310,23 @@ void ExchangeManager::OnMessageReceived(const PacketHeader & packetHeader, const ChipLogDetail(ExchangeManager, "Handling via exchange: " ChipLogFormatExchange ", Delegate: %p", ChipLogValueExchange(ec), ec->GetDelegate()); - if (ec->IsEncryptionRequired() != packetHeader.IsEncrypted()) + // Make sure the exchange stays alive through the code below even if we + // close it before calling HandleMessage. + ExchangeHandle ref(*ec); + + // Ignore encryption-required mismatches for emphemeral exchanges, + // because those never have delegates anyway. + if (matchingUMH != nullptr && ec->IsEncryptionRequired() != packetHeader.IsEncrypted()) { - ChipLogError(ExchangeManager, "OnMessageReceived failed, err = %s", ErrorStr(CHIP_ERROR_INVALID_MESSAGE_TYPE)); + // We want to still to do MRP processing for this message, but we do + // not want to deliver it to the application. Just close the + // exchange (which will notify the delegate, null it out, etc), then + // go ahead and call HandleMessage() on it to do the MRP + // processing.null out the delegate on the exchange, pretend to + // matchingUMH that exchange creation failed, so it cleans up the + // delegate, then tell the exchagne to handle the message. + ChipLogProgress(ExchangeManager, "OnMessageReceived encryption mismatch"); ec->Close(); - return; } CHIP_ERROR err = ec->HandleMessage(packetHeader.GetMessageCounter(), payloadHeader, msgFlags, std::move(msgBuf)); diff --git a/src/messaging/tests/TestReliableMessageProtocol.cpp b/src/messaging/tests/TestReliableMessageProtocol.cpp index 8cd06bb79015b2..4ddf2c63919759 100644 --- a/src/messaging/tests/TestReliableMessageProtocol.cpp +++ b/src/messaging/tests/TestReliableMessageProtocol.cpp @@ -132,7 +132,7 @@ class MockSessionEstablishmentExchangeDispatch : public Messaging::ApplicationEx public: bool IsReliableTransmissionAllowed() const override { return mRetainMessageOnSend; } - bool MessagePermitted(uint16_t protocol, uint8_t type) override { return true; } + bool MessagePermitted(Protocols::Id protocol, uint8_t type) override { return true; } bool IsEncryptionRequired() const override { return mRequireEncryption; } diff --git a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp b/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp index 240c3ada9dc087..3a26d0df1ee06d 100644 --- a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp +++ b/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp @@ -28,11 +28,10 @@ namespace chip { using namespace Messaging; -bool SessionEstablishmentExchangeDispatch::MessagePermitted(uint16_t protocol, uint8_t type) +bool SessionEstablishmentExchangeDispatch::MessagePermitted(Protocols::Id protocol, uint8_t type) { - switch (protocol) + if (protocol == Protocols::SecureChannel::Id) { - case Protocols::SecureChannel::Id.GetProtocolId(): switch (type) { case static_cast(Protocols::SecureChannel::MsgType::StandaloneAck): @@ -52,11 +51,8 @@ bool SessionEstablishmentExchangeDispatch::MessagePermitted(uint16_t protocol, u default: break; } - break; - - default: - break; } + return false; } diff --git a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.h b/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.h index 9cb75a596ff6a4..72083b45dd909c 100644 --- a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.h +++ b/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.h @@ -41,7 +41,7 @@ class SessionEstablishmentExchangeDispatch : public Messaging::ExchangeMessageDi ~SessionEstablishmentExchangeDispatch() override {} protected: - bool MessagePermitted(uint16_t protocol, uint8_t type) override; + bool MessagePermitted(Protocols::Id, uint8_t type) override; bool IsEncryptionRequired() const override { return false; } }; diff --git a/src/protocols/secure_channel/tests/TestCASESession.cpp b/src/protocols/secure_channel/tests/TestCASESession.cpp index 816682767df078..c66c350b5b4a7b 100644 --- a/src/protocols/secure_channel/tests/TestCASESession.cpp +++ b/src/protocols/secure_channel/tests/TestCASESession.cpp @@ -235,12 +235,11 @@ void CASE_SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) ctx.DrainAndServiceIO(); auto & loopback = ctx.GetLoopback(); - NL_TEST_ASSERT(inSuite, loopback.mSentMessageCount == 1); + // There should have been two message sent: Sigma1 and an ack. + NL_TEST_ASSERT(inSuite, loopback.mSentMessageCount == 2); - // Clear pending packet in CRMP - ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr(); - ReliableMessageContext * rc = context->GetReliableMessageContext(); - rm->ClearRetransTable(rc); + ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr(); + NL_TEST_ASSERT(inSuite, rm->TestGetCountRetransTable() == 0); loopback.mMessageSendError = CHIP_ERROR_BAD_REQUEST; diff --git a/src/protocols/secure_channel/tests/TestPASESession.cpp b/src/protocols/secure_channel/tests/TestPASESession.cpp index 5c4b5243f91368..6b60ff3bd413ab 100644 --- a/src/protocols/secure_channel/tests/TestPASESession.cpp +++ b/src/protocols/secure_channel/tests/TestPASESession.cpp @@ -177,12 +177,11 @@ void SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) &delegate) == CHIP_NO_ERROR); ctx.DrainAndServiceIO(); - NL_TEST_ASSERT(inSuite, loopback.mSentMessageCount == 1); + // There should have been two messages sent: PBKDFParamRequest and an ack. + NL_TEST_ASSERT(inSuite, loopback.mSentMessageCount == 2); - // Clear pending packet in CRMP - ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr(); - ReliableMessageContext * rc = context->GetReliableMessageContext(); - rm->ClearRetransTable(rc); + ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr(); + NL_TEST_ASSERT(inSuite, rm->TestGetCountRetransTable() == 0); loopback.Reset(); loopback.mSentMessageCount = 0;