-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix for Neuron #30259
Fix for Neuron #30259
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The attention layer also uses the ...
notation but if it's not a prblem good. Thanks for fixing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
@michaelbenayoun There were a bunch of fp16 issues we had to solve last week because of the new pytorch release. Could you try rebasing on main? This should resolve the currently failing tests |
Some tests are failing because of the PR itself. Working on that today. |
@@ -1021,8 +1021,11 @@ def _update_causal_mask( | |||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit | |||
if attention_mask.dim() == 2: | |||
mask_length = attention_mask.shape[-1] | |||
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) | |||
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) | |||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you risk overflows here no?
If you have torch_min + torch_min -> produced inf
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 - we should try to avoid this. I think it's technically OK because of the masked_fill below
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even if you overflow it's ok because the next line you create a boolean tensor: padding_mask = padding_mask == 0
. So in the end you get the same result as long as the elements in the tensor that need to be zero end up having this value.
This code, which overflows in the context of the function:
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask *= 2
padding_mask = padding_mask == 0
produces the same output as this code, which does not overflow:
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on and fixing this!
@@ -1021,8 +1021,11 @@ def _update_causal_mask( | |||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit | |||
if attention_mask.dim() == 2: | |||
mask_length = attention_mask.shape[-1] | |||
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) | |||
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) | |||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 - we should try to avoid this. I think it's technically OK because of the masked_fill below
src/transformers/utils/fx.py
Outdated
"`past_key_values` were specified as input names, but model.config.use_cache = False, this might lead to " | ||
"unexpected behavior." | ||
) | ||
if "past_key_values" not in input_names and hasattr(model.config, "use_cache") and model.config.use_cache: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For my own understanding - what's the reason for the different approach for getting the attr here - does model.config.use_cache
default to None
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When tracing with FX, you transform the inputs tensors to Proxy
that are recording the flow of operations.
So here we check for multiple cases:
past_key_values
is in the requested inputs, meaning that the users wants to trace the model with the use of the cache. In this case, we warn the user ifmodel.config.use_cache
isFalse
because no cache-related operations will be recorded.past_key_values
is not in the requested inputs, butmodel.config.use_cache
isTrue
. In this setting we will create aDynamicCache
(withpast_key_values=None
), but this operation will be "hardcoded" in the graph because no proxy input will be provided (sincepast_key_values=None
), resulting in failures. So if the users does not requestpast_key_values
as inputs, we disable the use of cache altogether.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, sorry, what I meant was there seems to be different assumptions about the default for model.config.use_cache
between these two branches. i.e. why don't we do getattr(model.config, "use_cache", False)
here too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can do that, changing it!
@amyeroberts about this comment, you can check my answer to @ArthurZucker . To give a little bit more context, I'm doing this because the way it is currently implemented produces a compiler error on AWS Trainium / Inferentia devices, preventing us to use these models with the latest Transformers version. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for iterating
What does this PR do?
This PR fixes things for them to run on Trainium instances.
It fixes:
symbolic_trace
. Before no metadata was traced when a user would define its custom leaf modules, it should be working now.transformers.cache_utils.Cache
classes are handled. They can now be symbolically traced.