diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/tf_patch.py b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/tf_patch.py index b7b55bc3e..b33d0f7ec 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/tf_patch.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/tf_patch.py @@ -44,6 +44,8 @@ from tensorflow.python.training import optimizer from tensorflow.python.training import slot_creator +_PARTITION_SHAPE = 'partition_shape' + class _DenseDynamicEmbeddingTrainableProcessor(optimizer._OptimizableVariable): """Processor for dense DynamicEmbedding.""" @@ -238,6 +240,13 @@ def device_function(self, op): return worker_device.to_string() +def _assert_float_dtype(dtype): + dtype = dtypes.as_dtype(dtype) + if not dtype.is_floating: + raise ValueError("Expected floating point type, got %s." % dtype) + return dtype + + def _compute_fans_for_keras_init_v1_v2(shape): """ Making keras VarianceScaling initializers v1 & v2 support dynamic shape. """ @@ -300,14 +309,22 @@ def __call__for_keras_init_v1(self, shape, dtype=None, partition_info=None): def __call__for_keras_init_v2(self, shape, dtype=None, **kwargs): """ Making keras VarianceScaling initializers v2 support dynamic shape. """ - kinit2._validate_kwargs(self.__class__.__name__, kwargs) - dtype = kinit2._assert_float_dtype(kinit2._get_dtype(dtype)) + if hasattr(kinit2, "_validate_kwargs"): + kinit2._validate_kwargs(self.__class__.__name__, kwargs) + elif hasattr(self, "_validate_kwargs"): + self._validate_kwargs(kwargs) + + if hasattr(kinit2, "_get_dtype"): + dtype = _assert_float_dtype(kinit2._get_dtype(dtype)) + else: + dtype = _assert_float_dtype(dtype) + scale = self.scale fan_in, fan_out = _compute_fans_for_keras_init_v1_v2(shape) fan_in = math_ops.cast(fan_in, dtype=dtype) fan_out = math_ops.cast(fan_out, dtype=dtype) - if kinit2._PARTITION_SHAPE in kwargs: - shape = kwargs[kinit2._PARTITION_SHAPE] + if _PARTITION_SHAPE in kwargs: + shape = kwargs[_PARTITION_SHAPE] if self.mode == 'fan_in': scale /= math_ops.maximum(1., fan_in) elif self.mode == 'fan_out':