-
Notifications
You must be signed in to change notification settings - Fork 27.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
class Cache must not be a subclass of torch.nn.Module
#32681
Comments
Hey! We should probably mention the fact that deepcopy can probably only work for |
@jiwoong-choi the previous code was just not usable otherwise for static cache, so maybe only having static cache as module. |
I found the reason why the copy wasn't working, tensors with grad are not leaf tensors so the forward should be done with import copy
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache, StaticCache
model_id = "meta-llama/Llama-2-7b-chat-hf"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="cuda")
tokenizer = AutoTokenizer.from_pretrained(model_id)
# init StaticCache with big enough max-length (1024 tokens for the below example)
prompt_cache = DynamicCache()
INITIAL_PROMPT = "You are a helpful assistant. "
inputs_initial_prompt = tokenizer(INITIAL_PROMPT, return_tensors="pt").to("cuda")
# This is the common prompt cached, we need to run forward without grad to be abel to copy
with torch.no_grad():
prompt_cache = model(**inputs_initial_prompt, past_key_values = prompt_cache).past_key_values
prompts = ["Help me to write a blogpost about travelling.", "What is the capital of France?"]
responses = []
for prompt in prompts:
new_inputs = tokenizer(INITIAL_PROMPT + prompt, return_tensors="pt").to("cuda")
past_key_values = copy.deepcopy(prompt_cache)
outputs = model.generate(**new_inputs, past_key_values=past_key_values,max_new_tokens=20)
response = tokenizer.batch_decode(outputs)[0]
responses.append(response)
print(responses) |
Shouldn’t creating the cache be done without grad (in generate?) as well? |
@ArthurZucker Yes, but in the example code we run forward to get initial cache, and that is usually ran with grad by default |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
I think our llama-recipe need an update for this! |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Quick ping @gante @zucchini-nlp - does anything need to be done here? |
Will be resolved for all cache classes in #33297 |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
System Info
I'm using
transformers==4.44.0
.Who can help?
People who have been involved in #32168
@gante @guangy10 @amyeroberts @ArthurZucker
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Steps to reproduce
For example, if you don't mind using conda:
Torch export for static cache
section in the transformers v4.44.0 release page. (Say this code is saved asexample.py
in your working directory.) The code is as follows:python example.py
The error message
You'll be able to see the error message, basically complaining about that the example code is calling
copy.deepcopy
on thetorch.nn.Module
instanceprompt_cache
.Comments
This is due to the change made in #32168, where the class
Cache
has become a subclass oftorch.nn.Module
. - See the comment that I wrote in this PR.Considering that the class
Cache
(and its subclasses, such asDynamicCache
) represents KV cache generated by a model (which is atorch.nn.Module
object itself), it is not natural to defineCache
as a subclass oftorch.nn.Module
.It looks like the purpose was to enable
copy.deepcopy
forCache
objects, but apparently, PyTorch 2.4 won't allow it.Expected behavior
The example code from the release page runs without the error, demonstrating support for prompt reuse introduced in transformers-4.44.0.
The text was updated successfully, but these errors were encountered: