diff --git a/roofit/roofitcore/inc/RooFitDriver.h b/roofit/roofitcore/inc/RooFitDriver.h index 4df96a1beab2af..a40c2c4e499715 100644 --- a/roofit/roofitcore/inc/RooFitDriver.h +++ b/roofit/roofitcore/inc/RooFitDriver.h @@ -61,6 +61,8 @@ class RooFitDriver { const RooNLLVarNew& _topNode; const RooAbsData* const _data = nullptr; const size_t _nEvents; + + std::vector _nodes; std::unordered_map _nodeInfos; //used for preserving resources diff --git a/roofit/roofitcore/src/RooFitDriver.cxx b/roofit/roofitcore/src/RooFitDriver.cxx index cf66b9307077a7..688c5a7d7fb80c 100644 --- a/roofit/roofitcore/src/RooFitDriver.cxx +++ b/roofit/roofitcore/src/RooFitDriver.cxx @@ -64,6 +64,8 @@ RooFitDriver::RooFitDriver(const RooAbsData& data, const RooNLLVarNew& topNode, } else //this node needs evaluation, mark it's clients { + _nodes.push_back(pAbsReal); + // If the node doesn't depend on any observables, there is no need to // loop over events and we don't need to use the batched evaluation. RooArgSet observablesForNode; @@ -155,7 +157,8 @@ double RooFitDriver::getVal() for (const auto& it:_nodeInfos) if (it.second.remServers==0 && it.second.computeInGPU) assignToGPU(it.first); - + + int nNodes = _nodeInfos.size(); while (nNodes) { @@ -171,22 +174,22 @@ double RooFitDriver::getVal() } // find next cpu node - auto it=_nodeInfos.begin(); - for ( ; it!=_nodeInfos.end(); it++) - if (it->second.remServers==0 && !it->second.computeInGPU) break; + auto it = std::find_if(_nodes.begin(), _nodes.end(), [&](const RooAbsReal* a){ + auto const& info = _nodeInfos[a]; return info.remServers==0 && !info.computeInGPU; }); // if no cpu node available sleep for a while to save cpu usage - if (it==_nodeInfos.end()) + if (it==_nodes.end()) { std::this_thread::sleep_for(std::chrono::milliseconds(1)); continue; } - + // compute next cpu node - const RooAbsReal* node = it->first; - NodeInfo& info = it->second; + const RooAbsReal* node = *it; + NodeInfo& info = _nodeInfos[*it]; info.remServers=-2; //so that it doesn't get picked again nNodes--; + if (info.computeInScalarMode) { _nonDerivedValues.push_back(node->getVal(_data->get()));