Skip to content

Commit

Permalink
Merge pull request cockroachdb#1807 from tschottdorf/sql_literals
Browse files Browse the repository at this point in the history
parse literals more like Postgres
  • Loading branch information
tbg committed Jul 26, 2015
2 parents 9488706 + c620ff6 commit cdd21fd
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 41 deletions.
30 changes: 18 additions & 12 deletions sql/parser/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,22 @@ var (
)

func encodeSQLString(buf []byte, in []byte) []byte {
buf = append(buf, '\'')
for _, ch := range in {
if encodedChar := encodeMap[ch]; encodedChar == dontEscape {
buf = append(buf, ch)
} else {
buf = append(buf, '\\')
buf = append(buf, encodedChar)
// See http://www.postgresql.org/docs/9.4/static/sql-syntax-lexical.html
start := 0
for i, ch := range in {
if encodedChar := encodeMap[ch]; encodedChar != dontEscape {
if start == 0 {
buf = append(buf, 'e', '\'') // begin e'xxx' string
}
buf = append(buf, in[start:i]...)
buf = append(buf, '\\', encodedChar)
start = i + 1
}
}
if start == 0 {
buf = append(buf, '\'') // begin 'xxx' string if nothing was escaped
}
buf = append(buf, in[start:]...)
buf = append(buf, '\'')
return buf
}
Expand All @@ -59,18 +66,18 @@ func encodeSQLIdent(buf *bytes.Buffer, s string) {
return
}

// The only characters we need to escape are '"' and '\\'.
// The only character that requires escaping is a double quote.
_ = buf.WriteByte('"')
start := 0
for i, n := 0, len(s); i < n; i++ {
ch := s[i]
if ch == '"' || ch == '\\' {
if ch == '"' {
if start != i {
_, _ = buf.WriteString(s[start:i])
}
start = i + 1
_ = buf.WriteByte('\\')
_ = buf.WriteByte(ch)
_ = buf.WriteByte(ch) // add extra copy of ch
}
}
if start < len(s) {
Expand Down Expand Up @@ -99,14 +106,13 @@ func encodeSQLBytes(buf []byte, v []byte) []byte {
func init() {
encodeRef := map[byte]byte{
'\x00': '0',
'\'': '\'',
'"': '"',
'\b': 'b',
'\f': 'f',
'\n': 'n',
'\r': 'r',
'\t': 't',
'\\': '\\',
'\'': '\'',
}

for i := range encodeMap {
Expand Down
6 changes: 3 additions & 3 deletions sql/parser/expr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ func TestQualifiedNameString(t *testing.T) {
// and is then followed by [a-zA-Z0-9$_] or extended ascii.
{"foo$09", "foo$09"},
{"_Ab10", "_Ab10"},
// Everything else quotes the string and escapes '"' and '\\'.
// Everything else quotes the string and escapes double quotes.
{".foobar", `".foobar"`},
{`".foobar"`, `"\".foobar\""`},
{`\".foobar\"`, `"\\\".foobar\\\""`},
{`".foobar"`, `""".foobar"""`},
{`\".foobar\"`, `"\"".foobar\"""`},
}

for _, tc := range testCases {
Expand Down
34 changes: 25 additions & 9 deletions sql/parser/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,21 @@ func TestParse(t *testing.T) {
{`SELECT a FROM t`},
{`SELECT a.b FROM t`},
{`SELECT 'a' FROM t`},
{`SELECT 'a\'a' FROM t`},

{`SELECT 'a' AS "12345"`},
{`SELECT 'a' AS clnm`},

{`SELECT 'a\\na' FROM t`},
{`SELECT '\\n' FROM t`},
// Escaping may change since the scanning process loses information
// (you can write e'\'' or ''''), but these are the idempotent cases.
// Generally, anything that needs to escape plus \ and ' leads to an
// escaped string.
{`SELECT e'a\'a' FROM t`},
{`SELECT e'a\\\\na' FROM t`},
{`SELECT e'\\\\n' FROM t`},
{`SELECT "a""a" FROM t`},
{`SELECT a FROM "t\n"`}, // no escaping in sql identifiers
{`SELECT a FROM "t"""`}, // no escaping in sql identifiers

{`SELECT "FROM" FROM t`},
{`SELECT CAST(1 AS TEXT)`},
{`SELECT FROM t AS bar`},
Expand Down Expand Up @@ -246,11 +254,19 @@ func TestParse2(t *testing.T) {
// {`SELECT 010 FROM t`, ``},
// {`SELECT 0xf0 FROM t`, ``},
// {`SELECT 0xF0 FROM t`, ``},
// Escaped string literals are not always escaped the same.
{`SELECT 'a''a' FROM t`,
`SELECT 'a\'a' FROM t`},
{`SELECT "a""a" FROM t`,
`SELECT "a\"a" FROM t`},
// Escaped string literals are not always escaped the same because
// '''' and e'\'' scan to the same token. It's more convenient to
// prefer escaping ' and \, so we do that.
{`SELECT 'a''a'`,
`SELECT e'a\'a'`},
{`SELECT 'a\a'`,
`SELECT e'a\\a'`},
{`SELECT 'a\n'`,
`SELECT e'a\\n'`},
{"SELECT '\n'",
`SELECT e'\n'`},
{"SELECT '\n\\'",
`SELECT e'\n\\'`},
{`SELECT "a'a" FROM t`,
`SELECT "a'a" FROM t`},
// Comments are stripped.
Expand Down Expand Up @@ -319,7 +335,7 @@ func TestParseSyntax(t *testing.T) {
{`SELECT 1 FROM t FOR SHARE`},
{`SELECT 1 FROM t FOR KEY SHARE`},
{`SELECT ((1)) FROM t WHERE ((a)) IN (((1))) AND ((a, b)) IN ((((1, 1))), ((2, 2)))`},
{`SELECT '\'\"\b\n\r\t\\' FROM t`},
{`SELECT e'\'\"\b\n\r\t\\' FROM t`},
{`SELECT '\x' FROM t`},
{`SELECT 1 FROM t GROUP BY a`},
{`SELECT 1 FROM t ORDER BY a`},
Expand Down
37 changes: 25 additions & 12 deletions sql/parser/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ import (
)

const eof = -1
const errUnterminated = "unterminated string"
const errUnsupportedEscape = "octal, hex and unicode escape not supported"

type scanner struct {
in string
Expand Down Expand Up @@ -498,29 +500,40 @@ func (s *scanner) scanString(lval *sqlSymType, ch int, allowEscapes bool) bool {

case '\\':
t := s.peek()
// We always allow the quote character and "\" to be escaped.
if t == ch || t == '\\' {
lval.str += s.in[start : s.pos-1]
start = s.pos
s.pos++
continue
}
if allowEscapes {
lval.str += s.in[start : s.pos-1]
if t == ch || t == '\\' {
start = s.pos
s.pos++
continue
}

switch t {
case 'b', 'f', 'n', 'r', 't', '\'', '"':
lval.str += s.in[start : s.pos-1]
// TODO(pmattis): Handle other back-slash escapes? Octal? Hexadecimal?
// Unicode?
case 'b', 'f', 'n', 'r', 't', '\'':
lval.str += string(decodeMap[byte(t)])
s.pos++
start = s.pos
continue
case 'x', 'u', 'U':
fallthrough
case '0', '1', '2', '3', '4', '5', '6', '7':
lval.id = ERROR
lval.str = errUnsupportedEscape
return false
}
// TODO(pmattis): Handle other back-slash escapes? Octal? Hexadecimal?
// Unicode?

// If we end up here, it's a redundant escape - simply drop the
// backslash. For example, e'\"' is equivalent to e'"', and
// e'\a\b' to e'a\b'. This is what Postgres does:
// http://www.postgresql.org/docs/9.4/static/sql-syntax-lexical.html#SQL-SYNTAX-STRINGS-ESCAPE
start = s.pos
}

case eof:
lval.id = ERROR
lval.str = "unterminated string"
lval.str = errUnterminated
return false
}
}
Expand Down
24 changes: 19 additions & 5 deletions sql/parser/scan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package parser

import (
"reflect"
"strings"
"testing"
)

Expand Down Expand Up @@ -216,10 +217,11 @@ func TestScanString(t *testing.T) {
{`'a''b'`, `a'b`},
{`"a" "b"`, `a`},
{`'a' 'b'`, `a`},
{`'\n'`, "\\n"},
{`"\""`, `"`},
{`'\''`, `'`},
{`'\0\'\"\b\f\n\r\t\\'`, `\0'\"\b\f\n\r\t\`},
{`'\n'`, `\n`},
{`e'\n'`, "\n"},
{`'\\n'`, `\\n`},
{`'\'''`, `\'`},
{`'\0\'`, `\0\`},
{`"a"
"b"`, `ab`},
{`"a"
Expand All @@ -228,7 +230,19 @@ func TestScanString(t *testing.T) {
'b'`, `ab`},
{`'a'
"b"`, `a`},
{`e'foo\"\'\\\b\f\n\r\tbar'`, "foo\"'\\\b\f\n\r\tbar"},
{`e'\"'`, `"`}, // redundant escape
{`e'\a'`, `a`}, // redundant escape
{"'\n\\'", "\n\\"},
{`e'foo\"\'\\\b\f\n\r\tbar'`,
strings.Join([]string{`foo"'\`, "\b\f\n\r\t", `bar`}, "")},
{`e'\\0'`, `\0`},
{`'\0'`, `\0`},
{`e'\0'`, errUnsupportedEscape},
{`"''"`, `''`},
{`'""'''`, `""'`},
{`""""`, `"`},
{`''''`, `'`},
{`''''''`, `''`},
}
for _, d := range testData {
s := newScanner(d.sql)
Expand Down

0 comments on commit cdd21fd

Please sign in to comment.