diff --git a/expression/builtin.go b/expression/builtin.go index f10cf9aa3dfa9..18e78ba17bdb8 100644 --- a/expression/builtin.go +++ b/expression/builtin.go @@ -874,6 +874,7 @@ var funcs = map[string]functionClass{ ast.JSONObject: &jsonObjectFunctionClass{baseFunctionClass{ast.JSONObject, 0, -1}}, ast.JSONArray: &jsonArrayFunctionClass{baseFunctionClass{ast.JSONArray, 0, -1}}, ast.JSONContains: &jsonContainsFunctionClass{baseFunctionClass{ast.JSONContains, 2, 3}}, + ast.JSONOverlaps: &jsonOverlapsFunctionClass{baseFunctionClass{ast.JSONOverlaps, 2, 2}}, ast.JSONContainsPath: &jsonContainsPathFunctionClass{baseFunctionClass{ast.JSONContainsPath, 3, -1}}, ast.JSONValid: &jsonValidFunctionClass{baseFunctionClass{ast.JSONValid, 1, 1}}, ast.JSONArrayAppend: &jsonArrayAppendFunctionClass{baseFunctionClass{ast.JSONArrayAppend, 3, -1}}, diff --git a/expression/builtin_json.go b/expression/builtin_json.go index e317fa88e952a..eeabef6fe2880 100644 --- a/expression/builtin_json.go +++ b/expression/builtin_json.go @@ -44,6 +44,7 @@ var ( _ functionClass = &jsonObjectFunctionClass{} _ functionClass = &jsonArrayFunctionClass{} _ functionClass = &jsonContainsFunctionClass{} + _ functionClass = &jsonOverlapsFunctionClass{} _ functionClass = &jsonContainsPathFunctionClass{} _ functionClass = &jsonValidFunctionClass{} _ functionClass = &jsonArrayAppendFunctionClass{} @@ -72,6 +73,7 @@ var ( _ builtinFunc = &builtinJSONRemoveSig{} _ builtinFunc = &builtinJSONMergeSig{} _ builtinFunc = &builtinJSONContainsSig{} + _ builtinFunc = &builtinJSONOverlapsSig{} _ builtinFunc = &builtinJSONStorageSizeSig{} _ builtinFunc = &builtinJSONDepthSig{} _ builtinFunc = &builtinJSONSearchSig{} @@ -820,6 +822,62 @@ func (b *builtinJSONContainsSig) evalInt(row chunk.Row) (res int64, isNull bool, return 0, false, nil } +type jsonOverlapsFunctionClass struct { + baseFunctionClass +} + +type builtinJSONOverlapsSig struct { + baseBuiltinFunc +} + +func (b *builtinJSONOverlapsSig) Clone() builtinFunc { + newSig := &builtinJSONOverlapsSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (c *jsonOverlapsFunctionClass) verifyArgs(args []Expression) error { + if err := c.baseFunctionClass.verifyArgs(args); err != nil { + return err + } + if evalType := args[0].GetType().EvalType(); evalType != types.ETJson && evalType != types.ETString { + return types.ErrInvalidJSONData.GenWithStackByArgs(1, "json_overlaps") + } + if evalType := args[1].GetType().EvalType(); evalType != types.ETJson && evalType != types.ETString { + return types.ErrInvalidJSONData.GenWithStackByArgs(2, "json_overlaps") + } + return nil +} + +func (c *jsonOverlapsFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + + argTps := []types.EvalType{types.ETJson, types.ETJson} + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, argTps...) + if err != nil { + return nil, err + } + sig := &builtinJSONOverlapsSig{bf} + return sig, nil +} + +func (b *builtinJSONOverlapsSig) evalInt(row chunk.Row) (res int64, isNull bool, err error) { + obj, isNull, err := b.args[0].EvalJSON(b.ctx, row) + if isNull || err != nil { + return res, isNull, err + } + target, isNull, err := b.args[1].EvalJSON(b.ctx, row) + if isNull || err != nil { + return res, isNull, err + } + if types.OverlapsBinaryJSON(obj, target) { + return 1, false, nil + } + return 0, false, nil +} + type jsonValidFunctionClass struct { baseFunctionClass } diff --git a/expression/builtin_json_test.go b/expression/builtin_json_test.go index 3e142860fdd27..72e3c725594c4 100644 --- a/expression/builtin_json_test.go +++ b/expression/builtin_json_test.go @@ -466,6 +466,71 @@ func TestJSONContains(t *testing.T) { } } +func TestJSONOverlaps(t *testing.T) { + ctx := createContext(t) + fc := funcs[ast.JSONOverlaps] + tbl := []struct { + input []any + expected any + err error + }{ + {[]any{`[1,2,[1,3]]`, `a:1`}, 1, types.ErrInvalidJSONText}, + {[]any{`a:1`, `1`}, 1, types.ErrInvalidJSONText}, + {[]any{nil, `1`}, nil, nil}, + {[]any{`1`, nil}, nil, nil}, + + {[]any{`[1, 2]`, `[2,3]`}, 1, nil}, + {[]any{`[1, 2]`, `[2]`}, 1, nil}, + {[]any{`[1, 2]`, `2`}, 1, nil}, + {[]any{`[{"a":1}]`, `{"a":1}`}, 1, nil}, + {[]any{`[{"a":1}]`, `{"a":1,"b":2}`}, 0, nil}, + {[]any{`[{"a":1}]`, `{"a":2}`}, 0, nil}, + {[]any{`{"a":[1,2]}`, `{"a":[1]}`}, 0, nil}, + {[]any{`{"a":[1,2]}`, `{"a":[2,1]}`}, 0, nil}, + {[]any{`[1,1,1]`, `1`}, 1, nil}, + {[]any{`1`, `1`}, 1, nil}, + {[]any{`0`, `1`}, 0, nil}, + {[]any{`[[1,2], 3]`, `[1,[2,3]]`}, 0, nil}, + {[]any{`[[1,2], 3]`, `[1,3]`}, 1, nil}, + {[]any{`{"a":1,"b":10,"d":10}`, `{"a":5,"e":10,"f":1,"d":20}`}, 0, nil}, + {[]any{`[4,5,"6",7]`, `6`}, 0, nil}, + {[]any{`[4,5,6,7]`, `"6"`}, 0, nil}, + + {[]any{`[2,3]`, `[1, 2]`}, 1, nil}, + {[]any{`[2]`, `[1, 2]`}, 1, nil}, + {[]any{`2`, `[1, 2]`}, 1, nil}, + {[]any{`{"a":1}`, `[{"a":1}]`}, 1, nil}, + {[]any{`{"a":1,"b":2}`, `[{"a":1}]`}, 0, nil}, + {[]any{`{"a":2}`, `[{"a":1}]`}, 0, nil}, + {[]any{`{"a":[1]}`, `{"a":[1,2]}`}, 0, nil}, + {[]any{`{"a":[2,1]}`, `{"a":[1,2]}`}, 0, nil}, + {[]any{`1`, `[1,1,1]`}, 1, nil}, + {[]any{`1`, `1`}, 1, nil}, + {[]any{`1`, `0`}, 0, nil}, + {[]any{`[1,[2,3]]`, `[[1,2], 3]`}, 0, nil}, + {[]any{`[1,3]`, `[[1,2], 3]`}, 1, nil}, + {[]any{`{"a":5,"e":10,"f":1,"d":20}`, `{"a":1,"b":10,"d":10}`}, 0, nil}, + {[]any{`6`, `[4,5,"6",7]`}, 0, nil}, + {[]any{`"6"`, `[4,5,6,7]`}, 0, nil}, + } + for _, tt := range tbl { + args := types.MakeDatums(tt.input...) + f, err := fc.getFunction(ctx, datumsToConstants(args)) + require.NoError(t, err, tt.input) + d, err := evalBuiltinFunc(f, chunk.Row{}) + if tt.err == nil { + require.NoError(t, err, tt.input) + if tt.expected == nil { + require.True(t, d.IsNull(), tt.input) + } else { + require.Equal(t, int64(tt.expected.(int)), d.GetInt64(), tt.input) + } + } else { + require.True(t, tt.err.(*terror.Error).Equal(err), tt.input) + } + } +} + func TestJSONContainsPath(t *testing.T) { ctx := createContext(t) fc := funcs[ast.JSONContainsPath] diff --git a/expression/builtin_json_vec.go b/expression/builtin_json_vec.go index fb24808ff2c73..45cca97232d2c 100644 --- a/expression/builtin_json_vec.go +++ b/expression/builtin_json_vec.go @@ -359,6 +359,51 @@ func (b *builtinJSONContainsSig) vecEvalInt(input *chunk.Chunk, result *chunk.Co return nil } +func (b *builtinJSONOverlapsSig) vectorized() bool { + return true +} + +func (b *builtinJSONOverlapsSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) error { + nr := input.NumRows() + + objCol, err := b.bufAllocator.get() + if err != nil { + return err + } + defer b.bufAllocator.put(objCol) + + if err := b.args[0].VecEvalJSON(b.ctx, input, objCol); err != nil { + return err + } + + targetCol, err := b.bufAllocator.get() + if err != nil { + return err + } + defer b.bufAllocator.put(targetCol) + + if err := b.args[1].VecEvalJSON(b.ctx, input, targetCol); err != nil { + return err + } + + result.ResizeInt64(nr, false) + resI64s := result.Int64s() + + result.MergeNulls(objCol, targetCol) + for i := 0; i < nr; i++ { + if result.IsNull(i) { + continue + } + if types.OverlapsBinaryJSON(objCol.GetJSON(i), targetCol.GetJSON(i)) { + resI64s[i] = 1 + } else { + resI64s[i] = 0 + } + } + + return nil +} + func (b *builtinJSONQuoteSig) vectorized() bool { return true } diff --git a/parser/ast/functions.go b/parser/ast/functions.go index d33550fc67626..fdedf53b701cf 100644 --- a/parser/ast/functions.go +++ b/parser/ast/functions.go @@ -331,6 +331,7 @@ const ( JSONInsert = "json_insert" JSONReplace = "json_replace" JSONRemove = "json_remove" + JSONOverlaps = "json_overlaps" JSONContains = "json_contains" JSONMemberOf = "json_memberof" JSONContainsPath = "json_contains_path" diff --git a/types/json_binary_functions.go b/types/json_binary_functions.go index 2b02d5a0f65e7..84bd86445aa94 100644 --- a/types/json_binary_functions.go +++ b/types/json_binary_functions.go @@ -1106,6 +1106,48 @@ func ContainsBinaryJSON(obj, target BinaryJSON) bool { } } +// OverlapsBinaryJSON is similar with ContainsBinaryJSON, but it checks the `OR` relationship. +func OverlapsBinaryJSON(obj, target BinaryJSON) bool { + if obj.TypeCode != JSONTypeCodeArray && target.TypeCode == JSONTypeCodeArray { + obj, target = target, obj + } + switch obj.TypeCode { + case JSONTypeCodeObject: + if target.TypeCode == JSONTypeCodeObject { + elemCount := target.GetElemCount() + for i := 0; i < elemCount; i++ { + key := target.objectGetKey(i) + val := target.objectGetVal(i) + if exp, exists := obj.objectSearchKey(key); exists && CompareBinaryJSON(exp, val) == 0 { + return true + } + } + } + return false + case JSONTypeCodeArray: + if target.TypeCode == JSONTypeCodeArray { + for i := 0; i < obj.GetElemCount(); i++ { + o := obj.arrayGetElem(i) + for j := 0; j < target.GetElemCount(); j++ { + if CompareBinaryJSON(o, target.arrayGetElem(j)) == 0 { + return true + } + } + } + return false + } + elemCount := obj.GetElemCount() + for i := 0; i < elemCount; i++ { + if CompareBinaryJSON(obj.arrayGetElem(i), target) == 0 { + return true + } + } + return false + default: + return CompareBinaryJSON(obj, target) == 0 + } +} + // GetElemDepth for JSON_DEPTH // Returns the maximum depth of a JSON document // rules referenced by MySQL JSON_DEPTH function