From 75d88183839f42e1623df4636386d9ff33a08c60 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Wed, 18 Dec 2024 10:56:40 -0800 Subject: [PATCH 1/2] Add option to fix context length --- CHANGELOG.md | 5 ++++ src/olmo_eval/tasks.py | 57 +++++++++++++++++++++++++++++++++++++++--- 2 files changed, 58 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 796dd8b..ef8949e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Added + +- Allowed passing additional kwargs to the task through `build_task()`. +- Added the option to fix the context length of every batch to the model's context length. + ## [v0.2.0](https://github.com/allenai/OLMo-in-loop-evals/releases/tag/v0.2.0) - 2024-10-29 ### Added diff --git a/src/olmo_eval/tasks.py b/src/olmo_eval/tasks.py index 422ac3a..c9b23dd 100644 --- a/src/olmo_eval/tasks.py +++ b/src/olmo_eval/tasks.py @@ -1,7 +1,7 @@ import abc import logging import re -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, Dict, List, Optional, Sequence, Type, Union, cast import datasets import torch @@ -26,6 +26,7 @@ def __init__( dataset_path: str, dataset_name: Union[str, Sequence[str], None] = None, model_ctx_len: int = 2048, + fixed_ctx_len: bool = False, split="validation", metric_type=None, # Override default metric type prompts: Optional[List[Optional[str]]] = None, # List of prompt variants to use @@ -36,6 +37,7 @@ def __init__( self.dataset_path = dataset_path self.dataset_name = dataset_name self.model_ctx_len = model_ctx_len + self.fixed_ctx_len = fixed_ctx_len self.prompts = prompts or [None] self.current_prompt: Optional[str] = None if metric_type is not None: @@ -213,7 +215,12 @@ def collate_fn(self, data): cont_byte_lens.append(sample["cont_byte_len"]) queries.append( - torch.LongTensor(self.pad_tokens_until_max(sample["query"], max_len=max_query_len)) + torch.LongTensor( + self.pad_tokens_until_max( + sample["query"], + max_len=self.model_ctx_len if self.fixed_ctx_len else max_query_len, + ) + ) ) dc_queries.append( torch.LongTensor( @@ -298,11 +305,13 @@ def __init__( tokenizer, dataset_path="piqa", dataset_name="plain_text", + **kwargs, ): super().__init__( tokenizer=tokenizer, dataset_path=dataset_path, dataset_name=dataset_name, + **kwargs, ) def doc_to_text(self, doc): @@ -341,11 +350,13 @@ def __init__( tokenizer, dataset_path="hellaswag", dataset_name=None, + **kwargs, ): super().__init__( tokenizer=tokenizer, dataset_path=dataset_path, dataset_name=dataset_name, + **kwargs, ) @classmethod @@ -406,12 +417,14 @@ def __init__( tokenizer, dataset_path="winogrande", dataset_name="winogrande_xl", + **kwargs, ): # all winogrande datasets have same val set super().__init__( tokenizer=tokenizer, dataset_path=dataset_path, dataset_name=dataset_name, + **kwargs, ) def prep_examples(self): @@ -508,11 +521,13 @@ def __init__( tokenizer, dataset_path="openbookqa", dataset_name="main", + **kwargs, ): super().__init__( tokenizer=tokenizer, dataset_path=dataset_path, dataset_name=dataset_name, + **kwargs, ) def doc_to_text(self, doc): @@ -548,11 +563,13 @@ def __init__( tokenizer, dataset_path="boolq", dataset_name=None, + **kwargs, ): super().__init__( tokenizer=tokenizer, dataset_path=dataset_path, dataset_name=dataset_name, + **kwargs, ) def doc_to_text(self, doc): @@ -598,11 +615,13 @@ def __init__( tokenizer, dataset_path="sciq", dataset_name=None, + **kwargs, ): super().__init__( tokenizer=tokenizer, dataset_path=dataset_path, dataset_name=dataset_name, + **kwargs, ) def doc_to_text(self, doc): @@ -645,11 +664,13 @@ def __init__( tokenizer, dataset_path: str = "ai2_arc", dataset_name: Optional[str] = "ARC-Easy", + **kwargs, ): super().__init__( tokenizer=tokenizer, dataset_path=dataset_path, dataset_name=dataset_name, + **kwargs, ) def doc_to_text(self, doc): @@ -685,11 +706,13 @@ def __init__( tokenizer, dataset_path="ai2_arc", dataset_name="ARC-Challenge", + **kwargs, ): super().__init__( tokenizer=tokenizer, dataset_path=dataset_path, dataset_name=dataset_name, + **kwargs, ) @@ -726,11 +749,13 @@ def __init__( tokenizer, dataset_path="allenai/basic_arithmetic", dataset_name: Optional[str] = None, + **kwargs, ): super().__init__( tokenizer=tokenizer, dataset_path=dataset_path, dataset_name=dataset_name, + **kwargs, ) @@ -752,11 +777,13 @@ def __init__( tokenizer, dataset_path="tau/commonsense_qa", dataset_name=None, + **kwargs, ): super().__init__( tokenizer=tokenizer, dataset_path=dataset_path, dataset_name=dataset_name, + **kwargs, ) @@ -777,11 +804,13 @@ def __init__( tokenizer, dataset_path="social_i_qa", dataset_name=None, + **kwargs, ): super().__init__( tokenizer=tokenizer, dataset_path=dataset_path, dataset_name=dataset_name, + **kwargs, ) def doc_to_text(self, doc): @@ -830,11 +859,13 @@ def __init__( tokenizer, dataset_path="super_glue", dataset_name="copa", + **kwargs, ): super().__init__( tokenizer=tokenizer, dataset_path=dataset_path, dataset_name=dataset_name, + **kwargs, ) def doc_to_text(self, doc): @@ -877,11 +908,13 @@ def __init__( tokenizer, dataset_path="glue", dataset_name="rte", + **kwargs, ): super().__init__( tokenizer=tokenizer, dataset_path=dataset_path, dataset_name=dataset_name, + **kwargs, ) def doc_to_text(self, doc): @@ -921,11 +954,13 @@ def __init__( tokenizer, dataset_path="super_glue", dataset_name="cb", + **kwargs, ): super().__init__( tokenizer=tokenizer, dataset_path=dataset_path, dataset_name=dataset_name, + **kwargs, ) def doc_to_text(self, doc): @@ -968,11 +1003,13 @@ def __init__( tokenizer, dataset_path="glue", dataset_name="mrpc", + **kwargs, ): super().__init__( tokenizer=tokenizer, dataset_path=dataset_path, dataset_name=dataset_name, + **kwargs, ) @classmethod @@ -1038,11 +1075,13 @@ def __init__( tokenizer, dataset_path="glue", dataset_name="sst2", + **kwargs, ): super().__init__( tokenizer=tokenizer, dataset_path=dataset_path, dataset_name=dataset_name, + **kwargs, ) @classmethod @@ -1168,6 +1207,7 @@ def __init__( prompt_variations=None, mc_labels=False, metric_type=None, + **kwargs, ): dataset_names = [] # Collect the relevant categories @@ -1203,6 +1243,7 @@ def __init__( split=split, prompts=prompts, metric_type=metric_type, + **kwargs, ) def doc_to_text(self, doc): @@ -1288,11 +1329,13 @@ def __init__( tokenizer, dataset_path="trivia_qa", dataset_name="rc.wikipedia.nocontext", + **kwargs, ): super().__init__( tokenizer=tokenizer, dataset_path=dataset_path, dataset_name=dataset_name, + **kwargs, ) def doc_to_text(self, doc): @@ -1326,11 +1369,13 @@ def __init__( tokenizer, dataset_path="nq_open", dataset_name=None, + **kwargs, ): super().__init__( tokenizer=tokenizer, dataset_path=dataset_path, dataset_name=dataset_name, + **kwargs, ) def doc_to_text(self, doc): @@ -1357,6 +1402,7 @@ def __init__( dataset_path: str, dataset_name: Union[str, Sequence[str], None] = None, model_ctx_len: int = 2048, + fixed_ctx_len: bool = False, split=None, metric_type=None, prompts: Optional[List[Optional[str]]] = None, # List of prompt variants to use @@ -1367,6 +1413,7 @@ def __init__( self.dataset_path = dataset_path self.dataset_name = dataset_name self.model_ctx_len = model_ctx_len + self.fixed_ctx_len = fixed_ctx_len self.log_instances = 0 # Set to > 0 to log the first few instances as a sanity check self.samples: List[Dict[str, Any]] = [] @@ -1847,9 +1894,11 @@ def list_tasks() -> List[str]: return list(label_to_task_map.keys()) -def build_task(label: str, tokenizer: Tokenizer) -> ICLMultiChoiceTaskDataset: +def build_task(label: str, tokenizer: Tokenizer, **kwargs) -> ICLMultiChoiceTaskDataset: task_class = label_to_task_map[label] task_kwargs = {} if isinstance(task_class, tuple): task_class, task_kwargs = task_class - return task_class(tokenizer=tokenizer, **task_kwargs) # type: ignore + return cast(Type[ICLMultiChoiceTaskDataset], task_class)( + tokenizer=tokenizer, **task_kwargs, **kwargs + ) From 762dbea9dc3c0f12a1b507139f77c22d343c4ea7 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Wed, 18 Dec 2024 10:59:23 -0800 Subject: [PATCH 2/2] clean up --- src/olmo_eval/tasks.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/olmo_eval/tasks.py b/src/olmo_eval/tasks.py index c9b23dd..eef5846 100644 --- a/src/olmo_eval/tasks.py +++ b/src/olmo_eval/tasks.py @@ -1894,11 +1894,16 @@ def list_tasks() -> List[str]: return list(label_to_task_map.keys()) -def build_task(label: str, tokenizer: Tokenizer, **kwargs) -> ICLMultiChoiceTaskDataset: +def build_task( + label: str, + tokenizer: Tokenizer, + model_ctx_len: int = 2048, + fixed_ctx_len: bool = False, +) -> ICLMultiChoiceTaskDataset: task_class = label_to_task_map[label] task_kwargs = {} if isinstance(task_class, tuple): task_class, task_kwargs = task_class return cast(Type[ICLMultiChoiceTaskDataset], task_class)( - tokenizer=tokenizer, **task_kwargs, **kwargs + tokenizer=tokenizer, model_ctx_len=model_ctx_len, fixed_ctx_len=fixed_ctx_len, **task_kwargs )