Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression: add max_allowed_packet check in concat/concat_ws #11137

Merged
merged 15 commits into from
Jul 16, 2019
36 changes: 34 additions & 2 deletions expression/builtin_string.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,29 +272,44 @@ func (c *concatFunctionClass) getFunction(ctx sessionctx.Context, args []Express
if bf.tp.Flen >= mysql.MaxBlobWidth {
bf.tp.Flen = mysql.MaxBlobWidth
}
sig := &builtinConcatSig{bf}

valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket)
maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
return nil, err
}

sig := &builtinConcatSig{bf, maxAllowedPacket}
return sig, nil
}

type builtinConcatSig struct {
baseBuiltinFunc
maxAllowedPacket uint64
}

func (b *builtinConcatSig) Clone() builtinFunc {
newSig := &builtinConcatSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.maxAllowedPacket = b.maxAllowedPacket
return newSig
}

// evalString evals a builtinConcatSig
// See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_concat
func (b *builtinConcatSig) evalString(row chunk.Row) (d string, isNull bool, err error) {
var s []byte
var targetLength int
for _, a := range b.getArgs() {
d, isNull, err = a.EvalString(b.ctx, row)
if isNull || err != nil {
return d, isNull, err
}
targetLength += len(d)
if uint64(targetLength) > b.maxAllowedPacket {
SunRunAway marked this conversation as resolved.
Show resolved Hide resolved
b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenWithStackByArgs("concat", b.maxAllowedPacket))
return "", true, nil
}
s = append(s, []byte(d)...)
}
return string(s), false, nil
Expand Down Expand Up @@ -337,17 +352,25 @@ func (c *concatWSFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
bf.tp.Flen = mysql.MaxBlobWidth
}

sig := &builtinConcatWSSig{bf}
valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket)
maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
return nil, err
}

sig := &builtinConcatWSSig{bf, maxAllowedPacket}
return sig, nil
}

type builtinConcatWSSig struct {
baseBuiltinFunc
maxAllowedPacket uint64
}

func (b *builtinConcatWSSig) Clone() builtinFunc {
newSig := &builtinConcatWSSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.maxAllowedPacket = b.maxAllowedPacket
return newSig
}

Expand All @@ -357,6 +380,7 @@ func (b *builtinConcatWSSig) evalString(row chunk.Row) (string, bool, error) {
args := b.getArgs()
strs := make([]string, 0, len(args))
var sep string
var targetLength int
for i, arg := range args {
val, isNull, err := arg.EvalString(b.ctx, row)
if err != nil {
Expand All @@ -377,6 +401,14 @@ func (b *builtinConcatWSSig) evalString(row chunk.Row) (string, bool, error) {
sep = val
continue
}
targetLength += len(val)
if i > 1 {
targetLength += len(sep)
}
SunRunAway marked this conversation as resolved.
Show resolved Hide resolved
if uint64(targetLength) > b.maxAllowedPacket {
b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenWithStackByArgs("concat_ws", b.maxAllowedPacket))
return "", true, nil
}
strs = append(strs, val)
}

Expand Down
91 changes: 91 additions & 0 deletions expression/builtin_string_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,50 @@ func (s *testEvaluatorSuite) TestConcat(c *C) {
}
}

func (s *testEvaluatorSuite) TestConcatSig(c *C) {
colTypes := []*types.FieldType{
{Tp: mysql.TypeVarchar},
{Tp: mysql.TypeVarchar},
}
resultType := &types.FieldType{Tp: mysql.TypeVarchar, Flen: 1000}
args := []Expression{
&Column{Index: 0, RetType: colTypes[0]},
&Column{Index: 1, RetType: colTypes[1]},
}
base := baseBuiltinFunc{args: args, ctx: s.ctx, tp: resultType}
concat := &builtinConcatSig{base, 5}

cases := []struct {
args []interface{}
warnings int
res string
}{
{[]interface{}{"a", "b"}, 0, "ab"},
{[]interface{}{"aaa", "bbb"}, 1, ""},
{[]interface{}{"中", "a"}, 0, "中a"},
{[]interface{}{"中文", "a"}, 2, ""},
}

for _, t := range cases {
input := chunk.NewChunkWithCapacity(colTypes, 10)
input.AppendString(0, t.args[0].(string))
input.AppendString(1, t.args[1].(string))

res, isNull, err := concat.evalString(input.GetRow(0))
c.Assert(res, Equals, t.res)
c.Assert(err, IsNil)
if t.warnings == 0 {
c.Assert(isNull, IsFalse)
} else {
c.Assert(isNull, IsTrue)
warnings := s.ctx.GetSessionVars().StmtCtx.GetWarnings()
c.Assert(warnings, HasLen, t.warnings)
lastWarn := warnings[len(warnings)-1]
c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, lastWarn.Err), IsTrue)
}
}
}

func (s *testEvaluatorSuite) TestConcatWS(c *C) {
defer testleak.AfterTest(c)()
cases := []struct {
Expand Down Expand Up @@ -246,6 +290,53 @@ func (s *testEvaluatorSuite) TestConcatWS(c *C) {
c.Assert(err, IsNil)
}

func (s *testEvaluatorSuite) TestConcatWSSig(c *C) {
colTypes := []*types.FieldType{
{Tp: mysql.TypeVarchar},
{Tp: mysql.TypeVarchar},
{Tp: mysql.TypeVarchar},
}
resultType := &types.FieldType{Tp: mysql.TypeVarchar, Flen: 1000}
args := []Expression{
&Column{Index: 0, RetType: colTypes[0]},
&Column{Index: 1, RetType: colTypes[1]},
&Column{Index: 2, RetType: colTypes[2]},
}
base := baseBuiltinFunc{args: args, ctx: s.ctx, tp: resultType}
concat := &builtinConcatWSSig{base, 6}

cases := []struct {
args []interface{}
warnings int
res string
}{
{[]interface{}{",", "a", "b"}, 0, "a,b"},
{[]interface{}{",", "aaa", "bbb"}, 1, ""},
{[]interface{}{",", "中", "a"}, 0, "中,a"},
{[]interface{}{",", "中文", "a"}, 2, ""},
}

for _, t := range cases {
input := chunk.NewChunkWithCapacity(colTypes, 10)
input.AppendString(0, t.args[0].(string))
input.AppendString(1, t.args[1].(string))
input.AppendString(2, t.args[2].(string))

res, isNull, err := concat.evalString(input.GetRow(0))
c.Assert(res, Equals, t.res)
c.Assert(err, IsNil)
if t.warnings == 0 {
c.Assert(isNull, IsFalse)
} else {
c.Assert(isNull, IsTrue)
warnings := s.ctx.GetSessionVars().StmtCtx.GetWarnings()
c.Assert(warnings, HasLen, t.warnings)
lastWarn := warnings[len(warnings)-1]
c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, lastWarn.Err), IsTrue)
}
}
}

func (s *testEvaluatorSuite) TestLeft(c *C) {
defer testleak.AfterTest(c)()
stmtCtx := s.ctx.GetSessionVars().StmtCtx
Expand Down
1 change: 1 addition & 0 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,7 @@ func (s *testIntegrationSuite) TestMathBuiltin(c *C) {
}

func (s *testIntegrationSuite) TestStringBuiltin(c *C) {
s.ctx.GetSessionVars().SetSystemVar(variable.MaxAllowedPacket, "67108864")
SunRunAway marked this conversation as resolved.
Show resolved Hide resolved
defer s.cleanEnv(c)
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
Expand Down
2 changes: 2 additions & 0 deletions planner/core/exhaust_physical_plans_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/pingcap/parser/model"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/types"
)

Expand All @@ -40,6 +41,7 @@ func (s *testUnitTestSuit) rewriteSimpleExpr(str string, schema *expression.Sche

func (s *testUnitTestSuit) TestIndexJoinAnalyzeLookUpFilters(c *C) {
s.ctx.GetSessionVars().PlanID = -1
s.ctx.GetSessionVars().SetSystemVar(variable.MaxAllowedPacket, "67108864")
joinNode := LogicalJoin{}.Init(s.ctx)
dataSourceNode := DataSource{}.Init(s.ctx)
dsSchema := expression.NewSchema()
Expand Down
2 changes: 2 additions & 0 deletions planner/core/logical_plan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/pingcap/tidb/infoschema"
"github.com/pingcap/tidb/planner/property"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/util/testleak"
)

Expand Down Expand Up @@ -1360,6 +1361,7 @@ func (s *testPlanSuite) TestValidate(c *C) {
err: ErrUnknownColumn,
},
}
s.ctx.GetSessionVars().SetSystemVar(variable.MaxAllowedPacket, "67108864")
for _, tt := range tests {
sql := tt.sql
comment := Commentf("for %s", sql)
Expand Down