Skip to content

Commit

Permalink
Add model ladder building blocks (#114)
Browse files Browse the repository at this point in the history
Porting over from
https://github.com/allenai/OLMo/blob/ladder-1xC/scripts/ladder_peteish.py.

### Example

You can run a model size (e.g. the 190M) for the peteish ladder on
Beaker as follows:

```fish
python src/scripts/train/OLMo2-ladder.py launch 190M ai2/jupiter-cirrascale-2
```

### Notes for reviewers

The key file to review is `src/olmo_core/model_ladder.py`.
  • Loading branch information
epwalsh authored Nov 27, 2024
1 parent 1647f78 commit 8e716b5
Show file tree
Hide file tree
Showing 18 changed files with 970 additions and 88 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added an implementation of nGPT called `NormalizedTransformer`.
- Added an example showing how to convert a HuggingFace Llama 3.2 checkpoint into the right format for OLMo-core.
- Added an API for scaling RoPE embeddings.
- Added a `ModelLadder` API.

### Changed

Expand Down
16 changes: 8 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ Throughput numbers from these scripts with various different configuration setti

| Model size | Model arch.   | Context length | Precision | Throughput[^1] | Training script | Commandline overrides                                    |
| :--------: | :--------: | :------------: | :-------: | -----------: | :----------- | :-------- |
| **1B** | OLMo-1124 | 4096 | BF16 | 55,000 TPS | `OLMo-1B.py` | |
| | | 4096 | BF16/FP8[^2] | 65,000 TPS | `OLMo-1B.py` | `--model.float8_config.enabled=true` |
| **7B** | OLMo-1124 | 4096 | BF16 | 10,000 TPS | `OLMo-7B.py` | |
| | | 4096 | BF16/FP8 | 13,000 TPS | `OLMo-7B.py` | `--model.float8_config.enabled=true` |
| **8B** | Llama | 4096 | BF16 | 9,500 TPS | `Llama-8B.py` | |
| | | 4096 | BF16/FP8 | 12,500 TPS | `Llama-8B.py` | `--model.float8_config.enabled=true` |
| **13B** | OLMo-1124 | 4096 | BF16 | 4,600 TPS | `OLMo-13B.py` | |
| | | 4096 | BF16/FP8 | 5,500 TPS | `OLMo-13B.py` | `--model.float8_config.enabled=true` |
| **1B** | OLMo-1124 | 4096 | BF16 | 55,000 TPS | `OLMo2-1B.py` | |
| | | 4096 | BF16/FP8[^2] | 65,000 TPS | `OLMo2-1B.py` | `--model.float8_config.enabled=true` |
| **7B** | OLMo-1124 | 4096 | BF16 | 10,000 TPS | `OLMo2-7B.py` | |
| | | 4096 | BF16/FP8 | 13,000 TPS | `OLMo2-7B.py` | `--model.float8_config.enabled=true` |
| **8B** | Llama | 4096 | BF16 | 9,500 TPS | `Llama3-8B.py` | |
| | | 4096 | BF16/FP8 | 12,500 TPS | `Llama3-8B.py` | `--model.float8_config.enabled=true` |
| **13B** | OLMo-1124 | 4096 | BF16 | 4,600 TPS | `OLMo2-13B.py` | |
| | | 4096 | BF16/FP8 | 5,500 TPS | `OLMo2-13B.py` | `--model.float8_config.enabled=true` |

[^1]: Throughput reported in tokens per second per device.
[^2]: In this setup most matrix multiplications are computed in `float8`, everything else is in `bfloat16`.
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ specific to your environment. Then you can install OLMo-core from PyPI with:
float8
io
launch
model_ladder
nn/index
optim
train/index
Expand Down
5 changes: 5 additions & 0 deletions docs/source/model_ladder.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
``model_ladder``
================

.. automodule:: olmo_core.model_ladder
:members:
22 changes: 20 additions & 2 deletions src/olmo_core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,32 @@ def validate(self):
"""
pass

def merge(self, dotlist: List[str]) -> Self:
def merge(self, dotlist: List[str], prefix: Optional[str] = None, strict: bool = True) -> Self:
"""
Merge self with fields from a "dotlist", creating a new object.
:param dotlist: A list of field attributes with dot notation, e.g. ``foo.bar=1``.
:param prefix: Only use override items in the dotlist that start with a given prefix name,
and strip that prefix (including the subsequent ".") before applying the overrides.
:param strict: Parse the dotlist strictly.
"""
try:
merge_fields = om.from_dotlist(_clean_opts(dotlist))
dotlist = _clean_opts(dotlist)
if prefix is not None:
dotlist = [o.lstrip(f"{prefix}.") for o in dotlist if o.startswith(f"{prefix}.")]
if not strict:
field_names = set(f.name for f in fields(self))
dotlist = [
o
for o in dotlist
if any(
[
o.startswith(f"{name}=") or o.startswith(f"{name}.")
for name in field_names
]
)
]
merge_fields = om.from_dotlist(dotlist)
merged = om.merge(self, merge_fields)
out = cast(Self, om.to_object(merged))
out.apply(lambda c: c.validate())
Expand Down
125 changes: 125 additions & 0 deletions src/olmo_core/internal/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import logging
from typing import List, Optional

from beaker import Beaker

from olmo_core.io import is_url
from olmo_core.launch.beaker import (
BeakerEnvSecret,
BeakerLaunchConfig,
BeakerWekaBucket,
OLMoCoreBeakerImage,
)
from olmo_core.utils import generate_uuid

log = logging.getLogger(__name__)
_BEAKER_CLIENT: Optional[Beaker] = None
_BEAKER_USERNAME: Optional[str] = None


def get_beaker_client() -> Beaker:
global _BEAKER_CLIENT

if _BEAKER_CLIENT is None:
_BEAKER_CLIENT = Beaker.from_env()

return _BEAKER_CLIENT


def get_beaker_username() -> str:
global _BEAKER_USERNAME

if _BEAKER_USERNAME is None:
_BEAKER_USERNAME = get_beaker_client().account.whoami().name

return _BEAKER_USERNAME


def get_root_dir(cluster: str) -> str:
root_dir: str = "weka://oe-training-default/ai2-llm"
if "jupiter" in cluster:
root_dir = "/weka/oe-training-default/ai2-llm"
elif "augusta" in cluster:
root_dir = "gs://ai2-llm"
return root_dir


def get_work_dir(root_dir: str) -> str:
return (
"./dataset-cache"
if is_url(root_dir)
else f"{root_dir}/checkpoints/{get_beaker_username().lower()}/dataset-cache"
)


def build_launch_config(
*,
name: str,
root_dir: str,
cmd: List[str],
cluster: str,
task_name: str = "train",
workspace: str = "ai2/OLMo-core",
budget: str = "ai2/oe-training",
) -> BeakerLaunchConfig:
weka_buckets: List[BeakerWekaBucket] = []
if root_dir.startswith("/weka/"):
weka_buckets.append(BeakerWekaBucket("oe-training-default", "/weka/oe-training-default"))

beaker_user = get_beaker_username()

return BeakerLaunchConfig(
name=f"{name}-{generate_uuid()[:8]}",
budget=budget,
cmd=cmd,
task_name=task_name,
workspace=workspace,
clusters=[cluster],
weka_buckets=weka_buckets,
beaker_image=OLMoCoreBeakerImage.nightly, # some features require nightly at the moment
num_nodes=1,
num_gpus=8,
shared_filesystem=not is_url(root_dir),
allow_dirty=False,
env_secrets=[
BeakerEnvSecret(name="BEAKER_TOKEN", secret=f"{beaker_user}_BEAKER_TOKEN"),
BeakerEnvSecret(name="WANDB_API_KEY", secret=f"{beaker_user}_WANDB_API_KEY"),
BeakerEnvSecret(name="COMET_API_KEY", secret=f"{beaker_user}_COMET_API_KEY"),
BeakerEnvSecret(name="AWS_CONFIG", secret=f"{beaker_user}_AWS_CONFIG"),
BeakerEnvSecret(name="AWS_CREDENTIALS", secret=f"{beaker_user}_AWS_CREDENTIALS"),
BeakerEnvSecret(name="R2_ENDPOINT_URL", secret="R2_ENDPOINT_URL"),
BeakerEnvSecret(name="WEKA_ENDPOINT_URL", secret="WEKA_ENDPOINT_URL"),
],
setup_steps=[
# Clone repo.
'git clone "$REPO_URL" .',
'git checkout "$GIT_REF"',
"git submodule update --init --recursive",
# Setup python environment.
"conda shell.bash activate base",
"pip install -e '.[all]'",
"pip freeze",
# Move AWS credentials from env to relevant files
"mkdir -p ~/.aws",
"printenv AWS_CONFIG > ~/.aws/config",
"printenv AWS_CREDENTIALS > ~/.aws/credentials",
],
)


CLUSTER_TO_GPU_TYPE = {
"ai2/jupiter-cirrascale-2": "NVIDIA H100 80GB HBM3",
"ai2/pluto-cirrascale": "NVIDIA H100",
"ai2/augusta-google-1": "NVIDIA H100",
}


def get_gpu_type(cluster: str) -> str:
if cluster in CLUSTER_TO_GPU_TYPE:
return CLUSTER_TO_GPU_TYPE[cluster]
else:
log.warning(f"Missing cluster '{cluster}' in CLUSTER_TO_GPU_TYPE mapping")
beaker = get_beaker_client()
nodes = beaker.cluster.nodes(cluster)
assert nodes and nodes[0].limits.gpu_type
return nodes[0].limits.gpu_type
88 changes: 17 additions & 71 deletions src/olmo_core/internal/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, cast

from beaker import Beaker
from rich import print
from torch.distributed.device_mesh import DeviceMesh

Expand All @@ -19,13 +18,7 @@
)
from olmo_core.distributed.utils import get_num_nodes, init_hybrid_shard_mesh
from olmo_core.float8 import Float8Config
from olmo_core.io import is_url
from olmo_core.launch.beaker import (
BeakerEnvSecret,
BeakerLaunchConfig,
BeakerWekaBucket,
OLMoCoreBeakerImage,
)
from olmo_core.launch.beaker import BeakerLaunchConfig
from olmo_core.nn.transformer import TransformerConfig
from olmo_core.optim import CosWithWarmup, OptimConfig
from olmo_core.train import (
Expand All @@ -46,12 +39,9 @@
SchedulerCallback,
WandBCallback,
)
from olmo_core.utils import (
generate_uuid,
get_default_device,
prepare_cli_environment,
seed_all,
)
from olmo_core.utils import get_default_device, prepare_cli_environment, seed_all

from .common import build_launch_config, get_beaker_username, get_root_dir, get_work_dir

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -109,7 +99,7 @@ def prepare_environment(self):
elif self == SubCmd.train:
prepare_training_environment()
else:
raise NotADirectoryError(self)
raise NotImplementedError(self)

def run(self, config: ExperimentConfig):
print(config)
Expand All @@ -133,7 +123,7 @@ def run(self, config: ExperimentConfig):
elif self == SubCmd.launch_prep:
launch_prep(config)
else:
raise NotADirectoryError(self)
raise NotImplementedError(self)


def build_common_components(
Expand All @@ -145,57 +135,21 @@ def build_common_components(
*,
global_batch_size: int,
) -> CommonComponents:
root_dir: str = "weka://oe-training-default/ai2-llm"
weka_buckets: List[BeakerWekaBucket] = []
if "jupiter" in cluster:
root_dir = "/weka/oe-training-default/ai2-llm"
weka_buckets.append(BeakerWekaBucket("oe-training-default", "/weka/oe-training-default"))
elif "augusta" in cluster:
root_dir = "gs://ai2-llm"

beaker_user = (Beaker.from_env().account.whoami().name).upper()
root_dir = get_root_dir(cluster)

cmd_to_launch = SubCmd.train
if cmd == SubCmd.launch_prep:
cmd_to_launch = SubCmd.prep

launch_config = BeakerLaunchConfig(
name=f"{run_name}-{cmd_to_launch}-{generate_uuid()[:8]}",
budget="ai2/oe-training",
launch_config = build_launch_config(
name=f"{run_name}-{cmd_to_launch}",
root_dir=root_dir,
cmd=[script, cmd_to_launch, run_name, cluster, *overrides],
task_name="train",
workspace="ai2/OLMo-core",
clusters=[cluster],
weka_buckets=weka_buckets,
beaker_image=OLMoCoreBeakerImage.nightly, # some features require nightly at the moment
num_nodes=1,
num_gpus=8,
shared_filesystem=not is_url(root_dir),
allow_dirty=False,
env_secrets=[
BeakerEnvSecret(name="BEAKER_TOKEN", secret=f"{beaker_user}_BEAKER_TOKEN"),
BeakerEnvSecret(name="WANDB_API_KEY", secret=f"{beaker_user}_WANDB_API_KEY"),
BeakerEnvSecret(name="COMET_API_KEY", secret=f"{beaker_user}_COMET_API_KEY"),
BeakerEnvSecret(name="AWS_CONFIG", secret=f"{beaker_user}_AWS_CONFIG"),
BeakerEnvSecret(name="AWS_CREDENTIALS", secret=f"{beaker_user}_AWS_CREDENTIALS"),
BeakerEnvSecret(name="R2_ENDPOINT_URL", secret="R2_ENDPOINT_URL"),
BeakerEnvSecret(name="WEKA_ENDPOINT_URL", secret="WEKA_ENDPOINT_URL"),
],
setup_steps=[
# Clone repo.
'git clone "$REPO_URL" .',
'git checkout "$GIT_REF"',
"git submodule update --init --recursive",
# Setup python environment.
"conda shell.bash activate base",
"pip install -e '.[all]'",
"pip freeze",
# Move AWS credentials from env to relevant files
"mkdir -p ~/.aws",
"printenv AWS_CONFIG > ~/.aws/config",
"printenv AWS_CREDENTIALS > ~/.aws/credentials",
],
cluster=cluster,
)

beaker_user = get_beaker_username()

tokenizer_config = TokenizerConfig.dolma2()

dataset_config = NumpyDatasetConfig.from_data_mix(
Expand All @@ -209,11 +163,7 @@ def build_common_components(
vsl_curriculum=VSLCurriculumConfig(
name=VSLCurriculumType.grow_p2, num_cycles=8, balanced=False
),
work_dir=(
"./dataset-cache"
if is_url(root_dir)
else f"{root_dir}/checkpoints/{beaker_user.lower()}/dataset-cache"
),
work_dir=get_work_dir(root_dir),
)

data_loader_config = NumpyDataLoaderConfig(
Expand All @@ -234,11 +184,7 @@ def build_common_components(
mix_base_dir=root_dir,
sequence_length=dataset_config.effective_sequence_length,
tokenizer=tokenizer_config,
work_dir=(
"./dataset-cache"
if is_url(root_dir)
else f"{root_dir}/checkpoints/{beaker_user.lower()}/dataset-cache"
),
work_dir=get_work_dir(root_dir),
),
eval_interval=1000,
),
Expand Down Expand Up @@ -345,7 +291,7 @@ def train(config: ExperimentConfig):
data_loader = config.data_loader.build(dataset)
trainer = config.trainer.build(model, optim, data_loader)

# Record the config to W&B and each checkpoint dir.
# Record the config to W&B/Comet and each checkpoint dir.
config_dict = config.as_config_dict()
cast(CometCallback, trainer.callbacks["comet"]).config = config_dict
cast(WandBCallback, trainer.callbacks["wandb"]).config = config_dict
Expand Down
Loading

0 comments on commit 8e716b5

Please sign in to comment.