From 27776014caed7746189bb087c5e1a90907bf9793 Mon Sep 17 00:00:00 2001 From: Andrei Litvin Date: Wed, 24 Nov 2021 13:20:41 -0500 Subject: [PATCH] Use separate udp endpoint (with separate ephemeral port) for minmdns unicast queries (#12161) * Separate out listening and querying ports for UDP * Correct port typo and make query port do UDP listen as well * Update unicast queries for the resolver to use a separate unicast function (bound to a random port) * Update minmdns client example to send unicast queries using ephemeral port as well * Do not assume having a query in a reply is a bad packet (because unicast replies will include the query) * Added comments to broadcast delegate implementations * Address some review comments * fix typo * Added more comments * More review comments * Rename things to filter/picker to make code slightly clearer --- examples/minimal-mdns/client.cpp | 17 +- src/lib/dnssd/Advertiser_ImplMinimalMdns.cpp | 2 +- src/lib/dnssd/Resolver_ImplMinimalMdns.cpp | 9 +- src/lib/dnssd/minimal_mdns/Server.cpp | 179 ++++++++++++------- src/lib/dnssd/minimal_mdns/Server.h | 34 +++- 5 files changed, 167 insertions(+), 74 deletions(-) diff --git a/examples/minimal-mdns/client.cpp b/examples/minimal-mdns/client.cpp index cacb5fcf438420..dcfef3f987e40d 100644 --- a/examples/minimal-mdns/client.cpp +++ b/examples/minimal-mdns/client.cpp @@ -283,10 +283,21 @@ void BroadcastPacket(mdns::Minimal::ServerBase * server) return; } - if (server->BroadcastSend(builder.ReleasePacket(), gOptions.querySendPort) != CHIP_NO_ERROR) + if (gOptions.unicastAnswers) { - printf("Error sending\n"); - return; + if (server->BroadcastUnicastQuery(builder.ReleasePacket(), gOptions.querySendPort) != CHIP_NO_ERROR) + { + printf("Error sending\n"); + return; + } + } + else + { + if (server->BroadcastSend(builder.ReleasePacket(), gOptions.querySendPort) != CHIP_NO_ERROR) + { + printf("Error sending\n"); + return; + } } } diff --git a/src/lib/dnssd/Advertiser_ImplMinimalMdns.cpp b/src/lib/dnssd/Advertiser_ImplMinimalMdns.cpp index 9dfa7ed332915f..4925bee60e33c1 100644 --- a/src/lib/dnssd/Advertiser_ImplMinimalMdns.cpp +++ b/src/lib/dnssd/Advertiser_ImplMinimalMdns.cpp @@ -733,7 +733,7 @@ bool AdvertiserMinMdns::ShouldAdvertiseOn(const chip::Inet::InterfaceId id, cons { const ServerBase::EndpointInfo & info = server.GetEndpoints()[i]; - if (info.udp == nullptr) + if (info.listen_udp == nullptr) { continue; } diff --git a/src/lib/dnssd/Resolver_ImplMinimalMdns.cpp b/src/lib/dnssd/Resolver_ImplMinimalMdns.cpp index 7b4a34a409fc26..739894d3a55e5f 100644 --- a/src/lib/dnssd/Resolver_ImplMinimalMdns.cpp +++ b/src/lib/dnssd/Resolver_ImplMinimalMdns.cpp @@ -108,8 +108,9 @@ class PacketDataReporter : public ParserDelegate void PacketDataReporter::OnQuery(const QueryData & data) { - ChipLogError(Discovery, "Unexpected query packet being parsed as a response"); - mValid = false; + // Ignore queries: + // - unicast answers will include the corresponding query in the answer + // packet, however that is not interesting for the resolver. } void PacketDataReporter::OnHeader(ConstHeaderRef & header) @@ -433,7 +434,7 @@ CHIP_ERROR MinMdnsResolver::SendQuery(mdns::Minimal::FullQName qname, mdns::Mini ReturnErrorCodeIf(!builder.Ok(), CHIP_ERROR_INTERNAL); - return GlobalMinimalMdnsServer::Server().BroadcastSend(builder.ReleasePacket(), kMdnsPort); + return GlobalMinimalMdnsServer::Server().BroadcastUnicastQuery(builder.ReleasePacket(), kMdnsPort); } CHIP_ERROR MinMdnsResolver::FindCommissionableNodes(DiscoveryFilter filter) @@ -577,7 +578,7 @@ CHIP_ERROR MinMdnsResolver::SendPendingResolveQueries() ReturnErrorCodeIf(!builder.Ok(), CHIP_ERROR_INTERNAL); - ReturnErrorOnFailure(GlobalMinimalMdnsServer::Server().BroadcastSend(builder.ReleasePacket(), kMdnsPort)); + ReturnErrorOnFailure(GlobalMinimalMdnsServer::Server().BroadcastUnicastQuery(builder.ReleasePacket(), kMdnsPort)); } return ScheduleResolveRetries(); diff --git a/src/lib/dnssd/minimal_mdns/Server.cpp b/src/lib/dnssd/minimal_mdns/Server.cpp index 1942a6f250be9c..bb06078fbc24a0 100644 --- a/src/lib/dnssd/minimal_mdns/Server.cpp +++ b/src/lib/dnssd/minimal_mdns/Server.cpp @@ -48,6 +48,65 @@ class ShutdownOnError ServerBase * mServer; }; +/** + * Extracts the Listening UDP Endpoint from an underlying ServerBase::EndpointInfo + */ +class ListenSocketPickerDelegate : public ServerBase::BroadcastSendDelegate +{ +public: + chip::Inet::UDPEndPoint * Accept(ServerBase::EndpointInfo * info) override { return info->listen_udp; } +}; + +/** + * Extracts the Querying UDP Endpoint from an underlying ServerBase::EndpointInfo + */ +class QuerySocketPickerDelegate : public ServerBase::BroadcastSendDelegate +{ +public: + chip::Inet::UDPEndPoint * Accept(ServerBase::EndpointInfo * info) override { return info->unicast_query_udp; } +}; + +/** + * Validates that an endpoint belongs to a specific interface/ip address type before forwarding the + * endpoint accept logic to another BroadcastSendDelegate. + * + * Usage like: + * + * SomeDelegate *child = ....; + * InterfaceTypeFilterDelegate filter(interfaceId, IPAddressType::IPv6, child); + * + * UDPEndPoint *udp = filter.Accept(endpointInfo); + */ +class InterfaceTypeFilterDelegate : public ServerBase::BroadcastSendDelegate +{ +public: + InterfaceTypeFilterDelegate(chip::Inet::InterfaceId interface, chip::Inet::IPAddressType type, + ServerBase::BroadcastSendDelegate * child) : + mInterface(interface), + mAddressType(type), mChild(child) + {} + + chip::Inet::UDPEndPoint * Accept(ServerBase::EndpointInfo * info) override + { + if ((info->interfaceId != mInterface) && (info->interfaceId != chip::Inet::InterfaceId::Null())) + { + return nullptr; + } + + if ((mAddressType != chip::Inet::IPAddressType::kAny) && (info->addressType != mAddressType)) + { + return nullptr; + } + + return mChild->Accept(info); + } + +private: + chip::Inet::InterfaceId mInterface; + chip::Inet::IPAddressType mAddressType; + ServerBase::BroadcastSendDelegate * mChild = nullptr; +}; + } // namespace namespace BroadcastIpAddresses { @@ -115,8 +174,17 @@ const char * AddressTypeStr(chip::Inet::IPAddressType addressType) void ShutdownEndpoint(mdns::Minimal::ServerBase::EndpointInfo & aEndpoint) { - aEndpoint.udp->Free(); - aEndpoint.udp = nullptr; + if (aEndpoint.listen_udp != nullptr) + { + aEndpoint.listen_udp->Free(); + aEndpoint.listen_udp = nullptr; + } + + if (aEndpoint.unicast_query_udp != nullptr) + { + aEndpoint.unicast_query_udp->Free(); + aEndpoint.unicast_query_udp = nullptr; + } } } // namespace @@ -130,10 +198,7 @@ void ServerBase::Shutdown() { for (size_t i = 0; i < mEndpointCount; i++) { - if (mEndpoints[i].udp != nullptr) - { - ShutdownEndpoint(mEndpoints[i]); - } + ShutdownEndpoint(mEndpoints[i]); } } @@ -141,7 +206,7 @@ bool ServerBase::IsListening() const { for (size_t i = 0; i < mEndpointCount; i++) { - if (mEndpoints[i].udp != nullptr) + if (mEndpoints[i].listen_udp != nullptr) { return true; } @@ -167,13 +232,13 @@ CHIP_ERROR ServerBase::Listen(chip::Inet::InetLayer * inetLayer, ListenIterator info->addressType = addressType; info->interfaceId = interfaceId; - ReturnErrorOnFailure(inetLayer->NewUDPEndPoint(&info->udp)); + ReturnErrorOnFailure(inetLayer->NewUDPEndPoint(&info->listen_udp)); - ReturnErrorOnFailure(info->udp->Bind(addressType, chip::Inet::IPAddress::Any, port, interfaceId)); + ReturnErrorOnFailure(info->listen_udp->Bind(addressType, chip::Inet::IPAddress::Any, port, interfaceId)); - ReturnErrorOnFailure(info->udp->Listen(OnUdpPacketReceived, nullptr /*OnReceiveError*/, this)); + ReturnErrorOnFailure(info->listen_udp->Listen(OnUdpPacketReceived, nullptr /*OnReceiveError*/, this)); - CHIP_ERROR err = JoinMulticastGroup(interfaceId, info->udp, addressType); + CHIP_ERROR err = JoinMulticastGroup(interfaceId, info->listen_udp, addressType); if (err != CHIP_NO_ERROR) { char interfaceName[chip::Inet::InterfaceId::kMaxIfNameLength]; @@ -188,6 +253,14 @@ CHIP_ERROR ServerBase::Listen(chip::Inet::InetLayer * inetLayer, ListenIterator { endpointIndex++; } + + // Separate UDP endpoint for unicast queries, bound to 0 (i.e. pick random ephemeral port) + // - helps in not having conflicts on port 5353, will receive unicast replies directly + // - has a *DRAWBACK* of unicast queries being considered LEGACY by mdns since they do + // not originate from 5353 and the answers will include a query section. + ReturnErrorOnFailure(inetLayer->NewUDPEndPoint(&info->unicast_query_udp)); + ReturnErrorOnFailure(info->unicast_query_udp->Bind(addressType, chip::Inet::IPAddress::Any, 0, interfaceId)); + ReturnErrorOnFailure(info->unicast_query_udp->Listen(OnUdpPacketReceived, nullptr /*OnReceiveError*/, this)); } return autoShutdown.ReturnSuccess(); @@ -199,7 +272,7 @@ CHIP_ERROR ServerBase::DirectSend(chip::System::PacketBufferHandle && data, cons for (size_t i = 0; i < mEndpointCount; i++) { EndpointInfo * info = &mEndpoints[i]; - if (info->udp == nullptr) + if (info->listen_udp == nullptr) { continue; } @@ -209,73 +282,50 @@ CHIP_ERROR ServerBase::DirectSend(chip::System::PacketBufferHandle && data, cons continue; } - chip::Inet::InterfaceId boundIf = info->udp->GetBoundInterface(); + chip::Inet::InterfaceId boundIf = info->listen_udp->GetBoundInterface(); if ((boundIf.IsPresent()) && (boundIf != interface)) { continue; } - return info->udp->SendTo(addr, port, std::move(data)); + return info->listen_udp->SendTo(addr, port, std::move(data)); } return CHIP_ERROR_NOT_CONNECTED; } -CHIP_ERROR ServerBase::BroadcastSend(chip::System::PacketBufferHandle && data, uint16_t port, chip::Inet::InterfaceId interface, - chip::Inet::IPAddressType addressType) +CHIP_ERROR ServerBase::BroadcastUnicastQuery(chip::System::PacketBufferHandle && data, uint16_t port) { - for (size_t i = 0; i < mEndpointCount; i++) - { - EndpointInfo * info = &mEndpoints[i]; - - if (info->udp == nullptr) - { - continue; - } - - if ((info->interfaceId != interface) && (info->interfaceId != chip::Inet::InterfaceId::Null())) - { - continue; - } - - if ((addressType != chip::Inet::IPAddressType::kAny) && (info->addressType != addressType)) - { - continue; - } + QuerySocketPickerDelegate socketPicker; + return BroadcastImpl(std::move(data), port, &socketPicker); +} - CHIP_ERROR err; +CHIP_ERROR ServerBase::BroadcastUnicastQuery(chip::System::PacketBufferHandle && data, uint16_t port, + chip::Inet::InterfaceId interface, chip::Inet::IPAddressType addressType) +{ + QuerySocketPickerDelegate socketPicker; + InterfaceTypeFilterDelegate filter(interface, addressType, &socketPicker); - /// The same packet needs to be sent over potentially multiple interfaces. - /// LWIP does not like having a pbuf sent over serparate interfaces, hence we create a copy - /// for sending via `CloneData` - /// - /// TODO: this wastes one copy of the data and that could be optimized away - if (info->addressType == chip::Inet::IPAddressType::kIPv6) - { - err = info->udp->SendTo(mIpv6BroadcastAddress, port, data.CloneData(), info->udp->GetBoundInterface()); - } -#if INET_CONFIG_ENABLE_IPV4 - else if (info->addressType == chip::Inet::IPAddressType::kIPv4) - { - err = info->udp->SendTo(mIpv4BroadcastAddress, port, data.CloneData(), info->udp->GetBoundInterface()); - } -#endif - else - { - return CHIP_ERROR_INCORRECT_STATE; - } + return BroadcastImpl(std::move(data), port, &filter); +} - if (err != CHIP_NO_ERROR) - { - return err; - } - } +CHIP_ERROR ServerBase::BroadcastSend(chip::System::PacketBufferHandle && data, uint16_t port, chip::Inet::InterfaceId interface, + chip::Inet::IPAddressType addressType) +{ + ListenSocketPickerDelegate socketPicker; + InterfaceTypeFilterDelegate filter(interface, addressType, &socketPicker); - return CHIP_NO_ERROR; + return BroadcastImpl(std::move(data), port, &filter); } CHIP_ERROR ServerBase::BroadcastSend(chip::System::PacketBufferHandle && data, uint16_t port) +{ + ListenSocketPickerDelegate socketPicker; + return BroadcastImpl(std::move(data), port, &socketPicker); +} + +CHIP_ERROR ServerBase::BroadcastImpl(chip::System::PacketBufferHandle && data, uint16_t port, BroadcastSendDelegate * delegate) { // Broadcast requires sending data multiple times, each of which may error // out, yet broadcast only has a single error code. @@ -290,9 +340,10 @@ CHIP_ERROR ServerBase::BroadcastSend(chip::System::PacketBufferHandle && data, u for (size_t i = 0; i < mEndpointCount; i++) { - EndpointInfo * info = &mEndpoints[i]; + EndpointInfo * info = &mEndpoints[i]; + chip::Inet::UDPEndPoint * udp = delegate->Accept(info); - if (info->udp == nullptr) + if (udp == nullptr) { continue; } @@ -306,12 +357,12 @@ CHIP_ERROR ServerBase::BroadcastSend(chip::System::PacketBufferHandle && data, u /// TODO: this wastes one copy of the data and that could be optimized away if (info->addressType == chip::Inet::IPAddressType::kIPv6) { - err = info->udp->SendTo(mIpv6BroadcastAddress, port, data.CloneData(), info->udp->GetBoundInterface()); + err = udp->SendTo(mIpv6BroadcastAddress, port, data.CloneData(), udp->GetBoundInterface()); } #if INET_CONFIG_ENABLE_IPV4 else if (info->addressType == chip::Inet::IPAddressType::kIPv4) { - err = info->udp->SendTo(mIpv4BroadcastAddress, port, data.CloneData(), info->udp->GetBoundInterface()); + err = udp->SendTo(mIpv4BroadcastAddress, port, data.CloneData(), udp->GetBoundInterface()); } #endif else diff --git a/src/lib/dnssd/minimal_mdns/Server.h b/src/lib/dnssd/minimal_mdns/Server.h index c1d07c5ba6f6aa..7c83a61a47c4f5 100644 --- a/src/lib/dnssd/minimal_mdns/Server.h +++ b/src/lib/dnssd/minimal_mdns/Server.h @@ -79,14 +79,32 @@ class ServerBase { chip::Inet::InterfaceId interfaceId = chip::Inet::InterfaceId::Null(); chip::Inet::IPAddressType addressType; - chip::Inet::UDPEndPoint * udp = nullptr; + chip::Inet::UDPEndPoint * listen_udp = nullptr; + chip::Inet::UDPEndPoint * unicast_query_udp = nullptr; + }; + + /** + * Helps implement a generic broadcast implementation: + * - provides the ability to determine what udp endpoint to use to broadcast + * a packet for the given endpoint info + */ + class BroadcastSendDelegate + { + public: + virtual ~BroadcastSendDelegate() = default; + + /** + * Returns non-null UDPEndpoint IFF a broadcast should be performed for the given EndpointInfo + */ + virtual chip::Inet::UDPEndPoint * Accept(ServerBase::EndpointInfo * info) = 0; }; ServerBase(EndpointInfo * endpointStorage, size_t kStorageSize) : mEndpoints(endpointStorage), mEndpointCount(kStorageSize) { for (size_t i = 0; i < mEndpointCount; i++) { - mEndpoints[i].udp = nullptr; + mEndpoints[i].listen_udp = nullptr; + mEndpoints[i].unicast_query_udp = nullptr; } BroadcastIpAddresses::GetIpv6Into(mIpv6BroadcastAddress); @@ -110,6 +128,16 @@ class ServerBase virtual CHIP_ERROR DirectSend(chip::System::PacketBufferHandle && data, const chip::Inet::IPAddress & addr, uint16_t port, chip::Inet::InterfaceId interface); + /// Send out a broadcast query, may use an ephemeral port to receive replies. + /// Ephemeral ports will make replies be marked as 'LEGACY' and replies will include a query secion. + virtual CHIP_ERROR BroadcastUnicastQuery(chip::System::PacketBufferHandle && data, uint16_t port); + + /// Send a specific packet broadcast to a specific interface using a specific address type + /// May use an ephemeral port to receive replies. + /// Ephemeral ports will make replies be marked as 'LEGACY' and replies will include a query secion. + virtual CHIP_ERROR BroadcastUnicastQuery(chip::System::PacketBufferHandle && data, uint16_t port, + chip::Inet::InterfaceId interface, chip::Inet::IPAddressType addressType); + /// Send a specific packet broadcast to all interfaces virtual CHIP_ERROR BroadcastSend(chip::System::PacketBufferHandle && data, uint16_t port); @@ -139,6 +167,8 @@ class ServerBase bool IsListening() const; private: + CHIP_ERROR BroadcastImpl(chip::System::PacketBufferHandle && data, uint16_t port, BroadcastSendDelegate * delegate); + static void OnUdpPacketReceived(chip::Inet::UDPEndPoint * endPoint, chip::System::PacketBufferHandle && buffer, const chip::Inet::IPPacketInfo * info);