diff --git a/pkg/api/api_test.go b/pkg/api/api_test.go index 0948884..e85eb31 100644 --- a/pkg/api/api_test.go +++ b/pkg/api/api_test.go @@ -50,6 +50,8 @@ func jpfEcho(arguments []interface{}) (interface{}, error) { } func TestSearch(t *testing.T) { + type Label string + type args struct { expression string data interface{} @@ -145,6 +147,38 @@ func TestSearch(t *testing.T) { data: []interface{}{map[string]any{}, nil, map[string]any{"foo": "bar"}}, }, want: true, + }, { + args: args{ + expression: "length(@[?metric.__name__ == 'foo'])", + data: []struct { + Metric map[Label]any + }{{ + Metric: map[Label]any{ + "__name__": "foo", + }, + }, { + Metric: map[Label]any{ + "__name__": "bar", + }, + }}, + }, + want: 1.0, + }, { + args: args{ + expression: "length(@[?metric.__name__ == 'foo'])", + data: []struct { + Metric map[string]string + }{{ + Metric: map[string]string{ + "__name__": "foo", + }, + }, { + Metric: map[string]string{ + "__name__": "bar", + }, + }}, + }, + want: 1.0, }} for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/pkg/interpreter/interpreter.go b/pkg/interpreter/interpreter.go index eb3c437..0cd0c19 100644 --- a/pkg/interpreter/interpreter.go +++ b/pkg/interpreter/interpreter.go @@ -154,19 +154,7 @@ func (intr *treeInterpreter) execute(node parsing.ASTNode, value interface{}, fu } return functionCaller.CallFunction(node.Value.(string), resolvedArgs) case parsing.ASTField: - key := node.Value.(string) - var result interface{} - if m, ok := value.(map[string]interface{}); ok { - result = m[key] - if result != nil { - return result, nil - } - } - result = intr.fieldFromStruct(node.Value.(string), value) - if result != nil { - return result, nil - } - return nil, nil + return extractField(value, node.Value.(string)) case parsing.ASTFilterProjection: left, err := intr.execute(node.Children[0], value, functionCaller) if err != nil { @@ -462,29 +450,40 @@ func (intr *treeInterpreter) execute(node parsing.ASTNode, value interface{}, fu return nil, errors.New("Unknown AST node: " + node.NodeType.String()) } -func (intr *treeInterpreter) fieldFromStruct(key string, value interface{}) interface{} { - rv := reflect.ValueOf(value) - first, n := utf8.DecodeRuneInString(key) - fieldName := string(unicode.ToUpper(first)) + key[n:] - if rv.Kind() == reflect.Struct { - v := rv.FieldByName(fieldName) - if !v.IsValid() { - return nil - } - return v.Interface() - } else if rv.Kind() == reflect.Ptr { - // Handle multiple levels of indirection? - if rv.IsNil() { - return nil - } - rv = rv.Elem() - v := rv.FieldByName(fieldName) - if !v.IsValid() { - return nil - } - return v.Interface() +func extractField(value any, field string) (any, error) { + if value == nil { + return nil, nil + } + if m, ok := value.(map[string]interface{}); ok { + return m[field], nil + } + return extractFieldUsingReflection(reflect.ValueOf(value), field) +} + +func extractFieldUsingReflection(value reflect.Value, field string) (any, error) { + if value.Kind() == reflect.Ptr { + if value.IsNil() { + return nil, nil + } + return extractFieldUsingReflection(value.Elem(), field) + } else if value.Kind() == reflect.Struct { + first, n := utf8.DecodeRuneInString(field) + fieldName := string(unicode.ToUpper(first)) + field[n:] + value := value.FieldByName(fieldName) + if value.IsValid() { + return value.Interface(), nil + } + } else if value.Kind() == reflect.Map { + keyType := value.Type().Key() + if reflect.TypeOf(field).ConvertibleTo(keyType) { + key := reflect.ValueOf(field) + value := value.MapIndex(key.Convert(keyType)) + if value.IsValid() { + return value.Interface(), nil + } + } } - return nil + return nil, nil } func (intr *treeInterpreter) flattenWithReflection(value interface{}) (interface{}, error) {