Skip to content

Commit

Permalink
Make static cache compatible with torch.export
Browse files Browse the repository at this point in the history
  • Loading branch information
guangy10 committed Jul 23, 2024
1 parent c85510f commit c995601
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
logger = logging.get_logger(__name__)


@dataclass
class Cache:
"""
Base, abstract class for all caches. The actual data structure is specific to each subclass.
Expand Down Expand Up @@ -785,9 +784,9 @@ def update(
return self.key_cache[layer_idx], self.value_cache[layer_idx]


class StaticCache(Cache):
class StaticCache(Cache, torch.nn.Module):
"""
Static Cache class to be used with `torch.compile(model)`.
Static Cache class to be used with `torch.compile(model)` and `torch.export()`.
Parameters:
config (`PretrainedConfig):
Expand Down Expand Up @@ -819,15 +818,18 @@ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len:
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
for _ in range(config.num_hidden_layers):
for idx in range(config.num_hidden_layers):
# Note: `torch.export()`` requires mutations to be registered as buffers.
self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device))
self.register_buffer(f"val_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device))
key_cache = getattr(self, f"key_cache_{idx}")
val_cache = getattr(self, f"val_cache_{idx}")
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache.
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache)
torch._dynamo.mark_static_address(key_cache)
torch._dynamo.mark_static_address(val_cache)
self.key_cache.append(key_cache)
self.value_cache.append(val_cache)

def update(
self,
Expand Down

0 comments on commit c995601

Please sign in to comment.