From 42ad7d251e3383fb00d0db84189a4c14067a6d59 Mon Sep 17 00:00:00 2001 From: MoFHeka Date: Wed, 31 Jan 2024 14:40:24 +0800 Subject: [PATCH] [feat] Add bfloat16 value type support to the HKV for being enhanced by Ampere GPU BF16 training feature. --- .../dynamic_embedding/core/kernels/hkv_hashtable_op_gpu.cu.cc | 1 + .../core/kernels/lookup_impl/lookup_table_op_hkv_impl.cu.cc | 1 + .../python/kernel_tests/dynamic_embedding_variable_test.py | 2 +- .../dynamic_embedding/python/ops/dynamic_embedding_variable.py | 1 + 4 files changed, 4 insertions(+), 1 deletion(-) diff --git a/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/hkv_hashtable_op_gpu.cu.cc b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/hkv_hashtable_op_gpu.cu.cc index 54724ead5..86b8cfb56 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/hkv_hashtable_op_gpu.cu.cc +++ b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/hkv_hashtable_op_gpu.cu.cc @@ -1048,6 +1048,7 @@ REGISTER_KERNEL(int64, int8); REGISTER_KERNEL(int64, int32); REGISTER_KERNEL(int64, int64); REGISTER_KERNEL(int64, Eigen::half); +REGISTER_KERNEL(int64, Eigen::bfloat16); #undef REGISTER_KERNEL diff --git a/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_hkv_impl.cu.cc b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_hkv_impl.cu.cc index 3f529f5eb..8f50afd6e 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_hkv_impl.cu.cc +++ b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_hkv_impl.cu.cc @@ -29,6 +29,7 @@ DEFINE_PURE_GPU_HASHTABLE(int64, int8); DEFINE_PURE_GPU_HASHTABLE(int64, int32); DEFINE_PURE_GPU_HASHTABLE(int64, int64); DEFINE_PURE_GPU_HASHTABLE(int64, Eigen::half); +DEFINE_PURE_GPU_HASHTABLE(int64, Eigen::bfloat16); #undef DEFINE_PURE_GPU_HASHTABLE diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/dynamic_embedding_variable_test.py b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/dynamic_embedding_variable_test.py index cb3d1c477..1cef73b9e 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/dynamic_embedding_variable_test.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/dynamic_embedding_variable_test.py @@ -384,7 +384,7 @@ def test_variable(self): dim_list = [1, 2, 4, 8, 10, 16, 32, 64, 100, 200] kv_list = [[dtypes.int64, dtypes.float32], [dtypes.int64, dtypes.int32], [dtypes.int64, dtypes.half], [dtypes.int64, dtypes.int8], - [dtypes.int64, dtypes.int64]] + [dtypes.int64, dtypes.int64], [dtypes.int64, dtypes.bfloat16]] else: dim_list = [1, 8, 16, 128] kv_list = [[dtypes.int32, dtypes.double], [dtypes.int32, dtypes.float32], diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_variable.py b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_variable.py index ddbb40dc8..30a5c6543 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_variable.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_variable.py @@ -589,6 +589,7 @@ def _get_default_devices(): [dtypes.int64, dtypes.int32], [dtypes.int64, dtypes.int64], [dtypes.int64, dtypes.half], + [dtypes.int64, dtypes.bfloat16], ] if is_macos() and is_arm64(): if value_dtype == dtypes.half or value_dtype == dtypes.bfloat16: