diff --git a/doc/ot_api_doc.h b/doc/ot_api_doc.h index 834280707..2513e092d 100644 --- a/doc/ot_api_doc.h +++ b/doc/ot_api_doc.h @@ -58,6 +58,7 @@ * @defgroup api-dnssd-server DNS-SD Server * @defgroup api-icmp6 ICMPv6 * @defgroup api-ip6 IPv6 + * @defgroup api-mdns Multicast DNS * @defgroup api-nat64 NAT64 * @defgroup api-srp SRP * @defgroup api-ping-sender Ping Sender @@ -181,6 +182,7 @@ * @defgroup plat-memory Memory * @defgroup plat-messagepool Message Pool * @defgroup plat-misc Miscellaneous + * @defgroup plat-mdns Multicast DNS * @defgroup plat-multipan Multipan * @defgroup plat-otns Network Simulator * @defgroup plat-radio Radio diff --git a/doc/ot_config_doc.h b/doc/ot_config_doc.h index 87af22965..c61b92e35 100644 --- a/doc/ot_config_doc.h +++ b/doc/ot_config_doc.h @@ -69,6 +69,7 @@ * @defgroup config-mesh-forwarder Mesh Forwarder * @defgroup config-misc Miscellaneous Constants * @defgroup config-mle MLE Service + * @defgroup config-mdns Multicast DNS * @defgroup config-nat64 NAT64 * @defgroup config-netdata-publisher Network Data Publisher * @defgroup config-network-diagnostic Network Diagnostics diff --git a/etc/cmake/options.cmake b/etc/cmake/options.cmake index bc18aea32..87b57dac6 100644 --- a/etc/cmake/options.cmake +++ b/etc/cmake/options.cmake @@ -215,6 +215,7 @@ ot_option(OT_LINK_METRICS_SUBJECT OPENTHREAD_CONFIG_MLE_LINK_METRICS_SUBJECT_ENA ot_option(OT_LINK_RAW OPENTHREAD_CONFIG_LINK_RAW_ENABLE "link raw service") ot_option(OT_LOG_LEVEL_DYNAMIC OPENTHREAD_CONFIG_LOG_LEVEL_DYNAMIC_ENABLE "dynamic log level control") ot_option(OT_MAC_FILTER OPENTHREAD_CONFIG_MAC_FILTER_ENABLE "mac filter") +ot_option(OT_MDNS OPENTHREAD_CONFIG_MULTICAST_DNS_ENABLE "multicast DNS (mDNS)") ot_option(OT_MESH_DIAG OPENTHREAD_CONFIG_MESH_DIAG_ENABLE "mesh diag") ot_option(OT_MESSAGE_USE_HEAP OPENTHREAD_CONFIG_MESSAGE_USE_HEAP_ENABLE "heap allocator for message buffers") ot_option(OT_MLE_LONG_ROUTES OPENTHREAD_CONFIG_MLE_LONG_ROUTES_ENABLE "MLE long routes extension (experimental)") diff --git a/examples/config/ot-core-config-check-size-br.h b/examples/config/ot-core-config-check-size-br.h index e01bc9ca8..d8b2b130e 100644 --- a/examples/config/ot-core-config-check-size-br.h +++ b/examples/config/ot-core-config-check-size-br.h @@ -79,6 +79,7 @@ #define OPENTHREAD_CONFIG_MLE_LINK_METRICS_INITIATOR_ENABLE 1 #define OPENTHREAD_CONFIG_MLE_LINK_METRICS_SUBJECT_ENABLE 1 #define OPENTHREAD_CONFIG_MLR_ENABLE 1 +#define OPENTHREAD_CONFIG_MULTICAST_DNS_ENABLE 1 #define OPENTHREAD_CONFIG_MULTIPLE_INSTANCE_ENABLE 0 #define OPENTHREAD_CONFIG_NAT64_BORDER_ROUTING_ENABLE 1 #define OPENTHREAD_CONFIG_NAT64_TRANSLATOR_ENABLE 1 @@ -99,5 +100,6 @@ #define OPENTHREAD_CONFIG_DNS_DSO_MOCK_PLAT_APIS_ENABLE 1 #define OPENTHREAD_CONFIG_BORDER_ROUTING_MOCK_PLAT_APIS_ENABLE 1 #define OPENTHREAD_CONFIG_DNS_UPSTREAM_QUERY_MOCK_PLAT_APIS_ENABLE 1 +#define OPENTHREAD_CONFIG_MULTICAST_DNS_MOCK_PLAT_APIS_ENABLE 1 #endif // OT_CORE_CONFIG_CHECK_SIZE_BR_H_ diff --git a/examples/platforms/simulation/CMakeLists.txt b/examples/platforms/simulation/CMakeLists.txt index 12580e2d1..4e4ffbae7 100644 --- a/examples/platforms/simulation/CMakeLists.txt +++ b/examples/platforms/simulation/CMakeLists.txt @@ -69,6 +69,7 @@ add_library(openthread-simulation flash.c infra_if.c logging.c + mdns_socket.c misc.c multipan.c radio.c diff --git a/examples/platforms/simulation/mdns_socket.c b/examples/platforms/simulation/mdns_socket.c new file mode 100644 index 000000000..88abcb50c --- /dev/null +++ b/examples/platforms/simulation/mdns_socket.c @@ -0,0 +1,569 @@ +/* + * Copyright (c) 2024, The OpenThread Authors. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the name of the copyright holder nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#include "platform-simulation.h" + +#include +#include + +#if OPENTHREAD_CONFIG_MULTICAST_DNS_ENABLE + +//--------------------------------------------------------------------------------------------------------------------- +#if OPENTHREAD_SIMULATION_MDNS_SOCKET_IMPLEMENT_POSIX + +// Provide a simplified POSIX based implementation of `otPlatMdns` +// platform APIs. This is intended for testing. + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "simul_utils.h" +#include "utils/code_utils.h" + +#define MAX_BUFFER_SIZE 1600 + +#define MDNS_PORT 5353 + +static bool sEnabled = false; +static uint32_t sInfraIfIndex; +static int sMdnsFd4 = -1; +static int sMdnsFd6 = -1; + +/* this is a portability hack */ +#ifndef IPV6_ADD_MEMBERSHIP +#ifdef IPV6_JOIN_GROUP +#define IPV6_ADD_MEMBERSHIP IPV6_JOIN_GROUP +#endif +#endif + +#ifndef IPV6_DROP_MEMBERSHIP +#ifdef IPV6_LEAVE_GROUP +#define IPV6_DROP_MEMBERSHIP IPV6_LEAVE_GROUP +#endif +#endif + +#define VerifyOrDie(aCondition, aErrMsg) \ + do \ + { \ + if (!(aCondition)) \ + { \ + fprintf(stderr, "\n\r" aErrMsg ". errono:%s\n\r", strerror(errno)); \ + exit(1); \ + } \ + } while (false) + +static void SetReuseAddrPort(int aFd) +{ + int ret; + int yes = 1; + + ret = setsockopt(aFd, SOL_SOCKET, SO_REUSEADDR, (char *)&yes, sizeof(yes)); + VerifyOrDie(ret >= 0, "setsocketopt(SO_REUSEADDR) failed"); + + ret = setsockopt(aFd, SOL_SOCKET, SO_REUSEPORT, (char *)&yes, sizeof(yes)); + VerifyOrDie(ret >= 0, "setsocketopt(SO_REUSEPORT) failed"); +} + +static void OpenIp4Socket(uint32_t aInfraIfIndex) +{ + OT_UNUSED_VARIABLE(aInfraIfIndex); + + struct sockaddr_in addr; + int fd; + int ret; + uint8_t u8; + int value; + + fd = socket(AF_INET, SOCK_DGRAM, 0); + VerifyOrDie(fd >= 0, "socket() failed"); + +#ifdef __linux__ + { + char nameBuffer[IF_NAMESIZE]; + const char *ifname; + + ifname = if_indextoname(aInfraIfIndex, nameBuffer); + VerifyOrDie(ifname != NULL, "if_indextoname() failed"); + + ret = setsockopt(fd, SOL_SOCKET, SO_BINDTODEVICE, ifname, strlen(ifname)); + VerifyOrDie(ret >= 0, "setsocketopt(SO_BINDTODEVICE) failed"); + } +#else + value = aInfraIfIndex; + ret = setsockopt(fd, IPPROTO_IP, IP_BOUND_IF, &value, sizeof(value)); +#endif + + u8 = 255; + ret = setsockopt(fd, IPPROTO_IP, IP_MULTICAST_TTL, &u8, sizeof(u8)); + VerifyOrDie(ret >= 0, "setsocketopt(IP_MULTICAST_TTL) failed"); + + value = 255; + ret = setsockopt(fd, IPPROTO_IP, IP_TTL, &value, sizeof(value)); + VerifyOrDie(ret >= 0, "setsocketopt(IP_TTL) failed"); + + u8 = 1; + ret = setsockopt(fd, IPPROTO_IP, IP_MULTICAST_LOOP, &u8, sizeof(u8)); + VerifyOrDie(ret >= 0, "setsocketopt(IP_MULTICAST_LOOP) failed"); + + SetReuseAddrPort(fd); + + { + struct ip_mreqn mreqn; + + memset(&mreqn, 0, sizeof(mreqn)); + mreqn.imr_multiaddr.s_addr = inet_addr("224.0.0.251"); + mreqn.imr_ifindex = aInfraIfIndex; + + ret = setsockopt(fd, IPPROTO_IP, IP_MULTICAST_IF, &mreqn, sizeof(mreqn)); + VerifyOrDie(ret >= 0, "setsocketopt(IP_MULTICAST_IF) failed"); + } + + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_addr.s_addr = htonl(INADDR_ANY); + addr.sin_port = htons(MDNS_PORT); + + ret = bind(fd, (struct sockaddr *)&addr, sizeof(addr)); + VerifyOrDie(ret >= 0, "bind() failed"); + + sMdnsFd4 = fd; +} + +static void JoinOrLeaveIp4MulticastGroup(bool aJoin, uint32_t aInfraIfIndex) +{ + struct ip_mreqn mreqn; + int ret; + + memset(&mreqn, 0, sizeof(mreqn)); + mreqn.imr_multiaddr.s_addr = inet_addr("224.0.0.251"); + mreqn.imr_ifindex = aInfraIfIndex; + + if (aJoin) + { + // Suggested workaround for netif not dropping + // a previous multicast membership. + setsockopt(sMdnsFd4, IPPROTO_IP, IP_DROP_MEMBERSHIP, &mreqn, sizeof(mreqn)); + } + + ret = setsockopt(sMdnsFd4, IPPROTO_IP, aJoin ? IP_ADD_MEMBERSHIP : IP_DROP_MEMBERSHIP, &mreqn, sizeof(mreqn)); + VerifyOrDie(ret >= 0, "setsocketopt(IP_ADD/DROP_MEMBERSHIP) failed"); +} + +static void OpenIp6Socket(uint32_t aInfraIfIndex) +{ + OT_UNUSED_VARIABLE(aInfraIfIndex); + + struct sockaddr_in6 addr6; + int fd; + int ret; + int value; + + fd = socket(AF_INET6, SOCK_DGRAM, 0); + VerifyOrDie(fd >= 0, "socket() failed"); + +#ifdef __linux__ + { + char nameBuffer[IF_NAMESIZE]; + const char *ifname; + + ifname = if_indextoname(aInfraIfIndex, nameBuffer); + VerifyOrDie(ifname != NULL, "if_indextoname() failed"); + + ret = setsockopt(fd, SOL_SOCKET, SO_BINDTODEVICE, ifname, strlen(ifname)); + VerifyOrDie(ret >= 0, "setsocketopt(SO_BINDTODEVICE) failed"); + } +#else + value = aInfraIfIndex; + ret = setsockopt(fd, IPPROTO_IPV6, IPV6_BOUND_IF, &value, sizeof(value)); +#endif + + value = 255; + ret = setsockopt(fd, IPPROTO_IPV6, IPV6_MULTICAST_HOPS, &value, sizeof(value)); + VerifyOrDie(ret >= 0, "setsocketopt(IPV6_MULTICAST_HOPS) failed"); + + value = 255; + ret = setsockopt(fd, IPPROTO_IPV6, IPV6_UNICAST_HOPS, &value, sizeof(value)); + VerifyOrDie(ret >= 0, "setsocketopt(IPV6_UNICAST_HOPS) failed"); + + value = 1; + ret = setsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY, &value, sizeof(value)); + VerifyOrDie(ret >= 0, "setsocketopt(IPV6_V6ONLY) failed"); + + value = aInfraIfIndex; + ret = setsockopt(fd, IPPROTO_IPV6, IPV6_MULTICAST_IF, &value, sizeof(value)); + VerifyOrDie(ret >= 0, "setsocketopt(IPV6_MULTICAST_IF) failed"); + + value = 1; + ret = setsockopt(fd, IPPROTO_IPV6, IPV6_MULTICAST_LOOP, &value, sizeof(value)); + VerifyOrDie(ret >= 0, "setsocketopt(IPV6_MULTICAST_LOOP) failed"); + + SetReuseAddrPort(fd); + + memset(&addr6, 0, sizeof(addr6)); + addr6.sin6_family = AF_INET6; + addr6.sin6_port = htons(MDNS_PORT); + + ret = bind(fd, (struct sockaddr *)&addr6, sizeof(addr6)); + VerifyOrDie(ret >= 0, "bind() failed"); + + sMdnsFd6 = fd; +} + +static void JoinOrLeaveIp6MulticastGroup(bool aJoin, uint32_t aInfraIfIndex) +{ + struct ipv6_mreq mreq6; + int ret; + + memset(&mreq6, 0, sizeof(mreq6)); + + inet_pton(AF_INET6, "ff02::fb", &mreq6.ipv6mr_multiaddr); + mreq6.ipv6mr_interface = (int)aInfraIfIndex; + + if (aJoin) + { + // Suggested workaround for netif not dropping + // a previous multicast membership. + setsockopt(sMdnsFd6, IPPROTO_IPV6, IPV6_DROP_MEMBERSHIP, &mreq6, sizeof(mreq6)); + } + + ret = setsockopt(sMdnsFd6, IPPROTO_IPV6, aJoin ? IPV6_ADD_MEMBERSHIP : IPV6_DROP_MEMBERSHIP, &mreq6, sizeof(mreq6)); + VerifyOrDie(ret >= 0, "setsocketopt(IP6_ADD/DROP_MEMBERSHIP) failed"); +} + +otError otPlatMdnsSetListeningEnabled(otInstance *aInstance, bool aEnable, uint32_t aInfraIfIndex) +{ + OT_UNUSED_VARIABLE(aInstance); + + if (aEnable) + { + otEXPECT(!sEnabled); + + OpenIp4Socket(aInfraIfIndex); + JoinOrLeaveIp4MulticastGroup(/* aJoin */ true, aInfraIfIndex); + OpenIp6Socket(aInfraIfIndex); + JoinOrLeaveIp6MulticastGroup(/* aJoin */ true, aInfraIfIndex); + + sEnabled = true; + sInfraIfIndex = aInfraIfIndex; + } + else + { + otEXPECT(sEnabled); + + JoinOrLeaveIp4MulticastGroup(/* aJoin */ false, aInfraIfIndex); + JoinOrLeaveIp6MulticastGroup(/* aJoin */ false, aInfraIfIndex); + close(sMdnsFd4); + close(sMdnsFd6); + sEnabled = false; + } + +exit: + return OT_ERROR_NONE; +} + +void otPlatMdnsSendMulticast(otInstance *aInstance, otMessage *aMessage, uint32_t aInfraIfIndex) +{ + OT_UNUSED_VARIABLE(aInstance); + OT_UNUSED_VARIABLE(aInfraIfIndex); + + uint8_t buffer[MAX_BUFFER_SIZE]; + uint16_t length; + int bytes; + + otEXPECT(sEnabled); + + length = otMessageRead(aMessage, 0, buffer, sizeof(buffer)); + otMessageFree(aMessage); + + { + struct sockaddr_in addr; + + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_addr.s_addr = inet_addr("224.0.0.251"); + addr.sin_port = htons(MDNS_PORT); + + bytes = sendto(sMdnsFd4, buffer, length, 0, (struct sockaddr *)&addr, sizeof(addr)); + + VerifyOrDie((bytes == length), "sendTo(sMdnsFd4) failed"); + } + + { + struct sockaddr_in6 addr6; + + memset(&addr6, 0, sizeof(addr6)); + addr6.sin6_family = AF_INET6; + addr6.sin6_port = htons(MDNS_PORT); + inet_pton(AF_INET6, "ff02::fb", &addr6.sin6_addr); + + bytes = sendto(sMdnsFd6, buffer, length, 0, (struct sockaddr *)&addr6, sizeof(addr6)); + + VerifyOrDie((bytes == length), "sendTo(sMdnsFd6) failed"); + } + +exit: + return; +} + +void otPlatMdnsSendUnicast(otInstance *aInstance, otMessage *aMessage, const otPlatMdnsAddressInfo *aAddress) +{ + OT_UNUSED_VARIABLE(aInstance); + + otIp4Address ip4Addr; + uint8_t buffer[MAX_BUFFER_SIZE]; + uint16_t length; + int bytes; + + otEXPECT(sEnabled); + + length = otMessageRead(aMessage, 0, buffer, sizeof(buffer)); + otMessageFree(aMessage); + + if (otIp4FromIp4MappedIp6Address(&aAddress->mAddress, &ip4Addr) == OT_ERROR_NONE) + { + struct sockaddr_in addr; + + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + memcpy(&addr.sin_addr.s_addr, &ip4Addr, sizeof(otIp4Address)); + addr.sin_port = htons(MDNS_PORT); + + bytes = sendto(sMdnsFd4, buffer, length, 0, (struct sockaddr *)&addr, sizeof(addr)); + + VerifyOrDie((bytes == length), "sendTo(sMdnsFd4) failed"); + } + else + { + struct sockaddr_in6 addr6; + + memset(&addr6, 0, sizeof(addr6)); + addr6.sin6_family = AF_INET6; + addr6.sin6_port = htons(MDNS_PORT); + memcpy(&addr6.sin6_addr, &aAddress->mAddress, sizeof(otIp6Address)); + + bytes = sendto(sMdnsFd6, buffer, length, 0, (struct sockaddr *)&addr6, sizeof(addr6)); + + VerifyOrDie((bytes == length), "sendTo(sMdnsFd6) failed"); + } + +exit: + return; +} + +void platformMdnsSocketUpdateFdSet(fd_set *aReadFdSet, int *aMaxFd) +{ + otEXPECT(sEnabled); + + utilsAddFdToFdSet(sMdnsFd4, aReadFdSet, aMaxFd); + utilsAddFdToFdSet(sMdnsFd6, aReadFdSet, aMaxFd); + +exit: + return; +} + +void platformMdnsSocketProcess(otInstance *aInstance, const fd_set *aReadFdSet) +{ + otEXPECT(sEnabled); + + if (FD_ISSET(sMdnsFd4, aReadFdSet)) + { + uint8_t buffer[MAX_BUFFER_SIZE]; + struct sockaddr_in sockaddr; + otPlatMdnsAddressInfo addrInfo; + otMessage *message; + socklen_t len = sizeof(sockaddr); + ssize_t rval; + + memset(&sockaddr, 0, sizeof(sockaddr)); + rval = recvfrom(sMdnsFd4, (char *)&buffer, sizeof(buffer), 0, (struct sockaddr *)&sockaddr, &len); + + VerifyOrDie(rval >= 0, "recvfrom() failed"); + + message = otIp6NewMessage(aInstance, NULL); + VerifyOrDie(message != NULL, "otIp6NewMessage() failed"); + + VerifyOrDie(otMessageAppend(message, buffer, (uint16_t)rval) == OT_ERROR_NONE, "otMessageAppend() failed"); + + memset(&addrInfo, 0, sizeof(addrInfo)); + otIp4ToIp4MappedIp6Address((otIp4Address *)(&sockaddr.sin_addr.s_addr), &addrInfo.mAddress); + addrInfo.mPort = MDNS_PORT; + addrInfo.mInfraIfIndex = sInfraIfIndex; + + otPlatMdnsHandleReceive(aInstance, message, /* aInUnicast */ false, &addrInfo); + } + + if (FD_ISSET(sMdnsFd6, aReadFdSet)) + { + uint8_t buffer[MAX_BUFFER_SIZE]; + struct sockaddr_in6 sockaddr6; + otPlatMdnsAddressInfo addrInfo; + otMessage *message; + socklen_t len = sizeof(sockaddr6); + ssize_t rval; + + memset(&sockaddr6, 0, sizeof(sockaddr6)); + rval = recvfrom(sMdnsFd6, (char *)&buffer, sizeof(buffer), 0, (struct sockaddr *)&sockaddr6, &len); + VerifyOrDie(rval >= 0, "recvfrom(sMdnsFd6) failed"); + + message = otIp6NewMessage(aInstance, NULL); + VerifyOrDie(message != NULL, "otIp6NewMessage() failed"); + + VerifyOrDie(otMessageAppend(message, buffer, (uint16_t)rval) == OT_ERROR_NONE, "otMessageAppend() failed"); + + memset(&addrInfo, 0, sizeof(addrInfo)); + memcpy(&addrInfo.mAddress, &sockaddr6.sin6_addr, sizeof(otIp6Address)); + addrInfo.mPort = MDNS_PORT; + addrInfo.mInfraIfIndex = sInfraIfIndex; + + otPlatMdnsHandleReceive(aInstance, message, /* aInUnicast */ false, &addrInfo); + } + +exit: + return; +} + +//- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +// Add weak implementation of `ot` APIs for RCP build. Note that +// `simulation` platform does not get `OPENTHREAD_RADIO` config) + +OT_TOOL_WEAK uint16_t otMessageRead(const otMessage *aMessage, uint16_t aOffset, void *aBuf, uint16_t aLength) +{ + OT_UNUSED_VARIABLE(aMessage); + OT_UNUSED_VARIABLE(aOffset); + OT_UNUSED_VARIABLE(aBuf); + OT_UNUSED_VARIABLE(aLength); + + fprintf(stderr, "\n\rWeak otMessageRead() is incorrectly used\n\r"); + exit(1); + + return 0; +} + +OT_TOOL_WEAK void otMessageFree(otMessage *aMessage) +{ + OT_UNUSED_VARIABLE(aMessage); + fprintf(stderr, "\n\rWeak otMessageFree() is incorrectly used\n\r"); + exit(1); +} + +OT_TOOL_WEAK otMessage *otIp6NewMessage(otInstance *aInstance, const otMessageSettings *aSettings) +{ + OT_UNUSED_VARIABLE(aInstance); + OT_UNUSED_VARIABLE(aSettings); + + fprintf(stderr, "\n\rWeak otIp6NewMessage() is incorrectly used\n\r"); + exit(1); + + return NULL; +} + +OT_TOOL_WEAK otError otMessageAppend(otMessage *aMessage, const void *aBuf, uint16_t aLength) +{ + OT_UNUSED_VARIABLE(aMessage); + OT_UNUSED_VARIABLE(aBuf); + OT_UNUSED_VARIABLE(aLength); + + fprintf(stderr, "\n\rWeak otMessageFree() is incorrectly used\n\r"); + exit(1); + + return OT_ERROR_NOT_IMPLEMENTED; +} + +OT_TOOL_WEAK void otIp4ToIp4MappedIp6Address(const otIp4Address *aIp4Address, otIp6Address *aIp6Address) +{ + OT_UNUSED_VARIABLE(aIp4Address); + OT_UNUSED_VARIABLE(aIp6Address); + + fprintf(stderr, "\n\rWeak otIp4ToIp4MappedIp6Address() is incorrectly used\n\r"); + exit(1); +} + +OT_TOOL_WEAK otError otIp4FromIp4MappedIp6Address(const otIp6Address *aIp6Address, otIp4Address *aIp4Address) +{ + OT_UNUSED_VARIABLE(aIp6Address); + OT_UNUSED_VARIABLE(aIp4Address); + + fprintf(stderr, "\n\rWeak otIp4FromIp4MappedIp6Address() is incorrectly used\n\r"); + exit(1); + + return OT_ERROR_NOT_IMPLEMENTED; +} + +OT_TOOL_WEAK void otPlatMdnsHandleReceive(otInstance *aInstance, + otMessage *aMessage, + bool aIsUnicast, + const otPlatMdnsAddressInfo *aAddress) +{ + OT_UNUSED_VARIABLE(aInstance); + OT_UNUSED_VARIABLE(aMessage); + OT_UNUSED_VARIABLE(aIsUnicast); + OT_UNUSED_VARIABLE(aAddress); + + fprintf(stderr, "\n\rWeak otPlatMdnsHandleReceive() is incorrectly used\n\r"); + exit(1); +} + +//--------------------------------------------------------------------------------------------------------------------- +#else // OPENTHREAD_SIMULATION_MDNS_SOCKET_IMPLEMENT_POSIX + +otError otPlatMdnsSetListeningEnabled(otInstance *aInstance, bool aEnable, uint32_t aInfraIfIndex) +{ + OT_UNUSED_VARIABLE(aInstance); + OT_UNUSED_VARIABLE(aEnable); + OT_UNUSED_VARIABLE(aInfraIfIndex); + + return OT_ERROR_NOT_IMPLEMENTED; +} + +void otPlatMdnsSendMulticast(otInstance *aInstance, otMessage *aMessage, uint32_t aInfraIfIndex) +{ + OT_UNUSED_VARIABLE(aInstance); + OT_UNUSED_VARIABLE(aInfraIfIndex); + + otMessageFree(aMessage); +} + +void otPlatMdnsSendUnicast(otInstance *aInstance, otMessage *aMessage, const otPlatMdnsAddressInfo *aAddress) +{ + OT_UNUSED_VARIABLE(aInstance); + OT_UNUSED_VARIABLE(aAddress); + otMessageFree(aMessage); +} + +#endif // OPENTHREAD_SIMULATION_MDNS_SOCKET_IMPLEMENT_POSIX + +#endif // OPENTHREAD_CONFIG_MULTICAST_DNS_ENABLE diff --git a/examples/platforms/simulation/platform-config.h b/examples/platforms/simulation/platform-config.h index 6591268d4..79caa4a4d 100644 --- a/examples/platforms/simulation/platform-config.h +++ b/examples/platforms/simulation/platform-config.h @@ -103,3 +103,16 @@ #ifndef OPENTHREAD_SIMULATION_MAX_NETWORK_SIZE #define OPENTHREAD_SIMULATION_MAX_NETWORK_SIZE 33 #endif + +/** + * @def OPENTHREAD_SIMULATION_MDNS_SOCKET_IMPLEMENT_POSIX + * + * Define as 1 for the simulation platform to provide a simplified implementation of `otPlatMdns` APIs using posix + * socket. + * + * This is intended for testing of the OpenThread Multicast DNS (mDNS) module. + * + */ +#ifndef OPENTHREAD_SIMULATION_MDNS_SOCKET_IMPLEMENT_POSIX +#define OPENTHREAD_SIMULATION_MDNS_SOCKET_IMPLEMENT_POSIX 0 +#endif diff --git a/examples/platforms/simulation/platform-simulation.h b/examples/platforms/simulation/platform-simulation.h index d4407686a..fcdd1c1eb 100644 --- a/examples/platforms/simulation/platform-simulation.h +++ b/examples/platforms/simulation/platform-simulation.h @@ -341,6 +341,28 @@ void platformInfraIfProcess(otInstance *aInstance, const fd_set *aReadFdSet, con #endif // OPENTHREAD_CONFIG_BORDER_ROUTING_ENABLE +#if OPENTHREAD_CONFIG_MULTICAST_DNS_ENABLE && OPENTHREAD_SIMULATION_MDNS_SOCKET_IMPLEMENT_POSIX + +/** + * Updates the file descriptor sets with file descriptors used by the mDNS socket. + * + * @param[in,out] aReadFdSet A pointer to the read file descriptors. + * @param[in,out] aMaxFd A pointer to the max file descriptor. + * + */ +void platformMdnsSocketUpdateFdSet(fd_set *aReadFdSet, int *aMaxFd); + +/** + * Performs mDNs Socket processing. + * + * @param[in] aInstance The OpenThread instance structure. + * @param[in] aReadFdSet A pointer to the read file descriptors. + * + */ +void platformMdnsSocketProcess(otInstance *aInstance, const fd_set *aReadFdSet); + +#endif + /** * Shuts down the BLE service used by OpenThread. * diff --git a/examples/platforms/simulation/system.c b/examples/platforms/simulation/system.c index 2abe892e3..ddf1bb210 100644 --- a/examples/platforms/simulation/system.c +++ b/examples/platforms/simulation/system.c @@ -295,6 +295,9 @@ void otSysProcessDrivers(otInstance *aInstance) #if OPENTHREAD_CONFIG_BORDER_ROUTING_ENABLE platformInfraIfUpdateFdSet(&read_fds, &write_fds, &max_fd); #endif +#if OPENTHREAD_CONFIG_MULTICAST_DNS_ENABLE && OPENTHREAD_SIMULATION_MDNS_SOCKET_IMPLEMENT_POSIX + platformMdnsSocketUpdateFdSet(&read_fds, &max_fd); +#endif #if OPENTHREAD_CONFIG_BLE_TCAT_ENABLE platformBleUpdateFdSet(&read_fds, &write_fds, &timeout, &max_fd); @@ -329,6 +332,9 @@ void otSysProcessDrivers(otInstance *aInstance) #if OPENTHREAD_CONFIG_BORDER_ROUTING_ENABLE platformInfraIfProcess(aInstance, &read_fds, &write_fds); #endif +#if OPENTHREAD_CONFIG_MULTICAST_DNS_ENABLE && OPENTHREAD_SIMULATION_MDNS_SOCKET_IMPLEMENT_POSIX + platformMdnsSocketProcess(aInstance, &read_fds); +#endif if (gTerminate) { diff --git a/include/openthread/BUILD.gn b/include/openthread/BUILD.gn index 453639710..f27b1d19c 100644 --- a/include/openthread/BUILD.gn +++ b/include/openthread/BUILD.gn @@ -73,6 +73,7 @@ source_set("openthread") { "link_metrics.h", "link_raw.h", "logging.h", + "mdns.h", "mesh_diag.h", "message.h", "multi_radio.h", @@ -97,6 +98,7 @@ source_set("openthread") { "platform/flash.h", "platform/infra_if.h", "platform/logging.h", + "platform/mdns_socket.h", "platform/memory.h", "platform/messagepool.h", "platform/misc.h", diff --git a/include/openthread/instance.h b/include/openthread/instance.h index 470ce012c..c3538cdfa 100644 --- a/include/openthread/instance.h +++ b/include/openthread/instance.h @@ -53,7 +53,7 @@ extern "C" { * @note This number versions both OpenThread platform and user APIs. * */ -#define OPENTHREAD_API_VERSION (403) +#define OPENTHREAD_API_VERSION (404) /** * @addtogroup api-instance diff --git a/include/openthread/mdns.h b/include/openthread/mdns.h new file mode 100644 index 000000000..a0a3045d9 --- /dev/null +++ b/include/openthread/mdns.h @@ -0,0 +1,747 @@ +/* + * Copyright (c) 2024, The OpenThread Authors. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the name of the copyright holder nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +/** + * @file + * @brief + * This file includes the mDNS related APIs. + * + */ + +#ifndef OPENTHREAD_MULTICAST_DNS_H_ +#define OPENTHREAD_MULTICAST_DNS_H_ + +#include + +#include +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * @addtogroup api-mdns + * + * @brief + * This module includes APIs for Multicast DNS (mDNS). + * + * @{ + * + * The mDNS APIs are available when the mDNS support `OPENTHREAD_CONFIG_MULTICAST_DNS_ENABLE` is enabled and the + * `OPENTHREAD_CONFIG_MULTICAST_DNS_PUBLIC_API_ENABLE` is also enabled. + * + */ + +/** + * Represents a request ID (`uint32_t` value) for registering a host, a service, or a key service. + * + */ +typedef otPlatDnssdRequestId otMdnsRequestId; + +/** + * Represents the callback function to report the outcome of a host, service, or key registration request. + * + * The outcome of a registration request is reported back by invoking this callback with one of the following `aError` + * inputs: + * + * - `OT_ERROR_NONE` indicates registration was successful. + * - `OT_ERROR_DUPLICATED` indicates a name conflict while probing, i.e., name is claimed by another mDNS responder. + * + * See `otMdnsRegisterHost()`, `otMdnsRegisterService()`, and `otMdnsRegisterKey()` for more details about when + * the callback will be invoked. + * + * @param[in] aInstance The OpenThread instance. + * @param[in] aRequestId The request ID. + * @param[in] aError Error indicating the outcome of request. + * + */ +typedef otPlatDnssdRegisterCallback otMdnsRegisterCallback; + +/** + * Represents the callback function to report a detected name conflict after successful registration of an entry. + * + * If a conflict is detected while registering an entry, it is reported through the provided `otMdnsRegisterCallback`. + * The `otMdnsConflictCallback` is used only when a name conflict is detected after an entry has been successfully + * registered. + * + * A non-NULL @p aServiceType indicates that conflict is for a service entry. In this case @p aName specifies the + * service instance label (treated as as a single DNS label and can potentially include dot `.` character). + * + * A NULL @p aServiceType indicates that conflict is for a host entry. In this case @p Name specifies the host name. It + * does not include the domain name. + * + * @param[in] aInstance The OpenThread instance. + * @param[in] aName The host name or the service instance label. + * @param[in] aServiceType The service type (e.g., `_tst._udp`). + * + */ +typedef void (*otMdnsConflictCallback)(otInstance *aInstance, const char *aName, const char *aServiceType); + +/** + * Represents an mDNS host. + * + * This type is used to register or unregister a host (`otMdnsRegisterHost()` and `otMdnsUnregisterHost()`). + * + * See the description of each function for more details on how different fields are used in each case. + * + */ +typedef otPlatDnssdHost otMdnsHost; + +/** + * Represents an mDNS service. + * + * This type is used to register or unregister a service (`otMdnsRegisterService()` and `otMdnsUnregisterService()`). + * + * See the description of each function for more details on how different fields are used in each case. + * + */ +typedef otPlatDnssdService otMdnsService; + +/** + * Represents an mDNS key record. + * + * See `otMdnsRegisterKey()`, `otMdnsUnregisterKey()` for more details about fields in each case. + * + */ +typedef otPlatDnssdKey otMdnsKey; + +/** + * Enables or disables the mDNS module. + * + * The mDNS module should be enabled before registration any host, service, or key entries. Disabling mDNS will + * immediately stop all operations and any communication (multicast or unicast tx) and remove any previously registered + * entries without sending any "goodbye" announcements or invoking their callback. Once disabled, all currently active + * browsers and resolvers are stopped. + * + * @param[in] aInstance The OpenThread instance. + * @param[in] aEnable Boolean to indicate whether to enable (on `TRUE`) or disable (on `FALSE`). + * @param[in] aInfraIfIndex The network interface index for mDNS operation. Value is ignored when disabling + * + * @retval OT_ERROR_NONE Enabled or disabled the mDNS module successfully. + * @retval OT_ERROR_ALREADY mDNS is already enabled on an enable request or is already disabled on a disable request. + * + */ +otError otMdnsSetEnabled(otInstance *aInstance, bool aEnable, uint32_t aInfraIfIndex); + +/** + * Indicates whether the mDNS module is enabled. + * + * @param[in] aInstance The OpenThread instance. + * + * @retval TRUE The mDNS module is enabled + * @retval FALSE The mDNS module is disabled. + * + */ +bool otMdnsIsEnabled(otInstance *aInstance); + +/** + * Sets whether the mDNS module is allowed to send questions requesting unicast responses referred to as "QU" questions. + * + * The "QU" questions request unicast responses, in contrast to "QM" questions which request multicast responses. + * + * When allowed, the first probe will be sent as a "QU" question. This API can be used to address platform limitation + * where platform socket cannot accept unicast response received on mDNS port (due to it being already bound). + * + * @param[in] aInstance The OpenThread instance. + * @param[in] aAllow Indicates whether or not to allow "QU" questions. + * + */ +void otMdnsSetQuestionUnicastAllowed(otInstance *aInstance, bool aAllow); + +/** + * Indicates whether mDNS module is allowed to send "QU" questions requesting unicast response. + * + * @retval TRUE The mDNS module is allowed to send "QU" questions. + * @retval FALSE The mDNS module is not allowed to send "QU" questions. + * + */ +bool otMdnsIsQuestionUnicastAllowed(otInstance *aInstance); + +/** + * Sets the post-registration conflict callback. + * + * If a conflict is detected while registering an entry, it is reported through the provided `otMdnsRegisterCallback`. + * The `otMdnsConflictCallback` is used only when a name conflict is detected after an entry has been successfully + * registered. + * + * @p aCallback can be set to `NULL` if not needed. Subsequent calls will replace any previously set callback. + * + * @param[in] aInstance The OpenThread instance. + * @param[in] aCallback The conflict callback. + * + */ +void otMdnsSetConflictCallback(otInstance *aInstance, otMdnsConflictCallback aCallback); + +/** + * Registers or updates a host on mDNS. + * + * The fields in @p aHost follow these rules: + * + * - The `mHostName` field specifies the host name to register (e.g., "myhost"). MUST NOT contain the domain name. + * - The `mAddresses` is array of IPv6 addresses to register with the host. `mAddressesLength` provides the number of + * entries in `mAddresses` array. + * - The `mAddresses` array can be empty with zero `mAddressesLength`. In this case, mDNS will treat it as if host is + * unregistered and stops advertising any addresses for this the host name. + * - The `mTtl` specifies the TTL if non-zero. If zero, the mDNS core will choose the default TTL of 120 seconds. + * - Other fields in @p aHost structure are ignored in an `otMdnsRegisterHost()` call. + * + * This function can be called again for the same `mHostName` to update a previously registered host entry, for example, + * to change the list of addresses of the host. In this case, the mDNS module will send "goodbye" announcements for any + * previously registered and now removed addresses and announce any newly added addresses. + * + * The outcome of the registration request is reported back by invoking the provided @p aCallback with @p aRequestId + * as its input and one of the following `aError` inputs: + * + * - `OT_ERROR_NONE` indicates registration was successful. + * - `OT_ERROR_DULICATED` indicates a name conflict while probing, i.e., name is claimed by another mDNS responder. + * + * For caller convenience, the OpenThread mDNS module guarantees that the callback will be invoked after this function + * returns, even in cases of immediate registration success. The @p aCallback can be `NULL` if caller does not want to + * be notified of the outcome. + * + * @param[in] aInstance The OpenThread instance. + * @param[in] aHost Information about the host to register. + * @param[in] aRequestId The ID associated with this request. + * @param[in] aCallback The callback function pointer to report the outcome (can be NULL if not needed). + * + * @retval OT_ERROR_NONE Successfully started registration. @p aCallback will report the outcome. + * @retval OT_ERROR_INVALID_STATE mDNS module is not enabled. + * + */ +otError otMdnsRegisterHost(otInstance *aInstance, + const otMdnsHost *aHost, + otMdnsRequestId aRequestId, + otMdnsRegisterCallback aCallback); + +/** + * Unregisters a host on mDNS. + * + * The fields in @p aHost follow these rules: + * + * - The `mHostName` field specifies the host name to unregister (e.g., "myhost"). MUST NOT contain the domain name. + * - Other fields in @p aHost structure are ignored in an `otMdnsUnregisterHost()` call. + * + * If there is no previously registered host with the same name, no action is performed. + * + * If there is a previously registered host with the same name, the mDNS module will send "goodbye" announcement for + * all previously advertised address records. + * + * @param[in] aInstance The OpenThread instance. + * @param[in] aHost Information about the host to unregister. + * + * @retval OT_ERROR_NONE Successfully unregistered host. + * @retval OT_ERROR_INVALID_STATE mDNS module is not enabled. + * + */ +otError otMdnsUnregisterHost(otInstance *aInstance, const otMdnsHost *aHost); + +/** + * Registers or updates a service on mDNS. + * + * The fields in @p aService follow these rules: + * + * - The `mServiceInstance` specifies the service instance label. It is treated as a single DNS name label. It may + * contain dot `.` character which is allowed in a service instance label. + * - The `mServiceType` specifies the service type (e.g., "_tst._udp"). It is treated as multiple dot `.` separated + * labels. It MUST NOT contain the domain name. + * - The `mHostName` field specifies the host name of the service. MUST NOT contain the domain name. + * - The `mSubTypeLabels` is an array of strings representing sub-types associated with the service. Each array entry + * is a sub-type label. The `mSubTypeLabels can be NULL if there is no sub-type. Otherwise, the array length is + * specified by `mSubTypeLabelsLength`. + * - The `mTxtData` and `mTxtDataLength` specify the encoded TXT data. The `mTxtData` can be NULL or `mTxtDataLength` + * can be zero to specify an empty TXT data. In this case mDNS module will use a single zero byte `[ 0 ]` as the + * TXT data. + * - The `mPort`, `mWeight`, and `mPriority` specify the service's parameters as specified in DNS SRV record. + * - The `mTtl` specifies the TTL if non-zero. If zero, the mDNS module will use the default TTL of 120 seconds. + * - Other fields in @p aService structure are ignored in an `otMdnsRegisterService()` call. + * + * This function can be called again for the same `mServiceInstance` and `mServiceType` to update a previously + * registered service entry, for example, to change the sub-types list, or update any parameter such as port, weight, + * priority, TTL, or host name. The mDNS module will send announcements for any changed info, e.g., will send "goodbye" + * announcements for any removed sub-types and announce any newly added sub-types. + * + * Regarding the invocation of the @p aCallback, this function behaves in the same way as described in + * `otMdnsRegisterHost()`. + * + * @param[in] aInstance The OpenThread instance. + * @param[in] aService Information about the service to register. + * @param[in] aRequestId The ID associated with this request. + * @param[in] aCallback The callback function pointer to report the outcome (can be NULL if not needed). + * + * @retval OT_ERROR_NONE Successfully started registration. @p aCallback will report the outcome. + * @retval OT_ERROR_INVALID_STATE mDNS module is not enabled. + * + */ +otError otMdnsRegisterService(otInstance *aInstance, + const otMdnsService *aService, + otMdnsRequestId aRequestId, + otMdnsRegisterCallback aCallback); + +/** + * Unregisters a service on mDNS module. + * + * The fields in @p aService follow these rules: + + * - The `mServiceInstance` specifies the service instance label. It is treated as a single DNS name label. It may + * contain dot `.` character which is allowed in a service instance label. + * - The `mServiceType` specifies the service type (e.g., "_tst._udp"). It is treated as multiple dot `.` separated + * labels. It MUST NOT contain the domain name. + * - Other fields in @p aService structure are ignored in an `otMdnsUnregisterService()` call. + * + * If there is no previously registered service with the same name, no action is performed. + * + * If there is a previously registered service with the same name, the mDNS module will send "goodbye" announcements + * for all related records. + * + * @param[in] aInstance The OpenThread instance. + * @param[in] aService Information about the service to unregister. + * + * @retval OT_ERROR_NONE Successfully unregistered service. + * @retval OT_ERROR_INVALID_STATE mDNS module is not enabled. + * + */ +otError otMdnsUnregisterService(otInstance *aInstance, const otMdnsService *aService); + +/** + * Registers or updates a key record on mDNS module. + * + * The fields in @p aKey follow these rules: + * + * - If the key is associated with a host entry, `mName` specifies the host name and `mServcieType` MUST be NULL. + * - If the key is associated with a service entry, `mName` specifies the service instance label (always treated as + * a single label) and `mServiceType` specifies the service type (e.g., "_tst._udp"). In this case the DNS name for + * key record is `.`. + * - The `mKeyData` field contains the key record's data with `mKeyDataLength` as its length in byes. + * - The `mTtl` specifies the TTL if non-zero. If zero, the mDNS module will use the default TTL of 120 seconds. + * - Other fields in @p aKey structure are ignored in an `otMdnsRegisterKey()` call. + * + * This function can be called again for the same name to updated a previously registered key entry, for example, to + * change the key data or TTL. + * + * Regarding the invocation of the @p aCallback, this function behaves in the same way as described in + * `otMdnsRegisterHost()`. + * + * @param[in] aInstance The OpenThread instance. + * @param[in] aKey Information about the key record to register. + * @param[in] aRequestId The ID associated with this request. + * @param[in] aCallback The callback function pointer to report the outcome (can be NULL if not needed). + * + * @retval OT_ERROR_NONE Successfully started registration. @p aCallback will report the outcome. + * @retval OT_ERROR_INVALID_STATE mDNS module is not enabled. + * + */ +otError otMdnsRegisterKey(otInstance *aInstance, + const otMdnsKey *aKey, + otMdnsRequestId aRequestId, + otMdnsRegisterCallback aCallback); + +/** + * Unregisters a key record on mDNS. + * + * The fields in @p aKey follow these rules: + * + * - If the key is associated with a host entry, `mName` specifies the host name and `mServcieType` MUST be NULL. + * - If the key is associated with a service entry, `mName` specifies the service instance label (always treated as + * a single label) and `mServiceType` specifies the service type (e.g., "_tst._udp"). In this case the DNS name for + * key record is `.`. + * - Other fields in @p aKey structure are ignored in an `otMdnsUnregisterKey()` call. + * + * If there is no previously registered key with the same name, no action is performed. + * + * If there is a previously registered key with the same name, the mDNS module will send "goodbye" announcements for + * the key record. + * + * @param[in] aInstance The OpenThread instance. + * @param[in] aKey Information about the key to unregister. + * + * @retval OT_ERROR_NONE Successfully unregistered key + * @retval OT_ERROR_INVALID_STATE mDNS module is not enabled. + * + */ +otError otMdnsUnregisterKey(otInstance *aInstance, const otMdnsKey *aKey); + +typedef struct otMdnsBrowseResult otMdnsBrowseResult; +typedef struct otMdnsSrvResult otMdnsSrvResult; +typedef struct otMdnsTxtResult otMdnsTxtResult; +typedef struct otMdnsAddressResult otMdnsAddressResult; + +/** + * Represents the callback function used to report a browse result. + * + * @param[in] aInstance The OpenThread instance. + * @param[in] aResult The browse result. + * + */ +typedef void (*otMdnsBrowseCallback)(otInstance *aInstance, const otMdnsBrowseResult *aResult); + +/** + * Represents the callback function used to report an SRV resolve result. + * + * @param[in] aInstance The OpenThread instance. + * @param[in] aResult The SRV resolve result. + * + */ +typedef void (*otMdnsSrvCallback)(otInstance *aInstance, const otMdnsSrvResult *aResult); + +/** + * Represents the callback function used to report a TXT resolve result. + * + * @param[in] aInstance The OpenThread instance. + * @param[in] aResult The TXT resolve result. + * + */ +typedef void (*otMdnsTxtCallback)(otInstance *aInstance, const otMdnsTxtResult *aResult); + +/** + * Represents the callback function use to report a IPv6/IPv4 address resolve result. + * + * @param[in] aInstance The OpenThread instance. + * @param[in] aResult The address resolve result. + * + */ +typedef void (*otMdnsAddressCallback)(otInstance *aInstance, const otMdnsAddressResult *aResult); + +/** + * Represents a service browser. + * + */ +typedef struct otMdnsBrowser +{ + const char *mServiceType; ///< The service type (e.g., "_mt._udp"). MUST NOT include domain name. + const char *mSubTypeLabel; ///< The sub-type label if browsing for sub-type, NULL otherwise. + uint32_t mInfraIfIndex; ///< The infrastructure network interface index. + otMdnsBrowseCallback mCallback; ///< The callback to report result. +} otMdnsBrowser; + +/** + * Represents a browse result. + * + */ +struct otMdnsBrowseResult +{ + const char *mServiceType; ///< The service type (e.g., "_mt._udp"). + const char *mSubTypeLabel; ///< The sub-type label if browsing for sub-type, NULL otherwise. + const char *mServiceInstance; ///< Service instance label. + uint32_t mTtl; ///< TTL in seconds. Zero TTL indicates that service is removed. + uint32_t mInfraIfIndex; ///< The infrastructure network interface index. +}; + +/** + * Represents an SRV service resolver. + * + */ +typedef struct otMdnsSrvResolver +{ + const char *mServiceInstance; ///< The service instance label. + const char *mServiceType; ///< The service type. + uint32_t mInfraIfIndex; ///< The infrastructure network interface index. + otMdnsSrvCallback mCallback; ///< The callback to report result. +} otMdnsSrvResolver; + +/** + * Represents an SRV resolver result. + * + */ +struct otMdnsSrvResult +{ + const char *mServiceInstance; ///< The service instance name label. + const char *mServiceType; ///< The service type. + const char *mHostName; ///< The host name (e.g., "myhost"). Can be NULL when `mTtl` is zero. + uint16_t mPort; ///< The service port number. + uint16_t mPriority; ///< The service priority. + uint16_t mWeight; ///< The service weight. + uint32_t mTtl; ///< The service TTL in seconds. Zero TTL indicates SRV record is removed. + uint32_t mInfraIfIndex; ///< The infrastructure network interface index. +}; + +/** + * Represents a TXT service resolver. + * + */ +typedef struct otMdnsTxtResolver +{ + const char *mServiceInstance; ///< Service instance label. + const char *mServiceType; ///< Service type. + uint32_t mInfraIfIndex; ///< The infrastructure network interface index. + otMdnsTxtCallback mCallback; +} otMdnsTxtResolver; + +/** + * Represents a TXT resolver result. + * + */ +struct otMdnsTxtResult +{ + const char *mServiceInstance; ///< The service instance name label. + const char *mServiceType; ///< The service type. + const uint8_t *mTxtData; ///< Encoded TXT data bytes. Can be NULL when `mTtl` is zero. + uint16_t mTxtDataLength; ///< Length of TXT data. + uint32_t mTtl; ///< The TXT data TTL in seconds. Zero TTL indicates record is removed. + uint32_t mInfraIfIndex; ///< The infrastructure network interface index. +}; + +/** + * Represents an address resolver. + * + */ +typedef struct otMdnsAddressResolver +{ + const char *mHostName; ///< The host name (e.g., "myhost"). MUST NOT contain domain name. + uint32_t mInfraIfIndex; ///< The infrastructure network interface index. + otMdnsAddressCallback mCallback; ///< The callback to report result. +} otMdnsAddressResolver; + +/** + * Represents a discovered host address and its TTL. + * + */ +typedef struct otMdnsAddressAndTtl +{ + otIp6Address mAddress; ///< The IPv6 address. For IPv4 address the IPv4-mapped IPv6 address format is used. + uint32_t mTtl; ///< The TTL in seconds. +} otMdnsAddressAndTtl; + +/** + * Represents address resolver result. + * + */ +struct otMdnsAddressResult +{ + const char *mHostName; ///< The host name. + const otMdnsAddressAndTtl *mAddresses; ///< Array of host addresses and their TTL. Can be NULL if empty. + uint16_t mAddressesLength; ///< Number of entries in `mAddresses` array. + uint32_t mInfraIfIndex; ///< The infrastructure network interface index. +}; + +/** + * Starts a service browser. + * + * Initiates a continuous search for the specified `mServiceType` in @p aBrowser. For sub-type services, use + * `mSubTypeLabel` to define the sub-type, for base services, set `mSubTypeLabel` to NULL. + * + * Discovered services are reported through the `mCallback` function in @p aBrowser. Services that have been removed + * are reported with a TTL value of zero. The callback may be invoked immediately with cached information (if available) + * and potentially before this function returns. When cached results are used, the reported TTL value will reflect + * the original TTL from the last received response. + * + * Multiple browsers can be started for the same service, provided they use different callback functions. + * + * @param[in] aInstance The OpenThread instance. + * @param[in] aBrowser The browser to be started. + * + * @retval OT_ERROR_NONE Browser started successfully. + * @retval OT_ERROR_INVALID_STATE mDNS module is not enabled. + * @retval OT_ERROR_ALREADY An identical browser (same service and callback) is already active. + * + */ +otError otMdnsStartBrowser(otInstance *aInstance, const otMdnsBrowser *aBrowser); + +/** + * Stops a service browser. + * + * No action is performed if no matching browser with the same service and callback is currently active. + * + * @param[in] aInstance The OpenThread instance. + * @param[in] aBrowser The browser to stop. + * + * @retval OT_ERROR_NONE Browser stopped successfully. + * @retval OT_ERROR_INVALID_STATE mDNS module is not enabled. + * + */ +otError otMdnsStopBrowser(otInstance *aInstance, const otMdnsBrowser *aBroswer); + +/** + * Starts an SRV record resolver. + * + * Initiates a continuous SRV record resolver for the specified service in @p aResolver. + * + * Discovered information is reported through the `mCallback` function in @p aResolver. When the service is removed + * it is reported with a TTL value of zero. In this case, `mHostName` may be NULL and other result fields (such as + * `mPort`) should be ignored. + * + * The callback may be invoked immediately with cached information (if available) and potentially before this function + * returns. When cached result is used, the reported TTL value will reflect the original TTL from the last received + * response. + * + * Multiple resolvers can be started for the same service, provided they use different callback functions. + * + * @param[in] aInstance The OpenThread instance. + * @param[in] aResolver The resolver to be started. + * + * @retval OT_ERROR_NONE Resolver started successfully. + * @retval OT_ERROR_INVALID_STATE mDNS module is not enabled. + * @retval OT_ERROR_ALREADY An identical resolver (same service and callback) is already active. + * + */ +otError otMdnsStartSrvResolver(otInstance *aInstance, const otMdnsSrvResolver *aResolver); + +/** + * Stops an SRV record resolver. + * + * No action is performed if no matching resolver with the same service and callback is currently active. + * + * @param[in] aInstance The OpenThread instance. + * @param[in] aResolver The resolver to stop. + * + * @retval OT_ERROR_NONE Resolver stopped successfully. + * @retval OT_ERROR_INVALID_STATE mDNS module is not enabled. + * + */ +otError otMdnsStopSrvResolver(otInstance *aInstance, const otMdnsSrvResolver *aResolver); + +/** + * Starts a TXT record resolver. + * + * Initiates a continuous TXT record resolver for the specified service in @p aResolver. + * + * Discovered information is reported through the `mCallback` function in @p aResolver. When the TXT record is removed + * it is reported with a TTL value of zero. In this case, `mTxtData` may be NULL, and other result fields (such as + * `mTxtDataLength`) should be ignored. + * + * The callback may be invoked immediately with cached information (if available) and potentially before this function + * returns. When cached result is used, the reported TTL value will reflect the original TTL from the last received + * response. + * + * Multiple resolvers can be started for the same service, provided they use different callback functions. + * + * @param[in] aInstance The OpenThread instance. + * @param[in] aResolver The resolver to be started. + * + * @retval OT_ERROR_NONE Resolver started successfully. + * @retval OT_ERROR_INVALID_STATE mDNS module is not enabled. + * @retval OT_ERROR_ALREADY An identical resolver (same service and callback) is already active. + * + */ +otError otMdnsStartTxtResolver(otInstance *aInstance, const otMdnsTxtResolver *aResolver); + +/** + * Stops a TXT record resolver. + * + * No action is performed if no matching resolver with the same service and callback is currently active. + * + * @param[in] aInstance The OpenThread instance. + * @param[in] aResolver The resolver to stop. + * + * @retval OT_ERROR_NONE Resolver stopped successfully. + * @retval OT_ERROR_INVALID_STATE mDNS module is not enabled. + * + */ +otError otMdnsStopTxtResolver(otInstance *aInstance, const otMdnsTxtResolver *aResolver); + +/** + * Starts an IPv6 address resolver. + * + * Initiates a continuous IPv6 address resolver for the specified host name in @p aResolver. + * + * Discovered addresses are reported through the `mCallback` function in @ p aResolver. The callback is invoked + * whenever addresses are added or removed, providing an updated list. If all addresses are removed, the callback is + * invoked with an empty list (`mAddresses` will be NULL, and `mAddressesLength` will be zero). + * + * The callback may be invoked immediately with cached information (if available) and potentially before this function + * returns. When cached result is used, the reported TTL values will reflect the original TTL from the last received + * response. + * + * Multiple resolvers can be started for the same host name, provided they use different callback functions. + * + * @param[in] aInstance The OpenThread instance. + * @param[in] aResolver The resolver to be started. + * + * @retval OT_ERROR_NONE Resolver started successfully. + * @retval OT_ERROR_INVALID_STATE mDNS module is not enabled. + * @retval OT_ERROR_ALREADY An identical resolver (same host and callback) is already active. + * + */ +otError otMdnsStartIp6AddressResolver(otInstance *aInstance, const otMdnsAddressResolver *aResolver); + +/** + * Stops an IPv6 address resolver. + * + * No action is performed if no matching resolver with the same host name and callback is currently active. + * + * @param[in] aInstance The OpenThread instance. + * @param[in] aResolver The resolver to stop. + * + * @retval OT_ERROR_NONE Resolver stopped successfully. + * @retval OT_ERROR_INVALID_STATE mDNS module is not enabled. + * + */ +otError otMdnsStopIp6AddressResolver(otInstance *aInstance, const otMdnsAddressResolver *aResolver); + +/** + * Starts an IPv4 address resolver. + * + * Initiates a continuous IPv4 address resolver for the specified host name in @p aResolver. + * + * Discovered addresses are reported through the `mCallback` function in @ p aResolver. The IPv4 addresses are + * represented using the IPv4-mapped IPv6 address format in `mAddresses` array. The callback is invoked whenever + * addresses are added or removed, providing an updated list. If all addresses are removed, the callback is invoked + * with an empty list (`mAddresses` will be NULL, and `mAddressesLength` will be zero). + * + * The callback may be invoked immediately with cached information (if available) and potentially before this function + * returns. When cached result is used, the reported TTL values will reflect the original TTL from the last received + * response. + * + * Multiple resolvers can be started for the same host name, provided they use different callback functions. + * + * @param[in] aInstance The OpenThread instance. + * @param[in] aResolver The resolver to be started. + * + * @retval OT_ERROR_NONE Resolver started successfully. + * @retval OT_ERROR_INVALID_STATE mDNS module is not enabled. + * @retval OT_ERROR_ALREADY An identical resolver (same host and callback) is already active. + * + */ +otError otMdnsStartIp4AddressResolver(otInstance *aInstance, const otMdnsAddressResolver *aResolver); + +/** + * Stops an IPv4 address resolver. + * + * No action is performed if no matching resolver with the same host name and callback is currently active. + * + * @param[in] aInstance The OpenThread instance. + * @param[in] aResolver The resolver to stop. + * + * @retval OT_ERROR_NONE Resolver stopped successfully. + * @retval OT_ERROR_INVALID_STATE mDNS module is not enabled. + * + */ +otError otMdnsStopIp4AddressResolver(otInstance *aInstance, const otMdnsAddressResolver *aResolver); + +/** + * @} + * + */ + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // OPENTHREAD_MULTICAST_DNS_H_ diff --git a/include/openthread/platform/mdns_socket.h b/include/openthread/platform/mdns_socket.h new file mode 100644 index 000000000..cc2bed97a --- /dev/null +++ b/include/openthread/platform/mdns_socket.h @@ -0,0 +1,173 @@ +/* + * Copyright (c) 2024, The OpenThread Authors. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the name of the copyright holder nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +/** + * @file + * @brief + * This file includes the platform abstraction for mDNS socket. + * + */ + +#ifndef OPENTHREAD_PLATFORM_MULTICAST_DNS_SOCKET_H_ +#define OPENTHREAD_PLATFORM_MULTICAST_DNS_SOCKET_H_ + +#include + +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * @addtogroup plat-mdns + * + * @brief + * This module defines platform APIs for Multicast DNS (mDNS) socket. + * + * @{ + * + */ + +/** + * Represents a socket address info. + * + */ +typedef struct otPlatMdnsAddressInfo +{ + otIp6Address mAddress; ///< IP address. IPv4-mapped IPv6 format should be used to represent IPv4 address. + uint16_t mPort; ///< Port number. + uint32_t mInfraIfIndex; ///< Interface index. +} otPlatMdnsAddressInfo; + +/** + * Enables or disables listening for mDNS messages sent to mDNS port 5353. + * + * When listening is enabled, the platform MUST listen for multicast messages sent to UDP destination port 5353 at the + * mDNS link-local multicast address `224.0.0.251` and its IPv6 equivalent `ff02::fb`. + * + * The platform SHOULD also listen for any unicast messages sent to UDP destination port 5353. If this is not possible, + * then OpenThread mDNS module can be configured to not use any "QU" questions requesting unicast response. + * + * While enabled, all received messages MUST be reported back using `otPlatMdnsHandleReceive()` callback. + * + * @param[in] aInstance The OpernThread instance. + * @param[in] aEnable Indicate whether to enable or disable. + * @param[in] aInfraInfIndex The infrastructure network interface index. + * + * @retval OT_ERROR_NONE Successfully enabled/disabled listening for mDNS messages. + * @retval OT_ERROR_FAILED Failed to enable/disable listening for mDNS messages. + * + */ +otError otPlatMdnsSetListeningEnabled(otInstance *aInstance, bool aEnable, uint32_t aInfraIfIndex); + +/** + * Sends an mDNS message as multicast. + * + * The platform MUST multicast the prepared mDNS message in @p aMessage as a UDP message using the mDNS well-known port + * number 5353 for both source and destination ports. The message MUST be sent to the mDNS link-local multicast + * address `224.0.0.251` and/or its IPv6 equivalent `ff02::fb`. + * + * @p aMessage contains the mDNS message starting with DNS header at offset zero. It does not include IP or UDP headers. + * This function passes the ownership of @p aMessage to the platform layer and platform implementation MUST free + * @p aMessage once sent and no longer needed. + * + * The platform MUST allow multicast loopback, i.e., the multicast message @p aMessage MUST also be received and + * passed back to OpenThread stack using `otPlatMdnsHandleReceive()` callback. This behavior is essential for the + * OpenThread mDNS stack to process and potentially respond to its own queries, while allowing other mDNS receivers + * to also receive the query and its response. + * + * @param[in] aInstance The OpenThread instance. + * @param[in] aMessage The mDNS message to multicast. Ownership is transferred to the platform layer. + * @param[in] aInfraIfIndex The infrastructure network interface index. + * + */ +void otPlatMdnsSendMulticast(otInstance *aInstance, otMessage *aMessage, uint32_t aInfraIfIndex); + +/** + * Sends an mDNS message as unicast. + * + * The platform MUST send the prepared mDNS message in @p aMessage as a UDP message using source UDP port 5353 to + * the destination address and port number specified by @p aAddress. + * + * @p aMessage contains the DNS message starting with the DNS header at offset zero. It does not include IP or UDP + * headers. This function passes the ownership of @p aMessage to the platform layer and platform implementation + * MUST free @p aMessage once sent and no longer needed. + * + * The @p aAddress fields are as follows: + * + * - `mAddress` specifies the destination address. IPv4-mapped IPv6 format is used to represent an IPv4 destination. + * - `mPort` specifies the destination port. + * - `mInfraIndex` specifies the interface index. + * + * If the @aAddress matches this devices address, the platform MUST ensure to receive and pass the message back to + * the OpenThread stack using `otPlatMdnsHandleReceive()` for processing. + * + * @param[in] aInstance The OpenThread instance. + * @param[in] aMessage The mDNS message to multicast. Ownership is transferred to platform layer. + * @param[in] aAddress The destination address info. + * + */ +void otPlatMdnsSendUnicast(otInstance *aInstance, otMessage *aMessage, const otPlatMdnsAddressInfo *aAddress); + +/** + * Callback to notify OpenThread mDNS module of a received message on UDP port 5353. + * + * @p aMessage MUST contain DNS message starting with the DNS header at offset zero. This function passes the + * ownership of @p aMessage from the platform layer to the OpenThread stack. The OpenThread stack will free the + * message once processed. + * + * The @p aAddress fields are as follows: + * + * - `mAddress` specifies the sender's address. IPv4-mapped IPv6 format is used to represent an IPv4 destination. + * - `mPort` specifies the sender's port. + * - `mInfraIndex` specifies the interface index. + * + * @param[in] aInstance The OpenThread instance. + * @param[in] aMessage The received mDNS message. Ownership is transferred to the OpenThread stack. + * @param[in] aIsUnicast Indicates whether the received message is unicast or multicast. + * @param[in] aAddress The sender's address info. + * + */ +extern void otPlatMdnsHandleReceive(otInstance *aInstance, + otMessage *aMessage, + bool aIsUnicast, + const otPlatMdnsAddressInfo *aAddress); + +/** + * @} + * + */ + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // OPENTHREAD_PLATFORM_MULTICAST_DNS_SOCKET_H_ diff --git a/src/cli/BUILD.gn b/src/cli/BUILD.gn index b74577af1..2c214d87b 100644 --- a/src/cli/BUILD.gn +++ b/src/cli/BUILD.gn @@ -53,6 +53,8 @@ openthread_cli_sources = [ "cli_link_metrics.hpp", "cli_mac_filter.cpp", "cli_mac_filter.hpp", + "cli_mdns.cpp", + "cli_mdns.hpp", "cli_network_data.cpp", "cli_network_data.hpp", "cli_ping.cpp", diff --git a/src/cli/CMakeLists.txt b/src/cli/CMakeLists.txt index 9dfee4030..3d9a32a21 100644 --- a/src/cli/CMakeLists.txt +++ b/src/cli/CMakeLists.txt @@ -44,6 +44,7 @@ set(COMMON_SOURCES cli_joiner.cpp cli_link_metrics.cpp cli_mac_filter.cpp + cli_mdns.cpp cli_network_data.cpp cli_ping.cpp cli_srp_client.cpp diff --git a/src/cli/cli.cpp b/src/cli/cli.cpp index e6132df8f..2a98ee0ec 100644 --- a/src/cli/cli.cpp +++ b/src/cli/cli.cpp @@ -119,6 +119,9 @@ Interpreter::Interpreter(Instance *aInstance, otCliOutputCallback aCallback, voi #if OPENTHREAD_CLI_DNS_ENABLE , mDns(aInstance, *this) #endif +#if OPENTHREAD_CONFIG_MULTICAST_DNS_ENABLE && OPENTHREAD_CONFIG_MULTICAST_DNS_PUBLIC_API_ENABLE + , mMdns(aInstance, *this) +#endif #if (OPENTHREAD_CONFIG_THREAD_VERSION >= OT_THREAD_VERSION_1_2) , mBbr(aInstance, *this) #endif @@ -2763,6 +2766,10 @@ template <> otError Interpreter::Process(Arg aArgs[]) template <> otError Interpreter::Process(Arg aArgs[]) { return mDns.Process(aArgs); } #endif +#if OPENTHREAD_CONFIG_MULTICAST_DNS_ENABLE && OPENTHREAD_CONFIG_MULTICAST_DNS_PUBLIC_API_ENABLE +template <> otError Interpreter::Process(Arg aArgs[]) { return mMdns.Process(aArgs); } +#endif + #if OPENTHREAD_FTD void Interpreter::OutputEidCacheEntry(const otCacheEntryInfo &aEntry) { @@ -8579,6 +8586,9 @@ otError Interpreter::ProcessCommand(Arg aArgs[]) #if OPENTHREAD_CONFIG_MAC_FILTER_ENABLE CmdEntry("macfilter"), #endif +#if OPENTHREAD_CONFIG_MULTICAST_DNS_ENABLE && OPENTHREAD_CONFIG_MULTICAST_DNS_PUBLIC_API_ENABLE + CmdEntry("mdns"), +#endif #if OPENTHREAD_CONFIG_MESH_DIAG_ENABLE && OPENTHREAD_FTD CmdEntry("meshdiag"), #endif diff --git a/src/cli/cli.hpp b/src/cli/cli.hpp index 75ea62975..b808f7455 100644 --- a/src/cli/cli.hpp +++ b/src/cli/cli.hpp @@ -61,12 +61,14 @@ #include "cli/cli_bbr.hpp" #include "cli/cli_br.hpp" #include "cli/cli_commissioner.hpp" +#include "cli/cli_config.h" #include "cli/cli_dataset.hpp" #include "cli/cli_dns.hpp" #include "cli/cli_history.hpp" #include "cli/cli_joiner.hpp" #include "cli/cli_link_metrics.hpp" #include "cli/cli_mac_filter.hpp" +#include "cli/cli_mdns.hpp" #include "cli/cli_network_data.hpp" #include "cli/cli_ping.hpp" #include "cli/cli_srp_client.hpp" @@ -117,6 +119,7 @@ class Interpreter : public OutputImplementer, public Utils friend class Dns; friend class Joiner; friend class LinkMetrics; + friend class Mdns; friend class NetworkData; friend class PingSender; friend class SrpClient; @@ -446,6 +449,10 @@ class Interpreter : public OutputImplementer, public Utils Dns mDns; #endif +#if OPENTHREAD_CONFIG_MULTICAST_DNS_ENABLE && OPENTHREAD_CONFIG_MULTICAST_DNS_PUBLIC_API_ENABLE + Mdns mMdns; +#endif + #if (OPENTHREAD_CONFIG_THREAD_VERSION >= OT_THREAD_VERSION_1_2) Bbr mBbr; #endif diff --git a/src/cli/cli_mdns.cpp b/src/cli/cli_mdns.cpp new file mode 100644 index 000000000..44ae0c838 --- /dev/null +++ b/src/cli/cli_mdns.cpp @@ -0,0 +1,776 @@ +/* + * Copyright (c) 2024, The OpenThread Authors. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the name of the copyright holder nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +/** + * @file + * This file implements CLI for mDNS. + */ + +#include + +#include "cli_mdns.hpp" + +#include +#include "cli/cli.hpp" + +#if OPENTHREAD_CONFIG_MULTICAST_DNS_ENABLE && OPENTHREAD_CONFIG_MULTICAST_DNS_PUBLIC_API_ENABLE + +namespace ot { +namespace Cli { + +template <> otError Mdns::Process(Arg aArgs[]) +{ + otError error; + uint32_t infraIfIndex; + + SuccessOrExit(error = aArgs[0].ParseAsUint32(infraIfIndex)); + VerifyOrExit(aArgs[1].IsEmpty(), error = OT_ERROR_INVALID_ARGS); + + SuccessOrExit(error = otMdnsSetEnabled(GetInstancePtr(), true, infraIfIndex)); + + mInfraIfIndex = infraIfIndex; + +exit: + return error; +} + +template <> otError Mdns::Process(Arg aArgs[]) +{ + otError error = OT_ERROR_NONE; + + VerifyOrExit(aArgs[0].IsEmpty(), error = OT_ERROR_INVALID_ARGS); + error = otMdnsSetEnabled(GetInstancePtr(), false, /* aInfraIfIndex */ 0); + +exit: + return error; +} + +template <> otError Mdns::Process(Arg aArgs[]) +{ + otError error = OT_ERROR_NONE; + + VerifyOrExit(aArgs[0].IsEmpty(), error = OT_ERROR_INVALID_ARGS); + OutputEnabledDisabledStatus(otMdnsIsEnabled(GetInstancePtr())); + +exit: + return error; +} + +template <> otError Mdns::Process(Arg aArgs[]) +{ + return ProcessEnableDisable(aArgs, otMdnsIsQuestionUnicastAllowed, otMdnsSetQuestionUnicastAllowed); +} + +void Mdns::OutputHost(const otMdnsHost &aHost) +{ + OutputLine("Host %s", aHost.mHostName); + OutputLine(kIndentSize, "%u address:", aHost.mAddressesLength); + + for (uint16_t index = 0; index < aHost.mAddressesLength; index++) + { + OutputFormat(kIndentSize, " "); + OutputIp6AddressLine(aHost.mAddresses[index]); + } + + OutputLine(kIndentSize, "ttl: %lu", ToUlong(aHost.mTtl)); +} + +void Mdns::OutputService(const otMdnsService &aService) +{ + OutputLine("Service %s for %s", aService.mServiceInstance, aService.mServiceType); + OutputLine(kIndentSize, "host: %s", aService.mHostName); + + if (aService.mSubTypeLabelsLength > 0) + { + OutputLine(kIndentSize, "%u sub-type:", aService.mSubTypeLabelsLength); + + for (uint16_t index = 0; index < aService.mSubTypeLabelsLength; index++) + { + OutputLine(kIndentSize * 2, "%s", aService.mSubTypeLabels[index]); + } + } + + OutputLine(kIndentSize, "port: %u", aService.mPort); + OutputLine(kIndentSize, "priority: %u", aService.mPriority); + OutputLine(kIndentSize, "weight: %u", aService.mWeight); + OutputLine(kIndentSize, "ttl: %lu", ToUlong(aService.mTtl)); + + if ((aService.mTxtData == nullptr) || (aService.mTxtDataLength == 0)) + { + OutputLine(kIndentSize, "txt-data: (empty)"); + } + else + { + OutputFormat(kIndentSize, "txt-data: "); + OutputBytesLine(aService.mTxtData, aService.mTxtDataLength); + } +} + +void Mdns::OutputKey(const otMdnsKey &aKey) +{ + if (aKey.mServiceType != nullptr) + { + OutputLine("Key %s for %s (service)", aKey.mName, aKey.mServiceType); + } + else + { + OutputLine("Key %s (host)", aKey.mName); + } + + OutputFormat(kIndentSize, "key-data: "); + OutputBytesLine(aKey.mKeyData, aKey.mKeyDataLength); + + OutputLine(kIndentSize, "ttl: %lu", ToUlong(aKey.mTtl)); +} + +template <> otError Mdns::Process(Arg aArgs[]) +{ + // mdns [async] [host|service|key] + + otError error = OT_ERROR_NONE; + bool isAsync = false; + + if (aArgs[0] == "async") + { + isAsync = true; + aArgs++; + } + + if (aArgs[0] == "host") + { + SuccessOrExit(error = ProcessRegisterHost(aArgs + 1)); + } + else if (aArgs[0] == "service") + { + SuccessOrExit(error = ProcessRegisterService(aArgs + 1)); + } + else if (aArgs[0] == "key") + { + SuccessOrExit(error = ProcessRegisterKey(aArgs + 1)); + } + else + { + ExitNow(error = OT_ERROR_INVALID_ARGS); + } + + if (isAsync) + { + OutputLine("mDNS request id: %lu", ToUlong(mRequestId)); + } + else + { + error = OT_ERROR_PENDING; + mWaitingForCallback = true; + } + +exit: + return error; +} + +otError Mdns::ProcessRegisterHost(Arg aArgs[]) +{ + // register host [] [] + + otError error = OT_ERROR_NONE; + otMdnsHost host; + otIp6Address addresses[kMaxAddresses]; + + memset(&host, 0, sizeof(host)); + + VerifyOrExit(!aArgs->IsEmpty(), error = OT_ERROR_INVALID_ARGS); + host.mHostName = aArgs->GetCString(); + aArgs++; + + host.mAddresses = addresses; + + for (; !aArgs->IsEmpty(); aArgs++) + { + otIp6Address address; + uint32_t ttl; + + if (aArgs->ParseAsIp6Address(address) == OT_ERROR_NONE) + { + VerifyOrExit(host.mAddressesLength < kMaxAddresses, error = OT_ERROR_NO_BUFS); + addresses[host.mAddressesLength] = address; + host.mAddressesLength++; + } + else if (aArgs->ParseAsUint32(ttl) == OT_ERROR_NONE) + { + host.mTtl = ttl; + VerifyOrExit(aArgs[1].IsEmpty(), error = OT_ERROR_INVALID_ARGS); + } + else + { + ExitNow(error = OT_ERROR_INVALID_ARGS); + } + } + + OutputHost(host); + + mRequestId++; + error = otMdnsRegisterHost(GetInstancePtr(), &host, mRequestId, HandleRegisterationDone); + +exit: + return error; +} + +otError Mdns::ProcessRegisterService(Arg aArgs[]) +{ + otError error; + otMdnsService service; + Buffers buffers; + + SuccessOrExit(error = ParseServiceArgs(aArgs, service, buffers)); + + OutputService(service); + + mRequestId++; + error = otMdnsRegisterService(GetInstancePtr(), &service, mRequestId, HandleRegisterationDone); + +exit: + return error; +} + +otError Mdns::ParseServiceArgs(Arg aArgs[], otMdnsService &aService, Buffers &aBuffers) +{ + // mdns register service [] [] [] + // [] + + otError error = OT_ERROR_INVALID_ARGS; + char *label; + uint16_t len; + + memset(&aService, 0, sizeof(aService)); + + VerifyOrExit(!aArgs->IsEmpty()); + aService.mServiceInstance = aArgs->GetCString(); + aArgs++; + + // Copy service type into `aBuffer.mString`, then search for + // `,` in the string to parse the list of sub-types (if any). + + VerifyOrExit(!aArgs->IsEmpty()); + len = aArgs->GetLength(); + VerifyOrExit(len + 1 < kStringSize, error = OT_ERROR_NO_BUFS); + memcpy(aBuffers.mString, aArgs->GetCString(), len + 1); + + aService.mServiceType = aBuffers.mString; + aService.mSubTypeLabels = aBuffers.mSubTypeLabels; + + label = strchr(aBuffers.mString, ','); + + if (label != nullptr) + { + while (true) + { + *label++ = '\0'; + + VerifyOrExit(aService.mSubTypeLabelsLength < kMaxSubTypes, error = OT_ERROR_NO_BUFS); + aBuffers.mSubTypeLabels[aService.mSubTypeLabelsLength] = label; + aService.mSubTypeLabelsLength++; + + label = strchr(label, ','); + + if (label == nullptr) + { + break; + } + } + } + + aArgs++; + VerifyOrExit(!aArgs->IsEmpty()); + aService.mHostName = aArgs->GetCString(); + + aArgs++; + SuccessOrExit(aArgs->ParseAsUint16(aService.mPort)); + + // The rest of `Args` are optional. + + error = OT_ERROR_NONE; + + aArgs++; + VerifyOrExit(!aArgs->IsEmpty()); + SuccessOrExit(error = aArgs->ParseAsUint16(aService.mPriority)); + + aArgs++; + VerifyOrExit(!aArgs->IsEmpty()); + SuccessOrExit(error = aArgs->ParseAsUint16(aService.mWeight)); + + aArgs++; + VerifyOrExit(!aArgs->IsEmpty()); + SuccessOrExit(error = aArgs->ParseAsUint32(aService.mTtl)); + + aArgs++; + VerifyOrExit(!aArgs->IsEmpty()); + len = kMaxTxtDataSize; + SuccessOrExit(error = aArgs->ParseAsHexString(len, aBuffers.mTxtData)); + aService.mTxtData = aBuffers.mTxtData; + aService.mTxtDataLength = len; + + aArgs++; + VerifyOrExit(aArgs->IsEmpty(), error = OT_ERROR_INVALID_ARGS); + +exit: + return error; +} + +otError Mdns::ProcessRegisterKey(Arg aArgs[]) +{ + otError error = OT_ERROR_INVALID_ARGS; + otMdnsKey key; + uint16_t len; + uint8_t data[kMaxKeyDataSize]; + + memset(&key, 0, sizeof(key)); + + VerifyOrExit(!aArgs->IsEmpty()); + key.mName = aArgs->GetCString(); + + aArgs++; + VerifyOrExit(!aArgs->IsEmpty()); + + if (aArgs->GetCString()[0] == '_') + { + key.mServiceType = aArgs->GetCString(); + aArgs++; + VerifyOrExit(!aArgs->IsEmpty()); + } + + len = kMaxKeyDataSize; + SuccessOrExit(error = aArgs->ParseAsHexString(len, data)); + + key.mKeyData = data; + key.mKeyDataLength = len; + + // ttl is optional + + aArgs++; + + if (!aArgs->IsEmpty()) + { + SuccessOrExit(error = aArgs->ParseAsUint32(key.mTtl)); + aArgs++; + VerifyOrExit(aArgs->IsEmpty(), error = kErrorInvalidArgs); + } + + OutputKey(key); + + mRequestId++; + error = otMdnsRegisterKey(GetInstancePtr(), &key, mRequestId, HandleRegisterationDone); + +exit: + return error; +} + +void Mdns::HandleRegisterationDone(otInstance *aInstance, otMdnsRequestId aRequestId, otError aError) +{ + OT_UNUSED_VARIABLE(aInstance); + + Interpreter::GetInterpreter().mMdns.HandleRegisterationDone(aRequestId, aError); +} + +void Mdns::HandleRegisterationDone(otMdnsRequestId aRequestId, otError aError) +{ + if (mWaitingForCallback && (aRequestId == mRequestId)) + { + mWaitingForCallback = false; + Interpreter::GetInterpreter().OutputResult(aError); + } + else + { + OutputLine("mDNS registration for request id %lu outcome: %s", ToUlong(aRequestId), + otThreadErrorToString(aError)); + } +} + +template <> otError Mdns::Process(Arg aArgs[]) +{ + otError error = OT_ERROR_INVALID_ARGS; + + if (aArgs[0] == "host") + { + otMdnsHost host; + + memset(&host, 0, sizeof(host)); + VerifyOrExit(!aArgs[1].IsEmpty()); + host.mHostName = aArgs[1].GetCString(); + VerifyOrExit(aArgs[2].IsEmpty()); + + error = otMdnsUnregisterHost(GetInstancePtr(), &host); + } + else if (aArgs[0] == "service") + { + otMdnsService service; + + memset(&service, 0, sizeof(service)); + VerifyOrExit(!aArgs[1].IsEmpty()); + service.mServiceInstance = aArgs[1].GetCString(); + VerifyOrExit(!aArgs[2].IsEmpty()); + service.mServiceType = aArgs[2].GetCString(); + VerifyOrExit(aArgs[3].IsEmpty()); + + error = otMdnsUnregisterService(GetInstancePtr(), &service); + } + else if (aArgs[0] == "key") + { + otMdnsKey key; + + memset(&key, 0, sizeof(key)); + VerifyOrExit(!aArgs[1].IsEmpty()); + key.mName = aArgs[1].GetCString(); + + if (!aArgs[2].IsEmpty()) + { + key.mServiceType = aArgs[2].GetCString(); + VerifyOrExit(aArgs[3].IsEmpty()); + } + + error = otMdnsUnregisterKey(GetInstancePtr(), &key); + } + +exit: + return error; +} + +otError Mdns::ParseStartOrStop(const Arg &aArg, bool &aIsStart) +{ + otError error = OT_ERROR_NONE; + + if (aArg == "start") + { + aIsStart = true; + } + else if (aArg == "stop") + { + aIsStart = false; + } + else + { + error = OT_ERROR_INVALID_ARGS; + } + + return error; +} + +template <> otError Mdns::Process(Arg aArgs[]) +{ + // mdns browser start|stop [] + + otError error; + otMdnsBrowser browser; + bool isStart; + + ClearAllBytes(browser); + + SuccessOrExit(error = ParseStartOrStop(aArgs[0], isStart)); + VerifyOrExit(!aArgs[1].IsEmpty(), error = OT_ERROR_INVALID_ARGS); + + browser.mServiceType = aArgs[1].GetCString(); + + if (!aArgs[2].IsEmpty()) + { + browser.mSubTypeLabel = aArgs[2].GetCString(); + VerifyOrExit(aArgs[3].IsEmpty(), error = OT_ERROR_INVALID_ARGS); + } + + browser.mInfraIfIndex = mInfraIfIndex; + browser.mCallback = HandleBrowseResult; + + if (isStart) + { + error = otMdnsStartBrowser(GetInstancePtr(), &browser); + } + else + { + error = otMdnsStopBrowser(GetInstancePtr(), &browser); + } + +exit: + return error; +} + +void Mdns::HandleBrowseResult(otInstance *aInstance, const otMdnsBrowseResult *aResult) +{ + OT_UNUSED_VARIABLE(aInstance); + + Interpreter::GetInterpreter().mMdns.HandleBrowseResult(*aResult); +} + +void Mdns::HandleBrowseResult(const otMdnsBrowseResult &aResult) +{ + OutputFormat("mDNS browse result for %s", aResult.mServiceType); + + if (aResult.mSubTypeLabel) + { + OutputLine(" sub-type %s", aResult.mSubTypeLabel); + } + else + { + OutputNewLine(); + } + + OutputLine(kIndentSize, "instance: %s", aResult.mServiceInstance); + OutputLine(kIndentSize, "ttl: %lu", ToUlong(aResult.mTtl)); + OutputLine(kIndentSize, "if-index: %lu", ToUlong(aResult.mInfraIfIndex)); +} + +template <> otError Mdns::Process(Arg aArgs[]) +{ + // mdns srvresolver start|stop + + otError error; + otMdnsSrvResolver resolver; + bool isStart; + + ClearAllBytes(resolver); + + SuccessOrExit(error = ParseStartOrStop(aArgs[0], isStart)); + VerifyOrExit(!aArgs[2].IsEmpty(), error = OT_ERROR_INVALID_ARGS); + + resolver.mServiceInstance = aArgs[1].GetCString(); + resolver.mServiceType = aArgs[2].GetCString(); + resolver.mInfraIfIndex = mInfraIfIndex; + resolver.mCallback = HandleSrvResult; + + if (isStart) + { + error = otMdnsStartSrvResolver(GetInstancePtr(), &resolver); + } + else + { + error = otMdnsStopSrvResolver(GetInstancePtr(), &resolver); + } + +exit: + return error; +} + +void Mdns::HandleSrvResult(otInstance *aInstance, const otMdnsSrvResult *aResult) +{ + OT_UNUSED_VARIABLE(aInstance); + + Interpreter::GetInterpreter().mMdns.HandleSrvResult(*aResult); +} + +void Mdns::HandleSrvResult(const otMdnsSrvResult &aResult) +{ + OutputLine("mDNS SRV result for %s for %s", aResult.mServiceInstance, aResult.mServiceType); + + if (aResult.mTtl != 0) + { + OutputLine(kIndentSize, "host: %s", aResult.mHostName); + OutputLine(kIndentSize, "port: %u", aResult.mPort); + OutputLine(kIndentSize, "priority: %u", aResult.mPriority); + OutputLine(kIndentSize, "weight: %u", aResult.mWeight); + } + + OutputLine(kIndentSize, "ttl: %lu", ToUlong(aResult.mTtl)); + OutputLine(kIndentSize, "if-index: %lu", ToUlong(aResult.mInfraIfIndex)); +} + +template <> otError Mdns::Process(Arg aArgs[]) +{ + // mdns txtresolver start|stop + + otError error; + otMdnsTxtResolver resolver; + bool isStart; + + ClearAllBytes(resolver); + + SuccessOrExit(error = ParseStartOrStop(aArgs[0], isStart)); + VerifyOrExit(!aArgs[2].IsEmpty(), error = OT_ERROR_INVALID_ARGS); + + resolver.mServiceInstance = aArgs[1].GetCString(); + resolver.mServiceType = aArgs[2].GetCString(); + resolver.mInfraIfIndex = mInfraIfIndex; + resolver.mCallback = HandleTxtResult; + + if (isStart) + { + error = otMdnsStartTxtResolver(GetInstancePtr(), &resolver); + } + else + { + error = otMdnsStopTxtResolver(GetInstancePtr(), &resolver); + } + +exit: + return error; +} + +void Mdns::HandleTxtResult(otInstance *aInstance, const otMdnsTxtResult *aResult) +{ + OT_UNUSED_VARIABLE(aInstance); + + Interpreter::GetInterpreter().mMdns.HandleTxtResult(*aResult); +} + +void Mdns::HandleTxtResult(const otMdnsTxtResult &aResult) +{ + OutputLine("mDNS TXT result for %s for %s", aResult.mServiceInstance, aResult.mServiceType); + + if (aResult.mTtl != 0) + { + OutputFormat(kIndentSize, "txt-data: "); + OutputBytesLine(aResult.mTxtData, aResult.mTxtDataLength); + } + + OutputLine(kIndentSize, "ttl: %lu", ToUlong(aResult.mTtl)); + OutputLine(kIndentSize, "if-index: %lu", ToUlong(aResult.mInfraIfIndex)); +} +template <> otError Mdns::Process(Arg aArgs[]) +{ + // mdns ip6resolver start|stop + + otError error; + otMdnsAddressResolver resolver; + bool isStart; + + ClearAllBytes(resolver); + + SuccessOrExit(error = ParseStartOrStop(aArgs[0], isStart)); + VerifyOrExit(!aArgs[1].IsEmpty(), error = OT_ERROR_INVALID_ARGS); + + resolver.mHostName = aArgs[1].GetCString(); + resolver.mInfraIfIndex = mInfraIfIndex; + resolver.mCallback = HandleIp6AddressResult; + + if (isStart) + { + error = otMdnsStartIp6AddressResolver(GetInstancePtr(), &resolver); + } + else + { + error = otMdnsStopIp6AddressResolver(GetInstancePtr(), &resolver); + } + +exit: + return error; +} + +void Mdns::HandleIp6AddressResult(otInstance *aInstance, const otMdnsAddressResult *aResult) +{ + OT_UNUSED_VARIABLE(aInstance); + + Interpreter::GetInterpreter().mMdns.HandleAddressResult(*aResult, kIp6Address); +} + +void Mdns::HandleAddressResult(const otMdnsAddressResult &aResult, IpAddressType aType) +{ + OutputLine("mDNS %s address result for %s", aType == kIp6Address ? "IPv6" : "IPv4", aResult.mHostName); + + OutputLine(kIndentSize, "%u address:", aResult.mAddressesLength); + + for (uint16_t index = 0; index < aResult.mAddressesLength; index++) + { + OutputFormat(kIndentSize, " "); + OutputIp6Address(aResult.mAddresses[index].mAddress); + OutputLine(" ttl:%lu", ToUlong(aResult.mAddresses[index].mTtl)); + } + + OutputLine(kIndentSize, "if-index: %lu", ToUlong(aResult.mInfraIfIndex)); +} + +template <> otError Mdns::Process(Arg aArgs[]) +{ + // mdns ip4resolver start|stop + + otError error; + otMdnsAddressResolver resolver; + bool isStart; + + ClearAllBytes(resolver); + + SuccessOrExit(error = ParseStartOrStop(aArgs[0], isStart)); + VerifyOrExit(!aArgs[1].IsEmpty(), error = OT_ERROR_INVALID_ARGS); + + resolver.mHostName = aArgs[1].GetCString(); + resolver.mInfraIfIndex = mInfraIfIndex; + resolver.mCallback = HandleIp4AddressResult; + + if (isStart) + { + error = otMdnsStartIp4AddressResolver(GetInstancePtr(), &resolver); + } + else + { + error = otMdnsStopIp4AddressResolver(GetInstancePtr(), &resolver); + } + +exit: + return error; +} + +void Mdns::HandleIp4AddressResult(otInstance *aInstance, const otMdnsAddressResult *aResult) +{ + OT_UNUSED_VARIABLE(aInstance); + + Interpreter::GetInterpreter().mMdns.HandleAddressResult(*aResult, kIp4Address); +} + +otError Mdns::Process(Arg aArgs[]) +{ +#define CmdEntry(aCommandString) \ + { \ + aCommandString, &Mdns::Process \ + } + + static constexpr Command kCommands[] = { + CmdEntry("browser"), CmdEntry("disable"), CmdEntry("enable"), CmdEntry("ip4resolver"), + CmdEntry("ip6resolver"), CmdEntry("register"), CmdEntry("srvresolver"), CmdEntry("state"), + CmdEntry("txtresolver"), CmdEntry("unicastquestion"), CmdEntry("unregister"), + }; + +#undef CmdEntry + + static_assert(BinarySearch::IsSorted(kCommands), "kCommands is not sorted"); + + otError error = OT_ERROR_INVALID_COMMAND; + const Command *command; + + if (aArgs[0].IsEmpty() || (aArgs[0] == "help")) + { + OutputCommandTable(kCommands); + ExitNow(error = aArgs[0].IsEmpty() ? error : OT_ERROR_NONE); + } + + command = BinarySearch::Find(aArgs[0].GetCString(), kCommands); + VerifyOrExit(command != nullptr); + + error = (this->*command->mHandler)(aArgs + 1); + +exit: + return error; +} + +} // namespace Cli +} // namespace ot + +#endif // OPENTHREAD_CONFIG_MULTICAST_DNS_ENABLE && OPENTHREAD_CONFIG_MULTICAST_DNS_PUBLIC_API_ENABLE diff --git a/src/cli/cli_mdns.hpp b/src/cli/cli_mdns.hpp new file mode 100644 index 000000000..0c189dc8b --- /dev/null +++ b/src/cli/cli_mdns.hpp @@ -0,0 +1,142 @@ +/* + * Copyright (c) 2024, The OpenThread Authors. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the name of the copyright holder nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +/** + * @file + * This file contains definitions for CLI to DNS (client and resolver). + */ + +#ifndef CLI_MDNS_HPP_ +#define CLI_MDNS_HPP_ + +#include "openthread-core-config.h" + +#include + +#include "cli/cli_config.h" +#include "cli/cli_utils.hpp" + +#if OPENTHREAD_CONFIG_MULTICAST_DNS_ENABLE && OPENTHREAD_CONFIG_MULTICAST_DNS_PUBLIC_API_ENABLE + +namespace ot { +namespace Cli { + +/** + * Implements the mDNS CLI interpreter. + * + */ +class Mdns : private Utils +{ +public: + /** + * Constructor. + * + * @param[in] aInstance The OpenThread Instance. + * @param[in] aOutputImplementer An `OutputImplementer`. + * + */ + Mdns(otInstance *aInstance, OutputImplementer &aOutputImplementer) + : Utils(aInstance, aOutputImplementer) + , mInfraIfIndex(0) + , mRequestId(0) + , mWaitingForCallback(false) + { + } + + /** + * Processes a CLI sub-command. + * + * @param[in] aArgs An array of command line arguments. + * + * @retval OT_ERROR_NONE Successfully executed the CLI command. + * @retval OT_ERROR_PENDING The CLI command was successfully started but final result is pending. + * @retval OT_ERROR_INVALID_COMMAND Invalid or unknown CLI command. + * @retval OT_ERROR_INVALID_ARGS Invalid arguments. + * @retval ... Error during execution of the CLI command. + * + */ + otError Process(Arg aArgs[]); + +private: + using Command = CommandEntry; + + static constexpr uint8_t kIndentSize = 4; + static constexpr uint16_t kMaxAddresses = 16; + static constexpr uint16_t kStringSize = 400; + static constexpr uint16_t kMaxSubTypes = 8; + static constexpr uint16_t kMaxTxtDataSize = 200; + static constexpr uint16_t kMaxKeyDataSize = 200; + + enum IpAddressType : uint8_t + { + kIp6Address, + kIp4Address, + }; + + struct Buffers // Used to populate `otMdnsService` field + { + char mString[kStringSize]; + const char *mSubTypeLabels[kMaxSubTypes]; + uint8_t mTxtData[kMaxTxtDataSize]; + }; + + template otError Process(Arg aArgs[]); + + void OutputHost(const otMdnsHost &aHost); + void OutputService(const otMdnsService &aService); + void OutputKey(const otMdnsKey &aKey); + otError ProcessRegisterHost(Arg aArgs[]); + otError ProcessRegisterService(Arg aArgs[]); + otError ProcessRegisterKey(Arg aArgs[]); + void HandleRegisterationDone(otMdnsRequestId aRequestId, otError aError); + void HandleBrowseResult(const otMdnsBrowseResult &aResult); + void HandleSrvResult(const otMdnsSrvResult &aResult); + void HandleTxtResult(const otMdnsTxtResult &aResult); + void HandleAddressResult(const otMdnsAddressResult &aResult, IpAddressType aType); + + static otError ParseStartOrStop(const Arg &aArg, bool &aIsStart); + static void HandleRegisterationDone(otInstance *aInstance, otMdnsRequestId aRequestId, otError aError); + static void HandleBrowseResult(otInstance *aInstance, const otMdnsBrowseResult *aResult); + static void HandleSrvResult(otInstance *aInstance, const otMdnsSrvResult *aResult); + static void HandleTxtResult(otInstance *aInstance, const otMdnsTxtResult *aResult); + static void HandleIp6AddressResult(otInstance *aInstance, const otMdnsAddressResult *aResult); + static void HandleIp4AddressResult(otInstance *aInstance, const otMdnsAddressResult *aResult); + + static otError ParseServiceArgs(Arg aArgs[], otMdnsService &aService, Buffers &aBuffers); + + uint32_t mInfraIfIndex; + otMdnsRequestId mRequestId; + bool mWaitingForCallback; +}; + +} // namespace Cli +} // namespace ot + +#endif // OPENTHREAD_CONFIG_MULTICAST_DNS_ENABLE && OPENTHREAD_CONFIG_MULTICAST_DNS_PUBLIC_API_ENABLE + +#endif // CLI_MDNS_HPP_ diff --git a/src/core/BUILD.gn b/src/core/BUILD.gn index 08fa6db98..dd137f73d 100644 --- a/src/core/BUILD.gn +++ b/src/core/BUILD.gn @@ -336,6 +336,7 @@ openthread_core_files = [ "api/link_metrics_api.cpp", "api/link_raw_api.cpp", "api/logging_api.cpp", + "api/mdns_api.cpp", "api/mesh_diag_api.cpp", "api/message_api.cpp", "api/multi_radio_api.cpp", @@ -575,6 +576,8 @@ openthread_core_files = [ "net/ip6_mpl.cpp", "net/ip6_mpl.hpp", "net/ip6_types.hpp", + "net/mdns.cpp", + "net/mdns.hpp", "net/nat64_translator.cpp", "net/nat64_translator.hpp", "net/nd6.cpp", @@ -815,6 +818,7 @@ source_set("libopenthread_core_config") { "config/link_raw.h", "config/logging.h", "config/mac.h", + "config/mdns.h", "config/mesh_diag.h", "config/mesh_forwarder.h", "config/misc.h", diff --git a/src/core/CMakeLists.txt b/src/core/CMakeLists.txt index 937de5c5a..8bc006e70 100644 --- a/src/core/CMakeLists.txt +++ b/src/core/CMakeLists.txt @@ -62,6 +62,7 @@ set(COMMON_SOURCES api/link_metrics_api.cpp api/link_raw_api.cpp api/logging_api.cpp + api/mdns_api.cpp api/mesh_diag_api.cpp api/message_api.cpp api/multi_radio_api.cpp @@ -178,6 +179,7 @@ set(COMMON_SOURCES net/ip6_filter.cpp net/ip6_headers.cpp net/ip6_mpl.cpp + net/mdns.cpp net/nat64_translator.cpp net/nd6.cpp net/nd_agent.cpp diff --git a/src/core/api/mdns_api.cpp b/src/core/api/mdns_api.cpp new file mode 100644 index 000000000..797153947 --- /dev/null +++ b/src/core/api/mdns_api.cpp @@ -0,0 +1,188 @@ +/* + * Copyright (c) 2024, The OpenThread Authors. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the name of the copyright holder nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +/** + * @file + * This file implements the OpenThread mDNS API. + */ + +#include "openthread-core-config.h" + +#include + +#include "instance/instance.hpp" +#include "net/mdns.hpp" + +using namespace ot; + +#if OPENTHREAD_CONFIG_MULTICAST_DNS_ENABLE && OPENTHREAD_CONFIG_MULTICAST_DNS_PUBLIC_API_ENABLE + +otError otMdnsSetEnabled(otInstance *aInstance, bool aEnable, uint32_t aInfraIfIndex) +{ + return AsCoreType(aInstance).Get().SetEnabled(aEnable, aInfraIfIndex); +} + +bool otMdnsIsEnabled(otInstance *aInstance) { return AsCoreType(aInstance).Get().IsEnabled(); } + +void otMdnsSetQuestionUnicastAllowed(otInstance *aInstance, bool aAllow) +{ + AsCoreType(aInstance).Get().SetQuestionUnicastAllowed(aAllow); +} + +bool otMdnsIsQuestionUnicastAllowed(otInstance *aInstance) +{ + return AsCoreType(aInstance).Get().IsQuestionUnicastAllowed(); +} + +void otMdnsSetConflictCallback(otInstance *aInstance, otMdnsConflictCallback aCallback) +{ + AsCoreType(aInstance).Get().SetConflictCallback(aCallback); +} + +otError otMdnsRegisterHost(otInstance *aInstance, + const otMdnsHost *aHost, + otMdnsRequestId aRequestId, + otMdnsRegisterCallback aCallback) +{ + AssertPointerIsNotNull(aHost); + + return AsCoreType(aInstance).Get().RegisterHost(*aHost, aRequestId, aCallback); +} + +otError otMdnsUnregisterHost(otInstance *aInstance, const otMdnsHost *aHost) +{ + AssertPointerIsNotNull(aHost); + + return AsCoreType(aInstance).Get().UnregisterHost(*aHost); +} + +otError otMdnsRegisterService(otInstance *aInstance, + const otMdnsService *aService, + otMdnsRequestId aRequestId, + otMdnsRegisterCallback aCallback) +{ + AssertPointerIsNotNull(aService); + + return AsCoreType(aInstance).Get().RegisterService(*aService, aRequestId, aCallback); +} + +otError otMdnsUnregisterService(otInstance *aInstance, const otMdnsService *aService) +{ + AssertPointerIsNotNull(aService); + + return AsCoreType(aInstance).Get().UnregisterService(*aService); +} + +otError otMdnsRegisterKey(otInstance *aInstance, + const otMdnsKey *aKey, + otMdnsRequestId aRequestId, + otMdnsRegisterCallback aCallback) +{ + AssertPointerIsNotNull(aKey); + + return AsCoreType(aInstance).Get().RegisterKey(*aKey, aRequestId, aCallback); +} + +otError otMdnsUnregisterKey(otInstance *aInstance, const otMdnsKey *aKey) +{ + AssertPointerIsNotNull(aKey); + + return AsCoreType(aInstance).Get().UnregisterKey(*aKey); +} + +otError otMdnsStartBrowser(otInstance *aInstance, const otMdnsBrowser *aBroswer) +{ + AssertPointerIsNotNull(aBroswer); + + return AsCoreType(aInstance).Get().StartBrowser(*aBroswer); +} + +otError otMdnsStopBrowser(otInstance *aInstance, const otMdnsBrowser *aBroswer) +{ + AssertPointerIsNotNull(aBroswer); + + return AsCoreType(aInstance).Get().StopBrowser(*aBroswer); +} + +otError otMdnsStartSrvResolver(otInstance *aInstance, const otMdnsSrvResolver *aResolver) +{ + AssertPointerIsNotNull(aResolver); + + return AsCoreType(aInstance).Get().StartSrvResolver(*aResolver); +} + +otError otMdnsStopSrvResolver(otInstance *aInstance, const otMdnsSrvResolver *aResolver) +{ + AssertPointerIsNotNull(aResolver); + + return AsCoreType(aInstance).Get().StopSrvResolver(*aResolver); +} + +otError otMdnsStartTxtResolver(otInstance *aInstance, const otMdnsTxtResolver *aResolver) +{ + AssertPointerIsNotNull(aResolver); + + return AsCoreType(aInstance).Get().StartTxtResolver(*aResolver); +} + +otError otMdnsStopTxtResolver(otInstance *aInstance, const otMdnsTxtResolver *aResolver) +{ + AssertPointerIsNotNull(aResolver); + + return AsCoreType(aInstance).Get().StopTxtResolver(*aResolver); +} + +otError otMdnsStartIp6AddressResolver(otInstance *aInstance, const otMdnsAddressResolver *aResolver) +{ + AssertPointerIsNotNull(aResolver); + + return AsCoreType(aInstance).Get().StartIp6AddressResolver(*aResolver); +} + +otError otMdnsStopIp6AddressResolver(otInstance *aInstance, const otMdnsAddressResolver *aResolver) +{ + AssertPointerIsNotNull(aResolver); + + return AsCoreType(aInstance).Get().StopIp6AddressResolver(*aResolver); +} + +otError otMdnsStartIp4AddressResolver(otInstance *aInstance, const otMdnsAddressResolver *aResolver) +{ + AssertPointerIsNotNull(aResolver); + + return AsCoreType(aInstance).Get().StartIp4AddressResolver(*aResolver); +} + +otError otMdnsStopIp4AddressResolver(otInstance *aInstance, const otMdnsAddressResolver *aResolver) +{ + AssertPointerIsNotNull(aResolver); + + return AsCoreType(aInstance).Get().StopIp4AddressResolver(*aResolver); +} + +#endif // OPENTHREAD_CONFIG_MULTICAST_DNS_ENABLE && OPENTHREAD_CONFIG_MULTICAST_DNS_PUBLIC_API_ENABLE diff --git a/src/core/config/mdns.h b/src/core/config/mdns.h new file mode 100644 index 000000000..04141b559 --- /dev/null +++ b/src/core/config/mdns.h @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2024, The OpenThread Authors. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the name of the copyright holder nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +/** + * @file + * This file includes compile-time configurations for the Multicast DNS (mDNS). + * + */ + +#ifndef CONFIG_MULTICAST_DNS_H_ +#define CONFIG_MULTICAST_DNS_H_ + +/** + * @addtogroup config-mdns + * + * @brief + * This module includes configuration variables for the Multicast DNS (mDNS). + * + * @{ + * + */ + +/** + * @def OPENTHREAD_CONFIG_MULTICAST_DNS_ENABLE + * + * Define to 1 to enable Multicast DNS (mDNS) support. + * + */ +#ifndef OPENTHREAD_CONFIG_MULTICAST_DNS_ENABLE +#define OPENTHREAD_CONFIG_MULTICAST_DNS_ENABLE 0 +#endif + +/** + * @def OPENTHREAD_CONFIG_MULTICAST_DNS_PUBLIC_API_ENABLE + * + * Define to 1 to allow public OpenThread APIs to be defined for Multicast DNS (mDNS) module. + * + * The OpenThread mDNS module is mainly intended for use by other OT core modules, so the public APIs are by default + * not provided. + * + */ +#ifndef OPENTHREAD_CONFIG_MULTICAST_DNS_PUBLIC_API_ENABLE +#define OPENTHREAD_CONFIG_MULTICAST_DNS_PUBLIC_API_ENABLE 0 +#endif + +/** + * @def OPENTHREAD_CONFIG_MULTICAST_DNS_DEFAULT_QUESTION_UNICAST_ALLOWED + * + * Specified the default value for `otMdnsIsQuestionUnicastAllowed()` which indicates whether mDNS core is allowed to + * send "QU" questions (questions requesting unicast response). When allowed, the first probe will be sent as "QU" + * question. The `otMdnsSetQuestionUnicastAllowed()` can be used to change the default value at run-time. + * + */ +#ifndef OPENTHREAD_CONFIG_MULTICAST_DNS_DEFAULT_QUESTION_UNICAST_ALLOWED +#define OPENTHREAD_CONFIG_MULTICAST_DNS_DEFAULT_QUESTION_UNICAST_ALLOWED 1 +#endif + +/** + * @def OPENTHREAD_CONFIG_MULTICAST_DNS_MOCK_PLAT_APIS_ENABLE + * + * Define to 1 to add mock (empty) implementation of mDNS platform APIs. + * + * This is intended for generating code size report only and should not be used otherwise. + * + */ +#ifndef OPENTHREAD_CONFIG_MULTICAST_DNS_MOCK_PLAT_APIS_ENABLE +#define OPENTHREAD_CONFIG_MULTICAST_DNS_MOCK_PLAT_APIS_ENABLE 0 +#endif + +/** + * @} + * + */ + +#endif // CONFIG_MULTICAST_DNS_H_ diff --git a/src/core/instance/instance.cpp b/src/core/instance/instance.cpp index 398dc8eb2..3519fd57f 100644 --- a/src/core/instance/instance.cpp +++ b/src/core/instance/instance.cpp @@ -122,6 +122,9 @@ Instance::Instance(void) #if OPENTHREAD_CONFIG_DNS_DSO_ENABLE , mDnsDso(*this) #endif +#if OPENTHREAD_CONFIG_MULTICAST_DNS_ENABLE + , mMdnsCore(*this) +#endif #if OPENTHREAD_CONFIG_SNTP_CLIENT_ENABLE , mSntpClient(*this) #endif diff --git a/src/core/instance/instance.hpp b/src/core/instance/instance.hpp index 39dea9d5f..3050b075b 100644 --- a/src/core/instance/instance.hpp +++ b/src/core/instance/instance.hpp @@ -92,6 +92,7 @@ #include "net/dnssd_server.hpp" #include "net/ip6.hpp" #include "net/ip6_filter.hpp" +#include "net/mdns.hpp" #include "net/nat64_translator.hpp" #include "net/nd_agent.hpp" #include "net/netif.hpp" @@ -526,6 +527,10 @@ class Instance : public otInstance, private NonCopyable Dns::Dso mDnsDso; #endif +#if OPENTHREAD_CONFIG_MULTICAST_DNS_ENABLE + Dns::Multicast::Core mMdnsCore; +#endif + #if OPENTHREAD_CONFIG_SNTP_CLIENT_ENABLE Sntp::Client mSntpClient; #endif @@ -925,6 +930,10 @@ template <> inline Dns::ServiceDiscovery::Server &Instance::Get(void) { return m template <> inline Dns::Dso &Instance::Get(void) { return mDnsDso; } #endif +#if OPENTHREAD_CONFIG_MULTICAST_DNS_ENABLE +template <> inline Dns::Multicast::Core &Instance::Get(void) { return mMdnsCore; } +#endif + template <> inline NetworkDiagnostic::Server &Instance::Get(void) { return mNetworkDiagnosticServer; } #if OPENTHREAD_CONFIG_TMF_NETDIAG_CLIENT_ENABLE diff --git a/src/core/net/dns_types.cpp b/src/core/net/dns_types.cpp index eecc9ff29..b8a9e37e9 100644 --- a/src/core/net/dns_types.cpp +++ b/src/core/net/dns_types.cpp @@ -36,6 +36,7 @@ #include "common/code_utils.hpp" #include "common/debug.hpp" #include "common/num_utils.hpp" +#include "common/numeric_limits.hpp" #include "common/random.hpp" #include "common/string.hpp" #include "instance/instance.hpp" @@ -1343,5 +1344,35 @@ bool TxtRecord::VerifyTxtData(const uint8_t *aTxtData, uint16_t aTxtLength, bool return valid; } +void NsecRecord::TypeBitMap::AddType(uint16_t aType) +{ + if ((aType >> 8) == mBlockNumber) + { + uint8_t type = static_cast(aType & 0xff); + uint8_t index = (type / kBitsPerByte); + uint16_t mask = (0x80 >> (type % kBitsPerByte)); + + mBitmaps[index] |= mask; + mBitmapLength = Max(mBitmapLength, index + 1); + } +} + +bool NsecRecord::TypeBitMap::ContainsType(uint16_t aType) const +{ + bool contains = false; + uint8_t type = static_cast(aType & 0xff); + uint8_t index = (type / kBitsPerByte); + uint16_t mask = (0x80 >> (type % kBitsPerByte)); + + VerifyOrExit((aType >> 8) == mBlockNumber); + + VerifyOrExit(index < mBitmapLength); + + contains = (mBitmaps[index] & mask); + +exit: + return contains; +} + } // namespace Dns } // namespace ot diff --git a/src/core/net/dns_types.hpp b/src/core/net/dns_types.hpp index e590a4831..8ca93e13f 100644 --- a/src/core/net/dns_types.hpp +++ b/src/core/net/dns_types.hpp @@ -1336,6 +1336,7 @@ class ResourceRecord static constexpr uint16_t kTypeAaaa = 28; ///< IPv6 address record. static constexpr uint16_t kTypeSrv = 33; ///< SRV locator record. static constexpr uint16_t kTypeOpt = 41; ///< Option record. + static constexpr uint16_t kTypeNsec = 47; ///< NSEC record. static constexpr uint16_t kTypeAny = 255; ///< ANY record. // Resource Record Class Codes. @@ -2745,6 +2746,104 @@ class LeaseOption : public Option uint32_t mKeyLeaseInterval; } OT_TOOL_PACKED_END; +/** + * Implements body format of NSEC record (RFC 3845) for use with mDNS. + * + */ +OT_TOOL_PACKED_BEGIN +class NsecRecord : public ResourceRecord +{ +public: + static constexpr uint16_t kType = kTypeNsec; ///< The NSEC record type. + + /** + * Represents NSEC Type Bit Map field (RFC 3845 - section 2.1.2) + * + */ + OT_TOOL_PACKED_BEGIN + class TypeBitMap : public Clearable + { + public: + static constexpr uint8_t kMinSize = 2; ///< Minimum size of a valid `TypeBitMap` (with zero length). + + static constexpr uint8_t kMaxLength = 32; ///< Maximum BitmapLength value. + + /** + * Gets the Window Block Number + * + * @returns The Window Block Number. + * + */ + uint8_t GetBlockNumber(void) const { return mBlockNumber; } + + /** + * Sets the Window Block Number + * + * @param[in] aBlockNumber The Window Block Number. + * + */ + void SetBlockNumber(uint8_t aBlockNumber) { mBlockNumber = aBlockNumber; } + + /** + * Gets the Bitmap length + * + * @returns The Bitmap length + * + */ + uint8_t GetBitmapLength(void) { return mBitmapLength; } + + /** + * Gets the total size (number of bytes) of the `TypeBitMap` field. + * + * @returns The size of the `TypeBitMap` + * + */ + uint16_t GetSize(void) const { return (sizeof(mBlockNumber) + sizeof(mBitmapLength) + mBitmapLength); } + + /** + * Adds a resource record type to the Bitmap. + * + * As the types are added to the Bitmap the Bitmap length gets updated accordingly. + * + * The type space is split into 256 window blocks, each representing the low-order 8 bits of the 16-bit type + * value. If @p aType does not match the currently set Window Block Number, no action is performed. + * + * @param[in] aType The resource record type to add. + * + */ + void AddType(uint16_t aType); + + /** + * Indicates whether a given resource record type is present in the Bitmap. + * + * If @p aType does not match the currently set Window Block Number, this method returns `false`.. + * + * @param[in] aType The resource record type to check. + * + * @retval TRUE The @p aType is present in the Bitmap. + * @retval FALSE The @p aType is not present in the Bitmap. + * + */ + bool ContainsType(uint16_t aType) const; + + private: + uint8_t mBlockNumber; + uint8_t mBitmapLength; + uint8_t mBitmaps[kMaxLength]; + } OT_TOOL_PACKED_END; + + /** + * Initializes the NSEC Resource Record by setting its type and class. + * + * Other record fields (TTL, length remain unchanged/uninitialized. + * + * @param[in] aClass The class of the resource record (default is `kClassInternet`). + * + */ + void Init(uint16_t aClass = kClassInternet) { ResourceRecord::Init(kTypeNsec, aClass); } + +} OT_TOOL_PACKED_END; + /** * Implements Question format. * diff --git a/src/core/net/mdns.cpp b/src/core/net/mdns.cpp new file mode 100644 index 000000000..fb28eda59 --- /dev/null +++ b/src/core/net/mdns.cpp @@ -0,0 +1,6011 @@ +/* + * Copyright (c) 2024, The OpenThread Authors. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the name of the copyright holder nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#include "mdns.hpp" + +#if OPENTHREAD_CONFIG_MULTICAST_DNS_ENABLE + +#include "common/code_utils.hpp" +#include "common/locator_getters.hpp" +#include "common/log.hpp" +#include "common/numeric_limits.hpp" +#include "common/type_traits.hpp" +#include "instance/instance.hpp" + +/** + * @file + * This file implements the Multicast DNS (mDNS) per RFC 6762. + */ + +namespace ot { +namespace Dns { +namespace Multicast { + +RegisterLogModule("MulticastDns"); + +//--------------------------------------------------------------------------------------------------------------------- +// otPlatMdns callbacks + +extern "C" void otPlatMdnsHandleReceive(otInstance *aInstance, + otMessage *aMessage, + bool aIsUnicast, + const otPlatMdnsAddressInfo *aAddress) +{ + AsCoreType(aInstance).Get().HandleMessage(AsCoreType(aMessage), aIsUnicast, AsCoreType(aAddress)); +} + +//---------------------------------------------------------------------------------------------------------------------- +// Core + +const char Core::kLocalDomain[] = "local."; +const char Core::kUdpServiceLabel[] = "_udp"; +const char Core::kTcpServiceLabel[] = "_tcp"; +const char Core::kSubServiceLabel[] = "_sub"; +const char Core::kServicesDnssdLabels[] = "_services._dns-sd._udp"; + +Core::Core(Instance &aInstance) + : InstanceLocator(aInstance) + , mIsEnabled(false) + , mIsQuestionUnicastAllowed(kDefaultQuAllowed) + , mMaxMessageSize(kMaxMessageSize) + , mInfraIfIndex(0) + , mMultiPacketRxMessages(aInstance) + , mNextProbeTxTime(TimerMilli::GetNow() - 1) + , mEntryTimer(aInstance) + , mEntryTask(aInstance) + , mTxMessageHistory(aInstance) + , mConflictCallback(nullptr) + , mNextQueryTxTime(TimerMilli::GetNow() - 1) + , mCacheTimer(aInstance) + , mCacheTask(aInstance) +{ +} + +Error Core::SetEnabled(bool aEnable, uint32_t aInfraIfIndex) +{ + Error error = kErrorNone; + + VerifyOrExit(aEnable != mIsEnabled, error = kErrorAlready); + SuccessOrExit(error = otPlatMdnsSetListeningEnabled(&GetInstance(), aEnable, aInfraIfIndex)); + + mIsEnabled = aEnable; + mInfraIfIndex = aInfraIfIndex; + + if (mIsEnabled) + { + LogInfo("Enabling on infra-if-index %lu", ToUlong(mInfraIfIndex)); + } + else + { + LogInfo("Disabling"); + } + + if (!mIsEnabled) + { + mHostEntries.Clear(); + mServiceEntries.Clear(); + mServiceTypes.Clear(); + mMultiPacketRxMessages.Clear(); + mTxMessageHistory.Clear(); + mEntryTimer.Stop(); + + mBrowseCacheList.Clear(); + mSrvCacheList.Clear(); + mTxtCacheList.Clear(); + mIp6AddrCacheList.Clear(); + mIp4AddrCacheList.Clear(); + mCacheTimer.Stop(); + } + +exit: + return error; +} + +template +Error Core::Register(const ItemInfo &aItemInfo, RequestId aRequestId, RegisterCallback aCallback) +{ + Error error = kErrorNone; + EntryType *entry; + + VerifyOrExit(mIsEnabled, error = kErrorInvalidState); + + entry = GetEntryList().FindMatching(aItemInfo); + + if (entry == nullptr) + { + entry = EntryType::AllocateAndInit(GetInstance(), aItemInfo); + OT_ASSERT(entry != nullptr); + GetEntryList().Push(*entry); + } + + entry->Register(aItemInfo, Callback(aRequestId, aCallback)); + +exit: + return error; +} + +template Error Core::Unregister(const ItemInfo &aItemInfo) +{ + Error error = kErrorNone; + EntryType *entry; + + VerifyOrExit(mIsEnabled, error = kErrorInvalidState); + + entry = GetEntryList().FindMatching(aItemInfo); + + if (entry != nullptr) + { + entry->Unregister(aItemInfo); + } + +exit: + return error; +} + +Error Core::RegisterHost(const Host &aHost, RequestId aRequestId, RegisterCallback aCallback) +{ + return Register(aHost, aRequestId, aCallback); +} + +Error Core::UnregisterHost(const Host &aHost) { return Unregister(aHost); } + +Error Core::RegisterService(const Service &aService, RequestId aRequestId, RegisterCallback aCallback) +{ + return Register(aService, aRequestId, aCallback); +} + +Error Core::UnregisterService(const Service &aService) { return Unregister(aService); } + +Error Core::RegisterKey(const Key &aKey, RequestId aRequestId, RegisterCallback aCallback) +{ + return IsKeyForService(aKey) ? Register(aKey, aRequestId, aCallback) + : Register(aKey, aRequestId, aCallback); +} + +Error Core::UnregisterKey(const Key &aKey) +{ + return IsKeyForService(aKey) ? Unregister(aKey) : Unregister(aKey); +} + +void Core::InvokeConflictCallback(const char *aName, const char *aServiceType) +{ + if (mConflictCallback != nullptr) + { + mConflictCallback(&GetInstance(), aName, aServiceType); + } +} +void Core::HandleMessage(Message &aMessage, bool aIsUnicast, const AddressInfo &aSenderAddress) +{ + OwnedPtr messagePtr(&aMessage); + OwnedPtr rxMessagePtr; + + VerifyOrExit(mIsEnabled); + + rxMessagePtr.Reset(RxMessage::AllocateAndInit(GetInstance(), messagePtr, aIsUnicast, aSenderAddress)); + VerifyOrExit(!rxMessagePtr.IsNull()); + + if (rxMessagePtr->IsQuery()) + { + // Check if this is a continuation of a multi-packet query. + // Initial query message sets the "Truncated" flag. + // Subsequent messages from the same sender contain no + // question and only known-answer records. + + if ((rxMessagePtr->GetRecordCounts().GetFor(kQuestionSection) == 0) && + (rxMessagePtr->GetRecordCounts().GetFor(kAnswerSection) > 0)) + { + mMultiPacketRxMessages.AddToExisting(rxMessagePtr); + ExitNow(); + } + + switch (rxMessagePtr->ProcessQuery(/* aShouldProcessTruncated */ false)) + { + case RxMessage::kProcessed: + break; + + case RxMessage::kSaveAsMultiPacket: + // This is a truncated multi-packet query and we can + // answer some questions in this query. We save it in + // `mMultiPacketRxMessages` list and defer its response + // for a random time waiting to receive next messages + // containing additional known-answer records. + + mMultiPacketRxMessages.AddNew(rxMessagePtr); + break; + } + } + else + { + rxMessagePtr->ProcessResponse(); + } + +exit: + return; +} + +void Core::HandleEntryTimer(void) +{ + EntryTimerContext context(GetInstance()); + + // We process host entries before service entries. This order + // ensures we can determine whether host addresses have already + // been appended to the Answer section (when processing service entries), + // preventing duplicates. + + for (HostEntry &entry : mHostEntries) + { + entry.HandleTimer(context); + } + + for (ServiceEntry &entry : mServiceEntries) + { + entry.HandleTimer(context); + } + + for (ServiceType &serviceType : mServiceTypes) + { + serviceType.HandleTimer(context); + } + + context.GetProbeMessage().Send(); + context.GetResponseMessage().Send(); + + RemoveEmptyEntries(); + + if (context.GetNextTime() != context.GetNow().GetDistantFuture()) + { + mEntryTimer.FireAtIfEarlier(context.GetNextTime()); + } +} + +void Core::RemoveEmptyEntries(void) +{ + OwningList removedHosts; + OwningList removedServices; + + mHostEntries.RemoveAllMatching(Entry::kRemoving, removedHosts); + mServiceEntries.RemoveAllMatching(Entry::kRemoving, removedServices); +} + +void Core::HandleEntryTask(void) +{ + // `mEntryTask` serves two purposes: + // + // Invoking callbacks: This ensures `Register()` calls will always + // return before invoking the callback, even when entry is + // already in `kRegistered` state and registration is immediately + // successful. + // + // Removing empty entries after `Unregister()` calls: This + // prevents modification of `mHostEntries` and `mServiceEntries` + // during callback execution while we are iterating over these + // lists. Allows us to safely call `Register()` or `Unregister()` + // from callbacks without iterator invalidation. + + for (HostEntry &entry : mHostEntries) + { + entry.InvokeCallbacks(); + } + + for (ServiceEntry &entry : mServiceEntries) + { + entry.InvokeCallbacks(); + } + + RemoveEmptyEntries(); +} + +uint32_t Core::DetermineTtl(uint32_t aTtl, uint32_t aDefaultTtl) +{ + return (aTtl == kUnspecifiedTtl) ? aDefaultTtl : aTtl; +} + +bool Core::NameMatch(const Heap::String &aHeapString, const char *aName) +{ + // Compares a DNS name given as a `Heap::String` with a + // `aName` C string. + + return !aHeapString.IsNull() && StringMatch(aHeapString.AsCString(), aName, kStringCaseInsensitiveMatch); +} + +bool Core::NameMatch(const Heap::String &aFirst, const Heap::String &aSecond) +{ + // Compares two DNS names given as `Heap::String`. + + return !aSecond.IsNull() && NameMatch(aFirst, aSecond.AsCString()); +} + +void Core::UpdateCacheFlushFlagIn(ResourceRecord &aResourceRecord, Section aSection) +{ + // Do not set the cache-flush flag is the record is + // appended in Authority Section in a probe message. + + if (aSection != kAuthoritySection) + { + aResourceRecord.SetClass(aResourceRecord.GetClass() | kClassCacheFlushFlag); + } +} + +void Core::UpdateRecordLengthInMessage(ResourceRecord &aRecord, Message &aMessage, uint16_t aOffset) +{ + // Determines the records DATA length and updates it in a message. + // Should be called immediately after all the fields in the + // record are appended to the message. `aOffset` gives the offset + // in the message to the start of the record. + + aRecord.SetLength(aMessage.GetLength() - aOffset - sizeof(ResourceRecord)); + aMessage.Write(aOffset, aRecord); +} + +void Core::UpdateCompressOffset(uint16_t &aOffset, uint16_t aNewOffset) +{ + if ((aOffset == kUnspecifiedOffset) && (aNewOffset != kUnspecifiedOffset)) + { + aOffset = aNewOffset; + } +} + +bool Core::QuestionMatches(uint16_t aQuestionRrType, uint16_t aRrType) +{ + return (aQuestionRrType == aRrType) || (aQuestionRrType == ResourceRecord::kTypeAny); +} + +bool Core::RrClassIsInternetOrAny(uint16_t aRrClass) +{ + aRrClass &= kClassMask; + + return (aRrClass == ResourceRecord::kClassInternet) || (aRrClass == ResourceRecord::kClassAny); +} + +//---------------------------------------------------------------------------------------------------------------------- +// Core::Callback + +Core::Callback::Callback(RequestId aRequestId, RegisterCallback aCallback) + : mRequestId(aRequestId) + , mCallback(aCallback) +{ +} + +void Core::Callback::InvokeAndClear(Instance &aInstance, Error aError) +{ + if (mCallback != nullptr) + { + RegisterCallback callback = mCallback; + RequestId requestId = mRequestId; + + Clear(); + + callback(&aInstance, requestId, aError); + } +} + +//---------------------------------------------------------------------------------------------------------------------- +// Core::RecordCounts + +void Core::RecordCounts::ReadFrom(const Header &aHeader) +{ + mCounts[kQuestionSection] = aHeader.GetQuestionCount(); + mCounts[kAnswerSection] = aHeader.GetAnswerCount(); + mCounts[kAuthoritySection] = aHeader.GetAuthorityRecordCount(); + mCounts[kAdditionalDataSection] = aHeader.GetAdditionalRecordCount(); +} + +void Core::RecordCounts::WriteTo(Header &aHeader) const +{ + aHeader.SetQuestionCount(mCounts[kQuestionSection]); + aHeader.SetAnswerCount(mCounts[kAnswerSection]); + aHeader.SetAuthorityRecordCount(mCounts[kAuthoritySection]); + aHeader.SetAdditionalRecordCount(mCounts[kAdditionalDataSection]); +} + +bool Core::RecordCounts::IsEmpty(void) const +{ + // Indicates whether or not all counts are zero. + + bool isEmpty = true; + + for (uint16_t count : mCounts) + { + if (count != 0) + { + isEmpty = false; + break; + } + } + + return isEmpty; +} + +//---------------------------------------------------------------------------------------------------------------------- +// Core::AddressArray + +bool Core::AddressArray::Matches(const Ip6::Address *aAddresses, uint16_t aNumAddresses) const +{ + bool matches = false; + + VerifyOrExit(aNumAddresses == GetLength()); + + for (uint16_t i = 0; i < aNumAddresses; i++) + { + VerifyOrExit(Contains(aAddresses[i])); + } + + matches = true; + +exit: + return matches; +} + +void Core::AddressArray::SetFrom(const Ip6::Address *aAddresses, uint16_t aNumAddresses) +{ + Free(); + SuccessOrAssert(ReserveCapacity(aNumAddresses)); + + for (uint16_t i = 0; i < aNumAddresses; i++) + { + IgnoreError(PushBack(aAddresses[i])); + } +} + +//---------------------------------------------------------------------------------------------------------------------- +// Core::RecordInfo + +template void Core::RecordInfo::UpdateProperty(UintType &aProperty, UintType aValue) +{ + // Updates a property variable associated with this record. The + // `aProperty` is updated if the record is empty (has no value + // yet) or if its current value differs from the new `aValue`. If + // the property is changed, we prepare the record to be announced. + + // This template version works with `UintType` properties. There + // are similar overloads for `Heap::Data` and `Heap::String` and + // `AddressArray` property types below. + + static_assert(TypeTraits::IsSame::kValue || TypeTraits::IsSame::kValue || + TypeTraits::IsSame::kValue || TypeTraits::IsSame::kValue, + "UintType must be `uint8_t`, `uint16_t`, `uint32_t`, or `uint64_t`"); + + if (!mIsPresent || (aProperty != aValue)) + { + mIsPresent = true; + aProperty = aValue; + StartAnnouncing(); + } +} + +void Core::RecordInfo::UpdateProperty(Heap::String &aStringProperty, const char *aString) +{ + if (!mIsPresent || !NameMatch(aStringProperty, aString)) + { + mIsPresent = true; + SuccessOrAssert(aStringProperty.Set(aString)); + StartAnnouncing(); + } +} + +void Core::RecordInfo::UpdateProperty(Heap::Data &aDataProperty, const uint8_t *aData, uint16_t aLength) +{ + if (!mIsPresent || !aDataProperty.Matches(aData, aLength)) + { + mIsPresent = true; + SuccessOrAssert(aDataProperty.SetFrom(aData, aLength)); + StartAnnouncing(); + } +} + +void Core::RecordInfo::UpdateProperty(AddressArray &aAddrProperty, const Ip6::Address *aAddrs, uint16_t aNumAddrs) +{ + if (!mIsPresent || !aAddrProperty.Matches(aAddrs, aNumAddrs)) + { + mIsPresent = true; + aAddrProperty.SetFrom(aAddrs, aNumAddrs); + StartAnnouncing(); + } +} + +void Core::RecordInfo::UpdateTtl(uint32_t aTtl) { return UpdateProperty(mTtl, aTtl); } + +void Core::RecordInfo::StartAnnouncing(void) +{ + if (mIsPresent) + { + mAnnounceCounter = 0; + mAnnounceTime = TimerMilli::GetNow(); + } +} + +bool Core::RecordInfo::CanAnswer(void) const { return (mIsPresent && (mTtl > 0)); } + +void Core::RecordInfo::ScheduleAnswer(const AnswerInfo &aInfo) +{ + VerifyOrExit(CanAnswer()); + + if (aInfo.mUnicastResponse) + { + mUnicastAnswerPending = true; + ExitNow(); + } + + if (!aInfo.mIsProbe) + { + // Rate-limiting multicasts to prevent excessive packet flooding + // (RFC 6762 section 6): We enforce a minimum interval of one + // second (`kMinIntervalBetweenMulticast`) between multicast + // transmissions of the same record. Skip the new request if the + // answer time is too close to the last multicast time. A querier + // that did not receive and cache the previous transmission will + // retry its request. + + VerifyOrExit(GetDurationSinceLastMulticast(aInfo.mAnswerTime) >= kMinIntervalBetweenMulticast); + } + + if (mMulticastAnswerPending) + { + VerifyOrExit(aInfo.mAnswerTime < mAnswerTime); + } + + mMulticastAnswerPending = true; + mAnswerTime = aInfo.mAnswerTime; + +exit: + return; +} + +bool Core::RecordInfo::ShouldAppendTo(TxMessage &aResponse, TimeMilli aNow) const +{ + bool shouldAppend = false; + + VerifyOrExit(mIsPresent); + + switch (aResponse.GetType()) + { + case TxMessage::kMulticastResponse: + + if ((mAnnounceCounter < kNumberOfAnnounces) && (mAnnounceTime <= aNow)) + { + shouldAppend = true; + ExitNow(); + } + + shouldAppend = mMulticastAnswerPending && (mAnswerTime <= aNow); + break; + + case TxMessage::kUnicastResponse: + shouldAppend = mUnicastAnswerPending; + break; + + default: + break; + } + +exit: + return shouldAppend; +} + +void Core::RecordInfo::UpdateStateAfterAnswer(const TxMessage &aResponse) +{ + // Updates the state after a unicast or multicast response is + // prepared containing the record in the Answer section. + + VerifyOrExit(mIsPresent); + + switch (aResponse.GetType()) + { + case TxMessage::kMulticastResponse: + VerifyOrExit(mAppendState == kAppendedInMulticastMsg); + VerifyOrExit(mAppendSection == kAnswerSection); + + mMulticastAnswerPending = false; + + if (mAnnounceCounter < kNumberOfAnnounces) + { + mAnnounceCounter++; + + if (mAnnounceCounter < kNumberOfAnnounces) + { + uint32_t delay = (1U << (mAnnounceCounter - 1)) * kAnnounceInterval; + + mAnnounceTime = TimerMilli::GetNow() + delay; + } + else if (mTtl == 0) + { + // We are done announcing the removed record with zero TTL. + mIsPresent = false; + } + } + + break; + + case TxMessage::kUnicastResponse: + VerifyOrExit(IsAppended()); + VerifyOrExit(mAppendSection == kAnswerSection); + mUnicastAnswerPending = false; + break; + + default: + break; + } + +exit: + return; +} + +void Core::RecordInfo::UpdateFireTimeOn(FireTime &aFireTime) +{ + VerifyOrExit(mIsPresent); + + if (mAnnounceCounter < kNumberOfAnnounces) + { + aFireTime.SetFireTime(mAnnounceTime); + } + + if (mMulticastAnswerPending) + { + aFireTime.SetFireTime(mAnswerTime); + } + + if (mIsLastMulticastValid) + { + // `mLastMulticastTime` tracks the timestamp of the last + // multicast of this record. To handle potential 32-bit + // `TimeMilli` rollover, an aging mechanism is implemented. + // If the record isn't multicast again within a given age + // interval `kLastMulticastTimeAge`, `mIsLastMulticastValid` + // is cleared, indicating outdated multicast information. + + TimeMilli lastMulticastAgeTime = mLastMulticastTime + kLastMulticastTimeAge; + + if (lastMulticastAgeTime <= TimerMilli::GetNow()) + { + mIsLastMulticastValid = false; + } + else + { + aFireTime.SetFireTime(lastMulticastAgeTime); + } + } + +exit: + return; +} + +void Core::RecordInfo::MarkAsAppended(TxMessage &aTxMessage, Section aSection) +{ + mAppendSection = aSection; + + switch (aTxMessage.GetType()) + { + case TxMessage::kMulticastResponse: + case TxMessage::kMulticastProbe: + + mAppendState = kAppendedInMulticastMsg; + + if ((aSection == kAnswerSection) || (aSection == kAdditionalDataSection)) + { + mLastMulticastTime = TimerMilli::GetNow(); + mIsLastMulticastValid = true; + } + + break; + + case TxMessage::kUnicastResponse: + mAppendState = kAppendedInUnicastMsg; + break; + + case TxMessage::kMulticastQuery: + break; + } +} + +void Core::RecordInfo::MarkToAppendInAdditionalData(void) +{ + if (mAppendState == kNotAppended) + { + mAppendState = kToAppendInAdditionalData; + } +} + +bool Core::RecordInfo::IsAppended(void) const +{ + bool isAppended = false; + + switch (mAppendState) + { + case kNotAppended: + case kToAppendInAdditionalData: + break; + case kAppendedInMulticastMsg: + case kAppendedInUnicastMsg: + isAppended = true; + break; + } + + return isAppended; +} + +bool Core::RecordInfo::CanAppend(void) const { return mIsPresent && !IsAppended(); } + +Error Core::RecordInfo::GetLastMulticastTime(TimeMilli &aLastMulticastTime) const +{ + Error error = kErrorNotFound; + + VerifyOrExit(mIsPresent && mIsLastMulticastValid); + aLastMulticastTime = mLastMulticastTime; + +exit: + return error; +} + +uint32_t Core::RecordInfo::GetDurationSinceLastMulticast(TimeMilli aTime) const +{ + uint32_t duration = NumericLimits::kMax; + + VerifyOrExit(mIsPresent && mIsLastMulticastValid); + VerifyOrExit(aTime > mLastMulticastTime, duration = 0); + duration = aTime - mLastMulticastTime; + +exit: + return duration; +} + +//---------------------------------------------------------------------------------------------------------------------- +// Core::FireTime + +void Core::FireTime::SetFireTime(TimeMilli aFireTime) +{ + if (mHasFireTime) + { + VerifyOrExit(aFireTime < mFireTime); + } + + mFireTime = aFireTime; + mHasFireTime = true; + +exit: + return; +} + +void Core::FireTime::ScheduleFireTimeOn(TimerMilli &aTimer) +{ + if (mHasFireTime) + { + aTimer.FireAtIfEarlier(mFireTime); + } +} + +//---------------------------------------------------------------------------------------------------------------------- +// Core::Entry + +Core::Entry::Entry(void) + : mState(kProbing) + , mProbeCount(0) + , mMulticastNsecPending(false) + , mUnicastNsecPending(false) + , mAppendedNsec(false) +{ +} + +void Core::Entry::Init(Instance &aInstance) +{ + // Initializes a newly allocated entry (host or service) + // and starts it in `kProbing` state. + + InstanceLocatorInit::Init(aInstance); + StartProbing(); +} + +void Core::Entry::SetState(State aState) +{ + mState = aState; + ScheduleCallbackTask(); +} + +void Core::Entry::Register(const Key &aKey, const Callback &aCallback) +{ + if (GetState() == kRemoving) + { + StartProbing(); + } + + mKeyRecord.UpdateTtl(DetermineTtl(aKey.mTtl, kDefaultKeyTtl)); + mKeyRecord.UpdateProperty(mKeyData, aKey.mKeyData, aKey.mKeyDataLength); + + mKeyCallback = aCallback; + ScheduleCallbackTask(); +} + +void Core::Entry::Unregister(const Key &aKey) +{ + OT_UNUSED_VARIABLE(aKey); + + VerifyOrExit(mKeyRecord.IsPresent()); + + mKeyCallback.Clear(); + + switch (GetState()) + { + case kRegistered: + mKeyRecord.UpdateTtl(0); + break; + + case kProbing: + case kConflict: + ClearKey(); + break; + + case kRemoving: + break; + } + +exit: + return; +} + +void Core::Entry::ClearKey(void) +{ + mKeyRecord.Clear(); + mKeyData.Free(); +} + +void Core::Entry::SetCallback(const Callback &aCallback) +{ + mCallback = aCallback; + ScheduleCallbackTask(); +} + +void Core::Entry::ScheduleCallbackTask(void) +{ + switch (GetState()) + { + case kRegistered: + case kConflict: + VerifyOrExit(!mCallback.IsEmpty() || !mKeyCallback.IsEmpty()); + Get().mEntryTask.Post(); + break; + + case kProbing: + case kRemoving: + break; + } + +exit: + return; +} + +void Core::Entry::InvokeCallbacks(void) +{ + Error error = kErrorNone; + + switch (GetState()) + { + case kConflict: + error = kErrorDuplicated; + OT_FALL_THROUGH; + + case kRegistered: + mKeyCallback.InvokeAndClear(GetInstance(), error); + mCallback.InvokeAndClear(GetInstance(), error); + break; + + case kProbing: + case kRemoving: + break; + } +} + +void Core::Entry::StartProbing(void) +{ + SetState(kProbing); + mProbeCount = 0; + SetFireTime(Get().RandomizeFirstProbeTxTime()); + ScheduleTimer(); +} + +void Core::Entry::SetStateToConflict(void) +{ + switch (GetState()) + { + case kProbing: + case kRegistered: + SetState(kConflict); + break; + case kConflict: + case kRemoving: + break; + } +} + +void Core::Entry::SetStateToRemoving(void) +{ + VerifyOrExit(GetState() != kRemoving); + SetState(kRemoving); + +exit: + return; +} + +void Core::Entry::ClearAppendState(void) +{ + mKeyRecord.MarkAsNotAppended(); + mAppendedNsec = false; +} + +void Core::Entry::UpdateRecordsState(const TxMessage &aResponse) +{ + mKeyRecord.UpdateStateAfterAnswer(aResponse); + + if (mAppendedNsec) + { + switch (aResponse.GetType()) + { + case TxMessage::kMulticastResponse: + mMulticastNsecPending = false; + break; + case TxMessage::kUnicastResponse: + mUnicastNsecPending = false; + break; + default: + break; + } + } +} + +void Core::Entry::ScheduleNsecAnswer(const AnswerInfo &aInfo) +{ + // Schedules NSEC record to be included in a response message. + // Used to answer to query for a record that is not present. + + VerifyOrExit(GetState() == kRegistered); + + if (aInfo.mUnicastResponse) + { + mUnicastNsecPending = true; + } + else + { + if (mMulticastNsecPending) + { + VerifyOrExit(aInfo.mAnswerTime < mNsecAnswerTime); + } + + mMulticastNsecPending = true; + mNsecAnswerTime = aInfo.mAnswerTime; + } + +exit: + return; +} + +bool Core::Entry::ShouldAnswerNsec(TimeMilli aNow) const { return mMulticastNsecPending && (mNsecAnswerTime <= aNow); } + +void Core::Entry::AnswerNonProbe(const AnswerInfo &aInfo, RecordAndType *aRecords, uint16_t aRecordsLength) +{ + // Schedule answers for all matching records in `aRecords` array + // to a given non-probe question. + + bool allEmptyOrZeroTtl = true; + bool answerNsec = true; + + for (uint16_t index = 0; index < aRecordsLength; index++) + { + RecordInfo &record = aRecords[index].mRecord; + + if (!record.CanAnswer()) + { + // Cannot answer if record is not present or has zero TTL. + continue; + } + + allEmptyOrZeroTtl = false; + + if (QuestionMatches(aInfo.mQuestionRrType, aRecords[index].mType)) + { + answerNsec = false; + record.ScheduleAnswer(aInfo); + } + } + + // If all records are removed or have zero TTL (we are still + // sending "Goodbye" announces), we should not provide any answer + // even NSEC. + + if (!allEmptyOrZeroTtl && answerNsec) + { + ScheduleNsecAnswer(aInfo); + } +} + +void Core::Entry::AnswerProbe(const AnswerInfo &aInfo, RecordAndType *aRecords, uint16_t aRecordsLength) +{ + bool allEmptyOrZeroTtl = true; + bool shouldDelay = false; + TimeMilli now = TimerMilli::GetNow(); + AnswerInfo info = aInfo; + + info.mAnswerTime = now; + + OT_ASSERT(info.mIsProbe); + + for (uint16_t index = 0; index < aRecordsLength; index++) + { + RecordInfo &record = aRecords[index].mRecord; + TimeMilli lastMulticastTime; + + if (!record.CanAnswer()) + { + continue; + } + + allEmptyOrZeroTtl = false; + + if (!info.mUnicastResponse) + { + // Rate limiting multicast probe responses + // + // We delay the response if all records were multicast + // recently within an interval `kMinIntervalProbeResponse` + // (250 msec). + + if (record.GetDurationSinceLastMulticast(now) >= kMinIntervalProbeResponse) + { + shouldDelay = false; + } + else if (record.GetLastMulticastTime(lastMulticastTime) == kErrorNone) + { + info.mAnswerTime = Max(info.mAnswerTime, lastMulticastTime + kMinIntervalProbeResponse); + } + } + } + + if (allEmptyOrZeroTtl) + { + // All records are removed or being removed. + + // Enhancement for future: If someone is probing for + // our name, we can stop announcement of removed records + // to let the new probe requester take over the name. + + ExitNow(); + } + + if (!shouldDelay) + { + info.mAnswerTime = now; + } + + for (uint16_t index = 0; index < aRecordsLength; index++) + { + aRecords[index].mRecord.ScheduleAnswer(info); + } + +exit: + return; +} + +void Core::Entry::DetermineNextFireTime(void) +{ + mKeyRecord.UpdateFireTimeOn(*this); + + if (mMulticastNsecPending) + { + SetFireTime(mNsecAnswerTime); + } +} + +void Core::Entry::ScheduleTimer(void) { ScheduleFireTimeOn(Get().mEntryTimer); } + +template void Core::Entry::HandleTimer(EntryTimerContext &aContext) +{ + EntryType *thisAsEntryType = static_cast(this); + + thisAsEntryType->ClearAppendState(); + + VerifyOrExit(HasFireTime()); + VerifyOrExit(GetFireTime() <= aContext.GetNow()); + ClearFireTime(); + + switch (GetState()) + { + case kProbing: + if (mProbeCount < kNumberOfProbes) + { + mProbeCount++; + SetFireTime(aContext.GetNow() + kProbeWaitTime); + thisAsEntryType->PrepareProbe(aContext.GetProbeMessage()); + break; + } + + SetState(kRegistered); + thisAsEntryType->StartAnnouncing(); + + OT_FALL_THROUGH; + + case kRegistered: + thisAsEntryType->PrepareResponse(aContext.GetResponseMessage(), aContext.GetNow()); + break; + + case kConflict: + case kRemoving: + ExitNow(); + } + + thisAsEntryType->DetermineNextFireTime(); + +exit: + if (HasFireTime()) + { + aContext.UpdateNextTime(GetFireTime()); + } +} + +void Core::Entry::AppendQuestionTo(TxMessage &aTxMessage) const +{ + Message &message = aTxMessage.SelectMessageFor(kQuestionSection); + uint16_t rrClass = ResourceRecord::kClassInternet; + Question question; + + if ((mProbeCount == 1) && Get().IsQuestionUnicastAllowed()) + { + rrClass |= kClassQuestionUnicastFlag; + } + + question.SetType(ResourceRecord::kTypeAny); + question.SetClass(rrClass); + SuccessOrAssert(message.Append(question)); + + aTxMessage.IncrementRecordCount(kQuestionSection); +} + +void Core::Entry::AppendKeyRecordTo(TxMessage &aTxMessage, Section aSection, NameAppender aNameAppender) +{ + Message *message; + ResourceRecord record; + + VerifyOrExit(mKeyRecord.CanAppend()); + mKeyRecord.MarkAsAppended(aTxMessage, aSection); + + message = &aTxMessage.SelectMessageFor(aSection); + + // Use the `aNameAppender` function to allow sub-class + // to append the proper name. + + aNameAppender(*this, aTxMessage, aSection); + + record.Init(ResourceRecord::kTypeKey); + record.SetTtl(mKeyRecord.GetTtl()); + record.SetLength(mKeyData.GetLength()); + UpdateCacheFlushFlagIn(record, aSection); + + SuccessOrAssert(message->Append(record)); + SuccessOrAssert(message->AppendBytes(mKeyData.GetBytes(), mKeyData.GetLength())); + + aTxMessage.IncrementRecordCount(aSection); + +exit: + return; +} + +void Core::Entry::AppendNsecRecordTo(TxMessage &aTxMessage, + Section aSection, + const TypeArray &aTypes, + NameAppender aNameAppender) +{ + Message &message = aTxMessage.SelectMessageFor(aSection); + NsecRecord nsec; + NsecRecord::TypeBitMap bitmap; + uint16_t offset; + + nsec.Init(); + nsec.SetTtl(kNsecTtl); + UpdateCacheFlushFlagIn(nsec, aSection); + + bitmap.Clear(); + + for (uint16_t type : aTypes) + { + bitmap.AddType(type); + } + + aNameAppender(*this, aTxMessage, aSection); + + offset = message.GetLength(); + SuccessOrAssert(message.Append(nsec)); + + // Next Domain Name (should be same as record name). + aNameAppender(*this, aTxMessage, aSection); + + SuccessOrAssert(message.AppendBytes(&bitmap, bitmap.GetSize())); + + UpdateRecordLengthInMessage(nsec, message, offset); + aTxMessage.IncrementRecordCount(aSection); + + mAppendedNsec = true; +} + +//---------------------------------------------------------------------------------------------------------------------- +// Core::HostEntry + +Core::HostEntry::HostEntry(void) + : mNext(nullptr) + , mNameOffset(kUnspecifiedOffset) +{ +} + +Error Core::HostEntry::Init(Instance &aInstance, const char *aName) +{ + Entry::Init(aInstance); + + return mName.Set(aName); +} + +bool Core::HostEntry::Matches(const Name &aName) const +{ + return aName.Matches(/* aFirstLabel */ nullptr, mName.AsCString(), kLocalDomain); +} + +bool Core::HostEntry::Matches(const Host &aHost) const { return NameMatch(mName, aHost.mHostName); } + +bool Core::HostEntry::Matches(const Key &aKey) const { return !IsKeyForService(aKey) && NameMatch(mName, aKey.mName); } + +bool Core::HostEntry::Matches(const Heap::String &aName) const { return NameMatch(mName, aName); } + +bool Core::HostEntry::IsEmpty(void) const { return !mAddrRecord.IsPresent() && !mKeyRecord.IsPresent(); } + +void Core::HostEntry::Register(const Host &aHost, const Callback &aCallback) +{ + if (GetState() == kRemoving) + { + StartProbing(); + } + + SetCallback(aCallback); + + if (aHost.mAddressesLength == 0) + { + // If host is registered with no addresses, treat it + // as host being unregistered and announce removal of + // the old addresses. + Unregister(aHost); + ExitNow(); + } + + mAddrRecord.UpdateTtl(DetermineTtl(aHost.mTtl, kDefaultTtl)); + mAddrRecord.UpdateProperty(mAddresses, AsCoreTypePtr(aHost.mAddresses), aHost.mAddressesLength); + + DetermineNextFireTime(); + ScheduleTimer(); + +exit: + return; +} + +void Core::HostEntry::Register(const Key &aKey, const Callback &aCallback) +{ + Entry::Register(aKey, aCallback); + + DetermineNextFireTime(); + ScheduleTimer(); +} + +void Core::HostEntry::Unregister(const Host &aHost) +{ + OT_UNUSED_VARIABLE(aHost); + + VerifyOrExit(mAddrRecord.IsPresent()); + + ClearCallback(); + + switch (GetState()) + { + case kRegistered: + mAddrRecord.UpdateTtl(0); + DetermineNextFireTime(); + ScheduleTimer(); + break; + + case kProbing: + case kConflict: + ClearHost(); + ScheduleToRemoveIfEmpty(); + break; + + case kRemoving: + break; + } + +exit: + return; +} + +void Core::HostEntry::Unregister(const Key &aKey) +{ + Entry::Unregister(aKey); + + DetermineNextFireTime(); + ScheduleTimer(); + + ScheduleToRemoveIfEmpty(); +} + +void Core::HostEntry::ClearHost(void) +{ + mAddrRecord.Clear(); + mAddresses.Free(); +} + +void Core::HostEntry::ScheduleToRemoveIfEmpty(void) +{ + if (IsEmpty()) + { + SetStateToRemoving(); + Get().mEntryTask.Post(); + } +} + +void Core::HostEntry::HandleConflict(void) +{ + State oldState = GetState(); + + SetStateToConflict(); + VerifyOrExit(oldState == kRegistered); + Get().InvokeConflictCallback(mName.AsCString(), nullptr); + +exit: + return; +} + +void Core::HostEntry::AnswerQuestion(const AnswerInfo &aInfo) +{ + RecordAndType records[] = { + {mAddrRecord, ResourceRecord::kTypeAaaa}, + {mKeyRecord, ResourceRecord::kTypeKey}, + }; + + VerifyOrExit(GetState() == kRegistered); + + if (aInfo.mIsProbe) + { + AnswerProbe(aInfo, records, GetArrayLength(records)); + } + else + { + AnswerNonProbe(aInfo, records, GetArrayLength(records)); + } + + DetermineNextFireTime(); + ScheduleTimer(); + +exit: + return; +} + +void Core::HostEntry::HandleTimer(EntryTimerContext &aContext) { Entry::HandleTimer(aContext); } + +void Core::HostEntry::ClearAppendState(void) +{ + // Clears `HostEntry` records and all tracked saved name + // compression offsets. + + Entry::ClearAppendState(); + + mAddrRecord.MarkAsNotAppended(); + + mNameOffset = kUnspecifiedOffset; +} + +void Core::HostEntry::PrepareProbe(TxMessage &aProbe) +{ + bool prepareAgain = false; + + do + { + aProbe.SaveCurrentState(); + + AppendNameTo(aProbe, kQuestionSection); + AppendQuestionTo(aProbe); + + AppendAddressRecordsTo(aProbe, kAuthoritySection); + AppendKeyRecordTo(aProbe, kAuthoritySection); + + aProbe.CheckSizeLimitToPrepareAgain(prepareAgain); + + } while (prepareAgain); +} + +void Core::HostEntry::StartAnnouncing(void) +{ + mAddrRecord.StartAnnouncing(); + mKeyRecord.StartAnnouncing(); +} + +void Core::HostEntry::PrepareResponse(TxMessage &aResponse, TimeMilli aNow) +{ + bool prepareAgain = false; + + do + { + aResponse.SaveCurrentState(); + PrepareResponseRecords(aResponse, aNow); + aResponse.CheckSizeLimitToPrepareAgain(prepareAgain); + + } while (prepareAgain); + + UpdateRecordsState(aResponse); +} + +void Core::HostEntry::PrepareResponseRecords(TxMessage &aResponse, TimeMilli aNow) +{ + bool appendNsec = false; + + if (mAddrRecord.ShouldAppendTo(aResponse, aNow)) + { + AppendAddressRecordsTo(aResponse, kAnswerSection); + appendNsec = true; + } + + if (mKeyRecord.ShouldAppendTo(aResponse, aNow)) + { + AppendKeyRecordTo(aResponse, kAnswerSection); + appendNsec = true; + } + + if (appendNsec || ShouldAnswerNsec(aNow)) + { + AppendNsecRecordTo(aResponse, kAdditionalDataSection); + } +} + +void Core::HostEntry::UpdateRecordsState(const TxMessage &aResponse) +{ + // Updates state after a response is prepared. + + Entry::UpdateRecordsState(aResponse); + mAddrRecord.UpdateStateAfterAnswer(aResponse); + + if (IsEmpty()) + { + SetStateToRemoving(); + } +} + +void Core::HostEntry::DetermineNextFireTime(void) +{ + VerifyOrExit(GetState() == kRegistered); + + Entry::DetermineNextFireTime(); + mAddrRecord.UpdateFireTimeOn(*this); + +exit: + return; +} + +void Core::HostEntry::AppendAddressRecordsTo(TxMessage &aTxMessage, Section aSection) +{ + Message *message; + + VerifyOrExit(mAddrRecord.CanAppend()); + mAddrRecord.MarkAsAppended(aTxMessage, aSection); + + message = &aTxMessage.SelectMessageFor(aSection); + + for (const Ip6::Address &address : mAddresses) + { + AaaaRecord aaaaRecord; + + aaaaRecord.Init(); + aaaaRecord.SetTtl(mAddrRecord.GetTtl()); + aaaaRecord.SetAddress(address); + UpdateCacheFlushFlagIn(aaaaRecord, aSection); + + AppendNameTo(aTxMessage, aSection); + SuccessOrAssert(message->Append(aaaaRecord)); + + aTxMessage.IncrementRecordCount(aSection); + } + +exit: + return; +} + +void Core::HostEntry::AppendKeyRecordTo(TxMessage &aTxMessage, Section aSection) +{ + Entry::AppendKeyRecordTo(aTxMessage, aSection, &AppendEntryName); +} + +void Core::HostEntry::AppendNsecRecordTo(TxMessage &aTxMessage, Section aSection) +{ + TypeArray types; + + if (mAddrRecord.IsPresent() && (mAddrRecord.GetTtl() > 0)) + { + types.Add(ResourceRecord::kTypeAaaa); + } + + if (mKeyRecord.IsPresent() && (mKeyRecord.GetTtl() > 0)) + { + types.Add(ResourceRecord::kTypeKey); + } + + if (!types.IsEmpty()) + { + Entry::AppendNsecRecordTo(aTxMessage, aSection, types, &AppendEntryName); + } +} + +void Core::HostEntry::AppendEntryName(Entry &aEntry, TxMessage &aTxMessage, Section aSection) +{ + static_cast(aEntry).AppendNameTo(aTxMessage, aSection); +} + +void Core::HostEntry::AppendNameTo(TxMessage &aTxMessage, Section aSection) +{ + AppendOutcome outcome; + + outcome = aTxMessage.AppendMultipleLabels(aSection, mName.AsCString(), mNameOffset); + VerifyOrExit(outcome != kAppendedFullNameAsCompressed); + + aTxMessage.AppendDomainName(aSection); + +exit: + return; +} + +//---------------------------------------------------------------------------------------------------------------------- +// Core::ServiceEntry + +const uint8_t Core::ServiceEntry::kEmptyTxtData[] = {0}; + +Core::ServiceEntry::ServiceEntry(void) + : mNext(nullptr) + , mPriority(0) + , mWeight(0) + , mPort(0) + , mServiceNameOffset(kUnspecifiedOffset) + , mServiceTypeOffset(kUnspecifiedOffset) + , mSubServiceTypeOffset(kUnspecifiedOffset) + , mHostNameOffset(kUnspecifiedOffset) + , mIsAddedInServiceTypes(false) +{ +} + +Error Core::ServiceEntry::Init(Instance &aInstance, const char *aServiceInstance, const char *aServiceType) +{ + Error error; + + Entry::Init(aInstance); + + SuccessOrExit(error = mServiceInstance.Set(aServiceInstance)); + SuccessOrExit(error = mServiceType.Set(aServiceType)); + +exit: + return error; +} + +Error Core::ServiceEntry::Init(Instance &aInstance, const Service &aService) +{ + return Init(aInstance, aService.mServiceInstance, aService.mServiceType); +} + +Error Core::ServiceEntry::Init(Instance &aInstance, const Key &aKey) +{ + return Init(aInstance, aKey.mName, aKey.mServiceType); +} + +bool Core::ServiceEntry::Matches(const Name &aFullName) const +{ + return aFullName.Matches(mServiceInstance.AsCString(), mServiceType.AsCString(), kLocalDomain); +} + +bool Core::ServiceEntry::MatchesServiceType(const Name &aServiceType) const +{ + // When matching service type, PTR record should be + // present with non-zero TTL (checked by `CanAnswer()`). + + return mPtrRecord.CanAnswer() && aServiceType.Matches(nullptr, mServiceType.AsCString(), kLocalDomain); +} + +bool Core::ServiceEntry::Matches(const Service &aService) const +{ + return NameMatch(mServiceInstance, aService.mServiceInstance) && NameMatch(mServiceType, aService.mServiceType); +} + +bool Core::ServiceEntry::Matches(const Key &aKey) const +{ + return IsKeyForService(aKey) && NameMatch(mServiceInstance, aKey.mName) && + NameMatch(mServiceType, aKey.mServiceType); +} + +bool Core::ServiceEntry::IsEmpty(void) const { return !mPtrRecord.IsPresent() && !mKeyRecord.IsPresent(); } + +bool Core::ServiceEntry::CanAnswerSubType(const char *aSubLabel) const +{ + bool canAnswer = false; + const SubType *subType; + + VerifyOrExit(mPtrRecord.CanAnswer()); + + subType = mSubTypes.FindMatching(aSubLabel); + VerifyOrExit(subType != nullptr); + + canAnswer = subType->mPtrRecord.CanAnswer(); + +exit: + return canAnswer; +} + +void Core::ServiceEntry::Register(const Service &aService, const Callback &aCallback) +{ + uint32_t ttl = DetermineTtl(aService.mTtl, kDefaultTtl); + + if (GetState() == kRemoving) + { + StartProbing(); + } + + SetCallback(aCallback); + + // Register sub-types PTRs. + + // First we check for any removed sub-types. We keep removed + // sub-types marked with zero TTL so to announce their removal + // before fully removing them from the list. + + for (SubType &subType : mSubTypes) + { + uint32_t subTypeTtl = subType.IsContainedIn(aService) ? ttl : 0; + + subType.mPtrRecord.UpdateTtl(subTypeTtl); + } + + // Next we add any new sub-types in `aService`. + + for (uint16_t i = 0; i < aService.mSubTypeLabelsLength; i++) + { + const char *label = aService.mSubTypeLabels[i]; + + if (!mSubTypes.ContainsMatching(label)) + { + SubType *newSubType = SubType::AllocateAndInit(label); + + OT_ASSERT(newSubType != nullptr); + mSubTypes.Push(*newSubType); + + newSubType->mPtrRecord.UpdateTtl(ttl); + } + } + + // Register base PTR service. + + mPtrRecord.UpdateTtl(ttl); + + // Register SRV record info. + + mSrvRecord.UpdateTtl(ttl); + mSrvRecord.UpdateProperty(mHostName, aService.mHostName); + mSrvRecord.UpdateProperty(mPriority, aService.mPriority); + mSrvRecord.UpdateProperty(mWeight, aService.mWeight); + mSrvRecord.UpdateProperty(mPort, aService.mPort); + + // Register TXT record info. + + mTxtRecord.UpdateTtl(ttl); + + if ((aService.mTxtData == nullptr) || (aService.mTxtDataLength == 0)) + { + mTxtRecord.UpdateProperty(mTxtData, kEmptyTxtData, sizeof(kEmptyTxtData)); + } + else + { + mTxtRecord.UpdateProperty(mTxtData, aService.mTxtData, aService.mTxtDataLength); + } + + UpdateServiceTypes(); + + DetermineNextFireTime(); + ScheduleTimer(); +} + +void Core::ServiceEntry::Register(const Key &aKey, const Callback &aCallback) +{ + Entry::Register(aKey, aCallback); + + DetermineNextFireTime(); + ScheduleTimer(); +} + +void Core::ServiceEntry::Unregister(const Service &aService) +{ + OT_UNUSED_VARIABLE(aService); + + VerifyOrExit(mPtrRecord.IsPresent()); + + ClearCallback(); + + switch (GetState()) + { + case kRegistered: + for (SubType &subType : mSubTypes) + { + subType.mPtrRecord.UpdateTtl(0); + } + + mPtrRecord.UpdateTtl(0); + mSrvRecord.UpdateTtl(0); + mTxtRecord.UpdateTtl(0); + DetermineNextFireTime(); + ScheduleTimer(); + break; + + case kProbing: + case kConflict: + ClearService(); + ScheduleToRemoveIfEmpty(); + break; + + case kRemoving: + break; + } + + UpdateServiceTypes(); + +exit: + return; +} + +void Core::ServiceEntry::Unregister(const Key &aKey) +{ + Entry::Unregister(aKey); + + DetermineNextFireTime(); + ScheduleTimer(); + + ScheduleToRemoveIfEmpty(); +} + +void Core::ServiceEntry::ClearService(void) +{ + mPtrRecord.Clear(); + mSrvRecord.Clear(); + mTxtRecord.Clear(); + mSubTypes.Free(); + mHostName.Free(); + mTxtData.Free(); +} + +void Core::ServiceEntry::ScheduleToRemoveIfEmpty(void) +{ + OwningList removedSubTypes; + + mSubTypes.RemoveAllMatching(EmptyChecker(), removedSubTypes); + + if (IsEmpty()) + { + SetStateToRemoving(); + Get().mEntryTask.Post(); + } +} + +void Core::ServiceEntry::HandleConflict(void) +{ + State oldState = GetState(); + + SetStateToConflict(); + UpdateServiceTypes(); + + VerifyOrExit(oldState == kRegistered); + Get().InvokeConflictCallback(mServiceInstance.AsCString(), mServiceType.AsCString()); + +exit: + return; +} + +void Core::ServiceEntry::AnswerServiceNameQuestion(const AnswerInfo &aInfo) +{ + RecordAndType records[] = { + {mSrvRecord, ResourceRecord::kTypeSrv}, + {mTxtRecord, ResourceRecord::kTypeTxt}, + {mKeyRecord, ResourceRecord::kTypeKey}, + }; + + VerifyOrExit(GetState() == kRegistered); + + if (aInfo.mIsProbe) + { + AnswerProbe(aInfo, records, GetArrayLength(records)); + } + else + { + AnswerNonProbe(aInfo, records, GetArrayLength(records)); + } + + DetermineNextFireTime(); + ScheduleTimer(); + +exit: + return; +} + +void Core::ServiceEntry::AnswerServiceTypeQuestion(const AnswerInfo &aInfo, const char *aSubLabel) +{ + VerifyOrExit(GetState() == kRegistered); + + if (aSubLabel == nullptr) + { + mPtrRecord.ScheduleAnswer(aInfo); + } + else + { + SubType *subType = mSubTypes.FindMatching(aSubLabel); + + VerifyOrExit(subType != nullptr); + subType->mPtrRecord.ScheduleAnswer(aInfo); + } + + DetermineNextFireTime(); + ScheduleTimer(); + +exit: + return; +} + +bool Core::ServiceEntry::ShouldSuppressKnownAnswer(uint32_t aTtl, const char *aSubLabel) const +{ + // Check `aTtl` of a matching record in known-answer section of + // a query with the corresponding PTR record's TTL and suppress + // answer if it is at least at least half the correct value. + + bool shouldSuppress = false; + uint32_t ttl; + + if (aSubLabel == nullptr) + { + ttl = mPtrRecord.GetTtl(); + } + else + { + const SubType *subType = mSubTypes.FindMatching(aSubLabel); + + VerifyOrExit(subType != nullptr); + ttl = subType->mPtrRecord.GetTtl(); + } + + shouldSuppress = (aTtl > ttl / 2); + +exit: + return shouldSuppress; +} + +void Core::ServiceEntry::HandleTimer(EntryTimerContext &aContext) { Entry::HandleTimer(aContext); } + +void Core::ServiceEntry::ClearAppendState(void) +{ + // Clear the append state for all `ServiceEntry` records, + // along with all tracked name compression offsets. + + Entry::ClearAppendState(); + + mPtrRecord.MarkAsNotAppended(); + mSrvRecord.MarkAsNotAppended(); + mTxtRecord.MarkAsNotAppended(); + + mServiceNameOffset = kUnspecifiedOffset; + mServiceTypeOffset = kUnspecifiedOffset; + mSubServiceTypeOffset = kUnspecifiedOffset; + mHostNameOffset = kUnspecifiedOffset; + + for (SubType &subType : mSubTypes) + { + subType.mPtrRecord.MarkAsNotAppended(); + subType.mSubServiceNameOffset = kUnspecifiedOffset; + } +} + +void Core::ServiceEntry::PrepareProbe(TxMessage &aProbe) +{ + bool prepareAgain = false; + + do + { + HostEntry *hostEntry = nullptr; + + aProbe.SaveCurrentState(); + + DiscoverOffsetsAndHost(hostEntry); + + AppendServiceNameTo(aProbe, kQuestionSection); + AppendQuestionTo(aProbe); + + // Append records (if present) in authority section + + AppendSrvRecordTo(aProbe, kAuthoritySection); + AppendTxtRecordTo(aProbe, kAuthoritySection); + AppendKeyRecordTo(aProbe, kAuthoritySection); + + aProbe.CheckSizeLimitToPrepareAgain(prepareAgain); + + } while (prepareAgain); +} + +void Core::ServiceEntry::StartAnnouncing(void) +{ + for (SubType &subType : mSubTypes) + { + subType.mPtrRecord.StartAnnouncing(); + } + + mPtrRecord.StartAnnouncing(); + mSrvRecord.StartAnnouncing(); + mTxtRecord.StartAnnouncing(); + mKeyRecord.StartAnnouncing(); + + UpdateServiceTypes(); +} + +void Core::ServiceEntry::PrepareResponse(TxMessage &aResponse, TimeMilli aNow) +{ + bool prepareAgain = false; + + do + { + aResponse.SaveCurrentState(); + PrepareResponseRecords(aResponse, aNow); + aResponse.CheckSizeLimitToPrepareAgain(prepareAgain); + + } while (prepareAgain); + + UpdateRecordsState(aResponse); +} + +void Core::ServiceEntry::PrepareResponseRecords(TxMessage &aResponse, TimeMilli aNow) +{ + bool appendNsec = false; + HostEntry *hostEntry = nullptr; + + DiscoverOffsetsAndHost(hostEntry); + + // We determine records to include in Additional Data section + // per RFC 6763 section 12: + // + // - For base PTR, we include SRV, TXT, and host addresses. + // - For SRV, we include host addresses only (TXT record not + // recommended). + // + // Records already appended in Answer section are excluded from + // Additional Data. Host Entries are processed before Service + // Entries which ensures address inclusion accuracy. + // `MarkToAppendInAdditionalData()` marks a record for potential + // Additional Data inclusion, but this is skipped if the record + // is already appended in the Answer section. + + if (mPtrRecord.ShouldAppendTo(aResponse, aNow)) + { + AppendPtrRecordTo(aResponse, kAnswerSection); + + if (mPtrRecord.GetTtl() > 0) + { + mSrvRecord.MarkToAppendInAdditionalData(); + mTxtRecord.MarkToAppendInAdditionalData(); + + if (hostEntry != nullptr) + { + hostEntry->mAddrRecord.MarkToAppendInAdditionalData(); + } + } + } + + for (SubType &subType : mSubTypes) + { + if (subType.mPtrRecord.ShouldAppendTo(aResponse, aNow)) + { + AppendPtrRecordTo(aResponse, kAnswerSection, &subType); + } + } + + if (mSrvRecord.ShouldAppendTo(aResponse, aNow)) + { + AppendSrvRecordTo(aResponse, kAnswerSection); + appendNsec = true; + + if ((mSrvRecord.GetTtl() > 0) && (hostEntry != nullptr)) + { + hostEntry->mAddrRecord.MarkToAppendInAdditionalData(); + } + } + + if (mTxtRecord.ShouldAppendTo(aResponse, aNow)) + { + AppendTxtRecordTo(aResponse, kAnswerSection); + appendNsec = true; + } + + if (mKeyRecord.ShouldAppendTo(aResponse, aNow)) + { + AppendKeyRecordTo(aResponse, kAnswerSection); + appendNsec = true; + } + + // Append records in Additional Data section + + if (mSrvRecord.ShouldAppendInAdditionalDataSection()) + { + AppendSrvRecordTo(aResponse, kAdditionalDataSection); + } + + if (mTxtRecord.ShouldAppendInAdditionalDataSection()) + { + AppendTxtRecordTo(aResponse, kAdditionalDataSection); + } + + if ((hostEntry != nullptr) && (hostEntry->mAddrRecord.ShouldAppendInAdditionalDataSection())) + { + hostEntry->AppendAddressRecordsTo(aResponse, kAdditionalDataSection); + } + + if (appendNsec || ShouldAnswerNsec(aNow)) + { + AppendNsecRecordTo(aResponse, kAdditionalDataSection); + } +} + +void Core::ServiceEntry::UpdateRecordsState(const TxMessage &aResponse) +{ + OwningList removedSubTypes; + + Entry::UpdateRecordsState(aResponse); + + mPtrRecord.UpdateStateAfterAnswer(aResponse); + mSrvRecord.UpdateStateAfterAnswer(aResponse); + mTxtRecord.UpdateStateAfterAnswer(aResponse); + + for (SubType &subType : mSubTypes) + { + subType.mPtrRecord.UpdateStateAfterAnswer(aResponse); + } + + mSubTypes.RemoveAllMatching(EmptyChecker(), removedSubTypes); + + if (IsEmpty()) + { + SetStateToRemoving(); + } +} + +void Core::ServiceEntry::DetermineNextFireTime(void) +{ + VerifyOrExit(GetState() == kRegistered); + + Entry::DetermineNextFireTime(); + + mPtrRecord.UpdateFireTimeOn(*this); + mSrvRecord.UpdateFireTimeOn(*this); + mTxtRecord.UpdateFireTimeOn(*this); + + for (SubType &subType : mSubTypes) + { + subType.mPtrRecord.UpdateFireTimeOn(*this); + } + +exit: + return; +} + +void Core::ServiceEntry::DiscoverOffsetsAndHost(HostEntry *&aHostEntry) +{ + // Discovers the `HostEntry` associated with this `ServiceEntry` + // and name compression offsets from the previously appended + // entries. + + aHostEntry = Get().mHostEntries.FindMatching(mHostName); + + if ((aHostEntry != nullptr) && (aHostEntry->GetState() != GetState())) + { + aHostEntry = nullptr; + } + + if (aHostEntry != nullptr) + { + UpdateCompressOffset(mHostNameOffset, aHostEntry->mNameOffset); + } + + for (ServiceEntry &other : Get().mServiceEntries) + { + // We only need to search up to `this` entry in the list, + // since entries after `this` are not yet processed and not + // yet appended in the response or the probe message. + + if (&other == this) + { + break; + } + + if (other.GetState() != GetState()) + { + // Validate that both entries are in the same state, + // ensuring their records are appended in the same + // message, i.e., a probe or a response message. + + continue; + } + + if (NameMatch(mHostName, other.mHostName)) + { + UpdateCompressOffset(mHostNameOffset, other.mHostNameOffset); + } + + if (NameMatch(mServiceType, other.mServiceType)) + { + UpdateCompressOffset(mServiceTypeOffset, other.mServiceTypeOffset); + + if (GetState() == kProbing) + { + // No need to search for sub-type service offsets when + // we are still probing. + + continue; + } + + UpdateCompressOffset(mSubServiceTypeOffset, other.mSubServiceTypeOffset); + + for (SubType &subType : mSubTypes) + { + const SubType *otherSubType = other.mSubTypes.FindMatching(subType.mLabel.AsCString()); + + if (otherSubType != nullptr) + { + UpdateCompressOffset(subType.mSubServiceNameOffset, otherSubType->mSubServiceNameOffset); + } + } + } + } +} + +void Core::ServiceEntry::UpdateServiceTypes(void) +{ + // This method updates the `mServiceTypes` list adding or + // removing this `ServiceEntry` info. + // + // It is called whenever `ServcieEntry` state gets changed or an + // PTR record is added or removed. The service is valid when + // entry is registered and we have a PTR with non-zero TTL. + + bool shouldAdd = (GetState() == kRegistered) && mPtrRecord.CanAnswer(); + ServiceType *serviceType; + + VerifyOrExit(shouldAdd != mIsAddedInServiceTypes); + + mIsAddedInServiceTypes = shouldAdd; + + serviceType = Get().mServiceTypes.FindMatching(mServiceType); + + if (shouldAdd && (serviceType == nullptr)) + { + serviceType = ServiceType::AllocateAndInit(GetInstance(), mServiceType.AsCString()); + OT_ASSERT(serviceType != nullptr); + Get().mServiceTypes.Push(*serviceType); + } + + VerifyOrExit(serviceType != nullptr); + + if (shouldAdd) + { + serviceType->IncrementNumEntries(); + } + else + { + serviceType->DecrementNumEntries(); + + if (serviceType->GetNumEntries() == 0) + { + // If there are no more `ServiceEntry` with + // this service type, we remove the it from + // the `mServiceTypes` list. It is safe to + // remove here as this method will never be + // called while we are iterating over the + // `mServcieTypes` list. + + Get().mServiceTypes.RemoveMatching(*serviceType); + } + } + +exit: + return; +} + +void Core::ServiceEntry::AppendSrvRecordTo(TxMessage &aTxMessage, Section aSection) +{ + Message *message; + SrvRecord srv; + uint16_t offset; + + VerifyOrExit(mSrvRecord.CanAppend()); + mSrvRecord.MarkAsAppended(aTxMessage, aSection); + + message = &aTxMessage.SelectMessageFor(aSection); + + srv.Init(); + srv.SetTtl(mSrvRecord.GetTtl()); + srv.SetPriority(mPriority); + srv.SetWeight(mWeight); + srv.SetPort(mPort); + UpdateCacheFlushFlagIn(srv, aSection); + + AppendServiceNameTo(aTxMessage, aSection); + offset = message->GetLength(); + SuccessOrAssert(message->Append(srv)); + AppendHostNameTo(aTxMessage, aSection); + UpdateRecordLengthInMessage(srv, *message, offset); + + aTxMessage.IncrementRecordCount(aSection); + +exit: + return; +} + +void Core::ServiceEntry::AppendTxtRecordTo(TxMessage &aTxMessage, Section aSection) +{ + Message *message; + TxtRecord txt; + + VerifyOrExit(mTxtRecord.CanAppend()); + mTxtRecord.MarkAsAppended(aTxMessage, aSection); + + message = &aTxMessage.SelectMessageFor(aSection); + + txt.Init(); + txt.SetTtl(mTxtRecord.GetTtl()); + txt.SetLength(mTxtData.GetLength()); + UpdateCacheFlushFlagIn(txt, aSection); + + AppendServiceNameTo(aTxMessage, aSection); + SuccessOrAssert(message->Append(txt)); + SuccessOrAssert(message->AppendBytes(mTxtData.GetBytes(), mTxtData.GetLength())); + + aTxMessage.IncrementRecordCount(aSection); + +exit: + return; +} + +void Core::ServiceEntry::AppendPtrRecordTo(TxMessage &aTxMessage, Section aSection, SubType *aSubType) +{ + // Appends PTR record for base service (when `aSubType == nullptr`) or + // for the given `aSubType`. + + Message *message; + RecordInfo &ptrRecord = (aSubType == nullptr) ? mPtrRecord : aSubType->mPtrRecord; + PtrRecord ptr; + uint16_t offset; + + VerifyOrExit(ptrRecord.CanAppend()); + ptrRecord.MarkAsAppended(aTxMessage, aSection); + + message = &aTxMessage.SelectMessageFor(aSection); + + ptr.Init(); + ptr.SetTtl(ptrRecord.GetTtl()); + + if (aSubType == nullptr) + { + AppendServiceTypeTo(aTxMessage, aSection); + } + else + { + AppendSubServiceNameTo(aTxMessage, aSection, *aSubType); + } + + offset = message->GetLength(); + SuccessOrAssert(message->Append(ptr)); + AppendServiceNameTo(aTxMessage, aSection); + UpdateRecordLengthInMessage(ptr, *message, offset); + + aTxMessage.IncrementRecordCount(aSection); + +exit: + return; +} + +void Core::ServiceEntry::AppendKeyRecordTo(TxMessage &aTxMessage, Section aSection) +{ + Entry::AppendKeyRecordTo(aTxMessage, aSection, &AppendEntryName); +} + +void Core::ServiceEntry::AppendNsecRecordTo(TxMessage &aTxMessage, Section aSection) +{ + TypeArray types; + + if (mSrvRecord.IsPresent() && (mSrvRecord.GetTtl() > 0)) + { + types.Add(ResourceRecord::kTypeSrv); + } + + if (mTxtRecord.IsPresent() && (mTxtRecord.GetTtl() > 0)) + { + types.Add(ResourceRecord::kTypeTxt); + } + + if (mKeyRecord.IsPresent() && (mKeyRecord.GetTtl() > 0)) + { + types.Add(ResourceRecord::kTypeKey); + } + + if (!types.IsEmpty()) + { + Entry::AppendNsecRecordTo(aTxMessage, aSection, types, &AppendEntryName); + } +} + +void Core::ServiceEntry::AppendEntryName(Entry &aEntry, TxMessage &aTxMessage, Section aSection) +{ + static_cast(aEntry).AppendServiceNameTo(aTxMessage, aSection); +} + +void Core::ServiceEntry::AppendServiceNameTo(TxMessage &aTxMessage, Section aSection) +{ + AppendOutcome outcome; + + outcome = aTxMessage.AppendLabel(aSection, mServiceInstance.AsCString(), mServiceNameOffset); + VerifyOrExit(outcome != kAppendedFullNameAsCompressed); + + AppendServiceTypeTo(aTxMessage, aSection); + +exit: + return; +} + +void Core::ServiceEntry::AppendServiceTypeTo(TxMessage &aTxMessage, Section aSection) +{ + aTxMessage.AppendServiceType(aSection, mServiceType.AsCString(), mServiceTypeOffset); +} + +void Core::ServiceEntry::AppendSubServiceTypeTo(TxMessage &aTxMessage, Section aSection) +{ + AppendOutcome outcome; + + outcome = aTxMessage.AppendLabel(aSection, kSubServiceLabel, mSubServiceTypeOffset); + VerifyOrExit(outcome != kAppendedFullNameAsCompressed); + + AppendServiceTypeTo(aTxMessage, aSection); + +exit: + return; +} + +void Core::ServiceEntry::AppendSubServiceNameTo(TxMessage &aTxMessage, Section aSection, SubType &aSubType) +{ + AppendOutcome outcome; + + outcome = aTxMessage.AppendLabel(aSection, aSubType.mLabel.AsCString(), aSubType.mSubServiceNameOffset); + VerifyOrExit(outcome != kAppendedFullNameAsCompressed); + + AppendSubServiceTypeTo(aTxMessage, aSection); + +exit: + return; +} + +void Core::ServiceEntry::AppendHostNameTo(TxMessage &aTxMessage, Section aSection) +{ + AppendOutcome outcome; + + outcome = aTxMessage.AppendMultipleLabels(aSection, mHostName.AsCString(), mHostNameOffset); + VerifyOrExit(outcome != kAppendedFullNameAsCompressed); + + aTxMessage.AppendDomainName(aSection); + +exit: + return; +} + +//---------------------------------------------------------------------------------------------------------------------- +// Core::ServiceEntry::SubType + +Error Core::ServiceEntry::SubType::Init(const char *aLabel) +{ + mSubServiceNameOffset = kUnspecifiedOffset; + + return mLabel.Set(aLabel); +} + +bool Core::ServiceEntry::SubType::Matches(const EmptyChecker &aChecker) const +{ + OT_UNUSED_VARIABLE(aChecker); + + return !mPtrRecord.IsPresent(); +} + +bool Core::ServiceEntry::SubType::IsContainedIn(const Service &aService) const +{ + bool contains = false; + + for (uint16_t i = 0; i < aService.mSubTypeLabelsLength; i++) + { + if (NameMatch(mLabel, aService.mSubTypeLabels[i])) + { + contains = true; + break; + } + } + + return contains; +} + +//---------------------------------------------------------------------------------------------------------------------- +// Core::ServiceType + +Error Core::ServiceType::Init(Instance &aInstance, const char *aServiceType) +{ + Error error; + + InstanceLocatorInit::Init(aInstance); + + mNext = nullptr; + mNumEntries = 0; + SuccessOrExit(error = mServiceType.Set(aServiceType)); + + mServicesPtr.UpdateTtl(kServicesPtrTtl); + mServicesPtr.StartAnnouncing(); + + mServicesPtr.UpdateFireTimeOn(*this); + ScheduleFireTimeOn(Get().mEntryTimer); + +exit: + return error; +} + +bool Core::ServiceType::Matches(const Name &aServcieTypeName) const +{ + return aServcieTypeName.Matches(/* aFirstLabel */ nullptr, mServiceType.AsCString(), kLocalDomain); +} + +bool Core::ServiceType::Matches(const Heap::String &aServiceType) const +{ + return NameMatch(aServiceType, mServiceType); +} + +void Core::ServiceType::ClearAppendState(void) { mServicesPtr.MarkAsNotAppended(); } + +void Core::ServiceType::AnswerQuestion(const AnswerInfo &aInfo) +{ + VerifyOrExit(mServicesPtr.CanAnswer()); + mServicesPtr.ScheduleAnswer(aInfo); + mServicesPtr.UpdateFireTimeOn(*this); + ScheduleFireTimeOn(Get().mEntryTimer); + +exit: + return; +} + +bool Core::ServiceType::ShouldSuppressKnownAnswer(uint32_t aTtl) const +{ + // Check `aTtl` of a matching record in known-answer section of + // a query with the corresponding PTR record's TTL and suppress + // answer if it is at least at least half the correct value. + + return (aTtl > mServicesPtr.GetTtl() / 2); +} + +void Core::ServiceType::HandleTimer(EntryTimerContext &aContext) +{ + ClearAppendState(); + + VerifyOrExit(HasFireTime()); + VerifyOrExit(GetFireTime() <= aContext.GetNow()); + ClearFireTime(); + + PrepareResponse(aContext.GetResponseMessage(), aContext.GetNow()); + + mServicesPtr.UpdateFireTimeOn(*this); + +exit: + if (HasFireTime()) + { + aContext.UpdateNextTime(GetFireTime()); + } +} + +void Core::ServiceType::PrepareResponse(TxMessage &aResponse, TimeMilli aNow) +{ + bool prepareAgain = false; + + do + { + aResponse.SaveCurrentState(); + PrepareResponseRecords(aResponse, aNow); + aResponse.CheckSizeLimitToPrepareAgain(prepareAgain); + + } while (prepareAgain); + + mServicesPtr.UpdateStateAfterAnswer(aResponse); +} + +void Core::ServiceType::PrepareResponseRecords(TxMessage &aResponse, TimeMilli aNow) +{ + uint16_t serviceTypeOffset = kUnspecifiedOffset; + + VerifyOrExit(mServicesPtr.ShouldAppendTo(aResponse, aNow)); + + // Discover compress offset for `mServiceType` if previously + // appended from any `ServiceEntry`. + + for (const ServiceEntry &serviceEntry : Get().mServiceEntries) + { + if (serviceEntry.GetState() != Entry::kRegistered) + { + continue; + } + + if (NameMatch(mServiceType, serviceEntry.mServiceType)) + { + UpdateCompressOffset(serviceTypeOffset, serviceEntry.mServiceTypeOffset); + + if (serviceTypeOffset != kUnspecifiedOffset) + { + break; + } + } + } + + AppendPtrRecordTo(aResponse, serviceTypeOffset); + +exit: + return; +} + +void Core::ServiceType::AppendPtrRecordTo(TxMessage &aResponse, uint16_t aServiceTypeOffset) +{ + Message *message; + PtrRecord ptr; + uint16_t offset; + + VerifyOrExit(mServicesPtr.CanAppend()); + mServicesPtr.MarkAsAppended(aResponse, kAnswerSection); + + message = &aResponse.SelectMessageFor(kAnswerSection); + + ptr.Init(); + ptr.SetTtl(mServicesPtr.GetTtl()); + + aResponse.AppendServicesDnssdName(kAnswerSection); + offset = message->GetLength(); + SuccessOrAssert(message->Append(ptr)); + aResponse.AppendServiceType(kAnswerSection, mServiceType.AsCString(), aServiceTypeOffset); + UpdateRecordLengthInMessage(ptr, *message, offset); + + aResponse.IncrementRecordCount(kAnswerSection); + +exit: + return; +} + +//---------------------------------------------------------------------------------------------------------------------- +// Core::TxMessage + +Core::TxMessage::TxMessage(Instance &aInstance, Type aType) + : InstanceLocator(aInstance) +{ + Init(aType); +} + +Core::TxMessage::TxMessage(Instance &aInstance, Type aType, const AddressInfo &aUnicastDest) + : TxMessage(aInstance, aType) +{ + mUnicastDest = aUnicastDest; +} + +void Core::TxMessage::Init(Type aType) +{ + Header header; + + mRecordCounts.Clear(); + mSavedRecordCounts.Clear(); + mSavedMsgLength = 0; + mSavedExtraMsgLength = 0; + mDomainOffset = kUnspecifiedOffset; + mUdpOffset = kUnspecifiedOffset; + mTcpOffset = kUnspecifiedOffset; + mServicesDnssdOffset = kUnspecifiedOffset; + mType = aType; + + // Allocate messages. The main `mMsgPtr` is always allocated. + // The Authority and Addition section messages are allocated + // the first time they are used. + + mMsgPtr.Reset(Get().Allocate(Message::kTypeOther)); + OT_ASSERT(!mMsgPtr.IsNull()); + + mExtraMsgPtr.Reset(); + + header.Clear(); + + switch (aType) + { + case kMulticastProbe: + case kMulticastQuery: + header.SetType(Header::kTypeQuery); + break; + case kMulticastResponse: + case kUnicastResponse: + header.SetType(Header::kTypeResponse); + break; + } + + SuccessOrAssert(mMsgPtr->Append(header)); +} + +Message &Core::TxMessage::SelectMessageFor(Section aSection) +{ + // Selects the `Message` to use for a given `aSection` based + // the message type. + + Message *message = nullptr; + Section mainSection = kAnswerSection; + Section extraSection = kAdditionalDataSection; + + switch (mType) + { + case kMulticastProbe: + mainSection = kQuestionSection; + extraSection = kAuthoritySection; + break; + + case kMulticastQuery: + mainSection = kQuestionSection; + extraSection = kAnswerSection; + break; + + case kMulticastResponse: + case kUnicastResponse: + break; + } + + if (aSection == mainSection) + { + message = mMsgPtr.Get(); + } + else if (aSection == extraSection) + { + if (mExtraMsgPtr.IsNull()) + { + mExtraMsgPtr.Reset(Get().Allocate(Message::kTypeOther)); + OT_ASSERT(!mExtraMsgPtr.IsNull()); + } + + message = mExtraMsgPtr.Get(); + } + + OT_ASSERT(message != nullptr); + + return *message; +} + +Core::AppendOutcome Core::TxMessage::AppendLabel(Section aSection, const char *aLabel, uint16_t &aCompressOffset) +{ + return AppendLabels(aSection, aLabel, kIsSingleLabel, aCompressOffset); +} + +Core::AppendOutcome Core::TxMessage::AppendMultipleLabels(Section aSection, + const char *aLabels, + uint16_t &aCompressOffset) +{ + return AppendLabels(aSection, aLabels, !kIsSingleLabel, aCompressOffset); +} + +Core::AppendOutcome Core::TxMessage::AppendLabels(Section aSection, + const char *aLabels, + bool aIsSingleLabel, + uint16_t &aCompressOffset) +{ + // Appends DNS name label(s) to the message in the specified section, + // using compression if possible. + // + // - If a valid `aCompressOffset` is given (indicating name was appended before) + // a compressed pointer label is used, and `kAppendedFullNameAsCompressed` + // is returned. + // - Otherwise, `aLabels` is appended, `aCompressOffset` is also updated for + // future compression, and `kAppendedLabels` is returned. + // + // `aIsSingleLabel` indicates that `aLabels` string should be appended + // as a single label. This is useful for service instance label which + // can itself contain the dot `.` character. + + AppendOutcome outcome = kAppendedLabels; + Message &message = SelectMessageFor(aSection); + + if (aCompressOffset != kUnspecifiedOffset) + { + SuccessOrAssert(Name::AppendPointerLabel(aCompressOffset, message)); + outcome = kAppendedFullNameAsCompressed; + ExitNow(); + } + + SaveOffset(aCompressOffset, message, aSection); + + if (aIsSingleLabel) + { + SuccessOrAssert(Name::AppendLabel(aLabels, message)); + } + else + { + SuccessOrAssert(Name::AppendMultipleLabels(aLabels, message)); + } + +exit: + return outcome; +} + +void Core::TxMessage::AppendServiceType(Section aSection, const char *aServiceType, uint16_t &aCompressOffset) +{ + // Appends DNS service type name to the message in the specified + // section, using compression if possible. + + const char *serviceLabels = aServiceType; + bool isUdp = false; + bool isTcp = false; + Name::Buffer labelsBuffer; + AppendOutcome outcome; + + if (Name::ExtractLabels(serviceLabels, kUdpServiceLabel, labelsBuffer) == kErrorNone) + { + isUdp = true; + serviceLabels = labelsBuffer; + } + else if (Name::ExtractLabels(serviceLabels, kTcpServiceLabel, labelsBuffer) == kErrorNone) + { + isTcp = true; + serviceLabels = labelsBuffer; + } + + outcome = AppendMultipleLabels(aSection, serviceLabels, aCompressOffset); + VerifyOrExit(outcome != kAppendedFullNameAsCompressed); + + if (isUdp) + { + outcome = AppendLabel(aSection, kUdpServiceLabel, mUdpOffset); + } + else if (isTcp) + { + outcome = AppendLabel(aSection, kTcpServiceLabel, mTcpOffset); + } + + VerifyOrExit(outcome != kAppendedFullNameAsCompressed); + + AppendDomainName(aSection); + +exit: + return; +} + +void Core::TxMessage::AppendDomainName(Section aSection) +{ + Message &message = SelectMessageFor(aSection); + + if (mDomainOffset != kUnspecifiedOffset) + { + SuccessOrAssert(Name::AppendPointerLabel(mDomainOffset, message)); + ExitNow(); + } + + SaveOffset(mDomainOffset, message, aSection); + SuccessOrAssert(Name::AppendName(kLocalDomain, message)); + +exit: + return; +} + +void Core::TxMessage::AppendServicesDnssdName(Section aSection) +{ + Message &message = SelectMessageFor(aSection); + + if (mServicesDnssdOffset != kUnspecifiedOffset) + { + SuccessOrAssert(Name::AppendPointerLabel(mServicesDnssdOffset, message)); + ExitNow(); + } + + SaveOffset(mServicesDnssdOffset, message, aSection); + SuccessOrAssert(Name::AppendMultipleLabels(kServicesDnssdLabels, message)); + AppendDomainName(aSection); + +exit: + return; +} + +void Core::TxMessage::SaveOffset(uint16_t &aCompressOffset, const Message &aMessage, Section aSection) +{ + // Saves the current message offset in `aCompressOffset` for name + // compression, but only when appending to the question or answer + // sections. + // + // This is necessary because other sections use separate message, + // and their offsets can shift when records are added to the main + // message. + // + // While current record types guarantee name inclusion in + // question/answer sections before their use in other sections, + // this check allows future extensions. + + switch (aSection) + { + case kQuestionSection: + case kAnswerSection: + aCompressOffset = aMessage.GetLength(); + break; + + case kAuthoritySection: + case kAdditionalDataSection: + break; + } +} + +bool Core::TxMessage::IsOverSizeLimit(void) const +{ + uint32_t size = mMsgPtr->GetLength(); + + if (!mExtraMsgPtr.IsNull()) + { + size += mExtraMsgPtr->GetLength(); + } + + return (size > Get().mMaxMessageSize); +} + +void Core::TxMessage::SaveCurrentState(void) +{ + mSavedRecordCounts = mRecordCounts; + mSavedMsgLength = mMsgPtr->GetLength(); + mSavedExtraMsgLength = mExtraMsgPtr.IsNull() ? 0 : mExtraMsgPtr->GetLength(); +} + +void Core::TxMessage::RestoreToSavedState(void) +{ + mRecordCounts = mSavedRecordCounts; + + IgnoreError(mMsgPtr->SetLength(mSavedMsgLength)); + + if (!mExtraMsgPtr.IsNull()) + { + IgnoreError(mExtraMsgPtr->SetLength(mSavedExtraMsgLength)); + } +} + +void Core::TxMessage::CheckSizeLimitToPrepareAgain(bool &aPrepareAgain) +{ + // Manages message size limits by re-preparing messages when + // necessary: + // - Checks if `TxMessage` exceeds the size limit. + // - If so, restores the `TxMessage` to its previously saved + // state, sends it, and re-initializes it which will also + // clear the "AppendState" of the related host and service + // entries to ensure correct re-processing. + // - Sets `aPrepareAgain` to `true` to signal that records should + // be prepared and added to the new message. + // + // We allow the `aPrepareAgain` to happen once. The very unlikely + // case where the `Entry` itself has so many records that its + // contents exceed the message size limit, is not handled, i.e. + // we always include all records of a single `Entry` within the same + // message. In future, the code can be updated to allow truncated + // messages. + + if (aPrepareAgain) + { + aPrepareAgain = false; + ExitNow(); + } + + VerifyOrExit(IsOverSizeLimit()); + + aPrepareAgain = true; + + RestoreToSavedState(); + Send(); + Reinit(); + +exit: + return; +} + +void Core::TxMessage::Send(void) +{ + static constexpr uint16_t kHeaderOffset = 0; + + Header header; + + VerifyOrExit(!mRecordCounts.IsEmpty()); + + SuccessOrAssert(mMsgPtr->Read(kHeaderOffset, header)); + mRecordCounts.WriteTo(header); + mMsgPtr->Write(kHeaderOffset, header); + + if (!mExtraMsgPtr.IsNull()) + { + SuccessOrAssert(mMsgPtr->AppendBytesFromMessage(*mExtraMsgPtr, 0, mExtraMsgPtr->GetLength())); + } + + Get().mTxMessageHistory.Add(*mMsgPtr); + + // We pass ownership of message to the platform layer. + + switch (mType) + { + case kMulticastProbe: + case kMulticastQuery: + case kMulticastResponse: + otPlatMdnsSendMulticast(&GetInstance(), mMsgPtr.Release(), Get().mInfraIfIndex); + break; + + case kUnicastResponse: + otPlatMdnsSendUnicast(&GetInstance(), mMsgPtr.Release(), &mUnicastDest); + break; + } + +exit: + return; +} + +void Core::TxMessage::Reinit(void) +{ + Init(GetType()); + + // After re-initializing `TxMessage`, we clear the "AppendState" + // on all related host and service entries, and service types + // or all cache entries (depending on the `GetType()`). + + switch (GetType()) + { + case kMulticastProbe: + case kMulticastResponse: + case kUnicastResponse: + for (HostEntry &entry : Get().mHostEntries) + { + if (ShouldClearAppendStateOnReinit(entry)) + { + entry.ClearAppendState(); + } + } + + for (ServiceEntry &entry : Get().mServiceEntries) + { + if (ShouldClearAppendStateOnReinit(entry)) + { + entry.ClearAppendState(); + } + } + + for (ServiceType &serviceType : Get().mServiceTypes) + { + if ((GetType() == kMulticastResponse) || (GetType() == kUnicastResponse)) + { + serviceType.ClearAppendState(); + } + } + + break; + + case kMulticastQuery: + + for (BrowseCache &browseCache : Get().mBrowseCacheList) + { + browseCache.ClearCompressOffsets(); + } + + for (SrvCache &srvCache : Get().mSrvCacheList) + { + srvCache.ClearCompressOffsets(); + } + + for (TxtCache &txtCache : Get().mTxtCacheList) + { + txtCache.ClearCompressOffsets(); + } + + // `Ip6AddrCache` entries do not track any append state or + // compress offset since the host name should not be used + // in any other query question. + + break; + } +} + +bool Core::TxMessage::ShouldClearAppendStateOnReinit(const Entry &aEntry) const +{ + // Determines whether we should clear "append state" on `aEntry` + // when re-initializing the `TxMessage`. If message is a probe, we + // check that entry is in `kProbing` state, if message is a + // unicast/multicast response, we check for `kRegistered` state. + + bool shouldClear = false; + + switch (aEntry.GetState()) + { + case Entry::kProbing: + shouldClear = (GetType() == kMulticastProbe); + break; + + case Entry::kRegistered: + shouldClear = (GetType() == kMulticastResponse) || (GetType() == kUnicastResponse); + break; + + case Entry::kConflict: + case Entry::kRemoving: + shouldClear = true; + break; + } + + return shouldClear; +} + +//---------------------------------------------------------------------------------------------------------------------- +// Core::TimerContext + +Core::TimerContext::TimerContext(Instance &aInstance) + : InstanceLocator(aInstance) + , mNow(TimerMilli::GetNow()) + , mNextTime(mNow.GetDistantFuture()) +{ +} + +void Core::TimerContext::UpdateNextTime(TimeMilli aTime) +{ + if (aTime <= mNow) + { + mNextTime = mNow; + } + else + { + mNextTime = Min(mNextTime, aTime); + } +} + +//---------------------------------------------------------------------------------------------------------------------- +// Core::EntryTimerContext + +Core::EntryTimerContext::EntryTimerContext(Instance &aInstance) + : TimerContext(aInstance) + , mProbeMessage(aInstance, TxMessage::kMulticastProbe) + , mResponseMessage(aInstance, TxMessage::kMulticastResponse) +{ +} + +//---------------------------------------------------------------------------------------------------------------------- +// Core::RxMessage + +Error Core::RxMessage::Init(Instance &aInstance, + OwnedPtr &aMessagePtr, + bool aIsUnicast, + const AddressInfo &aSenderAddress) +{ + static const Section kSections[] = {kAnswerSection, kAuthoritySection, kAdditionalDataSection}; + + Error error = kErrorNone; + Header header; + uint16_t offset; + uint16_t numRecords; + + InstanceLocatorInit::Init(aInstance); + + mNext = nullptr; + + VerifyOrExit(!aMessagePtr.IsNull(), error = kErrorInvalidArgs); + + offset = aMessagePtr->GetOffset(); + + SuccessOrExit(error = aMessagePtr->Read(offset, header)); + offset += sizeof(Header); + + // RFC 6762 Section 18: Query type (OPCODE) must be zero + // (standard query). All other flags must be ignored. Messages + // with non-zero RCODE MUST be silently ignored. + + VerifyOrExit(header.GetQueryType() == Header::kQueryTypeStandard, error = kErrorParse); + VerifyOrExit(header.GetResponseCode() == Header::kResponseSuccess, error = kErrorParse); + + mIsQuery = (header.GetType() == Header::kTypeQuery); + mIsUnicast = aIsUnicast; + mTruncated = header.IsTruncationFlagSet(); + mSenderAddress = aSenderAddress; + + if (aSenderAddress.mPort != kUdpPort) + { + if (mIsQuery) + { + // Section 6.7 Legacy Unicast + LogInfo("We do not yet support legacy unicast message (source port not matching mDNS port)"); + ExitNow(error = kErrorNotCapable); + } + else + { + // The source port in a response MUST be mDNS port. + // Otherwise response message MUST be silently ignored. + + ExitNow(error = kErrorParse); + } + } + + if (mIsUnicast && mIsQuery) + { + // Direct Unicast Queries to Port 5353 (RFC 6762 - section 5.5). + // Responders SHOULD check that the source address in the query + // packet matches the local subnet for that link and silently ignore + // the packet if not. + + LogInfo("We do not yet support unicast query to mDNS port"); + ExitNow(error = kErrorNotCapable); + } + + mRecordCounts.ReadFrom(header); + + // Parse questions + + mStartOffset[kQuestionSection] = offset; + + SuccessOrAssert(mQuestions.ReserveCapacity(mRecordCounts.GetFor(kQuestionSection))); + + for (numRecords = mRecordCounts.GetFor(kQuestionSection); numRecords > 0; numRecords--) + { + Question *question = mQuestions.PushBack(); + ot::Dns::Question record; + uint16_t rrClass; + + OT_ASSERT(question != nullptr); + + question->mNameOffset = offset; + + SuccessOrExit(error = Name::ParseName(*aMessagePtr, offset)); + SuccessOrExit(error = aMessagePtr->Read(offset, record)); + offset += sizeof(record); + + question->mRrType = record.GetType(); + + rrClass = record.GetClass(); + question->mUnicastResponse = rrClass & kClassQuestionUnicastFlag; + question->mIsRrClassInternet = RrClassIsInternetOrAny(rrClass); + } + + // Parse and validate records in Answer, Authority and Additional + // Data sections. + + for (Section section : kSections) + { + mStartOffset[section] = offset; + SuccessOrExit(error = ResourceRecord::ParseRecords(*aMessagePtr, offset, mRecordCounts.GetFor(section))); + } + + // Determine which questions are probes by searching in the + // Authority section for records matching the question name. + + for (Question &question : mQuestions) + { + Name name(*aMessagePtr, question.mNameOffset); + + offset = mStartOffset[kAuthoritySection]; + numRecords = mRecordCounts.GetFor(kAuthoritySection); + + if (ResourceRecord::FindRecord(*aMessagePtr, offset, numRecords, name) == kErrorNone) + { + question.mIsProbe = true; + } + } + + mIsSelfOriginating = Get().mTxMessageHistory.Contains(*aMessagePtr); + + mMessagePtr = aMessagePtr.PassOwnership(); + +exit: + if (error != kErrorNone) + { + LogInfo("Failed to parse message from %s, error:%s", aSenderAddress.GetAddress().ToString().AsCString(), + ErrorToString(error)); + } + + return error; +} + +void Core::RxMessage::ClearProcessState(void) +{ + for (Question &question : mQuestions) + { + question.ClearProcessState(); + } +} + +Core::RxMessage::ProcessOutcome Core::RxMessage::ProcessQuery(bool aShouldProcessTruncated) +{ + ProcessOutcome outcome = kProcessed; + bool shouldDelay = false; + bool canAnswer = false; + bool needUnicastResponse = false; + TimeMilli answerTime; + + for (Question &question : mQuestions) + { + question.ClearProcessState(); + + ProcessQuestion(question); + + // Check if we can answer every question in the query and all + // answers are for unique records (where we own the name). This + // determines whether we need to add any random delay before + // responding. + + if (!question.mCanAnswer || !question.mIsUnique) + { + shouldDelay = true; + } + + if (question.mCanAnswer) + { + canAnswer = true; + + if (question.mUnicastResponse) + { + needUnicastResponse = true; + } + } + } + + VerifyOrExit(canAnswer); + + if (mTruncated && !aShouldProcessTruncated) + { + outcome = kSaveAsMultiPacket; + ExitNow(); + } + + answerTime = TimerMilli::GetNow(); + + if (shouldDelay) + { + answerTime += Random::NonCrypto::GetUint32InRange(kMinResponseDelay, kMaxResponseDelay); + } + + for (const Question &question : mQuestions) + { + AnswerQuestion(question, answerTime); + } + + if (needUnicastResponse) + { + SendUnicastResponse(mSenderAddress); + } + +exit: + return outcome; +} + +void Core::RxMessage::ProcessQuestion(Question &aQuestion) +{ + Name name(*mMessagePtr, aQuestion.mNameOffset); + + VerifyOrExit(aQuestion.mIsRrClassInternet); + + // Check if question name matches "_services._dns-sd._udp" (all services) + + if (name.Matches(/* aFirstLabel */ nullptr, kServicesDnssdLabels, kLocalDomain)) + { + VerifyOrExit(QuestionMatches(aQuestion.mRrType, ResourceRecord::kTypePtr)); + VerifyOrExit(!Get().mServiceTypes.IsEmpty()); + + aQuestion.mCanAnswer = true; + aQuestion.mIsForAllServicesDnssd = true; + + ExitNow(); + } + + // Check if question name matches a `HostEntry` or a `ServiceEntry` + + aQuestion.mEntry = Get().mHostEntries.FindMatching(name); + + if (aQuestion.mEntry == nullptr) + { + aQuestion.mEntry = Get().mServiceEntries.FindMatching(name); + aQuestion.mIsForService = (aQuestion.mEntry != nullptr); + } + + if (aQuestion.mEntry != nullptr) + { + switch (aQuestion.mEntry->GetState()) + { + case Entry::kProbing: + if (aQuestion.mIsProbe) + { + // Handling probe conflicts deviates from RFC 6762. + // We allow the conflict to happen and report it + // let the caller handle it. In future, TSR can + // help select the winner. + } + break; + + case Entry::kRegistered: + aQuestion.mCanAnswer = true; + aQuestion.mIsUnique = true; + break; + + case Entry::kConflict: + case Entry::kRemoving: + break; + } + } + else + { + // Check if question matches a service type or sub-type. We + // can answer PTR or ANY questions. There may be multiple + // service entries matching the question. We find and save + // the first match. `AnswerServiceTypeQuestion()` will start + // from the saved entry and finds all the other matches. + + bool isSubType; + Name::LabelBuffer subLabel; + Name baseType; + + VerifyOrExit(QuestionMatches(aQuestion.mRrType, ResourceRecord::kTypePtr)); + + isSubType = ParseQuestionNameAsSubType(aQuestion, subLabel, baseType); + + if (!isSubType) + { + baseType = name; + } + + for (ServiceEntry &serviceEntry : Get().mServiceEntries) + { + if ((serviceEntry.GetState() != Entry::kRegistered) || !serviceEntry.MatchesServiceType(baseType)) + { + continue; + } + + if (isSubType && !serviceEntry.CanAnswerSubType(subLabel)) + { + continue; + } + + aQuestion.mCanAnswer = true; + aQuestion.mEntry = &serviceEntry; + aQuestion.mIsForService = true; + aQuestion.mIsServiceType = true; + ExitNow(); + } + } + +exit: + return; +} + +void Core::RxMessage::AnswerQuestion(const Question &aQuestion, TimeMilli aAnswerTime) +{ + HostEntry *hostEntry; + ServiceEntry *serviceEntry; + AnswerInfo answerInfo; + + VerifyOrExit(aQuestion.mCanAnswer); + + answerInfo.mQuestionRrType = aQuestion.mRrType; + answerInfo.mAnswerTime = aAnswerTime; + answerInfo.mIsProbe = aQuestion.mIsProbe; + answerInfo.mUnicastResponse = aQuestion.mUnicastResponse; + + if (aQuestion.mIsForAllServicesDnssd) + { + AnswerAllServicesQuestion(aQuestion, answerInfo); + ExitNow(); + } + + hostEntry = aQuestion.mIsForService ? nullptr : static_cast(aQuestion.mEntry); + serviceEntry = aQuestion.mIsForService ? static_cast(aQuestion.mEntry) : nullptr; + + if (hostEntry != nullptr) + { + hostEntry->AnswerQuestion(answerInfo); + ExitNow(); + } + + // Question is for `ServiceEntry` + + if (!aQuestion.mIsServiceType) + { + serviceEntry->AnswerServiceNameQuestion(answerInfo); + } + else + { + AnswerServiceTypeQuestion(aQuestion, answerInfo, *serviceEntry); + } + +exit: + return; +} + +void Core::RxMessage::AnswerServiceTypeQuestion(const Question &aQuestion, + const AnswerInfo &aInfo, + ServiceEntry &aFirstEntry) +{ + Name serviceType(*mMessagePtr, aQuestion.mNameOffset); + Name baseType; + Name::LabelBuffer labelBuffer; + const char *subLabel; + + if (ParseQuestionNameAsSubType(aQuestion, labelBuffer, baseType)) + { + subLabel = labelBuffer; + } + else + { + baseType = serviceType; + subLabel = nullptr; + } + + for (ServiceEntry *serviceEntry = &aFirstEntry; serviceEntry != nullptr; serviceEntry = serviceEntry->GetNext()) + { + bool shouldSuppress = false; + + if ((serviceEntry->GetState() != Entry::kRegistered) || !serviceEntry->MatchesServiceType(baseType)) + { + continue; + } + + if ((subLabel != nullptr) && !serviceEntry->CanAnswerSubType(subLabel)) + { + continue; + } + + // Check for known-answer in this `RxMessage` and all its + // related messages in case it is multi-packet query. + + for (const RxMessage *rxMessage = this; rxMessage != nullptr; rxMessage = rxMessage->GetNext()) + { + if (rxMessage->ShouldSuppressKnownAnswer(serviceType, subLabel, *serviceEntry)) + { + shouldSuppress = true; + break; + } + } + + if (!shouldSuppress) + { + serviceEntry->AnswerServiceTypeQuestion(aInfo, subLabel); + } + } +} + +bool Core::RxMessage::ShouldSuppressKnownAnswer(const Name &aServiceType, + const char *aSubLabel, + const ServiceEntry &aServiceEntry) const +{ + bool shouldSuppress = false; + uint16_t offset = mStartOffset[kAnswerSection]; + uint16_t numRecords = mRecordCounts.GetFor(kAnswerSection); + + while (ResourceRecord::FindRecord(*mMessagePtr, offset, numRecords, aServiceType) == kErrorNone) + { + Error error; + PtrRecord ptr; + + error = ResourceRecord::ReadRecord(*mMessagePtr, offset, ptr); + + if (error == kErrorNotFound) + { + // `ReadRecord()` will update the `offset` to skip over + // the entire record if it does not match the expected + // record type (PTR in this case). + continue; + } + + SuccessOrExit(error); + + // `offset` is now pointing to PTR name + + if (aServiceEntry.Matches(Name(*mMessagePtr, offset))) + { + shouldSuppress = aServiceEntry.ShouldSuppressKnownAnswer(ptr.GetTtl(), aSubLabel); + ExitNow(); + } + + // Parse the name and skip over it and update `offset` + // to the start of the next record. + + SuccessOrExit(Name::ParseName(*mMessagePtr, offset)); + } + +exit: + return shouldSuppress; +} + +bool Core::RxMessage::ParseQuestionNameAsSubType(const Question &aQuestion, + Name::LabelBuffer &aSubLabel, + Name &aServiceType) const +{ + bool isSubType = false; + uint16_t offset = aQuestion.mNameOffset; + uint8_t length = sizeof(aSubLabel); + + SuccessOrExit(Name::ReadLabel(*mMessagePtr, offset, aSubLabel, length)); + SuccessOrExit(Name::CompareLabel(*mMessagePtr, offset, kSubServiceLabel)); + aServiceType.SetFromMessage(*mMessagePtr, offset); + isSubType = true; + +exit: + return isSubType; +} + +void Core::RxMessage::AnswerAllServicesQuestion(const Question &aQuestion, const AnswerInfo &aInfo) +{ + for (ServiceType &serviceType : Get().mServiceTypes) + { + bool shouldSuppress = false; + + // Check for known-answer in this `RxMessage` and all its + // related messages in case it is multi-packet query. + + for (const RxMessage *rxMessage = this; rxMessage != nullptr; rxMessage = rxMessage->GetNext()) + { + if (rxMessage->ShouldSuppressKnownAnswer(aQuestion, serviceType)) + { + shouldSuppress = true; + break; + } + } + + if (!shouldSuppress) + { + serviceType.AnswerQuestion(aInfo); + } + } +} + +bool Core::RxMessage::ShouldSuppressKnownAnswer(const Question &aQuestion, const ServiceType &aServiceType) const +{ + // Check answer section to determine whether to suppress answering + // to "_services._dns-sd._udp" query with `aServiceType` + + bool shouldSuppress = false; + uint16_t offset = mStartOffset[kAnswerSection]; + uint16_t numRecords = mRecordCounts.GetFor(kAnswerSection); + Name name(*mMessagePtr, aQuestion.mNameOffset); + + while (ResourceRecord::FindRecord(*mMessagePtr, offset, numRecords, name) == kErrorNone) + { + Error error; + PtrRecord ptr; + + error = ResourceRecord::ReadRecord(*mMessagePtr, offset, ptr); + + if (error == kErrorNotFound) + { + // `ReadRecord()` will update the `offset` to skip over + // the entire record if it does not match the expected + // record type (PTR in this case). + continue; + } + + SuccessOrExit(error); + + // `offset` is now pointing to PTR name + + if (aServiceType.Matches(Name(*mMessagePtr, offset))) + { + shouldSuppress = aServiceType.ShouldSuppressKnownAnswer(ptr.GetTtl()); + ExitNow(); + } + + // Parse the name and skip over it and update `offset` + // to the start of the next record. + + SuccessOrExit(Name::ParseName(*mMessagePtr, offset)); + } + +exit: + return shouldSuppress; +} + +void Core::RxMessage::SendUnicastResponse(const AddressInfo &aUnicastDest) +{ + TxMessage response(GetInstance(), TxMessage::kUnicastResponse, aUnicastDest); + TimeMilli now = TimerMilli::GetNow(); + + for (HostEntry &entry : Get().mHostEntries) + { + entry.ClearAppendState(); + entry.PrepareResponse(response, now); + } + + for (ServiceEntry &entry : Get().mServiceEntries) + { + entry.ClearAppendState(); + entry.PrepareResponse(response, now); + } + + for (ServiceType &serviceType : Get().mServiceTypes) + { + serviceType.ClearAppendState(); + serviceType.PrepareResponse(response, now); + } + + response.Send(); +} + +void Core::RxMessage::ProcessResponse(void) +{ + if (!IsSelfOriginating()) + { + IterateOnAllRecordsInResponse(&RxMessage::ProcessRecordForConflict); + } + + // We process record types in a specific order to ensure correct + // passive cache creation: First PTR records are processed, which + // may create passive SRV/TXT cache entries for discovered + // services. Next SRV records are processed which may create TXT + // cache entries for service names and IPv6 address cache entries + // for associated host name. + + if (!Get().mBrowseCacheList.IsEmpty()) + { + IterateOnAllRecordsInResponse(&RxMessage::ProcessPtrRecord); + } + + if (!Get().mSrvCacheList.IsEmpty()) + { + IterateOnAllRecordsInResponse(&RxMessage::ProcessSrvRecord); + } + + if (!Get().mTxtCacheList.IsEmpty()) + { + IterateOnAllRecordsInResponse(&RxMessage::ProcessTxtRecord); + } + + if (!Get().mIp6AddrCacheList.IsEmpty()) + { + IterateOnAllRecordsInResponse(&RxMessage::ProcessAaaaRecord); + + for (Ip6AddrCache &addrCache : Get().mIp6AddrCacheList) + { + addrCache.CommitNewResponseEntries(); + } + } + + if (!Get().mIp4AddrCacheList.IsEmpty()) + { + IterateOnAllRecordsInResponse(&RxMessage::ProcessARecord); + + for (Ip4AddrCache &addrCache : Get().mIp4AddrCacheList) + { + addrCache.CommitNewResponseEntries(); + } + } +} + +void Core::RxMessage::IterateOnAllRecordsInResponse(RecordProcessor aRecordProcessor) +{ + // Iterates over all records in the response, calling + // `aRecordProcessor` for each. + + static const Section kSections[] = {kAnswerSection, kAdditionalDataSection}; + + for (Section section : kSections) + { + uint16_t offset = mStartOffset[section]; + + for (uint16_t numRecords = mRecordCounts.GetFor(section); numRecords > 0; numRecords--) + { + Name name(*mMessagePtr, offset); + ResourceRecord record; + + IgnoreError(Name::ParseName(*mMessagePtr, offset)); + IgnoreError(mMessagePtr->Read(offset, record)); + + if (!RrClassIsInternetOrAny(record.GetClass())) + { + continue; + } + + (this->*aRecordProcessor)(name, record, offset); + + offset += static_cast(record.GetSize()); + } + } +} + +void Core::RxMessage::ProcessRecordForConflict(const Name &aName, const ResourceRecord &aRecord, uint16_t aRecordOffset) +{ + HostEntry *hostEntry; + ServiceEntry *serviceEntry; + + VerifyOrExit(aRecord.GetTtl() > 0); + + hostEntry = Get().mHostEntries.FindMatching(aName); + + if (hostEntry != nullptr) + { + hostEntry->HandleConflict(); + } + + serviceEntry = Get().mServiceEntries.FindMatching(aName); + + if (serviceEntry != nullptr) + { + serviceEntry->HandleConflict(); + } + +exit: + OT_UNUSED_VARIABLE(aRecordOffset); +} + +void Core::RxMessage::ProcessPtrRecord(const Name &aName, const ResourceRecord &aRecord, uint16_t aRecordOffset) +{ + BrowseCache *browseCache; + + VerifyOrExit(aRecord.GetType() == ResourceRecord::kTypePtr); + + browseCache = Get().mBrowseCacheList.FindMatching(aName); + VerifyOrExit(browseCache != nullptr); + + browseCache->ProcessResponseRecord(*mMessagePtr, aRecordOffset); + +exit: + return; +} + +void Core::RxMessage::ProcessSrvRecord(const Name &aName, const ResourceRecord &aRecord, uint16_t aRecordOffset) +{ + SrvCache *srvCache; + + VerifyOrExit(aRecord.GetType() == ResourceRecord::kTypeSrv); + + srvCache = Get().mSrvCacheList.FindMatching(aName); + VerifyOrExit(srvCache != nullptr); + + srvCache->ProcessResponseRecord(*mMessagePtr, aRecordOffset); + +exit: + return; +} + +void Core::RxMessage::ProcessTxtRecord(const Name &aName, const ResourceRecord &aRecord, uint16_t aRecordOffset) +{ + TxtCache *txtCache; + + VerifyOrExit(aRecord.GetType() == ResourceRecord::kTypeTxt); + + txtCache = Get().mTxtCacheList.FindMatching(aName); + VerifyOrExit(txtCache != nullptr); + + txtCache->ProcessResponseRecord(*mMessagePtr, aRecordOffset); + +exit: + return; +} + +void Core::RxMessage::ProcessAaaaRecord(const Name &aName, const ResourceRecord &aRecord, uint16_t aRecordOffset) +{ + Ip6AddrCache *ip6AddrCache; + + VerifyOrExit(aRecord.GetType() == ResourceRecord::kTypeAaaa); + + ip6AddrCache = Get().mIp6AddrCacheList.FindMatching(aName); + VerifyOrExit(ip6AddrCache != nullptr); + + ip6AddrCache->ProcessResponseRecord(*mMessagePtr, aRecordOffset); + +exit: + return; +} + +void Core::RxMessage::ProcessARecord(const Name &aName, const ResourceRecord &aRecord, uint16_t aRecordOffset) +{ + Ip4AddrCache *ip4AddrCache; + + VerifyOrExit(aRecord.GetType() == ResourceRecord::kTypeA); + + ip4AddrCache = Get().mIp4AddrCacheList.FindMatching(aName); + VerifyOrExit(ip4AddrCache != nullptr); + + ip4AddrCache->ProcessResponseRecord(*mMessagePtr, aRecordOffset); + +exit: + return; +} + +//--------------------------------------------------------------------------------------------------------------------- +// Core::RxMessage::Question + +void Core::RxMessage::Question::ClearProcessState(void) +{ + mCanAnswer = false; + mIsUnique = false; + mIsForService = false; + mIsServiceType = false; + mIsForAllServicesDnssd = false; + mEntry = nullptr; +} + +//--------------------------------------------------------------------------------------------------------------------- +// Core::MultiPacketRxMessages + +Core::MultiPacketRxMessages::MultiPacketRxMessages(Instance &aInstance) + : InstanceLocator(aInstance) + , mTimer(aInstance) +{ +} + +void Core::MultiPacketRxMessages::AddToExisting(OwnedPtr &aRxMessagePtr) +{ + RxMsgEntry *msgEntry = mRxMsgEntries.FindMatching(aRxMessagePtr->GetSenderAddress()); + + VerifyOrExit(msgEntry != nullptr); + msgEntry->Add(aRxMessagePtr); + +exit: + return; +} + +void Core::MultiPacketRxMessages::AddNew(OwnedPtr &aRxMessagePtr) +{ + RxMsgEntry *newEntry = RxMsgEntry::Allocate(GetInstance()); + + OT_ASSERT(newEntry != nullptr); + newEntry->Add(aRxMessagePtr); + + // First remove an existing entries matching same sender + // before adding the new entry to the list. + + mRxMsgEntries.RemoveMatching(aRxMessagePtr->GetSenderAddress()); + mRxMsgEntries.Push(*newEntry); +} + +void Core::MultiPacketRxMessages::HandleTimer(void) +{ + TimeMilli now = TimerMilli::GetNow(); + TimeMilli nextTime = now.GetDistantFuture(); + OwningList expiredEntries; + + mRxMsgEntries.RemoveAllMatching(ExpireChecker(now), expiredEntries); + + for (RxMsgEntry &expiredEntry : expiredEntries) + { + expiredEntry.mRxMessages.GetHead()->ProcessQuery(/* aShouldProcessTruncated */ true); + } + + for (const RxMsgEntry &msgEntry : mRxMsgEntries) + { + nextTime = Min(nextTime, msgEntry.mProcessTime); + } + + if (nextTime != now.GetDistantFuture()) + { + mTimer.FireAtIfEarlier(nextTime); + } +} + +void Core::MultiPacketRxMessages::Clear(void) +{ + mTimer.Stop(); + mRxMsgEntries.Clear(); +} + +//--------------------------------------------------------------------------------------------------------------------- +// Core::MultiPacketRxMessage::RxMsgEntry + +Core::MultiPacketRxMessages::RxMsgEntry::RxMsgEntry(Instance &aInstance) + : InstanceLocator(aInstance) + , mNext(nullptr) +{ +} + +bool Core::MultiPacketRxMessages::RxMsgEntry::Matches(const AddressInfo &aAddress) const +{ + bool matches = false; + + VerifyOrExit(!mRxMessages.IsEmpty()); + matches = (mRxMessages.GetHead()->GetSenderAddress() == aAddress); + +exit: + return matches; +} + +bool Core::MultiPacketRxMessages::RxMsgEntry::Matches(const ExpireChecker &aExpireChecker) const +{ + return (mProcessTime <= aExpireChecker.mNow); +} + +void Core::MultiPacketRxMessages::RxMsgEntry::Add(OwnedPtr &aRxMessagePtr) +{ + uint16_t numMsgs = 0; + + for (const RxMessage &rxMsg : mRxMessages) + { + // If a subsequent received `RxMessage` is also marked as + // truncated, we again delay the process time. To avoid + // continuous delay and piling up of messages in the list, + // we limit the number of messages. + + numMsgs++; + VerifyOrExit(numMsgs < kMaxNumMessages); + + OT_UNUSED_VARIABLE(rxMsg); + } + + mProcessTime = TimerMilli::GetNow(); + + if (aRxMessagePtr->IsTruncated()) + { + mProcessTime += Random::NonCrypto::GetUint32InRange(kMinProcessDelay, kMaxProcessDelay); + } + + // We push the new `RxMessage` at tail of the list to keep the + // first query containing questions at the head of the list. + + mRxMessages.PushAfterTail(*aRxMessagePtr.Release()); + + Get().mMultiPacketRxMessages.mTimer.FireAtIfEarlier(mProcessTime); + +exit: + return; +} + +//--------------------------------------------------------------------------------------------------------------------- +// Core::TxMessageHistory + +Core::TxMessageHistory::TxMessageHistory(Instance &aInstance) + : InstanceLocator(aInstance) + , mTimer(aInstance) +{ +} + +void Core::TxMessageHistory::Clear(void) +{ + mHashEntries.Clear(); + mTimer.Stop(); +} + +void Core::TxMessageHistory::Add(const Message &aMessage) +{ + Hash hash; + HashEntry *entry; + + CalculateHash(aMessage, hash); + + entry = mHashEntries.FindMatching(hash); + + if (entry == nullptr) + { + entry = HashEntry::Allocate(); + OT_ASSERT(entry != nullptr); + entry->mHash = hash; + mHashEntries.Push(*entry); + } + + entry->mExpireTime = TimerMilli::GetNow() + kExpireInterval; + mTimer.FireAtIfEarlier(entry->mExpireTime); +} + +bool Core::TxMessageHistory::Contains(const Message &aMessage) const +{ + Hash hash; + + CalculateHash(aMessage, hash); + return mHashEntries.ContainsMatching(hash); +} + +void Core::TxMessageHistory::CalculateHash(const Message &aMessage, Hash &aHash) +{ + Crypto::Sha256 sha256; + + sha256.Start(); + sha256.Update(aMessage, /* aOffset */ 0, aMessage.GetLength()); + sha256.Finish(aHash); +} + +void Core::TxMessageHistory::HandleTimer(void) +{ + TimeMilli now = TimerMilli::GetNow(); + TimeMilli nextTime = now.GetDistantFuture(); + OwningList expiredEntries; + + mHashEntries.RemoveAllMatching(ExpireChecker(now), expiredEntries); + + for (const HashEntry &entry : mHashEntries) + { + nextTime = Min(nextTime, entry.mExpireTime); + } + + if (nextTime != now.GetDistantFuture()) + { + mTimer.FireAtIfEarlier(nextTime); + } +} + +template +Error Core::Start(const BrowserResolverType &aBrowserOrResolver) +{ + Error error = kErrorNone; + CacheType *cacheEntry; + + VerifyOrExit(mIsEnabled, error = kErrorInvalidState); + VerifyOrExit(aBrowserOrResolver.mCallback != nullptr, error = kErrorInvalidArgs); + + cacheEntry = GetCacheList().FindMatching(aBrowserOrResolver); + + if (cacheEntry == nullptr) + { + cacheEntry = CacheType::AllocateAndInit(GetInstance(), aBrowserOrResolver); + OT_ASSERT(cacheEntry != nullptr); + + GetCacheList().Push(*cacheEntry); + } + + error = cacheEntry->Add(aBrowserOrResolver); + +exit: + return error; +} + +template +Error Core::Stop(const BrowserResolverType &aBrowserOrResolver) +{ + Error error = kErrorNone; + CacheType *cacheEntry; + + VerifyOrExit(mIsEnabled, error = kErrorInvalidState); + VerifyOrExit(aBrowserOrResolver.mCallback != nullptr, error = kErrorInvalidArgs); + + cacheEntry = GetCacheList().FindMatching(aBrowserOrResolver); + VerifyOrExit(cacheEntry != nullptr); + + cacheEntry->Remove(aBrowserOrResolver); + +exit: + return error; +} + +Error Core::StartBrowser(const Browser &aBrowser) { return Start(aBrowser); } + +Error Core::StopBrowser(const Browser &aBrowser) { return Stop(aBrowser); } + +Error Core::StartSrvResolver(const SrvResolver &aResolver) { return Start(aResolver); } + +Error Core::StopSrvResolver(const SrvResolver &aResolver) { return Stop(aResolver); } + +Error Core::StartTxtResolver(const TxtResolver &aResolver) { return Start(aResolver); } + +Error Core::StopTxtResolver(const TxtResolver &aResolver) { return Stop(aResolver); } + +Error Core::StartIp6AddressResolver(const AddressResolver &aResolver) +{ + return Start(aResolver); +} + +Error Core::StopIp6AddressResolver(const AddressResolver &aResolver) +{ + return Stop(aResolver); +} + +Error Core::StartIp4AddressResolver(const AddressResolver &aResolver) +{ + return Start(aResolver); +} + +Error Core::StopIp4AddressResolver(const AddressResolver &aResolver) +{ + return Stop(aResolver); +} + +void Core::AddPassiveSrvTxtCache(const char *aServiceInstance, const char *aServiceType) +{ + ServiceName serviceName(aServiceInstance, aServiceType); + + if (!mSrvCacheList.ContainsMatching(serviceName)) + { + SrvCache *srvCache = SrvCache::AllocateAndInit(GetInstance(), serviceName); + + OT_ASSERT(srvCache != nullptr); + mSrvCacheList.Push(*srvCache); + } + + if (!mTxtCacheList.ContainsMatching(serviceName)) + { + TxtCache *txtCache = TxtCache::AllocateAndInit(GetInstance(), serviceName); + + OT_ASSERT(txtCache != nullptr); + mTxtCacheList.Push(*txtCache); + } +} + +void Core::AddPassiveIp6AddrCache(const char *aHostName) +{ + if (!mIp6AddrCacheList.ContainsMatching(aHostName)) + { + Ip6AddrCache *ip6AddrCache = Ip6AddrCache::AllocateAndInit(GetInstance(), aHostName); + + OT_ASSERT(ip6AddrCache != nullptr); + mIp6AddrCacheList.Push(*ip6AddrCache); + } +} + +void Core::HandleCacheTimer(void) +{ + CacheTimerContext context(GetInstance()); + ExpireChecker expireChecker(context.GetNow()); + OwningList expiredBrowseList; + OwningList expiredSrvList; + OwningList expiredTxtList; + OwningList expiredIp6AddrList; + OwningList expiredIp4AddrList; + + // First remove all expired entries. + + mBrowseCacheList.RemoveAllMatching(expireChecker, expiredBrowseList); + mSrvCacheList.RemoveAllMatching(expireChecker, expiredSrvList); + mTxtCacheList.RemoveAllMatching(expireChecker, expiredTxtList); + mIp6AddrCacheList.RemoveAllMatching(expireChecker, expiredIp6AddrList); + mIp4AddrCacheList.RemoveAllMatching(expireChecker, expiredIp4AddrList); + + // Process cache types in a specific order to optimize name + // compression when constructing query messages. + + for (SrvCache &srvCache : mSrvCacheList) + { + srvCache.HandleTimer(context); + } + + for (TxtCache &txtCache : mTxtCacheList) + { + txtCache.HandleTimer(context); + } + + for (BrowseCache &browseCache : mBrowseCacheList) + { + browseCache.HandleTimer(context); + } + + for (Ip6AddrCache &addrCache : mIp6AddrCacheList) + { + addrCache.HandleTimer(context); + } + + for (Ip4AddrCache &addrCache : mIp4AddrCacheList) + { + addrCache.HandleTimer(context); + } + + context.GetQueryMessage().Send(); + + if (context.GetNextTime() != context.GetNow().GetDistantFuture()) + { + mCacheTimer.FireAtIfEarlier(context.GetNextTime()); + } +} + +void Core::HandleCacheTask(void) +{ + // `CacheTask` is used to remove empty/null callbacks + // from cache entries. and also removing "passive" + // cache entries that timed out. + + for (BrowseCache &browseCache : mBrowseCacheList) + { + browseCache.ClearEmptyCallbacks(); + } + + for (SrvCache &srvCache : mSrvCacheList) + { + srvCache.ClearEmptyCallbacks(); + } + + for (TxtCache &txtCache : mTxtCacheList) + { + txtCache.ClearEmptyCallbacks(); + } + + for (Ip6AddrCache &addrCache : mIp6AddrCacheList) + { + addrCache.ClearEmptyCallbacks(); + } + + for (Ip4AddrCache &addrCache : mIp4AddrCacheList) + { + addrCache.ClearEmptyCallbacks(); + } +} + +TimeMilli Core::RandomizeFirstProbeTxTime(void) +{ + // Randomizes the transmission time of the first probe, adding a + // delay between 20-250 msec. Subsequent probes within a short + // window reuse the same delay for efficient aggregation. + + TimeMilli now = TimerMilli::GetNow(); + + // The comparison using `(mNextProbeTxTime - now)` will work + // correctly even in the unlikely case that `now` has wrapped + // (49 days has passed) since `mNextProbeTxTime` was last set. + + if ((mNextProbeTxTime - now) >= kMaxProbeDelay) + { + mNextProbeTxTime = now + Random::NonCrypto::GetUint32InRange(kMinProbeDelay, kMaxProbeDelay); + } + + return mNextProbeTxTime; +} + +TimeMilli Core::RandomizeInitialQueryTxTime(void) +{ + TimeMilli now = TimerMilli::GetNow(); + + if ((mNextQueryTxTime - now) >= kMaxInitialQueryDelay) + { + mNextQueryTxTime = now + Random::NonCrypto::GetUint32InRange(kMinInitialQueryDelay, kMaxInitialQueryDelay); + } + + return mNextQueryTxTime; +} + +//--------------------------------------------------------------------------------------------------------------------- +// Core::ResultCallback + +void Core::ResultCallback::Invoke(Instance &aInstance, const BrowseResult &aResult) const +{ + if (mSharedCallback.mBrowse != nullptr) + { + mSharedCallback.mBrowse(&aInstance, &aResult); + } +} + +void Core::ResultCallback::Invoke(Instance &aInstance, const SrvResult &aResult) const +{ + if (mSharedCallback.mSrv != nullptr) + { + mSharedCallback.mSrv(&aInstance, &aResult); + } +} + +void Core::ResultCallback::Invoke(Instance &aInstance, const TxtResult &aResult) const +{ + if (mSharedCallback.mTxt != nullptr) + { + mSharedCallback.mTxt(&aInstance, &aResult); + } +} + +void Core::ResultCallback::Invoke(Instance &aInstance, const AddressResult &aResult) const +{ + if (mSharedCallback.mAddress != nullptr) + { + mSharedCallback.mAddress(&aInstance, &aResult); + } +} + +//--------------------------------------------------------------------------------------------------------------------- +// Core::CacheTimerContext + +Core::CacheTimerContext::CacheTimerContext(Instance &aInstance) + : TimerContext(aInstance) + , mQueryMessage(aInstance, TxMessage::kMulticastQuery) +{ +} + +//--------------------------------------------------------------------------------------------------------------------- +// Core::CacheRecordInfo + +Core::CacheRecordInfo::CacheRecordInfo(void) + : mTtl(0) + , mQueryCount(0) +{ +} + +bool Core::CacheRecordInfo::RefreshTtl(uint32_t aTtl) +{ + // Called when cached record is refreshed. + // Returns a boolean to indicate if TTL value + // was changed or not. + + bool changed = (aTtl != mTtl); + + mLastRxTime = TimerMilli::GetNow(); + mTtl = aTtl; + mQueryCount = 0; + + return changed; +} + +bool Core::CacheRecordInfo::ShouldExpire(TimeMilli aNow) const { return IsPresent() && (GetExpireTime() <= aNow); } + +void Core::CacheRecordInfo::UpdateStateAfterQuery(TimeMilli aNow) +{ + VerifyOrExit(IsPresent()); + + // If the less than half TTL remains, then this record would not + // be included as "Known-Answer" in the send query, so we can + // count it towards queries to refresh this record. + + VerifyOrExit(LessThanHalfTtlRemains(aNow)); + + if (mQueryCount < kNumberOfQueries) + { + mQueryCount++; + } + +exit: + return; +} + +void Core::CacheRecordInfo::UpdateQueryAndFireTimeOn(CacheEntry &aCacheEntry) +{ + TimeMilli now; + TimeMilli expireTime; + + VerifyOrExit(IsPresent()); + + now = TimerMilli::GetNow(); + expireTime = GetExpireTime(); + + aCacheEntry.SetFireTime(expireTime); + + // Determine next query time + + for (uint8_t attemptIndex = mQueryCount; attemptIndex < kNumberOfQueries; attemptIndex++) + { + TimeMilli queryTime = GetQueryTime(attemptIndex); + + if (queryTime > now) + { + queryTime += Random::NonCrypto::GetUint32InRange(0, GetClampedTtl() * kQueryTtlVariation); + aCacheEntry.ScheduleQuery(queryTime); + break; + } + } + +exit: + return; +} + +bool Core::CacheRecordInfo::LessThanHalfTtlRemains(TimeMilli aNow) const +{ + return IsPresent() && ((aNow - mLastRxTime) > TimeMilli::SecToMsec(GetClampedTtl()) / 2); +} + +uint32_t Core::CacheRecordInfo::GetRemainingTtl(TimeMilli aNow) const +{ + uint32_t remainingTtl = 0; + TimeMilli expireTime; + + VerifyOrExit(IsPresent()); + + expireTime = GetExpireTime(); + VerifyOrExit(aNow < expireTime); + + remainingTtl = TimeMilli::MsecToSec(expireTime - aNow); + +exit: + return remainingTtl; +} + +uint32_t Core::CacheRecordInfo::GetClampedTtl(void) const +{ + // We clamp TTL to `kMaxTtl` (one day) to prevent `TimeMilli` + // calculation overflow. + + return Min(mTtl, kMaxTtl); +} + +TimeMilli Core::CacheRecordInfo::GetExpireTime(void) const +{ + return mLastRxTime + TimeMilli::SecToMsec(GetClampedTtl()); +} + +TimeMilli Core::CacheRecordInfo::GetQueryTime(uint8_t aAttemptIndex) const +{ + // Queries are sent at 80%, 85%, 90% and 95% of TTL plus a random + // variation of 2% (added when sceduling) + + static const uint32_t kTtlFactors[kNumberOfQueries] = { + 80 * 1000 / 100, + 85 * 1000 / 100, + 90 * 1000 / 100, + 95 * 1000 / 100, + }; + + OT_ASSERT(aAttemptIndex < kNumberOfQueries); + + return mLastRxTime + kTtlFactors[aAttemptIndex] * GetClampedTtl(); +} + +//--------------------------------------------------------------------------------------------------------------------- +// Core::CacheEntry + +void Core::CacheEntry::Init(Instance &aInstance, Type aType) +{ + InstanceLocatorInit::Init(aInstance); + + mType = aType; + mInitalQueries = 0; + mQueryPending = false; + mLastQueryTimeValid = false; + mIsActive = false; + mDeleteTime = TimerMilli::GetNow() + kNonActiveDeleteTimeout; +} + +void Core::CacheEntry::SetIsActive(bool aIsActive) +{ + // Sets the active/passive state of a cache entry. An entry is + // considered "active" when associated with at least one + // resolver/browser. "Passive" entries (without a resolver/browser) + // continue to process mDNS responses for updates but will not send + // queries. Passive entries are deleted after `kNonActiveDeleteTimeout` + // if no resolver/browser is added. + + mIsActive = aIsActive; + + if (!mIsActive) + { + mQueryPending = false; + mDeleteTime = TimerMilli::GetNow() + kNonActiveDeleteTimeout; + SetFireTime(mDeleteTime); + } +} + +bool Core::CacheEntry::ShouldDelete(TimeMilli aNow) const { return !mIsActive && (mDeleteTime <= aNow); } + +void Core::CacheEntry::StartInitialQueries(void) +{ + mInitalQueries = 0; + mLastQueryTimeValid = false; + mLastQueryTime = Get().RandomizeInitialQueryTxTime(); + + ScheduleQuery(mLastQueryTime); +} + +bool Core::CacheEntry::ShouldQuery(TimeMilli aNow) { return mQueryPending && (mNextQueryTime <= aNow); } + +void Core::CacheEntry::ScheduleQuery(TimeMilli aQueryTime) +{ + VerifyOrExit(mIsActive); + + if (mQueryPending) + { + VerifyOrExit(aQueryTime < mNextQueryTime); + } + + if (mLastQueryTimeValid) + { + aQueryTime = Max(aQueryTime, mLastQueryTime + kMinIntervalBetweenQueries); + } + + mQueryPending = true; + mNextQueryTime = aQueryTime; + SetFireTime(mNextQueryTime); + +exit: + return; +} + +Error Core::CacheEntry::Add(const ResultCallback &aCallback) +{ + Error error = kErrorNone; + bool isFirst; + ResultCallback *callback; + + callback = FindCallbackMatching(aCallback); + VerifyOrExit(callback == nullptr, error = kErrorAlready); + + isFirst = mCallbacks.IsEmpty(); + + callback = ResultCallback::Allocate(aCallback); + OT_ASSERT(callback != nullptr); + + mCallbacks.Push(*callback); + + // If this is the first active resolver/browser for this cache + // entry, we mark it as active which allows queries to be sent We + // decide whether or not to send initial queries (e.g., for + // SRV/TXT cache entries we send initial queries if there is no + // record, or less than half TTL remains). + + if (isFirst) + { + bool shouldStart = false; + + SetIsActive(true); + + switch (mType) + { + case kBrowseCache: + shouldStart = true; + break; + case kSrvCache: + case kTxtCache: + shouldStart = As().ShouldStartInitialQueries(); + break; + case kIp6AddrCache: + case kIp4AddrCache: + shouldStart = As().ShouldStartInitialQueries(); + break; + } + + if (shouldStart) + { + StartInitialQueries(); + } + + DetermineNextFireTime(); + ScheduleTimer(); + } + + // Report any discovered/cached result to the newly added + // callback. + + switch (mType) + { + case kBrowseCache: + As().ReportResultsTo(*callback); + break; + case kSrvCache: + As().ReportResultTo(*callback); + break; + case kTxtCache: + As().ReportResultTo(*callback); + break; + case kIp6AddrCache: + case kIp4AddrCache: + As().ReportResultsTo(*callback); + break; + } + +exit: + return error; +} + +void Core::CacheEntry::Remove(const ResultCallback &aCallback) +{ + ResultCallback *callback = FindCallbackMatching(aCallback); + + VerifyOrExit(callback != nullptr); + + // We clear the callback setting it to `nullptr` without removing + // it from the list here, since the `Remove()` method may be + // called from a callback while we are iterating over the list. + // Removal from the list is deferred to `mCacheTask` which will + // later call `ClearEmptyCallbacks()`. + + callback->ClearCallback(); + Get().mCacheTask.Post(); + +exit: + return; +} + +void Core::CacheEntry::ClearEmptyCallbacks(void) +{ + CallbackList emptyCallbacks; + + mCallbacks.RemoveAllMatching(EmptyChecker(), emptyCallbacks); + + if (mCallbacks.IsEmpty()) + { + SetIsActive(false); + DetermineNextFireTime(); + ScheduleTimer(); + } +} + +void Core::CacheEntry::HandleTimer(CacheTimerContext &aContext) +{ + switch (mType) + { + case kBrowseCache: + As().ClearCompressOffsets(); + break; + + case kSrvCache: + case kTxtCache: + As().ClearCompressOffsets(); + break; + + case kIp6AddrCache: + case kIp4AddrCache: + // `AddrCache` entries do not track any append state or + // compress offset since the host name would not be used + // in any other query question. + break; + } + + VerifyOrExit(HasFireTime()); + VerifyOrExit(GetFireTime() <= aContext.GetNow()); + ClearFireTime(); + + if (ShouldDelete(aContext.GetNow())) + { + ExitNow(); + } + + if (ShouldQuery(aContext.GetNow())) + { + mQueryPending = false; + PrepareQuery(aContext); + } + + switch (mType) + { + case kBrowseCache: + As().ProcessExpiredRecords(aContext.GetNow()); + break; + case kSrvCache: + As().ProcessExpiredRecords(aContext.GetNow()); + break; + case kTxtCache: + As().ProcessExpiredRecords(aContext.GetNow()); + break; + case kIp6AddrCache: + case kIp4AddrCache: + As().ProcessExpiredRecords(aContext.GetNow()); + break; + } + + DetermineNextFireTime(); + +exit: + if (HasFireTime()) + { + aContext.UpdateNextTime(GetFireTime()); + } +} + +Core::ResultCallback *Core::CacheEntry::FindCallbackMatching(const ResultCallback &aCallback) +{ + ResultCallback *callback = nullptr; + + switch (mType) + { + case kBrowseCache: + callback = mCallbacks.FindMatching(aCallback.mSharedCallback.mBrowse); + break; + case kSrvCache: + callback = mCallbacks.FindMatching(aCallback.mSharedCallback.mSrv); + break; + case kTxtCache: + callback = mCallbacks.FindMatching(aCallback.mSharedCallback.mTxt); + break; + case kIp6AddrCache: + case kIp4AddrCache: + callback = mCallbacks.FindMatching(aCallback.mSharedCallback.mAddress); + break; + } + + return callback; +} + +void Core::CacheEntry::DetermineNextFireTime(void) +{ + mQueryPending = false; + + if (mInitalQueries < kNumberOfInitalQueries) + { + uint32_t interval = (mInitalQueries == 0) ? 0 : (1U << (mInitalQueries - 1)) * kInitialQueryInterval; + + ScheduleQuery(mLastQueryTime + interval); + } + + if (!mIsActive) + { + SetFireTime(mDeleteTime); + } + + // Let each cache entry type schedule query and + // fire times based on the state of its discovered + // records. + + switch (mType) + { + case kBrowseCache: + As().DetermineRecordFireTime(); + break; + case kSrvCache: + case kTxtCache: + As().DetermineRecordFireTime(); + break; + case kIp6AddrCache: + case kIp4AddrCache: + As().DetermineRecordFireTime(); + break; + } +} + +void Core::CacheEntry::ScheduleTimer(void) { ScheduleFireTimeOn(Get().mCacheTimer); } + +void Core::CacheEntry::PrepareQuery(CacheTimerContext &aContext) +{ + bool prepareAgain = false; + + do + { + TxMessage &query = aContext.GetQueryMessage(); + + query.SaveCurrentState(); + + switch (mType) + { + case kBrowseCache: + As().PreparePtrQuestion(query, aContext.GetNow()); + break; + case kSrvCache: + As().PrepareSrvQuestion(query); + break; + case kTxtCache: + As().PrepareTxtQuestion(query); + break; + case kIp6AddrCache: + As().PrepareAaaaQuestion(query); + break; + case kIp4AddrCache: + As().PrepareAQuestion(query); + break; + } + + query.CheckSizeLimitToPrepareAgain(prepareAgain); + + } while (prepareAgain); + + mLastQueryTimeValid = true; + mLastQueryTime = aContext.GetNow(); + + if (mInitalQueries < kNumberOfInitalQueries) + { + mInitalQueries++; + } + + // Let the cache entry super-classes update their state + // after query was sent. + + switch (mType) + { + case kBrowseCache: + As().UpdateRecordStateAfterQuery(aContext.GetNow()); + break; + case kSrvCache: + case kTxtCache: + As().UpdateRecordStateAfterQuery(aContext.GetNow()); + break; + case kIp6AddrCache: + case kIp4AddrCache: + As().UpdateRecordStateAfterQuery(aContext.GetNow()); + break; + } +} + +template void Core::CacheEntry::InvokeCallbacks(const ResultType &aResult) +{ + for (const ResultCallback &callback : mCallbacks) + { + callback.Invoke(GetInstance(), aResult); + } +} + +//--------------------------------------------------------------------------------------------------------------------- +// Core::BrowseCache + +Error Core::BrowseCache::Init(Instance &aInstance, const char *aServiceType, const char *aSubTypeLabel) +{ + Error error = kErrorNone; + + CacheEntry::Init(aInstance, kBrowseCache); + mNext = nullptr; + + ClearCompressOffsets(); + SuccessOrExit(error = mServiceType.Set(aServiceType)); + SuccessOrExit(error = mSubTypeLabel.Set(aSubTypeLabel)); + +exit: + return error; +} + +Error Core::BrowseCache::Init(Instance &aInstance, const Browser &aBrowser) +{ + return Init(aInstance, aBrowser.mServiceType, aBrowser.mSubTypeLabel); +} + +void Core::BrowseCache::ClearCompressOffsets(void) +{ + mServiceTypeOffset = kUnspecifiedOffset; + mSubServiceTypeOffset = kUnspecifiedOffset; + mSubServiceNameOffset = kUnspecifiedOffset; +} + +bool Core::BrowseCache::Matches(const Name &aFullName) const +{ + bool matches = false; + bool isSubType = !mSubTypeLabel.IsNull(); + Name name = aFullName; + + OT_ASSERT(name.IsFromMessage()); + + if (isSubType) + { + uint16_t offset; + const Message &message = name.GetAsMessage(offset); + + SuccessOrExit(Name::CompareLabel(message, offset, mSubTypeLabel.AsCString())); + name.SetFromMessage(message, offset); + } + + matches = name.Matches(isSubType ? kSubServiceLabel : nullptr, mServiceType.AsCString(), kLocalDomain); + +exit: + return matches; +} + +bool Core::BrowseCache::Matches(const char *aServiceType, const char *aSubTypeLabel) const +{ + bool matches = false; + + if (aSubTypeLabel == nullptr) + { + VerifyOrExit(mSubTypeLabel.IsNull()); + } + else + { + VerifyOrExit(NameMatch(mSubTypeLabel, aSubTypeLabel)); + } + + matches = NameMatch(mServiceType, aServiceType); + +exit: + return matches; +} + +bool Core::BrowseCache::Matches(const Browser &aBrowser) const +{ + return Matches(aBrowser.mServiceType, aBrowser.mSubTypeLabel); +} + +bool Core::BrowseCache::Matches(const ExpireChecker &aExpireChecker) const { return ShouldDelete(aExpireChecker.mNow); } + +Error Core::BrowseCache::Add(const Browser &aBrowser) { return CacheEntry::Add(ResultCallback(aBrowser.mCallback)); } + +void Core::BrowseCache::Remove(const Browser &aBrowser) { CacheEntry::Remove(ResultCallback(aBrowser.mCallback)); } + +void Core::BrowseCache::ProcessResponseRecord(const Message &aMessage, uint16_t aRecordOffset) +{ + // Name and record type in `aMessage` are already matched. + + uint16_t offset = aRecordOffset; + PtrRecord ptr; + Name::Buffer fullServiceType; + Name::Buffer serviceInstance; + BrowseResult result; + PtrEntry *ptrEntry; + bool changed = false; + + // Read the PTR record. `ReadPtrName()` validates that + // PTR record is well-formed. + + SuccessOrExit(aMessage.Read(offset, ptr)); + offset += sizeof(ptr); + SuccessOrExit(ptr.ReadPtrName(aMessage, offset, serviceInstance, fullServiceType)); + + VerifyOrExit(Name(fullServiceType).Matches(nullptr, mServiceType.AsCString(), kLocalDomain)); + + ptrEntry = mPtrEntries.FindMatching(serviceInstance); + + if (ptr.GetTtl() == 0) + { + VerifyOrExit((ptrEntry != nullptr) && ptrEntry->mRecord.IsPresent()); + + ptrEntry->mRecord.RefreshTtl(0); + changed = true; + } + else + { + if (ptrEntry == nullptr) + { + ptrEntry = PtrEntry::AllocateAndInit(serviceInstance); + VerifyOrExit(ptrEntry != nullptr); + mPtrEntries.Push(*ptrEntry); + } + + if (ptrEntry->mRecord.RefreshTtl(ptr.GetTtl())) + { + changed = true; + } + } + + VerifyOrExit(changed); + + if (ptrEntry->mRecord.IsPresent() && IsActive()) + { + Get().AddPassiveSrvTxtCache(ptrEntry->mServiceInstance.AsCString(), mServiceType.AsCString()); + } + + ptrEntry->ConvertTo(result, *this); + InvokeCallbacks(result); + +exit: + DetermineNextFireTime(); + ScheduleTimer(); +} + +void Core::BrowseCache::PreparePtrQuestion(TxMessage &aQuery, TimeMilli aNow) +{ + Question question; + + DiscoverCompressOffsets(); + + question.SetType(ResourceRecord::kTypePtr); + question.SetClass(ResourceRecord::kClassInternet); + + AppendServiceTypeOrSubTypeTo(aQuery, kQuestionSection); + SuccessOrAssert(aQuery.SelectMessageFor(kQuestionSection).Append(question)); + + aQuery.IncrementRecordCount(kQuestionSection); + + for (const PtrEntry &ptrEntry : mPtrEntries) + { + if (!ptrEntry.mRecord.IsPresent() || ptrEntry.mRecord.LessThanHalfTtlRemains(aNow)) + { + continue; + } + + AppendKnownAnswer(aQuery, ptrEntry, aNow); + } +} + +void Core::BrowseCache::DiscoverCompressOffsets(void) +{ + for (const BrowseCache &browseCache : Get().mBrowseCacheList) + { + if (&browseCache == this) + { + break; + } + + if (NameMatch(browseCache.mServiceType, mServiceType)) + { + UpdateCompressOffset(mServiceTypeOffset, browseCache.mServiceTypeOffset); + UpdateCompressOffset(mSubServiceTypeOffset, browseCache.mSubServiceTypeOffset); + VerifyOrExit(mSubServiceTypeOffset == kUnspecifiedOffset); + } + } + + VerifyOrExit(mServiceTypeOffset == kUnspecifiedOffset); + + for (const SrvCache &srvCache : Get().mSrvCacheList) + { + if (NameMatch(srvCache.mServiceType, mServiceType)) + { + UpdateCompressOffset(mServiceTypeOffset, srvCache.mServiceTypeOffset); + VerifyOrExit(mServiceTypeOffset == kUnspecifiedOffset); + } + } + + for (const TxtCache &txtCache : Get().mTxtCacheList) + { + if (NameMatch(txtCache.mServiceType, mServiceType)) + { + UpdateCompressOffset(mServiceTypeOffset, txtCache.mServiceTypeOffset); + VerifyOrExit(mServiceTypeOffset == kUnspecifiedOffset); + } + } + +exit: + return; +} + +void Core::BrowseCache::AppendServiceTypeOrSubTypeTo(TxMessage &aTxMessage, Section aSection) +{ + if (!mSubTypeLabel.IsNull()) + { + AppendOutcome outcome; + + outcome = aTxMessage.AppendLabel(aSection, mSubTypeLabel.AsCString(), mSubServiceNameOffset); + VerifyOrExit(outcome != kAppendedFullNameAsCompressed); + + outcome = aTxMessage.AppendLabel(aSection, kSubServiceLabel, mSubServiceTypeOffset); + VerifyOrExit(outcome != kAppendedFullNameAsCompressed); + } + + aTxMessage.AppendServiceType(aSection, mServiceType.AsCString(), mServiceTypeOffset); + +exit: + return; +} + +void Core::BrowseCache::AppendKnownAnswer(TxMessage &aTxMessage, const PtrEntry &aPtrEntry, TimeMilli aNow) +{ + Message &message = aTxMessage.SelectMessageFor(kAnswerSection); + PtrRecord ptr; + uint16_t offset; + + ptr.Init(); + ptr.SetTtl(aPtrEntry.mRecord.GetRemainingTtl(aNow)); + + AppendServiceTypeOrSubTypeTo(aTxMessage, kAnswerSection); + + offset = message.GetLength(); + SuccessOrAssert(message.Append(ptr)); + + SuccessOrAssert(Name::AppendLabel(aPtrEntry.mServiceInstance.AsCString(), message)); + aTxMessage.AppendServiceType(kAnswerSection, mServiceType.AsCString(), mServiceTypeOffset); + + UpdateRecordLengthInMessage(ptr, message, offset); + + aTxMessage.IncrementRecordCount(kAnswerSection); +} + +void Core::BrowseCache::UpdateRecordStateAfterQuery(TimeMilli aNow) +{ + for (PtrEntry &ptrEntry : mPtrEntries) + { + ptrEntry.mRecord.UpdateStateAfterQuery(aNow); + } +} + +void Core::BrowseCache::DetermineRecordFireTime(void) +{ + for (PtrEntry &ptrEntry : mPtrEntries) + { + ptrEntry.mRecord.UpdateQueryAndFireTimeOn(*this); + } +} + +void Core::BrowseCache::ProcessExpiredRecords(TimeMilli aNow) +{ + OwningList expiredEntries; + + mPtrEntries.RemoveAllMatching(ExpireChecker(aNow), expiredEntries); + + for (PtrEntry &exiredEntry : expiredEntries) + { + BrowseResult result; + + exiredEntry.mRecord.RefreshTtl(0); + + exiredEntry.ConvertTo(result, *this); + InvokeCallbacks(result); + } +} + +void Core::BrowseCache::ReportResultsTo(ResultCallback &aCallback) const +{ + for (const PtrEntry &ptrEntry : mPtrEntries) + { + if (ptrEntry.mRecord.IsPresent()) + { + BrowseResult result; + + ptrEntry.ConvertTo(result, *this); + aCallback.Invoke(GetInstance(), result); + } + } +} + +//--------------------------------------------------------------------------------------------------------------------- +// Core::BrowseCache::PtrEntry + +Error Core::BrowseCache::PtrEntry::Init(const char *aServiceInstance) +{ + mNext = nullptr; + + return mServiceInstance.Set(aServiceInstance); +} + +bool Core::BrowseCache::PtrEntry::Matches(const ExpireChecker &aExpireChecker) const +{ + return mRecord.ShouldExpire(aExpireChecker.mNow); +} + +void Core::BrowseCache::PtrEntry::ConvertTo(BrowseResult &aResult, const BrowseCache &aBrowseCache) const +{ + ClearAllBytes(aResult); + + aResult.mServiceType = aBrowseCache.mServiceType.AsCString(); + aResult.mSubTypeLabel = aBrowseCache.mSubTypeLabel.AsCString(); + aResult.mServiceInstance = mServiceInstance.AsCString(); + aResult.mTtl = mRecord.GetTtl(); + aResult.mInfraIfIndex = aBrowseCache.Get().mInfraIfIndex; +} + +//--------------------------------------------------------------------------------------------------------------------- +// Core::ServiceCache + +Error Core::ServiceCache::Init(Instance &aInstance, Type aType, const char *aServiceInstance, const char *aServiceType) +{ + Error error = kErrorNone; + + CacheEntry::Init(aInstance, aType); + ClearCompressOffsets(); + SuccessOrExit(error = mServiceInstance.Set(aServiceInstance)); + SuccessOrExit(error = mServiceType.Set(aServiceType)); + +exit: + return error; +} + +void Core::ServiceCache::ClearCompressOffsets(void) +{ + mServiceNameOffset = kUnspecifiedOffset; + mServiceTypeOffset = kUnspecifiedOffset; +} + +bool Core::ServiceCache::Matches(const Name &aFullName) const +{ + return aFullName.Matches(mServiceInstance.AsCString(), mServiceType.AsCString(), kLocalDomain); +} + +bool Core::ServiceCache::Matches(const char *aServiceInstance, const char *aServiceType) const +{ + return NameMatch(mServiceInstance, aServiceInstance) && NameMatch(mServiceType, aServiceType); +} + +void Core::ServiceCache::PrepareQueryQuestion(TxMessage &aQuery, uint16_t aRrType) +{ + Message &message = aQuery.SelectMessageFor(kQuestionSection); + Question question; + + question.SetType(aRrType); + question.SetClass(ResourceRecord::kClassInternet); + + AppendServiceNameTo(aQuery, kQuestionSection); + SuccessOrAssert(message.Append(question)); + + aQuery.IncrementRecordCount(kQuestionSection); +} + +void Core::ServiceCache::AppendServiceNameTo(TxMessage &aTxMessage, Section aSection) +{ + AppendOutcome outcome; + + outcome = aTxMessage.AppendLabel(aSection, mServiceInstance.AsCString(), mServiceNameOffset); + VerifyOrExit(outcome != kAppendedFullNameAsCompressed); + + aTxMessage.AppendServiceType(aSection, mServiceType.AsCString(), mServiceTypeOffset); + +exit: + return; +} + +void Core::ServiceCache::UpdateRecordStateAfterQuery(TimeMilli aNow) { mRecord.UpdateStateAfterQuery(aNow); } + +void Core::ServiceCache::DetermineRecordFireTime(void) { mRecord.UpdateQueryAndFireTimeOn(*this); } + +bool Core::ServiceCache::ShouldStartInitialQueries(void) const +{ + // This is called when the first active resolver is added + // for this cache entry to determine whether we should + // send initial queries. It is possible that we were passively + // monitoring and have some cached record for this entry. + // We send initial queries if there is no record or less than + // half of the original TTL remains. + + return !mRecord.IsPresent() || mRecord.LessThanHalfTtlRemains(TimerMilli::GetNow()); +} + +//--------------------------------------------------------------------------------------------------------------------- +// Core::SrvCache + +Error Core::SrvCache::Init(Instance &aInstance, const char *aServiceInstance, const char *aServiceType) +{ + mNext = nullptr; + mPort = 0; + mPriority = 0; + mWeight = 0; + + return ServiceCache::Init(aInstance, kSrvCache, aServiceInstance, aServiceType); +} + +Error Core::SrvCache::Init(Instance &aInstance, const ServiceName &aServiceName) +{ + return Init(aInstance, aServiceName.mServiceInstance, aServiceName.mServiceType); +} + +Error Core::SrvCache::Init(Instance &aInstance, const SrvResolver &aResolver) +{ + return Init(aInstance, aResolver.mServiceInstance, aResolver.mServiceType); +} + +bool Core::SrvCache::Matches(const Name &aFullName) const { return ServiceCache::Matches(aFullName); } + +bool Core::SrvCache::Matches(const ServiceName &aServiceName) const +{ + return ServiceCache::Matches(aServiceName.mServiceInstance, aServiceName.mServiceType); +} + +bool Core::SrvCache::Matches(const SrvResolver &aResolver) const +{ + return ServiceCache::Matches(aResolver.mServiceInstance, aResolver.mServiceType); +} + +bool Core::SrvCache::Matches(const ExpireChecker &aExpireChecker) const { return ShouldDelete(aExpireChecker.mNow); } + +Error Core::SrvCache::Add(const SrvResolver &aResolver) { return CacheEntry::Add(ResultCallback(aResolver.mCallback)); } + +void Core::SrvCache::Remove(const SrvResolver &aResolver) { CacheEntry::Remove(ResultCallback(aResolver.mCallback)); } + +void Core::SrvCache::ProcessResponseRecord(const Message &aMessage, uint16_t aRecordOffset) +{ + // Name and record type in `aMessage` are already matched. + + uint16_t offset = aRecordOffset; + SrvRecord srv; + Name::Buffer hostFullName; + Name::Buffer hostName; + SrvResult result; + bool changed = false; + + // Read the SRV record. `ReadTargetHostName()` validates that + // SRV record is well-formed. + + SuccessOrExit(aMessage.Read(offset, srv)); + offset += sizeof(srv); + SuccessOrExit(srv.ReadTargetHostName(aMessage, offset, hostFullName)); + + SuccessOrExit(Name::ExtractLabels(hostFullName, kLocalDomain, hostName)); + + if (srv.GetTtl() == 0) + { + VerifyOrExit(mRecord.IsPresent()); + + mHostName.Free(); + mRecord.RefreshTtl(0); + changed = true; + } + else + { + if (!mRecord.IsPresent() || !NameMatch(mHostName, hostName)) + { + SuccessOrAssert(mHostName.Set(hostName)); + changed = true; + } + + if (!mRecord.IsPresent() || (mPort != srv.GetPort())) + { + mPort = srv.GetPort(); + changed = true; + } + + if (!mRecord.IsPresent() || (mPriority != srv.GetPriority())) + { + mPriority = srv.GetPriority(); + changed = true; + } + + if (!mRecord.IsPresent() || (mWeight != srv.GetWeight())) + { + mWeight = srv.GetWeight(); + changed = true; + } + + if (mRecord.RefreshTtl(srv.GetTtl())) + { + changed = true; + } + } + + VerifyOrExit(changed); + + if (mRecord.IsPresent()) + { + StopInitialQueries(); + + // If not present already, we add a passive `TxtCache` for the + // same service name, and an `Ip6AddrCache` for the host name. + + Get().AddPassiveSrvTxtCache(mServiceInstance.AsCString(), mServiceType.AsCString()); + Get().AddPassiveIp6AddrCache(mHostName.AsCString()); + } + + ConvertTo(result); + InvokeCallbacks(result); + +exit: + DetermineNextFireTime(); + ScheduleTimer(); +} + +void Core::SrvCache::PrepareSrvQuestion(TxMessage &aQuery) +{ + DiscoverCompressOffsets(); + PrepareQueryQuestion(aQuery, ResourceRecord::kTypeSrv); +} + +void Core::SrvCache::DiscoverCompressOffsets(void) +{ + for (const SrvCache &srvCache : Get().mSrvCacheList) + { + if (&srvCache == this) + { + break; + } + + if (NameMatch(srvCache.mServiceType, mServiceType)) + { + UpdateCompressOffset(mServiceTypeOffset, srvCache.mServiceTypeOffset); + } + + if (mServiceTypeOffset != kUnspecifiedOffset) + { + break; + } + } +} + +void Core::SrvCache::ProcessExpiredRecords(TimeMilli aNow) +{ + if (mRecord.ShouldExpire(aNow)) + { + SrvResult result; + + mRecord.RefreshTtl(0); + + ConvertTo(result); + InvokeCallbacks(result); + } +} + +void Core::SrvCache::ReportResultTo(ResultCallback &aCallback) const +{ + if (mRecord.IsPresent()) + { + SrvResult result; + + ConvertTo(result); + aCallback.Invoke(GetInstance(), result); + } +} + +void Core::SrvCache::ConvertTo(SrvResult &aResult) const +{ + ClearAllBytes(aResult); + + aResult.mServiceInstance = mServiceInstance.AsCString(); + aResult.mServiceType = mServiceType.AsCString(); + aResult.mHostName = mHostName.AsCString(); + aResult.mPort = mPort; + aResult.mPriority = mPriority; + aResult.mWeight = mWeight; + aResult.mTtl = mRecord.GetTtl(); + aResult.mInfraIfIndex = Get().mInfraIfIndex; +} + +//--------------------------------------------------------------------------------------------------------------------- +// Core::TxtCache + +Error Core::TxtCache::Init(Instance &aInstance, const char *aServiceInstance, const char *aServiceType) +{ + mNext = nullptr; + + return ServiceCache::Init(aInstance, kTxtCache, aServiceInstance, aServiceType); +} + +Error Core::TxtCache::Init(Instance &aInstance, const ServiceName &aServiceName) +{ + return Init(aInstance, aServiceName.mServiceInstance, aServiceName.mServiceType); +} + +Error Core::TxtCache::Init(Instance &aInstance, const TxtResolver &aResolver) +{ + return Init(aInstance, aResolver.mServiceInstance, aResolver.mServiceType); +} + +bool Core::TxtCache::Matches(const Name &aFullName) const { return ServiceCache::Matches(aFullName); } + +bool Core::TxtCache::Matches(const ServiceName &aServiceName) const +{ + return ServiceCache::Matches(aServiceName.mServiceInstance, aServiceName.mServiceType); +} + +bool Core::TxtCache::Matches(const TxtResolver &aResolver) const +{ + return ServiceCache::Matches(aResolver.mServiceInstance, aResolver.mServiceType); +} + +bool Core::TxtCache::Matches(const ExpireChecker &aExpireChecker) const { return ShouldDelete(aExpireChecker.mNow); } + +Error Core::TxtCache::Add(const TxtResolver &aResolver) { return CacheEntry::Add(ResultCallback(aResolver.mCallback)); } + +void Core::TxtCache::Remove(const TxtResolver &aResolver) { CacheEntry::Remove(ResultCallback(aResolver.mCallback)); } + +void Core::TxtCache::ProcessResponseRecord(const Message &aMessage, uint16_t aRecordOffset) +{ + // Name and record type are already matched. + + uint16_t offset = aRecordOffset; + TxtRecord txt; + TxtResult result; + bool changed = false; + + SuccessOrExit(aMessage.Read(offset, txt)); + offset += sizeof(txt); + + if (txt.GetTtl() == 0) + { + VerifyOrExit(mRecord.IsPresent()); + + mTxtData.Free(); + mRecord.RefreshTtl(0); + changed = true; + } + else + { + VerifyOrExit(txt.GetLength() > 0); + VerifyOrExit(aMessage.GetLength() >= offset + txt.GetLength()); + + if (!mRecord.IsPresent() || (mTxtData.GetLength() != txt.GetLength()) || + !aMessage.CompareBytes(offset, mTxtData.GetBytes(), mTxtData.GetLength())) + { + SuccessOrAssert(mTxtData.SetFrom(aMessage, offset, txt.GetLength())); + changed = true; + } + + if (mRecord.RefreshTtl(txt.GetTtl())) + { + changed = true; + } + } + + VerifyOrExit(changed); + + if (mRecord.IsPresent()) + { + StopInitialQueries(); + } + + ConvertTo(result); + InvokeCallbacks(result); + +exit: + DetermineNextFireTime(); + ScheduleTimer(); +} + +void Core::TxtCache::PrepareTxtQuestion(TxMessage &aQuery) +{ + DiscoverCompressOffsets(); + PrepareQueryQuestion(aQuery, ResourceRecord::kTypeTxt); +} + +void Core::TxtCache::DiscoverCompressOffsets(void) +{ + for (const SrvCache &srvCache : Get().mSrvCacheList) + { + if (!NameMatch(srvCache.mServiceType, mServiceType)) + { + continue; + } + + UpdateCompressOffset(mServiceTypeOffset, srvCache.mServiceTypeOffset); + + if (NameMatch(srvCache.mServiceInstance, mServiceInstance)) + { + UpdateCompressOffset(mServiceNameOffset, srvCache.mServiceNameOffset); + } + + VerifyOrExit(mServiceNameOffset == kUnspecifiedOffset); + } + + for (const TxtCache &txtCache : Get().mTxtCacheList) + { + if (&txtCache == this) + { + break; + } + + if (NameMatch(txtCache.mServiceType, mServiceType)) + { + UpdateCompressOffset(mServiceTypeOffset, txtCache.mServiceTypeOffset); + } + + VerifyOrExit(mServiceTypeOffset == kUnspecifiedOffset); + } + +exit: + return; +} + +void Core::TxtCache::ProcessExpiredRecords(TimeMilli aNow) +{ + if (mRecord.ShouldExpire(aNow)) + { + TxtResult result; + + mRecord.RefreshTtl(0); + + ConvertTo(result); + InvokeCallbacks(result); + } +} + +void Core::TxtCache::ReportResultTo(ResultCallback &aCallback) const +{ + if (mRecord.IsPresent()) + { + TxtResult result; + + ConvertTo(result); + aCallback.Invoke(GetInstance(), result); + } +} + +void Core::TxtCache::ConvertTo(TxtResult &aResult) const +{ + ClearAllBytes(aResult); + + aResult.mServiceInstance = mServiceInstance.AsCString(); + aResult.mServiceType = mServiceType.AsCString(); + aResult.mTxtData = mTxtData.GetBytes(); + aResult.mTxtDataLength = mTxtData.GetLength(); + aResult.mTtl = mRecord.GetTtl(); + aResult.mInfraIfIndex = Get().mInfraIfIndex; +} + +//--------------------------------------------------------------------------------------------------------------------- +// Core::AddrCache + +Error Core::AddrCache::Init(Instance &aInstance, Type aType, const char *aHostName) +{ + CacheEntry::Init(aInstance, aType); + + mNext = nullptr; + mShouldFlush = false; + + return mName.Set(aHostName); +} + +Error Core::AddrCache::Init(Instance &aInstance, Type aType, const AddressResolver &aResolver) +{ + return Init(aInstance, aType, aResolver.mHostName); +} + +bool Core::AddrCache::Matches(const Name &aFullName) const +{ + return aFullName.Matches(nullptr, mName.AsCString(), kLocalDomain); +} + +bool Core::AddrCache::Matches(const char *aName) const { return NameMatch(mName, aName); } + +bool Core::AddrCache::Matches(const AddressResolver &aResolver) const { return Matches(aResolver.mHostName); } + +bool Core::AddrCache::Matches(const ExpireChecker &aExpireChecker) const { return ShouldDelete(aExpireChecker.mNow); } + +Error Core::AddrCache::Add(const AddressResolver &aResolver) +{ + return CacheEntry::Add(ResultCallback(aResolver.mCallback)); +} + +void Core::AddrCache::Remove(const AddressResolver &aResolver) +{ + CacheEntry::Remove(ResultCallback(aResolver.mCallback)); +} + +void Core::AddrCache::PrepareQueryQuestion(TxMessage &aQuery, uint16_t aRrType) +{ + Question question; + + question.SetType(aRrType); + question.SetClass(ResourceRecord::kClassInternet); + + AppendNameTo(aQuery, kQuestionSection); + SuccessOrAssert(aQuery.SelectMessageFor(kQuestionSection).Append(question)); + + aQuery.IncrementRecordCount(kQuestionSection); +} + +void Core::AddrCache::AppendNameTo(TxMessage &aTxMessage, Section aSection) +{ + uint16_t compressOffset = kUnspecifiedOffset; + + AppendOutcome outcome; + + outcome = aTxMessage.AppendMultipleLabels(aSection, mName.AsCString(), compressOffset); + VerifyOrExit(outcome != kAppendedFullNameAsCompressed); + + aTxMessage.AppendDomainName(aSection); + +exit: + return; +} + +void Core::AddrCache::UpdateRecordStateAfterQuery(TimeMilli aNow) +{ + for (AddrEntry &entry : mCommittedEntries) + { + entry.mRecord.UpdateStateAfterQuery(aNow); + } +} + +void Core::AddrCache::DetermineRecordFireTime(void) +{ + for (AddrEntry &entry : mCommittedEntries) + { + entry.mRecord.UpdateQueryAndFireTimeOn(*this); + } +} + +void Core::AddrCache::ProcessExpiredRecords(TimeMilli aNow) +{ + OwningList expiredEntries; + Heap::Array addrArray; + AddressResult result; + + mCommittedEntries.RemoveAllMatching(ExpireChecker(aNow), expiredEntries); + + VerifyOrExit(!expiredEntries.IsEmpty()); + + ConstructResult(result, addrArray); + InvokeCallbacks(result); + +exit: + return; +} + +void Core::AddrCache::ReportResultsTo(ResultCallback &aCallback) const +{ + Heap::Array addrArray; + AddressResult result; + + ConstructResult(result, addrArray); + + if (result.mAddressesLength > 0) + { + aCallback.Invoke(GetInstance(), result); + } +} + +void Core::AddrCache::ConstructResult(AddressResult &aResult, Heap::Array &aAddrArray) const +{ + // Prepares an `AddressResult` populating it with discovered + // addresses from the `AddrCache` entry. Uses a caller-provided + // `Heap::Array` reference (`aAddrArray`) to ensure that the + // allocated array for `aResult.mAddresses` remains valid until + // after the `aResult` is used (passed as input to + // `ResultCallback`). + + uint16_t addrCount = 0; + + ClearAllBytes(aResult); + aAddrArray.Free(); + + for (const AddrEntry &entry : mCommittedEntries) + { + if (entry.mRecord.IsPresent()) + { + addrCount++; + } + } + + if (addrCount > 0) + { + SuccessOrAssert(aAddrArray.ReserveCapacity(addrCount)); + + for (const AddrEntry &entry : mCommittedEntries) + { + AddressAndTtl *addr; + + if (!entry.mRecord.IsPresent()) + { + continue; + } + + addr = aAddrArray.PushBack(); + OT_ASSERT(addr != nullptr); + + addr->mAddress = entry.mAddress; + addr->mTtl = entry.mRecord.GetTtl(); + } + } + + aResult.mHostName = mName.AsCString(); + aResult.mAddresses = aAddrArray.AsCArray(); + aResult.mAddressesLength = aAddrArray.GetLength(); + aResult.mInfraIfIndex = Get().mInfraIfIndex; +} + +bool Core::AddrCache::ShouldStartInitialQueries(void) const +{ + // This is called when the first active resolver is added + // for this cache entry to determine whether we should + // send initial queries. It is possible that we were passively + // monitoring and has some cached records for this entry. + // We send initial queries if there is no record or less than + // half of original TTL remains on any record. + + bool shouldStart = false; + TimeMilli now = TimerMilli::GetNow(); + + if (mCommittedEntries.IsEmpty()) + { + shouldStart = true; + ExitNow(); + } + + for (const AddrEntry &entry : mCommittedEntries) + { + if (entry.mRecord.LessThanHalfTtlRemains(now)) + { + shouldStart = true; + ExitNow(); + } + } + +exit: + return shouldStart; +} + +void Core::AddrCache::AddNewResponseAddress(const Ip6::Address &aAddress, uint32_t aTtl, bool aCacheFlush) +{ + // Adds a new address record to `mNewEntries` list. This called as + // the records in a received response are processed one by one. + // Once all records are processed `CommitNewResponseEntries()` is + // called to update the list of addresses. + + AddrEntry *entry; + + if (aCacheFlush) + { + mShouldFlush = true; + } + + // Check for duplicate addresses in the same response. + + entry = mNewEntries.FindMatching(aAddress); + + if (entry == nullptr) + { + entry = AddrEntry::Allocate(aAddress); + OT_ASSERT(entry != nullptr); + mNewEntries.Push(*entry); + } + + entry->mRecord.RefreshTtl(aTtl); +} + +void Core::AddrCache::CommitNewResponseEntries(void) +{ + bool changed = false; + + // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + // Determine whether there is any changes to the list of addresses + // between the `mNewEntries` and `mCommittedEntries` lists. + // + // First, we verify if all new entries are present in the + // `mCommittedEntries` list with the same TTL value. Next, if we + // need to flush the old cache list, we check if any existing + // `mCommittedEntries` is absent in `mNewEntries` list. + + for (const AddrEntry &newEntry : mNewEntries) + { + AddrEntry *exitingEntry = mCommittedEntries.FindMatching(newEntry.mAddress); + + if (newEntry.GetTtl() == 0) + { + // New entry has zero TTL, removing the address. If we + // have a matching `exitingEntry` we set its TTL to zero + // so to remove it in the next step when updating the + // `mCommittedEntries` list. + + if (exitingEntry != nullptr) + { + exitingEntry->mRecord.RefreshTtl(0); + changed = true; + } + } + else if ((exitingEntry == nullptr) || (exitingEntry->GetTtl() != newEntry.GetTtl())) + { + changed = true; + } + } + + if (mShouldFlush && !changed) + { + for (const AddrEntry &exitingEntry : mCommittedEntries) + { + if ((exitingEntry.GetTtl() > 0) && !mNewEntries.ContainsMatching(exitingEntry.mAddress)) + { + changed = true; + break; + } + } + } + + // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + // Update the `mCommittedEntries` list. + + // First remove entries, if we need to flush clear everything, + // otherwise remove the ones with zero TTL marked in previous + // step. Then, add or update new entries to `mCommittedEntries` + + if (mShouldFlush) + { + mCommittedEntries.Clear(); + mShouldFlush = false; + } + else + { + OwningList removedEntries; + + mCommittedEntries.RemoveAllMatching(EmptyChecker(), removedEntries); + } + + while (!mNewEntries.IsEmpty()) + { + OwnedPtr newEntry = mNewEntries.Pop(); + AddrEntry *entry; + + if (newEntry->GetTtl() == 0) + { + continue; + } + + entry = mCommittedEntries.FindMatching(newEntry->mAddress); + + if (entry != nullptr) + { + entry->mRecord.RefreshTtl(newEntry->GetTtl()); + } + else + { + mCommittedEntries.Push(*newEntry.Release()); + } + } + + StopInitialQueries(); + + // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + // Invoke callbacks if there is any change. + + if (changed) + { + Heap::Array addrArray; + AddressResult result; + + ConstructResult(result, addrArray); + InvokeCallbacks(result); + } + + DetermineNextFireTime(); + ScheduleTimer(); +} + +//--------------------------------------------------------------------------------------------------------------------- +// Core::AddrCache::AddrEntry + +Core::AddrCache::AddrEntry::AddrEntry(const Ip6::Address &aAddress) + : mNext(nullptr) + , mAddress(aAddress) +{ +} + +bool Core::AddrCache::AddrEntry::Matches(const ExpireChecker &aExpireChecker) const +{ + return mRecord.ShouldExpire(aExpireChecker.mNow); +} + +bool Core::AddrCache::AddrEntry::Matches(EmptyChecker aChecker) const +{ + OT_UNUSED_VARIABLE(aChecker); + + return !mRecord.IsPresent(); +} + +//--------------------------------------------------------------------------------------------------------------------- +// Core::Ip6AddrCache + +Error Core::Ip6AddrCache::Init(Instance &aInstance, const char *aHostName) +{ + return AddrCache::Init(aInstance, kIp6AddrCache, aHostName); +} + +Error Core::Ip6AddrCache::Init(Instance &aInstance, const AddressResolver &aResolver) +{ + return AddrCache::Init(aInstance, kIp6AddrCache, aResolver); +} + +void Core::Ip6AddrCache::ProcessResponseRecord(const Message &aMessage, uint16_t aRecordOffset) +{ + // Name and record type in `aMessage` are already matched. + + AaaaRecord aaaaRecord; + + SuccessOrExit(aMessage.Read(aRecordOffset, aaaaRecord)); + VerifyOrExit(aaaaRecord.GetLength() >= sizeof(Ip6::Address)); + + AddNewResponseAddress(aaaaRecord.GetAddress(), aaaaRecord.GetTtl(), aaaaRecord.GetClass() & kClassCacheFlushFlag); + +exit: + return; +} + +void Core::Ip6AddrCache::PrepareAaaaQuestion(TxMessage &aQuery) +{ + PrepareQueryQuestion(aQuery, ResourceRecord::kTypeAaaa); +} + +//--------------------------------------------------------------------------------------------------------------------- +// Core::Ip4AddrCache + +Error Core::Ip4AddrCache::Init(Instance &aInstance, const char *aHostName) +{ + return AddrCache::Init(aInstance, kIp4AddrCache, aHostName); +} + +Error Core::Ip4AddrCache::Init(Instance &aInstance, const AddressResolver &aResolver) +{ + return AddrCache::Init(aInstance, kIp4AddrCache, aResolver); +} + +void Core::Ip4AddrCache::ProcessResponseRecord(const Message &aMessage, uint16_t aRecordOffset) +{ + // Name and record type in `aMessage` are already matched. + + ARecord aRecord; + Ip6::Address address; + + SuccessOrExit(aMessage.Read(aRecordOffset, aRecord)); + VerifyOrExit(aRecord.GetLength() >= sizeof(Ip4::Address)); + + address.SetToIp4Mapped(aRecord.GetAddress()); + + AddNewResponseAddress(address, aRecord.GetTtl(), aRecord.GetClass() & kClassCacheFlushFlag); + +exit: + return; +} + +void Core::Ip4AddrCache::PrepareAQuestion(TxMessage &aQuery) { PrepareQueryQuestion(aQuery, ResourceRecord::kTypeA); } + +} // namespace Multicast +} // namespace Dns +} // namespace ot + +//--------------------------------------------------------------------------------------------------------------------- + +#if OPENTHREAD_CONFIG_MULTICAST_DNS_MOCK_PLAT_APIS_ENABLE + +OT_TOOL_WEAK otError otPlatMdnsSetListeningEnabled(otInstance *aInstance, bool aEnable, uint32_t aInfraIfIndex) +{ + OT_UNUSED_VARIABLE(aInstance); + OT_UNUSED_VARIABLE(aEnable); + OT_UNUSED_VARIABLE(aInfraIfIndex); + + return OT_ERROR_FAILED; +} + +OT_TOOL_WEAK void otPlatMdnsSendMulticast(otInstance *aInstance, otMessage *aMessage, uint32_t aInfraIfIndex) +{ + OT_UNUSED_VARIABLE(aInstance); + OT_UNUSED_VARIABLE(aMessage); + OT_UNUSED_VARIABLE(aInfraIfIndex); +} + +OT_TOOL_WEAK void otPlatMdnsSendUnicast(otInstance *aInstance, + otMessage *aMessage, + const otPlatMdnsAddressInfo *aAddress) +{ + OT_UNUSED_VARIABLE(aInstance); + OT_UNUSED_VARIABLE(aMessage); + OT_UNUSED_VARIABLE(aAddress); +} + +#endif // OPENTHREAD_CONFIG_MULTICAST_DNS_MOCK_PLAT_APIS_ENABLE + +#endif // OPENTHREAD_CONFIG_MULTICAST_DNS_ENABLE diff --git a/src/core/net/mdns.hpp b/src/core/net/mdns.hpp new file mode 100644 index 000000000..f158328cf --- /dev/null +++ b/src/core/net/mdns.hpp @@ -0,0 +1,1812 @@ +/* + * Copyright (c) 2024, The OpenThread Authors. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the name of the copyright holder nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#ifndef MULTICAST_DNS_HPP_ +#define MULTICAST_DNS_HPP_ + +#include "openthread-core-config.h" + +#if OPENTHREAD_CONFIG_MULTICAST_DNS_ENABLE + +#include +#include + +#include "common/as_core_type.hpp" +#include "common/clearable.hpp" +#include "common/debug.hpp" +#include "common/equatable.hpp" +#include "common/error.hpp" +#include "common/heap_allocatable.hpp" +#include "common/heap_array.hpp" +#include "common/heap_data.hpp" +#include "common/heap_string.hpp" +#include "common/linked_list.hpp" +#include "common/owned_ptr.hpp" +#include "common/owning_list.hpp" +#include "common/retain_ptr.hpp" +#include "common/timer.hpp" +#include "crypto/sha256.hpp" +#include "net/dns_types.hpp" + +/** + * @file + * This file includes definitions for the Multicast DNS per RFC 6762. + * + */ + +namespace ot { +namespace Dns { +namespace Multicast { + +extern "C" void otPlatMdnsHandleReceive(otInstance *aInstance, + otMessage *aMessage, + bool aIsUnicast, + const otPlatMdnsAddressInfo *aAddress); + +/** + * Implements Multicast DNS (mDNS) core. + * + */ +class Core : public InstanceLocator, private NonCopyable +{ + friend void otPlatMdnsHandleReceive(otInstance *aInstance, + otMessage *aMessage, + bool aIsUnicast, + const otPlatMdnsAddressInfo *aAddress); + +public: + /** + * Initializes a `Core` instance. + * + * @param[in] aInstance The OpenThread instance. + * + */ + explicit Core(Instance &aInstance); + + typedef otMdnsRequestId RequestId; ///< A request Identifier. + typedef otMdnsRegisterCallback RegisterCallback; ///< Registration callback. + typedef otMdnsConflictCallback ConflictCallback; ///< Conflict callback. + typedef otMdnsHost Host; ///< Host information. + typedef otMdnsService Service; ///< Service information. + typedef otMdnsKey Key; ///< Key information. + typedef otMdnsBrowser Browser; ///< Browser. + typedef otMdnsBrowseCallback BrowseCallback; ///< Browser callback. + typedef otMdnsBrowseResult BrowseResult; ///< Browser result. + typedef otMdnsSrvResolver SrvResolver; ///< SRV resolver. + typedef otMdnsSrvCallback SrvCallback; ///< SRV callback. + typedef otMdnsSrvResult SrvResult; ///< SRV result. + typedef otMdnsTxtResolver TxtResolver; ///< TXT resolver. + typedef otMdnsTxtCallback TxtCallback; ///< TXT callback. + typedef otMdnsTxtResult TxtResult; ///< TXT result. + typedef otMdnsAddressResolver AddressResolver; ///< Address resolver. + typedef otMdnsAddressCallback AddressCallback; ///< Address callback + typedef otMdnsAddressResult AddressResult; ///< Address result. + typedef otMdnsAddressAndTtl AddressAndTtl; ///< Address and TTL. + + /** + * Represents a socket address info. + * + */ + class AddressInfo : public otPlatMdnsAddressInfo, public Clearable, public Equatable + { + public: + /** + * Initializes the `AddressInfo` clearing all the fields. + * + */ + AddressInfo(void) { Clear(); } + + /** + * Gets the IPv6 address. + * + * @returns the IPv6 address. + * + */ + const Ip6::Address &GetAddress(void) const { return AsCoreType(&mAddress); } + }; + + /** + * Enables or disables the mDNS module. + * + * mDNS module should be enabled before registration any host, service, or key entries. Disabling mDNS will + * immediately stop all operations and any communication (multicast or unicast tx) and remove any previously + * registered entries without sending any "goodbye" announcements or invoking their callback. When disabled, + * all browsers and resolvers are stopped and all cached information is cleared. + * + * @param[in] aEnable Whether to enable or disable. + * @param[in] aInfraIfIndex The network interface index for mDNS operation. Value is ignored when disabling. + * + * @retval kErrorNone Enabled or disabled the mDNS module successfully. + * @retval kErrorAlready mDNS is already enabled on an enable request, or is already disabled on a disable request. + * @retval kErrorFailed Failed to enable/disable mDNS. + * + */ + Error SetEnabled(bool aEnable, uint32_t aInfraIfIndex); + + /** + * Indicates whether or not mDNS module is enabled. + * + * @retval TRUE The mDNS module is enabled. + * @retval FALSE The mDNS module is disabled. + * + */ + bool IsEnabled(void) const { return mIsEnabled; } + + /** + * Sets whether mDNS module is allowed to send questions requesting unicast responses referred to as "QU" questions. + * + * The "QU" question request unicast response in contrast to "QM" questions which request multicast responses. + * When allowed, the first probe will be sent as a "QU" question. + * + * This can be used to address platform limitation where platform cannot accept unicast response received on mDNS + * port. + * + * @param[in] aAllow Indicates whether or not to allow "QU" questions. + * + */ + void SetQuestionUnicastAllowed(bool aAllow) { mIsQuestionUnicastAllowed = aAllow; } + + /** + * Indicates whether mDNS module is allowed to send "QU" questions requesting unicast response. + * + * @retval TRUE The mDNS module is allowed to send "QU" questions. + * @retval FALSE The mDNS module is not allowed to send "QU" questions. + * + */ + bool IsQuestionUnicastAllowed(void) const { return mIsQuestionUnicastAllowed; } + + /** + * Sets the conflict callback. + * + * @param[in] aCallback The conflict callback. Can be `nullptr` is not needed. + * + */ + void SetConflictCallback(ConflictCallback aCallback) { mConflictCallback = aCallback; } + + /** + * Registers or updates a host. + * + * The fields in @p aHost follow these rules: + * + * - The `mHostName` field specifies the host name to register (e.g., "myhost"). MUST NOT contain the domain name. + * - The `mAddresses` is array of IPv6 addresses to register with the host. `mAddressesLength` provides the number + * of entries in `mAddresses` array. + * - The `mAddresses` array can be empty with zero `mAddressesLength`. In this case, mDNS will treat it as if host + * is unregistered and stop advertising any addresses for this the host name. + * - The `mTtl` specifies the TTL if non-zero. If zero, the mDNS core will choose a default TTL to use. + * + * This method can be called again for the same `mHostName` to update a previously registered host entry, for + * example, to change the list of addresses of the host. In this case, the mDNS module will send "goodbye" + * announcements for any previously registered and now removed addresses and announce any newly added addresses. + * + * The outcome of the registration request is reported back by invoking the provided @p aCallback with + * @p aRequestId as its input and one of the following `aError` inputs: + * + * - `kErrorNone` indicates registration was successful + * - `kErrorDuplicated` indicates a name conflict, i.e., the name is already claimed by another mDNS responder. + * + * For caller convenience, the OpenThread mDNS module guarantees that the callback will be invoked after this + * method returns, even in cases of immediate registration success. The @p aCallback can be `nullptr` if caller + * does not want to be notified of the outcome. + * + * @param[in] aHost The host to register. + * @param[in] aRequestId The ID associated with this request. + * @param[in] aCallback The callback function pointer to report the outcome (can be `nullptr` if not needed). + * + * @retval kErrorNone Successfully started registration. @p aCallback will report the outcome. + * @retval kErrorInvalidState mDNS module is not enabled. + * + */ + Error RegisterHost(const Host &aHost, RequestId aRequestId, RegisterCallback aCallback); + + /** + * Unregisters a host. + * + * The fields in @p aHost follow these rules: + * + * - The `mHostName` field specifies the host name to unregister (e.g., "myhost"). MUST NOT contain the domain name. + * - The rest of the fields in @p aHost structure are ignored in an `UnregisterHost()` call. + * + * If there is no previously registered host with the same name, no action is performed. + * + * If there is a previously registered host with the same name, the mDNS module will send "goodbye" announcement + * for all previously advertised address records. + * + * @param[in] aHost The host to unregister. + * + * @retval kErrorNone Successfully unregistered host. + * @retval kErrorInvalidState mDNS module is not enabled. + * + */ + Error UnregisterHost(const Host &aHost); + + /** + * Registers or updates a service. + * + * The fields in @p aService follow these rules: + * + * - The `mServiceInstance` specifies the service instance label. It is treated as a single DNS label. It may + * contain dot `.` character which is allowed in a service instance label. + * - The `mServiceType` specifies the service type (e.g., "_tst._udp"). It is treated as multiple dot `.` separated + * labels. It MUST NOT contain the domain name. + * - The `mHostName` field specifies the host name of the service. MUST NOT contain the domain name. + * - The `mSubTypeLabels` is an array of strings representing sub-types associated with the service. Each array + * entry is a sub-type label. The `mSubTypeLabels can be `nullptr` if there are no sub-types. Otherwise, the + * array length is specified by `mSubTypeLabelsLength`. + * - The `mTxtData` and `mTxtDataLength` specify the encoded TXT data. The `mTxtData` can be `nullptr` or + * `mTxtDataLength` can be zero to specify an empty TXT data. In this case mDNS module will use a single zero + * byte `[ 0 ]` as empty TXT data. + * - The `mPort`, `mWeight`, and `mPriority` specify the service's parameters (as specified in DNS SRV record). + * - The `mTtl` specifies the TTL if non-zero. If zero, the mDNS module will use default TTL for service entry. + * + * This method can be called again for the same `mServiceInstance` and `mServiceType` to update a previously + * registered service entry, for example, to change the sub-types list or update any parameter such as port, weight, + * priority, TTL, or host name. The mDNS module will send announcements for any changed info, e.g., will send + * "goodbye" announcements for any removed sub-types and announce any newly added sub-types. + * + * Regarding the invocation of the @p aCallback, this method behaves in the same way as described in + * `RegisterHost()`. + * + * @param[in] aService The service to register. + * @param[in] aRequestId The ID associated with this request. + * @param[in] aCallback The callback function pointer to report the outcome (can be `nullptr` if not needed). + * + * @retval kErrorNone Successfully started registration. @p aCallback will report the outcome. + * @retval kErrorInvalidState mDNS module is not enabled. + * + */ + Error RegisterService(const Service &aService, RequestId aRequestId, RegisterCallback aCallback); + + /** + * Unregisters a service. + * + * The fields in @p aService follow these rules: + + * - The `mServiceInstance` specifies the service instance label. It is treated as a single DNS label. It may + * contain dot `.` character which is allowed in a service instance label. + * - The `mServiceType` specifies the service type (e.g., "_tst._udp"). It is treated as multiple dot `.` separated + * labels. It MUST NOT contain the domain name. + * - The rest of the fields in @p aService structure are ignored in a`otMdnsUnregisterService()` call. + * + * If there is no previously registered service with the same name, no action is performed. + * + * If there is a previously registered service with the same name, the mDNS module will send "goodbye" + * announcements for all related records. + * + * @param[in] aService The service to unregister. + * + * @retval kErrorNone Successfully unregistered service. + * @retval kErrorInvalidState mDNS module is not enabled. + * + */ + Error UnregisterService(const Service &aService); + + /** + * Registers or updates a key record. + * + * The fields in @p aKey follow these rules: + * + * - If the key is associated with a host entry, `mName` specifies the host name & `mServcieType` MUST be `nullptr`. + * - If the key is associated with a service entry, `mName` specifies the service instance label (always treated as + * a single label) and `mServiceType` specifies the service type (e.g. "_tst._udp"). In this case the DNS name + * for key record is `.`. + * - The `mKeyData` field contains the key record's data with `mKeyDataLength` as its length in byes. + * - The `mTtl` specifies the TTL if non-zero. If zero, the mDNS module will use default TTL for the key entry. + * + * This method can be called again for the same name to updated a previously registered key entry, for example, + * to change the key data or TTL. + * + * Regarding the invocation of the @p aCallback, this method behaves in the same way as described in + * `RegisterHost()`. + * + * @param[in] aKey The key record to register. + * @param[in] aRequestId The ID associated with this request. + * @param[in] aCallback The callback function pointer to report the outcome (can be `nullptr` if not needed). + * + * @retval kErrorNone Successfully started registration. @p aCallback will report the outcome. + * @retval kErrorInvalidState mDNS module is not enabled. + * + */ + Error RegisterKey(const Key &aKey, RequestId aRequestId, RegisterCallback aCallback); + + /** + * Unregisters a key record on mDNS. + * + * The fields in @p aKey follow these rules: + * + * - If the key is associated with a host entry, `mName` specifies the host name & `mServcieType` MUST be `nullptr`. + * - If the key is associated with a service entry, `mName` specifies the service instance label (always treated as + * a single label) and `mServiceType` specifies the service type (e.g. "_tst._udp"). In this case the DNS name + * for key record is `.`. + * - The rest of the fields in @p aKey structure are ignored in a`otMdnsUnregisterKey()` call. + * + * If there is no previously registered key with the same name, no action is performed. + * + * If there is a previously registered key with the same name, the mDNS module will send "goodbye" announcements + * for the key record. + * + * @param[in] aKey The key to unregister. + * + * @retval kErrorNone Successfully unregistered key + * @retval kErrorInvalidState mDNS module is not enabled. + * + */ + Error UnregisterKey(const Key &aKey); + + /** + * Starts a service browser. + * + * Initiates a continuous search for the specified `mServiceType` in @p aBrowser. For sub-type services, use + * `mSubTypeLabel` to define the sub-type, for base services, set `mSubTypeLabel` to NULL. + * + * Discovered services are reported through the `mCallback` function in @p aBrowser. Services that have been + * removed are reported with a TTL value of zero. The callback may be invoked immediately with cached information + * (if available) and potentially before this method returns. When cached results are used, the reported TTL value + * will reflect the original TTL from the last received response. + * + * Multiple browsers can be started for the same service, provided they use different callback functions. + * + * @param[in] aBrowser The browser to be started. + * + * @retval kErrorNone Browser started successfully. + * @retval kErrorInvalidState mDNS module is not enabled. + * @retval kErrorAlready An identical browser (same service and callback) is already active. + * + */ + Error StartBrowser(const Browser &aBrowser); + + /** + * Stops a service browser. + * + * No action is performed if no matching browser with the same service and callback is currently active. + * + * @param[in] aBrowser The browser to stop. + * + * @retval kErrorNone Browser stopped successfully. + * @retval kErrorInvalidSatet mDNS module is not enabled. + * + */ + Error StopBrowser(const Browser &aBrowser); + + /** + * Starts an SRV record resolver. + * + * Initiates a continuous SRV record resolver for the specified service in @p aResolver. + * + * Discovered information is reported through the `mCallback` function in @p aResolver. When the service is removed + * it is reported with a TTL value of zero. In this case, `mHostName` may be NULL and other result fields (such as + * `mPort`) should be ignored. + * + * The callback may be invoked immediately with cached information (if available) and potentially before this + * method returns. When cached result is used, the reported TTL value will reflect the original TTL from the last + * received response. + * + * Multiple resolvers can be started for the same service, provided they use different callback functions. + * + * @param[in] aResolver The resolver to be started. + * + * @retval kErrorNone Resolver started successfully. + * @retval kErrorInvalidState mDNS module is not enabled. + * @retval kErrorAlready An identical resolver (same service and callback) is already active. + * + */ + Error StartSrvResolver(const SrvResolver &aResolver); + + /** + * Stops an SRV record resolver. + * + * No action is performed if no matching resolver with the same service and callback is currently active. + * + * @param[in] aResolver The resolver to stop. + * + * @retval kErrorNone Resolver stopped successfully. + * @retval kErrorInvalidState mDNS module is not enabled. + * + */ + Error StopSrvResolver(const SrvResolver &aResolver); + + /** + * Starts a TXT record resolver. + * + * Initiates a continuous TXT record resolver for the specified service in @p aResolver. + * + * Discovered information is reported through the `mCallback` function in @p aResolver. When the TXT record is + * removed it is reported with a TTL value of zero. In this case, `mTxtData` may be NULL, and other result fields + * (such as `mTxtDataLength`) should be ignored. + * + * The callback may be invoked immediately with cached information (if available) and potentially before this + * method returns. When cached result is used, the reported TTL value will reflect the original TTL from the last + * received response. + * + * Multiple resolvers can be started for the same service, provided they use different callback functions. + * + * @param[in] aResolver The resolver to be started. + * + * @retval kErrorNone Resolver started successfully. + * @retval kErrorInvalidState mDNS module is not enabled. + * @retval kErrorAlready An identical resolver (same service and callback) is already active. + * + */ + Error StartTxtResolver(const TxtResolver &aResolver); + + /** + * Stops a TXT record resolver. + * + * No action is performed if no matching resolver with the same service and callback is currently active. + * + * @param[in] aResolver The resolver to stop. + * + * @retval kErrorNone Resolver stopped successfully. + * @retval kErrorInvalidState mDNS module is not enabled. + * + */ + Error StopTxtResolver(const TxtResolver &aResolver); + + /** + * Starts an IPv6 address resolver. + * + * Initiates a continuous IPv6 address resolver for the specified host name in @p aResolver. + * + * Discovered addresses are reported through the `mCallback` function in @ p aResolver. The callback is invoked + * whenever addresses are added or removed, providing an updated list. If all addresses are removed, the callback + * is invoked with an empty list (`mAddresses` will be NULL, and `mAddressesLength` will be zero). + * + * The callback may be invoked immediately with cached information (if available) and potentially before this + * method returns. When cached result is used, the reported TTL values will reflect the original TTL from the last + * received response. + * + * Multiple resolvers can be started for the same host name, provided they use different callback functions. + * + * @param[in] aResolver The resolver to be started. + * + * @retval kErrorNone Resolver started successfully. + * @retval kErrorInvalidState mDNS module is not enabled. + * @retval kErrorAlready An identical resolver (same host and callback) is already active. + * + */ + Error StartIp6AddressResolver(const AddressResolver &aResolver); + + /** + * Stops an IPv6 address resolver. + * + * No action is performed if no matching resolver with the same host name and callback is currently active. + * + * @param[in] aResolver The resolver to stop. + * + * @retval kErrorNone Resolver stopped successfully. + * @retval kErrorInvalidState mDNS module is not enabled. + * + */ + Error StopIp6AddressResolver(const AddressResolver &aResolver); + + /** + * Starts an IPv4 address resolver. + * + * Initiates a continuous IPv4 address resolver for the specified host name in @p aResolver. + * + * Discovered addresses are reported through the `mCallback` function in @ p aResolver. The IPv4 addresses are + * represented using the IPv4-mapped IPv6 address format in `mAddresses` array. The callback is invoked whenever + * addresses are added or removed, providing an updated list. If all addresses are removed, the callback is invoked + * with an empty list (`mAddresses` will be NULL, and `mAddressesLength` will be zero). + * + * The callback may be invoked immediately with cached information (if available) and potentially before this + * method returns. When cached result is used, the reported TTL values will reflect the original TTL from the last + * received response. + * + * Multiple resolvers can be started for the same host name, provided they use different callback functions. + * + * @param[in] aResolver The resolver to be started. + * + * @retval kErrorNone Resolver started successfully. + * @retval kErrorInvalidState mDNS module is not enabled. + * @retval kErrorAlready An identical resolver (same host and callback) is already active. + * + */ + Error StartIp4AddressResolver(const AddressResolver &aResolver); + + /** + * Stops an IPv4 address resolver. + * + * No action is performed if no matching resolver with the same host name and callback is currently active. + * + * @param[in] aResolver The resolver to stop. + * + * @retval kErrorNone Resolver stopped successfully. + * @retval kErrorInvalidState mDNS module is not enabled. + * + */ + Error StopIp4AddressResolver(const AddressResolver &aResolver); + + /** + * Sets the max size threshold for mDNS messages. + * + * This method is mainly intended for testing. The max size threshold is used to break larger messages. + * + * @param[in] aMaxSize The max message size threshold. + * + */ + void SetMaxMessageSize(uint16_t aMaxSize) { mMaxMessageSize = aMaxSize; } + +private: + // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + static constexpr uint16_t kUdpPort = 5353; + + static constexpr bool kDefaultQuAllowed = OPENTHREAD_CONFIG_MULTICAST_DNS_DEFAULT_QUESTION_UNICAST_ALLOWED; + + static constexpr uint32_t kMaxMessageSize = 1200; + + static constexpr uint8_t kNumberOfProbes = 3; + static constexpr uint32_t kMinProbeDelay = 20; // In msec + static constexpr uint32_t kMaxProbeDelay = 250; // In msec + static constexpr uint32_t kProbeWaitTime = 250; // In msec + + static constexpr uint8_t kNumberOfAnnounces = 3; + static constexpr uint32_t kAnnounceInterval = 1000; // In msec - time between first two announces + + static constexpr uint8_t kNumberOfInitalQueries = 3; + static constexpr uint32_t kInitialQueryInterval = 1000; // In msec - time between first two queries + + static constexpr uint32_t kMinInitialQueryDelay = 20; // msec + static constexpr uint32_t kMaxInitialQueryDelay = 120; // msec + static constexpr uint32_t kRandomDelayReuseInterval = 2; // msec + + static constexpr uint32_t kUnspecifiedTtl = 0; + static constexpr uint32_t kDefaultTtl = 120; + static constexpr uint32_t kDefaultKeyTtl = kDefaultTtl; + static constexpr uint32_t kNsecTtl = 4500; + static constexpr uint32_t kServicesPtrTtl = 4500; + + static constexpr uint16_t kClassQuestionUnicastFlag = (1U << 15); + static constexpr uint16_t kClassCacheFlushFlag = (1U << 15); + static constexpr uint16_t kClassMask = (0x7fff); + + static constexpr uint16_t kUnspecifiedOffset = 0; + + static constexpr uint8_t kNumSections = 4; + + // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + enum Section : uint8_t + { + kQuestionSection, + kAnswerSection, + kAuthoritySection, + kAdditionalDataSection, + }; + + enum AppendOutcome : uint8_t + { + kAppendedFullNameAsCompressed, + kAppendedLabels, + }; + + // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + // Forward declarations + + class EntryTimerContext; + class TxMessage; + class RxMessage; + class ServiceEntry; + class ServiceType; + + // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + struct EmptyChecker + { + // Used in `Matches()` to find empty entries (with no record) to remove and free. + }; + + // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + struct ExpireChecker + { + // Used in `Matches()` to find expired entries in a list. + + explicit ExpireChecker(TimeMilli aNow) { mNow = aNow; } + + TimeMilli mNow; + }; + + // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + class Callback : public Clearable + { + public: + Callback(void) { Clear(); } + Callback(RequestId aRequestId, RegisterCallback aCallback); + + bool IsEmpty(void) const { return (mCallback == nullptr); } + void InvokeAndClear(Instance &aInstance, Error aError); + + private: + RequestId mRequestId; + RegisterCallback mCallback; + }; + + // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + class RecordCounts : public Clearable + { + public: + RecordCounts(void) { Clear(); } + + uint16_t GetFor(Section aSection) const { return mCounts[aSection]; } + void Increment(Section aSection) { mCounts[aSection]++; } + void ReadFrom(const Header &aHeader); + void WriteTo(Header &aHeader) const; + bool IsEmpty(void) const; + + private: + uint16_t mCounts[kNumSections]; + }; + + // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + struct AnswerInfo + { + uint16_t mQuestionRrType; + TimeMilli mAnswerTime; + bool mIsProbe; + bool mUnicastResponse; + }; + + // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + class AddressArray : public Heap::Array + { + public: + bool Matches(const Ip6::Address *aAddresses, uint16_t aNumAddresses) const; + void SetFrom(const Ip6::Address *aAddresses, uint16_t aNumAddresses); + }; + + // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + class FireTime + { + public: + FireTime(void) { ClearFireTime(); } + void ClearFireTime(void) { mHasFireTime = false; } + bool HasFireTime(void) const { return mHasFireTime; } + TimeMilli GetFireTime(void) const { return mFireTime; } + void SetFireTime(TimeMilli aFireTime); + + protected: + void ScheduleFireTimeOn(TimerMilli &aTimer); + + private: + TimeMilli mFireTime; + bool mHasFireTime; + }; + + // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + class RecordInfo : public Clearable, private NonCopyable + { + public: + // Keeps track of record state and timings. + + RecordInfo(void) { Clear(); } + + bool IsPresent(void) const { return mIsPresent; } + uint32_t GetTtl(void) const { return mTtl; } + + template void UpdateProperty(UintType &aProperty, UintType aValue); + void UpdateProperty(AddressArray &aAddrProperty, const Ip6::Address *aAddrs, uint16_t aNumAddrs); + void UpdateProperty(Heap::String &aStringProperty, const char *aString); + void UpdateProperty(Heap::Data &aDataProperty, const uint8_t *aData, uint16_t aLength); + void UpdateTtl(uint32_t aTtl); + + void StartAnnouncing(void); + bool ShouldAppendTo(TxMessage &aResponse, TimeMilli aNow) const; + bool CanAnswer(void) const; + void ScheduleAnswer(const AnswerInfo &aInfo); + void UpdateStateAfterAnswer(const TxMessage &aResponse); + void UpdateFireTimeOn(FireTime &aFireTime); + uint32_t GetDurationSinceLastMulticast(TimeMilli aTime) const; + Error GetLastMulticastTime(TimeMilli &aLastMulticastTime) const; + + // `AppendState` methods: Used to track whether the record + // is appended in a message, or needs to be appended in + // Additional Data section. + + void MarkAsNotAppended(void) { mAppendState = kNotAppended; } + void MarkAsAppended(TxMessage &aTxMessage, Section aSection); + void MarkToAppendInAdditionalData(void); + bool IsAppended(void) const; + bool CanAppend(void) const; + bool ShouldAppendInAdditionalDataSection(void) const { return (mAppendState == kToAppendInAdditionalData); } + + private: + enum AppendState : uint8_t + { + kNotAppended, + kToAppendInAdditionalData, + kAppendedInMulticastMsg, + kAppendedInUnicastMsg, + }; + + static constexpr uint32_t kMinIntervalBetweenMulticast = 1000; // msec + static constexpr uint32_t kLastMulticastTimeAge = 10 * Time::kOneHourInMsec; + + static_assert(kNotAppended == 0, "kNotAppended MUST be zero, so `Clear()` works correctly"); + + bool mIsPresent : 1; + bool mMulticastAnswerPending : 1; + bool mUnicastAnswerPending : 1; + bool mIsLastMulticastValid : 1; + uint8_t mAnnounceCounter; + AppendState mAppendState; + Section mAppendSection; + uint32_t mTtl; + TimeMilli mAnnounceTime; + TimeMilli mAnswerTime; + TimeMilli mLastMulticastTime; + }; + + // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + class Entry : public InstanceLocatorInit, public FireTime, private NonCopyable + { + // Base class for `HostEntry` and `ServiceEntry`. + + friend class ServiceType; + + public: + enum State : uint8_t + { + kProbing, + kRegistered, + kConflict, + kRemoving, + }; + + State GetState(void) const { return mState; } + void Register(const Key &aKey, const Callback &aCallback); + void Unregister(const Key &aKey); + void InvokeCallbacks(void); + void ClearAppendState(void); + + protected: + static constexpr uint32_t kMinIntervalProbeResponse = 250; // msec + static constexpr uint8_t kTypeArraySize = 8; // We can have SRV, TXT and KEY today. + + struct TypeArray : public Array // Array of record types for NSEC record + { + void Add(uint16_t aType) { SuccessOrAssert(PushBack(aType)); } + }; + + struct RecordAndType + { + RecordInfo &mRecord; + uint16_t mType; + }; + + typedef void (*NameAppender)(Entry &aEntry, TxMessage &aTxMessage, Section aSection); + + Entry(void); + void Init(Instance &aInstance); + void SetCallback(const Callback &aCallback); + void ClearCallback(void) { mCallback.Clear(); } + void StartProbing(void); + void SetStateToConflict(void); + void SetStateToRemoving(void); + void UpdateRecordsState(const TxMessage &aResponse); + void AppendQuestionTo(TxMessage &aTxMessage) const; + void AppendKeyRecordTo(TxMessage &aTxMessage, Section aSection, NameAppender aNameAppender); + void AppendNsecRecordTo(TxMessage &aTxMessage, + Section aSection, + const TypeArray &aTypes, + NameAppender aNameAppender); + bool ShouldAnswerNsec(TimeMilli aNow) const; + void DetermineNextFireTime(void); + void ScheduleTimer(void); + void AnswerProbe(const AnswerInfo &aInfo, RecordAndType *aRecords, uint16_t aRecordsLength); + void AnswerNonProbe(const AnswerInfo &aInfo, RecordAndType *aRecords, uint16_t aRecordsLength); + void ScheduleNsecAnswer(const AnswerInfo &aInfo); + + template void HandleTimer(EntryTimerContext &aContext); + + RecordInfo mKeyRecord; + + private: + void SetState(State aState); + void ClearKey(void); + void ScheduleCallbackTask(void); + void CheckMessageSizeLimitToPrepareAgain(TxMessage &aTxMessage, bool &aPrepareAgain); + + State mState; + uint8_t mProbeCount; + bool mMulticastNsecPending : 1; + bool mUnicastNsecPending : 1; + bool mAppendedNsec : 1; + TimeMilli mNsecAnswerTime; + Heap::Data mKeyData; + Callback mCallback; + Callback mKeyCallback; + }; + + // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + class HostEntry : public Entry, public LinkedListEntry, public Heap::Allocatable + { + friend class LinkedListEntry; + friend class Entry; + friend class ServiceEntry; + + public: + HostEntry(void); + Error Init(Instance &aInstance, const Host &aHost) { return Init(aInstance, aHost.mHostName); } + Error Init(Instance &aInstance, const Key &aKey) { return Init(aInstance, aKey.mName); } + bool IsEmpty(void) const; + bool Matches(const Name &aName) const; + bool Matches(const Host &aHost) const; + bool Matches(const Key &aKey) const; + bool Matches(const Heap::String &aName) const; + bool Matches(State aState) const { return GetState() == aState; } + bool Matches(const HostEntry &aEntry) const { return (this == &aEntry); } + void Register(const Host &aHost, const Callback &aCallback); + void Register(const Key &aKey, const Callback &aCallback); + void Unregister(const Host &aHost); + void Unregister(const Key &aKey); + void AnswerQuestion(const AnswerInfo &aInfo); + void HandleTimer(EntryTimerContext &aContext); + void ClearAppendState(void); + void PrepareResponse(TxMessage &aResponse, TimeMilli aNow); + void HandleConflict(void); + + private: + Error Init(Instance &aInstance, const char *aName); + void ClearHost(void); + void ScheduleToRemoveIfEmpty(void); + void PrepareProbe(TxMessage &aProbe); + void StartAnnouncing(void); + void PrepareResponseRecords(TxMessage &aResponse, TimeMilli aNow); + void UpdateRecordsState(const TxMessage &aResponse); + void DetermineNextFireTime(void); + void AppendAddressRecordsTo(TxMessage &aTxMessage, Section aSection); + void AppendKeyRecordTo(TxMessage &aTxMessage, Section aSection); + void AppendNsecRecordTo(TxMessage &aTxMessage, Section aSection); + void AppendNameTo(TxMessage &aTxMessage, Section aSection); + + static void AppendEntryName(Entry &aEntry, TxMessage &aTxMessage, Section aSection); + + HostEntry *mNext; + Heap::String mName; + RecordInfo mAddrRecord; + AddressArray mAddresses; + uint16_t mNameOffset; + }; + + // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + class ServiceEntry : public Entry, public LinkedListEntry, public Heap::Allocatable + { + friend class LinkedListEntry; + friend class Entry; + friend class ServiceType; + + public: + ServiceEntry(void); + Error Init(Instance &aInstance, const Service &aService); + Error Init(Instance &aInstance, const Key &aKey); + bool IsEmpty(void) const; + bool Matches(const Name &aName) const; + bool Matches(const Service &aService) const; + bool Matches(const Key &aKey) const; + bool Matches(State aState) const { return GetState() == aState; } + bool Matches(const ServiceEntry &aEntry) const { return (this == &aEntry); } + bool MatchesServiceType(const Name &aServiceType) const; + bool CanAnswerSubType(const char *aSubLabel) const; + void Register(const Service &aService, const Callback &aCallback); + void Register(const Key &aKey, const Callback &aCallback); + void Unregister(const Service &aService); + void Unregister(const Key &aKey); + void AnswerServiceNameQuestion(const AnswerInfo &aInfo); + void AnswerServiceTypeQuestion(const AnswerInfo &aInfo, const char *aSubLabel); + bool ShouldSuppressKnownAnswer(uint32_t aTtl, const char *aSubLabel) const; + void HandleTimer(EntryTimerContext &aContext); + void ClearAppendState(void); + void PrepareResponse(TxMessage &aResponse, TimeMilli aNow); + void HandleConflict(void); + + private: + class SubType : public LinkedListEntry, public Heap::Allocatable, private ot::NonCopyable + { + public: + Error Init(const char *aLabel); + bool Matches(const char *aLabel) const { return NameMatch(mLabel, aLabel); } + bool Matches(const EmptyChecker &aChecker) const; + bool IsContainedIn(const Service &aService) const; + + SubType *mNext; + Heap::String mLabel; + RecordInfo mPtrRecord; + uint16_t mSubServiceNameOffset; + }; + + Error Init(Instance &aInstance, const char *aServiceInstance, const char *aServiceType); + void ClearService(void); + void ScheduleToRemoveIfEmpty(void); + void PrepareProbe(TxMessage &aProbe); + void StartAnnouncing(void); + void PrepareResponseRecords(TxMessage &aResponse, TimeMilli aNow); + void UpdateRecordsState(const TxMessage &aResponse); + void DetermineNextFireTime(void); + void DiscoverOffsetsAndHost(HostEntry *&aHost); + void UpdateServiceTypes(void); + void AppendSrvRecordTo(TxMessage &aTxMessage, Section aSection); + void AppendTxtRecordTo(TxMessage &aTxMessage, Section aSection); + void AppendPtrRecordTo(TxMessage &aTxMessage, Section aSection, SubType *aSubType = nullptr); + void AppendKeyRecordTo(TxMessage &aTxMessage, Section aSection); + void AppendNsecRecordTo(TxMessage &aTxMessage, Section aSection); + void AppendServiceNameTo(TxMessage &TxMessage, Section aSection); + void AppendServiceTypeTo(TxMessage &aTxMessage, Section aSection); + void AppendSubServiceTypeTo(TxMessage &aTxMessage, Section aSection); + void AppendSubServiceNameTo(TxMessage &aTxMessage, Section aSection, SubType &aSubType); + void AppendHostNameTo(TxMessage &aTxMessage, Section aSection); + + static void AppendEntryName(Entry &aEntry, TxMessage &aTxMessage, Section aSection); + + static const uint8_t kEmptyTxtData[]; + + ServiceEntry *mNext; + Heap::String mServiceInstance; + Heap::String mServiceType; + RecordInfo mPtrRecord; + RecordInfo mSrvRecord; + RecordInfo mTxtRecord; + OwningList mSubTypes; + Heap::String mHostName; + Heap::Data mTxtData; + uint16_t mPriority; + uint16_t mWeight; + uint16_t mPort; + uint16_t mServiceNameOffset; + uint16_t mServiceTypeOffset; + uint16_t mSubServiceTypeOffset; + uint16_t mHostNameOffset; + bool mIsAddedInServiceTypes; + }; + + // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + class ServiceType : public InstanceLocatorInit, + public FireTime, + public LinkedListEntry, + public Heap::Allocatable, + private NonCopyable + { + // Track a service type to answer to `_services._dns-sd._udp.local` + // queries. + + friend class LinkedListEntry; + + public: + Error Init(Instance &aInstance, const char *aServiceType); + bool Matches(const Name &aServcieTypeName) const; + bool Matches(const Heap::String &aServiceType) const; + bool Matches(const ServiceType &aServiceType) const { return (this == &aServiceType); } + void IncrementNumEntries(void) { mNumEntries++; } + void DecrementNumEntries(void) { mNumEntries--; } + uint16_t GetNumEntries(void) const { return mNumEntries; } + void ClearAppendState(void); + void AnswerQuestion(const AnswerInfo &aInfo); + bool ShouldSuppressKnownAnswer(uint32_t aTtl) const; + void HandleTimer(EntryTimerContext &aContext); + void PrepareResponse(TxMessage &aResponse, TimeMilli aNow); + + private: + void PrepareResponseRecords(TxMessage &aResponse, TimeMilli aNow); + void AppendPtrRecordTo(TxMessage &aResponse, uint16_t aServiceTypeOffset); + + ServiceType *mNext; + Heap::String mServiceType; + RecordInfo mServicesPtr; + uint16_t mNumEntries; // Number of service entries providing this service type. + }; + + // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + class TxMessage : public InstanceLocator + { + public: + enum Type : uint8_t + { + kMulticastProbe, + kMulticastQuery, + kMulticastResponse, + kUnicastResponse, + }; + + TxMessage(Instance &aInstance, Type aType); + TxMessage(Instance &aInstance, Type aType, const AddressInfo &aUnicastDest); + Type GetType(void) const { return mType; } + Message &SelectMessageFor(Section aSection); + AppendOutcome AppendLabel(Section aSection, const char *aLabel, uint16_t &aCompressOffset); + AppendOutcome AppendMultipleLabels(Section aSection, const char *aLabels, uint16_t &aCompressOffset); + void AppendServiceType(Section aSection, const char *aServiceType, uint16_t &aCompressOffset); + void AppendDomainName(Section aSection); + void AppendServicesDnssdName(Section aSection); + void IncrementRecordCount(Section aSection) { mRecordCounts.Increment(aSection); } + void CheckSizeLimitToPrepareAgain(bool &aPrepareAgain); + void SaveCurrentState(void); + void RestoreToSavedState(void); + void Send(void); + + private: + static constexpr bool kIsSingleLabel = true; + + void Init(Type aType); + void Reinit(void); + bool IsOverSizeLimit(void) const; + AppendOutcome AppendLabels(Section aSection, + const char *aLabels, + bool aIsSingleLabel, + uint16_t &aCompressOffset); + bool ShouldClearAppendStateOnReinit(const Entry &aEntry) const; + + static void SaveOffset(uint16_t &aCompressOffset, const Message &aMessage, Section aSection); + + RecordCounts mRecordCounts; + OwnedPtr mMsgPtr; + OwnedPtr mExtraMsgPtr; + RecordCounts mSavedRecordCounts; + uint16_t mSavedMsgLength; + uint16_t mSavedExtraMsgLength; + uint16_t mDomainOffset; // Offset for domain name `.local.` for name compression. + uint16_t mUdpOffset; // Offset to `_udp.local.` + uint16_t mTcpOffset; // Offset to `_tcp.local.` + uint16_t mServicesDnssdOffset; // Offset to `_services._dns-sd` + AddressInfo mUnicastDest; + Type mType; + }; + + // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + class TimerContext : public InstanceLocator + { + public: + TimerContext(Instance &aInstance); + + TimeMilli GetNow(void) const { return mNow; } + TimeMilli GetNextTime(void) const { return mNextTime; } + void UpdateNextTime(TimeMilli aTime); + + private: + TimeMilli mNow; + TimeMilli mNextTime; + }; + + // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + class EntryTimerContext : public TimerContext // Used by `HandleEntryTimer`. + { + public: + EntryTimerContext(Instance &aInstance); + TxMessage &GetProbeMessage(void) { return mProbeMessage; } + TxMessage &GetResponseMessage(void) { return mResponseMessage; } + + private: + TxMessage mProbeMessage; + TxMessage mResponseMessage; + }; + + // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + class RxMessage : public InstanceLocatorInit, + public Heap::Allocatable, + public LinkedListEntry, + private NonCopyable + { + friend class LinkedListEntry; + + public: + enum ProcessOutcome : uint8_t + { + kProcessed, + kSaveAsMultiPacket, + }; + + Error Init(Instance &aInstance, + OwnedPtr &aMessagePtr, + bool aIsUnicast, + const AddressInfo &aSenderAddress); + bool IsQuery(void) const { return mIsQuery; } + bool IsTruncated(void) const { return mTruncated; } + bool IsSelfOriginating(void) const { return mIsSelfOriginating; } + const RecordCounts &GetRecordCounts(void) const { return mRecordCounts; } + const AddressInfo &GetSenderAddress(void) const { return mSenderAddress; } + void ClearProcessState(void); + ProcessOutcome ProcessQuery(bool aShouldProcessTruncated); + void ProcessResponse(void); + + private: + typedef void (RxMessage::*RecordProcessor)(const Name &aName, + const ResourceRecord &aRecord, + uint16_t aRecordOffset); + + struct Question : public Clearable + { + Question(void) { Clear(); } + void ClearProcessState(void); + + Entry *mEntry; // Entry which can provide answer (if any). + uint16_t mNameOffset; // Offset to start of question name. + uint16_t mRrType; // The question record type. + bool mIsRrClassInternet : 1; // Is the record class Internet or Any. + bool mIsProbe : 1; // Is a probe (contains a matching record in Authority section). + bool mUnicastResponse : 1; // Is QU flag set (requesting a unicast response). + bool mCanAnswer : 1; // Can provide answer for this question + bool mIsUnique : 1; // Is unique record (vs a shared record). + bool mIsForService : 1; // Is for a `ServiceEntry` (vs a `HostEntry`). + bool mIsServiceType : 1; // Is for service type or sub-type of a `ServiceEntry`. + bool mIsForAllServicesDnssd : 1; // Is for "_services._dns-sd._udp" (all service types). + }; + + static constexpr uint32_t kMinResponseDelay = 20; // msec + static constexpr uint32_t kMaxResponseDelay = 120; // msec + + void ProcessQuestion(Question &aQuestion); + void AnswerQuestion(const Question &aQuestion, TimeMilli aAnswerTime); + void AnswerServiceTypeQuestion(const Question &aQuestion, const AnswerInfo &aInfo, ServiceEntry &aFirstEntry); + bool ShouldSuppressKnownAnswer(const Name &aServiceType, + const char *aSubLabel, + const ServiceEntry &aServiceEntry) const; + bool ParseQuestionNameAsSubType(const Question &aQuestion, + Name::LabelBuffer &aSubLabel, + Name &aServiceType) const; + void AnswerAllServicesQuestion(const Question &aQuestion, const AnswerInfo &aInfo); + bool ShouldSuppressKnownAnswer(const Question &aQuestion, const ServiceType &aServiceType) const; + void SendUnicastResponse(const AddressInfo &aUnicastDest); + void IterateOnAllRecordsInResponse(RecordProcessor aRecordProcessor); + void ProcessRecordForConflict(const Name &aName, const ResourceRecord &aRecord, uint16_t aRecordOffset); + void ProcessPtrRecord(const Name &aName, const ResourceRecord &aRecord, uint16_t aRecordOffset); + void ProcessSrvRecord(const Name &aName, const ResourceRecord &aRecord, uint16_t aRecordOffset); + void ProcessTxtRecord(const Name &aName, const ResourceRecord &aRecord, uint16_t aRecordOffset); + void ProcessAaaaRecord(const Name &aName, const ResourceRecord &aRecord, uint16_t aRecordOffset); + void ProcessARecord(const Name &aName, const ResourceRecord &aRecord, uint16_t aRecordOffset); + + RxMessage *mNext; + OwnedPtr mMessagePtr; + Heap::Array mQuestions; + AddressInfo mSenderAddress; + RecordCounts mRecordCounts; + uint16_t mStartOffset[kNumSections]; + bool mIsQuery : 1; + bool mIsUnicast : 1; + bool mTruncated : 1; + bool mIsSelfOriginating : 1; + }; + + // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + void HandleMultiPacketTimer(void) { mMultiPacketRxMessages.HandleTimer(); } + + class MultiPacketRxMessages : public InstanceLocator + { + public: + explicit MultiPacketRxMessages(Instance &aInstance); + + void AddToExisting(OwnedPtr &aRxMessagePtr); + void AddNew(OwnedPtr &aRxMessagePtr); + void HandleTimer(void); + void Clear(void); + + private: + static constexpr uint32_t kMinProcessDelay = 400; // msec + static constexpr uint32_t kMaxProcessDelay = 500; // msec + static constexpr uint16_t kMaxNumMessages = 10; + + struct RxMsgEntry : public InstanceLocator, + public LinkedListEntry, + public Heap::Allocatable, + private NonCopyable + { + explicit RxMsgEntry(Instance &aInstance); + + bool Matches(const AddressInfo &aAddress) const; + bool Matches(const ExpireChecker &aExpireChecker) const; + void Add(OwnedPtr &aRxMessagePtr); + + OwningList mRxMessages; + TimeMilli mProcessTime; + RxMsgEntry *mNext; + }; + + using MultiPacketTimer = TimerMilliIn; + + OwningList mRxMsgEntries; + MultiPacketTimer mTimer; + }; + + // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + void HandleTxMessageHistoryTimer(void) { mTxMessageHistory.HandleTimer(); } + + class TxMessageHistory : public InstanceLocator + { + // Keep track of messages sent by mDNS module to tell if + // a received message is self originating. + + public: + explicit TxMessageHistory(Instance &aInstance); + void Clear(void); + void Add(const Message &aMessage); + bool Contains(const Message &aMessage) const; + void HandleTimer(void); + + private: + static constexpr uint32_t kExpireInterval = TimeMilli::SecToMsec(10); // in msec + + typedef Crypto::Sha256::Hash Hash; + + struct HashEntry : public LinkedListEntry, public Heap::Allocatable + { + bool Matches(const Hash &aHash) const { return aHash == mHash; } + bool Matches(const ExpireChecker &aExpireChecker) const { return mExpireTime <= aExpireChecker.mNow; } + + HashEntry *mNext; + Hash mHash; + TimeMilli mExpireTime; + }; + + static void CalculateHash(const Message &aMessage, Hash &aHash); + + using TxMsgHistoryTimer = TimerMilliIn; + + OwningList mHashEntries; + TxMsgHistoryTimer mTimer; + }; + + // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + class CacheEntry; + class TxtCache; + + class ResultCallback : public LinkedListEntry, public Heap::Allocatable + { + friend class Heap::Allocatable; + friend class LinkedListEntry; + friend class CacheEntry; + + public: + ResultCallback(const ResultCallback &aResultCallback) = default; + + template + explicit ResultCallback(CallbackType aCallback) + : mNext(nullptr) + , mSharedCallback(aCallback) + { + } + + bool Matches(BrowseCallback aCallback) const { return mSharedCallback.mBrowse == aCallback; } + bool Matches(SrvCallback aCallback) const { return mSharedCallback.mSrv == aCallback; } + bool Matches(TxtCallback aCallback) const { return mSharedCallback.mTxt == aCallback; } + bool Matches(AddressCallback aCallback) const { return mSharedCallback.mAddress == aCallback; } + bool Matches(EmptyChecker) const { return (mSharedCallback.mSrv == nullptr); } + + void Invoke(Instance &aInstance, const BrowseResult &aResult) const; + void Invoke(Instance &aInstance, const SrvResult &aResult) const; + void Invoke(Instance &aInstance, const TxtResult &aResult) const; + void Invoke(Instance &aInstance, const AddressResult &aResult) const; + + void ClearCallback(void) { mSharedCallback.Clear(); } + + private: + union SharedCallback + { + explicit SharedCallback(BrowseCallback aCallback) { mBrowse = aCallback; } + explicit SharedCallback(SrvCallback aCallback) { mSrv = aCallback; } + explicit SharedCallback(TxtCallback aCallback) { mTxt = aCallback; } + explicit SharedCallback(AddressCallback aCallback) { mAddress = aCallback; } + + void Clear(void) { mBrowse = nullptr; } + + BrowseCallback mBrowse; + SrvCallback mSrv; + TxtCallback mTxt; + AddressCallback mAddress; + }; + + ResultCallback *mNext; + SharedCallback mSharedCallback; + }; + + // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + class CacheTimerContext : public TimerContext + { + public: + CacheTimerContext(Instance &aInstance); + TxMessage &GetQueryMessage(void) { return mQueryMessage; } + + private: + TxMessage mQueryMessage; + }; + + // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + class CacheRecordInfo + { + public: + CacheRecordInfo(void); + + bool IsPresent(void) const { return (mTtl > 0); } + uint32_t GetTtl(void) const { return mTtl; } + bool RefreshTtl(uint32_t aTtl); + bool ShouldExpire(TimeMilli aNow) const; + void UpdateStateAfterQuery(TimeMilli aNow); + void UpdateQueryAndFireTimeOn(CacheEntry &aCacheEntry); + bool LessThanHalfTtlRemains(TimeMilli aNow) const; + uint32_t GetRemainingTtl(TimeMilli aNow) const; + + private: + static constexpr uint32_t kMaxTtl = (24 * 3600); // One day + static constexpr uint8_t kNumberOfQueries = 4; + static constexpr uint32_t kQueryTtlVariation = 1000 * 2 / 100; // 2% + + uint32_t GetClampedTtl(void) const; + TimeMilli GetExpireTime(void) const; + TimeMilli GetQueryTime(uint8_t aAttemptIndex) const; + + uint32_t mTtl; + TimeMilli mLastRxTime; + uint8_t mQueryCount; + }; + + // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + class CacheEntry : public FireTime, public InstanceLocatorInit, private NonCopyable + { + // Base class for cache entries: `BrowseCache`, `mSrvCache`, + // `mTxtCache`, etc. Implements common behaviors: initial + // queries, query/timer scheduling, callback tracking, entry + // aging, and timer handling. Tracks entry type in `mType` and + // invokes sub-class method for type-specific behaviors + // (e.g., query message construction). + + public: + void HandleTimer(CacheTimerContext &aContext); + void ClearEmptyCallbacks(void); + void ScheduleQuery(TimeMilli aQueryTime); + + protected: + enum Type : uint8_t + { + kBrowseCache, + kSrvCache, + kTxtCache, + kIp6AddrCache, + kIp4AddrCache, + }; + + void Init(Instance &aInstance, Type aType); + bool IsActive(void) const { return mIsActive; } + bool ShouldDelete(TimeMilli aNow) const; + void StartInitialQueries(void); + void StopInitialQueries(void) { mInitalQueries = kNumberOfInitalQueries; } + Error Add(const ResultCallback &aCallback); + void Remove(const ResultCallback &aCallback); + void DetermineNextFireTime(void); + void ScheduleTimer(void); + + template void InvokeCallbacks(const ResultType &aResult); + + private: + static constexpr uint32_t kMinIntervalBetweenQueries = 1000; // In msec + static constexpr uint32_t kNonActiveDeleteTimeout = 7 * Time::kOneMinuteInMsec; + + typedef OwningList CallbackList; + + void SetIsActive(bool aIsActive); + bool ShouldQuery(TimeMilli aNow); + void PrepareQuery(CacheTimerContext &aContext); + void ProcessExpiredRecords(TimeMilli aNow); + void DetermineNextInitialQueryTime(void); + + ResultCallback *FindCallbackMatching(const ResultCallback &aCallback); + + template CacheType &As(void) { return *static_cast(this); } + template const CacheType &As(void) const { return *static_cast(this); } + + Type mType; // Cache entry type. + uint8_t mInitalQueries; // Number initial queries sent already. + bool mQueryPending : 1; // Whether a query tx request is pending. + bool mLastQueryTimeValid : 1; // Whether `mLastQueryTime` is valid. + bool mIsActive : 1; // Whether there is any active resolver/browser for this entry. + TimeMilli mNextQueryTime; // The next query tx time when `mQueryPending`. + TimeMilli mLastQueryTime; // The last query tx time or the upcoming tx time of first initial query. + TimeMilli mDeleteTime; // The time to delete the entry when not `mIsActive`. + CallbackList mCallbacks; // Resolver/Browser callbacks. + }; + + // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + class BrowseCache : public CacheEntry, public LinkedListEntry, public Heap::Allocatable + { + friend class LinkedListEntry; + friend class Heap::Allocatable; + friend class CacheEntry; + + public: + void ClearCompressOffsets(void); + bool Matches(const Name &aFullName) const; + bool Matches(const char *aServiceType, const char *aSubTypeLabel) const; + bool Matches(const Browser &aBrowser) const; + bool Matches(const ExpireChecker &aExpireChecker) const; + Error Add(const Browser &aBrowser); + void Remove(const Browser &aBrowser); + void ProcessResponseRecord(const Message &aMessage, uint16_t aRecordOffset); + + private: + struct PtrEntry : public LinkedListEntry, public Heap::Allocatable + { + Error Init(const char *aServiceInstance); + bool Matches(const char *aServiceInstance) const { return NameMatch(mServiceInstance, aServiceInstance); } + bool Matches(const ExpireChecker &aExpireChecker) const; + void ConvertTo(BrowseResult &aResult, const BrowseCache &aBrowseCache) const; + + PtrEntry *mNext; + Heap::String mServiceInstance; + CacheRecordInfo mRecord; + }; + + // Called by base class `CacheEntry` + void PreparePtrQuestion(TxMessage &aQuery, TimeMilli aNow); + void UpdateRecordStateAfterQuery(TimeMilli aNow); + void DetermineRecordFireTime(void); + void ProcessExpiredRecords(TimeMilli aNow); + void ReportResultsTo(ResultCallback &aCallback) const; + + Error Init(Instance &aInstance, const char *aServiceType, const char *aSubTypeLabel); + Error Init(Instance &aInstance, const Browser &aBrowser); + void AppendServiceTypeOrSubTypeTo(TxMessage &aTxMessage, Section aSection); + void AppendKnownAnswer(TxMessage &aTxMessage, const PtrEntry &aPtrEntry, TimeMilli aNow); + void DiscoverCompressOffsets(void); + + BrowseCache *mNext; + Heap::String mServiceType; + Heap::String mSubTypeLabel; + OwningList mPtrEntries; + uint16_t mServiceTypeOffset; + uint16_t mSubServiceTypeOffset; + uint16_t mSubServiceNameOffset; + }; + + // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + struct ServiceName + { + ServiceName(const char *aServiceInstance, const char *aServiceType) + : mServiceInstance(aServiceInstance) + , mServiceType(aServiceType) + { + } + + const char *mServiceInstance; + const char *mServiceType; + }; + + // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + class ServiceCache : public CacheEntry + { + // Base class for `SrvCache` and `TxtCache`, tracking common info + // shared between the two, e.g. service instance/type strings, + // record info, and append state and compression offsets. + + friend class CacheEntry; + + public: + void ClearCompressOffsets(void); + + protected: + ServiceCache(void) = default; + + Error Init(Instance &aInstance, Type aType, const char *aServiceInstance, const char *aServiceType); + bool Matches(const Name &aFullName) const; + bool Matches(const char *aServiceInstance, const char *aServiceType) const; + void PrepareQueryQuestion(TxMessage &aQuery, uint16_t aRrType); + void AppendServiceNameTo(TxMessage &aTxMessage, Section aSection); + void UpdateRecordStateAfterQuery(TimeMilli aNow); + void DetermineRecordFireTime(void); + bool ShouldStartInitialQueries(void) const; + + CacheRecordInfo mRecord; + Heap::String mServiceInstance; + Heap::String mServiceType; + uint16_t mServiceNameOffset; + uint16_t mServiceTypeOffset; + }; + + // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + class SrvCache : public ServiceCache, public LinkedListEntry, public Heap::Allocatable + { + friend class LinkedListEntry; + friend class Heap::Allocatable; + friend class CacheEntry; + friend class TxtCache; + friend class BrowseCache; + + public: + bool Matches(const Name &aFullName) const; + bool Matches(const SrvResolver &aResolver) const; + bool Matches(const ServiceName &aServiceName) const; + bool Matches(const ExpireChecker &aExpireChecker) const; + Error Add(const SrvResolver &aResolver); + void Remove(const SrvResolver &aResolver); + void ProcessResponseRecord(const Message &aMessage, uint16_t aRecordOffset); + + private: + Error Init(Instance &aInstance, const char *aServiceInstance, const char *aServiceType); + Error Init(Instance &aInstance, const ServiceName &aServiceName); + Error Init(Instance &aInstance, const SrvResolver &aResolver); + void PrepareSrvQuestion(TxMessage &aQuery); + void DiscoverCompressOffsets(void); + void ProcessExpiredRecords(TimeMilli aNow); + void ReportResultTo(ResultCallback &aCallback) const; + void ConvertTo(SrvResult &aResult) const; + + SrvCache *mNext; + Heap::String mHostName; + uint16_t mPort; + uint16_t mPriority; + uint16_t mWeight; + }; + + // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + class TxtCache : public ServiceCache, public LinkedListEntry, public Heap::Allocatable + { + friend class LinkedListEntry; + friend class Heap::Allocatable; + friend class CacheEntry; + friend class BrowseCache; + + public: + bool Matches(const Name &aFullName) const; + bool Matches(const TxtResolver &aResolver) const; + bool Matches(const ServiceName &aServiceName) const; + bool Matches(const ExpireChecker &aExpireChecker) const; + Error Add(const TxtResolver &aResolver); + void Remove(const TxtResolver &aResolver); + void ProcessResponseRecord(const Message &aMessage, uint16_t aRecordOffset); + + private: + Error Init(Instance &aInstance, const char *aServiceInstance, const char *aServiceType); + Error Init(Instance &aInstance, const ServiceName &aServiceName); + Error Init(Instance &aInstance, const TxtResolver &aResolver); + void PrepareTxtQuestion(TxMessage &aQuery); + void DiscoverCompressOffsets(void); + void ProcessExpiredRecords(TimeMilli aNow); + void ReportResultTo(ResultCallback &aCallback) const; + void ConvertTo(TxtResult &aResult) const; + + TxtCache *mNext; + Heap::Data mTxtData; + }; + + // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + class AddrCache : public CacheEntry + { + // Base class for `Ip6AddrCache` and `Ip4AddrCache`, tracking common info + // shared between the two. + + friend class CacheEntry; + + public: + bool Matches(const Name &aFullName) const; + bool Matches(const char *aName) const; + bool Matches(const AddressResolver &aBrowser) const; + bool Matches(const ExpireChecker &aExpireChecker) const; + Error Add(const AddressResolver &aResolver); + void Remove(const AddressResolver &aResolver); + void CommitNewResponseEntries(void); + + protected: + struct AddrEntry : public LinkedListEntry, public Heap::Allocatable + { + explicit AddrEntry(const Ip6::Address &aAddress); + bool Matches(const Ip6::Address &aAddress) const { return (mAddress == aAddress); } + bool Matches(const ExpireChecker &aExpireChecker) const; + bool Matches(EmptyChecker aChecker) const; + uint32_t GetTtl(void) const { return mRecord.GetTtl(); } + + AddrEntry *mNext; + Ip6::Address mAddress; + CacheRecordInfo mRecord; + }; + + // Called by base class `CacheEntry` + void PrepareQueryQuestion(TxMessage &aQuery, uint16_t aRrType); + void UpdateRecordStateAfterQuery(TimeMilli aNow); + void DetermineRecordFireTime(void); + void ProcessExpiredRecords(TimeMilli aNow); + void ReportResultsTo(ResultCallback &aCallback) const; + bool ShouldStartInitialQueries(void) const; + + Error Init(Instance &aInstance, Type aType, const char *aHostName); + Error Init(Instance &aInstance, Type aType, const AddressResolver &aResolver); + void AppendNameTo(TxMessage &aTxMessage, Section aSection); + void ConstructResult(AddressResult &aResult, Heap::Array &aAddrArray) const; + void AddNewResponseAddress(const Ip6::Address &aAddress, uint32_t aTtl, bool aCacheFlush); + + AddrCache *mNext; + Heap::String mName; + OwningList mCommittedEntries; + OwningList mNewEntries; + bool mShouldFlush; + }; + + // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + class Ip6AddrCache : public AddrCache, public LinkedListEntry, public Heap::Allocatable + { + friend class CacheEntry; + friend class LinkedListEntry; + friend class Heap::Allocatable; + + public: + void ProcessResponseRecord(const Message &aMessage, uint16_t aRecordOffset); + + private: + Error Init(Instance &aInstance, const char *aHostName); + Error Init(Instance &aInstance, const AddressResolver &aResolver); + void PrepareAaaaQuestion(TxMessage &aQuery); + }; + + // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + class Ip4AddrCache : public AddrCache, public LinkedListEntry, public Heap::Allocatable + { + friend class CacheEntry; + friend class LinkedListEntry; + friend class Heap::Allocatable; + + public: + void ProcessResponseRecord(const Message &aMessage, uint16_t aRecordOffset); + + private: + Error Init(Instance &aInstance, const char *aHostName); + Error Init(Instance &aInstance, const AddressResolver &aResolver); + void PrepareAQuestion(TxMessage &aQuery); + }; + + // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + template OwningList &GetEntryList(void); + template + Error Register(const ItemInfo &aItemInfo, RequestId aRequestId, RegisterCallback aCallback); + template Error Unregister(const ItemInfo &aItemInfo); + + template OwningList &GetCacheList(void); + template + Error Start(const BrowserResolverType &aBrowserOrResolver); + template + Error Stop(const BrowserResolverType &aBrowserOrResolver); + + void InvokeConflictCallback(const char *aName, const char *aServiceType); + void HandleMessage(Message &aMessage, bool aIsUnicast, const AddressInfo &aSenderAddress); + void AddPassiveSrvTxtCache(const char *aServiceInstance, const char *aServcieType); + void AddPassiveIp6AddrCache(const char *aHostName); + TimeMilli RandomizeFirstProbeTxTime(void); + TimeMilli RandomizeInitialQueryTxTime(void); + void RemoveEmptyEntries(void); + void HandleEntryTimer(void); + void HandleEntryTask(void); + void HandleCacheTimer(void); + void HandleCacheTask(void); + + static bool IsKeyForService(const Key &aKey) { return aKey.mServiceType != nullptr; } + static uint32_t DetermineTtl(uint32_t aTtl, uint32_t aDefaultTtl); + static bool NameMatch(const Heap::String &aHeapString, const char *aName); + static bool NameMatch(const Heap::String &aFirst, const Heap::String &aSecond); + static void UpdateCacheFlushFlagIn(ResourceRecord &aResourceRecord, Section aSection); + static void UpdateRecordLengthInMessage(ResourceRecord &aRecord, Message &aMessage, uint16_t aOffset); + static void UpdateCompressOffset(uint16_t &aOffset, uint16_t aNewOffse); + static bool QuestionMatches(uint16_t aQuestionRrType, uint16_t aRrType); + static bool RrClassIsInternetOrAny(uint16_t aRrClass); + + using EntryTimer = TimerMilliIn; + using CacheTimer = TimerMilliIn; + using EntryTask = TaskletIn; + using CacheTask = TaskletIn; + + static const char kLocalDomain[]; // "local." + static const char kUdpServiceLabel[]; // "_udp" + static const char kTcpServiceLabel[]; // "_tcp" + static const char kSubServiceLabel[]; // "_sub" + static const char kServicesDnssdLabels[]; // "_services._dns-sd._udp" + + bool mIsEnabled; + bool mIsQuestionUnicastAllowed; + uint16_t mMaxMessageSize; + uint32_t mInfraIfIndex; + OwningList mHostEntries; + OwningList mServiceEntries; + OwningList mServiceTypes; + MultiPacketRxMessages mMultiPacketRxMessages; + TimeMilli mNextProbeTxTime; + EntryTimer mEntryTimer; + EntryTask mEntryTask; + TxMessageHistory mTxMessageHistory; + ConflictCallback mConflictCallback; + + OwningList mBrowseCacheList; + OwningList mSrvCacheList; + OwningList mTxtCacheList; + OwningList mIp6AddrCacheList; + OwningList mIp4AddrCacheList; + TimeMilli mNextQueryTxTime; + CacheTimer mCacheTimer; + CacheTask mCacheTask; +}; + +// Specializations of `Core::GetEntryList()` for `HostEntry` and `ServcieEntry`: + +template <> inline OwningList &Core::GetEntryList(void) { return mHostEntries; } + +template <> inline OwningList &Core::GetEntryList(void) +{ + return mServiceEntries; +} + +// Specializations of `Core::GetCacheList()`: + +template <> inline OwningList &Core::GetCacheList(void) +{ + return mBrowseCacheList; +} + +template <> inline OwningList &Core::GetCacheList(void) { return mSrvCacheList; } + +template <> inline OwningList &Core::GetCacheList(void) { return mTxtCacheList; } + +template <> inline OwningList &Core::GetCacheList(void) +{ + return mIp6AddrCacheList; +} + +template <> inline OwningList &Core::GetCacheList(void) +{ + return mIp4AddrCacheList; +} + +} // namespace Multicast +} // namespace Dns + +DefineCoreType(otPlatMdnsAddressInfo, Dns::Multicast::Core::AddressInfo); + +} // namespace ot + +#endif // OPENTHREAD_CONFIG_MULTICAST_DNS_ENABLE + +#endif // MULTICAST_DNS_HPP_ diff --git a/src/core/openthread-core-config.h b/src/core/openthread-core-config.h index 009b671f9..7465915e3 100644 --- a/src/core/openthread-core-config.h +++ b/src/core/openthread-core-config.h @@ -98,6 +98,7 @@ #include "config/link_raw.h" #include "config/logging.h" #include "config/mac.h" +#include "config/mdns.h" #include "config/mesh_diag.h" #include "config/mesh_forwarder.h" #include "config/misc.h" diff --git a/src/posix/platform/CMakeLists.txt b/src/posix/platform/CMakeLists.txt index e4795834f..825d106dc 100644 --- a/src/posix/platform/CMakeLists.txt +++ b/src/posix/platform/CMakeLists.txt @@ -135,6 +135,7 @@ add_library(openthread-posix infra_if.cpp logging.cpp mainloop.cpp + mdns_socket.cpp memory.cpp misc.cpp multicast_routing.cpp diff --git a/src/posix/platform/mdns_socket.cpp b/src/posix/platform/mdns_socket.cpp new file mode 100644 index 000000000..2f96ee796 --- /dev/null +++ b/src/posix/platform/mdns_socket.cpp @@ -0,0 +1,710 @@ +/* + * Copyright (c) 2024, The OpenThread Authors. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the name of the copyright holder nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#include "mdns_socket.hpp" + +#if OPENTHREAD_CONFIG_MULTICAST_DNS_ENABLE + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ip6_utils.hpp" +#include "platform-posix.h" +#include "common/code_utils.hpp" + +extern "C" otError otPlatMdnsSetListeningEnabled(otInstance *aInstance, bool aEnable, uint32_t aInfraIfIndex) +{ + return ot::Posix::MdnsSocket::Get().SetListeningEnabled(aInstance, aEnable, aInfraIfIndex); +} + +extern "C" void otPlatMdnsSendMulticast(otInstance *aInstance, otMessage *aMessage, uint32_t aInfraIfIndex) +{ + OT_UNUSED_VARIABLE(aInstance); + return ot::Posix::MdnsSocket::Get().SendMulticast(aMessage, aInfraIfIndex); +} + +extern "C" void otPlatMdnsSendUnicast(otInstance *aInstance, otMessage *aMessage, const otPlatMdnsAddressInfo *aAddress) +{ + OT_UNUSED_VARIABLE(aInstance); + return ot::Posix::MdnsSocket::Get().SendUnicast(aMessage, aAddress); +} + +namespace ot { +namespace Posix { + +using namespace ot::Posix::Ip6Utils; + +const char MdnsSocket::kLogModuleName[] = "MdnsSocket"; + +MdnsSocket &MdnsSocket::Get(void) +{ + static MdnsSocket sInstance; + + return sInstance; +} + +void MdnsSocket::Init(void) +{ + mEnabled = false; + mInfraIfIndex = 0; + mFd6 = -1; + mFd4 = -1; + mPendingIp6Tx = 0; + mPendingIp4Tx = 0; + + // mDNS multicast IPv6 address "ff02::fb" + memset(&mMulticastIp6Address, 0, sizeof(otIp6Address)); + mMulticastIp6Address.mFields.m8[0] = 0xff; + mMulticastIp6Address.mFields.m8[1] = 0x02; + mMulticastIp6Address.mFields.m8[15] = 0xfb; + + // mDNS multicast IPv4 address "224.0.0.251" + memset(&mMulticastIp4Address, 0, sizeof(otIp4Address)); + mMulticastIp4Address.mFields.m8[0] = 224; + mMulticastIp4Address.mFields.m8[3] = 251; + + memset(&mTxQueue, 0, sizeof(mTxQueue)); +} + +void MdnsSocket::SetUp(void) +{ + otMessageQueueInit(&mTxQueue); + Mainloop::Manager::Get().Add(*this); +} + +void MdnsSocket::TearDown(void) +{ + Mainloop::Manager::Get().Remove(*this); + + if (mEnabled) + { + ClearTxQueue(); + mEnabled = false; + } +} + +void MdnsSocket::Deinit(void) +{ + CloseIp4Socket(); + CloseIp6Socket(); +} + +void MdnsSocket::Update(otSysMainloopContext &aContext) +{ + VerifyOrExit(mEnabled); + + FD_SET(mFd6, &aContext.mReadFdSet); + FD_SET(mFd4, &aContext.mReadFdSet); + + if (mPendingIp6Tx > 0) + { + FD_SET(mFd6, &aContext.mWriteFdSet); + } + + if (mPendingIp4Tx > 0) + { + FD_SET(mFd4, &aContext.mWriteFdSet); + } + + if (aContext.mMaxFd < mFd6) + { + aContext.mMaxFd = mFd6; + } + + if (aContext.mMaxFd < mFd4) + { + aContext.mMaxFd = mFd4; + } + +exit: + return; +} + +void MdnsSocket::Process(const otSysMainloopContext &aContext) +{ + VerifyOrExit(mEnabled); + + if (FD_ISSET(mFd6, &aContext.mWriteFdSet)) + { + SendQueuedMessages(kIp6Msg); + } + + if (FD_ISSET(mFd4, &aContext.mWriteFdSet)) + { + SendQueuedMessages(kIp4Msg); + } + + if (FD_ISSET(mFd6, &aContext.mReadFdSet)) + { + ReceiveMessage(kIp6Msg); + } + + if (FD_ISSET(mFd4, &aContext.mReadFdSet)) + { + ReceiveMessage(kIp4Msg); + } + +exit: + return; +} + +otError MdnsSocket::SetListeningEnabled(otInstance *aInstance, bool aEnable, uint32_t aInfraIfIndex) +{ + otError error = OT_ERROR_NONE; + + VerifyOrExit(aEnable != mEnabled); + mInstance = aInstance; + + if (aEnable) + { + error = Enable(aInfraIfIndex); + } + else + { + Disable(aInfraIfIndex); + } + +exit: + return error; +} + +otError MdnsSocket::Enable(uint32_t aInfraIfIndex) +{ + otError error; + + SuccessOrExit(error = OpenIp4Socket(aInfraIfIndex)); + SuccessOrExit(error = JoinOrLeaveIp4MulticastGroup(/* aJoin */ true, aInfraIfIndex)); + + SuccessOrExit(error = OpenIp6Socket(aInfraIfIndex)); + SuccessOrExit(error = JoinOrLeaveIp6MulticastGroup(/* aJoin */ true, aInfraIfIndex)); + + mEnabled = true; + mInfraIfIndex = aInfraIfIndex; + + LogInfo("Enabled"); + +exit: + if (error != OT_ERROR_NONE) + { + CloseIp4Socket(); + CloseIp6Socket(); + } + + return error; +} + +void MdnsSocket::Disable(uint32_t aInfraIfIndex) +{ + ClearTxQueue(); + + IgnoreError(JoinOrLeaveIp4MulticastGroup(/* aJoin */ false, aInfraIfIndex)); + IgnoreError(JoinOrLeaveIp6MulticastGroup(/* aJoin */ false, aInfraIfIndex)); + CloseIp4Socket(); + CloseIp6Socket(); + + mEnabled = false; + + LogInfo("Disabled"); +} + +void MdnsSocket::SendMulticast(otMessage *aMessage, uint32_t aInfraIfIndex) +{ + Metadata metadata; + uint16_t length; + + VerifyOrExit(mEnabled); + VerifyOrExit(aInfraIfIndex == mInfraIfIndex); + + length = otMessageGetLength(aMessage); + + if (length > kMaxMessageLength) + { + LogWarn("Multicast msg length %u is longer than max %u", length, kMaxMessageLength); + ExitNow(); + } + + metadata.mIp6Address = mMulticastIp6Address; + metadata.mIp6Port = kMdnsPort; + metadata.mIp4Address = mMulticastIp4Address; + metadata.mIp4Port = kMdnsPort; + + SuccessOrExit(otMessageAppend(aMessage, &metadata, sizeof(Metadata))); + + mPendingIp4Tx++; + mPendingIp6Tx++; + + otMessageQueueEnqueue(&mTxQueue, aMessage); + aMessage = NULL; + +exit: + if (aMessage != NULL) + { + otMessageFree(aMessage); + } +} + +void MdnsSocket::SendUnicast(otMessage *aMessage, const otPlatMdnsAddressInfo *aAddress) +{ + bool isIp4 = false; + Metadata metadata; + uint16_t length; + + VerifyOrExit(mEnabled); + VerifyOrExit(aAddress->mInfraIfIndex == mInfraIfIndex); + + length = otMessageGetLength(aMessage); + + if (length > kMaxMessageLength) + { + LogWarn("Unicast msg length %u is longer than max %u", length, kMaxMessageLength); + ExitNow(); + } + + memset(&metadata, 0, sizeof(Metadata)); + + if (otIp4FromIp4MappedIp6Address(&aAddress->mAddress, &metadata.mIp4Address) == OT_ERROR_NONE) + { + isIp4 = true; + metadata.mIp4Port = aAddress->mPort; + metadata.mIp6Port = 0; + } + else + { + metadata.mIp6Address = aAddress->mAddress; + metadata.mIp4Port = 0; + metadata.mIp6Port = aAddress->mPort; + } + + SuccessOrExit(otMessageAppend(aMessage, &metadata, sizeof(Metadata))); + + if (isIp4) + { + mPendingIp4Tx++; + } + else + { + mPendingIp6Tx++; + } + + otMessageQueueEnqueue(&mTxQueue, aMessage); + aMessage = NULL; + +exit: + if (aMessage != NULL) + { + otMessageFree(aMessage); + } +} + +void MdnsSocket::ClearTxQueue(void) +{ + otMessage *message; + + while ((message = otMessageQueueGetHead(&mTxQueue)) != NULL) + { + otMessageQueueDequeue(&mTxQueue, message); + otMessageFree(message); + } + + mPendingIp4Tx = 0; + mPendingIp6Tx = 0; +} + +void MdnsSocket::SendQueuedMessages(MsgType aMsgType) +{ + switch (aMsgType) + { + case kIp6Msg: + VerifyOrExit(mPendingIp6Tx > 0); + break; + case kIp4Msg: + VerifyOrExit(mPendingIp4Tx > 0); + break; + } + + for (otMessage *message = otMessageQueueGetHead(&mTxQueue); message != NULL; + message = otMessageQueueGetNext(&mTxQueue, message)) + { + bool isTxPending = false; + uint16_t length; + uint16_t offset; + int bytesSent; + Metadata metadata; + uint8_t buffer[kMaxMessageLength]; + struct sockaddr_in6 addr6; + struct sockaddr_in addr; + + length = otMessageGetLength(message); + + offset = length - sizeof(Metadata); + length -= sizeof(Metadata); + + otMessageRead(message, offset, &metadata, sizeof(Metadata)); + + switch (aMsgType) + { + case kIp6Msg: + isTxPending = (metadata.mIp6Port != 0); + break; + case kIp4Msg: + isTxPending = (metadata.mIp4Port != 0); + break; + } + + if (!isTxPending) + { + continue; + } + + otMessageRead(message, 0, buffer, length); + + switch (aMsgType) + { + case kIp6Msg: + memset(&addr6, 0, sizeof(addr6)); + addr6.sin6_family = AF_INET6; + addr6.sin6_port = htons(metadata.mIp6Port); + CopyIp6AddressTo(metadata.mIp6Address, &addr6.sin6_addr); + bytesSent = sendto(mFd6, buffer, length, 0, reinterpret_cast(&addr6), sizeof(addr6)); + VerifyOrExit(bytesSent == length); + metadata.mIp6Port = 0; + mPendingIp6Tx--; + break; + + case kIp4Msg: + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = htons(metadata.mIp4Port); + memcpy(&addr.sin_addr.s_addr, &metadata.mIp4Address, sizeof(otIp4Address)); + bytesSent = sendto(mFd4, buffer, length, 0, reinterpret_cast(&addr), sizeof(addr)); + VerifyOrExit(bytesSent == length); + metadata.mIp4Port = 0; + mPendingIp4Tx--; + break; + } + + if (metadata.CanFreeMessage()) + { + otMessageQueueDequeue(&mTxQueue, message); + otMessageFree(message); + } + else + { + otMessageWrite(message, offset, &metadata, sizeof(Metadata)); + } + } + +exit: + return; +} + +void MdnsSocket::ReceiveMessage(MsgType aMsgType) +{ + otMessage *message = nullptr; + uint8_t buffer[kMaxMessageLength]; + otPlatMdnsAddressInfo addrInfo; + uint16_t length = 0; + struct sockaddr_in6 sockaddr6; + struct sockaddr_in sockaddr; + socklen_t len = sizeof(sockaddr6); + ssize_t rval; + + memset(&addrInfo, 0, sizeof(addrInfo)); + + switch (aMsgType) + { + case kIp6Msg: + len = sizeof(sockaddr6); + memset(&sockaddr6, 0, sizeof(sockaddr6)); + rval = recvfrom(mFd6, reinterpret_cast(&buffer), sizeof(buffer), 0, + reinterpret_cast(&sockaddr6), &len); + VerifyOrExit(rval >= 0, LogCrit("recvfrom() for IPv6 socket failed, errno: %s", strerror(errno))); + length = static_cast(rval); + ReadIp6AddressFrom(&sockaddr6.sin6_addr, addrInfo.mAddress); + break; + + case kIp4Msg: + len = sizeof(sockaddr); + memset(&sockaddr, 0, sizeof(sockaddr)); + rval = recvfrom(mFd4, reinterpret_cast(&buffer), sizeof(buffer), 0, + reinterpret_cast(&sockaddr), &len); + VerifyOrExit(rval >= 0, LogCrit("recvfrom() for IPv4 socket failed, errno: %s", strerror(errno))); + length = static_cast(rval); + otIp4ToIp4MappedIp6Address((otIp4Address *)(&sockaddr.sin_addr.s_addr), &addrInfo.mAddress); + break; + } + + VerifyOrExit(length > 0); + + message = otIp6NewMessage(mInstance, nullptr); + VerifyOrExit(message != nullptr); + SuccessOrExit(otMessageAppend(message, buffer, length)); + + addrInfo.mPort = kMdnsPort; + addrInfo.mInfraIfIndex = mInfraIfIndex; + + otPlatMdnsHandleReceive(mInstance, message, /* aInUnicast */ false, &addrInfo); + message = nullptr; + +exit: + if (message != nullptr) + { + otMessageFree(message); + } +} + +//--------------------------------------------------------------------------------------------------------------------- +// Socket helpers + +otError MdnsSocket::OpenIp4Socket(uint32_t aInfraIfIndex) +{ + otError error = OT_ERROR_FAILED; + struct sockaddr_in addr; + int fd; + + fd = socket(AF_INET, SOCK_DGRAM, 0); + VerifyOrExit(fd >= 0, LogCrit("Failed to create IPv4 socket")); + +#ifdef __linux__ + { + char nameBuffer[IF_NAMESIZE]; + const char *ifname; + + ifname = if_indextoname(aInfraIfIndex, nameBuffer); + VerifyOrExit(ifname != NULL, LogCrit("if_indextoname() failed")); + + error = SetSocketOptionValue(fd, SOL_SOCKET, SO_BINDTODEVICE, ifname, strlen(ifname), "SO_BINDTODEVICE"); + SuccessOrExit(error); + } +#else + { + int ifindex = static_cast(aInfraIfIndex); + + SuccessOrExit(error = SetSocketOption(fd, IPPROTO_IP, IP_BOUND_IF, ifindex, "IP_BOUND_IF")); + } +#endif + + SuccessOrExit(error = SetSocketOption(fd, IPPROTO_IP, IP_MULTICAST_TTL, 255, "IP_MULTICAST_TTL")); + SuccessOrExit(error = SetSocketOption(fd, IPPROTO_IP, IP_TTL, 255, "IP_TTL")); + SuccessOrExit(error = SetSocketOption(fd, IPPROTO_IP, IP_MULTICAST_LOOP, 1, "IP_MULTICAST_LOOP")); + SuccessOrExit(error = SetReuseAddrPortOptions(fd)); + + { + struct ip_mreqn mreqn; + + memset(&mreqn, 0, sizeof(mreqn)); + mreqn.imr_multiaddr.s_addr = inet_addr("224.0.0.251"); + mreqn.imr_ifindex = aInfraIfIndex; + + SuccessOrExit( + error = SetSocketOptionValue(fd, IPPROTO_IP, IP_MULTICAST_IF, &mreqn, sizeof(mreqn), "IP_MULTICAST_IF")); + } + + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_addr.s_addr = htonl(INADDR_ANY); + addr.sin_port = htons(kMdnsPort); + + if (bind(fd, reinterpret_cast(&addr), sizeof(addr)) < 0) + { + LogCrit("bind() to mDNS port for IPv4 socket failed, errno: %s", strerror(errno)); + error = OT_ERROR_FAILED; + ExitNow(); + } + + mFd4 = fd; + + LogInfo("Successfully opened IPv4 socket"); + +exit: + return error; +} + +otError MdnsSocket::JoinOrLeaveIp4MulticastGroup(bool aJoin, uint32_t aInfraIfIndex) +{ + struct ip_mreqn mreqn; + + memset(&mreqn, 0, sizeof(mreqn)); + memcpy(&mreqn.imr_multiaddr.s_addr, &mMulticastIp4Address, sizeof(otIp4Address)); + mreqn.imr_ifindex = aInfraIfIndex; + + if (aJoin) + { + // Suggested workaround for netif not dropping + // a previous multicast membership. + setsockopt(mFd4, IPPROTO_IP, IP_DROP_MEMBERSHIP, &mreqn, sizeof(mreqn)); + } + + return SetSocketOption(mFd4, IPPROTO_IP, aJoin ? IP_ADD_MEMBERSHIP : IP_DROP_MEMBERSHIP, mreqn, + "IP_ADD/DROP_MEMBERSHIP"); +} + +void MdnsSocket::CloseIp4Socket(void) +{ + if (mFd4 >= 0) + { + close(mFd4); + mFd4 = -1; + } +} + +otError MdnsSocket::OpenIp6Socket(uint32_t aInfraIfIndex) +{ + otError error = OT_ERROR_FAILED; + struct sockaddr_in6 addr6; + int fd; + int ifindex = static_cast(aInfraIfIndex); + + fd = socket(AF_INET6, SOCK_DGRAM, 0); + VerifyOrExit(fd >= 0, LogCrit("Failed to create IPv4 socket")); + +#ifdef __linux__ + { + char nameBuffer[IF_NAMESIZE]; + const char *ifname; + + ifname = if_indextoname(aInfraIfIndex, nameBuffer); + VerifyOrExit(ifname != NULL, LogCrit("if_indextoname() failed")); + + error = SetSocketOptionValue(fd, SOL_SOCKET, SO_BINDTODEVICE, ifname, strlen(ifname), "SO_BINDTODEVICE"); + SuccessOrExit(error); + } +#else + { + SuccessOrExit(error = SetSocketOption(fd, IPPROTO_IPV6, IPV6_BOUND_IF, ifindex, "IPV6_BOUND_IF")); + } +#endif + + SuccessOrExit(error = SetSocketOption(fd, IPPROTO_IPV6, IPV6_MULTICAST_HOPS, 255, "IPV6_MULTICAST_HOPS")); + SuccessOrExit(error = SetSocketOption(fd, IPPROTO_IPV6, IPV6_UNICAST_HOPS, 255, "IPV6_UNICAST_HOPS")); + SuccessOrExit(error = SetSocketOption(fd, IPPROTO_IPV6, IPV6_V6ONLY, 1, "IPV6_V6ONLY")); + SuccessOrExit(error = SetSocketOption(fd, IPPROTO_IPV6, IPV6_MULTICAST_IF, ifindex, "IPV6_MULTICAST_IF")); + SuccessOrExit(error = SetSocketOption(fd, IPPROTO_IPV6, IPV6_MULTICAST_LOOP, 1, "IPV6_MULTICAST_LOOP")); + SuccessOrExit(error = SetReuseAddrPortOptions(fd)); + + memset(&addr6, 0, sizeof(addr6)); + addr6.sin6_family = AF_INET6; + addr6.sin6_port = htons(kMdnsPort); + + if (bind(fd, reinterpret_cast(&addr6), sizeof(addr6)) < 0) + { + LogCrit("bind() to mDNS port for IPv6 socket failed, errno: %s", strerror(errno)); + error = OT_ERROR_FAILED; + ExitNow(); + } + + mFd6 = fd; + + LogInfo("Successfully opened IPv6 socket"); + +exit: + return error; +} + +#ifndef IPV6_ADD_MEMBERSHIP +#ifdef IPV6_JOIN_GROUP +#define IPV6_ADD_MEMBERSHIP IPV6_JOIN_GROUP +#endif +#endif + +#ifndef IPV6_DROP_MEMBERSHIP +#ifdef IPV6_LEAVE_GROUP +#define IPV6_DROP_MEMBERSHIP IPV6_LEAVE_GROUP +#endif +#endif + +otError MdnsSocket::JoinOrLeaveIp6MulticastGroup(bool aJoin, uint32_t aInfraIfIndex) +{ + struct ipv6_mreq mreq6; + + memset(&mreq6, 0, sizeof(mreq6)); + Ip6Utils::CopyIp6AddressTo(mMulticastIp6Address, &mreq6.ipv6mr_multiaddr); + + mreq6.ipv6mr_interface = static_cast(aInfraIfIndex); + + if (aJoin) + { + // Suggested workaround for netif not dropping + // a previous multicast membership. + setsockopt(mFd6, IPPROTO_IPV6, IPV6_DROP_MEMBERSHIP, &mreq6, sizeof(mreq6)); + } + + return SetSocketOptionValue(mFd6, IPPROTO_IPV6, aJoin ? IPV6_ADD_MEMBERSHIP : IPV6_DROP_MEMBERSHIP, &mreq6, + sizeof(mreq6), "IP6_ADD/DROP_MEMBERSHIP"); +} + +void MdnsSocket::CloseIp6Socket(void) +{ + if (mFd6 >= 0) + { + close(mFd6); + mFd6 = -1; + } +} + +otError MdnsSocket::SetReuseAddrPortOptions(int aFd) +{ + otError error; + + SuccessOrExit(error = SetSocketOption(aFd, SOL_SOCKET, SO_REUSEADDR, 1, "SO_REUSEADDR")); + SuccessOrExit(error = SetSocketOption(aFd, SOL_SOCKET, SO_REUSEPORT, 1, "SO_REUSEPORT")); + +exit: + return error; +} + +otError MdnsSocket::SetSocketOptionValue(int aFd, + int aLevel, + int aOption, + const void *aValue, + uint32_t aValueLength, + const char *aOptionName) +{ + otError error = OT_ERROR_NONE; + + if (setsockopt(aFd, aLevel, aOption, aValue, aValueLength) != 0) + { + error = OT_ERROR_FAILED; + LogCrit("Failed to setsockopt(%s) - errno: %s", aOptionName, strerror(errno)); + } + + return error; +} + +} // namespace Posix +} // namespace ot + +#endif // OPENTHREAD_CONFIG_MULTICAST_DNS_ENABLE diff --git a/src/posix/platform/mdns_socket.hpp b/src/posix/platform/mdns_socket.hpp new file mode 100644 index 000000000..af60733be --- /dev/null +++ b/src/posix/platform/mdns_socket.hpp @@ -0,0 +1,181 @@ +/* + * Copyright (c) 2024, The OpenThread Authors. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the name of the copyright holder nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ +#ifndef OT_POSIX_PLATFORM_MDNS_SOCKET_HPP_ +#define OT_POSIX_PLATFORM_MDNS_SOCKET_HPP_ + +#include "openthread-posix-config.h" + +#if OPENTHREAD_CONFIG_MULTICAST_DNS_ENABLE + +#include +#include + +#include "logger.hpp" +#include "mainloop.hpp" + +#include "core/common/non_copyable.hpp" + +namespace ot { +namespace Posix { + +/** + * Implements platform mDNS socket APIs. + * + */ +class MdnsSocket : public Mainloop::Source, public Logger, private NonCopyable +{ +public: + static const char kLogModuleName[]; ///< Module name used for logging. + + /** + * Gets the `MdnsSocket` singleton. + * + * @returns The singleton object. + * + */ + static MdnsSocket &Get(void); + + /** + * Initializes the `MdnsSocket`. + * + * Called before OpenThread instance is created. + * + */ + void Init(void); + + /** + * Sets up the `MdnsSocket`. + * + * Called after OpenThread instance is created. + * + */ + void SetUp(void); + + /** + * Tears down the `MdnsSocket`. + * + * Called before OpenThread instance is destructed. + * + */ + void TearDown(void); + + /** + * Deinitializes the `MdnsSocket`. + * + * Called after OpenThread instance is destructed. + * + */ + void Deinit(void); + + /** + * Updates the fd_set and timeout for mainloop. + * + * @param[in,out] aContext A reference to the mainloop context. + * + */ + void Update(otSysMainloopContext &aContext) override; + + /** + * Performs `MdnsSocket` processing. + * + * @param[in] aContext A reference to the mainloop context. + * + */ + void Process(const otSysMainloopContext &aContext) override; + + // otPlatMdns APIs + otError SetListeningEnabled(otInstance *aInstance, bool aEnable, uint32_t aInfraIfIndex); + void SendMulticast(otMessage *aMessage, uint32_t aInfraIfIndex); + void SendUnicast(otMessage *aMessage, const otPlatMdnsAddressInfo *aAddress); + +private: + static constexpr uint16_t kMaxMessageLength = 2000; + static constexpr uint16_t kMdnsPort = 5353; + + enum MsgType : uint8_t + { + kIp6Msg, + kIp4Msg, + }; + + struct Metadata + { + bool CanFreeMessage(void) const { return (mIp6Port == 0) && (mIp4Port == 0); } + + otIp6Address mIp6Address; + uint16_t mIp6Port; + otIp4Address mIp4Address; + uint16_t mIp4Port; + }; + + bool mEnabled; + uint32_t mInfraIfIndex; + int mFd4; + int mFd6; + uint32_t mPendingIp6Tx; + uint32_t mPendingIp4Tx; + otMessageQueue mTxQueue; + otIp6Address mMulticastIp6Address; + otIp4Address mMulticastIp4Address; + otInstance *mInstance; + + otError Enable(uint32_t aInfraIfIndex); + void Disable(uint32_t aInfraIfIndex); + void ClearTxQueue(void); + void SendQueuedMessages(MsgType aMsgType); + void ReceiveMessage(MsgType aMsgType); + + otError OpenIp4Socket(uint32_t aInfraIfIndex); + otError JoinOrLeaveIp4MulticastGroup(bool aJoin, uint32_t aInfraIfIndex); + void CloseIp4Socket(void); + otError OpenIp6Socket(uint32_t aInfraIfIndex); + otError JoinOrLeaveIp6MulticastGroup(bool aJoin, uint32_t aInfraIfIndex); + void CloseIp6Socket(void); + + static otError SetReuseAddrPortOptions(int aFd); + + template + static otError SetSocketOption(int aFd, int aLevel, int aOption, const ValueType &aValue, const char *aOptionName) + { + return SetSocketOptionValue(aFd, aLevel, aOption, &aValue, sizeof(ValueType), aOptionName); + } + + static otError SetSocketOptionValue(int aFd, + int aLevel, + int aOption, + const void *aValue, + uint32_t aValueLength, + const char *aOptionName); +}; + +} // namespace Posix +} // namespace ot + +#endif // OPENTHREAD_CONFIG_MULTICAST_DNS_ENABLE + +#endif // OT_POSIX_PLATFORM_MDNS_SOCKET_HPP_ diff --git a/src/posix/platform/system.cpp b/src/posix/platform/system.cpp index 64b2644c9..c209d63aa 100644 --- a/src/posix/platform/system.cpp +++ b/src/posix/platform/system.cpp @@ -54,6 +54,7 @@ #include "posix/platform/firewall.hpp" #include "posix/platform/infra_if.hpp" #include "posix/platform/mainloop.hpp" +#include "posix/platform/mdns_socket.hpp" #include "posix/platform/radio_url.hpp" #include "posix/platform/udp.hpp" @@ -145,7 +146,10 @@ void platformInit(otPlatformConfig *aPlatformConfig) #if OPENTHREAD_POSIX_CONFIG_INFRA_IF_ENABLE ot::Posix::InfraNetif::Get().Init(); +#endif +#if OPENTHREAD_CONFIG_MULTICAST_DNS_ENABLE + ot::Posix::MdnsSocket::Get().Init(); #endif gNetifName[0] = '\0'; @@ -197,6 +201,10 @@ void platformSetUp(otPlatformConfig *aPlatformConfig) ot::Posix::Udp::Get().SetUp(); #endif +#if OPENTHREAD_CONFIG_MULTICAST_DNS_ENABLE + ot::Posix::MdnsSocket::Get().SetUp(); +#endif + #if OPENTHREAD_POSIX_CONFIG_DAEMON_ENABLE ot::Posix::Daemon::Get().SetUp(); #endif @@ -244,6 +252,10 @@ void platformTearDown(void) ot::Posix::InfraNetif::Get().TearDown(); #endif +#if OPENTHREAD_CONFIG_MULTICAST_DNS_ENABLE + ot::Posix::MdnsSocket::Get().TearDown(); +#endif + exit: return; } @@ -272,6 +284,10 @@ void platformDeinit(void) ot::Posix::InfraNetif::Get().Deinit(); #endif +#if OPENTHREAD_CONFIG_MULTICAST_DNS_ENABLE + ot::Posix::MdnsSocket::Get().Deinit(); +#endif + exit: return; } diff --git a/tests/toranj/openthread-core-toranj-config-posix.h b/tests/toranj/openthread-core-toranj-config-posix.h index 2fced0032..e30e887d9 100644 --- a/tests/toranj/openthread-core-toranj-config-posix.h +++ b/tests/toranj/openthread-core-toranj-config-posix.h @@ -39,12 +39,14 @@ #define OPENTHREAD_CONFIG_PLATFORM_INFO "POSIX-toranj" +#define OPENTHREAD_CONFIG_MULTICAST_DNS_ENABLE 1 + #ifdef __linux__ #define OPENTHREAD_CONFIG_BACKBONE_ROUTER_ENABLE 1 #endif #ifndef OPENTHREAD_CONFIG_PLATFORM_UDP_ENABLE -#define OPENTHREAD_CONFIG_PLATFORM_UDP_ENABLE 1 +#define OPENTHREAD_CONFIG_PLATFORM_UDP_ENABLE 0 #endif #define OPENTHREAD_CONFIG_PLATFORM_NETIF_ENABLE 1 diff --git a/tests/toranj/openthread-core-toranj-config-simulation.h b/tests/toranj/openthread-core-toranj-config-simulation.h index f564e0bc9..7db783d47 100644 --- a/tests/toranj/openthread-core-toranj-config-simulation.h +++ b/tests/toranj/openthread-core-toranj-config-simulation.h @@ -42,13 +42,14 @@ #define OPENTHREAD_CONFIG_PLATFORM_INFO "SIMULATION-RCP-toranj" #else #define OPENTHREAD_CONFIG_PLATFORM_INFO "SIMULATION-toranj" - #endif #define OPENTHREAD_CONFIG_COAP_API_ENABLE 1 #define OPENTHREAD_CONFIG_COAP_SECURE_API_ENABLE 1 +#define OPENTHREAD_SIMULATION_MDNS_SOCKET_IMPLEMENT_POSIX 1 + #define OPENTHREAD_CONFIG_PLATFORM_USEC_TIMER_ENABLE 0 #define OPENTHREAD_CONFIG_PLATFORM_FLASH_API_ENABLE 1 diff --git a/tests/toranj/openthread-core-toranj-config.h b/tests/toranj/openthread-core-toranj-config.h index c1a968916..75d71ab96 100644 --- a/tests/toranj/openthread-core-toranj-config.h +++ b/tests/toranj/openthread-core-toranj-config.h @@ -178,6 +178,10 @@ #define OPENTHREAD_CONFIG_DNSSD_SERVER_ENABLE 1 +#define OPENTHREAD_CONFIG_MULTICAST_DNS_ENABLE 1 + +#define OPENTHREAD_CONFIG_MULTICAST_DNS_PUBLIC_API_ENABLE 1 + #define OPENTHREAD_CONFIG_DELAY_AWARE_QUEUE_MANAGEMENT_ENABLE 1 #define OPENTHREAD_CONFIG_CLI_REGISTER_IP6_RECV_CALLBACK 1 diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index 92337b30c..3b491bfb0 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -723,6 +723,27 @@ target_link_libraries(ot-test-macros add_test(NAME ot-test-macros COMMAND ot-test-macros) +add_executable(ot-test-mdns + test_mdns.cpp +) + +target_include_directories(ot-test-mdns + PRIVATE + ${COMMON_INCLUDES} +) + +target_compile_options(ot-test-mdns + PRIVATE + ${COMMON_COMPILE_OPTIONS} +) + +target_link_libraries(ot-test-mdns + PRIVATE + ${COMMON_LIBS} +) + +add_test(NAME ot-test-mdns COMMAND ot-test-mdns) + add_executable(ot-test-message test_message.cpp ) diff --git a/tests/unit/test_mdns.cpp b/tests/unit/test_mdns.cpp new file mode 100644 index 000000000..d0dbcfa86 --- /dev/null +++ b/tests/unit/test_mdns.cpp @@ -0,0 +1,6841 @@ +/* + * Copyright (c) 2024, The OpenThread Authors. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the name of the copyright holder nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#include + +#include "test_platform.h" +#include "test_util.hpp" + +#include "common/arg_macros.hpp" +#include "common/array.hpp" +#include "common/as_core_type.hpp" +#include "common/num_utils.hpp" +#include "common/owning_list.hpp" +#include "common/string.hpp" +#include "common/time.hpp" +#include "instance/instance.hpp" +#include "net/dns_dso.hpp" +#include "net/mdns.hpp" + +#if OPENTHREAD_CONFIG_MULTICAST_DNS_ENABLE + +namespace ot { +namespace Dns { +namespace Multicast { + +#define ENABLE_TEST_LOG 1 // Enable to get logs from unit test. + +// Logs a message and adds current time (sNow) as "::." +#if ENABLE_TEST_LOG +#define Log(...) \ + printf("%02u:%02u:%02u.%03u " OT_FIRST_ARG(__VA_ARGS__) "\n", (sNow / 3600000), (sNow / 60000) % 60, \ + (sNow / 1000) % 60, sNow % 1000 OT_REST_ARGS(__VA_ARGS__)) +#else +#define Log(...) +#endif + +//--------------------------------------------------------------------------------------------------------------------- +// Constants + +static constexpr uint16_t kClassQueryUnicastFlag = (1U << 15); +static constexpr uint16_t kClassCacheFlushFlag = (1U << 15); +static constexpr uint16_t kClassMask = 0x7fff; +static constexpr uint16_t kStringSize = 300; +static constexpr uint16_t kMaxDataSize = 400; +static constexpr uint16_t kNumAnnounces = 3; +static constexpr uint16_t kNumInitalQueries = 3; +static constexpr uint16_t kNumRefreshQueries = 4; +static constexpr bool kCacheFlush = true; +static constexpr uint16_t kMdnsPort = 5353; +static constexpr uint32_t kInfraIfIndex = 1; + +static const char kDeviceIp6Address[] = "fd01::1"; + +class DnsMessage; + +//--------------------------------------------------------------------------------------------------------------------- +// Variables + +static Instance *sInstance; + +static uint32_t sNow = 0; +static uint32_t sAlarmTime; +static bool sAlarmOn = false; + +OwningList sDnsMessages; +uint32_t sInfraIfIndex; + +//--------------------------------------------------------------------------------------------------------------------- +// Prototypes + +static const char *RecordTypeToString(uint16_t aType); + +//--------------------------------------------------------------------------------------------------------------------- +// Types + +template class Allocatable +{ +public: + static Type *Allocate(void) + { + void *buf = calloc(1, sizeof(Type)); + + VerifyOrQuit(buf != nullptr); + return new (buf) Type(); + } + + void Free(void) + { + static_cast(this)->~Type(); + free(this); + } +}; + +struct DnsName +{ + Name::Buffer mName; + + void ParseFrom(const Message &aMessage, uint16_t &aOffset) + { + SuccessOrQuit(Name::ReadName(aMessage, aOffset, mName)); + } + + void CopyFrom(const char *aName) + { + if (aName == nullptr) + { + mName[0] = '\0'; + } + else + { + uint16_t len = StringLength(aName, sizeof(mName)); + + VerifyOrQuit(len < sizeof(mName)); + memcpy(mName, aName, len + 1); + } + } + + const char *AsCString(void) const { return mName; } + bool Matches(const char *aName) const { return StringMatch(mName, aName, kStringCaseInsensitiveMatch); } +}; + +typedef String DnsNameString; + +struct AddrAndTtl +{ + bool operator==(const AddrAndTtl &aOther) const { return (mTtl == aOther.mTtl) && (mAddress == aOther.mAddress); } + + Ip6::Address mAddress; + uint32_t mTtl; +}; + +struct DnsQuestion : public Allocatable, public LinkedListEntry +{ + DnsQuestion *mNext; + DnsName mName; + uint16_t mType; + uint16_t mClass; + bool mUnicastResponse; + + void ParseFrom(const Message &aMessage, uint16_t &aOffset) + { + Question question; + + mName.ParseFrom(aMessage, aOffset); + SuccessOrQuit(aMessage.Read(aOffset, question)); + aOffset += sizeof(Question); + + mNext = nullptr; + mType = question.GetType(); + mClass = question.GetClass() & kClassMask; + mUnicastResponse = question.GetClass() & kClassQueryUnicastFlag; + + Log(" %s %s %s class:%u", mName.AsCString(), RecordTypeToString(mType), mUnicastResponse ? "QU" : "QM", + mClass); + } + + bool Matches(const char *aName) const { return mName.Matches(aName); } +}; + +struct DnsQuestions : public OwningList +{ + bool Contains(uint16_t aRrType, const DnsNameString &aFullName, bool aUnicastResponse = false) const + { + bool contains = false; + const DnsQuestion *question = FindMatching(aFullName.AsCString()); + + VerifyOrExit(question != nullptr); + VerifyOrExit(question->mType == aRrType); + VerifyOrExit(question->mClass == ResourceRecord::kClassInternet); + VerifyOrExit(question->mUnicastResponse == aUnicastResponse); + contains = true; + + exit: + return contains; + } + + bool Contains(const DnsNameString &aFullName, bool aUnicastResponse) const + { + return Contains(ResourceRecord::kTypeAny, aFullName, aUnicastResponse); + } +}; + +enum TtlCheckMode : uint8_t +{ + kZeroTtl, + kNonZeroTtl, +}; + +enum Section : uint8_t +{ + kInAnswerSection, + kInAdditionalSection, +}; + +struct Data : public ot::Data +{ + Data(const void *aBuffer, uint16_t aLength) { Init(aBuffer, aLength); } + + bool Matches(const Array &aDataArray) const + { + return (aDataArray.GetLength() == GetLength()) && MatchesBytesIn(aDataArray.GetArrayBuffer()); + } +}; + +struct DnsRecord : public Allocatable, public LinkedListEntry +{ + struct SrvData + { + uint16_t mPriority; + uint16_t mWeight; + uint16_t mPort; + DnsName mHostName; + }; + + union RecordData + { + RecordData(void) { memset(this, 0, sizeof(*this)); } + + Ip6::Address mIp6Address; // For AAAAA (or A) + SrvData mSrv; // For SRV + Array mData; // For TXT or KEY + DnsName mPtrName; // For PTR + NsecRecord::TypeBitMap mNsecBitmap; // For NSEC + }; + + DnsRecord *mNext; + DnsName mName; + uint16_t mType; + uint16_t mClass; + uint32_t mTtl; + bool mCacheFlush; + RecordData mData; + + bool Matches(const char *aName) const { return mName.Matches(aName); } + + void ParseFrom(const Message &aMessage, uint16_t &aOffset) + { + String logStr; + ResourceRecord record; + uint16_t offset; + + mName.ParseFrom(aMessage, aOffset); + SuccessOrQuit(aMessage.Read(aOffset, record)); + aOffset += sizeof(ResourceRecord); + + mNext = nullptr; + mType = record.GetType(); + mClass = record.GetClass() & kClassMask; + mCacheFlush = record.GetClass() & kClassCacheFlushFlag; + mTtl = record.GetTtl(); + + logStr.Append("%s %s%s cls:%u ttl:%u", mName.AsCString(), RecordTypeToString(mType), + mCacheFlush ? " cache-flush" : "", mClass, mTtl); + + offset = aOffset; + + switch (mType) + { + case ResourceRecord::kTypeAaaa: + VerifyOrQuit(record.GetLength() == sizeof(Ip6::Address)); + SuccessOrQuit(aMessage.Read(offset, mData.mIp6Address)); + logStr.Append(" %s", mData.mIp6Address.ToString().AsCString()); + break; + + case ResourceRecord::kTypeKey: + case ResourceRecord::kTypeTxt: + VerifyOrQuit(record.GetLength() > 0); + VerifyOrQuit(record.GetLength() < kMaxDataSize); + mData.mData.SetLength(record.GetLength()); + SuccessOrQuit(aMessage.Read(offset, mData.mData.GetArrayBuffer(), record.GetLength())); + logStr.Append(" data-len:%u", record.GetLength()); + break; + + case ResourceRecord::kTypePtr: + mData.mPtrName.ParseFrom(aMessage, offset); + VerifyOrQuit(offset - aOffset == record.GetLength()); + logStr.Append(" %s", mData.mPtrName.AsCString()); + break; + + case ResourceRecord::kTypeSrv: + { + SrvRecord srv; + + offset -= sizeof(ResourceRecord); + SuccessOrQuit(aMessage.Read(offset, srv)); + offset += sizeof(srv); + mData.mSrv.mHostName.ParseFrom(aMessage, offset); + VerifyOrQuit(offset - aOffset == record.GetLength()); + mData.mSrv.mPriority = srv.GetPriority(); + mData.mSrv.mWeight = srv.GetWeight(); + mData.mSrv.mPort = srv.GetPort(); + logStr.Append(" port:%u w:%u prio:%u host:%s", mData.mSrv.mPort, mData.mSrv.mWeight, mData.mSrv.mPriority, + mData.mSrv.mHostName.AsCString()); + break; + } + + case ResourceRecord::kTypeNsec: + { + NsecRecord::TypeBitMap &bitmap = mData.mNsecBitmap; + + SuccessOrQuit(Name::CompareName(aMessage, offset, mName.AsCString())); + SuccessOrQuit(aMessage.Read(offset, &bitmap, NsecRecord::TypeBitMap::kMinSize)); + VerifyOrQuit(bitmap.GetBlockNumber() == 0); + VerifyOrQuit(bitmap.GetBitmapLength() <= NsecRecord::TypeBitMap::kMaxLength); + SuccessOrQuit(aMessage.Read(offset, &bitmap, bitmap.GetSize())); + + offset += bitmap.GetSize(); + VerifyOrQuit(offset - aOffset == record.GetLength()); + + logStr.Append(" [ "); + + for (uint16_t type = 0; type < bitmap.GetBitmapLength() * kBitsPerByte; type++) + { + if (bitmap.ContainsType(type)) + { + logStr.Append("%s ", RecordTypeToString(type)); + } + } + + logStr.Append("]"); + break; + } + + default: + break; + } + + Log(" %s", logStr.AsCString()); + + aOffset += record.GetLength(); + } + + bool MatchesTtl(TtlCheckMode aTtlCheckMode, uint32_t aTtl) const + { + bool matches = false; + + switch (aTtlCheckMode) + { + case kZeroTtl: + VerifyOrExit(mTtl == 0); + break; + case kNonZeroTtl: + if (aTtl > 0) + { + VerifyOrQuit(mTtl == aTtl); + } + + VerifyOrExit(mTtl > 0); + break; + } + + matches = true; + + exit: + return matches; + } +}; + +struct DnsRecords : public OwningList +{ + bool ContainsAaaa(const DnsNameString &aFullName, + const Ip6::Address &aAddress, + bool aCacheFlush, + TtlCheckMode aTtlCheckMode, + uint32_t aTtl = 0) const + { + bool contains = false; + + for (const DnsRecord &record : *this) + { + if (record.Matches(aFullName.AsCString()) && (record.mType == ResourceRecord::kTypeAaaa) && + (record.mData.mIp6Address == aAddress)) + { + VerifyOrExit(record.mClass == ResourceRecord::kClassInternet); + VerifyOrExit(record.mCacheFlush == aCacheFlush); + VerifyOrExit(record.MatchesTtl(aTtlCheckMode, aTtl)); + contains = true; + ExitNow(); + } + } + + exit: + return contains; + } + + bool ContainsKey(const DnsNameString &aFullName, + const Data &aKeyData, + bool aCacheFlush, + TtlCheckMode aTtlCheckMode, + uint32_t aTtl = 0) const + { + bool contains = false; + + for (const DnsRecord &record : *this) + { + if (record.Matches(aFullName.AsCString()) && (record.mType == ResourceRecord::kTypeKey) && + aKeyData.Matches(record.mData.mData)) + { + VerifyOrExit(record.mClass == ResourceRecord::kClassInternet); + VerifyOrExit(record.mCacheFlush == aCacheFlush); + VerifyOrExit(record.MatchesTtl(aTtlCheckMode, aTtl)); + contains = true; + ExitNow(); + } + } + + exit: + return contains; + } + + bool ContainsSrv(const DnsNameString &aFullName, + const Core::Service &aService, + bool aCacheFlush, + TtlCheckMode aTtlCheckMode, + uint32_t aTtl = 0) const + { + bool contains = false; + DnsNameString hostName; + + hostName.Append("%s.local.", aService.mHostName); + + for (const DnsRecord &record : *this) + { + if (record.Matches(aFullName.AsCString()) && (record.mType == ResourceRecord::kTypeSrv)) + { + VerifyOrExit(record.mClass == ResourceRecord::kClassInternet); + VerifyOrExit(record.mCacheFlush == aCacheFlush); + VerifyOrExit(record.MatchesTtl(aTtlCheckMode, aTtl)); + VerifyOrExit(record.mData.mSrv.mPort == aService.mPort); + VerifyOrExit(record.mData.mSrv.mPriority == aService.mPriority); + VerifyOrExit(record.mData.mSrv.mWeight == aService.mWeight); + VerifyOrExit(record.mData.mSrv.mHostName.Matches(hostName.AsCString())); + contains = true; + ExitNow(); + } + } + + exit: + return contains; + } + + bool ContainsTxt(const DnsNameString &aFullName, + const Core::Service &aService, + bool aCacheFlush, + TtlCheckMode aTtlCheckMode, + uint32_t aTtl = 0) const + { + static const uint8_t kEmptyTxtData[1] = {0}; + + bool contains = false; + Data txtData(aService.mTxtData, aService.mTxtDataLength); + + if ((aService.mTxtData == nullptr) || (aService.mTxtDataLength == 0)) + { + txtData.Init(kEmptyTxtData, sizeof(kEmptyTxtData)); + } + + for (const DnsRecord &record : *this) + { + if (record.Matches(aFullName.AsCString()) && (record.mType == ResourceRecord::kTypeTxt) && + txtData.Matches(record.mData.mData)) + { + VerifyOrExit(record.mClass == ResourceRecord::kClassInternet); + VerifyOrExit(record.mCacheFlush == aCacheFlush); + VerifyOrExit(record.MatchesTtl(aTtlCheckMode, aTtl)); + contains = true; + ExitNow(); + } + } + + exit: + return contains; + } + + bool ContainsPtr(const DnsNameString &aFullName, + const DnsNameString &aPtrName, + TtlCheckMode aTtlCheckMode, + uint32_t aTtl = 0) const + { + bool contains = false; + + for (const DnsRecord &record : *this) + { + if (record.Matches(aFullName.AsCString()) && (record.mType == ResourceRecord::kTypePtr) && + (record.mData.mPtrName.Matches(aPtrName.AsCString()))) + { + VerifyOrExit(record.mClass == ResourceRecord::kClassInternet); + VerifyOrExit(!record.mCacheFlush); // PTR should never use cache-flush + VerifyOrExit(record.MatchesTtl(aTtlCheckMode, aTtl)); + contains = true; + ExitNow(); + } + } + + exit: + return contains; + } + + bool ContainsServicesPtr(const DnsNameString &aServiceType) const + { + DnsNameString allServices; + + allServices.Append("_services._dns-sd._udp.local."); + + return ContainsPtr(allServices, aServiceType, kNonZeroTtl, 0); + } + + bool ContainsNsec(const DnsNameString &aFullName, uint16_t aRecordType) const + { + bool contains = false; + + for (const DnsRecord &record : *this) + { + if (record.Matches(aFullName.AsCString()) && (record.mType == ResourceRecord::kTypeNsec)) + { + VerifyOrQuit(!contains); // Ensure only one NSEC record + VerifyOrExit(record.mData.mNsecBitmap.ContainsType(aRecordType)); + contains = true; + } + } + + exit: + return contains; + } +}; + +// Bit-flags used in `Validate()` with a `Service` +// to specify which records should be checked in the announce +// message. + +typedef uint8_t AnnounceCheckFlags; + +static constexpr uint8_t kCheckSrv = (1 << 0); +static constexpr uint8_t kCheckTxt = (1 << 1); +static constexpr uint8_t kCheckPtr = (1 << 2); +static constexpr uint8_t kCheckServicesPtr = (1 << 3); + +enum GoodBye : bool // Used to indicate "goodbye" records (with zero TTL) +{ + kNotGoodBye = false, + kGoodBye = true, +}; + +enum DnsMessageType : uint8_t +{ + kMulticastQuery, + kMulticastResponse, + kUnicastResponse, +}; + +struct DnsMessage : public Allocatable, public LinkedListEntry +{ + DnsMessage *mNext; + uint32_t mTimestamp; + DnsMessageType mType; + Core::AddressInfo mUnicastDest; + Header mHeader; + DnsQuestions mQuestions; + DnsRecords mAnswerRecords; + DnsRecords mAuthRecords; + DnsRecords mAdditionalRecords; + + DnsMessage(void) + : mNext(nullptr) + , mTimestamp(sNow) + { + } + + const DnsRecords &RecordsFor(Section aSection) const + { + const DnsRecords *records = nullptr; + + switch (aSection) + { + case kInAnswerSection: + records = &mAnswerRecords; + break; + case kInAdditionalSection: + records = &mAdditionalRecords; + break; + } + + VerifyOrQuit(records != nullptr); + + return *records; + } + + void ParseRecords(const Message &aMessage, + uint16_t &aOffset, + uint16_t aNumRecords, + OwningList &aRecords, + const char *aSectionName) + { + if (aNumRecords > 0) + { + Log(" %s", aSectionName); + } + + for (; aNumRecords > 0; aNumRecords--) + { + DnsRecord *record = DnsRecord::Allocate(); + + record->ParseFrom(aMessage, aOffset); + aRecords.PushAfterTail(*record); + } + } + + void ParseFrom(const Message &aMessage) + { + uint16_t offset = 0; + + SuccessOrQuit(aMessage.Read(offset, mHeader)); + offset += sizeof(Header); + + Log(" %s id:%u qt:%u t:%u rcode:%u [q:%u ans:%u auth:%u addn:%u]", + mHeader.GetType() == Header::kTypeQuery ? "Query" : "Response", mHeader.GetMessageId(), + mHeader.GetQueryType(), mHeader.IsTruncationFlagSet(), mHeader.GetResponseCode(), + mHeader.GetQuestionCount(), mHeader.GetAnswerCount(), mHeader.GetAuthorityRecordCount(), + mHeader.GetAdditionalRecordCount()); + + if (mHeader.GetQuestionCount() > 0) + { + Log(" Question"); + } + + for (uint16_t num = mHeader.GetQuestionCount(); num > 0; num--) + { + DnsQuestion *question = DnsQuestion::Allocate(); + + question->ParseFrom(aMessage, offset); + mQuestions.PushAfterTail(*question); + } + + ParseRecords(aMessage, offset, mHeader.GetAnswerCount(), mAnswerRecords, "Answer"); + ParseRecords(aMessage, offset, mHeader.GetAuthorityRecordCount(), mAuthRecords, "Authority"); + ParseRecords(aMessage, offset, mHeader.GetAdditionalRecordCount(), mAdditionalRecords, "Additional"); + } + + void ValidateHeader(DnsMessageType aType, + uint16_t aQuestionCount, + uint16_t aAnswerCount, + uint16_t aAuthCount, + uint16_t aAdditionalCount) const + { + VerifyOrQuit(mType == aType); + VerifyOrQuit(mHeader.GetQuestionCount() == aQuestionCount); + VerifyOrQuit(mHeader.GetAnswerCount() == aAnswerCount); + VerifyOrQuit(mHeader.GetAuthorityRecordCount() == aAuthCount); + VerifyOrQuit(mHeader.GetAdditionalRecordCount() == aAdditionalCount); + + if (aType == kUnicastResponse) + { + Ip6::Address ip6Address; + + SuccessOrQuit(ip6Address.FromString(kDeviceIp6Address)); + + VerifyOrQuit(mUnicastDest.mPort == kMdnsPort); + VerifyOrQuit(mUnicastDest.GetAddress() == ip6Address); + } + } + + static void DetemineFullNameForKey(const Core::Key &aKey, DnsNameString &aFullName) + { + if (aKey.mServiceType != nullptr) + { + aFullName.Append("%s.%s.local.", aKey.mName, aKey.mServiceType); + } + else + { + aFullName.Append("%s.local.", aKey.mName); + } + } + + void ValidateAsProbeFor(const Core::Host &aHost, bool aUnicastResponse) const + { + DnsNameString fullName; + + VerifyOrQuit(mHeader.GetType() == Header::kTypeQuery); + VerifyOrQuit(!mHeader.IsTruncationFlagSet()); + + fullName.Append("%s.local.", aHost.mHostName); + VerifyOrQuit(mQuestions.Contains(fullName, aUnicastResponse)); + + for (uint16_t index = 0; index < aHost.mAddressesLength; index++) + { + VerifyOrQuit(mAuthRecords.ContainsAaaa(fullName, AsCoreType(&aHost.mAddresses[index]), !kCacheFlush, + kNonZeroTtl, aHost.mTtl)); + } + } + + void ValidateAsProbeFor(const Core::Service &aService, bool aUnicastResponse) const + { + DnsNameString serviceName; + + VerifyOrQuit(mHeader.GetType() == Header::kTypeQuery); + VerifyOrQuit(!mHeader.IsTruncationFlagSet()); + + serviceName.Append("%s.%s.local.", aService.mServiceInstance, aService.mServiceType); + + VerifyOrQuit(mQuestions.Contains(serviceName, aUnicastResponse)); + + VerifyOrQuit(mAuthRecords.ContainsSrv(serviceName, aService, !kCacheFlush, kNonZeroTtl, aService.mTtl)); + VerifyOrQuit(mAuthRecords.ContainsTxt(serviceName, aService, !kCacheFlush, kNonZeroTtl, aService.mTtl)); + } + + void ValidateAsProbeFor(const Core::Key &aKey, bool aUnicastResponse) const + { + DnsNameString fullName; + + VerifyOrQuit(mHeader.GetType() == Header::kTypeQuery); + VerifyOrQuit(!mHeader.IsTruncationFlagSet()); + + DetemineFullNameForKey(aKey, fullName); + + VerifyOrQuit(mQuestions.Contains(fullName, aUnicastResponse)); + VerifyOrQuit(mAuthRecords.ContainsKey(fullName, Data(aKey.mKeyData, aKey.mKeyDataLength), !kCacheFlush, + kNonZeroTtl, aKey.mTtl)); + } + + void Validate(const Core::Host &aHost, Section aSection, GoodBye aIsGoodBye = kNotGoodBye) const + { + DnsNameString fullName; + + VerifyOrQuit(mHeader.GetType() == Header::kTypeResponse); + + fullName.Append("%s.local.", aHost.mHostName); + + for (uint16_t index = 0; index < aHost.mAddressesLength; index++) + { + VerifyOrQuit(RecordsFor(aSection).ContainsAaaa(fullName, AsCoreType(&aHost.mAddresses[index]), kCacheFlush, + aIsGoodBye ? kZeroTtl : kNonZeroTtl, aHost.mTtl)); + } + + if (!aIsGoodBye && (aSection == kInAnswerSection)) + { + VerifyOrQuit(mAdditionalRecords.ContainsNsec(fullName, ResourceRecord::kTypeAaaa)); + } + } + + void Validate(const Core::Service &aService, + Section aSection, + AnnounceCheckFlags aCheckFlags, + GoodBye aIsGoodBye = kNotGoodBye) const + { + DnsNameString serviceName; + DnsNameString serviceType; + bool checkNsec = false; + + VerifyOrQuit(mHeader.GetType() == Header::kTypeResponse); + + serviceName.Append("%s.%s.local.", aService.mServiceInstance, aService.mServiceType); + serviceType.Append("%s.local.", aService.mServiceType); + + if (aCheckFlags & kCheckSrv) + { + VerifyOrQuit(RecordsFor(aSection).ContainsSrv(serviceName, aService, kCacheFlush, + aIsGoodBye ? kZeroTtl : kNonZeroTtl, aService.mTtl)); + checkNsec = true; + } + + if (aCheckFlags & kCheckTxt) + { + VerifyOrQuit(RecordsFor(aSection).ContainsTxt(serviceName, aService, kCacheFlush, + aIsGoodBye ? kZeroTtl : kNonZeroTtl, aService.mTtl)); + checkNsec = true; + } + + if (aCheckFlags & kCheckPtr) + { + VerifyOrQuit(RecordsFor(aSection).ContainsPtr(serviceType, serviceName, aIsGoodBye ? kZeroTtl : kNonZeroTtl, + aService.mTtl)); + } + + if (aCheckFlags & kCheckServicesPtr) + { + VerifyOrQuit(RecordsFor(aSection).ContainsServicesPtr(serviceType)); + } + + if (!aIsGoodBye && checkNsec && (aSection == kInAnswerSection)) + { + VerifyOrQuit(mAdditionalRecords.ContainsNsec(serviceName, ResourceRecord::kTypeSrv)); + VerifyOrQuit(mAdditionalRecords.ContainsNsec(serviceName, ResourceRecord::kTypeTxt)); + } + } + + void Validate(const Core::Key &aKey, Section aSection, GoodBye aIsGoodBye = kNotGoodBye) const + { + DnsNameString fullName; + + VerifyOrQuit(mHeader.GetType() == Header::kTypeResponse); + + DetemineFullNameForKey(aKey, fullName); + VerifyOrQuit(RecordsFor(aSection).ContainsKey(fullName, Data(aKey.mKeyData, aKey.mKeyDataLength), kCacheFlush, + aIsGoodBye ? kZeroTtl : kNonZeroTtl, aKey.mTtl)); + + if (!aIsGoodBye && (aSection == kInAnswerSection)) + { + VerifyOrQuit(mAdditionalRecords.ContainsNsec(fullName, ResourceRecord::kTypeKey)); + } + } + + void ValidateSubType(const char *aSubLabel, const Core::Service &aService, GoodBye aIsGoodBye = kNotGoodBye) const + { + DnsNameString serviceName; + DnsNameString subServiceType; + + VerifyOrQuit(mHeader.GetType() == Header::kTypeResponse); + + serviceName.Append("%s.%s.local.", aService.mServiceInstance, aService.mServiceType); + subServiceType.Append("%s._sub.%s.local.", aSubLabel, aService.mServiceType); + + VerifyOrQuit(mAnswerRecords.ContainsPtr(subServiceType, serviceName, aIsGoodBye ? kZeroTtl : kNonZeroTtl, + aService.mTtl)); + } + + void ValidateAsQueryFor(const Core::Browser &aBrowser) const + { + DnsNameString fullServiceType; + + VerifyOrQuit(mHeader.GetType() == Header::kTypeQuery); + VerifyOrQuit(!mHeader.IsTruncationFlagSet()); + + if (aBrowser.mSubTypeLabel == nullptr) + { + fullServiceType.Append("%s.local.", aBrowser.mServiceType); + } + else + { + fullServiceType.Append("%s._sub.%s.local", aBrowser.mSubTypeLabel, aBrowser.mServiceType); + } + + VerifyOrQuit(mQuestions.Contains(ResourceRecord::kTypePtr, fullServiceType)); + } + + void ValidateAsQueryFor(const Core::SrvResolver &aResolver) const + { + DnsNameString fullName; + + VerifyOrQuit(mHeader.GetType() == Header::kTypeQuery); + VerifyOrQuit(!mHeader.IsTruncationFlagSet()); + + fullName.Append("%s.%s.local.", aResolver.mServiceInstance, aResolver.mServiceType); + + VerifyOrQuit(mQuestions.Contains(ResourceRecord::kTypeSrv, fullName)); + } + + void ValidateAsQueryFor(const Core::TxtResolver &aResolver) const + { + DnsNameString fullName; + + VerifyOrQuit(mHeader.GetType() == Header::kTypeQuery); + VerifyOrQuit(!mHeader.IsTruncationFlagSet()); + + fullName.Append("%s.%s.local.", aResolver.mServiceInstance, aResolver.mServiceType); + + VerifyOrQuit(mQuestions.Contains(ResourceRecord::kTypeTxt, fullName)); + } + + void ValidateAsQueryFor(const Core::AddressResolver &aResolver) const + { + DnsNameString fullName; + + VerifyOrQuit(mHeader.GetType() == Header::kTypeQuery); + VerifyOrQuit(!mHeader.IsTruncationFlagSet()); + + fullName.Append("%s.local.", aResolver.mHostName); + + VerifyOrQuit(mQuestions.Contains(ResourceRecord::kTypeAaaa, fullName)); + } +}; + +struct RegCallback +{ + void Reset(void) { mWasCalled = false; } + + bool mWasCalled; + Error mError; +}; + +static constexpr uint16_t kMaxCallbacks = 8; + +static RegCallback sRegCallbacks[kMaxCallbacks]; + +static void HandleCallback(otInstance *aInstance, otMdnsRequestId aRequestId, otError aError) +{ + Log("Register callback - ResuestId:%u Error:%s", aRequestId, otThreadErrorToString(aError)); + + VerifyOrQuit(aInstance == sInstance); + VerifyOrQuit(aRequestId < kMaxCallbacks); + + VerifyOrQuit(!sRegCallbacks[aRequestId].mWasCalled); + + sRegCallbacks[aRequestId].mWasCalled = true; + sRegCallbacks[aRequestId].mError = aError; +} + +static void HandleSuccessCallback(otInstance *aInstance, otMdnsRequestId aRequestId, otError aError) +{ + HandleCallback(aInstance, aRequestId, aError); + SuccessOrQuit(aError); +} + +struct ConflictCallback +{ + void Reset(void) { mWasCalled = false; } + + void Handle(const char *aName, const char *aServiceType) + { + VerifyOrQuit(!mWasCalled); + + mWasCalled = true; + mName.Clear(); + mName.Append("%s", aName); + + mHasServiceType = (aServiceType != nullptr); + VerifyOrExit(mHasServiceType); + mServiceType.Clear(); + mServiceType.Append("%s", aServiceType); + + exit: + return; + } + + bool mWasCalled; + bool mHasServiceType; + DnsNameString mName; + DnsNameString mServiceType; +}; + +static ConflictCallback sConflictCallback; + +static void HandleConflict(otInstance *aInstance, const char *aName, const char *aServiceType) +{ + Log("Conflict callback - %s %s", aName, (aServiceType == nullptr) ? "" : aServiceType); + + VerifyOrQuit(aInstance == sInstance); + sConflictCallback.Handle(aName, aServiceType); +} + +//--------------------------------------------------------------------------------------------------------------------- +// Helper functions and methods + +static const char *RecordTypeToString(uint16_t aType) +{ + const char *str = "Other"; + + switch (aType) + { + case ResourceRecord::kTypeZero: + str = "ZERO"; + break; + case ResourceRecord::kTypeA: + str = "A"; + break; + case ResourceRecord::kTypeSoa: + str = "SOA"; + break; + case ResourceRecord::kTypeCname: + str = "CNAME"; + break; + case ResourceRecord::kTypePtr: + str = "PTR"; + break; + case ResourceRecord::kTypeTxt: + str = "TXT"; + break; + case ResourceRecord::kTypeSig: + str = "SIG"; + break; + case ResourceRecord::kTypeKey: + str = "KEY"; + break; + case ResourceRecord::kTypeAaaa: + str = "AAAA"; + break; + case ResourceRecord::kTypeSrv: + str = "SRV"; + break; + case ResourceRecord::kTypeOpt: + str = "OPT"; + break; + case ResourceRecord::kTypeNsec: + str = "NSEC"; + break; + case ResourceRecord::kTypeAny: + str = "ANY"; + break; + } + + return str; +} + +static void ParseMessage(const Message &aMessage, const Core::AddressInfo *aUnicastDest) +{ + DnsMessage *msg = DnsMessage::Allocate(); + + msg->ParseFrom(aMessage); + + switch (msg->mHeader.GetType()) + { + case Header::kTypeQuery: + msg->mType = kMulticastQuery; + VerifyOrQuit(aUnicastDest == nullptr); + break; + + case Header::kTypeResponse: + if (aUnicastDest == nullptr) + { + msg->mType = kMulticastResponse; + } + else + { + msg->mType = kUnicastResponse; + msg->mUnicastDest = *aUnicastDest; + } + } + + sDnsMessages.PushAfterTail(*msg); +} + +static void SendQuery(const char *aName, + uint16_t aRecordType, + uint16_t aRecordClass = ResourceRecord::kClassInternet, + bool aTruncated = false) +{ + Message *message; + Header header; + Core::AddressInfo senderAddrInfo; + + message = sInstance->Get().Allocate(Message::kTypeOther); + VerifyOrQuit(message != nullptr); + + header.Clear(); + header.SetType(Header::kTypeQuery); + header.SetQuestionCount(1); + + if (aTruncated) + { + header.SetTruncationFlag(); + } + + SuccessOrQuit(message->Append(header)); + SuccessOrQuit(Name::AppendName(aName, *message)); + SuccessOrQuit(message->Append(Question(aRecordType, aRecordClass))); + + SuccessOrQuit(AsCoreType(&senderAddrInfo.mAddress).FromString(kDeviceIp6Address)); + senderAddrInfo.mPort = kMdnsPort; + senderAddrInfo.mInfraIfIndex = 0; + + Log("Sending query for %s %s", aName, RecordTypeToString(aRecordType)); + + otPlatMdnsHandleReceive(sInstance, message, /* aIsUnicast */ false, &senderAddrInfo); +} + +static void SendQueryForTwo(const char *aName1, uint16_t aRecordType1, const char *aName2, uint16_t aRecordType2) +{ + // Send query with two questions. + + Message *message; + Header header; + Core::AddressInfo senderAddrInfo; + + message = sInstance->Get().Allocate(Message::kTypeOther); + VerifyOrQuit(message != nullptr); + + header.Clear(); + header.SetType(Header::kTypeQuery); + header.SetQuestionCount(2); + + SuccessOrQuit(message->Append(header)); + SuccessOrQuit(Name::AppendName(aName1, *message)); + SuccessOrQuit(message->Append(Question(aRecordType1, ResourceRecord::kClassInternet))); + SuccessOrQuit(Name::AppendName(aName2, *message)); + SuccessOrQuit(message->Append(Question(aRecordType2, ResourceRecord::kClassInternet))); + + SuccessOrQuit(AsCoreType(&senderAddrInfo.mAddress).FromString(kDeviceIp6Address)); + senderAddrInfo.mPort = kMdnsPort; + senderAddrInfo.mInfraIfIndex = 0; + + Log("Sending query for %s %s and %s %s", aName1, RecordTypeToString(aRecordType1), aName2, + RecordTypeToString(aRecordType2)); + + otPlatMdnsHandleReceive(sInstance, message, /* aIsUnicast */ false, &senderAddrInfo); +} + +static void SendPtrResponse(const char *aName, const char *aPtrName, uint32_t aTtl, Section aSection) +{ + Message *message; + Header header; + PtrRecord ptr; + Core::AddressInfo senderAddrInfo; + + message = sInstance->Get().Allocate(Message::kTypeOther); + VerifyOrQuit(message != nullptr); + + header.Clear(); + header.SetType(Header::kTypeResponse); + + switch (aSection) + { + case kInAnswerSection: + header.SetAnswerCount(1); + break; + case kInAdditionalSection: + header.SetAdditionalRecordCount(1); + break; + } + + SuccessOrQuit(message->Append(header)); + SuccessOrQuit(Name::AppendName(aName, *message)); + + ptr.Init(); + ptr.SetTtl(aTtl); + ptr.SetLength(StringLength(aPtrName, Name::kMaxNameSize) + 1); + SuccessOrQuit(message->Append(ptr)); + SuccessOrQuit(Name::AppendName(aPtrName, *message)); + + SuccessOrQuit(AsCoreType(&senderAddrInfo.mAddress).FromString(kDeviceIp6Address)); + senderAddrInfo.mPort = kMdnsPort; + senderAddrInfo.mInfraIfIndex = 0; + + Log("Sending PTR response for %s with %s, ttl:%lu", aName, aPtrName, ToUlong(aTtl)); + + otPlatMdnsHandleReceive(sInstance, message, /* aIsUnicast */ false, &senderAddrInfo); +} + +static void SendSrvResponse(const char *aServiceName, + const char *aHostName, + uint16_t aPort, + uint16_t aPriority, + uint16_t aWeight, + uint32_t aTtl, + Section aSection) +{ + Message *message; + Header header; + SrvRecord srv; + Core::AddressInfo senderAddrInfo; + + message = sInstance->Get().Allocate(Message::kTypeOther); + VerifyOrQuit(message != nullptr); + + header.Clear(); + header.SetType(Header::kTypeResponse); + + switch (aSection) + { + case kInAnswerSection: + header.SetAnswerCount(1); + break; + case kInAdditionalSection: + header.SetAdditionalRecordCount(1); + break; + } + + SuccessOrQuit(message->Append(header)); + SuccessOrQuit(Name::AppendName(aServiceName, *message)); + + srv.Init(); + srv.SetTtl(aTtl); + srv.SetPort(aPort); + srv.SetPriority(aPriority); + srv.SetWeight(aWeight); + srv.SetLength(sizeof(srv) - sizeof(ResourceRecord) + StringLength(aHostName, Name::kMaxNameSize) + 1); + SuccessOrQuit(message->Append(srv)); + SuccessOrQuit(Name::AppendName(aHostName, *message)); + + SuccessOrQuit(AsCoreType(&senderAddrInfo.mAddress).FromString(kDeviceIp6Address)); + senderAddrInfo.mPort = kMdnsPort; + senderAddrInfo.mInfraIfIndex = 0; + + Log("Sending SRV response for %s, host:%s, port:%u, ttl:%lu", aServiceName, aHostName, aPort, ToUlong(aTtl)); + + otPlatMdnsHandleReceive(sInstance, message, /* aIsUnicast */ false, &senderAddrInfo); +} + +static void SendTxtResponse(const char *aServiceName, + const uint8_t *aTxtData, + uint16_t aTxtDataLength, + uint32_t aTtl, + Section aSection) +{ + Message *message; + Header header; + TxtRecord txt; + Core::AddressInfo senderAddrInfo; + + message = sInstance->Get().Allocate(Message::kTypeOther); + VerifyOrQuit(message != nullptr); + + header.Clear(); + header.SetType(Header::kTypeResponse); + + switch (aSection) + { + case kInAnswerSection: + header.SetAnswerCount(1); + break; + case kInAdditionalSection: + header.SetAdditionalRecordCount(1); + break; + } + + SuccessOrQuit(message->Append(header)); + SuccessOrQuit(Name::AppendName(aServiceName, *message)); + + txt.Init(); + txt.SetTtl(aTtl); + txt.SetLength(aTxtDataLength); + SuccessOrQuit(message->Append(txt)); + SuccessOrQuit(message->AppendBytes(aTxtData, aTxtDataLength)); + + SuccessOrQuit(AsCoreType(&senderAddrInfo.mAddress).FromString(kDeviceIp6Address)); + senderAddrInfo.mPort = kMdnsPort; + senderAddrInfo.mInfraIfIndex = 0; + + Log("Sending TXT response for %s, len:%u, ttl:%lu", aServiceName, aTxtDataLength, ToUlong(aTtl)); + + otPlatMdnsHandleReceive(sInstance, message, /* aIsUnicast */ false, &senderAddrInfo); +} + +static void SendHostAddrResponse(const char *aHostName, + AddrAndTtl *aAddrAndTtls, + uint32_t aNumAddrs, + bool aCacheFlush, + Section aSection) +{ + Message *message; + Header header; + AaaaRecord record; + Core::AddressInfo senderAddrInfo; + + message = sInstance->Get().Allocate(Message::kTypeOther); + VerifyOrQuit(message != nullptr); + + header.Clear(); + header.SetType(Header::kTypeResponse); + + switch (aSection) + { + case kInAnswerSection: + header.SetAnswerCount(aNumAddrs); + break; + case kInAdditionalSection: + header.SetAdditionalRecordCount(aNumAddrs); + break; + } + + SuccessOrQuit(message->Append(header)); + + record.Init(); + + if (aCacheFlush) + { + record.SetClass(record.GetClass() | kClassCacheFlushFlag); + } + + Log("Sending AAAA response for %s numAddrs:%u, cach-flush:%u", aHostName, aNumAddrs, aCacheFlush); + + for (uint16_t index = 0; index < aNumAddrs; index++) + { + record.SetTtl(aAddrAndTtls[index].mTtl); + record.SetAddress(aAddrAndTtls[index].mAddress); + + SuccessOrQuit(Name::AppendName(aHostName, *message)); + SuccessOrQuit(message->Append(record)); + + Log(" - %s, ttl:%lu", aAddrAndTtls[index].mAddress.ToString().AsCString(), ToUlong(aAddrAndTtls[index].mTtl)); + } + + SuccessOrQuit(AsCoreType(&senderAddrInfo.mAddress).FromString(kDeviceIp6Address)); + senderAddrInfo.mPort = kMdnsPort; + senderAddrInfo.mInfraIfIndex = 0; + + otPlatMdnsHandleReceive(sInstance, message, /* aIsUnicast */ false, &senderAddrInfo); +} + +static void SendResponseWithEmptyKey(const char *aName, Section aSection) +{ + Message *message; + Header header; + ResourceRecord record; + Core::AddressInfo senderAddrInfo; + + message = sInstance->Get().Allocate(Message::kTypeOther); + VerifyOrQuit(message != nullptr); + + header.Clear(); + header.SetType(Header::kTypeResponse); + + switch (aSection) + { + case kInAnswerSection: + header.SetAnswerCount(1); + break; + case kInAdditionalSection: + header.SetAdditionalRecordCount(1); + break; + } + + SuccessOrQuit(message->Append(header)); + SuccessOrQuit(Name::AppendName(aName, *message)); + + record.Init(ResourceRecord::kTypeKey); + record.SetTtl(4500); + record.SetLength(0); + SuccessOrQuit(message->Append(record)); + + SuccessOrQuit(AsCoreType(&senderAddrInfo.mAddress).FromString(kDeviceIp6Address)); + senderAddrInfo.mPort = kMdnsPort; + senderAddrInfo.mInfraIfIndex = 0; + + Log("Sending response with empty key for %s", aName); + + otPlatMdnsHandleReceive(sInstance, message, /* aIsUnicast */ false, &senderAddrInfo); +} + +struct KnownAnswer +{ + const char *mPtrAnswer; + uint32_t mTtl; +}; + +static void SendPtrQueryWithKnownAnswers(const char *aName, const KnownAnswer *aKnownAnswers, uint16_t aNumAnswers) +{ + Message *message; + Header header; + Core::AddressInfo senderAddrInfo; + uint16_t nameOffset; + + message = sInstance->Get().Allocate(Message::kTypeOther); + VerifyOrQuit(message != nullptr); + + header.Clear(); + header.SetType(Header::kTypeQuery); + header.SetQuestionCount(1); + header.SetAnswerCount(aNumAnswers); + + SuccessOrQuit(message->Append(header)); + nameOffset = message->GetLength(); + SuccessOrQuit(Name::AppendName(aName, *message)); + SuccessOrQuit(message->Append(Question(ResourceRecord::kTypePtr, ResourceRecord::kClassInternet))); + + for (uint16_t index = 0; index < aNumAnswers; index++) + { + PtrRecord ptr; + + ptr.Init(); + ptr.SetTtl(aKnownAnswers[index].mTtl); + ptr.SetLength(StringLength(aKnownAnswers[index].mPtrAnswer, Name::kMaxNameSize) + 1); + + SuccessOrQuit(Name::AppendPointerLabel(nameOffset, *message)); + SuccessOrQuit(message->Append(ptr)); + SuccessOrQuit(Name::AppendName(aKnownAnswers[index].mPtrAnswer, *message)); + } + + SuccessOrQuit(AsCoreType(&senderAddrInfo.mAddress).FromString(kDeviceIp6Address)); + senderAddrInfo.mPort = kMdnsPort; + senderAddrInfo.mInfraIfIndex = 0; + + Log("Sending query for %s PTR with %u known-answers", aName, aNumAnswers); + + otPlatMdnsHandleReceive(sInstance, message, /* aIsUnicast */ false, &senderAddrInfo); +} + +static void SendEmtryPtrQueryWithKnownAnswers(const char *aName, const KnownAnswer *aKnownAnswers, uint16_t aNumAnswers) +{ + Message *message; + Header header; + Core::AddressInfo senderAddrInfo; + uint16_t nameOffset = 0; + + message = sInstance->Get().Allocate(Message::kTypeOther); + VerifyOrQuit(message != nullptr); + + header.Clear(); + header.SetType(Header::kTypeQuery); + header.SetAnswerCount(aNumAnswers); + + SuccessOrQuit(message->Append(header)); + + for (uint16_t index = 0; index < aNumAnswers; index++) + { + PtrRecord ptr; + + ptr.Init(); + ptr.SetTtl(aKnownAnswers[index].mTtl); + ptr.SetLength(StringLength(aKnownAnswers[index].mPtrAnswer, Name::kMaxNameSize) + 1); + + if (nameOffset == 0) + { + nameOffset = message->GetLength(); + SuccessOrQuit(Name::AppendName(aName, *message)); + } + else + { + SuccessOrQuit(Name::AppendPointerLabel(nameOffset, *message)); + } + + SuccessOrQuit(message->Append(ptr)); + SuccessOrQuit(Name::AppendName(aKnownAnswers[index].mPtrAnswer, *message)); + } + + SuccessOrQuit(AsCoreType(&senderAddrInfo.mAddress).FromString(kDeviceIp6Address)); + senderAddrInfo.mPort = kMdnsPort; + senderAddrInfo.mInfraIfIndex = 0; + + Log("Sending empty query with %u known-answers for %s", aNumAnswers, aName); + + otPlatMdnsHandleReceive(sInstance, message, /* aIsUnicast */ false, &senderAddrInfo); +} + +//---------------------------------------------------------------------------------------------------------------------- +// `otPlatLog` + +extern "C" { + +#if OPENTHREAD_CONFIG_LOG_OUTPUT == OPENTHREAD_CONFIG_LOG_OUTPUT_PLATFORM_DEFINED +void otPlatLog(otLogLevel aLogLevel, otLogRegion aLogRegion, const char *aFormat, ...) +{ + OT_UNUSED_VARIABLE(aLogLevel); + OT_UNUSED_VARIABLE(aLogRegion); + OT_UNUSED_VARIABLE(aFormat); + +#if ENABLE_TEST_LOG + va_list args; + + printf(" "); + va_start(args, aFormat); + vprintf(aFormat, args); + va_end(args); + + printf("\n"); +#endif +} + +#endif + +//---------------------------------------------------------------------------------------------------------------------- +// `otPlatAlarm` + +void otPlatAlarmMilliStop(otInstance *) { sAlarmOn = false; } + +void otPlatAlarmMilliStartAt(otInstance *, uint32_t aT0, uint32_t aDt) +{ + sAlarmOn = true; + sAlarmTime = aT0 + aDt; +} + +uint32_t otPlatAlarmMilliGetNow(void) { return sNow; } + +//---------------------------------------------------------------------------------------------------------------------- +// Heap allocation + +Array sHeapAllocatedPtrs; + +#if OPENTHREAD_CONFIG_HEAP_EXTERNAL_ENABLE + +void *otPlatCAlloc(size_t aNum, size_t aSize) +{ + void *ptr = calloc(aNum, aSize); + + SuccessOrQuit(sHeapAllocatedPtrs.PushBack(ptr)); + + return ptr; +} + +void otPlatFree(void *aPtr) +{ + if (aPtr != nullptr) + { + void **entry = sHeapAllocatedPtrs.Find(aPtr); + + VerifyOrQuit(entry != nullptr, "A heap allocated item is freed twice"); + sHeapAllocatedPtrs.Remove(*entry); + } + + free(aPtr); +} + +#endif + +//---------------------------------------------------------------------------------------------------------------------- +// `otPlatMdns` + +otError otPlatMdnsSetListeningEnabled(otInstance *aInstance, bool aEnable, uint32_t aInfraIfIndex) +{ + VerifyOrQuit(aInstance == sInstance); + sInfraIfIndex = aInfraIfIndex; + + Log("otPlatMdnsSetListeningEnabled(%s)", aEnable ? "true" : "false"); + + return kErrorNone; +} + +void otPlatMdnsSendMulticast(otInstance *aInstance, otMessage *aMessage, uint32_t aInfraIfIndex) +{ + Message &message = AsCoreType(aMessage); + Core::AddressInfo senderAddrInfo; + + VerifyOrQuit(aInfraIfIndex == sInfraIfIndex); + + Log("otPlatMdnsSendMulticast(msg-len:%u)", message.GetLength()); + ParseMessage(message, nullptr); + + // Pass the multicast message back. + + SuccessOrQuit(AsCoreType(&senderAddrInfo.mAddress).FromString(kDeviceIp6Address)); + senderAddrInfo.mPort = kMdnsPort; + senderAddrInfo.mInfraIfIndex = 0; + + otPlatMdnsHandleReceive(sInstance, aMessage, /* aIsUnicast */ false, &senderAddrInfo); +} + +void otPlatMdnsSendUnicast(otInstance *aInstance, otMessage *aMessage, const otPlatMdnsAddressInfo *aAddress) +{ + Message &message = AsCoreType(aMessage); + const Core::AddressInfo &address = AsCoreType(aAddress); + Ip6::Address deviceAddress; + + Log("otPlatMdnsSendUnicast() - [%s]:%u", address.GetAddress().ToString().AsCString(), address.mPort); + ParseMessage(message, AsCoreTypePtr(aAddress)); + + SuccessOrQuit(deviceAddress.FromString(kDeviceIp6Address)); + + if ((address.GetAddress() == deviceAddress) && (address.mPort == kMdnsPort)) + { + Core::AddressInfo senderAddrInfo; + + SuccessOrQuit(AsCoreType(&senderAddrInfo.mAddress).FromString(kDeviceIp6Address)); + senderAddrInfo.mPort = kMdnsPort; + senderAddrInfo.mInfraIfIndex = 0; + + Log("otPlatMdnsSendUnicast() - unicast msg matches this device address, passing it back"); + otPlatMdnsHandleReceive(sInstance, &message, /* aIsUnicast */ true, &senderAddrInfo); + } + else + { + message.Free(); + } +} + +} // extern "C" + +//--------------------------------------------------------------------------------------------------------------------- + +void ProcessTasklets(void) +{ + while (otTaskletsArePending(sInstance)) + { + otTaskletsProcess(sInstance); + } +} + +void AdvanceTime(uint32_t aDuration) +{ + uint32_t time = sNow + aDuration; + + Log("AdvanceTime for %u.%03u", aDuration / 1000, aDuration % 1000); + + while (TimeMilli(sAlarmTime) <= TimeMilli(time)) + { + ProcessTasklets(); + sNow = sAlarmTime; + otPlatAlarmMilliFired(sInstance); + } + + ProcessTasklets(); + sNow = time; +} + +Core *InitTest(void) +{ + sNow = 0; + sAlarmOn = false; + + sDnsMessages.Clear(); + + for (RegCallback ®Callbck : sRegCallbacks) + { + regCallbck.Reset(); + } + + sConflictCallback.Reset(); + + sInstance = testInitInstance(); + + VerifyOrQuit(sInstance != nullptr); + + return &sInstance->Get(); +} + +//---------------------------------------------------------------------------------------------------------------------- + +static const uint8_t kKey1[] = {0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77}; +static const uint8_t kKey2[] = {0x12, 0x34, 0x56}; +static const uint8_t kTxtData1[] = {3, 'a', '=', '1', 0}; +static const uint8_t kTxtData2[] = {1, 'b', 0}; +static const uint8_t kEmptyTxtData[] = {0}; + +//--------------------------------------------------------------------------------------------------------------------- + +void TestHostReg(void) +{ + Core *mdns = InitTest(); + Core::Host host; + Ip6::Address hostAddresses[3]; + const DnsMessage *dnsMsg; + uint16_t heapAllocations; + DnsNameString hostFullName; + + Log("-------------------------------------------------------------------------------------------"); + Log("TestHostReg"); + + AdvanceTime(1); + + heapAllocations = sHeapAllocatedPtrs.GetLength(); + SuccessOrQuit(mdns->SetEnabled(true, kInfraIfIndex)); + + SuccessOrQuit(hostAddresses[0].FromString("fd00::aaaa")); + SuccessOrQuit(hostAddresses[1].FromString("fd00::bbbb")); + SuccessOrQuit(hostAddresses[2].FromString("fd00::cccc")); + + host.mHostName = "myhost"; + host.mAddresses = hostAddresses; + host.mAddressesLength = 3; + host.mTtl = 1500; + + hostFullName.Append("%s.local.", host.mHostName); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Register a `HostEntry`, check probes and announcements"); + + sDnsMessages.Clear(); + + sRegCallbacks[0].Reset(); + SuccessOrQuit(mdns->RegisterHost(host, 0, HandleSuccessCallback)); + + for (uint8_t probeCount = 0; probeCount < 3; probeCount++) + { + sDnsMessages.Clear(); + + VerifyOrQuit(!sRegCallbacks[0].mWasCalled); + AdvanceTime(250); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastQuery, /* Q */ 1, /* Ans */ 0, /* Auth */ 3, /* Addnl */ 0); + dnsMsg->ValidateAsProbeFor(host, /* aUnicastRequest */ (probeCount == 0)); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + } + + for (uint8_t anncCount = 0; anncCount < kNumAnnounces; anncCount++) + { + sDnsMessages.Clear(); + + AdvanceTime((anncCount == 0) ? 250 : (1U << (anncCount - 1)) * 1000); + VerifyOrQuit(sRegCallbacks[0].mWasCalled); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 3, /* Auth */ 0, /* Addnl */ 1); + dnsMsg->Validate(host, kInAnswerSection); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + } + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a query for AAAA record and validate the response"); + + AdvanceTime(2000); + + sDnsMessages.Clear(); + SendQuery(hostFullName.AsCString(), ResourceRecord::kTypeAaaa); + + AdvanceTime(1000); + + dnsMsg = sDnsMessages.GetHead(); + VerifyOrQuit(dnsMsg != nullptr); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 3, /* Auth */ 0, /* Addnl */ 1); + dnsMsg->Validate(host, kInAnswerSection); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a query for ANY record and validate the response"); + + AdvanceTime(2000); + + sDnsMessages.Clear(); + SendQuery(hostFullName.AsCString(), ResourceRecord::kTypeAny); + + AdvanceTime(1000); + + dnsMsg = sDnsMessages.GetHead(); + VerifyOrQuit(dnsMsg != nullptr); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 3, /* Auth */ 0, /* Addnl */ 1); + dnsMsg->Validate(host, kInAnswerSection); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a query for non-existing record and validate the response with NSEC"); + + AdvanceTime(2000); + + sDnsMessages.Clear(); + SendQuery(hostFullName.AsCString(), ResourceRecord::kTypeA); + + AdvanceTime(1000); + + dnsMsg = sDnsMessages.GetHead(); + VerifyOrQuit(dnsMsg != nullptr); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 0, /* Auth */ 0, /* Addnl */ 1); + VerifyOrQuit(dnsMsg->mAdditionalRecords.ContainsNsec(hostFullName, ResourceRecord::kTypeAaaa)); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Update number of host addresses and validate new announcements"); + + host.mAddressesLength = 2; + + sRegCallbacks[1].Reset(); + sDnsMessages.Clear(); + SuccessOrQuit(mdns->RegisterHost(host, 1, HandleSuccessCallback)); + + for (uint8_t anncCount = 0; anncCount < kNumAnnounces; anncCount++) + { + AdvanceTime((anncCount == 0) ? 0 : (1U << (anncCount - 1)) * 1000); + VerifyOrQuit(sRegCallbacks[1].mWasCalled); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 2, /* Auth */ 0, /* Addnl */ 1); + dnsMsg->Validate(host, kInAnswerSection); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + sDnsMessages.Clear(); + } + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Change the addresses and validate the first announce"); + + host.mAddressesLength = 3; + + sRegCallbacks[0].Reset(); + sDnsMessages.Clear(); + SuccessOrQuit(mdns->RegisterHost(host, 0, HandleSuccessCallback)); + + AdvanceTime(300); + VerifyOrQuit(sRegCallbacks[0].mWasCalled); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 3, /* Auth */ 0, /* Addnl */ 1); + dnsMsg->Validate(host, kInAnswerSection); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + + Log("Change the address list again before second announce"); + + host.mAddressesLength = 1; + + sRegCallbacks[1].Reset(); + sDnsMessages.Clear(); + SuccessOrQuit(mdns->RegisterHost(host, 1, HandleSuccessCallback)); + + for (uint8_t anncCount = 0; anncCount < kNumAnnounces; anncCount++) + { + AdvanceTime((anncCount == 0) ? 0 : (1U << (anncCount - 1)) * 1000); + VerifyOrQuit(sRegCallbacks[1].mWasCalled); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 1, /* Auth */ 0, /* Addnl */ 1); + dnsMsg->Validate(host, kInAnswerSection); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + sDnsMessages.Clear(); + } + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Change `HostEntry` TTL and validate announcements"); + + host.mTtl = 120; + + sRegCallbacks[1].Reset(); + sDnsMessages.Clear(); + SuccessOrQuit(mdns->RegisterHost(host, 1, HandleSuccessCallback)); + + for (uint8_t anncCount = 0; anncCount < kNumAnnounces; anncCount++) + { + AdvanceTime((anncCount == 0) ? 0 : (1U << (anncCount - 1)) * 1000); + VerifyOrQuit(sRegCallbacks[1].mWasCalled); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 1, /* Auth */ 0, /* Addnl */ 1); + dnsMsg->Validate(host, kInAnswerSection); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + sDnsMessages.Clear(); + } + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a query for AAAA record and validate the response"); + + AdvanceTime(2000); + + sDnsMessages.Clear(); + SendQuery(hostFullName.AsCString(), ResourceRecord::kTypeAaaa); + + AdvanceTime(1000); + + dnsMsg = sDnsMessages.GetHead(); + VerifyOrQuit(dnsMsg != nullptr); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 1, /* Auth */ 0, /* Addnl */ 1); + dnsMsg->Validate(host, kInAnswerSection); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Unregister the host and validate the goodbye announces"); + + sDnsMessages.Clear(); + SuccessOrQuit(mdns->UnregisterHost(host)); + + for (uint8_t anncCount = 0; anncCount < kNumAnnounces; anncCount++) + { + AdvanceTime((anncCount == 0) ? 0 : (1U << (anncCount - 1)) * 1000); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 1, /* Auth */ 0, /* Addnl */ 0); + dnsMsg->Validate(host, kInAnswerSection, kGoodBye); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + sDnsMessages.Clear(); + } + + AdvanceTime(15000); + VerifyOrQuit(sDnsMessages.IsEmpty()); + + SuccessOrQuit(mdns->SetEnabled(false, kInfraIfIndex)); + VerifyOrQuit(sHeapAllocatedPtrs.GetLength() <= heapAllocations); + + Log("End of test"); + + testFreeInstance(sInstance); +} + +//--------------------------------------------------------------------------------------------------------------------- + +void TestKeyReg(void) +{ + Core *mdns = InitTest(); + Core::Key key; + const DnsMessage *dnsMsg; + uint16_t heapAllocations; + + Log("-------------------------------------------------------------------------------------------"); + Log("TestKeyReg"); + + AdvanceTime(1); + + heapAllocations = sHeapAllocatedPtrs.GetLength(); + SuccessOrQuit(mdns->SetEnabled(true, kInfraIfIndex)); + + // Run all tests twice. first with key for a host name, followed + // by key for service instance name. + + for (uint8_t iter = 0; iter < 2; iter++) + { + DnsNameString fullName; + + if (iter == 0) + { + Log("= = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = ="); + Log("Registering key for 'myhost' host name"); + key.mName = "myhost"; + key.mServiceType = nullptr; + + fullName.Append("%s.local.", key.mName); + } + else + { + Log("= = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = ="); + Log("Registering key for 'mysrv._srv._udo' service name"); + + key.mName = "mysrv"; + key.mServiceType = "_srv._udp"; + + fullName.Append("%s.%s.local.", key.mName, key.mServiceType); + } + + key.mKeyData = kKey1; + key.mKeyDataLength = sizeof(kKey1); + key.mTtl = 8000; + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Register a key record and check probes and announcements"); + + sDnsMessages.Clear(); + + sRegCallbacks[0].Reset(); + SuccessOrQuit(mdns->RegisterKey(key, 0, HandleSuccessCallback)); + + for (uint8_t probeCount = 0; probeCount < 3; probeCount++) + { + sDnsMessages.Clear(); + + VerifyOrQuit(!sRegCallbacks[0].mWasCalled); + AdvanceTime(250); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastQuery, /* Q */ 1, /* Ans */ 0, /* Auth */ 1, /* Addnl */ 0); + dnsMsg->ValidateAsProbeFor(key, /* aUnicastRequest */ (probeCount == 0)); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + } + + for (uint8_t anncCount = 0; anncCount < kNumAnnounces; anncCount++) + { + sDnsMessages.Clear(); + + AdvanceTime((anncCount == 0) ? 250 : (1U << (anncCount - 1)) * 1000); + VerifyOrQuit(sRegCallbacks[0].mWasCalled); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 1, /* Auth */ 0, /* Addnl */ 1); + dnsMsg->Validate(key, kInAnswerSection); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + } + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a query for KEY record and validate the response"); + + AdvanceTime(2000); + + sDnsMessages.Clear(); + SendQuery(fullName.AsCString(), ResourceRecord::kTypeKey); + + AdvanceTime(1000); + + dnsMsg = sDnsMessages.GetHead(); + VerifyOrQuit(dnsMsg != nullptr); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 1, /* Auth */ 0, /* Addnl */ 1); + dnsMsg->Validate(key, kInAnswerSection); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a query for ANY record and validate the response"); + + AdvanceTime(2000); + + sDnsMessages.Clear(); + SendQuery(fullName.AsCString(), ResourceRecord::kTypeAny); + + AdvanceTime(1000); + + dnsMsg = sDnsMessages.GetHead(); + VerifyOrQuit(dnsMsg != nullptr); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 1, /* Auth */ 0, /* Addnl */ 1); + dnsMsg->Validate(key, kInAnswerSection); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a query for non-existing record and validate the response with NSEC"); + + AdvanceTime(2000); + + sDnsMessages.Clear(); + SendQuery(fullName.AsCString(), ResourceRecord::kTypeA); + + AdvanceTime(1000); + + dnsMsg = sDnsMessages.GetHead(); + VerifyOrQuit(dnsMsg != nullptr); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 0, /* Auth */ 0, /* Addnl */ 1); + VerifyOrQuit(dnsMsg->mAdditionalRecords.ContainsNsec(fullName, ResourceRecord::kTypeKey)); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Change the TTL"); + + key.mTtl = 0; // Use default + + sRegCallbacks[1].Reset(); + sDnsMessages.Clear(); + SuccessOrQuit(mdns->RegisterKey(key, 1, HandleSuccessCallback)); + + for (uint8_t anncCount = 0; anncCount < kNumAnnounces; anncCount++) + { + AdvanceTime((anncCount == 0) ? 0 : (1U << (anncCount - 1)) * 1000); + VerifyOrQuit(sRegCallbacks[1].mWasCalled); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 1, /* Auth */ 0, /* Addnl */ 1); + dnsMsg->Validate(key, kInAnswerSection); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + + sDnsMessages.Clear(); + } + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Change the key"); + + key.mKeyData = kKey2; + key.mKeyDataLength = sizeof(kKey2); + + sRegCallbacks[1].Reset(); + sDnsMessages.Clear(); + SuccessOrQuit(mdns->RegisterKey(key, 1, HandleSuccessCallback)); + + for (uint8_t anncCount = 0; anncCount < kNumAnnounces; anncCount++) + { + AdvanceTime((anncCount == 0) ? 0 : (1U << (anncCount - 1)) * 1000); + VerifyOrQuit(sRegCallbacks[1].mWasCalled); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 1, /* Auth */ 0, /* Addnl */ 1); + dnsMsg->Validate(key, kInAnswerSection); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + + sDnsMessages.Clear(); + } + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Unregister the key and validate the goodbye announces"); + + sDnsMessages.Clear(); + SuccessOrQuit(mdns->UnregisterKey(key)); + + for (uint8_t anncCount = 0; anncCount < kNumAnnounces; anncCount++) + { + AdvanceTime((anncCount == 0) ? 0 : (1U << (anncCount - 1)) * 1000); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 1, /* Auth */ 0, /* Addnl */ 0); + dnsMsg->Validate(key, kInAnswerSection, kGoodBye); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + + sDnsMessages.Clear(); + } + } + + AdvanceTime(15000); + VerifyOrQuit(sDnsMessages.IsEmpty()); + + SuccessOrQuit(mdns->SetEnabled(false, kInfraIfIndex)); + VerifyOrQuit(sHeapAllocatedPtrs.GetLength() <= heapAllocations); + + Log("End of test"); + + testFreeInstance(sInstance); +} + +//--------------------------------------------------------------------------------------------------------------------- + +void TestServiceReg(void) +{ + Core *mdns = InitTest(); + Core::Service service; + const DnsMessage *dnsMsg; + uint16_t heapAllocations; + DnsNameString fullServiceName; + DnsNameString fullServiceType; + + Log("-------------------------------------------------------------------------------------------"); + Log("TestServiceReg"); + + AdvanceTime(1); + + heapAllocations = sHeapAllocatedPtrs.GetLength(); + SuccessOrQuit(mdns->SetEnabled(true, kInfraIfIndex)); + + service.mHostName = "myhost"; + service.mServiceInstance = "myservice"; + service.mServiceType = "_srv._udp"; + service.mSubTypeLabels = nullptr; + service.mSubTypeLabelsLength = 0; + service.mTxtData = kTxtData1; + service.mTxtDataLength = sizeof(kTxtData1); + service.mPort = 1234; + service.mPriority = 1; + service.mWeight = 2; + service.mTtl = 1000; + + fullServiceName.Append("%s.%s.local.", service.mServiceInstance, service.mServiceType); + fullServiceType.Append("%s.local.", service.mServiceType); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Register a `ServiceEntry`, check probes and announcements"); + + sDnsMessages.Clear(); + + sRegCallbacks[0].Reset(); + SuccessOrQuit(mdns->RegisterService(service, 0, HandleSuccessCallback)); + + for (uint8_t probeCount = 0; probeCount < 3; probeCount++) + { + sDnsMessages.Clear(); + + VerifyOrQuit(!sRegCallbacks[0].mWasCalled); + AdvanceTime(250); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastQuery, /* Q */ 1, /* Ans */ 0, /* Auth */ 2, /* Addnl */ 0); + dnsMsg->ValidateAsProbeFor(service, /* aUnicastRequest */ (probeCount == 0)); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + } + + for (uint8_t anncCount = 0; anncCount < kNumAnnounces; anncCount++) + { + sDnsMessages.Clear(); + + AdvanceTime((anncCount == 0) ? 250 : (1U << (anncCount - 1)) * 1000); + VerifyOrQuit(sRegCallbacks[0].mWasCalled); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 4, /* Auth */ 0, /* Addnl */ 1); + dnsMsg->Validate(service, kInAnswerSection, kCheckSrv | kCheckTxt | kCheckPtr | kCheckServicesPtr); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + } + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a query for SRV record and validate the response"); + + AdvanceTime(2000); + + sDnsMessages.Clear(); + SendQuery(fullServiceName.AsCString(), ResourceRecord::kTypeSrv); + + AdvanceTime(1000); + + dnsMsg = sDnsMessages.GetHead(); + VerifyOrQuit(dnsMsg != nullptr); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 1, /* Auth */ 0, /* Addnl */ 1); + dnsMsg->Validate(service, kInAnswerSection, kCheckSrv); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a query for TXT record and validate the response"); + + AdvanceTime(2000); + + sDnsMessages.Clear(); + SendQuery(fullServiceName.AsCString(), ResourceRecord::kTypeTxt); + + AdvanceTime(1000); + + dnsMsg = sDnsMessages.GetHead(); + VerifyOrQuit(dnsMsg != nullptr); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 1, /* Auth */ 0, /* Addnl */ 1); + dnsMsg->Validate(service, kInAnswerSection, kCheckTxt); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a query for ANY record and validate the response"); + + AdvanceTime(2000); + + sDnsMessages.Clear(); + SendQuery(fullServiceName.AsCString(), ResourceRecord::kTypeAny); + + AdvanceTime(1000); + + dnsMsg = sDnsMessages.GetHead(); + VerifyOrQuit(dnsMsg != nullptr); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 2, /* Auth */ 0, /* Addnl */ 1); + dnsMsg->Validate(service, kInAnswerSection, kCheckSrv | kCheckTxt); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a query for PTR record for service type and validate the response"); + + AdvanceTime(2000); + + sDnsMessages.Clear(); + SendQuery(fullServiceType.AsCString(), ResourceRecord::kTypePtr); + + AdvanceTime(1000); + + dnsMsg = sDnsMessages.GetHead(); + VerifyOrQuit(dnsMsg != nullptr); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 1, /* Auth */ 0, /* Addnl */ 2); + dnsMsg->Validate(service, kInAnswerSection, kCheckPtr); + dnsMsg->Validate(service, kInAdditionalSection, kCheckSrv | kCheckTxt); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a query for PTR record for `services._dns-sd` and validate the response"); + + AdvanceTime(2000); + + sDnsMessages.Clear(); + SendQuery("_services._dns-sd._udp.local.", ResourceRecord::kTypePtr); + + AdvanceTime(1000); + + dnsMsg = sDnsMessages.GetHead(); + VerifyOrQuit(dnsMsg != nullptr); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 1, /* Auth */ 0, /* Addnl */ 0); + dnsMsg->Validate(service, kInAnswerSection, kCheckServicesPtr); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Update service port number and validate new announcements of SRV record"); + + service.mPort = 4567; + + sRegCallbacks[1].Reset(); + sDnsMessages.Clear(); + SuccessOrQuit(mdns->RegisterService(service, 1, HandleSuccessCallback)); + + for (uint8_t anncCount = 0; anncCount < kNumAnnounces; anncCount++) + { + AdvanceTime((anncCount == 0) ? 0 : (1U << (anncCount - 1)) * 1000); + VerifyOrQuit(sRegCallbacks[1].mWasCalled); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 1, /* Auth */ 0, /* Addnl */ 1); + dnsMsg->Validate(service, kInAnswerSection, kCheckSrv); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + sDnsMessages.Clear(); + } + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Update TXT data and validate new announcements of TXT record"); + + service.mTxtData = nullptr; + service.mTxtDataLength = 0; + + sRegCallbacks[1].Reset(); + sDnsMessages.Clear(); + SuccessOrQuit(mdns->RegisterService(service, 1, HandleSuccessCallback)); + + for (uint8_t anncCount = 0; anncCount < kNumAnnounces; anncCount++) + { + AdvanceTime((anncCount == 0) ? 0 : (1U << (anncCount - 1)) * 1000); + VerifyOrQuit(sRegCallbacks[1].mWasCalled); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 1, /* Auth */ 0, /* Addnl */ 1); + dnsMsg->Validate(service, kInAnswerSection, kCheckTxt); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + sDnsMessages.Clear(); + } + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Update both service and TXT data and validate new announcements of both records"); + + service.mTxtData = kTxtData2; + service.mTxtDataLength = sizeof(kTxtData2); + service.mWeight = 0; + + sRegCallbacks[1].Reset(); + sDnsMessages.Clear(); + SuccessOrQuit(mdns->RegisterService(service, 1, HandleSuccessCallback)); + + for (uint8_t anncCount = 0; anncCount < kNumAnnounces; anncCount++) + { + AdvanceTime((anncCount == 0) ? 0 : (1U << (anncCount - 1)) * 1000); + VerifyOrQuit(sRegCallbacks[1].mWasCalled); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 2, /* Auth */ 0, /* Addnl */ 1); + dnsMsg->Validate(service, kInAnswerSection, kCheckSrv | kCheckTxt); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + sDnsMessages.Clear(); + } + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Update service host name and validate new announcements of SRV record"); + + service.mHostName = "newhost"; + + sRegCallbacks[1].Reset(); + sDnsMessages.Clear(); + SuccessOrQuit(mdns->RegisterService(service, 1, HandleSuccessCallback)); + + for (uint8_t anncCount = 0; anncCount < kNumAnnounces; anncCount++) + { + AdvanceTime((anncCount == 0) ? 0 : (1U << (anncCount - 1)) * 1000); + VerifyOrQuit(sRegCallbacks[1].mWasCalled); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 1, /* Auth */ 0, /* Addnl */ 1); + dnsMsg->Validate(service, kInAnswerSection, kCheckSrv); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + sDnsMessages.Clear(); + } + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Update TTL and validate new announcements of SRV, TXT and PTR records"); + + service.mTtl = 0; + + sRegCallbacks[1].Reset(); + sDnsMessages.Clear(); + SuccessOrQuit(mdns->RegisterService(service, 1, HandleSuccessCallback)); + + for (uint8_t anncCount = 0; anncCount < kNumAnnounces; anncCount++) + { + AdvanceTime((anncCount == 0) ? 0 : (1U << (anncCount - 1)) * 1000); + VerifyOrQuit(sRegCallbacks[1].mWasCalled); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 3, /* Auth */ 0, /* Addnl */ 1); + dnsMsg->Validate(service, kInAnswerSection, kCheckSrv | kCheckTxt | kCheckPtr); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + sDnsMessages.Clear(); + } + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Unregister the service and validate the goodbye announces"); + + sDnsMessages.Clear(); + SuccessOrQuit(mdns->UnregisterService(service)); + + for (uint8_t anncCount = 0; anncCount < kNumAnnounces; anncCount++) + { + AdvanceTime((anncCount == 0) ? 0 : (1U << (anncCount - 1)) * 1000); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 3, /* Auth */ 0, /* Addnl */ 0); + dnsMsg->Validate(service, kInAnswerSection, kCheckSrv | kCheckTxt | kCheckPtr, kGoodBye); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + sDnsMessages.Clear(); + } + + AdvanceTime(15000); + VerifyOrQuit(sDnsMessages.IsEmpty()); + + SuccessOrQuit(mdns->SetEnabled(false, kInfraIfIndex)); + VerifyOrQuit(sHeapAllocatedPtrs.GetLength() <= heapAllocations); + + Log("End of test"); + + testFreeInstance(sInstance); +} + +//--------------------------------------------------------------------------------------------------------------------- + +void TestUnregisterBeforeProbeFinished(void) +{ + const uint8_t kKey1[] = {0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77}; + + Core *mdns = InitTest(); + Core::Host host; + Core::Service service; + Core::Key key; + Ip6::Address hostAddresses[3]; + const DnsMessage *dnsMsg; + uint16_t heapAllocations; + + Log("-------------------------------------------------------------------------------------------"); + Log("TestUnregisterBeforeProbeFinished"); + + AdvanceTime(1); + + heapAllocations = sHeapAllocatedPtrs.GetLength(); + SuccessOrQuit(mdns->SetEnabled(true, kInfraIfIndex)); + + SuccessOrQuit(hostAddresses[0].FromString("fd00::aaaa")); + SuccessOrQuit(hostAddresses[1].FromString("fd00::bbbb")); + SuccessOrQuit(hostAddresses[2].FromString("fd00::cccc")); + + host.mHostName = "myhost"; + host.mAddresses = hostAddresses; + host.mAddressesLength = 3; + host.mTtl = 1500; + + service.mHostName = "myhost"; + service.mServiceInstance = "myservice"; + service.mServiceType = "_srv._udp"; + service.mSubTypeLabels = nullptr; + service.mSubTypeLabelsLength = 0; + service.mTxtData = kTxtData1; + service.mTxtDataLength = sizeof(kTxtData1); + service.mPort = 1234; + service.mPriority = 1; + service.mWeight = 2; + service.mTtl = 1000; + + key.mName = "mysrv"; + key.mServiceType = "_srv._udp"; + key.mKeyData = kKey1; + key.mKeyDataLength = sizeof(kKey1); + key.mTtl = 8000; + + // Repeat the same test 3 times for host and service and key registration. + + for (uint8_t iter = 0; iter < 3; iter++) + { + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Register an entry, check for the first two probes"); + + sDnsMessages.Clear(); + + sRegCallbacks[0].Reset(); + + switch (iter) + { + case 0: + SuccessOrQuit(mdns->RegisterHost(host, 0, HandleSuccessCallback)); + break; + case 1: + SuccessOrQuit(mdns->RegisterService(service, 0, HandleSuccessCallback)); + break; + case 2: + SuccessOrQuit(mdns->RegisterKey(key, 0, HandleSuccessCallback)); + break; + } + + for (uint8_t probeCount = 0; probeCount < 2; probeCount++) + { + sDnsMessages.Clear(); + + VerifyOrQuit(!sRegCallbacks[0].mWasCalled); + AdvanceTime(250); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + + switch (iter) + { + case 0: + dnsMsg->ValidateHeader(kMulticastQuery, /* Q */ 1, /* Ans */ 0, /* Auth */ 3, /* Addnl */ 0); + dnsMsg->ValidateAsProbeFor(host, /* aUnicastRequest */ (probeCount == 0)); + break; + case 1: + dnsMsg->ValidateHeader(kMulticastQuery, /* Q */ 1, /* Ans */ 0, /* Auth */ 2, /* Addnl */ 0); + dnsMsg->ValidateAsProbeFor(service, /* aUnicastRequest */ (probeCount == 0)); + break; + case 2: + dnsMsg->ValidateHeader(kMulticastQuery, /* Q */ 1, /* Ans */ 0, /* Auth */ 1, /* Addnl */ 0); + dnsMsg->ValidateAsProbeFor(key, /* aUnicastRequest */ (probeCount == 0)); + break; + } + + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + } + + sDnsMessages.Clear(); + VerifyOrQuit(!sRegCallbacks[0].mWasCalled); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Unregister the entry before the last probe and make sure probing stops"); + + switch (iter) + { + case 0: + SuccessOrQuit(mdns->UnregisterHost(host)); + break; + case 1: + SuccessOrQuit(mdns->UnregisterService(service)); + break; + case 2: + SuccessOrQuit(mdns->UnregisterKey(key)); + break; + } + + AdvanceTime(20 * 1000); + VerifyOrQuit(!sRegCallbacks[0].mWasCalled); + + VerifyOrQuit(sDnsMessages.IsEmpty()); + } + + SuccessOrQuit(mdns->SetEnabled(false, kInfraIfIndex)); + VerifyOrQuit(sHeapAllocatedPtrs.GetLength() <= heapAllocations); + + Log("End of test"); + + testFreeInstance(sInstance); +} + +//--------------------------------------------------------------------------------------------------------------------- + +void TestServiceSubTypeReg(void) +{ + static const char *const kSubTypes1[] = {"_s1", "_r2", "_vXy", "_last"}; + static const char *const kSubTypes2[] = {"_vxy", "_r1", "_r2", "_zzz"}; + + Core *mdns = InitTest(); + Core::Service service; + const DnsMessage *dnsMsg; + uint16_t heapAllocations; + DnsNameString fullServiceName; + DnsNameString fullServiceType; + DnsNameString fullSubServiceType; + + Log("-------------------------------------------------------------------------------------------"); + Log("TestServiceSubTypeReg"); + + AdvanceTime(1); + + heapAllocations = sHeapAllocatedPtrs.GetLength(); + SuccessOrQuit(mdns->SetEnabled(true, kInfraIfIndex)); + + service.mHostName = "tarnished"; + service.mServiceInstance = "elden"; + service.mServiceType = "_ring._udp"; + service.mSubTypeLabels = kSubTypes1; + service.mSubTypeLabelsLength = 3; + service.mTxtData = kTxtData1; + service.mTxtDataLength = sizeof(kTxtData1); + service.mPort = 1234; + service.mPriority = 1; + service.mWeight = 2; + service.mTtl = 6000; + + fullServiceName.Append("%s.%s.local.", service.mServiceInstance, service.mServiceType); + fullServiceType.Append("%s.local.", service.mServiceType); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Register a `ServiceEntry` with sub-types, check probes and announcements"); + + sDnsMessages.Clear(); + + sRegCallbacks[0].Reset(); + SuccessOrQuit(mdns->RegisterService(service, 0, HandleSuccessCallback)); + + for (uint8_t probeCount = 0; probeCount < 3; probeCount++) + { + sDnsMessages.Clear(); + + VerifyOrQuit(!sRegCallbacks[0].mWasCalled); + AdvanceTime(250); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastQuery, /* Q */ 1, /* Ans */ 0, /* Auth */ 2, /* Addnl */ 0); + dnsMsg->ValidateAsProbeFor(service, /* aUnicastRequest */ (probeCount == 0)); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + } + + for (uint8_t anncCount = 0; anncCount < kNumAnnounces; anncCount++) + { + sDnsMessages.Clear(); + + AdvanceTime((anncCount == 0) ? 250 : (1U << (anncCount - 1)) * 1000); + VerifyOrQuit(sRegCallbacks[0].mWasCalled); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 7, /* Auth */ 0, /* Addnl */ 1); + dnsMsg->Validate(service, kInAnswerSection, kCheckSrv | kCheckTxt | kCheckPtr | kCheckServicesPtr); + + for (uint8_t index = 0; index < service.mSubTypeLabelsLength; index++) + { + dnsMsg->ValidateSubType(service.mSubTypeLabels[index], service); + } + + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + } + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a query for SRV record and validate the response"); + + AdvanceTime(2000); + + sDnsMessages.Clear(); + SendQuery(fullServiceName.AsCString(), ResourceRecord::kTypeSrv); + + AdvanceTime(1000); + + dnsMsg = sDnsMessages.GetHead(); + VerifyOrQuit(dnsMsg != nullptr); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 1, /* Auth */ 0, /* Addnl */ 1); + dnsMsg->Validate(service, kInAnswerSection, kCheckSrv); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a query for TXT record and validate the response"); + + AdvanceTime(2000); + + sDnsMessages.Clear(); + SendQuery(fullServiceName.AsCString(), ResourceRecord::kTypeTxt); + + AdvanceTime(1000); + + dnsMsg = sDnsMessages.GetHead(); + VerifyOrQuit(dnsMsg != nullptr); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 1, /* Auth */ 0, /* Addnl */ 1); + dnsMsg->Validate(service, kInAnswerSection, kCheckTxt); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a query for ANY record and validate the response"); + + AdvanceTime(2000); + + sDnsMessages.Clear(); + SendQuery(fullServiceName.AsCString(), ResourceRecord::kTypeAny); + + AdvanceTime(1000); + + dnsMsg = sDnsMessages.GetHead(); + VerifyOrQuit(dnsMsg != nullptr); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 2, /* Auth */ 0, /* Addnl */ 1); + dnsMsg->Validate(service, kInAnswerSection, kCheckSrv | kCheckTxt); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a query for PTR record for service type and validate the response"); + + AdvanceTime(2000); + + sDnsMessages.Clear(); + SendQuery(fullServiceType.AsCString(), ResourceRecord::kTypePtr); + + AdvanceTime(1000); + + dnsMsg = sDnsMessages.GetHead(); + VerifyOrQuit(dnsMsg != nullptr); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 1, /* Auth */ 0, /* Addnl */ 2); + dnsMsg->Validate(service, kInAnswerSection, kCheckPtr); + dnsMsg->Validate(service, kInAdditionalSection, kCheckSrv | kCheckTxt); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a query for PTR record for `services._dns-sd` and validate the response"); + + AdvanceTime(2000); + + sDnsMessages.Clear(); + SendQuery("_services._dns-sd._udp.local.", ResourceRecord::kTypePtr); + + AdvanceTime(1000); + + dnsMsg = sDnsMessages.GetHead(); + VerifyOrQuit(dnsMsg != nullptr); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 1, /* Auth */ 0, /* Addnl */ 0); + dnsMsg->Validate(service, kInAnswerSection, kCheckServicesPtr); + + for (uint8_t index = 0; index < service.mSubTypeLabelsLength; index++) + { + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a PTR query for sub-type `%s` and validate the response", service.mSubTypeLabels[index]); + + fullSubServiceType.Clear(); + fullSubServiceType.Append("%s._sub.%s", service.mSubTypeLabels[index], fullServiceType.AsCString()); + + AdvanceTime(2000); + + sDnsMessages.Clear(); + SendQuery(fullSubServiceType.AsCString(), ResourceRecord::kTypePtr); + + AdvanceTime(1000); + + dnsMsg = sDnsMessages.GetHead(); + VerifyOrQuit(dnsMsg != nullptr); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 1, /* Auth */ 0, /* Addnl */ 0); + dnsMsg->ValidateSubType(service.mSubTypeLabels[index], service); + } + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a PTR query for non-existing sub-type and validate there is no response"); + + AdvanceTime(2000); + + fullSubServiceType.Clear(); + fullSubServiceType.Append("_none._sub.%s", fullServiceType.AsCString()); + + sDnsMessages.Clear(); + SendQuery(fullSubServiceType.AsCString(), ResourceRecord::kTypePtr); + + AdvanceTime(2000); + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Register a new sub-type and validate announcements of PTR record for it"); + + service.mSubTypeLabelsLength = 4; + + sRegCallbacks[1].Reset(); + sDnsMessages.Clear(); + SuccessOrQuit(mdns->RegisterService(service, 1, HandleSuccessCallback)); + + for (uint8_t anncCount = 0; anncCount < kNumAnnounces; anncCount++) + { + AdvanceTime((anncCount == 0) ? 0 : (1U << (anncCount - 1)) * 1000); + VerifyOrQuit(sRegCallbacks[1].mWasCalled); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 1, /* Auth */ 0, /* Addnl */ 0); + dnsMsg->ValidateSubType(service.mSubTypeLabels[3], service); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + sDnsMessages.Clear(); + } + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Remove a previous sub-type and validate announcements of its removal"); + + service.mSubTypeLabels++; + service.mSubTypeLabelsLength = 3; + + sRegCallbacks[1].Reset(); + sDnsMessages.Clear(); + SuccessOrQuit(mdns->RegisterService(service, 1, HandleSuccessCallback)); + + for (uint8_t anncCount = 0; anncCount < kNumAnnounces; anncCount++) + { + AdvanceTime((anncCount == 0) ? 0 : (1U << (anncCount - 1)) * 1000); + VerifyOrQuit(sRegCallbacks[1].mWasCalled); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 1, /* Auth */ 0, /* Addnl */ 0); + dnsMsg->ValidateSubType(kSubTypes1[0], service, kGoodBye); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + sDnsMessages.Clear(); + } + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Update TTL and validate announcement of all records"); + + service.mTtl = 0; + + sRegCallbacks[1].Reset(); + sDnsMessages.Clear(); + SuccessOrQuit(mdns->RegisterService(service, 1, HandleSuccessCallback)); + + for (uint8_t anncCount = 0; anncCount < kNumAnnounces; anncCount++) + { + AdvanceTime((anncCount == 0) ? 0 : (1U << (anncCount - 1)) * 1000); + VerifyOrQuit(sRegCallbacks[1].mWasCalled); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 6, /* Auth */ 0, /* Addnl */ 1); + dnsMsg->Validate(service, kInAnswerSection, kCheckSrv | kCheckTxt | kCheckPtr); + + for (uint8_t index = 0; index < service.mSubTypeLabelsLength; index++) + { + dnsMsg->ValidateSubType(service.mSubTypeLabels[index], service); + } + + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + sDnsMessages.Clear(); + } + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Add and remove sub-types at the same time and check proper announcements"); + + // Registered sub-types: _r2, _vXy, _last + // New sub-types list : _vxy, _r1, _r2, _zzz + // + // Should announce removal of `_last` and addition of + // `_r1` and `_zzz`. The `_vxy` should match with `_vXy`. + + service.mSubTypeLabels = kSubTypes2; + service.mSubTypeLabelsLength = 4; + + sRegCallbacks[1].Reset(); + sDnsMessages.Clear(); + SuccessOrQuit(mdns->RegisterService(service, 1, HandleSuccessCallback)); + + for (uint8_t anncCount = 0; anncCount < kNumAnnounces; anncCount++) + { + AdvanceTime((anncCount == 0) ? 0 : (1U << (anncCount - 1)) * 1000); + VerifyOrQuit(sRegCallbacks[1].mWasCalled); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 3, /* Auth */ 0, /* Addnl */ 0); + + dnsMsg->ValidateSubType(kSubTypes1[3], service, kGoodBye); + dnsMsg->ValidateSubType(kSubTypes2[1], service); + dnsMsg->ValidateSubType(kSubTypes2[3], service); + + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + sDnsMessages.Clear(); + } + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Unregister the service and validate the goodbye announces for service and its sub-types"); + + sDnsMessages.Clear(); + SuccessOrQuit(mdns->UnregisterService(service)); + + for (uint8_t anncCount = 0; anncCount < kNumAnnounces; anncCount++) + { + AdvanceTime((anncCount == 0) ? 0 : (1U << (anncCount - 1)) * 1000); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 7, /* Auth */ 0, /* Addnl */ 0); + dnsMsg->Validate(service, kInAnswerSection, kCheckSrv | kCheckTxt | kCheckPtr, kGoodBye); + + for (uint8_t index = 0; index < service.mSubTypeLabelsLength; index++) + { + dnsMsg->ValidateSubType(service.mSubTypeLabels[index], service, kGoodBye); + } + + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + sDnsMessages.Clear(); + } + + AdvanceTime(15000); + VerifyOrQuit(sDnsMessages.IsEmpty()); + + SuccessOrQuit(mdns->SetEnabled(false, kInfraIfIndex)); + VerifyOrQuit(sHeapAllocatedPtrs.GetLength() <= heapAllocations); + + Log("End of test"); + + testFreeInstance(sInstance); +} + +void TestHostOrServiceAndKeyReg(void) +{ + Core *mdns = InitTest(); + Core::Host host; + Core::Service service; + Core::Key key; + Ip6::Address hostAddresses[2]; + const DnsMessage *dnsMsg; + uint16_t heapAllocations; + + Log("-------------------------------------------------------------------------------------------"); + Log("TestHostOrServiceAndKeyReg"); + + AdvanceTime(1); + + heapAllocations = sHeapAllocatedPtrs.GetLength(); + SuccessOrQuit(mdns->SetEnabled(true, kInfraIfIndex)); + + SuccessOrQuit(hostAddresses[0].FromString("fd00::1")); + SuccessOrQuit(hostAddresses[1].FromString("fd00::2")); + + host.mHostName = "myhost"; + host.mAddresses = hostAddresses; + host.mAddressesLength = 2; + host.mTtl = 5000; + + key.mKeyData = kKey1; + key.mKeyDataLength = sizeof(kKey1); + key.mTtl = 80000; + + service.mHostName = "myhost"; + service.mServiceInstance = "myservice"; + service.mServiceType = "_srv._udp"; + service.mSubTypeLabels = nullptr; + service.mSubTypeLabelsLength = 0; + service.mTxtData = kTxtData1; + service.mTxtDataLength = sizeof(kTxtData1); + service.mPort = 1234; + service.mPriority = 1; + service.mWeight = 2; + service.mTtl = 1000; + + // Run all test step twice, first time registering host and key, + // second time registering service and key. + + for (uint8_t iter = 0; iter < 2; iter++) + { + if (iter == 0) + { + key.mName = host.mHostName; + key.mServiceType = nullptr; + } + else + { + key.mName = service.mServiceInstance; + key.mServiceType = service.mServiceType; + } + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Register a %s entry, check the first probe is sent", iter == 0 ? "host" : "service"); + + sDnsMessages.Clear(); + + sRegCallbacks[0].Reset(); + + if (iter == 0) + { + SuccessOrQuit(mdns->RegisterHost(host, 0, HandleSuccessCallback)); + } + else + { + SuccessOrQuit(mdns->RegisterService(service, 0, HandleSuccessCallback)); + } + + sDnsMessages.Clear(); + + VerifyOrQuit(!sRegCallbacks[0].mWasCalled); + AdvanceTime(250); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + + dnsMsg->ValidateHeader(kMulticastQuery, /* Q */ 1, /* Ans */ 0, /* Auth */ 2, /* Addnl */ 0); + + if (iter == 0) + { + dnsMsg->ValidateAsProbeFor(host, /* aUnicastRequest */ true); + } + else + { + dnsMsg->ValidateAsProbeFor(service, /* aUnicastRequest */ true); + } + + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Register a `KeyEntry` for same name, check that probes continue"); + + sRegCallbacks[1].Reset(); + SuccessOrQuit(mdns->RegisterKey(key, 1, HandleSuccessCallback)); + + for (uint8_t probeCount = 1; probeCount < 3; probeCount++) + { + sDnsMessages.Clear(); + + VerifyOrQuit(!sRegCallbacks[0].mWasCalled); + VerifyOrQuit(!sRegCallbacks[1].mWasCalled); + + AdvanceTime(250); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastQuery, /* Q */ 1, /* Ans */ 0, /* Auth */ 3, /* Addnl */ 0); + + if (iter == 0) + { + dnsMsg->ValidateAsProbeFor(host, /* aUnicastRequest */ false); + } + else + { + dnsMsg->ValidateAsProbeFor(service, /* aUnicastRequest */ false); + } + + dnsMsg->ValidateAsProbeFor(key, /* aUnicastRequest */ (probeCount == 0)); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + } + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Validate Announces for both entry and key"); + + for (uint8_t anncCount = 0; anncCount < kNumAnnounces; anncCount++) + { + sDnsMessages.Clear(); + + AdvanceTime((anncCount == 0) ? 250 : (1U << (anncCount - 1)) * 1000); + VerifyOrQuit(sRegCallbacks[0].mWasCalled); + VerifyOrQuit(sRegCallbacks[1].mWasCalled); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + + if (iter == 0) + { + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 3, /* Auth */ 0, /* Addnl */ 1); + dnsMsg->Validate(host, kInAnswerSection); + } + else + { + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 5, /* Auth */ 0, /* Addnl */ 1); + dnsMsg->Validate(service, kInAnswerSection, kCheckSrv | kCheckTxt | kCheckPtr | kCheckServicesPtr); + } + + dnsMsg->Validate(key, kInAnswerSection); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + } + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Unregister the entry and validate its goodbye announces"); + + sDnsMessages.Clear(); + + if (iter == 0) + { + SuccessOrQuit(mdns->UnregisterHost(host)); + } + else + { + SuccessOrQuit(mdns->UnregisterService(service)); + } + + for (uint8_t anncCount = 0; anncCount < kNumAnnounces; anncCount++) + { + AdvanceTime((anncCount == 0) ? 0 : (1U << (anncCount - 1)) * 1000); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + + if (iter == 0) + { + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 2, /* Auth */ 0, /* Addnl */ 1); + dnsMsg->Validate(host, kInAnswerSection, kGoodBye); + } + else + { + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 3, /* Auth */ 0, /* Addnl */ 1); + dnsMsg->Validate(service, kInAnswerSection, kCheckSrv | kCheckTxt | kCheckPtr, kGoodBye); + } + + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + sDnsMessages.Clear(); + } + + AdvanceTime(15000); + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Register the entry again, validate its announcements"); + + sDnsMessages.Clear(); + + sRegCallbacks[2].Reset(); + + if (iter == 0) + { + SuccessOrQuit(mdns->RegisterHost(host, 2, HandleSuccessCallback)); + } + else + { + SuccessOrQuit(mdns->RegisterService(service, 2, HandleSuccessCallback)); + } + + for (uint8_t anncCount = 0; anncCount < kNumAnnounces; anncCount++) + { + sDnsMessages.Clear(); + + AdvanceTime((anncCount == 0) ? 250 : (1U << (anncCount - 1)) * 1000); + VerifyOrQuit(sRegCallbacks[2].mWasCalled); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + + if (iter == 0) + { + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 2, /* Auth */ 0, /* Addnl */ 1); + dnsMsg->Validate(host, kInAnswerSection); + } + else + { + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 4, /* Auth */ 0, /* Addnl */ 1); + dnsMsg->Validate(service, kInAnswerSection, kCheckSrv | kCheckTxt | kCheckPtr | kCheckServicesPtr); + } + + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + } + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Unregister the key and validate its goodbye announcements"); + + sDnsMessages.Clear(); + SuccessOrQuit(mdns->UnregisterKey(key)); + + for (uint8_t anncCount = 0; anncCount < kNumAnnounces; anncCount++) + { + AdvanceTime((anncCount == 0) ? 0 : (1U << (anncCount - 1)) * 1000); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 1, /* Auth */ 0, /* Addnl */ 1); + dnsMsg->Validate(key, kInAnswerSection, kGoodBye); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + sDnsMessages.Clear(); + } + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Register the key again, validate its announcements"); + + sDnsMessages.Clear(); + + sRegCallbacks[3].Reset(); + SuccessOrQuit(mdns->RegisterKey(key, 3, HandleSuccessCallback)); + + for (uint8_t anncCount = 0; anncCount < kNumAnnounces; anncCount++) + { + sDnsMessages.Clear(); + + AdvanceTime((anncCount == 0) ? 250 : (1U << (anncCount - 1)) * 1000); + VerifyOrQuit(sRegCallbacks[3].mWasCalled); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 1, /* Auth */ 0, /* Addnl */ 1); + dnsMsg->Validate(key, kInAnswerSection); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + } + + sDnsMessages.Clear(); + AdvanceTime(15000); + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Unregister key first, validate two of its goodbye announcements"); + + sDnsMessages.Clear(); + + SuccessOrQuit(mdns->UnregisterKey(key)); + + for (uint8_t anncCount = 0; anncCount < 2; anncCount++) + { + sDnsMessages.Clear(); + + AdvanceTime((anncCount == 0) ? 1 : (1U << (anncCount - 1)) * 1000); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 1, /* Auth */ 0, /* Addnl */ 1); + dnsMsg->Validate(key, kInAnswerSection, kGoodBye); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + } + + Log("Unregister entry as well"); + + if (iter == 0) + { + SuccessOrQuit(mdns->UnregisterHost(host)); + } + else + { + SuccessOrQuit(mdns->UnregisterService(service)); + } + + AdvanceTime(15000); + + for (uint16_t anncCount = 0; anncCount < 4; anncCount++) + { + dnsMsg = dnsMsg->GetNext(); + VerifyOrQuit(dnsMsg != nullptr); + + if (anncCount == 2) + { + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 1, /* Auth */ 0, /* Addnl */ 0); + dnsMsg->Validate(key, kInAnswerSection, kGoodBye); + } + else if (iter == 0) + { + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 2, /* Auth */ 0, /* Addnl */ 0); + dnsMsg->Validate(host, kInAnswerSection, kGoodBye); + } + else + { + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 3, /* Auth */ 0, /* Addnl */ 0); + dnsMsg->Validate(service, kInAnswerSection, kCheckSrv | kCheckTxt | kCheckPtr, kGoodBye); + } + } + + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + + sDnsMessages.Clear(); + AdvanceTime(15000); + VerifyOrQuit(sDnsMessages.IsEmpty()); + } + + SuccessOrQuit(mdns->SetEnabled(false, kInfraIfIndex)); + VerifyOrQuit(sHeapAllocatedPtrs.GetLength() <= heapAllocations); + + Log("End of test"); + + testFreeInstance(sInstance); +} + +//--------------------------------------------------------------------------------------------------------------------- + +void TestQuery(void) +{ + static const char *const kSubTypes[] = {"_s", "_r"}; + + Core *mdns = InitTest(); + Core::Host host1; + Core::Host host2; + Core::Service service1; + Core::Service service2; + Core::Service service3; + Core::Key key1; + Core::Key key2; + Ip6::Address host1Addresses[3]; + Ip6::Address host2Addresses[2]; + const DnsMessage *dnsMsg; + uint16_t heapAllocations; + DnsNameString host1FullName; + DnsNameString host2FullName; + DnsNameString service1FullName; + DnsNameString service2FullName; + DnsNameString service3FullName; + KnownAnswer knownAnswers[2]; + + Log("-------------------------------------------------------------------------------------------"); + Log("TestQuery"); + + AdvanceTime(1); + + heapAllocations = sHeapAllocatedPtrs.GetLength(); + SuccessOrQuit(mdns->SetEnabled(true, kInfraIfIndex)); + + SuccessOrQuit(host1Addresses[0].FromString("fd00::1:aaaa")); + SuccessOrQuit(host1Addresses[1].FromString("fd00::1:bbbb")); + SuccessOrQuit(host1Addresses[2].FromString("fd00::1:cccc")); + host1.mHostName = "host1"; + host1.mAddresses = host1Addresses; + host1.mAddressesLength = 3; + host1.mTtl = 1500; + host1FullName.Append("%s.local.", host1.mHostName); + + SuccessOrQuit(host2Addresses[0].FromString("fd00::2:eeee")); + SuccessOrQuit(host2Addresses[1].FromString("fd00::2:ffff")); + host2.mHostName = "host2"; + host2.mAddresses = host2Addresses; + host2.mAddressesLength = 2; + host2.mTtl = 1500; + host2FullName.Append("%s.local.", host2.mHostName); + + service1.mHostName = host1.mHostName; + service1.mServiceInstance = "srv1"; + service1.mServiceType = "_srv._udp"; + service1.mSubTypeLabels = kSubTypes; + service1.mSubTypeLabelsLength = 2; + service1.mTxtData = kTxtData1; + service1.mTxtDataLength = sizeof(kTxtData1); + service1.mPort = 1111; + service1.mPriority = 0; + service1.mWeight = 0; + service1.mTtl = 1500; + service1FullName.Append("%s.%s.local.", service1.mServiceInstance, service1.mServiceType); + + service2.mHostName = host1.mHostName; + service2.mServiceInstance = "srv2"; + service2.mServiceType = "_tst._tcp"; + service2.mSubTypeLabels = nullptr; + service2.mSubTypeLabelsLength = 0; + service2.mTxtData = nullptr; + service2.mTxtDataLength = 0; + service2.mPort = 2222; + service2.mPriority = 2; + service2.mWeight = 2; + service2.mTtl = 1500; + service2FullName.Append("%s.%s.local.", service2.mServiceInstance, service2.mServiceType); + + service3.mHostName = host2.mHostName; + service3.mServiceInstance = "srv3"; + service3.mServiceType = "_srv._udp"; + service3.mSubTypeLabels = kSubTypes; + service3.mSubTypeLabelsLength = 1; + service3.mTxtData = kTxtData2; + service3.mTxtDataLength = sizeof(kTxtData2); + service3.mPort = 3333; + service3.mPriority = 3; + service3.mWeight = 3; + service3.mTtl = 1500; + service3FullName.Append("%s.%s.local.", service3.mServiceInstance, service3.mServiceType); + + key1.mName = host2.mHostName; + key1.mServiceType = nullptr; + key1.mKeyData = kKey1; + key1.mKeyDataLength = sizeof(kKey1); + key1.mTtl = 8000; + + key2.mName = service3.mServiceInstance; + key2.mServiceType = service3.mServiceType; + key2.mKeyData = kKey1; + key2.mKeyDataLength = sizeof(kKey1); + key2.mTtl = 8000; + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Register 2 hosts and 3 services and 2 keys"); + + sDnsMessages.Clear(); + + for (RegCallback ®Callbck : sRegCallbacks) + { + regCallbck.Reset(); + } + + SuccessOrQuit(mdns->RegisterHost(host1, 0, HandleSuccessCallback)); + SuccessOrQuit(mdns->RegisterHost(host2, 1, HandleSuccessCallback)); + SuccessOrQuit(mdns->RegisterService(service1, 2, HandleSuccessCallback)); + SuccessOrQuit(mdns->RegisterService(service2, 3, HandleSuccessCallback)); + SuccessOrQuit(mdns->RegisterService(service3, 4, HandleSuccessCallback)); + SuccessOrQuit(mdns->RegisterKey(key1, 5, HandleSuccessCallback)); + SuccessOrQuit(mdns->RegisterKey(key2, 6, HandleSuccessCallback)); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Validate probes for all entries"); + + for (uint8_t probeCount = 0; probeCount < 3; probeCount++) + { + sDnsMessages.Clear(); + AdvanceTime(250); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + + for (uint16_t index = 0; index < 7; index++) + { + VerifyOrQuit(!sRegCallbacks[index].mWasCalled); + } + + dnsMsg->ValidateHeader(kMulticastQuery, /* Q */ 5, /* Ans */ 0, /* Auth */ 13, /* Addnl */ 0); + + dnsMsg->ValidateAsProbeFor(host1, /* aUnicastRequest */ (probeCount == 0)); + dnsMsg->ValidateAsProbeFor(host2, /* aUnicastRequest */ (probeCount == 0)); + dnsMsg->ValidateAsProbeFor(service1, /* aUnicastRequest */ (probeCount == 0)); + dnsMsg->ValidateAsProbeFor(service2, /* aUnicastRequest */ (probeCount == 0)); + dnsMsg->ValidateAsProbeFor(service3, /* aUnicastRequest */ (probeCount == 0)); + dnsMsg->ValidateAsProbeFor(key1, /* aUnicastRequest */ (probeCount == 0)); + dnsMsg->ValidateAsProbeFor(key2, /* aUnicastRequest */ (probeCount == 0)); + + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + } + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Validate announcements for all entries"); + + for (uint8_t anncCount = 0; anncCount < kNumAnnounces; anncCount++) + { + sDnsMessages.Clear(); + + AdvanceTime((anncCount == 0) ? 250 : (1U << (anncCount - 1)) * 1000); + + for (uint16_t index = 0; index < 7; index++) + { + VerifyOrQuit(sRegCallbacks[index].mWasCalled); + } + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 21, /* Auth */ 0, /* Addnl */ 5); + + dnsMsg->Validate(host1, kInAnswerSection); + dnsMsg->Validate(host2, kInAnswerSection); + dnsMsg->Validate(service1, kInAnswerSection, kCheckSrv | kCheckTxt | kCheckPtr | kCheckServicesPtr); + dnsMsg->Validate(service2, kInAnswerSection, kCheckSrv | kCheckTxt | kCheckPtr | kCheckServicesPtr); + dnsMsg->Validate(service2, kInAnswerSection, kCheckSrv | kCheckTxt | kCheckPtr | kCheckServicesPtr); + dnsMsg->Validate(key1, kInAnswerSection); + dnsMsg->Validate(key2, kInAnswerSection); + + for (uint8_t index = 0; index < service1.mSubTypeLabelsLength; index++) + { + dnsMsg->ValidateSubType(service1.mSubTypeLabels[index], service1); + } + + for (uint8_t index = 0; index < service3.mSubTypeLabelsLength; index++) + { + dnsMsg->ValidateSubType(service3.mSubTypeLabels[index], service3); + } + + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + } + + sDnsMessages.Clear(); + AdvanceTime(15000); + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a PTR query (browse) for `_srv._udp` and validate two answers and additional data"); + + AdvanceTime(2000); + sDnsMessages.Clear(); + + SendQuery("_srv._udp.local.", ResourceRecord::kTypePtr); + + AdvanceTime(200); + + dnsMsg = sDnsMessages.GetHead(); + VerifyOrQuit(dnsMsg != nullptr); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 2, /* Auth */ 0, /* Addnl */ 9); + + dnsMsg->Validate(service1, kInAnswerSection, kCheckPtr); + dnsMsg->Validate(service3, kInAnswerSection, kCheckPtr); + dnsMsg->Validate(service1, kInAdditionalSection, kCheckSrv | kCheckTxt); + dnsMsg->Validate(service3, kInAdditionalSection, kCheckSrv | kCheckTxt); + dnsMsg->Validate(host1, kInAdditionalSection); + dnsMsg->Validate(host2, kInAdditionalSection); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Resend the same query but request a unicast response, validate the response"); + + sDnsMessages.Clear(); + SendQuery("_srv._udp.local.", ResourceRecord::kTypePtr, ResourceRecord::kClassInternet | kClassQueryUnicastFlag); + + AdvanceTime(200); + + dnsMsg = sDnsMessages.GetHead(); + VerifyOrQuit(dnsMsg != nullptr); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + + dnsMsg->ValidateHeader(kUnicastResponse, /* Q */ 0, /* Ans */ 2, /* Auth */ 0, /* Addnl */ 9); + + dnsMsg->Validate(service1, kInAnswerSection, kCheckPtr); + dnsMsg->Validate(service3, kInAnswerSection, kCheckPtr); + dnsMsg->Validate(service1, kInAdditionalSection, kCheckSrv | kCheckTxt); + dnsMsg->Validate(service3, kInAdditionalSection, kCheckSrv | kCheckTxt); + dnsMsg->Validate(host1, kInAdditionalSection); + dnsMsg->Validate(host2, kInAdditionalSection); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Resend the same multicast query and validate that response is not emitted (rate limit)"); + + sDnsMessages.Clear(); + SendQuery("_srv._udp.local.", ResourceRecord::kTypePtr); + + AdvanceTime(1000); + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Wait for > 1 second and resend the query and validate that now a response is emitted"); + + SendQuery("_srv._udp.local.", ResourceRecord::kTypePtr); + + AdvanceTime(200); + + dnsMsg = sDnsMessages.GetHead(); + VerifyOrQuit(dnsMsg != nullptr); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 2, /* Auth */ 0, /* Addnl */ 9); + + dnsMsg->Validate(service1, kInAnswerSection, kCheckPtr); + dnsMsg->Validate(service3, kInAnswerSection, kCheckPtr); + dnsMsg->Validate(service1, kInAdditionalSection, kCheckSrv | kCheckTxt); + dnsMsg->Validate(service3, kInAdditionalSection, kCheckSrv | kCheckTxt); + dnsMsg->Validate(host1, kInAdditionalSection); + dnsMsg->Validate(host2, kInAdditionalSection); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Browse for sub-type `_s._sub._srv._udp` and validate two answers"); + + sDnsMessages.Clear(); + SendQuery("_s._sub._srv._udp.local.", ResourceRecord::kTypePtr); + + AdvanceTime(200); + + dnsMsg = sDnsMessages.GetHead(); + VerifyOrQuit(dnsMsg != nullptr); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 2, /* Auth */ 0, /* Addnl */ 0); + + dnsMsg->ValidateSubType("_s", service1); + dnsMsg->ValidateSubType("_s", service3); + + // Send same query again and make sure it is ignored (rate limit). + + sDnsMessages.Clear(); + SendQuery("_s._sub._srv._udp.local.", ResourceRecord::kTypePtr); + + AdvanceTime(1000); + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Validate that query with `ANY class` instead of `IN class` is responded"); + + AdvanceTime(2000); + + sDnsMessages.Clear(); + SendQuery("_r._sub._srv._udp.local.", ResourceRecord::kTypePtr, ResourceRecord::kClassAny); + + AdvanceTime(200); + + dnsMsg = sDnsMessages.GetHead(); + VerifyOrQuit(dnsMsg != nullptr); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 1, /* Auth */ 0, /* Addnl */ 0); + dnsMsg->ValidateSubType("_r", service1); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Validate that query with other `class` is ignored"); + + AdvanceTime(2000); + + sDnsMessages.Clear(); + SendQuery("_r._sub._srv._udp.local.", ResourceRecord::kTypePtr, ResourceRecord::kClassNone); + + AdvanceTime(2000); + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Validate that query for non-registered name is ignored"); + + sDnsMessages.Clear(); + SendQuery("_u._sub._srv._udp.local.", ResourceRecord::kTypeAny); + SendQuery("host3.local.", ResourceRecord::kTypeAny); + + AdvanceTime(2000); + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Query for SRV for `srv1._srv._udp` and validate answer and additional data"); + + sDnsMessages.Clear(); + + SendQuery("srv1._srv._udp.local.", ResourceRecord::kTypeSrv); + + AdvanceTime(200); + + dnsMsg = sDnsMessages.GetHead(); + VerifyOrQuit(dnsMsg != nullptr); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 1, /* Auth */ 0, /* Addnl */ 4); + + dnsMsg->Validate(service1, kInAnswerSection, kCheckSrv); + dnsMsg->Validate(host1, kInAdditionalSection); + + //--- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- + // Query with multiple questions + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a query with two questions (SRV for service1 and AAAA for host1). Validate response"); + + AdvanceTime(2000); + + sDnsMessages.Clear(); + SendQueryForTwo("srv1._srv._udp.local.", ResourceRecord::kTypeSrv, "host1.local.", ResourceRecord::kTypeAaaa); + + AdvanceTime(200); + + dnsMsg = sDnsMessages.GetHead(); + VerifyOrQuit(dnsMsg != nullptr); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + + // Since AAAA record are already present in Answer they should not be appended + // in Additional anymore (for the SRV query). + + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 4, /* Auth */ 0, /* Addnl */ 2); + + dnsMsg->Validate(service1, kInAnswerSection, kCheckSrv); + dnsMsg->Validate(host1, kInAnswerSection); + + //--- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- + // Known-answer suppression + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a PTR query for `_srv._udp` and include `srv1` as known-answer and validate response"); + + knownAnswers[0].mPtrAnswer = "srv1._srv._udp.local."; + knownAnswers[0].mTtl = 1500; + + AdvanceTime(1000); + + sDnsMessages.Clear(); + SendPtrQueryWithKnownAnswers("_srv._udp.local.", knownAnswers, 1); + + AdvanceTime(200); + + dnsMsg = sDnsMessages.GetHead(); + VerifyOrQuit(dnsMsg != nullptr); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + + // Response should include `service3` only + + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 1, /* Auth */ 0, /* Addnl */ 4); + dnsMsg->Validate(service3, kInAnswerSection, kCheckPtr); + dnsMsg->Validate(service3, kInAdditionalSection, kCheckSrv | kCheckTxt); + dnsMsg->Validate(host2, kInAdditionalSection); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a PTR query again with both services as known-answer, validate no response is emitted"); + + knownAnswers[1].mPtrAnswer = "srv3._srv._udp.local."; + knownAnswers[1].mTtl = 1500; + + AdvanceTime(1000); + + sDnsMessages.Clear(); + SendPtrQueryWithKnownAnswers("_srv._udp.local.", knownAnswers, 2); + + AdvanceTime(2000); + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a PTR query for `_srv._udp` and include `srv1` as known-answer and validate response"); + + knownAnswers[0].mPtrAnswer = "srv1._srv._udp.local."; + knownAnswers[0].mTtl = 1500; + + AdvanceTime(1000); + + sDnsMessages.Clear(); + SendPtrQueryWithKnownAnswers("_srv._udp.local.", knownAnswers, 1); + + AdvanceTime(200); + + dnsMsg = sDnsMessages.GetHead(); + VerifyOrQuit(dnsMsg != nullptr); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + + // Response should include `service3` only + + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 1, /* Auth */ 0, /* Addnl */ 4); + dnsMsg->Validate(service3, kInAnswerSection, kCheckPtr); + dnsMsg->Validate(service3, kInAdditionalSection, kCheckSrv | kCheckTxt); + dnsMsg->Validate(host2, kInAdditionalSection); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Change the TTL for known-answer to less than half of record TTL and validate response"); + + knownAnswers[1].mTtl = 1500 / 2 - 1; + + AdvanceTime(1000); + + sDnsMessages.Clear(); + SendPtrQueryWithKnownAnswers("_srv._udp.local.", knownAnswers, 2); + + AdvanceTime(200); + + dnsMsg = sDnsMessages.GetHead(); + VerifyOrQuit(dnsMsg != nullptr); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + + // Response should include `service3` only since anwer TTL + // is less than half of registered TTL + + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 1, /* Auth */ 0, /* Addnl */ 4); + dnsMsg->Validate(service3, kInAnswerSection, kCheckPtr); + dnsMsg->Validate(service3, kInAdditionalSection, kCheckSrv | kCheckTxt); + dnsMsg->Validate(host2, kInAdditionalSection); + + //--- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- + // Query during Goodbye announcements + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Unregister `service1` and wait for its two announcements and validate them"); + + sDnsMessages.Clear(); + SuccessOrQuit(mdns->UnregisterService(service1)); + + for (uint8_t anncCount = 0; anncCount < kNumAnnounces - 1; anncCount++) + { + sDnsMessages.Clear(); + + AdvanceTime((anncCount == 0) ? 250 : (1U << (anncCount - 1)) * 1000); + + dnsMsg = sDnsMessages.GetHead(); + VerifyOrQuit(dnsMsg != nullptr); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 5, /* Auth */ 0, /* Addnl */ 0); + dnsMsg->Validate(service1, kInAnswerSection, kCheckSrv | kCheckTxt | kCheckPtr, kGoodBye); + + for (uint8_t index = 0; index < service1.mSubTypeLabelsLength; index++) + { + dnsMsg->ValidateSubType(service1.mSubTypeLabels[index], service1, kGoodBye); + } + } + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a query for removed `service1` before its final announcement, validate no response"); + + sDnsMessages.Clear(); + + AdvanceTime(1100); + SendQuery("srv1._srv._udp.local.", ResourceRecord::kTypeSrv); + + AdvanceTime(200); + + VerifyOrQuit(sDnsMessages.IsEmpty()); + + // Wait for final announcement and validate it + + AdvanceTime(2000); + + dnsMsg = sDnsMessages.GetHead(); + VerifyOrQuit(dnsMsg != nullptr); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 5, /* Auth */ 0, /* Addnl */ 0); + dnsMsg->Validate(service1, kInAnswerSection, kCheckSrv | kCheckTxt | kCheckPtr, kGoodBye); + + for (uint8_t index = 0; index < service1.mSubTypeLabelsLength; index++) + { + dnsMsg->ValidateSubType(service1.mSubTypeLabels[index], service1, kGoodBye); + } + + SuccessOrQuit(mdns->SetEnabled(false, kInfraIfIndex)); + VerifyOrQuit(sHeapAllocatedPtrs.GetLength() <= heapAllocations); + + Log("End of test"); + + testFreeInstance(sInstance); +} + +//---------------------------------------------------------------------------------------------------------------------- + +void TestMultiPacket(void) +{ + static const char *const kSubTypes[] = {"_s1", "_r2", "vxy"}; + + Core *mdns = InitTest(); + Core::Service service; + const DnsMessage *dnsMsg; + uint16_t heapAllocations; + DnsNameString fullServiceName; + DnsNameString fullServiceType; + KnownAnswer knownAnswers[2]; + + Log("-------------------------------------------------------------------------------------------"); + Log("TestMultiPacket"); + + AdvanceTime(1); + + heapAllocations = sHeapAllocatedPtrs.GetLength(); + SuccessOrQuit(mdns->SetEnabled(true, kInfraIfIndex)); + + service.mHostName = "myhost"; + service.mServiceInstance = "mysrv"; + service.mServiceType = "_tst._udp"; + service.mSubTypeLabels = kSubTypes; + service.mSubTypeLabelsLength = 3; + service.mTxtData = kTxtData1; + service.mTxtDataLength = sizeof(kTxtData1); + service.mPort = 2222; + service.mPriority = 3; + service.mWeight = 4; + service.mTtl = 2000; + + fullServiceName.Append("%s.%s.local.", service.mServiceInstance, service.mServiceType); + fullServiceType.Append("%s.local.", service.mServiceType); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Register a `ServiceEntry` with sub-types, check probes and announcements"); + + sDnsMessages.Clear(); + + sRegCallbacks[0].Reset(); + SuccessOrQuit(mdns->RegisterService(service, 0, HandleSuccessCallback)); + + for (uint8_t probeCount = 0; probeCount < 3; probeCount++) + { + sDnsMessages.Clear(); + + VerifyOrQuit(!sRegCallbacks[0].mWasCalled); + AdvanceTime(250); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastQuery, /* Q */ 1, /* Ans */ 0, /* Auth */ 2, /* Addnl */ 0); + dnsMsg->ValidateAsProbeFor(service, /* aUnicastRequest */ (probeCount == 0)); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + } + + for (uint8_t anncCount = 0; anncCount < kNumAnnounces; anncCount++) + { + sDnsMessages.Clear(); + + AdvanceTime((anncCount == 0) ? 250 : (1U << (anncCount - 1)) * 1000); + VerifyOrQuit(sRegCallbacks[0].mWasCalled); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 7, /* Auth */ 0, /* Addnl */ 1); + dnsMsg->Validate(service, kInAnswerSection, kCheckSrv | kCheckTxt | kCheckPtr | kCheckServicesPtr); + + for (uint8_t index = 0; index < service.mSubTypeLabelsLength; index++) + { + dnsMsg->ValidateSubType(service.mSubTypeLabels[index], service); + } + + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + } + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a query for PTR record for service type and validate the response"); + + AdvanceTime(2000); + + sDnsMessages.Clear(); + SendQuery(fullServiceType.AsCString(), ResourceRecord::kTypePtr); + + AdvanceTime(1000); + + dnsMsg = sDnsMessages.GetHead(); + VerifyOrQuit(dnsMsg != nullptr); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 1, /* Auth */ 0, /* Addnl */ 2); + dnsMsg->Validate(service, kInAnswerSection, kCheckPtr); + dnsMsg->Validate(service, kInAdditionalSection, kCheckSrv | kCheckTxt); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a PTR query again but mark it as truncated"); + + AdvanceTime(2000); + + sDnsMessages.Clear(); + SendQuery(fullServiceType.AsCString(), ResourceRecord::kTypePtr, ResourceRecord::kClassInternet, + /* aTruncated */ true); + + Log("Since message is marked as `truncated`, mDNS should wait at least 400 msec"); + + AdvanceTime(400); + VerifyOrQuit(sDnsMessages.IsEmpty()); + + AdvanceTime(2000); + dnsMsg = sDnsMessages.GetHead(); + VerifyOrQuit(dnsMsg != nullptr); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 1, /* Auth */ 0, /* Addnl */ 2); + dnsMsg->Validate(service, kInAnswerSection, kCheckPtr); + dnsMsg->Validate(service, kInAdditionalSection, kCheckSrv | kCheckTxt); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a PTR query again as truncated followed-up by a non-matching answer"); + + AdvanceTime(2000); + + sDnsMessages.Clear(); + SendQuery(fullServiceType.AsCString(), ResourceRecord::kTypePtr, ResourceRecord::kClassInternet, + /* aTruncated */ true); + AdvanceTime(10); + + knownAnswers[0].mPtrAnswer = "other._tst._udp.local."; + knownAnswers[0].mTtl = 1500; + + SendEmtryPtrQueryWithKnownAnswers(fullServiceType.AsCString(), knownAnswers, 1); + + AdvanceTime(1000); + dnsMsg = sDnsMessages.GetHead(); + VerifyOrQuit(dnsMsg != nullptr); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 1, /* Auth */ 0, /* Addnl */ 2); + dnsMsg->Validate(service, kInAnswerSection, kCheckPtr); + dnsMsg->Validate(service, kInAdditionalSection, kCheckSrv | kCheckTxt); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a PTR query again as truncated now followed-up by matching known-answer"); + + AdvanceTime(2000); + + sDnsMessages.Clear(); + SendQuery(fullServiceType.AsCString(), ResourceRecord::kTypePtr, ResourceRecord::kClassInternet, + /* aTruncated */ true); + AdvanceTime(10); + + knownAnswers[1].mPtrAnswer = "mysrv._tst._udp.local."; + knownAnswers[1].mTtl = 1500; + + SendEmtryPtrQueryWithKnownAnswers(fullServiceType.AsCString(), knownAnswers, 2); + + Log("We expect no response since the followed-up message contains a matching known-answer"); + AdvanceTime(5000); + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a truncated query for PTR record for `services._dns-sd`"); + + AdvanceTime(2000); + + sDnsMessages.Clear(); + SendQuery("_services._dns-sd._udp.local.", ResourceRecord::kTypePtr, ResourceRecord::kClassInternet, + /* aTruncated */ true); + + Log("Response should be sent after longer wait time"); + AdvanceTime(1000); + + dnsMsg = sDnsMessages.GetHead(); + VerifyOrQuit(dnsMsg != nullptr); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 1, /* Auth */ 0, /* Addnl */ 0); + dnsMsg->Validate(service, kInAnswerSection, kCheckServicesPtr); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a truncated query for PTR record for `services._dns-sd` folloed by known-aswer"); + + AdvanceTime(2000); + + sDnsMessages.Clear(); + SendQuery("_services._dns-sd._udp.local.", ResourceRecord::kTypePtr, ResourceRecord::kClassInternet, + /* aTruncated */ true); + + AdvanceTime(20); + knownAnswers[0].mPtrAnswer = "_other._udp.local."; + knownAnswers[0].mTtl = 4500; + + SendEmtryPtrQueryWithKnownAnswers("_services._dns-sd._udp.local.", knownAnswers, 1); + + Log("Response should be sent again due to answer not matching"); + AdvanceTime(1000); + + dnsMsg = sDnsMessages.GetHead(); + VerifyOrQuit(dnsMsg != nullptr); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 1, /* Auth */ 0, /* Addnl */ 0); + dnsMsg->Validate(service, kInAnswerSection, kCheckServicesPtr); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send the same truncated query again but follow-up with a matching known-answer message"); + + AdvanceTime(2000); + + sDnsMessages.Clear(); + SendQuery("_services._dns-sd._udp.local.", ResourceRecord::kTypePtr, ResourceRecord::kClassInternet, + /* aTruncated */ true); + + AdvanceTime(20); + knownAnswers[1].mPtrAnswer = "_tst._udp.local."; + knownAnswers[1].mTtl = 4500; + + SendEmtryPtrQueryWithKnownAnswers("_services._dns-sd._udp.local.", knownAnswers, 2); + + Log("We expect no response since the followed-up message contains a matching known-answer"); + AdvanceTime(5000); + VerifyOrQuit(sDnsMessages.IsEmpty()); + + SuccessOrQuit(mdns->SetEnabled(false, kInfraIfIndex)); + VerifyOrQuit(sHeapAllocatedPtrs.GetLength() <= heapAllocations); + + Log("End of test"); + + testFreeInstance(sInstance); +} + +//--------------------------------------------------------------------------------------------------------------------- + +void TestQuestionUnicastDisallowed(void) +{ + Core *mdns = InitTest(); + Core::Host host; + Ip6::Address hostAddresses[1]; + const DnsMessage *dnsMsg; + uint16_t heapAllocations; + DnsNameString hostFullName; + + Log("-------------------------------------------------------------------------------------------"); + Log("TestQuestionUnicastDisallowed"); + + AdvanceTime(1); + + heapAllocations = sHeapAllocatedPtrs.GetLength(); + SuccessOrQuit(mdns->SetEnabled(true, kInfraIfIndex)); + + SuccessOrQuit(hostAddresses[0].FromString("fd00::1234")); + + host.mHostName = "myhost"; + host.mAddresses = hostAddresses; + host.mAddressesLength = 1; + host.mTtl = 1500; + + mdns->SetQuestionUnicastAllowed(false); + VerifyOrQuit(!mdns->IsQuestionUnicastAllowed()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Register a `HostEntry`, check probes and announcements"); + + sDnsMessages.Clear(); + + sRegCallbacks[0].Reset(); + SuccessOrQuit(mdns->RegisterHost(host, 0, HandleSuccessCallback)); + + for (uint8_t probeCount = 0; probeCount < 3; probeCount++) + { + sDnsMessages.Clear(); + + VerifyOrQuit(!sRegCallbacks[0].mWasCalled); + AdvanceTime(250); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastQuery, /* Q */ 1, /* Ans */ 0, /* Auth */ 1, /* Addnl */ 0); + dnsMsg->ValidateAsProbeFor(host, /* aUnicastRequest */ false); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + } + + for (uint8_t anncCount = 0; anncCount < kNumAnnounces; anncCount++) + { + sDnsMessages.Clear(); + + AdvanceTime((anncCount == 0) ? 250 : (1U << (anncCount - 1)) * 1000); + VerifyOrQuit(sRegCallbacks[0].mWasCalled); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 1, /* Auth */ 0, /* Addnl */ 1); + dnsMsg->Validate(host, kInAnswerSection); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + } + + sDnsMessages.Clear(); + AdvanceTime(15000); + VerifyOrQuit(sDnsMessages.IsEmpty()); + + SuccessOrQuit(mdns->SetEnabled(false, kInfraIfIndex)); + VerifyOrQuit(sHeapAllocatedPtrs.GetLength() <= heapAllocations); + + Log("End of test"); + + testFreeInstance(sInstance); +} + +//--------------------------------------------------------------------------------------------------------------------- + +void TestTxMessageSizeLimit(void) +{ + Core *mdns = InitTest(); + Core::Host host; + Core::Service service; + Core::Key hostKey; + Core::Key serviceKey; + Ip6::Address hostAddresses[3]; + uint8_t keyData[300]; + const DnsMessage *dnsMsg; + uint16_t heapAllocations; + DnsNameString hostFullName; + DnsNameString serviceFullName; + + memset(keyData, 1, sizeof(keyData)); + + Log("-------------------------------------------------------------------------------------------"); + Log("TestTxMessageSizeLimit"); + + AdvanceTime(1); + + heapAllocations = sHeapAllocatedPtrs.GetLength(); + SuccessOrQuit(mdns->SetEnabled(true, kInfraIfIndex)); + + SuccessOrQuit(hostAddresses[0].FromString("fd00::1:aaaa")); + SuccessOrQuit(hostAddresses[1].FromString("fd00::1:bbbb")); + SuccessOrQuit(hostAddresses[2].FromString("fd00::1:cccc")); + host.mHostName = "myhost"; + host.mAddresses = hostAddresses; + host.mAddressesLength = 3; + host.mTtl = 1500; + hostFullName.Append("%s.local.", host.mHostName); + + service.mHostName = host.mHostName; + service.mServiceInstance = "mysrv"; + service.mServiceType = "_srv._udp"; + service.mSubTypeLabels = nullptr; + service.mSubTypeLabelsLength = 0; + service.mTxtData = kTxtData1; + service.mTxtDataLength = sizeof(kTxtData1); + service.mPort = 1111; + service.mPriority = 0; + service.mWeight = 0; + service.mTtl = 1500; + serviceFullName.Append("%s.%s.local.", service.mServiceInstance, service.mServiceType); + + hostKey.mName = host.mHostName; + hostKey.mServiceType = nullptr; + hostKey.mKeyData = keyData; + hostKey.mKeyDataLength = 300; + hostKey.mTtl = 8000; + + serviceKey.mName = service.mServiceInstance; + serviceKey.mServiceType = service.mServiceType; + serviceKey.mKeyData = keyData; + serviceKey.mKeyDataLength = 300; + serviceKey.mTtl = 8000; + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Set `MaxMessageSize` to 340 and use large key record data to trigger size limit behavior"); + + mdns->SetMaxMessageSize(340); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Register host and service and keys for each"); + + sDnsMessages.Clear(); + + for (RegCallback ®Callbck : sRegCallbacks) + { + regCallbck.Reset(); + } + + SuccessOrQuit(mdns->RegisterHost(host, 0, HandleSuccessCallback)); + SuccessOrQuit(mdns->RegisterService(service, 1, HandleSuccessCallback)); + SuccessOrQuit(mdns->RegisterKey(hostKey, 2, HandleSuccessCallback)); + SuccessOrQuit(mdns->RegisterKey(serviceKey, 3, HandleSuccessCallback)); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Validate probes for all entries"); + Log("Probes for host and service should be broken into separate message due to size limit"); + + for (uint8_t probeCount = 0; probeCount < 3; probeCount++) + { + sDnsMessages.Clear(); + AdvanceTime(250); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + + for (uint16_t index = 0; index < 4; index++) + { + VerifyOrQuit(!sRegCallbacks[index].mWasCalled); + } + + dnsMsg->ValidateHeader(kMulticastQuery, /* Q */ 1, /* Ans */ 0, /* Auth */ 4, /* Addnl */ 0); + dnsMsg->ValidateAsProbeFor(host, /* aUnicastRequest */ (probeCount == 0)); + dnsMsg->ValidateAsProbeFor(hostKey, /* aUnicastRequest */ (probeCount == 0)); + + dnsMsg = dnsMsg->GetNext(); + VerifyOrQuit(dnsMsg != nullptr); + + dnsMsg->ValidateHeader(kMulticastQuery, /* Q */ 1, /* Ans */ 0, /* Auth */ 3, /* Addnl */ 0); + dnsMsg->ValidateAsProbeFor(service, /* aUnicastRequest */ (probeCount == 0)); + dnsMsg->ValidateAsProbeFor(serviceKey, /* aUnicastRequest */ (probeCount == 0)); + + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + } + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Validate announcements for all entries"); + Log("Announces should also be broken into separate message due to size limit"); + + for (uint8_t anncCount = 0; anncCount < kNumAnnounces; anncCount++) + { + sDnsMessages.Clear(); + + AdvanceTime((anncCount == 0) ? 250 : (1U << (anncCount - 1)) * 1000); + + for (uint16_t index = 0; index < 4; index++) + { + VerifyOrQuit(sRegCallbacks[index].mWasCalled); + } + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 4, /* Auth */ 0, /* Addnl */ 1); + dnsMsg->Validate(host, kInAnswerSection); + dnsMsg->Validate(hostKey, kInAnswerSection); + + dnsMsg = dnsMsg->GetNext(); + VerifyOrQuit(dnsMsg != nullptr); + + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 4, /* Auth */ 0, /* Addnl */ 4); + dnsMsg->Validate(service, kInAnswerSection, kCheckSrv | kCheckTxt | kCheckPtr); + dnsMsg->Validate(serviceKey, kInAnswerSection); + + dnsMsg = dnsMsg->GetNext(); + VerifyOrQuit(dnsMsg != nullptr); + + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 1, /* Auth */ 0, /* Addnl */ 0); + dnsMsg->Validate(service, kInAnswerSection, kCheckServicesPtr); + + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + } + + SuccessOrQuit(mdns->SetEnabled(false, kInfraIfIndex)); + VerifyOrQuit(sHeapAllocatedPtrs.GetLength() <= heapAllocations); + + Log("End of test"); + + testFreeInstance(sInstance); +} + +//--------------------------------------------------------------------------------------------------------------------- + +void TestHostConflict(void) +{ + Core *mdns = InitTest(); + Core::Host host; + Ip6::Address hostAddresses[2]; + const DnsMessage *dnsMsg; + uint16_t heapAllocations; + DnsNameString hostFullName; + + Log("-------------------------------------------------------------------------------------------"); + Log("TestHostConflict"); + + AdvanceTime(1); + + heapAllocations = sHeapAllocatedPtrs.GetLength(); + SuccessOrQuit(mdns->SetEnabled(true, kInfraIfIndex)); + + SuccessOrQuit(hostAddresses[0].FromString("fd00::1")); + SuccessOrQuit(hostAddresses[1].FromString("fd00::2")); + + host.mHostName = "myhost"; + host.mAddresses = hostAddresses; + host.mAddressesLength = 2; + host.mTtl = 1500; + + hostFullName.Append("%s.local.", host.mHostName); + + // Run the test twice, first run send response with record in Answer section, + // section run in Additional Data section. + + sConflictCallback.Reset(); + mdns->SetConflictCallback(HandleConflict); + + for (uint8_t iter = 0; iter < 2; iter++) + { + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Register a `HostEntry`, wait for first probe"); + + sDnsMessages.Clear(); + + sRegCallbacks[0].Reset(); + SuccessOrQuit(mdns->RegisterHost(host, 0, HandleCallback)); + + VerifyOrQuit(!sRegCallbacks[0].mWasCalled); + AdvanceTime(250); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastQuery, /* Q */ 1, /* Ans */ 0, /* Auth */ 2, /* Addnl */ 0); + dnsMsg->ValidateAsProbeFor(host, /* aUnicastRequest */ true); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a response claiming the name with record in %s section", (iter == 0) ? "answer" : "additional"); + + SendResponseWithEmptyKey(hostFullName.AsCString(), (iter == 0) ? kInAnswerSection : kInAdditionalSection); + AdvanceTime(1); + + VerifyOrQuit(sRegCallbacks[0].mWasCalled); + VerifyOrQuit(sRegCallbacks[0].mError == kErrorDuplicated); + + VerifyOrQuit(!sConflictCallback.mWasCalled); + + sDnsMessages.Clear(); + + SuccessOrQuit(mdns->UnregisterHost(host)); + + AdvanceTime(15000); + VerifyOrQuit(sDnsMessages.IsEmpty()); + } + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Register a `HostEntry` and respond to probe to trigger conflict"); + + sRegCallbacks[0].Reset(); + SuccessOrQuit(mdns->RegisterHost(host, 0, HandleCallback)); + + VerifyOrQuit(!sRegCallbacks[0].mWasCalled); + + SendResponseWithEmptyKey(hostFullName.AsCString(), kInAnswerSection); + AdvanceTime(1); + + VerifyOrQuit(sRegCallbacks[0].mWasCalled); + VerifyOrQuit(sRegCallbacks[0].mError == kErrorDuplicated); + VerifyOrQuit(!sConflictCallback.mWasCalled); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Register the conflicted `HostEntry` again, and make sure no probes are sent"); + + sRegCallbacks[1].Reset(); + sConflictCallback.Reset(); + sDnsMessages.Clear(); + + SuccessOrQuit(mdns->RegisterHost(host, 1, HandleCallback)); + AdvanceTime(5000); + + VerifyOrQuit(sRegCallbacks[1].mWasCalled); + VerifyOrQuit(sRegCallbacks[1].mError == kErrorDuplicated); + VerifyOrQuit(!sConflictCallback.mWasCalled); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Unregister the conflicted host and register it again immediately, make sure we see probes"); + + SuccessOrQuit(mdns->UnregisterHost(host)); + + sConflictCallback.Reset(); + sRegCallbacks[0].Reset(); + SuccessOrQuit(mdns->RegisterHost(host, 0, HandleSuccessCallback)); + + for (uint8_t probeCount = 0; probeCount < 3; probeCount++) + { + sDnsMessages.Clear(); + + VerifyOrQuit(!sRegCallbacks[0].mWasCalled); + AdvanceTime(250); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastQuery, /* Q */ 1, /* Ans */ 0, /* Auth */ 2, /* Addnl */ 0); + dnsMsg->ValidateAsProbeFor(host, /* aUnicastRequest */ (probeCount == 0)); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + } + + for (uint8_t anncCount = 0; anncCount < kNumAnnounces; anncCount++) + { + sDnsMessages.Clear(); + + AdvanceTime((anncCount == 0) ? 250 : (1U << (anncCount - 1)) * 1000); + VerifyOrQuit(sRegCallbacks[0].mWasCalled); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 2, /* Auth */ 0, /* Addnl */ 1); + dnsMsg->Validate(host, kInAnswerSection); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + } + + VerifyOrQuit(!sConflictCallback.mWasCalled); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a response for host name and validate that conflict is detected and callback is called"); + + SendResponseWithEmptyKey(hostFullName.AsCString(), kInAnswerSection); + AdvanceTime(1); + + VerifyOrQuit(sConflictCallback.mWasCalled); + VerifyOrQuit(StringMatch(sConflictCallback.mName.AsCString(), host.mHostName, kStringCaseInsensitiveMatch)); + VerifyOrQuit(!sConflictCallback.mHasServiceType); + + SuccessOrQuit(mdns->SetEnabled(false, kInfraIfIndex)); + VerifyOrQuit(sHeapAllocatedPtrs.GetLength() <= heapAllocations); + + Log("End of test"); + + testFreeInstance(sInstance); +} + +//--------------------------------------------------------------------------------------------------------------------- + +void TestServiceConflict(void) +{ + Core *mdns = InitTest(); + Core::Service service; + const DnsMessage *dnsMsg; + uint16_t heapAllocations; + DnsNameString fullServiceName; + + Log("-------------------------------------------------------------------------------------------"); + Log("TestServiceConflict"); + + service.mHostName = "myhost"; + service.mServiceInstance = "myservice"; + service.mServiceType = "_srv._udp"; + service.mSubTypeLabels = nullptr; + service.mSubTypeLabelsLength = 0; + service.mTxtData = kTxtData1; + service.mTxtDataLength = sizeof(kTxtData1); + service.mPort = 1234; + service.mPriority = 1; + service.mWeight = 2; + service.mTtl = 1000; + + fullServiceName.Append("%s.%s.local.", service.mServiceInstance, service.mServiceType); + + AdvanceTime(1); + + heapAllocations = sHeapAllocatedPtrs.GetLength(); + SuccessOrQuit(mdns->SetEnabled(true, kInfraIfIndex)); + + // Run the test twice, first run send response with record in Answer section, + // section run in Additional Data section. + + sConflictCallback.Reset(); + mdns->SetConflictCallback(HandleConflict); + + for (uint8_t iter = 0; iter < 2; iter++) + { + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Register a `ServiceEntry`, wait for first probe"); + + sDnsMessages.Clear(); + + sRegCallbacks[0].Reset(); + SuccessOrQuit(mdns->RegisterService(service, 0, HandleCallback)); + + VerifyOrQuit(!sRegCallbacks[0].mWasCalled); + AdvanceTime(250); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastQuery, /* Q */ 1, /* Ans */ 0, /* Auth */ 2, /* Addnl */ 0); + dnsMsg->ValidateAsProbeFor(service, /* aUnicastRequest */ true); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a response claiming the name with record in %s section", (iter == 0) ? "answer" : "additional"); + + SendResponseWithEmptyKey(fullServiceName.AsCString(), (iter == 0) ? kInAnswerSection : kInAdditionalSection); + AdvanceTime(1); + + VerifyOrQuit(sRegCallbacks[0].mWasCalled); + VerifyOrQuit(sRegCallbacks[0].mError == kErrorDuplicated); + + VerifyOrQuit(!sConflictCallback.mWasCalled); + + sDnsMessages.Clear(); + + SuccessOrQuit(mdns->UnregisterService(service)); + + AdvanceTime(15000); + VerifyOrQuit(sDnsMessages.IsEmpty()); + } + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Register a `ServiceEntry` and respond to probe to trigger conflict"); + + sRegCallbacks[0].Reset(); + SuccessOrQuit(mdns->RegisterService(service, 0, HandleCallback)); + + VerifyOrQuit(!sRegCallbacks[0].mWasCalled); + + SendResponseWithEmptyKey(fullServiceName.AsCString(), kInAnswerSection); + AdvanceTime(1); + + VerifyOrQuit(sRegCallbacks[0].mWasCalled); + VerifyOrQuit(sRegCallbacks[0].mError == kErrorDuplicated); + VerifyOrQuit(!sConflictCallback.mWasCalled); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Register the conflicted `ServiceEntry` again, and make sure no probes are sent"); + + sRegCallbacks[1].Reset(); + sConflictCallback.Reset(); + sDnsMessages.Clear(); + + SuccessOrQuit(mdns->RegisterService(service, 1, HandleCallback)); + AdvanceTime(5000); + + VerifyOrQuit(sRegCallbacks[1].mWasCalled); + VerifyOrQuit(sRegCallbacks[1].mError == kErrorDuplicated); + VerifyOrQuit(!sConflictCallback.mWasCalled); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Unregister the conflicted host and register it again immediately, make sure we see probes"); + + SuccessOrQuit(mdns->UnregisterService(service)); + + sConflictCallback.Reset(); + sRegCallbacks[0].Reset(); + SuccessOrQuit(mdns->RegisterService(service, 0, HandleSuccessCallback)); + + for (uint8_t probeCount = 0; probeCount < 3; probeCount++) + { + sDnsMessages.Clear(); + + VerifyOrQuit(!sRegCallbacks[0].mWasCalled); + AdvanceTime(250); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastQuery, /* Q */ 1, /* Ans */ 0, /* Auth */ 2, /* Addnl */ 0); + dnsMsg->ValidateAsProbeFor(service, /* aUnicastRequest */ (probeCount == 0)); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + } + + for (uint8_t anncCount = 0; anncCount < kNumAnnounces; anncCount++) + { + sDnsMessages.Clear(); + + AdvanceTime((anncCount == 0) ? 250 : (1U << (anncCount - 1)) * 1000); + VerifyOrQuit(sRegCallbacks[0].mWasCalled); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastResponse, /* Q */ 0, /* Ans */ 4, /* Auth */ 0, /* Addnl */ 1); + dnsMsg->Validate(service, kInAnswerSection, kCheckSrv | kCheckTxt | kCheckPtr | kCheckServicesPtr); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + } + + VerifyOrQuit(!sConflictCallback.mWasCalled); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a response for service name and validate that conflict is detected and callback is called"); + + SendResponseWithEmptyKey(fullServiceName.AsCString(), kInAnswerSection); + AdvanceTime(1); + + VerifyOrQuit(sConflictCallback.mWasCalled); + VerifyOrQuit( + StringMatch(sConflictCallback.mName.AsCString(), service.mServiceInstance, kStringCaseInsensitiveMatch)); + VerifyOrQuit(sConflictCallback.mHasServiceType); + VerifyOrQuit( + StringMatch(sConflictCallback.mServiceType.AsCString(), service.mServiceType, kStringCaseInsensitiveMatch)); + + sDnsMessages.Clear(); + AdvanceTime(20000); + VerifyOrQuit(sDnsMessages.IsEmpty()); + + SuccessOrQuit(mdns->SetEnabled(false, kInfraIfIndex)); + VerifyOrQuit(sHeapAllocatedPtrs.GetLength() <= heapAllocations); + + Log("End of test"); + + testFreeInstance(sInstance); +} + +//===================================================================================================================== +// Browser/Resolver tests + +struct BrowseCallback : public Allocatable, public LinkedListEntry +{ + BrowseCallback *mNext; + DnsName mServiceType; + DnsName mSubTypeLabel; + DnsName mServiceInstance; + uint32_t mTtl; + bool mIsSubType; +}; + +struct SrvCallback : public Allocatable, public LinkedListEntry +{ + SrvCallback *mNext; + DnsName mServiceInstance; + DnsName mServiceType; + DnsName mHostName; + uint16_t mPort; + uint16_t mPriority; + uint16_t mWeight; + uint32_t mTtl; +}; + +struct TxtCallback : public Allocatable, public LinkedListEntry +{ + static constexpr uint16_t kMaxTxtDataLength = 100; + + template bool Matches(const uint8_t (&aData)[kSize]) const + { + return (mTxtDataLength == kSize) && (memcmp(mTxtData, aData, kSize) == 0); + } + + TxtCallback *mNext; + DnsName mServiceInstance; + DnsName mServiceType; + uint8_t mTxtData[kMaxTxtDataLength]; + uint16_t mTxtDataLength; + uint32_t mTtl; +}; + +struct AddrCallback : public Allocatable, public LinkedListEntry +{ + static constexpr uint16_t kMaxNumAddrs = 16; + + bool Contains(const AddrAndTtl &aAddrAndTtl) const + { + bool contains = false; + + for (uint16_t index = 0; index < mNumAddrs; index++) + { + if (mAddrAndTtls[index] == aAddrAndTtl) + { + contains = true; + break; + } + } + + return contains; + } + + bool Matches(const AddrAndTtl *aAddrAndTtls, uint16_t aNumAddrs) const + { + bool matches = true; + + VerifyOrExit(aNumAddrs == mNumAddrs, matches = false); + + for (uint16_t index = 0; index < mNumAddrs; index++) + { + if (!Contains(aAddrAndTtls[index])) + { + ExitNow(matches = false); + } + } + + exit: + return matches; + } + + AddrCallback *mNext; + DnsName mHostName; + AddrAndTtl mAddrAndTtls[kMaxNumAddrs]; + uint16_t mNumAddrs; +}; + +OwningList sBrowseCallbacks; +OwningList sSrvCallbacks; +OwningList sTxtCallbacks; +OwningList sAddrCallbacks; + +void HandleBrowseResult(otInstance *aInstance, const otMdnsBrowseResult *aResult) +{ + BrowseCallback *entry; + + VerifyOrQuit(aInstance == sInstance); + VerifyOrQuit(aResult != nullptr); + VerifyOrQuit(aResult->mServiceType != nullptr); + VerifyOrQuit(aResult->mServiceInstance != nullptr); + VerifyOrQuit(aResult->mInfraIfIndex == kInfraIfIndex); + + Log("Browse callback: %s (subtype:%s) -> %s ttl:%lu", aResult->mServiceType, + aResult->mSubTypeLabel == nullptr ? "(null)" : aResult->mSubTypeLabel, aResult->mServiceInstance, + ToUlong(aResult->mTtl)); + + entry = BrowseCallback::Allocate(); + VerifyOrQuit(entry != nullptr); + + entry->mServiceType.CopyFrom(aResult->mServiceType); + entry->mSubTypeLabel.CopyFrom(aResult->mSubTypeLabel); + entry->mServiceInstance.CopyFrom(aResult->mServiceInstance); + entry->mTtl = aResult->mTtl; + entry->mIsSubType = (aResult->mSubTypeLabel != nullptr); + + sBrowseCallbacks.PushAfterTail(*entry); +} + +void HandleBrowseResultAlternate(otInstance *aInstance, const otMdnsBrowseResult *aResult) +{ + Log("Alternate browse callback is called"); + HandleBrowseResult(aInstance, aResult); +} + +void HandleSrvResult(otInstance *aInstance, const otMdnsSrvResult *aResult) +{ + SrvCallback *entry; + + VerifyOrQuit(aInstance == sInstance); + VerifyOrQuit(aResult != nullptr); + VerifyOrQuit(aResult->mServiceInstance != nullptr); + VerifyOrQuit(aResult->mServiceType != nullptr); + VerifyOrQuit(aResult->mInfraIfIndex == kInfraIfIndex); + + if (aResult->mTtl != 0) + { + VerifyOrQuit(aResult->mHostName != nullptr); + + Log("SRV callback: %s %s, host:%s port:%u, prio:%u, weight:%u, ttl:%lu", aResult->mServiceInstance, + aResult->mServiceType, aResult->mHostName, aResult->mPort, aResult->mPriority, aResult->mWeight, + ToUlong(aResult->mTtl)); + } + else + { + Log("SRV callback: %s %s, ttl:%lu", aResult->mServiceInstance, aResult->mServiceType, ToUlong(aResult->mTtl)); + } + + entry = SrvCallback::Allocate(); + VerifyOrQuit(entry != nullptr); + + entry->mServiceInstance.CopyFrom(aResult->mServiceInstance); + entry->mServiceType.CopyFrom(aResult->mServiceType); + entry->mHostName.CopyFrom(aResult->mHostName); + entry->mPort = aResult->mPort; + entry->mPriority = aResult->mPriority; + entry->mWeight = aResult->mWeight; + entry->mTtl = aResult->mTtl; + + sSrvCallbacks.PushAfterTail(*entry); +} + +void HandleSrvResultAlternate(otInstance *aInstance, const otMdnsSrvResult *aResult) +{ + Log("Alternate SRV callback is called"); + HandleSrvResult(aInstance, aResult); +} + +void HandleTxtResult(otInstance *aInstance, const otMdnsTxtResult *aResult) +{ + TxtCallback *entry; + + VerifyOrQuit(aInstance == sInstance); + VerifyOrQuit(aResult != nullptr); + VerifyOrQuit(aResult->mServiceInstance != nullptr); + VerifyOrQuit(aResult->mServiceType != nullptr); + VerifyOrQuit(aResult->mInfraIfIndex == kInfraIfIndex); + + VerifyOrQuit(aResult->mTxtDataLength <= TxtCallback::kMaxTxtDataLength); + + if (aResult->mTtl != 0) + { + VerifyOrQuit(aResult->mTxtData != nullptr); + + Log("TXT callback: %s %s, len:%u, ttl:%lu", aResult->mServiceInstance, aResult->mServiceType, + aResult->mTxtDataLength, ToUlong(aResult->mTtl)); + } + else + { + Log("TXT callback: %s %s, ttl:%lu", aResult->mServiceInstance, aResult->mServiceType, ToUlong(aResult->mTtl)); + } + + entry = TxtCallback::Allocate(); + VerifyOrQuit(entry != nullptr); + + entry->mServiceInstance.CopyFrom(aResult->mServiceInstance); + entry->mServiceType.CopyFrom(aResult->mServiceType); + entry->mTxtDataLength = aResult->mTxtDataLength; + memcpy(entry->mTxtData, aResult->mTxtData, aResult->mTxtDataLength); + entry->mTtl = aResult->mTtl; + + sTxtCallbacks.PushAfterTail(*entry); +} + +void HandleTxtResultAlternate(otInstance *aInstance, const otMdnsTxtResult *aResult) +{ + Log("Alternate TXT callback is called"); + HandleTxtResult(aInstance, aResult); +} + +void HandleAddrResult(otInstance *aInstance, const otMdnsAddressResult *aResult) +{ + AddrCallback *entry; + + VerifyOrQuit(aInstance == sInstance); + VerifyOrQuit(aResult != nullptr); + VerifyOrQuit(aResult->mHostName != nullptr); + VerifyOrQuit(aResult->mInfraIfIndex == kInfraIfIndex); + + VerifyOrQuit(aResult->mAddressesLength <= AddrCallback::kMaxNumAddrs); + + entry = AddrCallback::Allocate(); + VerifyOrQuit(entry != nullptr); + + entry->mHostName.CopyFrom(aResult->mHostName); + entry->mNumAddrs = aResult->mAddressesLength; + + Log("Addr callback: %s, num:%u", aResult->mHostName, aResult->mAddressesLength); + + for (uint16_t index = 0; index < aResult->mAddressesLength; index++) + { + entry->mAddrAndTtls[index].mAddress = AsCoreType(&aResult->mAddresses[index].mAddress); + entry->mAddrAndTtls[index].mTtl = aResult->mAddresses[index].mTtl; + + Log(" - %s, ttl:%lu", entry->mAddrAndTtls[index].mAddress.ToString().AsCString(), + ToUlong(entry->mAddrAndTtls[index].mTtl)); + } + + sAddrCallbacks.PushAfterTail(*entry); +} + +void HandleAddrResultAlternate(otInstance *aInstance, const otMdnsAddressResult *aResult) +{ + Log("Alternate addr callback is called"); + HandleAddrResult(aInstance, aResult); +} + +//--------------------------------------------------------------------------------------------------------------------- + +void TestBrowser(void) +{ + Core *mdns = InitTest(); + Core::Browser browser; + Core::Browser browser2; + const DnsMessage *dnsMsg; + const BrowseCallback *browseCallback; + uint16_t heapAllocations; + + Log("-------------------------------------------------------------------------------------------"); + Log("TestBrowser"); + + AdvanceTime(1); + + heapAllocations = sHeapAllocatedPtrs.GetLength(); + SuccessOrQuit(mdns->SetEnabled(true, kInfraIfIndex)); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Start a browser. Validate initial queries."); + + ClearAllBytes(browser); + + browser.mServiceType = "_srv._udp"; + browser.mSubTypeLabel = nullptr; + browser.mInfraIfIndex = kInfraIfIndex; + browser.mCallback = HandleBrowseResult; + + sDnsMessages.Clear(); + SuccessOrQuit(mdns->StartBrowser(browser)); + + for (uint8_t queryCount = 0; queryCount < kNumInitalQueries; queryCount++) + { + sDnsMessages.Clear(); + + AdvanceTime((queryCount == 0) ? 125 : (1U << (queryCount - 1)) * 1000); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastQuery, /* Q */ 1, /* Ans */ 0, /* Auth */ 0, /* Addnl */ 0); + dnsMsg->ValidateAsQueryFor(browser); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + } + + sDnsMessages.Clear(); + + AdvanceTime(20000); + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a response. Validate callback result."); + + sBrowseCallbacks.Clear(); + + SendPtrResponse("_srv._udp.local.", "mysrv._srv._udp.local.", 120, kInAnswerSection); + + AdvanceTime(1); + + VerifyOrQuit(!sBrowseCallbacks.IsEmpty()); + browseCallback = sBrowseCallbacks.GetHead(); + VerifyOrQuit(browseCallback->mServiceType.Matches("_srv._udp")); + VerifyOrQuit(!browseCallback->mIsSubType); + VerifyOrQuit(browseCallback->mServiceInstance.Matches("mysrv")); + VerifyOrQuit(browseCallback->mTtl == 120); + VerifyOrQuit(browseCallback->GetNext() == nullptr); + + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send another response. Validate callback result."); + + AdvanceTime(10000); + + sBrowseCallbacks.Clear(); + + SendPtrResponse("_srv._udp.local.", "awesome._srv._udp.local.", 500, kInAnswerSection); + + AdvanceTime(1); + + VerifyOrQuit(!sBrowseCallbacks.IsEmpty()); + browseCallback = sBrowseCallbacks.GetHead(); + VerifyOrQuit(browseCallback->mServiceType.Matches("_srv._udp")); + VerifyOrQuit(!browseCallback->mIsSubType); + VerifyOrQuit(browseCallback->mServiceInstance.Matches("awesome")); + VerifyOrQuit(browseCallback->mTtl == 500); + VerifyOrQuit(browseCallback->GetNext() == nullptr); + + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Start another browser for the same service and different callback. Validate results."); + + AdvanceTime(5000); + + browser2.mServiceType = "_srv._udp"; + browser2.mSubTypeLabel = nullptr; + browser2.mInfraIfIndex = kInfraIfIndex; + browser2.mCallback = HandleBrowseResultAlternate; + + sBrowseCallbacks.Clear(); + + SuccessOrQuit(mdns->StartBrowser(browser2)); + + browseCallback = sBrowseCallbacks.GetHead(); + + for (uint8_t iter = 0; iter < 2; iter++) + { + VerifyOrQuit(browseCallback != nullptr); + + VerifyOrQuit(browseCallback->mServiceType.Matches("_srv._udp")); + VerifyOrQuit(!browseCallback->mIsSubType); + + if (browseCallback->mServiceInstance.Matches("awesome")) + { + VerifyOrQuit(browseCallback->mTtl == 500); + } + else if (browseCallback->mServiceInstance.Matches("mysrv")) + { + VerifyOrQuit(browseCallback->mTtl == 120); + } + else + { + VerifyOrQuit(false); + } + + browseCallback = browseCallback->GetNext(); + } + + VerifyOrQuit(browseCallback == nullptr); + + AdvanceTime(5000); + + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Start same browser again and check the returned error."); + + sBrowseCallbacks.Clear(); + + VerifyOrQuit(mdns->StartBrowser(browser2) == kErrorAlready); + + AdvanceTime(5000); + + VerifyOrQuit(sBrowseCallbacks.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a goodbye response. Validate result callback for both browsers."); + + SendPtrResponse("_srv._udp.local.", "awesome._srv._udp.local.", 0, kInAnswerSection); + + AdvanceTime(1); + + browseCallback = sBrowseCallbacks.GetHead(); + + for (uint8_t iter = 0; iter < 2; iter++) + { + VerifyOrQuit(browseCallback != nullptr); + + VerifyOrQuit(browseCallback->mServiceType.Matches("_srv._udp")); + VerifyOrQuit(!browseCallback->mIsSubType); + VerifyOrQuit(browseCallback->mServiceInstance.Matches("awesome")); + VerifyOrQuit(browseCallback->mTtl == 0); + + browseCallback = browseCallback->GetNext(); + } + + VerifyOrQuit(browseCallback == nullptr); + + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a response with no changes, validate that no callback is invoked."); + + sBrowseCallbacks.Clear(); + + SendPtrResponse("_srv._udp.local.", "mysrv._srv._udp.local.", 120, kInAnswerSection); + + AdvanceTime(1); + + VerifyOrQuit(sBrowseCallbacks.IsEmpty()); + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Stop the second browser."); + + sBrowseCallbacks.Clear(); + + SuccessOrQuit(mdns->StopBrowser(browser2)); + + AdvanceTime(5000); + + VerifyOrQuit(sBrowseCallbacks.IsEmpty()); + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Check query is sent at 80 percentage of TTL and then respond to it."); + + // First query should be sent at 80-82% of TTL of 120 second (96.0-98.4 sec). + // We wait for 100 second. Note that 5 seconds already passed in the + // previous step. + + AdvanceTime(91 * 1000 - 1); + + VerifyOrQuit(sDnsMessages.IsEmpty()); + + AdvanceTime(4 * 1000 + 1); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastQuery, /* Q */ 1, /* Ans */ 0, /* Auth */ 0, /* Addnl */ 0); + dnsMsg->ValidateAsQueryFor(browser); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + + sDnsMessages.Clear(); + VerifyOrQuit(sBrowseCallbacks.IsEmpty()); + + AdvanceTime(10); + + SendPtrResponse("_srv._udp.local.", "mysrv._srv._udp.local.", 120, kInAnswerSection); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Check queries are sent at 80, 85, 90, 95 percentages of TTL."); + + for (uint8_t queryCount = 0; queryCount < kNumRefreshQueries; queryCount++) + { + if (queryCount == 0) + { + // First query is expected in 80-82% of TTL, so + // 80% of 120 = 96.0, 82% of 120 = 98.4 + + AdvanceTime(96 * 1000 - 1); + } + else + { + // Next query should happen within 3%-5% of TTL + // from previous query. We wait 3% of TTL here. + AdvanceTime(3600 - 1); + } + + VerifyOrQuit(sDnsMessages.IsEmpty()); + + // Wait for 2% of TTL of 120 which is 2.4 sec. + + AdvanceTime(2400 + 1); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastQuery, /* Q */ 1, /* Ans */ 0, /* Auth */ 0, /* Addnl */ 0); + dnsMsg->ValidateAsQueryFor(browser); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + + sDnsMessages.Clear(); + VerifyOrQuit(sBrowseCallbacks.IsEmpty()); + } + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Check TTL timeout and callback result."); + + AdvanceTime(6 * 1000); + + VerifyOrQuit(!sBrowseCallbacks.IsEmpty()); + + browseCallback = sBrowseCallbacks.GetHead(); + VerifyOrQuit(browseCallback->mServiceType.Matches("_srv._udp")); + VerifyOrQuit(!browseCallback->mIsSubType); + VerifyOrQuit(browseCallback->mServiceInstance.Matches("mysrv")); + VerifyOrQuit(browseCallback->mTtl == 0); + VerifyOrQuit(browseCallback->GetNext() == nullptr); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + + sBrowseCallbacks.Clear(); + sDnsMessages.Clear(); + + AdvanceTime(200 * 1000); + + VerifyOrQuit(sBrowseCallbacks.IsEmpty()); + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a new response and make sure result callback is invoked"); + + SendPtrResponse("_srv._udp.local.", "great._srv._udp.local.", 200, kInAdditionalSection); + + AdvanceTime(1); + + browseCallback = sBrowseCallbacks.GetHead(); + + VerifyOrQuit(browseCallback->mServiceType.Matches("_srv._udp")); + VerifyOrQuit(!browseCallback->mIsSubType); + VerifyOrQuit(browseCallback->mServiceInstance.Matches("great")); + VerifyOrQuit(browseCallback->mTtl == 200); + VerifyOrQuit(browseCallback->GetNext() == nullptr); + + sBrowseCallbacks.Clear(); + + AdvanceTime(150 * 1000); + + VerifyOrQuit(sDnsMessages.IsEmpty()); + VerifyOrQuit(sBrowseCallbacks.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Stop the browser. There is no active browser for this service. Ensure no queries are sent"); + + sBrowseCallbacks.Clear(); + + SuccessOrQuit(mdns->StopBrowser(browser)); + + AdvanceTime(100 * 1000); + + VerifyOrQuit(sBrowseCallbacks.IsEmpty()); + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Start browser again. Validate that initial queries are sent again"); + + SuccessOrQuit(mdns->StartBrowser(browser)); + + AdvanceTime(125); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastQuery, /* Q */ 1, /* Ans */ 0, /* Auth */ 0, /* Addnl */ 0); + dnsMsg->ValidateAsQueryFor(browser); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a response after the first initial query"); + + sDnsMessages.Clear(); + + SendPtrResponse("_srv._udp.local.", "mysrv._srv._udp.local.", 120, kInAnswerSection); + + AdvanceTime(1); + + browseCallback = sBrowseCallbacks.GetHead(); + + VerifyOrQuit(browseCallback->mServiceType.Matches("_srv._udp")); + VerifyOrQuit(!browseCallback->mIsSubType); + VerifyOrQuit(browseCallback->mServiceInstance.Matches("mysrv")); + VerifyOrQuit(browseCallback->mTtl == 120); + VerifyOrQuit(browseCallback->GetNext() == nullptr); + + sBrowseCallbacks.Clear(); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Validate initial esquires are still sent and include known-answer"); + + for (uint8_t queryCount = 1; queryCount < kNumInitalQueries; queryCount++) + { + sDnsMessages.Clear(); + + AdvanceTime((1U << (queryCount - 1)) * 1000); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastQuery, /* Q */ 1, /* Ans */ 1, /* Auth */ 0, /* Addnl */ 0); + dnsMsg->ValidateAsQueryFor(browser); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + } + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + + sDnsMessages.Clear(); + AdvanceTime(50 * 1000); + VerifyOrQuit(sDnsMessages.IsEmpty()); + + SuccessOrQuit(mdns->SetEnabled(false, kInfraIfIndex)); + VerifyOrQuit(sHeapAllocatedPtrs.GetLength() <= heapAllocations); + + Log("End of test"); + + testFreeInstance(sInstance); +} + +void TestSrvResolver(void) +{ + Core *mdns = InitTest(); + Core::SrvResolver resolver; + Core::SrvResolver resolver2; + const DnsMessage *dnsMsg; + const SrvCallback *srvCallback; + uint16_t heapAllocations; + + Log("-------------------------------------------------------------------------------------------"); + Log("TestSrvResolver"); + + AdvanceTime(1); + + heapAllocations = sHeapAllocatedPtrs.GetLength(); + SuccessOrQuit(mdns->SetEnabled(true, kInfraIfIndex)); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Start a SRV resolver. Validate initial queries."); + + ClearAllBytes(resolver); + + resolver.mServiceInstance = "mysrv"; + resolver.mServiceType = "_srv._udp"; + resolver.mInfraIfIndex = kInfraIfIndex; + resolver.mCallback = HandleSrvResult; + + sDnsMessages.Clear(); + SuccessOrQuit(mdns->StartSrvResolver(resolver)); + + for (uint8_t queryCount = 0; queryCount < kNumInitalQueries; queryCount++) + { + sDnsMessages.Clear(); + + AdvanceTime((queryCount == 0) ? 125 : (1U << (queryCount - 1)) * 1000); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastQuery, /* Q */ 1, /* Ans */ 0, /* Auth */ 0, /* Addnl */ 0); + dnsMsg->ValidateAsQueryFor(resolver); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + } + + sDnsMessages.Clear(); + + AdvanceTime(20 * 1000); + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a response. Validate callback result."); + + sSrvCallbacks.Clear(); + + SendSrvResponse("mysrv._srv._udp.local.", "myhost.local.", 1234, 0, 1, 120, kInAnswerSection); + + AdvanceTime(1); + + VerifyOrQuit(!sSrvCallbacks.IsEmpty()); + srvCallback = sSrvCallbacks.GetHead(); + VerifyOrQuit(srvCallback->mServiceInstance.Matches("mysrv")); + VerifyOrQuit(srvCallback->mServiceType.Matches("_srv._udp")); + VerifyOrQuit(srvCallback->mHostName.Matches("myhost")); + VerifyOrQuit(srvCallback->mPort == 1234); + VerifyOrQuit(srvCallback->mPriority == 0); + VerifyOrQuit(srvCallback->mWeight == 1); + VerifyOrQuit(srvCallback->mTtl == 120); + VerifyOrQuit(srvCallback->GetNext() == nullptr); + + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send an updated response changing host name. Validate callback result."); + + AdvanceTime(1000); + + sSrvCallbacks.Clear(); + + SendSrvResponse("mysrv._srv._udp.local.", "myhost2.local.", 1234, 0, 1, 120, kInAnswerSection); + + AdvanceTime(1); + + VerifyOrQuit(!sSrvCallbacks.IsEmpty()); + srvCallback = sSrvCallbacks.GetHead(); + VerifyOrQuit(srvCallback->mServiceInstance.Matches("mysrv")); + VerifyOrQuit(srvCallback->mServiceType.Matches("_srv._udp")); + VerifyOrQuit(srvCallback->mHostName.Matches("myhost2")); + VerifyOrQuit(srvCallback->mPort == 1234); + VerifyOrQuit(srvCallback->mPriority == 0); + VerifyOrQuit(srvCallback->mWeight == 1); + VerifyOrQuit(srvCallback->mTtl == 120); + VerifyOrQuit(srvCallback->GetNext() == nullptr); + + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send an updated response changing port. Validate callback result."); + + AdvanceTime(1000); + + sSrvCallbacks.Clear(); + + SendSrvResponse("mysrv._srv._udp.local.", "myhost2.local.", 4567, 0, 1, 120, kInAnswerSection); + + AdvanceTime(1); + + VerifyOrQuit(!sSrvCallbacks.IsEmpty()); + srvCallback = sSrvCallbacks.GetHead(); + VerifyOrQuit(srvCallback->mServiceInstance.Matches("mysrv")); + VerifyOrQuit(srvCallback->mServiceType.Matches("_srv._udp")); + VerifyOrQuit(srvCallback->mHostName.Matches("myhost2")); + VerifyOrQuit(srvCallback->mPort == 4567); + VerifyOrQuit(srvCallback->mPriority == 0); + VerifyOrQuit(srvCallback->mWeight == 1); + VerifyOrQuit(srvCallback->mTtl == 120); + VerifyOrQuit(srvCallback->GetNext() == nullptr); + + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send an updated response changing TTL. Validate callback result."); + + AdvanceTime(1000); + + sSrvCallbacks.Clear(); + + SendSrvResponse("mysrv._srv._udp.local.", "myhost2.local.", 4567, 0, 1, 0, kInAnswerSection); + + AdvanceTime(1); + + VerifyOrQuit(!sSrvCallbacks.IsEmpty()); + srvCallback = sSrvCallbacks.GetHead(); + VerifyOrQuit(srvCallback->mServiceInstance.Matches("mysrv")); + VerifyOrQuit(srvCallback->mServiceType.Matches("_srv._udp")); + VerifyOrQuit(srvCallback->mHostName.Matches("")); + VerifyOrQuit(srvCallback->mPort == 4567); + VerifyOrQuit(srvCallback->mPriority == 0); + VerifyOrQuit(srvCallback->mWeight == 1); + VerifyOrQuit(srvCallback->mTtl == 0); + VerifyOrQuit(srvCallback->GetNext() == nullptr); + + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send an updated response changing a bunch of things. Validate callback result."); + + AdvanceTime(1000); + + sSrvCallbacks.Clear(); + + SendSrvResponse("mysrv._srv._udp.local.", "myhost.local.", 1234, 2, 3, 120, kInAnswerSection); + + AdvanceTime(1); + + VerifyOrQuit(!sSrvCallbacks.IsEmpty()); + srvCallback = sSrvCallbacks.GetHead(); + VerifyOrQuit(srvCallback->mServiceInstance.Matches("mysrv")); + VerifyOrQuit(srvCallback->mServiceType.Matches("_srv._udp")); + VerifyOrQuit(srvCallback->mHostName.Matches("myhost")); + VerifyOrQuit(srvCallback->mPort == 1234); + VerifyOrQuit(srvCallback->mPriority == 2); + VerifyOrQuit(srvCallback->mWeight == 3); + VerifyOrQuit(srvCallback->mTtl == 120); + VerifyOrQuit(srvCallback->GetNext() == nullptr); + + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a response with no changes. Validate callback is not invoked."); + + AdvanceTime(1000); + + sSrvCallbacks.Clear(); + + SendSrvResponse("mysrv._srv._udp.local.", "myhost.local.", 1234, 2, 3, 120, kInAnswerSection); + + AdvanceTime(1); + + VerifyOrQuit(sSrvCallbacks.IsEmpty()); + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Start another resolver for the same service and different callback. Validate results."); + + ClearAllBytes(resolver2); + + resolver2.mServiceInstance = "mysrv"; + resolver2.mServiceType = "_srv._udp"; + resolver2.mInfraIfIndex = kInfraIfIndex; + resolver2.mCallback = HandleSrvResultAlternate; + + sSrvCallbacks.Clear(); + + SuccessOrQuit(mdns->StartSrvResolver(resolver2)); + + AdvanceTime(1); + + VerifyOrQuit(!sSrvCallbacks.IsEmpty()); + srvCallback = sSrvCallbacks.GetHead(); + VerifyOrQuit(srvCallback->mServiceInstance.Matches("mysrv")); + VerifyOrQuit(srvCallback->mServiceType.Matches("_srv._udp")); + VerifyOrQuit(srvCallback->mHostName.Matches("myhost")); + VerifyOrQuit(srvCallback->mPort == 1234); + VerifyOrQuit(srvCallback->mPriority == 2); + VerifyOrQuit(srvCallback->mWeight == 3); + VerifyOrQuit(srvCallback->mTtl == 120); + VerifyOrQuit(srvCallback->GetNext() == nullptr); + + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Start same resolver again and check the returned error."); + + sSrvCallbacks.Clear(); + + VerifyOrQuit(mdns->StartSrvResolver(resolver2) == kErrorAlready); + + AdvanceTime(5000); + + VerifyOrQuit(sSrvCallbacks.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Check query is sent at 80 percentage of TTL and then respond to it."); + + SendSrvResponse("mysrv._srv._udp.local.", "myhost.local.", 1234, 2, 3, 120, kInAnswerSection); + + // First query should be sent at 80-82% of TTL of 120 second (96.0-98.4 sec). + // We wait for 100 second. Note that 5 seconds already passed in the + // previous step. + + AdvanceTime(96 * 1000 - 1); + + VerifyOrQuit(sDnsMessages.IsEmpty()); + + AdvanceTime(4 * 1000 + 1); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastQuery, /* Q */ 1, /* Ans */ 0, /* Auth */ 0, /* Addnl */ 0); + dnsMsg->ValidateAsQueryFor(resolver); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + + sDnsMessages.Clear(); + VerifyOrQuit(sSrvCallbacks.IsEmpty()); + + AdvanceTime(10); + + SendSrvResponse("mysrv._srv._udp.local.", "myhost.local.", 1234, 2, 3, 120, kInAnswerSection); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Check queries are sent at 80, 85, 90, 95 percentages of TTL."); + + for (uint8_t queryCount = 0; queryCount < kNumRefreshQueries; queryCount++) + { + if (queryCount == 0) + { + // First query is expected in 80-82% of TTL, so + // 80% of 120 = 96.0, 82% of 120 = 98.4 + + AdvanceTime(96 * 1000 - 1); + } + else + { + // Next query should happen within 3%-5% of TTL + // from previous query. We wait 3% of TTL here. + AdvanceTime(3600 - 1); + } + + VerifyOrQuit(sDnsMessages.IsEmpty()); + + // Wait for 2% of TTL of 120 which is 2.4 sec. + + AdvanceTime(2400 + 1); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastQuery, /* Q */ 1, /* Ans */ 0, /* Auth */ 0, /* Addnl */ 0); + dnsMsg->ValidateAsQueryFor(resolver); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + + sDnsMessages.Clear(); + VerifyOrQuit(sSrvCallbacks.IsEmpty()); + } + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Check TTL timeout and callback result."); + + AdvanceTime(6 * 1000); + + srvCallback = sSrvCallbacks.GetHead(); + + for (uint8_t iter = 0; iter < 2; iter++) + { + VerifyOrQuit(srvCallback != nullptr); + VerifyOrQuit(srvCallback->mServiceInstance.Matches("mysrv")); + VerifyOrQuit(srvCallback->mServiceType.Matches("_srv._udp")); + VerifyOrQuit(srvCallback->mTtl == 0); + srvCallback = srvCallback->GetNext(); + } + + VerifyOrQuit(srvCallback == nullptr); + + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + + sSrvCallbacks.Clear(); + sDnsMessages.Clear(); + + AdvanceTime(200 * 1000); + + VerifyOrQuit(sSrvCallbacks.IsEmpty()); + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Stop the second resolver"); + + sSrvCallbacks.Clear(); + + SuccessOrQuit(mdns->StopSrvResolver(resolver2)); + + AdvanceTime(100 * 1000); + + VerifyOrQuit(sSrvCallbacks.IsEmpty()); + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a new response and make sure result callback is invoked"); + + SendSrvResponse("mysrv._srv._udp.local.", "myhost.local.", 1234, 2, 3, 120, kInAnswerSection); + + AdvanceTime(1); + + VerifyOrQuit(!sSrvCallbacks.IsEmpty()); + srvCallback = sSrvCallbacks.GetHead(); + VerifyOrQuit(srvCallback->mServiceInstance.Matches("mysrv")); + VerifyOrQuit(srvCallback->mServiceType.Matches("_srv._udp")); + VerifyOrQuit(srvCallback->mHostName.Matches("myhost")); + VerifyOrQuit(srvCallback->mPort == 1234); + VerifyOrQuit(srvCallback->mPriority == 2); + VerifyOrQuit(srvCallback->mWeight == 3); + VerifyOrQuit(srvCallback->mTtl == 120); + VerifyOrQuit(srvCallback->GetNext() == nullptr); + + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Stop the resolver. There is no active resolver. Ensure no queries are sent"); + + sSrvCallbacks.Clear(); + + SuccessOrQuit(mdns->StopSrvResolver(resolver)); + + AdvanceTime(20 * 1000); + + VerifyOrQuit(sSrvCallbacks.IsEmpty()); + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Restart the resolver with more than half of TTL remaining."); + Log("Ensure cached entry is reported in the result callback and no queries are sent."); + + SuccessOrQuit(mdns->StartSrvResolver(resolver)); + + AdvanceTime(1); + + VerifyOrQuit(!sSrvCallbacks.IsEmpty()); + srvCallback = sSrvCallbacks.GetHead(); + VerifyOrQuit(srvCallback->mServiceInstance.Matches("mysrv")); + VerifyOrQuit(srvCallback->mServiceType.Matches("_srv._udp")); + VerifyOrQuit(srvCallback->mHostName.Matches("myhost")); + VerifyOrQuit(srvCallback->mPort == 1234); + VerifyOrQuit(srvCallback->mPriority == 2); + VerifyOrQuit(srvCallback->mWeight == 3); + VerifyOrQuit(srvCallback->mTtl == 120); + VerifyOrQuit(srvCallback->GetNext() == nullptr); + + VerifyOrQuit(sDnsMessages.IsEmpty()); + + AdvanceTime(20 * 1000); + + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Stop and start the resolver again after less than half TTL remaining."); + Log("Ensure cached entry is still reported in the result callback but queries should be sent"); + + sSrvCallbacks.Clear(); + + SuccessOrQuit(mdns->StopSrvResolver(resolver)); + + AdvanceTime(25 * 1000); + + SuccessOrQuit(mdns->StartSrvResolver(resolver)); + + AdvanceTime(1); + + VerifyOrQuit(!sSrvCallbacks.IsEmpty()); + srvCallback = sSrvCallbacks.GetHead(); + VerifyOrQuit(srvCallback->mServiceInstance.Matches("mysrv")); + VerifyOrQuit(srvCallback->mServiceType.Matches("_srv._udp")); + VerifyOrQuit(srvCallback->mHostName.Matches("myhost")); + VerifyOrQuit(srvCallback->mPort == 1234); + VerifyOrQuit(srvCallback->mPriority == 2); + VerifyOrQuit(srvCallback->mWeight == 3); + VerifyOrQuit(srvCallback->mTtl == 120); + VerifyOrQuit(srvCallback->GetNext() == nullptr); + + sSrvCallbacks.Clear(); + + AdvanceTime(15 * 1000); + + dnsMsg = sDnsMessages.GetHead(); + + for (uint8_t queryCount = 0; queryCount < kNumInitalQueries; queryCount++) + { + VerifyOrQuit(dnsMsg != nullptr); + dnsMsg->ValidateHeader(kMulticastQuery, /* Q */ 1, /* Ans */ 0, /* Auth */ 0, /* Addnl */ 0); + dnsMsg->ValidateAsQueryFor(resolver); + dnsMsg = dnsMsg->GetNext(); + } + + VerifyOrQuit(dnsMsg == nullptr); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + + SuccessOrQuit(mdns->SetEnabled(false, kInfraIfIndex)); + VerifyOrQuit(sHeapAllocatedPtrs.GetLength() <= heapAllocations); + + Log("End of test"); + + testFreeInstance(sInstance); +} + +void TestTxtResolver(void) +{ + Core *mdns = InitTest(); + Core::TxtResolver resolver; + Core::TxtResolver resolver2; + const DnsMessage *dnsMsg; + const TxtCallback *txtCallback; + uint16_t heapAllocations; + + Log("-------------------------------------------------------------------------------------------"); + Log("TestTxtResolver"); + + AdvanceTime(1); + + heapAllocations = sHeapAllocatedPtrs.GetLength(); + SuccessOrQuit(mdns->SetEnabled(true, kInfraIfIndex)); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Start a TXT resolver. Validate initial queries."); + + ClearAllBytes(resolver); + + resolver.mServiceInstance = "mysrv"; + resolver.mServiceType = "_srv._udp"; + resolver.mInfraIfIndex = kInfraIfIndex; + resolver.mCallback = HandleTxtResult; + + sDnsMessages.Clear(); + SuccessOrQuit(mdns->StartTxtResolver(resolver)); + + for (uint8_t queryCount = 0; queryCount < kNumInitalQueries; queryCount++) + { + sDnsMessages.Clear(); + + AdvanceTime((queryCount == 0) ? 125 : (1U << (queryCount - 1)) * 1000); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastQuery, /* Q */ 1, /* Ans */ 0, /* Auth */ 0, /* Addnl */ 0); + dnsMsg->ValidateAsQueryFor(resolver); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + } + + sDnsMessages.Clear(); + + AdvanceTime(20 * 1000); + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a response. Validate callback result."); + + sTxtCallbacks.Clear(); + + SendTxtResponse("mysrv._srv._udp.local.", kTxtData1, sizeof(kTxtData1), 120, kInAnswerSection); + + AdvanceTime(1); + + VerifyOrQuit(!sTxtCallbacks.IsEmpty()); + txtCallback = sTxtCallbacks.GetHead(); + VerifyOrQuit(txtCallback->mServiceInstance.Matches("mysrv")); + VerifyOrQuit(txtCallback->mServiceType.Matches("_srv._udp")); + VerifyOrQuit(txtCallback->Matches(kTxtData1)); + VerifyOrQuit(txtCallback->mTtl == 120); + VerifyOrQuit(txtCallback->GetNext() == nullptr); + + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send an updated response changing TXT data. Validate callback result."); + + AdvanceTime(1000); + + sTxtCallbacks.Clear(); + + SendTxtResponse("mysrv._srv._udp.local.", kTxtData2, sizeof(kTxtData2), 120, kInAnswerSection); + + AdvanceTime(1); + + VerifyOrQuit(!sTxtCallbacks.IsEmpty()); + txtCallback = sTxtCallbacks.GetHead(); + VerifyOrQuit(txtCallback->mServiceInstance.Matches("mysrv")); + VerifyOrQuit(txtCallback->mServiceType.Matches("_srv._udp")); + VerifyOrQuit(txtCallback->Matches(kTxtData2)); + VerifyOrQuit(txtCallback->mTtl == 120); + VerifyOrQuit(txtCallback->GetNext() == nullptr); + + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send an updated response changing TXT data to empty. Validate callback result."); + + AdvanceTime(1000); + + sTxtCallbacks.Clear(); + + SendTxtResponse("mysrv._srv._udp.local.", kEmptyTxtData, sizeof(kEmptyTxtData), 120, kInAnswerSection); + + AdvanceTime(1); + + VerifyOrQuit(!sTxtCallbacks.IsEmpty()); + txtCallback = sTxtCallbacks.GetHead(); + VerifyOrQuit(txtCallback->mServiceInstance.Matches("mysrv")); + VerifyOrQuit(txtCallback->mServiceType.Matches("_srv._udp")); + VerifyOrQuit(txtCallback->Matches(kEmptyTxtData)); + VerifyOrQuit(txtCallback->mTtl == 120); + VerifyOrQuit(txtCallback->GetNext() == nullptr); + + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send an updated response changing TTL. Validate callback result."); + + AdvanceTime(1000); + + sTxtCallbacks.Clear(); + + SendTxtResponse("mysrv._srv._udp.local.", kEmptyTxtData, sizeof(kEmptyTxtData), 500, kInAnswerSection); + + AdvanceTime(1); + + VerifyOrQuit(!sTxtCallbacks.IsEmpty()); + txtCallback = sTxtCallbacks.GetHead(); + VerifyOrQuit(txtCallback->mServiceInstance.Matches("mysrv")); + VerifyOrQuit(txtCallback->mServiceType.Matches("_srv._udp")); + VerifyOrQuit(txtCallback->Matches(kEmptyTxtData)); + VerifyOrQuit(txtCallback->mTtl == 500); + VerifyOrQuit(txtCallback->GetNext() == nullptr); + + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send an updated response with zero TTL. Validate callback result."); + + AdvanceTime(1000); + + sTxtCallbacks.Clear(); + + SendTxtResponse("mysrv._srv._udp.local.", kEmptyTxtData, sizeof(kEmptyTxtData), 0, kInAnswerSection); + + AdvanceTime(1); + + VerifyOrQuit(!sTxtCallbacks.IsEmpty()); + txtCallback = sTxtCallbacks.GetHead(); + VerifyOrQuit(txtCallback->mServiceInstance.Matches("mysrv")); + VerifyOrQuit(txtCallback->mServiceType.Matches("_srv._udp")); + VerifyOrQuit(txtCallback->mTtl == 0); + VerifyOrQuit(txtCallback->GetNext() == nullptr); + + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send an updated response. Validate callback result."); + + sTxtCallbacks.Clear(); + AdvanceTime(100 * 1000); + + SendTxtResponse("mysrv._srv._udp.local.", kTxtData1, sizeof(kTxtData1), 120, kInAnswerSection); + + AdvanceTime(1); + + VerifyOrQuit(!sTxtCallbacks.IsEmpty()); + txtCallback = sTxtCallbacks.GetHead(); + VerifyOrQuit(txtCallback->mServiceInstance.Matches("mysrv")); + VerifyOrQuit(txtCallback->mServiceType.Matches("_srv._udp")); + VerifyOrQuit(txtCallback->Matches(kTxtData1)); + VerifyOrQuit(txtCallback->mTtl == 120); + VerifyOrQuit(txtCallback->GetNext() == nullptr); + + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a response with no changes. Validate callback is not invoked."); + + AdvanceTime(1000); + + sTxtCallbacks.Clear(); + + SendTxtResponse("mysrv._srv._udp.local.", kTxtData1, sizeof(kTxtData1), 120, kInAnswerSection); + + AdvanceTime(100); + + VerifyOrQuit(sTxtCallbacks.IsEmpty()); + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Start another resolver for the same service and different callback. Validate results."); + + resolver2.mServiceInstance = "mysrv"; + resolver2.mServiceType = "_srv._udp"; + resolver2.mInfraIfIndex = kInfraIfIndex; + resolver2.mCallback = HandleTxtResultAlternate; + + sTxtCallbacks.Clear(); + + SuccessOrQuit(mdns->StartTxtResolver(resolver2)); + + AdvanceTime(1); + + VerifyOrQuit(!sTxtCallbacks.IsEmpty()); + txtCallback = sTxtCallbacks.GetHead(); + VerifyOrQuit(txtCallback->mServiceInstance.Matches("mysrv")); + VerifyOrQuit(txtCallback->mServiceType.Matches("_srv._udp")); + VerifyOrQuit(txtCallback->Matches(kTxtData1)); + VerifyOrQuit(txtCallback->mTtl == 120); + VerifyOrQuit(txtCallback->GetNext() == nullptr); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Start same resolver again and check the returned error."); + + sTxtCallbacks.Clear(); + + VerifyOrQuit(mdns->StartTxtResolver(resolver2) == kErrorAlready); + + AdvanceTime(5000); + + VerifyOrQuit(sTxtCallbacks.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Check query is sent at 80 percentage of TTL and then respond to it."); + + SendTxtResponse("mysrv._srv._udp.local.", kTxtData1, sizeof(kTxtData1), 120, kInAnswerSection); + + // First query should be sent at 80-82% of TTL of 120 second (96.0-98.4 sec). + // We wait for 100 second. Note that 5 seconds already passed in the + // previous step. + + AdvanceTime(96 * 1000 - 1); + + VerifyOrQuit(sDnsMessages.IsEmpty()); + + AdvanceTime(4 * 1000 + 1); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastQuery, /* Q */ 1, /* Ans */ 0, /* Auth */ 0, /* Addnl */ 0); + dnsMsg->ValidateAsQueryFor(resolver); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + + sDnsMessages.Clear(); + VerifyOrQuit(sTxtCallbacks.IsEmpty()); + + AdvanceTime(10); + + SendTxtResponse("mysrv._srv._udp.local.", kTxtData1, sizeof(kTxtData1), 120, kInAnswerSection); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Check queries are sent at 80, 85, 90, 95 percentages of TTL."); + + for (uint8_t queryCount = 0; queryCount < kNumRefreshQueries; queryCount++) + { + if (queryCount == 0) + { + // First query is expected in 80-82% of TTL, so + // 80% of 120 = 96.0, 82% of 120 = 98.4 + + AdvanceTime(96 * 1000 - 1); + } + else + { + // Next query should happen within 3%-5% of TTL + // from previous query. We wait 3% of TTL here. + AdvanceTime(3600 - 1); + } + + VerifyOrQuit(sDnsMessages.IsEmpty()); + + // Wait for 2% of TTL of 120 which is 2.4 sec. + + AdvanceTime(2400 + 1); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastQuery, /* Q */ 1, /* Ans */ 0, /* Auth */ 0, /* Addnl */ 0); + dnsMsg->ValidateAsQueryFor(resolver); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + + sDnsMessages.Clear(); + VerifyOrQuit(sTxtCallbacks.IsEmpty()); + } + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Check TTL timeout and callback result."); + + AdvanceTime(6 * 1000); + + txtCallback = sTxtCallbacks.GetHead(); + + for (uint8_t iter = 0; iter < 2; iter++) + { + VerifyOrQuit(txtCallback != nullptr); + VerifyOrQuit(txtCallback->mServiceInstance.Matches("mysrv")); + VerifyOrQuit(txtCallback->mServiceType.Matches("_srv._udp")); + VerifyOrQuit(txtCallback->mTtl == 0); + txtCallback = txtCallback->GetNext(); + } + + VerifyOrQuit(txtCallback == nullptr); + + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + + sTxtCallbacks.Clear(); + sDnsMessages.Clear(); + + AdvanceTime(200 * 1000); + + VerifyOrQuit(sTxtCallbacks.IsEmpty()); + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Stop the second resolver"); + + sTxtCallbacks.Clear(); + + SuccessOrQuit(mdns->StopTxtResolver(resolver2)); + + AdvanceTime(100 * 1000); + + VerifyOrQuit(sTxtCallbacks.IsEmpty()); + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a new response and make sure result callback is invoked"); + + SendTxtResponse("mysrv._srv._udp.local.", kTxtData1, sizeof(kTxtData1), 120, kInAnswerSection); + + AdvanceTime(1); + + VerifyOrQuit(!sTxtCallbacks.IsEmpty()); + txtCallback = sTxtCallbacks.GetHead(); + VerifyOrQuit(txtCallback->mServiceInstance.Matches("mysrv")); + VerifyOrQuit(txtCallback->mServiceType.Matches("_srv._udp")); + VerifyOrQuit(txtCallback->Matches(kTxtData1)); + VerifyOrQuit(txtCallback->mTtl == 120); + VerifyOrQuit(txtCallback->GetNext() == nullptr); + + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Stop the resolver. There is no active resolver. Ensure no queries are sent"); + + sTxtCallbacks.Clear(); + + SuccessOrQuit(mdns->StopTxtResolver(resolver)); + + AdvanceTime(20 * 1000); + + VerifyOrQuit(sTxtCallbacks.IsEmpty()); + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Restart the resolver with more than half of TTL remaining."); + Log("Ensure cached entry is reported in the result callback and no queries are sent."); + + SuccessOrQuit(mdns->StartTxtResolver(resolver)); + + AdvanceTime(1); + + VerifyOrQuit(!sTxtCallbacks.IsEmpty()); + txtCallback = sTxtCallbacks.GetHead(); + VerifyOrQuit(txtCallback->mServiceInstance.Matches("mysrv")); + VerifyOrQuit(txtCallback->mServiceType.Matches("_srv._udp")); + VerifyOrQuit(txtCallback->Matches(kTxtData1)); + VerifyOrQuit(txtCallback->mTtl == 120); + VerifyOrQuit(txtCallback->GetNext() == nullptr); + + VerifyOrQuit(sDnsMessages.IsEmpty()); + + AdvanceTime(20 * 1000); + + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Stop and start the resolver again after less than half TTL remaining."); + Log("Ensure cached entry is still reported in the result callback but queries should be sent"); + + sTxtCallbacks.Clear(); + + SuccessOrQuit(mdns->StopTxtResolver(resolver)); + + AdvanceTime(25 * 1000); + + SuccessOrQuit(mdns->StartTxtResolver(resolver)); + + AdvanceTime(1); + + VerifyOrQuit(!sTxtCallbacks.IsEmpty()); + txtCallback = sTxtCallbacks.GetHead(); + VerifyOrQuit(txtCallback->mServiceInstance.Matches("mysrv")); + VerifyOrQuit(txtCallback->mServiceType.Matches("_srv._udp")); + VerifyOrQuit(txtCallback->Matches(kTxtData1)); + VerifyOrQuit(txtCallback->mTtl == 120); + VerifyOrQuit(txtCallback->GetNext() == nullptr); + + sTxtCallbacks.Clear(); + + AdvanceTime(15 * 1000); + + dnsMsg = sDnsMessages.GetHead(); + + for (uint8_t queryCount = 0; queryCount < kNumInitalQueries; queryCount++) + { + VerifyOrQuit(dnsMsg != nullptr); + dnsMsg->ValidateHeader(kMulticastQuery, /* Q */ 1, /* Ans */ 0, /* Auth */ 0, /* Addnl */ 0); + dnsMsg->ValidateAsQueryFor(resolver); + dnsMsg = dnsMsg->GetNext(); + } + + VerifyOrQuit(dnsMsg == nullptr); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + + SuccessOrQuit(mdns->SetEnabled(false, kInfraIfIndex)); + VerifyOrQuit(sHeapAllocatedPtrs.GetLength() <= heapAllocations); + + Log("End of test"); + + testFreeInstance(sInstance); +} + +void TestIp6AddrResolver(void) +{ + Core *mdns = InitTest(); + Core::AddressResolver resolver; + Core::AddressResolver resolver2; + AddrAndTtl addrs[5]; + const DnsMessage *dnsMsg; + const AddrCallback *addrCallback; + uint16_t heapAllocations; + + Log("-------------------------------------------------------------------------------------------"); + Log("TestIp6AddrResolver"); + + AdvanceTime(1); + + heapAllocations = sHeapAllocatedPtrs.GetLength(); + SuccessOrQuit(mdns->SetEnabled(true, kInfraIfIndex)); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Start an IPv6 address resolver. Validate initial queries."); + + ClearAllBytes(resolver); + + resolver.mHostName = "myhost"; + resolver.mInfraIfIndex = kInfraIfIndex; + resolver.mCallback = HandleAddrResult; + + sDnsMessages.Clear(); + SuccessOrQuit(mdns->StartIp6AddressResolver(resolver)); + + for (uint8_t queryCount = 0; queryCount < kNumInitalQueries; queryCount++) + { + sDnsMessages.Clear(); + + AdvanceTime((queryCount == 0) ? 125 : (1U << (queryCount - 1)) * 1000); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastQuery, /* Q */ 1, /* Ans */ 0, /* Auth */ 0, /* Addnl */ 0); + dnsMsg->ValidateAsQueryFor(resolver); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + } + + sDnsMessages.Clear(); + + AdvanceTime(20 * 1000); + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a response. Validate callback result."); + + sAddrCallbacks.Clear(); + + SuccessOrQuit(addrs[0].mAddress.FromString("fd00::1")); + addrs[0].mTtl = 120; + + SendHostAddrResponse("myhost.local.", addrs, 1, /* aCachFlush */ true, kInAnswerSection); + + AdvanceTime(1); + + VerifyOrQuit(!sAddrCallbacks.IsEmpty()); + addrCallback = sAddrCallbacks.GetHead(); + VerifyOrQuit(addrCallback->mHostName.Matches("myhost")); + VerifyOrQuit(addrCallback->Matches(addrs, 1)); + VerifyOrQuit(addrCallback->GetNext() == nullptr); + + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send an updated response adding a new address. Validate callback result."); + + SuccessOrQuit(addrs[1].mAddress.FromString("fd00::2")); + addrs[1].mTtl = 120; + + AdvanceTime(1000); + + sAddrCallbacks.Clear(); + + SendHostAddrResponse("myhost.local.", addrs, 2, /* aCachFlush */ true, kInAnswerSection); + + AdvanceTime(1); + + VerifyOrQuit(!sAddrCallbacks.IsEmpty()); + addrCallback = sAddrCallbacks.GetHead(); + VerifyOrQuit(addrCallback->mHostName.Matches("myhost")); + VerifyOrQuit(addrCallback->Matches(addrs, 2)); + VerifyOrQuit(addrCallback->GetNext() == nullptr); + + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send an updated response adding and removing addresses. Validate callback result."); + + SuccessOrQuit(addrs[0].mAddress.FromString("fd00::2")); + SuccessOrQuit(addrs[1].mAddress.FromString("fd00::aa")); + SuccessOrQuit(addrs[2].mAddress.FromString("fe80::bb")); + addrs[0].mTtl = 120; + addrs[1].mTtl = 120; + addrs[2].mTtl = 120; + + AdvanceTime(1000); + + sAddrCallbacks.Clear(); + + SendHostAddrResponse("myhost.local.", addrs, 3, /* aCachFlush */ true, kInAnswerSection); + + AdvanceTime(1); + + VerifyOrQuit(!sAddrCallbacks.IsEmpty()); + addrCallback = sAddrCallbacks.GetHead(); + VerifyOrQuit(addrCallback->mHostName.Matches("myhost")); + VerifyOrQuit(addrCallback->Matches(addrs, 3)); + VerifyOrQuit(addrCallback->GetNext() == nullptr); + + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a response without cache flush adding an address. Validate callback result."); + + SuccessOrQuit(addrs[3].mAddress.FromString("fd00::3")); + addrs[3].mTtl = 500; + + AdvanceTime(1000); + + sAddrCallbacks.Clear(); + + SendHostAddrResponse("myhost.local.", &addrs[3], 1, /* aCachFlush */ false, kInAnswerSection); + + AdvanceTime(1); + + VerifyOrQuit(!sAddrCallbacks.IsEmpty()); + addrCallback = sAddrCallbacks.GetHead(); + VerifyOrQuit(addrCallback->mHostName.Matches("myhost")); + VerifyOrQuit(addrCallback->Matches(addrs, 4)); + VerifyOrQuit(addrCallback->GetNext() == nullptr); + + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a response without cache flush with existing addresses. Validate that callback is not called."); + + AdvanceTime(1000); + + sAddrCallbacks.Clear(); + + SendHostAddrResponse("myhost.local.", &addrs[2], 2, /* aCachFlush */ false, kInAnswerSection); + + AdvanceTime(1); + + VerifyOrQuit(sAddrCallbacks.IsEmpty()); + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a response without no changes to the list. Validate that callback is not called"); + + AdvanceTime(1000); + + sAddrCallbacks.Clear(); + + SendHostAddrResponse("myhost.local.", addrs, 4, /* aCachFlush */ true, kInAdditionalSection); + + AdvanceTime(1); + + VerifyOrQuit(sAddrCallbacks.IsEmpty()); + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a response without cache flush updating TTL of existing address. Validate callback result."); + + addrs[3].mTtl = 200; + + AdvanceTime(1000); + + sAddrCallbacks.Clear(); + + SendHostAddrResponse("myhost.local.", &addrs[3], 1, /* aCachFlush */ false, kInAnswerSection); + + AdvanceTime(1); + + VerifyOrQuit(!sAddrCallbacks.IsEmpty()); + addrCallback = sAddrCallbacks.GetHead(); + VerifyOrQuit(addrCallback->mHostName.Matches("myhost")); + VerifyOrQuit(addrCallback->Matches(addrs, 4)); + VerifyOrQuit(addrCallback->GetNext() == nullptr); + + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a response without cache flush removing an address (zero TTL). Validate callback result."); + + addrs[3].mTtl = 0; + + AdvanceTime(1000); + + sAddrCallbacks.Clear(); + + SendHostAddrResponse("myhost.local.", &addrs[3], 1, /* aCachFlush */ false, kInAnswerSection); + + AdvanceTime(1); + + VerifyOrQuit(!sAddrCallbacks.IsEmpty()); + addrCallback = sAddrCallbacks.GetHead(); + VerifyOrQuit(addrCallback->mHostName.Matches("myhost")); + VerifyOrQuit(addrCallback->Matches(addrs, 3)); + VerifyOrQuit(addrCallback->GetNext() == nullptr); + + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a response with cache flush removing all addresses. Validate callback result."); + + addrs[0].mTtl = 0; + + AdvanceTime(1000); + + sAddrCallbacks.Clear(); + + SendHostAddrResponse("myhost.local.", addrs, 1, /* aCachFlush */ true, kInAnswerSection); + + AdvanceTime(1); + + VerifyOrQuit(!sAddrCallbacks.IsEmpty()); + addrCallback = sAddrCallbacks.GetHead(); + VerifyOrQuit(addrCallback->mHostName.Matches("myhost")); + VerifyOrQuit(addrCallback->Matches(addrs, 0)); + VerifyOrQuit(addrCallback->GetNext() == nullptr); + + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a response with addresses with different TTL. Validate callback result"); + + SuccessOrQuit(addrs[0].mAddress.FromString("fd00::00")); + SuccessOrQuit(addrs[1].mAddress.FromString("fd00::11")); + SuccessOrQuit(addrs[2].mAddress.FromString("fe80::22")); + SuccessOrQuit(addrs[3].mAddress.FromString("fe80::33")); + addrs[0].mTtl = 120; + addrs[1].mTtl = 800; + addrs[2].mTtl = 2000; + addrs[3].mTtl = 8000; + + AdvanceTime(5 * 1000); + + sAddrCallbacks.Clear(); + + SendHostAddrResponse("myhost.local.", addrs, 4, /* aCachFlush */ true, kInAnswerSection); + + AdvanceTime(1); + + VerifyOrQuit(!sAddrCallbacks.IsEmpty()); + addrCallback = sAddrCallbacks.GetHead(); + VerifyOrQuit(addrCallback->mHostName.Matches("myhost")); + VerifyOrQuit(addrCallback->Matches(addrs, 4)); + VerifyOrQuit(addrCallback->GetNext() == nullptr); + + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Start another resolver for the same host and different callback. Validate results."); + + resolver2.mHostName = "myhost"; + resolver2.mInfraIfIndex = kInfraIfIndex; + resolver2.mCallback = HandleAddrResultAlternate; + + sAddrCallbacks.Clear(); + + SuccessOrQuit(mdns->StartIp6AddressResolver(resolver2)); + + AdvanceTime(1); + + VerifyOrQuit(!sAddrCallbacks.IsEmpty()); + addrCallback = sAddrCallbacks.GetHead(); + VerifyOrQuit(addrCallback->mHostName.Matches("myhost")); + VerifyOrQuit(addrCallback->Matches(addrs, 4)); + VerifyOrQuit(addrCallback->GetNext() == nullptr); + + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Start same resolver again and check the returned error."); + + sAddrCallbacks.Clear(); + + VerifyOrQuit(mdns->StartIp6AddressResolver(resolver2) == kErrorAlready); + + AdvanceTime(5000); + + VerifyOrQuit(sAddrCallbacks.IsEmpty()); + sDnsMessages.Clear(); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Check query is sent at 80 percentage of TTL and then respond to it."); + + SendHostAddrResponse("myhost.local.", addrs, 4, /* aCachFlush */ true, kInAnswerSection); + + // First query should be sent at 80-82% of TTL of 120 second (96.0-98.4 sec). + // We wait for 100 second. Note that 5 seconds already passed in the + // previous step. + + AdvanceTime(96 * 1000 - 1); + + VerifyOrQuit(sDnsMessages.IsEmpty()); + + AdvanceTime(4 * 1000 + 1); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastQuery, /* Q */ 1, /* Ans */ 0, /* Auth */ 0, /* Addnl */ 0); + dnsMsg->ValidateAsQueryFor(resolver); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + + sDnsMessages.Clear(); + VerifyOrQuit(sAddrCallbacks.IsEmpty()); + + AdvanceTime(10); + + SendHostAddrResponse("myhost.local.", addrs, 4, /* aCachFlush */ true, kInAnswerSection); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Check queries are sent at 80, 85, 90, 95 percentages of TTL."); + + for (uint8_t queryCount = 0; queryCount < kNumRefreshQueries; queryCount++) + { + if (queryCount == 0) + { + // First query is expected in 80-82% of TTL, so + // 80% of 120 = 96.0, 82% of 120 = 98.4 + + AdvanceTime(96 * 1000 - 1); + } + else + { + // Next query should happen within 3%-5% of TTL + // from previous query. We wait 3% of TTL here. + AdvanceTime(3600 - 1); + } + + VerifyOrQuit(sDnsMessages.IsEmpty()); + + // Wait for 2% of TTL of 120 which is 2.4 sec. + + AdvanceTime(2400 + 1); + + VerifyOrQuit(!sDnsMessages.IsEmpty()); + dnsMsg = sDnsMessages.GetHead(); + dnsMsg->ValidateHeader(kMulticastQuery, /* Q */ 1, /* Ans */ 0, /* Auth */ 0, /* Addnl */ 0); + dnsMsg->ValidateAsQueryFor(resolver); + VerifyOrQuit(dnsMsg->GetNext() == nullptr); + + sDnsMessages.Clear(); + VerifyOrQuit(sAddrCallbacks.IsEmpty()); + } + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Check TTL timeout of first address (TTL 120) and callback result."); + + AdvanceTime(6 * 1000); + + addrCallback = sAddrCallbacks.GetHead(); + + for (uint8_t iter = 0; iter < 2; iter++) + { + VerifyOrQuit(addrCallback != nullptr); + VerifyOrQuit(addrCallback->mHostName.Matches("myhost")); + VerifyOrQuit(addrCallback->Matches(&addrs[1], 3)); + addrCallback = addrCallback->GetNext(); + } + + VerifyOrQuit(addrCallback == nullptr); + + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Check TTL timeout of next address (TTL 800) and callback result."); + + sAddrCallbacks.Clear(); + + AdvanceTime((800 - 120) * 1000); + + addrCallback = sAddrCallbacks.GetHead(); + + for (uint8_t iter = 0; iter < 2; iter++) + { + VerifyOrQuit(addrCallback != nullptr); + VerifyOrQuit(addrCallback->mHostName.Matches("myhost")); + VerifyOrQuit(addrCallback->Matches(&addrs[2], 2)); + addrCallback = addrCallback->GetNext(); + } + + VerifyOrQuit(addrCallback == nullptr); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + + sAddrCallbacks.Clear(); + sDnsMessages.Clear(); + + AdvanceTime(200 * 1000); + + VerifyOrQuit(sAddrCallbacks.IsEmpty()); + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Stop the second resolver"); + + sAddrCallbacks.Clear(); + + SuccessOrQuit(mdns->StopIp6AddressResolver(resolver2)); + + AdvanceTime(100 * 1000); + + VerifyOrQuit(sAddrCallbacks.IsEmpty()); + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Send a new response and make sure result callback is invoked"); + + sAddrCallbacks.Clear(); + + SendHostAddrResponse("myhost.local.", addrs, 1, /* aCachFlush */ true, kInAnswerSection); + + AdvanceTime(1); + + VerifyOrQuit(!sAddrCallbacks.IsEmpty()); + addrCallback = sAddrCallbacks.GetHead(); + VerifyOrQuit(addrCallback->mHostName.Matches("myhost")); + VerifyOrQuit(addrCallback->Matches(addrs, 1)); + VerifyOrQuit(addrCallback->GetNext() == nullptr); + + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Stop the resolver. There is no active resolver. Ensure no queries are sent"); + + sAddrCallbacks.Clear(); + + SuccessOrQuit(mdns->StopIp6AddressResolver(resolver)); + + AdvanceTime(20 * 1000); + + VerifyOrQuit(sAddrCallbacks.IsEmpty()); + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Restart the resolver with more than half of TTL remaining."); + Log("Ensure cached entry is reported in the result callback and no queries are sent."); + + SuccessOrQuit(mdns->StartIp6AddressResolver(resolver)); + + AdvanceTime(1); + + VerifyOrQuit(!sAddrCallbacks.IsEmpty()); + addrCallback = sAddrCallbacks.GetHead(); + VerifyOrQuit(addrCallback->mHostName.Matches("myhost")); + VerifyOrQuit(addrCallback->Matches(addrs, 1)); + VerifyOrQuit(addrCallback->GetNext() == nullptr); + + VerifyOrQuit(sDnsMessages.IsEmpty()); + + AdvanceTime(20 * 1000); + + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Stop and start the resolver again after less than half TTL remaining."); + Log("Ensure cached entry is still reported in the result callback but queries should be sent"); + + sAddrCallbacks.Clear(); + + SuccessOrQuit(mdns->StopIp6AddressResolver(resolver)); + + AdvanceTime(25 * 1000); + + SuccessOrQuit(mdns->StartIp6AddressResolver(resolver)); + + AdvanceTime(1); + + VerifyOrQuit(!sAddrCallbacks.IsEmpty()); + addrCallback = sAddrCallbacks.GetHead(); + VerifyOrQuit(addrCallback->mHostName.Matches("myhost")); + VerifyOrQuit(addrCallback->Matches(addrs, 1)); + VerifyOrQuit(addrCallback->GetNext() == nullptr); + + sAddrCallbacks.Clear(); + + AdvanceTime(15 * 1000); + + dnsMsg = sDnsMessages.GetHead(); + + for (uint8_t queryCount = 0; queryCount < kNumInitalQueries; queryCount++) + { + VerifyOrQuit(dnsMsg != nullptr); + dnsMsg->ValidateHeader(kMulticastQuery, /* Q */ 1, /* Ans */ 0, /* Auth */ 0, /* Addnl */ 0); + dnsMsg->ValidateAsQueryFor(resolver); + dnsMsg = dnsMsg->GetNext(); + } + + VerifyOrQuit(dnsMsg == nullptr); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + + SuccessOrQuit(mdns->SetEnabled(false, kInfraIfIndex)); + VerifyOrQuit(sHeapAllocatedPtrs.GetLength() <= heapAllocations); + + Log("End of test"); + + testFreeInstance(sInstance); +} + +void TestPassiveCache(void) +{ + static const char *const kSubTypes[] = {"_sub1", "_xyzw"}; + + Core *mdns = InitTest(); + Core::Browser browser; + Core::SrvResolver srvResolver; + Core::TxtResolver txtResolver; + Core::AddressResolver addrResolver; + Core::Host host1; + Core::Host host2; + Core::Service service1; + Core::Service service2; + Core::Service service3; + Ip6::Address host1Addresses[3]; + Ip6::Address host2Addresses[2]; + AddrAndTtl host1AddrTtls[3]; + AddrAndTtl host2AddrTtls[2]; + const DnsMessage *dnsMsg; + BrowseCallback *browseCallback; + SrvCallback *srvCallback; + TxtCallback *txtCallback; + AddrCallback *addrCallback; + uint16_t heapAllocations; + + Log("-------------------------------------------------------------------------------------------"); + Log("TestPassiveCache"); + + AdvanceTime(1); + + heapAllocations = sHeapAllocatedPtrs.GetLength(); + SuccessOrQuit(mdns->SetEnabled(true, kInfraIfIndex)); + + SuccessOrQuit(host1Addresses[0].FromString("fd00::1:aaaa")); + SuccessOrQuit(host1Addresses[1].FromString("fd00::1:bbbb")); + SuccessOrQuit(host1Addresses[2].FromString("fd00::1:cccc")); + host1.mHostName = "host1"; + host1.mAddresses = host1Addresses; + host1.mAddressesLength = 3; + host1.mTtl = 1500; + + host1AddrTtls[0].mAddress = host1Addresses[0]; + host1AddrTtls[1].mAddress = host1Addresses[1]; + host1AddrTtls[2].mAddress = host1Addresses[2]; + host1AddrTtls[0].mTtl = host1.mTtl; + host1AddrTtls[1].mTtl = host1.mTtl; + host1AddrTtls[2].mTtl = host1.mTtl; + + SuccessOrQuit(host2Addresses[0].FromString("fd00::2:eeee")); + SuccessOrQuit(host2Addresses[1].FromString("fd00::2:ffff")); + host2.mHostName = "host2"; + host2.mAddresses = host2Addresses; + host2.mAddressesLength = 2; + host2.mTtl = 1500; + + host2AddrTtls[0].mAddress = host2Addresses[0]; + host2AddrTtls[1].mAddress = host2Addresses[1]; + host2AddrTtls[0].mTtl = host2.mTtl; + host2AddrTtls[1].mTtl = host2.mTtl; + + service1.mHostName = host1.mHostName; + service1.mServiceInstance = "srv1"; + service1.mServiceType = "_srv._udp"; + service1.mSubTypeLabels = kSubTypes; + service1.mSubTypeLabelsLength = 2; + service1.mTxtData = kTxtData1; + service1.mTxtDataLength = sizeof(kTxtData1); + service1.mPort = 1111; + service1.mPriority = 0; + service1.mWeight = 0; + service1.mTtl = 1500; + + service2.mHostName = host1.mHostName; + service2.mServiceInstance = "srv2"; + service2.mServiceType = "_tst._tcp"; + service2.mSubTypeLabels = nullptr; + service2.mSubTypeLabelsLength = 0; + service2.mTxtData = nullptr; + service2.mTxtDataLength = 0; + service2.mPort = 2222; + service2.mPriority = 2; + service2.mWeight = 2; + service2.mTtl = 1500; + + service3.mHostName = host2.mHostName; + service3.mServiceInstance = "srv3"; + service3.mServiceType = "_srv._udp"; + service3.mSubTypeLabels = kSubTypes; + service3.mSubTypeLabelsLength = 1; + service3.mTxtData = kTxtData2; + service3.mTxtDataLength = sizeof(kTxtData2); + service3.mPort = 3333; + service3.mPriority = 3; + service3.mWeight = 3; + service3.mTtl = 1500; + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Register 2 hosts and 3 services"); + + SuccessOrQuit(mdns->RegisterHost(host1, 0, HandleSuccessCallback)); + SuccessOrQuit(mdns->RegisterHost(host2, 1, HandleSuccessCallback)); + SuccessOrQuit(mdns->RegisterService(service1, 2, HandleSuccessCallback)); + SuccessOrQuit(mdns->RegisterService(service2, 3, HandleSuccessCallback)); + SuccessOrQuit(mdns->RegisterService(service3, 4, HandleSuccessCallback)); + + AdvanceTime(10 * 1000); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Start a browser for `_srv._udp`, validate callback result"); + + browser.mServiceType = "_srv._udp"; + browser.mSubTypeLabel = nullptr; + browser.mInfraIfIndex = kInfraIfIndex; + browser.mCallback = HandleBrowseResult; + + sBrowseCallbacks.Clear(); + + SuccessOrQuit(mdns->StartBrowser(browser)); + + AdvanceTime(350); + + browseCallback = sBrowseCallbacks.GetHead(); + + for (uint8_t iter = 0; iter < 2; iter++) + { + VerifyOrQuit(browseCallback != nullptr); + + VerifyOrQuit(browseCallback->mServiceType.Matches("_srv._udp")); + VerifyOrQuit(!browseCallback->mIsSubType); + VerifyOrQuit(browseCallback->mServiceInstance.Matches("srv1") || + browseCallback->mServiceInstance.Matches("srv3")); + VerifyOrQuit(browseCallback->mTtl == 1500); + + browseCallback = browseCallback->GetNext(); + } + + VerifyOrQuit(browseCallback == nullptr); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Start SRV and TXT resolvers for the srv1 and for its host name."); + Log("Ensure all results are immediately provided from cache."); + + srvResolver.mServiceInstance = "srv1"; + srvResolver.mServiceType = "_srv._udp"; + srvResolver.mInfraIfIndex = kInfraIfIndex; + srvResolver.mCallback = HandleSrvResult; + + txtResolver.mServiceInstance = "srv1"; + txtResolver.mServiceType = "_srv._udp"; + txtResolver.mInfraIfIndex = kInfraIfIndex; + txtResolver.mCallback = HandleTxtResult; + + addrResolver.mHostName = "host1"; + addrResolver.mInfraIfIndex = kInfraIfIndex; + addrResolver.mCallback = HandleAddrResult; + + sSrvCallbacks.Clear(); + sTxtCallbacks.Clear(); + sAddrCallbacks.Clear(); + sDnsMessages.Clear(); + + SuccessOrQuit(mdns->StartSrvResolver(srvResolver)); + SuccessOrQuit(mdns->StartTxtResolver(txtResolver)); + SuccessOrQuit(mdns->StartIp6AddressResolver(addrResolver)); + + AdvanceTime(1); + + VerifyOrQuit(!sSrvCallbacks.IsEmpty()); + srvCallback = sSrvCallbacks.GetHead(); + VerifyOrQuit(srvCallback->mServiceInstance.Matches("srv1")); + VerifyOrQuit(srvCallback->mServiceType.Matches("_srv._udp")); + VerifyOrQuit(srvCallback->mHostName.Matches("host1")); + VerifyOrQuit(srvCallback->mPort == 1111); + VerifyOrQuit(srvCallback->mPriority == 0); + VerifyOrQuit(srvCallback->mWeight == 0); + VerifyOrQuit(srvCallback->mTtl == 1500); + VerifyOrQuit(srvCallback->GetNext() == nullptr); + + VerifyOrQuit(!sTxtCallbacks.IsEmpty()); + txtCallback = sTxtCallbacks.GetHead(); + VerifyOrQuit(txtCallback->mServiceInstance.Matches("srv1")); + VerifyOrQuit(txtCallback->mServiceType.Matches("_srv._udp")); + VerifyOrQuit(txtCallback->Matches(kTxtData1)); + VerifyOrQuit(txtCallback->mTtl == 1500); + VerifyOrQuit(txtCallback->GetNext() == nullptr); + + VerifyOrQuit(!sAddrCallbacks.IsEmpty()); + addrCallback = sAddrCallbacks.GetHead(); + VerifyOrQuit(addrCallback->mHostName.Matches("host1")); + VerifyOrQuit(addrCallback->Matches(host1AddrTtls, 3)); + VerifyOrQuit(addrCallback->GetNext() == nullptr); + + AdvanceTime(400); + + VerifyOrQuit(sDnsMessages.IsEmpty()); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Start a browser for sub-type service, validate callback result"); + + browser.mServiceType = "_srv._udp"; + browser.mSubTypeLabel = "_xyzw"; + browser.mInfraIfIndex = kInfraIfIndex; + browser.mCallback = HandleBrowseResult; + + sBrowseCallbacks.Clear(); + + SuccessOrQuit(mdns->StartBrowser(browser)); + + AdvanceTime(350); + + browseCallback = sBrowseCallbacks.GetHead(); + VerifyOrQuit(browseCallback != nullptr); + + VerifyOrQuit(browseCallback->mServiceType.Matches("_srv._udp")); + VerifyOrQuit(browseCallback->mIsSubType); + VerifyOrQuit(browseCallback->mSubTypeLabel.Matches("_xyzw")); + VerifyOrQuit(browseCallback->mServiceInstance.Matches("srv1")); + VerifyOrQuit(browseCallback->mTtl == 1500); + VerifyOrQuit(browseCallback->GetNext() == nullptr); + + AdvanceTime(5 * 1000); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Start SRV and TXT resolvers for `srv2._tst._tcp` service and validate callback result"); + + srvResolver.mServiceInstance = "srv2"; + srvResolver.mServiceType = "_tst._tcp"; + srvResolver.mInfraIfIndex = kInfraIfIndex; + srvResolver.mCallback = HandleSrvResult; + + txtResolver.mServiceInstance = "srv2"; + txtResolver.mServiceType = "_tst._tcp"; + txtResolver.mInfraIfIndex = kInfraIfIndex; + txtResolver.mCallback = HandleTxtResult; + + sSrvCallbacks.Clear(); + sTxtCallbacks.Clear(); + + SuccessOrQuit(mdns->StartSrvResolver(srvResolver)); + SuccessOrQuit(mdns->StartTxtResolver(txtResolver)); + + AdvanceTime(350); + + VerifyOrQuit(!sSrvCallbacks.IsEmpty()); + srvCallback = sSrvCallbacks.GetHead(); + VerifyOrQuit(srvCallback->mServiceInstance.Matches("srv2")); + VerifyOrQuit(srvCallback->mServiceType.Matches("_tst._tcp")); + VerifyOrQuit(srvCallback->mHostName.Matches("host1")); + VerifyOrQuit(srvCallback->mPort == 2222); + VerifyOrQuit(srvCallback->mPriority == 2); + VerifyOrQuit(srvCallback->mWeight == 2); + VerifyOrQuit(srvCallback->mTtl == 1500); + VerifyOrQuit(srvCallback->GetNext() == nullptr); + + VerifyOrQuit(!sTxtCallbacks.IsEmpty()); + txtCallback = sTxtCallbacks.GetHead(); + VerifyOrQuit(txtCallback->mServiceInstance.Matches("srv2")); + VerifyOrQuit(txtCallback->mServiceType.Matches("_tst._tcp")); + VerifyOrQuit(txtCallback->Matches(kEmptyTxtData)); + VerifyOrQuit(txtCallback->mTtl == 1500); + VerifyOrQuit(txtCallback->GetNext() == nullptr); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Unregister `srv2._tst._tcp` and validate callback results"); + + sSrvCallbacks.Clear(); + sTxtCallbacks.Clear(); + + SuccessOrQuit(mdns->UnregisterService(service2)); + + AdvanceTime(350); + + VerifyOrQuit(!sSrvCallbacks.IsEmpty()); + srvCallback = sSrvCallbacks.GetHead(); + VerifyOrQuit(srvCallback->mServiceInstance.Matches("srv2")); + VerifyOrQuit(srvCallback->mServiceType.Matches("_tst._tcp")); + VerifyOrQuit(srvCallback->mTtl == 0); + VerifyOrQuit(srvCallback->GetNext() == nullptr); + + VerifyOrQuit(!sTxtCallbacks.IsEmpty()); + txtCallback = sTxtCallbacks.GetHead(); + VerifyOrQuit(txtCallback->mServiceInstance.Matches("srv2")); + VerifyOrQuit(txtCallback->mServiceType.Matches("_tst._tcp")); + VerifyOrQuit(txtCallback->mTtl == 0); + VerifyOrQuit(txtCallback->GetNext() == nullptr); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Start an SRV resolver for `srv3._srv._udp` service and validate callback result"); + + srvResolver.mServiceInstance = "srv3"; + srvResolver.mServiceType = "_srv._udp"; + srvResolver.mInfraIfIndex = kInfraIfIndex; + srvResolver.mCallback = HandleSrvResult; + + sSrvCallbacks.Clear(); + + SuccessOrQuit(mdns->StartSrvResolver(srvResolver)); + + AdvanceTime(350); + + VerifyOrQuit(!sSrvCallbacks.IsEmpty()); + srvCallback = sSrvCallbacks.GetHead(); + VerifyOrQuit(srvCallback->mServiceInstance.Matches("srv3")); + VerifyOrQuit(srvCallback->mServiceType.Matches("_srv._udp")); + VerifyOrQuit(srvCallback->mHostName.Matches("host2")); + VerifyOrQuit(srvCallback->mPort == 3333); + VerifyOrQuit(srvCallback->mPriority == 3); + VerifyOrQuit(srvCallback->mWeight == 3); + VerifyOrQuit(srvCallback->mTtl == 1500); + VerifyOrQuit(srvCallback->GetNext() == nullptr); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + Log("Start an address resolver for host2 and validate result is immediately reported from cache"); + + addrResolver.mHostName = "host2"; + addrResolver.mInfraIfIndex = kInfraIfIndex; + addrResolver.mCallback = HandleAddrResult; + + sAddrCallbacks.Clear(); + SuccessOrQuit(mdns->StartIp6AddressResolver(addrResolver)); + + AdvanceTime(1); + + VerifyOrQuit(!sAddrCallbacks.IsEmpty()); + addrCallback = sAddrCallbacks.GetHead(); + VerifyOrQuit(addrCallback->mHostName.Matches("host2")); + VerifyOrQuit(addrCallback->Matches(host2AddrTtls, 2)); + VerifyOrQuit(addrCallback->GetNext() == nullptr); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -"); + + SuccessOrQuit(mdns->SetEnabled(false, kInfraIfIndex)); + VerifyOrQuit(sHeapAllocatedPtrs.GetLength() <= heapAllocations); + + Log("End of test"); + + testFreeInstance(sInstance); +} + +} // namespace Multicast +} // namespace Dns +} // namespace ot + +#endif // OPENTHREAD_CONFIG_MULTICAST_DNS_ENABLE + +int main(void) +{ +#if OPENTHREAD_CONFIG_MULTICAST_DNS_ENABLE + ot::Dns::Multicast::TestHostReg(); + ot::Dns::Multicast::TestKeyReg(); + ot::Dns::Multicast::TestServiceReg(); + ot::Dns::Multicast::TestUnregisterBeforeProbeFinished(); + ot::Dns::Multicast::TestServiceSubTypeReg(); + ot::Dns::Multicast::TestHostOrServiceAndKeyReg(); + ot::Dns::Multicast::TestQuery(); + ot::Dns::Multicast::TestMultiPacket(); + ot::Dns::Multicast::TestQuestionUnicastDisallowed(); + ot::Dns::Multicast::TestTxMessageSizeLimit(); + ot::Dns::Multicast::TestHostConflict(); + ot::Dns::Multicast::TestServiceConflict(); + + ot::Dns::Multicast::TestBrowser(); + ot::Dns::Multicast::TestSrvResolver(); + ot::Dns::Multicast::TestTxtResolver(); + ot::Dns::Multicast::TestIp6AddrResolver(); + ot::Dns::Multicast::TestPassiveCache(); + + printf("All tests passed\n"); +#else + printf("mDNS feature is not enabled\n"); +#endif + + return 0; +} diff --git a/tests/unit/test_platform.cpp b/tests/unit/test_platform.cpp index 2904f6019..496dc4979 100644 --- a/tests/unit/test_platform.cpp +++ b/tests/unit/test_platform.cpp @@ -555,6 +555,35 @@ otError otPlatRadioSetCcaEnergyDetectThreshold(otInstance *aInstance, int8_t aTh return OT_ERROR_NONE; } +#if OPENTHREAD_CONFIG_MULTICAST_DNS_ENABLE + +OT_TOOL_WEAK otError otPlatMdnsSetListeningEnabled(otInstance *aInstance, bool aEnable, uint32_t aInfraIfIndex) +{ + OT_UNUSED_VARIABLE(aInstance); + OT_UNUSED_VARIABLE(aEnable); + OT_UNUSED_VARIABLE(aInfraIfIndex); + + return OT_ERROR_NOT_IMPLEMENTED; +} + +OT_TOOL_WEAK void otPlatMdnsSendMulticast(otInstance *aInstance, otMessage *aMessage, uint32_t aInfraIfIndex) +{ + OT_UNUSED_VARIABLE(aInstance); + OT_UNUSED_VARIABLE(aMessage); + OT_UNUSED_VARIABLE(aInfraIfIndex); +} + +OT_TOOL_WEAK void otPlatMdnsSendUnicast(otInstance *aInstance, + otMessage *aMessage, + const otPlatMdnsAddressInfo *aAddress) +{ + OT_UNUSED_VARIABLE(aInstance); + OT_UNUSED_VARIABLE(aMessage); + OT_UNUSED_VARIABLE(aAddress); +} + +#endif // OPENTHREAD_CONFIG_MULTICAST_DNS_ENABLE + #if OPENTHREAD_CONFIG_DNS_DSO_ENABLE OT_TOOL_WEAK void otPlatDsoEnableListening(otInstance *aInstance, bool aEnable) diff --git a/tests/unit/test_platform.h b/tests/unit/test_platform.h index f1c7eb4f4..8f9d2917a 100644 --- a/tests/unit/test_platform.h +++ b/tests/unit/test_platform.h @@ -38,6 +38,7 @@ #include #include #include +#include #include #include #include