From 11283120243522e3ba892b1c400641f63f49aeff Mon Sep 17 00:00:00 2001 From: C Freeman Date: Mon, 14 Jun 2021 10:27:59 -0400 Subject: [PATCH] Minimal mdns: move responders into allocator class. (#7528) * Minimal mdns: move responders into allocator class. The code is basically the same, but moved into a subclass. This is done for 2 reasons: 1) Makes the allocator class more testable (tests added) 2) Right now the commissionable and operational records are mixed in a single query responder. This is not ideal and we need to separate these into different query responders in a single ResponseSender. This will let us allocate a query responder for each record easily and separate all the record and qname allocations so they can easily be cleared. (upcoming PR) * Restyled by gn Co-authored-by: Restyled.io --- src/lib/mdns/Advertiser_ImplMinimalMdns.cpp | 209 +++----------- .../Advertiser_ImplMinimalMdnsAllocator.h | 167 +++++++++++ src/lib/mdns/tests/BUILD.gn | 1 + .../mdns/tests/TestMinimalMdnsAllocator.cpp | 271 ++++++++++++++++++ 4 files changed, 486 insertions(+), 162 deletions(-) create mode 100644 src/lib/mdns/Advertiser_ImplMinimalMdnsAllocator.h create mode 100644 src/lib/mdns/tests/TestMinimalMdnsAllocator.cpp diff --git a/src/lib/mdns/Advertiser_ImplMinimalMdns.cpp b/src/lib/mdns/Advertiser_ImplMinimalMdns.cpp index 83374ac6788ead..3ff0169be7daf8 100644 --- a/src/lib/mdns/Advertiser_ImplMinimalMdns.cpp +++ b/src/lib/mdns/Advertiser_ImplMinimalMdns.cpp @@ -23,6 +23,7 @@ #include "MinimalMdnsServer.h" #include "ServiceNaming.h" +#include #include #include #include @@ -101,20 +102,11 @@ class AdvertiserMinMdns : public ServiceAdvertiser, public ParserDelegate // parses queries { public: - AdvertiserMinMdns() : mResponseSender(&GlobalMinimalMdnsServer::Server(), &mQueryResponder) + AdvertiserMinMdns() : mResponseSender(&GlobalMinimalMdnsServer::Server(), mQueryResponderAllocator.GetQueryResponder()) { GlobalMinimalMdnsServer::Instance().SetQueryDelegate(this); - - for (size_t i = 0; i < kMaxAllocatedResponders; i++) - { - mAllocatedResponders[i] = nullptr; - } - for (size_t i = 0; i < kMaxAllocatedQNameData; i++) - { - mAllocatedQNameParts[i] = nullptr; - } } - ~AdvertiserMinMdns() { Clear(); } + ~AdvertiserMinMdns() {} // Service advertiser CHIP_ERROR Start(chip::Inet::InetLayer * inetLayer, uint16_t port) override; @@ -131,10 +123,6 @@ class AdvertiserMinMdns : public ServiceAdvertiser, void OnQuery(const QueryData & data) override; private: - /// Sets the query responder to a blank state and frees up any - /// allocated memory. - void Clear(); - /// Advertise available records configured within the server /// /// Usable as boot-time advertisement of available SRV records. @@ -144,95 +132,17 @@ class AdvertiserMinMdns : public ServiceAdvertiser, /// interfaces on which the mDNS server is listening bool ShouldAdvertiseOn(const chip::Inet::InterfaceId id, const chip::Inet::IPAddress & addr); - QueryResponderSettings AddAllocatedResponder(RecordResponder * responder) - { - if (responder == nullptr) - { - ChipLogError(Discovery, "Responder memory allocation failed"); - return QueryResponderSettings(); // failed - } - - for (size_t i = 0; i < kMaxAllocatedResponders; i++) - { - if (mAllocatedResponders[i] != nullptr) - { - continue; - } - - mAllocatedResponders[i] = responder; - return mQueryResponder.AddResponder(mAllocatedResponders[i]); - } - - Platform::Delete(responder); - ChipLogError(Discovery, "Failed to find free slot for adding a responder"); - return QueryResponderSettings(); - } - - /// Appends another responder to the internal replies. - template - QueryResponderSettings AddResponder(Args &&... args) - { - return AddAllocatedResponder(chip::Platform::New(std::forward(args)...)); - } - - void * AllocateQNameSpace(size_t size) - { - for (size_t i = 0; i < kMaxAllocatedQNameData; i++) - { - if (mAllocatedQNameParts[i] != nullptr) - { - continue; - } - - mAllocatedQNameParts[i] = chip::Platform::MemoryAlloc(size); - if (mAllocatedQNameParts[i] == nullptr) - { - ChipLogError(Discovery, "QName memory allocation failed"); - } - return mAllocatedQNameParts[i]; - } - ChipLogError(Discovery, "Failed to find free slot for adding a qname"); - return nullptr; - } - - template - FullQName AllocateQName(Args &&... names) - { - void * storage = AllocateQNameSpace(FlatAllocatedQName::RequiredStorageSize(std::forward(names)...)); - if (storage == nullptr) - { - return FullQName(); - } - return FlatAllocatedQName::Build(storage, std::forward(names)...); - } - - FullQName AllocateQNameFromArray(char const * const * names, size_t num) - { - void * storage = AllocateQNameSpace(FlatAllocatedQName::RequiredStorageSizeFromArray(names, num)); - if (storage == nullptr) - { - return FullQName(); - } - return FlatAllocatedQName::BuildFromArray(storage, names, num); - } - FullQName GetCommisioningTextEntries(const CommissionAdvertisingParameters & params); - static constexpr size_t kMaxRecords = 32; - static constexpr size_t kMaxAllocatedResponders = 64; - static constexpr size_t kMaxAllocatedQNameData = 32; + static constexpr size_t kMaxRecords = 32; + QueryResponderAllocator mQueryResponderAllocator; - QueryResponder mQueryResponder; ResponseSender mResponseSender; // current request handling const chip::Inet::IPPacketInfo * mCurrentSource = nullptr; uint32_t mMessageId = 0; - // dynamically allocated items - RecordResponder * mAllocatedResponders[kMaxAllocatedResponders]; - void * mAllocatedQNameParts[kMaxAllocatedQNameData]; - const char * mEmptyTextEntries[1] = { "=", }; @@ -285,48 +195,23 @@ CHIP_ERROR AdvertiserMinMdns::Start(chip::Inet::InetLayer * inetLayer, uint16_t /// Stops the advertiser. CHIP_ERROR AdvertiserMinMdns::StopPublishDevice() { - Clear(); + mQueryResponderAllocator.Clear(); return CHIP_NO_ERROR; } -void AdvertiserMinMdns::Clear() -{ - // Init clears all responders, so that data can be freed - mQueryResponder.Init(); - - // Free all allocated data - for (size_t i = 0; i < kMaxAllocatedResponders; i++) - { - if (mAllocatedResponders[i] != nullptr) - { - chip::Platform::Delete(mAllocatedResponders[i]); - mAllocatedResponders[i] = nullptr; - } - } - - for (size_t i = 0; i < kMaxAllocatedQNameData; i++) - { - if (mAllocatedQNameParts[i] != nullptr) - { - chip::Platform::MemoryFree(mAllocatedQNameParts[i]); - mAllocatedQNameParts[i] = nullptr; - } - } -} - CHIP_ERROR AdvertiserMinMdns::Advertise(const OperationalAdvertisingParameters & params) { - Clear(); + mQueryResponderAllocator.Clear(); char nameBuffer[64] = ""; /// need to set server name ReturnErrorOnFailure(MakeInstanceName(nameBuffer, sizeof(nameBuffer), params.GetPeerId())); - FullQName operationalServiceName = AllocateQName("_chip", "_tcp", "local"); - FullQName operationalServerName = AllocateQName(nameBuffer, "_chip", "_tcp", "local"); + FullQName operationalServiceName = mQueryResponderAllocator.AllocateQName("_chip", "_tcp", "local"); + FullQName operationalServerName = mQueryResponderAllocator.AllocateQName(nameBuffer, "_chip", "_tcp", "local"); ReturnErrorOnFailure(MakeHostName(nameBuffer, sizeof(nameBuffer), params.GetMac())); - FullQName serverName = AllocateQName(nameBuffer, "local"); + FullQName serverName = mQueryResponderAllocator.AllocateQName(nameBuffer, "local"); if ((operationalServiceName.nameCount == 0) || (operationalServerName.nameCount == 0) || (serverName.nameCount == 0)) { @@ -334,7 +219,7 @@ CHIP_ERROR AdvertiserMinMdns::Advertise(const OperationalAdvertisingParameters & return CHIP_ERROR_NO_MEMORY; } - if (!AddResponder(operationalServiceName, operationalServerName) + if (!mQueryResponderAllocator.AddResponder(operationalServiceName, operationalServerName) .SetReportAdditional(operationalServerName) .SetReportInServiceListing(true) .IsValid()) @@ -343,14 +228,14 @@ CHIP_ERROR AdvertiserMinMdns::Advertise(const OperationalAdvertisingParameters & return CHIP_ERROR_NO_MEMORY; } - if (!AddResponder(SrvResourceRecord(operationalServerName, serverName, params.GetPort())) + if (!mQueryResponderAllocator.AddResponder(SrvResourceRecord(operationalServerName, serverName, params.GetPort())) .SetReportAdditional(serverName) .IsValid()) { ChipLogError(Discovery, "Failed to add SRV record mDNS responder"); return CHIP_ERROR_NO_MEMORY; } - if (!AddResponder(TxtResourceRecord(operationalServerName, mEmptyTextEntries)) + if (!mQueryResponderAllocator.AddResponder(TxtResourceRecord(operationalServerName, mEmptyTextEntries)) .SetReportAdditional(serverName) .IsValid()) { @@ -358,7 +243,7 @@ CHIP_ERROR AdvertiserMinMdns::Advertise(const OperationalAdvertisingParameters & return CHIP_ERROR_NO_MEMORY; } - if (!AddResponder(serverName).IsValid()) + if (!mQueryResponderAllocator.AddResponder(serverName).IsValid()) { ChipLogError(Discovery, "Failed to add IPv6 mDNS responder"); return CHIP_ERROR_NO_MEMORY; @@ -366,7 +251,7 @@ CHIP_ERROR AdvertiserMinMdns::Advertise(const OperationalAdvertisingParameters & if (params.IsIPv4Enabled()) { - if (!AddResponder(serverName).IsValid()) + if (!mQueryResponderAllocator.AddResponder(serverName).IsValid()) { ChipLogError(Discovery, "Failed to add IPv4 mDNS responder"); return CHIP_ERROR_NO_MEMORY; @@ -380,7 +265,7 @@ CHIP_ERROR AdvertiserMinMdns::Advertise(const OperationalAdvertisingParameters & CHIP_ERROR AdvertiserMinMdns::Advertise(const CommissionAdvertisingParameters & params) { - Clear(); + mQueryResponderAllocator.Clear(); // TODO: need to detect colisions here char nameBuffer[64] = ""; size_t len = snprintf(nameBuffer, sizeof(nameBuffer), ChipLogFormatX64, GetRandU32(), GetRandU32()); @@ -391,11 +276,11 @@ CHIP_ERROR AdvertiserMinMdns::Advertise(const CommissionAdvertisingParameters & const char * serviceType = params.GetCommissionAdvertiseMode() == CommssionAdvertiseMode::kCommissionableNode ? kCommissionableServiceName : kCommissionerServiceName; - FullQName serviceName = AllocateQName(serviceType, kCommissionProtocol, kLocalDomain); - FullQName instanceName = AllocateQName(nameBuffer, serviceType, kCommissionProtocol, kLocalDomain); + FullQName serviceName = mQueryResponderAllocator.AllocateQName(serviceType, kCommissionProtocol, kLocalDomain); + FullQName instanceName = mQueryResponderAllocator.AllocateQName(nameBuffer, serviceType, kCommissionProtocol, kLocalDomain); ReturnErrorOnFailure(MakeHostName(nameBuffer, sizeof(nameBuffer), params.GetMac())); - FullQName hostName = AllocateQName(nameBuffer, kLocalDomain); + FullQName hostName = mQueryResponderAllocator.AllocateQName(nameBuffer, kLocalDomain); if ((serviceName.nameCount == 0) || (instanceName.nameCount == 0) || (hostName.nameCount == 0)) { @@ -403,7 +288,7 @@ CHIP_ERROR AdvertiserMinMdns::Advertise(const CommissionAdvertisingParameters & return CHIP_ERROR_NO_MEMORY; } - if (!AddResponder(serviceName, instanceName) + if (!mQueryResponderAllocator.AddResponder(serviceName, instanceName) .SetReportAdditional(instanceName) .SetReportInServiceListing(true) .IsValid()) @@ -412,14 +297,14 @@ CHIP_ERROR AdvertiserMinMdns::Advertise(const CommissionAdvertisingParameters & return CHIP_ERROR_NO_MEMORY; } - if (!AddResponder(SrvResourceRecord(instanceName, hostName, params.GetPort())) + if (!mQueryResponderAllocator.AddResponder(SrvResourceRecord(instanceName, hostName, params.GetPort())) .SetReportAdditional(hostName) .IsValid()) { ChipLogError(Discovery, "Failed to add SRV record mDNS responder"); return CHIP_ERROR_NO_MEMORY; } - if (!AddResponder(hostName).IsValid()) + if (!mQueryResponderAllocator.AddResponder(hostName).IsValid()) { ChipLogError(Discovery, "Failed to add IPv6 mDNS responder"); return CHIP_ERROR_NO_MEMORY; @@ -427,7 +312,7 @@ CHIP_ERROR AdvertiserMinMdns::Advertise(const CommissionAdvertisingParameters & if (params.IsIPv4Enabled()) { - if (!AddResponder(hostName).IsValid()) + if (!mQueryResponderAllocator.AddResponder(hostName).IsValid()) { ChipLogError(Discovery, "Failed to add IPv4 mDNS responder"); return CHIP_ERROR_NO_MEMORY; @@ -438,11 +323,11 @@ CHIP_ERROR AdvertiserMinMdns::Advertise(const CommissionAdvertisingParameters & { MakeServiceSubtype(nameBuffer, sizeof(nameBuffer), DiscoveryFilter(DiscoveryFilterType::kVendor, params.GetVendorId().Value())); - FullQName vendorServiceName = - AllocateQName(nameBuffer, kSubtypeServiceNamePart, serviceType, kCommissionProtocol, kLocalDomain); + FullQName vendorServiceName = mQueryResponderAllocator.AllocateQName(nameBuffer, kSubtypeServiceNamePart, serviceType, + kCommissionProtocol, kLocalDomain); ReturnErrorCodeIf(vendorServiceName.nameCount == 0, CHIP_ERROR_NO_MEMORY); - if (!AddResponder(vendorServiceName, instanceName) + if (!mQueryResponderAllocator.AddResponder(vendorServiceName, instanceName) .SetReportAdditional(instanceName) .SetReportInServiceListing(true) .IsValid()) @@ -456,11 +341,11 @@ CHIP_ERROR AdvertiserMinMdns::Advertise(const CommissionAdvertisingParameters & { MakeServiceSubtype(nameBuffer, sizeof(nameBuffer), DiscoveryFilter(DiscoveryFilterType::kDeviceType, params.GetDeviceType().Value())); - FullQName vendorServiceName = - AllocateQName(nameBuffer, kSubtypeServiceNamePart, serviceType, kCommissionProtocol, kLocalDomain); + FullQName vendorServiceName = mQueryResponderAllocator.AllocateQName(nameBuffer, kSubtypeServiceNamePart, serviceType, + kCommissionProtocol, kLocalDomain); ReturnErrorCodeIf(vendorServiceName.nameCount == 0, CHIP_ERROR_NO_MEMORY); - if (!AddResponder(vendorServiceName, instanceName) + if (!mQueryResponderAllocator.AddResponder(vendorServiceName, instanceName) .SetReportAdditional(instanceName) .SetReportInServiceListing(true) .IsValid()) @@ -476,11 +361,11 @@ CHIP_ERROR AdvertiserMinMdns::Advertise(const CommissionAdvertisingParameters & { MakeServiceSubtype(nameBuffer, sizeof(nameBuffer), DiscoveryFilter(DiscoveryFilterType::kShort, params.GetShortDiscriminator())); - FullQName shortServiceName = - AllocateQName(nameBuffer, kSubtypeServiceNamePart, serviceType, kCommissionProtocol, kLocalDomain); + FullQName shortServiceName = mQueryResponderAllocator.AllocateQName(nameBuffer, kSubtypeServiceNamePart, serviceType, + kCommissionProtocol, kLocalDomain); ReturnErrorCodeIf(shortServiceName.nameCount == 0, CHIP_ERROR_NO_MEMORY); - if (!AddResponder(shortServiceName, instanceName) + if (!mQueryResponderAllocator.AddResponder(shortServiceName, instanceName) .SetReportAdditional(instanceName) .SetReportInServiceListing(true) .IsValid()) @@ -493,10 +378,10 @@ CHIP_ERROR AdvertiserMinMdns::Advertise(const CommissionAdvertisingParameters & { MakeServiceSubtype(nameBuffer, sizeof(nameBuffer), DiscoveryFilter(DiscoveryFilterType::kLong, params.GetLongDiscriminator())); - FullQName longServiceName = - AllocateQName(nameBuffer, kSubtypeServiceNamePart, serviceType, kCommissionProtocol, kLocalDomain); + FullQName longServiceName = mQueryResponderAllocator.AllocateQName(nameBuffer, kSubtypeServiceNamePart, serviceType, + kCommissionProtocol, kLocalDomain); ReturnErrorCodeIf(longServiceName.nameCount == 0, CHIP_ERROR_NO_MEMORY); - if (!AddResponder(longServiceName, instanceName) + if (!mQueryResponderAllocator.AddResponder(longServiceName, instanceName) .SetReportAdditional(instanceName) .SetReportInServiceListing(true) .IsValid()) @@ -509,10 +394,10 @@ CHIP_ERROR AdvertiserMinMdns::Advertise(const CommissionAdvertisingParameters & { MakeServiceSubtype(nameBuffer, sizeof(nameBuffer), DiscoveryFilter(DiscoveryFilterType::kCommissioningMode, params.GetCommissioningMode() ? 1 : 0)); - FullQName longServiceName = - AllocateQName(nameBuffer, kSubtypeServiceNamePart, serviceType, kCommissionProtocol, kLocalDomain); + FullQName longServiceName = mQueryResponderAllocator.AllocateQName(nameBuffer, kSubtypeServiceNamePart, serviceType, + kCommissionProtocol, kLocalDomain); ReturnErrorCodeIf(longServiceName.nameCount == 0, CHIP_ERROR_NO_MEMORY); - if (!AddResponder(longServiceName, instanceName) + if (!mQueryResponderAllocator.AddResponder(longServiceName, instanceName) .SetReportAdditional(instanceName) .SetReportInServiceListing(true) .IsValid()) @@ -526,10 +411,10 @@ CHIP_ERROR AdvertiserMinMdns::Advertise(const CommissionAdvertisingParameters & { MakeServiceSubtype(nameBuffer, sizeof(nameBuffer), DiscoveryFilter(DiscoveryFilterType::kCommissioningModeFromCommand, 1)); - FullQName longServiceName = - AllocateQName(nameBuffer, kSubtypeServiceNamePart, serviceType, kCommissionProtocol, kLocalDomain); + FullQName longServiceName = mQueryResponderAllocator.AllocateQName(nameBuffer, kSubtypeServiceNamePart, serviceType, + kCommissionProtocol, kLocalDomain); ReturnErrorCodeIf(longServiceName.nameCount == 0, CHIP_ERROR_NO_MEMORY); - if (!AddResponder(longServiceName, instanceName) + if (!mQueryResponderAllocator.AddResponder(longServiceName, instanceName) .SetReportAdditional(instanceName) .SetReportInServiceListing(true) .IsValid()) @@ -540,7 +425,7 @@ CHIP_ERROR AdvertiserMinMdns::Advertise(const CommissionAdvertisingParameters & } } - if (!AddResponder(TxtResourceRecord(instanceName, GetCommisioningTextEntries(params))) + if (!mQueryResponderAllocator.AddResponder(TxtResourceRecord(instanceName, GetCommisioningTextEntries(params))) .SetReportAdditional(hostName) .IsValid()) { @@ -603,7 +488,7 @@ FullQName AdvertiserMinMdns::GetCommisioningTextEntries(const CommissionAdvertis if (!params.GetVendorId().HasValue()) { - return AllocateQName(txtDiscriminator); + return mQueryResponderAllocator.AllocateQName(txtDiscriminator); } char txtCommissioningMode[chip::Mdns::kKeyCommissioningModeMaxLength + 4]; @@ -640,11 +525,11 @@ FullQName AdvertiserMinMdns::GetCommisioningTextEntries(const CommissionAdvertis } if (numTxtFields == 0) { - return AllocateQNameFromArray(mEmptyTextEntries, 1); + return mQueryResponderAllocator.AllocateQNameFromArray(mEmptyTextEntries, 1); } else { - return AllocateQNameFromArray(txtFields, numTxtFields); + return mQueryResponderAllocator.AllocateQNameFromArray(txtFields, numTxtFields); } } // namespace @@ -716,7 +601,7 @@ void AdvertiserMinMdns::AdvertiseRecords() QueryData queryData(QType::PTR, QClass::IN, false /* unicast */); queryData.SetIsBootAdvertising(true); - mQueryResponder.ClearBroadcastThrottle(); + mQueryResponderAllocator.GetQueryResponder()->ClearBroadcastThrottle(); CHIP_ERROR err = mResponseSender.Respond(0, queryData, &packetInfo); if (err != CHIP_NO_ERROR) @@ -726,7 +611,7 @@ void AdvertiserMinMdns::AdvertiseRecords() } // Once all automatic broadcasts are done, allow immediate replies once. - mQueryResponder.ClearBroadcastThrottle(); + mQueryResponderAllocator.GetQueryResponder()->ClearBroadcastThrottle(); } AdvertiserMinMdns gAdvertiser; diff --git a/src/lib/mdns/Advertiser_ImplMinimalMdnsAllocator.h b/src/lib/mdns/Advertiser_ImplMinimalMdnsAllocator.h new file mode 100644 index 00000000000000..fa198c04cd43cc --- /dev/null +++ b/src/lib/mdns/Advertiser_ImplMinimalMdnsAllocator.h @@ -0,0 +1,167 @@ +/* + * + * Copyright (c) 2020 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace chip { +namespace Mdns { + +template +class QueryResponderAllocator +{ +public: + QueryResponderAllocator() + { + for (size_t i = 0; i < kMaxRecords; i++) + { + mAllocatedResponders[i] = nullptr; + } + for (size_t i = 0; i < kMaxAllocatedQNameData; i++) + { + mAllocatedQNameParts[i] = nullptr; + } + } + ~QueryResponderAllocator() { Clear(); } + + /// Appends another responder to the internal replies. + template + mdns::Minimal::QueryResponderSettings AddResponder(Args &&... args) + { + return AddAllocatedResponder(chip::Platform::New(std::forward(args)...)); + } + + template + mdns::Minimal::FullQName AllocateQName(Args &&... names) + { + void * storage = AllocateQNameSpace(mdns::Minimal::FlatAllocatedQName::RequiredStorageSize(std::forward(names)...)); + if (storage == nullptr) + { + return mdns::Minimal::FullQName(); + } + return mdns::Minimal::FlatAllocatedQName::Build(storage, std::forward(names)...); + } + + mdns::Minimal::FullQName AllocateQNameFromArray(char const * const * names, size_t num) + { + void * storage = AllocateQNameSpace(mdns::Minimal::FlatAllocatedQName::RequiredStorageSizeFromArray(names, num)); + if (storage == nullptr) + { + return mdns::Minimal::FullQName(); + } + return mdns::Minimal::FlatAllocatedQName::BuildFromArray(storage, names, num); + } + + /// Sets the query responder to a blank state and frees up any + /// allocated memory. + void Clear() + { + // Init clears all responders, so that data can be freed + mQueryResponder.Init(); + + // Free all allocated data + for (size_t i = 0; i < kMaxRecords; i++) + { + if (mAllocatedResponders[i] != nullptr) + { + chip::Platform::Delete(mAllocatedResponders[i]); + mAllocatedResponders[i] = nullptr; + } + } + + for (size_t i = 0; i < kMaxAllocatedQNameData; i++) + { + if (mAllocatedQNameParts[i] != nullptr) + { + chip::Platform::MemoryFree(mAllocatedQNameParts[i]); + mAllocatedQNameParts[i] = nullptr; + } + } + } + mdns::Minimal::QueryResponder * GetQueryResponder() { return &mQueryResponder; } + +protected: + // For testing. + size_t GetMaxAllocatedQNames() { return kMaxAllocatedQNameData; } + void * GetQNamePart(size_t idx) { return mAllocatedQNameParts[idx]; } + mdns::Minimal::RecordResponder * GetRecordResponder(size_t idx) { return mAllocatedResponders[idx]; } + +private: + static constexpr size_t kMaxAllocatedQNameData = 32; + // dynamically allocated items + mdns::Minimal::RecordResponder * mAllocatedResponders[kMaxRecords]; + void * mAllocatedQNameParts[kMaxAllocatedQNameData]; + // The QueryResponder needs 1 extra space to hold the record for itself. + mdns::Minimal::QueryResponder mQueryResponder; + + mdns::Minimal::QueryResponderSettings AddAllocatedResponder(mdns::Minimal::RecordResponder * responder) + { + if (responder == nullptr) + { + ChipLogError(Discovery, "Responder memory allocation failed"); + return mdns::Minimal::QueryResponderSettings(); // failed + } + + for (size_t i = 0; i < kMaxRecords; i++) + { + if (mAllocatedResponders[i] != nullptr) + { + continue; + } + + mAllocatedResponders[i] = responder; + return mQueryResponder.AddResponder(mAllocatedResponders[i]); + } + + Platform::Delete(responder); + ChipLogError(Discovery, "Failed to find free slot for adding a responder"); + return mdns::Minimal::QueryResponderSettings(); + } + + void * AllocateQNameSpace(size_t size) + { + for (size_t i = 0; i < kMaxAllocatedQNameData; i++) + { + if (mAllocatedQNameParts[i] != nullptr) + { + continue; + } + + mAllocatedQNameParts[i] = chip::Platform::MemoryAlloc(size); + if (mAllocatedQNameParts[i] == nullptr) + { + ChipLogError(Discovery, "QName memory allocation failed"); + } + return mAllocatedQNameParts[i]; + } + ChipLogError(Discovery, "Failed to find free slot for adding a qname"); + return nullptr; + } +}; + +} // namespace Mdns +} // namespace chip diff --git a/src/lib/mdns/tests/BUILD.gn b/src/lib/mdns/tests/BUILD.gn index e7147a576bb4c0..cbbe880d6ef12d 100644 --- a/src/lib/mdns/tests/BUILD.gn +++ b/src/lib/mdns/tests/BUILD.gn @@ -22,6 +22,7 @@ chip_test_suite("tests") { output_name = "libMdnsTests" test_sources = [ + "TestMinimalMdnsAllocator.cpp", "TestServiceNaming.cpp", "TestTxtFields.cpp", ] diff --git a/src/lib/mdns/tests/TestMinimalMdnsAllocator.cpp b/src/lib/mdns/tests/TestMinimalMdnsAllocator.cpp new file mode 100644 index 00000000000000..406f3dd4cf9f23 --- /dev/null +++ b/src/lib/mdns/tests/TestMinimalMdnsAllocator.cpp @@ -0,0 +1,271 @@ +/* + * + * Copyright (c) 2021 Project CHIP Authors + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include + +#include + +using namespace chip; +using namespace chip::Mdns; +using namespace mdns::Minimal; + +#if CHIP_CONFIG_MEMORY_DEBUG_DMALLOC +#include +#endif // CHIP_CONFIG_MEMORY_DEBUG_DMALLOC + +namespace { + +constexpr size_t kMaxRecords = 10; +class TestAllocator : public QueryResponderAllocator +{ +public: + TestAllocator() : QueryResponderAllocator() + { +#if CHIP_CONFIG_MEMORY_DEBUG_DMALLOC + // void dmalloc_track(const dmalloc_track_t track_func) +#endif + } + void TestAllQNamesAreNull(nlTestSuite * inSuite) + { + for (size_t i = 0; i < GetMaxAllocatedQNames(); ++i) + { + NL_TEST_ASSERT(inSuite, GetQNamePart(i) == nullptr); + } + } + void TestAllRecordRespondersAreNull(nlTestSuite * inSuite) + { + for (size_t i = 0; i < kMaxRecords; ++i) + { + NL_TEST_ASSERT(inSuite, GetRecordResponder(i) == nullptr); + } + } + void TestRecordRespondersMatchQuery(nlTestSuite * inSuite) + { + mdns::Minimal::QueryResponderRecordFilter noFilter; + auto queryResponder = GetQueryResponder(); + size_t idx = 0; + for (auto it = queryResponder->begin(&noFilter); it != queryResponder->end(); it++, idx++) + { + // TODO: Once the responders are exposed in the query responder, check that they match. + NL_TEST_ASSERT(inSuite, idx < kMaxRecords); + } + } + size_t GetMaxAllocatedQNames() { return QueryResponderAllocator::GetMaxAllocatedQNames(); } +}; + +void TestQueryAllocatorQName(nlTestSuite * inSuite, void * inContext) +{ + TestAllocator test; +#if CHIP_CONFIG_MEMORY_DEBUG_DMALLOC + unsigned long mark = dmalloc_mark(); +#endif + // Start empty. + test.TestAllRecordRespondersAreNull(inSuite); + test.TestAllQNamesAreNull(inSuite); + + // We should be able to add up to GetMaxAllocatedQNames QNames + for (size_t i = 0; i < test.GetMaxAllocatedQNames(); ++i) + { + NL_TEST_ASSERT(inSuite, test.AllocateQName("test", "testy", "udp") != FullQName()); + test.TestAllRecordRespondersAreNull(inSuite); + } + +#if CHIP_CONFIG_MEMORY_DEBUG_DMALLOC + // Count the memory that has not been freed at this point (since mark) + unsigned long nAllocated = dmalloc_count_changed(mark, 1, 0); + NL_TEST_ASSERT(inSuite, nAllocated != 0); +#endif + + // Adding one more should fail. + NL_TEST_ASSERT(inSuite, test.AllocateQName("test", "testy", "udp") == FullQName()); + test.TestAllRecordRespondersAreNull(inSuite); + +#if CHIP_CONFIG_MEMORY_DEBUG_DMALLOC + // We should not have allocated any more memory + NL_TEST_ASSERT(inSuite, nAllocated == dmalloc_count_changed(mark, 1, 0)); +#endif + + // Clear should take us back to all empty. + test.Clear(); + test.TestAllQNamesAreNull(inSuite); + test.TestAllRecordRespondersAreNull(inSuite); + +#if CHIP_CONFIG_MEMORY_DEBUG_DMALLOC + // The amount of unfreed pointers should be 0. + NL_TEST_ASSERT(inSuite, dmalloc_count_changed(mark, 1, 0) == 0); +#endif +} + +void TestQueryAllocatorQNameArray(nlTestSuite * inSuite, void * inContext) +{ + TestAllocator test; + +#if CHIP_CONFIG_MEMORY_DEBUG_DMALLOC + unsigned long mark = dmalloc_mark(); +#endif + + constexpr size_t kNumParts = 4; + const char * kArray[kNumParts] = { "this", "is", "a", "test" }; + + // Start empty. + test.TestAllRecordRespondersAreNull(inSuite); + test.TestAllQNamesAreNull(inSuite); + + // We should be able to add up to GetMaxAllocatedQNames QNames + for (size_t i = 0; i < test.GetMaxAllocatedQNames(); ++i) + { + NL_TEST_ASSERT(inSuite, test.AllocateQNameFromArray(kArray, kNumParts) != FullQName()); + test.TestAllRecordRespondersAreNull(inSuite); + } + +#if CHIP_CONFIG_MEMORY_DEBUG_DMALLOC + // Count the memory that has not been freed at this point (since mark) + unsigned long nAllocated = dmalloc_count_changed(mark, 1, 0); + NL_TEST_ASSERT(inSuite, nAllocated != 0); +#endif + + // Adding one more should fail. + NL_TEST_ASSERT(inSuite, test.AllocateQNameFromArray(kArray, kNumParts) == FullQName()); + test.TestAllRecordRespondersAreNull(inSuite); + +#if CHIP_CONFIG_MEMORY_DEBUG_DMALLOC + // We should not have allocated any more memory + NL_TEST_ASSERT(inSuite, nAllocated == dmalloc_count_changed(mark, 1, 0)); +#endif + + // Clear should take us back to all empty. + test.Clear(); + test.TestAllQNamesAreNull(inSuite); + test.TestAllRecordRespondersAreNull(inSuite); + +#if CHIP_CONFIG_MEMORY_DEBUG_DMALLOC + // The amount of unfreed pointers should be 0. + NL_TEST_ASSERT(inSuite, dmalloc_count_changed(mark, 1, 0) == 0); +#endif +} + +void TestQueryAllocatorRecordResponder(nlTestSuite * inSuite, void * inContext) +{ + TestAllocator test; + +#if CHIP_CONFIG_MEMORY_DEBUG_DMALLOC + unsigned long mark = dmalloc_mark(); +#endif + // Start empty. + test.TestAllRecordRespondersAreNull(inSuite); + test.TestAllQNamesAreNull(inSuite); + + FullQName serviceName = test.AllocateQName("test", "service"); + FullQName instanceName = test.AllocateQName("test", "instance"); + + for (size_t i = 0; i < kMaxRecords; ++i) + { + NL_TEST_ASSERT(inSuite, test.AddResponder(serviceName, instanceName).IsValid()); + } + +#if CHIP_CONFIG_MEMORY_DEBUG_DMALLOC + // Count the memory that has not been freed at this point (since mark) + unsigned long nAllocated = dmalloc_count_changed(mark, 1, 0); + NL_TEST_ASSERT(inSuite, nAllocated != 0); +#endif + + // Adding one more should fail. + NL_TEST_ASSERT(inSuite, !test.AddResponder(serviceName, instanceName).IsValid()); +#if CHIP_CONFIG_MEMORY_DEBUG_DMALLOC + // We should not have allocated any more memory + NL_TEST_ASSERT(inSuite, nAllocated == dmalloc_count_changed(mark, 1, 0)); +#endif + + // Clear should take us back to all empty. + test.Clear(); + test.TestAllQNamesAreNull(inSuite); + test.TestAllRecordRespondersAreNull(inSuite); + +#if CHIP_CONFIG_MEMORY_DEBUG_DMALLOC + // The amount of unfreed pointers should be 0. + NL_TEST_ASSERT(inSuite, dmalloc_count_changed(mark, 1, 0) == 0); +#endif +} + +void TestQueryAllocatorRecordResponderTypes(nlTestSuite * inSuite, void * inContext) +{ + TestAllocator test; + +#if CHIP_CONFIG_MEMORY_DEBUG_DMALLOC + unsigned long mark = dmalloc_mark(); +#endif + // Start empty. + test.TestAllRecordRespondersAreNull(inSuite); + test.TestAllQNamesAreNull(inSuite); + + FullQName serviceName = test.AllocateQName("test", "service"); + FullQName instanceName = test.AllocateQName("test", "instance"); + FullQName hostName = test.AllocateQName("test", "host"); + FullQName someTxt = test.AllocateQName("L1=some text", "L2=some other text"); + + NL_TEST_ASSERT(inSuite, serviceName != FullQName()); + NL_TEST_ASSERT(inSuite, instanceName != FullQName()); + NL_TEST_ASSERT(inSuite, hostName != FullQName()); + NL_TEST_ASSERT(inSuite, someTxt != FullQName()); + + // Test that we can add all types + NL_TEST_ASSERT(inSuite, test.AddResponder(serviceName, instanceName).IsValid()); + NL_TEST_ASSERT(inSuite, test.AddResponder(SrvResourceRecord(instanceName, hostName, 57)).IsValid()); + NL_TEST_ASSERT(inSuite, test.AddResponder(TxtResourceRecord(instanceName, someTxt)).IsValid()); + NL_TEST_ASSERT(inSuite, test.AddResponder(hostName).IsValid()); + NL_TEST_ASSERT(inSuite, test.AddResponder(hostName).IsValid()); + +#if CHIP_CONFIG_MEMORY_DEBUG_DMALLOC + // Count the memory that has not been freed at this point (since mark) + unsigned long nAllocated = dmalloc_count_changed(mark, 1, 0); + NL_TEST_ASSERT(inSuite, nAllocated != 0); +#endif + + // Clear should take us back to all empty. + test.Clear(); + test.TestAllQNamesAreNull(inSuite); + test.TestAllRecordRespondersAreNull(inSuite); + +#if CHIP_CONFIG_MEMORY_DEBUG_DMALLOC + // The amount of unfreed pointers should be 0. + NL_TEST_ASSERT(inSuite, dmalloc_count_changed(mark, 1, 0) == 0); +#endif +} + +const nlTest sTests[] = { + NL_TEST_DEF("TestQueryAllocatorQName", TestQueryAllocatorQName), // + NL_TEST_DEF("TestQueryAllocatorQNameArray", TestQueryAllocatorQNameArray), // + NL_TEST_DEF("TestQueryAllocatorRecordResponder", TestQueryAllocatorRecordResponder), // + NL_TEST_DEF("TestQueryAllocatorRecordResponderTypes", TestQueryAllocatorRecordResponderTypes), // + + NL_TEST_SENTINEL() // +}; + +} // namespace + +int TestMinimalMdnsAllocator(void) +{ + chip::Platform::MemoryInit(); + nlTestSuite theSuite = { "MinimalMdnsAllocator", &sTests[0], nullptr, nullptr }; + nlTestRunner(&theSuite, nullptr); + return nlTestRunnerStats(&theSuite); +} + +CHIP_REGISTER_TEST_SUITE(TestMinimalMdnsAllocator);