Skip to content

Commit

Permalink
ddl: fix partition function check for partitioned table (#7464)
Browse files Browse the repository at this point in the history
  • Loading branch information
tiancaiamao authored Aug 23, 2018
1 parent 2ac2faf commit 1fa5669
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 9 deletions.
3 changes: 3 additions & 0 deletions ddl/ddl.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ var (
ErrPartitionFuncNotAllowed = terror.ClassDDL.New(codeErrPartitionFuncNotAllowed, mysql.MySQLErrName[mysql.ErrPartitionFuncNotAllowed])
// ErrUniqueKeyNeedAllFieldsInPf returns must include all columns in the table's partitioning function.
ErrUniqueKeyNeedAllFieldsInPf = terror.ClassDDL.New(codeUniqueKeyNeedAllFieldsInPf, mysql.MySQLErrName[mysql.ErrUniqueKeyNeedAllFieldsInPf])
errWrongExprInPartitionFunc = terror.ClassDDL.New(codeWrongExprInPartitionFunc, mysql.MySQLErrName[mysql.ErrWrongExprInPartitionFunc])
)

// DDL is responsible for updating schema in data store and maintaining in-memory InfoSchema cache.
Expand Down Expand Up @@ -606,6 +607,7 @@ const (
codeErrFieldTypeNotAllowedAsPartitionField = terror.ErrCode(mysql.ErrFieldTypeNotAllowedAsPartitionField)
codeUniqueKeyNeedAllFieldsInPf = terror.ErrCode(mysql.ErrUniqueKeyNeedAllFieldsInPf)
codePrimaryCantHaveNull = terror.ErrCode(mysql.ErrPrimaryCantHaveNull)
codeWrongExprInPartitionFunc = terror.ErrCode(mysql.ErrWrongExprInPartitionFunc)
)

func init() {
Expand Down Expand Up @@ -652,6 +654,7 @@ func init() {
codeErrFieldTypeNotAllowedAsPartitionField: mysql.ErrFieldTypeNotAllowedAsPartitionField,
codeUniqueKeyNeedAllFieldsInPf: mysql.ErrUniqueKeyNeedAllFieldsInPf,
codePrimaryCantHaveNull: mysql.ErrPrimaryCantHaveNull,
codeWrongExprInPartitionFunc: mysql.ErrWrongExprInPartitionFunc,
}
terror.ErrClassToMySQLCodes[terror.ClassDDL] = ddlMySQLErrCodes
}
2 changes: 1 addition & 1 deletion ddl/ddl_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -912,7 +912,7 @@ func (d *ddl) CreateTable(ctx sessionctx.Context, s *ast.CreateTableStmt) (err e
return errors.Trace(err)
}

if err = checkPartitionFuncValid(s.Partition.Expr); err != nil {
if err = checkPartitionFuncValid(ctx, tbInfo, s.Partition.Expr); err != nil {
return errors.Trace(err)
}

Expand Down
32 changes: 24 additions & 8 deletions ddl/partition.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/meta"
"github.com/pingcap/tidb/model"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/parser/opcode"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/table"
Expand Down Expand Up @@ -103,28 +104,43 @@ func checkPartitionNameUnique(tbInfo *model.TableInfo, pi *model.PartitionInfo)
}

// checkPartitionFuncValid checks partition function validly.
func checkPartitionFuncValid(expr ast.ExprNode) error {
func checkPartitionFuncValid(ctx sessionctx.Context, tblInfo *model.TableInfo, expr ast.ExprNode) error {
switch v := expr.(type) {
case *ast.CaseExpr:
return ErrPartitionFunctionIsNotAllowed
case *ast.FuncCastExpr, *ast.CaseExpr:
return errors.Trace(ErrPartitionFunctionIsNotAllowed)
case *ast.FuncCallExpr:
// check function which allowed in partitioning expressions
// see https://dev.mysql.com/doc/mysql-partitioning-excerpt/5.7/en/partitioning-limitations-functions.html
switch v.FnName.L {
case ast.Abs, ast.Ceiling, ast.DateDiff, ast.Day, ast.DayOfMonth, ast.DayOfWeek, ast.DayOfYear, ast.Extract, ast.Floor,
ast.Hour, ast.MicroSecond, ast.Minute, ast.Mod, ast.Month, ast.Quarter, ast.Second, ast.TimeToSec, ast.ToDays,
ast.ToSeconds, ast.UnixTimestamp, ast.Weekday, ast.Year, ast.YearWeek:
ast.ToSeconds, ast.Weekday, ast.Year, ast.YearWeek:
return nil
default:
return ErrPartitionFunctionIsNotAllowed
case ast.UnixTimestamp:
if len(v.Args) == 1 {
col, err := expression.RewriteSimpleExprWithTableInfo(ctx, tblInfo, v.Args[0])
if err != nil {
return errors.Trace(err)
}
if col.GetType().Tp != mysql.TypeTimestamp {
return errors.Trace(errWrongExprInPartitionFunc)
}
return nil
}
}
return errors.Trace(ErrPartitionFunctionIsNotAllowed)
case *ast.BinaryOperationExpr:
// The DIV operator (opcode.IntDiv) is also supported; the / operator ( opcode.Div ) is not permitted.
// see https://dev.mysql.com/doc/refman/5.7/en/partitioning-limitations.html
if v.Op == opcode.Div {
return ErrPartitionFunctionIsNotAllowed
switch v.Op {
case opcode.Or, opcode.And, opcode.Xor, opcode.LeftShift, opcode.RightShift, opcode.BitNeg, opcode.Div:
return errors.Trace(ErrPartitionFunctionIsNotAllowed)
}
return nil
case *ast.UnaryOperationExpr:
if v.Op == opcode.BitNeg {
return errors.Trace(ErrPartitionFunctionIsNotAllowed)
}
}
return nil
}
Expand Down

0 comments on commit 1fa5669

Please sign in to comment.