diff --git a/.github/workflows/darwin.yaml b/.github/workflows/darwin.yaml index ead897aa3e924e..3880c6bbe8f8fe 100644 --- a/.github/workflows/darwin.yaml +++ b/.github/workflows/darwin.yaml @@ -114,7 +114,7 @@ jobs: run: | mkdir -p /tmp/darwin/framework-tests ../../../out/debug/chip-all-clusters-app > >(tee /tmp/darwin/framework-tests/all-cluster-app.log) 2> >(tee /tmp/darwin/framework-tests/all-cluster-app-err.log >&2) & - xcodebuild test -target "CHIP" -scheme "CHIP Framework Tests" -sdk macosx > >(tee /tmp/darwin/framework-tests/darwin-tests.log) 2> >(tee /tmp/darwin/framework-tests/darwin-tests-err.log >&2) + xcodebuild test -target "CHIP" -scheme "CHIP Framework Tests" -sdk macosx OTHER_CFLAGS='${inherited} -Werror -Wno-documentation -Wno-conditional-uninitialized -Wno-incomplete-umbrella' > >(tee /tmp/darwin/framework-tests/darwin-tests.log) 2> >(tee /tmp/darwin/framework-tests/darwin-tests-err.log >&2) working-directory: src/darwin/Framework - name: Uploading log files uses: actions/upload-artifact@v2 diff --git a/.vscode/settings.json b/.vscode/settings.json index c4290296d8315f..6ff631080c2465 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -156,5 +156,6 @@ "clang-format.fallbackStyle": "WebKit", "files.trimFinalNewlines": true, "C_Cpp.default.cppStandard": "gnu++14", - "C_Cpp.default.cStandard": "gnu11" + "C_Cpp.default.cStandard": "gnu11", + "cmake.configureOnOpen": false } diff --git a/.vscode/tasks.json b/.vscode/tasks.json index fda87757b21a9c..a0f49412f87aed 100644 --- a/.vscode/tasks.json +++ b/.vscode/tasks.json @@ -170,6 +170,17 @@ "${workspaceFolder}/src/test_driver/mbed/unit_tests/build" ] } + }, + { + "label": "Flash EFR32 board", + "type": "shell", + "command": "python3", + "args": [ + "${workspaceFolder}/out/${input:exampleTarget}/chip-efr32-*.flash.py" + ], + "problemMatcher": { + "base": "$gcc" + } } ], "inputs": [ @@ -229,10 +240,45 @@ "android-x64-chip-tool", "android-x86-chip-tool", "efr32-brd4161a-light", + "efr32-brd4163a-light", + "efr32-brd4164a-light", + "efr32-brd4166a-light", + "efr32-brd4170a-light", + "efr32-brd4186a-light", + "efr32-brd4187a-light", + "efr32-brd4304a-light", "efr32-brd4161a-light-rpc", + "efr32-brd4163a-light-rpc", + "efr32-brd4164a-light-rpc", + "efr32-brd4166a-light-rpc", + "efr32-brd4170a-light-rpc", + "efr32-brd4186a-light-rpc", + "efr32-brd4187a-light-rpc", + "efr32-brd4304a-light-rpc", "efr32-brd4161a-lock", + "efr32-brd4163a-lock", + "efr32-brd4164a-lock", + "efr32-brd4166a-lock", + "efr32-brd4170a-lock", + "efr32-brd4186a-lock", + "efr32-brd4187a-lock", + "efr32-brd4304a-lock", "efr32-brd4161a-unit-test", + "efr32-brd4163a-unit-test", + "efr32-brd4164a-unit-test", + "efr32-brd4166a-unit-test", + "efr32-brd4170a-unit-test", + "efr32-brd4186a-unit-test", + "efr32-brd4187a-unit-test", + "efr32-brd4304a-unit-test", "efr32-brd4161a-window-covering", + "efr32-brd4163a-window-covering", + "efr32-brd4164a-window-covering", + "efr32-brd4166a-window-covering", + "efr32-brd4170a-window-covering", + "efr32-brd4186a-window-covering", + "efr32-brd4187a-window-covering", + "efr32-brd4304a-window-covering", "esp32-c3devkit-all-clusters", "esp32-devkitc-all-clusters", "esp32-devkitc-all-clusters-ipv6only", diff --git a/examples/all-clusters-app/esp32/main/CMakeLists.txt b/examples/all-clusters-app/esp32/main/CMakeLists.txt index 29430a67c079a6..804eb4c365185b 100644 --- a/examples/all-clusters-app/esp32/main/CMakeLists.txt +++ b/examples/all-clusters-app/esp32/main/CMakeLists.txt @@ -225,6 +225,7 @@ target_compile_options(${COMPONENT_LIB} PRIVATE "-DPW_RPC_BUTTON_SERVICE=1" "-DPW_RPC_DEVICE_SERVICE=1" "-DPW_RPC_LIGHTING_SERVICE=1" - "-DPW_RPC_LOCKING_SERVICE=1") + "-DPW_RPC_LOCKING_SERVICE=1" + "-DPW_RPC_TRACING_SERVICE=1") endif (CONFIG_ENABLE_PW_RPC) diff --git a/examples/platform/esp32/Rpc.cpp b/examples/platform/esp32/Rpc.cpp index 3cfe0070e815b6..9286191dc4eeca 100644 --- a/examples/platform/esp32/Rpc.cpp +++ b/examples/platform/esp32/Rpc.cpp @@ -52,6 +52,23 @@ #include "pigweed/rpc_services/Locking.h" #endif // defined(PW_RPC_LOCKING_SERVICE) && PW_RPC_LOCKING_SERVICE +#if defined(PW_RPC_TRACING_SERVICE) && PW_RPC_TRACING_SERVICE +#include "pw_trace/trace.h" +#include "pw_trace_tokenized/trace_rpc_service_nanopb.h" + +// Define trace time for pw_trace +PW_TRACE_TIME_TYPE pw_trace_GetTraceTime() +{ + return (PW_TRACE_TIME_TYPE) chip::System::SystemClock().GetMonotonicMicroseconds64().count(); +} +// Microsecond time source +size_t pw_trace_GetTraceTimeTicksPerSecond() +{ + return 1000000; +} + +#endif // defined(PW_RPC_TRACING_SERVICE) && PW_RPC_TRACING_SERVICE + namespace chip { namespace rpc { @@ -122,6 +139,10 @@ Lighting lighting_service; Locking locking; #endif // defined(PW_RPC_LOCKING_SERVICE) && PW_RPC_LOCKING_SERVICE +#if defined(PW_RPC_TRACING_SERVICE) && PW_RPC_TRACING_SERVICE +pw::trace::TraceService trace_service; +#endif // defined(PW_RPC_TRACING_SERVICE) && PW_RPC_TRACING_SERVICE + void RegisterServices(pw::rpc::Server & server) { #if defined(PW_RPC_ATTRIBUTE_SERVICE) && PW_RPC_ATTRIBUTE_SERVICE @@ -143,6 +164,10 @@ void RegisterServices(pw::rpc::Server & server) #if defined(PW_RPC_LOCKING_SERVICE) && PW_RPC_LOCKING_SERVICE server.RegisterService(locking); #endif // defined(PW_RPC_LOCKING_SERVICE) && PW_RPC_LOCKING_SERVICE + +#if defined(PW_RPC_TRACING_SERVICE) && PW_RPC_TRACING_SERVICE + server.RegisterService(trace_service); +#endif // defined(PW_RPC_TRACING_SERVICE) && PW_RPC_TRACING_SERVICE } } // namespace diff --git a/scripts/bootstrap.sh b/scripts/bootstrap.sh index 7570976129905c..a0a8f2e7c615ae 100644 --- a/scripts/bootstrap.sh +++ b/scripts/bootstrap.sh @@ -38,11 +38,17 @@ _bootstrap_or_activate() { local _CHIP_BANNER="$( cat <(apContext); TLVReader reader(aReader); - if (aReader.GetTag() == TLV::ContextTag(to_underlying(EventDataIB::Tag::kPath))) - { - err = - ctx->mpWriter->Put(TLV::ContextTag(to_underlying(EventDataIB::Tag::kEventNumber)), ctx->mpContext->mCurrentEventNumber); - } - if (aReader.GetTag() == TLV::ContextTag(to_underlying(EventDataIB::Tag::kDeltaSystemTimestamp)) || aReader.GetTag() == TLV::ContextTag(to_underlying(EventDataIB::Tag::kDeltaEpochTimestamp))) { @@ -490,6 +484,11 @@ CHIP_ERROR EventManagement::CopyAndAdjustDeltaTime(const TLVReader & aReader, si err = ctx->mpWriter->CopyElement(reader); } + if (aReader.GetTag() == TLV::ContextTag(to_underlying(EventDataIB::Tag::kPath))) + { + err = + ctx->mpWriter->Put(TLV::ContextTag(to_underlying(EventDataIB::Tag::kEventNumber)), ctx->mpContext->mCurrentEventNumber); + } return err; } @@ -762,7 +761,6 @@ CHIP_ERROR EventManagement::FetchEventsSince(TLVWriter & aWriter, ClusterInfo * } exit: - ChipLogProgress(EventLogging, "Debug log, err: %s\n", chip::ErrorStr(err)); aEventNumber = context.mCurrentEventNumber; aEventCount += context.mEventCount; return err; diff --git a/src/app/MessageDef/ReportDataMessage.h b/src/app/MessageDef/ReportDataMessage.h index 3b8d098925469d..fa4aeb2977b019 100644 --- a/src/app/MessageDef/ReportDataMessage.h +++ b/src/app/MessageDef/ReportDataMessage.h @@ -40,11 +40,11 @@ namespace app { namespace ReportDataMessage { enum { - kCsTag_SuppressResponse = 0, - kCsTag_SubscriptionId = 1, - kCsTag_AttributeReportIBs = 2, - kCsTag_EventReports = 3, - kCsTag_MoreChunkedMessages = 4, + kCsTag_SubscriptionId = 0, + kCsTag_AttributeReportIBs = 1, + kCsTag_EventReports = 2, + kCsTag_MoreChunkedMessages = 3, + kCsTag_SuppressResponse = 4, }; class Parser : public StructParser diff --git a/src/app/MessageDef/StructParser.cpp b/src/app/MessageDef/StructParser.cpp index 3e78b631c59727..c1886939ce5993 100644 --- a/src/app/MessageDef/StructParser.cpp +++ b/src/app/MessageDef/StructParser.cpp @@ -23,7 +23,37 @@ CHIP_ERROR StructParser::Init(const TLV::TLVReader & aReader) mReader.Init(aReader); VerifyOrReturnError(TLV::kTLVType_Structure == mReader.GetType(), CHIP_ERROR_WRONG_TLV_TYPE); ReturnErrorOnFailure(mReader.EnterContainer(mOuterContainerType)); + ReturnErrorOnFailure(CheckSchemaOrdering()); return CHIP_NO_ERROR; } + +CHIP_ERROR StructParser::CheckSchemaOrdering() const +{ + CHIP_ERROR err = CHIP_NO_ERROR; + TLV::TLVReader reader; + reader.Init(mReader); + uint32_t preTagNum = 0; + bool first = true; + while (CHIP_NO_ERROR == (err = reader.Next())) + { + VerifyOrReturnError(TLV::IsContextTag(reader.GetTag()), CHIP_ERROR_INVALID_TLV_TAG); + uint32_t tagNum = TLV::TagNumFromTag(reader.GetTag()); + if (first || (preTagNum < tagNum)) + { + preTagNum = tagNum; + } + else + { + return CHIP_ERROR_INVALID_TLV_TAG; + } + first = false; + } + if (CHIP_END_OF_TLV == err) + { + err = CHIP_NO_ERROR; + } + ReturnErrorOnFailure(err); + return reader.ExitContainer(mOuterContainerType); +} } // namespace app } // namespace chip diff --git a/src/app/MessageDef/StructParser.h b/src/app/MessageDef/StructParser.h index 6206ec1bd29543..acf4d94c8d20a8 100644 --- a/src/app/MessageDef/StructParser.h +++ b/src/app/MessageDef/StructParser.h @@ -33,6 +33,8 @@ class StructParser : public Parser * @return #CHIP_NO_ERROR on success */ CHIP_ERROR Init(const TLV::TLVReader & aReader); + + CHIP_ERROR CheckSchemaOrdering() const; }; } // namespace app } // namespace chip diff --git a/src/app/ReadClient.cpp b/src/app/ReadClient.cpp index 52dd3749ed6a11..297a418d29fc57 100644 --- a/src/app/ReadClient.cpp +++ b/src/app/ReadClient.cpp @@ -140,6 +140,15 @@ CHIP_ERROR ReadClient::SendReadRequest(ReadPrepareParams & aReadPrepareParams) err = request.Init(&writer); SuccessOrExit(err); + if (aReadPrepareParams.mAttributePathParamsListSize != 0 && aReadPrepareParams.mpAttributePathParamsList != nullptr) + { + AttributePathIBs::Builder attributePathListBuilder = request.CreateAttributeRequests(); + SuccessOrExit(err = attributePathListBuilder.GetError()); + err = GenerateAttributePathList(attributePathListBuilder, aReadPrepareParams.mpAttributePathParamsList, + aReadPrepareParams.mAttributePathParamsListSize); + SuccessOrExit(err); + } + if (aReadPrepareParams.mEventPathParamsListSize != 0 && aReadPrepareParams.mpEventPathParamsList != nullptr) { EventPathIBs::Builder & eventPathListBuilder = request.CreateEventRequests(); @@ -160,15 +169,6 @@ CHIP_ERROR ReadClient::SendReadRequest(ReadPrepareParams & aReadPrepareParams) } } - if (aReadPrepareParams.mAttributePathParamsListSize != 0 && aReadPrepareParams.mpAttributePathParamsList != nullptr) - { - AttributePathIBs::Builder attributePathListBuilder = request.CreateAttributeRequests(); - SuccessOrExit(err = attributePathListBuilder.GetError()); - err = GenerateAttributePathList(attributePathListBuilder, aReadPrepareParams.mpAttributePathParamsList, - aReadPrepareParams.mAttributePathParamsListSize); - SuccessOrExit(err); - } - request.IsFabricFiltered(false).EndOfReadRequestMessage(); SuccessOrExit(err = request.GetError()); @@ -632,12 +632,27 @@ CHIP_ERROR ReadClient::SendSubscribeRequest(ReadPrepareParams & aReadPreparePara VerifyOrExit(mpCallback != nullptr, err = CHIP_ERROR_INCORRECT_STATE); msgBuf = System::PacketBufferHandle::New(kMaxSecureSduLengthBytes); VerifyOrExit(!msgBuf.IsNull(), err = CHIP_ERROR_NO_MEMORY); + VerifyOrExit(aReadPrepareParams.mMinIntervalFloorSeconds < aReadPrepareParams.mMaxIntervalCeilingSeconds, + err = CHIP_ERROR_INVALID_ARGUMENT); writer.Init(std::move(msgBuf)); err = request.Init(&writer); SuccessOrExit(err); + request.KeepSubscriptions(aReadPrepareParams.mKeepSubscriptions) + .MinIntervalFloorSeconds(aReadPrepareParams.mMinIntervalFloorSeconds) + .MaxIntervalCeilingSeconds(aReadPrepareParams.mMaxIntervalCeilingSeconds); + + if (aReadPrepareParams.mAttributePathParamsListSize != 0 && aReadPrepareParams.mpAttributePathParamsList != nullptr) + { + AttributePathIBs::Builder & attributePathListBuilder = request.CreateAttributeRequests(); + SuccessOrExit(err = attributePathListBuilder.GetError()); + err = GenerateAttributePathList(attributePathListBuilder, aReadPrepareParams.mpAttributePathParamsList, + aReadPrepareParams.mAttributePathParamsListSize); + SuccessOrExit(err); + } + if (aReadPrepareParams.mEventPathParamsListSize != 0 && aReadPrepareParams.mpEventPathParamsList != nullptr) { EventPathIBs::Builder & eventPathListBuilder = request.CreateEventRequests(); @@ -659,22 +674,7 @@ CHIP_ERROR ReadClient::SendSubscribeRequest(ReadPrepareParams & aReadPreparePara } } - if (aReadPrepareParams.mAttributePathParamsListSize != 0 && aReadPrepareParams.mpAttributePathParamsList != nullptr) - { - AttributePathIBs::Builder & attributePathListBuilder = request.CreateAttributeRequests(); - SuccessOrExit(err = attributePathListBuilder.GetError()); - err = GenerateAttributePathList(attributePathListBuilder, aReadPrepareParams.mpAttributePathParamsList, - aReadPrepareParams.mAttributePathParamsListSize); - SuccessOrExit(err); - } - - VerifyOrExit(aReadPrepareParams.mMinIntervalFloorSeconds < aReadPrepareParams.mMaxIntervalCeilingSeconds, - err = CHIP_ERROR_INVALID_ARGUMENT); - request.MinIntervalFloorSeconds(aReadPrepareParams.mMinIntervalFloorSeconds) - .MaxIntervalCeilingSeconds(aReadPrepareParams.mMaxIntervalCeilingSeconds) - .KeepSubscriptions(aReadPrepareParams.mKeepSubscriptions) - .IsFabricFiltered(false) - .EndOfSubscribeRequestMessage(); + request.IsFabricFiltered(false).EndOfSubscribeRequestMessage(); SuccessOrExit(err = request.GetError()); err = writer.Finalize(&msgBuf); diff --git a/src/app/WriteClient.cpp b/src/app/WriteClient.cpp index 52842c52b2aac7..50d6f71d29c9cd 100644 --- a/src/app/WriteClient.cpp +++ b/src/app/WriteClient.cpp @@ -43,11 +43,10 @@ CHIP_ERROR WriteClient::Init(Messaging::ExchangeManager * apExchangeMgr, Callbac mMessageWriter.Init(std::move(packet)); ReturnErrorOnFailure(mWriteRequestBuilder.Init(&mMessageWriter)); - mWriteRequestBuilder.TimedRequest(false).IsFabricFiltered(false); + mWriteRequestBuilder.TimedRequest(false); ReturnErrorOnFailure(mWriteRequestBuilder.GetError()); attributeDataIBsBuilder = mWriteRequestBuilder.CreateWriteRequests(); ReturnErrorOnFailure(attributeDataIBsBuilder.GetError()); - ClearExistingExchangeContext(); mpExchangeMgr = apExchangeMgr; mpCallback = apCallback; @@ -139,6 +138,9 @@ CHIP_ERROR WriteClient::PrepareAttribute(const AttributePathParams & attributePa VerifyOrReturnError(attributePathParams.IsValidAttributePath(), CHIP_ERROR_INVALID_PATH_LIST); AttributeDataIB::Builder attributeDataIB = mWriteRequestBuilder.GetWriteRequests().CreateAttributeDataIBBuilder(); ReturnErrorOnFailure(attributeDataIB.GetError()); + // TODO: Add attribute version support + attributeDataIB.DataVersion(0); + ReturnErrorOnFailure(attributeDataIB.GetError()); ReturnErrorOnFailure(attributeDataIB.CreatePath().Encode(attributePathParams)); return CHIP_NO_ERROR; } @@ -148,9 +150,6 @@ CHIP_ERROR WriteClient::FinishAttribute() CHIP_ERROR err = CHIP_NO_ERROR; AttributeDataIB::Builder AttributeDataIB = mWriteRequestBuilder.GetWriteRequests().GetAttributeDataIBBuilder(); - - // TODO: Add attribute version support - AttributeDataIB.DataVersion(0); AttributeDataIB.EndOfAttributeDataIB(); SuccessOrExit(err = AttributeDataIB.GetError()); MoveToState(State::AddAttribute); @@ -173,7 +172,7 @@ CHIP_ERROR WriteClient::FinalizeMessage(System::PacketBufferHandle & aPacket) err = AttributeDataIBsBuilder.GetError(); SuccessOrExit(err); - mWriteRequestBuilder.EndOfWriteRequestMessage(); + mWriteRequestBuilder.IsFabricFiltered(false).EndOfWriteRequestMessage(); err = mWriteRequestBuilder.GetError(); SuccessOrExit(err); diff --git a/src/app/clusters/general_diagnostics_server/general_diagnostics_server.cpp b/src/app/clusters/general_diagnostics_server/general_diagnostics_server.cpp index 688c3ba1ef6722..d8495dc6df6593 100644 --- a/src/app/clusters/general_diagnostics_server/general_diagnostics_server.cpp +++ b/src/app/clusters/general_diagnostics_server/general_diagnostics_server.cpp @@ -28,6 +28,7 @@ using namespace chip; using namespace chip::app; using namespace chip::app::Clusters; +using namespace chip::app::Clusters::GeneralDiagnostics; using namespace chip::app::Clusters::GeneralDiagnostics::Attributes; using namespace chip::DeviceLayer; using chip::DeviceLayer::ConnectivityMgr; @@ -215,7 +216,8 @@ class GeneralDiagnosticsDelegate : public DeviceLayer::ConnectivityManagerDelega } // Get called when the Node detects a hardware fault has been raised. - void OnHardwareFaultsDetected() override + void OnHardwareFaultsDetected(GeneralFaults & previous, + GeneralFaults & current) override { ChipLogProgress(Zcl, "GeneralDiagnosticsDelegate: OnHardwareFaultsDetected"); @@ -236,7 +238,7 @@ class GeneralDiagnosticsDelegate : public DeviceLayer::ConnectivityManagerDelega } // Get called when the Node detects a radio fault has been raised. - void OnRadioFaultsDetected() override + void OnRadioFaultsDetected(GeneralFaults & previous, GeneralFaults & current) override { ChipLogProgress(Zcl, "GeneralDiagnosticsDelegate: OnHardwareFaultsDetected"); @@ -257,7 +259,7 @@ class GeneralDiagnosticsDelegate : public DeviceLayer::ConnectivityManagerDelega } // Get called when the Node detects a network fault has been raised. - void OnNetworkFaultsDetected() override + void OnNetworkFaultsDetected(GeneralFaults & previous, GeneralFaults & current) override { ChipLogProgress(Zcl, "GeneralDiagnosticsDelegate: OnHardwareFaultsDetected"); diff --git a/src/app/tests/TestMessageDef.cpp b/src/app/tests/TestMessageDef.cpp index db547413706b3f..b984621e2f7e1b 100644 --- a/src/app/tests/TestMessageDef.cpp +++ b/src/app/tests/TestMessageDef.cpp @@ -323,7 +323,7 @@ void BuildEventDataIB(nlTestSuite * apSuite, EventDataIB::Builder & aEventDataIB NL_TEST_ASSERT(apSuite, eventPathBuilder.GetError() == CHIP_NO_ERROR); BuildEventPath(apSuite, eventPathBuilder); - aEventDataIBBuilder.Priority(2).EventNumber(3).EpochTimestamp(4).SystemTimestamp(5).DeltaEpochTimestamp(6).DeltaSystemTimestamp( + aEventDataIBBuilder.EventNumber(2).Priority(3).EpochTimestamp(4).SystemTimestamp(5).DeltaEpochTimestamp(6).DeltaSystemTimestamp( 7); err = aEventDataIBBuilder.GetError(); NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); @@ -365,10 +365,10 @@ void ParseEventDataIB(nlTestSuite * apSuite, EventDataIB::Parser & aEventDataIBP err = aEventDataIBParser.GetPath(&eventPath); NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); } - err = aEventDataIBParser.GetPriority(&priorityLevel); - NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR && priorityLevel == 2); err = aEventDataIBParser.GetEventNumber(&number); - NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR && number == 3); + NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR && number == 2); + err = aEventDataIBParser.GetPriority(&priorityLevel); + NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR && priorityLevel == 3); err = aEventDataIBParser.GetEpochTimestamp(&EpochTimestamp); NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR && EpochTimestamp == 4); err = aEventDataIBParser.GetSystemTimestamp(&systemTimestamp); @@ -535,6 +535,7 @@ void BuildAttributeDataIB(nlTestSuite * apSuite, AttributeDataIB::Builder & aAtt { CHIP_ERROR err = CHIP_NO_ERROR; + aAttributeDataIBBuilder.DataVersion(2); AttributePathIB::Builder attributePathBuilder = aAttributeDataIBBuilder.CreatePath(); NL_TEST_ASSERT(apSuite, aAttributeDataIBBuilder.GetError() == CHIP_NO_ERROR); BuildAttributePathIB(apSuite, attributePathBuilder); @@ -553,7 +554,7 @@ void BuildAttributeDataIB(nlTestSuite * apSuite, AttributeDataIB::Builder & aAtt err = pWriter->EndContainer(dummyType); NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); } - aAttributeDataIBBuilder.DataVersion(2); + err = aAttributeDataIBBuilder.GetError(); NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); aAttributeDataIBBuilder.EndOfAttributeDataIB(); @@ -963,7 +964,7 @@ void BuildReportDataMessage(nlTestSuite * apSuite, chip::TLV::TLVWriter & aWrite err = reportDataMessageBuilder.Init(&aWriter); NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); - reportDataMessageBuilder.SuppressResponse(true).SubscriptionId(2); + reportDataMessageBuilder.SubscriptionId(2); NL_TEST_ASSERT(apSuite, reportDataMessageBuilder.GetError() == CHIP_NO_ERROR); AttributeReportIBs::Builder AttributeReportIBs = reportDataMessageBuilder.CreateAttributeReportIBs(); @@ -974,7 +975,7 @@ void BuildReportDataMessage(nlTestSuite * apSuite, chip::TLV::TLVWriter & aWrite NL_TEST_ASSERT(apSuite, reportDataMessageBuilder.GetError() == CHIP_NO_ERROR); BuildEventReports(apSuite, EventReportIBs); - reportDataMessageBuilder.MoreChunkedMessages(true); + reportDataMessageBuilder.MoreChunkedMessages(true).SuppressResponse(true); NL_TEST_ASSERT(apSuite, reportDataMessageBuilder.GetError() == CHIP_NO_ERROR); reportDataMessageBuilder.EndOfReportDataMessage(); @@ -1170,26 +1171,26 @@ void BuildSubscribeRequestMessage(nlTestSuite * apSuite, chip::TLV::TLVWriter & err = subscribeRequestBuilder.Init(&aWriter); NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); - AttributePathIBs::Builder attributePathIBs = subscribeRequestBuilder.CreateAttributeRequests(); + subscribeRequestBuilder.KeepSubscriptions(true); NL_TEST_ASSERT(apSuite, subscribeRequestBuilder.GetError() == CHIP_NO_ERROR); - BuildAttributePathList(apSuite, attributePathIBs); - EventPathIBs::Builder eventPathList = subscribeRequestBuilder.CreateEventRequests(); + subscribeRequestBuilder.MinIntervalFloorSeconds(2); NL_TEST_ASSERT(apSuite, subscribeRequestBuilder.GetError() == CHIP_NO_ERROR); - BuildEventPaths(apSuite, eventPathList); - EventFilterIBs::Builder eventFilters = subscribeRequestBuilder.CreateEventFilters(); + subscribeRequestBuilder.MaxIntervalCeilingSeconds(3); NL_TEST_ASSERT(apSuite, subscribeRequestBuilder.GetError() == CHIP_NO_ERROR); - BuildEventFilters(apSuite, eventFilters); - subscribeRequestBuilder.MinIntervalFloorSeconds(2); + AttributePathIBs::Builder attributePathIBs = subscribeRequestBuilder.CreateAttributeRequests(); NL_TEST_ASSERT(apSuite, subscribeRequestBuilder.GetError() == CHIP_NO_ERROR); + BuildAttributePathList(apSuite, attributePathIBs); - subscribeRequestBuilder.MaxIntervalCeilingSeconds(3); + EventPathIBs::Builder eventPathList = subscribeRequestBuilder.CreateEventRequests(); NL_TEST_ASSERT(apSuite, subscribeRequestBuilder.GetError() == CHIP_NO_ERROR); + BuildEventPaths(apSuite, eventPathList); - subscribeRequestBuilder.KeepSubscriptions(true); + EventFilterIBs::Builder eventFilters = subscribeRequestBuilder.CreateEventFilters(); NL_TEST_ASSERT(apSuite, subscribeRequestBuilder.GetError() == CHIP_NO_ERROR); + BuildEventFilters(apSuite, eventFilters); subscribeRequestBuilder.IsProxy(true); NL_TEST_ASSERT(apSuite, subscribeRequestBuilder.GetError() == CHIP_NO_ERROR); diff --git a/src/app/tests/TestReadInteraction.cpp b/src/app/tests/TestReadInteraction.cpp index 3617cf01abbcff..5f20fabdb8f8d6 100644 --- a/src/app/tests/TestReadInteraction.cpp +++ b/src/app/tests/TestReadInteraction.cpp @@ -248,11 +248,13 @@ CHIP_ERROR ReadSingleClusterData(FabricIndex aAccessingFabricIndex, const Concre } attributeData = aAttributeReport.CreateAttributeData(); + attributeData.DataVersion(0); + ReturnErrorOnFailure(attributeData.GetError()); attributePath = attributeData.CreatePath(); attributePath.Endpoint(aPath.mEndpointId).Cluster(aPath.mClusterId).Attribute(aPath.mAttributeId).EndOfAttributePathIB(); ReturnErrorOnFailure(attributePath.GetError()); ReturnErrorOnFailure(AttributeValueEncoder(attributeData.GetWriter(), 0).Encode(kTestFieldValue1)); - attributeData.DataVersion(0).EndOfAttributeDataIB(); + attributeData.EndOfAttributeDataIB(); ReturnErrorOnFailure(attributeData.GetError()); return CHIP_NO_ERROR; } @@ -298,9 +300,6 @@ void TestReadInteraction::GenerateReportData(nlTestSuite * apSuite, void * apCon err = reportDataMessageBuilder.Init(&writer); NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); - reportDataMessageBuilder.SuppressResponse(aSuppressResponse); - NL_TEST_ASSERT(apSuite, reportDataMessageBuilder.GetError() == CHIP_NO_ERROR); - AttributeReportIBs::Builder attributeReportIBsBuilder = reportDataMessageBuilder.CreateAttributeReportIBs(); NL_TEST_ASSERT(apSuite, reportDataMessageBuilder.GetError() == CHIP_NO_ERROR); @@ -310,6 +309,10 @@ void TestReadInteraction::GenerateReportData(nlTestSuite * apSuite, void * apCon AttributeDataIB::Builder attributeDataIBBuilder = attributeReportIBBuilder.CreateAttributeData(); NL_TEST_ASSERT(apSuite, attributeReportIBBuilder.GetError() == CHIP_NO_ERROR); + attributeDataIBBuilder.DataVersion(2); + err = attributeDataIBBuilder.GetError(); + NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); + AttributePathIB::Builder attributePathBuilder = attributeDataIBBuilder.CreatePath(); NL_TEST_ASSERT(apSuite, attributeDataIBBuilder.GetError() == CHIP_NO_ERROR); @@ -325,10 +328,6 @@ void TestReadInteraction::GenerateReportData(nlTestSuite * apSuite, void * apCon err = attributePathBuilder.GetError(); NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); - attributeDataIBBuilder.DataVersion(2); - err = attributeDataIBBuilder.GetError(); - NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); - // Construct attribute data { chip::TLV::TLVWriter * pWriter = attributeDataIBBuilder.GetWriter(); @@ -356,6 +355,9 @@ void TestReadInteraction::GenerateReportData(nlTestSuite * apSuite, void * apCon reportDataMessageBuilder.MoreChunkedMessages(false); NL_TEST_ASSERT(apSuite, reportDataMessageBuilder.GetError() == CHIP_NO_ERROR); + reportDataMessageBuilder.SuppressResponse(aSuppressResponse); + NL_TEST_ASSERT(apSuite, reportDataMessageBuilder.GetError() == CHIP_NO_ERROR); + reportDataMessageBuilder.EndOfReportDataMessage(); NL_TEST_ASSERT(apSuite, reportDataMessageBuilder.GetError() == CHIP_NO_ERROR); @@ -930,6 +932,15 @@ void TestReadInteraction::TestProcessSubscribeRequest(nlTestSuite * apSuite, voi err = subscribeRequestBuilder.Init(&writer); NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); + subscribeRequestBuilder.KeepSubscriptions(true); + NL_TEST_ASSERT(apSuite, subscribeRequestBuilder.GetError() == CHIP_NO_ERROR); + + subscribeRequestBuilder.MinIntervalFloorSeconds(2); + NL_TEST_ASSERT(apSuite, subscribeRequestBuilder.GetError() == CHIP_NO_ERROR); + + subscribeRequestBuilder.MaxIntervalCeilingSeconds(3); + NL_TEST_ASSERT(apSuite, subscribeRequestBuilder.GetError() == CHIP_NO_ERROR); + AttributePathIBs::Builder attributePathListBuilder = subscribeRequestBuilder.CreateAttributeRequests(); NL_TEST_ASSERT(apSuite, attributePathListBuilder.GetError() == CHIP_NO_ERROR); @@ -944,15 +955,6 @@ void TestReadInteraction::TestProcessSubscribeRequest(nlTestSuite * apSuite, voi err = attributePathListBuilder.GetError(); NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); - subscribeRequestBuilder.MinIntervalFloorSeconds(2); - NL_TEST_ASSERT(apSuite, subscribeRequestBuilder.GetError() == CHIP_NO_ERROR); - - subscribeRequestBuilder.MaxIntervalCeilingSeconds(3); - NL_TEST_ASSERT(apSuite, subscribeRequestBuilder.GetError() == CHIP_NO_ERROR); - - subscribeRequestBuilder.KeepSubscriptions(true); - NL_TEST_ASSERT(apSuite, subscribeRequestBuilder.GetError() == CHIP_NO_ERROR); - subscribeRequestBuilder.IsProxy(true); NL_TEST_ASSERT(apSuite, subscribeRequestBuilder.GetError() == CHIP_NO_ERROR); diff --git a/src/app/tests/TestWriteInteraction.cpp b/src/app/tests/TestWriteInteraction.cpp index 05960015d2fba1..c850a4b40a38d0 100644 --- a/src/app/tests/TestWriteInteraction.cpp +++ b/src/app/tests/TestWriteInteraction.cpp @@ -126,11 +126,15 @@ void TestWriteInteraction::GenerateWriteRequest(nlTestSuite * apSuite, void * ap WriteRequestMessage::Builder writeRequestBuilder; err = writeRequestBuilder.Init(&writer); NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); + writeRequestBuilder.TimedRequest(aIsTimedWrite); + NL_TEST_ASSERT(apSuite, writeRequestBuilder.GetError() == CHIP_NO_ERROR); AttributeDataIBs::Builder attributeDataIBsBuilder = writeRequestBuilder.CreateWriteRequests(); NL_TEST_ASSERT(apSuite, writeRequestBuilder.GetError() == CHIP_NO_ERROR); AttributeDataIB::Builder attributeDataIBBuilder = attributeDataIBsBuilder.CreateAttributeDataIBBuilder(); NL_TEST_ASSERT(apSuite, attributeDataIBsBuilder.GetError() == CHIP_NO_ERROR); + attributeDataIBBuilder.DataVersion(0); + NL_TEST_ASSERT(apSuite, attributeDataIBBuilder.GetError() == CHIP_NO_ERROR); AttributePathIB::Builder attributePathBuilder = attributeDataIBBuilder.CreatePath(); NL_TEST_ASSERT(apSuite, attributePathBuilder.GetError() == CHIP_NO_ERROR); attributePathBuilder.Node(1).Endpoint(2).Cluster(3).Attribute(4).ListIndex(5).EndOfAttributePathIB(); @@ -152,12 +156,12 @@ void TestWriteInteraction::GenerateWriteRequest(nlTestSuite * apSuite, void * ap NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); } - attributeDataIBBuilder.DataVersion(0).EndOfAttributeDataIB(); + attributeDataIBBuilder.EndOfAttributeDataIB(); NL_TEST_ASSERT(apSuite, attributeDataIBBuilder.GetError() == CHIP_NO_ERROR); attributeDataIBsBuilder.EndOfAttributeDataIBs(); NL_TEST_ASSERT(apSuite, attributeDataIBsBuilder.GetError() == CHIP_NO_ERROR); - writeRequestBuilder.TimedRequest(aIsTimedWrite).IsFabricFiltered(false).EndOfWriteRequestMessage(); + writeRequestBuilder.IsFabricFiltered(false).EndOfWriteRequestMessage(); NL_TEST_ASSERT(apSuite, writeRequestBuilder.GetError() == CHIP_NO_ERROR); err = writer.Finalize(&aPayload); diff --git a/src/app/tests/integration/chip_im_responder.cpp b/src/app/tests/integration/chip_im_responder.cpp index eddf1916a5fd69..3d51ca2c93391a 100644 --- a/src/app/tests/integration/chip_im_responder.cpp +++ b/src/app/tests/integration/chip_im_responder.cpp @@ -114,12 +114,14 @@ CHIP_ERROR ReadSingleClusterData(FabricIndex aAccessingFabricIndex, const Concre AttributeReportIB::Builder & aAttributeReport) { AttributeDataIB::Builder attributeData = aAttributeReport.CreateAttributeData(); + attributeData.DataVersion(0); + ReturnErrorOnFailure(attributeData.GetError()); AttributePathIB::Builder attributePath = attributeData.CreatePath(); VerifyOrReturnError(aPath.mClusterId == kTestClusterId && aPath.mEndpointId == kTestEndpointId, CHIP_ERROR_INVALID_ARGUMENT); attributePath.Endpoint(aPath.mEndpointId).Cluster(aPath.mClusterId).Attribute(aPath.mAttributeId).EndOfAttributePathIB(); ReturnErrorOnFailure(attributePath.GetError()); ReturnErrorOnFailure(AttributeValueEncoder(attributeData.GetWriter(), 0).Encode(kTestFieldValue1)); - attributeData.DataVersion(0).EndOfAttributeDataIB(); + attributeData.EndOfAttributeDataIB(); ReturnErrorOnFailure(attributeData.GetError()); return CHIP_NO_ERROR; } diff --git a/src/app/util/ember-compatibility-functions.cpp b/src/app/util/ember-compatibility-functions.cpp index 76bb3629aa818f..96542cdd006cfc 100644 --- a/src/app/util/ember-compatibility-functions.cpp +++ b/src/app/util/ember-compatibility-functions.cpp @@ -215,7 +215,7 @@ namespace { CHIP_ERROR SendSuccessStatus(AttributeDataIB::Builder & aAttributeDataIBBuilder) { - aAttributeDataIBBuilder.DataVersion(kTemporaryDataVersion).EndOfAttributeDataIB(); + aAttributeDataIBBuilder.EndOfAttributeDataIB(); return aAttributeDataIBBuilder.GetError(); } @@ -252,7 +252,11 @@ CHIP_ERROR ReadSingleClusterData(FabricIndex aAccessingFabricIndex, const Concre AttributeDataIB::Builder attributeDataIBBuilder = aAttributeReport.CreateAttributeData(); ReturnErrorOnFailure(attributeDataIBBuilder.GetError()); + attributeDataIBBuilder.DataVersion(kTemporaryDataVersion); + ReturnErrorOnFailure(attributeDataIBBuilder.GetError()); + AttributePathIB::Builder attributePathIBBuilder = attributeDataIBBuilder.CreatePath(); + ReturnErrorOnFailure(attributeDataIBBuilder.GetError()); attributePathIBBuilder.Endpoint(aPath.mEndpointId) .Cluster(aPath.mClusterId) .Attribute(aPath.mAttributeId) diff --git a/src/app/util/mock/attribute-storage.cpp b/src/app/util/mock/attribute-storage.cpp index 57c3bad233b2c5..471ab2bdadecb8 100644 --- a/src/app/util/mock/attribute-storage.cpp +++ b/src/app/util/mock/attribute-storage.cpp @@ -239,6 +239,8 @@ CHIP_ERROR ReadSingleMockClusterData(FabricIndex aAccessingFabricIndex, const Co } attributeData = aAttributeReport.CreateAttributeData(); + attributeData.DataVersion(0); + ReturnErrorOnFailure(attributeData.GetError()); attributePath = attributeData.CreatePath(); attributePath.Endpoint(aPath.mEndpointId).Cluster(aPath.mClusterId).Attribute(aPath.mAttributeId).EndOfAttributePathIB(); ReturnErrorOnFailure(attributePath.GetError()); @@ -271,7 +273,7 @@ CHIP_ERROR ReadSingleMockClusterData(FabricIndex aAccessingFabricIndex, const Co return CHIP_ERROR_KEY_NOT_FOUND; } - attributeData.DataVersion(0).EndOfAttributeDataIB(); + attributeData.EndOfAttributeDataIB(); return attributeData.GetError(); } diff --git a/src/controller/CHIPDeviceController.cpp b/src/controller/CHIPDeviceController.cpp index 1157992b489109..6d4857e02ba217 100644 --- a/src/controller/CHIPDeviceController.cpp +++ b/src/controller/CHIPDeviceController.cpp @@ -761,6 +761,20 @@ CHIP_ERROR DeviceCommissioner::PairDevice(NodeId remoteDeviceId, const char * se CHIP_ERROR DeviceCommissioner::PairDevice(NodeId remoteDeviceId, RendezvousParameters & params) { + CommissioningParameters commissioningParams; + return PairDevice(remoteDeviceId, params, commissioningParams); +} + +CHIP_ERROR DeviceCommissioner::PairDevice(NodeId remoteDeviceId, RendezvousParameters & rendezvousParams, + CommissioningParameters & commissioningParams) +{ + ReturnErrorOnFailure(EstablishPASEConnection(remoteDeviceId, rendezvousParams)); + return Commission(remoteDeviceId, commissioningParams); +} + +CHIP_ERROR DeviceCommissioner::EstablishPASEConnection(NodeId remoteDeviceId, RendezvousParameters & params) +{ + CHIP_ERROR err = CHIP_NO_ERROR; CommissioneeDeviceProxy * device = nullptr; Transport::PeerAddress peerAddress = Transport::PeerAddress::UDP(Inet::IPAddress::Any); @@ -804,30 +818,6 @@ CHIP_ERROR DeviceCommissioner::PairDevice(NodeId remoteDeviceId, RendezvousParam mDeviceBeingCommissioned = device; - // If the CSRNonce is passed in, using that else using a random one.. - if (params.HasCSRNonce()) - { - ReturnErrorOnFailure(device->SetCSRNonce(params.GetCSRNonce().Value())); - } - else - { - uint8_t mCSRNonce[kOpCSRNonceLength]; - Crypto::DRBG_get_bytes(mCSRNonce, sizeof(mCSRNonce)); - ReturnErrorOnFailure(device->SetCSRNonce(ByteSpan(mCSRNonce))); - } - - // If the AttestationNonce is passed in, using that else using a random one.. - if (params.HasAttestationNonce()) - { - ReturnErrorOnFailure(device->SetAttestationNonce(params.GetAttestationNonce().Value())); - } - else - { - uint8_t mAttestationNonce[kAttestationNonceLength]; - Crypto::DRBG_get_bytes(mAttestationNonce, sizeof(mAttestationNonce)); - ReturnErrorOnFailure(device->SetAttestationNonce(ByteSpan(mAttestationNonce))); - } - mIsIPRendezvous = (params.GetPeerAddress().GetTransportType() != Transport::Type::kBle); device->Init(GetControllerDeviceInitParams(), remoteDeviceId, peerAddress, fabric->GetFabricIndex()); @@ -835,8 +825,6 @@ CHIP_ERROR DeviceCommissioner::PairDevice(NodeId remoteDeviceId, RendezvousParam err = device->GetPairing().MessageDispatch().Init(mSystemState->SessionMgr()); SuccessOrExit(err); - mSystemState->SystemLayer()->StartTimer(chip::System::Clock::Milliseconds32(kSessionEstablishmentTimeout), - OnSessionEstablishmentTimeoutCallback, this); if (params.GetPeerAddress().GetTransportType() != Transport::Type::kBle) { device->SetAddress(params.GetPeerAddress().GetIPAddress()); @@ -874,9 +862,10 @@ CHIP_ERROR DeviceCommissioner::PairDevice(NodeId remoteDeviceId, RendezvousParam err = device->GetPairing().Pair(params.GetPeerAddress(), params.GetSetupPINCode(), keyID, exchangeCtxt, this); SuccessOrExit(err); - // Immediately persist the updted mNextKeyID value + // Immediately persist the updated mNextKeyID value // TODO maybe remove FreeRendezvousSession() since mNextKeyID is always persisted immediately PersistNextKeyId(); + mCommissioningStage = kSecurePairing; exit: if (err != CHIP_NO_ERROR) @@ -897,6 +886,58 @@ CHIP_ERROR DeviceCommissioner::PairDevice(NodeId remoteDeviceId, RendezvousParam return err; } +CHIP_ERROR DeviceCommissioner::Commission(NodeId remoteDeviceId, CommissioningParameters & params) +{ + // TODO(cecille): Can we get rid of mDeviceBeingCommissioned and use the remote id instead? Would require storing the + // commissioning stage in the device. + CommissioneeDeviceProxy * device = mDeviceBeingCommissioned; + if (device->GetDeviceId() != remoteDeviceId || (!device->IsSecureConnected() && !device->IsSessionSetupInProgress())) + { + ChipLogError(Controller, "Invalid device for commissioning" ChipLogFormatX64, ChipLogValueX64(remoteDeviceId)); + return CHIP_ERROR_INCORRECT_STATE; + } + if (mCommissioningStage != CommissioningStage::kSecurePairing) + { + ChipLogError(Controller, "Commissioning already in progress - not restarting"); + return CHIP_ERROR_INCORRECT_STATE; + } + // If the CSRNonce is passed in, using that else using a random one.. + if (params.HasCSRNonce()) + { + ReturnErrorOnFailure(device->SetCSRNonce(params.GetCSRNonce().Value())); + } + else + { + uint8_t mCSRNonce[kOpCSRNonceLength]; + Crypto::DRBG_get_bytes(mCSRNonce, sizeof(mCSRNonce)); + ReturnErrorOnFailure(device->SetCSRNonce(ByteSpan(mCSRNonce))); + } + + // If the AttestationNonce is passed in, using that else using a random one.. + if (params.HasAttestationNonce()) + { + ReturnErrorOnFailure(device->SetAttestationNonce(params.GetAttestationNonce().Value())); + } + else + { + uint8_t mAttestationNonce[kAttestationNonceLength]; + Crypto::DRBG_get_bytes(mAttestationNonce, sizeof(mAttestationNonce)); + ReturnErrorOnFailure(device->SetAttestationNonce(ByteSpan(mAttestationNonce))); + } + + mSystemState->SystemLayer()->StartTimer(chip::System::Clock::Milliseconds32(kSessionEstablishmentTimeout), + OnSessionEstablishmentTimeoutCallback, this); + if (device->IsSecureConnected()) + { + AdvanceCommissioningStage(CHIP_NO_ERROR); + } + else + { + mRunCommissioningAfterConnection = true; + } + return CHIP_NO_ERROR; +} + CHIP_ERROR DeviceCommissioner::StopPairing(NodeId remoteDeviceId) { VerifyOrReturnError(mState == State::Initialized, CHIP_ERROR_INCORRECT_STATE); @@ -981,21 +1022,29 @@ void DeviceCommissioner::OnSessionEstablished() // TODO: Add code to receive OpCSR from the device, and process the signing request // For IP rendezvous, this is sent as part of the state machine. - bool usingLegacyFlowWithImmediateStart = !mIsIPRendezvous; - - if (usingLegacyFlowWithImmediateStart) + if (mRunCommissioningAfterConnection) { - err = SendCertificateChainRequestCommand(mDeviceBeingCommissioned, CertificateType::kPAI); - if (err != CHIP_NO_ERROR) + mRunCommissioningAfterConnection = false; + bool usingLegacyFlowWithImmediateStart = !mIsIPRendezvous; + if (usingLegacyFlowWithImmediateStart) { - ChipLogError(Ble, "Failed in sending 'Certificate Chain request' command to the device: err %s", ErrorStr(err)); - OnSessionEstablishmentError(err); - return; + err = SendCertificateChainRequestCommand(mDeviceBeingCommissioned, CertificateType::kPAI); + if (err != CHIP_NO_ERROR) + { + ChipLogError(Ble, "Failed in sending 'Certificate Chain request' command to the device: err %s", ErrorStr(err)); + OnSessionEstablishmentError(err); + return; + } + } + else + { + AdvanceCommissioningStage(CHIP_NO_ERROR); } } else { - AdvanceCommissioningStage(CHIP_NO_ERROR); + ChipLogProgress(Controller, "OnPairingComplete"); + mPairingDelegate->OnPairingComplete(CHIP_NO_ERROR); } } diff --git a/src/controller/CHIPDeviceController.h b/src/controller/CHIPDeviceController.h index 32606dc11a4dcd..998bfd9b5ffa75 100644 --- a/src/controller/CHIPDeviceController.h +++ b/src/controller/CHIPDeviceController.h @@ -48,6 +48,7 @@ #include #include #include +#include #include #include #include @@ -490,9 +491,46 @@ class DLL_EXPORT DeviceCommissioner : public DeviceController, * in the Init() call. * * @param[in] remoteDeviceId The remote device Id. - * @param[in] params The Rendezvous connection parameters + * @param[in] rendezvousParams The Rendezvous connection parameters + * @param[in] commssioningParams The commissioning parameters (uses defualt if not supplied) */ - CHIP_ERROR PairDevice(NodeId remoteDeviceId, RendezvousParameters & params); + CHIP_ERROR PairDevice(NodeId remoteDeviceId, RendezvousParameters & rendezvousParams); + CHIP_ERROR PairDevice(NodeId remoteDeviceId, RendezvousParameters & rendezvousParams, + CommissioningParameters & commissioningParams); + + /** + * @brief + * Start establishing a PASE connection with a node for the purposes of commissioning. + * Commissioners that wish to use the auto-commissioning functions should use the + * supplied "PairDevice" functions above to automatically establish a connection then + * perform commissioning. This function is intended to be use by commissioners that + * are not using the supplied auto-commissioner. + * + * This function is non-blocking. PASE is established once the DevicePairingDelegate + * receives the OnPairingComplete call. + * + * PASE connections can only be established with nodes that have their commissioning + * window open. The PASE connection will fail if this window is not open and the + * OnPairingComplete will be called with an error. + * + * @param[in] remoteDeviceId The remote device Id. + * @param[in] rendezvousParams The Rendezvous connection parameters + */ + CHIP_ERROR EstablishPASEConnection(NodeId remoteDeviceId, RendezvousParameters & params); + + /** + * @brief + * Start the auto-commissioning process on a node after establishing a PASE connection. + * This function is intended to be used in conjunction with the EstablishPASEConnection + * function. It can be called either before or after the DevicePairingDelegate receives + * the OnPairingComplete call. Commissioners that want to perform simple auto-commissioning + * should use the supplied "PairDevice" functions above, which will establish the PASE + * connection and commission automatically. + * + * @param[in] remoteDeviceId The remote device Id. + * @param[in] params The commissioning parameters + */ + CHIP_ERROR Commission(NodeId remoteDeviceId, CommissioningParameters & params); CHIP_ERROR GetDeviceBeingCommissioned(NodeId deviceId, CommissioneeDeviceProxy ** device); @@ -628,6 +666,7 @@ class DLL_EXPORT DeviceCommissioner : public DeviceController, bool mPairedDevicesUpdated; CommissioningStage mCommissioningStage = CommissioningStage::kSecurePairing; + bool mRunCommissioningAfterConnection = false; BitMapObjectPool mCommissioneeDevicePool; diff --git a/src/controller/java/CHIPDeviceController-JNI.cpp b/src/controller/java/CHIPDeviceController-JNI.cpp index 2c5a8d7af593c4..5364aceebd9698 100644 --- a/src/controller/java/CHIPDeviceController-JNI.cpp +++ b/src/controller/java/CHIPDeviceController-JNI.cpp @@ -185,18 +185,19 @@ JNI_METHOD(void, pairDevice) ChipLogProgress(Controller, "pairDevice() called with device ID, connection object, and pincode"); - RendezvousParameters params = RendezvousParameters() - .SetSetupPINCode(pinCode) + RendezvousParameters rendezvousParams = RendezvousParameters() + .SetSetupPINCode(pinCode) #if CONFIG_NETWORK_LAYER_BLE - .SetConnectionObject(reinterpret_cast(connObj)) + .SetConnectionObject(reinterpret_cast(connObj)) #endif - .SetPeerAddress(Transport::PeerAddress::BLE()); + .SetPeerAddress(Transport::PeerAddress::BLE()); + CommissioningParameters commissioningParams = CommissioningParameters(); if (csrNonce != nullptr) { JniByteArray jniCsrNonce(env, csrNonce); - params.SetCSRNonce(jniCsrNonce.byteSpan()); + commissioningParams.SetCSRNonce(jniCsrNonce.byteSpan()); } - err = wrapper->Controller()->PairDevice(deviceId, params); + err = wrapper->Controller()->PairDevice(deviceId, rendezvousParams, commissioningParams); if (err != CHIP_NO_ERROR) { @@ -221,16 +222,17 @@ JNI_METHOD(void, pairDeviceWithAddress) ChipLogError(Controller, "Failed to parse IP address."), JniReferences::GetInstance().ThrowError(env, sChipDeviceControllerExceptionCls, CHIP_ERROR_INVALID_ARGUMENT)); - RendezvousParameters params = RendezvousParameters() - .SetDiscriminator(discriminator) - .SetSetupPINCode(pinCode) - .SetPeerAddress(Transport::PeerAddress::UDP(addr, port)); + RendezvousParameters rendezvousParams = RendezvousParameters() + .SetDiscriminator(discriminator) + .SetSetupPINCode(pinCode) + .SetPeerAddress(Transport::PeerAddress::UDP(addr, port)); + CommissioningParameters commissioningParams = CommissioningParameters(); if (csrNonce != nullptr) { JniByteArray jniCsrNonce(env, csrNonce); - params.SetCSRNonce(jniCsrNonce.byteSpan()); + commissioningParams.SetCSRNonce(jniCsrNonce.byteSpan()); } - err = wrapper->Controller()->PairDevice(deviceId, params); + err = wrapper->Controller()->PairDevice(deviceId, rendezvousParams, commissioningParams); if (err != CHIP_NO_ERROR) { diff --git a/src/controller/python/ChipDeviceController-ScriptBinding.cpp b/src/controller/python/ChipDeviceController-ScriptBinding.cpp index 789e5855d92289..c16ea817b4c092 100644 --- a/src/controller/python/ChipDeviceController-ScriptBinding.cpp +++ b/src/controller/python/ChipDeviceController-ScriptBinding.cpp @@ -112,6 +112,10 @@ ChipError::StorageType pychip_DeviceController_ConnectBLE(chip::Controller::Devi ChipError::StorageType pychip_DeviceController_ConnectIP(chip::Controller::DeviceCommissioner * devCtrl, const char * peerAddrStr, uint32_t setupPINCode, chip::NodeId nodeid); ChipError::StorageType pychip_DeviceController_CloseSession(chip::Controller::DeviceCommissioner * devCtrl, chip::NodeId nodeid); +ChipError::StorageType pychip_DeviceController_EstablishPASESessionIP(chip::Controller::DeviceCommissioner * devCtrl, + const char * peerAddrStr, uint32_t setupPINCode, + chip::NodeId nodeid); +ChipError::StorageType pychip_DeviceController_Commission(chip::Controller::DeviceCommissioner * devCtrl, chip::NodeId nodeid); ChipError::StorageType pychip_DeviceController_DiscoverCommissionableNodesLongDiscriminator(chip::Controller::DeviceCommissioner * devCtrl, @@ -346,6 +350,23 @@ ChipError::StorageType pychip_DeviceController_CloseSession(chip::Controller::De { return pychip_GetConnectedDeviceByNodeId(devCtrl, nodeid, CloseSessionCallback); } +ChipError::StorageType pychip_DeviceController_EstablishPASESessionIP(chip::Controller::DeviceCommissioner * devCtrl, + const char * peerAddrStr, uint32_t setupPINCode, + chip::NodeId nodeid) +{ + chip::Inet::IPAddress peerAddr; + chip::Transport::PeerAddress addr; + RendezvousParameters params = chip::RendezvousParameters().SetSetupPINCode(setupPINCode); + VerifyOrReturnError(chip::Inet::IPAddress::FromString(peerAddrStr, peerAddr), CHIP_ERROR_INVALID_ARGUMENT.AsInteger()); + addr.SetTransportType(chip::Transport::Type::kUdp).SetIPAddress(peerAddr); + params.SetPeerAddress(addr).SetDiscriminator(0); + return devCtrl->EstablishPASEConnection(nodeid, params).AsInteger(); +} +ChipError::StorageType pychip_DeviceController_Commission(chip::Controller::DeviceCommissioner * devCtrl, chip::NodeId nodeid) +{ + CommissioningParameters params; + return devCtrl->Commission(nodeid, params).AsInteger(); +} ChipError::StorageType pychip_DeviceController_DiscoverAllCommissionableNodes(chip::Controller::DeviceCommissioner * devCtrl) { diff --git a/src/controller/python/chip-device-ctrl.py b/src/controller/python/chip-device-ctrl.py index 03b31548d1c2e0..54bfef54b96b42 100755 --- a/src/controller/python/chip-device-ctrl.py +++ b/src/controller/python/chip-device-ctrl.py @@ -200,6 +200,8 @@ def __init__(self, rendezvousAddr=None, controllerNodeId=0, bluetoothAdapter=Non "close-ble", "close-session", "resolve", + "paseonly", + "commission", "zcl", "zclread", "zclsubscribe", @@ -491,6 +493,58 @@ def ConnectFromSetupPayload(self, setupPayload, nodeid): print(f"Unable to connect: {ex}") return -1 + def do_paseonly(self, line): + """ + paseonly -ip [] + + TODO: Add more methods to connect to device (like cert for auth, and IP + for connection) + """ + + try: + args = shlex.split(line) + if len(args) <= 1: + print("Usage:") + self.do_help("paseonly") + return + + nodeid = random.randint(1, 1000000) # Just a random number + if len(args) == 4: + nodeid = int(args[3]) + print("Device is assigned with nodeid = {}".format(nodeid)) + + if args[0] == "-ip" and len(args) >= 3: + self.devCtrl.EstablishPASESessionIP(args[1].encode( + "utf-8"), int(args[2]), nodeid) + else: + print("Usage:") + self.do_help("paseonly") + return + print( + "Device temporary node id (**this does not match spec**): {}".format(nodeid)) + except Exception as ex: + print(str(ex)) + return + + def do_commission(self, line): + """ + commission nodeid + + Runs commissioning on a device that has been connected with paseonly + """ + try: + args = shlex.split(line) + if len(args) != 1: + print("Usage:") + self.do_help("commission") + return + + nodeid = int(args[0]) + self.devCtrl.Commission(nodeid) + except Exception as ex: + print(str(ex)) + return + def do_connect(self, line): """ connect -ip [] diff --git a/src/controller/python/chip/ChipDeviceCtrl.py b/src/controller/python/chip/ChipDeviceCtrl.py index a993c2e2d4760e..b0fed9d9a99f04 100644 --- a/src/controller/python/chip/ChipDeviceCtrl.py +++ b/src/controller/python/chip/ChipDeviceCtrl.py @@ -199,6 +199,26 @@ def CloseSession(self, nodeid): self.devCtrl, nodeid) ) + def EstablishPASESessionIP(self, ipaddr, setupPinCode, nodeid): + self.state = DCState.RENDEZVOUS_ONGOING + return self._ChipStack.CallAsync( + lambda: self._dmLib.pychip_DeviceController_EstablishPASESessionIP( + self.devCtrl, ipaddr, setupPinCode, nodeid) + ) + + def Commission(self, nodeid): + self._ChipStack.CallAsync( + lambda: self._dmLib.pychip_DeviceController_Commission( + self.devCtrl, nodeid) + ) + # Wait up to 5 additional seconds for the commissioning complete event + if not self._ChipStack.commissioningCompleteEvent.isSet(): + self._ChipStack.commissioningCompleteEvent.wait(5.0) + if not self._ChipStack.commissioningCompleteEvent.isSet(): + # Error 50 is a timeout + return False + return self._ChipStack.commissioningEventRes == 0 + def ConnectIP(self, ipaddr, setupPinCode, nodeid): # IP connection will run through full commissioning, so we need to wait # for the commissioning complete event, not just any callback. @@ -653,6 +673,11 @@ def _InitLib(self): self._dmLib.pychip_DeviceController_ConnectIP.argtypes = [ c_void_p, c_char_p, c_uint32, c_uint64] + + self._dmLib.pychip_DeviceController_Commission.argtypes = [ + c_void_p, c_uint64] + self._dmLib.pychip_DeviceController_Commission.restype = c_uint32 + self._dmLib.pychip_DeviceController_DiscoverAllCommissionableNodes.argtypes = [ c_void_p] self._dmLib.pychip_DeviceController_DiscoverAllCommissionableNodes.restype = c_uint32 @@ -677,6 +702,12 @@ def _InitLib(self): c_void_p] self._dmLib.pychip_DeviceController_DiscoverCommissionableNodesCommissioningEnabled.restype = c_uint32 + self._dmLib.pychip_DeviceController_EstablishPASESessionIP.argtypes = [ + c_void_p, c_char_p, c_uint32, c_uint64] + self._dmLib.pychip_DeviceController_EstablishPASESessionIP.restype = c_uint32 + + self._dmLib.pychip_DeviceController_DiscoverAllCommissionableNodes.argtypes = [ + c_void_p] self._dmLib.pychip_DeviceController_PrintDiscoveredDevices.argtypes = [ c_void_p] diff --git a/src/controller/tests/data_model/TestRead.cpp b/src/controller/tests/data_model/TestRead.cpp index de7ff36ff47580..631e16025f0454 100644 --- a/src/controller/tests/data_model/TestRead.cpp +++ b/src/controller/tests/data_model/TestRead.cpp @@ -77,13 +77,14 @@ CHIP_ERROR ReadSingleClusterData(FabricIndex aAccessingFabricIndex, const Concre i++; } + attributeData.DataVersion(0); AttributePathIB::Builder attributePath = attributeData.CreatePath(); attributePath.Endpoint(aPath.mEndpointId).Cluster(aPath.mClusterId).Attribute(aPath.mAttributeId).EndOfAttributePathIB(); ReturnErrorOnFailure(attributePath.GetError()); ReturnErrorOnFailure(DataModel::Encode(*(attributeData.GetWriter()), chip::TLV::ContextTag(chip::to_underlying(AttributeDataIB::Tag::kData)), value)); - attributeData.DataVersion(0).EndOfAttributeDataIB(); + attributeData.EndOfAttributeDataIB(); return CHIP_NO_ERROR; } else diff --git a/src/darwin/Framework/CHIP/templates/helper.js b/src/darwin/Framework/CHIP/templates/helper.js index 1383e9f02231d4..1aaa3e0c1ba4b6 100644 --- a/src/darwin/Framework/CHIP/templates/helper.js +++ b/src/darwin/Framework/CHIP/templates/helper.js @@ -56,6 +56,8 @@ function asTestValue() return '[@"Test" dataUsingEncoding:NSUTF8StringEncoding]'; } else if (StringHelper.isCharString(this.type)) { return '@"Test"'; + } else if (this.isArray) { + return '[NSArray array]'; } else { return `@(${this.min || this.max || 0})`; } diff --git a/src/include/platform/DiagnosticDataProvider.h b/src/include/platform/DiagnosticDataProvider.h index f28e33477fdf52..e02a96201db182 100644 --- a/src/include/platform/DiagnosticDataProvider.h +++ b/src/include/platform/DiagnosticDataProvider.h @@ -67,19 +67,20 @@ class GeneralDiagnosticsDelegate * @brief * Called when the Node detects a hardware fault has been raised. */ - virtual void OnHardwareFaultsDetected() {} + virtual void OnHardwareFaultsDetected(GeneralFaults & previous, GeneralFaults & current) + {} /** * @brief * Called when the Node detects a radio fault has been raised. */ - virtual void OnRadioFaultsDetected() {} + virtual void OnRadioFaultsDetected(GeneralFaults & previous, GeneralFaults & current) {} /** * @brief * Called when the Node detects a network fault has been raised. */ - virtual void OnNetworkFaultsDetected() {} + virtual void OnNetworkFaultsDetected(GeneralFaults & previous, GeneralFaults & current) {} }; /** diff --git a/src/lib/dnssd/minimal_mdns/Parser.cpp b/src/lib/dnssd/minimal_mdns/Parser.cpp index b3ae34b98e2b6e..8bc3e9e1af3bb4 100644 --- a/src/lib/dnssd/minimal_mdns/Parser.cpp +++ b/src/lib/dnssd/minimal_mdns/Parser.cpp @@ -66,23 +66,24 @@ bool QueryData::Parse(const BytesRange & validData, const uint8_t ** start) return true; } -bool QueryData::Append(HeaderRef & hdr, chip::Encoding::BigEndian::BufferWriter & out) const +bool QueryData::Append(HeaderRef & hdr, RecordWriter & out) const { if ((hdr.GetAdditionalCount() != 0) || (hdr.GetAnswerCount() != 0) || (hdr.GetAuthorityCount() != 0)) { return false; } - GetName().Put(out); - out.Put16(static_cast(mType)); - out.Put16(static_cast(mClass) | (mAnswerViaUnicast ? kQClassUnicastAnswerFlag : 0)); + out.WriteQName(GetName()) + .Put16(static_cast(mType)) + .Put16(static_cast(mClass) | (mAnswerViaUnicast ? kQClassUnicastAnswerFlag : 0)); - if (out.Fit()) + if (!out.Fit()) { - hdr.SetQueryCount(static_cast(hdr.GetQueryCount() + 1)); + return false; } - return out.Fit(); + hdr.SetQueryCount(static_cast(hdr.GetQueryCount() + 1)); + return true; } bool ResourceData::Parse(const BytesRange & validData, const uint8_t ** start) diff --git a/src/lib/dnssd/minimal_mdns/Parser.h b/src/lib/dnssd/minimal_mdns/Parser.h index 2c82b99d5a4555..1b2b208c7dcc89 100644 --- a/src/lib/dnssd/minimal_mdns/Parser.h +++ b/src/lib/dnssd/minimal_mdns/Parser.h @@ -20,6 +20,7 @@ #include #include #include +#include namespace mdns { namespace Minimal { @@ -56,7 +57,7 @@ class QueryData bool Parse(const BytesRange & validData, const uint8_t ** start); /// Write out this query data back into an output buffer. - bool Append(HeaderRef & hdr, chip::Encoding::BigEndian::BufferWriter & out) const; + bool Append(HeaderRef & hdr, RecordWriter & out) const; private: QType mType = QType::ANY; diff --git a/src/lib/dnssd/minimal_mdns/Query.h b/src/lib/dnssd/minimal_mdns/Query.h index a6687311567999..cfdd1b1120fe1e 100644 --- a/src/lib/dnssd/minimal_mdns/Query.h +++ b/src/lib/dnssd/minimal_mdns/Query.h @@ -21,6 +21,7 @@ #include #include +#include namespace mdns { namespace Minimal { @@ -56,7 +57,7 @@ class Query /// /// @param hdr will be updated with a query count /// @param out where to write the query data - bool Append(HeaderRef & hdr, chip::Encoding::BigEndian::BufferWriter & out) const + bool Append(HeaderRef & hdr, RecordWriter & out) const { // Questions can only be appended before any other data is added if ((hdr.GetAdditionalCount() != 0) || (hdr.GetAnswerCount() != 0) || (hdr.GetAuthorityCount() != 0)) @@ -64,17 +65,17 @@ class Query return false; } - mQName.Output(out); + out.WriteQName(mQName) + .Put16(static_cast(mType)) + .Put16(static_cast(static_cast(mClass) | (mAnswerViaUnicast ? kQClassUnicastAnswerFlag : 0))); - out.Put16(static_cast(mType)); - out.Put16(static_cast(static_cast(mClass) | (mAnswerViaUnicast ? kQClassUnicastAnswerFlag : 0))); - - if (out.Fit()) + if (!out.Fit()) { - hdr.SetQueryCount(static_cast(hdr.GetQueryCount() + 1)); + return false; } - return out.Fit(); + hdr.SetQueryCount(static_cast(hdr.GetQueryCount() + 1)); + return true; } private: diff --git a/src/lib/dnssd/minimal_mdns/QueryBuilder.h b/src/lib/dnssd/minimal_mdns/QueryBuilder.h index 11b6c9097792ca..c6c964a0483cd3 100644 --- a/src/lib/dnssd/minimal_mdns/QueryBuilder.h +++ b/src/lib/dnssd/minimal_mdns/QueryBuilder.h @@ -68,8 +68,9 @@ class QueryBuilder } chip::Encoding::BigEndian::BufferWriter out(mPacket->Start() + mPacket->DataLength(), mPacket->AvailableDataLength()); + RecordWriter writer(&out); - if (!query.Append(mHeader, out)) + if (!query.Append(mHeader, writer)) { mQueryBuildOk = false; } diff --git a/src/lib/dnssd/minimal_mdns/ResponseBuilder.h b/src/lib/dnssd/minimal_mdns/ResponseBuilder.h index 73735d57b0191a..b2ce31ed60110f 100644 --- a/src/lib/dnssd/minimal_mdns/ResponseBuilder.h +++ b/src/lib/dnssd/minimal_mdns/ResponseBuilder.h @@ -29,8 +29,12 @@ namespace Minimal { class ResponseBuilder { public: - ResponseBuilder() : mHeader(nullptr) {} - ResponseBuilder(chip::System::PacketBufferHandle && packet) : mHeader(nullptr) { Reset(std::move(packet)); } + ResponseBuilder() : mHeader(nullptr), mEndianOutput(nullptr, 0), mWriter(&mEndianOutput) {} + ResponseBuilder(chip::System::PacketBufferHandle && packet) : + mHeader(nullptr), mEndianOutput(nullptr, 0), mWriter(&mEndianOutput) + { + Reset(std::move(packet)); + } ResponseBuilder & Reset(chip::System::PacketBufferHandle && packet) { @@ -49,6 +53,13 @@ class ResponseBuilder } mHeader.SetFlags(mHeader.GetFlags().SetResponse()); + + mEndianOutput = + chip::Encoding::BigEndian::BufferWriter(mPacket->Start(), mPacket->DataLength() + mPacket->AvailableDataLength()); + mEndianOutput.Skip(mPacket->DataLength()); + + mWriter.Reset(); + return *this; } @@ -77,16 +88,16 @@ class ResponseBuilder return *this; } - chip::Encoding::BigEndian::BufferWriter out(mPacket->Start() + mPacket->DataLength(), mPacket->AvailableDataLength()); - - if (!record.Append(mHeader, type, out)) + if (!record.Append(mHeader, type, mWriter)) { mBuildOk = false; } else { - mPacket->SetDataLength(static_cast(mPacket->DataLength() + out.Needed())); + VerifyOrDie(mEndianOutput.Fit()); // should be guaranteed because record Append succeeded + mPacket->SetDataLength(static_cast(mEndianOutput.Needed())); } + return *this; } @@ -97,15 +108,13 @@ class ResponseBuilder return *this; } - chip::Encoding::BigEndian::BufferWriter out(mPacket->Start() + mPacket->DataLength(), mPacket->AvailableDataLength()); - - if (!query.Append(mHeader, out)) + if (!query.Append(mHeader, mWriter)) { mBuildOk = false; } else { - mPacket->SetDataLength(static_cast(mPacket->DataLength() + out.Needed())); + mPacket->SetDataLength(static_cast(mEndianOutput.Needed())); } return *this; } @@ -116,6 +125,8 @@ class ResponseBuilder private: chip::System::PacketBufferHandle mPacket; HeaderRef mHeader; + chip::Encoding::BigEndian::BufferWriter mEndianOutput; + RecordWriter mWriter; bool mBuildOk = false; }; diff --git a/src/lib/dnssd/minimal_mdns/core/BUILD.gn b/src/lib/dnssd/minimal_mdns/core/BUILD.gn index 8ba686a63657a1..0fc83eaddfdd15 100644 --- a/src/lib/dnssd/minimal_mdns/core/BUILD.gn +++ b/src/lib/dnssd/minimal_mdns/core/BUILD.gn @@ -21,6 +21,8 @@ static_library("core") { "DnsHeader.h", "QName.cpp", "QName.h", + "RecordWriter.cpp", + "RecordWriter.h", ] public_deps = [ diff --git a/src/lib/dnssd/minimal_mdns/core/BytesRange.h b/src/lib/dnssd/minimal_mdns/core/BytesRange.h index 73e4170fd1347b..9b50664cdbdee0 100644 --- a/src/lib/dnssd/minimal_mdns/core/BytesRange.h +++ b/src/lib/dnssd/minimal_mdns/core/BytesRange.h @@ -44,6 +44,11 @@ class BytesRange size_t Size() const { return static_cast(mEnd - mStart); } + inline static BytesRange BufferWithSize(const void * buff, size_t len) + { + return BytesRange(static_cast(buff), static_cast(buff) + len); + } + private: const uint8_t * mStart = nullptr; const uint8_t * mEnd = nullptr; diff --git a/src/lib/dnssd/minimal_mdns/core/Constants.h b/src/lib/dnssd/minimal_mdns/core/Constants.h index e081edf5e62b79..30f2d4b0ddd0b9 100644 --- a/src/lib/dnssd/minimal_mdns/core/Constants.h +++ b/src/lib/dnssd/minimal_mdns/core/Constants.h @@ -66,7 +66,5 @@ enum class ResourceType kAdditional, }; -constexpr size_t kMaxQNamePartLength = 255; - } // namespace Minimal } // namespace mdns diff --git a/src/lib/dnssd/minimal_mdns/core/QName.cpp b/src/lib/dnssd/minimal_mdns/core/QName.cpp index f970b1f0a590f6..8fca9f47d8993a 100644 --- a/src/lib/dnssd/minimal_mdns/core/QName.cpp +++ b/src/lib/dnssd/minimal_mdns/core/QName.cpp @@ -152,6 +152,40 @@ bool SerializedQNameIterator::operator==(const FullQName & other) const return ((idx == other.nameCount) && !self.Next()); } +bool SerializedQNameIterator::operator==(const SerializedQNameIterator & other) const +{ + SerializedQNameIterator a = *this; // allow iteration + SerializedQNameIterator b = other; + + while (true) + { + bool hasA = a.Next(); + bool hasB = b.Next(); + + if (hasA ^ hasB) + { + return false; /// one is longer than the other + } + + if (!a.IsValid() || !b.IsValid()) + { + return false; // invalid data + } + + if (!hasA || !hasB) + { + break; + } + + if (strcasecmp(a.Value(), b.Value()) != 0) + { + return false; + } + } + + return a.IsValid() && b.IsValid(); +} + bool FullQName::operator==(const FullQName & other) const { if (nameCount != other.nameCount) diff --git a/src/lib/dnssd/minimal_mdns/core/QName.h b/src/lib/dnssd/minimal_mdns/core/QName.h index fc829f8ad2971d..e2dc6f54716f85 100644 --- a/src/lib/dnssd/minimal_mdns/core/QName.h +++ b/src/lib/dnssd/minimal_mdns/core/QName.h @@ -53,17 +53,6 @@ struct FullQName FullQName(const QNamePart (&data)[N]) : names(data), nameCount(N) {} - void Output(chip::Encoding::BigEndian::BufferWriter & out) const - { - for (uint16_t i = 0; i < nameCount; i++) - { - - out.Put8(static_cast(strlen(names[i]))); - out.Put(names[i]); - } - out.Put8(0); // end of qnames - } - bool operator==(const FullQName & other) const; bool operator!=(const FullQName & other) const { return !(*this == other); } }; @@ -108,17 +97,10 @@ class SerializedQNameIterator bool operator==(const FullQName & other) const; bool operator!=(const FullQName & other) const { return !(*this == other); } - void Put(chip::Encoding::BigEndian::BufferWriter & out) const - { - SerializedQNameIterator copy = *this; - while (copy.Next()) - { - - out.Put8(static_cast(strlen(copy.Value()))); - out.Put(copy.Value()); - } - out.Put8(0); // end of qnames - } + bool operator==(const SerializedQNameIterator & other) const; + bool operator!=(const SerializedQNameIterator & other) const { return !(*this == other); } + + size_t OffsetInCurrentValidData() const { return static_cast(mCurrentPosition - mValidData.Start()); } private: static constexpr size_t kMaxValueSize = 63; diff --git a/src/lib/dnssd/minimal_mdns/core/RecordWriter.cpp b/src/lib/dnssd/minimal_mdns/core/RecordWriter.cpp new file mode 100644 index 00000000000000..1cd8b68cc71dd4 --- /dev/null +++ b/src/lib/dnssd/minimal_mdns/core/RecordWriter.cpp @@ -0,0 +1,145 @@ +/* + * + * Copyright (c) 2020 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "RecordWriter.h" + +namespace mdns { +namespace Minimal { + +SerializedQNameIterator RecordWriter::PreviousName(size_t index) const +{ + if (index >= kMaxCachedReferences) + { + return SerializedQNameIterator(); + } + + uint16_t offset = mPreviousQNames[index]; + if (offset == kInvalidOffset) + { + return SerializedQNameIterator(); + } + + return SerializedQNameIterator(BytesRange(mOutput->Buffer(), mOutput->Buffer() + mOutput->WritePos()), + mOutput->Buffer() + offset); +} + +RecordWriter & RecordWriter::WriteQName(const FullQName & qname) +{ + size_t qNameWriteStart = mOutput->WritePos(); + bool isFullyCompressed = true; + + for (uint16_t i = 0; i < qname.nameCount; i++) + { + + // find out if the record part remaining already is located somewhere + FullQName remaining; + remaining.names = qname.names + i; + remaining.nameCount = qname.nameCount - i; + + // Try to find a valid offset + chip::Optional offset = FindPreviousName(remaining); + + if (offset.HasValue()) + { + // Pointer to offset: set the highest 2 bits + mOutput->Put16(offset.Value() | 0xC000); + + if (mOutput->Fit() && !isFullyCompressed) + { + RememberWrittenQnameOffset(qNameWriteStart); + } + return *this; + } + + mOutput->Put8(static_cast(strlen(qname.names[i]))); + mOutput->Put(qname.names[i]); + isFullyCompressed = false; + } + mOutput->Put8(0); // end of qnames + + if (mOutput->Fit()) + { + RememberWrittenQnameOffset(qNameWriteStart); + } + return *this; +} + +RecordWriter & RecordWriter::WriteQName(const SerializedQNameIterator & qname) +{ + size_t qNameWriteStart = mOutput->WritePos(); + bool isFullyCompressed = true; + + SerializedQNameIterator copy = qname; + while (true) + { + chip::Optional offset = FindPreviousName(copy); + + if (offset.HasValue()) + { + // Pointer to offset: set the highest 2 bits + // We guarantee that offsets saved are <= kMaxReuseOffset + mOutput->Put16(offset.Value() | 0xC000); + + if (mOutput->Fit() && !isFullyCompressed) + { + RememberWrittenQnameOffset(qNameWriteStart); + } + return *this; + } + + if (!copy.Next()) + { + break; + } + + if (!copy.IsValid()) + { + break; + } + + mOutput->Put8(static_cast(strlen(copy.Value()))); + mOutput->Put(copy.Value()); + isFullyCompressed = false; + } + mOutput->Put8(0); // end of qnames + + if (mOutput->Fit()) + { + RememberWrittenQnameOffset(qNameWriteStart); + } + return *this; +} + +void RecordWriter::RememberWrittenQnameOffset(size_t offset) +{ + if (offset > kMaxReuseOffset) + { + // cannot represent this offset properly + return; + } + + for (size_t i = 0; i < kMaxCachedReferences; i++) + { + if (mPreviousQNames[i] == kInvalidOffset) + { + mPreviousQNames[i] = offset; + return; + } + } +} + +} // namespace Minimal +} // namespace mdns diff --git a/src/lib/dnssd/minimal_mdns/core/RecordWriter.h b/src/lib/dnssd/minimal_mdns/core/RecordWriter.h new file mode 100644 index 00000000000000..0356f7e4b2643e --- /dev/null +++ b/src/lib/dnssd/minimal_mdns/core/RecordWriter.h @@ -0,0 +1,138 @@ +/* + * + * Copyright (c) 2020 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include + +namespace mdns { +namespace Minimal { + +/** + * Handles writing into mdns packets. + * + * Generally the same as a binary data writer, but can handle qname writing with + * compression. + */ +class RecordWriter +{ +public: + RecordWriter(chip::Encoding::BigEndian::BufferWriter * output) : mOutput(output) { Reset(); } + + void Reset() + { + for (size_t i = 0; i < kMaxCachedReferences; i++) + { + mPreviousQNames[i] = kInvalidOffset; + } + } + + chip::Encoding::BigEndian::BufferWriter & Writer() { return *mOutput; } + + /// Writes the given qname into the underlying buffer, applying + /// compression if possible + RecordWriter & WriteQName(const FullQName & qname); + + /// Writes the given qname into the underlying buffer, applying + /// compression if possible + RecordWriter & WriteQName(const SerializedQNameIterator & qname); + + inline RecordWriter & Put8(uint8_t value) + { + mOutput->Put8(value); + return *this; + } + + inline RecordWriter & Put16(uint16_t value) + { + mOutput->Put16(value); + return *this; + } + + inline RecordWriter & Put32(uint32_t value) + { + mOutput->Put32(value); + return *this; + } + + inline RecordWriter & PutString(const char * value) + { + mOutput->Put(value); + return *this; + } + + inline RecordWriter & Put(const BytesRange & range) + { + mOutput->Put(range.Start(), range.Size()); + return *this; + } + + inline bool Fit() const { return mOutput->Fit(); } + +private: + // How many paths to remember as 'previously written' + // and make use of them + static constexpr size_t kMaxCachedReferences = 8; + static constexpr uint16_t kInvalidOffset = 0xFFFF; + static constexpr uint16_t kMaxReuseOffset = 0x3FFF; + + // Where the data is being outputted + chip::Encoding::BigEndian::BufferWriter * mOutput; + uint16_t mPreviousQNames[kMaxCachedReferences]; + + /// Find the offset at which this qname was previously seen (if any) + /// works with QName and SerializedQNameIterator + template + chip::Optional FindPreviousName(const T & name) const + { + for (size_t i = 0; i < kMaxCachedReferences; i++) + { + SerializedQNameIterator previous = PreviousName(i); + + // Any of the sub-segments may match + while (previous.IsValid()) + { + if (previous == name) + { + return chip::Optional::Value(previous.OffsetInCurrentValidData()); + } + + if (!previous.Next()) + { + break; + } + } + } + + return chip::Optional::Missing(); + } + + /// Gets the iterator corresponding to the previous name + /// with the given index. + /// + /// Will return an iterator that is not valid if + /// lookbehind index is not valid + SerializedQNameIterator PreviousName(size_t index) const; + + /// Keep track that a qname was written at the given offset + void RememberWrittenQnameOffset(size_t offset); +}; + +} // namespace Minimal +} // namespace mdns diff --git a/src/lib/dnssd/minimal_mdns/core/tests/BUILD.gn b/src/lib/dnssd/minimal_mdns/core/tests/BUILD.gn index 9484a04c8aa817..a3f6141c16d029 100644 --- a/src/lib/dnssd/minimal_mdns/core/tests/BUILD.gn +++ b/src/lib/dnssd/minimal_mdns/core/tests/BUILD.gn @@ -24,6 +24,7 @@ chip_test_suite("tests") { test_sources = [ "TestFlatAllocatedQName.cpp", "TestQName.cpp", + "TestRecordWriter.cpp", ] cflags = [ "-Wconversion" ] diff --git a/src/lib/dnssd/minimal_mdns/core/tests/TestQName.cpp b/src/lib/dnssd/minimal_mdns/core/tests/TestQName.cpp index 055176cc69efd8..d8f741badb9c32 100644 --- a/src/lib/dnssd/minimal_mdns/core/tests/TestQName.cpp +++ b/src/lib/dnssd/minimal_mdns/core/tests/TestQName.cpp @@ -25,11 +25,26 @@ namespace { using namespace mdns::Minimal; +/// Convenience method to have a serialized QName: +/// +/// static const uint8_t kData[] = "datahere\00"; +/// AsSerializedQName(kData); +/// +/// NOTE: this MUST be using the string "" format to add an extra NULL +/// terminator that this method discards. +template +static SerializedQNameIterator AsSerializedQName(const uint8_t (&v)[N]) +{ + // NOTE: the -1 is because we format these items as STRINGS and that + // appends an extra NULL terminator + return SerializedQNameIterator(BytesRange(v, v + N - 1), v); +} + void IteratorTest(nlTestSuite * inSuite, void * inContext) { { static const uint8_t kOneItem[] = "\04test\00"; - SerializedQNameIterator it(BytesRange(kOneItem, kOneItem + sizeof(kOneItem)), kOneItem); + SerializedQNameIterator it = AsSerializedQName(kOneItem); NL_TEST_ASSERT(inSuite, it.Next()); NL_TEST_ASSERT(inSuite, strcmp(it.Value(), "test") == 0); @@ -39,7 +54,7 @@ void IteratorTest(nlTestSuite * inSuite, void * inContext) { static const uint8_t kManyItems[] = "\04this\02is\01a\04test\00"; - SerializedQNameIterator it(BytesRange(kManyItems, kManyItems + sizeof(kManyItems)), kManyItems); + SerializedQNameIterator it = AsSerializedQName(kManyItems); NL_TEST_ASSERT(inSuite, it.Next()); NL_TEST_ASSERT(inSuite, strcmp(it.Value(), "this") == 0); @@ -82,7 +97,7 @@ void ErrorTest(nlTestSuite * inSuite, void * inContext) { // Truncated before the end static const uint8_t kData[] = "\04test"; - SerializedQNameIterator it(BytesRange(kData, kData + 5), kData); + SerializedQNameIterator it = AsSerializedQName(kData); NL_TEST_ASSERT(inSuite, !it.Next()); NL_TEST_ASSERT(inSuite, !it.IsValid()); @@ -91,7 +106,7 @@ void ErrorTest(nlTestSuite * inSuite, void * inContext) { // Truncated before the end static const uint8_t kData[] = "\02"; - SerializedQNameIterator it(BytesRange(kData, kData + 1), kData); + SerializedQNameIterator it = AsSerializedQName(kData); NL_TEST_ASSERT(inSuite, !it.Next()); NL_TEST_ASSERT(inSuite, !it.IsValid()); @@ -100,7 +115,7 @@ void ErrorTest(nlTestSuite * inSuite, void * inContext) { // Truncated before the end static const uint8_t kData[] = "\xc0"; - SerializedQNameIterator it(BytesRange(kData, kData + 1), kData); + SerializedQNameIterator it = AsSerializedQName(kData); NL_TEST_ASSERT(inSuite, !it.Next()); NL_TEST_ASSERT(inSuite, !it.IsValid()); @@ -108,6 +123,7 @@ void ErrorTest(nlTestSuite * inSuite, void * inContext) { // Truncated before the end (but seemingly valid in case of error) + // does NOT use AsSerializedQName (because out of range) static const uint8_t kData[] = "\00\xc0\x00"; SerializedQNameIterator it(BytesRange(kData, kData + 2), kData + 1); @@ -117,7 +133,7 @@ void ErrorTest(nlTestSuite * inSuite, void * inContext) { // Infinite recursion static const uint8_t kData[] = "\03test\xc0\x00"; - SerializedQNameIterator it(BytesRange(kData, kData + 7), kData); + SerializedQNameIterator it = AsSerializedQName(kData); NL_TEST_ASSERT(inSuite, it.Next()); NL_TEST_ASSERT(inSuite, !it.Next()); @@ -131,44 +147,32 @@ void Comparison(nlTestSuite * inSuite, void * inContext) { const QNamePart kTestName[] = { "this" }; - NL_TEST_ASSERT(inSuite, - SerializedQNameIterator(BytesRange(kManyItems, kManyItems + sizeof(kManyItems)), kManyItems) != - FullQName(kTestName)); + NL_TEST_ASSERT(inSuite, AsSerializedQName(kManyItems) != FullQName(kTestName)); } { const QNamePart kTestName[] = { "this", "is" }; - NL_TEST_ASSERT(inSuite, - SerializedQNameIterator(BytesRange(kManyItems, kManyItems + sizeof(kManyItems)), kManyItems) != - FullQName(kTestName)); + NL_TEST_ASSERT(inSuite, AsSerializedQName(kManyItems) != FullQName(kTestName)); } { const QNamePart kTestName[] = { "is", "a", "test" }; - NL_TEST_ASSERT(inSuite, - SerializedQNameIterator(BytesRange(kManyItems, kManyItems + sizeof(kManyItems)), kManyItems) != - FullQName(kTestName)); + NL_TEST_ASSERT(inSuite, AsSerializedQName(kManyItems) != FullQName(kTestName)); } { const QNamePart kTestName[] = { "this", "is", "a", "test" }; - NL_TEST_ASSERT(inSuite, - SerializedQNameIterator(BytesRange(kManyItems, kManyItems + sizeof(kManyItems)), kManyItems) == - FullQName(kTestName)); + NL_TEST_ASSERT(inSuite, AsSerializedQName(kManyItems) == FullQName(kTestName)); } { const QNamePart kTestName[] = { "this", "is", "a", "test", "suffix" }; - NL_TEST_ASSERT(inSuite, - SerializedQNameIterator(BytesRange(kManyItems, kManyItems + sizeof(kManyItems)), kManyItems) != - FullQName(kTestName)); + NL_TEST_ASSERT(inSuite, AsSerializedQName(kManyItems) != FullQName(kTestName)); } { const QNamePart kTestName[] = { "prefix", "this", "is", "a", "test" }; - NL_TEST_ASSERT(inSuite, - SerializedQNameIterator(BytesRange(kManyItems, kManyItems + sizeof(kManyItems)), kManyItems) != - FullQName(kTestName)); + NL_TEST_ASSERT(inSuite, AsSerializedQName(kManyItems) != FullQName(kTestName)); } } @@ -237,6 +241,33 @@ void CaseInsensitiveFullQNameCompare(nlTestSuite * inSuite, void * inContext) } } +void SerializedCompare(nlTestSuite * inSuite, void * inContext) +{ + static const uint8_t kThisIsATest1[] = "\04this\02is\01a\04test\00"; + static const uint8_t kThisIsATest2[] = "\04ThIs\02is\01A\04tESt\00"; + static const uint8_t kThisIsDifferent[] = "\04this\02is\09different\00"; + static const uint8_t kThisIs[] = "\04this\02is"; + + NL_TEST_ASSERT(inSuite, AsSerializedQName(kThisIsATest1) == AsSerializedQName(kThisIsATest1)); + NL_TEST_ASSERT(inSuite, AsSerializedQName(kThisIsATest2) == AsSerializedQName(kThisIsATest2)); + NL_TEST_ASSERT(inSuite, AsSerializedQName(kThisIsATest1) == AsSerializedQName(kThisIsATest2)); + NL_TEST_ASSERT(inSuite, AsSerializedQName(kThisIsATest1) != AsSerializedQName(kThisIsDifferent)); + NL_TEST_ASSERT(inSuite, AsSerializedQName(kThisIsDifferent) != AsSerializedQName(kThisIsATest1)); + NL_TEST_ASSERT(inSuite, AsSerializedQName(kThisIsDifferent) != AsSerializedQName(kThisIs)); + NL_TEST_ASSERT(inSuite, AsSerializedQName(kThisIs) != AsSerializedQName(kThisIsDifferent)); + + // These items have back references and are "this.is.a.test" + static const uint8_t kPtrItems[] = "\03abc\02is\01a\04test\00\04this\xc0\04"; + SerializedQNameIterator thisIsATestPtr(BytesRange(kPtrItems, kPtrItems + sizeof(kPtrItems)), kPtrItems + 15); + + NL_TEST_ASSERT(inSuite, thisIsATestPtr == AsSerializedQName(kThisIsATest1)); + NL_TEST_ASSERT(inSuite, thisIsATestPtr == AsSerializedQName(kThisIsATest2)); + NL_TEST_ASSERT(inSuite, AsSerializedQName(kThisIsATest1) == thisIsATestPtr); + NL_TEST_ASSERT(inSuite, AsSerializedQName(kThisIsATest2) == thisIsATestPtr); + NL_TEST_ASSERT(inSuite, thisIsATestPtr != AsSerializedQName(kThisIs)); + NL_TEST_ASSERT(inSuite, AsSerializedQName(kThisIs) != thisIsATestPtr); +} + } // namespace // clang-format off @@ -247,6 +278,7 @@ static const nlTest sTests[] = NL_TEST_DEF("Comparison", Comparison), NL_TEST_DEF("CaseInsensitiveSerializedCompare", CaseInsensitiveSerializedCompare), NL_TEST_DEF("CaseInsensitiveFullQNameCompare", CaseInsensitiveFullQNameCompare), + NL_TEST_DEF("SerializedCompare", SerializedCompare), NL_TEST_SENTINEL() }; diff --git a/src/lib/dnssd/minimal_mdns/core/tests/TestRecordWriter.cpp b/src/lib/dnssd/minimal_mdns/core/tests/TestRecordWriter.cpp new file mode 100644 index 00000000000000..6323417a1c7992 --- /dev/null +++ b/src/lib/dnssd/minimal_mdns/core/tests/TestRecordWriter.cpp @@ -0,0 +1,197 @@ +/* + * + * Copyright (c) 2021 Project CHIP Authors + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include + +namespace { + +using namespace mdns::Minimal; +using namespace chip::Encoding::BigEndian; + +void BasicWriteTest(nlTestSuite * inSuite, void * inContext) +{ + const QNamePart kName1[] = { "some", "name" }; + const QNamePart kName2[] = { "abc", "xyz", "here" }; + + uint8_t dataBuffer[128]; + + BufferWriter output(dataBuffer, sizeof(dataBuffer)); + RecordWriter writer(&output); + + writer.WriteQName(FullQName(kName1)); + writer.WriteQName(FullQName(kName2)); + + // clang-format off + const uint8_t expectedOutput[] = { + // + 4, 's', 'o', 'm', 'e', // QNAME part: some + 4, 'n', 'a', 'm', 'e', // QNAME part: name + 0, // QNAME ends + 3, 'a', 'b', 'c', // QNAME part: abc + 3, 'x', 'y', 'z', // QNAME part: xyz + 4, 'h', 'e', 'r', 'e', // QNAME part: here + 0, // QNAME ends + }; + // clang-format on + + NL_TEST_ASSERT(inSuite, output.Needed() == sizeof(expectedOutput)); + NL_TEST_ASSERT(inSuite, memcmp(dataBuffer, expectedOutput, sizeof(expectedOutput)) == 0); +} + +void SimpleDedup(nlTestSuite * inSuite, void * inContext) +{ + const QNamePart kName1[] = { "some", "name" }; + const QNamePart kName2[] = { "other", "name" }; + + uint8_t dataBuffer[128]; + + BufferWriter output(dataBuffer, sizeof(dataBuffer)); + RecordWriter writer(&output); + + writer.WriteQName(FullQName(kName1)); + writer.WriteQName(FullQName(kName2)); + + // clang-format off + const uint8_t expectedOutput[] = { + // + 4, 's', 'o', 'm', 'e', // QNAME part: some + 4, 'n', 'a', 'm', 'e', // QNAME part: name + 0, // QNAME ends + 5, 'o', 't', 'h', 'e', 'r', // QNAME part: other + 0xC0, 5 // POINTER: "name" is at offset 5 + }; + // clang-format on + + NL_TEST_ASSERT(inSuite, output.Needed() == sizeof(expectedOutput)); + NL_TEST_ASSERT(inSuite, memcmp(dataBuffer, expectedOutput, sizeof(expectedOutput)) == 0); +} + +void ComplexDedup(nlTestSuite * inSuite, void * inContext) +{ + const QNamePart kName1[] = { "some", "name" }; + const QNamePart kName2[] = { "other", "name" }; + const QNamePart kName3[] = { "prefix", "of", "other", "name" }; + const QNamePart kName4[] = { "some", "name", "suffix" }; + const QNamePart kName5[] = { "more", "suffix" }; + + uint8_t dataBuffer[128]; + + BufferWriter output(dataBuffer, sizeof(dataBuffer)); + RecordWriter writer(&output); + + writer.WriteQName(FullQName(kName1)); + writer.WriteQName(FullQName(kName2)); + writer.WriteQName(FullQName(kName3)); + writer.Writer().Put("xyz"); // inject something that is NOT a qname + writer.WriteQName(FullQName(kName4)); + writer.WriteQName(FullQName(kName5)); + + // clang-format off + const uint8_t expectedOutput[] = { + // + 4, 's', 'o', 'm', 'e', // QNAME part: some + 4, 'n', 'a', 'm', 'e', // QNAME part: name + 0, // QNAME ends + 5, 'o', 't', 'h', 'e', 'r', // QNAME part: other + 0xC0, 5, // POINTER: "name" is at offset 5 + 6, 'p', 'r', 'e', 'f', 'i', 'x', + 2, 'o', 'f', + 0xC0, 11, // POINTER: "other.name" is at offset 11 + 'x', 'y', 'z', + 4, 's', 'o', 'm', 'e', // QNAME part: some + 4, 'n', 'a', 'm', 'e', // QNAME part: name + 6, 's', 'u', 'f', 'f', 'i', 'x', // suffix which prevents reuse + 0, + 4, 'm', 'o', 'r', 'e', + 0xC0, 44 + }; + // clang-format on + + NL_TEST_ASSERT(inSuite, output.Needed() == sizeof(expectedOutput)); + NL_TEST_ASSERT(inSuite, memcmp(dataBuffer, expectedOutput, sizeof(expectedOutput)) == 0); +} + +void TonsOfReferences(nlTestSuite * inSuite, void * inContext) +{ + const QNamePart kName1[] = { "some", "name" }; + const QNamePart kName2[] = { "different", "name" }; + + uint8_t dataBuffer[512]; + + BufferWriter output(dataBuffer, sizeof(dataBuffer)); + RecordWriter writer(&output); + + // First name is 11 bytes (2*4 bytes + null terminator) + // all other entires are 2 bytes (back - references) + // + // TOTAL: 211 bytes written + for (int i = 0; i < 101; i++) + { + writer.WriteQName(FullQName(kName1)); + } + + // Extra size: 10 for "different" and 2 for "name" link + // TOTAL: 211 + 12 = 223 + writer.WriteQName(FullQName(kName2)); + + // Another 200 bytes for references + // TOTAL: 423 + for (int i = 0; i < 100; i++) + { + writer.WriteQName(FullQName(kName2)); + } + + NL_TEST_ASSERT(inSuite, output.Fit()); + NL_TEST_ASSERT(inSuite, output.Needed() == 423); +} + +} // namespace + +// clang-format off +static const nlTest sTests[] = +{ + NL_TEST_DEF("BasicWriteTest", BasicWriteTest), + NL_TEST_DEF("SimpleDedup", SimpleDedup), + NL_TEST_DEF("ComplexDedup", ComplexDedup), + NL_TEST_DEF("TonsOfReferences", TonsOfReferences), + + NL_TEST_SENTINEL() +}; +// clang-format on + +int TestRecordWriter(void) +{ + // clang-format off + nlTestSuite theSuite = + { + "RecordWriter", + &sTests[0], + nullptr, + nullptr + }; + // clang-format on + + nlTestRunner(&theSuite, nullptr); + + return (nlTestRunnerStats(&theSuite)); +} + +CHIP_REGISTER_TEST_SUITE(TestRecordWriter) diff --git a/src/lib/dnssd/minimal_mdns/records/IP.cpp b/src/lib/dnssd/minimal_mdns/records/IP.cpp index 1312f185e4cd1b..43c996a501cfcf 100644 --- a/src/lib/dnssd/minimal_mdns/records/IP.cpp +++ b/src/lib/dnssd/minimal_mdns/records/IP.cpp @@ -20,16 +20,16 @@ namespace mdns { namespace Minimal { -bool IPResourceRecord::WriteData(chip::Encoding::BigEndian::BufferWriter & out) const +bool IPResourceRecord::WriteData(RecordWriter & out) const { // IP address is already stored in network byte order, hence raw bytes put if (mIPAddress.IsIPv6()) { - out.Put(mIPAddress.Addr, 16); + out.Put(BytesRange::BufferWithSize(mIPAddress.Addr, 16)); } else { - out.Put(mIPAddress.Addr + 3, 4); + out.Put(BytesRange::BufferWithSize(mIPAddress.Addr + 3, 4)); } return out.Fit(); diff --git a/src/lib/dnssd/minimal_mdns/records/IP.h b/src/lib/dnssd/minimal_mdns/records/IP.h index 36f4e02a8eb98c..fc5efb87a51069 100644 --- a/src/lib/dnssd/minimal_mdns/records/IP.h +++ b/src/lib/dnssd/minimal_mdns/records/IP.h @@ -32,7 +32,7 @@ class IPResourceRecord : public ResourceRecord {} protected: - bool WriteData(chip::Encoding::BigEndian::BufferWriter & out) const override; + bool WriteData(RecordWriter & out) const override; private: const chip::Inet::IPAddress mIPAddress; diff --git a/src/lib/dnssd/minimal_mdns/records/Ptr.h b/src/lib/dnssd/minimal_mdns/records/Ptr.h index 3186b0e00f31d5..a6622d946c94f7 100644 --- a/src/lib/dnssd/minimal_mdns/records/Ptr.h +++ b/src/lib/dnssd/minimal_mdns/records/Ptr.h @@ -30,11 +30,7 @@ class PtrResourceRecord : public ResourceRecord const FullQName & GetPtr() const { return mPtrName; } protected: - bool WriteData(chip::Encoding::BigEndian::BufferWriter & out) const override - { - mPtrName.Output(out); - return out.Fit(); - } + bool WriteData(RecordWriter & out) const override { return out.WriteQName(mPtrName).Fit(); } private: const FullQName mPtrName; diff --git a/src/lib/dnssd/minimal_mdns/records/ResourceRecord.cpp b/src/lib/dnssd/minimal_mdns/records/ResourceRecord.cpp index 1b6dac4897b83e..ce0b7e8e58e2e0 100644 --- a/src/lib/dnssd/minimal_mdns/records/ResourceRecord.cpp +++ b/src/lib/dnssd/minimal_mdns/records/ResourceRecord.cpp @@ -20,7 +20,7 @@ namespace mdns { namespace Minimal { -bool ResourceRecord::Append(HeaderRef & hdr, ResourceType asType, chip::Encoding::BigEndian::BufferWriter & out) const +bool ResourceRecord::Append(HeaderRef & hdr, ResourceType asType, RecordWriter & out) const { // order is important based on resource type. First come answers, then authorityAnswers // and then additional: @@ -33,22 +33,22 @@ bool ResourceRecord::Append(HeaderRef & hdr, ResourceType asType, chip::Encoding return false; } - mQName.Output(out); + out.WriteQName(mQName); - out // + out.Writer() // .Put16(static_cast(GetType())) // .Put16(static_cast(GetClass())) // .Put32(static_cast(GetTtl())) // ; - chip::Encoding::BigEndian::BufferWriter sizeOutput(out); // copy to re-output size - out.Put16(0); // dummy, will be replaced later + chip::Encoding::BigEndian::BufferWriter sizeOutput(out.Writer()); // copy to re-output size + out.Put16(0); // dummy, will be replaced later if (!WriteData(out)) { return false; } - sizeOutput.Put16(static_cast(out.Needed() - sizeOutput.Needed() - 2)); + sizeOutput.Put16(static_cast(out.Writer().Needed() - sizeOutput.Needed() - 2)); // This MUST be final and separated out: record count is only updated on success. if (out.Fit()) diff --git a/src/lib/dnssd/minimal_mdns/records/ResourceRecord.h b/src/lib/dnssd/minimal_mdns/records/ResourceRecord.h index 51aa2b127cf582..67bae7ece63327 100644 --- a/src/lib/dnssd/minimal_mdns/records/ResourceRecord.h +++ b/src/lib/dnssd/minimal_mdns/records/ResourceRecord.h @@ -21,6 +21,7 @@ #include #include +#include #include @@ -57,11 +58,11 @@ class ResourceRecord /// Append the given record to the underlying output. /// Updates header item count on success, does NOT update header on failure. - bool Append(HeaderRef & hdr, ResourceType asType, chip::Encoding::BigEndian::BufferWriter & out) const; + bool Append(HeaderRef & hdr, ResourceType asType, RecordWriter & out) const; protected: /// Output the data portion of the resource record. - virtual bool WriteData(chip::Encoding::BigEndian::BufferWriter & out) const = 0; + virtual bool WriteData(RecordWriter & out) const = 0; ResourceRecord(QType type, FullQName name) : mType(type), mQName(name) {} diff --git a/src/lib/dnssd/minimal_mdns/records/Srv.h b/src/lib/dnssd/minimal_mdns/records/Srv.h index ede53db30637fa..7587805313c141 100644 --- a/src/lib/dnssd/minimal_mdns/records/Srv.h +++ b/src/lib/dnssd/minimal_mdns/records/Srv.h @@ -40,14 +40,9 @@ class SrvResourceRecord : public ResourceRecord void SetWeight(uint16_t value) { mWeight = value; } protected: - bool WriteData(chip::Encoding::BigEndian::BufferWriter & out) const override + bool WriteData(RecordWriter & out) const override { - out.Put16(mPriority); - out.Put16(mWeight); - out.Put16(mPort); - mServerName.Output(out); - - return out.Fit(); + return out.Put16(mPriority).Put16(mWeight).Put16(mPort).WriteQName(mServerName).Fit(); } private: diff --git a/src/lib/dnssd/minimal_mdns/records/Txt.h b/src/lib/dnssd/minimal_mdns/records/Txt.h index 54dcb806fa3f15..881e703f7cc164 100644 --- a/src/lib/dnssd/minimal_mdns/records/Txt.h +++ b/src/lib/dnssd/minimal_mdns/records/Txt.h @@ -52,18 +52,17 @@ class TxtResourceRecord : public ResourceRecord const char * const * GetEntries() const { return mEntries; } protected: - bool WriteData(chip::Encoding::BigEndian::BufferWriter & out) const override + bool WriteData(RecordWriter & out) const override { for (size_t i = 0; i < mEntryCount; i++) { size_t len = strlen(mEntries[i]); - if (len > kMaxQNamePartLength) + if (len > kMaxTxtRecordLength) { return false; } - out.Put8(static_cast(len)); - out.Put(mEntries[i]); + out.Put8(static_cast(len)).PutString(mEntries[i]); } return out.Fit(); } @@ -71,6 +70,8 @@ class TxtResourceRecord : public ResourceRecord private: const char * const * mEntries; const size_t mEntryCount; + + static constexpr size_t kMaxTxtRecordLength = 63; }; } // namespace Minimal diff --git a/src/lib/dnssd/minimal_mdns/records/tests/TestResourceRecord.cpp b/src/lib/dnssd/minimal_mdns/records/tests/TestResourceRecord.cpp index 522c2ac3c7145a..4269d73854ea4c 100644 --- a/src/lib/dnssd/minimal_mdns/records/tests/TestResourceRecord.cpp +++ b/src/lib/dnssd/minimal_mdns/records/tests/TestResourceRecord.cpp @@ -35,11 +35,7 @@ class FakeResourceRecord : public ResourceRecord FakeResourceRecord(const char * data) : ResourceRecord(QType::ANY, kNames), mData(data) {} protected: - bool WriteData(BufferWriter & out) const override - { - out.Put(mData); - return out.Fit(); - } + bool WriteData(RecordWriter & out) const override { return out.PutString(mData).Fit(); } private: const char * mData; @@ -54,6 +50,8 @@ void CanWriteSimpleRecord(nlTestSuite * inSuite, void * inContext) header.Clear(); BufferWriter output(dataBuffer, sizeof(dataBuffer)); + RecordWriter writer(&output); + FakeResourceRecord record("somedata"); record.SetTtl(0x11223344); @@ -70,7 +68,7 @@ void CanWriteSimpleRecord(nlTestSuite * inSuite, void * inContext) 's', 'o', 'm', 'e', 'd', 'a', 't', 'a', }; - NL_TEST_ASSERT(inSuite, record.Append(header, ResourceType::kAnswer, output)); + NL_TEST_ASSERT(inSuite, record.Append(header, ResourceType::kAnswer, writer)); NL_TEST_ASSERT(inSuite, header.GetAnswerCount() == 1); NL_TEST_ASSERT(inSuite, header.GetAuthorityCount() == 0); NL_TEST_ASSERT(inSuite, header.GetAdditionalCount() == 0); @@ -87,6 +85,8 @@ void CanWriteMultipleRecords(nlTestSuite * inSuite, void * inContext) header.Clear(); BufferWriter output(dataBuffer, sizeof(dataBuffer)); + RecordWriter writer(&output); + FakeResourceRecord record1("somedata"); FakeResourceRecord record2("moredata"); FakeResourceRecord record3("xyz"); @@ -104,17 +104,13 @@ void CanWriteMultipleRecords(nlTestSuite * inSuite, void * inContext) 0x11, 0x22, 0x33, 0x44, // TTL 0, 8, // data size 's', 'o', 'm', 'e', 'd', 'a', 't', 'a', // - 3, 'f', 'o', 'o', // QNAME part: foo - 3, 'b', 'a', 'r', // QNAME part: bar - 0, // QNAME ends + 0xC0, 0x00, // PTR: foo.bar 0, 255, // QType ANY (totally fake) 0, 1, // QClass IN 0, 0, 0, 0, // TTL 0, 8, // data size 'm', 'o', 'r', 'e', 'd', 'a', 't', 'a', // - 3, 'f', 'o', 'o', // QNAME part: foo - 3, 'b', 'a', 'r', // QNAME part: bar - 0, // QNAME ends + 0xC0, 0x00, // PTR: foo.bar 0, 255, // QType ANY (totally fake) 0, 1, // QClass IN 0, 0, 0, 0xFF, // TTL @@ -122,17 +118,17 @@ void CanWriteMultipleRecords(nlTestSuite * inSuite, void * inContext) 'x', 'y', 'z', }; - NL_TEST_ASSERT(inSuite, record1.Append(header, ResourceType::kAnswer, output)); + NL_TEST_ASSERT(inSuite, record1.Append(header, ResourceType::kAnswer, writer)); NL_TEST_ASSERT(inSuite, header.GetAnswerCount() == 1); NL_TEST_ASSERT(inSuite, header.GetAuthorityCount() == 0); NL_TEST_ASSERT(inSuite, header.GetAdditionalCount() == 0); - NL_TEST_ASSERT(inSuite, record2.Append(header, ResourceType::kAuthority, output)); + NL_TEST_ASSERT(inSuite, record2.Append(header, ResourceType::kAuthority, writer)); NL_TEST_ASSERT(inSuite, header.GetAnswerCount() == 1); NL_TEST_ASSERT(inSuite, header.GetAuthorityCount() == 1); NL_TEST_ASSERT(inSuite, header.GetAdditionalCount() == 0); - NL_TEST_ASSERT(inSuite, record3.Append(header, ResourceType::kAdditional, output)); + NL_TEST_ASSERT(inSuite, record3.Append(header, ResourceType::kAdditional, writer)); NL_TEST_ASSERT(inSuite, header.GetAnswerCount() == 1); NL_TEST_ASSERT(inSuite, header.GetAuthorityCount() == 1); NL_TEST_ASSERT(inSuite, header.GetAdditionalCount() == 1); @@ -149,16 +145,18 @@ void RecordOrderIsEnforced(nlTestSuite * inSuite, void * inContext) HeaderRef header(headerBuffer); BufferWriter output(dataBuffer, sizeof(dataBuffer)); + RecordWriter writer(&output); + FakeResourceRecord record("somedata"); header.Clear(); header.SetAuthorityCount(1); - NL_TEST_ASSERT(inSuite, record.Append(header, ResourceType::kAnswer, output) == false); + NL_TEST_ASSERT(inSuite, record.Append(header, ResourceType::kAnswer, writer) == false); header.Clear(); header.SetAdditionalCount(1); - NL_TEST_ASSERT(inSuite, record.Append(header, ResourceType::kAnswer, output) == false); - NL_TEST_ASSERT(inSuite, record.Append(header, ResourceType::kAuthority, output) == false); + NL_TEST_ASSERT(inSuite, record.Append(header, ResourceType::kAnswer, writer) == false); + NL_TEST_ASSERT(inSuite, record.Append(header, ResourceType::kAuthority, writer) == false); } void ErrorsOutOnSmallBuffers(nlTestSuite * inSuite, void * inContext) @@ -191,8 +189,9 @@ void ErrorsOutOnSmallBuffers(nlTestSuite * inSuite, void * inContext) { memset(dataBuffer, 0, sizeof(dataBuffer)); BufferWriter output(dataBuffer, i); + RecordWriter writer(&output); - NL_TEST_ASSERT(inSuite, record.Append(header, ResourceType::kAnswer, output) == false); + NL_TEST_ASSERT(inSuite, record.Append(header, ResourceType::kAnswer, writer) == false); // header untouched NL_TEST_ASSERT(inSuite, memcmp(headerBuffer, clearHeader, HeaderRef::kSizeBytes) == 0); @@ -200,8 +199,9 @@ void ErrorsOutOnSmallBuffers(nlTestSuite * inSuite, void * inContext) memset(dataBuffer, 0, sizeof(dataBuffer)); BufferWriter output(dataBuffer, sizeof(expectedOutput)); + RecordWriter writer(&output); - NL_TEST_ASSERT(inSuite, record.Append(header, ResourceType::kAnswer, output)); + NL_TEST_ASSERT(inSuite, record.Append(header, ResourceType::kAnswer, writer)); NL_TEST_ASSERT(inSuite, output.Needed() == sizeof(expectedOutput)); NL_TEST_ASSERT(inSuite, memcmp(dataBuffer, expectedOutput, sizeof(expectedOutput)) == 0); NL_TEST_ASSERT(inSuite, memcmp(headerBuffer, clearHeader, HeaderRef::kSizeBytes) != 0); @@ -221,7 +221,9 @@ void RecordCount(nlTestSuite * inSuite, void * inContext) for (int i = 0; i < kAppendCount; i++) { BufferWriter output(dataBuffer, sizeof(dataBuffer)); - NL_TEST_ASSERT(inSuite, record.Append(header, ResourceType::kAnswer, output)); + RecordWriter writer(&output); + + NL_TEST_ASSERT(inSuite, record.Append(header, ResourceType::kAnswer, writer)); NL_TEST_ASSERT(inSuite, header.GetAnswerCount() == i + 1); NL_TEST_ASSERT(inSuite, header.GetAuthorityCount() == 0); NL_TEST_ASSERT(inSuite, header.GetAdditionalCount() == 0); @@ -230,7 +232,9 @@ void RecordCount(nlTestSuite * inSuite, void * inContext) for (int i = 0; i < kAppendCount; i++) { BufferWriter output(dataBuffer, sizeof(dataBuffer)); - NL_TEST_ASSERT(inSuite, record.Append(header, ResourceType::kAuthority, output)); + RecordWriter writer(&output); + + NL_TEST_ASSERT(inSuite, record.Append(header, ResourceType::kAuthority, writer)); NL_TEST_ASSERT(inSuite, header.GetAnswerCount() == kAppendCount); NL_TEST_ASSERT(inSuite, header.GetAuthorityCount() == i + 1); NL_TEST_ASSERT(inSuite, header.GetAdditionalCount() == 0); @@ -239,7 +243,9 @@ void RecordCount(nlTestSuite * inSuite, void * inContext) for (int i = 0; i < kAppendCount; i++) { BufferWriter output(dataBuffer, sizeof(dataBuffer)); - NL_TEST_ASSERT(inSuite, record.Append(header, ResourceType::kAdditional, output)); + RecordWriter writer(&output); + + NL_TEST_ASSERT(inSuite, record.Append(header, ResourceType::kAdditional, writer)); NL_TEST_ASSERT(inSuite, header.GetAnswerCount() == kAppendCount); NL_TEST_ASSERT(inSuite, header.GetAuthorityCount() == kAppendCount); NL_TEST_ASSERT(inSuite, header.GetAdditionalCount() == i + 1); diff --git a/src/lib/dnssd/minimal_mdns/records/tests/TestResourceRecordIP.cpp b/src/lib/dnssd/minimal_mdns/records/tests/TestResourceRecordIP.cpp index c1262adda27daa..ce66500b47e742 100644 --- a/src/lib/dnssd/minimal_mdns/records/tests/TestResourceRecordIP.cpp +++ b/src/lib/dnssd/minimal_mdns/records/tests/TestResourceRecordIP.cpp @@ -43,6 +43,8 @@ void WriteIPv4(nlTestSuite * inSuite, void * inContext) { BufferWriter output(dataBuffer, sizeof(dataBuffer)); + RecordWriter writer(&output); + IPResourceRecord ipResourceRecord(kNames, ipAddress); ipResourceRecord.SetTtl(123); @@ -61,7 +63,7 @@ void WriteIPv4(nlTestSuite * inSuite, void * inContext) 10, 20, 30, 40 // IP Address }; - NL_TEST_ASSERT(inSuite, ipResourceRecord.Append(header, ResourceType::kAnswer, output)); + NL_TEST_ASSERT(inSuite, ipResourceRecord.Append(header, ResourceType::kAnswer, writer)); NL_TEST_ASSERT(inSuite, header.GetAnswerCount() == 1); NL_TEST_ASSERT(inSuite, header.GetAuthorityCount() == 0); NL_TEST_ASSERT(inSuite, header.GetAdditionalCount() == 0); @@ -71,6 +73,7 @@ void WriteIPv4(nlTestSuite * inSuite, void * inContext) { BufferWriter output(dataBuffer, sizeof(dataBuffer)); + RecordWriter writer(&output); IPResourceRecord ipResourceRecord(kNames, ipAddress); @@ -90,7 +93,7 @@ void WriteIPv4(nlTestSuite * inSuite, void * inContext) 10, 20, 30, 40 // IP Address }; - NL_TEST_ASSERT(inSuite, ipResourceRecord.Append(header, ResourceType::kAuthority, output)); + NL_TEST_ASSERT(inSuite, ipResourceRecord.Append(header, ResourceType::kAuthority, writer)); NL_TEST_ASSERT(inSuite, header.GetAnswerCount() == 0); NL_TEST_ASSERT(inSuite, header.GetAuthorityCount() == 1); NL_TEST_ASSERT(inSuite, header.GetAdditionalCount() == 0); @@ -100,6 +103,7 @@ void WriteIPv4(nlTestSuite * inSuite, void * inContext) { BufferWriter output(dataBuffer, sizeof(dataBuffer)); + RecordWriter writer(&output); IPResourceRecord ipResourceRecord(kNames, ipAddress); @@ -119,7 +123,7 @@ void WriteIPv4(nlTestSuite * inSuite, void * inContext) 10, 20, 30, 40 // IP Address }; - NL_TEST_ASSERT(inSuite, ipResourceRecord.Append(header, ResourceType::kAdditional, output)); + NL_TEST_ASSERT(inSuite, ipResourceRecord.Append(header, ResourceType::kAdditional, writer)); NL_TEST_ASSERT(inSuite, header.GetAnswerCount() == 0); NL_TEST_ASSERT(inSuite, header.GetAuthorityCount() == 0); NL_TEST_ASSERT(inSuite, header.GetAdditionalCount() == 1); @@ -141,6 +145,7 @@ void WriteIPv6(nlTestSuite * inSuite, void * inContext) HeaderRef header(headerBuffer); BufferWriter output(dataBuffer, sizeof(dataBuffer)); + RecordWriter writer(&output); IPResourceRecord ipResourceRecord(kNames, ipAddress); ipResourceRecord.SetTtl(0x12345678); @@ -162,7 +167,7 @@ void WriteIPv6(nlTestSuite * inSuite, void * inContext) 0xfe, 0x19, 0x35, 0x9b }; - NL_TEST_ASSERT(inSuite, ipResourceRecord.Append(header, ResourceType::kAnswer, output)); + NL_TEST_ASSERT(inSuite, ipResourceRecord.Append(header, ResourceType::kAnswer, writer)); NL_TEST_ASSERT(inSuite, header.GetAnswerCount() == 1); NL_TEST_ASSERT(inSuite, header.GetAuthorityCount() == 0); NL_TEST_ASSERT(inSuite, header.GetAdditionalCount() == 0); diff --git a/src/lib/dnssd/minimal_mdns/records/tests/TestResourceRecordPtr.cpp b/src/lib/dnssd/minimal_mdns/records/tests/TestResourceRecordPtr.cpp index aa388cd2cbfa2f..02f9d651752e0c 100644 --- a/src/lib/dnssd/minimal_mdns/records/tests/TestResourceRecordPtr.cpp +++ b/src/lib/dnssd/minimal_mdns/records/tests/TestResourceRecordPtr.cpp @@ -37,6 +37,8 @@ void TestPtrResourceRecord(nlTestSuite * inSuite, void * inContext) HeaderRef header(headerBuffer); BigEndian::BufferWriter output(dataBuffer, sizeof(dataBuffer)); + RecordWriter writer(&output); + PtrResourceRecord record(kName, kPtr); record.SetTtl(123); @@ -57,7 +59,7 @@ void TestPtrResourceRecord(nlTestSuite * inSuite, void * inContext) 0 // QNAME ends }; - NL_TEST_ASSERT(inSuite, record.Append(header, ResourceType::kAnswer, output)); + NL_TEST_ASSERT(inSuite, record.Append(header, ResourceType::kAnswer, writer)); NL_TEST_ASSERT(inSuite, header.GetAnswerCount() == 1); NL_TEST_ASSERT(inSuite, header.GetAuthorityCount() == 0); NL_TEST_ASSERT(inSuite, header.GetAdditionalCount() == 0); diff --git a/src/lib/dnssd/minimal_mdns/records/tests/TestResourceRecordSrv.cpp b/src/lib/dnssd/minimal_mdns/records/tests/TestResourceRecordSrv.cpp index d294741c3d342a..11fc163566074c 100644 --- a/src/lib/dnssd/minimal_mdns/records/tests/TestResourceRecordSrv.cpp +++ b/src/lib/dnssd/minimal_mdns/records/tests/TestResourceRecordSrv.cpp @@ -38,13 +38,14 @@ void TestSrv(nlTestSuite * inSuite, void * inContext) HeaderRef header(headerBuffer); BigEndian::BufferWriter output(dataBuffer, sizeof(dataBuffer)); + RecordWriter writer(&output); SrvResourceRecord record(kName, kServerName, kPort); record.SetTtl(128); header.Clear(); - NL_TEST_ASSERT(inSuite, record.Append(header, ResourceType::kAdditional, output)); + NL_TEST_ASSERT(inSuite, record.Append(header, ResourceType::kAdditional, writer)); NL_TEST_ASSERT(inSuite, header.GetAnswerCount() == 0); NL_TEST_ASSERT(inSuite, header.GetAuthorityCount() == 0); NL_TEST_ASSERT(inSuite, header.GetAdditionalCount() == 1); diff --git a/src/lib/dnssd/minimal_mdns/records/tests/TestResourceRecordTxt.cpp b/src/lib/dnssd/minimal_mdns/records/tests/TestResourceRecordTxt.cpp index c7537a863d3a1d..4ed533aaacdec7 100644 --- a/src/lib/dnssd/minimal_mdns/records/tests/TestResourceRecordTxt.cpp +++ b/src/lib/dnssd/minimal_mdns/records/tests/TestResourceRecordTxt.cpp @@ -37,6 +37,7 @@ void TestTxt(nlTestSuite * inSuite, void * inContext) HeaderRef header(headerBuffer); BigEndian::BufferWriter output(dataBuffer, sizeof(dataBuffer)); + RecordWriter writer(&output); TxtResourceRecord record(kName, kData); record.SetTtl(128); @@ -44,7 +45,7 @@ void TestTxt(nlTestSuite * inSuite, void * inContext) header.Clear(); - NL_TEST_ASSERT(inSuite, record.Append(header, ResourceType::kAdditional, output)); + NL_TEST_ASSERT(inSuite, record.Append(header, ResourceType::kAdditional, writer)); NL_TEST_ASSERT(inSuite, header.GetAnswerCount() == 0); NL_TEST_ASSERT(inSuite, header.GetAuthorityCount() == 0); NL_TEST_ASSERT(inSuite, header.GetAdditionalCount() == 1); diff --git a/src/lib/dnssd/minimal_mdns/responders/tests/TestPtrResponder.cpp b/src/lib/dnssd/minimal_mdns/responders/tests/TestPtrResponder.cpp index f0cdd7fd79d37a..d9bf42a7383b8c 100644 --- a/src/lib/dnssd/minimal_mdns/responders/tests/TestPtrResponder.cpp +++ b/src/lib/dnssd/minimal_mdns/responders/tests/TestPtrResponder.cpp @@ -53,11 +53,12 @@ class PtrResponseAccumulator : public ResponderDelegate uint8_t buffer[128]; BigEndian::BufferWriter out(buffer, sizeof(buffer)); + RecordWriter writer(&out); HeaderRef hdr(headerBuffer); hdr.Clear(); - NL_TEST_ASSERT(mSuite, record.Append(hdr, ResourceType::kAnswer, out)); + NL_TEST_ASSERT(mSuite, record.Append(hdr, ResourceType::kAnswer, writer)); ResourceData data; SerializedQNameIterator target; diff --git a/src/lib/dnssd/minimal_mdns/tests/CheckOnlyServer.h b/src/lib/dnssd/minimal_mdns/tests/CheckOnlyServer.h index 250f304807bf80..6af8a16227fa4f 100644 --- a/src/lib/dnssd/minimal_mdns/tests/CheckOnlyServer.h +++ b/src/lib/dnssd/minimal_mdns/tests/CheckOnlyServer.h @@ -108,11 +108,11 @@ class CheckOnlyServer : private chip::PoolImplStart(), data->Start() + data->TotalLength()), this); + mPacketData = BytesRange(data->Start(), data->Start() + data->TotalLength()); + ParsePacket(mPacketData, this); if (mHeaderFound) { TestGotAllExpectedPackets(); @@ -317,6 +318,7 @@ class CheckOnlyServer : private chip::PoolImpl #include +#include #include #include #include @@ -46,6 +47,7 @@ struct CommonTestElements uint8_t * requestNameStart = requestStorage + ConstHeaderRef::kSizeBytes; Encoding::BigEndian::BufferWriter requestBufferWriter = Encoding::BigEndian::BufferWriter(requestNameStart, sizeof(requestStorage) - HeaderRef::kSizeBytes); + RecordWriter recordWriter; uint8_t dnsSdServiceStorage[64]; uint8_t serviceNameStorage[64]; @@ -71,6 +73,7 @@ struct CommonTestElements Inet::IPPacketInfo packetInfo; CommonTestElements(nlTestSuite * inSuite, const char * tag) : + recordWriter(&requestBufferWriter), dnsSd(FlatAllocatedQName::Build(dnsSdServiceStorage, "_services", "_dns-sd", "_udp", "local")), service(FlatAllocatedQName::Build(serviceNameStorage, tag, "service")), instance(FlatAllocatedQName::Build(instanceNameStorage, tag, "instance")), @@ -90,7 +93,7 @@ void SrvAnyResponseToInstance(nlTestSuite * inSuite, void * inContext) common.queryResponder.AddResponder(&common.srvResponder); // Build a query for our srv record - common.instance.Output(common.requestBufferWriter); + common.recordWriter.WriteQName(common.instance); QueryData queryData = QueryData(QType::ANY, QClass::IN, false, common.requestNameStart, common.requestBytesRange); @@ -110,7 +113,7 @@ void SrvTxtAnyResponseToInstance(nlTestSuite * inSuite, void * inContext) common.queryResponder.AddResponder(&common.txtResponder); // Build a query for the instance name - common.instance.Output(common.requestBufferWriter); + common.recordWriter.WriteQName(common.instance); QueryData queryData = QueryData(QType::ANY, QClass::IN, false, common.requestNameStart, common.requestBytesRange); @@ -133,7 +136,7 @@ void PtrSrvTxtAnyResponseToServiceName(nlTestSuite * inSuite, void * inContext) common.queryResponder.AddResponder(&common.txtResponder); // Build a query for the service name - common.service.Output(common.requestBufferWriter); + common.recordWriter.WriteQName(common.service); QueryData queryData = QueryData(QType::ANY, QClass::IN, false, common.requestNameStart, common.requestBytesRange); @@ -158,7 +161,7 @@ void PtrSrvTxtAnyResponseToInstance(nlTestSuite * inSuite, void * inContext) common.queryResponder.AddResponder(&common.txtResponder); // Build a query for the instance name - common.instance.Output(common.requestBufferWriter); + common.recordWriter.WriteQName(common.instance); QueryData queryData = QueryData(QType::ANY, QClass::IN, false, common.requestNameStart, common.requestBytesRange); @@ -182,7 +185,7 @@ void PtrSrvTxtSrvResponseToInstance(nlTestSuite * inSuite, void * inContext) common.queryResponder.AddResponder(&common.txtResponder); // Build a query for the instance - common.instance.Output(common.requestBufferWriter); + common.recordWriter.WriteQName(common.instance); QueryData queryData = QueryData(QType::SRV, QClass::IN, false, common.requestNameStart, common.requestBytesRange); @@ -205,7 +208,7 @@ void PtrSrvTxtAnyResponseToServiceListing(nlTestSuite * inSuite, void * inContex common.queryResponder.AddResponder(&common.txtResponder); // Build a query for the dns-sd services listing. - common.dnsSd.Output(common.requestBufferWriter); + common.recordWriter.WriteQName(common.dnsSd); QueryData queryData = QueryData(QType::ANY, QClass::IN, false, common.requestNameStart, common.requestBytesRange); @@ -226,15 +229,15 @@ void NoQueryResponder(nlTestSuite * inSuite, void * inContext) QueryData queryData = QueryData(QType::ANY, QClass::IN, false, common.requestNameStart, common.requestBytesRange); - common.dnsSd.Output(common.requestBufferWriter); + common.recordWriter.WriteQName(common.dnsSd); responseSender.Respond(1, queryData, &common.packetInfo); NL_TEST_ASSERT(inSuite, !common.server.GetSendCalled()); - common.service.Output(common.requestBufferWriter); + common.recordWriter.WriteQName(common.service); responseSender.Respond(1, queryData, &common.packetInfo); NL_TEST_ASSERT(inSuite, !common.server.GetSendCalled()); - common.instance.Output(common.requestBufferWriter); + common.recordWriter.WriteQName(common.instance); responseSender.Respond(1, queryData, &common.packetInfo); NL_TEST_ASSERT(inSuite, !common.server.GetSendCalled()); } @@ -289,7 +292,7 @@ void PtrSrvTxtMultipleRespondersToInstance(nlTestSuite * inSuite, void * inConte common2.queryResponder.AddResponder(&common2.txtResponder); // Build a query for the second instance. - common2.instance.Output(common2.requestBufferWriter); + common2.recordWriter.WriteQName(common2.instance); QueryData queryData = QueryData(QType::ANY, QClass::IN, false, common2.requestNameStart, common2.requestBytesRange); // Should get back answers from second instance only. @@ -321,7 +324,7 @@ void PtrSrvTxtMultipleRespondersToServiceListing(nlTestSuite * inSuite, void * i common2.queryResponder.AddResponder(&common2.txtResponder); // Build a query for the instance - common1.dnsSd.Output(common1.requestBufferWriter); + common1.recordWriter.WriteQName(common1.dnsSd); QueryData queryData = QueryData(QType::ANY, QClass::IN, false, common1.requestNameStart, common1.requestBytesRange); // Should get service listing from both. diff --git a/src/lib/support/BufferWriter.h b/src/lib/support/BufferWriter.h index b608d54ceda69c..a8afee2a7305b9 100644 --- a/src/lib/support/BufferWriter.h +++ b/src/lib/support/BufferWriter.h @@ -59,7 +59,10 @@ class BufferWriter } /// Number of bytes required to satisfy all calls to Put() so far - size_t Needed() const { return mNeeded; } + inline size_t Needed() const { return mNeeded; } + + /// Alias to Needed() for code clarity: current writing position for the buffer. + inline size_t WritePos() const { return Needed(); } /// Number of bytes still available for writing size_t Available() const { return mSize < mNeeded ? 0 : mSize - mNeeded; } diff --git a/src/platform/Linux/PlatformManagerImpl.cpp b/src/platform/Linux/PlatformManagerImpl.cpp index e1623903140b9f..76ac2fa9679f55 100644 --- a/src/platform/Linux/PlatformManagerImpl.cpp +++ b/src/platform/Linux/PlatformManagerImpl.cpp @@ -275,22 +275,64 @@ void PlatformManagerImpl::HandleGeneralFault(uint32_t EventId) { GeneralDiagnosticsDelegate * delegate = GetDiagnosticDataProvider().GetGeneralDiagnosticsDelegate(); - if (delegate != nullptr) + if (delegate == nullptr) + { + ChipLogError(DeviceLayer, "No delegate registered to handle General Diagnostics event"); + return; + } + + if (EventId == GeneralDiagnostics::Events::HardwareFaultChange::kEventId) + { + GeneralFaults previous; + GeneralFaults current; + +#if CHIP_CONFIG_TEST + // On Linux Simulation, set following hardware faults statically. + ReturnOnFailure(previous.add(EMBER_ZCL_HARDWARE_FAULT_TYPE_RADIO)); + ReturnOnFailure(previous.add(EMBER_ZCL_HARDWARE_FAULT_TYPE_POWER_SOURCE)); + + ReturnOnFailure(current.add(EMBER_ZCL_HARDWARE_FAULT_TYPE_RADIO)); + ReturnOnFailure(current.add(EMBER_ZCL_HARDWARE_FAULT_TYPE_SENSOR)); + ReturnOnFailure(current.add(EMBER_ZCL_HARDWARE_FAULT_TYPE_POWER_SOURCE)); + ReturnOnFailure(current.add(EMBER_ZCL_HARDWARE_FAULT_TYPE_USER_INTERFACE_FAULT)); +#endif + delegate->OnHardwareFaultsDetected(previous, current); + } + else if (EventId == GeneralDiagnostics::Events::RadioFaultChange::kEventId) + { + GeneralFaults previous; + GeneralFaults current; + +#if CHIP_CONFIG_TEST + // On Linux Simulation, set following radio faults statically. + ReturnOnFailure(previous.add(EMBER_ZCL_RADIO_FAULT_TYPE_WI_FI_FAULT)); + ReturnOnFailure(previous.add(EMBER_ZCL_RADIO_FAULT_TYPE_THREAD_FAULT)); + + ReturnOnFailure(current.add(EMBER_ZCL_RADIO_FAULT_TYPE_WI_FI_FAULT)); + ReturnOnFailure(current.add(EMBER_ZCL_RADIO_FAULT_TYPE_CELLULAR_FAULT)); + ReturnOnFailure(current.add(EMBER_ZCL_RADIO_FAULT_TYPE_THREAD_FAULT)); + ReturnOnFailure(current.add(EMBER_ZCL_RADIO_FAULT_TYPE_NFC_FAULT)); +#endif + delegate->OnRadioFaultsDetected(previous, current); + } + else if (EventId == GeneralDiagnostics::Events::NetworkFaultChange::kEventId) + { + GeneralFaults previous; + GeneralFaults current; + +#if CHIP_CONFIG_TEST + // On Linux Simulation, set following radio faults statically. + ReturnOnFailure(previous.add(EMBER_ZCL_NETWORK_FAULT_TYPE_HARDWARE_FAILURE)); + ReturnOnFailure(previous.add(EMBER_ZCL_NETWORK_FAULT_TYPE_NETWORK_JAMMED)); + + ReturnOnFailure(current.add(EMBER_ZCL_NETWORK_FAULT_TYPE_HARDWARE_FAILURE)); + ReturnOnFailure(current.add(EMBER_ZCL_NETWORK_FAULT_TYPE_NETWORK_JAMMED)); + ReturnOnFailure(current.add(EMBER_ZCL_NETWORK_FAULT_TYPE_CONNECTION_FAILED)); +#endif + delegate->OnNetworkFaultsDetected(previous, current); + } + else { - switch (EventId) - { - case GeneralDiagnostics::Events::HardwareFaultChange::kEventId: - delegate->OnHardwareFaultsDetected(); - break; - case GeneralDiagnostics::Events::RadioFaultChange::kEventId: - delegate->OnRadioFaultsDetected(); - break; - case GeneralDiagnostics::Events::NetworkFaultChange::kEventId: - delegate->OnNetworkFaultsDetected(); - break; - default: - break; - } } } diff --git a/src/protocols/secure_channel/RendezvousParameters.h b/src/protocols/secure_channel/RendezvousParameters.h index 295daff2644ff5..043ad223626059 100644 --- a/src/protocols/secure_channel/RendezvousParameters.h +++ b/src/protocols/secure_channel/RendezvousParameters.h @@ -48,28 +48,12 @@ class RendezvousParameters bool HasPeerAddress() const { return mPeerAddress.IsInitialized(); } Transport::PeerAddress GetPeerAddress() const { return mPeerAddress; } - const Optional GetCSRNonce() const { return mCSRNonce; } - const Optional GetAttestationNonce() const { return mAttestationNonce; } RendezvousParameters & SetPeerAddress(const Transport::PeerAddress & peerAddress) { mPeerAddress = peerAddress; return *this; } - // The lifetime of the buffer csrNonce is pointing to, should exceed the lifetime of RendezvousParameter object. - RendezvousParameters & SetCSRNonce(ByteSpan csrNonce) - { - mCSRNonce.SetValue(csrNonce); - return *this; - } - - // The lifetime of the buffer attestationNonce is pointing to, should exceed the lifetime of RendezvousParameter object. - RendezvousParameters & SetAttestationNonce(ByteSpan attestationNonce) - { - mAttestationNonce.SetValue(attestationNonce); - return *this; - } - bool HasDiscriminator() const { return mDiscriminator <= kMaxRendezvousDiscriminatorValue; } uint16_t GetDiscriminator() const { return mDiscriminator; } RendezvousParameters & SetDiscriminator(uint16_t discriminator) @@ -79,8 +63,6 @@ class RendezvousParameters } bool HasPASEVerifier() const { return mHasPASEVerifier; } - bool HasCSRNonce() const { return mCSRNonce.HasValue(); } - bool HasAttestationNonce() const { return mAttestationNonce.HasValue(); } const PASEVerifier & GetPASEVerifier() const { return mPASEVerifier; } RendezvousParameters & SetPASEVerifier(PASEVerifier & verifier) { @@ -113,8 +95,6 @@ class RendezvousParameters Transport::PeerAddress mPeerAddress; ///< the peer node address uint32_t mSetupPINCode = 0; ///< the target peripheral setup PIN Code uint16_t mDiscriminator = UINT16_MAX; ///< the target peripheral discriminator - Optional mCSRNonce; ///< CSR Nonce passed by the commissioner - Optional mAttestationNonce; ///< Attestation Nonce passed by the commissioner PASEVerifier mPASEVerifier; bool mHasPASEVerifier = false; @@ -125,4 +105,31 @@ class RendezvousParameters #endif // CONFIG_NETWORK_LAYER_BLE }; +class CommissioningParameters +{ +public: + bool HasCSRNonce() const { return mCSRNonce.HasValue(); } + bool HasAttestationNonce() const { return mAttestationNonce.HasValue(); } + const Optional GetCSRNonce() const { return mCSRNonce; } + const Optional GetAttestationNonce() const { return mAttestationNonce; } + + // The lifetime of the buffer csrNonce is pointing to, should exceed the lifetime of CommissioningParameters object. + CommissioningParameters & SetCSRNonce(ByteSpan csrNonce) + { + mCSRNonce.SetValue(csrNonce); + return *this; + } + + // The lifetime of the buffer attestationNonce is pointing to, should exceed the lifetime of CommissioningParameters object. + CommissioningParameters & SetAttestationNonce(ByteSpan attestationNonce) + { + mAttestationNonce.SetValue(attestationNonce); + return *this; + } + +private: + Optional mCSRNonce; ///< CSR Nonce passed by the commissioner + Optional mAttestationNonce; ///< Attestation Nonce passed by the commissioner +}; + } // namespace chip