From 5bb4f70e7f6336ca06045b3461e62373ee3c8457 Mon Sep 17 00:00:00 2001 From: C Freeman Date: Wed, 30 Jun 2021 14:32:05 -0400 Subject: [PATCH] Mdns: separate query responders for advertisers (#7615) * Mdns: separate query responders for advertisers The current setup uses a single query responder for commissionable and operational and all the records are cleared together and added in a large group. This makes it challenging to run commissionable and operational at the same time since they can't be cleared individually and because there are duplicate records between operational and commissionable. This commit changes the code to use one query responder per advertisement type and gets the ResponseSender to loop through the query responders to filter for answers. * Apply suggestions from code review Co-authored-by: chrisdecenzo <61757564+chrisdecenzo@users.noreply.github.com> * Restyled by clang-format * Fix tests - they were not testing right. * Restyled by clang-format * Update number of responders Represents 5 multi-admin fabrics for operational, commissionable and commissioner. * Whoops - broke a test with my last "minor" change. * Lighten the stack load for the failing test. Co-authored-by: chrisdecenzo <61757564+chrisdecenzo@users.noreply.github.com> Co-authored-by: Restyled.io --- src/lib/mdns/Advertiser_ImplMinimalMdns.cpp | 123 ++++++---- src/lib/mdns/minimal/ResponseSender.cpp | 57 +++-- src/lib/mdns/minimal/ResponseSender.h | 8 +- .../mdns/minimal/tests/TestResponseSender.cpp | 214 +++++++++++++++--- 4 files changed, 313 insertions(+), 89 deletions(-) diff --git a/src/lib/mdns/Advertiser_ImplMinimalMdns.cpp b/src/lib/mdns/Advertiser_ImplMinimalMdns.cpp index 504984d9dc47e8..39911ede6fdd39 100644 --- a/src/lib/mdns/Advertiser_ImplMinimalMdns.cpp +++ b/src/lib/mdns/Advertiser_ImplMinimalMdns.cpp @@ -102,9 +102,12 @@ class AdvertiserMinMdns : public ServiceAdvertiser, public ParserDelegate // parses queries { public: - AdvertiserMinMdns() : mResponseSender(&GlobalMinimalMdnsServer::Server(), mQueryResponderAllocator.GetQueryResponder()) + AdvertiserMinMdns() : mResponseSender(&GlobalMinimalMdnsServer::Server()) { GlobalMinimalMdnsServer::Instance().SetQueryDelegate(this); + mResponseSender.AddQueryResponder(mQueryResponderAllocatorOperational.GetQueryResponder()); + mResponseSender.AddQueryResponder(mQueryResponderAllocatorCommissionable.GetQueryResponder()); + mResponseSender.AddQueryResponder(mQueryResponderAllocatorCommissioner.GetQueryResponder()); } ~AdvertiserMinMdns() {} @@ -134,8 +137,13 @@ class AdvertiserMinMdns : public ServiceAdvertiser, FullQName GetCommisioningTextEntries(const CommissionAdvertisingParameters & params); - static constexpr size_t kMaxRecords = 32; - QueryResponderAllocator mQueryResponderAllocator; + // Max number of records for operational = PTR, SRV, TXT, A, AAAA, no subtypes. + static constexpr size_t kMaxOperationalRecords = 5; + QueryResponderAllocator mQueryResponderAllocatorOperational; + // Max number of records for commissionable = 7 x PTR (base + 6 sub types - _S, _L, _D, _T, _C, _A), SRV, TXT, A, AAAA + static constexpr size_t kMaxCommissionRecords = 11; + QueryResponderAllocator mQueryResponderAllocatorCommissionable; + QueryResponderAllocator mQueryResponderAllocatorCommissioner; ResponseSender mResponseSender; @@ -195,25 +203,28 @@ CHIP_ERROR AdvertiserMinMdns::Start(chip::Inet::InetLayer * inetLayer, uint16_t /// Stops the advertiser. CHIP_ERROR AdvertiserMinMdns::StopPublishDevice() { - mQueryResponderAllocator.Clear(); + mQueryResponderAllocatorOperational.Clear(); + mQueryResponderAllocatorCommissionable.Clear(); + mQueryResponderAllocatorCommissioner.Clear(); return CHIP_NO_ERROR; } CHIP_ERROR AdvertiserMinMdns::Advertise(const OperationalAdvertisingParameters & params) { - mQueryResponderAllocator.Clear(); + // TODO: When multi-admin is enabled, commissionable does not need to be cleared here. + mQueryResponderAllocatorOperational.Clear(); char nameBuffer[64] = ""; /// need to set server name ReturnErrorOnFailure(MakeInstanceName(nameBuffer, sizeof(nameBuffer), params.GetPeerId())); FullQName operationalServiceName = - mQueryResponderAllocator.AllocateQName(kOperationalServiceName, kOperationalProtocol, kLocalDomain); + mQueryResponderAllocatorOperational.AllocateQName(kOperationalServiceName, kOperationalProtocol, kLocalDomain); FullQName operationalServerName = - mQueryResponderAllocator.AllocateQName(nameBuffer, kOperationalServiceName, kOperationalProtocol, kLocalDomain); + mQueryResponderAllocatorOperational.AllocateQName(nameBuffer, kOperationalServiceName, kOperationalProtocol, kLocalDomain); ReturnErrorOnFailure(MakeHostName(nameBuffer, sizeof(nameBuffer), params.GetMac())); - FullQName serverName = mQueryResponderAllocator.AllocateQName(nameBuffer, kLocalDomain); + FullQName serverName = mQueryResponderAllocatorOperational.AllocateQName(nameBuffer, kLocalDomain); if ((operationalServiceName.nameCount == 0) || (operationalServerName.nameCount == 0) || (serverName.nameCount == 0)) { @@ -221,7 +232,7 @@ CHIP_ERROR AdvertiserMinMdns::Advertise(const OperationalAdvertisingParameters & return CHIP_ERROR_NO_MEMORY; } - if (!mQueryResponderAllocator.AddResponder(operationalServiceName, operationalServerName) + if (!mQueryResponderAllocatorOperational.AddResponder(operationalServiceName, operationalServerName) .SetReportAdditional(operationalServerName) .SetReportInServiceListing(true) .IsValid()) @@ -230,14 +241,15 @@ CHIP_ERROR AdvertiserMinMdns::Advertise(const OperationalAdvertisingParameters & return CHIP_ERROR_NO_MEMORY; } - if (!mQueryResponderAllocator.AddResponder(SrvResourceRecord(operationalServerName, serverName, params.GetPort())) + if (!mQueryResponderAllocatorOperational + .AddResponder(SrvResourceRecord(operationalServerName, serverName, params.GetPort())) .SetReportAdditional(serverName) .IsValid()) { ChipLogError(Discovery, "Failed to add SRV record mDNS responder"); return CHIP_ERROR_NO_MEMORY; } - if (!mQueryResponderAllocator.AddResponder(TxtResourceRecord(operationalServerName, mEmptyTextEntries)) + if (!mQueryResponderAllocatorOperational.AddResponder(TxtResourceRecord(operationalServerName, mEmptyTextEntries)) .SetReportAdditional(serverName) .IsValid()) { @@ -245,7 +257,7 @@ CHIP_ERROR AdvertiserMinMdns::Advertise(const OperationalAdvertisingParameters & return CHIP_ERROR_NO_MEMORY; } - if (!mQueryResponderAllocator.AddResponder(serverName).IsValid()) + if (!mQueryResponderAllocatorOperational.AddResponder(serverName).IsValid()) { ChipLogError(Discovery, "Failed to add IPv6 mDNS responder"); return CHIP_ERROR_NO_MEMORY; @@ -253,7 +265,7 @@ CHIP_ERROR AdvertiserMinMdns::Advertise(const OperationalAdvertisingParameters & if (params.IsIPv4Enabled()) { - if (!mQueryResponderAllocator.AddResponder(serverName).IsValid()) + if (!mQueryResponderAllocatorOperational.AddResponder(serverName).IsValid()) { ChipLogError(Discovery, "Failed to add IPv4 mDNS responder"); return CHIP_ERROR_NO_MEMORY; @@ -267,7 +279,16 @@ CHIP_ERROR AdvertiserMinMdns::Advertise(const OperationalAdvertisingParameters & CHIP_ERROR AdvertiserMinMdns::Advertise(const CommissionAdvertisingParameters & params) { - mQueryResponderAllocator.Clear(); + // TODO: When multi-admin is enabled, operational does not need to be cleared here. + if (params.GetCommissionAdvertiseMode() == CommssionAdvertiseMode::kCommissionableNode) + { + mQueryResponderAllocatorCommissionable.Clear(); + } + else + { + mQueryResponderAllocatorCommissioner.Clear(); + } + // TODO: need to detect colisions here char nameBuffer[64] = ""; size_t len = snprintf(nameBuffer, sizeof(nameBuffer), ChipLogFormatX64, GetRandU32(), GetRandU32()); @@ -275,14 +296,18 @@ CHIP_ERROR AdvertiserMinMdns::Advertise(const CommissionAdvertisingParameters & { return CHIP_ERROR_NO_MEMORY; } + QueryResponderAllocator * allocator = + params.GetCommissionAdvertiseMode() == CommssionAdvertiseMode::kCommissionableNode ? &mQueryResponderAllocatorCommissionable + : &mQueryResponderAllocatorCommissioner; const char * serviceType = params.GetCommissionAdvertiseMode() == CommssionAdvertiseMode::kCommissionableNode ? kCommissionableServiceName : kCommissionerServiceName; - FullQName serviceName = mQueryResponderAllocator.AllocateQName(serviceType, kCommissionProtocol, kLocalDomain); - FullQName instanceName = mQueryResponderAllocator.AllocateQName(nameBuffer, serviceType, kCommissionProtocol, kLocalDomain); + + FullQName serviceName = allocator->AllocateQName(serviceType, kCommissionProtocol, kLocalDomain); + FullQName instanceName = allocator->AllocateQName(nameBuffer, serviceType, kCommissionProtocol, kLocalDomain); ReturnErrorOnFailure(MakeHostName(nameBuffer, sizeof(nameBuffer), params.GetMac())); - FullQName hostName = mQueryResponderAllocator.AllocateQName(nameBuffer, kLocalDomain); + FullQName hostName = allocator->AllocateQName(nameBuffer, kLocalDomain); if ((serviceName.nameCount == 0) || (instanceName.nameCount == 0) || (hostName.nameCount == 0)) { @@ -290,7 +315,7 @@ CHIP_ERROR AdvertiserMinMdns::Advertise(const CommissionAdvertisingParameters & return CHIP_ERROR_NO_MEMORY; } - if (!mQueryResponderAllocator.AddResponder(serviceName, instanceName) + if (!allocator->AddResponder(serviceName, instanceName) .SetReportAdditional(instanceName) .SetReportInServiceListing(true) .IsValid()) @@ -299,14 +324,14 @@ CHIP_ERROR AdvertiserMinMdns::Advertise(const CommissionAdvertisingParameters & return CHIP_ERROR_NO_MEMORY; } - if (!mQueryResponderAllocator.AddResponder(SrvResourceRecord(instanceName, hostName, params.GetPort())) + if (!allocator->AddResponder(SrvResourceRecord(instanceName, hostName, params.GetPort())) .SetReportAdditional(hostName) .IsValid()) { ChipLogError(Discovery, "Failed to add SRV record mDNS responder"); return CHIP_ERROR_NO_MEMORY; } - if (!mQueryResponderAllocator.AddResponder(hostName).IsValid()) + if (!allocator->AddResponder(hostName).IsValid()) { ChipLogError(Discovery, "Failed to add IPv6 mDNS responder"); return CHIP_ERROR_NO_MEMORY; @@ -314,7 +339,7 @@ CHIP_ERROR AdvertiserMinMdns::Advertise(const CommissionAdvertisingParameters & if (params.IsIPv4Enabled()) { - if (!mQueryResponderAllocator.AddResponder(hostName).IsValid()) + if (!allocator->AddResponder(hostName).IsValid()) { ChipLogError(Discovery, "Failed to add IPv4 mDNS responder"); return CHIP_ERROR_NO_MEMORY; @@ -325,11 +350,11 @@ CHIP_ERROR AdvertiserMinMdns::Advertise(const CommissionAdvertisingParameters & { MakeServiceSubtype(nameBuffer, sizeof(nameBuffer), DiscoveryFilter(DiscoveryFilterType::kVendor, params.GetVendorId().Value())); - FullQName vendorServiceName = mQueryResponderAllocator.AllocateQName(nameBuffer, kSubtypeServiceNamePart, serviceType, - kCommissionProtocol, kLocalDomain); + FullQName vendorServiceName = + allocator->AllocateQName(nameBuffer, kSubtypeServiceNamePart, serviceType, kCommissionProtocol, kLocalDomain); ReturnErrorCodeIf(vendorServiceName.nameCount == 0, CHIP_ERROR_NO_MEMORY); - if (!mQueryResponderAllocator.AddResponder(vendorServiceName, instanceName) + if (!allocator->AddResponder(vendorServiceName, instanceName) .SetReportAdditional(instanceName) .SetReportInServiceListing(true) .IsValid()) @@ -343,11 +368,11 @@ CHIP_ERROR AdvertiserMinMdns::Advertise(const CommissionAdvertisingParameters & { MakeServiceSubtype(nameBuffer, sizeof(nameBuffer), DiscoveryFilter(DiscoveryFilterType::kDeviceType, params.GetDeviceType().Value())); - FullQName vendorServiceName = mQueryResponderAllocator.AllocateQName(nameBuffer, kSubtypeServiceNamePart, serviceType, - kCommissionProtocol, kLocalDomain); + FullQName vendorServiceName = + allocator->AllocateQName(nameBuffer, kSubtypeServiceNamePart, serviceType, kCommissionProtocol, kLocalDomain); ReturnErrorCodeIf(vendorServiceName.nameCount == 0, CHIP_ERROR_NO_MEMORY); - if (!mQueryResponderAllocator.AddResponder(vendorServiceName, instanceName) + if (!allocator->AddResponder(vendorServiceName, instanceName) .SetReportAdditional(instanceName) .SetReportInServiceListing(true) .IsValid()) @@ -363,11 +388,11 @@ CHIP_ERROR AdvertiserMinMdns::Advertise(const CommissionAdvertisingParameters & { MakeServiceSubtype(nameBuffer, sizeof(nameBuffer), DiscoveryFilter(DiscoveryFilterType::kShort, params.GetShortDiscriminator())); - FullQName shortServiceName = mQueryResponderAllocator.AllocateQName(nameBuffer, kSubtypeServiceNamePart, serviceType, - kCommissionProtocol, kLocalDomain); + FullQName shortServiceName = + allocator->AllocateQName(nameBuffer, kSubtypeServiceNamePart, serviceType, kCommissionProtocol, kLocalDomain); ReturnErrorCodeIf(shortServiceName.nameCount == 0, CHIP_ERROR_NO_MEMORY); - if (!mQueryResponderAllocator.AddResponder(shortServiceName, instanceName) + if (!allocator->AddResponder(shortServiceName, instanceName) .SetReportAdditional(instanceName) .SetReportInServiceListing(true) .IsValid()) @@ -380,10 +405,10 @@ CHIP_ERROR AdvertiserMinMdns::Advertise(const CommissionAdvertisingParameters & { MakeServiceSubtype(nameBuffer, sizeof(nameBuffer), DiscoveryFilter(DiscoveryFilterType::kLong, params.GetLongDiscriminator())); - FullQName longServiceName = mQueryResponderAllocator.AllocateQName(nameBuffer, kSubtypeServiceNamePart, serviceType, - kCommissionProtocol, kLocalDomain); + FullQName longServiceName = + allocator->AllocateQName(nameBuffer, kSubtypeServiceNamePart, serviceType, kCommissionProtocol, kLocalDomain); ReturnErrorCodeIf(longServiceName.nameCount == 0, CHIP_ERROR_NO_MEMORY); - if (!mQueryResponderAllocator.AddResponder(longServiceName, instanceName) + if (!allocator->AddResponder(longServiceName, instanceName) .SetReportAdditional(instanceName) .SetReportInServiceListing(true) .IsValid()) @@ -396,10 +421,10 @@ CHIP_ERROR AdvertiserMinMdns::Advertise(const CommissionAdvertisingParameters & { MakeServiceSubtype(nameBuffer, sizeof(nameBuffer), DiscoveryFilter(DiscoveryFilterType::kCommissioningMode, params.GetCommissioningMode() ? 1 : 0)); - FullQName longServiceName = mQueryResponderAllocator.AllocateQName(nameBuffer, kSubtypeServiceNamePart, serviceType, - kCommissionProtocol, kLocalDomain); + FullQName longServiceName = + allocator->AllocateQName(nameBuffer, kSubtypeServiceNamePart, serviceType, kCommissionProtocol, kLocalDomain); ReturnErrorCodeIf(longServiceName.nameCount == 0, CHIP_ERROR_NO_MEMORY); - if (!mQueryResponderAllocator.AddResponder(longServiceName, instanceName) + if (!allocator->AddResponder(longServiceName, instanceName) .SetReportAdditional(instanceName) .SetReportInServiceListing(true) .IsValid()) @@ -413,10 +438,10 @@ CHIP_ERROR AdvertiserMinMdns::Advertise(const CommissionAdvertisingParameters & { MakeServiceSubtype(nameBuffer, sizeof(nameBuffer), DiscoveryFilter(DiscoveryFilterType::kCommissioningModeFromCommand, 1)); - FullQName longServiceName = mQueryResponderAllocator.AllocateQName(nameBuffer, kSubtypeServiceNamePart, serviceType, - kCommissionProtocol, kLocalDomain); + FullQName longServiceName = + allocator->AllocateQName(nameBuffer, kSubtypeServiceNamePart, serviceType, kCommissionProtocol, kLocalDomain); ReturnErrorCodeIf(longServiceName.nameCount == 0, CHIP_ERROR_NO_MEMORY); - if (!mQueryResponderAllocator.AddResponder(longServiceName, instanceName) + if (!allocator->AddResponder(longServiceName, instanceName) .SetReportAdditional(instanceName) .SetReportInServiceListing(true) .IsValid()) @@ -427,7 +452,7 @@ CHIP_ERROR AdvertiserMinMdns::Advertise(const CommissionAdvertisingParameters & } } - if (!mQueryResponderAllocator.AddResponder(TxtResourceRecord(instanceName, GetCommisioningTextEntries(params))) + if (!allocator->AddResponder(TxtResourceRecord(instanceName, GetCommisioningTextEntries(params))) .SetReportAdditional(hostName) .IsValid()) { @@ -454,6 +479,10 @@ FullQName AdvertiserMinMdns::GetCommisioningTextEntries(const CommissionAdvertis const char * txtFields[kMaxTxtFields]; size_t numTxtFields = 0; + QueryResponderAllocator * allocator = + params.GetCommissionAdvertiseMode() == CommssionAdvertiseMode::kCommissionableNode ? &mQueryResponderAllocatorCommissionable + : &mQueryResponderAllocatorCommissioner; + char txtVidPid[chip::Mdns::kKeyVendorProductMaxLength + 4]; if (params.GetProductId().HasValue() && params.GetVendorId().HasValue()) { @@ -490,7 +519,7 @@ FullQName AdvertiserMinMdns::GetCommisioningTextEntries(const CommissionAdvertis if (!params.GetVendorId().HasValue()) { - return mQueryResponderAllocator.AllocateQName(txtDiscriminator); + return allocator->AllocateQName(txtDiscriminator); } char txtCommissioningMode[chip::Mdns::kKeyCommissioningModeMaxLength + 4]; @@ -527,11 +556,11 @@ FullQName AdvertiserMinMdns::GetCommisioningTextEntries(const CommissionAdvertis } if (numTxtFields == 0) { - return mQueryResponderAllocator.AllocateQNameFromArray(mEmptyTextEntries, 1); + return allocator->AllocateQNameFromArray(mEmptyTextEntries, 1); } else { - return mQueryResponderAllocator.AllocateQNameFromArray(txtFields, numTxtFields); + return allocator->AllocateQNameFromArray(txtFields, numTxtFields); } } // namespace @@ -603,7 +632,9 @@ void AdvertiserMinMdns::AdvertiseRecords() QueryData queryData(QType::PTR, QClass::IN, false /* unicast */); queryData.SetIsBootAdvertising(true); - mQueryResponderAllocator.GetQueryResponder()->ClearBroadcastThrottle(); + mQueryResponderAllocatorOperational.GetQueryResponder()->ClearBroadcastThrottle(); + mQueryResponderAllocatorCommissionable.GetQueryResponder()->ClearBroadcastThrottle(); + mQueryResponderAllocatorCommissioner.GetQueryResponder()->ClearBroadcastThrottle(); CHIP_ERROR err = mResponseSender.Respond(0, queryData, &packetInfo); if (err != CHIP_NO_ERROR) @@ -613,7 +644,9 @@ void AdvertiserMinMdns::AdvertiseRecords() } // Once all automatic broadcasts are done, allow immediate replies once. - mQueryResponderAllocator.GetQueryResponder()->ClearBroadcastThrottle(); + mQueryResponderAllocatorOperational.GetQueryResponder()->ClearBroadcastThrottle(); + mQueryResponderAllocatorCommissionable.GetQueryResponder()->ClearBroadcastThrottle(); + mQueryResponderAllocatorCommissioner.GetQueryResponder()->ClearBroadcastThrottle(); } AdvertiserMinMdns gAdvertiser; diff --git a/src/lib/mdns/minimal/ResponseSender.cpp b/src/lib/mdns/minimal/ResponseSender.cpp index 51596fc47b62e2..06d408650dd29a 100644 --- a/src/lib/mdns/minimal/ResponseSender.cpp +++ b/src/lib/mdns/minimal/ResponseSender.cpp @@ -59,6 +59,19 @@ bool ResponseSendingState::IncludeQuery() const } // namespace Internal +CHIP_ERROR ResponseSender::AddQueryResponder(QueryResponderBase * queryResponder) +{ + for (size_t i = 0; i < kMaxQueryResponders; ++i) + { + if (mResponder[i] == nullptr || mResponder[i] == queryResponder) + { + mResponder[i] = queryResponder; + return CHIP_NO_ERROR; + } + } + return CHIP_ERROR_NO_MEMORY; +} + CHIP_ERROR ResponseSender::Respond(uint32_t messageId, const QueryData & query, const chip::Inet::IPPacketInfo * querySource) { mSendState.Reset(messageId, query, querySource); @@ -66,7 +79,13 @@ CHIP_ERROR ResponseSender::Respond(uint32_t messageId, const QueryData & query, // Responder has a stateful 'additional replies required' that is used within the response // loop. 'no additionals required' is set at the start and additionals are marked as the query // reply is built. - mResponder->ResetAdditionals(); + for (size_t i = 0; i < kMaxQueryResponders; ++i) + { + if (mResponder[i] != nullptr) + { + mResponder[i]->ResetAdditionals(); + } + } // send all 'Answer' replies { @@ -86,17 +105,23 @@ CHIP_ERROR ResponseSender::Respond(uint32_t messageId, const QueryData & query, constexpr uint64_t kOneSecondMs = 1000; responseFilter.SetIncludeOnlyMulticastBeforeMS(kTimeNowMs - kOneSecondMs); } - - for (auto it = mResponder->begin(&responseFilter); it != mResponder->end(); it++) + for (size_t i = 0; i < kMaxQueryResponders; ++i) { - it->responder->AddAllResponses(querySource, this); - ReturnErrorOnFailure(mSendState.GetError()); + if (mResponder[i] == nullptr) + { + continue; + } + for (auto it = mResponder[i]->begin(&responseFilter); it != mResponder[i]->end(); it++) + { + it->responder->AddAllResponses(querySource, this); + ReturnErrorOnFailure(mSendState.GetError()); - mResponder->MarkAdditionalRepliesFor(it); + mResponder[i]->MarkAdditionalRepliesFor(it); - if (!mSendState.SendUnicast()) - { - it->lastMulticastTime = kTimeNowMs; + if (!mSendState.SendUnicast()) + { + it->lastMulticastTime = kTimeNowMs; + } } } } @@ -113,11 +138,17 @@ CHIP_ERROR ResponseSender::Respond(uint32_t messageId, const QueryData & query, responseFilter .SetReplyFilter(&queryReplyFilter) // .SetIncludeAdditionalRepliesOnly(true); - - for (auto it = mResponder->begin(&responseFilter); it != mResponder->end(); it++) + for (size_t i = 0; i < kMaxQueryResponders; ++i) { - it->responder->AddAllResponses(querySource, this); - ReturnErrorOnFailure(mSendState.GetError()); + if (mResponder[i] == nullptr) + { + continue; + } + for (auto it = mResponder[i]->begin(&responseFilter); it != mResponder[i]->end(); it++) + { + it->responder->AddAllResponses(querySource, this); + ReturnErrorOnFailure(mSendState.GetError()); + } } } diff --git a/src/lib/mdns/minimal/ResponseSender.h b/src/lib/mdns/minimal/ResponseSender.h index a8df341d5c463a..31964801bfbb40 100644 --- a/src/lib/mdns/minimal/ResponseSender.h +++ b/src/lib/mdns/minimal/ResponseSender.h @@ -89,7 +89,11 @@ class ResponseSendingState class ResponseSender : public ResponderDelegate { public: - ResponseSender(ServerBase * server, QueryResponderBase * responder) : mServer(server), mResponder(responder) {} + // TODO(cecille): Template this and set appropriately. Please see issue #8000. + static constexpr size_t kMaxQueryResponders = 7; + ResponseSender(ServerBase * server) : mServer(server) {} + + CHIP_ERROR AddQueryResponder(QueryResponderBase * queryResponder); /// Send back the response to a particular query CHIP_ERROR Respond(uint32_t messageId, const QueryData & query, const chip::Inet::IPPacketInfo * querySource); @@ -102,7 +106,7 @@ class ResponseSender : public ResponderDelegate CHIP_ERROR PrepareNewReplyPacket(); ServerBase * mServer; - QueryResponderBase * mResponder; + QueryResponderBase * mResponder[kMaxQueryResponders] = {}; /// Current send state ResponseBuilder mResponseBuilder; // packet being built diff --git a/src/lib/mdns/minimal/tests/TestResponseSender.cpp b/src/lib/mdns/minimal/tests/TestResponseSender.cpp index 8afe0228ea01c8..7541eaf482092f 100644 --- a/src/lib/mdns/minimal/tests/TestResponseSender.cpp +++ b/src/lib/mdns/minimal/tests/TestResponseSender.cpp @@ -19,6 +19,7 @@ #include #include +#include #include #include #include @@ -61,9 +62,25 @@ class CheckOnlyServer : public ServerBase, public ParserDelegate // For now, types and names are sufficient for checking that the response sender is sending out the correct records. if (data.GetType() == expectedRecord[i]->GetType() && data.GetName() == expectedRecord[i]->GetName()) { - foundRecord[i] = true; - recordIsExpected = true; - break; + if (data.GetType() == QType::PTR) + { + // Check that the internal values are the same + SerializedQNameIterator dataTarget; + ParsePtrRecord(data.GetData(), data.GetData(), &dataTarget); + const PtrResourceRecord * expectedPtr = static_cast(expectedRecord[i]); + if (dataTarget == expectedPtr->GetPtr()) + { + foundRecord[i] = true; + recordIsExpected = true; + break; + } + } + else + { + foundRecord[i] = true; + recordIsExpected = true; + break; + } } } NL_TEST_ASSERT(mInSuite, recordIsExpected); @@ -89,6 +106,7 @@ class CheckOnlyServer : public ServerBase, public ParserDelegate if (expectedRecord[i] == nullptr) { expectedRecord[i] = record; + foundRecord[i] = false; return; } } @@ -148,11 +166,11 @@ struct CommonTestElements uint8_t instanceNameStorage[64]; uint8_t hostNameStorage[64]; uint8_t txtStorage[64]; - FullQName dnsSd = FlatAllocatedQName::Build(dnsSdServiceStorage, "_services", "_dns-sd", "_udp", "local"); - FullQName service = FlatAllocatedQName::Build(serviceNameStorage, "test", "service"); - FullQName instance = FlatAllocatedQName::Build(instanceNameStorage, "test", "instance"); - FullQName host = FlatAllocatedQName::Build(hostNameStorage, "test", "host"); - FullQName txt = FlatAllocatedQName::Build(txtStorage, "L1=something", "L2=other"); + FullQName dnsSd; + FullQName service; + FullQName instance; + FullQName host; + FullQName txt; static constexpr uint16_t kPort = 54; PtrResourceRecord ptrRecord = PtrResourceRecord(service, instance); @@ -164,10 +182,14 @@ struct CommonTestElements CheckOnlyServer server; QueryResponder<10> queryResponder; - ResponseSender responseSender; Inet::IPPacketInfo packetInfo; - CommonTestElements(nlTestSuite * inSuite) : server(inSuite), responseSender(&server, &queryResponder) + CommonTestElements(nlTestSuite * inSuite, const char * tag) : + dnsSd(FlatAllocatedQName::Build(dnsSdServiceStorage, "_services", "_dns-sd", "_udp", "local")), + service(FlatAllocatedQName::Build(serviceNameStorage, tag, "service")), + instance(FlatAllocatedQName::Build(instanceNameStorage, tag, "instance")), + host(FlatAllocatedQName::Build(hostNameStorage, tag, "host")), + txt(FlatAllocatedQName::Build(txtStorage, tag, "L1=something", "L2=other")), server(inSuite) { queryResponder.Init(); header.SetQueryCount(1); @@ -176,7 +198,9 @@ struct CommonTestElements void SrvAnyResponseToInstance(nlTestSuite * inSuite, void * inContext) { - CommonTestElements common(inSuite); + CommonTestElements common(inSuite, "test"); + ResponseSender responseSender(&common.server); + NL_TEST_ASSERT(inSuite, responseSender.AddQueryResponder(&common.queryResponder) == CHIP_NO_ERROR); common.queryResponder.AddResponder(&common.srvResponder); // Build a query for our srv record @@ -185,7 +209,7 @@ void SrvAnyResponseToInstance(nlTestSuite * inSuite, void * inContext) QueryData queryData = QueryData(QType::ANY, QClass::IN, false, common.requestNameStart, common.requestBytesRange); common.server.AddExpectedRecord(&common.srvRecord); - common.responseSender.Respond(1, queryData, &common.packetInfo); + responseSender.Respond(1, queryData, &common.packetInfo); NL_TEST_ASSERT(inSuite, common.server.GetSendCalled()); NL_TEST_ASSERT(inSuite, common.server.GetHeaderFound()); @@ -193,7 +217,9 @@ void SrvAnyResponseToInstance(nlTestSuite * inSuite, void * inContext) void SrvTxtAnyResponseToInstance(nlTestSuite * inSuite, void * inContext) { - CommonTestElements common(inSuite); + CommonTestElements common(inSuite, "test"); + ResponseSender responseSender(&common.server); + NL_TEST_ASSERT(inSuite, responseSender.AddQueryResponder(&common.queryResponder) == CHIP_NO_ERROR); common.queryResponder.AddResponder(&common.srvResponder); common.queryResponder.AddResponder(&common.txtResponder); @@ -205,7 +231,7 @@ void SrvTxtAnyResponseToInstance(nlTestSuite * inSuite, void * inContext) // We requested ANY on the host name, expect both back. common.server.AddExpectedRecord(&common.srvRecord); common.server.AddExpectedRecord(&common.txtRecord); - common.responseSender.Respond(1, queryData, &common.packetInfo); + responseSender.Respond(1, queryData, &common.packetInfo); NL_TEST_ASSERT(inSuite, common.server.GetSendCalled()); NL_TEST_ASSERT(inSuite, common.server.GetHeaderFound()); @@ -213,7 +239,9 @@ void SrvTxtAnyResponseToInstance(nlTestSuite * inSuite, void * inContext) void PtrSrvTxtAnyResponseToServiceName(nlTestSuite * inSuite, void * inContext) { - CommonTestElements common(inSuite); + CommonTestElements common(inSuite, "test"); + ResponseSender responseSender(&common.server); + NL_TEST_ASSERT(inSuite, responseSender.AddQueryResponder(&common.queryResponder) == CHIP_NO_ERROR); common.queryResponder.AddResponder(&common.ptrResponder).SetReportAdditional(common.instance); common.queryResponder.AddResponder(&common.srvResponder); common.queryResponder.AddResponder(&common.txtResponder); @@ -228,7 +256,7 @@ void PtrSrvTxtAnyResponseToServiceName(nlTestSuite * inSuite, void * inContext) common.server.AddExpectedRecord(&common.srvRecord); common.server.AddExpectedRecord(&common.txtRecord); - common.responseSender.Respond(1, queryData, &common.packetInfo); + responseSender.Respond(1, queryData, &common.packetInfo); NL_TEST_ASSERT(inSuite, common.server.GetSendCalled()); NL_TEST_ASSERT(inSuite, common.server.GetHeaderFound()); @@ -236,7 +264,9 @@ void PtrSrvTxtAnyResponseToServiceName(nlTestSuite * inSuite, void * inContext) void PtrSrvTxtAnyResponseToInstance(nlTestSuite * inSuite, void * inContext) { - CommonTestElements common(inSuite); + CommonTestElements common(inSuite, "test"); + ResponseSender responseSender(&common.server); + NL_TEST_ASSERT(inSuite, responseSender.AddQueryResponder(&common.queryResponder) == CHIP_NO_ERROR); common.queryResponder.AddResponder(&common.ptrResponder); common.queryResponder.AddResponder(&common.srvResponder); common.queryResponder.AddResponder(&common.txtResponder); @@ -250,7 +280,7 @@ void PtrSrvTxtAnyResponseToInstance(nlTestSuite * inSuite, void * inContext) common.server.AddExpectedRecord(&common.srvRecord); common.server.AddExpectedRecord(&common.txtRecord); - common.responseSender.Respond(1, queryData, &common.packetInfo); + responseSender.Respond(1, queryData, &common.packetInfo); NL_TEST_ASSERT(inSuite, common.server.GetSendCalled()); NL_TEST_ASSERT(inSuite, common.server.GetHeaderFound()); @@ -258,7 +288,9 @@ void PtrSrvTxtAnyResponseToInstance(nlTestSuite * inSuite, void * inContext) void PtrSrvTxtSrvResponseToInstance(nlTestSuite * inSuite, void * inContext) { - CommonTestElements common(inSuite); + CommonTestElements common(inSuite, "test"); + ResponseSender responseSender(&common.server); + NL_TEST_ASSERT(inSuite, responseSender.AddQueryResponder(&common.queryResponder) == CHIP_NO_ERROR); common.queryResponder.AddResponder(&common.ptrResponder).SetReportInServiceListing(true); common.queryResponder.AddResponder(&common.srvResponder); common.queryResponder.AddResponder(&common.txtResponder); @@ -271,7 +303,7 @@ void PtrSrvTxtSrvResponseToInstance(nlTestSuite * inSuite, void * inContext) // We didn't set the txt as an additional on the srv name so expect only srv. common.server.AddExpectedRecord(&common.srvRecord); - common.responseSender.Respond(1, queryData, &common.packetInfo); + responseSender.Respond(1, queryData, &common.packetInfo); NL_TEST_ASSERT(inSuite, common.server.GetSendCalled()); NL_TEST_ASSERT(inSuite, common.server.GetHeaderFound()); @@ -279,7 +311,9 @@ void PtrSrvTxtSrvResponseToInstance(nlTestSuite * inSuite, void * inContext) void PtrSrvTxtAnyResponseToServiceListing(nlTestSuite * inSuite, void * inContext) { - CommonTestElements common(inSuite); + CommonTestElements common(inSuite, "test"); + ResponseSender responseSender(&common.server); + NL_TEST_ASSERT(inSuite, responseSender.AddQueryResponder(&common.queryResponder) == CHIP_NO_ERROR); common.queryResponder.AddResponder(&common.ptrResponder).SetReportInServiceListing(true); common.queryResponder.AddResponder(&common.srvResponder); common.queryResponder.AddResponder(&common.txtResponder); @@ -293,20 +327,142 @@ void PtrSrvTxtAnyResponseToServiceListing(nlTestSuite * inSuite, void * inContex PtrResourceRecord serviceRecord = PtrResourceRecord(common.dnsSd, common.ptrRecord.GetName()); common.server.AddExpectedRecord(&serviceRecord); - common.responseSender.Respond(1, queryData, &common.packetInfo); + responseSender.Respond(1, queryData, &common.packetInfo); NL_TEST_ASSERT(inSuite, common.server.GetSendCalled()); NL_TEST_ASSERT(inSuite, common.server.GetHeaderFound()); } +void NoQueryResponder(nlTestSuite * inSuite, void * inContext) +{ + CommonTestElements common(inSuite, "test"); + ResponseSender responseSender(&common.server); + + QueryData queryData = QueryData(QType::ANY, QClass::IN, false, common.requestNameStart, common.requestBytesRange); + + common.dnsSd.Output(common.requestBufferWriter); + responseSender.Respond(1, queryData, &common.packetInfo); + NL_TEST_ASSERT(inSuite, !common.server.GetSendCalled()); + + common.service.Output(common.requestBufferWriter); + responseSender.Respond(1, queryData, &common.packetInfo); + NL_TEST_ASSERT(inSuite, !common.server.GetSendCalled()); + + common.instance.Output(common.requestBufferWriter); + responseSender.Respond(1, queryData, &common.packetInfo); + NL_TEST_ASSERT(inSuite, !common.server.GetSendCalled()); +} + +void AddManyQueryResponders(nlTestSuite * inSuite, void * inContext) +{ + // TODO(cecille): Fix this test once #8000 gets resolved. + ResponseSender responseSender(nullptr); + QueryResponder<1> q1; + QueryResponder<1> q2; + QueryResponder<1> q3; + QueryResponder<1> q4; + QueryResponder<1> q5; + QueryResponder<1> q6; + QueryResponder<1> q7; + QueryResponder<1> q8; + + // We should be able to re-add the same query responder as many times as we want. + for (size_t i = 0; i < ResponseSender::kMaxQueryResponders + 1; ++i) + { + NL_TEST_ASSERT(inSuite, responseSender.AddQueryResponder(&q1) == CHIP_NO_ERROR); + } + + // There are 7 total + NL_TEST_ASSERT(inSuite, responseSender.AddQueryResponder(&q2) == CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, responseSender.AddQueryResponder(&q3) == CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, responseSender.AddQueryResponder(&q4) == CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, responseSender.AddQueryResponder(&q5) == CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, responseSender.AddQueryResponder(&q6) == CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, responseSender.AddQueryResponder(&q7) == CHIP_NO_ERROR); + + // Last one should return a no memory error (no space) + NL_TEST_ASSERT(inSuite, responseSender.AddQueryResponder(&q8) == CHIP_ERROR_NO_MEMORY); +} + +void PtrSrvTxtMultipleRespondersToInstance(nlTestSuite * inSuite, void * inContext) +{ + CommonTestElements common1(inSuite, "test1"); + CommonTestElements common2(inSuite, "test2"); + + // Just use the server from common1. + ResponseSender responseSender(&common1.server); + + NL_TEST_ASSERT(inSuite, responseSender.AddQueryResponder(&common1.queryResponder) == CHIP_NO_ERROR); + common1.queryResponder.AddResponder(&common1.ptrResponder).SetReportInServiceListing(true); + common1.queryResponder.AddResponder(&common1.srvResponder); + common1.queryResponder.AddResponder(&common1.txtResponder); + + NL_TEST_ASSERT(inSuite, responseSender.AddQueryResponder(&common2.queryResponder) == CHIP_NO_ERROR); + common2.queryResponder.AddResponder(&common2.ptrResponder).SetReportInServiceListing(true); + common2.queryResponder.AddResponder(&common2.srvResponder); + common2.queryResponder.AddResponder(&common2.txtResponder); + + // Build a query for the second instance. + common2.instance.Output(common2.requestBufferWriter); + QueryData queryData = QueryData(QType::ANY, QClass::IN, false, common2.requestNameStart, common2.requestBytesRange); + + // Should get back answers from second instance only. + common1.server.AddExpectedRecord(&common2.srvRecord); + common1.server.AddExpectedRecord(&common2.txtRecord); + + responseSender.Respond(1, queryData, &common1.packetInfo); + + NL_TEST_ASSERT(inSuite, common1.server.GetSendCalled()); + NL_TEST_ASSERT(inSuite, common1.server.GetHeaderFound()); +} + +void PtrSrvTxtMultipleRespondersToServiceListing(nlTestSuite * inSuite, void * inContext) +{ + CommonTestElements common1(inSuite, "test1"); + CommonTestElements common2(inSuite, "test2"); + + // Just use the server from common1. + ResponseSender responseSender(&common1.server); + + NL_TEST_ASSERT(inSuite, responseSender.AddQueryResponder(&common1.queryResponder) == CHIP_NO_ERROR); + common1.queryResponder.AddResponder(&common1.ptrResponder).SetReportInServiceListing(true); + common1.queryResponder.AddResponder(&common1.srvResponder); + common1.queryResponder.AddResponder(&common1.txtResponder); + + NL_TEST_ASSERT(inSuite, responseSender.AddQueryResponder(&common2.queryResponder) == CHIP_NO_ERROR); + common2.queryResponder.AddResponder(&common2.ptrResponder).SetReportInServiceListing(true); + common2.queryResponder.AddResponder(&common2.srvResponder); + common2.queryResponder.AddResponder(&common2.txtResponder); + + // Build a query for the instance + common1.dnsSd.Output(common1.requestBufferWriter); + QueryData queryData = QueryData(QType::ANY, QClass::IN, false, common1.requestNameStart, common1.requestBytesRange); + + // Should get service listing from both. + PtrResourceRecord serviceRecord1 = PtrResourceRecord(common1.dnsSd, common1.ptrRecord.GetName()); + common1.server.AddExpectedRecord(&serviceRecord1); + PtrResourceRecord serviceRecord2 = PtrResourceRecord(common2.dnsSd, common2.ptrRecord.GetName()); + common1.server.AddExpectedRecord(&serviceRecord2); + + responseSender.Respond(1, queryData, &common1.packetInfo); + + NL_TEST_ASSERT(inSuite, common1.server.GetSendCalled()); + NL_TEST_ASSERT(inSuite, common1.server.GetHeaderFound()); +} + const nlTest sTests[] = { - NL_TEST_DEF("SrvAnyResponseToInstance", SrvAnyResponseToInstance), // - NL_TEST_DEF("SrvTxtAnyResponseToInstance", SrvTxtAnyResponseToInstance), // - NL_TEST_DEF("PtrSrvTxtAnyResponseToServiceName", PtrSrvTxtAnyResponseToServiceName), // - NL_TEST_DEF("PtrSrvTxtAnyResponseToInstance", PtrSrvTxtAnyResponseToInstance), // - NL_TEST_DEF("PtrSrvTxtSrvResponseToInstance", PtrSrvTxtSrvResponseToInstance), // - NL_TEST_DEF("PtrSrvTxtAnyResponseToServiceListing", PtrSrvTxtAnyResponseToServiceListing), // - NL_TEST_SENTINEL() // + NL_TEST_DEF("SrvAnyResponseToInstance", SrvAnyResponseToInstance), // + NL_TEST_DEF("SrvTxtAnyResponseToInstance", SrvTxtAnyResponseToInstance), // + NL_TEST_DEF("PtrSrvTxtAnyResponseToServiceName", PtrSrvTxtAnyResponseToServiceName), // + NL_TEST_DEF("PtrSrvTxtAnyResponseToInstance", PtrSrvTxtAnyResponseToInstance), // + NL_TEST_DEF("PtrSrvTxtSrvResponseToInstance", PtrSrvTxtSrvResponseToInstance), // + NL_TEST_DEF("PtrSrvTxtAnyResponseToServiceListing", PtrSrvTxtAnyResponseToServiceListing), // + NL_TEST_DEF("NoQueryResponder", NoQueryResponder), // + NL_TEST_DEF("AddManyQueryResponders", AddManyQueryResponders), // + NL_TEST_DEF("PtrSrvTxtMultipleRespondersToInstance", PtrSrvTxtMultipleRespondersToInstance), // + NL_TEST_DEF("PtrSrvTxtMultipleRespondersToServiceListing", PtrSrvTxtMultipleRespondersToServiceListing), // + + NL_TEST_SENTINEL() // }; } // namespace