Skip to content

Commit

Permalink
float64 support in multi-sum and child_index() (rapidsai#4648)
Browse files Browse the repository at this point in the history
- removed default `T = float` from multi-sum
- tests with float64 for multi-sum and `child_index()`
- refactored multi-sum tests to reduce shared memory usage
  - this is needed for the tests with float64 to compile

This is pull request 1 of 3 to integrate rapidsai#4646. This pull request is partly based on the work by @levsnv.

Authors:
  - Andy Adinets (https://github.com/canonizer)

Approvers:
  - Divye Gala (https://github.com/divyegala)
  - William Hicks (https://github.com/wphicks)

URL: rapidsai#4648
  • Loading branch information
canonizer authored Mar 31, 2022
1 parent 9916c29 commit 29057cb
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 76 deletions.
4 changes: 2 additions & 2 deletions cpp/include/cuml/fil/multi_sum.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
* Copyright (c) 2020-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -34,7 +34,7 @@
@data[] values are ordered such that the stride is 1 for values belonging
to the same group and @n_groups for values that are to be added together
*/
template <int R = 5, typename T = float>
template <int R = 5, typename T>
__device__ T multi_sum(T* data, int n_groups, int n_values)
{
T acc = threadIdx.x < n_groups * n_values ? data[threadIdx.x] : T();
Expand Down
8 changes: 5 additions & 3 deletions cpp/src/fil/internal.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,8 @@ struct categorical_sets {

// set count is due to tree_idx + node_within_tree_idx are both ints, hence uint32_t result
template <typename node_t>
__host__ __device__ __forceinline__ int category_matches(node_t node, float category) const
__host__ __device__ __forceinline__ int category_matches(node_t node,
typename node_t::real_t category) const
{
// standard boolean packing. This layout has better ILP
// node.set() is global across feature IDs and is an offset (as opposed
Expand All @@ -408,7 +409,8 @@ struct categorical_sets {
FIL will reject a model where an integer within [0, fid_num_cats] cannot be represented
precisely as a 32-bit float.
*/
return category < fid_num_cats[node.fid()] && category >= 0.0f &&
using real_t = typename node_t::real_t;
return category < static_cast<real_t>(fid_num_cats[node.fid()]) && category >= real_t(0) &&
fetch_bit(bits + node.set(), static_cast<uint32_t>(static_cast<int>(category)));
}
static int sizeof_mask_from_num_cats(int num_cats)
Expand All @@ -429,7 +431,7 @@ struct tree_base {
template <bool CATS_SUPPORTED, typename node_t>
__host__ __device__ __forceinline__ int child_index(const node_t& node,
int node_idx,
float val) const
typename node_t::real_t val) const
{
bool cond;

Expand Down
131 changes: 76 additions & 55 deletions cpp/test/sg/fil_child_index_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,28 +38,33 @@ struct proto_inner_node {
bool is_categorical = false; // see base_node::is_categorical
int fid = 0; // feature id, see base_node::fid
int set = 0; // which bit set represents the matching category list
float thresh = 0.0f; // threshold, see base_node::thresh
double thresh = 0.0; // threshold, see base_node::thresh
int left = 1; // left child idx, see sparse_node*::left_index()
val_t<float> split()
template <typename real_t>
val_t<real_t> split()
{
val_t<float> split;
val_t<real_t> split;
if (is_categorical)
split.idx = set;
else if (std::isnan(thresh))
split.f = std::numeric_limits<real_t>::quiet_NaN();
else
split.f = thresh;
split.f = static_cast<real_t>(thresh);
return split;
}
operator sparse_node16<float>()
template <typename real_t>
operator dense_node<real_t>()
{
return sparse_node16<float>({}, split(), fid, def_left, false, is_categorical, left);
return dense_node<real_t>({}, split<real_t>(), fid, def_left, false, is_categorical);
}
operator sparse_node8()
template <typename real_t>
operator sparse_node16<real_t>()
{
return sparse_node8({}, split(), fid, def_left, false, is_categorical, left);
return sparse_node16<real_t>({}, split<real_t>(), fid, def_left, false, is_categorical, left);
}
operator dense_node<float>()
operator sparse_node8()
{
return dense_node<float>({}, split(), fid, def_left, false, is_categorical);
return sparse_node8({}, split<float>(), fid, def_left, false, is_categorical, left);
}
};

Expand Down Expand Up @@ -108,8 +113,9 @@ struct ChildIndexTestParams {
proto_inner_node node;
int parent_node_idx = 0;
cat_sets_owner cso;
float input = 0.0f;
int correct = INT_MAX;
double input = 0.0;
int correct = INT_MAX;
bool skip_f32 = false; // if true, the test only runs for float64
};

std::ostream& operator<<(std::ostream& os, const ChildIndexTestParams& ps)
Expand All @@ -136,29 +142,36 @@ std::ostream& operator<<(std::ostream& os, const ChildIndexTestParams& ps)

template <typename fil_node_t>
class ChildIndexTest : public testing::TestWithParam<ChildIndexTestParams> {
using real_t = typename fil_node_t::real_t;

protected:
void check()
{
ChildIndexTestParams param = GetParam();

// skip tests that require float64 to work correctly
if (std::is_same_v<real_t, float> && param.skip_f32) return;

tree_base tree{param.cso.accessor()};
if (!std::is_same<fil_node_t, fil::dense_node<float>>::value) {
if constexpr (!std::is_same_v<fil_node_t, fil::dense_node<real_t>>) {
// test that the logic uses node.left instead of parent_node_idx
param.node.left = param.parent_node_idx * 2 + 1;
param.parent_node_idx = INT_MIN;
}
real_t input = isnan(param.input) ? std::numeric_limits<real_t>::quiet_NaN()
: static_cast<real_t>(param.input);
// nan -> !def_left, categorical -> if matches, numerical -> input >= threshold
int test_idx =
tree.child_index<true>((fil_node_t)param.node, param.parent_node_idx, param.input);
ASSERT(test_idx == param.correct,
"child index test: actual %d != correct %d",
test_idx,
param.correct);
int test_idx = tree.child_index<true>((fil_node_t)param.node, param.parent_node_idx, input);
ASSERT_EQ(test_idx, param.correct)
<< "child index test: actual " << test_idx << " != correct %d" << param.correct;
}
};

typedef ChildIndexTest<fil::dense_node<float>> ChildIndexTestDense;
typedef ChildIndexTest<fil::sparse_node16<float>> ChildIndexTestSparse16;
typedef ChildIndexTest<fil::sparse_node8> ChildIndexTestSparse8;
using ChildIndexTestDenseFloat32 = ChildIndexTest<fil::dense_node<float>>;
using ChildIndexTestDenseFloat64 = ChildIndexTest<fil::dense_node<double>>;
using ChildIndexTestSparse16Float32 = ChildIndexTest<fil::sparse_node16<float>>;
using ChildIndexTestSparse16Float64 = ChildIndexTest<fil::sparse_node16<double>>;
using ChildIndexTestSparse8 = ChildIndexTest<fil::sparse_node8>;

/* for dense nodes, left (false) == parent * 2 + 1, right (true) == parent * 2 + 2
E.g. see tree below:
Expand All @@ -168,48 +181,52 @@ typedef ChildIndexTest<fil::sparse_node8> ChildIndexTestSparse8;
3 -> 7, 8
4 -> 9, 10
*/
const float INF = std::numeric_limits<float>::infinity();
const double INF = std::numeric_limits<double>::infinity();
const double QNAN = std::numeric_limits<double>::quiet_NaN();

std::vector<ChildIndexTestParams> params = {
CHILD_INDEX_TEST_PARAMS(node = NODE(thresh = 0.0f), input = -INF, correct = 1), // val !>= thresh
CHILD_INDEX_TEST_PARAMS(node = NODE(thresh = 0.0f), input = 0.0f, correct = 2), // val >= thresh
CHILD_INDEX_TEST_PARAMS(node = NODE(thresh = 0.0f), input = +INF, correct = 2), // val >= thresh
CHILD_INDEX_TEST_PARAMS(node = NODE(thresh = 0.0), input = -INF, correct = 1), // val !>= thresh
CHILD_INDEX_TEST_PARAMS(node = NODE(thresh = 0.0), input = 0.0, correct = 2), // val >= thresh
CHILD_INDEX_TEST_PARAMS(node = NODE(thresh = 0.0), input = +INF, correct = 2), // val >= thresh
// the following two tests only work for float64
CHILD_INDEX_TEST_PARAMS(node = NODE(thresh = 0.0), input = -1e-50, correct = 1, skip_f32 = true),
CHILD_INDEX_TEST_PARAMS(node = NODE(thresh = 1e-50), input = 0.0, correct = 1, skip_f32 = true),
CHILD_INDEX_TEST_PARAMS(
node = NODE(thresh = 1.0f), input = -3.141592f, correct = 1), // val !>= thresh
CHILD_INDEX_TEST_PARAMS( // val >= thresh (e**pi > pi**e)
node = NODE(thresh = 22.459158f),
input = 23.140693f,
node = NODE(thresh = 1.0), input = -3.141592, correct = 1), // val !>= thresh
CHILD_INDEX_TEST_PARAMS( // val >= thresh (e**pi > pi**e)
node = NODE(thresh = 22.459158),
input = 23.140693,
correct = 2),
CHILD_INDEX_TEST_PARAMS( // val >= thresh for both negative
node = NODE(thresh = -0.37f),
input = -0.36f,
correct = 2), // val >= thresh
CHILD_INDEX_TEST_PARAMS(node = NODE(thresh = -INF), input = 0.36f, correct = 2), // val >= thresh
CHILD_INDEX_TEST_PARAMS(node = NODE(thresh = 0.0f), input = NAN, correct = 2), // !def_left
CHILD_INDEX_TEST_PARAMS(node = NODE(def_left = true), input = NAN, correct = 1), // !def_left
CHILD_INDEX_TEST_PARAMS(node = NODE(thresh = NAN), input = NAN, correct = 2), // !def_left
node = NODE(thresh = -0.37),
input = -0.36,
correct = 2), // val >= thresh
CHILD_INDEX_TEST_PARAMS(node = NODE(thresh = -INF), input = 0.36, correct = 2), // val >= thresh
CHILD_INDEX_TEST_PARAMS(node = NODE(thresh = 0.0f), input = QNAN, correct = 2), // !def_left
CHILD_INDEX_TEST_PARAMS(node = NODE(def_left = true), input = QNAN, correct = 1), // !def_left
CHILD_INDEX_TEST_PARAMS(node = NODE(thresh = QNAN), input = QNAN, correct = 2), // !def_left
CHILD_INDEX_TEST_PARAMS(
node = NODE(def_left = true, thresh = NAN), input = NAN, correct = 1), // !def_left
CHILD_INDEX_TEST_PARAMS(node = NODE(thresh = NAN), input = 0.0f, correct = 1), // val !>= thresh
node = NODE(def_left = true, thresh = QNAN), input = QNAN, correct = 1), // !def_left
CHILD_INDEX_TEST_PARAMS(node = NODE(thresh = QNAN), input = 0.0, correct = 1), // val !>= thresh
CHILD_INDEX_TEST_PARAMS(
node = NODE(thresh = 0.0f), parent_node_idx = 1, input = -INF, correct = 3),
node = NODE(thresh = 0.0), parent_node_idx = 1, input = -INF, correct = 3),
CHILD_INDEX_TEST_PARAMS(
node = NODE(thresh = 0.0f), parent_node_idx = 1, input = 0.0f, correct = 4),
node = NODE(thresh = 0.0), parent_node_idx = 1, input = 0.0f, correct = 4),
CHILD_INDEX_TEST_PARAMS(
node = NODE(thresh = 0.0f), parent_node_idx = 2, input = -INF, correct = 5),
node = NODE(thresh = 0.0), parent_node_idx = 2, input = -INF, correct = 5),
CHILD_INDEX_TEST_PARAMS(
node = NODE(thresh = 0.0f), parent_node_idx = 2, input = 0.0f, correct = 6),
node = NODE(thresh = 0.0), parent_node_idx = 2, input = 0.0f, correct = 6),
CHILD_INDEX_TEST_PARAMS(
node = NODE(thresh = 0.0f), parent_node_idx = 3, input = -INF, correct = 7),
node = NODE(thresh = 0.0), parent_node_idx = 3, input = -INF, correct = 7),
CHILD_INDEX_TEST_PARAMS(
node = NODE(thresh = 0.0f), parent_node_idx = 3, input = 0.0f, correct = 8),
node = NODE(thresh = 0.0), parent_node_idx = 3, input = 0.0f, correct = 8),
CHILD_INDEX_TEST_PARAMS(
node = NODE(thresh = 0.0f), parent_node_idx = 4, input = -INF, correct = 9),
node = NODE(thresh = 0.0), parent_node_idx = 4, input = -INF, correct = 9),
CHILD_INDEX_TEST_PARAMS(
node = NODE(thresh = 0.0f), parent_node_idx = 4, input = 0.0f, correct = 10),
CHILD_INDEX_TEST_PARAMS(parent_node_idx = 4, input = NAN, correct = 10), // !def_left
node = NODE(thresh = 0.0), parent_node_idx = 4, input = 0.0, correct = 10),
CHILD_INDEX_TEST_PARAMS(parent_node_idx = 4, input = QNAN, correct = 10), // !def_left
CHILD_INDEX_TEST_PARAMS(
node = NODE(def_left = true), input = NAN, parent_node_idx = 4, correct = 9), // !def_left
node = NODE(def_left = true), input = QNAN, parent_node_idx = 4, correct = 9), // !def_left
// cannot match ( < 0 and realistic fid_num_cats)
CHILD_INDEX_TEST_PARAMS(node = NODE(is_categorical = true),
cso.bits = {},
Expand Down Expand Up @@ -282,21 +299,25 @@ std::vector<ChildIndexTestParams> params = {
CHILD_INDEX_TEST_PARAMS(node = NODE(is_categorical = true, def_left = true),
cso.bits = {0b0000'0101},
cso.fid_num_cats = {3.0f},
input = NAN,
input = QNAN,
correct = 1),
// default right
CHILD_INDEX_TEST_PARAMS(node = NODE(is_categorical = true, def_left = false),
cso.bits = {0b0000'0101},
cso.fid_num_cats = {3.0f},
input = NAN,
input = QNAN,
correct = 2),
};

TEST_P(ChildIndexTestDense, Predict) { check(); }
TEST_P(ChildIndexTestSparse16, Predict) { check(); }
TEST_P(ChildIndexTestDenseFloat32, Predict) { check(); }
TEST_P(ChildIndexTestDenseFloat64, Predict) { check(); }
TEST_P(ChildIndexTestSparse16Float32, Predict) { check(); }
TEST_P(ChildIndexTestSparse16Float64, Predict) { check(); }
TEST_P(ChildIndexTestSparse8, Predict) { check(); }

INSTANTIATE_TEST_CASE_P(FilTests, ChildIndexTestDense, testing::ValuesIn(params));
INSTANTIATE_TEST_CASE_P(FilTests, ChildIndexTestSparse16, testing::ValuesIn(params));
INSTANTIATE_TEST_CASE_P(FilTests, ChildIndexTestDenseFloat32, testing::ValuesIn(params));
INSTANTIATE_TEST_CASE_P(FilTests, ChildIndexTestDenseFloat64, testing::ValuesIn(params));
INSTANTIATE_TEST_CASE_P(FilTests, ChildIndexTestSparse16Float32, testing::ValuesIn(params));
INSTANTIATE_TEST_CASE_P(FilTests, ChildIndexTestSparse16Float64, testing::ValuesIn(params));
INSTANTIATE_TEST_CASE_P(FilTests, ChildIndexTestSparse8, testing::ValuesIn(params));
} // namespace ML
53 changes: 37 additions & 16 deletions cpp/test/sg/multi_sum_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,36 +44,45 @@ __device__ void serial_multi_sum(const T* in, T* out, int n_groups, int n_values
}

// the most threads a block can have
const int max_threads = 1024;
const int MAX_THREADS = 1024;

struct MultiSumTestParams {
int radix; // number of elements summed to 1 at each stage of the sum
int n_groups; // number of independent sums
int n_values; // number of elements to add in each sum
};

template <typename T>
struct multi_sum_test_shmem {
T work[MAX_THREADS];
T correct_result[MAX_THREADS];
};

template <int R, typename T>
__device__ void test_single_radix(T thread_value, MultiSumTestParams p, int* block_error_flag)
__device__ void test_single_radix(multi_sum_test_shmem<T>& s,
T thread_value,
MultiSumTestParams p,
int* block_error_flag)
{
__shared__ T work[max_threads], correct_result[max_threads];
work[threadIdx.x] = thread_value;
serial_multi_sum(work, correct_result, p.n_groups, p.n_values);
T sum = multi_sum<R>(work, p.n_groups, p.n_values);
if (threadIdx.x < p.n_groups && 1e-4 < fabsf(sum - correct_result[threadIdx.x])) {
s.work[threadIdx.x] = thread_value;
serial_multi_sum(s.work, s.correct_result, p.n_groups, p.n_values);
T sum = multi_sum<R>(s.work, p.n_groups, p.n_values);
if (threadIdx.x < p.n_groups && 1e-4 < fabsf(sum - s.correct_result[threadIdx.x])) {
atomicAdd(block_error_flag, 1);
}
}

template <typename T>
__global__ void test_multi_sum_k(T* data, MultiSumTestParams* params, int* error_flags)
{
__shared__ multi_sum_test_shmem<T> s;
MultiSumTestParams p = params[blockIdx.x];
switch (p.radix) {
case 2: test_single_radix<2>(data[threadIdx.x], p, &error_flags[blockIdx.x]); break;
case 3: test_single_radix<3>(data[threadIdx.x], p, &error_flags[blockIdx.x]); break;
case 4: test_single_radix<4>(data[threadIdx.x], p, &error_flags[blockIdx.x]); break;
case 5: test_single_radix<5>(data[threadIdx.x], p, &error_flags[blockIdx.x]); break;
case 6: test_single_radix<6>(data[threadIdx.x], p, &error_flags[blockIdx.x]); break;
case 2: test_single_radix<2>(s, data[threadIdx.x], p, &error_flags[blockIdx.x]); break;
case 3: test_single_radix<3>(s, data[threadIdx.x], p, &error_flags[blockIdx.x]); break;
case 4: test_single_radix<4>(s, data[threadIdx.x], p, &error_flags[blockIdx.x]); break;
case 5: test_single_radix<5>(s, data[threadIdx.x], p, &error_flags[blockIdx.x]); break;
case 6: test_single_radix<6>(s, data[threadIdx.x], p, &error_flags[blockIdx.x]); break;
}
}

Expand Down Expand Up @@ -142,21 +151,33 @@ std::vector<int> block_sizes = []() {
std::vector<int> res;
for (int i = 2; i < 50; ++i)
res.push_back(i);
for (int i = max_threads - 50; i <= max_threads; ++i)
for (int i = MAX_THREADS - 50; i <= MAX_THREADS; ++i)
res.push_back(i);
return res;
}();

class MultiSumTestFloat : public MultiSumTest<float> {
class MultiSumTestFloat32 : public MultiSumTest<float> {
public:
void generate_data()
{
raft::random::Rng r(4321);
r.uniform(data_d.data().get(), data_d.size(), -1.0f, 1.0f, cudaStreamDefault);
}
};
TEST_P(MultiSumTestFloat, Import) { check(); }
INSTANTIATE_TEST_CASE_P(FilTests, MultiSumTestFloat, testing::ValuesIn(block_sizes));
TEST_P(MultiSumTestFloat32, Import) { check(); }
INSTANTIATE_TEST_CASE_P(FilTests, MultiSumTestFloat32, testing::ValuesIn(block_sizes));

class MultiSumTestFloat64 : public MultiSumTest<double> {
public:
void generate_data()
{
raft::random::Rng r(4321);
r.uniform(data_d.data().get(), data_d.size(), -1.0, 1.0, cudaStreamDefault);
}
};

TEST_P(MultiSumTestFloat64, Import) { check(); }
INSTANTIATE_TEST_CASE_P(FilTests, MultiSumTestFloat64, testing::ValuesIn(block_sizes));

class MultiSumTestInt : public MultiSumTest<int> {
public:
Expand Down

0 comments on commit 29057cb

Please sign in to comment.