Skip to content

Commit

Permalink
Extend FastGen benchmark to use AML endpoints (#865)
Browse files Browse the repository at this point in the history
Add AML backend to MII benchmarking suite.

Co-authored-by: Lev Kurilenko <[email protected]>
  • Loading branch information
mrwyattii and lekurile authored Feb 29, 2024
1 parent 6540db6 commit 8182a8b
Show file tree
Hide file tree
Showing 9 changed files with 242 additions and 154 deletions.
16 changes: 14 additions & 2 deletions benchmarks/inference/mii/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,22 @@ python run_benchmark.py --tp_size 1 2
```

By default the benchmark runs with DeepSpeed-MII as the backend inference
server. To change the backend to vLLM, provide the `--vllm` flag:
server. The benchmark also supports vLLM and Azure endpoints. To change the
backend to vLLM, provide the `--backend vllm` arg:

```bash
python run_benchmark.py --vllm
python run_benchmark.py --backend vllm
```

To benchmark against an Azure endpoint, provide the `--backend aml` as well as
the following values:
- `--aml_api_url`: API URL that points to an AML endpoint
- `--aml_api_key`: API key for the given AML endpoint
- `--deployment_name`: The name of the AML endpoint deployment you want to test against
- `--model`: The name of the HuggingFace-hosted model deployed on the AML endpoint. This is used to load a tokenizer and correctly calculate the number of tokens in the prompts and responses.

```bash
python run_benchmark.py --backend aml --model mistralai/Mixtral-8x7B-v0.1 --deployment_name mistralai-mixtral-8x7b-v01-4 --aml_api_url <URL obtained from Azure> --aml_api_key <Authentication key obtained from Azure>
```

The run_all.sh script performs benchmarks across various models, client numbers,
Expand Down
8 changes: 4 additions & 4 deletions benchmarks/inference/mii/run_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
MODELS=(meta-llama/Llama-2-7b-hf meta-llama/Llama-2-13b-hf meta-llama/Llama-2-70b-hf tiiuae/falcon-40B tiiuae/falcon-180B microsoft/phi-2 mistralai/Mixtral-8x7B-v0.1)

for MODEL in ${MODELS[@]}; do
python ./run_benchmark.py --model ${MODEL} --stream
python ./run_benchmark.py --model ${MODEL} --stream --vllm
python ./run_benchmark.py --model ${MODEL} --stream --backend fastgen
python ./run_benchmark.py --model ${MODEL} --stream --backend vllm
done

# Extra runs for Mixtral with non-default settings
python ./run_benchmark.py --model mistralai/Mixtral-8x7B-v0.1 --stream --tp_size 4 --mean_prompt_length 500 --mean_max_new_tokens 150 500 1024
python ./run_benchmark.py --model mistralai/Mixtral-8x7B-v0.1 --stream --tp_size 4 --mean_prompt_length 500 --mean_max_new_tokens 150 500 1024 --vllm
python ./run_benchmark.py --model mistralai/Mixtral-8x7B-v0.1 --stream --tp_size 4 --mean_prompt_length 500 --mean_max_new_tokens 150 500 1024 --backend fastgen
python ./run_benchmark.py --model mistralai/Mixtral-8x7B-v0.1 --stream --tp_size 4 --mean_prompt_length 500 --mean_max_new_tokens 150 500 1024 --backend vllm
6 changes: 4 additions & 2 deletions benchmarks/inference/mii/run_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def run_benchmark() -> None:
args = parse_args(server_args=True, client_args=True)

for server_args in get_args_product(args, which=SERVER_PARAMS):
start_server(server_args)
if server_args.backend != "aml":
start_server(server_args)

for client_args in get_args_product(server_args, which=CLIENT_PARAMS):
if results_exist(client_args) and not args.overwrite_results:
Expand All @@ -33,7 +34,8 @@ def run_benchmark() -> None:
print_summary(client_args, response_details)
save_json_results(client_args, response_details)

stop_server(server_args)
if server_args.backend != "aml":
stop_server(server_args)


if __name__ == "__main__":
Expand Down
3 changes: 2 additions & 1 deletion benchmarks/inference/mii/run_example.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ python ./run_benchmark.py \
--max_ragged_batch_size 768 \
--mean_prompt_length 2600 \
--mean_max_new_tokens 60 \
--stream
--stream \
--backend fastgen \

### Gernerate the plots
python ./src/plot_th_lat.py
Expand Down
191 changes: 110 additions & 81 deletions benchmarks/inference/mii/src/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# DeepSpeed Team

import argparse
import asyncio
import json
import multiprocessing
Expand All @@ -12,18 +13,30 @@
import requests
import threading
import time
from typing import List, Iterable
from typing import List, Iterable, Union

import numpy as np
from transformers import AutoTokenizer

from .postprocess_results import ResponseDetails
from .random_query_generator import RandomQueryGenerator
from .sample_input import all_text
from .utils import parse_args, print_summary, get_args_product, CLIENT_PARAMS
try:
from .postprocess_results import ResponseDetails
from .random_query_generator import RandomQueryGenerator
from .sample_input import all_text
from .utils import parse_args, print_summary, get_args_product, CLIENT_PARAMS
except ImportError:
from postprocess_results import ResponseDetails
from random_query_generator import RandomQueryGenerator
from sample_input import all_text
from utils import parse_args, print_summary, get_args_product, CLIENT_PARAMS


def call_mii(client, input_tokens, max_new_tokens, stream):
def call_fastgen(
input_tokens: str, max_new_tokens: int, args: argparse.Namespace
) -> ResponseDetails:
import mii

client = mii.client(args.deployment_name)

output_tokens = []
token_gen_time = []
time_last_token = 0
Expand All @@ -38,7 +51,7 @@ def callback(response):

time_last_token = start_time = time.time()
token_gen_time = []
if stream:
if args.stream:
output_tokens = []
client.generate(
input_tokens, max_new_tokens=max_new_tokens, streaming_fn=callback
Expand All @@ -57,7 +70,12 @@ def callback(response):
)


def call_vllm(input_tokens, max_new_tokens, stream=True):
def call_vllm(
input_tokens: str, max_new_tokens: int, args: argparse.Namespace
) -> ResponseDetails:
if not args.stream:
raise NotImplementedError("Not implemented for non-streaming")

api_url = "http://localhost:26500/generate"
headers = {"User-Agent": "Benchmark Client"}
pload = {
Expand All @@ -68,7 +86,7 @@ def call_vllm(input_tokens, max_new_tokens, stream=True):
"top_p": 0.9,
"max_tokens": max_new_tokens,
"ignore_eos": False,
"stream": stream,
"stream": args.stream,
}

def clear_line(n: int = 1) -> None:
Expand All @@ -90,76 +108,104 @@ def get_streaming_response(
yield output, time_now - time_last_token
time_last_token = time_now

# For non-streaming, but currently non-streaming is not fully implemented
def get_response(response: requests.Response) -> List[str]:
data = json.loads(response.content)
output = data["text"]
return output

token_gen_time = []
start_time = time.time()
response = requests.post(api_url, headers=headers, json=pload, stream=stream)
if stream:
token_gen_time = []
for h, t in get_streaming_response(response, start_time):
output = h
token_gen_time.append(t)

return ResponseDetails(
generated_tokens=output,
prompt=input_tokens,
start_time=start_time,
end_time=time.time(),
model_time=0,
token_gen_time=token_gen_time,
)
else:
output = get_response(response)
raise NotImplementedError("Not implemented for non-streaming")
response = requests.post(api_url, headers=headers, json=pload, stream=args.stream)
for h, t in get_streaming_response(response, start_time):
output = h
token_gen_time.append(t)

return ResponseDetails(
generated_tokens=output,
prompt=input_tokens,
start_time=start_time,
end_time=time.time(),
model_time=0,
token_gen_time=token_gen_time,
)


def call_aml(
input_tokens: str, max_new_tokens: int, args: argparse.Namespace
) -> ResponseDetails:
if args.stream:
raise NotImplementedError("Not implemented for streaming")

headers = {
"Content-Type": "application/json",
"Authorization": ("Bearer " + args.aml_api_key),
"azureml-model-deployment": args.deployment_name,
}
pload = {
"input_data": {
"input_string": [
input_tokens,
],
"parameters": {
"max_new_tokens": max_new_tokens,
"do_sample": True,
"return_full_text": False,
},
}
}

def get_response(response: requests.Response) -> List[str]:
data = json.loads(response.content)
output = data[0]["0"]
return output

token_gen_time = []
start_time = time.time()
response = requests.post(args.aml_api_url, headers=headers, json=pload)
output = get_response(response)

return ResponseDetails(
generated_tokens=output,
prompt=input_tokens,
start_time=start_time,
end_time=time.time(),
model_time=0,
token_gen_time=token_gen_time,
)


def _run_parallel(
deployment_name,
warmup,
barrier,
query_queue,
result_queue,
num_clients,
stream,
vllm,
barrier: Union[threading.Barrier, multiprocessing.Barrier],
query_queue: Union[queue.Queue, multiprocessing.Queue],
result_queue: Union[queue.Queue, multiprocessing.Queue],
args: argparse.Namespace,
):
pid = os.getpid()
session_id = f"test_session_p{pid}_t{threading.get_ident()}"

event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(event_loop)
if not vllm:
import mii

client = mii.client(deployment_name)
backend_call_fns = {"fastgen": call_fastgen, "vllm": call_vllm, "aml": call_aml}
call_fn = backend_call_fns[args.backend]

barrier.wait()

for _ in range(warmup):
for _ in range(args.warmup):
print(f"warmup queue size: {query_queue.qsize()} ({pid})", flush=True)
input_tokens, req_max_new_tokens = query_queue.get(timeout=1.0)

if vllm:
call_vllm(input_tokens, req_max_new_tokens, stream)
else:
call_mii(client, input_tokens, req_max_new_tokens, stream)
_ = call_fn(input_tokens, req_max_new_tokens, args)

barrier.wait()

time.sleep(random.uniform(0, num_clients) * 0.01)
time.sleep(random.uniform(0, args.num_clients) * 0.01)
try:
while not query_queue.empty():
print(f"queue size: {query_queue.qsize()} ({pid})", flush=True)
input_tokens, req_max_new_tokens = query_queue.get(timeout=1.0)

# Set max_new_tokens following normal distribution
if vllm:
r = call_vllm(input_tokens, req_max_new_tokens)
else:
r = call_mii(client, input_tokens, req_max_new_tokens, stream)
r = call_fn(input_tokens, req_max_new_tokens, args)

result_queue.put(r)
except queue.Empty:
Expand All @@ -180,22 +226,7 @@ def run_client(args):
6. The main process marks the end time after receiving `num_requests' results
"""

# Unpack arguments
model = args.model
deployment_name = args.deployment_name
mean_prompt_length = args.mean_prompt_length
mean_max_new_tokens = args.mean_max_new_tokens
num_clients = args.num_clients
num_requests = args.num_requests
warmup = args.warmup
max_prompt_length = args.max_prompt_length
prompt_length_var = args.prompt_length_var
max_new_tokens_var = args.max_new_tokens_var
stream = args.stream
vllm = args.vllm
use_thread = args.use_thread

if use_thread:
if args.use_thread:
runnable_cls = threading.Thread
barrier_cls = threading.Barrier
queue_cls = queue.Queue
Expand All @@ -204,42 +235,40 @@ def run_client(args):
barrier_cls = multiprocessing.Barrier
queue_cls = multiprocessing.Queue

barrier = barrier_cls(num_clients + 1)
barrier = barrier_cls(args.num_clients + 1)
query_queue = queue_cls()
result_queue = queue_cls()

processes = [
runnable_cls(
target=_run_parallel,
args=(
deployment_name,
warmup,
barrier,
query_queue,
result_queue,
num_clients,
stream,
vllm,
args,
),
)
for i in range(num_clients)
for i in range(args.num_clients)
]
for p in processes:
p.start()

tokenizer = AutoTokenizer.from_pretrained(model)
tokenizer = AutoTokenizer.from_pretrained(args.model)
query_generator = RandomQueryGenerator(all_text, tokenizer, seed=42)
request_text = query_generator.get_random_request_text(
mean_prompt_length,
mean_prompt_length * prompt_length_var,
max_prompt_length,
num_requests + warmup * num_clients,
args.mean_prompt_length,
args.mean_prompt_length * args.prompt_length_var,
args.max_prompt_length,
args.num_requests + args.warmup * args.num_clients,
)

for t in request_text:
# Set max_new_tokens following normal distribution
req_max_new_tokens = int(
np.random.normal(
mean_max_new_tokens, max_new_tokens_var * mean_max_new_tokens
args.mean_max_new_tokens,
args.max_new_tokens_var * args.mean_max_new_tokens,
)
)
query_queue.put((t, req_max_new_tokens))
Expand All @@ -252,10 +281,10 @@ def run_client(args):
barrier.wait()

response_details = []
while len(response_details) < num_requests:
while len(response_details) < args.num_requests:
res = result_queue.get()
# vLLM returns concatinated tokens
if vllm:
if args.backend == "vllm":
all_tokens = tokenizer.tokenize(res.generated_tokens)
res.generated_tokens = all_tokens[len(tokenizer.tokenize(res.prompt)) :]
response_details.append(res)
Expand Down
2 changes: 2 additions & 0 deletions benchmarks/inference/mii/src/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# DeepSpeed Team

ARG_DEFAULTS = {
"model": "meta-llama/Llama-2-7b-hf",
"deployment_name": "benchmark-deployment",
"tp_size": 1,
"max_ragged_batch_size": 768,
"num_replicas": 1,
Expand Down
Loading

0 comments on commit 8182a8b

Please sign in to comment.