Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TFOPTForCausalLM Attention mask size mismatch exception #24637

Closed
1 of 4 tasks
abb128 opened this issue Jul 3, 2023 · 6 comments · Fixed by #25238
Closed
1 of 4 tasks

TFOPTForCausalLM Attention mask size mismatch exception #24637

abb128 opened this issue Jul 3, 2023 · 6 comments · Fixed by #25238

Comments

@abb128
Copy link

abb128 commented Jul 3, 2023

System Info

  • transformers version: 4.30.2
  • Platform: Linux-5.15.107+-x86_64-with-glibc2.31
  • Python version: 3.10.12
  • Huggingface_hub version: 0.15.1
  • Safetensors version: 0.3.1
  • PyTorch version (GPU?): 2.0.1+cu118 (False)
  • Tensorflow version (GPU?): 2.12.0 (False)
  • Flax version (CPU?/GPU?/TPU?): 0.6.11 (cpu)
  • Jax version: 0.4.10
  • JaxLib version: 0.4.10
  • Using GPU in script?: No
  • Using distributed or parallel set-up in script?: No

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I'm trying to write my own decoding logic so I can export to TFLite (the app runs decoding logic itself, calling into the tflite model with past_key_values and input_ids but the code for that is a little more involved)

I'm not sure if I'm missing something important here but I was able to successfully export Whisper before with this sort of pattern

I've reduced the problem to this example:

Colab Link

import tensorflow as tf
from transformers import AutoTokenizer, TFOPTForCausalLM, TFGPT2LMHeadModel

def decoding_example(model, tokenizer):
  input_ids = tf.convert_to_tensor([[1]]) * int(tokenizer.bos_token_id)
  outputs = model(input_ids, return_dict=True, use_cache=True, past_key_values=None)

  past_key_values = outputs.past_key_values
  max_new_tokens = 8
  for i in range(max_new_tokens):
    print(i)
    decoded_next_token = 123 # just an example, this would depend on outputs.last_hidden_state

    input_ids = tf.convert_to_tensor([[1]]) * decoded_next_token

    outputs = model(input_ids, return_dict=True, use_cache=True, past_key_values=past_key_values)
    past_key_values = outputs.past_key_values
  
  print("Finished, all OK")

tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
model = TFOPTForCausalLM.from_pretrained("facebook/opt-125m")

decoding_example(model, tokenizer) # fails
Output
0
---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-5-07105bf5f115> in <cell line: 4>()
      2 model = TFOPTForCausalLM.from_pretrained("facebook/opt-125m")
      3 
----> 4 decoding_example(model, tokenizer) # fails

9 frames
<ipython-input-3-94ad2e4e3e50> in decoding_example(model, tokenizer)
     11     input_ids = tf.convert_to_tensor([[1]]) * decoded_next_token
     12 
---> 13     outputs = model(input_ids, return_dict=True, use_cache=True, past_key_values=past_key_values)
     14     past_key_values = outputs.past_key_values
     15 

/usr/local/lib/python3.10/dist-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
     68             # To get the full stack trace, call:
     69             # `tf.debugging.disable_traceback_filtering()`
---> 70             raise e.with_traceback(filtered_tb) from None
     71         finally:
     72             del filtered_tb

/usr/local/lib/python3.10/dist-packages/transformers/modeling_tf_utils.py in run_call_with_unpacked_inputs(self, *args, **kwargs)
    440 
    441         unpacked_inputs = input_processing(func, config, **fn_args_and_kwargs)
--> 442         return func(self, **unpacked_inputs)
    443 
    444     # Keras enforces the first layer argument to be passed, and checks it through `inspect.getfullargspec()`. This

/usr/local/lib/python3.10/dist-packages/transformers/models/opt/modeling_tf_opt.py in call(self, input_ids, past_key_values, attention_mask, position_ids, head_mask, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, training, **kwargs)
    956         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    957 
--> 958         outputs = self.model(
    959             input_ids=input_ids,
    960             past_key_values=past_key_values,

/usr/local/lib/python3.10/dist-packages/transformers/modeling_tf_utils.py in run_call_with_unpacked_inputs(self, *args, **kwargs)
    440 
    441         unpacked_inputs = input_processing(func, config, **fn_args_and_kwargs)
--> 442         return func(self, **unpacked_inputs)
    443 
    444     # Keras enforces the first layer argument to be passed, and checks it through `inspect.getfullargspec()`. This

/usr/local/lib/python3.10/dist-packages/transformers/models/opt/modeling_tf_opt.py in call(self, input_ids, attention_mask, head_mask, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, training, **kwargs)
    730         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    731 
--> 732         outputs = self.decoder(
    733             input_ids,
    734             attention_mask=attention_mask,

/usr/local/lib/python3.10/dist-packages/transformers/modeling_tf_utils.py in run_call_with_unpacked_inputs(self, *args, **kwargs)
    440 
    441         unpacked_inputs = input_processing(func, config, **fn_args_and_kwargs)
--> 442         return func(self, **unpacked_inputs)
    443 
    444     # Keras enforces the first layer argument to be passed, and checks it through `inspect.getfullargspec()`. This

/usr/local/lib/python3.10/dist-packages/transformers/models/opt/modeling_tf_opt.py in call(self, input_ids, inputs_embeds, attention_mask, head_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict, training)
    657             past_key_value = past_key_values[idx] if past_key_values is not None else None
    658 
--> 659             hidden_states, layer_self_attn, present_key_value = decoder_layer(
    660                 hidden_states,
    661                 attention_mask=attention_mask,

/usr/local/lib/python3.10/dist-packages/transformers/models/opt/modeling_tf_opt.py in call(self, hidden_states, attention_mask, layer_head_mask, past_key_value, training, output_attentions, use_cache)
    323 
    324         # add present self-attn cache to positions 1,2 of present_key_value tuple
--> 325         hidden_states, self_attn_weights, present_key_value = self.self_attn(
    326             hidden_states=hidden_states,
    327             past_key_value=self_attn_past_key_value,

/usr/local/lib/python3.10/dist-packages/transformers/models/opt/modeling_tf_opt.py in call(self, hidden_states, key_value_states, past_key_value, attention_mask, layer_head_mask, training)
    217 
    218         if attention_mask is not None:
--> 219             tf.debugging.assert_equal(
    220                 shape_list(attention_mask),
    221                 [bsz, 1, tgt_len, src_len],

InvalidArgumentError: Exception encountered when calling layer 'self_attn' (type TFOPTAttention).

Attention mask should be of size (1, 1, 0, 1), but is [1, 1, 1, 2]
Condition x == y did not hold.
Indices of first 2 different values:
[[2]
 [3]]
Corresponding x values:
[1 2]
Corresponding y values:
[0 1]
First 3 elements of x:
[1 1 1]
First 3 elements of y:
[1 1 0]

Call arguments received by layer 'self_attn' (type TFOPTAttention):
  • hidden_states=tf.Tensor(shape=(1, 0, 768), dtype=float32)
  • key_value_states=None
  • past_key_value=('tf.Tensor(shape=(1, 12, 1, 64), dtype=float32)', 'tf.Tensor(shape=(1, 12, 1, 64), dtype=float32)')
  • attention_mask=tf.Tensor(shape=(1, 1, 1, 2), dtype=float32)
  • layer_head_mask=None
  • training=False

Expected behavior

I expect it to work like it does with GPT2

tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
model = TFGPT2LMHeadModel.from_pretrained("distilgpt2")

decoding_example(model, tokenizer) # works
@amyeroberts
Copy link
Collaborator

cc @Rocketknight1

@Rocketknight1
Copy link
Member

Rocketknight1 commented Jul 10, 2023

Yep, something is clearly being mangled in here. The hidden_states shape of (1, 0, 768) is alarming - there's obviously some incorrect array slicing happening somewhere. I'll investigate as soon as I get a chance, but if you want to try taking a look before then, the relevant code is all in this file. If you want to try debugging it yourself, I'd advise:

  1. Clone transformers yourself: git clone https://github.com/huggingface/transformers.git
  2. Make an editable install from that local repo: cd transformers && pip install -e .
  3. Start putting breakpoint() or tests in the modeling_tf_opt.py file and seeing if you can find where the arrays get sliced down to length 0.

That's a lot of work, though - if you can wait, I'll get around to it in a few days!

@Rocketknight1
Copy link
Member

Unfortunately, I didn't manage to finish this before a holiday due to some more Falcon chaos - cc @gante if you get a chance, and if not I can take it when I get back!

I identified the core problem as some confusion in the code about what the actual seq_length is. The first problem is here - it uses the sequence length from input_ids / input_embeds to build an attention_mask if one isn't provided, but the actual shape should be (batch_size, seq_length + past_key_values_length), whereas this just builds one with shape (batch_size, seq_length).

However, fixing this led to other problems - the expanded/combined attention mask code also gets a bit confused when past_key_values is present. I'm not sure why generation tests don't pick this up, but possibly they explicitly pass an attention mask and avoid the issue!

This attention mask expansion code has been copied all around the codebase - I encountered in in PyTorch Falcon and BLOOM recently, where it also caused some problems. This might be worth doing a repo-wide refactor at some point, as I think the code is unclear and the variable names can be confusing, probably because it started as encoder-decoder code and is now being used to manage attention over past key-values.

@abb128
Copy link
Author

abb128 commented Jul 14, 2023

Unrelated to this issue but for tflite export I end up having to do something hacky anyway to pass a custom past_key_values_length value, since the shape is dynamic and code cannot depend on it during tflite export (past_key_values[0][0].shape[2] just resolves to None and causes an exception later on trying to use None as a number). It'd be nice if there was a built-in way to pass a past_key_values_length value

@Rocketknight1
Copy link
Member

Hi @abb128, good point! That might be a sign that we should be using tf.shape() instead, which will correctly allow the dynamic shape to be compiled. I'll investigate while I'm fixing the rest of this.

@Rocketknight1
Copy link
Member

@abb128 I've filed a patch - please try it and let me know if it works for you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants