Skip to content

Commit

Permalink
more wip
Browse files Browse the repository at this point in the history
  • Loading branch information
carolinaecalderon committed May 6, 2024
1 parent 7dbdd86 commit a04212f
Show file tree
Hide file tree
Showing 71 changed files with 356 additions and 338 deletions.
6 changes: 6 additions & 0 deletions master/.golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@ linters:
- funlen # TODO(DET-9959)
- nestif # TODO(DET-9960)
- depguard # TOo many errors for now -- if we enable, this will require a large
- protogetter
- revive
- tagalign
- nakedret
- testifylint
- perfsprint


# Toss up linters.
Expand Down
3 changes: 2 additions & 1 deletion master/internal/api_command.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"encoding/json"
"fmt"
"math/rand"
"strconv"

petname "github.com/dustinkirkland/golang-petname"
pstruct "github.com/golang/protobuf/ptypes/struct"
Expand Down Expand Up @@ -176,7 +177,7 @@ func (a *apiServer) GetCommands(
return nil, err
}

workspaceNotFoundErr := api.NotFoundErrs("workspace", fmt.Sprint(req.WorkspaceId), true)
workspaceNotFoundErr := api.NotFoundErrs("workspace", strconv.Itoa(int(req.WorkspaceId)), true)

if req.WorkspaceId != 0 {
// check if the workspace exists.
Expand Down
11 changes: 5 additions & 6 deletions master/internal/api_experiment.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ func (a *apiServer) getExperiment(
func (a *apiServer) getExperimentTx(
ctx context.Context, idb bun.IDB, curUser model.User, experimentID int,
) (*experimentv1.Experiment, error) {
expNotFound := api.NotFoundErrs("experiment", fmt.Sprint(experimentID), true)
expNotFound := api.NotFoundErrs("experiment", strconv.Itoa(experimentID), true)
exp := &experimentv1.Experiment{}
expMap := map[string]interface{}{}
query := `
Expand Down Expand Up @@ -273,7 +273,7 @@ func (a *apiServer) GetSearcherEvents(

e, ok := experiment.ExperimentRegistry.Load(int(req.ExperimentId))
if !ok {
return nil, api.NotFoundErrs("experiment", fmt.Sprint(req.ExperimentId), true)
return nil, api.NotFoundErrs("experiment", strconv.Itoa(int(req.ExperimentId)), true)
}
w, err := e.GetSearcherEventsWatcher()
if err != nil {
Expand Down Expand Up @@ -316,7 +316,7 @@ func (a *apiServer) PostSearcherOperations(

e, ok := experiment.ExperimentRegistry.Load(int(req.ExperimentId))
if !ok {
return nil, api.NotFoundErrs("experiment", fmt.Sprint(req.ExperimentId), true)
strconv.Itoa(int(req.ExperimentId))
}
if err := e.PerformSearcherOperations(req); err != nil {
return nil, status.Errorf(codes.Internal, "failed to post operations: %v", err)
Expand Down Expand Up @@ -836,7 +836,7 @@ func (a *apiServer) GetExperimentValidationHistory(
var resp apiv1.GetExperimentValidationHistoryResponse
switch err := a.m.db.QueryProto("proto_experiment_validation_history", &resp, req.ExperimentId); {
case err == db.ErrNotFound:
return nil, api.NotFoundErrs("experiment", fmt.Sprint(req.ExperimentId), true)
strconv.Itoa(int(req.ExperimentId))
case err != nil:
return nil, errors.Wrapf(err,
"error fetching validation history for experiment from database: %d", req.ExperimentId)
Expand Down Expand Up @@ -898,7 +898,7 @@ func (a *apiServer) PreviewHPSearch(

sm := searcher.NewSearchMethod(sc)
s := searcher.NewSearcher(req.Seed, sm, hc)
sim, err := searcher.Simulate(s, nil, searcher.RandomValidation, true, sc.Metric())
sim, err := searcher.Simulate(s, nil, searcher.RandomValidation, true)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -2806,7 +2806,6 @@ func (a *apiServer) PutTrial(ctx context.Context, req *apiv1.PutTrialRequest) (
trial = innerResp.Trial
return nil
})

if err != nil {
return nil, fmt.Errorf("failed to run create trial tx: %w", err)
}
Expand Down
8 changes: 4 additions & 4 deletions master/internal/api_experiment_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -433,9 +433,9 @@ func TestMoveExperiments(t *testing.T) {

t.Run("Move mixed of existent and non-existent experiments", func(t *testing.T) {
exp := createTestExp(t, api, curUser)
expIds := []int32{-1, 0, int32(exp.ID)}
expIDs := []int32{-1, 0, int32(exp.ID)}
result, err := api.MoveExperiments(ctx, &apiv1.MoveExperimentsRequest{
ExperimentIds: expIds,
ExperimentIds: expIDs,
DestinationProjectId: int32(projectID),
Filters: nil,
})
Expand All @@ -457,7 +457,7 @@ func TestMoveExperiments(t *testing.T) {
require.Len(t, successIDList, 1)
require.Len(t, errorIDList, 2)
require.Equal(t, successIDList[0], int32(exp.ID))
require.Len(t, result.Results, len(expIds))
require.Len(t, result.Results, len(expIDs))
require.NoError(t, err)
})
}
Expand Down Expand Up @@ -1906,7 +1906,7 @@ func TestAuthZGetExperimentAndCanDoActions(t *testing.T) {

// FilterExperimentsQuery error returned unmodified.
expectedErr := fmt.Errorf("canGetExperimentError")
authZExp.On("FilterExperimentsQuery", mock.Anything, mock.Anything, mock.Anything,
authZExp.On("FilterExperimentsQuery", mock.Anything, mock.Anything, mock.Anything, mock.Anything,
mock.Anything).
Return(resQuery, expectedErr).Once().Run(func(args mock.Arguments) {
q := args.Get(3).(*bun.SelectQuery).Where("0 = 1")
Expand Down
34 changes: 17 additions & 17 deletions master/internal/api_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,11 @@ func (a *apiServer) GetModels(
archFilterExpr = strconv.FormatBool(req.Archived.Value)
}
userFilterExpr := strings.Join(req.Users, ",")
userIds := make([]string, 0, len(req.UserIds))
userIDs := make([]string, 0, len(req.UserIds))
for _, userID := range req.UserIds {
userIds = append(userIds, strconv.Itoa(int(userID)))
userIDs = append(userIDs, strconv.Itoa(int(userID)))
}
userIDFilterExpr := strings.Join(userIds, ",")
userIDFilterExpr := strings.Join(userIDs, ",")
labelFilterExpr := strings.Join(req.Labels, ",")
// Construct the ordering expression.
sortColMap := map[apiv1.GetModelsRequest_SortBy]string{
Expand Down Expand Up @@ -140,36 +140,36 @@ func (a *apiServer) GetModels(
if err != nil {
return nil, err
}
var workspaceIdsGiven []int32
var workspaceIDsGiven []int32
if req.WorkspaceIds != nil {
// default is to use workspace ids.
workspaceIdsGiven = req.WorkspaceIds
// default is to use workspace IDs.
workspaceIDsGiven = req.WorkspaceIds
} else if req.WorkspaceIds == nil && req.WorkspaceNames != nil {
// get the ids of the corresponding workspaces
if err := db.Bun().NewSelect().Table("workspaces").Column("id").
Where("name in (?)", bun.In(req.WorkspaceNames)).Distinct().
Scan(ctx, &workspaceIdsGiven); err != nil {
Scan(ctx, &workspaceIDsGiven); err != nil {
return nil, fmt.Errorf("getting workspace ids from names: %w", err)
}
}
// function below returns a list of workspaces that have permissions
// filtered according to user given workspaces.
// if global permissions and no filter list given by user then it's an empty list.
workspaceIdsWithPermsAndFilterList, err := modelauth.AuthZProvider.Get().
CanGetModels(ctx, *curUser, workspaceIdsGiven)
workspaceIDsWithPermsAndFilterList, err := modelauth.AuthZProvider.Get().
CanGetModels(ctx, *curUser, workspaceIDsGiven)
if err != nil {
return nil, authz.SubIfUnauthorized(err, errors.Errorf(
"current user doesn't have view permissions in related workspaces"))
}
var workspaceIds []string
var workspaceIdsWithPermsAndFilter string
if workspaceIdsWithPermsAndFilterList == nil {
workspaceIdsWithPermsAndFilter = ""
var workspaceIDs []string
var workspaceIDsWithPermsAndFilter string
if workspaceIDsWithPermsAndFilterList == nil {
workspaceIDsWithPermsAndFilter = ""
} else {
for _, wID := range workspaceIdsWithPermsAndFilterList {
workspaceIds = append(workspaceIds, strconv.Itoa(int(wID)))
for _, wID := range workspaceIDsWithPermsAndFilterList {
workspaceIDs = append(workspaceIDs, strconv.Itoa(int(wID)))
}
workspaceIdsWithPermsAndFilter = strings.Join(workspaceIds, ",")
workspaceIDsWithPermsAndFilter = strings.Join(workspaceIDs, ",")
}

err = a.m.db.QueryProtof(
Expand All @@ -183,7 +183,7 @@ func (a *apiServer) GetModels(
labelFilterExpr,
nameFilter,
descFilterExpr,
workspaceIdsWithPermsAndFilter,
workspaceIDsWithPermsAndFilter,
)
if err != nil {
return nil, err
Expand Down
4 changes: 2 additions & 2 deletions master/internal/api_resourcepool_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -309,9 +309,9 @@ func TestPatchBindingsSucceeds(t *testing.T) {
})
require.NoError(t, err)
require.Equal(t, 2, len(resp.WorkspaceIds))
expectedIds := set.FromSlice[int32](workspaceIDs)
expectedIDs := set.FromSlice[int32](workspaceIDs)
for _, id := range resp.WorkspaceIds {
require.True(t, expectedIds.Contains(id))
require.True(t, expectedIDs.Contains(id))
}

require.True(t, mockRM.AssertExpectations(t))
Expand Down
26 changes: 13 additions & 13 deletions master/internal/api_runs_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ func TestSearchRunsFilter(t *testing.T) {
}
}

func TestMoveRunsIds(t *testing.T) {
func TestMoveRunsIDs(t *testing.T) {
api, curUser, ctx := setupAPITest(t, nil)
_, projectIDInt := createProjectAndWorkspace(ctx, t, api)
sourceprojectID := int32(1)
Expand All @@ -258,10 +258,10 @@ func TestMoveRunsIds(t *testing.T) {
run1, _ := createTestTrial(t, api, curUser)
run2, _ := createTestTrial(t, api, curUser)

moveIds := []int32{int32(run1.ID)}
moveIDs := []int32{int32(run1.ID)}

moveReq := &apiv1.MoveRunsRequest{
RunIds: moveIds,
RunIds: moveIDs,
SourceProjectId: sourceprojectID,
DestinationProjectId: destprojectID,
SkipMultitrial: false,
Expand Down Expand Up @@ -294,7 +294,7 @@ func TestMoveRunsIds(t *testing.T) {
resp, err = api.SearchRuns(ctx, req)
require.NoError(t, err)
require.Len(t, resp.Runs, 1)
require.Equal(t, moveIds[0], resp.Runs[0].Id)
require.Equal(t, moveIDs[0], resp.Runs[0].Id)

// Experiment in new project
exp, err := api.getExperiment(ctx, curUser, run1.ExperimentID)
Expand Down Expand Up @@ -341,10 +341,10 @@ func TestMoveRunsMultiTrialSkip(t *testing.T) {
api, curUser, ctx := setupAPITest(t, nil)
sourceprojectID, destprojectID, runID1, runID2, _ := setUpMultiTrialExperiments(ctx, t, api, curUser)

moveIds := []int32{runID1}
moveIDs := []int32{runID1}

moveReq := &apiv1.MoveRunsRequest{
RunIds: moveIds,
RunIds: moveIDs,
SourceProjectId: sourceprojectID,
DestinationProjectId: destprojectID,
SkipMultitrial: true,
Expand Down Expand Up @@ -382,10 +382,10 @@ func TestMoveRunsMultiTrialNoSkip(t *testing.T) {
api, curUser, ctx := setupAPITest(t, nil)
sourceprojectID, destprojectID, runID1, runID2, expID := setUpMultiTrialExperiments(ctx, t, api, curUser)

moveIds := []int32{runID1}
moveIDs := []int32{runID1}

moveReq := &apiv1.MoveRunsRequest{
RunIds: moveIds,
RunIds: moveIDs,
SourceProjectId: sourceprojectID,
DestinationProjectId: destprojectID,
SkipMultitrial: false,
Expand Down Expand Up @@ -460,10 +460,10 @@ func TestMoveRunsFilter(t *testing.T) {
require.NoError(t, err)

// If provided with filter MoveRuns should ignore these move ids
moveIds := []int32{resp.Runs[0].Id, resp.Runs[1].Id}
moveIDs := []int32{resp.Runs[0].Id, resp.Runs[1].Id}

moveReq := &apiv1.MoveRunsRequest{
RunIds: moveIds,
RunIds: moveIDs,
SourceProjectId: sourceprojectID,
DestinationProjectId: destprojectID,
Filter: ptrs.Ptr(`{"filterGroup":{"children":[{"columnName":"hp.test1.test2","kind":"field",` +
Expand Down Expand Up @@ -547,7 +547,7 @@ func TestDeleteRunsNonTerminal(t *testing.T) {
require.Len(t, searchResp.Runs, 2)
}

func TestDeleteRunsIds(t *testing.T) {
func TestDeleteRunsIDs(t *testing.T) {
api, curUser, ctx := setupAPITest(t, nil)
projectID, _, runID1, runID2, _ := setUpMultiTrialExperiments(ctx, t, api, curUser)

Expand All @@ -574,7 +574,7 @@ func TestDeleteRunsIds(t *testing.T) {
require.Len(t, searchResp.Runs, 0)
}

func TestDeleteRunsIdsNonExistant(t *testing.T) {
func TestDeleteRunsIDsNonExistant(t *testing.T) {
api, _, ctx := setupAPITest(t, nil)
_, projectIDInt := createProjectAndWorkspace(ctx, t, api)
projectID := int32(projectIDInt)
Expand Down Expand Up @@ -771,7 +771,7 @@ func TestDeleteRunsNoInput(t *testing.T) {
require.Len(t, resp.Results, 0)
}

func TestArchiveUnarchiveIds(t *testing.T) {
func TestArchiveUnarchiveIDs(t *testing.T) {
api, curUser, ctx := setupAPITest(t, nil)
projectID, _, runID1, runID2, _ := setUpMultiTrialExperiments(ctx, t, api, curUser)

Expand Down
10 changes: 5 additions & 5 deletions master/internal/api_trials_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,10 +307,10 @@ func TestMultiTrialSampleMetrics(t *testing.T) {
maxDataPoints := 7

var trainMetricNames []string
var metricIds []string
var metricIDs []string
for metricName := range expectedTrainMetrics[0].AvgMetrics.AsMap() {
trainMetricNames = append(trainMetricNames, metricName)
metricIds = append(metricIds, "training."+metricName)
metricIDs = append(metricIDs, "training."+metricName)
}
actualTrainingMetrics, err := api.multiTrialSample(int32(trial.ID), trainMetricNames,
model.TrainingMetricGroup, maxDataPoints, 0, 10, nil, []string{})
Expand All @@ -320,7 +320,7 @@ func TestMultiTrialSampleMetrics(t *testing.T) {
var validationMetricNames []string
for metricName := range expectedValMetrics[0].AvgMetrics.AsMap() {
validationMetricNames = append(validationMetricNames, metricName)
metricIds = append(metricIds, "validation."+metricName)
metricIDs = append(metricIDs, "validation."+metricName)
}
actualValidationTrainingMetrics, err := api.multiTrialSample(int32(trial.ID),
validationMetricNames, model.ValidationMetricGroup, maxDataPoints,
Expand All @@ -331,7 +331,7 @@ func TestMultiTrialSampleMetrics(t *testing.T) {
var genericMetricNames []string
for metricName := range expectedValMetrics[0].AvgMetrics.AsMap() {
genericMetricNames = append(genericMetricNames, metricName)
metricIds = append(metricIds, "mygroup."+metricName)
metricIDs = append(metricIDs, "mygroup."+metricName)
}
actualGenericTrainingMetrics, err := api.multiTrialSample(int32(trial.ID),
genericMetricNames, model.MetricGroup("mygroup"), maxDataPoints,
Expand All @@ -343,7 +343,7 @@ func TestMultiTrialSampleMetrics(t *testing.T) {
require.True(t, isMultiTrialSampleCorrect(expectedValMetrics, actualValidationTrainingMetrics[0]))

actualAllMetrics, err := api.multiTrialSample(int32(trial.ID), []string{},
"", maxDataPoints, 0, 10, nil, metricIds)
"", maxDataPoints, 0, 10, nil, metricIDs)
require.Equal(t, 3, len(actualAllMetrics))
require.NoError(t, err)
require.Equal(t, maxDataPoints, len(actualAllMetrics[1].Data)) // max datapoints check
Expand Down
2 changes: 1 addition & 1 deletion master/internal/authz/permission_denied_error.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ type PermissionDeniedError struct {
// Error returns an error string.
func (p PermissionDeniedError) Error() string {
if len(p.RequiredPermissions) == 0 {
return strings.TrimSpace(fmt.Sprintf("%s access denied", p.Prefix))
return p.Prefix + " access denied"
}

permissions := make([]string, len(p.RequiredPermissions))
Expand Down
8 changes: 4 additions & 4 deletions master/internal/cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ func (f *FileCache) FileContent(expID int, path string) ([]byte, error) {
return []byte{}, err
}
for _, file := range fileTree {
if file.Path == path && !file.IsDir {
if file.GetPath() == path && !file.GetIsDir() {
folder.lock.Lock()
defer folder.lock.Unlock()
file, err := os.ReadFile(f.genPath(expID, path))
Expand Down Expand Up @@ -214,7 +214,7 @@ func genNestedTree(fileTree []*experimentv1.FileNode) []*experimentv1.FileNode {
var fileTreeNested []*experimentv1.FileNode
for _, file := range fileTree {
fileTreeNested = insertToTree(
fileTreeNested, strings.Split(file.Path, string(os.PathSeparator)), file)
fileTreeNested, strings.Split(file.GetPath(), string(os.PathSeparator)), file)
}
return fileTreeNested
}
Expand All @@ -225,14 +225,14 @@ func insertToTree(
if len(paths) > 0 {
var i int
for i = 0; i < len(root); i++ {
if root[i].Name == paths[0] {
if root[i].GetName() == paths[0] {
break
}
}
if i == len(root) {
root = append(root, node)
}
root[i].Files = insertToTree(root[i].Files, paths[1:], node)
root[i].Files = insertToTree(root[i].GetFiles(), paths[1:], node)
}
return root
}
2 changes: 1 addition & 1 deletion master/internal/cache/cache_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func TestCache(t *testing.T) {
files, _, err := cache.getFileTree(expID)
require.NoError(t, err)
require.NotEmpty(t, files)
path := files[0].Path
path := files[0].GetPath()
_, err = cache.FileContent(expID, path)
require.NoError(t, err)

Expand Down
2 changes: 1 addition & 1 deletion master/internal/checkpoint_gc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ func TestRunCheckpointGCTask(t *testing.T) {
task.DefaultService = tt.args.as(t)
defer func() { task.DefaultService = tmp }()

jobID := db.RequireMockJob(t, pgDB, &user.ID)
jobID := db.RequireMockJob(t, &user.ID)

if err := runCheckpointGCTask(
tt.args.rm,
Expand Down
Loading

0 comments on commit a04212f

Please sign in to comment.