diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/IntermediatePreparedStatement.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/IntermediatePreparedStatement.java index 71fd10427..138d1b487 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/IntermediatePreparedStatement.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/IntermediatePreparedStatement.java @@ -133,10 +133,7 @@ public IntermediatePortalStatement bind( @Override public DescribeMetadata describe() { - Set parameters = - ImmutableSortedSet.orderedBy(Comparator.comparing(o -> o.substring(1))) - .addAll(PARSER.getQueryParameters(this.parsedStatement.getSqlWithoutComments())) - .build(); + Set parameters = extractParameters(this.parsedStatement.getSqlWithoutComments()); ResultSet columnsResultSet = null; try { @@ -195,6 +192,19 @@ public DescribeMetadata describe() { } } + /** + * Extracts the statement parameters from the given sql string and returns these as a sorted set. + * The parameters are ordered by their index and not by the textual value (i.e. "$9" comes before + * "$10"). + */ + @VisibleForTesting + static ImmutableSortedSet extractParameters(String sql) { + return ImmutableSortedSet.orderedBy( + Comparator.comparing(o -> Integer.valueOf(o.substring(1)))) + .addAll(PARSER.getQueryParameters(sql)) + .build(); + } + /** * Transforms a query or DML statement into a SELECT statement that selects the parameters in the * statements. Examples: diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/statements/IntermediateStatementTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/statements/IntermediateStatementTest.java index b4573339f..c9c143817 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/statements/IntermediateStatementTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/statements/IntermediateStatementTest.java @@ -14,6 +14,7 @@ package com.google.cloud.spanner.pgadapter.statements; +import static com.google.cloud.spanner.pgadapter.statements.IntermediatePreparedStatement.extractParameters; import static com.google.cloud.spanner.pgadapter.statements.IntermediatePreparedStatement.transformDeleteToSelectParams; import static com.google.cloud.spanner.pgadapter.statements.IntermediatePreparedStatement.transformInsertToSelectParams; import static com.google.cloud.spanner.pgadapter.statements.IntermediatePreparedStatement.transformUpdateToSelectParams; @@ -39,9 +40,6 @@ import com.google.cloud.spanner.pgadapter.ConnectionHandler; import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSortedSet; -import java.util.Comparator; -import java.util.Set; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; @@ -177,6 +175,13 @@ public void testTransformInsertValuesToSelectParams() { transformInsert( "insert\ninto\nfoo\n(col1,\ncol2 ) values ($1 + $2 + 5, $3 || to_char($4) || coalesce($5, ''))") .getSql()); + assertEquals( + "select $1, $2, $3, $4, $5, $6, $7, $8, $9, $10 from " + + "(select col1=$1, col2=$2, col3=$3, col4=$4, col5=$5, col6=$6, col7=$7, col8=$8, col9=$9, col10=$10 from foo) p", + transformInsert( + "insert into foo (col1, col2, col3, col4, col5, col6, col7, col8, col9, col10) " + + "values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)") + .getSql()); } @Test @@ -225,6 +230,12 @@ public void testTransformUpdateToSelectParams() { assertNull(transformUpdate("update foo col1=1")); assertNull(transformUpdate("update foo col1=1 hwere col1=2")); assertNull(transformUpdate("udpate foo col1=1 where col1=2")); + + assertEquals( + "select $1, $2, $3, $4, $5, $6, $7, $8, $9, $10 from (select col1=$1, col2=$2, col3=$3, col4=$4, col5=$5, col6=$6, col7=$7, col8=$8, col9=$9 from foo where id=$10) p", + transformUpdate( + "update foo set col1=$1, col2=$2, col3=$3, col4=$4, col5=$5, col6=$6, col7=$7, col8=$8, col9=$9 where id=$10") + .getSql()); } @Test @@ -236,7 +247,7 @@ public void testTransformDeleteToSelectParams() { "select $1, $2 from (select 1 from foo where id=$1 and bar > $2) p", transformDelete("delete foo\nwhere id=$1 and bar > $2").getSql()); assertEquals( - "select $1, $2, $3, $4, $5, $6, $7, $8, $9 " + "select $1, $2, $3, $4, $5, $6, $7, $8, $9, $10 " + "from (select 1 from all_types " + "where col_bigint=$1 " + "and col_bool=$2 " @@ -246,7 +257,8 @@ public void testTransformDeleteToSelectParams() { + "and col_numeric=$6 " + "and col_timestamptz=$7 " + "and col_date=$8 " - + "and col_varchar=$9" + + "and col_varchar=$9 " + + "and col_jsonb=$10" + ") p", transformDelete( "delete " @@ -259,7 +271,8 @@ public void testTransformDeleteToSelectParams() { + "and col_numeric=$6 " + "and col_timestamptz=$7 " + "and col_date=$8 " - + "and col_varchar=$9") + + "and col_varchar=$9 " + + "and col_jsonb=$10") .getSql()); assertNull(transformDelete("delete from foo")); @@ -268,26 +281,14 @@ public void testTransformDeleteToSelectParams() { } private static Statement transformInsert(String sql) { - Set parameters = - ImmutableSortedSet.orderedBy(Comparator.comparing(o -> o.substring(1))) - .addAll(PARSER.getQueryParameters(sql)) - .build(); - return transformInsertToSelectParams(mock(Connection.class), sql, parameters); + return transformInsertToSelectParams(mock(Connection.class), sql, extractParameters(sql)); } private static Statement transformUpdate(String sql) { - Set parameters = - ImmutableSortedSet.orderedBy(Comparator.comparing(o -> o.substring(1))) - .addAll(PARSER.getQueryParameters(sql)) - .build(); - return transformUpdateToSelectParams(sql, parameters); + return transformUpdateToSelectParams(sql, extractParameters(sql)); } private static Statement transformDelete(String sql) { - Set parameters = - ImmutableSortedSet.orderedBy(Comparator.comparing(o -> o.substring(1))) - .addAll(PARSER.getQueryParameters(sql)) - .build(); - return transformDeleteToSelectParams(sql, parameters); + return transformDeleteToSelectParams(sql, extractParameters(sql)); } }