Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent reception of unencrypted messages in an encrypted message exchange #8556

Merged
merged 2 commits into from
Jul 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/messaging/ExchangeContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -425,8 +425,8 @@ CHIP_ERROR ExchangeContext::HandleMessage(const PacketHeader & packetHeader, con
MessageHandled();
});

ReturnErrorOnFailure(mDispatch->OnMessageReceived(payloadHeader, packetHeader.GetMessageId(), peerAddress, msgFlags,
GetReliableMessageContext()));
ReturnErrorOnFailure(mDispatch->OnMessageReceived(packetHeader.GetFlags(), payloadHeader, packetHeader.GetMessageId(),
peerAddress, msgFlags, GetReliableMessageContext()));

if (IsAckPending() && !mDelegate)
{
Expand Down
11 changes: 8 additions & 3 deletions src/messaging/ExchangeMessageDispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,15 @@ CHIP_ERROR ExchangeMessageDispatch::SendMessage(SecureSessionHandle session, uin
return CHIP_NO_ERROR;
}

CHIP_ERROR ExchangeMessageDispatch::OnMessageReceived(const PayloadHeader & payloadHeader, uint32_t messageId,
const Transport::PeerAddress & peerAddress, MessageFlags msgFlags,
ReliableMessageContext * reliableMessageContext)
CHIP_ERROR ExchangeMessageDispatch::OnMessageReceived(const Header::Flags & headerFlags, const PayloadHeader & payloadHeader,
uint32_t messageId, const Transport::PeerAddress & peerAddress,
MessageFlags msgFlags, ReliableMessageContext * reliableMessageContext)
{
if (IsEncryptionRequired())
{
VerifyOrReturnError(headerFlags.Has(Header::FlagValues::kEncryptedMessage), CHIP_ERROR_INVALID_ARGUMENT);
}
pan-apple marked this conversation as resolved.
Show resolved Hide resolved

ReturnErrorCodeIf(!MessagePermitted(payloadHeader.GetProtocolID().GetProtocolId(), payloadHeader.GetMessageType()),
CHIP_ERROR_INVALID_ARGUMENT);

Expand Down
3 changes: 2 additions & 1 deletion src/messaging/ExchangeMessageDispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,14 @@ class ExchangeMessageDispatch : public ReferenceCounted<ExchangeMessageDispatch>
virtual CHIP_ERROR SendPreparedMessage(SecureSessionHandle session,
const EncryptedPacketBufferHandle & preparedMessage) const = 0;

virtual CHIP_ERROR OnMessageReceived(const PayloadHeader & payloadHeader, uint32_t messageId,
virtual CHIP_ERROR OnMessageReceived(const Header::Flags & headerFlags, const PayloadHeader & payloadHeader, uint32_t messageId,
const Transport::PeerAddress & peerAddress, MessageFlags msgFlags,
ReliableMessageContext * reliableMessageContext);

protected:
virtual bool MessagePermitted(uint16_t protocol, uint8_t type) = 0;
virtual bool IsReliableTransmissionAllowed() const { return true; }
virtual bool IsEncryptionRequired() const { return true; }
};

} // namespace Messaging
Expand Down
51 changes: 51 additions & 0 deletions src/messaging/tests/TestReliableMessageProtocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,11 @@ class MockSessionEstablishmentExchangeDispatch : public Messaging::ExchangeMessa

bool MessagePermitted(uint16_t protocol, uint8_t type) override { return true; }

bool IsEncryptionRequired() const override { return mRequireEncryption; }

bool mRetainMessageOnSend = true;

bool mRequireEncryption = false;
};

class MockSessionEstablishmentDelegate : public ExchangeDelegate
Expand Down Expand Up @@ -421,6 +425,52 @@ void CheckFailedMessageRetainOnSend(nlTestSuite * inSuite, void * inContext)
rm->ClearRetransTable(rc);
}

void CheckUnencryptedMessageReceiveFailure(nlTestSuite * inSuite, void * inContext)
{
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);

ctx.GetInetLayer().SystemLayer()->Init();

chip::System::PacketBufferHandle buffer = chip::MessagePacketBuffer::NewWithData(PAYLOAD, sizeof(PAYLOAD));
NL_TEST_ASSERT(inSuite, !buffer.IsNull());

MockSessionEstablishmentDelegate mockReceiver;
CHIP_ERROR err = ctx.GetExchangeManager().RegisterUnsolicitedMessageHandlerForType(Echo::MsgType::EchoRequest, &mockReceiver);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);

// Expect the received messages to be encrypted
mockReceiver.mMessageDispatch.mRequireEncryption = true;

MockSessionEstablishmentDelegate mockSender;
ExchangeContext * exchange = ctx.NewExchangeToPeer(&mockSender);
NL_TEST_ASSERT(inSuite, exchange != nullptr);

ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr();
ReliableMessageContext * rc = exchange->GetReliableMessageContext();
NL_TEST_ASSERT(inSuite, rm != nullptr);
NL_TEST_ASSERT(inSuite, rc != nullptr);

err = mockSender.mMessageDispatch.Init();
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);

gLoopback.mSentMessageCount = 0;
gLoopback.mNumMessagesToDrop = 0;
gLoopback.mDroppedMessageCount = 0;

err = exchange->SendMessage(Echo::MsgType::EchoRequest, std::move(buffer));
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);
// Test that the message was actually sent (and not dropped)
NL_TEST_ASSERT(inSuite, gLoopback.mSentMessageCount == 1);
NL_TEST_ASSERT(inSuite, gLoopback.mDroppedMessageCount == 0);
// Test that the message was dropped by the receiver
NL_TEST_ASSERT(inSuite, !mockReceiver.IsOnMessageReceivedCalled);

// Since peer dropped the message, we might have pending acks. Let's clear the table
rm->ClearRetransTable(rc);

exchange->Close();
}

void CheckResendApplicationMessageWithPeerExchange(nlTestSuite * inSuite, void * inContext)
{
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);
Expand Down Expand Up @@ -1161,6 +1211,7 @@ const nlTest sTests[] =
NL_TEST_DEF("Test sending an unsolicited ack-soliciting 'standalone ack' message", CheckSendUnsolicitedStandaloneAckMessage),
NL_TEST_DEF("Test ReliableMessageMgr::CheckSendStandaloneAckMessage", CheckSendStandaloneAckMessage),
NL_TEST_DEF("Test command, response, default response, with receiver closing exchange after sending response", CheckMessageAfterClosed),
NL_TEST_DEF("Test that unencrypted message is dropped if exchange requires encryption", CheckUnencryptedMessageReceiveFailure),

NL_TEST_SENTINEL()
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,15 @@ CHIP_ERROR SessionEstablishmentExchangeDispatch::SendPreparedMessage(SecureSessi
return mTransportMgr->SendMessage(mPeerAddress, preparedMessage.CastToWritable());
}

CHIP_ERROR SessionEstablishmentExchangeDispatch::OnMessageReceived(const PayloadHeader & payloadHeader, uint32_t messageId,
CHIP_ERROR SessionEstablishmentExchangeDispatch::OnMessageReceived(const Header::Flags & headerFlags,
const PayloadHeader & payloadHeader, uint32_t messageId,
const Transport::PeerAddress & peerAddress,
Messaging::MessageFlags msgFlags,
ReliableMessageContext * reliableMessageContext)
{
mPeerAddress = peerAddress;
return ExchangeMessageDispatch::OnMessageReceived(payloadHeader, messageId, peerAddress, msgFlags, reliableMessageContext);
return ExchangeMessageDispatch::OnMessageReceived(headerFlags, payloadHeader, messageId, peerAddress, msgFlags,
reliableMessageContext);
}

bool SessionEstablishmentExchangeDispatch::MessagePermitted(uint16_t protocol, uint8_t type)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class SessionEstablishmentExchangeDispatch : public Messaging::ExchangeMessageDi
EncryptedPacketBufferHandle & out) override;
CHIP_ERROR SendPreparedMessage(SecureSessionHandle session, const EncryptedPacketBufferHandle & preparedMessage) const override;

CHIP_ERROR OnMessageReceived(const PayloadHeader & payloadHeader, uint32_t messageId,
CHIP_ERROR OnMessageReceived(const Header::Flags & headerFlags, const PayloadHeader & payloadHeader, uint32_t messageId,
const Transport::PeerAddress & peerAddress, Messaging::MessageFlags msgFlags,
Messaging::ReliableMessageContext * reliableMessageContext) override;

Expand All @@ -64,6 +64,8 @@ class SessionEstablishmentExchangeDispatch : public Messaging::ExchangeMessageDi
return (mPeerAddress.GetTransportType() == Transport::Type::kUdp);
}

bool IsEncryptionRequired() const override { return false; }

private:
TransportMgrBase * mTransportMgr = nullptr;
Transport::PeerAddress mPeerAddress;
Expand Down