diff --git a/WORKSPACE b/WORKSPACE index 9e07792b9..10894ce7e 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -57,55 +57,9 @@ http_archive( http_archive( name = "hkv", build_file = "//build_deps/toolchains/hkv:hkv.BUILD", - # TODO(LinGeLin) remove this when update hkv - patch_cmds = [ - """sed -i.bak '1772i\\'$'\\n ThrustAllocator thrust_allocator_;\\n' include/merlin_hashtable.cuh""", - """sed -i.bak '225i\\'$'\\n thrust_allocator_.set_allocator(allocator_);\\n' include/merlin_hashtable.cuh""", - "sed -i.bak 's/thrust::sort_by_key(thrust_par.on(stream)/thrust::sort_by_key(thrust_par(thrust_allocator_).on(stream)/' include/merlin_hashtable.cuh", - "sed -i.bak 's/reduce(thrust_par.on(stream)/reduce(thrust_par(thrust_allocator_).on(stream)/' include/merlin_hashtable.cuh", - """sed -i.bak '125i\\'$'\\n template \\n' include/merlin/allocator.cuh""", - """sed -i.bak '126i\\'$'\\n struct ThrustAllocator : thrust::device_malloc_allocator {\\n' include/merlin/allocator.cuh""", - """sed -i.bak '127i\\'$'\\n public:\\n' include/merlin/allocator.cuh""", - """sed -i.bak '128i\\'$'\\n typedef thrust::device_malloc_allocator super_t;\\n' include/merlin/allocator.cuh""", - """sed -i.bak '129i\\'$'\\n typedef typename super_t::pointer pointer;\\n' include/merlin/allocator.cuh""", - """sed -i.bak '130i\\'$'\\n typedef typename super_t::size_type size_type;\\n' include/merlin/allocator.cuh""", - """sed -i.bak '131i\\'$'\\n public:\\n' include/merlin/allocator.cuh""", - """sed -i.bak '132i\\'$'\\n pointer allocate(size_type n) {\\n' include/merlin/allocator.cuh""", - """sed -i.bak '133i\\'$'\\n void* ptr = nullptr;\\n' include/merlin/allocator.cuh""", - """sed -i.bak '134i\\'$'\\n MERLIN_CHECK(\\n' include/merlin/allocator.cuh""", - """sed -i.bak '135i\\'$'\\n allocator_ != nullptr,\\n' include/merlin/allocator.cuh""", - """sed -i.bak '136i\\'$'\\n "[ThrustAllocator] set_allocator should be called in advance!");\\n' include/merlin/allocator.cuh""", - """sed -i.bak '137i\\'$'\\n allocator_->alloc(MemoryType::Device, &ptr, sizeof(T) * n);\\n' include/merlin/allocator.cuh""", - """sed -i.bak '138i\\'$'\\n return pointer(reinterpret_cast(ptr));\\n' include/merlin/allocator.cuh""", - """sed -i.bak '139i\\'$'\\n }\\n' include/merlin/allocator.cuh""", - """sed -i.bak '140i\\'$'\\n void deallocate(pointer p, size_type n) {\\n' include/merlin/allocator.cuh""", - """sed -i.bak '141i\\'$'\\n MERLIN_CHECK(\\n' include/merlin/allocator.cuh""", - """sed -i.bak '142i\\'$'\\n allocator_ != nullptr,\\n' include/merlin/allocator.cuh""", - """sed -i.bak '143i\\'$'\\n "[ThrustAllocator] set_allocator should be called in advance!");\\n' include/merlin/allocator.cuh""", - """sed -i.bak '144i\\'$'\\n allocator_->free(MemoryType::Device, reinterpret_cast(p.get()));\\n' include/merlin/allocator.cuh""", - """sed -i.bak '145i\\'$'\\n }\\n' include/merlin/allocator.cuh""", - """sed -i.bak '146i\\'$'\\n void set_allocator(BaseAllocator* allocator) { allocator_ = allocator; }\\n' include/merlin/allocator.cuh""", - """sed -i.bak '147i\\'$'\\n public:\\n' include/merlin/allocator.cuh""", - """sed -i.bak '148i\\'$'\\n BaseAllocator* allocator_ = nullptr;\\n' include/merlin/allocator.cuh""", - """sed -i.bak '149i\\'$'\\n };\\n' include/merlin/allocator.cuh""", - """sed -i.bak '20i\\'$'\\n #include \\n' include/merlin/allocator.cuh""", - """sed -i.bak '367i\\'$'\\n for (auto addr : (*table)->buckets_address) {\\n' include/merlin/core_kernels.cuh""", - """sed -i.bak '368i\\'$'\\n allocator->free(MemoryType::Device, addr);\\n' include/merlin/core_kernels.cuh""", - """sed -i.bak '369i\\'$'\\n }\\n' include/merlin/core_kernels.cuh""", - """sed -i.bak '370i\\'$'\\n /*\\n' include/merlin/core_kernels.cuh""", - """sed -i.bak '382i\\'$'\\n */\\n' include/merlin/core_kernels.cuh""", - """sed -i.bak '224i\\'$'\\n uint8_t* address = nullptr;\\n' include/merlin/core_kernels.cuh""", - """sed -i.bak '225i\\'$'\\n allocator->alloc(MemoryType::Device, (void**)&(address), bucket_memory_size * (end - start));\\n' include/merlin/core_kernels.cuh""", - """sed -i.bak '226i\\'$'\\n (*table)->buckets_address.push_back(address);\\n' include/merlin/core_kernels.cuh""", - """sed -i.bak '228i\\'$'\\n allocate_bucket_others<<<1, 1>>>((*table)->buckets, i, address + (bucket_memory_size * (i-start)), reserve_size, bucket_max_size);\\n' include/merlin/core_kernels.cuh""", - """sed -i.bak '229i\\'$'\\n /*\\n' include/merlin/core_kernels.cuh""", - """sed -i.bak '235i\\'$'\\n */\\n' include/merlin/core_kernels.cuh""", - """sed -i.bak '22i\\'$'\\n#include \\n' include/merlin/types.cuh""", - """sed -i.bak '143i\\'$'\\n std::vector buckets_address;\\n' include/merlin/types.cuh""", - ], - sha256 = "f8179c445a06a558262946cda4d8ae7252d313e73f792586be9b1bc0c993b1cf", - strip_prefix = "HierarchicalKV-0.1.0-beta.6", - url = "https://github.com/NVIDIA-Merlin/HierarchicalKV/archive/refs/tags/v0.1.0-beta.6.tar.gz", + sha256 = "841be4cfb4059e5745838a23a32c776cfff1d38306b95b1ac2659df0d4d9709b", + strip_prefix = "HierarchicalKV-master", + url = "https://github.com/NVIDIA-Merlin/HierarchicalKV/archive/refs/heads/master.zip", ) tf_configure( diff --git a/tensorflow_recommenders_addons/dynamic_embedding/__init__.py b/tensorflow_recommenders_addons/dynamic_embedding/__init__.py index 3269c9171..e85aa8b41 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/__init__.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/__init__.py @@ -18,6 +18,7 @@ 'CuckooHashTable', 'CuckooHashTableConfig', 'CuckooHashTableCreator', + 'HkvEvictStrategy', 'HkvHashTable', 'HkvHashTableConfig', 'HkvHashTableCreator', @@ -55,7 +56,7 @@ from tensorflow_recommenders_addons.dynamic_embedding.python.ops import data_flow_ops as data_flow from tensorflow_recommenders_addons.dynamic_embedding.python.ops.dynamic_embedding_creator import ( KVCreator, CuckooHashTableConfig, CuckooHashTableCreator, - HkvHashTableConfig, HkvHashTableCreator, RedisTableConfig, + HkvHashTableConfig, HkvHashTableCreator, HkvEvictStrategy, RedisTableConfig, RedisTableCreator, FileSystemSaver) from tensorflow_recommenders_addons.dynamic_embedding.python.ops.cuckoo_hashtable_ops import ( CuckooHashTable,) diff --git a/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/hkv_hashtable_op_gpu.cu.cc b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/hkv_hashtable_op_gpu.cu.cc index 86b8cfb56..b91220b5d 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/hkv_hashtable_op_gpu.cu.cc +++ b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/hkv_hashtable_op_gpu.cu.cc @@ -45,6 +45,7 @@ limitations under the License. namespace tensorflow { using GPUDevice = Eigen::GpuDevice; +using HkvEvictStrategy = nv::merlin::EvictStrategy; namespace recommenders_addons { namespace lookup { @@ -74,19 +75,30 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface { int64 init_capacity_i64 = 0; int64 max_capacity_i64 = 0; int64 max_hbm_for_vectors_i64 = 0; + int64 evict_global_epoch = 0; + int strategy = 0; OP_REQUIRES_OK( ctx, GetNodeAttr(kernel->def(), "init_capacity", &init_capacity_i64)); OP_REQUIRES_OK( ctx, GetNodeAttr(kernel->def(), "max_capacity", &max_capacity_i64)); OP_REQUIRES_OK(ctx, GetNodeAttr(kernel->def(), "max_hbm_for_vectors", &max_hbm_for_vectors_i64)); + OP_REQUIRES_OK(ctx, GetNodeAttr(kernel->def(), "strategy", &strategy)); OP_REQUIRES( ctx, (max_hbm_for_vectors_i64 >= 0), errors::InvalidArgument("params max_hbm_for_vectors less than 0")); + OP_REQUIRES_OK(ctx, GetNodeAttr(kernel->def(), "evict_global_epoch", + &evict_global_epoch)); + + OP_REQUIRES( + ctx, (evict_global_epoch >= 0), + errors::InvalidArgument("params evict_global_epoch less than 0")); + options.init_capacity = static_cast(init_capacity_i64); options.max_capacity = static_cast(max_capacity_i64); options.max_hbm_for_vectors = static_cast(max_hbm_for_vectors_i64); + options.evict_global_epoch = static_cast(evict_global_epoch); if (options.max_capacity == 0) { char* env_max_capacity_str = @@ -118,8 +130,8 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface { return; } allocator_ptr_ = std::make_unique(ctx); - OP_REQUIRES_OK(ctx, - this->CreateTable(options, allocator_ptr_.get(), &table_)); + OP_REQUIRES_OK(ctx, this->CreateTable(options, allocator_ptr_.get(), + &table_, strategy)); OP_REQUIRES(ctx, (table_ != nullptr), errors::InvalidArgument("HashTable on GPU is created failed!")); LOG(INFO) << "GPU table max capacity was created on max_capacity: " @@ -139,8 +151,9 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface { Status CreateTable(gpu::TableWrapperInitOptions& options, nv::merlin::BaseAllocator* allocator, - gpu::TableWrapper** pptable) { - return gpu::CreateTableImpl(pptable, options, allocator, runtime_dim_); + gpu::TableWrapper** pptable, int strategy) { + return gpu::CreateTableImpl(pptable, options, allocator, runtime_dim_, + strategy); } size_t size() const override { @@ -241,8 +254,29 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface { { mutex_lock l(mu_); try { - table_->upsert((const K*)keys.tensor_data().data(), - (const V*)(values.tensor_data().data()), len, stream); + table_->upsert((const K*)(keys.tensor_data().data()), + (const V*)(values.tensor_data().data()), nullptr, len, + stream); + } catch (std::runtime_error& e) { + return gpu::ReturnInternalErrorStatus(e.what()); + } + } + CUDA_CHECK(cudaStreamSynchronize(stream)); + + return TFOkStatus; + } + + Status Insert(OpKernelContext* ctx, const Tensor& keys, const Tensor& values, + const Tensor& scores) { + size_t len = keys.flat().size(); + auto stream = ctx->eigen_device().stream(); + { + mutex_lock l(mu_); + try { + table_->upsert((const K*)(keys.tensor_data().data()), + (const V*)(values.tensor_data().data()), + (const uint64_t*)(scores.tensor_data().data()), len, + stream); } catch (std::runtime_error& e) { return gpu::ReturnInternalErrorStatus(e.what()); } @@ -259,9 +293,32 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface { { mutex_lock l(mu_); try { - table_->accum((const K*)keys.tensor_data().data(), + table_->accum((const K*)(keys.tensor_data().data()), + (const V*)(values_or_deltas.tensor_data().data()), + (const bool*)(exists.tensor_data().data()), nullptr, len, + stream); + } catch (std::runtime_error& e) { + return gpu::ReturnInternalErrorStatus(e.what()); + } + } + CUDA_CHECK(cudaStreamSynchronize(stream)); + + return TFOkStatus; + } + + Status Accum(OpKernelContext* ctx, const Tensor& keys, + const Tensor& values_or_deltas, const Tensor& exists, + const Tensor& scores) { + size_t len = keys.flat().size(); + auto stream = ctx->eigen_device().stream(); + { + mutex_lock l(mu_); + try { + table_->accum((const K*)(keys.tensor_data().data()), (const V*)(values_or_deltas.tensor_data().data()), - (const bool*)exists.tensor_data().data(), len, stream); + (const bool*)(exists.tensor_data().data()), + (const uint64_t*)(scores.tensor_data().data()), len, + stream); } catch (std::runtime_error& e) { return gpu::ReturnInternalErrorStatus(e.what()); } @@ -345,7 +402,8 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface { mutex_lock l(mu_); try { table_->clear(stream); - table_->upsert((const K*)d_keys, (const V*)d_values, len, stream); + table_->upsert((const K*)d_keys, (const V*)d_values, nullptr, len, + stream); CUDA_CHECK(cudaStreamSynchronize(stream)); } catch (std::runtime_error& e) { return gpu::ReturnInternalErrorStatus(e.what()); @@ -628,10 +686,6 @@ class HashTableFindGpuOp : public OpKernel { } }; -// REGISTER_KERNEL_BUILDER( -// Name(PREFIX_OP_NAME(HkvHashTableFind)).Device(DEVICE_GPU), -// HashTableFindGpuOp); - // Table lookup op. Perform the lookup operation on the given table. template @@ -676,6 +730,7 @@ class HashTableFindWithExistsGpuOp : public OpKernel { }; // Table insert op. +template class HashTableInsertGpuOp : public OpKernel { public: explicit HashTableInsertGpuOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} @@ -684,23 +739,27 @@ class HashTableInsertGpuOp : public OpKernel { lookup::LookupInterface* table; OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table)); core::ScopedUnref unref_me(table); + lookup::HkvHashTableOfTensorsGpu* table_hkv = + (lookup::HkvHashTableOfTensorsGpu*)table; DataType expected_input_0 = DT_RESOURCE; DataTypeVector expected_inputs = {expected_input_0, table->key_dtype(), - table->value_dtype()}; + table->value_dtype(), table->key_dtype()}; OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {})); const Tensor& keys = ctx->input(1); const Tensor& values = ctx->input(2); + const Tensor& scores = ctx->input(3); + OP_REQUIRES_OK(ctx, table->CheckKeyAndValueTensorsForInsert(keys, values)); - OP_REQUIRES_OK(ctx, table->Insert(ctx, keys, values)); + if (scores.NumElements() == 0) { + OP_REQUIRES_OK(ctx, table_hkv->Insert(ctx, keys, values)); + } else { + OP_REQUIRES_OK(ctx, table_hkv->Insert(ctx, keys, values, scores)); + } } }; -REGISTER_KERNEL_BUILDER( - Name(PREFIX_OP_NAME(HkvHashTableInsert)).Device(DEVICE_GPU), - HashTableInsertGpuOp); - // Table accum op. template class HashTableAccumGpuOp : public OpKernel { @@ -715,22 +774,28 @@ class HashTableAccumGpuOp : public OpKernel { (lookup::HkvHashTableOfTensorsGpu*)table; DataType expected_input_0 = DT_RESOURCE; - DataTypeVector expected_inputs = {expected_input_0, table->key_dtype(), - table->value_dtype(), - DataTypeToEnum::v()}; + DataTypeVector expected_inputs = { + expected_input_0, table->key_dtype(), table->value_dtype(), + DataTypeToEnum::v(), table->key_dtype()}; OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {})); const Tensor& keys = ctx->input(1); const Tensor& values_or_deltas = ctx->input(2); const Tensor& exists = ctx->input(3); + const Tensor& scores = ctx->input(4); OP_REQUIRES_OK( ctx, table->CheckKeyAndValueTensorsForInsert(keys, values_or_deltas)); - OP_REQUIRES_OK(ctx, table_hkv->Accum(ctx, keys, values_or_deltas, exists)); + if (scores.NumElements() == 0) { + OP_REQUIRES_OK(ctx, + table_hkv->Accum(ctx, keys, values_or_deltas, exists)); + } else { + OP_REQUIRES_OK( + ctx, table_hkv->Accum(ctx, keys, values_or_deltas, exists, scores)); + } } }; // Table remove op. -// template class HashTableRemoveGpuOp : public OpKernel { public: explicit HashTableRemoveGpuOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} @@ -795,11 +860,8 @@ class HashTableSizeGpuOp : public OpKernel { } }; -// REGISTER_KERNEL_BUILDER( -// Name(PREFIX_OP_NAME(HkvHashTableSize)).Device(DEVICE_GPU), -// HashTableSizeGpuOp); - // Op that outputs tensors of all keys and all values. +template class HashTableExportGpuOp : public OpKernel { public: explicit HashTableExportGpuOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} @@ -808,15 +870,13 @@ class HashTableExportGpuOp : public OpKernel { lookup::LookupInterface* table; OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table)); core::ScopedUnref unref_me(table); + lookup::HkvHashTableOfTensorsGpu* table_hkv = + (lookup::HkvHashTableOfTensorsGpu*)table; - OP_REQUIRES_OK(ctx, table->ExportValues(ctx)); + OP_REQUIRES_OK(ctx, table_hkv->ExportValues(ctx)); } }; -REGISTER_KERNEL_BUILDER( - Name(PREFIX_OP_NAME(HkvHashTableExport)).Device(DEVICE_GPU), - HashTableExportGpuOp); - // Op that export all keys and values to file. template class HashTableExportWithScoresGpuOp : public OpKernel { @@ -857,6 +917,7 @@ class HashTableExportKeysAndScoresGpuOp : public OpKernel { }; // Clear the table and insert data. +template class HashTableImportGpuOp : public OpKernel { public: explicit HashTableImportGpuOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} @@ -865,6 +926,8 @@ class HashTableImportGpuOp : public OpKernel { lookup::LookupInterface* table; OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table)); core::ScopedUnref unref_me(table); + lookup::HkvHashTableOfTensorsGpu* table_hkv = + (lookup::HkvHashTableOfTensorsGpu*)table; DataType expected_input_0 = DT_RESOURCE; DataTypeVector expected_inputs = {expected_input_0, table->key_dtype(), @@ -874,14 +937,10 @@ class HashTableImportGpuOp : public OpKernel { const Tensor& keys = ctx->input(1); const Tensor& values = ctx->input(2); OP_REQUIRES_OK(ctx, table->CheckKeyAndValueTensorsForImport(keys, values)); - OP_REQUIRES_OK(ctx, table->ImportValues(ctx, keys, values)); + OP_REQUIRES_OK(ctx, table_hkv->ImportValues(ctx, keys, values)); } }; -REGISTER_KERNEL_BUILDER( - Name(PREFIX_OP_NAME(HkvHashTableImport)).Device(DEVICE_GPU), - HashTableImportGpuOp); - // Op that export all keys and values to FileSystem. template class HashTableSaveToFileSystemGpuOp : public OpKernel { @@ -989,8 +1048,13 @@ class HashTableLoadFromFileSystemGpuOp : public OpKernel { size_t buffer_size_; }; +#define CONCAT_QUADRA_STRING(X, Y, Z, S) (#X #Y #Z #S) + +#define PREFIX_OP_NAME_X_IMPL(N, S) CONCAT_QUADRA_STRING(TFRA, >, N, S) +#define PREFIX_OP_NAME_X(N, ...) PREFIX_OP_NAME_X_IMPL(N, __VA_ARGS__) + // Register the HkvHashTableOfTensors op. -#define REGISTER_KERNEL(key_dtype, value_dtype) \ +#define REGISTER_HKV_TABLE(key_dtype, value_dtype) \ REGISTER_KERNEL_BUILDER( \ Name(PREFIX_OP_NAME(HkvHashTableOfTensors)) \ .Device(DEVICE_GPU) \ @@ -1008,6 +1072,21 @@ class HashTableLoadFromFileSystemGpuOp : public OpKernel { .TypeConstraint("key_dtype") \ .TypeConstraint("value_dtype"), \ HashTableSizeGpuOp); \ + REGISTER_KERNEL_BUILDER(Name(PREFIX_OP_NAME(HkvHashTableInsert)) \ + .Device(DEVICE_GPU) \ + .TypeConstraint("key_dtype") \ + .TypeConstraint("value_dtype"), \ + HashTableInsertGpuOp); \ + REGISTER_KERNEL_BUILDER(Name(PREFIX_OP_NAME(HkvHashTableExport)) \ + .Device(DEVICE_GPU) \ + .TypeConstraint("key_dtype") \ + .TypeConstraint("value_dtype"), \ + HashTableExportGpuOp); \ + REGISTER_KERNEL_BUILDER(Name(PREFIX_OP_NAME(HkvHashTableImport)) \ + .Device(DEVICE_GPU) \ + .TypeConstraint("key_dtype") \ + .TypeConstraint("value_dtype"), \ + HashTableImportGpuOp); \ REGISTER_KERNEL_BUILDER(Name(PREFIX_OP_NAME(HkvHashTableAccum)) \ .Device(DEVICE_GPU) \ .TypeConstraint("key_dtype") \ @@ -1021,14 +1100,14 @@ class HashTableLoadFromFileSystemGpuOp : public OpKernel { HashTableExportWithScoresGpuOp); \ REGISTER_KERNEL_BUILDER(Name(PREFIX_OP_NAME(HkvHashTableFind)) \ .Device(DEVICE_GPU) \ - .TypeConstraint("Tin") \ - .TypeConstraint("Tout"), \ + .TypeConstraint("key_dtype") \ + .TypeConstraint("value_dtype"), \ HashTableFindGpuOp); \ REGISTER_KERNEL_BUILDER( \ Name(PREFIX_OP_NAME(HkvHashTableFindWithExists)) \ .Device(DEVICE_GPU) \ - .TypeConstraint("Tin") \ - .TypeConstraint("Tout"), \ + .TypeConstraint("key_dtype") \ + .TypeConstraint("value_dtype"), \ HashTableFindWithExistsGpuOp); \ REGISTER_KERNEL_BUILDER( \ Name(PREFIX_OP_NAME(HkvHashTableSaveToFileSystem)) \ @@ -1041,27 +1120,22 @@ class HashTableLoadFromFileSystemGpuOp : public OpKernel { .Device(DEVICE_GPU) \ .TypeConstraint("key_dtype") \ .TypeConstraint("value_dtype"), \ - HashTableLoadFromFileSystemGpuOp); - -REGISTER_KERNEL(int64, float); -REGISTER_KERNEL(int64, int8); -REGISTER_KERNEL(int64, int32); -REGISTER_KERNEL(int64, int64); -REGISTER_KERNEL(int64, Eigen::half); -REGISTER_KERNEL(int64, Eigen::bfloat16); - -#undef REGISTER_KERNEL - -#define SINGLE_ATTR_REGISTER_KERNEL(key_dtype, value_type) \ - REGISTER_KERNEL_BUILDER( \ - Name(PREFIX_OP_NAME(HkvHashTableExportKeysAndScores)) \ - .Device(DEVICE_GPU) \ - .TypeConstraint("Tkeys"), \ - HashTableExportKeysAndScoresGpuOp); + HashTableLoadFromFileSystemGpuOp); \ + REGISTER_KERNEL_BUILDER( \ + Name(PREFIX_OP_NAME(HkvHashTableExportKeysAndScores)) \ + .Device(DEVICE_GPU) \ + .TypeConstraint("key_dtype") \ + .TypeConstraint("value_dtype"), \ + HashTableExportKeysAndScoresGpuOp); -SINGLE_ATTR_REGISTER_KERNEL(int64, float); +REGISTER_HKV_TABLE(int64, float); +REGISTER_HKV_TABLE(int64, int8); +REGISTER_HKV_TABLE(int64, int32); +REGISTER_HKV_TABLE(int64, int64); +REGISTER_HKV_TABLE(int64, Eigen::half); +REGISTER_HKV_TABLE(int64, bfloat16); -#undef SINGLE_ATTR_REGISTER_KERNEL +#undef REGISTER_HKV_TABLE } // namespace recommenders_addons } // namespace tensorflow diff --git a/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_hkv.h b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_hkv.h index f5a746199..81564e05d 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_hkv.h +++ b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_hkv.h @@ -57,6 +57,7 @@ inline Status ReturnInternalErrorStatus(const char* const str) { return Status(tensorflow::error::INTERNAL, str); #endif } +using HkvEvictStrategy = nv::merlin::EvictStrategy; template class KVOnlyFile : public nv::merlin::BaseKVFile { @@ -299,6 +300,8 @@ struct TableWrapperInitOptions { size_t init_capacity; size_t max_hbm_for_vectors; size_t max_bucket_size; + size_t evict_global_epoch; + float max_load_factor; int block_size; int io_block_size; @@ -416,15 +419,14 @@ class TFOrDefaultAllocator : public nv::merlin::BaseAllocator { template class TableWrapper { private: - // using S = uint64_t; - using Table = nv::merlin::HashTable; + using Table = nv::merlin::HashTableBase; nv::merlin::HashTableOptions mkv_options_; public: - TableWrapper(TableWrapperInitOptions& init_options, size_t dim) { + TableWrapper(TableWrapperInitOptions& init_options, size_t dim, + int strategy) { max_capacity_ = init_options.max_capacity; dim_ = dim; - // nv::merlin::HashTableOptions mkv_options_; mkv_options_.init_capacity = std::min(init_options.init_capacity, max_capacity_); mkv_options_.max_capacity = max_capacity_; @@ -436,11 +438,34 @@ class TableWrapper { mkv_options_.max_load_factor = 0.5; mkv_options_.block_size = nv::merlin::SAFE_GET_BLOCK_SIZE(128); mkv_options_.dim = dim; - // mkv_options_.evict_strategy = nv::merlin::EvictStrategy::kCustomized; - mkv_options_.evict_strategy = nv::merlin::EvictStrategy::kLru; block_size_ = mkv_options_.block_size; - table_ = new Table(); + switch (strategy) { + case HkvEvictStrategy::kLfu: + table_ = + new nv::merlin::HashTable(); + break; + case HkvEvictStrategy::kEpochLru: + table_ = new nv::merlin::HashTable(); + break; + case HkvEvictStrategy::kEpochLfu: + table_ = new nv::merlin::HashTable(); + break; + case HkvEvictStrategy::kCustomized: + table_ = new nv::merlin::HashTable(); + break; + default: + table_ = + new nv::merlin::HashTable(); + break; + } + table_->set_global_epoch(init_options.evict_global_epoch); + LOG(INFO) << "Use Evict Strategy:" << strategy + << ", [0:LRU, 1:LFU, 2:EPOCHLRU, 3:EPOCHLFU, 4:CUSTOMIZED]"; + LOG(INFO) << "Use Evict Global Epoch:" << init_options.evict_global_epoch; } Status init(nv::merlin::BaseAllocator* allocator) { @@ -454,20 +479,20 @@ class TableWrapper { ~TableWrapper() { delete table_; } - void upsert(const K* d_keys, const V* d_vals, size_t len, - cudaStream_t stream) { + void upsert(const K* d_keys, const V* d_vals, const uint64_t* d_scores, + size_t len, cudaStream_t stream) { uint64_t t0 = (uint64_t)time(NULL); size_t grid_size = nv::merlin::SAFE_GET_GRID_SIZE(len, block_size_); - table_->insert_or_assign(len, d_keys, d_vals, /*d_scores=*/nullptr, stream); + table_->insert_or_assign(len, d_keys, d_vals, d_scores, stream); CUDA_CHECK(cudaStreamSynchronize(stream)); } void accum(const K* d_keys, const V* d_vals_or_deltas, const bool* d_exists, - size_t len, cudaStream_t stream) { + const uint64_t* d_scores, size_t len, cudaStream_t stream) { uint64_t t0 = (uint64_t)time(NULL); size_t grid_size = nv::merlin::SAFE_GET_GRID_SIZE(len, block_size_); - table_->accum_or_assign(len, d_keys, d_vals_or_deltas, d_exists, - /*d_scores=*/nullptr, stream); + table_->accum_or_assign(len, d_keys, d_vals_or_deltas, d_exists, d_scores, + stream); CUDA_CHECK(cudaStreamSynchronize(stream)); } @@ -679,9 +704,9 @@ class TableWrapper { template Status CreateTableImpl(TableWrapper** pptable, TableWrapperInitOptions& options, - nv::merlin::BaseAllocator* allocator, - size_t runtime_dim) { - *pptable = new TableWrapper(options, runtime_dim); + nv::merlin::BaseAllocator* allocator, size_t runtime_dim, + int strategy) { + *pptable = new TableWrapper(options, runtime_dim, strategy); return (*pptable)->init(allocator); } diff --git a/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_hkv_impl.cu.cc b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_hkv_impl.cu.cc deleted file mode 100644 index 8f50afd6e..000000000 --- a/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_hkv_impl.cu.cc +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_hkv.h" - -namespace tensorflow { -namespace recommenders_addons { -namespace lookup { -namespace gpu { - -#define DEFINE_PURE_GPU_HASHTABLE(key_type, value_type) \ - template <> \ - class TableWrapper - -DEFINE_PURE_GPU_HASHTABLE(int64, float); -DEFINE_PURE_GPU_HASHTABLE(int64, int8); -DEFINE_PURE_GPU_HASHTABLE(int64, int32); -DEFINE_PURE_GPU_HASHTABLE(int64, int64); -DEFINE_PURE_GPU_HASHTABLE(int64, Eigen::half); -DEFINE_PURE_GPU_HASHTABLE(int64, Eigen::bfloat16); - -#undef DEFINE_PURE_GPU_HASHTABLE - -} // namespace gpu -} // namespace lookup -} // namespace recommenders_addons -} // namespace tensorflow diff --git a/tensorflow_recommenders_addons/dynamic_embedding/core/ops/hkv_hashtable_ops.cc b/tensorflow_recommenders_addons/dynamic_embedding/core/ops/hkv_hashtable_ops.cc index 8a92eb0ef..45e111d62 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/core/ops/hkv_hashtable_ops.cc +++ b/tensorflow_recommenders_addons/dynamic_embedding/core/ops/hkv_hashtable_ops.cc @@ -131,61 +131,70 @@ Status HkvHashTableShape(InferenceContext* c, const ShapeHandle& key, return TFOkStatus; } -REGISTER_OP(PREFIX_OP_NAME(HkvHashTableFind)) +REGISTER_OP(PREFIX_OP_NAME(HkvHashTableRemove)) .Input("table_handle: resource") - .Input("keys: Tin") - .Input("default_value: Tout") - .Output("values: Tout") - .Attr("Tin: type") - .Attr("Tout: type") + .Input("keys: key_dtype") + .Attr("key_dtype: type") .SetShapeFn([](InferenceContext* c) { ShapeHandle handle; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &handle)); + return TFOkStatus; + }); +#define CONCAT_QUADRA_STRING(X, Y, Z, S) (#X #Y #Z #S) + +#define PREFIX_OP_NAME_X_IMPL(N, S) CONCAT_QUADRA_STRING(TFRA, >, N, S) +#define PREFIX_OP_NAME_X(N, ...) PREFIX_OP_NAME_X_IMPL(N, __VA_ARGS__) + +REGISTER_OP(PREFIX_OP_NAME(HkvHashTableFind)) + .Input("table_handle: resource") + .Input("keys: key_dtype") + .Input("default_value: value_dtype") + .Output("values: value_dtype") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); ShapeAndType value_shape_and_type; TF_RETURN_IF_ERROR(ValidateTableResourceHandle( - c, - /*keys=*/c->input(1), - /*key_dtype_attr=*/"Tin", - /*value_dtype_attr=*/"Tout", - /*is_lookup=*/true, &value_shape_and_type)); + c, /*keys=*/c->input(1), /*key_dtype_attr=*/"key_dtype", + /*value_dtype_attr=*/"value_dtype", /*is_lookup=*/true, + &value_shape_and_type)); c->set_output(0, value_shape_and_type.shape); return TFOkStatus; }); - REGISTER_OP(PREFIX_OP_NAME(HkvHashTableFindWithExists)) .Input("table_handle: resource") - .Input("keys: Tin") - .Input("default_value: Tout") - .Output("values: Tout") + .Input("keys: key_dtype") + .Input("default_value: value_dtype") + .Output("values: value_dtype") .Output("exists: bool") - .Attr("Tin: type") - .Attr("Tout: type") + .Attr("key_dtype: type") + .Attr("value_dtype: type") .SetShapeFn([](InferenceContext* c) { ShapeHandle handle; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); - ShapeHandle keys = c->UnknownShapeOfRank(1); ShapeAndType value_shape_and_type; TF_RETURN_IF_ERROR(ValidateTableResourceHandle( - c, - /*keys=*/c->input(1), - /*key_dtype_attr=*/"Tin", - /*value_dtype_attr=*/"Tout", - /*is_lookup=*/true, &value_shape_and_type)); + c, /*keys=*/c->input(1), /*key_dtype_attr=*/"key_dtype", + /*value_dtype_attr=*/"value_dtype", /*is_lookup=*/true, + &value_shape_and_type)); c->set_output(0, value_shape_and_type.shape); c->set_output(1, keys); return TFOkStatus; }); - REGISTER_OP(PREFIX_OP_NAME(HkvHashTableInsert)) .Input("table_handle: resource") - .Input("keys: Tin") - .Input("values: Tout") - .Attr("Tin: type") - .Attr("Tout: type") + .Input("keys: key_dtype") + .Input("values: value_dtype") + .Input("scores: int64") + .Attr("key_dtype: type") + .Attr("value_dtype: type") .SetShapeFn([](InferenceContext* c) { ShapeHandle handle; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); @@ -193,12 +202,12 @@ REGISTER_OP(PREFIX_OP_NAME(HkvHashTableInsert)) // TODO: Validate keys and values shape. return TFOkStatus; }); - REGISTER_OP(PREFIX_OP_NAME(HkvHashTableAccum)) .Input("table_handle: resource") .Input("keys: key_dtype") .Input("values_or_deltas: value_dtype") .Input("exists: bool") + .Input("scores: int64") .Attr("key_dtype: type") .Attr("value_dtype: type") .SetShapeFn([](InferenceContext* c) { @@ -209,53 +218,35 @@ REGISTER_OP(PREFIX_OP_NAME(HkvHashTableAccum)) return TFOkStatus; }); -REGISTER_OP(PREFIX_OP_NAME(HkvHashTableRemove)) - .Input("table_handle: resource") - .Input("keys: Tin") - .Attr("Tin: type") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle handle; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); - TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &handle)); - - // TODO(turboale): Validate keys shape. - return TFOkStatus; - }); - REGISTER_OP(PREFIX_OP_NAME(HkvHashTableClear)) .Input("table_handle: resource") .Attr("key_dtype: type") .Attr("value_dtype: type"); - REGISTER_OP(PREFIX_OP_NAME(HkvHashTableSize)) .Input("table_handle: resource") .Output("size: int64") .Attr("key_dtype: type") .Attr("value_dtype: type") .SetShapeFn(ScalarAndTwoElementVectorInputsAndScalarOutputs); - REGISTER_OP(PREFIX_OP_NAME(HkvHashTableExport)) .Input("table_handle: resource") - .Output("keys: Tkeys") - .Output("values: Tvalues") - .Attr("Tkeys: type") - .Attr("Tvalues: type") + .Output("keys: key_dtype") + .Output("values: value_dtype") + .Attr("key_dtype: type") + .Attr("value_dtype: type") .SetShapeFn([](InferenceContext* c) { ShapeHandle handle; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); ShapeHandle keys = c->UnknownShapeOfRank(1); ShapeAndType value_shape_and_type; TF_RETURN_IF_ERROR(ValidateTableResourceHandle( - c, - /*keys=*/keys, - /*key_dtype_attr=*/"Tkeys", - /*value_dtype_attr=*/"Tvalues", - /*is_lookup=*/false, &value_shape_and_type)); + c, /*keys=*/keys, /*key_dtype_attr=*/"key_dtype", + /*value_dtype_attr=*/"value_dtype", /*is_lookup=*/false, + &value_shape_and_type)); c->set_output(0, keys); c->set_output(1, value_shape_and_type.shape); return TFOkStatus; }); - REGISTER_OP(PREFIX_OP_NAME(HkvHashTableSaveToFileSystem)) .Input("table_handle: resource") .Input("dirpath: string") @@ -265,12 +256,12 @@ REGISTER_OP(PREFIX_OP_NAME(HkvHashTableSaveToFileSystem)) .Attr("dirpath_env: string") .Attr("append_to_file: bool") .Attr("buffer_size: int >= 1"); - REGISTER_OP(PREFIX_OP_NAME(HkvHashTableExportKeysAndScores)) .Input("table_handle: resource") - .Output("keys: Tkeys") + .Output("keys: key_dtype") .Output("scores: int64") - .Attr("Tkeys: type") + .Attr("key_dtype: type") + .Attr("value_dtype: type") .Attr("split_size: int") .SetShapeFn([](InferenceContext* c) { ShapeHandle handle; @@ -281,23 +272,20 @@ REGISTER_OP(PREFIX_OP_NAME(HkvHashTableExportKeysAndScores)) c->set_output(1, scores); return TFOkStatus; }); - REGISTER_OP(PREFIX_OP_NAME(HkvHashTableImport)) .Input("table_handle: resource") - .Input("keys: Tin") - .Input("values: Tout") - .Attr("Tin: type") - .Attr("Tout: type") + .Input("keys: key_dtype") + .Input("values: value_dtype") + .Attr("key_dtype: type") + .Attr("value_dtype: type") .SetShapeFn([](InferenceContext* c) { ShapeHandle handle; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); - ShapeHandle keys; TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys)); TF_RETURN_IF_ERROR(c->Merge(keys, c->input(2), &keys)); return TFOkStatus; }); - REGISTER_OP(PREFIX_OP_NAME(HkvHashTableLoadFromFileSystem)) .Input("table_handle: resource") .Input("dirpath: string") @@ -307,7 +295,6 @@ REGISTER_OP(PREFIX_OP_NAME(HkvHashTableLoadFromFileSystem)) .Attr("dirpath_env: string") .Attr("load_entire_dir: bool") .Attr("buffer_size: int >= 1"); - REGISTER_OP(PREFIX_OP_NAME(HkvHashTableOfTensors)) .Output("table_handle: resource") .Attr("container: string = ''") @@ -319,6 +306,8 @@ REGISTER_OP(PREFIX_OP_NAME(HkvHashTableOfTensors)) .Attr("init_capacity: int = 0") .Attr("max_capacity: int = 0") .Attr("max_hbm_for_vectors: int = 0") + .Attr("evict_global_epoch: int = 0") + .Attr("strategy: int = 0") .SetIsStateful() .SetShapeFn([](InferenceContext* c) { PartialTensorShape value_p; @@ -327,4 +316,5 @@ REGISTER_OP(PREFIX_OP_NAME(HkvHashTableOfTensors)) TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(value_p, &value_s)); return HkvHashTableShape(c, /*key=*/c->Scalar(), /*value=*/value_s); }); + } // namespace tensorflow diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/hkv_hashtable_evict_test.py b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/hkv_hashtable_evict_test.py new file mode 100644 index 000000000..2b130adaf --- /dev/null +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/hkv_hashtable_evict_test.py @@ -0,0 +1,140 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""unit tests of hkv hashtable ops +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys +import os +import itertools +import numpy as np + +from tensorflow_recommenders_addons import dynamic_embedding as de +from tensorflow_recommenders_addons.utils.check_platform import is_windows, is_macos, is_arm64, is_linux, is_raspi_arm + +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import server_lib +from tensorflow.python.client import session + +import tensorflow as tf +try: + import tensorflow_io +except: + print() + + +def _type_converter(tf_type): + mapper = { + dtypes.int32: np.int32, + dtypes.int64: np.int64, + dtypes.float32: float, + dtypes.float64: np.float64, + dtypes.string: str, + dtypes.half: np.float16, + dtypes.int8: np.int8, + dtypes.bool: bool, + } + return mapper[tf_type] + + +def _convert(v, t): + return np.array(v).astype(_type_converter(t)) + + +default_config = config_pb2.ConfigProto( + allow_soft_placement=False, + gpu_options=config_pb2.GPUOptions(allow_growth=True)) + + +def _get_devices(): + return ["/gpu:0" if test_util.is_gpu_available() else "/cpu:0"] + + +is_gpu_available = test_util.is_gpu_available() + + +def convert(v, t): + return np.array(v).astype(_type_converter(t)) + + +def gen_scores_fn(keys): + return tf.constant([1, 2, 3, 4], dtypes.int64) + + +class HkvHashtableTest(test.TestCase): + + @test_util.run_in_graph_and_eager_modes() + def test_evict_strategy(self): + if not is_gpu_available: + self.skipTest('Only test when gpu is available.') + strategy_i = 0 + for strategy in de.HkvEvictStrategy: + with self.session(use_gpu=True, config=default_config): + with self.captureWritesToStream(sys.stderr) as printed: + table = de.get_variable( + str(strategy), + key_dtype=dtypes.int64, + value_dtype=dtypes.int32, + initializer=0, + dim=8, + init_size=1024, + kv_creator=de.HkvHashTableCreator( + config=de.HkvHashTableConfig(init_capacity=1024, + max_capacity=1024, + max_hbm_for_values=1024 * 4 * 8 * + 2, + evict_strategy=strategy, + gen_scores_fn=gen_scores_fn))) + self.evaluate(table.size()) + + content = "Use Evict Strategy:" + str(strategy_i) + # self.assertTrue(content in printed.contents()) + strategy_i = strategy_i + 1 + + key_dtype = dtypes.int64 + value_dtype = dtypes.int32 + dim = 8 + + keys = constant_op.constant( + np.array([0, 1, 2, 3]).astype(_type_converter(key_dtype)), + key_dtype) + values = constant_op.constant( + _convert([[0] * dim, [1] * dim, [2] * dim, [3] * dim], + value_dtype), value_dtype) + + self.evaluate(table.upsert(keys, values)) + + output = table.lookup(keys) + self.assertAllEqual(values, self.evaluate(output)) + + # exported_keys, exported_scores = self.evaluate(table.export_keys_and_scores()) + # print(exported_keys) + # print(exported_scores) + + del table + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/hkv_hashtable_ops_test.py b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/hkv_hashtable_ops_test.py index 5efa574d2..83e5e1822 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/hkv_hashtable_ops_test.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/hkv_hashtable_ops_test.py @@ -1,4 +1,4 @@ -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/horovod_sync_train_test.py b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/horovod_sync_train_test.py index e9d3256d2..c1be9234c 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/horovod_sync_train_test.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/horovod_sync_train_test.py @@ -94,6 +94,8 @@ def test_all_to_all_embedding_trainable(self): if (tf.__version__ == "2.11.0" or tf.__version__ == "2.11.1"): self.skipTest( "The save function doesn't work with TF 2.11, skip the test.") + if not is_gpu_available: + self.skipTest('Only test when gpu is available.') if (is_macos() and is_arm64()): self.skipTest( "Apple silicon devices don't support synchronous training based on Horovod." diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/cuckoo_hashtable_ops.py b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/cuckoo_hashtable_ops.py index f634c9fea..5dec12c63 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/cuckoo_hashtable_ops.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/cuckoo_hashtable_ops.py @@ -18,6 +18,7 @@ import sys import copy import functools +import tensorflow as tf from tensorflow.python.eager import context from tensorflow.python.framework import dtypes @@ -107,6 +108,7 @@ def __init__( self._max_hbm_for_values = sys.maxsize self._device_type = tf_device.DeviceSpec.from_string( self._device).device_type + self._default_scores = tf.constant([], dtypes.int64) self._shared_name = None if context.executing_eagerly(): @@ -367,7 +369,8 @@ def insert(self, keys, values, name=None): # pylint: disable=protected-access if self._device_type == "GPU": return hkv_ops.tfra_hkv_hash_table_insert(self.resource_handle, keys, - values) + values, + self._default_scores) else: return cuckoo_ops.tfra_cuckoo_hash_table_insert( self.resource_handle, keys, values) @@ -405,7 +408,8 @@ def accum(self, keys, values_or_deltas, exists, name=None): # pylint: disable=protected-access if self._device_type == "GPU": return hkv_ops.tfra_hkv_hash_table_accum(self.resource_handle, keys, - values_or_deltas, exists) + values_or_deltas, exists, + self._default_scores) else: return cuckoo_ops.tfra_cuckoo_hash_table_accum( self.resource_handle, keys, values_or_deltas, exists) @@ -424,6 +428,7 @@ def export(self, name=None): [self.resource_handle]): with ops.colocate_with(self.resource_handle): if self._device_type == "GPU": + keys, values = hkv_ops.tfra_hkv_hash_table_export( self.resource_handle, self._key_dtype, self._value_dtype) else: diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_creator.py b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_creator.py index c8379b7e4..a01a6f58a 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_creator.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_creator.py @@ -15,6 +15,7 @@ # lint-as: python3 from abc import ABCMeta +from enum import IntEnum, unique from tensorflow.python.eager import context from tensorflow.python.framework import constant_op @@ -136,17 +137,34 @@ def get_config(self): return config +@unique +class HkvEvictStrategy(IntEnum): + LRU = 0 + LFU = 1 + EPOCHLRU = 2 + EPOCHLFU = 3 + CUSTOMIZED = 4 + + class HkvHashTableConfig(object): - def __init__(self, - init_capacity=KHkvHashTableInitCapacity, - max_capacity=KHkvHashTableMaxCapacity, - max_hbm_for_values=KHkvHashTableMaxHbmForValuesByBytes): + def __init__( + self, + init_capacity=KHkvHashTableInitCapacity, + max_capacity=KHkvHashTableMaxCapacity, + max_hbm_for_values=KHkvHashTableMaxHbmForValuesByBytes, + evict_strategy=HkvEvictStrategy.LRU, + evict_global_epoch=0, + gen_scores_fn=None, + ): """ CuckooHashTableConfig include nothing for parameter default satisfied. """ self.init_capacity = init_capacity self.max_capacity = max_capacity self.max_hbm_for_values = max_hbm_for_values + self.evict_strategy = evict_strategy + self.evict_global_epoch = evict_global_epoch + self.gen_scores_fn = gen_scores_fn class HkvHashTableCreator(KVCreator): @@ -171,10 +189,16 @@ def create( self.init_capacity = init_size self.max_capacity = KHkvHashTableMaxCapacity self.max_hbm_for_values = KHkvHashTableMaxHbmForValuesByBytes + self.evict_strategy = HkvEvictStrategy.LRU + self.evict_global_epoch = 0 + self.gen_scores_fn = None if self.config and isinstance(self.config, de.HkvHashTableConfig): self.init_capacity = self.config.init_capacity self.max_capacity = self.config.max_capacity self.max_hbm_for_values = self.config.max_hbm_for_values + self.evict_strategy = self.config.evict_strategy + self.evict_global_epoch = self.config.evict_global_epoch + self.gen_scores_fn = self.config.gen_scores_fn self.device = device self.shard_saveable_object_fn = shard_saveable_object_fn @@ -187,6 +211,9 @@ def create( init_capacity=self.init_capacity, max_capacity=self.max_capacity, max_hbm_for_values=self.max_hbm_for_values, + evict_strategy=self.evict_strategy, + evict_global_epoch=self.evict_global_epoch, + gen_scores_fn=self.gen_scores_fn, config=self.config, device=self.device, shard_saveable_object_fn=self.shard_saveable_object_fn) diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/hkv_hashtable_ops.py b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/hkv_hashtable_ops.py index efcee7b4d..ae8c7274f 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/hkv_hashtable_ops.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/hkv_hashtable_ops.py @@ -20,6 +20,7 @@ import copy import functools +import tensorflow as tf from tensorflow.python.eager import context from tensorflow.python.framework import dtypes @@ -32,6 +33,8 @@ from tensorflow_recommenders_addons.utils.resource_loader import LazySO from tensorflow_recommenders_addons.utils.resource_loader import prefix_op_name +from tensorflow_recommenders_addons.dynamic_embedding.python.ops.dynamic_embedding_creator import HkvEvictStrategy + try: hkv_ops = LazySO("dynamic_embedding/core/_hkv_ops.so").ops except: @@ -73,6 +76,9 @@ def __init__( config=None, device='', shard_saveable_object_fn=None, + evict_strategy=HkvEvictStrategy.LRU, + evict_global_epoch=0, + gen_scores_fn=None, ): """Creates an empty `HkvHashTable` object. @@ -104,6 +110,7 @@ def __init__( self._checkpoint = checkpoint self._key_dtype = key_dtype self._value_dtype = value_dtype + self._scores_dtype = dtypes.int64 self._init_capacity = init_capacity self._max_capacity = max_capacity self._max_hbm_for_values = max_hbm_for_values @@ -113,11 +120,18 @@ def __init__( if not self._device or self._device == '': self._device = ['/GPU:0'] self._new_obj_trackable = None + self._evict_strategy = evict_strategy + self._evict_global_epoch = evict_global_epoch + self._gen_scores_fn = gen_scores_fn + self._default_scores = tf.constant([], dtypes.int64) if self._config: self._init_capacity = self._config.init_capacity self._max_capacity = self._config.max_capacity self._max_hbm_for_values = self._config.max_hbm_for_values + self._evict_strategy = self._config.evict_strategy + self._evict_global_epoch = self._config.evict_global_epoch + self._gen_scores_fn = self._config.gen_scores_fn self._shared_name = None if context.executing_eagerly(): @@ -166,6 +180,8 @@ def _create_resource(self): init_capacity=self._init_capacity, max_capacity=self._max_capacity, max_hbm_for_vectors=self._max_hbm_for_values, + strategy=self._evict_strategy.value, + evict_global_epoch=self._evict_global_epoch, name=self._name, ) @@ -330,14 +346,19 @@ def insert(self, keys, values, name=None): with ops.name_scope( name, "%s_lookup_table_insert" % self.name, - [self.resource_handle, keys, values], + [self.resource_handle, keys, values, keys], ): keys = ops.convert_to_tensor(keys, self._key_dtype, name="keys") values = ops.convert_to_tensor(values, self._value_dtype, name="values") + scores = self._default_scores + if self._evict_strategy == HkvEvictStrategy.CUSTOMIZED: + assert self._gen_scores_fn != None, "You must set gen_scores_fn when set evict strategy to CUSTOMIZED" + scores = self._gen_scores_fn(keys) + elif self._evict_strategy == HkvEvictStrategy.LFU or self._evict_strategy == HkvEvictStrategy.EPOCHLFU: + scores = tf.ones(keys.shape, keys.dtype) with ops.colocate_with(self.resource_handle, ignore_existing=True): - # pylint: disable=protected-access op = hkv_ops.tfra_hkv_hash_table_insert(self.resource_handle, keys, - values) + values, scores) return op def accum(self, keys, values_or_deltas, exists, name=None): @@ -369,10 +390,15 @@ def accum(self, keys, values_or_deltas, exists, name=None): self._value_dtype, name="values_or_deltas") exists = ops.convert_to_tensor(exists, dtypes.bool, name="exists") + scores = self._default_scores + if self._evict_strategy == HkvEvictStrategy.CUSTOMIZED: + assert self._gen_scores_fn != None, "You must set gen_scores_fn when set evict strategy to CUSTOMIZED" + scores = self._gen_scores_fn(keys) + elif self._evict_strategy == HkvEvictStrategy.LFU or self._evict_strategy == HkvEvictStrategy.EPOCHLFU: + scores = tf.ones(keys.shape, keys.dtype) with ops.colocate_with(self.resource_handle, ignore_existing=True): - # pylint: disable=protected-access op = hkv_ops.tfra_hkv_hash_table_accum(self.resource_handle, keys, - values_or_deltas, exists) + values_or_deltas, exists, scores) return op def export(self, name=None): @@ -395,12 +421,16 @@ def export(self, name=None): def export_keys_and_scores(self, split_size, name=None): if not (split_size > 0 and isinstance(split_size, int)): raise ValueError(f'split_size must be positive integer.') + with ops.name_scope(name, "%s_lookup_table_export_keys_and_scores" % self.name, [self.resource_handle]): with ops.colocate_with(self.resource_handle): keys, scores = hkv_ops.tfra_hkv_hash_table_export_keys_and_scores( - self.resource_handle, Tkeys=self._key_dtype, split_size=split_size) + self.resource_handle, + key_dtype=self._key_dtype, + value_dtype=self._value_dtype, + split_size=split_size) return keys, scores def save_to_file_system(self, diff --git a/tools/docker/cuda11.2.2-cudnn8-ubuntu20.04-manylinux2014-multipython.Dockerfile b/tools/docker/cuda11.2.2-cudnn8-ubuntu20.04-manylinux2014-multipython.Dockerfile new file mode 100644 index 000000000..629faaa32 --- /dev/null +++ b/tools/docker/cuda11.2.2-cudnn8-ubuntu20.04-manylinux2014-multipython.Dockerfile @@ -0,0 +1,107 @@ +# Dockerfile to build a manylinux 2010 compliant cross-compiler. +# +# Builds a devtoolset gcc/libstdc++ that targets manylinux 2010 compatible +# glibc (2.12) and system libstdc++ (4.4). +# +# To push a new version, run: +# $ docker build -f cuda11.2.2-cudnn8-ubuntu20.04-manylinux2014-multipython.Dockerfile . \ +# --tag "tfra/nosla-cuda11.2.2-cudnn8-ubuntu20.04-manylinux2014-multipython" +# $ docker push tfra/nosla-cuda11.2.2-cudnn8-ubuntu20.04-manylinux2014-multipython + +FROM nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04 as devtoolset + +RUN chmod 777 /tmp/ +ENV DEBIAN_FRONTEND=noninteractive +RUN apt-get update && apt-get install -y \ + cpio \ + file \ + flex \ + g++ \ + make \ + patch \ + rpm2cpio \ + unar \ + wget \ + xz-utils \ + libjpeg-dev \ + zlib1g-dev \ + libgflags-dev \ + libsnappy-dev \ + libbz2-dev \ + liblz4-dev \ + libzstd-dev \ + openssh-client \ + && \ + rm -rf /var/lib/apt/lists/* + +ADD devtoolset/fixlinks.sh fixlinks.sh +ADD devtoolset/build_devtoolset.sh build_devtoolset.sh +ADD devtoolset/rpm-patch.sh rpm-patch.sh + +# Set up a sysroot for glibc 2.12 / libstdc++ 4.4 / devtoolset-8 in /dt8. +RUN /build_devtoolset.sh devtoolset-8 /dt8 + +# TODO(klimek): Split up into two different docker images. +FROM nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04 +COPY --from=devtoolset /dt8 /dt8 + +# Install TensorRT. +RUN echo \ + deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64 / \ + > /etc/apt/sources.list.d/nvidia-ml.list \ + && \ + apt-key adv --keyserver keyserver.ubuntu.com --recv-keys F60F4B3D7FA2AF80 && \ + apt-get update && \ + rm -rf /var/lib/apt/lists/* + +# Copy and run the install scripts. +ARG DEBIAN_FRONTEND=noninteractive + +COPY install/install_bootstrap_deb_packages.sh /install/ +RUN /install/install_bootstrap_deb_packages.sh + +COPY install/install_deb_packages.sh /install/ +RUN /install/install_deb_packages.sh + +# Install additional packages needed for this image: +# - dependencies to build Python from source +# - patchelf, as it is required by auditwheel +RUN apt-get update && apt-get install -y \ + libbz2-dev \ + libffi-dev \ + libgdbm-dev \ + libncurses5-dev \ + libnss3-dev \ + libreadline-dev \ + patchelf \ + gcc-multilib \ + && \ + rm -rf /var/lib/apt/lists/* + +RUN chmod 777 /tmp/ +WORKDIR /tmp/ + +COPY install/install_nccl.sh /install/ +RUN /install/install_nccl.sh "2.8.4-1+cuda11.2" + +COPY install/install_bazel.sh /install/ +RUN /install/install_bazel.sh "5.1.1" + +COPY install/build_and_install_python.sh /install/ +RUN /install/build_and_install_python.sh "3.7.7" +RUN /install/build_and_install_python.sh "3.8.2" +RUN /install/build_and_install_python.sh "3.9.7" + +COPY install/install_pip_packages_by_version.sh /install/ +RUN /install/install_pip_packages_by_version.sh "/usr/local/bin/pip3.9" +RUN /install/install_pip_packages_by_version.sh "/usr/local/bin/pip3.8" +RUN /install/install_pip_packages_by_version.sh "/usr/local/bin/pip3.7" + +COPY install/use_devtoolset_8.sh /install/ +RUN /install/use_devtoolset_8.sh + +COPY install/install_openmpi.sh /install/ +RUN /install/install_openmpi.sh "4.1.1" + +# clean +RUN rm -rf /tmp/* diff --git a/tools/docker/cuda11.2.2-cudnn8-ubuntu20.04-manylinux2014-python3.10.Dockerfile b/tools/docker/cuda11.2.2-cudnn8-ubuntu20.04-manylinux2014-python3.10.Dockerfile new file mode 100644 index 000000000..be8fee2c3 --- /dev/null +++ b/tools/docker/cuda11.2.2-cudnn8-ubuntu20.04-manylinux2014-python3.10.Dockerfile @@ -0,0 +1,103 @@ +# Dockerfile to build a manylinux 2010 compliant cross-compiler. +# +# Builds a devtoolset gcc/libstdc++ that targets manylinux 2010 compatible +# glibc (2.12) and system libstdc++ (4.4). +# +# To push a new version, run: +# $ docker build -f cuda11.2.2-cudnn8-ubuntu20.04-manylinux2014-python3.10.Dockerfile . \ +# --tag "tfra/nosla-cuda11.2.2-cudnn8-ubuntu20.04-manylinux2014-python3.10" +# $ docker push tfra/nosla-cuda11.2.2-cudnn8-ubuntu20.04-manylinux2014-python3.10 + +FROM nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04 as devtoolset + +RUN chmod 777 /tmp/ +ENV DEBIAN_FRONTEND=noninteractive +RUN apt-get update && apt-get install -y \ + cpio \ + file \ + flex \ + g++ \ + make \ + patch \ + rpm2cpio \ + unar \ + wget \ + xz-utils \ + libjpeg-dev \ + zlib1g-dev \ + libgflags-dev \ + libsnappy-dev \ + libbz2-dev \ + liblz4-dev \ + libzstd-dev \ + openssh-client \ + && \ + rm -rf /var/lib/apt/lists/* + +ADD devtoolset/fixlinks.sh fixlinks.sh +ADD devtoolset/build_devtoolset.sh build_devtoolset.sh +ADD devtoolset/rpm-patch.sh rpm-patch.sh + +# Set up a sysroot for glibc 2.12 / libstdc++ 4.4 / devtoolset-8 in /dt8. +RUN /build_devtoolset.sh devtoolset-8 /dt8 + +# TODO(klimek): Split up into two different docker images. +FROM nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04 +COPY --from=devtoolset /dt8 /dt8 + +# Install TensorRT. +RUN echo \ + deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64 / \ + > /etc/apt/sources.list.d/nvidia-ml.list \ + && \ + apt-key adv --keyserver keyserver.ubuntu.com --recv-keys F60F4B3D7FA2AF80 && \ + apt-get update && \ + rm -rf /var/lib/apt/lists/* + +# Copy and run the install scripts. +ARG DEBIAN_FRONTEND=noninteractive + +COPY install/install_bootstrap_deb_packages.sh /install/ +RUN /install/install_bootstrap_deb_packages.sh + +COPY install/install_deb_packages.sh /install/ +RUN /install/install_deb_packages.sh + +# Install additional packages needed for this image: +# - dependencies to build Python from source +# - patchelf, as it is required by auditwheel +RUN apt-get update && apt-get install -y \ + libbz2-dev \ + libffi-dev \ + libgdbm-dev \ + libncurses5-dev \ + libnss3-dev \ + libreadline-dev \ + patchelf \ + gcc-multilib \ + && \ + rm -rf /var/lib/apt/lists/* + +RUN chmod 777 /tmp/ +WORKDIR /tmp/ + +COPY install/install_nccl.sh /install/ +RUN /install/install_nccl.sh "2.8.4-1+cuda11.2" + +COPY install/install_bazel.sh /install/ +RUN /install/install_bazel.sh "5.1.1" + +COPY install/build_and_install_python.sh /install/ +RUN /install/build_and_install_python.sh "3.10.6" + +COPY install/install_pip_packages_by_version.sh /install/ +RUN /install/install_pip_packages_by_version.sh "/usr/local/bin/pip3.10" + +COPY install/use_devtoolset_8.sh /install/ +RUN /install/use_devtoolset_8.sh + +COPY install/install_openmpi.sh /install/ +RUN /install/install_openmpi.sh "4.1.1" + +# clean +RUN rm -rf /tmp/* diff --git a/tools/docker/cuda11.2.2-cudnn8-ubuntu20.04-manylinux2014-python3.7.Dockerfile b/tools/docker/cuda11.2.2-cudnn8-ubuntu20.04-manylinux2014-python3.7.Dockerfile new file mode 100644 index 000000000..84581ca09 --- /dev/null +++ b/tools/docker/cuda11.2.2-cudnn8-ubuntu20.04-manylinux2014-python3.7.Dockerfile @@ -0,0 +1,103 @@ +# Dockerfile to build a manylinux 2010 compliant cross-compiler. +# +# Builds a devtoolset gcc/libstdc++ that targets manylinux 2010 compatible +# glibc (2.12) and system libstdc++ (4.4). +# +# To push a new version, run: +# $ docker build -f cuda11.2.2-cudnn8-ubuntu20.04-manylinux2014-python3.7.Dockerfile . \ +# --tag "tfra/nosla-cuda11.2.2-cudnn8-ubuntu20.04-manylinux2014-python3.7" +# $ docker push tfra/nosla-cuda11.2.2-cudnn8-ubuntu20.04-manylinux2014-python3.7 + +FROM nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04 as devtoolset + +RUN chmod 777 /tmp/ +ENV DEBIAN_FRONTEND=noninteractive +RUN apt-get update && apt-get install -y \ + cpio \ + file \ + flex \ + g++ \ + make \ + patch \ + rpm2cpio \ + unar \ + wget \ + xz-utils \ + libjpeg-dev \ + zlib1g-dev \ + libgflags-dev \ + libsnappy-dev \ + libbz2-dev \ + liblz4-dev \ + libzstd-dev \ + openssh-client \ + && \ + rm -rf /var/lib/apt/lists/* + +ADD devtoolset/fixlinks.sh fixlinks.sh +ADD devtoolset/build_devtoolset.sh build_devtoolset.sh +ADD devtoolset/rpm-patch.sh rpm-patch.sh + +# Set up a sysroot for glibc 2.12 / libstdc++ 4.4 / devtoolset-8 in /dt8. +RUN /build_devtoolset.sh devtoolset-8 /dt8 + +# TODO(klimek): Split up into two different docker images. +FROM nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04 +COPY --from=devtoolset /dt8 /dt8 + +# Install TensorRT. +RUN echo \ + deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64 / \ + > /etc/apt/sources.list.d/nvidia-ml.list \ + && \ + apt-key adv --keyserver keyserver.ubuntu.com --recv-keys F60F4B3D7FA2AF80 && \ + apt-get update && \ + rm -rf /var/lib/apt/lists/* + +# Copy and run the install scripts. +ARG DEBIAN_FRONTEND=noninteractive + +COPY install/install_bootstrap_deb_packages.sh /install/ +RUN /install/install_bootstrap_deb_packages.sh + +COPY install/install_deb_packages.sh /install/ +RUN /install/install_deb_packages.sh + +# Install additional packages needed for this image: +# - dependencies to build Python from source +# - patchelf, as it is required by auditwheel +RUN apt-get update && apt-get install -y \ + libbz2-dev \ + libffi-dev \ + libgdbm-dev \ + libncurses5-dev \ + libnss3-dev \ + libreadline-dev \ + patchelf \ + gcc-multilib \ + && \ + rm -rf /var/lib/apt/lists/* + +RUN chmod 777 /tmp/ +WORKDIR /tmp/ + +COPY install/install_nccl.sh /install/ +RUN /install/install_nccl.sh "2.8.4-1+cuda11.2" + +COPY install/install_bazel.sh /install/ +RUN /install/install_bazel.sh "5.1.1" + +COPY install/build_and_install_python.sh /install/ +RUN /install/build_and_install_python.sh "3.7.7" + +COPY install/install_pip_packages_by_version.sh /install/ +RUN /install/install_pip_packages_by_version.sh "/usr/local/bin/pip3.7" + +COPY install/use_devtoolset_8.sh /install/ +RUN /install/use_devtoolset_8.sh + +COPY install/install_openmpi.sh /install/ +RUN /install/install_openmpi.sh "4.1.1" + +# clean +RUN rm -rf /tmp/* diff --git a/tools/docker/cuda11.2.2-cudnn8-ubuntu20.04-manylinux2014-python3.8.Dockerfile b/tools/docker/cuda11.2.2-cudnn8-ubuntu20.04-manylinux2014-python3.8.Dockerfile new file mode 100644 index 000000000..748d60e40 --- /dev/null +++ b/tools/docker/cuda11.2.2-cudnn8-ubuntu20.04-manylinux2014-python3.8.Dockerfile @@ -0,0 +1,103 @@ +# Dockerfile to build a manylinux 2010 compliant cross-compiler. +# +# Builds a devtoolset gcc/libstdc++ that targets manylinux 2010 compatible +# glibc (2.12) and system libstdc++ (4.4). +# +# To push a new version, run: +# $ docker build -f cuda11.2.2-cudnn8-ubuntu20.04-manylinux2014-python3.8.Dockerfile . \ +# --tag "tfra/nosla-cuda11.2.2-cudnn8-ubuntu20.04-manylinux2014-python3.8" +# $ docker push tfra/nosla-cuda11.2.2-cudnn8-ubuntu20.04-manylinux2014-python3.8 + +FROM nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04 as devtoolset + +RUN chmod 777 /tmp/ +ENV DEBIAN_FRONTEND=noninteractive +RUN apt-get update && apt-get install -y \ + cpio \ + file \ + flex \ + g++ \ + make \ + patch \ + rpm2cpio \ + unar \ + wget \ + xz-utils \ + libjpeg-dev \ + zlib1g-dev \ + libgflags-dev \ + libsnappy-dev \ + libbz2-dev \ + liblz4-dev \ + libzstd-dev \ + openssh-client \ + && \ + rm -rf /var/lib/apt/lists/* + +ADD devtoolset/fixlinks.sh fixlinks.sh +ADD devtoolset/build_devtoolset.sh build_devtoolset.sh +ADD devtoolset/rpm-patch.sh rpm-patch.sh + +# Set up a sysroot for glibc 2.12 / libstdc++ 4.4 / devtoolset-8 in /dt8. +RUN /build_devtoolset.sh devtoolset-8 /dt8 + +# TODO(klimek): Split up into two different docker images. +FROM nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04 +COPY --from=devtoolset /dt8 /dt8 + +# Install TensorRT. +RUN echo \ + deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64 / \ + > /etc/apt/sources.list.d/nvidia-ml.list \ + && \ + apt-key adv --keyserver keyserver.ubuntu.com --recv-keys F60F4B3D7FA2AF80 && \ + apt-get update && \ + rm -rf /var/lib/apt/lists/* + +# Copy and run the install scripts. +ARG DEBIAN_FRONTEND=noninteractive + +COPY install/install_bootstrap_deb_packages.sh /install/ +RUN /install/install_bootstrap_deb_packages.sh + +COPY install/install_deb_packages.sh /install/ +RUN /install/install_deb_packages.sh + +# Install additional packages needed for this image: +# - dependencies to build Python from source +# - patchelf, as it is required by auditwheel +RUN apt-get update && apt-get install -y \ + libbz2-dev \ + libffi-dev \ + libgdbm-dev \ + libncurses5-dev \ + libnss3-dev \ + libreadline-dev \ + patchelf \ + gcc-multilib \ + && \ + rm -rf /var/lib/apt/lists/* + +RUN chmod 777 /tmp/ +WORKDIR /tmp/ + +COPY install/install_nccl.sh /install/ +RUN /install/install_nccl.sh "2.8.4-1+cuda11.2" + +COPY install/install_bazel.sh /install/ +RUN /install/install_bazel.sh "5.1.1" + +COPY install/build_and_install_python.sh /install/ +RUN /install/build_and_install_python.sh "3.8.2" + +COPY install/install_pip_packages_by_version.sh /install/ +RUN /install/install_pip_packages_by_version.sh "/usr/local/bin/pip3.8" + +COPY install/use_devtoolset_8.sh /install/ +RUN /install/use_devtoolset_8.sh + +COPY install/install_openmpi.sh /install/ +RUN /install/install_openmpi.sh "4.1.1" + +# clean +RUN rm -rf /tmp/* diff --git a/tools/docker/cuda11.2.2-cudnn8-ubuntu20.04-manylinux2014-python3.9.Dockerfile b/tools/docker/cuda11.2.2-cudnn8-ubuntu20.04-manylinux2014-python3.9.Dockerfile new file mode 100644 index 000000000..2c9fb4e81 --- /dev/null +++ b/tools/docker/cuda11.2.2-cudnn8-ubuntu20.04-manylinux2014-python3.9.Dockerfile @@ -0,0 +1,103 @@ +# Dockerfile to build a manylinux 2010 compliant cross-compiler. +# +# Builds a devtoolset gcc/libstdc++ that targets manylinux 2010 compatible +# glibc (2.12) and system libstdc++ (4.4). +# +# To push a new version, run: +# $ docker build -f cuda11.2.2-cudnn8-ubuntu20.04-manylinux2014-python3.9.Dockerfile . \ +# --tag "tfra/nosla-cuda11.2.2-cudnn8-ubuntu20.04-manylinux2014-python3.9" +# $ docker push tfra/nosla-cuda11.2.2-cudnn8-ubuntu20.04-manylinux2014-python3.9 + +FROM nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04 as devtoolset + +RUN chmod 777 /tmp/ +ENV DEBIAN_FRONTEND=noninteractive +RUN apt-get update && apt-get install -y \ + cpio \ + file \ + flex \ + g++ \ + make \ + patch \ + rpm2cpio \ + unar \ + wget \ + xz-utils \ + libjpeg-dev \ + zlib1g-dev \ + libgflags-dev \ + libsnappy-dev \ + libbz2-dev \ + liblz4-dev \ + libzstd-dev \ + openssh-client \ + && \ + rm -rf /var/lib/apt/lists/* + +ADD devtoolset/fixlinks.sh fixlinks.sh +ADD devtoolset/build_devtoolset.sh build_devtoolset.sh +ADD devtoolset/rpm-patch.sh rpm-patch.sh + +# Set up a sysroot for glibc 2.12 / libstdc++ 4.4 / devtoolset-8 in /dt8. +RUN /build_devtoolset.sh devtoolset-8 /dt8 + +# TODO(klimek): Split up into two different docker images. +FROM nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04 +COPY --from=devtoolset /dt8 /dt8 + +# Install TensorRT. +RUN echo \ + deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64 / \ + > /etc/apt/sources.list.d/nvidia-ml.list \ + && \ + apt-key adv --keyserver keyserver.ubuntu.com --recv-keys F60F4B3D7FA2AF80 && \ + apt-get update && \ + rm -rf /var/lib/apt/lists/* + +# Copy and run the install scripts. +ARG DEBIAN_FRONTEND=noninteractive + +COPY install/install_bootstrap_deb_packages.sh /install/ +RUN /install/install_bootstrap_deb_packages.sh + +COPY install/install_deb_packages.sh /install/ +RUN /install/install_deb_packages.sh + +# Install additional packages needed for this image: +# - dependencies to build Python from source +# - patchelf, as it is required by auditwheel +RUN apt-get update && apt-get install -y \ + libbz2-dev \ + libffi-dev \ + libgdbm-dev \ + libncurses5-dev \ + libnss3-dev \ + libreadline-dev \ + patchelf \ + gcc-multilib \ + && \ + rm -rf /var/lib/apt/lists/* + +RUN chmod 777 /tmp/ +WORKDIR /tmp/ + +COPY install/install_nccl.sh /install/ +RUN /install/install_nccl.sh "2.8.4-1+cuda11.2" + +COPY install/install_bazel.sh /install/ +RUN /install/install_bazel.sh "5.1.1" + +COPY install/build_and_install_python.sh /install/ +RUN /install/build_and_install_python.sh "3.9.7" + +COPY install/install_pip_packages_by_version.sh /install/ +RUN /install/install_pip_packages_by_version.sh "/usr/local/bin/pip3.9" + +COPY install/use_devtoolset_8.sh /install/ +RUN /install/use_devtoolset_8.sh + +COPY install/install_openmpi.sh /install/ +RUN /install/install_openmpi.sh "4.1.1" + +# clean +RUN rm -rf /tmp/* diff --git a/tools/docker/install/use_devtoolset_8.sh b/tools/docker/install/use_devtoolset_8.sh index 8f7d85c95..0a6b3cbfd 100755 --- a/tools/docker/install/use_devtoolset_8.sh +++ b/tools/docker/install/use_devtoolset_8.sh @@ -14,7 +14,7 @@ # limitations under the License. # ============================================================================== -# Use devtoolset-7 as tool chain +# Use devtoolset-8 as tool chain rm -r /usr/bin/gcc* export PATH=/dt8/usr/bin:${PATH} export PATH=/usr/bin/:/usr/local/bin/:${PATH}