Skip to content

Commit

Permalink
Dnssd changes to browse and resolve using open thread domain along wi… (
Browse files Browse the repository at this point in the history
project-chip#32631)

* Dnssd changes to browse and resolve using open thread domain along with the local domain

* Add checks for empty domain

* Restyled by clang-format

* Update src/platform/Darwin/DnssdImpl.cpp

Co-authored-by: Karsten Sperling <[email protected]>

* Addressed review comments

* Restyled by clang-format

---------

Co-authored-by: Restyled.io <[email protected]>
Co-authored-by: Karsten Sperling <[email protected]>
  • Loading branch information
3 people authored Mar 21, 2024
1 parent 03531f7 commit cc576dc
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 40 deletions.
52 changes: 32 additions & 20 deletions src/platform/Darwin/DnssdContexts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,6 @@ namespace {

constexpr uint8_t kDnssdKeyMaxSize = 32;
constexpr uint8_t kDnssdTxtRecordMaxEntries = 20;
constexpr char kLocalDot[] = "local.";

bool IsLocalDomain(const char * domain)
{
return strcmp(kLocalDot, domain) == 0;
}

std::string GetHostNameWithoutDomain(const char * hostnameWithDomain)
{
Expand Down Expand Up @@ -252,6 +246,7 @@ void MdnsContexts::Delete(GenericContext * context)
{
DNSServiceRefDeallocate(context->serviceRef);
}

chip::Platform::Delete(context);
}

Expand Down Expand Up @@ -388,7 +383,6 @@ void BrowseContext::OnBrowseAdd(const char * name, const char * type, const char
ChipLogProgress(Discovery, "Mdns: %s name: %s, type: %s, domain: %s, interface: %" PRIu32, __func__, StringOrNullMarker(name),
StringOrNullMarker(type), StringOrNullMarker(domain), interfaceId);

VerifyOrReturn(IsLocalDomain(domain));
auto service = GetService(name, type, protocol, interfaceId);
services.push_back(service);
}
Expand All @@ -399,7 +393,6 @@ void BrowseContext::OnBrowseRemove(const char * name, const char * type, const c
StringOrNullMarker(type), StringOrNullMarker(domain), interfaceId);

VerifyOrReturn(name != nullptr);
VerifyOrReturn(IsLocalDomain(domain));

services.erase(std::remove_if(services.begin(), services.end(),
[name, type, interfaceId](const DnssdService & service) {
Expand Down Expand Up @@ -443,8 +436,6 @@ void BrowseWithDelegateContext::OnBrowseAdd(const char * name, const char * type
ChipLogProgress(Discovery, "Mdns: %s name: %s, type: %s, domain: %s, interface: %" PRIu32, __func__, StringOrNullMarker(name),
StringOrNullMarker(type), StringOrNullMarker(domain), interfaceId);

VerifyOrReturn(IsLocalDomain(domain));

auto delegate = static_cast<DnssdBrowseDelegate *>(context);
auto service = GetService(name, type, protocol, interfaceId);
delegate->OnBrowseAdd(service);
Expand All @@ -456,7 +447,6 @@ void BrowseWithDelegateContext::OnBrowseRemove(const char * name, const char * t
StringOrNullMarker(type), StringOrNullMarker(domain), interfaceId);

VerifyOrReturn(name != nullptr);
VerifyOrReturn(IsLocalDomain(domain));

auto delegate = static_cast<DnssdBrowseDelegate *>(context);
auto service = GetService(name, type, protocol, interfaceId);
Expand Down Expand Up @@ -536,7 +526,17 @@ void ResolveContext::DispatchSuccess()

for (auto interfaceIndex : priorityInterfaceIndices)
{
if (TryReportingResultsForInterfaceIndex(static_cast<uint32_t>(interfaceIndex)))
// Try finding interfaces for domains kLocalDot and kOpenThreadDot and delete them.
if (TryReportingResultsForInterfaceIndex(static_cast<uint32_t>(interfaceIndex), std::string(kLocalDot)))
{
if (needDelete)
{
MdnsContexts::GetInstance().Delete(this);
}
return;
}

if (TryReportingResultsForInterfaceIndex(static_cast<uint32_t>(interfaceIndex), std::string(kOpenThreadDot)))
{
if (needDelete)
{
Expand All @@ -548,7 +548,7 @@ void ResolveContext::DispatchSuccess()

for (auto & interface : interfaces)
{
if (TryReportingResultsForInterfaceIndex(interface.first))
if (TryReportingResultsForInterfaceIndex(interface.first.first, interface.first.second))
{
break;
}
Expand All @@ -560,16 +560,17 @@ void ResolveContext::DispatchSuccess()
}
}

bool ResolveContext::TryReportingResultsForInterfaceIndex(uint32_t interfaceIndex)
bool ResolveContext::TryReportingResultsForInterfaceIndex(uint32_t interfaceIndex, std::string domainName)
{
if (interfaceIndex == 0)
{
// Not actually an interface we have.
return false;
}

auto & interface = interfaces[interfaceIndex];
auto & ips = interface.addresses;
std::pair<uint32_t, std::string> interfaceKey = std::make_pair(interfaceIndex, domainName);
auto & interface = interfaces[interfaceKey];
auto & ips = interface.addresses;

// Some interface may not have any ips, just ignore them.
if (ips.size() == 0)
Expand All @@ -596,15 +597,17 @@ bool ResolveContext::TryReportingResultsForInterfaceIndex(uint32_t interfaceInde
return true;
}

CHIP_ERROR ResolveContext::OnNewAddress(uint32_t interfaceId, const struct sockaddr * address)
CHIP_ERROR ResolveContext::OnNewAddress(const std::pair<uint32_t, std::string> interfaceKey, const struct sockaddr * address)
{
// If we don't have any information about this interfaceId, just ignore the
// address, since it won't be usable anyway without things like the port.
// This can happen if "local" is set up as a search domain in the DNS setup
// on the system, because the hostnames we are looking up all end in
// ".local". In other words, we can get regular DNS results in here, not
// just DNS-SD ones.
if (interfaces.find(interfaceId) == interfaces.end())
uint32_t interfaceId = interfaceKey.first;

if (interfaces.find(interfaceKey) == interfaces.end())
{
return CHIP_NO_ERROR;
}
Expand All @@ -627,7 +630,7 @@ CHIP_ERROR ResolveContext::OnNewAddress(uint32_t interfaceId, const struct socka
return CHIP_NO_ERROR;
}

interfaces[interfaceId].addresses.push_back(ip);
interfaces[interfaceKey].addresses.push_back(ip);

return CHIP_NO_ERROR;
}
Expand Down Expand Up @@ -709,7 +712,16 @@ void ResolveContext::OnNewInterface(uint32_t interfaceId, const char * fullname,
// resolving.
interface.fullyQualifiedDomainName = hostnameWithDomain;

interfaces.insert(std::make_pair(interfaceId, std::move(interface)));
std::string domainFromHostname = GetDomainFromHostName(hostnameWithDomain);
if (domainFromHostname.empty())
{
ChipLogError(Discovery, "Mdns: No domain set in hostname %s", hostnameWithDomain);
return;
}

std::pair<uint32_t, std::string> interfaceKey = std::make_pair(interfaceId, domainFromHostname);

interfaces.insert(std::make_pair(interfaceKey, std::move(interface)));
}

bool ResolveContext::HasInterface()
Expand Down
145 changes: 128 additions & 17 deletions src/platform/Darwin/DnssdImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,17 @@
#include <lib/support/logging/CHIPLogging.h>
#include <platform/CHIPDeviceLayer.h>

using namespace chip;
using namespace chip::Dnssd;
using namespace chip::Dnssd::Internal;

namespace {

constexpr char kLocalDot[] = "local.";
// The extra time in milliseconds that we will wait for the resolution on the open thread domain to complete.
constexpr uint16_t kOpenThreadTimeoutInMsec = 250;

constexpr DNSServiceFlags kRegisterFlags = kDNSServiceFlagsNoAutoRename;
constexpr DNSServiceFlags kBrowseFlags = 0;
constexpr DNSServiceFlags kBrowseFlags = kDNSServiceFlagsShareConnection;
constexpr DNSServiceFlags kGetAddrInfoFlags = kDNSServiceFlagsTimeout | kDNSServiceFlagsShareConnection;
constexpr DNSServiceFlags kResolveFlags = kDNSServiceFlagsShareConnection;
constexpr DNSServiceFlags kReconfirmRecordFlags = 0;
Expand All @@ -49,7 +51,7 @@ uint32_t GetInterfaceId(chip::Inet::InterfaceId interfaceId)
return interfaceId.IsPresent() ? interfaceId.GetPlatformInterface() : kDNSServiceInterfaceIndexAny;
}

std::string GetHostNameWithDomain(const char * hostname)
std::string GetHostNameWithLocalDomain(const char * hostname)
{
return std::string(hostname) + '.' + kLocalDot;
}
Expand Down Expand Up @@ -131,10 +133,70 @@ std::shared_ptr<uint32_t> GetCounterHolder(const char * name)
namespace chip {
namespace Dnssd {

/**
* @brief Returns the domain name from a given hostname with domain.
* The assumption here is that the hostname comprises of "hostnameWithoutDomain.<domain>."
* The domainName returned from this API is "<domain>."
*
* @param[in] hostname The hostname with domain.
*/
std::string GetDomainFromHostName(const char * hostnameWithDomain)
{
std::string hostname = std::string(hostnameWithDomain);

// Find the last occurence of '.'
size_t last_pos = hostname.find_last_of(".");
if (last_pos != std::string::npos)
{
// Get a substring without last '.'
std::string substring = hostname.substr(0, last_pos);

// Find the last occurence of '.' in the substring created above.
size_t pos = substring.find_last_of(".");
if (pos != std::string::npos)
{
// Return the domain name between the last 2 occurences of '.' including the trailing dot'.'.
return std::string(hostname.substr(pos + 1, last_pos));
}
}
return std::string();
}

Global<MdnsContexts> MdnsContexts::sInstance;

namespace {

/**
* @brief Callback that is called when the timeout for resolving on the kOpenThreadDot domain has expired.
*
* @param[in] systemLayer The system layer.
* @param[in] callbackContext The context passed to the timer callback.
*/
void OpenThreadTimerExpiredCallback(System::Layer * systemLayer, void * callbackContext)
{
ChipLogProgress(Discovery, "Mdns: Timer expired for resolve to complete on the open thread domain.");
auto sdCtx = static_cast<ResolveContext *>(callbackContext);
VerifyOrDie(sdCtx != nullptr);

if (sdCtx->hasOpenThreadTimerStarted)
{
sdCtx->Finalize();
}
}

/**
* @brief Starts a timer to wait for the resolution on the kOpenThreadDot domain to happen.
*
* @param[in] timeoutSeconds The timeout in seconds.
* @param[in] ResolveContext The resolve context.
*/
void StartOpenThreadTimer(uint16_t timeoutInMSecs, ResolveContext * ctx)
{
VerifyOrReturn(ctx != nullptr, ChipLogError(Discovery, "Can't schedule open thread timer since context is null"));
DeviceLayer::SystemLayer().StartTimer(System::Clock::Milliseconds16(timeoutInMSecs), OpenThreadTimerExpiredCallback,
reinterpret_cast<void *>(ctx));
}

static void OnRegister(DNSServiceRef sdRef, DNSServiceFlags flags, DNSServiceErrorType err, const char * name, const char * type,
const char * domain, void * context)
{
Expand Down Expand Up @@ -183,14 +245,24 @@ static void OnBrowse(DNSServiceRef sdRef, DNSServiceFlags flags, uint32_t interf

CHIP_ERROR Browse(BrowseHandler * sdCtx, uint32_t interfaceId, const char * type)
{
ChipLogProgress(Discovery, "Browsing for: %s", StringOrNullMarker(type));
DNSServiceRef sdRef;
auto err = DNSServiceBrowse(&sdRef, kBrowseFlags, interfaceId, type, kLocalDot, OnBrowse, sdCtx);
auto err = DNSServiceCreateConnection(&sdCtx->serviceRef);
VerifyOrReturnError(kDNSServiceErr_NoError == err, sdCtx->Finalize(err));

return MdnsContexts::GetInstance().Add(sdCtx, sdRef);
}
// We will browse on both the local domain and the open thread domain.
ChipLogProgress(Discovery, "Browsing for: %s on domain %s", StringOrNullMarker(type), kLocalDot);

auto sdRefLocal = sdCtx->serviceRef; // Mandatory copy because of kDNSServiceFlagsShareConnection
err = DNSServiceBrowse(&sdRefLocal, kBrowseFlags, interfaceId, type, kLocalDot, OnBrowse, sdCtx);
VerifyOrReturnError(kDNSServiceErr_NoError == err, sdCtx->Finalize(err));

ChipLogProgress(Discovery, "Browsing for: %s on domain %s", StringOrNullMarker(type), kOpenThreadDot);

DNSServiceRef sdRefOpenThread = sdCtx->serviceRef; // Mandatory copy because of kDNSServiceFlagsShareConnection
err = DNSServiceBrowse(&sdRefOpenThread, kBrowseFlags, interfaceId, type, kOpenThreadDot, OnBrowse, sdCtx);
VerifyOrReturnError(kDNSServiceErr_NoError == err, sdCtx->Finalize(err));

return MdnsContexts::GetInstance().Add(sdCtx, sdCtx->serviceRef);
}
CHIP_ERROR Browse(void * context, DnssdBrowseCallback callback, uint32_t interfaceId, const char * type,
DnssdServiceProtocol protocol, intptr_t * browseIdentifier)
{
Expand Down Expand Up @@ -219,25 +291,52 @@ static void OnGetAddrInfo(DNSServiceRef sdRef, DNSServiceFlags flags, uint32_t i
ReturnOnFailure(MdnsContexts::GetInstance().Has(sdCtx));
LogOnFailure(__func__, err);

std::string domainName = GetDomainFromHostName(hostname);
if (domainName.empty())
{
ChipLogError(Discovery, "Mdns: Domain name is not set in hostname %s", hostname);
return;
}
if (kDNSServiceErr_NoError == err)
{
sdCtx->OnNewAddress(interfaceId, address);
std::pair<uint32_t, std::string> key = std::make_pair(interfaceId, domainName);
sdCtx->OnNewAddress(key, address);
}

if (!(flags & kDNSServiceFlagsMoreComing))
{
VerifyOrReturn(sdCtx->HasAddress(), sdCtx->Finalize(kDNSServiceErr_BadState));
sdCtx->Finalize();

if (domainName.compare(kOpenThreadDot) == 0)
{
ChipLogProgress(Discovery, "Mdns: Resolve completed on the open thread domain.");
sdCtx->Finalize();
}
else if (domainName.compare(kLocalDot) == 0)
{
ChipLogProgress(
Discovery,
"Mdns: Resolve completed on the local domain. Starting a timer for the open thread resolve to come back");

// Usually the resolution on the local domain is quicker than on the open thread domain. We would like to give the
// resolution on the open thread domain around 250 millisecs more to give it a chance to resolve before finalizing
// the resolution.
if (!sdCtx->hasOpenThreadTimerStarted)
{
// Schedule a timer to allow the resolve on OpenThread domain to complete.
StartOpenThreadTimer(kOpenThreadTimeoutInMsec, sdCtx);
sdCtx->hasOpenThreadTimerStarted = true;
}
}
}
}

static void GetAddrInfo(ResolveContext * sdCtx)
{
auto protocol = sdCtx->protocol;

for (auto & interface : sdCtx->interfaces)
{
auto interfaceId = interface.first;
auto interfaceId = interface.first.first;
auto hostname = interface.second.fullyQualifiedDomainName.c_str();
auto sdRefCopy = sdCtx->serviceRef; // Mandatory copy because of kDNSServiceFlagsShareConnection
auto err = DNSServiceGetAddrInfo(&sdRefCopy, kGetAddrInfoFlags, interfaceId, protocol, hostname, OnGetAddrInfo, sdCtx);
Expand All @@ -263,7 +362,14 @@ static void OnResolve(DNSServiceRef sdRef, DNSServiceFlags flags, uint32_t inter
if (!(flags & kDNSServiceFlagsMoreComing))
{
VerifyOrReturn(sdCtx->HasInterface(), sdCtx->Finalize(kDNSServiceErr_BadState));
GetAddrInfo(sdCtx);

// If a resolve was not requested on this context, call GetAddrInfo and set the isResolveRequested flag to true.
if (!sdCtx->isResolveRequested)
{
GetAddrInfo(sdCtx);
sdCtx->isResolveRequested = true;
sdCtx->hasOpenThreadTimerStarted = false;
}
}
}

Expand All @@ -276,8 +382,13 @@ static CHIP_ERROR Resolve(ResolveContext * sdCtx, uint32_t interfaceId, chip::In
auto err = DNSServiceCreateConnection(&sdCtx->serviceRef);
VerifyOrReturnError(kDNSServiceErr_NoError == err, sdCtx->Finalize(err));

auto sdRefCopy = sdCtx->serviceRef; // Mandatory copy because of kDNSServiceFlagsShareConnection
err = DNSServiceResolve(&sdRefCopy, kResolveFlags, interfaceId, name, type, kLocalDot, OnResolve, sdCtx);
// Similar to browse, will try to resolve using both the local domain and the open thread domain.
auto sdRefLocal = sdCtx->serviceRef; // Mandatory copy because of kDNSServiceFlagsShareConnection
err = DNSServiceResolve(&sdRefLocal, kResolveFlags, interfaceId, name, type, kLocalDot, OnResolve, sdCtx);
VerifyOrReturnError(kDNSServiceErr_NoError == err, sdCtx->Finalize(err));

auto sdRefOpenThread = sdCtx->serviceRef; // Mandatory copy because of kDNSServiceFlagsShareConnection
err = DNSServiceResolve(&sdRefOpenThread, kResolveFlags, interfaceId, name, type, kOpenThreadDot, OnResolve, sdCtx);
VerifyOrReturnError(kDNSServiceErr_NoError == err, sdCtx->Finalize(err));

auto retval = MdnsContexts::GetInstance().Add(sdCtx, sdCtx->serviceRef);
Expand Down Expand Up @@ -339,7 +450,7 @@ CHIP_ERROR ChipDnssdPublishService(const DnssdService * service, DnssdPublishCal

auto regtype = GetFullTypeWithSubTypes(service);
auto interfaceId = GetInterfaceId(service->mInterface);
auto hostname = GetHostNameWithDomain(service->mHostName);
auto hostname = GetHostNameWithLocalDomain(service->mHostName);

return Register(context, callback, interfaceId, regtype.c_str(), service->mName, service->mPort, record, service->mAddressType,
hostname.c_str());
Expand Down Expand Up @@ -485,7 +596,7 @@ CHIP_ERROR ChipDnssdReconfirmRecord(const char * hostname, chip::Inet::IPAddress

auto interfaceId = interface.GetPlatformInterface();
auto rrclass = kDNSServiceClass_IN;
auto fullname = GetHostNameWithDomain(hostname);
auto fullname = GetHostNameWithLocalDomain(hostname);

uint16_t rrtype;
uint16_t rdlen;
Expand Down
Loading

0 comments on commit cc576dc

Please sign in to comment.