Skip to content

Commit

Permalink
fix: fix contains filter for hyperparameters and metadata (#9779)
Browse files Browse the repository at this point in the history
(cherry picked from commit 61aad78)
  • Loading branch information
AmanuelAaron authored and determined-ci committed Aug 1, 2024
1 parent 501d45c commit 94f916d
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 61 deletions.
34 changes: 20 additions & 14 deletions master/internal/api_runs_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,10 @@ func TestSearchRunsFilter(t *testing.T) {
require.NoError(t, err)
require.Empty(t, resp.Runs)

hyperparameters := map[string]any{"global_batch_size": 1, "test1": map[string]any{"test2": 1}}
hyperparameters := map[string]any{
"global_batch_size": 1, "test1": map[string]any{"test2": 1},
"stringVal": "apple", "test3": map[string]any{"stringVal": "apple"},
}

exp := createTestExpWithProjectID(t, api, curUser, projectIDInt)

Expand All @@ -255,7 +258,10 @@ func TestSearchRunsFilter(t *testing.T) {
require.NoError(t, err)
require.Len(t, resp.Runs, 1)

hyperparameters2 := map[string]any{"global_batch_size": 2, "test1": map[string]any{"test2": 5}}
hyperparameters2 := map[string]any{
"global_batch_size": 2, "test1": map[string]any{"test2": 5},
"stringVal": "bright", "test3": map[string]any{"stringVal": "bright"},
}

// Add second experiment
exp2 := createTestExpWithProjectID(t, api, curUser, projectIDInt)
Expand All @@ -274,10 +280,10 @@ func TestSearchRunsFilter(t *testing.T) {
require.Len(t, resp.Runs, 2)

rawMetadata := map[string]any{
"string_key": "a",
"string_key": "apple",
"number_key": 1,
"nested": map[string]any{
"string_key": "a",
"string_key": "apple",
"number_key": 1,
},
}
Expand All @@ -289,10 +295,10 @@ func TestSearchRunsFilter(t *testing.T) {
require.NoError(t, err)

rawMetadata = map[string]any{
"string_key": "b",
"string_key": "bright",
"number_key": 2,
"nested": map[string]any{
"string_key": "b",
"string_key": "bright",
"number_key": 2,
},
}
Expand Down Expand Up @@ -351,14 +357,14 @@ func TestSearchRunsFilter(t *testing.T) {
},
"HyperParamContains": {
expectedNumRuns: 1,
filter: `{"filterGroup":{"children":[{"columnName":"hp.global_batch_size","kind":"field",` +
`"location":"LOCATION_TYPE_RUN_HYPERPARAMETERS","operator":"contains","type":"COLUMN_TYPE_NUMBER","value":1}],` +
filter: `{"filterGroup":{"children":[{"columnName":"hp.stringVal","kind":"field",` +
`"location":"LOCATION_TYPE_RUN_HYPERPARAMETERS","operator":"contains","type":"COLUMN_TYPE_TEXT","value":"a"}],` +
`"conjunction":"and","kind":"group"},"showArchived":false}`,
},
"HyperParamNotContains": {
expectedNumRuns: 1,
filter: `{"filterGroup":{"children":[{"columnName":"hp.global_batch_size","kind":"field",` +
`"location":"LOCATION_TYPE_RUN_HYPERPARAMETERS","operator":"notContains","type":"COLUMN_TYPE_NUMBER","value":1}],` +
filter: `{"filterGroup":{"children":[{"columnName":"hp.stringVal","kind":"field",` +
`"location":"LOCATION_TYPE_RUN_HYPERPARAMETERS","operator":"notContains","type":"COLUMN_TYPE_TEXT","value":"a"}],` +
`"conjunction":"and","kind":"group"},"showArchived":false}`,
},
"HyperParamOperator": {
Expand All @@ -381,14 +387,14 @@ func TestSearchRunsFilter(t *testing.T) {
},
"HyperParamNestedContains": {
expectedNumRuns: 1,
filter: `{"filterGroup":{"children":[{"columnName":"hp.test1.test2","kind":"field",` +
`"location":"LOCATION_TYPE_RUN_HYPERPARAMETERS","operator":"contains","type":"COLUMN_TYPE_NUMBER","value":1}],` +
filter: `{"filterGroup":{"children":[{"columnName":"hp.test3.stringVal","kind":"field",` +
`"location":"LOCATION_TYPE_RUN_HYPERPARAMETERS","operator":"contains","type":"COLUMN_TYPE_TEXT","value":"a"}],` +
`"conjunction":"and","kind":"group"},"showArchived":false}`,
},
"HyperParamNestedNotContains": {
expectedNumRuns: 1,
filter: `{"filterGroup":{"children":[{"columnName":"hp.test1.test2","kind":"field",` +
`"location":"LOCATION_TYPE_RUN_HYPERPARAMETERS","operator":"notContains","type":"COLUMN_TYPE_NUMBER","value":1}],` +
filter: `{"filterGroup":{"children":[{"columnName":"hp.test3.stringVal","kind":"field",` +
`"location":"LOCATION_TYPE_RUN_HYPERPARAMETERS","operator":"notContains","type":"COLUMN_TYPE_TEXT","value":"a"}],` +
`"conjunction":"and","kind":"group"},"showArchived":false}`,
},
"HyperParamNestedOperator": {
Expand Down
100 changes: 53 additions & 47 deletions master/internal/experiment_filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,71 +188,74 @@ func runMetadataToSQL(c string, filterColumnType *string, filterValue *interface
queryColumnType = *filterColumnType
}
var queryArgs []interface{}
runHparam := strings.TrimPrefix(c, "metadata.")
runMetadata := strings.TrimPrefix(c, "metadata.")
queryArgs = append(queryArgs, runMetadata)
oSQL, err := o.toSQL()
if err != nil {
return nil, err
}
var queryString string
switch o {
case empty:
queryString = fmt.Sprintf(`r.id NOT IN (SELECT run_id FROM runs_metadata_index WHERE flat_key='%s')`, runHparam)
queryString = fmt.Sprintf(`r.id NOT IN (SELECT run_id FROM runs_metadata_index WHERE flat_key=%s)`, "?")
case notEmpty:
queryString = fmt.Sprintf(`r.id IN (SELECT run_id FROM runs_metadata_index WHERE flat_key='%s')`, runHparam)
queryString = fmt.Sprintf(`r.id IN (SELECT run_id FROM runs_metadata_index WHERE flat_key=%s)`, "?")
case contains:
queryArgs = append(queryArgs, queryValue)
queryLikeValue := `%` + queryValue.(string) + `%`
queryArgs = append(queryArgs, queryLikeValue)
switch queryColumnType {
case projectv1.ColumnType_COLUMN_TYPE_NUMBER.String():
queryString = fmt.Sprintf(`r.id IN (SELECT run_id FROM runs_metadata_index WHERE flat_key='%s'
queryString = fmt.Sprintf(`r.id IN (SELECT run_id FROM runs_metadata_index WHERE flat_key=%s
AND (number_value=%s OR float_value=%s))`,
runHparam, "?", "?")
"?", "?", "?")
queryArgs = append(queryArgs, queryValue)
case projectv1.ColumnType_COLUMN_TYPE_TEXT.String():
queryString = fmt.Sprintf(`r.id IN (SELECT run_id FROM runs_metadata_index WHERE flat_key='%s'
queryString = fmt.Sprintf(`r.id IN (SELECT run_id FROM runs_metadata_index WHERE flat_key=%s
AND string_value LIKE %s)`,
runHparam, "?")
"?", "?")
default:
queryString = fmt.Sprintf(`r.id IN (SELECT run_id FROM runs_metadata_index WHERE flat_key='%s'
queryString = fmt.Sprintf(`r.id IN (SELECT run_id FROM runs_metadata_index WHERE flat_key=%s
AND boolean_value=%s)`,
runHparam, "?")
"?", "?")
}
case doesNotContain:
queryArgs = append(queryArgs, queryValue)
queryLikeValue := `%` + queryValue.(string) + `%`
queryArgs = append(queryArgs, queryLikeValue)
switch queryColumnType {
case projectv1.ColumnType_COLUMN_TYPE_NUMBER.String():
queryString = fmt.Sprintf(`r.id IN (SELECT run_id FROM runs_metadata_index WHERE flat_key='%s'
queryString = fmt.Sprintf(`r.id IN (SELECT run_id FROM runs_metadata_index WHERE flat_key=%s
AND (number_value!=%s AND float_value!=%s))`,
runHparam, "?", "?")
"?", "?", "?")
queryArgs = append(queryArgs, queryValue)
case projectv1.ColumnType_COLUMN_TYPE_TEXT.String():
queryString = fmt.Sprintf(`r.id IN (SELECT run_id FROM runs_metadata_index WHERE flat_key='%s'
queryString = fmt.Sprintf(`r.id IN (SELECT run_id FROM runs_metadata_index WHERE flat_key=%s
AND string_value NOT LIKE %s)`,
runHparam, "?")
"?", "?")
default:
queryString = fmt.Sprintf(`r.id IN (SELECT run_id FROM runs_metadata_index WHERE flat_key='%s'
queryString = fmt.Sprintf(`r.id IN (SELECT run_id FROM runs_metadata_index WHERE flat_key=%s
AND boolean_value!=%s)`,
runHparam, "?")
"?", "?")
}
default:
queryArgs = append(queryArgs, bun.Safe(oSQL), queryValue)
switch queryColumnType {
case projectv1.ColumnType_COLUMN_TYPE_NUMBER.String():
queryString = fmt.Sprintf(`r.id IN (SELECT run_id FROM runs_metadata_index WHERE flat_key='%s'
queryString = fmt.Sprintf(`r.id IN (SELECT run_id FROM runs_metadata_index WHERE flat_key=%s
AND (integer_value %s %s OR float_value %s %s))`,
runHparam, "?", "?", "?", "?")
"?", "?", "?", "?", "?")
queryArgs = append(queryArgs, bun.Safe(oSQL), queryValue)
case projectv1.ColumnType_COLUMN_TYPE_TEXT.String():
queryString = fmt.Sprintf(`r.id IN (SELECT run_id FROM runs_metadata_index WHERE flat_key='%s'
queryString = fmt.Sprintf(`r.id IN (SELECT run_id FROM runs_metadata_index WHERE flat_key=%s
AND string_value %s %s)`,
runHparam, "?", "?")
"?", "?", "?")
case projectv1.ColumnType_COLUMN_TYPE_DATE.String():
queryString = fmt.Sprintf(`r.id IN (SELECT run_id FROM runs_metadata_index WHERE flat_key='%s'
queryString = fmt.Sprintf(`r.id IN (SELECT run_id FROM runs_metadata_index WHERE flat_key=%s
AND timestamp_value %s %s)`,
runHparam, "?", "?")
"?", "?", "?")
default:
queryString = fmt.Sprintf(`r.id IN (SELECT run_id FROM runs_metadata_index WHERE flat_key='%s'
queryString = fmt.Sprintf(`r.id IN (SELECT run_id FROM runs_metadata_index WHERE flat_key=%s
AND boolean_value %s %s)`,
runHparam, "?", "?")
"?", "?", "?")
}
}

Expand Down Expand Up @@ -281,54 +284,57 @@ func runHpToSQL(c string, filterColumnType *string, filterValue *interface{},
}
var queryArgs []interface{}
runHparam := strings.TrimPrefix(c, "hp.")
queryArgs = append(queryArgs, runHparam)
oSQL, err := o.toSQL()
if err != nil {
return nil, err
}
var queryString string
switch o {
case empty:
queryString = fmt.Sprintf(`r.id NOT IN (SELECT run_id FROM run_hparams WHERE hparam='%s')`, runHparam)
queryString = fmt.Sprintf(`r.id NOT IN (SELECT run_id FROM run_hparams WHERE hparam=%s)`, "?")
case notEmpty:
queryString = fmt.Sprintf(`r.id IN (SELECT run_id FROM run_hparams WHERE hparam='%s')`, runHparam)
queryString = fmt.Sprintf(`r.id IN (SELECT run_id FROM run_hparams WHERE hparam=%s)`, "?")
case contains:
queryArgs = append(queryArgs, queryValue)
queryLikeValue := `%` + queryValue.(string) + `%`
queryArgs = append(queryArgs, queryLikeValue)
switch queryColumnType {
case projectv1.ColumnType_COLUMN_TYPE_NUMBER.String():
queryString = fmt.Sprintf(`r.id IN (SELECT run_id FROM run_hparams WHERE hparam='%s' AND number_val=%s)`,
runHparam, "?")
queryString = fmt.Sprintf(`r.id IN (SELECT run_id FROM run_hparams WHERE hparam=%s AND number_val=%s)`,
"?", "?")
case projectv1.ColumnType_COLUMN_TYPE_TEXT.String():
queryString = fmt.Sprintf(`r.id IN (SELECT run_id FROM run_hparams WHERE hparam='%s' AND text_val LIKE %s)`,
runHparam, "?")
queryString = fmt.Sprintf(`r.id IN (SELECT run_id FROM run_hparams WHERE hparam=%s AND text_val LIKE %s)`,
"?", "?")
default:
queryString = fmt.Sprintf(`r.id IN (SELECT run_id FROM run_hparams WHERE hparam='%s' AND bool_val=%s)`,
runHparam, "?")
queryString = fmt.Sprintf(`r.id IN (SELECT run_id FROM run_hparams WHERE hparam=%s AND bool_val=%s)`,
"?", "?")
}
case doesNotContain:
queryArgs = append(queryArgs, queryValue)
queryLikeValue := `%` + queryValue.(string) + `%`
queryArgs = append(queryArgs, queryLikeValue)
switch queryColumnType {
case projectv1.ColumnType_COLUMN_TYPE_NUMBER.String():
queryString = fmt.Sprintf(`r.id IN (SELECT run_id FROM run_hparams WHERE hparam='%s' AND number_val!=%s)`,
runHparam, "?")
queryString = fmt.Sprintf(`r.id IN (SELECT run_id FROM run_hparams WHERE hparam=%s AND number_val!=%s)`,
"?", "?")
case projectv1.ColumnType_COLUMN_TYPE_TEXT.String():
queryString = fmt.Sprintf(`r.id IN (SELECT run_id FROM run_hparams WHERE hparam='%s' AND text_val NOT LIKE %s)`,
runHparam, "?")
queryString = fmt.Sprintf(`r.id IN (SELECT run_id FROM run_hparams WHERE hparam=%s AND text_val NOT LIKE %s)`,
"?", "?")
default:
queryString = fmt.Sprintf(`r.id IN (SELECT run_id FROM run_hparams WHERE hparam='%s' AND bool_val!=%s)`,
runHparam, "?")
queryString = fmt.Sprintf(`r.id IN (SELECT run_id FROM run_hparams WHERE hparam=%s AND bool_val!=%s)`,
"?", "?")
}
default:
queryArgs = append(queryArgs, bun.Safe(oSQL), queryValue)
switch queryColumnType {
case projectv1.ColumnType_COLUMN_TYPE_NUMBER.String():
queryString = fmt.Sprintf(`r.id IN (SELECT run_id FROM run_hparams WHERE hparam='%s' AND number_val %s %s)`,
runHparam, "?", "?")
queryString = fmt.Sprintf(`r.id IN (SELECT run_id FROM run_hparams WHERE hparam=%s AND number_val %s %s)`,
"?", "?", "?")
case projectv1.ColumnType_COLUMN_TYPE_TEXT.String():
queryString = fmt.Sprintf(`r.id IN (SELECT run_id FROM run_hparams WHERE hparam='%s' AND text_val %s %s)`,
runHparam, "?", "?")
queryString = fmt.Sprintf(`r.id IN (SELECT run_id FROM run_hparams WHERE hparam=%s AND text_val %s %s)`,
"?", "?", "?")
default:
queryString = fmt.Sprintf(`r.id IN (SELECT run_id FROM run_hparams WHERE hparam='%s' AND bool_val %s %s)`,
runHparam, "?", "?")
queryString = fmt.Sprintf(`r.id IN (SELECT run_id FROM run_hparams WHERE hparam=%s AND bool_val %s %s)`,
"?", "?", "?")
}
}

Expand Down

0 comments on commit 94f916d

Please sign in to comment.