From b865f67898eb643311cc64408ca563a291c735a6 Mon Sep 17 00:00:00 2001 From: Sagar Hathwar Date: Thu, 12 Jan 2023 17:22:41 -0800 Subject: [PATCH] Fix QAT 2.0 TypeError for transformer network Signed-off-by: Sagar Hathwar --- .../keras/quant_sim/quantsim_straight_through_grad.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/quant_sim/quantsim_straight_through_grad.py b/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/quant_sim/quantsim_straight_through_grad.py index 0c2f6c3272..562bfc0075 100644 --- a/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/quant_sim/quantsim_straight_through_grad.py +++ b/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/quant_sim/quantsim_straight_through_grad.py @@ -235,7 +235,7 @@ def _compute_dloss_by_dmin_dmax_and_dx(inputs: tf.Tensor, encoding_min: tf.Varia dloss_by_dmin = tf.cast(_compute_dloss_by_dmin_using_dmax(dloss_by_dmax), tf.float64) # Pass through gradient for skipped ops - dloss_by_dx = tf.cond(tf.equal(op_mode, 3), lambda: grad, lambda: dloss_by_dx) + dloss_by_dx = tf.cond(tf.equal(op_mode, 3), lambda: tf.convert_to_tensor(grad), lambda: dloss_by_dx) return dloss_by_dmin, dloss_by_dmax, dloss_by_dx