You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
V0730 05:56:47.345000 140167293411712 torch/_dynamo/guards.py:1425] [__recompiles] Recompiling function forward in /home/sanchit/transformers/src/transformers/models/gemma2/modeling_gemma2.py:891
V0730 05:56:47.345000 140167293411712 torch/_dynamo/guards.py:1425] [__recompiles] triggered by the following guard failure(s):
V0730 05:56:47.345000 140167293411712 torch/_dynamo/guards.py:1425] [__recompiles] - tensor 'L['input_ids']' stride mismatch at index 0. expected 8, actual 1
bdV0730 05:57:41.549000 140167293411712 torch/_dynamo/guards.py:1425] [__recompiles] Recompiling function forward in /home/sanchit/transformers/src/transformers/models/gemma2/modeling_gemma2.py:891
V0730 05:57:41.549000 140167293411712 torch/_dynamo/guards.py:1425] [__recompiles] triggered by the following guard failure(s):
V0730 05:57:41.549000 140167293411712 torch/_dynamo/guards.py:1425] [__recompiles] - tensor 'L['input_ids']' stride mismatch at index 0. expected 1, actual 40
V0730 05:57:41.549000 140167293411712 torch/_dynamo/guards.py:1425] [__recompiles] - tensor 'L['input_ids']' stride mismatch at index 0. expected 8, actual 40
Run 0: 28.8159080589 tok/s
Run 1: 0.878302057247666 tok/s
Run 2: 19.946942197324718 tok/s
=> we get only two recompilations (expected), but the inference speed of the second and third runs are significantly lower than the first. This pattern happens only after calling past_key_values.reset(), which suggests a bug in how we're resetting the HybridCache.
Expected behavior
Run 0: 28.8159080589 tok/s
Run 1: 28.8159080589 tok/s
Run 2: 28.8159080589 tok/s
The text was updated successfully, but these errors were encountered:
@sanchit-gandhi Interesting, is this greedy search? With llama greedy search input_ids stride is always the same, might be safer to call contiguous/clone anyway
It's sampling (we set do_sample=True, temperature=1). Having played around with your PR, it looks like it's the same issue that's affecting Gemma-2 as LLaMA, so I've pushed the changes for Gemma/Gemma-2 directly to your PR!
System Info
transformers
version: 4.44.0.dev0- distributed_type: MULTI_GPU
- mixed_precision: bf16
- use_cpu: False
- debug: False
- num_processes: 2
- machine_rank: 0
- num_machines: 1
- gpu_ids: 0,1
- rdzv_backend: static
- same_network: True
- main_training_function: main
- enable_cpu_affinity: False
- downcast_bf16: no
- tpu_use_cluster: False
- tpu_use_sudo: False
- tpu_env: []
Who can help?
@sanchit-gandhi @gante @ArthurZucker
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Print Output:
=> we get only two recompilations (expected), but the inference speed of the second and third runs are significantly lower than the first. This pattern happens only after calling
past_key_values.reset()
, which suggests a bug in how we're resetting theHybridCache
.Expected behavior
The text was updated successfully, but these errors were encountered: