-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replace k-means++ CPU bottleneck with a random::discrete
prim
#1039
Replace k-means++ CPU bottleneck with a random::discrete
prim
#1039
Conversation
Sorry for the absolute mess in the diff and automatic labels / review requests, but |
Reverting target branch for now to |
Could we do some little tricks here like setting most of the weights to be a very small value and then setting only a few weights to be much larger values and then comparing the unique set of sampled indices against the expected histogram to see that the unique list of indices that were sampled are all peaks in the expected histogram? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes look good, though aside from adding the test for the smaller sample sizes w/ larger weights, we should add a quick usage example for the docs.
I have consolidated the tests with small |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. I'm pre-approving, but we will need to change the target branch to 23.02 before this is merged.
@gpucibot merge |
Currently, k-means has a CPU bottleneck in its k-means++ initialization, especially when the dataset is tall and thin. At each step of the k-means++ initialization, the distances to the closest cluster centroid are copied to the host, and candidates centroids are selected using
std::discrete_distribution
. This distribution reduces and scans the weights, which is an expensive operation if the dataset has many rows, and should be done on GPU.For the test of this new primitive, I use a small number of weights and a large number of samples and compare the actual vs expected histogram using a tolerance of 4*sigma where sigma is the standard deviation computed from the number of samples and the smallest non-zero weight. I don't know any good way to correctly test the primitive with a very large number of weights and a small number of samples, but I'm open to suggestions.