Skip to content

Commit

Permalink
[Feat] Save and load cuckoo hashtable from hdfs
Browse files Browse the repository at this point in the history
code format and fix bracket issue

[fix] Fixed bug which cause by hiredis updating failed to compile.

add hadoop

fix undefined symbol and skip test hdfs

fix test error

fix variable-sized object may not be initialized

fix build for tf version >= 2.7.0

save to tmp file

convert values to const

load_from_hdfs adopts zero-copy method
  • Loading branch information
luliyucoordinate authored and rhdong committed Jul 18, 2022
1 parent c8c73c7 commit 64631be
Show file tree
Hide file tree
Showing 12 changed files with 1,070 additions and 2 deletions.
10 changes: 10 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,16 @@ http_archive(
url = "https://github.com/sewenew/redis-plus-plus/archive/refs/tags/1.2.3.zip",
)

http_archive(
name = "hadoop",
build_file = "//third_party:hadoop.BUILD",
sha256 = "fa9d0587d06c36838e778081bcf8271a9c63060af00b3bf456423c1777a62043",
strip_prefix = "hadoop-rel-release-3.3.0",
urls = [
"https://github.com/apache/hadoop/archive/refs/tags/rel/release-3.3.0.tar.gz",
],
)

tf_configure(
name = "local_config_tf",
)
Expand Down
5 changes: 4 additions & 1 deletion tensorflow_recommenders_addons/dynamic_embedding/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ custom_op_library(
"utils/utils.h",
"utils/types.h",
] + glob(["kernels/lookup_impl/lookup_table_op_gpu*"])),
deps = ["//tensorflow_recommenders_addons/dynamic_embedding/core/lib/cuckoo:cuckoohash"],
deps = [
"//tensorflow_recommenders_addons/dynamic_embedding/core/lib/cuckoo:cuckoohash",
"//tensorflow_recommenders_addons/dynamic_embedding/core/lib/hadoop_file_system",
],
)

custom_op_library(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,18 @@ class CuckooHashTableOfTensors final : public LookupInterface {
return table_->export_values(ctx, value_dim);
}

Status SaveToHDFS(OpKernelContext* ctx, const string& filepath,
const size_t buffer_size) {
int64 value_dim = value_shape_.dim_size(0);
return table_->save_to_hdfs(ctx, value_dim, filepath, buffer_size);
}

Status LoadFromHDFS(OpKernelContext* ctx, const string& filepath,
const size_t buffer_size) {
int64 value_dim = value_shape_.dim_size(0);
return table_->load_from_hdfs(ctx, value_dim, filepath, buffer_size);
}

DataType key_dtype() const override { return DataTypeToEnum<K>::v(); }

DataType value_dtype() const override { return DataTypeToEnum<V>::v(); }
Expand Down Expand Up @@ -607,6 +619,36 @@ class HashTableExportOp : public HashTableOpKernel {
}
};

// Op that export all keys and values to HDFS.
template <class K, class V>
class HashTableSaveToHDFSOp : public HashTableOpKernel {
public:
explicit HashTableSaveToHDFSOp(OpKernelConstruction* ctx)
: HashTableOpKernel(ctx) {
int64 signed_buffer_size = 0;
ctx->GetAttr("buffer_size", &signed_buffer_size);
buffer_size_ = static_cast<size_t>(signed_buffer_size);
}

void Compute(OpKernelContext* ctx) override {
LookupInterface* table;
OP_REQUIRES_OK(ctx, GetTable(ctx, &table));
core::ScopedUnref unref_me(table);

const Tensor& ftensor = ctx->input(1);
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ftensor.shape()),
errors::InvalidArgument("filepath must be scalar."));
string filepath = string(ftensor.scalar<tstring>()().data());

lookup::CuckooHashTableOfTensors<K, V>* table_cuckoo =
(lookup::CuckooHashTableOfTensors<K, V>*)table;
OP_REQUIRES_OK(ctx, table_cuckoo->SaveToHDFS(ctx, filepath, buffer_size_));
}

private:
size_t buffer_size_;
};

// Clear the table and insert data.
class HashTableImportOp : public HashTableOpKernel {
public:
Expand Down Expand Up @@ -637,6 +679,37 @@ class HashTableImportOp : public HashTableOpKernel {
}
};

// Clear the table and insert data from HDFS.
template <class K, class V>
class HashTableLoadFromHDFSOp : public HashTableOpKernel {
public:
explicit HashTableLoadFromHDFSOp(OpKernelConstruction* ctx)
: HashTableOpKernel(ctx) {
int64 signed_buffer_size = 0;
ctx->GetAttr("buffer_size", &signed_buffer_size);
buffer_size_ = static_cast<size_t>(signed_buffer_size);
}

void Compute(OpKernelContext* ctx) override {
LookupInterface* table;
OP_REQUIRES_OK(ctx, GetTable(ctx, &table));
core::ScopedUnref unref_me(table);

const Tensor& ftensor = ctx->input(1);
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ftensor.shape()),
errors::InvalidArgument("filepath must be scalar."));
string filepath = string(ftensor.scalar<tstring>()().data());

lookup::CuckooHashTableOfTensors<K, V>* table_cuckoo =
(lookup::CuckooHashTableOfTensors<K, V>*)table;
OP_REQUIRES_OK(ctx,
table_cuckoo->LoadFromHDFS(ctx, filepath, buffer_size_));
}

private:
size_t buffer_size_;
};

REGISTER_KERNEL_BUILDER(
Name(PREFIX_OP_NAME(CuckooHashTableFind)).Device(DEVICE_CPU),
HashTableFindOp);
Expand Down Expand Up @@ -679,7 +752,17 @@ REGISTER_KERNEL_BUILDER(
.Device(DEVICE_CPU) \
.TypeConstraint<key_dtype>("Tin") \
.TypeConstraint<value_dtype>("Tout"), \
HashTableFindWithExistsOp<key_dtype, value_dtype>);
HashTableFindWithExistsOp<key_dtype, value_dtype>); \
REGISTER_KERNEL_BUILDER(Name(PREFIX_OP_NAME(CuckooHashTableSaveToHDFS)) \
.Device(DEVICE_CPU) \
.TypeConstraint<key_dtype>("key_dtype") \
.TypeConstraint<value_dtype>("value_dtype"), \
HashTableSaveToHDFSOp<key_dtype, value_dtype>); \
REGISTER_KERNEL_BUILDER(Name(PREFIX_OP_NAME(CuckooHashTableLoadFromHDFS)) \
.Device(DEVICE_CPU) \
.TypeConstraint<key_dtype>("key_dtype") \
.TypeConstraint<value_dtype>("value_dtype"), \
HashTableLoadFromHDFSOp<key_dtype, value_dtype>);

REGISTER_KERNEL(int32, double);
REGISTER_KERNEL(int32, float);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/io/buffered_inputstream.h"
#include "tensorflow/core/lib/io/random_inputstream.h"
#include "tensorflow_recommenders_addons/dynamic_embedding/core/lib/cuckoo/cuckoohash_map.hh"
#include "tensorflow_recommenders_addons/dynamic_embedding/core/lib/hadoop_file_system/hadoop_file_system.h"
#include "tensorflow_recommenders_addons/dynamic_embedding/core/utils/types.h"

namespace tensorflow {
Expand Down Expand Up @@ -135,6 +138,12 @@ class TableWrapperBase {
virtual Status export_values(OpKernelContext* ctx, int64 value_dim) {
return Status::OK();
}
virtual Status save_to_hdfs(OpKernelContext* ctx, int64 value_dim,
const string& filepath,
const size_t buffer_size) {}
virtual Status load_from_hdfs(OpKernelContext* ctx, int64 value_dim,
const string& filepath,
const size_t buffer_size) {}
};

template <class K, class V, size_t DIM>
Expand Down Expand Up @@ -238,6 +247,79 @@ class TableWrapperOptimized final : public TableWrapperBase<K, V> {
return Status::OK();
}

Status save_to_hdfs(OpKernelContext* ctx, int64 value_dim,
const string& filepath,
const size_t buffer_size) override {
size_t dim = static_cast<size_t>(value_dim);
auto lt = table_->lock_table();

HadoopFileSystem hdfs;
std::unique_ptr<WritableFile> writer;
const string tmp_file = filepath + ".tmp";
TF_RETURN_IF_ERROR(hdfs.NewWritableFile(tmp_file, &writer));

const uint32 value_len = sizeof(V) * dim;
const uint32 record_len = sizeof(K) + value_len;
uint64 pos = 0;
uint8 content[buffer_size + record_len];

for (auto it = lt.begin(); it != lt.end(); ++it) {
K k = it->first;
std::memcpy(content + pos, reinterpret_cast<uint8*>(&k), sizeof(K));

const auto& jt = it->second.data();
std::memcpy(content + pos + sizeof(K), reinterpret_cast<uint8*>(jt),
value_len);

pos += record_len;
if (pos > buffer_size) {
TF_RETURN_IF_ERROR(
writer->Append(StringPiece(reinterpret_cast<char*>(content), pos)));
pos = 0;
}
}

if (pos > 0) {
TF_RETURN_IF_ERROR(
writer->Append(StringPiece(reinterpret_cast<char*>(content), pos)));
}

TF_RETURN_IF_ERROR(writer->Close());
TF_RETURN_IF_ERROR(hdfs.RenameFile(tmp_file, filepath));
return Status::OK();
}

Status load_from_hdfs(OpKernelContext* ctx, int64 value_dim,
const string& filepath,
const size_t buffer_size) override {
size_t dim = static_cast<size_t>(value_dim);

HadoopFileSystem hdfs;
std::unique_ptr<RandomAccessFile> file;
TF_RETURN_IF_ERROR(hdfs.NewRandomAccessFile(filepath, &file));
std::unique_ptr<io::RandomAccessInputStream> input_stream(
new io::RandomAccessInputStream(file.get()));
io::BufferedInputStream reader(input_stream.get(), buffer_size);

uint64 file_size = 0;
TF_RETURN_IF_ERROR(hdfs.GetFileSize(filepath, &file_size));

tstring content;
const uint32 value_len = sizeof(V) * dim;
const uint32 record_len = sizeof(K) + value_len;
uint64 i = 0;

while (i < file_size) {
TF_RETURN_IF_ERROR(reader.ReadNBytes(record_len, &content));
K* k = reinterpret_cast<K*>(content.data());
ValueType* value_vec =
reinterpret_cast<ValueType*>(content.data() + sizeof(K));
table_->insert_or_assign(*k, *value_vec);
i += record_len;
}
return Status::OK();
}

private:
size_t init_size_;
Table* table_;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
load("@local_config_tf//:build_defs.bzl", "DTF_VERSION_INTEGER", "D_GLIBCXX_USE_CXX11_ABI")

package(default_visibility = ["//visibility:public"])

cc_library(
name = "hadoop_file_system",
srcs = ["hadoop_file_system.cc"],
hdrs = ["hadoop_file_system.h"],
copts = [
D_GLIBCXX_USE_CXX11_ABI,
DTF_VERSION_INTEGER,
],
deps = [
"@hadoop",
"@local_config_tf//:libtensorflow_framework",
"@local_config_tf//:tf_header_lib",
],
alwayslink = 1,
)
Loading

0 comments on commit 64631be

Please sign in to comment.