Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix src/lib/support/StateMachine pattern matching #13664

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 19 additions & 13 deletions src/lib/support/StateMachine.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,37 +48,40 @@ struct VariantState : Variant<Ts...>

private:
template <typename T>
void Enter()
void Enter(bool & ever)
{
if (chip::Variant<Ts...>::template Is<T>())
if (!ever && chip::Variant<Ts...>::template Is<T>())
{
ever = true;
chip::Variant<Ts...>::template Get<T>().Enter();
}
}

template <typename T>
void Exit()
void Exit(bool & ever)
{
if (chip::Variant<Ts...>::template Is<T>())
if (!ever && chip::Variant<Ts...>::template Is<T>())
{
ever = true;
chip::Variant<Ts...>::template Get<T>().Exit();
}
}

template <typename T>
void GetName(const char ** name)
{
if (name && chip::Variant<Ts...>::template Is<T>())
if (name && !*name && chip::Variant<Ts...>::template Is<T>())
{
*name = chip::Variant<Ts...>::template Get<T>().GetName();
}
}

template <typename T>
void LogTransition(const char * previous)
void LogTransition(bool & ever, const char * previous)
{
if (chip::Variant<Ts...>::template Is<T>())
if (!ever && chip::Variant<Ts...>::template Is<T>())
{
ever = true;
chip::Variant<Ts...>::template Get<T>().LogTransition(previous);
}
}
Expand All @@ -94,12 +97,14 @@ struct VariantState : Variant<Ts...>

void Enter()
{
[](...) {}((this->template Enter<Ts>(), 0)...);
bool ever = false;
[](...) {}((this->template Enter<Ts>(ever), 0)...);
}

void Exit()
{
[](...) {}((this->template Exit<Ts>(), 0)...);
bool ever = false;
[](...) {}((this->template Exit<Ts>(ever), 0)...);
}

const char * GetName()
Expand All @@ -111,7 +116,8 @@ struct VariantState : Variant<Ts...>

void LogTransition(const char * previous)
{
[](...) {}((this->template LogTransition<Ts>(previous), 0)...);
bool ever = false;
[](...) {}((this->template LogTransition<Ts>(ever, previous), 0)...);
}
};

Expand Down Expand Up @@ -215,10 +221,10 @@ class StateMachine : public Context<TEvent>
auto newState = mTransitions(mCurrentState, evt);
if (newState.HasValue())
{
auto oldState = mCurrentState;
oldState.Exit();
auto oldState = mCurrentState.GetName();
mCurrentState.Exit();
mCurrentState = newState.Value();
mCurrentState.LogTransition(oldState.GetName());
mCurrentState.LogTransition(oldState);
// It is impermissible to dispatch events from Exit() or
// LogTransition(), or from the transitions table when a transition
// has also been returned. Verify that this hasn't occurred.
Expand Down
61 changes: 58 additions & 3 deletions src/lib/support/tests/TestStateMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,11 @@ struct Event3
struct Event4
{
};
struct Event5
{
};

using Event = chip::Variant<Event1, Event2, Event3, Event4>;
using Event = chip::Variant<Event1, Event2, Event3, Event4, Event5>;
using Context = chip::StateMachine::Context<Event>;

struct MockState
Expand Down Expand Up @@ -76,19 +79,32 @@ struct State2 : public BaseState
State2(Context & ctx, MockState & mock) : BaseState{ ctx, "State2", mock } {}
};

using State = chip::StateMachine::VariantState<State1, State2>;
struct State3 : public BaseState
{
State3(Context & ctx, MockState & mock) : BaseState{ ctx, "State3", mock } {}
void Enter()
{
BaseState::Enter();
this->mCtx.Dispatch(Event::Create<Event5>());
}
};

// Place State3 first in the variant. This can evoke the behavior that
// TestNestedDispatch is looking for.
using State = chip::StateMachine::VariantState<State3, State2, State1>;

struct StateFactory
{
Context & mCtx;
MockState ms1{ 0, 0, 0, nullptr };
MockState ms2{ 0, 0, 0, nullptr };
MockState ms3{ 0, 0, 0, nullptr };

StateFactory(Context & ctx) : mCtx(ctx) {}

auto CreateState1() { return State::Create<State1>(mCtx, ms1); }

auto CreateState2() { return State::Create<State2>(mCtx, ms2); }
auto CreateState3() { return State::Create<State3>(mCtx, ms3); }
};

struct Transitions
Expand Down Expand Up @@ -120,6 +136,14 @@ struct Transitions
// mCtx.Dispatch(Event::Create<Event2>()); // dsipatching an event and returning a transition would be illegal
return mFactory.CreateState1();
}
else if (state.Is<State1>() && event.Is<Event5>())
{
return mFactory.CreateState3();
}
else if (state.Is<State3>() && event.Is<Event5>())
{
return mFactory.CreateState2();
}
else
{
return {};
Expand Down Expand Up @@ -190,6 +214,36 @@ void TestTransitionsDispatch(nlTestSuite * inSuite, void * inContext)
NL_TEST_ASSERT(inSuite, fsm.mStateMachine.GetState().Is<State2>());
}

void TestNestedDispatch(nlTestSuite * inSuite, void * inContext)
{
// in State1
SimpleStateMachine fsm;
// Dispatch Event5, which places us into State3, which will dispatch
// Event5 again from its Enter method to place us into State2.
fsm.mStateMachine.Dispatch(Event::Create<Event5>());
NL_TEST_ASSERT(inSuite, fsm.mStateMachine.GetState().Is<State2>());
// Make sure that Enter methods execute the correct number of times.
// This helps verify that pattern matching is working correctly.
// Specifically, we need to verify this case: State3 creates State2
// by dispatching Event5 from its Enter method. This means that the
// Dispatch call from State3 also destructs State3. If the State3
// Enter method pattern matching triggers Enter more than once, this
// is use-after-destruction. What can appear to happen is that the
// State2 Enter method will execute twice, as State2 will already have
// been constructed when the State3 Enter method executes a second
// time. The state machine pattern matching has code to explicitly
// prevent this double-execution. This is testing that.
NL_TEST_ASSERT(inSuite, fsm.mTransitions.mFactory.ms1.mEntered == 0);
NL_TEST_ASSERT(inSuite, fsm.mTransitions.mFactory.ms1.mExited == 1);
NL_TEST_ASSERT(inSuite, fsm.mTransitions.mFactory.ms1.mLogged == 0);
NL_TEST_ASSERT(inSuite, fsm.mTransitions.mFactory.ms2.mEntered == 1);
NL_TEST_ASSERT(inSuite, fsm.mTransitions.mFactory.ms2.mExited == 0);
NL_TEST_ASSERT(inSuite, fsm.mTransitions.mFactory.ms2.mLogged == 1);
NL_TEST_ASSERT(inSuite, fsm.mTransitions.mFactory.ms3.mEntered == 1);
NL_TEST_ASSERT(inSuite, fsm.mTransitions.mFactory.ms3.mExited == 1);
NL_TEST_ASSERT(inSuite, fsm.mTransitions.mFactory.ms3.mLogged == 1);
}

void TestMethodExec(nlTestSuite * inSuite, void * inContext)
{
// in State1
Expand Down Expand Up @@ -237,6 +291,7 @@ static const nlTest sTests[] = {
NL_TEST_DEF("TestIgnoredEvents", TestIgnoredEvents),
NL_TEST_DEF("TestTransitions", TestTransitions),
NL_TEST_DEF("TestTransitionsDispatch", TestTransitionsDispatch),
NL_TEST_DEF("TestNestedDispatch", TestNestedDispatch),
NL_TEST_DEF("TestMethodExec", TestMethodExec),
NL_TEST_SENTINEL(),
};
Expand Down