Skip to content

Commit

Permalink
NVTX Markers for RF and RF-backend (#3014)
Browse files Browse the repository at this point in the history
* This PR adds NVTX Markers to major time-consuming function calls of the regressors and classifiers of RF and DecisionTrees.
* They span both RandomForest and DecisionTree code-bases

Authors:
  - Venkat (@venkywonka)
  - John Zedlewski (@JohnZed)

Approvers:
  - Thejaswi. N. S (@teju85)
  - AJ Schmidt (@ajschmidt8)
  - John Zedlewski (@JohnZed)

URL: #3014
  • Loading branch information
venkywonka authored Feb 25, 2021
1 parent 5e874e9 commit 8fa2b90
Show file tree
Hide file tree
Showing 13 changed files with 164 additions and 15 deletions.
4 changes: 4 additions & 0 deletions cpp/src/decisiontree/batched-levelalgo/builder.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

#include "builder_base.cuh"

#include <common/nvtx.hpp>

namespace ML {
namespace DecisionTree {

Expand Down Expand Up @@ -51,6 +53,7 @@ void grow_tree(std::shared_ptr<MLCommon::deviceAllocator> d_allocator,
const DecisionTreeParams& params, cudaStream_t stream,
std::vector<SparseTreeNode<DataT, LabelT>>& sparsetree,
IdxT& num_leaves, IdxT& depth) {
ML::PUSH_RANGE("DecisionTree::grow_tree in batched-levelalgo @builder.cuh");
Builder<Traits> builder;
size_t d_wsize, h_wsize;
builder.workspaceSize(d_wsize, h_wsize, treeid, seed, params, data, labels,
Expand All @@ -68,6 +71,7 @@ void grow_tree(std::shared_ptr<MLCommon::deviceAllocator> d_allocator,
d_buff.release(stream);
h_buff.release(stream);
convertToSparse<Traits>(builder, h_nodes.data(), sparsetree);
ML::POP_RANGE();
}

/**
Expand Down
28 changes: 27 additions & 1 deletion cpp/src/decisiontree/batched-levelalgo/builder_base.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
#include "node.cuh"
#include "split.cuh"

#include <common/nvtx.hpp>

namespace ML {
namespace DecisionTree {

Expand Down Expand Up @@ -139,6 +141,8 @@ struct Builder {
const DataT* data, const LabelT* labels, IdxT totalRows,
IdxT totalCols, IdxT sampledRows, IdxT sampledCols,
IdxT* rowids, IdxT nclasses, const DataT* quantiles) {
ML::PUSH_RANGE(
"Builder::workspaceSize @builder_base.cuh [batched-levelalgo]");
ASSERT(quantiles != nullptr,
"Currently quantiles need to be computed before this call!");
params = p;
Expand Down Expand Up @@ -200,6 +204,7 @@ struct Builder {
calculateAlignedBytes(sizeof(NodeT) * 2 * max_batch); // next_nodes
// all nodes in the tree
h_wsize = calculateAlignedBytes(sizeof(IdxT)); // h_n_nodes
ML::POP_RANGE();
}

/**
Expand All @@ -210,6 +215,8 @@ struct Builder {
* @param[in] h_wspace pinned host buffer needed to store the learned nodes
*/
void assignWorkspace(char* d_wspace, char* h_wspace) {
ML::PUSH_RANGE(
"Builder::assignWorkspace @builder_base.cuh [batched-levelalgo]");
auto max_batch = params.max_batch_size;
auto n_col_blks = n_blks_for_cols;
// device
Expand Down Expand Up @@ -245,6 +252,7 @@ struct Builder {
next_nodes = reinterpret_cast<NodeT*>(d_wspace);
// host
h_n_nodes = reinterpret_cast<IdxT*>(h_wspace);
ML::POP_RANGE();
}

/**
Expand All @@ -258,6 +266,7 @@ struct Builder {
*/
void train(std::vector<Node<DataT, LabelT, IdxT>>& h_nodes, IdxT& num_leaves,
IdxT& depth, cudaStream_t s) {
ML::PUSH_RANGE("Builder::train @builder_base.cuh [batched-levelalgo]");
init(h_nodes, s);
while (true) {
IdxT new_nodes = doSplit(h_nodes, s);
Expand All @@ -267,6 +276,7 @@ struct Builder {
}
raft::update_host(&num_leaves, n_leaves, 1, s);
raft::update_host(&depth, n_depth, 1, s);
ML::POP_RANGE();
}

private:
Expand Down Expand Up @@ -322,6 +332,7 @@ struct Builder {
*/
IdxT doSplit(std::vector<Node<DataT, LabelT, IdxT>>& h_nodes,
cudaStream_t s) {
ML::PUSH_RANGE("Builder::doSplit @bulder_base.cuh [batched-levelalgo]");
auto batchSize = node_end - node_start;
// start fresh on the number of *new* nodes created in this batch
CUDA_CHECK(cudaMemsetAsync(n_nodes, 0, sizeof(IdxT), s));
Expand All @@ -338,20 +349,23 @@ struct Builder {
}
// create child nodes (or make the current ones leaf)
auto smemSize = Traits::nodeSplitSmemSize(*this);
ML::PUSH_RANGE("nodeSplitKernel @builder_base.cuh [batched-levelalgo]");
nodeSplitKernel<DataT, LabelT, IdxT, typename Traits::DevTraits,
Traits::TPB_SPLIT>
<<<batchSize, Traits::TPB_SPLIT, smemSize, s>>>(
params.max_depth, params.min_samples_leaf, params.min_samples_split,
params.max_leaves, params.min_impurity_decrease, input, curr_nodes,
next_nodes, n_nodes, splits, n_leaves, h_total_nodes, n_depth);
CUDA_CHECK(cudaGetLastError());
ML::POP_RANGE();
// copy the updated (due to leaf creation) and newly created child nodes
raft::update_host(h_n_nodes, n_nodes, 1, s);
CUDA_CHECK(cudaStreamSynchronize(s));
h_nodes.resize(h_nodes.size() + batchSize + *h_n_nodes);
raft::update_host(h_nodes.data() + node_start, curr_nodes, batchSize, s);
raft::update_host(h_nodes.data() + h_total_nodes, next_nodes, *h_n_nodes,
s);
ML::POP_RANGE();
return *h_n_nodes;
}
}; // end Builder
Expand Down Expand Up @@ -391,6 +405,8 @@ struct ClsTraits {
static void computeSplit(Builder<ClsTraits<DataT, LabelT, IdxT>>& b, IdxT col,
IdxT batchSize, CRITERION splitType,
cudaStream_t s) {
ML::PUSH_RANGE(
"Builder::computeSplit @builder_base.cuh [batched-levelalgo]");
auto nbins = b.params.n_bins;
auto nclasses = b.input.nclasses;
auto binSize = nbins * 2 * nclasses;
Expand All @@ -403,12 +419,16 @@ struct ClsTraits {
smemSize += 2 * sizeof(DataT) + 1 * sizeof(int);

CUDA_CHECK(cudaMemsetAsync(b.hist, 0, sizeof(int) * b.nHistBins, s));
ML::PUSH_RANGE(
"computeSplitClassificationKernel @builder_base.cuh [batched-levelalgo]");
computeSplitClassificationKernel<DataT, LabelT, IdxT, TPB_DEFAULT>
<<<grid, TPB_DEFAULT, smemSize, s>>>(
b.hist, b.params.n_bins, b.params.max_depth, b.params.min_samples_split,
b.params.min_samples_leaf, b.params.min_impurity_decrease,
b.params.max_leaves, b.input, b.curr_nodes, col, b.done_count, b.mutex,
b.n_leaves, b.splits, splitType, b.treeid, b.seed);
ML::POP_RANGE(); //computeSplitClassificationKernel
ML::POP_RANGE(); //Builder::computeSplit
}

/**
Expand Down Expand Up @@ -460,8 +480,9 @@ struct RegTraits {
static void computeSplit(Builder<RegTraits<DataT, IdxT>>& b, IdxT col,
IdxT batchSize, CRITERION splitType,
cudaStream_t s) {
ML::PUSH_RANGE(
"Builder::computeSplit @builder_base.cuh [batched-levelalgo]");
auto n_col_blks = std::min(b.n_blks_for_cols, b.input.nSampledCols - col);

dim3 grid(b.n_blks_for_rows, n_col_blks, batchSize);
auto nbins = b.params.n_bins;
size_t smemSize = 7 * nbins * sizeof(DataT) + nbins * sizeof(int);
Expand All @@ -478,13 +499,18 @@ struct RegTraits {
CUDA_CHECK(cudaMemsetAsync(b.pred2P, 0, sizeof(DataT) * b.nPredCounts, s));
CUDA_CHECK(
cudaMemsetAsync(b.pred_count, 0, sizeof(IdxT) * b.nPredCounts, s));

ML::PUSH_RANGE(
"computeSplitRegressionKernel @builder_base.cuh [batched-levelalgo]");
computeSplitRegressionKernel<DataT, DataT, IdxT, TPB_DEFAULT>
<<<grid, TPB_DEFAULT, smemSize, s>>>(
b.pred, b.pred2, b.pred2P, b.pred_count, b.params.n_bins,
b.params.max_depth, b.params.min_samples_split,
b.params.min_samples_leaf, b.params.min_impurity_decrease,
b.params.max_leaves, b.input, b.curr_nodes, col, b.done_count, b.mutex,
b.n_leaves, b.splits, b.block_sync, splitType, b.treeid, b.seed);
ML::POP_RANGE(); //computeSplitRegressionKernel
ML::POP_RANGE(); //Builder::computeSplit
}

/**
Expand Down
13 changes: 12 additions & 1 deletion cpp/src/decisiontree/decisiontree_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
#include "quantile/quantile.h"
#include "treelite_util.h"

#include <common/nvtx.hpp>

namespace ML {

bool is_dev_ptr(const void *p) {
Expand Down Expand Up @@ -262,6 +264,7 @@ void DecisionTreeBase<T, L>::plant(
const int nrows, const L *labels, unsigned int *rowids,
const int n_sampled_rows, int unique_labels, const int treeid,
uint64_t seed) {
ML::PUSH_RANGE("DecisionTreeBase::plant @decisiontree_impl.cuh");
dinfo.NLocalrows = nrows;
dinfo.NGlobalrows = nrows;
dinfo.Ncols = ncols;
Expand All @@ -274,7 +277,7 @@ void DecisionTreeBase<T, L>::plant(
}
CUDA_CHECK(cudaStreamSynchronize(
tempmem->stream)); // added to ensure accurate measurement

ML::PUSH_RANGE("DecisionTreeBase::plant::bootstrapping features");
//Bootstrap features
unsigned int *h_colids = tempmem->h_colids->data();
if (tree_params.bootstrap_features) {
Expand All @@ -285,6 +288,7 @@ void DecisionTreeBase<T, L>::plant(
} else {
std::iota(h_colids, h_colids + dinfo.Ncols, 0);
}
ML::POP_RANGE();
prepare_time = prepare_fit_timer.getElapsedSeconds();

total_temp_mem = tempmem->totalmem;
Expand All @@ -304,6 +308,7 @@ void DecisionTreeBase<T, L>::plant(
treeid, tempmem);
}
train_time = timer.getElapsedSeconds();
ML::POP_RANGE();
}

template <typename T, typename L>
Expand Down Expand Up @@ -484,6 +489,8 @@ void DecisionTreeClassifier<T>::grow_deep_tree(
const int n_sampled_rows, const int ncols, const float colper,
const int nrows, std::vector<SparseTreeNode<T, int>> &sparsetree,
const int treeid, std::shared_ptr<TemporaryMemory<T, int>> tempmem) {
ML::PUSH_RANGE(
"DecisionTreeClassifier::grow_deep_tree @decisiontree_impl.cuh");
int leaf_cnt = 0;
int depth_cnt = 0;
grow_deep_tree_classification(data, labels, rowids, ncols, colper,
Expand All @@ -492,6 +499,7 @@ void DecisionTreeClassifier<T>::grow_deep_tree(
sparsetree, treeid, tempmem);
this->depth_counter = depth_cnt;
this->leaf_counter = leaf_cnt;
ML::POP_RANGE();
}

template <typename T>
Expand All @@ -500,13 +508,16 @@ void DecisionTreeRegressor<T>::grow_deep_tree(
const int n_sampled_rows, const int ncols, const float colper,
const int nrows, std::vector<SparseTreeNode<T, T>> &sparsetree,
const int treeid, std::shared_ptr<TemporaryMemory<T, T>> tempmem) {
ML::PUSH_RANGE(
"DecisionTreeRegressor::grow_deep_tree @decisiontree_impl.cuh");
int leaf_cnt = 0;
int depth_cnt = 0;
grow_deep_tree_regression(data, labels, rowids, ncols, colper, n_sampled_rows,
nrows, this->tree_params, depth_cnt, leaf_cnt,
sparsetree, treeid, tempmem);
this->depth_counter = depth_cnt;
this->leaf_counter = leaf_cnt;
ML::POP_RANGE();
}

//Class specializations
Expand Down
9 changes: 8 additions & 1 deletion cpp/src/decisiontree/levelalgo/common_helper.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -23,6 +23,8 @@
#include <stats/minmax.cuh>
#include "common_kernel.cuh"

#include <common/nvtx.hpp>

namespace ML {
namespace DecisionTree {

Expand All @@ -38,6 +40,8 @@ void update_feature_sampling(unsigned int *h_colids, unsigned int *d_colids,
std::vector<unsigned int> &feature_selector,
std::shared_ptr<TemporaryMemory<T, L>> tempmem,
raft::random::Rng &d_rng) {
ML::PUSH_RANGE(
"update_feature_sampling @common_helper.cuh (does feature subsampling)");
if (h_colstart != nullptr) {
if (Ncols != ncols_sampled) {
std::shuffle(h_colids, h_colids + Ncols, rng);
Expand All @@ -63,6 +67,7 @@ void update_feature_sampling(unsigned int *h_colids, unsigned int *d_colids,
raft::update_device(d_colids, h_colids, ncols_sampled * n_nodes,
tempmem->stream);
}
ML::POP_RANGE();
}

//This function calcualtes min/max from the samples that belong in a given node. This is done for all the nodes at a given level
Expand Down Expand Up @@ -104,6 +109,7 @@ void get_minmax(const T *data, const unsigned int *flags,
void setup_sampling(unsigned int *flagsptr, unsigned int *sample_cnt,
const unsigned int *rowids, const int nrows,
const int n_sampled_rows, cudaStream_t &stream) {
ML::PUSH_RANGE("DecisionTree::setup_sampling @common_helper.cuh");
CUDA_CHECK(cudaMemsetAsync(sample_cnt, 0, nrows * sizeof(int), stream));
int threads = 256;
int blocks = raft::ceildiv(n_sampled_rows, threads);
Expand All @@ -114,6 +120,7 @@ void setup_sampling(unsigned int *flagsptr, unsigned int *sample_cnt,
setup_flags_kernel<<<blocks, threads, 0, stream>>>(sample_cnt, flagsptr,
nrows);
CUDA_CHECK(cudaGetLastError());
ML::POP_RANGE(); //setup_sampling @common_helper.cuh
}

//This function call the split kernel
Expand Down
19 changes: 17 additions & 2 deletions cpp/src/decisiontree/levelalgo/levelfunc_classifier.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -25,6 +25,8 @@
#include "levelhelper_classifier.cuh"
#include "metric.cuh"

#include <common/nvtx.hpp>

namespace ML {
namespace DecisionTree {

Expand All @@ -45,6 +47,8 @@ void grow_deep_tree_classification(
const ML::DecisionTree::DecisionTreeParams& tree_params, int& depth_cnt,
int& leaf_cnt, std::vector<SparseTreeNode<T, int>>& sparsetree,
const int treeid, std::shared_ptr<TemporaryMemory<T, int>> tempmem) {
ML::PUSH_RANGE(
"DecisionTree::grow_deep_tree_classification @levelfunc_classifier.cuh");
const int ncols_sampled = (int)(colper * Ncols);
unsigned int* flagsptr = tempmem->d_flags->data();
unsigned int* sample_cnt = tempmem->d_sample_cnt->data();
Expand Down Expand Up @@ -111,6 +115,7 @@ void grow_deep_tree_classification(

int scatter_algo_depth =
std::min(tempmem->swap_depth, tree_params.max_depth + 1);
ML::PUSH_RANGE("scatter phase @levelfunc_classifier");
for (int depth = 0; (depth < scatter_algo_depth) && (n_nodes_nextitr != 0);
depth++) {
depth_cnt = depth;
Expand Down Expand Up @@ -166,12 +171,19 @@ void grow_deep_tree_classification(
2 * n_nodes * n_unique_labels * sizeof(unsigned int));
}
}
ML::POP_RANGE(); //scatter phase @levelfunc_classifier.cuh

ML::PUSH_RANGE("gather phase @levelfunc_classifier.cuh");
// Start of gather algorithm
//Convertor
CUML_LOG_DEBUG("begin gather ");
int lastsize = sparsetree.size() - sparsesize_nextitr;
n_nodes = n_nodes_nextitr;
if (n_nodes == 0) return;
if (n_nodes == 0) {
ML::POP_RANGE(); //gather phase ended
ML::POP_RANGE(); //grow_deep_tree_classification end
return;
}
unsigned int *d_nodecount, *d_samplelist, *d_nodestart;
SparseTreeNode<T, int>* d_sparsenodes;
SparseTreeNode<T, int>* h_sparsenodes;
Expand Down Expand Up @@ -250,6 +262,9 @@ void grow_deep_tree_classification(
sparsetree.insert(sparsetree.end(), h_sparsenodes,
h_sparsenodes + lastsize);
}

ML::POP_RANGE(); //gather phase @levelfunc_classifier.cuh
ML::POP_RANGE(); //grow_deep_tree_classification @levelfunc_classifier.cuh
}

} // namespace DecisionTree
Expand Down
Loading

0 comments on commit 8fa2b90

Please sign in to comment.