Skip to content

Commit

Permalink
Switch to comet callback in official train scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Oct 11, 2024
1 parent d90f5da commit ad4c8bb
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 6 deletions.
2 changes: 2 additions & 0 deletions src/olmo_core/internal/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
)
from olmo_core.train.callbacks import (
Callback,
CometCallback,
ConfigSaverCallback,
Float8HandlerCallback,
GPUMemoryMonitorCallback,
Expand Down Expand Up @@ -298,6 +299,7 @@ def train(config: ExperimentConfig):

# Record the config to W&B and each checkpoint dir.
config_dict = config.as_config_dict()
cast(CometCallback, trainer.callbacks["comet"]).config = config_dict
cast(WandBCallback, trainer.callbacks["wandb"]).config = config_dict
cast(ConfigSaverCallback, trainer.callbacks["config_saver"]).config = config_dict

Expand Down
14 changes: 12 additions & 2 deletions src/scripts/train/OLMo-13B.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from olmo_core.nn.transformer import TransformerConfig
from olmo_core.optim import AdamWConfig, OptimGroupOverride
from olmo_core.train import TrainerConfig
from olmo_core.train.callbacks import CheckpointerCallback, WandBCallback
from olmo_core.train.callbacks import CheckpointerCallback, CometCallback, WandBCallback

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -57,13 +57,23 @@ def build_trainer_config(common: CommonComponents) -> TrainerConfig:
save_async=True,
),
)
.with_callback(
"comet",
CometCallback(
name=common.run_name,
workspace="ai2",
project="OLMo-core-13B",
enabled=True,
cancel_check_interval=10,
),
)
.with_callback(
"wandb",
WandBCallback(
name=common.run_name,
entity="ai2-llm",
project="OLMo-core-13B",
enabled=True,
enabled=False,
cancel_check_interval=10,
),
)
Expand Down
14 changes: 12 additions & 2 deletions src/scripts/train/OLMo-1B.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from olmo_core.nn.transformer import TransformerConfig
from olmo_core.optim import AdamWConfig, OptimGroupOverride
from olmo_core.train import TrainerConfig
from olmo_core.train.callbacks import CheckpointerCallback, WandBCallback
from olmo_core.train.callbacks import CheckpointerCallback, CometCallback, WandBCallback

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -56,13 +56,23 @@ def build_trainer_config(common: CommonComponents) -> TrainerConfig:
save_async=True,
),
)
.with_callback(
"comet",
CometCallback(
name=common.run_name,
workspace="ai2",
project="OLMo-core-1B",
enabled=True,
cancel_check_interval=10,
),
)
.with_callback(
"wandb",
WandBCallback(
name=common.run_name,
entity="ai2-llm",
project="OLMo-core-1B",
enabled=True,
enabled=False,
cancel_check_interval=10,
),
)
Expand Down
14 changes: 12 additions & 2 deletions src/scripts/train/OLMo-7B.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from olmo_core.nn.transformer import TransformerConfig
from olmo_core.optim import AdamWConfig, OptimGroupOverride
from olmo_core.train import TrainerConfig
from olmo_core.train.callbacks import CheckpointerCallback, WandBCallback
from olmo_core.train.callbacks import CheckpointerCallback, CometCallback, WandBCallback

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -57,13 +57,23 @@ def build_trainer_config(common: CommonComponents) -> TrainerConfig:
save_async=True,
),
)
.with_callback(
"comet",
CometCallback(
name=common.run_name,
workspace="ai2",
project="OLMo-core-7B",
enabled=True,
cancel_check_interval=10,
),
)
.with_callback(
"wandb",
WandBCallback(
name=common.run_name,
entity="ai2-llm",
project="OLMo-core-7B",
enabled=True,
enabled=False,
cancel_check_interval=10,
),
)
Expand Down

0 comments on commit ad4c8bb

Please sign in to comment.