From 50ffd0f50aac3f500116fa273e835deb8f7f813e Mon Sep 17 00:00:00 2001 From: Raphael 'kena' Poss Date: Tue, 28 Jun 2016 20:56:42 +0000 Subject: [PATCH] sql: fix and enhance expression normalization. This patch is needed so that the expressions in the Sqllogictest do not cause the server to panic. Summary: - Bug fixes: - an invalid cast was attempted for some expressions; is now fixed. - NULL caused invalid transformations, NULL is now normalized on a fast path. - transformations that rebalance arithmetic across a comparison now first check that the proposed replacement arithmetic expression has a valid overload w.r.t typing. - Improvements: more simplifications. --- sql/parser/col_types.go | 29 +++ sql/parser/expr.go | 43 ++++ sql/parser/normalize.go | 411 ++++++++++++++++++++++++++++------- sql/parser/normalize_test.go | 32 +++ sql/parser/type_check.go | 17 ++ 5 files changed, 456 insertions(+), 76 deletions(-) diff --git a/sql/parser/col_types.go b/sql/parser/col_types.go index 9c9d6fe71d7e..d54750e73999 100644 --- a/sql/parser/col_types.go +++ b/sql/parser/col_types.go @@ -19,6 +19,8 @@ package parser import ( "bytes" "fmt" + + "github.com/pkg/errors" ) // ColumnType represents a type in a column definition. @@ -247,3 +249,30 @@ func (node *TimestampTZColType) String() string { return AsString(node) } func (node *IntervalColType) String() string { return AsString(node) } func (node *StringColType) String() string { return AsString(node) } func (node *BytesColType) String() string { return AsString(node) } + +// DatumTypeToColumnType produces a SQL column type equivalent to the +// given Datum type. Used to generate CastExpr nodes during +// normalization. +func DatumTypeToColumnType(d Datum) (ColumnType, error) { + switch d.(type) { + case *DInt: + return &IntColType{"INT", 0}, nil + case *DFloat: + return &FloatColType{"FLOAT", 0}, nil + case *DDecimal: + return &DecimalColType{"DECIMAL", 0, 0}, nil + case *DTimestamp: + return &TimestampColType{}, nil + case *DTimestampTZ: + return &TimestampTZColType{}, nil + case *DInterval: + return &IntervalColType{}, nil + case *DDate: + return &DateColType{}, nil + case *DString: + return &StringColType{"STRING", 0}, nil + case *DBytes: + return &BytesColType{"BYTES"}, nil + } + return nil, errors.Errorf("internal error: unknown Datum type %T", d) +} diff --git a/sql/parser/expr.go b/sql/parser/expr.go index 0328a1e84561..fb8cbace3aa8 100644 --- a/sql/parser/expr.go +++ b/sql/parser/expr.go @@ -456,6 +456,21 @@ type IfExpr struct { typeAnnotation } +// TypedTrueExpr returns the IfExpr's True expression as a TypedExpr. +func (node *IfExpr) TypedTrueExpr() TypedExpr { + return node.True.(TypedExpr) +} + +// TypedCondExpr returns the IfExpr's Cond expression as a TypedExpr. +func (node *IfExpr) TypedCondExpr() TypedExpr { + return node.Cond.(TypedExpr) +} + +// TypedElseExpr returns the IfExpr's Else expression as a TypedExpr. +func (node *IfExpr) TypedElseExpr() TypedExpr { + return node.Else.(TypedExpr) +} + // Format implements the NodeFormatter interface. func (node *IfExpr) Format(buf *bytes.Buffer, f FmtFlags) { buf.WriteString("IF(") @@ -492,6 +507,11 @@ type CoalesceExpr struct { typeAnnotation } +// TypedExprAt returns the expression at the specified index as a TypedExpr. +func (node *CoalesceExpr) TypedExprAt(idx int) TypedExpr { + return node.Exprs[idx].(TypedExpr) +} + // Format implements the NodeFormatter interface. func (node *CoalesceExpr) Format(buf *bytes.Buffer, f FmtFlags) { buf.WriteString(node.Name) @@ -967,6 +987,24 @@ func (node *BinaryExpr) memoizeFn() { node.fn = fn } +// newBinExprIfValidOverloads constructs a new BinaryExpr if and only +// if the pair of arguments have a valid implementation. +func newBinExprIfValidOverload(op BinaryOperator, left TypedExpr, right TypedExpr) *BinaryExpr { + leftRet, rightRet := left.ReturnType(), right.ReturnType() + fn, ok := BinOps[op].lookupImpl(leftRet, rightRet) + if ok { + expr := &BinaryExpr{ + Operator: op, + Left: left, + Right: right, + fn: fn, + } + expr.typ = fn.ReturnType + return expr + } + return nil +} + // Format implements the NodeFormatter interface. func (node *BinaryExpr) Format(buf *bytes.Buffer, f FmtFlags) { binExprFmtWithParen(buf, f, node.Left, node.Operator.String(), node.Right) @@ -1013,6 +1051,11 @@ func (node *UnaryExpr) Format(buf *bytes.Buffer, f FmtFlags) { exprFmtWithParen(buf, f, node.Expr) } +// TypedInnerExpr returns the UnaryExpr's inner expression as a TypedExpr. +func (node *UnaryExpr) TypedInnerExpr() TypedExpr { + return node.Expr.(TypedExpr) +} + // FuncExpr represents a function call. type FuncExpr struct { Name *QualifiedName diff --git a/sql/parser/normalize.go b/sql/parser/normalize.go index 446eb658ce6f..95b6bb9cd2e3 100644 --- a/sql/parser/normalize.go +++ b/sql/parser/normalize.go @@ -16,22 +16,174 @@ package parser -import "fmt" +import "github.com/pkg/errors" type normalizableExpr interface { Expr normalize(*normalizeVisitor) TypedExpr } +func (expr *CastExpr) normalize(v *normalizeVisitor) TypedExpr { + if expr.Expr == DNull { + return DNull + } + return expr +} + +func (expr *CoalesceExpr) normalize(v *normalizeVisitor) TypedExpr { + for i := range expr.Exprs { + subExpr := expr.TypedExprAt(i) + + if i == len(expr.Exprs)-1 { + return subExpr + } + + if !v.isConst(subExpr) { + exprCopy := *expr + exprCopy.Exprs = expr.Exprs[i:] + return &exprCopy + } + + val, err := subExpr.Eval(v.ctx) + if err != nil { + v.err = err + return expr + } + + if val != DNull { + return subExpr + } + } + return expr +} + +func (expr *IfExpr) normalize(v *normalizeVisitor) TypedExpr { + if v.isConst(expr.Cond) { + cond, err := expr.TypedCondExpr().Eval(v.ctx) + if err != nil { + v.err = err + return expr + } + if d, err := GetBool(cond.(Datum)); err == nil { + if d { + return expr.TypedTrueExpr() + } + return expr.TypedElseExpr() + } + return DNull + } + return expr +} + +func (expr *UnaryExpr) normalize(v *normalizeVisitor) TypedExpr { + val := expr.TypedInnerExpr() + + if val == DNull { + return val + } + + switch expr.Operator { + case UnaryPlus: + // +a -> a + return val + case UnaryMinus: + // -0 -> 0 (except for float which has negative zero) + if !val.ReturnType().TypeEqual(TypeFloat) && IsNumericZero(val) { + return val + } + switch b := val.(type) { + // -(a - b) -> (b - a) + case *BinaryExpr: + if b.Operator == Minus { + exprCopy := *b + b = &exprCopy + b.Left, b.Right = b.Right, b.Left + b.memoizeFn() + return b + } + // - (- a) -> a + case *UnaryExpr: + if b.Operator == UnaryMinus { + return b.TypedInnerExpr() + } + } + } + + return expr +} + func (expr *BinaryExpr) normalize(v *normalizeVisitor) TypedExpr { left := expr.TypedLeft() right := expr.TypedRight() + expectedType := expr.ReturnType() if left == DNull || right == DNull { return DNull } - return expr + var final TypedExpr + + switch expr.Operator { + case Plus: + if IsNumericZero(right) { + final, v.err = ReType(left, expectedType) + break + } + if IsNumericZero(left) { + final, v.err = ReType(right, expectedType) + break + } + case Minus: + if IsNumericZero(right) { + final, v.err = ReType(left, expectedType) + break + } + case Mult: + if IsNumericOne(right) { + final, v.err = ReType(left, expectedType) + break + } + if IsNumericOne(left) { + final, v.err = ReType(right, expectedType) + } + if IsNumericZero(left) || IsNumericZero(right) { + final, v.err = SameTypeZero(expectedType) + break + } + case Div, FloorDiv: + if IsNumericOne(right) { + final, v.err = ReType(left, expectedType) + break + } + if IsNumericZero(left) { + cbz, err := CanBeZeroDivider(right) + if err != nil { + final, v.err = expr, err + } + if !cbz { + final, v.err = ReType(left, expectedType) + break + } + } + case Mod: + if IsNumericOne(right) { + final, v.err = SameTypeZero(expectedType) + break + } + if IsNumericZero(left) { + cbz, err := CanBeZeroDivider(right) + if err != nil { + final, v.err = expr, err + break + } + if !cbz { + final, v.err = ReType(left, expectedType) + break + } + } + } + + return final } func (expr *AndExpr) normalize(v *normalizeVisitor) TypedExpr { @@ -68,7 +220,7 @@ func (expr *AndExpr) normalize(v *normalizeVisitor) TypedExpr { return expr } if right != DNull { - if d, err := GetBool(expr.Right.(Datum)); err == nil { + if d, err := GetBool(right.(Datum)); err == nil { if !d { return right } @@ -87,10 +239,6 @@ func (expr *AndExpr) normalize(v *normalizeVisitor) TypedExpr { func (expr *ComparisonExpr) normalize(v *normalizeVisitor) TypedExpr { switch expr.Operator { case EQ, GE, GT, LE, LT: - if expr.TypedLeft() == DNull || expr.TypedRight() == DNull { - return DNull - } - // We want var nodes (VariableExpr, QualifiedName, etc) to be immediate // children of the comparison expression and not second or third // children. That is, we want trees that look like: @@ -114,6 +262,10 @@ func (expr *ComparisonExpr) normalize(v *normalizeVisitor) TypedExpr { // tree or we would not have entered this code path. exprCopied := false for { + if expr.TypedLeft() == DNull || expr.TypedRight() == DNull { + return DNull + } + if v.isConst(expr.Left) { switch expr.Right.(type) { case *BinaryExpr, VariableExpr: @@ -121,6 +273,13 @@ func (expr *ComparisonExpr) normalize(v *normalizeVisitor) TypedExpr { default: return expr } + + invertedOp, err := invertComparisonOp(expr.Operator) + if err != nil { + v.err = nil + return expr + } + // The left side is const and the right side is a binary expression or a // variable. Flip the comparison op so that the right side is const and // the left side is a binary expression or variable. @@ -130,11 +289,8 @@ func (expr *ComparisonExpr) normalize(v *normalizeVisitor) TypedExpr { expr = &exprCopy exprCopied = true } - expr = NewTypedComparisonExpr( - invertComparisonOp(expr.Operator), - expr.TypedRight(), - expr.TypedLeft(), - ) + + expr = NewTypedComparisonExpr(invertedOp, expr.TypedRight(), expr.TypedLeft()) } else if !v.isConst(expr.Right) { return expr } @@ -156,6 +312,22 @@ func (expr *ComparisonExpr) normalize(v *normalizeVisitor) TypedExpr { // [+-/] 2 -> a [-+*] // / \ / \ // a 1 2 1 + var op BinaryOperator + switch left.Operator { + case Plus: + op = Minus + case Minus: + op = Plus + case Div: + op = Mult + } + + newBinExpr := newBinExprIfValidOverload(op, + expr.TypedRight(), left.TypedRight()) + if newBinExpr == nil { + // Substitution is not possible type-wise. Nothing else to do. + break + } if !exprCopied { exprCopy := *expr @@ -163,26 +335,12 @@ func (expr *ComparisonExpr) normalize(v *normalizeVisitor) TypedExpr { exprCopied = true } - leftCopy := *left - left = &leftCopy - - switch left.Operator { - case Plus: - left.Operator = Minus - case Minus: - left.Operator = Plus - case Div: - left.Operator = Mult - } - expr.Left = left.Left - left.Left = expr.Right - - left.memoizeFn() - expr.Right, v.err = left.Eval(v.ctx) + expr.Right, v.err = newBinExpr.Eval(v.ctx) if v.err != nil { return nil } + expr.memoizeFn() if !isVar(expr.Left) { // Continue as long as the left side of the comparison is not a @@ -197,30 +355,45 @@ func (expr *ComparisonExpr) normalize(v *normalizeVisitor) TypedExpr { // / \ / \ // 1 a 1 2 + op := expr.Operator + var newBinExpr *BinaryExpr + + switch left.Operator { + case Plus: + // + // (A + X) cmp B => X cmp (B - C) + // + newBinExpr = newBinExprIfValidOverload(Minus, + expr.TypedRight(), left.TypedLeft()) + case Minus: + // + // (A - X) cmp B => X cmp' (A - B) + // + newBinExpr = newBinExprIfValidOverload(Minus, + left.TypedLeft(), expr.TypedRight()) + op, v.err = invertComparisonOp(op) + if v.err != nil { + return expr + } + } + + if newBinExpr == nil { + break + } + if !exprCopied { exprCopy := *expr expr = &exprCopy exprCopied = true } - leftCopy := *left - left = &leftCopy - - // Clear the function caches; we're about to change stuff. - left.Right, expr.Right = expr.Right, left.Right - if left.Operator == Plus { - left.Operator = Minus - left.Left, left.Right = left.Right, left.Left - } else { - expr.Operator = invertComparisonOp(expr.Operator) - } - - left.memoizeFn() - expr.Left, v.err = left.Eval(v.ctx) + expr.Operator = op + expr.Left = left.Right + expr.Right, v.err = newBinExpr.Eval(v.ctx) if v.err != nil { return nil } - expr.Left, expr.Right = expr.Right, expr.Left + expr.memoizeFn() if !isVar(expr.Left) { // Continue as long as the left side of the comparison is not a @@ -307,7 +480,11 @@ func (expr *OrExpr) normalize(v *normalizeVisitor) TypedExpr { } func (expr *ParenExpr) normalize(v *normalizeVisitor) TypedExpr { - return expr.TypedInnerExpr() + newExpr := expr.TypedInnerExpr() + if normalizeable, ok := newExpr.(normalizableExpr); ok { + newExpr = normalizeable.normalize(v) + } + return newExpr } func (expr *RangeCond) normalize(v *normalizeVisitor) TypedExpr { @@ -319,16 +496,16 @@ func (expr *RangeCond) normalize(v *normalizeVisitor) TypedExpr { if expr.Not { // "a NOT BETWEEN b AND c" -> "a < b OR a > c" return NewTypedOrExpr( - NewTypedComparisonExpr(LT, left, from), - NewTypedComparisonExpr(GT, left, to), - ) + NewTypedComparisonExpr(LT, left, from).normalize(v), + NewTypedComparisonExpr(GT, left, to).normalize(v), + ).normalize(v) } // "a BETWEEN b AND c" -> "a >= b AND a <= c" return NewTypedAndExpr( - NewTypedComparisonExpr(GE, left, from), - NewTypedComparisonExpr(LE, left, to), - ) + NewTypedComparisonExpr(GE, left, from).normalize(v), + NewTypedComparisonExpr(LE, left, to).normalize(v), + ).normalize(v) } // NormalizeExpr normalizes a typed expression, simplifying where possible, @@ -364,25 +541,7 @@ func (v *normalizeVisitor) VisitPre(expr Expr) (recurse bool, newExpr Expr) { return false, expr } - // Normalize expressions that know how to normalize themselves. - if normalizeable, ok := expr.(normalizableExpr); ok { - expr = normalizeable.normalize(v) - if v.err != nil { - return false, expr - } - } - switch expr.(type) { - case *CaseExpr, *IfExpr, *NullIfExpr, *CoalesceExpr: - // Conditional expressions need to be evaluated during the downward - // traversal in order to avoid evaluating sub-expressions which should - // not be evaluated due to the case/conditional. - if v.isConst(expr) { - expr, v.err = expr.(TypedExpr).Eval(v.ctx) - if v.err != nil { - return false, expr - } - } case *Subquery: // Avoid normalizing subqueries. We need the subquery to be expanded in // order to do so properly. @@ -409,7 +568,11 @@ func (v *normalizeVisitor) VisitPost(expr Expr) Expr { // Evaluate all constant expressions. if v.isConst(expr) { - expr, v.err = expr.(TypedExpr).Eval(v.ctx) + newExpr, err := expr.(TypedExpr).Eval(v.ctx) + if err != nil { + return expr + } + expr = newExpr } return expr } @@ -418,20 +581,20 @@ func (v *normalizeVisitor) isConst(expr Expr) bool { return v.isConstVisitor.run(expr) } -func invertComparisonOp(op ComparisonOperator) ComparisonOperator { +func invertComparisonOp(op ComparisonOperator) (ComparisonOperator, error) { switch op { case EQ: - return EQ + return EQ, nil case GE: - return LE + return LE, nil case GT: - return LT + return LT, nil case LE: - return GE + return GE, nil case LT: - return GT + return GT, nil default: - panic(fmt.Sprintf("unable to invert: %s", op)) + return op, errors.Errorf("internal error: unable to invert: %s", op) } } @@ -496,3 +659,99 @@ func ContainsVars(expr Expr) bool { WalkExprConst(&v, expr) return v.containsVars } + +// DecimalZero represents the constant 0 as DECIMAL. +var DecimalZero DDecimal + +// DecimalOne represents the constant 1 as DECIMAL. +var DecimalOne DDecimal + +func init() { + DecimalOne.Dec.SetUnscaled(1).SetScale(0) + DecimalZero.Dec.SetUnscaled(0).SetScale(0) +} + +// IsNumericZero returns true if the datum is a number and equal to +// zero. +func IsNumericZero(expr TypedExpr) bool { + if d, ok := expr.(Datum); ok { + switch t := d.(type) { + case *DDecimal: + return t.Dec.Cmp(&DecimalZero.Dec) == 0 + case *DFloat: + return *t == 0 + case *DInt: + return *t == 0 + } + } + return false +} + +// CanBeZeroDivider returns true if the expr may be a number and equal +// to zero. It also returns true if it is not known yet whether the +// expr is a number (e.g. it is a more complex sub-expression that has +// resisted normalization). +func CanBeZeroDivider(expr TypedExpr) (bool, error) { + if d, ok := expr.(Datum); ok { + switch t := d.(type) { + case *DDecimal: + return t.Dec.Cmp(&DecimalZero.Dec) == 0, nil + case *DFloat: + return *t == 0, nil + case *DInt: + return *t == 0, nil + } + if _, ok := d.(dividerDatum); ok { + return true, errors.Errorf("internal error: unknown dividerDatum in IsValidDivider(): %T", d) + } + // All other Datums are non-numeric and thus cannot be zero. + return false, nil + } + // Other type of expression; may evaluate to zero. + return true, nil +} + +// IsNumericOne returns true if the datum is a number and equal to +// one. +func IsNumericOne(expr TypedExpr) bool { + if d, ok := expr.(Datum); ok { + switch t := d.(type) { + case *DDecimal: + return t.Dec.Cmp(&DecimalOne.Dec) == 0 + case *DFloat: + return *t == 1.0 + case *DInt: + return *t == 1 + } + } + return false +} + +// SameTypeZero returns a datum of equivalent type with value zero. +// The argument must be a datum of a numeric type. +func SameTypeZero(d Datum) (TypedExpr, error) { + switch d.(type) { + case *DDecimal: + return &DecimalZero, nil + case *DFloat: + return NewDFloat(0.0), nil + case *DInt: + return NewDInt(0), nil + } + return nil, errors.Errorf("internal error: zero not defined for Datum type %T", d) +} + +// ReType ensures that the given numeric expression evaluates +// to the requested type, inserting a cast if necessary. +func ReType(expr TypedExpr, wantedType Datum) (TypedExpr, error) { + if expr.ReturnType().TypeEqual(wantedType) { + return expr, nil + } + reqType, err := DatumTypeToColumnType(wantedType) + if err != nil { + return nil, err + } + res := &CastExpr{Expr: expr, Type: reqType} + res.typ = wantedType + return res, nil +} diff --git a/sql/parser/normalize_test.go b/sql/parser/normalize_test.go index 833920e3ecd6..c9ca323fabbf 100644 --- a/sql/parser/normalize_test.go +++ b/sql/parser/normalize_test.go @@ -32,6 +32,34 @@ func TestNormalizeExpr(t *testing.T) { }{ {`(a)`, `a`}, {`((((a))))`, `a`}, + {`CAST(NULL AS INTEGER)`, `NULL`}, + {`+a`, `a`}, + {`-(-a)`, `a`}, + {`-+-a`, `a`}, + {`-(a-b)`, `b - a`}, + {`-0`, `0`}, + {`-NULL`, `NULL`}, + {`-1`, `-1`}, + {`a+0`, `a`}, + {`0+a`, `a`}, + {`a+(2-2)`, `a`}, + {`a-0`, `a`}, + {`a*1`, `a`}, + {`1*a`, `a`}, + {`0*a`, `0`}, + {`a+NULL`, `NULL`}, + {`a/1`, `CAST(a AS DECIMAL)`}, + {`0/a`, `0 / a`}, + {`0/1`, `0`}, + {`0%a`, `0 % a`}, + {`0%1`, `0`}, + {`1%1`, `0`}, + {`0%2`, `0`}, + {`1%2`, `1`}, + {`1%-2`, `1`}, + {`12 BETWEEN 24 AND 36`, `false`}, + {`12 BETWEEN 10 AND 20`, `true`}, + {`10 BETWEEN a AND 20`, `a <= 10`}, {`a BETWEEN b AND c`, `(a >= b) AND (a <= c)`}, {`a NOT BETWEEN b AND c`, `(a < b) OR (a > c)`}, {`a BETWEEN NULL AND c`, `NULL`}, @@ -89,6 +117,10 @@ func TestNormalizeExpr(t *testing.T) { {`(1, 2, 3) = (1, 2, 3)`, `true`}, {`(1, 2, 3) IN ((1, 2, 3), (4, 5, 6))`, `true`}, {`(1, 'one')`, `(1, 'one')`}, + {`IF((true AND a < 0), (0 + a)::decimal, 2 / (1 - 1))`, `IF(a < 0, CAST(a AS DECIMAL), 2 / 0)`}, + {`IF((true OR a < 0), (0 + a)::decimal, 2 / (1 - 1))`, `CAST(a AS DECIMAL)`}, + {`COALESCE(NULL, (NULL < 3), a = 2 - 1, d)`, `COALESCE(a = 1, d)`}, + {`COALESCE(NULL, a)`, `a`}, } for _, d := range testData { expr, err := ParseExprTraditional(d.expr) diff --git a/sql/parser/type_check.go b/sql/parser/type_check.go index dc90d0d3fd9f..7a1c1fb30893 100644 --- a/sql/parser/type_check.go +++ b/sql/parser/type_check.go @@ -49,6 +49,23 @@ var ( TypeTuple Datum = &DTuple{} ) +// dividerDatum is used during normalization to determine which Datums +// are possible valid second operands to an arithmetic divide. See +// CanBeZeroDivider() for an example. +// Ensure this is implemented appropriately if/when new numeric types +// are added. +type dividerDatum interface { + isDividerDatum() +} + +var _ dividerDatum = NewDInt(0) +var _ dividerDatum = NewDFloat(0) +var _ dividerDatum = &DDecimal{} + +func (*DInt) isDividerDatum() {} +func (*DFloat) isDividerDatum() {} +func (*DDecimal) isDividerDatum() {} + // SemaContext defines the context in which to perform semantic analysis on an // expression syntax tree. type SemaContext struct {