Skip to content

Commit

Permalink
Add some more Config methods
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Sep 26, 2024
1 parent 916ecf9 commit f7f3709
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 12 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Added

- Added `Config.validate()`, `Config.replace()`, and `Config.apply()` methods.

### Fixed

- Ensure additional cached-path clients are added in the process pool workers from some dataset preparation methods.
Expand Down
21 changes: 12 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,18 @@ pip install ai2-olmo-core
## Official training scripts

Official training scripts for various model sizes can be found in [`src/scripts/train/`](https://github.com/allenai/OLMo-core/tree/main/src/scripts/train).
Throughput numbers from a cluster with NVIDIA H100 GPUs are reported below.

| Model size | Context length | Precision | Throughput[^1] | Launch command |
| :--------: | :------------: | :-------: | -----------: | :------ |
| 1B | 4K | BF16 | 44,000 TPS | `python src/scripts/train/OLMo-1B.py launch "${run_name}" "${cluster}"` |
| | | FP8 | 51,000 TPS | `python src/scripts/train/OLMo-1B.py launch "${run_name}" "${cluster}" --model.float8_config.enabled=true` |
| 7B | 4K | BF16 | 10,000 TPS | `python src/scripts/train/OLMo-7B.py launch "${run_name}" "${cluster}"` |
| | | FP8 | 13,000 TPS | `python src/scripts/train/OLMo-7B.py launch "${run_name}" "${cluster}" --model.float8_config.enabled=true` |
| 13B | 4K | BF16 | 4,600 TPS | `python src/scripts/train/OLMo-13B.py launch "${run_name}" "${cluster}"` |
To see the exact usage for each script, run the script without any arguments.

Throughput numbers from these scripts with various different configuration settings are reported below, measured on a cluster with NVIDIA H100 GPUs.

| Model size | Context length | Precision | Throughput[^1] | Train script | Overrides |
| :--------: | :------------: | :-------: | -----------: | :----------- | :-------- |
| 1B | 4K | BF16 | 44,000 TPS | `OLMo-1B.py` | |
| 1B | 256-8196 VSL | BF16 | 49,000 TPS | `OLMo-1B.py` | `--dataset.name=vsl` |
| | | FP8 | 51,000 TPS | `OLMo-1B.py` | `--model.float8_config.enabled=true` |
| 7B | 4K | BF16 | 10,000 TPS | `OLMo-7B.py` | |
| | | FP8 | 13,000 TPS | `OLMo-7B.py` | `--model.float8_config.enabled=true` |
| 13B | 4K | BF16 | 4,600 TPS | `OLMo-13B.py` | |

[^1]: Throughput reported in tokens per second per device.

Expand Down
55 changes: 52 additions & 3 deletions src/olmo_core/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
from dataclasses import dataclass, fields, is_dataclass
from dataclasses import dataclass, fields, is_dataclass, replace
from enum import Enum
from typing import Any, Dict, Generator, List, Optional, Tuple, Type, TypeVar, cast
from typing import (
Any,
Callable,
Dict,
Generator,
List,
Optional,
Tuple,
Type,
TypeVar,
cast,
)

import torch
from omegaconf import OmegaConf as om
Expand Down Expand Up @@ -117,6 +128,36 @@ def as_config_dict(self) -> Dict[str, Any]:
recurse=True,
)

def apply(self, func: Callable[["Config"], None]):
"""
Recursively apply a function to every config instance field, including ``self``.
:param func: The function to apply.
"""

def apply(d):
if isinstance(d, Config):
func(d)

if is_dataclass(d):
for field in fields(d):
value = getattr(d, field.name)
apply(value)
elif isinstance(d, dict):
for value in d.values():
apply(value)
elif isinstance(d, (list, tuple, set)):
for x in d:
apply(x)

apply(self)

def validate(self):
"""
Validate fields in ``self``. This may modify ``self`` in-place.
"""
pass

def merge(self, dotlist: List[str]) -> Self:
"""
Merge self with fields from a "dotlist", creating a new object.
Expand All @@ -126,10 +167,18 @@ def merge(self, dotlist: List[str]) -> Self:
try:
merge_fields = om.from_dotlist(_clean_opts(dotlist))
merged = om.merge(self, merge_fields)
return cast(Self, om.to_object(merged))
out = cast(Self, om.to_object(merged))
out.apply(lambda c: c.validate())
return out
except OmegaConfBaseException as e:
raise OLMoConfigurationError(str(e))

def replace(self, **changes) -> Self:
"""
Creates a new object of the same type, replacing fields with values from ``changes``.
"""
return replace(self, **changes)

@classmethod
def from_dict(cls: Type[C], data: Dict[str, Any], overrides: Optional[List[str]] = None) -> C:
"""
Expand Down
9 changes: 9 additions & 0 deletions src/olmo_core/data/numpy_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1389,6 +1389,15 @@ class NumpyDatasetConfig(Config):
all of you runs.
"""

def validate(self):
if self.name in (NumpyDatasetType.fsl, NumpyDatasetType.padded_fsl):
self.max_sequence_length = None
self.min_sequence_length = None
self.vsl_curriculum = None
elif self.name == NumpyDatasetType.vsl:
self.sequence_length = None
self.max_target_sequence_length = None

@property
def effective_sequence_length(self) -> int:
if self.sequence_length is not None:
Expand Down
7 changes: 7 additions & 0 deletions src/olmo_core/internal/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
NumpyDatasetConfig,
NumpyDatasetType,
TokenizerConfig,
VSLCurriculumConfig,
VSLCurriculumType,
)
from olmo_core.distributed.utils import get_num_nodes, init_hybrid_shard_mesh
from olmo_core.float8 import Float8Config
Expand Down Expand Up @@ -163,6 +165,11 @@ def build_common_components(
mix_base_dir=root_dir,
sequence_length=4096,
max_target_sequence_length=8192,
min_sequence_length=256,
max_sequence_length=8192,
vsl_curriculum=VSLCurriculumConfig(
name=VSLCurriculumType.grow_p2, num_cycles=8, balanced=False
),
work_dir=None
if is_url(root_dir)
else f"{root_dir}/checkpoints/{beaker_user.lower()}/dataset-cache",
Expand Down

0 comments on commit f7f3709

Please sign in to comment.