From 972e22ecf17183f03790983207918b4f77349b30 Mon Sep 17 00:00:00 2001 From: Karsten Sperling Date: Tue, 11 Jun 2024 01:53:54 +1200 Subject: [PATCH] Fix OperationalSessionSetup notifying success with inactive sessions When a success callback marks the session defunct for some reason, other succcess callbacks should not be called. Implement a GroupedCallbackList to make this logic possible. This specific solution is based on the assumption that we don't want to change the OperationalSessionSetup API which takes success and two variants of failure callbacks as separate, client-provided Callback objects, with all callbacks being optional to provide. We also don't want to introduce additional dynamic allocation within OperationalSessionSetup e.g. to allocate a struct holding the related callbacks. The GroupedCallbackList class makes use of the existing prev/next pointers within the client-allocated Callback objects to capture the grouping relationship between them. Co-authored-by: Boris Zbarsky --- src/app/OperationalSessionSetup.cpp | 120 +++------ src/app/OperationalSessionSetup.h | 12 +- src/lib/core/BUILD.gn | 1 + src/lib/core/GroupedCallbackList.h | 246 ++++++++++++++++++ src/lib/core/tests/BUILD.gn | 1 + .../core/tests/TestGroupedCallbackList.cpp | 226 ++++++++++++++++ 6 files changed, 522 insertions(+), 84 deletions(-) create mode 100644 src/lib/core/GroupedCallbackList.h create mode 100644 src/lib/core/tests/TestGroupedCallbackList.cpp diff --git a/src/app/OperationalSessionSetup.cpp b/src/app/OperationalSessionSetup.cpp index 9197a2edbddf70..5b2f00ed0a3798 100644 --- a/src/app/OperationalSessionSetup.cpp +++ b/src/app/OperationalSessionSetup.cpp @@ -325,35 +325,14 @@ void OperationalSessionSetup::EnqueueConnectionCallbacks(Callback::Callback * onFailure, Callback::Callback * onSetupFailure) { - if (onConnection != nullptr) - { - mConnectionSuccess.Enqueue(onConnection->Cancel()); - } - - if (onFailure != nullptr) - { - mConnectionFailure.Enqueue(onFailure->Cancel()); - } - - if (onSetupFailure != nullptr) - { - mSetupFailure.Enqueue(onSetupFailure->Cancel()); - } + mCallbacks.Enqueue(onConnection, onFailure, onSetupFailure); } void OperationalSessionSetup::DequeueConnectionCallbacks(CHIP_ERROR error, SessionEstablishmentStage stage, ReleaseBehavior releaseBehavior) { - Cancelable failureReady, setupFailureReady, successReady; - - // - // Dequeue both failure and success callback lists into temporary stack args before invoking either of them. - // We do this since we may not have a valid 'this' pointer anymore upon invoking any of those callbacks - // since the callee may destroy this object as part of that callback. - // - mConnectionFailure.DequeueAll(failureReady); - mSetupFailure.DequeueAll(setupFailureReady); - mConnectionSuccess.DequeueAll(successReady); + // We expect that we only have callbacks if we are not performing just address update. + VerifyOrDie(!mPerformingAddressUpdate || mCallbacks.IsEmpty()); #if CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES // Clear out mConnectionRetry, so that those cancelables are not holding @@ -365,7 +344,8 @@ void OperationalSessionSetup::DequeueConnectionCallbacks(CHIP_ERROR error, Sessi #endif // CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES // Gather up state we will need for our notifications. - bool performingAddressUpdate = mPerformingAddressUpdate; + SuccessFailureCallbackList readyCallbacks; + readyCallbacks.EnqueueTakeAll(mCallbacks); auto * exchangeMgr = mInitParams.exchangeMgr; Optional optionalSessionHandle = mSecureSession.Get(); ScopedNodeId peerId = mPeerId; @@ -383,71 +363,57 @@ void OperationalSessionSetup::DequeueConnectionCallbacks(CHIP_ERROR error, Sessi } // DO NOT touch any members of this object after this point. It's dead. - - NotifyConnectionCallbacks(failureReady, setupFailureReady, successReady, error, stage, peerId, performingAddressUpdate, - exchangeMgr, optionalSessionHandle, requestedBusyDelay); + NotifyConnectionCallbacks(readyCallbacks, error, stage, peerId, exchangeMgr, optionalSessionHandle, requestedBusyDelay); } -void OperationalSessionSetup::NotifyConnectionCallbacks(Cancelable & failureReady, Cancelable & setupFailureReady, - Cancelable & successReady, CHIP_ERROR error, +void OperationalSessionSetup::NotifyConnectionCallbacks(SuccessFailureCallbackList & ready, CHIP_ERROR error, SessionEstablishmentStage stage, const ScopedNodeId & peerId, - bool performingAddressUpdate, Messaging::ExchangeManager * exchangeMgr, + Messaging::ExchangeManager * exchangeMgr, const Optional & optionalSessionHandle, System::Clock::Milliseconds16 requestedBusyDelay) { - // - // If we encountered no error, go ahead and call all success callbacks. Otherwise, - // call the failure callbacks. - // - while (failureReady.mNext != &failureReady) + Callback::Callback * onConnected; + Callback::Callback * onConnectionFailure; + Callback::Callback * onSetupFailure; + while (ready.Take(onConnected, onConnectionFailure, onSetupFailure)) { - // We expect that we only have callbacks if we are not performing just address update. - VerifyOrDie(!performingAddressUpdate); - Callback::Callback * cb = - Callback::Callback::FromCancelable(failureReady.mNext); - - cb->Cancel(); - - if (error != CHIP_NO_ERROR) + if (error == CHIP_NO_ERROR) { - cb->mCall(cb->mContext, peerId, error); + VerifyOrDie(exchangeMgr); + VerifyOrDie(optionalSessionHandle.Value()->AsSecureSession()->IsActiveSession()); + if (onConnected != nullptr) + { + onConnected->mCall(onConnected->mContext, *exchangeMgr, optionalSessionHandle.Value()); + + // That sucessful call might have made the session inactive. If it did, then we should + // not call any more success callbacks, since we do not in fact have an active session + // for them, and if they try to put the session in a holder that will fail, and then + // trying to use the holder as if it has a session will crash. + if (!optionalSessionHandle.Value()->AsSecureSession()->IsActiveSession()) + { + ChipLogError(Discovery, "Success callback for connection to " ChipLogFormatScopedNodeId " tore down session", + ChipLogValueScopedNodeId(peerId)); + error = CHIP_ERROR_CONNECTION_ABORTED; + } + } } - } - - while (setupFailureReady.mNext != &setupFailureReady) - { - // We expect that we only have callbacks if we are not performing just address update. - VerifyOrDie(!performingAddressUpdate); - Callback::Callback * cb = Callback::Callback::FromCancelable(setupFailureReady.mNext); - - cb->Cancel(); - - if (error != CHIP_NO_ERROR) + else // error { - // Initialize the ConnnectionFailureInfo object - ConnnectionFailureInfo failureInfo(peerId, error, stage); -#if CHIP_CONFIG_ENABLE_BUSY_HANDLING_FOR_OPERATIONAL_SESSION_SETUP - if (error == CHIP_ERROR_BUSY) + if (onConnectionFailure != nullptr) { - failureInfo.requestedBusyDelay.Emplace(requestedBusyDelay); + onConnectionFailure->mCall(onConnectionFailure->mContext, peerId, error); } + if (onSetupFailure != nullptr) + { + ConnnectionFailureInfo failureInfo(peerId, error, stage); +#if CHIP_CONFIG_ENABLE_BUSY_HANDLING_FOR_OPERATIONAL_SESSION_SETUP + if (error == CHIP_ERROR_BUSY) + { + failureInfo.requestedBusyDelay.Emplace(requestedBusyDelay); + } #endif // CHIP_CONFIG_ENABLE_BUSY_HANDLING_FOR_OPERATIONAL_SESSION_SETUP - cb->mCall(cb->mContext, failureInfo); - } - } - - while (successReady.mNext != &successReady) - { - // We expect that we only have callbacks if we are not performing just address update. - VerifyOrDie(!performingAddressUpdate); - Callback::Callback * cb = Callback::Callback::FromCancelable(successReady.mNext); - - cb->Cancel(); - if (error == CHIP_NO_ERROR) - { - VerifyOrDie(exchangeMgr); - // We know that we for sure have the SessionHandle in the successful case. - cb->mCall(cb->mContext, *exchangeMgr, optionalSessionHandle.Value()); + onSetupFailure->mCall(onSetupFailure->mContext, failureInfo); + } } } } diff --git a/src/app/OperationalSessionSetup.h b/src/app/OperationalSessionSetup.h index 5955dbab0713bd..508c778923fd49 100644 --- a/src/app/OperationalSessionSetup.h +++ b/src/app/OperationalSessionSetup.h @@ -33,6 +33,7 @@ #include #include #include +#include #include #include #include @@ -309,9 +310,8 @@ class DLL_EXPORT OperationalSessionSetup : public SessionEstablishmentDelegate, SessionHolder mSecureSession; - Callback::CallbackDeque mConnectionSuccess; - Callback::CallbackDeque mConnectionFailure; - Callback::CallbackDeque mSetupFailure; + typedef Callback::GroupedCallbackList SuccessFailureCallbackList; + SuccessFailureCallbackList mCallbacks; OperationalSessionReleaseDelegate * mReleaseDelegate; @@ -402,10 +402,8 @@ class DLL_EXPORT OperationalSessionSetup : public SessionEstablishmentDelegate, * notifications. This happens after the object has been released, if it's * being released. */ - static void NotifyConnectionCallbacks(Callback::Cancelable & failureReady, Callback::Cancelable & setupFailureReady, - Callback::Cancelable & successReady, CHIP_ERROR error, SessionEstablishmentStage stage, - const ScopedNodeId & peerId, bool performingAddressUpdate, - Messaging::ExchangeManager * exchangeMgr, + static void NotifyConnectionCallbacks(SuccessFailureCallbackList & ready, CHIP_ERROR error, SessionEstablishmentStage stage, + const ScopedNodeId & peerId, Messaging::ExchangeManager * exchangeMgr, const Optional & optionalSessionHandle, // requestedBusyDelay will be 0 if not // CHIP_CONFIG_ENABLE_BUSY_HANDLING_FOR_OPERATIONAL_SESSION_SETUP, diff --git a/src/lib/core/BUILD.gn b/src/lib/core/BUILD.gn index ea09ddd01bbaad..70680640b025b2 100644 --- a/src/lib/core/BUILD.gn +++ b/src/lib/core/BUILD.gn @@ -157,6 +157,7 @@ static_library("core") { "CHIPKeyIds.h", "CHIPPersistentStorageDelegate.h", "ClusterEnums.h", + "GroupedCallbackList.h", "OTAImageHeader.cpp", "OTAImageHeader.h", "PeerId.h", diff --git a/src/lib/core/GroupedCallbackList.h b/src/lib/core/GroupedCallbackList.h new file mode 100644 index 00000000000000..100b7df15628ca --- /dev/null +++ b/src/lib/core/GroupedCallbackList.h @@ -0,0 +1,246 @@ +/* + * + * Copyright (c) 2024 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include + +namespace chip { +namespace Callback { + +namespace detail { +// Internal helper functions +template +void TaggedDequeueGroup(Cancelable * cancelable); +void EnqueueWithGroup(Cancelable * cancelable, Cancelable *& group, Cancelable * anchor, void (*cancelFn)(Cancelable *)); +void LinkGroup(Cancelable * prev, Cancelable * next); +inline Cancelable * ClearCancelable(Cancelable * cancelable); +} // namespace detail + +/** + * A GroupedCallbackList manages a list of Callback objects (see CHIPCallback.h). + * The state of the list is maintained using the prev/next pointers of each Callback. + * Unlike a normal linked list where entries are managed individually, this class + * manages a number of related callbacks as a group, with the callback function types + * given as template parameters. + * + * For example, a `GroupedCallbackList` manages groups of a + * `Callback` and a `Callback`. + * + * Groups of callbacks are enqueued and dequeued (or cancelled) as a unit. + * Within a group each callback is optional (i.e. can be null), however attempting + * to enqueue a group where all callbacks are null has no effect. + */ +template +class GroupedCallbackList : protected Cancelable +{ +public: + GroupedCallbackList() = default; + ~GroupedCallbackList() { Clear(); } + + GroupedCallbackList(GroupedCallbackList const &) = delete; + GroupedCallbackList & operator=(GroupedCallbackList const &) = delete; + + bool IsEmpty() { return mNext == this; } + + /** + * Enqueues the specified group of callbacks, any of which may be null. + */ + void Enqueue(Callback *... callback) { Enqueue(std::index_sequence_for{}, callback...); } + + /** + * If the list is non-empty, populates the reference arguments with the first + * group of callbacks and returns true. Returns false if the list is empty. + */ + bool Peek(Callback *&... callback) const { return Peek(std::index_sequence_for{}, callback...); } + + /** + * Like Peek(), but additionally removes the first group of callbacks from the list. + */ + bool Take(Callback *&... callback) + { + VerifyOrReturnValue(Peek(callback...), false); + mNext->Cancel(); + return true; + } + + /** + * Moves all elements of the source list into this list, leaving the source list empty. + */ + void EnqueueTakeAll(GroupedCallbackList & source) + { + VerifyOrReturn(!source.IsEmpty() && this != &source); + detail::LinkGroup(mPrev, source.mNext); + source.mPrev->mNext = this; + mPrev = source.mPrev; + + source.mPrev = source.mNext = &source; + } + + void Clear() + { + Cancelable * next = mNext; + while (next != this) + { + next = detail::ClearCancelable(next); + } + mPrev = mNext = this; + } + +private: + /* + * The grouped list structure is similar to a normal doubly linked list, + * with the list object itself (via inheriting from Cancelable) acting as + * an external "anchor" node that is both the head and tail of the list. + * + * However we have the additional requirement of representing node grouping. + * Due to the requirement so support sparse groups (one or more callbacks may + * not be present in a particular group) we cannot rely on a fixed group size. + * This problem is solved by having the "prev" pointer for all nodes in a group + * point to the node before the group, as illustrated below: + * + * |Anchor| |Grp 1| |====== Group 2 ======| + * _______________________________________________ + * / \ + * \ +---+ +---+ +---+ +---+ +---+ / + * ->|###|----->| |----->| |-->| |-->| |-- + * |###| | | | | | | | | + * --|###|<-----| |<-----| | -| | -| |<- + * / +---+ +---+ \ +---+ / +---+ / +---+ \ + * | \_______/ / | + * | \_____________/ | + * \_______________________________________________/ + * + * This allows the start of a group to be reached from any group member via + * ->prev->next. Nodes in a group can be enumerated by via the "next" chain, + * inspecting the "prev" pointers to detect the end of the group. The price + * for encoding grouping in this way is that upon removal of a group we have + * to update not just the "prev" pointer of the following node, but of all + * nodes in the following group. + * + * When retrieving a (sparse) group from the list, we also need to be able + * to tell which callbacks are present: In a grouped list with types (A, B) + * both (a, nullptr) and (nullptr, b) are by necessity represented by only + * a single node in the list. To be able to recover this information, we use + * distinct trampolines that tag the "cancel" function pointer stored in each + * node with the index of the callback type within the argument type tuple. + */ + + template + void Enqueue(std::index_sequence, Callback *... callback) + { + Cancelable * group = nullptr; + ( + [&] { + VerifyOrReturn(callback != nullptr); + detail::EnqueueWithGroup(callback->Cancel(), group, this, &detail::TaggedDequeueGroup); + }(), + ...); + } + + template + bool Peek(std::index_sequence, Callback *&... callback) const + { + Cancelable * cancelable = mNext; + VerifyOrReturnValue(cancelable != this, false); + Cancelable * groupPrev = cancelable->mPrev; + ( + [&] { + if (cancelable->mPrev == groupPrev && cancelable->mCancel == &detail::TaggedDequeueGroup) + { + callback = CallbackmCall)>::FromCancelable(cancelable); + cancelable = cancelable->mNext; + } + else + { + callback = nullptr; + } + }(), + ...); + return true; + } +}; + +namespace detail { + +// Inserts `cancelable` before `anchor`, either starting a new `group` +// (populating the passed pointer if it is null) or adding to it. +inline void EnqueueWithGroup(Cancelable * cancelable, Cancelable *& group, Cancelable * anchor, void (*cancelFn)(Cancelable *)) +{ + cancelable->mCancel = cancelFn; + cancelable->mNext = anchor; + if (!group) + { + group = cancelable; + cancelable->mPrev = anchor->mPrev; + } + else + { + cancelable->mPrev = group->mPrev; + } + anchor->mPrev->mNext = cancelable; + anchor->mPrev = cancelable; +} + +// Establish prev/next links between `prev` and the group starting at `cancelable`. +inline void LinkGroup(Cancelable * prev, Cancelable * cancelable) +{ + prev->mNext = cancelable; + + Cancelable * groupPrev = cancelable->mPrev; + do + { + cancelable->mPrev = prev; + cancelable = cancelable->mNext; + } while (cancelable->mPrev == groupPrev); +} + +// Clears the state of a cancelable and returns the following one. +// Does NOT touch the state of adjacent nodes. +inline Cancelable * ClearCancelable(Cancelable * cancelable) +{ + auto * next = cancelable->mNext; + cancelable->mPrev = cancelable->mNext = cancelable; + cancelable->mCancel = nullptr; + return next; +} + +// Dequeues `cancelable` and all otehr nodes in the same group. +inline void DequeueGroup(Cancelable * cancelable) +{ + Cancelable * prev = cancelable->mPrev; + Cancelable * next = prev->mNext; + do + { + next = ClearCancelable(next); + } while (next->mPrev == prev); + LinkGroup(prev, next); +} + +template +void TaggedDequeueGroup(Cancelable * cancelable) +{ + (void) Index; // not used, we only care that instantiations have unique addresses + DequeueGroup(cancelable); +} + +} // namespace detail +} // namespace Callback +} // namespace chip diff --git a/src/lib/core/tests/BUILD.gn b/src/lib/core/tests/BUILD.gn index 001078019a7e83..108d5112e8819b 100644 --- a/src/lib/core/tests/BUILD.gn +++ b/src/lib/core/tests/BUILD.gn @@ -26,6 +26,7 @@ chip_test_suite("tests") { "TestCHIPCallback.cpp", "TestCHIPError.cpp", "TestCHIPErrorStr.cpp", + "TestGroupedCallbackList.cpp", "TestOTAImageHeader.cpp", "TestOptional.cpp", "TestReferenceCounted.cpp", diff --git a/src/lib/core/tests/TestGroupedCallbackList.cpp b/src/lib/core/tests/TestGroupedCallbackList.cpp new file mode 100644 index 00000000000000..da09d162ce1d00 --- /dev/null +++ b/src/lib/core/tests/TestGroupedCallbackList.cpp @@ -0,0 +1,226 @@ +/* + * + * Copyright (c) 2024 Project CHIP Authors + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include + +#include + +using namespace chip::Callback; + +// Expose Cancelable anchor for testing +template +struct TestGroupedCallbackList : public GroupedCallbackList +{ + Cancelable * Anchor() { return this; } +}; + +static void CallbackFn(void *) {} +static void CallbackWithIntFn(void *, int) {} + +typedef void (*CallWithIntFn)(void *, int); + +static void * StringContext(char const * string) +{ + return const_cast(static_cast(string)); +} + +static void ValidateList(Cancelable const * anchor) +{ +#if 0 // for manual debugging + { + ChipLogDetail(NotSpecified, "ANCHOR: %p", anchor); + Cancelable * ca = anchor->mNext; + while (ca != anchor) + { + auto * cb = Callback<>::FromCancelable(ca); + ChipLogDetail(NotSpecified, "%s%p (prev=%p, cancel=%p) %s", (ca->mPrev->mNext == ca) ? "> " : " ", ca, ca->mPrev, + ca->mCancel, static_cast(cb->mContext)); + ca = ca->mNext; + } + } +#endif + { + EXPECT_TRUE(anchor->mPrev->mNext == anchor); + EXPECT_TRUE(anchor->mNext->mPrev == anchor); + + std::unordered_map index; + index[anchor] = 0; + + size_t lastPrevIndex = 0; + Cancelable * ca = anchor->mNext; + for (size_t i = 1; ca != anchor; i++, ca = ca->mNext) + { + EXPECT_TRUE(index.find(ca) == index.end()); // cycle? + index[ca] = i; + + Cancelable * prev = ca->mPrev; + auto search = index.find(prev); + EXPECT_TRUE(search != index.end()); // prev should point backwards + EXPECT_GE(search->second, lastPrevIndex); // should be monotonic + lastPrevIndex = search->second; + } + } +} + +TEST(GroupedCallbackListTest, Trivial) +{ + TestGroupedCallbackList list; + Callback<> * out = nullptr; + EXPECT_TRUE(list.IsEmpty()); + EXPECT_FALSE(list.Peek(out)); + + Callback cbOne(CallbackFn, StringContext("cbOne")); + list.Enqueue(&cbOne); + EXPECT_FALSE(list.IsEmpty()); + EXPECT_TRUE(list.Peek(out)); + EXPECT_TRUE(out == &cbOne); + + cbOne.Cancel(); + EXPECT_TRUE(list.IsEmpty()); +} + +TEST(GroupedCallbackListTest, EnqueueAllAndPeek) +{ + TestGroupedCallbackList list; + Callback cbOne(CallbackFn, StringContext("cbOne")); + Callback cbTwo(CallbackWithIntFn, StringContext("cbTwo")); + list.Enqueue(&cbOne, &cbTwo); + ValidateList(list.Anchor()); + Callback * outOne = nullptr; + Callback * outTwo = nullptr; + EXPECT_TRUE(list.Peek(outOne, outTwo)); + EXPECT_TRUE(outOne == &cbOne); + EXPECT_TRUE(outTwo == &cbTwo); +} + +TEST(GroupedCallbackListTest, EnqueueSparseAndPeek) +{ + TestGroupedCallbackList list; + Callback cbTwo(CallbackFn, StringContext("cbTwo")); + list.Enqueue(nullptr, &cbTwo); + ValidateList(list.Anchor()); + Callback<> * outOne = &cbTwo; // poison + Callback<> * outTwo = nullptr; + EXPECT_TRUE(list.Peek(outOne, outTwo)); + EXPECT_TRUE(outOne == nullptr); + EXPECT_TRUE(outTwo == &cbTwo); +} + +TEST(GroupedCallbackListTest, EnqueueAndClear) +{ + TestGroupedCallbackList list; + Callback cbOne(CallbackFn, StringContext("cbOne")); + Callback cbTwo(CallbackWithIntFn, StringContext("cbTwo")); + list.Enqueue(&cbOne, &cbTwo); + Callback cbThree(CallbackFn, StringContext("cbThree")); + list.Enqueue(&cbThree, nullptr); + ValidateList(list.Anchor()); + EXPECT_FALSE(list.IsEmpty()); + EXPECT_TRUE(cbOne.IsRegistered()); + EXPECT_TRUE(cbTwo.IsRegistered()); + EXPECT_TRUE(cbThree.IsRegistered()); + + list.Clear(); + ValidateList(list.Anchor()); + EXPECT_TRUE(list.IsEmpty()); + EXPECT_FALSE(cbOne.IsRegistered()); + EXPECT_FALSE(cbTwo.IsRegistered()); + EXPECT_FALSE(cbThree.IsRegistered()); +} + +TEST(GroupedCallbackListTest, Complex) +{ + TestGroupedCallbackList list; + ValidateList(list.Anchor()); + EXPECT_TRUE(list.IsEmpty()); + + Callback cbZero(CallbackFn, StringContext("cbZero")); + list.Enqueue(&cbZero, nullptr); + ValidateList(list.Anchor()); + EXPECT_FALSE(list.IsEmpty()); + EXPECT_TRUE(cbZero.IsRegistered()); + + Callback cbOne(CallbackFn, StringContext("cbOne")); + Callback cbTwo(CallbackFn, StringContext("cbTwo")); + list.Enqueue(&cbOne, &cbTwo); + ValidateList(list.Anchor()); + EXPECT_TRUE(cbOne.IsRegistered()); + EXPECT_TRUE(cbTwo.IsRegistered()); + + cbZero.Cancel(); + ValidateList(list.Anchor()); + EXPECT_FALSE(cbZero.IsRegistered()); + + Callback cbThree(CallbackFn, StringContext("cbThree")); + list.Enqueue(&cbThree, nullptr); + ValidateList(list.Anchor()); + + Callback cbFour(CallbackFn, StringContext("cbFour")); + list.Enqueue(nullptr, &cbFour); + ValidateList(list.Anchor()); + + cbOne.Cancel(); // also cancels cbTwo + ValidateList(list.Anchor()); + EXPECT_FALSE(cbOne.IsRegistered()); + EXPECT_FALSE(cbTwo.IsRegistered()); + + Callback<> * outA = &cbZero; + Callback<> * outB = &cbZero; + EXPECT_TRUE(list.Take(outA, outB)); + ValidateList(list.Anchor()); + EXPECT_TRUE(outA == &cbThree); + EXPECT_TRUE(outB == nullptr); + + EXPECT_TRUE(list.Take(outA, outB)); + ValidateList(list.Anchor()); + EXPECT_TRUE(outA == nullptr); + EXPECT_TRUE(outB == &cbFour); + + EXPECT_TRUE(list.IsEmpty()); +} + +TEST(GroupedCallbackListTest, EnqueueTakeAll) +{ + TestGroupedCallbackList listA; + Callback cbOne(CallbackFn, StringContext("cbOne")); + Callback cbTwo(CallbackFn, StringContext("cbTwo")); + listA.Enqueue(&cbOne, &cbTwo); + ValidateList(listA.Anchor()); + EXPECT_FALSE(listA.IsEmpty()); + EXPECT_TRUE(cbOne.IsRegistered()); + EXPECT_TRUE(cbTwo.IsRegistered()); + + TestGroupedCallbackList listB; + Callback cbThree(CallbackFn, StringContext("cbThree")); + listB.Enqueue(&cbThree, nullptr); + ValidateList(listB.Anchor()); + EXPECT_FALSE(listB.IsEmpty()); + EXPECT_TRUE(cbThree.IsRegistered()); + + listB.EnqueueTakeAll(listA); + ValidateList(listA.Anchor()); + ValidateList(listB.Anchor()); + EXPECT_TRUE(cbThree.IsRegistered()); + EXPECT_TRUE(cbOne.IsRegistered()); + EXPECT_TRUE(cbTwo.IsRegistered()); + EXPECT_FALSE(listB.IsEmpty()); + EXPECT_TRUE(listA.IsEmpty()); +}