-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Test categorical features with column-split gpu quantile #9595
Changes from 6 commits
2f2ac15
1847628
69daaf3
3915bbb
89fb13e
db9493a
80ed92c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -634,12 +634,25 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts, bool is_column_split) { | |
}); | ||
CHECK_EQ(num_columns_, d_in_columns_ptr.size() - 1); | ||
max_values.resize(d_in_columns_ptr.size() - 1); | ||
|
||
// In some cases (e.g. column-wise data split), we may have empty columns, so we need to keep | ||
// track of the unique keys (feature indices) after the thrust::reduce_by_key` call. | ||
dh::caching_device_vector<size_t> d_max_keys(d_in_columns_ptr.size() - 1); | ||
dh::caching_device_vector<SketchEntry> d_max_values(d_in_columns_ptr.size() - 1); | ||
thrust::reduce_by_key(thrust::cuda::par(alloc), key_it, key_it + in_cut_values.size(), val_it, | ||
thrust::make_discard_iterator(), d_max_values.begin(), | ||
thrust::equal_to<bst_feature_t>{}, | ||
[] __device__(auto l, auto r) { return l.value > r.value ? l : r; }); | ||
dh::CopyDeviceSpanToVector(&max_values, dh::ToSpan(d_max_values)); | ||
auto new_end = thrust::reduce_by_key( | ||
thrust::cuda::par(alloc), key_it, key_it + in_cut_values.size(), val_it, d_max_keys.begin(), | ||
d_max_values.begin(), thrust::equal_to<bst_feature_t>{}, | ||
[] __device__(auto l, auto r) { return l.value > r.value ? l : r; }); | ||
d_max_keys.erase(new_end.first, d_max_keys.end()); | ||
d_max_values.erase(new_end.second, d_max_values.end()); | ||
|
||
// The device vector needs to be initialized explicitly since we may have some missing columns. | ||
SketchEntry default_entry{}; | ||
dh::caching_device_vector<SketchEntry> d_max_results(d_in_columns_ptr.size() - 1, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you sure the caching device vector does initialize the value? (call constructor) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also verified it in debugger. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it, I think we ran into trouble with it before as commented in the |
||
default_entry); | ||
thrust::scatter(d_max_values.begin(), d_max_values.end(), d_max_keys.begin(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. exec policy? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
d_max_results.begin()); | ||
dh::CopyDeviceSpanToVector(&max_values, dh::ToSpan(d_max_results)); | ||
auto max_it = MakeIndexTransformIter([&](auto i) { | ||
if (IsCat(h_feature_types, i)) { | ||
return max_values[i].value; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit confused by these two erases, what are they doing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shrink the two vectors to actual size. If we have missing columns, they won't be fully populated.