diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/commands/ListCommand.java b/src/main/java/com/google/cloud/spanner/pgadapter/commands/ListCommand.java index 0fe2874ce..78e4da533 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/commands/ListCommand.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/commands/ListCommand.java @@ -31,7 +31,7 @@ public class ListCommand extends Command { + " pg_catalog\\.pg_encoding_to_char\\(d\\.encoding\\) as \"Encoding\",\n" + " pg_catalog\\.array_to_string\\(d\\.datacl, '\\\\n'\\) AS \"Access privileges\"\n" + "FROM pg_catalog\\.pg_database d\n.*\n?" - + "ORDER BY 1;$"); + + "ORDER BY 1;?$"); private static final String OUTPUT_QUERY = "SELECT '%s' AS Name;"; diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/CopyStatement.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/CopyStatement.java index 1ec41d7f8..f73f99fd6 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/CopyStatement.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/CopyStatement.java @@ -22,6 +22,7 @@ import com.google.cloud.spanner.SpannerException; import com.google.cloud.spanner.SpannerExceptionFactory; import com.google.cloud.spanner.Statement; +import com.google.cloud.spanner.connection.AbstractStatementParser.ParsedStatement; import com.google.cloud.spanner.connection.AutocommitDmlMode; import com.google.cloud.spanner.connection.Connection; import com.google.cloud.spanner.connection.StatementResult.ResultType; @@ -30,7 +31,6 @@ import com.google.cloud.spanner.pgadapter.parsers.copy.TokenMgrError; import com.google.cloud.spanner.pgadapter.utils.MutationWriter; import com.google.cloud.spanner.pgadapter.utils.MutationWriter.CopyTransactionMode; -import com.google.cloud.spanner.pgadapter.utils.StatementParser; import com.google.common.base.Strings; import com.google.spanner.v1.TypeCode; import java.util.LinkedHashMap; @@ -61,11 +61,9 @@ public class CopyStatement extends IntermediateStatement { private Future updateCount; private final ExecutorService executor = Executors.newSingleThreadExecutor(); - public CopyStatement(OptionsMetadata options, String sql, Connection connection) { - super(options, sql); - this.sql = sql; - this.command = StatementParser.parseCommand(sql); - this.connection = connection; + public CopyStatement( + OptionsMetadata options, ParsedStatement parsedStatement, Connection connection) { + super(options, parsedStatement, connection); } @Override @@ -332,7 +330,7 @@ private CopyTransactionMode getTransactionMode() { private void parseCopyStatement() throws Exception { try { - parse(sql, this.options); + parse(parsedStatement.getSqlWithoutComments(), this.options); } catch (Exception | TokenMgrError e) { throw SpannerExceptionFactory.newSpannerException( ErrorCode.INVALID_ARGUMENT, "Invalid COPY statement syntax: " + e); diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/IntermediatePortalStatement.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/IntermediatePortalStatement.java index cb842f3c2..bce37c02a 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/IntermediatePortalStatement.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/IntermediatePortalStatement.java @@ -18,6 +18,7 @@ import com.google.cloud.spanner.ResultSet; import com.google.cloud.spanner.SpannerException; import com.google.cloud.spanner.Statement; +import com.google.cloud.spanner.connection.AbstractStatementParser.ParsedStatement; import com.google.cloud.spanner.connection.Connection; import com.google.cloud.spanner.pgadapter.metadata.DescribeMetadata; import com.google.cloud.spanner.pgadapter.metadata.DescribePortalMetadata; @@ -38,9 +39,10 @@ public class IntermediatePortalStatement extends IntermediatePreparedStatement { protected List parameterFormatCodes; protected List resultFormatCodes; - public IntermediatePortalStatement(OptionsMetadata options, String sql, Connection connection) { - super(options, sql, connection); - this.statement = Statement.of(sql); + public IntermediatePortalStatement( + OptionsMetadata options, ParsedStatement parsedStatement, Connection connection) { + super(options, parsedStatement, connection); + this.statement = Statement.of(parsedStatement.getSqlWithoutComments()); this.parameterFormatCodes = new ArrayList<>(); this.resultFormatCodes = new ArrayList<>(); } diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/IntermediatePreparedStatement.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/IntermediatePreparedStatement.java index d2fb47bc7..63809e8c5 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/IntermediatePreparedStatement.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/IntermediatePreparedStatement.java @@ -18,6 +18,7 @@ import com.google.cloud.spanner.ResultSet; import com.google.cloud.spanner.SpannerException; import com.google.cloud.spanner.Statement; +import com.google.cloud.spanner.connection.AbstractStatementParser.ParsedStatement; import com.google.cloud.spanner.connection.Connection; import com.google.cloud.spanner.connection.StatementResult; import com.google.cloud.spanner.pgadapter.metadata.DescribeMetadata; @@ -25,9 +26,6 @@ import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata; import com.google.cloud.spanner.pgadapter.parsers.Parser; import com.google.cloud.spanner.pgadapter.parsers.Parser.FormatCode; -import com.google.cloud.spanner.pgadapter.utils.StatementParser; -import java.nio.charset.Charset; -import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Set; import org.postgresql.core.Oid; @@ -37,15 +35,12 @@ */ public class IntermediatePreparedStatement extends IntermediateStatement { - private static final Charset UTF8 = StandardCharsets.UTF_8; protected List parameterDataTypes; protected Statement statement; - public IntermediatePreparedStatement(OptionsMetadata options, String sql, Connection connection) { - super(options, sql); - this.sql = replaceKnownUnsupportedQueries(sql); - this.command = StatementParser.parseCommand(this.sql); - this.connection = connection; + public IntermediatePreparedStatement( + OptionsMetadata options, ParsedStatement parsedStatement, Connection connection) { + super(options, parsedStatement, connection); this.parameterDataTypes = null; } @@ -96,10 +91,10 @@ public void execute() { public IntermediatePortalStatement bind( byte[][] parameters, List parameterFormatCodes, List resultFormatCodes) { IntermediatePortalStatement portal = - new IntermediatePortalStatement(this.options, this.sql, this.connection); + new IntermediatePortalStatement(this.options, this.parsedStatement, this.connection); portal.setParameterFormatCodes(parameterFormatCodes); portal.setResultFormatCodes(resultFormatCodes); - Statement.Builder builder = Statement.newBuilder(sql); + Statement.Builder builder = Statement.newBuilder(this.parsedStatement.getSqlWithoutComments()); for (int index = 0; index < parameters.length; index++) { short formatCode = portal.getParameterFormatCode(index); int type = this.parseType(parameters, index); @@ -114,8 +109,8 @@ public IntermediatePortalStatement bind( @Override public DescribeMetadata describe() { - if (PARSER.isQuery(this.sql)) { - Statement statement = Statement.of(this.sql); + if (this.parsedStatement.isQuery()) { + Statement statement = Statement.of(this.parsedStatement.getSqlWithoutComments()); try (ResultSet resultSet = connection.analyzeQuery(statement, QueryAnalyzeMode.PLAN)) { // TODO: Remove ResultSet.next() call once this is supported in the client library. // See https://github.com/googleapis/java-spanner/pull/1691 @@ -132,7 +127,8 @@ public DescribeMetadata describe() { * parameter types. */ private int[] getParameterTypes() { - Set parameters = PARSER.getQueryParameters(this.sql); + Set parameters = + PARSER.getQueryParameters(this.parsedStatement.getSqlWithoutComments()); return new int[parameters.size()]; } } diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/IntermediateStatement.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/IntermediateStatement.java index ed0932d6c..07bbd3a5e 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/IntermediateStatement.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/IntermediateStatement.java @@ -21,15 +21,15 @@ import com.google.cloud.spanner.SpannerExceptionFactory; import com.google.cloud.spanner.Statement; import com.google.cloud.spanner.connection.AbstractStatementParser; +import com.google.cloud.spanner.connection.AbstractStatementParser.ParsedStatement; import com.google.cloud.spanner.connection.Connection; import com.google.cloud.spanner.connection.PostgreSQLStatementParser; import com.google.cloud.spanner.connection.StatementResult; -import com.google.cloud.spanner.connection.StatementResult.ResultType; import com.google.cloud.spanner.pgadapter.metadata.DescribeMetadata; import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata; import com.google.cloud.spanner.pgadapter.utils.StatementParser; import com.google.common.base.Preconditions; -import java.util.ArrayList; +import com.google.common.collect.ImmutableList; import java.util.List; /** @@ -46,54 +46,66 @@ public class IntermediateStatement { protected ResultSet statementResult; protected boolean hasMoreData; protected SpannerException exception; - protected String sql; - protected String command; + protected final ParsedStatement parsedStatement; + protected final String command; protected boolean executed; - protected Connection connection; + protected final Connection connection; protected Long updateCount; - protected List statements; + protected final ImmutableList statements; private static final char STATEMENT_DELIMITER = ';'; private static final char SINGLE_QUOTE = '\''; - public IntermediateStatement(OptionsMetadata options, String sql, Connection connection) { + public IntermediateStatement( + OptionsMetadata options, ParsedStatement parsedStatement, Connection connection) { + this( + options, + parsedStatement, + connection, + parseStatements(parsedStatement.getSqlWithoutComments())); + } + + protected IntermediateStatement( + OptionsMetadata options, + ParsedStatement parsedStatement, + Connection connection, + ImmutableList statements) { this.options = options; - this.sql = replaceKnownUnsupportedQueries(sql); - this.statements = parseStatements(sql); - this.command = StatementParser.parseCommand(this.sql); + this.parsedStatement = replaceKnownUnsupportedQueries(parsedStatement); + this.statements = statements; + this.command = StatementParser.parseCommand(this.parsedStatement.getSqlWithoutComments()); this.connection = connection; // Note: This determines the result type based on the first statement in the SQL statement. That // means that it assumes that if this is a batch of statements, all the statements in the batch // will have the same type of result (that is; they are all DML statements, all DDL statements, // all queries, etc.). That is a safe assumption for now, as PgAdapter currently only supports // all-DML and all-DDL batches. - this.resultType = determineResultType(this.sql); - } - - protected IntermediateStatement(OptionsMetadata options, String sql) { - this.options = options; - this.resultType = determineResultType(sql); + this.resultType = determineResultType(this.parsedStatement); } - protected String replaceKnownUnsupportedQueries(String sql) { + protected ParsedStatement replaceKnownUnsupportedQueries(ParsedStatement parsedStatement) { if (this.options.isReplaceJdbcMetadataQueries() - && JdbcMetadataStatementHelper.isPotentialJdbcMetadataStatement(sql)) { - return JdbcMetadataStatementHelper.replaceJdbcMetadataStatement(sql); + && JdbcMetadataStatementHelper.isPotentialJdbcMetadataStatement( + parsedStatement.getSqlWithoutComments())) { + return PARSER.parse( + Statement.of( + JdbcMetadataStatementHelper.replaceJdbcMetadataStatement( + parsedStatement.getSqlWithoutComments()))); } - return sql; + return parsedStatement; } /** * Determines the result type based on the given sql string. The sql string must already been * stripped of any comments that might precede the actual sql string. * - * @param sql The sql string to determine the type of result for + * @param parsedStatement The parsed statement to determine the type of result for * @return The {@link ResultType} that the given sql string will produce */ - protected static ResultType determineResultType(String sql) { - if (PARSER.isUpdateStatement(sql)) { + protected static ResultType determineResultType(ParsedStatement parsedStatement) { + if (parsedStatement.isUpdate()) { return ResultType.UPDATE_COUNT; - } else if (PARSER.isQuery(sql)) { + } else if (parsedStatement.isQuery()) { return ResultType.RESULT_SET; } else { return ResultType.NO_RESULT; @@ -101,34 +113,43 @@ protected static ResultType determineResultType(String sql) { } // Split statements by ';' delimiter, but ignore anything that is nested with '' or "". - private List splitStatements(String sql) { - List statements = new ArrayList<>(); - boolean quoteEsacpe = false; + private static ImmutableList splitStatements(String sql) { + // First check trivial cases with only one statement. + int firstIndexOfDelimiter = sql.indexOf(STATEMENT_DELIMITER); + if (firstIndexOfDelimiter == -1) { + return ImmutableList.of(sql); + } + if (firstIndexOfDelimiter == sql.length() - 1) { + return ImmutableList.of(sql.substring(0, sql.length() - 1)); + } + + ImmutableList.Builder builder = ImmutableList.builder(); + // TODO: Fix this parsing, as it does not take all types of quotes into consideration. + boolean quoteEscape = false; int index = 0; for (int i = 0; i < sql.length(); ++i) { if (sql.charAt(i) == SINGLE_QUOTE) { - quoteEsacpe = !quoteEsacpe; + quoteEscape = !quoteEscape; } - if (sql.charAt(i) == STATEMENT_DELIMITER && !quoteEsacpe) { - String stmt = sql.substring(index, i + 1).trim(); + if (sql.charAt(i) == STATEMENT_DELIMITER && !quoteEscape) { + String stmt = sql.substring(index, i).trim(); // Statements with only ';' character are empty and dropped. - if (stmt.length() > 1) { - statements.add(stmt); + if (stmt.length() > 0) { + builder.add(stmt); } index = i + 1; } } if (index < sql.length()) { - statements.add(sql.substring(index, sql.length()).trim()); + builder.add(sql.substring(index).trim()); } - return statements; + return builder.build(); } - protected List parseStatements(String sql) { + protected static ImmutableList parseStatements(String sql) { Preconditions.checkNotNull(sql); - List statements = splitStatements(sql); - return statements; + return splitStatements(sql); } /** @@ -205,7 +226,7 @@ public ResultType getResultType() { } public String getSql() { - return this.sql; + return this.parsedStatement.getSqlWithoutComments(); } public Exception getException() { @@ -282,14 +303,17 @@ public void execute() { long[] updateCounts = connection.runBatch(); updateBatchResultCount(updateCounts); } else { - StatementResult result = connection.execute(Statement.of(this.sql)); + StatementResult result = + connection.execute(Statement.of(this.parsedStatement.getSqlWithoutComments())); updateResultCount(result); } } catch (SpannerException e) { if (statements.size() > 1) { SpannerException exception = SpannerExceptionFactory.newSpannerException( - e.getErrorCode(), e.getMessage() + " \"" + this.sql + "\"", e); + e.getErrorCode(), + e.getMessage() + " \"" + this.parsedStatement.getSqlWithoutComments() + "\"", + e); handleExecutionException(exception); } else { handleExecutionException(e); diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/MatcherStatement.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/MatcherStatement.java index 808462720..756000194 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/MatcherStatement.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/MatcherStatement.java @@ -14,10 +14,12 @@ package com.google.cloud.spanner.pgadapter.statements; +import com.google.cloud.spanner.Statement; +import com.google.cloud.spanner.connection.AbstractStatementParser.ParsedStatement; +import com.google.cloud.spanner.connection.Connection; import com.google.cloud.spanner.pgadapter.ConnectionHandler; import com.google.cloud.spanner.pgadapter.commands.Command; import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata; -import com.google.cloud.spanner.pgadapter.utils.StatementParser; import org.json.simple.JSONObject; /** @@ -28,16 +30,17 @@ */ public class MatcherStatement extends IntermediateStatement { - private JSONObject commandMetadataJSON; - public MatcherStatement( - OptionsMetadata options, String sql, ConnectionHandler connectionHandler) { - super(options, sql); - this.connection = connectionHandler.getSpannerConnection(); - this.commandMetadataJSON = connectionHandler.getServer().getOptions().getCommandMetadataJSON(); - this.sql = translateSQL(sql); - this.statements = parseStatements(sql); - this.command = StatementParser.parseCommand(sql); + OptionsMetadata options, + ParsedStatement parsedStatement, + ConnectionHandler connectionHandler) { + super( + options, + translateSQL( + parsedStatement, + connectionHandler.getSpannerConnection(), + connectionHandler.getServer().getOptions().getCommandMetadataJSON()), + connectionHandler.getSpannerConnection()); } @Override @@ -49,17 +52,19 @@ public void execute() { * Translate a Postgres Specific command into something Spanner can handle. Currently, this is * only concerned with PSQL specific meta-commands. * - * @param sql The SQL statement to be translated. + * @param parsedStatement The SQL statement to be translated. * @return The translated SQL statement if it matches any {@link Command} statement. Otherwise * gives out the original Statement. */ - private String translateSQL(String sql) { + private static ParsedStatement translateSQL( + ParsedStatement parsedStatement, Connection connection, JSONObject commandMetadataJSON) { for (Command currentCommand : - Command.getCommands(sql, this.connection, this.commandMetadataJSON)) { + Command.getCommands( + parsedStatement.getSqlWithoutComments(), connection, commandMetadataJSON)) { if (currentCommand.is()) { - return currentCommand.translate(); + return PARSER.parse(Statement.of(currentCommand.translate())); } } - return sql; + return parsedStatement; } } diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/utils/StatementParser.java b/src/main/java/com/google/cloud/spanner/pgadapter/utils/StatementParser.java index 2ad32f46d..f42c9731b 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/utils/StatementParser.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/utils/StatementParser.java @@ -43,10 +43,25 @@ public static String singleQuoteEscape(String sql) { /** Determines the (update) command that was received from the sql string. */ public static String parseCommand(String sql) { Preconditions.checkNotNull(sql); - String[] tokens = sql.split("\\s+", 2); - if (tokens.length > 0) { - return tokens[0].toUpperCase(); + for (int i = 0; i < sql.length(); i++) { + if (Character.isSpaceChar(sql.charAt(i))) { + return sql.substring(0, i).toUpperCase(); + } + } + return sql; + } + + /** Returns true if the given sql string is the given command. */ + public static boolean isCommand(String command, String query) { + Preconditions.checkNotNull(command); + Preconditions.checkNotNull(query); + if (query.equalsIgnoreCase(command)) { + return true; + } + if (query.length() <= command.length()) { + return false; } - return null; + return Character.isSpaceChar(query.charAt(command.length())) + && query.substring(0, command.length()).equalsIgnoreCase(command); } } diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/ParseMessage.java b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/ParseMessage.java index db9440770..901460725 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/ParseMessage.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/ParseMessage.java @@ -15,14 +15,15 @@ package com.google.cloud.spanner.pgadapter.wireprotocol; import com.google.cloud.spanner.Dialect; +import com.google.cloud.spanner.Statement; import com.google.cloud.spanner.connection.AbstractStatementParser; +import com.google.cloud.spanner.connection.AbstractStatementParser.ParsedStatement; import com.google.cloud.spanner.pgadapter.ConnectionHandler; import com.google.cloud.spanner.pgadapter.statements.IntermediatePreparedStatement; import com.google.cloud.spanner.pgadapter.wireoutput.ParseCompleteResponse; import com.google.common.base.Strings; +import com.google.common.collect.ImmutableList; import java.text.MessageFormat; -import java.util.ArrayList; -import java.util.List; /** Creates a prepared statement. */ public class ParseMessage extends ControlMessage { @@ -30,23 +31,26 @@ public class ParseMessage extends ControlMessage { AbstractStatementParser.getInstance(Dialect.POSTGRESQL); protected static final char IDENTIFIER = 'P'; - private String name; - private IntermediatePreparedStatement statement; - private List parameterDataTypes; + private final String name; + private final IntermediatePreparedStatement statement; + private final ImmutableList parameterDataTypes; public ParseMessage(ConnectionHandler connection) throws Exception { super(connection); this.name = this.readString(); - String queryString = PARSER.removeCommentsAndTrim(this.readString()); - this.parameterDataTypes = new ArrayList<>(); + ParsedStatement parsedStatement = PARSER.parse(Statement.of(this.readString())); + ImmutableList.Builder builder = ImmutableList.builder(); short numberOfParameters = this.inputStream.readShort(); for (int i = 0; i < numberOfParameters; i++) { int type = this.inputStream.readInt(); - this.parameterDataTypes.add(type); + builder.add(type); } + this.parameterDataTypes = builder.build(); this.statement = new IntermediatePreparedStatement( - connection.getServer().getOptions(), queryString, connection.getSpannerConnection()); + connection.getServer().getOptions(), + parsedStatement, + connection.getSpannerConnection()); this.statement.setParameterDataTypes(this.parameterDataTypes); } diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/QueryMessage.java b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/QueryMessage.java index 589a5ad1a..7dca2ac83 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/QueryMessage.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/QueryMessage.java @@ -15,7 +15,9 @@ package com.google.cloud.spanner.pgadapter.wireprotocol; import com.google.cloud.spanner.Dialect; +import com.google.cloud.spanner.Statement; import com.google.cloud.spanner.connection.AbstractStatementParser; +import com.google.cloud.spanner.connection.AbstractStatementParser.ParsedStatement; import com.google.cloud.spanner.pgadapter.ConnectionHandler; import com.google.cloud.spanner.pgadapter.ConnectionHandler.ConnectionStatus; import com.google.cloud.spanner.pgadapter.ConnectionHandler.QueryMode; @@ -34,25 +36,31 @@ public class QueryMessage extends ControlMessage { private static final AbstractStatementParser PARSER = AbstractStatementParser.getInstance(Dialect.POSTGRESQL); protected static final char IDENTIFIER = 'Q'; - protected static final String COPY = "COPY"; + public static final String COPY = "COPY"; - private IntermediateStatement statement; + private final boolean isCopy; + private final IntermediateStatement statement; public QueryMessage(ConnectionHandler connection) throws Exception { super(connection); - String query = PARSER.removeCommentsAndTrim(this.readAll()); - String command = StatementParser.parseCommand(query); - if (COPY.equalsIgnoreCase(command)) { + ParsedStatement parsedStatement = PARSER.parse(Statement.of(this.readAll())); + this.isCopy = StatementParser.isCommand(COPY, parsedStatement.getSqlWithoutComments()); + if (isCopy) { this.statement = new CopyStatement( - connection.getServer().getOptions(), query, this.connection.getSpannerConnection()); + connection.getServer().getOptions(), + parsedStatement, + this.connection.getSpannerConnection()); } else if (!connection.getServer().getOptions().requiresMatcher()) { this.statement = new IntermediateStatement( - connection.getServer().getOptions(), query, this.connection.getSpannerConnection()); + connection.getServer().getOptions(), + parsedStatement, + this.connection.getSpannerConnection()); } else { this.statement = - new MatcherStatement(connection.getServer().getOptions(), query, this.connection); + new MatcherStatement( + connection.getServer().getOptions(), parsedStatement, this.connection); } this.connection.addActiveStatement(this.statement); } diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/PSQLTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/PSQLTest.java index 5652ae41c..902ef17a6 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/PSQLTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/PSQLTest.java @@ -14,6 +14,10 @@ package com.google.cloud.spanner.pgadapter; +import com.google.cloud.spanner.Dialect; +import com.google.cloud.spanner.Statement; +import com.google.cloud.spanner.connection.AbstractStatementParser; +import com.google.cloud.spanner.connection.AbstractStatementParser.ParsedStatement; import com.google.cloud.spanner.connection.Connection; import com.google.cloud.spanner.pgadapter.metadata.CommandMetadataParser; import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata; @@ -33,6 +37,12 @@ @RunWith(JUnit4.class) public class PSQLTest { + private static final AbstractStatementParser PARSER = + AbstractStatementParser.getInstance(Dialect.POSTGRESQL); + + private static ParsedStatement parse(String sql) { + return PARSER.parse(Statement.of(sql)); + } @Rule public MockitoRule rule = MockitoJUnit.rule(); @@ -81,9 +91,10 @@ public void testDescribeTranslates() { + "FROM" + " information_schema.tables AS t " + "WHERE" - + " t.table_schema = 'public';"; + + " t.table_schema = 'public'"; - MatcherStatement matcherStatement = new MatcherStatement(options, sql, connectionHandler); + MatcherStatement matcherStatement = + new MatcherStatement(options, parse(sql), connectionHandler); Assert.assertEquals(matcherStatement.getSql(), result); } @@ -110,9 +121,10 @@ public void testDescribeTableMatchTranslates() { + " WHERE" + " t.table_schema='public'" + " AND" - + " LOWER(t.table_name) = LOWER('users');"; + + " LOWER(t.table_name) = LOWER('users')"; - MatcherStatement matcherStatement = new MatcherStatement(options, sql, connectionHandler); + MatcherStatement matcherStatement = + new MatcherStatement(options, parse(sql), connectionHandler); Assert.assertEquals(matcherStatement.getSql(), result); } @@ -139,9 +151,10 @@ public void testDescribeTableMatchHandlesBobbyTables() { + " WHERE" + " t.table_schema='public'" + " AND" - + " LOWER(t.table_name) = LOWER('bobby\\'; DROP TABLE USERS; SELECT\\'');"; + + " LOWER(t.table_name) = LOWER('bobby\\'; DROP TABLE USERS; SELECT\\'')"; - MatcherStatement matcherStatement = new MatcherStatement(options, sql, connectionHandler); + MatcherStatement matcherStatement = + new MatcherStatement(options, parse(sql), connectionHandler); Assert.assertEquals(matcherStatement.getSql(), result); } @@ -164,9 +177,10 @@ public void testDescribeTableCatalogTranslates() { + " false as bool2," + " false as relhasoids," + " '' as str1," - + " '' as str2;"; + + " '' as str2"; - MatcherStatement matcherStatement = new MatcherStatement(options, sql, connectionHandler); + MatcherStatement matcherStatement = + new MatcherStatement(options, parse(sql), connectionHandler); Assert.assertEquals(matcherStatement.getSql(), result); } @@ -200,9 +214,10 @@ public void testDescribeTableMetadataTranslates() { + " information_schema.columns AS t" + " WHERE" + " t.table_schema='public'" - + " AND t.table_name = '-1';"; + + " AND t.table_name = '-1'"; - MatcherStatement matcherStatement = new MatcherStatement(options, sql, connectionHandler); + MatcherStatement matcherStatement = + new MatcherStatement(options, parse(sql), connectionHandler); Assert.assertEquals(matcherStatement.getSql(), result); } @@ -237,9 +252,10 @@ public void testDescribeTableMetadataHandlesBobbyTables() { + " information_schema.columns AS t" + " WHERE" + " t.table_schema='public'" - + " AND t.table_name = 'bobby\\'; DROP TABLE USERS; SELECT\\'';"; + + " AND t.table_name = 'bobby\\'; DROP TABLE USERS; SELECT\\''"; - MatcherStatement matcherStatement = new MatcherStatement(options, sql, connectionHandler); + MatcherStatement matcherStatement = + new MatcherStatement(options, parse(sql), connectionHandler); Assert.assertEquals(matcherStatement.getSql(), result); } @@ -251,9 +267,10 @@ public void testDescribeTableAttributesTranslates() { "SELECT c.oid::pg_catalog.regclass FROM pg_catalog.pg_class c, pg_catalog.pg_inherits i" + " WHERE c.oid=i.inhparent AND i.inhrelid = '-2264987671676060158' AND c.relkind !=" + " 'p' ORDER BY inhseqno;"; - String result = "SELECT 1 LIMIT 0;"; + String result = "SELECT 1 LIMIT 0"; - MatcherStatement matcherStatement = new MatcherStatement(options, sql, connectionHandler); + MatcherStatement matcherStatement = + new MatcherStatement(options, parse(sql), connectionHandler); Assert.assertEquals(matcherStatement.getSql(), result); } @@ -265,9 +282,10 @@ public void testDescribeMoreTableAttributesTranslates() { "SELECT c.oid::pg_catalog.regclass FROM pg_catalog.pg_class c, pg_catalog.pg_inherits i" + " WHERE c.oid=i.inhrelid AND i.inhparent = '-2264987671676060158' ORDER BY" + " c.relname;"; - String result = "SELECT 1 LIMIT 0;"; + String result = "SELECT 1 LIMIT 0"; - MatcherStatement matcherStatement = new MatcherStatement(options, sql, connectionHandler); + MatcherStatement matcherStatement = + new MatcherStatement(options, parse(sql), connectionHandler); Assert.assertEquals(matcherStatement.getSql(), result); } @@ -282,9 +300,10 @@ public void testListCommandTranslates() { + " pg_catalog.array_to_string(d.datacl, '\\n') AS \"Access privileges\"\n" + "FROM pg_catalog.pg_database d\n" + "ORDER BY 1;"; - String result = "SELECT '' AS Name;"; + String result = "SELECT '' AS Name"; - MatcherStatement matcherStatement = new MatcherStatement(options, sql, connectionHandler); + MatcherStatement matcherStatement = + new MatcherStatement(options, parse(sql), connectionHandler); Assert.assertEquals(matcherStatement.getSql(), result); } @@ -300,9 +319,10 @@ public void testListDatabaseCommandTranslates() { + "FROM pg_catalog.pg_database d\n" + "WHERE d.datname OPERATOR(pg_catalog.~) '^(users)$'\n" + "ORDER BY 1;"; - String result = "SELECT '' AS Name;"; + String result = "SELECT '' AS Name"; - MatcherStatement matcherStatement = new MatcherStatement(options, sql, connectionHandler); + MatcherStatement matcherStatement = + new MatcherStatement(options, parse(sql), connectionHandler); Assert.assertEquals(matcherStatement.getSql(), result); } @@ -318,11 +338,12 @@ public void testListDatabaseCommandFailsButPrintsUnknown() { + "FROM pg_catalog.pg_database d\n" + "WHERE d.datname OPERATOR(pg_catalog.~) '^(users)$'\n" + "ORDER BY 1;"; - String result = "SELECT '' AS Name;"; + String result = "SELECT '' AS Name"; // TODO: Add Connection#getDatabase() to Connection API and test here what happens if that // method throws an exception. - MatcherStatement matcherStatement = new MatcherStatement(options, sql, connectionHandler); + MatcherStatement matcherStatement = + new MatcherStatement(options, parse(sql), connectionHandler); Assert.assertEquals(matcherStatement.getSql(), result); } @@ -346,9 +367,10 @@ public void testDescribeAllTableMetadataTranslates() { + " AND n.nspname !~ '^pg_toast'\n" + " AND pg_catalog.pg_table_is_visible(c.oid)\n" + "ORDER BY 1,2;"; - String result = "SELECT * FROM information_schema.tables;"; + String result = "SELECT * FROM information_schema.tables"; - MatcherStatement matcherStatement = new MatcherStatement(options, sql, connectionHandler); + MatcherStatement matcherStatement = + new MatcherStatement(options, parse(sql), connectionHandler); Assert.assertEquals(matcherStatement.getSql(), result); } @@ -372,9 +394,10 @@ public void testDescribeSelectedTableMetadataTranslates() { + " AND pg_catalog.pg_table_is_visible(c.oid)\n" + "ORDER BY 1,2;"; String result = - "SELECT * FROM information_schema.tables WHERE LOWER(table_name) = LOWER('users');"; + "SELECT * FROM information_schema.tables WHERE LOWER(table_name) = LOWER('users')"; - MatcherStatement matcherStatement = new MatcherStatement(options, sql, connectionHandler); + MatcherStatement matcherStatement = + new MatcherStatement(options, parse(sql), connectionHandler); Assert.assertEquals(matcherStatement.getSql(), result); } @@ -399,9 +422,10 @@ public void testDescribeSelectedTableMetadataHandlesBobbyTables() { + "ORDER BY 1,2;"; String result = "SELECT * FROM information_schema.tables WHERE LOWER(table_name) =" - + " LOWER('bobby\\'; DROP TABLE USERS; SELECT\\'');"; + + " LOWER('bobby\\'; DROP TABLE USERS; SELECT\\'')"; - MatcherStatement matcherStatement = new MatcherStatement(options, sql, connectionHandler); + MatcherStatement matcherStatement = + new MatcherStatement(options, parse(sql), connectionHandler); Assert.assertEquals(matcherStatement.getSql(), result); } @@ -429,9 +453,10 @@ public void testDescribeAllIndexMetadataTranslates() { + " AND pg_catalog.pg_table_is_visible(c.oid)\n" + "ORDER BY 1,2;"; String result = - "SELECT table_catalog, table_schema, table_name, index_name, index_type, parent_table_name, is_unique, is_null_filtered, index_state, spanner_is_managed FROM information_schema.indexes;"; + "SELECT table_catalog, table_schema, table_name, index_name, index_type, parent_table_name, is_unique, is_null_filtered, index_state, spanner_is_managed FROM information_schema.indexes"; - MatcherStatement matcherStatement = new MatcherStatement(options, sql, connectionHandler); + MatcherStatement matcherStatement = + new MatcherStatement(options, parse(sql), connectionHandler); Assert.assertEquals(matcherStatement.getSql(), result); } @@ -459,9 +484,10 @@ public void testDescribeSelectedIndexMetadataTranslates() { + "ORDER BY 1,2;"; String result = "SELECT table_catalog, table_schema, table_name, index_name, index_type, parent_table_name, is_unique, is_null_filtered, index_state, spanner_is_managed FROM information_schema.indexes WHERE LOWER(index_name) =" - + " LOWER('index');"; + + " LOWER('index')"; - MatcherStatement matcherStatement = new MatcherStatement(options, sql, connectionHandler); + MatcherStatement matcherStatement = + new MatcherStatement(options, parse(sql), connectionHandler); Assert.assertEquals(matcherStatement.getSql(), result); } @@ -489,9 +515,10 @@ public void testDescribeSelectedIndexMetadataHandlesBobbyTables() { + "ORDER BY 1,2;"; String result = "SELECT table_catalog, table_schema, table_name, index_name, index_type, parent_table_name, is_unique, is_null_filtered, index_state, spanner_is_managed FROM information_schema.indexes WHERE LOWER(index_name) =" - + " LOWER('bobby\\'; DROP TABLE USERS; SELECT\\'');"; + + " LOWER('bobby\\'; DROP TABLE USERS; SELECT\\'')"; - MatcherStatement matcherStatement = new MatcherStatement(options, sql, connectionHandler); + MatcherStatement matcherStatement = + new MatcherStatement(options, parse(sql), connectionHandler); Assert.assertEquals(matcherStatement.getSql(), result); } @@ -505,9 +532,10 @@ public void testDescribeAllSchemaMetadataTranslates() { + "FROM pg_catalog.pg_namespace n\n" + "WHERE n.nspname !~ '^pg_' AND n.nspname <> 'information_schema'\n" + "ORDER BY 1;"; - String result = "SELECT * FROM information_schema.schemata;"; + String result = "SELECT * FROM information_schema.schemata"; - MatcherStatement matcherStatement = new MatcherStatement(options, sql, connectionHandler); + MatcherStatement matcherStatement = + new MatcherStatement(options, parse(sql), connectionHandler); Assert.assertEquals(matcherStatement.getSql(), result); } @@ -522,9 +550,10 @@ public void testDescribeSelectedSchemaMetadataTranslates() { + "WHERE n.nspname OPERATOR(pg_catalog.~) '^(schema)$'\n" + "ORDER BY 1;"; String result = - "SELECT * FROM information_schema.schemata WHERE LOWER(schema_name) = LOWER('schema');"; + "SELECT * FROM information_schema.schemata WHERE LOWER(schema_name) = LOWER('schema')"; - MatcherStatement matcherStatement = new MatcherStatement(options, sql, connectionHandler); + MatcherStatement matcherStatement = + new MatcherStatement(options, parse(sql), connectionHandler); Assert.assertEquals(matcherStatement.getSql(), result); } @@ -540,9 +569,10 @@ public void testDescribeSelectedSchemaMetadataHandlesBobbyTables() { + "ORDER BY 1;"; String result = "SELECT * FROM information_schema.schemata WHERE LOWER(schema_name) = LOWER('bobby\\'; DROP" - + " TABLE USERS; SELECT\\'');"; + + " TABLE USERS; SELECT\\'')"; - MatcherStatement matcherStatement = new MatcherStatement(options, sql, connectionHandler); + MatcherStatement matcherStatement = + new MatcherStatement(options, parse(sql), connectionHandler); Assert.assertEquals(matcherStatement.getSql(), result); } @@ -577,9 +607,10 @@ public void testTableSelectAutocomplete() { String result = "SELECT table_name AS quote_ident FROM information_schema.tables WHERE" + " table_schema = 'public' and STARTS_WITH(LOWER(table_name)," - + " LOWER('user')) LIMIT 1000;"; + + " LOWER('user')) LIMIT 1000"; - MatcherStatement matcherStatement = new MatcherStatement(options, sql, connectionHandler); + MatcherStatement matcherStatement = + new MatcherStatement(options, parse(sql), connectionHandler); Assert.assertEquals(matcherStatement.getSql(), result); } @@ -613,9 +644,10 @@ public void testTableInsertAutocomplete() { String result = "SELECT table_name AS quote_ident FROM information_schema.tables WHERE" + " table_schema = 'public' and STARTS_WITH(LOWER(table_name)," - + " LOWER('user')) LIMIT 1000;"; + + " LOWER('user')) LIMIT 1000"; - MatcherStatement matcherStatement = new MatcherStatement(options, sql, connectionHandler); + MatcherStatement matcherStatement = + new MatcherStatement(options, parse(sql), connectionHandler); Assert.assertEquals(matcherStatement.getSql(), result); } @@ -633,9 +665,10 @@ public void testTableAttributesAutocomplete() { + "LIMIT 1000"; String result = "SELECT column_name AS quote_ident FROM information_schema.columns WHERE" - + " table_name = 'user' AND STARTS_WITH(LOWER(COLUMN_NAME), LOWER('age')) LIMIT 1000;"; + + " table_name = 'user' AND STARTS_WITH(LOWER(COLUMN_NAME), LOWER('age')) LIMIT 1000"; - MatcherStatement matcherStatement = new MatcherStatement(options, sql, connectionHandler); + MatcherStatement matcherStatement = + new MatcherStatement(options, parse(sql), connectionHandler); Assert.assertEquals(matcherStatement.getSql(), result); } @@ -667,9 +700,10 @@ public void testDescribeTableAutocomplete() { + "LIMIT 1000"; String result = "SELECT table_name AS quote_ident FROM information_schema.tables WHERE " - + "table_schema = 'public' AND STARTS_WITH(LOWER(table_name), LOWER('user')) LIMIT 1000;"; + + "table_schema = 'public' AND STARTS_WITH(LOWER(table_name), LOWER('user')) LIMIT 1000"; - MatcherStatement matcherStatement = new MatcherStatement(options, sql, connectionHandler); + MatcherStatement matcherStatement = + new MatcherStatement(options, parse(sql), connectionHandler); Assert.assertEquals(matcherStatement.getSql(), result); } @@ -701,9 +735,10 @@ public void testDescribeTableMetadataAutocomplete() { + "LIMIT 1000"; String result = "SELECT table_name AS quote_ident FROM INFORMATION_SCHEMA.TABLES WHERE" - + " STARTS_WITH(LOWER(table_name), LOWER('user')) LIMIT 1000;"; + + " STARTS_WITH(LOWER(table_name), LOWER('user')) LIMIT 1000"; - MatcherStatement matcherStatement = new MatcherStatement(options, sql, connectionHandler); + MatcherStatement matcherStatement = + new MatcherStatement(options, parse(sql), connectionHandler); Assert.assertEquals(matcherStatement.getSql(), result); } @@ -735,9 +770,10 @@ public void testDescribeIndexMetadataAutocomplete() { + "LIMIT 1000"; String result = "SELECT index_name AS quote_ident FROM INFORMATION_SCHEMA.INDEXES WHERE" - + " STARTS_WITH(LOWER(index_name), LOWER('index')) LIMIT 1000;"; + + " STARTS_WITH(LOWER(index_name), LOWER('index')) LIMIT 1000"; - MatcherStatement matcherStatement = new MatcherStatement(options, sql, connectionHandler); + MatcherStatement matcherStatement = + new MatcherStatement(options, parse(sql), connectionHandler); Assert.assertEquals(matcherStatement.getSql(), result); } @@ -751,9 +787,10 @@ public void testDescribeSchemaMetadataAutocomplete() { + "LIMIT 1000"; String result = "SELECT schema_name AS quote_ident FROM INFORMATION_SCHEMA.SCHEMATA WHERE" - + " STARTS_WITH(LOWER(schema_name), LOWER('schema')) LIMIT 1000;"; + + " STARTS_WITH(LOWER(schema_name), LOWER('schema')) LIMIT 1000"; - MatcherStatement matcherStatement = new MatcherStatement(options, sql, connectionHandler); + MatcherStatement matcherStatement = + new MatcherStatement(options, parse(sql), connectionHandler); Assert.assertEquals(matcherStatement.getSql(), result); } @@ -766,12 +803,12 @@ public void testDynamicCommands() throws Exception { + " \"commands\": " + " [ " + " {" - + " \"input_pattern\": \"^SELECT \\* FROM USERS;$\", " + + " \"input_pattern\": \"^SELECT \\* FROM USERS;?$\", " + " \"output_pattern\": \"RESULT 1\", " + " \"matcher_array\": []" + " }," + " {" - + " \"input_pattern\": \"^SELECT (?.*) FROM USERS WHERE (?.*) = (?.*);$\", " + + " \"input_pattern\": \"^SELECT (?.*) FROM USERS WHERE (?.*) = (?.*);?$\", " + " \"output_pattern\": \"RESULT 2: selector=%s, arg2=%s, arg1=%s\", " + " \"matcher_array\": [ \"selector\", \"arg2\", \"arg1\" ]" + " }" @@ -787,10 +824,11 @@ public void testDynamicCommands() throws Exception { String secondSQL = "SELECT name FROM USERS WHERE age = 30;"; String expectedSecondResult = "RESULT 2: selector=name, arg2=30, arg1=age"; - MatcherStatement matcherStatement = new MatcherStatement(options, firstSQL, connectionHandler); - Assert.assertEquals(matcherStatement.getSql(), expectedFirstResult); + MatcherStatement matcherStatement = + new MatcherStatement(options, parse(firstSQL), connectionHandler); + Assert.assertEquals(expectedFirstResult, matcherStatement.getSql()); - matcherStatement = new MatcherStatement(options, secondSQL, connectionHandler); + matcherStatement = new MatcherStatement(options, parse(secondSQL), connectionHandler); Assert.assertEquals(matcherStatement.getSql(), expectedSecondResult); } @@ -802,7 +840,7 @@ public void testMatcherGroupInPlaceReplacements() throws Exception { + " \"commands\": " + " [ " + " {" - + " \"input_pattern\": \"^SELECT (?.*) FROM (?.*);$\", " + + " \"input_pattern\": \"^SELECT (?.*) FROM (?
.*);?$\", " + " \"output_pattern\": \"TABLE: ${table}, EXPRESSION: ${expression}\", " + " \"matcher_array\": []" + " }" @@ -816,7 +854,8 @@ public void testMatcherGroupInPlaceReplacements() throws Exception { String sql = "SELECT * FROM USERS;"; String expectedResult = "TABLE: USERS, EXPRESSION: *"; - MatcherStatement matcherStatement = new MatcherStatement(options, sql, connectionHandler); + MatcherStatement matcherStatement = + new MatcherStatement(options, parse(sql), connectionHandler); Assert.assertEquals(matcherStatement.getSql(), expectedResult); } } diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/ProtocolTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/ProtocolTest.java index 7a3e2e51a..bd1a610c5 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/ProtocolTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/ProtocolTest.java @@ -31,10 +31,13 @@ import static org.mockito.Mockito.when; import com.google.cloud.spanner.DatabaseClient; +import com.google.cloud.spanner.Dialect; import com.google.cloud.spanner.ErrorCode; import com.google.cloud.spanner.ResultSet; import com.google.cloud.spanner.SpannerException; import com.google.cloud.spanner.Statement; +import com.google.cloud.spanner.connection.AbstractStatementParser; +import com.google.cloud.spanner.connection.AbstractStatementParser.ParsedStatement; import com.google.cloud.spanner.connection.Connection; import com.google.cloud.spanner.connection.StatementResult; import com.google.cloud.spanner.connection.StatementResult.ResultType; @@ -100,6 +103,8 @@ @RunWith(JUnit4.class) public class ProtocolTest { + private static final AbstractStatementParser PARSER = + AbstractStatementParser.getInstance(Dialect.POSTGRESQL); @Rule public MockitoRule rule = MockitoJUnit.rule(); @Mock private ConnectionHandler connectionHandler; @@ -139,6 +144,10 @@ private void readUntilNullTerminator(DataInputStream input) throws Exception { } while (c != '\0'); } + private static ParsedStatement parse(String sql) { + return PARSER.parse(Statement.of(sql)); + } + @AfterClass public static void cleanup() { // TODO: Make error log file configurable and turn off writing to a file during tests. @@ -1243,7 +1252,7 @@ public void testMultipleCopyDataMessages() throws Exception { when(connection.executeQuery(any(Statement.class))).thenReturn(spannerType); CopyStatement copyStatement = - new CopyStatement(options, "COPY keyvalue FROM STDIN;", connection); + new CopyStatement(options, parse("COPY keyvalue FROM STDIN;"), connection); copyStatement.execute(); when(connectionHandler.getActiveStatement()).thenReturn(copyStatement); @@ -1345,7 +1354,8 @@ public void testCopyFromFilePipe() throws Exception { byte[] payload = Files.readAllBytes(Paths.get("./src/test/resources/small-file-test.txt")); CopyStatement copyStatement = - new CopyStatement(mock(OptionsMetadata.class), "COPY keyvalue FROM STDIN;", connection); + new CopyStatement( + mock(OptionsMetadata.class), parse("COPY keyvalue FROM STDIN;"), connection); copyStatement.execute(); MutationWriter mw = copyStatement.getMutationWriter(); @@ -1365,7 +1375,8 @@ public void testCopyBatchSizeLimit() throws Exception { byte[] payload = Files.readAllBytes(Paths.get("./src/test/resources/batch-size-test.txt")); CopyStatement copyStatement = - new CopyStatement(mock(OptionsMetadata.class), "COPY keyvalue FROM STDIN;", connection); + new CopyStatement( + mock(OptionsMetadata.class), parse("COPY keyvalue FROM STDIN;"), connection); assertFalse(copyStatement.isExecuted()); copyStatement.execute(); @@ -1396,7 +1407,8 @@ public void testCopyDataRowLengthMismatchLimit() throws Exception { byte[] payload = "1\t'one'\n2".getBytes(); CopyStatement copyStatement = - new CopyStatement(mock(OptionsMetadata.class), "COPY keyvalue FROM STDIN;", connection); + new CopyStatement( + mock(OptionsMetadata.class), parse("COPY keyvalue FROM STDIN;"), connection); assertFalse(copyStatement.isExecuted()); copyStatement.execute(); @@ -1427,7 +1439,8 @@ public void testCopyResumeErrorOutputFile() throws Exception { byte[] payload = Files.readAllBytes(Paths.get("./src/test/resources/test-copy-output.txt")); CopyStatement copyStatement = - new CopyStatement(mock(OptionsMetadata.class), "COPY keyvalue FROM STDIN;", connection); + new CopyStatement( + mock(OptionsMetadata.class), parse("COPY keyvalue FROM STDIN;"), connection); assertFalse(copyStatement.isExecuted()); copyStatement.execute(); assertTrue(copyStatement.isExecuted()); @@ -1468,7 +1481,7 @@ public void testCopyResumeErrorStartOutputFile() throws Exception { } CopyStatement copyStatement = - new CopyStatement(options, "COPY keyvalue FROM STDIN;", connection); + new CopyStatement(options, parse("COPY keyvalue FROM STDIN;"), connection); assertFalse(copyStatement.isExecuted()); copyStatement.execute(); assertTrue(copyStatement.isExecuted()); diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/StatementTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/StatementTest.java index 53c2c870b..bcd2985b6 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/StatementTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/StatementTest.java @@ -21,12 +21,15 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import com.google.cloud.spanner.Dialect; import com.google.cloud.spanner.ErrorCode; import com.google.cloud.spanner.ReadContext.QueryAnalyzeMode; import com.google.cloud.spanner.ResultSet; import com.google.cloud.spanner.SpannerException; import com.google.cloud.spanner.SpannerExceptionFactory; import com.google.cloud.spanner.Statement; +import com.google.cloud.spanner.connection.AbstractStatementParser; +import com.google.cloud.spanner.connection.AbstractStatementParser.ParsedStatement; import com.google.cloud.spanner.connection.Connection; import com.google.cloud.spanner.connection.StatementResult; import com.google.cloud.spanner.pgadapter.metadata.ConnectionMetadata; @@ -64,6 +67,12 @@ @RunWith(JUnit4.class) public class StatementTest { + private static final AbstractStatementParser PARSER = + AbstractStatementParser.getInstance(Dialect.POSTGRESQL); + + private static ParsedStatement parse(String sql) { + return PARSER.parse(Statement.of(sql)); + } @Rule public MockitoRule rule = MockitoJUnit.rule(); @Mock private Connection connection; @@ -96,7 +105,7 @@ public void testBasicSelectStatement() throws Exception { when(resultSet.next()).thenReturn(true); IntermediateStatement intermediateStatement = - new IntermediateStatement(options, "SELECT * FROM users", connection); + new IntermediateStatement(options, parse("SELECT * FROM users"), connection); assertFalse(intermediateStatement.isExecuted()); assertEquals(intermediateStatement.getCommand(), "SELECT"); @@ -126,7 +135,7 @@ public void testBasicUpdateStatement() throws Exception { IntermediateStatement intermediateStatement = new IntermediateStatement( - options, "UPDATE users SET name = someName WHERE id = 10", connection); + options, parse("UPDATE users SET name = someName WHERE id = 10"), connection); assertFalse(intermediateStatement.isExecuted()); assertEquals(intermediateStatement.getCommand(), "UPDATE"); @@ -158,7 +167,7 @@ public void testBasicZeroUpdateCountResultStatement() throws Exception { IntermediateStatement intermediateStatement = new IntermediateStatement( - options, "UPDATE users SET name = someName WHERE id = -1", connection); + options, parse("UPDATE users SET name = someName WHERE id = -1"), connection); assertFalse(intermediateStatement.isExecuted()); assertEquals(intermediateStatement.getCommand(), "UPDATE"); @@ -190,7 +199,7 @@ public void testBasicNoResultStatement() throws Exception { IntermediateStatement intermediateStatement = new IntermediateStatement( - options, "CREATE TABLE users (name varchar(100) primary key)", connection); + options, parse("CREATE TABLE users (name varchar(100) primary key)"), connection); assertFalse(intermediateStatement.isExecuted()); assertEquals(intermediateStatement.getCommand(), "CREATE"); @@ -216,7 +225,7 @@ public void testBasicNoResultStatement() throws Exception { @Test(expected = IllegalStateException.class) public void testDescribeBasicStatementThrowsException() throws Exception { IntermediateStatement intermediateStatement = - new IntermediateStatement(options, "SELECT * FROM users", connection); + new IntermediateStatement(options, parse("SELECT * FROM users"), connection); intermediateStatement.describe(); } @@ -229,7 +238,7 @@ public void testBasicStatementExceptionGetsSetOnExceptedExecution() throws Excep when(connection.execute(Statement.of("SELECT * FROM users"))).thenThrow(thrownException); IntermediateStatement intermediateStatement = - new IntermediateStatement(options, "SELECT * FROM users", connection); + new IntermediateStatement(options, parse("SELECT * FROM users"), connection); intermediateStatement.execute(); @@ -256,7 +265,7 @@ public void testPreparedStatement() throws Exception { when(connection.execute(statement)).thenReturn(statementResult); IntermediatePreparedStatement intermediateStatement = - new IntermediatePreparedStatement(options, sqlStatement, connection); + new IntermediatePreparedStatement(options, parse(sqlStatement), connection); intermediateStatement.setParameterDataTypes(parameterDataTypes); assertEquals(intermediateStatement.getSql(), sqlStatement); @@ -280,7 +289,7 @@ public void testPreparedStatementIllegalTypeThrowsException() throws Exception { List parameterDataTypes = Arrays.asList(Oid.JSON); IntermediatePreparedStatement intermediateStatement = - new IntermediatePreparedStatement(options, sqlStatement, connection); + new IntermediatePreparedStatement(options, parse(sqlStatement), connection); intermediateStatement.setParameterDataTypes(parameterDataTypes); byte[][] parameters = {"{}".getBytes()}; @@ -295,7 +304,7 @@ public void testPreparedStatementDescribeDoesNotThrowException() throws Exceptio .thenReturn(resultSet); IntermediatePreparedStatement intermediateStatement = - new IntermediatePreparedStatement(options, sqlStatement, connection); + new IntermediatePreparedStatement(options, parse(sqlStatement), connection); intermediateStatement.describe(); } @@ -307,7 +316,7 @@ public void testPortalStatement() throws Exception { .thenReturn(resultSet); IntermediatePortalStatement intermediateStatement = - new IntermediatePortalStatement(options, sqlStatement, connection); + new IntermediatePortalStatement(options, parse(sqlStatement), connection); intermediateStatement.describe(); @@ -345,7 +354,7 @@ public void testPortalStatementDescribePropagatesFailure() throws Exception { String sqlStatement = "SELECT * FROM users WHERE age > $1 AND age < $2 AND name = $3"; IntermediatePortalStatement intermediateStatement = - new IntermediatePortalStatement(options, sqlStatement, connection); + new IntermediatePortalStatement(options, parse(sqlStatement), connection); when(connection.analyzeQuery(Statement.of(sqlStatement), QueryAnalyzeMode.PLAN)) .thenThrow( @@ -359,14 +368,14 @@ public void testBatchStatements() throws Exception { String sql = "INSERT INTO users (id) VALUES (1); INSERT INTO users (id) VALUES (2);INSERT INTO users (id) VALUES (3);"; IntermediateStatement intermediateStatement = - new IntermediateStatement(options, sql, connection); + new IntermediateStatement(options, parse(sql), connection); assertTrue(intermediateStatement.isBatchedQuery()); List result = intermediateStatement.getStatements(); assertEquals(result.size(), 3); - assertEquals(result.get(0), "INSERT INTO users (id) VALUES (1);"); - assertEquals(result.get(1), "INSERT INTO users (id) VALUES (2);"); - assertEquals(result.get(2), "INSERT INTO users (id) VALUES (3);"); + assertEquals(result.get(0), "INSERT INTO users (id) VALUES (1)"); + assertEquals(result.get(1), "INSERT INTO users (id) VALUES (2)"); + assertEquals(result.get(2), "INSERT INTO users (id) VALUES (3)"); } @Test @@ -374,29 +383,29 @@ public void testAdditionalBatchStatements() throws Exception { String sql = "BEGIN TRANSACTION; INSERT INTO users (id) VALUES (1); INSERT INTO users (id) VALUES (2); INSERT INTO users (id) VALUES (3); COMMIT;"; IntermediateStatement intermediateStatement = - new IntermediateStatement(options, sql, connection); + new IntermediateStatement(options, parse(sql), connection); assertTrue(intermediateStatement.isBatchedQuery()); List result = intermediateStatement.getStatements(); assertEquals(result.size(), 5); - assertEquals(result.get(0), "BEGIN TRANSACTION;"); - assertEquals(result.get(1), "INSERT INTO users (id) VALUES (1);"); - assertEquals(result.get(2), "INSERT INTO users (id) VALUES (2);"); - assertEquals(result.get(3), "INSERT INTO users (id) VALUES (3);"); - assertEquals(result.get(4), "COMMIT;"); + assertEquals(result.get(0), "BEGIN TRANSACTION"); + assertEquals(result.get(1), "INSERT INTO users (id) VALUES (1)"); + assertEquals(result.get(2), "INSERT INTO users (id) VALUES (2)"); + assertEquals(result.get(3), "INSERT INTO users (id) VALUES (3)"); + assertEquals(result.get(4), "COMMIT"); } @Test public void testBatchStatementsWithEmptyStatements() throws Exception { String sql = "INSERT INTO users (id) VALUES (1); ;;; INSERT INTO users (id) VALUES (2);"; IntermediateStatement intermediateStatement = - new IntermediateStatement(options, sql, connection); + new IntermediateStatement(options, parse(sql), connection); assertTrue(intermediateStatement.isBatchedQuery()); List result = intermediateStatement.getStatements(); assertEquals(result.size(), 2); - assertEquals(result.get(0), "INSERT INTO users (id) VALUES (1);"); - assertEquals(result.get(1), "INSERT INTO users (id) VALUES (2);"); + assertEquals(result.get(0), "INSERT INTO users (id) VALUES (1)"); + assertEquals(result.get(1), "INSERT INTO users (id) VALUES (2)"); } @Test @@ -404,13 +413,13 @@ public void testBatchStatementsWithQuotes() throws Exception { String sql = "INSERT INTO users (name) VALUES (';;test;;'); INSERT INTO users (name1, name2) VALUES ('''''', ';'';');"; IntermediateStatement intermediateStatement = - new IntermediateStatement(options, sql, connection); + new IntermediateStatement(options, parse(sql), connection); assertTrue(intermediateStatement.isBatchedQuery()); List result = intermediateStatement.getStatements(); assertEquals(result.size(), 2); - assertEquals(result.get(0), "INSERT INTO users (name) VALUES (';;test;;');"); - assertEquals(result.get(1), "INSERT INTO users (name1, name2) VALUES ('''''', ';'';');"); + assertEquals(result.get(0), "INSERT INTO users (name) VALUES (';;test;;')"); + assertEquals(result.get(1), "INSERT INTO users (name1, name2) VALUES ('''''', ';'';')"); } @Test @@ -436,7 +445,7 @@ public void testBatchStatementsWithComments() throws Exception { assertTrue(intermediateStatement.isBatchedQuery()); List result = intermediateStatement.getStatements(); assertEquals(result.size(), 2); - assertEquals(result.get(0), "INSERT INTO users (name) VALUES (';;test;;');"); + assertEquals(result.get(0), "INSERT INTO users (name) VALUES (';;test;;')"); assertEquals(result.get(1), "INSERT INTO users (name1, name2) VALUES ('''''', ';'';')"); } @@ -449,7 +458,8 @@ public void testCopyInvalidBuildMutation() throws Exception { Mockito.when(statementResult.getUpdateCount()).thenReturn(1L); CopyStatement statement = - new CopyStatement(mock(OptionsMetadata.class), "COPY keyvalue FROM STDIN;", connection); + new CopyStatement( + mock(OptionsMetadata.class), parse("COPY keyvalue FROM STDIN;"), connection); statement.execute(); byte[] payload = "2 3\n".getBytes(); diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/statements/IntermediateStatementTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/statements/IntermediateStatementTest.java index 36e35107a..8d2b952c8 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/statements/IntermediateStatementTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/statements/IntermediateStatementTest.java @@ -22,7 +22,11 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import com.google.cloud.spanner.Dialect; import com.google.cloud.spanner.ResultSet; +import com.google.cloud.spanner.Statement; +import com.google.cloud.spanner.connection.AbstractStatementParser; +import com.google.cloud.spanner.connection.AbstractStatementParser.ParsedStatement; import com.google.cloud.spanner.connection.Connection; import com.google.cloud.spanner.connection.StatementResult; import com.google.cloud.spanner.connection.StatementResult.ResultType; @@ -34,12 +38,20 @@ @RunWith(JUnit4.class) public class IntermediateStatementTest { + private static final AbstractStatementParser PARSER = + AbstractStatementParser.getInstance(Dialect.POSTGRESQL); + + private static ParsedStatement parse(String sql) { + return PARSER.parse(Statement.of(sql)); + } + @Mock private Connection connection; @Test public void testUpdateResultCount_ResultSet() { IntermediateStatement statement = - new IntermediateStatement(mock(OptionsMetadata.class), "select foo from bar", connection); + new IntermediateStatement( + mock(OptionsMetadata.class), parse("select foo from bar"), connection); ResultSet resultSet = mock(ResultSet.class); when(resultSet.next()).thenReturn(true, false); StatementResult result = mock(StatementResult.class); @@ -57,7 +69,8 @@ public void testUpdateResultCount_ResultSet() { @Test public void testUpdateResultCount_UpdateCount() { IntermediateStatement statement = - new IntermediateStatement(mock(OptionsMetadata.class), "update bar set foo=1", connection); + new IntermediateStatement( + mock(OptionsMetadata.class), parse("update bar set foo=1"), connection); StatementResult result = mock(StatementResult.class); when(result.getResultType()).thenReturn(ResultType.UPDATE_COUNT); when(result.getResultSet()).thenThrow(new IllegalStateException()); @@ -74,7 +87,9 @@ public void testUpdateResultCount_UpdateCount() { public void testUpdateResultCount_NoResult() { IntermediateStatement statement = new IntermediateStatement( - mock(OptionsMetadata.class), "create table bar (foo bigint primary key)", connection); + mock(OptionsMetadata.class), + parse("create table bar (foo bigint primary key)"), + connection); StatementResult result = mock(StatementResult.class); when(result.getResultType()).thenReturn(ResultType.NO_RESULT); when(result.getResultSet()).thenThrow(new IllegalStateException());