Skip to content

Commit

Permalink
Adapt to different TF versions in Custom OP.
Browse files Browse the repository at this point in the history
  • Loading branch information
Lifann authored and rhdong committed May 12, 2021
1 parent 2f1b624 commit 5bf4bd8
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 25 deletions.
4 changes: 1 addition & 3 deletions build_deps/tf_dependency/build_defs.bzl.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,4 @@

D_GLIBCXX_USE_CXX11_ABI = "%{tf_cx11_abi}"

DTF_MAJOR_VERSION = "%{tf_major_version}"
DTF_MINOR_VERSION = "%{tf_minor_version}"
DTF_PATCH_VERSION = "%{tf_patch_version}"
DTF_VERSION = "%{tf_version}"
14 changes: 3 additions & 11 deletions build_deps/tf_dependency/tf_configure.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,7 @@ _TF_SHARED_LIBRARY_NAME = "TF_SHARED_LIBRARY_NAME"

_TF_CXX11_ABI_FLAG = "TF_CXX11_ABI_FLAG"

TF_MAJOR_VERSION = "TF_MAJOR_VERSION"

TF_MINOR_VERSION = "TF_MINOR_VERSION"

TF_PATCH_VERSION = "TF_PATCH_VERSION"
TF_VERSION = "TF_VERSION"

def _tpl(repository_ctx, tpl, substitutions = {}, out = None):
if not out:
Expand Down Expand Up @@ -210,9 +206,7 @@ def _tf_pip_impl(repository_ctx):
tf_shared_library_name = repository_ctx.os.environ[_TF_SHARED_LIBRARY_NAME]
tf_shared_library_path = "%s/%s" % (tf_shared_library_dir, tf_shared_library_name)
tf_cx11_abi = "-D_GLIBCXX_USE_CXX11_ABI=%s" % (repository_ctx.os.environ[_TF_CXX11_ABI_FLAG])
tf_major_version = "-DTF_MAJOR_VERSION=%s" % (repository_ctx.os.environ[TF_MAJOR_VERSION])
tf_minor_version = "-DTF_MINOR_VERSION=%s" % (repository_ctx.os.environ[TF_MINOR_VERSION])
tf_patch_version = "-DTF_PATCH_VERSION=%s" % (repository_ctx.os.environ[TF_PATCH_VERSION])
tf_version = "-DTF_VERSION=%s" % (repository_ctx.os.environ[TF_VERSION])

tf_shared_library_rule = _symlink_genrule_for_dir(
repository_ctx,
Expand All @@ -234,9 +228,7 @@ def _tf_pip_impl(repository_ctx):
"build_defs.bzl",
{
"%{tf_cx11_abi}": tf_cx11_abi,
"%{tf_major_version}": tf_major_version,
"%{tf_minor_version}": tf_minor_version,
"%{tf_patch_version}": tf_patch_version,
"%{tf_version}": tf_version,
},
)

Expand Down
38 changes: 33 additions & 5 deletions configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,42 @@ def get_shared_lib_name():


def get_tf_version():
"""
Get Tensorflow version as a 4 digits string.
For example:
1.15.2 get 1152
2.4.1 get 2041
The 4-digits-string will be passed to C macro to discriminate different
Tensorflow versions.
We assume that major version has 1 digit, minor version has 2 digits. And
patch version has 1 digit.
"""
version = tf.__version__
try:
major, minor, patch = version.split('.')
assert len(
major
) == 1, "Tensorflow major version must be length of 1. Version: {}".format(
version)
assert len(
minor
) <= 2, "Tensorflow minor version must be less or equal to 2. Version: {}".format(
version)
assert len(
patch
) == 1, "Tensorflow patch version must be length of 1. Version: {}".format(
version)
except:
raise ValueError('got wrong tf.__version__: {}'.format(version))
return version, major, minor, patch
tf_version_num = str(int(major) * 1000 + int(minor) * 10 + int(patch))
if len(tf_version_num) != 4:
raise ValueError('Tensorflow version flag must be length of 4 (major'
' version: 1, minor version: 2, patch_version: 1). But'
' get: {}'.format(tf_version_num))
return tf_version_num


def create_build_configuration():
Expand All @@ -115,12 +145,10 @@ def create_build_configuration():
write_action_env("TF_SHARED_LIBRARY_NAME", get_shared_lib_name())
write_action_env("TF_CXX11_ABI_FLAG", tf.sysconfig.CXX11_ABI_FLAG)

_, tf_major_version, tf_minor_version, tf_patch_version = get_tf_version()
tf_version = get_tf_version()
# This is used to trace the difference between Tensorflow versions.
# TODO(Lifann) write them to enviroment variables.
write_action_env("TF_MAJOR_VERSION", tf_major_version)
write_action_env("TF_MINOR_VERSION", tf_minor_version)
write_action_env("TF_PATCH_VERSION", tf_patch_version)
write_action_env("TF_VERSION", tf_version)

write("build --spawn_strategy=standalone")
write("build --strategy=Genrule=standalone")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,11 @@ limitations under the License.
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

#if GOOGLE_CUDA
#if TF_VERSION >= 2040 // 2.4.0
#include "tensorflow/core/util/cuda_solvers.h"
#else
#include "tensorflow/core/kernels/cuda_solvers.h"
#endif // TF_VERSION >= 2040
#include "tensorflow/stream_executor/cuda/cuda_activation.h"

using stream_executor::cuda::ScopedActivateExecutorContext;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
load(
"@local_config_tf//:build_defs.bzl",
"DTF_MAJOR_VERSION",
"DTF_MINOR_VERSION",
"DTF_PATCH_VERSION",
"DTF_VERSION",
"D_GLIBCXX_USE_CXX11_ABI",
)
load(
Expand Down Expand Up @@ -65,9 +63,7 @@ def custom_op_library(
"-pthread",
"-std=c++14",
D_GLIBCXX_USE_CXX11_ABI,
DTF_MAJOR_VERSION,
DTF_MINOR_VERSION,
DTF_PATCH_VERSION,
DTF_VERSION,
],
})

Expand Down

0 comments on commit 5bf4bd8

Please sign in to comment.