Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

collation: cast charset according to the function's resulting charset #29029

Closed
wants to merge 5 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
convert binary to non-binary
xiongjiwei committed Nov 4, 2021

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit 46269e094c8bfd8c146d96e9d7dca819f286d10c
1 change: 1 addition & 0 deletions errno/errcode.go
Original file line number Diff line number Diff line change
@@ -901,6 +901,7 @@ const (
ErrFKIncompatibleColumns = 3780
ErrFunctionalIndexRowValueIsNotAllowed = 3800
ErrDependentByFunctionalIndex = 3837
ErrCannotConvertString = 3854
ErrInvalidJSONValueForFuncIndex = 3903
ErrJSONValueOutOfRangeForFuncIndex = 3904
ErrFunctionalIndexDataIsTooLong = 3907
1 change: 1 addition & 0 deletions errno/errname.go
Original file line number Diff line number Diff line change
@@ -896,6 +896,7 @@ var MySQLErrName = map[uint16]*mysql.ErrMessage{
ErrFKIncompatibleColumns: mysql.Message("Referencing column '%s' in foreign key constraint '%s' are incompatible", nil),
ErrFunctionalIndexRowValueIsNotAllowed: mysql.Message("Expression of expression index '%s' cannot refer to a row value", nil),
ErrDependentByFunctionalIndex: mysql.Message("Column '%s' has an expression index dependency and cannot be dropped or renamed", nil),
ErrCannotConvertString: mysql.Message("Cannot convert string '%.64s' from %s to %s", nil),
ErrInvalidJSONValueForFuncIndex: mysql.Message("Invalid JSON value for CAST for expression index '%s'", nil),
ErrJSONValueOutOfRangeForFuncIndex: mysql.Message("Out of range JSON value for CAST for expression index '%s'", nil),
ErrFunctionalIndexDataIsTooLong: mysql.Message("Data too long for expression index '%s'", nil),
6 changes: 3 additions & 3 deletions expression/builtin.go
Original file line number Diff line number Diff line change
@@ -91,7 +91,7 @@ func newBaseBuiltinFunc(ctx sessionctx.Context, funcName string, args []Expressi
if ctx == nil {
return baseBuiltinFunc{}, errors.New("unexpected nil session ctx")
}
ec, err := deriveCollation(ctx, funcName, args, retType, retType)
ec, _, err := deriveCollation(ctx, funcName, args, retType, retType)
if err != nil {
return baseBuiltinFunc{}, err
}
@@ -125,7 +125,7 @@ func newBaseBuiltinFuncWithTp(ctx sessionctx.Context, funcName string, args []Ex

// derive collation information for string function, and we must do it
// before doing implicit cast.
ec, err := deriveCollation(ctx, funcName, args, retType, argTps...)
ec, retTp, err := deriveCollation(ctx, funcName, args, retType, argTps...)
if err != nil {
return
}
@@ -139,7 +139,7 @@ func newBaseBuiltinFuncWithTp(ctx sessionctx.Context, funcName string, args []Ex
case types.ETDecimal:
args[i] = WrapWithCastAsDecimal(ctx, args[i])
case types.ETString:
args[i] = WrapWithCastAsString(ctx, args[i])
args[i] = WrapWithCastAsStringWithTp(ctx, args[i], retTp)
case types.ETDatetime:
args[i] = WrapWithCastAsTime(ctx, args[i], types.NewFieldType(mysql.TypeDatetime))
case types.ETTimestamp:
46 changes: 46 additions & 0 deletions expression/builtin_cast.go
Original file line number Diff line number Diff line change
@@ -23,12 +23,16 @@
package expression

import (
"fmt"
"math"
"strconv"
"strings"
"unicode/utf8"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/errno"
"github.com/pingcap/tidb/parser/ast"
"github.com/pingcap/tidb/parser/charset"
"github.com/pingcap/tidb/parser/model"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/parser/terror"
@@ -37,6 +41,7 @@ import (
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/types/json"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/dbterror"
"github.com/pingcap/tipb/go-tipb"
)

@@ -108,6 +113,11 @@ var (
_ builtinFunc = &builtinCastJSONAsJSONSig{}
)

var (
// errCannotConvertString returns when the string can not convert to other charset.
errCannotConvertString = dbterror.ClassExpression.NewStd(errno.ErrCannotConvertString)
)

type castAsIntFunctionClass struct {
baseFunctionClass

@@ -1112,6 +1122,23 @@ func (b *builtinCastStringAsStringSig) evalString(row chunk.Row) (res string, is
if isNull || err != nil {
return res, isNull, err
}
ov := res
fromChs := b.args[0].GetType().Charset
toChs := b.tp.Charset
if toChs == charset.CharsetBin && fromChs != charset.CharsetBin {
res, err = charset.NewEncoding(fromChs).EncodeString(res)
} else if toChs != charset.CharsetBin && fromChs == charset.CharsetBin {
res, err = charset.NewEncoding(toChs).DecodeString(res)
// If toChs is utf8 or utf8mb4, DecodeString will do nothing and return nil error, but we need check if the binary literal is able to convert to utf8.
if toChs == charset.CharsetUTF8 || toChs == charset.CharsetUTF8MB4 {
if !utf8.ValidString(res) {
return "", false, errCannotConvertString.GenWithStackByArgs(fmt.Sprintf("%X", ov), fromChs, toChs)
}
}
}
if err != nil {
return "", false, errCannotConvertString.GenWithStackByArgs(fmt.Sprintf("%X", ov), fromChs, toChs)
}
sc := b.ctx.GetSessionVars().StmtCtx
res, err = types.ProduceStrWithSpecifiedTp(res, b.tp, sc, false)
if err != nil {
@@ -1907,6 +1934,25 @@ func WrapWithCastAsDecimal(ctx sessionctx.Context, expr Expression) Expression {
return BuildCastFunction(ctx, expr, tp)
}

// WrapWithCastAsStringWithTp wraps `expr` with `cast`.
func WrapWithCastAsStringWithTp(ctx sessionctx.Context, expr Expression, toTp *types.FieldType) Expression {
if expr.GetType().EvalType() == types.ETString && toTp != nil {
if expr.GetType().Charset == toTp.Charset {
return expr
}
toTp = &types.FieldType{
Tp: mysql.TypeVarString,
Decimal: expr.GetType().Decimal, // keep original Decimal
Charset: toTp.Charset,
Collate: toTp.Collate,
Flen: expr.GetType().Flen, // keep original Flen
}
return BuildCastFunction(ctx, expr, toTp)
}

return WrapWithCastAsString(ctx, expr)
}

// WrapWithCastAsString wraps `expr` with `cast` if the return type of expr is
// not type string, otherwise, returns `expr` directly.
func WrapWithCastAsString(ctx sessionctx.Context, expr Expression) Expression {
27 changes: 26 additions & 1 deletion expression/builtin_cast_vec.go
Original file line number Diff line number Diff line change
@@ -15,10 +15,13 @@
package expression

import (
"fmt"
"math"
"strconv"
"strings"
"unicode/utf8"

"github.com/pingcap/tidb/parser/charset"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/types/json"
@@ -1820,14 +1823,36 @@ func (b *builtinCastStringAsStringSig) vecEvalString(input *chunk.Chunk, result

var res string
var isNull bool

fromChs := b.args[0].GetType().Charset
toChs := b.tp.Charset
transferString := func(s string) (string, error) { return s, nil }
if toChs == charset.CharsetBin && fromChs != charset.CharsetBin {
transferString = charset.NewEncoding(fromChs).EncodeString
} else if toChs != charset.CharsetBin && fromChs == charset.CharsetBin {
transferString = charset.NewEncoding(toChs).DecodeString
if toChs == charset.CharsetUTF8 || toChs == charset.CharsetUTF8MB4 {
transferString = func(s string) (string, error) {
if !utf8.ValidString(s) {
return "", errCannotConvertString.GenWithStackByArgs(fmt.Sprintf("%X", s), fromChs, toChs)
}
return s, nil
}
}
}

sc := b.ctx.GetSessionVars().StmtCtx
result.ReserveString(n)
for i := 0; i < n; i++ {
if buf.IsNull(i) {
result.AppendNull()
continue
}
res, err = types.ProduceStrWithSpecifiedTp(buf.GetString(i), b.tp, sc, false)
res, err = transferString(buf.GetString(i))
if err != nil {
return errCannotConvertString.GenWithStackByArgs(fmt.Sprintf("%X", buf.GetString(i)), fromChs, toChs)
}
res, err = types.ProduceStrWithSpecifiedTp(res, b.tp, sc, false)
if err != nil {
return err
}
2 changes: 1 addition & 1 deletion expression/builtin_compare.go
Original file line number Diff line number Diff line change
@@ -1210,7 +1210,7 @@ func GetCmpFunction(ctx sessionctx.Context, lhs, rhs Expression) CompareFunc {
case types.ETDecimal:
return CompareDecimal
case types.ETString:
coll, _ := CheckAndDeriveCollationFromExprs(ctx, "", types.ETInt, lhs, rhs)
coll, _, _ := CheckAndDeriveCollationFromExprs(ctx, "", types.ETInt, lhs, rhs)
return genCompareString(coll.Collation)
case types.ETDuration:
return CompareDuration
4 changes: 2 additions & 2 deletions expression/builtin_control.go
Original file line number Diff line number Diff line change
@@ -94,7 +94,7 @@ func InferType4ControlFuncs(ctx sessionctx.Context, funcName string, lexp, rexp
}

if types.IsNonBinaryStr(lhs) && !types.IsBinaryStr(rhs) {
ec, err := CheckAndDeriveCollationFromExprs(ctx, funcName, evalType, lexp, rexp)
ec, _, err := CheckAndDeriveCollationFromExprs(ctx, funcName, evalType, lexp, rexp)
if err != nil {
return nil, err
}
@@ -104,7 +104,7 @@ func InferType4ControlFuncs(ctx sessionctx.Context, funcName string, lexp, rexp
resultFieldType.Flag |= mysql.BinaryFlag
}
} else if types.IsNonBinaryStr(rhs) && !types.IsBinaryStr(lhs) {
ec, err := CheckAndDeriveCollationFromExprs(ctx, funcName, evalType, lexp, rexp)
ec, _, err := CheckAndDeriveCollationFromExprs(ctx, funcName, evalType, lexp, rexp)
if err != nil {
return nil, err
}
4 changes: 4 additions & 0 deletions expression/builtin_string.go
Original file line number Diff line number Diff line change
@@ -2972,6 +2972,10 @@ func (c *quoteFunctionClass) getFunction(ctx sessionctx.Context, args []Expressi
}
SetBinFlagOrBinStr(args[0].GetType(), bf.tp)
bf.tp.Flen = 2*args[0].GetType().Flen + 2
// If arg is NULL, quote function will return 'NULL', the Flen should be 4.
if args[0].GetType().Tp == mysql.TypeNull {
bf.tp.Flen = 4
}
if bf.tp.Flen > mysql.MaxBlobWidth {
bf.tp.Flen = mysql.MaxBlobWidth
}
52 changes: 25 additions & 27 deletions expression/collation.go
Original file line number Diff line number Diff line change
@@ -192,7 +192,7 @@ func deriveCoercibilityForColumn(c *Column) Coercibility {
return CoercibilityImplicit
}

func deriveCollation(ctx sessionctx.Context, funcName string, args []Expression, retType types.EvalType, argTps ...types.EvalType) (ec *ExprCollation, err error) {
func deriveCollation(ctx sessionctx.Context, funcName string, args []Expression, retType types.EvalType, argTps ...types.EvalType) (ec *ExprCollation, retTp *types.FieldType, err error) {
switch funcName {
case ast.Concat, ast.ConcatWS, ast.Lower, ast.Lcase, ast.Reverse, ast.Upper, ast.Ucase, ast.Quote, ast.Coalesce:
return CheckAndDeriveCollationFromExprs(ctx, funcName, retType, args...)
@@ -215,53 +215,48 @@ func deriveCollation(ctx sessionctx.Context, funcName string, args []Expression,
case ast.GE, ast.LE, ast.GT, ast.LT, ast.EQ, ast.NE, ast.NullEQ, ast.Strcmp:
// if compare type is string, we should determine which collation should be used.
if argTps[0] == types.ETString {
ec, err = CheckAndDeriveCollationFromExprs(ctx, funcName, types.ETInt, args...)
ec, retTp, err = CheckAndDeriveCollationFromExprs(ctx, funcName, types.ETInt, args...)
if err != nil {
return nil, err
return nil, nil, err
}
ec.Coer = CoercibilityNumeric
ec.Repe = ASCII
return ec, nil
return ec, retTp, nil
}
case ast.If:
return CheckAndDeriveCollationFromExprs(ctx, funcName, retType, args[1], args[2])
case ast.Ifnull:
return CheckAndDeriveCollationFromExprs(ctx, funcName, retType, args[0], args[1])
case ast.Like:
ec, err = CheckAndDeriveCollationFromExprs(ctx, funcName, types.ETInt, args[0], args[1])
ec, retTp, err = CheckAndDeriveCollationFromExprs(ctx, funcName, types.ETInt, args[0], args[1])
if err != nil {
return nil, err
return nil, nil, err
}
ec.Coer = CoercibilityNumeric
ec.Repe = ASCII
return ec, nil
return ec, retTp, nil
case ast.In:
if args[0].GetType().EvalType() == types.ETString {
return CheckAndDeriveCollationFromExprs(ctx, funcName, types.ETInt, args...)
}
case ast.DateFormat, ast.TimeFormat:
charsetInfo, collation := ctx.GetSessionVars().GetCharsetInfo()
return &ExprCollation{args[1].Coercibility(), args[1].Repertoire(), charsetInfo, collation}, nil
return &ExprCollation{args[1].Coercibility(), args[1].Repertoire(), charsetInfo, collation}, nil, nil
case ast.Cast:
// We assume all the cast are implicit.
ec = &ExprCollation{args[0].Coercibility(), args[0].Repertoire(), args[0].GetType().Charset, args[0].GetType().Collate}
// Non-string type cast to string type should use @@character_set_connection and @@collation_connection.
// String type cast to string type should keep its original charset and collation. It should not happen.
if retType == types.ETString && argTps[0] != types.ETString {
ec.Charset, ec.Collation = ctx.GetSessionVars().GetCharsetInfo()
}
return ec, nil
// We assume all the cast are implicit, keep the collation related fields to its original value.
return &ExprCollation{args[0].Coercibility(), args[0].Repertoire(), args[0].GetType().Charset, args[0].GetType().Collate}, nil, nil
case ast.Case:
// FIXME: case function aggregate collation is not correct.
return CheckAndDeriveCollationFromExprs(ctx, funcName, retType, args...)
ec, _, err = CheckAndDeriveCollationFromExprs(ctx, funcName, retType, args...)
return ec, nil, err
case ast.Database, ast.User, ast.CurrentUser, ast.Version, ast.CurrentRole, ast.TiDBVersion:
chs, coll := charset.GetDefaultCharsetAndCollate()
return &ExprCollation{CoercibilitySysconst, UNICODE, chs, coll}, nil
return &ExprCollation{CoercibilitySysconst, UNICODE, chs, coll}, nil, nil
case ast.Format, ast.Space, ast.ToBase64, ast.UUID, ast.Hex, ast.MD5, ast.SHA, ast.SHA2:
// should return ASCII repertoire, MySQL's doc says it depends on character_set_connection, but it not true from its source code.
ec = &ExprCollation{Coer: CoercibilityCoercible, Repe: ASCII}
ec.Charset, ec.Collation = ctx.GetSessionVars().GetCharsetInfo()
return ec, nil
return ec, nil, nil
}

ec = &ExprCollation{CoercibilityNumeric, ASCII, charset.CharsetBin, charset.CollationBin}
@@ -272,7 +267,7 @@ func deriveCollation(ctx sessionctx.Context, funcName string, args []Expression,
ec.Repe = UNICODE
}
}
return ec, nil
return ec, nil, nil
}

// DeriveCollationFromExprs derives collation information from these expressions.
@@ -284,14 +279,14 @@ func DeriveCollationFromExprs(ctx sessionctx.Context, exprs ...Expression) (dstC
}

// CheckAndDeriveCollationFromExprs derives collation information from these expressions, return error if derives collation error.
func CheckAndDeriveCollationFromExprs(ctx sessionctx.Context, funcName string, evalType types.EvalType, args ...Expression) (et *ExprCollation, err error) {
func CheckAndDeriveCollationFromExprs(ctx sessionctx.Context, funcName string, evalType types.EvalType, args ...Expression) (et *ExprCollation, retTp *types.FieldType, err error) {
ec := inferCollation(args...)
if ec == nil {
return nil, illegalMixCollationErr(funcName, args)
return nil, nil, illegalMixCollationErr(funcName, args)
}

if evalType != types.ETString && ec.Coer == CoercibilityNone {
return nil, illegalMixCollationErr(funcName, args)
return nil, nil, illegalMixCollationErr(funcName, args)
}

if evalType == types.ETString && ec.Coer == CoercibilityNumeric {
@@ -301,10 +296,9 @@ func CheckAndDeriveCollationFromExprs(ctx sessionctx.Context, funcName string, e
}

if !safeConvert(ctx, ec, args...) {
return nil, illegalMixCollationErr(funcName, args)
return nil, nil, illegalMixCollationErr(funcName, args)
}

return ec, nil
return ec, &types.FieldType{Charset: ec.Charset, Collate: ec.Collation}, nil
}

func safeConvert(ctx sessionctx.Context, ec *ExprCollation, args ...Expression) bool {
@@ -322,7 +316,11 @@ func safeConvert(ctx sessionctx.Context, ec *ExprCollation, args ...Expression)
if err != nil {
return false
}
if !isNull && !isValidString(str, ec.Charset) {
// if value is NULL or binary string, just skip it.
if isNull || types.IsBinaryStr(c.GetType()) {
continue
}
Comment on lines +320 to +323
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can move types.IsBinaryStr(c.GetType()) to the beginning of this loop to avoid unnecessary EvalString.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xiongjiwei Please address this comment.

if !isValidString(str, ec.Charset) {
return false
}
} else {
4 changes: 2 additions & 2 deletions expression/collation_test.go
Original file line number Diff line number Diff line change
@@ -622,13 +622,13 @@ func TestDeriveCollation(t *testing.T) {
[]types.EvalType{types.ETInt},
types.ETString,
false,
&ExprCollation{CoercibilityExplicit, ASCII, charset.CharsetUTF8MB4, charset.CollationUTF8MB4},
&ExprCollation{CoercibilityExplicit, ASCII, charset.CharsetBinary, charset.CollationBin},
},
}

for i, test := range tests {
for _, fc := range test.fcs {
ec, err := deriveCollation(ctx, fc, test.args, test.retTp, test.argTps...)
ec, _, err := deriveCollation(ctx, fc, test.args, test.retTp, test.argTps...)
if test.err {
require.Error(t, err, "Number: %d, function: %s", i, fc)
require.Nil(t, ec, i)
7 changes: 6 additions & 1 deletion expression/distsql_builtin.go
Original file line number Diff line number Diff line change
@@ -1218,7 +1218,12 @@ func convertUint(val []byte) (*Constant, error) {
func convertString(val []byte, tp *tipb.FieldType) (*Constant, error) {
var d types.Datum
d.SetBytesAsString(val, protoToCollation(tp.Collate), uint32(tp.Flen))
return &Constant{Value: d, RetType: types.NewFieldType(mysql.TypeVarString)}, nil
return &Constant{Value: d, RetType: &types.FieldType{
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pb to string expression should use charset information in pb

Tp: mysql.TypeString,
Flag: uint(tp.Flag),
Charset: tp.Charset,
Flen: int(tp.Flen),
}}, nil
}

func convertFloat(val []byte, f32 bool) (*Constant, error) {
2 changes: 1 addition & 1 deletion expression/integration_test.go
Original file line number Diff line number Diff line change
@@ -1180,7 +1180,7 @@ func (s *testIntegrationSuite2) TestStringBuiltin(c *C) {

// for insert
result = tk.MustQuery(`select insert("中文", 1, 1, cast("aaa" as binary)), insert("ba", -1, 1, "aaa"), insert("ba", 1, 100, "aaa"), insert("ba", 100, 1, "aaa");`)
result.Check(testkit.Rows("aaa文 ba aaa ba"))
result.Check(testkit.Rows("aaa\xb8\xad文 ba aaa ba"))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is compatible with mysql version before 8.0.24.

Copy link
Contributor

@Defined2014 Defined2014 Oct 28, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this change happened? Because of implicit cast?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, before 8.0.24, MySQL uses 1st and 4th arguments to determine the resulting charset, after it, only uses 1st argument. in this case, the resulting charset will be binary for the former and utf8mb4 for the latter, and length of 1 for binary charset is a byte, utf8mb4 is a character.

result = tk.MustQuery(`select insert("bb", NULL, 1, "aa"), insert("bb", 1, NULL, "aa"), insert(NULL, 1, 1, "aaa"), insert("bb", 1, 1, NULL);`)
result.Check(testkit.Rows("<nil> <nil> <nil> <nil>"))
result = tk.MustQuery(`SELECT INSERT("bb", 0, 1, NULL), INSERT("bb", 0, NULL, "aaa");`)
1 change: 0 additions & 1 deletion expression/typeinfer_test.go
Original file line number Diff line number Diff line change
@@ -275,7 +275,6 @@ func (s *testInferTypeSuite) createTestCase4StrFuncs() []typeInferTestCase {
{"CONCAT('T', 'i', 'DB', c_binary)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 24, types.UnspecifiedLength},
{"CONCAT_WS('-', 'T', 'i', 'DB')", mysql.TypeVarString, charset.CharsetUTF8MB4, mysql.NotNullFlag, 6, types.UnspecifiedLength},
{"CONCAT_WS(',', 'TiDB', c_binary)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 25, types.UnspecifiedLength},
{"CONCAT(c_bchar, 0x80)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 23, types.UnspecifiedLength},
{"left(c_int_d, c_int_d)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 20, types.UnspecifiedLength},
{"right(c_int_d, c_int_d)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 20, types.UnspecifiedLength},
{"lower(c_int_d)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 20, types.UnspecifiedLength},
Loading