Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inputs left-padded passed to Instruct-Mistral-7B, with FlashAttention-2, causes garbage outputs for the padded sequences #29075

Closed
2 of 4 tasks
millicentli opened this issue Feb 17, 2024 · 5 comments

Comments

@millicentli
Copy link

millicentli commented Feb 17, 2024

System Info

transformers version: 4.36.2
Pytorch version: 2.2.0
Platform: Rocky Linux release 8.8 (Green Obsidian), 4.18.0-477.27.1.el8_8.x86_64
Python version: Python 3.9.18
Accelerate version: 0.26.1
FlashAttention-2 version: 2.5.3

Who can help?

@ArthurZucker, @younesbelkada

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Inference on Mistral-7B seems to vary wildly with padding when using FlashAttention-2, versus having no padding with FlashAttention-2.

The behavior for inference with FA-2 seems to be dependent on the complexity of the task -- in my case, I'm doing multi-document summarization, and my example is a multi-document example. I didn't try too hard to find a simpler example because a simple input text didn't seem to exhibit the same issues.

In addition for the reproduction, I've included the text I use (the examples below will take in the text for debugging).
text.txt

Example (minimal reproduction):

With FlashAttention-2

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig

model_name = "mistralai/Mistral-7B-Instruct-v0.1"
torch_dtype = torch.bfloat16

tokenizer_kwargs = {
    "add_bos_token": False,
    "add_eos_token": False,
    "padding_side": "left"
}
config = AutoConfig.from_pretrained(model_name)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    config=config,
    torch_dtype=torch_dtype,
    low_cpu_mem_usage=True,
    attn_implementation="flash_attention_2",
    device_map="balanced"
)

tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_kwargs)
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.sep_token = "[END DOCUMENT]"

f = open("text.txt", "r")
inputs = f.readlines()
inputs = tokenizer(inputs, return_tensors="pt")

# Find indices where = 2
pad_indices = (inputs['input_ids'] == 2).nonzero()
inputs['attention_mask'][:, pad_indices] = 0

inputs = {k: inputs[k].cuda() for k in inputs}

outputs = model.generate(
    **inputs,
    num_beams=2,
    no_repeat_ngram_size=3,
    max_new_tokens=256,
    pad_token_id=tokenizer.pad_token_id
)

print(tokenizer.batch_decode(outputs[:, inputs['input_ids'].shape[1]:][0].reshape(1, -1)))

The output:

['The']

Without FlashAttention-2

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig

model_name = "mistralai/Mistral-7B-Instruct-v0.1"
torch_dtype = torch.bfloat16

tokenizer_kwargs = {
    "add_bos_token": False,
    "add_eos_token": False,
    "padding_side": "left"
}
config = AutoConfig.from_pretrained(model_name)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    config=config,
    torch_dtype=torch_dtype,
    low_cpu_mem_usage=True,
    device_map="balanced"
)

tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_kwargs)
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.sep_token = "[END DOCUMENT]"

f = open("text.txt", "r")
inputs = f.readlines()
inputs = tokenizer(inputs, return_tensors="pt")

# Find indices where = 2
pad_indices = (inputs['input_ids'] == 2).nonzero()
inputs['attention_mask'][:, pad_indices] = 0

inputs = {k: inputs[k].cuda() for k in inputs}

outputs = model.generate(
    **inputs,
    num_beams=2,
    no_repeat_ngram_size=3,
    max_new_tokens=256,
    pad_token_id=tokenizer.pad_token_id
)

print(tokenizer.batch_decode(outputs[:, inputs['input_ids'].shape[1]:][0].reshape(1, -1)))

The output:

['The above documents discuss various studies and research related to the effects of breast- feeding on the health and development of newborn babies. Some studies suggest that breast- fed babies have a lower risk of certain health problems, such as infections and allergies, while others find no significant differences between breast- and formula-fed babies.\n\nOne study found that exclusive breast -feeding for at least six months was associated with a reduced risk of SIDS (Sudden Infant Death Syndrome) in infants aged 6-12 months. Another study found no significant association between breast - feeding and SIDS risk in infancy, but did find that breast - fed babies had a lower incidence and severity of respitory infections in the early months of life compared to formula- fed infancy.\nIn addition, some studies have found that breast feeding may have a positive effect on the cognitive and emotional development of babies. For example, one study found a positive correlation between breast feeding duration and cognitive development in infancies.\nOverall, the evidence suggests that breastfeeds is beneficial for newborn health and well-being, but more research is needed to fully understand the effects and to identify the optimal duration and frequency of feeding for different babies']

Discovered this issue by debugging and removing the padding from the beginning of the sequence; if the padding is gone from the beginning of the sequence, then the behavior w/ and w/o FA-2 is similar. Other attempts at debugging: upgraded FA-2 version to the latest, and torch version to 2.2.0, but neither solution fixed the problem. Did the Pytorch upgrade because of pytorch/pytorch#112577 but this didn't seem to be the problem. Also upgraded transformers to be 4.37.2 and it was still a problem there.

Expected behavior

The behavior for inference with FA-2 should be similar as inference without FA-2, but it's wildly different.

@millicentli millicentli changed the title Inputs left-padded with FlashAttention-2, on Mistral-7B, causes garbage outputs for the padded sequences Inputs left-padded passed to Instruct-Mistral-7B, with FlashAttention-2, causes garbage outputs for the padded sequences Feb 17, 2024
@millicentli
Copy link
Author

Update: so downgrading to transformers: 4.34.0 fixed this issue. See: https://discuss.huggingface.co/t/fine-tuned-mistral-7b-inference-issue-for-4k-context-length-token-with-transformer-4-35/65295

This is still a problem though with the transformers version noted though, so would like a fix if possible for the most recent one (so I'll keep this open).

@ArthurZucker
Copy link
Collaborator

Having a look right now, but the padding and the attention should not be manually changed, the tokenizer is supposed to take care of that

@millicentli
Copy link
Author

millicentli commented Feb 20, 2024

Yes I know of course, this is just an example to replicate what's happening and to visualize the bug for you guys, but in my actual code, the tokenizer takes care of the padding and attention. (in my own code, my batch size is > 1, but this is one example that I've scoped down to showcase the issue. Generally, the trend is, in any batched input, the only sample that has coherent output is the one without padding).

@ArthurZucker
Copy link
Collaborator

That would be very surprising as we try to make sure padding influences as little as possible. Of course you have this but it should be limited and we test left padded generation

@ArthurZucker
Copy link
Collaborator

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
model_name = "mistralai/Mistral-7B-Instruct-v0.1"

tokenizer_kwargs = {
    "add_bos_token": True,
    "add_eos_token": False,
    "padding_side": "left"
}
model = AutoModelForCausalLM.from_pretrained(model_name,torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, attn_implementation="flash_attention_2",device_map="balanced")
tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_kwargs)
tokenizer.pad_token_id = tokenizer.eos_token_id

inputs = tokenizer(["Hey! How are you doing?", "My favorite condiment is definitely:"], return_tensors="pt", padding = True).to(model.device)

outputs = model.generate(**inputs,num_beams=2,no_repeat_ngram_size=3,max_new_tokens=256,pad_token_id=tokenizer.pad_token_id)
print(tokenizer.batch_decode(outputs))
["<s> Hey! How are you doing?\n\nI'm doing well, thanks for asking! I've been keeping busy with work and other projects, but I'm always happy to chat with you. What have you been up to lately?</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s>", "<s> My favorite condiment is definitely: MAYONNAISE!\n\nMayonnaise is a versatile condiment that can be used in a variety of dishes. It is a classic ingredient in many salads, sandwiches, and dips. Mayonnaise can also be used as a base for many sauces and dressings.\n\nOne of my favorite ways to use mayonnaise in the kitchen is to make a classic Caesar salad. The creamy, tangy flavor of the mayonnaises pairs perfectly with the crisp lettuce, crunchy croutons, and salty parmesan cheese.\nI also love using mayonnaize as a spread for my sandwiches. It adds a rich, creamy texture that elevates the flavor of any sandwich. And, it's a great base for dips like tzatziki or aioli.\nOverall, I think mayonnais is a must-have condiment in any kitchen. It's versatile, flavorful, and adds a touch of luxury to any dish.</s>"]

I used the above reproducible snippet, which generates coherent and good text for padded sequence.
I'll close this for now but feel free to re-open if you are not satisfied.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants