-
Notifications
You must be signed in to change notification settings - Fork 138
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feat] add Hashtable On GPU #74
Conversation
rhdong
commented
May 10, 2021
- add Hashtable On GPU
- switch cuda tool chain to cuda11
- update STYLE GUIDE for clang format on MacOS.
- update bazel version to 3.7.2
size_t default_value_num = | ||
is_full_default ? default_value.shape().dim_size(0) : 1; | ||
CUDA_CHECK(cudaStreamCreate(&_stream)); | ||
CUDA_CHECK(cudaMalloc((void**)&d_status, sizeof(bool) * len)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it to use cudaMallocManaged?
Allocates memory that will be automatically managed by the Unified Memory system.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have tried but fail, and I will consider your advice next version, for this is a stable implement.
3.7.2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the bazel version 3.7.2 necessary? If it is, please also mention it in README.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
accept
~CuckooHashTableOfTensorsGpu() { delete table_; } | ||
|
||
void CreateTable(size_t max_size, gpu::TableWrapperBase<K, V>** pptable) { | ||
if (runtime_dim_ <= 50) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is possible to use mod 50
and switch
statement to achieve better performance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just run one time, and mod50 & switch will have worse readability
CUDA_CHECK(cudaStreamCreate(&_stream)); | ||
CUDA_CHECK(cudaMalloc((void**)&d_keys, sizeof(K) * len)); | ||
CUDA_CHECK(cudaMalloc((void**)&d_values, sizeof(V) * runtime_dim_ * len)); | ||
CUDA_CHECK(cudaMemcpy((void*)d_keys, (void*)keys.tensor_data().data(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the operation is registered on Device_GPU, the input Tensor is already on GPU device memory. So it is not strictly needed to do the copy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tested it, sometimes inputs were in CPU memory that caused crash, so I used memcpy..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made several review comments. Please check it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
ded0900
to
32eb233
Compare
9ebd069
to
9d71562
Compare