Skip to content

Commit

Permalink
[dnssd] Fixed OT DNS API usage for Thread platform (#26199)
Browse files Browse the repository at this point in the history
* [dnssd] Fixed OT DNS API usage for Thread platform

Introduced several fixes to the Thread platform DNS implementation:
* Added checking if memory allocation for DnsResult was successful
and dispatching dedicated methods to inform upper layer in case
of memory allocation failure
* Added checking if DNS response for DNS resolve includes AAAA
record. In case it doesn't the additional DNS query to obtain
IPv6 address will be sent.
* Added checking if DNS response for DNS browse includes SRV, TXT
and AAAA records. In case it doesn't the additional DNS queries
to obtain SRV + TXT, and AAAA records will be sent.

* Addressed review comments

* Fixed error handling and potential memory leaks
* Moved handling resolve after browse from Thread platform
to Discovery_ImplPlatform.

* Addressed second code review

* Fixed string copying by adding the exact size of data to copy
instead of relying on the max buffer size.
  • Loading branch information
kkasperczyk-no authored and pull[bot] committed Jan 4, 2024
1 parent e7597c5 commit 1104828
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 46 deletions.
4 changes: 3 additions & 1 deletion src/lib/dnssd/Discovery_ImplPlatform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,9 @@ static void HandleNodeBrowse(void * context, DnssdService * services, size_t ser
{
proxy->Retain();
// For some platforms browsed services are already resolved, so verify if resolve is really needed or call resolve callback
if (!services[i].mAddress.HasValue())

// Check if SRV, TXT and AAAA records were received in DNS responses
if (strlen(services[i].mHostName) == 0 || services[i].mTextEntrySize == 0 || !services[i].mAddress.HasValue())
{
ChipDnssdResolve(&services[i], services[i].mInterface, HandleNodeResolve, context);
}
Expand Down
199 changes: 156 additions & 43 deletions src/platform/OpenThread/GenericThreadStackManagerImpl_OpenThread.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2498,31 +2498,20 @@ CHIP_ERROR GenericThreadStackManagerImpl_OpenThread<ImplClass>::_SetSrpDnsCallba
template <class ImplClass>
CHIP_ERROR GenericThreadStackManagerImpl_OpenThread<ImplClass>::FromOtDnsResponseToMdnsData(
otDnsServiceInfo & serviceInfo, const char * serviceType, chip::Dnssd::DnssdService & mdnsService,
DnsServiceTxtEntries & serviceTxtEntries)
DnsServiceTxtEntries & serviceTxtEntries, otError error)
{
char protocol[chip::Dnssd::kDnssdProtocolTextMaxSize + 1];

if (strchr(serviceInfo.mHostNameBuffer, '.') == nullptr)
return CHIP_ERROR_INVALID_ARGUMENT;

// Extract from the <hostname>.<domain-name>. the <hostname> part.
size_t substringSize = strchr(serviceInfo.mHostNameBuffer, '.') - serviceInfo.mHostNameBuffer;
if (substringSize >= ArraySize(mdnsService.mHostName))
{
return CHIP_ERROR_INVALID_ARGUMENT;
}
Platform::CopyString(mdnsService.mHostName, serviceInfo.mHostNameBuffer);

if (strchr(serviceType, '.') == nullptr)
return CHIP_ERROR_INVALID_ARGUMENT;

// Extract from the <type>.<protocol>.<domain-name>. the <type> part.
substringSize = strchr(serviceType, '.') - serviceType;
size_t substringSize = strchr(serviceType, '.') - serviceType;
if (substringSize >= ArraySize(mdnsService.mType))
{
return CHIP_ERROR_INVALID_ARGUMENT;
}
Platform::CopyString(mdnsService.mType, serviceType);
Platform::CopyString(mdnsService.mType, substringSize + 1, serviceType);

// Extract from the <type>.<protocol>.<domain-name>. the <protocol> part.
const char * protocolSubstringStart = serviceType + substringSize + 1;
Expand All @@ -2535,7 +2524,7 @@ CHIP_ERROR GenericThreadStackManagerImpl_OpenThread<ImplClass>::FromOtDnsRespons
{
return CHIP_ERROR_INVALID_ARGUMENT;
}
Platform::CopyString(protocol, protocolSubstringStart);
Platform::CopyString(protocol, substringSize + 1, protocolSubstringStart);

if (strncmp(protocol, "_udp", chip::Dnssd::kDnssdProtocolTextMaxSize) == 0)
{
Expand All @@ -2549,37 +2538,97 @@ CHIP_ERROR GenericThreadStackManagerImpl_OpenThread<ImplClass>::FromOtDnsRespons
{
mdnsService.mProtocol = chip::Dnssd::DnssdServiceProtocol::kDnssdProtocolUnknown;
}
mdnsService.mPort = serviceInfo.mPort;
mdnsService.mInterface = Inet::InterfaceId::Null();
mdnsService.mAddressType = Inet::IPAddressType::kIPv6;
mdnsService.mAddress = chip::Optional<chip::Inet::IPAddress>(ToIPAddress(serviceInfo.mHostAddress));

otDnsTxtEntryIterator iterator;
otDnsInitTxtEntryIterator(&iterator, serviceInfo.mTxtData, serviceInfo.mTxtDataSize);
// Check if SRV record was included in DNS response.
if (error != OT_ERROR_NOT_FOUND)
{
if (strchr(serviceInfo.mHostNameBuffer, '.') == nullptr)
return CHIP_ERROR_INVALID_ARGUMENT;

otDnsTxtEntry txtEntry;
FixedBufferAllocator alloc(serviceTxtEntries.mBuffer);
// Extract from the <hostname>.<domain-name>. the <hostname> part.
substringSize = strchr(serviceInfo.mHostNameBuffer, '.') - serviceInfo.mHostNameBuffer;
if (substringSize >= ArraySize(mdnsService.mHostName))
{
return CHIP_ERROR_INVALID_ARGUMENT;
}
Platform::CopyString(mdnsService.mHostName, substringSize + 1, serviceInfo.mHostNameBuffer);

uint8_t entryIndex = 0;
while ((otDnsGetNextTxtEntry(&iterator, &txtEntry) == OT_ERROR_NONE) && entryIndex < kMaxDnsServiceTxtEntriesNumber)
{
if (txtEntry.mKey == nullptr || txtEntry.mValue == nullptr)
continue;
mdnsService.mPort = serviceInfo.mPort;
}

serviceTxtEntries.mTxtEntries[entryIndex].mKey = alloc.Clone(txtEntry.mKey);
serviceTxtEntries.mTxtEntries[entryIndex].mData = alloc.Clone(txtEntry.mValue, txtEntry.mValueLength);
serviceTxtEntries.mTxtEntries[entryIndex].mDataSize = txtEntry.mValueLength;
entryIndex++;
mdnsService.mInterface = Inet::InterfaceId::Null();

// Check if AAAA record was included in DNS response.

if (!otIp6IsAddressUnspecified(&serviceInfo.mHostAddress))
{
mdnsService.mAddressType = Inet::IPAddressType::kIPv6;
mdnsService.mAddress = MakeOptional(ToIPAddress(serviceInfo.mHostAddress));
}

ReturnErrorCodeIf(alloc.AnyAllocFailed(), CHIP_ERROR_BUFFER_TOO_SMALL);
// Check if TXT record was included in DNS response.
if (serviceInfo.mTxtDataSize != 0)
{
otDnsTxtEntryIterator iterator;
otDnsInitTxtEntryIterator(&iterator, serviceInfo.mTxtData, serviceInfo.mTxtDataSize);

otDnsTxtEntry txtEntry;
FixedBufferAllocator alloc(serviceTxtEntries.mBuffer);

mdnsService.mTextEntries = serviceTxtEntries.mTxtEntries;
mdnsService.mTextEntrySize = entryIndex;
uint8_t entryIndex = 0;
while ((otDnsGetNextTxtEntry(&iterator, &txtEntry) == OT_ERROR_NONE) && entryIndex < kMaxDnsServiceTxtEntriesNumber)
{
if (txtEntry.mKey == nullptr || txtEntry.mValue == nullptr)
continue;

serviceTxtEntries.mTxtEntries[entryIndex].mKey = alloc.Clone(txtEntry.mKey);
serviceTxtEntries.mTxtEntries[entryIndex].mData = alloc.Clone(txtEntry.mValue, txtEntry.mValueLength);
serviceTxtEntries.mTxtEntries[entryIndex].mDataSize = txtEntry.mValueLength;
entryIndex++;
}

ReturnErrorCodeIf(alloc.AnyAllocFailed(), CHIP_ERROR_BUFFER_TOO_SMALL);

mdnsService.mTextEntries = serviceTxtEntries.mTxtEntries;
mdnsService.mTextEntrySize = entryIndex;
}

return CHIP_NO_ERROR;
}

template <class ImplClass>
CHIP_ERROR GenericThreadStackManagerImpl_OpenThread<ImplClass>::ResolveAddress(intptr_t context, otDnsAddressCallback callback)
{
DnsResult * dnsResult = reinterpret_cast<DnsResult *>(context);

ThreadStackMgrImpl().LockThreadStack();

char fullHostName[chip::Dnssd::kHostNameMaxLength + 1 + SrpClient::kDefaultDomainNameSize + 1];
snprintf(fullHostName, sizeof(fullHostName), "%s.%s", dnsResult->mMdnsService.mHostName, SrpClient::kDefaultDomainName);

CHIP_ERROR error = MapOpenThreadError(otDnsClientResolveAddress(ThreadStackMgrImpl().OTInstance(), fullHostName, callback,
reinterpret_cast<void *>(dnsResult), NULL));

ThreadStackMgrImpl().UnlockThreadStack();

return error;
}

template <class ImplClass>
void GenericThreadStackManagerImpl_OpenThread<ImplClass>::DispatchAddressResolve(intptr_t context)
{
CHIP_ERROR error = ResolveAddress(context, OnDnsAddressResolveResult);

// In case of address resolve failure, fill the error code field and dispatch method to end resolve process.
if (error != CHIP_NO_ERROR)
{
DnsResult * dnsResult = reinterpret_cast<DnsResult *>(context);
dnsResult->error = error;

DeviceLayer::PlatformMgr().ScheduleWork(DispatchResolve, reinterpret_cast<intptr_t>(dnsResult));
}
}

template <class ImplClass>
void GenericThreadStackManagerImpl_OpenThread<ImplClass>::DispatchResolve(intptr_t context)
{
Expand All @@ -2596,6 +2645,13 @@ void GenericThreadStackManagerImpl_OpenThread<ImplClass>::DispatchResolve(intptr
Platform::Delete<DnsResult>(dnsResult);
}

template <class ImplClass>
void GenericThreadStackManagerImpl_OpenThread<ImplClass>::DispatchResolveNoMemory(intptr_t context)
{
Span<Inet::IPAddress> ipAddrs;
ThreadStackMgrImpl().mDnsResolveCallback(reinterpret_cast<void *>(context), nullptr, ipAddrs, CHIP_ERROR_NO_MEMORY);
}

template <class ImplClass>
void GenericThreadStackManagerImpl_OpenThread<ImplClass>::DispatchBrowseEmpty(intptr_t context)
{
Expand All @@ -2612,6 +2668,12 @@ void GenericThreadStackManagerImpl_OpenThread<ImplClass>::DispatchBrowse(intptr_
Platform::Delete<DnsResult>(dnsResult);
}

template <class ImplClass>
void GenericThreadStackManagerImpl_OpenThread<ImplClass>::DispatchBrowseNoMemory(intptr_t context)
{
ThreadStackMgrImpl().mDnsBrowseCallback(reinterpret_cast<void *>(context), nullptr, 0, true, CHIP_ERROR_NO_MEMORY);
}

template <class ImplClass>
void GenericThreadStackManagerImpl_OpenThread<ImplClass>::OnDnsBrowseResult(otError aError, const otDnsBrowseResponse * aResponse,
void * aContext)
Expand Down Expand Up @@ -2647,12 +2709,16 @@ void GenericThreadStackManagerImpl_OpenThread<ImplClass>::OnDnsBrowseResult(otEr
serviceInfo.mTxtData = txtBuffer;
serviceInfo.mTxtDataSize = sizeof(txtBuffer);

error = MapOpenThreadError(otDnsBrowseResponseGetServiceInfo(aResponse, serviceName, &serviceInfo));
otError err = otDnsBrowseResponseGetServiceInfo(aResponse, serviceName, &serviceInfo);
error = MapOpenThreadError(err);

VerifyOrExit(error == CHIP_NO_ERROR, );
VerifyOrExit(err == OT_ERROR_NOT_FOUND || err == OT_ERROR_NONE, );

DnsResult * dnsResult = Platform::New<DnsResult>(aContext, CHIP_NO_ERROR);
error = FromOtDnsResponseToMdnsData(serviceInfo, type, dnsResult->mMdnsService, dnsResult->mServiceTxtEntry);

VerifyOrExit(dnsResult != nullptr, error = CHIP_ERROR_NO_MEMORY);

error = FromOtDnsResponseToMdnsData(serviceInfo, type, dnsResult->mMdnsService, dnsResult->mServiceTxtEntry, err);
if (CHIP_NO_ERROR == error)
{
// Invoke callback for every service one by one instead of for the whole
Expand All @@ -2672,7 +2738,15 @@ void GenericThreadStackManagerImpl_OpenThread<ImplClass>::OnDnsBrowseResult(otEr
exit:
// Invoke callback to notify about end-of-browse or failure
DnsResult * dnsResult = Platform::New<DnsResult>(aContext, error);
DeviceLayer::PlatformMgr().ScheduleWork(DispatchBrowseEmpty, reinterpret_cast<intptr_t>(dnsResult));

if (dnsResult == nullptr)
{
DeviceLayer::PlatformMgr().ScheduleWork(DispatchBrowseNoMemory, reinterpret_cast<intptr_t>(aContext));
}
else
{
DeviceLayer::PlatformMgr().ScheduleWork(DispatchBrowseEmpty, reinterpret_cast<intptr_t>(dnsResult));
}
}

template <class ImplClass>
Expand Down Expand Up @@ -2701,12 +2775,36 @@ CHIP_ERROR GenericThreadStackManagerImpl_OpenThread<ImplClass>::_DnsBrowse(const
return error;
}

template <class ImplClass>
void GenericThreadStackManagerImpl_OpenThread<ImplClass>::OnDnsAddressResolveResult(otError aError,
const otDnsAddressResponse * aResponse,
void * aContext)
{
CHIP_ERROR error;
DnsResult * dnsResult = reinterpret_cast<DnsResult *>(aContext);
otIp6Address address;

error = MapOpenThreadError(otDnsAddressResponseGetAddress(aResponse, 0, &address, nullptr));
if (error == CHIP_NO_ERROR)
{
dnsResult->mMdnsService.mAddress = MakeOptional(ToIPAddress(address));
}

dnsResult->error = error;

DeviceLayer::PlatformMgr().ScheduleWork(DispatchResolve, reinterpret_cast<intptr_t>(dnsResult));
}

template <class ImplClass>
void GenericThreadStackManagerImpl_OpenThread<ImplClass>::OnDnsResolveResult(otError aError, const otDnsServiceResponse * aResponse,
void * aContext)
{
CHIP_ERROR error;
otError otErr;
DnsResult * dnsResult = Platform::New<DnsResult>(aContext, MapOpenThreadError(aError));

VerifyOrExit(dnsResult != nullptr, error = CHIP_ERROR_NO_MEMORY);

// type buffer size is kDnssdTypeAndProtocolMaxSize + . + kMaxDomainNameSize + . + termination character
char type[Dnssd::kDnssdTypeAndProtocolMaxSize + SrpClient::kMaxDomainNameSize + 3];
// hostname buffer size is kHostNameMaxLength + . + kMaxDomainNameSize + . + termination character
Expand All @@ -2718,7 +2816,7 @@ void GenericThreadStackManagerImpl_OpenThread<ImplClass>::OnDnsResolveResult(otE

if (ThreadStackMgrImpl().mDnsResolveCallback == nullptr)
{
ChipLogError(DeviceLayer, "Invalid dns browse callback");
ChipLogError(DeviceLayer, "Invalid dns resolve callback");
return;
}

Expand All @@ -2734,16 +2832,31 @@ void GenericThreadStackManagerImpl_OpenThread<ImplClass>::OnDnsResolveResult(otE
serviceInfo.mTxtData = txtBuffer;
serviceInfo.mTxtDataSize = sizeof(txtBuffer);

error = MapOpenThreadError(otDnsServiceResponseGetServiceInfo(aResponse, &serviceInfo));
otErr = otDnsServiceResponseGetServiceInfo(aResponse, &serviceInfo);
error = MapOpenThreadError(otErr);

VerifyOrExit(error == CHIP_NO_ERROR, );

error = FromOtDnsResponseToMdnsData(serviceInfo, type, dnsResult->mMdnsService, dnsResult->mServiceTxtEntry);
error = FromOtDnsResponseToMdnsData(serviceInfo, type, dnsResult->mMdnsService, dnsResult->mServiceTxtEntry, otErr);

exit:
if (dnsResult == nullptr)
{
DeviceLayer::PlatformMgr().ScheduleWork(DispatchResolveNoMemory, reinterpret_cast<intptr_t>(aContext));
return;
}

dnsResult->error = error;
DeviceLayer::PlatformMgr().ScheduleWork(DispatchResolve, reinterpret_cast<intptr_t>(dnsResult));

// If IPv6 address in unspecified (AAAA record not present), send additional DNS query to obtain IPv6 address.
if (otIp6IsAddressUnspecified(&serviceInfo.mHostAddress))
{
DeviceLayer::PlatformMgr().ScheduleWork(DispatchAddressResolve, reinterpret_cast<intptr_t>(dnsResult));
}
else
{
DeviceLayer::PlatformMgr().ScheduleWork(DispatchResolve, reinterpret_cast<intptr_t>(dnsResult));
}
}

template <class ImplClass>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,11 @@ class GenericThreadStackManagerImpl_OpenThread
CHIP_ERROR _DnsBrowse(const char * aServiceName, DnsBrowseCallback aCallback, void * aContext);
CHIP_ERROR _DnsResolve(const char * aServiceName, const char * aInstanceName, DnsResolveCallback aCallback, void * aContext);
static void DispatchResolve(intptr_t context);
static void DispatchResolveNoMemory(intptr_t context);
static void DispatchAddressResolve(intptr_t context);
static void DispatchBrowseEmpty(intptr_t context);
static void DispatchBrowse(intptr_t context);
static void DispatchBrowseNoMemory(intptr_t context);
#endif // CHIP_DEVICE_CONFIG_ENABLE_THREAD_DNS_CLIENT
#endif // CHIP_DEVICE_CONFIG_ENABLE_THREAD_SRP_CLIENT

Expand Down Expand Up @@ -261,9 +264,13 @@ class GenericThreadStackManagerImpl_OpenThread

static void OnDnsBrowseResult(otError aError, const otDnsBrowseResponse * aResponse, void * aContext);
static void OnDnsResolveResult(otError aError, const otDnsServiceResponse * aResponse, void * aContext);
static void OnDnsAddressResolveResult(otError aError, const otDnsAddressResponse * aResponse, void * aContext);

static CHIP_ERROR ResolveAddress(intptr_t context, otDnsAddressCallback callback);

static CHIP_ERROR FromOtDnsResponseToMdnsData(otDnsServiceInfo & serviceInfo, const char * serviceType,
chip::Dnssd::DnssdService & mdnsService,
DnsServiceTxtEntries & serviceTxtEntries);
chip::Dnssd::DnssdService & mdnsService, DnsServiceTxtEntries & serviceTxtEntries,
otError error);
#endif // CHIP_DEVICE_CONFIG_ENABLE_THREAD_DNS_CLIENT
#endif // CHIP_DEVICE_CONFIG_ENABLE_THREAD_SRP_CLIENT

Expand Down

0 comments on commit 1104828

Please sign in to comment.