Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.export fails for llama model #29190

Closed
2 of 4 tasks
mreso opened this issue Feb 21, 2024 · 7 comments · Fixed by #29211
Closed
2 of 4 tasks

torch.export fails for llama model #29190

mreso opened this issue Feb 21, 2024 · 7 comments · Fixed by #29211

Comments

@mreso
Copy link

mreso commented Feb 21, 2024

System Info

  • transformers version: 4.38.0.dev0
  • Platform: Linux-5.12.0-0_fbk15_zion_7536_ged335dd7211d-x86_64-with-glibc2.34
  • Python version: 3.11.7
  • Huggingface_hub version: 0.20.3
  • Safetensors version: 0.4.2
  • Accelerate version: 0.27.2
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.3.0.dev20240221+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: no
  • Using distributed or parallel set-up in script?: no

Who can help?

@ArthurZucker @gante

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

llama = AutoModelForCausalLM.from_pretrained(
        "meta-llama/Llama-2-7b-chat-hf", low_cpu_mem_usage=True
    )
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")

llama.eval()
inputs = tokenizer("How do you", return_tensors="pt")

exported_llama = torch.export.export(llama, args=(inputs["input_ids"],inputs["attention_mask"]))

Expected behavior

Successfully export of the model but fails with:

torch._dynamo.exc.UserError: Dynamic control flow is not supported at the moment. Please use functorch.experimental.control_flow.cond to explicitly capture the control flow. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#cond-operands

from user code:
   File "/home/mreso/.conda/envs/pippy/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 1176, in forward
    outputs = self.model(
  File "/home/mreso/.conda/envs/pippy/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/mreso/.conda/envs/pippy/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 990, in forward
    causal_mask = self._update_causal_mask(attention_mask, inputs_embeds)
  File "/home/mreso/.conda/envs/pippy/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 1090, in _update_causal_mask
    if not is_tracing and attention_mask is not None and torch.any(attention_mask != 1):

I worked around this issue by guarding with torch._dynamo_is_compiling() in these two places:

is_tracing = torch.jit.is_tracing() or isinstance(input_tensor, torch.fx.Proxy)

is_tracing = torch.jit.is_tracing() or isinstance(input_tensor, torch.fx.Proxy) or (hasattr(torch,"_dynamo") and not torch._dynamo.is_compiling())

if hasattr(torch,"_dynamo") and not torch._dynamo.is_compiling():
    self._cos_cached = cos
    self._sin_cached = sin

Happy to create a PR if this is a viable solution.

@ArthurZucker
Copy link
Collaborator

cc @fxmarty as well

@fxmarty
Copy link
Contributor

fxmarty commented Feb 22, 2024

@mreso This should be fixed (at least the first part) in #29211.

I did not hit an issue with self._cos_cached = cos, with your repro, maybe we are not using the same commit? Can you try on main again?

@fxmarty
Copy link
Contributor

fxmarty commented Feb 22, 2024

Maybe related #29109 & #29198

@mreso
Copy link
Author

mreso commented Feb 22, 2024

Thanks @fxmarty for fixing the first part! Just reinstalled from source using commit 2cc8cf6 and rechecked. Still running into the second issue with the code snippet from above:

  File "/home/mreso/.conda/envs/pippy/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2266, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mreso/.conda/envs/pippy/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2380, in inline_call_
    tracer.run()
  File "/home/mreso/.conda/envs/pippy/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 793, in run
    and self.step()
        ^^^^^^^^^^^
  File "/home/mreso/.conda/envs/pippy/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 756, in step
    getattr(self, inst.opname)(inst)
  File "/home/mreso/.conda/envs/pippy/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1308, in STORE_ATTR
    assert (
AssertionError: Mutating module attribute _cos_cached during export.

from user code:
   File "/home/mreso/.conda/envs/pippy/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 1175, in forward
    outputs = self.model(
  File "/home/mreso/.conda/envs/pippy/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/mreso/.conda/envs/pippy/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 1011, in forward
    layer_outputs = decoder_layer(
  File "/home/mreso/.conda/envs/pippy/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/mreso/.conda/envs/pippy/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 735, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/home/mreso/.conda/envs/pippy/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/mreso/.conda/envs/pippy/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 641, in forward
    cos, sin = self.rotary_emb(value_states, position_ids)
  File "/home/mreso/.conda/envs/pippy/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/mreso/.conda/envs/pippy/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 131, in forward
    self._cos_cached = cos

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

@ArthurZucker
Copy link
Collaborator

this will be fixed by #29198, there is no longer self._cos_cached = cos

@fxmarty fxmarty reopened this Feb 23, 2024
@tbaggu
Copy link

tbaggu commented Feb 28, 2024

@ArthurZucker I have seen similar issues with mistralai/Mixtral-8x7B-v0.1 model, I have seen there are 13 graph breaks because of conditional statements and call backs to python,

is there any plan to support dynamo full_graph support for maximum of models ?

@fxmarty
Copy link
Contributor

fxmarty commented Feb 28, 2024

Closing as #29198 is merged

@tbaggu that would be the goal indeed.

@fxmarty fxmarty closed this as completed Feb 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants