diff --git a/src/app/icd/client/CheckInHandler.cpp b/src/app/icd/client/CheckInHandler.cpp index e0edc649cc4563..f4eaa48b114116 100644 --- a/src/app/icd/client/CheckInHandler.cpp +++ b/src/app/icd/client/CheckInHandler.cpp @@ -52,8 +52,12 @@ CHIP_ERROR CheckInHandler::Init(Messaging::ExchangeManager * exchangeManager, IC { VerifyOrReturnError(exchangeManager != nullptr, CHIP_ERROR_INVALID_ARGUMENT); VerifyOrReturnError(clientStorage != nullptr, CHIP_ERROR_INVALID_ARGUMENT); + VerifyOrReturnError(delegate != nullptr, CHIP_ERROR_INVALID_ARGUMENT); + VerifyOrReturnError(engine != nullptr, CHIP_ERROR_INVALID_ARGUMENT); VerifyOrReturnError(mpExchangeManager == nullptr, CHIP_ERROR_INCORRECT_STATE); VerifyOrReturnError(mpICDClientStorage == nullptr, CHIP_ERROR_INCORRECT_STATE); + VerifyOrReturnError(mpCheckInDelegate == nullptr, CHIP_ERROR_INCORRECT_STATE); + VerifyOrReturnError(mpImEngine == nullptr, CHIP_ERROR_INCORRECT_STATE); mpExchangeManager = exchangeManager; mpICDClientStorage = clientStorage; @@ -113,6 +117,7 @@ CHIP_ERROR CheckInHandler::OnMessageReceived(Messaging::ExchangeContext * ec, co if (refreshKey) { + ChipLogProgress(ICD, "Key Refresh is required"); RefreshKeySender * refreshKeySender = mpCheckInDelegate->OnKeyRefreshNeeded(clientInfo, mpICDClientStorage); if (refreshKeySender == nullptr) { diff --git a/src/app/icd/client/DefaultICDClientStorage.cpp b/src/app/icd/client/DefaultICDClientStorage.cpp index 2b6b3a1aff1657..0c6a0f7e134bb4 100644 --- a/src/app/icd/client/DefaultICDClientStorage.cpp +++ b/src/app/icd/client/DefaultICDClientStorage.cpp @@ -51,6 +51,11 @@ CHIP_ERROR DefaultICDClientStorage::UpdateFabricList(FabricIndex fabricIndex) mFabricList.push_back(fabricIndex); + return StoreFabricList(); +} + +CHIP_ERROR DefaultICDClientStorage::StoreFabricList() +{ Platform::ScopedMemoryBuffer backingBuffer; size_t counter = mFabricList.size(); size_t total = kFabricIndexTlvSize * counter + kArrayOverHead; @@ -68,7 +73,7 @@ CHIP_ERROR DefaultICDClientStorage::UpdateFabricList(FabricIndex fabricIndex) const auto len = writer.GetLengthWritten(); VerifyOrReturnError(CanCastTo(len), CHIP_ERROR_BUFFER_TOO_SMALL); - writer.Finalize(backingBuffer); + ReturnErrorOnFailure(writer.Finalize(backingBuffer)); return mpClientInfoStore->SyncSetKeyValue(DefaultStorageKeyAllocator::ICDFabricList().KeyName(), backingBuffer.Get(), static_cast(len)); } @@ -211,6 +216,7 @@ CHIP_ERROR DefaultICDClientStorage::Load(FabricIndex fabricIndex, std::vector 0, CHIP_NO_ERROR); size_t len = clientInfoSize * count + kArrayOverHead; Platform::ScopedMemoryBuffer backingBuffer; VerifyOrReturnError(CanCastTo(len), CHIP_ERROR_BUFFER_TOO_SMALL); @@ -344,8 +350,21 @@ CHIP_ERROR DefaultICDClientStorage::SerializeToTlv(TLV::TLVWriter & writer, cons return writer.EndContainer(arrayType); } +bool DefaultICDClientStorage::FabricExists(FabricIndex fabricIndex) +{ + for (auto & fabric_idx : mFabricList) + { + if (fabric_idx == fabricIndex) + { + return true; + } + } + return false; +} + CHIP_ERROR DefaultICDClientStorage::StoreEntry(const ICDClientInfo & clientInfo) { + VerifyOrReturnError(FabricExists(clientInfo.peer_node.GetFabricIndex()), CHIP_ERROR_INVALID_FABRIC_INDEX); std::vector clientInfoVector; size_t clientInfoSize = MaxICDClientInfoSize(); ReturnErrorOnFailure(Load(clientInfo.peer_node.GetFabricIndex(), clientInfoVector, clientInfoSize)); @@ -359,7 +378,6 @@ CHIP_ERROR DefaultICDClientStorage::StoreEntry(const ICDClientInfo & clientInfo) break; } } - clientInfoVector.push_back(clientInfo); size_t total = clientInfoSize * clientInfoVector.size() + kArrayOverHead; Platform::ScopedMemoryBuffer backingBuffer; @@ -371,7 +389,7 @@ CHIP_ERROR DefaultICDClientStorage::StoreEntry(const ICDClientInfo & clientInfo) const auto len = writer.GetLengthWritten(); VerifyOrReturnError(CanCastTo(len), CHIP_ERROR_BUFFER_TOO_SMALL); - writer.Finalize(backingBuffer); + ReturnErrorOnFailure(writer.Finalize(backingBuffer)); ReturnErrorOnFailure(mpClientInfoStore->SyncSetKeyValue( DefaultStorageKeyAllocator::ICDClientInfoKey(clientInfo.peer_node.GetFabricIndex()).KeyName(), backingBuffer.Get(), static_cast(len))); @@ -416,7 +434,7 @@ CHIP_ERROR DefaultICDClientStorage::UpdateEntryCountForFabric(FabricIndex fabric const auto len = writer.GetLengthWritten(); VerifyOrReturnError(CanCastTo(len), CHIP_ERROR_BUFFER_TOO_SMALL); - writer.Finalize(backingBuffer); + ReturnErrorOnFailure(writer.Finalize(backingBuffer)); return mpClientInfoStore->SyncSetKeyValue(DefaultStorageKeyAllocator::FabricICDClientInfoCounter(fabricIndex).KeyName(), backingBuffer.Get(), static_cast(len)); @@ -424,9 +442,11 @@ CHIP_ERROR DefaultICDClientStorage::UpdateEntryCountForFabric(FabricIndex fabric CHIP_ERROR DefaultICDClientStorage::DeleteEntry(const ScopedNodeId & peerNode) { + VerifyOrReturnError(FabricExists(peerNode.GetFabricIndex()), CHIP_NO_ERROR); size_t clientInfoSize = 0; std::vector clientInfoVector; ReturnErrorOnFailure(Load(peerNode.GetFabricIndex(), clientInfoVector, clientInfoSize)); + VerifyOrReturnError(clientInfoVector.size() > 0, CHIP_NO_ERROR); for (auto it = clientInfoVector.begin(); it != clientInfoVector.end(); it++) { @@ -440,7 +460,6 @@ CHIP_ERROR DefaultICDClientStorage::DeleteEntry(const ScopedNodeId & peerNode) ReturnErrorOnFailure( mpClientInfoStore->SyncDeleteKeyValue(DefaultStorageKeyAllocator::ICDClientInfoKey(peerNode.GetFabricIndex()).KeyName())); - size_t total = clientInfoSize * clientInfoVector.size() + kArrayOverHead; Platform::ScopedMemoryBuffer backingBuffer; ReturnErrorCodeIf(!backingBuffer.Calloc(total), CHIP_ERROR_NO_MEMORY); @@ -451,7 +470,7 @@ CHIP_ERROR DefaultICDClientStorage::DeleteEntry(const ScopedNodeId & peerNode) const auto len = writer.GetLengthWritten(); VerifyOrReturnError(CanCastTo(len), CHIP_ERROR_BUFFER_TOO_SMALL); - writer.Finalize(backingBuffer); + ReturnErrorOnFailure(writer.Finalize(backingBuffer)); ReturnErrorOnFailure( mpClientInfoStore->SyncSetKeyValue(DefaultStorageKeyAllocator::ICDClientInfoKey(peerNode.GetFabricIndex()).KeyName(), backingBuffer.Get(), static_cast(len))); @@ -461,6 +480,8 @@ CHIP_ERROR DefaultICDClientStorage::DeleteEntry(const ScopedNodeId & peerNode) CHIP_ERROR DefaultICDClientStorage::DeleteAllEntries(FabricIndex fabricIndex) { + VerifyOrReturnError(FabricExists(fabricIndex), CHIP_NO_ERROR); + size_t clientInfoSize = 0; std::vector clientInfoVector; ReturnErrorOnFailure(Load(fabricIndex, clientInfoVector, clientInfoSize)); @@ -471,7 +492,23 @@ CHIP_ERROR DefaultICDClientStorage::DeleteAllEntries(FabricIndex fabricIndex) } ReturnErrorOnFailure( mpClientInfoStore->SyncDeleteKeyValue(DefaultStorageKeyAllocator::ICDClientInfoKey(fabricIndex).KeyName())); - return mpClientInfoStore->SyncDeleteKeyValue(DefaultStorageKeyAllocator::FabricICDClientInfoCounter(fabricIndex).KeyName()); + ReturnErrorOnFailure( + mpClientInfoStore->SyncDeleteKeyValue(DefaultStorageKeyAllocator::FabricICDClientInfoCounter(fabricIndex).KeyName())); + + for (auto fabric = mFabricList.begin(); fabric != mFabricList.end(); fabric++) + { + if (*fabric == fabricIndex) + { + mFabricList.erase(fabric); + break; + } + } + + if (mFabricList.size() == 0) + { + return mpClientInfoStore->SyncDeleteKeyValue(DefaultStorageKeyAllocator::ICDFabricList().KeyName()); + } + return StoreFabricList(); } CHIP_ERROR DefaultICDClientStorage::ProcessCheckInPayload(const ByteSpan & payload, ICDClientInfo & clientInfo, diff --git a/src/app/icd/client/DefaultICDClientStorage.h b/src/app/icd/client/DefaultICDClientStorage.h index 76faeb9afcc7cd..13064fe649c61b 100644 --- a/src/app/icd/client/DefaultICDClientStorage.h +++ b/src/app/icd/client/DefaultICDClientStorage.h @@ -120,6 +120,12 @@ class DefaultICDClientStorage : public ICDClientStorage CHIP_ERROR ProcessCheckInPayload(const ByteSpan & payload, ICDClientInfo & clientInfo, Protocols::SecureChannel::CounterType & counter) override; +#if CONFIG_BUILD_FOR_HOST_UNIT_TEST + size_t GetFabricListSize() { return mFabricList.size(); } + + PersistentStorageDelegate * GetClientInfoStore() { return mpClientInfoStore; } +#endif // CONFIG_BUILD_FOR_HOST_UNIT_TEST + protected: enum class ClientInfoTag : uint8_t { @@ -173,9 +179,12 @@ class DefaultICDClientStorage : public ICDClientStorage private: friend class ICDClientInfoIteratorImpl; + CHIP_ERROR StoreFabricList(); CHIP_ERROR LoadFabricList(); CHIP_ERROR LoadCounter(FabricIndex fabricIndex, size_t & count, size_t & clientInfoSize); + bool FabricExists(FabricIndex fabricIndex); + CHIP_ERROR IncreaseEntryCountForFabric(FabricIndex fabricIndex); CHIP_ERROR DecreaseEntryCountForFabric(FabricIndex fabricIndex); CHIP_ERROR UpdateEntryCountForFabric(FabricIndex fabricIndex, bool increase); diff --git a/src/app/tests/BUILD.gn b/src/app/tests/BUILD.gn index 104a57a2fc019a..09dcb304ea29f2 100644 --- a/src/app/tests/BUILD.gn +++ b/src/app/tests/BUILD.gn @@ -194,6 +194,7 @@ chip_test_suite("tests") { "TestBasicCommandPathRegistry.cpp", "TestBindingTable.cpp", "TestBuilderParser.cpp", + "TestCheckInHandler.cpp", "TestCommandHandlerInterfaceRegistry.cpp", "TestCommandInteraction.cpp", "TestCommandPathParams.cpp", @@ -236,6 +237,7 @@ chip_test_suite("tests") { "${chip_root}/src/app", "${chip_root}/src/app/codegen-data-model-provider:instance-header", "${chip_root}/src/app/common:cluster-objects", + "${chip_root}/src/app/icd/client:handler", "${chip_root}/src/app/icd/client:manager", "${chip_root}/src/app/tests:helpers", "${chip_root}/src/app/util/mock:mock_codegen_data_model", diff --git a/src/app/tests/TestCheckInHandler.cpp b/src/app/tests/TestCheckInHandler.cpp new file mode 100644 index 00000000000000..0d57eb6b42183c --- /dev/null +++ b/src/app/tests/TestCheckInHandler.cpp @@ -0,0 +1,300 @@ +/* + * Copyright (c) 2024 Project CHIP Authors + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace chip; +using namespace app; +using namespace System; +using TestSessionKeystoreImpl = Crypto::DefaultSessionKeystore; + +constexpr uint8_t kKeyBuffer1[] = { + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f +}; + +constexpr uint8_t kKeyBuffer2[] = { + 0xf1, 0xe1, 0xd1, 0xc1, 0xb1, 0xa1, 0x91, 0x81, 0x71, 0x61, 0x51, 0x14, 0x31, 0x21, 0x11, 0x01 +}; + +class TestCheckInHandler : public chip::Test::AppContext +{ +}; + +class CheckInHandlerWrapper : public chip::app::CheckInHandler +{ +public: + CHIP_ERROR ValidateOnMessageReceived(Messaging::ExchangeContext * ec, const PayloadHeader & payloadHeader, + System::PacketBufferHandle && payload) + { + return OnMessageReceived(ec, payloadHeader, std::move(payload)); + } +}; + +TEST_F(TestCheckInHandler, TestOnMessageReceived) +{ + InteractionModelEngine * engine = InteractionModelEngine::GetInstance(); + EXPECT_EQ(engine->Init(&GetExchangeManager(), &GetFabricTable(), app::reporting::GetDefaultReportScheduler()), CHIP_NO_ERROR); + TestPersistentStorageDelegate clientInfoStorage; + TestSessionKeystoreImpl keystore; + + DefaultICDClientStorage manager; + EXPECT_EQ(manager.Init(&clientInfoStorage, &keystore), CHIP_NO_ERROR); + + DefaultCheckInDelegate checkInDelegate; + EXPECT_EQ(checkInDelegate.Init(&manager, engine), CHIP_NO_ERROR); + + CheckInHandlerWrapper checkInHandler; + EXPECT_EQ(checkInHandler.Init(&GetExchangeManager(), &manager, &checkInDelegate, engine), CHIP_NO_ERROR); + + FabricIndex fabricIdA = 1; + NodeId nodeIdA = 6666; + + ICDClientInfo clientInfoA; + clientInfoA.peer_node = ScopedNodeId(nodeIdA, fabricIdA); + clientInfoA.start_icd_counter = 0; + clientInfoA.offset = 0; + EXPECT_EQ(manager.UpdateFabricList(fabricIdA), CHIP_NO_ERROR); + EXPECT_EQ(manager.SetKey(clientInfoA, ByteSpan(kKeyBuffer1)), CHIP_NO_ERROR); + EXPECT_EQ(manager.StoreEntry(clientInfoA), CHIP_NO_ERROR); + + FabricIndex fabricIdB = 2; + NodeId nodeIdB = 6667; + + ICDClientInfo clientInfoB; + clientInfoB.peer_node = ScopedNodeId(nodeIdB, fabricIdB); + clientInfoB.offset = 0; + EXPECT_EQ(manager.UpdateFabricList(fabricIdB), CHIP_NO_ERROR); + EXPECT_EQ(manager.SetKey(clientInfoB, ByteSpan(kKeyBuffer2)), CHIP_NO_ERROR); + EXPECT_EQ(manager.StoreEntry(clientInfoB), CHIP_NO_ERROR); + + PayloadHeader payloadHeader; + payloadHeader.SetExchangeID(0); + payloadHeader.SetMessageType(chip::Protocols::SecureChannel::MsgType::ICD_CheckIn); + + uint32_t counter = 1; + System::PacketBufferHandle buffer1 = MessagePacketBuffer::New(chip::Protocols::SecureChannel::CheckinMessage::kMinPayloadSize); + MutableByteSpan output1{ buffer1->Start(), buffer1->MaxDataLength() }; + EXPECT_EQ(chip::Protocols::SecureChannel::CheckinMessage::GenerateCheckinMessagePayload( + clientInfoA.aes_key_handle, clientInfoA.hmac_key_handle, counter, ByteSpan(), output1), + CHIP_NO_ERROR); + + buffer1->SetDataLength(static_cast(output1.size())); + EXPECT_EQ(checkInHandler.ValidateOnMessageReceived(nullptr, payloadHeader, std::move(buffer1)), CHIP_NO_ERROR); + + ICDClientInfo clientInfo1; + auto * iterator = manager.IterateICDClientInfo(); + ASSERT_NE(iterator, nullptr); + while (iterator->Next(clientInfo1)) + { + if (clientInfo1.peer_node.GetNodeId() == nodeIdA && clientInfo1.peer_node.GetFabricIndex() == fabricIdA) + { + break; + } + } + iterator->Release(); + + EXPECT_EQ(clientInfo1.offset, counter - clientInfoA.start_icd_counter); + + // Validate duplicate check-in message handling + chip::System::PacketBufferHandle buffer2 = + MessagePacketBuffer::New(chip::Protocols::SecureChannel::CheckinMessage::kMinPayloadSize); + MutableByteSpan output2{ buffer2->Start(), buffer2->MaxDataLength() }; + EXPECT_EQ(chip::Protocols::SecureChannel::CheckinMessage::GenerateCheckinMessagePayload( + clientInfoA.aes_key_handle, clientInfoA.hmac_key_handle, counter, ByteSpan(), output2), + CHIP_NO_ERROR); + + buffer2->SetDataLength(static_cast(output2.size())); + EXPECT_EQ(checkInHandler.ValidateOnMessageReceived(nullptr, payloadHeader, std::move(buffer2)), CHIP_NO_ERROR); + + ICDClientInfo clientInfo2; + iterator = manager.IterateICDClientInfo(); + ASSERT_NE(iterator, nullptr); + while (iterator->Next(clientInfo2)) + { + if (clientInfo2.peer_node.GetNodeId() == nodeIdA && clientInfo2.peer_node.GetFabricIndex() == fabricIdA) + { + break; + } + } + iterator->Release(); + EXPECT_EQ(clientInfo2.offset, counter - clientInfoA.start_icd_counter); + + // Validate second check-in message with increased counter + counter++; + System::PacketBufferHandle buffer3 = MessagePacketBuffer::New(chip::Protocols::SecureChannel::CheckinMessage::kMinPayloadSize); + MutableByteSpan output3{ buffer3->Start(), buffer3->MaxDataLength() }; + EXPECT_EQ(chip::Protocols::SecureChannel::CheckinMessage::GenerateCheckinMessagePayload( + clientInfoA.aes_key_handle, clientInfoA.hmac_key_handle, counter, ByteSpan(), output3), + CHIP_NO_ERROR); + buffer3->SetDataLength(static_cast(output3.size())); + EXPECT_EQ(checkInHandler.ValidateOnMessageReceived(nullptr, payloadHeader, std::move(buffer3)), CHIP_NO_ERROR); + ICDClientInfo clientInfo3; + iterator = manager.IterateICDClientInfo(); + ASSERT_NE(iterator, nullptr); + while (iterator->Next(clientInfo3)) + { + if (clientInfo3.peer_node.GetNodeId() == nodeIdA && clientInfo3.peer_node.GetFabricIndex() == fabricIdA) + { + break; + } + } + iterator->Release(); + EXPECT_EQ(clientInfo3.offset, counter - clientInfoA.start_icd_counter); + + // Validate check-in message from fabricB + counter++; + System::PacketBufferHandle buffer4 = MessagePacketBuffer::New(chip::Protocols::SecureChannel::CheckinMessage::kMinPayloadSize); + MutableByteSpan output4{ buffer4->Start(), buffer4->MaxDataLength() }; + EXPECT_EQ(chip::Protocols::SecureChannel::CheckinMessage::GenerateCheckinMessagePayload( + clientInfoB.aes_key_handle, clientInfoB.hmac_key_handle, counter, ByteSpan(), output4), + CHIP_NO_ERROR); + buffer4->SetDataLength(static_cast(output4.size())); + EXPECT_EQ(checkInHandler.ValidateOnMessageReceived(nullptr, payloadHeader, std::move(buffer4)), CHIP_NO_ERROR); + ICDClientInfo clientInfo4; + iterator = manager.IterateICDClientInfo(); + ASSERT_NE(iterator, nullptr); + while (iterator->Next(clientInfo4)) + { + if (clientInfo4.peer_node.GetNodeId() == nodeIdB && clientInfo4.peer_node.GetFabricIndex() == fabricIdB) + { + break; + } + } + iterator->Release(); + EXPECT_EQ(clientInfo4.offset, counter - clientInfoB.start_icd_counter); + + // Validate check-in message from removed fabricB + counter++; + System::PacketBufferHandle buffer5 = MessagePacketBuffer::New(chip::Protocols::SecureChannel::CheckinMessage::kMinPayloadSize); + MutableByteSpan output5{ buffer5->Start(), buffer5->MaxDataLength() }; + EXPECT_EQ(chip::Protocols::SecureChannel::CheckinMessage::GenerateCheckinMessagePayload( + clientInfoB.aes_key_handle, clientInfoB.hmac_key_handle, counter, ByteSpan(), output5), + CHIP_NO_ERROR); + + EXPECT_EQ(manager.DeleteAllEntries(fabricIdB), CHIP_NO_ERROR); + buffer5->SetDataLength(static_cast(output5.size())); + EXPECT_EQ(checkInHandler.ValidateOnMessageReceived(nullptr, payloadHeader, std::move(buffer5)), CHIP_NO_ERROR); + ICDClientInfo clientInfo5; + iterator = manager.IterateICDClientInfo(); + ASSERT_NE(iterator, nullptr); + bool located = false; + while (iterator->Next(clientInfo5)) + { + if (clientInfo5.peer_node.GetFabricIndex() == fabricIdB) + { + located = true; + break; + } + } + iterator->Release(); + EXPECT_FALSE(located); + + // Add back fabricB and validate check-in message again + EXPECT_EQ(manager.UpdateFabricList(fabricIdB), CHIP_NO_ERROR); + EXPECT_EQ(manager.SetKey(clientInfoB, ByteSpan(kKeyBuffer2)), CHIP_NO_ERROR); + EXPECT_EQ(manager.StoreEntry(clientInfoB), CHIP_NO_ERROR); + counter++; + System::PacketBufferHandle buffer6 = MessagePacketBuffer::New(chip::Protocols::SecureChannel::CheckinMessage::kMinPayloadSize); + MutableByteSpan output6{ buffer6->Start(), buffer6->MaxDataLength() }; + EXPECT_EQ(chip::Protocols::SecureChannel::CheckinMessage::GenerateCheckinMessagePayload( + clientInfoB.aes_key_handle, clientInfoB.hmac_key_handle, counter, ByteSpan(), output6), + CHIP_NO_ERROR); + buffer6->SetDataLength(static_cast(output4.size())); + EXPECT_EQ(checkInHandler.ValidateOnMessageReceived(nullptr, payloadHeader, std::move(buffer6)), CHIP_NO_ERROR); + ICDClientInfo clientInfo6; + iterator = manager.IterateICDClientInfo(); + ASSERT_NE(iterator, nullptr); + while (iterator->Next(clientInfo6)) + { + if (clientInfo6.peer_node.GetNodeId() == nodeIdB && clientInfo6.peer_node.GetFabricIndex() == fabricIdB) + { + break; + } + } + iterator->Release(); + EXPECT_EQ(clientInfo6.offset, counter - clientInfoB.start_icd_counter); + + // Clear fabric table + EXPECT_EQ(manager.DeleteAllEntries(fabricIdA), CHIP_NO_ERROR); + EXPECT_EQ(manager.DeleteAllEntries(fabricIdB), CHIP_NO_ERROR); + // Add back fabricA and validate check-in message again + EXPECT_EQ(manager.UpdateFabricList(fabricIdA), CHIP_NO_ERROR); + EXPECT_EQ(manager.SetKey(clientInfoA, ByteSpan(kKeyBuffer1)), CHIP_NO_ERROR); + EXPECT_EQ(manager.StoreEntry(clientInfoA), CHIP_NO_ERROR); + counter++; + System::PacketBufferHandle buffer7 = MessagePacketBuffer::New(chip::Protocols::SecureChannel::CheckinMessage::kMinPayloadSize); + MutableByteSpan output7{ buffer7->Start(), buffer7->MaxDataLength() }; + EXPECT_EQ(chip::Protocols::SecureChannel::CheckinMessage::GenerateCheckinMessagePayload( + clientInfoA.aes_key_handle, clientInfoA.hmac_key_handle, counter, ByteSpan(), output7), + CHIP_NO_ERROR); + buffer7->SetDataLength(static_cast(output7.size())); + EXPECT_EQ(checkInHandler.ValidateOnMessageReceived(nullptr, payloadHeader, std::move(buffer7)), CHIP_NO_ERROR); + ICDClientInfo clientInfo7; + iterator = manager.IterateICDClientInfo(); + ASSERT_NE(iterator, nullptr); + while (iterator->Next(clientInfo7)) + { + if (clientInfo7.peer_node.GetNodeId() == nodeIdA && clientInfo7.peer_node.GetFabricIndex() == fabricIdA) + { + break; + } + } + iterator->Release(); + EXPECT_EQ(clientInfo7.offset, counter - clientInfoA.start_icd_counter); + + // Validate IcdclientInfo is not updated when handling overlimited counter and fail to create case session + uint32_t old_start_icd_counter = clientInfo7.start_icd_counter; + uint32_t old_counter = counter; + counter = (1U << 31) + 100U + clientInfo7.start_icd_counter; + System::PacketBufferHandle buffer8 = MessagePacketBuffer::New(chip::Protocols::SecureChannel::CheckinMessage::kMinPayloadSize); + MutableByteSpan output8{ buffer8->Start(), buffer8->MaxDataLength() }; + EXPECT_EQ(chip::Protocols::SecureChannel::CheckinMessage::GenerateCheckinMessagePayload( + clientInfoA.aes_key_handle, clientInfoA.hmac_key_handle, counter, ByteSpan(), output8), + CHIP_NO_ERROR); + + buffer8->SetDataLength(static_cast(output8.size())); + EXPECT_EQ(checkInHandler.ValidateOnMessageReceived(nullptr, payloadHeader, std::move(buffer8)), CHIP_NO_ERROR); + ICDClientInfo clientInfo8; + iterator = manager.IterateICDClientInfo(); + ASSERT_NE(iterator, nullptr); + while (iterator->Next(clientInfo8)) + { + if (clientInfo8.peer_node.GetNodeId() == nodeIdA && clientInfo8.peer_node.GetFabricIndex() == fabricIdA) + { + break; + } + } + iterator->Release(); + EXPECT_EQ(clientInfo8.offset, old_counter - clientInfoA.start_icd_counter); + EXPECT_EQ(clientInfo8.start_icd_counter, old_start_icd_counter); + + checkInHandler.Shutdown(); +} diff --git a/src/app/tests/TestDefaultICDClientStorage.cpp b/src/app/tests/TestDefaultICDClientStorage.cpp index ac950fdbc75715..0ea48d3326742d 100644 --- a/src/app/tests/TestDefaultICDClientStorage.cpp +++ b/src/app/tests/TestDefaultICDClientStorage.cpp @@ -103,6 +103,12 @@ TEST_F(TestDefaultICDClientStorage, TestClientInfoCount) iterator->Release(); + EXPECT_TRUE(manager.GetClientInfoStore()->SyncDoesKeyExist(DefaultStorageKeyAllocator::ICDFabricList().KeyName())); + EXPECT_TRUE(manager.GetClientInfoStore()->SyncDoesKeyExist( + DefaultStorageKeyAllocator::FabricICDClientInfoCounter(fabricId).KeyName())); + EXPECT_TRUE( + manager.GetClientInfoStore()->SyncDoesKeyExist(DefaultStorageKeyAllocator::ICDClientInfoKey(fabricId).KeyName())); + // Delete all and verify iterator counts 0 EXPECT_EQ(manager.DeleteAllEntries(fabricId), CHIP_NO_ERROR); iterator = manager.IterateICDClientInfo(); @@ -116,6 +122,14 @@ TEST_F(TestDefaultICDClientStorage, TestClientInfoCount) } iterator->Release(); EXPECT_EQ(count, 0u); + EXPECT_EQ(manager.GetFabricListSize(), 0u); + EXPECT_FALSE(manager.GetClientInfoStore()->SyncDoesKeyExist(DefaultStorageKeyAllocator::ICDFabricList().KeyName())); + EXPECT_FALSE(manager.GetClientInfoStore()->SyncDoesKeyExist( + DefaultStorageKeyAllocator::FabricICDClientInfoCounter(fabricId).KeyName())); + EXPECT_FALSE( + manager.GetClientInfoStore()->SyncDoesKeyExist(DefaultStorageKeyAllocator::ICDClientInfoKey(fabricId).KeyName())); + + EXPECT_EQ(manager.DeleteAllEntries(fabricId), CHIP_NO_ERROR); } { @@ -174,6 +188,15 @@ TEST_F(TestDefaultICDClientStorage, TestClientInfoCountMultipleFabric) EXPECT_EQ(manager.DeleteEntry(ScopedNodeId(nodeId3, fabricId2)), CHIP_NO_ERROR); EXPECT_EQ(iterator->Count(), 0u); + EXPECT_EQ(manager.GetFabricListSize(), 2u); + EXPECT_TRUE(manager.GetClientInfoStore()->SyncDoesKeyExist(DefaultStorageKeyAllocator::ICDFabricList().KeyName())); + EXPECT_TRUE(manager.GetClientInfoStore()->SyncDoesKeyExist( + DefaultStorageKeyAllocator::FabricICDClientInfoCounter(fabricId1).KeyName())); + EXPECT_TRUE(manager.GetClientInfoStore()->SyncDoesKeyExist(DefaultStorageKeyAllocator::ICDClientInfoKey(fabricId1).KeyName())); + EXPECT_TRUE(manager.GetClientInfoStore()->SyncDoesKeyExist( + DefaultStorageKeyAllocator::FabricICDClientInfoCounter(fabricId2).KeyName())); + EXPECT_TRUE(manager.GetClientInfoStore()->SyncDoesKeyExist(DefaultStorageKeyAllocator::ICDClientInfoKey(fabricId2).KeyName())); + // Verify ClientInfos manually count correctly size_t count = 0; ICDClientInfo clientInfo; @@ -183,6 +206,59 @@ TEST_F(TestDefaultICDClientStorage, TestClientInfoCountMultipleFabric) } EXPECT_FALSE(count); + + EXPECT_EQ(manager.DeleteEntry(ScopedNodeId(nodeId3, fabricId2)), CHIP_NO_ERROR); +} + +TEST_F(TestDefaultICDClientStorage, TestClientInfoCountMultipleFabricWithRemovingFabric) +{ + + FabricIndex fabricId1 = 1; + FabricIndex fabricId2 = 2; + NodeId nodeId1 = 6666; + NodeId nodeId2 = 6667; + NodeId nodeId3 = 6668; + DefaultICDClientStorage manager; + TestPersistentStorageDelegate clientInfoStorage; + TestSessionKeystoreImpl keystore; + EXPECT_EQ(manager.Init(&clientInfoStorage, &keystore), CHIP_NO_ERROR); + EXPECT_EQ(manager.UpdateFabricList(fabricId1), CHIP_NO_ERROR); + EXPECT_EQ(manager.UpdateFabricList(fabricId2), CHIP_NO_ERROR); + + // Write some ClientInfos and see the counts are correct + ICDClientInfo clientInfo1; + clientInfo1.peer_node = ScopedNodeId(nodeId1, fabricId1); + ICDClientInfo clientInfo2; + clientInfo2.peer_node = ScopedNodeId(nodeId2, fabricId1); + ICDClientInfo clientInfo3; + clientInfo3.peer_node = ScopedNodeId(nodeId3, fabricId2); + + EXPECT_EQ(manager.SetKey(clientInfo1, ByteSpan(kKeyBuffer1)), CHIP_NO_ERROR); + EXPECT_EQ(manager.StoreEntry(clientInfo1), CHIP_NO_ERROR); + + EXPECT_EQ(manager.SetKey(clientInfo2, ByteSpan(kKeyBuffer2)), CHIP_NO_ERROR); + EXPECT_EQ(manager.StoreEntry(clientInfo2), CHIP_NO_ERROR); + + EXPECT_EQ(manager.SetKey(clientInfo3, ByteSpan(kKeyBuffer3)), CHIP_NO_ERROR); + EXPECT_EQ(manager.StoreEntry(clientInfo3), CHIP_NO_ERROR); + // Make sure iterator counts correctly + auto * iterator = manager.IterateICDClientInfo(); + EXPECT_EQ(iterator->Count(), 3u); + iterator->Release(); + + EXPECT_EQ(manager.DeleteAllEntries(fabricId1), CHIP_NO_ERROR); + + iterator = manager.IterateICDClientInfo(); + ASSERT_NE(iterator, nullptr); + DefaultICDClientStorage::ICDClientInfoIteratorWrapper clientInfoIteratorWrapper(iterator); + EXPECT_EQ(iterator->Count(), 1u); + + EXPECT_EQ(manager.DeleteAllEntries(fabricId2), CHIP_NO_ERROR); + EXPECT_EQ(iterator->Count(), 0u); + + EXPECT_EQ(manager.StoreEntry(clientInfo1), CHIP_ERROR_INVALID_FABRIC_INDEX); + EXPECT_EQ(manager.StoreEntry(clientInfo2), CHIP_ERROR_INVALID_FABRIC_INDEX); + EXPECT_EQ(manager.StoreEntry(clientInfo3), CHIP_ERROR_INVALID_FABRIC_INDEX); } TEST_F(TestDefaultICDClientStorage, TestProcessCheckInPayload) @@ -214,9 +290,103 @@ TEST_F(TestDefaultICDClientStorage, TestProcessCheckInPayload) uint32_t checkInCounter = 0; ByteSpan payload{ buffer->Start(), buffer->DataLength() }; EXPECT_EQ(manager.ProcessCheckInPayload(payload, decodeClientInfo, checkInCounter), CHIP_NO_ERROR); + EXPECT_EQ(checkInCounter, counter); + + // Validate second check-in message with increased counter + counter++; + EXPECT_EQ(chip::Protocols::SecureChannel::CheckinMessage::GenerateCheckinMessagePayload( + clientInfo.aes_key_handle, clientInfo.hmac_key_handle, counter, ByteSpan(), output), + CHIP_NO_ERROR); + buffer->SetDataLength(static_cast(output.size())); + ByteSpan payload1{ buffer->Start(), buffer->DataLength() }; + EXPECT_EQ(manager.ProcessCheckInPayload(payload1, decodeClientInfo, checkInCounter), CHIP_NO_ERROR); + EXPECT_EQ(checkInCounter, counter); - // 2. Use a key not available in the storage for encoding + // Use a key not available in the storage for encoding EXPECT_EQ(manager.SetKey(clientInfo, ByteSpan(kKeyBuffer2)), CHIP_NO_ERROR); + EXPECT_EQ(chip::Protocols::SecureChannel::CheckinMessage::GenerateCheckinMessagePayload( + clientInfo.aes_key_handle, clientInfo.hmac_key_handle, counter, ByteSpan(), output), + CHIP_NO_ERROR); + + buffer->SetDataLength(static_cast(output.size())); + ByteSpan payload2{ buffer->Start(), buffer->DataLength() }; + EXPECT_EQ(manager.ProcessCheckInPayload(payload2, decodeClientInfo, checkInCounter), CHIP_ERROR_NOT_FOUND); +} + +TEST_F(TestDefaultICDClientStorage, TestProcessCheckInPayloadWithRemovedKey) +{ + FabricIndex fabricId = 1; + NodeId nodeId = 6666; + TestPersistentStorageDelegate clientInfoStorage; + TestSessionKeystoreImpl keystore; + + DefaultICDClientStorage manager; + EXPECT_EQ(manager.Init(&clientInfoStorage, &keystore), CHIP_NO_ERROR); + EXPECT_EQ(manager.UpdateFabricList(fabricId), CHIP_NO_ERROR); + // Populate clientInfo + ICDClientInfo clientInfo; + clientInfo.peer_node = ScopedNodeId(nodeId, fabricId); + + EXPECT_EQ(manager.SetKey(clientInfo, ByteSpan(kKeyBuffer1)), CHIP_NO_ERROR); + EXPECT_EQ(manager.StoreEntry(clientInfo), CHIP_NO_ERROR); + + uint32_t counter = 1; + System::PacketBufferHandle buffer = MessagePacketBuffer::New(chip::Protocols::SecureChannel::CheckinMessage::kMinPayloadSize); + MutableByteSpan output{ buffer->Start(), buffer->MaxDataLength() }; + EXPECT_EQ(chip::Protocols::SecureChannel::CheckinMessage::GenerateCheckinMessagePayload( + clientInfo.aes_key_handle, clientInfo.hmac_key_handle, counter, ByteSpan(), output), + CHIP_NO_ERROR); + + buffer->SetDataLength(static_cast(output.size())); + ICDClientInfo decodeClientInfo; + uint32_t checkInCounter = 0; + ByteSpan payload{ buffer->Start(), buffer->DataLength() }; + EXPECT_EQ(manager.ProcessCheckInPayload(payload, decodeClientInfo, checkInCounter), CHIP_NO_ERROR); + EXPECT_EQ(checkInCounter, counter); + + // Use a removed key in the storage for encoding + manager.RemoveKey(clientInfo), CHIP_NO_ERROR; + EXPECT_EQ(chip::Protocols::SecureChannel::CheckinMessage::GenerateCheckinMessagePayload( + clientInfo.aes_key_handle, clientInfo.hmac_key_handle, counter, ByteSpan(), output), + CHIP_NO_ERROR); + + buffer->SetDataLength(static_cast(output.size())); + ByteSpan payload1{ buffer->Start(), buffer->DataLength() }; + EXPECT_EQ(manager.ProcessCheckInPayload(payload1, decodeClientInfo, checkInCounter), CHIP_ERROR_NOT_FOUND); +} + +TEST_F(TestDefaultICDClientStorage, TestProcessCheckInPayloadWithEmptyIcdStorage) +{ + FabricIndex fabricId = 1; + NodeId nodeId = 6666; + TestPersistentStorageDelegate clientInfoStorage; + TestSessionKeystoreImpl keystore; + + DefaultICDClientStorage manager; + EXPECT_EQ(manager.Init(&clientInfoStorage, &keystore), CHIP_NO_ERROR); + EXPECT_EQ(manager.UpdateFabricList(fabricId), CHIP_NO_ERROR); + // Populate clientInfo + ICDClientInfo clientInfo; + clientInfo.peer_node = ScopedNodeId(nodeId, fabricId); + + EXPECT_EQ(manager.SetKey(clientInfo, ByteSpan(kKeyBuffer1)), CHIP_NO_ERROR); + EXPECT_EQ(manager.StoreEntry(clientInfo), CHIP_NO_ERROR); + + uint32_t counter = 1; + System::PacketBufferHandle buffer = MessagePacketBuffer::New(chip::Protocols::SecureChannel::CheckinMessage::kMinPayloadSize); + MutableByteSpan output{ buffer->Start(), buffer->MaxDataLength() }; + EXPECT_EQ(chip::Protocols::SecureChannel::CheckinMessage::GenerateCheckinMessagePayload( + clientInfo.aes_key_handle, clientInfo.hmac_key_handle, counter, ByteSpan(), output), + CHIP_NO_ERROR); + + buffer->SetDataLength(static_cast(output.size())); + ICDClientInfo decodeClientInfo; + uint32_t checkInCounter = 0; + ByteSpan payload{ buffer->Start(), buffer->DataLength() }; + EXPECT_EQ(manager.ProcessCheckInPayload(payload, decodeClientInfo, checkInCounter), CHIP_NO_ERROR); + EXPECT_EQ(checkInCounter, counter); + manager.DeleteAllEntries(fabricId); + EXPECT_EQ(chip::Protocols::SecureChannel::CheckinMessage::GenerateCheckinMessagePayload( clientInfo.aes_key_handle, clientInfo.hmac_key_handle, counter, ByteSpan(), output), CHIP_NO_ERROR);