Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HybridCache slow after reset #32313

Closed
1 of 4 tasks
sanchit-gandhi opened this issue Jul 30, 2024 · 3 comments · Fixed by #32227
Closed
1 of 4 tasks

HybridCache slow after reset #32313

sanchit-gandhi opened this issue Jul 30, 2024 · 3 comments · Fixed by #32227
Labels

Comments

@sanchit-gandhi
Copy link
Contributor

sanchit-gandhi commented Jul 30, 2024

System Info

  • transformers version: 4.44.0.dev0
  • Platform: Linux-5.4.0-166-generic-x86_64-with-glibc2.31
  • Python version: 3.11.9
  • Huggingface_hub version: 0.23.4
  • Safetensors version: 0.4.3
  • Accelerate version: 0.30.0
  • Accelerate config: - compute_environment: LOCAL_MACHINE
    - distributed_type: MULTI_GPU
    - mixed_precision: bf16
    - use_cpu: False
    - debug: False
    - num_processes: 2
    - machine_rank: 0
    - num_machines: 1
    - gpu_ids: 0,1
    - rdzv_backend: static
    - same_network: True
    - main_training_function: main
    - enable_cpu_affinity: False
    - downcast_bf16: no
    - tpu_use_cluster: False
    - tpu_use_sudo: False
    - tpu_env: []
  • PyTorch version (GPU?): 2.3.1+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:
  • Using GPU in script?:
  • GPU type: NVIDIA A100-SXM4-80GB

Who can help?

@sanchit-gandhi @gante @ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import time
from transformers import AutoTokenizer, Gemma2ForCausalLM
from transformers.cache_utils import HybridCache
import torch

torch.set_float32_matmul_precision("high")
# catch re-compilations
torch._logging.set_logs(graph_breaks=True, recompiles=True)

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
model = Gemma2ForCausalLM.from_pretrained("google/gemma-2-9b", attn_implementation="eager")
model.to("cuda")

model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)

input_text = "The theory of special relativity states "
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
prompt_length = input_ids.input_ids.shape[1]

model.generation_config.min_new_tokens = model.generation_config.max_new_tokens = 32

past_key_values = HybridCache(
    config=model.config,
    max_batch_size=1,
    max_cache_len=prompt_length + 4 * model.generation_config.max_new_tokens,
    device=model.device,
    dtype=model.dtype
)

# enable passing kv cache
model._supports_cache_class = True
model.generation_config.cache_implementation = None

for i in range(3):
    # two warm-ups
    outputs_1 = model.generate(**input_ids, past_key_values=past_key_values, do_sample=True, temperature=1)
    outputs_2 = model.generate(outputs_1, past_key_values=past_key_values, do_sample=True, temperature=1)

    # one timed run
    torch.cuda.synchronize("cuda")
    start = time.time()
    outputs_3 = model.generate(outputs_2, past_key_values=past_key_values, do_sample=True, temperature=1)
    torch.cuda.synchronize("cuda")
    runtime = time.time() - start
    print(f"Run {i}: {model.generation_config.max_new_tokens / runtime} tok/s")

    past_key_values.reset()

Print Output:

V0730 05:56:47.345000 140167293411712 torch/_dynamo/guards.py:1425] [__recompiles] Recompiling function forward in /home/sanchit/transformers/src/transformers/models/gemma2/modeling_gemma2.py:891
V0730 05:56:47.345000 140167293411712 torch/_dynamo/guards.py:1425] [__recompiles]     triggered by the following guard failure(s):
V0730 05:56:47.345000 140167293411712 torch/_dynamo/guards.py:1425] [__recompiles]     - tensor 'L['input_ids']' stride mismatch at index 0. expected 8, actual 1

bdV0730 05:57:41.549000 140167293411712 torch/_dynamo/guards.py:1425] [__recompiles] Recompiling function forward in /home/sanchit/transformers/src/transformers/models/gemma2/modeling_gemma2.py:891
V0730 05:57:41.549000 140167293411712 torch/_dynamo/guards.py:1425] [__recompiles]     triggered by the following guard failure(s):
V0730 05:57:41.549000 140167293411712 torch/_dynamo/guards.py:1425] [__recompiles]     - tensor 'L['input_ids']' stride mismatch at index 0. expected 1, actual 40
V0730 05:57:41.549000 140167293411712 torch/_dynamo/guards.py:1425] [__recompiles]     - tensor 'L['input_ids']' stride mismatch at index 0. expected 8, actual 40

Run 0: 28.8159080589 tok/s
Run 1: 0.878302057247666 tok/s
Run 2: 19.946942197324718 tok/s

=> we get only two recompilations (expected), but the inference speed of the second and third runs are significantly lower than the first. This pattern happens only after calling past_key_values.reset(), which suggests a bug in how we're resetting the HybridCache.

Expected behavior

Run 0: 28.8159080589 tok/s
Run 1: 28.8159080589 tok/s
Run 2: 28.8159080589 tok/s
@sanchit-gandhi
Copy link
Contributor Author

Note @fxmarty that this issue also occurs when we only pass the input_ids to the model (and not the attention mask)

@fxmarty
Copy link
Contributor

fxmarty commented Jul 31, 2024

@sanchit-gandhi Interesting, is this greedy search? With llama greedy search input_ids stride is always the same, might be safer to call contiguous/clone anyway

@sanchit-gandhi
Copy link
Contributor Author

It's sampling (we set do_sample=True, temperature=1). Having played around with your PR, it looks like it's the same issue that's affecting Gemma-2 as LLaMA, so I've pushed the changes for Gemma/Gemma-2 directly to your PR!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
2 participants