Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use tree-parallel prediction in GTIL, if batch size is small #367

Merged
merged 1 commit into from
Feb 17, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 61 additions & 8 deletions src/gtil/predict.cc
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,6 @@ void PredValueByOneTree(const treelite::Tree<ThresholdType, LeafOutputType>& tre
}
}


template <typename OutputLogic, typename ThresholdType, typename LeafOutputType>
void PredictByAllTrees(const treelite::ModelImpl<ThresholdType, LeafOutputType>& model,
float* output, const std::size_t batch_offset,
Expand All @@ -290,7 +289,6 @@ void PredictByAllTrees(const treelite::ModelImpl<ThresholdType, LeafOutputType>&
}
}


template <std::size_t block_of_rows_size, typename OutputLogic, typename ThresholdType,
typename LeafOutputType, typename DMatrixType>
void PredictBatchByBlockOfRowsKernel(
Expand Down Expand Up @@ -323,23 +321,78 @@ void PredictBatchByBlockOfRowsKernel(
});
}

template <typename OutputLogic, typename ThresholdType, typename LeafOutputType,
typename DMatrixType>
void PredictBatchTreeParallelKernel(
const treelite::ModelImpl<ThresholdType, LeafOutputType>& model, const DMatrixType* input,
float* output, const ThreadConfig& thread_config) {
const std::size_t num_row = input->GetNumRow();
const std::size_t num_tree = model.GetNumTree();
const int num_feature = model.num_feature;
const auto num_class = model.task_param.num_class;

FVec feats;
feats.Init(num_feature);
std::vector<float> sum_tloc(num_class * thread_config.nthread);
auto sched = treelite::threading_utils::ParallelSchedule::Static();
for (std::size_t row_id = 0; row_id < num_row; ++row_id) {
std::fill(sum_tloc.begin(), sum_tloc.end(), 0.0f);
feats.Fill(input, row_id);
treelite::threading_utils::ParallelFor(std::size_t(0), num_tree, thread_config, sched,
[&](std::size_t tree_id, int thread_id) {
const treelite::Tree<ThresholdType, LeafOutputType>& tree = model.trees[tree_id];
auto has_categorical = tree.HasCategoricalSplit();
if (has_categorical) {
PredValueByOneTree<true, OutputLogic>(tree, tree_id, feats,
&sum_tloc[thread_id * num_class], num_class);
} else {
PredValueByOneTree<false, OutputLogic>(tree, tree_id, feats,
&sum_tloc[thread_id * num_class], num_class);
}
});
feats.Clear(input, row_id);
for (std::uint32_t thread_id = 0; thread_id < thread_config.nthread; ++thread_id) {
for (unsigned i = 0; i < num_class; ++i) {
output[row_id * num_class + i] += sum_tloc[thread_id * num_class + i];
}
}
}
if (model.average_tree_output) {
for (std::size_t row_id = 0; row_id < num_row; ++row_id) {
OutputLogic::ApplyAverageFactor(model.task_param, num_tree, output, row_id, 1);
}
}
}

template <typename OutputLogic, typename ThresholdType, typename LeafOutputType,
typename DMatrixType>
void PredictBatchDispatch(
const treelite::ModelImpl<ThresholdType, LeafOutputType>& model, const DMatrixType* input,
float* output, const ThreadConfig& thread_config) {
if (input->GetNumRow() < kBlockOfRowsSize) {
// Small batch size => tree parallel method
PredictBatchTreeParallelKernel<OutputLogic>(model, input, output, thread_config);
} else {
// Sufficiently large batch size => row parallel method
PredictBatchByBlockOfRowsKernel<kBlockOfRowsSize, OutputLogic>(
model, input, output, thread_config);
}
}

template <typename ThresholdType, typename LeafOutputType, typename DMatrixType>
inline void PredictRaw(const treelite::ModelImpl<ThresholdType, LeafOutputType>& model,
const DMatrixType* input, float* output, const ThreadConfig& thread_config) {
InitOutPredictions(model, input, output);

switch (model.task_type) {
case treelite::TaskType::kBinaryClfRegr:
PredictBatchByBlockOfRowsKernel<kBlockOfRowsSize, BinaryClfRegrOutputLogic>(
model, input, output, thread_config);
PredictBatchDispatch<BinaryClfRegrOutputLogic>(model, input, output, thread_config);
break;
case treelite::TaskType::kMultiClfGrovePerClass:
PredictBatchByBlockOfRowsKernel<kBlockOfRowsSize, MultiClfGrovePerClassOutputLogic>(
model, input, output, thread_config);
PredictBatchDispatch<MultiClfGrovePerClassOutputLogic>(model, input, output, thread_config);
break;
case treelite::TaskType::kMultiClfProbDistLeaf:
PredictBatchByBlockOfRowsKernel<kBlockOfRowsSize, MultiClfProbDistLeafOutputLogic>(
model, input, output, thread_config);
PredictBatchDispatch<MultiClfProbDistLeafOutputLogic>(model, input, output, thread_config);
break;
case treelite::TaskType::kMultiClfCategLeaf:
default:
Expand Down