-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DeepSpeed sequence parallelism (aka Ulysses) integration with HF transformer #32305
base: main
Are you sure you want to change the base?
Conversation
cc @SunMarc |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
This PR enhances capabilities of [DeepSpeed long sequence (context) parallelism (aka DS Ulysses)](https://dl.acm.org/doi/10.1145/3662158.3662806) with support for HuggingFace (and by extension other frameworks) models. With HF integration, users can use sequence parallelism for model pre/mid/post-training, finetuning etc. Usage requires both _torch >=2.2.2 and flash-attention_. ZeRO-1 and 2 are supported, ZeRO-3 and SPDA support in progress. Corresponding PR in HF is [PR32305](huggingface/transformers#32305). --------- Co-authored-by: Logan Adams <[email protected]>
Make torch a requirement for deepspeed sp
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! This looks great to me, well done. Made a few comments/questions, next let's get the accelerate version up and running 🔥
from .utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal | ||
|
||
|
||
if is_deepspeed_available(): | ||
from deepspeed.sequence.layer import _SeqAllToAll |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I enjoy the fact that we rely on the private _SeqAllToAll
, anything we should worry about there in the future?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nothing to worry about!
import torch | ||
from deepspeed import initialize |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These both need to be guarded under their respective if_xxx_available
, it's why the current tests are failing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test (in top-level test_deepspeed.py) is already guarded @required_deepspeed
and @required_torch_accelerator. I am thinking that should take of it, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope as it's being imported, we need to guard it from regular import checks. (This comes from our test discoverer I'm pretty sure)
model, _, _, _ = deepspeed.initialize(model=model, | ||
model_parameters=model.parameters(), | ||
config=ds_config, | ||
dist_init_required=True,) | ||
|
||
|
||
spg = model.get_sequence_parallel_group() | ||
seq_parallel_world_size = dist.get_world_size(spg) | ||
seq_parallel_rank = dist.get_rank(spg) | ||
|
||
for n, batch in enumerate(data_loader): | ||
seq_length = batch["input_ids"].size(1) | ||
assert seq_length % seq_parallel_world_size == 0 | ||
sub_seq_length = seq_length // seq_parallel_world_size | ||
sub_seq_start = seq_parallel_rank * sub_seq_length | ||
sub_seq_end = (seq_parallel_rank + 1) * sub_seq_length | ||
|
||
batch["input_ids"] = batch["input_ids"][:, sub_seq_start:sub_seq_end] | ||
batch["labels"] = batch["labels"][:, sub_seq_start:sub_seq_end] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that we have this in the Trainer, it'd be good if we can accelerate do this via it's init in the config + we advertise it here.
It'd be even better if we can modify the dataloaders to do the sequence parallel batching automatically, though I know this first was a transformers PR :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps rn we include a "accelerate version coming soon"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, thanks for adding this @samadejacobs ! Just a nit
if is_deepspeed_available(): | ||
from deepspeed.sequence.layer import _SeqAllToAll | ||
from deepspeed.utils import groups as ds_comm_groups |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we only import these function in the block where we use them. We don't want the user to implicitly import them when they don't use it + It will make transformers import slower.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing this out @SunMarc. I was torned between placing import at the top or within the block where it is needed. I decided to go with the former for performance reason. Recall that flash_attention_forward would be called multiple times, this could lead to repeated imports which could (possibly) degrade performance. Code readerability and mantainability are other (minor) reasons for my choice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh indeed ! Thx for explaining. A potential solution could be to import under the condition if is_deepspeed_sp_enabled()
instead since we are only using these imports with that condition. WDTY ? Could you also add a comment on why we decided to put the import at the top ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, let's make sure this does not impact any other models in terms of speed!
if is_deepspeed_sp_enabled(): | ||
spg = ds_comm_groups._get_sequence_parallel_group() | ||
# qkv tensors are of shape (batch_size, seq_len, num_heads, head_dim) | ||
scatter_idx = 2 # Scatter on num_heads dimension | ||
gather_idx = 1 # Gather on seq_len dimension | ||
batch_dim_idx = 0 # Synonymous with the batch_first==true | ||
query_states = _SeqAllToAll.apply(spg, query_states, scatter_idx, gather_idx, batch_dim_idx) | ||
key_states = _SeqAllToAll.apply(spg, key_states, scatter_idx, gather_idx, batch_dim_idx) | ||
value_states = _SeqAllToAll.apply(spg, value_states, scatter_idx, gather_idx, batch_dim_idx) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am a tad bit unconvinced by checking this for each model that runs flash attention.
We could:
- either have self.use_deep_speed, and check for that instead
- separate the deepspeed code, by change the forward if deepspeed is available with deepspeed forward.
For such changes it is important to make sure we don't introduce regressions in the perfromances. For a single forward pass it's gonna small, but for generation, it might add up!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ArthurZucker, as to your first bullet point, included in is_deepspeed_sp_enabled
is the check for "use_deep_speed". Generally speaking, I see your point of view but am also concern about duplicated code (your second bullet point). Is bullet point 1 without burden on user to set another use_deep_speed
flag sufficient for you? From deepspeed persepective, we assume that if the user enables deepspeed and sets sp_size > 1
during initialization, it means they intend to use deepspeed_sp
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if the user enables deepspeed and sets sp_size > 1 during initialization, it means they intend to use deepspeed_sp
yeah, IMO this is enough!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TBH I think we need to keep both code path separated.
I don't know exactly how but let's work together on a solution like this:
@wrap_deepseed
def _flash_attention_forward()
...
and @wrap_deepseed
would just replace the function with the one that fist calls:
spg = ds_comm_groups._get_sequence_parallel_group()
# qkv tensors are of shape (batch_size, seq_len, num_heads, head_dim)
scatter_idx = 2 # Scatter on num_heads dimension
gather_idx = 1 # Gather on seq_len dimension
batch_dim_idx = 0 # Synonymous with the batch_first==true
query_states = _SeqAllToAll.apply(spg, query_states, scatter_idx, gather_idx, batch_dim_idx)
key_states = _SeqAllToAll.apply(spg, key_states, scatter_idx, gather_idx, batch_dim_idx)
value_states = _SeqAllToAll.apply(spg, value_states, scatter_idx, gather_idx, batch_dim_idx)
if if is_deepspeed_sp_enabled():
.
If possible, we would only do this once, on the first call to the function, saving somewhere the results of is_deepspeed_sp_enabled
(unless it's cache somehow by python)
I really want to make sure this does not indtroduce bugs for others, and we don't increase the maintenant burden thus separate the codepaths !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ArthurZucker, apologies for late response, was OOF! Your suggestion LGTM, please feel to update this PR (or refactor as a follow up PR).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good to merge I just wanted docs to explain a bit more benefits ( when are perf better with this!)
@samadejacobs I'm glad to see this pr will be merged soon. When are you going to support sdpa in the future? It's useful for me. |
Actually it would be better if we can go with 2. bullet point: create a separate deep_speed_forward, which can call the flash_attn_foward inside, but add the deepspeed function on top of it
Yes, but I want to run it on the npu, but it doesn't support flash2, but sdpa. |
@samadejacobs anything I can do to help get this merged? |
Provide more clarity on the need for sequence parallelism.
@glowwormX, future support would be extended to SPDA. |
@ArthurZucker, many thanks, please see my earlier response. |
Hey @samadejacobs ! |
Hi, do I understand correctly that this PR deals only with attention support? |
@samadejacobs so how is this going now? |
I have been working on this for my team in Bloomberg. I think I may have a PR that merges changes from everybody that works. Will share shortly and hopefully we can get this feature in. |
cc @XuehaiPan |
What does this PR do?
This PR enhances capabilities of DeepSpeed long sequence (context) parallelism (aka DS Ulysses) with support for HF models. Support is currently enabled when both DeepSpeed and flash attn are enabled. Future support would be extended to SPDA. All current and future HF models (such as Llama, opt etc) using refactored flash_attention_utils are supported.
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@muellerzr