diff --git a/sql/parser/encode.go b/sql/parser/encode.go index 86d2eed83ee3..fec478f6a43f 100644 --- a/sql/parser/encode.go +++ b/sql/parser/encode.go @@ -38,15 +38,22 @@ var ( ) func encodeSQLString(buf []byte, in []byte) []byte { - buf = append(buf, '\'') - for _, ch := range in { - if encodedChar := encodeMap[ch]; encodedChar == dontEscape { - buf = append(buf, ch) - } else { - buf = append(buf, '\\') - buf = append(buf, encodedChar) + // See http://www.postgresql.org/docs/9.4/static/sql-syntax-lexical.html + start := 0 + for i, ch := range in { + if encodedChar := encodeMap[ch]; encodedChar != dontEscape { + if start == 0 { + buf = append(buf, 'e', '\'') // begin e'xxx' string + } + buf = append(buf, in[start:i]...) + buf = append(buf, '\\', encodedChar) + start = i + 1 } } + if start == 0 { + buf = append(buf, '\'') // begin 'xxx' string if nothing was escaped + } + buf = append(buf, in[start:]...) buf = append(buf, '\'') return buf } @@ -59,18 +66,18 @@ func encodeSQLIdent(buf *bytes.Buffer, s string) { return } - // The only characters we need to escape are '"' and '\\'. + // The only character that requires escaping is a double quote. _ = buf.WriteByte('"') start := 0 for i, n := 0, len(s); i < n; i++ { ch := s[i] - if ch == '"' || ch == '\\' { + if ch == '"' { if start != i { _, _ = buf.WriteString(s[start:i]) } start = i + 1 - _ = buf.WriteByte('\\') _ = buf.WriteByte(ch) + _ = buf.WriteByte(ch) // add extra copy of ch } } if start < len(s) { @@ -99,14 +106,13 @@ func encodeSQLBytes(buf []byte, v []byte) []byte { func init() { encodeRef := map[byte]byte{ '\x00': '0', - '\'': '\'', - '"': '"', '\b': 'b', '\f': 'f', '\n': 'n', '\r': 'r', '\t': 't', '\\': '\\', + '\'': '\'', } for i := range encodeMap { diff --git a/sql/parser/expr_test.go b/sql/parser/expr_test.go index b166b78bb7b9..794a72971bd5 100644 --- a/sql/parser/expr_test.go +++ b/sql/parser/expr_test.go @@ -32,10 +32,10 @@ func TestQualifiedNameString(t *testing.T) { // and is then followed by [a-zA-Z0-9$_] or extended ascii. {"foo$09", "foo$09"}, {"_Ab10", "_Ab10"}, - // Everything else quotes the string and escapes '"' and '\\'. + // Everything else quotes the string and escapes double quotes. {".foobar", `".foobar"`}, - {`".foobar"`, `"\".foobar\""`}, - {`\".foobar\"`, `"\\\".foobar\\\""`}, + {`".foobar"`, `""".foobar"""`}, + {`\".foobar\"`, `"\"".foobar\"""`}, } for _, tc := range testCases { diff --git a/sql/parser/parse_test.go b/sql/parser/parse_test.go index c69bc75d4f05..2b6f3858af0e 100644 --- a/sql/parser/parse_test.go +++ b/sql/parser/parse_test.go @@ -117,13 +117,21 @@ func TestParse(t *testing.T) { {`SELECT a FROM t`}, {`SELECT a.b FROM t`}, {`SELECT 'a' FROM t`}, - {`SELECT 'a\'a' FROM t`}, {`SELECT 'a' AS "12345"`}, {`SELECT 'a' AS clnm`}, - {`SELECT 'a\\na' FROM t`}, - {`SELECT '\\n' FROM t`}, + // Escaping may change since the scanning process loses information + // (you can write e'\'' or ''''), but these are the idempotent cases. + // Generally, anything that needs to escape plus \ and ' leads to an + // escaped string. + {`SELECT e'a\'a' FROM t`}, + {`SELECT e'a\\\\na' FROM t`}, + {`SELECT e'\\\\n' FROM t`}, + {`SELECT "a""a" FROM t`}, + {`SELECT a FROM "t\n"`}, // no escaping in sql identifiers + {`SELECT a FROM "t"""`}, // no escaping in sql identifiers + {`SELECT "FROM" FROM t`}, {`SELECT CAST(1 AS TEXT)`}, {`SELECT FROM t AS bar`}, @@ -246,11 +254,19 @@ func TestParse2(t *testing.T) { // {`SELECT 010 FROM t`, ``}, // {`SELECT 0xf0 FROM t`, ``}, // {`SELECT 0xF0 FROM t`, ``}, - // Escaped string literals are not always escaped the same. - {`SELECT 'a''a' FROM t`, - `SELECT 'a\'a' FROM t`}, - {`SELECT "a""a" FROM t`, - `SELECT "a\"a" FROM t`}, + // Escaped string literals are not always escaped the same because + // '''' and e'\'' scan to the same token. It's more convenient to + // prefer escaping ' and \, so we do that. + {`SELECT 'a''a'`, + `SELECT e'a\'a'`}, + {`SELECT 'a\a'`, + `SELECT e'a\\a'`}, + {`SELECT 'a\n'`, + `SELECT e'a\\n'`}, + {"SELECT '\n'", + `SELECT e'\n'`}, + {"SELECT '\n\\'", + `SELECT e'\n\\'`}, {`SELECT "a'a" FROM t`, `SELECT "a'a" FROM t`}, // Comments are stripped. @@ -319,7 +335,7 @@ func TestParseSyntax(t *testing.T) { {`SELECT 1 FROM t FOR SHARE`}, {`SELECT 1 FROM t FOR KEY SHARE`}, {`SELECT ((1)) FROM t WHERE ((a)) IN (((1))) AND ((a, b)) IN ((((1, 1))), ((2, 2)))`}, - {`SELECT '\'\"\b\n\r\t\\' FROM t`}, + {`SELECT e'\'\"\b\n\r\t\\' FROM t`}, {`SELECT '\x' FROM t`}, {`SELECT 1 FROM t GROUP BY a`}, {`SELECT 1 FROM t ORDER BY a`}, diff --git a/sql/parser/scan.go b/sql/parser/scan.go index 36dd6355207c..0c851582c48b 100644 --- a/sql/parser/scan.go +++ b/sql/parser/scan.go @@ -29,6 +29,8 @@ import ( ) const eof = -1 +const errUnterminated = "unterminated string" +const errUnsupportedEscape = "octal, hex and unicode escape not supported" type scanner struct { in string @@ -498,29 +500,40 @@ func (s *scanner) scanString(lval *sqlSymType, ch int, allowEscapes bool) bool { case '\\': t := s.peek() - // We always allow the quote character and "\" to be escaped. - if t == ch || t == '\\' { - lval.str += s.in[start : s.pos-1] - start = s.pos - s.pos++ - continue - } if allowEscapes { + lval.str += s.in[start : s.pos-1] + if t == ch || t == '\\' { + start = s.pos + s.pos++ + continue + } + switch t { - case 'b', 'f', 'n', 'r', 't', '\'', '"': - lval.str += s.in[start : s.pos-1] + // TODO(pmattis): Handle other back-slash escapes? Octal? Hexadecimal? + // Unicode? + case 'b', 'f', 'n', 'r', 't', '\'': lval.str += string(decodeMap[byte(t)]) s.pos++ start = s.pos continue + case 'x', 'u', 'U': + fallthrough + case '0', '1', '2', '3', '4', '5', '6', '7': + lval.id = ERROR + lval.str = errUnsupportedEscape + return false } - // TODO(pmattis): Handle other back-slash escapes? Octal? Hexadecimal? - // Unicode? + + // If we end up here, it's a redundant escape - simply drop the + // backslash. For example, e'\"' is equivalent to e'"', and + // e'\a\b' to e'a\b'. This is what Postgres does: + // http://www.postgresql.org/docs/9.4/static/sql-syntax-lexical.html#SQL-SYNTAX-STRINGS-ESCAPE + start = s.pos } case eof: lval.id = ERROR - lval.str = "unterminated string" + lval.str = errUnterminated return false } } diff --git a/sql/parser/scan_test.go b/sql/parser/scan_test.go index 76c4c58cb27e..8fb64761f860 100644 --- a/sql/parser/scan_test.go +++ b/sql/parser/scan_test.go @@ -19,6 +19,7 @@ package parser import ( "reflect" + "strings" "testing" ) @@ -216,10 +217,11 @@ func TestScanString(t *testing.T) { {`'a''b'`, `a'b`}, {`"a" "b"`, `a`}, {`'a' 'b'`, `a`}, - {`'\n'`, "\\n"}, - {`"\""`, `"`}, - {`'\''`, `'`}, - {`'\0\'\"\b\f\n\r\t\\'`, `\0'\"\b\f\n\r\t\`}, + {`'\n'`, `\n`}, + {`e'\n'`, "\n"}, + {`'\\n'`, `\\n`}, + {`'\'''`, `\'`}, + {`'\0\'`, `\0\`}, {`"a" "b"`, `ab`}, {`"a" @@ -228,7 +230,19 @@ func TestScanString(t *testing.T) { 'b'`, `ab`}, {`'a' "b"`, `a`}, - {`e'foo\"\'\\\b\f\n\r\tbar'`, "foo\"'\\\b\f\n\r\tbar"}, + {`e'\"'`, `"`}, // redundant escape + {`e'\a'`, `a`}, // redundant escape + {"'\n\\'", "\n\\"}, + {`e'foo\"\'\\\b\f\n\r\tbar'`, + strings.Join([]string{`foo"'\`, "\b\f\n\r\t", `bar`}, "")}, + {`e'\\0'`, `\0`}, + {`'\0'`, `\0`}, + {`e'\0'`, errUnsupportedEscape}, + {`"''"`, `''`}, + {`'""'''`, `""'`}, + {`""""`, `"`}, + {`''''`, `'`}, + {`''''''`, `''`}, } for _, d := range testData { s := newScanner(d.sql)