Skip to content

Commit

Permalink
fix: fix regression caused by join on trials view (#9091)
Browse files Browse the repository at this point in the history
  • Loading branch information
salonig23 authored Apr 5, 2024
1 parent bdab9e4 commit cf2f2be
Show file tree
Hide file tree
Showing 18 changed files with 209 additions and 86 deletions.
11 changes: 5 additions & 6 deletions master/internal/api_experiment.go
Original file line number Diff line number Diff line change
Expand Up @@ -2752,12 +2752,11 @@ func (a *apiServer) createTrialTx(
a.m.taskSpec.LogRetentionDays)

if err := db.AddTask(ctx, &model.Task{
TaskID: taskID,
TaskType: model.TaskTypeTrial,
StartTime: time.Now(),
JobID: nil,
LogVersion: model.CurrentTaskLogVersion,
LogRetentionDays: a.m.taskSpec.LogRetentionDays,
TaskID: taskID,
TaskType: model.TaskTypeTrial,
StartTime: time.Now(),
JobID: nil,
LogVersion: model.CurrentTaskLogVersion,
}); err != nil {
return nil, err
}
Expand Down
43 changes: 13 additions & 30 deletions master/internal/api_experiment_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -492,40 +492,23 @@ func TestHPSearchContinueCompletedError(t *testing.T) {
}

func TestPutExperimentRetainLogs(t *testing.T) {
api, curUser, ctx := setupAPITest(t, nil)
exp := createTestExp(t, api, curUser)

trialIDs, taskIDs, err := db.ExperimentsTrialAndTaskIDs(ctx, db.Bun(), []int{(exp.ID)})
require.NoError(t, err)
api, _, ctx := setupAPITest(t, nil)
exp, trialIDs, _ := CreateTestRetentionExperiment(ctx, t, api, logRetentionConfigForever, 5)

_, err = db.Bun().NewUpdate().Table("experiments").
Set("state = ?", model.CompletedState).
Where("id = ?", exp.ID).
Exec(ctx)
require.NoError(t, err)
_, err = db.Bun().NewUpdate().Table("runs").
Set("state = ?", model.CompletedState).
Where("id IN (?)", bun.In(trialIDs)).
Exec(ctx)
err := CompleteExpAndTrials(ctx, exp.Id, trialIDs)
require.NoError(t, err)

numDays := -1
res, err := api.PutExperimentRetainLogs(ctx, &apiv1.PutExperimentRetainLogsRequest{
ExperimentId: int32(exp.ID), NumDays: int32(numDays),
ExperimentId: exp.Id, NumDays: int32(numDays),
})
require.NoError(t, err)
require.NotNil(t, res)

var logRetentionDays []int
err = db.Bun().NewSelect().Table("tasks").
Column("log_retention_days").
Where("task_id IN (?)", bun.In(taskIDs)).
Scan(ctx, &logRetentionDays)
newLogRetentionDays := []int32{-1, -1, -1, -1, -1}
updatedLogRetentionDays, err := getLogRetentionDays(ctx, trialIDs)
require.NoError(t, err)

for _, v := range logRetentionDays {
require.Equal(t, v, numDays)
}
require.Equal(t, updatedLogRetentionDays, newLogRetentionDays)
}

func TestPutExperimentsRetainLogs(t *testing.T) {
Expand Down Expand Up @@ -558,17 +541,17 @@ func TestPutExperimentsRetainLogs(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, res)

_, taskIDs, err := db.ExperimentsTrialAndTaskIDs(ctx, db.Bun(), intExpIDS)
_, _, err = db.ExperimentsTrialAndTaskIDs(ctx, db.Bun(), intExpIDS)
require.NoError(t, err)

var logRetentionDays []int
err = db.Bun().NewSelect().Table("tasks").
var trialLogRetentionDays []int
err = db.Bun().NewSelect().Table("runs").
Column("log_retention_days").
Where("task_id IN (?)", bun.In(taskIDs)).
Scan(ctx, &logRetentionDays)
Where("id IN (?)", bun.In(trialIDs)).
Scan(ctx, &trialLogRetentionDays)
require.NoError(t, err)

for _, v := range logRetentionDays {
for _, v := range trialLogRetentionDays {
require.Equal(t, v, numDays)
}
}
Expand Down
22 changes: 11 additions & 11 deletions master/internal/api_logretention_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func setRetentionTime(timestamp string) error {
return err
}

func completeExpAndTrials(ctx context.Context, expID int32, trialIDs []int) error {
func CompleteExpAndTrials(ctx context.Context, expID int32, trialIDs []int) error {
_, err := db.Bun().NewUpdate().Table("experiments").
Set("state = ?", model.CompletedState).
Where("id = ?", expID).
Expand All @@ -74,7 +74,7 @@ func resetRetentionTime() error {
}

// nolint: exhaustruct
func createTestRetentionExperiment(
func CreateTestRetentionExperiment(
ctx context.Context, t *testing.T, api *apiServer, config string, numTrials int,
) (*experimentv1.Experiment, []int, []model.TaskID) {
conf := fmt.Sprintf(`
Expand Down Expand Up @@ -125,19 +125,19 @@ func TestDeleteExpiredTaskLogs(t *testing.T) {
require.NoError(t, err)

// Create an experiment1 with 5 trials and no special config.
experiment1, trialIDs1, taskIDs1 := createTestRetentionExperiment(ctx, t, api, "", 5)
experiment1, trialIDs1, taskIDs1 := CreateTestRetentionExperiment(ctx, t, api, "", 5)
require.Nil(t, experiment1.EndTime)
require.Len(t, trialIDs1, 5)
require.Len(t, taskIDs1, 5)

// Create an experiment1 with 5 trials and a config to expire in 1000 days.
experiment2, trialIDs2, taskIDs2 := createTestRetentionExperiment(ctx, t, api, logRetentionConfig1000days, 5)
experiment2, trialIDs2, taskIDs2 := CreateTestRetentionExperiment(ctx, t, api, logRetentionConfig1000days, 5)
require.Nil(t, experiment2.EndTime)
require.Len(t, trialIDs2, 5)
require.Len(t, taskIDs2, 5)

// Create an experiment1 with 5 trials and config to never expire.
experiment3, trialIDs3, taskIDs3 := createTestRetentionExperiment(ctx, t, api, logRetentionConfigForever, 5)
experiment3, trialIDs3, taskIDs3 := CreateTestRetentionExperiment(ctx, t, api, logRetentionConfigForever, 5)
require.Nil(t, experiment3.EndTime)
require.Len(t, trialIDs3, 5)
require.Len(t, taskIDs3, 5)
Expand Down Expand Up @@ -312,7 +312,7 @@ func TestScheduleRetentionNoConfig(t *testing.T) {
require.NoError(t, err)

// Create an experiment1 with 5 trials and no special config.
experiment, trialIDs, taskIDs := createTestRetentionExperiment(ctx, t, api, "", 5)
experiment, trialIDs, taskIDs := CreateTestRetentionExperiment(ctx, t, api, "", 5)
require.Nil(t, experiment.EndTime)
require.Len(t, trialIDs, 5)
require.Len(t, taskIDs, 5)
Expand Down Expand Up @@ -361,7 +361,7 @@ func TestScheduleRetentionNoConfig(t *testing.T) {
}

// Mark experiments and trials as completed.
err = completeExpAndTrials(ctx, experiment.Id, trialIDs)
err = CompleteExpAndTrials(ctx, experiment.Id, trialIDs)
require.NoError(t, err)

// Advance time by 1 day.
Expand Down Expand Up @@ -409,7 +409,7 @@ func TestScheduleRetention1000days(t *testing.T) {
require.NoError(t, err)

// Create an experiment with 5 trials and a config to expire in 1000 days.
experiment, trialIDs, taskIDs := createTestRetentionExperiment(ctx, t, api, logRetentionConfig1000days, 5)
experiment, trialIDs, taskIDs := CreateTestRetentionExperiment(ctx, t, api, logRetentionConfig1000days, 5)
require.Nil(t, experiment.EndTime)
require.Len(t, trialIDs, 5)
require.Len(t, taskIDs, 5)
Expand Down Expand Up @@ -458,7 +458,7 @@ func TestScheduleRetention1000days(t *testing.T) {
}

// Mark experiments and trials as completed.
err = completeExpAndTrials(ctx, experiment.Id, trialIDs)
err = CompleteExpAndTrials(ctx, experiment.Id, trialIDs)
require.NoError(t, err)

// Advance time by 998 days.
Expand Down Expand Up @@ -513,7 +513,7 @@ func TestScheduleRetentionNeverExpire(t *testing.T) {
require.NoError(t, err)

// Create an experiment with 5 trials and config to never expire.
experiment, trialIDs, taskIDs := createTestRetentionExperiment(ctx, t, api, logRetentionConfigForever, 5)
experiment, trialIDs, taskIDs := CreateTestRetentionExperiment(ctx, t, api, logRetentionConfigForever, 5)
require.Nil(t, experiment.EndTime)
require.Len(t, trialIDs, 5)
require.Len(t, taskIDs, 5)
Expand Down Expand Up @@ -562,7 +562,7 @@ func TestScheduleRetentionNeverExpire(t *testing.T) {
}

// Mark experiments and trials as completed.
err = completeExpAndTrials(ctx, experiment.Id, trialIDs)
err = CompleteExpAndTrials(ctx, experiment.Id, trialIDs)
require.NoError(t, err)

// Advance time by 100 days.
Expand Down
11 changes: 5 additions & 6 deletions master/internal/api_trials.go
Original file line number Diff line number Diff line change
Expand Up @@ -710,12 +710,11 @@ func (a *apiServer) PutTrialRetainLogs(
}

err := db.Bun().RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
if _, err := tx.NewUpdate().Table("tasks"). // TODO(nick-runs) call runs package.
Set("log_retention_days = ?", req.NumDays).
TableExpr("run_id_task_id as r").
Where("r.run_id = ? and tasks.task_id = r.task_id", req.TrialId).
Exec(ctx); err != nil {
return fmt.Errorf("updating log retention days for tasks: %w", err)
if _, err := tx.NewUpdate().Table("runs").
Set("log_retention_days = ?", req.NumDays).
Where("id = ?", req.TrialId).
Exec(ctx); err != nil {
return fmt.Errorf("updating log retention days for trial: %w", err)
}
return nil
})
Expand Down
36 changes: 36 additions & 0 deletions master/internal/api_trials_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/google/uuid"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/uptrace/bun"
"golang.org/x/exp/maps"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
Expand Down Expand Up @@ -1401,3 +1402,38 @@ func TestGetTrialByExternalID(t *testing.T) {

require.Equal(t, int(resp.Trial.Id), trial.ID)
}

func getLogRetentionDays(ctx context.Context, trialIDs []int) ([]int32, error) {
var trialLogRetentionDays []int32
err := db.Bun().NewSelect().Table("runs").
Column("log_retention_days").
Where("id IN (?)", bun.In(trialIDs)).
Scan(ctx, &trialLogRetentionDays)

return trialLogRetentionDays, err
}

func TestPutTrialRetainLogs(t *testing.T) {
api, _, ctx := setupAPITest(t, nil)
exp, trialIDs, _ := CreateTestRetentionExperiment(ctx, t, api, logRetentionConfigForever, 5)

err := CompleteExpAndTrials(ctx, exp.Id, trialIDs)
require.NoError(t, err)

orgLogRetentionDays, err := getLogRetentionDays(ctx, trialIDs)
require.NoError(t, err)
require.Equal(t, orgLogRetentionDays, []int32{-1, -1, -1, -1, -1})

newLogRetentionDays := []int32{10, 10, 10, 10, 10}
for i, v := range trialIDs {
res, err := api.PutTrialRetainLogs(ctx, &apiv1.PutTrialRetainLogsRequest{
TrialId: int32(v), NumDays: newLogRetentionDays[i],
})
require.NoError(t, err)
require.NotNil(t, res)
}

updatedLogRetentionDays, err := getLogRetentionDays(ctx, trialIDs)
require.NoError(t, err)
require.Equal(t, updatedLogRetentionDays, newLogRetentionDays)
}
11 changes: 5 additions & 6 deletions master/internal/checkpoint_gc.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,11 @@ func runCheckpointGCTask(
syslog := logrus.WithField("component", "checkpointgc").WithFields(logCtx.Fields())

if err := db.AddTask(context.TODO(), &model.Task{
TaskID: taskID,
TaskType: model.TaskTypeCheckpointGC,
StartTime: time.Now().UTC(),
JobID: &jobID,
LogVersion: model.CurrentTaskLogVersion,
LogRetentionDays: taskSpec.LogRetentionDays,
TaskID: taskID,
TaskType: model.TaskTypeCheckpointGC,
StartTime: time.Now().UTC(),
JobID: &jobID,
LogVersion: model.CurrentTaskLogVersion,
}); err != nil {
return errors.Wrapf(err, "persisting GC task %s", taskID)
}
Expand Down
11 changes: 5 additions & 6 deletions master/internal/command/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,11 @@ func (c *Command) registerJobAndTask(ctx context.Context, tx bun.Tx) error {
}

if err := internaldb.AddTaskTx(ctx, tx, &model.Task{
TaskID: c.taskID,
TaskType: c.taskType,
StartTime: c.registeredTime,
JobID: &c.jobID,
LogVersion: model.CurrentTaskLogVersion,
LogRetentionDays: c.Base.LogRetentionDays,
TaskID: c.taskID,
TaskType: c.taskType,
StartTime: c.registeredTime,
JobID: &c.jobID,
LogVersion: model.CurrentTaskLogVersion,
}); err != nil {
return fmt.Errorf("persisting task %v: %w", c.taskID, err)
}
Expand Down
3 changes: 1 addition & 2 deletions master/internal/db/postgres_tasks.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func AddTask(ctx context.Context, t *model.Task) error {
func AddTaskTx(ctx context.Context, idb bun.IDB, t *model.Task) error {
_, err := idb.NewInsert().Model(t).
Column("task_id", "task_type", "start_time", "job_id", "log_version",
"config", "forked_from", "parent_id", "task_state", "no_pause", "log_retention_days").
"config", "forked_from", "parent_id", "task_state", "no_pause").
On("CONFLICT (task_id) DO UPDATE").
Set("task_type=EXCLUDED.task_type").
Set("start_time=EXCLUDED.start_time").
Expand All @@ -46,7 +46,6 @@ func AddTaskTx(ctx context.Context, idb bun.IDB, t *model.Task) error {
Set("parent_id=EXCLUDED.parent_id").
Set("task_state=EXCLUDED.task_state").
Set("no_pause=EXCLUDED.no_pause").
Set("log_retention_days=EXCLUDED.log_retention_days").
Exec(ctx)
return MatchSentinelError(err)
}
Expand Down
7 changes: 4 additions & 3 deletions master/internal/experiment/bulk_action.go
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,7 @@ func BulkUpdateLogRentention(ctx context.Context, database db.DB,
}
}

_, taskIDs, err := db.ExperimentsTrialAndTaskIDs(ctx, db.Bun(), intExpIDs)
trialIDs, taskIDs, err := db.ExperimentsTrialAndTaskIDs(ctx, db.Bun(), intExpIDs)
if err != nil {
return nil, errors.Wrapf(err, "failed to gather trial IDs for experiments")
}
Expand All @@ -817,14 +817,15 @@ func BulkUpdateLogRentention(ctx context.Context, database db.DB,
}

err = db.Bun().RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
if _, err := tx.NewUpdate().Table("tasks").
if _, err := tx.NewUpdate().Table("runs").
Set("log_retention_days = ?", numDays).
Where("task_id IN (?)", bun.In(taskIDs)).
Where("id IN (?)", bun.In(trialIDs)).
Exec(ctx); err != nil {
return fmt.Errorf("updating log retention days for tasks: %w", err)
}
return nil
})

if err != nil {
return nil, err
}
Expand Down
8 changes: 5 additions & 3 deletions master/internal/logretention/logretention.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,11 @@ func DeleteExpiredTaskLogs(ctx context.Context, days *int16) (int64, error) {
log.WithField("default-retention-days", defaultLogRetentionDays).Trace("deleting expired task logs")
r, err := db.Bun().NewRaw(fmt.Sprintf(`
WITH log_retention_tasks AS (
SELECT task_id, end_time, COALESCE(log_retention_days, %d) AS log_retention_days FROM tasks
WHERE task_id IN (SELECT DISTINCT task_id FROM task_logs)
AND end_time IS NOT NULL
SELECT COALESCE(r.log_retention_days, %d) as log_retention_days, t.task_id, t.end_time
FROM runs as r
JOIN run_id_task_id as r_t ON r.id = r_t.run_id
JOIN tasks as t ON r_t.task_id = t.task_id
WHERE t.end_time IS NOT NULL
)
DELETE FROM task_logs
WHERE task_id IN (
Expand Down
9 changes: 4 additions & 5 deletions master/internal/populate_metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,11 +247,10 @@ func PopulateExpTrialsMetrics(pgdb *db.PgDB, masterConfig *config.Config, trivia
// create task
tID := model.NewTaskID()
tIn := &model.Task{
TaskID: tID,
JobID: &jID,
TaskType: model.TaskTypeTrial,
StartTime: time.Now().UTC().Truncate(time.Millisecond),
LogRetentionDays: masterConfig.RetentionPolicy.LogRetentionDays,
TaskID: tID,
JobID: &jID,
TaskType: model.TaskTypeTrial,
StartTime: time.Now().UTC().Truncate(time.Millisecond),
}
if err = db.AddTask(ctx, tIn); err != nil {
return err
Expand Down
11 changes: 5 additions & 6 deletions master/internal/trial.go
Original file line number Diff line number Diff line change
Expand Up @@ -478,12 +478,11 @@ func (t *trial) maybeAllocateTask() error {

func (t *trial) addTask(ctx context.Context) error {
return db.AddTask(ctx, &model.Task{
TaskID: t.taskID,
TaskType: model.TaskTypeTrial,
StartTime: t.jobSubmissionTime, // TODO: Why is this the job submission time..?
JobID: &t.jobID,
LogVersion: model.CurrentTaskLogVersion,
LogRetentionDays: t.taskSpec.LogRetentionDays,
TaskID: t.taskID,
TaskType: model.TaskTypeTrial,
StartTime: t.jobSubmissionTime, // TODO: Why is this the job submission time..?
JobID: &t.jobID,
LogVersion: model.CurrentTaskLogVersion,
})
}

Expand Down
2 changes: 2 additions & 0 deletions master/pkg/model/experiment.go
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,7 @@ func (t *Trial) ToRunAndTrialV2(experimentsProjectID int) (*Run, *TrialV2) {
Restarts: t.Restarts,
RunnerState: t.RunnerState,
LastActivity: t.LastActivity,
LogRetentionDays: t.LogRetentionDays,
}
v2 := &TrialV2{
RunID: t.ID,
Expand Down Expand Up @@ -516,6 +517,7 @@ type Run struct {
Restarts int `db:"restarts"`
RunnerState string `db:"runner_state"`
LastActivity *time.Time `db:"last_activity"`
LogRetentionDays *int16 `db:"log_retention_days"`
}

// RunTaskID represents a row from the `run_id_task_id` table.
Expand Down
Loading

0 comments on commit cf2f2be

Please sign in to comment.