From af20f6f94b5d0f5f73a32c194617b7fa85fc663c Mon Sep 17 00:00:00 2001 From: Fabien Hertschuh Date: Wed, 26 Jun 2024 17:45:29 -0700 Subject: [PATCH] Fix `dtype` and `assign*` in `AutocastVariable`. The `dtype` property would return to true dtype of the variable, instead of the dtype of the value that you get explicitly via `.value()` or implicitly by doing any operation. This would cause seemingly correct things like this to fail with a dtype mismatch: ``` y = variable * tf.cast(x, variable.dtype) ``` Forcing users to write workarounds like: ``` v = variable.value() y = variable * tf.cast(x, v.dtype) ``` Additionally, `assign`, `assign_add`, `assign_sub` expected the value to be of the true dtype, not the cast dtype. This would cause seemingly correct things like this to fail with a dtype mismatch: ``` variable.assign(variable * factor) ``` (This is a common use case for non-trainable variables.) Forcing users to write workarounds like: ``` variable.assign(tf.cast(variable * factor, variable.dtype)) ``` This changes fixes these issues to make autocasting fully transparent: - `dtype` returns the cast dtype if applicable - `assign*` accept the cast dtype for the value if applicable Note that this is consistent with how autocasting works in Keras 3. PiperOrigin-RevId: 647135376 --- .../python/core/sparsity/keras/pruning_wrapper.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py index d134640a3..02fa34a7e 100644 --- a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py +++ b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py @@ -233,11 +233,17 @@ def build(self, input_shape): # For each of the prunable weights, add mask and threshold variables for weight in self.prunable_weights: + # Under a mixed precision policy, variables report their "cast" dtype. + # However, we want to use the original dtype for mask and threshold. + if hasattr(weight, 'true_dtype'): + dtype = weight.true_dtype + else: + dtype = weight.dtype mask = self.add_weight( 'mask', shape=weight.shape, initializer=keras.initializers.get('ones'), - dtype=weight.dtype, + dtype=dtype, trainable=False, aggregation=tf.VariableAggregation.MEAN, ) @@ -245,7 +251,7 @@ def build(self, input_shape): 'threshold', shape=[], initializer=keras.initializers.get('zeros'), - dtype=weight.dtype, + dtype=dtype, trainable=False, aggregation=tf.VariableAggregation.MEAN, )