Skip to content

Commit

Permalink
fix: assign only run in a single run experiment as best_trial_id (#9051)
Browse files Browse the repository at this point in the history
  • Loading branch information
corban-beaird authored Jul 24, 2024
1 parent 543380d commit 4f81548
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 2 deletions.
16 changes: 14 additions & 2 deletions master/internal/api_experiment_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1112,7 +1112,7 @@ func TestSearchExperiments(t *testing.T) {
require.Nil(t, resp.Experiments[0].BestTrial)
require.Equal(t, int32(exp.ID), resp.Experiments[0].Experiment.Id)

require.Nil(t, resp.Experiments[1].BestTrial) // Still nil since no validations reported.
require.NotNil(t, resp.Experiments[1].BestTrial) // Now has a best trial, since it's the only one.
require.Equal(t, int32(noValidationsExp.ID), resp.Experiments[1].Experiment.Id)

// Validations returned properly.
Expand Down Expand Up @@ -1147,7 +1147,7 @@ func TestSearchExperiments(t *testing.T) {
require.Nil(t, resp.Experiments[0].BestTrial)
require.Equal(t, int32(exp.ID), resp.Experiments[0].Experiment.Id)

require.Nil(t, resp.Experiments[1].BestTrial)
require.NotNil(t, resp.Experiments[1].BestTrial)
require.Equal(t, int32(noValidationsExp.ID), resp.Experiments[1].Experiment.Id)

require.NotNil(t, resp.Experiments[2].BestTrial)
Expand All @@ -1169,6 +1169,18 @@ func TestSearchExperiments(t *testing.T) {
require.Equal(t, string(latestExpected), string(latestActual))

require.Equal(t, int32(5), resp.Experiments[2].BestTrial.Restarts)

// when single-trial experiment with no validation metrics completes
// ensure that the only trial is the best trial.
_, err = db.Bun().NewUpdate().Table("experiments").
Set("state = ?", model.CompletedState).
Where("id = ?", noValidationsExp.ID).
Exec(ctx)
require.NoError(t, err)
resp, err = api.SearchExperiments(ctx, req)
require.NoError(t, err)
require.NotNil(t, resp.Experiments[1].BestTrial)
require.Equal(t, int32(noValidationsExp.ID), resp.Experiments[1].Experiment.Id)
}

func TestSearchExperimentsFilters(t *testing.T) {
Expand Down
16 changes: 16 additions & 0 deletions master/internal/db/postgres_trial.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,22 @@ func AddTrial(ctx context.Context, trial *model.Trial, taskID model.TaskID) erro
return fmt.Errorf("inserting project hyperparameters: %w", err)
}
}

var isSingleTrial bool
err = tx.NewSelect().
ColumnExpr("config->'searcher'->>'name' = 'single'").
Table("experiments").
Where("id = ?", run.ExperimentID).
Scan(ctx, &isSingleTrial)
if err != nil {
return fmt.Errorf("getting experiment config while inserting trial: %w", err)
}
if isSingleTrial {
if _, err := tx.NewUpdate().Table("experiments").Set("best_trial_id = ?", run.ID).
Where("id = ?", run.ExperimentID).Exec(ctx); err != nil {
return fmt.Errorf("updating best trial id for single trial experiment: %w", err)
}
}
return nil
})
if err != nil {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
-- backfill existing single-trial experiments without best_trial_id
WITH single_run_experiments AS (
SELECT
id
FROM
experiments e
WHERE
e.config->'searcher'->>'name' = 'single'
),
br_no_validation AS (
SELECT r.experiment_id, r.id, r.best_validation_id
FROM
runs r
INNER JOIN single_run_experiments sre
ON r.experiment_id = sre.id
ORDER BY searcher_metric_value_signed)
UPDATE
experiments
SET
best_trial_id = brnv.id
FROM
br_no_validation brnv
WHERE
experiments.id = brnv.experiment_id;

0 comments on commit 4f81548

Please sign in to comment.