diff --git a/CHANGELOG.md b/CHANGELOG.md index 064ab8b3..cb37372d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Added + +- Added `retries` field to `BeakerLaunchConfig`. + ## [v1.6.0](https://github.com/allenai/OLMo-core/releases/tag/v1.6.0) - 2024-11-01 ### Added diff --git a/pyproject.toml b/pyproject.toml index b770d742..0633a4c5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,7 @@ dev = [ "sphinx-autodoc-typehints==1.23.3", ] beaker = [ - "beaker-py", + "beaker-py>=1.32.0", "GitPython>=3.0,<4.0", ] wandb = [ diff --git a/src/olmo_core/launch/beaker.py b/src/olmo_core/launch/beaker.py index 3bfa2f29..50353db9 100644 --- a/src/olmo_core/launch/beaker.py +++ b/src/olmo_core/launch/beaker.py @@ -19,6 +19,7 @@ ExperimentSpec, Job, Priority, + RetrySpec, TaskResources, TaskSpec, ) @@ -175,6 +176,11 @@ class BeakerLaunchConfig(Config): If the job should be preemptible. """ + retries: Optional[int] = None + """ + The number of times to retry the experiment if it fails. + """ + env_vars: List[BeakerEnvVar] = field(default_factory=list) """ Additional env vars to include. @@ -360,7 +366,12 @@ def build_experiment_spec(self, torchrun: bool = True) -> ExperimentSpec: for bucket in self.weka_buckets: task_spec = task_spec.with_dataset(bucket.mount, weka=bucket.bucket) - return ExperimentSpec(description=self.description, budget=self.budget, tasks=[task_spec]) + return ExperimentSpec( + description=self.description, + budget=self.budget, + tasks=[task_spec], + retry=None if not self.retries else RetrySpec(allowed_task_retries=self.retries), + ) def _follow_experiment(self, experiment: Experiment): # Wait for job to start...