From f7f370930045a7b4a4b2ddf8663d6f79a156ad02 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 26 Sep 2024 13:11:25 -0700 Subject: [PATCH] Add some more `Config` methods --- CHANGELOG.md | 4 ++ README.md | 21 ++++++----- src/olmo_core/config.py | 55 ++++++++++++++++++++++++++-- src/olmo_core/data/numpy_dataset.py | 9 +++++ src/olmo_core/internal/experiment.py | 7 ++++ 5 files changed, 84 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bb56709e..49591d31 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Added + +- Added `Config.validate()`, `Config.replace()`, and `Config.apply()` methods. + ### Fixed - Ensure additional cached-path clients are added in the process pool workers from some dataset preparation methods. diff --git a/README.md b/README.md index 4bd5c3cd..aeece343 100644 --- a/README.md +++ b/README.md @@ -22,15 +22,18 @@ pip install ai2-olmo-core ## Official training scripts Official training scripts for various model sizes can be found in [`src/scripts/train/`](https://github.com/allenai/OLMo-core/tree/main/src/scripts/train). -Throughput numbers from a cluster with NVIDIA H100 GPUs are reported below. - -| Model size | Context length | Precision | Throughput[^1] | Launch command | -| :--------: | :------------: | :-------: | -----------: | :------ | -| 1B | 4K | BF16 | 44,000 TPS | `python src/scripts/train/OLMo-1B.py launch "${run_name}" "${cluster}"` | -| | | FP8 | 51,000 TPS | `python src/scripts/train/OLMo-1B.py launch "${run_name}" "${cluster}" --model.float8_config.enabled=true` | -| 7B | 4K | BF16 | 10,000 TPS | `python src/scripts/train/OLMo-7B.py launch "${run_name}" "${cluster}"` | -| | | FP8 | 13,000 TPS | `python src/scripts/train/OLMo-7B.py launch "${run_name}" "${cluster}" --model.float8_config.enabled=true` | -| 13B | 4K | BF16 | 4,600 TPS | `python src/scripts/train/OLMo-13B.py launch "${run_name}" "${cluster}"` | +To see the exact usage for each script, run the script without any arguments. + +Throughput numbers from these scripts with various different configuration settings are reported below, measured on a cluster with NVIDIA H100 GPUs. + +| Model size | Context length | Precision | Throughput[^1] | Train script | Overrides | +| :--------: | :------------: | :-------: | -----------: | :----------- | :-------- | +| 1B | 4K | BF16 | 44,000 TPS | `OLMo-1B.py` | | +| 1B | 256-8196 VSL | BF16 | 49,000 TPS | `OLMo-1B.py` | `--dataset.name=vsl` | +| | | FP8 | 51,000 TPS | `OLMo-1B.py` | `--model.float8_config.enabled=true` | +| 7B | 4K | BF16 | 10,000 TPS | `OLMo-7B.py` | | +| | | FP8 | 13,000 TPS | `OLMo-7B.py` | `--model.float8_config.enabled=true` | +| 13B | 4K | BF16 | 4,600 TPS | `OLMo-13B.py` | | [^1]: Throughput reported in tokens per second per device. diff --git a/src/olmo_core/config.py b/src/olmo_core/config.py index 103619a6..3e02b556 100644 --- a/src/olmo_core/config.py +++ b/src/olmo_core/config.py @@ -1,6 +1,17 @@ -from dataclasses import dataclass, fields, is_dataclass +from dataclasses import dataclass, fields, is_dataclass, replace from enum import Enum -from typing import Any, Dict, Generator, List, Optional, Tuple, Type, TypeVar, cast +from typing import ( + Any, + Callable, + Dict, + Generator, + List, + Optional, + Tuple, + Type, + TypeVar, + cast, +) import torch from omegaconf import OmegaConf as om @@ -117,6 +128,36 @@ def as_config_dict(self) -> Dict[str, Any]: recurse=True, ) + def apply(self, func: Callable[["Config"], None]): + """ + Recursively apply a function to every config instance field, including ``self``. + + :param func: The function to apply. + """ + + def apply(d): + if isinstance(d, Config): + func(d) + + if is_dataclass(d): + for field in fields(d): + value = getattr(d, field.name) + apply(value) + elif isinstance(d, dict): + for value in d.values(): + apply(value) + elif isinstance(d, (list, tuple, set)): + for x in d: + apply(x) + + apply(self) + + def validate(self): + """ + Validate fields in ``self``. This may modify ``self`` in-place. + """ + pass + def merge(self, dotlist: List[str]) -> Self: """ Merge self with fields from a "dotlist", creating a new object. @@ -126,10 +167,18 @@ def merge(self, dotlist: List[str]) -> Self: try: merge_fields = om.from_dotlist(_clean_opts(dotlist)) merged = om.merge(self, merge_fields) - return cast(Self, om.to_object(merged)) + out = cast(Self, om.to_object(merged)) + out.apply(lambda c: c.validate()) + return out except OmegaConfBaseException as e: raise OLMoConfigurationError(str(e)) + def replace(self, **changes) -> Self: + """ + Creates a new object of the same type, replacing fields with values from ``changes``. + """ + return replace(self, **changes) + @classmethod def from_dict(cls: Type[C], data: Dict[str, Any], overrides: Optional[List[str]] = None) -> C: """ diff --git a/src/olmo_core/data/numpy_dataset.py b/src/olmo_core/data/numpy_dataset.py index a9a0132a..522c97d0 100644 --- a/src/olmo_core/data/numpy_dataset.py +++ b/src/olmo_core/data/numpy_dataset.py @@ -1389,6 +1389,15 @@ class NumpyDatasetConfig(Config): all of you runs. """ + def validate(self): + if self.name in (NumpyDatasetType.fsl, NumpyDatasetType.padded_fsl): + self.max_sequence_length = None + self.min_sequence_length = None + self.vsl_curriculum = None + elif self.name == NumpyDatasetType.vsl: + self.sequence_length = None + self.max_target_sequence_length = None + @property def effective_sequence_length(self) -> int: if self.sequence_length is not None: diff --git a/src/olmo_core/internal/experiment.py b/src/olmo_core/internal/experiment.py index a7ea94ad..d6b36c0f 100644 --- a/src/olmo_core/internal/experiment.py +++ b/src/olmo_core/internal/experiment.py @@ -12,6 +12,8 @@ NumpyDatasetConfig, NumpyDatasetType, TokenizerConfig, + VSLCurriculumConfig, + VSLCurriculumType, ) from olmo_core.distributed.utils import get_num_nodes, init_hybrid_shard_mesh from olmo_core.float8 import Float8Config @@ -163,6 +165,11 @@ def build_common_components( mix_base_dir=root_dir, sequence_length=4096, max_target_sequence_length=8192, + min_sequence_length=256, + max_sequence_length=8192, + vsl_curriculum=VSLCurriculumConfig( + name=VSLCurriculumType.grow_p2, num_cycles=8, balanced=False + ), work_dir=None if is_url(root_dir) else f"{root_dir}/checkpoints/{beaker_user.lower()}/dataset-cache",