From e649e9574b7143ec88d392f87927e6962f13e991 Mon Sep 17 00:00:00 2001 From: Vicent Marti Date: Mon, 6 Mar 2023 11:15:20 +0100 Subject: [PATCH 01/42] go: add a slices2 package Signed-off-by: Vicent Marti --- go/slices2/slices.go | 530 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 530 insertions(+) create mode 100644 go/slices2/slices.go diff --git a/go/slices2/slices.go b/go/slices2/slices.go new file mode 100644 index 00000000000..413490a4b9e --- /dev/null +++ b/go/slices2/slices.go @@ -0,0 +1,530 @@ +package slices2 + +// All returns true if all elements return true for given predicate +func All[T any](s []T, fn func(T) bool) bool { + for _, e := range s { + if !fn(e) { + return false + } + } + return true +} + +// Any returns true if at least one element returns true for given predicate +func Any[T any](s []T, fn func(T) bool) bool { + for _, e := range s { + if fn(e) { + return true + } + } + return false +} + +// AppendToGroup adds the key, value to the given map where each key +// points to a slice of values +func AppendToGroup[M ~map[K][]V, K comparable, V any](m M, k K, v V) { + lst, ok := m[k] + if !ok { + lst = make([]V, 0) + } + lst = append(lst, v) + m[k] = lst +} + +// Associate returns a map containing key-value pairs returned by the given +// function applied to the elements of the given slice +func Associate[T, V any, K comparable](s []T, fn func(T) (K, V)) map[K]V { + ret := make(map[K]V) + for _, e := range s { + k, v := fn(e) + ret[k] = v + } + return ret +} + +// Chunked splits the slice into a slice of slices, each not exceeding given size +// The last slice might have fewer elements than the given size +func Chunked[T any](s []T, chunkSize int) [][]T { + sz := len(s) + ret := make([][]T, 0, sz/chunkSize+2) + var sub []T + for i := 0; i < sz; i++ { + if i%chunkSize == 0 { + if len(sub) > 0 { + ret = append(ret, sub) + } + sub = make([]T, 0, chunkSize) + } + sub = append(sub, s[i]) + } + if len(sub) > 0 { + ret = append(ret, sub) + } + return ret +} + +func ChunkedBy[T any](s []T, fn func(T, T) bool) [][]T { + ret := make([][]T, 0) + switch len(s) { + case 0: + return ret + case 1: + ret = append(ret, []T{s[0]}) + return ret + } + var currentSubList = []T{s[0]} + for _, e := range s[1:] { + if fn(currentSubList[len(currentSubList)-1], e) { + currentSubList = append(currentSubList, e) + } else { + // save current sub list and start a new one + ret = append(ret, currentSubList) + currentSubList = []T{e} + } + } + if len(currentSubList) > 0 { + ret = append(ret, currentSubList) + } + return ret +} + +// Distinct returns a slice containing only distinct elements from the given slice +// Elements will retain their original order. +func Distinct[T comparable](s []T) []T { + m := make(map[T]bool) + ret := make([]T, 0) + for _, e := range s { + _, ok := m[e] + if ok { + continue + } + m[e] = true + ret = append(ret, e) + } + return ret +} + +// DistinctBy returns a slice containing only distinct elements from the +// given slice as distinguished by the given selector function +// Elements will retain their original order. +func DistinctBy[T any, K comparable](s []T, fn func(T) K) []T { + m := make(map[K]bool) + ret := make([]T, 0) + for _, e := range s { + k := fn(e) + _, ok := m[k] + if ok { + continue + } + m[k] = true + ret = append(ret, e) + } + return ret +} + +// Drop returns a slice containing all elements except the first n +func Drop[T any](s []T, n int) []T { + if n >= len(s) { + return make([]T, 0) + } + return s[n:] +} + +// DropLast returns a slice containing all elements except the last n +func DropLast[T any](s []T, n int) []T { + if n >= len(s) { + return make([]T, 0) + } + return s[:len(s)-n] +} + +// DropLastWhile returns a slice containing all elements except the last elements +// that satisfy the given predicate +func DropLastWhile[T any](s []T, fn func(T) bool) []T { + if len(s) == 0 { + return s + } + i := len(s) - 1 + for ; i >= 0; i-- { + if !fn(s[i]) { + break + } + } + return s[:i+1] +} + +// DropWhile returns a slice containing all elements except the first elements +// that satisfy the given predicate +func DropWhile[T any](s []T, fn func(T) bool) []T { + if len(s) == 0 { + return s + } + i := 0 + for ; i < len(s); i++ { + if !fn(s[i]) { + break + } + } + return s[i:] +} + +// Filter returns the slice obtained after retaining only those elements +// in the given slice for which the given function returns true +func Filter[T any](s []T, fn func(T) bool) []T { + ret := make([]T, 0) + for _, e := range s { + if fn(e) { + ret = append(ret, e) + } + } + return ret +} + +func FilterInPlace[T any](s []T, fn func(T) bool) []T { + filtered := s[:0] + for _, e := range s { + if fn(e) { + filtered = append(filtered, e) + } + } + return filtered +} + +// FilterIndexed returns the slice obtained after retaining only those elements +// in the given slice for which the given function returns true. Predicate +// receives the value as well as its index in the slice. +func FilterIndexed[T any](s []T, fn func(int, T) bool) []T { + ret := make([]T, 0) + for i, e := range s { + if fn(i, e) { + ret = append(ret, e) + } + } + return ret +} + +// FilterMap returns the slice obtained after both filtering and mapping using +// the given function. The function should return two values - +// first, the result of the mapping operation and +// second, whether the element should be included or not. +// This is faster than doing a separate filter and map operations, +// since it avoids extra allocations and slice traversals. +func FilterMap[T1, T2 any]( + s []T1, + fn func(T1) (T2, bool), +) []T2 { + + ret := make([]T2, 0) + for _, e := range s { + m, ok := fn(e) + if ok { + ret = append(ret, m) + } + } + return ret +} + +func Find[T any](s []T, fn func(T) bool) (T, bool) { + for _, e := range s { + if fn(e) { + return e, true + } + } + return *new(T), false +} + +// Fold accumulates values starting with given initial value and applying +// given function to current accumulator and each element. +func Fold[T, R any](s []T, initial R, fn func(R, T) R) R { + acc := initial + for _, e := range s { + acc = fn(acc, e) + } + return acc +} + +// FoldIndexed accumulates values starting with given initial value and applying +// given function to current accumulator and each element. Function also +// receives index of current element. +func FoldIndexed[T, R any](s []T, initial R, fn func(int, R, T) R) R { + acc := initial + for i, e := range s { + acc = fn(i, acc, e) + } + return acc +} + +// FoldItems accumulates values starting with given intial value and applying +// given function to current accumulator and each key, value. +func FoldItems[M ~map[K]V, K comparable, V, R any]( + m M, + initial R, + fn func(R, K, V) R, +) R { + acc := initial + for k, v := range m { + acc = fn(acc, k, v) + } + return acc +} + +// GetOrInsert checks if a value corresponding to the given key is present +// in the map. If present it returns the existing value. If not, it invokes the +// given callback function to get a new value for the given key, inserts it in +// the map and returns the new value +func GetOrInsert[M ~map[K]V, K comparable, V any](m M, k K, fn func(K) V) V { + v, ok := m[k] + if ok { + // present, return existing value + return v + } + // not present; get value, insert in map and return the new value + v = fn(k) + m[k] = v + return v +} + +// GroupBy returns a map containing key to list of values +// returned by the given function applied to the elements of the given slice +func GroupBy[T, V any, K comparable]( + s []T, + fn func(T) (K, V), +) map[K][]V { + ret := make(map[K][]V) + for _, e := range s { + k, v := fn(e) + lst, ok := ret[k] + if !ok { + lst = make([]V, 0) + } + lst = append(lst, v) + ret[k] = lst + } + return ret +} + +// Map returns the slice obtained after applying the given function over every +// element in the given slice +func Map[T1, T2 any](s []T1, fn func(T1) T2) []T2 { + ret := make([]T2, 0, len(s)) + for _, e := range s { + ret = append(ret, fn(e)) + } + return ret +} + +// MapIndexed returns the slice obtained after applying the given function over every +// element in the given slice. The function also receives the index of each +// element in the slice. +func MapIndexed[T1, T2 any](s []T1, fn func(int, T1) T2) []T2 { + ret := make([]T2, 0, len(s)) + for i, e := range s { + ret = append(ret, fn(i, e)) + } + return ret +} + +// Partition returns two slices where the first slice contains elements for +// which the predicate returned true and the second slice contains elements for +// which it returned false. +func Partition[T any](s []T, fn func(T) bool) ([]T, []T) { + trueList := make([]T, 0) + falseList := make([]T, 0) + for _, e := range s { + if fn(e) { + trueList = append(trueList, e) + } else { + falseList = append(falseList, e) + } + } + return trueList, falseList +} + +// Reduce accumulates the values starting with the first element and applying the +// operation from left to right to the current accumulator value and each element +// The input slice must have at least one element. +func Reduce[T any](s []T, fn func(T, T) T) T { + if len(s) == 1 { + return s[0] + } + return Fold(s[1:], s[0], fn) +} + +// ReduceIndexed accumulates the values starting with the first element and applying the +// operation from left to right to the current accumulator value and each element +// The input slice must have at least one element. The function also receives +// the index of the element. +func ReduceIndexed[T any](s []T, fn func(int, T, T) T) T { + if len(s) == 1 { + return s[0] + } + acc := s[0] + for i, e := range s[1:] { + acc = fn(i+1, acc, e) + } + return acc +} + +// Reverse reverses the elements of the list in place +func Reverse[T any](s []T) { + for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 { + s[i], s[j] = s[j], s[i] + } +} + +// Reversed returns a new list with the elements in reverse order +func Reversed[T any](s []T) []T { + ret := make([]T, 0, len(s)) + for i := len(s) - 1; i >= 0; i-- { + ret = append(ret, s[i]) + } + return ret +} + +// Take returns the slice obtained after taking the first n elements from the +// given slice. +// If n is greater than the length of the slice, returns the entire slice +func Take[T any](s []T, n int) []T { + if len(s) <= n { + return s + } + return s[:n] +} + +// TakeLast returns the slice obtained after taking the last n elements from the +// given slice. +func TakeLast[T any](s []T, n int) []T { + if len(s) <= n { + return s + } + return s[len(s)-n:] +} + +// TakeLastWhile returns a slice containing the last elements satisfying the given +// predicate +func TakeLastWhile[T any](s []T, fn func(T) bool) []T { + if len(s) == 0 { + return s + } + i := len(s) - 1 + for ; i >= 0; i-- { + if !fn(s[i]) { + break + } + } + return s[i+1:] +} + +// TakeWhile returns a list containing the first elements satisfying the +// given predicate +func TakeWhile[T any](s []T, fn func(T) bool) []T { + if len(s) == 0 { + return s + } + i := 0 + for ; i < len(s); i++ { + if !fn(s[i]) { + break + } + } + return s[:i] +} + +// TransformMap applies the given function to each key, value in the map, +// and returns a new map of the same type after transforming the keys +// and values depending on the callback functions return values. If the last +// bool return value from the callback function is false, the entry is dropped +func TransformMap[M ~map[K]V, K comparable, V any]( + m M, + fn func(k K, v V) (K, V, bool), +) M { + ret := make(map[K]V) + for k, v := range m { + newK, newV, include := fn(k, v) + if include { + ret[newK] = newV + } + } + return ret +} + +// Unzip returns two slices, where the first slice is built from the first +// values of each pair from the input slice, and the second slice is built +// from the second values of each pair +func Unzip[T1 any, T2 any](ps []*Pair[T1, T2]) ([]T1, []T2) { + l := len(ps) + s1 := make([]T1, 0, l) + s2 := make([]T2, 0, l) + for _, p := range ps { + s1 = append(s1, p.Fst) + s2 = append(s2, p.Snd) + } + return s1, s2 +} + +// Windowed returns a slice of sliding windows into the given slice of the +// given size, and with the given step +func Windowed[T any](s []T, size, step int) [][]T { + ret := make([][]T, 0) + sz := len(s) + if sz == 0 { + return ret + } + start := 0 + end := 0 + updateEnd := func() { + e := start + size + if e >= sz { + e = sz + } + end = e + } + updateStart := func() { + s := start + step + if s >= sz { + s = sz + } + start = s + } + updateEnd() + + for { + sub := make([]T, 0, end) + for i := start; i < end; i++ { + sub = append(sub, s[i]) + } + ret = append(ret, sub) + updateStart() + updateEnd() + if start == end { + break + } + } + return ret +} + +// Zip returns a slice of pairs from the elements of both slices with the same +// index. The returned slice has the length of the shortest input slice +func Zip[T1 any, T2 any](s1 []T1, s2 []T2) []*Pair[T1, T2] { + minLen := len(s1) + if minLen > len(s2) { + minLen = len(s2) + } + + // Allocate enough space to avoid copies and extra allocations + ret := make([]*Pair[T1, T2], 0, minLen) + + for i := 0; i < minLen; i++ { + ret = append(ret, &Pair[T1, T2]{ + Fst: s1[i], + Snd: s2[i], + }) + } + return ret +} + +// Pair represents a generic pair of two values +type Pair[T1, T2 any] struct { + Fst T1 + Snd T2 +} From c3842f446a034a5629ab4e06a3faf7d329426b34 Mon Sep 17 00:00:00 2001 From: Vicent Marti Date: Mon, 6 Mar 2023 11:15:33 +0100 Subject: [PATCH 02/42] evalengine: compiler [wip] Signed-off-by: Vicent Marti --- go/vt/vtgate/evalengine/arena.go | 81 + go/vt/vtgate/evalengine/compiler.go | 1241 ++++++++++++++++ .../evalengine/compiler_instructions.go | 1311 +++++++++++++++++ go/vt/vtgate/evalengine/compiler_test.go | 242 +++ go/vt/vtgate/evalengine/perf_test.go | 60 + go/vt/vtgate/evalengine/vm.go | 63 + 6 files changed, 2998 insertions(+) create mode 100644 go/vt/vtgate/evalengine/arena.go create mode 100644 go/vt/vtgate/evalengine/compiler.go create mode 100644 go/vt/vtgate/evalengine/compiler_instructions.go create mode 100644 go/vt/vtgate/evalengine/compiler_test.go create mode 100644 go/vt/vtgate/evalengine/perf_test.go create mode 100644 go/vt/vtgate/evalengine/vm.go diff --git a/go/vt/vtgate/evalengine/arena.go b/go/vt/vtgate/evalengine/arena.go new file mode 100644 index 00000000000..1e7092f0065 --- /dev/null +++ b/go/vt/vtgate/evalengine/arena.go @@ -0,0 +1,81 @@ +package evalengine + +import "vitess.io/vitess/go/vt/vtgate/evalengine/internal/decimal" + +type Arena struct { + aInt64 []evalInt64 + aUint64 []evalUint64 + aFloat64 []evalFloat + aDecimal []evalDecimal + aBytes []evalBytes +} + +func (a *Arena) reset() { + a.aInt64 = a.aInt64[:0] + a.aUint64 = a.aUint64[:0] + a.aFloat64 = a.aFloat64[:0] + a.aDecimal = a.aDecimal[:0] + a.aBytes = a.aBytes[:0] +} + +func (a *Arena) newEvalDecimalWithPrec(dec decimal.Decimal, prec int32) *evalDecimal { + if cap(a.aDecimal) > len(a.aDecimal) { + a.aDecimal = a.aDecimal[:len(a.aDecimal)+1] + } else { + a.aDecimal = append(a.aDecimal, evalDecimal{}) + } + val := &a.aDecimal[len(a.aDecimal)-1] + val.dec = dec + val.length = prec + return val +} + +func (a *Arena) newEvalDecimal(dec decimal.Decimal, m, d int32) *evalDecimal { + if m == 0 && d == 0 { + return a.newEvalDecimalWithPrec(dec, -dec.Exponent()) + } + return a.newEvalDecimalWithPrec(dec.Clamp(m-d, d), d) +} + +func (a *Arena) newEvalInt64(i int64) *evalInt64 { + if cap(a.aInt64) > len(a.aInt64) { + a.aInt64 = a.aInt64[:len(a.aInt64)+1] + } else { + a.aInt64 = append(a.aInt64, evalInt64{}) + } + val := &a.aInt64[len(a.aInt64)-1] + val.i = i + return val +} + +func (a *Arena) newEvalUint64(u uint64) *evalUint64 { + if cap(a.aUint64) > len(a.aUint64) { + a.aUint64 = a.aUint64[:len(a.aUint64)+1] + } else { + a.aUint64 = append(a.aUint64, evalUint64{}) + } + val := &a.aUint64[len(a.aUint64)-1] + val.u = u + val.hexLiteral = false + return val +} + +func (a *Arena) newEvalFloat(f float64) *evalFloat { + if cap(a.aFloat64) > len(a.aFloat64) { + a.aFloat64 = a.aFloat64[:len(a.aFloat64)+1] + } else { + a.aFloat64 = append(a.aFloat64, evalFloat{}) + } + val := &a.aFloat64[len(a.aFloat64)-1] + val.f = f + return val +} + +func (a *Arena) newEvalBytesEmpty() *evalBytes { + if cap(a.aBytes) > len(a.aBytes) { + a.aBytes = a.aBytes[:len(a.aBytes)+1] + } else { + a.aBytes = append(a.aBytes, evalBytes{}) + } + return &a.aBytes[len(a.aBytes)-1] +} diff --git a/go/vt/vtgate/evalengine/compiler.go b/go/vt/vtgate/evalengine/compiler.go new file mode 100644 index 00000000000..1bce3fbb009 --- /dev/null +++ b/go/vt/vtgate/evalengine/compiler.go @@ -0,0 +1,1241 @@ +package evalengine + +import ( + "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/slices2" + "vitess.io/vitess/go/sqltypes" + querypb "vitess.io/vitess/go/vt/proto/query" + "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/evalengine/internal/json" +) + +type frame func(vm *VirtualMachine) int + +type compiler struct { + ins []frame + log CompilerLog + fields []*querypb.Field + + stack struct { + cur int + max int + } +} + +type CompilerLog interface { + Disasm(ins string, args ...any) + Stack(old, new int) +} + +type compiledCoercion struct { + col collations.Collation + left collations.Coercion + right collations.Coercion +} + +type CompilerOption func(c *compiler) + +func WithCompilerLog(log CompilerLog) CompilerOption { + return func(c *compiler) { + c.log = log + } +} + +func Compile(expr Expr, fields []*querypb.Field, options ...CompilerOption) (*Program, error) { + comp := compiler{ + fields: fields, + } + for _, opt := range options { + opt(&comp) + } + _, err := comp.compileExpr(expr) + if err != nil { + return nil, err + } + if comp.stack.cur != 1 { + return nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "bad compilation: stack pointer at %d after compilation", comp.stack.cur) + } + return &Program{code: comp.ins, original: expr, stack: comp.stack.max}, nil +} + +func (c *compiler) adjustStack(offset int) { + c.stack.cur += offset + if c.stack.cur < 0 { + panic("negative stack position") + } + if c.stack.cur > c.stack.max { + c.stack.max = c.stack.cur + } + if c.log != nil { + c.log.Stack(c.stack.cur-offset, c.stack.cur) + } +} + +func (c *compiler) compileColumn(offset int) (ctype, error) { + if offset >= len(c.fields) { + return ctype{}, vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "Missing field for column %d", offset) + } + + field := c.fields[offset] + col := collations.TypedCollation{ + Collation: collations.ID(field.Charset), + Coercibility: collations.CoerceCoercible, + Repertoire: collations.RepertoireASCII, + } + if col.Collation != collations.CollationBinaryID { + col.Repertoire = collations.RepertoireUnicode + } + + switch tt := field.Type; { + case sqltypes.IsSigned(tt): + c.emitPushColumn_i(offset) + case sqltypes.IsUnsigned(tt): + c.emitPushColumn_u(offset) + case sqltypes.IsFloat(tt): + c.emitPushColumn_f(offset) + case sqltypes.IsDecimal(tt): + c.emitPushColumn_d(offset) + case sqltypes.IsText(tt): + if tt == sqltypes.HexNum { + c.emitPushColumn_hexnum(offset) + } else if tt == sqltypes.HexVal { + c.emitPushColumn_hexval(offset) + } else { + c.emitPushColumn_text(offset, col) + } + case sqltypes.IsBinary(tt): + c.emitPushColumn_bin(offset) + case sqltypes.IsNull(tt): + c.emitPushNull() + case tt == sqltypes.TypeJSON: + c.emitPushColumn_json(offset) + default: + return ctype{}, vterrors.Errorf(vtrpc.Code_UNIMPLEMENTED, "Type is not supported: %s", tt) + } + + var flag typeFlag + if (field.Flags & uint32(querypb.MySqlFlag_NOT_NULL_FLAG)) == 0 { + flag |= flagNullable + } + + return ctype{ + Type: field.Type, + Flag: flag, + Col: col, + }, nil +} + +type ctype struct { + Type sqltypes.Type + Flag typeFlag + Col collations.TypedCollation +} + +func (ct ctype) nullable() bool { + return ct.Flag&flagNullable != 0 +} + +func (ct ctype) isTextual() bool { + return sqltypes.IsText(ct.Type) || sqltypes.IsBinary(ct.Type) +} + +func (ct ctype) isHexOrBitLiteral() bool { + return ct.Flag&flagBit != 0 || ct.Flag&flagHex != 0 +} + +type jump struct { + from, to int +} + +func (j *jump) offset() int { + return j.to - j.from +} + +func (c *compiler) jumpFrom() *jump { + return &jump{from: len(c.ins)} +} + +func (c *compiler) jumpDestination(j *jump) { + j.to = len(c.ins) +} + +func (c *compiler) unsupported(expr Expr) error { + return vterrors.Errorf(vtrpc.Code_UNIMPLEMENTED, "unsupported compilation for expression '%s'", FormatExpr(expr)) +} + +func (c *compiler) compileExpr(expr Expr) (ctype, error) { + switch expr := expr.(type) { + case *Literal: + if expr.inner == nil { + c.emitPushNull() + } else if err := c.emitPushLiteral(expr.inner); err != nil { + return ctype{}, err + } + + t, f := expr.typeof(nil) + return ctype{t, f, evalCollation(expr.inner)}, nil + + case *Column: + return c.compileColumn(expr.Offset) + + case *ArithmeticExpr: + return c.compileArithmetic(expr) + + case *BitwiseExpr: + return c.compileBitwise(expr) + + case *BitwiseNotExpr: + return c.compileBitwiseNot(expr) + + case *NegateExpr: + return c.compileNegate(expr) + + case *builtinBitCount: + return c.compileBitCount(expr) + + case *ComparisonExpr: + return c.compileComparison(expr) + + case *CollateExpr: + return c.compileCollate(expr) + + case *ConvertExpr: + return c.compileConvert(expr) + + case *CaseExpr: + return c.compileCase(expr) + + case callable: + return c.compileCallable(expr) + + default: + return ctype{}, c.unsupported(expr) + } +} + +func (c *compiler) compileCallable(call callable) (ctype, error) { + switch call := call.(type) { + case *builtinMultiComparison: + return c.compileMultiComparison(call) + case *builtinJSONExtract: + return c.compileJSONExtract(call) + case *builtinJSONUnquote: + return c.compileJSONUnquote(call) + case *builtinJSONContainsPath: + return c.compileJSONContainsPath(call) + default: + return ctype{}, c.unsupported(call) + } +} + +func (c *compiler) compileComparison(expr *ComparisonExpr) (ctype, error) { + lt, err := c.compileExpr(expr.Left) + if err != nil { + return ctype{}, err + } + + rt, err := c.compileExpr(expr.Right) + if err != nil { + return ctype{}, err + } + + swapped := false + skip := c.jumpFrom() + + switch expr.Op.(type) { + case compareNullSafeEQ: + c.emitNullCmp(skip) + default: + c.emitNullCheck2(skip) + } + + switch { + case compareAsStrings(lt.Type, rt.Type): + var merged collations.TypedCollation + var coerceLeft collations.Coercion + var coerceRight collations.Coercion + var env = collations.Local() + + if lt.Col.Collation != rt.Col.Collation { + merged, coerceLeft, coerceRight, err = env.MergeCollations(lt.Col, rt.Col, collations.CoercionOptions{ + ConvertToSuperset: true, + ConvertWithCoercion: true, + }) + } else { + merged = lt.Col + } + if err != nil { + return ctype{}, err + } + + if coerceLeft == nil && coerceRight == nil { + c.emitCmpString_collate(env.LookupByID(merged.Collation)) + } else { + if coerceLeft == nil { + coerceLeft = func(dst, in []byte) ([]byte, error) { return in, nil } + } + if coerceRight == nil { + coerceRight = func(dst, in []byte) ([]byte, error) { return in, nil } + } + c.emitCmpString_coerce(&compiledCoercion{ + col: env.LookupByID(merged.Collation), + left: coerceLeft, + right: coerceRight, + }) + } + + case compareAsSameNumericType(lt.Type, rt.Type) || compareAsDecimal(lt.Type, rt.Type): + switch lt.Type { + case sqltypes.Int64: + switch rt.Type { + case sqltypes.Int64: + c.emitCmpNum_ii() + case sqltypes.Uint64: + c.emitCmpNum_iu(2, 1) + case sqltypes.Float64: + c.emitCmpNum_if(2, 1) + case sqltypes.Decimal: + c.emitCmpNum_id(2, 1) + } + case sqltypes.Uint64: + switch rt.Type { + case sqltypes.Int64: + c.emitCmpNum_iu(1, 2) + swapped = true + case sqltypes.Uint64: + c.emitCmpNum_uu() + case sqltypes.Float64: + c.emitCmpNum_uf(2, 1) + case sqltypes.Decimal: + c.emitCmpNum_ud(2, 1) + } + case sqltypes.Float64: + switch rt.Type { + case sqltypes.Int64: + c.emitCmpNum_if(1, 2) + swapped = true + case sqltypes.Uint64: + c.emitCmpNum_uf(1, 2) + swapped = true + case sqltypes.Float64: + c.emitCmpNum_ff() + case sqltypes.Decimal: + c.emitCmpNum_fd(2, 1) + } + + case sqltypes.Decimal: + switch rt.Type { + case sqltypes.Int64: + c.emitCmpNum_id(1, 2) + swapped = true + case sqltypes.Uint64: + c.emitCmpNum_ud(1, 2) + swapped = true + case sqltypes.Float64: + c.emitCmpNum_fd(1, 2) + swapped = true + case sqltypes.Decimal: + c.emitCmpNum_dd() + } + } + + case compareAsDates(lt.Type, rt.Type) || compareAsDateAndString(lt.Type, rt.Type) || compareAsDateAndNumeric(lt.Type, rt.Type): + return ctype{}, c.unsupported(expr) + + default: + lt = c.compileToFloat(lt, 2) + rt = c.compileToFloat(rt, 1) + c.emitCmpNum_ff() + } + + cmptype := ctype{Type: sqltypes.Int64, Col: collationNumeric} + + switch expr.Op.(type) { + case compareEQ: + c.emitCmp_eq() + case compareNE: + c.emitCmp_ne() + case compareLT: + if swapped { + c.emitCmp_gt() + } else { + c.emitCmp_lt() + } + case compareLE: + if swapped { + c.emitCmp_ge() + } else { + c.emitCmp_le() + } + case compareGT: + if swapped { + c.emitCmp_lt() + } else { + c.emitCmp_gt() + } + case compareGE: + if swapped { + c.emitCmp_le() + } else { + c.emitCmp_ge() + } + case compareNullSafeEQ: + c.jumpDestination(skip) + c.emitCmp_eq() + return cmptype, nil + + default: + panic("unexpected comparison operator") + } + + c.jumpDestination(skip) + return cmptype, nil +} + +func (c *compiler) compileArithmetic(expr *ArithmeticExpr) (ctype, error) { + switch expr.Op.(type) { + case *opArithAdd: + return c.compileArithmeticAdd(expr.Left, expr.Right) + case *opArithSub: + return c.compileArithmeticSub(expr.Left, expr.Right) + case *opArithMul: + return c.compileArithmeticMul(expr.Left, expr.Right) + case *opArithDiv: + return c.compileArithmeticDiv(expr.Left, expr.Right) + default: + panic("unexpected arithmetic operator") + } +} + +func (c *compiler) compileToNumeric(ct ctype, offset int) ctype { + if sqltypes.IsNumber(ct.Type) { + return ct + } + if ct.Type == sqltypes.VarBinary && (ct.Flag&flagHex) != 0 { + c.emitConvert_hex(offset) + return ctype{sqltypes.Uint64, ct.Flag, collationNumeric} + } + c.emitConvert_xf(offset) + return ctype{sqltypes.Float64, ct.Flag, collationNumeric} +} + +func (c *compiler) compileToInt64(ct ctype, offset int) ctype { + switch ct.Type { + case sqltypes.Int64: + return ct + // TODO: specialization + default: + c.emitConvert_xi(offset) + } + return ctype{sqltypes.Int64, ct.Flag, collationNumeric} +} + +func (c *compiler) compileToUint64(ct ctype, offset int) ctype { + switch ct.Type { + case sqltypes.Uint64: + return ct + case sqltypes.Int64: + c.emitConvert_iu(offset) + // TODO: specialization + default: + c.emitConvert_xu(offset) + } + return ctype{sqltypes.Uint64, ct.Flag, collationNumeric} +} + +func (c *compiler) compileToFloat(ct ctype, offset int) ctype { + if sqltypes.IsFloat(ct.Type) { + return ct + } + switch ct.Type { + case sqltypes.Int64: + c.emitConvert_if(offset) + case sqltypes.Uint64: + // only emit u->f conversion if this is not a hex value; hex values + // will already be converted + c.emitConvert_uf(offset) + default: + c.emitConvert_xf(offset) + } + return ctype{sqltypes.Float64, ct.Flag, collationNumeric} +} + +func (c *compiler) compileNumericPriority(lt, rt ctype) (ctype, ctype, bool) { + switch lt.Type { + case sqltypes.Int64: + if rt.Type == sqltypes.Uint64 || rt.Type == sqltypes.Float64 || rt.Type == sqltypes.Decimal { + return rt, lt, true + } + case sqltypes.Uint64: + if rt.Type == sqltypes.Float64 || rt.Type == sqltypes.Decimal { + return rt, lt, true + } + case sqltypes.Decimal: + if rt.Type == sqltypes.Float64 { + return rt, lt, true + } + } + return lt, rt, false +} + +func (c *compiler) compileArithmeticAdd(left, right Expr) (ctype, error) { + lt, err := c.compileExpr(left) + if err != nil { + return ctype{}, err + } + + rt, err := c.compileExpr(right) + if err != nil { + return ctype{}, err + } + + swap := false + skip := c.jumpFrom() + c.emitNullCheck2(skip) + lt = c.compileToNumeric(lt, 2) + rt = c.compileToNumeric(rt, 1) + lt, rt, swap = c.compileNumericPriority(lt, rt) + + var sumtype sqltypes.Type + + switch lt.Type { + case sqltypes.Int64: + c.emitAdd_ii() + sumtype = sqltypes.Int64 + case sqltypes.Uint64: + switch rt.Type { + case sqltypes.Int64: + c.emitAdd_ui(swap) + case sqltypes.Uint64: + c.emitAdd_uu() + } + sumtype = sqltypes.Uint64 + case sqltypes.Decimal: + if swap { + c.compileToDecimal(rt, 2) + } else { + c.compileToDecimal(rt, 1) + } + c.emitAdd_dd() + sumtype = sqltypes.Decimal + case sqltypes.Float64: + if swap { + c.compileToFloat(rt, 2) + } else { + c.compileToFloat(rt, 1) + } + c.emitAdd_ff() + sumtype = sqltypes.Float64 + } + + c.jumpDestination(skip) + return ctype{Type: sumtype, Col: collationNumeric}, nil +} + +func (c *compiler) compileArithmeticSub(left, right Expr) (ctype, error) { + lt, err := c.compileExpr(left) + if err != nil { + return ctype{}, err + } + + rt, err := c.compileExpr(right) + if err != nil { + return ctype{}, err + } + + skip := c.jumpFrom() + c.emitNullCheck2(skip) + lt = c.compileToNumeric(lt, 2) + rt = c.compileToNumeric(rt, 1) + + var subtype sqltypes.Type + + switch lt.Type { + case sqltypes.Int64: + switch rt.Type { + case sqltypes.Int64: + c.emitSub_ii() + subtype = sqltypes.Int64 + case sqltypes.Uint64: + c.emitSub_iu() + subtype = sqltypes.Uint64 + case sqltypes.Float64: + c.compileToFloat(lt, 2) + c.emitSub_ff() + subtype = sqltypes.Float64 + case sqltypes.Decimal: + c.compileToDecimal(lt, 2) + c.emitSub_dd() + subtype = sqltypes.Decimal + } + case sqltypes.Uint64: + switch rt.Type { + case sqltypes.Int64: + c.emitSub_ui() + subtype = sqltypes.Uint64 + case sqltypes.Uint64: + c.emitSub_uu() + subtype = sqltypes.Uint64 + case sqltypes.Float64: + c.compileToFloat(lt, 2) + c.emitSub_ff() + subtype = sqltypes.Float64 + case sqltypes.Decimal: + c.compileToDecimal(lt, 2) + c.emitSub_dd() + subtype = sqltypes.Decimal + } + case sqltypes.Float64: + c.compileToFloat(rt, 1) + c.emitSub_ff() + subtype = sqltypes.Float64 + case sqltypes.Decimal: + switch rt.Type { + case sqltypes.Float64: + c.compileToFloat(lt, 2) + c.emitSub_ff() + subtype = sqltypes.Float64 + default: + c.compileToDecimal(rt, 1) + c.emitSub_dd() + subtype = sqltypes.Decimal + } + } + + if subtype == 0 { + panic("did not compile?") + } + + c.jumpDestination(skip) + return ctype{Type: subtype, Col: collationNumeric}, nil +} + +func (c *compiler) compileToDecimal(ct ctype, offset int) ctype { + if sqltypes.IsDecimal(ct.Type) { + return ct + } + switch ct.Type { + case sqltypes.Int64: + c.emitConvert_id(offset) + case sqltypes.Uint64: + c.emitConvert_ud(offset) + default: + c.emitConvert_xd(offset, 0, 0) + } + return ctype{sqltypes.Decimal, ct.Flag, collationNumeric} +} + +func (c *compiler) compileArithmeticMul(left, right Expr) (ctype, error) { + lt, err := c.compileExpr(left) + if err != nil { + return ctype{}, err + } + + rt, err := c.compileExpr(right) + if err != nil { + return ctype{}, err + } + + swap := false + skip := c.jumpFrom() + c.emitNullCheck2(skip) + lt = c.compileToNumeric(lt, 2) + rt = c.compileToNumeric(rt, 1) + lt, rt, swap = c.compileNumericPriority(lt, rt) + + var multype sqltypes.Type + + switch lt.Type { + case sqltypes.Int64: + c.emitMul_ii() + multype = sqltypes.Int64 + case sqltypes.Uint64: + switch rt.Type { + case sqltypes.Int64: + c.emitMul_ui(swap) + case sqltypes.Uint64: + c.emitMul_uu() + } + multype = sqltypes.Uint64 + case sqltypes.Float64: + if swap { + c.compileToFloat(rt, 2) + } else { + c.compileToFloat(rt, 1) + } + c.emitMul_ff() + multype = sqltypes.Float64 + case sqltypes.Decimal: + if swap { + c.compileToDecimal(rt, 2) + } else { + c.compileToDecimal(rt, 1) + } + c.emitMul_dd() + multype = sqltypes.Decimal + } + + c.jumpDestination(skip) + return ctype{Type: multype, Col: collationNumeric}, nil +} + +func (c *compiler) compileArithmeticDiv(left, right Expr) (ctype, error) { + lt, err := c.compileExpr(left) + if err != nil { + return ctype{}, err + } + + rt, err := c.compileExpr(right) + if err != nil { + return ctype{}, err + } + + skip := c.jumpFrom() + c.emitNullCheck2(skip) + lt = c.compileToNumeric(lt, 2) + rt = c.compileToNumeric(rt, 1) + + if lt.Type == sqltypes.Float64 || rt.Type == sqltypes.Float64 { + c.compileToFloat(lt, 2) + c.compileToFloat(rt, 1) + c.emitDiv_ff() + return ctype{Type: sqltypes.Float64, Col: collationNumeric}, nil + } else { + c.compileToDecimal(lt, 2) + c.compileToDecimal(rt, 1) + c.emitDiv_dd() + return ctype{Type: sqltypes.Decimal, Col: collationNumeric}, nil + } +} + +func (c *compiler) compileCollate(expr *CollateExpr) (ctype, error) { + ct, err := c.compileExpr(expr.Inner) + if err != nil { + return ctype{}, err + } + + skip := c.jumpFrom() + c.emitNullCheck1(skip) + + switch ct.Type { + case sqltypes.VarChar: + if err := collations.Local().EnsureCollate(ct.Col.Collation, expr.TypedCollation.Collation); err != nil { + return ctype{}, vterrors.New(vtrpc.Code_INVALID_ARGUMENT, err.Error()) + } + fallthrough + case sqltypes.VarBinary: + c.emitCollate(expr.TypedCollation.Collation) + default: + return ctype{}, c.unsupported(expr) + } + + c.jumpDestination(skip) + + ct.Col = expr.TypedCollation + ct.Flag |= flagExplicitCollation + return ct, nil +} + +func (c *compiler) extractJSONPath(expr Expr) (*json.Path, error) { + path, ok := expr.(*Literal) + if !ok { + return nil, errJSONPath + } + pathBytes, ok := path.inner.(*evalBytes) + if !ok { + return nil, errJSONPath + } + var parser json.PathParser + return parser.ParseBytes(pathBytes.bytes) +} + +func (c *compiler) extractOneOrAll(fname string, expr Expr) (jsonMatch, error) { + lit, ok := expr.(*Literal) + if !ok { + return jsonMatchInvalid, errOneOrAll(fname) + } + b, ok := lit.inner.(*evalBytes) + if !ok { + return jsonMatchInvalid, errOneOrAll(fname) + } + return intoOneOrAll(fname, b.string()) +} + +func (c *compiler) compileJSONExtract(call *builtinJSONExtract) (ctype, error) { + doct, err := c.compileExpr(call.Arguments[0]) + if err != nil { + return ctype{}, err + } + + if slices2.All(call.Arguments[1:], func(expr Expr) bool { return expr.constant() }) { + paths := make([]*json.Path, 0, len(call.Arguments[1:])) + + for _, arg := range call.Arguments[1:] { + jp, err := c.extractJSONPath(arg) + if err != nil { + return ctype{}, err + } + paths = append(paths, jp) + } + + jt, err := c.compileParseJSON("JSON_EXTRACT", doct, 1) + if err != nil { + return ctype{}, err + } + + c.emitFn_JSON_EXTRACT0(paths) + return jt, nil + } + + return ctype{}, c.unsupported(call) +} + +func (c *compiler) compileJSONUnquote(call *builtinJSONUnquote) (ctype, error) { + arg, err := c.compileExpr(call.Arguments[0]) + if err != nil { + return ctype{}, err + } + + skip := c.jumpFrom() + c.emitNullCheck1(skip) + + _, err = c.compileParseJSON("JSON_UNQUOTE", arg, 1) + if err != nil { + return ctype{}, err + } + + c.emitFn_JSON_UNQUOTE() + c.jumpDestination(skip) + return ctype{Type: sqltypes.Blob, Col: collationJSON}, nil +} + +func (c *compiler) compileJSONContainsPath(call *builtinJSONContainsPath) (ctype, error) { + doct, err := c.compileExpr(call.Arguments[0]) + if err != nil { + return ctype{}, err + } + + if !call.Arguments[1].constant() { + return ctype{}, c.unsupported(call) + } + + if !slices2.All(call.Arguments[2:], func(expr Expr) bool { return expr.constant() }) { + return ctype{}, c.unsupported(call) + } + + match, err := c.extractOneOrAll("JSON_CONTAINS_PATH", call.Arguments[1]) + if err != nil { + return ctype{}, err + } + + paths := make([]*json.Path, 0, len(call.Arguments[2:])) + + for _, arg := range call.Arguments[2:] { + jp, err := c.extractJSONPath(arg) + if err != nil { + return ctype{}, err + } + paths = append(paths, jp) + } + + _, err = c.compileParseJSON("JSON_CONTAINS_PATH", doct, 1) + if err != nil { + return ctype{}, err + } + + c.emitFn_JSON_CONTAINS_PATH(match, paths) + return ctype{Type: sqltypes.Int64, Col: collationNumeric}, nil +} + +func (c *compiler) compileParseJSON(fn string, doct ctype, offset int) (ctype, error) { + switch doct.Type { + case sqltypes.TypeJSON: + case sqltypes.VarChar, sqltypes.VarBinary: + c.emitParse_j(offset) + default: + return ctype{}, errJSONType(fn) + } + return ctype{Type: sqltypes.TypeJSON, Col: collationJSON}, nil +} + +func (c *compiler) compileToJSON(doct ctype, offset int) (ctype, error) { + switch doct.Type { + case sqltypes.TypeJSON: + return doct, nil + case sqltypes.Float64: + c.emitConvert_fj(offset) + case sqltypes.Int64, sqltypes.Uint64, sqltypes.Decimal: + c.emitConvert_nj(offset) + case sqltypes.VarChar: + c.emitConvert_cj(offset) + case sqltypes.VarBinary: + c.emitConvert_bj(offset) + default: + return ctype{}, vterrors.Errorf(vtrpc.Code_UNIMPLEMENTED, "Unsupported type conversion: %s AS JSON", doct.Type) + } + return ctype{Type: sqltypes.TypeJSON, Col: collationJSON}, nil +} + +func (c *compiler) compileBitwise(expr *BitwiseExpr) (ctype, error) { + switch expr.Op.(type) { + case *opBitAnd: + return c.compileBitwiseOp(expr.Left, expr.Right, and) + case *opBitOr: + return c.compileBitwiseOp(expr.Left, expr.Right, or) + case *opBitXor: + return c.compileBitwiseOp(expr.Left, expr.Right, xor) + case *opBitShl: + return c.compileBitwiseShift(expr.Left, expr.Right, -1) + case *opBitShr: + return c.compileBitwiseShift(expr.Left, expr.Right, 1) + default: + panic("unexpected arithmetic operator") + } +} + +type bitwiseOp int + +const ( + and bitwiseOp = iota + or + xor +) + +func (c *compiler) compileBitwiseOp(left Expr, right Expr, op bitwiseOp) (ctype, error) { + lt, err := c.compileExpr(left) + if err != nil { + return ctype{}, err + } + + rt, err := c.compileExpr(right) + if err != nil { + return ctype{}, err + } + + skip := c.jumpFrom() + c.emitNullCheck2(skip) + + if lt.Type == sqltypes.VarBinary && rt.Type == sqltypes.VarBinary { + if !lt.isHexOrBitLiteral() || !rt.isHexOrBitLiteral() { + c.emitBitOp_bb(op) + c.jumpDestination(skip) + return ctype{Type: sqltypes.VarBinary, Col: collationBinary}, nil + } + } + + lt = c.compileToUint64(lt, 2) + rt = c.compileToUint64(rt, 1) + + c.emitBitOp_uu(op) + c.jumpDestination(skip) + return ctype{Type: sqltypes.Uint64, Col: collationNumeric}, nil +} + +func (c *compiler) compileBitwiseShift(left Expr, right Expr, i int) (ctype, error) { + lt, err := c.compileExpr(left) + if err != nil { + return ctype{}, err + } + + rt, err := c.compileExpr(right) + if err != nil { + return ctype{}, err + } + + skip := c.jumpFrom() + c.emitNullCheck2(skip) + + if lt.Type == sqltypes.VarBinary && !lt.isHexOrBitLiteral() { + _ = c.compileToUint64(rt, 1) + if i < 0 { + c.emitBitShiftLeft_bu() + } else { + c.emitBitShiftRight_bu() + } + c.jumpDestination(skip) + return ctype{Type: sqltypes.VarBinary, Col: collationBinary}, nil + } + + _ = c.compileToUint64(lt, 2) + _ = c.compileToUint64(rt, 1) + + if i < 0 { + c.emitBitShiftLeft_uu() + } else { + c.emitBitShiftRight_uu() + } + + c.jumpDestination(skip) + return ctype{Type: sqltypes.Uint64, Col: collationNumeric}, nil +} + +func (c *compiler) compileBitCount(expr *builtinBitCount) (ctype, error) { + ct, err := c.compileExpr(expr.Arguments[0]) + if err != nil { + return ctype{}, err + } + + skip := c.jumpFrom() + c.emitNullCheck1(skip) + + if ct.Type == sqltypes.VarBinary && !ct.isHexOrBitLiteral() { + c.emitBitCount_b() + c.jumpDestination(skip) + return ctype{Type: sqltypes.Int64, Col: collationBinary}, nil + } + + _ = c.compileToUint64(ct, 1) + c.emitBitCount_u() + c.jumpDestination(skip) + return ctype{Type: sqltypes.Int64, Col: collationBinary}, nil +} + +func (c *compiler) compileBitwiseNot(expr *BitwiseNotExpr) (ctype, error) { + ct, err := c.compileExpr(expr.Inner) + if err != nil { + return ctype{}, err + } + + skip := c.jumpFrom() + c.emitNullCheck1(skip) + + if ct.Type == sqltypes.VarBinary && !ct.isHexOrBitLiteral() { + c.emitBitwiseNot_b() + c.jumpDestination(skip) + return ct, nil + } + + ct = c.compileToUint64(ct, 1) + c.emitBitwiseNot_u() + c.jumpDestination(skip) + return ct, nil +} + +func (c *compiler) compileConvert(conv *ConvertExpr) (ctype, error) { + arg, err := c.compileExpr(conv.Inner) + if err != nil { + return ctype{}, err + } + + skip := c.jumpFrom() + c.emitNullCheck1(skip) + + var convt ctype + + switch conv.Type { + case "BINARY": + convt = ctype{Type: conv.convertToBinaryType(arg.Type), Col: collationBinary} + c.emitConvert_xb(1, convt.Type, conv.Length, conv.HasLength) + + case "CHAR", "NCHAR": + convt = ctype{ + Type: conv.convertToCharType(arg.Type), + Col: collations.TypedCollation{Collation: conv.Collation}, + } + c.emitConvert_xc(1, convt.Type, convt.Col.Collation, conv.Length, conv.HasLength) + + case "DECIMAL": + convt = ctype{Type: sqltypes.Decimal, Col: collationNumeric} + m, d := conv.decimalPrecision() + c.emitConvert_xd(1, m, d) + + case "DOUBLE", "REAL": + convt = c.compileToFloat(arg, 1) + + case "SIGNED", "SIGNED INTEGER": + convt = c.compileToInt64(arg, 1) + + case "UNSIGNED", "UNSIGNED INTEGER": + convt = c.compileToUint64(arg, 1) + + case "JSON": + // TODO: what does NULL map to? + convt, err = c.compileToJSON(arg, 1) + if err != nil { + return ctype{}, err + } + + default: + return ctype{}, c.unsupported(conv) + } + + c.jumpDestination(skip) + return convt, nil +} + +func (c *compiler) compileCase(cs *CaseExpr) (ctype, error) { + var ca collationAggregation + var ta typeAggregation + var local = collations.Local() + + for _, wt := range cs.cases { + when, err := c.compileExpr(wt.when) + if err != nil { + return ctype{}, err + } + + if err := c.compileCheckTrue(when, 1); err != nil { + return ctype{}, err + } + + then, err := c.compileExpr(wt.then) + if err != nil { + return ctype{}, err + } + + ta.add(then.Type, then.Flag) + if err := ca.add(local, then.Col); err != nil { + return ctype{}, err + } + } + + if cs.Else != nil { + els, err := c.compileExpr(cs.Else) + if err != nil { + return ctype{}, err + } + + ta.add(els.Type, els.Flag) + if err := ca.add(local, els.Col); err != nil { + return ctype{}, err + } + } + + ct := ctype{Type: ta.result(), Col: ca.result()} + c.emitCmpCase(len(cs.cases), cs.Else != nil, ct.Type, ct.Col) + return ct, nil +} + +func (c *compiler) compileCheckTrue(when ctype, offset int) error { + switch when.Type { + case sqltypes.Int64: + c.emitConvert_iB(offset) + case sqltypes.Uint64: + c.emitConvert_uB(offset) + case sqltypes.Float64: + c.emitConvert_fB(offset) + case sqltypes.Decimal: + c.emitConvert_dB(offset) + case sqltypes.VarChar, sqltypes.VarBinary: + c.emitConvert_bB(offset) + case sqltypes.Null: + c.emitSetBool(offset, false) + default: + return vterrors.Errorf(vtrpc.Code_UNIMPLEMENTED, "unsupported Truth check: %s", when.Type) + } + return nil +} + +func (c *compiler) compileNegate(expr *NegateExpr) (ctype, error) { + arg, err := c.compileExpr(expr.Inner) + if err != nil { + return ctype{}, err + } + + skip := c.jumpFrom() + c.emitNullCheck1(skip) + + arg = c.compileToNumeric(arg, 1) + var neg sqltypes.Type + + switch arg.Type { + case sqltypes.Int64: + neg = sqltypes.Int64 + c.emitNeg_i() + case sqltypes.Uint64: + if arg.Flag&flagHex != 0 { + neg = sqltypes.Float64 + c.emitNeg_hex() + } else { + neg = sqltypes.Int64 + c.emitNeg_u() + } + case sqltypes.Float64: + neg = sqltypes.Float64 + c.emitNeg_f() + case sqltypes.Decimal: + neg = sqltypes.Decimal + c.emitNeg_d() + default: + panic("unexpected Numeric type") + } + + c.jumpDestination(skip) + return ctype{Type: neg, Col: collationNumeric}, nil +} + +func (c *compiler) compileMultiComparison(call *builtinMultiComparison) (ctype, error) { + var ( + integers int + floats int + decimals int + text int + binary int + args []ctype + ) + + /* + If any argument is NULL, the result is NULL. No comparison is needed. + If all arguments are integer-valued, they are compared as integers. + If at least one argument is double precision, they are compared as double-precision values. Otherwise, if at least one argument is a DECIMAL value, they are compared as DECIMAL values. + If the arguments comprise a mix of numbers and strings, they are compared as strings. + If any argument is a nonbinary (character) string, the arguments are compared as nonbinary strings. + In all other cases, the arguments are compared as binary strings. + */ + + for _, expr := range call.Arguments { + tt, err := c.compileExpr(expr) + if err != nil { + return ctype{}, err + } + + args = append(args, tt) + + switch tt.Type { + case sqltypes.Int64: + integers++ + case sqltypes.Uint64: + integers++ + case sqltypes.Float64: + floats++ + case sqltypes.Decimal: + decimals++ + case sqltypes.Text, sqltypes.VarChar: + text++ + case sqltypes.Blob, sqltypes.Binary, sqltypes.VarBinary: + binary++ + default: + return ctype{}, c.unsupported(call) + } + } + + return ctype{}, c.unsupported(call) + + /* + if integers == len(args) { + for i, tt := range args { + c.compileToInt64(tt, len(args)-i) + } + return c.emitFn_MULTICMP_i(len(args)) + } + if binary > 0 || text > 0 { + if binary > 0 { + return c.emitFn_MULTICMP_b(len(args)) + } + + return c.compileMultiComparison_c() + } else { + if floats > 0 { + return c.compileMultiComparison_f() + } + if decimals > 0 { + return c.compileMultiComparison_d() + } + } + return ctype{}, vterrors.Errorf(vtrpc.Code_INTERNAL, "unexpected argument for GREATEST/LEAST") + + skip := c.jumpFrom() + c.emitNullCheckn(skip, len(call.Arguments)) + */ +} diff --git a/go/vt/vtgate/evalengine/compiler_instructions.go b/go/vt/vtgate/evalengine/compiler_instructions.go new file mode 100644 index 00000000000..7e9f81f0cab --- /dev/null +++ b/go/vt/vtgate/evalengine/compiler_instructions.go @@ -0,0 +1,1311 @@ +package evalengine + +import ( + "math" + "math/bits" + + "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/slices2" + "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/evalengine/internal/decimal" + "vitess.io/vitess/go/vt/vtgate/evalengine/internal/json" +) + +func cmpnum[N interface{ int64 | uint64 | float64 }](a, b N) int { + switch { + case a == b: + return 0 + case a < b: + return -1 + default: + return 1 + } +} + +func (c *compiler) emit(f frame, asm string, args ...any) { + if c.log != nil { + c.log.Disasm(asm, args...) + } + c.ins = append(c.ins, f) +} + +func (c *compiler) emitPushLiteral(lit eval) error { + c.adjustStack(1) + + if lit == evalBoolTrue { + c.emit(func(vm *VirtualMachine) int { + vm.stack[vm.sp] = evalBoolTrue + vm.sp++ + return 1 + }, "PUSH true") + return nil + } + + if lit == evalBoolFalse { + c.emit(func(vm *VirtualMachine) int { + vm.stack[vm.sp] = evalBoolFalse + vm.sp++ + return 1 + }, "PUSH false") + return nil + } + + switch lit := lit.(type) { + case *evalInt64: + c.emit(func(vm *VirtualMachine) int { + vm.stack[vm.sp] = vm.arena.newEvalInt64(lit.i) + vm.sp++ + return 1 + }, "PUSH INT64(%s)", lit.ToRawBytes()) + case *evalUint64: + c.emit(func(vm *VirtualMachine) int { + vm.stack[vm.sp] = vm.arena.newEvalUint64(lit.u) + vm.sp++ + return 1 + }, "PUSH UINT64(%s)", lit.ToRawBytes()) + case *evalFloat: + c.emit(func(vm *VirtualMachine) int { + vm.stack[vm.sp] = vm.arena.newEvalFloat(lit.f) + vm.sp++ + return 1 + }, "PUSH FLOAT64(%s)", lit.ToRawBytes()) + case *evalDecimal: + c.emit(func(vm *VirtualMachine) int { + vm.stack[vm.sp] = vm.arena.newEvalDecimalWithPrec(lit.dec, lit.length) + vm.sp++ + return 1 + }, "PUSH DECIMAL(%s)", lit.ToRawBytes()) + case *evalBytes: + c.emit(func(vm *VirtualMachine) int { + b := vm.arena.newEvalBytesEmpty() + *b = *lit + vm.stack[vm.sp] = b + vm.sp++ + return 1 + }, "PUSH VARCHAR(%q)", lit.ToRawBytes()) + default: + return vterrors.Errorf(vtrpc.Code_UNIMPLEMENTED, "unsupported literal kind '%T'", lit) + } + + return nil +} + +func (c *compiler) emitSetBool(offset int, b bool) { + if b { + c.emit(func(vm *VirtualMachine) int { + vm.stack[vm.sp-offset] = evalBoolTrue + return 1 + }, "SET (SP-%d), BOOL(true)", offset) + } else { + c.emit(func(vm *VirtualMachine) int { + vm.stack[vm.sp-offset] = evalBoolFalse + return 1 + }, "SET (SP-%d), BOOL(false)", offset) + } +} + +func (c *compiler) emitPushNull() { + c.adjustStack(1) + + c.emit(func(vm *VirtualMachine) int { + vm.stack[vm.sp] = nil + vm.sp++ + return 1 + }, "PUSH NULL") +} + +func (c *compiler) emitPushColumn_i(offset int) { + c.adjustStack(1) + + c.emit(func(vm *VirtualMachine) int { + var ival int64 + ival, vm.err = vm.row[offset].ToInt64() + vm.stack[vm.sp] = vm.arena.newEvalInt64(ival) + vm.sp++ + return 1 + }, "PUSH INT64(:%d)", offset) +} + +func (c *compiler) emitPushColumn_u(offset int) { + c.adjustStack(1) + + c.emit(func(vm *VirtualMachine) int { + var uval uint64 + uval, vm.err = vm.row[offset].ToUint64() + vm.stack[vm.sp] = vm.arena.newEvalUint64(uval) + vm.sp++ + return 1 + }, "PUSH UINT64(:%d)", offset) +} + +func (c *compiler) emitPushColumn_f(offset int) { + c.adjustStack(1) + + c.emit(func(vm *VirtualMachine) int { + var fval float64 + fval, vm.err = vm.row[offset].ToFloat64() + vm.stack[vm.sp] = vm.arena.newEvalFloat(fval) + vm.sp++ + return 1 + }, "PUSH FLOAT64(:%d)", offset) +} + +func (c *compiler) emitPushColumn_d(offset int) { + c.adjustStack(1) + + c.emit(func(vm *VirtualMachine) int { + var dec decimal.Decimal + dec, vm.err = decimal.NewFromMySQL(vm.row[offset].Raw()) + vm.stack[vm.sp] = vm.arena.newEvalDecimal(dec, 0, 0) + vm.sp++ + return 1 + }, "PUSH DECIMAL(:%d)", offset) +} + +func (c *compiler) emitPushColumn_hexnum(offset int) { + c.adjustStack(1) + + c.emit(func(vm *VirtualMachine) int { + var raw []byte + raw, vm.err = parseHexNumber(vm.row[offset].Raw()) + vm.stack[vm.sp] = newEvalBytesHex(raw) + vm.sp++ + return 1 + }, "PUSH HEXNUM(:%d)", offset) +} + +func (c *compiler) emitPushColumn_hexval(offset int) { + c.adjustStack(1) + + c.emit(func(vm *VirtualMachine) int { + hex := vm.row[offset].Raw() + var raw []byte + raw, vm.err = parseHexLiteral(hex[2 : len(hex)-1]) + vm.stack[vm.sp] = newEvalBytesHex(raw) + vm.sp++ + return 1 + }, "PUSH HEXVAL(:%d)", offset) +} + +func (c *compiler) emitPushColumn_text(offset int, col collations.TypedCollation) { + c.adjustStack(1) + + c.emit(func(vm *VirtualMachine) int { + vm.stack[vm.sp] = newEvalText(vm.row[offset].Raw(), col) + vm.sp++ + return 1 + }, "PUSH VARCHAR(:%d) COLLATE %d", offset, col.Collation) +} + +func (c *compiler) emitPushColumn_bin(offset int) { + c.adjustStack(1) + + c.emit(func(vm *VirtualMachine) int { + vm.stack[vm.sp] = newEvalBinary(vm.row[offset].Raw()) + vm.sp++ + return 1 + }, "PUSH VARBINARY(:%d)", offset) +} + +func (c *compiler) emitPushColumn_json(offset int) { + c.adjustStack(1) + + c.emit(func(vm *VirtualMachine) int { + var parser json.Parser + vm.stack[vm.sp], vm.err = parser.ParseBytes(vm.row[offset].Raw()) + vm.sp++ + return 1 + }, "PUSH JSON(:%d)", offset) +} + +func (c *compiler) emitConvert_hex(offset int) { + c.emit(func(vm *VirtualMachine) int { + var ok bool + vm.stack[vm.sp-offset], ok = vm.stack[vm.sp-offset].(*evalBytes).toNumericHex() + if !ok { + vm.err = errDeoptimize + } + return 1 + }, "CONV VARBINARY(SP-%d), HEX", offset) +} + +func (c *compiler) emitConvert_if(offset int) { + c.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-offset].(*evalInt64) + vm.stack[vm.sp-offset] = vm.arena.newEvalFloat(arg.toFloat0()) + return 1 + }, "CONV INT64(SP-%d), FLOAT64", offset) +} + +func (c *compiler) emitConvert_uf(offset int) { + c.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-offset].(*evalUint64) + vm.stack[vm.sp-offset] = vm.arena.newEvalFloat(arg.toFloat0()) + return 1 + }, "CONV UINT64(SP-%d), FLOAT64)", offset) +} + +func (c *compiler) emitConvert_xf(offset int) { + c.emit(func(vm *VirtualMachine) int { + vm.stack[vm.sp-offset], _ = evalToNumeric(vm.stack[vm.sp-offset]).toFloat() + return 1 + }, "CONV (SP-%d), FLOAT64", offset) +} + +func (c *compiler) emitConvert_id(offset int) { + c.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-offset].(*evalInt64) + vm.stack[vm.sp-offset] = vm.arena.newEvalDecimalWithPrec(decimal.NewFromInt(arg.i), 0) + return 1 + }, "CONV INT64(SP-%d), FLOAT64", offset) +} + +func (c *compiler) emitConvert_ud(offset int) { + c.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-offset].(*evalUint64) + vm.stack[vm.sp-offset] = vm.arena.newEvalDecimalWithPrec(decimal.NewFromUint(arg.u), 0) + return 1 + }, "CONV UINT64(SP-%d), FLOAT64)", offset) +} + +func (c *compiler) emitConvert_xd(offset int, m, d int32) { + c.emit(func(vm *VirtualMachine) int { + vm.stack[vm.sp-offset] = evalToNumeric(vm.stack[vm.sp-offset]).toDecimal(m, d) + return 1 + }, "CONV (SP-%d), FLOAT64", offset) +} + +func (c *compiler) emitParse_j(offset int) { + c.emit(func(vm *VirtualMachine) int { + var p json.Parser + arg := vm.stack[vm.sp-offset].(*evalBytes) + vm.stack[vm.sp-offset], vm.err = p.ParseBytes(arg.bytes) + return 1 + }, "PARSE_JSON VARCHAR(SP-%d)", offset) +} + +func (c *compiler) emitAdd_ii() { + c.adjustStack(-1) + + c.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalInt64) + r := vm.stack[vm.sp-1].(*evalInt64) + l.i, vm.err = mathAdd_ii0(l.i, r.i) + vm.sp-- + return 1 + }, "ADD INT64(SP-2), INT64(SP-1)") +} + +func (c *compiler) emitAdd_ui(swap bool) { + c.adjustStack(-1) + + if swap { + c.emit(func(vm *VirtualMachine) int { + var u uint64 + l := vm.stack[vm.sp-1].(*evalUint64) + r := vm.stack[vm.sp-2].(*evalInt64) + u, vm.err = mathAdd_ui0(l.u, r.i) + vm.stack[vm.sp-2] = vm.arena.newEvalUint64(u) + vm.sp-- + return 1 + }, "ADD UINT64(SP-1), INT64(SP-2)") + } else { + c.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalUint64) + r := vm.stack[vm.sp-1].(*evalInt64) + l.u, vm.err = mathAdd_ui0(l.u, r.i) + vm.sp-- + return 1 + }, "ADD UINT64(SP-2), INT64(SP-1)") + } +} + +func (c *compiler) emitAdd_uu() { + c.adjustStack(-1) + + c.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalUint64) + r := vm.stack[vm.sp-1].(*evalUint64) + l.u, vm.err = mathAdd_uu0(l.u, r.u) + vm.sp-- + return 1 + }, "ADD UINT64(SP-2), UINT64(SP-1)") +} + +func (c *compiler) emitAdd_ff() { + c.adjustStack(-1) + + c.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalFloat) + r := vm.stack[vm.sp-1].(*evalFloat) + l.f += r.f + vm.sp-- + return 1 + }, "ADD FLOAT64(SP-2), FLOAT64(SP-1)") +} + +func (c *compiler) emitAdd_dd() { + c.adjustStack(-1) + + c.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalDecimal) + r := vm.stack[vm.sp-1].(*evalDecimal) + mathAdd_dd0(l, r) + vm.sp-- + return 1 + }, "ADD DECIMAL(SP-2), DECIMAL(SP-1)") +} + +func (c *compiler) emitSub_ii() { + c.adjustStack(-1) + + c.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalInt64) + r := vm.stack[vm.sp-1].(*evalInt64) + l.i, vm.err = mathSub_ii0(l.i, r.i) + vm.sp-- + return 1 + }, "SUB INT64(SP-2), INT64(SP-1)") +} + +func (c *compiler) emitSub_iu() { + c.adjustStack(-1) + + c.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalInt64) + r := vm.stack[vm.sp-1].(*evalUint64) + r.u, vm.err = mathSub_iu0(l.i, r.u) + vm.stack[vm.sp-2] = r + vm.sp-- + return 1 + }, "SUB INT64(SP-2), UINT64(SP-1)") +} + +func (c *compiler) emitSub_ff() { + c.adjustStack(-1) + + c.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalFloat) + r := vm.stack[vm.sp-1].(*evalFloat) + l.f -= r.f + vm.sp-- + return 1 + }, "SUB FLOAT64(SP-2), FLOAT64(SP-1)") +} + +func (c *compiler) emitSub_dd() { + c.adjustStack(-1) + + c.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalDecimal) + r := vm.stack[vm.sp-1].(*evalDecimal) + mathSub_dd0(l, r) + vm.sp-- + return 1 + }, "SUB DECIMAL(SP-2), DECIMAL(SP-1)") +} + +func (c *compiler) emitSub_ui() { + c.adjustStack(-1) + + c.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalUint64) + r := vm.stack[vm.sp-1].(*evalInt64) + l.u, vm.err = mathSub_ui0(l.u, r.i) + vm.sp-- + return 1 + }, "SUB UINT64(SP-2), INT64(SP-1)") +} + +func (c *compiler) emitSub_uu() { + c.adjustStack(-1) + + c.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalUint64) + r := vm.stack[vm.sp-1].(*evalUint64) + l.u, vm.err = mathSub_uu0(l.u, r.u) + vm.sp-- + return 1 + }, "SUB UINT64(SP-2), UINT64(SP-1)") +} + +func (c *compiler) emitMul_ii() { + c.adjustStack(-1) + + c.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalInt64) + r := vm.stack[vm.sp-1].(*evalInt64) + l.i, vm.err = mathMul_ii0(l.i, r.i) + vm.sp-- + return 1 + }, "MUL INT64(SP-2), INT64(SP-1)") +} + +func (c *compiler) emitMul_ui(swap bool) { + c.adjustStack(-1) + + if swap { + c.emit(func(vm *VirtualMachine) int { + var u uint64 + l := vm.stack[vm.sp-1].(*evalUint64) + r := vm.stack[vm.sp-2].(*evalInt64) + u, vm.err = mathMul_ui0(l.u, r.i) + vm.stack[vm.sp-2] = vm.arena.newEvalUint64(u) + vm.sp-- + return 1 + }, "MUL UINT64(SP-1), INT64(SP-2)") + } else { + c.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalUint64) + r := vm.stack[vm.sp-1].(*evalInt64) + l.u, vm.err = mathMul_ui0(l.u, r.i) + vm.sp-- + return 1 + }, "MUL UINT64(SP-2), INT64(SP-1)") + } +} + +func (c *compiler) emitMul_uu() { + c.adjustStack(-1) + + c.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalUint64) + r := vm.stack[vm.sp-1].(*evalUint64) + l.u, vm.err = mathMul_uu0(l.u, r.u) + vm.sp-- + return 1 + }, "MUL UINT64(SP-2), UINT64(SP-1)") +} + +func (c *compiler) emitMul_ff() { + c.adjustStack(-1) + + c.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalFloat) + r := vm.stack[vm.sp-1].(*evalFloat) + l.f *= r.f + vm.sp-- + return 1 + }, "MUL FLOAT64(SP-2), FLOAT64(SP-1)") +} + +func (c *compiler) emitMul_dd() { + c.adjustStack(-1) + + c.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalDecimal) + r := vm.stack[vm.sp-1].(*evalDecimal) + mathMul_dd0(l, r) + vm.sp-- + return 1 + }, "MUL DECIMAL(SP-2), DECIMAL(SP-1)") +} + +func (c *compiler) emitDiv_ff() { + c.adjustStack(-1) + + c.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalFloat) + r := vm.stack[vm.sp-1].(*evalFloat) + if r.f == 0.0 { + vm.stack[vm.sp-2] = nil + } else { + l.f, vm.err = mathDiv_ff0(l.f, r.f) + } + vm.sp-- + return 1 + }, "DIV FLOAT64(SP-2), FLOAT64(SP-1)") +} + +func (c *compiler) emitDiv_dd() { + c.adjustStack(-1) + + c.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalDecimal) + r := vm.stack[vm.sp-1].(*evalDecimal) + if r.dec.IsZero() { + vm.stack[vm.sp-2] = nil + } else { + mathDiv_dd0(l, r, divPrecisionIncrement) + } + vm.sp-- + return 1 + }, "DIV DECIMAL(SP-2), DECIMAL(SP-1)") +} + +func (c *compiler) emitNullCmp(j *jump) { + c.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2] + r := vm.stack[vm.sp-1] + if l == nil || r == nil { + if l == r { + vm.flags.cmp = 0 + } else { + vm.flags.cmp = 1 + } + vm.sp -= 2 + return j.offset() + } + return 1 + }, "NULLCMP SP-1, SP-2") +} + +func (c *compiler) emitNullCheck1(j *jump) { + c.emit(func(vm *VirtualMachine) int { + if vm.stack[vm.sp-1] == nil { + return j.offset() + } + return 1 + }, "NULLCHECK SP-1") +} + +func (c *compiler) emitNullCheck2(j *jump) { + c.emit(func(vm *VirtualMachine) int { + if vm.stack[vm.sp-2] == nil || vm.stack[vm.sp-1] == nil { + vm.stack[vm.sp-2] = nil + vm.sp-- + return j.offset() + } + return 1 + }, "NULLCHECK SP-1, SP-2") +} + +func (c *compiler) emitCmp_eq() { + c.adjustStack(1) + + c.emit(func(vm *VirtualMachine) int { + vm.stack[vm.sp] = newEvalBool(vm.flags.cmp == 0) + vm.sp++ + return 1 + }, "CMPFLAG EQ") +} + +func (c *compiler) emitCmp_ne() { + c.adjustStack(1) + + c.emit(func(vm *VirtualMachine) int { + vm.stack[vm.sp] = newEvalBool(vm.flags.cmp != 0) + vm.sp++ + return 1 + }, "CMPFLAG NE") +} + +func (c *compiler) emitCmp_lt() { + c.adjustStack(1) + + c.emit(func(vm *VirtualMachine) int { + vm.stack[vm.sp] = newEvalBool(vm.flags.cmp < 0) + vm.sp++ + return 1 + }, "CMPFLAG LT") +} + +func (c *compiler) emitCmp_le() { + c.adjustStack(1) + + c.emit(func(vm *VirtualMachine) int { + vm.stack[vm.sp] = newEvalBool(vm.flags.cmp <= 0) + vm.sp++ + return 1 + }, "CMPFLAG LE") +} + +func (c *compiler) emitCmp_gt() { + c.adjustStack(1) + + c.emit(func(vm *VirtualMachine) int { + vm.stack[vm.sp] = newEvalBool(vm.flags.cmp > 0) + vm.sp++ + return 1 + }, "CMPFLAG GT") +} + +func (c *compiler) emitCmp_ge() { + c.adjustStack(1) + + c.emit(func(vm *VirtualMachine) int { + vm.stack[vm.sp] = newEvalBool(vm.flags.cmp >= 0) + vm.sp++ + return 1 + }, "CMPFLAG GE") +} + +func (c *compiler) emitCmpNum_ii() { + c.adjustStack(-2) + + c.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalInt64) + r := vm.stack[vm.sp-1].(*evalInt64) + vm.sp -= 2 + vm.flags.cmp = cmpnum(l.i, r.i) + return 1 + }, "CMP INT64(SP-2), INT64(SP-1)") +} + +func (c *compiler) emitCmpNum_iu(left, right int) { + c.adjustStack(-2) + + c.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-left].(*evalInt64) + r := vm.stack[vm.sp-right].(*evalUint64) + vm.sp -= 2 + if l.i < 0 { + vm.flags.cmp = -1 + } else { + vm.flags.cmp = cmpnum(uint64(l.i), r.u) + } + return 1 + }, "CMP INT64(SP-%d), UINT64(SP-%d)", left, right) +} + +func (c *compiler) emitCmpNum_if(left, right int) { + c.adjustStack(-2) + + c.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-left].(*evalInt64) + r := vm.stack[vm.sp-right].(*evalFloat) + vm.sp -= 2 + vm.flags.cmp = cmpnum(float64(l.i), r.f) + return 1 + }, "CMP INT64(SP-%d), FLOAT64(SP-%d)", left, right) +} + +func (c *compiler) emitCmpNum_id(left, right int) { + c.adjustStack(-2) + + c.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-left].(*evalInt64) + r := vm.stack[vm.sp-right].(*evalDecimal) + vm.sp -= 2 + vm.flags.cmp = decimal.NewFromInt(l.i).Cmp(r.dec) + return 1 + }, "CMP INT64(SP-%d), DECIMAL(SP-%d)", left, right) +} + +func (c *compiler) emitCmpNum_uu() { + c.adjustStack(-2) + + c.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalUint64) + r := vm.stack[vm.sp-1].(*evalUint64) + vm.sp -= 2 + vm.flags.cmp = cmpnum(l.u, r.u) + return 1 + }, "CMP UINT64(SP-2), UINT64(SP-1)") +} + +func (c *compiler) emitCmpNum_uf(left, right int) { + c.adjustStack(-2) + + c.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-left].(*evalUint64) + r := vm.stack[vm.sp-right].(*evalFloat) + vm.sp -= 2 + vm.flags.cmp = cmpnum(float64(l.u), r.f) + return 1 + }, "CMP UINT64(SP-%d), FLOAT64(SP-%d)", left, right) +} + +func (c *compiler) emitCmpNum_ud(left, right int) { + c.adjustStack(-2) + + c.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-left].(*evalUint64) + r := vm.stack[vm.sp-right].(*evalDecimal) + vm.sp -= 2 + vm.flags.cmp = decimal.NewFromUint(l.u).Cmp(r.dec) + return 1 + }, "CMP UINT64(SP-%d), DECIMAL(SP-%d)", left, right) +} + +func (c *compiler) emitCmpNum_ff() { + c.adjustStack(-2) + + c.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalFloat) + r := vm.stack[vm.sp-1].(*evalFloat) + vm.sp -= 2 + vm.flags.cmp = cmpnum(l.f, r.f) + return 1 + }, "CMP FLOAT64(SP-2), FLOAT64(SP-1)") +} + +func (c *compiler) emitCmpNum_fd(left, right int) { + c.adjustStack(-2) + + c.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-left].(*evalFloat) + r := vm.stack[vm.sp-right].(*evalDecimal) + vm.sp -= 2 + fval, ok := r.dec.Float64() + if !ok { + vm.err = errDecimalOutOfRange + } + vm.flags.cmp = cmpnum(l.f, fval) + return 1 + }, "CMP FLOAT64(SP-%d), DECIMAL(SP-%d)", left, right) +} + +func (c *compiler) emitCmpNum_dd() { + c.adjustStack(-2) + + c.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalDecimal) + r := vm.stack[vm.sp-1].(*evalDecimal) + vm.sp -= 2 + vm.flags.cmp = l.dec.Cmp(r.dec) + return 1 + }, "CMP DECIMAL(SP-2), DECIMAL(SP-1)") +} + +func (c *compiler) emitCmpString_collate(collation collations.Collation) { + c.adjustStack(-2) + + c.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalBytes) + r := vm.stack[vm.sp-1].(*evalBytes) + vm.sp -= 2 + vm.flags.cmp = collation.Collate(l.bytes, r.bytes, false) + return 1 + }, "CMP VARCHAR(SP-2), VARCHAR(SP-1) COLLATE '%s'", collation.Name()) +} + +func (c *compiler) emitCmpString_coerce(coercion *compiledCoercion) { + c.adjustStack(-2) + + c.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalBytes) + r := vm.stack[vm.sp-1].(*evalBytes) + vm.sp -= 2 + + var bl, br []byte + bl, vm.err = coercion.left(nil, l.bytes) + if vm.err != nil { + return 0 + } + br, vm.err = coercion.right(nil, r.bytes) + if vm.err != nil { + return 0 + } + vm.flags.cmp = coercion.col.Collate(bl, br, false) + return 1 + }, "CMP VARCHAR(SP-2), VARCHAR(SP-1) COERCE AND COLLATE '%s'", coercion.col.Name()) +} + +func (c *compiler) emitCollate(col collations.ID) { + c.emit(func(vm *VirtualMachine) int { + a := vm.stack[vm.sp-1].(*evalBytes) + a.tt = int16(sqltypes.VarChar) + a.col.Collation = col + return 1 + }, "COLLATE VARCHAR(SP-1), %d", col) +} + +func (c *compiler) emitFn_JSON_UNQUOTE() { + c.emit(func(vm *VirtualMachine) int { + j := vm.stack[vm.sp-1].(*evalJSON) + b := vm.arena.newEvalBytesEmpty() + b.tt = int16(sqltypes.Blob) + b.col = collationJSON + if jbytes, ok := j.StringBytes(); ok { + b.bytes = jbytes + } else { + b.bytes = j.MarshalTo(nil) + } + vm.stack[vm.sp-1] = b + return 1 + }, "FN JSON_UNQUOTE (SP-1)") +} + +func (c *compiler) emitFn_JSON_EXTRACT0(jp []*json.Path) { + multi := len(jp) > 1 || slices2.Any(jp, func(path *json.Path) bool { return path.ContainsWildcards() }) + + if multi { + c.emit(func(vm *VirtualMachine) int { + matches := make([]*json.Value, 0, 4) + arg := vm.stack[vm.sp-1].(*evalJSON) + for _, jp := range jp { + jp.Match(arg, true, func(value *json.Value) { + matches = append(matches, value) + }) + } + if len(matches) == 0 { + vm.stack[vm.sp-1] = nil + } else { + vm.stack[vm.sp-1] = json.NewArray(matches) + } + return 1 + }, "FN JSON_EXTRACT, SP-1, [static]") + } else { + c.emit(func(vm *VirtualMachine) int { + var match *json.Value + arg := vm.stack[vm.sp-1].(*evalJSON) + jp[0].Match(arg, true, func(value *json.Value) { + match = value + }) + if match == nil { + vm.stack[vm.sp-1] = nil + } else { + vm.stack[vm.sp-1] = match + } + return 1 + }, "FN JSON_EXTRACT, SP-1, [static]") + } +} + +func (c *compiler) emitFn_JSON_CONTAINS_PATH(match jsonMatch, paths []*json.Path) { + switch match { + case jsonMatchOne: + c.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-1].(*evalJSON) + matched := false + for _, p := range paths { + p.Match(arg, true, func(*json.Value) { matched = true }) + if matched { + break + } + } + vm.stack[vm.sp-1] = newEvalBool(matched) + return 1 + }, "FN JSON_CONTAINS_PATH, SP-1, 'one', [static]") + case jsonMatchAll: + c.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-1].(*evalJSON) + matched := true + for _, p := range paths { + matched = false + p.Match(arg, true, func(*json.Value) { matched = true }) + if !matched { + break + } + } + vm.stack[vm.sp-1] = newEvalBool(matched) + return 1 + }, "FN JSON_CONTAINS_PATH, SP-1, 'all', [static]") + } +} + +func (c *compiler) emitBitOp_bb(op bitwiseOp) { + c.adjustStack(-1) + + switch op { + case and: + c.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalBytes) + r := vm.stack[vm.sp-1].(*evalBytes) + if len(l.bytes) != len(r.bytes) { + vm.err = errBitwiseOperandsLength + return 0 + } + for i := range l.bytes { + l.bytes[i] = l.bytes[i] & r.bytes[i] + } + vm.sp-- + return 1 + }, "AND BINARY(SP-2), BINARY(SP-1)") + case or: + c.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalBytes) + r := vm.stack[vm.sp-1].(*evalBytes) + if len(l.bytes) != len(r.bytes) { + vm.err = errBitwiseOperandsLength + return 0 + } + for i := range l.bytes { + l.bytes[i] = l.bytes[i] | r.bytes[i] + } + vm.sp-- + return 1 + }, "OR BINARY(SP-2), BINARY(SP-1)") + case xor: + c.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalBytes) + r := vm.stack[vm.sp-1].(*evalBytes) + if len(l.bytes) != len(r.bytes) { + vm.err = errBitwiseOperandsLength + return 0 + } + for i := range l.bytes { + l.bytes[i] = l.bytes[i] ^ r.bytes[i] + } + vm.sp-- + return 1 + }, "XOR BINARY(SP-2), BINARY(SP-1)") + } +} + +func (c *compiler) emitBitOp_uu(op bitwiseOp) { + c.adjustStack(-1) + + switch op { + case and: + c.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalUint64) + r := vm.stack[vm.sp-1].(*evalUint64) + l.u = l.u & r.u + vm.sp-- + return 1 + }, "AND UINT64(SP-2), UINT64(SP-1)") + case or: + c.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalUint64) + r := vm.stack[vm.sp-1].(*evalUint64) + l.u = l.u | r.u + vm.sp-- + return 1 + }, "OR UINT64(SP-2), UINT64(SP-1)") + case xor: + c.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalUint64) + r := vm.stack[vm.sp-1].(*evalUint64) + l.u = l.u ^ r.u + vm.sp-- + return 1 + }, "XOR UINT64(SP-2), UINT64(SP-1)") + } +} + +func (c *compiler) emitConvert_iu(offset int) { + c.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-offset].(*evalInt64) + vm.stack[vm.sp-offset] = vm.arena.newEvalUint64(uint64(arg.i)) + return 1 + }, "CONV INT64(SP-%d), UINT64", offset) +} + +func (c *compiler) emitConvert_xi(offset int) { + c.emit(func(vm *VirtualMachine) int { + arg := evalToNumeric(vm.stack[vm.sp-offset]) + vm.stack[vm.sp-offset] = arg.toInt64() + return 1 + }, "CONV (SP-%d), INT64", offset) +} + +func (c *compiler) emitConvert_xu(offset int) { + c.emit(func(vm *VirtualMachine) int { + arg := evalToNumeric(vm.stack[vm.sp-offset]) + vm.stack[vm.sp-offset] = arg.toUint64() + return 1 + }, "CONV (SP-%d), UINT64", offset) +} + +func (c *compiler) emitConvert_xb(offset int, t sqltypes.Type, length int, hasLength bool) { + if hasLength { + c.emit(func(vm *VirtualMachine) int { + arg := evalToBinary(vm.stack[vm.sp-offset]) + arg.truncateInPlace(length) + arg.tt = int16(t) + vm.stack[vm.sp-offset] = arg + return 1 + }, "CONV (SP-%d), VARBINARY[%d]", offset, length) + } else { + c.emit(func(vm *VirtualMachine) int { + arg := evalToBinary(vm.stack[vm.sp-offset]) + arg.tt = int16(t) + vm.stack[vm.sp-offset] = arg + return 1 + }, "CONV (SP-%d), VARBINARY", offset) + } +} + +func (c *compiler) emitConvert_fj(offset int) { + c.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-offset].(*evalFloat) + vm.stack[vm.sp-offset] = evalConvert_fj(arg) + return 1 + }, "CONV FLOAT64(SP-%d), JSON") +} + +func (c *compiler) emitConvert_nj(offset int) { + c.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-offset].(evalNumeric) + vm.stack[vm.sp-offset] = evalConvert_nj(arg) + return 1 + }, "CONV numeric(SP-%d), JSON") +} + +func (c *compiler) emitConvert_cj(offset int) { + c.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-offset].(*evalBytes) + vm.stack[vm.sp-offset], vm.err = evalConvert_cj(arg) + return 1 + }, "CONV VARCHAR(SP-%d), JSON") +} + +func (c *compiler) emitConvert_bj(offset int) { + c.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-offset].(*evalBytes) + vm.stack[vm.sp-offset] = evalConvert_bj(arg) + return 1 + }, "CONV VARBINARY(SP-%d), JSON") +} + +func (c *compiler) emitConvert_xc(offset int, t sqltypes.Type, collation collations.ID, length int, hasLength bool) { + if hasLength { + c.emit(func(vm *VirtualMachine) int { + arg, err := evalToVarchar(vm.stack[vm.sp-offset], collation, true) + if err != nil { + vm.stack[vm.sp-offset] = nil + } else { + arg.truncateInPlace(length) + arg.tt = int16(t) + vm.stack[vm.sp-offset] = arg + } + return 1 + }, "CONV (SP-%d), VARCHAR[%d]", offset, length) + } else { + c.emit(func(vm *VirtualMachine) int { + arg, err := evalToVarchar(vm.stack[vm.sp-offset], collation, true) + if err != nil { + vm.stack[vm.sp-offset] = nil + } else { + arg.tt = int16(t) + vm.stack[vm.sp-offset] = arg + } + return 1 + }, "CONV (SP-%d), VARCHAR", offset) + } +} + +func (c *compiler) emitBitShiftLeft_bu() { + c.adjustStack(-1) + + c.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalBytes) + r := vm.stack[vm.sp-1].(*evalUint64) + + var ( + bits = int(r.u & 7) + bytes = int(r.u >> 3) + length = len(l.bytes) + out = make([]byte, length) + ) + + for i := 0; i < length; i++ { + pos := i + bytes + 1 + switch { + case pos < length: + out[i] = l.bytes[pos] >> (8 - bits) + fallthrough + case pos == length: + out[i] |= l.bytes[pos-1] << bits + } + } + l.bytes = out + + vm.sp-- + return 1 + }, "<< BINARY(SP-2), UINT64(SP-1)") +} + +func (c *compiler) emitBitShiftRight_bu() { + c.adjustStack(-1) + + c.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalBytes) + r := vm.stack[vm.sp-1].(*evalUint64) + + var ( + bits = int(r.u & 7) + bytes = int(r.u >> 3) + length = len(l.bytes) + out = make([]byte, length) + ) + + for i := length - 1; i >= 0; i-- { + switch { + case i > bytes: + out[i] = l.bytes[i-bytes-1] << (8 - bits) + fallthrough + case i == bytes: + out[i] |= l.bytes[i-bytes] >> bits + } + } + l.bytes = out + + vm.sp-- + return 1 + }, ">> BINARY(SP-2), UINT64(SP-1)") +} + +func (c *compiler) emitBitShiftLeft_uu() { + c.adjustStack(-1) + + c.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalUint64) + r := vm.stack[vm.sp-1].(*evalUint64) + l.u = l.u << r.u + + vm.sp-- + return 1 + }, "<< UINT64(SP-2), UINT64(SP-1)") +} + +func (c *compiler) emitBitShiftRight_uu() { + c.adjustStack(-1) + + c.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalUint64) + r := vm.stack[vm.sp-1].(*evalUint64) + l.u = l.u >> r.u + + vm.sp-- + return 1 + }, ">> UINT64(SP-2), UINT64(SP-1)") +} + +func (c *compiler) emitBitCount_b() { + c.emit(func(vm *VirtualMachine) int { + a := vm.stack[vm.sp-1].(*evalBytes) + count := 0 + for _, b := range a.bytes { + count += bits.OnesCount8(b) + } + vm.stack[vm.sp-1] = vm.arena.newEvalInt64(int64(count)) + return 1 + }, "BIT_COUNT BINARY(SP-1)") +} + +func (c *compiler) emitBitCount_u() { + c.emit(func(vm *VirtualMachine) int { + a := vm.stack[vm.sp-1].(*evalUint64) + vm.stack[vm.sp-1] = vm.arena.newEvalInt64(int64(bits.OnesCount64(a.u))) + return 1 + }, "BIT_COUNT UINT64(SP-1)") +} + +func (c *compiler) emitBitwiseNot_b() { + c.emit(func(vm *VirtualMachine) int { + a := vm.stack[vm.sp-1].(*evalBytes) + for i := range a.bytes { + a.bytes[i] = ^a.bytes[i] + } + return 1 + }, "~ BINARY(SP-1)") +} + +func (c *compiler) emitBitwiseNot_u() { + c.emit(func(vm *VirtualMachine) int { + a := vm.stack[vm.sp-1].(*evalUint64) + a.u = ^a.u + return 1 + }, "~ UINT64(SP-1)") +} + +func (c *compiler) emitCmpCase(cases int, hasElse bool, tt sqltypes.Type, cc collations.TypedCollation) { + elseOffset := 0 + if hasElse { + elseOffset = 1 + } + + stackDepth := 2*cases + elseOffset + c.adjustStack(-(stackDepth - 1)) + + c.emit(func(vm *VirtualMachine) int { + end := vm.sp - elseOffset + for sp := vm.sp - stackDepth; sp < end; sp += 2 { + if vm.stack[sp].(*evalInt64).i != 0 { + vm.stack[vm.sp-stackDepth], vm.err = evalCoerce(vm.stack[sp+1], tt, cc.Collation) + goto done + } + } + if elseOffset != 0 { + vm.stack[vm.sp-stackDepth], vm.err = evalCoerce(vm.stack[vm.sp-1], tt, cc.Collation) + } else { + vm.stack[vm.sp-stackDepth] = nil + } + done: + vm.sp -= stackDepth - 1 + return 1 + }, "CASE [%d cases, else = %v]", cases, hasElse) +} + +func (c *compiler) emitConvert_iB(offset int) { + c.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-offset] + vm.stack[vm.sp-offset] = newEvalBool(arg != nil && arg.(*evalInt64).i != 0) + return 1 + }, "CONV INT64(SP-%d), BOOL", offset) +} + +func (c *compiler) emitConvert_uB(offset int) { + c.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-offset] + vm.stack[vm.sp-offset] = newEvalBool(arg != nil && arg.(*evalUint64).u != 0) + return 1 + }, "CONV UINT64(SP-%d), BOOL", offset) +} + +func (c *compiler) emitConvert_fB(offset int) { + c.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-offset] + vm.stack[vm.sp-offset] = newEvalBool(arg != nil && arg.(*evalFloat).f != 0.0) + return 1 + }, "CONV FLOAT64(SP-%d), BOOL", offset) +} + +func (c *compiler) emitConvert_dB(offset int) { + c.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-offset] + vm.stack[vm.sp-offset] = newEvalBool(arg != nil && !arg.(*evalDecimal).dec.IsZero()) + return 1 + }, "CONV DECIMAL(SP-%d), BOOL", offset) +} + +func (c *compiler) emitConvert_bB(offset int) { + c.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-offset] + vm.stack[vm.sp-offset] = newEvalBool(arg != nil && parseStringToFloat(arg.(*evalBytes).string()) != 0.0) + return 1 + }, "CONV VARBINARY(SP-%d), BOOL", offset) +} + +func (c *compiler) emitNeg_i() { + c.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-1].(*evalInt64) + if arg.i == math.MinInt64 { + vm.err = errDeoptimize + } else { + arg.i = -arg.i + } + return 1 + }, "NEG INT64(SP-1)") +} + +func (c *compiler) emitNeg_hex() { + c.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-1].(*evalUint64) + vm.stack[vm.sp-1] = vm.arena.newEvalFloat(-float64(arg.u)) + return 1 + }, "NEG HEX(SP-1)") +} + +func (c *compiler) emitNeg_u() { + c.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-1].(*evalUint64) + if arg.u > math.MaxInt64+1 { + vm.err = errDeoptimize + } else { + vm.stack[vm.sp-1] = vm.arena.newEvalInt64(-int64(arg.u)) + } + return 1 + }, "NEG UINT64(SP-1)") +} + +func (c *compiler) emitNeg_f() { + c.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-1].(*evalFloat) + arg.f = -arg.f + return 1 + }, "NEG FLOAT64(SP-1)") +} + +func (c *compiler) emitNeg_d() { + c.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-1].(*evalDecimal) + arg.dec = arg.dec.NegInPlace() + return 1 + }, "NEG DECIMAL(SP-1)") +} diff --git a/go/vt/vtgate/evalengine/compiler_test.go b/go/vt/vtgate/evalengine/compiler_test.go new file mode 100644 index 00000000000..6531cf0248b --- /dev/null +++ b/go/vt/vtgate/evalengine/compiler_test.go @@ -0,0 +1,242 @@ +package evalengine_test + +import ( + "fmt" + "strconv" + "strings" + "testing" + + "github.com/olekukonko/tablewriter" + + "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/sqltypes" + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/vt/vtgate/evalengine/testcases" +) + +type Tracker struct { + buf strings.Builder + tbl *tablewriter.Table + supported, total int +} + +func NewTracker() *Tracker { + track := &Tracker{} + track.tbl = tablewriter.NewWriter(&track.buf) + return track +} + +func (s *Tracker) Add(name string, supported, total int) { + s.tbl.Append([]string{ + name[strings.LastIndexByte(name, '.')+1:], + strconv.Itoa(supported), + strconv.Itoa(total), + fmt.Sprintf("%.02f%%", 100*float64(supported)/float64(total)), + }) + s.supported += supported + s.total += total +} + +func (s *Tracker) String() string { + s.tbl.SetBorder(false) + s.tbl.SetColumnAlignment([]int{ + tablewriter.ALIGN_LEFT, + tablewriter.ALIGN_RIGHT, + tablewriter.ALIGN_RIGHT, + tablewriter.ALIGN_RIGHT, + }) + s.tbl.SetFooterAlignment(tablewriter.ALIGN_RIGHT) + s.tbl.SetFooter([]string{ + "", + strconv.Itoa(s.supported), + strconv.Itoa(s.total), + fmt.Sprintf("%.02f%%", 100*float64(s.supported)/float64(s.total)), + }) + s.tbl.Render() + return s.buf.String() +} + +func TestCompilerReference(t *testing.T) { + track := NewTracker() + + for _, tc := range testcases.Cases { + name := fmt.Sprintf("%T", tc) + t.Run(name, func(t *testing.T) { + var supported, total int + env := tc.Environment() + + tc.Test(func(query string, row []sqltypes.Value) { + stmt, err := sqlparser.Parse("SELECT " + query) + if err != nil { + // no need to test un-parseable queries + return + } + + queryAST := stmt.(*sqlparser.Select).SelectExprs[0].(*sqlparser.AliasedExpr).Expr + converted, err := evalengine.TranslateEx(queryAST, &evalengine.LookupIntegrationTest{collations.CollationUtf8mb4ID}, false) + if err != nil { + return + } + + env.Row = row + expected, err := env.Evaluate(converted) + if err != nil { + // do not attempt failing queries for now + return + } + + total++ + + program, err := evalengine.Compile(converted, env.Fields) + if err != nil { + if vterrors.Code(err) != vtrpcpb.Code_UNIMPLEMENTED { + t.Errorf("failed compilation:\nSQL: %s\nError: %s", query, err) + } + return + } + + var vm evalengine.VirtualMachine + res, err := func() (res evalengine.EvalResult, err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("PANIC: %v", r) + } + }() + res, err = vm.Run(program, env.Row) + return + }() + + if err != nil { + t.Errorf("failed evaluation from compiler:\nSQL: %s\nError: %s", query, err) + return + } + + eval := expected.String() + comp := res.String() + + if eval != comp { + t.Errorf("bad evaluation from compiler:\nSQL: %s\nEval: %s\nComp: %s", query, eval, comp) + return + } + + supported++ + }) + + track.Add(name, supported, total) + }) + } + + t.Logf("\n%s", track.String()) +} + +type debugCompiler struct { + t testing.TB +} + +func (d *debugCompiler) Disasm(ins string, args ...any) { + ins = fmt.Sprintf(ins, args...) + d.t.Logf("> %s", ins) +} + +func (d *debugCompiler) Stack(old, new int) { + d.t.Logf("\tsp = %d -> %d", old, new) +} + +func TestCompiler(t *testing.T) { + var testCases = []struct { + expression string + values []sqltypes.Value + result string + }{ + { + expression: "1 + column0", + values: []sqltypes.Value{sqltypes.NewInt64(1)}, + result: "INT64(2)", + }, + { + expression: "1 + column0", + values: []sqltypes.Value{sqltypes.NewFloat64(1)}, + result: "FLOAT64(2)", + }, + { + expression: "1.0e0 - column0", + values: []sqltypes.Value{sqltypes.NewFloat64(1)}, + result: "FLOAT64(0)", + }, + { + expression: "128 - column0", + values: []sqltypes.Value{sqltypes.NewFloat64(1)}, + result: "FLOAT64(127)", + }, + { + expression: "(128 - column0) * 3", + values: []sqltypes.Value{sqltypes.NewFloat64(1)}, + result: "FLOAT64(381)", + }, + { + expression: "1.0e0 < column0", + values: []sqltypes.Value{sqltypes.NewFloat64(2)}, + result: "INT64(1)", + }, + { + expression: "1.0e0 < column0", + values: []sqltypes.Value{sqltypes.NewFloat64(-1)}, + result: "INT64(0)", + }, + { + expression: `'foo' = 'FOO' collate utf8mb4_0900_as_cs`, + result: "INT64(0)", + }, + { + expression: `'foo' < 'bar'`, + result: "INT64(0)", + }, + { + expression: `case when false then 0 else 18446744073709551615 end`, + result: `DECIMAL(18446744073709551615)`, + }, + { + expression: `case when true then _binary "foobar" else 'foo' collate utf8mb4_0900_as_cs end`, + result: `VARCHAR("foobar")`, + }, + { + expression: `- 18446744073709551615`, + result: `DECIMAL(-18446744073709551615)`, + }, + { + expression: `CAST(CAST(true AS JSON) AS BINARY)`, + result: `BLOB("true")`, + }, + } + + for _, tc := range testCases { + t.Run(tc.expression, func(t *testing.T) { + expr := parseExpr(t, tc.expression) + converted, err := evalengine.TranslateEx(expr, &evalengine.LookupIntegrationTest{collations.CollationUtf8mb4ID}, false) + if err != nil { + t.Fatal(err) + } + + compiled, err := evalengine.Compile(converted, makeFields(tc.values), evalengine.WithCompilerLog(&debugCompiler{t})) + if err != nil { + t.Fatal(err) + } + + var vm evalengine.VirtualMachine + // re-run the same evaluation multiple times to ensure results are always consistent + for i := 0; i < 8; i++ { + res, err := vm.Run(compiled, tc.values) + if err != nil { + t.Fatal(err) + } + + if res.String() != tc.result { + t.Fatalf("bad evaluation: got %s, want %s (iteration %d)", res, tc.result, i) + } + } + }) + } +} diff --git a/go/vt/vtgate/evalengine/perf_test.go b/go/vt/vtgate/evalengine/perf_test.go new file mode 100644 index 00000000000..52914b0e7d9 --- /dev/null +++ b/go/vt/vtgate/evalengine/perf_test.go @@ -0,0 +1,60 @@ +package evalengine_test + +import ( + "testing" + + "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/vtgate/evalengine" +) + +func BenchmarkCompilerExpressions(b *testing.B) { + var testCases = []struct { + name string + expression string + values []sqltypes.Value + }{ + {"complex_arith", "((23 + column0) * 4.0e0) = ((column1 / 3.33e0) * 100)", []sqltypes.Value{sqltypes.NewInt64(666), sqltypes.NewUint64(420)}}, + {"comparison_i64", "column0 = 12", []sqltypes.Value{sqltypes.NewInt64(666)}}, + {"comparison_u64", "column0 = 12", []sqltypes.Value{sqltypes.NewUint64(666)}}, + {"comparison_dec", "column0 = 12", []sqltypes.Value{sqltypes.NewDecimal("420")}}, + {"comparison_f", "column0 = 12", []sqltypes.Value{sqltypes.NewFloat64(420.0)}}, + } + + for _, tc := range testCases { + expr := parseExpr(b, tc.expression) + converted, err := evalengine.TranslateEx(expr, &evalengine.LookupIntegrationTest{collations.CollationUtf8mb4ID}, true) + if err != nil { + b.Fatal(err) + } + + compiled, err := evalengine.Compile(converted, makeFields(tc.values)) + if err != nil { + b.Fatal(err) + } + + b.Run(tc.name+"/eval=ast", func(b *testing.B) { + b.ResetTimer() + b.ReportAllocs() + + var env evalengine.ExpressionEnv + env.Row = tc.values + for n := 0; n < b.N; n++ { + _, _ = env.Evaluate(converted) + } + }) + + b.Run(tc.name+"/eval=vm", func(b *testing.B) { + compiled := compiled + values := tc.values + + b.ResetTimer() + b.ReportAllocs() + + var vm evalengine.VirtualMachine + for n := 0; n < b.N; n++ { + _, _ = vm.Run(compiled, values) + } + }) + } +} diff --git a/go/vt/vtgate/evalengine/vm.go b/go/vt/vtgate/evalengine/vm.go new file mode 100644 index 00000000000..f44597edb9a --- /dev/null +++ b/go/vt/vtgate/evalengine/vm.go @@ -0,0 +1,63 @@ +package evalengine + +import ( + "errors" + + "vitess.io/vitess/go/sqltypes" +) + +var errDeoptimize = errors.New("de-optimize") + +type VirtualMachine struct { + arena Arena + row []sqltypes.Value + stack []eval + sp int + + err error + flags struct { + cmp int + null bool + } +} + +type Program struct { + code []frame + original Expr + stack int +} + +func (vm *VirtualMachine) Run(p *Program, row []sqltypes.Value) (EvalResult, error) { + vm.arena.reset() + vm.row = row + vm.sp = 0 + if len(vm.stack) < p.stack { + vm.stack = make([]eval, p.stack) + } + e, err := vm.execute(p) + if err != nil { + if err == errDeoptimize { + var env ExpressionEnv + env.Row = row + return env.Evaluate(p.original) + } + return EvalResult{}, err + } + return EvalResult{e}, nil +} + +func (vm *VirtualMachine) execute(p *Program) (eval, error) { + code := p.code + ip := 0 + + for ip < len(code) { + ip += code[ip](vm) + if vm.err != nil { + return nil, vm.err + } + } + if vm.sp == 0 { + return nil, nil + } + return vm.stack[vm.sp-1], nil +} From 1bc18460a004796be13207d1f7544bdde879a67c Mon Sep 17 00:00:00 2001 From: Vicent Marti Date: Mon, 6 Mar 2023 13:25:45 +0100 Subject: [PATCH 03/42] evalengine: fix test helpers Signed-off-by: Vicent Marti --- go/vt/vtgate/evalengine/compiler_test.go | 22 +++++++++++++++++++++- go/vt/vtgate/evalengine/perf_test.go | 6 +++++- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/go/vt/vtgate/evalengine/compiler_test.go b/go/vt/vtgate/evalengine/compiler_test.go index 6531cf0248b..12032b95ddf 100644 --- a/go/vt/vtgate/evalengine/compiler_test.go +++ b/go/vt/vtgate/evalengine/compiler_test.go @@ -10,6 +10,7 @@ import ( "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/sqltypes" + querypb "vitess.io/vitess/go/vt/proto/query" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" @@ -17,6 +18,21 @@ import ( "vitess.io/vitess/go/vt/vtgate/evalengine/testcases" ) +func makeFields(values []sqltypes.Value) (fields []*querypb.Field) { + for _, v := range values { + field := &querypb.Field{ + Type: v.Type(), + } + if sqltypes.IsText(field.Type) { + field.Charset = uint32(collations.CollationUtf8mb4ID) + } else { + field.Charset = uint32(collations.CollationBinaryID) + } + fields = append(fields, field) + } + return +} + type Tracker struct { buf strings.Builder tbl *tablewriter.Table @@ -214,7 +230,11 @@ func TestCompiler(t *testing.T) { for _, tc := range testCases { t.Run(tc.expression, func(t *testing.T) { - expr := parseExpr(t, tc.expression) + expr, err := sqlparser.ParseExpr(tc.expression) + if err != nil { + t.Fatal(err) + } + converted, err := evalengine.TranslateEx(expr, &evalengine.LookupIntegrationTest{collations.CollationUtf8mb4ID}, false) if err != nil { t.Fatal(err) diff --git a/go/vt/vtgate/evalengine/perf_test.go b/go/vt/vtgate/evalengine/perf_test.go index 52914b0e7d9..74e4e46bdac 100644 --- a/go/vt/vtgate/evalengine/perf_test.go +++ b/go/vt/vtgate/evalengine/perf_test.go @@ -5,6 +5,7 @@ import ( "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtgate/evalengine" ) @@ -22,7 +23,10 @@ func BenchmarkCompilerExpressions(b *testing.B) { } for _, tc := range testCases { - expr := parseExpr(b, tc.expression) + expr, err := sqlparser.ParseExpr(tc.expression) + if err != nil { + b.Fatal(err) + } converted, err := evalengine.TranslateEx(expr, &evalengine.LookupIntegrationTest{collations.CollationUtf8mb4ID}, true) if err != nil { b.Fatal(err) From 0a7fe36e85dd816c445d3fc4ad723353b868c7eb Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Mon, 6 Mar 2023 15:21:49 +0100 Subject: [PATCH 04/42] evalengine/compiler: Add support for REPEAT Signed-off-by: Dirkjan Bussink --- go/vt/vtgate/evalengine/compiler.go | 46 +++++++++++++++++-- .../evalengine/compiler_instructions.go | 33 +++++++++++++ 2 files changed, 75 insertions(+), 4 deletions(-) diff --git a/go/vt/vtgate/evalengine/compiler.go b/go/vt/vtgate/evalengine/compiler.go index 1bce3fbb009..61846f32e6c 100644 --- a/go/vt/vtgate/evalengine/compiler.go +++ b/go/vt/vtgate/evalengine/compiler.go @@ -13,9 +13,10 @@ import ( type frame func(vm *VirtualMachine) int type compiler struct { - ins []frame - log CompilerLog - fields []*querypb.Field + ins []frame + log CompilerLog + fields []*querypb.Field + defaultCollation collations.ID stack struct { cur int @@ -42,9 +43,16 @@ func WithCompilerLog(log CompilerLog) CompilerOption { } } +func WithDefaultCollation(collation collations.ID) CompilerOption { + return func(c *compiler) { + c.defaultCollation = collation + } +} + func Compile(expr Expr, fields []*querypb.Field, options ...CompilerOption) (*Program, error) { comp := compiler{ - fields: fields, + fields: fields, + defaultCollation: collations.Default(), } for _, opt := range options { opt(&comp) @@ -224,6 +232,8 @@ func (c *compiler) compileCallable(call callable) (ctype, error) { return c.compileJSONUnquote(call) case *builtinJSONContainsPath: return c.compileJSONContainsPath(call) + case *builtinRepeat: + return c.compileRepeat(call) default: return ctype{}, c.unsupported(call) } @@ -424,6 +434,8 @@ func (c *compiler) compileToInt64(ct ctype, offset int) ctype { switch ct.Type { case sqltypes.Int64: return ct + case sqltypes.Uint64: + c.emitConvert_ui(offset) // TODO: specialization default: c.emitConvert_xi(offset) @@ -1239,3 +1251,29 @@ func (c *compiler) compileMultiComparison(call *builtinMultiComparison) (ctype, c.emitNullCheckn(skip, len(call.Arguments)) */ } + +func (c *compiler) compileRepeat(expr *builtinRepeat) (ctype, error) { + str, err := c.compileExpr(expr.Arguments[0]) + if err != nil { + return ctype{}, err + } + + repeat, err := c.compileExpr(expr.Arguments[1]) + if err != nil { + return ctype{}, err + } + + skip := c.jumpFrom() + c.emitNullCheck2(skip) + + switch { + case sqltypes.IsText(str.Type) || sqltypes.IsBinary(str.Type): + default: + c.emitConvert_xc(2, str.Type, c.defaultCollation, 0, false) + } + _ = c.compileToInt64(repeat, 1) + + c.emitRepeat(1) + c.jumpDestination(skip) + return ctype{Type: sqltypes.VarChar, Col: str.Col}, nil +} diff --git a/go/vt/vtgate/evalengine/compiler_instructions.go b/go/vt/vtgate/evalengine/compiler_instructions.go index 7e9f81f0cab..fd1396f977f 100644 --- a/go/vt/vtgate/evalengine/compiler_instructions.go +++ b/go/vt/vtgate/evalengine/compiler_instructions.go @@ -1,6 +1,7 @@ package evalengine import ( + "bytes" "math" "math/bits" @@ -967,6 +968,14 @@ func (c *compiler) emitBitOp_uu(op bitwiseOp) { } } +func (c *compiler) emitConvert_ui(offset int) { + c.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-offset].(*evalUint64) + vm.stack[vm.sp-offset] = vm.arena.newEvalInt64(int64(arg.u)) + return 1 + }, "CONV UINT64(SP-%d), INT64", offset) +} + func (c *compiler) emitConvert_iu(offset int) { c.emit(func(vm *VirtualMachine) int { arg := vm.stack[vm.sp-offset].(*evalInt64) @@ -1309,3 +1318,27 @@ func (c *compiler) emitNeg_d() { return 1 }, "NEG DECIMAL(SP-1)") } + +func (c *compiler) emitRepeat(i int) { + c.adjustStack(-1) + + c.emit(func(vm *VirtualMachine) int { + str := vm.stack[vm.sp-2].(*evalBytes) + repeat := vm.stack[vm.sp-1].(*evalInt64) + + if repeat.i < 0 { + repeat.i = 0 + } + + if !checkMaxLength(int64(len(str.bytes)), repeat.i) { + vm.stack[vm.sp-2] = nil + vm.sp-- + return 1 + } + + str.tt = int16(sqltypes.VarChar) + str.bytes = bytes.Repeat(str.bytes, int(repeat.i)) + vm.sp-- + return 1 + }, "REPEAT VARCHAR(SP-2) INT64(SP-1)") +} From e689facfc15c9af69e43f461f8c73d7bc5893ab0 Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Mon, 6 Mar 2023 16:04:15 +0100 Subject: [PATCH 05/42] evalengine/compiler: Add support for base64 functions Signed-off-by: Dirkjan Bussink --- go/vt/vtgate/evalengine/compiler.go | 57 +++++++++++++++++++ .../evalengine/compiler_instructions.go | 31 ++++++++++ 2 files changed, 88 insertions(+) diff --git a/go/vt/vtgate/evalengine/compiler.go b/go/vt/vtgate/evalengine/compiler.go index 61846f32e6c..22c3560f8a6 100644 --- a/go/vt/vtgate/evalengine/compiler.go +++ b/go/vt/vtgate/evalengine/compiler.go @@ -234,6 +234,10 @@ func (c *compiler) compileCallable(call callable) (ctype, error) { return c.compileJSONContainsPath(call) case *builtinRepeat: return c.compileRepeat(call) + case *builtinToBase64: + return c.compileToBase64(call) + case *builtinFromBase64: + return c.compileFromBase64(call) default: return ctype{}, c.unsupported(call) } @@ -1277,3 +1281,56 @@ func (c *compiler) compileRepeat(expr *builtinRepeat) (ctype, error) { c.jumpDestination(skip) return ctype{Type: sqltypes.VarChar, Col: str.Col}, nil } + +func (c *compiler) compileToBase64(call *builtinToBase64) (ctype, error) { + str, err := c.compileExpr(call.Arguments[0]) + if err != nil { + return ctype{}, err + } + + skip := c.jumpFrom() + c.emitNullCheck1(skip) + + switch { + case sqltypes.IsText(str.Type) || sqltypes.IsBinary(str.Type): + default: + c.emitConvert_xc(1, str.Type, c.defaultCollation, 0, false) + } + + t := sqltypes.VarChar + if str.Type == sqltypes.Blob || str.Type == sqltypes.TypeJSON { + t = sqltypes.Text + } + + col := collations.TypedCollation{ + Collation: c.defaultCollation, + Coercibility: collations.CoerceCoercible, + Repertoire: collations.RepertoireASCII, + } + + c.emitToBase64(t, col) + c.jumpDestination(skip) + + return ctype{Type: t, Col: col}, nil +} + +func (c *compiler) compileFromBase64(call *builtinFromBase64) (ctype, error) { + str, err := c.compileExpr(call.Arguments[0]) + if err != nil { + return ctype{}, err + } + + skip := c.jumpFrom() + c.emitNullCheck1(skip) + + switch { + case sqltypes.IsText(str.Type) || sqltypes.IsBinary(str.Type): + default: + c.emitConvert_xc(1, str.Type, c.defaultCollation, 0, false) + } + + c.emitFromBase64() + c.jumpDestination(skip) + + return ctype{Type: sqltypes.VarBinary, Col: collationBinary}, nil +} diff --git a/go/vt/vtgate/evalengine/compiler_instructions.go b/go/vt/vtgate/evalengine/compiler_instructions.go index fd1396f977f..00764c677a2 100644 --- a/go/vt/vtgate/evalengine/compiler_instructions.go +++ b/go/vt/vtgate/evalengine/compiler_instructions.go @@ -1342,3 +1342,34 @@ func (c *compiler) emitRepeat(i int) { return 1 }, "REPEAT VARCHAR(SP-2) INT64(SP-1)") } + +func (c *compiler) emitToBase64(t sqltypes.Type, col collations.TypedCollation) { + c.emit(func(vm *VirtualMachine) int { + str := vm.stack[vm.sp-1].(*evalBytes) + + encoded := make([]byte, mysqlBase64.EncodedLen(len(str.bytes))) + mysqlBase64.Encode(encoded, str.bytes) + + str.tt = int16(t) + str.col = col + str.bytes = encoded + return 1 + }, "TO_BASE64 VARCHAR(SP-1)") +} + +func (c *compiler) emitFromBase64() { + c.emit(func(vm *VirtualMachine) int { + str := vm.stack[vm.sp-1].(*evalBytes) + + decoded := make([]byte, mysqlBase64.DecodedLen(len(str.bytes))) + + n, err := mysqlBase64.Decode(decoded, str.bytes) + if err != nil { + vm.stack[vm.sp-1] = nil + return 1 + } + str.tt = int16(sqltypes.VarBinary) + str.bytes = decoded[:n] + return 1 + }, "FROM_BASE64 VARCHAR(SP-1)") +} From 12697d136bbe0856b586f1e0f946626d142ef27d Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Mon, 6 Mar 2023 16:20:44 +0100 Subject: [PATCH 06/42] evalengine/compiler: Implement UPPER / LOWER Signed-off-by: Dirkjan Bussink --- go/vt/vtgate/evalengine/compiler.go | 39 +++++++++++++++---- .../evalengine/compiler_instructions.go | 32 +++++++++++++++ 2 files changed, 63 insertions(+), 8 deletions(-) diff --git a/go/vt/vtgate/evalengine/compiler.go b/go/vt/vtgate/evalengine/compiler.go index 22c3560f8a6..83923106665 100644 --- a/go/vt/vtgate/evalengine/compiler.go +++ b/go/vt/vtgate/evalengine/compiler.go @@ -238,6 +238,8 @@ func (c *compiler) compileCallable(call callable) (ctype, error) { return c.compileToBase64(call) case *builtinFromBase64: return c.compileFromBase64(call) + case *builtinChangeCase: + return c.compileChangeCase(call) default: return ctype{}, c.unsupported(call) } @@ -1273,7 +1275,7 @@ func (c *compiler) compileRepeat(expr *builtinRepeat) (ctype, error) { switch { case sqltypes.IsText(str.Type) || sqltypes.IsBinary(str.Type): default: - c.emitConvert_xc(2, str.Type, c.defaultCollation, 0, false) + c.emitConvert_xc(2, sqltypes.VarChar, c.defaultCollation, 0, false) } _ = c.compileToInt64(repeat, 1) @@ -1291,17 +1293,17 @@ func (c *compiler) compileToBase64(call *builtinToBase64) (ctype, error) { skip := c.jumpFrom() c.emitNullCheck1(skip) - switch { - case sqltypes.IsText(str.Type) || sqltypes.IsBinary(str.Type): - default: - c.emitConvert_xc(1, str.Type, c.defaultCollation, 0, false) - } - t := sqltypes.VarChar if str.Type == sqltypes.Blob || str.Type == sqltypes.TypeJSON { t = sqltypes.Text } + switch { + case sqltypes.IsText(str.Type) || sqltypes.IsBinary(str.Type): + default: + c.emitConvert_xc(1, t, c.defaultCollation, 0, false) + } + col := collations.TypedCollation{ Collation: c.defaultCollation, Coercibility: collations.CoerceCoercible, @@ -1326,7 +1328,7 @@ func (c *compiler) compileFromBase64(call *builtinFromBase64) (ctype, error) { switch { case sqltypes.IsText(str.Type) || sqltypes.IsBinary(str.Type): default: - c.emitConvert_xc(1, str.Type, c.defaultCollation, 0, false) + c.emitConvert_xc(1, sqltypes.VarBinary, c.defaultCollation, 0, false) } c.emitFromBase64() @@ -1334,3 +1336,24 @@ func (c *compiler) compileFromBase64(call *builtinFromBase64) (ctype, error) { return ctype{Type: sqltypes.VarBinary, Col: collationBinary}, nil } + +func (c *compiler) compileChangeCase(call *builtinChangeCase) (ctype, error) { + str, err := c.compileExpr(call.Arguments[0]) + if err != nil { + return ctype{}, err + } + + skip := c.jumpFrom() + c.emitNullCheck1(skip) + + switch { + case sqltypes.IsText(str.Type) || sqltypes.IsBinary(str.Type): + default: + c.emitConvert_xc(1, sqltypes.VarChar, c.defaultCollation, 0, false) + } + + c.emitChangeCase(call.upcase) + c.jumpDestination(skip) + + return ctype{Type: sqltypes.VarChar, Col: str.Col}, nil +} diff --git a/go/vt/vtgate/evalengine/compiler_instructions.go b/go/vt/vtgate/evalengine/compiler_instructions.go index 00764c677a2..b1c80489518 100644 --- a/go/vt/vtgate/evalengine/compiler_instructions.go +++ b/go/vt/vtgate/evalengine/compiler_instructions.go @@ -1373,3 +1373,35 @@ func (c *compiler) emitFromBase64() { return 1 }, "FROM_BASE64 VARCHAR(SP-1)") } + +func (c *compiler) emitChangeCase(upcase bool) { + if upcase { + c.emit(func(vm *VirtualMachine) int { + str := vm.stack[vm.sp-1].(*evalBytes) + + coll := collations.Local().LookupByID(str.col.Collation) + csa, ok := coll.(collations.CaseAwareCollation) + if !ok { + vm.err = vterrors.Errorf(vtrpc.Code_UNIMPLEMENTED, "not implemented") + } else { + str.bytes = csa.ToUpper(nil, str.bytes) + } + str.tt = int16(sqltypes.VarChar) + return 1 + }, "UPPER VARCHAR(SP-1)") + } else { + c.emit(func(vm *VirtualMachine) int { + str := vm.stack[vm.sp-1].(*evalBytes) + + coll := collations.Local().LookupByID(str.col.Collation) + csa, ok := coll.(collations.CaseAwareCollation) + if !ok { + vm.err = vterrors.Errorf(vtrpc.Code_UNIMPLEMENTED, "not implemented") + } else { + str.bytes = csa.ToLower(nil, str.bytes) + } + str.tt = int16(sqltypes.VarChar) + return 1 + }, "LOWER VARCHAR(SP-1)") + } +} From 679e49492003718734f0ab32a9ec0039957705d2 Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Mon, 6 Mar 2023 16:34:17 +0100 Subject: [PATCH 07/42] evalengine/compiler: Implement string length functions Signed-off-by: Dirkjan Bussink --- go/vt/vtgate/evalengine/compiler.go | 35 +++++++++++++++++++ .../evalengine/compiler_instructions.go | 31 ++++++++++++++++ 2 files changed, 66 insertions(+) diff --git a/go/vt/vtgate/evalengine/compiler.go b/go/vt/vtgate/evalengine/compiler.go index 83923106665..5de2af9eb48 100644 --- a/go/vt/vtgate/evalengine/compiler.go +++ b/go/vt/vtgate/evalengine/compiler.go @@ -240,6 +240,12 @@ func (c *compiler) compileCallable(call callable) (ctype, error) { return c.compileFromBase64(call) case *builtinChangeCase: return c.compileChangeCase(call) + case *builtinCharLength: + return c.compileLength(call, charLen) + case *builtinLength: + return c.compileLength(call, byteLen) + case *builtinBitLength: + return c.compileLength(call, bitLen) default: return ctype{}, c.unsupported(call) } @@ -1357,3 +1363,32 @@ func (c *compiler) compileChangeCase(call *builtinChangeCase) (ctype, error) { return ctype{Type: sqltypes.VarChar, Col: str.Col}, nil } + +type lengthOp int + +const ( + charLen lengthOp = iota + byteLen + bitLen +) + +func (c *compiler) compileLength(call callable, op lengthOp) (ctype, error) { + str, err := c.compileExpr(call.callable()[0]) + if err != nil { + return ctype{}, err + } + + skip := c.jumpFrom() + c.emitNullCheck1(skip) + + switch { + case sqltypes.IsText(str.Type) || sqltypes.IsBinary(str.Type): + default: + c.emitConvert_xc(1, sqltypes.VarChar, c.defaultCollation, 0, false) + } + + c.emitLength(op) + c.jumpDestination(skip) + + return ctype{Type: sqltypes.Int64, Col: collationNumeric}, nil +} diff --git a/go/vt/vtgate/evalengine/compiler_instructions.go b/go/vt/vtgate/evalengine/compiler_instructions.go index b1c80489518..135f3da4d5e 100644 --- a/go/vt/vtgate/evalengine/compiler_instructions.go +++ b/go/vt/vtgate/evalengine/compiler_instructions.go @@ -6,6 +6,7 @@ import ( "math/bits" "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/mysql/collations/charset" "vitess.io/vitess/go/slices2" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/proto/vtrpc" @@ -1405,3 +1406,33 @@ func (c *compiler) emitChangeCase(upcase bool) { }, "LOWER VARCHAR(SP-1)") } } + +func (c *compiler) emitLength(op lengthOp) { + switch op { + case charLen: + c.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-1].(*evalBytes) + + if sqltypes.IsBinary(arg.SQLType()) { + vm.stack[vm.sp-1] = vm.arena.newEvalInt64(int64(len(arg.bytes))) + } else { + coll := collations.Local().LookupByID(arg.col.Collation) + count := charset.Length(coll.Charset(), arg.bytes) + vm.stack[vm.sp-1] = vm.arena.newEvalInt64(int64(count)) + } + return 1 + }, "CHAR_LENGTH VARCHAR(SP-1)") + case byteLen: + c.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-1].(*evalBytes) + vm.stack[vm.sp-1] = vm.arena.newEvalInt64(int64(len(arg.bytes))) + return 1 + }, "LENGTH VARCHAR(SP-1)") + case bitLen: + c.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-1].(*evalBytes) + vm.stack[vm.sp-1] = vm.arena.newEvalInt64(int64(len(arg.bytes) * 8)) + return 1 + }, "BIT_LENGTH VARCHAR(SP-1)") + } +} From 469f9147dbeb2e11ea17c2e4fc24d7340305baba Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Mon, 6 Mar 2023 16:40:25 +0100 Subject: [PATCH 08/42] evalengine/compiler: Implement ASCII() function Signed-off-by: Dirkjan Bussink --- go/vt/vtgate/evalengine/compiler.go | 23 +++++++++++++++++++ .../evalengine/compiler_instructions.go | 12 ++++++++++ 2 files changed, 35 insertions(+) diff --git a/go/vt/vtgate/evalengine/compiler.go b/go/vt/vtgate/evalengine/compiler.go index 5de2af9eb48..cae5088d475 100644 --- a/go/vt/vtgate/evalengine/compiler.go +++ b/go/vt/vtgate/evalengine/compiler.go @@ -246,6 +246,8 @@ func (c *compiler) compileCallable(call callable) (ctype, error) { return c.compileLength(call, byteLen) case *builtinBitLength: return c.compileLength(call, bitLen) + case *builtinASCII: + return c.compileASCII(call) default: return ctype{}, c.unsupported(call) } @@ -1392,3 +1394,24 @@ func (c *compiler) compileLength(call callable, op lengthOp) (ctype, error) { return ctype{Type: sqltypes.Int64, Col: collationNumeric}, nil } + +func (c *compiler) compileASCII(call *builtinASCII) (ctype, error) { + str, err := c.compileExpr(call.Arguments[0]) + if err != nil { + return ctype{}, err + } + + skip := c.jumpFrom() + c.emitNullCheck1(skip) + + switch { + case sqltypes.IsText(str.Type) || sqltypes.IsBinary(str.Type): + default: + c.emitConvert_xc(1, sqltypes.VarChar, c.defaultCollation, 0, false) + } + + c.emitASCII() + c.jumpDestination(skip) + + return ctype{Type: sqltypes.Int64, Col: collationNumeric}, nil +} diff --git a/go/vt/vtgate/evalengine/compiler_instructions.go b/go/vt/vtgate/evalengine/compiler_instructions.go index 135f3da4d5e..417fab93829 100644 --- a/go/vt/vtgate/evalengine/compiler_instructions.go +++ b/go/vt/vtgate/evalengine/compiler_instructions.go @@ -1436,3 +1436,15 @@ func (c *compiler) emitLength(op lengthOp) { }, "BIT_LENGTH VARCHAR(SP-1)") } } + +func (c *compiler) emitASCII() { + c.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-1].(*evalBytes) + if len(arg.bytes) == 0 { + vm.stack[vm.sp-1] = vm.arena.newEvalInt64(0) + } else { + vm.stack[vm.sp-1] = vm.arena.newEvalInt64(int64(arg.bytes[0])) + } + return 1 + }, "ASCII VARCHAR(SP-1)") +} From c7b3da3699b9aad4ba521e4e336c54dbe37b3dee Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Tue, 7 Mar 2023 10:07:47 +0100 Subject: [PATCH 09/42] evalengine/compiler: Implement LIKE support Signed-off-by: Dirkjan Bussink --- go/vt/vtgate/evalengine/compiler.go | 73 +++++++++++++++++++ .../evalengine/compiler_instructions.go | 46 ++++++++++++ 2 files changed, 119 insertions(+) diff --git a/go/vt/vtgate/evalengine/compiler.go b/go/vt/vtgate/evalengine/compiler.go index cae5088d475..2d621ecaf69 100644 --- a/go/vt/vtgate/evalengine/compiler.go +++ b/go/vt/vtgate/evalengine/compiler.go @@ -214,6 +214,9 @@ func (c *compiler) compileExpr(expr Expr) (ctype, error) { case *CaseExpr: return c.compileCase(expr) + case *LikeExpr: + return c.compileLike(expr) + case callable: return c.compileCallable(expr) @@ -1415,3 +1418,73 @@ func (c *compiler) compileASCII(call *builtinASCII) (ctype, error) { return ctype{Type: sqltypes.Int64, Col: collationNumeric}, nil } + +func (c *compiler) compileLike(expr *LikeExpr) (ctype, error) { + lt, err := c.compileExpr(expr.Left) + if err != nil { + return ctype{}, err + } + + rt, err := c.compileExpr(expr.Right) + if err != nil { + return ctype{}, err + } + + skip := c.jumpFrom() + c.emitNullCheck2(skip) + + if !sqltypes.IsText(lt.Type) && !sqltypes.IsBinary(lt.Type) { + c.emitConvert_xc(2, sqltypes.VarChar, c.defaultCollation, 0, false) + lt.Col = collations.TypedCollation{ + Collation: c.defaultCollation, + Coercibility: collations.CoerceCoercible, + Repertoire: collations.RepertoireASCII, + } + } + + if !sqltypes.IsText(rt.Type) && !sqltypes.IsBinary(rt.Type) { + c.emitConvert_xc(1, sqltypes.VarChar, c.defaultCollation, 0, false) + rt.Col = collations.TypedCollation{ + Collation: c.defaultCollation, + Coercibility: collations.CoerceCoercible, + Repertoire: collations.RepertoireASCII, + } + } + + var merged collations.TypedCollation + var coerceLeft collations.Coercion + var coerceRight collations.Coercion + var env = collations.Local() + + if lt.Col.Collation != rt.Col.Collation { + merged, coerceLeft, coerceRight, err = env.MergeCollations(lt.Col, rt.Col, collations.CoercionOptions{ + ConvertToSuperset: true, + ConvertWithCoercion: true, + }) + } else { + merged = lt.Col + } + if err != nil { + return ctype{}, err + } + + if coerceLeft == nil && coerceRight == nil { + c.emitLike_collate(expr, env.LookupByID(merged.Collation)) + } else { + if coerceLeft == nil { + coerceLeft = func(dst, in []byte) ([]byte, error) { return in, nil } + } + if coerceRight == nil { + coerceRight = func(dst, in []byte) ([]byte, error) { return in, nil } + } + c.emitLike_coerce(expr, &compiledCoercion{ + col: env.LookupByID(merged.Collation), + left: coerceLeft, + right: coerceRight, + }) + } + + c.jumpDestination(skip) + + return ctype{Type: sqltypes.Int64, Col: collationNumeric}, nil +} diff --git a/go/vt/vtgate/evalengine/compiler_instructions.go b/go/vt/vtgate/evalengine/compiler_instructions.go index 417fab93829..d4af0047351 100644 --- a/go/vt/vtgate/evalengine/compiler_instructions.go +++ b/go/vt/vtgate/evalengine/compiler_instructions.go @@ -1448,3 +1448,49 @@ func (c *compiler) emitASCII() { return 1 }, "ASCII VARCHAR(SP-1)") } + +func (c *compiler) emitLike_collate(expr *LikeExpr, collation collations.Collation) { + c.adjustStack(-1) + + c.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalBytes) + r := vm.stack[vm.sp-1].(*evalBytes) + vm.sp-- + + match := expr.matchWildcard(l.bytes, r.bytes, collation.ID()) + if match { + vm.stack[vm.sp-1] = evalBoolTrue + } else { + vm.stack[vm.sp-1] = evalBoolFalse + } + return 1 + }, "LIKE VARCHAR(SP-2), VARCHAR(SP-1) COLLATE '%s'", collation.Name()) +} + +func (c *compiler) emitLike_coerce(expr *LikeExpr, coercion *compiledCoercion) { + c.adjustStack(-1) + + c.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalBytes) + r := vm.stack[vm.sp-1].(*evalBytes) + vm.sp-- + + var bl, br []byte + bl, vm.err = coercion.left(nil, l.bytes) + if vm.err != nil { + return 0 + } + br, vm.err = coercion.right(nil, r.bytes) + if vm.err != nil { + return 0 + } + + match := expr.matchWildcard(bl, br, coercion.col.ID()) + if match { + vm.stack[vm.sp-1] = evalBoolTrue + } else { + vm.stack[vm.sp-1] = evalBoolFalse + } + return 1 + }, "LIKE VARCHAR(SP-2), VARCHAR(SP-1) COERCE AND COLLATE '%s'", coercion.col.Name()) +} From 1afc00cc41b811f59ec18589808fe157020c8cef Mon Sep 17 00:00:00 2001 From: Vicent Marti Date: Tue, 7 Mar 2023 11:37:55 +0100 Subject: [PATCH 10/42] evalengine/compiler: fix naming for instructions Signed-off-by: Vicent Marti --- go/vt/vtgate/evalengine/compiler.go | 10 ++--- .../evalengine/compiler_instructions.go | 38 +++++++++---------- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/go/vt/vtgate/evalengine/compiler.go b/go/vt/vtgate/evalengine/compiler.go index 2d621ecaf69..691fb381c4b 100644 --- a/go/vt/vtgate/evalengine/compiler.go +++ b/go/vt/vtgate/evalengine/compiler.go @@ -1290,7 +1290,7 @@ func (c *compiler) compileRepeat(expr *builtinRepeat) (ctype, error) { } _ = c.compileToInt64(repeat, 1) - c.emitRepeat(1) + c.emitFn_REPEAT(1) c.jumpDestination(skip) return ctype{Type: sqltypes.VarChar, Col: str.Col}, nil } @@ -1321,7 +1321,7 @@ func (c *compiler) compileToBase64(call *builtinToBase64) (ctype, error) { Repertoire: collations.RepertoireASCII, } - c.emitToBase64(t, col) + c.emitFn_TO_BASE64(t, col) c.jumpDestination(skip) return ctype{Type: t, Col: col}, nil @@ -1363,7 +1363,7 @@ func (c *compiler) compileChangeCase(call *builtinChangeCase) (ctype, error) { c.emitConvert_xc(1, sqltypes.VarChar, c.defaultCollation, 0, false) } - c.emitChangeCase(call.upcase) + c.emitFn_LUCASE(call.upcase) c.jumpDestination(skip) return ctype{Type: sqltypes.VarChar, Col: str.Col}, nil @@ -1392,7 +1392,7 @@ func (c *compiler) compileLength(call callable, op lengthOp) (ctype, error) { c.emitConvert_xc(1, sqltypes.VarChar, c.defaultCollation, 0, false) } - c.emitLength(op) + c.emitFn_LENGTH(op) c.jumpDestination(skip) return ctype{Type: sqltypes.Int64, Col: collationNumeric}, nil @@ -1413,7 +1413,7 @@ func (c *compiler) compileASCII(call *builtinASCII) (ctype, error) { c.emitConvert_xc(1, sqltypes.VarChar, c.defaultCollation, 0, false) } - c.emitASCII() + c.emitFn_ASCII() c.jumpDestination(skip) return ctype{Type: sqltypes.Int64, Col: collationNumeric}, nil diff --git a/go/vt/vtgate/evalengine/compiler_instructions.go b/go/vt/vtgate/evalengine/compiler_instructions.go index d4af0047351..96c51b25e9e 100644 --- a/go/vt/vtgate/evalengine/compiler_instructions.go +++ b/go/vt/vtgate/evalengine/compiler_instructions.go @@ -1107,7 +1107,7 @@ func (c *compiler) emitBitShiftLeft_bu() { vm.sp-- return 1 - }, "<< BINARY(SP-2), UINT64(SP-1)") + }, "BIT_SHL BINARY(SP-2), UINT64(SP-1)") } func (c *compiler) emitBitShiftRight_bu() { @@ -1137,7 +1137,7 @@ func (c *compiler) emitBitShiftRight_bu() { vm.sp-- return 1 - }, ">> BINARY(SP-2), UINT64(SP-1)") + }, "BIT_SHR BINARY(SP-2), UINT64(SP-1)") } func (c *compiler) emitBitShiftLeft_uu() { @@ -1150,7 +1150,7 @@ func (c *compiler) emitBitShiftLeft_uu() { vm.sp-- return 1 - }, "<< UINT64(SP-2), UINT64(SP-1)") + }, "BIT_SHL UINT64(SP-2), UINT64(SP-1)") } func (c *compiler) emitBitShiftRight_uu() { @@ -1163,7 +1163,7 @@ func (c *compiler) emitBitShiftRight_uu() { vm.sp-- return 1 - }, ">> UINT64(SP-2), UINT64(SP-1)") + }, "BIT_SHR UINT64(SP-2), UINT64(SP-1)") } func (c *compiler) emitBitCount_b() { @@ -1193,7 +1193,7 @@ func (c *compiler) emitBitwiseNot_b() { a.bytes[i] = ^a.bytes[i] } return 1 - }, "~ BINARY(SP-1)") + }, "BIT_NOT BINARY(SP-1)") } func (c *compiler) emitBitwiseNot_u() { @@ -1201,7 +1201,7 @@ func (c *compiler) emitBitwiseNot_u() { a := vm.stack[vm.sp-1].(*evalUint64) a.u = ^a.u return 1 - }, "~ UINT64(SP-1)") + }, "BIT_NOT UINT64(SP-1)") } func (c *compiler) emitCmpCase(cases int, hasElse bool, tt sqltypes.Type, cc collations.TypedCollation) { @@ -1320,7 +1320,7 @@ func (c *compiler) emitNeg_d() { }, "NEG DECIMAL(SP-1)") } -func (c *compiler) emitRepeat(i int) { +func (c *compiler) emitFn_REPEAT(i int) { c.adjustStack(-1) c.emit(func(vm *VirtualMachine) int { @@ -1341,10 +1341,10 @@ func (c *compiler) emitRepeat(i int) { str.bytes = bytes.Repeat(str.bytes, int(repeat.i)) vm.sp-- return 1 - }, "REPEAT VARCHAR(SP-2) INT64(SP-1)") + }, "FN REPEAT VARCHAR(SP-2) INT64(SP-1)") } -func (c *compiler) emitToBase64(t sqltypes.Type, col collations.TypedCollation) { +func (c *compiler) emitFn_TO_BASE64(t sqltypes.Type, col collations.TypedCollation) { c.emit(func(vm *VirtualMachine) int { str := vm.stack[vm.sp-1].(*evalBytes) @@ -1355,7 +1355,7 @@ func (c *compiler) emitToBase64(t sqltypes.Type, col collations.TypedCollation) str.col = col str.bytes = encoded return 1 - }, "TO_BASE64 VARCHAR(SP-1)") + }, "FN TO_BASE64 VARCHAR(SP-1)") } func (c *compiler) emitFromBase64() { @@ -1375,7 +1375,7 @@ func (c *compiler) emitFromBase64() { }, "FROM_BASE64 VARCHAR(SP-1)") } -func (c *compiler) emitChangeCase(upcase bool) { +func (c *compiler) emitFn_LUCASE(upcase bool) { if upcase { c.emit(func(vm *VirtualMachine) int { str := vm.stack[vm.sp-1].(*evalBytes) @@ -1389,7 +1389,7 @@ func (c *compiler) emitChangeCase(upcase bool) { } str.tt = int16(sqltypes.VarChar) return 1 - }, "UPPER VARCHAR(SP-1)") + }, "FN UPPER VARCHAR(SP-1)") } else { c.emit(func(vm *VirtualMachine) int { str := vm.stack[vm.sp-1].(*evalBytes) @@ -1403,11 +1403,11 @@ func (c *compiler) emitChangeCase(upcase bool) { } str.tt = int16(sqltypes.VarChar) return 1 - }, "LOWER VARCHAR(SP-1)") + }, "FN LOWER VARCHAR(SP-1)") } } -func (c *compiler) emitLength(op lengthOp) { +func (c *compiler) emitFn_LENGTH(op lengthOp) { switch op { case charLen: c.emit(func(vm *VirtualMachine) int { @@ -1421,23 +1421,23 @@ func (c *compiler) emitLength(op lengthOp) { vm.stack[vm.sp-1] = vm.arena.newEvalInt64(int64(count)) } return 1 - }, "CHAR_LENGTH VARCHAR(SP-1)") + }, "FN CHAR_LENGTH VARCHAR(SP-1)") case byteLen: c.emit(func(vm *VirtualMachine) int { arg := vm.stack[vm.sp-1].(*evalBytes) vm.stack[vm.sp-1] = vm.arena.newEvalInt64(int64(len(arg.bytes))) return 1 - }, "LENGTH VARCHAR(SP-1)") + }, "FN LENGTH VARCHAR(SP-1)") case bitLen: c.emit(func(vm *VirtualMachine) int { arg := vm.stack[vm.sp-1].(*evalBytes) vm.stack[vm.sp-1] = vm.arena.newEvalInt64(int64(len(arg.bytes) * 8)) return 1 - }, "BIT_LENGTH VARCHAR(SP-1)") + }, "FN BIT_LENGTH VARCHAR(SP-1)") } } -func (c *compiler) emitASCII() { +func (c *compiler) emitFn_ASCII() { c.emit(func(vm *VirtualMachine) int { arg := vm.stack[vm.sp-1].(*evalBytes) if len(arg.bytes) == 0 { @@ -1446,7 +1446,7 @@ func (c *compiler) emitASCII() { vm.stack[vm.sp-1] = vm.arena.newEvalInt64(int64(arg.bytes[0])) } return 1 - }, "ASCII VARCHAR(SP-1)") + }, "FN ASCII VARCHAR(SP-1)") } func (c *compiler) emitLike_collate(expr *LikeExpr, collation collations.Collation) { From 3e4d778b66dd2a33eac15b18340d228586ca3c85 Mon Sep 17 00:00:00 2001 From: Vicent Marti Date: Tue, 7 Mar 2023 15:19:05 +0100 Subject: [PATCH 11/42] evalengine/compiler: use new collation API Signed-off-by: Vicent Marti --- go/vt/vtgate/evalengine/compiler.go | 8 ++++---- go/vt/vtgate/evalengine/compiler_instructions.go | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/go/vt/vtgate/evalengine/compiler.go b/go/vt/vtgate/evalengine/compiler.go index 691fb381c4b..6f1344b3f13 100644 --- a/go/vt/vtgate/evalengine/compiler.go +++ b/go/vt/vtgate/evalengine/compiler.go @@ -297,7 +297,7 @@ func (c *compiler) compileComparison(expr *ComparisonExpr) (ctype, error) { } if coerceLeft == nil && coerceRight == nil { - c.emitCmpString_collate(env.LookupByID(merged.Collation)) + c.emitCmpString_collate(merged.Collation.Get()) } else { if coerceLeft == nil { coerceLeft = func(dst, in []byte) ([]byte, error) { return in, nil } @@ -306,7 +306,7 @@ func (c *compiler) compileComparison(expr *ComparisonExpr) (ctype, error) { coerceRight = func(dst, in []byte) ([]byte, error) { return in, nil } } c.emitCmpString_coerce(&compiledCoercion{ - col: env.LookupByID(merged.Collation), + col: merged.Collation.Get(), left: coerceLeft, right: coerceRight, }) @@ -1469,7 +1469,7 @@ func (c *compiler) compileLike(expr *LikeExpr) (ctype, error) { } if coerceLeft == nil && coerceRight == nil { - c.emitLike_collate(expr, env.LookupByID(merged.Collation)) + c.emitLike_collate(expr, merged.Collation.Get()) } else { if coerceLeft == nil { coerceLeft = func(dst, in []byte) ([]byte, error) { return in, nil } @@ -1478,7 +1478,7 @@ func (c *compiler) compileLike(expr *LikeExpr) (ctype, error) { coerceRight = func(dst, in []byte) ([]byte, error) { return in, nil } } c.emitLike_coerce(expr, &compiledCoercion{ - col: env.LookupByID(merged.Collation), + col: merged.Collation.Get(), left: coerceLeft, right: coerceRight, }) diff --git a/go/vt/vtgate/evalengine/compiler_instructions.go b/go/vt/vtgate/evalengine/compiler_instructions.go index 96c51b25e9e..8b7e4bf1a6f 100644 --- a/go/vt/vtgate/evalengine/compiler_instructions.go +++ b/go/vt/vtgate/evalengine/compiler_instructions.go @@ -1380,7 +1380,7 @@ func (c *compiler) emitFn_LUCASE(upcase bool) { c.emit(func(vm *VirtualMachine) int { str := vm.stack[vm.sp-1].(*evalBytes) - coll := collations.Local().LookupByID(str.col.Collation) + coll := str.col.Collation.Get() csa, ok := coll.(collations.CaseAwareCollation) if !ok { vm.err = vterrors.Errorf(vtrpc.Code_UNIMPLEMENTED, "not implemented") @@ -1394,7 +1394,7 @@ func (c *compiler) emitFn_LUCASE(upcase bool) { c.emit(func(vm *VirtualMachine) int { str := vm.stack[vm.sp-1].(*evalBytes) - coll := collations.Local().LookupByID(str.col.Collation) + coll := str.col.Collation.Get() csa, ok := coll.(collations.CaseAwareCollation) if !ok { vm.err = vterrors.Errorf(vtrpc.Code_UNIMPLEMENTED, "not implemented") @@ -1416,7 +1416,7 @@ func (c *compiler) emitFn_LENGTH(op lengthOp) { if sqltypes.IsBinary(arg.SQLType()) { vm.stack[vm.sp-1] = vm.arena.newEvalInt64(int64(len(arg.bytes))) } else { - coll := collations.Local().LookupByID(arg.col.Collation) + coll := arg.col.Collation.Get() count := charset.Length(coll.Charset(), arg.bytes) vm.stack[vm.sp-1] = vm.arena.newEvalInt64(int64(count)) } From 1d4f093be06601bc40c620c9576f5b60ee5ecc3d Mon Sep 17 00:00:00 2001 From: Vicent Marti Date: Tue, 7 Mar 2023 13:14:33 +0100 Subject: [PATCH 12/42] evalengine/compiler: implement multi-comparisons Signed-off-by: Vicent Marti --- go/vt/vtgate/evalengine/arena.go | 22 +++- go/vt/vtgate/evalengine/compiler.go | 86 +++++++++----- .../evalengine/compiler_instructions.go | 108 ++++++++++++++++++ go/vt/vtgate/evalengine/mysql_test.go | 2 +- 4 files changed, 186 insertions(+), 32 deletions(-) diff --git a/go/vt/vtgate/evalengine/arena.go b/go/vt/vtgate/evalengine/arena.go index 1e7092f0065..a5460f629d1 100644 --- a/go/vt/vtgate/evalengine/arena.go +++ b/go/vt/vtgate/evalengine/arena.go @@ -1,6 +1,10 @@ package evalengine -import "vitess.io/vitess/go/vt/vtgate/evalengine/internal/decimal" +import ( + "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/vtgate/evalengine/internal/decimal" +) type Arena struct { aInt64 []evalInt64 @@ -79,3 +83,19 @@ func (a *Arena) newEvalBytesEmpty() *evalBytes { } return &a.aBytes[len(a.aBytes)-1] } + +func (a *Arena) newEvalBinary(raw []byte) *evalBytes { + b := a.newEvalBytesEmpty() + b.tt = int16(sqltypes.VarBinary) + b.col = collationBinary + b.bytes = raw + return b +} + +func (a *Arena) newEvalText(raw []byte, tc collations.TypedCollation) eval { + b := a.newEvalBytesEmpty() + b.tt = int16(sqltypes.VarChar) + b.col = tc + b.bytes = raw + return b +} diff --git a/go/vt/vtgate/evalengine/compiler.go b/go/vt/vtgate/evalengine/compiler.go index 6f1344b3f13..50c9cec44a8 100644 --- a/go/vt/vtgate/evalengine/compiler.go +++ b/go/vt/vtgate/evalengine/compiler.go @@ -1196,12 +1196,13 @@ func (c *compiler) compileNegate(expr *NegateExpr) (ctype, error) { func (c *compiler) compileMultiComparison(call *builtinMultiComparison) (ctype, error) { var ( - integers int - floats int - decimals int - text int - binary int - args []ctype + integersI int + integersU int + floats int + decimals int + text int + binary int + args []ctype ) /* @@ -1223,9 +1224,9 @@ func (c *compiler) compileMultiComparison(call *builtinMultiComparison) (ctype, switch tt.Type { case sqltypes.Int64: - integers++ + integersI++ case sqltypes.Uint64: - integers++ + integersU++ case sqltypes.Float64: floats++ case sqltypes.Decimal: @@ -1239,34 +1240,59 @@ func (c *compiler) compileMultiComparison(call *builtinMultiComparison) (ctype, } } - return ctype{}, c.unsupported(call) - - /* - if integers == len(args) { + if integersI+integersU == len(args) { + if integersI == len(args) { + c.emitFn_MULTICMP_i(len(args), call.cmp < 0) + return ctype{Type: sqltypes.Int64, Col: collationNumeric}, nil + } + if integersU == len(args) { + c.emitFn_MULTICMP_u(len(args), call.cmp < 0) + return ctype{Type: sqltypes.Uint64, Col: collationNumeric}, nil + } + return c.compileMultiComparison_d(args, call.cmp < 0) + } + if binary > 0 || text > 0 { + if text > 0 { + return c.compileMultiComparison_c(args, call.cmp < 0) + } + c.emitFn_MULTICMP_b(len(args), call.cmp < 0) + return ctype{Type: sqltypes.VarBinary, Col: collationBinary}, nil + } else { + if floats > 0 { for i, tt := range args { - c.compileToInt64(tt, len(args)-i) + c.compileToFloat(tt, len(args)-i) } - return c.emitFn_MULTICMP_i(len(args)) + c.emitFn_MULTICMP_f(len(args), call.cmp < 0) + return ctype{Type: sqltypes.Float64, Col: collationNumeric}, nil } - if binary > 0 || text > 0 { - if binary > 0 { - return c.emitFn_MULTICMP_b(len(args)) - } + if decimals > 0 { + return c.compileMultiComparison_d(args, call.cmp < 0) + } + } + return ctype{}, vterrors.Errorf(vtrpc.Code_INTERNAL, "unexpected argument for GREATEST/LEAST") +} - return c.compileMultiComparison_c() - } else { - if floats > 0 { - return c.compileMultiComparison_f() - } - if decimals > 0 { - return c.compileMultiComparison_d() - } +func (c *compiler) compileMultiComparison_c(args []ctype, lessThan bool) (ctype, error) { + env := collations.Local() + + var ca collationAggregation + for _, arg := range args { + if err := ca.add(env, arg.Col); err != nil { + return ctype{}, err } - return ctype{}, vterrors.Errorf(vtrpc.Code_INTERNAL, "unexpected argument for GREATEST/LEAST") + } - skip := c.jumpFrom() - c.emitNullCheckn(skip, len(call.Arguments)) - */ + tc := ca.result() + c.emitFn_MULTICMP_c(len(args), lessThan, tc) + return ctype{Type: sqltypes.VarChar, Col: tc}, nil +} + +func (c *compiler) compileMultiComparison_d(args []ctype, lessThan bool) (ctype, error) { + for i, tt := range args { + c.compileToDecimal(tt, len(args)-i) + } + c.emitFn_MULTICMP_d(len(args), lessThan) + return ctype{Type: sqltypes.Decimal, Col: collationNumeric}, nil } func (c *compiler) compileRepeat(expr *builtinRepeat) (ctype, error) { diff --git a/go/vt/vtgate/evalengine/compiler_instructions.go b/go/vt/vtgate/evalengine/compiler_instructions.go index 8b7e4bf1a6f..36fde0d75e2 100644 --- a/go/vt/vtgate/evalengine/compiler_instructions.go +++ b/go/vt/vtgate/evalengine/compiler_instructions.go @@ -1494,3 +1494,111 @@ func (c *compiler) emitLike_coerce(expr *LikeExpr, coercion *compiledCoercion) { return 1 }, "LIKE VARCHAR(SP-2), VARCHAR(SP-1) COERCE AND COLLATE '%s'", coercion.col.Name()) } + +func (c *compiler) emitFn_MULTICMP_i(args int, lessThan bool) { + c.adjustStack(-(args - 1)) + + c.emit(func(vm *VirtualMachine) int { + x := vm.stack[vm.sp-args].(*evalInt64) + for sp := vm.sp - args + 1; sp < vm.sp; sp++ { + y := vm.stack[sp].(*evalInt64) + if lessThan == (y.i < x.i) { + x = y + } + } + vm.stack[vm.sp-args] = x + vm.sp -= args - 1 + return 1 + }, "FN MULTICMP INT64(SP-%d)...INT64(SP-1)", args) +} + +func (c *compiler) emitFn_MULTICMP_u(args int, lessThan bool) { + c.adjustStack(-(args - 1)) + + c.emit(func(vm *VirtualMachine) int { + x := vm.stack[vm.sp-args].(*evalUint64) + for sp := vm.sp - args + 1; sp < vm.sp; sp++ { + y := vm.stack[sp].(*evalUint64) + if lessThan == (y.u < x.u) { + x = y + } + } + vm.stack[vm.sp-args] = x + vm.sp -= args - 1 + return 1 + }, "FN MULTICMP UINT64(SP-%d)...UINT64(SP-1)", args) +} + +func (c *compiler) emitFn_MULTICMP_f(args int, lessThan bool) { + c.adjustStack(-(args - 1)) + + c.emit(func(vm *VirtualMachine) int { + x := vm.stack[vm.sp-args].(*evalFloat) + for sp := vm.sp - args + 1; sp < vm.sp; sp++ { + y := vm.stack[sp].(*evalFloat) + if lessThan == (y.f < x.f) { + x = y + } + } + vm.stack[vm.sp-args] = x + vm.sp -= args - 1 + return 1 + }, "FN MULTICMP FLOAT64(SP-%d)...FLOAT64(SP-1)", args) +} + +func (c *compiler) emitFn_MULTICMP_d(args int, lessThan bool) { + c.adjustStack(-(args - 1)) + + c.emit(func(vm *VirtualMachine) int { + x := vm.stack[vm.sp-args].(*evalDecimal) + xprec := x.length + + for sp := vm.sp - args + 1; sp < vm.sp; sp++ { + y := vm.stack[sp].(*evalDecimal) + if lessThan == (y.dec.Cmp(x.dec) < 0) { + x = y + } + if y.length > xprec { + xprec = y.length + } + } + vm.stack[vm.sp-args] = vm.arena.newEvalDecimalWithPrec(x.dec, xprec) + vm.sp -= args - 1 + return 1 + }, "FN MULTICMP DECIMAL(SP-%d)...DECIMAL(SP-1)", args) +} + +func (c *compiler) emitFn_MULTICMP_b(args int, lessThan bool) { + c.adjustStack(-(args - 1)) + + c.emit(func(vm *VirtualMachine) int { + x := vm.stack[vm.sp-args].ToRawBytes() + for sp := vm.sp - args + 1; sp < vm.sp; sp++ { + y := vm.stack[sp].ToRawBytes() + if lessThan == (bytes.Compare(y, x) < 0) { + x = y + } + } + vm.stack[vm.sp-args] = vm.arena.newEvalBinary(x) + vm.sp -= args - 1 + return 1 + }, "FN MULTICMP VARBINARY(SP-%d)...VARBINARY(SP-1)", args) +} + +func (c *compiler) emitFn_MULTICMP_c(args int, lessThan bool, tc collations.TypedCollation) { + col := tc.Collation.Get() + + c.adjustStack(-(args - 1)) + c.emit(func(vm *VirtualMachine) int { + x := vm.stack[vm.sp-args].ToRawBytes() + for sp := vm.sp - args + 1; sp < vm.sp; sp++ { + y := vm.stack[sp].ToRawBytes() + if lessThan == (col.Collate(y, x, false) < 0) { + x = y + } + } + vm.stack[vm.sp-args] = vm.arena.newEvalText(x, tc) + vm.sp -= args - 1 + return 1 + }, "FN MULTICMP FLOAT64(SP-%d)...FLOAT64(SP-1)", args) +} diff --git a/go/vt/vtgate/evalengine/mysql_test.go b/go/vt/vtgate/evalengine/mysql_test.go index 9fece5a9db5..4ccb4c30029 100644 --- a/go/vt/vtgate/evalengine/mysql_test.go +++ b/go/vt/vtgate/evalengine/mysql_test.go @@ -127,6 +127,6 @@ func TestMySQLGolden(t *testing.T) { func TestDebug1(t *testing.T) { // Debug - eval, err := testSingle(t, `SELECT 12.0 * 0`) + eval, err := testSingle(t, `SELECT LEAST(18446744073709551615, CAST(0 AS UNSIGNED), CAST(1 AS UNSIGNED))`) t.Logf("eval=%s err=%v coll=%s", eval.String(), err, eval.Collation().Get().Name()) } From 196d4a57306823cb0df034324007b836a3de9e5a Mon Sep 17 00:00:00 2001 From: Vicent Marti Date: Wed, 8 Mar 2023 12:21:33 +0100 Subject: [PATCH 13/42] evalengine/compiler: implement JSON_ARRAY Signed-off-by: Vicent Marti --- go/vt/vtgate/evalengine/compiler.go | 24 +++++++++++++++++++ .../evalengine/compiler_instructions.go | 13 ++++++++++ 2 files changed, 37 insertions(+) diff --git a/go/vt/vtgate/evalengine/compiler.go b/go/vt/vtgate/evalengine/compiler.go index 50c9cec44a8..a97081121be 100644 --- a/go/vt/vtgate/evalengine/compiler.go +++ b/go/vt/vtgate/evalengine/compiler.go @@ -235,6 +235,10 @@ func (c *compiler) compileCallable(call callable) (ctype, error) { return c.compileJSONUnquote(call) case *builtinJSONContainsPath: return c.compileJSONContainsPath(call) + case *builtinJSONArray: + return c.compileJSONArray(call) + case *builtinJSONObject: + return c.compileJSONObject(call) case *builtinRepeat: return c.compileRepeat(call) case *builtinToBase64: @@ -1514,3 +1518,23 @@ func (c *compiler) compileLike(expr *LikeExpr) (ctype, error) { return ctype{Type: sqltypes.Int64, Col: collationNumeric}, nil } + +func (c *compiler) compileJSONArray(call *builtinJSONArray) (ctype, error) { + for _, arg := range call.Arguments { + tt, err := c.compileExpr(arg) + if err != nil { + return ctype{}, err + } + + tt, err = c.compileToJSON(tt, 1) + if err != nil { + return ctype{}, err + } + } + c.emitFn_JSON_ARRAY(len(call.Arguments)) + return ctype{Type: sqltypes.TypeJSON, Col: collationJSON}, nil +} + +func (c *compiler) compileJSONObject(call *builtinJSONObject) (ctype, error) { + return ctype{}, c.unsupported(call) +} diff --git a/go/vt/vtgate/evalengine/compiler_instructions.go b/go/vt/vtgate/evalengine/compiler_instructions.go index 36fde0d75e2..c6c75a30c31 100644 --- a/go/vt/vtgate/evalengine/compiler_instructions.go +++ b/go/vt/vtgate/evalengine/compiler_instructions.go @@ -889,6 +889,19 @@ func (c *compiler) emitFn_JSON_CONTAINS_PATH(match jsonMatch, paths []*json.Path } } +func (c *compiler) emitFn_JSON_ARRAY(args int) { + c.adjustStack(-(args - 1)) + c.emit(func(vm *VirtualMachine) int { + ary := make([]*json.Value, 0, args) + for sp := vm.sp - args; sp < vm.sp; sp++ { + ary = append(ary, vm.stack[sp].(*json.Value)) + } + vm.stack[vm.sp-args] = json.NewArray(ary) + vm.sp -= args - 1 + return 1 + }, "FN JSON_ARRAY (SP-%d)...(SP-1)", args) +} + func (c *compiler) emitBitOp_bb(op bitwiseOp) { c.adjustStack(-1) From 2ef498c74d747b0e488a5c9dfbb21cd6dfc99acc Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Wed, 8 Mar 2023 12:44:56 +0100 Subject: [PATCH 14/42] Fix issues after rebasing This fixes the compiler for the newly exposed issues after rebasing on top of evalengine fixes. Signed-off-by: Dirkjan Bussink --- go/vt/vtgate/evalengine/compiler.go | 25 +++++++++++++++---- .../evalengine/compiler_instructions.go | 14 ++++++++++- 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/go/vt/vtgate/evalengine/compiler.go b/go/vt/vtgate/evalengine/compiler.go index a97081121be..b4b162d0696 100644 --- a/go/vt/vtgate/evalengine/compiler.go +++ b/go/vt/vtgate/evalengine/compiler.go @@ -477,6 +477,21 @@ func (c *compiler) compileToUint64(ct ctype, offset int) ctype { return ctype{sqltypes.Uint64, ct.Flag, collationNumeric} } +func (c *compiler) compileToBitwiseUint64(ct ctype, offset int) ctype { + switch ct.Type { + case sqltypes.Uint64: + return ct + case sqltypes.Int64: + c.emitConvert_iu(offset) + case sqltypes.Decimal: + c.emitConvert_dbit(offset) + // TODO: specialization + default: + c.emitConvert_xu(offset) + } + return ctype{sqltypes.Uint64, ct.Flag, collationNumeric} +} + func (c *compiler) compileToFloat(ct ctype, offset int) ctype { if sqltypes.IsFloat(ct.Type) { return ct @@ -957,8 +972,8 @@ func (c *compiler) compileBitwiseOp(left Expr, right Expr, op bitwiseOp) (ctype, } } - lt = c.compileToUint64(lt, 2) - rt = c.compileToUint64(rt, 1) + lt = c.compileToBitwiseUint64(lt, 2) + rt = c.compileToBitwiseUint64(rt, 1) c.emitBitOp_uu(op) c.jumpDestination(skip) @@ -990,7 +1005,7 @@ func (c *compiler) compileBitwiseShift(left Expr, right Expr, i int) (ctype, err return ctype{Type: sqltypes.VarBinary, Col: collationBinary}, nil } - _ = c.compileToUint64(lt, 2) + _ = c.compileToBitwiseUint64(lt, 2) _ = c.compileToUint64(rt, 1) if i < 0 { @@ -1018,7 +1033,7 @@ func (c *compiler) compileBitCount(expr *builtinBitCount) (ctype, error) { return ctype{Type: sqltypes.Int64, Col: collationBinary}, nil } - _ = c.compileToUint64(ct, 1) + _ = c.compileToBitwiseUint64(ct, 1) c.emitBitCount_u() c.jumpDestination(skip) return ctype{Type: sqltypes.Int64, Col: collationBinary}, nil @@ -1039,7 +1054,7 @@ func (c *compiler) compileBitwiseNot(expr *BitwiseNotExpr) (ctype, error) { return ct, nil } - ct = c.compileToUint64(ct, 1) + ct = c.compileToBitwiseUint64(ct, 1) c.emitBitwiseNot_u() c.jumpDestination(skip) return ct, nil diff --git a/go/vt/vtgate/evalengine/compiler_instructions.go b/go/vt/vtgate/evalengine/compiler_instructions.go index c6c75a30c31..dd9041a12e4 100644 --- a/go/vt/vtgate/evalengine/compiler_instructions.go +++ b/go/vt/vtgate/evalengine/compiler_instructions.go @@ -1014,6 +1014,18 @@ func (c *compiler) emitConvert_xu(offset int) { }, "CONV (SP-%d), UINT64", offset) } +// emitConvert_dbit is a special instruction emission for converting +// a bigdecimal in preparation for a bitwise operation. In that case +// we need to convert the bigdecimal to an int64 and then cast to +// uint64 to ensure we match the behavior of MySQL. +func (c *compiler) emitConvert_dbit(offset int) { + c.emit(func(vm *VirtualMachine) int { + arg := evalToNumeric(vm.stack[vm.sp-offset]) + vm.stack[vm.sp-offset] = vm.arena.newEvalUint64(uint64(arg.toInt64().i)) + return 1 + }, "CONV DECIMAL_BITWISE(SP-%d), UINT64", offset) +} + func (c *compiler) emitConvert_xb(offset int, t sqltypes.Type, length int, hasLength bool) { if hasLength { c.emit(func(vm *VirtualMachine) int { @@ -1328,7 +1340,7 @@ func (c *compiler) emitNeg_f() { func (c *compiler) emitNeg_d() { c.emit(func(vm *VirtualMachine) int { arg := vm.stack[vm.sp-1].(*evalDecimal) - arg.dec = arg.dec.NegInPlace() + arg.dec = arg.dec.Neg() return 1 }, "NEG DECIMAL(SP-1)") } From b49e88cfb52667adcc6f78469ea40db540a9a1e8 Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Wed, 8 Mar 2023 12:53:47 +0100 Subject: [PATCH 15/42] evalengine/compiler: Implement hex() and convert using Signed-off-by: Dirkjan Bussink --- go/vt/vtgate/evalengine/arena.go | 2 +- go/vt/vtgate/evalengine/compiler.go | 62 ++++++++++++++++++- .../evalengine/compiler_instructions.go | 18 ++++++ 3 files changed, 80 insertions(+), 2 deletions(-) diff --git a/go/vt/vtgate/evalengine/arena.go b/go/vt/vtgate/evalengine/arena.go index a5460f629d1..deb7d6b17cd 100644 --- a/go/vt/vtgate/evalengine/arena.go +++ b/go/vt/vtgate/evalengine/arena.go @@ -92,7 +92,7 @@ func (a *Arena) newEvalBinary(raw []byte) *evalBytes { return b } -func (a *Arena) newEvalText(raw []byte, tc collations.TypedCollation) eval { +func (a *Arena) newEvalText(raw []byte, tc collations.TypedCollation) *evalBytes { b := a.newEvalBytesEmpty() b.tt = int16(sqltypes.VarChar) b.col = tc diff --git a/go/vt/vtgate/evalengine/compiler.go b/go/vt/vtgate/evalengine/compiler.go index b4b162d0696..be32721a0d6 100644 --- a/go/vt/vtgate/evalengine/compiler.go +++ b/go/vt/vtgate/evalengine/compiler.go @@ -211,6 +211,9 @@ func (c *compiler) compileExpr(expr Expr) (ctype, error) { case *ConvertExpr: return c.compileConvert(expr) + case *ConvertUsingExpr: + return c.compileConvertUsing(expr) + case *CaseExpr: return c.compileCase(expr) @@ -255,6 +258,8 @@ func (c *compiler) compileCallable(call callable) (ctype, error) { return c.compileLength(call, bitLen) case *builtinASCII: return c.compileASCII(call) + case *builtinHex: + return c.compileHex(call) default: return ctype{}, c.unsupported(call) } @@ -1110,6 +1115,26 @@ func (c *compiler) compileConvert(conv *ConvertExpr) (ctype, error) { c.jumpDestination(skip) return convt, nil + +} + +func (c *compiler) compileConvertUsing(conv *ConvertUsingExpr) (ctype, error) { + _, err := c.compileExpr(conv.Inner) + if err != nil { + return ctype{}, err + } + + skip := c.jumpFrom() + c.emitNullCheck1(skip) + c.emitConvert_xc(1, sqltypes.VarChar, conv.Collation, 0, false) + c.jumpDestination(skip) + + col := collations.TypedCollation{ + Collation: conv.Collation, + Coercibility: collations.CoerceCoercible, + Repertoire: collations.RepertoireASCII, + } + return ctype{Type: sqltypes.VarChar, Col: col}, nil } func (c *compiler) compileCase(cs *CaseExpr) (ctype, error) { @@ -1464,6 +1489,41 @@ func (c *compiler) compileASCII(call *builtinASCII) (ctype, error) { return ctype{Type: sqltypes.Int64, Col: collationNumeric}, nil } +func (c *compiler) compileHex(call *builtinHex) (ctype, error) { + str, err := c.compileExpr(call.Arguments[0]) + if err != nil { + return ctype{}, err + } + + skip := c.jumpFrom() + c.emitNullCheck1(skip) + + col := collations.TypedCollation{ + Collation: c.defaultCollation, + Coercibility: collations.CoerceCoercible, + Repertoire: collations.RepertoireASCII, + } + + t := sqltypes.VarChar + if str.Type == sqltypes.Blob || str.Type == sqltypes.TypeJSON { + t = sqltypes.Text + } + + switch { + case sqltypes.IsNumber(str.Type), sqltypes.IsDecimal(str.Type): + c.emitFn_HEXd(col) + case sqltypes.IsText(str.Type) || sqltypes.IsBinary(str.Type): + c.emitFn_HEXc(t, col) + default: + c.emitConvert_xc(1, t, c.defaultCollation, 0, false) + c.emitFn_HEXc(t, col) + } + + c.jumpDestination(skip) + + return ctype{Type: t, Col: col}, nil +} + func (c *compiler) compileLike(expr *LikeExpr) (ctype, error) { lt, err := c.compileExpr(expr.Left) if err != nil { @@ -1541,7 +1601,7 @@ func (c *compiler) compileJSONArray(call *builtinJSONArray) (ctype, error) { return ctype{}, err } - tt, err = c.compileToJSON(tt, 1) + _, err = c.compileToJSON(tt, 1) if err != nil { return ctype{}, err } diff --git a/go/vt/vtgate/evalengine/compiler_instructions.go b/go/vt/vtgate/evalengine/compiler_instructions.go index dd9041a12e4..5bae7496ceb 100644 --- a/go/vt/vtgate/evalengine/compiler_instructions.go +++ b/go/vt/vtgate/evalengine/compiler_instructions.go @@ -1474,6 +1474,24 @@ func (c *compiler) emitFn_ASCII() { }, "FN ASCII VARCHAR(SP-1)") } +func (c *compiler) emitFn_HEXc(t sqltypes.Type, col collations.TypedCollation) { + c.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-1].(*evalBytes) + encoded := vm.arena.newEvalText(hexEncodeBytes(arg.bytes), col) + encoded.tt = int16(t) + vm.stack[vm.sp-1] = encoded + return 1 + }, "FN HEX VARCHAR(SP-1)") +} + +func (c *compiler) emitFn_HEXd(col collations.TypedCollation) { + c.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-1].(evalNumeric) + vm.stack[vm.sp-1] = vm.arena.newEvalText(hexEncodeUint(uint64(arg.toInt64().i)), col) + return 1 + }, "FN HEX NUMERIC(SP-1)") +} + func (c *compiler) emitLike_collate(expr *LikeExpr, collation collations.Collation) { c.adjustStack(-1) From 04471147c2a37b335155e2e9964ba31edc43f458 Mon Sep 17 00:00:00 2001 From: Vicent Marti Date: Wed, 8 Mar 2023 13:18:10 +0100 Subject: [PATCH 16/42] evalengine/compiler: implement JSON_OBJECT Signed-off-by: Vicent Marti --- go/vt/vtgate/evalengine/compiler.go | 40 ++++++++++++++++++- .../evalengine/compiler_instructions.go | 30 ++++++++++++++ go/vt/vtgate/evalengine/fn_json.go | 4 +- 3 files changed, 72 insertions(+), 2 deletions(-) diff --git a/go/vt/vtgate/evalengine/compiler.go b/go/vt/vtgate/evalengine/compiler.go index be32721a0d6..fca5ceee86a 100644 --- a/go/vt/vtgate/evalengine/compiler.go +++ b/go/vt/vtgate/evalengine/compiler.go @@ -2,6 +2,7 @@ package evalengine import ( "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/mysql/collations/charset" "vitess.io/vitess/go/slices2" "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" @@ -924,6 +925,8 @@ func (c *compiler) compileToJSON(doct ctype, offset int) (ctype, error) { c.emitConvert_cj(offset) case sqltypes.VarBinary: c.emitConvert_bj(offset) + case sqltypes.Null: + c.emitConvert_Nj(offset) default: return ctype{}, vterrors.Errorf(vtrpc.Code_UNIMPLEMENTED, "Unsupported type conversion: %s AS JSON", doct.Type) } @@ -1611,5 +1614,40 @@ func (c *compiler) compileJSONArray(call *builtinJSONArray) (ctype, error) { } func (c *compiler) compileJSONObject(call *builtinJSONObject) (ctype, error) { - return ctype{}, c.unsupported(call) + for i := 0; i < len(call.Arguments); i += 2 { + key, err := c.compileExpr(call.Arguments[i]) + if err != nil { + return ctype{}, err + } + c.compileToJSONKey(key) + val, err := c.compileExpr(call.Arguments[i+1]) + if err != nil { + return ctype{}, err + } + val, err = c.compileToJSON(val, 1) + if err != nil { + return ctype{}, err + } + } + c.emitFn_JSON_OBJECT(len(call.Arguments)) + return ctype{Type: sqltypes.TypeJSON, Col: collationJSON}, nil +} + +func isEncodingJSONSafe(col collations.ID) bool { + switch col.Get().Charset().(type) { + case charset.Charset_utf8mb4, charset.Charset_utf8mb3, charset.Charset_binary: + return true + default: + return false + } +} + +func (c *compiler) compileToJSONKey(key ctype) { + if key.Type == sqltypes.VarChar && isEncodingJSONSafe(key.Col.Collation) { + return + } + if key.Type == sqltypes.VarBinary { + return + } + c.emitConvert_xc(1, sqltypes.VarChar, collations.CollationUtf8mb4ID, 0, false) } diff --git a/go/vt/vtgate/evalengine/compiler_instructions.go b/go/vt/vtgate/evalengine/compiler_instructions.go index 5bae7496ceb..338d84dd67c 100644 --- a/go/vt/vtgate/evalengine/compiler_instructions.go +++ b/go/vt/vtgate/evalengine/compiler_instructions.go @@ -902,6 +902,29 @@ func (c *compiler) emitFn_JSON_ARRAY(args int) { }, "FN JSON_ARRAY (SP-%d)...(SP-1)", args) } +func (c *compiler) emitFn_JSON_OBJECT(args int) { + c.adjustStack(-(args - 1)) + c.emit(func(vm *VirtualMachine) int { + j := json.NewObject() + obj, _ := j.Object() + + for sp := vm.sp - args; sp < vm.sp; sp += 2 { + key := vm.stack[sp] + val := vm.stack[sp+1] + + if key == nil { + vm.err = errJSONKeyIsNil + return 0 + } + + obj.Set(key.(*evalBytes).string(), val.(*evalJSON), json.Set) + } + vm.stack[vm.sp-args] = j + vm.sp -= args - 1 + return 1 + }, "FN JSON_ARRAY (SP-%d)...(SP-1)", args) +} + func (c *compiler) emitBitOp_bb(op bitwiseOp) { c.adjustStack(-1) @@ -1077,6 +1100,13 @@ func (c *compiler) emitConvert_bj(offset int) { }, "CONV VARBINARY(SP-%d), JSON") } +func (c *compiler) emitConvert_Nj(offset int) { + c.emit(func(vm *VirtualMachine) int { + vm.stack[vm.sp-offset] = json.ValueNull + return 1 + }, "CONV NULL(SP-%d), JSON") +} + func (c *compiler) emitConvert_xc(offset int, t sqltypes.Type, collation collations.ID, length int, hasLength bool) { if hasLength { c.emit(func(vm *VirtualMachine) int { diff --git a/go/vt/vtgate/evalengine/fn_json.go b/go/vt/vtgate/evalengine/fn_json.go index 2107e4adaeb..d0aced97417 100644 --- a/go/vt/vtgate/evalengine/fn_json.go +++ b/go/vt/vtgate/evalengine/fn_json.go @@ -146,6 +146,8 @@ func (call *builtinJSONUnquote) typeof(env *ExpressionEnv) (sqltypes.Type, typeF return sqltypes.Blob, f } +var errJSONKeyIsNil = vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "JSON documents may not contain NULL member names.") + func (call *builtinJSONObject) eval(env *ExpressionEnv) (eval, error) { j := json.NewObject() obj, _ := j.Object() @@ -156,7 +158,7 @@ func (call *builtinJSONObject) eval(env *ExpressionEnv) (eval, error) { return nil, err } if key == nil { - return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "JSON documents may not contain NULL member names.") + return nil, errJSONKeyIsNil } key1, err := evalToVarchar(key, collations.CollationUtf8mb4ID, true) if err != nil { From 961470379c3e52eafa2ca9e67a87a8d629ffd795 Mon Sep 17 00:00:00 2001 From: Vicent Marti Date: Wed, 8 Mar 2023 15:11:46 +0100 Subject: [PATCH 17/42] compiler: add support for tuple comparisons Signed-off-by: Vicent Marti --- go/vt/vtgate/evalengine/compiler.go | 51 ++++++++ .../evalengine/compiler_instructions.go | 121 +++++++++++++++++- 2 files changed, 166 insertions(+), 6 deletions(-) diff --git a/go/vt/vtgate/evalengine/compiler.go b/go/vt/vtgate/evalengine/compiler.go index fca5ceee86a..937ed4a7ce1 100644 --- a/go/vt/vtgate/evalengine/compiler.go +++ b/go/vt/vtgate/evalengine/compiler.go @@ -224,6 +224,9 @@ func (c *compiler) compileExpr(expr Expr) (ctype, error) { case callable: return c.compileCallable(expr) + case TupleExpr: + return c.compileTuple(expr) + default: return ctype{}, c.unsupported(expr) } @@ -266,6 +269,36 @@ func (c *compiler) compileCallable(call callable) (ctype, error) { } } +func (c *compiler) compileComparisonTuple(expr *ComparisonExpr) (ctype, error) { + + switch expr.Op.(type) { + case compareNullSafeEQ: + c.emitCmpTupleNullsafe() + return ctype{Type: sqltypes.Int64, Col: collationNumeric}, nil + case compareEQ: + c.emitCmpTuple(true) + c.emitCmp_eq_n() + case compareNE: + c.emitCmpTuple(true) + c.emitCmp_ne_n() + case compareLT: + c.emitCmpTuple(false) + c.emitCmp_lt_n() + case compareLE: + c.emitCmpTuple(false) + c.emitCmp_le_n() + case compareGT: + c.emitCmpTuple(false) + c.emitCmp_gt_n() + case compareGE: + c.emitCmpTuple(false) + c.emitCmp_ge_n() + default: + panic("invalid comparison operator") + } + return ctype{Type: sqltypes.Int64, Flag: flagNullable, Col: collationNumeric}, nil +} + func (c *compiler) compileComparison(expr *ComparisonExpr) (ctype, error) { lt, err := c.compileExpr(expr.Left) if err != nil { @@ -277,6 +310,13 @@ func (c *compiler) compileComparison(expr *ComparisonExpr) (ctype, error) { return ctype{}, err } + if lt.Type == sqltypes.Tuple || rt.Type == sqltypes.Tuple { + if lt.Type != rt.Type { + return ctype{}, vterrors.Errorf(vtrpc.Code_INTERNAL, "did not typecheck tuples during comparison") + } + return c.compileComparisonTuple(expr) + } + swapped := false skip := c.jumpFrom() @@ -1651,3 +1691,14 @@ func (c *compiler) compileToJSONKey(key ctype) { } c.emitConvert_xc(1, sqltypes.VarChar, collations.CollationUtf8mb4ID, 0, false) } + +func (c *compiler) compileTuple(tuple TupleExpr) (ctype, error) { + for _, arg := range tuple { + _, err := c.compileExpr(arg) + if err != nil { + return ctype{}, err + } + } + c.emitPackTuple(len(tuple)) + return ctype{Type: sqltypes.Tuple, Col: collationBinary}, nil +} diff --git a/go/vt/vtgate/evalengine/compiler_instructions.go b/go/vt/vtgate/evalengine/compiler_instructions.go index 338d84dd67c..84b096c12bf 100644 --- a/go/vt/vtgate/evalengine/compiler_instructions.go +++ b/go/vt/vtgate/evalengine/compiler_instructions.go @@ -576,7 +576,6 @@ func (c *compiler) emitNullCheck2(j *jump) { func (c *compiler) emitCmp_eq() { c.adjustStack(1) - c.emit(func(vm *VirtualMachine) int { vm.stack[vm.sp] = newEvalBool(vm.flags.cmp == 0) vm.sp++ @@ -586,7 +585,6 @@ func (c *compiler) emitCmp_eq() { func (c *compiler) emitCmp_ne() { c.adjustStack(1) - c.emit(func(vm *VirtualMachine) int { vm.stack[vm.sp] = newEvalBool(vm.flags.cmp != 0) vm.sp++ @@ -596,7 +594,6 @@ func (c *compiler) emitCmp_ne() { func (c *compiler) emitCmp_lt() { c.adjustStack(1) - c.emit(func(vm *VirtualMachine) int { vm.stack[vm.sp] = newEvalBool(vm.flags.cmp < 0) vm.sp++ @@ -606,7 +603,6 @@ func (c *compiler) emitCmp_lt() { func (c *compiler) emitCmp_le() { c.adjustStack(1) - c.emit(func(vm *VirtualMachine) int { vm.stack[vm.sp] = newEvalBool(vm.flags.cmp <= 0) vm.sp++ @@ -616,7 +612,6 @@ func (c *compiler) emitCmp_le() { func (c *compiler) emitCmp_gt() { c.adjustStack(1) - c.emit(func(vm *VirtualMachine) int { vm.stack[vm.sp] = newEvalBool(vm.flags.cmp > 0) vm.sp++ @@ -626,7 +621,6 @@ func (c *compiler) emitCmp_gt() { func (c *compiler) emitCmp_ge() { c.adjustStack(1) - c.emit(func(vm *VirtualMachine) int { vm.stack[vm.sp] = newEvalBool(vm.flags.cmp >= 0) vm.sp++ @@ -634,6 +628,121 @@ func (c *compiler) emitCmp_ge() { }, "CMPFLAG GE") } +func (c *compiler) emitCmp_eq_n() { + c.adjustStack(1) + c.emit(func(vm *VirtualMachine) int { + if vm.flags.null { + vm.stack[vm.sp] = nil + } else { + vm.stack[vm.sp] = newEvalBool(vm.flags.cmp == 0) + } + vm.sp++ + return 1 + }, "CMPFLAG EQ [NULL]") +} + +func (c *compiler) emitCmp_ne_n() { + c.adjustStack(1) + c.emit(func(vm *VirtualMachine) int { + if vm.flags.null { + vm.stack[vm.sp] = nil + } else { + vm.stack[vm.sp] = newEvalBool(vm.flags.cmp != 0) + } + vm.sp++ + return 1 + }, "CMPFLAG NE [NULL]") +} + +func (c *compiler) emitCmp_lt_n() { + c.adjustStack(1) + c.emit(func(vm *VirtualMachine) int { + if vm.flags.null { + vm.stack[vm.sp] = nil + } else { + vm.stack[vm.sp] = newEvalBool(vm.flags.cmp < 0) + } + vm.sp++ + return 1 + }, "CMPFLAG LT [NULL]") +} + +func (c *compiler) emitCmp_le_n() { + c.adjustStack(1) + c.emit(func(vm *VirtualMachine) int { + if vm.flags.null { + vm.stack[vm.sp] = nil + } else { + vm.stack[vm.sp] = newEvalBool(vm.flags.cmp <= 0) + } + vm.sp++ + return 1 + }, "CMPFLAG LE [NULL]") +} + +func (c *compiler) emitCmp_gt_n() { + c.adjustStack(1) + c.emit(func(vm *VirtualMachine) int { + if vm.flags.null { + vm.stack[vm.sp] = nil + } else { + vm.stack[vm.sp] = newEvalBool(vm.flags.cmp > 0) + } + vm.sp++ + return 1 + }, "CMPFLAG GT [NULL]") +} + +func (c *compiler) emitCmp_ge_n() { + c.adjustStack(1) + c.emit(func(vm *VirtualMachine) int { + if vm.flags.null { + vm.stack[vm.sp] = nil + } else { + vm.stack[vm.sp] = newEvalBool(vm.flags.cmp >= 0) + } + vm.sp++ + return 1 + }, "CMPFLAG GE [NULL]") +} + +func (c *compiler) emitPackTuple(tlen int) { + c.adjustStack(-(tlen - 1)) + c.emit(func(vm *VirtualMachine) int { + tuple := make([]eval, tlen) + copy(tuple, vm.stack[vm.sp-tlen:]) + vm.stack[vm.sp-tlen] = &evalTuple{tuple} + vm.sp -= tlen - 1 + return 1 + }, "TUPLE (SP-%d)...(SP-1)", tlen) +} + +func (c *compiler) emitCmpTuple(fullEquality bool) { + c.adjustStack(-2) + c.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalTuple) + r := vm.stack[vm.sp-1].(*evalTuple) + vm.sp -= 2 + vm.flags.cmp, vm.flags.null, vm.err = evalCompareMany(l.t, r.t, fullEquality) + return 1 + }, "CMP TUPLE(SP-2), TUPLE(SP-1)") +} + +func (c *compiler) emitCmpTupleNullsafe() { + c.adjustStack(-1) + c.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalTuple) + r := vm.stack[vm.sp-1].(*evalTuple) + + var equals bool + equals, vm.err = evalCompareTuplesNullSafe(l.t, r.t) + + vm.stack[vm.sp-2] = newEvalBool(equals) + vm.sp -= 1 + return 1 + }, "CMP NULLSAFE TUPLE(SP-2), TUPLE(SP-1)") +} + func (c *compiler) emitCmpNum_ii() { c.adjustStack(-2) From c802f542c275c3a578f1c5f7a997bc00efae35d8 Mon Sep 17 00:00:00 2001 From: Vicent Marti Date: Wed, 8 Mar 2023 16:04:27 +0100 Subject: [PATCH 18/42] evalengine/compiler: arrange in separate files Signed-off-by: Vicent Marti --- go/vt/vtgate/evalengine/compiler.go | 1564 +--------- .../vtgate/evalengine/compiler_arithmetic.go | 289 ++ ...mpiler_instructions.go => compiler_asm.go} | 2640 +++++++++-------- go/vt/vtgate/evalengine/compiler_bit.go | 138 + go/vt/vtgate/evalengine/compiler_compare.go | 297 ++ go/vt/vtgate/evalengine/compiler_convert.go | 108 + go/vt/vtgate/evalengine/compiler_fn.go | 333 +++ go/vt/vtgate/evalengine/compiler_json.go | 208 ++ go/vt/vtgate/evalengine/compiler_logical.go | 46 + go/vt/vtgate/evalengine/compiler_test.go | 4 +- 10 files changed, 2833 insertions(+), 2794 deletions(-) create mode 100644 go/vt/vtgate/evalengine/compiler_arithmetic.go rename go/vt/vtgate/evalengine/{compiler_instructions.go => compiler_asm.go} (68%) create mode 100644 go/vt/vtgate/evalengine/compiler_bit.go create mode 100644 go/vt/vtgate/evalengine/compiler_compare.go create mode 100644 go/vt/vtgate/evalengine/compiler_convert.go create mode 100644 go/vt/vtgate/evalengine/compiler_fn.go create mode 100644 go/vt/vtgate/evalengine/compiler_json.go create mode 100644 go/vt/vtgate/evalengine/compiler_logical.go diff --git a/go/vt/vtgate/evalengine/compiler.go b/go/vt/vtgate/evalengine/compiler.go index 937ed4a7ce1..b31650e2b98 100644 --- a/go/vt/vtgate/evalengine/compiler.go +++ b/go/vt/vtgate/evalengine/compiler.go @@ -2,31 +2,22 @@ package evalengine import ( "vitess.io/vitess/go/mysql/collations" - "vitess.io/vitess/go/mysql/collations/charset" - "vitess.io/vitess/go/slices2" "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/vterrors" - "vitess.io/vitess/go/vt/vtgate/evalengine/internal/json" ) type frame func(vm *VirtualMachine) int type compiler struct { - ins []frame - log CompilerLog + asm assembler fields []*querypb.Field defaultCollation collations.ID - - stack struct { - cur int - max int - } } -type CompilerLog interface { - Disasm(ins string, args ...any) +type AssemblerLog interface { + Instruction(ins string, args ...any) Stack(old, new int) } @@ -38,9 +29,9 @@ type compiledCoercion struct { type CompilerOption func(c *compiler) -func WithCompilerLog(log CompilerLog) CompilerOption { +func WithAssemblerLog(log AssemblerLog) CompilerOption { return func(c *compiler) { - c.log = log + c.asm.log = log } } @@ -62,77 +53,10 @@ func Compile(expr Expr, fields []*querypb.Field, options ...CompilerOption) (*Pr if err != nil { return nil, err } - if comp.stack.cur != 1 { - return nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "bad compilation: stack pointer at %d after compilation", comp.stack.cur) - } - return &Program{code: comp.ins, original: expr, stack: comp.stack.max}, nil -} - -func (c *compiler) adjustStack(offset int) { - c.stack.cur += offset - if c.stack.cur < 0 { - panic("negative stack position") - } - if c.stack.cur > c.stack.max { - c.stack.max = c.stack.cur - } - if c.log != nil { - c.log.Stack(c.stack.cur-offset, c.stack.cur) - } -} - -func (c *compiler) compileColumn(offset int) (ctype, error) { - if offset >= len(c.fields) { - return ctype{}, vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "Missing field for column %d", offset) - } - - field := c.fields[offset] - col := collations.TypedCollation{ - Collation: collations.ID(field.Charset), - Coercibility: collations.CoerceCoercible, - Repertoire: collations.RepertoireASCII, - } - if col.Collation != collations.CollationBinaryID { - col.Repertoire = collations.RepertoireUnicode - } - - switch tt := field.Type; { - case sqltypes.IsSigned(tt): - c.emitPushColumn_i(offset) - case sqltypes.IsUnsigned(tt): - c.emitPushColumn_u(offset) - case sqltypes.IsFloat(tt): - c.emitPushColumn_f(offset) - case sqltypes.IsDecimal(tt): - c.emitPushColumn_d(offset) - case sqltypes.IsText(tt): - if tt == sqltypes.HexNum { - c.emitPushColumn_hexnum(offset) - } else if tt == sqltypes.HexVal { - c.emitPushColumn_hexval(offset) - } else { - c.emitPushColumn_text(offset, col) - } - case sqltypes.IsBinary(tt): - c.emitPushColumn_bin(offset) - case sqltypes.IsNull(tt): - c.emitPushNull() - case tt == sqltypes.TypeJSON: - c.emitPushColumn_json(offset) - default: - return ctype{}, vterrors.Errorf(vtrpc.Code_UNIMPLEMENTED, "Type is not supported: %s", tt) - } - - var flag typeFlag - if (field.Flags & uint32(querypb.MySqlFlag_NOT_NULL_FLAG)) == 0 { - flag |= flagNullable + if comp.asm.stack.cur != 1 { + return nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "bad compilation: stack pointer at %d after compilation", comp.asm.stack.cur) } - - return ctype{ - Type: field.Type, - Flag: flag, - Col: col, - }, nil + return &Program{code: comp.asm.ins, original: expr, stack: comp.asm.stack.max}, nil } type ctype struct { @@ -153,22 +77,6 @@ func (ct ctype) isHexOrBitLiteral() bool { return ct.Flag&flagBit != 0 || ct.Flag&flagHex != 0 } -type jump struct { - from, to int -} - -func (j *jump) offset() int { - return j.to - j.from -} - -func (c *compiler) jumpFrom() *jump { - return &jump{from: len(c.ins)} -} - -func (c *compiler) jumpDestination(j *jump) { - j.to = len(c.ins) -} - func (c *compiler) unsupported(expr Expr) error { return vterrors.Errorf(vtrpc.Code_UNIMPLEMENTED, "unsupported compilation for expression '%s'", FormatExpr(expr)) } @@ -177,8 +85,8 @@ func (c *compiler) compileExpr(expr Expr) (ctype, error) { switch expr := expr.(type) { case *Literal: if expr.inner == nil { - c.emitPushNull() - } else if err := c.emitPushLiteral(expr.inner); err != nil { + c.asm.PushNull() + } else if err := c.asm.PushLiteral(expr.inner); err != nil { return ctype{}, err } @@ -200,9 +108,6 @@ func (c *compiler) compileExpr(expr Expr) (ctype, error) { case *NegateExpr: return c.compileNegate(expr) - case *builtinBitCount: - return c.compileBitCount(expr) - case *ComparisonExpr: return c.compileComparison(expr) @@ -222,7 +127,7 @@ func (c *compiler) compileExpr(expr Expr) (ctype, error) { return c.compileLike(expr) case callable: - return c.compileCallable(expr) + return c.compileFn(expr) case TupleExpr: return c.compileTuple(expr) @@ -232,257 +137,69 @@ func (c *compiler) compileExpr(expr Expr) (ctype, error) { } } -func (c *compiler) compileCallable(call callable) (ctype, error) { - switch call := call.(type) { - case *builtinMultiComparison: - return c.compileMultiComparison(call) - case *builtinJSONExtract: - return c.compileJSONExtract(call) - case *builtinJSONUnquote: - return c.compileJSONUnquote(call) - case *builtinJSONContainsPath: - return c.compileJSONContainsPath(call) - case *builtinJSONArray: - return c.compileJSONArray(call) - case *builtinJSONObject: - return c.compileJSONObject(call) - case *builtinRepeat: - return c.compileRepeat(call) - case *builtinToBase64: - return c.compileToBase64(call) - case *builtinFromBase64: - return c.compileFromBase64(call) - case *builtinChangeCase: - return c.compileChangeCase(call) - case *builtinCharLength: - return c.compileLength(call, charLen) - case *builtinLength: - return c.compileLength(call, byteLen) - case *builtinBitLength: - return c.compileLength(call, bitLen) - case *builtinASCII: - return c.compileASCII(call) - case *builtinHex: - return c.compileHex(call) - default: - return ctype{}, c.unsupported(call) - } -} - -func (c *compiler) compileComparisonTuple(expr *ComparisonExpr) (ctype, error) { - - switch expr.Op.(type) { - case compareNullSafeEQ: - c.emitCmpTupleNullsafe() - return ctype{Type: sqltypes.Int64, Col: collationNumeric}, nil - case compareEQ: - c.emitCmpTuple(true) - c.emitCmp_eq_n() - case compareNE: - c.emitCmpTuple(true) - c.emitCmp_ne_n() - case compareLT: - c.emitCmpTuple(false) - c.emitCmp_lt_n() - case compareLE: - c.emitCmpTuple(false) - c.emitCmp_le_n() - case compareGT: - c.emitCmpTuple(false) - c.emitCmp_gt_n() - case compareGE: - c.emitCmpTuple(false) - c.emitCmp_ge_n() - default: - panic("invalid comparison operator") - } - return ctype{Type: sqltypes.Int64, Flag: flagNullable, Col: collationNumeric}, nil -} - -func (c *compiler) compileComparison(expr *ComparisonExpr) (ctype, error) { - lt, err := c.compileExpr(expr.Left) - if err != nil { - return ctype{}, err - } - - rt, err := c.compileExpr(expr.Right) - if err != nil { - return ctype{}, err +func (c *compiler) compileColumn(offset int) (ctype, error) { + if offset >= len(c.fields) { + return ctype{}, vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "Missing field for column %d", offset) } - if lt.Type == sqltypes.Tuple || rt.Type == sqltypes.Tuple { - if lt.Type != rt.Type { - return ctype{}, vterrors.Errorf(vtrpc.Code_INTERNAL, "did not typecheck tuples during comparison") - } - return c.compileComparisonTuple(expr) + field := c.fields[offset] + col := collations.TypedCollation{ + Collation: collations.ID(field.Charset), + Coercibility: collations.CoerceCoercible, + Repertoire: collations.RepertoireASCII, } - - swapped := false - skip := c.jumpFrom() - - switch expr.Op.(type) { - case compareNullSafeEQ: - c.emitNullCmp(skip) - default: - c.emitNullCheck2(skip) + if col.Collation != collations.CollationBinaryID { + col.Repertoire = collations.RepertoireUnicode } - switch { - case compareAsStrings(lt.Type, rt.Type): - var merged collations.TypedCollation - var coerceLeft collations.Coercion - var coerceRight collations.Coercion - var env = collations.Local() - - if lt.Col.Collation != rt.Col.Collation { - merged, coerceLeft, coerceRight, err = env.MergeCollations(lt.Col, rt.Col, collations.CoercionOptions{ - ConvertToSuperset: true, - ConvertWithCoercion: true, - }) - } else { - merged = lt.Col - } - if err != nil { - return ctype{}, err - } - - if coerceLeft == nil && coerceRight == nil { - c.emitCmpString_collate(merged.Collation.Get()) + switch tt := field.Type; { + case sqltypes.IsSigned(tt): + c.asm.PushColumn_i(offset) + case sqltypes.IsUnsigned(tt): + c.asm.PushColumn_u(offset) + case sqltypes.IsFloat(tt): + c.asm.PushColumn_f(offset) + case sqltypes.IsDecimal(tt): + c.asm.PushColumn_d(offset) + case sqltypes.IsText(tt): + if tt == sqltypes.HexNum { + c.asm.PushColumn_hexnum(offset) + } else if tt == sqltypes.HexVal { + c.asm.PushColumn_hexval(offset) } else { - if coerceLeft == nil { - coerceLeft = func(dst, in []byte) ([]byte, error) { return in, nil } - } - if coerceRight == nil { - coerceRight = func(dst, in []byte) ([]byte, error) { return in, nil } - } - c.emitCmpString_coerce(&compiledCoercion{ - col: merged.Collation.Get(), - left: coerceLeft, - right: coerceRight, - }) - } - - case compareAsSameNumericType(lt.Type, rt.Type) || compareAsDecimal(lt.Type, rt.Type): - switch lt.Type { - case sqltypes.Int64: - switch rt.Type { - case sqltypes.Int64: - c.emitCmpNum_ii() - case sqltypes.Uint64: - c.emitCmpNum_iu(2, 1) - case sqltypes.Float64: - c.emitCmpNum_if(2, 1) - case sqltypes.Decimal: - c.emitCmpNum_id(2, 1) - } - case sqltypes.Uint64: - switch rt.Type { - case sqltypes.Int64: - c.emitCmpNum_iu(1, 2) - swapped = true - case sqltypes.Uint64: - c.emitCmpNum_uu() - case sqltypes.Float64: - c.emitCmpNum_uf(2, 1) - case sqltypes.Decimal: - c.emitCmpNum_ud(2, 1) - } - case sqltypes.Float64: - switch rt.Type { - case sqltypes.Int64: - c.emitCmpNum_if(1, 2) - swapped = true - case sqltypes.Uint64: - c.emitCmpNum_uf(1, 2) - swapped = true - case sqltypes.Float64: - c.emitCmpNum_ff() - case sqltypes.Decimal: - c.emitCmpNum_fd(2, 1) - } - - case sqltypes.Decimal: - switch rt.Type { - case sqltypes.Int64: - c.emitCmpNum_id(1, 2) - swapped = true - case sqltypes.Uint64: - c.emitCmpNum_ud(1, 2) - swapped = true - case sqltypes.Float64: - c.emitCmpNum_fd(1, 2) - swapped = true - case sqltypes.Decimal: - c.emitCmpNum_dd() - } + c.asm.PushColumn_text(offset, col) } - - case compareAsDates(lt.Type, rt.Type) || compareAsDateAndString(lt.Type, rt.Type) || compareAsDateAndNumeric(lt.Type, rt.Type): - return ctype{}, c.unsupported(expr) - + case sqltypes.IsBinary(tt): + c.asm.PushColumn_bin(offset) + case sqltypes.IsNull(tt): + c.asm.PushNull() + case tt == sqltypes.TypeJSON: + c.asm.PushColumn_json(offset) default: - lt = c.compileToFloat(lt, 2) - rt = c.compileToFloat(rt, 1) - c.emitCmpNum_ff() + return ctype{}, vterrors.Errorf(vtrpc.Code_UNIMPLEMENTED, "Type is not supported: %s", tt) } - cmptype := ctype{Type: sqltypes.Int64, Col: collationNumeric} - - switch expr.Op.(type) { - case compareEQ: - c.emitCmp_eq() - case compareNE: - c.emitCmp_ne() - case compareLT: - if swapped { - c.emitCmp_gt() - } else { - c.emitCmp_lt() - } - case compareLE: - if swapped { - c.emitCmp_ge() - } else { - c.emitCmp_le() - } - case compareGT: - if swapped { - c.emitCmp_lt() - } else { - c.emitCmp_gt() - } - case compareGE: - if swapped { - c.emitCmp_le() - } else { - c.emitCmp_ge() - } - case compareNullSafeEQ: - c.jumpDestination(skip) - c.emitCmp_eq() - return cmptype, nil - - default: - panic("unexpected comparison operator") + var flag typeFlag + if (field.Flags & uint32(querypb.MySqlFlag_NOT_NULL_FLAG)) == 0 { + flag |= flagNullable } - c.jumpDestination(skip) - return cmptype, nil + return ctype{ + Type: field.Type, + Flag: flag, + Col: col, + }, nil } -func (c *compiler) compileArithmetic(expr *ArithmeticExpr) (ctype, error) { - switch expr.Op.(type) { - case *opArithAdd: - return c.compileArithmeticAdd(expr.Left, expr.Right) - case *opArithSub: - return c.compileArithmeticSub(expr.Left, expr.Right) - case *opArithMul: - return c.compileArithmeticMul(expr.Left, expr.Right) - case *opArithDiv: - return c.compileArithmeticDiv(expr.Left, expr.Right) - default: - panic("unexpected arithmetic operator") +func (c *compiler) compileTuple(tuple TupleExpr) (ctype, error) { + for _, arg := range tuple { + _, err := c.compileExpr(arg) + if err != nil { + return ctype{}, err + } } + c.asm.PackTuple(len(tuple)) + return ctype{Type: sqltypes.Tuple, Col: collationBinary}, nil } func (c *compiler) compileToNumeric(ct ctype, offset int) ctype { @@ -490,10 +207,10 @@ func (c *compiler) compileToNumeric(ct ctype, offset int) ctype { return ct } if ct.Type == sqltypes.VarBinary && (ct.Flag&flagHex) != 0 { - c.emitConvert_hex(offset) + c.asm.Convert_hex(offset) return ctype{sqltypes.Uint64, ct.Flag, collationNumeric} } - c.emitConvert_xf(offset) + c.asm.Convert_xf(offset) return ctype{sqltypes.Float64, ct.Flag, collationNumeric} } @@ -502,10 +219,10 @@ func (c *compiler) compileToInt64(ct ctype, offset int) ctype { case sqltypes.Int64: return ct case sqltypes.Uint64: - c.emitConvert_ui(offset) + c.asm.Convert_ui(offset) // TODO: specialization default: - c.emitConvert_xi(offset) + c.asm.Convert_xi(offset) } return ctype{sqltypes.Int64, ct.Flag, collationNumeric} } @@ -515,10 +232,10 @@ func (c *compiler) compileToUint64(ct ctype, offset int) ctype { case sqltypes.Uint64: return ct case sqltypes.Int64: - c.emitConvert_iu(offset) + c.asm.Convert_iu(offset) // TODO: specialization default: - c.emitConvert_xu(offset) + c.asm.Convert_xu(offset) } return ctype{sqltypes.Uint64, ct.Flag, collationNumeric} } @@ -528,12 +245,12 @@ func (c *compiler) compileToBitwiseUint64(ct ctype, offset int) ctype { case sqltypes.Uint64: return ct case sqltypes.Int64: - c.emitConvert_iu(offset) + c.asm.Convert_iu(offset) case sqltypes.Decimal: - c.emitConvert_dbit(offset) + c.asm.Convert_dbit(offset) // TODO: specialization default: - c.emitConvert_xu(offset) + c.asm.Convert_xu(offset) } return ctype{sqltypes.Uint64, ct.Flag, collationNumeric} } @@ -544,1161 +261,28 @@ func (c *compiler) compileToFloat(ct ctype, offset int) ctype { } switch ct.Type { case sqltypes.Int64: - c.emitConvert_if(offset) + c.asm.Convert_if(offset) case sqltypes.Uint64: // only emit u->f conversion if this is not a hex value; hex values // will already be converted - c.emitConvert_uf(offset) + c.asm.Convert_uf(offset) default: - c.emitConvert_xf(offset) + c.asm.Convert_xf(offset) } return ctype{sqltypes.Float64, ct.Flag, collationNumeric} } -func (c *compiler) compileNumericPriority(lt, rt ctype) (ctype, ctype, bool) { - switch lt.Type { - case sqltypes.Int64: - if rt.Type == sqltypes.Uint64 || rt.Type == sqltypes.Float64 || rt.Type == sqltypes.Decimal { - return rt, lt, true - } - case sqltypes.Uint64: - if rt.Type == sqltypes.Float64 || rt.Type == sqltypes.Decimal { - return rt, lt, true - } - case sqltypes.Decimal: - if rt.Type == sqltypes.Float64 { - return rt, lt, true - } - } - return lt, rt, false -} - -func (c *compiler) compileArithmeticAdd(left, right Expr) (ctype, error) { - lt, err := c.compileExpr(left) - if err != nil { - return ctype{}, err - } - - rt, err := c.compileExpr(right) - if err != nil { - return ctype{}, err - } - - swap := false - skip := c.jumpFrom() - c.emitNullCheck2(skip) - lt = c.compileToNumeric(lt, 2) - rt = c.compileToNumeric(rt, 1) - lt, rt, swap = c.compileNumericPriority(lt, rt) - - var sumtype sqltypes.Type - - switch lt.Type { - case sqltypes.Int64: - c.emitAdd_ii() - sumtype = sqltypes.Int64 - case sqltypes.Uint64: - switch rt.Type { - case sqltypes.Int64: - c.emitAdd_ui(swap) - case sqltypes.Uint64: - c.emitAdd_uu() - } - sumtype = sqltypes.Uint64 - case sqltypes.Decimal: - if swap { - c.compileToDecimal(rt, 2) - } else { - c.compileToDecimal(rt, 1) - } - c.emitAdd_dd() - sumtype = sqltypes.Decimal - case sqltypes.Float64: - if swap { - c.compileToFloat(rt, 2) - } else { - c.compileToFloat(rt, 1) - } - c.emitAdd_ff() - sumtype = sqltypes.Float64 - } - - c.jumpDestination(skip) - return ctype{Type: sumtype, Col: collationNumeric}, nil -} - -func (c *compiler) compileArithmeticSub(left, right Expr) (ctype, error) { - lt, err := c.compileExpr(left) - if err != nil { - return ctype{}, err - } - - rt, err := c.compileExpr(right) - if err != nil { - return ctype{}, err - } - - skip := c.jumpFrom() - c.emitNullCheck2(skip) - lt = c.compileToNumeric(lt, 2) - rt = c.compileToNumeric(rt, 1) - - var subtype sqltypes.Type - - switch lt.Type { - case sqltypes.Int64: - switch rt.Type { - case sqltypes.Int64: - c.emitSub_ii() - subtype = sqltypes.Int64 - case sqltypes.Uint64: - c.emitSub_iu() - subtype = sqltypes.Uint64 - case sqltypes.Float64: - c.compileToFloat(lt, 2) - c.emitSub_ff() - subtype = sqltypes.Float64 - case sqltypes.Decimal: - c.compileToDecimal(lt, 2) - c.emitSub_dd() - subtype = sqltypes.Decimal - } - case sqltypes.Uint64: - switch rt.Type { - case sqltypes.Int64: - c.emitSub_ui() - subtype = sqltypes.Uint64 - case sqltypes.Uint64: - c.emitSub_uu() - subtype = sqltypes.Uint64 - case sqltypes.Float64: - c.compileToFloat(lt, 2) - c.emitSub_ff() - subtype = sqltypes.Float64 - case sqltypes.Decimal: - c.compileToDecimal(lt, 2) - c.emitSub_dd() - subtype = sqltypes.Decimal - } - case sqltypes.Float64: - c.compileToFloat(rt, 1) - c.emitSub_ff() - subtype = sqltypes.Float64 - case sqltypes.Decimal: - switch rt.Type { - case sqltypes.Float64: - c.compileToFloat(lt, 2) - c.emitSub_ff() - subtype = sqltypes.Float64 - default: - c.compileToDecimal(rt, 1) - c.emitSub_dd() - subtype = sqltypes.Decimal - } - } - - if subtype == 0 { - panic("did not compile?") - } - - c.jumpDestination(skip) - return ctype{Type: subtype, Col: collationNumeric}, nil -} - func (c *compiler) compileToDecimal(ct ctype, offset int) ctype { if sqltypes.IsDecimal(ct.Type) { return ct } switch ct.Type { case sqltypes.Int64: - c.emitConvert_id(offset) + c.asm.Convert_id(offset) case sqltypes.Uint64: - c.emitConvert_ud(offset) + c.asm.Convert_ud(offset) default: - c.emitConvert_xd(offset, 0, 0) + c.asm.Convert_xd(offset, 0, 0) } return ctype{sqltypes.Decimal, ct.Flag, collationNumeric} } - -func (c *compiler) compileArithmeticMul(left, right Expr) (ctype, error) { - lt, err := c.compileExpr(left) - if err != nil { - return ctype{}, err - } - - rt, err := c.compileExpr(right) - if err != nil { - return ctype{}, err - } - - swap := false - skip := c.jumpFrom() - c.emitNullCheck2(skip) - lt = c.compileToNumeric(lt, 2) - rt = c.compileToNumeric(rt, 1) - lt, rt, swap = c.compileNumericPriority(lt, rt) - - var multype sqltypes.Type - - switch lt.Type { - case sqltypes.Int64: - c.emitMul_ii() - multype = sqltypes.Int64 - case sqltypes.Uint64: - switch rt.Type { - case sqltypes.Int64: - c.emitMul_ui(swap) - case sqltypes.Uint64: - c.emitMul_uu() - } - multype = sqltypes.Uint64 - case sqltypes.Float64: - if swap { - c.compileToFloat(rt, 2) - } else { - c.compileToFloat(rt, 1) - } - c.emitMul_ff() - multype = sqltypes.Float64 - case sqltypes.Decimal: - if swap { - c.compileToDecimal(rt, 2) - } else { - c.compileToDecimal(rt, 1) - } - c.emitMul_dd() - multype = sqltypes.Decimal - } - - c.jumpDestination(skip) - return ctype{Type: multype, Col: collationNumeric}, nil -} - -func (c *compiler) compileArithmeticDiv(left, right Expr) (ctype, error) { - lt, err := c.compileExpr(left) - if err != nil { - return ctype{}, err - } - - rt, err := c.compileExpr(right) - if err != nil { - return ctype{}, err - } - - skip := c.jumpFrom() - c.emitNullCheck2(skip) - lt = c.compileToNumeric(lt, 2) - rt = c.compileToNumeric(rt, 1) - - if lt.Type == sqltypes.Float64 || rt.Type == sqltypes.Float64 { - c.compileToFloat(lt, 2) - c.compileToFloat(rt, 1) - c.emitDiv_ff() - return ctype{Type: sqltypes.Float64, Col: collationNumeric}, nil - } else { - c.compileToDecimal(lt, 2) - c.compileToDecimal(rt, 1) - c.emitDiv_dd() - return ctype{Type: sqltypes.Decimal, Col: collationNumeric}, nil - } -} - -func (c *compiler) compileCollate(expr *CollateExpr) (ctype, error) { - ct, err := c.compileExpr(expr.Inner) - if err != nil { - return ctype{}, err - } - - skip := c.jumpFrom() - c.emitNullCheck1(skip) - - switch ct.Type { - case sqltypes.VarChar: - if err := collations.Local().EnsureCollate(ct.Col.Collation, expr.TypedCollation.Collation); err != nil { - return ctype{}, vterrors.New(vtrpc.Code_INVALID_ARGUMENT, err.Error()) - } - fallthrough - case sqltypes.VarBinary: - c.emitCollate(expr.TypedCollation.Collation) - default: - return ctype{}, c.unsupported(expr) - } - - c.jumpDestination(skip) - - ct.Col = expr.TypedCollation - ct.Flag |= flagExplicitCollation - return ct, nil -} - -func (c *compiler) extractJSONPath(expr Expr) (*json.Path, error) { - path, ok := expr.(*Literal) - if !ok { - return nil, errJSONPath - } - pathBytes, ok := path.inner.(*evalBytes) - if !ok { - return nil, errJSONPath - } - var parser json.PathParser - return parser.ParseBytes(pathBytes.bytes) -} - -func (c *compiler) extractOneOrAll(fname string, expr Expr) (jsonMatch, error) { - lit, ok := expr.(*Literal) - if !ok { - return jsonMatchInvalid, errOneOrAll(fname) - } - b, ok := lit.inner.(*evalBytes) - if !ok { - return jsonMatchInvalid, errOneOrAll(fname) - } - return intoOneOrAll(fname, b.string()) -} - -func (c *compiler) compileJSONExtract(call *builtinJSONExtract) (ctype, error) { - doct, err := c.compileExpr(call.Arguments[0]) - if err != nil { - return ctype{}, err - } - - if slices2.All(call.Arguments[1:], func(expr Expr) bool { return expr.constant() }) { - paths := make([]*json.Path, 0, len(call.Arguments[1:])) - - for _, arg := range call.Arguments[1:] { - jp, err := c.extractJSONPath(arg) - if err != nil { - return ctype{}, err - } - paths = append(paths, jp) - } - - jt, err := c.compileParseJSON("JSON_EXTRACT", doct, 1) - if err != nil { - return ctype{}, err - } - - c.emitFn_JSON_EXTRACT0(paths) - return jt, nil - } - - return ctype{}, c.unsupported(call) -} - -func (c *compiler) compileJSONUnquote(call *builtinJSONUnquote) (ctype, error) { - arg, err := c.compileExpr(call.Arguments[0]) - if err != nil { - return ctype{}, err - } - - skip := c.jumpFrom() - c.emitNullCheck1(skip) - - _, err = c.compileParseJSON("JSON_UNQUOTE", arg, 1) - if err != nil { - return ctype{}, err - } - - c.emitFn_JSON_UNQUOTE() - c.jumpDestination(skip) - return ctype{Type: sqltypes.Blob, Col: collationJSON}, nil -} - -func (c *compiler) compileJSONContainsPath(call *builtinJSONContainsPath) (ctype, error) { - doct, err := c.compileExpr(call.Arguments[0]) - if err != nil { - return ctype{}, err - } - - if !call.Arguments[1].constant() { - return ctype{}, c.unsupported(call) - } - - if !slices2.All(call.Arguments[2:], func(expr Expr) bool { return expr.constant() }) { - return ctype{}, c.unsupported(call) - } - - match, err := c.extractOneOrAll("JSON_CONTAINS_PATH", call.Arguments[1]) - if err != nil { - return ctype{}, err - } - - paths := make([]*json.Path, 0, len(call.Arguments[2:])) - - for _, arg := range call.Arguments[2:] { - jp, err := c.extractJSONPath(arg) - if err != nil { - return ctype{}, err - } - paths = append(paths, jp) - } - - _, err = c.compileParseJSON("JSON_CONTAINS_PATH", doct, 1) - if err != nil { - return ctype{}, err - } - - c.emitFn_JSON_CONTAINS_PATH(match, paths) - return ctype{Type: sqltypes.Int64, Col: collationNumeric}, nil -} - -func (c *compiler) compileParseJSON(fn string, doct ctype, offset int) (ctype, error) { - switch doct.Type { - case sqltypes.TypeJSON: - case sqltypes.VarChar, sqltypes.VarBinary: - c.emitParse_j(offset) - default: - return ctype{}, errJSONType(fn) - } - return ctype{Type: sqltypes.TypeJSON, Col: collationJSON}, nil -} - -func (c *compiler) compileToJSON(doct ctype, offset int) (ctype, error) { - switch doct.Type { - case sqltypes.TypeJSON: - return doct, nil - case sqltypes.Float64: - c.emitConvert_fj(offset) - case sqltypes.Int64, sqltypes.Uint64, sqltypes.Decimal: - c.emitConvert_nj(offset) - case sqltypes.VarChar: - c.emitConvert_cj(offset) - case sqltypes.VarBinary: - c.emitConvert_bj(offset) - case sqltypes.Null: - c.emitConvert_Nj(offset) - default: - return ctype{}, vterrors.Errorf(vtrpc.Code_UNIMPLEMENTED, "Unsupported type conversion: %s AS JSON", doct.Type) - } - return ctype{Type: sqltypes.TypeJSON, Col: collationJSON}, nil -} - -func (c *compiler) compileBitwise(expr *BitwiseExpr) (ctype, error) { - switch expr.Op.(type) { - case *opBitAnd: - return c.compileBitwiseOp(expr.Left, expr.Right, and) - case *opBitOr: - return c.compileBitwiseOp(expr.Left, expr.Right, or) - case *opBitXor: - return c.compileBitwiseOp(expr.Left, expr.Right, xor) - case *opBitShl: - return c.compileBitwiseShift(expr.Left, expr.Right, -1) - case *opBitShr: - return c.compileBitwiseShift(expr.Left, expr.Right, 1) - default: - panic("unexpected arithmetic operator") - } -} - -type bitwiseOp int - -const ( - and bitwiseOp = iota - or - xor -) - -func (c *compiler) compileBitwiseOp(left Expr, right Expr, op bitwiseOp) (ctype, error) { - lt, err := c.compileExpr(left) - if err != nil { - return ctype{}, err - } - - rt, err := c.compileExpr(right) - if err != nil { - return ctype{}, err - } - - skip := c.jumpFrom() - c.emitNullCheck2(skip) - - if lt.Type == sqltypes.VarBinary && rt.Type == sqltypes.VarBinary { - if !lt.isHexOrBitLiteral() || !rt.isHexOrBitLiteral() { - c.emitBitOp_bb(op) - c.jumpDestination(skip) - return ctype{Type: sqltypes.VarBinary, Col: collationBinary}, nil - } - } - - lt = c.compileToBitwiseUint64(lt, 2) - rt = c.compileToBitwiseUint64(rt, 1) - - c.emitBitOp_uu(op) - c.jumpDestination(skip) - return ctype{Type: sqltypes.Uint64, Col: collationNumeric}, nil -} - -func (c *compiler) compileBitwiseShift(left Expr, right Expr, i int) (ctype, error) { - lt, err := c.compileExpr(left) - if err != nil { - return ctype{}, err - } - - rt, err := c.compileExpr(right) - if err != nil { - return ctype{}, err - } - - skip := c.jumpFrom() - c.emitNullCheck2(skip) - - if lt.Type == sqltypes.VarBinary && !lt.isHexOrBitLiteral() { - _ = c.compileToUint64(rt, 1) - if i < 0 { - c.emitBitShiftLeft_bu() - } else { - c.emitBitShiftRight_bu() - } - c.jumpDestination(skip) - return ctype{Type: sqltypes.VarBinary, Col: collationBinary}, nil - } - - _ = c.compileToBitwiseUint64(lt, 2) - _ = c.compileToUint64(rt, 1) - - if i < 0 { - c.emitBitShiftLeft_uu() - } else { - c.emitBitShiftRight_uu() - } - - c.jumpDestination(skip) - return ctype{Type: sqltypes.Uint64, Col: collationNumeric}, nil -} - -func (c *compiler) compileBitCount(expr *builtinBitCount) (ctype, error) { - ct, err := c.compileExpr(expr.Arguments[0]) - if err != nil { - return ctype{}, err - } - - skip := c.jumpFrom() - c.emitNullCheck1(skip) - - if ct.Type == sqltypes.VarBinary && !ct.isHexOrBitLiteral() { - c.emitBitCount_b() - c.jumpDestination(skip) - return ctype{Type: sqltypes.Int64, Col: collationBinary}, nil - } - - _ = c.compileToBitwiseUint64(ct, 1) - c.emitBitCount_u() - c.jumpDestination(skip) - return ctype{Type: sqltypes.Int64, Col: collationBinary}, nil -} - -func (c *compiler) compileBitwiseNot(expr *BitwiseNotExpr) (ctype, error) { - ct, err := c.compileExpr(expr.Inner) - if err != nil { - return ctype{}, err - } - - skip := c.jumpFrom() - c.emitNullCheck1(skip) - - if ct.Type == sqltypes.VarBinary && !ct.isHexOrBitLiteral() { - c.emitBitwiseNot_b() - c.jumpDestination(skip) - return ct, nil - } - - ct = c.compileToBitwiseUint64(ct, 1) - c.emitBitwiseNot_u() - c.jumpDestination(skip) - return ct, nil -} - -func (c *compiler) compileConvert(conv *ConvertExpr) (ctype, error) { - arg, err := c.compileExpr(conv.Inner) - if err != nil { - return ctype{}, err - } - - skip := c.jumpFrom() - c.emitNullCheck1(skip) - - var convt ctype - - switch conv.Type { - case "BINARY": - convt = ctype{Type: conv.convertToBinaryType(arg.Type), Col: collationBinary} - c.emitConvert_xb(1, convt.Type, conv.Length, conv.HasLength) - - case "CHAR", "NCHAR": - convt = ctype{ - Type: conv.convertToCharType(arg.Type), - Col: collations.TypedCollation{Collation: conv.Collation}, - } - c.emitConvert_xc(1, convt.Type, convt.Col.Collation, conv.Length, conv.HasLength) - - case "DECIMAL": - convt = ctype{Type: sqltypes.Decimal, Col: collationNumeric} - m, d := conv.decimalPrecision() - c.emitConvert_xd(1, m, d) - - case "DOUBLE", "REAL": - convt = c.compileToFloat(arg, 1) - - case "SIGNED", "SIGNED INTEGER": - convt = c.compileToInt64(arg, 1) - - case "UNSIGNED", "UNSIGNED INTEGER": - convt = c.compileToUint64(arg, 1) - - case "JSON": - // TODO: what does NULL map to? - convt, err = c.compileToJSON(arg, 1) - if err != nil { - return ctype{}, err - } - - default: - return ctype{}, c.unsupported(conv) - } - - c.jumpDestination(skip) - return convt, nil - -} - -func (c *compiler) compileConvertUsing(conv *ConvertUsingExpr) (ctype, error) { - _, err := c.compileExpr(conv.Inner) - if err != nil { - return ctype{}, err - } - - skip := c.jumpFrom() - c.emitNullCheck1(skip) - c.emitConvert_xc(1, sqltypes.VarChar, conv.Collation, 0, false) - c.jumpDestination(skip) - - col := collations.TypedCollation{ - Collation: conv.Collation, - Coercibility: collations.CoerceCoercible, - Repertoire: collations.RepertoireASCII, - } - return ctype{Type: sqltypes.VarChar, Col: col}, nil -} - -func (c *compiler) compileCase(cs *CaseExpr) (ctype, error) { - var ca collationAggregation - var ta typeAggregation - var local = collations.Local() - - for _, wt := range cs.cases { - when, err := c.compileExpr(wt.when) - if err != nil { - return ctype{}, err - } - - if err := c.compileCheckTrue(when, 1); err != nil { - return ctype{}, err - } - - then, err := c.compileExpr(wt.then) - if err != nil { - return ctype{}, err - } - - ta.add(then.Type, then.Flag) - if err := ca.add(local, then.Col); err != nil { - return ctype{}, err - } - } - - if cs.Else != nil { - els, err := c.compileExpr(cs.Else) - if err != nil { - return ctype{}, err - } - - ta.add(els.Type, els.Flag) - if err := ca.add(local, els.Col); err != nil { - return ctype{}, err - } - } - - ct := ctype{Type: ta.result(), Col: ca.result()} - c.emitCmpCase(len(cs.cases), cs.Else != nil, ct.Type, ct.Col) - return ct, nil -} - -func (c *compiler) compileCheckTrue(when ctype, offset int) error { - switch when.Type { - case sqltypes.Int64: - c.emitConvert_iB(offset) - case sqltypes.Uint64: - c.emitConvert_uB(offset) - case sqltypes.Float64: - c.emitConvert_fB(offset) - case sqltypes.Decimal: - c.emitConvert_dB(offset) - case sqltypes.VarChar, sqltypes.VarBinary: - c.emitConvert_bB(offset) - case sqltypes.Null: - c.emitSetBool(offset, false) - default: - return vterrors.Errorf(vtrpc.Code_UNIMPLEMENTED, "unsupported Truth check: %s", when.Type) - } - return nil -} - -func (c *compiler) compileNegate(expr *NegateExpr) (ctype, error) { - arg, err := c.compileExpr(expr.Inner) - if err != nil { - return ctype{}, err - } - - skip := c.jumpFrom() - c.emitNullCheck1(skip) - - arg = c.compileToNumeric(arg, 1) - var neg sqltypes.Type - - switch arg.Type { - case sqltypes.Int64: - neg = sqltypes.Int64 - c.emitNeg_i() - case sqltypes.Uint64: - if arg.Flag&flagHex != 0 { - neg = sqltypes.Float64 - c.emitNeg_hex() - } else { - neg = sqltypes.Int64 - c.emitNeg_u() - } - case sqltypes.Float64: - neg = sqltypes.Float64 - c.emitNeg_f() - case sqltypes.Decimal: - neg = sqltypes.Decimal - c.emitNeg_d() - default: - panic("unexpected Numeric type") - } - - c.jumpDestination(skip) - return ctype{Type: neg, Col: collationNumeric}, nil -} - -func (c *compiler) compileMultiComparison(call *builtinMultiComparison) (ctype, error) { - var ( - integersI int - integersU int - floats int - decimals int - text int - binary int - args []ctype - ) - - /* - If any argument is NULL, the result is NULL. No comparison is needed. - If all arguments are integer-valued, they are compared as integers. - If at least one argument is double precision, they are compared as double-precision values. Otherwise, if at least one argument is a DECIMAL value, they are compared as DECIMAL values. - If the arguments comprise a mix of numbers and strings, they are compared as strings. - If any argument is a nonbinary (character) string, the arguments are compared as nonbinary strings. - In all other cases, the arguments are compared as binary strings. - */ - - for _, expr := range call.Arguments { - tt, err := c.compileExpr(expr) - if err != nil { - return ctype{}, err - } - - args = append(args, tt) - - switch tt.Type { - case sqltypes.Int64: - integersI++ - case sqltypes.Uint64: - integersU++ - case sqltypes.Float64: - floats++ - case sqltypes.Decimal: - decimals++ - case sqltypes.Text, sqltypes.VarChar: - text++ - case sqltypes.Blob, sqltypes.Binary, sqltypes.VarBinary: - binary++ - default: - return ctype{}, c.unsupported(call) - } - } - - if integersI+integersU == len(args) { - if integersI == len(args) { - c.emitFn_MULTICMP_i(len(args), call.cmp < 0) - return ctype{Type: sqltypes.Int64, Col: collationNumeric}, nil - } - if integersU == len(args) { - c.emitFn_MULTICMP_u(len(args), call.cmp < 0) - return ctype{Type: sqltypes.Uint64, Col: collationNumeric}, nil - } - return c.compileMultiComparison_d(args, call.cmp < 0) - } - if binary > 0 || text > 0 { - if text > 0 { - return c.compileMultiComparison_c(args, call.cmp < 0) - } - c.emitFn_MULTICMP_b(len(args), call.cmp < 0) - return ctype{Type: sqltypes.VarBinary, Col: collationBinary}, nil - } else { - if floats > 0 { - for i, tt := range args { - c.compileToFloat(tt, len(args)-i) - } - c.emitFn_MULTICMP_f(len(args), call.cmp < 0) - return ctype{Type: sqltypes.Float64, Col: collationNumeric}, nil - } - if decimals > 0 { - return c.compileMultiComparison_d(args, call.cmp < 0) - } - } - return ctype{}, vterrors.Errorf(vtrpc.Code_INTERNAL, "unexpected argument for GREATEST/LEAST") -} - -func (c *compiler) compileMultiComparison_c(args []ctype, lessThan bool) (ctype, error) { - env := collations.Local() - - var ca collationAggregation - for _, arg := range args { - if err := ca.add(env, arg.Col); err != nil { - return ctype{}, err - } - } - - tc := ca.result() - c.emitFn_MULTICMP_c(len(args), lessThan, tc) - return ctype{Type: sqltypes.VarChar, Col: tc}, nil -} - -func (c *compiler) compileMultiComparison_d(args []ctype, lessThan bool) (ctype, error) { - for i, tt := range args { - c.compileToDecimal(tt, len(args)-i) - } - c.emitFn_MULTICMP_d(len(args), lessThan) - return ctype{Type: sqltypes.Decimal, Col: collationNumeric}, nil -} - -func (c *compiler) compileRepeat(expr *builtinRepeat) (ctype, error) { - str, err := c.compileExpr(expr.Arguments[0]) - if err != nil { - return ctype{}, err - } - - repeat, err := c.compileExpr(expr.Arguments[1]) - if err != nil { - return ctype{}, err - } - - skip := c.jumpFrom() - c.emitNullCheck2(skip) - - switch { - case sqltypes.IsText(str.Type) || sqltypes.IsBinary(str.Type): - default: - c.emitConvert_xc(2, sqltypes.VarChar, c.defaultCollation, 0, false) - } - _ = c.compileToInt64(repeat, 1) - - c.emitFn_REPEAT(1) - c.jumpDestination(skip) - return ctype{Type: sqltypes.VarChar, Col: str.Col}, nil -} - -func (c *compiler) compileToBase64(call *builtinToBase64) (ctype, error) { - str, err := c.compileExpr(call.Arguments[0]) - if err != nil { - return ctype{}, err - } - - skip := c.jumpFrom() - c.emitNullCheck1(skip) - - t := sqltypes.VarChar - if str.Type == sqltypes.Blob || str.Type == sqltypes.TypeJSON { - t = sqltypes.Text - } - - switch { - case sqltypes.IsText(str.Type) || sqltypes.IsBinary(str.Type): - default: - c.emitConvert_xc(1, t, c.defaultCollation, 0, false) - } - - col := collations.TypedCollation{ - Collation: c.defaultCollation, - Coercibility: collations.CoerceCoercible, - Repertoire: collations.RepertoireASCII, - } - - c.emitFn_TO_BASE64(t, col) - c.jumpDestination(skip) - - return ctype{Type: t, Col: col}, nil -} - -func (c *compiler) compileFromBase64(call *builtinFromBase64) (ctype, error) { - str, err := c.compileExpr(call.Arguments[0]) - if err != nil { - return ctype{}, err - } - - skip := c.jumpFrom() - c.emitNullCheck1(skip) - - switch { - case sqltypes.IsText(str.Type) || sqltypes.IsBinary(str.Type): - default: - c.emitConvert_xc(1, sqltypes.VarBinary, c.defaultCollation, 0, false) - } - - c.emitFromBase64() - c.jumpDestination(skip) - - return ctype{Type: sqltypes.VarBinary, Col: collationBinary}, nil -} - -func (c *compiler) compileChangeCase(call *builtinChangeCase) (ctype, error) { - str, err := c.compileExpr(call.Arguments[0]) - if err != nil { - return ctype{}, err - } - - skip := c.jumpFrom() - c.emitNullCheck1(skip) - - switch { - case sqltypes.IsText(str.Type) || sqltypes.IsBinary(str.Type): - default: - c.emitConvert_xc(1, sqltypes.VarChar, c.defaultCollation, 0, false) - } - - c.emitFn_LUCASE(call.upcase) - c.jumpDestination(skip) - - return ctype{Type: sqltypes.VarChar, Col: str.Col}, nil -} - -type lengthOp int - -const ( - charLen lengthOp = iota - byteLen - bitLen -) - -func (c *compiler) compileLength(call callable, op lengthOp) (ctype, error) { - str, err := c.compileExpr(call.callable()[0]) - if err != nil { - return ctype{}, err - } - - skip := c.jumpFrom() - c.emitNullCheck1(skip) - - switch { - case sqltypes.IsText(str.Type) || sqltypes.IsBinary(str.Type): - default: - c.emitConvert_xc(1, sqltypes.VarChar, c.defaultCollation, 0, false) - } - - c.emitFn_LENGTH(op) - c.jumpDestination(skip) - - return ctype{Type: sqltypes.Int64, Col: collationNumeric}, nil -} - -func (c *compiler) compileASCII(call *builtinASCII) (ctype, error) { - str, err := c.compileExpr(call.Arguments[0]) - if err != nil { - return ctype{}, err - } - - skip := c.jumpFrom() - c.emitNullCheck1(skip) - - switch { - case sqltypes.IsText(str.Type) || sqltypes.IsBinary(str.Type): - default: - c.emitConvert_xc(1, sqltypes.VarChar, c.defaultCollation, 0, false) - } - - c.emitFn_ASCII() - c.jumpDestination(skip) - - return ctype{Type: sqltypes.Int64, Col: collationNumeric}, nil -} - -func (c *compiler) compileHex(call *builtinHex) (ctype, error) { - str, err := c.compileExpr(call.Arguments[0]) - if err != nil { - return ctype{}, err - } - - skip := c.jumpFrom() - c.emitNullCheck1(skip) - - col := collations.TypedCollation{ - Collation: c.defaultCollation, - Coercibility: collations.CoerceCoercible, - Repertoire: collations.RepertoireASCII, - } - - t := sqltypes.VarChar - if str.Type == sqltypes.Blob || str.Type == sqltypes.TypeJSON { - t = sqltypes.Text - } - - switch { - case sqltypes.IsNumber(str.Type), sqltypes.IsDecimal(str.Type): - c.emitFn_HEXd(col) - case sqltypes.IsText(str.Type) || sqltypes.IsBinary(str.Type): - c.emitFn_HEXc(t, col) - default: - c.emitConvert_xc(1, t, c.defaultCollation, 0, false) - c.emitFn_HEXc(t, col) - } - - c.jumpDestination(skip) - - return ctype{Type: t, Col: col}, nil -} - -func (c *compiler) compileLike(expr *LikeExpr) (ctype, error) { - lt, err := c.compileExpr(expr.Left) - if err != nil { - return ctype{}, err - } - - rt, err := c.compileExpr(expr.Right) - if err != nil { - return ctype{}, err - } - - skip := c.jumpFrom() - c.emitNullCheck2(skip) - - if !sqltypes.IsText(lt.Type) && !sqltypes.IsBinary(lt.Type) { - c.emitConvert_xc(2, sqltypes.VarChar, c.defaultCollation, 0, false) - lt.Col = collations.TypedCollation{ - Collation: c.defaultCollation, - Coercibility: collations.CoerceCoercible, - Repertoire: collations.RepertoireASCII, - } - } - - if !sqltypes.IsText(rt.Type) && !sqltypes.IsBinary(rt.Type) { - c.emitConvert_xc(1, sqltypes.VarChar, c.defaultCollation, 0, false) - rt.Col = collations.TypedCollation{ - Collation: c.defaultCollation, - Coercibility: collations.CoerceCoercible, - Repertoire: collations.RepertoireASCII, - } - } - - var merged collations.TypedCollation - var coerceLeft collations.Coercion - var coerceRight collations.Coercion - var env = collations.Local() - - if lt.Col.Collation != rt.Col.Collation { - merged, coerceLeft, coerceRight, err = env.MergeCollations(lt.Col, rt.Col, collations.CoercionOptions{ - ConvertToSuperset: true, - ConvertWithCoercion: true, - }) - } else { - merged = lt.Col - } - if err != nil { - return ctype{}, err - } - - if coerceLeft == nil && coerceRight == nil { - c.emitLike_collate(expr, merged.Collation.Get()) - } else { - if coerceLeft == nil { - coerceLeft = func(dst, in []byte) ([]byte, error) { return in, nil } - } - if coerceRight == nil { - coerceRight = func(dst, in []byte) ([]byte, error) { return in, nil } - } - c.emitLike_coerce(expr, &compiledCoercion{ - col: merged.Collation.Get(), - left: coerceLeft, - right: coerceRight, - }) - } - - c.jumpDestination(skip) - - return ctype{Type: sqltypes.Int64, Col: collationNumeric}, nil -} - -func (c *compiler) compileJSONArray(call *builtinJSONArray) (ctype, error) { - for _, arg := range call.Arguments { - tt, err := c.compileExpr(arg) - if err != nil { - return ctype{}, err - } - - _, err = c.compileToJSON(tt, 1) - if err != nil { - return ctype{}, err - } - } - c.emitFn_JSON_ARRAY(len(call.Arguments)) - return ctype{Type: sqltypes.TypeJSON, Col: collationJSON}, nil -} - -func (c *compiler) compileJSONObject(call *builtinJSONObject) (ctype, error) { - for i := 0; i < len(call.Arguments); i += 2 { - key, err := c.compileExpr(call.Arguments[i]) - if err != nil { - return ctype{}, err - } - c.compileToJSONKey(key) - val, err := c.compileExpr(call.Arguments[i+1]) - if err != nil { - return ctype{}, err - } - val, err = c.compileToJSON(val, 1) - if err != nil { - return ctype{}, err - } - } - c.emitFn_JSON_OBJECT(len(call.Arguments)) - return ctype{Type: sqltypes.TypeJSON, Col: collationJSON}, nil -} - -func isEncodingJSONSafe(col collations.ID) bool { - switch col.Get().Charset().(type) { - case charset.Charset_utf8mb4, charset.Charset_utf8mb3, charset.Charset_binary: - return true - default: - return false - } -} - -func (c *compiler) compileToJSONKey(key ctype) { - if key.Type == sqltypes.VarChar && isEncodingJSONSafe(key.Col.Collation) { - return - } - if key.Type == sqltypes.VarBinary { - return - } - c.emitConvert_xc(1, sqltypes.VarChar, collations.CollationUtf8mb4ID, 0, false) -} - -func (c *compiler) compileTuple(tuple TupleExpr) (ctype, error) { - for _, arg := range tuple { - _, err := c.compileExpr(arg) - if err != nil { - return ctype{}, err - } - } - c.emitPackTuple(len(tuple)) - return ctype{Type: sqltypes.Tuple, Col: collationBinary}, nil -} diff --git a/go/vt/vtgate/evalengine/compiler_arithmetic.go b/go/vt/vtgate/evalengine/compiler_arithmetic.go new file mode 100644 index 00000000000..c3852bd23be --- /dev/null +++ b/go/vt/vtgate/evalengine/compiler_arithmetic.go @@ -0,0 +1,289 @@ +package evalengine + +import "vitess.io/vitess/go/sqltypes" + +func (c *compiler) compileNegate(expr *NegateExpr) (ctype, error) { + arg, err := c.compileExpr(expr.Inner) + if err != nil { + return ctype{}, err + } + + skip := c.asm.jumpFrom() + c.asm.NullCheck1(skip) + + arg = c.compileToNumeric(arg, 1) + var neg sqltypes.Type + + switch arg.Type { + case sqltypes.Int64: + neg = sqltypes.Int64 + c.asm.Neg_i() + case sqltypes.Uint64: + if arg.Flag&flagHex != 0 { + neg = sqltypes.Float64 + c.asm.Neg_hex() + } else { + neg = sqltypes.Int64 + c.asm.Neg_u() + } + case sqltypes.Float64: + neg = sqltypes.Float64 + c.asm.Neg_f() + case sqltypes.Decimal: + neg = sqltypes.Decimal + c.asm.Neg_d() + default: + panic("unexpected Numeric type") + } + + c.asm.jumpDestination(skip) + return ctype{Type: neg, Col: collationNumeric}, nil +} + +func (c *compiler) compileArithmetic(expr *ArithmeticExpr) (ctype, error) { + switch expr.Op.(type) { + case *opArithAdd: + return c.compileArithmeticAdd(expr.Left, expr.Right) + case *opArithSub: + return c.compileArithmeticSub(expr.Left, expr.Right) + case *opArithMul: + return c.compileArithmeticMul(expr.Left, expr.Right) + case *opArithDiv: + return c.compileArithmeticDiv(expr.Left, expr.Right) + default: + panic("unexpected arithmetic operator") + } +} + +func (c *compiler) compileNumericPriority(lt, rt ctype) (ctype, ctype, bool) { + switch lt.Type { + case sqltypes.Int64: + if rt.Type == sqltypes.Uint64 || rt.Type == sqltypes.Float64 || rt.Type == sqltypes.Decimal { + return rt, lt, true + } + case sqltypes.Uint64: + if rt.Type == sqltypes.Float64 || rt.Type == sqltypes.Decimal { + return rt, lt, true + } + case sqltypes.Decimal: + if rt.Type == sqltypes.Float64 { + return rt, lt, true + } + } + return lt, rt, false +} + +func (c *compiler) compileArithmeticAdd(left, right Expr) (ctype, error) { + lt, err := c.compileExpr(left) + if err != nil { + return ctype{}, err + } + + rt, err := c.compileExpr(right) + if err != nil { + return ctype{}, err + } + + swap := false + skip := c.asm.jumpFrom() + c.asm.NullCheck2(skip) + lt = c.compileToNumeric(lt, 2) + rt = c.compileToNumeric(rt, 1) + lt, rt, swap = c.compileNumericPriority(lt, rt) + + var sumtype sqltypes.Type + + switch lt.Type { + case sqltypes.Int64: + c.asm.Add_ii() + sumtype = sqltypes.Int64 + case sqltypes.Uint64: + switch rt.Type { + case sqltypes.Int64: + c.asm.Add_ui(swap) + case sqltypes.Uint64: + c.asm.Add_uu() + } + sumtype = sqltypes.Uint64 + case sqltypes.Decimal: + if swap { + c.compileToDecimal(rt, 2) + } else { + c.compileToDecimal(rt, 1) + } + c.asm.Add_dd() + sumtype = sqltypes.Decimal + case sqltypes.Float64: + if swap { + c.compileToFloat(rt, 2) + } else { + c.compileToFloat(rt, 1) + } + c.asm.Add_ff() + sumtype = sqltypes.Float64 + } + + c.asm.jumpDestination(skip) + return ctype{Type: sumtype, Col: collationNumeric}, nil +} + +func (c *compiler) compileArithmeticSub(left, right Expr) (ctype, error) { + lt, err := c.compileExpr(left) + if err != nil { + return ctype{}, err + } + + rt, err := c.compileExpr(right) + if err != nil { + return ctype{}, err + } + + skip := c.asm.jumpFrom() + c.asm.NullCheck2(skip) + lt = c.compileToNumeric(lt, 2) + rt = c.compileToNumeric(rt, 1) + + var subtype sqltypes.Type + + switch lt.Type { + case sqltypes.Int64: + switch rt.Type { + case sqltypes.Int64: + c.asm.Sub_ii() + subtype = sqltypes.Int64 + case sqltypes.Uint64: + c.asm.Sub_iu() + subtype = sqltypes.Uint64 + case sqltypes.Float64: + c.compileToFloat(lt, 2) + c.asm.Sub_ff() + subtype = sqltypes.Float64 + case sqltypes.Decimal: + c.compileToDecimal(lt, 2) + c.asm.Sub_dd() + subtype = sqltypes.Decimal + } + case sqltypes.Uint64: + switch rt.Type { + case sqltypes.Int64: + c.asm.Sub_ui() + subtype = sqltypes.Uint64 + case sqltypes.Uint64: + c.asm.Sub_uu() + subtype = sqltypes.Uint64 + case sqltypes.Float64: + c.compileToFloat(lt, 2) + c.asm.Sub_ff() + subtype = sqltypes.Float64 + case sqltypes.Decimal: + c.compileToDecimal(lt, 2) + c.asm.Sub_dd() + subtype = sqltypes.Decimal + } + case sqltypes.Float64: + c.compileToFloat(rt, 1) + c.asm.Sub_ff() + subtype = sqltypes.Float64 + case sqltypes.Decimal: + switch rt.Type { + case sqltypes.Float64: + c.compileToFloat(lt, 2) + c.asm.Sub_ff() + subtype = sqltypes.Float64 + default: + c.compileToDecimal(rt, 1) + c.asm.Sub_dd() + subtype = sqltypes.Decimal + } + } + + if subtype == 0 { + panic("did not compile?") + } + + c.asm.jumpDestination(skip) + return ctype{Type: subtype, Col: collationNumeric}, nil +} + +func (c *compiler) compileArithmeticMul(left, right Expr) (ctype, error) { + lt, err := c.compileExpr(left) + if err != nil { + return ctype{}, err + } + + rt, err := c.compileExpr(right) + if err != nil { + return ctype{}, err + } + + swap := false + skip := c.asm.jumpFrom() + c.asm.NullCheck2(skip) + lt = c.compileToNumeric(lt, 2) + rt = c.compileToNumeric(rt, 1) + lt, rt, swap = c.compileNumericPriority(lt, rt) + + var multype sqltypes.Type + + switch lt.Type { + case sqltypes.Int64: + c.asm.Mul_ii() + multype = sqltypes.Int64 + case sqltypes.Uint64: + switch rt.Type { + case sqltypes.Int64: + c.asm.Mul_ui(swap) + case sqltypes.Uint64: + c.asm.Mul_uu() + } + multype = sqltypes.Uint64 + case sqltypes.Float64: + if swap { + c.compileToFloat(rt, 2) + } else { + c.compileToFloat(rt, 1) + } + c.asm.Mul_ff() + multype = sqltypes.Float64 + case sqltypes.Decimal: + if swap { + c.compileToDecimal(rt, 2) + } else { + c.compileToDecimal(rt, 1) + } + c.asm.Mul_dd() + multype = sqltypes.Decimal + } + + c.asm.jumpDestination(skip) + return ctype{Type: multype, Col: collationNumeric}, nil +} + +func (c *compiler) compileArithmeticDiv(left, right Expr) (ctype, error) { + lt, err := c.compileExpr(left) + if err != nil { + return ctype{}, err + } + + rt, err := c.compileExpr(right) + if err != nil { + return ctype{}, err + } + + skip := c.asm.jumpFrom() + c.asm.NullCheck2(skip) + lt = c.compileToNumeric(lt, 2) + rt = c.compileToNumeric(rt, 1) + + if lt.Type == sqltypes.Float64 || rt.Type == sqltypes.Float64 { + c.compileToFloat(lt, 2) + c.compileToFloat(rt, 1) + c.asm.Div_ff() + return ctype{Type: sqltypes.Float64, Col: collationNumeric}, nil + } else { + c.compileToDecimal(lt, 2) + c.compileToDecimal(rt, 1) + c.asm.Div_dd() + return ctype{Type: sqltypes.Decimal, Col: collationNumeric}, nil + } +} diff --git a/go/vt/vtgate/evalengine/compiler_instructions.go b/go/vt/vtgate/evalengine/compiler_asm.go similarity index 68% rename from go/vt/vtgate/evalengine/compiler_instructions.go rename to go/vt/vtgate/evalengine/compiler_asm.go index 84b096c12bf..a9427b8193a 100644 --- a/go/vt/vtgate/evalengine/compiler_instructions.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -15,283 +15,79 @@ import ( "vitess.io/vitess/go/vt/vtgate/evalengine/internal/json" ) -func cmpnum[N interface{ int64 | uint64 | float64 }](a, b N) int { - switch { - case a == b: - return 0 - case a < b: - return -1 - default: - return 1 - } -} - -func (c *compiler) emit(f frame, asm string, args ...any) { - if c.log != nil { - c.log.Disasm(asm, args...) - } - c.ins = append(c.ins, f) +type jump struct { + from, to int } -func (c *compiler) emitPushLiteral(lit eval) error { - c.adjustStack(1) - - if lit == evalBoolTrue { - c.emit(func(vm *VirtualMachine) int { - vm.stack[vm.sp] = evalBoolTrue - vm.sp++ - return 1 - }, "PUSH true") - return nil - } - - if lit == evalBoolFalse { - c.emit(func(vm *VirtualMachine) int { - vm.stack[vm.sp] = evalBoolFalse - vm.sp++ - return 1 - }, "PUSH false") - return nil - } - - switch lit := lit.(type) { - case *evalInt64: - c.emit(func(vm *VirtualMachine) int { - vm.stack[vm.sp] = vm.arena.newEvalInt64(lit.i) - vm.sp++ - return 1 - }, "PUSH INT64(%s)", lit.ToRawBytes()) - case *evalUint64: - c.emit(func(vm *VirtualMachine) int { - vm.stack[vm.sp] = vm.arena.newEvalUint64(lit.u) - vm.sp++ - return 1 - }, "PUSH UINT64(%s)", lit.ToRawBytes()) - case *evalFloat: - c.emit(func(vm *VirtualMachine) int { - vm.stack[vm.sp] = vm.arena.newEvalFloat(lit.f) - vm.sp++ - return 1 - }, "PUSH FLOAT64(%s)", lit.ToRawBytes()) - case *evalDecimal: - c.emit(func(vm *VirtualMachine) int { - vm.stack[vm.sp] = vm.arena.newEvalDecimalWithPrec(lit.dec, lit.length) - vm.sp++ - return 1 - }, "PUSH DECIMAL(%s)", lit.ToRawBytes()) - case *evalBytes: - c.emit(func(vm *VirtualMachine) int { - b := vm.arena.newEvalBytesEmpty() - *b = *lit - vm.stack[vm.sp] = b - vm.sp++ - return 1 - }, "PUSH VARCHAR(%q)", lit.ToRawBytes()) - default: - return vterrors.Errorf(vtrpc.Code_UNIMPLEMENTED, "unsupported literal kind '%T'", lit) - } - - return nil +func (j *jump) offset() int { + return j.to - j.from } -func (c *compiler) emitSetBool(offset int, b bool) { - if b { - c.emit(func(vm *VirtualMachine) int { - vm.stack[vm.sp-offset] = evalBoolTrue - return 1 - }, "SET (SP-%d), BOOL(true)", offset) - } else { - c.emit(func(vm *VirtualMachine) int { - vm.stack[vm.sp-offset] = evalBoolFalse - return 1 - }, "SET (SP-%d), BOOL(false)", offset) +type assembler struct { + ins []frame + log AssemblerLog + stack struct { + cur int + max int } } -func (c *compiler) emitPushNull() { - c.adjustStack(1) - - c.emit(func(vm *VirtualMachine) int { - vm.stack[vm.sp] = nil - vm.sp++ - return 1 - }, "PUSH NULL") -} - -func (c *compiler) emitPushColumn_i(offset int) { - c.adjustStack(1) - - c.emit(func(vm *VirtualMachine) int { - var ival int64 - ival, vm.err = vm.row[offset].ToInt64() - vm.stack[vm.sp] = vm.arena.newEvalInt64(ival) - vm.sp++ - return 1 - }, "PUSH INT64(:%d)", offset) -} - -func (c *compiler) emitPushColumn_u(offset int) { - c.adjustStack(1) - - c.emit(func(vm *VirtualMachine) int { - var uval uint64 - uval, vm.err = vm.row[offset].ToUint64() - vm.stack[vm.sp] = vm.arena.newEvalUint64(uval) - vm.sp++ - return 1 - }, "PUSH UINT64(:%d)", offset) -} - -func (c *compiler) emitPushColumn_f(offset int) { - c.adjustStack(1) - - c.emit(func(vm *VirtualMachine) int { - var fval float64 - fval, vm.err = vm.row[offset].ToFloat64() - vm.stack[vm.sp] = vm.arena.newEvalFloat(fval) - vm.sp++ - return 1 - }, "PUSH FLOAT64(:%d)", offset) -} - -func (c *compiler) emitPushColumn_d(offset int) { - c.adjustStack(1) - - c.emit(func(vm *VirtualMachine) int { - var dec decimal.Decimal - dec, vm.err = decimal.NewFromMySQL(vm.row[offset].Raw()) - vm.stack[vm.sp] = vm.arena.newEvalDecimal(dec, 0, 0) - vm.sp++ - return 1 - }, "PUSH DECIMAL(:%d)", offset) -} - -func (c *compiler) emitPushColumn_hexnum(offset int) { - c.adjustStack(1) - - c.emit(func(vm *VirtualMachine) int { - var raw []byte - raw, vm.err = parseHexNumber(vm.row[offset].Raw()) - vm.stack[vm.sp] = newEvalBytesHex(raw) - vm.sp++ - return 1 - }, "PUSH HEXNUM(:%d)", offset) -} - -func (c *compiler) emitPushColumn_hexval(offset int) { - c.adjustStack(1) - - c.emit(func(vm *VirtualMachine) int { - hex := vm.row[offset].Raw() - var raw []byte - raw, vm.err = parseHexLiteral(hex[2 : len(hex)-1]) - vm.stack[vm.sp] = newEvalBytesHex(raw) - vm.sp++ - return 1 - }, "PUSH HEXVAL(:%d)", offset) -} - -func (c *compiler) emitPushColumn_text(offset int, col collations.TypedCollation) { - c.adjustStack(1) - - c.emit(func(vm *VirtualMachine) int { - vm.stack[vm.sp] = newEvalText(vm.row[offset].Raw(), col) - vm.sp++ - return 1 - }, "PUSH VARCHAR(:%d) COLLATE %d", offset, col.Collation) -} - -func (c *compiler) emitPushColumn_bin(offset int) { - c.adjustStack(1) - - c.emit(func(vm *VirtualMachine) int { - vm.stack[vm.sp] = newEvalBinary(vm.row[offset].Raw()) - vm.sp++ - return 1 - }, "PUSH VARBINARY(:%d)", offset) -} - -func (c *compiler) emitPushColumn_json(offset int) { - c.adjustStack(1) - - c.emit(func(vm *VirtualMachine) int { - var parser json.Parser - vm.stack[vm.sp], vm.err = parser.ParseBytes(vm.row[offset].Raw()) - vm.sp++ - return 1 - }, "PUSH JSON(:%d)", offset) -} - -func (c *compiler) emitConvert_hex(offset int) { - c.emit(func(vm *VirtualMachine) int { - var ok bool - vm.stack[vm.sp-offset], ok = vm.stack[vm.sp-offset].(*evalBytes).toNumericHex() - if !ok { - vm.err = errDeoptimize - } - return 1 - }, "CONV VARBINARY(SP-%d), HEX", offset) +func (asm *assembler) jumpFrom() *jump { + return &jump{from: len(asm.ins)} } -func (c *compiler) emitConvert_if(offset int) { - c.emit(func(vm *VirtualMachine) int { - arg := vm.stack[vm.sp-offset].(*evalInt64) - vm.stack[vm.sp-offset] = vm.arena.newEvalFloat(arg.toFloat0()) - return 1 - }, "CONV INT64(SP-%d), FLOAT64", offset) +func (asm *assembler) jumpDestination(j *jump) { + j.to = len(asm.ins) } -func (c *compiler) emitConvert_uf(offset int) { - c.emit(func(vm *VirtualMachine) int { - arg := vm.stack[vm.sp-offset].(*evalUint64) - vm.stack[vm.sp-offset] = vm.arena.newEvalFloat(arg.toFloat0()) - return 1 - }, "CONV UINT64(SP-%d), FLOAT64)", offset) +func (asm *assembler) adjustStack(offset int) { + asm.stack.cur += offset + if asm.stack.cur < 0 { + panic("negative stack position") + } + if asm.stack.cur > asm.stack.max { + asm.stack.max = asm.stack.cur + } + if asm.log != nil { + asm.log.Stack(asm.stack.cur-offset, asm.stack.cur) + } } -func (c *compiler) emitConvert_xf(offset int) { - c.emit(func(vm *VirtualMachine) int { - vm.stack[vm.sp-offset], _ = evalToNumeric(vm.stack[vm.sp-offset]).toFloat() - return 1 - }, "CONV (SP-%d), FLOAT64", offset) +func (asm *assembler) emit(f frame, instruction string, args ...any) { + if asm.log != nil { + asm.log.Instruction(instruction, args...) + } + asm.ins = append(asm.ins, f) } -func (c *compiler) emitConvert_id(offset int) { - c.emit(func(vm *VirtualMachine) int { - arg := vm.stack[vm.sp-offset].(*evalInt64) - vm.stack[vm.sp-offset] = vm.arena.newEvalDecimalWithPrec(decimal.NewFromInt(arg.i), 0) - return 1 - }, "CONV INT64(SP-%d), FLOAT64", offset) -} +func (asm *assembler) Add_dd() { + asm.adjustStack(-1) -func (c *compiler) emitConvert_ud(offset int) { - c.emit(func(vm *VirtualMachine) int { - arg := vm.stack[vm.sp-offset].(*evalUint64) - vm.stack[vm.sp-offset] = vm.arena.newEvalDecimalWithPrec(decimal.NewFromUint(arg.u), 0) + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalDecimal) + r := vm.stack[vm.sp-1].(*evalDecimal) + mathAdd_dd0(l, r) + vm.sp-- return 1 - }, "CONV UINT64(SP-%d), FLOAT64)", offset) + }, "ADD DECIMAL(SP-2), DECIMAL(SP-1)") } -func (c *compiler) emitConvert_xd(offset int, m, d int32) { - c.emit(func(vm *VirtualMachine) int { - vm.stack[vm.sp-offset] = evalToNumeric(vm.stack[vm.sp-offset]).toDecimal(m, d) - return 1 - }, "CONV (SP-%d), FLOAT64", offset) -} +func (asm *assembler) Add_ff() { + asm.adjustStack(-1) -func (c *compiler) emitParse_j(offset int) { - c.emit(func(vm *VirtualMachine) int { - var p json.Parser - arg := vm.stack[vm.sp-offset].(*evalBytes) - vm.stack[vm.sp-offset], vm.err = p.ParseBytes(arg.bytes) + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalFloat) + r := vm.stack[vm.sp-1].(*evalFloat) + l.f += r.f + vm.sp-- return 1 - }, "PARSE_JSON VARCHAR(SP-%d)", offset) + }, "ADD FLOAT64(SP-2), FLOAT64(SP-1)") } -func (c *compiler) emitAdd_ii() { - c.adjustStack(-1) +func (asm *assembler) Add_ii() { + asm.adjustStack(-1) - c.emit(func(vm *VirtualMachine) int { + asm.emit(func(vm *VirtualMachine) int { l := vm.stack[vm.sp-2].(*evalInt64) r := vm.stack[vm.sp-1].(*evalInt64) l.i, vm.err = mathAdd_ii0(l.i, r.i) @@ -300,11 +96,11 @@ func (c *compiler) emitAdd_ii() { }, "ADD INT64(SP-2), INT64(SP-1)") } -func (c *compiler) emitAdd_ui(swap bool) { - c.adjustStack(-1) +func (asm *assembler) Add_ui(swap bool) { + asm.adjustStack(-1) if swap { - c.emit(func(vm *VirtualMachine) int { + asm.emit(func(vm *VirtualMachine) int { var u uint64 l := vm.stack[vm.sp-1].(*evalUint64) r := vm.stack[vm.sp-2].(*evalInt64) @@ -314,7 +110,7 @@ func (c *compiler) emitAdd_ui(swap bool) { return 1 }, "ADD UINT64(SP-1), INT64(SP-2)") } else { - c.emit(func(vm *VirtualMachine) int { + asm.emit(func(vm *VirtualMachine) int { l := vm.stack[vm.sp-2].(*evalUint64) r := vm.stack[vm.sp-1].(*evalInt64) l.u, vm.err = mathAdd_ui0(l.u, r.i) @@ -324,10 +120,10 @@ func (c *compiler) emitAdd_ui(swap bool) { } } -func (c *compiler) emitAdd_uu() { - c.adjustStack(-1) +func (asm *assembler) Add_uu() { + asm.adjustStack(-1) - c.emit(func(vm *VirtualMachine) int { + asm.emit(func(vm *VirtualMachine) int { l := vm.stack[vm.sp-2].(*evalUint64) r := vm.stack[vm.sp-1].(*evalUint64) l.u, vm.err = mathAdd_uu0(l.u, r.u) @@ -336,327 +132,311 @@ func (c *compiler) emitAdd_uu() { }, "ADD UINT64(SP-2), UINT64(SP-1)") } -func (c *compiler) emitAdd_ff() { - c.adjustStack(-1) - - c.emit(func(vm *VirtualMachine) int { - l := vm.stack[vm.sp-2].(*evalFloat) - r := vm.stack[vm.sp-1].(*evalFloat) - l.f += r.f - vm.sp-- - return 1 - }, "ADD FLOAT64(SP-2), FLOAT64(SP-1)") -} - -func (c *compiler) emitAdd_dd() { - c.adjustStack(-1) - - c.emit(func(vm *VirtualMachine) int { - l := vm.stack[vm.sp-2].(*evalDecimal) - r := vm.stack[vm.sp-1].(*evalDecimal) - mathAdd_dd0(l, r) - vm.sp-- - return 1 - }, "ADD DECIMAL(SP-2), DECIMAL(SP-1)") -} - -func (c *compiler) emitSub_ii() { - c.adjustStack(-1) - - c.emit(func(vm *VirtualMachine) int { - l := vm.stack[vm.sp-2].(*evalInt64) - r := vm.stack[vm.sp-1].(*evalInt64) - l.i, vm.err = mathSub_ii0(l.i, r.i) - vm.sp-- - return 1 - }, "SUB INT64(SP-2), INT64(SP-1)") -} - -func (c *compiler) emitSub_iu() { - c.adjustStack(-1) - - c.emit(func(vm *VirtualMachine) int { - l := vm.stack[vm.sp-2].(*evalInt64) - r := vm.stack[vm.sp-1].(*evalUint64) - r.u, vm.err = mathSub_iu0(l.i, r.u) - vm.stack[vm.sp-2] = r - vm.sp-- +func (asm *assembler) BitCount_b() { + asm.emit(func(vm *VirtualMachine) int { + a := vm.stack[vm.sp-1].(*evalBytes) + count := 0 + for _, b := range a.bytes { + count += bits.OnesCount8(b) + } + vm.stack[vm.sp-1] = vm.arena.newEvalInt64(int64(count)) return 1 - }, "SUB INT64(SP-2), UINT64(SP-1)") + }, "BIT_COUNT BINARY(SP-1)") } -func (c *compiler) emitSub_ff() { - c.adjustStack(-1) - - c.emit(func(vm *VirtualMachine) int { - l := vm.stack[vm.sp-2].(*evalFloat) - r := vm.stack[vm.sp-1].(*evalFloat) - l.f -= r.f - vm.sp-- +func (asm *assembler) BitCount_u() { + asm.emit(func(vm *VirtualMachine) int { + a := vm.stack[vm.sp-1].(*evalUint64) + vm.stack[vm.sp-1] = vm.arena.newEvalInt64(int64(bits.OnesCount64(a.u))) return 1 - }, "SUB FLOAT64(SP-2), FLOAT64(SP-1)") + }, "BIT_COUNT UINT64(SP-1)") } -func (c *compiler) emitSub_dd() { - c.adjustStack(-1) +func (asm *assembler) BitOp_bb(op bitwiseOp) { + asm.adjustStack(-1) - c.emit(func(vm *VirtualMachine) int { - l := vm.stack[vm.sp-2].(*evalDecimal) - r := vm.stack[vm.sp-1].(*evalDecimal) - mathSub_dd0(l, r) - vm.sp-- - return 1 - }, "SUB DECIMAL(SP-2), DECIMAL(SP-1)") -} - -func (c *compiler) emitSub_ui() { - c.adjustStack(-1) - - c.emit(func(vm *VirtualMachine) int { - l := vm.stack[vm.sp-2].(*evalUint64) - r := vm.stack[vm.sp-1].(*evalInt64) - l.u, vm.err = mathSub_ui0(l.u, r.i) - vm.sp-- - return 1 - }, "SUB UINT64(SP-2), INT64(SP-1)") -} - -func (c *compiler) emitSub_uu() { - c.adjustStack(-1) - - c.emit(func(vm *VirtualMachine) int { - l := vm.stack[vm.sp-2].(*evalUint64) - r := vm.stack[vm.sp-1].(*evalUint64) - l.u, vm.err = mathSub_uu0(l.u, r.u) - vm.sp-- - return 1 - }, "SUB UINT64(SP-2), UINT64(SP-1)") -} - -func (c *compiler) emitMul_ii() { - c.adjustStack(-1) - - c.emit(func(vm *VirtualMachine) int { - l := vm.stack[vm.sp-2].(*evalInt64) - r := vm.stack[vm.sp-1].(*evalInt64) - l.i, vm.err = mathMul_ii0(l.i, r.i) - vm.sp-- - return 1 - }, "MUL INT64(SP-2), INT64(SP-1)") + switch op { + case and: + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalBytes) + r := vm.stack[vm.sp-1].(*evalBytes) + if len(l.bytes) != len(r.bytes) { + vm.err = errBitwiseOperandsLength + return 0 + } + for i := range l.bytes { + l.bytes[i] = l.bytes[i] & r.bytes[i] + } + vm.sp-- + return 1 + }, "AND BINARY(SP-2), BINARY(SP-1)") + case or: + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalBytes) + r := vm.stack[vm.sp-1].(*evalBytes) + if len(l.bytes) != len(r.bytes) { + vm.err = errBitwiseOperandsLength + return 0 + } + for i := range l.bytes { + l.bytes[i] = l.bytes[i] | r.bytes[i] + } + vm.sp-- + return 1 + }, "OR BINARY(SP-2), BINARY(SP-1)") + case xor: + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalBytes) + r := vm.stack[vm.sp-1].(*evalBytes) + if len(l.bytes) != len(r.bytes) { + vm.err = errBitwiseOperandsLength + return 0 + } + for i := range l.bytes { + l.bytes[i] = l.bytes[i] ^ r.bytes[i] + } + vm.sp-- + return 1 + }, "XOR BINARY(SP-2), BINARY(SP-1)") + } } -func (c *compiler) emitMul_ui(swap bool) { - c.adjustStack(-1) +func (asm *assembler) BitOp_uu(op bitwiseOp) { + asm.adjustStack(-1) - if swap { - c.emit(func(vm *VirtualMachine) int { - var u uint64 - l := vm.stack[vm.sp-1].(*evalUint64) - r := vm.stack[vm.sp-2].(*evalInt64) - u, vm.err = mathMul_ui0(l.u, r.i) - vm.stack[vm.sp-2] = vm.arena.newEvalUint64(u) + switch op { + case and: + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalUint64) + r := vm.stack[vm.sp-1].(*evalUint64) + l.u = l.u & r.u vm.sp-- return 1 - }, "MUL UINT64(SP-1), INT64(SP-2)") - } else { - c.emit(func(vm *VirtualMachine) int { + }, "AND UINT64(SP-2), UINT64(SP-1)") + case or: + asm.emit(func(vm *VirtualMachine) int { l := vm.stack[vm.sp-2].(*evalUint64) - r := vm.stack[vm.sp-1].(*evalInt64) - l.u, vm.err = mathMul_ui0(l.u, r.i) + r := vm.stack[vm.sp-1].(*evalUint64) + l.u = l.u | r.u vm.sp-- return 1 - }, "MUL UINT64(SP-2), INT64(SP-1)") + }, "OR UINT64(SP-2), UINT64(SP-1)") + case xor: + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalUint64) + r := vm.stack[vm.sp-1].(*evalUint64) + l.u = l.u ^ r.u + vm.sp-- + return 1 + }, "XOR UINT64(SP-2), UINT64(SP-1)") } } -func (c *compiler) emitMul_uu() { - c.adjustStack(-1) +func (asm *assembler) BitShiftLeft_bu() { + asm.adjustStack(-1) - c.emit(func(vm *VirtualMachine) int { - l := vm.stack[vm.sp-2].(*evalUint64) + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalBytes) r := vm.stack[vm.sp-1].(*evalUint64) - l.u, vm.err = mathMul_uu0(l.u, r.u) - vm.sp-- - return 1 - }, "MUL UINT64(SP-2), UINT64(SP-1)") -} -func (c *compiler) emitMul_ff() { - c.adjustStack(-1) + var ( + bits = int(r.u & 7) + bytes = int(r.u >> 3) + length = len(l.bytes) + out = make([]byte, length) + ) + + for i := 0; i < length; i++ { + pos := i + bytes + 1 + switch { + case pos < length: + out[i] = l.bytes[pos] >> (8 - bits) + fallthrough + case pos == length: + out[i] |= l.bytes[pos-1] << bits + } + } + l.bytes = out - c.emit(func(vm *VirtualMachine) int { - l := vm.stack[vm.sp-2].(*evalFloat) - r := vm.stack[vm.sp-1].(*evalFloat) - l.f *= r.f vm.sp-- return 1 - }, "MUL FLOAT64(SP-2), FLOAT64(SP-1)") + }, "BIT_SHL BINARY(SP-2), UINT64(SP-1)") } -func (c *compiler) emitMul_dd() { - c.adjustStack(-1) +func (asm *assembler) BitShiftLeft_uu() { + asm.adjustStack(-1) + + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalUint64) + r := vm.stack[vm.sp-1].(*evalUint64) + l.u = l.u << r.u - c.emit(func(vm *VirtualMachine) int { - l := vm.stack[vm.sp-2].(*evalDecimal) - r := vm.stack[vm.sp-1].(*evalDecimal) - mathMul_dd0(l, r) vm.sp-- return 1 - }, "MUL DECIMAL(SP-2), DECIMAL(SP-1)") + }, "BIT_SHL UINT64(SP-2), UINT64(SP-1)") } -func (c *compiler) emitDiv_ff() { - c.adjustStack(-1) +func (asm *assembler) BitShiftRight_bu() { + asm.adjustStack(-1) - c.emit(func(vm *VirtualMachine) int { - l := vm.stack[vm.sp-2].(*evalFloat) - r := vm.stack[vm.sp-1].(*evalFloat) - if r.f == 0.0 { - vm.stack[vm.sp-2] = nil - } else { - l.f, vm.err = mathDiv_ff0(l.f, r.f) - } - vm.sp-- - return 1 - }, "DIV FLOAT64(SP-2), FLOAT64(SP-1)") -} + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalBytes) + r := vm.stack[vm.sp-1].(*evalUint64) -func (c *compiler) emitDiv_dd() { - c.adjustStack(-1) + var ( + bits = int(r.u & 7) + bytes = int(r.u >> 3) + length = len(l.bytes) + out = make([]byte, length) + ) - c.emit(func(vm *VirtualMachine) int { - l := vm.stack[vm.sp-2].(*evalDecimal) - r := vm.stack[vm.sp-1].(*evalDecimal) - if r.dec.IsZero() { - vm.stack[vm.sp-2] = nil - } else { - mathDiv_dd0(l, r, divPrecisionIncrement) + for i := length - 1; i >= 0; i-- { + switch { + case i > bytes: + out[i] = l.bytes[i-bytes-1] << (8 - bits) + fallthrough + case i == bytes: + out[i] |= l.bytes[i-bytes] >> bits + } } + l.bytes = out + vm.sp-- return 1 - }, "DIV DECIMAL(SP-2), DECIMAL(SP-1)") + }, "BIT_SHR BINARY(SP-2), UINT64(SP-1)") } -func (c *compiler) emitNullCmp(j *jump) { - c.emit(func(vm *VirtualMachine) int { - l := vm.stack[vm.sp-2] - r := vm.stack[vm.sp-1] - if l == nil || r == nil { - if l == r { - vm.flags.cmp = 0 - } else { - vm.flags.cmp = 1 - } - vm.sp -= 2 - return j.offset() - } +func (asm *assembler) BitShiftRight_uu() { + asm.adjustStack(-1) + + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalUint64) + r := vm.stack[vm.sp-1].(*evalUint64) + l.u = l.u >> r.u + + vm.sp-- return 1 - }, "NULLCMP SP-1, SP-2") + }, "BIT_SHR UINT64(SP-2), UINT64(SP-1)") } -func (c *compiler) emitNullCheck1(j *jump) { - c.emit(func(vm *VirtualMachine) int { - if vm.stack[vm.sp-1] == nil { - return j.offset() +func (asm *assembler) BitwiseNot_b() { + asm.emit(func(vm *VirtualMachine) int { + a := vm.stack[vm.sp-1].(*evalBytes) + for i := range a.bytes { + a.bytes[i] = ^a.bytes[i] } return 1 - }, "NULLCHECK SP-1") + }, "BIT_NOT BINARY(SP-1)") } -func (c *compiler) emitNullCheck2(j *jump) { - c.emit(func(vm *VirtualMachine) int { - if vm.stack[vm.sp-2] == nil || vm.stack[vm.sp-1] == nil { - vm.stack[vm.sp-2] = nil - vm.sp-- - return j.offset() - } +func (asm *assembler) BitwiseNot_u() { + asm.emit(func(vm *VirtualMachine) int { + a := vm.stack[vm.sp-1].(*evalUint64) + a.u = ^a.u return 1 - }, "NULLCHECK SP-1, SP-2") + }, "BIT_NOT UINT64(SP-1)") } -func (c *compiler) emitCmp_eq() { - c.adjustStack(1) - c.emit(func(vm *VirtualMachine) int { +func (asm *assembler) Cmp_eq() { + asm.adjustStack(1) + asm.emit(func(vm *VirtualMachine) int { vm.stack[vm.sp] = newEvalBool(vm.flags.cmp == 0) vm.sp++ return 1 }, "CMPFLAG EQ") } -func (c *compiler) emitCmp_ne() { - c.adjustStack(1) - c.emit(func(vm *VirtualMachine) int { - vm.stack[vm.sp] = newEvalBool(vm.flags.cmp != 0) +func (asm *assembler) Cmp_eq_n() { + asm.adjustStack(1) + asm.emit(func(vm *VirtualMachine) int { + if vm.flags.null { + vm.stack[vm.sp] = nil + } else { + vm.stack[vm.sp] = newEvalBool(vm.flags.cmp == 0) + } vm.sp++ return 1 - }, "CMPFLAG NE") + }, "CMPFLAG EQ [NULL]") } -func (c *compiler) emitCmp_lt() { - c.adjustStack(1) - c.emit(func(vm *VirtualMachine) int { - vm.stack[vm.sp] = newEvalBool(vm.flags.cmp < 0) +func (asm *assembler) Cmp_ge() { + asm.adjustStack(1) + asm.emit(func(vm *VirtualMachine) int { + vm.stack[vm.sp] = newEvalBool(vm.flags.cmp >= 0) vm.sp++ return 1 - }, "CMPFLAG LT") + }, "CMPFLAG GE") } -func (c *compiler) emitCmp_le() { - c.adjustStack(1) - c.emit(func(vm *VirtualMachine) int { - vm.stack[vm.sp] = newEvalBool(vm.flags.cmp <= 0) +func (asm *assembler) Cmp_ge_n() { + asm.adjustStack(1) + asm.emit(func(vm *VirtualMachine) int { + if vm.flags.null { + vm.stack[vm.sp] = nil + } else { + vm.stack[vm.sp] = newEvalBool(vm.flags.cmp >= 0) + } vm.sp++ return 1 - }, "CMPFLAG LE") + }, "CMPFLAG GE [NULL]") } -func (c *compiler) emitCmp_gt() { - c.adjustStack(1) - c.emit(func(vm *VirtualMachine) int { +func (asm *assembler) Cmp_gt() { + asm.adjustStack(1) + asm.emit(func(vm *VirtualMachine) int { vm.stack[vm.sp] = newEvalBool(vm.flags.cmp > 0) vm.sp++ return 1 }, "CMPFLAG GT") } -func (c *compiler) emitCmp_ge() { - c.adjustStack(1) - c.emit(func(vm *VirtualMachine) int { - vm.stack[vm.sp] = newEvalBool(vm.flags.cmp >= 0) - vm.sp++ - return 1 - }, "CMPFLAG GE") -} - -func (c *compiler) emitCmp_eq_n() { - c.adjustStack(1) - c.emit(func(vm *VirtualMachine) int { +func (asm *assembler) Cmp_gt_n() { + asm.adjustStack(1) + asm.emit(func(vm *VirtualMachine) int { if vm.flags.null { vm.stack[vm.sp] = nil } else { - vm.stack[vm.sp] = newEvalBool(vm.flags.cmp == 0) + vm.stack[vm.sp] = newEvalBool(vm.flags.cmp > 0) } vm.sp++ return 1 - }, "CMPFLAG EQ [NULL]") + }, "CMPFLAG GT [NULL]") +} + +func (asm *assembler) Cmp_le() { + asm.adjustStack(1) + asm.emit(func(vm *VirtualMachine) int { + vm.stack[vm.sp] = newEvalBool(vm.flags.cmp <= 0) + vm.sp++ + return 1 + }, "CMPFLAG LE") } -func (c *compiler) emitCmp_ne_n() { - c.adjustStack(1) - c.emit(func(vm *VirtualMachine) int { +func (asm *assembler) Cmp_le_n() { + asm.adjustStack(1) + asm.emit(func(vm *VirtualMachine) int { if vm.flags.null { vm.stack[vm.sp] = nil } else { - vm.stack[vm.sp] = newEvalBool(vm.flags.cmp != 0) + vm.stack[vm.sp] = newEvalBool(vm.flags.cmp <= 0) } vm.sp++ return 1 - }, "CMPFLAG NE [NULL]") + }, "CMPFLAG LE [NULL]") +} + +func (asm *assembler) Cmp_lt() { + asm.adjustStack(1) + asm.emit(func(vm *VirtualMachine) int { + vm.stack[vm.sp] = newEvalBool(vm.flags.cmp < 0) + vm.sp++ + return 1 + }, "CMPFLAG LT") } -func (c *compiler) emitCmp_lt_n() { - c.adjustStack(1) - c.emit(func(vm *VirtualMachine) int { +func (asm *assembler) Cmp_lt_n() { + asm.adjustStack(1) + asm.emit(func(vm *VirtualMachine) int { if vm.flags.null { vm.stack[vm.sp] = nil } else { @@ -666,87 +446,124 @@ func (c *compiler) emitCmp_lt_n() { return 1 }, "CMPFLAG LT [NULL]") } - -func (c *compiler) emitCmp_le_n() { - c.adjustStack(1) - c.emit(func(vm *VirtualMachine) int { - if vm.flags.null { - vm.stack[vm.sp] = nil - } else { - vm.stack[vm.sp] = newEvalBool(vm.flags.cmp <= 0) - } +func (asm *assembler) Cmp_ne() { + asm.adjustStack(1) + asm.emit(func(vm *VirtualMachine) int { + vm.stack[vm.sp] = newEvalBool(vm.flags.cmp != 0) vm.sp++ return 1 - }, "CMPFLAG LE [NULL]") + }, "CMPFLAG NE") } -func (c *compiler) emitCmp_gt_n() { - c.adjustStack(1) - c.emit(func(vm *VirtualMachine) int { +func (asm *assembler) Cmp_ne_n() { + asm.adjustStack(1) + asm.emit(func(vm *VirtualMachine) int { if vm.flags.null { vm.stack[vm.sp] = nil } else { - vm.stack[vm.sp] = newEvalBool(vm.flags.cmp > 0) + vm.stack[vm.sp] = newEvalBool(vm.flags.cmp != 0) } vm.sp++ return 1 - }, "CMPFLAG GT [NULL]") + }, "CMPFLAG NE [NULL]") } -func (c *compiler) emitCmp_ge_n() { - c.adjustStack(1) - c.emit(func(vm *VirtualMachine) int { - if vm.flags.null { - vm.stack[vm.sp] = nil +func (asm *assembler) CmpCase(cases int, hasElse bool, tt sqltypes.Type, cc collations.TypedCollation) { + elseOffset := 0 + if hasElse { + elseOffset = 1 + } + + stackDepth := 2*cases + elseOffset + asm.adjustStack(-(stackDepth - 1)) + + asm.emit(func(vm *VirtualMachine) int { + end := vm.sp - elseOffset + for sp := vm.sp - stackDepth; sp < end; sp += 2 { + if vm.stack[sp].(*evalInt64).i != 0 { + vm.stack[vm.sp-stackDepth], vm.err = evalCoerce(vm.stack[sp+1], tt, cc.Collation) + goto done + } + } + if elseOffset != 0 { + vm.stack[vm.sp-stackDepth], vm.err = evalCoerce(vm.stack[vm.sp-1], tt, cc.Collation) } else { - vm.stack[vm.sp] = newEvalBool(vm.flags.cmp >= 0) + vm.stack[vm.sp-stackDepth] = nil } - vm.sp++ + done: + vm.sp -= stackDepth - 1 return 1 - }, "CMPFLAG GE [NULL]") + }, "CASE [%d cases, else = %v]", cases, hasElse) } -func (c *compiler) emitPackTuple(tlen int) { - c.adjustStack(-(tlen - 1)) - c.emit(func(vm *VirtualMachine) int { - tuple := make([]eval, tlen) - copy(tuple, vm.stack[vm.sp-tlen:]) - vm.stack[vm.sp-tlen] = &evalTuple{tuple} - vm.sp -= tlen - 1 +func (asm *assembler) CmpNum_dd() { + asm.adjustStack(-2) + + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalDecimal) + r := vm.stack[vm.sp-1].(*evalDecimal) + vm.sp -= 2 + vm.flags.cmp = l.dec.Cmp(r.dec) return 1 - }, "TUPLE (SP-%d)...(SP-1)", tlen) + }, "CMP DECIMAL(SP-2), DECIMAL(SP-1)") } -func (c *compiler) emitCmpTuple(fullEquality bool) { - c.adjustStack(-2) - c.emit(func(vm *VirtualMachine) int { - l := vm.stack[vm.sp-2].(*evalTuple) - r := vm.stack[vm.sp-1].(*evalTuple) +func (asm *assembler) CmpNum_fd(left, right int) { + asm.adjustStack(-2) + + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-left].(*evalFloat) + r := vm.stack[vm.sp-right].(*evalDecimal) vm.sp -= 2 - vm.flags.cmp, vm.flags.null, vm.err = evalCompareMany(l.t, r.t, fullEquality) + fval, ok := r.dec.Float64() + if !ok { + vm.err = errDecimalOutOfRange + } + vm.flags.cmp = cmpnum(l.f, fval) return 1 - }, "CMP TUPLE(SP-2), TUPLE(SP-1)") + }, "CMP FLOAT64(SP-%d), DECIMAL(SP-%d)", left, right) } -func (c *compiler) emitCmpTupleNullsafe() { - c.adjustStack(-1) - c.emit(func(vm *VirtualMachine) int { - l := vm.stack[vm.sp-2].(*evalTuple) - r := vm.stack[vm.sp-1].(*evalTuple) +func (asm *assembler) CmpNum_ff() { + asm.adjustStack(-2) - var equals bool - equals, vm.err = evalCompareTuplesNullSafe(l.t, r.t) + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalFloat) + r := vm.stack[vm.sp-1].(*evalFloat) + vm.sp -= 2 + vm.flags.cmp = cmpnum(l.f, r.f) + return 1 + }, "CMP FLOAT64(SP-2), FLOAT64(SP-1)") +} - vm.stack[vm.sp-2] = newEvalBool(equals) - vm.sp -= 1 +func (asm *assembler) CmpNum_id(left, right int) { + asm.adjustStack(-2) + + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-left].(*evalInt64) + r := vm.stack[vm.sp-right].(*evalDecimal) + vm.sp -= 2 + vm.flags.cmp = decimal.NewFromInt(l.i).Cmp(r.dec) return 1 - }, "CMP NULLSAFE TUPLE(SP-2), TUPLE(SP-1)") + }, "CMP INT64(SP-%d), DECIMAL(SP-%d)", left, right) +} + +func (asm *assembler) CmpNum_if(left, right int) { + asm.adjustStack(-2) + + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-left].(*evalInt64) + r := vm.stack[vm.sp-right].(*evalFloat) + vm.sp -= 2 + vm.flags.cmp = cmpnum(float64(l.i), r.f) + return 1 + }, "CMP INT64(SP-%d), FLOAT64(SP-%d)", left, right) } -func (c *compiler) emitCmpNum_ii() { - c.adjustStack(-2) +func (asm *assembler) CmpNum_ii() { + asm.adjustStack(-2) - c.emit(func(vm *VirtualMachine) int { + asm.emit(func(vm *VirtualMachine) int { l := vm.stack[vm.sp-2].(*evalInt64) r := vm.stack[vm.sp-1].(*evalInt64) vm.sp -= 2 @@ -755,10 +572,10 @@ func (c *compiler) emitCmpNum_ii() { }, "CMP INT64(SP-2), INT64(SP-1)") } -func (c *compiler) emitCmpNum_iu(left, right int) { - c.adjustStack(-2) +func (asm *assembler) CmpNum_iu(left, right int) { + asm.adjustStack(-2) - c.emit(func(vm *VirtualMachine) int { + asm.emit(func(vm *VirtualMachine) int { l := vm.stack[vm.sp-left].(*evalInt64) r := vm.stack[vm.sp-right].(*evalUint64) vm.sp -= 2 @@ -771,34 +588,34 @@ func (c *compiler) emitCmpNum_iu(left, right int) { }, "CMP INT64(SP-%d), UINT64(SP-%d)", left, right) } -func (c *compiler) emitCmpNum_if(left, right int) { - c.adjustStack(-2) +func (asm *assembler) CmpNum_ud(left, right int) { + asm.adjustStack(-2) - c.emit(func(vm *VirtualMachine) int { - l := vm.stack[vm.sp-left].(*evalInt64) - r := vm.stack[vm.sp-right].(*evalFloat) + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-left].(*evalUint64) + r := vm.stack[vm.sp-right].(*evalDecimal) vm.sp -= 2 - vm.flags.cmp = cmpnum(float64(l.i), r.f) + vm.flags.cmp = decimal.NewFromUint(l.u).Cmp(r.dec) return 1 - }, "CMP INT64(SP-%d), FLOAT64(SP-%d)", left, right) + }, "CMP UINT64(SP-%d), DECIMAL(SP-%d)", left, right) } -func (c *compiler) emitCmpNum_id(left, right int) { - c.adjustStack(-2) +func (asm *assembler) CmpNum_uf(left, right int) { + asm.adjustStack(-2) - c.emit(func(vm *VirtualMachine) int { - l := vm.stack[vm.sp-left].(*evalInt64) - r := vm.stack[vm.sp-right].(*evalDecimal) + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-left].(*evalUint64) + r := vm.stack[vm.sp-right].(*evalFloat) vm.sp -= 2 - vm.flags.cmp = decimal.NewFromInt(l.i).Cmp(r.dec) + vm.flags.cmp = cmpnum(float64(l.u), r.f) return 1 - }, "CMP INT64(SP-%d), DECIMAL(SP-%d)", left, right) + }, "CMP UINT64(SP-%d), FLOAT64(SP-%d)", left, right) } -func (c *compiler) emitCmpNum_uu() { - c.adjustStack(-2) +func (asm *assembler) CmpNum_uu() { + asm.adjustStack(-2) - c.emit(func(vm *VirtualMachine) int { + asm.emit(func(vm *VirtualMachine) int { l := vm.stack[vm.sp-2].(*evalUint64) r := vm.stack[vm.sp-1].(*evalUint64) vm.sp -= 2 @@ -807,360 +624,228 @@ func (c *compiler) emitCmpNum_uu() { }, "CMP UINT64(SP-2), UINT64(SP-1)") } -func (c *compiler) emitCmpNum_uf(left, right int) { - c.adjustStack(-2) +func (asm *assembler) CmpString_coerce(coercion *compiledCoercion) { + asm.adjustStack(-2) - c.emit(func(vm *VirtualMachine) int { - l := vm.stack[vm.sp-left].(*evalUint64) - r := vm.stack[vm.sp-right].(*evalFloat) + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalBytes) + r := vm.stack[vm.sp-1].(*evalBytes) vm.sp -= 2 - vm.flags.cmp = cmpnum(float64(l.u), r.f) + + var bl, br []byte + bl, vm.err = coercion.left(nil, l.bytes) + if vm.err != nil { + return 0 + } + br, vm.err = coercion.right(nil, r.bytes) + if vm.err != nil { + return 0 + } + vm.flags.cmp = coercion.col.Collate(bl, br, false) return 1 - }, "CMP UINT64(SP-%d), FLOAT64(SP-%d)", left, right) + }, "CMP VARCHAR(SP-2), VARCHAR(SP-1) COERCE AND COLLATE '%s'", coercion.col.Name()) } -func (c *compiler) emitCmpNum_ud(left, right int) { - c.adjustStack(-2) +func (asm *assembler) CmpString_collate(collation collations.Collation) { + asm.adjustStack(-2) - c.emit(func(vm *VirtualMachine) int { - l := vm.stack[vm.sp-left].(*evalUint64) - r := vm.stack[vm.sp-right].(*evalDecimal) + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalBytes) + r := vm.stack[vm.sp-1].(*evalBytes) vm.sp -= 2 - vm.flags.cmp = decimal.NewFromUint(l.u).Cmp(r.dec) + vm.flags.cmp = collation.Collate(l.bytes, r.bytes, false) return 1 - }, "CMP UINT64(SP-%d), DECIMAL(SP-%d)", left, right) + }, "CMP VARCHAR(SP-2), VARCHAR(SP-1) COLLATE '%s'", collation.Name()) } -func (c *compiler) emitCmpNum_ff() { - c.adjustStack(-2) - - c.emit(func(vm *VirtualMachine) int { - l := vm.stack[vm.sp-2].(*evalFloat) - r := vm.stack[vm.sp-1].(*evalFloat) +func (asm *assembler) CmpTuple(fullEquality bool) { + asm.adjustStack(-2) + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalTuple) + r := vm.stack[vm.sp-1].(*evalTuple) vm.sp -= 2 - vm.flags.cmp = cmpnum(l.f, r.f) + vm.flags.cmp, vm.flags.null, vm.err = evalCompareMany(l.t, r.t, fullEquality) return 1 - }, "CMP FLOAT64(SP-2), FLOAT64(SP-1)") + }, "CMP TUPLE(SP-2), TUPLE(SP-1)") } -func (c *compiler) emitCmpNum_fd(left, right int) { - c.adjustStack(-2) +func (asm *assembler) CmpTupleNullsafe() { + asm.adjustStack(-1) + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalTuple) + r := vm.stack[vm.sp-1].(*evalTuple) - c.emit(func(vm *VirtualMachine) int { - l := vm.stack[vm.sp-left].(*evalFloat) - r := vm.stack[vm.sp-right].(*evalDecimal) - vm.sp -= 2 - fval, ok := r.dec.Float64() - if !ok { - vm.err = errDecimalOutOfRange - } - vm.flags.cmp = cmpnum(l.f, fval) + var equals bool + equals, vm.err = evalCompareTuplesNullSafe(l.t, r.t) + + vm.stack[vm.sp-2] = newEvalBool(equals) + vm.sp -= 1 return 1 - }, "CMP FLOAT64(SP-%d), DECIMAL(SP-%d)", left, right) + }, "CMP NULLSAFE TUPLE(SP-2), TUPLE(SP-1)") } -func (c *compiler) emitCmpNum_dd() { - c.adjustStack(-2) +func (asm *assembler) Collate(col collations.ID) { + asm.emit(func(vm *VirtualMachine) int { + a := vm.stack[vm.sp-1].(*evalBytes) + a.tt = int16(sqltypes.VarChar) + a.col.Collation = col + return 1 + }, "COLLATE VARCHAR(SP-1), %d", col) +} - c.emit(func(vm *VirtualMachine) int { - l := vm.stack[vm.sp-2].(*evalDecimal) - r := vm.stack[vm.sp-1].(*evalDecimal) - vm.sp -= 2 - vm.flags.cmp = l.dec.Cmp(r.dec) +func (asm *assembler) Convert_bB(offset int) { + asm.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-offset] + vm.stack[vm.sp-offset] = newEvalBool(arg != nil && parseStringToFloat(arg.(*evalBytes).string()) != 0.0) return 1 - }, "CMP DECIMAL(SP-2), DECIMAL(SP-1)") + }, "CONV VARBINARY(SP-%d), BOOL", offset) } -func (c *compiler) emitCmpString_collate(collation collations.Collation) { - c.adjustStack(-2) +func (asm *assembler) Convert_bj(offset int) { + asm.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-offset].(*evalBytes) + vm.stack[vm.sp-offset] = evalConvert_bj(arg) + return 1 + }, "CONV VARBINARY(SP-%d), JSON") +} - c.emit(func(vm *VirtualMachine) int { - l := vm.stack[vm.sp-2].(*evalBytes) - r := vm.stack[vm.sp-1].(*evalBytes) - vm.sp -= 2 - vm.flags.cmp = collation.Collate(l.bytes, r.bytes, false) +func (asm *assembler) Convert_cj(offset int) { + asm.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-offset].(*evalBytes) + vm.stack[vm.sp-offset], vm.err = evalConvert_cj(arg) return 1 - }, "CMP VARCHAR(SP-2), VARCHAR(SP-1) COLLATE '%s'", collation.Name()) + }, "CONV VARCHAR(SP-%d), JSON") } -func (c *compiler) emitCmpString_coerce(coercion *compiledCoercion) { - c.adjustStack(-2) +func (asm *assembler) Convert_dB(offset int) { + asm.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-offset] + vm.stack[vm.sp-offset] = newEvalBool(arg != nil && !arg.(*evalDecimal).dec.IsZero()) + return 1 + }, "CONV DECIMAL(SP-%d), BOOL", offset) +} - c.emit(func(vm *VirtualMachine) int { - l := vm.stack[vm.sp-2].(*evalBytes) - r := vm.stack[vm.sp-1].(*evalBytes) - vm.sp -= 2 +// Convert_dbit is a special instruction emission for converting +// a bigdecimal in preparation for a bitwise operation. In that case +// we need to convert the bigdecimal to an int64 and then cast to +// uint64 to ensure we match the behavior of MySQL. +func (asm *assembler) Convert_dbit(offset int) { + asm.emit(func(vm *VirtualMachine) int { + arg := evalToNumeric(vm.stack[vm.sp-offset]) + vm.stack[vm.sp-offset] = vm.arena.newEvalUint64(uint64(arg.toInt64().i)) + return 1 + }, "CONV DECIMAL_BITWISE(SP-%d), UINT64", offset) +} - var bl, br []byte - bl, vm.err = coercion.left(nil, l.bytes) - if vm.err != nil { - return 0 - } - br, vm.err = coercion.right(nil, r.bytes) - if vm.err != nil { - return 0 - } - vm.flags.cmp = coercion.col.Collate(bl, br, false) +func (asm *assembler) Convert_fB(offset int) { + asm.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-offset] + vm.stack[vm.sp-offset] = newEvalBool(arg != nil && arg.(*evalFloat).f != 0.0) return 1 - }, "CMP VARCHAR(SP-2), VARCHAR(SP-1) COERCE AND COLLATE '%s'", coercion.col.Name()) + }, "CONV FLOAT64(SP-%d), BOOL", offset) } -func (c *compiler) emitCollate(col collations.ID) { - c.emit(func(vm *VirtualMachine) int { - a := vm.stack[vm.sp-1].(*evalBytes) - a.tt = int16(sqltypes.VarChar) - a.col.Collation = col +func (asm *assembler) Convert_fj(offset int) { + asm.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-offset].(*evalFloat) + vm.stack[vm.sp-offset] = evalConvert_fj(arg) return 1 - }, "COLLATE VARCHAR(SP-1), %d", col) + }, "CONV FLOAT64(SP-%d), JSON") } -func (c *compiler) emitFn_JSON_UNQUOTE() { - c.emit(func(vm *VirtualMachine) int { - j := vm.stack[vm.sp-1].(*evalJSON) - b := vm.arena.newEvalBytesEmpty() - b.tt = int16(sqltypes.Blob) - b.col = collationJSON - if jbytes, ok := j.StringBytes(); ok { - b.bytes = jbytes - } else { - b.bytes = j.MarshalTo(nil) +func (asm *assembler) Convert_hex(offset int) { + asm.emit(func(vm *VirtualMachine) int { + var ok bool + vm.stack[vm.sp-offset], ok = vm.stack[vm.sp-offset].(*evalBytes).toNumericHex() + if !ok { + vm.err = errDeoptimize } - vm.stack[vm.sp-1] = b return 1 - }, "FN JSON_UNQUOTE (SP-1)") + }, "CONV VARBINARY(SP-%d), HEX", offset) } -func (c *compiler) emitFn_JSON_EXTRACT0(jp []*json.Path) { - multi := len(jp) > 1 || slices2.Any(jp, func(path *json.Path) bool { return path.ContainsWildcards() }) +func (asm *assembler) Convert_iB(offset int) { + asm.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-offset] + vm.stack[vm.sp-offset] = newEvalBool(arg != nil && arg.(*evalInt64).i != 0) + return 1 + }, "CONV INT64(SP-%d), BOOL", offset) +} - if multi { - c.emit(func(vm *VirtualMachine) int { - matches := make([]*json.Value, 0, 4) - arg := vm.stack[vm.sp-1].(*evalJSON) - for _, jp := range jp { - jp.Match(arg, true, func(value *json.Value) { - matches = append(matches, value) - }) - } - if len(matches) == 0 { - vm.stack[vm.sp-1] = nil - } else { - vm.stack[vm.sp-1] = json.NewArray(matches) - } - return 1 - }, "FN JSON_EXTRACT, SP-1, [static]") - } else { - c.emit(func(vm *VirtualMachine) int { - var match *json.Value - arg := vm.stack[vm.sp-1].(*evalJSON) - jp[0].Match(arg, true, func(value *json.Value) { - match = value - }) - if match == nil { - vm.stack[vm.sp-1] = nil - } else { - vm.stack[vm.sp-1] = match - } - return 1 - }, "FN JSON_EXTRACT, SP-1, [static]") - } -} - -func (c *compiler) emitFn_JSON_CONTAINS_PATH(match jsonMatch, paths []*json.Path) { - switch match { - case jsonMatchOne: - c.emit(func(vm *VirtualMachine) int { - arg := vm.stack[vm.sp-1].(*evalJSON) - matched := false - for _, p := range paths { - p.Match(arg, true, func(*json.Value) { matched = true }) - if matched { - break - } - } - vm.stack[vm.sp-1] = newEvalBool(matched) - return 1 - }, "FN JSON_CONTAINS_PATH, SP-1, 'one', [static]") - case jsonMatchAll: - c.emit(func(vm *VirtualMachine) int { - arg := vm.stack[vm.sp-1].(*evalJSON) - matched := true - for _, p := range paths { - matched = false - p.Match(arg, true, func(*json.Value) { matched = true }) - if !matched { - break - } - } - vm.stack[vm.sp-1] = newEvalBool(matched) - return 1 - }, "FN JSON_CONTAINS_PATH, SP-1, 'all', [static]") - } -} - -func (c *compiler) emitFn_JSON_ARRAY(args int) { - c.adjustStack(-(args - 1)) - c.emit(func(vm *VirtualMachine) int { - ary := make([]*json.Value, 0, args) - for sp := vm.sp - args; sp < vm.sp; sp++ { - ary = append(ary, vm.stack[sp].(*json.Value)) - } - vm.stack[vm.sp-args] = json.NewArray(ary) - vm.sp -= args - 1 +func (asm *assembler) Convert_id(offset int) { + asm.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-offset].(*evalInt64) + vm.stack[vm.sp-offset] = vm.arena.newEvalDecimalWithPrec(decimal.NewFromInt(arg.i), 0) return 1 - }, "FN JSON_ARRAY (SP-%d)...(SP-1)", args) + }, "CONV INT64(SP-%d), FLOAT64", offset) } -func (c *compiler) emitFn_JSON_OBJECT(args int) { - c.adjustStack(-(args - 1)) - c.emit(func(vm *VirtualMachine) int { - j := json.NewObject() - obj, _ := j.Object() - - for sp := vm.sp - args; sp < vm.sp; sp += 2 { - key := vm.stack[sp] - val := vm.stack[sp+1] - - if key == nil { - vm.err = errJSONKeyIsNil - return 0 - } - - obj.Set(key.(*evalBytes).string(), val.(*evalJSON), json.Set) - } - vm.stack[vm.sp-args] = j - vm.sp -= args - 1 +func (asm *assembler) Convert_if(offset int) { + asm.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-offset].(*evalInt64) + vm.stack[vm.sp-offset] = vm.arena.newEvalFloat(arg.toFloat0()) return 1 - }, "FN JSON_ARRAY (SP-%d)...(SP-1)", args) + }, "CONV INT64(SP-%d), FLOAT64", offset) } -func (c *compiler) emitBitOp_bb(op bitwiseOp) { - c.adjustStack(-1) - - switch op { - case and: - c.emit(func(vm *VirtualMachine) int { - l := vm.stack[vm.sp-2].(*evalBytes) - r := vm.stack[vm.sp-1].(*evalBytes) - if len(l.bytes) != len(r.bytes) { - vm.err = errBitwiseOperandsLength - return 0 - } - for i := range l.bytes { - l.bytes[i] = l.bytes[i] & r.bytes[i] - } - vm.sp-- - return 1 - }, "AND BINARY(SP-2), BINARY(SP-1)") - case or: - c.emit(func(vm *VirtualMachine) int { - l := vm.stack[vm.sp-2].(*evalBytes) - r := vm.stack[vm.sp-1].(*evalBytes) - if len(l.bytes) != len(r.bytes) { - vm.err = errBitwiseOperandsLength - return 0 - } - for i := range l.bytes { - l.bytes[i] = l.bytes[i] | r.bytes[i] - } - vm.sp-- - return 1 - }, "OR BINARY(SP-2), BINARY(SP-1)") - case xor: - c.emit(func(vm *VirtualMachine) int { - l := vm.stack[vm.sp-2].(*evalBytes) - r := vm.stack[vm.sp-1].(*evalBytes) - if len(l.bytes) != len(r.bytes) { - vm.err = errBitwiseOperandsLength - return 0 - } - for i := range l.bytes { - l.bytes[i] = l.bytes[i] ^ r.bytes[i] - } - vm.sp-- - return 1 - }, "XOR BINARY(SP-2), BINARY(SP-1)") - } +func (asm *assembler) Convert_iu(offset int) { + asm.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-offset].(*evalInt64) + vm.stack[vm.sp-offset] = vm.arena.newEvalUint64(uint64(arg.i)) + return 1 + }, "CONV INT64(SP-%d), UINT64", offset) } -func (c *compiler) emitBitOp_uu(op bitwiseOp) { - c.adjustStack(-1) - - switch op { - case and: - c.emit(func(vm *VirtualMachine) int { - l := vm.stack[vm.sp-2].(*evalUint64) - r := vm.stack[vm.sp-1].(*evalUint64) - l.u = l.u & r.u - vm.sp-- - return 1 - }, "AND UINT64(SP-2), UINT64(SP-1)") - case or: - c.emit(func(vm *VirtualMachine) int { - l := vm.stack[vm.sp-2].(*evalUint64) - r := vm.stack[vm.sp-1].(*evalUint64) - l.u = l.u | r.u - vm.sp-- - return 1 - }, "OR UINT64(SP-2), UINT64(SP-1)") - case xor: - c.emit(func(vm *VirtualMachine) int { - l := vm.stack[vm.sp-2].(*evalUint64) - r := vm.stack[vm.sp-1].(*evalUint64) - l.u = l.u ^ r.u - vm.sp-- - return 1 - }, "XOR UINT64(SP-2), UINT64(SP-1)") - } +func (asm *assembler) Convert_nj(offset int) { + asm.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-offset].(evalNumeric) + vm.stack[vm.sp-offset] = evalConvert_nj(arg) + return 1 + }, "CONV numeric(SP-%d), JSON") } -func (c *compiler) emitConvert_ui(offset int) { - c.emit(func(vm *VirtualMachine) int { - arg := vm.stack[vm.sp-offset].(*evalUint64) - vm.stack[vm.sp-offset] = vm.arena.newEvalInt64(int64(arg.u)) +func (asm *assembler) Convert_Nj(offset int) { + asm.emit(func(vm *VirtualMachine) int { + vm.stack[vm.sp-offset] = json.ValueNull return 1 - }, "CONV UINT64(SP-%d), INT64", offset) + }, "CONV NULL(SP-%d), JSON") } -func (c *compiler) emitConvert_iu(offset int) { - c.emit(func(vm *VirtualMachine) int { - arg := vm.stack[vm.sp-offset].(*evalInt64) - vm.stack[vm.sp-offset] = vm.arena.newEvalUint64(uint64(arg.i)) +func (asm *assembler) Convert_uB(offset int) { + asm.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-offset] + vm.stack[vm.sp-offset] = newEvalBool(arg != nil && arg.(*evalUint64).u != 0) return 1 - }, "CONV INT64(SP-%d), UINT64", offset) + }, "CONV UINT64(SP-%d), BOOL", offset) } -func (c *compiler) emitConvert_xi(offset int) { - c.emit(func(vm *VirtualMachine) int { - arg := evalToNumeric(vm.stack[vm.sp-offset]) - vm.stack[vm.sp-offset] = arg.toInt64() +func (asm *assembler) Convert_ud(offset int) { + asm.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-offset].(*evalUint64) + vm.stack[vm.sp-offset] = vm.arena.newEvalDecimalWithPrec(decimal.NewFromUint(arg.u), 0) return 1 - }, "CONV (SP-%d), INT64", offset) + }, "CONV UINT64(SP-%d), FLOAT64)", offset) } -func (c *compiler) emitConvert_xu(offset int) { - c.emit(func(vm *VirtualMachine) int { - arg := evalToNumeric(vm.stack[vm.sp-offset]) - vm.stack[vm.sp-offset] = arg.toUint64() +func (asm *assembler) Convert_uf(offset int) { + asm.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-offset].(*evalUint64) + vm.stack[vm.sp-offset] = vm.arena.newEvalFloat(arg.toFloat0()) return 1 - }, "CONV (SP-%d), UINT64", offset) + }, "CONV UINT64(SP-%d), FLOAT64)", offset) } -// emitConvert_dbit is a special instruction emission for converting -// a bigdecimal in preparation for a bitwise operation. In that case -// we need to convert the bigdecimal to an int64 and then cast to -// uint64 to ensure we match the behavior of MySQL. -func (c *compiler) emitConvert_dbit(offset int) { - c.emit(func(vm *VirtualMachine) int { - arg := evalToNumeric(vm.stack[vm.sp-offset]) - vm.stack[vm.sp-offset] = vm.arena.newEvalUint64(uint64(arg.toInt64().i)) +func (asm *assembler) Convert_ui(offset int) { + asm.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-offset].(*evalUint64) + vm.stack[vm.sp-offset] = vm.arena.newEvalInt64(int64(arg.u)) return 1 - }, "CONV DECIMAL_BITWISE(SP-%d), UINT64", offset) + }, "CONV UINT64(SP-%d), INT64", offset) } -func (c *compiler) emitConvert_xb(offset int, t sqltypes.Type, length int, hasLength bool) { +func (asm *assembler) Convert_xb(offset int, t sqltypes.Type, length int, hasLength bool) { if hasLength { - c.emit(func(vm *VirtualMachine) int { + asm.emit(func(vm *VirtualMachine) int { arg := evalToBinary(vm.stack[vm.sp-offset]) arg.truncateInPlace(length) arg.tt = int16(t) @@ -1168,7 +853,7 @@ func (c *compiler) emitConvert_xb(offset int, t sqltypes.Type, length int, hasLe return 1 }, "CONV (SP-%d), VARBINARY[%d]", offset, length) } else { - c.emit(func(vm *VirtualMachine) int { + asm.emit(func(vm *VirtualMachine) int { arg := evalToBinary(vm.stack[vm.sp-offset]) arg.tt = int16(t) vm.stack[vm.sp-offset] = arg @@ -1177,48 +862,9 @@ func (c *compiler) emitConvert_xb(offset int, t sqltypes.Type, length int, hasLe } } -func (c *compiler) emitConvert_fj(offset int) { - c.emit(func(vm *VirtualMachine) int { - arg := vm.stack[vm.sp-offset].(*evalFloat) - vm.stack[vm.sp-offset] = evalConvert_fj(arg) - return 1 - }, "CONV FLOAT64(SP-%d), JSON") -} - -func (c *compiler) emitConvert_nj(offset int) { - c.emit(func(vm *VirtualMachine) int { - arg := vm.stack[vm.sp-offset].(evalNumeric) - vm.stack[vm.sp-offset] = evalConvert_nj(arg) - return 1 - }, "CONV numeric(SP-%d), JSON") -} - -func (c *compiler) emitConvert_cj(offset int) { - c.emit(func(vm *VirtualMachine) int { - arg := vm.stack[vm.sp-offset].(*evalBytes) - vm.stack[vm.sp-offset], vm.err = evalConvert_cj(arg) - return 1 - }, "CONV VARCHAR(SP-%d), JSON") -} - -func (c *compiler) emitConvert_bj(offset int) { - c.emit(func(vm *VirtualMachine) int { - arg := vm.stack[vm.sp-offset].(*evalBytes) - vm.stack[vm.sp-offset] = evalConvert_bj(arg) - return 1 - }, "CONV VARBINARY(SP-%d), JSON") -} - -func (c *compiler) emitConvert_Nj(offset int) { - c.emit(func(vm *VirtualMachine) int { - vm.stack[vm.sp-offset] = json.ValueNull - return 1 - }, "CONV NULL(SP-%d), JSON") -} - -func (c *compiler) emitConvert_xc(offset int, t sqltypes.Type, collation collations.ID, length int, hasLength bool) { +func (asm *assembler) Convert_xc(offset int, t sqltypes.Type, collation collations.ID, length int, hasLength bool) { if hasLength { - c.emit(func(vm *VirtualMachine) int { + asm.emit(func(vm *VirtualMachine) int { arg, err := evalToVarchar(vm.stack[vm.sp-offset], collation, true) if err != nil { vm.stack[vm.sp-offset] = nil @@ -1230,7 +876,7 @@ func (c *compiler) emitConvert_xc(offset int, t sqltypes.Type, collation collati return 1 }, "CONV (SP-%d), VARCHAR[%d]", offset, length) } else { - c.emit(func(vm *VirtualMachine) int { + asm.emit(func(vm *VirtualMachine) int { arg, err := evalToVarchar(vm.stack[vm.sp-offset], collation, true) if err != nil { vm.stack[vm.sp-offset] = nil @@ -1243,544 +889,934 @@ func (c *compiler) emitConvert_xc(offset int, t sqltypes.Type, collation collati } } -func (c *compiler) emitBitShiftLeft_bu() { - c.adjustStack(-1) - - c.emit(func(vm *VirtualMachine) int { - l := vm.stack[vm.sp-2].(*evalBytes) - r := vm.stack[vm.sp-1].(*evalUint64) - - var ( - bits = int(r.u & 7) - bytes = int(r.u >> 3) - length = len(l.bytes) - out = make([]byte, length) - ) - - for i := 0; i < length; i++ { - pos := i + bytes + 1 - switch { - case pos < length: - out[i] = l.bytes[pos] >> (8 - bits) - fallthrough - case pos == length: - out[i] |= l.bytes[pos-1] << bits - } - } - l.bytes = out +func (asm *assembler) Convert_xd(offset int, m, d int32) { + asm.emit(func(vm *VirtualMachine) int { + vm.stack[vm.sp-offset] = evalToNumeric(vm.stack[vm.sp-offset]).toDecimal(m, d) + return 1 + }, "CONV (SP-%d), FLOAT64", offset) +} - vm.sp-- +func (asm *assembler) Convert_xf(offset int) { + asm.emit(func(vm *VirtualMachine) int { + vm.stack[vm.sp-offset], _ = evalToNumeric(vm.stack[vm.sp-offset]).toFloat() return 1 - }, "BIT_SHL BINARY(SP-2), UINT64(SP-1)") + }, "CONV (SP-%d), FLOAT64", offset) +} +func (asm *assembler) Convert_xi(offset int) { + asm.emit(func(vm *VirtualMachine) int { + arg := evalToNumeric(vm.stack[vm.sp-offset]) + vm.stack[vm.sp-offset] = arg.toInt64() + return 1 + }, "CONV (SP-%d), INT64", offset) } -func (c *compiler) emitBitShiftRight_bu() { - c.adjustStack(-1) +func (asm *assembler) Convert_xu(offset int) { + asm.emit(func(vm *VirtualMachine) int { + arg := evalToNumeric(vm.stack[vm.sp-offset]) + vm.stack[vm.sp-offset] = arg.toUint64() + return 1 + }, "CONV (SP-%d), UINT64", offset) +} - c.emit(func(vm *VirtualMachine) int { - l := vm.stack[vm.sp-2].(*evalBytes) - r := vm.stack[vm.sp-1].(*evalUint64) +func (asm *assembler) Div_dd() { + asm.adjustStack(-1) - var ( - bits = int(r.u & 7) - bytes = int(r.u >> 3) - length = len(l.bytes) - out = make([]byte, length) - ) + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalDecimal) + r := vm.stack[vm.sp-1].(*evalDecimal) + if r.dec.IsZero() { + vm.stack[vm.sp-2] = nil + } else { + mathDiv_dd0(l, r, divPrecisionIncrement) + } + vm.sp-- + return 1 + }, "DIV DECIMAL(SP-2), DECIMAL(SP-1)") +} - for i := length - 1; i >= 0; i-- { - switch { - case i > bytes: - out[i] = l.bytes[i-bytes-1] << (8 - bits) - fallthrough - case i == bytes: - out[i] |= l.bytes[i-bytes] >> bits +func (asm *assembler) Div_ff() { + asm.adjustStack(-1) + + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalFloat) + r := vm.stack[vm.sp-1].(*evalFloat) + if r.f == 0.0 { + vm.stack[vm.sp-2] = nil + } else { + l.f, vm.err = mathDiv_ff0(l.f, r.f) + } + vm.sp-- + return 1 + }, "DIV FLOAT64(SP-2), FLOAT64(SP-1)") +} + +func (asm *assembler) Fn_ASCII() { + asm.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-1].(*evalBytes) + if len(arg.bytes) == 0 { + vm.stack[vm.sp-1] = vm.arena.newEvalInt64(0) + } else { + vm.stack[vm.sp-1] = vm.arena.newEvalInt64(int64(arg.bytes[0])) + } + return 1 + }, "FN ASCII VARCHAR(SP-1)") +} + +func (asm *assembler) Fn_FROM_BASE64() { + asm.emit(func(vm *VirtualMachine) int { + str := vm.stack[vm.sp-1].(*evalBytes) + + decoded := make([]byte, mysqlBase64.DecodedLen(len(str.bytes))) + + n, err := mysqlBase64.Decode(decoded, str.bytes) + if err != nil { + vm.stack[vm.sp-1] = nil + return 1 + } + str.tt = int16(sqltypes.VarBinary) + str.bytes = decoded[:n] + return 1 + }, "FROM_BASE64 VARCHAR(SP-1)") +} + +func (asm *assembler) Fn_HEX_c(t sqltypes.Type, col collations.TypedCollation) { + asm.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-1].(*evalBytes) + encoded := vm.arena.newEvalText(hexEncodeBytes(arg.bytes), col) + encoded.tt = int16(t) + vm.stack[vm.sp-1] = encoded + return 1 + }, "FN HEX VARCHAR(SP-1)") +} + +func (asm *assembler) Fn_HEX_d(col collations.TypedCollation) { + asm.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-1].(evalNumeric) + vm.stack[vm.sp-1] = vm.arena.newEvalText(hexEncodeUint(uint64(arg.toInt64().i)), col) + return 1 + }, "FN HEX NUMERIC(SP-1)") +} + +func (asm *assembler) Fn_JSON_ARRAY(args int) { + asm.adjustStack(-(args - 1)) + asm.emit(func(vm *VirtualMachine) int { + ary := make([]*json.Value, 0, args) + for sp := vm.sp - args; sp < vm.sp; sp++ { + ary = append(ary, vm.stack[sp].(*json.Value)) + } + vm.stack[vm.sp-args] = json.NewArray(ary) + vm.sp -= args - 1 + return 1 + }, "FN JSON_ARRAY (SP-%d)...(SP-1)", args) +} + +func (asm *assembler) Fn_JSON_CONTAINS_PATH(match jsonMatch, paths []*json.Path) { + switch match { + case jsonMatchOne: + asm.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-1].(*evalJSON) + matched := false + for _, p := range paths { + p.Match(arg, true, func(*json.Value) { matched = true }) + if matched { + break + } + } + vm.stack[vm.sp-1] = newEvalBool(matched) + return 1 + }, "FN JSON_CONTAINS_PATH, SP-1, 'one', [static]") + case jsonMatchAll: + asm.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-1].(*evalJSON) + matched := true + for _, p := range paths { + matched = false + p.Match(arg, true, func(*json.Value) { matched = true }) + if !matched { + break + } + } + vm.stack[vm.sp-1] = newEvalBool(matched) + return 1 + }, "FN JSON_CONTAINS_PATH, SP-1, 'all', [static]") + } +} + +func (asm *assembler) Fn_JSON_EXTRACT0(jp []*json.Path) { + multi := len(jp) > 1 || slices2.Any(jp, func(path *json.Path) bool { return path.ContainsWildcards() }) + + if multi { + asm.emit(func(vm *VirtualMachine) int { + matches := make([]*json.Value, 0, 4) + arg := vm.stack[vm.sp-1].(*evalJSON) + for _, jp := range jp { + jp.Match(arg, true, func(value *json.Value) { + matches = append(matches, value) + }) + } + if len(matches) == 0 { + vm.stack[vm.sp-1] = nil + } else { + vm.stack[vm.sp-1] = json.NewArray(matches) + } + return 1 + }, "FN JSON_EXTRACT, SP-1, [static]") + } else { + asm.emit(func(vm *VirtualMachine) int { + var match *json.Value + arg := vm.stack[vm.sp-1].(*evalJSON) + jp[0].Match(arg, true, func(value *json.Value) { + match = value + }) + if match == nil { + vm.stack[vm.sp-1] = nil + } else { + vm.stack[vm.sp-1] = match + } + return 1 + }, "FN JSON_EXTRACT, SP-1, [static]") + } +} + +func (asm *assembler) Fn_JSON_OBJECT(args int) { + asm.adjustStack(-(args - 1)) + asm.emit(func(vm *VirtualMachine) int { + j := json.NewObject() + obj, _ := j.Object() + + for sp := vm.sp - args; sp < vm.sp; sp += 2 { + key := vm.stack[sp] + val := vm.stack[sp+1] + + if key == nil { + vm.err = errJSONKeyIsNil + return 0 + } + + obj.Set(key.(*evalBytes).string(), val.(*evalJSON), json.Set) + } + vm.stack[vm.sp-args] = j + vm.sp -= args - 1 + return 1 + }, "FN JSON_ARRAY (SP-%d)...(SP-1)", args) +} + +func (asm *assembler) Fn_JSON_UNQUOTE() { + asm.emit(func(vm *VirtualMachine) int { + j := vm.stack[vm.sp-1].(*evalJSON) + b := vm.arena.newEvalBytesEmpty() + b.tt = int16(sqltypes.Blob) + b.col = collationJSON + if jbytes, ok := j.StringBytes(); ok { + b.bytes = jbytes + } else { + b.bytes = j.MarshalTo(nil) + } + vm.stack[vm.sp-1] = b + return 1 + }, "FN JSON_UNQUOTE (SP-1)") +} + +func (asm *assembler) Fn_LENGTH(op lengthOp) { + switch op { + case charLen: + asm.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-1].(*evalBytes) + + if sqltypes.IsBinary(arg.SQLType()) { + vm.stack[vm.sp-1] = vm.arena.newEvalInt64(int64(len(arg.bytes))) + } else { + coll := arg.col.Collation.Get() + count := charset.Length(coll.Charset(), arg.bytes) + vm.stack[vm.sp-1] = vm.arena.newEvalInt64(int64(count)) + } + return 1 + }, "FN CHAR_LENGTH VARCHAR(SP-1)") + case byteLen: + asm.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-1].(*evalBytes) + vm.stack[vm.sp-1] = vm.arena.newEvalInt64(int64(len(arg.bytes))) + return 1 + }, "FN LENGTH VARCHAR(SP-1)") + case bitLen: + asm.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-1].(*evalBytes) + vm.stack[vm.sp-1] = vm.arena.newEvalInt64(int64(len(arg.bytes) * 8)) + return 1 + }, "FN BIT_LENGTH VARCHAR(SP-1)") + } +} + +func (asm *assembler) Fn_LUCASE(upcase bool) { + if upcase { + asm.emit(func(vm *VirtualMachine) int { + str := vm.stack[vm.sp-1].(*evalBytes) + + coll := str.col.Collation.Get() + csa, ok := coll.(collations.CaseAwareCollation) + if !ok { + vm.err = vterrors.Errorf(vtrpc.Code_UNIMPLEMENTED, "not implemented") + } else { + str.bytes = csa.ToUpper(nil, str.bytes) + } + str.tt = int16(sqltypes.VarChar) + return 1 + }, "FN UPPER VARCHAR(SP-1)") + } else { + asm.emit(func(vm *VirtualMachine) int { + str := vm.stack[vm.sp-1].(*evalBytes) + + coll := str.col.Collation.Get() + csa, ok := coll.(collations.CaseAwareCollation) + if !ok { + vm.err = vterrors.Errorf(vtrpc.Code_UNIMPLEMENTED, "not implemented") + } else { + str.bytes = csa.ToLower(nil, str.bytes) + } + str.tt = int16(sqltypes.VarChar) + return 1 + }, "FN LOWER VARCHAR(SP-1)") + } +} + +func (asm *assembler) Fn_MULTICMP_b(args int, lessThan bool) { + asm.adjustStack(-(args - 1)) + + asm.emit(func(vm *VirtualMachine) int { + x := vm.stack[vm.sp-args].ToRawBytes() + for sp := vm.sp - args + 1; sp < vm.sp; sp++ { + y := vm.stack[sp].ToRawBytes() + if lessThan == (bytes.Compare(y, x) < 0) { + x = y + } + } + vm.stack[vm.sp-args] = vm.arena.newEvalBinary(x) + vm.sp -= args - 1 + return 1 + }, "FN MULTICMP VARBINARY(SP-%d)...VARBINARY(SP-1)", args) +} + +func (asm *assembler) Fn_MULTICMP_c(args int, lessThan bool, tc collations.TypedCollation) { + col := tc.Collation.Get() + + asm.adjustStack(-(args - 1)) + asm.emit(func(vm *VirtualMachine) int { + x := vm.stack[vm.sp-args].ToRawBytes() + for sp := vm.sp - args + 1; sp < vm.sp; sp++ { + y := vm.stack[sp].ToRawBytes() + if lessThan == (col.Collate(y, x, false) < 0) { + x = y + } + } + vm.stack[vm.sp-args] = vm.arena.newEvalText(x, tc) + vm.sp -= args - 1 + return 1 + }, "FN MULTICMP FLOAT64(SP-%d)...FLOAT64(SP-1)", args) +} + +func (asm *assembler) Fn_MULTICMP_d(args int, lessThan bool) { + asm.adjustStack(-(args - 1)) + + asm.emit(func(vm *VirtualMachine) int { + x := vm.stack[vm.sp-args].(*evalDecimal) + xprec := x.length + + for sp := vm.sp - args + 1; sp < vm.sp; sp++ { + y := vm.stack[sp].(*evalDecimal) + if lessThan == (y.dec.Cmp(x.dec) < 0) { + x = y + } + if y.length > xprec { + xprec = y.length + } + } + vm.stack[vm.sp-args] = vm.arena.newEvalDecimalWithPrec(x.dec, xprec) + vm.sp -= args - 1 + return 1 + }, "FN MULTICMP DECIMAL(SP-%d)...DECIMAL(SP-1)", args) +} + +func (asm *assembler) Fn_MULTICMP_f(args int, lessThan bool) { + asm.adjustStack(-(args - 1)) + + asm.emit(func(vm *VirtualMachine) int { + x := vm.stack[vm.sp-args].(*evalFloat) + for sp := vm.sp - args + 1; sp < vm.sp; sp++ { + y := vm.stack[sp].(*evalFloat) + if lessThan == (y.f < x.f) { + x = y + } + } + vm.stack[vm.sp-args] = x + vm.sp -= args - 1 + return 1 + }, "FN MULTICMP FLOAT64(SP-%d)...FLOAT64(SP-1)", args) +} + +func (asm *assembler) Fn_MULTICMP_i(args int, lessThan bool) { + asm.adjustStack(-(args - 1)) + + asm.emit(func(vm *VirtualMachine) int { + x := vm.stack[vm.sp-args].(*evalInt64) + for sp := vm.sp - args + 1; sp < vm.sp; sp++ { + y := vm.stack[sp].(*evalInt64) + if lessThan == (y.i < x.i) { + x = y } } - l.bytes = out + vm.stack[vm.sp-args] = x + vm.sp -= args - 1 + return 1 + }, "FN MULTICMP INT64(SP-%d)...INT64(SP-1)", args) +} + +func (asm *assembler) Fn_MULTICMP_u(args int, lessThan bool) { + asm.adjustStack(-(args - 1)) + + asm.emit(func(vm *VirtualMachine) int { + x := vm.stack[vm.sp-args].(*evalUint64) + for sp := vm.sp - args + 1; sp < vm.sp; sp++ { + y := vm.stack[sp].(*evalUint64) + if lessThan == (y.u < x.u) { + x = y + } + } + vm.stack[vm.sp-args] = x + vm.sp -= args - 1 + return 1 + }, "FN MULTICMP UINT64(SP-%d)...UINT64(SP-1)", args) +} + +func (asm *assembler) Fn_REPEAT(i int) { + asm.adjustStack(-1) + + asm.emit(func(vm *VirtualMachine) int { + str := vm.stack[vm.sp-2].(*evalBytes) + repeat := vm.stack[vm.sp-1].(*evalInt64) + + if repeat.i < 0 { + repeat.i = 0 + } + + if !checkMaxLength(int64(len(str.bytes)), repeat.i) { + vm.stack[vm.sp-2] = nil + vm.sp-- + return 1 + } + + str.tt = int16(sqltypes.VarChar) + str.bytes = bytes.Repeat(str.bytes, int(repeat.i)) + vm.sp-- + return 1 + }, "FN REPEAT VARCHAR(SP-2) INT64(SP-1)") +} + +func (asm *assembler) Fn_TO_BASE64(t sqltypes.Type, col collations.TypedCollation) { + asm.emit(func(vm *VirtualMachine) int { + str := vm.stack[vm.sp-1].(*evalBytes) + + encoded := make([]byte, mysqlBase64.EncodedLen(len(str.bytes))) + mysqlBase64.Encode(encoded, str.bytes) + + str.tt = int16(t) + str.col = col + str.bytes = encoded + return 1 + }, "FN TO_BASE64 VARCHAR(SP-1)") +} + +func (asm *assembler) Like_coerce(expr *LikeExpr, coercion *compiledCoercion) { + asm.adjustStack(-1) + + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalBytes) + r := vm.stack[vm.sp-1].(*evalBytes) + vm.sp-- + + var bl, br []byte + bl, vm.err = coercion.left(nil, l.bytes) + if vm.err != nil { + return 0 + } + br, vm.err = coercion.right(nil, r.bytes) + if vm.err != nil { + return 0 + } + + match := expr.matchWildcard(bl, br, coercion.col.ID()) + if match { + vm.stack[vm.sp-1] = evalBoolTrue + } else { + vm.stack[vm.sp-1] = evalBoolFalse + } + return 1 + }, "LIKE VARCHAR(SP-2), VARCHAR(SP-1) COERCE AND COLLATE '%s'", coercion.col.Name()) +} + +func (asm *assembler) Like_collate(expr *LikeExpr, collation collations.Collation) { + asm.adjustStack(-1) + + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalBytes) + r := vm.stack[vm.sp-1].(*evalBytes) + vm.sp-- + + match := expr.matchWildcard(l.bytes, r.bytes, collation.ID()) + if match { + vm.stack[vm.sp-1] = evalBoolTrue + } else { + vm.stack[vm.sp-1] = evalBoolFalse + } + return 1 + }, "LIKE VARCHAR(SP-2), VARCHAR(SP-1) COLLATE '%s'", collation.Name()) +} +func (asm *assembler) Mul_dd() { + asm.adjustStack(-1) + + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalDecimal) + r := vm.stack[vm.sp-1].(*evalDecimal) + mathMul_dd0(l, r) vm.sp-- return 1 - }, "BIT_SHR BINARY(SP-2), UINT64(SP-1)") + }, "MUL DECIMAL(SP-2), DECIMAL(SP-1)") } -func (c *compiler) emitBitShiftLeft_uu() { - c.adjustStack(-1) +func (asm *assembler) Mul_ff() { + asm.adjustStack(-1) - c.emit(func(vm *VirtualMachine) int { - l := vm.stack[vm.sp-2].(*evalUint64) - r := vm.stack[vm.sp-1].(*evalUint64) - l.u = l.u << r.u + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalFloat) + r := vm.stack[vm.sp-1].(*evalFloat) + l.f *= r.f + vm.sp-- + return 1 + }, "MUL FLOAT64(SP-2), FLOAT64(SP-1)") +} +func (asm *assembler) Mul_ii() { + asm.adjustStack(-1) + + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalInt64) + r := vm.stack[vm.sp-1].(*evalInt64) + l.i, vm.err = mathMul_ii0(l.i, r.i) vm.sp-- return 1 - }, "BIT_SHL UINT64(SP-2), UINT64(SP-1)") + }, "MUL INT64(SP-2), INT64(SP-1)") +} + +func (asm *assembler) Mul_ui(swap bool) { + asm.adjustStack(-1) + + if swap { + asm.emit(func(vm *VirtualMachine) int { + var u uint64 + l := vm.stack[vm.sp-1].(*evalUint64) + r := vm.stack[vm.sp-2].(*evalInt64) + u, vm.err = mathMul_ui0(l.u, r.i) + vm.stack[vm.sp-2] = vm.arena.newEvalUint64(u) + vm.sp-- + return 1 + }, "MUL UINT64(SP-1), INT64(SP-2)") + } else { + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalUint64) + r := vm.stack[vm.sp-1].(*evalInt64) + l.u, vm.err = mathMul_ui0(l.u, r.i) + vm.sp-- + return 1 + }, "MUL UINT64(SP-2), INT64(SP-1)") + } } -func (c *compiler) emitBitShiftRight_uu() { - c.adjustStack(-1) +func (asm *assembler) Mul_uu() { + asm.adjustStack(-1) - c.emit(func(vm *VirtualMachine) int { + asm.emit(func(vm *VirtualMachine) int { l := vm.stack[vm.sp-2].(*evalUint64) r := vm.stack[vm.sp-1].(*evalUint64) - l.u = l.u >> r.u - + l.u, vm.err = mathMul_uu0(l.u, r.u) vm.sp-- return 1 - }, "BIT_SHR UINT64(SP-2), UINT64(SP-1)") + }, "MUL UINT64(SP-2), UINT64(SP-1)") } -func (c *compiler) emitBitCount_b() { - c.emit(func(vm *VirtualMachine) int { - a := vm.stack[vm.sp-1].(*evalBytes) - count := 0 - for _, b := range a.bytes { - count += bits.OnesCount8(b) - } - vm.stack[vm.sp-1] = vm.arena.newEvalInt64(int64(count)) +func (asm *assembler) Neg_d() { + asm.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-1].(*evalDecimal) + arg.dec = arg.dec.Neg() return 1 - }, "BIT_COUNT BINARY(SP-1)") + }, "NEG DECIMAL(SP-1)") } -func (c *compiler) emitBitCount_u() { - c.emit(func(vm *VirtualMachine) int { - a := vm.stack[vm.sp-1].(*evalUint64) - vm.stack[vm.sp-1] = vm.arena.newEvalInt64(int64(bits.OnesCount64(a.u))) +func (asm *assembler) Neg_f() { + asm.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-1].(*evalFloat) + arg.f = -arg.f return 1 - }, "BIT_COUNT UINT64(SP-1)") + }, "NEG FLOAT64(SP-1)") } -func (c *compiler) emitBitwiseNot_b() { - c.emit(func(vm *VirtualMachine) int { - a := vm.stack[vm.sp-1].(*evalBytes) - for i := range a.bytes { - a.bytes[i] = ^a.bytes[i] - } +func (asm *assembler) Neg_hex() { + asm.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-1].(*evalUint64) + vm.stack[vm.sp-1] = vm.arena.newEvalFloat(-float64(arg.u)) return 1 - }, "BIT_NOT BINARY(SP-1)") + }, "NEG HEX(SP-1)") } -func (c *compiler) emitBitwiseNot_u() { - c.emit(func(vm *VirtualMachine) int { - a := vm.stack[vm.sp-1].(*evalUint64) - a.u = ^a.u +func (asm *assembler) Neg_i() { + asm.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-1].(*evalInt64) + if arg.i == math.MinInt64 { + vm.err = errDeoptimize + } else { + arg.i = -arg.i + } return 1 - }, "BIT_NOT UINT64(SP-1)") + }, "NEG INT64(SP-1)") } -func (c *compiler) emitCmpCase(cases int, hasElse bool, tt sqltypes.Type, cc collations.TypedCollation) { - elseOffset := 0 - if hasElse { - elseOffset = 1 - } - - stackDepth := 2*cases + elseOffset - c.adjustStack(-(stackDepth - 1)) - - c.emit(func(vm *VirtualMachine) int { - end := vm.sp - elseOffset - for sp := vm.sp - stackDepth; sp < end; sp += 2 { - if vm.stack[sp].(*evalInt64).i != 0 { - vm.stack[vm.sp-stackDepth], vm.err = evalCoerce(vm.stack[sp+1], tt, cc.Collation) - goto done - } - } - if elseOffset != 0 { - vm.stack[vm.sp-stackDepth], vm.err = evalCoerce(vm.stack[vm.sp-1], tt, cc.Collation) +func (asm *assembler) Neg_u() { + asm.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-1].(*evalUint64) + if arg.u > math.MaxInt64+1 { + vm.err = errDeoptimize } else { - vm.stack[vm.sp-stackDepth] = nil + vm.stack[vm.sp-1] = vm.arena.newEvalInt64(-int64(arg.u)) } - done: - vm.sp -= stackDepth - 1 return 1 - }, "CASE [%d cases, else = %v]", cases, hasElse) + }, "NEG UINT64(SP-1)") } -func (c *compiler) emitConvert_iB(offset int) { - c.emit(func(vm *VirtualMachine) int { - arg := vm.stack[vm.sp-offset] - vm.stack[vm.sp-offset] = newEvalBool(arg != nil && arg.(*evalInt64).i != 0) +func (asm *assembler) NullCheck1(j *jump) { + asm.emit(func(vm *VirtualMachine) int { + if vm.stack[vm.sp-1] == nil { + return j.offset() + } return 1 - }, "CONV INT64(SP-%d), BOOL", offset) + }, "NULLCHECK SP-1") } -func (c *compiler) emitConvert_uB(offset int) { - c.emit(func(vm *VirtualMachine) int { - arg := vm.stack[vm.sp-offset] - vm.stack[vm.sp-offset] = newEvalBool(arg != nil && arg.(*evalUint64).u != 0) +func (asm *assembler) NullCheck2(j *jump) { + asm.emit(func(vm *VirtualMachine) int { + if vm.stack[vm.sp-2] == nil || vm.stack[vm.sp-1] == nil { + vm.stack[vm.sp-2] = nil + vm.sp-- + return j.offset() + } return 1 - }, "CONV UINT64(SP-%d), BOOL", offset) + }, "NULLCHECK SP-1, SP-2") } -func (c *compiler) emitConvert_fB(offset int) { - c.emit(func(vm *VirtualMachine) int { - arg := vm.stack[vm.sp-offset] - vm.stack[vm.sp-offset] = newEvalBool(arg != nil && arg.(*evalFloat).f != 0.0) +func (asm *assembler) NullCmp(j *jump) { + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2] + r := vm.stack[vm.sp-1] + if l == nil || r == nil { + if l == r { + vm.flags.cmp = 0 + } else { + vm.flags.cmp = 1 + } + vm.sp -= 2 + return j.offset() + } return 1 - }, "CONV FLOAT64(SP-%d), BOOL", offset) + }, "NULLCMP SP-1, SP-2") } -func (c *compiler) emitConvert_dB(offset int) { - c.emit(func(vm *VirtualMachine) int { - arg := vm.stack[vm.sp-offset] - vm.stack[vm.sp-offset] = newEvalBool(arg != nil && !arg.(*evalDecimal).dec.IsZero()) +func (asm *assembler) PackTuple(tlen int) { + asm.adjustStack(-(tlen - 1)) + asm.emit(func(vm *VirtualMachine) int { + tuple := make([]eval, tlen) + copy(tuple, vm.stack[vm.sp-tlen:]) + vm.stack[vm.sp-tlen] = &evalTuple{tuple} + vm.sp -= tlen - 1 return 1 - }, "CONV DECIMAL(SP-%d), BOOL", offset) + }, "TUPLE (SP-%d)...(SP-1)", tlen) } -func (c *compiler) emitConvert_bB(offset int) { - c.emit(func(vm *VirtualMachine) int { - arg := vm.stack[vm.sp-offset] - vm.stack[vm.sp-offset] = newEvalBool(arg != nil && parseStringToFloat(arg.(*evalBytes).string()) != 0.0) +func (asm *assembler) Parse_j(offset int) { + asm.emit(func(vm *VirtualMachine) int { + var p json.Parser + arg := vm.stack[vm.sp-offset].(*evalBytes) + vm.stack[vm.sp-offset], vm.err = p.ParseBytes(arg.bytes) return 1 - }, "CONV VARBINARY(SP-%d), BOOL", offset) + }, "PARSE_JSON VARCHAR(SP-%d)", offset) } -func (c *compiler) emitNeg_i() { - c.emit(func(vm *VirtualMachine) int { - arg := vm.stack[vm.sp-1].(*evalInt64) - if arg.i == math.MinInt64 { - vm.err = errDeoptimize - } else { - arg.i = -arg.i - } +func (asm *assembler) PushColumn_bin(offset int) { + asm.adjustStack(1) + + asm.emit(func(vm *VirtualMachine) int { + vm.stack[vm.sp] = newEvalBinary(vm.row[offset].Raw()) + vm.sp++ return 1 - }, "NEG INT64(SP-1)") + }, "PUSH VARBINARY(:%d)", offset) } -func (c *compiler) emitNeg_hex() { - c.emit(func(vm *VirtualMachine) int { - arg := vm.stack[vm.sp-1].(*evalUint64) - vm.stack[vm.sp-1] = vm.arena.newEvalFloat(-float64(arg.u)) +func (asm *assembler) PushColumn_d(offset int) { + asm.adjustStack(1) + + asm.emit(func(vm *VirtualMachine) int { + var dec decimal.Decimal + dec, vm.err = decimal.NewFromMySQL(vm.row[offset].Raw()) + vm.stack[vm.sp] = vm.arena.newEvalDecimal(dec, 0, 0) + vm.sp++ return 1 - }, "NEG HEX(SP-1)") + }, "PUSH DECIMAL(:%d)", offset) } -func (c *compiler) emitNeg_u() { - c.emit(func(vm *VirtualMachine) int { - arg := vm.stack[vm.sp-1].(*evalUint64) - if arg.u > math.MaxInt64+1 { - vm.err = errDeoptimize - } else { - vm.stack[vm.sp-1] = vm.arena.newEvalInt64(-int64(arg.u)) - } +func (asm *assembler) PushColumn_f(offset int) { + asm.adjustStack(1) + + asm.emit(func(vm *VirtualMachine) int { + var fval float64 + fval, vm.err = vm.row[offset].ToFloat64() + vm.stack[vm.sp] = vm.arena.newEvalFloat(fval) + vm.sp++ return 1 - }, "NEG UINT64(SP-1)") + }, "PUSH FLOAT64(:%d)", offset) } -func (c *compiler) emitNeg_f() { - c.emit(func(vm *VirtualMachine) int { - arg := vm.stack[vm.sp-1].(*evalFloat) - arg.f = -arg.f +func (asm *assembler) PushColumn_hexnum(offset int) { + asm.adjustStack(1) + + asm.emit(func(vm *VirtualMachine) int { + var raw []byte + raw, vm.err = parseHexNumber(vm.row[offset].Raw()) + vm.stack[vm.sp] = newEvalBytesHex(raw) + vm.sp++ return 1 - }, "NEG FLOAT64(SP-1)") + }, "PUSH HEXNUM(:%d)", offset) } -func (c *compiler) emitNeg_d() { - c.emit(func(vm *VirtualMachine) int { - arg := vm.stack[vm.sp-1].(*evalDecimal) - arg.dec = arg.dec.Neg() +func (asm *assembler) PushColumn_hexval(offset int) { + asm.adjustStack(1) + + asm.emit(func(vm *VirtualMachine) int { + hex := vm.row[offset].Raw() + var raw []byte + raw, vm.err = parseHexLiteral(hex[2 : len(hex)-1]) + vm.stack[vm.sp] = newEvalBytesHex(raw) + vm.sp++ return 1 - }, "NEG DECIMAL(SP-1)") + }, "PUSH HEXVAL(:%d)", offset) } -func (c *compiler) emitFn_REPEAT(i int) { - c.adjustStack(-1) +func (asm *assembler) PushColumn_i(offset int) { + asm.adjustStack(1) - c.emit(func(vm *VirtualMachine) int { - str := vm.stack[vm.sp-2].(*evalBytes) - repeat := vm.stack[vm.sp-1].(*evalInt64) - - if repeat.i < 0 { - repeat.i = 0 - } + asm.emit(func(vm *VirtualMachine) int { + var ival int64 + ival, vm.err = vm.row[offset].ToInt64() + vm.stack[vm.sp] = vm.arena.newEvalInt64(ival) + vm.sp++ + return 1 + }, "PUSH INT64(:%d)", offset) +} - if !checkMaxLength(int64(len(str.bytes)), repeat.i) { - vm.stack[vm.sp-2] = nil - vm.sp-- - return 1 - } +func (asm *assembler) PushColumn_json(offset int) { + asm.adjustStack(1) - str.tt = int16(sqltypes.VarChar) - str.bytes = bytes.Repeat(str.bytes, int(repeat.i)) - vm.sp-- + asm.emit(func(vm *VirtualMachine) int { + var parser json.Parser + vm.stack[vm.sp], vm.err = parser.ParseBytes(vm.row[offset].Raw()) + vm.sp++ return 1 - }, "FN REPEAT VARCHAR(SP-2) INT64(SP-1)") + }, "PUSH JSON(:%d)", offset) } -func (c *compiler) emitFn_TO_BASE64(t sqltypes.Type, col collations.TypedCollation) { - c.emit(func(vm *VirtualMachine) int { - str := vm.stack[vm.sp-1].(*evalBytes) - - encoded := make([]byte, mysqlBase64.EncodedLen(len(str.bytes))) - mysqlBase64.Encode(encoded, str.bytes) +func (asm *assembler) PushColumn_text(offset int, col collations.TypedCollation) { + asm.adjustStack(1) - str.tt = int16(t) - str.col = col - str.bytes = encoded + asm.emit(func(vm *VirtualMachine) int { + vm.stack[vm.sp] = newEvalText(vm.row[offset].Raw(), col) + vm.sp++ return 1 - }, "FN TO_BASE64 VARCHAR(SP-1)") + }, "PUSH VARCHAR(:%d) COLLATE %d", offset, col.Collation) } -func (c *compiler) emitFromBase64() { - c.emit(func(vm *VirtualMachine) int { - str := vm.stack[vm.sp-1].(*evalBytes) - - decoded := make([]byte, mysqlBase64.DecodedLen(len(str.bytes))) +func (asm *assembler) PushColumn_u(offset int) { + asm.adjustStack(1) - n, err := mysqlBase64.Decode(decoded, str.bytes) - if err != nil { - vm.stack[vm.sp-1] = nil - return 1 - } - str.tt = int16(sqltypes.VarBinary) - str.bytes = decoded[:n] + asm.emit(func(vm *VirtualMachine) int { + var uval uint64 + uval, vm.err = vm.row[offset].ToUint64() + vm.stack[vm.sp] = vm.arena.newEvalUint64(uval) + vm.sp++ return 1 - }, "FROM_BASE64 VARCHAR(SP-1)") + }, "PUSH UINT64(:%d)", offset) } -func (c *compiler) emitFn_LUCASE(upcase bool) { - if upcase { - c.emit(func(vm *VirtualMachine) int { - str := vm.stack[vm.sp-1].(*evalBytes) +func (asm *assembler) PushLiteral(lit eval) error { + asm.adjustStack(1) - coll := str.col.Collation.Get() - csa, ok := coll.(collations.CaseAwareCollation) - if !ok { - vm.err = vterrors.Errorf(vtrpc.Code_UNIMPLEMENTED, "not implemented") - } else { - str.bytes = csa.ToUpper(nil, str.bytes) - } - str.tt = int16(sqltypes.VarChar) + if lit == evalBoolTrue { + asm.emit(func(vm *VirtualMachine) int { + vm.stack[vm.sp] = evalBoolTrue + vm.sp++ return 1 - }, "FN UPPER VARCHAR(SP-1)") - } else { - c.emit(func(vm *VirtualMachine) int { - str := vm.stack[vm.sp-1].(*evalBytes) + }, "PUSH true") + return nil + } - coll := str.col.Collation.Get() - csa, ok := coll.(collations.CaseAwareCollation) - if !ok { - vm.err = vterrors.Errorf(vtrpc.Code_UNIMPLEMENTED, "not implemented") - } else { - str.bytes = csa.ToLower(nil, str.bytes) - } - str.tt = int16(sqltypes.VarChar) + if lit == evalBoolFalse { + asm.emit(func(vm *VirtualMachine) int { + vm.stack[vm.sp] = evalBoolFalse + vm.sp++ return 1 - }, "FN LOWER VARCHAR(SP-1)") + }, "PUSH false") + return nil } -} - -func (c *compiler) emitFn_LENGTH(op lengthOp) { - switch op { - case charLen: - c.emit(func(vm *VirtualMachine) int { - arg := vm.stack[vm.sp-1].(*evalBytes) - if sqltypes.IsBinary(arg.SQLType()) { - vm.stack[vm.sp-1] = vm.arena.newEvalInt64(int64(len(arg.bytes))) - } else { - coll := arg.col.Collation.Get() - count := charset.Length(coll.Charset(), arg.bytes) - vm.stack[vm.sp-1] = vm.arena.newEvalInt64(int64(count)) - } + switch lit := lit.(type) { + case *evalInt64: + asm.emit(func(vm *VirtualMachine) int { + vm.stack[vm.sp] = vm.arena.newEvalInt64(lit.i) + vm.sp++ return 1 - }, "FN CHAR_LENGTH VARCHAR(SP-1)") - case byteLen: - c.emit(func(vm *VirtualMachine) int { - arg := vm.stack[vm.sp-1].(*evalBytes) - vm.stack[vm.sp-1] = vm.arena.newEvalInt64(int64(len(arg.bytes))) + }, "PUSH INT64(%s)", lit.ToRawBytes()) + case *evalUint64: + asm.emit(func(vm *VirtualMachine) int { + vm.stack[vm.sp] = vm.arena.newEvalUint64(lit.u) + vm.sp++ return 1 - }, "FN LENGTH VARCHAR(SP-1)") - case bitLen: - c.emit(func(vm *VirtualMachine) int { - arg := vm.stack[vm.sp-1].(*evalBytes) - vm.stack[vm.sp-1] = vm.arena.newEvalInt64(int64(len(arg.bytes) * 8)) + }, "PUSH UINT64(%s)", lit.ToRawBytes()) + case *evalFloat: + asm.emit(func(vm *VirtualMachine) int { + vm.stack[vm.sp] = vm.arena.newEvalFloat(lit.f) + vm.sp++ return 1 - }, "FN BIT_LENGTH VARCHAR(SP-1)") + }, "PUSH FLOAT64(%s)", lit.ToRawBytes()) + case *evalDecimal: + asm.emit(func(vm *VirtualMachine) int { + vm.stack[vm.sp] = vm.arena.newEvalDecimalWithPrec(lit.dec, lit.length) + vm.sp++ + return 1 + }, "PUSH DECIMAL(%s)", lit.ToRawBytes()) + case *evalBytes: + asm.emit(func(vm *VirtualMachine) int { + b := vm.arena.newEvalBytesEmpty() + *b = *lit + vm.stack[vm.sp] = b + vm.sp++ + return 1 + }, "PUSH VARCHAR(%q)", lit.ToRawBytes()) + default: + return vterrors.Errorf(vtrpc.Code_UNIMPLEMENTED, "unsupported literal kind '%T'", lit) } -} -func (c *compiler) emitFn_ASCII() { - c.emit(func(vm *VirtualMachine) int { - arg := vm.stack[vm.sp-1].(*evalBytes) - if len(arg.bytes) == 0 { - vm.stack[vm.sp-1] = vm.arena.newEvalInt64(0) - } else { - vm.stack[vm.sp-1] = vm.arena.newEvalInt64(int64(arg.bytes[0])) - } - return 1 - }, "FN ASCII VARCHAR(SP-1)") + return nil } -func (c *compiler) emitFn_HEXc(t sqltypes.Type, col collations.TypedCollation) { - c.emit(func(vm *VirtualMachine) int { - arg := vm.stack[vm.sp-1].(*evalBytes) - encoded := vm.arena.newEvalText(hexEncodeBytes(arg.bytes), col) - encoded.tt = int16(t) - vm.stack[vm.sp-1] = encoded - return 1 - }, "FN HEX VARCHAR(SP-1)") -} +func (asm *assembler) PushNull() { + asm.adjustStack(1) -func (c *compiler) emitFn_HEXd(col collations.TypedCollation) { - c.emit(func(vm *VirtualMachine) int { - arg := vm.stack[vm.sp-1].(evalNumeric) - vm.stack[vm.sp-1] = vm.arena.newEvalText(hexEncodeUint(uint64(arg.toInt64().i)), col) + asm.emit(func(vm *VirtualMachine) int { + vm.stack[vm.sp] = nil + vm.sp++ return 1 - }, "FN HEX NUMERIC(SP-1)") + }, "PUSH NULL") } -func (c *compiler) emitLike_collate(expr *LikeExpr, collation collations.Collation) { - c.adjustStack(-1) - - c.emit(func(vm *VirtualMachine) int { - l := vm.stack[vm.sp-2].(*evalBytes) - r := vm.stack[vm.sp-1].(*evalBytes) - vm.sp-- - - match := expr.matchWildcard(l.bytes, r.bytes, collation.ID()) - if match { - vm.stack[vm.sp-1] = evalBoolTrue - } else { - vm.stack[vm.sp-1] = evalBoolFalse - } - return 1 - }, "LIKE VARCHAR(SP-2), VARCHAR(SP-1) COLLATE '%s'", collation.Name()) +func (asm *assembler) SetBool(offset int, b bool) { + if b { + asm.emit(func(vm *VirtualMachine) int { + vm.stack[vm.sp-offset] = evalBoolTrue + return 1 + }, "SET (SP-%d), BOOL(true)", offset) + } else { + asm.emit(func(vm *VirtualMachine) int { + vm.stack[vm.sp-offset] = evalBoolFalse + return 1 + }, "SET (SP-%d), BOOL(false)", offset) + } } -func (c *compiler) emitLike_coerce(expr *LikeExpr, coercion *compiledCoercion) { - c.adjustStack(-1) +func (asm *assembler) Sub_dd() { + asm.adjustStack(-1) - c.emit(func(vm *VirtualMachine) int { - l := vm.stack[vm.sp-2].(*evalBytes) - r := vm.stack[vm.sp-1].(*evalBytes) + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalDecimal) + r := vm.stack[vm.sp-1].(*evalDecimal) + mathSub_dd0(l, r) vm.sp-- - - var bl, br []byte - bl, vm.err = coercion.left(nil, l.bytes) - if vm.err != nil { - return 0 - } - br, vm.err = coercion.right(nil, r.bytes) - if vm.err != nil { - return 0 - } - - match := expr.matchWildcard(bl, br, coercion.col.ID()) - if match { - vm.stack[vm.sp-1] = evalBoolTrue - } else { - vm.stack[vm.sp-1] = evalBoolFalse - } return 1 - }, "LIKE VARCHAR(SP-2), VARCHAR(SP-1) COERCE AND COLLATE '%s'", coercion.col.Name()) + }, "SUB DECIMAL(SP-2), DECIMAL(SP-1)") } -func (c *compiler) emitFn_MULTICMP_i(args int, lessThan bool) { - c.adjustStack(-(args - 1)) +func (asm *assembler) Sub_ff() { + asm.adjustStack(-1) - c.emit(func(vm *VirtualMachine) int { - x := vm.stack[vm.sp-args].(*evalInt64) - for sp := vm.sp - args + 1; sp < vm.sp; sp++ { - y := vm.stack[sp].(*evalInt64) - if lessThan == (y.i < x.i) { - x = y - } - } - vm.stack[vm.sp-args] = x - vm.sp -= args - 1 + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalFloat) + r := vm.stack[vm.sp-1].(*evalFloat) + l.f -= r.f + vm.sp-- return 1 - }, "FN MULTICMP INT64(SP-%d)...INT64(SP-1)", args) + }, "SUB FLOAT64(SP-2), FLOAT64(SP-1)") } -func (c *compiler) emitFn_MULTICMP_u(args int, lessThan bool) { - c.adjustStack(-(args - 1)) +func (asm *assembler) Sub_ii() { + asm.adjustStack(-1) - c.emit(func(vm *VirtualMachine) int { - x := vm.stack[vm.sp-args].(*evalUint64) - for sp := vm.sp - args + 1; sp < vm.sp; sp++ { - y := vm.stack[sp].(*evalUint64) - if lessThan == (y.u < x.u) { - x = y - } - } - vm.stack[vm.sp-args] = x - vm.sp -= args - 1 + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalInt64) + r := vm.stack[vm.sp-1].(*evalInt64) + l.i, vm.err = mathSub_ii0(l.i, r.i) + vm.sp-- return 1 - }, "FN MULTICMP UINT64(SP-%d)...UINT64(SP-1)", args) + }, "SUB INT64(SP-2), INT64(SP-1)") } -func (c *compiler) emitFn_MULTICMP_f(args int, lessThan bool) { - c.adjustStack(-(args - 1)) +func (asm *assembler) Sub_iu() { + asm.adjustStack(-1) - c.emit(func(vm *VirtualMachine) int { - x := vm.stack[vm.sp-args].(*evalFloat) - for sp := vm.sp - args + 1; sp < vm.sp; sp++ { - y := vm.stack[sp].(*evalFloat) - if lessThan == (y.f < x.f) { - x = y - } - } - vm.stack[vm.sp-args] = x - vm.sp -= args - 1 + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalInt64) + r := vm.stack[vm.sp-1].(*evalUint64) + r.u, vm.err = mathSub_iu0(l.i, r.u) + vm.stack[vm.sp-2] = r + vm.sp-- return 1 - }, "FN MULTICMP FLOAT64(SP-%d)...FLOAT64(SP-1)", args) + }, "SUB INT64(SP-2), UINT64(SP-1)") } -func (c *compiler) emitFn_MULTICMP_d(args int, lessThan bool) { - c.adjustStack(-(args - 1)) - - c.emit(func(vm *VirtualMachine) int { - x := vm.stack[vm.sp-args].(*evalDecimal) - xprec := x.length +func (asm *assembler) Sub_ui() { + asm.adjustStack(-1) - for sp := vm.sp - args + 1; sp < vm.sp; sp++ { - y := vm.stack[sp].(*evalDecimal) - if lessThan == (y.dec.Cmp(x.dec) < 0) { - x = y - } - if y.length > xprec { - xprec = y.length - } - } - vm.stack[vm.sp-args] = vm.arena.newEvalDecimalWithPrec(x.dec, xprec) - vm.sp -= args - 1 + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalUint64) + r := vm.stack[vm.sp-1].(*evalInt64) + l.u, vm.err = mathSub_ui0(l.u, r.i) + vm.sp-- return 1 - }, "FN MULTICMP DECIMAL(SP-%d)...DECIMAL(SP-1)", args) + }, "SUB UINT64(SP-2), INT64(SP-1)") } -func (c *compiler) emitFn_MULTICMP_b(args int, lessThan bool) { - c.adjustStack(-(args - 1)) +func (asm *assembler) Sub_uu() { + asm.adjustStack(-1) - c.emit(func(vm *VirtualMachine) int { - x := vm.stack[vm.sp-args].ToRawBytes() - for sp := vm.sp - args + 1; sp < vm.sp; sp++ { - y := vm.stack[sp].ToRawBytes() - if lessThan == (bytes.Compare(y, x) < 0) { - x = y - } - } - vm.stack[vm.sp-args] = vm.arena.newEvalBinary(x) - vm.sp -= args - 1 + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalUint64) + r := vm.stack[vm.sp-1].(*evalUint64) + l.u, vm.err = mathSub_uu0(l.u, r.u) + vm.sp-- return 1 - }, "FN MULTICMP VARBINARY(SP-%d)...VARBINARY(SP-1)", args) + }, "SUB UINT64(SP-2), UINT64(SP-1)") } -func (c *compiler) emitFn_MULTICMP_c(args int, lessThan bool, tc collations.TypedCollation) { - col := tc.Collation.Get() - - c.adjustStack(-(args - 1)) - c.emit(func(vm *VirtualMachine) int { - x := vm.stack[vm.sp-args].ToRawBytes() - for sp := vm.sp - args + 1; sp < vm.sp; sp++ { - y := vm.stack[sp].ToRawBytes() - if lessThan == (col.Collate(y, x, false) < 0) { - x = y - } - } - vm.stack[vm.sp-args] = vm.arena.newEvalText(x, tc) - vm.sp -= args - 1 +func cmpnum[N interface{ int64 | uint64 | float64 }](a, b N) int { + switch { + case a == b: + return 0 + case a < b: + return -1 + default: return 1 - }, "FN MULTICMP FLOAT64(SP-%d)...FLOAT64(SP-1)", args) + } } diff --git a/go/vt/vtgate/evalengine/compiler_bit.go b/go/vt/vtgate/evalengine/compiler_bit.go new file mode 100644 index 00000000000..0654ce6ed02 --- /dev/null +++ b/go/vt/vtgate/evalengine/compiler_bit.go @@ -0,0 +1,138 @@ +package evalengine + +import "vitess.io/vitess/go/sqltypes" + +func (c *compiler) compileBitwise(expr *BitwiseExpr) (ctype, error) { + switch expr.Op.(type) { + case *opBitAnd: + return c.compileBitwiseOp(expr.Left, expr.Right, and) + case *opBitOr: + return c.compileBitwiseOp(expr.Left, expr.Right, or) + case *opBitXor: + return c.compileBitwiseOp(expr.Left, expr.Right, xor) + case *opBitShl: + return c.compileBitwiseShift(expr.Left, expr.Right, -1) + case *opBitShr: + return c.compileBitwiseShift(expr.Left, expr.Right, 1) + default: + panic("unexpected arithmetic operator") + } +} + +type bitwiseOp int + +const ( + and bitwiseOp = iota + or + xor +) + +func (c *compiler) compileBitwiseOp(left Expr, right Expr, op bitwiseOp) (ctype, error) { + lt, err := c.compileExpr(left) + if err != nil { + return ctype{}, err + } + + rt, err := c.compileExpr(right) + if err != nil { + return ctype{}, err + } + + skip := c.asm.jumpFrom() + c.asm.NullCheck2(skip) + + if lt.Type == sqltypes.VarBinary && rt.Type == sqltypes.VarBinary { + if !lt.isHexOrBitLiteral() || !rt.isHexOrBitLiteral() { + c.asm.BitOp_bb(op) + c.asm.jumpDestination(skip) + return ctype{Type: sqltypes.VarBinary, Col: collationBinary}, nil + } + } + + lt = c.compileToBitwiseUint64(lt, 2) + rt = c.compileToBitwiseUint64(rt, 1) + + c.asm.BitOp_uu(op) + c.asm.jumpDestination(skip) + return ctype{Type: sqltypes.Uint64, Col: collationNumeric}, nil +} + +func (c *compiler) compileBitwiseShift(left Expr, right Expr, i int) (ctype, error) { + lt, err := c.compileExpr(left) + if err != nil { + return ctype{}, err + } + + rt, err := c.compileExpr(right) + if err != nil { + return ctype{}, err + } + + skip := c.asm.jumpFrom() + c.asm.NullCheck2(skip) + + if lt.Type == sqltypes.VarBinary && !lt.isHexOrBitLiteral() { + _ = c.compileToUint64(rt, 1) + if i < 0 { + c.asm.BitShiftLeft_bu() + } else { + c.asm.BitShiftRight_bu() + } + c.asm.jumpDestination(skip) + return ctype{Type: sqltypes.VarBinary, Col: collationBinary}, nil + } + + _ = c.compileToBitwiseUint64(lt, 2) + _ = c.compileToUint64(rt, 1) + + if i < 0 { + c.asm.BitShiftLeft_uu() + } else { + c.asm.BitShiftRight_uu() + } + + c.asm.jumpDestination(skip) + return ctype{Type: sqltypes.Uint64, Col: collationNumeric}, nil +} + +func (c *compiler) compileFn_BIT_COUNT(expr *builtinBitCount) (ctype, error) { + ct, err := c.compileExpr(expr.Arguments[0]) + if err != nil { + return ctype{}, err + } + + skip := c.asm.jumpFrom() + c.asm.NullCheck1(skip) + + if ct.Type == sqltypes.VarBinary && !ct.isHexOrBitLiteral() { + c.asm.BitCount_b() + c.asm.jumpDestination(skip) + return ctype{Type: sqltypes.Int64, Col: collationBinary}, nil + } + + _ = c.compileToBitwiseUint64(ct, 1) + c.asm.BitCount_u() + c.asm.jumpDestination(skip) + return ctype{Type: sqltypes.Int64, Col: collationBinary}, nil +} + +func (c *compiler) compileBitwiseNot(expr *BitwiseNotExpr) (ctype, error) { + ct, err := c.compileExpr(expr.Inner) + if err != nil { + return ctype{}, err + } + + skip := c.asm.jumpFrom() + c.asm.NullCheck1(skip) + + if ct.Type == sqltypes.VarBinary && !ct.isHexOrBitLiteral() { + c.asm.BitwiseNot_b() + c.asm.jumpDestination(skip) + return ct, nil + } + + ct = c.compileToBitwiseUint64(ct, 1) + c.asm.BitwiseNot_u() + c.asm.jumpDestination(skip) + return ct, nil +} diff --git a/go/vt/vtgate/evalengine/compiler_compare.go b/go/vt/vtgate/evalengine/compiler_compare.go new file mode 100644 index 00000000000..c3caa45a659 --- /dev/null +++ b/go/vt/vtgate/evalengine/compiler_compare.go @@ -0,0 +1,297 @@ +package evalengine + +import ( + "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/vterrors" +) + +func (c *compiler) compileComparisonTuple(expr *ComparisonExpr) (ctype, error) { + switch expr.Op.(type) { + case compareNullSafeEQ: + c.asm.CmpTupleNullsafe() + return ctype{Type: sqltypes.Int64, Col: collationNumeric}, nil + case compareEQ: + c.asm.CmpTuple(true) + c.asm.Cmp_eq_n() + case compareNE: + c.asm.CmpTuple(true) + c.asm.Cmp_ne_n() + case compareLT: + c.asm.CmpTuple(false) + c.asm.Cmp_lt_n() + case compareLE: + c.asm.CmpTuple(false) + c.asm.Cmp_le_n() + case compareGT: + c.asm.CmpTuple(false) + c.asm.Cmp_gt_n() + case compareGE: + c.asm.CmpTuple(false) + c.asm.Cmp_ge_n() + default: + panic("invalid comparison operator") + } + return ctype{Type: sqltypes.Int64, Flag: flagNullable, Col: collationNumeric}, nil +} + +func (c *compiler) compileComparison(expr *ComparisonExpr) (ctype, error) { + lt, err := c.compileExpr(expr.Left) + if err != nil { + return ctype{}, err + } + + rt, err := c.compileExpr(expr.Right) + if err != nil { + return ctype{}, err + } + + if lt.Type == sqltypes.Tuple || rt.Type == sqltypes.Tuple { + if lt.Type != rt.Type { + return ctype{}, vterrors.Errorf(vtrpc.Code_INTERNAL, "did not typecheck tuples during comparison") + } + return c.compileComparisonTuple(expr) + } + + swapped := false + skip := c.asm.jumpFrom() + + switch expr.Op.(type) { + case compareNullSafeEQ: + c.asm.NullCmp(skip) + default: + c.asm.NullCheck2(skip) + } + + switch { + case compareAsStrings(lt.Type, rt.Type): + var merged collations.TypedCollation + var coerceLeft collations.Coercion + var coerceRight collations.Coercion + var env = collations.Local() + + if lt.Col.Collation != rt.Col.Collation { + merged, coerceLeft, coerceRight, err = env.MergeCollations(lt.Col, rt.Col, collations.CoercionOptions{ + ConvertToSuperset: true, + ConvertWithCoercion: true, + }) + } else { + merged = lt.Col + } + if err != nil { + return ctype{}, err + } + + if coerceLeft == nil && coerceRight == nil { + c.asm.CmpString_collate(merged.Collation.Get()) + } else { + if coerceLeft == nil { + coerceLeft = func(dst, in []byte) ([]byte, error) { return in, nil } + } + if coerceRight == nil { + coerceRight = func(dst, in []byte) ([]byte, error) { return in, nil } + } + c.asm.CmpString_coerce(&compiledCoercion{ + col: merged.Collation.Get(), + left: coerceLeft, + right: coerceRight, + }) + } + + case compareAsSameNumericType(lt.Type, rt.Type) || compareAsDecimal(lt.Type, rt.Type): + switch lt.Type { + case sqltypes.Int64: + switch rt.Type { + case sqltypes.Int64: + c.asm.CmpNum_ii() + case sqltypes.Uint64: + c.asm.CmpNum_iu(2, 1) + case sqltypes.Float64: + c.asm.CmpNum_if(2, 1) + case sqltypes.Decimal: + c.asm.CmpNum_id(2, 1) + } + case sqltypes.Uint64: + switch rt.Type { + case sqltypes.Int64: + c.asm.CmpNum_iu(1, 2) + swapped = true + case sqltypes.Uint64: + c.asm.CmpNum_uu() + case sqltypes.Float64: + c.asm.CmpNum_uf(2, 1) + case sqltypes.Decimal: + c.asm.CmpNum_ud(2, 1) + } + case sqltypes.Float64: + switch rt.Type { + case sqltypes.Int64: + c.asm.CmpNum_if(1, 2) + swapped = true + case sqltypes.Uint64: + c.asm.CmpNum_uf(1, 2) + swapped = true + case sqltypes.Float64: + c.asm.CmpNum_ff() + case sqltypes.Decimal: + c.asm.CmpNum_fd(2, 1) + } + + case sqltypes.Decimal: + switch rt.Type { + case sqltypes.Int64: + c.asm.CmpNum_id(1, 2) + swapped = true + case sqltypes.Uint64: + c.asm.CmpNum_ud(1, 2) + swapped = true + case sqltypes.Float64: + c.asm.CmpNum_fd(1, 2) + swapped = true + case sqltypes.Decimal: + c.asm.CmpNum_dd() + } + } + + case compareAsDates(lt.Type, rt.Type) || compareAsDateAndString(lt.Type, rt.Type) || compareAsDateAndNumeric(lt.Type, rt.Type): + return ctype{}, c.unsupported(expr) + + default: + lt = c.compileToFloat(lt, 2) + rt = c.compileToFloat(rt, 1) + c.asm.CmpNum_ff() + } + + cmptype := ctype{Type: sqltypes.Int64, Col: collationNumeric} + + switch expr.Op.(type) { + case compareEQ: + c.asm.Cmp_eq() + case compareNE: + c.asm.Cmp_ne() + case compareLT: + if swapped { + c.asm.Cmp_gt() + } else { + c.asm.Cmp_lt() + } + case compareLE: + if swapped { + c.asm.Cmp_ge() + } else { + c.asm.Cmp_le() + } + case compareGT: + if swapped { + c.asm.Cmp_lt() + } else { + c.asm.Cmp_gt() + } + case compareGE: + if swapped { + c.asm.Cmp_le() + } else { + c.asm.Cmp_ge() + } + case compareNullSafeEQ: + c.asm.jumpDestination(skip) + c.asm.Cmp_eq() + return cmptype, nil + + default: + panic("unexpected comparison operator") + } + + c.asm.jumpDestination(skip) + return cmptype, nil +} + +func (c *compiler) compileCheckTrue(when ctype, offset int) error { + switch when.Type { + case sqltypes.Int64: + c.asm.Convert_iB(offset) + case sqltypes.Uint64: + c.asm.Convert_uB(offset) + case sqltypes.Float64: + c.asm.Convert_fB(offset) + case sqltypes.Decimal: + c.asm.Convert_dB(offset) + case sqltypes.VarChar, sqltypes.VarBinary: + c.asm.Convert_bB(offset) + case sqltypes.Null: + c.asm.SetBool(offset, false) + default: + return vterrors.Errorf(vtrpc.Code_UNIMPLEMENTED, "unsupported Truth check: %s", when.Type) + } + return nil +} + +func (c *compiler) compileLike(expr *LikeExpr) (ctype, error) { + lt, err := c.compileExpr(expr.Left) + if err != nil { + return ctype{}, err + } + + rt, err := c.compileExpr(expr.Right) + if err != nil { + return ctype{}, err + } + + skip := c.asm.jumpFrom() + c.asm.NullCheck2(skip) + + if !sqltypes.IsText(lt.Type) && !sqltypes.IsBinary(lt.Type) { + c.asm.Convert_xc(2, sqltypes.VarChar, c.defaultCollation, 0, false) + lt.Col = collations.TypedCollation{ + Collation: c.defaultCollation, + Coercibility: collations.CoerceCoercible, + Repertoire: collations.RepertoireASCII, + } + } + + if !sqltypes.IsText(rt.Type) && !sqltypes.IsBinary(rt.Type) { + c.asm.Convert_xc(1, sqltypes.VarChar, c.defaultCollation, 0, false) + rt.Col = collations.TypedCollation{ + Collation: c.defaultCollation, + Coercibility: collations.CoerceCoercible, + Repertoire: collations.RepertoireASCII, + } + } + + var merged collations.TypedCollation + var coerceLeft collations.Coercion + var coerceRight collations.Coercion + var env = collations.Local() + + if lt.Col.Collation != rt.Col.Collation { + merged, coerceLeft, coerceRight, err = env.MergeCollations(lt.Col, rt.Col, collations.CoercionOptions{ + ConvertToSuperset: true, + ConvertWithCoercion: true, + }) + } else { + merged = lt.Col + } + if err != nil { + return ctype{}, err + } + + if coerceLeft == nil && coerceRight == nil { + c.asm.Like_collate(expr, merged.Collation.Get()) + } else { + if coerceLeft == nil { + coerceLeft = func(dst, in []byte) ([]byte, error) { return in, nil } + } + if coerceRight == nil { + coerceRight = func(dst, in []byte) ([]byte, error) { return in, nil } + } + c.asm.Like_coerce(expr, &compiledCoercion{ + col: merged.Collation.Get(), + left: coerceLeft, + right: coerceRight, + }) + } + + c.asm.jumpDestination(skip) + return ctype{Type: sqltypes.Int64, Col: collationNumeric}, nil +} diff --git a/go/vt/vtgate/evalengine/compiler_convert.go b/go/vt/vtgate/evalengine/compiler_convert.go new file mode 100644 index 00000000000..701883eb405 --- /dev/null +++ b/go/vt/vtgate/evalengine/compiler_convert.go @@ -0,0 +1,108 @@ +package evalengine + +import ( + "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/vterrors" +) + +func (c *compiler) compileCollate(expr *CollateExpr) (ctype, error) { + ct, err := c.compileExpr(expr.Inner) + if err != nil { + return ctype{}, err + } + + skip := c.asm.jumpFrom() + c.asm.NullCheck1(skip) + + switch ct.Type { + case sqltypes.VarChar: + if err := collations.Local().EnsureCollate(ct.Col.Collation, expr.TypedCollation.Collation); err != nil { + return ctype{}, vterrors.New(vtrpc.Code_INVALID_ARGUMENT, err.Error()) + } + fallthrough + case sqltypes.VarBinary: + c.asm.Collate(expr.TypedCollation.Collation) + default: + return ctype{}, c.unsupported(expr) + } + + c.asm.jumpDestination(skip) + + ct.Col = expr.TypedCollation + ct.Flag |= flagExplicitCollation + return ct, nil +} + +func (c *compiler) compileConvert(conv *ConvertExpr) (ctype, error) { + arg, err := c.compileExpr(conv.Inner) + if err != nil { + return ctype{}, err + } + + skip := c.asm.jumpFrom() + c.asm.NullCheck1(skip) + + var convt ctype + + switch conv.Type { + case "BINARY": + convt = ctype{Type: conv.convertToBinaryType(arg.Type), Col: collationBinary} + c.asm.Convert_xb(1, convt.Type, conv.Length, conv.HasLength) + + case "CHAR", "NCHAR": + convt = ctype{ + Type: conv.convertToCharType(arg.Type), + Col: collations.TypedCollation{Collation: conv.Collation}, + } + c.asm.Convert_xc(1, convt.Type, convt.Col.Collation, conv.Length, conv.HasLength) + + case "DECIMAL": + convt = ctype{Type: sqltypes.Decimal, Col: collationNumeric} + m, d := conv.decimalPrecision() + c.asm.Convert_xd(1, m, d) + + case "DOUBLE", "REAL": + convt = c.compileToFloat(arg, 1) + + case "SIGNED", "SIGNED INTEGER": + convt = c.compileToInt64(arg, 1) + + case "UNSIGNED", "UNSIGNED INTEGER": + convt = c.compileToUint64(arg, 1) + + case "JSON": + // TODO: what does NULL map to? + convt, err = c.compileToJSON(arg, 1) + if err != nil { + return ctype{}, err + } + + default: + return ctype{}, c.unsupported(conv) + } + + c.asm.jumpDestination(skip) + return convt, nil + +} + +func (c *compiler) compileConvertUsing(conv *ConvertUsingExpr) (ctype, error) { + _, err := c.compileExpr(conv.Inner) + if err != nil { + return ctype{}, err + } + + skip := c.asm.jumpFrom() + c.asm.NullCheck1(skip) + c.asm.Convert_xc(1, sqltypes.VarChar, conv.Collation, 0, false) + c.asm.jumpDestination(skip) + + col := collations.TypedCollation{ + Collation: conv.Collation, + Coercibility: collations.CoerceCoercible, + Repertoire: collations.RepertoireASCII, + } + return ctype{Type: sqltypes.VarChar, Col: col}, nil +} diff --git a/go/vt/vtgate/evalengine/compiler_fn.go b/go/vt/vtgate/evalengine/compiler_fn.go new file mode 100644 index 00000000000..bc907faff45 --- /dev/null +++ b/go/vt/vtgate/evalengine/compiler_fn.go @@ -0,0 +1,333 @@ +package evalengine + +import ( + "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/vterrors" +) + +func (c *compiler) compileFn(call callable) (ctype, error) { + switch call := call.(type) { + case *builtinMultiComparison: + return c.compileFn_MULTICMP(call) + case *builtinJSONExtract: + return c.compileFn_JSON_EXTRACT(call) + case *builtinJSONUnquote: + return c.compileFn_JSON_UNQUOTE(call) + case *builtinJSONContainsPath: + return c.compileFn_JSON_CONTAINS_PATH(call) + case *builtinJSONArray: + return c.compileFn_JSON_ARRAY(call) + case *builtinJSONObject: + return c.compileFn_JSON_OBJECT(call) + case *builtinRepeat: + return c.compileFn_REPEAT(call) + case *builtinToBase64: + return c.compileFn_TO_BASE64(call) + case *builtinFromBase64: + return c.compileFn_FROM_BASE64(call) + case *builtinChangeCase: + return c.compileFn_CCASE(call) + case *builtinCharLength: + return c.compileFn_LENGTH(call, charLen) + case *builtinLength: + return c.compileFn_LENGTH(call, byteLen) + case *builtinBitLength: + return c.compileFn_LENGTH(call, bitLen) + case *builtinASCII: + return c.compileFn_ASCII(call) + case *builtinHex: + return c.compileFn_HEX(call) + case *builtinBitCount: + return c.compileFn_BIT_COUNT(call) + default: + return ctype{}, c.unsupported(call) + } +} + +func (c *compiler) compileFn_MULTICMP(call *builtinMultiComparison) (ctype, error) { + var ( + integersI int + integersU int + floats int + decimals int + text int + binary int + args []ctype + ) + + /* + If any argument is NULL, the result is NULL. No comparison is needed. + If all arguments are integer-valued, they are compared as integers. + If at least one argument is double precision, they are compared as double-precision values. Otherwise, if at least one argument is a DECIMAL value, they are compared as DECIMAL values. + If the arguments comprise a mix of numbers and strings, they are compared as strings. + If any argument is a nonbinary (character) string, the arguments are compared as nonbinary strings. + In all other cases, the arguments are compared as binary strings. + */ + + for _, expr := range call.Arguments { + tt, err := c.compileExpr(expr) + if err != nil { + return ctype{}, err + } + + args = append(args, tt) + + switch tt.Type { + case sqltypes.Int64: + integersI++ + case sqltypes.Uint64: + integersU++ + case sqltypes.Float64: + floats++ + case sqltypes.Decimal: + decimals++ + case sqltypes.Text, sqltypes.VarChar: + text++ + case sqltypes.Blob, sqltypes.Binary, sqltypes.VarBinary: + binary++ + default: + return ctype{}, c.unsupported(call) + } + } + + if integersI+integersU == len(args) { + if integersI == len(args) { + c.asm.Fn_MULTICMP_i(len(args), call.cmp < 0) + return ctype{Type: sqltypes.Int64, Col: collationNumeric}, nil + } + if integersU == len(args) { + c.asm.Fn_MULTICMP_u(len(args), call.cmp < 0) + return ctype{Type: sqltypes.Uint64, Col: collationNumeric}, nil + } + return c.compileFN_MULTICMP_d(args, call.cmp < 0) + } + if binary > 0 || text > 0 { + if text > 0 { + return c.compileFn_MULTICMP_c(args, call.cmp < 0) + } + c.asm.Fn_MULTICMP_b(len(args), call.cmp < 0) + return ctype{Type: sqltypes.VarBinary, Col: collationBinary}, nil + } else { + if floats > 0 { + for i, tt := range args { + c.compileToFloat(tt, len(args)-i) + } + c.asm.Fn_MULTICMP_f(len(args), call.cmp < 0) + return ctype{Type: sqltypes.Float64, Col: collationNumeric}, nil + } + if decimals > 0 { + return c.compileFN_MULTICMP_d(args, call.cmp < 0) + } + } + return ctype{}, vterrors.Errorf(vtrpc.Code_INTERNAL, "unexpected argument for GREATEST/LEAST") +} + +func (c *compiler) compileFn_MULTICMP_c(args []ctype, lessThan bool) (ctype, error) { + env := collations.Local() + + var ca collationAggregation + for _, arg := range args { + if err := ca.add(env, arg.Col); err != nil { + return ctype{}, err + } + } + + tc := ca.result() + c.asm.Fn_MULTICMP_c(len(args), lessThan, tc) + return ctype{Type: sqltypes.VarChar, Col: tc}, nil +} + +func (c *compiler) compileFN_MULTICMP_d(args []ctype, lessThan bool) (ctype, error) { + for i, tt := range args { + c.compileToDecimal(tt, len(args)-i) + } + c.asm.Fn_MULTICMP_d(len(args), lessThan) + return ctype{Type: sqltypes.Decimal, Col: collationNumeric}, nil +} + +func (c *compiler) compileFn_REPEAT(expr *builtinRepeat) (ctype, error) { + str, err := c.compileExpr(expr.Arguments[0]) + if err != nil { + return ctype{}, err + } + + repeat, err := c.compileExpr(expr.Arguments[1]) + if err != nil { + return ctype{}, err + } + + skip := c.asm.jumpFrom() + c.asm.NullCheck2(skip) + + switch { + case sqltypes.IsText(str.Type) || sqltypes.IsBinary(str.Type): + default: + c.asm.Convert_xc(2, sqltypes.VarChar, c.defaultCollation, 0, false) + } + _ = c.compileToInt64(repeat, 1) + + c.asm.Fn_REPEAT(1) + c.asm.jumpDestination(skip) + return ctype{Type: sqltypes.VarChar, Col: str.Col}, nil +} + +func (c *compiler) compileFn_TO_BASE64(call *builtinToBase64) (ctype, error) { + str, err := c.compileExpr(call.Arguments[0]) + if err != nil { + return ctype{}, err + } + + skip := c.asm.jumpFrom() + c.asm.NullCheck1(skip) + + t := sqltypes.VarChar + if str.Type == sqltypes.Blob || str.Type == sqltypes.TypeJSON { + t = sqltypes.Text + } + + switch { + case sqltypes.IsText(str.Type) || sqltypes.IsBinary(str.Type): + default: + c.asm.Convert_xc(1, t, c.defaultCollation, 0, false) + } + + col := collations.TypedCollation{ + Collation: c.defaultCollation, + Coercibility: collations.CoerceCoercible, + Repertoire: collations.RepertoireASCII, + } + + c.asm.Fn_TO_BASE64(t, col) + c.asm.jumpDestination(skip) + + return ctype{Type: t, Col: col}, nil +} + +func (c *compiler) compileFn_FROM_BASE64(call *builtinFromBase64) (ctype, error) { + str, err := c.compileExpr(call.Arguments[0]) + if err != nil { + return ctype{}, err + } + + skip := c.asm.jumpFrom() + c.asm.NullCheck1(skip) + + switch { + case sqltypes.IsText(str.Type) || sqltypes.IsBinary(str.Type): + default: + c.asm.Convert_xc(1, sqltypes.VarBinary, c.defaultCollation, 0, false) + } + + c.asm.Fn_FROM_BASE64() + c.asm.jumpDestination(skip) + + return ctype{Type: sqltypes.VarBinary, Col: collationBinary}, nil +} + +func (c *compiler) compileFn_CCASE(call *builtinChangeCase) (ctype, error) { + str, err := c.compileExpr(call.Arguments[0]) + if err != nil { + return ctype{}, err + } + + skip := c.asm.jumpFrom() + c.asm.NullCheck1(skip) + + switch { + case sqltypes.IsText(str.Type) || sqltypes.IsBinary(str.Type): + default: + c.asm.Convert_xc(1, sqltypes.VarChar, c.defaultCollation, 0, false) + } + + c.asm.Fn_LUCASE(call.upcase) + c.asm.jumpDestination(skip) + + return ctype{Type: sqltypes.VarChar, Col: str.Col}, nil +} + +type lengthOp int + +const ( + charLen lengthOp = iota + byteLen + bitLen +) + +func (c *compiler) compileFn_LENGTH(call callable, op lengthOp) (ctype, error) { + str, err := c.compileExpr(call.callable()[0]) + if err != nil { + return ctype{}, err + } + + skip := c.asm.jumpFrom() + c.asm.NullCheck1(skip) + + switch { + case sqltypes.IsText(str.Type) || sqltypes.IsBinary(str.Type): + default: + c.asm.Convert_xc(1, sqltypes.VarChar, c.defaultCollation, 0, false) + } + + c.asm.Fn_LENGTH(op) + c.asm.jumpDestination(skip) + + return ctype{Type: sqltypes.Int64, Col: collationNumeric}, nil +} + +func (c *compiler) compileFn_ASCII(call *builtinASCII) (ctype, error) { + str, err := c.compileExpr(call.Arguments[0]) + if err != nil { + return ctype{}, err + } + + skip := c.asm.jumpFrom() + c.asm.NullCheck1(skip) + + switch { + case sqltypes.IsText(str.Type) || sqltypes.IsBinary(str.Type): + default: + c.asm.Convert_xc(1, sqltypes.VarChar, c.defaultCollation, 0, false) + } + + c.asm.Fn_ASCII() + c.asm.jumpDestination(skip) + + return ctype{Type: sqltypes.Int64, Col: collationNumeric}, nil +} + +func (c *compiler) compileFn_HEX(call *builtinHex) (ctype, error) { + str, err := c.compileExpr(call.Arguments[0]) + if err != nil { + return ctype{}, err + } + + skip := c.asm.jumpFrom() + c.asm.NullCheck1(skip) + + col := collations.TypedCollation{ + Collation: c.defaultCollation, + Coercibility: collations.CoerceCoercible, + Repertoire: collations.RepertoireASCII, + } + + t := sqltypes.VarChar + if str.Type == sqltypes.Blob || str.Type == sqltypes.TypeJSON { + t = sqltypes.Text + } + + switch { + case sqltypes.IsNumber(str.Type), sqltypes.IsDecimal(str.Type): + c.asm.Fn_HEX_d(col) + case sqltypes.IsText(str.Type) || sqltypes.IsBinary(str.Type): + c.asm.Fn_HEX_c(t, col) + default: + c.asm.Convert_xc(1, t, c.defaultCollation, 0, false) + c.asm.Fn_HEX_c(t, col) + } + + c.asm.jumpDestination(skip) + + return ctype{Type: t, Col: col}, nil +} diff --git a/go/vt/vtgate/evalengine/compiler_json.go b/go/vt/vtgate/evalengine/compiler_json.go new file mode 100644 index 00000000000..be0223db2b5 --- /dev/null +++ b/go/vt/vtgate/evalengine/compiler_json.go @@ -0,0 +1,208 @@ +package evalengine + +import ( + "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/mysql/collations/charset" + "vitess.io/vitess/go/slices2" + "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/evalengine/internal/json" +) + +func isEncodingJSONSafe(col collations.ID) bool { + switch col.Get().Charset().(type) { + case charset.Charset_utf8mb4, charset.Charset_utf8mb3, charset.Charset_binary: + return true + default: + return false + } +} + +func (c *compiler) compileParseJSON(fn string, doct ctype, offset int) (ctype, error) { + switch doct.Type { + case sqltypes.TypeJSON: + case sqltypes.VarChar, sqltypes.VarBinary: + c.asm.Parse_j(offset) + default: + return ctype{}, errJSONType(fn) + } + return ctype{Type: sqltypes.TypeJSON, Col: collationJSON}, nil +} + +func (c *compiler) compileToJSON(doct ctype, offset int) (ctype, error) { + switch doct.Type { + case sqltypes.TypeJSON: + return doct, nil + case sqltypes.Float64: + c.asm.Convert_fj(offset) + case sqltypes.Int64, sqltypes.Uint64, sqltypes.Decimal: + c.asm.Convert_nj(offset) + case sqltypes.VarChar: + c.asm.Convert_cj(offset) + case sqltypes.VarBinary: + c.asm.Convert_bj(offset) + case sqltypes.Null: + c.asm.Convert_Nj(offset) + default: + return ctype{}, vterrors.Errorf(vtrpc.Code_UNIMPLEMENTED, "Unsupported type conversion: %s AS JSON", doct.Type) + } + return ctype{Type: sqltypes.TypeJSON, Col: collationJSON}, nil +} + +func (c *compiler) compileFn_JSON_ARRAY(call *builtinJSONArray) (ctype, error) { + for _, arg := range call.Arguments { + tt, err := c.compileExpr(arg) + if err != nil { + return ctype{}, err + } + + _, err = c.compileToJSON(tt, 1) + if err != nil { + return ctype{}, err + } + } + c.asm.Fn_JSON_ARRAY(len(call.Arguments)) + return ctype{Type: sqltypes.TypeJSON, Col: collationJSON}, nil +} + +func (c *compiler) compileFn_JSON_OBJECT(call *builtinJSONObject) (ctype, error) { + for i := 0; i < len(call.Arguments); i += 2 { + key, err := c.compileExpr(call.Arguments[i]) + if err != nil { + return ctype{}, err + } + c.compileToJSONKey(key) + val, err := c.compileExpr(call.Arguments[i+1]) + if err != nil { + return ctype{}, err + } + val, err = c.compileToJSON(val, 1) + if err != nil { + return ctype{}, err + } + } + c.asm.Fn_JSON_OBJECT(len(call.Arguments)) + return ctype{Type: sqltypes.TypeJSON, Col: collationJSON}, nil +} + +func (c *compiler) compileToJSONKey(key ctype) { + if key.Type == sqltypes.VarChar && isEncodingJSONSafe(key.Col.Collation) { + return + } + if key.Type == sqltypes.VarBinary { + return + } + c.asm.Convert_xc(1, sqltypes.VarChar, collations.CollationUtf8mb4ID, 0, false) +} + +func (c *compiler) jsonExtractPath(expr Expr) (*json.Path, error) { + path, ok := expr.(*Literal) + if !ok { + return nil, errJSONPath + } + pathBytes, ok := path.inner.(*evalBytes) + if !ok { + return nil, errJSONPath + } + var parser json.PathParser + return parser.ParseBytes(pathBytes.bytes) +} + +func (c *compiler) jsonExtractOneOrAll(fname string, expr Expr) (jsonMatch, error) { + lit, ok := expr.(*Literal) + if !ok { + return jsonMatchInvalid, errOneOrAll(fname) + } + b, ok := lit.inner.(*evalBytes) + if !ok { + return jsonMatchInvalid, errOneOrAll(fname) + } + return intoOneOrAll(fname, b.string()) +} + +func (c *compiler) compileFn_JSON_EXTRACT(call *builtinJSONExtract) (ctype, error) { + doct, err := c.compileExpr(call.Arguments[0]) + if err != nil { + return ctype{}, err + } + + if slices2.All(call.Arguments[1:], func(expr Expr) bool { return expr.constant() }) { + paths := make([]*json.Path, 0, len(call.Arguments[1:])) + + for _, arg := range call.Arguments[1:] { + jp, err := c.jsonExtractPath(arg) + if err != nil { + return ctype{}, err + } + paths = append(paths, jp) + } + + jt, err := c.compileParseJSON("JSON_EXTRACT", doct, 1) + if err != nil { + return ctype{}, err + } + + c.asm.Fn_JSON_EXTRACT0(paths) + return jt, nil + } + + return ctype{}, c.unsupported(call) +} + +func (c *compiler) compileFn_JSON_UNQUOTE(call *builtinJSONUnquote) (ctype, error) { + arg, err := c.compileExpr(call.Arguments[0]) + if err != nil { + return ctype{}, err + } + + skip := c.asm.jumpFrom() + c.asm.NullCheck1(skip) + + _, err = c.compileParseJSON("JSON_UNQUOTE", arg, 1) + if err != nil { + return ctype{}, err + } + + c.asm.Fn_JSON_UNQUOTE() + c.asm.jumpDestination(skip) + return ctype{Type: sqltypes.Blob, Col: collationJSON}, nil +} + +func (c *compiler) compileFn_JSON_CONTAINS_PATH(call *builtinJSONContainsPath) (ctype, error) { + doct, err := c.compileExpr(call.Arguments[0]) + if err != nil { + return ctype{}, err + } + + if !call.Arguments[1].constant() { + return ctype{}, c.unsupported(call) + } + + if !slices2.All(call.Arguments[2:], func(expr Expr) bool { return expr.constant() }) { + return ctype{}, c.unsupported(call) + } + + match, err := c.jsonExtractOneOrAll("JSON_CONTAINS_PATH", call.Arguments[1]) + if err != nil { + return ctype{}, err + } + + paths := make([]*json.Path, 0, len(call.Arguments[2:])) + + for _, arg := range call.Arguments[2:] { + jp, err := c.jsonExtractPath(arg) + if err != nil { + return ctype{}, err + } + paths = append(paths, jp) + } + + _, err = c.compileParseJSON("JSON_CONTAINS_PATH", doct, 1) + if err != nil { + return ctype{}, err + } + + c.asm.Fn_JSON_CONTAINS_PATH(match, paths) + return ctype{Type: sqltypes.Int64, Col: collationNumeric}, nil +} diff --git a/go/vt/vtgate/evalengine/compiler_logical.go b/go/vt/vtgate/evalengine/compiler_logical.go new file mode 100644 index 00000000000..0c029e975e2 --- /dev/null +++ b/go/vt/vtgate/evalengine/compiler_logical.go @@ -0,0 +1,46 @@ +package evalengine + +import "vitess.io/vitess/go/mysql/collations" + +func (c *compiler) compileCase(cs *CaseExpr) (ctype, error) { + var ca collationAggregation + var ta typeAggregation + var local = collations.Local() + + for _, wt := range cs.cases { + when, err := c.compileExpr(wt.when) + if err != nil { + return ctype{}, err + } + + if err := c.compileCheckTrue(when, 1); err != nil { + return ctype{}, err + } + + then, err := c.compileExpr(wt.then) + if err != nil { + return ctype{}, err + } + + ta.add(then.Type, then.Flag) + if err := ca.add(local, then.Col); err != nil { + return ctype{}, err + } + } + + if cs.Else != nil { + els, err := c.compileExpr(cs.Else) + if err != nil { + return ctype{}, err + } + + ta.add(els.Type, els.Flag) + if err := ca.add(local, els.Col); err != nil { + return ctype{}, err + } + } + + ct := ctype{Type: ta.result(), Col: ca.result()} + c.asm.CmpCase(len(cs.cases), cs.Else != nil, ct.Type, ct.Col) + return ct, nil +} diff --git a/go/vt/vtgate/evalengine/compiler_test.go b/go/vt/vtgate/evalengine/compiler_test.go index 12032b95ddf..46455eadb9d 100644 --- a/go/vt/vtgate/evalengine/compiler_test.go +++ b/go/vt/vtgate/evalengine/compiler_test.go @@ -152,7 +152,7 @@ type debugCompiler struct { t testing.TB } -func (d *debugCompiler) Disasm(ins string, args ...any) { +func (d *debugCompiler) Instruction(ins string, args ...any) { ins = fmt.Sprintf(ins, args...) d.t.Logf("> %s", ins) } @@ -240,7 +240,7 @@ func TestCompiler(t *testing.T) { t.Fatal(err) } - compiled, err := evalengine.Compile(converted, makeFields(tc.values), evalengine.WithCompilerLog(&debugCompiler{t})) + compiled, err := evalengine.Compile(converted, makeFields(tc.values), evalengine.WithAssemblerLog(&debugCompiler{t})) if err != nil { t.Fatal(err) } From 9834fc3c6b11357586257dd5187e98718a60f3ba Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Thu, 9 Mar 2023 15:35:14 +0100 Subject: [PATCH 19/42] evalengine/compiler: Implement COLLATION function Signed-off-by: Dirkjan Bussink --- go/vt/vtgate/evalengine/compiler_asm.go | 8 ++++++++ go/vt/vtgate/evalengine/compiler_fn.go | 22 ++++++++++++++++++++++ go/vt/vtgate/evalengine/compiler_json.go | 4 ++-- go/vt/vtgate/evalengine/fn_json.go | 3 +-- 4 files changed, 33 insertions(+), 4 deletions(-) diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index a9427b8193a..9bcea5d9904 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -1810,6 +1810,14 @@ func (asm *assembler) Sub_uu() { }, "SUB UINT64(SP-2), UINT64(SP-1)") } +func (asm *assembler) Collation(col collations.TypedCollation) { + asm.emit(func(vm *VirtualMachine) int { + v := evalCollation(vm.stack[vm.sp-1]) + vm.stack[vm.sp-1] = vm.arena.newEvalText([]byte(v.Collation.Get().Name()), col) + return 1 + }, "COLLATION (SP-1)") +} + func cmpnum[N interface{ int64 | uint64 | float64 }](a, b N) int { switch { case a == b: diff --git a/go/vt/vtgate/evalengine/compiler_fn.go b/go/vt/vtgate/evalengine/compiler_fn.go index bc907faff45..d4827826e29 100644 --- a/go/vt/vtgate/evalengine/compiler_fn.go +++ b/go/vt/vtgate/evalengine/compiler_fn.go @@ -41,6 +41,8 @@ func (c *compiler) compileFn(call callable) (ctype, error) { return c.compileFn_HEX(call) case *builtinBitCount: return c.compileFn_BIT_COUNT(call) + case *builtinCollation: + return c.compileFn_COLLATION(call) default: return ctype{}, c.unsupported(call) } @@ -331,3 +333,23 @@ func (c *compiler) compileFn_HEX(call *builtinHex) (ctype, error) { return ctype{Type: t, Col: col}, nil } + +func (c *compiler) compileFn_COLLATION(expr *builtinCollation) (ctype, error) { + _, err := c.compileExpr(expr.Arguments[0]) + if err != nil { + return ctype{}, err + } + + skip := c.asm.jumpFrom() + + col := collations.TypedCollation{ + Collation: collations.CollationUtf8ID, + Coercibility: collations.CoerceCoercible, + Repertoire: collations.RepertoireASCII, + } + + c.asm.Collation(col) + c.asm.jumpDestination(skip) + + return ctype{Type: sqltypes.VarChar, Col: col}, nil +} diff --git a/go/vt/vtgate/evalengine/compiler_json.go b/go/vt/vtgate/evalengine/compiler_json.go index be0223db2b5..5b4ba288337 100644 --- a/go/vt/vtgate/evalengine/compiler_json.go +++ b/go/vt/vtgate/evalengine/compiler_json.go @@ -77,7 +77,7 @@ func (c *compiler) compileFn_JSON_OBJECT(call *builtinJSONObject) (ctype, error) if err != nil { return ctype{}, err } - val, err = c.compileToJSON(val, 1) + _, err = c.compileToJSON(val, 1) if err != nil { return ctype{}, err } @@ -93,7 +93,7 @@ func (c *compiler) compileToJSONKey(key ctype) { if key.Type == sqltypes.VarBinary { return } - c.asm.Convert_xc(1, sqltypes.VarChar, collations.CollationUtf8mb4ID, 0, false) + c.asm.Convert_xc(1, sqltypes.VarChar, c.defaultCollation, 0, false) } func (c *compiler) jsonExtractPath(expr Expr) (*json.Path, error) { diff --git a/go/vt/vtgate/evalengine/fn_json.go b/go/vt/vtgate/evalengine/fn_json.go index d0aced97417..75e88c82d75 100644 --- a/go/vt/vtgate/evalengine/fn_json.go +++ b/go/vt/vtgate/evalengine/fn_json.go @@ -17,7 +17,6 @@ limitations under the License. package evalengine import ( - "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/sqltypes" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/vterrors" @@ -160,7 +159,7 @@ func (call *builtinJSONObject) eval(env *ExpressionEnv) (eval, error) { if key == nil { return nil, errJSONKeyIsNil } - key1, err := evalToVarchar(key, collations.CollationUtf8mb4ID, true) + key1, err := evalToVarchar(key, env.DefaultCollation, true) if err != nil { return nil, err } From 1f16cafc10b7fd067393c783f038d6db88053c8c Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Thu, 9 Mar 2023 16:34:00 +0100 Subject: [PATCH 20/42] evalengine/compiler: Implement CEIL Signed-off-by: Dirkjan Bussink --- go/vt/vtgate/evalengine/compiler_asm.go | 22 +++++++++++++++ go/vt/vtgate/evalengine/compiler_fn.go | 32 ++++++++++++++++++++++ go/vt/vtgate/evalengine/testcases/cases.go | 5 ++++ 3 files changed, 59 insertions(+) diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index 9bcea5d9904..8b17f6b7eef 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -1818,6 +1818,28 @@ func (asm *assembler) Collation(col collations.TypedCollation) { }, "COLLATION (SP-1)") } +func (asm *assembler) Fn_CEIL_f() { + asm.emit(func(vm *VirtualMachine) int { + f := vm.stack[vm.sp-1].(*evalFloat) + f.f = math.Ceil(f.f) + return 1 + }, "CEIL FLOAT64(SP-1)") +} + +func (asm *assembler) Fn_CEIL_d() { + asm.emit(func(vm *VirtualMachine) int { + d := vm.stack[vm.sp-1].(*evalDecimal) + c := d.dec.Ceil() + i, valid := c.Int64() + if valid { + vm.stack[vm.sp-1] = vm.arena.newEvalInt64(i) + } else { + vm.err = errDeoptimize + } + return 1 + }, "CEIL DECIMAL(SP-1)") +} + func cmpnum[N interface{ int64 | uint64 | float64 }](a, b N) int { switch { case a == b: diff --git a/go/vt/vtgate/evalengine/compiler_fn.go b/go/vt/vtgate/evalengine/compiler_fn.go index d4827826e29..7611884f946 100644 --- a/go/vt/vtgate/evalengine/compiler_fn.go +++ b/go/vt/vtgate/evalengine/compiler_fn.go @@ -43,6 +43,8 @@ func (c *compiler) compileFn(call callable) (ctype, error) { return c.compileFn_BIT_COUNT(call) case *builtinCollation: return c.compileFn_COLLATION(call) + case *builtinCeil: + return c.compileFn_CEIL(call) default: return ctype{}, c.unsupported(call) } @@ -353,3 +355,33 @@ func (c *compiler) compileFn_COLLATION(expr *builtinCollation) (ctype, error) { return ctype{Type: sqltypes.VarChar, Col: col}, nil } + +func (c *compiler) compileFn_CEIL(expr *builtinCeil) (ctype, error) { + arg, err := c.compileExpr(expr.Arguments[0]) + if err != nil { + return ctype{}, err + } + + skip := c.asm.jumpFrom() + c.asm.NullCheck1(skip) + + convt := ctype{Type: arg.Type, Col: collationNumeric} + switch { + case sqltypes.IsIntegral(arg.Type): + // No-op for integers. + case sqltypes.IsFloat(arg.Type): + c.asm.Fn_CEIL_f() + case sqltypes.IsDecimal(arg.Type): + // We assume here the most common case here is that + // the decimal fits into an integer. + convt.Type = sqltypes.Int64 + c.asm.Fn_CEIL_d() + default: + convt.Type = sqltypes.Float64 + c.asm.Convert_xf(1) + c.asm.Fn_CEIL_f() + } + + c.asm.jumpDestination(skip) + return convt, nil +} diff --git a/go/vt/vtgate/evalengine/testcases/cases.go b/go/vt/vtgate/evalengine/testcases/cases.go index 7b112b991a2..261553d1abc 100644 --- a/go/vt/vtgate/evalengine/testcases/cases.go +++ b/go/vt/vtgate/evalengine/testcases/cases.go @@ -224,6 +224,11 @@ func (Ceil) Test(yield Iterator) { yield(fmt.Sprintf("CEIL(%s)", num), nil) yield(fmt.Sprintf("CEILING(%s)", num), nil) } + + for _, num := range inputBitwise { + yield(fmt.Sprintf("CEIL(%s)", num), nil) + yield(fmt.Sprintf("CEILING(%s)", num), nil) + } } // HACK: for CASE comparisons, the expression is supposed to decompose like this: From 500bcf9c1b9053182e737c326a3b8fb773534046 Mon Sep 17 00:00:00 2001 From: Vicent Marti Date: Wed, 8 Mar 2023 16:41:50 +0100 Subject: [PATCH 21/42] evalengine/compiler: IS expressions and JSON_KEYS Signed-off-by: Vicent Marti --- go/vt/vtgate/evalengine/compiler.go | 3 ++ go/vt/vtgate/evalengine/compiler_asm.go | 54 ++++++++++++++++++++- go/vt/vtgate/evalengine/compiler_fn.go | 2 + go/vt/vtgate/evalengine/compiler_json.go | 26 ++++++++++ go/vt/vtgate/evalengine/compiler_logical.go | 14 +++++- go/vt/vtgate/evalengine/compiler_test.go | 1 + 6 files changed, 97 insertions(+), 3 deletions(-) diff --git a/go/vt/vtgate/evalengine/compiler.go b/go/vt/vtgate/evalengine/compiler.go index b31650e2b98..ce09e77f9a6 100644 --- a/go/vt/vtgate/evalengine/compiler.go +++ b/go/vt/vtgate/evalengine/compiler.go @@ -126,6 +126,9 @@ func (c *compiler) compileExpr(expr Expr) (ctype, error) { case *LikeExpr: return c.compileLike(expr) + case *IsExpr: + return c.compileIs(expr) + case callable: return c.compileFn(expr) diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index 8b17f6b7eef..730b69b3955 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -706,7 +706,7 @@ func (asm *assembler) Convert_bj(offset int) { arg := vm.stack[vm.sp-offset].(*evalBytes) vm.stack[vm.sp-offset] = evalConvert_bj(arg) return 1 - }, "CONV VARBINARY(SP-%d), JSON") + }, "CONV VARBINARY(SP-%d), JSON", offset) } func (asm *assembler) Convert_cj(offset int) { @@ -714,7 +714,7 @@ func (asm *assembler) Convert_cj(offset int) { arg := vm.stack[vm.sp-offset].(*evalBytes) vm.stack[vm.sp-offset], vm.err = evalConvert_cj(arg) return 1 - }, "CONV VARCHAR(SP-%d), JSON") + }, "CONV VARCHAR(SP-%d), JSON", offset) } func (asm *assembler) Convert_dB(offset int) { @@ -1325,6 +1325,13 @@ func (asm *assembler) Fn_TO_BASE64(t sqltypes.Type, col collations.TypedCollatio }, "FN TO_BASE64 VARCHAR(SP-1)") } +func (asm *assembler) Is(check func(eval) bool) { + asm.emit(func(vm *VirtualMachine) int { + vm.stack[vm.sp-1] = newEvalBool(check(vm.stack[vm.sp-1])) + return 1 + }, "IS (SP-1)") +} + func (asm *assembler) Like_coerce(expr *LikeExpr, coercion *compiledCoercion) { asm.adjustStack(-1) @@ -1810,6 +1817,49 @@ func (asm *assembler) Sub_uu() { }, "SUB UINT64(SP-2), UINT64(SP-1)") } +func (asm *assembler) Fn_JSON_KEYS(jp *json.Path) { + if jp == nil { + asm.emit(func(vm *VirtualMachine) int { + doc := vm.stack[vm.sp-1] + if doc == nil { + return 1 + } + j := doc.(*evalJSON) + if obj, ok := j.Object(); ok { + var keys []*json.Value + obj.Visit(func(key []byte, _ *json.Value) { + keys = append(keys, json.NewString(key)) + }) + vm.stack[vm.sp-1] = json.NewArray(keys) + } else { + vm.stack[vm.sp-1] = nil + } + return 1 + }, "FN JSON_KEYS (SP-1)") + } else { + asm.emit(func(vm *VirtualMachine) int { + doc := vm.stack[vm.sp-1] + if doc == nil { + return 1 + } + var obj *json.Object + jp.Match(doc.(*evalJSON), false, func(value *json.Value) { + obj, _ = value.Object() + }) + if obj != nil { + var keys []*json.Value + obj.Visit(func(key []byte, _ *json.Value) { + keys = append(keys, json.NewString(key)) + }) + vm.stack[vm.sp-1] = json.NewArray(keys) + } else { + vm.stack[vm.sp-1] = nil + } + return 1 + }, "FN JSON_KEYS (SP-1), %q", jp.String()) + } +} + func (asm *assembler) Collation(col collations.TypedCollation) { asm.emit(func(vm *VirtualMachine) int { v := evalCollation(vm.stack[vm.sp-1]) diff --git a/go/vt/vtgate/evalengine/compiler_fn.go b/go/vt/vtgate/evalengine/compiler_fn.go index 7611884f946..182e081284d 100644 --- a/go/vt/vtgate/evalengine/compiler_fn.go +++ b/go/vt/vtgate/evalengine/compiler_fn.go @@ -21,6 +21,8 @@ func (c *compiler) compileFn(call callable) (ctype, error) { return c.compileFn_JSON_ARRAY(call) case *builtinJSONObject: return c.compileFn_JSON_OBJECT(call) + case *builtinJSONKeys: + return c.compileFn_JSON_KEYS(call) case *builtinRepeat: return c.compileFn_REPEAT(call) case *builtinToBase64: diff --git a/go/vt/vtgate/evalengine/compiler_json.go b/go/vt/vtgate/evalengine/compiler_json.go index 5b4ba288337..dc2761991f8 100644 --- a/go/vt/vtgate/evalengine/compiler_json.go +++ b/go/vt/vtgate/evalengine/compiler_json.go @@ -206,3 +206,29 @@ func (c *compiler) compileFn_JSON_CONTAINS_PATH(call *builtinJSONContainsPath) ( c.asm.Fn_JSON_CONTAINS_PATH(match, paths) return ctype{Type: sqltypes.Int64, Col: collationNumeric}, nil } + +func (c *compiler) compileFn_JSON_KEYS(call *builtinJSONKeys) (ctype, error) { + doc, err := c.compileExpr(call.Arguments[0]) + if err != nil { + return ctype{}, err + } + + doc, err = c.compileParseJSON("JSON_KEYS", doc, 1) + if err != nil { + return ctype{}, err + } + + var jp *json.Path + if len(call.Arguments) == 2 { + jp, err = c.jsonExtractPath(call.Arguments[1]) + if err != nil { + return ctype{}, err + } + if jp.ContainsWildcards() { + return ctype{}, errInvalidPathForTransform + } + } + + c.asm.Fn_JSON_KEYS(jp) + return ctype{Type: sqltypes.TypeJSON, Flag: flagNullable, Col: collationJSON}, nil +} diff --git a/go/vt/vtgate/evalengine/compiler_logical.go b/go/vt/vtgate/evalengine/compiler_logical.go index 0c029e975e2..bd1d10ef974 100644 --- a/go/vt/vtgate/evalengine/compiler_logical.go +++ b/go/vt/vtgate/evalengine/compiler_logical.go @@ -1,6 +1,18 @@ package evalengine -import "vitess.io/vitess/go/mysql/collations" +import ( + "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/sqltypes" +) + +func (c *compiler) compileIs(is *IsExpr) (ctype, error) { + _, err := c.compileExpr(is.Inner) + if err != nil { + return ctype{}, err + } + c.asm.Is(is.Check) + return ctype{Type: sqltypes.Int64, Col: collationNumeric}, nil +} func (c *compiler) compileCase(cs *CaseExpr) (ctype, error) { var ca collationAggregation diff --git a/go/vt/vtgate/evalengine/compiler_test.go b/go/vt/vtgate/evalengine/compiler_test.go index 46455eadb9d..baa6ad94616 100644 --- a/go/vt/vtgate/evalengine/compiler_test.go +++ b/go/vt/vtgate/evalengine/compiler_test.go @@ -111,6 +111,7 @@ func TestCompilerReference(t *testing.T) { if vterrors.Code(err) != vtrpcpb.Code_UNIMPLEMENTED { t.Errorf("failed compilation:\nSQL: %s\nError: %s", query, err) } + t.Logf("unsupported: %s", query) return } From 9cff5d754af0b7facd4227f77b284eb0ebdbc438 Mon Sep 17 00:00:00 2001 From: Vicent Marti Date: Thu, 9 Mar 2023 18:05:24 +0100 Subject: [PATCH 22/42] evalengine: fix IS and CEIL corner cases Signed-off-by: Vicent Marti --- go/vt/vtgate/evalengine/cached_size.go | 21 +------- go/vt/vtgate/evalengine/eval.go | 4 ++ go/vt/vtgate/evalengine/expr_compare.go | 43 ++++----------- go/vt/vtgate/evalengine/expr_env.go | 13 +++-- go/vt/vtgate/evalengine/fn_numeric.go | 6 ++- .../evalengine/integration/comparison_test.go | 2 +- .../evalengine/integration/fuzz_test.go | 5 ++ go/vt/vtgate/evalengine/mysql_test.go | 2 +- go/vt/vtgate/evalengine/testcases/cases.go | 52 +++++++----------- go/vt/vtgate/evalengine/testcases/helpers.go | 54 ++++++++++++++++++- go/vt/vtgate/evalengine/testcases/inputs.go | 3 +- go/vt/vtgate/evalengine/translate.go | 1 - go/vt/vtgate/evalengine/translate_simplify.go | 54 +------------------ 13 files changed, 112 insertions(+), 148 deletions(-) diff --git a/go/vt/vtgate/evalengine/cached_size.go b/go/vt/vtgate/evalengine/cached_size.go index 1e0740cbb50..1733706123b 100644 --- a/go/vt/vtgate/evalengine/cached_size.go +++ b/go/vt/vtgate/evalengine/cached_size.go @@ -17,13 +17,7 @@ limitations under the License. package evalengine -import ( - "math" - "reflect" - "unsafe" - - hack "vitess.io/vitess/go/hack" -) +import hack "vitess.io/vitess/go/hack" type cachedObject interface { CachedSize(alloc bool) int64 @@ -209,8 +203,6 @@ func (cached *ConvertUsingExpr) CachedSize(alloc bool) int64 { size += cached.UnaryExpr.CachedSize(false) return size } - -//go:nocheckptr func (cached *InExpr) CachedSize(alloc bool) int64 { if cached == nil { return int64(0) @@ -221,17 +213,6 @@ func (cached *InExpr) CachedSize(alloc bool) int64 { } // field BinaryExpr vitess.io/vitess/go/vt/vtgate/evalengine.BinaryExpr size += cached.BinaryExpr.CachedSize(false) - // field Hashed map[[16]byte]int - if cached.Hashed != nil { - size += int64(48) - hmap := reflect.ValueOf(cached.Hashed) - numBuckets := int(math.Pow(2, float64((*(*uint8)(unsafe.Pointer(hmap.Pointer() + uintptr(9))))))) - numOldBuckets := (*(*uint16)(unsafe.Pointer(hmap.Pointer() + uintptr(10)))) - size += hack.RuntimeAllocSize(int64(numOldBuckets * 208)) - if len(cached.Hashed) > 0 || numBuckets > 1 { - size += hack.RuntimeAllocSize(int64(numBuckets * 208)) - } - } return size } func (cached *IsExpr) CachedSize(alloc bool) int64 { diff --git a/go/vt/vtgate/evalengine/eval.go b/go/vt/vtgate/evalengine/eval.go index f5f42c30625..d9169444d3b 100644 --- a/go/vt/vtgate/evalengine/eval.go +++ b/go/vt/vtgate/evalengine/eval.go @@ -52,6 +52,10 @@ const ( // flagExplicitCollation marks that this value has an explicit collation flagExplicitCollation typeFlag = 1 << 10 + // flagAmbiguousType marks that the type of this value depends on the value at runtime + // and cannot be computed accurately + flagAmbiguousType typeFlag = 1 << 11 + // flagIntegerRange are the flags that mark overflow/underflow in integers flagIntegerRange = flagIntegerOvf | flagIntegerCap | flagIntegerUdf ) diff --git a/go/vt/vtgate/evalengine/expr_compare.go b/go/vt/vtgate/evalengine/expr_compare.go index a401ddd940d..5012d027e5d 100644 --- a/go/vt/vtgate/evalengine/expr_compare.go +++ b/go/vt/vtgate/evalengine/expr_compare.go @@ -21,7 +21,6 @@ import ( "vitess.io/vitess/go/sqltypes" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/vterrors" - "vitess.io/vitess/go/vt/vthash" ) type ( @@ -45,7 +44,6 @@ type ( InExpr struct { BinaryExpr Negate bool - Hashed map[vthash.Hash]int } ComparisonOp interface { @@ -291,37 +289,18 @@ func (i *InExpr) eval(env *ExpressionEnv) (eval, error) { } var foundNull, found bool - var hasher = vthash.New() - if i.Hashed != nil { - if left, ok := left.(hashable); ok { - left.Hash(&hasher) - - hash := hasher.Sum128() - hasher.Reset() - - if idx, ok := i.Hashed[hash]; ok { - var numeric int - numeric, foundNull, err = evalCompareAll(left, rtuple.t[idx], true) - if err != nil { - return nil, err - } - found = numeric == 0 - } + for _, rtuple := range rtuple.t { + numeric, isNull, err := evalCompareAll(left, rtuple, true) + if err != nil { + return nil, err } - } else { - for _, rtuple := range rtuple.t { - numeric, isNull, err := evalCompareAll(left, rtuple, true) - if err != nil { - return nil, err - } - if isNull { - foundNull = true - continue - } - if numeric == 0 { - found = true - break - } + if isNull { + foundNull = true + continue + } + if numeric == 0 { + found = true + break } } diff --git a/go/vt/vtgate/evalengine/expr_env.go b/go/vt/vtgate/evalengine/expr_env.go index bcefe10f81b..ae73066dee6 100644 --- a/go/vt/vtgate/evalengine/expr_env.go +++ b/go/vt/vtgate/evalengine/expr_env.go @@ -17,6 +17,8 @@ limitations under the License. package evalengine import ( + "errors" + "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" @@ -43,9 +45,14 @@ func (env *ExpressionEnv) Evaluate(expr Expr) (EvalResult, error) { return EvalResult{e}, err } -func (env *ExpressionEnv) TypeOf(expr Expr) (ty sqltypes.Type, err error) { - ty, _ = expr.typeof(env) - return +var ErrAmbiguousType = errors.New("the type of this expression cannot be statically computed") + +func (env *ExpressionEnv) TypeOf(expr Expr) (sqltypes.Type, error) { + ty, f := expr.typeof(env) + if f&flagAmbiguousType != 0 { + return ty, ErrAmbiguousType + } + return ty, nil } func (env *ExpressionEnv) collation() collations.TypedCollation { diff --git a/go/vt/vtgate/evalengine/fn_numeric.go b/go/vt/vtgate/evalengine/fn_numeric.go index 3c157f9bc89..19f12c9cfe3 100644 --- a/go/vt/vtgate/evalengine/fn_numeric.go +++ b/go/vt/vtgate/evalengine/fn_numeric.go @@ -56,10 +56,12 @@ func (call *builtinCeil) eval(env *ExpressionEnv) (eval, error) { func (call *builtinCeil) typeof(env *ExpressionEnv) (sqltypes.Type, typeFlag) { t, f := call.Arguments[0].typeof(env) - if sqltypes.IsIntegral(t) { + if sqltypes.IsSigned(t) { return sqltypes.Int64, f + } else if sqltypes.IsUnsigned(t) { + return sqltypes.Uint64, f } else if sqltypes.Decimal == t { - return sqltypes.Decimal, f + return sqltypes.Int64, f | flagAmbiguousType } else { return sqltypes.Float64, f } diff --git a/go/vt/vtgate/evalengine/integration/comparison_test.go b/go/vt/vtgate/evalengine/integration/comparison_test.go index 03bf5598ad2..c541238fc9a 100644 --- a/go/vt/vtgate/evalengine/integration/comparison_test.go +++ b/go/vt/vtgate/evalengine/integration/comparison_test.go @@ -131,7 +131,7 @@ func compareRemoteExprEnv(t *testing.T, env *evalengine.ExpressionEnv, conn *mys } else { localVal = v.String() } - if debugCheckTypes { + if debugCheckTypes && localType != -1 { tt := v.Type() if tt != sqltypes.Null && tt != localType { t.Errorf("evaluation type mismatch: eval=%v vs typeof=%v\nlocal: %s\nquery: %s (SIMPLIFY=%v)", diff --git a/go/vt/vtgate/evalengine/integration/fuzz_test.go b/go/vt/vtgate/evalengine/integration/fuzz_test.go index 55b5cc3c4c0..875fab6540c 100644 --- a/go/vt/vtgate/evalengine/integration/fuzz_test.go +++ b/go/vt/vtgate/evalengine/integration/fuzz_test.go @@ -18,6 +18,7 @@ package integration import ( "encoding/json" + "errors" "fmt" "math/rand" "os" @@ -151,6 +152,10 @@ func evaluateLocalEvalengine(env *evalengine.ExpressionEnv, query string) (evale eval, err = env.Evaluate(local) if err == nil && debugCheckTypes { tt, err = env.TypeOf(local) + if errors.Is(err, evalengine.ErrAmbiguousType) { + tt = -1 + err = nil + } } return }() diff --git a/go/vt/vtgate/evalengine/mysql_test.go b/go/vt/vtgate/evalengine/mysql_test.go index 4ccb4c30029..dc0eed7c77f 100644 --- a/go/vt/vtgate/evalengine/mysql_test.go +++ b/go/vt/vtgate/evalengine/mysql_test.go @@ -127,6 +127,6 @@ func TestMySQLGolden(t *testing.T) { func TestDebug1(t *testing.T) { // Debug - eval, err := testSingle(t, `SELECT LEAST(18446744073709551615, CAST(0 AS UNSIGNED), CAST(1 AS UNSIGNED))`) + eval, err := testSingle(t, `SELECT 0x1 IN (0xFF, 64)`) t.Logf("eval=%s err=%v coll=%s", eval.String(), err, eval.Collation().Get().Name()) } diff --git a/go/vt/vtgate/evalengine/testcases/cases.go b/go/vt/vtgate/evalengine/testcases/cases.go index 261553d1abc..152f6c571e6 100644 --- a/go/vt/vtgate/evalengine/testcases/cases.go +++ b/go/vt/vtgate/evalengine/testcases/cases.go @@ -78,6 +78,7 @@ type FnBitLength struct{ defaultEnv } type FnAscii struct{ defaultEnv } type FnRepeat struct{ defaultEnv } type FnHex struct{ defaultEnv } +type InStatement struct{ defaultEnv } var Cases = []TestCase{ JSONExtract{}, @@ -116,6 +117,7 @@ var Cases = []TestCase{ FnAscii{}, FnRepeat{}, FnHex{}, + InStatement{}, } func (JSONPathOperations) Test(yield Iterator) { @@ -231,39 +233,6 @@ func (Ceil) Test(yield Iterator) { } } -// HACK: for CASE comparisons, the expression is supposed to decompose like this: -// -// CASE a WHEN b THEN bb WHEN c THEN cc ELSE d -// => CASE WHEN a = b THEN bb WHEN a == c THEN cc ELSE d -// -// See: https://dev.mysql.com/doc/refman/5.7/en/flow-control-functions.html#operator_case -// However, MySQL does not seem to be using the real `=` operator for some of these comparisons -// namely, numerical comparisons are coerced into an unsigned form when they shouldn't. -// Example: -// -// SELECT -1 = 18446744073709551615 -// => 0 -// SELECT -1 WHEN 18446744073709551615 THEN 1 ELSE 0 END -// => 1 -// -// This does not happen for other types, which all follow the behavior of the `=` operator, -// so we're going to assume this is a bug for now. -func comparisonSkip(a, b string) bool { - if a == "-1" && b == "18446744073709551615" { - return true - } - if b == "-1" && a == "18446744073709551615" { - return true - } - if a == "9223372036854775808" && b == "-9223372036854775808" { - return true - } - if a == "-9223372036854775808" && b == "9223372036854775808" { - return true - } - return false -} - func (CaseExprWithValue) Test(yield Iterator) { var elements []string elements = append(elements, inputBitwise...) @@ -271,7 +240,7 @@ func (CaseExprWithValue) Test(yield Iterator) { for _, cmpbase := range elements { for _, val1 := range elements { - if comparisonSkip(cmpbase, val1) { + if !(bugs{}).CanCompare(cmpbase, val1) { continue } yield(fmt.Sprintf("case %s when %s then 1 else 0 end", cmpbase, val1), nil) @@ -768,3 +737,18 @@ func (FnHex) Test(yield Iterator) { yield(fmt.Sprintf("hex(%s)", str), nil) } } + +func (InStatement) Test(yield Iterator) { + roots := append([]string(nil), inputBitwise...) + roots = append(roots, inputComparisonElement...) + + genSubsets(roots, 3, func(inputs []string) { + if !(bugs{}).CanCompare(inputs...) { + return + } + yield(fmt.Sprintf("%s IN (%s, %s)", inputs[0], inputs[1], inputs[2]), nil) + yield(fmt.Sprintf("%s IN (%s, %s)", inputs[2], inputs[1], inputs[0]), nil) + yield(fmt.Sprintf("%s IN (%s, %s)", inputs[1], inputs[0], inputs[2]), nil) + yield(fmt.Sprintf("%s IN (%s, %s, %s)", inputs[0], inputs[1], inputs[2], inputs[0]), nil) + }) +} diff --git a/go/vt/vtgate/evalengine/testcases/helpers.go b/go/vt/vtgate/evalengine/testcases/helpers.go index 6402549576f..fa352c31eec 100644 --- a/go/vt/vtgate/evalengine/testcases/helpers.go +++ b/go/vt/vtgate/evalengine/testcases/helpers.go @@ -16,7 +16,11 @@ limitations under the License. package testcases -import "vitess.io/vitess/go/sqltypes" +import ( + "strings" + + "vitess.io/vitess/go/sqltypes" +) func perm(a []string, f func([]string)) { perm1(a, f, 0) @@ -67,3 +71,51 @@ func mustJSON(j string) sqltypes.Value { } return v } + +type bugs struct{} + +// CanCompare skips comparisons in CASE and IN expressions that behave in unexpected +// ways. The following is an example of expressions giving un-intuitive results (i.e. +// results that do not match the behavior of the `=` operator, which is supposed to apply, +// pair-wise, to the comparisons in a CASE or IN statement): +// +// SELECT -1 IN (0xFF, 18446744073709551615) => 1 +// SELECT -1 IN (0, 18446744073709551615) => 0 +// SELECT -1 IN (0.0, 18446744073709551615) => 1 +// +// SELECT 'FOO' IN ('foo', 0x00) => 0 +// SELECT 'FOO' IN ('foo', 0) => 1 +// SELECT 'FOO' IN ('foo', 0x00, CAST('bar' as char)) => 1 +// SELECT 'FOO' IN ('foo', 0x00, 'bar') => 0 +// +// SELECT 9223372036854775808 IN (0.0e0, -9223372036854775808) => 1 +// SELECT 9223372036854775808 IN (0, -9223372036854775808) => 0 +// SELECT 9223372036854775808 IN (0.0, -9223372036854775808) => 1 +// +// Generally speaking, it's counter-intuitive that adding more (unrelated) types to the +// right-hand of the IN operator would change the result of the operation itself. It seems +// like there's logic that changes the way the elements are compared with a type aggregation +// but this is not documented anywhere. +func (bugs) CanCompare(elems ...string) bool { + var invalid = map[string]string{ + "18446744073709551615": "-1", + `9223372036854775808`: `-9223372036854775808`, + } + + for i, e := range elems { + if strings.HasPrefix(e, "_binary ") || + strings.HasPrefix(e, "0x") || + strings.HasPrefix(e, "X'") || + strings.HasSuffix(e, "collate utf8mb4_0900_as_cs") { + return false + } + if other, ok := invalid[e]; ok { + for j := 0; j < len(elems); j++ { + if i != j && elems[j] == other { + return false + } + } + } + } + return true +} diff --git a/go/vt/vtgate/evalengine/testcases/inputs.go b/go/vt/vtgate/evalengine/testcases/inputs.go index 067106df468..3ae80620f99 100644 --- a/go/vt/vtgate/evalengine/testcases/inputs.go +++ b/go/vt/vtgate/evalengine/testcases/inputs.go @@ -57,7 +57,8 @@ var inputBitwise = []string{ "64", "'64'", "_binary '64'", "X'40'", "_binary X'40'", } -var inputComparisonElement = []string{"NULL", "-1", "0", "1", +var inputComparisonElement = []string{ + "NULL", "-1", "0", "1", `'foo'`, `'bar'`, `'FOO'`, `'BAR'`, `'foo' collate utf8mb4_0900_as_cs`, `'FOO' collate utf8mb4_0900_as_cs`, diff --git a/go/vt/vtgate/evalengine/translate.go b/go/vt/vtgate/evalengine/translate.go index 562dc920c21..4ace1b54f4f 100644 --- a/go/vt/vtgate/evalengine/translate.go +++ b/go/vt/vtgate/evalengine/translate.go @@ -60,7 +60,6 @@ func (ast *astCompiler) translateComparisonExpr2(op sqlparser.ComparisonExprOper return &InExpr{ BinaryExpr: binaryExpr, Negate: op == sqlparser.NotInOp, - Hashed: nil, }, nil } diff --git a/go/vt/vtgate/evalengine/translate_simplify.go b/go/vt/vtgate/evalengine/translate_simplify.go index b76adfe96a0..3e957a943fc 100644 --- a/go/vt/vtgate/evalengine/translate_simplify.go +++ b/go/vt/vtgate/evalengine/translate_simplify.go @@ -16,12 +16,6 @@ limitations under the License. package evalengine -import ( - "vitess.io/vitess/go/mysql/collations" - "vitess.io/vitess/go/sqltypes" - "vitess.io/vitess/go/vt/vthash" -) - func (expr *Literal) constant() bool { return true } @@ -96,54 +90,10 @@ func (inexpr *InExpr) simplify(env *ExpressionEnv) error { return err } - tuple, ok := inexpr.Right.(TupleExpr) - if !ok { - return nil - } - - var ( - collation collations.ID - typ sqltypes.Type - optimize = true - ) - - for i, expr := range tuple { - if lit, ok := expr.(*Literal); ok { - thisColl := evalCollation(lit.inner).Collation - thisTyp := lit.inner.SQLType() - if i == 0 { - collation = thisColl - typ = thisTyp - continue - } - if collation == thisColl && typ == thisTyp { - continue - } - } - optimize = false - break + if err := inexpr.Right.simplify(env); err != nil { + return err } - if optimize { - inexpr.Hashed = make(map[vthash.Hash]int) - hasher := vthash.New() - for i, expr := range tuple { - lit := expr.(*Literal) - inner, ok := lit.inner.(hashable) - if !ok { - inexpr.Hashed = nil - break - } - - inner.Hash(&hasher) - hash := hasher.Sum128() - hasher.Reset() - - if _, found := inexpr.Hashed[hash]; !found { - inexpr.Hashed[hash] = i - } - } - } return nil } From e2c84d2f5ab6fa0477648f8111435f9d242c5f74 Mon Sep 17 00:00:00 2001 From: Vicent Marti Date: Fri, 10 Mar 2023 11:16:39 +0100 Subject: [PATCH 23/42] evalengine/compiler: implement IN Signed-off-by: Vicent Marti --- go/vt/vtgate/evalengine/compiler.go | 3 ++ go/vt/vtgate/evalengine/compiler_asm.go | 57 +++++++++++++++++++++ go/vt/vtgate/evalengine/compiler_compare.go | 52 +++++++++++++++++++ go/vt/vtgate/evalengine/expr_compare.go | 47 ++++++++++------- go/vt/vtgate/evalengine/vm.go | 2 + 5 files changed, 143 insertions(+), 18 deletions(-) diff --git a/go/vt/vtgate/evalengine/compiler.go b/go/vt/vtgate/evalengine/compiler.go index ce09e77f9a6..128bbb3a413 100644 --- a/go/vt/vtgate/evalengine/compiler.go +++ b/go/vt/vtgate/evalengine/compiler.go @@ -129,6 +129,9 @@ func (c *compiler) compileExpr(expr Expr) (ctype, error) { case *IsExpr: return c.compileIs(expr) + case *InExpr: + return c.compileIn(expr) + case callable: return c.compileFn(expr) diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index 730b69b3955..c550c0164fd 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -13,6 +13,7 @@ import ( "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/evalengine/internal/decimal" "vitess.io/vitess/go/vt/vtgate/evalengine/internal/json" + "vitess.io/vitess/go/vt/vthash" ) type jump struct { @@ -1890,6 +1891,62 @@ func (asm *assembler) Fn_CEIL_d() { }, "CEIL DECIMAL(SP-1)") } +func (asm *assembler) In_table(not bool, table map[vthash.Hash]struct{}) { + if not { + asm.emit(func(vm *VirtualMachine) int { + lhs := vm.stack[vm.sp-1] + if lhs != nil { + vm.hash.Reset() + lhs.(hashable).Hash(&vm.hash) + _, in := table[vm.hash.Sum128()] + vm.stack[vm.sp-1] = newEvalBool(!in) + } + return 1 + }, "NOT IN (SP-1), [static table]") + } else { + asm.emit(func(vm *VirtualMachine) int { + lhs := vm.stack[vm.sp-1] + if lhs != nil { + vm.hash.Reset() + lhs.(hashable).Hash(&vm.hash) + _, in := table[vm.hash.Sum128()] + vm.stack[vm.sp-1] = newEvalBool(in) + } + return 1 + }, "IN (SP-1), [static table]") + } +} + +func (asm *assembler) In_slow(not bool) { + asm.adjustStack(-1) + + if not { + asm.emit(func(vm *VirtualMachine) int { + lhs := vm.stack[vm.sp-2] + rhs := vm.stack[vm.sp-1].(*evalTuple) + + var in boolean + in, vm.err = evalInExpr(lhs, rhs) + + vm.stack[vm.sp-2] = in.not().eval() + vm.sp -= 1 + return 1 + }, "NOT IN (SP-2), TUPLE(SP-1)") + } else { + asm.emit(func(vm *VirtualMachine) int { + lhs := vm.stack[vm.sp-2] + rhs := vm.stack[vm.sp-1].(*evalTuple) + + var in boolean + in, vm.err = evalInExpr(lhs, rhs) + + vm.stack[vm.sp-2] = in.eval() + vm.sp -= 1 + return 1 + }, "IN (SP-2), TUPLE(SP-1)") + } +} + func cmpnum[N interface{ int64 | uint64 | float64 }](a, b N) int { switch { case a == b: diff --git a/go/vt/vtgate/evalengine/compiler_compare.go b/go/vt/vtgate/evalengine/compiler_compare.go index c3caa45a659..087f82481cf 100644 --- a/go/vt/vtgate/evalengine/compiler_compare.go +++ b/go/vt/vtgate/evalengine/compiler_compare.go @@ -5,6 +5,7 @@ import ( "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vthash" ) func (c *compiler) compileComparisonTuple(expr *ComparisonExpr) (ctype, error) { @@ -295,3 +296,54 @@ func (c *compiler) compileLike(expr *LikeExpr) (ctype, error) { c.asm.jumpDestination(skip) return ctype{Type: sqltypes.Int64, Col: collationNumeric}, nil } + +func (c *compiler) compileInTable(lhs ctype, rhs TupleExpr) map[vthash.Hash]struct{} { + var ( + table = make(map[vthash.Hash]struct{}) + hasher = vthash.New() + ) + + for _, expr := range rhs { + lit, ok := expr.(*Literal) + if !ok { + return nil + } + inner, ok := lit.inner.(hashable) + if !ok { + return nil + } + + thisColl := evalCollation(lit.inner).Collation + thisTyp := lit.inner.SQLType() + + if thisTyp != lhs.Type || thisColl != lhs.Col.Collation { + return nil + } + + inner.Hash(&hasher) + table[hasher.Sum128()] = struct{}{} + hasher.Reset() + } + + return table +} + +func (c *compiler) compileIn(expr *InExpr) (ctype, error) { + lhs, err := c.compileExpr(expr.Left) + if err != nil { + return ctype{}, nil + } + + rhs := expr.Right.(TupleExpr) + + if table := c.compileInTable(lhs, rhs); table != nil { + c.asm.In_table(expr.Negate, table) + } else { + _, err := c.compileTuple(rhs) + if err != nil { + return ctype{}, err + } + c.asm.In_slow(expr.Negate) + } + return ctype{Type: sqltypes.Int64, Col: collationNumeric}, nil +} diff --git a/go/vt/vtgate/evalengine/expr_compare.go b/go/vt/vtgate/evalengine/expr_compare.go index 5012d027e5d..8d8f64dfadc 100644 --- a/go/vt/vtgate/evalengine/expr_compare.go +++ b/go/vt/vtgate/evalengine/expr_compare.go @@ -274,25 +274,16 @@ func (c *ComparisonExpr) typeof(env *ExpressionEnv) (sqltypes.Type, typeFlag) { return sqltypes.Int64, f1 | f2 } -// eval implements the ComparisonOp interface -func (i *InExpr) eval(env *ExpressionEnv) (eval, error) { - left, right, err := i.arguments(env) - if err != nil { - return nil, err - } - rtuple, ok := right.(*evalTuple) - if !ok { - return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "rhs of an In operation should be a tuple") - } - if left == nil { - return nil, nil +func evalInExpr(lhs eval, rhs *evalTuple) (boolean, error) { + if lhs == nil { + return boolNULL, nil } var foundNull, found bool - for _, rtuple := range rtuple.t { - numeric, isNull, err := evalCompareAll(left, rtuple, true) + for _, rtuple := range rhs.t { + numeric, isNull, err := evalCompareAll(lhs, rtuple, true) if err != nil { - return nil, err + return boolNULL, err } if isNull { foundNull = true @@ -306,12 +297,32 @@ func (i *InExpr) eval(env *ExpressionEnv) (eval, error) { switch { case found: - return newEvalBool(!i.Negate), nil + return boolTrue, nil case foundNull: - return nil, nil + return boolNULL, nil default: - return newEvalBool(i.Negate), nil + return boolFalse, nil + } +} + +// eval implements the ComparisonOp interface +func (i *InExpr) eval(env *ExpressionEnv) (eval, error) { + left, right, err := i.arguments(env) + if err != nil { + return nil, err + } + rtuple, ok := right.(*evalTuple) + if !ok { + return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "rhs of an In operation should be a tuple") + } + in, err := evalInExpr(left, rtuple) + if err != nil { + return nil, err + } + if i.Negate { + in = in.not() } + return in.eval(), nil } func (i *InExpr) typeof(env *ExpressionEnv) (sqltypes.Type, typeFlag) { diff --git a/go/vt/vtgate/evalengine/vm.go b/go/vt/vtgate/evalengine/vm.go index f44597edb9a..fd3e4af7a21 100644 --- a/go/vt/vtgate/evalengine/vm.go +++ b/go/vt/vtgate/evalengine/vm.go @@ -4,12 +4,14 @@ import ( "errors" "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/vthash" ) var errDeoptimize = errors.New("de-optimize") type VirtualMachine struct { arena Arena + hash vthash.Hasher row []sqltypes.Value stack []eval sp int From a98190683e225558b056e4e334607a237d293ceb Mon Sep 17 00:00:00 2001 From: Vicent Marti Date: Fri, 10 Mar 2023 11:38:57 +0100 Subject: [PATCH 24/42] evalengine/compiler: implement WEIGHT_STRING Signed-off-by: Vicent Marti --- go/vt/vtgate/evalengine/compiler_asm.go | 287 +++++++++++++----------- go/vt/vtgate/evalengine/compiler_fn.go | 32 ++- 2 files changed, 187 insertions(+), 132 deletions(-) diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index c550c0164fd..dee8f3e20bd 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -963,6 +963,36 @@ func (asm *assembler) Fn_ASCII() { }, "FN ASCII VARCHAR(SP-1)") } +func (asm *assembler) Fn_CEIL_d() { + asm.emit(func(vm *VirtualMachine) int { + d := vm.stack[vm.sp-1].(*evalDecimal) + c := d.dec.Ceil() + i, valid := c.Int64() + if valid { + vm.stack[vm.sp-1] = vm.arena.newEvalInt64(i) + } else { + vm.err = errDeoptimize + } + return 1 + }, "FN CEIL DECIMAL(SP-1)") +} + +func (asm *assembler) Fn_CEIL_f() { + asm.emit(func(vm *VirtualMachine) int { + f := vm.stack[vm.sp-1].(*evalFloat) + f.f = math.Ceil(f.f) + return 1 + }, "FN CEIL FLOAT64(SP-1)") +} + +func (asm *assembler) Fn_COLLATION(col collations.TypedCollation) { + asm.emit(func(vm *VirtualMachine) int { + v := evalCollation(vm.stack[vm.sp-1]) + vm.stack[vm.sp-1] = vm.arena.newEvalText([]byte(v.Collation.Get().Name()), col) + return 1 + }, "FN COLLATION (SP-1)") +} + func (asm *assembler) Fn_FROM_BASE64() { asm.emit(func(vm *VirtualMachine) int { str := vm.stack[vm.sp-1].(*evalBytes) @@ -977,7 +1007,7 @@ func (asm *assembler) Fn_FROM_BASE64() { str.tt = int16(sqltypes.VarBinary) str.bytes = decoded[:n] return 1 - }, "FROM_BASE64 VARCHAR(SP-1)") + }, "FN FROM_BASE64 VARCHAR(SP-1)") } func (asm *assembler) Fn_HEX_c(t sqltypes.Type, col collations.TypedCollation) { @@ -1079,6 +1109,49 @@ func (asm *assembler) Fn_JSON_EXTRACT0(jp []*json.Path) { } } +func (asm *assembler) Fn_JSON_KEYS(jp *json.Path) { + if jp == nil { + asm.emit(func(vm *VirtualMachine) int { + doc := vm.stack[vm.sp-1] + if doc == nil { + return 1 + } + j := doc.(*evalJSON) + if obj, ok := j.Object(); ok { + var keys []*json.Value + obj.Visit(func(key []byte, _ *json.Value) { + keys = append(keys, json.NewString(key)) + }) + vm.stack[vm.sp-1] = json.NewArray(keys) + } else { + vm.stack[vm.sp-1] = nil + } + return 1 + }, "FN JSON_KEYS (SP-1)") + } else { + asm.emit(func(vm *VirtualMachine) int { + doc := vm.stack[vm.sp-1] + if doc == nil { + return 1 + } + var obj *json.Object + jp.Match(doc.(*evalJSON), false, func(value *json.Value) { + obj, _ = value.Object() + }) + if obj != nil { + var keys []*json.Value + obj.Visit(func(key []byte, _ *json.Value) { + keys = append(keys, json.NewString(key)) + }) + vm.stack[vm.sp-1] = json.NewArray(keys) + } else { + vm.stack[vm.sp-1] = nil + } + return 1 + }, "FN JSON_KEYS (SP-1), %q", jp.String()) + } +} + func (asm *assembler) Fn_JSON_OBJECT(args int) { asm.adjustStack(-(args - 1)) asm.emit(func(vm *VirtualMachine) int { @@ -1326,11 +1399,85 @@ func (asm *assembler) Fn_TO_BASE64(t sqltypes.Type, col collations.TypedCollatio }, "FN TO_BASE64 VARCHAR(SP-1)") } +func (asm *assembler) Fn_WEIGHT_STRING_b(length int) { + asm.emit(func(vm *VirtualMachine) int { + str := vm.stack[vm.sp-1].(*evalBytes) + w := collations.Binary.WeightString(make([]byte, 0, length), str.bytes, collations.PadToMax) + vm.stack[vm.sp-1] = vm.arena.newEvalBinary(w) + return 1 + }, "FN WEIGHT_STRING VARBINARY(SP-1)") +} + +func (asm *assembler) Fn_WEIGHT_STRING_c(col collations.Collation, length int) { + asm.emit(func(vm *VirtualMachine) int { + str := vm.stack[vm.sp-1].(*evalBytes) + w := col.WeightString(nil, str.bytes, length) + vm.stack[vm.sp-1] = vm.arena.newEvalBinary(w) + return 1 + }, "FN WEIGHT_STRING VARCHAR(SP-1)") +} + +func (asm *assembler) In_table(not bool, table map[vthash.Hash]struct{}) { + if not { + asm.emit(func(vm *VirtualMachine) int { + lhs := vm.stack[vm.sp-1] + if lhs != nil { + vm.hash.Reset() + lhs.(hashable).Hash(&vm.hash) + _, in := table[vm.hash.Sum128()] + vm.stack[vm.sp-1] = newEvalBool(!in) + } + return 1 + }, "NOT IN (SP-1), [static table]") + } else { + asm.emit(func(vm *VirtualMachine) int { + lhs := vm.stack[vm.sp-1] + if lhs != nil { + vm.hash.Reset() + lhs.(hashable).Hash(&vm.hash) + _, in := table[vm.hash.Sum128()] + vm.stack[vm.sp-1] = newEvalBool(in) + } + return 1 + }, "IN (SP-1), [static table]") + } +} + +func (asm *assembler) In_slow(not bool) { + asm.adjustStack(-1) + + if not { + asm.emit(func(vm *VirtualMachine) int { + lhs := vm.stack[vm.sp-2] + rhs := vm.stack[vm.sp-1].(*evalTuple) + + var in boolean + in, vm.err = evalInExpr(lhs, rhs) + + vm.stack[vm.sp-2] = in.not().eval() + vm.sp -= 1 + return 1 + }, "NOT IN (SP-2), TUPLE(SP-1)") + } else { + asm.emit(func(vm *VirtualMachine) int { + lhs := vm.stack[vm.sp-2] + rhs := vm.stack[vm.sp-1].(*evalTuple) + + var in boolean + in, vm.err = evalInExpr(lhs, rhs) + + vm.stack[vm.sp-2] = in.eval() + vm.sp -= 1 + return 1 + }, "IN (SP-2), TUPLE(SP-1)") + } +} + func (asm *assembler) Is(check func(eval) bool) { asm.emit(func(vm *VirtualMachine) int { vm.stack[vm.sp-1] = newEvalBool(check(vm.stack[vm.sp-1])) return 1 - }, "IS (SP-1)") + }, "IS (SP-1), [static]") } func (asm *assembler) Like_coerce(expr *LikeExpr, coercion *compiledCoercion) { @@ -1745,6 +1892,13 @@ func (asm *assembler) SetBool(offset int, b bool) { } } +func (asm *assembler) SetNull(offset int) { + asm.emit(func(vm *VirtualMachine) int { + vm.stack[vm.sp-offset] = nil + return 1 + }, "SET (SP-%d), NULL", offset) +} + func (asm *assembler) Sub_dd() { asm.adjustStack(-1) @@ -1818,135 +1972,6 @@ func (asm *assembler) Sub_uu() { }, "SUB UINT64(SP-2), UINT64(SP-1)") } -func (asm *assembler) Fn_JSON_KEYS(jp *json.Path) { - if jp == nil { - asm.emit(func(vm *VirtualMachine) int { - doc := vm.stack[vm.sp-1] - if doc == nil { - return 1 - } - j := doc.(*evalJSON) - if obj, ok := j.Object(); ok { - var keys []*json.Value - obj.Visit(func(key []byte, _ *json.Value) { - keys = append(keys, json.NewString(key)) - }) - vm.stack[vm.sp-1] = json.NewArray(keys) - } else { - vm.stack[vm.sp-1] = nil - } - return 1 - }, "FN JSON_KEYS (SP-1)") - } else { - asm.emit(func(vm *VirtualMachine) int { - doc := vm.stack[vm.sp-1] - if doc == nil { - return 1 - } - var obj *json.Object - jp.Match(doc.(*evalJSON), false, func(value *json.Value) { - obj, _ = value.Object() - }) - if obj != nil { - var keys []*json.Value - obj.Visit(func(key []byte, _ *json.Value) { - keys = append(keys, json.NewString(key)) - }) - vm.stack[vm.sp-1] = json.NewArray(keys) - } else { - vm.stack[vm.sp-1] = nil - } - return 1 - }, "FN JSON_KEYS (SP-1), %q", jp.String()) - } -} - -func (asm *assembler) Collation(col collations.TypedCollation) { - asm.emit(func(vm *VirtualMachine) int { - v := evalCollation(vm.stack[vm.sp-1]) - vm.stack[vm.sp-1] = vm.arena.newEvalText([]byte(v.Collation.Get().Name()), col) - return 1 - }, "COLLATION (SP-1)") -} - -func (asm *assembler) Fn_CEIL_f() { - asm.emit(func(vm *VirtualMachine) int { - f := vm.stack[vm.sp-1].(*evalFloat) - f.f = math.Ceil(f.f) - return 1 - }, "CEIL FLOAT64(SP-1)") -} - -func (asm *assembler) Fn_CEIL_d() { - asm.emit(func(vm *VirtualMachine) int { - d := vm.stack[vm.sp-1].(*evalDecimal) - c := d.dec.Ceil() - i, valid := c.Int64() - if valid { - vm.stack[vm.sp-1] = vm.arena.newEvalInt64(i) - } else { - vm.err = errDeoptimize - } - return 1 - }, "CEIL DECIMAL(SP-1)") -} - -func (asm *assembler) In_table(not bool, table map[vthash.Hash]struct{}) { - if not { - asm.emit(func(vm *VirtualMachine) int { - lhs := vm.stack[vm.sp-1] - if lhs != nil { - vm.hash.Reset() - lhs.(hashable).Hash(&vm.hash) - _, in := table[vm.hash.Sum128()] - vm.stack[vm.sp-1] = newEvalBool(!in) - } - return 1 - }, "NOT IN (SP-1), [static table]") - } else { - asm.emit(func(vm *VirtualMachine) int { - lhs := vm.stack[vm.sp-1] - if lhs != nil { - vm.hash.Reset() - lhs.(hashable).Hash(&vm.hash) - _, in := table[vm.hash.Sum128()] - vm.stack[vm.sp-1] = newEvalBool(in) - } - return 1 - }, "IN (SP-1), [static table]") - } -} - -func (asm *assembler) In_slow(not bool) { - asm.adjustStack(-1) - - if not { - asm.emit(func(vm *VirtualMachine) int { - lhs := vm.stack[vm.sp-2] - rhs := vm.stack[vm.sp-1].(*evalTuple) - - var in boolean - in, vm.err = evalInExpr(lhs, rhs) - - vm.stack[vm.sp-2] = in.not().eval() - vm.sp -= 1 - return 1 - }, "NOT IN (SP-2), TUPLE(SP-1)") - } else { - asm.emit(func(vm *VirtualMachine) int { - lhs := vm.stack[vm.sp-2] - rhs := vm.stack[vm.sp-1].(*evalTuple) - - var in boolean - in, vm.err = evalInExpr(lhs, rhs) - - vm.stack[vm.sp-2] = in.eval() - vm.sp -= 1 - return 1 - }, "IN (SP-2), TUPLE(SP-1)") - } -} - func cmpnum[N interface{ int64 | uint64 | float64 }](a, b N) int { switch { case a == b: diff --git a/go/vt/vtgate/evalengine/compiler_fn.go b/go/vt/vtgate/evalengine/compiler_fn.go index 182e081284d..76855682835 100644 --- a/go/vt/vtgate/evalengine/compiler_fn.go +++ b/go/vt/vtgate/evalengine/compiler_fn.go @@ -47,6 +47,8 @@ func (c *compiler) compileFn(call callable) (ctype, error) { return c.compileFn_COLLATION(call) case *builtinCeil: return c.compileFn_CEIL(call) + case *builtinWeightString: + return c.compileFn_WEIGHT_STRING(call) default: return ctype{}, c.unsupported(call) } @@ -352,7 +354,7 @@ func (c *compiler) compileFn_COLLATION(expr *builtinCollation) (ctype, error) { Repertoire: collations.RepertoireASCII, } - c.asm.Collation(col) + c.asm.Fn_COLLATION(col) c.asm.jumpDestination(skip) return ctype{Type: sqltypes.VarChar, Col: col}, nil @@ -387,3 +389,31 @@ func (c *compiler) compileFn_CEIL(expr *builtinCeil) (ctype, error) { c.asm.jumpDestination(skip) return convt, nil } + +func (c *compiler) compileFn_WEIGHT_STRING(call *builtinWeightString) (ctype, error) { + str, err := c.compileExpr(call.String) + if err != nil { + return ctype{}, err + } + + switch str.Type { + case sqltypes.Int64, sqltypes.Uint64: + return ctype{}, c.unsupported(call) + + case sqltypes.VarChar, sqltypes.VarBinary: + skip := c.asm.jumpFrom() + c.asm.NullCheck1(skip) + + if call.Cast == "binary" { + c.asm.Fn_WEIGHT_STRING_b(call.Len) + } else { + c.asm.Fn_WEIGHT_STRING_c(str.Col.Collation.Get(), call.Len) + } + c.asm.jumpDestination(skip) + return ctype{Type: sqltypes.VarBinary, Col: collationBinary}, nil + + default: + c.asm.SetNull(1) + return ctype{Type: sqltypes.VarBinary, Flag: flagNullable | flagNull, Col: collationBinary}, nil + } +} From 1f513a3440e57506f7dbbd856663285f5d132a6e Mon Sep 17 00:00:00 2001 From: Vicent Marti Date: Fri, 10 Mar 2023 12:40:09 +0100 Subject: [PATCH 25/42] evalengine/compiler: optimize null checks Signed-off-by: Vicent Marti --- go/vt/vtgate/evalengine/compiler.go | 18 +++++++++++++ .../vtgate/evalengine/compiler_arithmetic.go | 18 +++++-------- go/vt/vtgate/evalengine/compiler_asm.go | 6 +++-- go/vt/vtgate/evalengine/compiler_bit.go | 12 +++------ go/vt/vtgate/evalengine/compiler_compare.go | 10 +++---- go/vt/vtgate/evalengine/compiler_convert.go | 17 +++++------- go/vt/vtgate/evalengine/compiler_fn.go | 27 +++++++------------ go/vt/vtgate/evalengine/compiler_json.go | 8 +++--- 8 files changed, 57 insertions(+), 59 deletions(-) diff --git a/go/vt/vtgate/evalengine/compiler.go b/go/vt/vtgate/evalengine/compiler.go index 128bbb3a413..cc122155e24 100644 --- a/go/vt/vtgate/evalengine/compiler.go +++ b/go/vt/vtgate/evalengine/compiler.go @@ -292,3 +292,21 @@ func (c *compiler) compileToDecimal(ct ctype, offset int) ctype { } return ctype{sqltypes.Decimal, ct.Flag, collationNumeric} } + +func (c *compiler) compileNullCheck1(ct ctype) *jump { + if ct.Flag&flagNullable != 0 { + j := c.asm.jumpFrom() + c.asm.NullCheck1(j) + return j + } + return nil +} + +func (c *compiler) compileNullCheck2(lt, rt ctype) *jump { + if lt.Flag&flagNullable != 0 || rt.Flag&flagNullable != 0 { + j := c.asm.jumpFrom() + c.asm.NullCheck2(j) + return j + } + return nil +} diff --git a/go/vt/vtgate/evalengine/compiler_arithmetic.go b/go/vt/vtgate/evalengine/compiler_arithmetic.go index c3852bd23be..62319f1a03f 100644 --- a/go/vt/vtgate/evalengine/compiler_arithmetic.go +++ b/go/vt/vtgate/evalengine/compiler_arithmetic.go @@ -8,9 +8,7 @@ func (c *compiler) compileNegate(expr *NegateExpr) (ctype, error) { return ctype{}, err } - skip := c.asm.jumpFrom() - c.asm.NullCheck1(skip) - + skip := c.compileNullCheck1(arg) arg = c.compileToNumeric(arg, 1) var neg sqltypes.Type @@ -85,8 +83,7 @@ func (c *compiler) compileArithmeticAdd(left, right Expr) (ctype, error) { } swap := false - skip := c.asm.jumpFrom() - c.asm.NullCheck2(skip) + skip := c.compileNullCheck2(lt, rt) lt = c.compileToNumeric(lt, 2) rt = c.compileToNumeric(rt, 1) lt, rt, swap = c.compileNumericPriority(lt, rt) @@ -138,8 +135,7 @@ func (c *compiler) compileArithmeticSub(left, right Expr) (ctype, error) { return ctype{}, err } - skip := c.asm.jumpFrom() - c.asm.NullCheck2(skip) + skip := c.compileNullCheck2(lt, rt) lt = c.compileToNumeric(lt, 2) rt = c.compileToNumeric(rt, 1) @@ -217,8 +213,7 @@ func (c *compiler) compileArithmeticMul(left, right Expr) (ctype, error) { } swap := false - skip := c.asm.jumpFrom() - c.asm.NullCheck2(skip) + skip := c.compileNullCheck2(lt, rt) lt = c.compileToNumeric(lt, 2) rt = c.compileToNumeric(rt, 1) lt, rt, swap = c.compileNumericPriority(lt, rt) @@ -270,8 +265,7 @@ func (c *compiler) compileArithmeticDiv(left, right Expr) (ctype, error) { return ctype{}, err } - skip := c.asm.jumpFrom() - c.asm.NullCheck2(skip) + skip := c.compileNullCheck2(lt, rt) lt = c.compileToNumeric(lt, 2) rt = c.compileToNumeric(rt, 1) @@ -279,11 +273,13 @@ func (c *compiler) compileArithmeticDiv(left, right Expr) (ctype, error) { c.compileToFloat(lt, 2) c.compileToFloat(rt, 1) c.asm.Div_ff() + c.asm.jumpDestination(skip) return ctype{Type: sqltypes.Float64, Col: collationNumeric}, nil } else { c.compileToDecimal(lt, 2) c.compileToDecimal(rt, 1) c.asm.Div_dd() + c.asm.jumpDestination(skip) return ctype{Type: sqltypes.Decimal, Col: collationNumeric}, nil } } diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index dee8f3e20bd..8806224fb52 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -38,7 +38,9 @@ func (asm *assembler) jumpFrom() *jump { } func (asm *assembler) jumpDestination(j *jump) { - j.to = len(asm.ins) + if j != nil { + j.to = len(asm.ins) + } } func (asm *assembler) adjustStack(offset int) { @@ -1666,7 +1668,7 @@ func (asm *assembler) NullCheck2(j *jump) { }, "NULLCHECK SP-1, SP-2") } -func (asm *assembler) NullCmp(j *jump) { +func (asm *assembler) Cmp_nullsafe(j *jump) { asm.emit(func(vm *VirtualMachine) int { l := vm.stack[vm.sp-2] r := vm.stack[vm.sp-1] diff --git a/go/vt/vtgate/evalengine/compiler_bit.go b/go/vt/vtgate/evalengine/compiler_bit.go index 0654ce6ed02..c56d4c6266b 100644 --- a/go/vt/vtgate/evalengine/compiler_bit.go +++ b/go/vt/vtgate/evalengine/compiler_bit.go @@ -38,8 +38,7 @@ func (c *compiler) compileBitwiseOp(left Expr, right Expr, op bitwiseOp) (ctype, return ctype{}, err } - skip := c.asm.jumpFrom() - c.asm.NullCheck2(skip) + skip := c.compileNullCheck2(lt, rt) if lt.Type == sqltypes.VarBinary && rt.Type == sqltypes.VarBinary { if !lt.isHexOrBitLiteral() || !rt.isHexOrBitLiteral() { @@ -68,8 +67,7 @@ func (c *compiler) compileBitwiseShift(left Expr, right Expr, i int) (ctype, err return ctype{}, err } - skip := c.asm.jumpFrom() - c.asm.NullCheck2(skip) + skip := c.compileNullCheck2(lt, rt) if lt.Type == sqltypes.VarBinary && !lt.isHexOrBitLiteral() { _ = c.compileToUint64(rt, 1) @@ -101,8 +99,7 @@ func (c *compiler) compileFn_BIT_COUNT(expr *builtinBitCount) (ctype, error) { return ctype{}, err } - skip := c.asm.jumpFrom() - c.asm.NullCheck1(skip) + skip := c.compileNullCheck1(ct) if ct.Type == sqltypes.VarBinary && !ct.isHexOrBitLiteral() { c.asm.BitCount_b() @@ -122,8 +119,7 @@ func (c *compiler) compileBitwiseNot(expr *BitwiseNotExpr) (ctype, error) { return ctype{}, err } - skip := c.asm.jumpFrom() - c.asm.NullCheck1(skip) + skip := c.compileNullCheck1(ct) if ct.Type == sqltypes.VarBinary && !ct.isHexOrBitLiteral() { c.asm.BitwiseNot_b() diff --git a/go/vt/vtgate/evalengine/compiler_compare.go b/go/vt/vtgate/evalengine/compiler_compare.go index 087f82481cf..cacdcb66438 100644 --- a/go/vt/vtgate/evalengine/compiler_compare.go +++ b/go/vt/vtgate/evalengine/compiler_compare.go @@ -56,13 +56,14 @@ func (c *compiler) compileComparison(expr *ComparisonExpr) (ctype, error) { } swapped := false - skip := c.asm.jumpFrom() + var skip *jump switch expr.Op.(type) { case compareNullSafeEQ: - c.asm.NullCmp(skip) + skip = c.asm.jumpFrom() + c.asm.Cmp_nullsafe(skip) default: - c.asm.NullCheck2(skip) + skip = c.compileNullCheck2(lt, rt) } switch { @@ -239,8 +240,7 @@ func (c *compiler) compileLike(expr *LikeExpr) (ctype, error) { return ctype{}, err } - skip := c.asm.jumpFrom() - c.asm.NullCheck2(skip) + skip := c.compileNullCheck2(lt, rt) if !sqltypes.IsText(lt.Type) && !sqltypes.IsBinary(lt.Type) { c.asm.Convert_xc(2, sqltypes.VarChar, c.defaultCollation, 0, false) diff --git a/go/vt/vtgate/evalengine/compiler_convert.go b/go/vt/vtgate/evalengine/compiler_convert.go index 701883eb405..131f011cdba 100644 --- a/go/vt/vtgate/evalengine/compiler_convert.go +++ b/go/vt/vtgate/evalengine/compiler_convert.go @@ -13,8 +13,7 @@ func (c *compiler) compileCollate(expr *CollateExpr) (ctype, error) { return ctype{}, err } - skip := c.asm.jumpFrom() - c.asm.NullCheck1(skip) + skip := c.compileNullCheck1(ct) switch ct.Type { case sqltypes.VarChar: @@ -31,7 +30,7 @@ func (c *compiler) compileCollate(expr *CollateExpr) (ctype, error) { c.asm.jumpDestination(skip) ct.Col = expr.TypedCollation - ct.Flag |= flagExplicitCollation + ct.Flag |= flagExplicitCollation | flagNullable return ct, nil } @@ -41,9 +40,7 @@ func (c *compiler) compileConvert(conv *ConvertExpr) (ctype, error) { return ctype{}, err } - skip := c.asm.jumpFrom() - c.asm.NullCheck1(skip) - + skip := c.compileNullCheck1(arg) var convt ctype switch conv.Type { @@ -84,18 +81,18 @@ func (c *compiler) compileConvert(conv *ConvertExpr) (ctype, error) { } c.asm.jumpDestination(skip) + convt.Flag = arg.Flag | flagNullable return convt, nil } func (c *compiler) compileConvertUsing(conv *ConvertUsingExpr) (ctype, error) { - _, err := c.compileExpr(conv.Inner) + ct, err := c.compileExpr(conv.Inner) if err != nil { return ctype{}, err } - skip := c.asm.jumpFrom() - c.asm.NullCheck1(skip) + skip := c.compileNullCheck1(ct) c.asm.Convert_xc(1, sqltypes.VarChar, conv.Collation, 0, false) c.asm.jumpDestination(skip) @@ -104,5 +101,5 @@ func (c *compiler) compileConvertUsing(conv *ConvertUsingExpr) (ctype, error) { Coercibility: collations.CoerceCoercible, Repertoire: collations.RepertoireASCII, } - return ctype{Type: sqltypes.VarChar, Col: col}, nil + return ctype{Type: sqltypes.VarChar, Flag: flagNullable, Col: col}, nil } diff --git a/go/vt/vtgate/evalengine/compiler_fn.go b/go/vt/vtgate/evalengine/compiler_fn.go index 76855682835..079ba9e7464 100644 --- a/go/vt/vtgate/evalengine/compiler_fn.go +++ b/go/vt/vtgate/evalengine/compiler_fn.go @@ -166,8 +166,7 @@ func (c *compiler) compileFn_REPEAT(expr *builtinRepeat) (ctype, error) { return ctype{}, err } - skip := c.asm.jumpFrom() - c.asm.NullCheck2(skip) + skip := c.compileNullCheck2(str, repeat) switch { case sqltypes.IsText(str.Type) || sqltypes.IsBinary(str.Type): @@ -187,8 +186,7 @@ func (c *compiler) compileFn_TO_BASE64(call *builtinToBase64) (ctype, error) { return ctype{}, err } - skip := c.asm.jumpFrom() - c.asm.NullCheck1(skip) + skip := c.compileNullCheck1(str) t := sqltypes.VarChar if str.Type == sqltypes.Blob || str.Type == sqltypes.TypeJSON { @@ -219,8 +217,7 @@ func (c *compiler) compileFn_FROM_BASE64(call *builtinFromBase64) (ctype, error) return ctype{}, err } - skip := c.asm.jumpFrom() - c.asm.NullCheck1(skip) + skip := c.compileNullCheck1(str) switch { case sqltypes.IsText(str.Type) || sqltypes.IsBinary(str.Type): @@ -240,8 +237,7 @@ func (c *compiler) compileFn_CCASE(call *builtinChangeCase) (ctype, error) { return ctype{}, err } - skip := c.asm.jumpFrom() - c.asm.NullCheck1(skip) + skip := c.compileNullCheck1(str) switch { case sqltypes.IsText(str.Type) || sqltypes.IsBinary(str.Type): @@ -269,8 +265,7 @@ func (c *compiler) compileFn_LENGTH(call callable, op lengthOp) (ctype, error) { return ctype{}, err } - skip := c.asm.jumpFrom() - c.asm.NullCheck1(skip) + skip := c.compileNullCheck1(str) switch { case sqltypes.IsText(str.Type) || sqltypes.IsBinary(str.Type): @@ -290,8 +285,7 @@ func (c *compiler) compileFn_ASCII(call *builtinASCII) (ctype, error) { return ctype{}, err } - skip := c.asm.jumpFrom() - c.asm.NullCheck1(skip) + skip := c.compileNullCheck1(str) switch { case sqltypes.IsText(str.Type) || sqltypes.IsBinary(str.Type): @@ -311,8 +305,7 @@ func (c *compiler) compileFn_HEX(call *builtinHex) (ctype, error) { return ctype{}, err } - skip := c.asm.jumpFrom() - c.asm.NullCheck1(skip) + skip := c.compileNullCheck1(str) col := collations.TypedCollation{ Collation: c.defaultCollation, @@ -366,8 +359,7 @@ func (c *compiler) compileFn_CEIL(expr *builtinCeil) (ctype, error) { return ctype{}, err } - skip := c.asm.jumpFrom() - c.asm.NullCheck1(skip) + skip := c.compileNullCheck1(arg) convt := ctype{Type: arg.Type, Col: collationNumeric} switch { @@ -401,8 +393,7 @@ func (c *compiler) compileFn_WEIGHT_STRING(call *builtinWeightString) (ctype, er return ctype{}, c.unsupported(call) case sqltypes.VarChar, sqltypes.VarBinary: - skip := c.asm.jumpFrom() - c.asm.NullCheck1(skip) + skip := c.compileNullCheck1(str) if call.Cast == "binary" { c.asm.Fn_WEIGHT_STRING_b(call.Len) diff --git a/go/vt/vtgate/evalengine/compiler_json.go b/go/vt/vtgate/evalengine/compiler_json.go index dc2761991f8..9433254d8b2 100644 --- a/go/vt/vtgate/evalengine/compiler_json.go +++ b/go/vt/vtgate/evalengine/compiler_json.go @@ -27,7 +27,7 @@ func (c *compiler) compileParseJSON(fn string, doct ctype, offset int) (ctype, e default: return ctype{}, errJSONType(fn) } - return ctype{Type: sqltypes.TypeJSON, Col: collationJSON}, nil + return ctype{Type: sqltypes.TypeJSON, Flag: doct.Flag, Col: collationJSON}, nil } func (c *compiler) compileToJSON(doct ctype, offset int) (ctype, error) { @@ -156,9 +156,7 @@ func (c *compiler) compileFn_JSON_UNQUOTE(call *builtinJSONUnquote) (ctype, erro return ctype{}, err } - skip := c.asm.jumpFrom() - c.asm.NullCheck1(skip) - + skip := c.compileNullCheck1(arg) _, err = c.compileParseJSON("JSON_UNQUOTE", arg, 1) if err != nil { return ctype{}, err @@ -166,7 +164,7 @@ func (c *compiler) compileFn_JSON_UNQUOTE(call *builtinJSONUnquote) (ctype, erro c.asm.Fn_JSON_UNQUOTE() c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Blob, Col: collationJSON}, nil + return ctype{Type: sqltypes.Blob, Flag: flagNullable, Col: collationJSON}, nil } func (c *compiler) compileFn_JSON_CONTAINS_PATH(call *builtinJSONContainsPath) (ctype, error) { From 85b97bfbda10e419d5f8fc043ddc86117e5e8bc3 Mon Sep 17 00:00:00 2001 From: Vicent Marti Date: Fri, 10 Mar 2023 12:59:27 +0100 Subject: [PATCH 26/42] evalengine/compiler: test for runtime errors Signed-off-by: Vicent Marti --- go/vt/vtgate/evalengine/compiler_json.go | 14 ++++++--- go/vt/vtgate/evalengine/compiler_test.go | 38 ++++++++++++++---------- 2 files changed, 33 insertions(+), 19 deletions(-) diff --git a/go/vt/vtgate/evalengine/compiler_json.go b/go/vt/vtgate/evalengine/compiler_json.go index 9433254d8b2..f6691d44ba4 100644 --- a/go/vt/vtgate/evalengine/compiler_json.go +++ b/go/vt/vtgate/evalengine/compiler_json.go @@ -72,7 +72,9 @@ func (c *compiler) compileFn_JSON_OBJECT(call *builtinJSONObject) (ctype, error) if err != nil { return ctype{}, err } - c.compileToJSONKey(key) + if err := c.compileToJSONKey(key); err != nil { + return ctype{}, err + } val, err := c.compileExpr(call.Arguments[i+1]) if err != nil { return ctype{}, err @@ -86,14 +88,18 @@ func (c *compiler) compileFn_JSON_OBJECT(call *builtinJSONObject) (ctype, error) return ctype{Type: sqltypes.TypeJSON, Col: collationJSON}, nil } -func (c *compiler) compileToJSONKey(key ctype) { +func (c *compiler) compileToJSONKey(key ctype) error { + if key.Type == sqltypes.Null { + return errJSONKeyIsNil + } if key.Type == sqltypes.VarChar && isEncodingJSONSafe(key.Col.Collation) { - return + return nil } if key.Type == sqltypes.VarBinary { - return + return nil } c.asm.Convert_xc(1, sqltypes.VarChar, c.defaultCollation, 0, false) + return nil } func (c *compiler) jsonExtractPath(expr Expr) (*json.Path, error) { diff --git a/go/vt/vtgate/evalengine/compiler_test.go b/go/vt/vtgate/evalengine/compiler_test.go index baa6ad94616..890d073c873 100644 --- a/go/vt/vtgate/evalengine/compiler_test.go +++ b/go/vt/vtgate/evalengine/compiler_test.go @@ -97,26 +97,27 @@ func TestCompilerReference(t *testing.T) { return } - env.Row = row - expected, err := env.Evaluate(converted) - if err != nil { - // do not attempt failing queries for now - return - } - total++ + env.Row = row + expected, evalErr := env.Evaluate(converted) + program, compileErr := evalengine.Compile(converted, env.Fields) - program, err := evalengine.Compile(converted, env.Fields) - if err != nil { - if vterrors.Code(err) != vtrpcpb.Code_UNIMPLEMENTED { - t.Errorf("failed compilation:\nSQL: %s\nError: %s", query, err) + if compileErr != nil { + switch { + case vterrors.Code(compileErr) == vtrpcpb.Code_UNIMPLEMENTED: + t.Logf("unsupported: %s", query) + case evalErr == nil: + t.Errorf("failed compilation:\nSQL: %s\nError: %s", query, compileErr) + case evalErr.Error() != compileErr.Error(): + t.Errorf("error mismatch:\nSQL: %s\nError eval: %s\nError comp: %s", query, evalErr, compileErr) + default: + supported++ } - t.Logf("unsupported: %s", query) return } var vm evalengine.VirtualMachine - res, err := func() (res evalengine.EvalResult, err error) { + res, vmErr := func() (res evalengine.EvalResult, err error) { defer func() { if r := recover(); r != nil { err = fmt.Errorf("PANIC: %v", r) @@ -126,8 +127,15 @@ func TestCompilerReference(t *testing.T) { return }() - if err != nil { - t.Errorf("failed evaluation from compiler:\nSQL: %s\nError: %s", query, err) + if vmErr != nil { + switch { + case evalErr == nil: + t.Errorf("failed evaluation from compiler:\nSQL: %s\nError: %s", query, err) + case evalErr.Error() != vmErr.Error(): + t.Errorf("error mismatch:\nSQL: %s\nError eval: %s\nError comp: %s", query, evalErr, vmErr) + default: + supported++ + } return } From 35ad0fea16329a022c6c330b026b7e947cf133f3 Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Fri, 10 Mar 2023 13:45:45 +0100 Subject: [PATCH 27/42] evalengine: Add support for FLOOR(). We already had CEIL() and this is very similar, in fact, it's a tiny bit simpler since we don't have to do the add one logic if divmod returns a non zero remainder. Signed-off-by: Dirkjan Bussink --- go/vt/vtgate/evalengine/compiler_asm.go | 22 +++++++++ go/vt/vtgate/evalengine/compiler_fn.go | 31 +++++++++++++ go/vt/vtgate/evalengine/compiler_json.go | 2 +- go/vt/vtgate/evalengine/fn_numeric.go | 45 +++++++++++++++++++ .../evalengine/internal/decimal/decimal.go | 14 ++++++ go/vt/vtgate/evalengine/testcases/cases.go | 25 +++++++++++ go/vt/vtgate/evalengine/testcases/helpers.go | 14 +++--- go/vt/vtgate/evalengine/translate_builtin.go | 5 +++ 8 files changed, 150 insertions(+), 8 deletions(-) diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index 8806224fb52..9393af829d7 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -987,6 +987,28 @@ func (asm *assembler) Fn_CEIL_f() { }, "FN CEIL FLOAT64(SP-1)") } +func (asm *assembler) Fn_FLOOR_d() { + asm.emit(func(vm *VirtualMachine) int { + d := vm.stack[vm.sp-1].(*evalDecimal) + c := d.dec.Floor() + i, valid := c.Int64() + if valid { + vm.stack[vm.sp-1] = vm.arena.newEvalInt64(i) + } else { + vm.err = errDeoptimize + } + return 1 + }, "FN FLOOR DECIMAL(SP-1)") +} + +func (asm *assembler) Fn_FLOOR_f() { + asm.emit(func(vm *VirtualMachine) int { + f := vm.stack[vm.sp-1].(*evalFloat) + f.f = math.Floor(f.f) + return 1 + }, "FN FLOOR FLOAT64(SP-1)") +} + func (asm *assembler) Fn_COLLATION(col collations.TypedCollation) { asm.emit(func(vm *VirtualMachine) int { v := evalCollation(vm.stack[vm.sp-1]) diff --git a/go/vt/vtgate/evalengine/compiler_fn.go b/go/vt/vtgate/evalengine/compiler_fn.go index 079ba9e7464..6c16374d82e 100644 --- a/go/vt/vtgate/evalengine/compiler_fn.go +++ b/go/vt/vtgate/evalengine/compiler_fn.go @@ -47,6 +47,8 @@ func (c *compiler) compileFn(call callable) (ctype, error) { return c.compileFn_COLLATION(call) case *builtinCeil: return c.compileFn_CEIL(call) + case *builtinFloor: + return c.compileFn_FLOOR(call) case *builtinWeightString: return c.compileFn_WEIGHT_STRING(call) default: @@ -382,6 +384,35 @@ func (c *compiler) compileFn_CEIL(expr *builtinCeil) (ctype, error) { return convt, nil } +func (c *compiler) compileFn_FLOOR(expr *builtinFloor) (ctype, error) { + arg, err := c.compileExpr(expr.Arguments[0]) + if err != nil { + return ctype{}, err + } + + skip := c.compileNullCheck1(arg) + + convt := ctype{Type: arg.Type, Col: collationNumeric} + switch { + case sqltypes.IsIntegral(arg.Type): + // No-op for integers. + case sqltypes.IsFloat(arg.Type): + c.asm.Fn_FLOOR_f() + case sqltypes.IsDecimal(arg.Type): + // We assume here the most common case here is that + // the decimal fits into an integer. + convt.Type = sqltypes.Int64 + c.asm.Fn_FLOOR_d() + default: + convt.Type = sqltypes.Float64 + c.asm.Convert_xf(1) + c.asm.Fn_FLOOR_f() + } + + c.asm.jumpDestination(skip) + return convt, nil +} + func (c *compiler) compileFn_WEIGHT_STRING(call *builtinWeightString) (ctype, error) { str, err := c.compileExpr(call.String) if err != nil { diff --git a/go/vt/vtgate/evalengine/compiler_json.go b/go/vt/vtgate/evalengine/compiler_json.go index f6691d44ba4..e39065a89e6 100644 --- a/go/vt/vtgate/evalengine/compiler_json.go +++ b/go/vt/vtgate/evalengine/compiler_json.go @@ -217,7 +217,7 @@ func (c *compiler) compileFn_JSON_KEYS(call *builtinJSONKeys) (ctype, error) { return ctype{}, err } - doc, err = c.compileParseJSON("JSON_KEYS", doc, 1) + _, err = c.compileParseJSON("JSON_KEYS", doc, 1) if err != nil { return ctype{}, err } diff --git a/go/vt/vtgate/evalengine/fn_numeric.go b/go/vt/vtgate/evalengine/fn_numeric.go index 19f12c9cfe3..297cbdb59cc 100644 --- a/go/vt/vtgate/evalengine/fn_numeric.go +++ b/go/vt/vtgate/evalengine/fn_numeric.go @@ -66,3 +66,48 @@ func (call *builtinCeil) typeof(env *ExpressionEnv) (sqltypes.Type, typeFlag) { return sqltypes.Float64, f } } + +type builtinFloor struct { + CallExpr +} + +var _ Expr = (*builtinFloor)(nil) + +func (call *builtinFloor) eval(env *ExpressionEnv) (eval, error) { + arg, err := call.arg1(env) + if err != nil { + return nil, err + } + if arg == nil { + return nil, nil + } + + switch num := arg.(type) { + case *evalInt64, *evalUint64: + return num, nil + case *evalDecimal: + dec := num.dec + dec = dec.Floor() + intnum, isfit := dec.Int64() + if isfit { + return newEvalInt64(intnum), nil + } + return newEvalDecimalWithPrec(dec, 0), nil + default: + f, _ := evalToNumeric(num).toFloat() + return newEvalFloat(math.Floor(f.f)), nil + } +} + +func (call *builtinFloor) typeof(env *ExpressionEnv) (sqltypes.Type, typeFlag) { + t, f := call.Arguments[0].typeof(env) + if sqltypes.IsSigned(t) { + return sqltypes.Int64, f + } else if sqltypes.IsUnsigned(t) { + return sqltypes.Uint64, f + } else if sqltypes.Decimal == t { + return sqltypes.Int64, f | flagAmbiguousType + } else { + return sqltypes.Float64, f + } +} diff --git a/go/vt/vtgate/evalengine/internal/decimal/decimal.go b/go/vt/vtgate/evalengine/internal/decimal/decimal.go index 5eb9aef5f9f..5be44b02e00 100644 --- a/go/vt/vtgate/evalengine/internal/decimal/decimal.go +++ b/go/vt/vtgate/evalengine/internal/decimal/decimal.go @@ -474,6 +474,20 @@ func (d Decimal) Ceil() Decimal { return Decimal{value: z, exp: 0} } +func (d Decimal) Floor() Decimal { + if d.isInteger() { + return d + } + + exp := big.NewInt(10) + + // NOTE(vadim): must negate after casting to prevent int32 overflow + exp.Exp(exp, big.NewInt(-int64(d.exp)), nil) + + z, _ := new(big.Int).DivMod(d.value, exp, new(big.Int)) + return Decimal{value: z, exp: 0} +} + func (d Decimal) truncate(precision int32) Decimal { d.ensureInitialized() if precision >= 0 && -precision > d.exp { diff --git a/go/vt/vtgate/evalengine/testcases/cases.go b/go/vt/vtgate/evalengine/testcases/cases.go index 152f6c571e6..aa7bb2aca44 100644 --- a/go/vt/vtgate/evalengine/testcases/cases.go +++ b/go/vt/vtgate/evalengine/testcases/cases.go @@ -48,6 +48,7 @@ type JSONObject struct{ defaultEnv } type CharsetConversionOperators struct{ defaultEnv } type CaseExprWithPredicate struct{ defaultEnv } type Ceil struct{ defaultEnv } +type Floor struct{ defaultEnv } type CaseExprWithValue struct{ defaultEnv } type Base64 struct{ defaultEnv } type Conversion struct{ defaultEnv } @@ -88,6 +89,7 @@ var Cases = []TestCase{ CharsetConversionOperators{}, CaseExprWithPredicate{}, Ceil{}, + Floor{}, CaseExprWithValue{}, Base64{}, Conversion{}, @@ -233,6 +235,29 @@ func (Ceil) Test(yield Iterator) { } } +func (Floor) Test(yield Iterator) { + var floorInputs = []string{ + "0", + "1", + "-1", + "'1.5'", + "NULL", + "'ABC'", + "1.5e0", + "-1.5e0", + "9223372036854775810.4", + "-9223372036854775810.4", + } + + for _, num := range floorInputs { + yield(fmt.Sprintf("FLOOR(%s)", num), nil) + } + + for _, num := range inputBitwise { + yield(fmt.Sprintf("FLOOR(%s)", num), nil) + } +} + func (CaseExprWithValue) Test(yield Iterator) { var elements []string elements = append(elements, inputBitwise...) diff --git a/go/vt/vtgate/evalengine/testcases/helpers.go b/go/vt/vtgate/evalengine/testcases/helpers.go index fa352c31eec..193533fcaa8 100644 --- a/go/vt/vtgate/evalengine/testcases/helpers.go +++ b/go/vt/vtgate/evalengine/testcases/helpers.go @@ -79,14 +79,14 @@ type bugs struct{} // results that do not match the behavior of the `=` operator, which is supposed to apply, // pair-wise, to the comparisons in a CASE or IN statement): // -// SELECT -1 IN (0xFF, 18446744073709551615) => 1 -// SELECT -1 IN (0, 18446744073709551615) => 0 -// SELECT -1 IN (0.0, 18446744073709551615) => 1 +// SELECT -1 IN (0xFF, 18446744073709551615) => 1 +// SELECT -1 IN (0, 18446744073709551615) => 0 +// SELECT -1 IN (0.0, 18446744073709551615) => 1 // -// SELECT 'FOO' IN ('foo', 0x00) => 0 -// SELECT 'FOO' IN ('foo', 0) => 1 -// SELECT 'FOO' IN ('foo', 0x00, CAST('bar' as char)) => 1 -// SELECT 'FOO' IN ('foo', 0x00, 'bar') => 0 +// SELECT 'FOO' IN ('foo', 0x00) => 0 +// SELECT 'FOO' IN ('foo', 0) => 1 +// SELECT 'FOO' IN ('foo', 0x00, CAST('bar' as char)) => 1 +// SELECT 'FOO' IN ('foo', 0x00, 'bar') => 0 // // SELECT 9223372036854775808 IN (0.0e0, -9223372036854775808) => 1 // SELECT 9223372036854775808 IN (0, -9223372036854775808) => 0 diff --git a/go/vt/vtgate/evalengine/translate_builtin.go b/go/vt/vtgate/evalengine/translate_builtin.go index 4228970ecbd..6a0e5bc9419 100644 --- a/go/vt/vtgate/evalengine/translate_builtin.go +++ b/go/vt/vtgate/evalengine/translate_builtin.go @@ -102,6 +102,11 @@ func (ast *astCompiler) translateFuncExpr(fn *sqlparser.FuncExpr) (Expr, error) return nil, argError(method) } return &builtinCeil{CallExpr: call}, nil + case "floor": + if len(args) != 1 { + return nil, argError(method) + } + return &builtinFloor{CallExpr: call}, nil case "lower", "lcase": if len(args) != 1 { return nil, argError(method) From 35a9998dcb100aab8d5eee996ea5356aa92c720d Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Fri, 10 Mar 2023 14:02:54 +0100 Subject: [PATCH 28/42] evalengine: Add support for ABS() Since we already have CEIL(), FLOOR() etc. let's add some more numeric operations here to the eval engine. Signed-off-by: Dirkjan Bussink --- go/vt/vtgate/evalengine/compiler_asm.go | 34 ++++++++++++++ go/vt/vtgate/evalengine/compiler_fn.go | 32 +++++++++++++ go/vt/vtgate/evalengine/fn_numeric.go | 45 +++++++++++++++++++ .../evalengine/internal/decimal/decimal.go | 4 +- .../internal/decimal/decimal_test.go | 20 ++++----- go/vt/vtgate/evalengine/testcases/cases.go | 25 +++++++++++ go/vt/vtgate/evalengine/translate_builtin.go | 5 +++ 7 files changed, 153 insertions(+), 12 deletions(-) diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index 9393af829d7..59a083f288b 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -1009,6 +1009,40 @@ func (asm *assembler) Fn_FLOOR_f() { }, "FN FLOOR FLOAT64(SP-1)") } +func (asm *assembler) Fn_ABS_i() { + asm.emit(func(vm *VirtualMachine) int { + f := vm.stack[vm.sp-1].(*evalInt64) + if f.i >= 0 { + return 1 + } + if f.i == math.MinInt64 { + vm.err = vterrors.NewErrorf(vtrpc.Code_INVALID_ARGUMENT, vterrors.DataOutOfRange, "BIGINT value is out of range") + return 1 + } + f.i = -f.i + return 1 + }, "FN ABS INT64(SP-1)") +} + +func (asm *assembler) Fn_ABS_d() { + asm.emit(func(vm *VirtualMachine) int { + d := vm.stack[vm.sp-1].(*evalDecimal) + d.dec = d.dec.Abs() + return 1 + }, "FN ABS DECIMAL(SP-1)") +} + +func (asm *assembler) Fn_ABS_f() { + asm.emit(func(vm *VirtualMachine) int { + f := vm.stack[vm.sp-1].(*evalFloat) + if f.f >= 0 { + return 1 + } + f.f = -f.f + return 1 + }, "FN ABS FLOAT64(SP-1)") +} + func (asm *assembler) Fn_COLLATION(col collations.TypedCollation) { asm.emit(func(vm *VirtualMachine) int { v := evalCollation(vm.stack[vm.sp-1]) diff --git a/go/vt/vtgate/evalengine/compiler_fn.go b/go/vt/vtgate/evalengine/compiler_fn.go index 6c16374d82e..977ccec5f3a 100644 --- a/go/vt/vtgate/evalengine/compiler_fn.go +++ b/go/vt/vtgate/evalengine/compiler_fn.go @@ -49,6 +49,8 @@ func (c *compiler) compileFn(call callable) (ctype, error) { return c.compileFn_CEIL(call) case *builtinFloor: return c.compileFn_FLOOR(call) + case *builtinAbs: + return c.compileFn_ABS(call) case *builtinWeightString: return c.compileFn_WEIGHT_STRING(call) default: @@ -413,6 +415,36 @@ func (c *compiler) compileFn_FLOOR(expr *builtinFloor) (ctype, error) { return convt, nil } +func (c *compiler) compileFn_ABS(expr *builtinAbs) (ctype, error) { + arg, err := c.compileExpr(expr.Arguments[0]) + if err != nil { + return ctype{}, err + } + + skip := c.compileNullCheck1(arg) + + convt := ctype{Type: arg.Type, Col: collationNumeric} + switch { + case sqltypes.IsUnsigned(arg.Type): + // No-op if it's unsigned since that's already positive. + case sqltypes.IsSigned(arg.Type): + c.asm.Fn_ABS_i() + case sqltypes.IsFloat(arg.Type): + c.asm.Fn_ABS_f() + case sqltypes.IsDecimal(arg.Type): + // We assume here the most common case here is that + // the decimal fits into an integer. + c.asm.Fn_ABS_d() + default: + convt.Type = sqltypes.Float64 + c.asm.Convert_xf(1) + c.asm.Fn_ABS_f() + } + + c.asm.jumpDestination(skip) + return convt, nil +} + func (c *compiler) compileFn_WEIGHT_STRING(call *builtinWeightString) (ctype, error) { str, err := c.compileExpr(call.String) if err != nil { diff --git a/go/vt/vtgate/evalengine/fn_numeric.go b/go/vt/vtgate/evalengine/fn_numeric.go index 297cbdb59cc..9bec5feaf1c 100644 --- a/go/vt/vtgate/evalengine/fn_numeric.go +++ b/go/vt/vtgate/evalengine/fn_numeric.go @@ -20,6 +20,8 @@ import ( "math" "vitess.io/vitess/go/sqltypes" + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/vterrors" ) type builtinCeil struct { @@ -111,3 +113,46 @@ func (call *builtinFloor) typeof(env *ExpressionEnv) (sqltypes.Type, typeFlag) { return sqltypes.Float64, f } } + +type builtinAbs struct { + CallExpr +} + +var _ Expr = (*builtinAbs)(nil) + +func (call *builtinAbs) eval(env *ExpressionEnv) (eval, error) { + arg, err := call.arg1(env) + if err != nil { + return nil, err + } + if arg == nil { + return nil, nil + } + + switch num := arg.(type) { + case *evalUint64: + return num, nil + case *evalInt64: + if num.i < 0 { + if num.i == math.MinInt64 { + return nil, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.DataOutOfRange, "BIGINT value is out of range") + } + return newEvalInt64(-num.i), nil + } + return num, nil + case *evalDecimal: + return newEvalDecimalWithPrec(num.dec.Abs(), num.length), nil + default: + f, _ := evalToNumeric(num).toFloat() + return newEvalFloat(math.Abs(f.f)), nil + } +} + +func (call *builtinAbs) typeof(env *ExpressionEnv) (sqltypes.Type, typeFlag) { + t, f := call.Arguments[0].typeof(env) + if sqltypes.IsNumber(t) { + return t, f + } else { + return sqltypes.Float64, f + } +} diff --git a/go/vt/vtgate/evalengine/internal/decimal/decimal.go b/go/vt/vtgate/evalengine/internal/decimal/decimal.go index 5be44b02e00..b93a1b9069b 100644 --- a/go/vt/vtgate/evalengine/internal/decimal/decimal.go +++ b/go/vt/vtgate/evalengine/internal/decimal/decimal.go @@ -263,7 +263,7 @@ func (d Decimal) rescale(exp int32) Decimal { } // abs returns the absolute value of the decimal. -func (d Decimal) abs() Decimal { +func (d Decimal) Abs() Decimal { if d.Sign() >= 0 { return d } @@ -438,7 +438,7 @@ func (d Decimal) divRound(d2 Decimal, precision int32) Decimal { // now rv2 = abs(r.value) * 2 r2 := Decimal{value: &rv2, exp: r.exp + precision} // r2 is now 2 * r * 10 ^ precision - var c = r2.Cmp(d2.abs()) + var c = r2.Cmp(d2.Abs()) if c < 0 { return q diff --git a/go/vt/vtgate/evalengine/internal/decimal/decimal_test.go b/go/vt/vtgate/evalengine/internal/decimal/decimal_test.go index 88edef68ff7..333d24e5d4d 100644 --- a/go/vt/vtgate/evalengine/internal/decimal/decimal_test.go +++ b/go/vt/vtgate/evalengine/internal/decimal/decimal_test.go @@ -761,7 +761,7 @@ func TestDecimal_QuoRem(t *testing.T) { t.Errorf("quotient wrong precision: d=%v, d2= %v, prec=%d, q=%v, r=%v", d, d2, prec, q, r) } - if r.abs().Cmp(d2.abs().mul(New(1, -prec))) >= 0 { + if r.Abs().Cmp(d2.Abs().mul(New(1, -prec))) >= 0 { t.Errorf("remainder too large: d=%v, d2= %v, prec=%d, q=%v, r=%v", d, d2, prec, q, r) } @@ -825,7 +825,7 @@ func TestDecimal_QuoRem2(t *testing.T) { d, d2, prec, q, r) } // rule 3: abs(r)= 0 { + if r.Abs().Cmp(d2.Abs().mul(New(1, -prec))) >= 0 { t.Errorf("remainder too large, d=%v, d2=%v, prec=%d, q=%v, r=%v", d, d2, prec, q, r) } @@ -879,11 +879,11 @@ func TestDecimal_DivRound(t *testing.T) { if sign(q)*sign(d)*sign(d2) < 0 { t.Errorf("sign of quotient wrong, got: %v/%v is about %v", d, d2, q) } - x := q.mul(d2).abs().sub(d.abs()).mul(New(2, 0)) - if x.Cmp(d2.abs().mul(New(1, -prec))) > 0 { + x := q.mul(d2).Abs().sub(d.Abs()).mul(New(2, 0)) + if x.Cmp(d2.Abs().mul(New(1, -prec))) > 0 { t.Errorf("wrong rounding, got: %v/%v prec=%d is about %v", d, d2, prec, q) } - if x.Cmp(d2.abs().mul(New(-1, -prec))) <= 0 { + if x.Cmp(d2.Abs().mul(New(-1, -prec))) <= 0 { t.Errorf("wrong rounding, got: %v/%v prec=%d is about %v", d, d2, prec, q) } if !q.Equal(result) { @@ -904,11 +904,11 @@ func TestDecimal_DivRound2(t *testing.T) { if sign(q)*sign(d)*sign(d2) < 0 { t.Errorf("sign of quotient wrong, got: %v/%v is about %v", d, d2, q) } - x := q.mul(d2).abs().sub(d.abs()).mul(New(2, 0)) - if x.Cmp(d2.abs().mul(New(1, -prec))) > 0 { + x := q.mul(d2).Abs().sub(d.Abs()).mul(New(2, 0)) + if x.Cmp(d2.Abs().mul(New(1, -prec))) > 0 { t.Errorf("wrong rounding, got: %v/%v prec=%d is about %v", d, d2, prec, q) } - if x.Cmp(d2.abs().mul(New(-1, -prec))) <= 0 { + if x.Cmp(d2.Abs().mul(New(-1, -prec))) <= 0 { t.Errorf("wrong rounding, got: %v/%v prec=%d is about %v", d, d2, prec, q) } } @@ -973,7 +973,7 @@ func TestDecimal_Abs1(t *testing.T) { a := New(-1234, -4) b := New(1234, -4) - c := a.abs() + c := a.Abs() if c.Cmp(b) != 0 { t.Errorf("error") } @@ -983,7 +983,7 @@ func TestDecimal_Abs2(t *testing.T) { a := New(-1234, -4) b := New(1234, -4) - c := b.abs() + c := b.Abs() if c.Cmp(a) == 0 { t.Errorf("error") } diff --git a/go/vt/vtgate/evalengine/testcases/cases.go b/go/vt/vtgate/evalengine/testcases/cases.go index aa7bb2aca44..3fa156a7187 100644 --- a/go/vt/vtgate/evalengine/testcases/cases.go +++ b/go/vt/vtgate/evalengine/testcases/cases.go @@ -49,6 +49,7 @@ type CharsetConversionOperators struct{ defaultEnv } type CaseExprWithPredicate struct{ defaultEnv } type Ceil struct{ defaultEnv } type Floor struct{ defaultEnv } +type Abs struct{ defaultEnv } type CaseExprWithValue struct{ defaultEnv } type Base64 struct{ defaultEnv } type Conversion struct{ defaultEnv } @@ -90,6 +91,7 @@ var Cases = []TestCase{ CaseExprWithPredicate{}, Ceil{}, Floor{}, + Abs{}, CaseExprWithValue{}, Base64{}, Conversion{}, @@ -258,6 +260,29 @@ func (Floor) Test(yield Iterator) { } } +func (Abs) Test(yield Iterator) { + var absInputs = []string{ + "0", + "1", + "-1", + "'1.5'", + "NULL", + "'ABC'", + "1.5e0", + "-1.5e0", + "9223372036854775810.4", + "-9223372036854775810.4", + } + + for _, num := range absInputs { + yield(fmt.Sprintf("ABS(%s)", num), nil) + } + + for _, num := range inputBitwise { + yield(fmt.Sprintf("ABS(%s)", num), nil) + } +} + func (CaseExprWithValue) Test(yield Iterator) { var elements []string elements = append(elements, inputBitwise...) diff --git a/go/vt/vtgate/evalengine/translate_builtin.go b/go/vt/vtgate/evalengine/translate_builtin.go index 6a0e5bc9419..f4b0e98122b 100644 --- a/go/vt/vtgate/evalengine/translate_builtin.go +++ b/go/vt/vtgate/evalengine/translate_builtin.go @@ -107,6 +107,11 @@ func (ast *astCompiler) translateFuncExpr(fn *sqlparser.FuncExpr) (Expr, error) return nil, argError(method) } return &builtinFloor{CallExpr: call}, nil + case "abs": + if len(args) != 1 { + return nil, argError(method) + } + return &builtinAbs{CallExpr: call}, nil case "lower", "lcase": if len(args) != 1 { return nil, argError(method) From d71e051d667a374ffece23652d91418da466a4de Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Sat, 11 Mar 2023 11:11:10 +0100 Subject: [PATCH 29/42] evalengine: Implement trigonometry functions Signed-off-by: Dirkjan Bussink --- go/vt/vtgate/evalengine/api_types.go | 15 +- go/vt/vtgate/evalengine/compiler_asm.go | 92 ++++++ go/vt/vtgate/evalengine/compiler_fn.go | 206 ++++++++++++++ go/vt/vtgate/evalengine/fn_numeric.go | 261 ++++++++++++++++++ .../evalengine/integration/comparison_test.go | 8 +- .../evalengine/integration/fuzz_test.go | 42 ++- go/vt/vtgate/evalengine/testcases/cases.go | 132 +++++++++ go/vt/vtgate/evalengine/testcases/inputs.go | 20 ++ go/vt/vtgate/evalengine/translate_builtin.go | 59 ++++ 9 files changed, 812 insertions(+), 23 deletions(-) diff --git a/go/vt/vtgate/evalengine/api_types.go b/go/vt/vtgate/evalengine/api_types.go index 817689d04c8..bd7a4313903 100644 --- a/go/vt/vtgate/evalengine/api_types.go +++ b/go/vt/vtgate/evalengine/api_types.go @@ -17,7 +17,6 @@ limitations under the License. package evalengine import ( - "fmt" "strconv" "vitess.io/vitess/go/mysql/collations" @@ -176,16 +175,10 @@ func ToNative(v sqltypes.Value) (any, error) { return out, err } -func NormalizeValue(v sqltypes.Value, coll collations.ID) string { +func NormalizeValue(v sqltypes.Value, coll collations.ID) sqltypes.Value { typ := v.Type() - if typ == sqltypes.Null { - return "NULL" - } if typ == sqltypes.VarChar && coll == collations.CollationBinaryID { - return fmt.Sprintf("VARBINARY(%q)", v.Raw()) - } - if v.IsQuoted() || typ == sqltypes.Bit { - return fmt.Sprintf("%v(%q)", typ, v.Raw()) + return sqltypes.NewVarBinary(string(v.Raw())) } if typ == sqltypes.Float32 || typ == sqltypes.Float64 { var bitsize = 64 @@ -196,7 +189,7 @@ func NormalizeValue(v sqltypes.Value, coll collations.ID) string { if err != nil { panic(err) } - return fmt.Sprintf("%v(%s)", typ, FormatFloat(typ, f)) + return sqltypes.MakeTrusted(typ, FormatFloat(typ, f)) } - return fmt.Sprintf("%v(%s)", typ, v.Raw()) + return v } diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index 59a083f288b..b13fbed9a5c 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -1043,6 +1043,98 @@ func (asm *assembler) Fn_ABS_f() { }, "FN ABS FLOAT64(SP-1)") } +func (asm *assembler) Fn_PI() { + asm.adjustStack(1) + asm.emit(func(vm *VirtualMachine) int { + vm.stack[vm.sp] = vm.arena.newEvalFloat(math.Pi) + vm.sp++ + return 1 + }, "FN PI") +} + +func (asm *assembler) Fn_ACOS() { + asm.emit(func(vm *VirtualMachine) int { + f := vm.stack[vm.sp-1].(*evalFloat) + if f.f < -1 || f.f > 1 { + vm.stack[vm.sp-1] = nil + return 1 + } + f.f = math.Acos(f.f) + return 1 + }, "FN ACOS FLOAT64(SP-1)") +} + +func (asm *assembler) Fn_ASIN() { + asm.emit(func(vm *VirtualMachine) int { + f := vm.stack[vm.sp-1].(*evalFloat) + if f.f < -1 || f.f > 1 { + vm.stack[vm.sp-1] = nil + return 1 + } + f.f = math.Asin(f.f) + return 1 + }, "FN ASIN FLOAT64(SP-1)") +} + +func (asm *assembler) Fn_ATAN() { + asm.emit(func(vm *VirtualMachine) int { + f := vm.stack[vm.sp-1].(*evalFloat) + f.f = math.Atan(f.f) + return 1 + }, "FN ATAN FLOAT64(SP-1)") +} + +func (asm *assembler) Fn_ATAN2() { + asm.adjustStack(-1) + asm.emit(func(vm *VirtualMachine) int { + f1 := vm.stack[vm.sp-2].(*evalFloat) + f2 := vm.stack[vm.sp-1].(*evalFloat) + f1.f = math.Atan2(f1.f, f2.f) + vm.sp-- + return 1 + }, "FN ATAN2 FLOAT64(SP-2) FLOAT64(SP-1)") +} + +func (asm *assembler) Fn_COS() { + asm.emit(func(vm *VirtualMachine) int { + f := vm.stack[vm.sp-1].(*evalFloat) + f.f = math.Cos(f.f) + return 1 + }, "FN COS FLOAT64(SP-1)") +} + +func (asm *assembler) Fn_SIN() { + asm.emit(func(vm *VirtualMachine) int { + f := vm.stack[vm.sp-1].(*evalFloat) + f.f = math.Sin(f.f) + return 1 + }, "FN SIN FLOAT64(SP-1)") +} + +func (asm *assembler) Fn_TAN() { + asm.emit(func(vm *VirtualMachine) int { + f := vm.stack[vm.sp-1].(*evalFloat) + f.f = math.Tan(f.f) + return 1 + }, "FN TAN FLOAT64(SP-1)") +} + +func (asm *assembler) Fn_DEGREES() { + asm.emit(func(vm *VirtualMachine) int { + f := vm.stack[vm.sp-1].(*evalFloat) + f.f = f.f * (180 / math.Pi) + return 1 + }, "FN DEGREES FLOAT64(SP-1)") +} + +func (asm *assembler) Fn_RADIANS() { + asm.emit(func(vm *VirtualMachine) int { + f := vm.stack[vm.sp-1].(*evalFloat) + f.f = f.f * (math.Pi / 180) + return 1 + }, "FN RADIANS FLOAT64(SP-1)") +} + func (asm *assembler) Fn_COLLATION(col collations.TypedCollation) { asm.emit(func(vm *VirtualMachine) int { v := evalCollation(vm.stack[vm.sp-1]) diff --git a/go/vt/vtgate/evalengine/compiler_fn.go b/go/vt/vtgate/evalengine/compiler_fn.go index 977ccec5f3a..91786dc087f 100644 --- a/go/vt/vtgate/evalengine/compiler_fn.go +++ b/go/vt/vtgate/evalengine/compiler_fn.go @@ -51,6 +51,26 @@ func (c *compiler) compileFn(call callable) (ctype, error) { return c.compileFn_FLOOR(call) case *builtinAbs: return c.compileFn_ABS(call) + case *builtinPi: + return c.compileFn_PI(call) + case *builtinAcos: + return c.compileFn_ACOS(call) + case *builtinAsin: + return c.compileFn_ASIN(call) + case *builtinAtan: + return c.compileFn_ATAN(call) + case *builtinAtan2: + return c.compileFn_ATAN2(call) + case *builtinCos: + return c.compileFn_COS(call) + case *builtinSin: + return c.compileFn_SIN(call) + case *builtinTan: + return c.compileFn_TAN(call) + case *builtinDegrees: + return c.compileFn_DEGREES(call) + case *builtinRadians: + return c.compileFn_RADIANS(call) case *builtinWeightString: return c.compileFn_WEIGHT_STRING(call) default: @@ -445,6 +465,192 @@ func (c *compiler) compileFn_ABS(expr *builtinAbs) (ctype, error) { return convt, nil } +func (c *compiler) compileFn_PI(expr *builtinPi) (ctype, error) { + c.asm.Fn_PI() + return ctype{Type: sqltypes.Float64, Col: collationNumeric}, nil +} + +func (c *compiler) compileFn_ACOS(expr *builtinAcos) (ctype, error) { + arg, err := c.compileExpr(expr.Arguments[0]) + if err != nil { + return ctype{}, err + } + + skip := c.compileNullCheck1(arg) + + switch { + case sqltypes.IsFloat(arg.Type): + default: + c.asm.Convert_xf(1) + } + + c.asm.Fn_ACOS() + c.asm.jumpDestination(skip) + return ctype{Type: sqltypes.Float64, Col: collationNumeric, Flag: flagNullable}, nil +} + +func (c *compiler) compileFn_ASIN(expr *builtinAsin) (ctype, error) { + arg, err := c.compileExpr(expr.Arguments[0]) + if err != nil { + return ctype{}, err + } + + skip := c.compileNullCheck1(arg) + + switch { + case sqltypes.IsFloat(arg.Type): + default: + c.asm.Convert_xf(1) + } + + c.asm.Fn_ASIN() + c.asm.jumpDestination(skip) + return ctype{Type: sqltypes.Float64, Col: collationNumeric, Flag: flagNullable}, nil +} + +func (c *compiler) compileFn_ATAN(expr *builtinAtan) (ctype, error) { + arg, err := c.compileExpr(expr.Arguments[0]) + if err != nil { + return ctype{}, err + } + + skip := c.compileNullCheck1(arg) + + switch { + case sqltypes.IsFloat(arg.Type): + default: + c.asm.Convert_xf(1) + } + + c.asm.Fn_ATAN() + c.asm.jumpDestination(skip) + return ctype{Type: sqltypes.Float64, Col: collationNumeric}, nil +} + +func (c *compiler) compileFn_ATAN2(expr *builtinAtan2) (ctype, error) { + arg1, err := c.compileExpr(expr.Arguments[0]) + if err != nil { + return ctype{}, err + } + + arg2, err := c.compileExpr(expr.Arguments[1]) + if err != nil { + return ctype{}, err + } + + skip := c.compileNullCheck2(arg1, arg2) + + switch { + case sqltypes.IsFloat(arg1.Type) && sqltypes.IsFloat(arg2.Type): + case sqltypes.IsFloat(arg1.Type): + c.asm.Convert_xf(1) + case sqltypes.IsFloat(arg2.Type): + c.asm.Convert_xf(2) + default: + c.asm.Convert_xf(2) + c.asm.Convert_xf(1) + } + + c.asm.Fn_ATAN2() + c.asm.jumpDestination(skip) + return ctype{Type: sqltypes.Float64, Col: collationNumeric}, nil +} + +func (c *compiler) compileFn_COS(expr *builtinCos) (ctype, error) { + arg, err := c.compileExpr(expr.Arguments[0]) + if err != nil { + return ctype{}, err + } + + skip := c.compileNullCheck1(arg) + + switch { + case sqltypes.IsFloat(arg.Type): + default: + c.asm.Convert_xf(1) + } + + c.asm.Fn_COS() + c.asm.jumpDestination(skip) + return ctype{Type: sqltypes.Float64, Col: collationNumeric}, nil +} + +func (c *compiler) compileFn_SIN(expr *builtinSin) (ctype, error) { + arg, err := c.compileExpr(expr.Arguments[0]) + if err != nil { + return ctype{}, err + } + + skip := c.compileNullCheck1(arg) + + switch { + case sqltypes.IsFloat(arg.Type): + default: + c.asm.Convert_xf(1) + } + + c.asm.Fn_SIN() + c.asm.jumpDestination(skip) + return ctype{Type: sqltypes.Float64, Col: collationNumeric}, nil +} + +func (c *compiler) compileFn_TAN(expr *builtinTan) (ctype, error) { + arg, err := c.compileExpr(expr.Arguments[0]) + if err != nil { + return ctype{}, err + } + + skip := c.compileNullCheck1(arg) + + switch { + case sqltypes.IsFloat(arg.Type): + default: + c.asm.Convert_xf(1) + } + + c.asm.Fn_TAN() + c.asm.jumpDestination(skip) + return ctype{Type: sqltypes.Float64, Col: collationNumeric}, nil +} + +func (c *compiler) compileFn_DEGREES(expr *builtinDegrees) (ctype, error) { + arg, err := c.compileExpr(expr.Arguments[0]) + if err != nil { + return ctype{}, err + } + + skip := c.compileNullCheck1(arg) + + switch { + case sqltypes.IsFloat(arg.Type): + default: + c.asm.Convert_xf(1) + } + + c.asm.Fn_DEGREES() + c.asm.jumpDestination(skip) + return ctype{Type: sqltypes.Float64, Col: collationNumeric}, nil +} + +func (c *compiler) compileFn_RADIANS(expr *builtinRadians) (ctype, error) { + arg, err := c.compileExpr(expr.Arguments[0]) + if err != nil { + return ctype{}, err + } + + skip := c.compileNullCheck1(arg) + + switch { + case sqltypes.IsFloat(arg.Type): + default: + c.asm.Convert_xf(1) + } + + c.asm.Fn_RADIANS() + c.asm.jumpDestination(skip) + return ctype{Type: sqltypes.Float64, Col: collationNumeric}, nil +} + func (c *compiler) compileFn_WEIGHT_STRING(call *builtinWeightString) (ctype, error) { str, err := c.compileExpr(call.String) if err != nil { diff --git a/go/vt/vtgate/evalengine/fn_numeric.go b/go/vt/vtgate/evalengine/fn_numeric.go index 9bec5feaf1c..7f4c3c725ef 100644 --- a/go/vt/vtgate/evalengine/fn_numeric.go +++ b/go/vt/vtgate/evalengine/fn_numeric.go @@ -156,3 +156,264 @@ func (call *builtinAbs) typeof(env *ExpressionEnv) (sqltypes.Type, typeFlag) { return sqltypes.Float64, f } } + +type builtinPi struct { + CallExpr +} + +var _ Expr = (*builtinPi)(nil) + +func (call *builtinPi) eval(env *ExpressionEnv) (eval, error) { + return newEvalFloat(math.Pi), nil +} + +func (call *builtinPi) typeof(_ *ExpressionEnv) (sqltypes.Type, typeFlag) { + return sqltypes.Float64, 0 +} + +type builtinAcos struct { + CallExpr +} + +var _ Expr = (*builtinAcos)(nil) + +func (call *builtinAcos) eval(env *ExpressionEnv) (eval, error) { + arg, err := call.arg1(env) + if err != nil { + return nil, err + } + if arg == nil { + return nil, nil + } + + f, _ := evalToNumeric(arg).toFloat() + if f.f < -1 || f.f > 1 { + return nil, nil + } + return newEvalFloat(math.Acos(f.f)), nil +} + +func (call *builtinAcos) typeof(env *ExpressionEnv) (sqltypes.Type, typeFlag) { + _, f := call.Arguments[0].typeof(env) + return sqltypes.Float64, f | flagNullable +} + +type builtinAsin struct { + CallExpr +} + +var _ Expr = (*builtinAsin)(nil) + +func (call *builtinAsin) eval(env *ExpressionEnv) (eval, error) { + arg, err := call.arg1(env) + if err != nil { + return nil, err + } + if arg == nil { + return nil, nil + } + + f, _ := evalToNumeric(arg).toFloat() + if f.f < -1 || f.f > 1 { + return nil, nil + } + return newEvalFloat(math.Asin(f.f)), nil +} + +func (call *builtinAsin) typeof(env *ExpressionEnv) (sqltypes.Type, typeFlag) { + _, f := call.Arguments[0].typeof(env) + return sqltypes.Float64, f | flagNullable +} + +type builtinAtan struct { + CallExpr +} + +var _ Expr = (*builtinAtan)(nil) + +func (call *builtinAtan) eval(env *ExpressionEnv) (eval, error) { + arg, err := call.arg1(env) + if err != nil { + return nil, err + } + if arg == nil { + return nil, nil + } + + f, _ := evalToNumeric(arg).toFloat() + return newEvalFloat(math.Atan(f.f)), nil +} + +func (call *builtinAtan) typeof(env *ExpressionEnv) (sqltypes.Type, typeFlag) { + _, f := call.Arguments[0].typeof(env) + return sqltypes.Float64, f +} + +type builtinAtan2 struct { + CallExpr +} + +var _ Expr = (*builtinAtan2)(nil) + +func (call *builtinAtan2) eval(env *ExpressionEnv) (eval, error) { + arg1, arg2, err := call.arg2(env) + if err != nil { + return nil, err + } + if arg1 == nil || arg2 == nil { + return nil, nil + } + + f1, _ := evalToNumeric(arg1).toFloat() + f2, _ := evalToNumeric(arg2).toFloat() + return newEvalFloat(math.Atan2(f1.f, f2.f)), nil +} + +func (call *builtinAtan2) typeof(env *ExpressionEnv) (sqltypes.Type, typeFlag) { + _, f := call.Arguments[0].typeof(env) + return sqltypes.Float64, f +} + +type builtinCos struct { + CallExpr +} + +var _ Expr = (*builtinCos)(nil) + +func (call *builtinCos) eval(env *ExpressionEnv) (eval, error) { + arg, err := call.arg1(env) + if err != nil { + return nil, err + } + if arg == nil { + return nil, nil + } + + f, _ := evalToNumeric(arg).toFloat() + return newEvalFloat(math.Cos(f.f)), nil +} + +func (call *builtinCos) typeof(env *ExpressionEnv) (sqltypes.Type, typeFlag) { + _, f := call.Arguments[0].typeof(env) + return sqltypes.Float64, f +} + +type builtinCot struct { + CallExpr +} + +var _ Expr = (*builtinCot)(nil) + +func (call *builtinCot) eval(env *ExpressionEnv) (eval, error) { + arg, err := call.arg1(env) + if err != nil { + return nil, err + } + if arg == nil { + return nil, nil + } + + f, _ := evalToNumeric(arg).toFloat() + return newEvalFloat(1.0 / math.Tan(f.f)), nil +} + +func (call *builtinCot) typeof(env *ExpressionEnv) (sqltypes.Type, typeFlag) { + _, f := call.Arguments[0].typeof(env) + return sqltypes.Float64, f +} + +type builtinSin struct { + CallExpr +} + +var _ Expr = (*builtinSin)(nil) + +func (call *builtinSin) eval(env *ExpressionEnv) (eval, error) { + arg, err := call.arg1(env) + if err != nil { + return nil, err + } + if arg == nil { + return nil, nil + } + + f, _ := evalToNumeric(arg).toFloat() + return newEvalFloat(math.Sin(f.f)), nil +} + +func (call *builtinSin) typeof(env *ExpressionEnv) (sqltypes.Type, typeFlag) { + _, f := call.Arguments[0].typeof(env) + return sqltypes.Float64, f +} + +type builtinTan struct { + CallExpr +} + +var _ Expr = (*builtinTan)(nil) + +func (call *builtinTan) eval(env *ExpressionEnv) (eval, error) { + arg, err := call.arg1(env) + if err != nil { + return nil, err + } + if arg == nil { + return nil, nil + } + + f, _ := evalToNumeric(arg).toFloat() + return newEvalFloat(math.Tan(f.f)), nil +} + +func (call *builtinTan) typeof(env *ExpressionEnv) (sqltypes.Type, typeFlag) { + _, f := call.Arguments[0].typeof(env) + return sqltypes.Float64, f +} + +type builtinDegrees struct { + CallExpr +} + +var _ Expr = (*builtinDegrees)(nil) + +func (call *builtinDegrees) eval(env *ExpressionEnv) (eval, error) { + arg, err := call.arg1(env) + if err != nil { + return nil, err + } + if arg == nil { + return nil, nil + } + + f, _ := evalToNumeric(arg).toFloat() + return newEvalFloat(f.f * (180 / math.Pi)), nil +} + +func (call *builtinDegrees) typeof(env *ExpressionEnv) (sqltypes.Type, typeFlag) { + _, f := call.Arguments[0].typeof(env) + return sqltypes.Float64, f +} + +type builtinRadians struct { + CallExpr +} + +var _ Expr = (*builtinRadians)(nil) + +func (call *builtinRadians) eval(env *ExpressionEnv) (eval, error) { + arg, err := call.arg1(env) + if err != nil { + return nil, err + } + if arg == nil { + return nil, nil + } + + f, _ := evalToNumeric(arg).toFloat() + return newEvalFloat(f.f * (math.Pi / 180)), nil +} + +func (call *builtinRadians) typeof(env *ExpressionEnv) (sqltypes.Type, typeFlag) { + _, f := call.Arguments[0].typeof(env) + return sqltypes.Float64, f +} diff --git a/go/vt/vtgate/evalengine/integration/comparison_test.go b/go/vt/vtgate/evalengine/integration/comparison_test.go index c541238fc9a..de26de542a2 100644 --- a/go/vt/vtgate/evalengine/integration/comparison_test.go +++ b/go/vt/vtgate/evalengine/integration/comparison_test.go @@ -115,7 +115,7 @@ func compareRemoteExprEnv(t *testing.T, env *evalengine.ExpressionEnv, conn *mys local, localType, localErr := evaluateLocalEvalengine(env, localQuery) remote, remoteErr := conn.ExecuteFetch(remoteQuery, 1, true) - var localVal, remoteVal string + var localVal, remoteVal sqltypes.Value var localCollation, remoteCollation collations.ID if localErr == nil { v := local.Value() @@ -129,7 +129,7 @@ func compareRemoteExprEnv(t *testing.T, env *evalengine.ExpressionEnv, conn *mys if debugNormalize { localVal = evalengine.NormalizeValue(v, local.Collation()) } else { - localVal = v.String() + localVal = v } if debugCheckTypes && localType != -1 { tt := v.Type() @@ -143,7 +143,7 @@ func compareRemoteExprEnv(t *testing.T, env *evalengine.ExpressionEnv, conn *mys if debugNormalize { remoteVal = evalengine.NormalizeValue(remote.Rows[0][0], collations.ID(remote.Fields[0].Charset)) } else { - remoteVal = remote.Rows[0][0].String() + remoteVal = remote.Rows[0][0] } if debugCheckCollations { if remote.Rows[0][0].IsNull() { @@ -158,7 +158,7 @@ func compareRemoteExprEnv(t *testing.T, env *evalengine.ExpressionEnv, conn *mys if diff := compareResult(localErr, remoteErr, localVal, remoteVal, localCollation, remoteCollation); diff != "" { t.Errorf("%s\nquery: %s (SIMPLIFY=%v)\nrow: %v", diff, localQuery, debugSimplify, env.Row) } else if debugPrintAll { - t.Logf("local=%s mysql=%s\nquery: %s\nrow: %v", localVal, remoteVal, localQuery, env.Row) + t.Logf("local=%s mysql=%s\nquery: %s\nrow: %v", localVal.String(), remoteVal.String(), localQuery, env.Row) } } diff --git a/go/vt/vtgate/evalengine/integration/fuzz_test.go b/go/vt/vtgate/evalengine/integration/fuzz_test.go index 875fab6540c..59d6f345dc3 100644 --- a/go/vt/vtgate/evalengine/integration/fuzz_test.go +++ b/go/vt/vtgate/evalengine/integration/fuzz_test.go @@ -20,6 +20,7 @@ import ( "encoding/json" "errors" "fmt" + "math" "math/rand" "os" "regexp" @@ -205,10 +206,10 @@ func TestGenerateFuzzCases(t *testing.T) { remoteErr: remoteErr, } if localErr == nil { - res.localVal = eval.Value().String() + res.localVal = eval.Value() } if remoteErr == nil { - res.remoteVal = remote.Rows[0][0].String() + res.remoteVal = remote.Rows[0][0] } if res.Error() != "" { return &res @@ -270,7 +271,7 @@ func TestGenerateFuzzCases(t *testing.T) { } else { golden = append(golden, evaltest{ Query: query, - Value: fail.remoteVal, + Value: fail.remoteVal.String(), }) } } @@ -290,10 +291,22 @@ func TestGenerateFuzzCases(t *testing.T) { type mismatch struct { expr sqlparser.Expr localErr, remoteErr error - localVal, remoteVal string + localVal, remoteVal sqltypes.Value } -func compareResult(localErr, remoteErr error, localVal, remoteVal string, localCollation, remoteCollation collations.ID) string { +const tolerance = 1e-14 + +func close(a, b float64) bool { + if a == b { + return true + } + if b == 0 { + return math.Abs(a) < tolerance + } + return math.Abs((a-b)/b) < tolerance +} + +func compareResult(localErr, remoteErr error, localVal, remoteVal sqltypes.Value, localCollation, remoteCollation collations.ID) string { if localErr != nil { if remoteErr == nil { return fmt.Sprintf("%v; mysql response: %s", localErr, remoteVal) @@ -322,13 +335,26 @@ func compareResult(localErr, remoteErr error, localVal, remoteVal string, localC remoteCollationName = coll.Name() } - if localVal != remoteVal { + if localVal.IsFloat() && remoteVal.IsFloat() { + localFloat, err := localVal.ToFloat64() + if err != nil { + return fmt.Sprintf("error converting local value to float: %v", err) + } + remoteFloat, err := remoteVal.ToFloat64() + if err != nil { + return fmt.Sprintf("error converting remote value to float: %v", err) + } + if !close(localFloat, remoteFloat) { + return fmt.Sprintf("different results: %s; mysql response: %s (local collation: %s; mysql collation: %s)", + localVal.String(), remoteVal.String(), localCollationName, remoteCollationName) + } + } else if localVal.String() != remoteVal.String() { return fmt.Sprintf("different results: %s; mysql response: %s (local collation: %s; mysql collation: %s)", - localVal, remoteVal, localCollationName, remoteCollationName) + localVal.String(), remoteVal.String(), localCollationName, remoteCollationName) } if localCollation != remoteCollation { return fmt.Sprintf("different collations: %s; mysql response: %s (local result: %s; mysql result: %s)", - localCollationName, remoteCollationName, localVal, remoteVal, + localCollationName, remoteCollationName, localVal.String(), remoteVal.String(), ) } diff --git a/go/vt/vtgate/evalengine/testcases/cases.go b/go/vt/vtgate/evalengine/testcases/cases.go index 3fa156a7187..5c02ff9723d 100644 --- a/go/vt/vtgate/evalengine/testcases/cases.go +++ b/go/vt/vtgate/evalengine/testcases/cases.go @@ -50,6 +50,17 @@ type CaseExprWithPredicate struct{ defaultEnv } type Ceil struct{ defaultEnv } type Floor struct{ defaultEnv } type Abs struct{ defaultEnv } +type Pi struct{ defaultEnv } +type Acos struct{ defaultEnv } +type Asin struct{ defaultEnv } +type Atan struct{ defaultEnv } +type Atan2 struct{ defaultEnv } +type Cos struct{ defaultEnv } +type Cot struct{ defaultEnv } +type Sin struct{ defaultEnv } +type Tan struct{ defaultEnv } +type Degrees struct{ defaultEnv } +type Radians struct{ defaultEnv } type CaseExprWithValue struct{ defaultEnv } type Base64 struct{ defaultEnv } type Conversion struct{ defaultEnv } @@ -92,6 +103,17 @@ var Cases = []TestCase{ Ceil{}, Floor{}, Abs{}, + Pi{}, + Acos{}, + Asin{}, + Atan{}, + Atan2{}, + Cos{}, + Cot{}, + Sin{}, + Tan{}, + Degrees{}, + Radians{}, CaseExprWithValue{}, Base64{}, Conversion{}, @@ -283,6 +305,116 @@ func (Abs) Test(yield Iterator) { } } +func (Pi) Test(yield Iterator) { + yield("PI()+0.000000000000000000", nil) +} + +func (Acos) Test(yield Iterator) { + for _, num := range radianInputs { + yield(fmt.Sprintf("ACOS(%s)", num), nil) + } + + for _, num := range inputBitwise { + yield(fmt.Sprintf("ACOS(%s)", num), nil) + } +} + +func (Asin) Test(yield Iterator) { + for _, num := range radianInputs { + yield(fmt.Sprintf("ASIN(%s)", num), nil) + } + + for _, num := range inputBitwise { + yield(fmt.Sprintf("ASIN(%s)", num), nil) + } +} + +func (Atan) Test(yield Iterator) { + for _, num := range radianInputs { + yield(fmt.Sprintf("ATAN(%s)", num), nil) + } + + for _, num := range inputBitwise { + yield(fmt.Sprintf("ATAN(%s)", num), nil) + } +} + +func (Atan2) Test(yield Iterator) { + for _, num1 := range radianInputs { + for _, num2 := range radianInputs { + yield(fmt.Sprintf("ATAN(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("ATAN2(%s, %s)", num1, num2), nil) + } + } + + for _, num1 := range inputBitwise { + for _, num2 := range inputBitwise { + yield(fmt.Sprintf("ATAN(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("ATAN2(%s, %s)", num1, num2), nil) + } + } +} + +func (Cos) Test(yield Iterator) { + for _, num := range radianInputs { + yield(fmt.Sprintf("COS(%s)", num), nil) + } + + for _, num := range inputBitwise { + yield(fmt.Sprintf("COS(%s)", num), nil) + } +} + +func (Cot) Test(yield Iterator) { + for _, num := range radianInputs { + yield(fmt.Sprintf("COT(%s)", num), nil) + } + + for _, num := range inputBitwise { + yield(fmt.Sprintf("COT(%s)", num), nil) + } +} + +func (Sin) Test(yield Iterator) { + for _, num := range radianInputs { + yield(fmt.Sprintf("SIN(%s)", num), nil) + } + + for _, num := range inputBitwise { + yield(fmt.Sprintf("SIN(%s)", num), nil) + } +} + +func (Tan) Test(yield Iterator) { + for _, num := range radianInputs { + yield(fmt.Sprintf("TAN(%s)", num), nil) + } + + for _, num := range inputBitwise { + yield(fmt.Sprintf("TAN(%s)", num), nil) + } +} + +func (Degrees) Test(yield Iterator) { + for _, num := range radianInputs { + yield(fmt.Sprintf("DEGREES(%s)", num), nil) + } + + for _, num := range inputBitwise { + yield(fmt.Sprintf("DEGREES(%s)", num), nil) + } +} + +func (Radians) Test(yield Iterator) { + for _, num := range radianInputs { + yield(fmt.Sprintf("RADIANS(%s)", num), nil) + } + + for _, num := range inputBitwise { + yield(fmt.Sprintf("RADIANS(%s)", num), nil) + } +} + func (CaseExprWithValue) Test(yield Iterator) { var elements []string elements = append(elements, inputBitwise...) diff --git a/go/vt/vtgate/evalengine/testcases/inputs.go b/go/vt/vtgate/evalengine/testcases/inputs.go index 3ae80620f99..f871502f0d7 100644 --- a/go/vt/vtgate/evalengine/testcases/inputs.go +++ b/go/vt/vtgate/evalengine/testcases/inputs.go @@ -19,6 +19,9 @@ package testcases import ( "math" "strconv" + + "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/vtgate/evalengine" ) var inputJSONObjects = []string{ @@ -57,6 +60,23 @@ var inputBitwise = []string{ "64", "'64'", "_binary '64'", "X'40'", "_binary X'40'", } +var radianInputs = []string{ + "0", + "1", + "-1", + "'1.5'", + "NULL", + "'ABC'", + "1.5e0", + "-1.5e0", + "9223372036854775810.4", + "-9223372036854775810.4", + string(evalengine.FormatFloat(sqltypes.Float64, math.Pi)), + string(evalengine.FormatFloat(sqltypes.Float64, math.MaxFloat64)), + string(evalengine.FormatFloat(sqltypes.Float64, math.SmallestNonzeroFloat32)), + string(evalengine.FormatFloat(sqltypes.Float64, math.SmallestNonzeroFloat64)), +} + var inputComparisonElement = []string{ "NULL", "-1", "0", "1", `'foo'`, `'bar'`, `'FOO'`, `'BAR'`, diff --git a/go/vt/vtgate/evalengine/translate_builtin.go b/go/vt/vtgate/evalengine/translate_builtin.go index f4b0e98122b..2b07a0ea9e2 100644 --- a/go/vt/vtgate/evalengine/translate_builtin.go +++ b/go/vt/vtgate/evalengine/translate_builtin.go @@ -112,6 +112,65 @@ func (ast *astCompiler) translateFuncExpr(fn *sqlparser.FuncExpr) (Expr, error) return nil, argError(method) } return &builtinAbs{CallExpr: call}, nil + case "pi": + if len(args) != 0 { + return nil, argError(method) + } + return &builtinPi{CallExpr: call}, nil + case "acos": + if len(args) != 1 { + return nil, argError(method) + } + return &builtinAcos{CallExpr: call}, nil + case "asin": + if len(args) != 1 { + return nil, argError(method) + } + return &builtinAsin{CallExpr: call}, nil + case "atan": + switch len(args) { + case 1: + return &builtinAtan{CallExpr: call}, nil + case 2: + return &builtinAtan2{CallExpr: call}, nil + default: + return nil, argError(method) + } + case "atan2": + if len(args) != 2 { + return nil, argError(method) + } + return &builtinAtan2{CallExpr: call}, nil + case "cos": + if len(args) != 1 { + return nil, argError(method) + } + return &builtinCos{CallExpr: call}, nil + case "cot": + if len(args) != 1 { + return nil, argError(method) + } + return &builtinCot{CallExpr: call}, nil + case "sin": + if len(args) != 1 { + return nil, argError(method) + } + return &builtinSin{CallExpr: call}, nil + case "tan": + if len(args) != 1 { + return nil, argError(method) + } + return &builtinTan{CallExpr: call}, nil + case "degrees": + if len(args) != 1 { + return nil, argError(method) + } + return &builtinDegrees{CallExpr: call}, nil + case "radians": + if len(args) != 1 { + return nil, argError(method) + } + return &builtinRadians{CallExpr: call}, nil case "lower", "lcase": if len(args) != 1 { return nil, argError(method) From 7cd35c2c9aa7085269a36cf59cef64be6f0455e1 Mon Sep 17 00:00:00 2001 From: Vicent Marti Date: Mon, 13 Mar 2023 15:59:07 +0100 Subject: [PATCH 30/42] evalengine: simplify test case generation Signed-off-by: Vicent Marti --- go/vt/vtgate/evalengine/compiler_test.go | 12 +- go/vt/vtgate/evalengine/expr_env.go | 2 +- .../evalengine/integration/comparison_test.go | 7 +- go/vt/vtgate/evalengine/testcases/cases.go | 297 +++++++----------- go/vt/vtgate/evalengine/testcases/helpers.go | 16 + go/vt/vtgate/evalengine/testcases/inputs.go | 9 + 6 files changed, 145 insertions(+), 198 deletions(-) diff --git a/go/vt/vtgate/evalengine/compiler_test.go b/go/vt/vtgate/evalengine/compiler_test.go index 890d073c873..6f20cac4873 100644 --- a/go/vt/vtgate/evalengine/compiler_test.go +++ b/go/vt/vtgate/evalengine/compiler_test.go @@ -47,7 +47,7 @@ func NewTracker() *Tracker { func (s *Tracker) Add(name string, supported, total int) { s.tbl.Append([]string{ - name[strings.LastIndexByte(name, '.')+1:], + name, strconv.Itoa(supported), strconv.Itoa(total), fmt.Sprintf("%.02f%%", 100*float64(supported)/float64(total)), @@ -79,12 +79,12 @@ func TestCompilerReference(t *testing.T) { track := NewTracker() for _, tc := range testcases.Cases { - name := fmt.Sprintf("%T", tc) - t.Run(name, func(t *testing.T) { + t.Run(tc.Name(), func(t *testing.T) { var supported, total int - env := tc.Environment() + env := evalengine.EmptyExpressionEnv() + env.Fields = tc.Schema - tc.Test(func(query string, row []sqltypes.Value) { + tc.Run(func(query string, row []sqltypes.Value) { stmt, err := sqlparser.Parse("SELECT " + query) if err != nil { // no need to test un-parseable queries @@ -150,7 +150,7 @@ func TestCompilerReference(t *testing.T) { supported++ }) - track.Add(name, supported, total) + track.Add(tc.Name(), supported, total) }) } diff --git a/go/vt/vtgate/evalengine/expr_env.go b/go/vt/vtgate/evalengine/expr_env.go index ae73066dee6..33b45aff1b0 100644 --- a/go/vt/vtgate/evalengine/expr_env.go +++ b/go/vt/vtgate/evalengine/expr_env.go @@ -65,7 +65,7 @@ func (env *ExpressionEnv) collation() collations.TypedCollation { // EmptyExpressionEnv returns a new ExpressionEnv with no bind vars or row func EmptyExpressionEnv() *ExpressionEnv { - return EnvWithBindVars(map[string]*querypb.BindVariable{}, collations.Unknown) + return EnvWithBindVars(map[string]*querypb.BindVariable{}, collations.Default()) } // EnvWithBindVars returns an expression environment with no current row, but with bindvars diff --git a/go/vt/vtgate/evalengine/integration/comparison_test.go b/go/vt/vtgate/evalengine/integration/comparison_test.go index de26de542a2..8104746b3b0 100644 --- a/go/vt/vtgate/evalengine/integration/comparison_test.go +++ b/go/vt/vtgate/evalengine/integration/comparison_test.go @@ -167,9 +167,10 @@ func TestMySQL(t *testing.T) { defer conn.Close() for _, tc := range testcases.Cases { - t.Run(fmt.Sprintf("%T", tc), func(t *testing.T) { - env := tc.Environment() - tc.Test(func(query string, row []sqltypes.Value) { + t.Run(tc.Name(), func(t *testing.T) { + env := evalengine.EmptyExpressionEnv() + env.Fields = tc.Schema + tc.Run(func(query string, row []sqltypes.Value) { env.Row = row compareRemoteExprEnv(t, env, conn, query) }) diff --git a/go/vt/vtgate/evalengine/testcases/cases.go b/go/vt/vtgate/evalengine/testcases/cases.go index 5c02ff9723d..e1db6b0a7d0 100644 --- a/go/vt/vtgate/evalengine/testcases/cases.go +++ b/go/vt/vtgate/evalengine/testcases/cases.go @@ -23,130 +23,64 @@ import ( "strconv" "strings" - "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" - "vitess.io/vitess/go/vt/vtgate/evalengine" ) -type TestCase interface { - Test(yield Iterator) - Environment() *evalengine.ExpressionEnv -} - -type Iterator func(query string, row []sqltypes.Value) - -type defaultEnv struct{} - -func (defaultEnv) Environment() *evalengine.ExpressionEnv { - return evalengine.EnvWithBindVars(nil, collations.CollationUtf8mb4ID) -} - -type JSONPathOperations struct{ defaultEnv } -type JSONArray struct{ defaultEnv } -type JSONObject struct{ defaultEnv } -type CharsetConversionOperators struct{ defaultEnv } -type CaseExprWithPredicate struct{ defaultEnv } -type Ceil struct{ defaultEnv } -type Floor struct{ defaultEnv } -type Abs struct{ defaultEnv } -type Pi struct{ defaultEnv } -type Acos struct{ defaultEnv } -type Asin struct{ defaultEnv } -type Atan struct{ defaultEnv } -type Atan2 struct{ defaultEnv } -type Cos struct{ defaultEnv } -type Cot struct{ defaultEnv } -type Sin struct{ defaultEnv } -type Tan struct{ defaultEnv } -type Degrees struct{ defaultEnv } -type Radians struct{ defaultEnv } -type CaseExprWithValue struct{ defaultEnv } -type Base64 struct{ defaultEnv } -type Conversion struct{ defaultEnv } -type LargeDecimals struct{ defaultEnv } -type LargeIntegers struct{ defaultEnv } -type DecimalClamping struct{ defaultEnv } -type BitwiseOperatorsUnary struct{ defaultEnv } -type BitwiseOperators struct{ defaultEnv } -type WeightString struct{ defaultEnv } -type FloatFormatting struct{ defaultEnv } -type UnderscoreAndPercentage struct{ defaultEnv } -type Types struct{ defaultEnv } -type HexArithmetic struct{ defaultEnv } -type NumericTypes struct{ defaultEnv } -type NegateArithmetic struct{ defaultEnv } -type CollationOperations struct{ defaultEnv } -type LikeComparison struct{ defaultEnv } -type MultiComparisons struct{ defaultEnv } -type IsStatement struct{ defaultEnv } -type TupleComparisons struct{ defaultEnv } -type Comparisons struct{ defaultEnv } -type JSONExtract struct{} -type FnLower struct{ defaultEnv } -type FnUpper struct{ defaultEnv } -type FnCharLength struct{ defaultEnv } -type FnLength struct{ defaultEnv } -type FnBitLength struct{ defaultEnv } -type FnAscii struct{ defaultEnv } -type FnRepeat struct{ defaultEnv } -type FnHex struct{ defaultEnv } -type InStatement struct{ defaultEnv } - var Cases = []TestCase{ - JSONExtract{}, - JSONPathOperations{}, - JSONArray{}, - JSONObject{}, - CharsetConversionOperators{}, - CaseExprWithPredicate{}, - Ceil{}, - Floor{}, - Abs{}, - Pi{}, - Acos{}, - Asin{}, - Atan{}, - Atan2{}, - Cos{}, - Cot{}, - Sin{}, - Tan{}, - Degrees{}, - Radians{}, - CaseExprWithValue{}, - Base64{}, - Conversion{}, - LargeDecimals{}, - LargeIntegers{}, - DecimalClamping{}, - BitwiseOperatorsUnary{}, - BitwiseOperators{}, - WeightString{}, - FloatFormatting{}, - UnderscoreAndPercentage{}, - Types{}, - HexArithmetic{}, - NumericTypes{}, - NegateArithmetic{}, - CollationOperations{}, - LikeComparison{}, - MultiComparisons{}, - IsStatement{}, - TupleComparisons{}, - Comparisons{}, - FnLower{}, - FnUpper{}, - FnCharLength{}, - FnLength{}, - FnBitLength{}, - FnAscii{}, - FnRepeat{}, - FnHex{}, - InStatement{}, -} - -func (JSONPathOperations) Test(yield Iterator) { + {Run: JSONExtract, Schema: JSONExtract_Schema}, + {Run: JSONPathOperations}, + {Run: JSONArray}, + {Run: JSONObject}, + {Run: CharsetConversionOperators}, + {Run: CaseExprWithPredicate}, + {Run: CaseExprWithValue}, + {Run: Base64}, + {Run: Conversion}, + {Run: LargeDecimals}, + {Run: LargeIntegers}, + {Run: DecimalClamping}, + {Run: BitwiseOperatorsUnary}, + {Run: BitwiseOperators}, + {Run: WeightString}, + {Run: FloatFormatting}, + {Run: UnderscoreAndPercentage}, + {Run: Types}, + {Run: HexArithmetic}, + {Run: NumericTypes}, + {Run: NegateArithmetic}, + {Run: CollationOperations}, + {Run: LikeComparison}, + {Run: MultiComparisons}, + {Run: IsStatement}, + {Run: TupleComparisons}, + {Run: Comparisons}, + {Run: InStatement}, + {Run: FnLower}, + {Run: FnUpper}, + {Run: FnCharLength}, + {Run: FnLength}, + {Run: FnBitLength}, + {Run: FnAscii}, + {Run: FnRepeat}, + {Run: FnHex}, + {Run: FnCeil}, + {Run: FnFloor}, + {Run: FnAbs}, + {Run: FnPi}, + {Run: FnAcos}, + {Run: FnAsin}, + {Run: FnAtan}, + {Run: FnAtan2}, + {Run: FnCos}, + {Run: FnCot}, + {Run: FnSin}, + {Run: FnTan}, + {Run: FnDegrees}, + {Run: FnRadians}, +} + +func JSONPathOperations(yield Query) { for _, obj := range inputJSONObjects { yield(fmt.Sprintf("JSON_KEYS('%s')", obj), nil) @@ -165,7 +99,7 @@ func (JSONPathOperations) Test(yield Iterator) { } } -func (JSONArray) Test(yield Iterator) { +func JSONArray(yield Query) { for _, a := range inputJSONPrimitives { yield(fmt.Sprintf("JSON_ARRAY(%s)", a), nil) for _, b := range inputJSONPrimitives { @@ -175,7 +109,7 @@ func (JSONArray) Test(yield Iterator) { yield("JSON_ARRAY()", nil) } -func (JSONObject) Test(yield Iterator) { +func JSONObject(yield Query) { for _, a := range inputJSONPrimitives { for _, b := range inputJSONPrimitives { yield(fmt.Sprintf("JSON_OBJECT(%s, %s)", a, b), nil) @@ -184,7 +118,7 @@ func (JSONObject) Test(yield Iterator) { yield("JSON_OBJECT()", nil) } -func (CharsetConversionOperators) Test(yield Iterator) { +func CharsetConversionOperators(yield Query) { var introducers = []string{ "", "_latin1", "_utf8mb4", "_utf8", "_binary", } @@ -204,7 +138,7 @@ func (CharsetConversionOperators) Test(yield Iterator) { } } -func (CaseExprWithPredicate) Test(yield Iterator) { +func CaseExprWithPredicate(yield Query) { var predicates = []string{ "true", "false", @@ -234,7 +168,7 @@ func (CaseExprWithPredicate) Test(yield Iterator) { }) } -func (Ceil) Test(yield Iterator) { +func FnCeil(yield Query) { var ceilInputs = []string{ "0", "1", @@ -259,7 +193,7 @@ func (Ceil) Test(yield Iterator) { } } -func (Floor) Test(yield Iterator) { +func FnFloor(yield Query) { var floorInputs = []string{ "0", "1", @@ -282,7 +216,7 @@ func (Floor) Test(yield Iterator) { } } -func (Abs) Test(yield Iterator) { +func FnAbs(yield Query) { var absInputs = []string{ "0", "1", @@ -305,11 +239,11 @@ func (Abs) Test(yield Iterator) { } } -func (Pi) Test(yield Iterator) { +func FnPi(yield Query) { yield("PI()+0.000000000000000000", nil) } -func (Acos) Test(yield Iterator) { +func FnAcos(yield Query) { for _, num := range radianInputs { yield(fmt.Sprintf("ACOS(%s)", num), nil) } @@ -319,7 +253,7 @@ func (Acos) Test(yield Iterator) { } } -func (Asin) Test(yield Iterator) { +func FnAsin(yield Query) { for _, num := range radianInputs { yield(fmt.Sprintf("ASIN(%s)", num), nil) } @@ -329,7 +263,7 @@ func (Asin) Test(yield Iterator) { } } -func (Atan) Test(yield Iterator) { +func FnAtan(yield Query) { for _, num := range radianInputs { yield(fmt.Sprintf("ATAN(%s)", num), nil) } @@ -339,7 +273,7 @@ func (Atan) Test(yield Iterator) { } } -func (Atan2) Test(yield Iterator) { +func FnAtan2(yield Query) { for _, num1 := range radianInputs { for _, num2 := range radianInputs { yield(fmt.Sprintf("ATAN(%s, %s)", num1, num2), nil) @@ -355,7 +289,7 @@ func (Atan2) Test(yield Iterator) { } } -func (Cos) Test(yield Iterator) { +func FnCos(yield Query) { for _, num := range radianInputs { yield(fmt.Sprintf("COS(%s)", num), nil) } @@ -365,7 +299,7 @@ func (Cos) Test(yield Iterator) { } } -func (Cot) Test(yield Iterator) { +func FnCot(yield Query) { for _, num := range radianInputs { yield(fmt.Sprintf("COT(%s)", num), nil) } @@ -375,7 +309,7 @@ func (Cot) Test(yield Iterator) { } } -func (Sin) Test(yield Iterator) { +func FnSin(yield Query) { for _, num := range radianInputs { yield(fmt.Sprintf("SIN(%s)", num), nil) } @@ -385,7 +319,7 @@ func (Sin) Test(yield Iterator) { } } -func (Tan) Test(yield Iterator) { +func FnTan(yield Query) { for _, num := range radianInputs { yield(fmt.Sprintf("TAN(%s)", num), nil) } @@ -395,7 +329,7 @@ func (Tan) Test(yield Iterator) { } } -func (Degrees) Test(yield Iterator) { +func FnDegrees(yield Query) { for _, num := range radianInputs { yield(fmt.Sprintf("DEGREES(%s)", num), nil) } @@ -405,7 +339,7 @@ func (Degrees) Test(yield Iterator) { } } -func (Radians) Test(yield Iterator) { +func FnRadians(yield Query) { for _, num := range radianInputs { yield(fmt.Sprintf("RADIANS(%s)", num), nil) } @@ -415,7 +349,7 @@ func (Radians) Test(yield Iterator) { } } -func (CaseExprWithValue) Test(yield Iterator) { +func CaseExprWithValue(yield Query) { var elements []string elements = append(elements, inputBitwise...) elements = append(elements, inputComparisonElement...) @@ -430,7 +364,7 @@ func (CaseExprWithValue) Test(yield Iterator) { } } -func (Base64) Test(yield Iterator) { +func Base64(yield Query) { var inputs = []string{ `'bGlnaHQgdw=='`, `'bGlnaHQgd28='`, @@ -448,17 +382,9 @@ func (Base64) Test(yield Iterator) { } } -func (Conversion) Test(yield Iterator) { - var right = []string{ - "BINARY", "BINARY(1)", "BINARY(0)", "BINARY(16)", "BINARY(-1)", - "CHAR", "CHAR(1)", "CHAR(0)", "CHAR(16)", "CHAR(-1)", - "NCHAR", "NCHAR(1)", "NCHAR(0)", "NCHAR(16)", "NCHAR(-1)", - "DECIMAL", "DECIMAL(0, 4)", "DECIMAL(12, 0)", "DECIMAL(12, 4)", - "DOUBLE", "REAL", - "SIGNED", "UNSIGNED", "SIGNED INTEGER", "UNSIGNED INTEGER", "JSON", - } +func Conversion(yield Query) { for _, lhs := range inputConversions { - for _, rhs := range right { + for _, rhs := range inputConversionTypes { yield(fmt.Sprintf("CAST(%s AS %s)", lhs, rhs), nil) yield(fmt.Sprintf("CONVERT(%s, %s)", lhs, rhs), nil) yield(fmt.Sprintf("CAST(CAST(%s AS JSON) AS %s)", lhs, rhs), nil) @@ -466,7 +392,7 @@ func (Conversion) Test(yield Iterator) { } } -func (LargeDecimals) Test(yield Iterator) { +func LargeDecimals(yield Query) { var largepi = inputPi + inputPi for pos := 0; pos < len(largepi); pos++ { @@ -475,7 +401,7 @@ func (LargeDecimals) Test(yield Iterator) { } } -func (LargeIntegers) Test(yield Iterator) { +func LargeIntegers(yield Query) { var largepi = inputPi + inputPi for pos := 1; pos < len(largepi); pos++ { @@ -484,7 +410,7 @@ func (LargeIntegers) Test(yield Iterator) { } } -func (DecimalClamping) Test(yield Iterator) { +func DecimalClamping(yield Query) { for pos := 0; pos < len(inputPi); pos++ { for m := 0; m < min(len(inputPi), 67); m += 2 { for d := 0; d <= min(m, 33); d += 2 { @@ -494,7 +420,7 @@ func (DecimalClamping) Test(yield Iterator) { } } -func (BitwiseOperatorsUnary) Test(yield Iterator) { +func BitwiseOperatorsUnary(yield Query) { for _, op := range []string{"~", "BIT_COUNT"} { for _, rhs := range inputBitwise { yield(fmt.Sprintf("%s(%s)", op, rhs), nil) @@ -502,7 +428,7 @@ func (BitwiseOperatorsUnary) Test(yield Iterator) { } } -func (BitwiseOperators) Test(yield Iterator) { +func BitwiseOperators(yield Query) { for _, op := range []string{"&", "|", "^", "<<", ">>"} { for _, lhs := range inputBitwise { for _, rhs := range inputBitwise { @@ -512,7 +438,7 @@ func (BitwiseOperators) Test(yield Iterator) { } } -func (WeightString) Test(yield Iterator) { +func WeightString(yield Query) { var inputs = []string{ `'foobar'`, `_latin1 'foobar'`, `'foobar' as char(12)`, `'foobar' as binary(12)`, @@ -526,7 +452,7 @@ func (WeightString) Test(yield Iterator) { } } -func (FloatFormatting) Test(yield Iterator) { +func FloatFormatting(yield Query) { var floats = []string{ `18446744073709551615`, `9223372036854775807`, @@ -554,7 +480,7 @@ func (FloatFormatting) Test(yield Iterator) { } } -func (UnderscoreAndPercentage) Test(yield Iterator) { +func UnderscoreAndPercentage(yield Query) { var queries = []string{ `'pokemon' LIKE 'poke%'`, `'pokemon' LIKE 'poke\%'`, @@ -578,7 +504,7 @@ func (UnderscoreAndPercentage) Test(yield Iterator) { } } -func (Types) Test(yield Iterator) { +func Types(yield Query) { var queries = []string{ "1 > 3", "3 > 1", @@ -609,7 +535,7 @@ func (Types) Test(yield Iterator) { } } -func (HexArithmetic) Test(yield Iterator) { +func HexArithmetic(yield Query) { var cases = []string{ `0`, `1`, `1.0`, `0.0`, `1.0e0`, `0.0e0`, `X'00'`, `X'1234'`, `X'ff'`, @@ -626,7 +552,7 @@ func (HexArithmetic) Test(yield Iterator) { } } -func (NumericTypes) Test(yield Iterator) { +func NumericTypes(yield Query) { var numbers = []string{ `1234`, `-1234`, `18446744073709551614`, @@ -653,7 +579,7 @@ func (NumericTypes) Test(yield Iterator) { } } -func (NegateArithmetic) Test(yield Iterator) { +func NegateArithmetic(yield Query) { var cases = []string{ `0`, `1`, `1.0`, `0.0`, `1.0e0`, `0.0e0`, `X'00'`, `X'1234'`, `X'ff'`, @@ -671,7 +597,7 @@ func (NegateArithmetic) Test(yield Iterator) { } } -func (CollationOperations) Test(yield Iterator) { +func CollationOperations(yield Query) { var cases = []string{ "COLLATION('foobar')", "COLLATION(_latin1 'foobar')", @@ -685,7 +611,7 @@ func (CollationOperations) Test(yield Iterator) { } } -func (LikeComparison) Test(yield Iterator) { +func LikeComparison(yield Query) { var left = []string{ `'foobar'`, `'FOOBAR'`, `'1234'`, `1234`, @@ -708,7 +634,7 @@ func (LikeComparison) Test(yield Iterator) { } } -func (MultiComparisons) Test(yield Iterator) { +func MultiComparisons(yield Query) { var numbers = []string{ `0`, `-1`, `1`, `0.0`, `1.0`, `-1.0`, `1.0E0`, `-1.0E0`, `0.0E0`, strconv.FormatUint(math.MaxUint64, 10), @@ -752,7 +678,7 @@ func (MultiComparisons) Test(yield Iterator) { } } -func (IsStatement) Test(yield Iterator) { +func IsStatement(yield Query) { var left = []string{ "NULL", "TRUE", "FALSE", `1`, `0`, `1.0`, `0.0`, `-1`, `666`, @@ -775,7 +701,7 @@ func (IsStatement) Test(yield Iterator) { } } -func (TupleComparisons) Test(yield Iterator) { +func TupleComparisons(yield Query) { var elems = []string{"NULL", "-1", "0", "1"} var operators = []string{"=", "!=", "<=>", "<", "<=", ">", ">="} @@ -793,7 +719,7 @@ func (TupleComparisons) Test(yield Iterator) { } } -func (Comparisons) Test(yield Iterator) { +func Comparisons(yield Query) { var operators = []string{"=", "!=", "<=>", "<", "<=", ">", ">="} for _, op := range operators { for i := 0; i < len(inputComparisonElement); i++ { @@ -804,7 +730,7 @@ func (Comparisons) Test(yield Iterator) { } } -func (JSONExtract) Test(yield Iterator) { +func JSONExtract(yield Query) { var cases = []struct { Operator string Path string @@ -844,60 +770,55 @@ func (JSONExtract) Test(yield Iterator) { } } -func (JSONExtract) Environment() *evalengine.ExpressionEnv { - env := new(evalengine.ExpressionEnv) - env.DefaultCollation = collations.CollationUtf8mb4ID - env.Fields = []*querypb.Field{ - { - Name: "column0", - Type: sqltypes.TypeJSON, - ColumnType: "JSON", - }, - } - return env +var JSONExtract_Schema = []*querypb.Field{ + { + Name: "column0", + Type: sqltypes.TypeJSON, + ColumnType: "JSON", + }, } -func (FnLower) Test(yield Iterator) { +func FnLower(yield Query) { for _, str := range inputStrings { yield(fmt.Sprintf("LOWER(%s)", str), nil) yield(fmt.Sprintf("LCASE(%s)", str), nil) } } -func (FnUpper) Test(yield Iterator) { +func FnUpper(yield Query) { for _, str := range inputStrings { yield(fmt.Sprintf("UPPER(%s)", str), nil) yield(fmt.Sprintf("UCASE(%s)", str), nil) } } -func (FnCharLength) Test(yield Iterator) { +func FnCharLength(yield Query) { for _, str := range inputStrings { yield(fmt.Sprintf("CHAR_LENGTH(%s)", str), nil) yield(fmt.Sprintf("CHARACTER_LENGTH(%s)", str), nil) } } -func (FnLength) Test(yield Iterator) { +func FnLength(yield Query) { for _, str := range inputStrings { yield(fmt.Sprintf("LENGTH(%s)", str), nil) yield(fmt.Sprintf("OCTET_LENGTH(%s)", str), nil) } } -func (FnBitLength) Test(yield Iterator) { +func FnBitLength(yield Query) { for _, str := range inputStrings { yield(fmt.Sprintf("BIT_LENGTH(%s)", str), nil) } } -func (FnAscii) Test(yield Iterator) { +func FnAscii(yield Query) { for _, str := range inputStrings { yield(fmt.Sprintf("ASCII(%s)", str), nil) } } -func (FnRepeat) Test(yield Iterator) { +func FnRepeat(yield Query) { counts := []string{"-1", "1.2", "3", "1073741825"} for _, str := range inputStrings { for _, cnt := range counts { @@ -906,7 +827,7 @@ func (FnRepeat) Test(yield Iterator) { } } -func (FnHex) Test(yield Iterator) { +func FnHex(yield Query) { for _, str := range inputStrings { yield(fmt.Sprintf("hex(%s)", str), nil) } @@ -920,7 +841,7 @@ func (FnHex) Test(yield Iterator) { } } -func (InStatement) Test(yield Iterator) { +func InStatement(yield Query) { roots := append([]string(nil), inputBitwise...) roots = append(roots, inputComparisonElement...) diff --git a/go/vt/vtgate/evalengine/testcases/helpers.go b/go/vt/vtgate/evalengine/testcases/helpers.go index 193533fcaa8..242f803ccc8 100644 --- a/go/vt/vtgate/evalengine/testcases/helpers.go +++ b/go/vt/vtgate/evalengine/testcases/helpers.go @@ -17,11 +17,27 @@ limitations under the License. package testcases import ( + "reflect" + "runtime" "strings" "vitess.io/vitess/go/sqltypes" + querypb "vitess.io/vitess/go/vt/proto/query" ) +type Query func(query string, row []sqltypes.Value) +type Runner func(yield Query) +type TestCase struct { + Run Runner + Schema []*querypb.Field +} + +func (tc TestCase) Name() string { + ptr := reflect.ValueOf(tc.Run).Pointer() + name := runtime.FuncForPC(ptr).Name() + return name[strings.LastIndexByte(name, '.')+1:] +} + func perm(a []string, f func([]string)) { perm1(a, f, 0) } diff --git a/go/vt/vtgate/evalengine/testcases/inputs.go b/go/vt/vtgate/evalengine/testcases/inputs.go index f871502f0d7..94f818dc0f3 100644 --- a/go/vt/vtgate/evalengine/testcases/inputs.go +++ b/go/vt/vtgate/evalengine/testcases/inputs.go @@ -131,3 +131,12 @@ var inputStrings = []string{ // "_utf32 'AabcÃ…Ã¥'", // "_ucs2 'AabcÃ…Ã¥'", } + +var inputConversionTypes = []string{ + "BINARY", "BINARY(1)", "BINARY(0)", "BINARY(16)", "BINARY(-1)", + "CHAR", "CHAR(1)", "CHAR(0)", "CHAR(16)", "CHAR(-1)", + "NCHAR", "NCHAR(1)", "NCHAR(0)", "NCHAR(16)", "NCHAR(-1)", + "DECIMAL", "DECIMAL(0, 4)", "DECIMAL(12, 0)", "DECIMAL(12, 4)", + "DOUBLE", "REAL", + "SIGNED", "UNSIGNED", "SIGNED INTEGER", "UNSIGNED INTEGER", "JSON", +} From 07d300382fab71fc061f3c2fe56f6837f4b339c2 Mon Sep 17 00:00:00 2001 From: Vicent Marti Date: Mon, 13 Mar 2023 16:32:47 +0100 Subject: [PATCH 31/42] evalengine: add missing headers Signed-off-by: Vicent Marti --- go/vt/vtgate/evalengine/compiler.go | 16 ++++++++++++++++ go/vt/vtgate/evalengine/compiler_arithmetic.go | 16 ++++++++++++++++ go/vt/vtgate/evalengine/compiler_asm.go | 16 ++++++++++++++++ go/vt/vtgate/evalengine/compiler_bit.go | 16 ++++++++++++++++ go/vt/vtgate/evalengine/compiler_compare.go | 16 ++++++++++++++++ go/vt/vtgate/evalengine/compiler_convert.go | 16 ++++++++++++++++ go/vt/vtgate/evalengine/compiler_fn.go | 16 ++++++++++++++++ go/vt/vtgate/evalengine/compiler_json.go | 16 ++++++++++++++++ go/vt/vtgate/evalengine/compiler_logical.go | 16 ++++++++++++++++ go/vt/vtgate/evalengine/compiler_test.go | 16 ++++++++++++++++ 10 files changed, 160 insertions(+) diff --git a/go/vt/vtgate/evalengine/compiler.go b/go/vt/vtgate/evalengine/compiler.go index cc122155e24..24ff089c047 100644 --- a/go/vt/vtgate/evalengine/compiler.go +++ b/go/vt/vtgate/evalengine/compiler.go @@ -1,3 +1,19 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package evalengine import ( diff --git a/go/vt/vtgate/evalengine/compiler_arithmetic.go b/go/vt/vtgate/evalengine/compiler_arithmetic.go index 62319f1a03f..a81afacc53a 100644 --- a/go/vt/vtgate/evalengine/compiler_arithmetic.go +++ b/go/vt/vtgate/evalengine/compiler_arithmetic.go @@ -1,3 +1,19 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package evalengine import "vitess.io/vitess/go/sqltypes" diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index b13fbed9a5c..acf4df112fa 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -1,3 +1,19 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package evalengine import ( diff --git a/go/vt/vtgate/evalengine/compiler_bit.go b/go/vt/vtgate/evalengine/compiler_bit.go index c56d4c6266b..9adb0efe242 100644 --- a/go/vt/vtgate/evalengine/compiler_bit.go +++ b/go/vt/vtgate/evalengine/compiler_bit.go @@ -1,3 +1,19 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package evalengine import "vitess.io/vitess/go/sqltypes" diff --git a/go/vt/vtgate/evalengine/compiler_compare.go b/go/vt/vtgate/evalengine/compiler_compare.go index cacdcb66438..34274da8844 100644 --- a/go/vt/vtgate/evalengine/compiler_compare.go +++ b/go/vt/vtgate/evalengine/compiler_compare.go @@ -1,3 +1,19 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package evalengine import ( diff --git a/go/vt/vtgate/evalengine/compiler_convert.go b/go/vt/vtgate/evalengine/compiler_convert.go index 131f011cdba..7f071b85f8f 100644 --- a/go/vt/vtgate/evalengine/compiler_convert.go +++ b/go/vt/vtgate/evalengine/compiler_convert.go @@ -1,3 +1,19 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package evalengine import ( diff --git a/go/vt/vtgate/evalengine/compiler_fn.go b/go/vt/vtgate/evalengine/compiler_fn.go index 91786dc087f..39365596fb6 100644 --- a/go/vt/vtgate/evalengine/compiler_fn.go +++ b/go/vt/vtgate/evalengine/compiler_fn.go @@ -1,3 +1,19 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package evalengine import ( diff --git a/go/vt/vtgate/evalengine/compiler_json.go b/go/vt/vtgate/evalengine/compiler_json.go index e39065a89e6..e18bf35242f 100644 --- a/go/vt/vtgate/evalengine/compiler_json.go +++ b/go/vt/vtgate/evalengine/compiler_json.go @@ -1,3 +1,19 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package evalengine import ( diff --git a/go/vt/vtgate/evalengine/compiler_logical.go b/go/vt/vtgate/evalengine/compiler_logical.go index bd1d10ef974..b6f964c3fd5 100644 --- a/go/vt/vtgate/evalengine/compiler_logical.go +++ b/go/vt/vtgate/evalengine/compiler_logical.go @@ -1,3 +1,19 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package evalengine import ( diff --git a/go/vt/vtgate/evalengine/compiler_test.go b/go/vt/vtgate/evalengine/compiler_test.go index 6f20cac4873..dd17522398c 100644 --- a/go/vt/vtgate/evalengine/compiler_test.go +++ b/go/vt/vtgate/evalengine/compiler_test.go @@ -1,3 +1,19 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package evalengine_test import ( From 9ff967c28997e145f14a0e35d636f1ec6251fe92 Mon Sep 17 00:00:00 2001 From: Vicent Marti Date: Mon, 13 Mar 2023 16:36:58 +0100 Subject: [PATCH 32/42] evalengine: sizegen Signed-off-by: Vicent Marti --- go/vt/vtgate/evalengine/cached_size.go | 156 +++++++++++++++++++++++++ 1 file changed, 156 insertions(+) diff --git a/go/vt/vtgate/evalengine/cached_size.go b/go/vt/vtgate/evalengine/cached_size.go index 1733706123b..fe221bbe581 100644 --- a/go/vt/vtgate/evalengine/cached_size.go +++ b/go/vt/vtgate/evalengine/cached_size.go @@ -339,6 +339,66 @@ func (cached *builtinASCII) CachedSize(alloc bool) int64 { size += cached.CallExpr.CachedSize(false) return size } +func (cached *builtinAbs) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(48) + } + // field CallExpr vitess.io/vitess/go/vt/vtgate/evalengine.CallExpr + size += cached.CallExpr.CachedSize(false) + return size +} +func (cached *builtinAcos) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(48) + } + // field CallExpr vitess.io/vitess/go/vt/vtgate/evalengine.CallExpr + size += cached.CallExpr.CachedSize(false) + return size +} +func (cached *builtinAsin) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(48) + } + // field CallExpr vitess.io/vitess/go/vt/vtgate/evalengine.CallExpr + size += cached.CallExpr.CachedSize(false) + return size +} +func (cached *builtinAtan) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(48) + } + // field CallExpr vitess.io/vitess/go/vt/vtgate/evalengine.CallExpr + size += cached.CallExpr.CachedSize(false) + return size +} +func (cached *builtinAtan2) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(48) + } + // field CallExpr vitess.io/vitess/go/vt/vtgate/evalengine.CallExpr + size += cached.CallExpr.CachedSize(false) + return size +} func (cached *builtinBitCount) CachedSize(alloc bool) int64 { if cached == nil { return int64(0) @@ -423,6 +483,54 @@ func (cached *builtinCollation) CachedSize(alloc bool) int64 { size += cached.CallExpr.CachedSize(false) return size } +func (cached *builtinCos) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(48) + } + // field CallExpr vitess.io/vitess/go/vt/vtgate/evalengine.CallExpr + size += cached.CallExpr.CachedSize(false) + return size +} +func (cached *builtinCot) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(48) + } + // field CallExpr vitess.io/vitess/go/vt/vtgate/evalengine.CallExpr + size += cached.CallExpr.CachedSize(false) + return size +} +func (cached *builtinDegrees) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(48) + } + // field CallExpr vitess.io/vitess/go/vt/vtgate/evalengine.CallExpr + size += cached.CallExpr.CachedSize(false) + return size +} +func (cached *builtinFloor) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(48) + } + // field CallExpr vitess.io/vitess/go/vt/vtgate/evalengine.CallExpr + size += cached.CallExpr.CachedSize(false) + return size +} func (cached *builtinFromBase64) CachedSize(alloc bool) int64 { if cached == nil { return int64(0) @@ -567,6 +675,30 @@ func (cached *builtinMultiComparison) CachedSize(alloc bool) int64 { size += cached.CallExpr.CachedSize(false) return size } +func (cached *builtinPi) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(48) + } + // field CallExpr vitess.io/vitess/go/vt/vtgate/evalengine.CallExpr + size += cached.CallExpr.CachedSize(false) + return size +} +func (cached *builtinRadians) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(48) + } + // field CallExpr vitess.io/vitess/go/vt/vtgate/evalengine.CallExpr + size += cached.CallExpr.CachedSize(false) + return size +} func (cached *builtinRepeat) CachedSize(alloc bool) int64 { if cached == nil { return int64(0) @@ -579,6 +711,30 @@ func (cached *builtinRepeat) CachedSize(alloc bool) int64 { size += cached.CallExpr.CachedSize(false) return size } +func (cached *builtinSin) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(48) + } + // field CallExpr vitess.io/vitess/go/vt/vtgate/evalengine.CallExpr + size += cached.CallExpr.CachedSize(false) + return size +} +func (cached *builtinTan) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(48) + } + // field CallExpr vitess.io/vitess/go/vt/vtgate/evalengine.CallExpr + size += cached.CallExpr.CachedSize(false) + return size +} func (cached *builtinToBase64) CachedSize(alloc bool) int64 { if cached == nil { return int64(0) From 26871f527985475a16e38bc41e0210d4c0c02d83 Mon Sep 17 00:00:00 2001 From: Vicent Marti Date: Mon, 13 Mar 2023 16:45:59 +0100 Subject: [PATCH 33/42] slices2: update licensing Signed-off-by: Vicent Marti --- go/slices2/slices.go | 527 ++----------------------------------------- 1 file changed, 18 insertions(+), 509 deletions(-) diff --git a/go/slices2/slices.go b/go/slices2/slices.go index 413490a4b9e..c69acb13cd5 100644 --- a/go/slices2/slices.go +++ b/go/slices2/slices.go @@ -1,3 +1,21 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package slices2 contains generic Slice helpers; +// Some of this code is sourced from https://github.com/luraim/fun (Apache v2) package slices2 // All returns true if all elements return true for given predicate @@ -19,512 +37,3 @@ func Any[T any](s []T, fn func(T) bool) bool { } return false } - -// AppendToGroup adds the key, value to the given map where each key -// points to a slice of values -func AppendToGroup[M ~map[K][]V, K comparable, V any](m M, k K, v V) { - lst, ok := m[k] - if !ok { - lst = make([]V, 0) - } - lst = append(lst, v) - m[k] = lst -} - -// Associate returns a map containing key-value pairs returned by the given -// function applied to the elements of the given slice -func Associate[T, V any, K comparable](s []T, fn func(T) (K, V)) map[K]V { - ret := make(map[K]V) - for _, e := range s { - k, v := fn(e) - ret[k] = v - } - return ret -} - -// Chunked splits the slice into a slice of slices, each not exceeding given size -// The last slice might have fewer elements than the given size -func Chunked[T any](s []T, chunkSize int) [][]T { - sz := len(s) - ret := make([][]T, 0, sz/chunkSize+2) - var sub []T - for i := 0; i < sz; i++ { - if i%chunkSize == 0 { - if len(sub) > 0 { - ret = append(ret, sub) - } - sub = make([]T, 0, chunkSize) - } - sub = append(sub, s[i]) - } - if len(sub) > 0 { - ret = append(ret, sub) - } - return ret -} - -func ChunkedBy[T any](s []T, fn func(T, T) bool) [][]T { - ret := make([][]T, 0) - switch len(s) { - case 0: - return ret - case 1: - ret = append(ret, []T{s[0]}) - return ret - } - var currentSubList = []T{s[0]} - for _, e := range s[1:] { - if fn(currentSubList[len(currentSubList)-1], e) { - currentSubList = append(currentSubList, e) - } else { - // save current sub list and start a new one - ret = append(ret, currentSubList) - currentSubList = []T{e} - } - } - if len(currentSubList) > 0 { - ret = append(ret, currentSubList) - } - return ret -} - -// Distinct returns a slice containing only distinct elements from the given slice -// Elements will retain their original order. -func Distinct[T comparable](s []T) []T { - m := make(map[T]bool) - ret := make([]T, 0) - for _, e := range s { - _, ok := m[e] - if ok { - continue - } - m[e] = true - ret = append(ret, e) - } - return ret -} - -// DistinctBy returns a slice containing only distinct elements from the -// given slice as distinguished by the given selector function -// Elements will retain their original order. -func DistinctBy[T any, K comparable](s []T, fn func(T) K) []T { - m := make(map[K]bool) - ret := make([]T, 0) - for _, e := range s { - k := fn(e) - _, ok := m[k] - if ok { - continue - } - m[k] = true - ret = append(ret, e) - } - return ret -} - -// Drop returns a slice containing all elements except the first n -func Drop[T any](s []T, n int) []T { - if n >= len(s) { - return make([]T, 0) - } - return s[n:] -} - -// DropLast returns a slice containing all elements except the last n -func DropLast[T any](s []T, n int) []T { - if n >= len(s) { - return make([]T, 0) - } - return s[:len(s)-n] -} - -// DropLastWhile returns a slice containing all elements except the last elements -// that satisfy the given predicate -func DropLastWhile[T any](s []T, fn func(T) bool) []T { - if len(s) == 0 { - return s - } - i := len(s) - 1 - for ; i >= 0; i-- { - if !fn(s[i]) { - break - } - } - return s[:i+1] -} - -// DropWhile returns a slice containing all elements except the first elements -// that satisfy the given predicate -func DropWhile[T any](s []T, fn func(T) bool) []T { - if len(s) == 0 { - return s - } - i := 0 - for ; i < len(s); i++ { - if !fn(s[i]) { - break - } - } - return s[i:] -} - -// Filter returns the slice obtained after retaining only those elements -// in the given slice for which the given function returns true -func Filter[T any](s []T, fn func(T) bool) []T { - ret := make([]T, 0) - for _, e := range s { - if fn(e) { - ret = append(ret, e) - } - } - return ret -} - -func FilterInPlace[T any](s []T, fn func(T) bool) []T { - filtered := s[:0] - for _, e := range s { - if fn(e) { - filtered = append(filtered, e) - } - } - return filtered -} - -// FilterIndexed returns the slice obtained after retaining only those elements -// in the given slice for which the given function returns true. Predicate -// receives the value as well as its index in the slice. -func FilterIndexed[T any](s []T, fn func(int, T) bool) []T { - ret := make([]T, 0) - for i, e := range s { - if fn(i, e) { - ret = append(ret, e) - } - } - return ret -} - -// FilterMap returns the slice obtained after both filtering and mapping using -// the given function. The function should return two values - -// first, the result of the mapping operation and -// second, whether the element should be included or not. -// This is faster than doing a separate filter and map operations, -// since it avoids extra allocations and slice traversals. -func FilterMap[T1, T2 any]( - s []T1, - fn func(T1) (T2, bool), -) []T2 { - - ret := make([]T2, 0) - for _, e := range s { - m, ok := fn(e) - if ok { - ret = append(ret, m) - } - } - return ret -} - -func Find[T any](s []T, fn func(T) bool) (T, bool) { - for _, e := range s { - if fn(e) { - return e, true - } - } - return *new(T), false -} - -// Fold accumulates values starting with given initial value and applying -// given function to current accumulator and each element. -func Fold[T, R any](s []T, initial R, fn func(R, T) R) R { - acc := initial - for _, e := range s { - acc = fn(acc, e) - } - return acc -} - -// FoldIndexed accumulates values starting with given initial value and applying -// given function to current accumulator and each element. Function also -// receives index of current element. -func FoldIndexed[T, R any](s []T, initial R, fn func(int, R, T) R) R { - acc := initial - for i, e := range s { - acc = fn(i, acc, e) - } - return acc -} - -// FoldItems accumulates values starting with given intial value and applying -// given function to current accumulator and each key, value. -func FoldItems[M ~map[K]V, K comparable, V, R any]( - m M, - initial R, - fn func(R, K, V) R, -) R { - acc := initial - for k, v := range m { - acc = fn(acc, k, v) - } - return acc -} - -// GetOrInsert checks if a value corresponding to the given key is present -// in the map. If present it returns the existing value. If not, it invokes the -// given callback function to get a new value for the given key, inserts it in -// the map and returns the new value -func GetOrInsert[M ~map[K]V, K comparable, V any](m M, k K, fn func(K) V) V { - v, ok := m[k] - if ok { - // present, return existing value - return v - } - // not present; get value, insert in map and return the new value - v = fn(k) - m[k] = v - return v -} - -// GroupBy returns a map containing key to list of values -// returned by the given function applied to the elements of the given slice -func GroupBy[T, V any, K comparable]( - s []T, - fn func(T) (K, V), -) map[K][]V { - ret := make(map[K][]V) - for _, e := range s { - k, v := fn(e) - lst, ok := ret[k] - if !ok { - lst = make([]V, 0) - } - lst = append(lst, v) - ret[k] = lst - } - return ret -} - -// Map returns the slice obtained after applying the given function over every -// element in the given slice -func Map[T1, T2 any](s []T1, fn func(T1) T2) []T2 { - ret := make([]T2, 0, len(s)) - for _, e := range s { - ret = append(ret, fn(e)) - } - return ret -} - -// MapIndexed returns the slice obtained after applying the given function over every -// element in the given slice. The function also receives the index of each -// element in the slice. -func MapIndexed[T1, T2 any](s []T1, fn func(int, T1) T2) []T2 { - ret := make([]T2, 0, len(s)) - for i, e := range s { - ret = append(ret, fn(i, e)) - } - return ret -} - -// Partition returns two slices where the first slice contains elements for -// which the predicate returned true and the second slice contains elements for -// which it returned false. -func Partition[T any](s []T, fn func(T) bool) ([]T, []T) { - trueList := make([]T, 0) - falseList := make([]T, 0) - for _, e := range s { - if fn(e) { - trueList = append(trueList, e) - } else { - falseList = append(falseList, e) - } - } - return trueList, falseList -} - -// Reduce accumulates the values starting with the first element and applying the -// operation from left to right to the current accumulator value and each element -// The input slice must have at least one element. -func Reduce[T any](s []T, fn func(T, T) T) T { - if len(s) == 1 { - return s[0] - } - return Fold(s[1:], s[0], fn) -} - -// ReduceIndexed accumulates the values starting with the first element and applying the -// operation from left to right to the current accumulator value and each element -// The input slice must have at least one element. The function also receives -// the index of the element. -func ReduceIndexed[T any](s []T, fn func(int, T, T) T) T { - if len(s) == 1 { - return s[0] - } - acc := s[0] - for i, e := range s[1:] { - acc = fn(i+1, acc, e) - } - return acc -} - -// Reverse reverses the elements of the list in place -func Reverse[T any](s []T) { - for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 { - s[i], s[j] = s[j], s[i] - } -} - -// Reversed returns a new list with the elements in reverse order -func Reversed[T any](s []T) []T { - ret := make([]T, 0, len(s)) - for i := len(s) - 1; i >= 0; i-- { - ret = append(ret, s[i]) - } - return ret -} - -// Take returns the slice obtained after taking the first n elements from the -// given slice. -// If n is greater than the length of the slice, returns the entire slice -func Take[T any](s []T, n int) []T { - if len(s) <= n { - return s - } - return s[:n] -} - -// TakeLast returns the slice obtained after taking the last n elements from the -// given slice. -func TakeLast[T any](s []T, n int) []T { - if len(s) <= n { - return s - } - return s[len(s)-n:] -} - -// TakeLastWhile returns a slice containing the last elements satisfying the given -// predicate -func TakeLastWhile[T any](s []T, fn func(T) bool) []T { - if len(s) == 0 { - return s - } - i := len(s) - 1 - for ; i >= 0; i-- { - if !fn(s[i]) { - break - } - } - return s[i+1:] -} - -// TakeWhile returns a list containing the first elements satisfying the -// given predicate -func TakeWhile[T any](s []T, fn func(T) bool) []T { - if len(s) == 0 { - return s - } - i := 0 - for ; i < len(s); i++ { - if !fn(s[i]) { - break - } - } - return s[:i] -} - -// TransformMap applies the given function to each key, value in the map, -// and returns a new map of the same type after transforming the keys -// and values depending on the callback functions return values. If the last -// bool return value from the callback function is false, the entry is dropped -func TransformMap[M ~map[K]V, K comparable, V any]( - m M, - fn func(k K, v V) (K, V, bool), -) M { - ret := make(map[K]V) - for k, v := range m { - newK, newV, include := fn(k, v) - if include { - ret[newK] = newV - } - } - return ret -} - -// Unzip returns two slices, where the first slice is built from the first -// values of each pair from the input slice, and the second slice is built -// from the second values of each pair -func Unzip[T1 any, T2 any](ps []*Pair[T1, T2]) ([]T1, []T2) { - l := len(ps) - s1 := make([]T1, 0, l) - s2 := make([]T2, 0, l) - for _, p := range ps { - s1 = append(s1, p.Fst) - s2 = append(s2, p.Snd) - } - return s1, s2 -} - -// Windowed returns a slice of sliding windows into the given slice of the -// given size, and with the given step -func Windowed[T any](s []T, size, step int) [][]T { - ret := make([][]T, 0) - sz := len(s) - if sz == 0 { - return ret - } - start := 0 - end := 0 - updateEnd := func() { - e := start + size - if e >= sz { - e = sz - } - end = e - } - updateStart := func() { - s := start + step - if s >= sz { - s = sz - } - start = s - } - updateEnd() - - for { - sub := make([]T, 0, end) - for i := start; i < end; i++ { - sub = append(sub, s[i]) - } - ret = append(ret, sub) - updateStart() - updateEnd() - if start == end { - break - } - } - return ret -} - -// Zip returns a slice of pairs from the elements of both slices with the same -// index. The returned slice has the length of the shortest input slice -func Zip[T1 any, T2 any](s1 []T1, s2 []T2) []*Pair[T1, T2] { - minLen := len(s1) - if minLen > len(s2) { - minLen = len(s2) - } - - // Allocate enough space to avoid copies and extra allocations - ret := make([]*Pair[T1, T2], 0, minLen) - - for i := 0; i < minLen; i++ { - ret = append(ret, &Pair[T1, T2]{ - Fst: s1[i], - Snd: s2[i], - }) - } - return ret -} - -// Pair represents a generic pair of two values -type Pair[T1, T2 any] struct { - Fst T1 - Snd T2 -} From 1bc5a96d913be7e744947bbb6c55ae2292a6d6ed Mon Sep 17 00:00:00 2001 From: Vicent Marti Date: Wed, 15 Mar 2023 10:16:59 +0100 Subject: [PATCH 34/42] evalengine: add documentation Signed-off-by: Vicent Marti --- go/vt/vtgate/evalengine/api_types.go | 4 ++++ go/vt/vtgate/evalengine/arena.go | 19 +++++++++++++++++++ go/vt/vtgate/evalengine/vm.go | 16 ++++++++++++++++ 3 files changed, 39 insertions(+) diff --git a/go/vt/vtgate/evalengine/api_types.go b/go/vt/vtgate/evalengine/api_types.go index bd7a4313903..822391c5f74 100644 --- a/go/vt/vtgate/evalengine/api_types.go +++ b/go/vt/vtgate/evalengine/api_types.go @@ -175,6 +175,10 @@ func ToNative(v sqltypes.Value) (any, error) { return out, err } +// NormalizeValue returns a normalized form of this value that matches the output +// of the evaluation engine. This is used to mask quirks in the way MySQL sends SQL +// values over the wire, to allow comparing our implementation against MySQL's in +// integration tests. func NormalizeValue(v sqltypes.Value, coll collations.ID) sqltypes.Value { typ := v.Type() if typ == sqltypes.VarChar && coll == collations.CollationBinaryID { diff --git a/go/vt/vtgate/evalengine/arena.go b/go/vt/vtgate/evalengine/arena.go index deb7d6b17cd..bda082f99cf 100644 --- a/go/vt/vtgate/evalengine/arena.go +++ b/go/vt/vtgate/evalengine/arena.go @@ -1,3 +1,19 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package evalengine import ( @@ -6,6 +22,9 @@ import ( "vitess.io/vitess/go/vt/vtgate/evalengine/internal/decimal" ) +// Arena is an arena memory allocator for eval types. +// It allocates the types from reusable slices to prevent heap allocations. +// After each evaluation execution, (*Arena).reset() should be called to reset the arena. type Arena struct { aInt64 []evalInt64 aUint64 []evalUint64 diff --git a/go/vt/vtgate/evalengine/vm.go b/go/vt/vtgate/evalengine/vm.go index fd3e4af7a21..31aff262ca6 100644 --- a/go/vt/vtgate/evalengine/vm.go +++ b/go/vt/vtgate/evalengine/vm.go @@ -1,3 +1,19 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package evalengine import ( From 5efd9f0b76f39fd8b8e7ba7e390f56853449d1e2 Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Wed, 15 Mar 2023 10:15:26 +0100 Subject: [PATCH 35/42] Move normalization logic as it's test only Signed-off-by: Dirkjan Bussink --- go/vt/vtgate/evalengine/api_types.go | 26 ----------------- .../evalengine/integration/comparison_test.go | 28 +++++++++++++++++-- 2 files changed, 26 insertions(+), 28 deletions(-) diff --git a/go/vt/vtgate/evalengine/api_types.go b/go/vt/vtgate/evalengine/api_types.go index 822391c5f74..bcb647caa89 100644 --- a/go/vt/vtgate/evalengine/api_types.go +++ b/go/vt/vtgate/evalengine/api_types.go @@ -17,9 +17,6 @@ limitations under the License. package evalengine import ( - "strconv" - - "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/sqltypes" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/sqlparser" @@ -174,26 +171,3 @@ func ToNative(v sqltypes.Value) (any, error) { } return out, err } - -// NormalizeValue returns a normalized form of this value that matches the output -// of the evaluation engine. This is used to mask quirks in the way MySQL sends SQL -// values over the wire, to allow comparing our implementation against MySQL's in -// integration tests. -func NormalizeValue(v sqltypes.Value, coll collations.ID) sqltypes.Value { - typ := v.Type() - if typ == sqltypes.VarChar && coll == collations.CollationBinaryID { - return sqltypes.NewVarBinary(string(v.Raw())) - } - if typ == sqltypes.Float32 || typ == sqltypes.Float64 { - var bitsize = 64 - if typ == sqltypes.Float32 { - bitsize = 32 - } - f, err := strconv.ParseFloat(v.RawStr(), bitsize) - if err != nil { - panic(err) - } - return sqltypes.MakeTrusted(typ, FormatFloat(typ, f)) - } - return v -} diff --git a/go/vt/vtgate/evalengine/integration/comparison_test.go b/go/vt/vtgate/evalengine/integration/comparison_test.go index 8104746b3b0..450eb3691b0 100644 --- a/go/vt/vtgate/evalengine/integration/comparison_test.go +++ b/go/vt/vtgate/evalengine/integration/comparison_test.go @@ -18,6 +18,7 @@ package integration import ( "fmt" + "strconv" "strings" "testing" "time" @@ -58,6 +59,29 @@ func init() { servenv.OnParse(registerFlags) } +// normalizeValue returns a normalized form of this value that matches the output +// of the evaluation engine. This is used to mask quirks in the way MySQL sends SQL +// values over the wire, to allow comparing our implementation against MySQL's in +// integration tests. +func normalizeValue(v sqltypes.Value, coll collations.ID) sqltypes.Value { + typ := v.Type() + if typ == sqltypes.VarChar && coll == collations.CollationBinaryID { + return sqltypes.NewVarBinary(string(v.Raw())) + } + if typ == sqltypes.Float32 || typ == sqltypes.Float64 { + var bitsize = 64 + if typ == sqltypes.Float32 { + bitsize = 32 + } + f, err := strconv.ParseFloat(v.RawStr(), bitsize) + if err != nil { + panic(err) + } + return sqltypes.MakeTrusted(typ, evalengine.FormatFloat(typ, f)) + } + return v +} + func compareRemoteExprEnv(t *testing.T, env *evalengine.ExpressionEnv, conn *mysql.Conn, expr string) { t.Helper() @@ -127,7 +151,7 @@ func compareRemoteExprEnv(t *testing.T, env *evalengine.ExpressionEnv, conn *mys } } if debugNormalize { - localVal = evalengine.NormalizeValue(v, local.Collation()) + localVal = normalizeValue(v, local.Collation()) } else { localVal = v } @@ -141,7 +165,7 @@ func compareRemoteExprEnv(t *testing.T, env *evalengine.ExpressionEnv, conn *mys } if remoteErr == nil { if debugNormalize { - remoteVal = evalengine.NormalizeValue(remote.Rows[0][0], collations.ID(remote.Fields[0].Charset)) + remoteVal = normalizeValue(remote.Rows[0][0], collations.ID(remote.Fields[0].Charset)) } else { remoteVal = remote.Rows[0][0] } From e0fca5e208b69ec01ec7d0402fa07f74ee7b4c5f Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Thu, 16 Mar 2023 11:33:38 +0100 Subject: [PATCH 36/42] evalengine/compiler: Add missing COT to compiler. Signed-off-by: Dirkjan Bussink --- go/vt/vtgate/evalengine/compiler_asm.go | 8 ++++++++ go/vt/vtgate/evalengine/compiler_fn.go | 21 +++++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index acf4df112fa..6a1a4381abb 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -1119,6 +1119,14 @@ func (asm *assembler) Fn_COS() { }, "FN COS FLOAT64(SP-1)") } +func (asm *assembler) Fn_COT() { + asm.emit(func(vm *VirtualMachine) int { + f := vm.stack[vm.sp-1].(*evalFloat) + f.f = 1.0 / math.Tan(f.f) + return 1 + }, "FN COT FLOAT64(SP-1)") +} + func (asm *assembler) Fn_SIN() { asm.emit(func(vm *VirtualMachine) int { f := vm.stack[vm.sp-1].(*evalFloat) diff --git a/go/vt/vtgate/evalengine/compiler_fn.go b/go/vt/vtgate/evalengine/compiler_fn.go index 39365596fb6..85f439a9a41 100644 --- a/go/vt/vtgate/evalengine/compiler_fn.go +++ b/go/vt/vtgate/evalengine/compiler_fn.go @@ -79,6 +79,8 @@ func (c *compiler) compileFn(call callable) (ctype, error) { return c.compileFn_ATAN2(call) case *builtinCos: return c.compileFn_COS(call) + case *builtinCot: + return c.compileFn_COT(call) case *builtinSin: return c.compileFn_SIN(call) case *builtinTan: @@ -591,6 +593,25 @@ func (c *compiler) compileFn_COS(expr *builtinCos) (ctype, error) { return ctype{Type: sqltypes.Float64, Col: collationNumeric}, nil } +func (c *compiler) compileFn_COT(expr *builtinCot) (ctype, error) { + arg, err := c.compileExpr(expr.Arguments[0]) + if err != nil { + return ctype{}, err + } + + skip := c.compileNullCheck1(arg) + + switch { + case sqltypes.IsFloat(arg.Type): + default: + c.asm.Convert_xf(1) + } + + c.asm.Fn_COT() + c.asm.jumpDestination(skip) + return ctype{Type: sqltypes.Float64, Col: collationNumeric}, nil +} + func (c *compiler) compileFn_SIN(expr *builtinSin) (ctype, error) { arg, err := c.compileExpr(expr.Arguments[0]) if err != nil { From 310d5726fa8b40ccefc975323323d8cdbb9579da Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Thu, 16 Mar 2023 11:39:30 +0100 Subject: [PATCH 37/42] evalengine/compiler: Fix flags for arithmatic operations Signed-off-by: Dirkjan Bussink --- go/vt/vtgate/evalengine/compiler_fn.go | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/go/vt/vtgate/evalengine/compiler_fn.go b/go/vt/vtgate/evalengine/compiler_fn.go index 85f439a9a41..51c75c257f2 100644 --- a/go/vt/vtgate/evalengine/compiler_fn.go +++ b/go/vt/vtgate/evalengine/compiler_fn.go @@ -403,7 +403,7 @@ func (c *compiler) compileFn_CEIL(expr *builtinCeil) (ctype, error) { skip := c.compileNullCheck1(arg) - convt := ctype{Type: arg.Type, Col: collationNumeric} + convt := ctype{Type: arg.Type, Col: collationNumeric, Flag: arg.Flag} switch { case sqltypes.IsIntegral(arg.Type): // No-op for integers. @@ -432,7 +432,7 @@ func (c *compiler) compileFn_FLOOR(expr *builtinFloor) (ctype, error) { skip := c.compileNullCheck1(arg) - convt := ctype{Type: arg.Type, Col: collationNumeric} + convt := ctype{Type: arg.Type, Col: collationNumeric, Flag: arg.Flag} switch { case sqltypes.IsIntegral(arg.Type): // No-op for integers. @@ -461,7 +461,7 @@ func (c *compiler) compileFn_ABS(expr *builtinAbs) (ctype, error) { skip := c.compileNullCheck1(arg) - convt := ctype{Type: arg.Type, Col: collationNumeric} + convt := ctype{Type: arg.Type, Col: collationNumeric, Flag: arg.Flag} switch { case sqltypes.IsUnsigned(arg.Type): // No-op if it's unsigned since that's already positive. @@ -504,7 +504,7 @@ func (c *compiler) compileFn_ACOS(expr *builtinAcos) (ctype, error) { c.asm.Fn_ACOS() c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Float64, Col: collationNumeric, Flag: flagNullable}, nil + return ctype{Type: sqltypes.Float64, Col: collationNumeric, Flag: arg.Flag | flagNullable}, nil } func (c *compiler) compileFn_ASIN(expr *builtinAsin) (ctype, error) { @@ -523,7 +523,7 @@ func (c *compiler) compileFn_ASIN(expr *builtinAsin) (ctype, error) { c.asm.Fn_ASIN() c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Float64, Col: collationNumeric, Flag: flagNullable}, nil + return ctype{Type: sqltypes.Float64, Col: collationNumeric, Flag: arg.Flag | flagNullable}, nil } func (c *compiler) compileFn_ATAN(expr *builtinAtan) (ctype, error) { @@ -542,7 +542,7 @@ func (c *compiler) compileFn_ATAN(expr *builtinAtan) (ctype, error) { c.asm.Fn_ATAN() c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Float64, Col: collationNumeric}, nil + return ctype{Type: sqltypes.Float64, Col: collationNumeric, Flag: arg.Flag}, nil } func (c *compiler) compileFn_ATAN2(expr *builtinAtan2) (ctype, error) { @@ -571,7 +571,7 @@ func (c *compiler) compileFn_ATAN2(expr *builtinAtan2) (ctype, error) { c.asm.Fn_ATAN2() c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Float64, Col: collationNumeric}, nil + return ctype{Type: sqltypes.Float64, Col: collationNumeric, Flag: arg1.Flag | arg2.Flag}, nil } func (c *compiler) compileFn_COS(expr *builtinCos) (ctype, error) { @@ -590,7 +590,7 @@ func (c *compiler) compileFn_COS(expr *builtinCos) (ctype, error) { c.asm.Fn_COS() c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Float64, Col: collationNumeric}, nil + return ctype{Type: sqltypes.Float64, Col: collationNumeric, Flag: arg.Flag}, nil } func (c *compiler) compileFn_COT(expr *builtinCot) (ctype, error) { @@ -609,7 +609,7 @@ func (c *compiler) compileFn_COT(expr *builtinCot) (ctype, error) { c.asm.Fn_COT() c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Float64, Col: collationNumeric}, nil + return ctype{Type: sqltypes.Float64, Col: collationNumeric, Flag: arg.Flag}, nil } func (c *compiler) compileFn_SIN(expr *builtinSin) (ctype, error) { @@ -628,7 +628,7 @@ func (c *compiler) compileFn_SIN(expr *builtinSin) (ctype, error) { c.asm.Fn_SIN() c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Float64, Col: collationNumeric}, nil + return ctype{Type: sqltypes.Float64, Col: collationNumeric, Flag: arg.Flag}, nil } func (c *compiler) compileFn_TAN(expr *builtinTan) (ctype, error) { @@ -647,7 +647,7 @@ func (c *compiler) compileFn_TAN(expr *builtinTan) (ctype, error) { c.asm.Fn_TAN() c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Float64, Col: collationNumeric}, nil + return ctype{Type: sqltypes.Float64, Col: collationNumeric, Flag: arg.Flag}, nil } func (c *compiler) compileFn_DEGREES(expr *builtinDegrees) (ctype, error) { @@ -666,7 +666,7 @@ func (c *compiler) compileFn_DEGREES(expr *builtinDegrees) (ctype, error) { c.asm.Fn_DEGREES() c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Float64, Col: collationNumeric}, nil + return ctype{Type: sqltypes.Float64, Col: collationNumeric, Flag: arg.Flag}, nil } func (c *compiler) compileFn_RADIANS(expr *builtinRadians) (ctype, error) { @@ -685,7 +685,7 @@ func (c *compiler) compileFn_RADIANS(expr *builtinRadians) (ctype, error) { c.asm.Fn_RADIANS() c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Float64, Col: collationNumeric}, nil + return ctype{Type: sqltypes.Float64, Col: collationNumeric, Flag: arg.Flag}, nil } func (c *compiler) compileFn_WEIGHT_STRING(call *builtinWeightString) (ctype, error) { From ffaa7370cd7b3a8cd7a7b35559442d9ee07be987 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Thu, 16 Mar 2023 15:45:43 +0100 Subject: [PATCH 38/42] refactor comparisons Signed-off-by: Andres Taylor --- go/vt/vtgate/evalengine/compiler_compare.go | 178 +++++++++++--------- 1 file changed, 95 insertions(+), 83 deletions(-) diff --git a/go/vt/vtgate/evalengine/compiler_compare.go b/go/vt/vtgate/evalengine/compiler_compare.go index 34274da8844..e8d63f05682 100644 --- a/go/vt/vtgate/evalengine/compiler_compare.go +++ b/go/vt/vtgate/evalengine/compiler_compare.go @@ -84,93 +84,12 @@ func (c *compiler) compileComparison(expr *ComparisonExpr) (ctype, error) { switch { case compareAsStrings(lt.Type, rt.Type): - var merged collations.TypedCollation - var coerceLeft collations.Coercion - var coerceRight collations.Coercion - var env = collations.Local() - - if lt.Col.Collation != rt.Col.Collation { - merged, coerceLeft, coerceRight, err = env.MergeCollations(lt.Col, rt.Col, collations.CoercionOptions{ - ConvertToSuperset: true, - ConvertWithCoercion: true, - }) - } else { - merged = lt.Col - } - if err != nil { + if err := c.compareAsStrings(lt, rt, err); err != nil { return ctype{}, err } - if coerceLeft == nil && coerceRight == nil { - c.asm.CmpString_collate(merged.Collation.Get()) - } else { - if coerceLeft == nil { - coerceLeft = func(dst, in []byte) ([]byte, error) { return in, nil } - } - if coerceRight == nil { - coerceRight = func(dst, in []byte) ([]byte, error) { return in, nil } - } - c.asm.CmpString_coerce(&compiledCoercion{ - col: merged.Collation.Get(), - left: coerceLeft, - right: coerceRight, - }) - } - case compareAsSameNumericType(lt.Type, rt.Type) || compareAsDecimal(lt.Type, rt.Type): - switch lt.Type { - case sqltypes.Int64: - switch rt.Type { - case sqltypes.Int64: - c.asm.CmpNum_ii() - case sqltypes.Uint64: - c.asm.CmpNum_iu(2, 1) - case sqltypes.Float64: - c.asm.CmpNum_if(2, 1) - case sqltypes.Decimal: - c.asm.CmpNum_id(2, 1) - } - case sqltypes.Uint64: - switch rt.Type { - case sqltypes.Int64: - c.asm.CmpNum_iu(1, 2) - swapped = true - case sqltypes.Uint64: - c.asm.CmpNum_uu() - case sqltypes.Float64: - c.asm.CmpNum_uf(2, 1) - case sqltypes.Decimal: - c.asm.CmpNum_ud(2, 1) - } - case sqltypes.Float64: - switch rt.Type { - case sqltypes.Int64: - c.asm.CmpNum_if(1, 2) - swapped = true - case sqltypes.Uint64: - c.asm.CmpNum_uf(1, 2) - swapped = true - case sqltypes.Float64: - c.asm.CmpNum_ff() - case sqltypes.Decimal: - c.asm.CmpNum_fd(2, 1) - } - - case sqltypes.Decimal: - switch rt.Type { - case sqltypes.Int64: - c.asm.CmpNum_id(1, 2) - swapped = true - case sqltypes.Uint64: - c.asm.CmpNum_ud(1, 2) - swapped = true - case sqltypes.Float64: - c.asm.CmpNum_fd(1, 2) - swapped = true - case sqltypes.Decimal: - c.asm.CmpNum_dd() - } - } + swapped = c.compareNumericTypes(lt, rt) case compareAsDates(lt.Type, rt.Type) || compareAsDateAndString(lt.Type, rt.Type) || compareAsDateAndNumeric(lt.Type, rt.Type): return ctype{}, c.unsupported(expr) @@ -225,6 +144,99 @@ func (c *compiler) compileComparison(expr *ComparisonExpr) (ctype, error) { return cmptype, nil } +func (c *compiler) compareNumericTypes(lt ctype, rt ctype) (swapped bool) { + switch lt.Type { + case sqltypes.Int64: + switch rt.Type { + case sqltypes.Int64: + c.asm.CmpNum_ii() + case sqltypes.Uint64: + c.asm.CmpNum_iu(2, 1) + case sqltypes.Float64: + c.asm.CmpNum_if(2, 1) + case sqltypes.Decimal: + c.asm.CmpNum_id(2, 1) + } + case sqltypes.Uint64: + switch rt.Type { + case sqltypes.Int64: + c.asm.CmpNum_iu(1, 2) + swapped = true + case sqltypes.Uint64: + c.asm.CmpNum_uu() + case sqltypes.Float64: + c.asm.CmpNum_uf(2, 1) + case sqltypes.Decimal: + c.asm.CmpNum_ud(2, 1) + } + case sqltypes.Float64: + switch rt.Type { + case sqltypes.Int64: + c.asm.CmpNum_if(1, 2) + swapped = true + case sqltypes.Uint64: + c.asm.CmpNum_uf(1, 2) + swapped = true + case sqltypes.Float64: + c.asm.CmpNum_ff() + case sqltypes.Decimal: + c.asm.CmpNum_fd(2, 1) + } + + case sqltypes.Decimal: + switch rt.Type { + case sqltypes.Int64: + c.asm.CmpNum_id(1, 2) + swapped = true + case sqltypes.Uint64: + c.asm.CmpNum_ud(1, 2) + swapped = true + case sqltypes.Float64: + c.asm.CmpNum_fd(1, 2) + swapped = true + case sqltypes.Decimal: + c.asm.CmpNum_dd() + } + } + return +} + +func (c *compiler) compareAsStrings(lt ctype, rt ctype, err error) error { + var merged collations.TypedCollation + var coerceLeft collations.Coercion + var coerceRight collations.Coercion + var env = collations.Local() + + if lt.Col.Collation != rt.Col.Collation { + merged, coerceLeft, coerceRight, err = env.MergeCollations(lt.Col, rt.Col, collations.CoercionOptions{ + ConvertToSuperset: true, + ConvertWithCoercion: true, + }) + } else { + merged = lt.Col + } + if err != nil { + return err + } + + if coerceLeft == nil && coerceRight == nil { + c.asm.CmpString_collate(merged.Collation.Get()) + } else { + if coerceLeft == nil { + coerceLeft = func(dst, in []byte) ([]byte, error) { return in, nil } + } + if coerceRight == nil { + coerceRight = func(dst, in []byte) ([]byte, error) { return in, nil } + } + c.asm.CmpString_coerce(&compiledCoercion{ + col: merged.Collation.Get(), + left: coerceLeft, + right: coerceRight, + }) + } + return nil +} + func (c *compiler) compileCheckTrue(when ctype, offset int) error { switch when.Type { case sqltypes.Int64: From 96c5759b7f7515350994b6967f9e2bfe9e882116 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Thu, 16 Mar 2023 15:53:02 +0100 Subject: [PATCH 39/42] no need to pass in err to method Signed-off-by: Andres Taylor --- go/vt/vtgate/evalengine/compiler_compare.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/go/vt/vtgate/evalengine/compiler_compare.go b/go/vt/vtgate/evalengine/compiler_compare.go index e8d63f05682..0cfaa7703de 100644 --- a/go/vt/vtgate/evalengine/compiler_compare.go +++ b/go/vt/vtgate/evalengine/compiler_compare.go @@ -84,7 +84,7 @@ func (c *compiler) compileComparison(expr *ComparisonExpr) (ctype, error) { switch { case compareAsStrings(lt.Type, rt.Type): - if err := c.compareAsStrings(lt, rt, err); err != nil { + if err := c.compareAsStrings(lt, rt); err != nil { return ctype{}, err } @@ -201,11 +201,12 @@ func (c *compiler) compareNumericTypes(lt ctype, rt ctype) (swapped bool) { return } -func (c *compiler) compareAsStrings(lt ctype, rt ctype, err error) error { +func (c *compiler) compareAsStrings(lt ctype, rt ctype) error { var merged collations.TypedCollation var coerceLeft collations.Coercion var coerceRight collations.Coercion var env = collations.Local() + var err error if lt.Col.Collation != rt.Col.Collation { merged, coerceLeft, coerceRight, err = env.MergeCollations(lt.Col, rt.Col, collations.CoercionOptions{ From 44328b9bed2f8a30e9b683949a5211bee1633348 Mon Sep 17 00:00:00 2001 From: Vicent Marti Date: Thu, 16 Mar 2023 16:10:59 +0100 Subject: [PATCH 40/42] evalengine/compiler: remove redundant constants Signed-off-by: Vicent Marti --- go/vt/vtgate/evalengine/compiler_asm.go | 201 ++++++++++++------------ go/vt/vtgate/evalengine/compiler_bit.go | 20 +-- go/vt/vtgate/evalengine/compiler_fn.go | 18 +-- 3 files changed, 114 insertions(+), 125 deletions(-) diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index 6a1a4381abb..e0971910fef 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -171,84 +171,88 @@ func (asm *assembler) BitCount_u() { }, "BIT_COUNT UINT64(SP-1)") } -func (asm *assembler) BitOp_bb(op bitwiseOp) { +func (asm *assembler) BitOp_and_bb() { asm.adjustStack(-1) + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalBytes) + r := vm.stack[vm.sp-1].(*evalBytes) + if len(l.bytes) != len(r.bytes) { + vm.err = errBitwiseOperandsLength + return 0 + } + for i := range l.bytes { + l.bytes[i] = l.bytes[i] & r.bytes[i] + } + vm.sp-- + return 1 + }, "AND BINARY(SP-2), BINARY(SP-1)") +} - switch op { - case and: - asm.emit(func(vm *VirtualMachine) int { - l := vm.stack[vm.sp-2].(*evalBytes) - r := vm.stack[vm.sp-1].(*evalBytes) - if len(l.bytes) != len(r.bytes) { - vm.err = errBitwiseOperandsLength - return 0 - } - for i := range l.bytes { - l.bytes[i] = l.bytes[i] & r.bytes[i] - } - vm.sp-- - return 1 - }, "AND BINARY(SP-2), BINARY(SP-1)") - case or: - asm.emit(func(vm *VirtualMachine) int { - l := vm.stack[vm.sp-2].(*evalBytes) - r := vm.stack[vm.sp-1].(*evalBytes) - if len(l.bytes) != len(r.bytes) { - vm.err = errBitwiseOperandsLength - return 0 - } - for i := range l.bytes { - l.bytes[i] = l.bytes[i] | r.bytes[i] - } - vm.sp-- - return 1 - }, "OR BINARY(SP-2), BINARY(SP-1)") - case xor: - asm.emit(func(vm *VirtualMachine) int { - l := vm.stack[vm.sp-2].(*evalBytes) - r := vm.stack[vm.sp-1].(*evalBytes) - if len(l.bytes) != len(r.bytes) { - vm.err = errBitwiseOperandsLength - return 0 - } - for i := range l.bytes { - l.bytes[i] = l.bytes[i] ^ r.bytes[i] - } - vm.sp-- - return 1 - }, "XOR BINARY(SP-2), BINARY(SP-1)") - } +func (asm *assembler) BitOp_or_bb() { + asm.adjustStack(-1) + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalBytes) + r := vm.stack[vm.sp-1].(*evalBytes) + if len(l.bytes) != len(r.bytes) { + vm.err = errBitwiseOperandsLength + return 0 + } + for i := range l.bytes { + l.bytes[i] = l.bytes[i] | r.bytes[i] + } + vm.sp-- + return 1 + }, "OR BINARY(SP-2), BINARY(SP-1)") } -func (asm *assembler) BitOp_uu(op bitwiseOp) { +func (asm *assembler) BitOp_xor_bb() { asm.adjustStack(-1) + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalBytes) + r := vm.stack[vm.sp-1].(*evalBytes) + if len(l.bytes) != len(r.bytes) { + vm.err = errBitwiseOperandsLength + return 0 + } + for i := range l.bytes { + l.bytes[i] = l.bytes[i] ^ r.bytes[i] + } + vm.sp-- + return 1 + }, "XOR BINARY(SP-2), BINARY(SP-1)") +} - switch op { - case and: - asm.emit(func(vm *VirtualMachine) int { - l := vm.stack[vm.sp-2].(*evalUint64) - r := vm.stack[vm.sp-1].(*evalUint64) - l.u = l.u & r.u - vm.sp-- - return 1 - }, "AND UINT64(SP-2), UINT64(SP-1)") - case or: - asm.emit(func(vm *VirtualMachine) int { - l := vm.stack[vm.sp-2].(*evalUint64) - r := vm.stack[vm.sp-1].(*evalUint64) - l.u = l.u | r.u - vm.sp-- - return 1 - }, "OR UINT64(SP-2), UINT64(SP-1)") - case xor: - asm.emit(func(vm *VirtualMachine) int { - l := vm.stack[vm.sp-2].(*evalUint64) - r := vm.stack[vm.sp-1].(*evalUint64) - l.u = l.u ^ r.u - vm.sp-- - return 1 - }, "XOR UINT64(SP-2), UINT64(SP-1)") - } +func (asm *assembler) BitOp_and_uu() { + asm.adjustStack(-1) + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalUint64) + r := vm.stack[vm.sp-1].(*evalUint64) + l.u = l.u & r.u + vm.sp-- + return 1 + }, "AND UINT64(SP-2), UINT64(SP-1)") +} + +func (asm *assembler) BitOp_or_uu() { + asm.adjustStack(-1) + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalUint64) + r := vm.stack[vm.sp-1].(*evalUint64) + l.u = l.u | r.u + vm.sp-- + return 1 + }, "OR UINT64(SP-2), UINT64(SP-1)") +} + +func (asm *assembler) BitOp_xor_uu() { + asm.adjustStack(-1) + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalUint64) + r := vm.stack[vm.sp-1].(*evalUint64) + l.u = l.u ^ r.u + vm.sp-- + return 1 + }, "XOR UINT64(SP-2), UINT64(SP-1)") } func (asm *assembler) BitShiftLeft_bu() { @@ -1365,34 +1369,35 @@ func (asm *assembler) Fn_JSON_UNQUOTE() { }, "FN JSON_UNQUOTE (SP-1)") } -func (asm *assembler) Fn_LENGTH(op lengthOp) { - switch op { - case charLen: - asm.emit(func(vm *VirtualMachine) int { - arg := vm.stack[vm.sp-1].(*evalBytes) +func (asm *assembler) Fn_CHAR_LENGTH() { + asm.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-1].(*evalBytes) - if sqltypes.IsBinary(arg.SQLType()) { - vm.stack[vm.sp-1] = vm.arena.newEvalInt64(int64(len(arg.bytes))) - } else { - coll := arg.col.Collation.Get() - count := charset.Length(coll.Charset(), arg.bytes) - vm.stack[vm.sp-1] = vm.arena.newEvalInt64(int64(count)) - } - return 1 - }, "FN CHAR_LENGTH VARCHAR(SP-1)") - case byteLen: - asm.emit(func(vm *VirtualMachine) int { - arg := vm.stack[vm.sp-1].(*evalBytes) + if sqltypes.IsBinary(arg.SQLType()) { vm.stack[vm.sp-1] = vm.arena.newEvalInt64(int64(len(arg.bytes))) - return 1 - }, "FN LENGTH VARCHAR(SP-1)") - case bitLen: - asm.emit(func(vm *VirtualMachine) int { - arg := vm.stack[vm.sp-1].(*evalBytes) - vm.stack[vm.sp-1] = vm.arena.newEvalInt64(int64(len(arg.bytes) * 8)) - return 1 - }, "FN BIT_LENGTH VARCHAR(SP-1)") - } + } else { + coll := arg.col.Collation.Get() + count := charset.Length(coll.Charset(), arg.bytes) + vm.stack[vm.sp-1] = vm.arena.newEvalInt64(int64(count)) + } + return 1 + }, "FN CHAR_LENGTH VARCHAR(SP-1)") +} + +func (asm *assembler) Fn_LENGTH() { + asm.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-1].(*evalBytes) + vm.stack[vm.sp-1] = vm.arena.newEvalInt64(int64(len(arg.bytes))) + return 1 + }, "FN LENGTH VARCHAR(SP-1)") +} + +func (asm *assembler) Fn_BIT_LENGTH() { + asm.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-1].(*evalBytes) + vm.stack[vm.sp-1] = vm.arena.newEvalInt64(int64(len(arg.bytes) * 8)) + return 1 + }, "FN BIT_LENGTH VARCHAR(SP-1)") } func (asm *assembler) Fn_LUCASE(upcase bool) { diff --git a/go/vt/vtgate/evalengine/compiler_bit.go b/go/vt/vtgate/evalengine/compiler_bit.go index 9adb0efe242..2bf3f0d98d3 100644 --- a/go/vt/vtgate/evalengine/compiler_bit.go +++ b/go/vt/vtgate/evalengine/compiler_bit.go @@ -21,11 +21,11 @@ import "vitess.io/vitess/go/sqltypes" func (c *compiler) compileBitwise(expr *BitwiseExpr) (ctype, error) { switch expr.Op.(type) { case *opBitAnd: - return c.compileBitwiseOp(expr.Left, expr.Right, and) + return c.compileBitwiseOp(expr.Left, expr.Right, c.asm.BitOp_and_bb, c.asm.BitOp_and_uu) case *opBitOr: - return c.compileBitwiseOp(expr.Left, expr.Right, or) + return c.compileBitwiseOp(expr.Left, expr.Right, c.asm.BitOp_or_bb, c.asm.BitOp_or_uu) case *opBitXor: - return c.compileBitwiseOp(expr.Left, expr.Right, xor) + return c.compileBitwiseOp(expr.Left, expr.Right, c.asm.BitOp_xor_bb, c.asm.BitOp_xor_uu) case *opBitShl: return c.compileBitwiseShift(expr.Left, expr.Right, -1) case *opBitShr: @@ -35,15 +35,7 @@ func (c *compiler) compileBitwise(expr *BitwiseExpr) (ctype, error) { } } -type bitwiseOp int - -const ( - and bitwiseOp = iota - or - xor -) - -func (c *compiler) compileBitwiseOp(left Expr, right Expr, op bitwiseOp) (ctype, error) { +func (c *compiler) compileBitwiseOp(left Expr, right Expr, asm_ins_bb, asm_ins_uu func()) (ctype, error) { lt, err := c.compileExpr(left) if err != nil { return ctype{}, err @@ -58,7 +50,7 @@ func (c *compiler) compileBitwiseOp(left Expr, right Expr, op bitwiseOp) (ctype, if lt.Type == sqltypes.VarBinary && rt.Type == sqltypes.VarBinary { if !lt.isHexOrBitLiteral() || !rt.isHexOrBitLiteral() { - c.asm.BitOp_bb(op) + asm_ins_bb() c.asm.jumpDestination(skip) return ctype{Type: sqltypes.VarBinary, Col: collationBinary}, nil } @@ -67,7 +59,7 @@ func (c *compiler) compileBitwiseOp(left Expr, right Expr, op bitwiseOp) (ctype, lt = c.compileToBitwiseUint64(lt, 2) rt = c.compileToBitwiseUint64(rt, 1) - c.asm.BitOp_uu(op) + asm_ins_uu() c.asm.jumpDestination(skip) return ctype{Type: sqltypes.Uint64, Col: collationNumeric}, nil } diff --git a/go/vt/vtgate/evalengine/compiler_fn.go b/go/vt/vtgate/evalengine/compiler_fn.go index 51c75c257f2..1b7a10e776f 100644 --- a/go/vt/vtgate/evalengine/compiler_fn.go +++ b/go/vt/vtgate/evalengine/compiler_fn.go @@ -48,11 +48,11 @@ func (c *compiler) compileFn(call callable) (ctype, error) { case *builtinChangeCase: return c.compileFn_CCASE(call) case *builtinCharLength: - return c.compileFn_LENGTH(call, charLen) + return c.compileFn_xLENGTH(call, c.asm.Fn_CHAR_LENGTH) case *builtinLength: - return c.compileFn_LENGTH(call, byteLen) + return c.compileFn_xLENGTH(call, c.asm.Fn_LENGTH) case *builtinBitLength: - return c.compileFn_LENGTH(call, bitLen) + return c.compileFn_xLENGTH(call, c.asm.Fn_BIT_LENGTH) case *builtinASCII: return c.compileFn_ASCII(call) case *builtinHex: @@ -293,15 +293,7 @@ func (c *compiler) compileFn_CCASE(call *builtinChangeCase) (ctype, error) { return ctype{Type: sqltypes.VarChar, Col: str.Col}, nil } -type lengthOp int - -const ( - charLen lengthOp = iota - byteLen - bitLen -) - -func (c *compiler) compileFn_LENGTH(call callable, op lengthOp) (ctype, error) { +func (c *compiler) compileFn_xLENGTH(call callable, asm_ins func()) (ctype, error) { str, err := c.compileExpr(call.callable()[0]) if err != nil { return ctype{}, err @@ -315,7 +307,7 @@ func (c *compiler) compileFn_LENGTH(call callable, op lengthOp) (ctype, error) { c.asm.Convert_xc(1, sqltypes.VarChar, c.defaultCollation, 0, false) } - c.asm.Fn_LENGTH(op) + asm_ins() c.asm.jumpDestination(skip) return ctype{Type: sqltypes.Int64, Col: collationNumeric}, nil From 10ab0a1030bc05b722468bdfb86cf8ef24189a2e Mon Sep 17 00:00:00 2001 From: Vicent Marti Date: Thu, 16 Mar 2023 16:46:14 +0100 Subject: [PATCH 41/42] evalengine/compiler: de-duplicate math functions Signed-off-by: Vicent Marti --- go/vt/dbconnpool/connection_pool.go | 3 +- go/vt/vtgate/evalengine/compiler_fn.go | 202 +++---------------------- 2 files changed, 18 insertions(+), 187 deletions(-) diff --git a/go/vt/dbconnpool/connection_pool.go b/go/vt/dbconnpool/connection_pool.go index bb339862dcf..e8e4acce017 100644 --- a/go/vt/dbconnpool/connection_pool.go +++ b/go/vt/dbconnpool/connection_pool.go @@ -22,6 +22,7 @@ object to pool these DBConnections. package dbconnpool import ( + "context" "errors" "net" "sync" @@ -29,8 +30,6 @@ import ( "vitess.io/vitess/go/netutil" - "context" - "vitess.io/vitess/go/pools" "vitess.io/vitess/go/stats" "vitess.io/vitess/go/vt/dbconfigs" diff --git a/go/vt/vtgate/evalengine/compiler_fn.go b/go/vt/vtgate/evalengine/compiler_fn.go index 1b7a10e776f..b178d713031 100644 --- a/go/vt/vtgate/evalengine/compiler_fn.go +++ b/go/vt/vtgate/evalengine/compiler_fn.go @@ -70,25 +70,25 @@ func (c *compiler) compileFn(call callable) (ctype, error) { case *builtinPi: return c.compileFn_PI(call) case *builtinAcos: - return c.compileFn_ACOS(call) + return c.compileFn_math1(call, c.asm.Fn_ACOS, flagNullable) case *builtinAsin: - return c.compileFn_ASIN(call) + return c.compileFn_math1(call, c.asm.Fn_ASIN, flagNullable) case *builtinAtan: - return c.compileFn_ATAN(call) + return c.compileFn_math1(call, c.asm.Fn_ATAN, 0) case *builtinAtan2: return c.compileFn_ATAN2(call) case *builtinCos: - return c.compileFn_COS(call) + return c.compileFn_math1(call, c.asm.Fn_COS, 0) case *builtinCot: - return c.compileFn_COT(call) + return c.compileFn_math1(call, c.asm.Fn_COT, 0) case *builtinSin: - return c.compileFn_SIN(call) + return c.compileFn_math1(call, c.asm.Fn_SIN, 0) case *builtinTan: - return c.compileFn_TAN(call) + return c.compileFn_math1(call, c.asm.Fn_TAN, 0) case *builtinDegrees: - return c.compileFn_DEGREES(call) + return c.compileFn_math1(call, c.asm.Fn_DEGREES, 0) case *builtinRadians: - return c.compileFn_RADIANS(call) + return c.compileFn_math1(call, c.asm.Fn_RADIANS, 0) case *builtinWeightString: return c.compileFn_WEIGHT_STRING(call) default: @@ -475,66 +475,22 @@ func (c *compiler) compileFn_ABS(expr *builtinAbs) (ctype, error) { return convt, nil } -func (c *compiler) compileFn_PI(expr *builtinPi) (ctype, error) { +func (c *compiler) compileFn_PI(_ *builtinPi) (ctype, error) { c.asm.Fn_PI() return ctype{Type: sqltypes.Float64, Col: collationNumeric}, nil } -func (c *compiler) compileFn_ACOS(expr *builtinAcos) (ctype, error) { - arg, err := c.compileExpr(expr.Arguments[0]) - if err != nil { - return ctype{}, err - } - - skip := c.compileNullCheck1(arg) - - switch { - case sqltypes.IsFloat(arg.Type): - default: - c.asm.Convert_xf(1) - } - - c.asm.Fn_ACOS() - c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Float64, Col: collationNumeric, Flag: arg.Flag | flagNullable}, nil -} - -func (c *compiler) compileFn_ASIN(expr *builtinAsin) (ctype, error) { - arg, err := c.compileExpr(expr.Arguments[0]) - if err != nil { - return ctype{}, err - } - - skip := c.compileNullCheck1(arg) - - switch { - case sqltypes.IsFloat(arg.Type): - default: - c.asm.Convert_xf(1) - } - - c.asm.Fn_ASIN() - c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Float64, Col: collationNumeric, Flag: arg.Flag | flagNullable}, nil -} - -func (c *compiler) compileFn_ATAN(expr *builtinAtan) (ctype, error) { - arg, err := c.compileExpr(expr.Arguments[0]) +func (c *compiler) compileFn_math1(expr callable, asm_ins func(), nullable typeFlag) (ctype, error) { + arg, err := c.compileExpr(expr.callable()[0]) if err != nil { return ctype{}, err } skip := c.compileNullCheck1(arg) - - switch { - case sqltypes.IsFloat(arg.Type): - default: - c.asm.Convert_xf(1) - } - - c.asm.Fn_ATAN() + c.compileToFloat(arg, 1) + asm_ins() c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Float64, Col: collationNumeric, Flag: arg.Flag}, nil + return ctype{Type: sqltypes.Float64, Col: collationNumeric, Flag: arg.Flag | nullable}, nil } func (c *compiler) compileFn_ATAN2(expr *builtinAtan2) (ctype, error) { @@ -549,137 +505,13 @@ func (c *compiler) compileFn_ATAN2(expr *builtinAtan2) (ctype, error) { } skip := c.compileNullCheck2(arg1, arg2) - - switch { - case sqltypes.IsFloat(arg1.Type) && sqltypes.IsFloat(arg2.Type): - case sqltypes.IsFloat(arg1.Type): - c.asm.Convert_xf(1) - case sqltypes.IsFloat(arg2.Type): - c.asm.Convert_xf(2) - default: - c.asm.Convert_xf(2) - c.asm.Convert_xf(1) - } - + c.compileToFloat(arg1, 2) + c.compileToFloat(arg2, 1) c.asm.Fn_ATAN2() c.asm.jumpDestination(skip) return ctype{Type: sqltypes.Float64, Col: collationNumeric, Flag: arg1.Flag | arg2.Flag}, nil } -func (c *compiler) compileFn_COS(expr *builtinCos) (ctype, error) { - arg, err := c.compileExpr(expr.Arguments[0]) - if err != nil { - return ctype{}, err - } - - skip := c.compileNullCheck1(arg) - - switch { - case sqltypes.IsFloat(arg.Type): - default: - c.asm.Convert_xf(1) - } - - c.asm.Fn_COS() - c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Float64, Col: collationNumeric, Flag: arg.Flag}, nil -} - -func (c *compiler) compileFn_COT(expr *builtinCot) (ctype, error) { - arg, err := c.compileExpr(expr.Arguments[0]) - if err != nil { - return ctype{}, err - } - - skip := c.compileNullCheck1(arg) - - switch { - case sqltypes.IsFloat(arg.Type): - default: - c.asm.Convert_xf(1) - } - - c.asm.Fn_COT() - c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Float64, Col: collationNumeric, Flag: arg.Flag}, nil -} - -func (c *compiler) compileFn_SIN(expr *builtinSin) (ctype, error) { - arg, err := c.compileExpr(expr.Arguments[0]) - if err != nil { - return ctype{}, err - } - - skip := c.compileNullCheck1(arg) - - switch { - case sqltypes.IsFloat(arg.Type): - default: - c.asm.Convert_xf(1) - } - - c.asm.Fn_SIN() - c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Float64, Col: collationNumeric, Flag: arg.Flag}, nil -} - -func (c *compiler) compileFn_TAN(expr *builtinTan) (ctype, error) { - arg, err := c.compileExpr(expr.Arguments[0]) - if err != nil { - return ctype{}, err - } - - skip := c.compileNullCheck1(arg) - - switch { - case sqltypes.IsFloat(arg.Type): - default: - c.asm.Convert_xf(1) - } - - c.asm.Fn_TAN() - c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Float64, Col: collationNumeric, Flag: arg.Flag}, nil -} - -func (c *compiler) compileFn_DEGREES(expr *builtinDegrees) (ctype, error) { - arg, err := c.compileExpr(expr.Arguments[0]) - if err != nil { - return ctype{}, err - } - - skip := c.compileNullCheck1(arg) - - switch { - case sqltypes.IsFloat(arg.Type): - default: - c.asm.Convert_xf(1) - } - - c.asm.Fn_DEGREES() - c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Float64, Col: collationNumeric, Flag: arg.Flag}, nil -} - -func (c *compiler) compileFn_RADIANS(expr *builtinRadians) (ctype, error) { - arg, err := c.compileExpr(expr.Arguments[0]) - if err != nil { - return ctype{}, err - } - - skip := c.compileNullCheck1(arg) - - switch { - case sqltypes.IsFloat(arg.Type): - default: - c.asm.Convert_xf(1) - } - - c.asm.Fn_RADIANS() - c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Float64, Col: collationNumeric, Flag: arg.Flag}, nil -} - func (c *compiler) compileFn_WEIGHT_STRING(call *builtinWeightString) (ctype, error) { str, err := c.compileExpr(call.String) if err != nil { From 45f64971292bf85fb294f2664003411d9d9102a2 Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Thu, 16 Mar 2023 17:24:11 +0100 Subject: [PATCH 42/42] evalengine: Add additional arithmetic tests and fix edge cases This also fixes how we deal with boolean values so we don't rewrite internal booleans on the stack and update what literal booleans mean accidentally. Signed-off-by: Dirkjan Bussink --- .../vtgate/evalengine/api_arithmetic_test.go | 2 +- go/vt/vtgate/evalengine/arithmetic.go | 15 ++++----- go/vt/vtgate/evalengine/compiler_asm.go | 33 ++++++++----------- go/vt/vtgate/evalengine/compiler_compare.go | 10 +++--- go/vt/vtgate/evalengine/compiler_json.go | 4 +-- go/vt/vtgate/evalengine/compiler_logical.go | 2 +- go/vt/vtgate/evalengine/compiler_test.go | 4 +++ go/vt/vtgate/evalengine/eval.go | 2 ++ go/vt/vtgate/evalengine/expr_literal.go | 3 ++ .../evalengine/integration/comparison_test.go | 4 ++- .../evalengine/integration/fuzz_test.go | 14 +++++--- go/vt/vtgate/evalengine/testcases/cases.go | 19 +++++++++++ 12 files changed, 70 insertions(+), 42 deletions(-) diff --git a/go/vt/vtgate/evalengine/api_arithmetic_test.go b/go/vt/vtgate/evalengine/api_arithmetic_test.go index 9a941d5242e..cadd979ca41 100644 --- a/go/vt/vtgate/evalengine/api_arithmetic_test.go +++ b/go/vt/vtgate/evalengine/api_arithmetic_test.go @@ -350,7 +350,7 @@ func TestArithmetics(t *testing.T) { // testing for overflow of float64 v1: NewFloat64(math.MaxFloat64), v2: NewFloat64(0.5), - err: dataOutOfRangeError(math.MaxFloat64, 0.5, "BIGINT", "/").Error(), + err: dataOutOfRangeError(math.MaxFloat64, 0.5, "DOUBLE", "/").Error(), }}, }, { operator: "*", diff --git a/go/vt/vtgate/evalengine/arithmetic.go b/go/vt/vtgate/evalengine/arithmetic.go index 47a2880a746..5b6c0ee538c 100644 --- a/go/vt/vtgate/evalengine/arithmetic.go +++ b/go/vt/vtgate/evalengine/arithmetic.go @@ -17,6 +17,7 @@ limitations under the License. package evalengine import ( + "math" "strings" "golang.org/x/exp/constraints" @@ -193,7 +194,7 @@ func mathSub_iu(v1 int64, v2 uint64) (*evalUint64, error) { } func mathSub_iu0(v1 int64, v2 uint64) (uint64, error) { - if v1 < 0 || v1 < int64(v2) { + if v1 < 0 { return 0, dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "-") } return mathSub_uu0(uint64(v1), v2) @@ -218,12 +219,12 @@ func mathSub_ui(v1 uint64, v2 int64) (*evalUint64, error) { } func mathSub_ui0(v1 uint64, v2 int64) (uint64, error) { - if int64(v1) < v2 && v2 > 0 { + if v2 > 0 && v1 < uint64(v2) { return 0, dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "-") } // uint - (- int) = uint + int if v2 < 0 { - return mathAdd_ui0(v1, -v2) + return mathAdd_uu0(v1, uint64(-v2)) } return mathSub_uu0(v1, uint64(v2)) } @@ -237,7 +238,7 @@ func mathMul_ui0(v1 uint64, v2 int64) (uint64, error) { if v1 == 0 || v2 == 0 { return 0, nil } - if v2 < 0 || int64(v1) < 0 { + if v2 < 0 { return 0, dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "*") } return mathMul_uu0(v1, uint64(v2)) @@ -405,11 +406,9 @@ func mathDiv_ff(v1, v2 float64) (eval, error) { func mathDiv_ff0(v1, v2 float64) (float64, error) { result := v1 / v2 - divisorLessThanOne := v2 < 1 - resultMismatch := v2*result != v1 - if divisorLessThanOne && resultMismatch { - return 0, dataOutOfRangeError(v1, v2, "BIGINT", "/") + if math.IsInf(result, 1) || math.IsInf(result, -1) { + return 0, dataOutOfRangeError(v1, v2, "DOUBLE", "/") } return result, nil } diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index e0971910fef..4e484e9996d 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -819,10 +819,21 @@ func (asm *assembler) Convert_iu(offset int) { }, "CONV INT64(SP-%d), UINT64", offset) } -func (asm *assembler) Convert_nj(offset int) { +func (asm *assembler) Convert_nj(offset int, isBool bool) { asm.emit(func(vm *VirtualMachine) int { arg := vm.stack[vm.sp-offset].(evalNumeric) - vm.stack[vm.sp-offset] = evalConvert_nj(arg) + if intArg, ok := arg.(*evalInt64); isBool && ok { + switch intArg.i { + case 0: + vm.stack[vm.sp-offset] = json.ValueFalse + case 1: + vm.stack[vm.sp-offset] = json.ValueTrue + default: + vm.stack[vm.sp-offset] = json.NewNumber(intArg.ToRawBytes()) + } + } else { + vm.stack[vm.sp-offset] = json.NewNumber(arg.ToRawBytes()) + } return 1 }, "CONV numeric(SP-%d), JSON") } @@ -1989,24 +2000,6 @@ func (asm *assembler) PushColumn_u(offset int) { func (asm *assembler) PushLiteral(lit eval) error { asm.adjustStack(1) - if lit == evalBoolTrue { - asm.emit(func(vm *VirtualMachine) int { - vm.stack[vm.sp] = evalBoolTrue - vm.sp++ - return 1 - }, "PUSH true") - return nil - } - - if lit == evalBoolFalse { - asm.emit(func(vm *VirtualMachine) int { - vm.stack[vm.sp] = evalBoolFalse - vm.sp++ - return 1 - }, "PUSH false") - return nil - } - switch lit := lit.(type) { case *evalInt64: asm.emit(func(vm *VirtualMachine) int { diff --git a/go/vt/vtgate/evalengine/compiler_compare.go b/go/vt/vtgate/evalengine/compiler_compare.go index 0cfaa7703de..15d28370ef5 100644 --- a/go/vt/vtgate/evalengine/compiler_compare.go +++ b/go/vt/vtgate/evalengine/compiler_compare.go @@ -28,7 +28,7 @@ func (c *compiler) compileComparisonTuple(expr *ComparisonExpr) (ctype, error) { switch expr.Op.(type) { case compareNullSafeEQ: c.asm.CmpTupleNullsafe() - return ctype{Type: sqltypes.Int64, Col: collationNumeric}, nil + return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagIsBoolean}, nil case compareEQ: c.asm.CmpTuple(true) c.asm.Cmp_eq_n() @@ -50,7 +50,7 @@ func (c *compiler) compileComparisonTuple(expr *ComparisonExpr) (ctype, error) { default: panic("invalid comparison operator") } - return ctype{Type: sqltypes.Int64, Flag: flagNullable, Col: collationNumeric}, nil + return ctype{Type: sqltypes.Int64, Flag: flagNullable | flagIsBoolean, Col: collationNumeric}, nil } func (c *compiler) compileComparison(expr *ComparisonExpr) (ctype, error) { @@ -100,7 +100,7 @@ func (c *compiler) compileComparison(expr *ComparisonExpr) (ctype, error) { c.asm.CmpNum_ff() } - cmptype := ctype{Type: sqltypes.Int64, Col: collationNumeric} + cmptype := ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagIsBoolean} switch expr.Op.(type) { case compareEQ: @@ -323,7 +323,7 @@ func (c *compiler) compileLike(expr *LikeExpr) (ctype, error) { } c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Int64, Col: collationNumeric}, nil + return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagIsBoolean}, nil } func (c *compiler) compileInTable(lhs ctype, rhs TupleExpr) map[vthash.Hash]struct{} { @@ -374,5 +374,5 @@ func (c *compiler) compileIn(expr *InExpr) (ctype, error) { } c.asm.In_slow(expr.Negate) } - return ctype{Type: sqltypes.Int64, Col: collationNumeric}, nil + return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagIsBoolean}, nil } diff --git a/go/vt/vtgate/evalengine/compiler_json.go b/go/vt/vtgate/evalengine/compiler_json.go index e18bf35242f..f6232394849 100644 --- a/go/vt/vtgate/evalengine/compiler_json.go +++ b/go/vt/vtgate/evalengine/compiler_json.go @@ -53,7 +53,7 @@ func (c *compiler) compileToJSON(doct ctype, offset int) (ctype, error) { case sqltypes.Float64: c.asm.Convert_fj(offset) case sqltypes.Int64, sqltypes.Uint64, sqltypes.Decimal: - c.asm.Convert_nj(offset) + c.asm.Convert_nj(offset, doct.Flag&flagIsBoolean != 0) case sqltypes.VarChar: c.asm.Convert_cj(offset) case sqltypes.VarBinary: @@ -224,7 +224,7 @@ func (c *compiler) compileFn_JSON_CONTAINS_PATH(call *builtinJSONContainsPath) ( } c.asm.Fn_JSON_CONTAINS_PATH(match, paths) - return ctype{Type: sqltypes.Int64, Col: collationNumeric}, nil + return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagIsBoolean}, nil } func (c *compiler) compileFn_JSON_KEYS(call *builtinJSONKeys) (ctype, error) { diff --git a/go/vt/vtgate/evalengine/compiler_logical.go b/go/vt/vtgate/evalengine/compiler_logical.go index b6f964c3fd5..e3ec25754fe 100644 --- a/go/vt/vtgate/evalengine/compiler_logical.go +++ b/go/vt/vtgate/evalengine/compiler_logical.go @@ -27,7 +27,7 @@ func (c *compiler) compileIs(is *IsExpr) (ctype, error) { return ctype{}, err } c.asm.Is(is.Check) - return ctype{Type: sqltypes.Int64, Col: collationNumeric}, nil + return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagIsBoolean}, nil } func (c *compiler) compileCase(cs *CaseExpr) (ctype, error) { diff --git a/go/vt/vtgate/evalengine/compiler_test.go b/go/vt/vtgate/evalengine/compiler_test.go index dd17522398c..ac03d186607 100644 --- a/go/vt/vtgate/evalengine/compiler_test.go +++ b/go/vt/vtgate/evalengine/compiler_test.go @@ -251,6 +251,10 @@ func TestCompiler(t *testing.T) { expression: `CAST(CAST(true AS JSON) AS BINARY)`, result: `BLOB("true")`, }, + { + expression: `JSON_ARRAY(true, 1.0)`, + result: `JSON("[true, 1.0]")`, + }, } for _, tc := range testCases { diff --git a/go/vt/vtgate/evalengine/eval.go b/go/vt/vtgate/evalengine/eval.go index d9169444d3b..0c740cb3efc 100644 --- a/go/vt/vtgate/evalengine/eval.go +++ b/go/vt/vtgate/evalengine/eval.go @@ -36,6 +36,8 @@ const ( flagNull typeFlag = 1 << 0 // flagNullable marks that this value CAN be null flagNullable typeFlag = 1 << 1 + // flagIsBoolean marks that this value should be interpreted as boolean + flagIsBoolean typeFlag = 1 << 2 // flagIntegerUdf marks that this value is math.MinInt64, and will underflow if negated flagIntegerUdf typeFlag = 1 << 5 diff --git a/go/vt/vtgate/evalengine/expr_literal.go b/go/vt/vtgate/evalengine/expr_literal.go index 7a6867704c6..c469a35c39d 100644 --- a/go/vt/vtgate/evalengine/expr_literal.go +++ b/go/vt/vtgate/evalengine/expr_literal.go @@ -52,6 +52,9 @@ func (l *Literal) typeof(*ExpressionEnv) (sqltypes.Type, typeFlag) { if e.i == math.MinInt64 { f |= flagIntegerUdf } + if e == evalBoolTrue || e == evalBoolFalse { + f |= flagIsBoolean + } case *evalUint64: if e.hexLiteral { f |= flagHex diff --git a/go/vt/vtgate/evalengine/integration/comparison_test.go b/go/vt/vtgate/evalengine/integration/comparison_test.go index 450eb3691b0..c101d0aa719 100644 --- a/go/vt/vtgate/evalengine/integration/comparison_test.go +++ b/go/vt/vtgate/evalengine/integration/comparison_test.go @@ -141,6 +141,7 @@ func compareRemoteExprEnv(t *testing.T, env *evalengine.ExpressionEnv, conn *mys var localVal, remoteVal sqltypes.Value var localCollation, remoteCollation collations.ID + var decimals uint32 if localErr == nil { v := local.Value() if debugCheckCollations { @@ -166,6 +167,7 @@ func compareRemoteExprEnv(t *testing.T, env *evalengine.ExpressionEnv, conn *mys if remoteErr == nil { if debugNormalize { remoteVal = normalizeValue(remote.Rows[0][0], collations.ID(remote.Fields[0].Charset)) + decimals = remote.Fields[0].Decimals } else { remoteVal = remote.Rows[0][0] } @@ -179,7 +181,7 @@ func compareRemoteExprEnv(t *testing.T, env *evalengine.ExpressionEnv, conn *mys } } - if diff := compareResult(localErr, remoteErr, localVal, remoteVal, localCollation, remoteCollation); diff != "" { + if diff := compareResult(localErr, remoteErr, localVal, remoteVal, localCollation, remoteCollation, decimals); diff != "" { t.Errorf("%s\nquery: %s (SIMPLIFY=%v)\nrow: %v", diff, localQuery, debugSimplify, env.Row) } else if debugPrintAll { t.Logf("local=%s mysql=%s\nquery: %s\nrow: %v", localVal.String(), remoteVal.String(), localQuery, env.Row) diff --git a/go/vt/vtgate/evalengine/integration/fuzz_test.go b/go/vt/vtgate/evalengine/integration/fuzz_test.go index 59d6f345dc3..4b5a92abe3e 100644 --- a/go/vt/vtgate/evalengine/integration/fuzz_test.go +++ b/go/vt/vtgate/evalengine/integration/fuzz_test.go @@ -296,7 +296,13 @@ type mismatch struct { const tolerance = 1e-14 -func close(a, b float64) bool { +func close(a, b float64, decimals uint32) bool { + if decimals > 0 { + ratio := math.Pow(10, float64(decimals)) + a = math.Round(a*ratio) / ratio + b = math.Round(b*ratio) / ratio + } + if a == b { return true } @@ -306,7 +312,7 @@ func close(a, b float64) bool { return math.Abs((a-b)/b) < tolerance } -func compareResult(localErr, remoteErr error, localVal, remoteVal sqltypes.Value, localCollation, remoteCollation collations.ID) string { +func compareResult(localErr, remoteErr error, localVal, remoteVal sqltypes.Value, localCollation, remoteCollation collations.ID, decimals uint32) string { if localErr != nil { if remoteErr == nil { return fmt.Sprintf("%v; mysql response: %s", localErr, remoteVal) @@ -344,7 +350,7 @@ func compareResult(localErr, remoteErr error, localVal, remoteVal sqltypes.Value if err != nil { return fmt.Sprintf("error converting remote value to float: %v", err) } - if !close(localFloat, remoteFloat) { + if !close(localFloat, remoteFloat, decimals) { return fmt.Sprintf("different results: %s; mysql response: %s (local collation: %s; mysql collation: %s)", localVal.String(), remoteVal.String(), localCollationName, remoteCollationName) } @@ -362,5 +368,5 @@ func compareResult(localErr, remoteErr error, localVal, remoteVal sqltypes.Value } func (cr *mismatch) Error() string { - return compareResult(cr.localErr, cr.remoteErr, cr.localVal, cr.remoteVal, collations.Unknown, collations.Unknown) + return compareResult(cr.localErr, cr.remoteErr, cr.localVal, cr.remoteVal, collations.Unknown, collations.Unknown, 0) } diff --git a/go/vt/vtgate/evalengine/testcases/cases.go b/go/vt/vtgate/evalengine/testcases/cases.go index e1db6b0a7d0..d93cb3b837b 100644 --- a/go/vt/vtgate/evalengine/testcases/cases.go +++ b/go/vt/vtgate/evalengine/testcases/cases.go @@ -46,6 +46,7 @@ var Cases = []TestCase{ {Run: FloatFormatting}, {Run: UnderscoreAndPercentage}, {Run: Types}, + {Run: Arithmetic}, {Run: HexArithmetic}, {Run: NumericTypes}, {Run: NegateArithmetic}, @@ -535,6 +536,24 @@ func Types(yield Query) { } } +func Arithmetic(yield Query) { + operators := []string{"+", "-", "*", "/"} + + for _, op := range operators { + for _, lhs := range inputConversions { + for _, rhs := range inputConversions { + yield(fmt.Sprintf("%s %s %s", lhs, op, rhs), nil) + } + } + + for _, lhs := range inputBitwise { + for _, rhs := range inputBitwise { + yield(fmt.Sprintf("%s %s %s", lhs, op, rhs), nil) + } + } + } +} + func HexArithmetic(yield Query) { var cases = []string{ `0`, `1`, `1.0`, `0.0`, `1.0e0`, `0.0e0`,