From 09cd2e3d3e5a54ed03bb1d786cead8c3a55cdf81 Mon Sep 17 00:00:00 2001 From: featherchen Date: Wed, 23 Oct 2024 02:40:22 -0700 Subject: [PATCH] feat(test): TestGetTask Signed-off-by: featherchen --- .../pkg/repositories/gormimpl/task_repo.go | 4 ++-- .../repositories/gormimpl/task_repo_test.go | 22 +++++++++++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/flyteadmin/pkg/repositories/gormimpl/task_repo.go b/flyteadmin/pkg/repositories/gormimpl/task_repo.go index 828888e646c..c8c9a6948f8 100644 --- a/flyteadmin/pkg/repositories/gormimpl/task_repo.go +++ b/flyteadmin/pkg/repositories/gormimpl/task_repo.go @@ -51,8 +51,8 @@ func (r *TaskRepo) Get(ctx context.Context, input interfaces.Identifier) (models timer := r.metrics.GetDuration.Start() var tx *gorm.DB if input.Version == "" { - tx := r.db.WithContext(ctx).Limit(1) - tx = tx.Order("DESC") + tx := r.db.WithContext(ctx).Where("project = ? AND domain = ? AND name = ?", input.Project, input.Domain, input.Name).Limit(1) + tx = tx.Order("version DESC") tx.Find(&task) } else { tx = r.db.WithContext(ctx).Where(&models.Task{ diff --git a/flyteadmin/pkg/repositories/gormimpl/task_repo_test.go b/flyteadmin/pkg/repositories/gormimpl/task_repo_test.go index 3309ad36090..5204c2a1cf2 100644 --- a/flyteadmin/pkg/repositories/gormimpl/task_repo_test.go +++ b/flyteadmin/pkg/repositories/gormimpl/task_repo_test.go @@ -79,6 +79,28 @@ func TestGetTask(t *testing.T) { assert.Equal(t, version, output.Version) assert.Equal(t, []byte{1, 2}, output.Closure) assert.Equal(t, pythonTestTaskType, output.Type) + + //When version is empty, return the latest task + GlobalMock = mocket.Catcher.Reset() + GlobalMock.Logging = true + + GlobalMock.NewMock().WithQuery( + `SELECT * FROM "tasks" WHERE project = $1 AND domain = $2 AND name = $3 ORDER BY version DESC LIMIT 1`). + WithReply(tasks) + output, err = taskRepo.Get(context.Background(), interfaces.Identifier{ + Project: project, + Domain: domain, + Name: name, + Version: "", + }) + + assert.NoError(t, err) + assert.Equal(t, project, output.Project) + assert.Equal(t, domain, output.Domain) + assert.Equal(t, name, output.Name) + assert.Equal(t, "v2", output.Version) + assert.Equal(t, []byte{3, 4}, output.Closure) + assert.Equal(t, pythonTestTaskType, output.Type) } func TestListTasks(t *testing.T) {