Skip to content

Commit

Permalink
Add a callback for downstream evals, update Docker builds (#73)
Browse files Browse the repository at this point in the history
Replaces #72. Gives us complete parity with the original downstream
evals in the OLMo repo.
  • Loading branch information
epwalsh authored Oct 30, 2024
1 parent ecd523e commit 3fe59b6
Show file tree
Hide file tree
Showing 16 changed files with 321 additions and 32 deletions.
7 changes: 4 additions & 3 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ jobs:
src/test/
- name: Test checkpoint (GPU)
image: olmo-core
image: olmo-core-nightly
gpus: 2
run: |
pytest -v --color=yes --durations=3 -m gpu \
Expand Down Expand Up @@ -180,10 +180,11 @@ jobs:
gpuCount: ${{ matrix.task.gpus }}
constraints:
cluster:
- ai2/allennlp-cirrascale
- ai2/allennlp-elanding-a100-40g
# - ai2/allennlp-cirrascale
# - ai2/allennlp-elanding-a100-40g
- ai2/pluto-cirrascale
- ai2/jupiter-cirrascale-2
# - ai2/saturn-cirrascale
envVars:
- name: CUBLAS_WORKSPACE_CONFIG
value: ":16:8"
Expand Down
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Added

- Added `DownstreamEvaluatorCallbackConfig` class for running in-loop downstream eval via [OLMo-in-loop-evals](https://github.com/allenai/OLMo-in-loop-evals).

### Removed

- Removed `flash-attn` from the Beaker images since `flash-attn` currently can't be built for torch 2.5.1. We are waiting on updates from the `flash-attn` maintainers. See https://github.com/Dao-AILab/flash-attention/issues/1302.

## [v1.5.0](https://github.com/allenai/OLMo-core/releases/tag/v1.5.0) - 2024-10-23

### Added
Expand Down
12 changes: 8 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
BASE_IMAGE = ghcr.io/allenai/pytorch:2.4.1-cuda12.1-python3.11
# NOTE: make sure CUDA versions match across these variables
BASE_IMAGE = ghcr.io/allenai/pytorch:2.5.1-cuda12.1-python3.11-v2024.10.29
CUDA_TOOLKIT_VERSION = 12.1.0
TORCH_CUDA_VERSION = 121

# NOTE: when upgrading the nightly version you also need to upgrade the torch version specification
# in 'pyproject.toml' to include that nightly version.
NIGHTLY_VERSION = "2.6.0.dev20241009+cu121 --index-url https://download.pytorch.org/whl/nightly/cu121"
TORCHAO_VERSION = "torchao==0.5.0 --extra-index-url https://download.pytorch.org/whl/cu121"
NIGHTLY_VERSION = "2.6.0.dev20241009+cu121"
TORCHAO_VERSION = "torchao==0.5.0"
MEGABLOCKS_VERSION = "megablocks[gg] @ git+https://[email protected]/epwalsh/megablocks.git@epwalsh/deps"
CUDA_TOOLKIT_VERSION = 12.1.0

VERSION = $(shell python src/olmo_core/version.py)
VERSION_SHORT = $(shell python src/olmo_core/version.py short)
Expand Down Expand Up @@ -49,6 +51,7 @@ stable-image :
--build-arg BUILDKIT_INLINE_CACHE=1 \
--build-arg BASE=$(BASE_IMAGE) \
--build-arg CUDA_TOOLKIT_VERSION=$(CUDA_TOOLKIT_VERSION) \
--build-arg TORCH_CUDA_VERSION=$(TORCH_CUDA_VERSION) \
--build-arg MEGABLOCKS_VERSION=$(MEGABLOCKS_VERSION) \
--build-arg TORCHAO_VERSION=$(TORCHAO_VERSION) \
--target stable \
Expand All @@ -62,6 +65,7 @@ nightly-image :
--build-arg BUILDKIT_INLINE_CACHE=1 \
--build-arg BASE=$(BASE_IMAGE) \
--build-arg CUDA_TOOLKIT_VERSION=$(CUDA_TOOLKIT_VERSION) \
--build-arg TORCH_CUDA_VERSION=$(TORCH_CUDA_VERSION) \
--build-arg MEGABLOCKS_VERSION=$(MEGABLOCKS_VERSION) \
--build-arg TORCHAO_VERSION=$(TORCHAO_VERSION) \
--build-arg NIGHTLY_VERSION=$(NIGHTLY_VERSION) \
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ dependencies = [
"omegaconf",
"safetensors",
"importlib_resources",
"ai2-olmo-eval==0.2.0",
]

[project.urls]
Expand Down
29 changes: 25 additions & 4 deletions src/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,39 @@ WORKDIR /app/build
ARG CUDA_TOOLKIT_VERSION
RUN conda install -y -c nvidia cuda-toolkit==${CUDA_TOOLKIT_VERSION}

ARG TORCH_CUDA_VERSION

# Build megablocks and grouped-gemm.
ENV TORCH_CUDA_ARCH_LIST="8.0 9.0"
ENV GROUPED_GEMM_CUTLASS=1
ARG MEGABLOCKS_VERSION
RUN pip wheel --no-build-isolation --no-cache-dir "${MEGABLOCKS_VERSION}" \
&& rm -rf torch-*.whl numpy-*.whl triton-*.whl
RUN pip wheel --no-build-isolation --no-cache-dir \
--extra-index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION} \
"${MEGABLOCKS_VERSION}"

# Flash-attn from pre-built wheel (can't get this to work at the moment)
#RUN wget https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.4cxx11abiTRUE-cp311-cp311-linux_x86_64.whl

# Only keep the target wheels and dependencies with CUDA extensions.
RUN echo "Built wheels:" \
&& ls -lh . \
&& ls -1 | grep -Ev 'megablocks|grouped_gemm|stanford_stk|flash_attn' | xargs rm \
&& echo "Final wheels:" \
&& ls -lh .

#########################################################################
# Stable image
#########################################################################

FROM ${BASE} as stable

ARG TORCH_CUDA_VERSION

# Install torchao.
ARG TORCHAO_VERSION
RUN pip install --no-cache-dir ${TORCHAO_VERSION}
RUN pip install --no-cache-dir \
--extra-index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION} \
${TORCHAO_VERSION}

# Copy and install wheels from build image.
COPY --from=build /app/build /app/build
Expand All @@ -50,5 +67,9 @@ WORKDIR /app/olmo-core

FROM stable as nightly

ARG TORCH_CUDA_VERSION

ARG NIGHTLY_VERSION
RUN pip install --no-cache-dir --pre torch==${NIGHTLY_VERSION}
RUN pip install --no-cache-dir --pre \
--index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION} \
torch==${NIGHTLY_VERSION}
11 changes: 10 additions & 1 deletion src/examples/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
CheckpointerCallback,
CometCallback,
ConfigSaverCallback,
DownstreamEvaluatorCallbackConfig,
GPUMemoryMonitorCallback,
GradClipperCallback,
LMEvaluatorCallbackConfig,
Expand Down Expand Up @@ -133,7 +134,7 @@ def build_config(run_name: str, overrides: List[str]) -> ExperimentConfig:
.with_callback("config_saver", ConfigSaverCallback())
.with_callback("profiler", ProfilerCallback(enabled=False))
.with_callback(
"evaluator",
"lm_evaluator",
LMEvaluatorCallbackConfig(
eval_dataset=NumpyDatasetConfig(
paths=["/net/nfs/allennlp/llm-data/c4/en/c4-validation.00000-00008.npy"],
Expand All @@ -147,6 +148,14 @@ def build_config(run_name: str, overrides: List[str]) -> ExperimentConfig:
eval_duration=Duration.steps(10),
),
)
.with_callback(
"downstream_evaluator",
DownstreamEvaluatorCallbackConfig(
tasks=["hellaswag"],
tokenizer=tokenizer_config,
eval_interval=250,
),
)
)

return ExperimentConfig(
Expand Down
6 changes: 4 additions & 2 deletions src/olmo_core/data/mixes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from abc import abstractmethod
from contextlib import contextmanager
from pathlib import Path
from typing import Generator, List, Tuple
Expand All @@ -15,7 +16,8 @@ class DataMixBase(StrEnum):
Base class for enumeration of data mixes.
"""

def build(self, base_dir: str, tokenizer: TokenizerName) -> Tuple[List[str], List[str]]:
@abstractmethod
def build(self, base_dir: str, tokenizer: str) -> Tuple[List[str], List[str]]:
"""
Construct the data mix.
Expand All @@ -37,7 +39,7 @@ class DataMix(DataMixBase):
dolma17 = "dolma17"
v3_small_ppl_validation = "v3-small-ppl-validation"

def build(self, base_dir: str, tokenizer: TokenizerName) -> Tuple[List[str], List[str]]:
def build(self, base_dir: str, tokenizer: str) -> Tuple[List[str], List[str]]:
if not base_dir.endswith("/"):
base_dir = base_dir + "/"

Expand Down
48 changes: 46 additions & 2 deletions src/olmo_core/data/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,31 @@

from ..config import Config, StrEnum

__all__ = [
"TokenizerConfig",
"TokenizerName",
]


class TokenizerName(StrEnum):
"""
An enumeration of supported tokenizer names.
An enumeration of tokenizer identifiers commonly used OLMo researchers.
"""

dolma2 = "allenai/dolma2-tokenizer"
"""
The dolma2 tokenizer.
"""

gpt_neox_olmo_dolma_v1_5 = "allenai/gpt-neox-olmo-dolma-v1_5"
"""
A modified GPT NeoX tokenizer.
"""

gpt2 = "gpt2"
"""
The base GPT2 tokenizer.
"""


@dataclass
Expand All @@ -21,10 +37,29 @@ class TokenizerConfig(Config):
"""

vocab_size: int
"""
The vocab size.
"""

eos_token_id: int
"""
The end-of-sentence token ID.
"""

pad_token_id: int
"""
The padding token ID.
"""

bos_token_id: Optional[int] = None
identifier: Optional[TokenizerName] = None
"""
The begin-of-sentence token ID.
"""

identifier: Optional[str] = None
"""
The identifier of the tokenizer. Could be a path or HuggingFace identifier.
"""

def padded_vocab_size(self, pad_multiple: int = 128) -> int:
"""
Expand All @@ -35,6 +70,9 @@ def padded_vocab_size(self, pad_multiple: int = 128) -> int:

@classmethod
def dolma2(cls) -> "TokenizerConfig":
"""
Get a :data:`~TokenizerName.dolma2` tokenizer config.
"""
return cls(
vocab_size=100278,
eos_token_id=100257,
Expand All @@ -44,6 +82,9 @@ def dolma2(cls) -> "TokenizerConfig":

@classmethod
def gpt_neox_olmo_dolma_v1_5(cls) -> "TokenizerConfig":
"""
Get a :data:`~TokenizerName.gpt_neox_olmo_dolma_v1_5` tokenizer config.
"""
return cls(
vocab_size=50280,
eos_token_id=50279,
Expand All @@ -53,6 +94,9 @@ def gpt_neox_olmo_dolma_v1_5(cls) -> "TokenizerConfig":

@classmethod
def gpt2(cls) -> "TokenizerConfig":
"""
Get a :data:`~TokenizerName.gpt2` tokenizer config.
"""
return cls(
vocab_size=50280,
eos_token_id=50256,
Expand Down
19 changes: 18 additions & 1 deletion src/olmo_core/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from ..config import StrEnum
from ..exceptions import OLMoConfigurationError, OLMoEnvironmentError
from ..utils import get_default_device, set_env_var
from ..utils import get_default_device, move_to_device, set_env_var

OLMO_SHARED_FS_ENV_VAR = "OLMO_SHARED_FS"
OLMO_FS_LOCAL_RANK_ENV_VAR = "FS_LOCAL_RANK"
Expand Down Expand Up @@ -270,6 +270,23 @@ def scatter_object(obj: T, src: int = 0, group: Optional[dist.ProcessGroup] = No
return output_list[0]


def all_gather(
tensor: torch.Tensor, group: Optional[dist.ProcessGroup] = None
) -> List[torch.Tensor]:
"""
All-gather tensors from the whole group into a list.
"""
if not is_distributed():
return [tensor]

shapes = all_gather_object(tensor.shape, group=group)
output_list = [
move_to_device(torch.zeros(shape, dtype=tensor.dtype), tensor.device) for shape in shapes
]
dist.all_gather(output_list, tensor, group=group)
return output_list


def all_gather_object(obj: T, group: Optional[dist.ProcessGroup] = None) -> List[T]:
"""
All-gather an object using pickle to all ranks in a process group.
Expand Down
7 changes: 7 additions & 0 deletions src/olmo_core/eval/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ def __init__(
def update(
self, value: Union[float, torch.Tensor], weight: Union[float, torch.Tensor] = 1.0
) -> None:
"""
:param value: The latest value to update the metric with. Could be a tensor of values.
:param weight: The corresponding weight(s) for the value. Should be the same shape as ``value``.
"""
value = self.as_tensor(value)
weight = torch.broadcast_to(self.as_tensor(weight), value.shape)
if value.numel() == 0:
Expand All @@ -75,6 +79,9 @@ def update(
self.weight += weight.sum()

def compute(self) -> torch.Tensor:
"""
Computes the mean over the values and weights given.
"""
weighted_sum = all_reduce_value(
self.weighted_sum, device=self.device, group=self.process_group
)
Expand Down
7 changes: 6 additions & 1 deletion src/olmo_core/train/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
from .comet import CometCallback, CometNotificationSetting
from .config_saver import ConfigSaverCallback
from .console_logger import ConsoleLoggerCallback
from .evaluator_callback import EvaluatorCallback, LMEvaluatorCallbackConfig
from .evaluator_callback import (
DownstreamEvaluatorCallbackConfig,
EvaluatorCallback,
LMEvaluatorCallbackConfig,
)
from .float8_handler import Float8HandlerCallback
from .garbage_collector import GarbageCollectorCallback
from .gpu_memory_monitor import GPUMemoryMonitorCallback
Expand All @@ -27,6 +31,7 @@
"EvaluatorCallback",
"Float8HandlerCallback",
"LMEvaluatorCallbackConfig",
"DownstreamEvaluatorCallbackConfig",
"MoEHandlerCallback",
"GarbageCollectorCallback",
"GPUMemoryMonitorCallback",
Expand Down
Loading

0 comments on commit 3fe59b6

Please sign in to comment.