Skip to content

Commit

Permalink
*: fix the lower bound when converting numbers less than 0 to unsigne…
Browse files Browse the repository at this point in the history
…d integers (pingcap#8544)
  • Loading branch information
exialin committed Jan 11, 2019
1 parent 1570609 commit 4830b11
Show file tree
Hide file tree
Showing 10 changed files with 101 additions and 28 deletions.
2 changes: 1 addition & 1 deletion executor/distsql.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ func statementContextToFlags(sc *stmtctx.StatementContext) uint64 {
var flags uint64
if sc.InInsertStmt {
flags |= FlagInInsertStmt
} else if sc.InUpdateOrDeleteStmt {
} else if sc.InUpdateStmt || sc.InDeleteStmt {
flags |= FlagInUpdateOrDeleteStmt
} else if sc.InSelectStmt {
flags |= FlagInSelectStmt
Expand Down
1 change: 1 addition & 0 deletions executor/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ func checkCases(tests []testCase, ld *executor.LoadDataInfo,
c *C, tk *testkit.TestKit, ctx sessionctx.Context, selectSQL, deleteSQL string) {
for _, tt := range tests {
c.Assert(ctx.NewTxn(), IsNil)
ctx.GetSessionVars().StmtCtx.InLoadDataStmt = true
data, reachLimit, err1 := ld.InsertData(tt.data1, tt.data2)
c.Assert(err1, IsNil)
c.Assert(reachLimit, IsFalse)
Expand Down
4 changes: 2 additions & 2 deletions executor/prepared.go
Original file line number Diff line number Diff line change
Expand Up @@ -286,15 +286,15 @@ func ResetStmtCtx(ctx sessionctx.Context, s ast.StmtNode) {
sc.IgnoreTruncate = false
sc.OverflowAsWarning = false
sc.TruncateAsWarning = !sessVars.StrictSQLMode || stmt.IgnoreErr
sc.InUpdateOrDeleteStmt = true
sc.InUpdateStmt = true
sc.DividedByZeroAsWarning = stmt.IgnoreErr
sc.IgnoreZeroInDate = !sessVars.StrictSQLMode || stmt.IgnoreErr
sc.Priority = stmt.Priority
case *ast.DeleteStmt:
sc.IgnoreTruncate = false
sc.OverflowAsWarning = false
sc.TruncateAsWarning = !sessVars.StrictSQLMode || stmt.IgnoreErr
sc.InUpdateOrDeleteStmt = true
sc.InDeleteStmt = true
sc.DividedByZeroAsWarning = stmt.IgnoreErr
sc.IgnoreZeroInDate = !sessVars.StrictSQLMode || stmt.IgnoreErr
sc.Priority = stmt.Priority
Expand Down
41 changes: 41 additions & 0 deletions executor/write_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,27 @@ func (s *testSuite) TestInsert(c *C) {
tk.MustExec("insert into test values(2, 3)")
tk.MustQuery("select * from test use index (id) where id = 2").Check(testkit.Rows("2 2", "2 3"))

// issue 6360
tk.MustExec("drop table if exists t;")
tk.MustExec("create table t(a bigint unsigned);")
tk.MustExec(" set @orig_sql_mode = @@sql_mode; set @@sql_mode = 'strict_all_tables';")
_, err = tk.Exec("insert into t value (-1);")
c.Assert(types.ErrWarnDataOutOfRange.Equal(err), IsTrue)
tk.MustExec("set @@sql_mode = '';")
tk.MustExec("insert into t value (-1);")
// TODO: the following warning messages are not consistent with MySQL, fix them in the future PRs
tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1690 constant -1 overflows bigint"))
tk.MustExec("insert into t select -1;")
tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1690 constant -1 overflows bigint"))
tk.MustExec("insert into t select cast(-1 as unsigned);")
tk.MustExec("insert into t value (-1.111);")
tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1690 constant -1 overflows bigint"))
tk.MustExec("insert into t value ('-1.111');")
tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1690 BIGINT UNSIGNED value is out of range in '-1'"))
r = tk.MustQuery("select * from t;")
r.Check(testkit.Rows("0", "0", "18446744073709551615", "0", "0"))
tk.MustExec("set @@sql_mode = @orig_sql_mode;")

// issue 6424
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a time(6))")
Expand Down Expand Up @@ -1344,6 +1365,26 @@ func makeLoadDataInfo(column int, specifiedColumns []string, ctx sessionctx.Cont
return
}

// related to issue 6360
func (s *testSuite) TestLoadDataOverflowBigintUnsigned(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test; drop table if exists load_data_test;")
tk.MustExec("CREATE TABLE load_data_test (a bigint unsigned);")
tk.MustExec("load data local infile '/tmp/nonexistence.csv' into table load_data_test")
ctx := tk.Se.(sessionctx.Context)
ld, ok := ctx.Value(executor.LoadDataVarKey).(*executor.LoadDataInfo)
c.Assert(ok, IsTrue)
defer ctx.SetValue(executor.LoadDataVarKey, nil)
c.Assert(ld, NotNil)
tests := []testCase{
{nil, []byte("-1\n-18446744073709551615\n-18446744073709551616\n"), []string{"0", "0", "0"}, nil},
{nil, []byte("-9223372036854775809\n18446744073709551616\n"), []string{"0", "18446744073709551615"}, nil},
}
deleteSQL := "delete from load_data_test"
selectSQL := "select * from load_data_test;"
checkCases(tests, ld, c, tk, ctx, selectSQL, deleteSQL)
}

func (s *testSuite) TestBatchInsertDelete(c *C) {
originLimit := atomic.LoadUint64(&kv.TxnEntryCountLimit)
defer func() {
Expand Down
12 changes: 8 additions & 4 deletions expression/builtin_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,8 @@ func (b *builtinCastIntAsRealSig) evalReal(row types.Row) (res float64, isNull b
res = float64(val)
} else {
var uVal uint64
uVal, err = types.ConvertIntToUint(val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong)
sc := b.ctx.GetSessionVars().StmtCtx
uVal, err = types.ConvertIntToUint(sc, val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong)
res = float64(uVal)
}
return res, false, errors.Trace(err)
Expand All @@ -482,7 +483,8 @@ func (b *builtinCastIntAsDecimalSig) evalDecimal(row types.Row) (res *types.MyDe
res = types.NewDecFromInt(val)
} else {
var uVal uint64
uVal, err = types.ConvertIntToUint(val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong)
sc := b.ctx.GetSessionVars().StmtCtx
uVal, err = types.ConvertIntToUint(sc, val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong)
if err != nil {
return res, false, errors.Trace(err)
}
Expand Down Expand Up @@ -511,7 +513,8 @@ func (b *builtinCastIntAsStringSig) evalString(row types.Row) (res string, isNul
res = strconv.FormatInt(val, 10)
} else {
var uVal uint64
uVal, err = types.ConvertIntToUint(val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong)
sc := b.ctx.GetSessionVars().StmtCtx
uVal, err = types.ConvertIntToUint(sc, val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong)
if err != nil {
return res, false, errors.Trace(err)
}
Expand Down Expand Up @@ -732,7 +735,8 @@ func (b *builtinCastRealAsIntSig) evalInt(row types.Row) (res int64, isNull bool
res, err = types.ConvertFloatToInt(val, types.SignedLowerBound[mysql.TypeLonglong], types.SignedUpperBound[mysql.TypeLonglong], mysql.TypeDouble)
} else {
var uintVal uint64
uintVal, err = types.ConvertFloatToUint(val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeDouble)
sc := b.ctx.GetSessionVars().StmtCtx
uintVal, err = types.ConvertFloatToUint(sc, val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeDouble)
res = int64(uintVal)
}
return res, isNull, errors.Trace(err)
Expand Down
4 changes: 2 additions & 2 deletions expression/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func handleInvalidTimeError(ctx sessionctx.Context, err error) error {
return err
}
sc := ctx.GetSessionVars().StmtCtx
if ctx.GetSessionVars().StrictSQLMode && (sc.InInsertStmt || sc.InUpdateOrDeleteStmt) {
if ctx.GetSessionVars().StrictSQLMode && (sc.InInsertStmt || sc.InUpdateStmt || sc.InDeleteStmt) {
return err
}
sc.AppendWarning(err)
Expand All @@ -69,7 +69,7 @@ func handleInvalidTimeError(ctx sessionctx.Context, err error) error {
// handleDivisionByZeroError reports error or warning depend on the context.
func handleDivisionByZeroError(ctx sessionctx.Context) error {
sc := ctx.GetSessionVars().StmtCtx
if sc.InInsertStmt || sc.InUpdateOrDeleteStmt {
if sc.InInsertStmt || sc.InUpdateStmt || sc.InDeleteStmt {
if !ctx.GetSessionVars().SQLMode.HasErrorForDivisionByZeroMode() {
return nil
}
Expand Down
22 changes: 21 additions & 1 deletion sessionctx/stmtctx/stmtctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,10 @@ type StatementContext struct {
// Set the following variables before execution

InInsertStmt bool
InUpdateOrDeleteStmt bool
InUpdateStmt bool
InDeleteStmt bool
InSelectStmt bool
InLoadDataStmt bool
IgnoreTruncate bool
IgnoreZeroInDate bool
DividedByZeroAsWarning bool
Expand Down Expand Up @@ -223,3 +225,21 @@ func (sc *StatementContext) GetExecDetails() execdetails.ExecDetails {
sc.mu.Unlock()
return details
}

// ShouldClipToZero indicates whether values less than 0 should be clipped to 0 for unsigned integer types.
// This is the case for `insert`, `update`, `alter table` and `load data infile` statements, when not in strict SQL mode.
// see https://dev.mysql.com/doc/refman/5.7/en/out-of-range-and-overflow.html
func (sc *StatementContext) ShouldClipToZero() bool {
// TODO: Currently altering column of integer to unsigned integer is not supported.
// If it is supported one day, that case should be added here.
return sc.InInsertStmt || sc.InLoadDataStmt
}

// ShouldIgnoreOverflowError indicates whether we should ignore the error when type conversion overflows,
// so we can leave it for further processing like clipping values less than 0 to 0 for unsigned integer types.
func (sc *StatementContext) ShouldIgnoreOverflowError() bool {
if (sc.InInsertStmt && sc.TruncateAsWarning) || sc.InLoadDataStmt {
return true
}
return false
}
2 changes: 1 addition & 1 deletion sessionctx/variable/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ func (s *SessionVars) GetTimeZone() *time.Location {
func (s *SessionVars) ResetPrevAffectedRows() {
s.PrevAffectedRows = 0
if s.StmtCtx != nil {
if s.StmtCtx.InUpdateOrDeleteStmt || s.StmtCtx.InInsertStmt {
if s.StmtCtx.InUpdateStmt || s.StmtCtx.InDeleteStmt || s.StmtCtx.InInsertStmt {
s.PrevAffectedRows = int64(s.StmtCtx.AffectedRows())
} else if s.StmtCtx.InSelectStmt {
s.PrevAffectedRows = -1
Expand Down
15 changes: 11 additions & 4 deletions types/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,11 @@ func ConvertUintToInt(val uint64, upperBound int64, tp byte) (int64, error) {
}

// ConvertIntToUint converts an int value to an uint value.
func ConvertIntToUint(val int64, upperBound uint64, tp byte) (uint64, error) {
func ConvertIntToUint(sc *stmtctx.StatementContext, val int64, upperBound uint64, tp byte) (uint64, error) {
if sc.ShouldClipToZero() && val < 0 {
return 0, overflow(val, tp)
}

if uint64(val) > upperBound {
return upperBound, overflow(val, tp)
}
Expand All @@ -124,9 +128,12 @@ func ConvertUintToUint(val uint64, upperBound uint64, tp byte) (uint64, error) {
}

// ConvertFloatToUint converts a float value to an uint value.
func ConvertFloatToUint(fval float64, upperBound uint64, tp byte) (uint64, error) {
func ConvertFloatToUint(sc *stmtctx.StatementContext, fval float64, upperBound uint64, tp byte) (uint64, error) {
val := RoundFloat(fval)
if val < 0 {
if sc.ShouldClipToZero() {
return 0, overflow(val, tp)
}
return uint64(int64(val)), overflow(val, tp)
}

Expand Down Expand Up @@ -343,7 +350,7 @@ func ConvertJSONToInt(sc *stmtctx.StatementContext, j json.BinaryJSON, unsigned
return ConvertFloatToInt(f, lBound, uBound, mysql.TypeDouble)
}
bound := UnsignedUpperBound[mysql.TypeLonglong]
u, err := ConvertFloatToUint(f, bound, mysql.TypeDouble)
u, err := ConvertFloatToUint(sc, f, bound, mysql.TypeDouble)
return int64(u), errors.Trace(err)
case json.TypeCodeString:
return StrToInt(sc, hack.String(j.GetString()))
Expand All @@ -366,7 +373,7 @@ func ConvertJSONToFloat(sc *stmtctx.StatementContext, j json.BinaryJSON) (float6
case json.TypeCodeInt64:
return float64(j.GetInt64()), nil
case json.TypeCodeUint64:
u, err := ConvertIntToUint(j.GetInt64(), UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong)
u, err := ConvertIntToUint(sc, j.GetInt64(), UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong)
return float64(u), errors.Trace(err)
case json.TypeCodeFloat64:
return j.GetFloat64(), nil
Expand Down
26 changes: 13 additions & 13 deletions types/datum.go
Original file line number Diff line number Diff line change
Expand Up @@ -850,29 +850,29 @@ func (d *Datum) convertToUint(sc *stmtctx.StatementContext, target *FieldType) (
)
switch d.k {
case KindInt64:
val, err = ConvertIntToUint(d.GetInt64(), upperBound, tp)
val, err = ConvertIntToUint(sc, d.GetInt64(), upperBound, tp)
case KindUint64:
val, err = ConvertUintToUint(d.GetUint64(), upperBound, tp)
case KindFloat32, KindFloat64:
val, err = ConvertFloatToUint(d.GetFloat64(), upperBound, tp)
val, err = ConvertFloatToUint(sc, d.GetFloat64(), upperBound, tp)
case KindString, KindBytes:
val, err = StrToUint(sc, d.GetString())
if err != nil {
return ret, errors.Trace(err)
uval, err1 := StrToUint(sc, d.GetString())
if err1 != nil && ErrOverflow.Equal(err1) && !sc.ShouldIgnoreOverflowError() {
return ret, errors.Trace(err1)
}
val, err = ConvertUintToUint(val, upperBound, tp)
val, err = ConvertUintToUint(uval, upperBound, tp)
if err != nil {
return ret, errors.Trace(err)
}
ret.SetUint64(val)
err = err1
case KindMysqlTime:
dec := d.GetMysqlTime().ToNumber()
err = dec.Round(dec, 0, ModeHalfEven)
ival, err1 := dec.ToInt()
if err == nil {
err = err1
}
val, err1 = ConvertIntToUint(ival, upperBound, tp)
val, err1 = ConvertIntToUint(sc, ival, upperBound, tp)
if err == nil {
err = err1
}
Expand All @@ -881,18 +881,18 @@ func (d *Datum) convertToUint(sc *stmtctx.StatementContext, target *FieldType) (
err = dec.Round(dec, 0, ModeHalfEven)
ival, err1 := dec.ToInt()
if err1 == nil {
val, err = ConvertIntToUint(ival, upperBound, tp)
val, err = ConvertIntToUint(sc, ival, upperBound, tp)
}
case KindMysqlDecimal:
fval, err1 := d.GetMysqlDecimal().ToFloat64()
val, err = ConvertFloatToUint(fval, upperBound, tp)
val, err = ConvertFloatToUint(sc, fval, upperBound, tp)
if err == nil {
err = err1
}
case KindMysqlEnum:
val, err = ConvertFloatToUint(d.GetMysqlEnum().ToNumber(), upperBound, tp)
val, err = ConvertFloatToUint(sc, d.GetMysqlEnum().ToNumber(), upperBound, tp)
case KindMysqlSet:
val, err = ConvertFloatToUint(d.GetMysqlSet().ToNumber(), upperBound, tp)
val, err = ConvertFloatToUint(sc, d.GetMysqlSet().ToNumber(), upperBound, tp)
case KindBinaryLiteral, KindMysqlBit:
val, err = d.GetBinaryLiteral().ToInt(sc)
case KindMysqlJSON:
Expand Down Expand Up @@ -1105,7 +1105,7 @@ func ProduceDecWithSpecifiedTp(dec *MyDecimal, tp *FieldType, sc *stmtctx.Statem
return nil, errors.Trace(err)
}
if !dec.IsZero() && frac > decimal && dec.Compare(&old) != 0 {
if sc.InInsertStmt || sc.InUpdateOrDeleteStmt {
if sc.InInsertStmt || sc.InUpdateStmt || sc.InDeleteStmt {
// fix https://github.com/pingcap/tidb/issues/3895
// fix https://github.com/pingcap/tidb/issues/5532
sc.AppendWarning(ErrTruncated)
Expand Down

0 comments on commit 4830b11

Please sign in to comment.