diff --git a/examples/minimal-mdns/server.cpp b/examples/minimal-mdns/server.cpp index f1727455a5b284..d4456001c16875 100644 --- a/examples/minimal-mdns/server.cpp +++ b/examples/minimal-mdns/server.cpp @@ -51,7 +51,7 @@ struct Options const char * instanceName = "chip-mdns-demo"; } gOptions; -using namespace chip::ArgParser; +using namespace ArgParser; constexpr uint16_t kOptionEnableIpV4 = '4'; constexpr uint16_t kOptionListenPort = 'p'; @@ -111,7 +111,7 @@ class ReplyDelegate : public mdns::Minimal::ServerDelegate, public mdns::Minimal public: ReplyDelegate(mdns::Minimal::ResponseSender * responder) : mResponder(responder) {} - void OnQuery(const mdns::Minimal::BytesRange & data, const chip::Inet::IPPacketInfo * info) override + void OnQuery(const mdns::Minimal::BytesRange & data, const Inet::IPPacketInfo * info) override { char addr[INET6_ADDRSTRLEN]; info->SrcAddress.ToString(addr, sizeof(addr)); @@ -127,7 +127,7 @@ class ReplyDelegate : public mdns::Minimal::ServerDelegate, public mdns::Minimal mCurrentSource = nullptr; } - void OnResponse(const mdns::Minimal::BytesRange & data, const chip::Inet::IPPacketInfo * info) override + void OnResponse(const mdns::Minimal::BytesRange & data, const Inet::IPPacketInfo * info) override { char addr[INET6_ADDRSTRLEN]; info->SrcAddress.ToString(addr, sizeof(addr)); @@ -158,8 +158,8 @@ class ReplyDelegate : public mdns::Minimal::ServerDelegate, public mdns::Minimal } mdns::Minimal::ResponseSender * mResponder; - const chip::Inet::IPPacketInfo * mCurrentSource = nullptr; - uint32_t mMessageId = 0; + const Inet::IPPacketInfo * mCurrentSource = nullptr; + uint32_t mMessageId = 0; }; } // namespace @@ -178,7 +178,7 @@ int main(int argc, char ** args) return 1; } - if (!chip::ArgParser::ParseArgs(args[0], argc, args, allOptions)) + if (!ArgParser::ParseArgs(args[0], argc, args, allOptions)) { return 1; } @@ -188,27 +188,22 @@ int main(int argc, char ** args) mdns::Minimal::Server<10 /* endpoints */> mdnsServer; mdns::Minimal::QueryResponder<16 /* maxRecords */> queryResponder; - mdns::Minimal::QNamePart tcpServiceName[] = { chip::Mdns::kOperationalServiceName, chip::Mdns::kOperationalProtocol, - chip::Mdns::kLocalDomain }; - mdns::Minimal::QNamePart tcpServerServiceName[] = { gOptions.instanceName, chip::Mdns::kOperationalServiceName, - chip::Mdns::kOperationalProtocol, chip::Mdns::kLocalDomain }; - mdns::Minimal::QNamePart udpServiceName[] = { chip::Mdns::kCommissionableServiceName, chip::Mdns::kCommissionProtocol, - chip::Mdns::kLocalDomain }; - mdns::Minimal::QNamePart udpServerServiceName[] = { gOptions.instanceName, chip::Mdns::kCommissionableServiceName, - chip::Mdns::kCommissionProtocol, chip::Mdns::kLocalDomain }; + mdns::Minimal::QNamePart tcpServiceName[] = { Mdns::kOperationalServiceName, Mdns::kOperationalProtocol, Mdns::kLocalDomain }; + mdns::Minimal::QNamePart tcpServerServiceName[] = { gOptions.instanceName, Mdns::kOperationalServiceName, + Mdns::kOperationalProtocol, Mdns::kLocalDomain }; + mdns::Minimal::QNamePart udpServiceName[] = { Mdns::kCommissionableServiceName, Mdns::kCommissionProtocol, Mdns::kLocalDomain }; + mdns::Minimal::QNamePart udpServerServiceName[] = { gOptions.instanceName, Mdns::kCommissionableServiceName, + Mdns::kCommissionProtocol, Mdns::kLocalDomain }; // several UDP versions for discriminators - mdns::Minimal::QNamePart udpDiscriminator1[] = { "S52", chip::Mdns::kSubtypeServiceNamePart, - chip::Mdns::kCommissionableServiceName, chip::Mdns::kCommissionProtocol, - chip::Mdns::kLocalDomain }; - mdns::Minimal::QNamePart udpDiscriminator2[] = { "V123", chip::Mdns::kSubtypeServiceNamePart, - chip::Mdns::kCommissionableServiceName, chip::Mdns::kCommissionProtocol, - chip::Mdns::kLocalDomain }; - mdns::Minimal::QNamePart udpDiscriminator3[] = { "L840", chip::Mdns::kSubtypeServiceNamePart, - chip::Mdns::kCommissionableServiceName, chip::Mdns::kCommissionProtocol, - chip::Mdns::kLocalDomain }; - - mdns::Minimal::QNamePart serverName[] = { gOptions.instanceName, chip::Mdns::kLocalDomain }; + mdns::Minimal::QNamePart udpDiscriminator1[] = { "S52", Mdns::kSubtypeServiceNamePart, Mdns::kCommissionableServiceName, + Mdns::kCommissionProtocol, Mdns::kLocalDomain }; + mdns::Minimal::QNamePart udpDiscriminator2[] = { "V123", Mdns::kSubtypeServiceNamePart, Mdns::kCommissionableServiceName, + Mdns::kCommissionProtocol, Mdns::kLocalDomain }; + mdns::Minimal::QNamePart udpDiscriminator3[] = { "L840", Mdns::kSubtypeServiceNamePart, Mdns::kCommissionableServiceName, + Mdns::kCommissionProtocol, Mdns::kLocalDomain }; + + mdns::Minimal::QNamePart serverName[] = { gOptions.instanceName, Mdns::kLocalDomain }; mdns::Minimal::IPv4Responder ipv4Responder(serverName); mdns::Minimal::IPv6Responder ipv6Responder(serverName); @@ -256,7 +251,7 @@ int main(int argc, char ** args) { MdnsExample::AllInterfaces allInterfaces(gOptions.enableIpV4); - if (mdnsServer.Listen(&chip::DeviceLayer::InetLayer, &allInterfaces, gOptions.listenPort) != CHIP_NO_ERROR) + if (mdnsServer.Listen(&DeviceLayer::InetLayer, &allInterfaces, gOptions.listenPort) != CHIP_NO_ERROR) { printf("Server failed to listen on all interfaces\n"); return 1; diff --git a/examples/platform/linux/AppMain.cpp b/examples/platform/linux/AppMain.cpp index 89bcd3b6273778..fe00b509f05e72 100644 --- a/examples/platform/linux/AppMain.cpp +++ b/examples/platform/linux/AppMain.cpp @@ -184,11 +184,16 @@ CHIP_ERROR InitCommissioner() ReturnErrorOnFailure(gCommissioner.SetUdpListenPort(CHIP_PORT + 2)); ReturnErrorOnFailure(gCommissioner.SetUdcListenPort(CHIP_PORT + 3)); ReturnErrorOnFailure(gCommissioner.Init(localId, params)); - ReturnErrorOnFailure(gCommissioner.ServiceEvents()); return CHIP_NO_ERROR; } +CHIP_ERROR ShutdownCommissioner() +{ + gCommissioner.Shutdown(); + return CHIP_NO_ERROR; +} + #endif // CHIP_DEVICE_CONFIG_ENABLE_BOTH_COMMISSIONER_AND_COMMISSIONEE void ChipLinuxAppMainLoop() @@ -209,6 +214,11 @@ void ChipLinuxAppMainLoop() #endif // CHIP_DEVICE_CONFIG_ENABLE_BOTH_COMMISSIONER_AND_COMMISSIONEE chip::DeviceLayer::PlatformMgr().RunEventLoop(); + +#if CHIP_DEVICE_CONFIG_ENABLE_BOTH_COMMISSIONER_AND_COMMISSIONEE + ShutdownCommissioner(); +#endif // CHIP_DEVICE_CONFIG_ENABLE_BOTH_COMMISSIONER_AND_COMMISSIONEE + #if defined(ENABLE_CHIP_SHELL) shellThread.join(); #endif diff --git a/examples/platform/linux/CommissioneeShellCommands.cpp b/examples/platform/linux/CommissioneeShellCommands.cpp index 676822b6d688c7..88479997734b63 100644 --- a/examples/platform/linux/CommissioneeShellCommands.cpp +++ b/examples/platform/linux/CommissioneeShellCommands.cpp @@ -59,7 +59,7 @@ static CHIP_ERROR PrintAllCommands() streamer_printf(sout, " help Usage: commissionee \r\n"); #if CHIP_DEVICE_CONFIG_ENABLE_COMMISSIONER_DISCOVERY_CLIENT streamer_printf(sout, - " sendudc
Send UDC message to address. Usage: commissionee sendudc 127.0.0.1 11100\r\n"); + " sendudc
Send UDC message to address. Usage: commissionee sendudc 127.0.0.1 5543\r\n"); #endif // CHIP_DEVICE_CONFIG_ENABLE_COMMISSIONER_DISCOVERY_CLIENT streamer_printf(sout, "\r\n"); diff --git a/src/BUILD.gn b/src/BUILD.gn index 1a2cfe50c90d6d..8867dcbb59b17f 100644 --- a/src/BUILD.gn +++ b/src/BUILD.gn @@ -54,6 +54,7 @@ if (chip_build_tests) { "${chip_root}/src/messaging/tests", "${chip_root}/src/protocols/bdx/tests", "${chip_root}/src/protocols/secure_channel/tests", + "${chip_root}/src/protocols/user_directed_commissioning/tests", "${chip_root}/src/system/tests", "${chip_root}/src/transport/retransmit/tests", "${chip_root}/src/transport/tests", diff --git a/src/app/server/Server.cpp b/src/app/server/Server.cpp index 1436e3af9d7089..a072fc256b9a0f 100644 --- a/src/app/server/Server.cpp +++ b/src/app/server/Server.cpp @@ -588,27 +588,21 @@ CHIP_ERROR SendUserDirectedCommissioningRequest(chip::Transport::PeerAddress com } ChipLogDetail(AppServer, "instanceName=%s", nameBuffer); - // send UDC message 5 times per spec (no ACK on this message) - for (unsigned int i = 0; i < 5; i++) + chip::System::PacketBufferHandle payloadBuf = chip::MessagePacketBuffer::NewWithData(nameBuffer, strlen(nameBuffer)); + if (payloadBuf.IsNull()) { - chip::System::PacketBufferHandle payloadBuf = chip::MessagePacketBuffer::NewWithData(nameBuffer, strlen(nameBuffer)); - if (payloadBuf.IsNull()) - { - ChipLogError(AppServer, "Unable to allocate packet buffer\n"); - return CHIP_ERROR_NO_MEMORY; - } - - err = gUDCClient.SendUDCMessage(&gTransports, std::move(payloadBuf), commissioner); - if (err == CHIP_NO_ERROR) - { - ChipLogDetail(AppServer, "Send UDC request success"); - } - else - { - ChipLogError(AppServer, "Send UDC request failed, err: %s\n", chip::ErrorStr(err)); - } + ChipLogError(AppServer, "Unable to allocate packet buffer\n"); + return CHIP_ERROR_NO_MEMORY; + } - sleep(1); + err = gUDCClient.SendUDCMessage(&gTransports, std::move(payloadBuf), commissioner); + if (err == CHIP_NO_ERROR) + { + ChipLogDetail(AppServer, "Send UDC request success"); + } + else + { + ChipLogError(AppServer, "Send UDC request failed, err: %s\n", chip::ErrorStr(err)); } return err; } diff --git a/src/controller/CHIPDeviceController.cpp b/src/controller/CHIPDeviceController.cpp index 78956fd969bf2c..d894e89cec5339 100644 --- a/src/controller/CHIPDeviceController.cpp +++ b/src/controller/CHIPDeviceController.cpp @@ -864,17 +864,18 @@ CHIP_ERROR DeviceCommissioner::Shutdown() PersistDeviceList(); #if CHIP_DEVICE_CONFIG_ENABLE_COMMISSIONER_DISCOVERY // make this commissioner discoverable + if (mUdcTransportMgr != nullptr) + { + chip::Platform::Delete(mUdcTransportMgr); + mUdcTransportMgr = nullptr; + } if (mUdcServer != nullptr) { mUdcServer->SetInstanceNameResolver(nullptr); mUdcServer->SetUserConfirmationProvider(nullptr); + chip::Platform::Delete(mUdcServer); mUdcServer = nullptr; } - if (mUdcTransportMgr != nullptr) - { - chip::Platform::Delete(mUdcTransportMgr); - mUdcTransportMgr = nullptr; - } #endif // CHIP_DEVICE_CONFIG_ENABLE_COMMISSIONER_DISCOVERY DeviceController::Shutdown(); diff --git a/src/controller/python/ChipDeviceController-ScriptBinding.cpp b/src/controller/python/ChipDeviceController-ScriptBinding.cpp index b00da403104292..63366aeae0d4f4 100644 --- a/src/controller/python/ChipDeviceController-ScriptBinding.cpp +++ b/src/controller/python/ChipDeviceController-ScriptBinding.cpp @@ -284,7 +284,7 @@ ChipError::StorageType pychip_DeviceController_CloseSession(chip::Controller::De ChipError::StorageType pychip_DeviceController_DiscoverAllCommissionableNodes(chip::Controller::DeviceCommissioner * devCtrl) { - Mdns::DiscoveryFilter filter(Mdns::DiscoveryFilterType::kNone, (uint16_t) 0); + Mdns::DiscoveryFilter filter(Mdns::DiscoveryFilterType::kNone, static_cast(0)); return devCtrl->DiscoverCommissionableNodes(filter).AsInteger(); } diff --git a/src/lib/mdns/Advertiser.h b/src/lib/mdns/Advertiser.h index f421922484360a..642fa7f5cf854d 100644 --- a/src/lib/mdns/Advertiser.h +++ b/src/lib/mdns/Advertiser.h @@ -197,7 +197,7 @@ class CommissionAdvertisingParameters : public BaseAdvertisingParams GetUDCClients() { return mUdcClients; } + private: InstanceNameResolver * mInstanceNameResolver = nullptr; UserConfirmationProvider * mUserConfirmationProvider = nullptr; void OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msgBuf) override; - // Cache contains 16 clients. This may need to be tweaked. - UDCClients<16> mUdcClients; // < Active UDC clients + UDCClients mUdcClients; // < Active UDC clients }; } // namespace UserDirectedCommissioning diff --git a/src/protocols/user_directed_commissioning/UserDirectedCommissioningClient.cpp b/src/protocols/user_directed_commissioning/UserDirectedCommissioningClient.cpp index 3ef7ca8ca1f75f..6b67389653aceb 100644 --- a/src/protocols/user_directed_commissioning/UserDirectedCommissioningClient.cpp +++ b/src/protocols/user_directed_commissioning/UserDirectedCommissioningClient.cpp @@ -32,8 +32,30 @@ namespace UserDirectedCommissioning { CHIP_ERROR UserDirectedCommissioningClient::SendUDCMessage(TransportMgrBase * transportMgr, System::PacketBufferHandle && payload, chip::Transport::PeerAddress peerAddress) { - CHIP_ERROR err; + CHIP_ERROR err = EncodeUDCMessage(std::move(payload)); + if (err != CHIP_NO_ERROR) + { + return err; + } + ChipLogProgress(Inet, "Sending UDC msg"); + + // send UDC message 5 times per spec (no ACK on this message) + for (unsigned int i = 0; i < 5; i++) + { + err = transportMgr->SendMessage(peerAddress, std::move(payload)); + if (err != CHIP_NO_ERROR) + { + ChipLogError(AppServer, "UDC SendMessage failed, err: %s\n", chip::ErrorStr(err)); + return err; + } + sleep(1); + } + ChipLogProgress(Inet, "UDC msg send status %s", ErrorStr(err)); + return err; +} +CHIP_ERROR UserDirectedCommissioningClient::EncodeUDCMessage(System::PacketBufferHandle && payload) +{ PayloadHeader payloadHeader; PacketHeader packetHeader; @@ -49,11 +71,7 @@ CHIP_ERROR UserDirectedCommissioningClient::SendUDCMessage(TransportMgrBase * tr ReturnErrorOnFailure(packetHeader.EncodeBeforeData(payload)); - ChipLogProgress(Inet, "Sending UDC msg"); - err = transportMgr->SendMessage(peerAddress, std::move(payload)); - - ChipLogProgress(Inet, "UDC msg send status %s", ErrorStr(err)); - return err; + return CHIP_NO_ERROR; } } // namespace UserDirectedCommissioning diff --git a/src/protocols/user_directed_commissioning/tests/BUILD.gn b/src/protocols/user_directed_commissioning/tests/BUILD.gn new file mode 100644 index 00000000000000..4d3a51f2b95a42 --- /dev/null +++ b/src/protocols/user_directed_commissioning/tests/BUILD.gn @@ -0,0 +1,34 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import("//build_overrides/build.gni") +import("//build_overrides/chip.gni") +import("//build_overrides/nlio.gni") +import("//build_overrides/nlunit_test.gni") + +import("${chip_root}/build/chip/chip_test_suite.gni") + +chip_test_suite("tests") { + output_name = "libUserDirectedCommissioningTests" + + test_sources = [ "TestUdcMessages.cpp" ] + + public_deps = [ + "${chip_root}/src/lib/core", + "${chip_root}/src/lib/support", + "${chip_root}/src/protocols", + "${nlio_root}:nlio", + "${nlunit_test_root}:nlunit-test", + ] + + cflags = [ "-Wconversion" ] +} diff --git a/src/protocols/user_directed_commissioning/tests/TestUdcMessages.cpp b/src/protocols/user_directed_commissioning/tests/TestUdcMessages.cpp new file mode 100644 index 00000000000000..467847523550d2 --- /dev/null +++ b/src/protocols/user_directed_commissioning/tests/TestUdcMessages.cpp @@ -0,0 +1,306 @@ +#include + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include + +using namespace chip; +using namespace chip::Protocols::UserDirectedCommissioning; + +class DLL_EXPORT TestCallback : public UserConfirmationProvider, public InstanceNameResolver +{ +public: + void OnUserDirectedCommissioningRequest(const Mdns::DiscoveredNodeData & nodeData) + { + mOnUserDirectedCommissioningRequestCalled = true; + mNodeData = nodeData; + } + + void FindCommissionableNode(char * instanceName) + { + mFindCommissionableNodeCalled = true; + mInstanceName = instanceName; + } + + // virtual ~UserConfirmationProvider() = default; + Mdns::DiscoveredNodeData mNodeData; + char * mInstanceName; + + bool mOnUserDirectedCommissioningRequestCalled = false; + bool mFindCommissionableNodeCalled = false; +}; + +using DeviceTransportMgr = TransportMgr; + +void TestUDCServerClients(nlTestSuite * inSuite, void * inContext) +{ + UserDirectedCommissioningServer udcServer; + const char * instanceName1 = "servertest1"; + + // test setting UDC Clients + NL_TEST_ASSERT(inSuite, nullptr == udcServer.GetUDCClients().FindUDCClientState(instanceName1)); + udcServer.SetUDCClientProcessingState((char *) instanceName1, UDCClientProcessingState::kUserDeclined); + UDCClientState * state = udcServer.GetUDCClients().FindUDCClientState(instanceName1); + NL_TEST_ASSERT(inSuite, nullptr != state); + NL_TEST_ASSERT(inSuite, UDCClientProcessingState::kUserDeclined == state->GetUDCClientProcessingState()); +} + +void TestUDCServerUserConfirmationProvider(nlTestSuite * inSuite, void * inContext) +{ + UserDirectedCommissioningServer udcServer; + TestCallback testCallback; + const char * instanceName1 = "servertest1"; + const char * instanceName2 = "servertest2"; + UDCClientState * state; + + // setup for tests + udcServer.SetUDCClientProcessingState((char *) instanceName1, UDCClientProcessingState::kUserDeclined); + + // test empty UserConfirmationProvider + Mdns::DiscoveredNodeData nodeData; + strncpy((char *) nodeData.instanceName, instanceName2, sizeof(nodeData.instanceName)); + udcServer.OnCommissionableNodeFound(nodeData); + strncpy((char *) nodeData.instanceName, instanceName1, sizeof(nodeData.instanceName)); + udcServer.OnCommissionableNodeFound(nodeData); + state = udcServer.GetUDCClients().FindUDCClientState(instanceName1); + NL_TEST_ASSERT(inSuite, nullptr != state); + NL_TEST_ASSERT(inSuite, UDCClientProcessingState::kUserDeclined == state->GetUDCClientProcessingState()); + state = udcServer.GetUDCClients().FindUDCClientState(instanceName2); + NL_TEST_ASSERT(inSuite, nullptr == state); + + // test current state check + udcServer.SetUDCClientProcessingState((char *) instanceName1, UDCClientProcessingState::kUserDeclined); + udcServer.SetUDCClientProcessingState((char *) instanceName2, UDCClientProcessingState::kDiscoveringNode); + strncpy((char *) nodeData.instanceName, instanceName2, sizeof(nodeData.instanceName)); + udcServer.OnCommissionableNodeFound(nodeData); + strncpy((char *) nodeData.instanceName, instanceName1, sizeof(nodeData.instanceName)); + udcServer.OnCommissionableNodeFound(nodeData); + state = udcServer.GetUDCClients().FindUDCClientState(instanceName1); + NL_TEST_ASSERT(inSuite, nullptr != state); + NL_TEST_ASSERT(inSuite, UDCClientProcessingState::kUserDeclined == state->GetUDCClientProcessingState()); + state = udcServer.GetUDCClients().FindUDCClientState(instanceName2); + NL_TEST_ASSERT(inSuite, nullptr != state); + NL_TEST_ASSERT(inSuite, UDCClientProcessingState::kPromptingUser == state->GetUDCClientProcessingState()); + + // test non-empty UserConfirmationProvider + udcServer.SetUserConfirmationProvider(&testCallback); + udcServer.SetUDCClientProcessingState((char *) instanceName1, UDCClientProcessingState::kUserDeclined); + udcServer.SetUDCClientProcessingState((char *) instanceName2, UDCClientProcessingState::kDiscoveringNode); + strncpy((char *) nodeData.instanceName, instanceName1, sizeof(nodeData.instanceName)); + udcServer.OnCommissionableNodeFound(nodeData); + NL_TEST_ASSERT(inSuite, !testCallback.mOnUserDirectedCommissioningRequestCalled); + strncpy((char *) nodeData.instanceName, instanceName2, sizeof(nodeData.instanceName)); + udcServer.OnCommissionableNodeFound(nodeData); + NL_TEST_ASSERT(inSuite, testCallback.mOnUserDirectedCommissioningRequestCalled); + NL_TEST_ASSERT(inSuite, 0 == strcmp(testCallback.mNodeData.instanceName, instanceName2)); +} + +void TestUDCServerInstanceNameResolver(nlTestSuite * inSuite, void * inContext) +{ + UserDirectedCommissioningServer udcServer; + UserDirectedCommissioningClient udcClient; + TestCallback testCallback; + UDCClientState * state; + const char * instanceName1 = "servertest1"; + + // setup for tests + DeviceTransportMgr * mUdcTransportMgr = chip::Platform::New(); + mUdcTransportMgr->SetSecureSessionMgr(&udcServer); + udcServer.SetInstanceNameResolver(&testCallback); + + // set state for instance1 + udcServer.SetUDCClientProcessingState((char *) instanceName1, UDCClientProcessingState::kUserDeclined); + + // encode our client message + char nameBuffer[Mdns::kMaxInstanceNameSize + 1] = "Chris"; + System::PacketBufferHandle payloadBuf = MessagePacketBuffer::NewWithData(nameBuffer, strlen(nameBuffer)); + udcClient.EncodeUDCMessage(std::move(payloadBuf)); + + // prepare peerAddress for handleMessage + Inet::IPAddress commissioner; + Inet::IPAddress::FromString("127.0.0.1", commissioner); + uint16_t port = 11100; + Transport::PeerAddress peerAddress = Transport::PeerAddress::UDP(commissioner, port); + + // test OnMessageReceived + mUdcTransportMgr->HandleMessageReceived(peerAddress, std::move(payloadBuf)); + + // check if the state is set for the instance name sent + state = udcServer.GetUDCClients().FindUDCClientState(nameBuffer); + NL_TEST_ASSERT(inSuite, nullptr != state); + NL_TEST_ASSERT(inSuite, UDCClientProcessingState::kDiscoveringNode == state->GetUDCClientProcessingState()); + + // check if a callback happened + NL_TEST_ASSERT(inSuite, testCallback.mFindCommissionableNodeCalled); + + // reset callback tracker so we can confirm that when the + // same instance name is received, there is no callback + testCallback.mFindCommissionableNodeCalled = false; + + // reset the UDC message + udcClient.EncodeUDCMessage(std::move(payloadBuf)); + + // test OnMessageReceived again + mUdcTransportMgr->HandleMessageReceived(peerAddress, std::move(payloadBuf)); + + // verify it was not called + NL_TEST_ASSERT(inSuite, !testCallback.mFindCommissionableNodeCalled); + + // next, reset the cache state and confirm the callback + udcServer.ResetUDCClientProcessingStates(); + + // reset the UDC message + udcClient.EncodeUDCMessage(std::move(payloadBuf)); + + // test OnMessageReceived again + mUdcTransportMgr->HandleMessageReceived(peerAddress, std::move(payloadBuf)); + + // verify it was called + NL_TEST_ASSERT(inSuite, testCallback.mFindCommissionableNodeCalled); +} + +void TestUserDirectedCommissioningClientMessage(nlTestSuite * inSuite, void * inContext) +{ + char nameBuffer[Mdns::kMaxInstanceNameSize + 1] = "Chris"; + System::PacketBufferHandle payloadBuf = MessagePacketBuffer::NewWithData(nameBuffer, strlen(nameBuffer)); + UserDirectedCommissioningClient udcClient; + + // obtain the UDC message + CHIP_ERROR err = udcClient.EncodeUDCMessage(std::move(payloadBuf)); + + // check the packet header fields + PacketHeader packetHeader; + packetHeader.DecodeAndConsume(payloadBuf); + NL_TEST_ASSERT(inSuite, !packetHeader.GetFlags().Has(Header::FlagValues::kEncryptedMessage)); + + // check the payload header fields + PayloadHeader payloadHeader; + payloadHeader.DecodeAndConsume(payloadBuf); + NL_TEST_ASSERT(inSuite, payloadHeader.GetMessageType() == to_underlying(MsgType::IdentificationDeclaration)); + NL_TEST_ASSERT(inSuite, payloadHeader.GetProtocolID() == Protocols::UserDirectedCommissioning::Id); + NL_TEST_ASSERT(inSuite, !payloadHeader.NeedsAck()); + NL_TEST_ASSERT(inSuite, payloadHeader.IsInitiator()); + + // check the payload + char instanceName[chip::Mdns::kMaxInstanceNameSize + 1]; + size_t instanceNameLength = (payloadBuf->DataLength() > (chip::Mdns::kMaxInstanceNameSize)) ? chip::Mdns::kMaxInstanceNameSize + : payloadBuf->DataLength(); + payloadBuf->Read(Uint8::from_char(instanceName), instanceNameLength); + instanceName[instanceNameLength] = '\0'; + ChipLogProgress(Inet, "UDC instance=%s", instanceName); + NL_TEST_ASSERT(inSuite, strcmp(instanceName, nameBuffer) == 0); + + // verify no errors + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); +} + +void TestUDCClients(nlTestSuite * inSuite, void * inContext) +{ + UDCClients<3> mUdcClients; + const char * instanceName1 = "test1"; + const char * instanceName2 = "test2"; + const char * instanceName3 = "test3"; + const char * instanceName4 = "test4"; + + // test base case + UDCClientState * state = mUdcClients.FindUDCClientState(instanceName1); + NL_TEST_ASSERT(inSuite, state == nullptr); + + // test max size + NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == mUdcClients.CreateNewUDCClientState(instanceName1, &state)); + NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == mUdcClients.CreateNewUDCClientState(instanceName2, &state)); + NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == mUdcClients.CreateNewUDCClientState(instanceName3, &state)); + NL_TEST_ASSERT(inSuite, CHIP_ERROR_NO_MEMORY == mUdcClients.CreateNewUDCClientState(instanceName4, &state)); + + // test reset + mUdcClients.ResetUDCClientStates(); + NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == mUdcClients.CreateNewUDCClientState(instanceName4, &state)); + + // test find + NL_TEST_ASSERT(inSuite, nullptr == mUdcClients.FindUDCClientState(instanceName1)); + NL_TEST_ASSERT(inSuite, nullptr == mUdcClients.FindUDCClientState(instanceName2)); + NL_TEST_ASSERT(inSuite, nullptr == mUdcClients.FindUDCClientState(instanceName3)); + state = mUdcClients.FindUDCClientState(instanceName4); + NL_TEST_ASSERT(inSuite, nullptr != state); + + // test expiry + state->Reset(); + NL_TEST_ASSERT(inSuite, nullptr == mUdcClients.FindUDCClientState(instanceName4)); + + // test re-activation + NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == mUdcClients.CreateNewUDCClientState(instanceName4, &state)); + uint64_t expirationTime = state->GetExpirationTimeMs(); + state->SetExpirationTimeMs(expirationTime - 1); + NL_TEST_ASSERT(inSuite, (expirationTime - 1) == state->GetExpirationTimeMs()); + mUdcClients.MarkUDCClientActive(state); + NL_TEST_ASSERT(inSuite, (expirationTime - 1) < state->GetExpirationTimeMs()); +} + +// Test Suite + +/** + * Test Suite that lists all the test functions. + */ +// clang-format off +static const nlTest sTests[] = +{ + NL_TEST_DEF("TestUDCServerClients", TestUDCServerClients), + NL_TEST_DEF("TestUDCServerUserConfirmationProvider", TestUDCServerUserConfirmationProvider), + NL_TEST_DEF("TestUDCServerInstanceNameResolver", TestUDCServerInstanceNameResolver), + NL_TEST_DEF("TestUserDirectedCommissioningClientMessage", TestUserDirectedCommissioningClientMessage), + NL_TEST_DEF("TestUDCClients", TestUDCClients), + + NL_TEST_SENTINEL() +}; +// clang-format on + +/** + * Set up the test suite. + */ +static int TestSetup(void * inContext) +{ + CHIP_ERROR error = chip::Platform::MemoryInit(); + if (error != CHIP_NO_ERROR) + return FAILURE; + return SUCCESS; +} + +/** + * Tear down the test suite. + */ +static int TestTeardown(void * inContext) +{ + chip::Platform::MemoryShutdown(); + return SUCCESS; +} + +// clang-format off +static nlTestSuite sSuite = +{ + "Test-CHIP-UdcMessages", + &sTests[0], + TestSetup, + TestTeardown, +}; +// clang-format on + +/** + * Main + */ +int TestUdcMessages() +{ + // Run test suit against one context + nlTestRunner(&sSuite, nullptr); + + return (nlTestRunnerStats(&sSuite)); +} + +CHIP_REGISTER_TEST_SUITE(TestUdcMessages)