From 3be0994e97e596dc491c9645cefeec796cc15b66 Mon Sep 17 00:00:00 2001 From: congqixia Date: Wed, 25 Oct 2023 10:08:01 +0800 Subject: [PATCH] Add wildcard outputfield expansion logic back (#604) Signed-off-by: Congqi Xia --- client/data.go | 34 +++++++++++++++++++++++++++++- client/data_test.go | 39 +++++++++++++++++++++++++++++++++++ test/testcases/search_test.go | 12 ++++------- 3 files changed, 76 insertions(+), 9 deletions(-) diff --git a/client/data.go b/client/data.go index 4d18ed4a3..106f53f32 100644 --- a/client/data.go +++ b/client/data.go @@ -92,7 +92,8 @@ func (c *GrpcClient) Search(ctx context.Context, collName string, partitions []s return sr, nil } -func (c *GrpcClient) parseSearchResult(_ *entity.Schema, outputFields []string, fieldDataList []*schemapb.FieldData, _, from, to int) ([]entity.Column, error) { +func (c *GrpcClient) parseSearchResult(sch *entity.Schema, outputFields []string, fieldDataList []*schemapb.FieldData, _, from, to int) ([]entity.Column, error) { + outputFields = expandWildcard(sch, outputFields) // duplicated name will have only one column now outputSet := make(map[string]struct{}) for _, output := range outputFields { @@ -141,6 +142,37 @@ func (c *GrpcClient) parseSearchResult(_ *entity.Schema, outputFields []string, return columns, nil } +func expandWildcard(schema *entity.Schema, outputFields []string) []string { + wildcard := false + for _, outputField := range outputFields { + if outputField == "*" { + wildcard = true + } + } + if !wildcard { + return outputFields + } + + set := make(map[string]struct{}) + result := make([]string, 0, len(schema.Fields)) + for _, field := range schema.Fields { + result = append(result, field.Name) + set[field.Name] = struct{}{} + } + + // add dynamic fields output + for _, output := range outputFields { + if output == "*" { + continue + } + _, ok := set[output] + if !ok { + result = append(result, output) + } + } + return result +} + func PKs2Expr(backName string, ids entity.Column) string { var expr string var pkName = ids.Name() diff --git a/client/data_test.go b/client/data_test.go index 0e787b68e..aff3bb91c 100644 --- a/client/data_test.go +++ b/client/data_test.go @@ -1178,3 +1178,42 @@ func TestVector2PlaceHolder(t *testing.T) { } }) } + +type WildcardSuite struct { + suite.Suite + + schema *entity.Schema +} + +func (s *WildcardSuite) SetupTest() { + s.schema = entity.NewSchema(). + WithField(entity.NewField().WithName("pk").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)). + WithField(entity.NewField().WithName("attr").WithDataType(entity.FieldTypeInt64)). + WithField(entity.NewField().WithName("$meta").WithDataType(entity.FieldTypeJSON).WithIsDynamic(true)). + WithField(entity.NewField().WithName("vector").WithDataType(entity.FieldTypeFloatVector).WithDim(128)) +} + +func (s *WildcardSuite) TestExpandWildcard() { + type testCase struct { + tag string + input []string + expect []string + } + + cases := []testCase{ + {tag: "normal", input: []string{"pk", "attr"}, expect: []string{"pk", "attr"}}, + {tag: "with_wildcard", input: []string{"*"}, expect: []string{"pk", "attr", "$meta", "vector"}}, + {tag: "wildcard_dynamic", input: []string{"*", "a"}, expect: []string{"pk", "attr", "$meta", "vector", "a"}}, + } + + for _, tc := range cases { + s.Run(tc.tag, func() { + output := expandWildcard(s.schema, tc.input) + s.ElementsMatch(tc.expect, output) + }) + } +} + +func TestExpandWildcard(t *testing.T) { + suite.Run(t, new(WildcardSuite)) +} diff --git a/test/testcases/search_test.go b/test/testcases/search_test.go index ffaddc212..7f4ea91ad 100644 --- a/test/testcases/search_test.go +++ b/test/testcases/search_test.go @@ -1099,8 +1099,7 @@ func TestSearchScannAllMetricsWithRawData(t *testing.T) { sp, ) common.CheckErr(t, errSearch, true) - // TODO output_fields include * ??? - common.CheckOutputFields(t, resSearch[0].Fields, []string{"*", common.DefaultIntFieldName, common.DefaultFloatFieldName, + common.CheckOutputFields(t, resSearch[0].Fields, []string{common.DefaultIntFieldName, common.DefaultFloatFieldName, common.DefaultJSONFieldName, common.DefaultFloatVecFieldName}) common.CheckSearchResult(t, resSearch, 1, common.DefaultTopK) } @@ -1148,8 +1147,7 @@ func TestRangeSearchScannL2(t *testing.T) { // verify error nil, output all fields, range score common.CheckErr(t, errSearch, true) common.CheckSearchResult(t, resSearch, 1, common.DefaultTopK) - // TODO output_fields include * ??? https://github.com/milvus-io/milvus-sdk-go/issues/596 - common.CheckOutputFields(t, resSearch[0].Fields, []string{"*", common.DefaultIntFieldName, common.DefaultFloatFieldName, + common.CheckOutputFields(t, resSearch[0].Fields, []string{common.DefaultIntFieldName, common.DefaultFloatFieldName, common.DefaultJSONFieldName, common.DefaultFloatVecFieldName}) for _, s := range resSearch[0].Scores { require.GreaterOrEqual(t, s, float32(15.0)) @@ -1208,8 +1206,7 @@ func TestRangeSearchScannIPCosine(t *testing.T) { // verify error nil, output all fields, range score common.CheckErr(t, errSearch, true) common.CheckSearchResult(t, resSearch, 1, common.DefaultTopK) - // TODO output_fields include * ??? https://github.com/milvus-io/milvus-sdk-go/issues/596 - common.CheckOutputFields(t, resSearch[0].Fields, []string{"*", common.DefaultIntFieldName, common.DefaultFloatFieldName, + common.CheckOutputFields(t, resSearch[0].Fields, []string{common.DefaultIntFieldName, common.DefaultFloatFieldName, common.DefaultJSONFieldName, common.DefaultFloatVecFieldName}) for _, s := range resSearch[0].Scores { require.GreaterOrEqual(t, s, float32(0)) @@ -1270,8 +1267,7 @@ func TestRangeSearchScannBinary(t *testing.T) { // verify error nil, output all fields, range score common.CheckErr(t, errSearch, true) common.CheckSearchResult(t, resSearch, 1, common.DefaultTopK) - // TODO output_fields include * ??? https://github.com/milvus-io/milvus-sdk-go/issues/596 - common.CheckOutputFields(t, resSearch[0].Fields, []string{"*", common.DefaultIntFieldName, common.DefaultFloatFieldName, common.DefaultBinaryVecFieldName}) + common.CheckOutputFields(t, resSearch[0].Fields, []string{common.DefaultIntFieldName, common.DefaultFloatFieldName, common.DefaultBinaryVecFieldName}) for _, s := range resSearch[0].Scores { require.GreaterOrEqual(t, s, float32(0)) require.Less(t, s, float32(100))