Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for Neuron #30259

Merged
merged 16 commits into from
May 2, 2024
7 changes: 5 additions & 2 deletions src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,8 +1003,11 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position, curr
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.dim() == 2:
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you risk overflows here no?
If you have torch_min + torch_min -> produced inf

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 - we should try to avoid this. I think it's technically OK because of the masked_fill below

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if you overflow it's ok because the next line you create a boolean tensor: padding_mask = padding_mask == 0. So in the end you get the same result as long as the elements in the tensor that need to be zero end up having this value.

This code, which overflows in the context of the function:

padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask *= 2
padding_mask = padding_mask == 0

produces the same output as this code, which does not overflow:

padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0

padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
elif attention_mask.dim() == 4:
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only.
Expand Down
7 changes: 5 additions & 2 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,8 +989,11 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position, curr
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.dim() == 2:
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
elif attention_mask.dim() == 4:
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only.
Expand Down
7 changes: 5 additions & 2 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,8 +1082,11 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position, curr
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.dim() == 2:
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
elif attention_mask.dim() == 4:
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only.
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,12 @@
if os.environ.get("TORCHELASTIC_RUN_ID"):
if is_optimum_neuron_available():
logger.info(
"Make sure that you are performing the training with the TrainiumTrainer from optimum[neuron], this "
"Make sure that you are performing the training with the NeuronTrainer from optimum[neuron], this "
"will fail otherwise."
)
else:
logger.warning(
"Please use the TrainiumTrainer from optimum[neuron] instead of the Transformers library to perform "
"Please use the NeuronTrainer from optimum[neuron] instead of the Transformers library to perform "
"training on AWS Trainium instances. More information here: "
"https://github.com/huggingface/optimum-neuron"
)
Expand Down
Loading