diff --git a/CHANGELOG.md b/CHANGELOG.md index e7186cff..664598e3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,8 +9,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added `SourceMixtureDataset` for composing a training mixture based on ratios of source datasets. +- Added `NumpyFSLDatasetMixture` for constructing a `NumpyDatasetBase` from a `SourceMixtureDataset`. Note this is only supported for FSL datasets. +- Added tests for `SourceMixture*` and `NumpyFSLDatasetMixture`. - Added `DownstreamEvaluatorCallbackConfig` class for running in-loop downstream eval via [OLMo-in-loop-evals](https://github.com/allenai/OLMo-in-loop-evals). +### Changed + +- Moved some types into `olmo_core.data.types` to avoid some circular dependencies. + ### Fixed - Made GCS client more robust by automatically retrying timeout errors for most operations. diff --git a/pyproject.toml b/pyproject.toml index 89f40d2d..b770d742 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ dev = [ "black>=23.1,<24.0", "isort>=5.12,<5.13", "pytest", + "pytest-memray", "pytest-sphinx", "pytest-xdist", "twine>=1.11.0", diff --git a/src/olmo_core/data/__init__.py b/src/olmo_core/data/__init__.py index b3920e3a..b710100e 100644 --- a/src/olmo_core/data/__init__.py +++ b/src/olmo_core/data/__init__.py @@ -24,8 +24,6 @@ from .numpy_dataset import ( NumpyDatasetBase, NumpyDatasetConfig, - NumpyDatasetDType, - NumpyDatasetType, NumpyFSLDataset, NumpyPaddedFSLDataset, NumpyVSLDataset, @@ -38,6 +36,7 @@ VSLNaturalCurriculum, ) from .tokenizer import TokenizerConfig, TokenizerName +from .types import NumpyDatasetDType, NumpyDatasetType __all__ = [ "NumpyDatasetBase", diff --git a/src/olmo_core/data/numpy_dataset.py b/src/olmo_core/data/numpy_dataset.py index b6c9ad73..6862e7e6 100644 --- a/src/olmo_core/data/numpy_dataset.py +++ b/src/olmo_core/data/numpy_dataset.py @@ -29,6 +29,8 @@ import torch.nn.functional as F from torch.utils.data import Dataset +from olmo_core.data.source_mixture import SourceMixtureDatasetConfig +from olmo_core.data.types import NumpyDatasetDType, NumpyDatasetType, NumpyUIntTypes from olmo_core.exceptions import OLMoConfigurationError, OLMoEnvironmentError from ..aliases import PathOrStr @@ -60,11 +62,9 @@ "VSLGrowP2Curriculum", "VSLGrowLinearCurriculum", "NumpyVSLDataset", - "NumpyDatasetType", "NumpyDatasetConfig", "VSLCurriculumType", "VSLCurriculumConfig", - "NumpyDatasetDType", ] @@ -99,7 +99,7 @@ def __init__( pad_token_id: int, eos_token_id: int, vocab_size: int, - dtype: Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]] = np.uint16, + dtype: NumpyUIntTypes = np.uint16, ): if not paths: raise OLMoConfigurationError("At least one path is required") @@ -135,7 +135,7 @@ def file_sizes(self) -> Tuple[int, ...]: The size, in bytes, of each numpy array. """ if self._array_file_sizes is None: - self._array_file_sizes = tuple(self.map(get_file_size)) + self._array_file_sizes = tuple(self.map(lambda path, _: get_file_size(path))) return self._array_file_sizes @property @@ -153,7 +153,7 @@ def vocab_size(self) -> int: @property def dtype( self, - ) -> Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]]: + ) -> NumpyUIntTypes: """ The numpy datatype of the arrays. """ @@ -208,6 +208,13 @@ def work_dir_set(self) -> bool: """ return self._work_dir_set + @property + def num_tokens(self) -> int: + """ + Get the total number of tokens in the dataset. + """ + raise NotImplementedError + def _get_file_size(self, path: PathOrStr): path_idx = self.paths.index(path) return self.file_sizes[path_idx] @@ -235,7 +242,7 @@ def _warmup_clients(self): def map( self, - func: Callable[[PathOrStr], T], + func: Callable[[PathOrStr, int], T], *, max_workers: Optional[int] = None, method: Literal["threads", "processes"] = "threads", @@ -244,7 +251,7 @@ def map( """ Call a function on each path in the dataset, returning a list of the results, in order. - :param func: The function to map to the paths. + :param func: The function to map to the paths and their indices. :param max_workers: The number of workers threads/processes. Set to 0 to execute synchronously in the main thread/process. :param method: Whether to use multi-threading or multi-processing. @@ -254,7 +261,7 @@ def map( paths = _paths or self.paths if max_workers == 0: - return [func(path) for path in paths] + return [func(path, idx) for idx, path in enumerate(paths)] executor_class: Union[ Type[concurrent.futures.ThreadPoolExecutor], @@ -269,16 +276,9 @@ def map( raise ValueError(method) with executor_class(max_workers=max_workers) as executor: - path_to_future = {} - for path in paths: - if path not in path_to_future: - path_to_future[path] = executor.submit(func, path) - - results = [] - for path in paths: - results.append(path_to_future[path].result()) + futures = [executor.submit(func, path, idx) for idx, path in enumerate(paths)] - return results + return [future.result() for future in futures] def prepare(self): """ @@ -347,7 +347,7 @@ def __init__( pad_token_id: int, eos_token_id: int, vocab_size: int, - dtype: Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]] = np.uint16, + dtype: NumpyUIntTypes = np.uint16, metadata: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = None, include_instance_metadata: Optional[bool] = None, generate_doc_lengths: bool = False, @@ -483,7 +483,9 @@ def _read_chunk_from_array(self, path: PathOrStr, index: int) -> torch.Tensor: path, start_idx, start_idx + self.sequence_length, self.dtype ) - def _get_file_size_and_length(self, path, dtype=None) -> Tuple[int, int]: + def _get_file_size_and_length( + self, path: PathOrStr, idx: int, dtype: Optional[NumpyUIntTypes] = None + ) -> Tuple[int, int]: dtype = dtype or self.dtype item_size = dtype(0).itemsize file_size = get_file_size(path) @@ -503,6 +505,156 @@ def _get_file_size_and_length(self, path, dtype=None) -> Tuple[int, int]: raise RuntimeError("invalid 'max_target_sequence_length' or 'sequence_length'") +class NumpyFSLDatasetMixture(NumpyFSLDataset): + """ + A version of :class:`NumpyFSLDataset` built from a mixture of sources and their expected token ratios relative to each other. A ``path_offset_index`` is used to determine the number of instances to retain from a path when constructing the local indices. + """ + + def __init__( + self, + *paths: PathOrStr, + path_offset_index: Dict[Tuple[str, int], int], + seed: int, + sequence_length: int, + pad_token_id: int, + eos_token_id: int, + vocab_size: int, + dtype: NumpyUIntTypes = np.uint16, + metadata: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = None, + include_instance_metadata: Optional[bool] = None, + generate_doc_lengths: bool = False, + max_target_sequence_length: Optional[int] = None, + ): + if max_target_sequence_length is not None and ( + max_target_sequence_length < sequence_length + or max_target_sequence_length % sequence_length != 0 + ): + raise OLMoConfigurationError( + "'max_target_sequence_length' should be a multiple of 'sequence_length'" + ) + + if include_instance_metadata is None and metadata: + include_instance_metadata = True + + if isinstance(metadata, list): + if len(metadata) != len(paths): + raise OLMoConfigurationError( + "'metadata' should have the same length as the number of file paths" + ) + else: + metadata = [metadata or {}] * len(paths) + + super().__init__( + *paths, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + vocab_size=vocab_size, + dtype=dtype, + sequence_length=sequence_length, + metadata=metadata, + include_instance_metadata=include_instance_metadata, + generate_doc_lengths=generate_doc_lengths, + max_target_sequence_length=max_target_sequence_length, + ) + self._metadata = tuple(metadata) + self._include_instance_metadata = include_instance_metadata + self._num_instances: Optional[int] = None + self._array_offsets: Optional[Tuple[Tuple[int, int], ...]] = None + self._lengths_dtype: Optional[NumpyUIntTypes] = None + self._instances_per_bucket: Optional[Tuple[Tuple[int, int], ...]] = None + self._path_offset_index = path_offset_index + self._seed = seed + + def prepare(self): + if self.fs_local_rank == 0: + log.info("Gathering indices...") + self._write_document_indices() + barrier() + len(self) + + def _get_indices_path(self, path: PathOrStr) -> Path: + sha256_hash = hashlib.sha256() + sha256_hash.update(str(path).encode()) + sha256_hash.update(str(self._get_file_size(path)).encode()) + path_hash = sha256_hash.hexdigest() + return ( + self.work_dir + / "dataset-common" + / f"mixture-instance-indices-{self.sequence_length}-{path_hash}.npy" + ) + + def _write_document_indices(self): + paths_needed: List[Tuple[PathOrStr, int]] = [] + for idx, path in enumerate(self.paths): + indices_path = self._get_indices_path(path) + if indices_path.is_file() and not self._bust_index_cache: + log.info(f"Reusing document indices for '{path}' at:\n'{indices_path}'") + elif path not in paths_needed: + paths_needed.append((path, idx)) + + if paths_needed: + with concurrent.futures.ProcessPoolExecutor() as executor: + futures = [] + for path, idx in paths_needed: + indices_path = self._get_indices_path(path) + log.info(f"Gathering instance indices for '{path}'...") + # NOTE: We limit the number of instances by total target token count // sequence length + max_instances = ( + self._path_offset_index[(str(path), idx)] // self.sequence_length + ) + future = executor.submit( + run_worker_func, + segment_documents_into_instances, + path, + indices_path, + max_sequence_length=self.sequence_length, + eos_token_id=self.eos_token_id, + dtype=self.dtype, + indices_dtype=self.dtype, + sample=(max_instances, self._seed), + ) + futures.append(future) + + concurrent.futures.wait(futures, return_when="ALL_COMPLETED") + + # Log results. + for path, future in zip([item[0] for item in paths_needed], futures): + _, total_instances = future.result() + log.info( + f"Created {total_instances:,d} instances of sequence length up to " + f"{self.sequence_length} from '{path}'" + ) + + def _get_file_size_and_length( + self, path: PathOrStr, idx: int, dtype: Optional[NumpyUIntTypes] = None + ) -> Tuple[int, int]: + dtype = dtype or self.dtype + item_size = dtype(0).itemsize + file_size = self._get_size_from_offset_index((path, idx)) + if ( + self.max_target_sequence_length is None + or self.max_target_sequence_length == self.sequence_length + ): + return file_size, file_size // (item_size * self.sequence_length) + elif self.max_target_sequence_length > self.sequence_length: + num_max_seq_len_instances = file_size // (item_size * self.max_target_sequence_length) + return ( + file_size, + num_max_seq_len_instances + * (self.max_target_sequence_length // self.sequence_length), + ) + else: + raise RuntimeError("invalid 'max_target_sequence_length' or 'sequence_length'") + + def _get_size_from_offset_index(self, path_index: Tuple[PathOrStr, int]) -> int: + try: + path, idx = path_index + # Get size in bytes from tokens in the supplied index * itemsize + return self._path_offset_index[(str(path), idx)] * self.dtype(0).itemsize + except KeyError: + raise OLMoEnvironmentError(f"Item not found in path index @ {path_index}") + + class NumpyPaddedFSLDataset(NumpyFSLDataset): """ A version of :class:`NumpyFSLDataset` that creates a single instance from each document. @@ -516,7 +668,7 @@ def __init__( pad_token_id: int, eos_token_id: int, vocab_size: int, - dtype: Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]] = np.uint16, + dtype: NumpyUIntTypes = np.uint16, metadata: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = None, include_instance_metadata: Optional[bool] = None, ): @@ -537,7 +689,8 @@ def offsets(self) -> Tuple[Tuple[int, int], ...]: if self._array_instance_offsets is None: item_size = self.indices_dtype(0).itemsize num_instances_per_path = self.map( - lambda path: get_file_size(self._get_instance_indices_path(path)) // (item_size * 2) + lambda path, _: get_file_size(self._get_instance_indices_path(path)) + // (item_size * 2) ) array_instance_offsets = [] start_offset = 0 @@ -550,7 +703,7 @@ def offsets(self) -> Tuple[Tuple[int, int], ...]: @property def indices_dtype( self, - ) -> Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]]: + ) -> NumpyUIntTypes: return np.uint32 def prepare(self): @@ -945,7 +1098,7 @@ def __init__( max_sequence_length: int, min_sequence_length: int = 256, curriculum: Optional[VSLCurriculum] = None, - dtype: Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]] = np.uint16, + dtype: NumpyUIntTypes = np.uint16, metadata: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = None, include_instance_metadata: Optional[bool] = None, ): @@ -985,9 +1138,7 @@ def __init__( self._curriculum = curriculum or VSLNaturalCurriculum() self._num_instances: Optional[int] = None self._array_offsets: Optional[Tuple[Tuple[int, int], ...]] = None - self._lengths_dtype: Optional[ - Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]] - ] = None + self._lengths_dtype: Optional[NumpyUIntTypes] = None self._instances_per_bucket: Optional[Tuple[Tuple[int, int], ...]] = None @property @@ -1226,13 +1377,13 @@ def instances_per_bucket(self) -> Tuple[Tuple[int, int], ...]: @property def indices_dtype( self, - ) -> Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]]: + ) -> NumpyUIntTypes: return np.uint32 @property def lengths_dtype( self, - ) -> Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]]: + ) -> NumpyUIntTypes: if self._lengths_dtype is None: for dtype in ( np.uint8, @@ -1247,39 +1398,6 @@ def lengths_dtype( return self._lengths_dtype -class NumpyDatasetType(StrEnum): - """ - An enumeration of the different :class:`NumpyDatasetBase` implementations. - """ - - fsl = "fsl" - """ - Fixed sequenced length ➡️ :class:`NumpyFSLDataset`. - """ - - padded_fsl = "padded_fsl" - """ - Padded fixed sequence length ➡️ :class:`NumpyPaddedFSLDataset`. - """ - - vsl = "vsl" - """ - Variable sequenced length ➡️ :class:`NumpyVSLDataset`. - """ - - -class NumpyDatasetDType(StrEnum): - uint8 = "uint8" - uint16 = "uint16" - uint32 = "uint32" - uint64 = "uint64" - - def as_np_dtype( - self, - ) -> Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]]: - return getattr(np, str(self)) - - class VSLCurriculumType(StrEnum): """ An enumeration of the different VSL curriculum implementations. @@ -1364,6 +1482,10 @@ class NumpyDatasetConfig(Config): """ The type of dataset. """ + source_mixture_config: Optional[SourceMixtureDatasetConfig] = None + """ + The source mixture dataset config. + """ sequence_length: Optional[int] = None """ The sequence length for a :class:`NumpyFSLDataset`. @@ -1437,6 +1559,10 @@ def validate(self): self.sequence_length = None self.max_target_sequence_length = None + if self.source_mixture_config and self.mix: + # NOTE(tylerm): This could be revisited as I think they could play nicely together. + raise OLMoConfigurationError("Only one of 'source_mixture_config' or 'mix' can be set") + @property def effective_sequence_length(self) -> int: if self.sequence_length is not None: @@ -1480,7 +1606,7 @@ def from_data_mix( def get_dtype( self, - ) -> Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]]: + ) -> NumpyUIntTypes: if self.dtype is not None: return NumpyDatasetDType(self.dtype).as_np_dtype() @@ -1501,8 +1627,10 @@ def build(self) -> NumpyDatasetBase: """ Construct the corresponding :class:`NumpyDatasetBase`. """ - if (self.paths is None) == (self.mix is None): - raise OLMoConfigurationError("Exactly one of 'paths' or 'mix' is required") + if (self.paths is None) == (self.mix is None) == (self.source_mixture_config is None): + raise OLMoConfigurationError( + "Exactly one of 'paths' or 'mix' or 'source_mixture' is required" + ) paths: List[str] = [] metadata = self.metadata @@ -1519,6 +1647,8 @@ def build(self) -> NumpyDatasetBase: paths.extend(matches) elif self.paths: paths = self.paths + elif self.source_mixture_config and self.name == NumpyDatasetType.fsl: + log.info("Building dataset from source mixture...") else: assert self.mix is not None if self.mix_base_dir is None: @@ -1555,18 +1685,35 @@ def build(self) -> NumpyDatasetBase: raise OLMoConfigurationError( "'vsl_curriculum' is only a valid field for VSL datasets" ) - dataset = NumpyFSLDataset( - *paths, - sequence_length=self.sequence_length, - max_target_sequence_length=self.max_target_sequence_length, - pad_token_id=self.tokenizer.pad_token_id, - eos_token_id=self.tokenizer.eos_token_id, - vocab_size=self.tokenizer.vocab_size, - dtype=self.get_dtype(), - metadata=metadata, - include_instance_metadata=self.include_instance_metadata, - generate_doc_lengths=self.generate_doc_lengths, - ) + if self.source_mixture_config: + mixture = self.source_mixture_config.build() + return NumpyFSLDatasetMixture( + *mixture.to_paths(), + seed=mixture.seed, + sequence_length=self.sequence_length, + max_target_sequence_length=self.max_target_sequence_length, + pad_token_id=self.tokenizer.pad_token_id, + eos_token_id=self.tokenizer.eos_token_id, + vocab_size=self.tokenizer.vocab_size, + dtype=self.get_dtype(), + metadata=self.metadata, + include_instance_metadata=self.include_instance_metadata, + generate_doc_lengths=self.generate_doc_lengths, + path_offset_index=mixture.to_index(), + ) + else: + dataset = NumpyFSLDataset( + *paths, + sequence_length=self.sequence_length, + max_target_sequence_length=self.max_target_sequence_length, + pad_token_id=self.tokenizer.pad_token_id, + eos_token_id=self.tokenizer.eos_token_id, + vocab_size=self.tokenizer.vocab_size, + dtype=self.get_dtype(), + metadata=metadata, + include_instance_metadata=self.include_instance_metadata, + generate_doc_lengths=self.generate_doc_lengths, + ) elif self.name == NumpyDatasetType.padded_fsl: if self.sequence_length is None: raise OLMoConfigurationError("'sequence_length' is required for padded FSL dataset") diff --git a/src/olmo_core/data/source_mixture.py b/src/olmo_core/data/source_mixture.py new file mode 100644 index 00000000..c5858070 --- /dev/null +++ b/src/olmo_core/data/source_mixture.py @@ -0,0 +1,364 @@ +import logging +import math +import random +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass +from itertools import chain +from typing import Dict, List, Optional, Tuple + +from rich.console import Console +from rich.progress import Progress +from rich.table import Table + +from olmo_core.aliases import PathOrStr +from olmo_core.config import Config +from olmo_core.data.types import NumpyDatasetDType +from olmo_core.exceptions import OLMoConfigurationError +from olmo_core.io import get_file_size + +__all__ = [ + "SourceMixtureConfig", + "SourceMixtureDataset", + "SourceMixtureDatasetConfig", +] + +log = logging.getLogger(__name__) + + +@dataclass +class SourceMixtureConfig(Config): + """ + A configuration class for building a source mixture. + """ + + source_name: str + """ + The name of the source. + """ + target_ratio: float + """ + The target ratio of the source in the mixture. + """ + paths: List[PathOrStr] + """ + A list of paths to the source data. + """ + max_repetition_ratio: float = 1.0 + """ + The maximum ratio of repetitions of the source data to include in the mixture. + This can be used to upsample the source data by setting the repetition ratio > 1. + """ + max_source_fraction: float = 1.0 + """ + The maximum ratio of the source data to include in the mixture. + """ + + def validate(self): + if self.target_ratio: + if not 0 <= self.target_ratio <= 1: + raise OLMoConfigurationError("target_ratio must be in the range [0, 1]") + if not 0 <= self.max_source_fraction <= 1: + raise OLMoConfigurationError("max_source_fraction must be in the range [0, 1]") + if self.max_source_fraction < self.target_ratio: + raise OLMoConfigurationError("max_source_fraction must be >= target_ratio") + + if self.max_repetition_ratio < 1: + raise OLMoConfigurationError("max_repetition_ratio must be >= 1") + + if not self.paths: + raise OLMoConfigurationError("paths must not be empty") + + if not 0 <= self.max_source_fraction <= 1: + raise OLMoConfigurationError("max_source_fraction must be in the range [0, 1]") + + +@dataclass +class SourceTokenDetails: + """ + A class to hold intermediate selection details for a mixture source. + """ + + config: SourceMixtureConfig + """ + The configuration object associated with the source. + """ + population: int + """ + The total number of tokens available for the source. + """ + num_selected: int + """ + The number of tokens to select for the source. + """ + + def for_table(self, max_tokens: int) -> Dict: + return { + "source_name": self.config.source_name, + "source_population": f"{self.population:.2e}", + "num_selected": f"{self.num_selected:.2e}", + "target_ratio": str(self.config.target_ratio), + "max_repetion_ratio": str(self.config.max_repetition_ratio), + "max_source_fraction": str(self.config.max_source_fraction), + "observed_source_ratio": f"{(self.num_selected / self.population):.4}", + "observed_global_ratio": f"{(self.num_selected / max_tokens):.4}", + } + + +@dataclass +class SourcePathTokens: + path: PathOrStr + tokens: int + + +@dataclass +class SourceMixtureOutcome: + name: str + """ + The name of the source. + """ + path_tokens: List[SourcePathTokens] + """ + A list of paths and the associated token counts. + """ + + +@dataclass +class SourceMixtureDataset: + """ + A dataset consisting of a fractionalized mixture of data sources. + """ + + seed: int + """ + The seed used to generate the dataset. + """ + sources: List[SourceMixtureOutcome] + """ + A list of sources and the associated paths and token counts. + """ + + def to_index(self) -> Dict[Tuple[str, int], int]: + """ + Convert the dataset to an indexed array of dict((int, path), int). + """ + return { + (str(outcome.path), idx): outcome.tokens + for idx, outcome in enumerate( + list(chain.from_iterable([outcome.path_tokens for outcome in self.sources])) + ) + } + + def to_paths(self) -> List[PathOrStr]: + """ + Convert the dataset to a list of paths while maintaining stable ordering. + """ + return [ + item.path + for item in list(chain.from_iterable([outcome.path_tokens for outcome in self.sources])) + ] + + +@dataclass +class SourceMixtureDatasetConfig(Config): + """ + A configuration class for building a dataset from a fractionalized mixture of sources. + """ + + max_tokens: int + """ + The maximum number of tokens to include in the dataset. + """ + source_configs: List[SourceMixtureConfig] + """ + A list of source configurations. + """ + sequence_length: int + """ + The instance sequence length of the dataset. + """ + dtype: NumpyDatasetDType + """ + The data type of the dataset. + """ + processes: int = 1 + """ + The number of processes to use for counting tokens in parallel. + """ + seed: int = 42 + """ + The seed used to generate the dataset. + """ + + def validate(self): + if self.max_tokens <= 0: + raise OLMoConfigurationError("max_tokens must be > 0") + + if not self.source_configs: + raise OLMoConfigurationError("source_configs must not be empty") + + if (total := sum([source.target_ratio for source in self.source_configs])) != 1.0: + raise OLMoConfigurationError(f"target_ratios must sum to 1, got {total}") + + def build(self) -> SourceMixtureDataset: + self.validate() + random.seed(self.seed) + available_tokens_by_source: Dict[str, int] = {} + + log.info("---------------------------------------------------------") + log.info("Generating a source mixture from configurations:") + log.info(self.source_configs) + + # Count the number of tokens available for each source + for source_config in self.source_configs: + log.info(f"Counting tokens for source: {source_config.source_name}") + available_tokens_by_source[source_config.source_name] = self._count_tokens_for_paths( + paths=source_config.paths, source=source_config.source_name + ) + + tokens_details_by_source: List[SourceTokenDetails] = [] + + # Calculate the number of tokens available and to include for each source + for source_config in self.source_configs: + num_for_source = available_tokens_by_source[source_config.source_name] + needed_for_source = int(self.max_tokens * source_config.target_ratio) + max_for_source = int( + (num_for_source * source_config.max_source_fraction) + * source_config.max_repetition_ratio + ) + + # Ensure that the max tokens for a source meet the target ratio requirement + if max_for_source < needed_for_source: + raise OLMoConfigurationError( + f"Insufficient tokens for source: {source_config.source_name} @ target global ratio: {source_config.target_ratio} :: {max_for_source} < {needed_for_source}" + ) + + tokens_details_by_source.append( + SourceTokenDetails( + config=source_config, + population=num_for_source, + num_selected=needed_for_source, + ) + ) + + completed: List[SourceMixtureOutcome] = [] + for source in tokens_details_by_source: + completed.append( + SourceMixtureOutcome( + name=source.config.source_name, + path_tokens=self.get_paths_and_tokens_for_source( + source_config=source.config, + token_details=source, + ), + ) + ) + + self.render_mixture_outcome_tables(tokens_details_by_source) + + for outcome in completed: + for item in outcome.path_tokens: + log.info(f"Selected {item.tokens} tokens from {outcome.name} at {item.path}") + + return SourceMixtureDataset(seed=self.seed, sources=completed) + + def get_paths_and_tokens_for_source( + self, source_config: SourceMixtureConfig, token_details: SourceTokenDetails + ) -> List[SourcePathTokens]: + """ + Get the paths and resulting token count for a source. + """ + take_ratio = token_details.num_selected / token_details.population + path_tokens = [] + + # When we need more than 1 repetition of the source data we have a take ration > 1 + if take_ratio > 1: + take_ratios = [] + remaining = take_ratio + + while remaining > 0: + chunk = min(1.0, remaining) + take_ratios.append(chunk) + remaining -= chunk + + for ratio in take_ratios: + for path in source_config.paths: + tokens_to_keep = int(math.ceil(self._count_tokens_for_file(path) * ratio)) + path_tokens.append(SourcePathTokens(path=path, tokens=tokens_to_keep)) + + return path_tokens + + for path in source_config.paths: + tokens_to_keep = int(math.ceil(self._count_tokens_for_file(path) * take_ratio)) + path_tokens.append(SourcePathTokens(path=path, tokens=tokens_to_keep)) + + return path_tokens + + def _count_tokens_for_paths(self, paths: List[PathOrStr], source: Optional[str]) -> int: + """ + Count the number of tokens for a set of source files in parallel. + + Args: + source_config: The source configuration. + dtype: The data type of the source tokens. + """ + + with ThreadPoolExecutor(max_workers=self.processes) as executor: + futures = [] + for path in paths: + futures.append(executor.submit(self._count_tokens_for_file, path)) + + with Progress() as progress: + results = [] + task = progress.add_task( + f"Counting available tokens for source: {source}", total=len(futures) + ) + for future in as_completed(futures): + progress.update(task, advance=1) + results.append(future.result()) + + return sum(results) + + def _count_tokens_for_file(self, path: PathOrStr) -> int: + return self._bytes_to_tokens(get_file_size(path), self.dtype) + + def _bytes_to_tokens(self, num_bytes: int, dtype: NumpyDatasetDType) -> int: + """ + Convert bytes to tokens based on the dtype. + """ + npdtype = dtype.as_np_dtype() + return num_bytes // npdtype(int(0)).itemsize + + def render_mixture_outcome_tables(self, results: List[SourceTokenDetails]) -> None: + """ + Render tables enumerating the global and per-source mixture outcomes. + """ + + console = Console() + + source_rows = [item.for_table(self.max_tokens) for item in results] + source_headers = source_rows[0].keys() + + source_table = Table(title="Outcome by source") + for header in source_headers: + source_table.add_column(header) + + for row in source_rows: + source_table.add_row(*[row[header] for header in source_headers]) + + console.print(source_table) + + total_tokens = sum([item.population for item in results]) + selected_tokens = sum([item.num_selected for item in results]) + observed_global_ratio = f"{(selected_tokens / total_tokens):.4}" + + global_table = Table(title="Global outcome") + global_headers = [ + "total_tokens", + "selected_tokens", + "observed_global_ratio", + ] + + for header in global_headers: + global_table.add_column(header) + + global_table.add_row(f"{total_tokens:.2e}", f"{selected_tokens:.2e}", observed_global_ratio) + console.print(global_table) diff --git a/src/olmo_core/data/types.py b/src/olmo_core/data/types.py new file mode 100644 index 00000000..d08571f2 --- /dev/null +++ b/src/olmo_core/data/types.py @@ -0,0 +1,40 @@ +from typing import Type, Union + +import numpy as np + +from olmo_core.config import StrEnum + +NumpyUIntTypes = Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]] + + +class NumpyDatasetType(StrEnum): + """ + An enumeration of the different :class:`NumpyDatasetBase` implementations. + """ + + fsl = "fsl" + """ + Fixed sequenced length ➡️ :class:`NumpyFSLDataset`. + """ + + padded_fsl = "padded_fsl" + """ + Padded fixed sequence length ➡️ :class:`NumpyPaddedFSLDataset`. + """ + + vsl = "vsl" + """ + Variable sequenced length ➡️ :class:`NumpyVSLDataset`. + """ + + +class NumpyDatasetDType(StrEnum): + uint8 = "uint8" + uint16 = "uint16" + uint32 = "uint32" + uint64 = "uint64" + + def as_np_dtype( + self, + ) -> Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]]: + return getattr(np, str(self)) diff --git a/src/olmo_core/data/utils.py b/src/olmo_core/data/utils.py index 2a5293db..dd44d3af 100644 --- a/src/olmo_core/data/utils.py +++ b/src/olmo_core/data/utils.py @@ -1,6 +1,7 @@ import gzip import math import os +import random from contextlib import contextmanager from pathlib import Path from typing import ( @@ -328,7 +329,7 @@ def memmap_to_write( file until the context exists successfully. """ path.parent.mkdir(exist_ok=True, parents=True) - tmp_path = path.with_suffix(".npy.tmp") + tmp_path = path.with_suffix(f".{random.randint(0,2**32)}.npy.tmp") mmap = np.memmap(tmp_path, dtype=dtype, mode="w+", shape=shape) try: yield mmap @@ -411,23 +412,33 @@ def segment_documents_into_instances( indices_dtype: Union[ Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64] ] = np.uint32, + sample: Optional[Tuple[int, int]] = None, ) -> Tuple[int, int]: """ Segment documents into instances of at most ``sequence_length`` tokens. Saving the indices of the instances to ``target``. + Sample a subset of the instances if ``sample`` is provided as a tuple of ``(max_instances, seed)``. + Returns the number of original documents and the number of resulting instances documents. """ total_og_docs = 0 - indices = [] - for start_idx, end_idx in iter_document_indices(path, eos_token_id=eos_token_id, dtype=dtype): - total_og_docs += 1 - length = end_idx - start_idx - indices.append(start_idx) - indices.append(start_idx + min(length, max_sequence_length)) - start_idx += length + idx_gen = ( + idx + for start_idx, end_idx in iter_document_indices( + path, eos_token_id=eos_token_id, dtype=dtype + ) + for idx in (start_idx, start_idx + min(end_idx - start_idx, max_sequence_length)) + ) + indices = np.fromiter(idx_gen, dtype=indices_dtype) + total_og_docs = len(indices) // 2 - with memmap_to_write(target, dtype=indices_dtype, shape=(len(indices),)) as indices_mmap: + if sample is not None: + max_instances, seed = sample + rng = get_rng(seed) + indices = rng.choice(indices.reshape(-1, 2), size=max_instances).reshape(-1) + + with memmap_to_write(target, dtype=indices_dtype, shape=(indices.size,)) as indices_mmap: indices_mmap[:] = indices return total_og_docs, len(indices) // 2 diff --git a/src/olmo_core/internal/experiment.py b/src/olmo_core/internal/experiment.py index efe15a4f..bc7f829d 100644 --- a/src/olmo_core/internal/experiment.py +++ b/src/olmo_core/internal/experiment.py @@ -180,9 +180,11 @@ def build_common_components( vsl_curriculum=VSLCurriculumConfig( name=VSLCurriculumType.grow_p2, num_cycles=8, balanced=False ), - work_dir=None - if is_url(root_dir) - else f"{root_dir}/checkpoints/{beaker_user.lower()}/dataset-cache", + work_dir=( + None + if is_url(root_dir) + else f"{root_dir}/checkpoints/{beaker_user.lower()}/dataset-cache" + ), ) data_loader_config = NumpyDataLoaderConfig( @@ -202,9 +204,11 @@ def build_common_components( mix_base_dir=root_dir, sequence_length=dataset_config.effective_sequence_length, tokenizer=tokenizer_config, - work_dir=None - if is_url(root_dir) - else f"{root_dir}/checkpoints/{beaker_user.lower()}/dataset-cache", + work_dir=( + None + if is_url(root_dir) + else f"{root_dir}/checkpoints/{beaker_user.lower()}/dataset-cache" + ), ), eval_interval=1000, ), diff --git a/src/test/data/fixtures.py b/src/test/data/fixtures.py new file mode 100644 index 00000000..4fcaa84d --- /dev/null +++ b/src/test/data/fixtures.py @@ -0,0 +1,66 @@ +from pathlib import Path +from typing import Type, Union + +import numpy as np + +from olmo_core.data import NumpyDatasetBase, NumpyDatasetConfig, TokenizerConfig +from olmo_core.data.source_mixture import ( + SourceMixtureConfig, + SourceMixtureDatasetConfig, +) +from olmo_core.data.types import NumpyDatasetDType + +from ..utils import mk_mmaps + + +def get_fsl_mixture( + tmp_path: Path, + dtype: Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]] = np.uint32, + seed: int = 42, + sequence_length: int = 4, + num_tokens: int = 20 * 1000, + eos: int = 0, +) -> NumpyDatasetBase: + seed = 42 + mmap1 = mk_mmaps( + tmp_path, "mmap1", 1, num_tokens * 2, dtype, eos=eos, seed=seed, seq_length=sequence_length + ) + mmap2 = mk_mmaps( + tmp_path, "mmap2", 1, num_tokens * 2, dtype, eos=eos, seed=seed, seq_length=sequence_length + ) + + tokenizer = TokenizerConfig( + vocab_size=32_000, + eos_token_id=eos, + pad_token_id=-1, + ) + + mixture_config = SourceMixtureDatasetConfig( + max_tokens=num_tokens, + sequence_length=sequence_length, + source_configs=[ + SourceMixtureConfig( + source_name="mmap1", + paths=[i[0] for i in mmap1], + target_ratio=0.8, + ), + SourceMixtureConfig( + source_name="mmap2", + paths=[i[0] for i in mmap2], + target_ratio=0.2, + ), + ], + dtype=NumpyDatasetDType.uint16, + processes=1, + seed=seed, + ) + + ds = NumpyDatasetConfig( + source_mixture_config=mixture_config, + sequence_length=sequence_length, + tokenizer=tokenizer, + include_instance_metadata=False, + ).build() + ds.prepare() + + return ds diff --git a/src/test/data/numpy_dataset_test.py b/src/test/data/numpy_dataset_test.py index 19e0bd19..e8c49952 100644 --- a/src/test/data/numpy_dataset_test.py +++ b/src/test/data/numpy_dataset_test.py @@ -10,8 +10,15 @@ NumpyVSLDataset, TokenizerConfig, ) +from olmo_core.data.source_mixture import ( + SourceMixtureConfig, + SourceMixtureDatasetConfig, +) +from olmo_core.data.types import NumpyDatasetDType from olmo_core.data.utils import get_document_indices, write_document_indices +from ..utils import mk_mmaps + def test_numpy_fsl_dataset(tmp_path: Path): mmap1 = np.memmap(tmp_path / "mmap1.npy", mode="w+", dtype=np.uint16, shape=(16,)) @@ -65,6 +72,120 @@ def test_numpy_padded_fsl_dataset(tmp_path: Path): assert len(ds) == 4 +def test_numpy_fsl_mixture_dataset(tmp_path: Path): + # NOTE: At small token counts the take_ratio can be finicky so we test at small but real world-ish scale + npdtype = np.uint16 + seed = 42 + mmap1 = mk_mmaps(tmp_path, "mmap1", 1, 20 * 1000, npdtype, eos=0, seed=seed) + mmap2 = mk_mmaps(tmp_path, "mmap2", 1, 20 * 1000, npdtype, eos=0, seed=seed) + + sequence_length = 4 + tokenizer = TokenizerConfig( + vocab_size=32_000, + eos_token_id=0, + pad_token_id=-1, + ) + + mixture_config = SourceMixtureDatasetConfig( + max_tokens=10_000, + sequence_length=sequence_length, + source_configs=[ + SourceMixtureConfig( + source_name="mmap1", + paths=[i[0] for i in mmap1], + target_ratio=0.8, + ), + SourceMixtureConfig( + source_name="mmap2", + paths=[i[0] for i in mmap2], + target_ratio=0.2, + ), + ], + dtype=NumpyDatasetDType.uint16, + processes=1, + seed=seed, + ) + + ds = NumpyDatasetConfig( + source_mixture_config=mixture_config, + sequence_length=sequence_length, + tokenizer=tokenizer, + include_instance_metadata=False, + ).build() + ds.prepare() + + expected = "68144f" + assert ds.fingerprint.endswith( + expected + ), f"Fingerprint mismatch, expected {expected}, got {ds.fingerprint[-6:]}...Do you need to update expected fingerprint?" + assert ds[0]["input_ids"].tolist() == [ + 56423, + 24546, + 15796, + 52203, + ] # stable because we pass a seed + assert ds.num_tokens == 10000 + assert len(ds) == 2500 + + +def test_numpy_fsl_mixture_dataset_with_repetition(tmp_path: Path): + # NOTE: At small token counts the take_ratio can be finicky so we test at small but real world-ish scale + npdtype = np.uint16 + seed = 42 + mmap1 = mk_mmaps(tmp_path, "mmap1", 1, 10 * 1000, npdtype, eos=0, seed=seed) + mmap2 = mk_mmaps(tmp_path, "mmap2", 1, 20 * 1000, npdtype, eos=0, seed=seed) + + sequence_length = 4 + tokenizer = TokenizerConfig( + vocab_size=32_000, + eos_token_id=0, + pad_token_id=-1, + ) + + source1_paths = [i[0] for i in mmap1] * 2 # duplicate the paths + + mixture_config = SourceMixtureDatasetConfig( + max_tokens=10_000, + sequence_length=sequence_length, + source_configs=[ + SourceMixtureConfig( + source_name="mmap1", + paths=source1_paths, + target_ratio=0.8, + ), + SourceMixtureConfig( + source_name="mmap2", + paths=[i[0] for i in mmap2], + target_ratio=0.2, + ), + ], + dtype=NumpyDatasetDType.uint16, + processes=1, + seed=seed, + ) + + ds = NumpyDatasetConfig( + source_mixture_config=mixture_config, + sequence_length=sequence_length, + tokenizer=tokenizer, + include_instance_metadata=False, + ).build() + ds.prepare() + + expected = "190cd0" + assert ds.fingerprint.endswith( + expected + ), f"Fingerprint mismatch, expected {expected}, got {ds.fingerprint[-6:]}...Do you need to update expected fingerprint?" + assert ds[0]["input_ids"].tolist() == [ + 56423, + 24546, + 15796, + 52203, + ] # stable because we pass a seed + assert ds.num_tokens == 10000 + assert len(ds) == 2500 + + def write_data_file(data: List[int], path: Path, dtype, eos_token_id: int): path.parent.mkdir(exist_ok=True, parents=True) mmap = np.memmap(path, mode="w+", dtype=dtype, shape=(len(data),)) diff --git a/src/test/data/source_mixture_test.py b/src/test/data/source_mixture_test.py new file mode 100644 index 00000000..264ff52f --- /dev/null +++ b/src/test/data/source_mixture_test.py @@ -0,0 +1,312 @@ +import logging +from itertools import chain +from pathlib import Path + +import pytest + +from olmo_core.data import NumpyDatasetDType +from olmo_core.data.source_mixture import ( + SourceMixtureConfig, + SourceMixtureDataset, + SourceMixtureDatasetConfig, +) +from olmo_core.exceptions import OLMoConfigurationError + +from ..utils import mk_mmaps + + +def test_source_mixture_config(tmp_path: Path, caplog, capsys): + source_paths = { + "1": mk_mmaps(tmp_path=tmp_path, prefix="source1", num_files=2, size=1_000_000), + "2": mk_mmaps(tmp_path=tmp_path, prefix="source2", num_files=2, size=1_000_000), + "3": mk_mmaps(tmp_path=tmp_path, prefix="source3", num_files=2, size=1_000_000), + } + + source_configs = [ + SourceMixtureConfig( + source_name="1", + target_ratio=0.33, + paths=[i[0] for i in source_paths["1"]], + ), + SourceMixtureConfig( + source_name="2", target_ratio=0.33, paths=[i[0] for i in source_paths["2"]] + ), + SourceMixtureConfig( + source_name="3", + target_ratio=0.34, + paths=[i[0] for i in source_paths["3"]], + ), + ] + + max_tokens = 5_000_000 + + config = SourceMixtureDatasetConfig( + max_tokens=max_tokens, + source_configs=source_configs, + dtype=NumpyDatasetDType.uint32, + sequence_length=1024, + ) + + # NOTE: We need to disable capsys so we can override log capture as + # we want to see the rendered tables in the case + with capsys.disabled(), caplog.at_level(logging.DEBUG): + config.validate() + mixture = config.build() + print(caplog.text) + assert isinstance(mixture, SourceMixtureDataset) + + +def test_source_mixture_config_validation(): + with pytest.raises(OLMoConfigurationError): + SourceMixtureConfig( + source_name="source1", target_ratio=1.2, paths=["/path/to/source1"] + ).validate() + + with pytest.raises(OLMoConfigurationError): + SourceMixtureConfig( + source_name="source1", + target_ratio=0.5, + max_source_fraction=0.4, + paths=["/path/to/source1"], + ).validate() + + with pytest.raises(OLMoConfigurationError): + SourceMixtureConfig(source_name="source1", target_ratio=0.5, paths=[]).validate() + + config = SourceMixtureConfig( + source_name="source1", target_ratio=0.5, paths=["/path/to/source1"] + ) + config.validate() + + +def test_dataset_mixture_config_validation(): + source_configs = [ + SourceMixtureConfig(source_name="source1", target_ratio=0.5, paths=["/path/to/source1"]), + SourceMixtureConfig(source_name="source2", target_ratio=0.5, paths=["/path/to/source2"]), + ] + + config = SourceMixtureDatasetConfig( + max_tokens=1000, + source_configs=source_configs, + dtype=NumpyDatasetDType.uint32, + sequence_length=1024, + ) + config.validate() + + source_configs_invalid = [ + SourceMixtureConfig(source_name="source1", target_ratio=0.7, paths=["/path/to/source1"]), + SourceMixtureConfig(source_name="source2", target_ratio=0.5, paths=["/path/to/source2"]), + ] + + config_invalid = SourceMixtureDatasetConfig( + max_tokens=1000, + source_configs=source_configs_invalid, + dtype=NumpyDatasetDType.uint32, + sequence_length=1024, + ) + + with pytest.raises(OLMoConfigurationError): + config_invalid.validate() + + +def test_dataset_mixture_build(tmp_path: Path): + source_paths = { + "1": mk_mmaps(tmp_path=tmp_path, prefix="source1", num_files=2, size=1_000_000), + "2": mk_mmaps(tmp_path=tmp_path, prefix="source2", num_files=2, size=1_000_000), + "3": mk_mmaps(tmp_path=tmp_path, prefix="source3", num_files=2, size=1_000_000), + } + + source_configs = [ + SourceMixtureConfig( + source_name="1", + target_ratio=0.33, + paths=[i[0] for i in source_paths["1"]], + ), + SourceMixtureConfig( + source_name="2", target_ratio=0.33, paths=[i[0] for i in source_paths["2"]] + ), + SourceMixtureConfig( + source_name="3", + target_ratio=0.34, + paths=[i[0] for i in source_paths["3"]], + ), + ] + + max_tokens = 5_000_000 + + config = SourceMixtureDatasetConfig( + max_tokens=max_tokens, + source_configs=source_configs, + dtype=NumpyDatasetDType.uint32, + sequence_length=1024, + ) + + mixture = config.build() + assert isinstance(mixture, SourceMixtureDataset) + + +def test_dataset_mixture_build_insufficient_source_data(tmp_path: Path): + source_paths = { + "1": mk_mmaps(tmp_path=tmp_path, prefix="source1", num_files=1, size=1_000_000), + "2": mk_mmaps(tmp_path=tmp_path, prefix="source2", num_files=2, size=1_000_000), + "3": mk_mmaps(tmp_path=tmp_path, prefix="source3", num_files=2, size=1_000_000), + } + source_configs = [ + SourceMixtureConfig( + source_name="1", + target_ratio=0.5, + paths=[i[0] for i in source_paths["1"]], + ), + SourceMixtureConfig( + source_name="2", target_ratio=0.25, paths=[i[0] for i in source_paths["2"]] + ), + SourceMixtureConfig( + source_name="3", + target_ratio=0.25, + paths=[i[0] for i in source_paths["3"]], + ), + ] + + max_tokens = 5_000_000 + + config = SourceMixtureDatasetConfig( + max_tokens=max_tokens, + source_configs=source_configs, + dtype=NumpyDatasetDType.uint32, + sequence_length=1024, + ) + + # Should raise exception because the target ratio for source 1 @50% (2.5M) is infeasible without repetition (default max_repetition_ratio=1) + with pytest.raises(OLMoConfigurationError): + config.build() + + +def test_dataset_mixture_build_with_repetition(tmp_path: Path): + """ + Test building a dataset with repetition of a source. + + Source 1 has a target ratio of 90% and a max repetition ratio of 4.0, so it should be possible to meet the target of 3600 tokens with 1 file of 1000 tokens repeated 4 times. + """ + source_paths = { + "1": mk_mmaps(tmp_path=tmp_path, prefix="source1", num_files=1, size=1_000_000), + "2": mk_mmaps(tmp_path=tmp_path, prefix="source2", num_files=2, size=1_000_000), + "3": mk_mmaps(tmp_path=tmp_path, prefix="source3", num_files=2, size=1_000_000), + } + + source_configs = [ + SourceMixtureConfig( + source_name="1", + target_ratio=0.5, + max_repetition_ratio=3.0, # Allow 3x repetition of source1 so that we can meet the target of 2.5M + paths=[i[0] for i in source_paths["1"]], + ), + SourceMixtureConfig( + source_name="2", target_ratio=0.25, paths=[i[0] for i in source_paths["2"]] + ), + SourceMixtureConfig( + source_name="3", + target_ratio=0.25, + paths=[i[0] for i in source_paths["3"]], + ), + ] + + max_tokens = 5_000_000 + + config = SourceMixtureDatasetConfig( + max_tokens=max_tokens, + source_configs=source_configs, + dtype=NumpyDatasetDType.uint32, + sequence_length=1024, + ) + + mixture = config.build() + sources = [source for source in mixture.sources] + all_paths = [] + for source in sources: + all_paths.extend([item for item in source.path_tokens]) + + total_tokens = sum([item.tokens for item in all_paths]) + assert isinstance(mixture, SourceMixtureDataset) + assert total_tokens == 5_000_000 + + +def test_dataset_mixture_build_insufficient_source_max_fraction(tmp_path: Path): + source_paths = { + "1": mk_mmaps(tmp_path=tmp_path, prefix="source1", num_files=1, size=1_000_000), + "2": mk_mmaps(tmp_path=tmp_path, prefix="source2", num_files=2, size=1_000_000), + "3": mk_mmaps(tmp_path=tmp_path, prefix="source3", num_files=2, size=1_000_000), + } + source_configs = [ + SourceMixtureConfig( + source_name="1", + target_ratio=0.25, + paths=[i[0] for i in source_paths["1"]], + max_source_fraction=0.10, # Allow only 10% of source1 to be used (population is 1M tokens) + ), + SourceMixtureConfig( + source_name="2", + target_ratio=0.25, + paths=[i[0] for i in source_paths["2"]], + ), + SourceMixtureConfig( + source_name="3", + target_ratio=0.5, + paths=[i[0] for i in source_paths["3"]], + ), + ] + + # 5 source files * 1_000_000 tokens per file + max_tokens = len(list(chain(*source_paths.values()))) * 1_000_000 + + config = SourceMixtureDatasetConfig( + max_tokens=max_tokens, + source_configs=source_configs, + dtype=NumpyDatasetDType.uint32, + sequence_length=1024, + ) + + # Should raise exception because the target ratio for source 1 is infeasible because + # we limit usage to 10% of the source + with pytest.raises(OLMoConfigurationError): + config.build() + + +def test_dataset_mixture_build_duplicate_paths(tmp_path: Path): + sources = { + "1": mk_mmaps(tmp_path=tmp_path, prefix="source1", num_files=1, size=500_000), + "2": mk_mmaps(tmp_path=tmp_path, prefix="source2", num_files=2, size=1_000_000), + "3": mk_mmaps(tmp_path=tmp_path, prefix="source3", num_files=2, size=1_000_000), + } + + source_configs = [ + SourceMixtureConfig( + source_name="1", + target_ratio=0.33, # 990k tokens + max_repetition_ratio=2.0, + paths=[sources["1"][0][0], sources["1"][0][0]], # Duplicate the 1 path for source 1 + ), + SourceMixtureConfig(source_name="2", target_ratio=0.33, paths=[i[0] for i in sources["2"]]), + SourceMixtureConfig( + source_name="3", + target_ratio=0.34, + paths=[i[0] for i in sources["3"]], + ), + ] + + max_tokens = 3_000_000 + + config = SourceMixtureDatasetConfig( + max_tokens=max_tokens, + source_configs=source_configs, + dtype=NumpyDatasetDType.uint32, + sequence_length=1024, + ) + + expected = [sources["1"][0][0]] + [item[0] for item in list(chain(*sources.values()))] + mixture = config.build() + index = mixture.to_index() + paths = mixture.to_paths() + assert paths == expected + assert len(index) == 6 + assert isinstance(mixture, SourceMixtureDataset) + assert len(mixture.sources) == 3 diff --git a/src/test/data/utils_test.py b/src/test/data/utils_test.py index 4d678c09..caef70ba 100644 --- a/src/test/data/utils_test.py +++ b/src/test/data/utils_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest import torch from olmo_core.data.utils import ( @@ -8,10 +9,41 @@ iter_batched, iter_document_indices, melt_batch, + segment_documents_into_instances, write_document_indices, ) +@pytest.mark.limit_memory("245 KB") +def test_segment_documents_into_instances(tmp_path): + data = [1, 2, 3, 4, 0, 5, 6, 7, 8, 0, 1, 2, 3, 4, 5, 0] * 1000 + data_path = tmp_path / "data.npy" + max_sequence_length = 4 + mmap = np.memmap(data_path, mode="w+", dtype=np.uint16, shape=(len(data),)) + indices_path = tmp_path / "indices.npy" + mmap[:] = data + mmap.flush() + + eos = 0 + dtype = np.uint16 + sample = (2, 42) + + results = [] + for _ in range(10): + results.append( + segment_documents_into_instances( + path=data_path, + target=indices_path, + max_sequence_length=max_sequence_length, + eos_token_id=eos, + dtype=dtype, + sample=sample, + ) + ) + + assert all([r[1] == 2 for r in results]) + + def test_iter_document_indices(tmp_path): data = [1, 2, 3, 4, 5, 6, 7, 8, 0, 1, 2, 3, 4, 5, 0] data_path = tmp_path / "data.npy" diff --git a/src/test/utils.py b/src/test/utils.py index 4b04415e..14a622aa 100644 --- a/src/test/utils.py +++ b/src/test/utils.py @@ -1,3 +1,8 @@ +from os import PathLike +from pathlib import Path +from typing import Any, List, Tuple, Type, Union + +import numpy as np import pytest import torch @@ -89,3 +94,32 @@ def get_default_device(): return torch.device("cuda") else: return torch.device("cpu") + + +Mmaps = List[Tuple[Union[Path, PathLike[Any], str], Any]] + + +def mk_mmaps( + tmp_path: Path, + prefix: str, + num_files: int, + size: int, + dtype: Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]] = np.uint32, + eos: int = 0, + seq_length: int = 4, + seed: int = 42, +) -> Mmaps: + mmaps: Mmaps = [] + for i in range(num_files): + filepath = f"{tmp_path}/{prefix}_{i}.npy" + np.random.seed(seed) + data = np.random.randint(1, np.iinfo(dtype).max, size=size, dtype=dtype) + data = np.append( + np.insert(data, np.arange(seq_length + 1, len(data), seq_length), eos), eos + ) + mm = np.memmap(filepath, mode="w+", dtype=dtype, shape=(len(data),)) + mm[:] = data + mm.flush() + mmaps.append((Path(filepath), data)) + + return mmaps