Skip to content

Commit

Permalink
Require setuptools>=70 and update deprecated api (#10659)
Browse files Browse the repository at this point in the history
* Require setuptools>=70 and update deprecated api

Signed-off-by: Dong Hyuk Chang <[email protected]>

* Apply isort and black reformatting

Signed-off-by: thomasdhc <[email protected]>

---------

Signed-off-by: Dong Hyuk Chang <[email protected]>
Signed-off-by: thomasdhc <[email protected]>
Co-authored-by: Dong Hyuk Chang <[email protected]>
Co-authored-by: thomasdhc <[email protected]>
  • Loading branch information
3 people authored and monica-sekoyan committed Oct 11, 2024
1 parent 7ce3326 commit ae60b7f
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@
from importlib.metadata import version
from typing import Tuple

import packaging
import torch
import torch.nn.functional as F
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.parallel_state import get_tensor_model_parallel_group
from megatron.core.transformer import TransformerConfig
from pkg_resources import packaging
from torch import Tensor
from torch.nn.modules.loss import _Loss

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
from typing import Any, Optional

import numpy as np
import packaging
import torch
import torch.nn.functional as F
from einops import rearrange, reduce, repeat
from omegaconf import DictConfig, ListConfig, OmegaConf
from pkg_resources import packaging
from pytorch_lightning.trainer.trainer import Trainer
from transformers import CLIPVisionModel, SiglipVisionModel

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from importlib.metadata import version
from typing import Any, Callable, Optional

import packaging
import torch
from pkg_resources import packaging

from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults
from nemo.collections.nlp.parts import utils_funcs
Expand Down
48 changes: 36 additions & 12 deletions nemo/collections/nlp/modules/common/megatron/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,9 @@

try:
# Flash Attention Triton
import pkg_resources
from flash_attn.flash_attn_triton import flash_attn_func as flash_attn_func_triton

except (ImportError, ModuleNotFoundError, pkg_resources.DistributionNotFound):
except (ImportError, ModuleNotFoundError):

flash_attn_func_triton = None

Expand Down Expand Up @@ -202,7 +201,12 @@ def __init__(
else:
assert attention_type == AttnType.cross_attn
self.query = tensor_parallel.ColumnParallelLinear(
hidden_size, projection_size, config=config, gather_output=False, init_method=init_method, bias=bias,
hidden_size,
projection_size,
config=config,
gather_output=False,
init_method=init_method,
bias=bias,
)

self.key_value = tensor_parallel.ColumnParallelLinear(
Expand Down Expand Up @@ -336,7 +340,7 @@ def _transpose_last_dim(self, mixed_layer, num_splits, num_splits_first):
"""[s, b, num_splits * np * hn]
-->(view) [s, b, num_splits, np, hn]
-->(tranpose) [s, b, np, num_splits, hn]
-->(view) [s, b, np * num_splits * hn] """
-->(view) [s, b, np * num_splits * hn]"""

intermediate_shape = input_shape[:-1] + (
num_splits,
Expand All @@ -350,7 +354,7 @@ def _transpose_last_dim(self, mixed_layer, num_splits, num_splits_first):
"""[s, b, np * hn * num_splits]
-->(view) [s, b, np, hn, num_splits]
-->(tranpose) [s, b, np, num_splits, hn]
-->(view) [s, b, np * num_splits * hn] """
-->(view) [s, b, np * num_splits * hn]"""

intermediate_shape = input_shape[:-1] + (
self.num_attention_heads_per_partition,
Expand Down Expand Up @@ -535,7 +539,10 @@ def forward(
)
v = _cast_if_autocast_enabled(rearrange(value_layer, 'sk b np hn -> b sk np hn'))
context_layer = flash_attn_with_kvcache(
q=q, k_cache=k, v_cache=v, causal=self.attn_mask_type == AttnMaskType.causal,
q=q,
k_cache=k,
v_cache=v,
causal=self.attn_mask_type == AttnMaskType.causal,
)
context_layer = rearrange(context_layer, 'b sq np hn -> sq b (np hn)')

Expand Down Expand Up @@ -742,9 +749,9 @@ def forward(


class CoreAttention(MegatronModule):
""" Region where selective activation recomputation is applied.
See Figure 3. in Reducing Activation Recomputation in Large Transformer Models
https://arxiv.org/pdf/2205.05198.pdf for more details.
"""Region where selective activation recomputation is applied.
See Figure 3. in Reducing Activation Recomputation in Large Transformer Models
https://arxiv.org/pdf/2205.05198.pdf for more details.
"""

Expand Down Expand Up @@ -994,10 +1001,21 @@ def flash_attention(self, query_layer, key_layer, value_layer, attention_mask, a

if attention_bias is not None:
return self.flash_attention_triton(
query_layer, key_layer, value_layer, attention_mask, attention_bias, is_causal,
query_layer,
key_layer,
value_layer,
attention_mask,
attention_bias,
is_causal,
)
else:
return self.flash_attention_cuda(query_layer, key_layer, value_layer, attention_mask, is_causal,)
return self.flash_attention_cuda(
query_layer,
key_layer,
value_layer,
attention_mask,
is_causal,
)

def flash_attention_cuda(self, query_layer, key_layer, value_layer, attention_mask, is_causal):
batch_size, seqlen, nheads, _ = query_layer.shape
Expand Down Expand Up @@ -1071,7 +1089,13 @@ def flash_attention_triton(self, query_layer, key_layer, value_layer, attention_
if attention_bias.shape[3] == attention_mask_kv.shape[3]:
attention_bias = attention_bias.masked_fill(~attention_mask_kv, torch.finfo(query_layer.dtype).min)

context_layer = flash_attn_func_triton(query_layer, key_layer, value_layer, attention_bias, is_causal,)
context_layer = flash_attn_func_triton(
query_layer,
key_layer,
value_layer,
attention_bias,
is_causal,
)

# [b, sq, np, hn] -> [b, np, sq, hn]
context_layer = context_layer.permute(0, 2, 1, 3)
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ onnx>=1.7.0
python-dateutil
ruamel.yaml
scikit-learn
setuptools>=65.5.1
setuptools>=70.0.0
tensorboard
text-unidecode
torch
Expand Down
9 changes: 7 additions & 2 deletions tests/collections/nlp/test_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
HAVE_FA = False

try:
import pkg_resources
import triton

HAVE_TRITON = True
Expand Down Expand Up @@ -80,7 +79,13 @@ def setup_class(cls):
MB_SIZE = 4
GB_SIZE = 8
SEED = 1234
trainer = Trainer(strategy=NLPDDPStrategy(), devices=GPUS, accelerator='gpu', num_nodes=1, logger=None,)
trainer = Trainer(
strategy=NLPDDPStrategy(),
devices=GPUS,
accelerator='gpu',
num_nodes=1,
logger=None,
)

initialize_model_parallel_for_nemo(
world_size=trainer.world_size,
Expand Down

0 comments on commit ae60b7f

Please sign in to comment.