Skip to content

Commit

Permalink
expression: migrate builtin tests [a-e]
Browse files Browse the repository at this point in the history
This refers to pingcap#26855 .

Signed-off-by: tison <[email protected]>
  • Loading branch information
tisonkun committed Oct 26, 2021
1 parent 19a2b3c commit b19ee43
Show file tree
Hide file tree
Showing 18 changed files with 1,136 additions and 1,025 deletions.
389 changes: 203 additions & 186 deletions expression/builtin_arithmetic_test.go

Large diffs are not rendered by default.

486 changes: 254 additions & 232 deletions expression/builtin_cast_test.go

Large diffs are not rendered by default.

28 changes: 16 additions & 12 deletions expression/builtin_cast_vec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ import (
"testing"
"time"

. "github.com/pingcap/check"
"github.com/pingcap/tidb/parser/ast"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/types/json"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/mock"
"github.com/stretchr/testify/require"
)

var vecBuiltinCastCases = map[string][]vecExprBenchCase{
Expand Down Expand Up @@ -157,7 +157,9 @@ func TestVectorizedBuiltinCastFunc(t *testing.T) {
testVectorizedBuiltinFunc(t, vecBuiltinCastCases)
}

func (s *testEvaluatorSuite) TestVectorizedCastRealAsTime(c *C) {
func TestVectorizedCastRealAsTime(t *testing.T) {
t.Parallel()

col := &Column{RetType: types.NewFieldType(mysql.TypeDouble), Index: 0}
baseFunc, err := newBaseBuiltinFunc(mock.NewContext(), "", []Expression{col}, 0)
if err != nil {
Expand All @@ -171,16 +173,16 @@ func (s *testEvaluatorSuite) TestVectorizedCastRealAsTime(c *C) {

for _, input := range inputs {
result := chunk.NewColumn(types.NewFieldType(mysql.TypeDatetime), input.NumRows())
c.Assert(cast.vecEvalTime(input, result), IsNil)
require.NoError(t, cast.vecEvalTime(input, result))
for i := 0; i < input.NumRows(); i++ {
res, isNull, err := cast.evalTime(input.GetRow(i))
c.Assert(err, IsNil)
require.NoError(t, err)
if isNull {
c.Assert(result.IsNull(i), IsTrue)
require.True(t, result.IsNull(i))
continue
}
c.Assert(result.IsNull(i), IsFalse)
c.Assert(result.GetTime(i).Compare(res), Equals, 0)
require.False(t, result.IsNull(i))
require.Zero(t, result.GetTime(i).Compare(res))
}
}
}
Expand All @@ -199,7 +201,9 @@ func genCastRealAsTime() *chunk.Chunk {
}

// for issue https://github.com/pingcap/tidb/issues/16825
func (s *testEvaluatorSuite) TestVectorizedCastStringAsDecimalWithUnsignedFlagInUnion(c *C) {
func TestVectorizedCastStringAsDecimalWithUnsignedFlagInUnion(t *testing.T) {
t.Parallel()

col := &Column{RetType: types.NewFieldType(mysql.TypeString), Index: 0}
baseFunc, err := newBaseBuiltinFunc(mock.NewContext(), "", []Expression{col}, 0)
if err != nil {
Expand All @@ -219,12 +223,12 @@ func (s *testEvaluatorSuite) TestVectorizedCastStringAsDecimalWithUnsignedFlagIn

for _, input := range inputs {
result := chunk.NewColumn(types.NewFieldType(mysql.TypeNewDecimal), input.NumRows())
c.Assert(cast.vecEvalDecimal(input, result), IsNil)
require.NoError(t, cast.vecEvalDecimal(input, result))
for i := 0; i < input.NumRows(); i++ {
res, isNull, err := cast.evalDecimal(input.GetRow(i))
c.Assert(isNull, IsFalse)
c.Assert(err, IsNil)
c.Assert(result.GetDecimal(i).Compare(res), Equals, 0)
require.False(t, isNull)
require.NoError(t, err)
require.Zero(t, result.GetDecimal(i).Compare(res))
}
}
}
Expand Down
157 changes: 86 additions & 71 deletions expression/builtin_compare_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,23 @@
package expression

import (
"testing"
"time"

. "github.com/pingcap/check"
"github.com/pingcap/errors"
"github.com/pingcap/tidb/parser/ast"
"github.com/pingcap/tidb/parser/model"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/types/json"
"github.com/pingcap/tidb/util/chunk"
"github.com/stretchr/testify/require"
)

func (s *testEvaluatorSuite) TestCompareFunctionWithRefine(c *C) {
func TestCompareFunctionWithRefine(t *testing.T) {
t.Parallel()
ctx := createContext(t)

tblInfo := newTestTableBuilder("").add("a", mysql.TypeLong, mysql.NotNullFlag).build()
tests := []struct {
exprStr string
Expand Down Expand Up @@ -69,17 +73,20 @@ func (s *testEvaluatorSuite) TestCompareFunctionWithRefine(c *C) {
{"-123456789123456789123456789.12345 < a", "1"},
{"'aaaa'=a", "eq(0, a)"},
}
cols, names, err := ColumnInfos2ColumnsAndNames(s.ctx, model.NewCIStr(""), tblInfo.Name, tblInfo.Cols(), tblInfo)
c.Assert(err, IsNil)
cols, names, err := ColumnInfos2ColumnsAndNames(ctx, model.NewCIStr(""), tblInfo.Name, tblInfo.Cols(), tblInfo)
require.NoError(t, err)
schema := NewSchema(cols...)
for _, t := range tests {
f, err := ParseSimpleExprsWithNames(s.ctx, t.exprStr, schema, names)
c.Assert(err, IsNil)
c.Assert(f[0].String(), Equals, t.result)
for _, test := range tests {
f, err := ParseSimpleExprsWithNames(ctx, test.exprStr, schema, names)
require.NoError(t, err)
require.Equal(t, test.result, f[0].String())
}
}

func (s *testEvaluatorSuite) TestCompare(c *C) {
func TestCompare(t *testing.T) {
t.Parallel()
ctx := createContext(t)

intVal, uintVal, realVal, stringVal, decimalVal := 1, uint64(1), 1.1, "123", types.NewDecFromFloatForTest(123.123)
timeVal := types.NewTime(types.FromGoTime(time.Now()), mysql.TypeDatetime, 6)
durationVal := types.Duration{Duration: 12*time.Hour + 1*time.Minute + 1*time.Second}
Expand Down Expand Up @@ -133,36 +140,39 @@ func (s *testEvaluatorSuite) TestCompare(c *C) {
{jsonVal, jsonVal, ast.NullEQ, mysql.TypeJSON, 1},
}

for _, t := range tests {
bf, err := funcs[t.funcName].getFunction(s.ctx, s.primitiveValsToConstants([]interface{}{t.arg0, t.arg1}))
c.Assert(err, IsNil)
for _, test := range tests {
bf, err := funcs[test.funcName].getFunction(ctx, primitiveValsToConstants(ctx, []interface{}{test.arg0, test.arg1}))
require.NoError(t, err)
args := bf.getArgs()
c.Assert(args[0].GetType().Tp, Equals, t.tp)
c.Assert(args[1].GetType().Tp, Equals, t.tp)
require.Equal(t, test.tp, args[0].GetType().Tp)
require.Equal(t, test.tp, args[1].GetType().Tp)
res, isNil, err := bf.evalInt(chunk.Row{})
c.Assert(err, IsNil)
c.Assert(isNil, IsFalse)
c.Assert(res, Equals, t.expected)
require.NoError(t, err)
require.False(t, isNil)
require.Equal(t, test.expected, res)
}

// test <non-const decimal expression> <cmp> <const string expression>
decimalCol, stringCon := &Column{RetType: types.NewFieldType(mysql.TypeNewDecimal)}, &Constant{RetType: types.NewFieldType(mysql.TypeVarchar)}
bf, err := funcs[ast.LT].getFunction(s.ctx, []Expression{decimalCol, stringCon})
c.Assert(err, IsNil)
bf, err := funcs[ast.LT].getFunction(ctx, []Expression{decimalCol, stringCon})
require.NoError(t, err)
args := bf.getArgs()
c.Assert(args[0].GetType().Tp, Equals, mysql.TypeNewDecimal)
c.Assert(args[1].GetType().Tp, Equals, mysql.TypeNewDecimal)
require.Equal(t, mysql.TypeNewDecimal, args[0].GetType().Tp)
require.Equal(t, mysql.TypeNewDecimal, args[1].GetType().Tp)

// test <time column> <cmp> <non-time const>
timeCol := &Column{RetType: types.NewFieldType(mysql.TypeDatetime)}
bf, err = funcs[ast.LT].getFunction(s.ctx, []Expression{timeCol, stringCon})
c.Assert(err, IsNil)
bf, err = funcs[ast.LT].getFunction(ctx, []Expression{timeCol, stringCon})
require.NoError(t, err)
args = bf.getArgs()
c.Assert(args[0].GetType().Tp, Equals, mysql.TypeDatetime)
c.Assert(args[1].GetType().Tp, Equals, mysql.TypeDatetime)
require.Equal(t, mysql.TypeDatetime, args[0].GetType().Tp)
require.Equal(t, mysql.TypeDatetime, args[1].GetType().Tp)
}

func (s *testEvaluatorSuite) TestCoalesce(c *C) {
func TestCoalesce(t *testing.T) {
t.Parallel()
ctx := createContext(t)

cases := []struct {
args []interface{}
expected interface{}
Expand All @@ -173,7 +183,7 @@ func (s *testEvaluatorSuite) TestCoalesce(c *C) {
{[]interface{}{nil, nil}, nil, true, false},
{[]interface{}{nil, nil, nil}, nil, true, false},
{[]interface{}{nil, 1}, int64(1), false, false},
{[]interface{}{nil, 1.1}, float64(1.1), false, false},
{[]interface{}{nil, 1.1}, 1.1, false, false},
{[]interface{}{1, 1.1}, float64(1), false, false},
{[]interface{}{nil, types.NewDecFromFloatForTest(123.456)}, types.NewDecFromFloatForTest(123.456), false, false},
{[]interface{}{1, types.NewDecFromFloatForTest(123.456)}, types.NewDecFromInt(1), false, false},
Expand All @@ -183,37 +193,40 @@ func (s *testEvaluatorSuite) TestCoalesce(c *C) {
{[]interface{}{tm, dt}, tm, false, false},
}

for _, t := range cases {
f, err := newFunctionForTest(s.ctx, ast.Coalesce, s.primitiveValsToConstants(t.args)...)
c.Assert(err, IsNil)
for _, test := range cases {
f, err := newFunctionForTest(ctx, ast.Coalesce, primitiveValsToConstants(ctx, test.args)...)
require.NoError(t, err)

d, err := f.Eval(chunk.Row{})

if t.getErr {
c.Assert(err, NotNil)
if test.getErr {
require.Error(t, err)
} else {
c.Assert(err, IsNil)
if t.isNil {
c.Assert(d.Kind(), Equals, types.KindNull)
require.NoError(t, err)
if test.isNil {
require.Equal(t, types.KindNull, d.Kind())
} else {
c.Assert(d.GetValue(), DeepEquals, t.expected)
require.Equal(t, test.expected, d.GetValue())
}
}
}

_, err := funcs[ast.Length].getFunction(s.ctx, []Expression{NewZero()})
c.Assert(err, IsNil)
_, err := funcs[ast.Length].getFunction(ctx, []Expression{NewZero()})
require.NoError(t, err)
}

func (s *testEvaluatorSuite) TestIntervalFunc(c *C) {
sc := s.ctx.GetSessionVars().StmtCtx
func TestIntervalFunc(t *testing.T) {
t.Parallel()
ctx := createContext(t)

sc := ctx.GetSessionVars().StmtCtx
origin := sc.IgnoreTruncate
sc.IgnoreTruncate = true
defer func() {
sc.IgnoreTruncate = origin
}()

for _, t := range []struct {
for _, test := range []struct {
args []types.Datum
ret int64
getErr bool
Expand Down Expand Up @@ -245,30 +258,32 @@ func (s *testEvaluatorSuite) TestIntervalFunc(c *C) {
{types.MakeDatums("9007199254740992", "9007199254740993"), 1, false},
} {
fc := funcs[ast.Interval]
f, err := fc.getFunction(s.ctx, s.datumsToConstants(t.args))
c.Assert(err, IsNil)
if t.getErr {
f, err := fc.getFunction(ctx, datumsToConstants(test.args))
require.NoError(t, err)
if test.getErr {
v, err := evalBuiltinFunc(f, chunk.Row{})
c.Assert(err, NotNil)
c.Assert(v.GetInt64(), Equals, t.ret)
require.Error(t, err)
require.Equal(t, test.ret, v.GetInt64())
continue
}
v, err := evalBuiltinFunc(f, chunk.Row{})
c.Assert(err, IsNil)
c.Assert(v.GetInt64(), Equals, t.ret)
require.NoError(t, err)
require.Equal(t, test.ret, v.GetInt64())
}
}

// greatest/least function is compatible with MySQL 8.0
func (s *testEvaluatorSuite) TestGreatestLeastFunc(c *C) {
sc := s.ctx.GetSessionVars().StmtCtx
func TestGreatestLeastFunc(t *testing.T) {
t.Parallel()
ctx := createContext(t)
sc := ctx.GetSessionVars().StmtCtx
originIgnoreTruncate := sc.IgnoreTruncate
sc.IgnoreTruncate = true
defer func() {
sc.IgnoreTruncate = originIgnoreTruncate
}()

for _, t := range []struct {
for _, test := range []struct {
args []interface{}
expectedGreatest interface{}
expectedLeast interface{}
Expand Down Expand Up @@ -332,36 +347,36 @@ func (s *testEvaluatorSuite) TestGreatestLeastFunc(c *C) {
"905969664", "1990-06-16 17:22:56.005534", false, false,
},
} {
f0, err := newFunctionForTest(s.ctx, ast.Greatest, s.primitiveValsToConstants(t.args)...)
c.Assert(err, IsNil)
f0, err := newFunctionForTest(ctx, ast.Greatest, primitiveValsToConstants(ctx, test.args)...)
require.NoError(t, err)
d, err := f0.Eval(chunk.Row{})
if t.getErr {
c.Assert(err, NotNil)
if test.getErr {
require.Error(t, err)
} else {
c.Assert(err, IsNil)
if t.isNil {
c.Assert(d.Kind(), Equals, types.KindNull)
require.NoError(t, err)
if test.isNil {
require.Equal(t, types.KindNull, d.Kind())
} else {
c.Assert(d.GetValue(), DeepEquals, t.expectedGreatest)
require.Equal(t, test.expectedGreatest, d.GetValue())
}
}

f1, err := newFunctionForTest(s.ctx, ast.Least, s.primitiveValsToConstants(t.args)...)
c.Assert(err, IsNil)
f1, err := newFunctionForTest(ctx, ast.Least, primitiveValsToConstants(ctx, test.args)...)
require.NoError(t, err)
d, err = f1.Eval(chunk.Row{})
if t.getErr {
c.Assert(err, NotNil)
if test.getErr {
require.Error(t, err)
} else {
c.Assert(err, IsNil)
if t.isNil {
c.Assert(d.Kind(), Equals, types.KindNull)
require.NoError(t, err)
if test.isNil {
require.Equal(t, types.KindNull, d.Kind())
} else {
c.Assert(d.GetValue(), DeepEquals, t.expectedLeast)
require.Equal(t, test.expectedLeast, d.GetValue())
}
}
}
_, err := funcs[ast.Greatest].getFunction(s.ctx, []Expression{NewZero(), NewOne()})
c.Assert(err, IsNil)
_, err = funcs[ast.Least].getFunction(s.ctx, []Expression{NewZero(), NewOne()})
c.Assert(err, IsNil)
_, err := funcs[ast.Greatest].getFunction(ctx, []Expression{NewZero(), NewOne()})
require.NoError(t, err)
_, err = funcs[ast.Least].getFunction(ctx, []Expression{NewZero(), NewOne()})
require.NoError(t, err)
}
Loading

0 comments on commit b19ee43

Please sign in to comment.