diff --git a/src/inet/BUILD.gn b/src/inet/BUILD.gn index cb86a45780df9c..f8f81f94e3c44a 100644 --- a/src/inet/BUILD.gn +++ b/src/inet/BUILD.gn @@ -122,6 +122,8 @@ static_library("inet") { if (chip_inet_config_enable_udp_endpoint) { sources += [ + "BasicPacketFilters.h", + "EndpointQueueFilter.h", "UDPEndPoint.cpp", "UDPEndPoint.h", "UDPEndPointImpl${chip_system_config_inet}.cpp", diff --git a/src/inet/BasicPacketFilters.h b/src/inet/BasicPacketFilters.h new file mode 100644 index 00000000000000..364c0c7c11f17d --- /dev/null +++ b/src/inet/BasicPacketFilters.h @@ -0,0 +1,185 @@ +/* + * Copyright (c) 2022 Project CHIP Authors + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace chip { +namespace Inet { + +/** + * @brief Basic filter that counts how many pending (not yet dequeued) packets + * are accumulated that match a predicate function, and drops those that + * would cause crossing of the threshold. + */ +class DropIfTooManyQueuedPacketsFilter : public chip::Inet::EndpointQueueFilter +{ +public: + typedef bool (*PacketMatchPredicateFunc)(void * context, const void * endpoint, const chip::Inet::IPPacketInfo & pktInfo, + const chip::System::PacketBufferHandle & pktPayload); + + /** + * @brief Initialize the packet filter with a starting limit + * + * @param maxAllowedQueuedPackets - max number of pending-in-queue not yet processed predicate-matching packets + */ + DropIfTooManyQueuedPacketsFilter(size_t maxAllowedQueuedPackets) : mMaxAllowedQueuedPackets(maxAllowedQueuedPackets) {} + + /** + * @brief Set the predicate to use for filtering + * + * @warning DO NOT modify at runtime while the filter is being called. If you do so, the queue accounting could + * get out of sync, and cause the filtering to fail to properly work. + * + * @param predicateFunc - Predicate function to apply. If nullptr, no filtering will take place + * @param context - Pointer to predicate-specific context that will be provided to predicate at every call. May be nullptr. + */ + void SetPredicate(PacketMatchPredicateFunc predicateFunc, void * context) + { + mPredicate = predicateFunc; + mContext = context; + } + + /** + * @brief Set the ceiling for max allowed packets queued up that matched the predicate. + * + * @note Changing this at runtime while packets are coming only affects future dropping, and + * does not remove packets from the queue if the limit is lowered below the currently-in-queue + * count. + * + * @param maxAllowedQueuedPackets - number of packets currently pending allowed. + */ + void SetMaxQueuedPacketsLimit(int maxAllowedQueuedPackets) { mMaxAllowedQueuedPackets.store(maxAllowedQueuedPackets); } + + /** + * @return the total number of packets dropped so far by the filter + */ + size_t GetNumDroppedPackets() const { return mNumDroppedPackets.load(); } + + /** + * @brief Reset the counter of dropped packets. + */ + void ClearNumDroppedPackets() { mNumDroppedPackets.store(0); } + + /** + * @brief Method called when a packet is dropped due to high watermark getting reached, based on predicate. + * + * Subclasses may use this to implement additional behavior or diagnostics. + * + * This is called once for every dropped packet. If there is no filter predicate, this is not called. + * + * @param endpoint - pointer to endpoint instance (platform-dependent, which is why it's void) + * @param pktInfo - info about source/dest of packet + * @param pktPayload - payload content of packet + */ + virtual void OnDropped(const void * endpoint, const chip::Inet::IPPacketInfo & pktInfo, + const chip::System::PacketBufferHandle & pktPayload) + {} + + /** + * @brief Method called whenever queue of accumulated packets is now empty, based on predicate. + * + * Subclasses may use this to implement additional behavior or diagnostics. + * + * This is possibly called repeatedly in a row, if the queue actually never gets above one. + * + * This is only called for packets that had matched the filtering rule, where they had + * been explicitly allowed in the past. If there is no filter predicate, this is not called. + * + * @param endpoint - pointer to endpoint instance (platform-dependent, which is why it's void) + * @param pktInfo - info about source/dest of packet + * @param pktPayload - payload content of packet + */ + virtual void OnLastMatchDequeued(const void * endpoint, const chip::Inet::IPPacketInfo & pktInfo, + const chip::System::PacketBufferHandle & pktPayload) + {} + + /** + * @brief Implementation of filtering before queueing that applies the predicate. + * + * See base class for arguments + */ + FilterOutcome FilterBeforeEnqueue(const void * endpoint, const chip::Inet::IPPacketInfo & pktInfo, + const chip::System::PacketBufferHandle & pktPayload) override + { + // WARNING: This is likely called in a different context than `FilterAfterDequeue`. We use an atomic for the counter. + + // Non-matching is never accounted, always allowed. Lack of predicate is equivalent to non-matching. + if ((mPredicate == nullptr) || !mPredicate(mContext, endpoint, pktInfo, pktPayload)) + { + return FilterOutcome::kAllowPacket; + } + + if (mNumQueuedPackets.load() >= mMaxAllowedQueuedPackets) + { + ++mNumDroppedPackets; + OnDropped(endpoint, pktInfo, pktPayload); + return FilterOutcome::kDropPacket; + } + + ++mNumQueuedPackets; + + return FilterOutcome::kAllowPacket; + } + + /** + * @brief Implementation of filtering after dequeueing that applies the predicate. + * + * See base class for arguments + */ + FilterOutcome FilterAfterDequeue(const void * endpoint, const chip::Inet::IPPacketInfo & pktInfo, + const chip::System::PacketBufferHandle & pktPayload) override + { + // WARNING: This is likely called in a different context than `FilterBeforeEnqueue`. We use an atomic for the counter. + // NOTE: This is always called from Matter platform event loop + + // Non-matching is never accounted, always allowed. Lack of predicate is equivalent to non-matching. + if ((mPredicate == nullptr) || !mPredicate(mContext, endpoint, pktInfo, pktPayload)) + { + return FilterOutcome::kAllowPacket; + } + + --mNumQueuedPackets; + int numQueuedPackets = mNumQueuedPackets.load(); + if (numQueuedPackets == 0) + { + OnLastMatchDequeued(endpoint, pktInfo, pktPayload); + } + + // If we ever go negative, we have mismatch ingress/egress filter via predicate and + // device may eventually starve. + VerifyOrDie(numQueuedPackets >= 0); + + // We always allow the packet and just do accounting, since all dropping is prior to queue entry. + return FilterOutcome::kAllowPacket; + } + +protected: + PacketMatchPredicateFunc mPredicate = nullptr; + void * mContext = nullptr; + std::atomic_int mNumQueuedPackets{ 0 }; + std::atomic_int mMaxAllowedQueuedPackets{ 0 }; + std::atomic_size_t mNumDroppedPackets{ 0u }; +}; + +} // namespace Inet +} // namespace chip diff --git a/src/inet/EndpointQueueFilter.h b/src/inet/EndpointQueueFilter.h new file mode 100644 index 00000000000000..3239b973ac3f64 --- /dev/null +++ b/src/inet/EndpointQueueFilter.h @@ -0,0 +1,88 @@ +/* + * Copyright (c) 2022 Project CHIP Authors + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace chip { +namespace Inet { + +/** + * @brief Filter for UDP Packets going into and out of UDPEndPoint queue. + * + * NOTE: This is only used by some low-level implementations of UDPEndPoint + */ +class EndpointQueueFilter +{ +public: + enum FilterOutcome + { + kAllowPacket = 0, + kDropPacket = 1, + }; + + virtual ~EndpointQueueFilter() {} + + /** + * @brief Run filter prior to inserting in queue. + * + * If filter returns `kAllowPacket`, packet will be enqueued, and `FilterAfterDequeue` will + * be called when it gets dequeued. If filter returns `kDropPacket`, packet will be dropped + * rather than enqueued and `FilterAfterDequeue` method will not be called. + * + * WARNING: This is likely called from non-Matter-eventloop context, from network layer code. + * Be extremely careful about accessing any system data which may belong to Matter + * stack from this method. + * + * @param endpoint - pointer to endpoint instance (platform-dependent, which is why it's void) + * @param pktInfo - info about source/dest of packet + * @param pktPayload - payload content of packet + * + * @return kAllowPacket to allow packet to enqueue or kDropPacket to drop the packet + */ + virtual FilterOutcome FilterBeforeEnqueue(const void * endpoint, const IPPacketInfo & pktInfo, + const chip::System::PacketBufferHandle & pktPayload) + { + return FilterOutcome::kAllowPacket; + } + + /** + * @brief Run filter after dequeuing, prior to processing. + * + * If filter returns `kAllowPacket`, packet will be processed after dequeuing. If filter returns + * `kDropPacket`, packet will be dropped and not processed, even though it was dequeued. + * + * WARNING: This is called from Matter thread context. Be extremely careful about accessing any + * data which may belong to different threads from this method. + * + * @param endpoint - pointer to endpoint instance (platform-dependent, which is why it's void) + * @param pktInfo - info about source/dest of packet + * @param pktPayload - payload content of packet + * + * @return kAllowPacket to allow packet to be processed or kDropPacket to drop the packet + */ + virtual FilterOutcome FilterAfterDequeue(const void * endpoint, const IPPacketInfo & pktInfo, + const chip::System::PacketBufferHandle & pktPayload) + { + return FilterOutcome::kAllowPacket; + } +}; + +} // namespace Inet +} // namespace chip diff --git a/src/inet/UDPEndPointImplLwIP.cpp b/src/inet/UDPEndPointImplLwIP.cpp index 07ce94c7d653e9..78355628420b8b 100644 --- a/src/inet/UDPEndPointImplLwIP.cpp +++ b/src/inet/UDPEndPointImplLwIP.cpp @@ -78,6 +78,8 @@ struct Deleter namespace chip { namespace Inet { +EndpointQueueFilter * UDPEndPointImplLwIP::sQueueFilter = nullptr; + CHIP_ERROR UDPEndPointImplLwIP::BindImpl(IPAddressType addressType, const IPAddress & address, uint16_t port, InterfaceId interfaceId) { @@ -287,7 +289,16 @@ void UDPEndPointImplLwIP::Free() void UDPEndPointImplLwIP::HandleDataReceived(System::PacketBufferHandle && msg, IPPacketInfo * pktInfo) { - if ((mState == State::kListening) && (OnMessageReceived != nullptr)) + // Process packet filter if needed. May cause packet to get dropped before processing. + bool dropPacket = false; + if ((pktInfo != nullptr) && (sQueueFilter != nullptr)) + { + auto outcome = sQueueFilter->FilterAfterDequeue(this, *pktInfo, msg); + dropPacket = (outcome == EndpointQueueFilter::FilterOutcome::kDropPacket); + } + + // Process actual packet if allowed + if ((mState == State::kListening) && (OnMessageReceived != nullptr) && !dropPacket) { if (pktInfo != nullptr) { @@ -424,6 +435,18 @@ void UDPEndPointImplLwIP::LwIPReceiveUDPMessage(void * arg, struct udp_pcb * pcb pktInfo->SrcPort = port; pktInfo->DestPort = pcb->local_port; + auto filterOutcome = EndpointQueueFilter::FilterOutcome::kAllowPacket; + if (sQueueFilter != nullptr) + { + filterOutcome = sQueueFilter->FilterBeforeEnqueue(ep, *(pktInfo.get()), buf); + } + + if (filterOutcome != EndpointQueueFilter::FilterOutcome::kAllowPacket) + { + // Logging, if any, should be at the choice of the filter impl at time of filtering. + return; + } + // Increase mDelayReleaseCount to delay release of this UDP EndPoint while the HandleDataReceived call is // pending on it. ep->mDelayReleaseCount++; @@ -431,7 +454,9 @@ void UDPEndPointImplLwIP::LwIPReceiveUDPMessage(void * arg, struct udp_pcb * pcb CHIP_ERROR err = ep->GetSystemLayer().ScheduleLambda( [ep, p = System::LwIPPacketBufferView::UnsafeGetLwIPpbuf(buf), pktInfo = pktInfo.get()] { ep->mDelayReleaseCount--; - ep->HandleDataReceived(System::PacketBufferHandle::Adopt(p), pktInfo); + + auto handle = System::PacketBufferHandle::Adopt(p); + ep->HandleDataReceived(std::move(handle), pktInfo); }); if (err == CHIP_NO_ERROR) @@ -443,6 +468,14 @@ void UDPEndPointImplLwIP::LwIPReceiveUDPMessage(void * arg, struct udp_pcb * pcb } else { + // On failure to enqueue the processing, we have to tell the filter that + // the packet is basically dequeued, if it tries to keep track of the lifecycle. + if (sQueueFilter != nullptr) + { + (void) sQueueFilter->FilterAfterDequeue(ep, *(pktInfo.get()), buf); + ChipLogError(Inet, "Dequeue ERROR err = %" CHIP_ERROR_FORMAT, err.Format()); + } + ep->mDelayReleaseCount--; } } diff --git a/src/inet/UDPEndPointImplLwIP.h b/src/inet/UDPEndPointImplLwIP.h index 3d8aebe1021f52..230c872743641c 100644 --- a/src/inet/UDPEndPointImplLwIP.h +++ b/src/inet/UDPEndPointImplLwIP.h @@ -24,6 +24,7 @@ #pragma once #include +#include #include namespace chip { @@ -40,6 +41,20 @@ class UDPEndPointImplLwIP : public UDPEndPoint, public EndPointStateLwIP uint16_t GetBoundPort() const override; void Free() override; + /** + * @brief Set the queue filter for all UDP endpoints + * + * Responsibility is on the caller to avoid changing the filter while packets are being + * processed. Setting the queue filter to `nullptr` disables the filtering. + * + * NOTE: There is only one EndpointQueueFilter instance settable. However it's possible + * to create an instance of EndpointQueueFilter that combines several other + * EndpointQueueFilter by composition to achieve the effect of multiple filters. + * + * @param queueFilter - queue filter instance to set, owned by caller + */ + static void SetQueueFilter(EndpointQueueFilter * queueFilter) { sQueueFilter = queueFilter; } + private: // UDPEndPoint overrides. #if INET_CONFIG_ENABLE_IPV4 @@ -62,6 +77,8 @@ class UDPEndPointImplLwIP : public UDPEndPoint, public EndPointStateLwIP udp_pcb * mUDP; // LwIP User datagram protocol (UDP) control block. std::atomic_int mDelayReleaseCount{ 0 }; + + static EndpointQueueFilter * sQueueFilter; }; using UDPEndPointImpl = UDPEndPointImplLwIP; diff --git a/src/inet/tests/BUILD.gn b/src/inet/tests/BUILD.gn index 1788410bb6c6d3..ce7fd4664096ba 100644 --- a/src/inet/tests/BUILD.gn +++ b/src/inet/tests/BUILD.gn @@ -74,6 +74,7 @@ chip_test_suite("tests") { ] test_sources = [ + "TestBasicPacketFilters.cpp", "TestInetAddress.cpp", "TestInetErrorStr.cpp", ] diff --git a/src/inet/tests/TestBasicPacketFilters.cpp b/src/inet/tests/TestBasicPacketFilters.cpp new file mode 100644 index 00000000000000..ef07cab057eb89 --- /dev/null +++ b/src/inet/tests/TestBasicPacketFilters.cpp @@ -0,0 +1,385 @@ +/* + * Copyright (c) 2022 Project CHIP Authors + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace { + +using namespace chip; +using namespace chip::Inet; + +class DropIfTooManyQueuedPacketsHarness : public DropIfTooManyQueuedPacketsFilter +{ +public: + DropIfTooManyQueuedPacketsHarness(size_t maxAllowedQueuedPackets) : DropIfTooManyQueuedPacketsFilter(maxAllowedQueuedPackets) {} + + void OnDropped(const void * endpoint, const IPPacketInfo & pktInfo, + const chip::System::PacketBufferHandle & pktPayload) override + { + ++mNumOnDroppedCalled; + + // Log a hysteretic event + if (!mHitCeilingWantFloor) + { + mHitCeilingWantFloor = true; + ChipLogError(Inet, "Hit waterwark, will log as resolved once we get back to empty"); + } + } + + void OnLastMatchDequeued(const void * endpoint, const IPPacketInfo & pktInfo, + const chip::System::PacketBufferHandle & pktPayload) override + { + ++mNumOnLastMatchDequeuedCalled; + + // Log a hysteretic event + if (mHitCeilingWantFloor) + { + mHitCeilingWantFloor = false; + ChipLogError(Inet, "Resolved burst, got back to fully empty."); + } + } + + // Public bits to make testing easier + nlTestSuite * mTestSuite = nullptr; + int mNumOnDroppedCalled = 0; + int mNumOnLastMatchDequeuedCalled = 0; + bool mHitCeilingWantFloor; +}; + +class FilterDriver +{ +public: + FilterDriver(EndpointQueueFilter * filter, const void * endpoint) : mFilter(filter), mEndpoint(endpoint) {} + + EndpointQueueFilter::FilterOutcome ProcessEnqueue(const IPAddress & srcAddr, uint16_t srcPort, const IPAddress & dstAddr, + uint16_t dstPort, ByteSpan payload) + { + VerifyOrDie(mFilter != nullptr); + + chip::Inet::IPPacketInfo pktInfo; + pktInfo.SrcAddress = srcAddr; + pktInfo.DestAddress = dstAddr; + pktInfo.SrcPort = srcPort; + pktInfo.DestPort = dstPort; + + auto pktPayload = chip::System::PacketBufferHandle::NewWithData(payload.data(), payload.size()); + + return mFilter->FilterBeforeEnqueue(mEndpoint, pktInfo, pktPayload); + } + + EndpointQueueFilter::FilterOutcome ProcessDequeue(const IPAddress & srcAddr, uint16_t srcPort, const IPAddress & dstAddr, + uint16_t dstPort, ByteSpan payload) + { + VerifyOrDie(mFilter != nullptr); + + chip::Inet::IPPacketInfo pktInfo; + pktInfo.SrcAddress = srcAddr; + pktInfo.DestAddress = dstAddr; + pktInfo.SrcPort = srcPort; + pktInfo.DestPort = dstPort; + + auto pktPayload = chip::System::PacketBufferHandle::NewWithData(payload.data(), payload.size()); + + return mFilter->FilterAfterDequeue(mEndpoint, pktInfo, pktPayload); + } + +protected: + EndpointQueueFilter * mFilter = nullptr; + const void * mEndpoint = nullptr; +}; + +DropIfTooManyQueuedPacketsHarness gFilter(0); +int gFakeEndpointForPointer = 0; + +void TestBasicPacketFilter(nlTestSuite * inSuite, void * inContext) +{ + constexpr uint16_t kMdnsPort = 5353u; + + gFilter.mTestSuite = inSuite; + + // Predicate for test is filter that destination port is 5353 (mDNS). + // NOTE: A non-capturing lambda is used, but a plain function could have been used as well... + auto predicate = [](void * context, const void * endpoint, const chip::Inet::IPPacketInfo & pktInfo, + const chip::System::PacketBufferHandle & pktPayload) -> bool { + auto filter = reinterpret_cast(context); + auto testSuite = filter->mTestSuite; + auto expectedEndpoint = &gFakeEndpointForPointer; + + // Ensure we get called with context and expected endpoint pointer + NL_TEST_ASSERT(testSuite, context == &gFilter); + NL_TEST_ASSERT(testSuite, endpoint == expectedEndpoint); + + // Predicate filters destination port being 5353 + return (pktInfo.DestPort == kMdnsPort); + }; + gFilter.SetPredicate(predicate, &gFilter); + + FilterDriver fakeUdpEndpoint(&gFilter, &gFakeEndpointForPointer); + + IPAddress fakeSrc; + IPAddress fakeDest; + IPAddress fakeMdnsDest; + constexpr uint16_t kOtherPort = 43210u; + const uint8_t kFakePayloadData[] = { 1, 2, 3 }; + const ByteSpan kFakePayload{ kFakePayloadData }; + + NL_TEST_ASSERT(inSuite, IPAddress::FromString("fe80::aaaa:bbbb:cccc:dddd", fakeSrc)); + NL_TEST_ASSERT(inSuite, IPAddress::FromString("fe80::0000:1111:2222:3333", fakeDest)); + NL_TEST_ASSERT(inSuite, IPAddress::FromString("ff02::fb", fakeMdnsDest)); + + // Shorthands for simplifying asserts + constexpr EndpointQueueFilter::FilterOutcome kAllowPacket = EndpointQueueFilter::FilterOutcome::kAllowPacket; + constexpr EndpointQueueFilter::FilterOutcome kDropPacket = EndpointQueueFilter::FilterOutcome::kDropPacket; + + constexpr int kMaxQueuedPacketsLimit = 3; + gFilter.SetMaxQueuedPacketsLimit(kMaxQueuedPacketsLimit); + + { + // Enqueue some packets that don't match filter, all allowed, never hit the drop + for (int numPkt = 0; numPkt < (kMaxQueuedPacketsLimit + 1); ++numPkt) + { + NL_TEST_ASSERT(inSuite, + kAllowPacket == fakeUdpEndpoint.ProcessEnqueue(fakeSrc, kOtherPort, fakeDest, kOtherPort, kFakePayload)); + } + NL_TEST_ASSERT(inSuite, gFilter.GetNumDroppedPackets() == 0); + + // Dequeue all packets + for (int numPkt = 0; numPkt < (kMaxQueuedPacketsLimit + 1); ++numPkt) + { + NL_TEST_ASSERT(inSuite, + kAllowPacket == fakeUdpEndpoint.ProcessDequeue(fakeSrc, kOtherPort, fakeDest, kOtherPort, kFakePayload)); + } + NL_TEST_ASSERT(inSuite, gFilter.GetNumDroppedPackets() == 0); + + // OnDroped/OnLastMatchDequeued only ever called for matching packets, never for non-matching + NL_TEST_ASSERT(inSuite, gFilter.mNumOnDroppedCalled == 0); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnLastMatchDequeuedCalled == 0); + } + + { + // Enqueue packets that match filter, up to watermark. None dropped + for (int numPkt = 0; numPkt < kMaxQueuedPacketsLimit; ++numPkt) + { + NL_TEST_ASSERT(inSuite, + kAllowPacket == + fakeUdpEndpoint.ProcessEnqueue(fakeSrc, kOtherPort, fakeMdnsDest, kMdnsPort, kFakePayload)); + } + NL_TEST_ASSERT(inSuite, gFilter.GetNumDroppedPackets() == 0); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnDroppedCalled == 0); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnLastMatchDequeuedCalled == 0); + + // Enqueue packets that match filter, beyond watermark: all dropped. + for (int numPkt = 0; numPkt < 2; ++numPkt) + { + NL_TEST_ASSERT( + inSuite, kDropPacket == fakeUdpEndpoint.ProcessEnqueue(fakeSrc, kOtherPort, fakeMdnsDest, kMdnsPort, kFakePayload)); + } + NL_TEST_ASSERT(inSuite, gFilter.GetNumDroppedPackets() == 2); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnDroppedCalled == 2); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnLastMatchDequeuedCalled == 0); + + // Dequeue 2 packets that were enqueued, matching filter + for (int numPkt = 0; numPkt < 2; ++numPkt) + { + NL_TEST_ASSERT(inSuite, + kAllowPacket == + fakeUdpEndpoint.ProcessDequeue(fakeSrc, kOtherPort, fakeMdnsDest, kMdnsPort, kFakePayload)); + } + // Number of dropped packets didn't change + NL_TEST_ASSERT(inSuite, gFilter.GetNumDroppedPackets() == 2); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnDroppedCalled == 2); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnLastMatchDequeuedCalled == 0); + + // Enqueue packets that match filter, up to watermark again. None dropped. + for (int numPkt = 0; numPkt < 2; ++numPkt) + { + NL_TEST_ASSERT(inSuite, + kAllowPacket == + fakeUdpEndpoint.ProcessEnqueue(fakeSrc, kOtherPort, fakeMdnsDest, kMdnsPort, kFakePayload)); + } + + // No change from prior state + NL_TEST_ASSERT(inSuite, gFilter.GetNumDroppedPackets() == 2); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnDroppedCalled == 2); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnLastMatchDequeuedCalled == 0); + + // Enqueue two more packets, expect drop + for (int numPkt = 0; numPkt < 2; ++numPkt) + { + NL_TEST_ASSERT( + inSuite, kDropPacket == fakeUdpEndpoint.ProcessEnqueue(fakeSrc, kOtherPort, fakeMdnsDest, kMdnsPort, kFakePayload)); + } + + // Expect two more dropped total + NL_TEST_ASSERT(inSuite, gFilter.GetNumDroppedPackets() == 4); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnDroppedCalled == 4); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnLastMatchDequeuedCalled == 0); + + // Enqueue non-matching packet, expect allowed. + for (int numPkt = 0; numPkt < kMaxQueuedPacketsLimit; ++numPkt) + { + NL_TEST_ASSERT(inSuite, + kAllowPacket == fakeUdpEndpoint.ProcessEnqueue(fakeSrc, kOtherPort, fakeDest, kOtherPort, kFakePayload)); + } + + // Expect no more dropepd + NL_TEST_ASSERT(inSuite, gFilter.GetNumDroppedPackets() == 4); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnDroppedCalled == 4); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnLastMatchDequeuedCalled == 0); + + // Dequeue non-matching packet, expect allowed. + for (int numPkt = 0; numPkt < kMaxQueuedPacketsLimit; ++numPkt) + { + NL_TEST_ASSERT(inSuite, + kAllowPacket == fakeUdpEndpoint.ProcessDequeue(fakeSrc, kOtherPort, fakeDest, kOtherPort, kFakePayload)); + } + + // Expect no change + NL_TEST_ASSERT(inSuite, gFilter.GetNumDroppedPackets() == 4); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnDroppedCalled == 4); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnLastMatchDequeuedCalled == 0); + + // Dequeue all matching packets, expect allowed and one OnLastMatchDequeued on last one. + for (int numPkt = 0; numPkt < (kMaxQueuedPacketsLimit - 1); ++numPkt) + { + NL_TEST_ASSERT(inSuite, + kAllowPacket == + fakeUdpEndpoint.ProcessDequeue(fakeSrc, kOtherPort, fakeMdnsDest, kMdnsPort, kFakePayload)); + } + + NL_TEST_ASSERT(inSuite, gFilter.GetNumDroppedPackets() == 4); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnDroppedCalled == 4); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnLastMatchDequeuedCalled == 0); + + NL_TEST_ASSERT(inSuite, + kAllowPacket == fakeUdpEndpoint.ProcessDequeue(fakeSrc, kOtherPort, fakeMdnsDest, kMdnsPort, kFakePayload)); + + NL_TEST_ASSERT(inSuite, gFilter.GetNumDroppedPackets() == 4); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnDroppedCalled == 4); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnLastMatchDequeuedCalled == 1); + } + + // Validate that clearing drop count works + { + gFilter.ClearNumDroppedPackets(); + NL_TEST_ASSERT(inSuite, gFilter.GetNumDroppedPackets() == 0); + + gFilter.mNumOnDroppedCalled = 0; + gFilter.mNumOnLastMatchDequeuedCalled = 0; + } + + // Validate that all packets pass when no predicate set + { + gFilter.SetPredicate(nullptr, nullptr); + + // Enqueue packets up to twice the watermark. None dropped. + for (int numPkt = 0; numPkt < (2 * kMaxQueuedPacketsLimit); ++numPkt) + { + NL_TEST_ASSERT(inSuite, + kAllowPacket == + fakeUdpEndpoint.ProcessEnqueue(fakeSrc, kOtherPort, fakeMdnsDest, kMdnsPort, kFakePayload)); + } + NL_TEST_ASSERT(inSuite, gFilter.GetNumDroppedPackets() == 0); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnDroppedCalled == 0); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnLastMatchDequeuedCalled == 0); + + // Works even if max number of packets allowed is zero + gFilter.SetMaxQueuedPacketsLimit(0); + + // Enqueue packets up to twice the watermark. None dropped. + for (int numPkt = 0; numPkt < (2 * kMaxQueuedPacketsLimit); ++numPkt) + { + NL_TEST_ASSERT(inSuite, + kAllowPacket == + fakeUdpEndpoint.ProcessEnqueue(fakeSrc, kOtherPort, fakeMdnsDest, kMdnsPort, kFakePayload)); + } + NL_TEST_ASSERT(inSuite, gFilter.GetNumDroppedPackets() == 0); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnDroppedCalled == 0); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnLastMatchDequeuedCalled == 0); + } + + // Validate that setting max packets to zero, with a matching predicate, drops all matching packets, none of the non-matching. + { + gFilter.SetPredicate(predicate, &gFilter); + gFilter.SetMaxQueuedPacketsLimit(0); + + // Enqueue packets that match filter, up to watermark. All dropped + for (int numPkt = 0; numPkt < kMaxQueuedPacketsLimit; ++numPkt) + { + NL_TEST_ASSERT( + inSuite, kDropPacket == fakeUdpEndpoint.ProcessEnqueue(fakeSrc, kOtherPort, fakeMdnsDest, kMdnsPort, kFakePayload)); + } + + NL_TEST_ASSERT(inSuite, gFilter.GetNumDroppedPackets() == 3); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnDroppedCalled == 3); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnLastMatchDequeuedCalled == 0); + + // Enqueue non-filter-matching, none dropped + for (int numPkt = 0; numPkt < kMaxQueuedPacketsLimit; ++numPkt) + { + NL_TEST_ASSERT(inSuite, + kAllowPacket == fakeUdpEndpoint.ProcessDequeue(fakeSrc, kOtherPort, fakeDest, kOtherPort, kFakePayload)); + } + + NL_TEST_ASSERT(inSuite, gFilter.GetNumDroppedPackets() == 3); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnDroppedCalled == 3); + NL_TEST_ASSERT(inSuite, gFilter.mNumOnLastMatchDequeuedCalled == 0); + } +} + +const nlTest sTests[] = { + NL_TEST_DEF("TestBasicPacketFilter", TestBasicPacketFilter), // + NL_TEST_SENTINEL() // +}; + +int TestSuiteSetup(void * inContext) +{ + CHIP_ERROR error = chip::Platform::MemoryInit(); + if (error != CHIP_NO_ERROR) + return FAILURE; + return SUCCESS; +} + +int TestSuiteTeardown(void * inContext) +{ + chip::Platform::MemoryShutdown(); + return SUCCESS; +} + +} // namespace + +int TestBasicPacketFilters() +{ + nlTestSuite theSuite = { "TestBasicPacketFilters", sTests, &TestSuiteSetup, &TestSuiteTeardown }; + nlTestRunner(&theSuite, nullptr); + return nlTestRunnerStats(&theSuite); +} + +CHIP_REGISTER_TEST_SUITE(TestBasicPacketFilters)