diff --git a/evals/eval_on_hotpot.py b/evals/eval_on_hotpot.py index da102c8e..6fa5748b 100644 --- a/evals/eval_on_hotpot.py +++ b/evals/eval_on_hotpot.py @@ -10,8 +10,10 @@ from evals.qa_dataset_utils import load_qa_dataset from evals.qa_metrics_utils import get_metrics from evals.qa_context_provider_utils import qa_context_providers, valid_pipeline_slices +import random logger = logging.getLogger(__name__) +random.seed(42) async def answer_qa_instance(instance, context_provider): @@ -77,7 +79,7 @@ async def eval_on_QA_dataset( dataset = load_qa_dataset(dataset_name_or_filename) context_provider = qa_context_providers[context_provider_name] eval_metrics = get_metrics(metric_name_list) - instances = dataset if not num_samples else dataset[:num_samples] + instances = dataset if not num_samples else random.sample(dataset, num_samples) if "promptfoo_metrics" in eval_metrics: promptfoo_results = await eval_metrics["promptfoo_metrics"].measure(