diff --git a/CHANGELOG.md b/CHANGELOG.md index c1bbb367..bb56709e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Ensure additional cached-path clients are added in the process pool workers from some dataset preparation methods. +- Fixed `label_mask` tensor created by `NumpyPaddedFSLDataset`. ## [v1.3.1](https://github.com/allenai/OLMo-core/releases/tag/v1.3.1) - 2024-09-26 diff --git a/src/olmo_core/data/numpy_dataset.py b/src/olmo_core/data/numpy_dataset.py index a8c34bde..a9a0132a 100644 --- a/src/olmo_core/data/numpy_dataset.py +++ b/src/olmo_core/data/numpy_dataset.py @@ -537,7 +537,7 @@ def __getitem__(self, index: int) -> Dict[str, Any]: item = super().__getitem__(index) pad_shape = (0, self.sequence_length - len(item["input_ids"])) item["label_mask"] = F.pad( - torch.ones_like(item["input_ids"]), pad_shape, value=self.pad_token_id + torch.ones_like(item["input_ids"], dtype=torch.bool), pad_shape, value=False ) item["input_ids"] = F.pad(item["input_ids"], pad_shape, value=self.pad_token_id) return item