Skip to content

Commit

Permalink
Implement JDBC PreparedStatement.getMetaData
Browse files Browse the repository at this point in the history
  • Loading branch information
puchengy authored and electrum committed Feb 29, 2020
1 parent f5b8496 commit 455d6b9
Show file tree
Hide file tree
Showing 3 changed files with 199 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.prestosql.client.ClientStandardTypes.ROW;
import static io.prestosql.client.ClientStandardTypes.VARCHAR;
import static java.lang.String.format;
import static java.util.Collections.unmodifiableList;
import static java.util.Objects.requireNonNull;
Expand All @@ -41,6 +42,7 @@ public class ClientTypeSignature
private static final Pattern PATTERN = Pattern.compile(".*[<>,].*");
private final String rawType;
private final List<ClientTypeSignatureParameter> arguments;
public static final int VARCHAR_UNBOUNDED_LENGTH = Integer.MAX_VALUE;

public ClientTypeSignature(String rawType)
{
Expand Down Expand Up @@ -87,6 +89,10 @@ public String toString()
return rowToString();
}

if (rawType.equals(VARCHAR) && arguments.get(0).getKind() == ParameterKind.LONG && arguments.get(0).getLongLiteral() == VARCHAR_UNBOUNDED_LENGTH) {
return "varchar";
}

if (arguments.isEmpty()) {
return rawType;
}
Expand All @@ -108,6 +114,9 @@ private String rowToString()
})
.collect(joining(","));

if (fields.isEmpty()) {
return "row";
}
return format("row(%s)", fields);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,13 @@
*/
package io.prestosql.jdbc;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Joiner;
import com.google.common.base.Splitter;
import com.google.common.collect.ImmutableList;
import com.google.common.primitives.Ints;
import io.prestosql.client.ClientTypeSignature;
import io.prestosql.client.ClientTypeSignatureParameter;
import org.joda.time.DateTimeZone;

import java.io.InputStream;
Expand All @@ -36,6 +41,7 @@
import java.sql.SQLFeatureNotSupportedException;
import java.sql.SQLType;
import java.sql.SQLXML;
import java.sql.Statement;
import java.sql.Time;
import java.sql.Timestamp;
import java.sql.Types;
Expand All @@ -44,8 +50,13 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.io.BaseEncoding.base16;
import static io.prestosql.client.ClientTypeSignature.VARCHAR_UNBOUNDED_LENGTH;
import static io.prestosql.jdbc.ColumnInfo.setTypeInfo;
import static io.prestosql.jdbc.ObjectCasts.castToBigDecimal;
import static io.prestosql.jdbc.ObjectCasts.castToBinary;
import static io.prestosql.jdbc.ObjectCasts.castToBoolean;
Expand All @@ -61,13 +72,18 @@
import static io.prestosql.jdbc.PrestoResultSet.DATE_FORMATTER;
import static io.prestosql.jdbc.PrestoResultSet.TIMESTAMP_FORMATTER;
import static io.prestosql.jdbc.PrestoResultSet.TIME_FORMATTER;
import static java.lang.Long.parseLong;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;

public class PrestoPreparedStatement
extends PrestoStatement
implements PreparedStatement
{
private static final Pattern TOP_LEVEL_TYPE_PATTERN = Pattern.compile("(.+?)\\((.+)\\)");
private static final Pattern TIMESTAMP_WITH_TIME_ZONE_PRECISION_PATTERN = Pattern.compile("timestamp\\((\\d+)\\) with time zone");
private static final Pattern TIME_WITH_TIME_ZONE_PRECISION_PATTERN = Pattern.compile("time\\((\\d+)\\) with time zone");

private final Map<Integer, String> parameters = new HashMap<>();
private final String statementName;
private final String originalSql;
Expand Down Expand Up @@ -479,7 +495,9 @@ public void setArray(int parameterIndex, Array x)
public ResultSetMetaData getMetaData()
throws SQLException
{
throw new SQLFeatureNotSupportedException("getMetaData");
try (Statement statement = connection().createStatement(); ResultSet resultSet = statement.executeQuery("DESCRIBE OUTPUT " + statementName)) {
return new PrestoResultSetMetaData(getDescribeOutputColumnInfoList(resultSet));
}
}

@Override
Expand Down Expand Up @@ -858,4 +876,70 @@ private static String typedNull(String prestoType)
{
return format("CAST(NULL AS %s)", prestoType);
}

private static List<ColumnInfo> getDescribeOutputColumnInfoList(ResultSet resultSet)
throws SQLException
{
ImmutableList.Builder<ColumnInfo> list = ImmutableList.builder();
while (resultSet.next()) {
String columnName = resultSet.getString("Column Name");
String catalog = resultSet.getString("Catalog");
String schema = resultSet.getString("Schema");
String table = resultSet.getString("Table");
ClientTypeSignature clientTypeSignature = getClientTypeSignatureFromTypeString(resultSet.getString("Type"));
ColumnInfo.Builder builder = new ColumnInfo.Builder()
.setColumnName(columnName)
.setColumnLabel(columnName)
.setCatalogName(catalog)
.setSchemaName(schema)
.setTableName(table)
.setColumnTypeSignature(clientTypeSignature)
.setNullable(ColumnInfo.Nullable.UNKNOWN);
setTypeInfo(builder, clientTypeSignature);
list.add(builder.build());
}
return list.build();
}

@VisibleForTesting
static ClientTypeSignature getClientTypeSignatureFromTypeString(String type)
{
String topLevelType;
List<ClientTypeSignatureParameter> arguments = new ArrayList<>();
Matcher topLevelMatcher = TOP_LEVEL_TYPE_PATTERN.matcher(type);
if (topLevelMatcher.matches()) {
topLevelType = topLevelMatcher.group(1);
String typeParameters = topLevelMatcher.group(2);
if (topLevelType.equals("decimal")) {
List<String> precisionAndScale = Splitter.on(',').splitToList(typeParameters);
checkArgument(precisionAndScale.size() == 2, "Invalid decimal parameters: %s", typeParameters);
arguments.add(ClientTypeSignatureParameter.ofLong(parseLong(precisionAndScale.get(0))));
arguments.add(ClientTypeSignatureParameter.ofLong(parseLong(precisionAndScale.get(1))));
}
else if (topLevelType.equals("char") || topLevelType.equals("varchar")) {
long precision = parseLong(typeParameters);
arguments.add(ClientTypeSignatureParameter.ofLong(precision));
}
// TODO support array, map, row etc top level types' parameters using recursive parser, current behavior is their parameter list will be empty
}
else {
Matcher timestampMatcher = TIMESTAMP_WITH_TIME_ZONE_PRECISION_PATTERN.matcher(type);
Matcher timeMatcher = TIME_WITH_TIME_ZONE_PRECISION_PATTERN.matcher(type);
if (timestampMatcher.matches()) {
topLevelType = "timestamp with time zone";
arguments.add(ClientTypeSignatureParameter.ofLong(parseLong(timestampMatcher.group(1))));
}
else if (timeMatcher.matches()) {
topLevelType = "time with time zone";
arguments.add(ClientTypeSignatureParameter.ofLong(parseLong(timeMatcher.group(1))));
}
else {
topLevelType = type;
if (topLevelType.equals("varchar")) {
arguments.add(ClientTypeSignatureParameter.ofLong(VARCHAR_UNBOUNDED_LENGTH));
}
}
}
return new ClientTypeSignature(topLevelType, arguments);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
*/
package io.prestosql.jdbc;

import com.google.common.collect.ImmutableList;
import io.airlift.log.Logging;
import io.prestosql.client.ClientTypeSignature;
import io.prestosql.client.ClientTypeSignatureParameter;
import io.prestosql.plugin.blackhole.BlackHolePlugin;
import io.prestosql.server.testing.TestingPrestoServer;
import org.testng.annotations.AfterClass;
Expand All @@ -27,6 +30,7 @@
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.sql.Statement;
import java.sql.Time;
Expand All @@ -41,6 +45,7 @@

import static com.google.common.base.Strings.repeat;
import static com.google.common.primitives.Ints.asList;
import static io.prestosql.client.ClientTypeSignature.VARCHAR_UNBOUNDED_LENGTH;
import static io.prestosql.jdbc.TestPrestoDriver.closeQuietly;
import static io.prestosql.jdbc.TestPrestoDriver.waitForNodeRefresh;
import static java.lang.String.format;
Expand Down Expand Up @@ -103,6 +108,106 @@ public void testExecuteQuery()
}
}

@Test
public void testGetMetadata()
throws Exception
{
try (Connection connection = createConnection("blackhole", "blackhole")) {
try (Statement statement = connection.createStatement()) {
statement.execute("CREATE TABLE test_get_metadata (" +
"c_boolean boolean, " +
"c_decimal decimal, " +
"c_decimal_2 decimal(10,3)," +
"c_varchar varchar, " +
"c_varchar_2 varchar(10), " +
"c_row row(x integer, y array(integer)), " +
"c_array array(integer), " +
"c_map map(integer, integer))");
}

try (PreparedStatement statement = connection.prepareStatement(
"SELECT * FROM test_get_metadata")) {
ResultSetMetaData metadata = statement.getMetaData();
assertEquals(metadata.getColumnCount(), 8);
for (int i = 1; i <= metadata.getColumnCount(); i++) {
assertEquals(metadata.getCatalogName(i), "blackhole");
assertEquals(metadata.getSchemaName(i), "blackhole");
assertEquals(metadata.getTableName(i), "test_get_metadata");
}

assertEquals(metadata.getColumnName(1), "c_boolean");
assertEquals(metadata.getColumnTypeName(1), "boolean");

assertEquals(metadata.getColumnName(2), "c_decimal");
assertEquals(metadata.getColumnTypeName(2), "decimal(38,0)");

assertEquals(metadata.getColumnName(3), "c_decimal_2");
assertEquals(metadata.getColumnTypeName(3), "decimal(10,3)");

assertEquals(metadata.getColumnName(4), "c_varchar");
assertEquals(metadata.getColumnTypeName(4), "varchar");

assertEquals(metadata.getColumnName(5), "c_varchar_2");
assertEquals(metadata.getColumnTypeName(5), "varchar(10)");

assertEquals(metadata.getColumnName(6), "c_row");
assertEquals(metadata.getColumnTypeName(6), "row");

assertEquals(metadata.getColumnName(7), "c_array");
assertEquals(metadata.getColumnTypeName(7), "array");

assertEquals(metadata.getColumnName(8), "c_map");
assertEquals(metadata.getColumnTypeName(8), "map");
}

try (Statement statement = connection.createStatement()) {
statement.execute("DROP TABLE test_get_metadata");
}
}
}

@Test
public void testGetClientTypeSignatureFromTypeString()
{
ClientTypeSignature actualClientTypeSignature = PrestoPreparedStatement.getClientTypeSignatureFromTypeString("boolean");
ClientTypeSignature expectedClientTypeSignature = new ClientTypeSignature("boolean", ImmutableList.of());
assertEquals(actualClientTypeSignature, expectedClientTypeSignature);

actualClientTypeSignature = PrestoPreparedStatement.getClientTypeSignatureFromTypeString("decimal(10,3)");
expectedClientTypeSignature = new ClientTypeSignature("decimal", ImmutableList.of(
ClientTypeSignatureParameter.ofLong(10),
ClientTypeSignatureParameter.ofLong(3)));
assertEquals(actualClientTypeSignature, expectedClientTypeSignature);

actualClientTypeSignature = PrestoPreparedStatement.getClientTypeSignatureFromTypeString("varchar");
expectedClientTypeSignature = new ClientTypeSignature("varchar", ImmutableList.of(ClientTypeSignatureParameter.ofLong(VARCHAR_UNBOUNDED_LENGTH)));
assertEquals(actualClientTypeSignature, expectedClientTypeSignature);

actualClientTypeSignature = PrestoPreparedStatement.getClientTypeSignatureFromTypeString("varchar(10)");
expectedClientTypeSignature = new ClientTypeSignature("varchar", ImmutableList.of(ClientTypeSignatureParameter.ofLong(10)));
assertEquals(actualClientTypeSignature, expectedClientTypeSignature);

actualClientTypeSignature = PrestoPreparedStatement.getClientTypeSignatureFromTypeString("row(x integer, y array(integer))");
expectedClientTypeSignature = new ClientTypeSignature("row", ImmutableList.of());
assertEquals(actualClientTypeSignature, expectedClientTypeSignature);

actualClientTypeSignature = PrestoPreparedStatement.getClientTypeSignatureFromTypeString("array(integer)");
expectedClientTypeSignature = new ClientTypeSignature("array", ImmutableList.of());
assertEquals(actualClientTypeSignature, expectedClientTypeSignature);

actualClientTypeSignature = PrestoPreparedStatement.getClientTypeSignatureFromTypeString("map(integer, integer)");
expectedClientTypeSignature = new ClientTypeSignature("map", ImmutableList.of());
assertEquals(actualClientTypeSignature, expectedClientTypeSignature);

actualClientTypeSignature = PrestoPreparedStatement.getClientTypeSignatureFromTypeString("timestamp(12) with time zone");
expectedClientTypeSignature = new ClientTypeSignature("timestamp with time zone", ImmutableList.of(ClientTypeSignatureParameter.ofLong(12)));
assertEquals(actualClientTypeSignature, expectedClientTypeSignature);

actualClientTypeSignature = PrestoPreparedStatement.getClientTypeSignatureFromTypeString("time(13) with time zone");
expectedClientTypeSignature = new ClientTypeSignature("time with time zone", ImmutableList.of(ClientTypeSignatureParameter.ofLong(13)));
assertEquals(actualClientTypeSignature, expectedClientTypeSignature);
}

@Test
public void testDeallocate()
throws Exception
Expand Down

0 comments on commit 455d6b9

Please sign in to comment.