From b6fcfbd4b641577fb3e16fb7929474517a1a2496 Mon Sep 17 00:00:00 2001 From: Raphael 'kena' Poss Date: Tue, 28 Jun 2016 20:56:42 +0000 Subject: [PATCH] sql: fix and enhance expression normalization. This patch is needed so that the expressions in the Sqllogictest do not cause the server to panic. Summary: - Bug fixes: - an invalid cast was attempted for some expressions; is now fixed. - NULL caused invalid transformations, NULL is now normalized on a fast path. - transformations that rebalance arithmetic across a comparison now first check that the proposed replacement arithmetic expression has a valid overload w.r.t typing. - Improvements: more simplifications. --- sql/parser/col_types.go | 29 +++ sql/parser/expr.go | 44 ++++ sql/parser/normalize.go | 481 ++++++++++++++++++++++++++++------- sql/parser/normalize_test.go | 32 +++ sql/parser/type_check.go | 17 ++ 5 files changed, 507 insertions(+), 96 deletions(-) diff --git a/sql/parser/col_types.go b/sql/parser/col_types.go index 9c9d6fe71d7e..a5ca0d01a214 100644 --- a/sql/parser/col_types.go +++ b/sql/parser/col_types.go @@ -19,6 +19,8 @@ package parser import ( "bytes" "fmt" + + "github.com/pkg/errors" ) // ColumnType represents a type in a column definition. @@ -247,3 +249,30 @@ func (node *TimestampTZColType) String() string { return AsString(node) } func (node *IntervalColType) String() string { return AsString(node) } func (node *StringColType) String() string { return AsString(node) } func (node *BytesColType) String() string { return AsString(node) } + +// DatumTypeToColumnType produces a SQL column type equivalent to the +// given Datum type. Used to generate CastExpr nodes during +// normalization. +func DatumTypeToColumnType(d Datum) (ColumnType, error) { + switch d.(type) { + case *DInt: + return intColTypeInt, nil + case *DFloat: + return floatColTypeFloat, nil + case *DDecimal: + return decimalColTypeDecimal, nil + case *DTimestamp: + return timestampColTypeTimestamp, nil + case *DTimestampTZ: + return timestampTzColTypeTimestampWithTZ, nil + case *DInterval: + return intervalColTypeInterval, nil + case *DDate: + return dateColTypeDate, nil + case *DString: + return stringColTypeString, nil + case *DBytes: + return bytesColTypeBytes, nil + } + return nil, errors.Errorf("internal error: unknown Datum type %T", d) +} diff --git a/sql/parser/expr.go b/sql/parser/expr.go index 629abdc00911..899d030f8127 100644 --- a/sql/parser/expr.go +++ b/sql/parser/expr.go @@ -468,6 +468,21 @@ type IfExpr struct { typeAnnotation } +// TypedTrueExpr returns the IfExpr's True expression as a TypedExpr. +func (node *IfExpr) TypedTrueExpr() TypedExpr { + return node.True.(TypedExpr) +} + +// TypedCondExpr returns the IfExpr's Cond expression as a TypedExpr. +func (node *IfExpr) TypedCondExpr() TypedExpr { + return node.Cond.(TypedExpr) +} + +// TypedElseExpr returns the IfExpr's Else expression as a TypedExpr. +func (node *IfExpr) TypedElseExpr() TypedExpr { + return node.Else.(TypedExpr) +} + // Format implements the NodeFormatter interface. func (node *IfExpr) Format(buf *bytes.Buffer, f FmtFlags) { buf.WriteString("IF(") @@ -504,6 +519,11 @@ type CoalesceExpr struct { typeAnnotation } +// TypedExprAt returns the expression at the specified index as a TypedExpr. +func (node *CoalesceExpr) TypedExprAt(idx int) TypedExpr { + return node.Exprs[idx].(TypedExpr) +} + // Format implements the NodeFormatter interface. func (node *CoalesceExpr) Format(buf *bytes.Buffer, f FmtFlags) { buf.WriteString(node.Name) @@ -979,6 +999,25 @@ func (node *BinaryExpr) memoizeFn() { node.fn = fn } +// newBinExprIfValidOverload constructs a new BinaryExpr if and only +// if the pair of arguments have a valid implementation for the given +// BinaryOperator. +func newBinExprIfValidOverload(op BinaryOperator, left TypedExpr, right TypedExpr) *BinaryExpr { + leftRet, rightRet := left.ReturnType(), right.ReturnType() + fn, ok := BinOps[op].lookupImpl(leftRet, rightRet) + if ok { + expr := &BinaryExpr{ + Operator: op, + Left: left, + Right: right, + fn: fn, + } + expr.typ = fn.returnType() + return expr + } + return nil +} + // Format implements the NodeFormatter interface. func (node *BinaryExpr) Format(buf *bytes.Buffer, f FmtFlags) { binExprFmtWithParen(buf, f, node.Left, node.Operator.String(), node.Right) @@ -1025,6 +1064,11 @@ func (node *UnaryExpr) Format(buf *bytes.Buffer, f FmtFlags) { exprFmtWithParen(buf, f, node.Expr) } +// TypedInnerExpr returns the UnaryExpr's inner expression as a TypedExpr. +func (node *UnaryExpr) TypedInnerExpr() TypedExpr { + return node.Expr.(TypedExpr) +} + // FuncExpr represents a function call. type FuncExpr struct { Name *QualifiedName diff --git a/sql/parser/normalize.go b/sql/parser/normalize.go index 8ff91d66f043..a1bc6c4d7785 100644 --- a/sql/parser/normalize.go +++ b/sql/parser/normalize.go @@ -16,27 +16,194 @@ package parser -import "fmt" +import "github.com/pkg/errors" type normalizableExpr interface { Expr normalize(*normalizeVisitor) TypedExpr } +func (expr *CastExpr) normalize(v *normalizeVisitor) TypedExpr { + if expr.Expr == DNull { + return DNull + } + return expr +} + +func (expr *CoalesceExpr) normalize(v *normalizeVisitor) TypedExpr { + // This normalization checks whether COALESCE can be simplified + // based on constant expressions at the start of the COALESCE + // argument list. All known-null constant arguments are simply + // removed, and any known-nonnull constant argument before + // non-constant argument cause the entire COALESCE expression to + // collapse to that argument. + last := len(expr.Exprs) - 1 + for i := range expr.Exprs { + subExpr := expr.TypedExprAt(i) + + if i == last { + return subExpr + } + + if !v.isConst(subExpr) { + exprCopy := *expr + exprCopy.Exprs = expr.Exprs[i:] + return &exprCopy + } + + val, err := subExpr.Eval(v.ctx) + if err != nil { + v.err = err + return expr + } + + if val != DNull { + return subExpr + } + } + return expr +} + +func (expr *IfExpr) normalize(v *normalizeVisitor) TypedExpr { + if v.isConst(expr.Cond) { + cond, err := expr.TypedCondExpr().Eval(v.ctx) + if err != nil { + v.err = err + return expr + } + if d, err := GetBool(cond); err == nil { + if d { + return expr.TypedTrueExpr() + } + return expr.TypedElseExpr() + } + return DNull + } + return expr +} + +func (expr *UnaryExpr) normalize(v *normalizeVisitor) TypedExpr { + val := expr.TypedInnerExpr() + + if val == DNull { + return val + } + + switch expr.Operator { + case UnaryPlus: + // +a -> a + return val + case UnaryMinus: + // -0 -> 0 (except for float which has negative zero) + if !val.ReturnType().TypeEqual(TypeFloat) && IsNumericZero(val) { + return val + } + switch b := val.(type) { + // -(a - b) -> (b - a) + case *BinaryExpr: + if b.Operator == Minus { + newBinExpr := newBinExprIfValidOverload(Minus, + b.TypedRight(), b.TypedLeft()) + if newBinExpr != nil { + newBinExpr.memoizeFn() + b = newBinExpr + } + return b + } + // - (- a) -> a + case *UnaryExpr: + if b.Operator == UnaryMinus { + return b.TypedInnerExpr() + } + } + } + + return expr +} + func (expr *BinaryExpr) normalize(v *normalizeVisitor) TypedExpr { left := expr.TypedLeft() right := expr.TypedRight() + expectedType := expr.ReturnType() if left == DNull || right == DNull { return DNull } - return expr + var final TypedExpr + + switch expr.Operator { + case Plus: + if IsNumericZero(right) { + final, v.err = ReType(left, expectedType) + break + } + if IsNumericZero(left) { + final, v.err = ReType(right, expectedType) + break + } + case Minus: + if IsNumericZero(right) { + final, v.err = ReType(left, expectedType) + break + } + case Mult: + if IsNumericOne(right) { + final, v.err = ReType(left, expectedType) + break + } + if IsNumericOne(left) { + final, v.err = ReType(right, expectedType) + break + } + if IsNumericZero(left) || IsNumericZero(right) { + final, v.err = SameTypeZero(expectedType) + break + } + case Div, FloorDiv: + if IsNumericOne(right) { + final, v.err = ReType(left, expectedType) + break + } + if IsNumericZero(left) { + cbz, err := CanBeZeroDivider(right) + if err != nil { + v.err = err + break + } + if !cbz { + final, v.err = ReType(left, expectedType) + break + } + } + case Mod: + if IsNumericOne(right) { + final, v.err = SameTypeZero(expectedType) + break + } + if IsNumericZero(left) { + cbz, err := CanBeZeroDivider(right) + if err != nil { + v.err = err + break + } + if !cbz { + final, v.err = ReType(left, expectedType) + break + } + } + } + + if final == nil { + return expr + } + return final } func (expr *AndExpr) normalize(v *normalizeVisitor) TypedExpr { left := expr.TypedLeft() right := expr.TypedRight() + var dleft, dright Datum if left == DNull && right == DNull { return DNull @@ -44,31 +211,31 @@ func (expr *AndExpr) normalize(v *normalizeVisitor) TypedExpr { // Use short-circuit evaluation to simplify AND expressions. if v.isConst(left) { - left, v.err = left.Eval(v.ctx) + dleft, v.err = left.Eval(v.ctx) if v.err != nil { return expr } - if left != DNull { - if d, err := GetBool(left.(Datum)); err == nil { + if dleft != DNull { + if d, err := GetBool(dleft); err == nil { if !d { - return left + return dleft } return right } return DNull } return NewTypedAndExpr( - left, + dleft, right, ) } if v.isConst(right) { - right, v.err = right.Eval(v.ctx) + dright, v.err = right.Eval(v.ctx) if v.err != nil { return expr } - if right != DNull { - if d, err := GetBool(expr.Right.(Datum)); err == nil { + if dright != DNull { + if d, err := GetBool(dright); err == nil { if !d { return right } @@ -78,7 +245,7 @@ func (expr *AndExpr) normalize(v *normalizeVisitor) TypedExpr { } return NewTypedAndExpr( left, - right, + dright, ) } return expr @@ -87,10 +254,6 @@ func (expr *AndExpr) normalize(v *normalizeVisitor) TypedExpr { func (expr *ComparisonExpr) normalize(v *normalizeVisitor) TypedExpr { switch expr.Operator { case EQ, GE, GT, LE, LT: - if expr.TypedLeft() == DNull || expr.TypedRight() == DNull { - return DNull - } - // We want var nodes (VariableExpr, QualifiedName, etc) to be immediate // children of the comparison expression and not second or third // children. That is, we want trees that look like: @@ -114,6 +277,10 @@ func (expr *ComparisonExpr) normalize(v *normalizeVisitor) TypedExpr { // tree or we would not have entered this code path. exprCopied := false for { + if expr.TypedLeft() == DNull || expr.TypedRight() == DNull { + return DNull + } + if v.isConst(expr.Left) { switch expr.Right.(type) { case *BinaryExpr, VariableExpr: @@ -121,6 +288,13 @@ func (expr *ComparisonExpr) normalize(v *normalizeVisitor) TypedExpr { default: return expr } + + invertedOp, err := invertComparisonOp(expr.Operator) + if err != nil { + v.err = nil + return expr + } + // The left side is const and the right side is a binary expression or a // variable. Flip the comparison op so that the right side is const and // the left side is a binary expression or variable. @@ -130,11 +304,8 @@ func (expr *ComparisonExpr) normalize(v *normalizeVisitor) TypedExpr { expr = &exprCopy exprCopied = true } - expr = NewTypedComparisonExpr( - invertComparisonOp(expr.Operator), - expr.TypedRight(), - expr.TypedLeft(), - ) + + expr = NewTypedComparisonExpr(invertedOp, expr.TypedRight(), expr.TypedLeft()) } else if !v.isConst(expr.Right) { return expr } @@ -156,6 +327,22 @@ func (expr *ComparisonExpr) normalize(v *normalizeVisitor) TypedExpr { // [+-/] 2 -> a [-+*] // / \ / \ // a 1 2 1 + var op BinaryOperator + switch left.Operator { + case Plus: + op = Minus + case Minus: + op = Plus + case Div: + op = Mult + } + + newBinExpr := newBinExprIfValidOverload(op, + expr.TypedRight(), left.TypedRight()) + if newBinExpr == nil { + // Substitution is not possible type-wise. Nothing else to do. + break + } if !exprCopied { exprCopy := *expr @@ -163,26 +350,12 @@ func (expr *ComparisonExpr) normalize(v *normalizeVisitor) TypedExpr { exprCopied = true } - leftCopy := *left - left = &leftCopy - - switch left.Operator { - case Plus: - left.Operator = Minus - case Minus: - left.Operator = Plus - case Div: - left.Operator = Mult - } - expr.Left = left.Left - left.Left = expr.Right - - left.memoizeFn() - expr.Right, v.err = left.Eval(v.ctx) + expr.Right, v.err = newBinExpr.Eval(v.ctx) if v.err != nil { - return nil + return expr } + expr.memoizeFn() if !isVar(expr.Left) { // Continue as long as the left side of the comparison is not a @@ -197,30 +370,45 @@ func (expr *ComparisonExpr) normalize(v *normalizeVisitor) TypedExpr { // / \ / \ // 1 a 1 2 + op := expr.Operator + var newBinExpr *BinaryExpr + + switch left.Operator { + case Plus: + // + // (A + X) cmp B => X cmp (B - C) + // + newBinExpr = newBinExprIfValidOverload(Minus, + expr.TypedRight(), left.TypedLeft()) + case Minus: + // + // (A - X) cmp B => X cmp' (A - B) + // + newBinExpr = newBinExprIfValidOverload(Minus, + left.TypedLeft(), expr.TypedRight()) + op, v.err = invertComparisonOp(op) + if v.err != nil { + return expr + } + } + + if newBinExpr == nil { + break + } + if !exprCopied { exprCopy := *expr expr = &exprCopy exprCopied = true } - leftCopy := *left - left = &leftCopy - - // Clear the function caches; we're about to change stuff. - left.Right, expr.Right = expr.Right, left.Right - if left.Operator == Plus { - left.Operator = Minus - left.Left, left.Right = left.Right, left.Left - } else { - expr.Operator = invertComparisonOp(expr.Operator) - } - - left.memoizeFn() - expr.Left, v.err = left.Eval(v.ctx) + expr.Operator = op + expr.Left = left.Right + expr.Right, v.err = newBinExpr.Eval(v.ctx) if v.err != nil { return nil } - expr.Left, expr.Right = expr.Right, expr.Left + expr.memoizeFn() if !isVar(expr.Left) { // Continue as long as the left side of the comparison is not a @@ -264,6 +452,7 @@ func (expr *ComparisonExpr) normalize(v *normalizeVisitor) TypedExpr { func (expr *OrExpr) normalize(v *normalizeVisitor) TypedExpr { left := expr.TypedLeft() right := expr.TypedRight() + var dleft, dright Datum if left == DNull && right == DNull { return DNull @@ -271,31 +460,31 @@ func (expr *OrExpr) normalize(v *normalizeVisitor) TypedExpr { // Use short-circuit evaluation to simplify OR expressions. if v.isConst(left) { - left, v.err = left.Eval(v.ctx) + dleft, v.err = left.Eval(v.ctx) if v.err != nil { return expr } - if left != DNull { - if d, err := GetBool(left.(Datum)); err == nil { + if dleft != DNull { + if d, err := GetBool(dleft); err == nil { if d { - return left + return dleft } return right } return DNull } return NewTypedOrExpr( - left, + dleft, right, ) } if v.isConst(right) { - right, v.err = right.Eval(v.ctx) + dright, v.err = right.Eval(v.ctx) if v.err != nil { return expr } - if right != DNull { - if d, err := GetBool(right.(Datum)); err == nil { + if dright != DNull { + if d, err := GetBool(dright); err == nil { if d { return right } @@ -305,14 +494,21 @@ func (expr *OrExpr) normalize(v *normalizeVisitor) TypedExpr { } return NewTypedOrExpr( left, - right, + dright, ) } return expr } func (expr *ParenExpr) normalize(v *normalizeVisitor) TypedExpr { - return expr.TypedInnerExpr() + newExpr := expr.TypedInnerExpr() + if normalizeable, ok := newExpr.(normalizableExpr); ok { + newExpr = normalizeable.normalize(v) + if v.err != nil { + return expr + } + } + return newExpr } func (expr *AnnotateTypeExpr) normalize(v *normalizeVisitor) TypedExpr { @@ -329,17 +525,27 @@ func (expr *RangeCond) normalize(v *normalizeVisitor) TypedExpr { if expr.Not { // "a NOT BETWEEN b AND c" -> "a < b OR a > c" - return NewTypedOrExpr( - NewTypedComparisonExpr(LT, left, from), - NewTypedComparisonExpr(GT, left, to), - ) + newLeft := NewTypedComparisonExpr(LT, left, from).normalize(v) + if v.err != nil { + return expr + } + newRight := NewTypedComparisonExpr(GT, left, to).normalize(v) + if v.err != nil { + return expr + } + return NewTypedOrExpr(newLeft, newRight).normalize(v) } // "a BETWEEN b AND c" -> "a >= b AND a <= c" - return NewTypedAndExpr( - NewTypedComparisonExpr(GE, left, from), - NewTypedComparisonExpr(LE, left, to), - ) + newLeft := NewTypedComparisonExpr(GE, left, from).normalize(v) + if v.err != nil { + return expr + } + newRight := NewTypedComparisonExpr(LE, left, to).normalize(v) + if v.err != nil { + return expr + } + return NewTypedAndExpr(newLeft, newRight).normalize(v) } // NormalizeExpr normalizes a typed expression, simplifying where possible, @@ -375,25 +581,7 @@ func (v *normalizeVisitor) VisitPre(expr Expr) (recurse bool, newExpr Expr) { return false, expr } - // Normalize expressions that know how to normalize themselves. - if normalizeable, ok := expr.(normalizableExpr); ok { - expr = normalizeable.normalize(v) - if v.err != nil { - return false, expr - } - } - switch expr.(type) { - case *CaseExpr, *IfExpr, *NullIfExpr, *CoalesceExpr: - // Conditional expressions need to be evaluated during the downward - // traversal in order to avoid evaluating sub-expressions which should - // not be evaluated due to the case/conditional. - if v.isConst(expr) { - expr, v.err = expr.(TypedExpr).Eval(v.ctx) - if v.err != nil { - return false, expr - } - } case *Subquery: // Avoid normalizing subqueries. We need the subquery to be expanded in // order to do so properly. @@ -420,7 +608,11 @@ func (v *normalizeVisitor) VisitPost(expr Expr) Expr { // Evaluate all constant expressions. if v.isConst(expr) { - expr, v.err = expr.(TypedExpr).Eval(v.ctx) + newExpr, err := expr.(TypedExpr).Eval(v.ctx) + if err != nil { + return expr + } + expr = newExpr } return expr } @@ -429,20 +621,20 @@ func (v *normalizeVisitor) isConst(expr Expr) bool { return v.isConstVisitor.run(expr) } -func invertComparisonOp(op ComparisonOperator) ComparisonOperator { +func invertComparisonOp(op ComparisonOperator) (ComparisonOperator, error) { switch op { case EQ: - return EQ + return EQ, nil case GE: - return LE + return LE, nil case GT: - return LT + return LT, nil case LE: - return GE + return GE, nil case LT: - return GT + return GT, nil default: - panic(fmt.Sprintf("unable to invert: %s", op)) + return op, errors.Errorf("internal error: unable to invert: %s", op) } } @@ -507,3 +699,100 @@ func ContainsVars(expr Expr) bool { WalkExprConst(&v, expr) return v.containsVars } + +// DecimalZero represents the constant 0 as DECIMAL. +var DecimalZero DDecimal + +// DecimalOne represents the constant 1 as DECIMAL. +var DecimalOne DDecimal + +func init() { + DecimalOne.Dec.SetUnscaled(1).SetScale(0) + DecimalZero.Dec.SetUnscaled(0).SetScale(0) +} + +// IsNumericZero returns true if the datum is a number and equal to +// zero. +func IsNumericZero(expr TypedExpr) bool { + if d, ok := expr.(Datum); ok { + switch t := d.(type) { + case *DDecimal: + return t.Dec.Cmp(&DecimalZero.Dec) == 0 + case *DFloat: + return *t == 0 + case *DInt: + return *t == 0 + } + } + return false +} + +// CanBeZeroDivider returns true if the expr may be a number and equal +// to zero. It also returns true if it is not known yet whether the +// expr is a number (e.g. it is a more complex sub-expression that has +// resisted normalization). +func CanBeZeroDivider(expr TypedExpr) (bool, error) { + if d, ok := expr.(Datum); ok { + if _, ok := d.(dividerDatum); ok { + switch t := d.(type) { + case *DDecimal: + return t.Dec.Cmp(&DecimalZero.Dec) == 0, nil + case *DFloat: + return *t == 0, nil + case *DInt: + return *t == 0, nil + default: + return true, errors.Errorf("internal error: unknown dividerDatum in CanBeZeroDivider(): %T", d) + } + } + // All other Datums are non-numeric and thus cannot be zero. + return false, nil + } + // Other type of expression; may evaluate to zero. + return true, nil +} + +// IsNumericOne returns true if the datum is a number and equal to +// one. +func IsNumericOne(expr TypedExpr) bool { + if d, ok := expr.(Datum); ok { + switch t := d.(type) { + case *DDecimal: + return t.Dec.Cmp(&DecimalOne.Dec) == 0 + case *DFloat: + return *t == 1.0 + case *DInt: + return *t == 1 + } + } + return false +} + +// SameTypeZero returns a datum of equivalent type with value zero. +// The argument must be a datum of a numeric type. +func SameTypeZero(d Datum) (TypedExpr, error) { + switch d.(type) { + case *DDecimal: + return &DecimalZero, nil + case *DFloat: + return NewDFloat(0.0), nil + case *DInt: + return NewDInt(0), nil + } + return nil, errors.Errorf("internal error: zero not defined for Datum type %T", d) +} + +// ReType ensures that the given numeric expression evaluates +// to the requested type, inserting a cast if necessary. +func ReType(expr TypedExpr, wantedType Datum) (TypedExpr, error) { + if expr.ReturnType().TypeEqual(wantedType) { + return expr, nil + } + reqType, err := DatumTypeToColumnType(wantedType) + if err != nil { + return nil, err + } + res := &CastExpr{Expr: expr, Type: reqType} + res.typ = wantedType + return res, nil +} diff --git a/sql/parser/normalize_test.go b/sql/parser/normalize_test.go index fa34b1a73d85..d9568ddb01e4 100644 --- a/sql/parser/normalize_test.go +++ b/sql/parser/normalize_test.go @@ -32,6 +32,34 @@ func TestNormalizeExpr(t *testing.T) { }{ {`(a)`, `a`}, {`((((a))))`, `a`}, + {`CAST(NULL AS INTEGER)`, `NULL`}, + {`+a`, `a`}, + {`-(-a)`, `a`}, + {`-+-a`, `a`}, + {`-(a-b)`, `b - a`}, + {`-0`, `0`}, + {`-NULL`, `NULL`}, + {`-1`, `-1`}, + {`a+0`, `a`}, + {`0+a`, `a`}, + {`a+(2-2)`, `a`}, + {`a-0`, `a`}, + {`a*1`, `a`}, + {`1*a`, `a`}, + {`0*a`, `0`}, + {`a+NULL`, `NULL`}, + {`a/1`, `CAST(a AS DECIMAL)`}, + {`0/a`, `0 / a`}, + {`0/1`, `0`}, + {`0%a`, `0 % a`}, + {`0%1`, `0`}, + {`1%1`, `0`}, + {`0%2`, `0`}, + {`1%2`, `1`}, + {`1%-2`, `1`}, + {`12 BETWEEN 24 AND 36`, `false`}, + {`12 BETWEEN 10 AND 20`, `true`}, + {`10 BETWEEN a AND 20`, `a <= 10`}, {`a BETWEEN b AND c`, `(a >= b) AND (a <= c)`}, {`a NOT BETWEEN b AND c`, `(a < b) OR (a > c)`}, {`a BETWEEN NULL AND c`, `NULL`}, @@ -100,6 +128,10 @@ func TestNormalizeExpr(t *testing.T) { {`ANNOTATE_TYPE(1, float)`, `1.0`}, // TODO(nvanbenschoten) introduce a shorthand type annotation notation. // {`1!float`, `1.0`}, + {`IF((true AND a < 0), (0 + a)::decimal, 2 / (1 - 1))`, `IF(a < 0, CAST(a AS DECIMAL), 2 / 0)`}, + {`IF((true OR a < 0), (0 + a)::decimal, 2 / (1 - 1))`, `CAST(a AS DECIMAL)`}, + {`COALESCE(NULL, (NULL < 3), a = 2 - 1, d)`, `COALESCE(a = 1, d)`}, + {`COALESCE(NULL, a)`, `a`}, } for _, d := range testData { expr, err := ParseExprTraditional(d.expr) diff --git a/sql/parser/type_check.go b/sql/parser/type_check.go index 159713c21508..4c5525510d6d 100644 --- a/sql/parser/type_check.go +++ b/sql/parser/type_check.go @@ -49,6 +49,23 @@ var ( TypeTuple Datum = &DTuple{} ) +// dividerDatum is used during normalization to determine which Datums +// are possible valid second operands to an arithmetic divide. See +// CanBeZeroDivider() for an example. +// Ensure this is implemented appropriately if/when new numeric types +// are added. +type dividerDatum interface { + dividerDatum() +} + +var _ dividerDatum = NewDInt(0) +var _ dividerDatum = NewDFloat(0) +var _ dividerDatum = &DDecimal{} + +func (*DInt) dividerDatum() {} +func (*DFloat) dividerDatum() {} +func (*DDecimal) dividerDatum() {} + // SemaContext defines the context in which to perform semantic analysis on an // expression syntax tree. type SemaContext struct {