Skip to content

Commit

Permalink
transformers request generator is controlled by --subset-size CLI p…
Browse files Browse the repository at this point in the history
…arameter
  • Loading branch information
Dmytro Parfeniuk committed Oct 28, 2024
1 parent ecf2984 commit 2c94aae
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
14 changes: 13 additions & 1 deletion src/guidellm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,15 @@
"until the user exits. "
),
)
@click.option(
"--subset-size",
type=int,
default=None,
help=(
"The number of subsets to use from the dataset. "
"If not provided, all subsets will be used."
),
)
def generate_benchmark_report_cli(
target: str,
backend: BackendEnginePublic,
Expand All @@ -164,6 +173,7 @@ def generate_benchmark_report_cli(
max_requests: Union[Literal["dataset"], int, None],
output_path: str,
enable_continuous_refresh: bool,
subset_size: Optional[int],
):
"""
Generate a benchmark report for a specified backend and dataset.
Expand All @@ -181,6 +191,7 @@ def generate_benchmark_report_cli(
max_requests=max_requests,
output_path=output_path,
cont_refresh_table=enable_continuous_refresh,
subset_size=subset_size,
)


Expand All @@ -197,6 +208,7 @@ def generate_benchmark_report(
max_requests: Union[Literal["dataset"], int, None],
output_path: str,
cont_refresh_table: bool,
subset_size: Optional[int],
) -> GuidanceReport:
"""
Generate a benchmark report for a specified backend and dataset.
Expand Down Expand Up @@ -251,7 +263,7 @@ def generate_benchmark_report(
request_generator = FileRequestGenerator(path=data, tokenizer=tokenizer_inst)
elif data_type == "transformers":
request_generator = TransformersDatasetRequestGenerator(
dataset=data, tokenizer=tokenizer_inst
dataset=data, tokenizer=tokenizer_inst, subset_size=subset_size
)
else:
raise ValueError(f"Unknown data type: {data_type}")
Expand Down
6 changes: 6 additions & 0 deletions src/guidellm/request/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class TransformersDatasetRequestGenerator(RequestGenerator):
:type mode: str
:param async_queue_size: The size of the request queue.
:type async_queue_size: int
:param subset_size: The number of the subsets to use from the database.
:type subset_size: Optional[int]
"""

def __init__(
Expand All @@ -45,6 +47,7 @@ def __init__(
tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
mode: GenerationMode = "async",
async_queue_size: int = 50,
subset_size: Optional[int] = None,
**kwargs,
):
self._dataset = dataset
Expand All @@ -58,6 +61,9 @@ def __init__(
self._hf_column = resolve_transformers_dataset_column(
self._hf_dataset, column=column
)
if subset_size is not None and isinstance(self._hf_dataset, Dataset):
self._hf_dataset = self._hf_dataset.select(range(subset_size))

self._hf_dataset_iterator = iter(self._hf_dataset)

# NOTE: Must be after all the parameters since the queue population
Expand Down

0 comments on commit 2c94aae

Please sign in to comment.