Skip to content

Commit

Permalink
use concurrent queue to buffer query mappings
Browse files Browse the repository at this point in the history
  • Loading branch information
Funatiq committed Oct 28, 2019
1 parent 91abf44 commit 003fd93
Showing 1 changed file with 20 additions and 21 deletions.
41 changes: 20 additions & 21 deletions src/classification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -664,46 +664,45 @@ struct query_mapping
// matches_per_location allhits;
};

using query_mappings = std::vector<std::pair<query_id, query_mapping>>;


/*************************************************************************//**
*
* @brief redo the classification of all reads using only targets in tgtMatches
*
*****************************************************************************/
void redo_classification_batched(
std::unordered_map<query_id, query_mapping>& queryMappings,
moodycamel::ConcurrentQueue<query_mappings>& queryMappingsQueue,
const matches_per_target& tgtMatches,
const database& db, const query_options& opt,
classification_results& results,
taxon_count_map& allTaxCounts)
{
//parallel
std::vector<std::future<void>> threads;
std::atomic<std::unordered_map<query_id, query_mapping>::iterator> globalMappingIterator;
globalMappingIterator.store(queryMappings.begin());
std::mutex mtx;

for(int threadId = 0; threadId < opt.process.numThreads; ++threadId) {
threads.emplace_back(std::async(std::launch::async, [&, threadId] {
auto localMappingIterator = globalMappingIterator.load();
while(localMappingIterator != queryMappings.end()) {
auto batchEnd = size_t(std::distance(localMappingIterator, queryMappings.end())) > opt.process.batchSize ?
std::next(localMappingIterator, opt.process.batchSize) : queryMappings.end();
bool success = globalMappingIterator.compare_exchange_weak(localMappingIterator, batchEnd);
if(success) {

query_mappings mappings;

while(queryMappingsQueue.size_approx()) {
if(queryMappingsQueue.try_dequeue(mappings)) {
std::ostringstream bufout;
taxon_count_map taxCounts;

for(auto it = localMappingIterator; it != batchEnd; ++it) {
for(auto& mapping : mappings) {
// classify using only targets left in tgtMatches
update_classification(db, opt.classify, it->second.cls, tgtMatches);
update_classification(db, opt.classify, mapping.second.cls, tgtMatches);

evaluate_classification(db, opt.evaluate, it->second.query, it->second.cls, results.statistics);
evaluate_classification(db, opt.evaluate, mapping.second.query, mapping.second.cls, results.statistics);

show_query_mapping(bufout, db, opt.output, it->second.query, it->second.cls, match_locations{});
show_query_mapping(bufout, db, opt.output, mapping.second.query, mapping.second.cls, match_locations{});

if(opt.output.makeTaxCounts && it->second.cls.best) {
++taxCounts[it->second.cls.best];
if(opt.output.makeTaxCounts && mapping.second.cls.best) {
++taxCounts[mapping.second.cls.best];
}
}
std::lock_guard<std::mutex> lock(mtx);
Expand Down Expand Up @@ -734,7 +733,7 @@ void redo_classification_batched(
struct mappings_buffer
{
std::ostringstream out;
std::unordered_map<query_id, query_mapping> queryMappings;
query_mappings queryMappings;
matches_per_target hitsPerTarget;
taxon_count_map taxCounts;
};
Expand All @@ -754,7 +753,7 @@ void map_queries_to_targets_default(
//global target -> query_id/win:hits... list
matches_per_target tgtMatches;

std::unordered_map<query_id, query_mapping> queryMappings;
moodycamel::ConcurrentQueue<query_mappings> queryMappingsQueue;

//global taxon -> read count
taxon_count_map allTaxCounts;
Expand Down Expand Up @@ -802,7 +801,7 @@ void map_queries_to_targets_default(
sequence().swap(query.seq2);

//save query mapping for post processing
buf.queryMappings.emplace(query.id, query_mapping{std::move(query), std::move(cls)});
buf.queryMappings.emplace_back(query.id, query_mapping{std::move(query), std::move(cls)});
}
};

Expand All @@ -814,8 +813,8 @@ void map_queries_to_targets_default(
}
if(opt.classify.covPercentile > 0) {
//move mappings to global map
queryMappings.insert(std::make_move_iterator(buf.queryMappings.begin()),
std::make_move_iterator(buf.queryMappings.end()) );
queryMappingsQueue.enqueue(buf.queryMappings);
buf.queryMappings.clear();
}
else {
if(opt.output.makeTaxCounts) {
Expand All @@ -842,7 +841,7 @@ void map_queries_to_targets_default(
if(opt.classify.covPercentile > 0) {
filter_targets_by_coverage(tgtMatches, opt.classify.covPercentile);

redo_classification_batched(queryMappings, tgtMatches, db, opt, results, allTaxCounts);
redo_classification_batched(queryMappingsQueue, tgtMatches, db, opt, results, allTaxCounts);
}

if(opt.output.showHitsPerTargetList) {
Expand Down

0 comments on commit 003fd93

Please sign in to comment.