Skip to content

Commit

Permalink
Adding cached KVs (#266)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
gkroiz and carmocca authored May 22, 2023
1 parent 206f316 commit a24fc5e
Show file tree
Hide file tree
Showing 10 changed files with 339 additions and 221 deletions.
44 changes: 22 additions & 22 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@

@torch.no_grad()
def generate(
model: torch.nn.Module,
model: LLaMA,
idx: torch.Tensor,
max_new_tokens: int,
max_seq_length: int,
*,
max_seq_length: Optional[int] = None,
temperature: float = 1.0,
top_k: Optional[int] = None,
eos_id: Optional[int] = None,
Expand All @@ -41,44 +42,49 @@ def generate(
# create an empty tensor of the expected final shape and fill in the current tokens
T = idx.size(0)
T_new = T + max_new_tokens
empty = torch.empty(T_new, dtype=idx.dtype, device=idx.device)
if max_seq_length is None:
max_seq_length = min(T_new, model.config.block_size)

device, dtype = idx.device, idx.dtype
# create an empty tensor of the expected final shape and fill in the current tokens
empty = torch.empty(T_new, dtype=dtype, device=device)
empty[:T] = idx
idx = empty
input_pos = torch.arange(0, T, device=device)

if idx.device.type == "xla":
import torch_xla.core.xla_model as xm

xm.mark_step()

# generate max_new_tokens tokens
for t in range(T, T_new):
# ignore the not-filled-yet tokens
idx_cond = idx[:t]
# if the sequence context is growing too long we must crop it at max_seq_length
idx_cond = idx_cond if t <= max_seq_length else idx_cond[-max_seq_length:]
for _ in range(max_new_tokens):
x = idx.index_select(0, input_pos).view(1, -1)

# forward
logits = model(idx_cond.view(1, -1))
logits = model(x, max_seq_length, input_pos)
logits = logits[0, -1] / temperature

# optionally crop the logits to only the top k options
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[[-1]]] = -float("Inf")
logits = torch.where(logits < v[[-1]], -float("Inf"), logits)

probs = torch.nn.functional.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx_next = torch.multinomial(probs, num_samples=1).to(dtype=dtype)

# advance
input_pos = input_pos[-1:] + 1

if idx.device.type == "xla":
xm.mark_step()

# concatenate the new generation
# https://github.com/pytorch/pytorch/issues/101936
idx[t] = idx_next.item() if idx.device.type == "mps" else idx_next
idx = idx.index_copy(0, input_pos, idx_next)

# if <eos> token is triggered, return the output (stop generation)
if idx_next == eos_id:
return idx[:t + 1] # include the EOS token
return idx[:input_pos] # include the EOS token

return idx

Expand Down Expand Up @@ -138,16 +144,10 @@ def main(
L.seed_everything(1234)
for i in range(num_samples):
t0 = time.perf_counter()
y = generate(
model,
encoded,
max_new_tokens,
model.config.block_size, # type: ignore[union-attr,arg-type]
temperature=temperature,
top_k=top_k,
)
y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k)
t = time.perf_counter() - t0

model.reset_cache()
print(tokenizer.decode(y))
tokens_generated = y.size(0) - prompt_length
print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr)
Expand Down
11 changes: 2 additions & 9 deletions generate/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,17 +82,10 @@ def main(
prompt_length = encoded.size(0)

t0 = time.perf_counter()
y = generate(
model,
idx=encoded,
max_seq_length=max_new_tokens,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
eos_id=tokenizer.eos_id
)
y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id)
t = time.perf_counter() - t0

model.reset_cache()
output = tokenizer.decode(y)
output = output.split("### Response:")[1].strip()
print(output)
Expand Down
12 changes: 2 additions & 10 deletions generate/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import lightning as L
import torch
import torch.nn as nn

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
Expand Down Expand Up @@ -85,17 +84,10 @@ def main(
prompt_length = encoded.size(0)

t0 = time.perf_counter()
y = generate(
model,
idx=encoded,
max_seq_length=max_new_tokens,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
eos_id=tokenizer.eos_id
)
y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id)
t = time.perf_counter() - t0

model.reset_cache()
output = tokenizer.decode(y)
output = output.split("### Response:")[1].strip()
print(output)
Expand Down
74 changes: 7 additions & 67 deletions generate/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,68 +7,14 @@
import lightning as L
import torch

# support running without installing as a package
wd = Path(__file__).absolute().parent.parent
sys.path.append(str(wd))

from lit_llama import LLaMA, Tokenizer
from lit_llama.utils import EmptyInitOnDevice
from scripts.prepare_alpaca import generate_prompt

@torch.no_grad()
def generate(
model: torch.nn.Module,
idx: torch.Tensor,
max_new_tokens: int,
max_seq_length: int,
temperature: float = 1.0,
top_k: Optional[int] = None,
eos_id: Optional[int] = None,
) -> torch.Tensor:
"""Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
The implementation of this function is modified from A. Karpathy's nanoGPT.
Args:
model: The model to use.
idx: Tensor of shape (T) with indices of the prompt sequence.
max_new_tokens: The number of new tokens to generate.
max_seq_length: The maximum sequence length allowed.
temperature: Scales the predicted logits by 1 / temperature
top_k: If specified, only sample among the tokens with the k highest probabilities
eos_id: If specified, stop generating any more token once the <eos> token is triggered
"""
# create an empty tensor of the expected final shape and fill in the current tokens
T = idx.size(0)
T_new = T + max_new_tokens
empty = torch.empty(T_new, dtype=idx.dtype, device=idx.device)
empty[:T] = idx
idx = empty

# generate max_new_tokens tokens
for t in range(T, T_new):
# ignore the not-filled-yet tokens
idx_cond = idx[:t]
# if the sequence context is growing too long we must crop it at max_seq_length
idx_cond = idx_cond if t <= max_seq_length else idx_cond[-max_seq_length:]

# forward
logits = model(idx_cond.view(1, -1))
logits = logits[0, -1] / temperature

# optionally crop the logits to only the top k options
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[[-1]]] = -float("Inf")

probs = torch.nn.functional.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)

# concatenate the new generation
# https://github.com/pytorch/pytorch/issues/101936
idx[t] = idx_next.item() if idx.device.type == "mps" else idx_next

# if <eos> token is triggered, return the output (stop generation)
if idx_next == eos_id:
return idx[:t + 1] # include the EOS token

return idx
from generate import generate


def main(
Expand Down Expand Up @@ -130,16 +76,10 @@ def main(
L.seed_everything(1234)
for i in range(num_samples):
t0 = time.perf_counter()
y = generate(
model,
encoded,
max_new_tokens,
model.config.block_size, # type: ignore[union-attr,arg-type]
temperature=temperature,
top_k=top_k,
)
y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k)
t = time.perf_counter() - t0

model.reset_cache()
print(tokenizer.decode(y))
tokens_generated = y.size(0) - prompt_length
print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr)
Expand Down
1 change: 0 additions & 1 deletion generate/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ def main(
output = generate(
model,
idx=encoded,
max_seq_length=max_new_tokens,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
Expand Down
Loading

0 comments on commit a24fc5e

Please sign in to comment.