Skip to content

Commit

Permalink
fix: parse table names with schema prefix (#232)
Browse files Browse the repository at this point in the history
  • Loading branch information
olavloite authored Jun 29, 2022
1 parent 1c60253 commit cbdf28d
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata;
import com.google.cloud.spanner.pgadapter.parsers.Parser;
import com.google.cloud.spanner.pgadapter.parsers.Parser.FormatCode;
import com.google.cloud.spanner.pgadapter.statements.SimpleParser.TableOrIndexName;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableSortedSet;
Expand Down Expand Up @@ -280,7 +281,7 @@ private Statement transformDmlToSelectParams(Set<String> parameters) {
return null;
}
parser.eat("into");
String table = parser.readIdentifier();
TableOrIndexName table = parser.readTableOrIndexName();
if (table == null) {
return null;
}
Expand Down Expand Up @@ -346,10 +347,10 @@ private Statement transformDmlToSelectParams(Set<String> parameters) {
* without a column list. The query that is used does not use the INFORMATION_SCHEMA, but queries
* the table directly, so it can use the same transaction as the actual insert statement.
*/
static List<String> getAllColumns(Connection connection, String table) {
static List<String> getAllColumns(Connection connection, TableOrIndexName table) {
try (ResultSet resultSet =
connection.analyzeQuery(
Statement.of("SELECT * FROM \"" + table + "\" LIMIT 1"), QueryAnalyzeMode.PLAN)) {
Statement.of("SELECT * FROM " + table + " LIMIT 1"), QueryAnalyzeMode.PLAN)) {
return resultSet.getType().getStructFields().stream()
.map(StructField::getName)
.collect(Collectors.toList());
Expand Down Expand Up @@ -377,7 +378,7 @@ static Statement transformUpdateToSelectParams(String sql, Set<String> parameter
return null;
}
parser.eat("only");
String table = parser.readIdentifier();
TableOrIndexName table = parser.readTableOrIndexName();
if (table == null) {
return null;
}
Expand Down Expand Up @@ -427,7 +428,7 @@ static Statement transformDeleteToSelectParams(String sql, Set<String> parameter
return null;
}
parser.eat("from");
String table = parser.readIdentifier();
TableOrIndexName table = parser.readTableOrIndexName();
if (table == null) {
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,52 @@

package com.google.cloud.spanner.pgadapter.statements;

import com.google.common.collect.ImmutableSet;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.Objects;
import javax.annotation.Nullable;

/** A very simple parser that can interpret SQL statements to find specific parts in the string. */
class SimpleParser {
private static final Set<Character> OPERATORS = ImmutableSet.of('+', '-', '*', '/', '!');
/** Name of table or index. */
static class TableOrIndexName {
/** Schema is an optional schema name prefix. */
final String schema;
/** Name is the actual object name. */
final String name;

TableOrIndexName(String name) {
this.schema = null;
this.name = name;
}

TableOrIndexName(String schema, String name) {
this.schema = schema;
this.name = name;
}

@Override
public String toString() {
if (schema == null) {
return name;
}
return schema + "." + name;
}

@Override
public boolean equals(Object o) {
if (!(o instanceof TableOrIndexName)) {
return false;
}
TableOrIndexName other = (TableOrIndexName) o;
return Objects.equals(this.schema, other.schema) && Objects.equals(this.name, other.name);
}

@Override
public int hashCode() {
return Objects.hash(this.schema, this.name);
}
}

private final String sql;
private int pos;
Expand Down Expand Up @@ -103,7 +141,22 @@ String parseExpression(@Nullable String delimiter) {
return sql.substring(start, pos).trim();
}

String readIdentifier() {
TableOrIndexName readTableOrIndexName() {
String nameOrSchema = readTableOrIndexNamePart();
if (nameOrSchema == null) {
return null;
}
if (eat(".")) {
String name = readTableOrIndexNamePart();
if (name == null) {
name = "";
}
return new TableOrIndexName(nameOrSchema, name);
}
return new TableOrIndexName(nameOrSchema);
}

String readTableOrIndexNamePart() {
skipWhitespaces();
boolean quoted = sql.charAt(pos) == '"';
int start = pos;
Expand All @@ -112,11 +165,18 @@ String readIdentifier() {
}
while (pos < sql.length()) {
if (quoted) {
if (sql.charAt(pos) == '"' && sql.charAt(pos - 1) != '\\') {
return sql.substring(start, ++pos);
if (sql.charAt(pos) == '"') {
if (pos < (sql.length() - 1) && sql.charAt(pos + 1) == '"') {
pos++;
} else {
return sql.substring(start, ++pos);
}
}
} else {
if (Character.isWhitespace(sql.charAt(pos))) {
if (Character.isWhitespace(sql.charAt(pos))
|| sql.charAt(pos) == '.'
|| sql.charAt(pos) == ','
|| sql.charAt(pos) == '"') {
return sql.substring(start, pos);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import java.nio.charset.StandardCharsets;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.ParameterMetaData;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
Expand Down Expand Up @@ -483,6 +484,55 @@ public void testDescribeDmlWithNonExistingTable() throws SQLException {
assertEquals(QueryMode.PLAN, requests.get(1).getQueryMode());
}

@Test
public void testDescribeDmlWithSchemaPrefix() throws SQLException {
String sql = "update public.my_table set value=? where id=?";
String describeSql = "select $1, $2 from (select value=$1 from public.my_table where id=$2) p";
mockSpanner.putStatementResult(
StatementResult.query(
Statement.of(describeSql),
com.google.spanner.v1.ResultSet.newBuilder()
.setMetadata(createMetadata(ImmutableList.of(TypeCode.STRING, TypeCode.INT64)))
.build()));
try (Connection connection = DriverManager.getConnection(createUrl())) {
try (PreparedStatement preparedStatement = connection.prepareStatement(sql)) {
ParameterMetaData metadata = preparedStatement.getParameterMetaData();
assertEquals(Types.VARCHAR, metadata.getParameterType(1));
assertEquals(Types.BIGINT, metadata.getParameterType(2));
}
}

List<ExecuteSqlRequest> requests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class);
assertEquals(1, requests.size());
assertEquals(describeSql, requests.get(0).getSql());
assertEquals(QueryMode.PLAN, requests.get(0).getQueryMode());
}

@Test
public void testDescribeDmlWithQuotedSchemaPrefix() throws SQLException {
String sql = "update \"public\".\"my_table\" set value=? where id=?";
String describeSql =
"select $1, $2 from (select value=$1 from \"public\".\"my_table\" where id=$2) p";
mockSpanner.putStatementResult(
StatementResult.query(
Statement.of(describeSql),
com.google.spanner.v1.ResultSet.newBuilder()
.setMetadata(createMetadata(ImmutableList.of(TypeCode.STRING, TypeCode.INT64)))
.build()));
try (Connection connection = DriverManager.getConnection(createUrl())) {
try (PreparedStatement preparedStatement = connection.prepareStatement(sql)) {
ParameterMetaData metadata = preparedStatement.getParameterMetaData();
assertEquals(Types.VARCHAR, metadata.getParameterType(1));
assertEquals(Types.BIGINT, metadata.getParameterType(2));
}
}

List<ExecuteSqlRequest> requests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class);
assertEquals(1, requests.size());
assertEquals(describeSql, requests.get(0).getSql());
assertEquals(QueryMode.PLAN, requests.get(0).getQueryMode());
}

@Test
public void testTwoDmlStatements() throws SQLException {
try (Connection connection = DriverManager.getConnection(createUrl())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;

import com.google.cloud.spanner.pgadapter.statements.SimpleParser.TableOrIndexName;
import java.util.Arrays;
import java.util.Collections;
import org.junit.Test;
Expand Down Expand Up @@ -46,16 +47,55 @@ public void testEat() {
}

@Test
public void testReadIdentifier() {
assertEquals("foo", new SimpleParser("foo bar").readIdentifier());
assertEquals("foo", new SimpleParser("foo").readIdentifier());
assertEquals("\"foo\"", new SimpleParser("\"foo\" bar").readIdentifier());
assertEquals("\"foo\"", new SimpleParser("\"foo\"").readIdentifier());
assertEquals("foo", new SimpleParser(" foo bar").readIdentifier());
assertEquals("foo", new SimpleParser("\tfoo").readIdentifier());
assertEquals("\"foo\"", new SimpleParser("\n\"foo\" bar").readIdentifier());
assertEquals("\"foo\"", new SimpleParser(" \"foo\"").readIdentifier());
assertNull(new SimpleParser("\"foo").readIdentifier());
public void testReadTableOrIndexNamePart() {
assertEquals("foo", new SimpleParser("foo bar").readTableOrIndexNamePart());
assertEquals("foo", new SimpleParser("foo").readTableOrIndexNamePart());
assertEquals("\"foo\"", new SimpleParser("\"foo\" bar").readTableOrIndexNamePart());
assertEquals("\"foo\"", new SimpleParser("\"foo\"").readTableOrIndexNamePart());
assertEquals("foo", new SimpleParser(" foo bar").readTableOrIndexNamePart());
assertEquals("foo", new SimpleParser("\tfoo").readTableOrIndexNamePart());
assertEquals("\"foo\"", new SimpleParser("\n\"foo\" bar").readTableOrIndexNamePart());
assertEquals("\"foo\"", new SimpleParser(" \"foo\"").readTableOrIndexNamePart());
assertEquals("\"foo\"\"bar\"", new SimpleParser("\"foo\"\"bar\"").readTableOrIndexNamePart());
assertEquals("foo", new SimpleParser("foo\"bar\"").readTableOrIndexNamePart());
assertEquals("foo", new SimpleParser("foo.bar").readTableOrIndexNamePart());
assertEquals("foo", new SimpleParser("foo").readTableOrIndexNamePart());
assertEquals("\"foo\"", new SimpleParser("\"foo\".bar").readTableOrIndexNamePart());
assertEquals("\"foo\"", new SimpleParser("\"foo\"").readTableOrIndexNamePart());
assertNull(new SimpleParser("\"foo").readTableOrIndexNamePart());
}

@Test
public void testReadTableOrIndexName() {
assertEquals(new TableOrIndexName("foo"), new SimpleParser("foo bar").readTableOrIndexName());
assertEquals(new TableOrIndexName("foo"), new SimpleParser("foo").readTableOrIndexName());
assertEquals(
new TableOrIndexName("\"foo\""), new SimpleParser("\"foo\" bar").readTableOrIndexName());
assertEquals(
new TableOrIndexName("\"foo\""), new SimpleParser("\"foo\"").readTableOrIndexName());
assertEquals(new TableOrIndexName("foo"), new SimpleParser(" foo bar").readTableOrIndexName());
assertEquals(new TableOrIndexName("foo"), new SimpleParser("\tfoo").readTableOrIndexName());
assertEquals(
new TableOrIndexName("\"foo\""), new SimpleParser("\n\"foo\" bar").readTableOrIndexName());
assertEquals(
new TableOrIndexName("\"foo\""), new SimpleParser(" \"foo\"").readTableOrIndexName());
assertEquals(
new TableOrIndexName("\"foo\"\"bar\""),
new SimpleParser("\"foo\"\"bar\"").readTableOrIndexName());
assertEquals(
new TableOrIndexName("foo"), new SimpleParser("foo\"bar\"").readTableOrIndexName());
assertEquals(
new TableOrIndexName("foo", "bar"), new SimpleParser("foo.bar").readTableOrIndexName());
assertEquals(new TableOrIndexName("foo"), new SimpleParser("foo").readTableOrIndexName());
assertEquals(
new TableOrIndexName("\"foo\"", "bar"),
new SimpleParser("\"foo\".bar").readTableOrIndexName());
assertEquals(
new TableOrIndexName("\"foo\"", "\"bar\""),
new SimpleParser("\"foo\".\"bar\"").readTableOrIndexName());
assertEquals(
new TableOrIndexName("\"foo\""), new SimpleParser("\"foo\"").readTableOrIndexName());
assertNull(new SimpleParser("\"foo").readTableOrIndexName());
}

@Test
Expand Down

0 comments on commit cbdf28d

Please sign in to comment.