From ee86f3b58ee48f6ede1e49993deae49cd2597a0a Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 29 Oct 2024 11:10:21 -0700 Subject: [PATCH] Add `ICLMultiChoiceTaskDataset.max_sequence_length` property --- CHANGELOG.md | 4 ++++ src/olmo_eval/tasks.py | 13 +++++++++++++ src/test/tasks_test.py | 1 + 3 files changed, 18 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index cbbc431..69af929 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Added + +- Added `ICLMultiChoiceTaskDataset.max_sequence_length` property. + ## [v0.1.0](https://github.com/allenai/OLMo-in-loop-evals/releases/tag/v0.1.0) - 2024-10-28 ### Added diff --git a/src/olmo_eval/tasks.py b/src/olmo_eval/tasks.py index 87709aa..422ac3a 100644 --- a/src/olmo_eval/tasks.py +++ b/src/olmo_eval/tasks.py @@ -57,6 +57,7 @@ def __init__( # prep examples self.prep_examples() + self._max_sequence_length: Optional[int] = None def __getitem__(self, index): return self.samples[index] @@ -147,6 +148,16 @@ def pad_tokens_until_max(self, tokens, max_len=2048): return tokens + @property + def max_sequence_length(self) -> int: + if self._max_sequence_length is None: + max_seq_len = 0 + for sample in self.samples: + if len(sample["query"]) > max_seq_len: + max_seq_len = len(sample["query"]) + self._max_sequence_length = max_seq_len + return self._max_sequence_length + def collate_fn(self, data): # pad to max length # 'ctx', 'continuation', 'query' can all have variable length @@ -1391,6 +1402,8 @@ def __init__( # prep examples self.prep_examples() + self._max_sequence_length: Optional[int] = None + def prep_examples(self): current_doc_id_offset = 0 max_doc_id = 0 diff --git a/src/test/tasks_test.py b/src/test/tasks_test.py index 2c03cf6..95bdc89 100644 --- a/src/test/tasks_test.py +++ b/src/test/tasks_test.py @@ -13,6 +13,7 @@ def test_build_task(label: str): ) task = build_task(label, tokenizer) assert len(task) >= 2 + assert task.max_sequence_length > 0 instance1, instance2 = task[0], task[1] batch = task.collate_fn([instance1, instance2]) assert isinstance(batch, dict)