Skip to content

Commit

Permalink
Register the buffers.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Jan 7, 2025
1 parent 5888f38 commit d305f40
Showing 1 changed file with 28 additions and 22 deletions.
50 changes: 28 additions & 22 deletions moshi/moshi/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def from_kv(keys: torch.Tensor, values: torch.Tensor) -> "KVCacheResult":
return KVCacheResult(keys, values, positions)


class RingKVCache:
class RingKVCache(nn.Module):
"""Efficient streaming KVCache to be compatible with Cuda Graph.
Args:
Expand All @@ -228,13 +228,14 @@ def __init__(
device: torch.device = torch.device("cuda"),
dtype: torch.dtype = torch.bfloat16,
):
super().__init__()
self.capacity = capacity
self.cache = torch.zeros(
self.register_buffer("cache", torch.zeros(
(2, batch_size, num_heads, capacity, dim_per_head),
device=device,
dtype=dtype,
)
self.end_offset = torch.zeros(1, device=device, dtype=torch.long)
))
self.register_buffer("end_offset", torch.zeros(1, device=device, dtype=torch.long))

def reset(self):
self.end_offset.zero_()
Expand Down Expand Up @@ -280,13 +281,9 @@ def complete(self, k: torch.Tensor, v: torch.Tensor) -> KVCacheResult:

@dataclass
class _MHAState:
kv_cache: RingKVCache
offset: torch.Tensor
offset_cpu: int

def reset(self):
self.kv_cache.reset()
self.offset.zero_()
self.offset_cpu = 0


Expand Down Expand Up @@ -342,6 +339,21 @@ def __init__(
self.out_proj = nn.Linear(
embed_dim, mult * embed_dim, bias=False, **factory_kwargs
)
self.register_buffer("offset", torch.zeros(1, device=in_proj.weight.device, dtype=torch.long))
dim_per_head = self.embed_dim // self.num_heads
dtype = self.in_proj_weight.dtype
if self.context is None:
if self.weights_per_step:
capacity = self.weights_per_step
else:
raise RuntimeError(
"Cannot create a streaming KVCache without a context to estimate capacity."
)
else:
capacity = self.context
self.kv_cache = RingKVCache(
1, self.num_heads, dim_per_head, capacity, device, dtype
)

def _init_streaming_state(self, batch_size: int) -> _MHAState:
if self.context is None:
Expand All @@ -357,12 +369,7 @@ def _init_streaming_state(self, batch_size: int) -> _MHAState:
# TODO: the following estimation will not work great with FSDP.
dtype = self.in_proj_weight.dtype
dim_per_head = self.embed_dim // self.num_heads
kv_cache = RingKVCache(
batch_size, self.num_heads, dim_per_head, capacity, device, dtype
)
return _MHAState(
kv_cache,
offset=torch.zeros(1, device=device, dtype=torch.long),
offset_cpu=0,
)

Expand All @@ -371,7 +378,7 @@ def _complete_kv(self, k, v) -> KVCacheResult:
if state is None:
return KVCacheResult.from_kv(k, v)
else:
return state.kv_cache.complete(k, v)
return self.kv_cache.complete(k, v)

def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
state = self._streaming_state
Expand All @@ -382,7 +389,7 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
offset_cpu = 0
else:
assert self.causal, "Streaming only available for causal"
offset = state.offset
offset = self.offset
offset_cpu = state.offset_cpu

if self.weights_per_step:
Expand Down Expand Up @@ -418,7 +425,7 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
else:
x = self.out_proj(x)
if state is not None:
state.offset.add_(T)
self.offset.add_(T)
state.offset_cpu += T
return x

Expand Down Expand Up @@ -591,10 +598,8 @@ def forward(self, x: torch.Tensor):

@dataclass
class _TransformerState:
offset: torch.Tensor

def reset(self):
self.offset.zero_()
pass


class StreamingTransformer(StreamingModule[_TransformerState]):
Expand Down Expand Up @@ -663,10 +668,11 @@ def __init__(
**kwargs,
)
)
self.register_buffer("offset", torch.zeros(1, device=device, dtype=torch.long))

def _init_streaming_state(self, batch_size: int) -> _TransformerState:
device = next(self.parameters()).device
return _TransformerState(offset=torch.zeros(1, device=device, dtype=torch.long))
return _TransformerState()

def forward(self, x: torch.Tensor, *args, **kwargs):
B, T, C = x.shape
Expand All @@ -675,7 +681,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs):
if state is None:
offset = torch.zeros(1, dtype=torch.long, device=x.device)
else:
offset = state.offset
offset = self.offset

if self.positional_embedding in {"sin", "sin_rope"}:
positions = torch.arange(T, device=x.device).view(1, -1, 1)
Expand All @@ -689,7 +695,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs):
x = layer(x, *args, **kwargs)

if state is not None:
state.offset.add_(T)
self.offset.add_(T)
return x


Expand Down

0 comments on commit d305f40

Please sign in to comment.