Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend FastGen benchmark to use AML endpoints #865

Merged
merged 7 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmarks/inference/mii/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ python run_benchmark.py --tp_size 1 2
```

By default the benchmark runs with DeepSpeed-MII as the backend inference
server. To change the backend to vLLM, provide the `--vllm` flag:
server. To change the backend to vLLM, provide the `--backend vllm` arg:

```bash
python run_benchmark.py --vllm
python run_benchmark.py --backend vllm
mrwyattii marked this conversation as resolved.
Show resolved Hide resolved
```

The run_all.sh script performs benchmarks across various models, client numbers,
Expand Down
8 changes: 4 additions & 4 deletions benchmarks/inference/mii/run_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
MODELS=(meta-llama/Llama-2-7b-hf meta-llama/Llama-2-13b-hf meta-llama/Llama-2-70b-hf tiiuae/falcon-40B tiiuae/falcon-180B microsoft/phi-2 mistralai/Mixtral-8x7B-v0.1)

for MODEL in ${MODELS[@]}; do
python ./run_benchmark.py --model ${MODEL} --stream
python ./run_benchmark.py --model ${MODEL} --stream --vllm
python ./run_benchmark.py --model ${MODEL} --stream --backend fastgen
python ./run_benchmark.py --model ${MODEL} --stream --backend vllm
done

# Extra runs for Mixtral with non-default settings
python ./run_benchmark.py --model mistralai/Mixtral-8x7B-v0.1 --stream --tp_size 4 --mean_prompt_length 500 --mean_max_new_tokens 150 500 1024
python ./run_benchmark.py --model mistralai/Mixtral-8x7B-v0.1 --stream --tp_size 4 --mean_prompt_length 500 --mean_max_new_tokens 150 500 1024 --vllm
python ./run_benchmark.py --model mistralai/Mixtral-8x7B-v0.1 --stream --tp_size 4 --mean_prompt_length 500 --mean_max_new_tokens 150 500 1024 --backend fastgen
python ./run_benchmark.py --model mistralai/Mixtral-8x7B-v0.1 --stream --tp_size 4 --mean_prompt_length 500 --mean_max_new_tokens 150 500 1024 --backend vllm
6 changes: 4 additions & 2 deletions benchmarks/inference/mii/run_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def run_benchmark() -> None:
args = parse_args(server_args=True, client_args=True)

for server_args in get_args_product(args, which=SERVER_PARAMS):
start_server(server_args)
if server_args.backend != "aml":
start_server(server_args)

for client_args in get_args_product(server_args, which=CLIENT_PARAMS):
if results_exist(client_args) and not args.overwrite_results:
Expand All @@ -33,7 +34,8 @@ def run_benchmark() -> None:
print_summary(client_args, response_details)
save_json_results(client_args, response_details)

stop_server(server_args)
if server_args.backend != "aml":
stop_server(server_args)


if __name__ == "__main__":
Expand Down
3 changes: 2 additions & 1 deletion benchmarks/inference/mii/run_example.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ python ./run_benchmark.py \
--max_ragged_batch_size 768 \
--mean_prompt_length 2600 \
--mean_max_new_tokens 60 \
--stream
--stream \
--backend fastgen \

### Gernerate the plots
python ./src/plot_th_lat.py
Expand Down
194 changes: 113 additions & 81 deletions benchmarks/inference/mii/src/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# DeepSpeed Team

import argparse
import asyncio
import json
import multiprocessing
Expand All @@ -12,18 +13,30 @@
import requests
import threading
import time
from typing import List, Iterable
from typing import List, Iterable, Union

import numpy as np
from transformers import AutoTokenizer

from .postprocess_results import ResponseDetails
from .random_query_generator import RandomQueryGenerator
from .sample_input import all_text
from .utils import parse_args, print_summary, get_args_product, CLIENT_PARAMS
try:
lekurile marked this conversation as resolved.
Show resolved Hide resolved
from .postprocess_results import ResponseDetails
from .random_query_generator import RandomQueryGenerator
from .sample_input import all_text
from .utils import parse_args, print_summary, get_args_product, CLIENT_PARAMS
except ImportError:
from postprocess_results import ResponseDetails
from random_query_generator import RandomQueryGenerator
from sample_input import all_text
from utils import parse_args, print_summary, get_args_product, CLIENT_PARAMS


def call_mii(client, input_tokens, max_new_tokens, stream):
def call_fastgen(
input_tokens: str, max_new_tokens: int, args: argparse.Namespace
) -> ResponseDetails:
import mii

client = mii.client(args.deployment_name)

output_tokens = []
token_gen_time = []
time_last_token = 0
Expand All @@ -38,7 +51,7 @@ def callback(response):

time_last_token = start_time = time.time()
token_gen_time = []
if stream:
if args.stream:
output_tokens = []
client.generate(
input_tokens, max_new_tokens=max_new_tokens, streaming_fn=callback
Expand All @@ -57,7 +70,12 @@ def callback(response):
)


def call_vllm(input_tokens, max_new_tokens, stream=True):
def call_vllm(
input_tokens: str, max_new_tokens: int, args: argparse.Namespace
) -> ResponseDetails:
if not args.stream:
raise NotImplementedError("Not implemented for non-streaming")

api_url = "http://localhost:26500/generate"
headers = {"User-Agent": "Benchmark Client"}
pload = {
Expand All @@ -68,7 +86,7 @@ def call_vllm(input_tokens, max_new_tokens, stream=True):
"top_p": 0.9,
"max_tokens": max_new_tokens,
"ignore_eos": False,
"stream": stream,
"stream": args.stream,
}

def clear_line(n: int = 1) -> None:
Expand All @@ -90,76 +108,107 @@ def get_streaming_response(
yield output, time_now - time_last_token
time_last_token = time_now

# For non-streaming, but currently non-streaming is not fully implemented
def get_response(response: requests.Response) -> List[str]:
data = json.loads(response.content)
output = data["text"]
return output

token_gen_time = []
start_time = time.time()
response = requests.post(api_url, headers=headers, json=pload, stream=stream)
if stream:
token_gen_time = []
for h, t in get_streaming_response(response, start_time):
output = h
token_gen_time.append(t)

return ResponseDetails(
generated_tokens=output,
prompt=input_tokens,
start_time=start_time,
end_time=time.time(),
model_time=0,
token_gen_time=token_gen_time,
)
else:
output = get_response(response)
raise NotImplementedError("Not implemented for non-streaming")
response = requests.post(api_url, headers=headers, json=pload, stream=args.stream)
for h, t in get_streaming_response(response, start_time):
output = h
token_gen_time.append(t)

return ResponseDetails(
generated_tokens=output,
prompt=input_tokens,
start_time=start_time,
end_time=time.time(),
model_time=0,
token_gen_time=token_gen_time,
)


## TODO (lekurile): Create AML call function
lekurile marked this conversation as resolved.
Show resolved Hide resolved
def call_aml(
input_tokens: str, max_new_tokens: int, args: argparse.Namespace
) -> ResponseDetails:
if args.stream:
raise NotImplementedError("Not implemented for streaming")

headers = {
"Content-Type": "application/json",
"Authorization": ("Bearer " + args.aml_api_key),
"azureml-model-deployment": args.deployment_name,
}
pload = {
"input_data": {
"input_string": [
input_tokens,
],
"parameters": {
"max_new_tokens": max_new_tokens,
"do_sample": True,
"return_full_text": False,
},
}
}

def get_response(response: requests.Response) -> List[str]:
data = json.loads(response.content)
output = data[0]["0"]
return output

token_gen_time = []
start_time = time.time()
response = requests.post(args.aml_api_url, headers=headers, json=pload)
output = get_response(response)

return ResponseDetails(
generated_tokens=output,
prompt=input_tokens,
start_time=start_time,
end_time=time.time(),
model_time=0,
token_gen_time=token_gen_time,
)


def _run_parallel(
deployment_name,
warmup,
barrier,
query_queue,
result_queue,
num_clients,
stream,
vllm,
barrier: Union[threading.Barrier, multiprocessing.Barrier],
query_queue: Union[queue.Queue, multiprocessing.Queue],
result_queue: Union[queue.Queue, multiprocessing.Queue],
args: argparse.Namespace,
):
pid = os.getpid()
session_id = f"test_session_p{pid}_t{threading.get_ident()}"

event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(event_loop)
if not vllm:
import mii

client = mii.client(deployment_name)
backend_call_fns = {"fastgen": call_fastgen, "vllm": call_vllm, "aml": call_aml}
call_fn = backend_call_fns[args.backend]

barrier.wait()

for _ in range(warmup):
for _ in range(args.warmup):
print(f"warmup queue size: {query_queue.qsize()} ({pid})", flush=True)
input_tokens, req_max_new_tokens = query_queue.get(timeout=1.0)

if vllm:
call_vllm(input_tokens, req_max_new_tokens, stream)
else:
call_mii(client, input_tokens, req_max_new_tokens, stream)
_ = call_fn(input_tokens, req_max_new_tokens, args)
# call_fastgen(client, input_tokens, req_max_new_tokens, stream)

barrier.wait()

time.sleep(random.uniform(0, num_clients) * 0.01)
time.sleep(random.uniform(0, args.num_clients) * 0.01)
try:
while not query_queue.empty():
print(f"queue size: {query_queue.qsize()} ({pid})", flush=True)
input_tokens, req_max_new_tokens = query_queue.get(timeout=1.0)

# Set max_new_tokens following normal distribution
if vllm:
r = call_vllm(input_tokens, req_max_new_tokens)
else:
r = call_mii(client, input_tokens, req_max_new_tokens, stream)
r = call_fn(input_tokens, req_max_new_tokens, args)
# r = call_fastgen(client, input_tokens, req_max_new_tokens, stream)

result_queue.put(r)
except queue.Empty:
Expand All @@ -180,22 +229,7 @@ def run_client(args):
6. The main process marks the end time after receiving `num_requests' results
"""

# Unpack arguments
model = args.model
deployment_name = args.deployment_name
mean_prompt_length = args.mean_prompt_length
mean_max_new_tokens = args.mean_max_new_tokens
num_clients = args.num_clients
num_requests = args.num_requests
warmup = args.warmup
max_prompt_length = args.max_prompt_length
prompt_length_var = args.prompt_length_var
max_new_tokens_var = args.max_new_tokens_var
stream = args.stream
vllm = args.vllm
use_thread = args.use_thread

if use_thread:
if args.use_thread:
runnable_cls = threading.Thread
barrier_cls = threading.Barrier
queue_cls = queue.Queue
Expand All @@ -204,42 +238,40 @@ def run_client(args):
barrier_cls = multiprocessing.Barrier
queue_cls = multiprocessing.Queue

barrier = barrier_cls(num_clients + 1)
barrier = barrier_cls(args.num_clients + 1)
query_queue = queue_cls()
result_queue = queue_cls()

processes = [
runnable_cls(
target=_run_parallel,
args=(
deployment_name,
warmup,
barrier,
query_queue,
result_queue,
num_clients,
stream,
vllm,
args,
),
)
for i in range(num_clients)
for i in range(args.num_clients)
]
for p in processes:
p.start()

tokenizer = AutoTokenizer.from_pretrained(model)
tokenizer = AutoTokenizer.from_pretrained(args.model)
query_generator = RandomQueryGenerator(all_text, tokenizer, seed=42)
request_text = query_generator.get_random_request_text(
mean_prompt_length,
mean_prompt_length * prompt_length_var,
max_prompt_length,
num_requests + warmup * num_clients,
args.mean_prompt_length,
args.mean_prompt_length * args.prompt_length_var,
args.max_prompt_length,
args.num_requests + args.warmup * args.num_clients,
)

for t in request_text:
# Set max_new_tokens following normal distribution
req_max_new_tokens = int(
np.random.normal(
mean_max_new_tokens, max_new_tokens_var * mean_max_new_tokens
args.mean_max_new_tokens,
args.max_new_tokens_var * args.mean_max_new_tokens,
)
)
query_queue.put((t, req_max_new_tokens))
Expand All @@ -252,10 +284,10 @@ def run_client(args):
barrier.wait()

response_details = []
while len(response_details) < num_requests:
while len(response_details) < args.num_requests:
res = result_queue.get()
# vLLM returns concatinated tokens
if vllm:
if args.backend == "vllm":
all_tokens = tokenizer.tokenize(res.generated_tokens)
res.generated_tokens = all_tokens[len(tokenizer.tokenize(res.prompt)) :]
response_details.append(res)
Expand Down
2 changes: 2 additions & 0 deletions benchmarks/inference/mii/src/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# DeepSpeed Team

ARG_DEFAULTS = {
"model": "meta-llama/Llama-2-7b-hf",
"deployment_name": "benchmark-deployment",
"tp_size": 1,
"max_ragged_batch_size": 768,
"num_replicas": 1,
Expand Down
Loading
Loading