Skip to content

Commit

Permalink
[Feat] Compatible with Tensorflow later than version 2.11.0.
Browse files Browse the repository at this point in the history
It also supports automatic detection of TF compile cxx standard versions.
  • Loading branch information
MoFHeka committed Jun 8, 2023
1 parent 884bab5 commit 27b6cb2
Show file tree
Hide file tree
Showing 37 changed files with 633 additions and 213 deletions.
2 changes: 2 additions & 0 deletions build_deps/tf_dependency/build_defs.bzl.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

D_GLIBCXX_USE_CXX11_ABI = "%{tf_cx11_abi}"

TF_CXX_STANDARD = "%{tf_cxx_standard}"

DTF_VERSION_INTEGER = "%{tf_version_integer}"

FOR_TF_SERVING = "%{for_tf_serving}"
5 changes: 5 additions & 0 deletions build_deps/tf_dependency/tf_configure.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ _TF_SHARED_LIBRARY_NAME = "TF_SHARED_LIBRARY_NAME"

_TF_CXX11_ABI_FLAG = "TF_CXX11_ABI_FLAG"

_TF_CXX_STANDARD = "TF_CXX_STANDARD"

_FOR_TF_SERVING = "FOR_TF_SERVING"

_TF_VERSION_INTEGER = "TF_VERSION_INTEGER"
Expand Down Expand Up @@ -208,6 +210,7 @@ def _tf_pip_impl(repository_ctx):
tf_shared_library_name = repository_ctx.os.environ[_TF_SHARED_LIBRARY_NAME]
tf_shared_library_path = "%s/%s" % (tf_shared_library_dir, tf_shared_library_name)
tf_cx11_abi = "-D_GLIBCXX_USE_CXX11_ABI=%s" % (repository_ctx.os.environ[_TF_CXX11_ABI_FLAG])
tf_cxx_standard = "%s" % (repository_ctx.os.environ[_TF_CXX_STANDARD])
tf_version_integer = "-DTF_VERSION_INTEGER=%s" % (repository_ctx.os.environ[_TF_VERSION_INTEGER])
for_tf_serving = repository_ctx.os.environ[_FOR_TF_SERVING]

Expand All @@ -231,6 +234,7 @@ def _tf_pip_impl(repository_ctx):
"build_defs.bzl",
{
"%{tf_cx11_abi}": tf_cx11_abi,
"%{tf_cxx_standard}": tf_cxx_standard,
"%{tf_version_integer}": tf_version_integer,
"%{for_tf_serving}": for_tf_serving,
},
Expand All @@ -242,6 +246,7 @@ tf_configure = repository_rule(
_TF_SHARED_LIBRARY_DIR,
_TF_SHARED_LIBRARY_NAME,
_TF_CXX11_ABI_FLAG,
_TF_CXX_STANDARD,
_FOR_TF_SERVING,
],
implementation = _tf_pip_impl,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def InvokeNvcc(argv, log=False):
undefines = ''.join([' -U' + define for define in undefines])
std_options = GetOptionValue(argv, '-std')
# Supported -std flags as of CUDA 9.0. Only keep last to mimic gcc/clang.
nvcc_allowed_std_options = ["c++03", "c++11", "c++14"]
nvcc_allowed_std_options = ["c++03", "c++11", "c++14", "c++17"]
std_options = ''.join([' -std=' + define
for define in std_options if define in nvcc_allowed_std_options][-1:])
fatbin_options = ''.join([' --fatbin-options=' + option
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def InvokeNvcc(argv, log=False):
undefines = ''.join([' -U' + define for define in undefines])
std_options = GetOptionValue(argv, 'std')
# currently only c++14 is supported by Cuda 10.0 std argument
nvcc_allowed_std_options = ["c++14"]
nvcc_allowed_std_options = ["c++03", "c++11", "c++14", "c++17"]
std_options = ''.join([' -std=' + define
for define in std_options if define in nvcc_allowed_std_options])

Expand Down
25 changes: 23 additions & 2 deletions build_deps/toolchains/gpu/cuda_configure.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ _DEFAULT_CUDA_COMPUTE_CAPABILITIES.update(
"7.5",
"8.0",
"8.6",
] for v in range(1, 8)},
] for v in range(0, 8)},
)

_DEFAULT_CUDA_COMPUTE_CAPABILITIES.update(
Expand All @@ -86,9 +86,23 @@ _DEFAULT_CUDA_COMPUTE_CAPABILITIES.update(
"8.0",
"8.6",
"8.9",
"9.0",
] for v in range(8, 9)},
)

_DEFAULT_CUDA_COMPUTE_CAPABILITIES.update(
{"12.{}".format(v): [
"6.0",
"6.1",
"7.0",
"7.5",
"8.0",
"8.6",
"8.9",
"9.0",
] for v in range(0, 8)},
)

def _get_python_bin(repository_ctx):
"""Gets the python bin path."""
python_bin = repository_ctx.os.environ.get(_PYTHON_BIN_PATH)
Expand Down Expand Up @@ -562,10 +576,17 @@ def _find_cuda_lib(
Returns the path to the library.
"""
file_name = lib_name(lib, cpu_value, version, static)
paths = ["%s/%s" % (basedir, file_name)]
if version:
# In cuda 12.1, the name of libcupti.so is no longer libcupti.so.12.1 but libcupti.so.2023.1.0.
# And there is still reserve a libcupti.so.12 link, so we need to find it but not "*.12.1".
major_version = version.split(".")[0]
file_name_major = lib_name(lib, cpu_value, major_version, static)
paths.append("%s/%s" % (basedir, file_name_major))

return find_lib(
repository_ctx,
["%s/%s" % (basedir, file_name)],
paths,
check_soname = version and not static,
)

Expand Down
7 changes: 4 additions & 3 deletions build_deps/toolchains/redis/redis-plus-plus.BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ load("@rules_foreign_cc//foreign_cc:defs.bzl", "cmake")
load(
"@local_config_tf//:build_defs.bzl",
"D_GLIBCXX_USE_CXX11_ABI",
"TF_CXX_STANDARD",
)

package(
Expand Down Expand Up @@ -30,10 +31,10 @@ cmake(
generate_args = [
"-DCMAKE_BUILD_TYPE=Release",
"-DREDIS_PLUS_PLUS_BUILD_TEST=OFF",
"-DREDIS_PLUS_PLUS_CXX_STANDARD=11",
"-DCMAKE_CXX_FLAGS="+D_GLIBCXX_USE_CXX11_ABI,
"-DREDIS_PLUS_PLUS_CXX_STANDARD=" + TF_CXX_STANDARD.split("c++")[-1],
"-DCMAKE_CXX_FLAGS=" + D_GLIBCXX_USE_CXX11_ABI,
],
lib_source = "@redis-plus-plus//:all_srcs",
out_static_libs = ["libredis++.a"],
deps = ["@hiredis"],
)
)
12 changes: 12 additions & 0 deletions configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,18 @@ def create_build_configuration():
write_action_env("TF_SHARED_LIBRARY_DIR", get_tf_shared_lib_dir())
write_action_env("TF_SHARED_LIBRARY_NAME", get_shared_lib_name())
write_action_env("TF_CXX11_ABI_FLAG", tf.sysconfig.CXX11_ABI_FLAG)
tf_cxx_standard_compile_flags = [
flag for flag in tf.sysconfig.get_compile_flags() if "-std=" in flag
]
if len(tf_cxx_standard_compile_flags) > 0:
tf_cxx_standard_compile_flag = tf_cxx_standard_compile_flags[-1]
else:
tf_cxx_standard_compile_flag = None
if tf_cxx_standard_compile_flag is None:
tf_cxx_standard = "c++14"
else:
tf_cxx_standard = tf_cxx_standard_compile_flag.split("-std=")[-1]
write_action_env("TF_CXX_STANDARD", tf_cxx_standard)

tf_version_integer = get_tf_version_integer()
# This is used to trace the difference between Tensorflow versions.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ class CuckooHashTableOfTensors final : public LookupInterface {
LaunchTensorsFind<CPUDevice, K, V> launcher(value_dim);
launcher.launch(ctx, table_, key, value, default_value);

return Status::OK();
return TFOkStatus;
}

Status FindWithExists(OpKernelContext* ctx, const Tensor& key, Tensor* value,
Expand All @@ -229,7 +229,7 @@ class CuckooHashTableOfTensors final : public LookupInterface {
LaunchTensorsFindWithExists<CPUDevice, K, V> launcher(value_dim);
launcher.launch(ctx, table_, key, value, default_value, exists);

return Status::OK();
return TFOkStatus;
}

Status DoInsert(bool clear, OpKernelContext* ctx, const Tensor& keys,
Expand All @@ -243,7 +243,7 @@ class CuckooHashTableOfTensors final : public LookupInterface {
LaunchTensorsInsert<CPUDevice, K, V> launcher(value_dim);
launcher.launch(ctx, table_, keys, values);

return Status::OK();
return TFOkStatus;
}

Status DoAccum(bool clear, OpKernelContext* ctx, const Tensor& keys,
Expand All @@ -257,7 +257,7 @@ class CuckooHashTableOfTensors final : public LookupInterface {
LaunchTensorsAccum<CPUDevice, K, V> launcher(value_dim);
launcher.launch(ctx, table_, keys, values_or_deltas, exists);

return Status::OK();
return TFOkStatus;
}

Status Insert(OpKernelContext* ctx, const Tensor& keys,
Expand All @@ -272,12 +272,12 @@ class CuckooHashTableOfTensors final : public LookupInterface {
for (int64 i = 0; i < key_flat.size(); ++i) {
table_->erase(tensorflow::lookup::SubtleMustCopyIfIntegral(key_flat(i)));
}
return Status::OK();
return TFOkStatus;
}

Status Clear(OpKernelContext* ctx) {
table_->clear();
return Status::OK();
return TFOkStatus;
}

Status Accum(OpKernelContext* ctx, const Tensor& keys,
Expand All @@ -304,7 +304,7 @@ class CuckooHashTableOfTensors final : public LookupInterface {
table_->dump((K*)keys->tensor_data().data(),
(V*)values->tensor_data().data(), 0, table_size);

return Status::OK();
return TFOkStatus;
}

Status SaveToFileSystemImpl(FileSystem* fs, const size_t value_dim,
Expand All @@ -319,7 +319,7 @@ class CuckooHashTableOfTensors final : public LookupInterface {
bool has_atomic_move = false;
auto has_atomic_move_ret = fs->HasAtomicMove(filepath, &has_atomic_move);
bool need_tmp_file =
(has_atomic_move == false) || (has_atomic_move_ret != Status::OK());
(has_atomic_move == false) || (has_atomic_move_ret != TFOkStatus);
if (!need_tmp_file) {
key_tmpfilepath = key_filepath;
value_tmpfilepath = value_filepath;
Expand Down Expand Up @@ -387,7 +387,7 @@ class CuckooHashTableOfTensors final : public LookupInterface {
TF_RETURN_IF_ERROR(fs->RenameFile(value_tmpfilepath, value_filepath));
}

return Status::OK();
return TFOkStatus;
}

Status SaveToFileSystem(OpKernelContext* ctx, const string& dirpath,
Expand Down Expand Up @@ -461,7 +461,7 @@ class CuckooHashTableOfTensors final : public LookupInterface {
LOG(INFO) << "Finish loading " << key_size << " keys and values from "
<< key_filepath << " and " << value_filepath << " in total.";

return Status::OK();
return TFOkStatus;
}

Status LoadFromFileSystem(OpKernelContext* ctx, const string& dirpath,
Expand Down Expand Up @@ -500,7 +500,7 @@ class CuckooHashTableOfTensors final : public LookupInterface {
string filepath = io::JoinPath(dirpath, file_name);
return LoadFromFileSystemImpl(fs, value_dim, filepath, buffer_size);
}
return Status::OK();
return TFOkStatus;
}

DataType key_dtype() const override { return DataTypeToEnum<K>::v(); }
Expand Down Expand Up @@ -557,7 +557,7 @@ class HashTableOpKernel : public OpKernel {
*container = h(0);
*table_handle = h(1);
}
return Status::OK();
return TFOkStatus;
}

Status GetResourceHashTable(StringPiece input_name, OpKernelContext* ctx,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/util/env_var.h"
#include "tensorflow_recommenders_addons/dynamic_embedding/core/utils/types.h"
#include "tensorflow_recommenders_addons/dynamic_embedding/core/utils/utils.h"

namespace tensorflow {
namespace recommenders_addons {
Expand Down Expand Up @@ -76,7 +77,7 @@ class HashTableOp : public OpKernel {
table_.AllocatedBytes());
}
*ret = container;
return Status::OK();
return TFOkStatus;
};

LookupInterface* table = nullptr;
Expand Down
Loading

0 comments on commit 27b6cb2

Please sign in to comment.