Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify dense and sparse tests #4417

Merged
merged 34 commits into from
Dec 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
701956e
unified tree2fil
levsnv Nov 4, 2021
97dc049
unified init_dense, init_sparse
levsnv Nov 10, 2021
84bdc1e
drafted tl2fil as class for from_treelite
levsnv Nov 11, 2021
e1219bb
Merge branch 'branch-21.12' of github.com:rapidsai/cuml into unify-de…
levsnv Nov 11, 2021
c1b5d54
fixed a bug
levsnv Nov 11, 2021
6e8e463
stray changes
levsnv Nov 11, 2021
d6d0ece
Apply suggestions from code review
levsnv Nov 12, 2021
42c1c2e
apply suggestions from code review
levsnv Nov 16, 2021
c002e26
Merge branch 'branch-21.12' of github.com:rapidsai/cuml into unify-de…
levsnv Nov 16, 2021
9d304e6
Merge branch 'unify-dense-sparse-import' of github.com:levsnv/cuml in…
levsnv Nov 16, 2021
6967714
style
levsnv Nov 17, 2021
dfdfff1
made tree2fil a method of tl2fil_t, misc comments
levsnv Nov 17, 2021
58044af
tl2fil_t:: init_object(), init_forest()
levsnv Nov 17, 2021
8377aee
tracking tokens
levsnv Nov 17, 2021
e170290
addressed review comments
levsnv Nov 18, 2021
06bba79
addressed review comments
levsnv Nov 19, 2021
0ef919f
Merge branch 'branch-21.12' of github.com:rapidsai/cuml into unify-de…
levsnv Nov 19, 2021
d52b896
fixed enum->bool bug
levsnv Nov 19, 2021
b495abb
Merge branch 'branch-21.12' of github.com:rapidsai/cuml into unify-de…
levsnv Nov 20, 2021
de1668c
Merge branch 'branch-22.02' of github.com:rapidsai/cuml into unify-de…
levsnv Nov 23, 2021
70cd427
style
levsnv Nov 23, 2021
5f5ce4e
typo
levsnv Nov 24, 2021
eea9bab
style
levsnv Nov 24, 2021
6e12723
unified dense adn sparse tests; test cases are almost entirely disjoint
levsnv Dec 2, 2021
32a5103
stray comment
levsnv Dec 2, 2021
e610175
Merge branch 'branch-22.02' of github.com:rapidsai/cuml into unify-de…
levsnv Dec 3, 2021
e4913ad
Merge branch 'branch-22.02' of github.com:rapidsai/cuml into unify-tests
levsnv Dec 16, 2021
fe38a39
addressed review comments
levsnv Dec 17, 2021
7ac36f2
Merge branch 'unify-dense-sparse-import' into unify-tests
levsnv Dec 17, 2021
d9d888f
moved node_traits from common.cuh to internal.cuh to use in fil_tests.cu
levsnv Dec 17, 2021
f36304d
types
levsnv Dec 17, 2021
8de22be
ref -> val
levsnv Dec 17, 2021
d200341
Merge branch 'branch-22.02' into unify-tests
levsnv Dec 17, 2021
5679fe1
fix conflict resolution
levsnv Dec 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 0 additions & 33 deletions cpp/src/fil/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,6 @@

#include "internal.cuh"

// needed for node_traits<...>
namespace treelite {
template <typename, typename>
struct ModelImpl;
}

namespace ML {
namespace fil {

Expand Down Expand Up @@ -120,33 +114,6 @@ struct sparse_storage : storage_base {
typedef sparse_storage<sparse_node16> sparse_storage16;
typedef sparse_storage<sparse_node8> sparse_storage8;

struct dense_forest;
template <typename node_t>
struct sparse_forest;

template <typename node_t>
struct node_traits {
using storage = sparse_storage<node_t>;
using forest = sparse_forest<node_t>;
static const bool IS_DENSE = false;
static const storage_type_t storage_type_enum =
std::is_same<sparse_node16, node_t>() ? SPARSE : SPARSE8;
template <typename threshold_t, typename leaf_t>
static void check(const treelite::ModelImpl<threshold_t, leaf_t>& model);
};

template <>
struct node_traits<dense_node> {
using storage = dense_storage;
using forest = dense_forest;
static const bool IS_DENSE = true;
static const storage_type_t storage_type_enum = DENSE;
template <typename threshold_t, typename leaf_t>
static void check(const treelite::ModelImpl<threshold_t, leaf_t>& model)
{
}
};

/// all model parameters mostly required to compute shared memory footprint,
/// also the footprint itself
struct shmem_size_params {
Expand Down
37 changes: 37 additions & 0 deletions cpp/src/fil/internal.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ namespace raft {
class handle_t;
}

// needed for node_traits<...>
namespace treelite {
template <typename, typename>
struct ModelImpl;
}

namespace ML {
namespace fil {

Expand Down Expand Up @@ -215,6 +221,37 @@ struct alignas(8) sparse_node8 : base_node {
__host__ __device__ int left(int curr) const { return left_index(); }
};

struct dense_forest;
template <typename node_t>
struct sparse_forest;

struct dense_storage;
template <typename node_t>
struct sparse_storage;

template <typename node_t>
struct node_traits {
using storage = sparse_storage<node_t>;
using forest = sparse_forest<node_t>;
static const bool IS_DENSE = false;
static const storage_type_t storage_type_enum =
std::is_same<sparse_node16, node_t>() ? SPARSE : SPARSE8;
template <typename threshold_t, typename leaf_t>
static void check(const treelite::ModelImpl<threshold_t, leaf_t>& model);
};

template <>
struct node_traits<dense_node> {
using storage = dense_storage;
using forest = dense_forest;
static const bool IS_DENSE = true;
static const storage_type_t storage_type_enum = DENSE;
template <typename threshold_t, typename leaf_t>
static void check(const treelite::ModelImpl<threshold_t, leaf_t>& model)
{
}
};

/** leaf_algo_t describes what the leaves in a FIL forest store (predict)
and how FIL aggregates them into class margins/regression result/best class
**/
Expand Down
73 changes: 31 additions & 42 deletions cpp/test/sg/fil_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -646,31 +646,8 @@ class BaseFilTest : public testing::TestWithParam<FilTestParams> {
FilTestParams ps;
};

class PredictDenseFilTest : public BaseFilTest {
protected:
void init_forest(fil::forest_t* pforest) override
{
// init FIL model
fil::forest_params_t fil_ps;
fil_ps.depth = ps.depth;
fil_ps.num_trees = ps.num_trees;
fil_ps.num_cols = ps.num_cols;
fil_ps.algo = ps.algo;
fil_ps.output = ps.output;
fil_ps.threshold = ps.threshold;
fil_ps.global_bias = ps.global_bias;
fil_ps.leaf_algo = ps.leaf_algo;
fil_ps.num_classes = ps.num_classes;
fil_ps.blocks_per_sm = ps.blocks_per_sm;
fil_ps.threads_per_tree = ps.threads_per_tree;
fil_ps.n_items = ps.n_items;

fil::init(handle, pforest, cat_sets_h.accessor(), vector_leaf, nullptr, nodes.data(), &fil_ps);
}
};

template <typename fil_node_t>
class BasePredictSparseFilTest : public BaseFilTest {
class BasePredictFilTest : public BaseFilTest {
protected:
void dense2sparse_node(const fil::dense_node* dense_root,
int i_dense,
Expand Down Expand Up @@ -717,36 +694,48 @@ class BasePredictSparseFilTest : public BaseFilTest {

void init_forest(fil::forest_t* pforest) override
{
constexpr bool IS_DENSE = node_traits<fil_node_t>::IS_DENSE;
std::vector<fil_node_t> init_nodes;
if constexpr (!IS_DENSE) {
dense2sparse();
init_nodes = sparse_nodes;
} else {
init_nodes = nodes;
}
ASSERT(init_nodes.size() < std::size_t(INT_MAX), "generated too many nodes");

// init FIL model
fil::forest_params_t fil_params;
fil_params.num_trees = ps.num_trees;
fil_params.num_cols = ps.num_cols;
fil_params.algo = ps.algo;
fil_params.output = ps.output;
fil_params.threshold = ps.threshold;
fil_params.global_bias = ps.global_bias;
fil_params.leaf_algo = ps.leaf_algo;
fil_params.num_classes = ps.num_classes;
fil_params.blocks_per_sm = ps.blocks_per_sm;
fil_params.threads_per_tree = ps.threads_per_tree;
fil_params.n_items = ps.n_items;

dense2sparse();
fil_params.num_nodes = sparse_nodes.size();
fil::forest_params_t fil_params = {
.num_nodes = static_cast<int>(init_nodes.size()),
.depth = ps.depth,
.num_trees = ps.num_trees,
.num_cols = ps.num_cols,
.leaf_algo = ps.leaf_algo,
.algo = ps.algo,
.output = ps.output,
.threshold = ps.threshold,
.global_bias = ps.global_bias,
.num_classes = ps.num_classes,
.blocks_per_sm = ps.blocks_per_sm,
.threads_per_tree = ps.threads_per_tree,
.n_items = ps.n_items,
};

fil::init(handle,
pforest,
cat_sets_h.accessor(),
vector_leaf,
trees.data(),
sparse_nodes.data(),
init_nodes.data(),
&fil_params);
}
std::vector<fil_node_t> sparse_nodes;
std::vector<int> trees;
};

typedef BasePredictSparseFilTest<fil::sparse_node16> PredictSparse16FilTest;
typedef BasePredictSparseFilTest<fil::sparse_node8> PredictSparse8FilTest;
typedef BasePredictFilTest<fil::dense_node> PredictDenseFilTest;
typedef BasePredictFilTest<fil::sparse_node16> PredictSparse16FilTest;
typedef BasePredictFilTest<fil::sparse_node8> PredictSparse8FilTest;

class TreeliteFilTest : public BaseFilTest {
protected:
Expand Down