Skip to content

Commit

Permalink
Add ICLMultiChoiceTaskDataset.max_sequence_length property
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Oct 29, 2024
1 parent bfdc809 commit ee86f3b
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 0 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Added

- Added `ICLMultiChoiceTaskDataset.max_sequence_length` property.

## [v0.1.0](https://github.com/allenai/OLMo-in-loop-evals/releases/tag/v0.1.0) - 2024-10-28

### Added
Expand Down
13 changes: 13 additions & 0 deletions src/olmo_eval/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(

# prep examples
self.prep_examples()
self._max_sequence_length: Optional[int] = None

def __getitem__(self, index):
return self.samples[index]
Expand Down Expand Up @@ -147,6 +148,16 @@ def pad_tokens_until_max(self, tokens, max_len=2048):

return tokens

@property
def max_sequence_length(self) -> int:
if self._max_sequence_length is None:
max_seq_len = 0
for sample in self.samples:
if len(sample["query"]) > max_seq_len:
max_seq_len = len(sample["query"])
self._max_sequence_length = max_seq_len
return self._max_sequence_length

def collate_fn(self, data):
# pad to max length
# 'ctx', 'continuation', 'query' can all have variable length
Expand Down Expand Up @@ -1391,6 +1402,8 @@ def __init__(
# prep examples
self.prep_examples()

self._max_sequence_length: Optional[int] = None

def prep_examples(self):
current_doc_id_offset = 0
max_doc_id = 0
Expand Down
1 change: 1 addition & 0 deletions src/test/tasks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def test_build_task(label: str):
)
task = build_task(label, tokenizer)
assert len(task) >= 2
assert task.max_sequence_length > 0
instance1, instance2 = task[0], task[1]
batch = task.collate_fn([instance1, instance2])
assert isinstance(batch, dict)
Expand Down

0 comments on commit ee86f3b

Please sign in to comment.