From 3d736347f5f92575b8eebd0b8b21b73b899f9ddf Mon Sep 17 00:00:00 2001 From: Andrei Litvin Date: Fri, 17 Jun 2022 16:54:36 -0400 Subject: [PATCH] Make minmdns advertiser send a TTL=0 record broadcast when services are removed. (#19692) * Prepare to have response configuration, specifically allow replies to contain an TTL override * Support the remove all type of advertisement (in theory) * Make minmdns advertiser send 0 ttl advertisements * make minmdns server compile and be able to shutdown more cleanly. Will NOT send removal broadcasts though (it does not send add either) * Make unit tests compile * Add a unit test showing TTL overrides * Code review comments --- examples/minimal-mdns/server.cpp | 22 ++++-- src/lib/dnssd/Advertiser_ImplMinimalMdns.cpp | 70 +++++++++++++++---- src/lib/dnssd/minimal_mdns/Parser.h | 16 +++-- src/lib/dnssd/minimal_mdns/QueryReplyFilter.h | 2 +- src/lib/dnssd/minimal_mdns/ResponseSender.cpp | 7 +- src/lib/dnssd/minimal_mdns/ResponseSender.h | 3 +- src/lib/dnssd/minimal_mdns/responders/IP.cpp | 14 ++-- src/lib/dnssd/minimal_mdns/responders/IP.h | 6 +- src/lib/dnssd/minimal_mdns/responders/Ptr.h | 7 +- .../responders/QueryResponder.cpp | 7 +- .../minimal_mdns/responders/QueryResponder.h | 3 +- .../dnssd/minimal_mdns/responders/Responder.h | 36 +++++++++- src/lib/dnssd/minimal_mdns/responders/Srv.h | 7 +- src/lib/dnssd/minimal_mdns/responders/Txt.h | 7 +- .../responders/tests/TestIPResponder.cpp | 4 +- .../responders/tests/TestPtrResponder.cpp | 36 ++++++++-- .../responders/tests/TestQueryResponder.cpp | 6 +- .../minimal_mdns/tests/TestResponseSender.cpp | 22 +++--- 18 files changed, 206 insertions(+), 69 deletions(-) diff --git a/examples/minimal-mdns/server.cpp b/examples/minimal-mdns/server.cpp index afc84b75261716..38a68e100150d3 100644 --- a/examples/minimal-mdns/server.cpp +++ b/examples/minimal-mdns/server.cpp @@ -146,7 +146,7 @@ class ReplyDelegate : public mdns::Minimal::ServerDelegate, public mdns::Minimal void OnQuery(const mdns::Minimal::QueryData & data) override { - if (mResponder->Respond(mMessageId, data, mCurrentSource) != CHIP_NO_ERROR) + if (mResponder->Respond(mMessageId, data, mCurrentSource, mdns::Minimal::ResponseConfiguration()) != CHIP_NO_ERROR) { printf("FAILED to respond!\n"); } @@ -167,6 +167,16 @@ class ReplyDelegate : public mdns::Minimal::ServerDelegate, public mdns::Minimal uint32_t mMessageId = 0; }; +mdns::Minimal::Server<10 /* endpoints */> gMdnsServer; + +void StopSignalHandler(int signal) +{ + gMdnsServer.Shutdown(); + + DeviceLayer::PlatformMgr().StopEventLoopTask(); + DeviceLayer::PlatformMgr().Shutdown(); +} + } // namespace int main(int argc, char ** args) @@ -190,7 +200,6 @@ int main(int argc, char ** args) printf("Running on port %d using %s...\n", gOptions.listenPort, gOptions.enableIpV4 ? "IPv4 AND IPv6" : "IPv6 ONLY"); - mdns::Minimal::Server<10 /* endpoints */> mdnsServer; mdns::Minimal::QueryResponder<16 /* maxRecords */> queryResponder; mdns::Minimal::QNamePart tcpServiceName[] = { Dnssd::kOperationalServiceName, Dnssd::kOperationalProtocol, @@ -249,22 +258,25 @@ int main(int argc, char ** args) queryResponder.AddResponder(&ipv4Responder); } - mdns::Minimal::ResponseSender responseSender(&mdnsServer); + mdns::Minimal::ResponseSender responseSender(&gMdnsServer); responseSender.AddQueryResponder(&queryResponder); ReplyDelegate delegate(&responseSender); - mdnsServer.SetDelegate(&delegate); + gMdnsServer.SetDelegate(&delegate); { MdnsExample::AllInterfaces allInterfaces(gOptions.enableIpV4); - if (mdnsServer.Listen(DeviceLayer::UDPEndPointManager(), &allInterfaces, gOptions.listenPort) != CHIP_NO_ERROR) + if (gMdnsServer.Listen(DeviceLayer::UDPEndPointManager(), &allInterfaces, gOptions.listenPort) != CHIP_NO_ERROR) { printf("Server failed to listen on all interfaces\n"); return 1; } } + signal(SIGTERM, StopSignalHandler); + signal(SIGINT, StopSignalHandler); + DeviceLayer::PlatformMgr().RunEventLoop(); printf("Done...\n"); diff --git a/src/lib/dnssd/Advertiser_ImplMinimalMdns.cpp b/src/lib/dnssd/Advertiser_ImplMinimalMdns.cpp index a90a31a9f6a24e..4db3f0676c89c0 100644 --- a/src/lib/dnssd/Advertiser_ImplMinimalMdns.cpp +++ b/src/lib/dnssd/Advertiser_ImplMinimalMdns.cpp @@ -147,6 +147,12 @@ class OperationalQueryAllocator : public chip::IntrusiveListNodeBase<> Allocator * mAllocator = nullptr; }; +enum BroadcastAdvertiseType +{ + kStarted, // Advertise at startup of all records added, as required by RFC 6762. + kRemovingAll, // sent a TTL 0 for all records, as records are removed +}; + class AdvertiserMinMdns : public ServiceAdvertiser, public MdnsPacketDelegate, // receive query packets public ParserDelegate // parses queries @@ -190,10 +196,12 @@ class AdvertiserMinMdns : public ServiceAdvertiser, void OnQuery(const QueryData & data) override; private: - /// Advertise available records configured within the server + /// Advertise available records configured within the server. /// - /// Usable as boot-time advertisement of available SRV records. - void AdvertiseRecords(); + /// Establishes a type of 'Advertise all currently configured items' + /// for a specific purpose (e.g. boot time advertises everything, shut-down + /// removes all records by advertising a 0 TTL) + void AdvertiseRecords(BroadcastAdvertiseType type); /// Determine if advertisement on the specified interface/address is ok given the /// interfaces on which the mDNS server is listening @@ -311,7 +319,8 @@ void AdvertiserMinMdns::OnQuery(const QueryData & data) LogQuery(data); - CHIP_ERROR err = mResponseSender.Respond(mMessageId, data, mCurrentSource); + const ResponseConfiguration defaultResponseConfiguration; + CHIP_ERROR err = mResponseSender.Respond(mMessageId, data, mCurrentSource, defaultResponseConfiguration); if (err != CHIP_NO_ERROR) { ChipLogError(Discovery, "Failed to reply to query: %s", ErrorStr(err)); @@ -339,7 +348,7 @@ CHIP_ERROR AdvertiserMinMdns::Init(chip::Inet::EndPointManagerClearBroadcastThrottle(); mQueryResponderAllocatorCommissioner.GetQueryResponder()->ClearBroadcastThrottle(); - CHIP_ERROR err = mResponseSender.Respond(0, queryData, &packetInfo); + CHIP_ERROR err = mResponseSender.Respond(0, queryData, &packetInfo, responseConfiguration); if (err != CHIP_NO_ERROR) { ChipLogError(Discovery, "Failed to advertise records: %s", ErrorStr(err)); diff --git a/src/lib/dnssd/minimal_mdns/Parser.h b/src/lib/dnssd/minimal_mdns/Parser.h index 8150f0770925a9..509633b9e192fa 100644 --- a/src/lib/dnssd/minimal_mdns/Parser.h +++ b/src/lib/dnssd/minimal_mdns/Parser.h @@ -42,10 +42,12 @@ class QueryData QClass GetClass() const { return mClass; } bool RequestedUnicastAnswer() const { return mAnswerViaUnicast; } - /// Boot advertisement is an internal query meant to advertise all available - /// services at device startup time. - bool IsBootAdvertising() const { return mIsBootAdvertising; } - void SetIsBootAdvertising(bool isBootAdvertising) { mIsBootAdvertising = isBootAdvertising; } + /// Internal broadcasts will advertise all available data and will not apply + /// any broadcast filtering. Intent is for paths such as: + /// - boot time advertisement: advertise all services available + /// - stop-time advertisement: advertise a TTL of 0 as services are removed + bool IsInternalBroadcast() const { return mIsInternalBroadcast; } + void SetIsInternalBroadcast(bool isInternalBroadcast) { mIsInternalBroadcast = isInternalBroadcast; } SerializedQNameIterator GetName() const { return mNameIterator; } @@ -65,9 +67,9 @@ class QueryData bool mAnswerViaUnicast = false; SerializedQNameIterator mNameIterator; - /// Flag as a boot-time internal query. This allows query replies - /// to be built accordingly. - bool mIsBootAdvertising = false; + /// Flag as an internal broadcast, controls reply construction (e.g. no + /// filtering applied) + bool mIsInternalBroadcast = false; }; class ResourceData diff --git a/src/lib/dnssd/minimal_mdns/QueryReplyFilter.h b/src/lib/dnssd/minimal_mdns/QueryReplyFilter.h index 8091e4da6d532f..c7eb0a98af0151 100644 --- a/src/lib/dnssd/minimal_mdns/QueryReplyFilter.h +++ b/src/lib/dnssd/minimal_mdns/QueryReplyFilter.h @@ -81,7 +81,7 @@ class QueryReplyFilter : public ReplyFilter bool AcceptablePath(FullQName qname) { - if (mIgnoreNameMatch || mQueryData.IsBootAdvertising()) + if (mIgnoreNameMatch || mQueryData.IsInternalBroadcast()) { return true; } diff --git a/src/lib/dnssd/minimal_mdns/ResponseSender.cpp b/src/lib/dnssd/minimal_mdns/ResponseSender.cpp index f3119a27ed8c13..9bef5b773defbf 100644 --- a/src/lib/dnssd/minimal_mdns/ResponseSender.cpp +++ b/src/lib/dnssd/minimal_mdns/ResponseSender.cpp @@ -100,7 +100,8 @@ bool ResponseSender::HasQueryResponders() const return false; } -CHIP_ERROR ResponseSender::Respond(uint32_t messageId, const QueryData & query, const chip::Inet::IPPacketInfo * querySource) +CHIP_ERROR ResponseSender::Respond(uint32_t messageId, const QueryData & query, const chip::Inet::IPPacketInfo * querySource, + const ResponseConfiguration & configuration) { mSendState.Reset(messageId, query, querySource); @@ -142,7 +143,7 @@ CHIP_ERROR ResponseSender::Respond(uint32_t messageId, const QueryData & query, } for (auto it = (*responder)->begin(&responseFilter); it != (*responder)->end(); it++) { - it->responder->AddAllResponses(querySource, this); + it->responder->AddAllResponses(querySource, this, configuration); ReturnErrorOnFailure(mSendState.GetError()); (*responder)->MarkAdditionalRepliesFor(it); @@ -175,7 +176,7 @@ CHIP_ERROR ResponseSender::Respond(uint32_t messageId, const QueryData & query, } for (auto it = (*responder)->begin(&responseFilter); it != (*responder)->end(); it++) { - it->responder->AddAllResponses(querySource, this); + it->responder->AddAllResponses(querySource, this, configuration); ReturnErrorOnFailure(mSendState.GetError()); } } diff --git a/src/lib/dnssd/minimal_mdns/ResponseSender.h b/src/lib/dnssd/minimal_mdns/ResponseSender.h index 8cba032bd9a771..cf1b80f5d97f65 100644 --- a/src/lib/dnssd/minimal_mdns/ResponseSender.h +++ b/src/lib/dnssd/minimal_mdns/ResponseSender.h @@ -112,7 +112,8 @@ class ResponseSender : public ResponderDelegate bool HasQueryResponders() const; /// Send back the response to a particular query - CHIP_ERROR Respond(uint32_t messageId, const QueryData & query, const chip::Inet::IPPacketInfo * querySource); + CHIP_ERROR Respond(uint32_t messageId, const QueryData & query, const chip::Inet::IPPacketInfo * querySource, + const ResponseConfiguration & configuration); // Implementation of ResponderDelegate void AddResponse(const ResourceRecord & record) override; diff --git a/src/lib/dnssd/minimal_mdns/responders/IP.cpp b/src/lib/dnssd/minimal_mdns/responders/IP.cpp index c235c8f57d3ed9..84c36bd1cb01cc 100644 --- a/src/lib/dnssd/minimal_mdns/responders/IP.cpp +++ b/src/lib/dnssd/minimal_mdns/responders/IP.cpp @@ -21,19 +21,23 @@ namespace mdns { namespace Minimal { -void IPv4Responder::AddAllResponses(const chip::Inet::IPPacketInfo * source, ResponderDelegate * delegate) +void IPv4Responder::AddAllResponses(const chip::Inet::IPPacketInfo * source, ResponderDelegate * delegate, + const ResponseConfiguration & configuration) { chip::Inet::IPAddress addr; for (chip::Inet::InterfaceAddressIterator it; it.HasCurrent(); it.Next()) { if ((it.GetInterfaceId() == source->Interface) && (it.GetAddress(addr) == CHIP_NO_ERROR) && addr.IsIPv4()) { - delegate->AddResponse(IPResourceRecord(GetQName(), addr)); + IPResourceRecord record(GetQName(), addr); + configuration.Adjust(record); + delegate->AddResponse(record); } } } -void IPv6Responder::AddAllResponses(const chip::Inet::IPPacketInfo * source, ResponderDelegate * delegate) +void IPv6Responder::AddAllResponses(const chip::Inet::IPPacketInfo * source, ResponderDelegate * delegate, + const ResponseConfiguration & configuration) { for (chip::Inet::InterfaceAddressIterator it; it.HasCurrent(); it.Next()) { @@ -45,7 +49,9 @@ void IPv6Responder::AddAllResponses(const chip::Inet::IPPacketInfo * source, Res chip::Inet::IPAddress addr; if ((it.GetInterfaceId() == source->Interface) && (it.GetAddress(addr) == CHIP_NO_ERROR) && addr.IsIPv6()) { - delegate->AddResponse(IPResourceRecord(GetQName(), addr)); + IPResourceRecord record(GetQName(), addr); + configuration.Adjust(record); + delegate->AddResponse(record); } } } diff --git a/src/lib/dnssd/minimal_mdns/responders/IP.h b/src/lib/dnssd/minimal_mdns/responders/IP.h index 71a0957f4e9b60..12d266ce50280d 100644 --- a/src/lib/dnssd/minimal_mdns/responders/IP.h +++ b/src/lib/dnssd/minimal_mdns/responders/IP.h @@ -27,7 +27,8 @@ class IPv4Responder : public RecordResponder public: IPv4Responder(const FullQName & qname) : RecordResponder(QType::A, qname) {} - void AddAllResponses(const chip::Inet::IPPacketInfo * source, ResponderDelegate * delegate) override; + void AddAllResponses(const chip::Inet::IPPacketInfo * source, ResponderDelegate * delegate, + const ResponseConfiguration & configuration) override; }; class IPv6Responder : public RecordResponder @@ -35,7 +36,8 @@ class IPv6Responder : public RecordResponder public: IPv6Responder(const FullQName & qname) : RecordResponder(QType::AAAA, qname) {} - void AddAllResponses(const chip::Inet::IPPacketInfo * source, ResponderDelegate * delegate) override; + void AddAllResponses(const chip::Inet::IPPacketInfo * source, ResponderDelegate * delegate, + const ResponseConfiguration & configuration) override; }; } // namespace Minimal diff --git a/src/lib/dnssd/minimal_mdns/responders/Ptr.h b/src/lib/dnssd/minimal_mdns/responders/Ptr.h index d2b72f47124c05..1fd50a0b5e142c 100644 --- a/src/lib/dnssd/minimal_mdns/responders/Ptr.h +++ b/src/lib/dnssd/minimal_mdns/responders/Ptr.h @@ -28,9 +28,12 @@ class PtrResponder : public RecordResponder public: PtrResponder(const FullQName & qname, const FullQName & target) : RecordResponder(QType::PTR, qname), mTarget(target) {} - void AddAllResponses(const chip::Inet::IPPacketInfo * source, ResponderDelegate * delegate) override + void AddAllResponses(const chip::Inet::IPPacketInfo * source, ResponderDelegate * delegate, + const ResponseConfiguration & configuration) override { - delegate->AddResponse(PtrResourceRecord(GetQName(), mTarget)); + PtrResourceRecord record(GetQName(), mTarget); + configuration.Adjust(record); + delegate->AddResponse(record); } private: diff --git a/src/lib/dnssd/minimal_mdns/responders/QueryResponder.cpp b/src/lib/dnssd/minimal_mdns/responders/QueryResponder.cpp index 5f14659dac50db..6873ed49bf2d7b 100644 --- a/src/lib/dnssd/minimal_mdns/responders/QueryResponder.cpp +++ b/src/lib/dnssd/minimal_mdns/responders/QueryResponder.cpp @@ -137,7 +137,8 @@ void QueryResponderBase::MarkAdditionalRepliesFor(QueryResponderIterator it) } } -void QueryResponderBase::AddAllResponses(const chip::Inet::IPPacketInfo * source, ResponderDelegate * delegate) +void QueryResponderBase::AddAllResponses(const chip::Inet::IPPacketInfo * source, ResponderDelegate * delegate, + const ResponseConfiguration & configuration) { // reply to dns-sd service list request for (size_t i = 0; i < mResponderInfoSize; i++) @@ -152,7 +153,9 @@ void QueryResponderBase::AddAllResponses(const chip::Inet::IPPacketInfo * source continue; } - delegate->AddResponse(PtrResourceRecord(GetQName(), mResponderInfos[i].responder->GetQName())); + PtrResourceRecord record(GetQName(), mResponderInfos[i].responder->GetQName()); + configuration.Adjust(record); + delegate->AddResponse(record); } } diff --git a/src/lib/dnssd/minimal_mdns/responders/QueryResponder.h b/src/lib/dnssd/minimal_mdns/responders/QueryResponder.h index c2f55096606f95..e39bd32287e6ac 100644 --- a/src/lib/dnssd/minimal_mdns/responders/QueryResponder.h +++ b/src/lib/dnssd/minimal_mdns/responders/QueryResponder.h @@ -252,7 +252,8 @@ class QueryResponderBase : public Responder // "_services._dns-sd._udp.local" /// Implementation of the responder delegate. /// /// Adds responses for all known _dns-sd services. - void AddAllResponses(const chip::Inet::IPPacketInfo * source, ResponderDelegate * delegate) override; + void AddAllResponses(const chip::Inet::IPPacketInfo * source, ResponderDelegate * delegate, + const ResponseConfiguration & configuration) override; QueryResponderIterator begin(QueryResponderRecordFilter * filter) { diff --git a/src/lib/dnssd/minimal_mdns/responders/Responder.h b/src/lib/dnssd/minimal_mdns/responders/Responder.h index 04c7ca1537dfdb..4a1b4e015dbe59 100644 --- a/src/lib/dnssd/minimal_mdns/responders/Responder.h +++ b/src/lib/dnssd/minimal_mdns/responders/Responder.h @@ -21,6 +21,7 @@ #include #include +#include namespace mdns { namespace Minimal { @@ -34,6 +35,38 @@ class ResponderDelegate virtual void AddResponse(const ResourceRecord & record) = 0; }; +/// Controls specific options for responding to mDNS queries +/// +class ResponseConfiguration +{ +public: + ResponseConfiguration() {} + ~ResponseConfiguration() = default; + + chip::Optional GetTtlSecondsOverride() const { return mTtlSecondsOverride; } + ResponseConfiguration & SetTtlSecondsOverride(chip::Optional override) + { + mTtlSecondsOverride = override; + return *this; + } + + ResponseConfiguration & SetTtlSecondsOverride(uint32_t value) { return SetTtlSecondsOverride(chip::MakeOptional(value)); } + ResponseConfiguration & ClearTtlSecondsOverride() { return SetTtlSecondsOverride(chip::NullOptional); } + + /// Applies any adjustments to resource records before they are being serialized + /// to some form of reply. + void Adjust(ResourceRecord & record) const + { + if (mTtlSecondsOverride.HasValue()) + { + record.SetTtl(mTtlSecondsOverride.Value()); + } + } + +private: + chip::Optional mTtlSecondsOverride; +}; + /// Adds ability to respond with specific types of data class Responder { @@ -51,7 +84,8 @@ class Responder /// Report all reponses maintained by this responder /// /// Responses are associated with the objects type/class/qname. - virtual void AddAllResponses(const chip::Inet::IPPacketInfo * source, ResponderDelegate * delegate) = 0; + virtual void AddAllResponses(const chip::Inet::IPPacketInfo * source, ResponderDelegate * delegate, + const ResponseConfiguration & configuration) = 0; private: const QType mQType; diff --git a/src/lib/dnssd/minimal_mdns/responders/Srv.h b/src/lib/dnssd/minimal_mdns/responders/Srv.h index c8cb724a127a27..d23226092b3e09 100644 --- a/src/lib/dnssd/minimal_mdns/responders/Srv.h +++ b/src/lib/dnssd/minimal_mdns/responders/Srv.h @@ -28,9 +28,12 @@ class SrvResponder : public RecordResponder public: SrvResponder(const SrvResourceRecord & record) : RecordResponder(QType::SRV, record.GetName()), mRecord(record) {} - void AddAllResponses(const chip::Inet::IPPacketInfo * source, ResponderDelegate * delegate) override + void AddAllResponses(const chip::Inet::IPPacketInfo * source, ResponderDelegate * delegate, + const ResponseConfiguration & configuration) override { - delegate->AddResponse(mRecord); + SrvResourceRecord record = mRecord; + configuration.Adjust(record); + delegate->AddResponse(record); } private: diff --git a/src/lib/dnssd/minimal_mdns/responders/Txt.h b/src/lib/dnssd/minimal_mdns/responders/Txt.h index 92302bda90c632..5b607265b8c55c 100644 --- a/src/lib/dnssd/minimal_mdns/responders/Txt.h +++ b/src/lib/dnssd/minimal_mdns/responders/Txt.h @@ -28,9 +28,12 @@ class TxtResponder : public RecordResponder public: TxtResponder(const TxtResourceRecord & record) : RecordResponder(QType::TXT, record.GetName()), mRecord(record) {} - void AddAllResponses(const chip::Inet::IPPacketInfo * source, ResponderDelegate * delegate) override + void AddAllResponses(const chip::Inet::IPPacketInfo * source, ResponderDelegate * delegate, + const ResponseConfiguration & configuration) override { - delegate->AddResponse(mRecord); + TxtResourceRecord record = mRecord; + configuration.Adjust(record); + delegate->AddResponse(record); } private: diff --git a/src/lib/dnssd/minimal_mdns/responders/tests/TestIPResponder.cpp b/src/lib/dnssd/minimal_mdns/responders/tests/TestIPResponder.cpp index 7cafa4344b578c..35fc5f7d642bfa 100644 --- a/src/lib/dnssd/minimal_mdns/responders/tests/TestIPResponder.cpp +++ b/src/lib/dnssd/minimal_mdns/responders/tests/TestIPResponder.cpp @@ -81,7 +81,7 @@ void TestIPv4(nlTestSuite * inSuite, void * inContext) packetInfo.DestPort = kMdnsPort; packetInfo.Interface = FindValidInterfaceId(); - responder.AddAllResponses(&packetInfo, &acc); + responder.AddAllResponses(&packetInfo, &acc, ResponseConfiguration()); } #endif // INET_CONFIG_ENABLE_IPV4 @@ -105,7 +105,7 @@ void TestIPv6(nlTestSuite * inSuite, void * inContext) packetInfo.DestPort = kMdnsPort; packetInfo.Interface = FindValidInterfaceId(); - responder.AddAllResponses(&packetInfo, &acc); + responder.AddAllResponses(&packetInfo, &acc, ResponseConfiguration()); } const nlTest sTests[] = { diff --git a/src/lib/dnssd/minimal_mdns/responders/tests/TestPtrResponder.cpp b/src/lib/dnssd/minimal_mdns/responders/tests/TestPtrResponder.cpp index d9bf42a7383b8c..473103255e7c47 100644 --- a/src/lib/dnssd/minimal_mdns/responders/tests/TestPtrResponder.cpp +++ b/src/lib/dnssd/minimal_mdns/responders/tests/TestPtrResponder.cpp @@ -39,7 +39,7 @@ const QNamePart kTargetNames[] = { "point", "to", "this" }; class PtrResponseAccumulator : public ResponderDelegate { public: - PtrResponseAccumulator(nlTestSuite * suite) : mSuite(suite) {} + PtrResponseAccumulator(nlTestSuite * suite, const uint32_t expectedTtl) : mSuite(suite), mExpectedTtl(expectedTtl) {} void AddResponse(const ResourceRecord & record) override { @@ -72,6 +72,7 @@ class PtrResponseAccumulator : public ResponderDelegate NL_TEST_ASSERT(mSuite, start == (buffer + out.Needed())); NL_TEST_ASSERT(mSuite, data.GetName() == FullQName(kNames)); NL_TEST_ASSERT(mSuite, data.GetType() == QType::PTR); + NL_TEST_ASSERT(mSuite, data.GetTtlSeconds() == mExpectedTtl); NL_TEST_ASSERT(mSuite, ParsePtrRecord(data.GetData(), validDataRange, &target)); NL_TEST_ASSERT(mSuite, target == FullQName(kTargetNames)); @@ -80,6 +81,7 @@ class PtrResponseAccumulator : public ResponderDelegate private: nlTestSuite * mSuite; + const uint32_t mExpectedTtl; }; void TestPtrResponse(nlTestSuite * inSuite, void * inContext) @@ -93,7 +95,7 @@ void TestPtrResponse(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, responder.GetQType() == QType::PTR); NL_TEST_ASSERT(inSuite, responder.GetQName() == kNames); - PtrResponseAccumulator acc(inSuite); + PtrResponseAccumulator acc(inSuite, ResourceRecord::kDefaultTtl); chip::Inet::IPPacketInfo packetInfo; packetInfo.SrcAddress = ipAddress; @@ -102,12 +104,36 @@ void TestPtrResponse(nlTestSuite * inSuite, void * inContext) packetInfo.DestPort = kMdnsPort; packetInfo.Interface = InterfaceId::Null(); - responder.AddAllResponses(&packetInfo, &acc); + responder.AddAllResponses(&packetInfo, &acc, ResponseConfiguration()); +} + +void TestPtrResponseOverrideTtl(nlTestSuite * inSuite, void * inContext) +{ + IPAddress ipAddress; + NL_TEST_ASSERT(inSuite, IPAddress::FromString("2607:f8b0:4005:804::200e", ipAddress)); + + PtrResponder responder(kNames, kTargetNames); + + NL_TEST_ASSERT(inSuite, responder.GetQClass() == QClass::IN); + NL_TEST_ASSERT(inSuite, responder.GetQType() == QType::PTR); + NL_TEST_ASSERT(inSuite, responder.GetQName() == kNames); + + PtrResponseAccumulator acc(inSuite, 123); + chip::Inet::IPPacketInfo packetInfo; + + packetInfo.SrcAddress = ipAddress; + packetInfo.DestAddress = ipAddress; + packetInfo.SrcPort = kMdnsPort; + packetInfo.DestPort = kMdnsPort; + packetInfo.Interface = InterfaceId::Null(); + + responder.AddAllResponses(&packetInfo, &acc, ResponseConfiguration().SetTtlSecondsOverride(123)); } const nlTest sTests[] = { - NL_TEST_DEF("TestPtrResponse", TestPtrResponse), // - NL_TEST_SENTINEL() // + NL_TEST_DEF("TestPtrResponse", TestPtrResponse), // + NL_TEST_DEF("TestPtrResponseOverrideTtl", TestPtrResponseOverrideTtl), // + NL_TEST_SENTINEL() // }; } // namespace diff --git a/src/lib/dnssd/minimal_mdns/responders/tests/TestQueryResponder.cpp b/src/lib/dnssd/minimal_mdns/responders/tests/TestQueryResponder.cpp index c2db28a63627a1..aed19ef04cc8e7 100644 --- a/src/lib/dnssd/minimal_mdns/responders/tests/TestQueryResponder.cpp +++ b/src/lib/dnssd/minimal_mdns/responders/tests/TestQueryResponder.cpp @@ -38,7 +38,7 @@ class EmptyResponder : public RecordResponder { public: EmptyResponder(const FullQName & qName) : RecordResponder(QType::NULLVALUE, qName) {} - void AddAllResponses(const chip::Inet::IPPacketInfo *, ResponderDelegate *) override {} + void AddAllResponses(const chip::Inet::IPPacketInfo *, ResponderDelegate *, const ResponseConfiguration &) override {} }; class DnssdReplyAccumulator : public ResponderDelegate @@ -111,7 +111,7 @@ void RespondsToDnsSdQueries(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, responder.GetQName() == kDnsSdname); DnssdReplyAccumulator accumulator(inSuite); - responder.AddAllResponses(nullptr, &accumulator); + responder.AddAllResponses(nullptr, &accumulator, ResponseConfiguration()); NL_TEST_ASSERT(inSuite, accumulator.Captures().size() == 2); if (accumulator.Captures().size() == 2) @@ -160,7 +160,7 @@ void NonDiscoverableService(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, responder.AddResponder(&empty2).SetReportInServiceListing(true).IsValid()); DnssdReplyAccumulator accumulator(inSuite); - responder.AddAllResponses(nullptr, &accumulator); + responder.AddAllResponses(nullptr, &accumulator, ResponseConfiguration()); NL_TEST_ASSERT(inSuite, accumulator.Captures().size() == 1); if (accumulator.Captures().size() == 1) diff --git a/src/lib/dnssd/minimal_mdns/tests/TestResponseSender.cpp b/src/lib/dnssd/minimal_mdns/tests/TestResponseSender.cpp index c265bdab5a8604..0089566ad249ad 100644 --- a/src/lib/dnssd/minimal_mdns/tests/TestResponseSender.cpp +++ b/src/lib/dnssd/minimal_mdns/tests/TestResponseSender.cpp @@ -98,7 +98,7 @@ void SrvAnyResponseToInstance(nlTestSuite * inSuite, void * inContext) QueryData queryData = QueryData(QType::ANY, QClass::IN, false, common.requestNameStart, common.requestBytesRange); common.server.AddExpectedRecord(&common.srvRecord); - responseSender.Respond(1, queryData, &common.packetInfo); + responseSender.Respond(1, queryData, &common.packetInfo, ResponseConfiguration()); NL_TEST_ASSERT(inSuite, common.server.GetSendCalled()); NL_TEST_ASSERT(inSuite, common.server.GetHeaderFound()); @@ -120,7 +120,7 @@ void SrvTxtAnyResponseToInstance(nlTestSuite * inSuite, void * inContext) // We requested ANY on the host name, expect both back. common.server.AddExpectedRecord(&common.srvRecord); common.server.AddExpectedRecord(&common.txtRecord); - responseSender.Respond(1, queryData, &common.packetInfo); + responseSender.Respond(1, queryData, &common.packetInfo, ResponseConfiguration()); NL_TEST_ASSERT(inSuite, common.server.GetSendCalled()); NL_TEST_ASSERT(inSuite, common.server.GetHeaderFound()); @@ -145,7 +145,7 @@ void PtrSrvTxtAnyResponseToServiceName(nlTestSuite * inSuite, void * inContext) common.server.AddExpectedRecord(&common.srvRecord); common.server.AddExpectedRecord(&common.txtRecord); - responseSender.Respond(1, queryData, &common.packetInfo); + responseSender.Respond(1, queryData, &common.packetInfo, ResponseConfiguration()); NL_TEST_ASSERT(inSuite, common.server.GetSendCalled()); NL_TEST_ASSERT(inSuite, common.server.GetHeaderFound()); @@ -169,7 +169,7 @@ void PtrSrvTxtAnyResponseToInstance(nlTestSuite * inSuite, void * inContext) common.server.AddExpectedRecord(&common.srvRecord); common.server.AddExpectedRecord(&common.txtRecord); - responseSender.Respond(1, queryData, &common.packetInfo); + responseSender.Respond(1, queryData, &common.packetInfo, ResponseConfiguration()); NL_TEST_ASSERT(inSuite, common.server.GetSendCalled()); NL_TEST_ASSERT(inSuite, common.server.GetHeaderFound()); @@ -192,7 +192,7 @@ void PtrSrvTxtSrvResponseToInstance(nlTestSuite * inSuite, void * inContext) // We didn't set the txt as an additional on the srv name so expect only srv. common.server.AddExpectedRecord(&common.srvRecord); - responseSender.Respond(1, queryData, &common.packetInfo); + responseSender.Respond(1, queryData, &common.packetInfo, ResponseConfiguration()); NL_TEST_ASSERT(inSuite, common.server.GetSendCalled()); NL_TEST_ASSERT(inSuite, common.server.GetHeaderFound()); @@ -216,7 +216,7 @@ void PtrSrvTxtAnyResponseToServiceListing(nlTestSuite * inSuite, void * inContex PtrResourceRecord serviceRecord = PtrResourceRecord(common.dnsSd, common.ptrRecord.GetName()); common.server.AddExpectedRecord(&serviceRecord); - responseSender.Respond(1, queryData, &common.packetInfo); + responseSender.Respond(1, queryData, &common.packetInfo, ResponseConfiguration()); NL_TEST_ASSERT(inSuite, common.server.GetSendCalled()); NL_TEST_ASSERT(inSuite, common.server.GetHeaderFound()); @@ -230,15 +230,15 @@ void NoQueryResponder(nlTestSuite * inSuite, void * inContext) QueryData queryData = QueryData(QType::ANY, QClass::IN, false, common.requestNameStart, common.requestBytesRange); common.recordWriter.WriteQName(common.dnsSd); - responseSender.Respond(1, queryData, &common.packetInfo); + responseSender.Respond(1, queryData, &common.packetInfo, ResponseConfiguration()); NL_TEST_ASSERT(inSuite, !common.server.GetSendCalled()); common.recordWriter.WriteQName(common.service); - responseSender.Respond(1, queryData, &common.packetInfo); + responseSender.Respond(1, queryData, &common.packetInfo, ResponseConfiguration()); NL_TEST_ASSERT(inSuite, !common.server.GetSendCalled()); common.recordWriter.WriteQName(common.instance); - responseSender.Respond(1, queryData, &common.packetInfo); + responseSender.Respond(1, queryData, &common.packetInfo, ResponseConfiguration()); NL_TEST_ASSERT(inSuite, !common.server.GetSendCalled()); } @@ -304,7 +304,7 @@ void PtrSrvTxtMultipleRespondersToInstance(nlTestSuite * inSuite, void * inConte common1.server.AddExpectedRecord(&common2.srvRecord); common1.server.AddExpectedRecord(&common2.txtRecord); - responseSender.Respond(1, queryData, &common1.packetInfo); + responseSender.Respond(1, queryData, &common1.packetInfo, ResponseConfiguration()); NL_TEST_ASSERT(inSuite, common1.server.GetSendCalled()); NL_TEST_ASSERT(inSuite, common1.server.GetHeaderFound()); @@ -338,7 +338,7 @@ void PtrSrvTxtMultipleRespondersToServiceListing(nlTestSuite * inSuite, void * i PtrResourceRecord serviceRecord2 = PtrResourceRecord(common2.dnsSd, common2.ptrRecord.GetName()); common1.server.AddExpectedRecord(&serviceRecord2); - responseSender.Respond(1, queryData, &common1.packetInfo); + responseSender.Respond(1, queryData, &common1.packetInfo, ResponseConfiguration()); NL_TEST_ASSERT(inSuite, common1.server.GetSendCalled()); NL_TEST_ASSERT(inSuite, common1.server.GetHeaderFound());