From bb9ac689e272f2f6138b3c0763d4f7524a206135 Mon Sep 17 00:00:00 2001 From: Michael Sandstedt Date: Wed, 19 Jan 2022 16:02:32 -0600 Subject: [PATCH] Fix src/lib/support/StateMachine pattern matching (#13664) The Enter method can rewrite the state objects during pattern matching traversal. We need a one-shot bool to ensure we only execute Enter once. And although state changes from Exit, LogTransition and GetName aren't expected, we should check in these too. In all cases, the pattern match should only execute once. This commit adds the necessary check and a test case to verify this is working correctly. --- src/lib/support/StateMachine.h | 32 +++++++----- src/lib/support/tests/TestStateMachine.cpp | 61 ++++++++++++++++++++-- 2 files changed, 77 insertions(+), 16 deletions(-) diff --git a/src/lib/support/StateMachine.h b/src/lib/support/StateMachine.h index b9b5152512e032..7a9e255b440719 100644 --- a/src/lib/support/StateMachine.h +++ b/src/lib/support/StateMachine.h @@ -48,19 +48,21 @@ struct VariantState : Variant private: template - void Enter() + void Enter(bool & ever) { - if (chip::Variant::template Is()) + if (!ever && chip::Variant::template Is()) { + ever = true; chip::Variant::template Get().Enter(); } } template - void Exit() + void Exit(bool & ever) { - if (chip::Variant::template Is()) + if (!ever && chip::Variant::template Is()) { + ever = true; chip::Variant::template Get().Exit(); } } @@ -68,17 +70,18 @@ struct VariantState : Variant template void GetName(const char ** name) { - if (name && chip::Variant::template Is()) + if (name && !*name && chip::Variant::template Is()) { *name = chip::Variant::template Get().GetName(); } } template - void LogTransition(const char * previous) + void LogTransition(bool & ever, const char * previous) { - if (chip::Variant::template Is()) + if (!ever && chip::Variant::template Is()) { + ever = true; chip::Variant::template Get().LogTransition(previous); } } @@ -94,12 +97,14 @@ struct VariantState : Variant void Enter() { - [](...) {}((this->template Enter(), 0)...); + bool ever = false; + [](...) {}((this->template Enter(ever), 0)...); } void Exit() { - [](...) {}((this->template Exit(), 0)...); + bool ever = false; + [](...) {}((this->template Exit(ever), 0)...); } const char * GetName() @@ -111,7 +116,8 @@ struct VariantState : Variant void LogTransition(const char * previous) { - [](...) {}((this->template LogTransition(previous), 0)...); + bool ever = false; + [](...) {}((this->template LogTransition(ever, previous), 0)...); } }; @@ -215,10 +221,10 @@ class StateMachine : public Context auto newState = mTransitions(mCurrentState, evt); if (newState.HasValue()) { - auto oldState = mCurrentState; - oldState.Exit(); + auto oldState = mCurrentState.GetName(); + mCurrentState.Exit(); mCurrentState = newState.Value(); - mCurrentState.LogTransition(oldState.GetName()); + mCurrentState.LogTransition(oldState); // It is impermissible to dispatch events from Exit() or // LogTransition(), or from the transitions table when a transition // has also been returned. Verify that this hasn't occurred. diff --git a/src/lib/support/tests/TestStateMachine.cpp b/src/lib/support/tests/TestStateMachine.cpp index 37433dd64a0b9d..47baef33737049 100644 --- a/src/lib/support/tests/TestStateMachine.cpp +++ b/src/lib/support/tests/TestStateMachine.cpp @@ -34,8 +34,11 @@ struct Event3 struct Event4 { }; +struct Event5 +{ +}; -using Event = chip::Variant; +using Event = chip::Variant; using Context = chip::StateMachine::Context; struct MockState @@ -76,19 +79,32 @@ struct State2 : public BaseState State2(Context & ctx, MockState & mock) : BaseState{ ctx, "State2", mock } {} }; -using State = chip::StateMachine::VariantState; +struct State3 : public BaseState +{ + State3(Context & ctx, MockState & mock) : BaseState{ ctx, "State3", mock } {} + void Enter() + { + BaseState::Enter(); + this->mCtx.Dispatch(Event::Create()); + } +}; + +// Place State3 first in the variant. This can evoke the behavior that +// TestNestedDispatch is looking for. +using State = chip::StateMachine::VariantState; struct StateFactory { Context & mCtx; MockState ms1{ 0, 0, 0, nullptr }; MockState ms2{ 0, 0, 0, nullptr }; + MockState ms3{ 0, 0, 0, nullptr }; StateFactory(Context & ctx) : mCtx(ctx) {} auto CreateState1() { return State::Create(mCtx, ms1); } - auto CreateState2() { return State::Create(mCtx, ms2); } + auto CreateState3() { return State::Create(mCtx, ms3); } }; struct Transitions @@ -120,6 +136,14 @@ struct Transitions // mCtx.Dispatch(Event::Create()); // dsipatching an event and returning a transition would be illegal return mFactory.CreateState1(); } + else if (state.Is() && event.Is()) + { + return mFactory.CreateState3(); + } + else if (state.Is() && event.Is()) + { + return mFactory.CreateState2(); + } else { return {}; @@ -190,6 +214,36 @@ void TestTransitionsDispatch(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, fsm.mStateMachine.GetState().Is()); } +void TestNestedDispatch(nlTestSuite * inSuite, void * inContext) +{ + // in State1 + SimpleStateMachine fsm; + // Dispatch Event5, which places us into State3, which will dispatch + // Event5 again from its Enter method to place us into State2. + fsm.mStateMachine.Dispatch(Event::Create()); + NL_TEST_ASSERT(inSuite, fsm.mStateMachine.GetState().Is()); + // Make sure that Enter methods execute the correct number of times. + // This helps verify that pattern matching is working correctly. + // Specifically, we need to verify this case: State3 creates State2 + // by dispatching Event5 from its Enter method. This means that the + // Dispatch call from State3 also destructs State3. If the State3 + // Enter method pattern matching triggers Enter more than once, this + // is use-after-destruction. What can appear to happen is that the + // State2 Enter method will execute twice, as State2 will already have + // been constructed when the State3 Enter method executes a second + // time. The state machine pattern matching has code to explicitly + // prevent this double-execution. This is testing that. + NL_TEST_ASSERT(inSuite, fsm.mTransitions.mFactory.ms1.mEntered == 0); + NL_TEST_ASSERT(inSuite, fsm.mTransitions.mFactory.ms1.mExited == 1); + NL_TEST_ASSERT(inSuite, fsm.mTransitions.mFactory.ms1.mLogged == 0); + NL_TEST_ASSERT(inSuite, fsm.mTransitions.mFactory.ms2.mEntered == 1); + NL_TEST_ASSERT(inSuite, fsm.mTransitions.mFactory.ms2.mExited == 0); + NL_TEST_ASSERT(inSuite, fsm.mTransitions.mFactory.ms2.mLogged == 1); + NL_TEST_ASSERT(inSuite, fsm.mTransitions.mFactory.ms3.mEntered == 1); + NL_TEST_ASSERT(inSuite, fsm.mTransitions.mFactory.ms3.mExited == 1); + NL_TEST_ASSERT(inSuite, fsm.mTransitions.mFactory.ms3.mLogged == 1); +} + void TestMethodExec(nlTestSuite * inSuite, void * inContext) { // in State1 @@ -237,6 +291,7 @@ static const nlTest sTests[] = { NL_TEST_DEF("TestIgnoredEvents", TestIgnoredEvents), NL_TEST_DEF("TestTransitions", TestTransitions), NL_TEST_DEF("TestTransitionsDispatch", TestTransitionsDispatch), + NL_TEST_DEF("TestNestedDispatch", TestNestedDispatch), NL_TEST_DEF("TestMethodExec", TestMethodExec), NL_TEST_SENTINEL(), };