From 3689594b3ed3f8afb7ec2af0e3fce3696858289b Mon Sep 17 00:00:00 2001 From: Boris Zbarsky Date: Wed, 9 Jun 2021 15:54:26 -0400 Subject: [PATCH] Fix data race on mTestIndex. (#7494) We could end up sending a message and getting a response to it before we ever incremented mTestIndex (if our call into NextTest() was on a thread other than the message thread). If that happened, we would end up running some subtest twice, and then later whenever we incrememented mTestIndex would end up skipping some subtest. Fixes https://github.com/project-chip/connectedhomeip/issues/7493 --- examples/chip-tool/commands/tests/Commands.h | 26 ++++++++++++------- .../templates/partials/test_cluster.zapt | 13 ++++++---- 2 files changed, 24 insertions(+), 15 deletions(-) diff --git a/examples/chip-tool/commands/tests/Commands.h b/examples/chip-tool/commands/tests/Commands.h index 807a9bbe4e4f37..c8e5b1cd4aef23 100644 --- a/examples/chip-tool/commands/tests/Commands.h +++ b/examples/chip-tool/commands/tests/Commands.h @@ -24,7 +24,7 @@ class TestCluster : public TestCommand { public: - TestCluster() : TestCommand("TestCluster") {} + TestCluster() : TestCommand("TestCluster"), mTestIndex(0) {} /////////// TestCommand Interface ///////// CHIP_ERROR NextTest() override @@ -37,7 +37,11 @@ class TestCluster : public TestCommand SetCommandExitStatus(true); } - switch (mTestIndex) + // Ensure we increment mTestIndex before we start running the relevant + // command. That way if we lose the timeslice after we send the message + // but before our function call returns, we won't end up with an + // incorrect mTestIndex value observed when we get the response. + switch (mTestIndex++) { case 0: err = TestSendClusterTestClusterCommandTest_0(); @@ -55,7 +59,6 @@ class TestCluster : public TestCommand err = TestSendClusterTestClusterCommandReadAttribute_4(); break; } - mTestIndex++; if (CHIP_NO_ERROR != err) { @@ -67,8 +70,8 @@ class TestCluster : public TestCommand } private: - uint16_t mTestIndex = 0; - uint16_t mTestCount = 5; + std::atomic_uint16_t mTestIndex; + const uint16_t mTestCount = 5; // // Tests methods @@ -437,7 +440,7 @@ class TestCluster : public TestCommand class OnOffCluster : public TestCommand { public: - OnOffCluster() : TestCommand("OnOffCluster") {} + OnOffCluster() : TestCommand("OnOffCluster"), mTestIndex(0) {} /////////// TestCommand Interface ///////// CHIP_ERROR NextTest() override @@ -450,7 +453,11 @@ class OnOffCluster : public TestCommand SetCommandExitStatus(true); } - switch (mTestIndex) + // Ensure we increment mTestIndex before we start running the relevant + // command. That way if we lose the timeslice after we send the message + // but before our function call returns, we won't end up with an + // incorrect mTestIndex value observed when we get the response. + switch (mTestIndex++) { case 0: err = TestSendClusterOnOffCommandReadAttribute_0(); @@ -468,7 +475,6 @@ class OnOffCluster : public TestCommand err = TestSendClusterOnOffCommandReadAttribute_4(); break; } - mTestIndex++; if (CHIP_NO_ERROR != err) { @@ -480,8 +486,8 @@ class OnOffCluster : public TestCommand } private: - uint16_t mTestIndex = 0; - uint16_t mTestCount = 5; + std::atomic_uint16_t mTestIndex; + const uint16_t mTestCount = 5; // // Tests methods diff --git a/examples/chip-tool/templates/partials/test_cluster.zapt b/examples/chip-tool/templates/partials/test_cluster.zapt index 2a201787bb4525..5d461c438c043a 100644 --- a/examples/chip-tool/templates/partials/test_cluster.zapt +++ b/examples/chip-tool/templates/partials/test_cluster.zapt @@ -2,7 +2,7 @@ class {{asCamelCased filename false}}: public TestCommand { public: - {{asCamelCased filename false}}(): TestCommand("{{filename}}") {} + {{asCamelCased filename false}}(): TestCommand("{{filename}}"), mTestIndex(0) {} /////////// TestCommand Interface ///////// CHIP_ERROR NextTest() override @@ -15,7 +15,11 @@ class {{asCamelCased filename false}}: public TestCommand SetCommandExitStatus(true); } - switch (mTestIndex) + // Ensure we increment mTestIndex before we start running the relevant + // command. That way if we lose the timeslice after we send the message + // but before our function call returns, we won't end up with an + // incorrect mTestIndex value observed when we get the response. + switch (mTestIndex++) { {{#chip_tests_items}} case {{index}}: @@ -23,7 +27,6 @@ class {{asCamelCased filename false}}: public TestCommand break; {{/chip_tests_items}} } - mTestIndex++; if (CHIP_NO_ERROR != err) { @@ -36,8 +39,8 @@ class {{asCamelCased filename false}}: public TestCommand private: - uint16_t mTestIndex = 0; - uint16_t mTestCount = {{totalTests}}; + std::atomic_uint16_t mTestIndex; + const uint16_t mTestCount = {{totalTests}}; // // Tests methods