Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

float64 support in FIL core #4646

Merged
merged 51 commits into from
Apr 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
86bc6b5
templatized node, forest and storage types
levsnv Feb 9, 2022
abfb605
compiles?
levsnv Feb 11, 2022
a5d970f
Merge branch 'branch-22.04' of github.com:rapidsai/cuml into fp64
levsnv Feb 11, 2022
96e2ae6
simplify leaf_output_t
levsnv Feb 11, 2022
ead832d
Merge branch 'branch-22.04' of github.com:rapidsai/cuml into fp64
levsnv Feb 12, 2022
35ad5d9
draft
levsnv Feb 16, 2022
df001e7
Merge branch 'branch-22.04' of github.com:rapidsai/cuml into fp64
levsnv Feb 16, 2022
f6efb88
fixed extra/missing instantiations
levsnv Feb 16, 2022
4bc4f99
Merge branch 'fp64' into func64
levsnv Feb 16, 2022
1bebcca
style
levsnv Feb 16, 2022
aab158b
style
levsnv Feb 16, 2022
c3f2d4f
Merge branch 'fp64' into func64
levsnv Feb 16, 2022
1236c52
removed ML::fil::init templatization, added KeyValuePair templatization
levsnv Feb 18, 2022
cc946cd
style
levsnv Feb 18, 2022
91ecb20
add the old instantiations back
levsnv Feb 18, 2022
67f900e
fixed
levsnv Feb 18, 2022
c3c4deb
style
levsnv Feb 18, 2022
68e7f65
base_node::output() now compiles.
canonizer Mar 9, 2022
fe2cf92
Merge branch 'branch-22.04' into dev-fil-fp64
canonizer Mar 9, 2022
b31069d
Fixed style.
canonizer Mar 10, 2022
c250653
F -> real_t.
canonizer Mar 10, 2022
2f658b5
Small fixes.
canonizer Mar 10, 2022
6fb105b
Updated alignment.
canonizer Mar 10, 2022
f1a10be
static_assert(real_t == float) in a number of places.
canonizer Mar 10, 2022
ce3624e
noinline -> forceinline.
canonizer Mar 10, 2022
63fadd1
Updated comment.
canonizer Mar 11, 2022
986a50d
Merge branch 'dev-fil-fp64' into enh-fil-func64
canonizer Mar 11, 2022
fda7669
Fixed many compiler errors.
canonizer Mar 12, 2022
40f7a23
Multiple changes.
canonizer Mar 14, 2022
12cc051
Fixed compilation errors; now it compiles.
canonizer Mar 14, 2022
2bc6b6e
float -> void in predict().
canonizer Mar 16, 2022
138e1dd
Some templating.
canonizer Mar 16, 2022
532686f
template_forest<real_t> for type-dependent forest members.
canonizer Mar 17, 2022
ce689f7
Instantiate forests with double.
canonizer Mar 17, 2022
98b997a
Small changes.
canonizer Mar 17, 2022
8c84cf7
Templatized BaseFilTest.
canonizer Mar 18, 2022
05de38d
Templatized child_index tests, added float64-only tests.
canonizer Mar 18, 2022
316e99a
float64 versions of multi-sum and FIL predict tests.
canonizer Mar 19, 2022
2ba5eed
compute_smem_footprint() uses float or double, based on sizeof_real.
canonizer Mar 19, 2022
0395979
Merge branch 'branch-22.04' into dev-fil64
canonizer Mar 22, 2022
ac92be7
Removed stray static_asserts.
canonizer Mar 22, 2022
2d51762
Merge branch 'branch-22.06' into dev-fil64
canonizer Mar 31, 2022
0db5e37
Merge branch 'branch-22.06' into dev-fil64
canonizer Apr 4, 2022
92c44af
Finish merge.
canonizer Apr 4, 2022
938e02a
Fixed compilation errors.
canonizer Apr 4, 2022
c665bbf
Fixed endless recursion in forest::free().
canonizer Apr 4, 2022
175837a
Removed changes to fil.h.
canonizer Apr 4, 2022
886c649
Refactored tests.
canonizer Apr 4, 2022
1426c14
noinline -> forceinline.
canonizer Apr 4, 2022
d1fe2e0
Merge branch 'branch-22.06' into dev-fil64
canonizer Apr 6, 2022
56ecd51
Addressed review comments.
canonizer Apr 6, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion cpp/src/fil/fil.cu
Original file line number Diff line number Diff line change
Expand Up @@ -504,9 +504,9 @@ struct sparse_forest : forest<typename node_t::real_type> {

void free(const raft::handle_t& h) override
{
forest<real_type>::free(h);
trees_.release();
nodes_.release();
forest<real_type>::free(h);
}

int num_nodes_ = 0;
Expand Down Expand Up @@ -616,13 +616,27 @@ template void init<dense_node<float>, float>(const raft::handle_t& h,
const int* trees,
const dense_node<float>* nodes,
const forest_params_t* params);
template void init<dense_node<double>, double>(const raft::handle_t& h,
forest_t<double>* pf,
const categorical_sets& cat_sets,
const std::vector<double>& vector_leaf,
const int* trees,
const dense_node<double>* nodes,
const forest_params_t* params);
template void init<sparse_node16<float>, float>(const raft::handle_t& h,
forest_t<float>* pf,
const categorical_sets& cat_sets,
const std::vector<float>& vector_leaf,
const int* trees,
const sparse_node16<float>* nodes,
const forest_params_t* params);
template void init<sparse_node16<double>, double>(const raft::handle_t& h,
forest_t<double>* pf,
const categorical_sets& cat_sets,
const std::vector<double>& vector_leaf,
const int* trees,
const sparse_node16<double>* nodes,
const forest_params_t* params);
template void init<sparse_node8, float>(const raft::handle_t& h,
forest_t<float>* pf,
const categorical_sets& cat_sets,
Expand All @@ -639,6 +653,7 @@ void free(const raft::handle_t& h, forest_t<real_t> f)
}

template void free<float>(const raft::handle_t& h, forest_t<float> f);
template void free<double>(const raft::handle_t& h, forest_t<double> f);

template <typename real_t>
void predict(const raft::handle_t& h,
Expand All @@ -657,6 +672,12 @@ template void predict<float>(const raft::handle_t& h,
const float* data,
size_t num_rows,
bool predict_proba);
template void predict<double>(const raft::handle_t& h,
forest_t<double> f,
double* preds,
const double* data,
size_t num_rows,
bool predict_proba);

} // namespace fil
} // namespace ML
6 changes: 6 additions & 0 deletions cpp/src/fil/infer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -917,9 +917,15 @@ void infer(storage_type forest, predict_params params, cudaStream_t stream)
template void infer<dense_storage_f32>(dense_storage_f32 forest,
predict_params params,
cudaStream_t stream);
template void infer<dense_storage_f64>(dense_storage_f64 forest,
predict_params params,
cudaStream_t stream);
template void infer<sparse_storage16_f32>(sparse_storage16_f32 forest,
predict_params params,
cudaStream_t stream);
template void infer<sparse_storage16_f64>(sparse_storage16_f64 forest,
predict_params params,
cudaStream_t stream);
template void infer<sparse_storage8>(sparse_storage8 forest,
predict_params params,
cudaStream_t stream);
Expand Down
Loading