Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Benchmark Fix : Remove special tokens from warmup prompts #140

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 5 additions & 11 deletions neuralmagic/benchmarks/scripts/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,6 @@ def get_benchmarking_context() -> dict:
}


def remove_special_tokens_and_decode(
prompt_ids: list[int], tokenizer: PreTrainedTokenizerBase) -> str:
# Remove special tokens from prompt ids
prompt_ids = list(
filter(lambda id: id not in tokenizer.all_special_ids, prompt_ids))
return tokenizer.decode(prompt_ids)


def generate_synthetic_requests(
num_input_tokens: int, num_output_tokens: int, num_requests: int,
tokenizer: PreTrainedTokenizerBase) -> List[Tuple[str, int, int]]:
Expand Down Expand Up @@ -88,7 +80,7 @@ def generate_synthetic_requests(
continue

prompt_ids = prompt_ids[:num_input_tokens]
prompt = remove_special_tokens_and_decode(prompt_ids, tokenizer)
prompt = tokenizer.decode(prompt_ids, skip_special_tokens=True)

sampled_requests.append((prompt, num_input_tokens, num_output_tokens))

Expand All @@ -103,15 +95,17 @@ def warmup_requests(tokenizer: PreTrainedTokenizerBase,
"""
Given a tokenizer, generate `num_requests` requests used for warmup
"""
words = list(tokenizer.get_vocab().keys())
all_words = list(tokenizer.get_vocab().keys())
# Remove special tokens like <s>, </s>, <pad> etc. from all_words
words = list(filter(lambda word: not word.startswith('<'), all_words))
requests = []
for _ in range(num_requests):
# We make up random prompts for warmups in order to avoid the effects of
# prefix caching during actual benchmarking.
prompt = " ".join(random.choices(words, k=num_input_tokens))
prompt_ids = tokenizer(prompt).input_ids
prompt_ids = prompt_ids[:num_input_tokens]
prompt = remove_special_tokens_and_decode(prompt_ids, tokenizer)
prompt = tokenizer.decode(prompt_ids, skip_special_tokens=True)
requests.append((prompt, num_input_tokens, num_output_tokens))
return requests

Expand Down
Loading