Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure that we send MRP acks to incoming messages as needed. #19398

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions src/messaging/ApplicationExchangeDispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,11 @@
namespace chip {
namespace Messaging {

bool ApplicationExchangeDispatch::MessagePermitted(uint16_t protocol, uint8_t type)
bool ApplicationExchangeDispatch::MessagePermitted(Protocols::Id protocol, uint8_t type)
{
// TODO: Change this check to only include the protocol and message types that are allowed
switch (protocol)
if (protocol == Protocols::SecureChannel::Id)
{
case Protocols::SecureChannel::Id.GetProtocolId():
switch (type)
{
case static_cast<uint8_t>(Protocols::SecureChannel::MsgType::PBKDFParamRequest):
Expand All @@ -49,11 +48,8 @@ bool ApplicationExchangeDispatch::MessagePermitted(uint16_t protocol, uint8_t ty
default:
break;
}
break;

default:
break;
}

return true;
}

Expand Down
2 changes: 1 addition & 1 deletion src/messaging/ApplicationExchangeDispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class ApplicationExchangeDispatch : public ExchangeMessageDispatch
~ApplicationExchangeDispatch() override {}

protected:
bool MessagePermitted(uint16_t protocol, uint8_t type) override;
bool MessagePermitted(Protocols::Id protocol, uint8_t type) override;
};

} // namespace Messaging
Expand Down
20 changes: 18 additions & 2 deletions src/messaging/ExchangeContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,21 @@ CHIP_ERROR ExchangeContext::HandleMessage(uint32_t messageCounter, const Payload
MessageHandled();
});

ReturnErrorOnFailure(mDispatch.OnMessageReceived(messageCounter, payloadHeader, msgFlags, GetReliableMessageContext()));
if (mDispatch.IsReliableTransmissionAllowed() && !IsGroupExchangeContext())
{
if (!msgFlags.Has(MessageFlagValues::kDuplicateMessage) && payloadHeader.IsAckMsg() &&
payloadHeader.GetAckMessageCounter().HasValue())
{
HandleRcvdAck(payloadHeader.GetAckMessageCounter().Value());
}

if (payloadHeader.NeedsAck())
{
// An acknowledgment needs to be sent back to the peer for this message on this exchange,

HandleNeedsAck(messageCounter, msgFlags);
}
}

if (IsAckPending() && !mDelegate)
{
Expand Down Expand Up @@ -487,7 +501,9 @@ CHIP_ERROR ExchangeContext::HandleMessage(uint32_t messageCounter, const Payload
// is implicitly that response.
SetResponseExpected(false);

if (mDelegate != nullptr)
// Don't send messages on to our delegate if our dispatch does not allow
// those messages.
if (mDelegate != nullptr && mDispatch.MessagePermitted(payloadHeader.GetProtocolID(), payloadHeader.GetMessageType()))
{
return mDelegate->OnMessageReceived(this, payloadHeader, std::move(msgBuf));
}
Expand Down
27 changes: 1 addition & 26 deletions src/messaging/ExchangeMessageDispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ CHIP_ERROR ExchangeMessageDispatch::SendMessage(SessionManager * sessionManager,
bool isReliableTransmission, Protocols::Id protocol, uint8_t type,
System::PacketBufferHandle && message)
{
ReturnErrorCodeIf(!MessagePermitted(protocol.GetProtocolId(), type), CHIP_ERROR_INVALID_ARGUMENT);
ReturnErrorCodeIf(!MessagePermitted(protocol, type), CHIP_ERROR_INVALID_ARGUMENT);

PayloadHeader payloadHeader;
payloadHeader.SetExchangeID(exchangeId).SetMessageType(protocol, type).SetInitiator(isInitiator);
Expand Down Expand Up @@ -113,30 +113,5 @@ CHIP_ERROR ExchangeMessageDispatch::SendMessage(SessionManager * sessionManager,
return CHIP_NO_ERROR;
}

CHIP_ERROR ExchangeMessageDispatch::OnMessageReceived(uint32_t messageCounter, const PayloadHeader & payloadHeader,
MessageFlags msgFlags, ReliableMessageContext * reliableMessageContext)
{
ReturnErrorCodeIf(!MessagePermitted(payloadHeader.GetProtocolID().GetProtocolId(), payloadHeader.GetMessageType()),
CHIP_ERROR_INVALID_ARGUMENT);

if (IsReliableTransmissionAllowed() && !reliableMessageContext->GetExchangeContext()->IsGroupExchangeContext())
{
if (!msgFlags.Has(MessageFlagValues::kDuplicateMessage) && payloadHeader.IsAckMsg() &&
payloadHeader.GetAckMessageCounter().HasValue())
{
reliableMessageContext->HandleRcvdAck(payloadHeader.GetAckMessageCounter().Value());
}

if (payloadHeader.NeedsAck())
{
// An acknowledgment needs to be sent back to the peer for this message on this exchange,

ReturnErrorOnFailure(reliableMessageContext->HandleNeedsAck(messageCounter, msgFlags));
}
}

return CHIP_NO_ERROR;
}

} // namespace Messaging
} // namespace chip
6 changes: 2 additions & 4 deletions src/messaging/ExchangeMessageDispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#pragma once

#include <messaging/Flags.h>
#include <protocols/Protocols.h>
#include <transport/SessionManager.h>

namespace chip {
Expand All @@ -42,11 +43,8 @@ class ExchangeMessageDispatch
CHIP_ERROR SendMessage(SessionManager * sessionManager, const SessionHandle & session, uint16_t exchangeId, bool isInitiator,
ReliableMessageContext * reliableMessageContext, bool isReliableTransmission, Protocols::Id protocol,
uint8_t type, System::PacketBufferHandle && message);
CHIP_ERROR OnMessageReceived(uint32_t messageCounter, const PayloadHeader & payloadHeader, MessageFlags msgFlags,
ReliableMessageContext * reliableMessageContext);

protected:
virtual bool MessagePermitted(uint16_t protocol, uint8_t type) = 0;
virtual bool MessagePermitted(Protocols::Id protocol, uint8_t type) = 0;

// TODO: remove IsReliableTransmissionAllowed, this function should be provided over session.
virtual bool IsReliableTransmissionAllowed() const { return true; }
Expand Down
18 changes: 15 additions & 3 deletions src/messaging/ExchangeMgr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,11 +310,23 @@ void ExchangeManager::OnMessageReceived(const PacketHeader & packetHeader, const
ChipLogDetail(ExchangeManager, "Handling via exchange: " ChipLogFormatExchange ", Delegate: %p", ChipLogValueExchange(ec),
ec->GetDelegate());

if (ec->IsEncryptionRequired() != packetHeader.IsEncrypted())
// Make sure the exchange stays alive through the code below even if we
// close it before calling HandleMessage.
ExchangeHandle ref(*ec);

// Ignore encryption-required mismatches for emphemeral exchanges,
// because those never have delegates anyway.
if (matchingUMH != nullptr && ec->IsEncryptionRequired() != packetHeader.IsEncrypted())
{
ChipLogError(ExchangeManager, "OnMessageReceived failed, err = %s", ErrorStr(CHIP_ERROR_INVALID_MESSAGE_TYPE));
// We want to still to do MRP processing for this message, but we do
// not want to deliver it to the application. Just close the
// exchange (which will notify the delegate, null it out, etc), then
// go ahead and call HandleMessage() on it to do the MRP
// processing.null out the delegate on the exchange, pretend to
// matchingUMH that exchange creation failed, so it cleans up the
// delegate, then tell the exchagne to handle the message.
ChipLogProgress(ExchangeManager, "OnMessageReceived encryption mismatch");
ec->Close();
return;
kghost marked this conversation as resolved.
Show resolved Hide resolved
}

CHIP_ERROR err = ec->HandleMessage(packetHeader.GetMessageCounter(), payloadHeader, msgFlags, std::move(msgBuf));
Expand Down
2 changes: 1 addition & 1 deletion src/messaging/tests/TestReliableMessageProtocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ class MockSessionEstablishmentExchangeDispatch : public Messaging::ApplicationEx
public:
bool IsReliableTransmissionAllowed() const override { return mRetainMessageOnSend; }

bool MessagePermitted(uint16_t protocol, uint8_t type) override { return true; }
bool MessagePermitted(Protocols::Id protocol, uint8_t type) override { return true; }

bool IsEncryptionRequired() const override { return mRequireEncryption; }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,10 @@ namespace chip {

using namespace Messaging;

bool SessionEstablishmentExchangeDispatch::MessagePermitted(uint16_t protocol, uint8_t type)
bool SessionEstablishmentExchangeDispatch::MessagePermitted(Protocols::Id protocol, uint8_t type)
{
switch (protocol)
if (protocol == Protocols::SecureChannel::Id)
{
case Protocols::SecureChannel::Id.GetProtocolId():
switch (type)
{
case static_cast<uint8_t>(Protocols::SecureChannel::MsgType::StandaloneAck):
Expand All @@ -52,11 +51,8 @@ bool SessionEstablishmentExchangeDispatch::MessagePermitted(uint16_t protocol, u
default:
break;
}
break;

default:
break;
}

return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class SessionEstablishmentExchangeDispatch : public Messaging::ExchangeMessageDi
~SessionEstablishmentExchangeDispatch() override {}

protected:
bool MessagePermitted(uint16_t protocol, uint8_t type) override;
bool MessagePermitted(Protocols::Id, uint8_t type) override;
bool IsEncryptionRequired() const override { return false; }
};

Expand Down
9 changes: 4 additions & 5 deletions src/protocols/secure_channel/tests/TestCASESession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,12 +235,11 @@ void CASE_SecurePairingStartTest(nlTestSuite * inSuite, void * inContext)
ctx.DrainAndServiceIO();

auto & loopback = ctx.GetLoopback();
NL_TEST_ASSERT(inSuite, loopback.mSentMessageCount == 1);
// There should have been two message sent: Sigma1 and an ack.
NL_TEST_ASSERT(inSuite, loopback.mSentMessageCount == 2);

// Clear pending packet in CRMP
ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr();
ReliableMessageContext * rc = context->GetReliableMessageContext();
rm->ClearRetransTable(rc);
ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr();
NL_TEST_ASSERT(inSuite, rm->TestGetCountRetransTable() == 0);

loopback.mMessageSendError = CHIP_ERROR_BAD_REQUEST;

Expand Down
9 changes: 4 additions & 5 deletions src/protocols/secure_channel/tests/TestPASESession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,11 @@ void SecurePairingStartTest(nlTestSuite * inSuite, void * inContext)
&delegate) == CHIP_NO_ERROR);
ctx.DrainAndServiceIO();

NL_TEST_ASSERT(inSuite, loopback.mSentMessageCount == 1);
// There should have been two messages sent: PBKDFParamRequest and an ack.
NL_TEST_ASSERT(inSuite, loopback.mSentMessageCount == 2);

// Clear pending packet in CRMP
ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr();
ReliableMessageContext * rc = context->GetReliableMessageContext();
rm->ClearRetransTable(rc);
ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr();
NL_TEST_ASSERT(inSuite, rm->TestGetCountRetransTable() == 0);

loopback.Reset();
loopback.mSentMessageCount = 0;
Expand Down