Skip to content

Commit

Permalink
Fix AttributeCache forwarding callback (#16566)
Browse files Browse the repository at this point in the history
* Fix forwarded OnAttributeData callback losing data by
  copying the TLV reader before its state changes for cache update.
  • Loading branch information
kpark-apple authored and pull[bot] committed Jun 23, 2023
1 parent 155abf0 commit 4e342cb
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 7 deletions.
9 changes: 8 additions & 1 deletion src/app/AttributeCache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,19 @@ void AttributeCache::OnAttributeData(const ConcreteDataAttributePath & aPath, TL
//
VerifyOrDie(!aPath.IsListItemOperation());

// Copy the reader for forwarding
TLV::TLVReader dataSnapshot;
if (apData)
{
dataSnapshot.Init(*apData);
}

UpdateCache(aPath, apData, aStatus);

//
// Forward the call through.
//
mCallback.OnAttributeData(aPath, apData, aStatus);
mCallback.OnAttributeData(aPath, apData ? &dataSnapshot : nullptr, aStatus);
}

CHIP_ERROR AttributeCache::Get(const ConcreteAttributePath & path, TLV::TLVReader & reader)
Expand Down
96 changes: 90 additions & 6 deletions src/app/tests/TestAttributeCache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,21 +105,76 @@ uint8_t AttributeInstruction::sInstructionId = 0;

using AttributeInstructionListType = std::vector<AttributeInstruction>;

class ForwardedDataCallbackValidator final
{
public:
void SetExpectation(TLV::TLVReader & aData, EndpointId endpointId, AttributeInstruction::AttributeType attributeType)
{
auto length = aData.GetRemainingLength();
std::vector<uint8_t> buffer(aData.GetReadPoint(), aData.GetReadPoint() + length);
if (!mExpectedBuffers.empty() && endpointId == mLastEndpointId && attributeType == mLastAttributeType)
{
// For overriding test, the last buffered data is removed.
mExpectedBuffers.pop_back();
}
mExpectedBuffers.push_back(buffer);
mLastEndpointId = endpointId;
mLastAttributeType = attributeType;
}

void SetExpectation() { mExpectedBuffers.clear(); }

void ValidateData(TLV::TLVReader & aData, bool isListOperation)
{
NL_TEST_ASSERT(gSuite, !mExpectedBuffers.empty());
if (!mExpectedBuffers.empty() > 0)
{
auto buffer = mExpectedBuffers.front();
mExpectedBuffers.erase(mExpectedBuffers.begin());
uint32_t length = static_cast<uint32_t>(buffer.size());
if (isListOperation)
{
// List operation will attach end of container
NL_TEST_ASSERT(gSuite, length < aData.GetRemainingLength());
}
else
{
NL_TEST_ASSERT(gSuite, length == aData.GetRemainingLength());
}
if (length <= aData.GetRemainingLength() && length > 0)
{
NL_TEST_ASSERT(gSuite, memcmp(aData.GetReadPoint(), buffer.data(), length) == 0);
if (memcmp(aData.GetReadPoint(), buffer.data(), length) != 0)
{
ChipLogProgress(DataManagement, "Failed");
}
}
}
}

void ValidateNoData() { NL_TEST_ASSERT(gSuite, mExpectedBuffers.empty()); }

private:
std::vector<std::vector<uint8_t>> mExpectedBuffers;
EndpointId mLastEndpointId;
AttributeInstruction::AttributeType mLastAttributeType;
};

class DataSeriesGenerator
{
public:
DataSeriesGenerator(ReadClient::Callback * readCallback, AttributeInstructionListType & instructionList) :
mReadCallback(readCallback), mInstructionList(instructionList)
{}

void Generate();
void Generate(ForwardedDataCallbackValidator & dataCallbackValidator);

private:
ReadClient::Callback * mReadCallback;
AttributeInstructionListType & mInstructionList;
};

void DataSeriesGenerator::Generate()
void DataSeriesGenerator::Generate(ForwardedDataCallbackValidator & dataCallbackValidator)
{
System::PacketBufferHandle handle;
System::PacketBufferTLVWriter writer;
Expand Down Expand Up @@ -197,12 +252,14 @@ void DataSeriesGenerator::Generate()
writer.Finalize(&handle);
reader.Init(std::move(handle));
NL_TEST_ASSERT(gSuite, reader.Next() == CHIP_NO_ERROR);
dataCallbackValidator.SetExpectation(reader, instruction.mEndpointId, instruction.mAttributeType);
callback->OnAttributeData(path, &reader, status);
}
else
{
ChipLogProgress(DataManagement, "\t -- Generating Status");
status.mStatus = Protocols::InteractionModel::Status::Failure;
dataCallbackValidator.SetExpectation();
callback->OnAttributeData(path, nullptr, status);
}

Expand All @@ -215,12 +272,34 @@ void DataSeriesGenerator::Generate()
class CacheValidator : public AttributeCache::Callback
{
public:
CacheValidator(AttributeInstructionListType & instructionList);
CacheValidator(AttributeInstructionListType & instructionList, ForwardedDataCallbackValidator & dataCallbackValidator);

Clusters::TestCluster::Attributes::TypeInfo::DecodableType clusterValue;

private:
void OnDone() override {}
void OnAttributeData(const ConcreteDataAttributePath & aPath, TLV::TLVReader * apData, const StatusIB & aStatus) override
{
ChipLogProgress(DataManagement, "\t\t -- Validating OnAttributeData callback");
// Ensure that the provided path is one that we're expecting to find
auto iter = mExpectedAttributes.find(aPath);
NL_TEST_ASSERT(gSuite, iter != mExpectedAttributes.end());

if (aStatus.IsSuccess())
{
// Verify that the apData is passed as nonnull
NL_TEST_ASSERT(gSuite, apData != nullptr);
if (apData)
{
mDataCallbackValidator.ValidateData(*apData, aPath.IsListOperation());
}
}
else
{
mDataCallbackValidator.ValidateNoData();
}
}

void DecodeAttribute(const AttributeInstruction & instruction, const ConcreteAttributePath & path, AttributeCache * cache)
{
CHIP_ERROR err;
Expand Down Expand Up @@ -431,9 +510,13 @@ class CacheValidator : public AttributeCache::Callback
std::set<ConcreteAttributePath> mExpectedAttributes;
std::set<std::tuple<EndpointId, ClusterId>> mExpectedClusters;
std::set<EndpointId> mExpectedEndpoints;

ForwardedDataCallbackValidator & mDataCallbackValidator;
};

CacheValidator::CacheValidator(AttributeInstructionListType & instructionList)
CacheValidator::CacheValidator(AttributeInstructionListType & instructionList,
ForwardedDataCallbackValidator & dataCallbackValidator) :
mDataCallbackValidator(dataCallbackValidator)
{
for (auto & instruction : instructionList)
{
Expand All @@ -452,10 +535,11 @@ CacheValidator::CacheValidator(AttributeInstructionListType & instructionList)

void RunAndValidateSequence(AttributeInstructionListType list)
{
CacheValidator client(list);
ForwardedDataCallbackValidator dataCallbackValidator;
CacheValidator client(list, dataCallbackValidator);
AttributeCache cache(client);
DataSeriesGenerator generator(&cache.GetBufferedCallback(), list);
generator.Generate();
generator.Generate(dataCallbackValidator);
}

/*
Expand Down

0 comments on commit 4e342cb

Please sign in to comment.