Skip to content

Commit

Permalink
fix: handle failed validations correctly [DET-3838] (#1069)
Browse files Browse the repository at this point in the history
Previously, a change was made that prevented completed validation workloads from completing in the experiment actor if they were missing validation metrics. This doesn't work because validation workloads that fail are intentionally missing metrics. Instead we check if the exit reason is nil before trying to compute the best validation.
  • Loading branch information
stoksc authored Aug 12, 2020
1 parent a5f325f commit 39726d6
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 6 deletions.
1 change: 1 addition & 0 deletions e2e_tests/tests/experiment/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
run_basic_test,
run_basic_test_with_temp_config,
run_failure_test,
run_failure_test_with_temp_config,
s3_checkpoint_config,
s3_checkpoint_config_no_creds,
shared_fs_checkpoint_config,
Expand Down
9 changes: 9 additions & 0 deletions e2e_tests/tests/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,15 @@ def run_basic_test_with_temp_config(
return experiment_id


def run_failure_test_with_temp_config(
config: Dict[Any, Any], model_def_path: str, error_str: Optional[str] = None,
) -> None:
with tempfile.NamedTemporaryFile() as tf:
with open(tf.name, "w") as f:
yaml.dump(config, f)
run_failure_test(tf.name, model_def_path, error_str=error_str)


def shared_fs_checkpoint_config() -> Dict[str, str]:
return {
"type": "shared_fs",
Expand Down
3 changes: 3 additions & 0 deletions e2e_tests/tests/fixtures/no_op/model_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
self.chaos_probability_train = self.env.hparams.get("chaos_probability_train")
self.chaos_probability_validate = self.env.hparams.get("chaos_probability_validate")
self.chaos_probability_checkpoint = self.env.hparams.get("chaos_probability_checkpoint")
self.fail_on_first_validation = self.env.hparams.get("fail_on_first_validation", "")
self.validation_set_size = self.env.hparams.get("validation_set_size", 32 * 32)
self.train_batch_secs = self.env.hparams.get("training_batch_seconds", 0)
self.validation_secs = self.env.hparams.get(
Expand Down Expand Up @@ -103,6 +104,8 @@ def train_for_step(self, step_id: int, num_batches: int) -> Dict[str, Any]:
return response

def compute_validation_metrics(self, step_id: int) -> Dict[str, Any]:
if self.fail_on_first_validation:
raise Exception(self.fail_on_first_validation)
self.chaos_failure(self.chaos_probability_validate)
time.sleep(self.validation_secs)
metrics = {
Expand Down
10 changes: 10 additions & 0 deletions e2e_tests/tests/test_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,3 +471,13 @@ def test_pytorch_parallel() -> None:
exp.run_basic_test_with_temp_config(
config, conf.official_examples_path("trial/mnist_pytorch"), 1
)


@pytest.mark.e2e_cpu # type: ignore
def test_fail_on_first_validation() -> None:
error_log = "failed on first validation"
config_obj = conf.load_config(conf.fixtures_path("no_op/single.yaml"))
config_obj["hyperparameters"]["fail_on_first_validation"] = error_log
exp.run_failure_test_with_temp_config(
config_obj, conf.fixtures_path("no_op"), error_log,
)
10 changes: 4 additions & 6 deletions master/internal/experiment.go
Original file line number Diff line number Diff line change
Expand Up @@ -291,12 +291,10 @@ func (e *experiment) Receive(ctx *actor.Context) error {
e.processOperations(ctx, ops, err)
case trialCompletedWorkload:
e.searcher.WorkloadCompleted(msg.completedMessage, msg.unitsCompleted)
e.processOperations(ctx, nil, nil) // we call processOperations to flush searcher events.
if msg.completedMessage.Workload.Kind == searcher.ComputeValidationMetrics {
if msg.completedMessage.ValidationMetrics == nil {
return fmt.Errorf("completed validation workload missing metrics %s",
msg.completedMessage.Workload)
}
e.processOperations(ctx, nil, nil) // We call processOperations to flush searcher events.
if msg.completedMessage.Workload.Kind == searcher.ComputeValidationMetrics &&
// Messages indicating trial failures won't have metrics (or need their status).
msg.completedMessage.ExitedReason == nil {
ctx.Respond(e.isBestValidation(*msg.completedMessage.ValidationMetrics))
}
progress := e.searcher.Progress()
Expand Down

0 comments on commit 39726d6

Please sign in to comment.