Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix potential random layout inconsistency issues in sparse attention modules #534

Merged
merged 7 commits into from
Dec 4, 2020

Conversation

Justin1904
Copy link
Contributor

There are two changes made to the SparseSelfAttention module in this PR:

  1. Now SparseSelfAttention module will create a "master_layout" upfront and register it as a buffer, this saves us the need to create new layout on-the-fly later (which can cause inconsistency if there's randomness in layout creation) and also makes it easy for us to save & load the layout from checkpoint;
  2. Add a broadcast of layout at the beginning to ensure different processes in distributed training will have consistent layout.

@arashashari

…point; 2) Add a broadcast of layout at the beginning to ensure different processes will have consistent layout during distributed training.
@@ -22,7 +23,8 @@ def __init__(
# SparsityConfig parameters needs to be set accordingly
sparsity_config=SparsityConfig(num_heads=4),
key_padding_mask_mode='add',
attn_mask_mode='mul'):
attn_mask_mode='mul',
max_seq_length=2048):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add docstring for the new parameter as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, just added.

def get_layout(self, L):
# if layout is never synchronized across GPUs, broadcast the layout from global rank 0
if self._need_layout_synchronization and dist.is_initialized():
dist.broadcast(self.master_layout, src=0)
Copy link
Collaborator

@jeffra jeffra Nov 18, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might break with model parallelism (e.g., megatron-style or pipeline parallelism). However, it might be tricky to get the correct process group and rank inside the op since we can't easily communicate with the deepspeed engine to get this info here. /cc @ShadenSmith, @samyam

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point @jeffra . I think we want to only broadcast along the data parallel group, similar to our weight initialization? But getting the group is tricky as you pointed out. We could add a data_parallel_group=None parameter to the constructor, and if present broadcast along that torch.distributed group? It'll be up to the modeling side of things to ensure that the data parallel group is created/provided. Alternatively, I think we'd need a reference to the training engine.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep that makes sense. The training engine hasn't been created yet in this timeline, so that's a bit tricky. However, for now let's just the data_parallel_group passed into the constructor and use if if it's not None in this broadcast. Then it allows the option for this at least.

Copy link
Contributor Author

@Justin1904 Justin1904 Nov 18, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have multiple data_parallel_groups (i.e. in model parallel scenario), does that mean we would also require passing in the source rank to broadcast from within that process group? Do you think we would also need an optional argument for broadcast_src_rank in the constructor?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, we could also add the broadcast_src_rank parameter. This just means the caller has to do this translation instead of us, which sounds fine.

@arashashari arashashari merged commit 1e44d48 into microsoft:master Dec 4, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants