diff --git a/src/lib/support/DefaultStorageKeyAllocator.h b/src/lib/support/DefaultStorageKeyAllocator.h index 6af5c614b5fc07..b45b5d92ff0ab7 100644 --- a/src/lib/support/DefaultStorageKeyAllocator.h +++ b/src/lib/support/DefaultStorageKeyAllocator.h @@ -58,8 +58,8 @@ class DefaultStorageKeyAllocator { return Format("f/%x/s/%08" PRIX32 "%08" PRIX32, fabric, static_cast(nodeId >> 32), static_cast(nodeId)); } - const char * SessionResumptionIndex() { return Format("f/sri"); } - const char * SessionResumption(const char * resumptionIdBase64) { return Format("s/%s", resumptionIdBase64); } + const char * SessionResumptionIndex() { return Format("g/sri"); } + const char * SessionResumption(const char * resumptionIdBase64) { return Format("g/s/%s", resumptionIdBase64); } // Access Control const char * AccessControlExtensionEntry(FabricIndex fabric) { return Format("f/%x/ac/1", fabric); } diff --git a/src/protocols/secure_channel/SessionResumptionStorage.cpp b/src/protocols/secure_channel/SessionResumptionStorage.cpp index 368cc1930fd186..a32b743a9fbff8 100644 --- a/src/protocols/secure_channel/SessionResumptionStorage.cpp +++ b/src/protocols/secure_channel/SessionResumptionStorage.cpp @@ -114,6 +114,8 @@ CHIP_ERROR SessionResumptionStorage::Delete(const ScopedNodeId & node) { if (found) { + // index.mSize was decreased by 1 when found was set to true. + // So the (i+1)th element isn't out of bounds. index.mNodes[i] = index.mNodes[i + 1]; } else @@ -121,11 +123,11 @@ CHIP_ERROR SessionResumptionStorage::Delete(const ScopedNodeId & node) if (index.mNodes[i] == node) { found = true; - index.mSize -= 1; - if (i + 1 != index.mSize) + if (i + 1 < index.mSize) { index.mNodes[i] = index.mNodes[i + 1]; } + index.mSize -= 1; } } } diff --git a/src/protocols/secure_channel/SessionResumptionStorage.h b/src/protocols/secure_channel/SessionResumptionStorage.h index f42fb6976385d6..cb7a680fd66e19 100644 --- a/src/protocols/secure_channel/SessionResumptionStorage.h +++ b/src/protocols/secure_channel/SessionResumptionStorage.h @@ -31,9 +31,9 @@ namespace chip { /** - * @brief Stores assets for sessoin resumption. The resumption data are indexed by 2 indexes: ScopedNodeId and . The index of ScopedNodeId is used when initiating a CASE session, it will look up the storage and check whether it - * is able to resume a previous session. The index of ResumptionId is used when receiving a Sigma1 with ResumptionId. + * @brief Stores assets for session resumption. The resumption data are indexed by 2 indexes: ScopedNodeId and ResumptionId. The + * index of ScopedNodeId is used when initiating a CASE session, it will look up the storage and check whether it is able to + * resume a previous session. The index of ResumptionId is used when receiving a Sigma1 with ResumptionId. * * The implementation saves 2 maps: * * => diff --git a/src/protocols/secure_channel/SimpleSessionResumptionStorage.cpp b/src/protocols/secure_channel/SimpleSessionResumptionStorage.cpp index 033a53675bda0e..5f7ad75e306a48 100644 --- a/src/protocols/secure_channel/SimpleSessionResumptionStorage.cpp +++ b/src/protocols/secure_channel/SimpleSessionResumptionStorage.cpp @@ -29,7 +29,6 @@ namespace chip { -constexpr TLV::Tag SimpleSessionResumptionStorage::kIndexContentTag; constexpr TLV::Tag SimpleSessionResumptionStorage::kFabricIndexTag; constexpr TLV::Tag SimpleSessionResumptionStorage::kPeerNodeIdTag; constexpr TLV::Tag SimpleSessionResumptionStorage::kResumptionIdTag; @@ -51,12 +50,12 @@ const char * SimpleSessionResumptionStorage::StorageKey(DefaultStorageKeyAllocat CHIP_ERROR SimpleSessionResumptionStorage::SaveIndex(const SessionIndex & index) { - uint8_t buf[MaxIndexSize()]; + std::array buf; TLV::TLVWriter writer; writer.Init(buf); TLV::TLVType arrayType; - ReturnErrorOnFailure(writer.StartContainer(kIndexContentTag, TLV::kTLVType_Array, arrayType)); + ReturnErrorOnFailure(writer.StartContainer(TLV::AnonymousTag(), TLV::kTLVType_Array, arrayType)); for (size_t i = index.mSize; i < index.mSize; ++i) { @@ -73,23 +72,23 @@ CHIP_ERROR SimpleSessionResumptionStorage::SaveIndex(const SessionIndex & index) VerifyOrReturnError(CanCastTo(len), CHIP_ERROR_BUFFER_TOO_SMALL); DefaultStorageKeyAllocator keyAlloc; - ReturnErrorOnFailure(mStorage->SyncSetKeyValue(keyAlloc.SessionResumptionIndex(), buf, static_cast(len))); + ReturnErrorOnFailure(mStorage->SyncSetKeyValue(keyAlloc.SessionResumptionIndex(), buf.data(), static_cast(len))); return CHIP_NO_ERROR; } CHIP_ERROR SimpleSessionResumptionStorage::LoadIndex(SessionIndex & index) { - uint8_t buf[MaxIndexSize()]; - uint16_t len = static_cast(MaxStateSize()); + std::array buf; + uint16_t len = static_cast(buf.size()); DefaultStorageKeyAllocator keyAlloc; - ReturnErrorOnFailure(mStorage->SyncGetKeyValue(keyAlloc.SessionResumptionIndex(), buf, len)); + ReturnErrorOnFailure(mStorage->SyncGetKeyValue(keyAlloc.SessionResumptionIndex(), buf.data(), len)); TLV::ContiguousBufferTLVReader reader; - reader.Init(buf, len); + reader.Init(buf.data(), len); - ReturnErrorOnFailure(reader.Next(TLV::kTLVType_Array, kIndexContentTag)); + ReturnErrorOnFailure(reader.Next(TLV::kTLVType_Array, TLV::AnonymousTag())); TLV::TLVType arrayType; ReturnErrorOnFailure(reader.EnterContainer(arrayType)); @@ -133,9 +132,8 @@ CHIP_ERROR SimpleSessionResumptionStorage::LoadIndex(SessionIndex & index) CHIP_ERROR SimpleSessionResumptionStorage::SaveLink(ConstResumptionIdView resumptionId, const ScopedNodeId & node) { - // Save a link from resumptionId to node, in key: /f//r/ - uint8_t buf[MaxScopedNodeIdSize()]; - + // Save a link from resumptionId to node, in key: /g/s/ + std::array buf; TLV::TLVWriter writer; writer.Init(buf); @@ -149,20 +147,20 @@ CHIP_ERROR SimpleSessionResumptionStorage::SaveLink(ConstResumptionIdView resump VerifyOrDie(CanCastTo(len)); DefaultStorageKeyAllocator keyAlloc; - ReturnErrorOnFailure(mStorage->SyncSetKeyValue(StorageKey(keyAlloc, resumptionId), buf, static_cast(len))); + ReturnErrorOnFailure(mStorage->SyncSetKeyValue(StorageKey(keyAlloc, resumptionId), buf.data(), static_cast(len))); return CHIP_NO_ERROR; } CHIP_ERROR SimpleSessionResumptionStorage::LoadLink(ConstResumptionIdView resumptionId, ScopedNodeId & node) { - uint8_t buf[MaxScopedNodeIdSize()]; - uint16_t len = static_cast(MaxStateSize()); + std::array buf; + uint16_t len = static_cast(buf.size()); DefaultStorageKeyAllocator keyAlloc; - ReturnErrorOnFailure(mStorage->SyncGetKeyValue(StorageKey(keyAlloc, resumptionId), buf, len)); + ReturnErrorOnFailure(mStorage->SyncGetKeyValue(StorageKey(keyAlloc, resumptionId), buf.data(), len)); TLV::ContiguousBufferTLVReader reader; - reader.Init(buf, len); + reader.Init(buf.data(), len); ReturnErrorOnFailure(reader.Next(TLV::kTLVType_Structure, TLV::AnonymousTag())); TLV::TLVType containerType; @@ -195,8 +193,7 @@ CHIP_ERROR SimpleSessionResumptionStorage::SaveState(const ScopedNodeId & node, const Crypto::P256ECDHDerivedSecret & sharedSecret, const CATValues & peerCATs) { // Save session state into key: /f//s/ - uint8_t buf[MaxStateSize()]; - + std::array buf; TLV::TLVWriter writer; writer.Init(buf); @@ -217,21 +214,21 @@ CHIP_ERROR SimpleSessionResumptionStorage::SaveState(const ScopedNodeId & node, VerifyOrDie(CanCastTo(len)); DefaultStorageKeyAllocator keyAlloc; - ReturnErrorOnFailure(mStorage->SyncSetKeyValue(StorageKey(keyAlloc, node), buf, static_cast(len))); + ReturnErrorOnFailure(mStorage->SyncSetKeyValue(StorageKey(keyAlloc, node), buf.data(), static_cast(len))); return CHIP_NO_ERROR; } CHIP_ERROR SimpleSessionResumptionStorage::LoadState(const ScopedNodeId & node, ResumptionIdStorage & resumptionId, Crypto::P256ECDHDerivedSecret & sharedSecret, CATValues & peerCATs) { - uint8_t buf[MaxStateSize()]; - uint16_t len = static_cast(MaxStateSize()); + std::array buf; + uint16_t len = static_cast(buf.size()); DefaultStorageKeyAllocator keyAlloc; - ReturnErrorOnFailure(mStorage->SyncGetKeyValue(StorageKey(keyAlloc, node), buf, len)); + ReturnErrorOnFailure(mStorage->SyncGetKeyValue(StorageKey(keyAlloc, node), buf.data(), len)); TLV::ContiguousBufferTLVReader reader; - reader.Init(buf, len); + reader.Init(buf.data(), len); ReturnErrorOnFailure(reader.Next(TLV::kTLVType_Structure, TLV::AnonymousTag())); TLV::TLVType containerType; @@ -240,6 +237,7 @@ CHIP_ERROR SimpleSessionResumptionStorage::LoadState(const ScopedNodeId & node, ByteSpan resumptionIdSpan; ReturnErrorOnFailure(reader.Next(kResumptionIdTag)); ReturnErrorOnFailure(reader.Get(resumptionIdSpan)); + VerifyOrReturnError(resumptionIdSpan.size() == resumptionId.size(), CHIP_ERROR_KEY_NOT_FOUND); std::copy(resumptionIdSpan.begin(), resumptionIdSpan.end(), resumptionId.begin()); ByteSpan sharedSecretSpan; diff --git a/src/protocols/secure_channel/SimpleSessionResumptionStorage.h b/src/protocols/secure_channel/SimpleSessionResumptionStorage.h index 68daeeb1206336..6a7752df46e07b 100644 --- a/src/protocols/secure_channel/SimpleSessionResumptionStorage.h +++ b/src/protocols/secure_channel/SimpleSessionResumptionStorage.h @@ -64,9 +64,8 @@ class SimpleSessionResumptionStorage : public SessionResumptionStorage static constexpr size_t MaxIndexSize() { - // The max size of the list is (1 byte control + bytes for actual value) times max number of list items, plus one byte for - // the list terminator. - return TLV::EstimateStructOverhead((1 + MaxScopedNodeIdSize()) * CHIP_CONFIG_CASE_SESSION_RESUME_CACHE_SIZE + 1); + // The max size of the list is (1 byte control + bytes for actual value) times max number of list items + return TLV::EstimateStructOverhead((1 + MaxScopedNodeIdSize()) * CHIP_CONFIG_CASE_SESSION_RESUME_CACHE_SIZE); } static constexpr size_t MaxStateSize() @@ -75,12 +74,11 @@ class SimpleSessionResumptionStorage : public SessionResumptionStorage CATValues::kSerializedLength); } - static constexpr TLV::Tag kIndexContentTag = TLV::ContextTag(1); - static constexpr TLV::Tag kFabricIndexTag = TLV::ContextTag(2); - static constexpr TLV::Tag kPeerNodeIdTag = TLV::ContextTag(3); - static constexpr TLV::Tag kResumptionIdTag = TLV::ContextTag(4); - static constexpr TLV::Tag kSharedSecretTag = TLV::ContextTag(5); - static constexpr TLV::Tag kCATTag = TLV::ContextTag(6); + static constexpr TLV::Tag kFabricIndexTag = TLV::ContextTag(1); + static constexpr TLV::Tag kPeerNodeIdTag = TLV::ContextTag(2); + static constexpr TLV::Tag kResumptionIdTag = TLV::ContextTag(3); + static constexpr TLV::Tag kSharedSecretTag = TLV::ContextTag(4); + static constexpr TLV::Tag kCATTag = TLV::ContextTag(5); PersistentStorageDelegate * mStorage; };