diff --git a/src/system/SystemLayerImplSelect.cpp b/src/system/SystemLayerImplSelect.cpp index 98f4fe71547826..9f967a516ecd22 100644 --- a/src/system/SystemLayerImplSelect.cpp +++ b/src/system/SystemLayerImplSelect.cpp @@ -170,6 +170,12 @@ void LayerImplSelect::CancelTimer(TimerCompleteCallback onComplete, void * appSt VerifyOrReturn(mLayerState.IsInitialized()); TimerList::Node * timer = mTimerList.Remove(onComplete, appState); + + if (timer == nullptr) + { + // Check if the timer is maybe currently being processed + timer = mExpiredTimersBeingProcessed.Remove(onComplete, appState); + } VerifyOrReturn(timer != nullptr); #if CHIP_SYSTEM_CONFIG_USE_DISPATCH @@ -469,9 +475,9 @@ void LayerImplSelect::HandleEvents() // Obtain the list of currently expired timers. Any new timers added by timer callback are NOT handled on this pass, // since that could result in infinite handling of new timers blocking any other progress. - TimerList expiredTimers = mTimerList.ExtractEarlier(Clock::Timeout(1) + SystemClock().GetMonotonicTimestamp()); - TimerList::Node * timer = nullptr; - while ((timer = expiredTimers.PopEarliest()) != nullptr) + mExpiredTimersBeingProcessed = mTimerList.ExtractEarlier(Clock::Timeout(1) + SystemClock().GetMonotonicTimestamp()); + TimerList::Node * timer = nullptr; + while ((timer = mExpiredTimersBeingProcessed.PopEarliest()) != nullptr) { mTimerPool.Invoke(timer); } diff --git a/src/system/SystemLayerImplSelect.h b/src/system/SystemLayerImplSelect.h index f193a8860d961c..9e15a2855f123e 100644 --- a/src/system/SystemLayerImplSelect.h +++ b/src/system/SystemLayerImplSelect.h @@ -101,6 +101,7 @@ class LayerImplSelect : public LayerSocketsLoop TimerPool mTimerPool; TimerList mTimerList; + TimerList mExpiredTimersBeingProcessed; // timers handled by HandleEvents timeval mNextTimeout; // Members for select loop diff --git a/src/system/tests/TestSystemTimer.cpp b/src/system/tests/TestSystemTimer.cpp index 58ff87e85cd8dd..4052b767a7690c 100644 --- a/src/system/tests/TestSystemTimer.cpp +++ b/src/system/tests/TestSystemTimer.cpp @@ -126,6 +126,15 @@ class TestContext TestContext() : mGreedyTimer(GreedyTimer, this), mNumTimersHandled(0) {} }; +static TestContext * gCurrentTestContext = nullptr; + +class ScopedGlobalTestContext +{ +public: + ScopedGlobalTestContext(TestContext * ctx) { gCurrentTestContext = ctx; } + ~ScopedGlobalTestContext() { gCurrentTestContext = nullptr; } +}; + // Test input data. static volatile bool sOverflowTestDone; @@ -257,6 +266,109 @@ void CheckOrder(nlTestSuite * inSuite, void * aContext) Clock::Internal::SetSystemClockForTesting(savedClock); } +namespace { + +namespace CancelTimerTest { + +// A bit lower than maximum system timers just in case, for systems that +// have some form of limit +constexpr unsigned kCancelTimerCount = CHIP_SYSTEM_CONFIG_NUM_TIMERS - 4; +int gCallbackProcessed[kCancelTimerCount]; + +/// Validates that gCallbackProcessed has valid values (0 or 1) +void ValidateExecutedTimerCounts(nlTestSuite * suite) +{ + for (unsigned i = 0; i < kCancelTimerCount; i++) + { + NL_TEST_ASSERT(suite, (gCallbackProcessed[i] == 0) || (gCallbackProcessed[i] == 1)); + } +} + +unsigned ExecutedTimerCount() +{ + unsigned count = 0; + for (unsigned i = 0; i < kCancelTimerCount; i++) + { + if (gCallbackProcessed[i] != 0) + { + count++; + } + } + return count; +} + +void Callback(Layer * layer, void * state) +{ + unsigned idx = static_cast(reinterpret_cast(state)); + if (gCallbackProcessed[idx] != 0) + { + ChipLogError(Test, "UNEXPECTED EXECUTION at index %u", idx); + } + + gCallbackProcessed[idx]++; + + if (ExecutedTimerCount() == kCancelTimerCount / 2) + { + ChipLogProgress(Test, "Cancelling timers"); + for (unsigned i = 0; i < kCancelTimerCount; i++) + { + if (gCallbackProcessed[i] != 0) + { + continue; + } + ChipLogProgress(Test, "Timer %u is being cancelled", i); + gCurrentTestContext->mLayer->CancelTimer(Callback, reinterpret_cast(static_cast(i))); + gCallbackProcessed[i]++; // pretend executed. + } + } +} + +void Test(nlTestSuite * inSuite, void * aContext) +{ + // Validates that timers can cancel other timers. Generally the test will + // do the following: + // - schedule several timers to start at the same time + // - within each timers, after half of them have run, make one timer + // cancel all the other ones + // - assert that: + // - timers will run if scheduled + // - once cancelled, timers will NOT run (i.e. a timer can cancel + // other timers, even if they are expiring at the same time) + memset(gCallbackProcessed, 0, sizeof(gCallbackProcessed)); + + TestContext & testContext = *static_cast(aContext); + ScopedGlobalTestContext testScope(&testContext); + + Layer & systemLayer = *testContext.mLayer; + nlTestSuite * const suite = testContext.mTestSuite; + + Clock::ClockBase * const savedClock = &SystemClock(); + Clock::Internal::MockClock mockClock; + Clock::Internal::SetSystemClockForTesting(&mockClock); + using namespace Clock::Literals; + + for (unsigned i = 0; i < kCancelTimerCount; i++) + { + NL_TEST_ASSERT( + suite, systemLayer.StartTimer(10_ms, Callback, reinterpret_cast(static_cast(i))) == CHIP_NO_ERROR); + } + + LayerEvents::ServiceEvents(systemLayer); + ValidateExecutedTimerCounts(suite); + NL_TEST_ASSERT(suite, ExecutedTimerCount() == 0); + + mockClock.AdvanceMonotonic(20_ms); + LayerEvents::ServiceEvents(systemLayer); + + ValidateExecutedTimerCounts(suite); + NL_TEST_ASSERT(suite, ExecutedTimerCount() == kCancelTimerCount); + + Clock::Internal::SetSystemClockForTesting(savedClock); +} + +} // namespace CancelTimerTest +} // namespace + // Test the implementation helper classes TimerPool, TimerList, and TimerData. namespace chip { namespace System { @@ -417,6 +529,7 @@ static const nlTest sTests[] = NL_TEST_DEF("Timer::TestTimerStarvation", CheckStarvation), NL_TEST_DEF("Timer::TestTimerOrder", CheckOrder), NL_TEST_DEF("Timer::TestTimerPool", chip::System::TestTimer::CheckTimerPool), + NL_TEST_DEF("Timer::TestCancelTimer", CancelTimerTest::Test), NL_TEST_SENTINEL() }; // clang-format on diff --git a/src/transport/raw/tests/NetworkTestHelpers.h b/src/transport/raw/tests/NetworkTestHelpers.h index de459cbeed9090..7e88d9a7ec042d 100644 --- a/src/transport/raw/tests/NetworkTestHelpers.h +++ b/src/transport/raw/tests/NetworkTestHelpers.h @@ -68,9 +68,16 @@ class LoopbackTransportDelegate // configurable allowed number of messages (mNumMessagesToAllowBeforeDropping) virtual void OnMessageDropped() {} }; - class LoopbackTransport : public Transport::Base { +private: + // Use unique pointers for work callbacks, so that one callback does not cancel another. + struct LoopbackWork + { + LoopbackTransport * self; + LoopbackWork(LoopbackTransport * transport) : self(transport) {} + }; + public: void InitLoopbackTransport(System::Layer * systemLayer) { mSystemLayer = systemLayer; } void ShutdownLoopbackTransport() @@ -89,7 +96,9 @@ class LoopbackTransport : public Transport::Base static void OnMessageReceived(System::Layer * aSystemLayer, void * aAppState) { - LoopbackTransport * _this = static_cast(aAppState); + LoopbackWork * work = static_cast(aAppState); + LoopbackTransport * _this = work->self; + delete work; while (!_this->mPendingMessageQueue.empty()) { @@ -129,7 +138,7 @@ class LoopbackTransport : public Transport::Base { System::PacketBufferHandle receivedMessage = msgBuf.CloneData(); mPendingMessageQueue.push(PendingMessageItem(address, std::move(receivedMessage))); - mSystemLayer->ScheduleWork(OnMessageReceived, this); + mSystemLayer->ScheduleWork(OnMessageReceived, new LoopbackWork(this)); } return CHIP_NO_ERROR;