Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace k-means++ CPU bottleneck with a random::discrete prim #1039

Merged
merged 6 commits into from
Nov 30, 2022

Conversation

Nyrio
Copy link
Contributor

@Nyrio Nyrio commented Nov 22, 2022

Currently, k-means has a CPU bottleneck in its k-means++ initialization, especially when the dataset is tall and thin. At each step of the k-means++ initialization, the distances to the closest cluster centroid are copied to the host, and candidates centroids are selected using std::discrete_distribution. This distribution reduces and scans the weights, which is an expensive operation if the dataset has many rows, and should be done on GPU.

For the test of this new primitive, I use a small number of weights and a large number of samples and compare the actual vs expected histogram using a tolerance of 4*sigma where sigma is the standard deviation computed from the number of samples and the smallest non-zero weight. I don't know any good way to correctly test the primitive with a very large number of weights and a small number of samples, but I'm open to suggestions.

@Nyrio Nyrio requested review from a team as code owners November 22, 2022 15:11
@Nyrio Nyrio added 3 - Ready for Review improvement Improvement / enhancement to an existing function non-breaking Non-breaking change labels Nov 22, 2022
@Nyrio Nyrio changed the base branch from branch-22.12 to branch-23.02 November 22, 2022 15:22
@Nyrio Nyrio requested review from a team as code owners November 22, 2022 15:22
@Nyrio
Copy link
Contributor Author

Nyrio commented Nov 22, 2022

Sorry for the absolute mess in the diff and automatic labels / review requests, but branch-23.02 is behind branch-22.12 by 22 commits...

@Nyrio Nyrio changed the base branch from branch-23.02 to branch-22.12 November 22, 2022 15:29
@Nyrio
Copy link
Contributor Author

Nyrio commented Nov 22, 2022

Reverting target branch for now to branch-22.12 to make reviews easier.

@Nyrio
Copy link
Contributor Author

Nyrio commented Nov 22, 2022

This is a figure of the before/after overall performance of k-means, with k-means++ initialization and a tall and thin dataset (note that the raft benchmark uses k-means|| initialization, which is apparently the default, for which no significant perf difference should be observed).

2022-11-22_kmeansplusplus_perf_comp

@Nyrio
Copy link
Contributor Author

Nyrio commented Nov 22, 2022

Nsight Systems timelines for comparison.

before
2022-11-22_kmeansplusplus_before


after
2022-11-22_kmeansplusplus_after

@cjnolet
Copy link
Member

cjnolet commented Nov 23, 2022

I don't know any good way to correctly test the primitive with a very large number of weights and a small number of samples, but I'm open to suggestions.

Could we do some little tricks here like setting most of the weights to be a very small value and then setting only a few weights to be much larger values and then comparing the unique set of sampled indices against the expected histogram to see that the unique list of indices that were sampled are all peaks in the expected histogram?

Copy link
Member

@cjnolet cjnolet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes look good, though aside from adding the test for the smaller sample sizes w/ larger weights, we should add a quick usage example for the docs.

cpp/include/raft/random/rng.cuh Show resolved Hide resolved
cpp/include/raft/random/detail/rng_device.cuh Show resolved Hide resolved
@Nyrio
Copy link
Contributor Author

Nyrio commented Nov 24, 2022

I have consolidated the tests with small sampled_len / large len.

@Nyrio Nyrio requested a review from cjnolet November 24, 2022 20:19
Copy link
Member

@cjnolet cjnolet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I'm pre-approving, but we will need to change the target branch to 23.02 before this is merged.

@cjnolet cjnolet changed the base branch from branch-22.12 to branch-23.02 November 29, 2022 23:39
@cjnolet
Copy link
Member

cjnolet commented Nov 30, 2022

@gpucibot merge

@rapids-bot rapids-bot bot merged commit 63a1d94 into rapidsai:branch-23.02 Nov 30, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
3 - Ready for Review CMake cpp improvement Improvement / enhancement to an existing function non-breaking Non-breaking change
Projects
Development

Successfully merging this pull request may close these issues.

2 participants