Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add external KV to LLaMA 3 #734

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 29 additions & 20 deletions train_llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,9 @@
import inspect
from contextlib import nullcontext
from dataclasses import dataclass
import json
from pathlib import Path
from typing import (
AbstractSet,
Callable,
Collection,
Dict,
Iterator,
Expand Down Expand Up @@ -58,6 +56,25 @@
# using a global to toggle flash-attention
FLASH = 0

class KVCache:
def __init__(self, num_layers, max_batch_size, block_size, n_kv_head, hd, device, dtype):
self.kv = [KV(max_batch_size, block_size, n_kv_head, hd, device, dtype) for _ in range(num_layers)]

def clear(self):
self.kv = []

class KV:
# Static KV cache, we preallocate the memory for the key and value tensors
def __init__(self, max_batch_size, block_size, n_kv_head, hd, device, dtype):
self.k = torch.zeros((max_batch_size, block_size, n_kv_head, hd), dtype=dtype).to(device)
self.v = torch.zeros((max_batch_size, block_size, n_kv_head, hd), dtype=dtype).to(device)

def add(self, k, v, B, T, start_pos):
assert B == k.shape[0] and T == k.shape[1] and self.k.shape[2] == k.shape[2] and self.k.shape[3] == k.shape[3]
self.k[:B, start_pos : start_pos + T] = k
self.v[:B, start_pos : start_pos + T] = v
return self.k[:B, : start_pos + T], self.v[:B, : start_pos + T]

# Used in Grouped Query Attention (GQA), broadcasts the key and value tensors
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
Expand Down Expand Up @@ -156,17 +173,11 @@ def __init__(self, config):
self.n_kv_head = config.n_kv_head
self.n_rep = self.n_head // self.n_kv_head
self.hd = config.n_embd // config.n_head
self.use_kv = config.use_kv

self.c_attn = nn.Linear(config.n_embd, (config.n_head + 2 * config.n_kv_head) * self.hd, bias=False) # key, query, value projections
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) # output projection

# static KV cache - we could alternatively allocate it outside of the model and just pass it in when needed
if self.use_kv:
self.cache_k = torch.zeros((config.max_gen_batch_size, config.block_size, config.n_kv_head, self.hd))
self.cache_v = torch.zeros((config.max_gen_batch_size, config.block_size, config.n_kv_head, self.hd))

def forward(self, x, freqs_cis=None, start_pos=None, mask=None):
def forward(self, x, freqs_cis=None, kv_cache=None, start_pos=None, mask=None):
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
qkv = self.c_attn(x)
Expand All @@ -175,11 +186,8 @@ def forward(self, x, freqs_cis=None, start_pos=None, mask=None):

q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis) # rotate QK (rope) <-- 1. difference compared to GPT-2

if self.use_kv and not self.training and start_pos >= 0: # use kv-caching during inference
self.cache_k[:B, start_pos : start_pos + T] = k
self.cache_v[:B, start_pos : start_pos + T] = v
k = self.cache_k[:B, : start_pos + T]
v = self.cache_v[:B, : start_pos + T]
if kv_cache and not self.training: # use kv-caching during inference
k, v = kv_cache.add(k, v, B, T, start_pos)

k = repeat_kv(k, self.n_rep) # GQA <-- 2. difference compared to GPT-2
v = repeat_kv(v, self.n_rep)
Expand Down Expand Up @@ -233,8 +241,8 @@ def __init__(self, config):
self.ln_2 = RMSNorm(config.n_embd, config.norm_eps)
self.mlp = MLP(config)

def forward(self, x, freqs_cis=None, start_pos=None, mask=None):
x = x + self.attn(self.ln_1(x), freqs_cis, start_pos, mask)
def forward(self, x, freqs_cis=None, kv_cache=None, start_pos=None, mask=None):
x = x + self.attn(self.ln_1(x), freqs_cis, kv_cache, start_pos, mask)
x = x + self.mlp(self.ln_2(x))
return x

Expand All @@ -256,7 +264,6 @@ class LlamaConfig:
rope_theta: float = 500000.0
use_scaled_rope: bool = True
max_gen_batch_size: int = 4
use_kv: bool = True

def __init__(self, **kwargs):
for k, v in kwargs.items():
Expand Down Expand Up @@ -290,7 +297,7 @@ def __init__(self, config):
config.use_scaled_rope,
)

def forward(self, idx, targets=None, return_logits=True, start_pos=0):
def forward(self, idx, targets=None, return_logits=True, kv_cache=None, start_pos=0):
_, t = idx.size()
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"

Expand All @@ -301,7 +308,7 @@ def forward(self, idx, targets=None, return_logits=True, start_pos=0):
mask = torch.triu(torch.ones((t, t), device=next(self.parameters()).device, dtype=torch.bool), diagonal=1)

for i, block in enumerate(self.transformer.h):
x = block(x, freqs_cis, start_pos, mask)
x = block(x, freqs_cis, kv_cache.kv[i] if kv_cache is not None else None, start_pos, mask)
x = self.transformer.ln_f(x)

if targets is not None:
Expand Down Expand Up @@ -530,8 +537,10 @@ def generate(

stop_tokens = torch.tensor(list(self.tokenizer.stop_tokens)).to(device)

dtype = self.transformer.wte.weight.dtype
kv_cache = KVCache(self.config.n_layer, bsz, self.config.block_size, self.config.n_kv_head, self.config.n_embd // self.config.n_head, device, dtype)
for cur_pos in range(min_prompt_len, total_len):
logits, _ = self.forward(tokens[:, prev_pos:cur_pos], start_pos=prev_pos)
logits, _ = self.forward(tokens[:, prev_pos:cur_pos], kv_cache=kv_cache, start_pos=prev_pos)
if temperature > 0:
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
next_token = sample_top_p(probs, top_p)
Expand Down