Skip to content

Commit

Permalink
fix: sorting by arbitrary metadata (#9874)
Browse files Browse the repository at this point in the history
  • Loading branch information
AmanuelAaron authored Aug 28, 2024
1 parent c1b7767 commit 0ef81aa
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 22 deletions.
11 changes: 11 additions & 0 deletions master/internal/api_runs.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,17 @@ func sortRuns(sortString *string, runQuery *bun.SelectQuery) error {
hpQuery := strings.Join(hp, "->")
queryArgs = append(queryArgs, bun.Safe(sortDirection))
runQuery.OrderExpr(fmt.Sprintf(`r.hparams->%s ?`, hpQuery), queryArgs...)
case strings.HasPrefix(paramDetail[0], "metadata."):
param := strings.ReplaceAll(paramDetail[0], "'", "")
mdt := strings.Split(strings.TrimPrefix(param, "metadata."), ".")
var queryArgs []interface{}
for i := 0; i < len(mdt); i++ {
queryArgs = append(queryArgs, mdt[i])
mdt[i] = "?"
}
mdtQuery := strings.Join(mdt, "->")
queryArgs = append(queryArgs, bun.Safe(sortDirection))
runQuery.OrderExpr(fmt.Sprintf(`rm.metadata->%s ?`, mdtQuery), queryArgs...)
case strings.Contains(paramDetail[0], "."):
metricGroup, metricName, metricQualifier, err := parseMetricsName(paramDetail[0])
if err != nil {
Expand Down
90 changes: 68 additions & 22 deletions master/internal/api_runs_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,35 +193,81 @@ func TestSearchRunsSort(t *testing.T) {
HParams: hyperparameters2,
}, task2.TaskID))

// Sort by start time
resp, err = api.SearchRuns(ctx, &apiv1.SearchRunsRequest{
ProjectId: req.ProjectId,
Sort: ptrs.Ptr("startTime=asc"),
})

// Get runs in project
resp, err = api.SearchRuns(ctx, req)
require.NoError(t, err)
require.Equal(t, int32(exp.ID), resp.Runs[0].Experiment.Id)
require.Equal(t, int32(exp2.ID), resp.Runs[1].Experiment.Id)
require.Len(t, resp.Runs, 2)

// Sort by hyperparameter
resp, err = api.SearchRuns(ctx, &apiv1.SearchRunsRequest{
ProjectId: req.ProjectId,
Sort: ptrs.Ptr("hp.global_batch_size=desc"),
// add metadata
rawMetadata := map[string]any{
"number_key": 1,
"nested": map[string]any{
"number_key": 1,
},
}
metadata := newProtoStruct(t, rawMetadata)
_, err = api.PostRunMetadata(ctx, &apiv1.PostRunMetadataRequest{
RunId: resp.Runs[0].Id,
Metadata: metadata,
})

require.NoError(t, err)
require.Equal(t, int32(exp2.ID), resp.Runs[0].Experiment.Id)
require.Equal(t, int32(exp.ID), resp.Runs[1].Experiment.Id)

// Sort by nested hyperparameter
resp, err = api.SearchRuns(ctx, &apiv1.SearchRunsRequest{
ProjectId: req.ProjectId,
Sort: ptrs.Ptr("hp.test1.test2=desc"),
rawMetadata = map[string]any{
"number_key": 2,
"nested": map[string]any{
"number_key": 2,
},
}
metadata = newProtoStruct(t, rawMetadata)
_, err = api.PostRunMetadata(ctx, &apiv1.PostRunMetadataRequest{
RunId: resp.Runs[1].Id,
Metadata: metadata,
})

require.NoError(t, err)
require.Equal(t, int32(exp2.ID), resp.Runs[0].Experiment.Id)
require.Equal(t, int32(exp.ID), resp.Runs[1].Experiment.Id)

tests := map[string]struct {
sortBy string
reverse bool
}{
"StartTime": {
sortBy: "startTime=asc",
reverse: false,
},
"Hyperparameter": {
sortBy: "hp.global_batch_size=desc",
reverse: true,
},
"HyperparameterNested": {
sortBy: "hp.test1.test2=desc",
reverse: true,
},
"Metadata": {
sortBy: "metadata.number_key=desc",
reverse: true,
},
"MetadataNested": {
sortBy: "metadata.nested.number_key=desc",
reverse: true,
},
}

for testCase, testVars := range tests {
t.Run(testCase, func(t *testing.T) {
resp, err = api.SearchRuns(ctx, &apiv1.SearchRunsRequest{
ProjectId: &projectID,
Sort: ptrs.Ptr(testVars.sortBy),
})

require.NoError(t, err)
if testVars.reverse {
require.Equal(t, int32(exp2.ID), resp.Runs[0].Experiment.Id)
require.Equal(t, int32(exp.ID), resp.Runs[1].Experiment.Id)
} else {
require.Equal(t, int32(exp.ID), resp.Runs[0].Experiment.Id)
require.Equal(t, int32(exp2.ID), resp.Runs[1].Experiment.Id)
}
})
}
}

func TestSearchRunsFilter(t *testing.T) {
Expand Down

0 comments on commit 0ef81aa

Please sign in to comment.