Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add torch.compile for Whisper #30949

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,3 +935,39 @@ def get_max_length(self) -> Optional[int]:
def reset(self):
self.key_cache.zero_()
self.value_cache.zero_()


class OneShotStaticCache(StaticCache):
"""
OneShotStaticCache is for cases where we update the cache only once and the cache remains constant after, it's useful in
encoder decoder models where we need to cache the key and value states of cross attention layer, in which case we only need
to compute and update the cache in the first generation step.

Parameters:
config (`PretrainedConfig):
The configuration file defining the shape-related attributes required to initialize the static cache.
max_batch_size (`int`):
The maximum batch size with which the model will be used.
max_cache_len (`int`):
The maximum sequence length with which the model will be used.
device (`torch.device`):
The device on which the cache should be initialized. Should be the same as the layer.
dtype (*optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer.
"""
def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
super().__init__(config, max_batch_size, max_cache_len, device, dtype)
self.cache_filled = [False for _ in range(config.num_hidden_layers)]

def update(self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, cache_kwargs: Dict[str, Any] | None = None) -> Tuple[torch.Tensor]:
if self.cache_filled[layer_idx]:
return self.key_cache[layer_idx], self.value_cache[layer_idx]
self.cache_filled[layer_idx] = True
return super().update(key_states, value_states, layer_idx, cache_kwargs)

def query_cache_filled_status(self, layer_idx: int) -> bool:
return self.cache_filled[layer_idx]

def reset(self):
super().reset()
self.cache_filled = [False for _ in range(len(self.cache_filled))]
62 changes: 37 additions & 25 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
QuantoQuantizedCache,
SlidingWindowCache,
StaticCache,
OneShotStaticCache
)
from ..integrations.deepspeed import is_deepspeed_zero3_enabled
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
Expand Down Expand Up @@ -111,7 +112,7 @@
if is_accelerate_available():
from accelerate.hooks import AlignDevicesHook, add_hook_to_module

NEED_SETUP_CACHE_CLASSES_MAPPING = {"static": StaticCache, "sliding_window": SlidingWindowCache}
NEED_SETUP_CACHE_CLASSES_MAPPING = {"static": StaticCache, "sliding_window": SlidingWindowCache, "one_shot": OneShotStaticCache}
QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache}


Expand Down Expand Up @@ -1348,53 +1349,60 @@ def _get_initial_cache_position(self, input_ids, model_kwargs):

past_length = 0
if "past_key_values" in model_kwargs:
if isinstance(model_kwargs["past_key_values"], Cache):
past_length = model_kwargs["past_key_values"].get_seq_length()
past_key_values = model_kwargs["past_key_values"]
# double cache case in encoder decoder arch
if isinstance(past_key_values, tuple) and isinstance(past_key_values[0], Cache):
past_key_values = past_key_values[0]

if isinstance(past_key_values, Cache):
past_length = past_key_values.get_seq_length()
else:
past_length = model_kwargs["past_key_values"][0][0].shape[2]
past_length = past_key_values[0][0].shape[2]
if "inputs_embeds" in model_kwargs:
cur_len = model_kwargs["inputs_embeds"].shape[1]
else:
cur_len = input_ids.shape[-1]
model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device)
return model_kwargs

def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_len: int) -> Cache:
def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_len: int, cache_name = '_cache') -> Cache:
"""
Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a
new `generate` call requires a larger cache.

Returns the resulting cache object.
"""
cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation]
need_new_cache = (
not hasattr(self, "_cache")
or (not isinstance(self._cache, cache_cls))
or self._cache.max_batch_size < max_batch_size
)
if cache_implementation == "sliding_window":
need_new_cache = need_new_cache or (
self._cache.sliding_window_size < self._cache.model_sliding_window_size
and max_cache_len > self._cache.max_cache_len
)
elif cache_implementation == "static":
need_new_cache = need_new_cache or self._cache.max_cache_len < max_cache_len
need_new_cache = not hasattr(self, cache_name)

if not need_new_cache:
current_cache: Cache = getattr(self, cache_name)
need_new_cache = not isinstance(current_cache, cache_cls) or current_cache.max_batch_size != max_batch_size
if cache_implementation == "sliding_window":
need_new_cache = need_new_cache or (
current_cache.sliding_window_size < current_cache.model_sliding_window_size
and max_cache_len > current_cache.max_cache_len
)
elif cache_implementation == "static":
need_new_cache = need_new_cache or current_cache.max_cache_len < max_cache_len
elif cache_implementation == "one_shot":
need_new_cache = need_new_cache or current_cache.max_cache_len != max_cache_len

if need_new_cache:
if hasattr(self.config, "_pre_quantization_dtype"):
cache_dtype = self.config._pre_quantization_dtype
else:
cache_dtype = self.dtype
self._cache = cache_cls(
setattr(self, cache_name, cache_cls(
config=self.config,
max_batch_size=max_batch_size,
max_cache_len=max_cache_len,
device=self.device,
dtype=cache_dtype,
)
))
else:
self._cache.reset()
return self._cache
current_cache.reset()
return getattr(self, cache_name)

def _get_decoder_start_token_id(
self, decoder_start_token_id: Union[int, List[int]] = None, bos_token_id: int = None
Expand Down Expand Up @@ -1681,9 +1689,14 @@ def generate(
"This model does not support `cache_implementation='static'`. Please check the following "
"issue: https://github.com/huggingface/transformers/issues/28981"
)
model_kwargs["past_key_values"] = self._get_cache(
generation_config.cache_implementation, batch_size, generation_config.max_length
)
model_kwargs["past_key_values"] = self._get_cache(generation_config.cache_implementation, batch_size, generation_config.max_length)
if self.config.is_encoder_decoder:
# manually set another cache for cross attention
encoder_outputs = model_kwargs["encoder_outputs"][0]
model_kwargs["past_key_values"] = (
model_kwargs["past_key_values"],
self._get_cache("one_shot", encoder_outputs.shape[0], encoder_outputs.shape[1], '_cross_attn_cache')
)
Comment on lines +1693 to +1699
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if there is a better way to do this, because in the encoder-decoder scenario we need a tuple of two caches here according to the current design, but this seems hardcode and easy to break

elif generation_config.cache_implementation == "quantized":
if not self._supports_quantized_cache:
raise ValueError(
Expand Down Expand Up @@ -2454,7 +2467,6 @@ def _sample(
this_peer_finished = False
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)

while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/whisper/configuration_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ class WhisperConfig(PretrainedConfig):

model_type = "whisper"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
attribute_map = {"num_key_value_heads": "encoder_attention_heads", "num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}

def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/whisper/generation_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_prec
else:
# num_frames is of shape (batch_size,) whereas batch_size is truely batch_size*num_return_sequences
repeat_time = batch_size if isinstance(num_frames, int) else batch_size // len(num_frames)
num_frames = np.repeat(num_frames, repeat_time)
num_frames = np.repeat(num_frames.cpu() if torch.is_tensor(num_frames) else num_frames, repeat_time)

if num_frames is None or isinstance(num_frames, int):
# Normalize and smoothen the weights.
Expand Down
Loading
Loading