From df74de161936834944c7b1b371af434aad003044 Mon Sep 17 00:00:00 2001 From: Zang MingJie Date: Thu, 9 Jun 2022 13:32:17 +0800 Subject: [PATCH] Enforce maximum NONCE (#19037) * Enforce maximum NONCE * Add test-case * Resove conversations, fix comments * Resolve comments --- src/lib/core/CHIPError.cpp | 3 + src/lib/core/CHIPError.h | 8 +- src/lib/core/tests/TestCHIPErrorStr.cpp | 1 + .../secure_channel/PairingSession.cpp | 2 +- src/transport/MessageCounter.cpp | 2 +- src/transport/MessageCounter.h | 41 +++++----- src/transport/PeerMessageCounter.h | 3 +- src/transport/SessionManager.cpp | 12 +-- .../tests/TestPeerMessageCounter.cpp | 4 +- src/transport/tests/TestSessionManager.cpp | 78 +++++++++++++++++++ 10 files changed, 124 insertions(+), 30 deletions(-) diff --git a/src/lib/core/CHIPError.cpp b/src/lib/core/CHIPError.cpp index 5b398d7de8800e..706972fa59b567 100644 --- a/src/lib/core/CHIPError.cpp +++ b/src/lib/core/CHIPError.cpp @@ -437,6 +437,9 @@ bool FormatCHIPError(char * buf, uint16_t bufSize, CHIP_ERROR err) case CHIP_ERROR_DRBG_ENTROPY_SOURCE_FAILED.AsInteger(): desc = "DRBG entropy source failed to generate entropy data"; break; + case CHIP_ERROR_MESSAGE_COUNTER_EXHAUSTED.AsInteger(): + desc = "Message counter exhausted"; + break; case CHIP_ERROR_FABRIC_EXISTS.AsInteger(): desc = "Trying to add a NOC for a fabric that already exists"; break; diff --git a/src/lib/core/CHIPError.h b/src/lib/core/CHIPError.h index a9d104780d1a6d..4e67ef507c9b97 100644 --- a/src/lib/core/CHIPError.h +++ b/src/lib/core/CHIPError.h @@ -1554,7 +1554,13 @@ using CHIP_ERROR = ::chip::ChipError; */ #define CHIP_ERROR_IM_MALFORMED_STATUS_RESPONSE_MESSAGE CHIP_CORE_ERROR(0x7c) -// unused CHIP_CORE_ERROR(0x7d) +/** + * @def CHIP_ERROR_MESSAGE_COUNTER_EXHAUSTED + * + * @brief + * The message counter of the session is exhausted, the session should be closed. + */ +#define CHIP_ERROR_MESSAGE_COUNTER_EXHAUSTED CHIP_CORE_ERROR(0x7d) /** * @def CHIP_ERROR_FABRIC_EXISTS diff --git a/src/lib/core/tests/TestCHIPErrorStr.cpp b/src/lib/core/tests/TestCHIPErrorStr.cpp index 259d6bb416f492..67a9c40775455a 100644 --- a/src/lib/core/tests/TestCHIPErrorStr.cpp +++ b/src/lib/core/tests/TestCHIPErrorStr.cpp @@ -173,6 +173,7 @@ static const CHIP_ERROR kTestElements[] = CHIP_ERROR_IM_MALFORMED_ATTRIBUTE_REPORT_IB, CHIP_ERROR_IM_MALFORMED_EVENT_STATUS_IB, CHIP_ERROR_IM_MALFORMED_STATUS_RESPONSE_MESSAGE, + CHIP_ERROR_MESSAGE_COUNTER_EXHAUSTED, CHIP_ERROR_FABRIC_EXISTS, CHIP_ERROR_KEY_NOT_FOUND_FROM_PEER, CHIP_ERROR_WRONG_ENCRYPTION_TYPE_FROM_PEER, diff --git a/src/protocols/secure_channel/PairingSession.cpp b/src/protocols/secure_channel/PairingSession.cpp index 806d0ec4306785..259f66c53be98e 100644 --- a/src/protocols/secure_channel/PairingSession.cpp +++ b/src/protocols/secure_channel/PairingSession.cpp @@ -44,7 +44,7 @@ CHIP_ERROR PairingSession::ActivateSecureSession(const Transport::PeerAddress & ReturnErrorOnFailure(DeriveSecureSession(secureSession->GetCryptoContext())); uint16_t peerSessionId = GetPeerSessionId(); secureSession->SetPeerAddress(peerAddress); - secureSession->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(LocalSessionMessageCounter::kInitialSyncValue); + secureSession->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(Transport::PeerMessageCounter::kInitialSyncValue); // Call Activate last, otherwise errors on anything after would lead to // a partially valid session. diff --git a/src/transport/MessageCounter.cpp b/src/transport/MessageCounter.cpp index 59fae9bfc9f2e7..99930cfd47c46f 100644 --- a/src/transport/MessageCounter.cpp +++ b/src/transport/MessageCounter.cpp @@ -29,7 +29,7 @@ namespace chip { void GlobalUnencryptedMessageCounter::Init() { - mValue = Crypto::GetRandU32(); + mLastUsedValue = Crypto::GetRandU32(); } } // namespace chip diff --git a/src/transport/MessageCounter.h b/src/transport/MessageCounter.h index 26f50568343047..06934c0ef8ea5f 100644 --- a/src/transport/MessageCounter.h +++ b/src/transport/MessageCounter.h @@ -29,12 +29,10 @@ namespace chip { /** - * MessageCounter represents a local message counter. There are 3 types - * of message counter + * MessageCounter represents a local message counter. There are 2 types of message counter * * 1. Global unencrypted message counter - * 2. Global encrypted message counter - * 3. Session message counter + * 2. Secure session message counter * * There will be separate implementations for each type */ @@ -50,53 +48,60 @@ class MessageCounter virtual ~MessageCounter() = default; - virtual Type GetType() const = 0; - virtual uint32_t Value() const = 0; /** Get current value */ - virtual CHIP_ERROR Advance() = 0; /** Advance the counter */ + virtual Type GetType() const = 0; + virtual CHIP_ERROR AdvanceAndConsume(uint32_t & fetch) = 0; /** Advance the counter, and feed the new counter to fetch */ }; class GlobalUnencryptedMessageCounter : public MessageCounter { public: - GlobalUnencryptedMessageCounter() : mValue(0) {} + GlobalUnencryptedMessageCounter() : mLastUsedValue(0) {} void Init(); Type GetType() const override { return GlobalUnencrypted; } - uint32_t Value() const override { return mValue; } - CHIP_ERROR Advance() override + CHIP_ERROR AdvanceAndConsume(uint32_t & fetch) override { - ++mValue; + fetch = ++mLastUsedValue; return CHIP_NO_ERROR; } private: - uint32_t mValue; + uint32_t mLastUsedValue; }; class LocalSessionMessageCounter : public MessageCounter { public: - static constexpr uint32_t kInitialSyncValue = 0; ///< Used for initializing peer counter + static constexpr uint32_t kMessageCounterMax = 0xFFFFFFFF; static constexpr uint32_t kMessageCounterRandomInitMask = 0x0FFFFFFF; ///< 28-bit mask /** * Initialize a local message counter with random value between [1, 2^28]. This increases the difficulty of traffic analysis * attacks by making it harder to determine how long a particular session has been open. The initial counter is always 1 or * higher to guarantee first message is always greater than initial peer counter set to 0. + * + * The mLastUsedValue is the predecessor of the initial value, it will be advanced before using, so don't need to add 1 here. */ - LocalSessionMessageCounter() { mValue = (Crypto::GetRandU32() & kMessageCounterRandomInitMask) + 1; } + LocalSessionMessageCounter() { mLastUsedValue = (Crypto::GetRandU32() & kMessageCounterRandomInitMask); } Type GetType() const override { return Session; } - uint32_t Value() const override { return mValue; } - CHIP_ERROR Advance() override + CHIP_ERROR AdvanceAndConsume(uint32_t & fetch) override { - ++mValue; + if (mLastUsedValue == kMessageCounterMax) + { + return CHIP_ERROR_MESSAGE_COUNTER_EXHAUSTED; + } + + fetch = ++mLastUsedValue; return CHIP_NO_ERROR; } + // Test-only function to set the counter value + void TestSetCounter(uint32_t value) { mLastUsedValue = value; } + private: - uint32_t mValue; + uint32_t mLastUsedValue; }; } // namespace chip diff --git a/src/transport/PeerMessageCounter.h b/src/transport/PeerMessageCounter.h index edfac7bf8ddb5a..e02981d98a3e9d 100644 --- a/src/transport/PeerMessageCounter.h +++ b/src/transport/PeerMessageCounter.h @@ -32,7 +32,8 @@ namespace Transport { class PeerMessageCounter { public: - static constexpr size_t kChallengeSize = 8; + static constexpr size_t kChallengeSize = 8; + static constexpr uint32_t kInitialSyncValue = 0; PeerMessageCounter() : mStatus(Status::NotSynced) {} ~PeerMessageCounter() { Reset(); } diff --git a/src/transport/SessionManager.cpp b/src/transport/SessionManager.cpp index f68ae57f535dae..0298e4595f09ab 100644 --- a/src/transport/SessionManager.cpp +++ b/src/transport/SessionManager.cpp @@ -187,7 +187,8 @@ CHIP_ERROR SessionManager::PrepareMessage(const SessionHandle & sessionHandle, P } MessageCounter & counter = session->GetSessionMessageCounter().GetLocalMessageCounter(); - uint32_t messageCounter = counter.Value(); + uint32_t messageCounter; + ReturnErrorOnFailure(counter.AdvanceAndConsume(messageCounter)); packetHeader .SetMessageCounter(messageCounter) // .SetSessionId(session->GetPeerSessionId()) // @@ -201,7 +202,6 @@ CHIP_ERROR SessionManager::PrepareMessage(const SessionHandle & sessionHandle, P CryptoContext::BuildNonce(nonce, packetHeader.GetSecurityFlags(), messageCounter, sourceNodeId); ReturnErrorOnFailure(SecureMessageCodec::Encrypt(session->GetCryptoContext(), nonce, payloadHeader, packetHeader, message)); - ReturnErrorOnFailure(counter.Advance()); #if CHIP_PROGRESS_LOGGING destination = session->GetPeerNodeId(); @@ -211,8 +211,8 @@ CHIP_ERROR SessionManager::PrepareMessage(const SessionHandle & sessionHandle, P break; case Transport::Session::SessionType::kUnauthenticated: { MessageCounter & counter = mGlobalUnencryptedMessageCounter; - uint32_t messageCounter = counter.Value(); - ReturnErrorOnFailure(counter.Advance()); + uint32_t messageCounter; + ReturnErrorOnFailure(counter.AdvanceAndConsume(messageCounter)); packetHeader.SetMessageCounter(messageCounter); Transport::UnauthenticatedSession * session = sessionHandle->AsUnauthenticatedSession(); switch (session->GetSessionRole()) @@ -408,7 +408,7 @@ CHIP_ERROR SessionManager::InjectPaseSessionWithTestKey(SessionHolder & sessionH ByteSpan secret(reinterpret_cast(CHIP_CONFIG_TEST_SHARED_SECRET_VALUE), secretLen); ReturnErrorOnFailure(secureSession->GetCryptoContext().InitFromSecret( secret, ByteSpan(nullptr, 0), CryptoContext::SessionInfoType::kSessionEstablishment, role)); - secureSession->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(LocalSessionMessageCounter::kInitialSyncValue); + secureSession->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(Transport::PeerMessageCounter::kInitialSyncValue); sessionHolder.Grab(session.Value()); return CHIP_NO_ERROR; } @@ -675,7 +675,7 @@ void SessionManager::SecureGroupMessageDispatch(const PacketHeader & packetHeade // Handle Group message counter here spec 4.7.3 // spec 4.5.1.2 for msg counter - chip::Transport::PeerMessageCounter * counter = nullptr; + Transport::PeerMessageCounter * counter = nullptr; if (CHIP_NO_ERROR == mGroupPeerMsgCounter.FindOrAddPeer(groupContext.fabric_index, packetHeader.GetSourceNodeId().Value(), diff --git a/src/transport/tests/TestPeerMessageCounter.cpp b/src/transport/tests/TestPeerMessageCounter.cpp index 63a4e5a89f6de5..2c10ce4c8e6a0a 100644 --- a/src/transport/tests/TestPeerMessageCounter.cpp +++ b/src/transport/tests/TestPeerMessageCounter.cpp @@ -187,7 +187,7 @@ void UnicastSmallStepTest(nlTestSuite * inSuite, void * inContext) for (uint32_t k = 1; k <= 2 * CHIP_CONFIG_MESSAGE_COUNTER_WINDOW_SIZE; k++) { chip::Transport::PeerMessageCounter counter; - counter.SetCounter(LocalSessionMessageCounter::kInitialSyncValue); + counter.SetCounter(chip::Transport::PeerMessageCounter::kInitialSyncValue); if (counter.VerifyEncryptedUnicast(n) == CHIP_NO_ERROR) { // Act like we got this counter value on the wire. @@ -259,7 +259,7 @@ void UnicastLargeStepTest(nlTestSuite * inSuite, void * inContext) for (uint32_t k = (static_cast(1 << 31) - 5); k <= (static_cast(1 << 31) - 1); k++) { chip::Transport::PeerMessageCounter counter; - counter.SetCounter(LocalSessionMessageCounter::kInitialSyncValue); + counter.SetCounter(chip::Transport::PeerMessageCounter::kInitialSyncValue); if (counter.VerifyEncryptedUnicast(n) == CHIP_NO_ERROR) { // Act like we got this counter value on the wire. diff --git a/src/transport/tests/TestSessionManager.cpp b/src/transport/tests/TestSessionManager.cpp index d3d60855879bd9..ccf751bf594420 100644 --- a/src/transport/tests/TestSessionManager.cpp +++ b/src/transport/tests/TestSessionManager.cpp @@ -759,6 +759,83 @@ void SessionAllocationTest(nlTestSuite * inSuite, void * inContext) sessionManager.Shutdown(); } +void SessionCounterExhaustedTest(nlTestSuite * inSuite, void * inContext) +{ + TestContext & ctx = *reinterpret_cast(inContext); + + IPAddress addr; + IPAddress::FromString("::1", addr); + CHIP_ERROR err = CHIP_NO_ERROR; + + FabricTable fabricTable; + SessionManager sessionManager; + secure_channel::MessageCounterManager gMessageCounterManager; + chip::TestPersistentStorageDelegate deviceStorage; + + NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == fabricTable.Init(&deviceStorage)); + NL_TEST_ASSERT(inSuite, + CHIP_NO_ERROR == + sessionManager.Init(&ctx.GetSystemLayer(), &ctx.GetTransportMgr(), &gMessageCounterManager, &deviceStorage, + &fabricTable)); + + Transport::PeerAddress peer(Transport::PeerAddress::UDP(addr, CHIP_PORT)); + + FabricIndex aliceFabricIndex; + FabricInfo aliceFabric; + aliceFabric.TestOnlyBuildFabric(GetRootACertAsset().mCert, GetIAA1CertAsset().mCert, GetNodeA1CertAsset().mCert, + GetNodeA1CertAsset().mKey); + NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == fabricTable.AddNewFabricForTest(aliceFabric, &aliceFabricIndex)); + + FabricIndex bobFabricIndex; + FabricInfo bobFabric; + bobFabric.TestOnlyBuildFabric(GetRootACertAsset().mCert, GetIAA1CertAsset().mCert, GetNodeA2CertAsset().mCert, + GetNodeA2CertAsset().mKey); + NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == fabricTable.AddNewFabricForTest(bobFabric, &bobFabricIndex)); + + SessionHolder aliceToBobSession; + err = sessionManager.InjectPaseSessionWithTestKey(aliceToBobSession, 2, + fabricTable.FindFabricWithIndex(bobFabricIndex)->GetNodeId(), 1, + aliceFabricIndex, peer, CryptoContext::SessionRole::kInitiator); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + + SessionHolder bobToAliceSession; + err = sessionManager.InjectPaseSessionWithTestKey(bobToAliceSession, 1, + fabricTable.FindFabricWithIndex(aliceFabricIndex)->GetNodeId(), 2, + bobFabricIndex, peer, CryptoContext::SessionRole::kResponder); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + + // ==== Set counter value to max ==== + LocalSessionMessageCounter & counter = static_cast( + aliceToBobSession.Get().Value()->AsSecureSession()->GetSessionMessageCounter().GetLocalMessageCounter()); + counter.TestSetCounter(LocalSessionMessageCounter::kMessageCounterMax - 1); + + // ==== Build a valid message with max counter value ==== + chip::System::PacketBufferHandle buffer = chip::MessagePacketBuffer::NewWithData(PAYLOAD, sizeof(PAYLOAD)); + NL_TEST_ASSERT(inSuite, !buffer.IsNull()); + + PayloadHeader payloadHeader; + + // Set the exchange ID for this header. + payloadHeader.SetExchangeID(0); + + // Set the protocol ID and message type for this header. + payloadHeader.SetMessageType(chip::Protocols::Echo::MsgType::EchoRequest); + + EncryptedPacketBufferHandle preparedMessage; + err = sessionManager.PrepareMessage(aliceToBobSession.Get().Value(), payloadHeader, std::move(buffer), preparedMessage); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + + // ==== Build another message which will fail becuase message counter is exhausted ==== + chip::System::PacketBufferHandle buffer2 = chip::MessagePacketBuffer::NewWithData(PAYLOAD, sizeof(PAYLOAD)); + NL_TEST_ASSERT(inSuite, !buffer2.IsNull()); + + EncryptedPacketBufferHandle preparedMessage2; + err = sessionManager.PrepareMessage(aliceToBobSession.Get().Value(), payloadHeader, std::move(buffer2), preparedMessage2); + NL_TEST_ASSERT(inSuite, err == CHIP_ERROR_MESSAGE_COUNTER_EXHAUSTED); + + sessionManager.Shutdown(); +} + // Test Suite /** @@ -774,6 +851,7 @@ const nlTest sTests[] = NL_TEST_DEF("Old counter Test", SendPacketWithOldCounterTest), NL_TEST_DEF("Too-old counter Test", SendPacketWithTooOldCounterTest), NL_TEST_DEF("Session Allocation Test", SessionAllocationTest), + NL_TEST_DEF("Session Counter Exhausted Test", SessionCounterExhaustedTest), NL_TEST_SENTINEL() };