From 7f45aa242a00197c257f3797ef487c243f583e9f Mon Sep 17 00:00:00 2001 From: MoFHeka Date: Thu, 30 May 2024 11:24:37 +0800 Subject: [PATCH] [feat] Add bfloat16(bf16) value type support in Redis and CPU table backend. --- .../core/kernels/cuckoo_hashtable_op.cc | 1 + .../kernels/lookup_impl/lookup_table_op_cpu.h | 1 + ...lookup_table_op_cpu_impl_int32_bfloat16.cc | 25 +++++++++++++++++++ .../core/kernels/redis_table_op.cc | 3 +++ .../dynamic_embedding_variable_test.py | 11 ++++---- .../kernel_tests/redis_table_variable_test.py | 9 ++++--- .../python/ops/dynamic_embedding_variable.py | 1 + 7 files changed, 42 insertions(+), 9 deletions(-) create mode 100644 tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_cpu_impl_int32_bfloat16.cc diff --git a/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/cuckoo_hashtable_op.cc b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/cuckoo_hashtable_op.cc index 1290688e8..ab00a58db 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/cuckoo_hashtable_op.cc +++ b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/cuckoo_hashtable_op.cc @@ -995,6 +995,7 @@ REGISTER_KERNEL_BUILDER( REGISTER_KERNEL(int32, double); REGISTER_KERNEL(int32, float); REGISTER_KERNEL(int32, int32); +REGISTER_KERNEL(int32, bfloat16); REGISTER_KERNEL(int64, double); REGISTER_KERNEL(int64, float); REGISTER_KERNEL(int64, int32); diff --git a/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_cpu.h b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_cpu.h index c54a2584b..b74285e1a 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_cpu.h +++ b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_cpu.h @@ -474,6 +474,7 @@ void CreateTableImpl(TableWrapperBase** pptable, size_t init_size, DECLARE_CREATE_TABLE(int32, double); DECLARE_CREATE_TABLE(int32, float); DECLARE_CREATE_TABLE(int32, int32); +DECLARE_CREATE_TABLE(int32, bfloat16); DECLARE_CREATE_TABLE(int64, double); DECLARE_CREATE_TABLE(int64, float); DECLARE_CREATE_TABLE(int64, int32); diff --git a/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_cpu_impl_int32_bfloat16.cc b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_cpu_impl_int32_bfloat16.cc new file mode 100644 index 000000000..ce4b7bf2a --- /dev/null +++ b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_cpu_impl_int32_bfloat16.cc @@ -0,0 +1,25 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_cpu.h" +namespace tensorflow { +namespace recommenders_addons { +namespace lookup { +namespace cpu { +DEFINE_CREATE_TABLE(int32, bfloat16, 0, 0); +} // namespace cpu +} // namespace lookup +} // namespace recommenders_addons +} // namespace tensorflow diff --git a/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/redis_table_op.cc b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/redis_table_op.cc index 00a5c4688..dfa922402 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/redis_table_op.cc +++ b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/redis_table_op.cc @@ -1842,6 +1842,7 @@ REGISTER_KERNEL_BUILDER( REGISTER_KERNEL(int32, double); REGISTER_KERNEL(int32, float); REGISTER_KERNEL(int32, int32); +REGISTER_KERNEL(int32, bfloat16); REGISTER_KERNEL(int64_t, double); REGISTER_KERNEL(int64_t, float); REGISTER_KERNEL(int64_t, int32); @@ -1849,6 +1850,7 @@ REGISTER_KERNEL(int64_t, int64_t); REGISTER_KERNEL(int64_t, tstring); REGISTER_KERNEL(int64_t, int8); REGISTER_KERNEL(int64_t, Eigen::half); +REGISTER_KERNEL(int64_t, bfloat16); REGISTER_KERNEL(tstring, bool); REGISTER_KERNEL(tstring, double); REGISTER_KERNEL(tstring, float); @@ -1856,6 +1858,7 @@ REGISTER_KERNEL(tstring, int32); REGISTER_KERNEL(tstring, int64_t); REGISTER_KERNEL(tstring, int8); REGISTER_KERNEL(tstring, Eigen::half); +REGISTER_KERNEL(tstring, bfloat16); #undef REGISTER_KERNEL diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/dynamic_embedding_variable_test.py b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/dynamic_embedding_variable_test.py index dd0472e33..036525fae 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/dynamic_embedding_variable_test.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/dynamic_embedding_variable_test.py @@ -466,11 +466,12 @@ def test_variable_find_with_exists_and_accum(self): else: dim_list = [1, 8, 16, 128] kv_list = [[dtypes.int32, dtypes.double], [dtypes.int32, dtypes.float32], - [dtypes.int32, dtypes.int32], [dtypes.int64, dtypes.double], - [dtypes.int64, dtypes.float32], [dtypes.int64, dtypes.int32], - [dtypes.int64, dtypes.int64], [dtypes.int64, dtypes.int8], - [dtypes.int64, dtypes.half], [dtypes.int64, dtypes.bfloat16], - [dtypes.string, dtypes.double], + [dtypes.int32, dtypes.int32], [dtypes.int32, dtypes.bfloat16], + [dtypes.int64, dtypes.double], [dtypes.int64, dtypes.float32], + [dtypes.int64, dtypes.int32], [dtypes.int64, dtypes.int64], + [dtypes.int64, dtypes.int8], [dtypes.int64, dtypes.half], + [dtypes.int64, + dtypes.bfloat16], [dtypes.string, dtypes.double], [dtypes.string, dtypes.float32], [dtypes.string, dtypes.int32], [dtypes.string, dtypes.int64], [dtypes.string, dtypes.int8], [dtypes.string, dtypes.half]] diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/redis_table_variable_test.py b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/redis_table_variable_test.py index c4e95c01f..a17fc1a7c 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/redis_table_variable_test.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/redis_table_variable_test.py @@ -351,11 +351,12 @@ def test_variable(self): else: dim_list = [1, 8, 16, 128] kv_list = [[dtypes.int32, dtypes.double], [dtypes.int32, dtypes.float32], - [dtypes.int32, dtypes.int32], [dtypes.int64, dtypes.double], - [dtypes.int64, dtypes.float32], [dtypes.int64, dtypes.int32], - [dtypes.int64, dtypes.int64], [dtypes.int64, dtypes.string], + [dtypes.int32, dtypes.int32], [dtypes.int32, dtypes.bfloat16], + [dtypes.int64, dtypes.double], [dtypes.int64, dtypes.float32], + [dtypes.int64, dtypes.int32], [dtypes.int64, dtypes.int64], [dtypes.int64, dtypes.int8], [dtypes.int64, dtypes.half], - [dtypes.string, dtypes.double], + [dtypes.int64, + dtypes.bfloat16], [dtypes.string, dtypes.double], [dtypes.string, dtypes.float32], [dtypes.string, dtypes.int32], [dtypes.string, dtypes.int64], [dtypes.string, dtypes.int8], [dtypes.string, dtypes.half]] diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_variable.py b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_variable.py index 02c77c9b7..02928517c 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_variable.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_variable.py @@ -585,6 +585,7 @@ def _get_default_devices(): [dtypes.int32, dtypes.float32], [dtypes.int32, dtypes.int32], [dtypes.int32, dtypes.float64], + [dtypes.int32, dtypes.bfloat16], [dtypes.string, dtypes.float32], [dtypes.string, dtypes.half], [dtypes.string, dtypes.bfloat16],