Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Gemma2 Attention Args #11365

Merged
merged 4 commits into from
Nov 21, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ class Gemma2DotProductAttention(MegatronModule):
Region where selective activation recomputation is applied.
This region is memory intensive but less compute intensive which
makes activation checkpointing more efficient for LLMs (20B+).
See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details.
See Reducing Activation Recomputation in Large Transformer Models:
https://arxiv.org/abs/2205.05198 for more details.

We use the following notation:
h: hidden size
Expand Down Expand Up @@ -126,7 +127,12 @@ def forward(
attention_mask: Tensor,
attn_mask_type: AttnMaskType = None,
packed_seq_params: PackedSeqParams = None,
**kwargs,
):
"""Forward.
Modified from mcore.transformer.dot_product_attention to support Gemma2-specific
final_logit_softcapping.
"""
assert packed_seq_params is None, (
"Packed sequence is not supported by DotProductAttention." "Please use TEDotProductAttention instead."
)
Expand Down Expand Up @@ -243,6 +249,8 @@ def forward(


class TERowParallelLinearLayerNorm(TERowParallelLinear):
"""Modified From TERowParallelLinear with an additional Post-LN."""

def __init__(
self,
input_size: int,
Expand Down Expand Up @@ -270,12 +278,16 @@ def __init__(
self.post_layernorm = TENorm(config, output_size)

def forward(self, x):
"""Forward with additional Post LN on output"""
output, bias = super().forward(x)
return self.post_layernorm(output), bias


class Gemma2OutputLayer(ColumnParallelLinear):
"""Extends from ColumnParallelLinear with logit soft capping."""

def forward(self, *args, **kwargs):
"""Forward with logit soft capping."""
output, bias = super().forward(*args, **kwargs)
output = logit_softcapping(output, self.config.final_logit_softcapping)
return output, bias
Loading