Skip to content

Commit

Permalink
Audit ForEachActiveObject thread-safe
Browse files Browse the repository at this point in the history
  • Loading branch information
kghost committed Sep 24, 2021
1 parent 221296a commit 479af54
Show file tree
Hide file tree
Showing 10 changed files with 83 additions and 58 deletions.
2 changes: 1 addition & 1 deletion src/channel/Manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ ChannelHandle ChannelManager::EstablishChannel(const ChannelBuilder & builder, C
ChannelContext * channelContext = nullptr;

// Find an existing Channel matching the builder
mChannelContexts.ForEachActiveObject([&](ChannelContext * context) {
mChannelContexts.ForEachActiveObjectMutableUnsafe([&](ChannelContext * context) {
if (context->MatchesBuilder(builder))
{
channelContext = context;
Expand Down
6 changes: 3 additions & 3 deletions src/channel/Manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class DLL_EXPORT ChannelManager : public ExchangeMgrDelegate
template <typename Event>
void NotifyChannelEvent(ChannelContext * channel, Event event)
{
mChannelHandles.ForEachActiveObject([&](ChannelContextHandleAssociation * association) {
mChannelHandles.ForEachActiveObjectMutableUnsafe([&](ChannelContextHandleAssociation * association) {
if (association->mChannelContext == channel)
event(association->mChannelDelegate);
return true;
Expand All @@ -60,7 +60,7 @@ class DLL_EXPORT ChannelManager : public ExchangeMgrDelegate

void OnNewConnection(SessionHandle session, ExchangeManager * mgr) override
{
mChannelContexts.ForEachActiveObject([&](ChannelContext * context) {
mChannelContexts.ForEachActiveObjectMutableUnsafe([&](ChannelContext * context) {
if (context->MatchesSession(session, mgr->GetSessionManager()))
{
context->OnNewConnection(session);
Expand All @@ -72,7 +72,7 @@ class DLL_EXPORT ChannelManager : public ExchangeMgrDelegate

void OnConnectionExpired(SessionHandle session, ExchangeManager * mgr) override
{
mChannelContexts.ForEachActiveObject([&](ChannelContext * context) {
mChannelContexts.ForEachActiveObjectMutableUnsafe([&](ChannelContext * context) {
if (context->MatchesSession(session, mgr->GetSessionManager()))
{
context->OnConnectionExpired(session);
Expand Down
12 changes: 6 additions & 6 deletions src/inet/InetLayer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ CHIP_ERROR InetLayer::Shutdown()
{
#if INET_CONFIG_ENABLE_DNS_RESOLVER
// Cancel all DNS resolution requests owned by this instance.
DNSResolver::sPool.ForEachActiveObject([&](DNSResolver * lResolver) {
DNSResolver::sPool.ForEachActiveObjectMutableUnsafe([&](DNSResolver * lResolver) {
if ((lResolver != nullptr) && lResolver->IsCreatedByInetLayer(*this))
{
lResolver->Cancel();
Expand All @@ -330,7 +330,7 @@ CHIP_ERROR InetLayer::Shutdown()

#if INET_CONFIG_ENABLE_TCP_ENDPOINT
// Abort all TCP endpoints owned by this instance.
TCPEndPoint::sPool.ForEachActiveObject([&](TCPEndPoint * lEndPoint) {
TCPEndPoint::sPool.ForEachActiveObjectMutableUnsafe([&](TCPEndPoint * lEndPoint) {
if ((lEndPoint != nullptr) && lEndPoint->IsCreatedByInetLayer(*this))
{
lEndPoint->Abort();
Expand All @@ -341,7 +341,7 @@ CHIP_ERROR InetLayer::Shutdown()

#if INET_CONFIG_ENABLE_UDP_ENDPOINT
// Close all UDP endpoints owned by this instance.
UDPEndPoint::sPool.ForEachActiveObject([&](UDPEndPoint * lEndPoint) {
UDPEndPoint::sPool.ForEachActiveObjectMutableUnsafe([&](UDPEndPoint * lEndPoint) {
if ((lEndPoint != nullptr) && lEndPoint->IsCreatedByInetLayer(*this))
{
lEndPoint->Close();
Expand Down Expand Up @@ -390,7 +390,7 @@ bool InetLayer::IsIdleTimerRunning()
bool timerRunning = false;

// See if there are any TCP connections with the idle timer check in use.
TCPEndPoint::sPool.ForEachActiveObject([&](TCPEndPoint * lEndPoint) {
TCPEndPoint::sPool.ForEachActiveObjectImmutable([&](TCPEndPoint * lEndPoint) {
if ((lEndPoint != nullptr) && (lEndPoint->mIdleTimeout != 0))
{
timerRunning = true;
Expand Down Expand Up @@ -812,7 +812,7 @@ void InetLayer::CancelResolveHostAddress(DNSResolveCompleteFunct onComplete, voi
if (State != kState_Initialized)
return;

DNSResolver::sPool.ForEachActiveObject([&](DNSResolver * lResolver) {
DNSResolver::sPool.ForEachActiveObjectMutableUnsafe([&](DNSResolver * lResolver) {
if (!lResolver->IsCreatedByInetLayer(*this))
{
return true;
Expand Down Expand Up @@ -916,7 +916,7 @@ void InetLayer::HandleTCPInactivityTimer(chip::System::Layer * aSystemLayer, voi
InetLayer & lInetLayer = *reinterpret_cast<InetLayer *>(aAppState);
bool lTimerRequired = lInetLayer.IsIdleTimerRunning();

TCPEndPoint::sPool.ForEachActiveObject([&](TCPEndPoint * lEndPoint) {
TCPEndPoint::sPool.ForEachActiveObjectMutableUnsafe([&](TCPEndPoint * lEndPoint) {
if (!lEndPoint->IsCreatedByInetLayer(lInetLayer))
return true;
if (!lEndPoint->IsConnected())
Expand Down
33 changes: 16 additions & 17 deletions src/messaging/ExchangeMgr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ CHIP_ERROR ExchangeManager::Shutdown()
{
mReliableMessageMgr.Shutdown();

mContextPool.ForEachActiveObject([](auto * ec) {
mContextPool.ForEachActiveObjectImmutable([](auto * ec) {
// There should be no active object in the pool
VerifyOrDie(false);
return true;
Expand Down Expand Up @@ -213,28 +213,27 @@ void ExchangeManager::OnMessageReceived(const PacketHeader & packetHeader, const
msgFlags.Set(MessageFlagValues::kDuplicateMessage);
}

// Search for an existing exchange that the message applies to. If a match is found...
bool found = false;
mContextPool.ForEachActiveObject([&](auto * ec) {
ExchangeContext * exchange = nullptr;
mContextPool.ForEachActiveObjectImmutable([&](auto * ec) {
if (ec->MatchExchange(session, packetHeader, payloadHeader))
{
// Found a matching exchange. Set flag for correct subsequent MRP
// retransmission timeout selection.
if (!ec->HasRcvdMsgFromPeer())
{
ec->SetMsgRcvdFromPeer(true);
}

// Matched ExchangeContext; send to message handler.
ec->HandleMessage(packetHeader.GetMessageCounter(), payloadHeader, source, msgFlags, std::move(msgBuf));
found = true;
exchange = ec;
return false;
}
return true;
});

if (found)
if (exchange != nullptr)
{
// Found a matching exchange. Set flag for correct subsequent MRP
// retransmission timeout selection.
if (!exchange->HasRcvdMsgFromPeer())
{
exchange->SetMsgRcvdFromPeer(true);
}

// Matched ExchangeContext; send to message handler.
exchange->HandleMessage(packetHeader.GetMessageCounter(), payloadHeader, source, msgFlags, std::move(msgBuf));
return;
}

Expand Down Expand Up @@ -324,7 +323,7 @@ void ExchangeManager::OnConnectionExpired(SessionHandle session)
mDelegate->OnConnectionExpired(session, this);
}

mContextPool.ForEachActiveObject([&](auto * ec) {
mContextPool.ForEachActiveObjectMutableUnsafe([&](auto * ec) {
if (ec->mSecureSession.HasValue() && ec->mSecureSession.Value() == session)
{
ec->OnConnectionExpired();
Expand All @@ -337,7 +336,7 @@ void ExchangeManager::OnConnectionExpired(SessionHandle session)

void ExchangeManager::CloseAllContextsForDelegate(const ExchangeDelegate * delegate)
{
mContextPool.ForEachActiveObject([&](auto * ec) {
mContextPool.ForEachActiveObjectMutableUnsafe([&](auto * ec) {
if (ec->GetDelegate() == delegate)
{
// Make sure to null out the delegate before closing the context, so
Expand Down
2 changes: 1 addition & 1 deletion src/messaging/ReliableMessageMgr.h
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ class ReliableMessageMgr
template <typename Function>
void ExecuteForAllContext(Function function)
{
mContextPool.ForEachActiveObject([&](auto * ec) {
mContextPool.ForEachActiveObjectMutableUnsafe([&](auto * ec) {
function(ec->GetReliableMessageContext());
return true;
});
Expand Down
34 changes: 25 additions & 9 deletions src/system/SystemPoolHeap.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,38 @@ class ObjectPoolHeap : public ObjectPoolStatistics

/**
* @brief
* Run a functor for each active object in the pool
* Run a functor for each active object in the pool. The object delivered to callback function can't be removed
* during the iterator, or else it will trigger a dead lock.
*
* @param function The functor of type `bool (*)(T*)`, return false to break the iteration
* @return bool Returns false if broke during iteration
*/
template <typename Function>
bool ForEachActiveObject(Function && function)
bool ForEachActiveObjectImmutable(Function && function)
{
// Create a new copy of original set, allowing add/remove elements while iterating in the same thread.
for (auto object : CopyObjectSet())
std::lock_guard<std::mutex> lock(mutex);
for (auto object : mObjects)
{
if (!function(object))
return false;
}
return true;
}

/**
* @brief
* Run a functor for each active object in the pool. This function is not thread-safe, the caller must ensure that
* extra synchronization model is used to prevent racing problems.
*
* @param function The functor of type `bool (*)(T*)`, return false to break the iteration
* @return bool Returns false if broke during iteration
*/
template <typename Function>
bool ForEachActiveObjectMutableUnsafe(Function && function)
{
// Create a new copy of original set, allowing add/remove elements while iterating in the same thread.
for (auto object : std::set<T*>(mObjects))
{
if (!function(object))
return false;
Expand All @@ -96,12 +118,6 @@ class ObjectPoolHeap : public ObjectPoolStatistics
private:
std::mutex mutex;
std::set<T *> mObjects;

std::set<T *> CopyObjectSet()
{
std::lock_guard<std::mutex> lock(mutex);
return mObjects;
}
};

} // namespace System
Expand Down
20 changes: 15 additions & 5 deletions src/system/SystemPoolNonHeap.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,17 +71,27 @@ class ObjectPoolNonHeap : public StaticAllocatorBitmap

/**
* @brief
* Run a functor for each active object in the pool
* Run a functor for each active object in the pool.
*
* @param function The functor of type `bool (*)(T*)`, return false to break the iteration
* @return bool Returns false if broke during iteration
*/
template <typename Function>
bool ForEachActiveObjectImmutable(Function && function)
{
LambdaProxy<Function> proxy(std::forward<Function>(function));
return ForEachActiveObjectInner(&proxy, &LambdaProxy<Function>::Call);
}

/**
* @brief
* Run a functor for each active object in the pool.
*
* caution
* this function is not thread-safe, make sure all usage of the
* pool is protected by a lock, or else avoid using this function
* @param function The functor of type `bool (*)(T*)`, return false to break the iteration
* @return bool Returns false if broke during iteration
*/
template <typename Function>
bool ForEachActiveObject(Function && function)
bool ForEachActiveObjectMutableUnsafe(Function && function)
{
LambdaProxy<Function> proxy(std::forward<Function>(function));
return ForEachActiveObjectInner(&proxy, &LambdaProxy<Function>::Call);
Expand Down
26 changes: 13 additions & 13 deletions src/system/tests/TestSystemObject.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ void TestObject::CheckIteration(nlTestSuite * inSuite, void * aContext)
unsigned int i;

// Pool should be empty before tests.
sPool.ForEachActiveObject([&](auto object) {
sPool.ForEachActiveObjectMutableUnsafe([&](auto object) {
NL_TEST_ASSERT(lContext.mTestSuite, false);
return true;
});
Expand All @@ -157,15 +157,15 @@ void TestObject::CheckIteration(nlTestSuite * inSuite, void * aContext)
}

i = 0;
sPool.ForEachActiveObject([&](TestObject * lCreated) {
sPool.ForEachActiveObjectMutableUnsafe([&](TestObject * lCreated) {
NL_TEST_ASSERT(lContext.mTestSuite, lCreated->GetReferenceCount() > 0);
i++;
return true;
});
NL_TEST_ASSERT(lContext.mTestSuite, i == kPoolSize);

i = 0;
sPool.ForEachActiveObject([&](TestObject * lCreated) {
sPool.ForEachActiveObjectMutableUnsafe([&](TestObject * lCreated) {
i++;
if (i == kPoolSize / 2)
return false;
Expand All @@ -175,7 +175,7 @@ void TestObject::CheckIteration(nlTestSuite * inSuite, void * aContext)
NL_TEST_ASSERT(lContext.mTestSuite, i == kPoolSize / 2);

// Clear the pool
sPool.ForEachActiveObject([&](auto object) {
sPool.ForEachActiveObjectMutableUnsafe([&](auto object) {
object->Release();
return true;
});
Expand All @@ -189,7 +189,7 @@ void TestObject::CheckRetention(nlTestSuite * inSuite, void * aContext)
unsigned int i;

// Pool should be empty before tests.
sPool.ForEachActiveObject([&](auto object) {
sPool.ForEachActiveObjectMutableUnsafe([&](auto object) {
NL_TEST_ASSERT(lContext.mTestSuite, false);
return true;
});
Expand All @@ -204,23 +204,23 @@ void TestObject::CheckRetention(nlTestSuite * inSuite, void * aContext)
}

i = 0;
TestObject::sPool.ForEachActiveObject([&](TestObject * lGotten) {
TestObject::sPool.ForEachActiveObjectMutableUnsafe([&](TestObject * lGotten) {
lGotten->Retain();
i++;
return true;
});
NL_TEST_ASSERT(lContext.mTestSuite, i == kPoolSize);

i = 0;
TestObject::sPool.ForEachActiveObject([&](TestObject * lGotten) {
TestObject::sPool.ForEachActiveObjectMutableUnsafe([&](TestObject * lGotten) {
lGotten->Release();
i++;
return true;
});
NL_TEST_ASSERT(lContext.mTestSuite, i == kPoolSize);

// Clear the pool
sPool.ForEachActiveObject([&](auto object) {
sPool.ForEachActiveObjectMutableUnsafe([&](auto object) {
object->Release();
return true;
});
Expand Down Expand Up @@ -296,7 +296,7 @@ void TestObject::MultithreadedTest(nlTestSuite * inSuite, void * aContext, void
pthread_t lThread[kNumThreads];

// Pool should be empty before tests.
sPool.ForEachActiveObject([&](auto object) {
sPool.ForEachActiveObjectMutableUnsafe([&](auto object) {
NL_TEST_ASSERT(lContext.mTestSuite, false);
return true;
});
Expand Down Expand Up @@ -324,7 +324,7 @@ void TestObject::CheckConcurrency(nlTestSuite * inSuite, void * aContext)
#endif // CHIP_SYSTEM_CONFIG_POSIX_LOCKING

// Clear the pool
sPool.ForEachActiveObject([&](auto object) {
sPool.ForEachActiveObjectMutableUnsafe([&](auto object) {
object->Release();
return true;
});
Expand All @@ -335,7 +335,7 @@ void TestObject::CheckHighWatermark(nlTestSuite * inSuite, void * aContext)
TestContext & lContext = *static_cast<TestContext *>(aContext);

// Pool should be empty before tests.
sPool.ForEachActiveObject([&](auto object) {
sPool.ForEachActiveObjectMutableUnsafe([&](auto object) {
NL_TEST_ASSERT(lContext.mTestSuite, false);
return true;
});
Expand Down Expand Up @@ -375,7 +375,7 @@ void TestObject::CheckHighWatermark(nlTestSuite * inSuite, void * aContext)
// change.
lObject->Release();
// Verify that lObject is not in the pool
sPool.ForEachActiveObject([&](auto object) {
sPool.ForEachActiveObjectMutableUnsafe([&](auto object) {
NL_TEST_ASSERT(lContext.mTestSuite, lObject != object);
return true;
});
Expand All @@ -385,7 +385,7 @@ void TestObject::CheckHighWatermark(nlTestSuite * inSuite, void * aContext)
NL_TEST_ASSERT(lContext.mTestSuite, lHighWatermark == kNumObjects);

// Clear the pool
sPool.ForEachActiveObject([&](auto object) {
sPool.ForEachActiveObjectMutableUnsafe([&](auto object) {
object->Release();
return true;
});
Expand Down
2 changes: 1 addition & 1 deletion src/system/tests/TestSystemPool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ template <class T>
size_t GetNumObjectsInUse(T & pool)
{
size_t count = 0;
pool.ForEachActiveObject([&count](void *) {
pool.ForEachActiveObjectMutableUnsafe([&count](void *) {
++count;
return true;
});
Expand Down
4 changes: 2 additions & 2 deletions src/transport/UnauthenticatedSessionTable.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class UnauthenticatedSessionTable
UnauthenticatedSession * FindEntry(const PeerAddress & address)
{
UnauthenticatedSession * result = nullptr;
mEntries.ForEachActiveObject([&](UnauthenticatedSession * entry) {
mEntries.ForEachActiveObjectImmutable([&](UnauthenticatedSession * entry) {
if (MatchPeerAddress(entry->GetPeerAddress(), address))
{
result = entry;
Expand Down Expand Up @@ -160,7 +160,7 @@ class UnauthenticatedSessionTable
UnauthenticatedSession * result = nullptr;
uint64_t oldestTimeMs = std::numeric_limits<uint64_t>::max();

mEntries.ForEachActiveObject([&](UnauthenticatedSession * entry) {
mEntries.ForEachActiveObjectImmutable([&](UnauthenticatedSession * entry) {
if (entry->GetReferenceCount() == 0 && entry->GetLastActivityTimeMs() < oldestTimeMs)
{
result = entry;
Expand Down

0 comments on commit 479af54

Please sign in to comment.