Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix potential random layout inconsistency issues in sparse attention modules #534

Merged
merged 7 commits into from
Dec 4, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions deepspeed/ops/sparse_attention/sparse_self_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch.nn as nn
from torch.nn.functional import *
import torch
from torch import distributed as dist
from collections import namedtuple
from deepspeed.ops.sparse_attention import MatMul, Softmax, SparsityConfig
import sys
Expand All @@ -22,29 +23,50 @@ def __init__(
# SparsityConfig parameters needs to be set accordingly
sparsity_config=SparsityConfig(num_heads=4),
key_padding_mask_mode='add',
attn_mask_mode='mul'):
attn_mask_mode='mul',
max_seq_length=2048):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add docstring for the new parameter as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, just added.

"""Initialize the sparse self attention layer.
Arguments:
sparsity_config: optional: this parameter determins sparsity pattern configuration; it is based on SparsityConfig class.
key_padding_mask_mode: optional: a string determining if key padding mask needs to be added, `add`, or be multiplied, `mul`.
attn_mask_mode: optional: a string determining if attention mask needs to be added, `add`, or be multiplied, `mul`.
max_seq_length: optional: the maximum sequence length this sparse attention module will be applied to; it controls the size of the master_layout.
"""
super().__init__()

# sparsity information
self.sparsity_config = sparsity_config

# initialize sparse layout and register as buffer
master_layout = self.sparsity_config.make_layout(max_seq_length)
self.register_buffer("master_layout", master_layout)
self._need_layout_synchronization = True

# mask modes
self.key_padding_mask_mode = key_padding_mask_mode
self.attn_mask_mode = attn_mask_mode

ops = dict()

def get_layout(self, L):
# if layout is never synchronized across GPUs, broadcast the layout from global rank 0
if self._need_layout_synchronization and dist.is_initialized():
dist.broadcast(self.master_layout, src=0)
Copy link
Collaborator

@jeffra jeffra Nov 18, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might break with model parallelism (e.g., megatron-style or pipeline parallelism). However, it might be tricky to get the correct process group and rank inside the op since we can't easily communicate with the deepspeed engine to get this info here. /cc @ShadenSmith, @samyam

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point @jeffra . I think we want to only broadcast along the data parallel group, similar to our weight initialization? But getting the group is tricky as you pointed out. We could add a data_parallel_group=None parameter to the constructor, and if present broadcast along that torch.distributed group? It'll be up to the modeling side of things to ensure that the data parallel group is created/provided. Alternatively, I think we'd need a reference to the training engine.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep that makes sense. The training engine hasn't been created yet in this timeline, so that's a bit tricky. However, for now let's just the data_parallel_group passed into the constructor and use if if it's not None in this broadcast. Then it allows the option for this at least.

Copy link
Contributor Author

@Justin1904 Justin1904 Nov 18, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have multiple data_parallel_groups (i.e. in model parallel scenario), does that mean we would also require passing in the source rank to broadcast from within that process group? Do you think we would also need an optional argument for broadcast_src_rank in the constructor?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, we could also add the broadcast_src_rank parameter. This just means the caller has to do this translation instead of us, which sounds fine.

self._need_layout_synchronization = False

if (L % self.sparsity_config.block != 0):
raise ValueError(
f'Sequence Length, {L}, needs to be dividable by Block size {self.sparsity_config.block}!'
)

num_blocks = L // self.sparsity_config.block
return self.master_layout[..., :num_blocks, :num_blocks].cpu() # layout needs to be a CPU tensor

# add to cache
def get_ops(self, H, L):
import sys
if L not in SparseSelfAttention.ops:
sparsity_layout = self.sparsity_config.make_layout(L)
sparsity_layout = self.get_layout(L)
sparse_dot_sdd_nt = MatMul(sparsity_layout,
self.sparsity_config.block,
'sdd',
Expand Down