-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generate: SinkCache can handle iterative prompts #27907
Conversation
@@ -209,8 +217,11 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: | |||
# Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length | |||
if len(self.key_cache) <= layer_idx: | |||
return 0 | |||
cache_length = self.key_cache[layer_idx].shape[-2] | |||
return min(cache_length, self.window_length - 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line was a bit of a hack, get_max_length
makes us no longer need the hack :)
get_seq_length
now always does what the fn name and the docstring say it does.
src/transformers/cache_utils.py
Outdated
@@ -239,8 +250,8 @@ def update( | |||
""" | |||
# Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models | |||
# with partially rotated position embeddings, like Phi or Persimmon. | |||
sin = cache_kwargs.get("sin") | |||
cos = cache_kwargs.get("cos") | |||
sin = cache_kwargs.get("sin")[: self.window_length] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Slicing here is needed if more than one token is fed at once, after the cache is full.
kv_max_length = past_key_value.get_max_length() | ||
if kv_max_length is not None and kv_seq_len > kv_max_length: | ||
kv_seq_len = kv_max_length | ||
attention_mask = attention_mask[:, :, :, -kv_seq_len:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The attention mask must be sliced to match the length of key_states
, which might have been sliced in .update()
(for fixed-length caches)
@@ -187,3 +189,37 @@ def test_sink_cache_hard(self): | |||
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=3000, past_key_values=cache) | |||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) | |||
self.assertTrue(decoded[0].endswith("to perform a variety of tasks. The Transformer is a neural network")) | |||
|
|||
@require_auto_gptq | |||
def test_sink_cache_iterative_prompts(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test would fail on main
Some additional results @gante
This experiment is with calling the models with individual tokens exclusively. Using Additionally, I'm doing a test with calling Mistral-7B in this case with this PR using multiple tokens at once. I took indices 0 to 63k of a book from pg19, only kept 20% of all indices, and then fed the model with the tokens between subsequent indices. The running cache is also included. The NLL loss is then converted to perplexity. The same script crashes on
The perplexity diverges quite heavily between the SinkCache and non, which is not ideal. Perhaps this is indicative of some error/bug, or perhaps not. It's a bit hard to tell. Beyond that, the left and right cache implementations behave identically (unless I made some measuring mistake), which is a bit odd. I don't have 100% confidence in this fix anymore I'm afraid.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks
src/transformers/cache_utils.py
Outdated
@@ -239,8 +250,8 @@ def update( | |||
""" | |||
# Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models | |||
# with partially rotated position embeddings, like Phi or Persimmon. | |||
sin = cache_kwargs.get("sin") | |||
cos = cache_kwargs.get("cos") | |||
sin = cache_kwargs.get("sin")[: self.window_length] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sin = cache_kwargs.get("sin")[: self.window_length] | |
sin = cache_kwargs.get("sin")[-self.window_length:] |
would that not make more sense? since that's the side we split the cache from no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll run some tests for this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated my previous message with my findings
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ArthurZucker @tomaarsen that's the beauty of sin and cos: the rerotation applied on the sink caches is based on the relative angles, and the relative angles are the same regardless of the side we slice 🙌
If you place a debugger, you can see that the sliced tensors are different, but the rerotation coefficients are exactly the same!
kv_max_length = past_key_value.get_max_length() | ||
if kv_max_length is not None and kv_seq_len > kv_max_length: | ||
kv_seq_len = kv_max_length | ||
attention_mask = attention_mask[:, :, :, -kv_seq_len:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ouch not a fan of that.
We have something similar for the mistral sliced window but not in favor of keeping this. That should either go in the attention convert, which should slice it, or in the cache_kwargs as it's clearly sink cache and window cache specific 😉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The alternative would be to make prepare_inputs_for_generation
to prepare the sliced mask in advance (which is how TF/JAX do it). Going to check if it is feasible
@ArthurZucker as per your suggestion, I've reworked the PR to avoid post hoc @tomaarsen the rework seems to have resulted in a qualitative result upgrade (e.g. see the test case), so I suspect that I've inadvertently fixed a bug 👀 Would you be able to rerun your benchmarks for |
@gante I get |
@tomaarsen no, the script in the PR header runs endlessly without issues 🤔 LMK if you can find a reproducer |
I have the same, that script works fine. Hmmm |
Got a reproducer: change |
@tomaarsen should be fixed |
@gante Works great now! Red is the baseline, I can only run it to about ~15k seq length until my PC completely freezes. |
🙌 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the prompt update
# older attention values, as their corresponding values are not part of the input. | ||
if cache_length < past_length and attention_mask is not None: | ||
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] | ||
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay better! thanks
@@ -268,7 +268,7 @@ def forward( | |||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " | |||
"with a layer index." | |||
) | |||
kv_seq_len += past_key_value.get_seq_length(self.layer_idx) | |||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
slice_length or window_length might be better? but a nit feel free to ignore
* add sdpa * wip * cleaning * add ref * yet more cleaning * and more :) * wip llama * working llama * add output_attentions=True support * bigcode sdpa support * fixes * gpt-bigcode support, require torch>=2.1.1 * add falcon support * fix conflicts falcon * style * fix attention_mask definition * remove output_attentions from attnmaskconverter * support whisper without removing any Copied from statement * fix mbart default to eager renaming * fix typo in falcon * fix is_causal in SDPA * check is_flash_attn_2_available in the models init as well in case the model is not initialized through from_pretrained * add warnings when falling back on the manual implementation * precise doc * wip replace _flash_attn_enabled by config.attn_implementation * fix typo * add tests * style * add a copy.deepcopy on the config in from_pretrained, as we do not want to modify it inplace * obey to config.attn_implementation if a config is passed in from_pretrained * fix is_torch_sdpa_available when torch is not installed * remove dead code * Update src/transformers/modeling_attn_mask_utils.py Co-authored-by: Arthur <[email protected]> * Update src/transformers/modeling_attn_mask_utils.py Co-authored-by: Arthur <[email protected]> * Update src/transformers/modeling_attn_mask_utils.py Co-authored-by: Arthur <[email protected]> * Update src/transformers/modeling_attn_mask_utils.py Co-authored-by: Arthur <[email protected]> * Update src/transformers/modeling_attn_mask_utils.py Co-authored-by: Arthur <[email protected]> * Update src/transformers/models/bart/modeling_bart.py Co-authored-by: Arthur <[email protected]> * remove duplicate pretraining_tp code * add dropout in llama * precise comment on attn_mask * add fmt: off for _unmask_unattended docstring * precise num_masks comment * nuke pretraining_tp in LlamaSDPAAttention following Arthur's suggestion * cleanup modeling_utils * backward compatibility * fix style as requested * style * improve documentation * test pass * style * add _unmask_unattended tests * skip meaningless tests for idefics * hard_check SDPA requirements when specifically requested * standardize the use if XXX_ATTENTION_CLASSES * fix SDPA bug with mem-efficient backend on CUDA when using fp32 * fix test * rely on SDPA is_causal parameter to handle the causal mask in some cases * fix FALCON_ATTENTION_CLASSES * remove _flash_attn_2_enabled occurences * fix test * add OPT to the list of supported flash models * improve test * properly test on different SDPA backends, on different dtypes & properly handle separately the pad tokens in the test * remove remaining _flash_attn_2_enabled occurence * Update src/transformers/modeling_utils.py Co-authored-by: Arthur <[email protected]> * Update src/transformers/modeling_utils.py Co-authored-by: Arthur <[email protected]> * Update src/transformers/modeling_utils.py Co-authored-by: Arthur <[email protected]> * Update src/transformers/modeling_attn_mask_utils.py Co-authored-by: Arthur <[email protected]> * Update docs/source/en/perf_infer_gpu_one.md Co-authored-by: Arthur <[email protected]> * remove use_attn_implementation * fix docstring & slight bug * make attn_implementation internal (_attn_implementation) * typos * fix tests * deprecate use_flash_attention_2=True * fix test * add back llama that was removed by mistake * fix tests * remove _flash_attn_2_enabled occurences bis * add check & test that passed attn_implementation is valid * fix falcon torchscript export * fix device of mask in tests * add tip about torch.jit.trace and move bt doc below sdpa * fix parameterized.expand order * move tests from test_modeling_attn_mask_utils to test_modeling_utils as a relevant test class is already there * update sdpaattention class with the new cache * Update src/transformers/configuration_utils.py Co-authored-by: Arthur <[email protected]> * Update src/transformers/models/bark/modeling_bark.py * address review comments * WIP torch.jit.trace fix. left: test both eager & sdpa * add test for torch.jit.trace for both eager/sdpa * fix falcon with torch==2.0 that needs to use sdpa * fix doc * hopefully last fix * fix key_value_length that has no default now in mask converter * is it flacky? * fix speculative decoding bug * tests do pass * fix following #27907 --------- Co-authored-by: Arthur <[email protected]>
@tomaarsen @gante @ArthurZucker Hi, thanks for the work. Do you know how we can reproduce the benchmark test with the latest transformers repo implementation? Are there any scripts? Really appreciated! |
@YJHMITWEB you can adapt the following scripts to your need. Note -- Llama + attention sinks is probably not working on the latest version, Mistral should be. We are reworking how cache works, things may be a bit rough :) Worst case scenario, use Script to gather perplexity"""
Adapted from https://github.com/mit-han-lab/streaming-llm
Note: Although this script measures latency, it is not optimized whatsoever!
The latency is only tracked to see the impact of speed over time.
Usage:
python benchmark/perplexity.py --experiment attention_sinks
python benchmark/perplexity.py --experiment transformers
python benchmark/perplexity.py --experiment windowed
"""
import argparse
import itertools
import time
from collections import defaultdict
from pathlib import Path
from typing import Optional
import pandas as pd
import torch
from datasets import load_dataset
from torch.nn import CrossEntropyLoss
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
def compute_perplexity(
model,
tokenizer,
dataset,
experiment: str,
output_dir: str = "outputs",
data_column: str = "text",
num_samples: int = 1,
num_tokens: Optional[int] = None,
overwrite: bool = False,
) -> None:
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
output_file = output_dir / f"{experiment}.csv"
if output_file.exists() and not overwrite:
raise ValueError(
f"The {output_file!r} output file already exists - if you really want to override it, then use `--overwrite`."
)
logs = defaultdict(list)
loss_fn = CrossEntropyLoss(reduction="none")
past_key_values = None
num_processed_tokens = 0
for text in itertools.islice(dataset, num_samples):
encodings = tokenizer(text[data_column], return_tensors="pt")
seq_len = encodings.input_ids.size(1)
print(f"sequence length: {seq_len}")
pbar = tqdm(range(0, seq_len - 1))
for idx in pbar:
start_t = time.time()
input_ids = encodings.input_ids[:, idx : idx + 1].to(model.device)
with torch.no_grad():
outputs = model(input_ids, past_key_values=past_key_values, use_cache=True)
logits = outputs.logits.view(-1, model.config.vocab_size)
past_key_values = outputs.past_key_values
label = encodings.input_ids[:, idx + 1 : idx + 2].to(logits.device).view(-1)
neg_log_likelihood = loss_fn(logits, label)
perplexity = neg_log_likelihood.exp()
pbar.set_description(f"nll: {neg_log_likelihood.item():>5.2f}, ppl: {perplexity.item():>8.2f}")
# Store data and save every 10 tokens
logs["input_length"].append(idx + 1)
logs["nll"].append(neg_log_likelihood.item())
logs["ppl"].append(perplexity.item())
logs["overall_ppl"].append(torch.tensor(logs["nll"]).mean().exp().item())
logs["cuda_vram_allocated"].append(torch.cuda.memory_allocated(0) / 1024 / 1024 / 1024) # in GB
logs["latency"].append(time.time() - start_t)
if num_processed_tokens % 10 == 0:
try:
pd.DataFrame(logs).to_csv(output_file, index=False)
except KeyboardInterrupt as ex:
# If there's a Keyboard Interrupt, still write the file, and then stop
pd.DataFrame(logs).to_csv(output_file, index=False)
raise ex
num_processed_tokens += 1
if num_tokens and num_processed_tokens >= num_tokens:
return
def main():
parser = argparse.ArgumentParser()
# How to call this experiment?
parser.add_argument(
"--experiment", type=str, default="main"
)
# Model args
# parser.add_argument("--model_name_or_path", type=str, default="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
parser.add_argument("--model_name_or_path", type=str, default="meta-llama/Llama-2-7b-hf")
parser.add_argument("--revision", type=str, default="main")
parser.add_argument("--trust_remote_code", action="store_true")
# Dataset args
parser.add_argument("--dataset_name", type=str, default="emozilla/pg19-test")
parser.add_argument("--data_column", type=str, default="text")
parser.add_argument("--task", type=str, default=None)
parser.add_argument("--split", type=str, default="test", choices=["validation", "test"])
# parser.add_argument("--num_samples", type=int, default=1)
parser.add_argument("--num_tokens", type=int, default=5000)
parser.add_argument("--dtype", type=str, default="fp16")
# Where to log
parser.add_argument("--output_dir", type=str, default="/home/joao/joao_scripts/perplexity/outputs")
parser.add_argument("--overwrite", action="store_true")
args = parser.parse_args()
if args.dtype == "fp16":
dtype = torch.float16
elif args.dtype == "fp32":
dtype = torch.float32
elif args.dtype == "bf16":
dtype = torch.bfloat16
else:
raise ValueError(f"Unknown dtype: {args.dtype}")
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
revision=args.revision,
trust_remote_code=bool(args.trust_remote_code),
attn_implementation="eager",
torch_dtype=dtype,
device_map="auto",
)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=bool(args.trust_remote_code))
# Set up the dataset
dataset = load_dataset(args.dataset_name, args.task, split=args.split, streaming=True)
compute_perplexity(
model,
tokenizer,
dataset,
args.experiment,
output_dir=args.output_dir,
data_column=args.data_column,
num_samples=1, # <- No support for more than one instance now
num_tokens=args.num_tokens,
overwrite=args.overwrite,
)
if __name__ == "__main__":
main() Script to plot perplexity (and other metrics)"""
First run `perplexity.py` to generate one or more `csv` files.
This script can plot those csv files.
Usage:
python benchmark/plot_perplexity.py
python benchmark/plot_perplexity.py --features perplexity latency --title "Log perplexity & latency of Llama 2 7B as a function of input lengths"
"""
import argparse
from pathlib import Path
from typing import List, Optional
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
FEATURE_DF_MAP = {
"perplexity": "overall_ppl",
"vram": "cuda_vram_allocated",
"latency": "latency",
}
FEATURE_STYLE_MAP = {
"perplexity": "-",
"vram": "--",
"latency": ":",
}
FEATURE_LABEL_MAP = {
"perplexity": "Perplexity (log), lower is better",
"vram": "CUDA VRAM Usage (GB), lower is better",
"latency": "Time per token (sec), lower is better",
}
def plot(
features: List[str],
output_dir: str = "outputs",
title: Optional[str] = None,
perplexity_limit: Optional[float] = None,
skip_first: int = 100,
):
output_dir = Path(output_dir)
fig, ax = plt.subplots()
ax.set_xlabel("Input Sequence Length")
for feature_i, feature in enumerate(features):
# If we already plotted on this ax, make a new one
if feature_i:
ax = ax.twinx()
for file in output_dir.glob("*.csv"):
experiment = file.stem
df = pd.read_csv(file)
X = df["input_length"][skip_first:]
Y = df[FEATURE_DF_MAP[feature]][skip_first:]
if feature == "perplexity":
Y = np.log(Y)
if feature == "latency":
poly = np.polyfit(X, Y, 20)
poly_y = np.poly1d(poly)(X)
ax.plot(X, poly_y, FEATURE_STYLE_MAP[feature], label=f"{experiment} {feature}")
else:
ax.plot(X, Y, FEATURE_STYLE_MAP[feature], label=f"{experiment} {feature}")
ax.set_ylabel(FEATURE_LABEL_MAP[feature])
if perplexity_limit and feature == "perplexity":
ax.set_ylim(top=min(ax.get_ylim()[1], perplexity_limit))
ax.legend(loc=[1, 2, 7][feature_i]) # upper right, upper left, center right
ax.set_title(title.replace("\\n", "\n") if title else "Log perplexity as a function of input lengths")
fig.tight_layout()
return fig
def main():
parser = argparse.ArgumentParser()
# Where csv files have been logged
parser.add_argument("--output_dir", type=str, default="/home/joao/joao_scripts/perplexity/outputs")
parser.add_argument(
"--features", choices=["perplexity", "vram", "latency"], nargs="+", default=["perplexity", "vram"]
)
parser.add_argument("--title", type=str, default=None)
parser.add_argument("--log_perplexity_limit", type=float, default=5.0)
# Perplexity starts a bit unstable, so we skip the start
parser.add_argument("--skip_first", type=int, default=100)
args = parser.parse_args()
figure = plot(
args.features,
output_dir=args.output_dir,
title=args.title,
perplexity_limit=args.log_perplexity_limit,
skip_first=args.skip_first,
)
# Add your own code here if you'd like to change the figure
features = "_".join(args.features)
save_path = f"/home/joao/joao_scripts/perplexity/outputs/plot_{features}.png"
plt.savefig(save_path, dpi=600)
print(f"plot saved to {save_path}")
if __name__ == "__main__":
main() Credits to @tomaarsen, who shared these with me :) |
@gante @tomaarsen Thanks! |
What does this PR do?
Fixes the case where
SinkCache
is used in a chat bot, receiving new prompts after giving an answer. Fix developed with @tomaarsenHere's an example of a script that works after this PR: