Skip to content

Commit

Permalink
Address SessionResumption reviews (#17385)
Browse files Browse the repository at this point in the history
* Address SessionResumption reviews

* Update src/protocols/secure_channel/SessionResumptionStorage.cpp

Co-authored-by: Boris Zbarsky <[email protected]>

Co-authored-by: Andrei Litvin <[email protected]>
Co-authored-by: Boris Zbarsky <[email protected]>
  • Loading branch information
3 people authored and pull[bot] committed Nov 28, 2023
1 parent ab78f98 commit 3b97000
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 40 deletions.
4 changes: 2 additions & 2 deletions src/lib/support/DefaultStorageKeyAllocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ class DefaultStorageKeyAllocator
{
return Format("f/%x/s/%08" PRIX32 "%08" PRIX32, fabric, static_cast<uint32_t>(nodeId >> 32), static_cast<uint32_t>(nodeId));
}
const char * SessionResumptionIndex() { return Format("f/sri"); }
const char * SessionResumption(const char * resumptionIdBase64) { return Format("s/%s", resumptionIdBase64); }
const char * SessionResumptionIndex() { return Format("g/sri"); }
const char * SessionResumption(const char * resumptionIdBase64) { return Format("g/s/%s", resumptionIdBase64); }

// Access Control
const char * AccessControlExtensionEntry(FabricIndex fabric) { return Format("f/%x/ac/1", fabric); }
Expand Down
6 changes: 4 additions & 2 deletions src/protocols/secure_channel/SessionResumptionStorage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,18 +114,20 @@ CHIP_ERROR SessionResumptionStorage::Delete(const ScopedNodeId & node)
{
if (found)
{
// index.mSize was decreased by 1 when found was set to true.
// So the (i+1)th element isn't out of bounds.
index.mNodes[i] = index.mNodes[i + 1];
}
else
{
if (index.mNodes[i] == node)
{
found = true;
index.mSize -= 1;
if (i + 1 != index.mSize)
if (i + 1 < index.mSize)
{
index.mNodes[i] = index.mNodes[i + 1];
}
index.mSize -= 1;
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions src/protocols/secure_channel/SessionResumptionStorage.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@
namespace chip {

/**
* @brief Stores assets for sessoin resumption. The resumption data are indexed by 2 indexes: ScopedNodeId and <FabricIndex,
* ResumptionId>. The index of ScopedNodeId is used when initiating a CASE session, it will look up the storage and check whether it
* is able to resume a previous session. The index of ResumptionId is used when receiving a Sigma1 with ResumptionId.
* @brief Stores assets for session resumption. The resumption data are indexed by 2 indexes: ScopedNodeId and ResumptionId. The
* index of ScopedNodeId is used when initiating a CASE session, it will look up the storage and check whether it is able to
* resume a previous session. The index of ResumptionId is used when receiving a Sigma1 with ResumptionId.
*
* The implementation saves 2 maps:
* * <FabricIndex, PeerNodeId> => <ResumptionId, ShareSecret, PeerCATs>
Expand Down
46 changes: 22 additions & 24 deletions src/protocols/secure_channel/SimpleSessionResumptionStorage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

namespace chip {

constexpr TLV::Tag SimpleSessionResumptionStorage::kIndexContentTag;
constexpr TLV::Tag SimpleSessionResumptionStorage::kFabricIndexTag;
constexpr TLV::Tag SimpleSessionResumptionStorage::kPeerNodeIdTag;
constexpr TLV::Tag SimpleSessionResumptionStorage::kResumptionIdTag;
Expand All @@ -51,12 +50,12 @@ const char * SimpleSessionResumptionStorage::StorageKey(DefaultStorageKeyAllocat

CHIP_ERROR SimpleSessionResumptionStorage::SaveIndex(const SessionIndex & index)
{
uint8_t buf[MaxIndexSize()];
std::array<uint8_t, MaxIndexSize()> buf;
TLV::TLVWriter writer;
writer.Init(buf);

TLV::TLVType arrayType;
ReturnErrorOnFailure(writer.StartContainer(kIndexContentTag, TLV::kTLVType_Array, arrayType));
ReturnErrorOnFailure(writer.StartContainer(TLV::AnonymousTag(), TLV::kTLVType_Array, arrayType));

for (size_t i = index.mSize; i < index.mSize; ++i)
{
Expand All @@ -73,23 +72,23 @@ CHIP_ERROR SimpleSessionResumptionStorage::SaveIndex(const SessionIndex & index)
VerifyOrReturnError(CanCastTo<uint16_t>(len), CHIP_ERROR_BUFFER_TOO_SMALL);

DefaultStorageKeyAllocator keyAlloc;
ReturnErrorOnFailure(mStorage->SyncSetKeyValue(keyAlloc.SessionResumptionIndex(), buf, static_cast<uint16_t>(len)));
ReturnErrorOnFailure(mStorage->SyncSetKeyValue(keyAlloc.SessionResumptionIndex(), buf.data(), static_cast<uint16_t>(len)));

return CHIP_NO_ERROR;
}

CHIP_ERROR SimpleSessionResumptionStorage::LoadIndex(SessionIndex & index)
{
uint8_t buf[MaxIndexSize()];
uint16_t len = static_cast<uint16_t>(MaxStateSize());
std::array<uint8_t, MaxIndexSize()> buf;
uint16_t len = static_cast<uint16_t>(buf.size());

DefaultStorageKeyAllocator keyAlloc;
ReturnErrorOnFailure(mStorage->SyncGetKeyValue(keyAlloc.SessionResumptionIndex(), buf, len));
ReturnErrorOnFailure(mStorage->SyncGetKeyValue(keyAlloc.SessionResumptionIndex(), buf.data(), len));

TLV::ContiguousBufferTLVReader reader;
reader.Init(buf, len);
reader.Init(buf.data(), len);

ReturnErrorOnFailure(reader.Next(TLV::kTLVType_Array, kIndexContentTag));
ReturnErrorOnFailure(reader.Next(TLV::kTLVType_Array, TLV::AnonymousTag()));
TLV::TLVType arrayType;
ReturnErrorOnFailure(reader.EnterContainer(arrayType));

Expand Down Expand Up @@ -133,9 +132,8 @@ CHIP_ERROR SimpleSessionResumptionStorage::LoadIndex(SessionIndex & index)

CHIP_ERROR SimpleSessionResumptionStorage::SaveLink(ConstResumptionIdView resumptionId, const ScopedNodeId & node)
{
// Save a link from resumptionId to node, in key: /f/<fabricIndex>/r/<resumptionId>
uint8_t buf[MaxScopedNodeIdSize()];

// Save a link from resumptionId to node, in key: /g/s/<resumptionId>
std::array<uint8_t, MaxScopedNodeIdSize()> buf;
TLV::TLVWriter writer;
writer.Init(buf);

Expand All @@ -149,20 +147,20 @@ CHIP_ERROR SimpleSessionResumptionStorage::SaveLink(ConstResumptionIdView resump
VerifyOrDie(CanCastTo<uint16_t>(len));

DefaultStorageKeyAllocator keyAlloc;
ReturnErrorOnFailure(mStorage->SyncSetKeyValue(StorageKey(keyAlloc, resumptionId), buf, static_cast<uint16_t>(len)));
ReturnErrorOnFailure(mStorage->SyncSetKeyValue(StorageKey(keyAlloc, resumptionId), buf.data(), static_cast<uint16_t>(len)));
return CHIP_NO_ERROR;
}

CHIP_ERROR SimpleSessionResumptionStorage::LoadLink(ConstResumptionIdView resumptionId, ScopedNodeId & node)
{
uint8_t buf[MaxScopedNodeIdSize()];
uint16_t len = static_cast<uint16_t>(MaxStateSize());
std::array<uint8_t, MaxScopedNodeIdSize()> buf;
uint16_t len = static_cast<uint16_t>(buf.size());

DefaultStorageKeyAllocator keyAlloc;
ReturnErrorOnFailure(mStorage->SyncGetKeyValue(StorageKey(keyAlloc, resumptionId), buf, len));
ReturnErrorOnFailure(mStorage->SyncGetKeyValue(StorageKey(keyAlloc, resumptionId), buf.data(), len));

TLV::ContiguousBufferTLVReader reader;
reader.Init(buf, len);
reader.Init(buf.data(), len);

ReturnErrorOnFailure(reader.Next(TLV::kTLVType_Structure, TLV::AnonymousTag()));
TLV::TLVType containerType;
Expand Down Expand Up @@ -195,8 +193,7 @@ CHIP_ERROR SimpleSessionResumptionStorage::SaveState(const ScopedNodeId & node,
const Crypto::P256ECDHDerivedSecret & sharedSecret, const CATValues & peerCATs)
{
// Save session state into key: /f/<fabricIndex>/s/<nodeId>
uint8_t buf[MaxStateSize()];

std::array<uint8_t, MaxStateSize()> buf;
TLV::TLVWriter writer;
writer.Init(buf);

Expand All @@ -217,21 +214,21 @@ CHIP_ERROR SimpleSessionResumptionStorage::SaveState(const ScopedNodeId & node,
VerifyOrDie(CanCastTo<uint16_t>(len));

DefaultStorageKeyAllocator keyAlloc;
ReturnErrorOnFailure(mStorage->SyncSetKeyValue(StorageKey(keyAlloc, node), buf, static_cast<uint16_t>(len)));
ReturnErrorOnFailure(mStorage->SyncSetKeyValue(StorageKey(keyAlloc, node), buf.data(), static_cast<uint16_t>(len)));
return CHIP_NO_ERROR;
}

CHIP_ERROR SimpleSessionResumptionStorage::LoadState(const ScopedNodeId & node, ResumptionIdStorage & resumptionId,
Crypto::P256ECDHDerivedSecret & sharedSecret, CATValues & peerCATs)
{
uint8_t buf[MaxStateSize()];
uint16_t len = static_cast<uint16_t>(MaxStateSize());
std::array<uint8_t, MaxStateSize()> buf;
uint16_t len = static_cast<uint16_t>(buf.size());

DefaultStorageKeyAllocator keyAlloc;
ReturnErrorOnFailure(mStorage->SyncGetKeyValue(StorageKey(keyAlloc, node), buf, len));
ReturnErrorOnFailure(mStorage->SyncGetKeyValue(StorageKey(keyAlloc, node), buf.data(), len));

TLV::ContiguousBufferTLVReader reader;
reader.Init(buf, len);
reader.Init(buf.data(), len);

ReturnErrorOnFailure(reader.Next(TLV::kTLVType_Structure, TLV::AnonymousTag()));
TLV::TLVType containerType;
Expand All @@ -240,6 +237,7 @@ CHIP_ERROR SimpleSessionResumptionStorage::LoadState(const ScopedNodeId & node,
ByteSpan resumptionIdSpan;
ReturnErrorOnFailure(reader.Next(kResumptionIdTag));
ReturnErrorOnFailure(reader.Get(resumptionIdSpan));
VerifyOrReturnError(resumptionIdSpan.size() == resumptionId.size(), CHIP_ERROR_KEY_NOT_FOUND);
std::copy(resumptionIdSpan.begin(), resumptionIdSpan.end(), resumptionId.begin());

ByteSpan sharedSecretSpan;
Expand Down
16 changes: 7 additions & 9 deletions src/protocols/secure_channel/SimpleSessionResumptionStorage.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,8 @@ class SimpleSessionResumptionStorage : public SessionResumptionStorage

static constexpr size_t MaxIndexSize()
{
// The max size of the list is (1 byte control + bytes for actual value) times max number of list items, plus one byte for
// the list terminator.
return TLV::EstimateStructOverhead((1 + MaxScopedNodeIdSize()) * CHIP_CONFIG_CASE_SESSION_RESUME_CACHE_SIZE + 1);
// The max size of the list is (1 byte control + bytes for actual value) times max number of list items
return TLV::EstimateStructOverhead((1 + MaxScopedNodeIdSize()) * CHIP_CONFIG_CASE_SESSION_RESUME_CACHE_SIZE);
}

static constexpr size_t MaxStateSize()
Expand All @@ -75,12 +74,11 @@ class SimpleSessionResumptionStorage : public SessionResumptionStorage
CATValues::kSerializedLength);
}

static constexpr TLV::Tag kIndexContentTag = TLV::ContextTag(1);
static constexpr TLV::Tag kFabricIndexTag = TLV::ContextTag(2);
static constexpr TLV::Tag kPeerNodeIdTag = TLV::ContextTag(3);
static constexpr TLV::Tag kResumptionIdTag = TLV::ContextTag(4);
static constexpr TLV::Tag kSharedSecretTag = TLV::ContextTag(5);
static constexpr TLV::Tag kCATTag = TLV::ContextTag(6);
static constexpr TLV::Tag kFabricIndexTag = TLV::ContextTag(1);
static constexpr TLV::Tag kPeerNodeIdTag = TLV::ContextTag(2);
static constexpr TLV::Tag kResumptionIdTag = TLV::ContextTag(3);
static constexpr TLV::Tag kSharedSecretTag = TLV::ContextTag(4);
static constexpr TLV::Tag kCATTag = TLV::ContextTag(5);

PersistentStorageDelegate * mStorage;
};
Expand Down

0 comments on commit 3b97000

Please sign in to comment.