Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

T5 compile compatibilty #34089

Merged
merged 34 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
15abc14
this worked in normal generation, needs more tests
zucchini-nlp Sep 19, 2024
06d9d62
fix almost all tests in t5
zucchini-nlp Sep 30, 2024
51c689c
nit
zucchini-nlp Sep 30, 2024
9e5244c
longt5, umt5, mt5
zucchini-nlp Sep 30, 2024
417dd6d
style
zucchini-nlp Sep 30, 2024
814a405
udop, pix2struct
zucchini-nlp Sep 30, 2024
0bc8b54
more models
zucchini-nlp Sep 30, 2024
7c5925b
fix some tests
zucchini-nlp Sep 30, 2024
038bb1e
fix onnx tests
zucchini-nlp Sep 30, 2024
df98842
tracing tests fixed
zucchini-nlp Oct 1, 2024
0544b65
compile enabled and tested for t5 models
zucchini-nlp Oct 1, 2024
1063971
fix small bug in slow tests
zucchini-nlp Oct 1, 2024
0e7fb50
[run-slow] t5
zucchini-nlp Oct 1, 2024
c4ccdea
uncomment
zucchini-nlp Oct 1, 2024
11065c9
Merge remote-tracking branch 'upstream/main' into t5-compile
zucchini-nlp Oct 1, 2024
993f318
style
zucchini-nlp Oct 1, 2024
41911b7
update with new generation refactoring
zucchini-nlp Oct 1, 2024
2449e32
nit
zucchini-nlp Oct 1, 2024
df0a05c
fix copies
zucchini-nlp Oct 1, 2024
c98e541
this is the fix, had to change t5 to fix copies
zucchini-nlp Oct 1, 2024
4f16856
update
zucchini-nlp Oct 8, 2024
d7260d3
[run-slow] t5
zucchini-nlp Oct 8, 2024
5f5f66f
[run-slow] t5
zucchini-nlp Oct 9, 2024
e404063
update
zucchini-nlp Oct 11, 2024
47d70c5
add test for encoder only T5
zucchini-nlp Oct 11, 2024
042101f
Merge remote-tracking branch 'upstream/main' into t5-compile
zucchini-nlp Oct 14, 2024
3048ab8
clean up after rebase
zucchini-nlp Oct 14, 2024
2c805f2
fix pop2piano
zucchini-nlp Oct 14, 2024
9e1fefa
add comment
zucchini-nlp Oct 18, 2024
0cb6036
Merge branch 'main' into t5-compile
zucchini-nlp Oct 21, 2024
56d036c
style
zucchini-nlp Oct 21, 2024
c25a8a4
fix copies after rebase
zucchini-nlp Oct 22, 2024
3086178
Merge remote-tracking branch 'upstream/main' into t5-compile
zucchini-nlp Oct 22, 2024
befe2d8
fix copies missed this one
zucchini-nlp Oct 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1475,11 +1475,7 @@ def from_legacy_cache(
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
# check if empty list because in case of static cache it will be a tensors and we can't check `if not torch.Tensor`
if self.self_attention_cache.key_cache == []:
return 0
if len(self.self_attention_cache.key_cache) > 1 and self.self_attention_cache.key_cache[layer_idx] == []:
return 0
return (self.self_attention_cache.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
return self.self_attention_cache.get_seq_length(layer_idx)

def reset(self):
if hasattr(self.self_attention_cache, "reset"):
Expand Down
8 changes: 6 additions & 2 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1535,8 +1535,12 @@ def _prepare_generation_config(
def _get_initial_cache_position(self, input_ids, model_kwargs):
"""Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length"""
# `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange`
if "inputs_embeds" in model_kwargs:
if "inputs_embeds" in model_kwargs and not self.config.is_encoder_decoder:
cache_position = torch.ones_like(model_kwargs["inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1
elif "decoder_inputs_embeds" in model_kwargs and self.config.is_encoder_decoder:
cache_position = (
torch.ones_like(model_kwargs["decoder_inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1
)
else:
cache_position = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1

Expand Down Expand Up @@ -1633,7 +1637,7 @@ def get_layer_device_map(execution_device_map: Optional[dict] = None):

cache_kwargs = {
"config": self.config.get_text_config(),
"max_batch_size": batch_size,
"batch_size": batch_size,
"max_cache_len": max_cache_len,
"device": device,
"dtype": cache_dtype,
Expand Down
7 changes: 6 additions & 1 deletion src/transformers/models/longt5/configuration_longt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,12 @@ class LongT5Config(PretrainedConfig):

model_type = "longt5"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"}
attribute_map = {
"hidden_size": "d_model",
"num_attention_heads": "num_heads",
"num_hidden_layers": "num_layers",
"head_dim": "d_kv",
}

def __init__(
self,
Expand Down
455 changes: 299 additions & 156 deletions src/transformers/models/longt5/modeling_longt5.py

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion src/transformers/models/mt5/configuration_mt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,12 @@ class MT5Config(PretrainedConfig):

model_type = "mt5"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"}
attribute_map = {
"hidden_size": "d_model",
"num_attention_heads": "num_heads",
"num_hidden_layers": "num_layers",
"head_dim": "d_kv",
}

def __init__(
self,
Expand Down
449 changes: 299 additions & 150 deletions src/transformers/models/mt5/modeling_mt5.py

Large diffs are not rendered by default.

407 changes: 275 additions & 132 deletions src/transformers/models/pix2struct/modeling_pix2struct.py

Large diffs are not rendered by default.

451 changes: 298 additions & 153 deletions src/transformers/models/pop2piano/modeling_pop2piano.py

Large diffs are not rendered by default.

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion src/transformers/models/t5/configuration_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,12 @@ class T5Config(PretrainedConfig):

model_type = "t5"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"}
attribute_map = {
"hidden_size": "d_model",
"num_attention_heads": "num_heads",
"num_hidden_layers": "num_layers",
"head_dim": "d_kv",
}

def __init__(
self,
Expand Down
452 changes: 302 additions & 150 deletions src/transformers/models/t5/modeling_t5.py

Large diffs are not rendered by default.

434 changes: 289 additions & 145 deletions src/transformers/models/udop/modeling_udop.py

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion src/transformers/models/umt5/configuration_umt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,12 @@ class UMT5Config(PretrainedConfig):

model_type = "umt5"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"}
attribute_map = {
"hidden_size": "d_model",
"num_attention_heads": "num_heads",
"num_hidden_layers": "num_layers",
"head_dim": "d_kv",
}

def __init__(
self,
Expand Down
Loading
Loading