-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add torch.compile for Mistral #30642
Add torch.compile for Mistral #30642
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
06fb2d6
to
57842ab
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this!
Also the list of models that support static cache in the doc probably need an update |
BTW mistral will nee a |
from my understanding, SlidingWindowCache is for memory efficiency, actually I wonder how SlidingWindowCache would address the issue of get_seq_length where the current solution is relying on performing a non-zero check on all previous tokens, another solution is to use cache_position so that we don't have to call past_key_values.get_seq_length(), but this would require cache position is always passed in model.forward right? |
57842ab
to
315becb
Compare
Yes, we need to rely on cache positions that have to be passed to the model's forward. They can be initialized like in Llama. NOte that it's not breaking for the dyunamic cache because it will ignore the extra kwargs. def _update_cache(self, key_states, value_states, **cache_kwargs):
"""
torch.compile compatible sliding window.
Computes the `indices` based on `cache_position >= self.config.attention_window_size - 1`.
The `to_shift` is only true once we are above attention_window_size. Thus with `attention_window_size==64`:
indices = (slicing + to_shift[-1].int()-1) % self.config.attention_window_size
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
55, 56, 57, 58, 59, 60, 61, 62, 63, 0])
We overwrite the cache using these, then we always write at cache_position (clamped to `attention_window_size`)
"""
cache_position = cache_kwargs.get("cache_position")
if cache_position.shape[0] > self.config.attention_window_size:
# int indexing -> device sync? in compile, use tensor
k_out = key_states[:, :, -self.config.attention_window_size :, :]
v_out = value_states[:, :, -self.config.attention_window_size :, :]
else:
slicing = torch.ones(
self.config.attention_window_size, dtype=torch.long, device=value_states.device
).cumsum(0)
cache_position = cache_position.clamp(0, self.config.attention_window_size - 1)
to_shift = cache_position >= self.config.attention_window_size - 1
indices = (slicing + to_shift[-1].int() - 1) % self.config.attention_window_size
k_out, v_out = self.key_states.to(key_states.device), self.value_states.to(value_states.device)
k_out = k_out[:, :, indices]
v_out = v_out[:, :, indices]
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
self.key_states, self.value_states = k_out, v_out
return k_out, v_out updating the cache should be like this |
YEah it's hard to debug. merge from main and try to add a print to check where it's failing! Most probably a copied from that is not well placed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in general looks like it's going in the right direction! 💪
Related PR: #30688
315becb
to
e0e9968
Compare
Hi Aurthur, I have add support for Sliding Window Cache, and please take a look at its implementation and also the _update_causal_mask implementation, I have add my thoughts as comments |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good work
- for all the copied from that were removed, we need to use on of the model as the new base (mixtral for example)
- as @gante said, let's add phi to the lsit of model slow tested
- let's revert some styling on setup.py and examples run object detection
src/transformers/generation/utils.py
Outdated
@@ -1342,6 +1340,33 @@ def _get_static_cache(self, max_batch_size: int, max_cache_len: int) -> StaticCa | |||
self._static_cache.reset() # reset the cache for a new generation | |||
return self._static_cache | |||
|
|||
# maybe a better way is to use a single factory function to set caches for models ? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, you should be able to use a mapping from the cache_implementation
to the cache_class
, and replace _get_static_cache by _get_cache. Since most of the args are gonna come from the config.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The need_new_cache
is not a great naming IMO. If we need to check that for each cache class it makes sense to implement it, but would rather have minimal number of function in the cache class that are not directly related to the cache (here it's a generate trick)
@gante how much gain do we have from not allocating and just reseting? (IMO should be fairly efficient in general so we could just forget this needs new cache arg
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still keep the logic but keep them in _get_cache
src/transformers/cache_utils.py
Outdated
self.key_cache: List[torch.Tensor] = [] | ||
self.value_cache: List[torch.Tensor] = [] | ||
cache_shape = ( | ||
config.num_hidden_layers, | ||
max_batch_size, | ||
self.num_key_value_heads, | ||
self.sliding_window_size, | ||
self.head_dim, | ||
) | ||
|
||
self.key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) | ||
self.value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) | ||
|
||
torch._dynamo.mark_static_address(self.key_cache) | ||
torch._dynamo.mark_static_address(self.value_cache) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure what is the best between list of tensor and a full concatenated tensor?
But I think for everything related to generate, assisted decoding etc, list might be better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also in the current configuration:
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
is useless!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried using lists, cudagraph seems to complain about it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting as that works for the fully static cache! But alright 👍🏻
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's different because in static cache the tensor we keep in self.key_cache[layer_idx]
never changes in terms of address, however in here we have the needs of assigning whole self.key_cache[layer_idx]
to the new tensor passed in as key states
e0e9968
to
ff65b81
Compare
Hi Authur, I made some modifications, and I still can't quite get what we should do to solve this copy consistency issue. I have changed the base to Could you please explain more about the second point about phi model? I don't quite get it. |
2b7d873
to
dd7ff33
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the copied from issue, indeed mixtral is not necessarily the best, the idea is just to do what we can to not break the copied from chain. Not everything has to be from mixtral
I think A10 workflow only trigger on push on main, not on PR ? I mean my commits on this pr are not triggering the test,
|
You could run workflow manually https://github.com/huggingface/transformers/actions/workflows/ssh-runner.yml See DM on slaick |
the interesting fact is, |
8: [ | ||
"My favourite condiment is 100% ketchup. I love it on everything. " | ||
"I’m not a big fan of mustard, mayo, or relish. I’m not a fan of pickles" | ||
], | ||
7: [ | ||
"My favourite condiment is 100% ketchup. I love it on everything. " | ||
"I’m not a big fan of mustard, mayo, or relish. I’m not a fan of pickles" | ||
], | ||
} | ||
|
||
prompts = ["My favourite condiment is "] | ||
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False) | ||
tokenizer.pad_token = tokenizer.eos_token | ||
model = MistralForCausalLM.from_pretrained( | ||
"mistralai/Mistral-7B-v0.1", device_map="sequential", torch_dtype=torch.float16 | ||
) | ||
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) | ||
|
||
# Dynamic Cache | ||
generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) | ||
dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) | ||
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], dynamic_text) | ||
|
||
# Static Cache | ||
generated_ids = model.generate( | ||
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" | ||
) | ||
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) | ||
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_text) | ||
|
||
# Sliding Window Cache | ||
generated_ids = model.generate( | ||
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="sliding_window" | ||
) | ||
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) | ||
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_text) | ||
|
||
# Static Cache + compile | ||
forward_function = model.forward | ||
model.forward = torch.compile(forward_function, mode="reduce-overhead", fullgraph=True) | ||
generated_ids = model.generate( | ||
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" | ||
) | ||
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) | ||
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text) | ||
|
||
# Sliding Window Cache + compile | ||
torch._dynamo.reset() | ||
model.forward = torch.compile(forward_function, mode="reduce-overhead", fullgraph=True) | ||
generated_ids = model.generate( | ||
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="sliding_window" | ||
) | ||
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) | ||
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text) | ||
|
||
del model | ||
backend_empty_cache(torch_device) | ||
gc.collect() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I merged two tests into one, and now it passes!
Yep accelerate does not support compile yet |
I merged from current main and again did slow tests on my dev and aws a10 machine, I believe this PR is good to merge now @ArthurZucker @ydshieh |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just need to resolve the merge conflicts, 1 nit and feel free to merge!
@property | ||
def sin_cached(self): | ||
logger.warning_once( | ||
"The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is still the wrong version, it should now be 4.43!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just remove this like in llama
Please merge this when available @ArthurZucker , for I don't have write access to the library |
As suggested by the title, this PR attempts to add torch.compile support for mistral, and this is a not-ready-to-merge PR, it tries to replicate what has been done in Llama to Mistral considering the similar arch