Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix
dtype
and assign*
in AutocastVariable
.
The `dtype` property would return to true dtype of the variable, instead of the dtype of the value that you get explicitly via `.value()` or implicitly by doing any operation. This would cause seemingly correct things like this to fail with a dtype mismatch: ``` y = variable * tf.cast(x, variable.dtype) ``` Forcing users to write workarounds like: ``` v = variable.value() y = variable * tf.cast(x, v.dtype) ``` Additionally, `assign`, `assign_add`, `assign_sub` expected the value to be of the true dtype, not the cast dtype. This would cause seemingly correct things like this to fail with a dtype mismatch: ``` variable.assign(variable * factor) ``` (This is a common use case for non-trainable variables.) Forcing users to write workarounds like: ``` variable.assign(tf.cast(variable * factor, variable.dtype)) ``` This changes fixes these issues to make autocasting fully transparent: - `dtype` returns the cast dtype if applicable - `assign*` accept the cast dtype for the value if applicable Note that this is consistent with how autocasting works in Keras 3. PiperOrigin-RevId: 647135376
- Loading branch information