Skip to content

Commit

Permalink
Make static cache compatible with torch.export (#32168)
Browse files Browse the repository at this point in the history
  • Loading branch information
guangy10 authored Jul 29, 2024
1 parent 7f5d644 commit 811a9ca
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 10 deletions.
32 changes: 22 additions & 10 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@
logger = logging.get_logger(__name__)


@dataclass
class Cache:
class Cache(torch.nn.Module):
"""
Base, abstract class for all caches. The actual data structure is specific to each subclass.
"""

def __init__(self):
super().__init__()

def update(
self,
key_states: torch.Tensor,
Expand Down Expand Up @@ -299,6 +301,7 @@ class DynamicCache(Cache):
"""

def __init__(self) -> None:
super().__init__()
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
Expand Down Expand Up @@ -461,6 +464,7 @@ class QuantizedCache(DynamicCache):
"""

def __init__(self, cache_config: QuantizedCacheConfig) -> None:
super().__init__()
self._quantized_key_cache: List[torch.Tensor] = []
self._quantized_value_cache: List[torch.Tensor] = []

Expand Down Expand Up @@ -634,6 +638,7 @@ class SinkCache(Cache):
"""

def __init__(self, window_length: int, num_sink_tokens: int) -> None:
super().__init__()
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
self.window_length = window_length
Expand Down Expand Up @@ -786,7 +791,7 @@ def update(

class StaticCache(Cache):
"""
Static Cache class to be used with `torch.compile(model)`.
Static Cache class to be used with `torch.compile(model)` and `torch.export()`.
Parameters:
config (`PretrainedConfig):
Expand Down Expand Up @@ -817,18 +822,22 @@ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len:

self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
# Note: There will be significant perf decrease if switching to use 5D tensors instead.
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
for _ in range(config.num_hidden_layers):
for idx in range(config.num_hidden_layers):
# Note: `torch.export()`` requires mutations to be registered as buffers.
self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device))
self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device))
key_cache = getattr(self, f"key_cache_{idx}")
value_cache = getattr(self, f"value_cache_{idx}")
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
# it is not needed anyway)
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
if not is_torchdynamo_compiling():
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache)
torch._dynamo.mark_static_address(key_cache)
torch._dynamo.mark_static_address(value_cache)
self.key_cache.append(key_cache)
self.value_cache.append(value_cache)

def update(
self,
Expand Down Expand Up @@ -928,6 +937,7 @@ class SlidingWindowCache(StaticCache):
"""

def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
super().__init__()
if not hasattr(config, "sliding_window") or config.sliding_window is None:
raise ValueError(
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
Expand Down Expand Up @@ -1005,6 +1015,7 @@ class EncoderDecoderCache(Cache):
"""

def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache):
super().__init__()
self.self_attention_cache = self_attention_cache
self.cross_attention_cache = cross_attention_cache

Expand Down Expand Up @@ -1148,6 +1159,7 @@ def batch_select_indices(self, indices: torch.Tensor):

class HybridCache(Cache):
def __init__(self, config: PretrainedConfig, max_batch_size, max_cache_len, device="cpu", dtype=None) -> None:
super().__init__()
if not hasattr(config, "sliding_window") or config.sliding_window is None:
raise ValueError(
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
Expand Down
58 changes: 58 additions & 0 deletions tests/utils/test_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@

import unittest

from packaging import version
from parameterized import parameterized

from transformers import set_seed
from transformers.testing_utils import (
is_torch_available,
require_auto_gptq,
require_read_token,
require_torch,
require_torch_gpu,
slow,
Expand All @@ -32,6 +34,7 @@
import torch

from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
DynamicCache,
Expand Down Expand Up @@ -164,6 +167,61 @@ def _random_kvs(config):
self.assertTrue(cached_keys.shape == (1, 1, 10, 128))
self.assertTrue(cached_values.shape == (1, 1, 10, 128))

@slow
@require_read_token
def test_static_cache_exportability(self):
"""
Tests that static cache works with `torch.export()`
"""
if version.parse(torch.__version__) < version.parse("2.3"):
self.skipTest(reason="This test requires torch >= 2.3 to run.")

device = "cpu"
dtype = torch.float32
max_batch_size = 1

config = AutoConfig.from_pretrained(
"google/gemma-2b",
torch_dtype=dtype,
use_cache=True,
)
m = AutoModelForCausalLM.from_pretrained(
"google/gemma-2b",
config=config,
torch_dtype=dtype,
attn_implementation="sdpa", # Export and ExecuTorch only works for SdpaAttention
).to(device)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
inputs = tokenizer(["The best color is"], return_tensors="pt").to(device)["input_ids"]

class ExportatibleModelWithStaticCache(torch.nn.Module):
def __init__(self, config, model):
super().__init__()
self.config = config
self.model = model
self.static_cache = StaticCache(
config=config, max_batch_size=max_batch_size, max_cache_len=config.max_length, device=device
)

def forward(self, tokens: torch.Tensor, input_pos: torch.Tensor):
outs = self.model(
input_ids=tokens,
attention_mask=None,
position_ids=input_pos.unsqueeze(0),
cache_position=input_pos,
past_key_values=self.static_cache,
use_cache=True,
)
return outs.logits

set_seed(0)
with torch.no_grad():
from torch.export import ExportedProgram, export

model = ExportatibleModelWithStaticCache(config, m)
exported_program = export(model, args=(inputs,), kwargs={"input_pos": torch.arange(1)})
self.assertTrue(isinstance(exported_program, ExportedProgram))


@require_torch_gpu
@slow
Expand Down

0 comments on commit 811a9ca

Please sign in to comment.