Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Whisper + Torch.Compile: torch._dynamo.exc.Unsupported: reconstruct: UserDefinedObjectVariable(EncoderDecoderCache) #31987

Closed
2 of 4 tasks
kadirnar opened this issue Jul 15, 2024 · 7 comments
Labels

Comments

@kadirnar
Copy link
Contributor

kadirnar commented Jul 15, 2024

System Info

- `transformers` version: 4.43.0.dev0
- Platform: Linux-6.5.0-28-generic-x86_64-with-glibc2.35
- Python version: 3.10.12
- Huggingface_hub version: 0.23.4
- Safetensors version: 0.4.3
- Accelerate version: 0.32.1
- Accelerate config:    not found
- PyTorch version (GPU?): 2.2.0+cu121 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?: <fill in>
- Using GPU in script?: <fill in>
- GPU type: NVIDIA GeForce RTX 4090

Who can help?

@sanchit-gandhi, @Narsil, @SunMarc

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from datasets import load_dataset
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import torch

import torch._dynamo
torch._dynamo.config.suppress_errors = True
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
audio_sample = ds[0]["audio"]

processor = WhisperProcessor.from_pretrained("distil-whisper/distil-large-v3")
model = WhisperForConditionalGeneration.from_pretrained("distil-whisper/distil-large-v3", attn_implementation="sdpa")

model.generation_config.cache_implementation = "static"
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)

import time

start = time.time()
input_features = processor(
    audio_sample["array"], sampling_rate=audio_sample["sampling_rate"], return_tensors="pt"
).input_features

_ = model.generate(input_features)

predicted_ids = model.generate(input_features)

transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)

end = time.time()
print(f"Elapsed time: {end - start} seconds")

Expected behavior

I want to optimize torch.compile using the Whisper model. I also want to use the pipeline function while doing the torch.compile process. Can you also add sample code for .mp3 in the doc section?

Error Message:

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/codegen.py", line 161, in __call__
    output.extend(value.reconstruct(self))
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/base.py", line 277, in reconstruct
    raise NotImplementedError()
NotImplementedError

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/root/whisper-plus/test.py", line 23, in <module>
    _ = model.generate(input_features)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/whisper/generation_whisper.py", line 587, in generate
    outputs = super().generate(
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 1969, in generate
    result = self._sample(
  File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 2912, in _sample
    outputs = self(**model_inputs, return_dict=True)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 655, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 383, in _convert_frame_assert
    compiled_product = _compile(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 646, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 562, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1033, in transform_code_object
    transformations(instructions, code_options)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 151, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 527, in transform
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2128, in run
    super().run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
    and self.step()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
    getattr(self, inst.opname)(inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2243, in RETURN_VALUE
    self.output.compile_subgraph(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py", line 927, in compile_subgraph
    pass1.restore_stack(stack_values)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/codegen.py", line 68, in restore_stack
    self.foreach(stack_values)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/codegen.py", line 193, in foreach
    self(i)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/codegen.py", line 161, in __call__
    output.extend(value.reconstruct(self))
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/dicts.py", line 701, in reconstruct
    codegen(self.items[key])
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/codegen.py", line 163, in __call__
    unimplemented(f"reconstruct: {value}")
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/exc.py", line 193, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: reconstruct: UserDefinedObjectVariable(EncoderDecoderCache)

from user code:
   File "/usr/local/lib/python3.10/dist-packages/transformers/models/whisper/modeling_whisper.py", line 1751, in forward
    return Seq2SeqLMOutput(

[2024-07-15 21:02:30,139] torch._dynamo.utils: [INFO] TorchDynamo compilation metrics:
[2024-07-15 21:02:30,139] torch._dynamo.utils: [INFO] Function                           Runtimes (s)
[2024-07-15 21:02:30,139] torch._dynamo.utils: [INFO] -------------------------------  --------------
[2024-07-15 21:02:30,139] torch._dynamo.utils: [INFO] _compile.<locals>.compile_inner 
@kadirnar
Copy link
Contributor Author

#31166

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Jul 16, 2024

Hey! I think you need torch 2.3! Can you try with it?

@kadirnar
Copy link
Contributor Author

Hey! I think you need torch 2.3! Can you try with it?

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
You have passed task=transcribe, but also have set `forced_decoder_ids` to [[1, None], [2, 50360]] which creates a conflict. `forced_decoder_ids` will be ignored in favor of task=transcribe.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token.As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
skipping cudagraphs due to skipping cudagraphs due to cpu device. Found from : 
   File "/usr/local/lib/python3.10/dist-packages/transformers/models/whisper/modeling_whisper.py", line 1720, in forward
    outputs = self.model(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/whisper/modeling_whisper.py", line 1592, in forward
    decoder_outputs = self.decoder(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/whisper/modeling_whisper.py", line 1263, in forward
    positions = self.embed_positions(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/whisper/modeling_whisper.py", line 219, in forward
    return self.weight[position_ids]

skipping cudagraphs due to skipping cudagraphs due to cpu device. Found from : 
   File "/usr/local/lib/python3.10/dist-packages/transformers/models/whisper/modeling_whisper.py", line 1720, in forward
    outputs = self.model(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/whisper/modeling_whisper.py", line 1592, in forward
    decoder_outputs = self.decoder(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/whisper/modeling_whisper.py", line 1263, in forward
    positions = self.embed_positions(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/whisper/modeling_whisper.py", line 219, in forward
    return self.weight[position_ids]

Elapsed time: 33.72781229019165 seconds

Env:

- `transformers` version: 4.43.0.dev0
- Platform: Linux-6.5.0-35-generic-x86_64-with-glibc2.31
- Python version: 3.10.12
- Huggingface_hub version: 0.23.5
- Safetensors version: 0.4.3
- Accelerate version: 0.32.1
- Accelerate config:    not found
- PyTorch version (GPU?): 2.3.1+cu121 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?: <fill in>
- Using GPU in script?: <fill in>
- GPU type: NVIDIA GeForce RTX 4090

@kadirnar
Copy link
Contributor Author

The error was resolved when I added the Device parameter.

@kadirnar
Copy link
Contributor Author

@ArthurZucker ,

I tested this code and it works. But I want to give .mp3 file as input. I looked at the load_dataset documentation and couldn't find it. Can you share sample code?

https://gist.github.com/ArthurZucker/a79018e7642e7ddefe06531407ef8401

@ArthurZucker
Copy link
Collaborator

glad that it was resolved!
I don't have an example out of hand but you should be able to pass mp3 instead of using dataset.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants