Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sql: fix reg* cast escaping #55607

Merged
merged 2 commits into from
Oct 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion pkg/sql/logictest/testdata/logic_test/pgoidtype
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ SELECT 'blah ()'::REGPROC
query error unknown function: blah\(\)
SELECT 'blah( )'::REGPROC

query error unknown function: blah\(, \)\(\)
query error invalid name: expected separator \.: blah\(, \)
SELECT 'blah(, )'::REGPROC

query error more than one function named 'sqrt'
Expand Down Expand Up @@ -347,3 +347,42 @@ FROM
(d.adrelid = a.attrelid AND d.adnum = a.attnum)
JOIN (SELECT 1 AS oid, 1 AS attnum) AS vals ON
(c.oid = vals.oid AND a.attnum = vals.attnum);

statement error relation ".*"regression_53686.*"" does not exist
SELECT '\"regression_53686\"'::regclass

statement ok
CREATE TABLE "regression_53686""" (a int)

query T
SELECT 'regression_53686"'::regclass
----
"regression_53686"""

query T
SELECT 'public.regression_53686"'::regclass
----
"regression_53686"""

query T
SELECT 'pg_catalog."radians"'::regproc
----
radians

query T
SELECT 'pg_catalog."radians"("float4")'::regproc
----
radians

statement error unknown function: pg_catalog.radians"\(\)
SELECT 'pg_catalog."radians"""'::regproc

query TTTTT
SELECT
'12345'::regclass::string,
'12345'::regtype::string,
'12345'::oid::string,
'12345'::regproc::string,
'12345'::regprocedure::string
----
12345 12345 12345 12345 12345
157 changes: 141 additions & 16 deletions pkg/sql/sem/tree/casts.go
Original file line number Diff line number Diff line change
Expand Up @@ -1076,30 +1076,31 @@ func PerformCast(ctx *EvalContext, d Datum, t *types.T) (Datum, error) {
}
case *DString:
s := string(*v)
// Trim whitespace and unwrap outer quotes if necessary.
// This is required to mimic postgres.
s = strings.TrimSpace(s)
origS := s
if len(s) > 1 && s[0] == '"' && s[len(s)-1] == '"' {
s = s[1 : len(s)-1]
}

switch t.Oid() {
case oid.T_oid:
i, err := ParseDInt(s)
// If it is an integer in string form, convert it as an int.
if val, err := ParseDInt(strings.TrimSpace(s)); err == nil {
tmpOid := NewDOid(*val)
oid, err := queryOid(ctx, t, tmpOid)
if err != nil {
return nil, err
oid = tmpOid
oid.semanticType = t
}
return &DOid{semanticType: t, DInt: *i}, nil
return oid, nil
}

switch t.Oid() {
case oid.T_regproc, oid.T_regprocedure:
// Trim procedure type parameters, e.g. `max(int)` becomes `max`.
// Postgres only does this when the cast is ::regprocedure, but we're
// going to always do it.
// We additionally do not yet implement disambiguation based on type
// parameters: we return the match iff there is exactly one.
s = pgSignatureRegexp.ReplaceAllString(s, "$1")
// Resolve function name.
substrs := strings.Split(s, ".")

substrs, err := splitIdentifierList(s)
if err != nil {
return nil, err
}
if len(substrs) > 3 {
// A fully qualified function name in pg's dialect can contain
// at most 3 parts: db.schema.funname.
Expand All @@ -1126,10 +1127,21 @@ func PerformCast(ctx *EvalContext, d Datum, t *types.T) (Datum, error) {
name: parsedTyp.SQLStandardName(),
}, nil
}

// Fall back to searching pg_type, since we don't provide syntax for
// every postgres type that we understand OIDs for.
// Note this section does *not* work if there is a schema in front of the
// type, e.g. "pg_catalog"."int4" (if int4 was not defined).

// Trim whitespace and unwrap outer quotes if necessary.
// This is required to mimic postgres.
s = strings.TrimSpace(s)
if len(s) > 1 && s[0] == '"' && s[len(s)-1] == '"' {
s = s[1 : len(s)-1]
}
// Trim type modifiers, e.g. `numeric(10,3)` becomes `numeric`.
s = pgSignatureRegexp.ReplaceAllString(s, "$1")

dOid, missingTypeErr := queryOid(ctx, t, NewDString(s))
if missingTypeErr == nil {
return dOid, missingTypeErr
Expand All @@ -1155,11 +1167,11 @@ func PerformCast(ctx *EvalContext, d Datum, t *types.T) (Datum, error) {
}, nil

case oid.T_regclass:
tn, err := ctx.Planner.ParseQualifiedTableName(origS)
tn, err := castStringToRegClassTableName(s)
if err != nil {
return nil, err
}
id, err := ctx.Planner.ResolveTableName(ctx.Ctx(), tn)
id, err := ctx.Planner.ResolveTableName(ctx.Ctx(), &tn)
if err != nil {
return nil, err
}
Expand All @@ -1177,3 +1189,116 @@ func PerformCast(ctx *EvalContext, d Datum, t *types.T) (Datum, error) {
return nil, pgerror.Newf(
pgcode.CannotCoerce, "invalid cast: %s -> %s", d.ResolvedType(), t)
}

// castStringToRegClassTableName normalizes a TableName from a string.
func castStringToRegClassTableName(s string) (TableName, error) {
components, err := splitIdentifierList(s)
if err != nil {
return TableName{}, err
}

if len(components) > 3 {
return TableName{}, pgerror.Newf(
pgcode.InvalidName,
"too many components: %s",
s,
)
}
var retComponents [3]string
for i := 0; i < len(components); i++ {
retComponents[len(components)-1-i] = components[i]
}
u, err := NewUnresolvedObjectName(
len(components),
retComponents,
0,
)
if err != nil {
return TableName{}, err
}
return u.ToTableName(), nil
}

// splitIdentifierList splits identifiers to individual components, lower
// casing non-quoted identifiers and escaping quoted identifiers as appropriate.
// It is based on PostgreSQL's SplitIdentifier.
func splitIdentifierList(in string) ([]string, error) {
var pos int
var ret []string
const separator = '.'

for pos < len(in) {
if isWhitespace(in[pos]) {
pos++
continue
}
if in[pos] == '"' {
var b strings.Builder
// Attempt to find the ending quote. If the quote is double "",
// fold it into a " character for the str (e.g. "a""" means a").
for {
pos++
endIdx := strings.IndexByte(in[pos:], '"')
if endIdx == -1 {
return nil, pgerror.Newf(
pgcode.InvalidName,
`invalid name: unclosed ": %s`,
in,
)
}
b.WriteString(in[pos : pos+endIdx])
pos += endIdx + 1
// If we reached the end, or the following character is not ",
// we can break and assume this is one identifier.
// There are checks below to ensure EOF or whitespace comes
// afterward.
if pos == len(in) || in[pos] != '"' {
break
}
b.WriteByte('"')
}
ret = append(ret, b.String())
} else {
var b strings.Builder
for pos < len(in) && in[pos] != separator && !isWhitespace(in[pos]) {
b.WriteByte(in[pos])
pos++
}
// Anything with no quotations should be lowered.
ret = append(ret, strings.ToLower(b.String()))
}

// Further ignore all white space.
for pos < len(in) && isWhitespace(in[pos]) {
pos++
}

// At this stage, we expect separator or end of string.
if pos == len(in) {
break
}

if in[pos] != separator {
return nil, pgerror.Newf(
pgcode.InvalidName,
"invalid name: expected separator %c: %s",
separator,
in,
)
}

pos++
}

return ret, nil
}

// isWhitespace returns true if the given character is a space.
// This must match parser.SkipWhitespace above.
func isWhitespace(ch byte) bool {
switch ch {
case ' ', '\t', '\r', '\f', '\n':
return true
}
return false
}
74 changes: 74 additions & 0 deletions pkg/sql/sem/tree/casts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,77 @@ func TestTupleCastVolatility(t *testing.T) {
}
}
}

func TestCastStringToRegClassTableName(t *testing.T) {
defer leaktest.AfterTest(t)()

testCases := []struct {
in string
expected TableName
}{
{"a", MakeUnqualifiedTableName("a")},
{`a"`, MakeUnqualifiedTableName(`a"`)},
{`"a""".bB."cD" `, MakeTableNameWithSchema(`a"`, "bb", "cD")},
}

for _, tc := range testCases {
t.Run(tc.in, func(t *testing.T) {
out, err := castStringToRegClassTableName(tc.in)
require.NoError(t, err)
require.Equal(t, tc.expected, out)
})
}

errorTestCases := []struct {
in string
expectedError string
}{
{"a.b.c.d", "too many components: a.b.c.d"},
{"", `invalid table name: `},
}

for _, tc := range errorTestCases {
t.Run(tc.in, func(t *testing.T) {
_, err := castStringToRegClassTableName(tc.in)
require.EqualError(t, err, tc.expectedError)
})
}

}

func TestSplitIdentifierList(t *testing.T) {
defer leaktest.AfterTest(t)()

testCases := []struct {
in string
expected []string
}{
{`abc`, []string{"abc"}},
{`abc.dEf `, []string{"abc", "def"}},
{` "aBc" . d ."HeLLo"""`, []string{"aBc", "d", `HeLLo"`}},
}

for _, tc := range testCases {
t.Run(tc.in, func(t *testing.T) {
out, err := splitIdentifierList(tc.in)
require.NoError(t, err)
require.Equal(t, tc.expected, out)
})
}

errorTestCases := []struct {
in string
expectedError string
}{
{`"unclosed`, `invalid name: unclosed ": "unclosed`},
{`"unclosed""`, `invalid name: unclosed ": "unclosed""`},
{`hello !`, `invalid name: expected separator .: hello !`},
}

for _, tc := range errorTestCases {
t.Run(tc.in, func(t *testing.T) {
_, err := splitIdentifierList(tc.in)
require.EqualError(t, err, tc.expectedError)
})
}
}
2 changes: 1 addition & 1 deletion pkg/sql/sem/tree/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -3616,7 +3616,7 @@ func (expr *CaseExpr) Eval(ctx *EvalContext) (Datum, error) {
// pgSignatureRegexp matches a Postgres function type signature, capturing the
// name of the function into group 1.
// e.g. function(a, b, c) or function( a )
var pgSignatureRegexp = regexp.MustCompile(`^\s*([\w\.]+)\s*\((?:(?:\s*\w+\s*,)*\s*\w+)?\s*\)\s*$`)
var pgSignatureRegexp = regexp.MustCompile(`^\s*([\w\."]+)\s*\((?:(?:\s*[\w"]+\s*,)*\s*[\w"]+)?\s*\)\s*$`)

// regTypeInfo contains details on a pg_catalog table that has a reg* type.
type regTypeInfo struct {
Expand Down