Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] version check to improve tf version compatibility #416

Merged
merged 1 commit into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def DynamicEmbeddingOptimizer(self, bp_v2=False, synchronous=False, **kwargs):
if hasattr(self, 'add_variable_from_reference'):
original_add_variable_from_reference = self.add_variable_from_reference

# pylint: disable=protected-access
def _distributed_apply(distribution, grads_and_vars, name, apply_state):
"""`apply_gradients` using a `DistributionStrategy`."""

Expand Down Expand Up @@ -208,9 +209,8 @@ def apply_grad_to_update_var(var, grad):
args=(grad,),
group=False)
replica_context = distribute_ctx.get_replica_context()
# pylint: disable=protected-access
if (replica_context is None or replica_context is
distribute_ctx._get_default_replica_context()):
if (replica_context is None or replica_context
is distribute_ctx._get_default_replica_context()):
# In cross-replica context, extended.update returns a list of
# update ops from all replicas (group=False).
update_ops.extend(update_op)
Expand Down
34 changes: 23 additions & 11 deletions tensorflow_recommenders_addons/utils/resource_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pkg_resources
import tensorflow as tf
import warnings
from packaging.version import parse as parse_version

abi_warning_already_raised = False
SKIP_CUSTOM_OPS = False
Expand All @@ -37,13 +38,32 @@ def get_required_tf_version():
"TFRA installation.",
UserWarning,
)
return tf.__version__
return tf.__version__, tf.__version__

pkg_info = pkg.requires()
low_version, high_version = None, None

for x in pkg_info:
if x.name in ["tensorflow", "tensorflow-gpu"]:
return x.specs[0][1]
assert False, "Fail to get required TensorFlow version of TFRA!"
for spec in x.specs:
if spec[0] == ">=":
low_version = spec[1]
elif spec[0] == "<=":
high_version = spec[1]
if low_version and high_version:
return low_version, high_version

assert False, f"Fail to get required TensorFlow version of TFRA: {pkg_info[0]} {low_version} {high_version}"


def abi_is_compatible():
if "dev" in tf.__version__:
return False
low_version, high_version = get_required_tf_version()

current_version = parse_version(tf.__version__)
return parse_version(low_version) <= current_version <= parse_version(
high_version)


def get_devices(device_type="GPU"):
Expand Down Expand Up @@ -138,14 +158,6 @@ def display_warning_if_incompatible(self):
abi_warning_already_raised = True


def abi_is_compatible():
if "dev" in tf.__version__:
return False

required_tf_version = get_required_tf_version()
return tf.__version__ == required_tf_version


def prefix_op_name(op_name):
"""
In order to keep compatibility of existing models,
Expand Down
54 changes: 54 additions & 0 deletions tensorflow_recommenders_addons/utils/tests/test_resource_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import unittest
from unittest.mock import patch, Mock

import pkg_resources
import tensorflow as tf

from tensorflow_recommenders_addons.utils.resource_loader import abi_is_compatible


class TestTensorFlowCompatibility(unittest.TestCase):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice job!


@patch('pkg_resources.get_distribution')
@patch('tensorflow.__version__', '2.12.0')
def test_compatible_version(self, mock_get_distribution):
mock_pkg = Mock()
mock_requirement = Mock()
mock_requirement.name = 'tensorflow'
mock_requirement.specs = [('>=', '2.11.0'), ('<=', '2.15.1')]
mock_pkg.requires.return_value = [mock_requirement]
mock_get_distribution.return_value = mock_pkg
self.assertTrue(abi_is_compatible())

@patch('pkg_resources.get_distribution')
@patch('tensorflow.__version__', '2.10.0')
def test_incompatible_version_below_range(self, mock_get_distribution):
mock_pkg = Mock()
mock_requirement = Mock()
mock_requirement.name = 'tensorflow'
mock_requirement.specs = [('>=', '2.11.0'), ('<=', '2.15.1')]
mock_pkg.requires.return_value = [mock_requirement]
mock_get_distribution.return_value = mock_pkg
self.assertFalse(abi_is_compatible())

@patch('pkg_resources.get_distribution')
@patch('tensorflow.__version__', '2.16.0')
def test_incompatible_version_above_range(self, mock_get_distribution):
mock_pkg = Mock()
mock_requirement = Mock()
mock_requirement.name = 'tensorflow'
mock_requirement.specs = [('>=', '2.11.0'), ('<=', '2.15.1')]
mock_pkg.requires.return_value = [mock_requirement]
mock_get_distribution.return_value = mock_pkg
self.assertFalse(abi_is_compatible())

@patch('pkg_resources.get_distribution')
@patch('tensorflow.__version__', '2.13.0-dev20240528')
def test_dev_version(self, mock_get_distribution):
mock_pkg = Mock()
mock_requirement = Mock()
mock_requirement.name = 'tensorflow'
mock_requirement.specs = [('>=', '2.11.0'), ('<=', '2.15.1')]
mock_pkg.requires.return_value = [mock_requirement]
mock_get_distribution.return_value = mock_pkg
self.assertFalse(abi_is_compatible())
31 changes: 26 additions & 5 deletions tensorflow_recommenders_addons/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,44 @@
# limitations under the License.
# ==============================================================================
"""Types for typing functions signatures."""
# pylint: disable=protected-access

from typing import Union, Callable, List

import numpy as np
import tensorflow as tf

Number = Union[float, int, np.float16, np.float32, np.float64, np.int8,
np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32,
np.uint64,]
Number = Union[
float,
int,
np.float16,
np.float32,
np.float64,
np.int8,
np.int16,
np.int32,
np.int64,
np.uint8,
np.uint16,
np.uint32,
np.uint64,
]

Initializer = Union[None, dict, str, Callable]
Regularizer = Union[None, dict, str, Callable]
Constraint = Union[None, dict, str, Callable]
Activation = Union[None, str, Callable]
Optimizer = Union[tf.keras.optimizers.Optimizer, str]

TensorLike = Union[List[Union[Number, list]], tuple, Number, np.ndarray,
tf.Tensor, tf.SparseTensor, tf.Variable,]
TensorLike = Union[
List[Union[Number, list]],
tuple,
Number,
np.ndarray,
tf.Tensor,
tf.SparseTensor,
tf.Variable,
]
FloatTensorLike = Union[tf.Tensor, float, np.float16, np.float32, np.float64]
AcceptableDTypes = Union[tf.DType, np.dtype, type, int, str, None]
# pylint: enable=protected-access
2 changes: 1 addition & 1 deletion tools/install_deps/yapf.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
yapf == 0.30.0
yapf == 0.40.2
Loading