-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix potential random layout inconsistency issues in sparse attention modules #534
Conversation
…point; 2) Add a broadcast of layout at the beginning to ensure different processes will have consistent layout during distributed training.
@@ -22,7 +23,8 @@ def __init__( | |||
# SparsityConfig parameters needs to be set accordingly | |||
sparsity_config=SparsityConfig(num_heads=4), | |||
key_padding_mask_mode='add', | |||
attn_mask_mode='mul'): | |||
attn_mask_mode='mul', | |||
max_seq_length=2048): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please add docstring for the new parameter as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, just added.
def get_layout(self, L): | ||
# if layout is never synchronized across GPUs, broadcast the layout from global rank 0 | ||
if self._need_layout_synchronization and dist.is_initialized(): | ||
dist.broadcast(self.master_layout, src=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might break with model parallelism (e.g., megatron-style or pipeline parallelism). However, it might be tricky to get the correct process group and rank inside the op since we can't easily communicate with the deepspeed engine to get this info here. /cc @ShadenSmith, @samyam
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good point @jeffra . I think we want to only broadcast along the data parallel group, similar to our weight initialization? But getting the group is tricky as you pointed out. We could add a data_parallel_group=None
parameter to the constructor, and if present broadcast along that torch.distributed
group? It'll be up to the modeling side of things to ensure that the data parallel group is created/provided. Alternatively, I think we'd need a reference to the training engine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep that makes sense. The training engine hasn't been created yet in this timeline, so that's a bit tricky. However, for now let's just the data_parallel_group
passed into the constructor and use if if it's not None in this broadcast. Then it allows the option for this at least.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we have multiple data_parallel_groups (i.e. in model parallel scenario), does that mean we would also require passing in the source rank to broadcast from within that process group? Do you think we would also need an optional argument for broadcast_src_rank in the constructor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, we could also add the broadcast_src_rank parameter. This just means the caller has to do this translation instead of us, which sounds fine.
There are two changes made to the SparseSelfAttention module in this PR:
@arashashari