From 2416376b683233102ec61fac11f3dd92ded4cf4d Mon Sep 17 00:00:00 2001 From: Tennessee Carmel-Veilleux Date: Wed, 7 Dec 2022 15:25:29 -0500 Subject: [PATCH] Introduce basic UDP packet filtering scheme (#23957) * Introduce basic UDP packet filtering scheme - Embedded class devices on Wi-Fi networks can see large number of mDNS packets from normal network. If these arrive while device is processing CASE/PASE or any other long running system activity, the amount of queuing can overrun the packet buffer pools or event queues, which can cause very bad outcomes. - Filtering of packets can always be done by a product deep in its network stack. However, allowing some level of basic filtering for UDP packets (that cover most of the traffic that could be problematic and need queuing) for at least lwIP devices allows sharing code for filters that may be smarter. Follow-up PR will show an example hook-up to ESP32 Issue #23258 Issue #23180 This PR: - Adds generic EndpointQueueFilter.h, which is portable/testable on its own. - Adds hook-ups to UDPEndPointImplLwIP to run an EndpointQueueFilter on incoming packets Testing done: - Full unit test suite for the filtering framework and basic filter. * Address review comments --- src/inet/BUILD.gn | 2 + src/inet/BasicPacketFilters.h | 185 +++++++++++ src/inet/EndpointQueueFilter.h | 88 +++++ src/inet/UDPEndPointImplLwIP.cpp | 37 ++- src/inet/UDPEndPointImplLwIP.h | 17 + src/inet/tests/BUILD.gn | 1 + src/inet/tests/TestBasicPacketFilters.cpp | 385 ++++++++++++++++++++++ 7 files changed, 713 insertions(+), 2 deletions(-) create mode 100644 src/inet/BasicPacketFilters.h create mode 100644 src/inet/EndpointQueueFilter.h create mode 100644 src/inet/tests/TestBasicPacketFilters.cpp diff --git a/src/inet/BUILD.gn b/src/inet/BUILD.gn index cb86a45780df9c..f8f81f94e3c44a 100644 --- a/src/inet/BUILD.gn +++ b/src/inet/BUILD.gn @@ -122,6 +122,8 @@ static_library("inet") { if (chip_inet_config_enable_udp_endpoint) { sources += [ + "BasicPacketFilters.h", + "EndpointQueueFilter.h", "UDPEndPoint.cpp", "UDPEndPoint.h", "UDPEndPointImpl${chip_system_config_inet}.cpp", diff --git a/src/inet/BasicPacketFilters.h b/src/inet/BasicPacketFilters.h new file mode 100644 index 00000000000000..364c0c7c11f17d --- /dev/null +++ b/src/inet/BasicPacketFilters.h @@ -0,0 +1,185 @@ +/* + * Copyright (c) 2022 Project CHIP Authors + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace chip { +namespace Inet { + +/** + * @brief Basic filter that counts how many pending (not yet dequeued) packets + * are accumulated that match a predicate function, and drops those that + * would cause crossing of the threshold. + */ +class DropIfTooManyQueuedPacketsFilter : public chip::Inet::EndpointQueueFilter +{ +public: + typedef bool (*PacketMatchPredicateFunc)(void * context, const void * endpoint, const chip::Inet::IPPacketInfo & pktInfo, + const chip::System::PacketBufferHandle & pktPayload); + + /** + * @brief Initialize the packet filter with a starting limit + * + * @param maxAllowedQueuedPackets - max number of pending-in-queue not yet processed predicate-matching packets + */ + DropIfTooManyQueuedPacketsFilter(size_t maxAllowedQueuedPackets) : mMaxAllowedQueuedPackets(maxAllowedQueuedPackets) {} + + /** + * @brief Set the predicate to use for filtering + * + * @warning DO NOT modify at runtime while the filter is being called. If you do so, the queue accounting could + * get out of sync, and cause the filtering to fail to properly work. + * + * @param predicateFunc - Predicate function to apply. If nullptr, no filtering will take place + * @param context - Pointer to predicate-specific context that will be provided to predicate at every call. May be nullptr. + */ + void SetPredicate(PacketMatchPredicateFunc predicateFunc, void * context) + { + mPredicate = predicateFunc; + mContext = context; + } + + /** + * @brief Set the ceiling for max allowed packets queued up that matched the predicate. + * + * @note Changing this at runtime while packets are coming only affects future dropping, and + * does not remove packets from the queue if the limit is lowered below the currently-in-queue + * count. + * + * @param maxAllowedQueuedPackets - number of packets currently pending allowed. + */ + void SetMaxQueuedPacketsLimit(int maxAllowedQueuedPackets) { mMaxAllowedQueuedPackets.store(maxAllowedQueuedPackets); } + + /** + * @return the total number of packets dropped so far by the filter + */ + size_t GetNumDroppedPackets() const { return mNumDroppedPackets.load(); } + + /** + * @brief Reset the counter of dropped packets. + */ + void ClearNumDroppedPackets() { mNumDroppedPackets.store(0); } + + /** + * @brief Method called when a packet is dropped due to high watermark getting reached, based on predicate. + * + * Subclasses may use this to implement additional behavior or diagnostics. + * + * This is called once for every dropped packet. If there is no filter predicate, this is not called. + * + * @param endpoint - pointer to endpoint instance (platform-dependent, which is why it's void) + * @param pktInfo - info about source/dest of packet + * @param pktPayload - payload content of packet + */ + virtual void OnDropped(const void * endpoint, const chip::Inet::IPPacketInfo & pktInfo, + const chip::System::PacketBufferHandle & pktPayload) + {} + + /** + * @brief Method called whenever queue of accumulated packets is now empty, based on predicate. + * + * Subclasses may use this to implement additional behavior or diagnostics. + * + * This is possibly called repeatedly in a row, if the queue actually never gets above one. + * + * This is only called for packets that had matched the filtering rule, where they had + * been explicitly allowed in the past. If there is no filter predicate, this is not called. + * + * @param endpoint - pointer to endpoint instance (platform-dependent, which is why it's void) + * @param pktInfo - info about source/dest of packet + * @param pktPayload - payload content of packet + */ + virtual void OnLastMatchDequeued(const void * endpoint, const chip::Inet::IPPacketInfo & pktInfo, + const chip::System::PacketBufferHandle & pktPayload) + {} + + /** + * @brief Implementation of filtering before queueing that applies the predicate. + * + * See base class for arguments + */ + FilterOutcome FilterBeforeEnqueue(const void * endpoint, const chip::Inet::IPPacketInfo & pktInfo, + const chip::System::PacketBufferHandle & pktPayload) override + { + // WARNING: This is likely called in a different context than `FilterAfterDequeue`. We use an atomic for the counter. + + // Non-matching is never accounted, always allowed. Lack of predicate is equivalent to non-matching. + if ((mPredicate == nullptr) || !mPredicate(mContext, endpoint, pktInfo, pktPayload)) + { + return FilterOutcome::kAllowPacket; + } + + if (mNumQueuedPackets.load() >= mMaxAllowedQueuedPackets) + { + ++mNumDroppedPackets; + OnDropped(endpoint, pktInfo, pktPayload); + return FilterOutcome::kDropPacket; + } + + ++mNumQueuedPackets; + + return FilterOutcome::kAllowPacket; + } + + /** + * @brief Implementation of filtering after dequeueing that applies the predicate. + * + * See base class for arguments + */ + FilterOutcome FilterAfterDequeue(const void * endpoint, const chip::Inet::IPPacketInfo & pktInfo, + const chip::System::PacketBufferHandle & pktPayload) override + { + // WARNING: This is likely called in a different context than `FilterBeforeEnqueue`. We use an atomic for the counter. + // NOTE: This is always called from Matter platform event loop + + // Non-matching is never accounted, always allowed. Lack of predicate is equivalent to non-matching. + if ((mPredicate == nullptr) || !mPredicate(mContext, endpoint, pktInfo, pktPayload)) + { + return FilterOutcome::kAllowPacket; + } + + --mNumQueuedPackets; + int numQueuedPackets = mNumQueuedPackets.load(); + if (numQueuedPackets == 0) + { + OnLastMatchDequeued(endpoint, pktInfo, pktPayload); + } + + // If we ever go negative, we have mismatch ingress/egress filter via predicate and + // device may eventually starve. + VerifyOrDie(numQueuedPackets >= 0); + + // We always allow the packet and just do accounting, since all dropping is prior to queue entry. + return FilterOutcome::kAllowPacket; + } + +protected: + PacketMatchPredicateFunc mPredicate = nullptr; + void * mContext = nullptr; + std::atomic_int mNumQueuedPackets{ 0 }; + std::atomic_int mMaxAllowedQueuedPackets{ 0 }; + std::atomic_size_t mNumDroppedPackets{ 0u }; +}; + +} // namespace Inet +} // namespace chip diff --git a/src/inet/EndpointQueueFilter.h b/src/inet/EndpointQueueFilter.h new file mode 100644 index 00000000000000..3239b973ac3f64 --- /dev/null +++ b/src/inet/EndpointQueueFilter.h @@ -0,0 +1,88 @@ +/* + * Copyright (c) 2022 Project CHIP Authors + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace chip { +namespace Inet { + +/** + * @brief Filter for UDP Packets going into and out of UDPEndPoint queue. + * + * NOTE: This is only used by some low-level implementations of UDPEndPoint + */ +class EndpointQueueFilter +{ +public: + enum FilterOutcome + { + kAllowPacket = 0, + kDropPacket = 1, + }; + + virtual ~EndpointQueueFilter() {} + + /** + * @brief Run filter prior to inserting in queue. + * + * If filter returns `kAllowPacket`, packet will be enqueued, and `FilterAfterDequeue` will + * be called when it gets dequeued. If filter returns `kDropPacket`, packet will be dropped + * rather than enqueued and `FilterAfterDequeue` method will not be called. + * + * WARNING: This is likely called from non-Matter-eventloop context, from network layer code. + * Be extremely careful about accessing any system data which may belong to Matter + * stack from this method. + * + * @param endpoint - pointer to endpoint instance (platform-dependent, which is why it's void) + * @param pktInfo - info about source/dest of packet + * @param pktPayload - payload content of packet + * + * @return kAllowPacket to allow packet to enqueue or kDropPacket to drop the packet + */ + virtual FilterOutcome FilterBeforeEnqueue(const void * endpoint, const IPPacketInfo & pktInfo, + const chip::System::PacketBufferHandle & pktPayload) + { + return FilterOutcome::kAllowPacket; + } + + /** + * @brief Run filter after dequeuing, prior to processing. + * + * If filter returns `kAllowPacket`, packet will be processed after dequeuing. If filter returns + * `kDropPacket`, packet will be dropped and not processed, even though it was dequeued. + * + * WARNING: This is called from Matter thread context. Be extremely careful about accessing any + * data which may belong to different threads from this method. + * + * @param endpoint - pointer to endpoint instance (platform-dependent, which is why it's void) + * @param pktInfo - info about source/dest of packet + * @param pktPayload - payload content of packet + * + * @return kAllowPacket to allow packet to be processed or kDropPacket to drop the packet + */ + virtual FilterOutcome FilterAfterDequeue(const void * endpoint, const IPPacketInfo & pktInfo, + const chip::System::PacketBufferHandle & pktPayload) + { + return FilterOutcome::kAllowPacket; + } +}; + +} // namespace Inet +} // namespace chip diff --git a/src/inet/UDPEndPointImplLwIP.cpp b/src/inet/UDPEndPointImplLwIP.cpp index 07ce94c7d653e9..78355628420b8b 100644 --- a/src/inet/UDPEndPointImplLwIP.cpp +++ b/src/inet/UDPEndPointImplLwIP.cpp @@ -78,6 +78,8 @@ struct Deleter namespace chip { namespace Inet { +EndpointQueueFilter * UDPEndPointImplLwIP::sQueueFilter = nullptr; + CHIP_ERROR UDPEndPointImplLwIP::BindImpl(IPAddressType addressType, const IPAddress & address, uint16_t port, InterfaceId interfaceId) { @@ -287,7 +289,16 @@ void UDPEndPointImplLwIP::Free() void UDPEndPointImplLwIP::HandleDataReceived(System::PacketBufferHandle && msg, IPPacketInfo * pktInfo) { - if ((mState == State::kListening) && (OnMessageReceived != nullptr)) + // Process packet filter if needed. May cause packet to get dropped before processing. + bool dropPacket = false; + if ((pktInfo != nullptr) && (sQueueFilter != nullptr)) + { + auto outcome = sQueueFilter->FilterAfterDequeue(this, *pktInfo, msg); + dropPacket = (outcome == EndpointQueueFilter::FilterOutcome::kDropPacket); + } + + // Process actual packet if allowed + if ((mState == State::kListening) && (OnMessageReceived != nullptr) && !dropPacket) { if (pktInfo != nullptr) { @@ -424,6 +435,18 @@ void UDPEndPointImplLwIP::LwIPReceiveUDPMessage(void * arg, struct udp_pcb * pcb pktInfo->SrcPort = port; pktInfo->DestPort = pcb->local_port; + auto filterOutcome = EndpointQueueFilter::FilterOutcome::kAllowPacket; + if (sQueueFilter != nullptr) + { + filterOutcome = sQueueFilter->FilterBeforeEnqueue(ep, *(pktInfo.get()), buf); + } + + if (filterOutcome != EndpointQueueFilter::FilterOutcome::kAllowPacket) + { + // Logging, if any, should be at the choice of the filter impl at time of filtering. + return; + } + // Increase mDelayReleaseCount to delay release of this UDP EndPoint while the HandleDataReceived call is // pending on it. ep->mDelayReleaseCount++; @@ -431,7 +454,9 @@ void UDPEndPointImplLwIP::LwIPReceiveUDPMessage(void * arg, struct udp_pcb * pcb CHIP_ERROR err = ep->GetSystemLayer().ScheduleLambda( [ep, p = System::LwIPPacketBufferView::UnsafeGetLwIPpbuf(buf), pktInfo = pktInfo.get()] { ep->mDelayReleaseCount--; - ep->HandleDataReceived(System::PacketBufferHandle::Adopt(p), pktInfo); + + auto handle = System::PacketBufferHandle::Adopt(p); + ep->HandleDataReceived(std::move(handle), pktInfo); }); if (err == CHIP_NO_ERROR) @@ -443,6 +468,14 @@ void UDPEndPointImplLwIP::LwIPReceiveUDPMessage(void * arg, struct udp_pcb * pcb } else { + // On failure to enqueue the processing, we have to tell the filter that + // the packet is basically dequeued, if it tries to keep track of the lifecycle. + if (sQueueFilter != nullptr) + { + (void) sQueueFilter->FilterAfterDequeue(ep, *(pktInfo.get()), buf); + ChipLogError(Inet, "Dequeue ERROR err = %" CHIP_ERROR_FORMAT, err.Format()); + } + ep->mDelayReleaseCount--; } } diff --git a/src/inet/UDPEndPointImplLwIP.h b/src/inet/UDPEndPointImplLwIP.h index 3d8aebe1021f52..230c872743641c 100644 --- a/src/inet/UDPEndPointImplLwIP.h +++ b/src/inet/UDPEndPointImplLwIP.h @@ -24,6 +24,7 @@ #pragma once #include +#include #include namespace chip { @@ -40,6 +41,20 @@ class UDPEndPointImplLwIP : public UDPEndPoint, public EndPointStateLwIP uint16_t GetBoundPort() const override; void Free() override; + /** + * @brief Set the queue filter for all UDP endpoints + * + * Responsibility is on the caller to avoid changing the filter while packets are being + * processed. Setting the queue filter to `nullptr` disables the filtering. + * + * NOTE: There is only one EndpointQueueFilter instance settable. However it's possible + * to create an instance of EndpointQueueFilter that combines several other + * EndpointQueueFilter by composition to achieve the effect of multiple filters. + * + * @param queueFilter - queue filter instance to set, owned by caller + */ + static void SetQueueFilter(EndpointQueueFilter * queueFilter) { sQueueFilter = queueFilter; } + private: // UDPEndPoint overrides. #if INET_CONFIG_ENABLE_IPV4 @@ -62,6 +77,8 @@ class UDPEndPointImplLwIP : public UDPEndPoint, public EndPointStateLwIP udp_pcb * mUDP; // LwIP User datagram protocol (UDP) control block. std::atomic_int mDelayReleaseCount{ 0 }; + + static EndpointQueueFilter * sQueueFilter; }; using UDPEndPointImpl = UDPEndPointImplLwIP; diff --git a/src/inet/tests/BUILD.gn b/src/inet/tests/BUILD.gn index 1788410bb6c6d3..ce7fd4664096ba 100644 --- a/src/inet/tests/BUILD.gn +++ b/src/inet/tests/BUILD.gn @@ -74,6 +74,7 @@ chip_test_suite("tests") { ] test_sources = [ + "TestBasicPacketFilters.cpp", "TestInetAddress.cpp", "TestInetErrorStr.cpp", ] diff --git a/src/inet/tests/TestBasicPacketFilters.cpp b/src/inet/tests/TestBasicPacketFilters.cpp new file mode 100644 index 00000000000000..ef07cab057eb89 --- /dev/null +++ b/src/inet/tests/TestBasicPacketFilters.cpp @@ -0,0 +1,385 @@ +/* + * Copyright (c) 2022 Project CHIP Authors + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace { + +using namespace chip; +using namespace chip::Inet; + +class DropIfTooManyQueuedPacketsHarness : public DropIfTooManyQueuedPacketsFilter +{ +public: + DropIfTooManyQueuedPacketsHarness(size_t maxAllowedQueuedPackets) : DropIfTooManyQueuedPacketsFilter(maxAllowedQueuedPackets) {} + + void OnDropped(const void * endpoint, const IPPacketInfo & pktInfo, + const chip::System::PacketBufferHandle & pktPayload) override + { + ++mNumOnDroppedCalled; + + // Log a hysteretic event + if (!mHitCeilingWantFloor) + { + mHitCeilingWantFloor = true; + ChipLogError(Inet, "Hit waterwark, will log as resolved once we get back to empty"); + } + } + + void OnLastMatchDequeued(const void * endpoint, const IPPacketInfo & pktInfo, + const chip::System::PacketBufferHandle & pktPayload) override + { + ++mNumOnLastMatchDequeuedCalled; + + // Log a hysteretic event + if (mHitCeilingWantFloor) + { + mHitCeilingWantFloor = false; + ChipLogError(Inet, "Resolved burst, got back to fully empty."); + } + } + + // Public bits to make testing easier + nlTestSuite * mTestSuite = nullptr; + int mNumOnDroppedCalled = 0; + int mNumOnLastMatchDequeuedCalled = 0; + bool mHitCeilingWantFloor; +}; + +class FilterDriver +{ +public: + FilterDriver(EndpointQueueFilter * filter, const void * endpoint) : mFilter(filter), mEndpoint(endpoint) {} + + EndpointQueueFilter::FilterOutcome ProcessEnqueue(const IPAddress & srcAddr, uint16_t srcPort, const IPAddress & dstAddr, + uint16_t dstPort, ByteSpan payload) + { + VerifyOrDie(mFilter != nullptr); + + chip::Inet::IPPacketInfo pktInfo; + pktInfo.SrcAddress = srcAddr; + pktInfo.DestAddress = dstAddr; + pktInfo.SrcPort = srcPort; + pktInfo.DestPort = dstPort; + + auto pktPayload = chip::System::PacketBufferHandle::NewWithData(payload.data(), payload.size()); + + return mFilter->FilterBeforeEnqueue(mEndpoint, pktInfo, pktPayload); + } + + EndpointQueueFilter::FilterOutcome ProcessDequeue(const IPAddress & srcAddr, uint16_t srcPort, const IPAddress & dstAddr, + uint16_t dstPort, ByteSpan payload) + { + VerifyOrDie(mFilter != nullptr); + + chip::Inet::IPPacketInfo pktInfo; + pktInfo.SrcAddress = srcAddr; + pktInfo.DestAddress = dstAddr; + pktInfo.SrcPort = srcPort; + pktInfo.DestPort = dstPort; + + auto pktPayload = chip::System::PacketBufferHandle::NewWithData(payload.data(), payload.size()); + + return mFilter->FilterAfterDequeue(mEndpoint, pktInfo, pktPayload); + } + +protected: + EndpointQueueFilter * mFilter = nullptr; + const void * mEndpoint = nullptr; +}; + +DropIfTooManyQueuedPacketsHarness gFilter(0); +int gFakeEndpointForPointer = 0; + +void TestBasicPacketFilter(nlTestSuite * inSuite, void * inContext) +{ + constexpr uint16_t kMdnsPort = 5353u; + + gFilter.mTestSuite = inSuite; + + // Predicate for test is filter that destination port is 5353 (mDNS). + // NOTE: A non-capturing lambda is used, but a plain function could have been used as well... + auto predicate = [](void * context, const void * endpoint, const chip::Inet::IPPacketInfo & pktInfo, + const chip::System::PacketBufferHandle & pktPayload) -> bool { + auto filter = reinterpret_cast(context); + auto testSuite = filter->mTestSuite; + auto expectedEndpoint = &gFakeEndpointForPointer; + + // Ensure we get called with context and expected endpoint pointer + NL_TEST_ASSERT(testSuite, context == &gFilter); + NL_TEST_ASSERT(testSuite, endpoint == expectedEndpoint); + + // Predicate filters destination port being 5353 + return (pktInfo.DestPort == kMdnsPort); + }; + gFilter.SetPredicate(predicate, &gFilter); + + FilterDriver fakeUdpEndpoint(&gFilter, &gFakeEndpointForPointer); + + IPAddress fakeSrc; + IPAddress fakeDest; + IPAddress fakeMdnsDest; + constexpr uint16_t kOtherPort = 43210u; + const uint8_t kFakePayloadData[] = { 1, 2, 3 }; + const ByteSpan kFakePayload{ kFakePayloadData }; + + NL_TEST_ASSERT(inSuite, IPAddress::FromString("fe80::aaaa:bbbb:cccc:dddd", fakeSrc)); + NL_TEST_ASSERT(inSuite, IPAddress::FromString("fe80::0000:1111:2222:3333", fakeDest)); + NL_TEST_ASSERT(inSuite, IPAddress::FromString("ff02::fb", fakeMdnsDest)); + + // Shorthands for simplifying asserts + constexpr EndpointQueueFilter::FilterOutcome kAllowPacket = EndpointQueueFilter::FilterOutcome::kAllowPacket; + constexpr EndpointQueueFilter::FilterOutcome kDropPacket = EndpointQueueFilter::FilterOutcome::kDropPacket; + + constexpr int kMaxQueuedPacketsLimit = 3; + gFilter.SetMaxQueuedPacketsLimit(kMaxQueuedPacketsLimit); + + { + // Enqueue some packets that don't match filter, all allowed, never hit the drop + for (int numPkt = 0; numPkt < (kMaxQueuedPacketsLimit + 1); ++numPkt) + { + NL_TEST_ASSERT(inSuite, + kAllowPacket == fakeUdpEndpoint.ProcessEnqueue(fakeSrc, kOtherPort, fakeDest, kOtherPort, kFakePayload)); + } + NL_TEST_ASSERT(inSuite, gFilter.GetNumDroppedPackets() == 0); + + // Dequeue all packets + for (int numPkt = 0; numPkt < (kMaxQueuedPacketsLimit + 1); ++numPkt) + { + NL_TEST_ASSERT(inSuite, + kAllowPacket == fakeUdpEndpoint.ProcessDequeue(fakeSrc, kOtherPort, fakeDest, kOtherPort, kFakePayload)); + } + NL_TEST_ASSERT(inSuite, gFilter.GetNumDroppedPackets() == 0); + + // OnDroped/OnLastMatchDequeued only ever called for matching packets, never for non-matching + NL_TEST_ASSERT(inSuite, gFilter.mNumOnDroppedCalled == 0); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnLastMatchDequeuedCalled == 0); + } + + { + // Enqueue packets that match filter, up to watermark. None dropped + for (int numPkt = 0; numPkt < kMaxQueuedPacketsLimit; ++numPkt) + { + NL_TEST_ASSERT(inSuite, + kAllowPacket == + fakeUdpEndpoint.ProcessEnqueue(fakeSrc, kOtherPort, fakeMdnsDest, kMdnsPort, kFakePayload)); + } + NL_TEST_ASSERT(inSuite, gFilter.GetNumDroppedPackets() == 0); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnDroppedCalled == 0); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnLastMatchDequeuedCalled == 0); + + // Enqueue packets that match filter, beyond watermark: all dropped. + for (int numPkt = 0; numPkt < 2; ++numPkt) + { + NL_TEST_ASSERT( + inSuite, kDropPacket == fakeUdpEndpoint.ProcessEnqueue(fakeSrc, kOtherPort, fakeMdnsDest, kMdnsPort, kFakePayload)); + } + NL_TEST_ASSERT(inSuite, gFilter.GetNumDroppedPackets() == 2); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnDroppedCalled == 2); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnLastMatchDequeuedCalled == 0); + + // Dequeue 2 packets that were enqueued, matching filter + for (int numPkt = 0; numPkt < 2; ++numPkt) + { + NL_TEST_ASSERT(inSuite, + kAllowPacket == + fakeUdpEndpoint.ProcessDequeue(fakeSrc, kOtherPort, fakeMdnsDest, kMdnsPort, kFakePayload)); + } + // Number of dropped packets didn't change + NL_TEST_ASSERT(inSuite, gFilter.GetNumDroppedPackets() == 2); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnDroppedCalled == 2); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnLastMatchDequeuedCalled == 0); + + // Enqueue packets that match filter, up to watermark again. None dropped. + for (int numPkt = 0; numPkt < 2; ++numPkt) + { + NL_TEST_ASSERT(inSuite, + kAllowPacket == + fakeUdpEndpoint.ProcessEnqueue(fakeSrc, kOtherPort, fakeMdnsDest, kMdnsPort, kFakePayload)); + } + + // No change from prior state + NL_TEST_ASSERT(inSuite, gFilter.GetNumDroppedPackets() == 2); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnDroppedCalled == 2); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnLastMatchDequeuedCalled == 0); + + // Enqueue two more packets, expect drop + for (int numPkt = 0; numPkt < 2; ++numPkt) + { + NL_TEST_ASSERT( + inSuite, kDropPacket == fakeUdpEndpoint.ProcessEnqueue(fakeSrc, kOtherPort, fakeMdnsDest, kMdnsPort, kFakePayload)); + } + + // Expect two more dropped total + NL_TEST_ASSERT(inSuite, gFilter.GetNumDroppedPackets() == 4); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnDroppedCalled == 4); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnLastMatchDequeuedCalled == 0); + + // Enqueue non-matching packet, expect allowed. + for (int numPkt = 0; numPkt < kMaxQueuedPacketsLimit; ++numPkt) + { + NL_TEST_ASSERT(inSuite, + kAllowPacket == fakeUdpEndpoint.ProcessEnqueue(fakeSrc, kOtherPort, fakeDest, kOtherPort, kFakePayload)); + } + + // Expect no more dropepd + NL_TEST_ASSERT(inSuite, gFilter.GetNumDroppedPackets() == 4); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnDroppedCalled == 4); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnLastMatchDequeuedCalled == 0); + + // Dequeue non-matching packet, expect allowed. + for (int numPkt = 0; numPkt < kMaxQueuedPacketsLimit; ++numPkt) + { + NL_TEST_ASSERT(inSuite, + kAllowPacket == fakeUdpEndpoint.ProcessDequeue(fakeSrc, kOtherPort, fakeDest, kOtherPort, kFakePayload)); + } + + // Expect no change + NL_TEST_ASSERT(inSuite, gFilter.GetNumDroppedPackets() == 4); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnDroppedCalled == 4); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnLastMatchDequeuedCalled == 0); + + // Dequeue all matching packets, expect allowed and one OnLastMatchDequeued on last one. + for (int numPkt = 0; numPkt < (kMaxQueuedPacketsLimit - 1); ++numPkt) + { + NL_TEST_ASSERT(inSuite, + kAllowPacket == + fakeUdpEndpoint.ProcessDequeue(fakeSrc, kOtherPort, fakeMdnsDest, kMdnsPort, kFakePayload)); + } + + NL_TEST_ASSERT(inSuite, gFilter.GetNumDroppedPackets() == 4); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnDroppedCalled == 4); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnLastMatchDequeuedCalled == 0); + + NL_TEST_ASSERT(inSuite, + kAllowPacket == fakeUdpEndpoint.ProcessDequeue(fakeSrc, kOtherPort, fakeMdnsDest, kMdnsPort, kFakePayload)); + + NL_TEST_ASSERT(inSuite, gFilter.GetNumDroppedPackets() == 4); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnDroppedCalled == 4); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnLastMatchDequeuedCalled == 1); + } + + // Validate that clearing drop count works + { + gFilter.ClearNumDroppedPackets(); + NL_TEST_ASSERT(inSuite, gFilter.GetNumDroppedPackets() == 0); + + gFilter.mNumOnDroppedCalled = 0; + gFilter.mNumOnLastMatchDequeuedCalled = 0; + } + + // Validate that all packets pass when no predicate set + { + gFilter.SetPredicate(nullptr, nullptr); + + // Enqueue packets up to twice the watermark. None dropped. + for (int numPkt = 0; numPkt < (2 * kMaxQueuedPacketsLimit); ++numPkt) + { + NL_TEST_ASSERT(inSuite, + kAllowPacket == + fakeUdpEndpoint.ProcessEnqueue(fakeSrc, kOtherPort, fakeMdnsDest, kMdnsPort, kFakePayload)); + } + NL_TEST_ASSERT(inSuite, gFilter.GetNumDroppedPackets() == 0); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnDroppedCalled == 0); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnLastMatchDequeuedCalled == 0); + + // Works even if max number of packets allowed is zero + gFilter.SetMaxQueuedPacketsLimit(0); + + // Enqueue packets up to twice the watermark. None dropped. + for (int numPkt = 0; numPkt < (2 * kMaxQueuedPacketsLimit); ++numPkt) + { + NL_TEST_ASSERT(inSuite, + kAllowPacket == + fakeUdpEndpoint.ProcessEnqueue(fakeSrc, kOtherPort, fakeMdnsDest, kMdnsPort, kFakePayload)); + } + NL_TEST_ASSERT(inSuite, gFilter.GetNumDroppedPackets() == 0); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnDroppedCalled == 0); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnLastMatchDequeuedCalled == 0); + } + + // Validate that setting max packets to zero, with a matching predicate, drops all matching packets, none of the non-matching. + { + gFilter.SetPredicate(predicate, &gFilter); + gFilter.SetMaxQueuedPacketsLimit(0); + + // Enqueue packets that match filter, up to watermark. All dropped + for (int numPkt = 0; numPkt < kMaxQueuedPacketsLimit; ++numPkt) + { + NL_TEST_ASSERT( + inSuite, kDropPacket == fakeUdpEndpoint.ProcessEnqueue(fakeSrc, kOtherPort, fakeMdnsDest, kMdnsPort, kFakePayload)); + } + + NL_TEST_ASSERT(inSuite, gFilter.GetNumDroppedPackets() == 3); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnDroppedCalled == 3); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnLastMatchDequeuedCalled == 0); + + // Enqueue non-filter-matching, none dropped + for (int numPkt = 0; numPkt < kMaxQueuedPacketsLimit; ++numPkt) + { + NL_TEST_ASSERT(inSuite, + kAllowPacket == fakeUdpEndpoint.ProcessDequeue(fakeSrc, kOtherPort, fakeDest, kOtherPort, kFakePayload)); + } + + NL_TEST_ASSERT(inSuite, gFilter.GetNumDroppedPackets() == 3); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnDroppedCalled == 3); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnLastMatchDequeuedCalled == 0); + } +} + +const nlTest sTests[] = { + NL_TEST_DEF("TestBasicPacketFilter", TestBasicPacketFilter), // + NL_TEST_SENTINEL() // +}; + +int TestSuiteSetup(void * inContext) +{ + CHIP_ERROR error = chip::Platform::MemoryInit(); + if (error != CHIP_NO_ERROR) + return FAILURE; + return SUCCESS; +} + +int TestSuiteTeardown(void * inContext) +{ + chip::Platform::MemoryShutdown(); + return SUCCESS; +} + +} // namespace + +int TestBasicPacketFilters() +{ + nlTestSuite theSuite = { "TestBasicPacketFilters", sTests, &TestSuiteSetup, &TestSuiteTeardown }; + nlTestRunner(&theSuite, nullptr); + return nlTestRunnerStats(&theSuite); +} + +CHIP_REGISTER_TEST_SUITE(TestBasicPacketFilters)