diff --git a/src/lib/support/StateMachine.h b/src/lib/support/StateMachine.h index b9b5152512e032..7a9e255b440719 100644 --- a/src/lib/support/StateMachine.h +++ b/src/lib/support/StateMachine.h @@ -48,19 +48,21 @@ struct VariantState : Variant private: template - void Enter() + void Enter(bool & ever) { - if (chip::Variant::template Is()) + if (!ever && chip::Variant::template Is()) { + ever = true; chip::Variant::template Get().Enter(); } } template - void Exit() + void Exit(bool & ever) { - if (chip::Variant::template Is()) + if (!ever && chip::Variant::template Is()) { + ever = true; chip::Variant::template Get().Exit(); } } @@ -68,17 +70,18 @@ struct VariantState : Variant template void GetName(const char ** name) { - if (name && chip::Variant::template Is()) + if (name && !*name && chip::Variant::template Is()) { *name = chip::Variant::template Get().GetName(); } } template - void LogTransition(const char * previous) + void LogTransition(bool & ever, const char * previous) { - if (chip::Variant::template Is()) + if (!ever && chip::Variant::template Is()) { + ever = true; chip::Variant::template Get().LogTransition(previous); } } @@ -94,12 +97,14 @@ struct VariantState : Variant void Enter() { - [](...) {}((this->template Enter(), 0)...); + bool ever = false; + [](...) {}((this->template Enter(ever), 0)...); } void Exit() { - [](...) {}((this->template Exit(), 0)...); + bool ever = false; + [](...) {}((this->template Exit(ever), 0)...); } const char * GetName() @@ -111,7 +116,8 @@ struct VariantState : Variant void LogTransition(const char * previous) { - [](...) {}((this->template LogTransition(previous), 0)...); + bool ever = false; + [](...) {}((this->template LogTransition(ever, previous), 0)...); } }; @@ -215,10 +221,10 @@ class StateMachine : public Context auto newState = mTransitions(mCurrentState, evt); if (newState.HasValue()) { - auto oldState = mCurrentState; - oldState.Exit(); + auto oldState = mCurrentState.GetName(); + mCurrentState.Exit(); mCurrentState = newState.Value(); - mCurrentState.LogTransition(oldState.GetName()); + mCurrentState.LogTransition(oldState); // It is impermissible to dispatch events from Exit() or // LogTransition(), or from the transitions table when a transition // has also been returned. Verify that this hasn't occurred. diff --git a/src/lib/support/tests/TestStateMachine.cpp b/src/lib/support/tests/TestStateMachine.cpp index 37433dd64a0b9d..47baef33737049 100644 --- a/src/lib/support/tests/TestStateMachine.cpp +++ b/src/lib/support/tests/TestStateMachine.cpp @@ -34,8 +34,11 @@ struct Event3 struct Event4 { }; +struct Event5 +{ +}; -using Event = chip::Variant; +using Event = chip::Variant; using Context = chip::StateMachine::Context; struct MockState @@ -76,19 +79,32 @@ struct State2 : public BaseState State2(Context & ctx, MockState & mock) : BaseState{ ctx, "State2", mock } {} }; -using State = chip::StateMachine::VariantState; +struct State3 : public BaseState +{ + State3(Context & ctx, MockState & mock) : BaseState{ ctx, "State3", mock } {} + void Enter() + { + BaseState::Enter(); + this->mCtx.Dispatch(Event::Create()); + } +}; + +// Place State3 first in the variant. This can evoke the behavior that +// TestNestedDispatch is looking for. +using State = chip::StateMachine::VariantState; struct StateFactory { Context & mCtx; MockState ms1{ 0, 0, 0, nullptr }; MockState ms2{ 0, 0, 0, nullptr }; + MockState ms3{ 0, 0, 0, nullptr }; StateFactory(Context & ctx) : mCtx(ctx) {} auto CreateState1() { return State::Create(mCtx, ms1); } - auto CreateState2() { return State::Create(mCtx, ms2); } + auto CreateState3() { return State::Create(mCtx, ms3); } }; struct Transitions @@ -120,6 +136,14 @@ struct Transitions // mCtx.Dispatch(Event::Create()); // dsipatching an event and returning a transition would be illegal return mFactory.CreateState1(); } + else if (state.Is() && event.Is()) + { + return mFactory.CreateState3(); + } + else if (state.Is() && event.Is()) + { + return mFactory.CreateState2(); + } else { return {}; @@ -190,6 +214,36 @@ void TestTransitionsDispatch(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, fsm.mStateMachine.GetState().Is()); } +void TestNestedDispatch(nlTestSuite * inSuite, void * inContext) +{ + // in State1 + SimpleStateMachine fsm; + // Dispatch Event5, which places us into State3, which will dispatch + // Event5 again from its Enter method to place us into State2. + fsm.mStateMachine.Dispatch(Event::Create()); + NL_TEST_ASSERT(inSuite, fsm.mStateMachine.GetState().Is()); + // Make sure that Enter methods execute the correct number of times. + // This helps verify that pattern matching is working correctly. + // Specifically, we need to verify this case: State3 creates State2 + // by dispatching Event5 from its Enter method. This means that the + // Dispatch call from State3 also destructs State3. If the State3 + // Enter method pattern matching triggers Enter more than once, this + // is use-after-destruction. What can appear to happen is that the + // State2 Enter method will execute twice, as State2 will already have + // been constructed when the State3 Enter method executes a second + // time. The state machine pattern matching has code to explicitly + // prevent this double-execution. This is testing that. + NL_TEST_ASSERT(inSuite, fsm.mTransitions.mFactory.ms1.mEntered == 0); + NL_TEST_ASSERT(inSuite, fsm.mTransitions.mFactory.ms1.mExited == 1); + NL_TEST_ASSERT(inSuite, fsm.mTransitions.mFactory.ms1.mLogged == 0); + NL_TEST_ASSERT(inSuite, fsm.mTransitions.mFactory.ms2.mEntered == 1); + NL_TEST_ASSERT(inSuite, fsm.mTransitions.mFactory.ms2.mExited == 0); + NL_TEST_ASSERT(inSuite, fsm.mTransitions.mFactory.ms2.mLogged == 1); + NL_TEST_ASSERT(inSuite, fsm.mTransitions.mFactory.ms3.mEntered == 1); + NL_TEST_ASSERT(inSuite, fsm.mTransitions.mFactory.ms3.mExited == 1); + NL_TEST_ASSERT(inSuite, fsm.mTransitions.mFactory.ms3.mLogged == 1); +} + void TestMethodExec(nlTestSuite * inSuite, void * inContext) { // in State1 @@ -237,6 +291,7 @@ static const nlTest sTests[] = { NL_TEST_DEF("TestIgnoredEvents", TestIgnoredEvents), NL_TEST_DEF("TestTransitions", TestTransitions), NL_TEST_DEF("TestTransitionsDispatch", TestTransitionsDispatch), + NL_TEST_DEF("TestNestedDispatch", TestNestedDispatch), NL_TEST_DEF("TestMethodExec", TestMethodExec), NL_TEST_SENTINEL(), };