Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeepSpeed sequence parallelism (aka Ulysses) integration with HF transformer #32305

Open
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

samadejacobs
Copy link

@samadejacobs samadejacobs commented Jul 29, 2024

What does this PR do?

This PR enhances capabilities of DeepSpeed long sequence (context) parallelism (aka DS Ulysses) with support for HF models. Support is currently enabled when both DeepSpeed and flash attn are enabled. Future support would be extended to SPDA. All current and future HF models (such as Llama, opt etc) using refactored flash_attention_utils are supported.

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@muellerzr

@LysandreJik
Copy link
Member

cc @SunMarc

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

github-merge-queue bot pushed a commit to deepspeedai/DeepSpeed that referenced this pull request Aug 21, 2024
This PR enhances capabilities of [DeepSpeed long sequence (context)
parallelism (aka DS
Ulysses)](https://dl.acm.org/doi/10.1145/3662158.3662806) with support
for HuggingFace (and by extension other frameworks) models. With HF
integration, users can use sequence parallelism for model
pre/mid/post-training, finetuning etc. Usage requires both _torch
>=2.2.2 and flash-attention_. ZeRO-1 and 2 are supported, ZeRO-3 and
SPDA support in progress. Corresponding PR in HF is
[PR32305](huggingface/transformers#32305).

---------

Co-authored-by: Logan Adams <[email protected]>
Make torch a requirement for deepspeed sp
Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! This looks great to me, well done. Made a few comments/questions, next let's get the accelerate version up and running 🔥

from .utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal


if is_deepspeed_available():
from deepspeed.sequence.layer import _SeqAllToAll
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I enjoy the fact that we rely on the private _SeqAllToAll, anything we should worry about there in the future?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nothing to worry about!

Comment on lines +18 to +19
import torch
from deepspeed import initialize
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These both need to be guarded under their respective if_xxx_available, it's why the current tests are failing

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test (in top-level test_deepspeed.py) is already guarded @required_deepspeed and @required_torch_accelerator. I am thinking that should take of it, no?

Copy link
Contributor

@muellerzr muellerzr Aug 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope as it's being imported, we need to guard it from regular import checks. (This comes from our test discoverer I'm pretty sure)

Comment on lines +1161 to +1179
model, _, _, _ = deepspeed.initialize(model=model,
model_parameters=model.parameters(),
config=ds_config,
dist_init_required=True,)


spg = model.get_sequence_parallel_group()
seq_parallel_world_size = dist.get_world_size(spg)
seq_parallel_rank = dist.get_rank(spg)

for n, batch in enumerate(data_loader):
seq_length = batch["input_ids"].size(1)
assert seq_length % seq_parallel_world_size == 0
sub_seq_length = seq_length // seq_parallel_world_size
sub_seq_start = seq_parallel_rank * sub_seq_length
sub_seq_end = (seq_parallel_rank + 1) * sub_seq_length

batch["input_ids"] = batch["input_ids"][:, sub_seq_start:sub_seq_end]
batch["labels"] = batch["labels"][:, sub_seq_start:sub_seq_end]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we have this in the Trainer, it'd be good if we can accelerate do this via it's init in the config + we advertise it here.

It'd be even better if we can modify the dataloaders to do the sequence parallel batching automatically, though I know this first was a transformers PR :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps rn we include a "accelerate version coming soon"?

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks for adding this @samadejacobs ! Just a nit

Comment on lines +30 to +32
if is_deepspeed_available():
from deepspeed.sequence.layer import _SeqAllToAll
from deepspeed.utils import groups as ds_comm_groups
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we only import these function in the block where we use them. We don't want the user to implicitly import them when they don't use it + It will make transformers import slower.

Copy link
Author

@samadejacobs samadejacobs Aug 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out @SunMarc. I was torned between placing import at the top or within the block where it is needed. I decided to go with the former for performance reason. Recall that flash_attention_forward would be called multiple times, this could lead to repeated imports which could (possibly) degrade performance. Code readerability and mantainability are other (minor) reasons for my choice.

Copy link
Member

@SunMarc SunMarc Aug 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh indeed ! Thx for explaining. A potential solution could be to import under the condition if is_deepspeed_sp_enabled() instead since we are only using these imports with that condition. WDTY ? Could you also add a comment on why we decided to put the import at the top ?

@SunMarc SunMarc requested a review from ArthurZucker August 23, 2024 11:03
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, let's make sure this does not impact any other models in terms of speed!

Comment on lines +232 to +240
if is_deepspeed_sp_enabled():
spg = ds_comm_groups._get_sequence_parallel_group()
# qkv tensors are of shape (batch_size, seq_len, num_heads, head_dim)
scatter_idx = 2 # Scatter on num_heads dimension
gather_idx = 1 # Gather on seq_len dimension
batch_dim_idx = 0 # Synonymous with the batch_first==true
query_states = _SeqAllToAll.apply(spg, query_states, scatter_idx, gather_idx, batch_dim_idx)
key_states = _SeqAllToAll.apply(spg, key_states, scatter_idx, gather_idx, batch_dim_idx)
value_states = _SeqAllToAll.apply(spg, value_states, scatter_idx, gather_idx, batch_dim_idx)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a tad bit unconvinced by checking this for each model that runs flash attention.
We could:

  • either have self.use_deep_speed, and check for that instead
  • separate the deepspeed code, by change the forward if deepspeed is available with deepspeed forward.
    For such changes it is important to make sure we don't introduce regressions in the perfromances. For a single forward pass it's gonna small, but for generation, it might add up!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker, as to your first bullet point, included in is_deepspeed_sp_enabled is the check for "use_deep_speed". Generally speaking, I see your point of view but am also concern about duplicated code (your second bullet point). Is bullet point 1 without burden on user to set another use_deep_speed flag sufficient for you? From deepspeed persepective, we assume that if the user enables deepspeed and sets sp_size > 1 during initialization, it means they intend to use deepspeed_sp.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the user enables deepspeed and sets sp_size > 1 during initialization, it means they intend to use deepspeed_sp

yeah, IMO this is enough!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH I think we need to keep both code path separated.
I don't know exactly how but let's work together on a solution like this:

@wrap_deepseed
def _flash_attention_forward()
...

and @wrap_deepseed would just replace the function with the one that fist calls:

        spg = ds_comm_groups._get_sequence_parallel_group()
        # qkv tensors are of shape (batch_size, seq_len, num_heads, head_dim)
        scatter_idx = 2  # Scatter on num_heads dimension
        gather_idx = 1  # Gather on seq_len dimension
        batch_dim_idx = 0  # Synonymous with the batch_first==true
        query_states = _SeqAllToAll.apply(spg, query_states, scatter_idx, gather_idx, batch_dim_idx)
        key_states = _SeqAllToAll.apply(spg, key_states, scatter_idx, gather_idx, batch_dim_idx)
        value_states = _SeqAllToAll.apply(spg, value_states, scatter_idx, gather_idx, batch_dim_idx)

if if is_deepspeed_sp_enabled(): .
If possible, we would only do this once, on the first call to the function, saving somewhere the results of is_deepspeed_sp_enabled (unless it's cache somehow by python)

I really want to make sure this does not indtroduce bugs for others, and we don't increase the maintenant burden thus separate the codepaths !

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WDYT?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker, apologies for late response, was OOF! Your suggestion LGTM, please feel to update this PR (or refactor as a follow up PR).

ArthurZucker
ArthurZucker previously approved these changes Sep 21, 2024
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to merge I just wanted docs to explain a bit more benefits ( when are perf better with this!)

@glowwormX
Copy link

@samadejacobs I'm glad to see this pr will be merged soon. When are you going to support sdpa in the future? It's useful for me.

@ArthurZucker ArthurZucker dismissed their stale review September 21, 2024 15:28

Actually it would be better if we can go with 2. bullet point: create a separate deep_speed_forward, which can call the flash_attn_foward inside, but add the deepspeed function on top of it

@glowwormX
Copy link

@samadejacobs I'm glad to see this pr will be merged soon. When are you going to support sdpa in the future? It's useful for me.

Yes, but I want to run it on the npu, but it doesn't support flash2, but sdpa.

@ArthurZucker
Copy link
Collaborator

@samadejacobs anything I can do to help get this merged?

Provide more clarity on the need for sequence parallelism.
@samadejacobs
Copy link
Author

@samadejacobs I'm glad to see this pr will be merged soon. When are you going to support sdpa in the future? It's useful for me.

@glowwormX, future support would be extended to SPDA.

@samadejacobs
Copy link
Author

@samadejacobs anything I can do to help get this merged?

@ArthurZucker, many thanks, please see my earlier response.

@ArthurZucker
Copy link
Collaborator

Hey @samadejacobs !
Ah, I am not sure I have bandwidth right now for #32305 (comment), but if you can't do it will see if I can ping someone or do it!

@pavelgein
Copy link
Contributor

Hi, do I understand correctly that this PR deals only with attention support?
As far as I understand, sequence parallelism requires that two or more workers are given the same data, and this requires adjustments in trainers and dataloaders.

@jyshee
Copy link

jyshee commented Dec 13, 2024

@samadejacobs so how is this going now?

@ronald-d-rogers
Copy link

I have been working on this for my team in Bloomberg. I think I may have a PR that merges changes from everybody that works. Will share shortly and hopefully we can get this feature in.

@SunMarc
Copy link
Member

SunMarc commented Dec 30, 2024

cc @XuehaiPan

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.