Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce basic UDP packet filtering scheme #23957

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/inet/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ static_library("inet") {

if (chip_inet_config_enable_udp_endpoint) {
sources += [
"BasicPacketFilters.h",
"EndpointQueueFilter.h",
"UDPEndPoint.cpp",
"UDPEndPoint.h",
"UDPEndPointImpl${chip_system_config_inet}.cpp",
Expand Down
176 changes: 176 additions & 0 deletions src/inet/BasicPacketFilters.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
/*
* Copyright (c) 2022 Project CHIP Authors
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <atomic>
#include <inet/EndpointQueueFilter.h>
#include <inet/IPPacketInfo.h>
#include <lib/support/CodeUtils.h>
#include <system/SystemPacketBuffer.h>

namespace chip {
namespace Inet {

/**
* @brief Basic filter that counts how many pending (not yet dequeued) packets
* are accumulated that match a predicate function, and drops those that
* would cause crossing of the threshold.
*/
class DropIfTooManyQueuedPacketsFilter : public chip::Inet::EndpointQueueFilter
{
public:
typedef bool (*PacketMatchPredicateFunc)(void * context, const void * endpoint, const chip::Inet::IPPacketInfo & pktInfo,
const chip::System::PacketBufferHandle & pktPayload);

/**
* @brief Initialize the packet filter with a starting limit
*
* @param maxAllowedQueuedPackets - max number of pending-in-queue DNS packets not yet processed that matched predicate
*/
DropIfTooManyQueuedPacketsFilter(size_t maxAllowedQueuedPackets) : mMaxAllowedQueuedPackets(maxAllowedQueuedPackets) {}

/**
* @brief Set the predicate to use for filtering
*
* @warning DO NOT modify at runtime while the filter is being called.
*
* @param predicateFunc - Predicate function to apply. If nullptr, no filtering will take place
* @param context - Pointer to predicate-specific context that will be provided to predicate at every call. May be nullptr.
*/
void SetPredicate(PacketMatchPredicateFunc predicateFunc, void * context)
{
mPredicate = predicateFunc;
mContext = context;
}

/**
* @brief Set the ceiling for max allowed packets queued up that matched the predicate.
*
* @param maxAllowedQueuedPackets - number of packets currently pending allowed.
*/
void SetMaxQueuedPacketsLimit(int maxAllowedQueuedPackets) { mMaxAllowedQueuedPackets = maxAllowedQueuedPackets; }

/**
* @return the total number of packets dropped so far by the filter
*/
size_t GetNumDroppedPackets() const { return mNumDroppedPackets.load(); }

/**
* @brief Reset the counter of dropped packets.
*/
void ClearNumDroppedPackets() { mNumDroppedPackets.exchange(0); }

/**
* @brief Template method called when a packet is dropped due to high watermark getting reached, based on predicate.
*
* This is called once for every dropped packet. If there is no filter predicate, this is not called.
*
* @param endpoint - pointer to endpoint instance (platform-dependent, which is why it's void)
* @param pktInfo - info about source/dest of packet
* @param pktPayload - payload content of packet
*/
virtual void OnDropped(const void * endpoint, const chip::Inet::IPPacketInfo & pktInfo,
const chip::System::PacketBufferHandle & pktPayload)
{}

/**
* @brief Template method called whenever queue of accumulated packets is now empty, based on predicate.
*
* This is possibly called repeatedly in a row, if the queue actually never gets above one.
*
* This is only called for packets that had matched the filtering rule, where they had
* been explicitly allowed in the past. If there is no filter predicate, this is not called.
*
* @param endpoint - pointer to endpoint instance (platform-dependent, which is why it's void)
* @param pktInfo - info about source/dest of packet
* @param pktPayload - payload content of packet
*/
virtual void OnQueueEmpty(const void * endpoint, const chip::Inet::IPPacketInfo & pktInfo,
const chip::System::PacketBufferHandle & pktPayload)
{}

/**
* @brief Implementation of filtering before queueing that applies the predicate.
*
* See base class for arguments
*/
FilterOutcome FilterBeforeEnqueue(const void * endpoint, const chip::Inet::IPPacketInfo & pktInfo,
const chip::System::PacketBufferHandle & pktPayload) override
{
// WARNING: This is likely called in a different context than `FilterAfterDequeue`. We use an atomic for the counter.

// Non-matching is never accounted, always allowed. Lack of predicate is equivalent to non-matching.
if ((mPredicate == nullptr) || !mPredicate(mContext, endpoint, pktInfo, pktPayload))
{
return FilterOutcome::kAllowPacket;
}

if (mNumQueuedPackets.load() >= mMaxAllowedQueuedPackets)
{
++mNumDroppedPackets;
OnDropped(endpoint, pktInfo, pktPayload);
return FilterOutcome::kDropPacket;
}

++mNumQueuedPackets;

return FilterOutcome::kAllowPacket;
}

/**
* @brief Implementation of filtering after dequeueing that applies the predicate.
*
* See base class for arguments
*/
FilterOutcome FilterAfterDequeue(const void * endpoint, const chip::Inet::IPPacketInfo & pktInfo,
const chip::System::PacketBufferHandle & pktPayload) override
{
// WARNING: This is likely called in a different context than `FilterBeforeEnqueue`. We use an atomic for the counter.
// NOTE: This is always called from Matter platform event loop

// Non-matching is never accounted, always allowed. Lack of predicate is equivalent to non-matching.
if ((mPredicate == nullptr) || !mPredicate(mContext, endpoint, pktInfo, pktPayload))
{
return FilterOutcome::kAllowPacket;
}

--mNumQueuedPackets;
int numQueuedPackets = mNumQueuedPackets.load();
if (numQueuedPackets == 0)
{
OnQueueEmpty(endpoint, pktInfo, pktPayload);
}

// If we ever go negative, we have mismatch ingress/egress filter via predicate and
// device may eventually starve.
VerifyOrDie(numQueuedPackets >= 0);

// We always allow the packet and just do accounting, since all dropping is prior to queue entry.
return FilterOutcome::kAllowPacket;
}

protected:
PacketMatchPredicateFunc mPredicate = nullptr;
void * mContext = nullptr;
std::atomic_int mNumQueuedPackets{ 0 };
std::atomic_int mMaxAllowedQueuedPackets{ 0 };
std::atomic_size_t mNumDroppedPackets{ 0u };
};

} // namespace Inet
} // namespace chip
88 changes: 88 additions & 0 deletions src/inet/EndpointQueueFilter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
* Copyright (c) 2022 Project CHIP Authors
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <inet/IPPacketInfo.h>
#include <system/SystemPacketBuffer.h>

namespace chip {
namespace Inet {

/**
* @brief Filter for UDP Packets going into and out of UDPEndPoint queue.
*
* NOTE: This is only used by some low-level implementations of UDPEndPoint
*/
class EndpointQueueFilter
{
public:
enum FilterOutcome : int
{
kAllowPacket = 0,
kDropPacket = 1,
};

virtual ~EndpointQueueFilter() {}

/**
* @brief Run filter prior to inserting in queue.
*
* If filter returns `kAllowPacket`, packet will be enqueued, and `FilterAfterDequeue` will
* be called when it gets dequeued. If filter returns `kDropPacket`, packet will be dropped
* rather than enqueued and `FilterAfterDequeue` method will not be called.
*
* WARNING: This is likely called from non-Matter-eventloop context, from network layer code.
* Be extremely careful about accessing any system data which may belong to Matter
* stack from this method.
*
* @param endpoint - pointer to endpoint instance (platform-dependent, which is why it's void)
* @param pktInfo - info about source/dest of packet
* @param pktPayload - payload content of packet
*
* @return kAllowPacket to allow packet to enqueue or kDropPacket to drop the packet
*/
virtual FilterOutcome FilterBeforeEnqueue(const void * endpoint, const IPPacketInfo & pktInfo,
const chip::System::PacketBufferHandle & pktPayload)
{
return FilterOutcome::kAllowPacket;
}

/**
* @brief Run filter after dequeuing, prior to processing.
*
* If filter returns `kAllowPacket`, packet will be processed after dequeuing. If filter returns
* `kDropPacket`, packet will be dropped and not processed, even though it was dequeued.
*
* WARNING: This is called from Matter thread context. Be extremely careful about accessing any
* data which may belong to different threads from this method.
*
* @param endpoint - pointer to endpoint instance (platform-dependent, which is why it's void)
* @param pktInfo - info about source/dest of packet
* @param pktPayload - payload content of packet
*
* @return kAllowPacket to allow packet to be processed or kDropPacket to drop the packet
*/
virtual FilterOutcome FilterAfterDequeue(const void * endpoint, const IPPacketInfo & pktInfo,
const chip::System::PacketBufferHandle & pktPayload)
{
return FilterOutcome::kAllowPacket;
}
};

} // namespace Inet
} // namespace chip
37 changes: 35 additions & 2 deletions src/inet/UDPEndPointImplLwIP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ struct Deleter<struct pbuf>
namespace chip {
namespace Inet {

EndpointQueueFilter * UDPEndPointImplLwIP::sQueueFilter = nullptr;

CHIP_ERROR UDPEndPointImplLwIP::BindImpl(IPAddressType addressType, const IPAddress & address, uint16_t port,
InterfaceId interfaceId)
{
Expand Down Expand Up @@ -287,7 +289,16 @@ void UDPEndPointImplLwIP::Free()

void UDPEndPointImplLwIP::HandleDataReceived(System::PacketBufferHandle && msg, IPPacketInfo * pktInfo)
{
if ((mState == State::kListening) && (OnMessageReceived != nullptr))
// Process packet filter if needed. May cause packet to get dropped before processing.
bool dropPacket = false;
if ((pktInfo != nullptr) && (sQueueFilter != nullptr))
{
auto outcome = sQueueFilter->FilterAfterDequeue(this, *pktInfo, msg);
dropPacket = (outcome == EndpointQueueFilter::FilterOutcome::kDropPacket);
}

// Process actual packet if allowed
if ((mState == State::kListening) && (OnMessageReceived != nullptr) && !dropPacket)
{
if (pktInfo != nullptr)
{
Expand Down Expand Up @@ -424,14 +435,28 @@ void UDPEndPointImplLwIP::LwIPReceiveUDPMessage(void * arg, struct udp_pcb * pcb
pktInfo->SrcPort = port;
pktInfo->DestPort = pcb->local_port;

auto filterOutcome = EndpointQueueFilter::FilterOutcome::kAllowPacket;
if (sQueueFilter != nullptr)
{
filterOutcome = sQueueFilter->FilterBeforeEnqueue(ep, *(pktInfo.get()), buf);
}

if (filterOutcome != EndpointQueueFilter::FilterOutcome::kAllowPacket)
{
// Logging, if any, should be at the choice of the filter impl at time of filtering.
return;
}

// Increase mDelayReleaseCount to delay release of this UDP EndPoint while the HandleDataReceived call is
// pending on it.
ep->mDelayReleaseCount++;

CHIP_ERROR err = ep->GetSystemLayer().ScheduleLambda(
[ep, p = System::LwIPPacketBufferView::UnsafeGetLwIPpbuf(buf), pktInfo = pktInfo.get()] {
ep->mDelayReleaseCount--;
ep->HandleDataReceived(System::PacketBufferHandle::Adopt(p), pktInfo);

auto handle = System::PacketBufferHandle::Adopt(p);
ep->HandleDataReceived(std::move(handle), pktInfo);
});

if (err == CHIP_NO_ERROR)
Expand All @@ -443,6 +468,14 @@ void UDPEndPointImplLwIP::LwIPReceiveUDPMessage(void * arg, struct udp_pcb * pcb
}
else
{
// On failure to enqueue the processing, we have to tell the filter that
// the packet is basically dequeued, if it tries to keep track of the lifecycle.
if (sQueueFilter != nullptr)
{
(void) sQueueFilter->FilterAfterDequeue(ep, *(pktInfo.get()), buf);
ChipLogError(Inet, "Dequeue ERROR err = %" CHIP_ERROR_FORMAT, err.Format());
}

ep->mDelayReleaseCount--;
}
}
Expand Down
17 changes: 17 additions & 0 deletions src/inet/UDPEndPointImplLwIP.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#pragma once

#include <inet/EndPointStateLwIP.h>
#include <inet/EndpointQueueFilter.h>
#include <inet/UDPEndPoint.h>

namespace chip {
Expand All @@ -40,6 +41,20 @@ class UDPEndPointImplLwIP : public UDPEndPoint, public EndPointStateLwIP
uint16_t GetBoundPort() const override;
void Free() override;

/**
* @brief Set the queue filter for all UDP endpoints
*
* Responsibility is on the caller to avoid changing the filter while packets are being
* processed. Setting the queue filter to `nullptr` disables the filtering.
*
* NOTE: There is only one EndpointQueueFilter instance settable. However it's possible
* to create an instance of EndpointQueueFilter that combines several other
* EndpointQueueFilter by composition to achieve the effect of multiple filters.
*
* @param queueFilter - queue filter instance to set, owned by caller
*/
static void SetQueueFilter(EndpointQueueFilter * queueFilter) { sQueueFilter = queueFilter; }

private:
// UDPEndPoint overrides.
#if INET_CONFIG_ENABLE_IPV4
Expand All @@ -62,6 +77,8 @@ class UDPEndPointImplLwIP : public UDPEndPoint, public EndPointStateLwIP

udp_pcb * mUDP; // LwIP User datagram protocol (UDP) control block.
std::atomic_int mDelayReleaseCount{ 0 };

static EndpointQueueFilter * sQueueFilter;
};

using UDPEndPointImpl = UDPEndPointImplLwIP;
Expand Down
1 change: 1 addition & 0 deletions src/inet/tests/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ chip_test_suite("tests") {
]

test_sources = [
"TestBasicPacketFilters.cpp",
"TestInetAddress.cpp",
"TestInetErrorStr.cpp",
]
Expand Down
Loading