Skip to content

Commit

Permalink
Fix potential random layout inconsistency issues in sparse attention …
Browse files Browse the repository at this point in the history
…modules (microsoft#534)

* 1) Register layout as buffer of module so that we can save/load checkpoint; 2) Add a broadcast of layout at the beginning to ensure different processes will have consistent layout during distributed training.

* Add docstring for max_seq_length argument in SparseSelfAttention

Co-authored-by: Zhun Liu <[email protected]>
Co-authored-by: Jeff Rasley <[email protected]>
  • Loading branch information
3 people authored Dec 4, 2020
1 parent ff58fa7 commit 1e44d48
Showing 1 changed file with 24 additions and 2 deletions.
26 changes: 24 additions & 2 deletions deepspeed/ops/sparse_attention/sparse_self_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch.nn as nn
from torch.nn.functional import *
import torch
from torch import distributed as dist
from collections import namedtuple
from deepspeed.ops.sparse_attention import MatMul, Softmax, SparsityConfig
import sys
Expand All @@ -22,29 +23,50 @@ def __init__(
# SparsityConfig parameters needs to be set accordingly
sparsity_config=SparsityConfig(num_heads=4),
key_padding_mask_mode='add',
attn_mask_mode='mul'):
attn_mask_mode='mul',
max_seq_length=2048):
"""Initialize the sparse self attention layer.
Arguments:
sparsity_config: optional: this parameter determins sparsity pattern configuration; it is based on SparsityConfig class.
key_padding_mask_mode: optional: a string determining if key padding mask needs to be added, `add`, or be multiplied, `mul`.
attn_mask_mode: optional: a string determining if attention mask needs to be added, `add`, or be multiplied, `mul`.
max_seq_length: optional: the maximum sequence length this sparse attention module will be applied to; it controls the size of the master_layout.
"""
super().__init__()

# sparsity information
self.sparsity_config = sparsity_config

# initialize sparse layout and register as buffer
master_layout = self.sparsity_config.make_layout(max_seq_length)
self.register_buffer("master_layout", master_layout)
self._need_layout_synchronization = True

# mask modes
self.key_padding_mask_mode = key_padding_mask_mode
self.attn_mask_mode = attn_mask_mode

ops = dict()

def get_layout(self, L):
# if layout is never synchronized across GPUs, broadcast the layout from global rank 0
if self._need_layout_synchronization and dist.is_initialized():
dist.broadcast(self.master_layout, src=0)
self._need_layout_synchronization = False

if (L % self.sparsity_config.block != 0):
raise ValueError(
f'Sequence Length, {L}, needs to be dividable by Block size {self.sparsity_config.block}!'
)

num_blocks = L // self.sparsity_config.block
return self.master_layout[..., :num_blocks, :num_blocks].cpu() # layout needs to be a CPU tensor

# add to cache
def get_ops(self, H, L):
import sys
if L not in SparseSelfAttention.ops:
sparsity_layout = self.sparsity_config.make_layout(L)
sparsity_layout = self.get_layout(L)
sparse_dot_sdd_nt = MatMul(sparsity_layout,
self.sparsity_config.block,
'sdd',
Expand Down

0 comments on commit 1e44d48

Please sign in to comment.