Skip to content

Commit

Permalink
sql: fix some type / string conversion brokenness
Browse files Browse the repository at this point in the history
This ensures that e.g. `drop type "a+b"` works if there is a type
indeed named `a+b`.

Release note (bug fix): DROP TYPE and certain other statements that
work over SQL scalar types now properly support type names containing
special characters.
  • Loading branch information
knz committed Dec 2, 2020
1 parent 13ce0e8 commit d3f7d85
Show file tree
Hide file tree
Showing 13 changed files with 71 additions and 70 deletions.
2 changes: 1 addition & 1 deletion pkg/internal/sqlsmith/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ ORDER BY
currentCols = nil
}

coltyp, err := s.typeFromName(typ)
coltyp, err := s.typeFromSQLTypeSyntax(typ)
if err != nil {
return nil, err
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/internal/sqlsmith/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ import (
"github.com/lib/pq/oid"
)

func (s *Smither) typeFromName(name string) (*types.T, error) {
typRef, err := parser.ParseType(name)
func (s *Smither) typeFromSQLTypeSyntax(typeStr string) (*types.T, error) {
typRef, err := parser.GetTypeFromValidSQLSyntax(typeStr)
if err != nil {
return nil, errors.AssertionFailedf("failed to parse type: %v", name)
return nil, errors.AssertionFailedf("failed to parse type: %v", typeStr)
}
typ, err := tree.ResolveType(context.Background(), typRef, s)
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions pkg/sql/faketreeeval/evalctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ func (ep *DummyEvalPlanner) ResolveTableName(
return 0, errors.WithStack(errEvalPlanner)
}

// ParseType is part of the tree.EvalPlanner interface.
func (ep *DummyEvalPlanner) ParseType(sql string) (*types.T, error) {
// GetTypeFromValidSQLSyntax is part of the tree.EvalPlanner interface.
func (ep *DummyEvalPlanner) GetTypeFromValidSQLSyntax(sql string) (*types.T, error) {
return nil, errors.WithStack(errEvalPlanner)
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/opt/optgen/exprgen/parse_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func ParseType(typeStr string) (*types.T, error) {
}
return types.MakeTuple(colTypes), nil
}
typ, err := parser.ParseType(typeStr)
typ, err := parser.GetTypeFromValidSQLSyntax(typeStr)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/opt/testutils/testcat/test_catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -1022,7 +1022,7 @@ func (ts *TableStat) Histogram() []cat.HistogramBucket {
if ts.js.HistogramColumnType == "" || ts.js.HistogramBuckets == nil {
return nil
}
colTypeRef, err := parser.ParseType(ts.js.HistogramColumnType)
colTypeRef, err := parser.GetTypeFromValidSQLSyntax(ts.js.HistogramColumnType)
if err != nil {
panic(err)
}
Expand Down
84 changes: 50 additions & 34 deletions pkg/sql/parser/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,13 @@ func HasMultipleStatements(sql string) bool {
return false
}

// ParseQualifiedTableName parses a SQL string of the form
// `[ database_name . ] [ schema_name . ] table_name`.
// ParseQualifiedTableName parses a possibly qualified table name. The
// table name must contain one or more name parts, using the full
// input SQL syntax: each name part containing special characters, or
// non-lowercase characters, must be enclosed in double quote. The
// name may not be an invalid table name (the caller is responsible
// for guaranteeing that only valid table names are provided as
// input).
func ParseQualifiedTableName(sql string) (*tree.TableName, error) {
name, err := ParseTableName(sql)
if err != nil {
Expand All @@ -285,7 +290,12 @@ func ParseQualifiedTableName(sql string) (*tree.TableName, error) {
return &tn, nil
}

// ParseTableName parses a table name.
// ParseTableName parses a table name. The table name must contain one
// or more name parts, using the full input SQL syntax: each name
// part containing special characters, or non-lowercase characters,
// must be enclosed in double quote. The name may not be an invalid
// table name (the caller is responsible for guaranteeing that only
// valid table names are provided as input).
func ParseTableName(sql string) (*tree.UnresolvedObjectName, error) {
// We wrap the name we want to parse into a dummy statement since our parser
// can only parse full statements.
Expand All @@ -300,30 +310,6 @@ func ParseTableName(sql string) (*tree.UnresolvedObjectName, error) {
return rename.Name, nil
}

// ParseTableNameWithQualifiedNames can be used to parse an input table name that
// might be prefixed with an unquoted qualified name. The standard ParseTableName
// cannot do this due to limitations with our parser. In particular, the parser
// can't parse different productions individually -- it must parse them as part
// of a top level statement. This causes qualified names that contain keywords
// to require quotes, which are not required in some cases due to Postgres
// compatibility (in particular, as arguments to pg_dump). This function gets
// around this limitation by parsing the input table name as a column name
// with a fake non-keyword prefix, and then shifting the result down into an
// UnresolvedObjectName.
func ParseTableNameWithQualifiedNames(sql string) (*tree.UnresolvedObjectName, error) {
stmt, err := ParseOne(fmt.Sprintf("SELECT fakeprefix.%s", sql))
if err != nil {
return nil, err
}
un := stmt.AST.(*tree.Select).Select.(*tree.SelectClause).Exprs[0].Expr.(*tree.UnresolvedName)
var nameParts [3]string
numParts := un.NumParts - 1
for i := 0; i < numParts; i++ {
nameParts[i] = un.Parts[i]
}
return tree.NewUnresolvedObjectName(numParts, nameParts, 0 /* annotationIdx */)
}

// parseExprsWithInt parses one or more sql expressions.
func parseExprsWithInt(exprs []string, nakedIntType *types.T) (tree.Exprs, error) {
stmt, err := ParseOneWithInt(fmt.Sprintf("SET ROW (%s)", strings.Join(exprs, ",")), nakedIntType)
Expand All @@ -337,22 +323,31 @@ func parseExprsWithInt(exprs []string, nakedIntType *types.T) (tree.Exprs, error
return set.Values, nil
}

// ParseExprs is a short-hand for parseExprsWithInt(sql, defaultNakedIntType).
// ParseExprs parses a comma-delimited sequence of SQL scalar
// expressions. The caller is responsible for ensuring that the input
// is, in fact, a comma-delimited sequence of SQL scalar expressions —
// the results are undefined if the string contains invalid SQL
// syntax.
func ParseExprs(sql []string) (tree.Exprs, error) {
if len(sql) == 0 {
return tree.Exprs{}, nil
}
return parseExprsWithInt(sql, defaultNakedIntType)
}

// ParseExpr is a short-hand for parseExprsWithInt([]string{sql},
// defaultNakedIntType).
// ParseExpr parses a SQL scalar expression. The caller is responsible
// for ensuring that the input is, in fact, a valid SQL scalar
// expression — the results are undefined if the string contains
// invalid SQL syntax.
func ParseExpr(sql string) (tree.Expr, error) {
return ParseExprWithInt(sql, defaultNakedIntType)
}

// ParseExprWithInt is a short-hand for parseExprsWithInt([]string{sql},
// nakedIntType).'
// ParseExprWithInt parses a SQL scalar expression, using the given
// type when INT is used as type name in the SQL syntax. The caller is
// responsible for ensuring that the input is, in fact, a valid SQL
// scalar expression — the results are undefined if the string
// contains invalid SQL syntax.
func ParseExprWithInt(sql string, nakedIntType *types.T) (tree.Expr, error) {
exprs, err := parseExprsWithInt([]string{sql}, nakedIntType)
if err != nil {
Expand All @@ -364,8 +359,29 @@ func ParseExprWithInt(sql string, nakedIntType *types.T) (tree.Expr, error) {
return exprs[0], nil
}

// ParseType parses a column type.
func ParseType(sql string) (tree.ResolvableTypeReference, error) {
// GetTypeReferenceFromName turns a type name into a type
// reference. This supports only “simple” (single-identifier)
// references to built-in types, when the identifer has already been
// parsed away from the input SQL syntax.
func GetTypeReferenceFromName(typeName tree.Name) (tree.ResolvableTypeReference, error) {
expr, err := ParseExpr(fmt.Sprintf("1::%s", typeName.String()))
if err != nil {
return nil, err
}

cast, ok := expr.(*tree.CastExpr)
if !ok {
return nil, errors.AssertionFailedf("expected a tree.CastExpr, but found %T", expr)
}

return cast.Type, nil
}

// GetTypeFromValidSQLSyntax retrieves a type from its SQL syntax. The caller is
// responsible for guaranteeing that the type expression is valid
// SQL. This includes verifying that complex identifiers are enclosed
// in double quotes, etc.
func GetTypeFromValidSQLSyntax(sql string) (tree.ResolvableTypeReference, error) {
expr, err := ParseExpr(fmt.Sprintf("1::%s", sql))
if err != nil {
return nil, err
Expand Down
18 changes: 0 additions & 18 deletions pkg/sql/parser/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import (
"github.com/cockroachdb/datadriven"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

// TestParse verifies that we can parse the supplied SQL and regenerate the SQL
Expand Down Expand Up @@ -2873,23 +2872,6 @@ func TestParseDatadriven(t *testing.T) {
})
}

func TestParseTableNameWithQualifiedNames(t *testing.T) {
testdata := []struct {
name string
expected string
}{
{"unique", `"unique"`},
{"unique.index", `"unique".index`},
{"table.index.primary", `"table".index.primary`},
}

for _, tc := range testdata {
name, err := parser.ParseTableNameWithQualifiedNames(tc.name)
require.NoError(t, err)
require.Equal(t, tc.expected, name.String())
}
}

func TestParsePanic(t *testing.T) {
// Replicates #1801.
defer func() {
Expand Down
6 changes: 3 additions & 3 deletions pkg/sql/planner.go
Original file line number Diff line number Diff line change
Expand Up @@ -466,10 +466,10 @@ func (p *planner) DistSQLPlanner() *DistSQLPlanner {
return p.extendedEvalCtx.DistSQLPlanner
}

// ParseType implements the tree.EvalPlanner interface.
// GetTypeFromValidSQLSyntax implements the tree.EvalPlanner interface.
// We define this here to break the dependency from eval.go to the parser.
func (p *planner) ParseType(sql string) (*types.T, error) {
ref, err := parser.ParseType(sql)
func (p *planner) GetTypeFromValidSQLSyntax(sql string) (*types.T, error) {
ref, err := parser.GetTypeFromValidSQLSyntax(sql)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/schemachange/alter_column_type_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func TestColumnConversions(t *testing.T) {
defer leaktest.AfterTest(t)()

columnType := func(typStr string) *types.T {
t, err := parser.ParseType(typStr)
t, err := parser.GetTypeFromValidSQLSyntax(typStr)
if err != nil {
panic(err)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/sem/tree/datum.go
Original file line number Diff line number Diff line change
Expand Up @@ -4258,7 +4258,7 @@ func ParseDOid(ctx *EvalContext, s string, t *types.T) (*DOid, error) {
}
return queryOid(ctx, t, NewDString(funcDef.Name))
case oid.T_regtype:
parsedTyp, err := ctx.Planner.ParseType(s)
parsedTyp, err := ctx.Planner.GetTypeFromValidSQLSyntax(s)
if err == nil {
return &DOid{
semanticType: t,
Expand Down
7 changes: 5 additions & 2 deletions pkg/sql/sem/tree/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -3026,8 +3026,11 @@ type EvalDatabase interface {
type EvalPlanner interface {
EvalDatabase
TypeReferenceResolver
// ParseType parses a column type.
ParseType(sql string) (*types.T, error)

// GetTypeFromValidSQLSyntax parses a column type when the input
// string uses the parseable SQL representation of a type name, e.g.
// "INT(13)" or "pg_catalog.int".
GetTypeFromValidSQLSyntax(sql string) (*types.T, error)

// EvalSubquery returns the Datum for the given subquery node.
EvalSubquery(expr *Subquery) (Datum, error)
Expand Down
4 changes: 2 additions & 2 deletions pkg/sql/stats/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type JSONStatistic struct {
NullCount uint64 `json:"null_count"`
// HistogramColumnType is the string representation of the column type for the
// histogram (or unset if there is no histogram). Parsable with
// tree.ParseType.
// tree.GetTypeFromValidSQLSyntax.
HistogramColumnType string `json:"histo_col_type"`
HistogramBuckets []JSONHistoBucket `json:"histo_buckets,omitempty"`
}
Expand Down Expand Up @@ -106,7 +106,7 @@ func (js *JSONStatistic) GetHistogram(
return nil, nil
}
h := &HistogramData{}
colTypeRef, err := parser.ParseType(js.HistogramColumnType)
colTypeRef, err := parser.GetTypeFromValidSQLSyntax(js.HistogramColumnType)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/virtual_schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ func (v *virtualSchemaEntry) GetObjectByName(
// invalid input type like "notatype" will be parsed successfully as
// a ResolvableTypeReference, so the error here does not need to be
// intercepted and inspected.
typRef, err := parser.ParseType(name)
typRef, err := parser.GetTypeReferenceFromName(tree.Name(name))
if err != nil {
return nil, err
}
Expand Down

0 comments on commit d3f7d85

Please sign in to comment.