Skip to content

Commit

Permalink
Use separate udp endpoint (with separate ephemeral port) for minmdns …
Browse files Browse the repository at this point in the history
…unicast queries (#12161)

* Separate out listening and querying ports for UDP

* Correct port typo and make query port do UDP listen as well

* Update unicast queries for  the resolver to use a separate unicast function (bound to a random port)

* Update minmdns client example to send unicast queries using ephemeral port as well

* Do not assume having a query in a reply is a bad packet (because unicast replies will include the query)

* Added comments to broadcast delegate implementations

* Address some review comments

* fix typo

* Added more comments

* More review comments

* Rename things to filter/picker to make code slightly clearer
  • Loading branch information
andy31415 authored and pull[bot] committed Jan 8, 2024
1 parent 1807f38 commit 22b613a
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 74 deletions.
17 changes: 14 additions & 3 deletions examples/minimal-mdns/client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,10 +283,21 @@ void BroadcastPacket(mdns::Minimal::ServerBase * server)
return;
}

if (server->BroadcastSend(builder.ReleasePacket(), gOptions.querySendPort) != CHIP_NO_ERROR)
if (gOptions.unicastAnswers)
{
printf("Error sending\n");
return;
if (server->BroadcastUnicastQuery(builder.ReleasePacket(), gOptions.querySendPort) != CHIP_NO_ERROR)
{
printf("Error sending\n");
return;
}
}
else
{
if (server->BroadcastSend(builder.ReleasePacket(), gOptions.querySendPort) != CHIP_NO_ERROR)
{
printf("Error sending\n");
return;
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/lib/dnssd/Advertiser_ImplMinimalMdns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,7 @@ bool AdvertiserMinMdns::ShouldAdvertiseOn(const chip::Inet::InterfaceId id, cons
{
const ServerBase::EndpointInfo & info = server.GetEndpoints()[i];

if (info.udp == nullptr)
if (info.listen_udp == nullptr)
{
continue;
}
Expand Down
9 changes: 5 additions & 4 deletions src/lib/dnssd/Resolver_ImplMinimalMdns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,9 @@ class PacketDataReporter : public ParserDelegate

void PacketDataReporter::OnQuery(const QueryData & data)
{
ChipLogError(Discovery, "Unexpected query packet being parsed as a response");
mValid = false;
// Ignore queries:
// - unicast answers will include the corresponding query in the answer
// packet, however that is not interesting for the resolver.
}

void PacketDataReporter::OnHeader(ConstHeaderRef & header)
Expand Down Expand Up @@ -433,7 +434,7 @@ CHIP_ERROR MinMdnsResolver::SendQuery(mdns::Minimal::FullQName qname, mdns::Mini

ReturnErrorCodeIf(!builder.Ok(), CHIP_ERROR_INTERNAL);

return GlobalMinimalMdnsServer::Server().BroadcastSend(builder.ReleasePacket(), kMdnsPort);
return GlobalMinimalMdnsServer::Server().BroadcastUnicastQuery(builder.ReleasePacket(), kMdnsPort);
}

CHIP_ERROR MinMdnsResolver::FindCommissionableNodes(DiscoveryFilter filter)
Expand Down Expand Up @@ -577,7 +578,7 @@ CHIP_ERROR MinMdnsResolver::SendPendingResolveQueries()

ReturnErrorCodeIf(!builder.Ok(), CHIP_ERROR_INTERNAL);

ReturnErrorOnFailure(GlobalMinimalMdnsServer::Server().BroadcastSend(builder.ReleasePacket(), kMdnsPort));
ReturnErrorOnFailure(GlobalMinimalMdnsServer::Server().BroadcastUnicastQuery(builder.ReleasePacket(), kMdnsPort));
}

return ScheduleResolveRetries();
Expand Down
179 changes: 115 additions & 64 deletions src/lib/dnssd/minimal_mdns/Server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,65 @@ class ShutdownOnError
ServerBase * mServer;
};

/**
* Extracts the Listening UDP Endpoint from an underlying ServerBase::EndpointInfo
*/
class ListenSocketPickerDelegate : public ServerBase::BroadcastSendDelegate
{
public:
chip::Inet::UDPEndPoint * Accept(ServerBase::EndpointInfo * info) override { return info->listen_udp; }
};

/**
* Extracts the Querying UDP Endpoint from an underlying ServerBase::EndpointInfo
*/
class QuerySocketPickerDelegate : public ServerBase::BroadcastSendDelegate
{
public:
chip::Inet::UDPEndPoint * Accept(ServerBase::EndpointInfo * info) override { return info->unicast_query_udp; }
};

/**
* Validates that an endpoint belongs to a specific interface/ip address type before forwarding the
* endpoint accept logic to another BroadcastSendDelegate.
*
* Usage like:
*
* SomeDelegate *child = ....;
* InterfaceTypeFilterDelegate filter(interfaceId, IPAddressType::IPv6, child);
*
* UDPEndPoint *udp = filter.Accept(endpointInfo);
*/
class InterfaceTypeFilterDelegate : public ServerBase::BroadcastSendDelegate
{
public:
InterfaceTypeFilterDelegate(chip::Inet::InterfaceId interface, chip::Inet::IPAddressType type,
ServerBase::BroadcastSendDelegate * child) :
mInterface(interface),
mAddressType(type), mChild(child)
{}

chip::Inet::UDPEndPoint * Accept(ServerBase::EndpointInfo * info) override
{
if ((info->interfaceId != mInterface) && (info->interfaceId != chip::Inet::InterfaceId::Null()))
{
return nullptr;
}

if ((mAddressType != chip::Inet::IPAddressType::kAny) && (info->addressType != mAddressType))
{
return nullptr;
}

return mChild->Accept(info);
}

private:
chip::Inet::InterfaceId mInterface;
chip::Inet::IPAddressType mAddressType;
ServerBase::BroadcastSendDelegate * mChild = nullptr;
};

} // namespace

namespace BroadcastIpAddresses {
Expand Down Expand Up @@ -115,8 +174,17 @@ const char * AddressTypeStr(chip::Inet::IPAddressType addressType)

void ShutdownEndpoint(mdns::Minimal::ServerBase::EndpointInfo & aEndpoint)
{
aEndpoint.udp->Free();
aEndpoint.udp = nullptr;
if (aEndpoint.listen_udp != nullptr)
{
aEndpoint.listen_udp->Free();
aEndpoint.listen_udp = nullptr;
}

if (aEndpoint.unicast_query_udp != nullptr)
{
aEndpoint.unicast_query_udp->Free();
aEndpoint.unicast_query_udp = nullptr;
}
}

} // namespace
Expand All @@ -130,18 +198,15 @@ void ServerBase::Shutdown()
{
for (size_t i = 0; i < mEndpointCount; i++)
{
if (mEndpoints[i].udp != nullptr)
{
ShutdownEndpoint(mEndpoints[i]);
}
ShutdownEndpoint(mEndpoints[i]);
}
}

bool ServerBase::IsListening() const
{
for (size_t i = 0; i < mEndpointCount; i++)
{
if (mEndpoints[i].udp != nullptr)
if (mEndpoints[i].listen_udp != nullptr)
{
return true;
}
Expand All @@ -167,13 +232,13 @@ CHIP_ERROR ServerBase::Listen(chip::Inet::InetLayer * inetLayer, ListenIterator
info->addressType = addressType;
info->interfaceId = interfaceId;

ReturnErrorOnFailure(inetLayer->NewUDPEndPoint(&info->udp));
ReturnErrorOnFailure(inetLayer->NewUDPEndPoint(&info->listen_udp));

ReturnErrorOnFailure(info->udp->Bind(addressType, chip::Inet::IPAddress::Any, port, interfaceId));
ReturnErrorOnFailure(info->listen_udp->Bind(addressType, chip::Inet::IPAddress::Any, port, interfaceId));

ReturnErrorOnFailure(info->udp->Listen(OnUdpPacketReceived, nullptr /*OnReceiveError*/, this));
ReturnErrorOnFailure(info->listen_udp->Listen(OnUdpPacketReceived, nullptr /*OnReceiveError*/, this));

CHIP_ERROR err = JoinMulticastGroup(interfaceId, info->udp, addressType);
CHIP_ERROR err = JoinMulticastGroup(interfaceId, info->listen_udp, addressType);
if (err != CHIP_NO_ERROR)
{
char interfaceName[chip::Inet::InterfaceId::kMaxIfNameLength];
Expand All @@ -188,6 +253,14 @@ CHIP_ERROR ServerBase::Listen(chip::Inet::InetLayer * inetLayer, ListenIterator
{
endpointIndex++;
}

// Separate UDP endpoint for unicast queries, bound to 0 (i.e. pick random ephemeral port)
// - helps in not having conflicts on port 5353, will receive unicast replies directly
// - has a *DRAWBACK* of unicast queries being considered LEGACY by mdns since they do
// not originate from 5353 and the answers will include a query section.
ReturnErrorOnFailure(inetLayer->NewUDPEndPoint(&info->unicast_query_udp));
ReturnErrorOnFailure(info->unicast_query_udp->Bind(addressType, chip::Inet::IPAddress::Any, 0, interfaceId));
ReturnErrorOnFailure(info->unicast_query_udp->Listen(OnUdpPacketReceived, nullptr /*OnReceiveError*/, this));
}

return autoShutdown.ReturnSuccess();
Expand All @@ -199,7 +272,7 @@ CHIP_ERROR ServerBase::DirectSend(chip::System::PacketBufferHandle && data, cons
for (size_t i = 0; i < mEndpointCount; i++)
{
EndpointInfo * info = &mEndpoints[i];
if (info->udp == nullptr)
if (info->listen_udp == nullptr)
{
continue;
}
Expand All @@ -209,73 +282,50 @@ CHIP_ERROR ServerBase::DirectSend(chip::System::PacketBufferHandle && data, cons
continue;
}

chip::Inet::InterfaceId boundIf = info->udp->GetBoundInterface();
chip::Inet::InterfaceId boundIf = info->listen_udp->GetBoundInterface();

if ((boundIf.IsPresent()) && (boundIf != interface))
{
continue;
}

return info->udp->SendTo(addr, port, std::move(data));
return info->listen_udp->SendTo(addr, port, std::move(data));
}

return CHIP_ERROR_NOT_CONNECTED;
}

CHIP_ERROR ServerBase::BroadcastSend(chip::System::PacketBufferHandle && data, uint16_t port, chip::Inet::InterfaceId interface,
chip::Inet::IPAddressType addressType)
CHIP_ERROR ServerBase::BroadcastUnicastQuery(chip::System::PacketBufferHandle && data, uint16_t port)
{
for (size_t i = 0; i < mEndpointCount; i++)
{
EndpointInfo * info = &mEndpoints[i];

if (info->udp == nullptr)
{
continue;
}

if ((info->interfaceId != interface) && (info->interfaceId != chip::Inet::InterfaceId::Null()))
{
continue;
}

if ((addressType != chip::Inet::IPAddressType::kAny) && (info->addressType != addressType))
{
continue;
}
QuerySocketPickerDelegate socketPicker;
return BroadcastImpl(std::move(data), port, &socketPicker);
}

CHIP_ERROR err;
CHIP_ERROR ServerBase::BroadcastUnicastQuery(chip::System::PacketBufferHandle && data, uint16_t port,
chip::Inet::InterfaceId interface, chip::Inet::IPAddressType addressType)
{
QuerySocketPickerDelegate socketPicker;
InterfaceTypeFilterDelegate filter(interface, addressType, &socketPicker);

/// The same packet needs to be sent over potentially multiple interfaces.
/// LWIP does not like having a pbuf sent over serparate interfaces, hence we create a copy
/// for sending via `CloneData`
///
/// TODO: this wastes one copy of the data and that could be optimized away
if (info->addressType == chip::Inet::IPAddressType::kIPv6)
{
err = info->udp->SendTo(mIpv6BroadcastAddress, port, data.CloneData(), info->udp->GetBoundInterface());
}
#if INET_CONFIG_ENABLE_IPV4
else if (info->addressType == chip::Inet::IPAddressType::kIPv4)
{
err = info->udp->SendTo(mIpv4BroadcastAddress, port, data.CloneData(), info->udp->GetBoundInterface());
}
#endif
else
{
return CHIP_ERROR_INCORRECT_STATE;
}
return BroadcastImpl(std::move(data), port, &filter);
}

if (err != CHIP_NO_ERROR)
{
return err;
}
}
CHIP_ERROR ServerBase::BroadcastSend(chip::System::PacketBufferHandle && data, uint16_t port, chip::Inet::InterfaceId interface,
chip::Inet::IPAddressType addressType)
{
ListenSocketPickerDelegate socketPicker;
InterfaceTypeFilterDelegate filter(interface, addressType, &socketPicker);

return CHIP_NO_ERROR;
return BroadcastImpl(std::move(data), port, &filter);
}

CHIP_ERROR ServerBase::BroadcastSend(chip::System::PacketBufferHandle && data, uint16_t port)
{
ListenSocketPickerDelegate socketPicker;
return BroadcastImpl(std::move(data), port, &socketPicker);
}

CHIP_ERROR ServerBase::BroadcastImpl(chip::System::PacketBufferHandle && data, uint16_t port, BroadcastSendDelegate * delegate)
{
// Broadcast requires sending data multiple times, each of which may error
// out, yet broadcast only has a single error code.
Expand All @@ -290,9 +340,10 @@ CHIP_ERROR ServerBase::BroadcastSend(chip::System::PacketBufferHandle && data, u

for (size_t i = 0; i < mEndpointCount; i++)
{
EndpointInfo * info = &mEndpoints[i];
EndpointInfo * info = &mEndpoints[i];
chip::Inet::UDPEndPoint * udp = delegate->Accept(info);

if (info->udp == nullptr)
if (udp == nullptr)
{
continue;
}
Expand All @@ -306,12 +357,12 @@ CHIP_ERROR ServerBase::BroadcastSend(chip::System::PacketBufferHandle && data, u
/// TODO: this wastes one copy of the data and that could be optimized away
if (info->addressType == chip::Inet::IPAddressType::kIPv6)
{
err = info->udp->SendTo(mIpv6BroadcastAddress, port, data.CloneData(), info->udp->GetBoundInterface());
err = udp->SendTo(mIpv6BroadcastAddress, port, data.CloneData(), udp->GetBoundInterface());
}
#if INET_CONFIG_ENABLE_IPV4
else if (info->addressType == chip::Inet::IPAddressType::kIPv4)
{
err = info->udp->SendTo(mIpv4BroadcastAddress, port, data.CloneData(), info->udp->GetBoundInterface());
err = udp->SendTo(mIpv4BroadcastAddress, port, data.CloneData(), udp->GetBoundInterface());
}
#endif
else
Expand Down
Loading

0 comments on commit 22b613a

Please sign in to comment.