Skip to content

Commit

Permalink
Fix dtype and assign* in AutocastVariable.
Browse files Browse the repository at this point in the history
The `dtype` property would return to true dtype of the variable, instead of the dtype of the value that you get explicitly via `.value()` or implicitly by doing any operation.

This would cause seemingly correct things like this to fail with a dtype mismatch:
```
y = variable * tf.cast(x, variable.dtype)
```

Forcing users to write workarounds like:
```
v = variable.value()
y = variable * tf.cast(x, v.dtype)
```

Additionally, `assign`, `assign_add`, `assign_sub` expected the value to be of the true dtype, not the cast dtype.

This would cause seemingly correct things like this to fail with a dtype mismatch:
```
variable.assign(variable * factor)
```
(This is a common use case for non-trainable variables.)

Forcing users to write workarounds like:
```
variable.assign(tf.cast(variable * factor, variable.dtype))
```

This changes fixes these issues to make autocasting fully transparent:
- `dtype` returns the cast dtype if applicable
- `assign*` accept the cast dtype for the value if applicable

Note that this is consistent with how autocasting works in Keras 3.

PiperOrigin-RevId: 647135376
  • Loading branch information
hertschuh authored and tensorflower-gardener committed Jul 8, 2024
1 parent 909a2a4 commit af20f6f
Showing 1 changed file with 8 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -233,19 +233,25 @@ def build(self, input_shape):

# For each of the prunable weights, add mask and threshold variables
for weight in self.prunable_weights:
# Under a mixed precision policy, variables report their "cast" dtype.
# However, we want to use the original dtype for mask and threshold.
if hasattr(weight, 'true_dtype'):
dtype = weight.true_dtype
else:
dtype = weight.dtype
mask = self.add_weight(
'mask',
shape=weight.shape,
initializer=keras.initializers.get('ones'),
dtype=weight.dtype,
dtype=dtype,
trainable=False,
aggregation=tf.VariableAggregation.MEAN,
)
threshold = self.add_weight(
'threshold',
shape=[],
initializer=keras.initializers.get('zeros'),
dtype=weight.dtype,
dtype=dtype,
trainable=False,
aggregation=tf.VariableAggregation.MEAN,
)
Expand Down

0 comments on commit af20f6f

Please sign in to comment.