-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Learn heuristic to pick fastest select_k algorithm #1523
Learn heuristic to pick fastest select_k algorithm #1523
Conversation
This uses the select_k dataset from rapidsai#1497 to learn a heuristic of the fastest select_k variant based off the rows/ cols/k of the input. This heuristic is modelled as a DecisionTree, which is automatically exported in C++ code that is compiled into RAFT. This lets us learn a function to pick the fastest select_k method - which requires only a few if statements in C++ code, making it very cheap to evaluate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Just one tiny suggestion.
* on different values of rows/cols/k. The decision tree is converted to c++ | ||
* code, which is cut and paste below. | ||
* | ||
* The code to generate is in cpp/scripts/heuristics/select_k, running the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a tiny nitpick:
* The code to generate is in cpp/scripts/heuristics/select_k, running the | |
* NOTE: The code to generate is in cpp/scripts/heuristics/select_k, running the |
/merge |
*/ | ||
inline Algo choose_select_k_algorithm(size_t rows, size_t cols, int k) | ||
{ | ||
if (k > 134) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think, we'd better use log_2(k) instead of k when constructing the heuristic, so that all values of k go in powers of two. For all warp-based algorithms, performance for non-powers of two is equal to their rounded-up powers of two (queue capacity parameter).
This uses the select_k dataset from #1497 to learn a heuristic of the fastest select_k variant based off the rows/ cols/k of the input. This heuristic is modelled as a DecisionTree, which is automatically exported in C++ code that is compiled into RAFT. This lets us learn a function to pick the fastest select_k method - which requires only a few if statements in C++ code, making it very cheap to evaluate.