diff --git a/src/channel/ChannelContext.cpp b/src/channel/ChannelContext.cpp index df550c23504e30..97798f604b72c7 100644 --- a/src/channel/ChannelContext.cpp +++ b/src/channel/ChannelContext.cpp @@ -38,7 +38,7 @@ void ChannelContext::Start(const ChannelBuilder & builder) ExchangeContext * ChannelContext::NewExchange(ExchangeDelegate * delegate) { assert(GetState() == ChannelState::kReady); - return mExchangeManager->NewContext(mStateVars.mReady.mSession, delegate); + return mExchangeManager->NewContext(GetReadyVars().mSession, delegate); } bool ChannelContext::MatchNodeId(NodeId nodeId) @@ -46,9 +46,9 @@ bool ChannelContext::MatchNodeId(NodeId nodeId) switch (mState) { case ChannelState::kPreparing: - return nodeId == mStateVars.mPreparing.mBuilder.GetPeerNodeId(); + return nodeId == GetPrepareVars().mBuilder.GetPeerNodeId(); case ChannelState::kReady: { - auto state = mExchangeManager->GetSessionMgr()->GetPeerConnectionState(mStateVars.mReady.mSession); + auto state = mExchangeManager->GetSessionMgr()->GetPeerConnectionState(GetReadyVars().mSession); if (state == nullptr) return false; return nodeId == state->GetPeerNodeId(); @@ -63,7 +63,7 @@ bool ChannelContext::MatchTransport(Transport::Type transport) switch (mState) { case ChannelState::kPreparing: - switch (mStateVars.mPreparing.mBuilder.GetTransportPreference()) + switch (GetPrepareVars().mBuilder.GetTransportPreference()) { case ChannelBuilder::TransportPreference::kPreferConnectionOriented: case ChannelBuilder::TransportPreference::kConnectionOriented: @@ -73,7 +73,7 @@ bool ChannelContext::MatchTransport(Transport::Type transport) } return false; case ChannelState::kReady: { - auto state = mExchangeManager->GetSessionMgr()->GetPeerConnectionState(mStateVars.mReady.mSession); + auto state = mExchangeManager->GetSessionMgr()->GetPeerConnectionState(GetReadyVars().mSession); if (state == nullptr) return false; return transport == state->GetPeerAddress().GetTransportType(); @@ -118,7 +118,7 @@ bool ChannelContext::MatchesBuilder(const ChannelBuilder & builder) bool ChannelContext::IsCasePairing() { - return mState == ChannelState::kPreparing && mStateVars.mPreparing.mState == PrepareState::kCasePairing; + return mState == ChannelState::kPreparing && GetPrepareVars().mState == PrepareState::kCasePairing; } bool ChannelContext::MatchesSession(SecureSessionHandle session, SecureSessionMgr * ssm) @@ -126,19 +126,19 @@ bool ChannelContext::MatchesSession(SecureSessionHandle session, SecureSessionMg switch (mState) { case ChannelState::kPreparing: { - switch (mStateVars.mPreparing.mState) + switch (GetPrepareVars().mState) { case PrepareState::kCasePairing: { auto state = ssm->GetPeerConnectionState(session); - return (state->GetPeerNodeId() == mStateVars.mPreparing.mBuilder.GetPeerNodeId() && - state->GetPeerKeyID() == mStateVars.mPreparing.mBuilder.GetPeerKeyID()); + return (state->GetPeerNodeId() == GetPrepareVars().mBuilder.GetPeerNodeId() && + state->GetPeerKeyID() == GetPrepareVars().mBuilder.GetPeerKeyID()); } default: return false; } } case ChannelState::kReady: - return mStateVars.mReady.mSession == session; + return GetReadyVars().mSession == session; default: return false; } @@ -146,8 +146,10 @@ bool ChannelContext::MatchesSession(SecureSessionHandle session, SecureSessionMg void ChannelContext::EnterPreparingState(const ChannelBuilder & builder) { - mState = ChannelState::kPreparing; - mStateVars.mPreparing.mBuilder = builder; + mState = ChannelState::kPreparing; + + mStateVars.Set(); + GetPrepareVars().mBuilder = builder; EnterAddressResolve(); } @@ -157,14 +159,14 @@ void ChannelContext::ExitPreparingState() {} // Address resolve void ChannelContext::EnterAddressResolve() { - mStateVars.mPreparing.mState = PrepareState::kAddressResolving; + GetPrepareVars().mState = PrepareState::kAddressResolving; // Skip address resolve if the address is provided { - auto addr = mStateVars.mPreparing.mBuilder.GetForcePeerAddress(); + auto addr = GetPrepareVars().mBuilder.GetForcePeerAddress(); if (addr.HasValue()) { - mStateVars.mPreparing.mAddress = addr.Value(); + GetPrepareVars().mAddress = addr.Value(); ExitAddressResolve(); // Only CASE session is supported EnterCasePairingState(); @@ -174,10 +176,10 @@ void ChannelContext::EnterAddressResolve() // TODO: call mDNS Scanner::SubscribeNode after PR #4459 is ready // Scanner::RegisterScannerDelegate(this) - // Scanner::SubscribeNode(mStateVars.mPreparing.mBuilder.GetPeerNodeId()) + // Scanner::SubscribeNode(GetPrepareVars().mBuilder.GetPeerNodeId()) // The HandleNodeIdResolve may already have been called, recheck the state here before set up the timer - if (mState == ChannelState::kPreparing && mStateVars.mPreparing.mState == PrepareState::kAddressResolving) + if (mState == ChannelState::kPreparing && GetPrepareVars().mState == PrepareState::kAddressResolving) { System::Layer * layer = mExchangeManager->GetSessionMgr()->SystemLayer(); layer->StartTimer(CHIP_CONFIG_NODE_ADDRESS_RESOLVE_TIMEOUT_MSECS, AddressResolveTimeout, this); @@ -196,7 +198,7 @@ void ChannelContext::AddressResolveTimeout() { if (mState != ChannelState::kPreparing) return; - if (mStateVars.mPreparing.mState != PrepareState::kAddressResolving) + if (GetPrepareVars().mState != PrepareState::kAddressResolving) return; ExitAddressResolve(); @@ -219,7 +221,7 @@ void ChannelContext::HandleNodeIdResolve(CHIP_ERROR error, uint64_t nodeId, cons return; } case ChannelState::kPreparing: { - switch (mStateVars.mPreparing.mState) + switch (GetPrepareVars().mState) { case PrepareState::kAddressResolving: { if (error != CHIP_NO_ERROR) @@ -232,8 +234,8 @@ void ChannelContext::HandleNodeIdResolve(CHIP_ERROR error, uint64_t nodeId, cons if (!address.mAddress.HasValue()) return; - mStateVars.mPreparing.mAddressType = address.mAddressType; - mStateVars.mPreparing.mAddress = address.mAddress.Value(); + GetPrepareVars().mAddressType = address.mAddressType; + GetPrepareVars().mAddress = address.mAddress.Value(); ExitAddressResolve(); EnterCasePairingState(); return; @@ -253,18 +255,18 @@ void ChannelContext::HandleNodeIdResolve(CHIP_ERROR error, uint64_t nodeId, cons void ChannelContext::EnterCasePairingState() { - mStateVars.mPreparing.mState = PrepareState::kCasePairing; - mStateVars.mPreparing.mCasePairingSession = Platform::New(); + auto & prepare = GetPrepareVars(); + prepare.mCasePairingSession = Platform::New(); - ExchangeContext * ctxt = mExchangeManager->NewContext(SecureSessionHandle(), mStateVars.mPreparing.mCasePairingSession); + ExchangeContext * ctxt = mExchangeManager->NewContext(SecureSessionHandle(), prepare.mCasePairingSession); VerifyOrReturn(ctxt != nullptr); // TODO: currently only supports IP/UDP paring Transport::PeerAddress addr; - addr.SetTransportType(Transport::Type::kUdp).SetIPAddress(mStateVars.mPreparing.mAddress); - CHIP_ERROR err = mStateVars.mPreparing.mCasePairingSession->EstablishSession( - addr, &mStateVars.mPreparing.mBuilder.GetOperationalCredentialSet(), mStateVars.mPreparing.mBuilder.GetPeerNodeId(), - mExchangeManager->GetNextKeyId(), ctxt, this); + addr.SetTransportType(Transport::Type::kUdp).SetIPAddress(prepare.mAddress); + CHIP_ERROR err = prepare.mCasePairingSession->EstablishSession(addr, &prepare.mBuilder.GetOperationalCredentialSet(), + prepare.mBuilder.GetPeerNodeId(), + mExchangeManager->GetNextKeyId(), ctxt, this); if (err != CHIP_NO_ERROR) { ExitCasePairingState(); @@ -275,14 +277,14 @@ void ChannelContext::EnterCasePairingState() void ChannelContext::ExitCasePairingState() { - Platform::Delete(mStateVars.mPreparing.mCasePairingSession); + Platform::Delete(GetPrepareVars().mCasePairingSession); } void ChannelContext::OnSessionEstablishmentError(CHIP_ERROR error) { if (mState != ChannelState::kPreparing) return; - switch (mStateVars.mPreparing.mState) + switch (GetPrepareVars().mState) { case PrepareState::kCasePairing: ExitCasePairingState(); @@ -298,11 +300,11 @@ void ChannelContext::OnSessionEstablished() { if (mState != ChannelState::kPreparing) return; - switch (mStateVars.mPreparing.mState) + switch (GetPrepareVars().mState) { case PrepareState::kCasePairing: ExitCasePairingState(); - mStateVars.mPreparing.mState = PrepareState::kCasePairingDone; + GetPrepareVars().mState = PrepareState::kCasePairingDone; // TODO: current CASE paring session API doesn't show how to derive a secure session return; default: @@ -314,7 +316,7 @@ void ChannelContext::OnNewConnection(SecureSessionHandle session) { if (mState != ChannelState::kPreparing) return; - if (mStateVars.mPreparing.mState != PrepareState::kCasePairingDone) + if (GetPrepareVars().mState != PrepareState::kCasePairingDone) return; ExitPreparingState(); @@ -324,8 +326,7 @@ void ChannelContext::OnNewConnection(SecureSessionHandle session) void ChannelContext::EnterReadyState(SecureSessionHandle session) { mState = ChannelState::kReady; - - mStateVars.mReady.mSession = session; + mStateVars.Set(session); mChannelManager->NotifyChannelEvent(this, [](ChannelDelegate * delegate) { delegate->OnEstablished(); }); } @@ -344,7 +345,7 @@ void ChannelContext::ExitReadyState() // Currently SecureSessionManager doesn't provide an interface to close a session // TODO: call mDNS Scanner::UnubscribeNode after PR #4459 is ready - // Scanner::UnsubscribeNode(mStateVars.mPreparing.mBuilder.GetPeerNodeId()) + // Scanner::UnsubscribeNode(GetPrepareVars().mBuilder.GetPeerNodeId()) } void ChannelContext::EnterFailedState(CHIP_ERROR error) diff --git a/src/channel/ChannelContext.h b/src/channel/ChannelContext.h index 5a97fee0bde2e9..393c1cbad2c97a 100644 --- a/src/channel/ChannelContext.h +++ b/src/channel/ChannelContext.h @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -129,25 +130,28 @@ class ChannelContext : public ReferenceCounted mStateVars; + + PrepareVars & GetPrepareVars() { return mStateVars.Get(); } + ReadyVars & GetReadyVars() { return mStateVars.Get(); } // State machine functions void EnterPreparingState(const ChannelBuilder & builder); diff --git a/src/lib/support/BUILD.gn b/src/lib/support/BUILD.gn index 987dccb2f78982..f88bbc0603e18b 100644 --- a/src/lib/support/BUILD.gn +++ b/src/lib/support/BUILD.gn @@ -92,6 +92,7 @@ static_library("support") { "TimeUtils.h", "UnitTestRegistration.cpp", "UnitTestRegistration.h", + "Variant.h", "logging/CHIPLogging.cpp", "logging/CHIPLogging.h", "verhoeff/Verhoeff.cpp", diff --git a/src/lib/support/Variant.h b/src/lib/support/Variant.h new file mode 100644 index 00000000000000..8fe0e99babc697 --- /dev/null +++ b/src/lib/support/Variant.h @@ -0,0 +1,184 @@ +/* + * + * Copyright (c) 2021 Project CHIP Authors + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +namespace chip { + +namespace Internal { + +template +struct VariantCurry; + +template +struct VariantCurry +{ + inline static void Destroy(std::size_t id, void * data) + { + if (id == T::VariantId) + reinterpret_cast(data)->~T(); + else + VariantCurry::Destroy(id, data); + } + + inline static void Move(std::size_t that_t, void * that_v, void * this_v) + { + if (that_t == T::VariantId) + new (this_v) T(std::move(*reinterpret_cast(that_v))); + else + VariantCurry::Move(that_t, that_v, this_v); + } + + inline static void Copy(std::size_t that_t, const void * that_v, void * this_v) + { + if (that_t == T::VariantId) + new (this_v) T(*reinterpret_cast(that_v)); + else + VariantCurry::Copy(that_t, that_v, this_v); + } +}; + +template <> +struct VariantCurry<> +{ + inline static void Destroy(std::size_t id, void * data) {} + inline static void Move(std::size_t that_t, void * that_v, void * this_v) {} + inline static void Copy(std::size_t that_t, const void * that_v, void * this_v) {} +}; + +} // namespace Internal + +/** + * @brief + * Represents a type-safe union. An instance of Variant at any given time either holds a value of one of its + * alternative types, or no value. Each type must define a unique value of a static field named VariantId. + * + * Example: + * struct Type1 + * { + * static constexpr const std::size_t VariantId = 1; + * }; + * + * struct Type2 + * { + * static constexpr const std::size_t VariantId = 2; + * }; + * + * Variant v; + * v.Set(); // v contains Type1 + * Type1 o1 = v.Get(); + */ +template +struct Variant +{ +private: + static constexpr std::size_t kDataSize = std::max(sizeof(Ts)...); + static constexpr std::size_t kDataAlign = std::max(alignof(Ts)...); + static constexpr std::size_t kInvalidType = SIZE_MAX; + + using Data = typename std::aligned_storage::type; + using Curry = Internal::VariantCurry; + + std::size_t mTypeId; + Data mData; + +public: + Variant() : mTypeId(kInvalidType) {} + + Variant(const Variant & that) : mTypeId(that.mTypeId) { Curry::Copy(that.mTypeId, &that.mData, &mData); } + + Variant(Variant && that) : mTypeId(that.mTypeId) + { + Curry::Move(that.mTypeId, &that.mData, &mData); + Curry::Destroy(that.mTypeId, &that.mData); + that.mTypeId = kInvalidType; + } + + Variant & operator=(Variant & that) + { + Curry::Destroy(mTypeId, &mData); + mTypeId = that.mTypeId; + Curry::Copy(that.mTypeId, &that.mData, &mData); + return *this; + } + + Variant & operator=(Variant && that) + { + Curry::Destroy(mTypeId, &mData); + mTypeId = that.mTypeId; + Curry::Move(that.mTypeId, &that.mData, &mData); + return *this; + } + + template + bool Is() + { + return (mTypeId == T::VariantId); + } + + bool Valid() { return (mTypeId != kInvalidType); } + + template + void Set(Args &&... args) + { + Curry::Destroy(mTypeId, &mData); + new (&mData) T(std::forward(args)...); + mTypeId = T::VariantId; + } + + template + T & Get() + { + if (mTypeId == T::VariantId) + { + return *reinterpret_cast(&mData); + } + else + { + assert(false); + return *static_cast(nullptr); + } + } + + template + const T & Get() const + { + if (mTypeId == T::VariantId) + { + return *reinterpret_cast(&mData); + } + else + { + assert(false); + return *static_cast(nullptr); + } + } + + ~Variant() { Curry::Destroy(mTypeId, &mData); } +}; + +} // namespace chip diff --git a/src/lib/support/tests/BUILD.gn b/src/lib/support/tests/BUILD.gn index 5d1e4e4a9c37f7..78b43a90024b51 100644 --- a/src/lib/support/tests/BUILD.gn +++ b/src/lib/support/tests/BUILD.gn @@ -40,6 +40,7 @@ chip_test_suite("tests") { "TestStringBuilder.cpp", "TestThreadOperationalDataset.cpp", "TestTimeUtils.cpp", + "TestVariant.cpp", ] sources = [] diff --git a/src/lib/support/tests/TestVariant.cpp b/src/lib/support/tests/TestVariant.cpp new file mode 100644 index 00000000000000..611a7038e632e0 --- /dev/null +++ b/src/lib/support/tests/TestVariant.cpp @@ -0,0 +1,178 @@ +/* + * + * Copyright (c) 2021 Project CHIP Authors + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include + +namespace { + +struct Simple +{ + static constexpr const std::size_t VariantId = 1; +}; + +struct Pod +{ + static constexpr const std::size_t VariantId = 2; + + Pod(int v1, int v2) : m1(v1), m2(v2) {} + + int m1; + int m2; +}; + +struct Movable +{ + static constexpr const std::size_t VariantId = 3; + + Movable(int v1, int v2) : m1(v1), m2(v2) {} + + Movable(Movable &) = delete; + Movable & operator=(Movable &) = delete; + + Movable(Movable &&) = default; + Movable & operator=(Movable &&) = default; + + int m1; + int m2; +}; + +struct Count +{ + static constexpr const std::size_t VariantId = 4; + + Count() { ++created; } + ~Count() { ++destroyed; } + + static int created; + static int destroyed; +}; + +int Count::created = 0; +int Count::destroyed = 0; + +using namespace chip; + +void TestVariantSimple(nlTestSuite * inSuite, void * inContext) +{ + Variant v; + NL_TEST_ASSERT(inSuite, !v.Valid()); + v.Set(5, 10); + NL_TEST_ASSERT(inSuite, v.Valid()); + NL_TEST_ASSERT(inSuite, v.Is()); + NL_TEST_ASSERT(inSuite, v.Get().m1 == 5); + NL_TEST_ASSERT(inSuite, v.Get().m2 == 10); +} + +void TestVariantMovable(nlTestSuite * inSuite, void * inContext) +{ + Variant v; + v.Set(); + v.Set(Movable{ 5, 10 }); + NL_TEST_ASSERT(inSuite, v.Get().m1 == 5); + NL_TEST_ASSERT(inSuite, v.Get().m2 == 10); + auto & m = v.Get(); + NL_TEST_ASSERT(inSuite, m.m1 == 5); + NL_TEST_ASSERT(inSuite, m.m2 == 10); + v.Set(); +} + +void TestVariantCtorDtor(nlTestSuite * inSuite, void * inContext) +{ + { + Variant v; + NL_TEST_ASSERT(inSuite, Count::created == 0); + v.Set(); + NL_TEST_ASSERT(inSuite, Count::created == 0); + v.Get(); + NL_TEST_ASSERT(inSuite, Count::created == 0); + } + { + Variant v; + NL_TEST_ASSERT(inSuite, Count::created == 0); + v.Set(); + NL_TEST_ASSERT(inSuite, Count::created == 0); + v.Set(); + NL_TEST_ASSERT(inSuite, Count::created == 1); + NL_TEST_ASSERT(inSuite, Count::destroyed == 0); + v.Get(); + NL_TEST_ASSERT(inSuite, Count::created == 1); + NL_TEST_ASSERT(inSuite, Count::destroyed == 0); + v.Set(); + NL_TEST_ASSERT(inSuite, Count::created == 1); + NL_TEST_ASSERT(inSuite, Count::destroyed == 1); + v.Set(); + NL_TEST_ASSERT(inSuite, Count::created == 2); + NL_TEST_ASSERT(inSuite, Count::destroyed == 1); + } + NL_TEST_ASSERT(inSuite, Count::destroyed == 2); +} + +void TestVariantCopy(nlTestSuite * inSuite, void * inContext) +{ + Variant v1; + v1.Set(5, 10); + Variant v2 = v1; + NL_TEST_ASSERT(inSuite, v1.Get().m1 == 5); + NL_TEST_ASSERT(inSuite, v1.Get().m2 == 10); + NL_TEST_ASSERT(inSuite, v2.Get().m1 == 5); + NL_TEST_ASSERT(inSuite, v2.Get().m2 == 10); +} + +void TestVariantMove(nlTestSuite * inSuite, void * inContext) +{ + Variant v1; + v1.Set(5, 10); + Variant v2 = std::move(v1); + NL_TEST_ASSERT(inSuite, !v1.Valid()); + NL_TEST_ASSERT(inSuite, v2.Get().m1 == 5); + NL_TEST_ASSERT(inSuite, v2.Get().m2 == 10); +} + +int Setup(void * inContext) +{ + return SUCCESS; +} + +int Teardown(void * inContext) +{ + return SUCCESS; +} + +} // namespace + +#define NL_TEST_DEF_FN(fn) NL_TEST_DEF("Test " #fn, fn) +/** + * Test Suite. It lists all the test functions. + */ +static const nlTest sTests[] = { NL_TEST_DEF_FN(TestVariantSimple), NL_TEST_DEF_FN(TestVariantMovable), + NL_TEST_DEF_FN(TestVariantCtorDtor), NL_TEST_DEF_FN(TestVariantCopy), + NL_TEST_DEF_FN(TestVariantMove), NL_TEST_SENTINEL() }; + +int TestVariant() +{ + nlTestSuite theSuite = { "CHIP Variant tests", &sTests[0], Setup, Teardown }; + + // Run test suit againt one context. + nlTestRunner(&theSuite, nullptr); + return nlTestRunnerStats(&theSuite); +} + +CHIP_REGISTER_TEST_SUITE(TestVariant);