Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] return hidden states #3364

Open
wants to merge 30 commits into
base: main
Choose a base branch
from

Conversation

Jackmin801
Copy link

@Jackmin801 Jackmin801 commented Feb 7, 2025

Motivation

This PR intends to add the return_hidden_states argument to ServerArgs which makes the results contain the last layer hidden states in output["meta_info"]["hidden_states"].
These hidden states are useful for example for verifying computations. (e.g. https://arxiv.org/abs/2501.16007)

Modifications

  • Add return_hidden_states to ServerArgs
  • Changed the logic to determine capture_hidden_mode to accomodate return_hidden_states
  • Modify scheduler process_batch_results to save the hidden state to the Req
  • Add return_hidden_states and hidden_states to necessary dataclasses

Script used to test changes

# launch the offline engine
import asyncio
from transformers import AutoTokenizer
import sglang as sgl

def main():
    MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct"
    llm = sgl.Engine(
        model_path=MODEL_NAME,
        skip_tokenizer_init=True,
        disable_cuda_graph=False,
        return_hidden_states=False,
    )
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]

    sampling_params = {"temperature": 0.8, "top_p": 0.95, "max_new_tokens": 10}

    input_ids = tokenizer(prompts).input_ids
    #outputs = llm.generate(input_ids=input_ids, sampling_params=sampling_params)
    outputs = llm.generate(prompts, sampling_params=sampling_params)
    for input_id, output in zip(input_ids, outputs):
        print("===============================")
        print(input_id)
        print(output)
        print()
        if "token_ids" in output:
            print(input_id, output["token_ids"], len(input_id), len(output["token_ids"]))
        else:
            print(output['text'])
        if "hidden_states" in output["meta_info"]:
            print(
                [i.shape for i in output["meta_info"]["hidden_states"]],
                len(output["meta_info"]["hidden_states"]),
            )

if __name__ == "__main__":
    main()

Checklist

  • Format your code according to the Code Formatting with Pre-Commit.
  • Add unit tests as outlined in the Running Unit Tests.
  • Update documentation / docstrings / example tutorials as needed, according to Writing Documentation.
  • Provide throughput / latency benchmark results and accuracy evaluation results as needed, according to Benchmark and Profiling and Accuracy Results.
  • For reviewers: If you haven't made any contributions to this PR and are only assisting with merging the main branch, please remove yourself as a co-author when merging the PR.

@zhaochenyang20
Copy link
Collaborator

This is good to see. But could change our documents to demonstrate the usage and add unit tests to your feature?

docs/backend/server_arguments.md Outdated Show resolved Hide resolved
test/srt/test_srt_engine.py Outdated Show resolved Hide resolved
@zhaochenyang20
Copy link
Collaborator

Thanks. I will try to get some one familiar with hidden state to help.

@zhaochenyang20
Copy link
Collaborator

Copy link
Collaborator

@zhaochenyang20 zhaochenyang20 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! Also, could you add the API also to the server, not only the engine. Like how we do for update_weights_from_dist. You can use Engine API and Server / HTTPS API.

Comment on lines +182 to +193
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"llm.shutdown()"
]
},
{
"cell_type": "markdown",
"metadata": {},
Copy link
Collaborator

@zhaochenyang20 zhaochenyang20 Feb 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I think we should not change the docs of offline API. Instead, we should change this.

https://docs.sglang.ai/backend/native_api.html

Also, I think that the best to do this is not add an serving arguement, but rather make a new native API instead. Just like:

@app.post("/update_weights_from_distributed")

And this:

def update_weights_from_distributed(self, name: str, dtype, shape):

This could be much easier to use and do not need to launch a specific engine, which cost a lot of time in the docs CI.

@zhaochenyang20
Copy link
Collaborator

Also, update the beginning of the native API docs.

https://docs.sglang.ai/backend/native_api.html

Apart from the OpenAI compatible APIs, the SGLang Runtime also provides its native server APIs. We introduce these following APIs:

/generate (text generation model)

/get_model_info

/get_server_info

/health

/health_generate

/flush_cache

/update_weights

/encode(embedding model)

/classify(reward model)

We mainly use requests to test these APIs in the following examples. You can also use curl.

@zhaochenyang20
Copy link
Collaborator

You can add one seperate test file as test_hidden_state.py, but add it in https://github.com/sgl-project/sglang/blob/main/test/srt/run_suite.py

Just like test_update_weights_from_disk in the run_suite.

Copy link
Collaborator

@zhaochenyang20 zhaochenyang20 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test looks good.

Comment on lines +351 to +357
CaptureHiddenMode.FULL
if self.model_runner.server_args.return_hidden_states
else (
spec_info.capture_hidden_mode
if spec_info
else CaptureHiddenMode.NULL
)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhaochenyang20 What I meant is here. It seems like it is necessary for the capture_hidden_mode to be known at engine init time. Otherwise, the decode cuda graph will not contain the return hidden state logic and this cant be changed by sampling args.

@zhaochenyang20
Copy link
Collaborator

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants