Skip to content

Commit

Permalink
Fix the oversized list handling in ClusterStateCache (#18187)
Browse files Browse the repository at this point in the history
* Fix the oversized list handling in ClusterStateCache
  • Loading branch information
yunhanw-google authored and pull[bot] committed Nov 6, 2023
1 parent 4aba70a commit f749290
Show file tree
Hide file tree
Showing 6 changed files with 625 additions and 63 deletions.
60 changes: 32 additions & 28 deletions src/app/ClusterStateCache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,25 @@
namespace chip {
namespace app {

CHIP_ERROR ClusterStateCache::GetElementTLVSize(TLV::TLVReader * apData, size_t & aSize)
{
Platform::ScopedMemoryBufferWithSize<uint8_t> backingBuffer;
TLV::TLVReader reader;
reader.Init(*apData);
size_t totalBufSize = reader.GetTotalLength();
backingBuffer.Calloc(totalBufSize);
VerifyOrReturnError(backingBuffer.Get() != nullptr, CHIP_ERROR_NO_MEMORY);
TLV::ScopedBufferTLVWriter writer(std::move(backingBuffer), totalBufSize);
ReturnErrorOnFailure(writer.CopyElement(TLV::AnonymousTag(), reader));
aSize = writer.GetLengthWritten();
ReturnErrorOnFailure(writer.Finalize(backingBuffer));
return CHIP_NO_ERROR;
}

CHIP_ERROR ClusterStateCache::UpdateCache(const ConcreteDataAttributePath & aPath, TLV::TLVReader * apData,
const StatusIB & aStatus)
{
AttributeState state;
System::PacketBufferHandle handle;
System::PacketBufferTLVWriter writer;
bool endpointIsNew = false;

if (mCache.find(aPath.mEndpointId) == mCache.end())
Expand All @@ -44,21 +57,16 @@ CHIP_ERROR ClusterStateCache::UpdateCache(const ConcreteDataAttributePath & aPat

if (apData)
{
handle = System::PacketBufferHandle::New(chip::app::kMaxSecureSduLengthBytes);

writer.Init(std::move(handle), false);

size_t elementSize = 0;
ReturnErrorOnFailure(GetElementTLVSize(apData, elementSize));
Platform::ScopedMemoryBufferWithSize<uint8_t> backingBuffer;
backingBuffer.Calloc(elementSize);
VerifyOrReturnError(backingBuffer.Get() != nullptr, CHIP_ERROR_NO_MEMORY);
TLV::ScopedBufferTLVWriter writer(std::move(backingBuffer), elementSize);
ReturnErrorOnFailure(writer.CopyElement(TLV::AnonymousTag(), *apData));
ReturnErrorOnFailure(writer.Finalize(&handle));

//
// Compact the buffer down to a more reasonably sized packet buffer
// if we can.
//
handle.RightSize();

state.Set<System::PacketBufferHandle>(std::move(handle));
ReturnErrorOnFailure(writer.Finalize(backingBuffer));

state.Set<Platform::ScopedMemoryBufferWithSize<uint8_t>>(std::move(backingBuffer));
//
// Clear out the committed data version and only set it again once we have received all data for this cluster.
// Otherwise, we may have incomplete data that looks like it's complete since it has a valid data version.
Expand Down Expand Up @@ -204,22 +212,16 @@ void ClusterStateCache::OnReportEnd()
CHIP_ERROR ClusterStateCache::Get(const ConcreteAttributePath & path, TLV::TLVReader & reader)
{
CHIP_ERROR err;

auto attributeState = GetAttributeState(path.mEndpointId, path.mClusterId, path.mAttributeId, err);
ReturnErrorOnFailure(err);

if (attributeState->Is<StatusIB>())
{
return CHIP_ERROR_IM_STATUS_CODE_RECEIVED;
}

System::PacketBufferTLVReader bufReader;

bufReader.Init(attributeState->Get<System::PacketBufferHandle>().Retain());
ReturnErrorOnFailure(bufReader.Next());

reader.Init(bufReader);
return CHIP_NO_ERROR;
reader.Init(attributeState->Get<Platform::ScopedMemoryBufferWithSize<uint8_t>>().Get(),
attributeState->Get<Platform::ScopedMemoryBufferWithSize<uint8_t>>().BufferByteSize());
return reader.Next();
}

CHIP_ERROR ClusterStateCache::Get(EventNumber eventNumber, TLV::TLVReader & reader)
Expand Down Expand Up @@ -336,10 +338,11 @@ void ClusterStateCache::OnAttributeData(const ConcreteDataAttributePath & aPath,
mCallback.OnAttributeData(aPath, apData ? &dataSnapshot : nullptr, aStatus);
}

CHIP_ERROR ClusterStateCache::GetVersion(EndpointId mEndpointId, ClusterId mClusterId, Optional<DataVersion> & aVersion)
CHIP_ERROR ClusterStateCache::GetVersion(const ConcreteClusterPath & aPath, Optional<DataVersion> & aVersion)
{
VerifyOrReturnError(aPath.IsValidConcreteClusterPath(), CHIP_ERROR_INVALID_ARGUMENT);
CHIP_ERROR err;
auto clusterState = GetClusterState(mEndpointId, mClusterId, err);
auto clusterState = GetClusterState(aPath.mEndpointId, aPath.mClusterId, err);
ReturnErrorOnFailure(err);
aVersion = clusterState->mCommittedDataVersion;
return CHIP_NO_ERROR;
Expand Down Expand Up @@ -417,8 +420,9 @@ void ClusterStateCache::GetSortedFilters(std::vector<std::pair<DataVersionFilter
}
else
{
System::PacketBufferTLVReader bufReader;
bufReader.Init(attributeIter.second.Get<System::PacketBufferHandle>().Retain());
TLV::TLVReader bufReader;
bufReader.Init(attributeIter.second.Get<Platform::ScopedMemoryBufferWithSize<uint8_t>>().Get(),
attributeIter.second.Get<Platform::ScopedMemoryBufferWithSize<uint8_t>>().BufferByteSize());
ReturnOnFailure(bufReader.Next());
// Skip to the end of the element.
ReturnOnFailure(bufReader.Skip());
Expand Down
6 changes: 4 additions & 2 deletions src/app/ClusterStateCache.h
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ class ClusterStateCache : protected ReadClient::Callback
* current data version for the cluster (which may have no value if we don't have a known data version
* for it, for example because none of our paths were wildcards that covered the whole cluster).
*/
CHIP_ERROR GetVersion(EndpointId mEndpointId, ClusterId mClusterId, Optional<DataVersion> & aVersion);
CHIP_ERROR GetVersion(const ConcreteClusterPath & path, Optional<DataVersion> & aVersion);

/*
* Get highest received event number.
Expand Down Expand Up @@ -483,7 +483,7 @@ class ClusterStateCache : protected ReadClient::Callback
}

private:
using AttributeState = Variant<System::PacketBufferHandle, StatusIB>;
using AttributeState = Variant<Platform::ScopedMemoryBufferWithSize<uint8_t>, StatusIB>;
// mPendingDataVersion represents a tentative data version for a cluster that we have gotten some reports for.
//
// mCurrentDataVersion represents a known data version for a cluster. In order for this to have a
Expand Down Expand Up @@ -591,6 +591,8 @@ class ClusterStateCache : protected ReadClient::Callback
// on the wire if not all filters can be applied.
void GetSortedFilters(std::vector<std::pair<DataVersionFilter, size_t>> & aVector);

CHIP_ERROR GetElementTLVSize(TLV::TLVReader * apData, size_t & aSize);

Callback & mCallback;
NodeState mCache;
std::set<ConcreteAttributePath> mChangedAttributeSet;
Expand Down
18 changes: 10 additions & 8 deletions src/app/tests/TestClusterStateCache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,6 @@ class DataSeriesGenerator

void DataSeriesGenerator::Generate(ForwardedDataCallbackValidator & dataCallbackValidator)
{
System::PacketBufferHandle handle;
System::PacketBufferTLVWriter writer;
System::PacketBufferTLVReader reader;
ReadClient::Callback * callback = mReadCallback;
StatusIB status;
callback->OnReportBegin();
Expand All @@ -187,8 +184,10 @@ void DataSeriesGenerator::Generate(ForwardedDataCallbackValidator & dataCallback
for (auto & instruction : mInstructionList)
{
ConcreteDataAttributePath path(instruction.mEndpointId, Clusters::TestCluster::Id, 0);
handle = System::PacketBufferHandle::New(1000);
writer.Init(std::move(handle), true);
Platform::ScopedMemoryBufferWithSize<uint8_t> handle;
handle.Calloc(3000);
TLV::ScopedBufferTLVWriter writer(std::move(handle), 3000);

status = StatusIB();
path.mAttributeId = instruction.GetAttributeId();
path.mDataVersion.SetValue(1);
Expand Down Expand Up @@ -231,7 +230,8 @@ void DataSeriesGenerator::Generate(ForwardedDataCallbackValidator & dataCallback
case AttributeInstruction::kAttributeD: {
ChipLogProgress(DataManagement, "\t -- Generating D");

Clusters::TestCluster::Structs::TestListStructOctet::Type buf[4];
// buf[200] is 1.6k
Clusters::TestCluster::Structs::TestListStructOctet::Type buf[200];

for (auto & i : buf)
{
Expand All @@ -250,8 +250,10 @@ void DataSeriesGenerator::Generate(ForwardedDataCallbackValidator & dataCallback
break;
}

writer.Finalize(&handle);
reader.Init(std::move(handle));
uint32_t writtenLength = writer.GetLengthWritten();
writer.Finalize(handle);
TLV::ScopedBufferTLVReader reader;
reader.Init(std::move(handle), writtenLength);
NL_TEST_ASSERT(gSuite, reader.Next() == CHIP_NO_ERROR);
dataCallbackValidator.SetExpectation(reader, instruction.mEndpointId, instruction.mAttributeType);
callback->OnAttributeData(path, &reader, status);
Expand Down
Loading

0 comments on commit f749290

Please sign in to comment.