diff --git a/src/lib/core/CHIPError.cpp b/src/lib/core/CHIPError.cpp index 5b398d7de8800e..706972fa59b567 100644 --- a/src/lib/core/CHIPError.cpp +++ b/src/lib/core/CHIPError.cpp @@ -437,6 +437,9 @@ bool FormatCHIPError(char * buf, uint16_t bufSize, CHIP_ERROR err) case CHIP_ERROR_DRBG_ENTROPY_SOURCE_FAILED.AsInteger(): desc = "DRBG entropy source failed to generate entropy data"; break; + case CHIP_ERROR_MESSAGE_COUNTER_EXHAUSTED.AsInteger(): + desc = "Message counter exhausted"; + break; case CHIP_ERROR_FABRIC_EXISTS.AsInteger(): desc = "Trying to add a NOC for a fabric that already exists"; break; diff --git a/src/lib/core/CHIPError.h b/src/lib/core/CHIPError.h index a9d104780d1a6d..4e67ef507c9b97 100644 --- a/src/lib/core/CHIPError.h +++ b/src/lib/core/CHIPError.h @@ -1554,7 +1554,13 @@ using CHIP_ERROR = ::chip::ChipError; */ #define CHIP_ERROR_IM_MALFORMED_STATUS_RESPONSE_MESSAGE CHIP_CORE_ERROR(0x7c) -// unused CHIP_CORE_ERROR(0x7d) +/** + * @def CHIP_ERROR_MESSAGE_COUNTER_EXHAUSTED + * + * @brief + * The message counter of the session is exhausted, the session should be closed. + */ +#define CHIP_ERROR_MESSAGE_COUNTER_EXHAUSTED CHIP_CORE_ERROR(0x7d) /** * @def CHIP_ERROR_FABRIC_EXISTS diff --git a/src/lib/core/tests/TestCHIPErrorStr.cpp b/src/lib/core/tests/TestCHIPErrorStr.cpp index 259d6bb416f492..67a9c40775455a 100644 --- a/src/lib/core/tests/TestCHIPErrorStr.cpp +++ b/src/lib/core/tests/TestCHIPErrorStr.cpp @@ -173,6 +173,7 @@ static const CHIP_ERROR kTestElements[] = CHIP_ERROR_IM_MALFORMED_ATTRIBUTE_REPORT_IB, CHIP_ERROR_IM_MALFORMED_EVENT_STATUS_IB, CHIP_ERROR_IM_MALFORMED_STATUS_RESPONSE_MESSAGE, + CHIP_ERROR_MESSAGE_COUNTER_EXHAUSTED, CHIP_ERROR_FABRIC_EXISTS, CHIP_ERROR_KEY_NOT_FOUND_FROM_PEER, CHIP_ERROR_WRONG_ENCRYPTION_TYPE_FROM_PEER, diff --git a/src/protocols/secure_channel/PairingSession.cpp b/src/protocols/secure_channel/PairingSession.cpp index 55d992c1655d83..170d68f7c03e8c 100644 --- a/src/protocols/secure_channel/PairingSession.cpp +++ b/src/protocols/secure_channel/PairingSession.cpp @@ -44,7 +44,7 @@ CHIP_ERROR PairingSession::ActivateSecureSession(const Transport::PeerAddress & ReturnErrorOnFailure(DeriveSecureSession(secureSession->GetCryptoContext())); uint16_t peerSessionId = GetPeerSessionId(); secureSession->SetPeerAddress(peerAddress); - secureSession->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(LocalSessionMessageCounter::kInitialSyncValue); + secureSession->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(Transport::PeerMessageCounter::kInitialSyncValue); // Call Activate last, otherwise errors on anything after would lead to // a partially valid session. diff --git a/src/transport/MessageCounter.h b/src/transport/MessageCounter.h index 26f50568343047..70ce525e851b80 100644 --- a/src/transport/MessageCounter.h +++ b/src/transport/MessageCounter.h @@ -33,8 +33,7 @@ namespace chip { * of message counter * * 1. Global unencrypted message counter - * 2. Global encrypted message counter - * 3. Session message counter + * 2. Secure session message counter * * There will be separate implementations for each type */ @@ -51,6 +50,7 @@ class MessageCounter virtual ~MessageCounter() = default; virtual Type GetType() const = 0; + virtual bool IsValid() const = 0; virtual uint32_t Value() const = 0; /** Get current value */ virtual CHIP_ERROR Advance() = 0; /** Advance the counter */ }; @@ -63,6 +63,7 @@ class GlobalUnencryptedMessageCounter : public MessageCounter void Init(); Type GetType() const override { return GlobalUnencrypted; } + bool IsValid() const override { return true; } uint32_t Value() const override { return mValue; } CHIP_ERROR Advance() override { @@ -77,7 +78,8 @@ class GlobalUnencryptedMessageCounter : public MessageCounter class LocalSessionMessageCounter : public MessageCounter { public: - static constexpr uint32_t kInitialSyncValue = 0; ///< Used for initializing peer counter + static constexpr uint32_t kInvalidMessageCounter = 0; + static constexpr uint32_t kMaxMessageCounter = 0xFFFFFFFF; static constexpr uint32_t kMessageCounterRandomInitMask = 0x0FFFFFFF; ///< 28-bit mask /** @@ -88,9 +90,21 @@ class LocalSessionMessageCounter : public MessageCounter LocalSessionMessageCounter() { mValue = (Crypto::GetRandU32() & kMessageCounterRandomInitMask) + 1; } Type GetType() const override { return Session; } + bool IsValid() const override { return mValue != kInvalidMessageCounter; } uint32_t Value() const override { return mValue; } CHIP_ERROR Advance() override { + if (mValue == kInvalidMessageCounter) + { + return CHIP_ERROR_MESSAGE_COUNTER_EXHAUSTED; + } + + if (mValue == kMaxMessageCounter) + { + mValue = kInvalidMessageCounter; + return CHIP_NO_ERROR; + } + ++mValue; return CHIP_NO_ERROR; } diff --git a/src/transport/PeerMessageCounter.h b/src/transport/PeerMessageCounter.h index edfac7bf8ddb5a..e02981d98a3e9d 100644 --- a/src/transport/PeerMessageCounter.h +++ b/src/transport/PeerMessageCounter.h @@ -32,7 +32,8 @@ namespace Transport { class PeerMessageCounter { public: - static constexpr size_t kChallengeSize = 8; + static constexpr size_t kChallengeSize = 8; + static constexpr uint32_t kInitialSyncValue = 0; PeerMessageCounter() : mStatus(Status::NotSynced) {} ~PeerMessageCounter() { Reset(); } diff --git a/src/transport/SessionManager.cpp b/src/transport/SessionManager.cpp index f68ae57f535dae..48b388472bd494 100644 --- a/src/transport/SessionManager.cpp +++ b/src/transport/SessionManager.cpp @@ -188,6 +188,7 @@ CHIP_ERROR SessionManager::PrepareMessage(const SessionHandle & sessionHandle, P MessageCounter & counter = session->GetSessionMessageCounter().GetLocalMessageCounter(); uint32_t messageCounter = counter.Value(); + VerifyOrReturnError(counter.IsValid(), CHIP_ERROR_MESSAGE_COUNTER_EXHAUSTED); packetHeader .SetMessageCounter(messageCounter) // .SetSessionId(session->GetPeerSessionId()) // @@ -408,7 +409,7 @@ CHIP_ERROR SessionManager::InjectPaseSessionWithTestKey(SessionHolder & sessionH ByteSpan secret(reinterpret_cast(CHIP_CONFIG_TEST_SHARED_SECRET_VALUE), secretLen); ReturnErrorOnFailure(secureSession->GetCryptoContext().InitFromSecret( secret, ByteSpan(nullptr, 0), CryptoContext::SessionInfoType::kSessionEstablishment, role)); - secureSession->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(LocalSessionMessageCounter::kInitialSyncValue); + secureSession->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(Transport::PeerMessageCounter::kInitialSyncValue); sessionHolder.Grab(session.Value()); return CHIP_NO_ERROR; } @@ -675,7 +676,7 @@ void SessionManager::SecureGroupMessageDispatch(const PacketHeader & packetHeade // Handle Group message counter here spec 4.7.3 // spec 4.5.1.2 for msg counter - chip::Transport::PeerMessageCounter * counter = nullptr; + Transport::PeerMessageCounter * counter = nullptr; if (CHIP_NO_ERROR == mGroupPeerMsgCounter.FindOrAddPeer(groupContext.fabric_index, packetHeader.GetSourceNodeId().Value(), diff --git a/src/transport/tests/TestPeerMessageCounter.cpp b/src/transport/tests/TestPeerMessageCounter.cpp index 63a4e5a89f6de5..2c10ce4c8e6a0a 100644 --- a/src/transport/tests/TestPeerMessageCounter.cpp +++ b/src/transport/tests/TestPeerMessageCounter.cpp @@ -187,7 +187,7 @@ void UnicastSmallStepTest(nlTestSuite * inSuite, void * inContext) for (uint32_t k = 1; k <= 2 * CHIP_CONFIG_MESSAGE_COUNTER_WINDOW_SIZE; k++) { chip::Transport::PeerMessageCounter counter; - counter.SetCounter(LocalSessionMessageCounter::kInitialSyncValue); + counter.SetCounter(chip::Transport::PeerMessageCounter::kInitialSyncValue); if (counter.VerifyEncryptedUnicast(n) == CHIP_NO_ERROR) { // Act like we got this counter value on the wire. @@ -259,7 +259,7 @@ void UnicastLargeStepTest(nlTestSuite * inSuite, void * inContext) for (uint32_t k = (static_cast(1 << 31) - 5); k <= (static_cast(1 << 31) - 1); k++) { chip::Transport::PeerMessageCounter counter; - counter.SetCounter(LocalSessionMessageCounter::kInitialSyncValue); + counter.SetCounter(chip::Transport::PeerMessageCounter::kInitialSyncValue); if (counter.VerifyEncryptedUnicast(n) == CHIP_NO_ERROR) { // Act like we got this counter value on the wire.