From 88afa3361bbb89b1999624016d79b4eb92245427 Mon Sep 17 00:00:00 2001 From: Nivi Sarkar <55898241+nivi-apple@users.noreply.github.com> Date: Thu, 21 Mar 2024 15:31:18 -0700 Subject: [PATCH] Fix the dnssd code that browses on both the local and srp domains (#32675) * Fix the dnssd code that browses on both the local and srp domains - Fixes the UAF issue with the timer - Resolves critical comments from PR #32631 * Fix GetDomainFromHostName to get the domain name correctly * Restyled by clang-format --------- Co-authored-by: Restyled.io --- src/platform/Darwin/DnssdContexts.cpp | 105 ++++++++++++---------- src/platform/Darwin/DnssdImpl.cpp | 121 ++++++++++++++------------ src/platform/Darwin/DnssdImpl.h | 20 ++--- 3 files changed, 136 insertions(+), 110 deletions(-) diff --git a/src/platform/Darwin/DnssdContexts.cpp b/src/platform/Darwin/DnssdContexts.cpp index c6067a55168e12..a8ae75e34d2eb5 100644 --- a/src/platform/Darwin/DnssdContexts.cpp +++ b/src/platform/Darwin/DnssdContexts.cpp @@ -458,27 +458,35 @@ ResolveContext::ResolveContext(void * cbContext, DnssdResolveCallback cb, chip:: std::shared_ptr && consumerCounterToUse) : browseThatCausedResolve(browseCausingResolve) { - type = ContextType::Resolve; - context = cbContext; - callback = cb; - protocol = GetProtocol(cbAddressType); - instanceName = instanceNameToResolve; - consumerCounter = std::move(consumerCounterToUse); + type = ContextType::Resolve; + context = cbContext; + callback = cb; + protocol = GetProtocol(cbAddressType); + instanceName = instanceNameToResolve; + consumerCounter = std::move(consumerCounterToUse); + hasSrpTimerStarted = false; } ResolveContext::ResolveContext(CommissioningResolveDelegate * delegate, chip::Inet::IPAddressType cbAddressType, const char * instanceNameToResolve, std::shared_ptr && consumerCounterToUse) : browseThatCausedResolve(nullptr) { - type = ContextType::Resolve; - context = delegate; - callback = nullptr; - protocol = GetProtocol(cbAddressType); - instanceName = instanceNameToResolve; - consumerCounter = std::move(consumerCounterToUse); + type = ContextType::Resolve; + context = delegate; + callback = nullptr; + protocol = GetProtocol(cbAddressType); + instanceName = instanceNameToResolve; + consumerCounter = std::move(consumerCounterToUse); + hasSrpTimerStarted = false; } -ResolveContext::~ResolveContext() {} +ResolveContext::~ResolveContext() +{ + if (this->hasSrpTimerStarted) + { + CancelSrpTimer(this); + } +} void ResolveContext::DispatchFailure(const char * errorStr, CHIP_ERROR err) { @@ -526,8 +534,7 @@ void ResolveContext::DispatchSuccess() for (auto interfaceIndex : priorityInterfaceIndices) { - // Try finding interfaces for domains kLocalDot and kOpenThreadDot and delete them. - if (TryReportingResultsForInterfaceIndex(static_cast(interfaceIndex), std::string(kLocalDot))) + if (TryReportingResultsForInterfaceIndex(interfaceIndex)) { if (needDelete) { @@ -536,7 +543,7 @@ void ResolveContext::DispatchSuccess() return; } - if (TryReportingResultsForInterfaceIndex(static_cast(interfaceIndex), std::string(kOpenThreadDot))) + if (TryReportingResultsForInterfaceIndex(interfaceIndex)) { if (needDelete) { @@ -548,7 +555,8 @@ void ResolveContext::DispatchSuccess() for (auto & interface : interfaces) { - if (TryReportingResultsForInterfaceIndex(interface.first.first, interface.first.second)) + auto interfaceId = interface.first.first; + if (TryReportingResultsForInterfaceIndex(interfaceId)) { break; } @@ -560,7 +568,7 @@ void ResolveContext::DispatchSuccess() } } -bool ResolveContext::TryReportingResultsForInterfaceIndex(uint32_t interfaceIndex, std::string domainName) +bool ResolveContext::TryReportingResultsForInterfaceIndex(uint32_t interfaceIndex) { if (interfaceIndex == 0) { @@ -568,36 +576,44 @@ bool ResolveContext::TryReportingResultsForInterfaceIndex(uint32_t interfaceInde return false; } - std::pair interfaceKey = std::make_pair(interfaceIndex, domainName); - auto & interface = interfaces[interfaceKey]; - auto & ips = interface.addresses; - - // Some interface may not have any ips, just ignore them. - if (ips.size() == 0) + std::map, InterfaceInfo>::iterator iter = interfaces.begin(); + while (iter != interfaces.end()) { - return false; - } + std::pair key = iter->first; + if (key.first == interfaceIndex) + { + auto & interface = interfaces[key]; + auto & ips = interface.addresses; - ChipLogProgress(Discovery, "Mdns: Resolve success on interface %" PRIu32, interfaceIndex); + // Some interface may not have any ips, just ignore them. + if (ips.size() == 0) + { + return false; + } - auto & service = interface.service; - auto addresses = Span(ips.data(), ips.size()); - if (nullptr == callback) - { - auto delegate = static_cast(context); - DiscoveredNodeData nodeData; - service.ToDiscoveredNodeData(addresses, nodeData); - delegate->OnNodeDiscovered(nodeData); - } - else - { - callback(context, &service, addresses, CHIP_NO_ERROR); - } + ChipLogProgress(Discovery, "Mdns: Resolve success on interface %" PRIu32, interfaceIndex); - return true; + auto & service = interface.service; + auto addresses = Span(ips.data(), ips.size()); + if (nullptr == callback) + { + auto delegate = static_cast(context); + DiscoveredNodeData nodeData; + service.ToDiscoveredNodeData(addresses, nodeData); + delegate->OnNodeDiscovered(nodeData); + } + else + { + callback(context, &service, addresses, CHIP_NO_ERROR); + } + + return true; + } + } + return false; } -CHIP_ERROR ResolveContext::OnNewAddress(const std::pair interfaceKey, const struct sockaddr * address) +CHIP_ERROR ResolveContext::OnNewAddress(const std::pair & interfaceKey, const struct sockaddr * address) { // If we don't have any information about this interfaceId, just ignore the // address, since it won't be usable anyway without things like the port. @@ -605,7 +621,7 @@ CHIP_ERROR ResolveContext::OnNewAddress(const std::pair i // on the system, because the hostnames we are looking up all end in // ".local". In other words, we can get regular DNS results in here, not // just DNS-SD ones. - uint32_t interfaceId = interfaceKey.first; + auto interfaceId = interfaceKey.first; if (interfaces.find(interfaceKey) == interfaces.end()) { @@ -720,8 +736,7 @@ void ResolveContext::OnNewInterface(uint32_t interfaceId, const char * fullname, } std::pair interfaceKey = std::make_pair(interfaceId, domainFromHostname); - - interfaces.insert(std::make_pair(interfaceKey, std::move(interface))); + interfaces.insert(std::make_pair(std::move(interfaceKey), std::move(interface))); } bool ResolveContext::HasInterface() diff --git a/src/platform/Darwin/DnssdImpl.cpp b/src/platform/Darwin/DnssdImpl.cpp index bf0fc96dc7cc4f..94bbdc477d9b64 100644 --- a/src/platform/Darwin/DnssdImpl.cpp +++ b/src/platform/Darwin/DnssdImpl.cpp @@ -32,8 +32,12 @@ using namespace chip::Dnssd::Internal; namespace { -// The extra time in milliseconds that we will wait for the resolution on the open thread domain to complete. -constexpr uint16_t kOpenThreadTimeoutInMsec = 250; +constexpr char kLocalDot[] = "local."; + +constexpr char kSrpDot[] = "default.service.arpa."; + +// The extra time in milliseconds that we will wait for the resolution on the srp domain to complete. +constexpr uint16_t kSrpTimeoutInMsec = 250; constexpr DNSServiceFlags kRegisterFlags = kDNSServiceFlagsNoAutoRename; constexpr DNSServiceFlags kBrowseFlags = kDNSServiceFlagsShareConnection; @@ -144,59 +148,57 @@ std::string GetDomainFromHostName(const char * hostnameWithDomain) { std::string hostname = std::string(hostnameWithDomain); - // Find the last occurence of '.' - size_t last_pos = hostname.find_last_of("."); - if (last_pos != std::string::npos) - { - // Get a substring without last '.' - std::string substring = hostname.substr(0, last_pos); - - // Find the last occurence of '.' in the substring created above. - size_t pos = substring.find_last_of("."); - if (pos != std::string::npos) - { - // Return the domain name between the last 2 occurences of '.' including the trailing dot'.'. - return std::string(hostname.substr(pos + 1, last_pos)); - } - } - return std::string(); -} + // Find the first occurence of '.' + size_t first_pos = hostname.find("."); -Global MdnsContexts::sInstance; + // if not found, return empty string + VerifyOrReturnValue(first_pos != std::string::npos, std::string()); -namespace { + // Get a substring after the first occurence of '.' to the end of the string + return hostname.substr(first_pos + 1, hostname.size()); +} /** - * @brief Callback that is called when the timeout for resolving on the kOpenThreadDot domain has expired. + * @brief Callback that is called when the timeout for resolving on the kSrpDot domain has expired. * * @param[in] systemLayer The system layer. * @param[in] callbackContext The context passed to the timer callback. */ -void OpenThreadTimerExpiredCallback(System::Layer * systemLayer, void * callbackContext) +void SrpTimerExpiredCallback(System::Layer * systemLayer, void * callbackContext) { - ChipLogProgress(Discovery, "Mdns: Timer expired for resolve to complete on the open thread domain."); + ChipLogProgress(Discovery, "Mdns: Timer expired for resolve to complete on the srp domain."); auto sdCtx = static_cast(callbackContext); VerifyOrDie(sdCtx != nullptr); - - if (sdCtx->hasOpenThreadTimerStarted) - { - sdCtx->Finalize(); - } + sdCtx->Finalize(); } /** - * @brief Starts a timer to wait for the resolution on the kOpenThreadDot domain to happen. + * @brief Starts a timer to wait for the resolution on the kSrpDot domain to happen. * * @param[in] timeoutSeconds The timeout in seconds. * @param[in] ResolveContext The resolve context. */ -void StartOpenThreadTimer(uint16_t timeoutInMSecs, ResolveContext * ctx) +CHIP_ERROR StartSrpTimer(uint16_t timeoutInMSecs, ResolveContext * ctx) { - VerifyOrReturn(ctx != nullptr, ChipLogError(Discovery, "Can't schedule open thread timer since context is null")); - DeviceLayer::SystemLayer().StartTimer(System::Clock::Milliseconds16(timeoutInMSecs), OpenThreadTimerExpiredCallback, - reinterpret_cast(ctx)); + VerifyOrReturnValue(ctx != nullptr, CHIP_ERROR_INCORRECT_STATE); + return DeviceLayer::SystemLayer().StartTimer(System::Clock::Milliseconds16(timeoutInMSecs), SrpTimerExpiredCallback, + reinterpret_cast(ctx)); } +/** + * @brief Cancels the timer that was started to wait for the resolution on the kSrpDot domain to happen. + * + * @param[in] ResolveContext The resolve context. + */ +void CancelSrpTimer(ResolveContext * ctx) +{ + DeviceLayer::SystemLayer().CancelTimer(SrpTimerExpiredCallback, reinterpret_cast(ctx)); +} + +Global MdnsContexts::sInstance; + +namespace { + static void OnRegister(DNSServiceRef sdRef, DNSServiceFlags flags, DNSServiceErrorType err, const char * name, const char * type, const char * domain, void * context) { @@ -248,17 +250,17 @@ CHIP_ERROR Browse(BrowseHandler * sdCtx, uint32_t interfaceId, const char * type auto err = DNSServiceCreateConnection(&sdCtx->serviceRef); VerifyOrReturnError(kDNSServiceErr_NoError == err, sdCtx->Finalize(err)); - // We will browse on both the local domain and the open thread domain. + // We will browse on both the local domain and the srp domain. ChipLogProgress(Discovery, "Browsing for: %s on domain %s", StringOrNullMarker(type), kLocalDot); auto sdRefLocal = sdCtx->serviceRef; // Mandatory copy because of kDNSServiceFlagsShareConnection err = DNSServiceBrowse(&sdRefLocal, kBrowseFlags, interfaceId, type, kLocalDot, OnBrowse, sdCtx); VerifyOrReturnError(kDNSServiceErr_NoError == err, sdCtx->Finalize(err)); - ChipLogProgress(Discovery, "Browsing for: %s on domain %s", StringOrNullMarker(type), kOpenThreadDot); + ChipLogProgress(Discovery, "Browsing for: %s on domain %s", StringOrNullMarker(type), kSrpDot); - DNSServiceRef sdRefOpenThread = sdCtx->serviceRef; // Mandatory copy because of kDNSServiceFlagsShareConnection - err = DNSServiceBrowse(&sdRefOpenThread, kBrowseFlags, interfaceId, type, kOpenThreadDot, OnBrowse, sdCtx); + auto sdRefSrp = sdCtx->serviceRef; // Mandatory copy because of kDNSServiceFlagsShareConnection + err = DNSServiceBrowse(&sdRefSrp, kBrowseFlags, interfaceId, type, kSrpDot, OnBrowse, sdCtx); VerifyOrReturnError(kDNSServiceErr_NoError == err, sdCtx->Finalize(err)); return MdnsContexts::GetInstance().Add(sdCtx, sdCtx->serviceRef); @@ -307,25 +309,37 @@ static void OnGetAddrInfo(DNSServiceRef sdRef, DNSServiceFlags flags, uint32_t i { VerifyOrReturn(sdCtx->HasAddress(), sdCtx->Finalize(kDNSServiceErr_BadState)); - if (domainName.compare(kOpenThreadDot) == 0) + if (domainName.compare(kSrpDot) == 0) { - ChipLogProgress(Discovery, "Mdns: Resolve completed on the open thread domain."); + ChipLogProgress(Discovery, "Mdns: Resolve completed on the srp domain."); + + // Cancel the timer if one has been started + if (sdCtx->hasSrpTimerStarted) + { + CancelSrpTimer(sdCtx); + } sdCtx->Finalize(); } else if (domainName.compare(kLocalDot) == 0) { - ChipLogProgress( - Discovery, - "Mdns: Resolve completed on the local domain. Starting a timer for the open thread resolve to come back"); + ChipLogProgress(Discovery, + "Mdns: Resolve completed on the local domain. Starting a timer for the srp resolve to come back"); - // Usually the resolution on the local domain is quicker than on the open thread domain. We would like to give the - // resolution on the open thread domain around 250 millisecs more to give it a chance to resolve before finalizing + // Usually the resolution on the local domain is quicker than on the srp domain. We would like to give the + // resolution on the srp domain around 250 millisecs more to give it a chance to resolve before finalizing // the resolution. - if (!sdCtx->hasOpenThreadTimerStarted) + if (!sdCtx->hasSrpTimerStarted) { - // Schedule a timer to allow the resolve on OpenThread domain to complete. - StartOpenThreadTimer(kOpenThreadTimeoutInMsec, sdCtx); - sdCtx->hasOpenThreadTimerStarted = true; + // Schedule a timer to allow the resolve on Srp domain to complete. + CHIP_ERROR error = StartSrpTimer(kSrpTimeoutInMsec, sdCtx); + + // If the timer fails to start, finalize the context and return. + if (error != CHIP_NO_ERROR) + { + sdCtx->Finalize(); + return; + } + sdCtx->hasSrpTimerStarted = true; } } } @@ -367,8 +381,7 @@ static void OnResolve(DNSServiceRef sdRef, DNSServiceFlags flags, uint32_t inter if (!sdCtx->isResolveRequested) { GetAddrInfo(sdCtx); - sdCtx->isResolveRequested = true; - sdCtx->hasOpenThreadTimerStarted = false; + sdCtx->isResolveRequested = true; } } } @@ -382,13 +395,13 @@ static CHIP_ERROR Resolve(ResolveContext * sdCtx, uint32_t interfaceId, chip::In auto err = DNSServiceCreateConnection(&sdCtx->serviceRef); VerifyOrReturnError(kDNSServiceErr_NoError == err, sdCtx->Finalize(err)); - // Similar to browse, will try to resolve using both the local domain and the open thread domain. + // Similar to browse, will try to resolve using both the local domain and the srp domain. auto sdRefLocal = sdCtx->serviceRef; // Mandatory copy because of kDNSServiceFlagsShareConnection err = DNSServiceResolve(&sdRefLocal, kResolveFlags, interfaceId, name, type, kLocalDot, OnResolve, sdCtx); VerifyOrReturnError(kDNSServiceErr_NoError == err, sdCtx->Finalize(err)); - auto sdRefOpenThread = sdCtx->serviceRef; // Mandatory copy because of kDNSServiceFlagsShareConnection - err = DNSServiceResolve(&sdRefOpenThread, kResolveFlags, interfaceId, name, type, kOpenThreadDot, OnResolve, sdCtx); + auto sdRefSrp = sdCtx->serviceRef; // Mandatory copy because of kDNSServiceFlagsShareConnection + err = DNSServiceResolve(&sdRefSrp, kResolveFlags, interfaceId, name, type, kSrpDot, OnResolve, sdCtx); VerifyOrReturnError(kDNSServiceErr_NoError == err, sdCtx->Finalize(err)); auto retval = MdnsContexts::GetInstance().Add(sdCtx, sdCtx->serviceRef); diff --git a/src/platform/Darwin/DnssdImpl.h b/src/platform/Darwin/DnssdImpl.h index 713a465c236afc..8a9853082b59ee 100644 --- a/src/platform/Darwin/DnssdImpl.h +++ b/src/platform/Darwin/DnssdImpl.h @@ -27,15 +27,17 @@ #include #include -constexpr char kLocalDot[] = "local."; - -constexpr char kOpenThreadDot[] = "default.service.arpa."; - namespace chip { namespace Dnssd { +struct BrowseWithDelegateContext; +struct RegisterContext; +struct ResolveContext; + std::string GetDomainFromHostName(const char * hostname); +void CancelSrpTimer(ResolveContext * ctx); + enum class ContextType { Register, @@ -62,10 +64,6 @@ struct GenericContext CHIP_ERROR FinalizeInternal(const char * errorStr, CHIP_ERROR err); }; -struct BrowseWithDelegateContext; -struct RegisterContext; -struct ResolveContext; - class MdnsContexts { public: @@ -239,7 +237,7 @@ struct ResolveContext : public GenericContext bool isResolveRequested = false; std::shared_ptr consumerCounter; BrowseContext * const browseThatCausedResolve; // Can be null - bool hasOpenThreadTimerStarted = false; + bool hasSrpTimerStarted = false; // browseCausingResolve can be null. ResolveContext(void * cbContext, DnssdResolveCallback cb, chip::Inet::IPAddressType cbAddressType, @@ -252,7 +250,7 @@ struct ResolveContext : public GenericContext void DispatchFailure(const char * errorStr, CHIP_ERROR err) override; void DispatchSuccess() override; - CHIP_ERROR OnNewAddress(const std::pair interfaceKey, const struct sockaddr * address); + CHIP_ERROR OnNewAddress(const std::pair & interfaceKey, const struct sockaddr * address); bool HasAddress(); void OnNewInterface(uint32_t interfaceId, const char * fullname, const char * hostname, uint16_t port, uint16_t txtLen, @@ -266,7 +264,7 @@ struct ResolveContext : public GenericContext * Returns true if information was reported, false if not (e.g. if there * were no IP addresses, etc). */ - bool TryReportingResultsForInterfaceIndex(uint32_t interfaceIndex, std::string domainName); + bool TryReportingResultsForInterfaceIndex(uint32_t interfaceIndex); }; } // namespace Dnssd