Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

class Cache must not be a subclass of torch.nn.Module #32681

Closed
1 of 4 tasks
jiwoong-choi opened this issue Aug 14, 2024 · 11 comments
Closed
1 of 4 tasks

class Cache must not be a subclass of torch.nn.Module #32681

jiwoong-choi opened this issue Aug 14, 2024 · 11 comments
Labels

Comments

@jiwoong-choi
Copy link
Contributor

jiwoong-choi commented Aug 14, 2024

System Info

I'm using transformers==4.44.0.

  • The script that I used for collecting my system info is as follows:
$ curl -OL https://raw.githubusercontent.com/pytorch/pytorch/main/torch/utils/collect_env.py
$ python3 collect_env.py
  • The collected system info is as follows:
Collecting environment information...
PyTorch version: 2.4.0
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.26.4
Libc version: glibc-2.31

Python version: 3.10.14 (main, May  6 2024, 19:42:50) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-102-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 10.1.243
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA RTX A6000
GPU 1: NVIDIA RTX A6000

Nvidia driver version: 535.171.04
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Byte Order:                         Little Endian
Address sizes:                      43 bits physical, 48 bits virtual
CPU(s):                             32
On-line CPU(s) list:                0-31
Thread(s) per core:                 2
Core(s) per socket:                 16
Socket(s):                          1
NUMA node(s):                       1
Vendor ID:                          AuthenticAMD
CPU family:                         23
Model:                              49
Model name:                         AMD Ryzen Threadripper PRO 3955WX 16-Cores
Stepping:                           0
Frequency boost:                    enabled
CPU MHz:                            2082.992
CPU max MHz:                        4402.7339
CPU min MHz:                        2200.0000
BogoMIPS:                           7785.71
Virtualization:                     AMD-V
L1d cache:                          512 KiB
L1i cache:                          512 KiB
L2 cache:                           8 MiB
L3 cache:                           64 MiB
NUMA node0 CPU(s):                  0-31
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Mitigation; untrained return thunk; SMT enabled with STIBP protection
Vulnerability Spec rstack overflow: Mitigation; safe RET
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines, IBPB conditional, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip rdpid overflow_recov succor smca sme sev sev_es

Versions of relevant libraries:
[pip3] mypy==1.11.1
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] pytorchvideo==0.1.5
[pip3] torch==2.4.0
[pip3] torchvision==0.19.0
[pip3] triton==3.0.0
[conda] Could not collect

Who can help?

People who have been involved in #32168
@gante @guangy10 @amyeroberts @ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Steps to reproduce

  1. Create an environment with python 3.10 with pytorch-2.4.0 and transformers-4.44.0 installed.
    For example, if you don't mind using conda:
conda create -n repro python=3.10
conda activate repro
conda install pytorch pytorch-cuda=12.4 -c pytorch -c nvidia
pip install transformers==4.44.0
  1. Copy and paste the example code under the Torch export for static cache section in the transformers v4.44.0 release page. (Say this code is saved as example.py in your working directory.) The code is as follows:
import os, torch, copy
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache
device = "cuda"
ckpt = "meta-llama/Meta-Llama-3.1-8B-Instruct"

INITIAL_PROMPT = "From now on, you are going to answer all my questions with historical details. Make sure to always add a bit of french here and there, for style."

model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16)
model.to(device)
tokenizer = AutoTokenizer.from_pretrained(ckpt)

prompt_cache = DynamicCache()
inputs = tokenizer(INITIAL_PROMPT, return_tensors="pt").to("cuda")
prompt_cache = model(**inputs, past_key_values = prompt_cache).past_key_values

prompt = "Why are french people obsessed with french?"
new_inputs = tokenizer(INITIAL_PROMPT + prompt, return_tensors="pt").to("cuda")
past_key_values = copy.deepcopy(prompt_cache)
outputs = model.generate(**new_inputs, past_key_values=past_key_values,max_new_tokens=20) 
response = tokenizer.batch_decode(outputs)[0]
print(response)

prompt = "What is the best city to swim in?"
new_inputs = tokenizer(INITIAL_PROMPT + prompt, return_tensors="pt").to("cuda")
outputs = model.generate(**new_inputs, past_key_values=copy.deepcopy(prompt_cache),max_new_tokens=20) 
response = tokenizer.batch_decode(outputs)[0]
  1. Run the code: python example.py

The error message

You'll be able to see the error message, basically complaining about that the example code is calling copy.deepcopy on the torch.nn.Module instance prompt_cache.

Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 12.45it/s]
Traceback (most recent call last):
  File "/path/to/your/working/directory/example.py", line 18, in <module>
    past_key_values = copy.deepcopy(prompt_cache)
  File "/path/to/your/repro/lib/python3.10/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/path/to/your/repro/lib/python3.10/copy.py", line 271, in _reconstruct
    state = deepcopy(state, memo)
  File "/path/to/your/repro/lib/python3.10/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/path/to/your/repro/lib/python3.10/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/path/to/your/repro/lib/python3.10/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/path/to/your/repro/lib/python3.10/copy.py", line 206, in _deepcopy_list
    append(deepcopy(a, memo))
  File "/path/to/your/repro/lib/python3.10/copy.py", line 153, in deepcopy
    y = copier(memo)
  File "/path/to/your/repro/lib/python3.10/site-packages/torch/_tensor.py", line 87, in __deepcopy__
    raise RuntimeError(
RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment.  If you were attempting to deepcopy a module, this may be because of a torch.nn.utils.weight_norm usage, see https:/github.com/pytorch/pytorch/pull/103001

Comments

This is due to the change made in #32168, where the class Cache has become a subclass of torch.nn.Module. - See the comment that I wrote in this PR.

Considering that the class Cache (and its subclasses, such as DynamicCache) represents KV cache generated by a model (which is a torch.nn.Module object itself), it is not natural to define Cache as a subclass of torch.nn.Module.
It looks like the purpose was to enable copy.deepcopy for Cache objects, but apparently, PyTorch 2.4 won't allow it.

Expected behavior

The example code from the release page runs without the error, demonstrating support for prompt reuse introduced in transformers-4.44.0.

@ArthurZucker
Copy link
Collaborator

Hey! We should probably mention the fact that deepcopy can probably only work for StaticCache for now, as the Dynamic Cache create new leafs 🫠 cc. @gante we can potentially change the way we init the dynamic cache to have already empty tensors en just replace them

@ArthurZucker
Copy link
Collaborator

@jiwoong-choi the previous code was just not usable otherwise for static cache, so maybe only having static cache as module.
Also cf our internal talks about testing: we need this to be tested!

@zucchini-nlp
Copy link
Member

I found the reason why the copy wasn't working, tensors with grad are not leaf tensors so the forward should be done with no_grad. It worked for dynamic and static cache for me

import copy
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache, StaticCache

model_id = "meta-llama/Llama-2-7b-chat-hf"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="cuda")
tokenizer = AutoTokenizer.from_pretrained(model_id)

# init StaticCache with big enough max-length (1024 tokens for the below example) 
prompt_cache = DynamicCache()
INITIAL_PROMPT = "You are a helpful assistant. "
inputs_initial_prompt = tokenizer(INITIAL_PROMPT, return_tensors="pt").to("cuda")
# This is the common prompt cached, we need to run forward without grad to be abel to copy
with torch.no_grad():
    prompt_cache = model(**inputs_initial_prompt, past_key_values = prompt_cache).past_key_values

prompts = ["Help me to write a blogpost about travelling.", "What is the capital of France?"]
responses = []
for prompt in prompts:
    new_inputs = tokenizer(INITIAL_PROMPT + prompt, return_tensors="pt").to("cuda")
    past_key_values = copy.deepcopy(prompt_cache)
    outputs = model.generate(**new_inputs, past_key_values=past_key_values,max_new_tokens=20) 
    response = tokenizer.batch_decode(outputs)[0]
    responses.append(response)

print(responses)

@ArthurZucker
Copy link
Collaborator

Shouldn’t creating the cache be done without grad (in generate?) as well?

@zucchini-nlp
Copy link
Member

@ArthurZucker Yes, but in the example code we run forward to get initial cache, and that is usually ran with grad by default

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@ArthurZucker
Copy link
Collaborator

I think our llama-recipe need an update for this!

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@Rocketknight1
Copy link
Member

Quick ping @gante @zucchini-nlp - does anything need to be done here?

@zucchini-nlp
Copy link
Member

Will be resolved for all cache classes in #33297

@huggingface huggingface deleted a comment from github-actions bot Nov 25, 2024
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants