You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I used the aqt_einsum function in the code to only quantify the qk sccore, and then trained the model. However, I found that the loss dropped very slowly after training to a certain number of steps (such as 200 steps), which was quite different from the loss curve trained by bfloat16. Am I missing something? For example, does backward need some additional processing?
ps: I train model on jax==0.4.23 and tpu v5p-8
In other words, is there a training example for AQT int8 in pax?
The text was updated successfully, but these errors were encountered:
I used the aqt_einsum function in the code to only quantify the qk sccore, and then trained the model. However, I found that the loss dropped very slowly after training to a certain number of steps (such as 200 steps), which was quite different from the loss curve trained by bfloat16. Am I missing something? For example, does backward need some additional processing?
ps: I train model on jax==0.4.23 and tpu v5p-8
In other words, is there a training example for AQT int8 in pax?
The text was updated successfully, but these errors were encountered: