Skip to content

Commit

Permalink
Prevent reception of unencrypted messages in an encrypted message exc…
Browse files Browse the repository at this point in the history
…hange (#8556)

* Prevent injection of unencrypted messages in an application message exchange

* fix test
  • Loading branch information
pan-apple authored and pull[bot] committed Sep 1, 2021
1 parent 63e5730 commit 1076380
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 9 deletions.
4 changes: 2 additions & 2 deletions src/messaging/ExchangeContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -433,8 +433,8 @@ CHIP_ERROR ExchangeContext::HandleMessage(const PacketHeader & packetHeader, con
MessageHandled();
});

ReturnErrorOnFailure(mDispatch->OnMessageReceived(payloadHeader, packetHeader.GetMessageId(), peerAddress, msgFlags,
GetReliableMessageContext()));
ReturnErrorOnFailure(mDispatch->OnMessageReceived(packetHeader.GetFlags(), payloadHeader, packetHeader.GetMessageId(),
peerAddress, msgFlags, GetReliableMessageContext()));

if (IsAckPending() && !mDelegate)
{
Expand Down
11 changes: 8 additions & 3 deletions src/messaging/ExchangeMessageDispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,15 @@ CHIP_ERROR ExchangeMessageDispatch::SendMessage(SecureSessionHandle session, uin
return CHIP_NO_ERROR;
}

CHIP_ERROR ExchangeMessageDispatch::OnMessageReceived(const PayloadHeader & payloadHeader, uint32_t messageId,
const Transport::PeerAddress & peerAddress, MessageFlags msgFlags,
ReliableMessageContext * reliableMessageContext)
CHIP_ERROR ExchangeMessageDispatch::OnMessageReceived(const Header::Flags & headerFlags, const PayloadHeader & payloadHeader,
uint32_t messageId, const Transport::PeerAddress & peerAddress,
MessageFlags msgFlags, ReliableMessageContext * reliableMessageContext)
{
if (IsEncryptionRequired())
{
VerifyOrReturnError(headerFlags.Has(Header::FlagValues::kEncryptedMessage), CHIP_ERROR_INVALID_ARGUMENT);
}

ReturnErrorCodeIf(!MessagePermitted(payloadHeader.GetProtocolID().GetProtocolId(), payloadHeader.GetMessageType()),
CHIP_ERROR_INVALID_ARGUMENT);

Expand Down
3 changes: 2 additions & 1 deletion src/messaging/ExchangeMessageDispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,14 @@ class ExchangeMessageDispatch : public ReferenceCounted<ExchangeMessageDispatch>
virtual CHIP_ERROR SendPreparedMessage(SecureSessionHandle session,
const EncryptedPacketBufferHandle & preparedMessage) const = 0;

virtual CHIP_ERROR OnMessageReceived(const PayloadHeader & payloadHeader, uint32_t messageId,
virtual CHIP_ERROR OnMessageReceived(const Header::Flags & headerFlags, const PayloadHeader & payloadHeader, uint32_t messageId,
const Transport::PeerAddress & peerAddress, MessageFlags msgFlags,
ReliableMessageContext * reliableMessageContext);

protected:
virtual bool MessagePermitted(uint16_t protocol, uint8_t type) = 0;
virtual bool IsReliableTransmissionAllowed() const { return true; }
virtual bool IsEncryptionRequired() const { return true; }
};

} // namespace Messaging
Expand Down
51 changes: 51 additions & 0 deletions src/messaging/tests/TestReliableMessageProtocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,11 @@ class MockSessionEstablishmentExchangeDispatch : public Messaging::ExchangeMessa

bool MessagePermitted(uint16_t protocol, uint8_t type) override { return true; }

bool IsEncryptionRequired() const override { return mRequireEncryption; }

bool mRetainMessageOnSend = true;

bool mRequireEncryption = false;
};

class MockSessionEstablishmentDelegate : public ExchangeDelegate
Expand Down Expand Up @@ -416,6 +420,52 @@ void CheckFailedMessageRetainOnSend(nlTestSuite * inSuite, void * inContext)
NL_TEST_ASSERT(inSuite, rm->TestGetCountRetransTable() == 0);
}

void CheckUnencryptedMessageReceiveFailure(nlTestSuite * inSuite, void * inContext)
{
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);

ctx.GetInetLayer().SystemLayer()->Init();

chip::System::PacketBufferHandle buffer = chip::MessagePacketBuffer::NewWithData(PAYLOAD, sizeof(PAYLOAD));
NL_TEST_ASSERT(inSuite, !buffer.IsNull());

MockSessionEstablishmentDelegate mockReceiver;
CHIP_ERROR err = ctx.GetExchangeManager().RegisterUnsolicitedMessageHandlerForType(Echo::MsgType::EchoRequest, &mockReceiver);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);

// Expect the received messages to be encrypted
mockReceiver.mMessageDispatch.mRequireEncryption = true;

MockSessionEstablishmentDelegate mockSender;
ExchangeContext * exchange = ctx.NewExchangeToPeer(&mockSender);
NL_TEST_ASSERT(inSuite, exchange != nullptr);

ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr();
ReliableMessageContext * rc = exchange->GetReliableMessageContext();
NL_TEST_ASSERT(inSuite, rm != nullptr);
NL_TEST_ASSERT(inSuite, rc != nullptr);

err = mockSender.mMessageDispatch.Init();
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);

gLoopback.mSentMessageCount = 0;
gLoopback.mNumMessagesToDrop = 0;
gLoopback.mDroppedMessageCount = 0;

err = exchange->SendMessage(Echo::MsgType::EchoRequest, std::move(buffer));
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);
// Test that the message was actually sent (and not dropped)
NL_TEST_ASSERT(inSuite, gLoopback.mSentMessageCount == 1);
NL_TEST_ASSERT(inSuite, gLoopback.mDroppedMessageCount == 0);
// Test that the message was dropped by the receiver
NL_TEST_ASSERT(inSuite, !mockReceiver.IsOnMessageReceivedCalled);

// Since peer dropped the message, we might have pending acks. Let's clear the table
rm->ClearRetransTable(rc);

exchange->Close();
}

void CheckResendApplicationMessageWithPeerExchange(nlTestSuite * inSuite, void * inContext)
{
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);
Expand Down Expand Up @@ -1137,6 +1187,7 @@ const nlTest sTests[] =
NL_TEST_DEF("Test sending an unsolicited ack-soliciting 'standalone ack' message", CheckSendUnsolicitedStandaloneAckMessage),
NL_TEST_DEF("Test ReliableMessageMgr::CheckSendStandaloneAckMessage", CheckSendStandaloneAckMessage),
NL_TEST_DEF("Test command, response, default response, with receiver closing exchange after sending response", CheckMessageAfterClosed),
NL_TEST_DEF("Test that unencrypted message is dropped if exchange requires encryption", CheckUnencryptedMessageReceiveFailure),

NL_TEST_SENTINEL()
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,15 @@ CHIP_ERROR SessionEstablishmentExchangeDispatch::SendPreparedMessage(SecureSessi
return mTransportMgr->SendMessage(mPeerAddress, preparedMessage.CastToWritable());
}

CHIP_ERROR SessionEstablishmentExchangeDispatch::OnMessageReceived(const PayloadHeader & payloadHeader, uint32_t messageId,
CHIP_ERROR SessionEstablishmentExchangeDispatch::OnMessageReceived(const Header::Flags & headerFlags,
const PayloadHeader & payloadHeader, uint32_t messageId,
const Transport::PeerAddress & peerAddress,
Messaging::MessageFlags msgFlags,
ReliableMessageContext * reliableMessageContext)
{
mPeerAddress = peerAddress;
return ExchangeMessageDispatch::OnMessageReceived(payloadHeader, messageId, peerAddress, msgFlags, reliableMessageContext);
return ExchangeMessageDispatch::OnMessageReceived(headerFlags, payloadHeader, messageId, peerAddress, msgFlags,
reliableMessageContext);
}

bool SessionEstablishmentExchangeDispatch::MessagePermitted(uint16_t protocol, uint8_t type)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class SessionEstablishmentExchangeDispatch : public Messaging::ExchangeMessageDi
EncryptedPacketBufferHandle & out) override;
CHIP_ERROR SendPreparedMessage(SecureSessionHandle session, const EncryptedPacketBufferHandle & preparedMessage) const override;

CHIP_ERROR OnMessageReceived(const PayloadHeader & payloadHeader, uint32_t messageId,
CHIP_ERROR OnMessageReceived(const Header::Flags & headerFlags, const PayloadHeader & payloadHeader, uint32_t messageId,
const Transport::PeerAddress & peerAddress, Messaging::MessageFlags msgFlags,
Messaging::ReliableMessageContext * reliableMessageContext) override;

Expand All @@ -64,6 +64,8 @@ class SessionEstablishmentExchangeDispatch : public Messaging::ExchangeMessageDi
return (mPeerAddress.GetTransportType() == Transport::Type::kUdp);
}

bool IsEncryptionRequired() const override { return false; }

private:
TransportMgrBase * mTransportMgr = nullptr;
Transport::PeerAddress mPeerAddress;
Expand Down

0 comments on commit 1076380

Please sign in to comment.