diff --git a/src/gtil/predict.cc b/src/gtil/predict.cc index ae094311..725efc4c 100644 --- a/src/gtil/predict.cc +++ b/src/gtil/predict.cc @@ -265,7 +265,6 @@ void PredValueByOneTree(const treelite::Tree& tre } } - template void PredictByAllTrees(const treelite::ModelImpl& model, float* output, const std::size_t batch_offset, @@ -290,7 +289,6 @@ void PredictByAllTrees(const treelite::ModelImpl& } } - template void PredictBatchByBlockOfRowsKernel( @@ -323,6 +321,64 @@ void PredictBatchByBlockOfRowsKernel( }); } +template +void PredictBatchTreeParallelKernel( + const treelite::ModelImpl& model, const DMatrixType* input, + float* output, const ThreadConfig& thread_config) { + const std::size_t num_row = input->GetNumRow(); + const std::size_t num_tree = model.GetNumTree(); + const int num_feature = model.num_feature; + const auto num_class = model.task_param.num_class; + + FVec feats; + feats.Init(num_feature); + std::vector sum_tloc(num_class * thread_config.nthread); + auto sched = treelite::threading_utils::ParallelSchedule::Static(); + for (std::size_t row_id = 0; row_id < num_row; ++row_id) { + std::fill(sum_tloc.begin(), sum_tloc.end(), 0.0f); + feats.Fill(input, row_id); + treelite::threading_utils::ParallelFor(std::size_t(0), num_tree, thread_config, sched, + [&](std::size_t tree_id, int thread_id) { + const treelite::Tree& tree = model.trees[tree_id]; + auto has_categorical = tree.HasCategoricalSplit(); + if (has_categorical) { + PredValueByOneTree(tree, tree_id, feats, + &sum_tloc[thread_id * num_class], num_class); + } else { + PredValueByOneTree(tree, tree_id, feats, + &sum_tloc[thread_id * num_class], num_class); + } + }); + feats.Clear(input, row_id); + for (std::uint32_t thread_id = 0; thread_id < thread_config.nthread; ++thread_id) { + for (unsigned i = 0; i < num_class; ++i) { + output[row_id * num_class + i] += sum_tloc[thread_id * num_class + i]; + } + } + } + if (model.average_tree_output) { + for (std::size_t row_id = 0; row_id < num_row; ++row_id) { + OutputLogic::ApplyAverageFactor(model.task_param, num_tree, output, row_id, 1); + } + } +} + +template +void PredictBatchDispatch( + const treelite::ModelImpl& model, const DMatrixType* input, + float* output, const ThreadConfig& thread_config) { + if (input->GetNumRow() < kBlockOfRowsSize) { + // Small batch size => tree parallel method + PredictBatchTreeParallelKernel(model, input, output, thread_config); + } else { + // Sufficiently large batch size => row parallel method + PredictBatchByBlockOfRowsKernel( + model, input, output, thread_config); + } +} + template inline void PredictRaw(const treelite::ModelImpl& model, const DMatrixType* input, float* output, const ThreadConfig& thread_config) { @@ -330,16 +386,13 @@ inline void PredictRaw(const treelite::ModelImpl& switch (model.task_type) { case treelite::TaskType::kBinaryClfRegr: - PredictBatchByBlockOfRowsKernel( - model, input, output, thread_config); + PredictBatchDispatch(model, input, output, thread_config); break; case treelite::TaskType::kMultiClfGrovePerClass: - PredictBatchByBlockOfRowsKernel( - model, input, output, thread_config); + PredictBatchDispatch(model, input, output, thread_config); break; case treelite::TaskType::kMultiClfProbDistLeaf: - PredictBatchByBlockOfRowsKernel( - model, input, output, thread_config); + PredictBatchDispatch(model, input, output, thread_config); break; case treelite::TaskType::kMultiClfCategLeaf: default: