diff --git a/CHANGELOG.md b/CHANGELOG.md index 352ece3a..ef600e43 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Added support for tensor parallelism. See the `TransformerConfig` class for usage. +- Added more downstream tasks from the model ladder. - Added `io.copy_dir()` function. - Added new LR schedulers: `LinearWithWarmup`, `InvSqrtWithWarmup`, `ConstantWithWarmup`, `SequentialScheduler`. - Added option to pre-download checkpoint files from remote storage before trying to load a checkpoint. diff --git a/pyproject.toml b/pyproject.toml index 9db53a44..40f45842 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ "omegaconf", "safetensors", "importlib_resources", - "ai2-olmo-eval==0.2.0", + "ai2-olmo-eval==0.5.0", ] [project.urls] diff --git a/src/olmo_core/train/callbacks/evaluator_callback.py b/src/olmo_core/train/callbacks/evaluator_callback.py index 19bdb191..d84d0a58 100644 --- a/src/olmo_core/train/callbacks/evaluator_callback.py +++ b/src/olmo_core/train/callbacks/evaluator_callback.py @@ -99,7 +99,6 @@ def post_step(self): metrics = [] with cuda_sync_debug_mode(0): for name, value in evaluator.compute_metrics().items(): - value = value.item() metrics.append(f" {name}={format_float(value)}") self.trainer.record_metric(f"eval/{evaluator.name}/{name}", value) log.info("Eval metrics:\n" + "\n".join(metrics)) @@ -161,6 +160,8 @@ class DownstreamEvaluator(Evaluator): "pmi_dc": "PMI-DC accuracy", "ce_loss": "CE loss", "bpb": "BPB", + "soft": "soft loss", + "soft_log": "log soft loss", } def __init__( @@ -184,13 +185,14 @@ def __init__( if is_distributed(): sampler = DistributedSampler( self.task, # type: ignore - drop_last=True, + drop_last=False, shuffle=False, num_replicas=get_world_size(dp_process_group), rank=get_rank(dp_process_group), ) rank_batch_size_instances = max(0, rank_batch_size // self.task.max_sequence_length) + log.info( f"Using per-rank batch size of {rank_batch_size_instances} instances " f"for downstream eval task '{task}' with max sequence length {self.task.max_sequence_length:,d} tokens" @@ -215,9 +217,12 @@ def update_metrics( self.metric.update(batch, logits) def compute_metrics(self) -> Dict[str, torch.Tensor]: - value = self.metric.compute() - label = f"{self.label} ({self.metric_type_to_label[self.task.metric_type]})" - return {label: value} + metric_type_to_value = self.metric.compute() + outputs = {} + for metric_type, value in metric_type_to_value.items(): + key = f"{self.label} ({self.metric_type_to_label[metric_type]})" + outputs[key] = value.item() + return outputs def reset_metrics(self) -> None: self.metric.reset() @@ -227,7 +232,7 @@ def reset_metrics(self) -> None: class DownstreamEvaluatorCallbackConfig(CallbackConfig): tasks: List[str] tokenizer: TokenizerConfig - eval_batch_size: Optional[int] = None + eval_batch_size: Optional[int] = None # NOTE: this counts in number of tokens eval_interval: int = 1000 eval_duration: Duration = field(default_factory=lambda: Duration.epochs(1)) log_interval: int = 5 diff --git a/src/scripts/train/OLMo2-1B.py b/src/scripts/train/OLMo2-1B.py index 1285df7c..99b292e1 100644 --- a/src/scripts/train/OLMo2-1B.py +++ b/src/scripts/train/OLMo2-1B.py @@ -9,6 +9,9 @@ from olmo_core.optim import AdamWConfig, OptimGroupOverride from olmo_core.train import TrainerConfig from olmo_core.train.callbacks import CheckpointerCallback, CometCallback, WandBCallback +from olmo_core.train.callbacks.evaluator_callback import ( + DownstreamEvaluatorCallbackConfig, +) def build_model_config(common: CommonComponents) -> TransformerConfig: @@ -73,6 +76,48 @@ def build_trainer_config(common: CommonComponents) -> TrainerConfig: cancel_check_interval=10, ), ) + .with_callback( + "downstream_evaluator", + DownstreamEvaluatorCallbackConfig( + tasks=[ + "arc_challenge_val_rc_5shot", + "arc_challenge_val_mc_5shot", + "arc_challenge_test_rc_5shot", + "arc_challenge_test_mc_5shot", + "arc_easy_val_rc_5shot", + "arc_easy_val_mc_5shot", + "arc_easy_test_rc_5shot", + "arc_easy_test_mc_5shot", + "boolq_val_rc_5shot", + "boolq_val_mc_5shot", + "csqa_val_rc_5shot", + "csqa_val_mc_5shot", + "hellaswag_val_rc_5shot", + "hellaswag_val_mc_5shot", + "openbookqa_val_rc_5shot", + "openbookqa_val_mc_5shot", + "openbookqa_test_rc_5shot", + "openbookqa_test_mc_5shot", + "piqa_val_rc_5shot", + "piqa_val_mc_5shot", + "socialiqa_val_rc_5shot", + "socialiqa_val_mc_5shot", + "winogrande_val_rc_5shot", + "winogrande_val_mc_5shot", + "mmlu_stem_val_rc_5shot", + "mmlu_stem_val_mc_5shot", + "mmlu_humanities_val_rc_5shot", + "mmlu_humanities_val_mc_5shot", + "mmlu_social_sciences_val_rc_5shot", + "mmlu_social_sciences_val_mc_5shot", + "mmlu_other_val_rc_5shot", + "mmlu_other_val_mc_5shot", + ], + tokenizer=common.tokenizer, + eval_batch_size=1024 * 4096, + eval_interval=1000, + ), + ) )