Skip to content

Commit

Permalink
Enforce maximum NONCE (#19037)
Browse files Browse the repository at this point in the history
* Enforce maximum NONCE

* Add test-case

* Resove conversations, fix comments

* Resolve comments
  • Loading branch information
kghost authored and pull[bot] committed Oct 12, 2023
1 parent e824cda commit df74de1
Show file tree
Hide file tree
Showing 10 changed files with 124 additions and 30 deletions.
3 changes: 3 additions & 0 deletions src/lib/core/CHIPError.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,9 @@ bool FormatCHIPError(char * buf, uint16_t bufSize, CHIP_ERROR err)
case CHIP_ERROR_DRBG_ENTROPY_SOURCE_FAILED.AsInteger():
desc = "DRBG entropy source failed to generate entropy data";
break;
case CHIP_ERROR_MESSAGE_COUNTER_EXHAUSTED.AsInteger():
desc = "Message counter exhausted";
break;
case CHIP_ERROR_FABRIC_EXISTS.AsInteger():
desc = "Trying to add a NOC for a fabric that already exists";
break;
Expand Down
8 changes: 7 additions & 1 deletion src/lib/core/CHIPError.h
Original file line number Diff line number Diff line change
Expand Up @@ -1554,7 +1554,13 @@ using CHIP_ERROR = ::chip::ChipError;
*/
#define CHIP_ERROR_IM_MALFORMED_STATUS_RESPONSE_MESSAGE CHIP_CORE_ERROR(0x7c)

// unused CHIP_CORE_ERROR(0x7d)
/**
* @def CHIP_ERROR_MESSAGE_COUNTER_EXHAUSTED
*
* @brief
* The message counter of the session is exhausted, the session should be closed.
*/
#define CHIP_ERROR_MESSAGE_COUNTER_EXHAUSTED CHIP_CORE_ERROR(0x7d)

/**
* @def CHIP_ERROR_FABRIC_EXISTS
Expand Down
1 change: 1 addition & 0 deletions src/lib/core/tests/TestCHIPErrorStr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ static const CHIP_ERROR kTestElements[] =
CHIP_ERROR_IM_MALFORMED_ATTRIBUTE_REPORT_IB,
CHIP_ERROR_IM_MALFORMED_EVENT_STATUS_IB,
CHIP_ERROR_IM_MALFORMED_STATUS_RESPONSE_MESSAGE,
CHIP_ERROR_MESSAGE_COUNTER_EXHAUSTED,
CHIP_ERROR_FABRIC_EXISTS,
CHIP_ERROR_KEY_NOT_FOUND_FROM_PEER,
CHIP_ERROR_WRONG_ENCRYPTION_TYPE_FROM_PEER,
Expand Down
2 changes: 1 addition & 1 deletion src/protocols/secure_channel/PairingSession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ CHIP_ERROR PairingSession::ActivateSecureSession(const Transport::PeerAddress &
ReturnErrorOnFailure(DeriveSecureSession(secureSession->GetCryptoContext()));
uint16_t peerSessionId = GetPeerSessionId();
secureSession->SetPeerAddress(peerAddress);
secureSession->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(LocalSessionMessageCounter::kInitialSyncValue);
secureSession->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(Transport::PeerMessageCounter::kInitialSyncValue);

// Call Activate last, otherwise errors on anything after would lead to
// a partially valid session.
Expand Down
2 changes: 1 addition & 1 deletion src/transport/MessageCounter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ namespace chip {

void GlobalUnencryptedMessageCounter::Init()
{
mValue = Crypto::GetRandU32();
mLastUsedValue = Crypto::GetRandU32();
}

} // namespace chip
41 changes: 23 additions & 18 deletions src/transport/MessageCounter.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,10 @@
namespace chip {

/**
* MessageCounter represents a local message counter. There are 3 types
* of message counter
* MessageCounter represents a local message counter. There are 2 types of message counter
*
* 1. Global unencrypted message counter
* 2. Global encrypted message counter
* 3. Session message counter
* 2. Secure session message counter
*
* There will be separate implementations for each type
*/
Expand All @@ -50,53 +48,60 @@ class MessageCounter

virtual ~MessageCounter() = default;

virtual Type GetType() const = 0;
virtual uint32_t Value() const = 0; /** Get current value */
virtual CHIP_ERROR Advance() = 0; /** Advance the counter */
virtual Type GetType() const = 0;
virtual CHIP_ERROR AdvanceAndConsume(uint32_t & fetch) = 0; /** Advance the counter, and feed the new counter to fetch */
};

class GlobalUnencryptedMessageCounter : public MessageCounter
{
public:
GlobalUnencryptedMessageCounter() : mValue(0) {}
GlobalUnencryptedMessageCounter() : mLastUsedValue(0) {}

void Init();

Type GetType() const override { return GlobalUnencrypted; }
uint32_t Value() const override { return mValue; }
CHIP_ERROR Advance() override
CHIP_ERROR AdvanceAndConsume(uint32_t & fetch) override
{
++mValue;
fetch = ++mLastUsedValue;
return CHIP_NO_ERROR;
}

private:
uint32_t mValue;
uint32_t mLastUsedValue;
};

class LocalSessionMessageCounter : public MessageCounter
{
public:
static constexpr uint32_t kInitialSyncValue = 0; ///< Used for initializing peer counter
static constexpr uint32_t kMessageCounterMax = 0xFFFFFFFF;
static constexpr uint32_t kMessageCounterRandomInitMask = 0x0FFFFFFF; ///< 28-bit mask

/**
* Initialize a local message counter with random value between [1, 2^28]. This increases the difficulty of traffic analysis
* attacks by making it harder to determine how long a particular session has been open. The initial counter is always 1 or
* higher to guarantee first message is always greater than initial peer counter set to 0.
*
* The mLastUsedValue is the predecessor of the initial value, it will be advanced before using, so don't need to add 1 here.
*/
LocalSessionMessageCounter() { mValue = (Crypto::GetRandU32() & kMessageCounterRandomInitMask) + 1; }
LocalSessionMessageCounter() { mLastUsedValue = (Crypto::GetRandU32() & kMessageCounterRandomInitMask); }

Type GetType() const override { return Session; }
uint32_t Value() const override { return mValue; }
CHIP_ERROR Advance() override
CHIP_ERROR AdvanceAndConsume(uint32_t & fetch) override
{
++mValue;
if (mLastUsedValue == kMessageCounterMax)
{
return CHIP_ERROR_MESSAGE_COUNTER_EXHAUSTED;
}

fetch = ++mLastUsedValue;
return CHIP_NO_ERROR;
}

// Test-only function to set the counter value
void TestSetCounter(uint32_t value) { mLastUsedValue = value; }

private:
uint32_t mValue;
uint32_t mLastUsedValue;
};

} // namespace chip
3 changes: 2 additions & 1 deletion src/transport/PeerMessageCounter.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ namespace Transport {
class PeerMessageCounter
{
public:
static constexpr size_t kChallengeSize = 8;
static constexpr size_t kChallengeSize = 8;
static constexpr uint32_t kInitialSyncValue = 0;

PeerMessageCounter() : mStatus(Status::NotSynced) {}
~PeerMessageCounter() { Reset(); }
Expand Down
12 changes: 6 additions & 6 deletions src/transport/SessionManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ CHIP_ERROR SessionManager::PrepareMessage(const SessionHandle & sessionHandle, P
}

MessageCounter & counter = session->GetSessionMessageCounter().GetLocalMessageCounter();
uint32_t messageCounter = counter.Value();
uint32_t messageCounter;
ReturnErrorOnFailure(counter.AdvanceAndConsume(messageCounter));
packetHeader
.SetMessageCounter(messageCounter) //
.SetSessionId(session->GetPeerSessionId()) //
Expand All @@ -201,7 +202,6 @@ CHIP_ERROR SessionManager::PrepareMessage(const SessionHandle & sessionHandle, P
CryptoContext::BuildNonce(nonce, packetHeader.GetSecurityFlags(), messageCounter, sourceNodeId);

ReturnErrorOnFailure(SecureMessageCodec::Encrypt(session->GetCryptoContext(), nonce, payloadHeader, packetHeader, message));
ReturnErrorOnFailure(counter.Advance());

#if CHIP_PROGRESS_LOGGING
destination = session->GetPeerNodeId();
Expand All @@ -211,8 +211,8 @@ CHIP_ERROR SessionManager::PrepareMessage(const SessionHandle & sessionHandle, P
break;
case Transport::Session::SessionType::kUnauthenticated: {
MessageCounter & counter = mGlobalUnencryptedMessageCounter;
uint32_t messageCounter = counter.Value();
ReturnErrorOnFailure(counter.Advance());
uint32_t messageCounter;
ReturnErrorOnFailure(counter.AdvanceAndConsume(messageCounter));
packetHeader.SetMessageCounter(messageCounter);
Transport::UnauthenticatedSession * session = sessionHandle->AsUnauthenticatedSession();
switch (session->GetSessionRole())
Expand Down Expand Up @@ -408,7 +408,7 @@ CHIP_ERROR SessionManager::InjectPaseSessionWithTestKey(SessionHolder & sessionH
ByteSpan secret(reinterpret_cast<const uint8_t *>(CHIP_CONFIG_TEST_SHARED_SECRET_VALUE), secretLen);
ReturnErrorOnFailure(secureSession->GetCryptoContext().InitFromSecret(
secret, ByteSpan(nullptr, 0), CryptoContext::SessionInfoType::kSessionEstablishment, role));
secureSession->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(LocalSessionMessageCounter::kInitialSyncValue);
secureSession->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(Transport::PeerMessageCounter::kInitialSyncValue);
sessionHolder.Grab(session.Value());
return CHIP_NO_ERROR;
}
Expand Down Expand Up @@ -675,7 +675,7 @@ void SessionManager::SecureGroupMessageDispatch(const PacketHeader & packetHeade

// Handle Group message counter here spec 4.7.3
// spec 4.5.1.2 for msg counter
chip::Transport::PeerMessageCounter * counter = nullptr;
Transport::PeerMessageCounter * counter = nullptr;

if (CHIP_NO_ERROR ==
mGroupPeerMsgCounter.FindOrAddPeer(groupContext.fabric_index, packetHeader.GetSourceNodeId().Value(),
Expand Down
4 changes: 2 additions & 2 deletions src/transport/tests/TestPeerMessageCounter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ void UnicastSmallStepTest(nlTestSuite * inSuite, void * inContext)
for (uint32_t k = 1; k <= 2 * CHIP_CONFIG_MESSAGE_COUNTER_WINDOW_SIZE; k++)
{
chip::Transport::PeerMessageCounter counter;
counter.SetCounter(LocalSessionMessageCounter::kInitialSyncValue);
counter.SetCounter(chip::Transport::PeerMessageCounter::kInitialSyncValue);
if (counter.VerifyEncryptedUnicast(n) == CHIP_NO_ERROR)
{
// Act like we got this counter value on the wire.
Expand Down Expand Up @@ -259,7 +259,7 @@ void UnicastLargeStepTest(nlTestSuite * inSuite, void * inContext)
for (uint32_t k = (static_cast<uint32_t>(1 << 31) - 5); k <= (static_cast<uint32_t>(1 << 31) - 1); k++)
{
chip::Transport::PeerMessageCounter counter;
counter.SetCounter(LocalSessionMessageCounter::kInitialSyncValue);
counter.SetCounter(chip::Transport::PeerMessageCounter::kInitialSyncValue);
if (counter.VerifyEncryptedUnicast(n) == CHIP_NO_ERROR)
{
// Act like we got this counter value on the wire.
Expand Down
78 changes: 78 additions & 0 deletions src/transport/tests/TestSessionManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,83 @@ void SessionAllocationTest(nlTestSuite * inSuite, void * inContext)
sessionManager.Shutdown();
}

void SessionCounterExhaustedTest(nlTestSuite * inSuite, void * inContext)
{
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);

IPAddress addr;
IPAddress::FromString("::1", addr);
CHIP_ERROR err = CHIP_NO_ERROR;

FabricTable fabricTable;
SessionManager sessionManager;
secure_channel::MessageCounterManager gMessageCounterManager;
chip::TestPersistentStorageDelegate deviceStorage;

NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == fabricTable.Init(&deviceStorage));
NL_TEST_ASSERT(inSuite,
CHIP_NO_ERROR ==
sessionManager.Init(&ctx.GetSystemLayer(), &ctx.GetTransportMgr(), &gMessageCounterManager, &deviceStorage,
&fabricTable));

Transport::PeerAddress peer(Transport::PeerAddress::UDP(addr, CHIP_PORT));

FabricIndex aliceFabricIndex;
FabricInfo aliceFabric;
aliceFabric.TestOnlyBuildFabric(GetRootACertAsset().mCert, GetIAA1CertAsset().mCert, GetNodeA1CertAsset().mCert,
GetNodeA1CertAsset().mKey);
NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == fabricTable.AddNewFabricForTest(aliceFabric, &aliceFabricIndex));

FabricIndex bobFabricIndex;
FabricInfo bobFabric;
bobFabric.TestOnlyBuildFabric(GetRootACertAsset().mCert, GetIAA1CertAsset().mCert, GetNodeA2CertAsset().mCert,
GetNodeA2CertAsset().mKey);
NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == fabricTable.AddNewFabricForTest(bobFabric, &bobFabricIndex));

SessionHolder aliceToBobSession;
err = sessionManager.InjectPaseSessionWithTestKey(aliceToBobSession, 2,
fabricTable.FindFabricWithIndex(bobFabricIndex)->GetNodeId(), 1,
aliceFabricIndex, peer, CryptoContext::SessionRole::kInitiator);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);

SessionHolder bobToAliceSession;
err = sessionManager.InjectPaseSessionWithTestKey(bobToAliceSession, 1,
fabricTable.FindFabricWithIndex(aliceFabricIndex)->GetNodeId(), 2,
bobFabricIndex, peer, CryptoContext::SessionRole::kResponder);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);

// ==== Set counter value to max ====
LocalSessionMessageCounter & counter = static_cast<LocalSessionMessageCounter &>(
aliceToBobSession.Get().Value()->AsSecureSession()->GetSessionMessageCounter().GetLocalMessageCounter());
counter.TestSetCounter(LocalSessionMessageCounter::kMessageCounterMax - 1);

// ==== Build a valid message with max counter value ====
chip::System::PacketBufferHandle buffer = chip::MessagePacketBuffer::NewWithData(PAYLOAD, sizeof(PAYLOAD));
NL_TEST_ASSERT(inSuite, !buffer.IsNull());

PayloadHeader payloadHeader;

// Set the exchange ID for this header.
payloadHeader.SetExchangeID(0);

// Set the protocol ID and message type for this header.
payloadHeader.SetMessageType(chip::Protocols::Echo::MsgType::EchoRequest);

EncryptedPacketBufferHandle preparedMessage;
err = sessionManager.PrepareMessage(aliceToBobSession.Get().Value(), payloadHeader, std::move(buffer), preparedMessage);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);

// ==== Build another message which will fail becuase message counter is exhausted ====
chip::System::PacketBufferHandle buffer2 = chip::MessagePacketBuffer::NewWithData(PAYLOAD, sizeof(PAYLOAD));
NL_TEST_ASSERT(inSuite, !buffer2.IsNull());

EncryptedPacketBufferHandle preparedMessage2;
err = sessionManager.PrepareMessage(aliceToBobSession.Get().Value(), payloadHeader, std::move(buffer2), preparedMessage2);
NL_TEST_ASSERT(inSuite, err == CHIP_ERROR_MESSAGE_COUNTER_EXHAUSTED);

sessionManager.Shutdown();
}

// Test Suite

/**
Expand All @@ -774,6 +851,7 @@ const nlTest sTests[] =
NL_TEST_DEF("Old counter Test", SendPacketWithOldCounterTest),
NL_TEST_DEF("Too-old counter Test", SendPacketWithTooOldCounterTest),
NL_TEST_DEF("Session Allocation Test", SessionAllocationTest),
NL_TEST_DEF("Session Counter Exhausted Test", SessionCounterExhaustedTest),

NL_TEST_SENTINEL()
};
Expand Down

0 comments on commit df74de1

Please sign in to comment.