Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add torch.compile for Mistral #30642

Conversation

zhenglongjiepheonix
Copy link
Contributor

@zhenglongjiepheonix zhenglongjiepheonix commented May 3, 2024

As suggested by the title, this PR attempts to add torch.compile support for mistral, and this is a not-ready-to-merge PR, it tries to replicate what has been done in Llama to Mistral considering the similar arch

  • Add Sliding Window Cache
  • Use Static Cache/Sliding Window Cache for torch.compile
  • Moves attention mask related logic inside _update_causal_mask
  • Modify _prepare_input_for_generation to make _generate work

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@zhenglongjiepheonix zhenglongjiepheonix force-pushed the longjie/add_torch_compile_for_mistral branch from 06fb2d6 to 57842ab Compare May 4, 2024 05:40
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this!

src/transformers/models/mistral/modeling_mistral.py Outdated Show resolved Hide resolved
src/transformers/models/mistral/modeling_mistral.py Outdated Show resolved Hide resolved
src/transformers/models/mistral/modeling_mistral.py Outdated Show resolved Hide resolved
src/transformers/models/mistral/modeling_mistral.py Outdated Show resolved Hide resolved
@ArthurZucker
Copy link
Collaborator

Also the list of models that support static cache in the doc probably need an update

@ArthurZucker
Copy link
Collaborator

BTW mistral will nee a SlidingWindowCache based on the implementation of RecurrentGemma!

@zhenglongjiepheonix
Copy link
Contributor Author

zhenglongjiepheonix commented May 7, 2024

BTW mistral will nee a SlidingWindowCache based on the implementation of RecurrentGemma!

from my understanding, SlidingWindowCache is for memory efficiency, actually I wonder how SlidingWindowCache would address the issue of get_seq_length where the current solution is relying on performing a non-zero check on all previous tokens, another solution is to use cache_position so that we don't have to call past_key_values.get_seq_length(), but this would require cache position is always passed in model.forward right?

@zhenglongjiepheonix zhenglongjiepheonix force-pushed the longjie/add_torch_compile_for_mistral branch from 57842ab to 315becb Compare May 7, 2024 04:29
@ArthurZucker
Copy link
Collaborator

Yes, we need to rely on cache positions that have to be passed to the model's forward. They can be initialized like in Llama. NOte that it's not breaking for the dyunamic cache because it will ignore the extra kwargs.

    def _update_cache(self, key_states, value_states, **cache_kwargs):
        """
        torch.compile compatible sliding window.
        Computes the `indices` based on `cache_position >= self.config.attention_window_size - 1`.
        The `to_shift` is only true once we are above attention_window_size. Thus with `attention_window_size==64`:

        indices = (slicing + to_shift[-1].int()-1) % self.config.attention_window_size
        tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
                19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
                37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
                55, 56, 57, 58, 59, 60, 61, 62, 63,  0])

        We overwrite the cache using these, then we always write at cache_position (clamped to `attention_window_size`)
        """
        cache_position = cache_kwargs.get("cache_position")
        if cache_position.shape[0] > self.config.attention_window_size:
            # int indexing -> device sync? in compile, use tensor
            k_out = key_states[:, :, -self.config.attention_window_size :, :]
            v_out = value_states[:, :, -self.config.attention_window_size :, :]
        else:
            slicing = torch.ones(
                self.config.attention_window_size, dtype=torch.long, device=value_states.device
            ).cumsum(0)
            cache_position = cache_position.clamp(0, self.config.attention_window_size - 1)
            to_shift = cache_position >= self.config.attention_window_size - 1
            indices = (slicing + to_shift[-1].int() - 1) % self.config.attention_window_size

            k_out, v_out = self.key_states.to(key_states.device), self.value_states.to(value_states.device)
            k_out = k_out[:, :, indices]
            v_out = v_out[:, :, indices]

            k_out[:, :, cache_position] = key_states
            v_out[:, :, cache_position] = value_states

        self.key_states, self.value_states = k_out, v_out
        return k_out, v_out

updating the cache should be like this

@ArthurZucker
Copy link
Collaborator

YEah it's hard to debug. merge from main and try to add a print to check where it's failing! Most probably a copied from that is not well placed

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in general looks like it's going in the right direction! 💪

Related PR: #30688

docs/source/en/llm_optims.md Outdated Show resolved Hide resolved
@zhenglongjiepheonix
Copy link
Contributor Author

Hi Aurthur, I have add support for Sliding Window Cache, and please take a look at its implementation and also the _update_causal_mask implementation, I have add my thoughts as comments

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good work

  • for all the copied from that were removed, we need to use on of the model as the new base (mixtral for example)
  • as @gante said, let's add phi to the lsit of model slow tested
  • let's revert some styling on setup.py and examples run object detection

src/transformers/cache_utils.py Outdated Show resolved Hide resolved
@@ -1342,6 +1340,33 @@ def _get_static_cache(self, max_batch_size: int, max_cache_len: int) -> StaticCa
self._static_cache.reset() # reset the cache for a new generation
return self._static_cache

# maybe a better way is to use a single factory function to set caches for models ?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, you should be able to use a mapping from the cache_implementation to the cache_class, and replace _get_static_cache by _get_cache. Since most of the args are gonna come from the config.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The need_new_cache is not a great naming IMO. If we need to check that for each cache class it makes sense to implement it, but would rather have minimal number of function in the cache class that are not directly related to the cache (here it's a generate trick)

@gante how much gain do we have from not allocating and just reseting? (IMO should be fairly efficient in general so we could just forget this needs new cache arg

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still keep the logic but keep them in _get_cache

src/transformers/generation/utils.py Outdated Show resolved Hide resolved
src/transformers/models/mistral/modeling_mistral.py Outdated Show resolved Hide resolved
src/transformers/models/mistral/modeling_mistral.py Outdated Show resolved Hide resolved
tests/models/mistral/test_modeling_mistral.py Outdated Show resolved Hide resolved
tests/models/mixtral/test_modeling_mixtral.py Outdated Show resolved Hide resolved
Comment on lines 488 to 520
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
cache_shape = (
config.num_hidden_layers,
max_batch_size,
self.num_key_value_heads,
self.sliding_window_size,
self.head_dim,
)

self.key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
self.value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)

torch._dynamo.mark_static_address(self.key_cache)
torch._dynamo.mark_static_address(self.value_cache)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what is the best between list of tensor and a full concatenated tensor?
But I think for everything related to generate, assisted decoding etc, list might be better?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also in the current configuration:

        self.key_cache: List[torch.Tensor] = []
        self.value_cache: List[torch.Tensor] = []

is useless!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried using lists, cudagraph seems to complain about it

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting as that works for the fully static cache! But alright 👍🏻

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's different because in static cache the tensor we keep in self.key_cache[layer_idx] never changes in terms of address, however in here we have the needs of assigning whole self.key_cache[layer_idx] to the new tensor passed in as key states

src/transformers/cache_utils.py Show resolved Hide resolved
@zhenglongjiepheonix zhenglongjiepheonix force-pushed the longjie/add_torch_compile_for_mistral branch from e0e9968 to ff65b81 Compare May 10, 2024 01:56
@zhenglongjiepheonix
Copy link
Contributor Author

zhenglongjiepheonix commented May 10, 2024

Good work

  • for all the copied from that were removed, we need to use on of the model as the new base (mixtral for example)
  • as @gante said, let's add phi to the lsit of model slow tested
  • let's revert some styling on setup.py and examples run object detection

Hi Authur, I made some modifications, and I still can't quite get what we should do to solve this copy consistency issue. I have changed the base to Mixtral for most models that refered to Mistral, but there are still some cases where it doesn't quite fit because Mixtral has itself moe related logic, and I also see this Ignore copy bandaid-stuff everywhere, it becomes very unclear that what the exact thing I need to do is, is it add Ignore copy or is it make modifications to make sources match and which source should be the standard if they disagree ? and make fix-copies does not help much because you can't just use it without being really sure that you actually need exactly the same thing, otherwise it just break stuffs!

Could you please explain more about the second point about phi model? I don't quite get it.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the copied from issue, indeed mixtral is not necessarily the best, the idea is just to do what we can to not break the copied from chain. Not everything has to be from mixtral

setup.py Outdated Show resolved Hide resolved
examples/pytorch/object-detection/run_object_detection.py Outdated Show resolved Hide resolved
src/transformers/cache_utils.py Show resolved Hide resolved
src/transformers/cache_utils.py Show resolved Hide resolved
src/transformers/generation/utils.py Show resolved Hide resolved
@zhenglongjiepheonix
Copy link
Contributor Author

zhenglongjiepheonix commented May 15, 2024

If except the 4 mentioned failing tests, all other are passing with this PR + test_compile_static_cache is passing on a A10 with torch 2.3, it's OK from my side (in terms of CI)

I don't see it running on the A10-based workflow, is it because it does not trigger on PR?

It's because the torch version there should use torch 2.3 but it is currently torch 2.2. I will update the docker files to use torch 2.3.

I think A10 workflow only trigger on push on main, not on PR ? I mean my commits on this pr are not triggering the test,
do I need to directly work on the transformers repo instead of my forked repo to run the test

name: Slow tests on important models (on Push - A10)

on:
  push:
    branches: [ main ]

@ydshieh
Copy link
Collaborator

ydshieh commented May 16, 2024

I think A10 workflow only trigger on push on main, not on PR ?
Yes.

You could run workflow manually

https://github.com/huggingface/transformers/actions/workflows/ssh-runner.yml

See DM on slaick

@zhenglongjiepheonix
Copy link
Contributor Author

zhenglongjiepheonix commented May 16, 2024

Ok, now the problem is both llama and mistral are failing the compile static cache tests on A10 because of the same error:

============================= FAILURES SHORT STACK =============================
________________ LlamaIntegrationTest.test_compile_static_cache ________________
msg = 'hasattr ConstDictVariable to'
    def unimplemented(msg: str) -> NoReturn:
        assert msg != os.environ.get("BREAK", False)
>       raise Unsupported(msg)
E       torch._dynamo.exc.Unsupported: hasattr ConstDictVariable to
E       
E       from user code:
E          File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/external_utils.py", line 36, in inner
E           return fn(*args, **kwargs)
E         File "/usr/local/lib/python3.8/dist-packages/accelerate/hooks.py", line 161, in new_forward
E           args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
E         File "/usr/local/lib/python3.8/dist-packages/accelerate/hooks.py", line 356, in pre_forward
E           return send_to_device(args, self.execution_device), send_to_device(
E         File "/usr/local/lib/python3.8/dist-packages/accelerate/utils/operations.py", line 148, in send_to_device
E           if is_torch_tensor(tensor) or hasattr(tensor, "to"):
E       
E       Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
E       
E       
E       You can suppress this exception and fall back to eager by setting:
E           import torch._dynamo
E           torch._dynamo.config.suppress_errors = True

They work fine on my A100 dev and colab T4 env, but here this error seems unrelated with GPUs, rather it's something related with software, I can't reproduce the error, even on an A10 aws machine with python==3.8:
Screen Shot 2024-05-16 at 17 25 19

@zhenglongjiepheonix
Copy link
Contributor Author

zhenglongjiepheonix commented May 16, 2024

Ok, now the problem is both llama and mistral are failing the compile static cache tests on A10 because of the same error:

============================= FAILURES SHORT STACK =============================
________________ LlamaIntegrationTest.test_compile_static_cache ________________
msg = 'hasattr ConstDictVariable to'
    def unimplemented(msg: str) -> NoReturn:
        assert msg != os.environ.get("BREAK", False)
>       raise Unsupported(msg)
E       torch._dynamo.exc.Unsupported: hasattr ConstDictVariable to
E       
E       from user code:
E          File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/external_utils.py", line 36, in inner
E           return fn(*args, **kwargs)
E         File "/usr/local/lib/python3.8/dist-packages/accelerate/hooks.py", line 161, in new_forward
E           args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
E         File "/usr/local/lib/python3.8/dist-packages/accelerate/hooks.py", line 356, in pre_forward
E           return send_to_device(args, self.execution_device), send_to_device(
E         File "/usr/local/lib/python3.8/dist-packages/accelerate/utils/operations.py", line 148, in send_to_device
E           if is_torch_tensor(tensor) or hasattr(tensor, "to"):
E       
E       Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
E       
E       
E       You can suppress this exception and fall back to eager by setting:
E           import torch._dynamo
E           torch._dynamo.config.suppress_errors = True

They work fine on my A100 dev and colab T4 env, but here this error seems unrelated with GPUs, rather it's something related with software, I can't reproduce the error, even on an A10 aws machine with python==3.8: Screen Shot 2024-05-16 at 17 25 19

the interesting fact is, test_compile_static_cache passes with itself along in CI, however if run after test_compile_sliding_window_cache, then it fails because some memory-related strategies in accelerate make dynamo unhappy @ArthurZucker @ydshieh

Comment on lines +640 to +697
8: [
"My favourite condiment is 100% ketchup. I love it on everything. "
"I’m not a big fan of mustard, mayo, or relish. I’m not a fan of pickles"
],
7: [
"My favourite condiment is 100% ketchup. I love it on everything. "
"I’m not a big fan of mustard, mayo, or relish. I’m not a fan of pickles"
],
}

prompts = ["My favourite condiment is "]
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False)
tokenizer.pad_token = tokenizer.eos_token
model = MistralForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1", device_map="sequential", torch_dtype=torch.float16
)
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)

# Dynamic Cache
generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False)
dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], dynamic_text)

# Static Cache
generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
)
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_text)

# Sliding Window Cache
generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="sliding_window"
)
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_text)

# Static Cache + compile
forward_function = model.forward
model.forward = torch.compile(forward_function, mode="reduce-overhead", fullgraph=True)
generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
)
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text)

# Sliding Window Cache + compile
torch._dynamo.reset()
model.forward = torch.compile(forward_function, mode="reduce-overhead", fullgraph=True)
generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="sliding_window"
)
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text)

del model
backend_empty_cache(torch_device)
gc.collect()
Copy link
Contributor Author

@zhenglongjiepheonix zhenglongjiepheonix May 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I merged two tests into one, and now it passes!

@ArthurZucker
Copy link
Collaborator

Yep accelerate does not support compile yet

@zhenglongjiepheonix
Copy link
Contributor Author

I merged from current main and again did slow tests on my dev and aws a10 machine, I believe this PR is good to merge now @ArthurZucker @ydshieh

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just need to resolve the merge conflicts, 1 nit and feel free to merge!

@property
def sin_cached(self):
logger.warning_once(
"The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is still the wrong version, it should now be 4.43!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just remove this like in llama

@zhenglongjiepheonix
Copy link
Contributor Author

Please merge this when available @ArthurZucker , for I don't have write access to the library

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants