Skip to content

Commit

Permalink
Group Cryptography: Message encryption implemented in the Group Data … (
Browse files Browse the repository at this point in the history
#14313)

* Group Cryptography: Message encryption implemented in the Group Data Provider.

* Group Cryptography: Review comments applied.

* Update src/credentials/GroupDataProviderImpl.h

Co-authored-by: Boris Zbarsky <[email protected]>

Co-authored-by: Justin Wood <[email protected]>
Co-authored-by: Boris Zbarsky <[email protected]>
  • Loading branch information
3 people authored and pull[bot] committed Feb 10, 2024
1 parent 7a0015e commit 251b565
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 78 deletions.
19 changes: 2 additions & 17 deletions src/credentials/GroupDataProvider.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,22 +146,6 @@ class GroupDataProvider
VerifyOrReturnError(this->policy == other.policy && this->num_keys_used == other.num_keys_used, false);
return !memcmp(this->epoch_keys, other.epoch_keys, this->num_keys_used * sizeof(EpochKey));
}

ByteSpan GetCurrentKey()
{
// An epoch key update SHALL order the keys from oldest to newest,
// the current epoch key having the second newest time
switch (this->num_keys_used)
{
case 1:
case 2:
return ByteSpan(epoch_keys[0].key, EpochKey::kLengthBytes);
case 3:
return ByteSpan(epoch_keys[1].key, EpochKey::kLengthBytes);
default:
return ByteSpan(nullptr, 0);
}
}
};

/**
Expand Down Expand Up @@ -317,7 +301,8 @@ class GroupDataProvider
virtual CHIP_ERROR RemoveFabric(FabricIndex fabric_index) = 0;

// Decryption
virtual GroupSessionIterator * IterateGroupSessions(uint16_t session_id) = 0;
virtual GroupSessionIterator * IterateGroupSessions(uint16_t session_id) = 0;
virtual Crypto::SymmetricKeyContext * GetKeyContext(FabricIndex fabric_index, GroupId group_id) = 0;

// Listener
void SetListener(GroupListener * listener) { mListener = listener; };
Expand Down
71 changes: 56 additions & 15 deletions src/credentials/GroupDataProviderImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,23 @@ struct KeySetData : PersistentData<kPersistentBufferMax>
next = 0xffff;
}

OperationalKey * GetCurrentKey()
{
// An epoch key update SHALL order the keys from oldest to newest,
// the current epoch key having the second newest time if time
// synchronization is not achieved or guaranteed.
switch (this->keys_count)
{
case 1:
case 2:
return &operational_keys[0];
case 3:
return &operational_keys[1];
default:
return nullptr;
}
}

CHIP_ERROR Serialize(TLV::TLVWriter & writer) const override
{
TLV::TLVType container;
Expand Down Expand Up @@ -1584,9 +1601,10 @@ CHIP_ERROR GroupDataProviderImpl::SetKeySet(chip::FabricIndex fabric_index, cons
for (size_t i = 0; i < in_keyset.num_keys_used; ++i)
{
ByteSpan epoch_key(in_keyset.epoch_keys[i].key, Crypto::CHIP_CRYPTO_SYMMETRIC_KEY_LENGTH_BYTES);
MutableByteSpan key(keyset.operational_keys[i].value, Crypto::CHIP_CRYPTO_SYMMETRIC_KEY_LENGTH_BYTES);
ReturnErrorOnFailure(Crypto::DeriveGroupOperationalKey(epoch_key, key));
ReturnErrorOnFailure(Crypto::DeriveGroupSessionId(key, keyset.operational_keys[i].hash));
uint8_t key[Crypto::CHIP_CRYPTO_SYMMETRIC_KEY_LENGTH_BYTES];
MutableByteSpan key_span(key, sizeof(key));
ReturnErrorOnFailure(Crypto::DeriveGroupOperationalKey(epoch_key, key_span));
ReturnErrorOnFailure(Crypto::DeriveGroupSessionId(ByteSpan(key, sizeof(key)), keyset.operational_keys[i].hash));
}

if (found)
Expand Down Expand Up @@ -1754,31 +1772,54 @@ CHIP_ERROR GroupDataProviderImpl::RemoveFabric(chip::FabricIndex fabric_index)
// Cryptography
//

CHIP_ERROR GroupDataProviderImpl::GroupKeyContext::SetKey(const ByteSpan & value)
Crypto::SymmetricKeyContext * GroupDataProviderImpl::GetKeyContext(FabricIndex fabric_index, GroupId group_id)
{
VerifyOrReturnError(value.size() == Crypto::CHIP_CRYPTO_SYMMETRIC_KEY_LENGTH_BYTES, CHIP_ERROR_BUFFER_TOO_SMALL);
memcpy(mKey, value.data(), value.size());
return CHIP_NO_ERROR;
FabricData fabric(fabric_index);
VerifyOrReturnError(CHIP_NO_ERROR == fabric.Load(mStorage), nullptr);

KeyMapData mapping(fabric.fabric_index, fabric.first_map);

// Look for the target group in the fabric's keyset-group pairs
for (uint16_t i = 0; i < fabric.map_count; ++i, mapping.id = mapping.next)
{
VerifyOrReturnError(CHIP_NO_ERROR == mapping.Load(mStorage), nullptr);
// GroupKeySetID of 0 is reserved for the Identity Protection Key (IPK)
if (mapping.keyset_id > 0 && mapping.group_id == group_id)
{
// Group found, get the keyset
KeySetData keyset;
VerifyOrReturnError(keyset.Find(mStorage, fabric, mapping.keyset_id), nullptr);
OperationalKey * key = keyset.GetCurrentKey();
if (nullptr != key)
{
return mKeyContexPool.CreateObject(*this, ByteSpan(key->value, Crypto::CHIP_CRYPTO_SYMMETRIC_KEY_LENGTH_BYTES),
key->hash);
}
}
}
return nullptr;
}

void GroupDataProviderImpl::GroupKeyContext::Clear()
void GroupDataProviderImpl::GroupKeyContext::Release()
{
memset(mKey, 0, sizeof(mKey));
memset(mKeyValue, 0, sizeof(mKeyValue));
mProvider.mKeyContexPool.ReleaseObject(this);
}

CHIP_ERROR GroupDataProviderImpl::GroupKeyContext::EncryptMessage(MutableByteSpan & plaintext, const ByteSpan & aad,
const ByteSpan & nonce, MutableByteSpan & out_mic) const
{
uint8_t * output = plaintext.data();
return Crypto::AES_CCM_encrypt(plaintext.data(), plaintext.size(), aad.data(), aad.size(), mKey, Crypto::kAES_CCM128_Key_Length,
nonce.data(), nonce.size(), output, out_mic.data(), out_mic.size());
return Crypto::AES_CCM_encrypt(plaintext.data(), plaintext.size(), aad.data(), aad.size(), mKeyValue,
Crypto::kAES_CCM128_Key_Length, nonce.data(), nonce.size(), output, out_mic.data(),
out_mic.size());
}

CHIP_ERROR GroupDataProviderImpl::GroupKeyContext::DecryptMessage(MutableByteSpan & ciphertext, const ByteSpan & aad,
const ByteSpan & nonce, const ByteSpan & mic) const
{
uint8_t * output = ciphertext.data();
return Crypto::AES_CCM_decrypt(ciphertext.data(), ciphertext.size(), aad.data(), aad.size(), mic.data(), mic.size(), mKey,
return Crypto::AES_CCM_decrypt(ciphertext.data(), ciphertext.size(), aad.data(), aad.size(), mic.data(), mic.size(), mKeyValue,
Crypto::kAES_CCM128_Key_Length, nonce.data(), nonce.size(), output);
}

Expand All @@ -1801,7 +1842,7 @@ GroupDataProviderImpl::GroupSessionIterator * GroupDataProviderImpl::IterateGrou
}

GroupDataProviderImpl::GroupSessionIteratorImpl::GroupSessionIteratorImpl(GroupDataProviderImpl & provider, uint16_t session_id) :
mProvider(provider), mSessionId(session_id)
mProvider(provider), mSessionId(session_id), mKeyContext(provider)
{
FabricList fabric_list;
ReturnOnFailure(fabric_list.Load(provider.mStorage));
Expand Down Expand Up @@ -1896,10 +1937,10 @@ bool GroupDataProviderImpl::GroupSessionIteratorImpl::Next(GroupSession & output
OperationalKey & key = keyset.operational_keys[mKeyIndex++];
if (key.hash == mSessionId)
{
mKey.SetKey(ByteSpan(key.value, Crypto::CHIP_CRYPTO_SYMMETRIC_KEY_LENGTH_BYTES));
mKeyContext.SetKey(ByteSpan(key.value, sizeof(key.value)), mSessionId);
output.fabric_index = fabric.fabric_index;
output.group_id = mapping.group_id;
output.key = &mKey;
output.key = &mKeyContext;
return true;
}
}
Expand Down
27 changes: 22 additions & 5 deletions src/credentials/GroupDataProviderImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class GroupDataProviderImpl : public GroupDataProvider
CHIP_ERROR RemoveFabric(FabricIndex fabric_index) override;

// Decryption
Crypto::SymmetricKeyContext * GetKeyContext(FabricIndex fabric_index, GroupId group_id) override;
GroupSessionIterator * IterateGroupSessions(uint16_t session_id) override;

protected:
Expand Down Expand Up @@ -142,9 +143,20 @@ class GroupDataProviderImpl : public GroupDataProvider
class GroupKeyContext : public Crypto::SymmetricKeyContext
{
public:
GroupKeyContext() = default;
CHIP_ERROR SetKey(const ByteSpan & value);
void Clear();
GroupKeyContext(GroupDataProviderImpl & provider) : mProvider(provider) {}

GroupKeyContext(GroupDataProviderImpl & provider, const ByteSpan & key, uint16_t hash) : mProvider(provider)
{
SetKey(key, hash);
}

void SetKey(const ByteSpan & key, uint16_t hash)
{
mKeyHash = hash;
memcpy(mKeyValue, key.data(), std::min(key.size(), sizeof(mKeyValue)));
}

uint16_t GetKeyHash() override { return mKeyHash; }

CHIP_ERROR EncryptMessage(MutableByteSpan & plaintext, const ByteSpan & aad, const ByteSpan & nonce,
MutableByteSpan & out_mic) const override;
Expand All @@ -155,8 +167,12 @@ class GroupDataProviderImpl : public GroupDataProvider
CHIP_ERROR DecryptPrivacy(MutableByteSpan & header, uint16_t session_id, const ByteSpan & payload,
const ByteSpan & mic) const override;

void Release() override;

protected:
uint8_t mKey[Crypto::CHIP_CRYPTO_SYMMETRIC_KEY_LENGTH_BYTES];
GroupDataProviderImpl & mProvider;
uint16_t mKeyHash = 0;
uint8_t mKeyValue[Crypto::CHIP_CRYPTO_SYMMETRIC_KEY_LENGTH_BYTES] = { 0 };
};

class KeySetIteratorImpl : public KeySetIterator
Expand Down Expand Up @@ -195,7 +211,7 @@ class GroupDataProviderImpl : public GroupDataProvider
uint16_t mKeyIndex = 0;
uint16_t mKeyCount = 0;
bool mFirstMap = true;
GroupKeyContext mKey;
GroupKeyContext mKeyContext;
};
CHIP_ERROR RemoveEndpoints(FabricIndex fabric_index, GroupId group_id);

Expand All @@ -206,6 +222,7 @@ class GroupDataProviderImpl : public GroupDataProvider
BitMapObjectPool<EndpointIteratorImpl, kIteratorsMax> mEndpointIterators;
BitMapObjectPool<KeySetIteratorImpl, kIteratorsMax> mKeySetIterators;
BitMapObjectPool<GroupSessionIteratorImpl, kIteratorsMax> mGroupSessionsIterator;
BitMapObjectPool<GroupKeyContext, kIteratorsMax> mKeyContexPool;
};

} // namespace Credentials
Expand Down
97 changes: 56 additions & 41 deletions src/credentials/tests/TestGroupDataProvider.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1054,44 +1054,65 @@ void TestGroupDecryption(nlTestSuite * apSuite, void * apContext)
NL_TEST_ASSERT(apSuite, CHIP_NO_ERROR == provider->AddEndpoint(kFabric2, kGroup3, kEndpointId3));
NL_TEST_ASSERT(apSuite, CHIP_NO_ERROR == provider->AddEndpoint(kFabric2, kGroup3, kEndpointId4));

NL_TEST_ASSERT(apSuite, CHIP_NO_ERROR == provider->SetGroupKeyAt(kFabric1, 0, kGroup3Keyset0));
NL_TEST_ASSERT(apSuite, CHIP_NO_ERROR == provider->SetGroupKeyAt(kFabric1, 1, kGroup3Keyset1));
NL_TEST_ASSERT(apSuite, CHIP_NO_ERROR == provider->SetGroupKeyAt(kFabric1, 2, kGroup3Keyset2));
NL_TEST_ASSERT(apSuite, CHIP_NO_ERROR == provider->SetGroupKeyAt(kFabric1, 3, kGroup3Keyset3));
NL_TEST_ASSERT(apSuite, CHIP_NO_ERROR == provider->SetGroupKeyAt(kFabric1, 4, kGroup1Keyset0));
NL_TEST_ASSERT(apSuite, CHIP_NO_ERROR == provider->SetGroupKeyAt(kFabric1, 5, kGroup1Keyset1));
NL_TEST_ASSERT(apSuite, CHIP_NO_ERROR == provider->SetGroupKeyAt(kFabric1, 6, kGroup1Keyset2));
NL_TEST_ASSERT(apSuite, CHIP_NO_ERROR == provider->SetGroupKeyAt(kFabric1, 7, kGroup1Keyset3));

NL_TEST_ASSERT(apSuite, CHIP_NO_ERROR == provider->SetGroupKeyAt(kFabric2, 0, kGroup2Keyset0));
NL_TEST_ASSERT(apSuite, CHIP_NO_ERROR == provider->SetGroupKeyAt(kFabric2, 1, kGroup2Keyset1));
NL_TEST_ASSERT(apSuite, CHIP_NO_ERROR == provider->SetGroupKeyAt(kFabric2, 2, kGroup2Keyset2));
NL_TEST_ASSERT(apSuite, CHIP_NO_ERROR == provider->SetGroupKeyAt(kFabric2, 3, kGroup2Keyset3));

NL_TEST_ASSERT(apSuite, CHIP_NO_ERROR == provider->SetKeySet(kFabric1, kKeySet0));
NL_TEST_ASSERT(apSuite, CHIP_NO_ERROR == provider->SetKeySet(kFabric1, kKeySet1));
NL_TEST_ASSERT(apSuite, CHIP_NO_ERROR == provider->SetKeySet(kFabric1, kKeySet2));
NL_TEST_ASSERT(apSuite, CHIP_NO_ERROR == provider->SetKeySet(kFabric1, kKeySet3));
NL_TEST_ASSERT(apSuite, CHIP_NO_ERROR == provider->SetKeySet(kFabric2, kKeySet0));
NL_TEST_ASSERT(apSuite, CHIP_NO_ERROR == provider->SetKeySet(kFabric2, kKeySet3));
NL_TEST_ASSERT(apSuite, CHIP_NO_ERROR == provider->SetKeySet(kFabric2, kKeySet2));
NL_TEST_ASSERT(apSuite, CHIP_NO_ERROR == provider->SetKeySet(kFabric2, kKeySet1));
NL_TEST_ASSERT(apSuite, CHIP_NO_ERROR == provider->SetKeySet(kFabric2, kKeySet3));

NL_TEST_ASSERT(apSuite, CHIP_NO_ERROR == provider->SetGroupKeyAt(kFabric1, 0, kGroup1Keyset0));
NL_TEST_ASSERT(apSuite, CHIP_NO_ERROR == provider->SetGroupKeyAt(kFabric1, 1, kGroup1Keyset2));
NL_TEST_ASSERT(apSuite, CHIP_NO_ERROR == provider->SetGroupKeyAt(kFabric1, 2, kGroup3Keyset0));
NL_TEST_ASSERT(apSuite, CHIP_NO_ERROR == provider->SetGroupKeyAt(kFabric1, 3, kGroup3Keyset2));

NL_TEST_ASSERT(apSuite, CHIP_NO_ERROR == provider->SetGroupKeyAt(kFabric2, 0, kGroup2Keyset1));
NL_TEST_ASSERT(apSuite, CHIP_NO_ERROR == provider->SetGroupKeyAt(kFabric2, 1, kGroup2Keyset3));

const uint8_t kMessage[] = { 0xa0, 0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7, 0xa8, 0xa9 };
const uint8_t nonce[13] = { 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x18, 0x1a, 0x1b, 0x1c };
const uint8_t aad[40] = { 0x0a, 0x1a, 0x2a, 0x3a, 0x4a, 0x5a, 0x6a, 0x7a, 0x8a, 0x9a, 0x0b, 0x1b, 0x2b, 0x3b,
const size_t kMessageLength = 10;
const uint8_t kMessage[kMessageLength] = { 0xa0, 0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7, 0xa8, 0xa9 };
const uint8_t nonce[13] = { 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x18, 0x1a, 0x1b, 0x1c };
const uint8_t aad[40] = { 0x0a, 0x1a, 0x2a, 0x3a, 0x4a, 0x5a, 0x6a, 0x7a, 0x8a, 0x9a, 0x0b, 0x1b, 0x2b, 0x3b,
0x4b, 0x5b, 0x6b, 0x7b, 0x8b, 0x9b, 0x0c, 0x1c, 0x2c, 0x3c, 0x4c, 0x5c, 0x6c, 0x7c,
0x8c, 0x9c, 0x0d, 0x1d, 0x2d, 0x3d, 0x4d, 0x5d, 0x6d, 0x7d, 0x8d, 0x9d };
uint8_t mic[16] = {
uint8_t mic[16] = {
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
};
uint8_t buffer[32];

std::set<std::pair<FabricIndex, GroupId>> expected = { { kFabric1, kGroup1 }, { kFabric1, kGroup3 }, { kFabric2, kGroup2 } };
uint8_t ciphertext_buffer[kMessageLength];
uint8_t plaintext_buffer[kMessageLength];
MutableByteSpan ciphertext(ciphertext_buffer, sizeof(ciphertext_buffer));
MutableByteSpan plaintext(plaintext_buffer, sizeof(plaintext_buffer));
MutableByteSpan tag(mic, sizeof(mic));

//
// Encrypt
//

// Load the plaintext to encrypt
memcpy(ciphertext_buffer, kMessage, sizeof(kMessage));

// Get the key context
Crypto::SymmetricKeyContext * key_context = provider->GetKeyContext(kFabric2, kGroup2);
NL_TEST_ASSERT(apSuite, nullptr != key_context);
uint16_t session_id = key_context->GetKeyHash();

const uint16_t kSessionId = 0xbc66;
// Encrypt the message
NL_TEST_ASSERT(apSuite,
CHIP_NO_ERROR ==
key_context->EncryptMessage(ciphertext, ByteSpan(aad, sizeof(aad)), ByteSpan(nonce, sizeof(nonce)), tag));

// The ciphertext must be different to the original message
NL_TEST_ASSERT(apSuite, memcmp(ciphertext.data(), kMessage, sizeof(kMessage)));
key_context->Release();

//
// Decrypt
//

const std::set<std::pair<FabricIndex, GroupId>> expected = { { kFabric2, kGroup2 } };

// Iterate all keys that matches the incoming session
GroupSession session;
auto it = provider->IterateGroupSessions(kSessionId);
auto it = provider->IterateGroupSessions(session_id);
size_t count = 0, total = 0;

NL_TEST_ASSERT(apSuite, it);
Expand All @@ -1101,27 +1122,21 @@ void TestGroupDecryption(nlTestSuite * apSuite, void * apContext)
NL_TEST_ASSERT(apSuite, expected.size() == total);
while (it->Next(session))
{
std::pair<FabricIndex, GroupId> result(session.fabric_index, session.group_id);
NL_TEST_ASSERT(apSuite, expected.count(result) > 0);
std::pair<FabricIndex, GroupId> found(session.fabric_index, session.group_id);
NL_TEST_ASSERT(apSuite, expected.count(found) > 0);
NL_TEST_ASSERT(apSuite, session.key != nullptr);

memcpy(buffer, kMessage, sizeof(kMessage));
MutableByteSpan message(buffer, sizeof(kMessage));
MutableByteSpan tag(mic, sizeof(mic));
// Load ciphertext to decrypt
memcpy(plaintext_buffer, ciphertext_buffer, sizeof(plaintext_buffer));

// Encrypt
// Decrypt de ciphertext
NL_TEST_ASSERT(
apSuite,
CHIP_NO_ERROR ==
session.key->EncryptMessage(message, ByteSpan(aad, sizeof(aad)), ByteSpan(nonce, sizeof(nonce)), tag));
NL_TEST_ASSERT(apSuite, memcmp(message.data(), kMessage, sizeof(kMessage)));
session.key->DecryptMessage(plaintext, ByteSpan(aad, sizeof(aad)), ByteSpan(nonce, sizeof(nonce)), tag));

// Decrypt
NL_TEST_ASSERT(
apSuite,
CHIP_NO_ERROR ==
session.key->DecryptMessage(message, ByteSpan(aad, sizeof(aad)), ByteSpan(nonce, sizeof(nonce)), tag));
NL_TEST_ASSERT(apSuite, 0 == memcmp(message.data(), kMessage, sizeof(kMessage)));
// The new plaintext must match the original message
NL_TEST_ASSERT(apSuite, 0 == memcmp(plaintext.data(), kMessage, sizeof(kMessage)));
count++;
}
NL_TEST_ASSERT(apSuite, count == total);
Expand Down
11 changes: 11 additions & 0 deletions src/crypto/CHIPCryptoPAL.h
Original file line number Diff line number Diff line change
Expand Up @@ -1346,6 +1346,12 @@ CHIP_ERROR ExtractDNAttributeFromX509Cert(MatterOid matterOid, const ByteSpan &
class SymmetricKeyContext
{
public:
/**
* @brief Returns the symmetric key hash
* @return Group Key Hash
*/
virtual uint16_t GetKeyHash() = 0;

virtual ~SymmetricKeyContext() = default;
/**
* @brief Perform the message encryption as described in 4.7.2. (Security Processing of Outgoing Messages)
Expand Down Expand Up @@ -1389,6 +1395,11 @@ class SymmetricKeyContext
*/
virtual CHIP_ERROR DecryptPrivacy(MutableByteSpan & header, uint16_t session_id, const ByteSpan & payload,
const ByteSpan & mic) const = 0;

/**
* @brief Release the dynamic memory used to allocate this instance of the SymmetricKeyContext
*/
virtual void Release() = 0;
};

/**
Expand Down

0 comments on commit 251b565

Please sign in to comment.