Skip to content

Commit

Permalink
chore: prune cruft out of no_op fixture (#9912)
Browse files Browse the repository at this point in the history
The no_op fixture is one of the oldest fixtures we have, and it has
built up a lot of cruft.

It is also based on legacy infrastructure that we are trying to remove
(SearcherContext and WorkloadSequencer), so we're first removing dead or
unnecessary code now, to make the work of rewriting the fixture easier
later.

Pointless tests, now removed:
  - test_fail_on_chechpoint_save: added in d0b384a, used to test the
    golang workload sequencer, which is long-dead.  This is basically
    testing the fixture at this point.
  - test_fail_on_preclose_chechpoint_save: added in 2e429f3, also used
    to test the golang workload sequencer.
  - test_fail_on_first_validation: added in 39726d6, also used to test
    the golang workload sequencer.
  - test_perform_initial_validation: added in 3485d82, used to test the
    golang workload sequencer.

Unused configs, now removed:
  - single-in-epochs.yaml
  - single-in-records.yaml
  - adaptive_chaos.yaml

Unused hyperparameters, now removed:
  - validation_secs
  - save_checkpoint_secs
  - load_checkpoint_secs
  - validation_set_size
  - fail_on_first_validation
  - fail_on_chechpoint_save

Weird cruft:
  - the metrics_sigma hparam was exclusively used to crash the trial.  I
    replaced it with the much clearer "crash_on_startup" hyperparameter.
  • Loading branch information
rb-determined-ai authored Sep 11, 2024
1 parent 11de119 commit 8c799b8
Show file tree
Hide file tree
Showing 23 changed files with 30 additions and 218 deletions.
31 changes: 0 additions & 31 deletions e2e_tests/tests/cluster/test_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,34 +538,3 @@ def assert_checkpoint_state(
new_resources,
bindings.checkpointv1State.PARTIALLY_DELETED,
)


@pytest.mark.e2e_cpu
def test_fail_on_chechpoint_save() -> None:
sess = api_utils.user_session()
error_log = "failed on checkpoint save"
config_obj = conf.load_config(conf.fixtures_path("no_op/single.yaml"))
config_obj["hyperparameters"]["fail_on_chechpoint_save"] = error_log
exp.run_failure_test_with_temp_config(
sess,
config_obj,
conf.fixtures_path("no_op"),
error_log,
)


@pytest.mark.e2e_cpu
def test_fail_on_preclose_chechpoint_save() -> None:
sess = api_utils.user_session()
error_log = "failed on checkpoint save"
config_obj = conf.load_config(conf.fixtures_path("no_op/single.yaml"))
config_obj["hyperparameters"]["fail_on_chechpoint_save"] = error_log
config_obj["searcher"]["max_length"] = {"batches": 1}
config_obj["min_validation_period"] = {"batches": 1}
config_obj["max_restarts"] = 1
exp.run_failure_test_with_temp_config(
sess,
config_obj,
conf.fixtures_path("no_op"),
error_log,
)
23 changes: 12 additions & 11 deletions e2e_tests/tests/cluster/test_exp_continue.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ def test_continue_config_file_cli() -> None:
sess,
conf.fixtures_path("no_op/single-medium-train-step.yaml"),
conf.fixtures_path("no_op"),
["--config", "hyperparameters.metrics_sigma=-1.0"],
["--config", "hyperparameters.crash_on_startup=true"],
)
exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.ERROR)

with tempfile.NamedTemporaryFile() as tf:
with open(tf.name, "w") as f:
util.yaml_safe_dump({"hyperparameters": {"metrics_sigma": 1.0}}, f)
util.yaml_safe_dump({"hyperparameters": {"crash_on_startup": False}}, f)
detproc.check_call(sess, ["det", "e", "continue", str(exp_id), "--config-file", tf.name])

exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.COMPLETED)
Expand All @@ -37,15 +37,15 @@ def test_continue_config_file_and_args_cli() -> None:
sess,
conf.fixtures_path("no_op/single-medium-train-step.yaml"),
conf.fixtures_path("no_op"),
["--config", "hyperparameters.metrics_sigma=-1.0"],
["--config", "hyperparameters.crash_on_startup=true"],
)
exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.ERROR)

expected_name = "checkThis"
with tempfile.NamedTemporaryFile() as tf:
with open(tf.name, "w") as f:
util.yaml_safe_dump(
{"name": expected_name, "hyperparameters": {"metrics_sigma": -1.0}}, f
{"name": expected_name, "hyperparameters": {"crash_on_startup": True}}, f
)

stdout = detproc.check_output(
Expand All @@ -58,7 +58,7 @@ def test_continue_config_file_and_args_cli() -> None:
"--config-file",
tf.name,
"--config",
"hyperparameters.metrics_sigma=1.0",
"hyperparameters.crash_on_startup=false",
"-f",
],
)
Expand All @@ -85,12 +85,13 @@ def test_continue_fixing_broken_config() -> None:
sess,
conf.fixtures_path("no_op/single-medium-train-step.yaml"),
conf.fixtures_path("no_op"),
["--config", "hyperparameters.metrics_sigma=-1.0"],
["--config", "hyperparameters.crash_on_startup=true"],
)
exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.ERROR)

detproc.check_call(
sess, ["det", "e", "continue", str(exp_id), "--config", "hyperparameters.metrics_sigma=1.0"]
sess,
["det", "e", "continue", str(exp_id), "--config", "hyperparameters.crash_on_startup=false"],
)
exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.COMPLETED)

Expand All @@ -99,7 +100,7 @@ def test_continue_fixing_broken_config() -> None:

# Trial logs show both tasks logs with the failure message in it.
trial_logs = "\n".join(exp.trial_logs(sess, trials[0].trial.id))
assert "assert 0 <= self.metrics_sigma" in trial_logs
assert "assert not self.crash_on_startup" in trial_logs
assert "resources exited successfully with a zero exit code" in trial_logs


Expand All @@ -110,7 +111,7 @@ def test_continue_max_restart() -> None:
sess,
conf.fixtures_path("no_op/single-medium-train-step.yaml"),
conf.fixtures_path("no_op"),
["--config", "hyperparameters.metrics_sigma=-1.0", "--config", "max_restarts=2"],
["--config", "hyperparameters.crash_on_startup=true", "--config", "max_restarts=2"],
)
exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.ERROR)

Expand All @@ -119,7 +120,7 @@ def test_continue_max_restart() -> None:

def count_times_ran() -> int:
return "\n".join(exp.trial_logs(sess, trials[0].trial.id)).count(
"assert 0 <= self.metrics_sigma"
"assert not self.crash_on_startup"
)

def get_trial_restarts() -> int:
Expand Down Expand Up @@ -148,7 +149,7 @@ def test_continue_trial_time() -> None:
sess,
conf.fixtures_path("no_op/single-medium-train-step.yaml"),
conf.fixtures_path("no_op"),
["--config", "hyperparameters.metrics_sigma=-1.0"],
["--config", "hyperparameters.crash_on_startup=true"],
)
exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.ERROR)

Expand Down
16 changes: 8 additions & 8 deletions e2e_tests/tests/cluster/test_log_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
@pytest.mark.parametrize("should_match", [True, False])
def test_log_policy_cancel_retries(should_match: bool) -> None:
sess = api_utils.user_session()
regex = r"assert 0 <= self\.metrics_sigma"
regex = r"assert not self\.crash_on_startup"
if not should_match:
regex = r"(.*) this should not match (.*)"

Expand All @@ -26,7 +26,7 @@ def test_log_policy_cancel_retries(should_match: bool) -> None:
},
},
]
config["hyperparameters"]["metrics_sigma"] = -1
config["hyperparameters"]["crash_on_startup"] = True
config["max_restarts"] = 1

with tempfile.NamedTemporaryFile() as tf:
Expand All @@ -52,7 +52,7 @@ def test_log_policy_cancel_retries(should_match: bool) -> None:
@pytest.mark.parametrize("should_match", [True, False])
def test_log_policy_exclude_node_k8s(should_match: bool) -> None:
sess = api_utils.user_session()
regex = r"assert 0 <= self\.metrics_sigma"
regex = r"assert not self\.crash_on_startup"
if not should_match:
regex = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b"

Expand All @@ -65,7 +65,7 @@ def test_log_policy_exclude_node_k8s(should_match: bool) -> None:
},
},
]
config["hyperparameters"]["metrics_sigma"] = -1
config["hyperparameters"]["crash_on_startup"] = True
config["max_restarts"] = 1

agents = bindings.get_GetAgents(sess).agents
Expand Down Expand Up @@ -111,7 +111,7 @@ def test_log_policy_exclude_node_k8s(should_match: bool) -> None:
@pytest.mark.parametrize("should_match", [True, False])
def test_log_policy_exclude_node_single_agent(should_match: bool) -> None:
sess = api_utils.user_session()
regex = r"assert 0 <= self\.metrics_sigma"
regex = r"assert not self\.crash_on_startup"
if not should_match:
regex = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b"

Expand All @@ -124,7 +124,7 @@ def test_log_policy_exclude_node_single_agent(should_match: bool) -> None:
},
},
]
config["hyperparameters"]["metrics_sigma"] = -1
config["hyperparameters"]["crash_on_startup"] = True
config["max_restarts"] = 1

agents = bindings.get_GetAgents(sess).agents
Expand Down Expand Up @@ -170,7 +170,7 @@ def test_log_policy_exclude_slurm(should_match: bool) -> None:
if len(agents) != 1:
pytest.skip("can only be run on a single agent cluster")

regex = r"assert 0 <= self\.metrics_sigma"
regex = r"assert not self\.crash_on_startup"
if not should_match:
regex = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b"

Expand All @@ -183,7 +183,7 @@ def test_log_policy_exclude_slurm(should_match: bool) -> None:
},
},
]
config["hyperparameters"]["metrics_sigma"] = -1
config["hyperparameters"]["crash_on_startup"] = True
config["max_restarts"] = 1

with tempfile.NamedTemporaryFile() as tf:
Expand Down
4 changes: 2 additions & 2 deletions e2e_tests/tests/cluster/test_webhooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_log_pattern_send_webhook(should_match: bool) -> None:
sess = api_utils.admin_session()
ws_id = []

regex = r"assert 0 <= self\.metrics_sigma"
regex = r"assert not self\.crash_on_startup"
if not should_match:
regex = r"(.*)cuda(.*)"

Expand Down Expand Up @@ -177,7 +177,7 @@ def test_log_pattern_send_webhook(should_match: bool) -> None:
conf.fixtures_path("no_op"),
[
"--config",
"hyperparameters.metrics_sigma=-1.0",
"hyperparameters.crash_on_startup=True",
"--config",
"integrations.webhooks.webhook_name=['specific-webhook']",
"--project_id",
Expand Down
6 changes: 0 additions & 6 deletions e2e_tests/tests/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,6 @@ def set_hparam(config: Dict[Any, Any], name: str, value: Any) -> Dict[Any, Any]:
return config


def set_perform_initial_validation(config: Dict[Any, Any], init_val: bool) -> Dict[Any, Any]:
config = config.copy()
config["perform_initial_validation"] = init_val
return config


def set_pod_spec(config: Dict[Any, Any], pod_spec: Dict[Any, Any]) -> Dict[Any, Any]:
config = config.copy()
config.setdefault("environment", {})
Expand Down
1 change: 0 additions & 1 deletion e2e_tests/tests/experiment/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
activate_experiments,
archive_experiments,
assert_performed_final_checkpoint,
assert_performed_initial_validation,
cancel_experiments,
cancel_single,
cancel_experiment,
Expand Down
12 changes: 0 additions & 12 deletions e2e_tests/tests/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,18 +616,6 @@ def assert_patterns_in_trial_logs(sess: api.Session, trial_id: int, patterns: Li
)


def assert_performed_initial_validation(sess: api.Session, exp_id: int) -> None:
trials = experiment_trials(sess, exp_id)

assert len(trials) > 0
workloads = trials[0].workloads

assert len(workloads) > 0
zeroth_step = workloads_with_validation(workloads)[0]

assert zeroth_step.totalBatches == 0


def last_workload_matches_last_checkpoint(
workloads: Sequence[bindings.v1WorkloadContainer],
) -> None:
Expand Down
24 changes: 0 additions & 24 deletions e2e_tests/tests/experiment/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,30 +286,6 @@ def test_kill_experiment_ignoring_preemption() -> None:
exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.CANCELED)


@pytest.mark.e2e_cpu
def test_fail_on_first_validation() -> None:
sess = api_utils.user_session()
error_log = "failed on first validation"
config_obj = conf.load_config(conf.fixtures_path("no_op/single.yaml"))
config_obj["hyperparameters"]["fail_on_first_validation"] = error_log
exp.run_failure_test_with_temp_config(
sess,
config_obj,
conf.fixtures_path("no_op"),
error_log,
)


@pytest.mark.e2e_cpu
def test_perform_initial_validation() -> None:
sess = api_utils.user_session()
config = conf.load_config(conf.fixtures_path("no_op/single.yaml"))
config = conf.set_max_length(config, {"batches": 1})
config = conf.set_perform_initial_validation(config, True)
exp_id = exp.run_basic_test_with_temp_config(sess, config, conf.fixtures_path("no_op"), 1)
exp.assert_performed_initial_validation(sess, exp_id)


@pytest.mark.e2e_cpu_2a
@pytest.mark.parametrize(
"name,searcher_cfg",
Expand Down
1 change: 0 additions & 1 deletion e2e_tests/tests/fixtures/no_op/adaptive.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ hyperparameters:
type: double
minval: 0.5
maxval: 0.9
metrics_sigma: 0
searcher:
name: adaptive_asha
metric: validation_error
Expand Down
24 changes: 0 additions & 24 deletions e2e_tests/tests/fixtures/no_op/adaptive_chaos.yaml

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ hyperparameters:
global_batch_size: 32
metrics_progression: decreasing
metrics_base: 0.5
metrics_sigma: 0
request_stop:
type: categorical
vals: [True, False]
Expand Down
1 change: 0 additions & 1 deletion e2e_tests/tests/fixtures/no_op/grid-short.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ checkpoint_storage:
hyperparameters:
global_batch_size: 32
metrics_progression: decreasing
metrics_sigma: 0
learning_rate:
count: 3
maxval: 1
Expand Down
Loading

0 comments on commit 8c799b8

Please sign in to comment.