Skip to content

Commit

Permalink
float64 support in FIL functions (#4655)
Browse files Browse the repository at this point in the history
Templatized functions related to FIL inference in preparation of `float64` support.

Instantiations of templates with `float64`, or tests for `float64`, _are not included_; they will be provided in a future pull request.

This is pull request 2 of 3 to integrate #4646. This pull request is partly based on the work by @levsnv.

Authors:
  - Andy Adinets (https://github.com/canonizer)
  - Levs Dolgovs (https://github.com/levsnv)
  - Dante Gama Dessavre (https://github.com/dantegd)

Approvers:
  - Divye Gala (https://github.com/divyegala)
  - William Hicks (https://github.com/wphicks)

URL: #4655
  • Loading branch information
canonizer authored Apr 2, 2022
1 parent 52767a9 commit 4ee12db
Show file tree
Hide file tree
Showing 11 changed files with 335 additions and 280 deletions.
2 changes: 1 addition & 1 deletion cpp/bench/sg/fil.cu
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ class FIL : public RegressionFixture<float> {
}

private:
ML::fil::forest_t forest;
ML::fil::forest_t<float> forest;
ModelHandle model;
Params p_rest;
};
Expand Down
19 changes: 12 additions & 7 deletions cpp/include/cuml/fil/fil.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,12 @@ enum storage_type_t {
};
static const char* storage_type_repr[] = {"AUTO", "DENSE", "SPARSE", "SPARSE8"};

template <typename real_t>
struct forest;

/** forest_t is the predictor handle */
typedef forest* forest_t;
template <typename real_t>
using forest_t = forest<real_t>*;

/** MAX_N_ITEMS determines the maximum allowed value for tl_params::n_items */
constexpr int MAX_N_ITEMS = 4;
Expand Down Expand Up @@ -112,33 +114,36 @@ struct treelite_params_t {
* @param model treelite model used to initialize the forest
* @param tl_params additional parameters for the forest
*/
// TODO (canonizer): use std::variant<forest_t<float> forest_t<double>>* for pforest
void from_treelite(const raft::handle_t& handle,
forest_t* pforest,
forest_t<float>* pforest,
ModelHandle model,
const treelite_params_t* tl_params);

/** free deletes forest and all resources held by it; after this, forest is no longer usable
* @param h cuML handle used by this function
* @param f the forest to free; not usable after the call to this function
*/
void free(const raft::handle_t& h, forest_t f);
template <typename real_t>
void free(const raft::handle_t& h, forest_t<real_t> f);

/** predict predicts on data (with n rows) using forest and writes results into preds;
* the number of columns is stored in forest, and both preds and data point to GPU memory
* @param h cuML handle used by this function
* @param f forest used for predictions
* @param preds array in GPU memory to store predictions into
size == predict_proba ? (2*num_rows) : num_rows
* size = predict_proba ? (2*num_rows) : num_rows
* @param data array of size n * cols (cols is the number of columns
* for the forest f) from which to predict
* @param num_rows number of data rows
* @param predict_proba for classifier models, this forces to output both class probabilities
* instead of binary class prediction. format matches scikit-learn API
*/
template <typename real_t>
void predict(const raft::handle_t& h,
forest_t f,
float* preds,
const float* data,
forest_t<real_t> f,
real_t* preds,
const real_t* data,
size_t num_rows,
bool predict_proba = false);

Expand Down
37 changes: 21 additions & 16 deletions cpp/src/fil/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ struct storage_base {
/** represents a dense tree */
template <typename real_t>
struct tree<dense_node<real_t>> : tree_base {
using real_type = real_t;
__host__ __device__ tree(categorical_sets cat_sets, dense_node<real_t>* nodes, int node_pitch)
: tree_base{cat_sets}, nodes_(nodes), node_pitch_(node_pitch)
{
Expand All @@ -61,10 +62,10 @@ struct tree<dense_node<real_t>> : tree_base {
};

/** partial specialization of storage. Stores the forest on GPU as a collection of dense nodes */
template <typename real_t_>
struct storage<dense_node<real_t_>> : storage_base<real_t_> {
using real_t = real_t_;
using node_t = dense_node<real_t>;
template <typename real_t>
struct storage<dense_node<real_t>> : storage_base<real_t> {
using real_type = real_t;
using node_t = dense_node<real_t>;
__host__ __device__ storage(categorical_sets cat_sets,
real_t* vector_leaf,
node_t* nodes,
Expand Down Expand Up @@ -93,6 +94,7 @@ struct storage<dense_node<real_t_>> : storage_base<real_t_> {
/** sparse tree */
template <typename node_t>
struct tree : tree_base {
using real_type = typename node_t::real_type;
__host__ __device__ tree(categorical_sets cat_sets, node_t* nodes)
: tree_base{cat_sets}, nodes_(nodes)
{
Expand All @@ -103,15 +105,15 @@ struct tree : tree_base {

/** storage stores the forest on GPU as a collection of sparse nodes */
template <typename node_t_>
struct storage : storage_base<typename node_t_::real_t> {
using node_t = node_t_;
using real_t = typename node_t::real_t;
int* trees_ = nullptr;
node_t* nodes_ = nullptr;
int num_trees_ = 0;
__host__ __device__
storage(categorical_sets cat_sets, real_t* vector_leaf, int* trees, node_t* nodes, int num_trees)
: storage_base<real_t>{cat_sets, vector_leaf},
struct storage : storage_base<typename node_t_::real_type> {
using node_t = node_t_;
using real_type = typename node_t::real_type;
int* trees_ = nullptr;
node_t* nodes_ = nullptr;
int num_trees_ = 0;
__host__ __device__ storage(
categorical_sets cat_sets, real_type* vector_leaf, int* trees, node_t* nodes, int num_trees)
: storage_base<real_type>{cat_sets, vector_leaf},
trees_(trees),
nodes_(nodes),
num_trees_(num_trees)
Expand All @@ -125,8 +127,11 @@ struct storage : storage_base<typename node_t_::real_t> {
}
};

typedef storage<sparse_node16<float>> sparse_storage16;
typedef storage<sparse_node8> sparse_storage8;
using dense_storage_f32 = storage<dense_node<float>>;
using dense_storage_f64 = storage<dense_node<double>>;
using sparse_storage16_f32 = storage<sparse_node16<float>>;
using sparse_storage16_f64 = storage<sparse_node16<double>>;
using sparse_storage8 = storage<sparse_node8>;

/// all model parameters mostly required to compute shared memory footprint,
/// also the footprint itself
Expand Down Expand Up @@ -168,7 +173,7 @@ struct shmem_size_params {
{
return cols_in_shmem ? sizeof_real * sdata_stride() * n_items << log2_threads_per_tree : 0;
}
template <int NITEMS, leaf_algo_t leaf_algo>
template <int NITEMS, typename real_t, leaf_algo_t leaf_algo>
size_t get_smem_footprint();
};

Expand Down
Loading

0 comments on commit 4ee12db

Please sign in to comment.