diff --git a/src/credentials/GroupDataProvider.h b/src/credentials/GroupDataProvider.h index 7ac92ccd908eee..d3fbc57cf03576 100644 --- a/src/credentials/GroupDataProvider.h +++ b/src/credentials/GroupDataProvider.h @@ -141,6 +141,27 @@ class GroupDataProvider } }; + /** + * Interface to listen for changes in the Group info. + */ + class GroupListener + { + public: + virtual ~GroupListener() = default; + /** + * Callback invoked when a new group is added. + * + * @param[in] new_group GroupInfo structure of the new group. + */ + virtual void OnGroupAdded(chip::FabricIndex fabric_index, const GroupInfo & new_group) = 0; + /** + * Callback invoked when an existing group is removed. + * + * @param[in] removed_state GroupInfo structure of the removed group. + */ + virtual void OnGroupRemoved(chip::FabricIndex fabric_index, const GroupInfo & old_group) = 0; + }; + /** * Template used to iterate the stored group data */ @@ -211,16 +232,16 @@ class GroupDataProvider // Iterators /** * Creates an iterator that may be used to obtain the list of groups associated with the given fabric. - * The number of concurrent instances of this iterator is limited. In order to release the allocated memory, - * the iterator's Release() method must be called after the iteration is finished. + * In order to release the allocated memory, the Release() method must be called after the iteration is finished. + * Modifying the group table during the iteration is currently not supported, and may yield unexpected behaviour. * @retval An instance of EndpointIterator on success * @retval nullptr if no iterator instances are available. */ virtual GroupInfoIterator * IterateGroupInfo(chip::FabricIndex fabric_index) = 0; /** * Creates an iterator that may be used to obtain the list of (group, endpoint) pairs associated with the given fabric. - * The number of concurrent instances of this iterator is limited. In order to release the allocated memory, - * the iterator's Release() method must be called after the iteration is finished. + * In order to release the allocated memory, the Release() method must be called after the iteration is finished. + * Modifying the group table during the iteration is currently not supported, and may yield unexpected behaviour. * @retval An instance of EndpointIterator on success * @retval nullptr if no iterator instances are available. */ @@ -233,14 +254,15 @@ class GroupDataProvider virtual CHIP_ERROR SetGroupKeyAt(chip::FabricIndex fabric_index, size_t index, const GroupKey & info) = 0; virtual CHIP_ERROR GetGroupKeyAt(chip::FabricIndex fabric_index, size_t index, GroupKey & info) = 0; virtual CHIP_ERROR RemoveGroupKeyAt(chip::FabricIndex fabric_index, size_t index) = 0; + /** * Creates an iterator that may be used to obtain the list of (group, keyset) pairs associated with the given fabric. - * The number of concurrent instances of this iterator is limited. In order to release the allocated memory, - * the iterator's Release() method must be called after the iteration is finished. + * In order to release the allocated memory, the Release() method must be called after the iteration is finished. + * Modifying the keyset mappings during the iteration is currently not supported, and may yield unexpected behaviour. * @retval An instance of GroupKeyIterator on success * @retval nullptr if no iterator instances are available. */ - virtual GroupKeyIterator * IterateGroupKey(chip::FabricIndex fabric_index) = 0; + virtual GroupKeyIterator * IterateGroupKeys(chip::FabricIndex fabric_index) = 0; // // Key Sets @@ -251,8 +273,8 @@ class GroupDataProvider virtual CHIP_ERROR RemoveKeySet(chip::FabricIndex fabric_index, chip::KeysetId keyset_id) = 0; /** * Creates an iterator that may be used to obtain the list of key sets associated with the given fabric. - * The number of concurrent instances of this iterator is limited. In order to release the allocated memory, - * the iterator's Release() method must be called after the iteration is finished. + * In order to release the allocated memory, the Release() method must be called after the iteration is finished. + * Modifying the key sets table during the iteration is currently not supported, and may yield unexpected behaviour. * @retval An instance of KeySetIterator on success * @retval nullptr if no iterator instances are available. */ @@ -263,6 +285,20 @@ class GroupDataProvider // General virtual CHIP_ERROR Decrypt(PacketHeader packetHeader, PayloadHeader & payloadHeader, System::PacketBufferHandle & msg) = 0; + + // Listener + void SetListener(GroupListener * listener) { mListener = listener; }; + void RemoveListener() { mListener = nullptr; }; + +protected: + void GroupAdded(chip::FabricIndex fabric_index, const GroupInfo & new_group) + { + if (mListener) + { + mListener->OnGroupAdded(fabric_index, new_group); + } + } + GroupListener * mListener = nullptr; }; /** diff --git a/src/credentials/GroupDataProviderImpl.cpp b/src/credentials/GroupDataProviderImpl.cpp index 985885246261c8..8873a27c870621 100644 --- a/src/credentials/GroupDataProviderImpl.cpp +++ b/src/credentials/GroupDataProviderImpl.cpp @@ -758,20 +758,32 @@ CHIP_ERROR GroupDataProviderImpl::SetGroupInfoAt(chip::FabricIndex fabric_index, bool found = group.Find(mStorage, fabric, info.group_id); VerifyOrReturnError(!found || (group.index == index), CHIP_ERROR_DUPLICATE_KEY_ID); - found = group.Get(mStorage, fabric, index); - group.group_id = info.group_id; + found = group.Get(mStorage, fabric, index); + const bool new_group = (group.group_id != info.group_id); + group.group_id = info.group_id; group.SetName(info.name); if (found) { - // Update existing group - return group.Save(mStorage); + // Update existing entry + if (new_group) + { + // New group, clear endpoints + RemoveEndpoints(fabric_index, group.group_id); + } + ReturnErrorOnFailure(group.Save(mStorage)); + if (new_group) + { + GroupAdded(fabric_index, group); + } + return CHIP_NO_ERROR; } // Insert last VerifyOrReturnError(fabric.group_count == index, CHIP_ERROR_INVALID_ARGUMENT); - group.next = 0; + group.group_id = info.group_id; + group.next = 0; ReturnErrorOnFailure(group.Save(mStorage)); if (group.first) @@ -789,7 +801,9 @@ CHIP_ERROR GroupDataProviderImpl::SetGroupInfoAt(chip::FabricIndex fabric_index, } // Update fabric fabric.group_count++; - return fabric.Save(mStorage); + ReturnErrorOnFailure(fabric.Save(mStorage)); + GroupAdded(fabric_index, group); + return CHIP_NO_ERROR; } CHIP_ERROR GroupDataProviderImpl::GetGroupInfoAt(chip::FabricIndex fabric_index, size_t index, GroupInfo & info) @@ -850,7 +864,12 @@ CHIP_ERROR GroupDataProviderImpl::RemoveGroupInfoAt(chip::FabricIndex fabric_ind fabric.group_count--; } // Update fabric info - return fabric.Save(mStorage); + ReturnErrorOnFailure(fabric.Save(mStorage)); + if (mListener) + { + mListener->OnGroupRemoved(fabric_index, group); + } + return CHIP_NO_ERROR; } bool GroupDataProviderImpl::HasEndpoint(chip::FabricIndex fabric_index, chip::GroupId group_id, chip::EndpointId endpoint_id) @@ -893,7 +912,9 @@ CHIP_ERROR GroupDataProviderImpl::AddEndpoint(chip::FabricIndex fabric_index, ch // Update fabric fabric.first_group = group.id; fabric.group_count++; - return fabric.Save(mStorage); + ReturnErrorOnFailure(fabric.Save(mStorage)); + GroupAdded(fabric_index, group); + return CHIP_NO_ERROR; } // Existing group @@ -1149,6 +1170,32 @@ void GroupDataProviderImpl::EndpointIteratorImpl::Release() mProvider.mEndpointIterators.ReleaseObject(this); } +CHIP_ERROR GroupDataProviderImpl::RemoveEndpoints(chip::FabricIndex fabric_index, chip::GroupId group_id) +{ + VerifyOrReturnError(mInitialized, CHIP_ERROR_INTERNAL); + + FabricData fabric(fabric_index); + GroupData group; + + VerifyOrReturnError(CHIP_NO_ERROR == fabric.Load(mStorage), CHIP_ERROR_INVALID_FABRIC_ID); + VerifyOrReturnError(group.Find(mStorage, fabric, group_id), CHIP_ERROR_KEY_NOT_FOUND); + + EndpointData endpoint(fabric_index, group.id, group.first_endpoint); + size_t endpoint_index = 0; + while (endpoint_index < group.endpoint_count) + { + ReturnErrorOnFailure(endpoint.Load(mStorage)); + endpoint.Delete(mStorage); + endpoint.id = endpoint.next; + endpoint_index++; + } + group.first_endpoint = kInvalidEndpointId; + group.endpoint_count = 0; + ReturnErrorOnFailure(group.Save(mStorage)); + + return CHIP_NO_ERROR; +} + // // Group-Key map // @@ -1250,7 +1297,7 @@ CHIP_ERROR GroupDataProviderImpl::RemoveGroupKeyAt(chip::FabricIndex fabric_inde return fabric.Save(mStorage); } -GroupDataProvider::GroupKeyIterator * GroupDataProviderImpl::IterateGroupKey(chip::FabricIndex fabric_index) +GroupDataProvider::GroupKeyIterator * GroupDataProviderImpl::IterateGroupKeys(chip::FabricIndex fabric_index) { VerifyOrReturnError(mInitialized, nullptr); return mGroupKeyIterators.CreateObject(*this, fabric_index); diff --git a/src/credentials/GroupDataProviderImpl.h b/src/credentials/GroupDataProviderImpl.h index 0462b61bf3904e..25c6bc11deae0f 100644 --- a/src/credentials/GroupDataProviderImpl.h +++ b/src/credentials/GroupDataProviderImpl.h @@ -62,7 +62,7 @@ class GroupDataProviderImpl : public GroupDataProvider CHIP_ERROR SetGroupKeyAt(chip::FabricIndex fabric_index, size_t index, const GroupKey & info) override; CHIP_ERROR GetGroupKeyAt(chip::FabricIndex fabric_index, size_t index, GroupKey & info) override; CHIP_ERROR RemoveGroupKeyAt(chip::FabricIndex fabric_index, size_t index) override; - GroupKeyIterator * IterateGroupKey(chip::FabricIndex fabric_index) override; + GroupKeyIterator * IterateGroupKeys(chip::FabricIndex fabric_index) override; // // Key Sets @@ -147,6 +147,7 @@ class GroupDataProviderImpl : public GroupDataProvider size_t mCount = 0; size_t mTotal = 0; }; + CHIP_ERROR RemoveEndpoints(chip::FabricIndex fabric_index, chip::GroupId group_id); chip::PersistentStorageDelegate & mStorage; bool mInitialized = false; diff --git a/src/credentials/tests/TestGroupDataProvider.cpp b/src/credentials/tests/TestGroupDataProvider.cpp index fc50519fe6ea23..9d39efbafe8713 100644 --- a/src/credentials/tests/TestGroupDataProvider.cpp +++ b/src/credentials/tests/TestGroupDataProvider.cpp @@ -95,6 +95,37 @@ static KeySet kKeySet1(kKeysetId1, KeySet::SecurityPolicy::kLowLatency, 1); static KeySet kKeySet2(kKeysetId2, KeySet::SecurityPolicy::kLowLatency, 2); static KeySet kKeySet3(kKeysetId3, KeySet::SecurityPolicy::kStandard, 3); +class TestListener : public GroupDataProvider::GroupListener +{ +public: + chip::FabricIndex fabric_index = kUndefinedFabricIndex; + GroupInfo latest; + size_t added_count = 0; + size_t removed_count = 0; + + void Reset() + { + fabric_index = kUndefinedFabricIndex; + latest = GroupInfo(); + added_count = 0; + removed_count = 0; + } + + void OnGroupAdded(chip::FabricIndex fabric, const GroupInfo & new_group) override + { + fabric_index = fabric; + latest = new_group; + added_count++; + } + void OnGroupRemoved(chip::FabricIndex fabric, const GroupInfo & old_group) override + { + fabric_index = fabric; + latest = old_group; + removed_count++; + } +}; +static TestListener sListener; + void TestStorageDelegate(nlTestSuite * apSuite, void * apContext) { chip::TestPersistentStorageDelegate delegate; @@ -139,6 +170,8 @@ void TestGroupInfo(nlTestSuite * apSuite, void * apContext) // Set Group Info + sListener.Reset(); + // Out-of-order NL_TEST_ASSERT(apSuite, CHIP_ERROR_INVALID_ARGUMENT == provider->SetGroupInfoAt(kFabric1, 2, kGroupInfo1_1)); @@ -158,6 +191,10 @@ void TestGroupInfo(nlTestSuite * apSuite, void * apContext) NL_TEST_ASSERT(apSuite, CHIP_ERROR_INVALID_FABRIC_ID == provider->GetGroupInfoAt(kUndefinedFabricIndex, 0, group)); NL_TEST_ASSERT(apSuite, CHIP_ERROR_NOT_FOUND == provider->GetGroupInfoAt(kFabric2, 999, group)); + NL_TEST_ASSERT(apSuite, sListener.latest == kGroupInfo2_3); + NL_TEST_ASSERT(apSuite, 6 == sListener.added_count); + NL_TEST_ASSERT(apSuite, 0 == sListener.removed_count); + NL_TEST_ASSERT(apSuite, CHIP_NO_ERROR == provider->GetGroupInfoAt(kFabric2, 2, group)); NL_TEST_ASSERT(apSuite, group == kGroupInfo2_3); NL_TEST_ASSERT(apSuite, CHIP_NO_ERROR == provider->GetGroupInfoAt(kFabric2, 1, group)); @@ -175,6 +212,9 @@ void TestGroupInfo(nlTestSuite * apSuite, void * apContext) NL_TEST_ASSERT(apSuite, CHIP_NO_ERROR == provider->RemoveGroupInfo(kFabric1, kGroup3)); NL_TEST_ASSERT(apSuite, CHIP_NO_ERROR == provider->RemoveGroupInfoAt(kFabric2, 0)); + NL_TEST_ASSERT(apSuite, sListener.latest == kGroupInfo2_1); + NL_TEST_ASSERT(apSuite, 6 == sListener.added_count); + NL_TEST_ASSERT(apSuite, 2 == sListener.removed_count); // Remaining entries shift up @@ -199,6 +239,9 @@ void TestGroupInfo(nlTestSuite * apSuite, void * apContext) NL_TEST_ASSERT(apSuite, CHIP_NO_ERROR == provider->GetGroupInfoAt(kFabric2, 0, group)); NL_TEST_ASSERT(apSuite, group == kGroupInfo3_4); + NL_TEST_ASSERT(apSuite, sListener.latest == kGroupInfo3_4); + NL_TEST_ASSERT(apSuite, 8 == sListener.added_count); + NL_TEST_ASSERT(apSuite, 2 == sListener.removed_count); // Overwrite existing group, index must match @@ -213,18 +256,24 @@ void TestGroupInfo(nlTestSuite * apSuite, void * apContext) NL_TEST_ASSERT(apSuite, CHIP_NO_ERROR == provider->GetGroupInfoAt(kFabric2, 1, group)); NL_TEST_ASSERT(apSuite, group == kGroupInfo1_3); + NL_TEST_ASSERT(apSuite, sListener.latest == kGroupInfo3_4); + NL_TEST_ASSERT(apSuite, 8 == sListener.added_count); + NL_TEST_ASSERT(apSuite, 2 == sListener.removed_count); // By group_id - // New + // Override existing NL_TEST_ASSERT(apSuite, CHIP_NO_ERROR == provider->SetGroupInfo(kFabric1, kGroupInfo3_5)); - // Override + // New group NL_TEST_ASSERT(apSuite, CHIP_NO_ERROR == provider->SetGroupInfo(kFabric2, kGroupInfo3_2)); // Not found NL_TEST_ASSERT(apSuite, CHIP_ERROR_NOT_FOUND == provider->GetGroupInfo(kFabric2, kGroup5, group)); // Existing NL_TEST_ASSERT(apSuite, CHIP_NO_ERROR == provider->GetGroupInfo(kFabric2, kGroup2, group)); NL_TEST_ASSERT(apSuite, group == kGroupInfo3_2); + NL_TEST_ASSERT(apSuite, sListener.latest == kGroupInfo3_2); + NL_TEST_ASSERT(apSuite, 9 == sListener.added_count); + NL_TEST_ASSERT(apSuite, 2 == sListener.removed_count); } void TestGroupInfoIterator(nlTestSuite * apSuite, void * apContext) @@ -456,7 +505,7 @@ void TestEndpointIterator(nlTestSuite * apSuite, void * apContext) } } -void TestGroupKey(nlTestSuite * apSuite, void * apContext) +void TestGroupKeys(nlTestSuite * apSuite, void * apContext) { GroupDataProvider * provider = GetGroupDataProvider(); NL_TEST_ASSERT(apSuite, provider); @@ -577,7 +626,7 @@ void TestGroupKeyIterator(nlTestSuite * apSuite, void * apContext) kGroup1Keyset0, kGroup1Keyset1, kGroup1Keyset2, kGroup1Keyset3 }; size_t expected_f1_count = sizeof(expected_f1) / sizeof(GroupKey); - auto it = provider->IterateGroupKey(kFabric1); + auto it = provider->IterateGroupKeys(kFabric1); size_t count = 0; NL_TEST_ASSERT(apSuite, it); if (it) @@ -596,7 +645,7 @@ void TestGroupKeyIterator(nlTestSuite * apSuite, void * apContext) GroupKey expected_f2[] = { kGroup2Keyset0, kGroup2Keyset1, kGroup2Keyset2, kGroup2Keyset3 }; size_t expected_f2_count = sizeof(expected_f2) / sizeof(GroupKey); - it = provider->IterateGroupKey(kFabric2); + it = provider->IterateGroupKeys(kFabric2); NL_TEST_ASSERT(apSuite, it); if (it) { @@ -956,6 +1005,9 @@ int Test_Setup(void * inContext) VerifyOrReturnError(CHIP_NO_ERROR == chip::Platform::MemoryInit(), FAILURE); VerifyOrReturnError(CHIP_NO_ERROR == sProvider.Init(), FAILURE); + // Event listener + sProvider.SetListener(&chip::app::TestGroups::sListener); + memcpy(chip::app::TestGroups::kKeySet0.epoch_keys, kEpochKeys0, sizeof(kEpochKeys0)); memcpy(chip::app::TestGroups::kKeySet1.epoch_keys, kEpochKeys1, sizeof(kEpochKeys1)); memcpy(chip::app::TestGroups::kKeySet2.epoch_keys, kEpochKeys2, sizeof(kEpochKeys2)); @@ -982,7 +1034,7 @@ const nlTest sTests[] = { NL_TEST_DEF("TestStorageDelegate", chip::app::TestGrou NL_TEST_DEF("TestGroupInfoIterator", chip::app::TestGroups::TestGroupInfoIterator), NL_TEST_DEF("TestEndpoints", chip::app::TestGroups::TestEndpoints), NL_TEST_DEF("TestEndpointIterator", chip::app::TestGroups::TestEndpointIterator), - NL_TEST_DEF("TestGroupKey", chip::app::TestGroups::TestGroupKey), + NL_TEST_DEF("TestGroupKeys", chip::app::TestGroups::TestGroupKeys), NL_TEST_DEF("TestGroupKeyIterator", chip::app::TestGroups::TestGroupKeyIterator), NL_TEST_DEF("TestKeySets", chip::app::TestGroups::TestKeySets), NL_TEST_DEF("TestKeySetIterator", chip::app::TestGroups::TestKeySetIterator),