Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Async eval callback #702

Merged
merged 66 commits into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
e70a2f4
Async eval callback
aspfohl Oct 27, 2023
acd8b2e
add very basic tests
aspfohl Nov 8, 2023
aef814b
more tests
aspfohl Nov 8, 2023
007ae90
bump mcli
aspfohl Nov 8, 2023
d04aa40
Merge branch 'main' into anna/asynceval
aspfohl Nov 9, 2023
6cd020f
woop, missing import
aspfohl Nov 9, 2023
9fbe7a1
instance not specified error
aspfohl Nov 9, 2023
547ec21
fixes
aspfohl Nov 9, 2023
9819456
fix typing
aspfohl Nov 9, 2023
ba871d7
small testing fixes
aspfohl Nov 9, 2023
21e9880
launch new run only on main process
aspfohl Nov 9, 2023
47a8255
logger name
aspfohl Nov 9, 2023
08c24be
items
aspfohl Nov 9, 2023
bf415f0
format
aspfohl Nov 10, 2023
ecacdac
Merge branch 'main' into anna/asynceval
aspfohl Nov 10, 2023
5616ae4
Update llmfoundry/callbacks/async_eval_callback.py
aspfohl Nov 10, 2023
bc1647a
Update llmfoundry/callbacks/async_eval_callback.py
aspfohl Nov 10, 2023
3358837
feedback
aspfohl Nov 10, 2023
28e47df
Apply suggestions from code review
aspfohl Nov 13, 2023
78cc0b8
small updates
aspfohl Nov 13, 2023
7baa53f
Merge branch 'main' into anna/asynceval
aspfohl Nov 13, 2023
b58ccf9
use parameters from train.py to capture overrides and mounted paramet…
aspfohl Nov 13, 2023
d85ee5e
config_overrides
aspfohl Nov 14, 2023
0e96fea
Merge branch 'main' into anna/asynceval
aspfohl Nov 14, 2023
194774d
updates
aspfohl Nov 15, 2023
ba1280a
Merge branch 'main' into anna/asynceval
aspfohl Nov 15, 2023
08857d5
fix test
aspfohl Nov 15, 2023
6ce8b77
small fixes
aspfohl Nov 15, 2023
de155f7
add logging
aspfohl Nov 15, 2023
e5f9e9e
remove last launch check
aspfohl Nov 15, 2023
3f518f9
better logging
aspfohl Nov 15, 2023
ea742ef
fix parameters
aspfohl Nov 16, 2023
e3623f3
fix double unit in the name
aspfohl Nov 16, 2023
2deef3f
sadz
aspfohl Nov 16, 2023
91bdf43
Merge branch 'main' into anna/asynceval
aspfohl Nov 16, 2023
e8f4661
Merge branch 'main' into anna/asynceval
dakinggg Nov 30, 2023
f9e2dc7
Merge branch 'main' into anna/asynceval
aspfohl Dec 2, 2023
add7fbb
fies
aspfohl Dec 2, 2023
238086f
git integration path validation and update
aspfohl Dec 4, 2023
53a9943
detect forks, better error/comment
aspfohl Dec 4, 2023
1f35a7b
version import
aspfohl Dec 4, 2023
99f48cb
merge with main
aspfohl Dec 4, 2023
1184531
last checkpoint
aspfohl Dec 5, 2023
7af7383
post_close -> close
aspfohl Dec 6, 2023
9337af0
add todos, fix path bug
aspfohl Dec 6, 2023
87ffd86
add missing args
aspfohl Dec 6, 2023
e940f1c
remove eval_loader in callback too
aspfohl Dec 6, 2023
bb040d1
remove fit end event (already doing on close)
aspfohl Dec 6, 2023
14f386f
misc fixes
aspfohl Dec 6, 2023
ac37d09
fix test
aspfohl Dec 6, 2023
535d5fb
Merge branch 'main' into anna/asynceval
aspfohl Dec 7, 2023
9e11cf7
add back eval interval
aspfohl Dec 8, 2023
aa652f3
build_loggers and add tests
aspfohl Dec 8, 2023
818f4ac
Merge branch 'main' into anna/asynceval
aspfohl Dec 8, 2023
0e4f085
updates
aspfohl Dec 11, 2023
bf06c19
Merge branch 'main' into anna/asynceval
aspfohl Dec 11, 2023
dc25b2a
typing
aspfohl Dec 11, 2023
6865b96
changes
aspfohl Dec 11, 2023
9754b03
Merge branch 'main' into anna/asynceval
aspfohl Dec 11, 2023
1ac70cc
typing?
aspfohl Dec 11, 2023
1a485d4
Merge branch 'main' into anna/asynceval
aspfohl Dec 13, 2023
cd2a31d
metadata in eval.py
aspfohl Dec 18, 2023
f6393a9
Merge branch 'main' into anna/asynceval
aspfohl Dec 18, 2023
04865db
actually, just log metadata on every model eval
aspfohl Dec 19, 2023
557aca8
Merge branch 'main' into anna/asynceval
aspfohl Dec 19, 2023
2fd2317
Merge branch 'main' into anna/asynceval
aspfohl Dec 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions llmfoundry/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

try:
from llmfoundry.callbacks.async_eval_callback import AsyncEval
from llmfoundry.callbacks.eval_gauntlet_callback import EvalGauntlet
from llmfoundry.callbacks.fdiff_callback import FDiffMetrics
from llmfoundry.callbacks.generate_callback import Generate
Expand All @@ -28,4 +29,5 @@
'EvalGauntlet',
'ModelGauntlet',
'HuggingFaceCheckpointer',
'AsyncEval',
]
156 changes: 156 additions & 0 deletions llmfoundry/callbacks/async_eval_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# Copyright 2023 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
import logging
aspfohl marked this conversation as resolved.
Show resolved Hide resolved
aspfohl marked this conversation as resolved.
Show resolved Hide resolved
import os
from typing import Any, Dict, Optional, Union

from composer.core import Callback, Event, State, Time
from composer.loggers import Logger
from composer.loggers.mosaicml_logger import (MOSAICML_PLATFORM_ENV_VAR,
RUN_NAME_ENV_VAR)
from composer.utils import create_interval_scheduler, dist
from mcli.api.runs import ComputeConfig # TODO: should be available in root

from mcli import Run, RunConfig, create_run, get_run

log = logging.getLogger(__name__)

MAX_RUN_NAME_LENGTH = 40
aspfohl marked this conversation as resolved.
Show resolved Hide resolved

# Note: train parameter names. See comments if they are different from eval
REQUIRED_PARAMS_FOR_EVAL = {
'device_eval_batch_size',
'icl_tasks', # only required for eval
'max_seq_len',
'model', # models
'save_folder', # required, but used as load_path
}
aspfohl marked this conversation as resolved.
Show resolved Hide resolved
OPTIONAL_PARAMS_FOR_EVAL = {
'dist_timeout',
'eval_gauntlet',
'fsdp_config', # fsdp_dict_cfg
'icl_subset_num_batches',
'loggers',
'precision',
'python_log_level',
'seed',
}


def get_run_name(previous_run_name: str, count: int) -> str:
return f'eval{count}-{previous_run_name[:MAX_RUN_NAME_LENGTH]}'


def get_load_path(save_folder: str,
save_latest_filename: Optional[str] = None) -> str:
# TODO: check that the prefix is remote and not a local file (not supported of course)

if not save_latest_filename:
aspfohl marked this conversation as resolved.
Show resolved Hide resolved
rank = dist.get_global_rank()
save_latest_filename = f'latest-rank{rank}.pt'

return f'{save_folder}/{save_latest_filename}'
aspfohl marked this conversation as resolved.
Show resolved Hide resolved


class AsyncEval(Callback):
aspfohl marked this conversation as resolved.
Show resolved Hide resolved
"""Run the eval loop asynchronously as part of a MosaicML platform run

Args:
interval: Union[str, int, Time]: The interval describing how often eval runs should be
launched. If an integer, it will be assumed to be in :attr:`.TimeUnit.EPOCH`.
Otherwise, the unit must be either :attr:`.TimeUnit.EPOCH`, :attr:`.TimeUnit.BATCH`,
:attr:`.TimeUnit.TOKEN`, or :attr:`.TimeUnit.SAMPLE`.
"""

def __init__(
self,
interval: Union[str, int, Time],
compute: Optional[ComputeConfig] = None,
):
aspfohl marked this conversation as resolved.
Show resolved Hide resolved
self.check_interval = create_interval_scheduler(interval)
self.compute = compute
self.count = 0

# Run these during init to fail fast in any of the error cases
self.current_run = self._get_current_run()
self._get_eval_parameters()

def run_event(self, event: Event, state: State, logger: Logger) -> None:
aspfohl marked this conversation as resolved.
Show resolved Hide resolved
del logger
if state.get_elapsed_duration() is not None and self.check_interval(
state, event):
new_run = self._launch_run()
logger.info(f'Launched new run {new_run.name} for eval loop')
self.count += 1

def _get_current_run(self) -> Run:
if os.environ.get(MOSAICML_PLATFORM_ENV_VAR,
'false').lower() == 'false':
raise Exception(
'AsyncEval callback is only supported when running on the MosaicML platform'
)

run_name = os.environ.get(RUN_NAME_ENV_VAR, None)
if not run_name:
raise Exception(
'RUN_NAME environment variable must be set to use the AsyncEval callback'
)

# allows the MapiException to be raised if the run doesn't exist
aspfohl marked this conversation as resolved.
Show resolved Hide resolved
return get_run(run_name, include_details=True)

def _get_eval_parameters(self) -> Dict[str, Any]:
cfg_params = self.current_run.submitted_config.parameters or {}
looking_for = REQUIRED_PARAMS_FOR_EVAL.copy()

# Go through all parameters and pull out the ones needed for eval
subset_keys = {}
for key in cfg_params:
if key in OPTIONAL_PARAMS_FOR_EVAL:
subset_keys[key] = cfg_params[key]
elif key in REQUIRED_PARAMS_FOR_EVAL:
subset_keys[key] = cfg_params[key]
looking_for.remove(key)

if looking_for:
raise Exception(
f'Missing the following required parameters for async eval: {looking_for}'
)

# Convert the save_folder to a load_path
subset_keys['load_path'] = get_load_path(
subset_keys.pop('save_folder'),
cfg_params.get('save_latest_filename', None))

# Rename the keys to match the eval script
subset_keys['models'] = [cfg_params.pop('model')]
if 'fsdp_cfg' in subset_keys:
subset_keys['fsdp_dict_cfg'] = cfg_params.pop('fsdp_cfg')

cfg_params['run_name'] = get_run_name(self.current_run.name, self.count)
return cfg_params

def _launch_run(self) -> Run:
cfg = self.current_run.submitted_config
default_compute = {
'nodes': 1,
'cluster': self.current_run.cluster,
}
params = self._get_eval_parameters()

# TODO: This just runs an eval run, but we also want to attach the
# deployment, which would require a hf conversion and parametrizing the
# dependent_deployment in the run config
command = 'cd llm-foundry/scripts \n composer eval/eval.py $PARAMETERS'
aspfohl marked this conversation as resolved.
Show resolved Hide resolved
aspfohl marked this conversation as resolved.
Show resolved Hide resolved
c = RunConfig(
aspfohl marked this conversation as resolved.
Show resolved Hide resolved
name=get_run_name(self.current_run.name, self.count),
image=self.current_run.image,
compute=self.compute or default_compute,
command=command,
integrations=cfg.integrations,
env_variables=cfg.env_variables,
metadata=cfg.metadata,
aspfohl marked this conversation as resolved.
Show resolved Hide resolved
parameters=params,
)

return create_run(c)
8 changes: 5 additions & 3 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
from torch.optim.optimizer import Optimizer
from transformers import AutoTokenizer, PreTrainedTokenizerBase

from llmfoundry.callbacks import (EvalGauntlet, FDiffMetrics, GlobalLRScaling,
HuggingFaceCheckpointer, LayerFreezing,
MonolithicCheckpointSaver,
from llmfoundry.callbacks import (AsyncEval, EvalGauntlet, FDiffMetrics,
GlobalLRScaling, HuggingFaceCheckpointer,
LayerFreezing, MonolithicCheckpointSaver,
ScheduledGarbageCollector)
from llmfoundry.optim import (DecoupledAdaLRLion, DecoupledClipLion,
DecoupledLionW, DecoupledLionW_8bit)
Expand Down Expand Up @@ -118,6 +118,8 @@ def build_callback(name: str, kwargs: Dict[str, Any]) -> Callback:
return EarlyStopper(**kwargs)
elif name == 'hf_checkpointer':
return HuggingFaceCheckpointer(**kwargs)
elif name == 'async_eval':
return AsyncEval(**kwargs)
else:
raise ValueError(f'Not sure how to build callback: {name}')

Expand Down
7 changes: 5 additions & 2 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,9 @@ def main(cfg: DictConfig) -> Trainer:
for name, callback_cfg in callback_configs.items()
] if callback_configs else []

use_async_eval = any(
isinstance(callback, Evaluator.Async) for callback in callbacks)
aspfohl marked this conversation as resolved.
Show resolved Hide resolved

# Algorithms
algorithms = [
build_algorithm(str(name), algorithm_cfg)
Expand Down Expand Up @@ -556,14 +559,14 @@ def main(cfg: DictConfig) -> Trainer:

eval_gauntlet_callback = None

if icl_tasks_config is not None:
aspfohl marked this conversation as resolved.
Show resolved Hide resolved
if icl_tasks_config is not None and not use_async_eval:
aspfohl marked this conversation as resolved.
Show resolved Hide resolved
icl_evaluators, _, eval_gauntlet_callback = build_icl_data_and_gauntlet(
icl_tasks_config, eval_gauntlet_config, tokenizer,
device_eval_batch_size, icl_seq_len if icl_seq_len else max_seq_len,
icl_subset_num_batches)
evaluators.extend(icl_evaluators)

if eval_gauntlet_callback is not None:
if eval_gauntlet_callback is not None and not use_async_eval:
callbacks.append(eval_gauntlet_callback)

# Build Model
Expand Down
Loading