Skip to content

Commit

Permalink
Group Data Provider: Listener added.
Browse files Browse the repository at this point in the history
  • Loading branch information
rcasallas-silabs committed Dec 15, 2021
1 parent ebca375 commit 9e8c991
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 15 deletions.
46 changes: 45 additions & 1 deletion src/credentials/GroupDataProvider.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,27 @@ class GroupDataProvider
}
};

/**
* Interface to listen for changes in the Group info.
*/
class Listener
{
public:
virtual ~Listener() = default;
/**
* Callback invoked when a new group is added.
*
* @param[in] new_group GroupInfo structure of the new group.
*/
virtual void OnGroupAdded(chip::FabricIndex fabric_index, const GroupInfo & new_group) = 0;
/**
* Callback invoked when an existing group is removed.
*
* @param[in] removed_state GroupInfo structure of the removed group.
*/
virtual void OnGroupRemoved(chip::FabricIndex fabric_index, const GroupInfo & old_group) = 0;
};

/**
* Template used to iterate the stored group data
*/
Expand Down Expand Up @@ -233,14 +254,15 @@ class GroupDataProvider
virtual CHIP_ERROR SetGroupKeyAt(chip::FabricIndex fabric_index, size_t index, const GroupKey & info) = 0;
virtual CHIP_ERROR GetGroupKeyAt(chip::FabricIndex fabric_index, size_t index, GroupKey & info) = 0;
virtual CHIP_ERROR RemoveGroupKeyAt(chip::FabricIndex fabric_index, size_t index) = 0;

/**
* Creates an iterator that may be used to obtain the list of (group, keyset) pairs associated with the given fabric.
* The number of concurrent instances of this iterator is limited. In order to release the allocated memory,
* the iterator's Release() method must be called after the iteration is finished.
* @retval An instance of GroupKeyIterator on success
* @retval nullptr if no iterator instances are available.
*/
virtual GroupKeyIterator * IterateGroupKey(chip::FabricIndex fabric_index) = 0;
virtual GroupKeyIterator * IterateGroupKeys(chip::FabricIndex fabric_index) = 0;

//
// Key Sets
Expand All @@ -263,6 +285,28 @@ class GroupDataProvider

// General
virtual CHIP_ERROR Decrypt(PacketHeader packetHeader, PayloadHeader & payloadHeader, System::PacketBufferHandle & msg) = 0;

// Listener
void SetListener(Listener * listener) { mListener = listener; };
void RemoveListener() { mListener = nullptr; };

protected:
void GroupAdded(chip::FabricIndex fabric_index, const GroupInfo & new_group)
{
if (mListener)
{
mListener->OnGroupAdded(fabric_index, new_group);
}
}
void GroupRemoved(chip::FabricIndex fabric_index, const GroupInfo & old_group)
{
if (mListener)
{
mListener->OnGroupRemoved(fabric_index, old_group);
}
}

Listener * mListener = nullptr;
};

/**
Expand Down
62 changes: 55 additions & 7 deletions src/credentials/GroupDataProviderImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,7 @@ CHIP_ERROR GroupDataProviderImpl::SetGroupInfoAt(chip::FabricIndex fabric_index,

FabricData fabric(fabric_index);
GroupData group;
bool new_group = false;

// Load fabric, defaults to zero
CHIP_ERROR err = fabric.Load(mStorage);
Expand All @@ -759,19 +760,31 @@ CHIP_ERROR GroupDataProviderImpl::SetGroupInfoAt(chip::FabricIndex fabric_index,
VerifyOrReturnError(!found || (group.index == index), CHIP_ERROR_DUPLICATE_KEY_ID);

found = group.Get(mStorage, fabric, index);
new_group = (group.group_id != info.group_id);
group.group_id = info.group_id;
group.SetName(info.name);

if (found)
{
// Update existing group
return group.Save(mStorage);
// Update existing entry
if (new_group)
{
// New group, clear endpoints
RemoveEndpoints(fabric_index, group.group_id);
}
ReturnErrorOnFailure(group.Save(mStorage));
if (new_group)
{
GroupAdded(fabric_index, group);
}
return CHIP_NO_ERROR;
}

// Insert last
VerifyOrReturnError(fabric.group_count == index, CHIP_ERROR_INVALID_ARGUMENT);

group.next = 0;
group.group_id = info.group_id;
group.next = 0;
ReturnErrorOnFailure(group.Save(mStorage));

if (group.first)
Expand All @@ -789,7 +802,9 @@ CHIP_ERROR GroupDataProviderImpl::SetGroupInfoAt(chip::FabricIndex fabric_index,
}
// Update fabric
fabric.group_count++;
return fabric.Save(mStorage);
ReturnErrorOnFailure(fabric.Save(mStorage));
GroupAdded(fabric_index, group);
return CHIP_NO_ERROR;
}

CHIP_ERROR GroupDataProviderImpl::GetGroupInfoAt(chip::FabricIndex fabric_index, size_t index, GroupInfo & info)
Expand Down Expand Up @@ -850,7 +865,9 @@ CHIP_ERROR GroupDataProviderImpl::RemoveGroupInfoAt(chip::FabricIndex fabric_ind
fabric.group_count--;
}
// Update fabric info
return fabric.Save(mStorage);
ReturnErrorOnFailure(fabric.Save(mStorage));
GroupRemoved(fabric_index, group);
return CHIP_NO_ERROR;
}

bool GroupDataProviderImpl::HasEndpoint(chip::FabricIndex fabric_index, chip::GroupId group_id, chip::EndpointId endpoint_id)
Expand Down Expand Up @@ -893,7 +910,9 @@ CHIP_ERROR GroupDataProviderImpl::AddEndpoint(chip::FabricIndex fabric_index, ch
// Update fabric
fabric.first_group = group.id;
fabric.group_count++;
return fabric.Save(mStorage);
ReturnErrorOnFailure(fabric.Save(mStorage));
GroupAdded(fabric_index, group);
return CHIP_NO_ERROR;
}

// Existing group
Expand Down Expand Up @@ -1149,6 +1168,35 @@ void GroupDataProviderImpl::EndpointIteratorImpl::Release()
mProvider.mEndpointIterators.ReleaseObject(this);
}

CHIP_ERROR GroupDataProviderImpl::RemoveEndpoints(chip::FabricIndex fabric_index, chip::GroupId group_id)
{
VerifyOrReturnError(mInitialized, CHIP_ERROR_INTERNAL);

FabricData fabric(fabric_index);
GroupData group;

VerifyOrReturnError(CHIP_NO_ERROR == fabric.Load(mStorage), CHIP_ERROR_INVALID_FABRIC_ID);
VerifyOrReturnError(group.Find(mStorage, fabric, group_id), CHIP_ERROR_KEY_NOT_FOUND);

EndpointData endpoint(fabric_index, group.id, group.first_endpoint);
size_t endpoint_index = 0;
while (endpoint_index < group.endpoint_count)
{
if (CHIP_NO_ERROR != endpoint.Load(mStorage))
{
break;
}
endpoint.Delete(mStorage);
endpoint.id = endpoint.next;
endpoint_index++;
}
group.first_endpoint = kInvalidEndpointId;
group.endpoint_count = 0;
ReturnErrorOnFailure(group.Save(mStorage));

return CHIP_NO_ERROR;
}

//
// Group-Key map
//
Expand Down Expand Up @@ -1250,7 +1298,7 @@ CHIP_ERROR GroupDataProviderImpl::RemoveGroupKeyAt(chip::FabricIndex fabric_inde
return fabric.Save(mStorage);
}

GroupDataProvider::GroupKeyIterator * GroupDataProviderImpl::IterateGroupKey(chip::FabricIndex fabric_index)
GroupDataProvider::GroupKeyIterator * GroupDataProviderImpl::IterateGroupKeys(chip::FabricIndex fabric_index)
{
VerifyOrReturnError(mInitialized, nullptr);
return mGroupKeyIterators.CreateObject(*this, fabric_index);
Expand Down
3 changes: 2 additions & 1 deletion src/credentials/GroupDataProviderImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class GroupDataProviderImpl : public GroupDataProvider
CHIP_ERROR SetGroupKeyAt(chip::FabricIndex fabric_index, size_t index, const GroupKey & info) override;
CHIP_ERROR GetGroupKeyAt(chip::FabricIndex fabric_index, size_t index, GroupKey & info) override;
CHIP_ERROR RemoveGroupKeyAt(chip::FabricIndex fabric_index, size_t index) override;
GroupKeyIterator * IterateGroupKey(chip::FabricIndex fabric_index) override;
GroupKeyIterator * IterateGroupKeys(chip::FabricIndex fabric_index) override;

//
// Key Sets
Expand Down Expand Up @@ -147,6 +147,7 @@ class GroupDataProviderImpl : public GroupDataProvider
size_t mCount = 0;
size_t mTotal = 0;
};
CHIP_ERROR RemoveEndpoints(chip::FabricIndex fabric_index, chip::GroupId group_id);

chip::PersistentStorageDelegate & mStorage;
bool mInitialized = false;
Expand Down
64 changes: 58 additions & 6 deletions src/credentials/tests/TestGroupDataProvider.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,37 @@ static KeySet kKeySet1(kKeysetId1, KeySet::SecurityPolicy::kLowLatency, 1);
static KeySet kKeySet2(kKeysetId2, KeySet::SecurityPolicy::kLowLatency, 2);
static KeySet kKeySet3(kKeysetId3, KeySet::SecurityPolicy::kStandard, 3);

class TestListener : public GroupDataProvider::Listener
{
public:
chip::FabricIndex fabric_index = kUndefinedFabricIndex;
GroupInfo latest;
size_t added_count = 0;
size_t removed_count = 0;

void Reset()
{
fabric_index = kUndefinedFabricIndex;
latest = GroupInfo();
added_count = 0;
removed_count = 0;
}

void OnGroupAdded(chip::FabricIndex fabric, const GroupInfo & new_group) override
{
fabric_index = fabric;
latest = new_group;
added_count++;
}
void OnGroupRemoved(chip::FabricIndex fabric, const GroupInfo & old_group) override
{
fabric_index = fabric;
latest = old_group;
removed_count++;
}
};
static TestListener sListener;

void TestStorageDelegate(nlTestSuite * apSuite, void * apContext)
{
chip::TestPersistentStorageDelegate delegate;
Expand Down Expand Up @@ -139,6 +170,8 @@ void TestGroupInfo(nlTestSuite * apSuite, void * apContext)

// Set Group Info

sListener.Reset();

// Out-of-order
NL_TEST_ASSERT(apSuite, CHIP_ERROR_INVALID_ARGUMENT == provider->SetGroupInfoAt(kFabric1, 2, kGroupInfo1_1));

Expand All @@ -158,6 +191,10 @@ void TestGroupInfo(nlTestSuite * apSuite, void * apContext)
NL_TEST_ASSERT(apSuite, CHIP_ERROR_INVALID_FABRIC_ID == provider->GetGroupInfoAt(kUndefinedFabricIndex, 0, group));
NL_TEST_ASSERT(apSuite, CHIP_ERROR_NOT_FOUND == provider->GetGroupInfoAt(kFabric2, 999, group));

NL_TEST_ASSERT(apSuite, sListener.latest == kGroupInfo2_3);
NL_TEST_ASSERT(apSuite, 6 == sListener.added_count);
NL_TEST_ASSERT(apSuite, 0 == sListener.removed_count);

NL_TEST_ASSERT(apSuite, CHIP_NO_ERROR == provider->GetGroupInfoAt(kFabric2, 2, group));
NL_TEST_ASSERT(apSuite, group == kGroupInfo2_3);
NL_TEST_ASSERT(apSuite, CHIP_NO_ERROR == provider->GetGroupInfoAt(kFabric2, 1, group));
Expand All @@ -175,6 +212,9 @@ void TestGroupInfo(nlTestSuite * apSuite, void * apContext)

NL_TEST_ASSERT(apSuite, CHIP_NO_ERROR == provider->RemoveGroupInfo(kFabric1, kGroup3));
NL_TEST_ASSERT(apSuite, CHIP_NO_ERROR == provider->RemoveGroupInfoAt(kFabric2, 0));
NL_TEST_ASSERT(apSuite, sListener.latest == kGroupInfo2_1);
NL_TEST_ASSERT(apSuite, 6 == sListener.added_count);
NL_TEST_ASSERT(apSuite, 2 == sListener.removed_count);

// Remaining entries shift up

Expand All @@ -199,6 +239,9 @@ void TestGroupInfo(nlTestSuite * apSuite, void * apContext)

NL_TEST_ASSERT(apSuite, CHIP_NO_ERROR == provider->GetGroupInfoAt(kFabric2, 0, group));
NL_TEST_ASSERT(apSuite, group == kGroupInfo3_4);
NL_TEST_ASSERT(apSuite, sListener.latest == kGroupInfo3_4);
NL_TEST_ASSERT(apSuite, 8 == sListener.added_count);
NL_TEST_ASSERT(apSuite, 2 == sListener.removed_count);

// Overwrite existing group, index must match

Expand All @@ -213,18 +256,24 @@ void TestGroupInfo(nlTestSuite * apSuite, void * apContext)

NL_TEST_ASSERT(apSuite, CHIP_NO_ERROR == provider->GetGroupInfoAt(kFabric2, 1, group));
NL_TEST_ASSERT(apSuite, group == kGroupInfo1_3);
NL_TEST_ASSERT(apSuite, sListener.latest == kGroupInfo3_4);
NL_TEST_ASSERT(apSuite, 8 == sListener.added_count);
NL_TEST_ASSERT(apSuite, 2 == sListener.removed_count);

// By group_id

// New
// Override existing
NL_TEST_ASSERT(apSuite, CHIP_NO_ERROR == provider->SetGroupInfo(kFabric1, kGroupInfo3_5));
// Override
// New group
NL_TEST_ASSERT(apSuite, CHIP_NO_ERROR == provider->SetGroupInfo(kFabric2, kGroupInfo3_2));
// Not found
NL_TEST_ASSERT(apSuite, CHIP_ERROR_NOT_FOUND == provider->GetGroupInfo(kFabric2, kGroup5, group));
// Existing
NL_TEST_ASSERT(apSuite, CHIP_NO_ERROR == provider->GetGroupInfo(kFabric2, kGroup2, group));
NL_TEST_ASSERT(apSuite, group == kGroupInfo3_2);
NL_TEST_ASSERT(apSuite, sListener.latest == kGroupInfo3_2);
NL_TEST_ASSERT(apSuite, 9 == sListener.added_count);
NL_TEST_ASSERT(apSuite, 2 == sListener.removed_count);
}

void TestGroupInfoIterator(nlTestSuite * apSuite, void * apContext)
Expand Down Expand Up @@ -456,7 +505,7 @@ void TestEndpointIterator(nlTestSuite * apSuite, void * apContext)
}
}

void TestGroupKey(nlTestSuite * apSuite, void * apContext)
void TestGroupKeys(nlTestSuite * apSuite, void * apContext)
{
GroupDataProvider * provider = GetGroupDataProvider();
NL_TEST_ASSERT(apSuite, provider);
Expand Down Expand Up @@ -577,7 +626,7 @@ void TestGroupKeyIterator(nlTestSuite * apSuite, void * apContext)
kGroup1Keyset0, kGroup1Keyset1, kGroup1Keyset2, kGroup1Keyset3 };
size_t expected_f1_count = sizeof(expected_f1) / sizeof(GroupKey);

auto it = provider->IterateGroupKey(kFabric1);
auto it = provider->IterateGroupKeys(kFabric1);
size_t count = 0;
NL_TEST_ASSERT(apSuite, it);
if (it)
Expand All @@ -596,7 +645,7 @@ void TestGroupKeyIterator(nlTestSuite * apSuite, void * apContext)
GroupKey expected_f2[] = { kGroup2Keyset0, kGroup2Keyset1, kGroup2Keyset2, kGroup2Keyset3 };
size_t expected_f2_count = sizeof(expected_f2) / sizeof(GroupKey);

it = provider->IterateGroupKey(kFabric2);
it = provider->IterateGroupKeys(kFabric2);
NL_TEST_ASSERT(apSuite, it);
if (it)
{
Expand Down Expand Up @@ -956,6 +1005,9 @@ int Test_Setup(void * inContext)
VerifyOrReturnError(CHIP_NO_ERROR == chip::Platform::MemoryInit(), FAILURE);
VerifyOrReturnError(CHIP_NO_ERROR == sProvider.Init(), FAILURE);

// Event listener
sProvider.SetListener(&chip::app::TestGroups::sListener);

memcpy(chip::app::TestGroups::kKeySet0.epoch_keys, kEpochKeys0, sizeof(kEpochKeys0));
memcpy(chip::app::TestGroups::kKeySet1.epoch_keys, kEpochKeys1, sizeof(kEpochKeys1));
memcpy(chip::app::TestGroups::kKeySet2.epoch_keys, kEpochKeys2, sizeof(kEpochKeys2));
Expand All @@ -982,7 +1034,7 @@ const nlTest sTests[] = { NL_TEST_DEF("TestStorageDelegate", chip::app::TestGrou
NL_TEST_DEF("TestGroupInfoIterator", chip::app::TestGroups::TestGroupInfoIterator),
NL_TEST_DEF("TestEndpoints", chip::app::TestGroups::TestEndpoints),
NL_TEST_DEF("TestEndpointIterator", chip::app::TestGroups::TestEndpointIterator),
NL_TEST_DEF("TestGroupKey", chip::app::TestGroups::TestGroupKey),
NL_TEST_DEF("TestGroupKeys", chip::app::TestGroups::TestGroupKeys),
NL_TEST_DEF("TestGroupKeyIterator", chip::app::TestGroups::TestGroupKeyIterator),
NL_TEST_DEF("TestKeySets", chip::app::TestGroups::TestKeySets),
NL_TEST_DEF("TestKeySetIterator", chip::app::TestGroups::TestKeySetIterator),
Expand Down

0 comments on commit 9e8c991

Please sign in to comment.