Skip to content

Commit

Permalink
#385 Thread queue deadlock with multiple consumers
Browse files Browse the repository at this point in the history
  • Loading branch information
fpagliughi committed Apr 29, 2022
1 parent 37d7616 commit 67b1e5e
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 54 deletions.
105 changes: 51 additions & 54 deletions src/mqtt/thread_queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
/////////////////////////////////////////////////////////////////////////////

/*******************************************************************************
* Copyright (c) 2017-2021 Frank Pagliughi <[email protected]>
* Copyright (c) 2017-2022 Frank Pagliughi <[email protected]>
*
* All rights reserved. This program and the accompanying materials
* are made available under the terms of the Eclipse Public License v1.0
Expand Down Expand Up @@ -39,6 +39,7 @@ namespace mqtt {

/**
* A thread-safe queue for inter-thread communication.
*
* This is a lockinq queue with blocking operations. The get() operations
* can always block on an empty queue, but have variations for non-blocking
* (try_get) and bounded-time blocking (try_get_for, try_get_until).
Expand Down Expand Up @@ -149,14 +150,11 @@ class thread_queue
*/
void put(value_type val) {
unique_guard g(lock_);
if (que_.size() >= cap_)
notFullCond_.wait(g, [this]{return que_.size() < cap_;});
bool wasEmpty = que_.empty();
notFullCond_.wait(g, [this]{return que_.size() < cap_;});

que_.emplace(std::move(val));
if (wasEmpty) {
g.unlock();
notEmptyCond_.notify_one();
}
g.unlock();
notEmptyCond_.notify_one();
}
/**
* Non-blocking attempt to place an item into the queue.
Expand All @@ -166,14 +164,12 @@ class thread_queue
*/
bool try_put(value_type val) {
unique_guard g(lock_);
size_type n = que_.size();
if (n >= cap_)
if (que_.size() >= cap_)
return false;

que_.emplace(std::move(val));
if (n == 0) {
g.unlock();
notEmptyCond_.notify_one();
}
g.unlock();
notEmptyCond_.notify_one();
return true;
}
/**
Expand All @@ -186,16 +182,14 @@ class thread_queue
* timeout occurred.
*/
template <typename Rep, class Period>
bool try_put_for(value_type* val, const std::chrono::duration<Rep, Period>& relTime) {
bool try_put_for(value_type val, const std::chrono::duration<Rep, Period>& relTime) {
unique_guard g(lock_);
if (que_.size() >= cap_ && !notFullCond_.wait_for(g, relTime, [this]{return que_.size() < cap_;}))
if (!notFullCond_.wait_for(g, relTime, [this]{return que_.size() < cap_;}))
return false;
bool wasEmpty = que_.empty();

que_.emplace(std::move(val));
if (wasEmpty) {
g.unlock();
notEmptyCond_.notify_one();
}
g.unlock();
notEmptyCond_.notify_one();
return true;
}
/**
Expand All @@ -209,16 +203,14 @@ class thread_queue
* timeout occurred.
*/
template <class Clock, class Duration>
bool try_put_until(value_type* val, const std::chrono::time_point<Clock,Duration>& absTime) {
bool try_put_until(value_type val, const std::chrono::time_point<Clock,Duration>& absTime) {
unique_guard g(lock_);
if (que_.size() >= cap_ && !notFullCond_.wait_until(g, absTime, [this]{return que_.size() < cap_;}))
if (!notFullCond_.wait_until(g, absTime, [this]{return que_.size() < cap_;}))
return false;
bool wasEmpty = que_.empty();

que_.emplace(std::move(val));
if (wasEmpty) {
g.unlock();
notEmptyCond_.notify_one();
}
g.unlock();
notEmptyCond_.notify_one();
return true;
}
/**
Expand All @@ -228,15 +220,16 @@ class thread_queue
* @param val Pointer to a variable to receive the value.
*/
void get(value_type* val) {
if (!val)
return;

unique_guard g(lock_);
if (que_.empty())
notEmptyCond_.wait(g, [this]{return !que_.empty();});
notEmptyCond_.wait(g, [this]{return !que_.empty();});

*val = std::move(que_.front());
que_.pop();
if (que_.size() == cap_-1) {
g.unlock();
notFullCond_.notify_one();
}
g.unlock();
notFullCond_.notify_one();
}
/**
* Retrieve a value from the queue.
Expand All @@ -246,14 +239,12 @@ class thread_queue
*/
value_type get() {
unique_guard g(lock_);
if (que_.empty())
notEmptyCond_.wait(g, [this]{return !que_.empty();});
notEmptyCond_.wait(g, [this]{return !que_.empty();});

value_type val = std::move(que_.front());
que_.pop();
if (que_.size() == cap_-1) {
g.unlock();
notFullCond_.notify_one();
}
g.unlock();
notFullCond_.notify_one();
return val;
}
/**
Expand All @@ -265,15 +256,17 @@ class thread_queue
* the queue is empty.
*/
bool try_get(value_type* val) {
if (!val)
return false;

unique_guard g(lock_);
if (que_.empty())
return false;

*val = std::move(que_.front());
que_.pop();
if (que_.size() == cap_-1) {
g.unlock();
notFullCond_.notify_one();
}
g.unlock();
notFullCond_.notify_one();
return true;
}
/**
Expand All @@ -288,15 +281,17 @@ class thread_queue
*/
template <typename Rep, class Period>
bool try_get_for(value_type* val, const std::chrono::duration<Rep, Period>& relTime) {
if (!val)
return false;

unique_guard g(lock_);
if (que_.empty() && !notEmptyCond_.wait_for(g, relTime, [this]{return !que_.empty();}))
if (!notEmptyCond_.wait_for(g, relTime, [this]{return !que_.empty();}))
return false;

*val = std::move(que_.front());
que_.pop();
if (que_.size() == cap_-1) {
g.unlock();
notFullCond_.notify_one();
}
g.unlock();
notFullCond_.notify_one();
return true;
}
/**
Expand All @@ -311,15 +306,17 @@ class thread_queue
*/
template <class Clock, class Duration>
bool try_get_until(value_type* val, const std::chrono::time_point<Clock,Duration>& absTime) {
if (!val)
return false;

unique_guard g(lock_);
if (que_.empty() && !notEmptyCond_.wait_until(g, absTime, [this]{return !que_.empty();}))
if (!notEmptyCond_.wait_until(g, absTime, [this]{return !que_.empty();}))
return false;

*val = std::move(que_.front());
que_.pop();
if (que_.size() == cap_-1) {
g.unlock();
notFullCond_.notify_one();
}
g.unlock();
notFullCond_.notify_one();
return true;
}
};
Expand Down
1 change: 1 addition & 0 deletions test/unit/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ add_executable(unit_tests unit_tests.cpp
test_properties.cpp
test_response_options.cpp
test_string_collection.cpp
test_thread_queue.cpp
test_token.cpp
test_topic.cpp
test_topic_matcher.cpp
Expand Down
89 changes: 89 additions & 0 deletions test/unit/test_thread_queue.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// test_thread_queue.cpp
//
// Unit tests for the thread_queue class in the Paho MQTT C++ library.
//

/*******************************************************************************
* Copyright (c) 2022 Frank Pagliughi <[email protected]>
*
* All rights reserved. This program and the accompanying materials
* are made available under the terms of the Eclipse Public License v1.0
* and Eclipse Distribution License v1.0 which accompany this distribution.
*
* The Eclipse Public License is available at
* http://www.eclipse.org/legal/epl-v10.html
* and the Eclipse Distribution License is available at
* http://www.eclipse.org/org/documents/edl-v10.php.
*
* Contributors:
* Frank Pagliughi - Initial implementation
*******************************************************************************/

#define UNIT_TESTS

#include "catch2/catch.hpp"
#include "mqtt/types.h"
#include "mqtt/thread_queue.h"

#include <thread>
#include <future>
#include <chrono>
#include <vector>

using namespace mqtt;
using namespace std::chrono;

TEST_CASE("que put/get", "[thread_queue]")
{
thread_queue<int> que;

que.put(1);
que.put(2);
REQUIRE(que.get() == 1);

que.put(3);
REQUIRE(que.get() == 2);
REQUIRE(que.get() == 3);
}

TEST_CASE("que mt put/get", "[thread_queue]")
{
thread_queue<string> que;
const size_t N = 1000000;
const size_t N_THR = 2;

auto producer = [&que]() {
string s;
for (size_t i=0; i<512; ++i)
s.push_back('a' + i%26);

for (size_t i=0; i<N; ++i)
que.put(s);
};

auto consumer = [&que]() {
string s;
bool ok = true;
for (size_t i=0; i<N && ok; ++i) {
ok = que.try_get_for(&s, seconds{1});
}
return ok;
};

std::vector<std::thread> producers;
std::vector<std::future<bool>> consumers;

for (size_t i=0; i<N_THR; ++i)
producers.push_back(std::thread(producer));

for (size_t i=0; i<N_THR; ++i)
consumers.push_back(std::async(consumer));

for (size_t i=0; i<N_THR; ++i)
producers[i].join();

for (size_t i=0; i<N_THR; ++i) {
REQUIRE(consumers[i].get());
}
}

0 comments on commit 67b1e5e

Please sign in to comment.