Skip to content

Commit

Permalink
add flash_attention on model chatglm_v2 (PaddlePaddle#9296)
Browse files Browse the repository at this point in the history
* add flash_attention on model chatglm_v2

* add flash attn and consider sequence parallel

---------

Co-authored-by: huxinye <2392038429>
  • Loading branch information
Mangodadada authored Oct 28, 2024
1 parent 7e37028 commit 2993974
Showing 1 changed file with 45 additions and 13 deletions.
58 changes: 45 additions & 13 deletions paddlenlp/transformers/chatglm_v2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from paddle.utils import map_structure

from paddlenlp.transformers.long_sequence_strategies import LongSequenceStrategies
from paddlenlp.utils.log import logger

from ...utils.converter import StateDictNameMapping, init_name_mappings
from .. import PretrainedModel, linear_utils, register_base_model
Expand Down Expand Up @@ -395,25 +396,56 @@ def forward(
# ==================================
# core attention computation
# ==================================
attention_fuc = self._core_attention

has_gradient = (
(not query_layer.stop_gradient) or (not key_layer.stop_gradient) or (not value_layer.stop_gradient)
)
if self.enable_recompute and self.config.recompute_granularity == "core_attn" and has_gradient:
context_layer = recompute(
attention_fuc,
version = paddle.version.full_version
version_check = True
if self.config.use_flash_attention and version != "0.0.0" and version <= "2.5.2":
logger.warning(
"PaddlePaddle version 2.5.3 or higher is required, please upgrade your PaddlePaddle to 2.5.3 or other higher version."
)
version_check = False
if self.config.use_flash_attention and version_check:
query_layer = query_layer.transpose([1, 0, 2, 3])
key_layer = key_layer.transpose([1, 0, 2, 3])
value_layer = value_layer.transpose([1, 0, 2, 3])
# attention_mask = attention_mask
attn_output = F.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
attention_mask,
output_attentions,
use_reentrant=False,
attn_mask=attention_mask,
dropout_p=self.config.attention_dropout,
training=self.training,
is_causal=False,
)
batch_size, q_length, _, _ = query_layer.shape
if self.config.sequence_parallel:
context_layer = attn_output.reshape([batch_size * q_length, -1])
else:
context_layer = attn_output.reshape([q_length, batch_size, -1])
else:
context_layer = attention_fuc(
query_layer, key_layer, value_layer, attention_mask=attention_mask, output_attentions=output_attentions
attention_fuc = self._core_attention

has_gradient = (
(not query_layer.stop_gradient) or (not key_layer.stop_gradient) or (not value_layer.stop_gradient)
)
if self.enable_recompute and self.config.recompute_granularity == "core_attn" and has_gradient:
context_layer = recompute(
attention_fuc,
query_layer,
key_layer,
value_layer,
attention_mask,
output_attentions,
use_reentrant=False,
)
else:
context_layer = attention_fuc(
query_layer,
key_layer,
value_layer,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
# =================
# Output. [seq_length, b, h]
# =================
Expand Down

0 comments on commit 2993974

Please sign in to comment.