diff --git a/master/internal/api_project.go b/master/internal/api_project.go index c8632d26c966..91e80578cdfa 100644 --- a/master/internal/api_project.go +++ b/master/internal/api_project.go @@ -3,8 +3,8 @@ package internal import ( "context" "fmt" + "regexp" "sort" - "strings" "github.com/uptrace/bun" @@ -28,6 +28,10 @@ import ( "github.com/determined-ai/determined/proto/pkg/rbacv1" "github.com/determined-ai/determined/proto/pkg/workspacev1" ) +const ( + // ProjectKeyRegex is the regex pattern for a project key. + ProjectKeyRegex = "^[A-Z0-9]{5}$" +) var defaultRunsTableColumns = []*projectv1.ProjectColumn{ { @@ -223,6 +227,7 @@ func getRunSummaryMetrics(ctx context.Context, whereClause string, group []int) return columns, nil } + func (a *apiServer) GetProjectByID( ctx context.Context, id int32, curUser model.User, ) (*projectv1.Project, error) { @@ -711,6 +716,19 @@ func (a *apiServer) getProjectNumericMetricsRange( return metricsValues, searcherMetricsValue, nil } +func validateProjectKey(key string) error { + switch { + case len(key) > project.MaxProjectKeyLength: + return errors.Errorf("project key cannot be longer than %d characters", project.MaxProjectKeyLength) + case len(key) < 1: + return errors.New("project key cannot be empty") + case !regexp.MustCompile(ProjectKeyRegex).MatchString(key): + return errors.Errorf("project key can only contain alphanumeric characters") + default: + return nil + } +} + func (a *apiServer) PostProject( ctx context.Context, req *apiv1.PostProjectRequest, ) (*apiv1.PostProjectResponse, error) { @@ -726,28 +744,24 @@ func (a *apiServer) PostProject( return nil, status.Error(codes.PermissionDenied, err.Error()) } - var projectKey string - if req.Key == nil { - projectKey, err = project.GenerateProjectKey(ctx, req.Name) - if err != nil { - return nil, fmt.Errorf("error generating project key: %w", err) + if req.Key != nil { + if err = validateProjectKey(*req.Key); err != nil { + return nil, status.Error(codes.InvalidArgument, err.Error()) } - } else { - projectKey = *req.Key } - p := &projectv1.Project{} - err = a.m.db.QueryProto("insert_project", p, req.Name, req.Description, - req.WorkspaceId, curUser.ID, projectKey) - - if err != nil && strings.Contains(err.Error(), db.CodeUniqueViolation) { - if strings.Contains(err.Error(), "projects_key_key") { - return nil, - status.Errorf(codes.AlreadyExists, "project with key %s already exists", projectKey) - } + p := &model.Project{ + Name: req.Name, + Description: req.Description, + WorkspaceID: int(req.WorkspaceId), + UserID: int(curUser.ID), + Username: curUser.Username, } - return &apiv1.PostProjectResponse{Project: p}, + if err = project.InsertProject(ctx, p, req.Key); err != nil { + return nil, err + } + return &apiv1.PostProjectResponse{Project: p.Proto()}, errors.Wrapf(err, "error creating project %s in database", req.Name) } diff --git a/master/internal/api_project_intg_test.go b/master/internal/api_project_intg_test.go index 35d1bfebcbff..6abcdcc39a14 100644 --- a/master/internal/api_project_intg_test.go +++ b/master/internal/api_project_intg_test.go @@ -6,6 +6,7 @@ package internal import ( "context" "fmt" + "strings" "testing" "time" @@ -25,6 +26,7 @@ import ( "github.com/determined-ai/determined/master/internal/mocks" "github.com/determined-ai/determined/master/internal/project" "github.com/determined-ai/determined/master/pkg/model" + "github.com/determined-ai/determined/master/pkg/syncx/errgroupx" "github.com/determined-ai/determined/proto/pkg/apiv1" "github.com/determined-ai/determined/proto/pkg/projectv1" "github.com/determined-ai/determined/proto/pkg/rbacv1" @@ -437,21 +439,12 @@ func TestCreateProjectWithoutProjectKey(t *testing.T) { require.NoError(t, werr) projectName := "test-project" + uuid.New().String() - projectKeyPrefix := projectName[:3] + projectKeyPrefix := strings.ToUpper(projectName[:project.MaxProjectKeyPrefixLength]) resp, err := api.PostProject(ctx, &apiv1.PostProjectRequest{ Name: projectName, WorkspaceId: wresp.Workspace.Id, }) require.NoError(t, err) - - // Check that the project key is generated correctly. - countPostFix := 0 - err = db.Bun().NewSelect(). - ColumnExpr("COUNT(*)"). - Table("projects"). - Where("key ILIKE ?", (projectKeyPrefix+"%")). - Scan(ctx, &countPostFix) - require.NoError(t, err) - require.Equal(t, (projectKeyPrefix + fmt.Sprintf("%d", countPostFix)), resp.Project.Key) + require.Equal(t, projectKeyPrefix, resp.Project.Key[:project.MaxProjectKeyPrefixLength]) } func TestCreateProjectWithProjectKey(t *testing.T) { @@ -460,7 +453,7 @@ func TestCreateProjectWithProjectKey(t *testing.T) { require.NoError(t, werr) projectName := "test-project" + uuid.New().String() - projectKey := uuid.New().String()[:5] + projectKey := uuid.New().String()[:project.MaxProjectKeyLength] resp, err := api.PostProject(ctx, &apiv1.PostProjectRequest{ Name: projectName, WorkspaceId: wresp.Workspace.Id, Key: &projectKey, }) @@ -482,7 +475,7 @@ func TestCreateProjectWithDuplicateProjectKey(t *testing.T) { require.NoError(t, werr) projectName := "test-project" + uuid.New().String() - projectKey := uuid.New().String()[:5] + projectKey := uuid.New().String()[:project.MaxProjectKeyLength] _, err := api.PostProject(ctx, &apiv1.PostProjectRequest{ Name: projectName, WorkspaceId: wresp.Workspace.Id, Key: &projectKey, }) @@ -501,17 +494,43 @@ func TestCreateProjectWithDefaultKeyAndDuplicatePrefix(t *testing.T) { require.NoError(t, werr) projectName := uuid.New().String() - projectKeyPrefix := projectName[:3] + projectKeyPrefix := strings.ToUpper(projectName[:project.MaxProjectKeyPrefixLength]) resp1, err := api.PostProject(ctx, &apiv1.PostProjectRequest{ Name: projectName, WorkspaceId: wresp.Workspace.Id, }) require.NoError(t, err) - require.Equal(t, (projectKeyPrefix + "1"), resp1.Project.Key) + require.Equal(t, projectKeyPrefix, resp1.Project.Key[:project.MaxProjectKeyPrefixLength]) resp2, err := api.PostProject(ctx, &apiv1.PostProjectRequest{ Name: projectName + "2", WorkspaceId: wresp.Workspace.Id, }) require.NoError(t, err) require.NoError(t, err) - require.Equal(t, (projectKeyPrefix + "2"), resp2.Project.Key) + require.Equal(t, projectKeyPrefix, resp2.Project.Key[:project.MaxProjectKeyPrefixLength]) +} + +func TestConcurrentProjectKeyGenerationAttempts(t *testing.T) { + api, _, ctx := setupAPITest(t, nil) + wresp, werr := api.PostWorkspace(ctx, &apiv1.PostWorkspaceRequest{Name: uuid.New().String()}) + require.NoError(t, werr) + for x := 0; x < 20; x++ { + errgrp := errgroupx.WithContext(ctx) + for i := 0; i < 20; i++ { + projectName := "test-project" + uuid.New().String() + errgrp.Go(func(context.Context) error { + _, err := api.PostProject(ctx, &apiv1.PostProjectRequest{ + Name: projectName, WorkspaceId: wresp.Workspace.Id, + }) + require.NoError(t, err) + return err + }) + } + require.NoError(t, errgrp.Wait()) + t.Cleanup(func() { + _, err := db.Bun().NewDelete().Table("projects").Where("workspace_id = ?", wresp.Workspace.Id).Exec(ctx) + require.NoError(t, err) + _, err = db.Bun().NewDelete().Table("workspaces").Where("id = ?", wresp.Workspace.Id).Exec(ctx) + require.NoError(t, err) + }) + } } diff --git a/master/internal/db/postgres.go b/master/internal/db/postgres.go index 9a9b01124afe..da84aaf31526 100644 --- a/master/internal/db/postgres.go +++ b/master/internal/db/postgres.go @@ -204,6 +204,10 @@ const ( // insert/update violates a foreign key constraint. Obtained from: // https://www.postgresql.org/docs/10/errcodes-appendix.html CodeForeignKeyViolation = "23503" + // CodeSerializationFailure is the error code that Postgres uses to indicate that a transaction + // failed due to a serialization failure. Obtained from: + // https://www.postgresql.org/docs/10/errcodes-appendix.html + CodeSerializationFailure = "40001" ) // Close closes the underlying pq connection. diff --git a/master/internal/project/postgres_project.go b/master/internal/project/postgres_project.go index a8adb704df15..2dcc987248d4 100644 --- a/master/internal/project/postgres_project.go +++ b/master/internal/project/postgres_project.go @@ -4,9 +4,25 @@ import ( "context" "database/sql" "fmt" + "strings" + + "github.com/pkg/errors" + log "github.com/sirupsen/logrus" + "github.com/uptrace/bun" "github.com/determined-ai/determined/master/internal/db" "github.com/determined-ai/determined/master/internal/workspace" + "github.com/determined-ai/determined/master/pkg/model" + "github.com/determined-ai/determined/master/pkg/random" +) + +const ( + // MaxProjectKeyLength is the maximum length of a project key. + MaxProjectKeyLength = 5 + // MaxProjectKeyPrefixLength is the maximum length of a project key prefix. + MaxProjectKeyPrefixLength = 3 + // MaxRetries is the maximum number of retries for transaction conflicts. + MaxRetries = 5 ) // ProjectByName returns a project's ID if it exists in the given workspace and is not archived. @@ -52,11 +68,59 @@ func ProjectIDByName(ctx context.Context, workspaceID int, projectName string) ( } // GenerateProjectKey generates a unique project key for a project based on its name. -func GenerateProjectKey(ctx context.Context, projectName string) (string, error) { - generatedKey := "" - err := db.Bun().NewRaw("SELECT function_generate_project_key(?)", projectName).Scan(ctx, &generatedKey) - if err != nil { - return "", err +func generateProjectKey(ctx context.Context, tx bun.Tx, projectName string) (string, error) { + var key string + found := true + for i := 0; i < MaxRetries && found; i++ { + prefixLength := min(len(projectName), MaxProjectKeyPrefixLength) + prefix := projectName[:prefixLength] + suffix := random.String(MaxProjectKeyLength - prefixLength) + key = strings.ToUpper(prefix + suffix) + err := tx.NewSelect().Model(&model.Project{}).Where("key = ?", key).For("UPDATE").Scan(ctx) + found = err == nil + } + if found { + return "", fmt.Errorf("could not generate a unique project key") + } + return key, nil +} + +// InsertProject inserts a new project into the database. +func InsertProject( + ctx context.Context, + p *model.Project, + requestedKey *string, +) (err error) { +RetryLoop: + for i := 0; i < MaxRetries; i++ { + err = db.Bun().RunInTx(ctx, &sql.TxOptions{Isolation: sql.LevelRepeatableRead}, + func(ctx context.Context, tx bun.Tx) error { + var err error + if requestedKey == nil { + p.Key, err = generateProjectKey(ctx, tx, p.Name) + if err != nil { + return err + } + } else { + p.Key = *requestedKey + } + _, err = tx.NewInsert().Model(p).Exec(ctx) + if err != nil { + return err + } + return nil + }, + ) + + switch { + case err == nil: + break RetryLoop + case requestedKey == nil && strings.Contains(err.Error(), "duplicate key value violates unique constraint"): + log.Debugf("retrying project (%s) insertion due to generated key conflict (%s)", p.Name, p.Key) + continue // retry + default: + break RetryLoop + } } - return generatedKey, nil + return errors.Wrapf(err, "error inserting project %s into database", p.Name) } diff --git a/master/pkg/model/project.go b/master/pkg/model/project.go index 8b8f4c26d705..70c14f0f622b 100644 --- a/master/pkg/model/project.go +++ b/master/pkg/model/project.go @@ -17,17 +17,18 @@ type Project struct { CreatedAt time.Time `bun:"created_at,scanonly"` Archived bool `bun:"archived"` WorkspaceID int `bun:"workspace_id"` - WorkspaceName string `bun:"workspace_name"` + WorkspaceName string `bun:"workspace_name,scanonly"` UserID int `bun:"user_id"` - Username string `bun:"username"` + Username string `bun:"username,scanonly"` Immutable bool `bun:"immutable"` Description string `bun:"description"` Notes []*projectv1.Note `bun:"notes,type:jsonb"` - NumActiveExperiments int32 `bun:"num_active_experiments"` - NumExperiments int32 `bun:"num_experiments"` - State WorkspaceState `bun:"state"` + NumActiveExperiments int32 `bun:"num_active_experiments,scanonly"` + NumExperiments int32 `bun:"num_experiments,scanonly"` + State WorkspaceState `bun:"state,default:'UNSPECIFIED'::workspace_state"` ErrorMessage string `bun:"error_message"` - LastExperimentStartedAt time.Time `bun:"last_experiment_started_at"` + LastExperimentStartedAt time.Time `bun:"last_experiment_started_at,scanonly"` + Key string `bun:"key"` } // Projects is an array of project instances. @@ -55,6 +56,7 @@ func (p Project) Proto() *projectv1.Project { NumActiveExperiments: p.NumActiveExperiments, Notes: p.Notes, LastExperimentStartedAt: lastExperimentStartedAt, + Key: p.Key, } } diff --git a/master/pkg/random/string.go b/master/pkg/random/string.go new file mode 100644 index 000000000000..594baa8db6f4 --- /dev/null +++ b/master/pkg/random/string.go @@ -0,0 +1,31 @@ +package random + +import ( + "crypto/rand" +) + +const ( + // DefaultChars is the default character set used for generating random strings. + DefaultChars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" +) + +// String generates a random string of length n, using the characters in the charset string +// if provided, or using the default charset if not. +func String(n int, charset ...string) string { + var chars string + if len(charset) == 0 { + chars = DefaultChars + } else { + chars = charset[0] + } + + bytes := make([]byte, n) + _, err := rand.Read(bytes) + if err != nil { + panic(err) + } + for i, b := range bytes { + bytes[i] = chars[b%byte(len(chars))] + } + return string(bytes) +} diff --git a/master/static/migrations/20240409104254_add-custom-project-key.tx.down.sql b/master/static/migrations/20240409104254_add-custom-project-key.tx.down.sql deleted file mode 100644 index c340ee7bf319..000000000000 --- a/master/static/migrations/20240409104254_add-custom-project-key.tx.down.sql +++ /dev/null @@ -1,2 +0,0 @@ -DROP FUNCTION function_generate_project_key(TEXT); -ALTER TABLE projects DROP COLUMN key; diff --git a/master/static/migrations/20240409104254_add-custom-project-key.tx.up.sql b/master/static/migrations/20240409104254_add-custom-project-key.tx.up.sql deleted file mode 100644 index d72d7c942ce5..000000000000 --- a/master/static/migrations/20240409104254_add-custom-project-key.tx.up.sql +++ /dev/null @@ -1,23 +0,0 @@ -CREATE OR REPLACE FUNCTION function_generate_project_key(input_string TEXT) -RETURNS TEXT AS $$ -DECLARE - prefix TEXT; - count_suffix TEXT; - count_value INT; -BEGIN - -- Take the first 3 characters of the input string - prefix := LEFT(input_string, 3); - - execute format('SELECT COUNT(*)+1 FROM projects WHERE key ILIKE $1') - into count_value - using prefix || '%'; - count_suffix := CAST(count_value as text); - - -- Concatenate prefix and count suffix - RETURN lower(prefix || count_suffix); -END; -$$ LANGUAGE plpgsql; - -ALTER TABLE projects ADD COLUMN key VARCHAR(5) UNIQUE; - -UPDATE projects SET key = function_generate_project_key(name) WHERE key IS NULL; diff --git a/master/static/migrations/20240508154139_add-custom-project-key.tx.up.sql b/master/static/migrations/20240508154139_add-custom-project-key.tx.up.sql new file mode 100644 index 000000000000..e2c54329bcf6 --- /dev/null +++ b/master/static/migrations/20240508154139_add-custom-project-key.tx.up.sql @@ -0,0 +1,33 @@ +CREATE OR REPLACE FUNCTION function_generate_project_key( + max_key_length INTEGER, + max_prefix_length INTEGER, + input_string TEXT +) +RETURNS TEXT AS $$ +DECLARE + prefix TEXT; + prefix_length INT; + suffix TEXT; +BEGIN + -- Take the first 3 characters of the input string + prefix := UPPER(LEFT(input_string, max_prefix_length)); + prefix_length := LENGTH(prefix); + -- Generate a random suffix + suffix := UPPER(LEFT(md5(random()::text), max_key_length - prefix_length)); + + -- Check if the key already exists and loop until we find a unique key + WHILE EXISTS(SELECT 1 FROM projects WHERE key = prefix || suffix) LOOP + suffix := UPPER(LEFT(md5(random()::text), max_key_length - prefix_length)); + END LOOP; + + RETURN prefix || suffix; +END; +$$ LANGUAGE plpgsql; + +ALTER TABLE projects ADD COLUMN IF NOT EXISTS key VARCHAR(5) UNIQUE; + +UPDATE + projects +SET + key = function_generate_project_key(5, 3, name) +WHERE key IS NULL;