Skip to content

Commit

Permalink
Fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
pkufool committed Jan 7, 2022
1 parent 8907406 commit bb11575
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 28 deletions.
55 changes: 28 additions & 27 deletions k2/csrc/hash.h
Original file line number Diff line number Diff line change
Expand Up @@ -1021,9 +1021,13 @@ class Hash {
decide the number of buckets, when you create the hash, but you can resize
it (manually).
Note:
Each bucket contains a pair of key/value, each 64bit, key is stored at
data[2 * bucket_index] and value is stored at data[2 * bucket_index + 1].
Some constraints:
- You can store any (key,value) pair, except the pair where all the bits of
both and key and value are set [that is used to mean "nothing here"]
both key and value are set [that is used to mean "nothing here"]
- The number of buckets must always be a power of 2.
- When deleting values from the hash you must delete them all at
once (necessary because there is no concept of a "tombstone".
Expand Down Expand Up @@ -1058,10 +1062,10 @@ class Hash64 {
buckets, or the code will loop infinitely when you try to add
items; aim for less than 50% occupancy.
*/
Hash64(ContextPtr c, int32_t num_buckets) {
Hash64(ContextPtr c, int64_t num_buckets) {
K2_CHECK_GE(num_buckets, 128);
data_ = Array1<uint64_t>(c, num_buckets * 2, ~(uint64_t)0);
int32_t n = 2;
int64_t n = 2;
for (buckets_num_bitsm1_ = 0; n < num_buckets;
n *= 2, buckets_num_bitsm1_++) {
}
Expand All @@ -1072,7 +1076,7 @@ class Hash64 {
// Only to be used prior to assignment.
Hash64() = default;

int32_t NumBuckets() const { return data_.Dim() / 2; }
int64_t NumBuckets() const { return data_.Dim() / 2; }

// Returns data pointer; for testing..
uint64_t *Data() { return data_.Data(); }
Expand All @@ -1088,7 +1092,7 @@ class Hash64 {
public:
Accessor(Hash64 &hash)
: data_(hash.data_.Data()),
num_buckets_mask_(uint32_t(hash.NumBuckets()) - 1),
num_buckets_mask_(uint64_t(hash.NumBuckets()) - 1),
buckets_num_bitsm1_(hash.buckets_num_bitsm1_) {}

// Copy constructor
Expand Down Expand Up @@ -1116,7 +1120,7 @@ class Hash64 {
__forceinline__ __host__ __device__ bool Insert(
uint64_t key, uint64_t value, uint64_t *old_value = nullptr,
uint64_t **key_value_location = nullptr) const {
uint32_t cur_bucket = static_cast<uint32_t>(key) & num_buckets_mask_,
uint64_t cur_bucket = key & num_buckets_mask_,
bucket_inc = 1 | ((key >> buckets_num_bitsm1_) ^ key);

while (1) {
Expand Down Expand Up @@ -1175,7 +1179,7 @@ class Hash64 {
__forceinline__ __host__ __device__ bool Find(
uint64_t key, uint64_t *value_out,
uint64_t **key_value_location = nullptr) const {
uint32_t cur_bucket = key & num_buckets_mask_,
uint64_t cur_bucket = key & num_buckets_mask_,
bucket_inc = 1 | ((key >> buckets_num_bitsm1_) ^ key);
while (1) {
uint64_t old_key = data_[2 * cur_bucket];
Expand All @@ -1198,16 +1202,13 @@ class Hash64 {
Find().
@param [in] key_value_location Location that was obtained from
a successful call to Find().
@param [in] key Required to be the same key that was provided to
Find(); it is an error otherwise.
@param [in] value Value to write;
Note: the const is with respect to the metadata only; it is required, to
avoid compilation errors.
*/
__forceinline__ __host__ __device__ void SetValue(
uint64_t *key_value_location, uint64_t key, uint64_t value) const {
*key_value_location = key;
uint64_t *key_value_location, uint64_t value) const {
*(key_value_location + 1) = value;
}

Expand All @@ -1224,7 +1225,7 @@ class Hash64 {
compilation errors.
*/
__forceinline__ __host__ __device__ void Delete(uint64_t key) const {
uint32_t cur_bucket = key & num_buckets_mask_,
uint64_t cur_bucket = key & num_buckets_mask_,
bucket_inc = 1 | ((key >> buckets_num_bitsm1_) ^ key);
while (1) {
uint64_t old_key = data_[2 * cur_bucket];
Expand All @@ -1244,10 +1245,10 @@ class Hash64 {
// num_buckets_mask is num_buckets (i.e. size of `data_` array) minus one;
// num_buckets is a power of 2 so this can be used as a mask to get a number
// modulo num_buckets.
uint32_t num_buckets_mask_;
uint64_t num_buckets_mask_;
// A number satisfying num_buckets == 1 << (1+buckets_num_bitsm1_)
// the number of bits in `num_buckets` minus one.
uint32_t buckets_num_bitsm1_;
uint64_t buckets_num_bitsm1_;
};

/*
Expand All @@ -1260,25 +1261,25 @@ class Hash64 {
// contain values when it is destroyed, to bypass a check.
void Destroy() { data_ = Array1<uint64_t>(); }

void CheckEmpty() {
void CheckEmpty() const {
if (data_.Dim() == 0) return;
ContextPtr c = Context();
Array1<int32_t> error(c, 1, -1);
int32_t *error_data = error.Data();
uint64_t *hash_data = data_.Data();
Array1<int64_t> error(c, 1, -1);
int64_t *error_data = error.Data();
const uint64_t *hash_data = data_.Data();

K2_EVAL(
Context(), data_.Dim(), lambda_check_data, (int32_t i)->void {
Context(), data_.Dim(), lambda_check_data, (int64_t i)->void {
if (~(hash_data[i]) != 0) error_data[0] = i;
});
int32_t i = error[0];
int64_t i = error[0];
if (i >= 0) { // there was an error; i is the index into the hash where
// there was an element.
int64_t elem = data_[i];
// We don't know the number of bits the user was using for the key vs.
// value, so print in hex, maybe they can figure it out.
K2_LOG(FATAL) << "Destroying hash: still contains values: position " << i
<< ", key,value = " << std::hex << elem;
<< ", content = " << std::hex << elem;
}
}

Expand All @@ -1292,7 +1293,7 @@ class Hash64 {
CAUTION: Resizing will invalidate any accessor objects you have; you need
to re-get the accessors before accessing the hash again.
*/
void Resize(int32_t new_num_buckets, bool copy_data = true) {
void Resize(int64_t new_num_buckets, bool copy_data = true) {
NVTX_RANGE(K2_FUNC);

K2_CHECK_GT(new_num_buckets, 0);
Expand All @@ -1315,19 +1316,19 @@ class Hash64 {
*/
void CopyDataFromSimple(Hash64 &src) {
NVTX_RANGE(K2_FUNC);
int32_t num_buckets = data_.Dim() / 2,
int64_t num_buckets = data_.Dim() / 2,
src_num_buckets = src.data_.Dim() / 2;
const uint64_t *src_data = src.data_.Data();
uint64_t *data = data_.Data();
size_t new_num_buckets_mask = static_cast<size_t>(num_buckets) - 1,
uint64_t new_num_buckets_mask = static_cast<uint64_t>(num_buckets) - 1,
new_buckets_num_bitsm1 = buckets_num_bitsm1_;
ContextPtr c = data_.Context();
K2_EVAL(c, src_num_buckets, lambda_copy_data, (int32_t i) -> void {
K2_EVAL(c, src_num_buckets, lambda_copy_data, (uint64_t i) -> void {
uint64_t key = src_data[2 * i];
uint64_t value = src_data[2 * i + 1];
if (~key == 0) return; // equals -1.. nothing there.
uint64_t bucket_inc = 1 | ((key >> new_buckets_num_bitsm1) ^ key);
size_t cur_bucket = key & new_num_buckets_mask;
uint64_t cur_bucket = key & new_num_buckets_mask;
while (1) {
uint64_t assumed = ~((uint64_t)0),
old_elem = AtomicCAS((unsigned long long*)(data + 2 * cur_bucket),
Expand All @@ -1354,7 +1355,7 @@ class Hash64 {
Array1<uint64_t> data_;

// number satisfying data_.Dim() == 1 << (1+buckets_num_bitsm1_)
int32_t buckets_num_bitsm1_;
uint64_t buckets_num_bitsm1_;
};

/*
Expand Down
2 changes: 1 addition & 1 deletion k2/csrc/hash_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ void TestHash64Construct() {
}
uint64_t keyval = *key_value_location;
if (success) {
acc.SetValue(key_value_location, key, value);
acc.SetValue(key_value_location, value);
K2_DCHECK_EQ(keyval, *key_value_location);
}
success_data[i] = success;
Expand Down

0 comments on commit bb11575

Please sign in to comment.