Skip to content

Commit

Permalink
Use tree-parallel prediction in GTIL, if batch size is small
Browse files Browse the repository at this point in the history
  • Loading branch information
hcho3 committed Feb 17, 2022
1 parent e7fe6ee commit 7647887
Showing 1 changed file with 61 additions and 8 deletions.
69 changes: 61 additions & 8 deletions src/gtil/predict.cc
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,6 @@ void PredValueByOneTree(const treelite::Tree<ThresholdType, LeafOutputType>& tre
}
}


template <typename OutputLogic, typename ThresholdType, typename LeafOutputType>
void PredictByAllTrees(const treelite::ModelImpl<ThresholdType, LeafOutputType>& model,
float* output, const std::size_t batch_offset,
Expand All @@ -290,7 +289,6 @@ void PredictByAllTrees(const treelite::ModelImpl<ThresholdType, LeafOutputType>&
}
}


template <std::size_t block_of_rows_size, typename OutputLogic, typename ThresholdType,
typename LeafOutputType, typename DMatrixType>
void PredictBatchByBlockOfRowsKernel(
Expand Down Expand Up @@ -323,23 +321,78 @@ void PredictBatchByBlockOfRowsKernel(
});
}

template <typename OutputLogic, typename ThresholdType, typename LeafOutputType,
typename DMatrixType>
void PredictBatchTreeParallelKernel(
const treelite::ModelImpl<ThresholdType, LeafOutputType>& model, const DMatrixType* input,
float* output, const ThreadConfig& thread_config) {
const std::size_t num_row = input->GetNumRow();
const std::size_t num_tree = model.GetNumTree();
const int num_feature = model.num_feature;
const auto num_class = model.task_param.num_class;

FVec feats;
feats.Init(num_feature);
std::vector<float> sum_tloc(num_class * thread_config.nthread);
auto sched = treelite::threading_utils::ParallelSchedule::Static();
for (std::size_t row_id = 0; row_id < num_row; ++row_id) {
std::fill(sum_tloc.begin(), sum_tloc.end(), 0.0f);
feats.Fill(input, row_id);
treelite::threading_utils::ParallelFor(std::size_t(0), num_tree, thread_config, sched,
[&](std::size_t tree_id, int thread_id) {
const treelite::Tree<ThresholdType, LeafOutputType>& tree = model.trees[tree_id];
auto has_categorical = tree.HasCategoricalSplit();
if (has_categorical) {
PredValueByOneTree<true, OutputLogic>(tree, tree_id, feats,
&sum_tloc[thread_id * num_class], num_class);
} else {
PredValueByOneTree<false, OutputLogic>(tree, tree_id, feats,
&sum_tloc[thread_id * num_class], num_class);
}
});
feats.Clear(input, row_id);
for (std::uint32_t thread_id = 0; thread_id < thread_config.nthread; ++thread_id) {
for (unsigned i = 0; i < num_class; ++i) {
output[row_id * num_class + i] += sum_tloc[thread_id * num_class + i];
}
}
}
if (model.average_tree_output) {
for (std::size_t row_id = 0; row_id < num_row; ++row_id) {
OutputLogic::ApplyAverageFactor(model.task_param, num_tree, output, row_id, 1);
}
}
}

template <typename OutputLogic, typename ThresholdType, typename LeafOutputType,
typename DMatrixType>
void PredictBatchDispatch(
const treelite::ModelImpl<ThresholdType, LeafOutputType>& model, const DMatrixType* input,
float* output, const ThreadConfig& thread_config) {
if (input->GetNumRow() < kBlockOfRowsSize) {
// Small batch size => tree parallel method
PredictBatchTreeParallelKernel<OutputLogic>(model, input, output, thread_config);
} else {
// Sufficiently large batch size => row parallel method
PredictBatchByBlockOfRowsKernel<kBlockOfRowsSize, OutputLogic>(
model, input, output, thread_config);
}
}

template <typename ThresholdType, typename LeafOutputType, typename DMatrixType>
inline void PredictRaw(const treelite::ModelImpl<ThresholdType, LeafOutputType>& model,
const DMatrixType* input, float* output, const ThreadConfig& thread_config) {
InitOutPredictions(model, input, output);

switch (model.task_type) {
case treelite::TaskType::kBinaryClfRegr:
PredictBatchByBlockOfRowsKernel<kBlockOfRowsSize, BinaryClfRegrOutputLogic>(
model, input, output, thread_config);
PredictBatchDispatch<BinaryClfRegrOutputLogic>(model, input, output, thread_config);
break;
case treelite::TaskType::kMultiClfGrovePerClass:
PredictBatchByBlockOfRowsKernel<kBlockOfRowsSize, MultiClfGrovePerClassOutputLogic>(
model, input, output, thread_config);
PredictBatchDispatch<MultiClfGrovePerClassOutputLogic>(model, input, output, thread_config);
break;
case treelite::TaskType::kMultiClfProbDistLeaf:
PredictBatchByBlockOfRowsKernel<kBlockOfRowsSize, MultiClfProbDistLeafOutputLogic>(
model, input, output, thread_config);
PredictBatchDispatch<MultiClfProbDistLeafOutputLogic>(model, input, output, thread_config);
break;
case treelite::TaskType::kMultiClfCategLeaf:
default:
Expand Down

0 comments on commit 7647887

Please sign in to comment.