diff --git a/executor/distsql.go b/executor/distsql.go index 465e6746226a3..b19761ad69397 100644 --- a/executor/distsql.go +++ b/executor/distsql.go @@ -158,7 +158,7 @@ func statementContextToFlags(sc *stmtctx.StatementContext) uint64 { var flags uint64 if sc.InInsertStmt { flags |= FlagInInsertStmt - } else if sc.InUpdateOrDeleteStmt { + } else if sc.InUpdateStmt || sc.InDeleteStmt { flags |= FlagInUpdateOrDeleteStmt } else if sc.InSelectStmt { flags |= FlagInSelectStmt diff --git a/executor/executor_test.go b/executor/executor_test.go index 7530b11285170..7674438b770cf 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -304,6 +304,7 @@ func checkCases(tests []testCase, ld *executor.LoadDataInfo, c *C, tk *testkit.TestKit, ctx sessionctx.Context, selectSQL, deleteSQL string) { for _, tt := range tests { c.Assert(ctx.NewTxn(), IsNil) + ctx.GetSessionVars().StmtCtx.InLoadDataStmt = true data, reachLimit, err1 := ld.InsertData(tt.data1, tt.data2) c.Assert(err1, IsNil) c.Assert(reachLimit, IsFalse) diff --git a/executor/prepared.go b/executor/prepared.go index e3668cac337fc..a0434be190cbc 100644 --- a/executor/prepared.go +++ b/executor/prepared.go @@ -286,7 +286,7 @@ func ResetStmtCtx(ctx sessionctx.Context, s ast.StmtNode) { sc.IgnoreTruncate = false sc.OverflowAsWarning = false sc.TruncateAsWarning = !sessVars.StrictSQLMode || stmt.IgnoreErr - sc.InUpdateOrDeleteStmt = true + sc.InUpdateStmt = true sc.DividedByZeroAsWarning = stmt.IgnoreErr sc.IgnoreZeroInDate = !sessVars.StrictSQLMode || stmt.IgnoreErr sc.Priority = stmt.Priority @@ -294,7 +294,7 @@ func ResetStmtCtx(ctx sessionctx.Context, s ast.StmtNode) { sc.IgnoreTruncate = false sc.OverflowAsWarning = false sc.TruncateAsWarning = !sessVars.StrictSQLMode || stmt.IgnoreErr - sc.InUpdateOrDeleteStmt = true + sc.InDeleteStmt = true sc.DividedByZeroAsWarning = stmt.IgnoreErr sc.IgnoreZeroInDate = !sessVars.StrictSQLMode || stmt.IgnoreErr sc.Priority = stmt.Priority diff --git a/executor/write_test.go b/executor/write_test.go index b9225703e44ed..722b560be5621 100644 --- a/executor/write_test.go +++ b/executor/write_test.go @@ -217,6 +217,27 @@ func (s *testSuite) TestInsert(c *C) { tk.MustExec("insert into test values(2, 3)") tk.MustQuery("select * from test use index (id) where id = 2").Check(testkit.Rows("2 2", "2 3")) + // issue 6360 + tk.MustExec("drop table if exists t;") + tk.MustExec("create table t(a bigint unsigned);") + tk.MustExec(" set @orig_sql_mode = @@sql_mode; set @@sql_mode = 'strict_all_tables';") + _, err = tk.Exec("insert into t value (-1);") + c.Assert(types.ErrWarnDataOutOfRange.Equal(err), IsTrue) + tk.MustExec("set @@sql_mode = '';") + tk.MustExec("insert into t value (-1);") + // TODO: the following warning messages are not consistent with MySQL, fix them in the future PRs + tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1690 constant -1 overflows bigint")) + tk.MustExec("insert into t select -1;") + tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1690 constant -1 overflows bigint")) + tk.MustExec("insert into t select cast(-1 as unsigned);") + tk.MustExec("insert into t value (-1.111);") + tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1690 constant -1 overflows bigint")) + tk.MustExec("insert into t value ('-1.111');") + tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1690 BIGINT UNSIGNED value is out of range in '-1'")) + r = tk.MustQuery("select * from t;") + r.Check(testkit.Rows("0", "0", "18446744073709551615", "0", "0")) + tk.MustExec("set @@sql_mode = @orig_sql_mode;") + // issue 6424 tk.MustExec("drop table if exists t") tk.MustExec("create table t(a time(6))") @@ -1344,6 +1365,26 @@ func makeLoadDataInfo(column int, specifiedColumns []string, ctx sessionctx.Cont return } +// related to issue 6360 +func (s *testSuite) TestLoadDataOverflowBigintUnsigned(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test; drop table if exists load_data_test;") + tk.MustExec("CREATE TABLE load_data_test (a bigint unsigned);") + tk.MustExec("load data local infile '/tmp/nonexistence.csv' into table load_data_test") + ctx := tk.Se.(sessionctx.Context) + ld, ok := ctx.Value(executor.LoadDataVarKey).(*executor.LoadDataInfo) + c.Assert(ok, IsTrue) + defer ctx.SetValue(executor.LoadDataVarKey, nil) + c.Assert(ld, NotNil) + tests := []testCase{ + {nil, []byte("-1\n-18446744073709551615\n-18446744073709551616\n"), []string{"0", "0", "0"}, nil}, + {nil, []byte("-9223372036854775809\n18446744073709551616\n"), []string{"0", "18446744073709551615"}, nil}, + } + deleteSQL := "delete from load_data_test" + selectSQL := "select * from load_data_test;" + checkCases(tests, ld, c, tk, ctx, selectSQL, deleteSQL) +} + func (s *testSuite) TestBatchInsertDelete(c *C) { originLimit := atomic.LoadUint64(&kv.TxnEntryCountLimit) defer func() { diff --git a/expression/builtin_cast.go b/expression/builtin_cast.go index 8b3bd7f68eeeb..74124b0bc9891 100644 --- a/expression/builtin_cast.go +++ b/expression/builtin_cast.go @@ -457,7 +457,8 @@ func (b *builtinCastIntAsRealSig) evalReal(row types.Row) (res float64, isNull b res = float64(val) } else { var uVal uint64 - uVal, err = types.ConvertIntToUint(val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong) + sc := b.ctx.GetSessionVars().StmtCtx + uVal, err = types.ConvertIntToUint(sc, val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong) res = float64(uVal) } return res, false, errors.Trace(err) @@ -482,7 +483,8 @@ func (b *builtinCastIntAsDecimalSig) evalDecimal(row types.Row) (res *types.MyDe res = types.NewDecFromInt(val) } else { var uVal uint64 - uVal, err = types.ConvertIntToUint(val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong) + sc := b.ctx.GetSessionVars().StmtCtx + uVal, err = types.ConvertIntToUint(sc, val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong) if err != nil { return res, false, errors.Trace(err) } @@ -511,7 +513,8 @@ func (b *builtinCastIntAsStringSig) evalString(row types.Row) (res string, isNul res = strconv.FormatInt(val, 10) } else { var uVal uint64 - uVal, err = types.ConvertIntToUint(val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong) + sc := b.ctx.GetSessionVars().StmtCtx + uVal, err = types.ConvertIntToUint(sc, val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong) if err != nil { return res, false, errors.Trace(err) } @@ -732,7 +735,8 @@ func (b *builtinCastRealAsIntSig) evalInt(row types.Row) (res int64, isNull bool res, err = types.ConvertFloatToInt(val, types.SignedLowerBound[mysql.TypeLonglong], types.SignedUpperBound[mysql.TypeLonglong], mysql.TypeDouble) } else { var uintVal uint64 - uintVal, err = types.ConvertFloatToUint(val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeDouble) + sc := b.ctx.GetSessionVars().StmtCtx + uintVal, err = types.ConvertFloatToUint(sc, val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeDouble) res = int64(uintVal) } return res, isNull, errors.Trace(err) diff --git a/expression/errors.go b/expression/errors.go index 512880b8c688d..7fea525479418 100644 --- a/expression/errors.go +++ b/expression/errors.go @@ -59,7 +59,7 @@ func handleInvalidTimeError(ctx sessionctx.Context, err error) error { return err } sc := ctx.GetSessionVars().StmtCtx - if ctx.GetSessionVars().StrictSQLMode && (sc.InInsertStmt || sc.InUpdateOrDeleteStmt) { + if ctx.GetSessionVars().StrictSQLMode && (sc.InInsertStmt || sc.InUpdateStmt || sc.InDeleteStmt) { return err } sc.AppendWarning(err) @@ -69,7 +69,7 @@ func handleInvalidTimeError(ctx sessionctx.Context, err error) error { // handleDivisionByZeroError reports error or warning depend on the context. func handleDivisionByZeroError(ctx sessionctx.Context) error { sc := ctx.GetSessionVars().StmtCtx - if sc.InInsertStmt || sc.InUpdateOrDeleteStmt { + if sc.InInsertStmt || sc.InUpdateStmt || sc.InDeleteStmt { if !ctx.GetSessionVars().SQLMode.HasErrorForDivisionByZeroMode() { return nil } diff --git a/sessionctx/stmtctx/stmtctx.go b/sessionctx/stmtctx/stmtctx.go index 5347e686f8c5d..fdaa42b13b408 100644 --- a/sessionctx/stmtctx/stmtctx.go +++ b/sessionctx/stmtctx/stmtctx.go @@ -42,8 +42,10 @@ type StatementContext struct { // Set the following variables before execution InInsertStmt bool - InUpdateOrDeleteStmt bool + InUpdateStmt bool + InDeleteStmt bool InSelectStmt bool + InLoadDataStmt bool IgnoreTruncate bool IgnoreZeroInDate bool DividedByZeroAsWarning bool @@ -223,3 +225,21 @@ func (sc *StatementContext) GetExecDetails() execdetails.ExecDetails { sc.mu.Unlock() return details } + +// ShouldClipToZero indicates whether values less than 0 should be clipped to 0 for unsigned integer types. +// This is the case for `insert`, `update`, `alter table` and `load data infile` statements, when not in strict SQL mode. +// see https://dev.mysql.com/doc/refman/5.7/en/out-of-range-and-overflow.html +func (sc *StatementContext) ShouldClipToZero() bool { + // TODO: Currently altering column of integer to unsigned integer is not supported. + // If it is supported one day, that case should be added here. + return sc.InInsertStmt || sc.InLoadDataStmt +} + +// ShouldIgnoreOverflowError indicates whether we should ignore the error when type conversion overflows, +// so we can leave it for further processing like clipping values less than 0 to 0 for unsigned integer types. +func (sc *StatementContext) ShouldIgnoreOverflowError() bool { + if (sc.InInsertStmt && sc.TruncateAsWarning) || sc.InLoadDataStmt { + return true + } + return false +} diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index 42328538afc92..8476590b9e5bd 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -469,7 +469,7 @@ func (s *SessionVars) GetTimeZone() *time.Location { func (s *SessionVars) ResetPrevAffectedRows() { s.PrevAffectedRows = 0 if s.StmtCtx != nil { - if s.StmtCtx.InUpdateOrDeleteStmt || s.StmtCtx.InInsertStmt { + if s.StmtCtx.InUpdateStmt || s.StmtCtx.InDeleteStmt || s.StmtCtx.InInsertStmt { s.PrevAffectedRows = int64(s.StmtCtx.AffectedRows()) } else if s.StmtCtx.InSelectStmt { s.PrevAffectedRows = -1 diff --git a/types/convert.go b/types/convert.go index 87b87efbecf1c..57a8471dbc8a7 100644 --- a/types/convert.go +++ b/types/convert.go @@ -106,7 +106,11 @@ func ConvertUintToInt(val uint64, upperBound int64, tp byte) (int64, error) { } // ConvertIntToUint converts an int value to an uint value. -func ConvertIntToUint(val int64, upperBound uint64, tp byte) (uint64, error) { +func ConvertIntToUint(sc *stmtctx.StatementContext, val int64, upperBound uint64, tp byte) (uint64, error) { + if sc.ShouldClipToZero() && val < 0 { + return 0, overflow(val, tp) + } + if uint64(val) > upperBound { return upperBound, overflow(val, tp) } @@ -124,9 +128,12 @@ func ConvertUintToUint(val uint64, upperBound uint64, tp byte) (uint64, error) { } // ConvertFloatToUint converts a float value to an uint value. -func ConvertFloatToUint(fval float64, upperBound uint64, tp byte) (uint64, error) { +func ConvertFloatToUint(sc *stmtctx.StatementContext, fval float64, upperBound uint64, tp byte) (uint64, error) { val := RoundFloat(fval) if val < 0 { + if sc.ShouldClipToZero() { + return 0, overflow(val, tp) + } return uint64(int64(val)), overflow(val, tp) } @@ -343,7 +350,7 @@ func ConvertJSONToInt(sc *stmtctx.StatementContext, j json.BinaryJSON, unsigned return ConvertFloatToInt(f, lBound, uBound, mysql.TypeDouble) } bound := UnsignedUpperBound[mysql.TypeLonglong] - u, err := ConvertFloatToUint(f, bound, mysql.TypeDouble) + u, err := ConvertFloatToUint(sc, f, bound, mysql.TypeDouble) return int64(u), errors.Trace(err) case json.TypeCodeString: return StrToInt(sc, hack.String(j.GetString())) @@ -366,7 +373,7 @@ func ConvertJSONToFloat(sc *stmtctx.StatementContext, j json.BinaryJSON) (float6 case json.TypeCodeInt64: return float64(j.GetInt64()), nil case json.TypeCodeUint64: - u, err := ConvertIntToUint(j.GetInt64(), UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong) + u, err := ConvertIntToUint(sc, j.GetInt64(), UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong) return float64(u), errors.Trace(err) case json.TypeCodeFloat64: return j.GetFloat64(), nil diff --git a/types/datum.go b/types/datum.go index b1bb1ed02dd08..78b0eaa20194a 100644 --- a/types/datum.go +++ b/types/datum.go @@ -850,21 +850,21 @@ func (d *Datum) convertToUint(sc *stmtctx.StatementContext, target *FieldType) ( ) switch d.k { case KindInt64: - val, err = ConvertIntToUint(d.GetInt64(), upperBound, tp) + val, err = ConvertIntToUint(sc, d.GetInt64(), upperBound, tp) case KindUint64: val, err = ConvertUintToUint(d.GetUint64(), upperBound, tp) case KindFloat32, KindFloat64: - val, err = ConvertFloatToUint(d.GetFloat64(), upperBound, tp) + val, err = ConvertFloatToUint(sc, d.GetFloat64(), upperBound, tp) case KindString, KindBytes: - val, err = StrToUint(sc, d.GetString()) - if err != nil { - return ret, errors.Trace(err) + uval, err1 := StrToUint(sc, d.GetString()) + if err1 != nil && ErrOverflow.Equal(err1) && !sc.ShouldIgnoreOverflowError() { + return ret, errors.Trace(err1) } - val, err = ConvertUintToUint(val, upperBound, tp) + val, err = ConvertUintToUint(uval, upperBound, tp) if err != nil { return ret, errors.Trace(err) } - ret.SetUint64(val) + err = err1 case KindMysqlTime: dec := d.GetMysqlTime().ToNumber() err = dec.Round(dec, 0, ModeHalfEven) @@ -872,7 +872,7 @@ func (d *Datum) convertToUint(sc *stmtctx.StatementContext, target *FieldType) ( if err == nil { err = err1 } - val, err1 = ConvertIntToUint(ival, upperBound, tp) + val, err1 = ConvertIntToUint(sc, ival, upperBound, tp) if err == nil { err = err1 } @@ -881,18 +881,18 @@ func (d *Datum) convertToUint(sc *stmtctx.StatementContext, target *FieldType) ( err = dec.Round(dec, 0, ModeHalfEven) ival, err1 := dec.ToInt() if err1 == nil { - val, err = ConvertIntToUint(ival, upperBound, tp) + val, err = ConvertIntToUint(sc, ival, upperBound, tp) } case KindMysqlDecimal: fval, err1 := d.GetMysqlDecimal().ToFloat64() - val, err = ConvertFloatToUint(fval, upperBound, tp) + val, err = ConvertFloatToUint(sc, fval, upperBound, tp) if err == nil { err = err1 } case KindMysqlEnum: - val, err = ConvertFloatToUint(d.GetMysqlEnum().ToNumber(), upperBound, tp) + val, err = ConvertFloatToUint(sc, d.GetMysqlEnum().ToNumber(), upperBound, tp) case KindMysqlSet: - val, err = ConvertFloatToUint(d.GetMysqlSet().ToNumber(), upperBound, tp) + val, err = ConvertFloatToUint(sc, d.GetMysqlSet().ToNumber(), upperBound, tp) case KindBinaryLiteral, KindMysqlBit: val, err = d.GetBinaryLiteral().ToInt(sc) case KindMysqlJSON: @@ -1105,7 +1105,7 @@ func ProduceDecWithSpecifiedTp(dec *MyDecimal, tp *FieldType, sc *stmtctx.Statem return nil, errors.Trace(err) } if !dec.IsZero() && frac > decimal && dec.Compare(&old) != 0 { - if sc.InInsertStmt || sc.InUpdateOrDeleteStmt { + if sc.InInsertStmt || sc.InUpdateStmt || sc.InDeleteStmt { // fix /~https://github.com/pingcap/tidb/issues/3895 // fix /~https://github.com/pingcap/tidb/issues/5532 sc.AppendWarning(ErrTruncated)