Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] return hidden states #3364

Open
wants to merge 34 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
54b524d
extract hidden states
Jackmin801 Feb 6, 2025
f5a8a4d
include meow.py
Jackmin801 Feb 6, 2025
e6414cc
allow cuda graph runner
Jackmin801 Feb 7, 2025
73e5305
add return hidden states as engine arg
Jackmin801 Feb 7, 2025
eb4f93a
change meow script
Jackmin801 Feb 7, 2025
5e9ce35
lint
Jackmin801 Feb 7, 2025
52dc2cb
add cli arg
Jackmin801 Feb 7, 2025
371fe0e
forward from detokenizer
Jackmin801 Feb 7, 2025
9cb1111
fix: dont error on embedding model
Jackmin801 Feb 7, 2025
64232cb
remove testing script
Jackmin801 Feb 7, 2025
7c73a30
style
Jackmin801 Feb 7, 2025
b0e4765
Merge branch 'main' into feat-hidden_states
zhaochenyang20 Feb 7, 2025
ef315c2
add docs in server args
Jackmin801 Feb 7, 2025
fc64fdc
add example
Jackmin801 Feb 7, 2025
5ff6edc
test: add test
Jackmin801 Feb 7, 2025
9dfdbff
add example to offline engine api
Jackmin801 Feb 7, 2025
09be3af
Revert "add example"
Jackmin801 Feb 7, 2025
496b572
add comparison to hf [skip ci]
Jackmin801 Feb 7, 2025
a061616
add 1 decode to test [skip ci]
Jackmin801 Feb 7, 2025
0288b8f
change to meta llama 3.1 8b I
Jackmin801 Feb 8, 2025
3e49643
add test_hidden_states
Jackmin801 Feb 9, 2025
c296ed9
Revert "change to meta llama 3.1 8b I"
Jackmin801 Feb 9, 2025
01afdf8
Revert "add 1 decode to test [skip ci]"
Jackmin801 Feb 9, 2025
45d64fe
Revert "add comparison to hf [skip ci]"
Jackmin801 Feb 9, 2025
bab6d20
Revert "test: add test"
Jackmin801 Feb 9, 2025
329e3d0
lint
Jackmin801 Feb 9, 2025
6efd6eb
add to test suite
Jackmin801 Feb 9, 2025
929cdb9
Merge branch 'main' into feat-hidden_states
zhaochenyang20 Feb 10, 2025
be03127
fix: only output when return hidden states in server args
Jackmin801 Feb 10, 2025
0afff9a
increase ci timeout
Jackmin801 Feb 10, 2025
6d5210e
Merge branch 'main' into feat-hidden_states
zhaochenyang20 Feb 10, 2025
6229a23
frontload failing hidden state test
Jackmin801 Feb 10, 2025
d8b489f
add prints to figure out why fail in CI
Jackmin801 Feb 10, 2025
dc9b434
clone in tp_model_worker to avoid illegal memory access
Jackmin801 Feb 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions docs/backend/offline_engine_api.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,59 @@
"asyncio.run(main())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"llm.shutdown()"
]
},
{
"cell_type": "markdown",
"metadata": {},
Jackmin801 marked this conversation as resolved.
Show resolved Hide resolved
"source": [
"### Return Hidden States"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sglang as sgl\n",
"\n",
"llm = sgl.Engine(\n",
" model_path=\"meta-llama/Meta-Llama-3.1-8B-Instruct\", return_hidden_states=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"prompts = [\n",
" \"Hello, my name is\",\n",
" \"The president of the United States is\",\n",
" \"The capital of France is\",\n",
" \"The future of AI is\",\n",
"]\n",
"\n",
"sampling_params = {\"temperature\": 0.8, \"top_p\": 0.95, \"max_new_tokens\": 10}\n",
"\n",
"outputs = llm.generate(prompts, sampling_params=sampling_params)\n",
"for prompt, output in zip(prompts, outputs):\n",
" print(\"===============================\")\n",
" print(\n",
" f\"Prompt: {prompt}\\nGenerated text: {output['text']}\\nPrompt_Tokens: {output['meta_info']['prompt_tokens']}\\tCompletion_tokens: {output['meta_info']['completion_tokens']}\\nHidden states: {[i.shape for i in output['meta_info']['hidden_states']]}\"\n",
" )\n",
" print()"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/detokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def event_loop(self):
input_top_logprobs_idx=recv_obj.input_top_logprobs_idx,
output_top_logprobs_val=recv_obj.output_top_logprobs_val,
output_top_logprobs_idx=recv_obj.output_top_logprobs_idx,
output_hidden_states=recv_obj.output_hidden_states,
)
)

Expand Down
4 changes: 4 additions & 0 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,8 @@ class BatchTokenIDOut:
output_top_logprobs_val: List[List]
output_top_logprobs_idx: List[List]

output_hidden_states: List[List[float]]


@dataclass
class BatchStrOut:
Expand All @@ -397,6 +399,8 @@ class BatchStrOut:
output_top_logprobs_val: List[List]
output_top_logprobs_idx: List[List]

output_hidden_states: List[List[float]]


@dataclass
class BatchEmbeddingOut:
Expand Down
18 changes: 15 additions & 3 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ def __init__(
self.output_token_logprobs_val = self.output_token_logprobs_idx = (
self.output_top_logprobs_val
) = self.output_top_logprobs_idx = None
self.hidden_states = []

# Logprobs (internal values)
# The tokens is prefilled but need to be considered as decode tokens
Expand Down Expand Up @@ -604,6 +605,9 @@ class ScheduleBatch:
# Enable custom logit processor
enable_custom_logit_processor: bool = False

# Return hidden states
return_hidden_states: bool = False

@classmethod
def init_new(
cls,
Expand All @@ -615,6 +619,7 @@ def init_new(
enable_overlap: bool,
spec_algorithm: SpeculativeAlgorithm,
enable_custom_logit_processor: bool,
return_hidden_states: bool = False,
):
return cls(
reqs=reqs,
Expand All @@ -629,6 +634,7 @@ def init_new(
device=req_to_token_pool.device,
spec_algorithm=spec_algorithm,
enable_custom_logit_processor=enable_custom_logit_processor,
return_hidden_states=return_hidden_states,
)

def batch_size(self):
Expand Down Expand Up @@ -1196,9 +1202,15 @@ def get_model_worker_batch(self):
spec_algorithm=self.spec_algorithm,
spec_info=self.spec_info,
capture_hidden_mode=(
getattr(self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL)
if self.spec_info
else CaptureHiddenMode.NULL
CaptureHiddenMode.FULL
if self.return_hidden_states
else (
getattr(
self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
)
if self.spec_info
else CaptureHiddenMode.NULL
)
),
)

Expand Down
24 changes: 24 additions & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,6 +997,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
self.enable_overlap,
self.spec_algorithm,
self.server_args.enable_custom_logit_processor,
self.server_args.return_hidden_states,
)
new_batch.prepare_for_extend()

Expand Down Expand Up @@ -1156,6 +1157,16 @@ def process_batch_result_prefill(
logits_output.input_token_logprobs.tolist()
)

if logits_output.hidden_states is not None:
logits_output.hidden_states = logits_output.hidden_states.cpu()
offset = 0
# Cuts up the hidden states for each request
logits_output.hidden_states = [
logits_output.hidden_states[
offset : (offset := offset + len(req.origin_input_ids))
]
for req in batch.reqs
]
# Check finish conditions
logprob_pt = 0
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
Expand All @@ -1182,6 +1193,9 @@ def process_batch_result_prefill(
i, req, logprob_pt, next_token_ids, logits_output
)

if logits_output.hidden_states is not None:
req.hidden_states.append(logits_output.hidden_states[i])

if req.grammar is not None:
req.grammar.accept_token(next_token_id)
req.grammar.finished = req.finished()
Expand Down Expand Up @@ -1245,6 +1259,8 @@ def process_batch_result_decode(

self.token_to_kv_pool.free_group_begin()

if logits_output.hidden_states is not None:
logits_output.hidden_states = logits_output.hidden_states.cpu()
# Check finish condition
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
if req.is_retracted:
Expand Down Expand Up @@ -1275,6 +1291,9 @@ def process_batch_result_decode(
logits_output.next_token_top_logprobs_idx[i]
)

if logits_output.hidden_states is not None:
req.hidden_states.append(logits_output.hidden_states[i])

if req.grammar is not None:
req.grammar.accept_token(next_token_id)
req.grammar.finished = req.finished()
Expand Down Expand Up @@ -1398,6 +1417,7 @@ def stream_output(
completion_tokens = []
cached_tokens = []
spec_verify_ct = []
hidden_states = []

if return_logprob:
input_token_logprobs_val = []
Expand Down Expand Up @@ -1464,6 +1484,8 @@ def stream_output(
output_top_logprobs_val.append(req.output_top_logprobs_val)
output_top_logprobs_idx.append(req.output_top_logprobs_idx)

hidden_states.append(req.hidden_states)

# Send to detokenizer
if rids:
self.send_to_detokenizer.send_pyobj(
Expand All @@ -1490,6 +1512,7 @@ def stream_output(
input_top_logprobs_idx,
output_top_logprobs_val,
output_top_logprobs_idx,
hidden_states,
)
)
else: # embedding or reward model
Expand Down Expand Up @@ -1553,6 +1576,7 @@ def get_idle_batch(self):
self.enable_overlap,
self.spec_algorithm,
self.server_args.enable_custom_logit_processor,
self.server_args.return_hidden_states,
)
idle_batch.prepare_for_idle()
return idle_batch
Expand Down
6 changes: 6 additions & 0 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,12 @@ def _handle_batch_output(
}
)

if (
hasattr(recv_obj, "output_hidden_states")
and len(recv_obj.output_hidden_states[i]) > 0
):
meta_info["hidden_states"] = recv_obj.output_hidden_states[i]

if isinstance(recv_obj, BatchStrOut):
out_dict = {
"text": recv_obj.output_strs[i],
Expand Down
8 changes: 7 additions & 1 deletion python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,13 @@ def capture_one_batch_size(self, bs: int, forward: Callable):
spec_algorithm=self.model_runner.spec_algorithm,
spec_info=spec_info,
capture_hidden_mode=(
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
CaptureHiddenMode.FULL
if self.model_runner.server_args.return_hidden_states
else (
spec_info.capture_hidden_mode
if spec_info
else CaptureHiddenMode.NULL
)
Jackmin801 marked this conversation as resolved.
Show resolved Hide resolved
),
)

Expand Down
6 changes: 6 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ class ServerArgs:
delete_ckpt_after_loading: bool = False
enable_memory_saver: bool = False
allow_auto_truncate: bool = False
return_hidden_states: bool = False

# Custom logit processor
enable_custom_logit_processor: bool = False
Expand Down Expand Up @@ -896,6 +897,11 @@ def add_cli_args(parser: argparse.ArgumentParser):
action="store_true",
help="Enable users to pass custom logit processors to the server (disabled by default for security)",
)
parser.add_argument(
"--return-hidden-states",
action="store_true",
help="Return hidden states in the response.",
)
# Function Calling
parser.add_argument(
"--tool-call-parser",
Expand Down
1 change: 1 addition & 0 deletions test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
"test_vision_openai_server.py",
"test_w8a8_quantization.py",
"test_fp8_kernel.py",
"test_hidden_states.py",
],
"nightly": [
"test_nightly_gsm8k_eval.py",
Expand Down
72 changes: 72 additions & 0 deletions test/srt/test_hidden_states.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""
Usage:
python3 test_hidden_states.py
"""
Comment on lines +1 to +4
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete this.


import unittest

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

import sglang as sgl


class TestHiddenState(unittest.TestCase):
def test_return_hidden_states(self):
prompts = ["Today is", "Today is a sunny day and I like"]
model_path = "meta-llama/Meta-Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_path)
input_ids = tokenizer(prompts).input_ids

sampling_params = {"temperature": 0, "max_new_tokens": 8}

engine = sgl.Engine(
model_path=model_path,
random_seed=42,
return_hidden_states=True,
skip_tokenizer_init=True,
)
outputs = engine.generate(input_ids=input_ids, sampling_params=sampling_params)
engine.shutdown()

for output in outputs:
self.assertEqual(len(output["meta_info"]["hidden_states"]), 8)
for hidden_state in output["meta_info"]["hidden_states"]:
self.assertIsInstance(hidden_state, torch.Tensor)
# Checks that splicing of the batch was done correctly
self.assertGreater(
outputs[1]["meta_info"]["hidden_states"][0].shape[0],
outputs[0]["meta_info"]["hidden_states"][0].shape[0],
)

model = AutoModelForCausalLM.from_pretrained(
model_path, torch_dtype=torch.bfloat16, device_map="cuda"
)

for input_id, output in zip(input_ids, outputs):
with torch.inference_mode():
hf_out = model(
torch.tensor(
[input_id + output["token_ids"][:-1]], device=model.device
),
output_hidden_states=True,
)
sg_hidden_states = torch.cat(
[
i.unsqueeze(0) if len(i.shape) == 1 else i
for i in output["meta_info"]["hidden_states"]
]
).to("cuda")

self.assertTrue(
torch.allclose(
hf_out["hidden_states"][-1][0],
sg_hidden_states.to("cuda"),
atol=4e-1,
rtol=0,
)
)


if __name__ == "__main__":
unittest.main()
Loading