Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ReadClient: Truncate data version list during encoding if necessary #34111

Merged
merged 7 commits into from
Jul 2, 2024
Merged
49 changes: 35 additions & 14 deletions src/app/ReadClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -423,33 +423,54 @@ CHIP_ERROR ReadClient::BuildDataVersionFilterList(DataVersionFilterIBs::Builder
continue;
}

DataVersionFilterIB::Builder & filterIB = aDataVersionFilterIBsBuilder.CreateDataVersionFilter();
ReturnErrorOnFailure(aDataVersionFilterIBsBuilder.GetError());
ClusterPathIB::Builder & path = filterIB.CreatePath();
ReturnErrorOnFailure(filterIB.GetError());
ReturnErrorOnFailure(path.Endpoint(filter.mEndpointId).Cluster(filter.mClusterId).EndOfClusterPathIB());
VerifyOrReturnError(filter.mDataVersion.HasValue(), CHIP_ERROR_INVALID_ARGUMENT);
ReturnErrorOnFailure(filterIB.DataVersion(filter.mDataVersion.Value()).EndOfDataVersionFilterIB());
aEncodedDataVersionList = true;
TLV::TLVWriter backup;
aDataVersionFilterIBsBuilder.Checkpoint(backup);
woody-apple marked this conversation as resolved.
Show resolved Hide resolved
CHIP_ERROR err = EncodeDataVersionFilter(aDataVersionFilterIBsBuilder, filter);
if (err == CHIP_NO_ERROR)
{
aEncodedDataVersionList = true;
}
else if (err == CHIP_ERROR_NO_MEMORY || err == CHIP_ERROR_BUFFER_TOO_SMALL)
{
// Packet is full, ignore the rest of the list
aDataVersionFilterIBsBuilder.Rollback(backup);
return CHIP_NO_ERROR;
}
else
{
return err;
}
}
return CHIP_NO_ERROR;
}

CHIP_ERROR ReadClient::EncodeDataVersionFilter(DataVersionFilterIBs::Builder & aDataVersionFilterIBsBuilder,
DataVersionFilter const & aFilter)
{
// Caller has checked aFilter.IsValidDataVersionFilter()
DataVersionFilterIB::Builder & filterIB = aDataVersionFilterIBsBuilder.CreateDataVersionFilter();
ReturnErrorOnFailure(aDataVersionFilterIBsBuilder.GetError());
ClusterPathIB::Builder & path = filterIB.CreatePath();
ReturnErrorOnFailure(filterIB.GetError());
ReturnErrorOnFailure(path.Endpoint(aFilter.mEndpointId).Cluster(aFilter.mClusterId).EndOfClusterPathIB());
ReturnErrorOnFailure(filterIB.DataVersion(aFilter.mDataVersion.Value()).EndOfDataVersionFilterIB());
return CHIP_NO_ERROR;
}

CHIP_ERROR ReadClient::GenerateDataVersionFilterList(DataVersionFilterIBs::Builder & aDataVersionFilterIBsBuilder,
const Span<AttributePathParams> & aAttributePaths,
const Span<DataVersionFilter> & aDataVersionFilters,
bool & aEncodedDataVersionList)
{
if (!aDataVersionFilters.empty())
// Give the callback a chance first, otherwise use the list we have, if any.
ReturnErrorOnFailure(
mpCallback.OnUpdateDataVersionFilterList(aDataVersionFilterIBsBuilder, aAttributePaths, aEncodedDataVersionList));

if (!aEncodedDataVersionList)
{
ReturnErrorOnFailure(BuildDataVersionFilterList(aDataVersionFilterIBsBuilder, aAttributePaths, aDataVersionFilters,
aEncodedDataVersionList));
}
else
{
ReturnErrorOnFailure(
mpCallback.OnUpdateDataVersionFilterList(aDataVersionFilterIBsBuilder, aAttributePaths, aEncodedDataVersionList));
}

return CHIP_NO_ERROR;
}
Expand Down
5 changes: 5 additions & 0 deletions src/app/ReadClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,9 @@ class ReadClient : public Messaging::ExchangeDelegate
* This will send either a Read Request or a Subscribe Request depending on
* the InteractionType this read client was initialized with.
*
* If the params contain more data version filters than can fit in the request packet
* the list will be truncated as needed, i.e. filter inclusion is on a best effort basis.
*
* @retval #others fail to send read request
* @retval #CHIP_NO_ERROR On success.
*/
Expand Down Expand Up @@ -559,6 +562,8 @@ class ReadClient : public Messaging::ExchangeDelegate
CHIP_ERROR BuildDataVersionFilterList(DataVersionFilterIBs::Builder & aDataVersionFilterIBsBuilder,
const Span<AttributePathParams> & aAttributePaths,
const Span<DataVersionFilter> & aDataVersionFilters, bool & aEncodedDataVersionList);
CHIP_ERROR EncodeDataVersionFilter(DataVersionFilterIBs::Builder & aDataVersionFilterIBsBuilder,
DataVersionFilter const & aFilter);
CHIP_ERROR ReadICDOperatingModeFromAttributeDataIB(TLV::TLVReader && aReader, PeerType & aType);
CHIP_ERROR ProcessAttributeReportIBs(TLV::TLVReader & aAttributeDataIBsReader);
CHIP_ERROR ProcessEventReportIBs(TLV::TLVReader & aEventReportIBsReader);
Expand Down
1 change: 1 addition & 0 deletions src/controller/tests/data_model/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ chip_test_suite("data_model") {
"${chip_root}/src/app/tests:helpers",
"${chip_root}/src/app/util/mock:mock_ember",
"${chip_root}/src/controller",
"${chip_root}/src/lib/core:string-builder-adapters",
"${chip_root}/src/messaging/tests:helpers",
"${chip_root}/src/transport/raw/tests:helpers",
]
Expand Down
3 changes: 2 additions & 1 deletion src/controller/tests/data_model/TestCommands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
*
*/

#include <gtest/gtest.h>
#include <lib/core/StringBuilderAdapters.h>
#include <pw_unit_test/framework.h>

#include "app/data-model/NullObject.h"
#include <app-common/zap-generated/cluster-objects.h>
Expand Down
84 changes: 83 additions & 1 deletion src/controller/tests/data_model/TestRead.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
* limitations under the License.
*/

#include <gtest/gtest.h>
#include <lib/core/StringBuilderAdapters.h>
#include <pw_unit_test/framework.h>

#include "system/SystemClock.h"
#include "transport/SecureSession.h"
Expand All @@ -25,6 +26,7 @@
#include <app/ConcreteAttributePath.h>
#include <app/ConcreteEventPath.h>
#include <app/InteractionModelEngine.h>
#include <app/ReadClient.h>
#include <app/tests/AppTestContext.h>
#include <app/util/mock/Constants.h>
#include <app/util/mock/Functions.h>
Expand Down Expand Up @@ -3093,6 +3095,86 @@ TEST_F(TestRead, TestReadHandler_MultipleSubscriptionsWithDataVersionFilter)
EXPECT_EQ(mpContext->GetExchangeManager().GetNumActiveExchanges(), 0u);
}

TEST_F(TestRead, TestReadHandler_DataVersionFiltersTruncated)
{
struct : public chip::Test::LoopbackTransportDelegate
{
size_t requestSize = 0;
void WillSendMessage(const Transport::PeerAddress & peer, const System::PacketBufferHandle & message) override
{
// We only care about the messages we (Alice) send to Bob, not the responses.
// Assume the first message we see in an iteration is the request.
if (peer == mpContext->GetBobAddress() && requestSize == 0)
{
requestSize = message->TotalLength();
}
}
} loopbackDelegate;
mpContext->GetLoopback().SetLoopbackTransportDelegate(&loopbackDelegate);

// Note that on the server side, wildcard expansion does not actually work for kTestEndpointId due
// to lack of meta-data, but we don't care about the reports we get back in this test.
AttributePathParams wildcardPath(kTestEndpointId, kInvalidClusterId, kInvalidAttributeId);
constexpr size_t maxDataVersionFilterCount = 100;
DataVersionFilter dataVersionFilters[maxDataVersionFilterCount];
ClusterId nextClusterId = 0;
for (auto & dv : dataVersionFilters)
{
dv.mEndpointId = wildcardPath.mEndpointId;
dv.mClusterId = nextClusterId++;
dv.mDataVersion = MakeOptional(0x01000000u);
}

// Keep increasing the number of data version filters until we see truncation kick in.
size_t lastRequestSize;
for (size_t count = 1; count <= maxDataVersionFilterCount; count++)
{
lastRequestSize = loopbackDelegate.requestSize;
loopbackDelegate.requestSize = 0; // reset

ReadPrepareParams read(mpContext->GetSessionAliceToBob());
read.mpAttributePathParamsList = &wildcardPath;
read.mAttributePathParamsListSize = 1;
read.mpDataVersionFilterList = dataVersionFilters;
read.mDataVersionFilterListSize = count;

struct : public ReadClient::Callback
{
CHIP_ERROR error = CHIP_NO_ERROR;
bool done = false;
void OnError(CHIP_ERROR aError) override { error = aError; }
void OnDone(ReadClient * apReadClient) override { done = true; };

} readCallback;

ReadClient readClient(app::InteractionModelEngine::GetInstance(), &mpContext->GetExchangeManager(), readCallback,
ReadClient::InteractionType::Read);

EXPECT_EQ(readClient.SendRequest(read), CHIP_NO_ERROR);

mpContext->GetIOContext().DriveIOUntil(System::Clock::Seconds16(5), [&]() { return readCallback.done; });
EXPECT_EQ(readCallback.error, CHIP_NO_ERROR);
EXPECT_EQ(mpContext->GetExchangeManager().GetNumActiveExchanges(), 0u);

EXPECT_NE(loopbackDelegate.requestSize, 0u);
EXPECT_GE(loopbackDelegate.requestSize, lastRequestSize);
if (loopbackDelegate.requestSize == lastRequestSize)
{
ChipLogProgress(DataManagement, "Data Version truncation detected after %llu elements",
static_cast<unsigned long long>(count - 1));
// With the parameters used in this test and current encoding rules we can fit 68 data versions
// into a packet. If we're seeing substantially less then something is likely gone wrong.
EXPECT_GE(count, 60u);
ExitNow();
}
}
ChipLogProgress(DataManagement, "Unable to detect Data Version truncation, maxDataVersionFilterCount too small?");
ADD_FAILURE();

exit:
mpContext->GetLoopback().SetLoopbackTransportDelegate(nullptr);
}

TEST_F(TestRead, TestReadHandlerResourceExhaustion_MultipleReads)
{
auto sessionHandle = mpContext->GetSessionBobToAlice();
Expand Down
3 changes: 2 additions & 1 deletion src/controller/tests/data_model/TestWrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
* limitations under the License.
*/

#include <gtest/gtest.h>
#include <lib/core/StringBuilderAdapters.h>
#include <pw_unit_test/framework.h>

#include "app-common/zap-generated/ids/Clusters.h"
#include <app-common/zap-generated/cluster-objects.h>
Expand Down
2 changes: 1 addition & 1 deletion src/messaging/tests/MessagingContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class MessagingContext : public PlatformMemoryUser

MessagingContext() :
mInitialized(false), mAliceAddress(Transport::PeerAddress::UDP(GetAddress(), CHIP_PORT + 1)),
mBobAddress(Transport::PeerAddress::UDP(GetAddress(), CHIP_PORT))
mBobAddress(LoopbackTransport::LoopbackPeer(mAliceAddress))
{}
// TODO Replace VerifyOrDie with Pigweed assert after transition app/tests to Pigweed.
// TODO Currently src/app/icd/server/tests is using MessagingConetext as dependency.
Expand Down
23 changes: 22 additions & 1 deletion src/transport/raw/tests/NetworkTestHelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ class LoopbackTransportDelegate
public:
virtual ~LoopbackTransportDelegate() {}

// Called by the loopback transport when a message is requested to be sent.
// This is called even if the message is subsequently rejected or dropped.
virtual void WillSendMessage(const Transport::PeerAddress & peer, const System::PacketBufferHandle & message) {}

// Called by the loopback transport when it drops one of a configurable number of messages (mDroppedMessageCount) after a
// configurable allowed number of messages (mNumMessagesToAllowBeforeDropping)
virtual void OnMessageDropped() {}
Expand All @@ -72,6 +76,18 @@ class LoopbackTransportDelegate
class LoopbackTransport : public Transport::Base
{
public:
// In test scenarios using the loopback transport, we're only ever given
// the address we're sending to, but we don't have any information about
// what our local address is. Assume our fake addresses come in pairs of
// even and odd port numbers, so we can calculate one from the other by
// flipping the LSB of the port number.
static Transport::PeerAddress LoopbackPeer(const Transport::PeerAddress & address)
{
Transport::PeerAddress other(address);
other.SetPort(address.GetPort() ^ 1);
return other;
}

void InitLoopbackTransport(System::Layer * systemLayer)
{
Reset();
Expand Down Expand Up @@ -100,14 +116,19 @@ class LoopbackTransport : public Transport::Base
{
auto item = std::move(_this->mPendingMessageQueue.front());
_this->mPendingMessageQueue.pop();
_this->HandleMessageReceived(item.mDestinationAddress, std::move(item.mPendingMessage));
_this->HandleMessageReceived(LoopbackPeer(item.mDestinationAddress), std::move(item.mPendingMessage));
}
}

static constexpr uint32_t kUnlimitedMessageCount = std::numeric_limits<uint32_t>::max();

CHIP_ERROR SendMessage(const Transport::PeerAddress & address, System::PacketBufferHandle && msgBuf) override
{
if (mDelegate != nullptr)
{
mDelegate->WillSendMessage(address, msgBuf);
}

if (mNumMessagesToAllowBeforeError == 0)
{
ReturnErrorOnFailure(mMessageSendError);
Expand Down
Loading