From a71a457261ea8e522d320726a8aeea6768f08acf Mon Sep 17 00:00:00 2001 From: rahul2393 Date: Thu, 18 Aug 2022 15:36:15 +0530 Subject: [PATCH] feat: add support of positional parameter in the queries (#110) * feat: add support of positional parameter in the queries * incorporate requested changes * fixed nits and added tests * run test --- README.md | 16 ++- driver.go | 4 +- examples/transactions/main.go | 4 +- statement_parser.go | 63 ++++++++--- statement_parser_test.go | 194 +++++++++++++++++++++++++++++----- stmt.go | 2 +- 6 files changed, 231 insertions(+), 52 deletions(-) diff --git a/README.md b/README.md index 04c1d8de..befa4722 100644 --- a/README.md +++ b/README.md @@ -37,13 +37,19 @@ for rows.Next() { ## Statements -Statements support follows the official [Google Cloud Spanner Go](https://pkg.go.dev/cloud.google.com/go/spanner) client style arguments. +Statements support follows the official [Google Cloud Spanner Go](https://pkg.go.dev/cloud.google.com/go/spanner) client style arguments as well as positional paramaters. + +### Using positional patameter ```go -db.QueryContext(ctx, "SELECT id, text FROM tweets WHERE likes > @likes", 500) +db.QueryContext(ctx, "SELECT id, text FROM tweets WHERE likes > ?", 500) + +db.ExecContext(ctx, "INSERT INTO tweets (id, text, rts) VALUES (?, ?, ?)", id, text, 10000) +``` -db.ExecContext(ctx, "INSERT INTO tweets (id, text, rts) VALUES (@id, @text, @rts)", id, text, 10000) +### Using named patameter +```go db.ExecContext(ctx, "DELETE FROM tweets WHERE id = @id", 14544498215374) ``` @@ -51,8 +57,8 @@ db.ExecContext(ctx, "DELETE FROM tweets WHERE id = @id", 14544498215374) - Read-write transactions always uses the strongest isolation level and ignore the user-specified level. - Read-only transactions do strong-reads by default. Read-only transactions must be ended by calling -either Commit or Rollback. Calling either of these methods will end the current read-only -transaction and return the session that is used to the session pool. + either Commit or Rollback. Calling either of these methods will end the current read-only + transaction and return the session that is used to the session pool. ``` go tx, err := db.BeginTx(ctx, &sql.TxOptions{}) // Read-write transaction. diff --git a/driver.go b/driver.go index 3989ecc6..ea107f3f 100644 --- a/driver.go +++ b/driver.go @@ -737,11 +737,11 @@ func (c *conn) Prepare(query string) (driver.Stmt, error) { } func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { - args, err := parseNamedParameters(query) + parsedSQL, args, err := parseParameters(query) if err != nil { return nil, err } - return &stmt{conn: c, query: query, numArgs: len(args)}, nil + return &stmt{conn: c, query: parsedSQL, numArgs: len(args)}, nil } func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { diff --git a/examples/transactions/main.go b/examples/transactions/main.go index 23119ca1..3f2ef787 100644 --- a/examples/transactions/main.go +++ b/examples/transactions/main.go @@ -49,7 +49,7 @@ func transaction(projectId, instanceId, databaseId string) error { } // The row that we inserted will be readable for the same transaction that started it. - rows, err := tx.QueryContext(ctx, "SELECT SingerId, Name FROM Singers WHERE SingerId = @id", 123) + rows, err := tx.QueryContext(ctx, "SELECT SingerId, Name FROM Singers WHERE SingerId = ?", 123) if err != nil { _ = tx.Rollback() return err @@ -92,7 +92,7 @@ func transaction(projectId, instanceId, databaseId string) error { } // This should now find the row. - row = db.QueryRowContext(ctx, "SELECT SingerId, Name FROM Singers WHERE SingerId = @id", 123) + row = db.QueryRowContext(ctx, "SELECT SingerId, Name FROM Singers WHERE SingerId = ?", 123) if err := row.Err(); err != nil { return err } diff --git a/statement_parser.go b/statement_parser.go index 62f2324c..595d4a78 100644 --- a/statement_parser.go +++ b/statement_parser.go @@ -20,6 +20,7 @@ import ( "encoding/json" "reflect" "regexp" + "strconv" "strings" "sync" "unicode" @@ -45,18 +46,19 @@ func union(m1 map[string]bool, m2 map[string]bool) map[string]bool { return res } -// parseNamedParameters returns the named parameters in the given sql string. +// parseParameters returns the parameters in the given sql string, if the input +// sql contains positional parameters it returns the converted sql string with +// all positional parameters replaced with named parameters. // The sql string must be a valid Cloud Spanner sql statement. It may contain // comments and (string) literals without any restrictions. That is, string // literals containing for example an email address ('test@test.com') will be // recognized as a string literal and not returned as a named parameter. -func parseNamedParameters(sql string) ([]string, error) { +func parseParameters(sql string) (string, []string, error) { sql, err := removeCommentsAndTrim(sql) if err != nil { - return nil, err + return sql, nil, err } - sql = removeStatementHint(sql) - return findParams(sql) + return findParams('?', sql) } // RemoveCommentsAndTrim removes any comments in the query string and trims any @@ -188,9 +190,9 @@ func removeStatementHint(sql string) string { return sql } -// This function assumes that all comments and statement hints have already +// This function assumes that all comments have already // been removed from the statement. -func findParams(sql string) ([]string, error) { +func findParams(positionalParamChar rune, sql string) (string, []string, error) { const paramPrefix = '@' const singleQuote = '\'' const doubleQuote = '"' @@ -199,14 +201,19 @@ func findParams(sql string) ([]string, error) { var startQuote rune lastCharWasEscapeChar := false isTripleQuoted := false - res := make([]string, 0) + hasNamedParameter := false + hasPositionalParameter := false + namedParams := make([]string, 0) + parsedSQL := strings.Builder{} + parsedSQL.Grow(len(sql)) + positionalParameterIndex := 1 index := 0 runes := []rune(sql) for index < len(runes) { c := runes[index] if isInQuoted { if (c == '\n' || c == '\r') && !isTripleQuoted { - return nil, spanner.ToSpannerError(status.Errorf(codes.InvalidArgument, "statement contains an unclosed literal: %s", sql)) + return sql, nil, spanner.ToSpannerError(status.Errorf(codes.InvalidArgument, "statement contains an unclosed literal: %s", sql)) } else if c == startQuote { if lastCharWasEscapeChar { lastCharWasEscapeChar = false @@ -215,6 +222,8 @@ func findParams(sql string) ([]string, error) { isInQuoted = false startQuote = 0 isTripleQuoted = false + parsedSQL.WriteRune(c) + parsedSQL.WriteRune(c) index += 2 } } else { @@ -226,23 +235,41 @@ func findParams(sql string) ([]string, error) { } else { lastCharWasEscapeChar = false } + parsedSQL.WriteRune(c) } else { // We are not in a quoted string. It's a parameter if it is an '@' followed by a letter or an underscore. // See https://cloud.google.com/spanner/docs/lexical#identifiers for identifier rules. if c == paramPrefix && len(runes) > index+1 && (unicode.IsLetter(runes[index+1]) || runes[index+1] == '_') { + if hasPositionalParameter { + return sql, nil, spanner.ToSpannerError(status.Errorf(codes.InvalidArgument, "statement must not contain both named and positional parameter: %s", sql)) + } + parsedSQL.WriteRune(c) index++ startIndex := index for index < len(runes) { if !(unicode.IsLetter(runes[index]) || unicode.IsDigit(runes[index]) || runes[index] == '_') { - res = append(res, string(runes[startIndex:index])) + hasNamedParameter = true + namedParams = append(namedParams, string(runes[startIndex:index])) + parsedSQL.WriteRune(runes[index]) break } if index == len(runes)-1 { - res = append(res, string(runes[startIndex:])) + hasNamedParameter = true + namedParams = append(namedParams, string(runes[startIndex:])) + parsedSQL.WriteRune(runes[index]) break } + parsedSQL.WriteRune(runes[index]) index++ } + } else if c == positionalParamChar { + if hasNamedParameter { + return sql, nil, spanner.ToSpannerError(status.Errorf(codes.InvalidArgument, "statement must not contain both named and positional parameter: %s", sql)) + } + hasPositionalParameter = true + parsedSQL.WriteString("@p" + strconv.Itoa(positionalParameterIndex)) + namedParams = append(namedParams, "p"+strconv.Itoa(positionalParameterIndex)) + positionalParameterIndex++ } else { if c == singleQuote || c == doubleQuote || c == backtick { isInQuoted = true @@ -250,17 +277,27 @@ func findParams(sql string) ([]string, error) { // Check whether it is a triple-quote. if len(runes) > index+2 && runes[index+1] == startQuote && runes[index+2] == startQuote { isTripleQuoted = true + parsedSQL.WriteRune(c) + parsedSQL.WriteRune(c) index += 2 } } + parsedSQL.WriteRune(c) } } index++ } if isInQuoted { - return nil, spanner.ToSpannerError(status.Errorf(codes.InvalidArgument, "statement contains an unclosed literal: %s", sql)) + return sql, nil, spanner.ToSpannerError(status.Errorf(codes.InvalidArgument, "statement contains an unclosed literal: %s", sql)) + } + if hasNamedParameter { + return sql, namedParams, nil + } + sql = strings.TrimSpace(parsedSQL.String()) + if len(sql) > 0 && sql[len(sql)-1] == ';' { + sql = sql } - return res, nil + return sql, namedParams, nil } // isDDL returns true if the given sql string is a DDL statement. diff --git a/statement_parser_test.go b/statement_parser_test.go index 7d429afe..cfc213d4 100644 --- a/statement_parser_test.go +++ b/statement_parser_test.go @@ -20,6 +20,7 @@ import ( "cloud.google.com/go/spanner" "github.com/google/go-cmp/cmp" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) func TestRemoveCommentsAndTrim(t *testing.T) { @@ -465,61 +466,196 @@ SELECT SchoolID FROM Roster`, func TestFindParams(t *testing.T) { tests := []struct { - input string - want []string - wantErr bool + input string + wantSQL string + want []string + wantErr error + skipRemoveComments bool }{ { - input: `SELECT * FROM PersonsTable WHERE id=@id`, - want: []string{"id"}, + input: `SELECT * FROM PersonsTable WHERE id=@id`, + wantSQL: `SELECT * FROM PersonsTable WHERE id=@id`, + want: []string{"id"}, }, { - input: `SELECT * FROM PersonsTable WHERE id=@id AND name=@name`, - want: []string{"id", "name"}, + input: `SELECT * FROM PersonsTable WHERE id=@id AND name=@name`, + wantSQL: `SELECT * FROM PersonsTable WHERE id=@id AND name=@name`, + want: []string{"id", "name"}, }, { - input: `SELECT * FROM PersonsTable WHERE Name like @name AND Email='test@test.com'`, - want: []string{"name"}, + input: `SELECT * FROM PersonsTable WHERE Name like @name AND Email='test@test.com'`, + wantSQL: `SELECT * FROM PersonsTable WHERE Name like @name AND Email='test@test.com'`, + want: []string{"name"}, }, { input: `SELECT * FROM """strange - @table -""" WHERE Name like @name AND Email='test@test.com'`, + @table + """ WHERE Name like @name AND Email='test@test.com'`, + wantSQL: `SELECT * FROM """strange + @table + """ WHERE Name like @name AND Email='test@test.com'`, want: []string{"name"}, }, { - input: `@{JOIN_METHOD=HASH_JOIN} SELECT * FROM PersonsTable WHERE Name like @name AND Email='test@test.com'`, - want: []string{"name"}, + input: `@{JOIN_METHOD=HASH_JOIN} SELECT * FROM PersonsTable WHERE Name like @name AND Email='test@test.com'`, + wantSQL: `@{JOIN_METHOD=HASH_JOIN} SELECT * FROM PersonsTable WHERE Name like @name AND Email='test@test.com'`, + want: []string{"name"}, + }, + { + input: "INSERT INTO Foo (Col1, Col2, Col3) VALUES (@param1, @param2, @param3)", + wantSQL: "INSERT INTO Foo (Col1, Col2, Col3) VALUES (@param1, @param2, @param3)", + want: []string{"param1", "param2", "param3"}, + }, + { + input: "SELECT * FROM PersonsTable@{FORCE_INDEX=`my_index`} WHERE id=@id AND name=@name", + wantSQL: "SELECT * FROM PersonsTable@{FORCE_INDEX=`my_index`} WHERE id=@id AND name=@name", + want: []string{"id", "name"}, + }, + { + input: "SELECT * FROM PersonsTable @{FORCE_INDEX=my_index} WHERE id=@id AND name=@name", + wantSQL: "SELECT * FROM PersonsTable @{FORCE_INDEX=my_index} WHERE id=@id AND name=@name", + want: []string{"id", "name"}, + }, + { + input: `SELECT * FROM PersonsTable WHERE id=?`, + wantSQL: `SELECT * FROM PersonsTable WHERE id=@p1`, + want: []string{"p1"}, + }, + { + input: `?'?test?"?test?"?'?`, + wantSQL: `@p1'?test?"?test?"?'@p2`, + want: []string{"p1", "p2"}, + }, + { + input: `?'?it\'?s'?`, + wantSQL: `@p1'?it\'?s'@p2`, + want: []string{"p1", "p2"}, + }, + { + input: `?'?it\"?s'?`, + wantSQL: `@p1'?it\"?s'@p2`, + want: []string{"p1", "p2"}, + }, + { + input: `?"?it\"?s"?`, + wantSQL: `@p1"?it\"?s"@p2`, + want: []string{"p1", "p2"}, + }, + { + input: `?'''?it\'?s'''?`, + wantSQL: `@p1'''?it\'?s'''@p2`, + want: []string{"p1", "p2"}, + }, + { + input: `?"""?it\"?s"""?`, + wantSQL: `@p1"""?it\"?s"""@p2`, + want: []string{"p1", "p2"}, + }, + { + input: `?` + "`?it" + `\` + "`?s`" + `?`, + wantSQL: `@p1` + "`?it" + `\` + "`?s`" + `@p2`, + want: []string{"p1", "p2"}, + }, + { + input: `?'''?it\'?s + ?it\'?s'''?`, + wantSQL: `@p1'''?it\'?s + ?it\'?s'''@p2`, + want: []string{"p1", "p2"}, }, { - input: "INSERT INTO Foo (Col1, Col2, Col3) VALUES (@param1, @param2, @param3)", - want: []string{"param1", "param2", "param3"}, + input: `?'''?it\'?s + ?it\'?s'''?`, + wantSQL: `@p1'''?it\'?s + ?it\'?s'''@p2`, + want: []string{"p1", "p2"}, }, { - input: "SELECT * FROM PersonsTable@{FORCE_INDEX=`my_index`} WHERE id=@id AND name=@name", - want: []string{"id", "name"}, + input: `select 1, ?, 'test?test', "test?test", foo.* from` + "`foo`" + `where col1=? and col2='test' and col3=? and col4='?' and col5="?" and col6='?''?''?'`, + wantSQL: `select 1, @p1, 'test?test', "test?test", foo.* from` + "`foo`" + `where col1=@p2 and col2='test' and col3=@p3 and col4='?' and col5="?" and col6='?''?''?'`, + want: []string{"p1", "p2", "p3"}, }, { - input: "SELECT * FROM PersonsTable @{FORCE_INDEX=my_index} WHERE id=@id AND name=@name", - want: []string{"id", "name"}, + input: `select * from foo where name=? and col2 like ? and col3 > ?`, + wantSQL: `select * from foo where name=@p1 and col2 like @p2 and col3 > @p3`, + want: []string{"p1", "p2", "p3"}, + }, + { + input: `select * from foo where id between ? and ?`, + wantSQL: `select * from foo where id between @p1 and @p2`, + want: []string{"p1", "p2"}, + }, + { + input: `select * from foo limit ? offset ?`, + wantSQL: `select * from foo limit @p1 offset @p2`, + want: []string{"p1", "p2"}, + }, + { + input: `select * from foo where col1=? and col2 like ? and col3 > ? and col4 < ? and col5 != ? and col6 not in (?, ?, ?) and col7 in (?, ?, ?) and col8 between ? and ?`, + wantSQL: `select * from foo where col1=@p1 and col2 like @p2 and col3 > @p3 and col4 < @p4 and col5 != @p5 and col6 not in (@p6, @p7, @p8) and col7 in (@p9, @p10, @p11) and col8 between @p12 and @p13`, + want: []string{"p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13"}, + }, + { + input: `select * from foo where ?='''strange @table'''`, + wantSQL: `select * from foo where @p1='''strange @table'''`, + want: []string{"p1"}, + }, + { + input: `select foo from bar where id=@ order by value`, + wantSQL: `select foo from bar where id=@ order by value`, + want: []string{}, + }, + { + input: `?'?it\'?s + ?it\'?s'?`, + wantErr: spanner.ToSpannerError(status.Errorf(codes.InvalidArgument, "statement contains an unclosed literal: %s", `?'?it\'?s + ?it\'?s'?`)), + skipRemoveComments: true, + }, + { + input: `?'?it\'?s + ?it\'?s?`, + wantErr: spanner.ToSpannerError(status.Errorf(codes.InvalidArgument, "statement contains an unclosed literal: %s", `?'?it\'?s + ?it\'?s?`)), + skipRemoveComments: true, + }, + { + input: `?'''?it\'?s + ?it\'?s'?`, + wantErr: spanner.ToSpannerError(status.Errorf(codes.InvalidArgument, "statement contains an unclosed literal: %s", `?'''?it\'?s + ?it\'?s'?`)), + skipRemoveComments: true, }, } for _, tc := range tests { - sql, err := removeCommentsAndTrim(tc.input) - if err != nil { - t.Fatal(err) + sql := tc.input + if !tc.skipRemoveComments { + var err error + sql, err = removeCommentsAndTrim(tc.input) + if err != nil && tc.wantErr == nil { + t.Fatal(err) + } } - got, err := parseNamedParameters(removeStatementHint(sql)) - if err != nil && !tc.wantErr { + gotSQL, got, err := parseParameters(sql) + if err != nil && tc.wantErr == nil { t.Error(err) continue } - if tc.wantErr { - t.Errorf("missing expected error for %q", tc.input) + if tc.wantErr != nil { + if err == nil { + t.Errorf("missing expected error for %q", tc.input) + continue + } + if !cmp.Equal(err.Error(), tc.wantErr.Error()) { + t.Errorf("parseParameters error mismatch\nGot: %s\nWant: %s", err.Error(), tc.wantErr) + } continue } if !cmp.Equal(got, tc.want) { - t.Errorf("parseNamedParameters result mismatch\nGot: %s\nWant: %s", got, tc.want) + t.Errorf("parseParameters result mismatch\nGot: %s\nWant: %s", got, tc.want) + } + if !cmp.Equal(gotSQL, tc.wantSQL) { + t.Errorf("parseParameters sql mismatch\nGot: %s\nWant: %s", gotSQL, tc.wantSQL) } } } @@ -776,11 +912,11 @@ func TestRemoveCommentsAndTrim_Errors(t *testing.T) { } func TestFindParams_Errors(t *testing.T) { - _, err := findParams("SELECT 'Hello World FROM SomeTable WHERE id=@id") + _, _, err := findParams('?', "SELECT 'Hello World FROM SomeTable WHERE id=@id") if g, w := spanner.ErrCode(err), codes.InvalidArgument; g != w { t.Errorf("error code mismatch\nGot: %v\nWant: %v\n", g, w) } - _, err = findParams("SELECT 'Hello World\nFROM SomeTable WHERE id=@id") + _, _, err = findParams('?', "SELECT 'Hello World\nFROM SomeTable WHERE id=@id") if g, w := spanner.ErrCode(err), codes.InvalidArgument; g != w { t.Errorf("error code mismatch\nGot: %v\nWant: %v\n", g, w) } diff --git a/stmt.go b/stmt.go index 2f37be1a..a12ddddb 100644 --- a/stmt.go +++ b/stmt.go @@ -65,7 +65,7 @@ func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv } func prepareSpannerStmt(q string, args []driver.NamedValue) (spanner.Statement, error) { - names, err := parseNamedParameters(q) + q, names, err := parseParameters(q) if err != nil { return spanner.Statement{}, err }