Skip to content

Commit

Permalink
introduce typed array
Browse files Browse the repository at this point in the history
  • Loading branch information
xiongjiwei committed Dec 19, 2022
1 parent ad20431 commit f404166
Show file tree
Hide file tree
Showing 12 changed files with 124 additions and 94 deletions.
2 changes: 1 addition & 1 deletion br/pkg/lightning/backend/kv/sql2kv.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ func collectGeneratedColumns(se *session, meta *model.TableInfo, cols []*table.C
var genCols []genCol
for i, col := range cols {
if col.GeneratedExpr != nil {
expr, err := expression.RewriteAstExpr(se, col.GeneratedExpr, schema, names)
expr, err := expression.RewriteAstExpr(se, col.GeneratedExpr, schema, names, false)
if err != nil {
return nil, err
}
Expand Down
6 changes: 3 additions & 3 deletions ddl/ddl_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -6164,7 +6164,7 @@ func (d *ddl) CreatePrimaryKey(ctx sessionctx.Context, ti ast.Ident, indexName m
// After DDL job is put to the queue, and if the check fail, TiDB will run the DDL cancel logic.
// The recover step causes DDL wait a few seconds, makes the unit test painfully slow.
// For same reason, decide whether index is global here.
indexColumns, err := buildIndexColumns(ctx, tblInfo.Columns, indexPartSpecifications)
indexColumns, _, err := buildIndexColumns(ctx, tblInfo.Columns, indexPartSpecifications)
if err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -6274,7 +6274,7 @@ func BuildHiddenColumnInfo(ctx sessionctx.Context, indexPartSpecifications []*as
if err != nil {
return nil, errors.Trace(err)
}
expr, err := expression.RewriteSimpleExprWithTableInfo(ctx, tblInfo, idxPart.Expr)
expr, err := expression.RewriteSimpleExprWithTableInfo(ctx, tblInfo, idxPart.Expr, true)
if err != nil {
// TODO: refine the error message.
return nil, err
Expand Down Expand Up @@ -6389,7 +6389,7 @@ func (d *ddl) createIndex(ctx sessionctx.Context, ti ast.Ident, keyType ast.Inde
// After DDL job is put to the queue, and if the check fail, TiDB will run the DDL cancel logic.
// The recover step causes DDL wait a few seconds, makes the unit test painfully slow.
// For same reason, decide whether index is global here.
indexColumns, err := buildIndexColumns(ctx, finalColumns, indexPartSpecifications)
indexColumns, _, err := buildIndexColumns(ctx, finalColumns, indexPartSpecifications)
if err != nil {
return errors.Trace(err)
}
Expand Down
24 changes: 18 additions & 6 deletions ddl/generated_column.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,12 +268,14 @@ func checkModifyGeneratedColumn(sctx sessionctx.Context, tbl table.Table, oldCol
}

type illegalFunctionChecker struct {
hasIllegalFunc bool
hasAggFunc bool
hasRowVal bool // hasRowVal checks whether the functional index refers to a row value
hasWindowFunc bool
hasNotGAFunc4ExprIdx bool
otherErr error
hasIllegalFunc bool
hasAggFunc bool
hasRowVal bool // hasRowVal checks whether the functional index refers to a row value
hasWindowFunc bool
hasNotGAFunc4ExprIdx bool
hasCastArrayFunc bool
disallowCastArrayFunc bool
otherErr error
}

func (c *illegalFunctionChecker) Enter(inNode ast.Node) (outNode ast.Node, skipChildren bool) {
Expand Down Expand Up @@ -308,7 +310,14 @@ func (c *illegalFunctionChecker) Enter(inNode ast.Node) (outNode ast.Node, skipC
case *ast.WindowFuncExpr:
c.hasWindowFunc = true
return inNode, true
case *ast.FuncCastExpr:
c.hasCastArrayFunc = c.hasCastArrayFunc || node.Tp.IsArray()
if c.disallowCastArrayFunc && node.Tp.IsArray() {
c.otherErr = expression.ErrNotSupportedYet.GenWithStackByArgs("Use of CAST( .. AS .. ARRAY) outside of functional index in CREATE(non-SELECT)/ALTER TABLE or in general expressions")
return inNode, true
}
}
c.disallowCastArrayFunc = true
return inNode, false
}

Expand Down Expand Up @@ -355,6 +364,9 @@ func checkIllegalFn4Generated(name string, genType int, expr ast.ExprNode) error
if genType == typeIndex && c.hasNotGAFunc4ExprIdx && !config.GetGlobalConfig().Experimental.AllowsExpressionIndex {
return dbterror.ErrUnsupportedExpressionIndex
}
if genType == typeColumn && c.hasCastArrayFunc {
return expression.ErrNotSupportedYet.GenWithStackByArgs("Use of CAST( .. AS .. ARRAY) outside of functional index in CREATE(non-SELECT)/ALTER TABLE or in general expressions")
}
return nil
}

Expand Down
21 changes: 12 additions & 9 deletions ddl/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,26 +64,28 @@ var (
telemetryAddIndexIngestUsage = metrics.TelemetryAddIndexIngestCnt
)

func buildIndexColumns(ctx sessionctx.Context, columns []*model.ColumnInfo, indexPartSpecifications []*ast.IndexPartSpecification) ([]*model.IndexColumn, error) {
func buildIndexColumns(ctx sessionctx.Context, columns []*model.ColumnInfo, indexPartSpecifications []*ast.IndexPartSpecification) ([]*model.IndexColumn, bool, error) {
// Build offsets.
idxParts := make([]*model.IndexColumn, 0, len(indexPartSpecifications))
var col *model.ColumnInfo
var mvIndex bool
maxIndexLength := config.GetGlobalConfig().MaxIndexLength
// The sum of length of all index columns.
sumLength := 0
for _, ip := range indexPartSpecifications {
col = model.FindColumnInfo(columns, ip.Column.Name.L)
if col == nil {
return nil, dbterror.ErrKeyColumnDoesNotExits.GenWithStack("column does not exist: %s", ip.Column.Name)
return nil, false, dbterror.ErrKeyColumnDoesNotExits.GenWithStack("column does not exist: %s", ip.Column.Name)
}

if err := checkIndexColumn(ctx, col, ip.Length); err != nil {
return nil, err
return nil, false, err
}
mvIndex = mvIndex || col.FieldType.IsArray()
indexColLen := ip.Length
indexColumnLength, err := getIndexColumnLength(col, ip.Length)
if err != nil {
return nil, err
return nil, false, err
}
sumLength += indexColumnLength

Expand All @@ -92,12 +94,12 @@ func buildIndexColumns(ctx sessionctx.Context, columns []*model.ColumnInfo, inde
// The multiple column index and the unique index in which the length sum exceeds the maximum size
// will return an error instead produce a warning.
if ctx == nil || ctx.GetSessionVars().StrictSQLMode || mysql.HasUniKeyFlag(col.GetFlag()) || len(indexPartSpecifications) > 1 {
return nil, dbterror.ErrTooLongKey.GenWithStackByArgs(maxIndexLength)
return nil, false, dbterror.ErrTooLongKey.GenWithStackByArgs(maxIndexLength)
}
// truncate index length and produce warning message in non-restrict sql mode.
colLenPerUint, err := getIndexColumnLength(col, 1)
if err != nil {
return nil, err
return nil, false, err
}
indexColLen = maxIndexLength / colLenPerUint
// produce warning message
Expand All @@ -111,7 +113,7 @@ func buildIndexColumns(ctx sessionctx.Context, columns []*model.ColumnInfo, inde
})
}

return idxParts, nil
return idxParts, mvIndex, nil
}

// CheckPKOnGeneratedColumn checks the specification of PK is valid.
Expand Down Expand Up @@ -154,7 +156,7 @@ func checkIndexColumn(ctx sessionctx.Context, col *model.ColumnInfo, indexColumn
}

// JSON column cannot index.
if col.FieldType.GetType() == mysql.TypeJSON {
if col.FieldType.GetType() == mysql.TypeJSON && !col.FieldType.IsArray() {
if col.Hidden {
return dbterror.ErrFunctionalIndexOnJSONOrGeometryFunction
}
Expand Down Expand Up @@ -263,7 +265,7 @@ func BuildIndexInfo(
return nil, errors.Trace(err)
}

idxColumns, err := buildIndexColumns(ctx, allTableColumns, indexPartSpecifications)
idxColumns, mvIndex, err := buildIndexColumns(ctx, allTableColumns, indexPartSpecifications)
if err != nil {
return nil, errors.Trace(err)
}
Expand All @@ -276,6 +278,7 @@ func BuildIndexInfo(
Primary: isPrimary,
Unique: isUnique,
Global: isGlobal,
MVIndex: mvIndex,
}

if indexOption != nil {
Expand Down
2 changes: 1 addition & 1 deletion ddl/partition.go
Original file line number Diff line number Diff line change
Expand Up @@ -1375,7 +1375,7 @@ func checkPartitionFuncType(ctx sessionctx.Context, expr ast.ExprNode, tblInfo *
return nil
}

e, err := expression.RewriteSimpleExprWithTableInfo(ctx, tblInfo, expr)
e, err := expression.RewriteSimpleExprWithTableInfo(ctx, tblInfo, expr, false)
if err != nil {
return errors.Trace(err)
}
Expand Down
81 changes: 18 additions & 63 deletions expression/builtin_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ func (c *castAsArrayFunctionClass) verifyArgs(args []Expression) error {
}

if args[0].GetType().EvalType() != types.ETJson {
return types.ErrInvalidJSONData.GenWithStackByArgs(1, "cast_as_array")
return types.ErrInvalidJSONData.GenWithStackByArgs("1", "cast_as_array")
}

return nil
Expand All @@ -435,7 +435,7 @@ func (c *castAsArrayFunctionClass) getFunction(ctx sessionctx.Context, args []Ex
case mysql.TypeYear, mysql.TypeJSON:
return nil, ErrNotSupportedYet.GenWithStackByArgs(fmt.Sprintf("CAST-ing data to array of %s", arrayType.String()))
}
if arrayType.GetCharset() != charset.CharsetUTF8MB4 || arrayType.GetCharset() != charset.CharsetBin {
if arrayType.EvalType() == types.ETString && arrayType.GetCharset() != charset.CharsetUTF8MB4 && arrayType.GetCharset() != charset.CharsetBin {
return nil, ErrNotSupportedYet.GenWithStackByArgs("specifying charset for multi-valued index", arrayType.String())
}

Expand Down Expand Up @@ -467,60 +467,9 @@ func (b *castJSONAsArrayFunctionSig) evalJSON(row chunk.Row) (res types.BinaryJS
return types.BinaryJSON{}, false, ErrNotSupportedYet.GenWithStackByArgs("CAST-ing Non-JSON Array type to array")
}

arrayVals := make([]any, 0, len(b.args))

tp := b.tp.ArrayType()
for i := 0; i < val.GetElemCount(); i++ {
arrayElem := val.ArrayGetElem(i)
switch tp.EvalType() {
case types.ETInt:
switch arrayElem.TypeCode {
case types.JSONTypeCodeInt64, types.JSONTypeCodeUint64:
v, err := types.ConvertJSONToInt(b.ctx.GetSessionVars().StmtCtx, arrayElem, mysql.HasUnsignedFlag(b.tp.GetFlag()), b.tp.ArrayType().GetType())
if err != nil {
return types.BinaryJSON{}, false, errIncorrectArgs
}
if mysql.HasUnsignedFlag(b.tp.GetFlag()) {
arrayVals = append(arrayVals, uint64(v))
} else {
arrayVals = append(arrayVals, v)
}
default:
return types.BinaryJSON{}, false, errIncorrectArgs
}
case types.ETDecimal:
//types.ConvertJSONToDecimal()
case types.ETReal:
switch arrayElem.TypeCode {
case types.JSONTypeCodeFloat64, types.JSONTypeCodeInt64, types.JSONTypeCodeUint64:
v, err := types.ConvertJSONToFloat(b.ctx.GetSessionVars().StmtCtx, arrayElem)
if err != nil {
return types.BinaryJSON{}, false, errIncorrectArgs
}
arrayVals = append(arrayVals, v)
default:
return types.BinaryJSON{}, false, errIncorrectArgs
}
case types.ETDatetime, types.ETTimestamp:

case types.ETDuration:
case types.ETString:
switch arrayElem.TypeCode {
case types.JSONTypeCodeString:
s, err := types.ProduceStrWithSpecifiedTp(string(arrayElem.GetString()), tp, b.ctx.GetSessionVars().StmtCtx, false)
if err != nil {
return types.BinaryJSON{}, false, err
}
arrayVals = append(arrayVals, s)
default:
return types.BinaryJSON{}, false, errIncorrectArgs
}
default:
return types.BinaryJSON{}, false, errIncorrectArgs
}
}
// TODO: impl the cast(... as ... array) function

return types.CreateBinaryJSON(arrayVals), false, nil
return types.BinaryJSON{}, false, nil
}

type castAsJSONFunctionClass struct {
Expand Down Expand Up @@ -2030,6 +1979,19 @@ func BuildCastCollationFunction(ctx sessionctx.Context, expr Expression, ec *Exp

// BuildCastFunction builds a CAST ScalarFunction from the Expression.
func BuildCastFunction(ctx sessionctx.Context, expr Expression, tp *types.FieldType) (res Expression) {
res, err := BuildCastFunctionWithCheck(ctx, expr, tp)
terror.Log(err)
// We do not fold CAST if the eval type of this scalar function is ETJson
// since we may reset the flag of the field type of CastAsJson later which
// would affect the evaluation of it.
if tp.EvalType() != types.ETJson {
res = FoldConstant(res)
}
return res
}

// BuildCastFunctionWithCheck builds a CAST ScalarFunction from the Expression and return error if any.
func BuildCastFunctionWithCheck(ctx sessionctx.Context, expr Expression, tp *types.FieldType) (res Expression, err error) {
argType := expr.GetType()
// If source argument's nullable, then target type should be nullable
if !mysql.HasNotNullFlag(argType.GetFlag()) {
Expand Down Expand Up @@ -2061,19 +2023,12 @@ func BuildCastFunction(ctx sessionctx.Context, expr Expression, tp *types.FieldT
}
}
f, err := fc.getFunction(ctx, []Expression{expr})
terror.Log(err)
res = &ScalarFunction{
FuncName: model.NewCIStr(ast.Cast),
RetType: tp,
Function: f,
}
// We do not fold CAST if the eval type of this scalar function is ETJson
// since we may reset the flag of the field type of CastAsJson later which
// would affect the evaluation of it.
if tp.EvalType() != types.ETJson {
res = FoldConstant(res)
}
return res
return res, err
}

// WrapWithCastAsInt wraps `expr` with `cast` if the return type of expr is not
Expand Down
4 changes: 2 additions & 2 deletions expression/expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ var EvalAstExpr func(sctx sessionctx.Context, expr ast.ExprNode) (types.Datum, e
// RewriteAstExpr rewrites ast expression directly.
// Note: initialized in planner/core
// import expression and planner/core together to use EvalAstExpr
var RewriteAstExpr func(sctx sessionctx.Context, expr ast.ExprNode, schema *Schema, names types.NameSlice) (Expression, error)
var RewriteAstExpr func(sctx sessionctx.Context, expr ast.ExprNode, schema *Schema, names types.NameSlice, allowCastArray bool) (Expression, error)

// VecExpr contains all vectorized evaluation methods.
type VecExpr interface {
Expand Down Expand Up @@ -998,7 +998,7 @@ func ColumnInfos2ColumnsAndNames(ctx sessionctx.Context, dbName, tblName model.C
if err != nil {
return nil, nil, errors.Trace(err)
}
e, err := RewriteAstExpr(ctx, expr, mockSchema, names)
e, err := RewriteAstExpr(ctx, expr, mockSchema, names, false)
if err != nil {
return nil, nil, errors.Trace(err)
}
Expand Down
48 changes: 48 additions & 0 deletions expression/multi_valued_index_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright 2022 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.


package expression_test

import (
"testing"

"github.com/pingcap/tidb/errno"
"github.com/pingcap/tidb/testkit"
)

func TestMultiValuedIndexDDL(t *testing.T) {
store := testkit.CreateMockStore(t)

tk := testkit.NewTestKit(t, store)
tk.MustExec("USE test;")

tk.MustExec("create table t(a json);")
tk.MustGetErrCode("select cast(a as signed array) from t", errno.ErrNotSupportedYet)
tk.MustGetErrCode("select json_extract(cast(a as signed array), '$[0]') from t", errno.ErrNotSupportedYet)
tk.MustGetErrCode("select * from t where cast(a as signed array)", errno.ErrNotSupportedYet)
tk.MustGetErrCode("select cast('[1,2,3]' as unsigned array);", errno.ErrNotSupportedYet)

tk.MustExec("drop table t")
tk.MustGetErrCode("CREATE TABLE t(x INT, KEY k ((1 AND CAST(JSON_ARRAY(x) AS UNSIGNED ARRAY))));", errno.ErrNotSupportedYet)
tk.MustGetErrCode("CREATE TABLE t1 (f1 json, key mvi((cast(cast(f1 as unsigned array) as unsigned array))));", errno.ErrNotSupportedYet)
tk.MustGetErrCode("CREATE TABLE t1 (f1 json, key mvi((cast(f1->>'$[*]' as unsigned array))));", errno.ErrInvalidJSONData)
tk.MustGetErrCode("CREATE TABLE t1 (f1 json, key mvi((cast(f1->'$[*]' as year array))));", errno.ErrNotSupportedYet)
tk.MustGetErrCode("CREATE TABLE t1 (f1 json, key mvi((cast(f1->'$[*]' as json array))));", errno.ErrNotSupportedYet)
tk.MustGetErrCode("CREATE TABLE t1 (f1 json, key mvi((cast(f1->'$[*]' as char(10) charset gbk array))));", errno.ErrNotSupportedYet)
tk.MustGetErrCode("create table t(j json, gc json as ((concat(cast(j->'$[*]' as unsigned array),\"x\"))));", errno.ErrNotSupportedYet)
tk.MustGetErrCode("create table t(j json, gc json as (cast(j->'$[*]' as unsigned array)));", errno.ErrNotSupportedYet)
tk.MustGetErrCode("create view v as select cast('[1,2,3]' as unsigned array);", errno.ErrNotSupportedYet)
tk.MustExec("create table t(a json, index idx((cast(a as signed array))));")
}
8 changes: 4 additions & 4 deletions expression/simple_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func ParseSimpleExprWithTableInfo(ctx sessionctx.Context, exprStr string, tableI
return nil, errors.Trace(err)
}
expr := stmts[0].(*ast.SelectStmt).Fields.Fields[0].Expr
return RewriteSimpleExprWithTableInfo(ctx, tableInfo, expr)
return RewriteSimpleExprWithTableInfo(ctx, tableInfo, expr, false)
}

// ParseSimpleExprCastWithTableInfo parses simple expression string to Expression.
Expand All @@ -63,13 +63,13 @@ func ParseSimpleExprCastWithTableInfo(ctx sessionctx.Context, exprStr string, ta
}

// RewriteSimpleExprWithTableInfo rewrites simple ast.ExprNode to expression.Expression.
func RewriteSimpleExprWithTableInfo(ctx sessionctx.Context, tbl *model.TableInfo, expr ast.ExprNode) (Expression, error) {
func RewriteSimpleExprWithTableInfo(ctx sessionctx.Context, tbl *model.TableInfo, expr ast.ExprNode, allowCastArray bool) (Expression, error) {
dbName := model.NewCIStr(ctx.GetSessionVars().CurrentDB)
columns, names, err := ColumnInfos2ColumnsAndNames(ctx, dbName, tbl.Name, tbl.Cols(), tbl)
if err != nil {
return nil, err
}
e, err := RewriteAstExpr(ctx, expr, NewSchema(columns...), names)
e, err := RewriteAstExpr(ctx, expr, NewSchema(columns...), names, allowCastArray)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -111,7 +111,7 @@ func ParseSimpleExprsWithNames(ctx sessionctx.Context, exprStr string, schema *S

// RewriteSimpleExprWithNames rewrites simple ast.ExprNode to expression.Expression.
func RewriteSimpleExprWithNames(ctx sessionctx.Context, expr ast.ExprNode, schema *Schema, names []*types.FieldName) (Expression, error) {
e, err := RewriteAstExpr(ctx, expr, schema, names)
e, err := RewriteAstExpr(ctx, expr, schema, names, false)
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit f404166

Please sign in to comment.