diff --git a/java/.gitignore b/java/.gitignore index 376e06946d7de..59c2e7b2a0c6d 100644 --- a/java/.gitignore +++ b/java/.gitignore @@ -23,3 +23,6 @@ install_manifest.txt target/ ?/ !/c/ + +# Generated properties file +flight/flight-sql-jdbc-driver/src/main/resources/properties/flight.properties diff --git a/java/dev/checkstyle/suppressions.xml b/java/dev/checkstyle/suppressions.xml index c3f61f46c92be..585985bf32dbc 100644 --- a/java/dev/checkstyle/suppressions.xml +++ b/java/dev/checkstyle/suppressions.xml @@ -39,4 +39,6 @@ + + diff --git a/java/flight/flight-sql-jdbc-driver/jdbc-spotbugs-exclude.xml b/java/flight/flight-sql-jdbc-driver/jdbc-spotbugs-exclude.xml new file mode 100644 index 0000000000000..af75d70425cb4 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/jdbc-spotbugs-exclude.xml @@ -0,0 +1,40 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/java/flight/flight-sql-jdbc-driver/pom.xml b/java/flight/flight-sql-jdbc-driver/pom.xml new file mode 100644 index 0000000000000..b8a49165adb4a --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/pom.xml @@ -0,0 +1,375 @@ + + + + + + arrow-flight + org.apache.arrow + 10.0.0-SNAPSHOT + ../pom.xml + + 4.0.0 + + flight-sql-jdbc-driver + Arrow Flight SQL JDBC Driver + (Contrib/Experimental) A JDBC driver based on Arrow Flight SQL. + jar + https://arrow.apache.org + + + ${project.parent.groupId}:${project.parent.artifactId} + ${project.parent.version} + ${project.name} + ${project.version} + ${project.build.directory}/coverage-reports/jacoco-ut.html + + + + + org.apache.arrow + flight-core + ${project.version} + + + io.netty + netty-transport-native-kqueue + + + io.netty + netty-transport-native-epoll + + + + + + + org.apache.arrow + arrow-memory-core + ${project.version} + + + + + org.apache.arrow + arrow-memory-netty + ${project.version} + runtime + + + + + org.apache.arrow + arrow-vector + ${project.version} + ${arrow.vector.classifier} + + + + com.google.guava + guava + + + + org.slf4j + slf4j-api + runtime + + + + com.google.protobuf + protobuf-java + + + org.hamcrest + hamcrest-core + 1.3 + test + + + me.alexpanov + free-port-finder + 1.1.1 + test + + + + commons-io + commons-io + 2.6 + test + + + + org.mockito + mockito-core + 3.12.4 + test + + + + org.mockito + mockito-inline + 3.12.4 + test + + + + io.netty + netty-common + + + + org.apache.arrow + flight-sql + ${project.version} + + + + org.apache.calcite.avatica + avatica + 1.18.0 + + + org.bouncycastle + bcpkix-jdk15on + 1.61 + + + + joda-time + joda-time + 2.10.14 + + + + + + + src/main/resources + + + + + org.apache.maven.plugins + maven-shade-plugin + 3.2.4 + + + package + + shade + + + false + false + false + + + *:* + + + + + com. + cfjd.com. + + com.sun.** + + + + org. + cfjd.org. + + org.apache.arrow.driver.jdbc.** + org.slf4j.** + + org.apache.arrow.flight.name + org.apache.arrow.flight.version + org.apache.arrow.flight.jdbc-driver.name + org.apache.arrow.flight.jdbc-driver.version + + + + io. + cfjd.io. + + + + META-INF.native.libnetty_ + META-INF.native.libcfjd_netty_ + + + META-INF.native.netty_ + META-INF.native.cfjd_netty_ + + + + + + + + org.apache.calcite.avatica:* + + META-INF/services/java.sql.Driver + + + + *:* + + **/*.SF + **/*.RSA + **/*.DSA + META-INF/native/libio_grpc_netty* + META-INF/native/io_grpc_netty_shaded* + + + + + + + + + org.codehaus.mojo + properties-maven-plugin + 1.1.0 + + + write-project-properties-to-file + generate-resources + + write-project-properties + + + src/main/resources/properties/flight.properties + + + + + + org.jacoco + jacoco-maven-plugin + + + + prepare-agent + + + + ${jacoco.ut.execution.data.file} + + surefireArgLine + + + + + report + test + + report + + + + ${jacoco.ut.execution.data.file} + + + + + + check + + check + + + ${jacoco.ut.execution.data.file} + + + CLASS + + org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl + + org.apache.arrow.driver.jdbc.utils.UrlParser + + + + BRANCH + COVEREDRATIO + 0.80 + + + + + + + + + + + + + + jdk8 + + 1.8 + + + + + org.apache.maven.plugins + maven-surefire-plugin + + + ${surefireArgLine} + + **/IT*.java + + false + + ${project.basedir}/../../../testing/data + + + + + + + + + jdk9+ + + [9,] + + + + + org.apache.maven.plugins + maven-surefire-plugin + + + ${surefireArgLine} --add-opens=java.base/java.nio=ALL-UNNAMED + + **/IT*.java + + false + + ${project.basedir}/../../../testing/data + + + + + + + + diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowDatabaseMetadata.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowDatabaseMetadata.java new file mode 100644 index 0000000000000..da2b0b00edaef --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowDatabaseMetadata.java @@ -0,0 +1,1218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import static java.sql.Types.BIGINT; +import static java.sql.Types.BINARY; +import static java.sql.Types.BIT; +import static java.sql.Types.CHAR; +import static java.sql.Types.DATE; +import static java.sql.Types.DECIMAL; +import static java.sql.Types.FLOAT; +import static java.sql.Types.INTEGER; +import static java.sql.Types.LONGNVARCHAR; +import static java.sql.Types.LONGVARBINARY; +import static java.sql.Types.NUMERIC; +import static java.sql.Types.REAL; +import static java.sql.Types.SMALLINT; +import static java.sql.Types.TIMESTAMP; +import static java.sql.Types.TINYINT; +import static java.sql.Types.VARCHAR; +import static org.apache.arrow.flight.sql.util.SqlInfoOptionsUtils.doesBitmaskTranslateToEnum; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.nio.channels.Channels; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.sql.Connection; +import java.sql.DatabaseMetaData; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.Arrays; +import java.util.Collections; +import java.util.EnumMap; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.regex.Pattern; +import java.util.stream.Collectors; + +import org.apache.arrow.driver.jdbc.utils.SqlTypes; +import org.apache.arrow.driver.jdbc.utils.VectorSchemaRootTransformer; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.sql.FlightSqlColumnMetadata; +import org.apache.arrow.flight.sql.FlightSqlProducer.Schemas; +import org.apache.arrow.flight.sql.impl.FlightSql.SqlInfo; +import org.apache.arrow.flight.sql.impl.FlightSql.SqlOuterJoinsSupportLevel; +import org.apache.arrow.flight.sql.impl.FlightSql.SqlSupportedElementActions; +import org.apache.arrow.flight.sql.impl.FlightSql.SqlSupportedGroupBy; +import org.apache.arrow.flight.sql.impl.FlightSql.SqlSupportedPositionedCommands; +import org.apache.arrow.flight.sql.impl.FlightSql.SqlSupportedResultSetType; +import org.apache.arrow.flight.sql.impl.FlightSql.SqlSupportedSubqueries; +import org.apache.arrow.flight.sql.impl.FlightSql.SqlSupportedUnions; +import org.apache.arrow.flight.sql.impl.FlightSql.SqlSupportsConvert; +import org.apache.arrow.flight.sql.impl.FlightSql.SqlTransactionIsolationLevel; +import org.apache.arrow.flight.sql.impl.FlightSql.SupportedAnsi92SqlGrammarLevel; +import org.apache.arrow.flight.sql.impl.FlightSql.SupportedSqlGrammar; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VarBinaryVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ReadChannel; +import org.apache.arrow.vector.ipc.message.MessageSerializer; +import org.apache.arrow.vector.types.Types; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.Text; +import org.apache.calcite.avatica.AvaticaConnection; +import org.apache.calcite.avatica.AvaticaDatabaseMetaData; + +import com.google.protobuf.ProtocolMessageEnum; + +/** + * Arrow Flight JDBC's implementation of {@link DatabaseMetaData}. + */ +public class ArrowDatabaseMetadata extends AvaticaDatabaseMetaData { + private static final String JAVA_REGEX_SPECIALS = "[]()|^-+*?{}$\\."; + private static final Charset CHARSET = StandardCharsets.UTF_8; + private static final byte[] EMPTY_BYTE_ARRAY = new byte[0]; + static final int NO_DECIMAL_DIGITS = 0; + private static final int BASE10_RADIX = 10; + static final int COLUMN_SIZE_BYTE = (int) Math.ceil((Byte.SIZE - 1) * Math.log(2) / Math.log(10)); + static final int COLUMN_SIZE_SHORT = + (int) Math.ceil((Short.SIZE - 1) * Math.log(2) / Math.log(10)); + static final int COLUMN_SIZE_INT = + (int) Math.ceil((Integer.SIZE - 1) * Math.log(2) / Math.log(10)); + static final int COLUMN_SIZE_LONG = (int) Math.ceil((Long.SIZE - 1) * Math.log(2) / Math.log(10)); + static final int COLUMN_SIZE_VARCHAR_AND_BINARY = 65536; + static final int COLUMN_SIZE_DATE = "YYYY-MM-DD".length(); + static final int COLUMN_SIZE_TIME = "HH:MM:ss".length(); + static final int COLUMN_SIZE_TIME_MILLISECONDS = "HH:MM:ss.SSS".length(); + static final int COLUMN_SIZE_TIME_MICROSECONDS = "HH:MM:ss.SSSSSS".length(); + static final int COLUMN_SIZE_TIME_NANOSECONDS = "HH:MM:ss.SSSSSSSSS".length(); + static final int COLUMN_SIZE_TIMESTAMP_SECONDS = COLUMN_SIZE_DATE + 1 + COLUMN_SIZE_TIME; + static final int COLUMN_SIZE_TIMESTAMP_MILLISECONDS = + COLUMN_SIZE_DATE + 1 + COLUMN_SIZE_TIME_MILLISECONDS; + static final int COLUMN_SIZE_TIMESTAMP_MICROSECONDS = + COLUMN_SIZE_DATE + 1 + COLUMN_SIZE_TIME_MICROSECONDS; + static final int COLUMN_SIZE_TIMESTAMP_NANOSECONDS = + COLUMN_SIZE_DATE + 1 + COLUMN_SIZE_TIME_NANOSECONDS; + static final int DECIMAL_DIGITS_TIME_MILLISECONDS = 3; + static final int DECIMAL_DIGITS_TIME_MICROSECONDS = 6; + static final int DECIMAL_DIGITS_TIME_NANOSECONDS = 9; + private static final Schema GET_COLUMNS_SCHEMA = new Schema( + Arrays.asList( + Field.nullable("TABLE_CAT", Types.MinorType.VARCHAR.getType()), + Field.nullable("TABLE_SCHEM", Types.MinorType.VARCHAR.getType()), + Field.notNullable("TABLE_NAME", Types.MinorType.VARCHAR.getType()), + Field.notNullable("COLUMN_NAME", Types.MinorType.VARCHAR.getType()), + Field.nullable("DATA_TYPE", Types.MinorType.INT.getType()), + Field.nullable("TYPE_NAME", Types.MinorType.VARCHAR.getType()), + Field.nullable("COLUMN_SIZE", Types.MinorType.INT.getType()), + Field.nullable("BUFFER_LENGTH", Types.MinorType.INT.getType()), + Field.nullable("DECIMAL_DIGITS", Types.MinorType.INT.getType()), + Field.nullable("NUM_PREC_RADIX", Types.MinorType.INT.getType()), + Field.notNullable("NULLABLE", Types.MinorType.INT.getType()), + Field.nullable("REMARKS", Types.MinorType.VARCHAR.getType()), + Field.nullable("COLUMN_DEF", Types.MinorType.VARCHAR.getType()), + Field.nullable("SQL_DATA_TYPE", Types.MinorType.INT.getType()), + Field.nullable("SQL_DATETIME_SUB", Types.MinorType.INT.getType()), + Field.notNullable("CHAR_OCTET_LENGTH", Types.MinorType.INT.getType()), + Field.notNullable("ORDINAL_POSITION", Types.MinorType.INT.getType()), + Field.notNullable("IS_NULLABLE", Types.MinorType.VARCHAR.getType()), + Field.nullable("SCOPE_CATALOG", Types.MinorType.VARCHAR.getType()), + Field.nullable("SCOPE_SCHEMA", Types.MinorType.VARCHAR.getType()), + Field.nullable("SCOPE_TABLE", Types.MinorType.VARCHAR.getType()), + Field.nullable("SOURCE_DATA_TYPE", Types.MinorType.SMALLINT.getType()), + Field.notNullable("IS_AUTOINCREMENT", Types.MinorType.VARCHAR.getType()), + Field.notNullable("IS_GENERATEDCOLUMN", Types.MinorType.VARCHAR.getType()) + )); + private final Map cachedSqlInfo = + Collections.synchronizedMap(new EnumMap<>(SqlInfo.class)); + private static final Map sqlTypesToFlightEnumConvertTypes = new HashMap<>(); + + static { + sqlTypesToFlightEnumConvertTypes.put(BIT, SqlSupportsConvert.SQL_CONVERT_BIT_VALUE); + sqlTypesToFlightEnumConvertTypes.put(INTEGER, SqlSupportsConvert.SQL_CONVERT_INTEGER_VALUE); + sqlTypesToFlightEnumConvertTypes.put(NUMERIC, SqlSupportsConvert.SQL_CONVERT_NUMERIC_VALUE); + sqlTypesToFlightEnumConvertTypes.put(SMALLINT, SqlSupportsConvert.SQL_CONVERT_SMALLINT_VALUE); + sqlTypesToFlightEnumConvertTypes.put(TINYINT, SqlSupportsConvert.SQL_CONVERT_TINYINT_VALUE); + sqlTypesToFlightEnumConvertTypes.put(FLOAT, SqlSupportsConvert.SQL_CONVERT_FLOAT_VALUE); + sqlTypesToFlightEnumConvertTypes.put(BIGINT, SqlSupportsConvert.SQL_CONVERT_BIGINT_VALUE); + sqlTypesToFlightEnumConvertTypes.put(REAL, SqlSupportsConvert.SQL_CONVERT_REAL_VALUE); + sqlTypesToFlightEnumConvertTypes.put(DECIMAL, SqlSupportsConvert.SQL_CONVERT_DECIMAL_VALUE); + sqlTypesToFlightEnumConvertTypes.put(BINARY, SqlSupportsConvert.SQL_CONVERT_BINARY_VALUE); + sqlTypesToFlightEnumConvertTypes.put(LONGVARBINARY, + SqlSupportsConvert.SQL_CONVERT_LONGVARBINARY_VALUE); + sqlTypesToFlightEnumConvertTypes.put(CHAR, SqlSupportsConvert.SQL_CONVERT_CHAR_VALUE); + sqlTypesToFlightEnumConvertTypes.put(VARCHAR, SqlSupportsConvert.SQL_CONVERT_VARCHAR_VALUE); + sqlTypesToFlightEnumConvertTypes.put(LONGNVARCHAR, + SqlSupportsConvert.SQL_CONVERT_LONGVARCHAR_VALUE); + sqlTypesToFlightEnumConvertTypes.put(DATE, SqlSupportsConvert.SQL_CONVERT_DATE_VALUE); + sqlTypesToFlightEnumConvertTypes.put(TIMESTAMP, SqlSupportsConvert.SQL_CONVERT_TIMESTAMP_VALUE); + } + + ArrowDatabaseMetadata(final AvaticaConnection connection) { + super(connection); + } + + @Override + public String getDatabaseProductName() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.FLIGHT_SQL_SERVER_NAME, String.class); + } + + @Override + public String getDatabaseProductVersion() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.FLIGHT_SQL_SERVER_VERSION, String.class); + } + + @Override + public String getIdentifierQuoteString() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_IDENTIFIER_QUOTE_CHAR, String.class); + } + + @Override + public boolean isReadOnly() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.FLIGHT_SQL_SERVER_READ_ONLY, Boolean.class); + } + + @Override + public String getSQLKeywords() throws SQLException { + return convertListSqlInfoToString( + getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_KEYWORDS, List.class)); + } + + @Override + public String getNumericFunctions() throws SQLException { + return convertListSqlInfoToString( + getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_NUMERIC_FUNCTIONS, List.class)); + } + + @Override + public String getStringFunctions() throws SQLException { + return convertListSqlInfoToString( + getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_STRING_FUNCTIONS, List.class)); + } + + @Override + public String getSystemFunctions() throws SQLException { + return convertListSqlInfoToString( + getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_SYSTEM_FUNCTIONS, List.class)); + } + + @Override + public String getTimeDateFunctions() throws SQLException { + return convertListSqlInfoToString( + getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_DATETIME_FUNCTIONS, List.class)); + } + + @Override + public String getSearchStringEscape() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_SEARCH_STRING_ESCAPE, String.class); + } + + @Override + public String getExtraNameCharacters() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_EXTRA_NAME_CHARACTERS, String.class); + } + + @Override + public boolean supportsColumnAliasing() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTS_COLUMN_ALIASING, Boolean.class); + } + + @Override + public boolean nullPlusNonNullIsNull() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_NULL_PLUS_NULL_IS_NULL, Boolean.class); + } + + @Override + public boolean supportsConvert() throws SQLException { + return !getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTS_CONVERT, Map.class).isEmpty(); + } + + @Override + public boolean supportsConvert(final int fromType, final int toType) throws SQLException { + final Map> sqlSupportsConvert = + getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTS_CONVERT, Map.class); + + if (!sqlTypesToFlightEnumConvertTypes.containsKey(fromType)) { + return false; + } + + final List list = + sqlSupportsConvert.get(sqlTypesToFlightEnumConvertTypes.get(fromType)); + + return list != null && list.contains(sqlTypesToFlightEnumConvertTypes.get(toType)); + } + + @Override + public boolean supportsTableCorrelationNames() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTS_TABLE_CORRELATION_NAMES, + Boolean.class); + } + + @Override + public boolean supportsDifferentTableCorrelationNames() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTS_DIFFERENT_TABLE_CORRELATION_NAMES, + Boolean.class); + } + + @Override + public boolean supportsExpressionsInOrderBy() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTS_EXPRESSIONS_IN_ORDER_BY, + Boolean.class); + } + + @Override + public boolean supportsOrderByUnrelated() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTS_ORDER_BY_UNRELATED, Boolean.class); + } + + @Override + public boolean supportsGroupBy() throws SQLException { + final int bitmask = + getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTED_GROUP_BY, Integer.class); + return bitmask != 0; + } + + @Override + public boolean supportsGroupByUnrelated() throws SQLException { + return getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTED_GROUP_BY, + SqlSupportedGroupBy.SQL_GROUP_BY_UNRELATED); + } + + @Override + public boolean supportsLikeEscapeClause() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTS_LIKE_ESCAPE_CLAUSE, Boolean.class); + } + + @Override + public boolean supportsNonNullableColumns() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTS_NON_NULLABLE_COLUMNS, + Boolean.class); + } + + @Override + public boolean supportsMinimumSQLGrammar() throws SQLException { + return checkEnumLevel( + Arrays.asList(getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTED_GRAMMAR, + SupportedSqlGrammar.SQL_EXTENDED_GRAMMAR), + getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTED_GRAMMAR, + SupportedSqlGrammar.SQL_CORE_GRAMMAR), + getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTED_GRAMMAR, + SupportedSqlGrammar.SQL_MINIMUM_GRAMMAR))); + } + + @Override + public boolean supportsCoreSQLGrammar() throws SQLException { + return checkEnumLevel( + Arrays.asList(getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTED_GRAMMAR, + SupportedSqlGrammar.SQL_EXTENDED_GRAMMAR), + getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTED_GRAMMAR, + SupportedSqlGrammar.SQL_CORE_GRAMMAR))); + } + + @Override + public boolean supportsExtendedSQLGrammar() throws SQLException { + return getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTED_GRAMMAR, + SupportedSqlGrammar.SQL_EXTENDED_GRAMMAR); + } + + @Override + public boolean supportsANSI92EntryLevelSQL() throws SQLException { + return checkEnumLevel( + Arrays.asList(getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_ANSI92_SUPPORTED_LEVEL, + SupportedAnsi92SqlGrammarLevel.ANSI92_ENTRY_SQL), + getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_ANSI92_SUPPORTED_LEVEL, + SupportedAnsi92SqlGrammarLevel.ANSI92_INTERMEDIATE_SQL), + getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_ANSI92_SUPPORTED_LEVEL, + SupportedAnsi92SqlGrammarLevel.ANSI92_FULL_SQL))); + } + + @Override + public boolean supportsANSI92IntermediateSQL() throws SQLException { + return checkEnumLevel( + Arrays.asList(getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_ANSI92_SUPPORTED_LEVEL, + SupportedAnsi92SqlGrammarLevel.ANSI92_ENTRY_SQL), + getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_ANSI92_SUPPORTED_LEVEL, + SupportedAnsi92SqlGrammarLevel.ANSI92_INTERMEDIATE_SQL))); + } + + @Override + public boolean supportsANSI92FullSQL() throws SQLException { + return getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_ANSI92_SUPPORTED_LEVEL, + SupportedAnsi92SqlGrammarLevel.ANSI92_FULL_SQL); + } + + @Override + public boolean supportsIntegrityEnhancementFacility() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTS_INTEGRITY_ENHANCEMENT_FACILITY, + Boolean.class); + } + + @Override + public boolean supportsOuterJoins() throws SQLException { + final int bitmask = + getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_OUTER_JOINS_SUPPORT_LEVEL, Integer.class); + return bitmask != 0; + } + + @Override + public boolean supportsFullOuterJoins() throws SQLException { + return getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_OUTER_JOINS_SUPPORT_LEVEL, + SqlOuterJoinsSupportLevel.SQL_FULL_OUTER_JOINS); + } + + @Override + public boolean supportsLimitedOuterJoins() throws SQLException { + return getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_OUTER_JOINS_SUPPORT_LEVEL, + SqlOuterJoinsSupportLevel.SQL_LIMITED_OUTER_JOINS); + } + + @Override + public String getSchemaTerm() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_SCHEMA_TERM, String.class); + } + + @Override + public String getProcedureTerm() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_PROCEDURE_TERM, String.class); + } + + @Override + public String getCatalogTerm() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_CATALOG_TERM, String.class); + } + + @Override + public boolean isCatalogAtStart() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_CATALOG_AT_START, Boolean.class); + } + + @Override + public boolean supportsSchemasInProcedureCalls() throws SQLException { + return getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_SCHEMAS_SUPPORTED_ACTIONS, + SqlSupportedElementActions.SQL_ELEMENT_IN_PROCEDURE_CALLS); + } + + @Override + public boolean supportsSchemasInIndexDefinitions() throws SQLException { + return getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_SCHEMAS_SUPPORTED_ACTIONS, + SqlSupportedElementActions.SQL_ELEMENT_IN_INDEX_DEFINITIONS); + } + + @Override + public boolean supportsSchemasInPrivilegeDefinitions() throws SQLException { + return getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_SCHEMAS_SUPPORTED_ACTIONS, + SqlSupportedElementActions.SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS); + } + + @Override + public boolean supportsCatalogsInIndexDefinitions() throws SQLException { + return getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_CATALOGS_SUPPORTED_ACTIONS, + SqlSupportedElementActions.SQL_ELEMENT_IN_INDEX_DEFINITIONS); + } + + @Override + public boolean supportsCatalogsInPrivilegeDefinitions() throws SQLException { + return getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_CATALOGS_SUPPORTED_ACTIONS, + SqlSupportedElementActions.SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS); + } + + @Override + public boolean supportsPositionedDelete() throws SQLException { + return getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTED_POSITIONED_COMMANDS, + SqlSupportedPositionedCommands.SQL_POSITIONED_DELETE); + } + + @Override + public boolean supportsPositionedUpdate() throws SQLException { + return getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTED_POSITIONED_COMMANDS, + SqlSupportedPositionedCommands.SQL_POSITIONED_UPDATE); + } + + @Override + public boolean supportsResultSetType(final int type) throws SQLException { + final int bitmask = + getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTED_RESULT_SET_TYPES, Integer.class); + + switch (type) { + case ResultSet.TYPE_FORWARD_ONLY: + return doesBitmaskTranslateToEnum(SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_FORWARD_ONLY, + bitmask); + case ResultSet.TYPE_SCROLL_INSENSITIVE: + return doesBitmaskTranslateToEnum(SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE, + bitmask); + case ResultSet.TYPE_SCROLL_SENSITIVE: + return doesBitmaskTranslateToEnum(SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_SCROLL_SENSITIVE, + bitmask); + default: + throw new SQLException( + "Invalid result set type argument. The informed type is not defined in java.sql.ResultSet."); + } + } + + @Override + public boolean supportsSelectForUpdate() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_SELECT_FOR_UPDATE_SUPPORTED, Boolean.class); + } + + @Override + public boolean supportsStoredProcedures() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_STORED_PROCEDURES_SUPPORTED, Boolean.class); + } + + @Override + public boolean supportsSubqueriesInComparisons() throws SQLException { + return getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTED_SUBQUERIES, + SqlSupportedSubqueries.SQL_SUBQUERIES_IN_COMPARISONS); + } + + @Override + public boolean supportsSubqueriesInExists() throws SQLException { + return getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTED_SUBQUERIES, + SqlSupportedSubqueries.SQL_SUBQUERIES_IN_EXISTS); + } + + @Override + public boolean supportsSubqueriesInIns() throws SQLException { + return getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTED_SUBQUERIES, + SqlSupportedSubqueries.SQL_SUBQUERIES_IN_INS); + } + + @Override + public boolean supportsSubqueriesInQuantifieds() throws SQLException { + return getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTED_SUBQUERIES, + SqlSupportedSubqueries.SQL_SUBQUERIES_IN_QUANTIFIEDS); + } + + @Override + public boolean supportsCorrelatedSubqueries() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_CORRELATED_SUBQUERIES_SUPPORTED, + Boolean.class); + } + + @Override + public boolean supportsUnion() throws SQLException { + final int bitmask = + getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTED_UNIONS, Integer.class); + return bitmask != 0; + } + + @Override + public boolean supportsUnionAll() throws SQLException { + return getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTED_UNIONS, + SqlSupportedUnions.SQL_UNION_ALL); + } + + @Override + public int getMaxBinaryLiteralLength() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_MAX_BINARY_LITERAL_LENGTH, + Long.class).intValue(); + } + + @Override + public int getMaxCharLiteralLength() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_MAX_CHAR_LITERAL_LENGTH, + Long.class).intValue(); + } + + @Override + public int getMaxColumnNameLength() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_MAX_COLUMN_NAME_LENGTH, + Long.class).intValue(); + } + + @Override + public int getMaxColumnsInGroupBy() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_MAX_COLUMNS_IN_GROUP_BY, + Long.class).intValue(); + } + + @Override + public int getMaxColumnsInIndex() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_MAX_COLUMNS_IN_INDEX, + Long.class).intValue(); + } + + @Override + public int getMaxColumnsInOrderBy() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_MAX_COLUMNS_IN_ORDER_BY, + Long.class).intValue(); + } + + @Override + public int getMaxColumnsInSelect() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_MAX_COLUMNS_IN_SELECT, + Long.class).intValue(); + } + + @Override + public int getMaxColumnsInTable() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_MAX_COLUMNS_IN_TABLE, + Long.class).intValue(); + } + + @Override + public int getMaxConnections() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_MAX_CONNECTIONS, Long.class).intValue(); + } + + @Override + public int getMaxCursorNameLength() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_MAX_CURSOR_NAME_LENGTH, + Long.class).intValue(); + } + + @Override + public int getMaxIndexLength() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_MAX_INDEX_LENGTH, Long.class).intValue(); + } + + @Override + public int getMaxSchemaNameLength() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_DB_SCHEMA_NAME_LENGTH, + Long.class).intValue(); + } + + @Override + public int getMaxProcedureNameLength() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_MAX_PROCEDURE_NAME_LENGTH, + Long.class).intValue(); + } + + @Override + public int getMaxCatalogNameLength() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_MAX_CATALOG_NAME_LENGTH, + Long.class).intValue(); + } + + @Override + public int getMaxRowSize() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_MAX_ROW_SIZE, Long.class).intValue(); + } + + @Override + public boolean doesMaxRowSizeIncludeBlobs() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_MAX_ROW_SIZE_INCLUDES_BLOBS, Boolean.class); + } + + @Override + public int getMaxStatementLength() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_MAX_STATEMENT_LENGTH, + Long.class).intValue(); + } + + @Override + public int getMaxStatements() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_MAX_STATEMENTS, Long.class).intValue(); + } + + @Override + public int getMaxTableNameLength() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_MAX_TABLE_NAME_LENGTH, + Long.class).intValue(); + } + + @Override + public int getMaxTablesInSelect() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_MAX_TABLES_IN_SELECT, + Long.class).intValue(); + } + + @Override + public int getMaxUserNameLength() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_MAX_USERNAME_LENGTH, Long.class).intValue(); + } + + @Override + public int getDefaultTransactionIsolation() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_DEFAULT_TRANSACTION_ISOLATION, + Long.class).intValue(); + } + + @Override + public boolean supportsTransactions() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_TRANSACTIONS_SUPPORTED, Boolean.class); + } + + @Override + public boolean supportsTransactionIsolationLevel(final int level) throws SQLException { + final int bitmask = + getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTED_TRANSACTIONS_ISOLATION_LEVELS, + Integer.class); + + switch (level) { + case Connection.TRANSACTION_NONE: + return doesBitmaskTranslateToEnum(SqlTransactionIsolationLevel.SQL_TRANSACTION_NONE, bitmask); + case Connection.TRANSACTION_READ_COMMITTED: + return doesBitmaskTranslateToEnum(SqlTransactionIsolationLevel.SQL_TRANSACTION_READ_COMMITTED, + bitmask); + case Connection.TRANSACTION_READ_UNCOMMITTED: + return doesBitmaskTranslateToEnum(SqlTransactionIsolationLevel.SQL_TRANSACTION_READ_UNCOMMITTED, + bitmask); + case Connection.TRANSACTION_REPEATABLE_READ: + return doesBitmaskTranslateToEnum(SqlTransactionIsolationLevel.SQL_TRANSACTION_REPEATABLE_READ, + bitmask); + case Connection.TRANSACTION_SERIALIZABLE: + return doesBitmaskTranslateToEnum(SqlTransactionIsolationLevel.SQL_TRANSACTION_SERIALIZABLE, + bitmask); + default: + throw new SQLException( + "Invalid transaction isolation level argument. The informed level is not defined in java.sql.Connection."); + } + } + + @Override + public boolean dataDefinitionCausesTransactionCommit() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_DATA_DEFINITION_CAUSES_TRANSACTION_COMMIT, + Boolean.class); + } + + @Override + public boolean dataDefinitionIgnoredInTransactions() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_DATA_DEFINITIONS_IN_TRANSACTIONS_IGNORED, + Boolean.class); + } + + @Override + public boolean supportsBatchUpdates() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_BATCH_UPDATES_SUPPORTED, Boolean.class); + } + + @Override + public boolean supportsSavepoints() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_SAVEPOINTS_SUPPORTED, Boolean.class); + } + + @Override + public boolean supportsNamedParameters() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_NAMED_PARAMETERS_SUPPORTED, Boolean.class); + } + + @Override + public boolean locatorsUpdateCopy() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_LOCATORS_UPDATE_COPY, Boolean.class); + } + + @Override + public boolean supportsStoredFunctionsUsingCallSyntax() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty( + SqlInfo.SQL_STORED_FUNCTIONS_USING_CALL_SYNTAX_SUPPORTED, Boolean.class); + } + + @Override + public ArrowFlightConnection getConnection() throws SQLException { + return (ArrowFlightConnection) super.getConnection(); + } + + private T getSqlInfoAndCacheIfCacheIsEmpty(final SqlInfo sqlInfoCommand, + final Class desiredType) + throws SQLException { + final ArrowFlightConnection connection = getConnection(); + if (cachedSqlInfo.isEmpty()) { + final FlightInfo sqlInfo = connection.getClientHandler().getSqlInfo(); + synchronized (cachedSqlInfo) { + if (cachedSqlInfo.isEmpty()) { + try (final ResultSet resultSet = + ArrowFlightJdbcFlightStreamResultSet.fromFlightInfo( + connection, sqlInfo, null)) { + while (resultSet.next()) { + cachedSqlInfo.put(SqlInfo.forNumber((Integer) resultSet.getObject("info_name")), + resultSet.getObject("value")); + } + } + } + } + } + return desiredType.cast(cachedSqlInfo.get(sqlInfoCommand)); + } + + private String convertListSqlInfoToString(final List sqlInfoList) { + return sqlInfoList.stream().map(Object::toString).collect(Collectors.joining(", ")); + } + + private boolean getSqlInfoEnumOptionAndCacheIfCacheIsEmpty( + final SqlInfo sqlInfoCommand, + final ProtocolMessageEnum enumInstance + ) throws SQLException { + final int bitmask = getSqlInfoAndCacheIfCacheIsEmpty(sqlInfoCommand, Integer.class); + return doesBitmaskTranslateToEnum(enumInstance, bitmask); + } + + private boolean checkEnumLevel(final List toCheck) { + return toCheck.stream().anyMatch(e -> e); + } + + @Override + public ResultSet getCatalogs() throws SQLException { + final ArrowFlightConnection connection = getConnection(); + final FlightInfo flightInfoCatalogs = connection.getClientHandler().getCatalogs(); + + final BufferAllocator allocator = connection.getBufferAllocator(); + final VectorSchemaRootTransformer transformer = + new VectorSchemaRootTransformer.Builder(Schemas.GET_CATALOGS_SCHEMA, allocator) + .renameFieldVector("catalog_name", "TABLE_CAT") + .build(); + return ArrowFlightJdbcFlightStreamResultSet.fromFlightInfo(connection, flightInfoCatalogs, + transformer); + } + + @Override + public ResultSet getImportedKeys(final String catalog, final String schema, final String table) + throws SQLException { + final ArrowFlightConnection connection = getConnection(); + final FlightInfo flightInfoImportedKeys = + connection.getClientHandler().getImportedKeys(catalog, schema, table); + + final BufferAllocator allocator = connection.getBufferAllocator(); + final VectorSchemaRootTransformer transformer = getForeignKeysTransformer(allocator); + return ArrowFlightJdbcFlightStreamResultSet.fromFlightInfo(connection, flightInfoImportedKeys, + transformer); + } + + @Override + public ResultSet getExportedKeys(final String catalog, final String schema, final String table) + throws SQLException { + final ArrowFlightConnection connection = getConnection(); + final FlightInfo flightInfoExportedKeys = + connection.getClientHandler().getExportedKeys(catalog, schema, table); + + final BufferAllocator allocator = connection.getBufferAllocator(); + final VectorSchemaRootTransformer transformer = getForeignKeysTransformer(allocator); + return ArrowFlightJdbcFlightStreamResultSet.fromFlightInfo(connection, flightInfoExportedKeys, + transformer); + } + + @Override + public ResultSet getCrossReference(final String parentCatalog, final String parentSchema, + final String parentTable, + final String foreignCatalog, final String foreignSchema, + final String foreignTable) + throws SQLException { + final ArrowFlightConnection connection = getConnection(); + final FlightInfo flightInfoCrossReference = connection.getClientHandler().getCrossReference( + parentCatalog, parentSchema, parentTable, foreignCatalog, foreignSchema, foreignTable); + + final BufferAllocator allocator = connection.getBufferAllocator(); + final VectorSchemaRootTransformer transformer = getForeignKeysTransformer(allocator); + return ArrowFlightJdbcFlightStreamResultSet.fromFlightInfo(connection, flightInfoCrossReference, + transformer); + } + + /** + * Transformer used on getImportedKeys, getExportedKeys and getCrossReference methods, since + * all three share the same schema. + */ + private VectorSchemaRootTransformer getForeignKeysTransformer(final BufferAllocator allocator) { + return new VectorSchemaRootTransformer.Builder(Schemas.GET_IMPORTED_KEYS_SCHEMA, + allocator) + .renameFieldVector("pk_catalog_name", "PKTABLE_CAT") + .renameFieldVector("pk_db_schema_name", "PKTABLE_SCHEM") + .renameFieldVector("pk_table_name", "PKTABLE_NAME") + .renameFieldVector("pk_column_name", "PKCOLUMN_NAME") + .renameFieldVector("fk_catalog_name", "FKTABLE_CAT") + .renameFieldVector("fk_db_schema_name", "FKTABLE_SCHEM") + .renameFieldVector("fk_table_name", "FKTABLE_NAME") + .renameFieldVector("fk_column_name", "FKCOLUMN_NAME") + .renameFieldVector("key_sequence", "KEY_SEQ") + .renameFieldVector("fk_key_name", "FK_NAME") + .renameFieldVector("pk_key_name", "PK_NAME") + .renameFieldVector("update_rule", "UPDATE_RULE") + .renameFieldVector("delete_rule", "DELETE_RULE") + .addEmptyField("DEFERRABILITY", new ArrowType.Int(Byte.SIZE, false)) + .build(); + } + + @Override + public ResultSet getSchemas(final String catalog, final String schemaPattern) + throws SQLException { + final ArrowFlightConnection connection = getConnection(); + final FlightInfo flightInfoSchemas = + connection.getClientHandler().getSchemas(catalog, schemaPattern); + + final BufferAllocator allocator = connection.getBufferAllocator(); + final VectorSchemaRootTransformer transformer = + new VectorSchemaRootTransformer.Builder(Schemas.GET_SCHEMAS_SCHEMA, allocator) + .renameFieldVector("db_schema_name", "TABLE_SCHEM") + .renameFieldVector("catalog_name", "TABLE_CATALOG") + .build(); + return ArrowFlightJdbcFlightStreamResultSet.fromFlightInfo(connection, flightInfoSchemas, + transformer); + } + + @Override + public ResultSet getTableTypes() throws SQLException { + final ArrowFlightConnection connection = getConnection(); + final FlightInfo flightInfoTableTypes = connection.getClientHandler().getTableTypes(); + + final BufferAllocator allocator = connection.getBufferAllocator(); + final VectorSchemaRootTransformer transformer = + new VectorSchemaRootTransformer.Builder(Schemas.GET_TABLE_TYPES_SCHEMA, allocator) + .renameFieldVector("table_type", "TABLE_TYPE") + .build(); + return ArrowFlightJdbcFlightStreamResultSet.fromFlightInfo(connection, flightInfoTableTypes, + transformer); + } + + @Override + public ResultSet getTables(final String catalog, final String schemaPattern, + final String tableNamePattern, + final String[] types) + throws SQLException { + final ArrowFlightConnection connection = getConnection(); + final List typesList = types == null ? null : Arrays.asList(types); + final FlightInfo flightInfoTables = + connection.getClientHandler() + .getTables(catalog, schemaPattern, tableNamePattern, typesList, false); + + final BufferAllocator allocator = connection.getBufferAllocator(); + final VectorSchemaRootTransformer transformer = + new VectorSchemaRootTransformer.Builder(Schemas.GET_TABLES_SCHEMA_NO_SCHEMA, allocator) + .renameFieldVector("catalog_name", "TABLE_CAT") + .renameFieldVector("db_schema_name", "TABLE_SCHEM") + .renameFieldVector("table_name", "TABLE_NAME") + .renameFieldVector("table_type", "TABLE_TYPE") + .addEmptyField("REMARKS", Types.MinorType.VARBINARY) + .addEmptyField("TYPE_CAT", Types.MinorType.VARBINARY) + .addEmptyField("TYPE_SCHEM", Types.MinorType.VARBINARY) + .addEmptyField("TYPE_NAME", Types.MinorType.VARBINARY) + .addEmptyField("SELF_REFERENCING_COL_NAME", Types.MinorType.VARBINARY) + .addEmptyField("REF_GENERATION", Types.MinorType.VARBINARY) + .build(); + return ArrowFlightJdbcFlightStreamResultSet.fromFlightInfo(connection, flightInfoTables, + transformer); + } + + @Override + public ResultSet getPrimaryKeys(final String catalog, final String schema, final String table) + throws SQLException { + final ArrowFlightConnection connection = getConnection(); + final FlightInfo flightInfoPrimaryKeys = + connection.getClientHandler().getPrimaryKeys(catalog, schema, table); + + final BufferAllocator allocator = connection.getBufferAllocator(); + final VectorSchemaRootTransformer transformer = + new VectorSchemaRootTransformer.Builder(Schemas.GET_PRIMARY_KEYS_SCHEMA, allocator) + .renameFieldVector("catalog_name", "TABLE_CAT") + .renameFieldVector("db_schema_name", "TABLE_SCHEM") + .renameFieldVector("table_name", "TABLE_NAME") + .renameFieldVector("column_name", "COLUMN_NAME") + .renameFieldVector("key_sequence", "KEY_SEQ") + .renameFieldVector("key_name", "PK_NAME") + .build(); + return ArrowFlightJdbcFlightStreamResultSet.fromFlightInfo(connection, flightInfoPrimaryKeys, + transformer); + } + + @Override + public ResultSet getColumns(final String catalog, final String schemaPattern, + final String tableNamePattern, + final String columnNamePattern) + throws SQLException { + final ArrowFlightConnection connection = getConnection(); + final FlightInfo flightInfoTables = + connection.getClientHandler() + .getTables(catalog, schemaPattern, tableNamePattern, null, true); + + final BufferAllocator allocator = connection.getBufferAllocator(); + + final Pattern columnNamePat = + columnNamePattern != null ? Pattern.compile(sqlToRegexLike(columnNamePattern)) : null; + + return ArrowFlightJdbcFlightStreamResultSet.fromFlightInfo(connection, flightInfoTables, + (originalRoot, transformedRoot) -> { + int columnCounter = 0; + if (transformedRoot == null) { + transformedRoot = VectorSchemaRoot.create(GET_COLUMNS_SCHEMA, allocator); + } + + final int originalRootRowCount = originalRoot.getRowCount(); + + final VarCharVector catalogNameVector = + (VarCharVector) originalRoot.getVector("catalog_name"); + final VarCharVector tableNameVector = + (VarCharVector) originalRoot.getVector("table_name"); + final VarCharVector schemaNameVector = + (VarCharVector) originalRoot.getVector("db_schema_name"); + + final VarBinaryVector schemaVector = + (VarBinaryVector) originalRoot.getVector("table_schema"); + + for (int i = 0; i < originalRootRowCount; i++) { + final Text catalogName = catalogNameVector.getObject(i); + final Text tableName = tableNameVector.getObject(i); + final Text schemaName = schemaNameVector.getObject(i); + + final Schema currentSchema; + try { + currentSchema = MessageSerializer.deserializeSchema( + new ReadChannel(Channels.newChannel( + new ByteArrayInputStream(schemaVector.get(i))))); + } catch (final IOException e) { + throw new IOException( + String.format("Failed to deserialize schema for table %s", tableName), e); + } + final List tableColumns = currentSchema.getFields(); + + columnCounter = setGetColumnsVectorSchemaRootFromFields(transformedRoot, columnCounter, + tableColumns, + catalogName, tableName, schemaName, columnNamePat); + } + + transformedRoot.setRowCount(columnCounter); + + originalRoot.clear(); + return transformedRoot; + }); + } + + private int setGetColumnsVectorSchemaRootFromFields(final VectorSchemaRoot currentRoot, + int insertIndex, + final List tableColumns, + final Text catalogName, + final Text tableName, final Text schemaName, + final Pattern columnNamePattern) { + int ordinalIndex = 1; + final int tableColumnsSize = tableColumns.size(); + + final VarCharVector tableCatVector = (VarCharVector) currentRoot.getVector("TABLE_CAT"); + final VarCharVector tableSchemVector = (VarCharVector) currentRoot.getVector("TABLE_SCHEM"); + final VarCharVector tableNameVector = (VarCharVector) currentRoot.getVector("TABLE_NAME"); + final VarCharVector columnNameVector = (VarCharVector) currentRoot.getVector("COLUMN_NAME"); + final IntVector dataTypeVector = (IntVector) currentRoot.getVector("DATA_TYPE"); + final VarCharVector typeNameVector = (VarCharVector) currentRoot.getVector("TYPE_NAME"); + final IntVector columnSizeVector = (IntVector) currentRoot.getVector("COLUMN_SIZE"); + final IntVector decimalDigitsVector = (IntVector) currentRoot.getVector("DECIMAL_DIGITS"); + final IntVector numPrecRadixVector = (IntVector) currentRoot.getVector("NUM_PREC_RADIX"); + final IntVector nullableVector = (IntVector) currentRoot.getVector("NULLABLE"); + final IntVector ordinalPositionVector = (IntVector) currentRoot.getVector("ORDINAL_POSITION"); + final VarCharVector isNullableVector = (VarCharVector) currentRoot.getVector("IS_NULLABLE"); + final VarCharVector isAutoincrementVector = (VarCharVector) currentRoot.getVector("IS_AUTOINCREMENT"); + final VarCharVector isGeneratedColumnVector = (VarCharVector) currentRoot.getVector("IS_GENERATEDCOLUMN"); + + for (int i = 0; i < tableColumnsSize; i++, ordinalIndex++) { + final Field field = tableColumns.get(i); + final FlightSqlColumnMetadata columnMetadata = new FlightSqlColumnMetadata(field.getMetadata()); + final String columnName = field.getName(); + + if (columnNamePattern != null && !columnNamePattern.matcher(columnName).matches()) { + continue; + } + final ArrowType fieldType = field.getType(); + + if (catalogName != null) { + tableCatVector.setSafe(insertIndex, catalogName); + } + + if (schemaName != null) { + tableSchemVector.setSafe(insertIndex, schemaName); + } + + if (tableName != null) { + tableNameVector.setSafe(insertIndex, tableName); + } + + if (columnName != null) { + columnNameVector.setSafe(insertIndex, columnName.getBytes(CHARSET)); + } + + dataTypeVector.setSafe(insertIndex, SqlTypes.getSqlTypeIdFromArrowType(fieldType)); + byte[] typeName = columnMetadata.getTypeName() != null ? + columnMetadata.getTypeName().getBytes(CHARSET) : + SqlTypes.getSqlTypeNameFromArrowType(fieldType).getBytes(CHARSET); + typeNameVector.setSafe(insertIndex, typeName); + + // We're not setting COLUMN_SIZE for ROWID SQL Types, as there's no such Arrow type. + // We're not setting COLUMN_SIZE nor DECIMAL_DIGITS for Float/Double as their precision and scale are variable. + if (fieldType instanceof ArrowType.Decimal) { + numPrecRadixVector.setSafe(insertIndex, BASE10_RADIX); + } else if (fieldType instanceof ArrowType.Int) { + numPrecRadixVector.setSafe(insertIndex, BASE10_RADIX); + } else if (fieldType instanceof ArrowType.FloatingPoint) { + numPrecRadixVector.setSafe(insertIndex, BASE10_RADIX); + } + + Integer decimalDigits = columnMetadata.getScale(); + if (decimalDigits == null) { + decimalDigits = getDecimalDigits(fieldType); + } + if (decimalDigits != null) { + decimalDigitsVector.setSafe(insertIndex, decimalDigits); + } + + Integer columnSize = columnMetadata.getPrecision(); + if (columnSize == null) { + columnSize = getColumnSize(fieldType); + } + if (columnSize != null) { + columnSizeVector.setSafe(insertIndex, columnSize); + } + + nullableVector.setSafe(insertIndex, field.isNullable() ? 1 : 0); + + isNullableVector.setSafe(insertIndex, booleanToYesOrNo(field.isNullable())); + + Boolean autoIncrement = columnMetadata.isAutoIncrement(); + if (autoIncrement != null) { + isAutoincrementVector.setSafe(insertIndex, booleanToYesOrNo(autoIncrement)); + } else { + isAutoincrementVector.setSafe(insertIndex, EMPTY_BYTE_ARRAY); + } + + // Fields also don't hold information about IS_AUTOINCREMENT and IS_GENERATEDCOLUMN, + // so we're setting an empty string (as bytes), which means it couldn't be determined. + isGeneratedColumnVector.setSafe(insertIndex, EMPTY_BYTE_ARRAY); + + ordinalPositionVector.setSafe(insertIndex, ordinalIndex); + + insertIndex++; + } + return insertIndex; + } + + private static byte[] booleanToYesOrNo(boolean autoIncrement) { + return autoIncrement ? "YES".getBytes(CHARSET) : "NO".getBytes(CHARSET); + } + + static Integer getDecimalDigits(final ArrowType fieldType) { + // We're not setting DECIMAL_DIGITS for Float/Double as their precision and scale are variable. + if (fieldType instanceof ArrowType.Decimal) { + final ArrowType.Decimal thisDecimal = (ArrowType.Decimal) fieldType; + return thisDecimal.getScale(); + } else if (fieldType instanceof ArrowType.Int) { + return NO_DECIMAL_DIGITS; + } else if (fieldType instanceof ArrowType.Timestamp) { + switch (((ArrowType.Timestamp) fieldType).getUnit()) { + case SECOND: + return NO_DECIMAL_DIGITS; + case MILLISECOND: + return DECIMAL_DIGITS_TIME_MILLISECONDS; + case MICROSECOND: + return DECIMAL_DIGITS_TIME_MICROSECONDS; + case NANOSECOND: + return DECIMAL_DIGITS_TIME_NANOSECONDS; + default: + break; + } + } else if (fieldType instanceof ArrowType.Time) { + switch (((ArrowType.Time) fieldType).getUnit()) { + case SECOND: + return NO_DECIMAL_DIGITS; + case MILLISECOND: + return DECIMAL_DIGITS_TIME_MILLISECONDS; + case MICROSECOND: + return DECIMAL_DIGITS_TIME_MICROSECONDS; + case NANOSECOND: + return DECIMAL_DIGITS_TIME_NANOSECONDS; + default: + break; + } + } else if (fieldType instanceof ArrowType.Date) { + return NO_DECIMAL_DIGITS; + } + + return null; + } + + static Integer getColumnSize(final ArrowType fieldType) { + // We're not setting COLUMN_SIZE for ROWID SQL Types, as there's no such Arrow type. + // We're not setting COLUMN_SIZE nor DECIMAL_DIGITS for Float/Double as their precision and scale are variable. + if (fieldType instanceof ArrowType.Decimal) { + final ArrowType.Decimal thisDecimal = (ArrowType.Decimal) fieldType; + return thisDecimal.getPrecision(); + } else if (fieldType instanceof ArrowType.Int) { + final ArrowType.Int thisInt = (ArrowType.Int) fieldType; + switch (thisInt.getBitWidth()) { + case Byte.SIZE: + return COLUMN_SIZE_BYTE; + case Short.SIZE: + return COLUMN_SIZE_SHORT; + case Integer.SIZE: + return COLUMN_SIZE_INT; + case Long.SIZE: + return COLUMN_SIZE_LONG; + default: + break; + } + } else if (fieldType instanceof ArrowType.Utf8 || fieldType instanceof ArrowType.Binary) { + return COLUMN_SIZE_VARCHAR_AND_BINARY; + } else if (fieldType instanceof ArrowType.Timestamp) { + switch (((ArrowType.Timestamp) fieldType).getUnit()) { + case SECOND: + return COLUMN_SIZE_TIMESTAMP_SECONDS; + case MILLISECOND: + return COLUMN_SIZE_TIMESTAMP_MILLISECONDS; + case MICROSECOND: + return COLUMN_SIZE_TIMESTAMP_MICROSECONDS; + case NANOSECOND: + return COLUMN_SIZE_TIMESTAMP_NANOSECONDS; + default: + break; + } + } else if (fieldType instanceof ArrowType.Time) { + switch (((ArrowType.Time) fieldType).getUnit()) { + case SECOND: + return COLUMN_SIZE_TIME; + case MILLISECOND: + return COLUMN_SIZE_TIME_MILLISECONDS; + case MICROSECOND: + return COLUMN_SIZE_TIME_MICROSECONDS; + case NANOSECOND: + return COLUMN_SIZE_TIME_NANOSECONDS; + default: + break; + } + } else if (fieldType instanceof ArrowType.Date) { + return COLUMN_SIZE_DATE; + } + + return null; + } + + static String sqlToRegexLike(final String sqlPattern) { + final int len = sqlPattern.length(); + final StringBuilder javaPattern = new StringBuilder(len + len); + + for (int i = 0; i < len; i++) { + final char currentChar = sqlPattern.charAt(i); + + if (JAVA_REGEX_SPECIALS.indexOf(currentChar) >= 0) { + javaPattern.append('\\'); + } + + switch (currentChar) { + case '_': + javaPattern.append('.'); + break; + case '%': + javaPattern.append("."); + javaPattern.append('*'); + break; + default: + javaPattern.append(currentChar); + break; + } + } + return javaPattern.toString(); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java new file mode 100644 index 0000000000000..d2b6e89e3fb81 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.replaceSemiColons; + +import java.sql.SQLException; +import java.util.Properties; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +import org.apache.arrow.driver.jdbc.client.ArrowFlightSqlClientHandler; +import org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl; +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.util.Preconditions; +import org.apache.calcite.avatica.AvaticaConnection; +import org.apache.calcite.avatica.AvaticaFactory; + +import io.netty.util.concurrent.DefaultThreadFactory; + +/** + * Connection to the Arrow Flight server. + */ +public final class ArrowFlightConnection extends AvaticaConnection { + + private final BufferAllocator allocator; + private final ArrowFlightSqlClientHandler clientHandler; + private final ArrowFlightConnectionConfigImpl config; + private ExecutorService executorService; + + /** + * Creates a new {@link ArrowFlightConnection}. + * + * @param driver the {@link ArrowFlightJdbcDriver} to use. + * @param factory the {@link AvaticaFactory} to use. + * @param url the URL to use. + * @param properties the {@link Properties} to use. + * @param config the {@link ArrowFlightConnectionConfigImpl} to use. + * @param allocator the {@link BufferAllocator} to use. + * @param clientHandler the {@link ArrowFlightSqlClientHandler} to use. + */ + private ArrowFlightConnection(final ArrowFlightJdbcDriver driver, final AvaticaFactory factory, + final String url, final Properties properties, + final ArrowFlightConnectionConfigImpl config, + final BufferAllocator allocator, + final ArrowFlightSqlClientHandler clientHandler) { + super(driver, factory, url, properties); + this.config = Preconditions.checkNotNull(config, "Config cannot be null."); + this.allocator = Preconditions.checkNotNull(allocator, "Allocator cannot be null."); + this.clientHandler = Preconditions.checkNotNull(clientHandler, "Handler cannot be null."); + } + + /** + * Creates a new {@link ArrowFlightConnection} to a {@link FlightClient}. + * + * @param driver the {@link ArrowFlightJdbcDriver} to use. + * @param factory the {@link AvaticaFactory} to use. + * @param url the URL to establish the connection to. + * @param properties the {@link Properties} to use for this session. + * @param allocator the {@link BufferAllocator} to use. + * @return a new {@link ArrowFlightConnection}. + * @throws SQLException on error. + */ + static ArrowFlightConnection createNewConnection(final ArrowFlightJdbcDriver driver, + final AvaticaFactory factory, + String url, final Properties properties, + final BufferAllocator allocator) + throws SQLException { + url = replaceSemiColons(url); + final ArrowFlightConnectionConfigImpl config = new ArrowFlightConnectionConfigImpl(properties); + final ArrowFlightSqlClientHandler clientHandler = createNewClientHandler(config, allocator); + return new ArrowFlightConnection(driver, factory, url, properties, config, allocator, clientHandler); + } + + private static ArrowFlightSqlClientHandler createNewClientHandler( + final ArrowFlightConnectionConfigImpl config, + final BufferAllocator allocator) throws SQLException { + try { + return new ArrowFlightSqlClientHandler.Builder() + .withHost(config.getHost()) + .withPort(config.getPort()) + .withUsername(config.getUser()) + .withPassword(config.getPassword()) + .withTrustStorePath(config.getTrustStorePath()) + .withTrustStorePassword(config.getTrustStorePassword()) + .withSystemTrustStore(config.useSystemTrustStore()) + .withBufferAllocator(allocator) + .withEncryption(config.useEncryption()) + .withDisableCertificateVerification(config.getDisableCertificateVerification()) + .withToken(config.getToken()) + .withCallOptions(config.toCallOption()) + .build(); + } catch (final SQLException e) { + try { + allocator.close(); + } catch (final Exception allocatorCloseEx) { + e.addSuppressed(allocatorCloseEx); + } + throw e; + } + } + + void reset() throws SQLException { + // Clean up any open Statements + try { + AutoCloseables.close(statementMap.values()); + } catch (final Exception e) { + throw AvaticaConnection.HELPER.createException(e.getMessage(), e); + } + + statementMap.clear(); + + // Reset Holdability + this.setHoldability(this.metaData.getResultSetHoldability()); + + // Reset Meta + ((ArrowFlightMetaImpl) this.meta).setDefaultConnectionProperties(); + } + + /** + * Gets the client {@link #clientHandler} backing this connection. + * + * @return the handler. + */ + ArrowFlightSqlClientHandler getClientHandler() throws SQLException { + return clientHandler; + } + + /** + * Gets the {@link ExecutorService} of this connection. + * + * @return the {@link #executorService}. + */ + synchronized ExecutorService getExecutorService() { + return executorService = executorService == null ? + Executors.newFixedThreadPool(config.threadPoolSize(), + new DefaultThreadFactory(getClass().getSimpleName())) : + executorService; + } + + @Override + public Properties getClientInfo() { + final Properties copy = new Properties(); + copy.putAll(info); + return copy; + } + + @Override + public void close() throws SQLException { + if (executorService != null) { + executorService.shutdown(); + } + + try { + AutoCloseables.close(clientHandler); + allocator.getChildAllocators().forEach(AutoCloseables::closeNoChecked); + AutoCloseables.close(allocator); + + super.close(); + } catch (final Exception e) { + throw AvaticaConnection.HELPER.createException(e.getMessage(), e); + } + } + + BufferAllocator getBufferAllocator() { + return allocator; + } + + public ArrowFlightMetaImpl getMeta() { + return (ArrowFlightMetaImpl) this.meta; + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightInfoStatement.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightInfoStatement.java new file mode 100644 index 0000000000000..8365c7bb57a4a --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightInfoStatement.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import java.sql.SQLException; +import java.sql.Statement; + +import org.apache.arrow.flight.FlightInfo; + +/** + * A {@link Statement} that deals with {@link FlightInfo}. + */ +public interface ArrowFlightInfoStatement extends Statement { + + @Override + ArrowFlightConnection getConnection() throws SQLException; + + /** + * Executes the query this {@link Statement} is holding. + * + * @return the {@link FlightInfo} for the results. + * @throws SQLException on error. + */ + FlightInfo executeFlightInfoQuery() throws SQLException; +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcArray.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcArray.java new file mode 100644 index 0000000000000..ed67c97cf691e --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcArray.java @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import java.sql.Array; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.SQLFeatureNotSupportedException; +import java.util.Arrays; +import java.util.Map; + +import org.apache.arrow.driver.jdbc.accessor.impl.complex.AbstractArrowFlightJdbcListVectorAccessor; +import org.apache.arrow.driver.jdbc.utils.SqlTypes; +import org.apache.arrow.memory.util.LargeMemoryUtil; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.util.JsonStringArrayList; +import org.apache.arrow.vector.util.TransferPair; + +/** + * Implementation of {@link Array} using an underlying {@link FieldVector}. + * + * @see AbstractArrowFlightJdbcListVectorAccessor + */ +public class ArrowFlightJdbcArray implements Array { + + private final FieldVector dataVector; + private final long startOffset; + private final long valuesCount; + + /** + * Instantiate an {@link Array} backed up by given {@link FieldVector}, limited by a start offset and values count. + * + * @param dataVector underlying FieldVector, containing the Array items. + * @param startOffset offset from FieldVector pointing to this Array's first value. + * @param valuesCount how many items this Array contains. + */ + public ArrowFlightJdbcArray(FieldVector dataVector, long startOffset, long valuesCount) { + this.dataVector = dataVector; + this.startOffset = startOffset; + this.valuesCount = valuesCount; + } + + @Override + public String getBaseTypeName() { + final ArrowType arrowType = this.dataVector.getField().getType(); + return SqlTypes.getSqlTypeNameFromArrowType(arrowType); + } + + @Override + public int getBaseType() { + final ArrowType arrowType = this.dataVector.getField().getType(); + return SqlTypes.getSqlTypeIdFromArrowType(arrowType); + } + + @Override + public Object getArray() throws SQLException { + return getArray(null); + } + + @Override + public Object getArray(Map> map) throws SQLException { + if (map != null) { + throw new SQLFeatureNotSupportedException(); + } + + return getArrayNoBoundCheck(this.dataVector, this.startOffset, this.valuesCount); + } + + @Override + public Object getArray(long index, int count) throws SQLException { + return getArray(index, count, null); + } + + private void checkBoundaries(long index, int count) { + if (index < 0 || index + count > this.startOffset + this.valuesCount) { + throw new ArrayIndexOutOfBoundsException(); + } + } + + private static Object getArrayNoBoundCheck(ValueVector dataVector, long start, long count) { + Object[] result = new Object[LargeMemoryUtil.checkedCastToInt(count)]; + for (int i = 0; i < count; i++) { + result[i] = dataVector.getObject(LargeMemoryUtil.checkedCastToInt(start + i)); + } + + return result; + } + + @Override + public Object getArray(long index, int count, Map> map) throws SQLException { + if (map != null) { + throw new SQLFeatureNotSupportedException(); + } + + checkBoundaries(index, count); + return getArrayNoBoundCheck(this.dataVector, + LargeMemoryUtil.checkedCastToInt(this.startOffset + index), count); + } + + @Override + public ResultSet getResultSet() throws SQLException { + return this.getResultSet(null); + } + + @Override + public ResultSet getResultSet(Map> map) throws SQLException { + if (map != null) { + throw new SQLFeatureNotSupportedException(); + } + + return getResultSetNoBoundariesCheck(this.dataVector, this.startOffset, this.valuesCount); + } + + @Override + public ResultSet getResultSet(long index, int count) throws SQLException { + return getResultSet(index, count, null); + } + + private static ResultSet getResultSetNoBoundariesCheck(ValueVector dataVector, long start, + long count) + throws SQLException { + TransferPair transferPair = dataVector.getTransferPair(dataVector.getAllocator()); + transferPair.splitAndTransfer(LargeMemoryUtil.checkedCastToInt(start), + LargeMemoryUtil.checkedCastToInt(count)); + FieldVector vectorSlice = (FieldVector) transferPair.getTo(); + + VectorSchemaRoot vectorSchemaRoot = VectorSchemaRoot.of(vectorSlice); + return ArrowFlightJdbcVectorSchemaRootResultSet.fromVectorSchemaRoot(vectorSchemaRoot); + } + + @Override + public ResultSet getResultSet(long index, int count, Map> map) + throws SQLException { + if (map != null) { + throw new SQLFeatureNotSupportedException(); + } + + checkBoundaries(index, count); + return getResultSetNoBoundariesCheck(this.dataVector, + LargeMemoryUtil.checkedCastToInt(this.startOffset + index), count); + } + + @Override + public void free() { + + } + + @Override + public String toString() { + JsonStringArrayList array = new JsonStringArrayList<>((int) this.valuesCount); + + try { + array.addAll(Arrays.asList((Object[]) getArray())); + } catch (SQLException e) { + throw new RuntimeException(e); + } + + return array.toString(); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcConnectionPoolDataSource.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcConnectionPoolDataSource.java new file mode 100644 index 0000000000000..46a1d3ff87c34 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcConnectionPoolDataSource.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import java.sql.SQLException; +import java.util.Map; +import java.util.Properties; +import java.util.Queue; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; + +import javax.sql.ConnectionEvent; +import javax.sql.ConnectionEventListener; +import javax.sql.ConnectionPoolDataSource; +import javax.sql.PooledConnection; + +import org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl; + +/** + * {@link ConnectionPoolDataSource} implementation for Arrow Flight JDBC Driver. + */ +public class ArrowFlightJdbcConnectionPoolDataSource extends ArrowFlightJdbcDataSource + implements ConnectionPoolDataSource, ConnectionEventListener, AutoCloseable { + private final Map> pool = + new ConcurrentHashMap<>(); + + /** + * Instantiates a new DataSource. + * + * @param properties the properties + * @param config the config. + */ + protected ArrowFlightJdbcConnectionPoolDataSource(final Properties properties, + final ArrowFlightConnectionConfigImpl config) { + super(properties, config); + } + + /** + * Creates a new {@link ArrowFlightJdbcConnectionPoolDataSource}. + * + * @param properties the properties. + * @return a new data source. + */ + public static ArrowFlightJdbcConnectionPoolDataSource createNewDataSource( + final Properties properties) { + return new ArrowFlightJdbcConnectionPoolDataSource(properties, + new ArrowFlightConnectionConfigImpl(properties)); + } + + @Override + public PooledConnection getPooledConnection() throws SQLException { + final ArrowFlightConnectionConfigImpl config = getConfig(); + return this.getPooledConnection(config.getUser(), config.getPassword()); + } + + @Override + public PooledConnection getPooledConnection(final String username, final String password) + throws SQLException { + final Properties properties = getProperties(username, password); + Queue objectPool = + pool.computeIfAbsent(properties, s -> new ConcurrentLinkedQueue<>()); + ArrowFlightJdbcPooledConnection pooledConnection = objectPool.poll(); + if (pooledConnection == null) { + pooledConnection = createPooledConnection(new ArrowFlightConnectionConfigImpl(properties)); + } else { + pooledConnection.reset(); + } + return pooledConnection; + } + + private ArrowFlightJdbcPooledConnection createPooledConnection( + final ArrowFlightConnectionConfigImpl config) + throws SQLException { + ArrowFlightJdbcPooledConnection pooledConnection = + new ArrowFlightJdbcPooledConnection(getConnection(config.getUser(), config.getPassword())); + pooledConnection.addConnectionEventListener(this); + return pooledConnection; + } + + @Override + public void connectionClosed(ConnectionEvent connectionEvent) { + final ArrowFlightJdbcPooledConnection pooledConnection = + (ArrowFlightJdbcPooledConnection) connectionEvent.getSource(); + Queue connectionQueue = + pool.get(pooledConnection.getProperties()); + connectionQueue.add(pooledConnection); + } + + @Override + public void connectionErrorOccurred(ConnectionEvent connectionEvent) { + + } + + @Override + public void close() throws Exception { + SQLException lastException = null; + for (Queue connections : this.pool.values()) { + while (!connections.isEmpty()) { + PooledConnection pooledConnection = connections.poll(); + try { + pooledConnection.close(); + } catch (SQLException e) { + lastException = e; + } + } + } + + if (lastException != null) { + throw lastException; + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcCursor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcCursor.java new file mode 100644 index 0000000000000..45c23e4d5298b --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcCursor.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + + +import java.util.ArrayList; +import java.util.Calendar; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.calcite.avatica.ColumnMetaData; +import org.apache.calcite.avatica.util.AbstractCursor; +import org.apache.calcite.avatica.util.ArrayImpl; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Arrow Flight Jdbc's Cursor class. + */ +public class ArrowFlightJdbcCursor extends AbstractCursor { + + private static final Logger LOGGER; + private final VectorSchemaRoot root; + private final int rowCount; + private int currentRow = -1; + + static { + LOGGER = LoggerFactory.getLogger(ArrowFlightJdbcCursor.class); + } + + public ArrowFlightJdbcCursor(VectorSchemaRoot root) { + this.root = root; + rowCount = root.getRowCount(); + } + + @Override + public List createAccessors(List columns, + Calendar localCalendar, + ArrayImpl.Factory factory) { + final List fieldVectors = root.getFieldVectors(); + + return IntStream.range(0, fieldVectors.size()).mapToObj(root::getVector) + .map(this::createAccessor) + .collect(Collectors.toCollection(() -> new ArrayList<>(fieldVectors.size()))); + } + + private Accessor createAccessor(FieldVector vector) { + return ArrowFlightJdbcAccessorFactory.createAccessor(vector, this::getCurrentRow, + (boolean wasNull) -> { + // AbstractCursor creates a boolean array of length 1 to hold the wasNull value + this.wasNull[0] = wasNull; + }); + } + + /** + * ArrowFlightJdbcAccessors do not use {@link AbstractCursor.Getter}, as it would box primitive types and cause + * performance issues. Each Accessor implementation works directly on Arrow Vectors. + */ + @Override + protected Getter createGetter(int column) { + throw new UnsupportedOperationException("Not allowed."); + } + + @Override + public boolean next() { + currentRow++; + return currentRow < rowCount; + } + + @Override + public void close() { + try { + AutoCloseables.close(root); + } catch (Exception e) { + LOGGER.error(e.getMessage(), e); + } + } + + private int getCurrentRow() { + return currentRow; + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcDataSource.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcDataSource.java new file mode 100644 index 0000000000000..a57eeaa830492 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcDataSource.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty; + +import java.io.PrintWriter; +import java.sql.SQLException; +import java.sql.SQLFeatureNotSupportedException; +import java.util.Properties; +import java.util.logging.Logger; + +import javax.sql.DataSource; + +import org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl; +import org.apache.arrow.util.Preconditions; + +/** + * {@link DataSource} implementation for Arrow Flight JDBC Driver. + */ +public class ArrowFlightJdbcDataSource implements DataSource { + private final Properties properties; + private final ArrowFlightConnectionConfigImpl config; + private PrintWriter logWriter; + + /** + * Instantiates a new DataSource. + */ + protected ArrowFlightJdbcDataSource(final Properties properties, + final ArrowFlightConnectionConfigImpl config) { + this.properties = Preconditions.checkNotNull(properties); + this.config = Preconditions.checkNotNull(config); + } + + /** + * Gets the {@link #config} for this {@link ArrowFlightJdbcDataSource}. + * + * @return the {@link ArrowFlightConnectionConfigImpl}. + */ + protected final ArrowFlightConnectionConfigImpl getConfig() { + return config; + } + + /** + * Gets a copy of the {@link #properties} for this {@link ArrowFlightJdbcDataSource} with + * the provided {@code username} and {@code password}. + * + * @return the {@link Properties} for this data source. + */ + protected final Properties getProperties(final String username, final String password) { + final Properties newProperties = new Properties(); + newProperties.putAll(this.properties); + if (username != null) { + newProperties.replace(ArrowFlightConnectionProperty.USER.camelName(), username); + } + if (password != null) { + newProperties.replace(ArrowFlightConnectionProperty.PASSWORD.camelName(), password); + } + return ArrowFlightJdbcDriver.lowerCasePropertyKeys(newProperties); + } + + /** + * Creates a new {@link ArrowFlightJdbcDataSource}. + * + * @param properties the properties. + * @return a new data source. + */ + public static ArrowFlightJdbcDataSource createNewDataSource(final Properties properties) { + return new ArrowFlightJdbcDataSource(properties, + new ArrowFlightConnectionConfigImpl(properties)); + } + + @Override + public ArrowFlightConnection getConnection() throws SQLException { + return getConnection(config.getUser(), config.getPassword()); + } + + @Override + public ArrowFlightConnection getConnection(final String username, final String password) + throws SQLException { + final Properties properties = getProperties(username, password); + return new ArrowFlightJdbcDriver().connect(config.url(), properties); + } + + @Override + public T unwrap(Class aClass) throws SQLException { + throw new SQLException("ArrowFlightJdbcDataSource is not a wrapper."); + } + + @Override + public boolean isWrapperFor(Class aClass) { + return false; + } + + @Override + public PrintWriter getLogWriter() { + return this.logWriter; + } + + @Override + public void setLogWriter(PrintWriter logWriter) { + this.logWriter = logWriter; + } + + @Override + public void setLoginTimeout(int timeout) throws SQLException { + throw new SQLFeatureNotSupportedException("Setting Login timeout is not supported."); + } + + @Override + public int getLoginTimeout() { + return 0; + } + + @Override + public Logger getParentLogger() { + return Logger.getLogger("ArrowFlightJdbc"); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcDriver.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcDriver.java new file mode 100644 index 0000000000000..a72fbd3a4d592 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcDriver.java @@ -0,0 +1,269 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.replaceSemiColons; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.Reader; +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.sql.SQLException; +import java.util.Map; +import java.util.Objects; +import java.util.Properties; + +import org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty; +import org.apache.arrow.driver.jdbc.utils.UrlParser; +import org.apache.arrow.flight.FlightRuntimeException; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.Preconditions; +import org.apache.arrow.util.VisibleForTesting; +import org.apache.calcite.avatica.AvaticaConnection; +import org.apache.calcite.avatica.DriverVersion; +import org.apache.calcite.avatica.Meta; +import org.apache.calcite.avatica.UnregisteredDriver; + +/** + * JDBC driver for querying data from an Apache Arrow Flight server. + */ +public class ArrowFlightJdbcDriver extends UnregisteredDriver { + private static final String CONNECT_STRING_PREFIX = "jdbc:arrow-flight-sql://"; + private static final String CONNECT_STRING_PREFIX_DEPRECATED = "jdbc:arrow-flight://"; + private static final String CONNECTION_STRING_EXPECTED = "jdbc:arrow-flight-sql://[host][:port][?param1=value&...]"; + private static DriverVersion version; + + static { + // Special code for supporting Java9 and higher. + // Netty requires some extra properties to unlock some native memory management api + // Setting this property if not already set externally + // This has to be done before any netty class is being loaded + final String key = "cfjd.io.netty.tryReflectionSetAccessible"; + final String tryReflectionSetAccessible = System.getProperty(key); + if (tryReflectionSetAccessible == null) { + System.setProperty(key, Boolean.TRUE.toString()); + } + + new ArrowFlightJdbcDriver().register(); + } + + @Override + public ArrowFlightConnection connect(final String url, final Properties info) + throws SQLException { + final Properties properties = new Properties(info); + properties.putAll(info); + + if (url != null) { + final Map propertiesFromUrl = getUrlsArgs(url); + properties.putAll(propertiesFromUrl); + } + + try { + return ArrowFlightConnection.createNewConnection( + this, + factory, + url, + lowerCasePropertyKeys(properties), + new RootAllocator(Long.MAX_VALUE)); + } catch (final FlightRuntimeException e) { + throw new SQLException("Failed to connect.", e); + } + } + + @Override + protected String getFactoryClassName(final JdbcVersion jdbcVersion) { + return ArrowFlightJdbcFactory.class.getName(); + } + + @Override + protected DriverVersion createDriverVersion() { + if (version == null) { + final InputStream flightProperties = this.getClass().getResourceAsStream("/properties/flight.properties"); + if (flightProperties == null) { + throw new RuntimeException("Flight Properties not found. Ensure the JAR was built properly."); + } + try (final Reader reader = new BufferedReader(new InputStreamReader(flightProperties, StandardCharsets.UTF_8))) { + final Properties properties = new Properties(); + properties.load(reader); + + final String parentName = properties.getProperty("org.apache.arrow.flight.name"); + final String parentVersion = properties.getProperty("org.apache.arrow.flight.version"); + final String[] pVersion = parentVersion.split("\\."); + + final int parentMajorVersion = Integer.parseInt(pVersion[0]); + final int parentMinorVersion = Integer.parseInt(pVersion[1]); + + final String childName = properties.getProperty("org.apache.arrow.flight.jdbc-driver.name"); + final String childVersion = properties.getProperty("org.apache.arrow.flight.jdbc-driver.version"); + final String[] cVersion = childVersion.split("\\."); + + final int childMajorVersion = Integer.parseInt(cVersion[0]); + final int childMinorVersion = Integer.parseInt(cVersion[1]); + + version = new DriverVersion( + childName, + childVersion, + parentName, + parentVersion, + true, + childMajorVersion, + childMinorVersion, + parentMajorVersion, + parentMinorVersion); + } catch (final IOException e) { + throw new RuntimeException("Failed to load driver version.", e); + } + } + + return version; + } + + @Override + public Meta createMeta(final AvaticaConnection connection) { + return new ArrowFlightMetaImpl(connection); + } + + @Override + protected String getConnectStringPrefix() { + return CONNECT_STRING_PREFIX; + } + + @Override + public boolean acceptsURL(final String url) { + Preconditions.checkNotNull(url); + return url.startsWith(CONNECT_STRING_PREFIX) || url.startsWith(CONNECT_STRING_PREFIX_DEPRECATED); + } + + /** + * Parses the provided url based on the format this driver accepts, retrieving + * arguments after the {@link #CONNECT_STRING_PREFIX}. + *

+ * This method gets the args if the provided URL follows this pattern: + * {@code jdbc:arrow-flight-sql://:[/?key1=val1&key2=val2&(...)]} + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
GroupDefinitionValue
? — inaccessible{@link #getConnectStringPrefix} + * the URL prefix accepted by this driver, i.e., + * {@code "jdbc:arrow-flight-sql://"} + *
1IPv4 host name + * first word after previous group and before "{@code :}" + *
2IPv4 port number + * first number after previous group and before "{@code /?}" + *
3custom call parameters + * all parameters provided after "{@code /?}" — must follow the + * pattern: "{@code key=value}" with "{@code &}" separating a + * parameter from another + *
+ * + * @param url The url to parse. + * @return the parsed arguments. + * @throws SQLException If an error occurs while trying to parse the URL. + */ + @VisibleForTesting // ArrowFlightJdbcDriverTest + Map getUrlsArgs(String url) + throws SQLException { + + /* + * + * Perhaps this logic should be inside a utility class, separated from this + * one, so as to better delegate responsibilities and concerns throughout + * the code and increase maintainability. + * + * ===== + * + * Keep in mind that the URL must ALWAYS follow the pattern: + * "jdbc:arrow-flight-sql://:[/?param1=value1¶m2=value2&(...)]." + * + */ + + final Properties resultMap = new Properties(); + url = replaceSemiColons(url); + + if (!url.startsWith("jdbc:")) { + throw new SQLException("Connection string must start with 'jdbc:'. Expected format: " + + CONNECTION_STRING_EXPECTED); + } + + // It's necessary to use a string without "jdbc:" at the beginning to be parsed as a valid URL. + url = url.substring(5); + + final URI uri; + + try { + uri = URI.create(url); + } catch (final IllegalArgumentException e) { + throw new SQLException("Malformed/invalid URL!", e); + } + + if (!Objects.equals(uri.getScheme(), "arrow-flight") && + !Objects.equals(uri.getScheme(), "arrow-flight-sql")) { + throw new SQLException("URL Scheme must be 'arrow-flight'. Expected format: " + + CONNECTION_STRING_EXPECTED); + } + + if (uri.getHost() == null) { + throw new SQLException("URL must have a host. Expected format: " + CONNECTION_STRING_EXPECTED); + } else if (uri.getPort() < 0) { + throw new SQLException("URL must have a port. Expected format: " + CONNECTION_STRING_EXPECTED); + } + resultMap.put(ArrowFlightConnectionProperty.HOST.camelName(), uri.getHost()); // host + resultMap.put(ArrowFlightConnectionProperty.PORT.camelName(), uri.getPort()); // port + + final String extraParams = uri.getRawQuery(); // optional params + if (extraParams != null) { + final Map keyValuePairs = UrlParser.parse(extraParams, "&"); + resultMap.putAll(keyValuePairs); + } + + return resultMap; + } + + static Properties lowerCasePropertyKeys(final Properties properties) { + final Properties resultProperty = new Properties(); + properties.forEach((k, v) -> resultProperty.put(k.toString().toLowerCase(), v)); + return resultProperty; + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFactory.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFactory.java new file mode 100644 index 0000000000000..a54fbb9511b55 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFactory.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import java.sql.ResultSetMetaData; +import java.sql.SQLException; +import java.util.Properties; +import java.util.TimeZone; + +import org.apache.arrow.memory.RootAllocator; +import org.apache.calcite.avatica.AvaticaConnection; +import org.apache.calcite.avatica.AvaticaFactory; +import org.apache.calcite.avatica.AvaticaResultSetMetaData; +import org.apache.calcite.avatica.AvaticaSpecificDatabaseMetaData; +import org.apache.calcite.avatica.AvaticaStatement; +import org.apache.calcite.avatica.Meta; +import org.apache.calcite.avatica.QueryState; +import org.apache.calcite.avatica.UnregisteredDriver; + +/** + * Factory for the Arrow Flight JDBC Driver. + */ +public class ArrowFlightJdbcFactory implements AvaticaFactory { + private final int major; + private final int minor; + + // This need to be public so Avatica can call this constructor + public ArrowFlightJdbcFactory() { + this(4, 1); + } + + private ArrowFlightJdbcFactory(final int major, final int minor) { + this.major = major; + this.minor = minor; + } + + @Override + public AvaticaConnection newConnection(final UnregisteredDriver driver, + final AvaticaFactory factory, + final String url, + final Properties info) throws SQLException { + return ArrowFlightConnection.createNewConnection( + (ArrowFlightJdbcDriver) driver, + factory, + url, + info, + new RootAllocator(Long.MAX_VALUE)); + } + + @Override + public AvaticaStatement newStatement( + final AvaticaConnection connection, + final Meta.StatementHandle handle, + final int resultType, + final int resultSetConcurrency, + final int resultSetHoldability) { + return new ArrowFlightStatement((ArrowFlightConnection) connection, + handle, resultType, resultSetConcurrency, resultSetHoldability); + } + + @Override + public ArrowFlightPreparedStatement newPreparedStatement( + final AvaticaConnection connection, + final Meta.StatementHandle statementHandle, + final Meta.Signature signature, + final int resultType, + final int resultSetConcurrency, + final int resultSetHoldability) throws SQLException { + return ArrowFlightPreparedStatement.createNewPreparedStatement( + (ArrowFlightConnection) connection, statementHandle, signature, + resultType, resultSetConcurrency, resultSetHoldability); + } + + @Override + public ArrowFlightJdbcVectorSchemaRootResultSet newResultSet(final AvaticaStatement statement, + final QueryState state, + final Meta.Signature signature, + final TimeZone timeZone, + final Meta.Frame frame) + throws SQLException { + final ResultSetMetaData metaData = newResultSetMetaData(statement, signature); + + return new ArrowFlightJdbcFlightStreamResultSet(statement, state, signature, metaData, timeZone, + frame); + } + + @Override + public AvaticaSpecificDatabaseMetaData newDatabaseMetaData(final AvaticaConnection connection) { + return new ArrowDatabaseMetadata(connection); + } + + @Override + public ResultSetMetaData newResultSetMetaData( + final AvaticaStatement avaticaStatement, + final Meta.Signature signature) { + return new AvaticaResultSetMetaData(avaticaStatement, + null, signature); + } + + @Override + public int getJdbcMajorVersion() { + return major; + } + + @Override + public int getJdbcMinorVersion() { + return minor; + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java new file mode 100644 index 0000000000000..4c01cb6e5813c --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java @@ -0,0 +1,250 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import static org.apache.arrow.driver.jdbc.utils.FlightStreamQueue.createNewQueue; + +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.SQLException; +import java.util.Optional; +import java.util.TimeZone; +import java.util.concurrent.TimeUnit; + +import org.apache.arrow.driver.jdbc.utils.FlightStreamQueue; +import org.apache.arrow.driver.jdbc.utils.VectorSchemaRootTransformer; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.calcite.avatica.AvaticaResultSet; +import org.apache.calcite.avatica.AvaticaResultSetMetaData; +import org.apache.calcite.avatica.AvaticaStatement; +import org.apache.calcite.avatica.Meta; +import org.apache.calcite.avatica.QueryState; + +/** + * {@link ResultSet} implementation for Arrow Flight used to access the results of multiple {@link FlightStream} + * objects. + */ +public final class ArrowFlightJdbcFlightStreamResultSet + extends ArrowFlightJdbcVectorSchemaRootResultSet { + + private final ArrowFlightConnection connection; + private FlightStream currentFlightStream; + private FlightStreamQueue flightStreamQueue; + + private VectorSchemaRootTransformer transformer; + private VectorSchemaRoot currentVectorSchemaRoot; + + private Schema schema; + + ArrowFlightJdbcFlightStreamResultSet(final AvaticaStatement statement, + final QueryState state, + final Meta.Signature signature, + final ResultSetMetaData resultSetMetaData, + final TimeZone timeZone, + final Meta.Frame firstFrame) throws SQLException { + super(statement, state, signature, resultSetMetaData, timeZone, firstFrame); + this.connection = (ArrowFlightConnection) statement.connection; + } + + ArrowFlightJdbcFlightStreamResultSet(final ArrowFlightConnection connection, + final QueryState state, + final Meta.Signature signature, + final ResultSetMetaData resultSetMetaData, + final TimeZone timeZone, + final Meta.Frame firstFrame) throws SQLException { + super(null, state, signature, resultSetMetaData, timeZone, firstFrame); + this.connection = connection; + } + + /** + * Create a {@link ResultSet} which pulls data from given {@link FlightInfo}. + * + * @param connection The connection linked to the returned ResultSet. + * @param flightInfo The FlightInfo from which data will be iterated by the returned ResultSet. + * @param transformer Optional transformer for processing VectorSchemaRoot before access from ResultSet + * @return A ResultSet which pulls data from given FlightInfo. + */ + static ArrowFlightJdbcFlightStreamResultSet fromFlightInfo( + final ArrowFlightConnection connection, + final FlightInfo flightInfo, + final VectorSchemaRootTransformer transformer) throws SQLException { + // Similar to how org.apache.calcite.avatica.util.ArrayFactoryImpl does + + final TimeZone timeZone = TimeZone.getDefault(); + final QueryState state = new QueryState(); + + final Meta.Signature signature = ArrowFlightMetaImpl.newSignature(null); + + final AvaticaResultSetMetaData resultSetMetaData = + new AvaticaResultSetMetaData(null, null, signature); + final ArrowFlightJdbcFlightStreamResultSet resultSet = + new ArrowFlightJdbcFlightStreamResultSet(connection, state, signature, resultSetMetaData, + timeZone, null); + + resultSet.transformer = transformer; + + resultSet.execute(flightInfo); + return resultSet; + } + + private void loadNewQueue() { + Optional.ofNullable(flightStreamQueue).ifPresent(AutoCloseables::closeNoChecked); + flightStreamQueue = createNewQueue(connection.getExecutorService()); + } + + private void loadNewFlightStream() throws SQLException { + if (currentFlightStream != null) { + AutoCloseables.closeNoChecked(currentFlightStream); + } + this.currentFlightStream = getNextFlightStream(true); + } + + @Override + protected AvaticaResultSet execute() throws SQLException { + final FlightInfo flightInfo = ((ArrowFlightInfoStatement) statement).executeFlightInfoQuery(); + + if (flightInfo != null) { + schema = flightInfo.getSchema(); + execute(flightInfo); + } + return this; + } + + private void execute(final FlightInfo flightInfo) throws SQLException { + loadNewQueue(); + flightStreamQueue.enqueue(connection.getClientHandler().getStreams(flightInfo)); + loadNewFlightStream(); + + // Ownership of the root will be passed onto the cursor. + if (currentFlightStream != null) { + executeForCurrentFlightStream(); + } + } + + private void executeForCurrentFlightStream() throws SQLException { + final VectorSchemaRoot originalRoot = currentFlightStream.getRoot(); + + if (transformer != null) { + try { + currentVectorSchemaRoot = transformer.transform(originalRoot, currentVectorSchemaRoot); + } catch (final Exception e) { + throw new SQLException("Failed to transform VectorSchemaRoot.", e); + } + } else { + currentVectorSchemaRoot = originalRoot; + } + + if (schema != null) { + execute(currentVectorSchemaRoot, schema); + } else { + execute(currentVectorSchemaRoot); + } + } + + @Override + public boolean next() throws SQLException { + if (currentVectorSchemaRoot == null) { + return false; + } + while (true) { + final boolean hasNext = super.next(); + final int maxRows = statement != null ? statement.getMaxRows() : 0; + if (maxRows != 0 && this.getRow() > maxRows) { + if (statement.isCloseOnCompletion()) { + statement.close(); + } + return false; + } + + if (hasNext) { + return true; + } + + if (currentFlightStream != null) { + currentFlightStream.getRoot().clear(); + if (currentFlightStream.next()) { + executeForCurrentFlightStream(); + continue; + } + + flightStreamQueue.enqueue(currentFlightStream); + } + + currentFlightStream = getNextFlightStream(false); + + if (currentFlightStream != null) { + executeForCurrentFlightStream(); + continue; + } + + if (statement != null && statement.isCloseOnCompletion()) { + statement.close(); + } + + return false; + } + } + + @Override + protected void cancel() { + super.cancel(); + final FlightStream currentFlightStream = this.currentFlightStream; + if (currentFlightStream != null) { + currentFlightStream.cancel("Cancel", null); + } + + if (flightStreamQueue != null) { + try { + flightStreamQueue.close(); + } catch (final Exception e) { + throw new RuntimeException(e); + } + } + } + + @Override + public synchronized void close() { + try { + if (flightStreamQueue != null) { + // flightStreamQueue should close currentFlightStream internally + flightStreamQueue.close(); + } else if (currentFlightStream != null) { + // close is only called for currentFlightStream if there's no queue + currentFlightStream.close(); + } + } catch (final Exception e) { + throw new RuntimeException(e); + } finally { + super.close(); + } + } + + private FlightStream getNextFlightStream(final boolean isExecution) throws SQLException { + if (isExecution) { + final int statementTimeout = statement != null ? statement.getQueryTimeout() : 0; + return statementTimeout != 0 ? + flightStreamQueue.next(statementTimeout, TimeUnit.SECONDS) : flightStreamQueue.next(); + } else { + return flightStreamQueue.next(); + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcPooledConnection.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcPooledConnection.java new file mode 100644 index 0000000000000..96a2d9dda1d18 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcPooledConnection.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import java.sql.Connection; +import java.sql.SQLException; +import java.util.Collections; +import java.util.HashSet; +import java.util.Properties; +import java.util.Set; + +import javax.sql.ConnectionEvent; +import javax.sql.ConnectionEventListener; +import javax.sql.PooledConnection; +import javax.sql.StatementEventListener; + +import org.apache.arrow.driver.jdbc.utils.ConnectionWrapper; + +/** + * {@link PooledConnection} implementation for Arrow Flight JDBC Driver. + */ +public class ArrowFlightJdbcPooledConnection implements PooledConnection { + + private final ArrowFlightConnection connection; + private final Set eventListeners; + private final Set statementEventListeners; + + private final class ConnectionHandle extends ConnectionWrapper { + private boolean closed = false; + + public ConnectionHandle() { + super(connection); + } + + @Override + public void close() throws SQLException { + if (!closed) { + closed = true; + onConnectionClosed(); + } + } + + @Override + public boolean isClosed() throws SQLException { + return this.closed || super.isClosed(); + } + } + + ArrowFlightJdbcPooledConnection(ArrowFlightConnection connection) { + this.connection = connection; + this.eventListeners = Collections.synchronizedSet(new HashSet<>()); + this.statementEventListeners = Collections.synchronizedSet(new HashSet<>()); + } + + public Properties getProperties() { + return connection.getClientInfo(); + } + + @Override + public Connection getConnection() throws SQLException { + return new ConnectionHandle(); + } + + @Override + public void close() throws SQLException { + this.connection.close(); + } + + void reset() throws SQLException { + this.connection.reset(); + } + + @Override + public void addConnectionEventListener(ConnectionEventListener listener) { + eventListeners.add(listener); + } + + @Override + public void removeConnectionEventListener(ConnectionEventListener listener) { + this.eventListeners.remove(listener); + } + + @Override + public void addStatementEventListener(StatementEventListener listener) { + statementEventListeners.add(listener); + } + + @Override + public void removeStatementEventListener(StatementEventListener listener) { + this.statementEventListeners.remove(listener); + } + + private void onConnectionClosed() { + ConnectionEvent connectionEvent = new ConnectionEvent(this); + eventListeners.forEach(listener -> listener.connectionClosed(connectionEvent)); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcTime.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcTime.java new file mode 100644 index 0000000000000..109048bc05c91 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcTime.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import static org.apache.calcite.avatica.util.DateTimeUtils.MILLIS_PER_DAY; + +import java.sql.Time; +import java.time.LocalTime; +import java.time.temporal.ChronoField; +import java.util.List; +import java.util.Objects; +import java.util.concurrent.TimeUnit; + +import org.apache.arrow.util.VisibleForTesting; + +import com.google.common.collect.ImmutableList; + +/** + * Wrapper class for Time objects to include the milliseconds part in ISO 8601 format in this#toString. + */ +public class ArrowFlightJdbcTime extends Time { + private static final List LEADING_ZEROES = ImmutableList.of("", "0", "00"); + + // Desired length of the millisecond portion should be 3 + private static final int DESIRED_MILLIS_LENGTH = 3; + + // Millis of the date time object. + private final int millisReprValue; + + /** + * Constructs this object based on epoch millis. + * + * @param milliseconds milliseconds representing Time. + */ + public ArrowFlightJdbcTime(final long milliseconds) { + super(milliseconds); + millisReprValue = getMillisReprValue(milliseconds); + } + + @VisibleForTesting + ArrowFlightJdbcTime(final LocalTime time) { + // Although the constructor is deprecated, this is the exact same code as Time#valueOf(LocalTime) + super(time.getHour(), time.getMinute(), time.getSecond()); + millisReprValue = time.get(ChronoField.MILLI_OF_SECOND); + } + + private int getMillisReprValue(long milliseconds) { + // Extract the millisecond part from epoch nano day + if (milliseconds >= MILLIS_PER_DAY) { + // Convert to Epoch Day + milliseconds %= MILLIS_PER_DAY; + } else if (milliseconds < 0) { + // LocalTime#ofNanoDay only accepts positive values + milliseconds -= ((milliseconds / MILLIS_PER_DAY) - 1) * MILLIS_PER_DAY; + } + return LocalTime.ofNanoOfDay(TimeUnit.MILLISECONDS.toNanos(milliseconds)) + .get(ChronoField.MILLI_OF_SECOND); + } + + @Override + public String toString() { + final StringBuilder time = new StringBuilder().append(super.toString()); + + if (millisReprValue > 0) { + final String millisString = Integer.toString(millisReprValue); + + // dot to separate the fractional seconds + time.append("."); + + final int millisLength = millisString.length(); + if (millisLength < DESIRED_MILLIS_LENGTH) { + // add necessary leading zeroes + time.append(LEADING_ZEROES.get(DESIRED_MILLIS_LENGTH - millisLength)); + } + time.append(millisString); + } + + return time.toString(); + } + + // Spotbugs requires these methods to be overridden + @Override + public boolean equals(Object obj) { + return super.equals(obj); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), this.millisReprValue); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcVectorSchemaRootResultSet.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcVectorSchemaRootResultSet.java new file mode 100644 index 0000000000000..9e377e51decc9 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcVectorSchemaRootResultSet.java @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import static java.util.Objects.isNull; + +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.SQLException; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.TimeZone; + +import org.apache.arrow.driver.jdbc.utils.ConvertUtils; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.calcite.avatica.AvaticaResultSet; +import org.apache.calcite.avatica.AvaticaResultSetMetaData; +import org.apache.calcite.avatica.AvaticaStatement; +import org.apache.calcite.avatica.ColumnMetaData; +import org.apache.calcite.avatica.Meta; +import org.apache.calcite.avatica.Meta.Frame; +import org.apache.calcite.avatica.Meta.Signature; +import org.apache.calcite.avatica.QueryState; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * {@link ResultSet} implementation used to access a {@link VectorSchemaRoot}. + */ +public class ArrowFlightJdbcVectorSchemaRootResultSet extends AvaticaResultSet { + + private static final Logger LOGGER = + LoggerFactory.getLogger(ArrowFlightJdbcVectorSchemaRootResultSet.class); + VectorSchemaRoot vectorSchemaRoot; + + ArrowFlightJdbcVectorSchemaRootResultSet(final AvaticaStatement statement, final QueryState state, + final Signature signature, + final ResultSetMetaData resultSetMetaData, + final TimeZone timeZone, final Frame firstFrame) + throws SQLException { + super(statement, state, signature, resultSetMetaData, timeZone, firstFrame); + } + + /** + * Instantiate a ResultSet backed up by given VectorSchemaRoot. + * + * @param vectorSchemaRoot root from which the ResultSet will access. + * @return a ResultSet which accesses the given VectorSchemaRoot + */ + public static ArrowFlightJdbcVectorSchemaRootResultSet fromVectorSchemaRoot( + final VectorSchemaRoot vectorSchemaRoot) + throws SQLException { + // Similar to how org.apache.calcite.avatica.util.ArrayFactoryImpl does + + final TimeZone timeZone = TimeZone.getDefault(); + final QueryState state = new QueryState(); + + final Meta.Signature signature = ArrowFlightMetaImpl.newSignature(null); + + final AvaticaResultSetMetaData resultSetMetaData = + new AvaticaResultSetMetaData(null, null, signature); + final ArrowFlightJdbcVectorSchemaRootResultSet + resultSet = + new ArrowFlightJdbcVectorSchemaRootResultSet(null, state, signature, resultSetMetaData, + timeZone, null); + + resultSet.execute(vectorSchemaRoot); + return resultSet; + } + + @Override + protected AvaticaResultSet execute() throws SQLException { + throw new RuntimeException("Can only execute with execute(VectorSchemaRoot)"); + } + + void execute(final VectorSchemaRoot vectorSchemaRoot) { + final List fields = vectorSchemaRoot.getSchema().getFields(); + final List columns = ConvertUtils.convertArrowFieldsToColumnMetaDataList(fields); + signature.columns.clear(); + signature.columns.addAll(columns); + + this.vectorSchemaRoot = vectorSchemaRoot; + execute2(new ArrowFlightJdbcCursor(vectorSchemaRoot), this.signature.columns); + } + + void execute(final VectorSchemaRoot vectorSchemaRoot, final Schema schema) { + final List columns = ConvertUtils.convertArrowFieldsToColumnMetaDataList(schema.getFields()); + signature.columns.clear(); + signature.columns.addAll(columns); + + this.vectorSchemaRoot = vectorSchemaRoot; + execute2(new ArrowFlightJdbcCursor(vectorSchemaRoot), this.signature.columns); + } + + @Override + protected void cancel() { + signature.columns.clear(); + super.cancel(); + try { + AutoCloseables.close(vectorSchemaRoot); + } catch (final Exception e) { + throw new RuntimeException(e); + } + } + + @Override + public void close() { + final Set exceptions = new HashSet<>(); + try { + if (isClosed()) { + return; + } + } catch (final SQLException e) { + exceptions.add(e); + } + try { + AutoCloseables.close(vectorSchemaRoot); + } catch (final Exception e) { + exceptions.add(e); + } + if (!isNull(statement)) { + try { + super.close(); + } catch (final Exception e) { + exceptions.add(e); + } + } + exceptions.parallelStream().forEach(e -> LOGGER.error(e.getMessage(), e)); + exceptions.stream().findAny().ifPresent(e -> { + throw new RuntimeException(e); + }); + } + +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java new file mode 100644 index 0000000000000..cc7addc3a74d1 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import static java.lang.String.format; + +import java.sql.Connection; +import java.sql.SQLException; +import java.sql.SQLTimeoutException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import org.apache.arrow.driver.jdbc.client.ArrowFlightSqlClientHandler.PreparedStatement; +import org.apache.arrow.util.Preconditions; +import org.apache.calcite.avatica.AvaticaConnection; +import org.apache.calcite.avatica.AvaticaParameter; +import org.apache.calcite.avatica.ColumnMetaData; +import org.apache.calcite.avatica.MetaImpl; +import org.apache.calcite.avatica.NoSuchStatementException; +import org.apache.calcite.avatica.QueryState; +import org.apache.calcite.avatica.remote.TypedValue; + +/** + * Metadata handler for Arrow Flight. + */ +public class ArrowFlightMetaImpl extends MetaImpl { + private final Map statementHandlePreparedStatementMap; + + /** + * Constructs a {@link MetaImpl} object specific for Arrow Flight. + * @param connection A {@link AvaticaConnection}. + */ + public ArrowFlightMetaImpl(final AvaticaConnection connection) { + super(connection); + this.statementHandlePreparedStatementMap = new ConcurrentHashMap<>(); + setDefaultConnectionProperties(); + } + + static Signature newSignature(final String sql) { + return new Signature( + new ArrayList(), + sql, + Collections.emptyList(), + Collections.emptyMap(), + null, // unnecessary, as SQL requests use ArrowFlightJdbcCursor + StatementType.SELECT + ); + } + + @Override + public void closeStatement(final StatementHandle statementHandle) { + PreparedStatement preparedStatement = statementHandlePreparedStatementMap.remove(statementHandle); + // Testing if the prepared statement was created because the statement can be not created until this moment + if (preparedStatement != null) { + preparedStatement.close(); + } + } + + @Override + public void commit(final ConnectionHandle connectionHandle) { + // TODO Fill this stub. + } + + @Override + public ExecuteResult execute(final StatementHandle statementHandle, + final List typedValues, final long maxRowCount) { + // TODO Why is maxRowCount ignored? + Preconditions.checkNotNull(statementHandle.signature, "Signature not found."); + return new ExecuteResult( + Collections.singletonList(MetaResultSet.create( + statementHandle.connectionId, statementHandle.id, + true, statementHandle.signature, null))); + } + + @Override + public ExecuteResult execute(final StatementHandle statementHandle, + final List typedValues, final int maxRowsInFirstFrame) { + return execute(statementHandle, typedValues, (long) maxRowsInFirstFrame); + } + + @Override + public ExecuteBatchResult executeBatch(final StatementHandle statementHandle, + final List> parameterValuesList) + throws IllegalStateException { + throw new IllegalStateException("executeBatch not implemented."); + } + + @Override + public Frame fetch(final StatementHandle statementHandle, final long offset, + final int fetchMaxRowCount) { + /* + * ArrowFlightMetaImpl does not use frames. + * Instead, we have accessors that contain a VectorSchemaRoot with + * the results. + */ + throw AvaticaConnection.HELPER.wrap( + format("%s does not use frames.", this), + AvaticaConnection.HELPER.unsupported()); + } + + @Override + public StatementHandle prepare(final ConnectionHandle connectionHandle, + final String query, final long maxRowCount) { + final StatementHandle handle = super.createStatement(connectionHandle); + handle.signature = newSignature(query); + return handle; + } + + @Override + public ExecuteResult prepareAndExecute(final StatementHandle statementHandle, + final String query, final long maxRowCount, + final PrepareCallback prepareCallback) + throws NoSuchStatementException { + return prepareAndExecute( + statementHandle, query, maxRowCount, -1 /* Not used */, prepareCallback); + } + + @Override + public ExecuteResult prepareAndExecute(final StatementHandle handle, + final String query, final long maxRowCount, + final int maxRowsInFirstFrame, + final PrepareCallback callback) + throws NoSuchStatementException { + try { + final PreparedStatement preparedStatement = + ((ArrowFlightConnection) connection).getClientHandler().prepare(query); + final StatementType statementType = preparedStatement.getType(); + statementHandlePreparedStatementMap.put(handle, preparedStatement); + final Signature signature = newSignature(query); + final long updateCount = + statementType.equals(StatementType.UPDATE) ? preparedStatement.executeUpdate() : -1; + synchronized (callback.getMonitor()) { + callback.clear(); + callback.assign(signature, null, updateCount); + } + callback.execute(); + final MetaResultSet metaResultSet = MetaResultSet.create(handle.connectionId, handle.id, + false, signature, null); + return new ExecuteResult(Collections.singletonList(metaResultSet)); + } catch (SQLTimeoutException e) { + // So far AvaticaStatement(executeInternal) only handles NoSuchStatement and Runtime Exceptions. + throw new RuntimeException(e); + } catch (SQLException e) { + throw new NoSuchStatementException(handle); + } + } + + @Override + public ExecuteBatchResult prepareAndExecuteBatch( + final StatementHandle statementHandle, final List queries) + throws NoSuchStatementException { + // TODO Fill this stub. + return null; + } + + @Override + public void rollback(final ConnectionHandle connectionHandle) { + // TODO Fill this stub. + } + + @Override + public boolean syncResults(final StatementHandle statementHandle, + final QueryState queryState, final long offset) + throws NoSuchStatementException { + // TODO Fill this stub. + return false; + } + + void setDefaultConnectionProperties() { + // TODO Double-check this. + connProps.setDirty(false) + .setAutoCommit(true) + .setReadOnly(true) + .setCatalog(null) + .setSchema(null) + .setTransactionIsolation(Connection.TRANSACTION_NONE); + } + + PreparedStatement getPreparedStatement(StatementHandle statementHandle) { + return statementHandlePreparedStatementMap.get(statementHandle); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatement.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatement.java new file mode 100644 index 0000000000000..80029f38f0958 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatement.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.SQLException; + +import org.apache.arrow.driver.jdbc.client.ArrowFlightSqlClientHandler; +import org.apache.arrow.driver.jdbc.utils.ConvertUtils; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.calcite.avatica.AvaticaPreparedStatement; +import org.apache.calcite.avatica.Meta.Signature; +import org.apache.calcite.avatica.Meta.StatementHandle; + + +/** + * Arrow Flight JBCS's implementation {@link PreparedStatement}. + */ +public class ArrowFlightPreparedStatement extends AvaticaPreparedStatement + implements ArrowFlightInfoStatement { + + private final ArrowFlightSqlClientHandler.PreparedStatement preparedStatement; + + private ArrowFlightPreparedStatement(final ArrowFlightConnection connection, + final ArrowFlightSqlClientHandler.PreparedStatement preparedStatement, + final StatementHandle handle, + final Signature signature, final int resultSetType, + final int resultSetConcurrency, + final int resultSetHoldability) + throws SQLException { + super(connection, handle, signature, resultSetType, resultSetConcurrency, resultSetHoldability); + this.preparedStatement = Preconditions.checkNotNull(preparedStatement); + } + + /** + * Creates a new {@link ArrowFlightPreparedStatement} from the provided information. + * + * @param connection the {@link Connection} to use. + * @param statementHandle the {@link StatementHandle} to use. + * @param signature the {@link Signature} to use. + * @param resultSetType the ResultSet type. + * @param resultSetConcurrency the ResultSet concurrency. + * @param resultSetHoldability the ResultSet holdability. + * @return a new {@link PreparedStatement}. + * @throws SQLException on error. + */ + static ArrowFlightPreparedStatement createNewPreparedStatement( + final ArrowFlightConnection connection, + final StatementHandle statementHandle, + final Signature signature, + final int resultSetType, + final int resultSetConcurrency, + final int resultSetHoldability) throws SQLException { + + final ArrowFlightSqlClientHandler.PreparedStatement prepare = connection.getClientHandler().prepare(signature.sql); + final Schema resultSetSchema = prepare.getDataSetSchema(); + + signature.columns.addAll(ConvertUtils.convertArrowFieldsToColumnMetaDataList(resultSetSchema.getFields())); + + return new ArrowFlightPreparedStatement( + connection, prepare, statementHandle, + signature, resultSetType, resultSetConcurrency, resultSetHoldability); + } + + @Override + public ArrowFlightConnection getConnection() throws SQLException { + return (ArrowFlightConnection) super.getConnection(); + } + + @Override + public synchronized void close() throws SQLException { + this.preparedStatement.close(); + super.close(); + } + + @Override + public FlightInfo executeFlightInfoQuery() throws SQLException { + return preparedStatement.executeQuery(); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightStatement.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightStatement.java new file mode 100644 index 0000000000000..5bc7c2ab9b4f8 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightStatement.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import java.sql.SQLException; + +import org.apache.arrow.driver.jdbc.client.ArrowFlightSqlClientHandler.PreparedStatement; +import org.apache.arrow.driver.jdbc.utils.ConvertUtils; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.calcite.avatica.AvaticaStatement; +import org.apache.calcite.avatica.Meta; +import org.apache.calcite.avatica.Meta.StatementHandle; + +/** + * A SQL statement for querying data from an Arrow Flight server. + */ +public class ArrowFlightStatement extends AvaticaStatement implements ArrowFlightInfoStatement { + + ArrowFlightStatement(final ArrowFlightConnection connection, + final StatementHandle handle, final int resultSetType, + final int resultSetConcurrency, final int resultSetHoldability) { + super(connection, handle, resultSetType, resultSetConcurrency, resultSetHoldability); + } + + @Override + public ArrowFlightConnection getConnection() throws SQLException { + return (ArrowFlightConnection) super.getConnection(); + } + + @Override + public FlightInfo executeFlightInfoQuery() throws SQLException { + final PreparedStatement preparedStatement = getConnection().getMeta().getPreparedStatement(handle); + final Meta.Signature signature = getSignature(); + if (signature == null) { + return null; + } + + final Schema resultSetSchema = preparedStatement.getDataSetSchema(); + signature.columns.addAll(ConvertUtils.convertArrowFieldsToColumnMetaDataList(resultSetSchema.getFields())); + setSignature(signature); + + return preparedStatement.executeQuery(); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/ArrowFlightJdbcAccessor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/ArrowFlightJdbcAccessor.java new file mode 100644 index 0000000000000..3821ee1dc8755 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/ArrowFlightJdbcAccessor.java @@ -0,0 +1,256 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor; + +import static org.apache.calcite.avatica.util.Cursor.Accessor; + +import java.io.InputStream; +import java.io.Reader; +import java.math.BigDecimal; +import java.net.URL; +import java.sql.Array; +import java.sql.Blob; +import java.sql.Clob; +import java.sql.Date; +import java.sql.NClob; +import java.sql.Ref; +import java.sql.SQLException; +import java.sql.SQLXML; +import java.sql.Struct; +import java.sql.Time; +import java.sql.Timestamp; +import java.util.Calendar; +import java.util.Map; +import java.util.function.IntSupplier; + +/** + * Base Jdbc Accessor. + */ +public abstract class ArrowFlightJdbcAccessor implements Accessor { + private final IntSupplier currentRowSupplier; + + // All the derived accessor classes should alter this as they encounter null Values + protected boolean wasNull; + protected ArrowFlightJdbcAccessorFactory.WasNullConsumer wasNullConsumer; + + protected ArrowFlightJdbcAccessor(final IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer wasNullConsumer) { + this.currentRowSupplier = currentRowSupplier; + this.wasNullConsumer = wasNullConsumer; + } + + protected int getCurrentRow() { + return currentRowSupplier.getAsInt(); + } + + // It needs to be public so this method can be accessed when creating the complex types. + public abstract Class getObjectClass(); + + @Override + public boolean wasNull() { + return wasNull; + } + + @Override + public String getString() throws SQLException { + final Object object = getObject(); + if (object == null) { + return null; + } + + return object.toString(); + } + + @Override + public boolean getBoolean() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public byte getByte() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public short getShort() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public int getInt() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public long getLong() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public float getFloat() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public double getDouble() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public BigDecimal getBigDecimal() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public BigDecimal getBigDecimal(final int i) throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public byte[] getBytes() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public InputStream getAsciiStream() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public InputStream getUnicodeStream() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public InputStream getBinaryStream() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public Object getObject() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public Reader getCharacterStream() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public Object getObject(final Map> map) throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public Ref getRef() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public Blob getBlob() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public Clob getClob() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public Array getArray() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public Struct getStruct() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public Date getDate(final Calendar calendar) throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public Time getTime(final Calendar calendar) throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public Timestamp getTimestamp(final Calendar calendar) throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public URL getURL() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public NClob getNClob() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public SQLXML getSQLXML() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public String getNString() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public Reader getNCharacterStream() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public T getObject(final Class type) throws SQLException { + final Object value; + if (type == Byte.class) { + value = getByte(); + } else if (type == Short.class) { + value = getShort(); + } else if (type == Integer.class) { + value = getInt(); + } else if (type == Long.class) { + value = getLong(); + } else if (type == Float.class) { + value = getFloat(); + } else if (type == Double.class) { + value = getDouble(); + } else if (type == Boolean.class) { + value = getBoolean(); + } else if (type == BigDecimal.class) { + value = getBigDecimal(); + } else if (type == String.class) { + value = getString(); + } else if (type == byte[].class) { + value = getBytes(); + } else { + value = getObject(); + } + return !type.isPrimitive() && wasNull ? null : type.cast(value); + } + + private static SQLException getOperationNotSupported(final Class type) { + return new SQLException(String.format("Operation not supported for type: %s.", type.getName())); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/ArrowFlightJdbcAccessorFactory.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/ArrowFlightJdbcAccessorFactory.java new file mode 100644 index 0000000000000..813b40a8070f7 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/ArrowFlightJdbcAccessorFactory.java @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor; + +import java.util.function.IntSupplier; + +import org.apache.arrow.driver.jdbc.accessor.impl.ArrowFlightJdbcNullVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.binary.ArrowFlightJdbcBinaryVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcDateVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcDurationVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcIntervalVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcTimeStampVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcTimeVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.complex.ArrowFlightJdbcDenseUnionVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.complex.ArrowFlightJdbcFixedSizeListVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.complex.ArrowFlightJdbcLargeListVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.complex.ArrowFlightJdbcListVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.complex.ArrowFlightJdbcMapVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.complex.ArrowFlightJdbcStructVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.complex.ArrowFlightJdbcUnionVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.numeric.ArrowFlightJdbcBaseIntVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.numeric.ArrowFlightJdbcBitVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.numeric.ArrowFlightJdbcDecimalVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.numeric.ArrowFlightJdbcFloat4VectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.numeric.ArrowFlightJdbcFloat8VectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.text.ArrowFlightJdbcVarCharVectorAccessor; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.DateDayVector; +import org.apache.arrow.vector.DateMilliVector; +import org.apache.arrow.vector.Decimal256Vector; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.DurationVector; +import org.apache.arrow.vector.FixedSizeBinaryVector; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.IntervalDayVector; +import org.apache.arrow.vector.IntervalYearVector; +import org.apache.arrow.vector.LargeVarBinaryVector; +import org.apache.arrow.vector.LargeVarCharVector; +import org.apache.arrow.vector.NullVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TimeMicroVector; +import org.apache.arrow.vector.TimeMilliVector; +import org.apache.arrow.vector.TimeNanoVector; +import org.apache.arrow.vector.TimeSecVector; +import org.apache.arrow.vector.TimeStampVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.UInt1Vector; +import org.apache.arrow.vector.UInt2Vector; +import org.apache.arrow.vector.UInt4Vector; +import org.apache.arrow.vector.UInt8Vector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.VarBinaryVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.complex.DenseUnionVector; +import org.apache.arrow.vector.complex.FixedSizeListVector; +import org.apache.arrow.vector.complex.LargeListVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.MapVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.complex.UnionVector; + +/** + * Factory to instantiate the accessors. + */ +public class ArrowFlightJdbcAccessorFactory { + + /** + * Create an accessor according to its type. + * + * @param vector an instance of an arrow vector. + * @param getCurrentRow a supplier to check which row is being accessed. + * @return an instance of one of the accessors. + */ + public static ArrowFlightJdbcAccessor createAccessor(ValueVector vector, + IntSupplier getCurrentRow, + WasNullConsumer setCursorWasNull) { + if (vector instanceof UInt1Vector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((UInt1Vector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof UInt2Vector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((UInt2Vector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof UInt4Vector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((UInt4Vector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof UInt8Vector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((UInt8Vector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof TinyIntVector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((TinyIntVector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof SmallIntVector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((SmallIntVector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof IntVector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((IntVector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof BigIntVector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((BigIntVector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof Float4Vector) { + return new ArrowFlightJdbcFloat4VectorAccessor((Float4Vector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof Float8Vector) { + return new ArrowFlightJdbcFloat8VectorAccessor((Float8Vector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof BitVector) { + return new ArrowFlightJdbcBitVectorAccessor((BitVector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof DecimalVector) { + return new ArrowFlightJdbcDecimalVectorAccessor((DecimalVector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof Decimal256Vector) { + return new ArrowFlightJdbcDecimalVectorAccessor((Decimal256Vector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof VarBinaryVector) { + return new ArrowFlightJdbcBinaryVectorAccessor((VarBinaryVector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof LargeVarBinaryVector) { + return new ArrowFlightJdbcBinaryVectorAccessor((LargeVarBinaryVector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof FixedSizeBinaryVector) { + return new ArrowFlightJdbcBinaryVectorAccessor((FixedSizeBinaryVector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof TimeStampVector) { + return new ArrowFlightJdbcTimeStampVectorAccessor((TimeStampVector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof TimeNanoVector) { + return new ArrowFlightJdbcTimeVectorAccessor((TimeNanoVector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof TimeMicroVector) { + return new ArrowFlightJdbcTimeVectorAccessor((TimeMicroVector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof TimeMilliVector) { + return new ArrowFlightJdbcTimeVectorAccessor((TimeMilliVector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof TimeSecVector) { + return new ArrowFlightJdbcTimeVectorAccessor((TimeSecVector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof DateDayVector) { + return new ArrowFlightJdbcDateVectorAccessor(((DateDayVector) vector), getCurrentRow, + setCursorWasNull); + } else if (vector instanceof DateMilliVector) { + return new ArrowFlightJdbcDateVectorAccessor(((DateMilliVector) vector), getCurrentRow, + setCursorWasNull); + } else if (vector instanceof VarCharVector) { + return new ArrowFlightJdbcVarCharVectorAccessor((VarCharVector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof LargeVarCharVector) { + return new ArrowFlightJdbcVarCharVectorAccessor((LargeVarCharVector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof DurationVector) { + return new ArrowFlightJdbcDurationVectorAccessor((DurationVector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof IntervalDayVector) { + return new ArrowFlightJdbcIntervalVectorAccessor(((IntervalDayVector) vector), getCurrentRow, + setCursorWasNull); + } else if (vector instanceof IntervalYearVector) { + return new ArrowFlightJdbcIntervalVectorAccessor(((IntervalYearVector) vector), getCurrentRow, + setCursorWasNull); + } else if (vector instanceof StructVector) { + return new ArrowFlightJdbcStructVectorAccessor((StructVector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof MapVector) { + return new ArrowFlightJdbcMapVectorAccessor((MapVector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof ListVector) { + return new ArrowFlightJdbcListVectorAccessor((ListVector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof LargeListVector) { + return new ArrowFlightJdbcLargeListVectorAccessor((LargeListVector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof FixedSizeListVector) { + return new ArrowFlightJdbcFixedSizeListVectorAccessor((FixedSizeListVector) vector, + getCurrentRow, setCursorWasNull); + } else if (vector instanceof UnionVector) { + return new ArrowFlightJdbcUnionVectorAccessor((UnionVector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof DenseUnionVector) { + return new ArrowFlightJdbcDenseUnionVectorAccessor((DenseUnionVector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof NullVector || vector == null) { + return new ArrowFlightJdbcNullVectorAccessor(setCursorWasNull); + } + + throw new UnsupportedOperationException("Unsupported vector type: " + vector.getClass().getName()); + } + + /** + * Functional interface used to propagate that the value accessed was null or not. + */ + @FunctionalInterface + public interface WasNullConsumer { + void setWasNull(boolean wasNull); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/ArrowFlightJdbcNullVectorAccessor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/ArrowFlightJdbcNullVectorAccessor.java new file mode 100644 index 0000000000000..f40a5797293f9 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/ArrowFlightJdbcNullVectorAccessor.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessor; +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.vector.NullVector; + +/** + * Accessor for the Arrow type {@link NullVector}. + */ +public class ArrowFlightJdbcNullVectorAccessor extends ArrowFlightJdbcAccessor { + public ArrowFlightJdbcNullVectorAccessor( + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(null, setCursorWasNull); + } + + @Override + public Class getObjectClass() { + return Object.class; + } + + @Override + public boolean wasNull() { + return true; + } + + @Override + public Object getObject() { + this.wasNullConsumer.setWasNull(true); + return null; + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/binary/ArrowFlightJdbcBinaryVectorAccessor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/binary/ArrowFlightJdbcBinaryVectorAccessor.java new file mode 100644 index 0000000000000..c50d734972134 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/binary/ArrowFlightJdbcBinaryVectorAccessor.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.binary; + +import java.io.ByteArrayInputStream; +import java.io.CharArrayReader; +import java.io.InputStream; +import java.io.Reader; +import java.nio.charset.StandardCharsets; +import java.util.function.IntSupplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessor; +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.vector.FixedSizeBinaryVector; +import org.apache.arrow.vector.LargeVarBinaryVector; +import org.apache.arrow.vector.VarBinaryVector; + +/** + * Accessor for the Arrow types: {@link FixedSizeBinaryVector}, {@link VarBinaryVector} + * and {@link LargeVarBinaryVector}. + */ +public class ArrowFlightJdbcBinaryVectorAccessor extends ArrowFlightJdbcAccessor { + + private interface ByteArrayGetter { + byte[] get(int index); + } + + private final ByteArrayGetter getter; + + public ArrowFlightJdbcBinaryVectorAccessor(FixedSizeBinaryVector vector, + IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + this(vector::get, currentRowSupplier, setCursorWasNull); + } + + public ArrowFlightJdbcBinaryVectorAccessor(VarBinaryVector vector, IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + this(vector::get, currentRowSupplier, setCursorWasNull); + } + + public ArrowFlightJdbcBinaryVectorAccessor(LargeVarBinaryVector vector, + IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + this(vector::get, currentRowSupplier, setCursorWasNull); + } + + private ArrowFlightJdbcBinaryVectorAccessor(ByteArrayGetter getter, + IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.getter = getter; + } + + @Override + public byte[] getBytes() { + byte[] bytes = getter.get(getCurrentRow()); + this.wasNull = bytes == null; + this.wasNullConsumer.setWasNull(this.wasNull); + + return bytes; + } + + @Override + public Object getObject() { + return this.getBytes(); + } + + @Override + public Class getObjectClass() { + return byte[].class; + } + + @Override + public String getString() { + byte[] bytes = this.getBytes(); + if (bytes == null) { + return null; + } + + return new String(bytes, StandardCharsets.UTF_8); + } + + @Override + public InputStream getAsciiStream() { + byte[] bytes = getBytes(); + if (bytes == null) { + return null; + } + + return new ByteArrayInputStream(bytes); + } + + @Override + public InputStream getUnicodeStream() { + byte[] bytes = getBytes(); + if (bytes == null) { + return null; + } + + return new ByteArrayInputStream(bytes); + } + + @Override + public InputStream getBinaryStream() { + byte[] bytes = getBytes(); + if (bytes == null) { + return null; + } + + return new ByteArrayInputStream(bytes); + } + + @Override + public Reader getCharacterStream() { + String string = getString(); + if (string == null) { + return null; + } + + return new CharArrayReader(string.toCharArray()); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcDateVectorAccessor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcDateVectorAccessor.java new file mode 100644 index 0000000000000..f6c14a47f521c --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcDateVectorAccessor.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.calendar; + +import static org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcDateVectorGetter.Getter; +import static org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcDateVectorGetter.Holder; +import static org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcDateVectorGetter.createGetter; +import static org.apache.arrow.driver.jdbc.utils.DateTimeUtils.getTimestampValue; +import static org.apache.calcite.avatica.util.DateTimeUtils.MILLIS_PER_DAY; +import static org.apache.calcite.avatica.util.DateTimeUtils.unixDateToString; + +import java.sql.Date; +import java.sql.Timestamp; +import java.util.Calendar; +import java.util.concurrent.TimeUnit; +import java.util.function.IntSupplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessor; +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.driver.jdbc.utils.DateTimeUtils; +import org.apache.arrow.vector.DateDayVector; +import org.apache.arrow.vector.DateMilliVector; +import org.apache.arrow.vector.ValueVector; + +/** + * Accessor for the Arrow types: {@link DateDayVector} and {@link DateMilliVector}. + */ +public class ArrowFlightJdbcDateVectorAccessor extends ArrowFlightJdbcAccessor { + + private final Getter getter; + private final TimeUnit timeUnit; + private final Holder holder; + + /** + * Instantiate an accessor for a {@link DateDayVector}. + * + * @param vector an instance of a DateDayVector. + * @param currentRowSupplier the supplier to track the lines. + * @param setCursorWasNull the consumer to set if value was null. + */ + public ArrowFlightJdbcDateVectorAccessor(DateDayVector vector, IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.holder = new Holder(); + this.getter = createGetter(vector); + this.timeUnit = getTimeUnitForVector(vector); + } + + /** + * Instantiate an accessor for a {@link DateMilliVector}. + * + * @param vector an instance of a DateMilliVector. + * @param currentRowSupplier the supplier to track the lines. + */ + public ArrowFlightJdbcDateVectorAccessor(DateMilliVector vector, IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.holder = new Holder(); + this.getter = createGetter(vector); + this.timeUnit = getTimeUnitForVector(vector); + } + + @Override + public Class getObjectClass() { + return Date.class; + } + + @Override + public Object getObject() { + return this.getDate(null); + } + + @Override + public Date getDate(Calendar calendar) { + fillHolder(); + if (this.wasNull) { + return null; + } + + long value = holder.value; + long milliseconds = this.timeUnit.toMillis(value); + + long millisWithCalendar = DateTimeUtils.applyCalendarOffset(milliseconds, calendar); + + return new Date(getTimestampValue(millisWithCalendar).getTime()); + } + + private void fillHolder() { + getter.get(getCurrentRow(), holder); + this.wasNull = holder.isSet == 0; + this.wasNullConsumer.setWasNull(this.wasNull); + } + + @Override + public Timestamp getTimestamp(Calendar calendar) { + Date date = getDate(calendar); + if (date == null) { + return null; + } + return new Timestamp(date.getTime()); + } + + @Override + public String getString() { + fillHolder(); + if (wasNull) { + return null; + } + long milliseconds = timeUnit.toMillis(holder.value); + return unixDateToString((int) (milliseconds / MILLIS_PER_DAY)); + } + + protected static TimeUnit getTimeUnitForVector(ValueVector vector) { + if (vector instanceof DateDayVector) { + return TimeUnit.DAYS; + } else if (vector instanceof DateMilliVector) { + return TimeUnit.MILLISECONDS; + } + + throw new IllegalArgumentException("Invalid Arrow vector"); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcDateVectorGetter.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcDateVectorGetter.java new file mode 100644 index 0000000000000..ea545851a3aeb --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcDateVectorGetter.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.calendar; + +import org.apache.arrow.vector.DateDayVector; +import org.apache.arrow.vector.DateMilliVector; +import org.apache.arrow.vector.holders.NullableDateDayHolder; +import org.apache.arrow.vector.holders.NullableDateMilliHolder; + +/** + * Auxiliary class used to unify data access on TimeStampVectors. + */ +final class ArrowFlightJdbcDateVectorGetter { + + private ArrowFlightJdbcDateVectorGetter() { + // Prevent instantiation. + } + + /** + * Auxiliary class meant to unify Date*Vector#get implementations with different classes of ValueHolders. + */ + static class Holder { + int isSet; // Tells if value is set; 0 = not set, 1 = set + long value; // Holds actual value in its respective timeunit + } + + /** + * Functional interface used to unify Date*Vector#get implementations. + */ + @FunctionalInterface + interface Getter { + void get(int index, Holder holder); + } + + static Getter createGetter(DateDayVector vector) { + NullableDateDayHolder auxHolder = new NullableDateDayHolder(); + return (index, holder) -> { + vector.get(index, auxHolder); + holder.isSet = auxHolder.isSet; + holder.value = auxHolder.value; + }; + } + + static Getter createGetter(DateMilliVector vector) { + NullableDateMilliHolder auxHolder = new NullableDateMilliHolder(); + return (index, holder) -> { + vector.get(index, auxHolder); + holder.isSet = auxHolder.isSet; + holder.value = auxHolder.value; + }; + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcDurationVectorAccessor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcDurationVectorAccessor.java new file mode 100644 index 0000000000000..22a0e6f892378 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcDurationVectorAccessor.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.calendar; + +import java.time.Duration; +import java.util.function.IntSupplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessor; +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.vector.DurationVector; + +/** + * Accessor for the Arrow type {@link DurationVector}. + */ +public class ArrowFlightJdbcDurationVectorAccessor extends ArrowFlightJdbcAccessor { + + private final DurationVector vector; + + public ArrowFlightJdbcDurationVectorAccessor(DurationVector vector, + IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.vector = vector; + } + + @Override + public Class getObjectClass() { + return Duration.class; + } + + @Override + public Object getObject() { + Duration duration = vector.getObject(getCurrentRow()); + this.wasNull = duration == null; + this.wasNullConsumer.setWasNull(this.wasNull); + + return duration; + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcIntervalVectorAccessor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcIntervalVectorAccessor.java new file mode 100644 index 0000000000000..283dc9160a9e9 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcIntervalVectorAccessor.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.calendar; + +import static org.apache.arrow.driver.jdbc.utils.IntervalStringUtils.formatIntervalDay; +import static org.apache.arrow.driver.jdbc.utils.IntervalStringUtils.formatIntervalYear; +import static org.apache.arrow.vector.util.DateUtility.yearsToMonths; + +import java.sql.SQLException; +import java.util.function.IntSupplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessor; +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.vector.BaseFixedWidthVector; +import org.apache.arrow.vector.IntervalDayVector; +import org.apache.arrow.vector.IntervalYearVector; +import org.apache.arrow.vector.holders.NullableIntervalDayHolder; +import org.apache.arrow.vector.holders.NullableIntervalYearHolder; +import org.joda.time.Period; + +/** + * Accessor for the Arrow type {@link IntervalDayVector}. + */ +public class ArrowFlightJdbcIntervalVectorAccessor extends ArrowFlightJdbcAccessor { + + private final BaseFixedWidthVector vector; + private final StringGetter stringGetter; + private final Class objectClass; + + /** + * Instantiate an accessor for a {@link IntervalDayVector}. + * + * @param vector an instance of a IntervalDayVector. + * @param currentRowSupplier the supplier to track the rows. + * @param setCursorWasNull the consumer to set if value was null. + */ + public ArrowFlightJdbcIntervalVectorAccessor(IntervalDayVector vector, + IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.vector = vector; + stringGetter = (index) -> { + final NullableIntervalDayHolder holder = new NullableIntervalDayHolder(); + vector.get(index, holder); + if (holder.isSet == 0) { + return null; + } else { + final int days = holder.days; + final int millis = holder.milliseconds; + return formatIntervalDay(new Period().plusDays(days).plusMillis(millis)); + } + }; + objectClass = java.time.Duration.class; + } + + /** + * Instantiate an accessor for a {@link IntervalYearVector}. + * + * @param vector an instance of a IntervalYearVector. + * @param currentRowSupplier the supplier to track the rows. + * @param setCursorWasNull the consumer to set if value was null. + */ + public ArrowFlightJdbcIntervalVectorAccessor(IntervalYearVector vector, + IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.vector = vector; + stringGetter = (index) -> { + final NullableIntervalYearHolder holder = new NullableIntervalYearHolder(); + vector.get(index, holder); + if (holder.isSet == 0) { + return null; + } else { + final int interval = holder.value; + final int years = (interval / yearsToMonths); + final int months = (interval % yearsToMonths); + return formatIntervalYear(new Period().plusYears(years).plusMonths(months)); + } + }; + objectClass = java.time.Period.class; + } + + @Override + public Class getObjectClass() { + return objectClass; + } + + @Override + public String getString() throws SQLException { + String result = stringGetter.get(getCurrentRow()); + wasNull = result == null; + wasNullConsumer.setWasNull(wasNull); + return result; + } + + @Override + public Object getObject() { + Object object = vector.getObject(getCurrentRow()); + wasNull = object == null; + wasNullConsumer.setWasNull(wasNull); + return object; + } + + /** + * Functional interface used to unify Interval*Vector#getAsStringBuilder implementations. + */ + @FunctionalInterface + interface StringGetter { + String get(int index); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcTimeStampVectorAccessor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcTimeStampVectorAccessor.java new file mode 100644 index 0000000000000..a23883baf1e9b --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcTimeStampVectorAccessor.java @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.calendar; + +import static org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcTimeStampVectorGetter.Getter; +import static org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcTimeStampVectorGetter.Holder; +import static org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcTimeStampVectorGetter.createGetter; + +import java.sql.Date; +import java.sql.Time; +import java.sql.Timestamp; +import java.time.LocalDateTime; +import java.time.temporal.ChronoUnit; +import java.util.Calendar; +import java.util.TimeZone; +import java.util.concurrent.TimeUnit; +import java.util.function.IntSupplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessor; +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.vector.TimeStampVector; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.util.DateUtility; + +/** + * Accessor for the Arrow types extending from {@link TimeStampVector}. + */ +public class ArrowFlightJdbcTimeStampVectorAccessor extends ArrowFlightJdbcAccessor { + + private final TimeZone timeZone; + private final Getter getter; + private final TimeUnit timeUnit; + private final LongToLocalDateTime longToLocalDateTime; + private final Holder holder; + + /** + * Functional interface used to convert a number (in any time resolution) to LocalDateTime. + */ + interface LongToLocalDateTime { + LocalDateTime fromLong(long value); + } + + /** + * Instantiate a ArrowFlightJdbcTimeStampVectorAccessor for given vector. + */ + public ArrowFlightJdbcTimeStampVectorAccessor(TimeStampVector vector, + IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.holder = new Holder(); + this.getter = createGetter(vector); + + this.timeZone = getTimeZoneForVector(vector); + this.timeUnit = getTimeUnitForVector(vector); + this.longToLocalDateTime = getLongToLocalDateTimeForVector(vector, this.timeZone); + } + + @Override + public Class getObjectClass() { + return Timestamp.class; + } + + @Override + public Object getObject() { + return this.getTimestamp(null); + } + + private LocalDateTime getLocalDateTime(Calendar calendar) { + getter.get(getCurrentRow(), holder); + this.wasNull = holder.isSet == 0; + this.wasNullConsumer.setWasNull(this.wasNull); + if (this.wasNull) { + return null; + } + + long value = holder.value; + + LocalDateTime localDateTime = this.longToLocalDateTime.fromLong(value); + + if (calendar != null) { + TimeZone timeZone = calendar.getTimeZone(); + long millis = this.timeUnit.toMillis(value); + localDateTime = localDateTime + .minus(timeZone.getOffset(millis) - this.timeZone.getOffset(millis), ChronoUnit.MILLIS); + } + return localDateTime; + } + + @Override + public Date getDate(Calendar calendar) { + LocalDateTime localDateTime = getLocalDateTime(calendar); + if (localDateTime == null) { + return null; + } + + return new Date(Timestamp.valueOf(localDateTime).getTime()); + } + + @Override + public Time getTime(Calendar calendar) { + LocalDateTime localDateTime = getLocalDateTime(calendar); + if (localDateTime == null) { + return null; + } + + return new Time(Timestamp.valueOf(localDateTime).getTime()); + } + + @Override + public Timestamp getTimestamp(Calendar calendar) { + LocalDateTime localDateTime = getLocalDateTime(calendar); + if (localDateTime == null) { + return null; + } + + return Timestamp.valueOf(localDateTime); + } + + protected static TimeUnit getTimeUnitForVector(TimeStampVector vector) { + ArrowType.Timestamp arrowType = + (ArrowType.Timestamp) vector.getField().getFieldType().getType(); + + switch (arrowType.getUnit()) { + case NANOSECOND: + return TimeUnit.NANOSECONDS; + case MICROSECOND: + return TimeUnit.MICROSECONDS; + case MILLISECOND: + return TimeUnit.MILLISECONDS; + case SECOND: + return TimeUnit.SECONDS; + default: + throw new UnsupportedOperationException("Invalid Arrow time unit"); + } + } + + protected static LongToLocalDateTime getLongToLocalDateTimeForVector(TimeStampVector vector, + TimeZone timeZone) { + String timeZoneID = timeZone.getID(); + + ArrowType.Timestamp arrowType = + (ArrowType.Timestamp) vector.getField().getFieldType().getType(); + + switch (arrowType.getUnit()) { + case NANOSECOND: + return nanoseconds -> DateUtility.getLocalDateTimeFromEpochNano(nanoseconds, timeZoneID); + case MICROSECOND: + return microseconds -> DateUtility.getLocalDateTimeFromEpochMicro(microseconds, timeZoneID); + case MILLISECOND: + return milliseconds -> DateUtility.getLocalDateTimeFromEpochMilli(milliseconds, timeZoneID); + case SECOND: + return seconds -> DateUtility.getLocalDateTimeFromEpochMilli( + TimeUnit.SECONDS.toMillis(seconds), timeZoneID); + default: + throw new UnsupportedOperationException("Invalid Arrow time unit"); + } + } + + protected static TimeZone getTimeZoneForVector(TimeStampVector vector) { + ArrowType.Timestamp arrowType = + (ArrowType.Timestamp) vector.getField().getFieldType().getType(); + + String timezoneName = arrowType.getTimezone(); + if (timezoneName == null) { + return TimeZone.getTimeZone("UTC"); + } + + return TimeZone.getTimeZone(timezoneName); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcTimeStampVectorGetter.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcTimeStampVectorGetter.java new file mode 100644 index 0000000000000..03fb35face722 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcTimeStampVectorGetter.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.calendar; + +import org.apache.arrow.vector.TimeStampMicroTZVector; +import org.apache.arrow.vector.TimeStampMicroVector; +import org.apache.arrow.vector.TimeStampMilliTZVector; +import org.apache.arrow.vector.TimeStampMilliVector; +import org.apache.arrow.vector.TimeStampNanoTZVector; +import org.apache.arrow.vector.TimeStampNanoVector; +import org.apache.arrow.vector.TimeStampSecTZVector; +import org.apache.arrow.vector.TimeStampSecVector; +import org.apache.arrow.vector.TimeStampVector; +import org.apache.arrow.vector.holders.NullableTimeStampMicroHolder; +import org.apache.arrow.vector.holders.NullableTimeStampMicroTZHolder; +import org.apache.arrow.vector.holders.NullableTimeStampMilliHolder; +import org.apache.arrow.vector.holders.NullableTimeStampMilliTZHolder; +import org.apache.arrow.vector.holders.NullableTimeStampNanoHolder; +import org.apache.arrow.vector.holders.NullableTimeStampNanoTZHolder; +import org.apache.arrow.vector.holders.NullableTimeStampSecHolder; +import org.apache.arrow.vector.holders.NullableTimeStampSecTZHolder; + +/** + * Auxiliary class used to unify data access on TimeStampVectors. + */ +final class ArrowFlightJdbcTimeStampVectorGetter { + + private ArrowFlightJdbcTimeStampVectorGetter() { + // Prevent instantiation. + } + + /** + * Auxiliary class meant to unify TimeStamp*Vector#get implementations with different classes of ValueHolders. + */ + static class Holder { + int isSet; // Tells if value is set; 0 = not set, 1 = set + long value; // Holds actual value in its respective timeunit + } + + /** + * Functional interface used to unify TimeStamp*Vector#get implementations. + */ + @FunctionalInterface + interface Getter { + void get(int index, Holder holder); + } + + static Getter createGetter(TimeStampVector vector) { + if (vector instanceof TimeStampNanoVector) { + return createGetter((TimeStampNanoVector) vector); + } else if (vector instanceof TimeStampNanoTZVector) { + return createGetter((TimeStampNanoTZVector) vector); + } else if (vector instanceof TimeStampMicroVector) { + return createGetter((TimeStampMicroVector) vector); + } else if (vector instanceof TimeStampMicroTZVector) { + return createGetter((TimeStampMicroTZVector) vector); + } else if (vector instanceof TimeStampMilliVector) { + return createGetter((TimeStampMilliVector) vector); + } else if (vector instanceof TimeStampMilliTZVector) { + return createGetter((TimeStampMilliTZVector) vector); + } else if (vector instanceof TimeStampSecVector) { + return createGetter((TimeStampSecVector) vector); + } else if (vector instanceof TimeStampSecTZVector) { + return createGetter((TimeStampSecTZVector) vector); + } + + throw new UnsupportedOperationException("Unsupported Timestamp vector type"); + } + + private static Getter createGetter(TimeStampNanoVector vector) { + NullableTimeStampNanoHolder auxHolder = new NullableTimeStampNanoHolder(); + return (index, holder) -> { + vector.get(index, auxHolder); + holder.isSet = auxHolder.isSet; + holder.value = auxHolder.value; + }; + } + + private static Getter createGetter(TimeStampNanoTZVector vector) { + NullableTimeStampNanoTZHolder auxHolder = new NullableTimeStampNanoTZHolder(); + return (index, holder) -> { + vector.get(index, auxHolder); + holder.isSet = auxHolder.isSet; + holder.value = auxHolder.value; + }; + } + + private static Getter createGetter(TimeStampMicroVector vector) { + NullableTimeStampMicroHolder auxHolder = new NullableTimeStampMicroHolder(); + return (index, holder) -> { + vector.get(index, auxHolder); + holder.isSet = auxHolder.isSet; + holder.value = auxHolder.value; + }; + } + + private static Getter createGetter(TimeStampMicroTZVector vector) { + NullableTimeStampMicroTZHolder auxHolder = new NullableTimeStampMicroTZHolder(); + return (index, holder) -> { + vector.get(index, auxHolder); + holder.isSet = auxHolder.isSet; + holder.value = auxHolder.value; + }; + } + + private static Getter createGetter(TimeStampMilliVector vector) { + NullableTimeStampMilliHolder auxHolder = new NullableTimeStampMilliHolder(); + return (index, holder) -> { + vector.get(index, auxHolder); + holder.isSet = auxHolder.isSet; + holder.value = auxHolder.value; + }; + } + + private static Getter createGetter(TimeStampMilliTZVector vector) { + NullableTimeStampMilliTZHolder auxHolder = new NullableTimeStampMilliTZHolder(); + return (index, holder) -> { + vector.get(index, auxHolder); + holder.isSet = auxHolder.isSet; + holder.value = auxHolder.value; + }; + } + + private static Getter createGetter(TimeStampSecVector vector) { + NullableTimeStampSecHolder auxHolder = new NullableTimeStampSecHolder(); + return (index, holder) -> { + vector.get(index, auxHolder); + holder.isSet = auxHolder.isSet; + holder.value = auxHolder.value; + }; + } + + private static Getter createGetter(TimeStampSecTZVector vector) { + NullableTimeStampSecTZHolder auxHolder = new NullableTimeStampSecTZHolder(); + return (index, holder) -> { + vector.get(index, auxHolder); + holder.isSet = auxHolder.isSet; + holder.value = auxHolder.value; + }; + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcTimeVectorAccessor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcTimeVectorAccessor.java new file mode 100644 index 0000000000000..6c2173d5e5656 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcTimeVectorAccessor.java @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.calendar; + +import static org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcTimeVectorGetter.Getter; +import static org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcTimeVectorGetter.Holder; +import static org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcTimeVectorGetter.createGetter; + +import java.sql.Time; +import java.sql.Timestamp; +import java.util.Calendar; +import java.util.concurrent.TimeUnit; +import java.util.function.IntSupplier; + +import org.apache.arrow.driver.jdbc.ArrowFlightJdbcTime; +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessor; +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.driver.jdbc.utils.DateTimeUtils; +import org.apache.arrow.vector.TimeMicroVector; +import org.apache.arrow.vector.TimeMilliVector; +import org.apache.arrow.vector.TimeNanoVector; +import org.apache.arrow.vector.TimeSecVector; +import org.apache.arrow.vector.ValueVector; + +/** + * Accessor for the Arrow types: {@link TimeNanoVector}, {@link TimeMicroVector}, {@link TimeMilliVector} + * and {@link TimeSecVector}. + */ +public class ArrowFlightJdbcTimeVectorAccessor extends ArrowFlightJdbcAccessor { + + private final Getter getter; + private final TimeUnit timeUnit; + private final Holder holder; + + /** + * Instantiate an accessor for a {@link TimeNanoVector}. + * + * @param vector an instance of a TimeNanoVector. + * @param currentRowSupplier the supplier to track the lines. + * @param setCursorWasNull the consumer to set if value was null. + */ + public ArrowFlightJdbcTimeVectorAccessor(TimeNanoVector vector, IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.holder = new Holder(); + this.getter = createGetter(vector); + this.timeUnit = getTimeUnitForVector(vector); + } + + /** + * Instantiate an accessor for a {@link TimeMicroVector}. + * + * @param vector an instance of a TimeMicroVector. + * @param currentRowSupplier the supplier to track the lines. + * @param setCursorWasNull the consumer to set if value was null. + */ + public ArrowFlightJdbcTimeVectorAccessor(TimeMicroVector vector, IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.holder = new Holder(); + this.getter = createGetter(vector); + this.timeUnit = getTimeUnitForVector(vector); + } + + /** + * Instantiate an accessor for a {@link TimeMilliVector}. + * + * @param vector an instance of a TimeMilliVector. + * @param currentRowSupplier the supplier to track the lines. + */ + public ArrowFlightJdbcTimeVectorAccessor(TimeMilliVector vector, IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.holder = new Holder(); + this.getter = createGetter(vector); + this.timeUnit = getTimeUnitForVector(vector); + } + + /** + * Instantiate an accessor for a {@link TimeSecVector}. + * + * @param vector an instance of a TimeSecVector. + * @param currentRowSupplier the supplier to track the lines. + */ + public ArrowFlightJdbcTimeVectorAccessor(TimeSecVector vector, IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.holder = new Holder(); + this.getter = createGetter(vector); + this.timeUnit = getTimeUnitForVector(vector); + } + + @Override + public Class getObjectClass() { + return Time.class; + } + + @Override + public Object getObject() { + return this.getTime(null); + } + + @Override + public Time getTime(Calendar calendar) { + fillHolder(); + if (this.wasNull) { + return null; + } + + long value = holder.value; + long milliseconds = this.timeUnit.toMillis(value); + + return new ArrowFlightJdbcTime(DateTimeUtils.applyCalendarOffset(milliseconds, calendar)); + } + + private void fillHolder() { + getter.get(getCurrentRow(), holder); + this.wasNull = holder.isSet == 0; + this.wasNullConsumer.setWasNull(this.wasNull); + } + + @Override + public Timestamp getTimestamp(Calendar calendar) { + Time time = getTime(calendar); + if (time == null) { + return null; + } + return new Timestamp(time.getTime()); + } + + protected static TimeUnit getTimeUnitForVector(ValueVector vector) { + if (vector instanceof TimeNanoVector) { + return TimeUnit.NANOSECONDS; + } else if (vector instanceof TimeMicroVector) { + return TimeUnit.MICROSECONDS; + } else if (vector instanceof TimeMilliVector) { + return TimeUnit.MILLISECONDS; + } else if (vector instanceof TimeSecVector) { + return TimeUnit.SECONDS; + } + + throw new IllegalArgumentException("Invalid Arrow vector"); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcTimeVectorGetter.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcTimeVectorGetter.java new file mode 100644 index 0000000000000..fb254c694014d --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcTimeVectorGetter.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.calendar; + +import org.apache.arrow.vector.TimeMicroVector; +import org.apache.arrow.vector.TimeMilliVector; +import org.apache.arrow.vector.TimeNanoVector; +import org.apache.arrow.vector.TimeSecVector; +import org.apache.arrow.vector.holders.NullableTimeMicroHolder; +import org.apache.arrow.vector.holders.NullableTimeMilliHolder; +import org.apache.arrow.vector.holders.NullableTimeNanoHolder; +import org.apache.arrow.vector.holders.NullableTimeSecHolder; + +/** + * Auxiliary class used to unify data access on Time*Vectors. + */ +final class ArrowFlightJdbcTimeVectorGetter { + + private ArrowFlightJdbcTimeVectorGetter() { + // Prevent instantiation. + } + + /** + * Auxiliary class meant to unify TimeStamp*Vector#get implementations with different classes of ValueHolders. + */ + static class Holder { + int isSet; // Tells if value is set; 0 = not set, 1 = set + long value; // Holds actual value in its respective timeunit + } + + /** + * Functional interface used to unify TimeStamp*Vector#get implementations. + */ + @FunctionalInterface + interface Getter { + void get(int index, Holder holder); + } + + static Getter createGetter(TimeNanoVector vector) { + NullableTimeNanoHolder auxHolder = new NullableTimeNanoHolder(); + return (index, holder) -> { + vector.get(index, auxHolder); + holder.isSet = auxHolder.isSet; + holder.value = auxHolder.value; + }; + } + + static Getter createGetter(TimeMicroVector vector) { + NullableTimeMicroHolder auxHolder = new NullableTimeMicroHolder(); + return (index, holder) -> { + vector.get(index, auxHolder); + holder.isSet = auxHolder.isSet; + holder.value = auxHolder.value; + }; + } + + static Getter createGetter(TimeMilliVector vector) { + NullableTimeMilliHolder auxHolder = new NullableTimeMilliHolder(); + return (index, holder) -> { + vector.get(index, auxHolder); + holder.isSet = auxHolder.isSet; + holder.value = auxHolder.value; + }; + } + + static Getter createGetter(TimeSecVector vector) { + NullableTimeSecHolder auxHolder = new NullableTimeSecHolder(); + return (index, holder) -> { + vector.get(index, auxHolder); + holder.isSet = auxHolder.isSet; + holder.value = auxHolder.value; + }; + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/AbstractArrowFlightJdbcListVectorAccessor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/AbstractArrowFlightJdbcListVectorAccessor.java new file mode 100644 index 0000000000000..d3338608f8352 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/AbstractArrowFlightJdbcListVectorAccessor.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.complex; + +import java.sql.Array; +import java.util.List; +import java.util.function.IntSupplier; + +import org.apache.arrow.driver.jdbc.ArrowFlightJdbcArray; +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessor; +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.complex.FixedSizeListVector; +import org.apache.arrow.vector.complex.LargeListVector; +import org.apache.arrow.vector.complex.ListVector; + +/** + * Base Accessor for the Arrow types {@link ListVector}, {@link LargeListVector} and {@link FixedSizeListVector}. + */ +public abstract class AbstractArrowFlightJdbcListVectorAccessor extends ArrowFlightJdbcAccessor { + + protected AbstractArrowFlightJdbcListVectorAccessor(IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + } + + @Override + public Class getObjectClass() { + return List.class; + } + + protected abstract long getStartOffset(int index); + + protected abstract long getEndOffset(int index); + + protected abstract FieldVector getDataVector(); + + protected abstract boolean isNull(int index); + + @Override + public final Array getArray() { + int index = getCurrentRow(); + FieldVector dataVector = getDataVector(); + + this.wasNull = isNull(index); + this.wasNullConsumer.setWasNull(this.wasNull); + if (this.wasNull) { + return null; + } + + long startOffset = getStartOffset(index); + long endOffset = getEndOffset(index); + + long valuesCount = endOffset - startOffset; + return new ArrowFlightJdbcArray(dataVector, startOffset, valuesCount); + } +} + diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/AbstractArrowFlightJdbcUnionVectorAccessor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/AbstractArrowFlightJdbcUnionVectorAccessor.java new file mode 100644 index 0000000000000..0465765f18358 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/AbstractArrowFlightJdbcUnionVectorAccessor.java @@ -0,0 +1,259 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.complex; + +import java.io.InputStream; +import java.io.Reader; +import java.math.BigDecimal; +import java.net.URL; +import java.sql.Array; +import java.sql.Blob; +import java.sql.Clob; +import java.sql.Date; +import java.sql.NClob; +import java.sql.Ref; +import java.sql.SQLException; +import java.sql.SQLXML; +import java.sql.Struct; +import java.sql.Time; +import java.sql.Timestamp; +import java.util.Calendar; +import java.util.Map; +import java.util.function.IntSupplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessor; +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.driver.jdbc.accessor.impl.ArrowFlightJdbcNullVectorAccessor; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.complex.DenseUnionVector; +import org.apache.arrow.vector.complex.UnionVector; + +/** + * Base accessor for {@link UnionVector} and {@link DenseUnionVector}. + */ +public abstract class AbstractArrowFlightJdbcUnionVectorAccessor extends ArrowFlightJdbcAccessor { + + /** + * Array of accessors for each type contained in UnionVector. + * Index corresponds to UnionVector and DenseUnionVector typeIds which are both limited to 128. + */ + private final ArrowFlightJdbcAccessor[] accessors = new ArrowFlightJdbcAccessor[128]; + + private final ArrowFlightJdbcNullVectorAccessor nullAccessor = + new ArrowFlightJdbcNullVectorAccessor((boolean wasNull) -> { + }); + + protected AbstractArrowFlightJdbcUnionVectorAccessor(IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + } + + protected abstract ArrowFlightJdbcAccessor createAccessorForVector(ValueVector vector); + + protected abstract byte getCurrentTypeId(); + + protected abstract ValueVector getVectorByTypeId(byte typeId); + + /** + * Returns an accessor for UnionVector child vector on current row. + * + * @return ArrowFlightJdbcAccessor for child vector on current row. + */ + protected ArrowFlightJdbcAccessor getAccessor() { + // Get the typeId and child vector for the current row being accessed. + byte typeId = this.getCurrentTypeId(); + ValueVector vector = this.getVectorByTypeId(typeId); + + if (typeId < 0) { + // typeId may be negative if the current row has no type defined. + return this.nullAccessor; + } + + // Ensure there is an accessor for given typeId + if (this.accessors[typeId] == null) { + this.accessors[typeId] = this.createAccessorForVector(vector); + } + + return this.accessors[typeId]; + } + + @Override + public Class getObjectClass() { + return getAccessor().getObjectClass(); + } + + @Override + public boolean wasNull() { + return getAccessor().wasNull(); + } + + @Override + public String getString() throws SQLException { + return getAccessor().getString(); + } + + @Override + public boolean getBoolean() throws SQLException { + return getAccessor().getBoolean(); + } + + @Override + public byte getByte() throws SQLException { + return getAccessor().getByte(); + } + + @Override + public short getShort() throws SQLException { + return getAccessor().getShort(); + } + + @Override + public int getInt() throws SQLException { + return getAccessor().getInt(); + } + + @Override + public long getLong() throws SQLException { + return getAccessor().getLong(); + } + + @Override + public float getFloat() throws SQLException { + return getAccessor().getFloat(); + } + + @Override + public double getDouble() throws SQLException { + return getAccessor().getDouble(); + } + + @Override + public BigDecimal getBigDecimal() throws SQLException { + return getAccessor().getBigDecimal(); + } + + @Override + public BigDecimal getBigDecimal(int i) throws SQLException { + return getAccessor().getBigDecimal(i); + } + + @Override + public byte[] getBytes() throws SQLException { + return getAccessor().getBytes(); + } + + @Override + public InputStream getAsciiStream() throws SQLException { + return getAccessor().getAsciiStream(); + } + + @Override + public InputStream getUnicodeStream() throws SQLException { + return getAccessor().getUnicodeStream(); + } + + @Override + public InputStream getBinaryStream() throws SQLException { + return getAccessor().getBinaryStream(); + } + + @Override + public Object getObject() throws SQLException { + return getAccessor().getObject(); + } + + @Override + public Reader getCharacterStream() throws SQLException { + return getAccessor().getCharacterStream(); + } + + @Override + public Object getObject(Map> map) throws SQLException { + return getAccessor().getObject(map); + } + + @Override + public Ref getRef() throws SQLException { + return getAccessor().getRef(); + } + + @Override + public Blob getBlob() throws SQLException { + return getAccessor().getBlob(); + } + + @Override + public Clob getClob() throws SQLException { + return getAccessor().getClob(); + } + + @Override + public Array getArray() throws SQLException { + return getAccessor().getArray(); + } + + @Override + public Struct getStruct() throws SQLException { + return getAccessor().getStruct(); + } + + @Override + public Date getDate(Calendar calendar) throws SQLException { + return getAccessor().getDate(calendar); + } + + @Override + public Time getTime(Calendar calendar) throws SQLException { + return getAccessor().getTime(calendar); + } + + @Override + public Timestamp getTimestamp(Calendar calendar) throws SQLException { + return getAccessor().getTimestamp(calendar); + } + + @Override + public URL getURL() throws SQLException { + return getAccessor().getURL(); + } + + @Override + public NClob getNClob() throws SQLException { + return getAccessor().getNClob(); + } + + @Override + public SQLXML getSQLXML() throws SQLException { + return getAccessor().getSQLXML(); + } + + @Override + public String getNString() throws SQLException { + return getAccessor().getNString(); + } + + @Override + public Reader getNCharacterStream() throws SQLException { + return getAccessor().getNCharacterStream(); + } + + @Override + public T getObject(Class type) throws SQLException { + return getAccessor().getObject(type); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcDenseUnionVectorAccessor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcDenseUnionVectorAccessor.java new file mode 100644 index 0000000000000..ba5b83ade63ab --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcDenseUnionVectorAccessor.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.complex; + +import java.util.function.IntSupplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessor; +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.complex.DenseUnionVector; + +/** + * Accessor for the Arrow type {@link DenseUnionVector}. + */ +public class ArrowFlightJdbcDenseUnionVectorAccessor + extends AbstractArrowFlightJdbcUnionVectorAccessor { + + private final DenseUnionVector vector; + + /** + * Instantiate an accessor for a {@link DenseUnionVector}. + * + * @param vector an instance of a DenseUnionVector. + * @param currentRowSupplier the supplier to track the rows. + * @param setCursorWasNull the consumer to set if value was null. + */ + public ArrowFlightJdbcDenseUnionVectorAccessor(DenseUnionVector vector, + IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.vector = vector; + } + + @Override + protected ArrowFlightJdbcAccessor createAccessorForVector(ValueVector vector) { + return ArrowFlightJdbcAccessorFactory.createAccessor(vector, + () -> this.vector.getOffset(this.getCurrentRow()), (boolean wasNull) -> { + }); + } + + @Override + protected byte getCurrentTypeId() { + int index = getCurrentRow(); + return this.vector.getTypeId(index); + } + + @Override + protected ValueVector getVectorByTypeId(byte typeId) { + return this.vector.getVectorByType(typeId); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcFixedSizeListVectorAccessor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcFixedSizeListVectorAccessor.java new file mode 100644 index 0000000000000..7bdd3abfd0cad --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcFixedSizeListVectorAccessor.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.complex; + +import java.util.List; +import java.util.function.IntSupplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.complex.FixedSizeListVector; + +/** + * Accessor for the Arrow type {@link FixedSizeListVector}. + */ +public class ArrowFlightJdbcFixedSizeListVectorAccessor + extends AbstractArrowFlightJdbcListVectorAccessor { + + private final FixedSizeListVector vector; + + public ArrowFlightJdbcFixedSizeListVectorAccessor(FixedSizeListVector vector, + IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.vector = vector; + } + + @Override + protected long getStartOffset(int index) { + return (long) vector.getListSize() * index; + } + + @Override + protected long getEndOffset(int index) { + return (long) vector.getListSize() * (index + 1); + } + + @Override + protected FieldVector getDataVector() { + return vector.getDataVector(); + } + + @Override + protected boolean isNull(int index) { + return vector.isNull(index); + } + + @Override + public Object getObject() { + List object = vector.getObject(getCurrentRow()); + this.wasNull = object == null; + this.wasNullConsumer.setWasNull(this.wasNull); + + return object; + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcLargeListVectorAccessor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcLargeListVectorAccessor.java new file mode 100644 index 0000000000000..f7608bb06e5fe --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcLargeListVectorAccessor.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.complex; + +import java.util.List; +import java.util.function.IntSupplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.complex.LargeListVector; + +/** + * Accessor for the Arrow type {@link LargeListVector}. + */ +public class ArrowFlightJdbcLargeListVectorAccessor + extends AbstractArrowFlightJdbcListVectorAccessor { + + private final LargeListVector vector; + + public ArrowFlightJdbcLargeListVectorAccessor(LargeListVector vector, + IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.vector = vector; + } + + @Override + protected long getStartOffset(int index) { + return vector.getOffsetBuffer().getLong((long) index * LargeListVector.OFFSET_WIDTH); + } + + @Override + protected long getEndOffset(int index) { + return vector.getOffsetBuffer().getLong((long) (index + 1) * LargeListVector.OFFSET_WIDTH); + } + + @Override + protected FieldVector getDataVector() { + return vector.getDataVector(); + } + + @Override + protected boolean isNull(int index) { + return vector.isNull(index); + } + + @Override + public Object getObject() { + List object = vector.getObject(getCurrentRow()); + this.wasNull = object == null; + this.wasNullConsumer.setWasNull(this.wasNull); + + return object; + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcListVectorAccessor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcListVectorAccessor.java new file mode 100644 index 0000000000000..a329a344073a6 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcListVectorAccessor.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.complex; + +import java.util.List; +import java.util.function.IntSupplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.complex.BaseRepeatedValueVector; +import org.apache.arrow.vector.complex.ListVector; + +/** + * Accessor for the Arrow type {@link ListVector}. + */ +public class ArrowFlightJdbcListVectorAccessor extends AbstractArrowFlightJdbcListVectorAccessor { + + private final ListVector vector; + + public ArrowFlightJdbcListVectorAccessor(ListVector vector, IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.vector = vector; + } + + @Override + protected long getStartOffset(int index) { + return vector.getOffsetBuffer().getInt((long) index * BaseRepeatedValueVector.OFFSET_WIDTH); + } + + @Override + protected long getEndOffset(int index) { + return vector.getOffsetBuffer() + .getInt((long) (index + 1) * BaseRepeatedValueVector.OFFSET_WIDTH); + } + + @Override + protected FieldVector getDataVector() { + return vector.getDataVector(); + } + + @Override + protected boolean isNull(int index) { + return vector.isNull(index); + } + + @Override + public Object getObject() { + List object = vector.getObject(getCurrentRow()); + this.wasNull = object == null; + this.wasNullConsumer.setWasNull(this.wasNull); + + return object; + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcMapVectorAccessor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcMapVectorAccessor.java new file mode 100644 index 0000000000000..bf1225b33de64 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcMapVectorAccessor.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.complex; + +import java.util.Map; +import java.util.function.IntSupplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.complex.BaseRepeatedValueVector; +import org.apache.arrow.vector.complex.MapVector; +import org.apache.arrow.vector.complex.impl.UnionMapReader; +import org.apache.arrow.vector.util.JsonStringHashMap; + +/** + * Accessor for the Arrow type {@link MapVector}. + */ +public class ArrowFlightJdbcMapVectorAccessor extends AbstractArrowFlightJdbcListVectorAccessor { + + private final MapVector vector; + + public ArrowFlightJdbcMapVectorAccessor(MapVector vector, IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.vector = vector; + } + + @Override + public Class getObjectClass() { + return Map.class; + } + + @Override + public Object getObject() { + int index = getCurrentRow(); + + this.wasNull = vector.isNull(index); + this.wasNullConsumer.setWasNull(this.wasNull); + if (this.wasNull) { + return null; + } + + Map result = new JsonStringHashMap<>(); + UnionMapReader reader = vector.getReader(); + + reader.setPosition(index); + while (reader.next()) { + Object key = reader.key().readObject(); + Object value = reader.value().readObject(); + + result.put(key, value); + } + + return result; + } + + @Override + protected long getStartOffset(int index) { + return vector.getOffsetBuffer().getInt((long) index * BaseRepeatedValueVector.OFFSET_WIDTH); + } + + @Override + protected long getEndOffset(int index) { + return vector.getOffsetBuffer() + .getInt((long) (index + 1) * BaseRepeatedValueVector.OFFSET_WIDTH); + } + + @Override + protected boolean isNull(int index) { + return vector.isNull(index); + } + + @Override + protected FieldVector getDataVector() { + return vector.getDataVector(); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcStructVectorAccessor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcStructVectorAccessor.java new file mode 100644 index 0000000000000..8a7ac117113c5 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcStructVectorAccessor.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.complex; + +import java.sql.Struct; +import java.util.List; +import java.util.Map; +import java.util.function.IntSupplier; +import java.util.stream.Collectors; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessor; +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.calcite.avatica.util.StructImpl; + +/** + * Accessor for the Arrow type {@link StructVector}. + */ +public class ArrowFlightJdbcStructVectorAccessor extends ArrowFlightJdbcAccessor { + + private final StructVector vector; + + public ArrowFlightJdbcStructVectorAccessor(StructVector vector, IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.vector = vector; + } + + @Override + public Class getObjectClass() { + return Map.class; + } + + @Override + public Object getObject() { + Map object = vector.getObject(getCurrentRow()); + this.wasNull = object == null; + this.wasNullConsumer.setWasNull(this.wasNull); + + return object; + } + + @Override + public Struct getStruct() { + int currentRow = getCurrentRow(); + + this.wasNull = vector.isNull(currentRow); + this.wasNullConsumer.setWasNull(this.wasNull); + if (this.wasNull) { + return null; + } + + List attributes = vector.getChildrenFromFields() + .stream() + .map(vector -> vector.getObject(currentRow)) + .collect(Collectors.toList()); + + return new StructImpl(attributes); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcUnionVectorAccessor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcUnionVectorAccessor.java new file mode 100644 index 0000000000000..5b5a0a472d5a3 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcUnionVectorAccessor.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.complex; + +import java.util.function.IntSupplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessor; +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.complex.UnionVector; + +/** + * Accessor for the Arrow type {@link UnionVector}. + */ +public class ArrowFlightJdbcUnionVectorAccessor extends AbstractArrowFlightJdbcUnionVectorAccessor { + + private final UnionVector vector; + + /** + * Instantiate an accessor for a {@link UnionVector}. + * + * @param vector an instance of a UnionVector. + * @param currentRowSupplier the supplier to track the rows. + * @param setCursorWasNull the consumer to set if value was null. + */ + public ArrowFlightJdbcUnionVectorAccessor(UnionVector vector, IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.vector = vector; + } + + @Override + protected ArrowFlightJdbcAccessor createAccessorForVector(ValueVector vector) { + return ArrowFlightJdbcAccessorFactory.createAccessor(vector, this::getCurrentRow, + (boolean wasNull) -> { + }); + } + + @Override + protected byte getCurrentTypeId() { + int index = getCurrentRow(); + return (byte) this.vector.getTypeValue(index); + } + + @Override + protected ValueVector getVectorByTypeId(byte typeId) { + return this.vector.getVectorByType(typeId); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcBaseIntVectorAccessor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcBaseIntVectorAccessor.java new file mode 100644 index 0000000000000..aea9b75fa6c3f --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcBaseIntVectorAccessor.java @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.numeric; + +import static org.apache.arrow.driver.jdbc.accessor.impl.numeric.ArrowFlightJdbcNumericGetter.Getter; +import static org.apache.arrow.driver.jdbc.accessor.impl.numeric.ArrowFlightJdbcNumericGetter.createGetter; + +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.util.function.IntSupplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessor; +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.driver.jdbc.accessor.impl.numeric.ArrowFlightJdbcNumericGetter.NumericHolder; +import org.apache.arrow.vector.BaseIntVector; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.UInt1Vector; +import org.apache.arrow.vector.UInt2Vector; +import org.apache.arrow.vector.UInt4Vector; +import org.apache.arrow.vector.UInt8Vector; +import org.apache.arrow.vector.types.Types.MinorType; + +/** + * Accessor for the arrow types: TinyIntVector, SmallIntVector, IntVector, BigIntVector, + * UInt1Vector, UInt2Vector, UInt4Vector and UInt8Vector. + */ +public class ArrowFlightJdbcBaseIntVectorAccessor extends ArrowFlightJdbcAccessor { + + private final MinorType type; + private final boolean isUnsigned; + private final int bytesToAllocate; + private final Getter getter; + private final NumericHolder holder; + + public ArrowFlightJdbcBaseIntVectorAccessor(UInt1Vector vector, IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + this(vector, currentRowSupplier, true, UInt1Vector.TYPE_WIDTH, setCursorWasNull); + } + + public ArrowFlightJdbcBaseIntVectorAccessor(UInt2Vector vector, IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + this(vector, currentRowSupplier, true, UInt2Vector.TYPE_WIDTH, setCursorWasNull); + } + + public ArrowFlightJdbcBaseIntVectorAccessor(UInt4Vector vector, IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + this(vector, currentRowSupplier, true, UInt4Vector.TYPE_WIDTH, setCursorWasNull); + } + + public ArrowFlightJdbcBaseIntVectorAccessor(UInt8Vector vector, IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + this(vector, currentRowSupplier, true, UInt8Vector.TYPE_WIDTH, setCursorWasNull); + } + + public ArrowFlightJdbcBaseIntVectorAccessor(TinyIntVector vector, IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + this(vector, currentRowSupplier, false, TinyIntVector.TYPE_WIDTH, setCursorWasNull); + } + + public ArrowFlightJdbcBaseIntVectorAccessor(SmallIntVector vector, IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + this(vector, currentRowSupplier, false, SmallIntVector.TYPE_WIDTH, setCursorWasNull); + } + + public ArrowFlightJdbcBaseIntVectorAccessor(IntVector vector, IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + this(vector, currentRowSupplier, false, IntVector.TYPE_WIDTH, setCursorWasNull); + } + + public ArrowFlightJdbcBaseIntVectorAccessor(BigIntVector vector, IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + this(vector, currentRowSupplier, false, BigIntVector.TYPE_WIDTH, setCursorWasNull); + } + + private ArrowFlightJdbcBaseIntVectorAccessor(BaseIntVector vector, IntSupplier currentRowSupplier, + boolean isUnsigned, int bytesToAllocate, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.type = vector.getMinorType(); + this.holder = new NumericHolder(); + this.getter = createGetter(vector); + this.isUnsigned = isUnsigned; + this.bytesToAllocate = bytesToAllocate; + } + + @Override + public long getLong() { + getter.get(getCurrentRow(), holder); + + this.wasNull = holder.isSet == 0; + this.wasNullConsumer.setWasNull(this.wasNull); + if (this.wasNull) { + return 0; + } + + return holder.value; + } + + @Override + public Class getObjectClass() { + return Long.class; + } + + @Override + public String getString() { + final long number = getLong(); + + if (this.wasNull) { + return null; + } else { + return isUnsigned ? Long.toUnsignedString(number) : Long.toString(number); + } + } + + @Override + public byte getByte() { + return (byte) getLong(); + } + + @Override + public short getShort() { + return (short) getLong(); + } + + @Override + public int getInt() { + return (int) getLong(); + } + + @Override + public float getFloat() { + return (float) getLong(); + } + + @Override + public double getDouble() { + return (double) getLong(); + } + + @Override + public BigDecimal getBigDecimal() { + final BigDecimal value = BigDecimal.valueOf(getLong()); + return this.wasNull ? null : value; + } + + @Override + public BigDecimal getBigDecimal(int scale) { + final BigDecimal value = + BigDecimal.valueOf(this.getDouble()).setScale(scale, RoundingMode.HALF_UP); + return this.wasNull ? null : value; + } + + @Override + public Number getObject() { + final Number number; + switch (type) { + case TINYINT: + case UINT1: + number = getByte(); + break; + case SMALLINT: + case UINT2: + number = getShort(); + break; + case INT: + case UINT4: + number = getInt(); + break; + case BIGINT: + case UINT8: + number = getLong(); + break; + default: + throw new IllegalStateException("No valid MinorType was provided."); + } + return wasNull ? null : number; + } + + @Override + public boolean getBoolean() { + final long value = getLong(); + + return value != 0; + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcBitVectorAccessor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcBitVectorAccessor.java new file mode 100644 index 0000000000000..f55fd12f9a517 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcBitVectorAccessor.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.numeric; + +import java.math.BigDecimal; +import java.util.function.IntSupplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessor; +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.holders.NullableBitHolder; + +/** + * Accessor for the arrow {@link BitVector}. + */ +public class ArrowFlightJdbcBitVectorAccessor extends ArrowFlightJdbcAccessor { + + private final BitVector vector; + private final NullableBitHolder holder; + private static final int BYTES_T0_ALLOCATE = 1; + + /** + * Constructor for the BitVectorAccessor. + * + * @param vector an instance of a {@link BitVector}. + * @param currentRowSupplier a supplier to check which row is being accessed. + * @param setCursorWasNull the consumer to set if value was null. + */ + public ArrowFlightJdbcBitVectorAccessor(BitVector vector, IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.vector = vector; + this.holder = new NullableBitHolder(); + } + + @Override + public Class getObjectClass() { + return Boolean.class; + } + + @Override + public String getString() { + final boolean value = getBoolean(); + return wasNull ? null : Boolean.toString(value); + } + + @Override + public boolean getBoolean() { + return this.getLong() != 0; + } + + @Override + public byte getByte() { + return (byte) this.getLong(); + } + + @Override + public short getShort() { + return (short) this.getLong(); + } + + @Override + public int getInt() { + return (int) this.getLong(); + } + + @Override + public long getLong() { + vector.get(getCurrentRow(), holder); + + this.wasNull = holder.isSet == 0; + this.wasNullConsumer.setWasNull(this.wasNull); + if (this.wasNull) { + return 0; + } + + return holder.value; + } + + @Override + public float getFloat() { + return this.getLong(); + } + + @Override + public double getDouble() { + return this.getLong(); + } + + @Override + public BigDecimal getBigDecimal() { + final long value = this.getLong(); + + return this.wasNull ? null : BigDecimal.valueOf(value); + } + + @Override + public Object getObject() { + final boolean value = this.getBoolean(); + return this.wasNull ? null : value; + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcDecimalVectorAccessor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcDecimalVectorAccessor.java new file mode 100644 index 0000000000000..0f7d618a6090f --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcDecimalVectorAccessor.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.numeric; + +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.util.function.IntSupplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessor; +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.vector.Decimal256Vector; +import org.apache.arrow.vector.DecimalVector; + +/** + * Accessor for {@link DecimalVector} and {@link Decimal256Vector}. + */ +public class ArrowFlightJdbcDecimalVectorAccessor extends ArrowFlightJdbcAccessor { + + private final Getter getter; + + /** + * Functional interface used to unify Decimal*Vector#getObject implementations. + */ + @FunctionalInterface + interface Getter { + BigDecimal getObject(int index); + } + + public ArrowFlightJdbcDecimalVectorAccessor(DecimalVector vector, IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.getter = vector::getObject; + } + + public ArrowFlightJdbcDecimalVectorAccessor(Decimal256Vector vector, + IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.getter = vector::getObject; + } + + @Override + public Class getObjectClass() { + return BigDecimal.class; + } + + @Override + public BigDecimal getBigDecimal() { + final BigDecimal value = getter.getObject(getCurrentRow()); + this.wasNull = value == null; + this.wasNullConsumer.setWasNull(this.wasNull); + return value; + } + + @Override + public String getString() { + final BigDecimal value = this.getBigDecimal(); + return this.wasNull ? null : value.toString(); + } + + @Override + public boolean getBoolean() { + final BigDecimal value = this.getBigDecimal(); + + return !this.wasNull && !value.equals(BigDecimal.ZERO); + } + + @Override + public byte getByte() { + final BigDecimal value = this.getBigDecimal(); + + return this.wasNull ? 0 : value.byteValue(); + } + + @Override + public short getShort() { + final BigDecimal value = this.getBigDecimal(); + + return this.wasNull ? 0 : value.shortValue(); + } + + @Override + public int getInt() { + final BigDecimal value = this.getBigDecimal(); + + return this.wasNull ? 0 : value.intValue(); + } + + @Override + public long getLong() { + final BigDecimal value = this.getBigDecimal(); + + return this.wasNull ? 0 : value.longValue(); + } + + @Override + public float getFloat() { + final BigDecimal value = this.getBigDecimal(); + + return this.wasNull ? 0 : value.floatValue(); + } + + @Override + public double getDouble() { + final BigDecimal value = this.getBigDecimal(); + + return this.wasNull ? 0 : value.doubleValue(); + } + + @Override + public BigDecimal getBigDecimal(int scale) { + final BigDecimal value = this.getBigDecimal(); + + return this.wasNull ? null : value.setScale(scale, RoundingMode.HALF_UP); + } + + @Override + public Object getObject() { + return this.getBigDecimal(); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcFloat4VectorAccessor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcFloat4VectorAccessor.java new file mode 100644 index 0000000000000..cbf2d36ff80b1 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcFloat4VectorAccessor.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.numeric; + +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.sql.SQLException; +import java.util.function.IntSupplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessor; +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.holders.NullableFloat4Holder; + +/** + * Accessor for the Float4Vector. + */ +public class ArrowFlightJdbcFloat4VectorAccessor extends ArrowFlightJdbcAccessor { + + private final Float4Vector vector; + private final NullableFloat4Holder holder; + + /** + * Instantiate a accessor for the {@link Float4Vector}. + * + * @param vector an instance of a Float4Vector. + * @param currentRowSupplier the supplier to track the lines. + * @param setCursorWasNull the consumer to set if value was null. + */ + public ArrowFlightJdbcFloat4VectorAccessor(Float4Vector vector, + IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.holder = new NullableFloat4Holder(); + this.vector = vector; + } + + @Override + public Class getObjectClass() { + return Float.class; + } + + @Override + public String getString() { + final float value = this.getFloat(); + + return this.wasNull ? null : Float.toString(value); + } + + @Override + public boolean getBoolean() { + return this.getFloat() != 0.0; + } + + @Override + public byte getByte() { + return (byte) this.getFloat(); + } + + @Override + public short getShort() { + return (short) this.getFloat(); + } + + @Override + public int getInt() { + return (int) this.getFloat(); + } + + @Override + public long getLong() { + return (long) this.getFloat(); + } + + @Override + public float getFloat() { + vector.get(getCurrentRow(), holder); + + this.wasNull = holder.isSet == 0; + this.wasNullConsumer.setWasNull(this.wasNull); + if (this.wasNull) { + return 0; + } + + return holder.value; + } + + @Override + public double getDouble() { + return this.getFloat(); + } + + @Override + public BigDecimal getBigDecimal() throws SQLException { + final float value = this.getFloat(); + + if (Float.isInfinite(value) || Float.isNaN(value)) { + throw new SQLException("BigDecimal doesn't support Infinite/NaN."); + } + + return this.wasNull ? null : BigDecimal.valueOf(value); + } + + @Override + public BigDecimal getBigDecimal(int scale) throws SQLException { + final float value = this.getFloat(); + if (Float.isInfinite(value) || Float.isNaN(value)) { + throw new SQLException("BigDecimal doesn't support Infinite/NaN."); + } + return this.wasNull ? null : BigDecimal.valueOf(value).setScale(scale, RoundingMode.HALF_UP); + } + + @Override + public Object getObject() { + final float value = this.getFloat(); + return this.wasNull ? null : value; + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcFloat8VectorAccessor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcFloat8VectorAccessor.java new file mode 100644 index 0000000000000..dc5542ffc58d8 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcFloat8VectorAccessor.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.numeric; + +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.sql.SQLException; +import java.util.function.IntSupplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessor; +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.holders.NullableFloat8Holder; + +/** + * Accessor for the Float8Vector. + */ +public class ArrowFlightJdbcFloat8VectorAccessor extends ArrowFlightJdbcAccessor { + + private final Float8Vector vector; + private final NullableFloat8Holder holder; + + /** + * Instantiate a accessor for the {@link Float8Vector}. + * + * @param vector an instance of a Float8Vector. + * @param currentRowSupplier the supplier to track the lines. + * @param setCursorWasNull the consumer to set if value was null. + */ + public ArrowFlightJdbcFloat8VectorAccessor(Float8Vector vector, + IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.holder = new NullableFloat8Holder(); + this.vector = vector; + } + + @Override + public Class getObjectClass() { + return Double.class; + } + + @Override + public double getDouble() { + vector.get(getCurrentRow(), holder); + + this.wasNull = holder.isSet == 0; + this.wasNullConsumer.setWasNull(this.wasNull); + if (this.wasNull) { + return 0; + } + + return holder.value; + } + + @Override + public Object getObject() { + final double value = this.getDouble(); + + return this.wasNull ? null : value; + } + + @Override + public String getString() { + final double value = this.getDouble(); + return this.wasNull ? null : Double.toString(value); + } + + @Override + public boolean getBoolean() { + return this.getDouble() != 0.0; + } + + @Override + public byte getByte() { + return (byte) this.getDouble(); + } + + @Override + public short getShort() { + return (short) this.getDouble(); + } + + @Override + public int getInt() { + return (int) this.getDouble(); + } + + @Override + public long getLong() { + return (long) this.getDouble(); + } + + @Override + public float getFloat() { + return (float) this.getDouble(); + } + + @Override + public BigDecimal getBigDecimal() throws SQLException { + final double value = this.getDouble(); + if (Double.isInfinite(value) || Double.isNaN(value)) { + throw new SQLException("BigDecimal doesn't support Infinite/NaN."); + } + return this.wasNull ? null : BigDecimal.valueOf(value); + } + + @Override + public BigDecimal getBigDecimal(int scale) throws SQLException { + final double value = this.getDouble(); + if (Double.isInfinite(value) || Double.isNaN(value)) { + throw new SQLException("BigDecimal doesn't support Infinite/NaN."); + } + return this.wasNull ? null : BigDecimal.valueOf(value).setScale(scale, RoundingMode.HALF_UP); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcNumericGetter.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcNumericGetter.java new file mode 100644 index 0000000000000..cc802a0089d3e --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcNumericGetter.java @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.arrow.driver.jdbc.accessor.impl.numeric; + +import org.apache.arrow.vector.BaseIntVector; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.UInt1Vector; +import org.apache.arrow.vector.UInt2Vector; +import org.apache.arrow.vector.UInt4Vector; +import org.apache.arrow.vector.UInt8Vector; +import org.apache.arrow.vector.holders.NullableBigIntHolder; +import org.apache.arrow.vector.holders.NullableIntHolder; +import org.apache.arrow.vector.holders.NullableSmallIntHolder; +import org.apache.arrow.vector.holders.NullableTinyIntHolder; +import org.apache.arrow.vector.holders.NullableUInt1Holder; +import org.apache.arrow.vector.holders.NullableUInt2Holder; +import org.apache.arrow.vector.holders.NullableUInt4Holder; +import org.apache.arrow.vector.holders.NullableUInt8Holder; + +/** + * A custom getter for values from the {@link BaseIntVector}. + */ +class ArrowFlightJdbcNumericGetter { + /** + * A holder for values from the {@link BaseIntVector}. + */ + static class NumericHolder { + int isSet; // Tells if value is set; 0 = not set, 1 = set + long value; // Holds actual value + } + + /** + * Functional interface for a getter to baseInt values. + */ + @FunctionalInterface + interface Getter { + void get(int index, NumericHolder holder); + } + + /** + * Main class that will check the type of the vector to create + * a specific getter. + * + * @param vector an instance of the {@link BaseIntVector} + * @return a getter. + */ + static Getter createGetter(BaseIntVector vector) { + if (vector instanceof UInt1Vector) { + return createGetter((UInt1Vector) vector); + } else if (vector instanceof UInt2Vector) { + return createGetter((UInt2Vector) vector); + } else if (vector instanceof UInt4Vector) { + return createGetter((UInt4Vector) vector); + } else if (vector instanceof UInt8Vector) { + return createGetter((UInt8Vector) vector); + } else if (vector instanceof TinyIntVector) { + return createGetter((TinyIntVector) vector); + } else if (vector instanceof SmallIntVector) { + return createGetter((SmallIntVector) vector); + } else if (vector instanceof IntVector) { + return createGetter((IntVector) vector); + } else if (vector instanceof BigIntVector) { + return createGetter((BigIntVector) vector); + } + + throw new UnsupportedOperationException("No valid IntVector was provided."); + } + + /** + * Create a specific getter for {@link UInt1Vector}. + * + * @param vector an instance of the {@link UInt1Vector} + * @return a getter. + */ + private static Getter createGetter(UInt1Vector vector) { + NullableUInt1Holder nullableUInt1Holder = new NullableUInt1Holder(); + + return (index, holder) -> { + vector.get(index, nullableUInt1Holder); + + holder.isSet = nullableUInt1Holder.isSet; + holder.value = nullableUInt1Holder.value; + }; + } + + /** + * Create a specific getter for {@link UInt2Vector}. + * + * @param vector an instance of the {@link UInt2Vector} + * @return a getter. + */ + private static Getter createGetter(UInt2Vector vector) { + NullableUInt2Holder nullableUInt2Holder = new NullableUInt2Holder(); + return (index, holder) -> { + vector.get(index, nullableUInt2Holder); + + holder.isSet = nullableUInt2Holder.isSet; + holder.value = nullableUInt2Holder.value; + }; + } + + /** + * Create a specific getter for {@link UInt4Vector}. + * + * @param vector an instance of the {@link UInt4Vector} + * @return a getter. + */ + private static Getter createGetter(UInt4Vector vector) { + NullableUInt4Holder nullableUInt4Holder = new NullableUInt4Holder(); + return (index, holder) -> { + vector.get(index, nullableUInt4Holder); + + holder.isSet = nullableUInt4Holder.isSet; + holder.value = nullableUInt4Holder.value; + }; + } + + /** + * Create a specific getter for {@link UInt8Vector}. + * + * @param vector an instance of the {@link UInt8Vector} + * @return a getter. + */ + private static Getter createGetter(UInt8Vector vector) { + NullableUInt8Holder nullableUInt8Holder = new NullableUInt8Holder(); + return (index, holder) -> { + vector.get(index, nullableUInt8Holder); + + holder.isSet = nullableUInt8Holder.isSet; + holder.value = nullableUInt8Holder.value; + }; + } + + /** + * Create a specific getter for {@link TinyIntVector}. + * + * @param vector an instance of the {@link TinyIntVector} + * @return a getter. + */ + private static Getter createGetter(TinyIntVector vector) { + NullableTinyIntHolder nullableTinyIntHolder = new NullableTinyIntHolder(); + return (index, holder) -> { + vector.get(index, nullableTinyIntHolder); + + holder.isSet = nullableTinyIntHolder.isSet; + holder.value = nullableTinyIntHolder.value; + }; + } + + /** + * Create a specific getter for {@link SmallIntVector}. + * + * @param vector an instance of the {@link SmallIntVector} + * @return a getter. + */ + private static Getter createGetter(SmallIntVector vector) { + NullableSmallIntHolder nullableSmallIntHolder = new NullableSmallIntHolder(); + return (index, holder) -> { + vector.get(index, nullableSmallIntHolder); + + holder.isSet = nullableSmallIntHolder.isSet; + holder.value = nullableSmallIntHolder.value; + }; + } + + /** + * Create a specific getter for {@link IntVector}. + * + * @param vector an instance of the {@link IntVector} + * @return a getter. + */ + private static Getter createGetter(IntVector vector) { + NullableIntHolder nullableIntHolder = new NullableIntHolder(); + return (index, holder) -> { + vector.get(index, nullableIntHolder); + + holder.isSet = nullableIntHolder.isSet; + holder.value = nullableIntHolder.value; + }; + } + + /** + * Create a specific getter for {@link BigIntVector}. + * + * @param vector an instance of the {@link BigIntVector} + * @return a getter. + */ + private static Getter createGetter(BigIntVector vector) { + NullableBigIntHolder nullableBigIntHolder = new NullableBigIntHolder(); + return (index, holder) -> { + vector.get(index, nullableBigIntHolder); + + holder.isSet = nullableBigIntHolder.isSet; + holder.value = nullableBigIntHolder.value; + }; + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/text/ArrowFlightJdbcVarCharVectorAccessor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/text/ArrowFlightJdbcVarCharVectorAccessor.java new file mode 100644 index 0000000000000..aad8d9094c9c9 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/text/ArrowFlightJdbcVarCharVectorAccessor.java @@ -0,0 +1,258 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.text; + +import static java.nio.charset.StandardCharsets.US_ASCII; +import static java.nio.charset.StandardCharsets.UTF_8; + +import java.io.ByteArrayInputStream; +import java.io.CharArrayReader; +import java.io.InputStream; +import java.io.Reader; +import java.math.BigDecimal; +import java.sql.Date; +import java.sql.SQLException; +import java.sql.Time; +import java.sql.Timestamp; +import java.util.Calendar; +import java.util.function.IntSupplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessor; +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.driver.jdbc.utils.DateTimeUtils; +import org.apache.arrow.vector.LargeVarCharVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.util.Text; + +/** + * Accessor for the Arrow types: {@link VarCharVector} and {@link LargeVarCharVector}. + */ +public class ArrowFlightJdbcVarCharVectorAccessor extends ArrowFlightJdbcAccessor { + + /** + * Functional interface to help integrating VarCharVector and LargeVarCharVector. + */ + @FunctionalInterface + interface Getter { + byte[] get(int index); + } + + private final Getter getter; + + public ArrowFlightJdbcVarCharVectorAccessor(VarCharVector vector, + IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + this(vector::get, currentRowSupplier, setCursorWasNull); + } + + public ArrowFlightJdbcVarCharVectorAccessor(LargeVarCharVector vector, + IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + this(vector::get, currentRowSupplier, setCursorWasNull); + } + + ArrowFlightJdbcVarCharVectorAccessor(Getter getter, + IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.getter = getter; + } + + @Override + public Class getObjectClass() { + return String.class; + } + + @Override + public String getObject() { + final byte[] bytes = getBytes(); + return bytes == null ? null : new String(bytes, UTF_8); + } + + @Override + public String getString() { + return getObject(); + } + + @Override + public byte[] getBytes() { + final byte[] bytes = this.getter.get(getCurrentRow()); + this.wasNull = bytes == null; + this.wasNullConsumer.setWasNull(this.wasNull); + return bytes; + } + + @Override + public boolean getBoolean() throws SQLException { + String value = getString(); + if (value == null || value.equalsIgnoreCase("false") || value.equals("0")) { + return false; + } else if (value.equalsIgnoreCase("true") || value.equals("1")) { + return true; + } else { + throw new SQLException("It is not possible to convert this value to boolean: " + value); + } + } + + @Override + public byte getByte() throws SQLException { + try { + return Byte.parseByte(this.getString()); + } catch (Exception e) { + throw new SQLException(e); + } + } + + @Override + public short getShort() throws SQLException { + try { + return Short.parseShort(this.getString()); + } catch (Exception e) { + throw new SQLException(e); + } + } + + @Override + public int getInt() throws SQLException { + try { + return Integer.parseInt(this.getString()); + } catch (Exception e) { + throw new SQLException(e); + } + } + + @Override + public long getLong() throws SQLException { + try { + return Long.parseLong(this.getString()); + } catch (Exception e) { + throw new SQLException(e); + } + } + + @Override + public float getFloat() throws SQLException { + try { + return Float.parseFloat(this.getString()); + } catch (Exception e) { + throw new SQLException(e); + } + } + + @Override + public double getDouble() throws SQLException { + try { + return Double.parseDouble(this.getString()); + } catch (Exception e) { + throw new SQLException(e); + } + } + + @Override + public BigDecimal getBigDecimal() throws SQLException { + try { + return new BigDecimal(this.getString()); + } catch (Exception e) { + throw new SQLException(e); + } + } + + @Override + public BigDecimal getBigDecimal(int i) throws SQLException { + try { + return BigDecimal.valueOf(this.getLong(), i); + } catch (Exception e) { + throw new SQLException(e); + } + } + + @Override + public InputStream getAsciiStream() { + final String textValue = getString(); + if (textValue == null) { + return null; + } + // Already in UTF-8 + return new ByteArrayInputStream(textValue.getBytes(US_ASCII)); + } + + @Override + public InputStream getUnicodeStream() { + final byte[] value = getBytes(); + if (value == null) { + return null; + } + + // Already in UTF-8 + final Text textValue = new Text(value); + return new ByteArrayInputStream(textValue.getBytes(), 0, textValue.getLength()); + } + + @Override + public Reader getCharacterStream() { + return new CharArrayReader(getString().toCharArray()); + } + + @Override + public Date getDate(Calendar calendar) throws SQLException { + try { + Date date = Date.valueOf(getString()); + if (calendar == null) { + return date; + } + + // Use Calendar to apply time zone's offset + long milliseconds = date.getTime(); + return new Date(DateTimeUtils.applyCalendarOffset(milliseconds, calendar)); + } catch (Exception e) { + throw new SQLException(e); + } + } + + @Override + public Time getTime(Calendar calendar) throws SQLException { + try { + Time time = Time.valueOf(getString()); + if (calendar == null) { + return time; + } + + // Use Calendar to apply time zone's offset + long milliseconds = time.getTime(); + return new Time(DateTimeUtils.applyCalendarOffset(milliseconds, calendar)); + } catch (Exception e) { + throw new SQLException(e); + } + } + + @Override + public Timestamp getTimestamp(Calendar calendar) throws SQLException { + try { + Timestamp timestamp = Timestamp.valueOf(getString()); + if (calendar == null) { + return timestamp; + } + + // Use Calendar to apply time zone's offset + long milliseconds = timestamp.getTime(); + return new Timestamp(DateTimeUtils.applyCalendarOffset(milliseconds, calendar)); + } catch (Exception e) { + throw new SQLException(e); + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java new file mode 100644 index 0000000000000..afac6c164708b --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java @@ -0,0 +1,582 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.client; + +import java.io.IOException; +import java.security.GeneralSecurityException; +import java.sql.SQLException; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; + +import org.apache.arrow.driver.jdbc.client.utils.ClientAuthenticationUtils; +import org.apache.arrow.flight.CallOption; +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightClientMiddleware; +import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightRuntimeException; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.auth2.BearerCredentialWriter; +import org.apache.arrow.flight.auth2.ClientBearerHeaderHandler; +import org.apache.arrow.flight.auth2.ClientIncomingAuthHeaderMiddleware; +import org.apache.arrow.flight.client.ClientCookieMiddleware; +import org.apache.arrow.flight.grpc.CredentialCallOption; +import org.apache.arrow.flight.sql.FlightSqlClient; +import org.apache.arrow.flight.sql.impl.FlightSql.SqlInfo; +import org.apache.arrow.flight.sql.util.TableRef; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.calcite.avatica.Meta.StatementType; + +/** + * A {@link FlightSqlClient} handler. + */ +public final class ArrowFlightSqlClientHandler implements AutoCloseable { + + private final FlightSqlClient sqlClient; + private final Set options = new HashSet<>(); + + ArrowFlightSqlClientHandler(final FlightSqlClient sqlClient, + final Collection options) { + this.options.addAll(options); + this.sqlClient = Preconditions.checkNotNull(sqlClient); + } + + /** + * Creates a new {@link ArrowFlightSqlClientHandler} from the provided {@code client} and {@code options}. + * + * @param client the {@link FlightClient} to manage under a {@link FlightSqlClient} wrapper. + * @param options the {@link CallOption}s to persist in between subsequent client calls. + * @return a new {@link ArrowFlightSqlClientHandler}. + */ + public static ArrowFlightSqlClientHandler createNewHandler(final FlightClient client, + final Collection options) { + return new ArrowFlightSqlClientHandler(new FlightSqlClient(client), options); + } + + /** + * Gets the {@link #options} for the subsequent calls from this handler. + * + * @return the {@link CallOption}s. + */ + private CallOption[] getOptions() { + return options.toArray(new CallOption[0]); + } + + /** + * Makes an RPC "getStream" request based on the provided {@link FlightInfo} + * object. Retrieves the result of the query previously prepared with "getInfo." + * + * @param flightInfo The {@link FlightInfo} instance from which to fetch results. + * @return a {@code FlightStream} of results. + */ + public List getStreams(final FlightInfo flightInfo) { + return flightInfo.getEndpoints().stream() + .map(FlightEndpoint::getTicket) + .map(ticket -> sqlClient.getStream(ticket, getOptions())) + .collect(Collectors.toList()); + } + + /** + * Makes an RPC "getInfo" request based on the provided {@code query} + * object. + * + * @param query The query. + * @return a {@code FlightStream} of results. + */ + public FlightInfo getInfo(final String query) { + return sqlClient.execute(query, getOptions()); + } + + @Override + public void close() throws SQLException { + try { + AutoCloseables.close(sqlClient); + } catch (final Exception e) { + throw new SQLException("Failed to clean up client resources.", e); + } + } + + /** + * A prepared statement handler. + */ + public interface PreparedStatement extends AutoCloseable { + /** + * Executes this {@link PreparedStatement}. + * + * @return the {@link FlightInfo} representing the outcome of this query execution. + * @throws SQLException on error. + */ + FlightInfo executeQuery() throws SQLException; + + /** + * Executes a {@link StatementType#UPDATE} query. + * + * @return the number of rows affected. + */ + long executeUpdate(); + + /** + * Gets the {@link StatementType} of this {@link PreparedStatement}. + * + * @return the Statement Type. + */ + StatementType getType(); + + /** + * Gets the {@link Schema} of this {@link PreparedStatement}. + * + * @return {@link Schema}. + */ + Schema getDataSetSchema(); + + @Override + void close(); + } + + /** + * Creates a new {@link PreparedStatement} for the given {@code query}. + * + * @param query the SQL query. + * @return a new prepared statement. + */ + public PreparedStatement prepare(final String query) { + final FlightSqlClient.PreparedStatement preparedStatement = + sqlClient.prepare(query, getOptions()); + return new PreparedStatement() { + @Override + public FlightInfo executeQuery() throws SQLException { + return preparedStatement.execute(getOptions()); + } + + @Override + public long executeUpdate() { + return preparedStatement.executeUpdate(getOptions()); + } + + @Override + public StatementType getType() { + final Schema schema = preparedStatement.getResultSetSchema(); + return schema.getFields().isEmpty() ? StatementType.UPDATE : StatementType.SELECT; + } + + @Override + public Schema getDataSetSchema() { + return preparedStatement.getResultSetSchema(); + } + + @Override + public void close() { + preparedStatement.close(getOptions()); + } + }; + } + + /** + * Makes an RPC "getCatalogs" request. + * + * @return a {@code FlightStream} of results. + */ + public FlightInfo getCatalogs() { + return sqlClient.getCatalogs(getOptions()); + } + + /** + * Makes an RPC "getImportedKeys" request based on the provided info. + * + * @param catalog The catalog name. Must match the catalog name as it is stored in the database. + * Retrieves those without a catalog. Null means that the catalog name should not be used to + * narrow the search. + * @param schema The schema name. Must match the schema name as it is stored in the database. + * "" retrieves those without a schema. Null means that the schema name should not be used to narrow + * the search. + * @param table The table name. Must match the table name as it is stored in the database. + * @return a {@code FlightStream} of results. + */ + public FlightInfo getImportedKeys(final String catalog, final String schema, final String table) { + return sqlClient.getImportedKeys(TableRef.of(catalog, schema, table), getOptions()); + } + + /** + * Makes an RPC "getExportedKeys" request based on the provided info. + * + * @param catalog The catalog name. Must match the catalog name as it is stored in the database. + * Retrieves those without a catalog. Null means that the catalog name should not be used to + * narrow the search. + * @param schema The schema name. Must match the schema name as it is stored in the database. + * "" retrieves those without a schema. Null means that the schema name should not be used to narrow + * the search. + * @param table The table name. Must match the table name as it is stored in the database. + * @return a {@code FlightStream} of results. + */ + public FlightInfo getExportedKeys(final String catalog, final String schema, final String table) { + return sqlClient.getExportedKeys(TableRef.of(catalog, schema, table), getOptions()); + } + + /** + * Makes an RPC "getSchemas" request based on the provided info. + * + * @param catalog The catalog name. Must match the catalog name as it is stored in the database. + * Retrieves those without a catalog. Null means that the catalog name should not be used to + * narrow the search. + * @param schemaPattern The schema name pattern. Must match the schema name as it is stored in the database. + * Null means that schema name should not be used to narrow down the search. + * @return a {@code FlightStream} of results. + */ + public FlightInfo getSchemas(final String catalog, final String schemaPattern) { + return sqlClient.getSchemas(catalog, schemaPattern, getOptions()); + } + + /** + * Makes an RPC "getTableTypes" request. + * + * @return a {@code FlightStream} of results. + */ + public FlightInfo getTableTypes() { + return sqlClient.getTableTypes(getOptions()); + } + + /** + * Makes an RPC "getTables" request based on the provided info. + * + * @param catalog The catalog name. Must match the catalog name as it is stored in the database. + * Retrieves those without a catalog. Null means that the catalog name should not be used to + * narrow the search. + * @param schemaPattern The schema name pattern. Must match the schema name as it is stored in the database. + * "" retrieves those without a schema. Null means that the schema name should not be used to + * narrow the search. + * @param tableNamePattern The table name pattern. Must match the table name as it is stored in the database. + * @param types The list of table types, which must be from the list of table types to include. + * Null returns all types. + * @param includeSchema Whether to include schema. + * @return a {@code FlightStream} of results. + */ + public FlightInfo getTables(final String catalog, final String schemaPattern, + final String tableNamePattern, + final List types, final boolean includeSchema) { + + return sqlClient.getTables(catalog, schemaPattern, tableNamePattern, types, includeSchema, + getOptions()); + } + + /** + * Gets SQL info. + * + * @return the SQL info. + */ + public FlightInfo getSqlInfo(SqlInfo... info) { + return sqlClient.getSqlInfo(info, getOptions()); + } + + /** + * Makes an RPC "getPrimaryKeys" request based on the provided info. + * + * @param catalog The catalog name; must match the catalog name as it is stored in the database. + * "" retrieves those without a catalog. + * Null means that the catalog name should not be used to narrow the search. + * @param schema The schema name; must match the schema name as it is stored in the database. + * "" retrieves those without a schema. Null means that the schema name should not be used to narrow + * the search. + * @param table The table name. Must match the table name as it is stored in the database. + * @return a {@code FlightStream} of results. + */ + public FlightInfo getPrimaryKeys(final String catalog, final String schema, final String table) { + return sqlClient.getPrimaryKeys(TableRef.of(catalog, schema, table), getOptions()); + } + + /** + * Makes an RPC "getCrossReference" request based on the provided info. + * + * @param pkCatalog The catalog name. Must match the catalog name as it is stored in the database. + * Retrieves those without a catalog. Null means that the catalog name should not be used to + * narrow the search. + * @param pkSchema The schema name. Must match the schema name as it is stored in the database. + * "" retrieves those without a schema. Null means that the schema name should not be used to narrow + * the search. + * @param pkTable The table name. Must match the table name as it is stored in the database. + * @param fkCatalog The catalog name. Must match the catalog name as it is stored in the database. + * Retrieves those without a catalog. Null means that the catalog name should not be used to + * narrow the search. + * @param fkSchema The schema name. Must match the schema name as it is stored in the database. + * "" retrieves those without a schema. Null means that the schema name should not be used to narrow + * the search. + * @param fkTable The table name. Must match the table name as it is stored in the database. + * @return a {@code FlightStream} of results. + */ + public FlightInfo getCrossReference(String pkCatalog, String pkSchema, String pkTable, + String fkCatalog, String fkSchema, String fkTable) { + return sqlClient.getCrossReference(TableRef.of(pkCatalog, pkSchema, pkTable), + TableRef.of(fkCatalog, fkSchema, fkTable), + getOptions()); + } + + /** + * Builder for {@link ArrowFlightSqlClientHandler}. + */ + public static final class Builder { + private final Set middlewareFactories = new HashSet<>(); + private final Set options = new HashSet<>(); + private String host; + private int port; + private String username; + private String password; + private String trustStorePath; + private String trustStorePassword; + private String token; + private boolean useEncryption; + private boolean disableCertificateVerification; + private boolean useSystemTrustStore; + private BufferAllocator allocator; + + /** + * Sets the host for this handler. + * + * @param host the host. + * @return this instance. + */ + public Builder withHost(final String host) { + this.host = host; + return this; + } + + /** + * Sets the port for this handler. + * + * @param port the port. + * @return this instance. + */ + public Builder withPort(final int port) { + this.port = port; + return this; + } + + /** + * Sets the username for this handler. + * + * @param username the username. + * @return this instance. + */ + public Builder withUsername(final String username) { + this.username = username; + return this; + } + + /** + * Sets the password for this handler. + * + * @param password the password. + * @return this instance. + */ + public Builder withPassword(final String password) { + this.password = password; + return this; + } + + /** + * Sets the KeyStore path for this handler. + * + * @param trustStorePath the KeyStore path. + * @return this instance. + */ + public Builder withTrustStorePath(final String trustStorePath) { + this.trustStorePath = trustStorePath; + return this; + } + + /** + * Sets the KeyStore password for this handler. + * + * @param trustStorePassword the KeyStore password. + * @return this instance. + */ + public Builder withTrustStorePassword(final String trustStorePassword) { + this.trustStorePassword = trustStorePassword; + return this; + } + + /** + * Sets whether to use TLS encryption in this handler. + * + * @param useEncryption whether to use TLS encryption. + * @return this instance. + */ + public Builder withEncryption(final boolean useEncryption) { + this.useEncryption = useEncryption; + return this; + } + + /** + * Sets whether to disable the certificate verification in this handler. + * + * @param disableCertificateVerification whether to disable certificate verification. + * @return this instance. + */ + public Builder withDisableCertificateVerification(final boolean disableCertificateVerification) { + this.disableCertificateVerification = disableCertificateVerification; + return this; + } + + /** + * Sets whether to use the certificates from the operating system. + * + * @param useSystemTrustStore whether to use the system operating certificates. + * @return this instance. + */ + public Builder withSystemTrustStore(final boolean useSystemTrustStore) { + this.useSystemTrustStore = useSystemTrustStore; + return this; + } + + /** + * Sets the token used in the token authetication. + * @param token the token value. + * @return this builder instance. + */ + public Builder withToken(final String token) { + this.token = token; + return this; + } + + /** + * Sets the {@link BufferAllocator} to use in this handler. + * + * @param allocator the allocator. + * @return this instance. + */ + public Builder withBufferAllocator(final BufferAllocator allocator) { + this.allocator = allocator + .newChildAllocator("ArrowFlightSqlClientHandler", 0, allocator.getLimit()); + return this; + } + + /** + * Adds the provided {@code factories} to the list of {@link #middlewareFactories} of this handler. + * + * @param factories the factories to add. + * @return this instance. + */ + public Builder withMiddlewareFactories(final FlightClientMiddleware.Factory... factories) { + return withMiddlewareFactories(Arrays.asList(factories)); + } + + /** + * Adds the provided {@code factories} to the list of {@link #middlewareFactories} of this handler. + * + * @param factories the factories to add. + * @return this instance. + */ + public Builder withMiddlewareFactories( + final Collection factories) { + this.middlewareFactories.addAll(factories); + return this; + } + + /** + * Adds the provided {@link CallOption}s to this handler. + * + * @param options the options + * @return this instance. + */ + public Builder withCallOptions(final CallOption... options) { + return withCallOptions(Arrays.asList(options)); + } + + /** + * Adds the provided {@link CallOption}s to this handler. + * + * @param options the options + * @return this instance. + */ + public Builder withCallOptions(final Collection options) { + this.options.addAll(options); + return this; + } + + /** + * Builds a new {@link ArrowFlightSqlClientHandler} from the provided fields. + * + * @return a new client handler. + * @throws SQLException on error. + */ + public ArrowFlightSqlClientHandler build() throws SQLException { + FlightClient client = null; + try { + ClientIncomingAuthHeaderMiddleware.Factory authFactory = null; + if (username != null) { + authFactory = + new ClientIncomingAuthHeaderMiddleware.Factory(new ClientBearerHeaderHandler()); + withMiddlewareFactories(authFactory); + } + final FlightClient.Builder clientBuilder = FlightClient.builder().allocator(allocator); + withMiddlewareFactories(new ClientCookieMiddleware.Factory()); + middlewareFactories.forEach(clientBuilder::intercept); + Location location; + if (useEncryption) { + location = Location.forGrpcTls(host, port); + clientBuilder.useTls(); + } else { + location = Location.forGrpcInsecure(host, port); + } + clientBuilder.location(location); + + if (useEncryption) { + if (disableCertificateVerification) { + clientBuilder.verifyServer(false); + } else { + if (useSystemTrustStore) { + clientBuilder.trustedCertificates( + ClientAuthenticationUtils.getCertificateInputStreamFromSystem(trustStorePassword)); + } else if (trustStorePath != null) { + clientBuilder.trustedCertificates( + ClientAuthenticationUtils.getCertificateStream(trustStorePath, trustStorePassword)); + } + } + } + + client = clientBuilder.build(); + if (authFactory != null) { + options.add( + ClientAuthenticationUtils.getAuthenticate(client, username, password, authFactory)); + } else if (token != null) { + options.add( + ClientAuthenticationUtils.getAuthenticate( + client, new CredentialCallOption(new BearerCredentialWriter(token)))); + } + return ArrowFlightSqlClientHandler.createNewHandler(client, options); + + } catch (final IllegalArgumentException | GeneralSecurityException | IOException | FlightRuntimeException e) { + final SQLException originalException = new SQLException(e); + if (client != null) { + try { + client.close(); + } catch (final InterruptedException interruptedException) { + originalException.addSuppressed(interruptedException); + } + } + throw originalException; + } + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/client/utils/ClientAuthenticationUtils.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/client/utils/ClientAuthenticationUtils.java new file mode 100644 index 0000000000000..854b036ae6b7b --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/client/utils/ClientAuthenticationUtils.java @@ -0,0 +1,251 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.client.utils; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.StringWriter; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.security.GeneralSecurityException; +import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.cert.Certificate; +import java.security.cert.CertificateException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Enumeration; +import java.util.List; + +import org.apache.arrow.flight.CallOption; +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.auth2.BasicAuthCredentialWriter; +import org.apache.arrow.flight.auth2.ClientIncomingAuthHeaderMiddleware; +import org.apache.arrow.flight.grpc.CredentialCallOption; +import org.apache.arrow.util.Preconditions; +import org.apache.arrow.util.VisibleForTesting; +import org.bouncycastle.openssl.jcajce.JcaPEMWriter; + +/** + * Utils for {@link FlightClientHandler} authentication. + */ +public final class ClientAuthenticationUtils { + + private ClientAuthenticationUtils() { + // Prevent instantiation. + } + + /** + * Gets the {@link CredentialCallOption} for the provided authentication info. + * + * @param client the client. + * @param credential the credential as CallOptions. + * @param options the {@link CallOption}s to use. + * @return the credential call option. + */ + public static CredentialCallOption getAuthenticate(final FlightClient client, + final CredentialCallOption credential, + final CallOption... options) { + + final List theseOptions = new ArrayList<>(); + theseOptions.add(credential); + theseOptions.addAll(Arrays.asList(options)); + client.handshake(theseOptions.toArray(new CallOption[0])); + + return (CredentialCallOption) theseOptions.get(0); + } + + /** + * Gets the {@link CredentialCallOption} for the provided authentication info. + * + * @param client the client. + * @param username the username. + * @param password the password. + * @param factory the {@link ClientIncomingAuthHeaderMiddleware.Factory} to use. + * @param options the {@link CallOption}s to use. + * @return the credential call option. + */ + public static CredentialCallOption getAuthenticate(final FlightClient client, + final String username, final String password, + final ClientIncomingAuthHeaderMiddleware.Factory factory, + final CallOption... options) { + + return getAuthenticate(client, + new CredentialCallOption(new BasicAuthCredentialWriter(username, password)), + factory, options); + } + + private static CredentialCallOption getAuthenticate(final FlightClient client, + final CredentialCallOption token, + final ClientIncomingAuthHeaderMiddleware.Factory factory, + final CallOption... options) { + final List theseOptions = new ArrayList<>(); + theseOptions.add(token); + theseOptions.addAll(Arrays.asList(options)); + client.handshake(theseOptions.toArray(new CallOption[0])); + return factory.getCredentialCallOption(); + } + + @VisibleForTesting + static KeyStore getKeyStoreInstance(String instance) + throws KeyStoreException, CertificateException, IOException, NoSuchAlgorithmException { + KeyStore keyStore = KeyStore.getInstance(instance); + keyStore.load(null, null); + + return keyStore; + } + + static String getOperatingSystem() { + return System.getProperty("os.name"); + } + + /** + * Check if the operating system running the software is Windows. + * + * @return whether is the windows system. + */ + public static boolean isWindows() { + return getOperatingSystem().contains("Windows"); + } + + /** + * Check if the operating system running the software is Mac. + * + * @return whether is the mac system. + */ + public static boolean isMac() { + return getOperatingSystem().contains("Mac"); + } + + /** + * It gets the trusted certificate based on the operating system and loads all the certificate into a + * {@link InputStream}. + * + * @return An input stream with all the certificates. + * + * @throws KeyStoreException if a key store could not be loaded. + * @throws CertificateException if a certificate could not be found. + * @throws IOException if it fails reading the file. + */ + public static InputStream getCertificateInputStreamFromSystem(String password) throws KeyStoreException, + CertificateException, IOException, NoSuchAlgorithmException { + + List keyStoreList = new ArrayList<>(); + if (isWindows()) { + keyStoreList.add(getKeyStoreInstance("Windows-ROOT")); + keyStoreList.add(getKeyStoreInstance("Windows-MY")); + } else if (isMac()) { + keyStoreList.add(getKeyStoreInstance("KeychainStore")); + } else { + Path path = Paths.get(System.getProperty("java.home"), "lib", "security", "cacerts"); + try (InputStream fileInputStream = Files.newInputStream(path)) { + KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType()); + keyStore.load(fileInputStream, password.toCharArray()); + keyStoreList.add(keyStore); + } + } + + return getCertificatesInputStream(keyStoreList); + } + + @VisibleForTesting + static void getCertificatesInputStream(KeyStore keyStore, JcaPEMWriter pemWriter) + throws IOException, KeyStoreException { + Enumeration aliases = keyStore.aliases(); + while (aliases.hasMoreElements()) { + String alias = aliases.nextElement(); + if (keyStore.isCertificateEntry(alias)) { + pemWriter.writeObject(keyStore.getCertificate(alias)); + } + } + pemWriter.flush(); + } + + @VisibleForTesting + static InputStream getCertificatesInputStream(Collection keyStores) + throws IOException, KeyStoreException { + try (final StringWriter writer = new StringWriter(); + final JcaPEMWriter pemWriter = new JcaPEMWriter(writer)) { + + for (KeyStore keyStore : keyStores) { + getCertificatesInputStream(keyStore, pemWriter); + } + + return new ByteArrayInputStream( + writer.toString().getBytes(StandardCharsets.UTF_8)); + } + } + + /** + * Generates an {@link InputStream} that contains certificates for a private + * key. + * + * @param keyStorePath The path of the KeyStore. + * @param keyStorePass The password of the KeyStore. + * @return a new {code InputStream} containing the certificates. + * @throws GeneralSecurityException on error. + * @throws IOException on error. + */ + public static InputStream getCertificateStream(final String keyStorePath, + final String keyStorePass) + throws GeneralSecurityException, IOException { + Preconditions.checkNotNull(keyStorePath, "KeyStore path cannot be null!"); + Preconditions.checkNotNull(keyStorePass, "KeyStorePass cannot be null!"); + final KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType()); + + try (final InputStream keyStoreStream = Files + .newInputStream(Paths.get(Preconditions.checkNotNull(keyStorePath)))) { + keyStore.load(keyStoreStream, + Preconditions.checkNotNull(keyStorePass).toCharArray()); + } + + return getSingleCertificateInputStream(keyStore); + } + + private static InputStream getSingleCertificateInputStream(KeyStore keyStore) + throws KeyStoreException, IOException, CertificateException { + final Enumeration aliases = keyStore.aliases(); + + while (aliases.hasMoreElements()) { + final String alias = aliases.nextElement(); + if (keyStore.isCertificateEntry(alias)) { + return toInputStream(keyStore.getCertificate(alias)); + } + } + + throw new CertificateException("Keystore did not have a certificate."); + } + + private static InputStream toInputStream(final Certificate certificate) + throws IOException { + + try (final StringWriter writer = new StringWriter(); + final JcaPEMWriter pemWriter = new JcaPEMWriter(writer)) { + + pemWriter.writeObject(certificate); + pemWriter.flush(); + return new ByteArrayInputStream( + writer.toString().getBytes(StandardCharsets.UTF_8)); + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImpl.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImpl.java new file mode 100644 index 0000000000000..ac338a85d6292 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImpl.java @@ -0,0 +1,292 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.Properties; + +import org.apache.arrow.driver.jdbc.ArrowFlightConnection; +import org.apache.arrow.flight.CallHeaders; +import org.apache.arrow.flight.CallOption; +import org.apache.arrow.flight.FlightCallHeaders; +import org.apache.arrow.flight.HeaderCallOption; +import org.apache.arrow.util.Preconditions; +import org.apache.calcite.avatica.ConnectionConfig; +import org.apache.calcite.avatica.ConnectionConfigImpl; +import org.apache.calcite.avatica.ConnectionProperty; + +/** + * A {@link ConnectionConfig} for the {@link ArrowFlightConnection}. + */ +public final class ArrowFlightConnectionConfigImpl extends ConnectionConfigImpl { + public ArrowFlightConnectionConfigImpl(final Properties properties) { + super(properties); + } + + /** + * Gets the host. + * + * @return the host. + */ + public String getHost() { + return ArrowFlightConnectionProperty.HOST.getString(properties); + } + + /** + * Gets the port. + * + * @return the port. + */ + public int getPort() { + return ArrowFlightConnectionProperty.PORT.getInteger(properties); + } + + /** + * Gets the host. + * + * @return the host. + */ + public String getUser() { + return ArrowFlightConnectionProperty.USER.getString(properties); + } + + /** + * Gets the host. + * + * @return the host. + */ + public String getPassword() { + return ArrowFlightConnectionProperty.PASSWORD.getString(properties); + } + + + public String getToken() { + return ArrowFlightConnectionProperty.TOKEN.getString(properties); + } + + /** + * Gets the KeyStore path. + * + * @return the path. + */ + public String getTrustStorePath() { + return ArrowFlightConnectionProperty.TRUST_STORE.getString(properties); + } + + /** + * Gets the KeyStore password. + * + * @return the password. + */ + public String getTrustStorePassword() { + return ArrowFlightConnectionProperty.TRUST_STORE_PASSWORD.getString(properties); + } + + /** + * Check if the JDBC should use the trusted store files from the operating system. + * + * @return whether to use system trusted store certificates. + */ + public boolean useSystemTrustStore() { + return ArrowFlightConnectionProperty.USE_SYSTEM_TRUST_STORE.getBoolean(properties); + } + + /** + * Whether to use TLS encryption. + * + * @return whether to use TLS encryption. + */ + public boolean useEncryption() { + return ArrowFlightConnectionProperty.USE_ENCRYPTION.getBoolean(properties); + } + + public boolean getDisableCertificateVerification() { + return ArrowFlightConnectionProperty.CERTIFICATE_VERIFICATION.getBoolean(properties); + } + + /** + * Gets the thread pool size. + * + * @return the thread pool size. + */ + public int threadPoolSize() { + return ArrowFlightConnectionProperty.THREAD_POOL_SIZE.getInteger(properties); + } + + /** + * Gets the {@link CallOption}s from this {@link ConnectionConfig}. + * + * @return the call options. + */ + public CallOption toCallOption() { + final CallHeaders headers = new FlightCallHeaders(); + Map headerAttributes = getHeaderAttributes(); + headerAttributes.forEach(headers::insert); + return new HeaderCallOption(headers); + } + + /** + * Gets which properties should be added as headers. + * + * @return {@link Map} + */ + public Map getHeaderAttributes() { + Map headers = new HashMap<>(); + ArrowFlightConnectionProperty[] builtInProperties = ArrowFlightConnectionProperty.values(); + properties.forEach( + (key, val) -> { + // For built-in properties before adding new headers + if (Arrays.stream(builtInProperties) + .noneMatch(builtInProperty -> builtInProperty.camelName.equalsIgnoreCase(key.toString()))) { + headers.put(key.toString(), val.toString()); + } + }); + return headers; + } + + /** + * Custom {@link ConnectionProperty} for the {@link ArrowFlightConnectionConfigImpl}. + */ + public enum ArrowFlightConnectionProperty implements ConnectionProperty { + HOST("host", null, Type.STRING, true), + PORT("port", null, Type.NUMBER, true), + USER("user", null, Type.STRING, false), + PASSWORD("password", null, Type.STRING, false), + USE_ENCRYPTION("useEncryption", true, Type.BOOLEAN, false), + CERTIFICATE_VERIFICATION("disableCertificateVerification", false, Type.BOOLEAN, false), + TRUST_STORE("trustStore", null, Type.STRING, false), + TRUST_STORE_PASSWORD("trustStorePassword", null, Type.STRING, false), + USE_SYSTEM_TRUST_STORE("useSystemTrustStore", true, Type.BOOLEAN, false), + THREAD_POOL_SIZE("threadPoolSize", 1, Type.NUMBER, false), + TOKEN("token", null, Type.STRING, false); + + private final String camelName; + private final Object defaultValue; + private final Type type; + private final boolean required; + + ArrowFlightConnectionProperty(final String camelName, final Object defaultValue, + final Type type, final boolean required) { + this.camelName = Preconditions.checkNotNull(camelName); + this.defaultValue = defaultValue; + this.type = Preconditions.checkNotNull(type); + this.required = required; + } + + /** + * Gets the property. + * + * @param properties the properties from which to fetch this property. + * @return the property. + */ + public Object get(final Properties properties) { + Preconditions.checkNotNull(properties, "Properties cannot be null."); + Object value = properties.get(camelName); + if (value == null) { + value = properties.get(camelName.toLowerCase()); + } + if (required) { + if (value == null) { + throw new IllegalStateException(String.format("Required property not provided: <%s>.", this)); + } + return value; + } else { + return value != null ? value : defaultValue; + } + } + + /** + * Gets the property as Boolean. + * + * @param properties the properties from which to fetch this property. + * @return the property. + */ + public Boolean getBoolean(final Properties properties) { + final String valueFromProperties = String.valueOf(get(properties)); + return valueFromProperties.equals("1") || valueFromProperties.equals("true"); + } + + /** + * Gets the property as Integer. + * + * @param properties the properties from which to fetch this property. + * @return the property. + */ + public Integer getInteger(final Properties properties) { + final String valueFromProperties = String.valueOf(get(properties)); + return valueFromProperties.equals("null") ? null : Integer.parseInt(valueFromProperties); + } + + /** + * Gets the property as String. + * + * @param properties the properties from which to fetch this property. + * @return the property. + */ + public String getString(final Properties properties) { + return Objects.toString(get(properties), null); + } + + @Override + public String camelName() { + return camelName; + } + + @Override + public Object defaultValue() { + return defaultValue; + } + + @Override + public Type type() { + return type; + } + + @Override + public PropEnv wrap(final Properties properties) { + throw new UnsupportedOperationException("Operation unsupported."); + } + + @Override + public boolean required() { + return required; + } + + @Override + public Class valueClass() { + return type.defaultValueClass(); + } + + /** + * Replaces the semicolons in the URL to the proper format. + * + * @param url the current connection string + * @return the formatted url + */ + public static String replaceSemiColons(String url) { + if (url != null) { + url = url.replaceFirst(";", "?"); + url = url.replaceAll(";", "&"); + } + return url; + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/ConnectionWrapper.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/ConnectionWrapper.java new file mode 100644 index 0000000000000..5ee43ce012e94 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/ConnectionWrapper.java @@ -0,0 +1,344 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import static com.google.common.base.Preconditions.checkNotNull; + +import java.sql.Array; +import java.sql.Blob; +import java.sql.CallableStatement; +import java.sql.Clob; +import java.sql.Connection; +import java.sql.DatabaseMetaData; +import java.sql.NClob; +import java.sql.PreparedStatement; +import java.sql.SQLClientInfoException; +import java.sql.SQLException; +import java.sql.SQLWarning; +import java.sql.SQLXML; +import java.sql.Savepoint; +import java.sql.Statement; +import java.sql.Struct; +import java.util.Map; +import java.util.Properties; +import java.util.concurrent.Executor; + +import org.apache.arrow.driver.jdbc.ArrowFlightJdbcPooledConnection; + +/** + * Auxiliary wrapper class for {@link Connection}, used on {@link ArrowFlightJdbcPooledConnection}. + */ +public class ConnectionWrapper implements Connection { + private final Connection realConnection; + + public ConnectionWrapper(final Connection connection) { + realConnection = checkNotNull(connection); + } + + @Override + public T unwrap(final Class type) { + return type.cast(realConnection); + } + + @Override + public boolean isWrapperFor(final Class type) { + return realConnection.getClass().isAssignableFrom(type); + } + + @Override + public Statement createStatement() throws SQLException { + return realConnection.createStatement(); + } + + @Override + public PreparedStatement prepareStatement(final String sqlQuery) throws SQLException { + return realConnection.prepareStatement(sqlQuery); + } + + @Override + public CallableStatement prepareCall(final String sqlQuery) throws SQLException { + return realConnection.prepareCall(sqlQuery); + } + + @Override + public String nativeSQL(final String sqlStatement) throws SQLException { + return realConnection.nativeSQL(sqlStatement); + } + + @Override + public void setAutoCommit(boolean autoCommit) throws SQLException { + realConnection.setAutoCommit(autoCommit); + } + + @Override + public boolean getAutoCommit() throws SQLException { + return realConnection.getAutoCommit(); + } + + @Override + public void commit() throws SQLException { + realConnection.commit(); + } + + @Override + public void rollback() throws SQLException { + realConnection.rollback(); + } + + @Override + public void close() throws SQLException { + realConnection.close(); + } + + @Override + public boolean isClosed() throws SQLException { + return realConnection.isClosed(); + } + + @Override + public DatabaseMetaData getMetaData() throws SQLException { + return realConnection.getMetaData(); + } + + @Override + public void setReadOnly(final boolean readOnly) throws SQLException { + realConnection.setReadOnly(readOnly); + } + + @Override + public boolean isReadOnly() throws SQLException { + return realConnection.isReadOnly(); + } + + @Override + public void setCatalog(final String catalogName) throws SQLException { + realConnection.setCatalog(catalogName); + } + + @Override + public String getCatalog() throws SQLException { + return realConnection.getCatalog(); + } + + @Override + public void setTransactionIsolation(final int transactionIsolationId) throws SQLException { + realConnection.setTransactionIsolation(transactionIsolationId); + } + + @Override + public int getTransactionIsolation() throws SQLException { + return realConnection.getTransactionIsolation(); + } + + @Override + public SQLWarning getWarnings() throws SQLException { + return realConnection.getWarnings(); + } + + @Override + public void clearWarnings() throws SQLException { + realConnection.clearWarnings(); + } + + @Override + public Statement createStatement(final int resultSetTypeId, final int resultSetConcurrencyId) + throws SQLException { + return realConnection.createStatement(resultSetTypeId, resultSetConcurrencyId); + } + + @Override + public PreparedStatement prepareStatement(final String sqlQuery, final int resultSetTypeId, + final int resultSetConcurrencyId) + throws SQLException { + return realConnection.prepareStatement(sqlQuery, resultSetTypeId, resultSetConcurrencyId); + } + + @Override + public CallableStatement prepareCall(final String query, final int resultSetTypeId, + final int resultSetConcurrencyId) + throws SQLException { + return realConnection.prepareCall(query, resultSetTypeId, resultSetConcurrencyId); + } + + @Override + public Map> getTypeMap() throws SQLException { + return realConnection.getTypeMap(); + } + + @Override + public void setTypeMap(final Map> typeNameToClass) throws SQLException { + realConnection.setTypeMap(typeNameToClass); + } + + @Override + public void setHoldability(final int holdabilityId) throws SQLException { + realConnection.setHoldability(holdabilityId); + } + + @Override + public int getHoldability() throws SQLException { + return realConnection.getHoldability(); + } + + @Override + public Savepoint setSavepoint() throws SQLException { + return realConnection.setSavepoint(); + } + + @Override + public Savepoint setSavepoint(final String savepointName) throws SQLException { + return realConnection.setSavepoint(savepointName); + } + + @Override + public void rollback(final Savepoint savepoint) throws SQLException { + realConnection.rollback(savepoint); + } + + @Override + public void releaseSavepoint(final Savepoint savepoint) throws SQLException { + realConnection.releaseSavepoint(savepoint); + } + + @Override + public Statement createStatement(final int resultSetType, + final int resultSetConcurrency, + final int resultSetHoldability) throws SQLException { + return realConnection.createStatement(resultSetType, resultSetConcurrency, + resultSetHoldability); + } + + @Override + public PreparedStatement prepareStatement(final String sqlQuery, + final int resultSetType, + final int resultSetConcurrency, + final int resultSetHoldability) throws SQLException { + return realConnection.prepareStatement(sqlQuery, resultSetType, resultSetConcurrency, + resultSetHoldability); + } + + @Override + public CallableStatement prepareCall(final String sqlQuery, + final int resultSetType, + final int resultSetConcurrency, + final int resultSetHoldability) throws SQLException { + return realConnection.prepareCall(sqlQuery, resultSetType, resultSetConcurrency, + resultSetHoldability); + } + + @Override + public PreparedStatement prepareStatement(final String sqlQuery, final int autoGeneratedKeysId) + throws SQLException { + return realConnection.prepareStatement(sqlQuery, autoGeneratedKeysId); + } + + @Override + public PreparedStatement prepareStatement(final String sqlQuery, final int... columnIndices) + throws SQLException { + return realConnection.prepareStatement(sqlQuery, columnIndices); + } + + @Override + public PreparedStatement prepareStatement(final String sqlQuery, final String... columnNames) + throws SQLException { + return realConnection.prepareStatement(sqlQuery, columnNames); + } + + @Override + public Clob createClob() throws SQLException { + return realConnection.createClob(); + } + + @Override + public Blob createBlob() throws SQLException { + return realConnection.createBlob(); + } + + @Override + public NClob createNClob() throws SQLException { + return realConnection.createNClob(); + } + + @Override + public SQLXML createSQLXML() throws SQLException { + return realConnection.createSQLXML(); + } + + @Override + public boolean isValid(final int timeout) throws SQLException { + return realConnection.isValid(timeout); + } + + @Override + public void setClientInfo(final String propertyName, final String propertyValue) + throws SQLClientInfoException { + realConnection.setClientInfo(propertyName, propertyValue); + } + + @Override + public void setClientInfo(final Properties properties) throws SQLClientInfoException { + realConnection.setClientInfo(properties); + } + + @Override + public String getClientInfo(final String propertyName) throws SQLException { + return realConnection.getClientInfo(propertyName); + } + + @Override + public Properties getClientInfo() throws SQLException { + return realConnection.getClientInfo(); + } + + @Override + public Array createArrayOf(final String typeName, final Object... elements) throws SQLException { + return realConnection.createArrayOf(typeName, elements); + } + + @Override + public Struct createStruct(final String typeName, final Object... attributes) + throws SQLException { + return realConnection.createStruct(typeName, attributes); + } + + @Override + public void setSchema(final String schemaName) throws SQLException { + realConnection.setSchema(schemaName); + } + + @Override + public String getSchema() throws SQLException { + return realConnection.getSchema(); + } + + @Override + public void abort(final Executor executor) throws SQLException { + realConnection.abort(executor); + } + + @Override + public void setNetworkTimeout(final Executor executor, final int timeoutInMillis) + throws SQLException { + realConnection.setNetworkTimeout(executor, timeoutInMillis); + } + + @Override + public int getNetworkTimeout() throws SQLException { + return realConnection.getNetworkTimeout(); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/ConvertUtils.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/ConvertUtils.java new file mode 100644 index 0000000000000..324f991ef09e9 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/ConvertUtils.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import org.apache.arrow.flight.sql.FlightSqlColumnMetadata; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.calcite.avatica.ColumnMetaData; +import org.apache.calcite.avatica.proto.Common; +import org.apache.calcite.avatica.proto.Common.ColumnMetaData.Builder; + +/** + * Convert Fields To Column MetaData List functions. + */ +public final class ConvertUtils { + + private ConvertUtils() { + } + + /** + * Convert Fields To Column MetaData List functions. + * + * @param fields list of {@link Field}. + * @return list of {@link ColumnMetaData}. + */ + public static List convertArrowFieldsToColumnMetaDataList(final List fields) { + return Stream.iterate(0, Math::incrementExact).limit(fields.size()) + .map(index -> { + final Field field = fields.get(index); + final ArrowType fieldType = field.getType(); + + final Builder builder = Common.ColumnMetaData.newBuilder() + .setOrdinal(index) + .setColumnName(field.getName()) + .setLabel(field.getName()); + + setOnColumnMetaDataBuilder(builder, field.getMetadata()); + + builder.setType(Common.AvaticaType.newBuilder() + .setId(SqlTypes.getSqlTypeIdFromArrowType(fieldType)) + .setName(SqlTypes.getSqlTypeNameFromArrowType(fieldType)) + .build()); + + return ColumnMetaData.fromProto(builder.build()); + }).collect(Collectors.toList()); + } + + /** + * Set on Column MetaData Builder. + * + * @param builder {@link Builder} + * @param metadataMap {@link Map} + */ + public static void setOnColumnMetaDataBuilder(final Builder builder, + final Map metadataMap) { + final FlightSqlColumnMetadata columnMetadata = new FlightSqlColumnMetadata(metadataMap); + final String catalogName = columnMetadata.getCatalogName(); + if (catalogName != null) { + builder.setCatalogName(catalogName); + } + final String schemaName = columnMetadata.getSchemaName(); + if (schemaName != null) { + builder.setSchemaName(schemaName); + } + final String tableName = columnMetadata.getTableName(); + if (tableName != null) { + builder.setTableName(tableName); + } + + final Integer precision = columnMetadata.getPrecision(); + if (precision != null) { + builder.setPrecision(precision); + } + final Integer scale = columnMetadata.getScale(); + if (scale != null) { + builder.setScale(scale); + } + + final Boolean isAutoIncrement = columnMetadata.isAutoIncrement(); + if (isAutoIncrement != null) { + builder.setAutoIncrement(isAutoIncrement); + } + final Boolean caseSensitive = columnMetadata.isCaseSensitive(); + if (caseSensitive != null) { + builder.setCaseSensitive(caseSensitive); + } + final Boolean readOnly = columnMetadata.isReadOnly(); + if (readOnly != null) { + builder.setReadOnly(readOnly); + } + final Boolean searchable = columnMetadata.isSearchable(); + if (searchable != null) { + builder.setSearchable(searchable); + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/DateTimeUtils.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/DateTimeUtils.java new file mode 100644 index 0000000000000..dd94a09256dd5 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/DateTimeUtils.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import static org.apache.calcite.avatica.util.DateTimeUtils.MILLIS_PER_DAY; + +import java.sql.Timestamp; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.LocalTime; +import java.util.Calendar; +import java.util.TimeZone; +import java.util.concurrent.TimeUnit; + +/** + * Datetime utility functions. + */ +public class DateTimeUtils { + private DateTimeUtils() { + // Prevent instantiation. + } + + /** + * Subtracts given Calendar's TimeZone offset from epoch milliseconds. + */ + public static long applyCalendarOffset(long milliseconds, Calendar calendar) { + if (calendar == null) { + calendar = Calendar.getInstance(); + } + + final TimeZone tz = calendar.getTimeZone(); + final TimeZone defaultTz = TimeZone.getDefault(); + + if (tz != defaultTz) { + milliseconds -= tz.getOffset(milliseconds) - defaultTz.getOffset(milliseconds); + } + + return milliseconds; + } + + + /** + * Converts Epoch millis to a {@link Timestamp} object. + * + * @param millisWithCalendar the Timestamp in Epoch millis + * @return a {@link Timestamp} object representing the given Epoch millis + */ + public static Timestamp getTimestampValue(long millisWithCalendar) { + long milliseconds = millisWithCalendar; + if (milliseconds < 0) { + // LocalTime#ofNanoDay only accepts positive values + milliseconds -= ((milliseconds / MILLIS_PER_DAY) - 1) * MILLIS_PER_DAY; + } + + return Timestamp.valueOf( + LocalDateTime.of( + LocalDate.ofEpochDay(millisWithCalendar / MILLIS_PER_DAY), + LocalTime.ofNanoOfDay(TimeUnit.MILLISECONDS.toNanos(milliseconds % MILLIS_PER_DAY))) + ); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/FlightStreamQueue.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/FlightStreamQueue.java new file mode 100644 index 0000000000000..e1d770800e40c --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/FlightStreamQueue.java @@ -0,0 +1,237 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import static java.lang.String.format; +import static java.util.Collections.synchronizedSet; +import static org.apache.arrow.util.Preconditions.checkNotNull; +import static org.apache.arrow.util.Preconditions.checkState; + +import java.sql.SQLException; +import java.sql.SQLTimeoutException; +import java.util.Collection; +import java.util.HashSet; +import java.util.Set; +import java.util.concurrent.CancellationException; +import java.util.concurrent.CompletionService; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorCompletionService; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.FlightRuntimeException; +import org.apache.arrow.flight.FlightStream; +import org.apache.calcite.avatica.AvaticaConnection; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Auxiliary class used to handle consuming of multiple {@link FlightStream}. + *

+ * The usage follows this routine: + *

    + *
  1. Create a FlightStreamQueue;
  2. + *
  3. Call enqueue(FlightStream) for all streams to be consumed;
  4. + *
  5. Call next() to get a FlightStream that is ready to consume
  6. + *
  7. Consume the given FlightStream and add it back to the queue - call enqueue(FlightStream)
  8. + *
  9. Repeat from (3) until next() returns null.
  10. + *
+ */ +public class FlightStreamQueue implements AutoCloseable { + private static final Logger LOGGER = LoggerFactory.getLogger(FlightStreamQueue.class); + private final CompletionService completionService; + private final Set> futures = synchronizedSet(new HashSet<>()); + private final Set allStreams = synchronizedSet(new HashSet<>()); + private final AtomicBoolean closed = new AtomicBoolean(); + + /** + * Instantiate a new FlightStreamQueue. + */ + protected FlightStreamQueue(final CompletionService executorService) { + completionService = checkNotNull(executorService); + } + + /** + * Creates a new {@link FlightStreamQueue} from the provided {@link ExecutorService}. + * + * @param service the service from which to create a new queue. + * @return a new queue. + */ + public static FlightStreamQueue createNewQueue(final ExecutorService service) { + return new FlightStreamQueue(new ExecutorCompletionService<>(service)); + } + + /** + * Gets whether this queue is closed. + * + * @return a boolean indicating whether this resource is closed. + */ + public boolean isClosed() { + return closed.get(); + } + + /** + * Auxiliary functional interface for getting ready-to-consume FlightStreams. + */ + @FunctionalInterface + interface FlightStreamSupplier { + Future get() throws SQLException; + } + + private FlightStream next(final FlightStreamSupplier flightStreamSupplier) throws SQLException { + checkOpen(); + while (!futures.isEmpty()) { + final Future future = flightStreamSupplier.get(); + futures.remove(future); + try { + final FlightStream stream = future.get(); + if (stream.getRoot().getRowCount() > 0) { + return stream; + } + } catch (final ExecutionException | InterruptedException | CancellationException e) { + throw AvaticaConnection.HELPER.wrap(e.getMessage(), e); + } + } + return null; + } + + /** + * Blocking request with timeout to get the next ready FlightStream in queue. + * + * @param timeoutValue the amount of time to be waited + * @param timeoutUnit the timeoutValue time unit + * @return a FlightStream that is ready to consume or null if all FlightStreams are ended. + */ + public FlightStream next(final long timeoutValue, final TimeUnit timeoutUnit) + throws SQLException { + return next(() -> { + try { + final Future future = completionService.poll(timeoutValue, timeoutUnit); + if (future != null) { + return future; + } + } catch (final InterruptedException e) { + throw new SQLTimeoutException("Query was interrupted", e); + } + + throw new SQLTimeoutException( + String.format("Query timed out after %d %s", timeoutValue, timeoutUnit)); + }); + } + + /** + * Blocking request to get the next ready FlightStream in queue. + * + * @return a FlightStream that is ready to consume or null if all FlightStreams are ended. + */ + public FlightStream next() throws SQLException { + return next(() -> { + try { + return completionService.take(); + } catch (final InterruptedException e) { + throw AvaticaConnection.HELPER.wrap(e.getMessage(), e); + } + }); + } + + /** + * Checks if this queue is open. + */ + public synchronized void checkOpen() { + checkState(!isClosed(), format("%s closed", this.getClass().getSimpleName())); + } + + /** + * Readily adds given {@link FlightStream}s to the queue. + */ + public void enqueue(final Collection flightStreams) { + flightStreams.forEach(this::enqueue); + } + + /** + * Adds given {@link FlightStream} to the queue. + */ + public synchronized void enqueue(final FlightStream flightStream) { + checkNotNull(flightStream); + checkOpen(); + allStreams.add(flightStream); + futures.add(completionService.submit(() -> { + // `FlightStream#next` will block until new data can be read or stream is over. + flightStream.next(); + return flightStream; + })); + } + + private static boolean isCallStatusCancelled(final Exception e) { + return e.getCause() instanceof FlightRuntimeException && + ((FlightRuntimeException) e.getCause()).status().code() == CallStatus.CANCELLED.code(); + } + + @Override + public synchronized void close() throws SQLException { + final Set exceptions = new HashSet<>(); + if (isClosed()) { + return; + } + try { + for (final FlightStream flightStream : allStreams) { + try { + flightStream.cancel("Cancelling this FlightStream.", null); + } catch (final Exception e) { + final String errorMsg = "Failed to cancel a FlightStream."; + LOGGER.error(errorMsg, e); + exceptions.add(new SQLException(errorMsg, e)); + } + } + futures.forEach(future -> { + try { + // TODO: Consider adding a hardcoded timeout? + future.get(); + } catch (final InterruptedException | ExecutionException e) { + // Ignore if future is already cancelled + if (!isCallStatusCancelled(e)) { + final String errorMsg = "Failed consuming a future during close."; + LOGGER.error(errorMsg, e); + exceptions.add(new SQLException(errorMsg, e)); + } + } + }); + for (final FlightStream flightStream : allStreams) { + try { + flightStream.close(); + } catch (final Exception e) { + final String errorMsg = "Failed to close a FlightStream."; + LOGGER.error(errorMsg, e); + exceptions.add(new SQLException(errorMsg, e)); + } + } + } finally { + allStreams.clear(); + futures.clear(); + closed.set(true); + } + if (!exceptions.isEmpty()) { + final SQLException sqlException = new SQLException("Failed to close streams."); + exceptions.forEach(sqlException::setNextException); + throw sqlException; + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/IntervalStringUtils.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/IntervalStringUtils.java new file mode 100644 index 0000000000000..05643274ac348 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/IntervalStringUtils.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import org.apache.arrow.vector.util.DateUtility; +import org.joda.time.Period; + +/** + * Utility class to format periods similar to Oracle's representation + * of "INTERVAL * to *" data type. + */ +public final class IntervalStringUtils { + + /** + * Constructor Method of class. + */ + private IntervalStringUtils( ) {} + + /** + * Formats a period similar to Oracle INTERVAL YEAR TO MONTH data type
. + * For example, the string "+21-02" defines an interval of 21 years and 2 months. + */ + public static String formatIntervalYear(final Period p) { + long months = p.getYears() * (long) DateUtility.yearsToMonths + p.getMonths(); + boolean neg = false; + if (months < 0) { + months = -months; + neg = true; + } + final int years = (int) (months / DateUtility.yearsToMonths); + months = months % DateUtility.yearsToMonths; + + return String.format("%c%03d-%02d", neg ? '-' : '+', years, months); + } + + /** + * Formats a period similar to Oracle INTERVAL DAY TO SECOND data type.
. + * For example, the string "-001 18:25:16.766" defines an interval of + * - 1 day 18 hours 25 minutes 16 seconds and 766 milliseconds. + */ + public static String formatIntervalDay(final Period p) { + long millis = p.getDays() * (long) DateUtility.daysToStandardMillis + millisFromPeriod(p); + + boolean neg = false; + if (millis < 0) { + millis = -millis; + neg = true; + } + + final int days = (int) (millis / DateUtility.daysToStandardMillis); + millis = millis % DateUtility.daysToStandardMillis; + + final int hours = (int) (millis / DateUtility.hoursToMillis); + millis = millis % DateUtility.hoursToMillis; + + final int minutes = (int) (millis / DateUtility.minutesToMillis); + millis = millis % DateUtility.minutesToMillis; + + final int seconds = (int) (millis / DateUtility.secondsToMillis); + millis = millis % DateUtility.secondsToMillis; + + return String.format("%c%03d %02d:%02d:%02d.%03d", neg ? '-' : '+', days, hours, minutes, seconds, millis); + } + + public static int millisFromPeriod(Period period) { + return period.getHours() * DateUtility.hoursToMillis + period.getMinutes() * DateUtility.minutesToMillis + + period.getSeconds() * DateUtility.secondsToMillis + period.getMillis(); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/SqlTypes.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/SqlTypes.java new file mode 100644 index 0000000000000..85c3964303c45 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/SqlTypes.java @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import java.sql.Types; +import java.util.HashMap; +import java.util.Map; + +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.pojo.ArrowType; + +/** + * SQL Types utility functions. + */ +public class SqlTypes { + private static final Map typeIdToName = new HashMap<>(); + + static { + typeIdToName.put(Types.BIT, "BIT"); + typeIdToName.put(Types.TINYINT, "TINYINT"); + typeIdToName.put(Types.SMALLINT, "SMALLINT"); + typeIdToName.put(Types.INTEGER, "INTEGER"); + typeIdToName.put(Types.BIGINT, "BIGINT"); + typeIdToName.put(Types.FLOAT, "FLOAT"); + typeIdToName.put(Types.REAL, "REAL"); + typeIdToName.put(Types.DOUBLE, "DOUBLE"); + typeIdToName.put(Types.NUMERIC, "NUMERIC"); + typeIdToName.put(Types.DECIMAL, "DECIMAL"); + typeIdToName.put(Types.CHAR, "CHAR"); + typeIdToName.put(Types.VARCHAR, "VARCHAR"); + typeIdToName.put(Types.LONGVARCHAR, "LONGVARCHAR"); + typeIdToName.put(Types.DATE, "DATE"); + typeIdToName.put(Types.TIME, "TIME"); + typeIdToName.put(Types.TIMESTAMP, "TIMESTAMP"); + typeIdToName.put(Types.BINARY, "BINARY"); + typeIdToName.put(Types.VARBINARY, "VARBINARY"); + typeIdToName.put(Types.LONGVARBINARY, "LONGVARBINARY"); + typeIdToName.put(Types.NULL, "NULL"); + typeIdToName.put(Types.OTHER, "OTHER"); + typeIdToName.put(Types.JAVA_OBJECT, "JAVA_OBJECT"); + typeIdToName.put(Types.DISTINCT, "DISTINCT"); + typeIdToName.put(Types.STRUCT, "STRUCT"); + typeIdToName.put(Types.ARRAY, "ARRAY"); + typeIdToName.put(Types.BLOB, "BLOB"); + typeIdToName.put(Types.CLOB, "CLOB"); + typeIdToName.put(Types.REF, "REF"); + typeIdToName.put(Types.DATALINK, "DATALINK"); + typeIdToName.put(Types.BOOLEAN, "BOOLEAN"); + typeIdToName.put(Types.ROWID, "ROWID"); + typeIdToName.put(Types.NCHAR, "NCHAR"); + typeIdToName.put(Types.NVARCHAR, "NVARCHAR"); + typeIdToName.put(Types.LONGNVARCHAR, "LONGNVARCHAR"); + typeIdToName.put(Types.NCLOB, "NCLOB"); + typeIdToName.put(Types.SQLXML, "SQLXML"); + typeIdToName.put(Types.REF_CURSOR, "REF_CURSOR"); + typeIdToName.put(Types.TIME_WITH_TIMEZONE, "TIME_WITH_TIMEZONE"); + typeIdToName.put(Types.TIMESTAMP_WITH_TIMEZONE, "TIMESTAMP_WITH_TIMEZONE"); + } + + /** + * Convert given {@link ArrowType} to its corresponding SQL type name. + * + * @param arrowType type to convert from + * @return corresponding SQL type name. + * @see java.sql.Types + */ + public static String getSqlTypeNameFromArrowType(ArrowType arrowType) { + final int typeId = getSqlTypeIdFromArrowType(arrowType); + return typeIdToName.get(typeId); + } + + + /** + * Convert given {@link ArrowType} to its corresponding SQL type ID. + * + * @param arrowType type to convert from + * @return corresponding SQL type ID. + * @see java.sql.Types + */ + public static int getSqlTypeIdFromArrowType(ArrowType arrowType) { + final ArrowType.ArrowTypeID typeID = arrowType.getTypeID(); + switch (typeID) { + case Int: + final int bitWidth = ((ArrowType.Int) arrowType).getBitWidth(); + switch (bitWidth) { + case 8: + return Types.TINYINT; + case 16: + return Types.SMALLINT; + case 32: + return Types.INTEGER; + case 64: + return Types.BIGINT; + default: + break; + } + break; + case Binary: + return Types.VARBINARY; + case FixedSizeBinary: + return Types.BINARY; + case LargeBinary: + return Types.LONGVARBINARY; + case Utf8: + return Types.VARCHAR; + case LargeUtf8: + return Types.LONGVARCHAR; + case Date: + return Types.DATE; + case Time: + return Types.TIME; + case Timestamp: + return Types.TIMESTAMP; + case Bool: + return Types.BOOLEAN; + case Decimal: + return Types.DECIMAL; + case FloatingPoint: + final FloatingPointPrecision floatingPointPrecision = + ((ArrowType.FloatingPoint) arrowType).getPrecision(); + switch (floatingPointPrecision) { + case DOUBLE: + return Types.DOUBLE; + case SINGLE: + return Types.FLOAT; + default: + break; + } + break; + case List: + case FixedSizeList: + case LargeList: + return Types.ARRAY; + case Struct: + case Duration: + case Interval: + case Map: + case Union: + return Types.JAVA_OBJECT; + case NONE: + case Null: + return Types.NULL; + default: + break; + } + + throw new IllegalArgumentException("Unsupported ArrowType " + arrowType); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/UrlParser.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/UrlParser.java new file mode 100644 index 0000000000000..e52251f53918a --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/UrlParser.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import java.io.UnsupportedEncodingException; +import java.net.URLDecoder; +import java.util.HashMap; +import java.util.Map; + +/** + * URL Parser for extracting key values from a connection string. + */ +public final class UrlParser { + private UrlParser() { + } + + /** + * Parse URL key value parameters. + * + *

URL-decodes keys and values. + * + * @param url {@link String} + * @return {@link Map} + */ + public static Map parse(String url, String separator) { + Map resultMap = new HashMap<>(); + if (url != null) { + String[] keyValues = url.split(separator); + + for (String keyValue : keyValues) { + try { + int separatorKey = keyValue.indexOf("="); // Find the first equal sign to split key and value. + if (separatorKey != -1) { // Avoid crashes when not finding an equal sign in the property value. + String key = keyValue.substring(0, separatorKey); + key = URLDecoder.decode(key, "UTF-8"); + String value = ""; + if (!keyValue.endsWith("=")) { // Avoid crashes for empty values. + value = keyValue.substring(separatorKey + 1); + } + value = URLDecoder.decode(value, "UTF-8"); + resultMap.put(key, value); + } + } catch (UnsupportedEncodingException e) { + throw new RuntimeException(e); + } + } + } + return resultMap; + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/VectorSchemaRootTransformer.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/VectorSchemaRootTransformer.java new file mode 100644 index 0000000000000..3bab918c83aab --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/VectorSchemaRootTransformer.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.BaseFixedWidthVector; +import org.apache.arrow.vector.BaseVariableWidthVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.Types; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; + +/** + * Converts Arrow's {@link VectorSchemaRoot} format to one JDBC would expect. + */ +@FunctionalInterface +public interface VectorSchemaRootTransformer { + VectorSchemaRoot transform(VectorSchemaRoot originalRoot, VectorSchemaRoot transformedRoot) + throws Exception; + + /** + * Transformer's helper class; builds a new {@link VectorSchemaRoot}. + */ + class Builder { + + private final Schema schema; + private final BufferAllocator bufferAllocator; + private final List newFields = new ArrayList<>(); + private final Collection tasks = new ArrayList<>(); + + public Builder(final Schema schema, final BufferAllocator bufferAllocator) { + this.schema = schema; + this.bufferAllocator = bufferAllocator + .newChildAllocator("VectorSchemaRootTransformer", 0, bufferAllocator.getLimit()); + } + + /** + * Add task to transform a vector to a new vector renaming it. + * This also adds transformedVectorName to the transformed {@link VectorSchemaRoot} schema. + * + * @param originalVectorName Name of the original vector to be transformed. + * @param transformedVectorName Name of the vector that is the result of the transformation. + * @return a VectorSchemaRoot instance with a task to rename a field vector. + */ + public Builder renameFieldVector(final String originalVectorName, + final String transformedVectorName) { + tasks.add((originalRoot, transformedRoot) -> { + final FieldVector originalVector = originalRoot.getVector(originalVectorName); + final FieldVector transformedVector = transformedRoot.getVector(transformedVectorName); + + final ArrowType originalType = originalVector.getField().getType(); + final ArrowType transformedType = transformedVector.getField().getType(); + if (!originalType.equals(transformedType)) { + throw new IllegalArgumentException(String.format( + "Can not transfer vector with field type %s to %s", originalType, transformedType)); + } + + if (originalVector instanceof BaseVariableWidthVector) { + ((BaseVariableWidthVector) originalVector).transferTo( + ((BaseVariableWidthVector) transformedVector)); + } else if (originalVector instanceof BaseFixedWidthVector) { + ((BaseFixedWidthVector) originalVector).transferTo( + ((BaseFixedWidthVector) transformedVector)); + } else { + throw new IllegalStateException(String.format( + "Can not transfer vector of type %s", originalVector.getClass())); + } + }); + + final Field originalField = schema.findField(originalVectorName); + newFields.add(new Field( + transformedVectorName, + new FieldType(originalField.isNullable(), originalField.getType(), + originalField.getDictionary(), originalField.getMetadata()), + originalField.getChildren()) + ); + + return this; + } + + /** + * Adds an empty field to the transformed {@link VectorSchemaRoot} schema. + * + * @param fieldName Name of the field to be added. + * @param fieldType Type of the field to be added. + * @return a VectorSchemaRoot instance with the current tasks. + */ + public Builder addEmptyField(final String fieldName, final Types.MinorType fieldType) { + newFields.add(Field.nullable(fieldName, fieldType.getType())); + + return this; + } + + /** + * Adds an empty field to the transformed {@link VectorSchemaRoot} schema. + * + * @param fieldName Name of the field to be added. + * @param fieldType Type of the field to be added. + * @return a VectorSchemaRoot instance with the current tasks. + */ + public Builder addEmptyField(final String fieldName, final ArrowType fieldType) { + newFields.add(Field.nullable(fieldName, fieldType)); + + return this; + } + + public VectorSchemaRootTransformer build() { + return (originalRoot, transformedRoot) -> { + if (transformedRoot == null) { + transformedRoot = VectorSchemaRoot.create(new Schema(newFields), bufferAllocator); + } + + for (final Task task : tasks) { + task.run(originalRoot, transformedRoot); + } + + transformedRoot.setRowCount(originalRoot.getRowCount()); + + originalRoot.clear(); + return transformedRoot; + }; + } + + /** + * Functional interface used to a task to transform a VectorSchemaRoot into a new VectorSchemaRoot. + */ + @FunctionalInterface + interface Task { + void run(VectorSchemaRoot originalRoot, VectorSchemaRoot transformedRoot); + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/resources/META-INF/services/java.sql.Driver b/java/flight/flight-sql-jdbc-driver/src/main/resources/META-INF/services/java.sql.Driver new file mode 100644 index 0000000000000..83cfb23427f71 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/resources/META-INF/services/java.sql.Driver @@ -0,0 +1,15 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +org.apache.arrow.driver.jdbc.ArrowFlightJdbcDriver \ No newline at end of file diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowDatabaseMetadataTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowDatabaseMetadataTest.java new file mode 100644 index 0000000000000..0d930f4c44e1f --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowDatabaseMetadataTest.java @@ -0,0 +1,1423 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import static com.google.protobuf.ByteString.copyFrom; +import static java.lang.String.format; +import static java.sql.Types.BIGINT; +import static java.sql.Types.BIT; +import static java.sql.Types.INTEGER; +import static java.sql.Types.JAVA_OBJECT; +import static java.util.Collections.singletonList; +import static java.util.stream.Collectors.toList; +import static java.util.stream.IntStream.range; +import static org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer.serializeSchema; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetCrossReference; +import static org.apache.arrow.flight.sql.impl.FlightSql.SqlSupportsConvert.SQL_CONVERT_BIGINT_VALUE; +import static org.apache.arrow.flight.sql.impl.FlightSql.SqlSupportsConvert.SQL_CONVERT_BIT_VALUE; +import static org.apache.arrow.flight.sql.impl.FlightSql.SqlSupportsConvert.SQL_CONVERT_INTEGER_VALUE; +import static org.hamcrest.CoreMatchers.is; + +import java.sql.Connection; +import java.sql.DatabaseMetaData; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.SQLException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.function.Consumer; + +import org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer; +import org.apache.arrow.driver.jdbc.utils.ResultSetTestUtils; +import org.apache.arrow.driver.jdbc.utils.ThrowableAssertionUtils; +import org.apache.arrow.flight.FlightProducer.ServerStreamListener; +import org.apache.arrow.flight.sql.FlightSqlProducer.Schemas; +import org.apache.arrow.flight.sql.impl.FlightSql; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetCatalogs; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetDbSchemas; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetExportedKeys; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetImportedKeys; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetPrimaryKeys; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetTableTypes; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetTables; +import org.apache.arrow.flight.sql.impl.FlightSql.SqlSupportedSubqueries; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.UInt1Vector; +import org.apache.arrow.vector.VarBinaryVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.DateUnit; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.TimeUnit; +import org.apache.arrow.vector.types.Types; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.Text; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.Message; + +/** + * Class containing the tests from the {@link ArrowDatabaseMetadata}. + */ +@SuppressWarnings("DoubleBraceInitialization") +public class ArrowDatabaseMetadataTest { + public static final boolean EXPECTED_MAX_ROW_SIZE_INCLUDES_BLOBS = false; + private static final MockFlightSqlProducer FLIGHT_SQL_PRODUCER = new MockFlightSqlProducer(); + @ClassRule + public static final FlightServerTestRule FLIGHT_SERVER_TEST_RULE = FlightServerTestRule + .createStandardTestRule(FLIGHT_SQL_PRODUCER); + private static final int ROW_COUNT = 10; + private static final List> EXPECTED_GET_CATALOGS_RESULTS = + range(0, ROW_COUNT) + .mapToObj(i -> format("catalog #%d", i)) + .map(Object.class::cast) + .map(Collections::singletonList) + .collect(toList()); + private static final List> EXPECTED_GET_TABLE_TYPES_RESULTS = + range(0, ROW_COUNT) + .mapToObj(i -> format("table_type #%d", i)) + .map(Object.class::cast) + .map(Collections::singletonList) + .collect(toList()); + private static final List> EXPECTED_GET_TABLES_RESULTS = + range(0, ROW_COUNT) + .mapToObj(i -> new Object[] { + format("catalog_name #%d", i), + format("db_schema_name #%d", i), + format("table_name #%d", i), + format("table_type #%d", i), + // TODO Add these fields to FlightSQL, as it's currently not possible to fetch them. + null, null, null, null, null, null}) + .map(Arrays::asList) + .collect(toList()); + private static final List> EXPECTED_GET_SCHEMAS_RESULTS = + range(0, ROW_COUNT) + .mapToObj(i -> new Object[] { + format("db_schema_name #%d", i), + format("catalog_name #%d", i)}) + .map(Arrays::asList) + .collect(toList()); + private static final List> EXPECTED_GET_EXPORTED_AND_IMPORTED_KEYS_RESULTS = + range(0, ROW_COUNT) + .mapToObj(i -> new Object[] { + format("pk_catalog_name #%d", i), + format("pk_db_schema_name #%d", i), + format("pk_table_name #%d", i), + format("pk_column_name #%d", i), + format("fk_catalog_name #%d", i), + format("fk_db_schema_name #%d", i), + format("fk_table_name #%d", i), + format("fk_column_name #%d", i), + i, + format("fk_key_name #%d", i), + format("pk_key_name #%d", i), + (byte) i, + (byte) i, + // TODO Add this field to FlightSQL, as it's currently not possible to fetch it. + null}) + .map(Arrays::asList) + .collect(toList()); + private static final List> EXPECTED_CROSS_REFERENCE_RESULTS = + EXPECTED_GET_EXPORTED_AND_IMPORTED_KEYS_RESULTS; + private static final List> EXPECTED_PRIMARY_KEYS_RESULTS = + range(0, ROW_COUNT) + .mapToObj(i -> new Object[] { + format("catalog_name #%d", i), + format("db_schema_name #%d", i), + format("table_name #%d", i), + format("column_name #%d", i), + i, + format("key_name #%d", i)}) + .map(Arrays::asList) + .collect(toList()); + private static final List FIELDS_GET_IMPORTED_EXPORTED_KEYS = ImmutableList.of( + "PKTABLE_CAT", "PKTABLE_SCHEM", "PKTABLE_NAME", + "PKCOLUMN_NAME", "FKTABLE_CAT", "FKTABLE_SCHEM", + "FKTABLE_NAME", "FKCOLUMN_NAME", "KEY_SEQ", + "FK_NAME", "PK_NAME", "UPDATE_RULE", "DELETE_RULE", + "DEFERRABILITY"); + private static final List FIELDS_GET_CROSS_REFERENCE = FIELDS_GET_IMPORTED_EXPORTED_KEYS; + private static final String TARGET_TABLE = "TARGET_TABLE"; + private static final String TARGET_FOREIGN_TABLE = "FOREIGN_TABLE"; + private static final String EXPECTED_DATABASE_PRODUCT_NAME = "Test Server Name"; + private static final String EXPECTED_DATABASE_PRODUCT_VERSION = "v0.0.1-alpha"; + private static final String EXPECTED_IDENTIFIER_QUOTE_STRING = "\""; + private static final boolean EXPECTED_IS_READ_ONLY = true; + private static final String EXPECTED_SQL_KEYWORDS = + "ADD, ADD CONSTRAINT, ALTER, ALTER TABLE, ANY, USER, TABLE"; + private static final String EXPECTED_NUMERIC_FUNCTIONS = + "ABS(), ACOS(), ASIN(), ATAN(), CEIL(), CEILING(), COT()"; + private static final String EXPECTED_STRING_FUNCTIONS = + "ASCII, CHAR, CHARINDEX, CONCAT, CONCAT_WS, FORMAT, LEFT"; + private static final String EXPECTED_SYSTEM_FUNCTIONS = + "CAST, CONVERT, CHOOSE, ISNULL, IS_NUMERIC, IIF, TRY_CAST"; + private static final String EXPECTED_TIME_DATE_FUNCTIONS = + "GETDATE(), DATEPART(), DATEADD(), DATEDIFF()"; + private static final String EXPECTED_SEARCH_STRING_ESCAPE = "\\"; + private static final String EXPECTED_EXTRA_NAME_CHARACTERS = ""; + private static final boolean EXPECTED_SUPPORTS_COLUMN_ALIASING = true; + private static final boolean EXPECTED_NULL_PLUS_NULL_IS_NULL = true; + private static final boolean EXPECTED_SQL_SUPPORTS_CONVERT = true; + private static final boolean EXPECTED_INVALID_SQL_SUPPORTS_CONVERT = false; + private static final boolean EXPECTED_SUPPORTS_TABLE_CORRELATION_NAMES = true; + private static final boolean EXPECTED_SUPPORTS_DIFFERENT_TABLE_CORRELATION_NAMES = false; + private static final boolean EXPECTED_EXPRESSIONS_IN_ORDER_BY = true; + private static final boolean EXPECTED_SUPPORTS_ORDER_BY_UNRELATED = true; + private static final boolean EXPECTED_SUPPORTS_GROUP_BY = true; + private static final boolean EXPECTED_SUPPORTS_GROUP_BY_UNRELATED = true; + private static final boolean EXPECTED_SUPPORTS_LIKE_ESCAPE_CLAUSE = true; + private static final boolean EXPECTED_NON_NULLABLE_COLUMNS = true; + private static final boolean EXPECTED_MINIMUM_SQL_GRAMMAR = true; + private static final boolean EXPECTED_CORE_SQL_GRAMMAR = true; + private static final boolean EXPECTED_EXTEND_SQL_GRAMMAR = false; + private static final boolean EXPECTED_ANSI92_ENTRY_LEVEL_SQL = true; + private static final boolean EXPECTED_ANSI92_INTERMEDIATE_SQL = true; + private static final boolean EXPECTED_ANSI92_FULL_SQL = false; + private static final String EXPECTED_SCHEMA_TERM = "schema"; + private static final String EXPECTED_PROCEDURE_TERM = "procedure"; + private static final String EXPECTED_CATALOG_TERM = "catalog"; + private static final boolean EXPECTED_SUPPORTS_INTEGRITY_ENHANCEMENT_FACILITY = true; + private static final boolean EXPECTED_SUPPORTS_OUTER_JOINS = true; + private static final boolean EXPECTED_SUPPORTS_FULL_OUTER_JOINS = true; + private static final boolean EXPECTED_SUPPORTS_LIMITED_JOINS = false; + private static final boolean EXPECTED_CATALOG_AT_START = true; + private static final boolean EXPECTED_SCHEMAS_IN_PROCEDURE_CALLS = true; + private static final boolean EXPECTED_SCHEMAS_IN_INDEX_DEFINITIONS = true; + private static final boolean EXPECTED_SCHEMAS_IN_PRIVILEGE_DEFINITIONS = false; + private static final boolean EXPECTED_CATALOGS_IN_INDEX_DEFINITIONS = true; + private static final boolean EXPECTED_CATALOGS_IN_PRIVILEGE_DEFINITIONS = false; + private static final boolean EXPECTED_POSITIONED_DELETE = true; + private static final boolean EXPECTED_POSITIONED_UPDATE = false; + private static final boolean EXPECTED_TYPE_FORWARD_ONLY = true; + private static final boolean EXPECTED_TYPE_SCROLL_INSENSITIVE = true; + private static final boolean EXPECTED_TYPE_SCROLL_SENSITIVE = false; + private static final boolean EXPECTED_SELECT_FOR_UPDATE_SUPPORTED = false; + private static final boolean EXPECTED_STORED_PROCEDURES_SUPPORTED = false; + private static final boolean EXPECTED_SUBQUERIES_IN_COMPARISON = true; + private static final boolean EXPECTED_SUBQUERIES_IN_EXISTS = false; + private static final boolean EXPECTED_SUBQUERIES_IN_INS = false; + private static final boolean EXPECTED_SUBQUERIES_IN_QUANTIFIEDS = false; + private static final SqlSupportedSubqueries[] EXPECTED_SUPPORTED_SUBQUERIES = new SqlSupportedSubqueries[] + {SqlSupportedSubqueries.SQL_SUBQUERIES_IN_COMPARISONS}; + private static final boolean EXPECTED_CORRELATED_SUBQUERIES_SUPPORTED = true; + private static final boolean EXPECTED_SUPPORTS_UNION = true; + private static final boolean EXPECTED_SUPPORTS_UNION_ALL = true; + private static final int EXPECTED_MAX_BINARY_LITERAL_LENGTH = 0; + private static final int EXPECTED_MAX_CHAR_LITERAL_LENGTH = 0; + private static final int EXPECTED_MAX_COLUMN_NAME_LENGTH = 1024; + private static final int EXPECTED_MAX_COLUMNS_IN_GROUP_BY = 0; + private static final int EXPECTED_MAX_COLUMNS_IN_INDEX = 0; + private static final int EXPECTED_MAX_COLUMNS_IN_ORDER_BY = 0; + private static final int EXPECTED_MAX_COLUMNS_IN_SELECT = 0; + private static final int EXPECTED_MAX_CONNECTIONS = 0; + private static final int EXPECTED_MAX_CURSOR_NAME_LENGTH = 1024; + private static final int EXPECTED_MAX_INDEX_LENGTH = 0; + private static final int EXPECTED_SCHEMA_NAME_LENGTH = 1024; + private static final int EXPECTED_MAX_PROCEDURE_NAME_LENGTH = 0; + private static final int EXPECTED_MAX_CATALOG_NAME_LENGTH = 1024; + private static final int EXPECTED_MAX_ROW_SIZE = 0; + private static final int EXPECTED_MAX_STATEMENT_LENGTH = 0; + private static final int EXPECTED_MAX_STATEMENTS = 0; + private static final int EXPECTED_MAX_TABLE_NAME_LENGTH = 1024; + private static final int EXPECTED_MAX_TABLES_IN_SELECT = 0; + private static final int EXPECTED_MAX_USERNAME_LENGTH = 1024; + private static final int EXPECTED_DEFAULT_TRANSACTION_ISOLATION = 0; + private static final boolean EXPECTED_TRANSACTIONS_SUPPORTED = false; + private static final boolean EXPECTED_TRANSACTION_NONE = false; + private static final boolean EXPECTED_TRANSACTION_READ_UNCOMMITTED = false; + private static final boolean EXPECTED_TRANSACTION_READ_COMMITTED = true; + private static final boolean EXPECTED_TRANSACTION_REPEATABLE_READ = false; + private static final boolean EXPECTED_TRANSACTION_SERIALIZABLE = true; + private static final boolean EXPECTED_DATA_DEFINITION_CAUSES_TRANSACTION_COMMIT = true; + private static final boolean EXPECTED_DATA_DEFINITIONS_IN_TRANSACTIONS_IGNORED = false; + private static final boolean EXPECTED_BATCH_UPDATES_SUPPORTED = true; + private static final boolean EXPECTED_SAVEPOINTS_SUPPORTED = false; + private static final boolean EXPECTED_NAMED_PARAMETERS_SUPPORTED = false; + private static final boolean EXPECTED_LOCATORS_UPDATE_COPY = true; + private static final boolean EXPECTED_STORED_FUNCTIONS_USING_CALL_SYNTAX_SUPPORTED = false; + private static final List> EXPECTED_GET_COLUMNS_RESULTS; + private static Connection connection; + + static { + List expectedGetColumnsDataTypes = Arrays.asList(3, 93, 4); + List expectedGetColumnsTypeName = Arrays.asList("DECIMAL", "TIMESTAMP", "INTEGER"); + List expectedGetColumnsRadix = Arrays.asList(10, null, 10); + List expectedGetColumnsColumnSize = Arrays.asList(5, 29, 10); + List expectedGetColumnsDecimalDigits = Arrays.asList(2, 9, 0); + List expectedGetColumnsIsNullable = Arrays.asList("YES", "YES", "NO"); + EXPECTED_GET_COLUMNS_RESULTS = range(0, ROW_COUNT * 3) + .mapToObj(i -> new Object[] { + format("catalog_name #%d", i / 3), + format("db_schema_name #%d", i / 3), + format("table_name%d", i / 3), + format("column_%d", (i % 3) + 1), + expectedGetColumnsDataTypes.get(i % 3), + expectedGetColumnsTypeName.get(i % 3), + expectedGetColumnsColumnSize.get(i % 3), + null, + expectedGetColumnsDecimalDigits.get(i % 3), + expectedGetColumnsRadix.get(i % 3), + !Objects.equals(expectedGetColumnsIsNullable.get(i % 3), "NO") ? 1 : 0, + null, null, null, null, null, + (i % 3) + 1, + expectedGetColumnsIsNullable.get(i % 3), + null, null, null, null, + "", ""}) + .map(Arrays::asList) + .collect(toList()); + } + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + public final ResultSetTestUtils resultSetTestUtils = new ResultSetTestUtils(collector); + + @BeforeClass + public static void setUpBeforeClass() throws SQLException { + connection = FLIGHT_SERVER_TEST_RULE.getConnection(false); + + final Message commandGetCatalogs = CommandGetCatalogs.getDefaultInstance(); + final Consumer commandGetCatalogsResultProducer = listener -> { + try (final BufferAllocator allocator = new RootAllocator(); + final VectorSchemaRoot root = VectorSchemaRoot.create(Schemas.GET_CATALOGS_SCHEMA, + allocator)) { + final VarCharVector catalogName = (VarCharVector) root.getVector("catalog_name"); + range(0, ROW_COUNT).forEach( + i -> catalogName.setSafe(i, new Text(format("catalog #%d", i)))); + root.setRowCount(ROW_COUNT); + listener.start(root); + listener.putNext(); + } catch (final Throwable throwable) { + listener.error(throwable); + } finally { + listener.completed(); + } + }; + FLIGHT_SQL_PRODUCER.addCatalogQuery(commandGetCatalogs, commandGetCatalogsResultProducer); + + final Message commandGetTableTypes = CommandGetTableTypes.getDefaultInstance(); + final Consumer commandGetTableTypesResultProducer = listener -> { + try (final BufferAllocator allocator = new RootAllocator(); + final VectorSchemaRoot root = VectorSchemaRoot.create(Schemas.GET_TABLE_TYPES_SCHEMA, + allocator)) { + final VarCharVector tableType = (VarCharVector) root.getVector("table_type"); + range(0, ROW_COUNT).forEach( + i -> tableType.setSafe(i, new Text(format("table_type #%d", i)))); + root.setRowCount(ROW_COUNT); + listener.start(root); + listener.putNext(); + } catch (final Throwable throwable) { + listener.error(throwable); + } finally { + listener.completed(); + } + }; + FLIGHT_SQL_PRODUCER.addCatalogQuery(commandGetTableTypes, commandGetTableTypesResultProducer); + + final Message commandGetTables = CommandGetTables.getDefaultInstance(); + final Consumer commandGetTablesResultProducer = listener -> { + try (final BufferAllocator allocator = new RootAllocator(); + final VectorSchemaRoot root = VectorSchemaRoot.create( + Schemas.GET_TABLES_SCHEMA_NO_SCHEMA, allocator)) { + final VarCharVector catalogName = (VarCharVector) root.getVector("catalog_name"); + final VarCharVector schemaName = (VarCharVector) root.getVector("db_schema_name"); + final VarCharVector tableName = (VarCharVector) root.getVector("table_name"); + final VarCharVector tableType = (VarCharVector) root.getVector("table_type"); + range(0, ROW_COUNT) + .peek(i -> catalogName.setSafe(i, new Text(format("catalog_name #%d", i)))) + .peek(i -> schemaName.setSafe(i, new Text(format("db_schema_name #%d", i)))) + .peek(i -> tableName.setSafe(i, new Text(format("table_name #%d", i)))) + .forEach(i -> tableType.setSafe(i, new Text(format("table_type #%d", i)))); + root.setRowCount(ROW_COUNT); + listener.start(root); + listener.putNext(); + } catch (final Throwable throwable) { + listener.error(throwable); + } finally { + listener.completed(); + } + }; + FLIGHT_SQL_PRODUCER.addCatalogQuery(commandGetTables, commandGetTablesResultProducer); + + final Message commandGetTablesWithSchema = CommandGetTables.newBuilder() + .setIncludeSchema(true) + .build(); + final Consumer commandGetTablesWithSchemaResultProducer = listener -> { + try (final BufferAllocator allocator = new RootAllocator(); + final VectorSchemaRoot root = VectorSchemaRoot.create(Schemas.GET_TABLES_SCHEMA, + allocator)) { + final byte[] filledTableSchemaBytes = + copyFrom( + serializeSchema(new Schema(Arrays.asList( + Field.nullable("column_1", ArrowType.Decimal.createDecimal(5, 2, 128)), + Field.nullable("column_2", new ArrowType.Timestamp(TimeUnit.NANOSECOND, "UTC")), + Field.notNullable("column_3", Types.MinorType.INT.getType()))))) + .toByteArray(); + final VarCharVector catalogName = (VarCharVector) root.getVector("catalog_name"); + final VarCharVector schemaName = (VarCharVector) root.getVector("db_schema_name"); + final VarCharVector tableName = (VarCharVector) root.getVector("table_name"); + final VarCharVector tableType = (VarCharVector) root.getVector("table_type"); + final VarBinaryVector tableSchema = (VarBinaryVector) root.getVector("table_schema"); + range(0, ROW_COUNT) + .peek(i -> catalogName.setSafe(i, new Text(format("catalog_name #%d", i)))) + .peek(i -> schemaName.setSafe(i, new Text(format("db_schema_name #%d", i)))) + .peek(i -> tableName.setSafe(i, new Text(format("table_name%d", i)))) + .peek(i -> tableType.setSafe(i, new Text(format("table_type #%d", i)))) + .forEach(i -> tableSchema.setSafe(i, filledTableSchemaBytes)); + root.setRowCount(ROW_COUNT); + listener.start(root); + listener.putNext(); + } catch (final Throwable throwable) { + listener.error(throwable); + } finally { + listener.completed(); + } + }; + FLIGHT_SQL_PRODUCER.addCatalogQuery(commandGetTablesWithSchema, + commandGetTablesWithSchemaResultProducer); + + final Message commandGetDbSchemas = CommandGetDbSchemas.getDefaultInstance(); + final Consumer commandGetSchemasResultProducer = listener -> { + try (final BufferAllocator allocator = new RootAllocator(); + final VectorSchemaRoot root = VectorSchemaRoot.create(Schemas.GET_SCHEMAS_SCHEMA, + allocator)) { + final VarCharVector catalogName = (VarCharVector) root.getVector("catalog_name"); + final VarCharVector schemaName = (VarCharVector) root.getVector("db_schema_name"); + range(0, ROW_COUNT) + .peek(i -> catalogName.setSafe(i, new Text(format("catalog_name #%d", i)))) + .forEach(i -> schemaName.setSafe(i, new Text(format("db_schema_name #%d", i)))); + root.setRowCount(ROW_COUNT); + listener.start(root); + listener.putNext(); + } catch (final Throwable throwable) { + listener.error(throwable); + } finally { + listener.completed(); + } + }; + FLIGHT_SQL_PRODUCER.addCatalogQuery(commandGetDbSchemas, commandGetSchemasResultProducer); + + final Message commandGetExportedKeys = + CommandGetExportedKeys.newBuilder().setTable(TARGET_TABLE).build(); + final Message commandGetImportedKeys = + CommandGetImportedKeys.newBuilder().setTable(TARGET_TABLE).build(); + final Message commandGetCrossReference = CommandGetCrossReference.newBuilder() + .setPkTable(TARGET_TABLE) + .setFkTable(TARGET_FOREIGN_TABLE) + .build(); + final Consumer commandGetExportedAndImportedKeysResultProducer = + listener -> { + try (final BufferAllocator allocator = new RootAllocator(); + final VectorSchemaRoot root = VectorSchemaRoot.create( + Schemas.GET_IMPORTED_KEYS_SCHEMA, + allocator)) { + final VarCharVector pkCatalogName = (VarCharVector) root.getVector("pk_catalog_name"); + final VarCharVector pkSchemaName = (VarCharVector) root.getVector("pk_db_schema_name"); + final VarCharVector pkTableName = (VarCharVector) root.getVector("pk_table_name"); + final VarCharVector pkColumnName = (VarCharVector) root.getVector("pk_column_name"); + final VarCharVector fkCatalogName = (VarCharVector) root.getVector("fk_catalog_name"); + final VarCharVector fkSchemaName = (VarCharVector) root.getVector("fk_db_schema_name"); + final VarCharVector fkTableName = (VarCharVector) root.getVector("fk_table_name"); + final VarCharVector fkColumnName = (VarCharVector) root.getVector("fk_column_name"); + final IntVector keySequence = (IntVector) root.getVector("key_sequence"); + final VarCharVector fkKeyName = (VarCharVector) root.getVector("fk_key_name"); + final VarCharVector pkKeyName = (VarCharVector) root.getVector("pk_key_name"); + final UInt1Vector updateRule = (UInt1Vector) root.getVector("update_rule"); + final UInt1Vector deleteRule = (UInt1Vector) root.getVector("delete_rule"); + range(0, ROW_COUNT) + .peek(i -> pkCatalogName.setSafe(i, new Text(format("pk_catalog_name #%d", i)))) + .peek(i -> pkSchemaName.setSafe(i, new Text(format("pk_db_schema_name #%d", i)))) + .peek(i -> pkTableName.setSafe(i, new Text(format("pk_table_name #%d", i)))) + .peek(i -> pkColumnName.setSafe(i, new Text(format("pk_column_name #%d", i)))) + .peek(i -> fkCatalogName.setSafe(i, new Text(format("fk_catalog_name #%d", i)))) + .peek(i -> fkSchemaName.setSafe(i, new Text(format("fk_db_schema_name #%d", i)))) + .peek(i -> fkTableName.setSafe(i, new Text(format("fk_table_name #%d", i)))) + .peek(i -> fkColumnName.setSafe(i, new Text(format("fk_column_name #%d", i)))) + .peek(i -> keySequence.setSafe(i, i)) + .peek(i -> fkKeyName.setSafe(i, new Text(format("fk_key_name #%d", i)))) + .peek(i -> pkKeyName.setSafe(i, new Text(format("pk_key_name #%d", i)))) + .peek(i -> updateRule.setSafe(i, i)) + .forEach(i -> deleteRule.setSafe(i, i)); + root.setRowCount(ROW_COUNT); + listener.start(root); + listener.putNext(); + } catch (final Throwable throwable) { + listener.error(throwable); + } finally { + listener.completed(); + } + }; + FLIGHT_SQL_PRODUCER.addCatalogQuery(commandGetExportedKeys, + commandGetExportedAndImportedKeysResultProducer); + FLIGHT_SQL_PRODUCER.addCatalogQuery(commandGetImportedKeys, + commandGetExportedAndImportedKeysResultProducer); + FLIGHT_SQL_PRODUCER.addCatalogQuery(commandGetCrossReference, + commandGetExportedAndImportedKeysResultProducer); + + final Message commandGetPrimaryKeys = + CommandGetPrimaryKeys.newBuilder().setTable(TARGET_TABLE).build(); + final Consumer commandGetPrimaryKeysResultProducer = listener -> { + try (final BufferAllocator allocator = new RootAllocator(); + final VectorSchemaRoot root = VectorSchemaRoot.create(Schemas.GET_PRIMARY_KEYS_SCHEMA, + allocator)) { + final VarCharVector catalogName = (VarCharVector) root.getVector("catalog_name"); + final VarCharVector schemaName = (VarCharVector) root.getVector("db_schema_name"); + final VarCharVector tableName = (VarCharVector) root.getVector("table_name"); + final VarCharVector columnName = (VarCharVector) root.getVector("column_name"); + final IntVector keySequence = (IntVector) root.getVector("key_sequence"); + final VarCharVector keyName = (VarCharVector) root.getVector("key_name"); + range(0, ROW_COUNT) + .peek(i -> catalogName.setSafe(i, new Text(format("catalog_name #%d", i)))) + .peek(i -> schemaName.setSafe(i, new Text(format("db_schema_name #%d", i)))) + .peek(i -> tableName.setSafe(i, new Text(format("table_name #%d", i)))) + .peek(i -> columnName.setSafe(i, new Text(format("column_name #%d", i)))) + .peek(i -> keySequence.setSafe(i, i)) + .forEach(i -> keyName.setSafe(i, new Text(format("key_name #%d", i)))); + root.setRowCount(ROW_COUNT); + listener.start(root); + listener.putNext(); + } catch (final Throwable throwable) { + listener.error(throwable); + } finally { + listener.completed(); + } + }; + FLIGHT_SQL_PRODUCER.addCatalogQuery(commandGetPrimaryKeys, commandGetPrimaryKeysResultProducer); + + FLIGHT_SQL_PRODUCER.getSqlInfoBuilder() + .withSqlOuterJoinSupportLevel(FlightSql.SqlOuterJoinsSupportLevel.SQL_FULL_OUTER_JOINS) + .withFlightSqlServerName(EXPECTED_DATABASE_PRODUCT_NAME) + .withFlightSqlServerVersion(EXPECTED_DATABASE_PRODUCT_VERSION) + .withSqlIdentifierQuoteChar(EXPECTED_IDENTIFIER_QUOTE_STRING) + .withFlightSqlServerReadOnly(EXPECTED_IS_READ_ONLY) + .withSqlKeywords(EXPECTED_SQL_KEYWORDS.split("\\s*,\\s*")) + .withSqlNumericFunctions(EXPECTED_NUMERIC_FUNCTIONS.split("\\s*,\\s*")) + .withSqlStringFunctions(EXPECTED_STRING_FUNCTIONS.split("\\s*,\\s*")) + .withSqlSystemFunctions(EXPECTED_SYSTEM_FUNCTIONS.split("\\s*,\\s*")) + .withSqlDatetimeFunctions(EXPECTED_TIME_DATE_FUNCTIONS.split("\\s*,\\s*")) + .withSqlSearchStringEscape(EXPECTED_SEARCH_STRING_ESCAPE) + .withSqlExtraNameCharacters(EXPECTED_EXTRA_NAME_CHARACTERS) + .withSqlSupportsColumnAliasing(EXPECTED_SUPPORTS_COLUMN_ALIASING) + .withSqlNullPlusNullIsNull(EXPECTED_NULL_PLUS_NULL_IS_NULL) + .withSqlSupportsConvert(ImmutableMap.of(SQL_CONVERT_BIT_VALUE, + Arrays.asList(SQL_CONVERT_INTEGER_VALUE, SQL_CONVERT_BIGINT_VALUE))) + .withSqlSupportsTableCorrelationNames(EXPECTED_SUPPORTS_TABLE_CORRELATION_NAMES) + .withSqlSupportsDifferentTableCorrelationNames( + EXPECTED_SUPPORTS_DIFFERENT_TABLE_CORRELATION_NAMES) + .withSqlSupportsExpressionsInOrderBy(EXPECTED_EXPRESSIONS_IN_ORDER_BY) + .withSqlSupportsOrderByUnrelated(EXPECTED_SUPPORTS_ORDER_BY_UNRELATED) + .withSqlSupportedGroupBy(FlightSql.SqlSupportedGroupBy.SQL_GROUP_BY_UNRELATED) + .withSqlSupportsLikeEscapeClause(EXPECTED_SUPPORTS_LIKE_ESCAPE_CLAUSE) + .withSqlSupportsNonNullableColumns(EXPECTED_NON_NULLABLE_COLUMNS) + .withSqlSupportedGrammar(FlightSql.SupportedSqlGrammar.SQL_CORE_GRAMMAR, + FlightSql.SupportedSqlGrammar.SQL_MINIMUM_GRAMMAR) + .withSqlAnsi92SupportedLevel(FlightSql.SupportedAnsi92SqlGrammarLevel.ANSI92_ENTRY_SQL, + FlightSql.SupportedAnsi92SqlGrammarLevel.ANSI92_INTERMEDIATE_SQL) + .withSqlSupportsIntegrityEnhancementFacility( + EXPECTED_SUPPORTS_INTEGRITY_ENHANCEMENT_FACILITY) + .withSqlSchemaTerm(EXPECTED_SCHEMA_TERM) + .withSqlCatalogTerm(EXPECTED_CATALOG_TERM) + .withSqlProcedureTerm(EXPECTED_PROCEDURE_TERM) + .withSqlCatalogAtStart(EXPECTED_CATALOG_AT_START) + .withSqlSchemasSupportedActions( + FlightSql.SqlSupportedElementActions.SQL_ELEMENT_IN_PROCEDURE_CALLS, + FlightSql.SqlSupportedElementActions.SQL_ELEMENT_IN_INDEX_DEFINITIONS) + .withSqlCatalogsSupportedActions( + FlightSql.SqlSupportedElementActions.SQL_ELEMENT_IN_INDEX_DEFINITIONS) + .withSqlSupportedPositionedCommands( + FlightSql.SqlSupportedPositionedCommands.SQL_POSITIONED_DELETE) + .withSqlSelectForUpdateSupported(EXPECTED_SELECT_FOR_UPDATE_SUPPORTED) + .withSqlStoredProceduresSupported(EXPECTED_STORED_PROCEDURES_SUPPORTED) + .withSqlSubQueriesSupported(EXPECTED_SUPPORTED_SUBQUERIES) + .withSqlCorrelatedSubqueriesSupported(EXPECTED_CORRELATED_SUBQUERIES_SUPPORTED) + .withSqlSupportedUnions(FlightSql.SqlSupportedUnions.SQL_UNION_ALL) + .withSqlMaxBinaryLiteralLength(EXPECTED_MAX_BINARY_LITERAL_LENGTH) + .withSqlMaxCharLiteralLength(EXPECTED_MAX_CHAR_LITERAL_LENGTH) + .withSqlMaxColumnNameLength(EXPECTED_MAX_COLUMN_NAME_LENGTH) + .withSqlMaxColumnsInGroupBy(EXPECTED_MAX_COLUMNS_IN_GROUP_BY) + .withSqlMaxColumnsInIndex(EXPECTED_MAX_COLUMNS_IN_INDEX) + .withSqlMaxColumnsInOrderBy(EXPECTED_MAX_COLUMNS_IN_ORDER_BY) + .withSqlMaxColumnsInSelect(EXPECTED_MAX_COLUMNS_IN_SELECT) + .withSqlMaxConnections(EXPECTED_MAX_CONNECTIONS) + .withSqlMaxCursorNameLength(EXPECTED_MAX_CURSOR_NAME_LENGTH) + .withSqlMaxIndexLength(EXPECTED_MAX_INDEX_LENGTH) + .withSqlDbSchemaNameLength(EXPECTED_SCHEMA_NAME_LENGTH) + .withSqlMaxProcedureNameLength(EXPECTED_MAX_PROCEDURE_NAME_LENGTH) + .withSqlMaxCatalogNameLength(EXPECTED_MAX_CATALOG_NAME_LENGTH) + .withSqlMaxRowSize(EXPECTED_MAX_ROW_SIZE) + .withSqlMaxRowSizeIncludesBlobs(EXPECTED_MAX_ROW_SIZE_INCLUDES_BLOBS) + .withSqlMaxStatementLength(EXPECTED_MAX_STATEMENT_LENGTH) + .withSqlMaxStatements(EXPECTED_MAX_STATEMENTS) + .withSqlMaxTableNameLength(EXPECTED_MAX_TABLE_NAME_LENGTH) + .withSqlMaxTablesInSelect(EXPECTED_MAX_TABLES_IN_SELECT) + .withSqlMaxUsernameLength(EXPECTED_MAX_USERNAME_LENGTH) + .withSqlDefaultTransactionIsolation(EXPECTED_DEFAULT_TRANSACTION_ISOLATION) + .withSqlTransactionsSupported(EXPECTED_TRANSACTIONS_SUPPORTED) + .withSqlSupportedTransactionsIsolationLevels( + FlightSql.SqlTransactionIsolationLevel.SQL_TRANSACTION_SERIALIZABLE, + FlightSql.SqlTransactionIsolationLevel.SQL_TRANSACTION_READ_COMMITTED) + .withSqlDataDefinitionCausesTransactionCommit( + EXPECTED_DATA_DEFINITION_CAUSES_TRANSACTION_COMMIT) + .withSqlDataDefinitionsInTransactionsIgnored( + EXPECTED_DATA_DEFINITIONS_IN_TRANSACTIONS_IGNORED) + .withSqlSupportedResultSetTypes( + FlightSql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_FORWARD_ONLY, + FlightSql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE) + .withSqlBatchUpdatesSupported(EXPECTED_BATCH_UPDATES_SUPPORTED) + .withSqlSavepointsSupported(EXPECTED_SAVEPOINTS_SUPPORTED) + .withSqlNamedParametersSupported(EXPECTED_NAMED_PARAMETERS_SUPPORTED) + .withSqlLocatorsUpdateCopy(EXPECTED_LOCATORS_UPDATE_COPY) + .withSqlStoredFunctionsUsingCallSyntaxSupported( + EXPECTED_STORED_FUNCTIONS_USING_CALL_SYNTAX_SUPPORTED); + } + + @AfterClass + public static void tearDown() throws Exception { + AutoCloseables.close(connection, FLIGHT_SQL_PRODUCER); + } + + + @Test + public void testGetCatalogsCanBeAccessedByIndices() throws SQLException { + try (final ResultSet resultSet = connection.getMetaData().getCatalogs()) { + resultSetTestUtils.testData(resultSet, EXPECTED_GET_CATALOGS_RESULTS); + } + } + + @Test + public void testGetCatalogsCanBeAccessedByNames() throws SQLException { + try (final ResultSet resultSet = connection.getMetaData().getCatalogs()) { + resultSetTestUtils.testData(resultSet, singletonList("TABLE_CAT"), + EXPECTED_GET_CATALOGS_RESULTS); + } + } + + @Test + public void testTableTypesCanBeAccessedByIndices() throws SQLException { + try (final ResultSet resultSet = connection.getMetaData().getTableTypes()) { + resultSetTestUtils.testData(resultSet, EXPECTED_GET_TABLE_TYPES_RESULTS); + } + } + + @Test + public void testTableTypesCanBeAccessedByNames() throws SQLException { + try (final ResultSet resultSet = connection.getMetaData().getTableTypes()) { + resultSetTestUtils.testData(resultSet, singletonList("TABLE_TYPE"), + EXPECTED_GET_TABLE_TYPES_RESULTS); + } + } + + @Test + public void testGetTablesCanBeAccessedByIndices() throws SQLException { + try (final ResultSet resultSet = connection.getMetaData().getTables(null, null, null, null)) { + resultSetTestUtils.testData(resultSet, EXPECTED_GET_TABLES_RESULTS); + } + } + + @Test + public void testGetTablesCanBeAccessedByNames() throws SQLException { + try (final ResultSet resultSet = connection.getMetaData().getTables(null, null, null, null)) { + resultSetTestUtils.testData( + resultSet, + ImmutableList.of( + "TABLE_CAT", + "TABLE_SCHEM", + "TABLE_NAME", + "TABLE_TYPE", + "REMARKS", + "TYPE_CAT", + "TYPE_SCHEM", + "TYPE_NAME", + "SELF_REFERENCING_COL_NAME", + "REF_GENERATION"), + EXPECTED_GET_TABLES_RESULTS + ); + } + } + + @Test + public void testGetSchemasCanBeAccessedByIndices() throws SQLException { + try (final ResultSet resultSet = connection.getMetaData().getSchemas()) { + resultSetTestUtils.testData(resultSet, EXPECTED_GET_SCHEMAS_RESULTS); + } + } + + @Test + public void testGetSchemasCanBeAccessedByNames() throws SQLException { + try (final ResultSet resultSet = connection.getMetaData().getSchemas()) { + resultSetTestUtils.testData(resultSet, ImmutableList.of("TABLE_SCHEM", "TABLE_CATALOG"), + EXPECTED_GET_SCHEMAS_RESULTS); + } + } + + @Test + public void testGetExportedKeysCanBeAccessedByIndices() throws SQLException { + try (final ResultSet resultSet = connection.getMetaData() + .getExportedKeys(null, null, TARGET_TABLE)) { + resultSetTestUtils.testData(resultSet, EXPECTED_GET_EXPORTED_AND_IMPORTED_KEYS_RESULTS); + } + } + + @Test + public void testGetExportedKeysCanBeAccessedByNames() throws SQLException { + try (final ResultSet resultSet = connection.getMetaData() + .getExportedKeys(null, null, TARGET_TABLE)) { + resultSetTestUtils.testData( + resultSet, FIELDS_GET_IMPORTED_EXPORTED_KEYS, + EXPECTED_GET_EXPORTED_AND_IMPORTED_KEYS_RESULTS); + } + } + + @Test + public void testGetImportedKeysCanBeAccessedByIndices() throws SQLException { + try (final ResultSet resultSet = connection.getMetaData() + .getImportedKeys(null, null, TARGET_TABLE)) { + resultSetTestUtils.testData(resultSet, EXPECTED_GET_EXPORTED_AND_IMPORTED_KEYS_RESULTS); + } + } + + @Test + public void testGetImportedKeysCanBeAccessedByNames() throws SQLException { + try (final ResultSet resultSet = connection.getMetaData() + .getImportedKeys(null, null, TARGET_TABLE)) { + resultSetTestUtils.testData( + resultSet, FIELDS_GET_IMPORTED_EXPORTED_KEYS, + EXPECTED_GET_EXPORTED_AND_IMPORTED_KEYS_RESULTS); + } + } + + @Test + public void testGetCrossReferenceCanBeAccessedByIndices() throws SQLException { + try (final ResultSet resultSet = connection.getMetaData().getCrossReference(null, null, + TARGET_TABLE, null, null, TARGET_FOREIGN_TABLE)) { + resultSetTestUtils.testData(resultSet, EXPECTED_CROSS_REFERENCE_RESULTS); + } + } + + @Test + public void testGetGetCrossReferenceCanBeAccessedByNames() throws SQLException { + try (final ResultSet resultSet = connection.getMetaData().getCrossReference(null, null, + TARGET_TABLE, null, null, TARGET_FOREIGN_TABLE)) { + resultSetTestUtils.testData( + resultSet, FIELDS_GET_CROSS_REFERENCE, EXPECTED_CROSS_REFERENCE_RESULTS); + } + } + + @Test + public void testPrimaryKeysCanBeAccessedByIndices() throws SQLException { + try (final ResultSet resultSet = connection.getMetaData() + .getPrimaryKeys(null, null, TARGET_TABLE)) { + resultSetTestUtils.testData(resultSet, EXPECTED_PRIMARY_KEYS_RESULTS); + } + } + + @Test + public void testPrimaryKeysCanBeAccessedByNames() throws SQLException { + try (final ResultSet resultSet = connection.getMetaData() + .getPrimaryKeys(null, null, TARGET_TABLE)) { + resultSetTestUtils.testData( + resultSet, + ImmutableList.of( + "TABLE_CAT", + "TABLE_SCHEM", + "TABLE_NAME", + "COLUMN_NAME", + "KEY_SEQ", + "PK_NAME"), + EXPECTED_PRIMARY_KEYS_RESULTS + ); + } + } + + @Test + public void testGetColumnsCanBeAccessedByIndices() throws SQLException { + try (final ResultSet resultSet = connection.getMetaData().getColumns(null, null, null, null)) { + resultSetTestUtils.testData(resultSet, EXPECTED_GET_COLUMNS_RESULTS); + } + } + + @Test + public void testGetColumnsCanByIndicesFilteringColumnNames() throws SQLException { + try ( + final ResultSet resultSet = connection.getMetaData() + .getColumns(null, null, null, "column_1")) { + resultSetTestUtils.testData(resultSet, EXPECTED_GET_COLUMNS_RESULTS + .stream() + .filter(insideList -> Objects.equals(insideList.get(3), "column_1")) + .collect(toList()) + ); + } + } + + @Test + public void testGetSqlInfo() throws SQLException { + final DatabaseMetaData metaData = connection.getMetaData(); + collector.checkThat(metaData.getDatabaseProductName(), is(EXPECTED_DATABASE_PRODUCT_NAME)); + collector.checkThat(metaData.getDatabaseProductVersion(), + is(EXPECTED_DATABASE_PRODUCT_VERSION)); + collector.checkThat(metaData.getIdentifierQuoteString(), is(EXPECTED_IDENTIFIER_QUOTE_STRING)); + collector.checkThat(metaData.isReadOnly(), is(EXPECTED_IS_READ_ONLY)); + collector.checkThat(metaData.getSQLKeywords(), is(EXPECTED_SQL_KEYWORDS)); + collector.checkThat(metaData.getNumericFunctions(), is(EXPECTED_NUMERIC_FUNCTIONS)); + collector.checkThat(metaData.getStringFunctions(), is(EXPECTED_STRING_FUNCTIONS)); + collector.checkThat(metaData.getSystemFunctions(), is(EXPECTED_SYSTEM_FUNCTIONS)); + collector.checkThat(metaData.getTimeDateFunctions(), is(EXPECTED_TIME_DATE_FUNCTIONS)); + collector.checkThat(metaData.getSearchStringEscape(), is(EXPECTED_SEARCH_STRING_ESCAPE)); + collector.checkThat(metaData.getExtraNameCharacters(), is(EXPECTED_EXTRA_NAME_CHARACTERS)); + collector.checkThat(metaData.supportsConvert(), is(EXPECTED_SQL_SUPPORTS_CONVERT)); + collector.checkThat(metaData.supportsConvert(BIT, INTEGER), is(EXPECTED_SQL_SUPPORTS_CONVERT)); + collector.checkThat(metaData.supportsConvert(BIT, BIGINT), is(EXPECTED_SQL_SUPPORTS_CONVERT)); + collector.checkThat(metaData.supportsConvert(BIGINT, INTEGER), + is(EXPECTED_INVALID_SQL_SUPPORTS_CONVERT)); + collector.checkThat(metaData.supportsConvert(JAVA_OBJECT, INTEGER), + is(EXPECTED_INVALID_SQL_SUPPORTS_CONVERT)); + collector.checkThat(metaData.supportsTableCorrelationNames(), + is(EXPECTED_SUPPORTS_TABLE_CORRELATION_NAMES)); + collector.checkThat(metaData.supportsExpressionsInOrderBy(), + is(EXPECTED_EXPRESSIONS_IN_ORDER_BY)); + collector.checkThat(metaData.supportsOrderByUnrelated(), + is(EXPECTED_SUPPORTS_ORDER_BY_UNRELATED)); + collector.checkThat(metaData.supportsGroupBy(), is(EXPECTED_SUPPORTS_GROUP_BY)); + collector.checkThat(metaData.supportsGroupByUnrelated(), + is(EXPECTED_SUPPORTS_GROUP_BY_UNRELATED)); + collector.checkThat(metaData.supportsLikeEscapeClause(), + is(EXPECTED_SUPPORTS_LIKE_ESCAPE_CLAUSE)); + collector.checkThat(metaData.supportsNonNullableColumns(), is(EXPECTED_NON_NULLABLE_COLUMNS)); + collector.checkThat(metaData.supportsMinimumSQLGrammar(), is(EXPECTED_MINIMUM_SQL_GRAMMAR)); + collector.checkThat(metaData.supportsCoreSQLGrammar(), is(EXPECTED_CORE_SQL_GRAMMAR)); + collector.checkThat(metaData.supportsExtendedSQLGrammar(), is(EXPECTED_EXTEND_SQL_GRAMMAR)); + collector.checkThat(metaData.supportsANSI92EntryLevelSQL(), + is(EXPECTED_ANSI92_ENTRY_LEVEL_SQL)); + collector.checkThat(metaData.supportsANSI92IntermediateSQL(), + is(EXPECTED_ANSI92_INTERMEDIATE_SQL)); + collector.checkThat(metaData.supportsANSI92FullSQL(), is(EXPECTED_ANSI92_FULL_SQL)); + collector.checkThat(metaData.supportsOuterJoins(), is(EXPECTED_SUPPORTS_OUTER_JOINS)); + collector.checkThat(metaData.supportsFullOuterJoins(), is(EXPECTED_SUPPORTS_FULL_OUTER_JOINS)); + collector.checkThat(metaData.supportsLimitedOuterJoins(), is(EXPECTED_SUPPORTS_LIMITED_JOINS)); + collector.checkThat(metaData.getSchemaTerm(), is(EXPECTED_SCHEMA_TERM)); + collector.checkThat(metaData.getProcedureTerm(), is(EXPECTED_PROCEDURE_TERM)); + collector.checkThat(metaData.getCatalogTerm(), is(EXPECTED_CATALOG_TERM)); + collector.checkThat(metaData.isCatalogAtStart(), is(EXPECTED_CATALOG_AT_START)); + collector.checkThat(metaData.supportsSchemasInProcedureCalls(), + is(EXPECTED_SCHEMAS_IN_PROCEDURE_CALLS)); + collector.checkThat(metaData.supportsSchemasInIndexDefinitions(), + is(EXPECTED_SCHEMAS_IN_INDEX_DEFINITIONS)); + collector.checkThat(metaData.supportsCatalogsInIndexDefinitions(), + is(EXPECTED_CATALOGS_IN_INDEX_DEFINITIONS)); + collector.checkThat(metaData.supportsPositionedDelete(), is(EXPECTED_POSITIONED_DELETE)); + collector.checkThat(metaData.supportsPositionedUpdate(), is(EXPECTED_POSITIONED_UPDATE)); + collector.checkThat(metaData.supportsResultSetType(ResultSet.TYPE_FORWARD_ONLY), + is(EXPECTED_TYPE_FORWARD_ONLY)); + collector.checkThat(metaData.supportsSelectForUpdate(), + is(EXPECTED_SELECT_FOR_UPDATE_SUPPORTED)); + collector.checkThat(metaData.supportsStoredProcedures(), + is(EXPECTED_STORED_PROCEDURES_SUPPORTED)); + collector.checkThat(metaData.supportsSubqueriesInComparisons(), + is(EXPECTED_SUBQUERIES_IN_COMPARISON)); + collector.checkThat(metaData.supportsSubqueriesInExists(), is(EXPECTED_SUBQUERIES_IN_EXISTS)); + collector.checkThat(metaData.supportsSubqueriesInIns(), is(EXPECTED_SUBQUERIES_IN_INS)); + collector.checkThat(metaData.supportsSubqueriesInQuantifieds(), + is(EXPECTED_SUBQUERIES_IN_QUANTIFIEDS)); + collector.checkThat(metaData.supportsCorrelatedSubqueries(), + is(EXPECTED_CORRELATED_SUBQUERIES_SUPPORTED)); + collector.checkThat(metaData.supportsUnion(), is(EXPECTED_SUPPORTS_UNION)); + collector.checkThat(metaData.supportsUnionAll(), is(EXPECTED_SUPPORTS_UNION_ALL)); + collector.checkThat(metaData.getMaxBinaryLiteralLength(), + is(EXPECTED_MAX_BINARY_LITERAL_LENGTH)); + collector.checkThat(metaData.getMaxCharLiteralLength(), is(EXPECTED_MAX_CHAR_LITERAL_LENGTH)); + collector.checkThat(metaData.getMaxColumnsInGroupBy(), is(EXPECTED_MAX_COLUMNS_IN_GROUP_BY)); + collector.checkThat(metaData.getMaxColumnsInIndex(), is(EXPECTED_MAX_COLUMNS_IN_INDEX)); + collector.checkThat(metaData.getMaxColumnsInOrderBy(), is(EXPECTED_MAX_COLUMNS_IN_ORDER_BY)); + collector.checkThat(metaData.getMaxColumnsInSelect(), is(EXPECTED_MAX_COLUMNS_IN_SELECT)); + collector.checkThat(metaData.getMaxConnections(), is(EXPECTED_MAX_CONNECTIONS)); + collector.checkThat(metaData.getMaxCursorNameLength(), is(EXPECTED_MAX_CURSOR_NAME_LENGTH)); + collector.checkThat(metaData.getMaxIndexLength(), is(EXPECTED_MAX_INDEX_LENGTH)); + collector.checkThat(metaData.getMaxSchemaNameLength(), is(EXPECTED_SCHEMA_NAME_LENGTH)); + collector.checkThat(metaData.getMaxProcedureNameLength(), + is(EXPECTED_MAX_PROCEDURE_NAME_LENGTH)); + collector.checkThat(metaData.getMaxCatalogNameLength(), is(EXPECTED_MAX_CATALOG_NAME_LENGTH)); + collector.checkThat(metaData.getMaxRowSize(), is(EXPECTED_MAX_ROW_SIZE)); + collector.checkThat(metaData.doesMaxRowSizeIncludeBlobs(), + is(EXPECTED_MAX_ROW_SIZE_INCLUDES_BLOBS)); + collector.checkThat(metaData.getMaxStatementLength(), is(EXPECTED_MAX_STATEMENT_LENGTH)); + collector.checkThat(metaData.getMaxStatements(), is(EXPECTED_MAX_STATEMENTS)); + collector.checkThat(metaData.getMaxTableNameLength(), is(EXPECTED_MAX_TABLE_NAME_LENGTH)); + collector.checkThat(metaData.getMaxTablesInSelect(), is(EXPECTED_MAX_TABLES_IN_SELECT)); + collector.checkThat(metaData.getMaxUserNameLength(), is(EXPECTED_MAX_USERNAME_LENGTH)); + collector.checkThat(metaData.getDefaultTransactionIsolation(), + is(EXPECTED_DEFAULT_TRANSACTION_ISOLATION)); + collector.checkThat(metaData.supportsTransactions(), is(EXPECTED_TRANSACTIONS_SUPPORTED)); + collector.checkThat(metaData.supportsBatchUpdates(), is(EXPECTED_BATCH_UPDATES_SUPPORTED)); + collector.checkThat(metaData.supportsSavepoints(), is(EXPECTED_SAVEPOINTS_SUPPORTED)); + collector.checkThat(metaData.supportsNamedParameters(), + is(EXPECTED_NAMED_PARAMETERS_SUPPORTED)); + collector.checkThat(metaData.locatorsUpdateCopy(), is(EXPECTED_LOCATORS_UPDATE_COPY)); + + collector.checkThat(metaData.supportsResultSetType(ResultSet.TYPE_SCROLL_INSENSITIVE), + is(EXPECTED_TYPE_SCROLL_INSENSITIVE)); + collector.checkThat(metaData.supportsResultSetType(ResultSet.TYPE_SCROLL_SENSITIVE), + is(EXPECTED_TYPE_SCROLL_SENSITIVE)); + collector.checkThat(metaData.supportsSchemasInPrivilegeDefinitions(), + is(EXPECTED_SCHEMAS_IN_PRIVILEGE_DEFINITIONS)); + collector.checkThat(metaData.supportsCatalogsInPrivilegeDefinitions(), + is(EXPECTED_CATALOGS_IN_PRIVILEGE_DEFINITIONS)); + collector.checkThat(metaData.supportsTransactionIsolationLevel(Connection.TRANSACTION_NONE), + is(EXPECTED_TRANSACTION_NONE)); + collector.checkThat( + metaData.supportsTransactionIsolationLevel(Connection.TRANSACTION_READ_COMMITTED), + is(EXPECTED_TRANSACTION_READ_COMMITTED)); + collector.checkThat( + metaData.supportsTransactionIsolationLevel(Connection.TRANSACTION_READ_UNCOMMITTED), + is(EXPECTED_TRANSACTION_READ_UNCOMMITTED)); + collector.checkThat( + metaData.supportsTransactionIsolationLevel(Connection.TRANSACTION_REPEATABLE_READ), + is(EXPECTED_TRANSACTION_REPEATABLE_READ)); + collector.checkThat( + metaData.supportsTransactionIsolationLevel(Connection.TRANSACTION_SERIALIZABLE), + is(EXPECTED_TRANSACTION_SERIALIZABLE)); + collector.checkThat(metaData.dataDefinitionCausesTransactionCommit(), + is(EXPECTED_DATA_DEFINITION_CAUSES_TRANSACTION_COMMIT)); + collector.checkThat(metaData.dataDefinitionIgnoredInTransactions(), + is(EXPECTED_DATA_DEFINITIONS_IN_TRANSACTIONS_IGNORED)); + collector.checkThat(metaData.supportsStoredFunctionsUsingCallSyntax(), + is(EXPECTED_STORED_FUNCTIONS_USING_CALL_SYNTAX_SUPPORTED)); + collector.checkThat(metaData.supportsIntegrityEnhancementFacility(), + is(EXPECTED_SUPPORTS_INTEGRITY_ENHANCEMENT_FACILITY)); + collector.checkThat(metaData.supportsDifferentTableCorrelationNames(), + is(EXPECTED_SUPPORTS_DIFFERENT_TABLE_CORRELATION_NAMES)); + + ThrowableAssertionUtils.simpleAssertThrowableClass(SQLException.class, + () -> metaData.supportsTransactionIsolationLevel(Connection.TRANSACTION_SERIALIZABLE + 1)); + ThrowableAssertionUtils.simpleAssertThrowableClass(SQLException.class, + () -> metaData.supportsResultSetType(ResultSet.HOLD_CURSORS_OVER_COMMIT)); + } + + @Test + public void testGetColumnsCanBeAccessedByNames() throws SQLException { + try (final ResultSet resultSet = connection.getMetaData().getColumns(null, null, null, null)) { + resultSetTestUtils.testData(resultSet, + ImmutableList.of( + "TABLE_CAT", + "TABLE_SCHEM", + "TABLE_NAME", + "COLUMN_NAME", + "DATA_TYPE", + "TYPE_NAME", + "COLUMN_SIZE", + "BUFFER_LENGTH", + "DECIMAL_DIGITS", + "NUM_PREC_RADIX", + "NULLABLE", + "REMARKS", + "COLUMN_DEF", + "SQL_DATA_TYPE", + "SQL_DATETIME_SUB", + "CHAR_OCTET_LENGTH", + "ORDINAL_POSITION", + "IS_NULLABLE", + "SCOPE_CATALOG", + "SCOPE_SCHEMA", + "SCOPE_TABLE", + "SOURCE_DATA_TYPE", + "IS_AUTOINCREMENT", + "IS_GENERATEDCOLUMN"), + EXPECTED_GET_COLUMNS_RESULTS + ); + } + } + + @Test + public void testGetProcedures() throws SQLException { + try (final ResultSet resultSet = connection.getMetaData().getProcedures(null, null, null)) { + // Maps ordinal index to column name according to JDBC documentation + final Map expectedGetProceduresSchema = new HashMap() { + { + put(1, "PROCEDURE_CAT"); + put(2, "PROCEDURE_SCHEM"); + put(3, "PROCEDURE_NAME"); + put(4, "FUTURE_USE1"); + put(5, "FUTURE_USE2"); + put(6, "FUTURE_USE3"); + put(7, "REMARKS"); + put(8, "PROCEDURE_TYPE"); + put(9, "SPECIFIC_NAME"); + } + }; + testEmptyResultSet(resultSet, expectedGetProceduresSchema); + } + } + + @Test + public void testGetProcedureColumns() throws SQLException { + try (ResultSet resultSet = connection.getMetaData() + .getProcedureColumns(null, null, null, null)) { + // Maps ordinal index to column name according to JDBC documentation + final Map expectedGetProcedureColumnsSchema = + new HashMap() { + { + put(1, "PROCEDURE_CAT"); + put(2, "PROCEDURE_SCHEM"); + put(3, "PROCEDURE_NAME"); + put(4, "COLUMN_NAME"); + put(5, "COLUMN_TYPE"); + put(6, "DATA_TYPE"); + put(7, "TYPE_NAME"); + put(8, "PRECISION"); + put(9, "LENGTH"); + put(10, "SCALE"); + put(11, "RADIX"); + put(12, "NULLABLE"); + put(13, "REMARKS"); + put(14, "COLUMN_DEF"); + put(15, "SQL_DATA_TYPE"); + put(16, "SQL_DATETIME_SUB"); + put(17, "CHAR_OCTET_LENGTH"); + put(18, "ORDINAL_POSITION"); + put(19, "IS_NULLABLE"); + put(20, "SPECIFIC_NAME"); + } + }; + testEmptyResultSet(resultSet, expectedGetProcedureColumnsSchema); + } + } + + @Test + public void testGetColumnPrivileges() throws SQLException { + try (ResultSet resultSet = connection.getMetaData() + .getColumnPrivileges(null, null, null, null)) { + // Maps ordinal index to column name according to JDBC documentation + final Map expectedGetColumnPrivilegesSchema = + new HashMap() { + { + put(1, "TABLE_CAT"); + put(2, "TABLE_SCHEM"); + put(3, "TABLE_NAME"); + put(4, "COLUMN_NAME"); + put(5, "GRANTOR"); + put(6, "GRANTEE"); + put(7, "PRIVILEGE"); + put(8, "IS_GRANTABLE"); + } + }; + testEmptyResultSet(resultSet, expectedGetColumnPrivilegesSchema); + } + } + + @Test + public void testGetTablePrivileges() throws SQLException { + try (ResultSet resultSet = connection.getMetaData().getTablePrivileges(null, null, null)) { + // Maps ordinal index to column name according to JDBC documentation + final Map expectedGetTablePrivilegesSchema = new HashMap() { + { + put(1, "TABLE_CAT"); + put(2, "TABLE_SCHEM"); + put(3, "TABLE_NAME"); + put(4, "GRANTOR"); + put(5, "GRANTEE"); + put(6, "PRIVILEGE"); + put(7, "IS_GRANTABLE"); + } + }; + testEmptyResultSet(resultSet, expectedGetTablePrivilegesSchema); + } + } + + @Test + public void testGetBestRowIdentifier() throws SQLException { + try (ResultSet resultSet = connection.getMetaData() + .getBestRowIdentifier(null, null, null, 0, true)) { + // Maps ordinal index to column name according to JDBC documentation + final Map expectedGetBestRowIdentifierSchema = + new HashMap() { + { + put(1, "SCOPE"); + put(2, "COLUMN_NAME"); + put(3, "DATA_TYPE"); + put(4, "TYPE_NAME"); + put(5, "COLUMN_SIZE"); + put(6, "BUFFER_LENGTH"); + put(7, "DECIMAL_DIGITS"); + put(8, "PSEUDO_COLUMN"); + } + }; + testEmptyResultSet(resultSet, expectedGetBestRowIdentifierSchema); + } + } + + @Test + public void testGetVersionColumns() throws SQLException { + try (ResultSet resultSet = connection.getMetaData().getVersionColumns(null, null, null)) { + // Maps ordinal index to column name according to JDBC documentation + final Map expectedGetVersionColumnsSchema = new HashMap() { + { + put(1, "SCOPE"); + put(2, "COLUMN_NAME"); + put(3, "DATA_TYPE"); + put(4, "TYPE_NAME"); + put(5, "COLUMN_SIZE"); + put(6, "BUFFER_LENGTH"); + put(7, "DECIMAL_DIGITS"); + put(8, "PSEUDO_COLUMN"); + } + }; + testEmptyResultSet(resultSet, expectedGetVersionColumnsSchema); + } + } + + @Test + public void testGetTypeInfo() throws SQLException { + try (ResultSet resultSet = connection.getMetaData().getTypeInfo()) { + // Maps ordinal index to column name according to JDBC documentation + final Map expectedGetTypeInfoSchema = new HashMap() { + { + put(1, "TYPE_NAME"); + put(2, "DATA_TYPE"); + put(3, "PRECISION"); + put(4, "LITERAL_PREFIX"); + put(5, "LITERAL_SUFFIX"); + put(6, "CREATE_PARAMS"); + put(7, "NULLABLE"); + put(8, "CASE_SENSITIVE"); + put(9, "SEARCHABLE"); + put(10, "UNSIGNED_ATTRIBUTE"); + put(11, "FIXED_PREC_SCALE"); + put(12, "AUTO_INCREMENT"); + put(13, "LOCAL_TYPE_NAME"); + put(14, "MINIMUM_SCALE"); + put(15, "MAXIMUM_SCALE"); + put(16, "SQL_DATA_TYPE"); + put(17, "SQL_DATETIME_SUB"); + put(18, "NUM_PREC_RADIX"); + } + }; + testEmptyResultSet(resultSet, expectedGetTypeInfoSchema); + } + } + + @Test + public void testGetIndexInfo() throws SQLException { + try (ResultSet resultSet = connection.getMetaData() + .getIndexInfo(null, null, null, false, true)) { + // Maps ordinal index to column name according to JDBC documentation + final Map expectedGetIndexInfoSchema = new HashMap() { + { + put(1, "TABLE_CAT"); + put(2, "TABLE_SCHEM"); + put(3, "TABLE_NAME"); + put(4, "NON_UNIQUE"); + put(5, "INDEX_QUALIFIER"); + put(6, "INDEX_NAME"); + put(7, "TYPE"); + put(8, "ORDINAL_POSITION"); + put(9, "COLUMN_NAME"); + put(10, "ASC_OR_DESC"); + put(11, "CARDINALITY"); + put(12, "PAGES"); + put(13, "FILTER_CONDITION"); + } + }; + testEmptyResultSet(resultSet, expectedGetIndexInfoSchema); + } + } + + @Test + public void testGetUDTs() throws SQLException { + try (ResultSet resultSet = connection.getMetaData().getUDTs(null, null, null, null)) { + // Maps ordinal index to column name according to JDBC documentation + final Map expectedGetUDTsSchema = new HashMap() { + { + put(1, "TYPE_CAT"); + put(2, "TYPE_SCHEM"); + put(3, "TYPE_NAME"); + put(4, "CLASS_NAME"); + put(5, "DATA_TYPE"); + put(6, "REMARKS"); + put(7, "BASE_TYPE"); + } + }; + testEmptyResultSet(resultSet, expectedGetUDTsSchema); + } + } + + @Test + public void testGetSuperTypes() throws SQLException { + try (ResultSet resultSet = connection.getMetaData().getSuperTypes(null, null, null)) { + // Maps ordinal index to column name according to JDBC documentation + final Map expectedGetSuperTypesSchema = new HashMap() { + { + put(1, "TYPE_CAT"); + put(2, "TYPE_SCHEM"); + put(3, "TYPE_NAME"); + put(4, "SUPERTYPE_CAT"); + put(5, "SUPERTYPE_SCHEM"); + put(6, "SUPERTYPE_NAME"); + } + }; + testEmptyResultSet(resultSet, expectedGetSuperTypesSchema); + } + } + + @Test + public void testGetSuperTables() throws SQLException { + try (ResultSet resultSet = connection.getMetaData().getSuperTables(null, null, null)) { + // Maps ordinal index to column name according to JDBC documentation + final Map expectedGetSuperTablesSchema = new HashMap() { + { + put(1, "TABLE_CAT"); + put(2, "TABLE_SCHEM"); + put(3, "TABLE_NAME"); + put(4, "SUPERTABLE_NAME"); + } + }; + testEmptyResultSet(resultSet, expectedGetSuperTablesSchema); + } + } + + @Test + public void testGetAttributes() throws SQLException { + try (ResultSet resultSet = connection.getMetaData().getAttributes(null, null, null, null)) { + // Maps ordinal index to column name according to JDBC documentation + final Map expectedGetAttributesSchema = new HashMap() { + { + put(1, "TYPE_CAT"); + put(2, "TYPE_SCHEM"); + put(3, "TYPE_NAME"); + put(4, "ATTR_NAME"); + put(5, "DATA_TYPE"); + put(6, "ATTR_TYPE_NAME"); + put(7, "ATTR_SIZE"); + put(8, "DECIMAL_DIGITS"); + put(9, "NUM_PREC_RADIX"); + put(10, "NULLABLE"); + put(11, "REMARKS"); + put(12, "ATTR_DEF"); + put(13, "SQL_DATA_TYPE"); + put(14, "SQL_DATETIME_SUB"); + put(15, "CHAR_OCTET_LENGTH"); + put(16, "ORDINAL_POSITION"); + put(17, "IS_NULLABLE"); + put(18, "SCOPE_CATALOG"); + put(19, "SCOPE_SCHEMA"); + put(20, "SCOPE_TABLE"); + put(21, "SOURCE_DATA_TYPE"); + } + }; + testEmptyResultSet(resultSet, expectedGetAttributesSchema); + } + } + + @Test + public void testGetClientInfoProperties() throws SQLException { + try (ResultSet resultSet = connection.getMetaData().getClientInfoProperties()) { + // Maps ordinal index to column name according to JDBC documentation + final Map expectedGetClientInfoPropertiesSchema = + new HashMap() { + { + put(1, "NAME"); + put(2, "MAX_LEN"); + put(3, "DEFAULT_VALUE"); + put(4, "DESCRIPTION"); + } + }; + testEmptyResultSet(resultSet, expectedGetClientInfoPropertiesSchema); + } + } + + @Test + public void testGetFunctions() throws SQLException { + try (ResultSet resultSet = connection.getMetaData().getFunctions(null, null, null)) { + // Maps ordinal index to column name according to JDBC documentation + final Map expectedGetFunctionsSchema = new HashMap() { + { + put(1, "FUNCTION_CAT"); + put(2, "FUNCTION_SCHEM"); + put(3, "FUNCTION_NAME"); + put(4, "REMARKS"); + put(5, "FUNCTION_TYPE"); + put(6, "SPECIFIC_NAME"); + } + }; + testEmptyResultSet(resultSet, expectedGetFunctionsSchema); + } + } + + @Test + public void testGetFunctionColumns() throws SQLException { + try ( + ResultSet resultSet = connection.getMetaData().getFunctionColumns(null, null, null, null)) { + // Maps ordinal index to column name according to JDBC documentation + final Map expectedGetFunctionColumnsSchema = new HashMap() { + { + put(1, "FUNCTION_CAT"); + put(2, "FUNCTION_SCHEM"); + put(3, "FUNCTION_NAME"); + put(4, "COLUMN_NAME"); + put(5, "COLUMN_TYPE"); + put(6, "DATA_TYPE"); + put(7, "TYPE_NAME"); + put(8, "PRECISION"); + put(9, "LENGTH"); + put(10, "SCALE"); + put(11, "RADIX"); + put(12, "NULLABLE"); + put(13, "REMARKS"); + put(14, "CHAR_OCTET_LENGTH"); + put(15, "ORDINAL_POSITION"); + put(16, "IS_NULLABLE"); + put(17, "SPECIFIC_NAME"); + } + }; + testEmptyResultSet(resultSet, expectedGetFunctionColumnsSchema); + } + } + + @Test + public void testGetPseudoColumns() throws SQLException { + try (ResultSet resultSet = connection.getMetaData().getPseudoColumns(null, null, null, null)) { + // Maps ordinal index to column name according to JDBC documentation + final Map expectedGetPseudoColumnsSchema = new HashMap() { + { + put(1, "TABLE_CAT"); + put(2, "TABLE_SCHEM"); + put(3, "TABLE_NAME"); + put(4, "COLUMN_NAME"); + put(5, "DATA_TYPE"); + put(6, "COLUMN_SIZE"); + put(7, "DECIMAL_DIGITS"); + put(8, "NUM_PREC_RADIX"); + put(9, "COLUMN_USAGE"); + put(10, "REMARKS"); + put(11, "CHAR_OCTET_LENGTH"); + put(12, "IS_NULLABLE"); + } + }; + testEmptyResultSet(resultSet, expectedGetPseudoColumnsSchema); + } + } + + private void testEmptyResultSet(final ResultSet resultSet, + final Map expectedResultSetSchema) + throws SQLException { + Assert.assertFalse(resultSet.next()); + final ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); + for (final Map.Entry entry : expectedResultSetSchema.entrySet()) { + Assert.assertEquals(entry.getValue(), resultSetMetaData.getColumnLabel(entry.getKey())); + } + } + + @Test + public void testGetColumnSize() { + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.COLUMN_SIZE_BYTE), + ArrowDatabaseMetadata.getColumnSize(new ArrowType.Int(Byte.SIZE, true))); + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.COLUMN_SIZE_SHORT), + ArrowDatabaseMetadata.getColumnSize(new ArrowType.Int(Short.SIZE, true))); + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.COLUMN_SIZE_INT), + ArrowDatabaseMetadata.getColumnSize(new ArrowType.Int(Integer.SIZE, true))); + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.COLUMN_SIZE_LONG), + ArrowDatabaseMetadata.getColumnSize(new ArrowType.Int(Long.SIZE, true))); + + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.COLUMN_SIZE_VARCHAR_AND_BINARY), + ArrowDatabaseMetadata.getColumnSize(new ArrowType.Utf8())); + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.COLUMN_SIZE_VARCHAR_AND_BINARY), + ArrowDatabaseMetadata.getColumnSize(new ArrowType.Binary())); + + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.COLUMN_SIZE_TIMESTAMP_SECONDS), + ArrowDatabaseMetadata.getColumnSize(new ArrowType.Timestamp(TimeUnit.SECOND, null))); + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.COLUMN_SIZE_TIMESTAMP_MILLISECONDS), + ArrowDatabaseMetadata.getColumnSize(new ArrowType.Timestamp(TimeUnit.MILLISECOND, null))); + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.COLUMN_SIZE_TIMESTAMP_MICROSECONDS), + ArrowDatabaseMetadata.getColumnSize(new ArrowType.Timestamp(TimeUnit.MICROSECOND, null))); + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.COLUMN_SIZE_TIMESTAMP_NANOSECONDS), + ArrowDatabaseMetadata.getColumnSize(new ArrowType.Timestamp(TimeUnit.NANOSECOND, null))); + + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.COLUMN_SIZE_TIME), + ArrowDatabaseMetadata.getColumnSize(new ArrowType.Time(TimeUnit.SECOND, Integer.SIZE))); + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.COLUMN_SIZE_TIME_MILLISECONDS), + ArrowDatabaseMetadata.getColumnSize( + new ArrowType.Time(TimeUnit.MILLISECOND, Integer.SIZE))); + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.COLUMN_SIZE_TIME_MICROSECONDS), + ArrowDatabaseMetadata.getColumnSize( + new ArrowType.Time(TimeUnit.MICROSECOND, Integer.SIZE))); + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.COLUMN_SIZE_TIME_NANOSECONDS), + ArrowDatabaseMetadata.getColumnSize(new ArrowType.Time(TimeUnit.NANOSECOND, Integer.SIZE))); + + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.COLUMN_SIZE_DATE), + ArrowDatabaseMetadata.getColumnSize(new ArrowType.Date(DateUnit.DAY))); + + Assert.assertNull(ArrowDatabaseMetadata.getColumnSize(new ArrowType.FloatingPoint( + FloatingPointPrecision.DOUBLE))); + } + + @Test + public void testGetDecimalDigits() { + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.NO_DECIMAL_DIGITS), + ArrowDatabaseMetadata.getDecimalDigits(new ArrowType.Int(Byte.SIZE, true))); + + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.NO_DECIMAL_DIGITS), + ArrowDatabaseMetadata.getDecimalDigits(new ArrowType.Timestamp(TimeUnit.SECOND, null))); + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.DECIMAL_DIGITS_TIME_MILLISECONDS), + ArrowDatabaseMetadata.getDecimalDigits( + new ArrowType.Timestamp(TimeUnit.MILLISECOND, null))); + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.DECIMAL_DIGITS_TIME_MICROSECONDS), + ArrowDatabaseMetadata.getDecimalDigits( + new ArrowType.Timestamp(TimeUnit.MICROSECOND, null))); + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.DECIMAL_DIGITS_TIME_NANOSECONDS), + ArrowDatabaseMetadata.getDecimalDigits(new ArrowType.Timestamp(TimeUnit.NANOSECOND, null))); + + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.NO_DECIMAL_DIGITS), + ArrowDatabaseMetadata.getDecimalDigits(new ArrowType.Time(TimeUnit.SECOND, Integer.SIZE))); + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.DECIMAL_DIGITS_TIME_MILLISECONDS), + ArrowDatabaseMetadata.getDecimalDigits( + new ArrowType.Time(TimeUnit.MILLISECOND, Integer.SIZE))); + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.DECIMAL_DIGITS_TIME_MICROSECONDS), + ArrowDatabaseMetadata.getDecimalDigits( + new ArrowType.Time(TimeUnit.MICROSECOND, Integer.SIZE))); + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.DECIMAL_DIGITS_TIME_NANOSECONDS), + ArrowDatabaseMetadata.getDecimalDigits( + new ArrowType.Time(TimeUnit.NANOSECOND, Integer.SIZE))); + + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.NO_DECIMAL_DIGITS), + ArrowDatabaseMetadata.getDecimalDigits(new ArrowType.Date(DateUnit.DAY))); + + Assert.assertNull(ArrowDatabaseMetadata.getDecimalDigits(new ArrowType.Utf8())); + } + + @Test + public void testSqlToRegexLike() { + Assert.assertEquals(".*", ArrowDatabaseMetadata.sqlToRegexLike("%")); + Assert.assertEquals(".", ArrowDatabaseMetadata.sqlToRegexLike("_")); + Assert.assertEquals("\\*", ArrowDatabaseMetadata.sqlToRegexLike("*")); + Assert.assertEquals("T\\*E.S.*T", ArrowDatabaseMetadata.sqlToRegexLike("T*E_S%T")); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcArrayTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcArrayTest.java new file mode 100644 index 0000000000000..90c926612f15a --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcArrayTest.java @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.SQLFeatureNotSupportedException; +import java.sql.Types; +import java.util.Arrays; +import java.util.HashMap; + +import org.apache.arrow.driver.jdbc.utils.RootAllocatorTestRule; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.util.JsonStringArrayList; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.junit.MockitoJUnitRunner; + +@RunWith(MockitoJUnitRunner.class) +public class ArrowFlightJdbcArrayTest { + + @Rule + public RootAllocatorTestRule rootAllocatorTestRule = new RootAllocatorTestRule(); + + IntVector dataVector; + + @Before + public void setup() { + dataVector = rootAllocatorTestRule.createIntVector(); + } + + @After + public void tearDown() { + this.dataVector.close(); + } + + @Test + public void testShouldGetBaseTypeNameReturnCorrectTypeName() { + ArrowFlightJdbcArray arrowFlightJdbcArray = + new ArrowFlightJdbcArray(dataVector, 0, dataVector.getValueCount()); + Assert.assertEquals("INTEGER", arrowFlightJdbcArray.getBaseTypeName()); + } + + @Test + public void testShouldGetBaseTypeReturnCorrectType() { + ArrowFlightJdbcArray arrowFlightJdbcArray = + new ArrowFlightJdbcArray(dataVector, 0, dataVector.getValueCount()); + Assert.assertEquals(Types.INTEGER, arrowFlightJdbcArray.getBaseType()); + } + + @Test + public void testShouldGetArrayReturnValidArray() throws SQLException { + ArrowFlightJdbcArray arrowFlightJdbcArray = + new ArrowFlightJdbcArray(dataVector, 0, dataVector.getValueCount()); + Object[] array = (Object[]) arrowFlightJdbcArray.getArray(); + + Object[] expected = new Object[dataVector.getValueCount()]; + for (int i = 0; i < expected.length; i++) { + expected[i] = dataVector.getObject(i); + } + Assert.assertArrayEquals(array, expected); + } + + @Test + public void testShouldGetArrayReturnValidArrayWithOffsets() throws SQLException { + ArrowFlightJdbcArray arrowFlightJdbcArray = + new ArrowFlightJdbcArray(dataVector, 0, dataVector.getValueCount()); + Object[] array = (Object[]) arrowFlightJdbcArray.getArray(1, 5); + + Object[] expected = new Object[5]; + for (int i = 0; i < expected.length; i++) { + expected[i] = dataVector.getObject(i + 1); + } + Assert.assertArrayEquals(array, expected); + } + + @Test(expected = ArrayIndexOutOfBoundsException.class) + public void testShouldGetArrayWithOffsetsThrowArrayIndexOutOfBoundsException() + throws SQLException { + ArrowFlightJdbcArray arrowFlightJdbcArray = + new ArrowFlightJdbcArray(dataVector, 0, dataVector.getValueCount()); + arrowFlightJdbcArray.getArray(0, dataVector.getValueCount() + 1); + } + + @Test(expected = SQLFeatureNotSupportedException.class) + public void testShouldGetArrayWithMapNotBeSupported() throws SQLException { + ArrowFlightJdbcArray arrowFlightJdbcArray = + new ArrowFlightJdbcArray(dataVector, 0, dataVector.getValueCount()); + HashMap> map = new HashMap<>(); + arrowFlightJdbcArray.getArray(map); + } + + @Test(expected = SQLFeatureNotSupportedException.class) + public void testShouldGetArrayWithOffsetsAndMapNotBeSupported() throws SQLException { + ArrowFlightJdbcArray arrowFlightJdbcArray = + new ArrowFlightJdbcArray(dataVector, 0, dataVector.getValueCount()); + HashMap> map = new HashMap<>(); + arrowFlightJdbcArray.getArray(0, 5, map); + } + + @Test + public void testShouldGetResultSetReturnValidResultSet() throws SQLException { + ArrowFlightJdbcArray arrowFlightJdbcArray = + new ArrowFlightJdbcArray(dataVector, 0, dataVector.getValueCount()); + try (ResultSet resultSet = arrowFlightJdbcArray.getResultSet()) { + int count = 0; + while (resultSet.next()) { + Assert.assertEquals((Object) resultSet.getInt(1), dataVector.getObject(count)); + count++; + } + } + } + + @Test + public void testShouldGetResultSetReturnValidResultSetWithOffsets() throws SQLException { + ArrowFlightJdbcArray arrowFlightJdbcArray = + new ArrowFlightJdbcArray(dataVector, 0, dataVector.getValueCount()); + try (ResultSet resultSet = arrowFlightJdbcArray.getResultSet(3, 5)) { + int count = 0; + while (resultSet.next()) { + Assert.assertEquals((Object) resultSet.getInt(1), dataVector.getObject(count + 3)); + count++; + } + Assert.assertEquals(count, 5); + } + } + + @Test + public void testToString() throws SQLException { + ArrowFlightJdbcArray arrowFlightJdbcArray = + new ArrowFlightJdbcArray(dataVector, 0, dataVector.getValueCount()); + + JsonStringArrayList array = new JsonStringArrayList<>(); + array.addAll(Arrays.asList((Object[]) arrowFlightJdbcArray.getArray())); + + Assert.assertEquals(array.toString(), arrowFlightJdbcArray.toString()); + } + + @Test(expected = SQLFeatureNotSupportedException.class) + public void testShouldGetResultSetWithMapNotBeSupported() throws SQLException { + ArrowFlightJdbcArray arrowFlightJdbcArray = + new ArrowFlightJdbcArray(dataVector, 0, dataVector.getValueCount()); + HashMap> map = new HashMap<>(); + arrowFlightJdbcArray.getResultSet(map); + } + + @Test(expected = SQLFeatureNotSupportedException.class) + public void testShouldGetResultSetWithOffsetsAndMapNotBeSupported() throws SQLException { + ArrowFlightJdbcArray arrowFlightJdbcArray = + new ArrowFlightJdbcArray(dataVector, 0, dataVector.getValueCount()); + HashMap> map = new HashMap<>(); + arrowFlightJdbcArray.getResultSet(0, 5, map); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcConnectionCookieTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcConnectionCookieTest.java new file mode 100644 index 0000000000000..c7268e0594ecc --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcConnectionCookieTest.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import java.sql.Connection; +import java.sql.SQLException; +import java.sql.Statement; + +import org.apache.arrow.driver.jdbc.utils.CoreMockedSqlProducers; +import org.junit.Assert; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; + +public class ArrowFlightJdbcConnectionCookieTest { + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + + @ClassRule + public static final FlightServerTestRule FLIGHT_SERVER_TEST_RULE = + FlightServerTestRule.createStandardTestRule(CoreMockedSqlProducers.getLegacyProducer()); + + @Test + public void testCookies() throws SQLException { + try (Connection connection = FLIGHT_SERVER_TEST_RULE.getConnection(false); + Statement statement = connection.createStatement()) { + + // Expect client didn't receive cookies before any operation + Assert.assertNull(FLIGHT_SERVER_TEST_RULE.getMiddlewareCookieFactory().getCookie()); + + // Run another action for check if the cookies was sent by the server. + statement.execute(CoreMockedSqlProducers.LEGACY_REGULAR_SQL_CMD); + Assert.assertEquals("k=v", FLIGHT_SERVER_TEST_RULE.getMiddlewareCookieFactory().getCookie()); + } + } +} + diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcConnectionPoolDataSourceTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcConnectionPoolDataSourceTest.java new file mode 100644 index 0000000000000..f4a5c87a23cc2 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcConnectionPoolDataSourceTest.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import java.sql.Connection; + +import javax.sql.PooledConnection; + +import org.apache.arrow.driver.jdbc.authentication.UserPasswordAuthentication; +import org.apache.arrow.driver.jdbc.utils.ConnectionWrapper; +import org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Test; + +public class ArrowFlightJdbcConnectionPoolDataSourceTest { + @ClassRule + public static final FlightServerTestRule FLIGHT_SERVER_TEST_RULE; + + private static final MockFlightSqlProducer PRODUCER = new MockFlightSqlProducer(); + + static { + UserPasswordAuthentication authentication = + new UserPasswordAuthentication.Builder() + .user("user1", "pass1") + .user("user2", "pass2") + .build(); + + FLIGHT_SERVER_TEST_RULE = new FlightServerTestRule.Builder() + .host("localhost") + .randomPort() + .authentication(authentication) + .producer(PRODUCER) + .build(); + } + + private ArrowFlightJdbcConnectionPoolDataSource dataSource; + + @Before + public void setUp() { + dataSource = FLIGHT_SERVER_TEST_RULE.createConnectionPoolDataSource(false); + } + + @After + public void tearDown() throws Exception { + dataSource.close(); + } + + @Test + public void testShouldInnerConnectionIsClosedReturnCorrectly() throws Exception { + PooledConnection pooledConnection = dataSource.getPooledConnection(); + Connection connection = pooledConnection.getConnection(); + Assert.assertFalse(connection.isClosed()); + connection.close(); + Assert.assertTrue(connection.isClosed()); + } + + @Test + public void testShouldInnerConnectionShouldIgnoreDoubleClose() throws Exception { + PooledConnection pooledConnection = dataSource.getPooledConnection(); + Connection connection = pooledConnection.getConnection(); + Assert.assertFalse(connection.isClosed()); + connection.close(); + Assert.assertTrue(connection.isClosed()); + } + + @Test + public void testShouldInnerConnectionIsClosedReturnTrueIfPooledConnectionCloses() + throws Exception { + PooledConnection pooledConnection = dataSource.getPooledConnection(); + Connection connection = pooledConnection.getConnection(); + Assert.assertFalse(connection.isClosed()); + pooledConnection.close(); + Assert.assertTrue(connection.isClosed()); + } + + @Test + public void testShouldReuseConnectionsOnPool() throws Exception { + PooledConnection pooledConnection = dataSource.getPooledConnection("user1", "pass1"); + ConnectionWrapper connection = ((ConnectionWrapper) pooledConnection.getConnection()); + Assert.assertFalse(connection.isClosed()); + connection.close(); + Assert.assertTrue(connection.isClosed()); + Assert.assertFalse(connection.unwrap(ArrowFlightConnection.class).isClosed()); + + PooledConnection pooledConnection2 = dataSource.getPooledConnection("user1", "pass1"); + ConnectionWrapper connection2 = ((ConnectionWrapper) pooledConnection2.getConnection()); + Assert.assertFalse(connection2.isClosed()); + connection2.close(); + Assert.assertTrue(connection2.isClosed()); + Assert.assertFalse(connection2.unwrap(ArrowFlightConnection.class).isClosed()); + + Assert.assertSame(pooledConnection, pooledConnection2); + Assert.assertNotSame(connection, connection2); + Assert.assertSame(connection.unwrap(ArrowFlightConnection.class), + connection2.unwrap(ArrowFlightConnection.class)); + } + + @Test + public void testShouldNotMixConnectionsForDifferentUsers() throws Exception { + PooledConnection pooledConnection = dataSource.getPooledConnection("user1", "pass1"); + ConnectionWrapper connection = ((ConnectionWrapper) pooledConnection.getConnection()); + Assert.assertFalse(connection.isClosed()); + connection.close(); + Assert.assertTrue(connection.isClosed()); + Assert.assertFalse(connection.unwrap(ArrowFlightConnection.class).isClosed()); + + PooledConnection pooledConnection2 = dataSource.getPooledConnection("user2", "pass2"); + ConnectionWrapper connection2 = ((ConnectionWrapper) pooledConnection2.getConnection()); + Assert.assertFalse(connection2.isClosed()); + connection2.close(); + Assert.assertTrue(connection2.isClosed()); + Assert.assertFalse(connection2.unwrap(ArrowFlightConnection.class).isClosed()); + + Assert.assertNotSame(pooledConnection, pooledConnection2); + Assert.assertNotSame(connection, connection2); + Assert.assertNotSame(connection.unwrap(ArrowFlightConnection.class), + connection2.unwrap(ArrowFlightConnection.class)); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcCursorTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcCursorTest.java new file mode 100644 index 0000000000000..b818f7115b7f9 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcCursorTest.java @@ -0,0 +1,251 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import static org.junit.Assert.assertTrue; + +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.DateDayVector; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.DurationVector; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntervalDayVector; +import org.apache.arrow.vector.TimeMilliVector; +import org.apache.arrow.vector.TimeStampMilliVector; +import org.apache.arrow.vector.UInt4Vector; +import org.apache.arrow.vector.VarBinaryVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.FixedSizeListVector; +import org.apache.arrow.vector.complex.LargeListVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.MapVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.types.DateUnit; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.IntervalUnit; +import org.apache.arrow.vector.types.TimeUnit; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.calcite.avatica.util.Cursor; +import org.junit.After; +import org.junit.Test; + +import com.google.common.collect.ImmutableList; + +/** + * Tests for {@link ArrowFlightJdbcCursor}. + */ +public class ArrowFlightJdbcCursorTest { + + ArrowFlightJdbcCursor cursor; + BufferAllocator allocator; + + @After + public void cleanUp() { + allocator.close(); + cursor.close(); + } + + @Test + public void testBinaryVectorNullTrue() throws SQLException { + final VectorSchemaRoot root = getVectorSchemaRoot("Binary", new ArrowType.Binary(), null); + ((VarBinaryVector) root.getVector("Binary")).setNull(0); + testCursorWasNull(root); + } + + @Test + public void testDateVectorNullTrue() throws SQLException { + final VectorSchemaRoot root = + getVectorSchemaRoot("Date", new ArrowType.Date(DateUnit.DAY), null); + ((DateDayVector) root.getVector("Date")).setNull(0); + testCursorWasNull(root); + } + + @Test + public void testDurationVectorNullTrue() throws SQLException { + final VectorSchemaRoot root = getVectorSchemaRoot("Duration", + new ArrowType.Duration(TimeUnit.MILLISECOND), null); + ((DurationVector) root.getVector("Duration")).setNull(0); + testCursorWasNull(root); + } + + @Test + public void testDateInternalNullTrue() throws SQLException { + final VectorSchemaRoot root = getVectorSchemaRoot("Interval", + new ArrowType.Interval(IntervalUnit.DAY_TIME), null); + ((IntervalDayVector) root.getVector("Interval")).setNull(0); + testCursorWasNull(root); + } + + @Test + public void testTimeStampVectorNullTrue() throws SQLException { + final VectorSchemaRoot root = getVectorSchemaRoot("TimeStamp", + new ArrowType.Timestamp(TimeUnit.MILLISECOND, null), null); + ((TimeStampMilliVector) root.getVector("TimeStamp")).setNull(0); + testCursorWasNull(root); + } + + @Test + public void testTimeVectorNullTrue() throws SQLException { + final VectorSchemaRoot root = getVectorSchemaRoot("Time", + new ArrowType.Time(TimeUnit.MILLISECOND, 32), null); + ((TimeMilliVector) root.getVector("Time")).setNull(0); + testCursorWasNull(root); + + } + + @Test + public void testFixedSizeListVectorNullTrue() throws SQLException { + List fieldList = new ArrayList<>(); + fieldList.add(new Field("Null", new FieldType(true, new ArrowType.Null(), null), + null)); + final VectorSchemaRoot root = getVectorSchemaRoot("FixedSizeList", + new ArrowType.FixedSizeList(10), fieldList); + ((FixedSizeListVector) root.getVector("FixedSizeList")).setNull(0); + testCursorWasNull(root); + } + + @Test + public void testLargeListVectorNullTrue() throws SQLException { + List fieldList = new ArrayList<>(); + fieldList.add(new Field("Null", new FieldType(true, new ArrowType.Null(), null), + null)); + final VectorSchemaRoot root = + getVectorSchemaRoot("LargeList", new ArrowType.LargeList(), fieldList); + ((LargeListVector) root.getVector("LargeList")).setNull(0); + testCursorWasNull(root); + } + + @Test + public void testListVectorNullTrue() throws SQLException { + List fieldList = new ArrayList<>(); + fieldList.add(new Field("Null", new FieldType(true, new ArrowType.Null(), null), + null)); + final VectorSchemaRoot root = getVectorSchemaRoot("List", new ArrowType.List(), fieldList); + ((ListVector) root.getVector("List")).setNull(0); + testCursorWasNull(root); + } + + @Test + public void testMapVectorNullTrue() throws SQLException { + List structChildren = new ArrayList<>(); + structChildren.add(new Field("Key", new FieldType(false, new ArrowType.Utf8(), null), + null)); + structChildren.add(new Field("Value", new FieldType(false, new ArrowType.Utf8(), null), + null)); + List fieldList = new ArrayList<>(); + fieldList.add(new Field("Struct", new FieldType(false, new ArrowType.Struct(), null), + structChildren)); + final VectorSchemaRoot root = getVectorSchemaRoot("Map", new ArrowType.Map(false), fieldList); + ((MapVector) root.getVector("Map")).setNull(0); + testCursorWasNull(root); + } + + @Test + public void testStructVectorNullTrue() throws SQLException { + final VectorSchemaRoot root = getVectorSchemaRoot("Struct", new ArrowType.Struct(), null); + ((StructVector) root.getVector("Struct")).setNull(0); + testCursorWasNull(root); + } + + @Test + public void testBaseIntVectorNullTrue() throws SQLException { + final VectorSchemaRoot root = getVectorSchemaRoot("BaseInt", + new ArrowType.Int(32, false), null); + ((UInt4Vector) root.getVector("BaseInt")).setNull(0); + testCursorWasNull(root); + } + + @Test + public void testBitVectorNullTrue() throws SQLException { + final VectorSchemaRoot root = getVectorSchemaRoot("Bit", new ArrowType.Bool(), null); + ((BitVector) root.getVector("Bit")).setNull(0); + testCursorWasNull(root); + } + + @Test + public void testDecimalVectorNullTrue() throws SQLException { + final VectorSchemaRoot root = getVectorSchemaRoot("Decimal", + new ArrowType.Decimal(2, 2, 128), null); + ((DecimalVector) root.getVector("Decimal")).setNull(0); + testCursorWasNull(root); + } + + @Test + public void testFloat4VectorNullTrue() throws SQLException { + final VectorSchemaRoot root = getVectorSchemaRoot("Float4", + new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE), null); + ((Float4Vector) root.getVector("Float4")).setNull(0); + testCursorWasNull(root); + } + + @Test + public void testFloat8VectorNullTrue() throws SQLException { + final VectorSchemaRoot root = getVectorSchemaRoot("Float8", + new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE), null); + ((Float8Vector) root.getVector("Float8")).setNull(0); + testCursorWasNull(root); + } + + @Test + public void testVarCharVectorNullTrue() throws SQLException { + final VectorSchemaRoot root = getVectorSchemaRoot("VarChar", new ArrowType.Utf8(), null); + ((VarCharVector) root.getVector("VarChar")).setNull(0); + testCursorWasNull(root); + } + + @Test + public void testNullVectorNullTrue() throws SQLException { + final VectorSchemaRoot root = getVectorSchemaRoot("Null", new ArrowType.Null(), null); + testCursorWasNull(root); + } + + private VectorSchemaRoot getVectorSchemaRoot(String name, ArrowType arrowType, + List children) { + final Schema schema = new Schema(ImmutableList.of( + new Field( + name, + new FieldType(true, arrowType, + null), + children))); + allocator = new RootAllocator(Long.MAX_VALUE); + final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator); + root.allocateNew(); + return root; + } + + private void testCursorWasNull(VectorSchemaRoot root) throws SQLException { + root.setRowCount(1); + cursor = new ArrowFlightJdbcCursor(root); + cursor.next(); + List accessorList = cursor.createAccessors(null, null, null); + accessorList.get(0).getObject(); + assertTrue(cursor.wasNull()); + root.close(); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcDriverTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcDriverTest.java new file mode 100644 index 0000000000000..682c20c696ac3 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcDriverTest.java @@ -0,0 +1,381 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.sql.Connection; +import java.sql.Driver; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.util.Collection; +import java.util.Map; +import java.util.Properties; + +import org.apache.arrow.driver.jdbc.authentication.UserPasswordAuthentication; +import org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty; +import org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.junit.After; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Test; + +/** + * Tests for {@link ArrowFlightJdbcDriver}. + */ +public class ArrowFlightJdbcDriverTest { + + @ClassRule + public static final FlightServerTestRule FLIGHT_SERVER_TEST_RULE; + private static final MockFlightSqlProducer PRODUCER = new MockFlightSqlProducer(); + + static { + UserPasswordAuthentication authentication = + new UserPasswordAuthentication.Builder().user("user1", "pass1").user("user2", "pass2") + .build(); + + FLIGHT_SERVER_TEST_RULE = new FlightServerTestRule.Builder().host("localhost").randomPort() + .authentication(authentication).producer(PRODUCER).build(); + } + + private BufferAllocator allocator; + private ArrowFlightJdbcConnectionPoolDataSource dataSource; + + @Before + public void setUp() throws Exception { + allocator = new RootAllocator(Long.MAX_VALUE); + dataSource = FLIGHT_SERVER_TEST_RULE.createConnectionPoolDataSource(); + } + + @After + public void tearDown() throws Exception { + Collection childAllocators = allocator.getChildAllocators(); + AutoCloseables.close(childAllocators.toArray(new AutoCloseable[0])); + AutoCloseables.close(dataSource, allocator); + } + + /** + * Tests whether the {@link ArrowFlightJdbcDriver} is registered in the + * {@link DriverManager}. + * + * @throws SQLException If an error occurs. (This is not supposed to happen.) + */ + @Test + public void testDriverIsRegisteredInDriverManager() throws Exception { + assertTrue(DriverManager.getDriver("jdbc:arrow-flight://localhost:32010") instanceof + ArrowFlightJdbcDriver); + assertTrue(DriverManager.getDriver("jdbc:arrow-flight-sql://localhost:32010") instanceof + ArrowFlightJdbcDriver); + } + + /** + * Tests whether the {@link ArrowFlightJdbcDriver} fails when provided with an + * unsupported URL prefix. + * + * @throws SQLException If the test passes. + */ + @Test(expected = SQLException.class) + public void testShouldDeclineUrlWithUnsupportedPrefix() throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + + driver.connect("jdbc:mysql://localhost:32010", dataSource.getProperties("flight", "flight123")) + .close(); + } + + /** + * Tests whether the {@link ArrowFlightJdbcDriver} can establish a successful + * connection to the Arrow Flight client. + * + * @throws Exception If the connection fails to be established. + */ + @Test + public void testShouldConnectWhenProvidedWithValidUrl() throws Exception { + // Get the Arrow Flight JDBC driver by providing a URL with a valid prefix. + final Driver driver = new ArrowFlightJdbcDriver(); + + try (Connection connection = + driver.connect("jdbc:arrow-flight://" + + dataSource.getConfig().getHost() + ":" + + dataSource.getConfig().getPort() + "?" + + "useEncryption=false", + dataSource.getProperties(dataSource.getConfig().getUser(), dataSource.getConfig().getPassword()))) { + assertTrue(connection.isValid(300)); + } + try (Connection connection = + driver.connect("jdbc:arrow-flight-sql://" + + dataSource.getConfig().getHost() + ":" + + dataSource.getConfig().getPort() + "?" + + "useEncryption=false", + dataSource.getProperties(dataSource.getConfig().getUser(), dataSource.getConfig().getPassword()))) { + assertTrue(connection.isValid(300)); + } + } + + @Test + public void testConnectWithInsensitiveCasePropertyKeys() throws Exception { + // Get the Arrow Flight JDBC driver by providing a URL with insensitive case property keys. + final Driver driver = new ArrowFlightJdbcDriver(); + + try (Connection connection = + driver.connect("jdbc:arrow-flight://" + + dataSource.getConfig().getHost() + ":" + + dataSource.getConfig().getPort() + "?" + + "UseEncryptiOn=false", + dataSource.getProperties(dataSource.getConfig().getUser(), dataSource.getConfig().getPassword()))) { + assertTrue(connection.isValid(300)); + } + try (Connection connection = + driver.connect("jdbc:arrow-flight-sql://" + + dataSource.getConfig().getHost() + ":" + + dataSource.getConfig().getPort() + "?" + + "UseEncryptiOn=false", + dataSource.getProperties(dataSource.getConfig().getUser(), dataSource.getConfig().getPassword()))) { + assertTrue(connection.isValid(300)); + } + } + + @Test + public void testConnectWithInsensitiveCasePropertyKeys2() throws Exception { + // Get the Arrow Flight JDBC driver by providing a property object with insensitive case keys. + final Driver driver = new ArrowFlightJdbcDriver(); + Properties properties = + dataSource.getProperties(dataSource.getConfig().getUser(), dataSource.getConfig().getPassword()); + properties.put("UseEncryptiOn", "false"); + + try (Connection connection = + driver.connect("jdbc:arrow-flight://" + + dataSource.getConfig().getHost() + ":" + + dataSource.getConfig().getPort(), properties)) { + assertTrue(connection.isValid(300)); + } + try (Connection connection = + driver.connect("jdbc:arrow-flight-sql://" + + dataSource.getConfig().getHost() + ":" + + dataSource.getConfig().getPort(), properties)) { + assertTrue(connection.isValid(300)); + } + } + + /** + * Tests whether an exception is thrown upon attempting to connect to a + * malformed URI. + */ + @Test(expected = SQLException.class) + public void testShouldThrowExceptionWhenAttemptingToConnectToMalformedUrl() throws SQLException { + final Driver driver = new ArrowFlightJdbcDriver(); + final String malformedUri = "yes:??/chainsaw.i=T333"; + + driver.connect(malformedUri, dataSource.getProperties("flight", "flight123")); + } + + /** + * Tests whether an exception is thrown upon attempting to connect to a + * malformed URI. + * + * @throws Exception If an error occurs. + */ + @Test(expected = SQLException.class) + public void testShouldThrowExceptionWhenAttemptingToConnectToUrlNoPrefix() throws SQLException { + final Driver driver = new ArrowFlightJdbcDriver(); + final String malformedUri = "localhost:32010"; + + driver.connect(malformedUri, dataSource.getProperties(dataSource.getConfig().getUser(), + dataSource.getConfig().getPassword())); + } + + /** + * Tests whether an exception is thrown upon attempting to connect to a + * malformed URI. + */ + @Test + public void testShouldThrowExceptionWhenAttemptingToConnectToUrlNoPort() { + final Driver driver = new ArrowFlightJdbcDriver(); + SQLException e = assertThrows(SQLException.class, () -> { + Properties properties = dataSource.getProperties(dataSource.getConfig().getUser(), + dataSource.getConfig().getPassword()); + Connection conn = driver.connect("jdbc:arrow-flight://localhost", properties); + conn.close(); + }); + assertTrue(e.getMessage().contains("URL must have a port")); + e = assertThrows(SQLException.class, () -> { + Properties properties = dataSource.getProperties(dataSource.getConfig().getUser(), + dataSource.getConfig().getPassword()); + Connection conn = driver.connect("jdbc:arrow-flight-sql://localhost", properties); + conn.close(); + }); + assertTrue(e.getMessage().contains("URL must have a port")); + } + + /** + * Tests whether an exception is thrown upon attempting to connect to a + * malformed URI. + */ + @Test + public void testShouldThrowExceptionWhenAttemptingToConnectToUrlNoHost() { + final Driver driver = new ArrowFlightJdbcDriver(); + SQLException e = assertThrows(SQLException.class, () -> { + Properties properties = dataSource.getProperties(dataSource.getConfig().getUser(), + dataSource.getConfig().getPassword()); + Connection conn = driver.connect("jdbc:arrow-flight://32010:localhost", properties); + conn.close(); + }); + assertTrue(e.getMessage().contains("URL must have a host")); + + e = assertThrows(SQLException.class, () -> { + Properties properties = dataSource.getProperties(dataSource.getConfig().getUser(), + dataSource.getConfig().getPassword()); + Connection conn = driver.connect("jdbc:arrow-flight-sql://32010:localhost", properties); + conn.close(); + }); + assertTrue(e.getMessage().contains("URL must have a host")); + } + + /** + * Tests whether {@link ArrowFlightJdbcDriver#getUrlsArgs} returns the + * correct URL parameters. + * + * @throws Exception If an error occurs. + */ + @Test + public void testDriverUrlParsingMechanismShouldReturnTheDesiredArgsFromUrl() throws Exception { + final ArrowFlightJdbcDriver driver = new ArrowFlightJdbcDriver(); + + final Map parsedArgs = driver.getUrlsArgs( + "jdbc:arrow-flight-sql://localhost:2222/?key1=value1&key2=value2&a=b"); + + // Check size == the amount of args provided (scheme not included) + assertEquals(5, parsedArgs.size()); + + // Check host == the provided host + assertEquals(parsedArgs.get(ArrowFlightConnectionProperty.HOST.camelName()), "localhost"); + + // Check port == the provided port + assertEquals(parsedArgs.get(ArrowFlightConnectionProperty.PORT.camelName()), 2222); + + // Check all other non-default arguments + assertEquals(parsedArgs.get("key1"), "value1"); + assertEquals(parsedArgs.get("key2"), "value2"); + assertEquals(parsedArgs.get("a"), "b"); + } + + @Test + public void testDriverUrlParsingMechanismShouldReturnTheDesiredArgsFromUrlWithSemicolon() throws Exception { + final ArrowFlightJdbcDriver driver = new ArrowFlightJdbcDriver(); + final Map parsedArgs = driver.getUrlsArgs( + "jdbc:arrow-flight-sql://localhost:2222/;key1=value1;key2=value2;a=b"); + + // Check size == the amount of args provided (scheme not included) + assertEquals(5, parsedArgs.size()); + + // Check host == the provided host + assertEquals(parsedArgs.get(ArrowFlightConnectionProperty.HOST.camelName()), "localhost"); + + // Check port == the provided port + assertEquals(parsedArgs.get(ArrowFlightConnectionProperty.PORT.camelName()), 2222); + + // Check all other non-default arguments + assertEquals(parsedArgs.get("key1"), "value1"); + assertEquals(parsedArgs.get("key2"), "value2"); + assertEquals(parsedArgs.get("a"), "b"); + } + + @Test + public void testDriverUrlParsingMechanismShouldReturnTheDesiredArgsFromUrlWithOneSemicolon() throws Exception { + final ArrowFlightJdbcDriver driver = new ArrowFlightJdbcDriver(); + final Map parsedArgs = driver.getUrlsArgs( + "jdbc:arrow-flight-sql://localhost:2222/;key1=value1"); + + // Check size == the amount of args provided (scheme not included) + assertEquals(3, parsedArgs.size()); + + // Check host == the provided host + assertEquals(parsedArgs.get(ArrowFlightConnectionProperty.HOST.camelName()), "localhost"); + + // Check port == the provided port + assertEquals(parsedArgs.get(ArrowFlightConnectionProperty.PORT.camelName()), 2222); + + // Check all other non-default arguments + assertEquals(parsedArgs.get("key1"), "value1"); + } + + /** + * Tests whether an exception is thrown upon attempting to connect to a + * malformed URI. + * + */ + @Test + public void testDriverUrlParsingMechanismShouldThrowExceptionUponProvidedWithMalformedUrl() { + final ArrowFlightJdbcDriver driver = new ArrowFlightJdbcDriver(); + assertThrows(SQLException.class, () -> driver.getUrlsArgs( + "jdbc:malformed-url-flight://localhost:2222")); + } + + /** + * Tests whether {@code ArrowFlightJdbcDriverTest#getUrlsArgs} returns the + * correct URL parameters when the host is an IP Address. + * + * @throws Exception If an error occurs. + */ + @Test + public void testDriverUrlParsingMechanismShouldWorkWithIPAddress() throws Exception { + final ArrowFlightJdbcDriver driver = new ArrowFlightJdbcDriver(); + final Map parsedArgs = driver.getUrlsArgs("jdbc:arrow-flight-sql://0.0.0.0:2222"); + + // Check size == the amount of args provided (scheme not included) + assertEquals(2, parsedArgs.size()); + + // Check host == the provided host + assertEquals(parsedArgs.get(ArrowFlightConnectionProperty.HOST.camelName()), "0.0.0.0"); + + // Check port == the provided port + assertEquals(parsedArgs.get(ArrowFlightConnectionProperty.PORT.camelName()), 2222); + } + + /** + * Tests whether {@code ArrowFlightJdbcDriverTest#getUrlsArgs} escape especial characters and returns the + * correct URL parameters when the especial character '&' is embedded in the query parameters values. + * + * @throws Exception If an error occurs. + */ + @Test + public void testDriverUrlParsingMechanismShouldWorkWithEmbeddedEspecialCharacter() + throws Exception { + final ArrowFlightJdbcDriver driver = new ArrowFlightJdbcDriver(); + final Map parsedArgs = driver.getUrlsArgs( + "jdbc:arrow-flight-sql://0.0.0.0:2222?test1=test1value&test2%26continue=test2value&test3=test3value"); + + // Check size == the amount of args provided (scheme not included) + assertEquals(5, parsedArgs.size()); + + // Check host == the provided host + assertEquals(parsedArgs.get(ArrowFlightConnectionProperty.HOST.camelName()), "0.0.0.0"); + + // Check port == the provided port + assertEquals(parsedArgs.get(ArrowFlightConnectionProperty.PORT.camelName()), 2222); + + // Check all other non-default arguments + assertEquals(parsedArgs.get("test1"), "test1value"); + assertEquals(parsedArgs.get("test2&continue"), "test2value"); + assertEquals(parsedArgs.get("test3"), "test3value"); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFactoryTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFactoryTest.java new file mode 100644 index 0000000000000..c482169852e5e --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFactoryTest.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import java.lang.reflect.Constructor; +import java.sql.Connection; +import java.util.Properties; + +import org.apache.arrow.driver.jdbc.authentication.UserPasswordAuthentication; +import org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty; +import org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.apache.calcite.avatica.UnregisteredDriver; +import org.junit.After; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Test; + +import com.google.common.collect.ImmutableMap; + +/** + * Tests for {@link ArrowFlightJdbcDriver}. + */ +public class ArrowFlightJdbcFactoryTest { + + @ClassRule + public static final FlightServerTestRule FLIGHT_SERVER_TEST_RULE; + private static final MockFlightSqlProducer PRODUCER = new MockFlightSqlProducer(); + + static { + UserPasswordAuthentication authentication = + new UserPasswordAuthentication.Builder().user("user1", "pass1").user("user2", "pass2") + .build(); + + FLIGHT_SERVER_TEST_RULE = new FlightServerTestRule.Builder().host("localhost").randomPort() + .authentication(authentication).producer(PRODUCER).build(); + } + + private BufferAllocator allocator; + private ArrowFlightJdbcConnectionPoolDataSource dataSource; + + @Before + public void setUp() throws Exception { + allocator = new RootAllocator(Long.MAX_VALUE); + dataSource = FLIGHT_SERVER_TEST_RULE.createConnectionPoolDataSource(); + } + + @After + public void tearDown() throws Exception { + AutoCloseables.close(dataSource, allocator); + } + + @Test + public void testShouldBeAbleToEstablishAConnectionSuccessfully() throws Exception { + UnregisteredDriver driver = new ArrowFlightJdbcDriver(); + Constructor constructor = ArrowFlightJdbcFactory.class.getConstructor(); + constructor.setAccessible(true); + ArrowFlightJdbcFactory factory = constructor.newInstance(); + + final Properties properties = new Properties(); + properties.putAll(ImmutableMap.of( + ArrowFlightConnectionProperty.HOST.camelName(), "localhost", + ArrowFlightConnectionProperty.PORT.camelName(), 32010, + ArrowFlightConnectionProperty.USE_ENCRYPTION.camelName(), false)); + + try (Connection connection = factory.newConnection(driver, constructor.newInstance(), + "jdbc:arrow-flight-sql://localhost:32010", properties)) { + assert connection.isValid(300); + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcTimeTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcTimeTest.java new file mode 100644 index 0000000000000..104794b3ad145 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcTimeTest.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import static org.hamcrest.CoreMatchers.endsWith; +import static org.hamcrest.CoreMatchers.is; + +import java.time.LocalTime; +import java.util.concurrent.TimeUnit; + +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; + +public class ArrowFlightJdbcTimeTest { + + @ClassRule + public static final ErrorCollector collector = new ErrorCollector(); + final int hour = 5; + final int minute = 6; + final int second = 7; + + @Test + public void testPrintingMillisNoLeadingZeroes() { + // testing the regular case where the precision of the millisecond is 3 + LocalTime dateTime = LocalTime.of(hour, minute, second, (int) TimeUnit.MILLISECONDS.toNanos(999)); + ArrowFlightJdbcTime time = new ArrowFlightJdbcTime(dateTime); + collector.checkThat(time.toString(), endsWith(".999")); + collector.checkThat(time.getHours(), is(hour)); + collector.checkThat(time.getMinutes(), is(minute)); + collector.checkThat(time.getSeconds(), is(second)); + } + + @Test + public void testPrintingMillisOneLeadingZeroes() { + // test case where one leading zero needs to be added + LocalTime dateTime = LocalTime.of(hour, minute, second, (int) TimeUnit.MILLISECONDS.toNanos(99)); + ArrowFlightJdbcTime time = new ArrowFlightJdbcTime(dateTime); + collector.checkThat(time.toString(), endsWith(".099")); + collector.checkThat(time.getHours(), is(hour)); + collector.checkThat(time.getMinutes(), is(minute)); + collector.checkThat(time.getSeconds(), is(second)); + } + + @Test + public void testPrintingMillisTwoLeadingZeroes() { + // test case where two leading zeroes needs to be added + LocalTime dateTime = LocalTime.of(hour, minute, second, (int) TimeUnit.MILLISECONDS.toNanos(1)); + ArrowFlightJdbcTime time = new ArrowFlightJdbcTime(dateTime); + collector.checkThat(time.toString(), endsWith(".001")); + collector.checkThat(time.getHours(), is(hour)); + collector.checkThat(time.getMinutes(), is(minute)); + collector.checkThat(time.getSeconds(), is(second)); + } + + @Test + public void testEquality() { + // tests #equals and #hashCode for coverage checks + LocalTime dateTime = LocalTime.of(hour, minute, second, (int) TimeUnit.MILLISECONDS.toNanos(1)); + ArrowFlightJdbcTime time1 = new ArrowFlightJdbcTime(dateTime); + ArrowFlightJdbcTime time2 = new ArrowFlightJdbcTime(dateTime); + collector.checkThat(time1, is(time2)); + collector.checkThat(time1.hashCode(), is(time2.hashCode())); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java new file mode 100644 index 0000000000000..51c491be288f3 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import static org.hamcrest.CoreMatchers.equalTo; + +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; + +import org.apache.arrow.driver.jdbc.utils.CoreMockedSqlProducers; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; + +public class ArrowFlightPreparedStatementTest { + + @ClassRule + public static final FlightServerTestRule FLIGHT_SERVER_TEST_RULE = FlightServerTestRule + .createStandardTestRule(CoreMockedSqlProducers.getLegacyProducer()); + + private static Connection connection; + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + + @BeforeClass + public static void setup() throws SQLException { + connection = FLIGHT_SERVER_TEST_RULE.getConnection(false); + } + + @AfterClass + public static void tearDown() throws SQLException { + connection.close(); + } + + @Test + public void testSimpleQueryNoParameterBinding() throws SQLException { + final String query = CoreMockedSqlProducers.LEGACY_REGULAR_SQL_CMD; + try (final PreparedStatement preparedStatement = connection.prepareStatement(query); + final ResultSet resultSet = preparedStatement.executeQuery()) { + CoreMockedSqlProducers.assertLegacyRegularSqlResultSet(resultSet, collector); + } + } + + @Test + public void testReturnColumnCount() throws SQLException { + final String query = CoreMockedSqlProducers.LEGACY_REGULAR_SQL_CMD; + try (final PreparedStatement psmt = connection.prepareStatement(query)) { + collector.checkThat("ID", equalTo(psmt.getMetaData().getColumnName(1))); + collector.checkThat("Name", equalTo(psmt.getMetaData().getColumnName(2))); + collector.checkThat("Age", equalTo(psmt.getMetaData().getColumnName(3))); + collector.checkThat("Salary", equalTo(psmt.getMetaData().getColumnName(4))); + collector.checkThat("Hire Date", equalTo(psmt.getMetaData().getColumnName(5))); + collector.checkThat("Last Sale", equalTo(psmt.getMetaData().getColumnName(6))); + collector.checkThat(6, equalTo(psmt.getMetaData().getColumnCount())); + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementExecuteTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementExecuteTest.java new file mode 100644 index 0000000000000..155fcc50827a1 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementExecuteTest.java @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import static org.hamcrest.CoreMatchers.allOf; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.not; +import static org.hamcrest.CoreMatchers.nullValue; + +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.UInt1Vector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.Types.MinorType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.calcite.avatica.AvaticaUtils; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; + +/** + * Tests for {@link ArrowFlightStatement#execute}. + */ +public class ArrowFlightStatementExecuteTest { + private static final String SAMPLE_QUERY_CMD = "SELECT * FROM this_test"; + private static final int SAMPLE_QUERY_ROWS = Byte.MAX_VALUE; + private static final String VECTOR_NAME = "Unsigned Byte"; + private static final Schema SAMPLE_QUERY_SCHEMA = + new Schema(Collections.singletonList(Field.nullable(VECTOR_NAME, MinorType.UINT1.getType()))); + private static final String SAMPLE_UPDATE_QUERY = + "UPDATE this_table SET this_field = that_field FROM this_test WHERE this_condition"; + private static final long SAMPLE_UPDATE_COUNT = 100L; + private static final String SAMPLE_LARGE_UPDATE_QUERY = + "UPDATE this_large_table SET this_large_field = that_large_field FROM this_large_test WHERE this_large_condition"; + private static final long SAMPLE_LARGE_UPDATE_COUNT = Long.MAX_VALUE; + private static final MockFlightSqlProducer PRODUCER = new MockFlightSqlProducer(); + @ClassRule + public static final FlightServerTestRule SERVER_TEST_RULE = FlightServerTestRule.createStandardTestRule(PRODUCER); + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + private Connection connection; + private Statement statement; + + @BeforeClass + public static void setUpBeforeClass() { + PRODUCER.addSelectQuery( + SAMPLE_QUERY_CMD, + SAMPLE_QUERY_SCHEMA, + Collections.singletonList(listener -> { + try (final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + final VectorSchemaRoot root = VectorSchemaRoot.create(SAMPLE_QUERY_SCHEMA, + allocator)) { + final UInt1Vector vector = (UInt1Vector) root.getVector(VECTOR_NAME); + IntStream.range(0, SAMPLE_QUERY_ROWS).forEach(index -> vector.setSafe(index, index)); + vector.setValueCount(SAMPLE_QUERY_ROWS); + root.setRowCount(SAMPLE_QUERY_ROWS); + listener.start(root); + listener.putNext(); + } catch (final Throwable throwable) { + listener.error(throwable); + } finally { + listener.completed(); + } + })); + PRODUCER.addUpdateQuery(SAMPLE_UPDATE_QUERY, SAMPLE_UPDATE_COUNT); + PRODUCER.addUpdateQuery(SAMPLE_LARGE_UPDATE_QUERY, SAMPLE_LARGE_UPDATE_COUNT); + } + + @Before + public void setUp() throws SQLException { + connection = SERVER_TEST_RULE.getConnection(false); + statement = connection.createStatement(); + } + + @After + public void tearDown() throws Exception { + AutoCloseables.close(statement, connection); + } + + @AfterClass + public static void tearDownAfterClass() throws Exception { + AutoCloseables.close(PRODUCER); + } + + @Test + public void testExecuteShouldRunSelectQuery() throws SQLException { + collector.checkThat(statement.execute(SAMPLE_QUERY_CMD), + is(true)); // Means this is a SELECT query. + final Set numbers = + IntStream.range(0, SAMPLE_QUERY_ROWS).boxed() + .map(Integer::byteValue) + .collect(Collectors.toCollection(HashSet::new)); + try (final ResultSet resultSet = statement.getResultSet()) { + final int columnCount = resultSet.getMetaData().getColumnCount(); + collector.checkThat(columnCount, is(1)); + int rowCount = 0; + for (; resultSet.next(); rowCount++) { + collector.checkThat(numbers.remove(resultSet.getByte(1)), is(true)); + } + collector.checkThat(rowCount, is(equalTo(SAMPLE_QUERY_ROWS))); + } + collector.checkThat(numbers, is(Collections.emptySet())); + collector.checkThat( + (long) statement.getUpdateCount(), + is(allOf(equalTo(statement.getLargeUpdateCount()), equalTo(-1L)))); + } + + @Test + public void testExecuteShouldRunUpdateQueryForSmallUpdate() throws SQLException { + collector.checkThat(statement.execute(SAMPLE_UPDATE_QUERY), + is(false)); // Means this is an UPDATE query. + collector.checkThat( + (long) statement.getUpdateCount(), + is(allOf(equalTo(statement.getLargeUpdateCount()), equalTo(SAMPLE_UPDATE_COUNT)))); + collector.checkThat(statement.getResultSet(), is(nullValue())); + } + + @Test + public void testExecuteShouldRunUpdateQueryForLargeUpdate() throws SQLException { + collector.checkThat(statement.execute(SAMPLE_LARGE_UPDATE_QUERY), is(false)); // UPDATE query. + final long updateCountSmall = statement.getUpdateCount(); + final long updateCountLarge = statement.getLargeUpdateCount(); + collector.checkThat(updateCountLarge, is(equalTo(SAMPLE_LARGE_UPDATE_COUNT))); + collector.checkThat( + updateCountSmall, + is(allOf(equalTo((long) AvaticaUtils.toSaturatedInt(updateCountLarge)), + not(equalTo(updateCountLarge))))); + collector.checkThat(statement.getResultSet(), is(nullValue())); + } + + @Test + public void testUpdateCountShouldStartOnZero() throws SQLException { + collector.checkThat( + (long) statement.getUpdateCount(), + is(allOf(equalTo(statement.getLargeUpdateCount()), equalTo(0L)))); + collector.checkThat(statement.getResultSet(), is(nullValue())); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementExecuteUpdateTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementExecuteUpdateTest.java new file mode 100644 index 0000000000000..43209d8913ebd --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementExecuteUpdateTest.java @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import static java.lang.String.format; +import static org.hamcrest.CoreMatchers.allOf; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.not; + +import java.sql.Connection; +import java.sql.SQLException; +import java.sql.SQLFeatureNotSupportedException; +import java.sql.Statement; +import java.util.Collections; + +import org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.Types.MinorType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.calcite.avatica.AvaticaUtils; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; + +/** + * Tests for {@link ArrowFlightStatement#executeUpdate}. + */ +public class ArrowFlightStatementExecuteUpdateTest { + private static final String UPDATE_SAMPLE_QUERY = + "UPDATE sample_table SET sample_col = sample_val WHERE sample_condition"; + private static final int UPDATE_SAMPLE_QUERY_AFFECTED_COLS = 10; + private static final String LARGE_UPDATE_SAMPLE_QUERY = + "UPDATE large_sample_table SET large_sample_col = large_sample_val WHERE large_sample_condition"; + private static final long LARGE_UPDATE_SAMPLE_QUERY_AFFECTED_COLS = (long) Integer.MAX_VALUE + 1; + private static final String REGULAR_QUERY_SAMPLE = "SELECT * FROM NOT_UPDATE_QUERY"; + private static final Schema REGULAR_QUERY_SCHEMA = + new Schema( + Collections.singletonList(Field.nullable("placeholder", MinorType.VARCHAR.getType()))); + private static final MockFlightSqlProducer PRODUCER = new MockFlightSqlProducer(); + @ClassRule + public static final FlightServerTestRule SERVER_TEST_RULE = FlightServerTestRule.createStandardTestRule(PRODUCER); + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + public Connection connection; + public Statement statement; + + @BeforeClass + public static void setUpBeforeClass() { + PRODUCER.addUpdateQuery(UPDATE_SAMPLE_QUERY, UPDATE_SAMPLE_QUERY_AFFECTED_COLS); + PRODUCER.addUpdateQuery(LARGE_UPDATE_SAMPLE_QUERY, LARGE_UPDATE_SAMPLE_QUERY_AFFECTED_COLS); + PRODUCER.addSelectQuery( + REGULAR_QUERY_SAMPLE, + REGULAR_QUERY_SCHEMA, + Collections.singletonList(listener -> { + try (final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + final VectorSchemaRoot root = VectorSchemaRoot.create(REGULAR_QUERY_SCHEMA, + allocator)) { + listener.start(root); + listener.putNext(); + } catch (final Throwable throwable) { + listener.error(throwable); + } finally { + listener.completed(); + } + })); + } + + @Before + public void setUp() throws SQLException { + connection = SERVER_TEST_RULE.getConnection(false); + statement = connection.createStatement(); + } + + @After + public void tearDown() throws Exception { + AutoCloseables.close(statement, connection); + } + + @AfterClass + public static void tearDownAfterClass() throws Exception { + AutoCloseables.close(PRODUCER); + } + + @Test + public void testExecuteUpdateShouldReturnNumColsAffectedForNumRowsFittingInt() + throws SQLException { + collector.checkThat(statement.executeUpdate(UPDATE_SAMPLE_QUERY), + is(UPDATE_SAMPLE_QUERY_AFFECTED_COLS)); + } + + @Test + public void testExecuteUpdateShouldReturnSaturatedNumColsAffectedIfDoesNotFitInInt() + throws SQLException { + final long result = statement.executeUpdate(LARGE_UPDATE_SAMPLE_QUERY); + final long expectedRowCountRaw = LARGE_UPDATE_SAMPLE_QUERY_AFFECTED_COLS; + collector.checkThat( + result, + is(allOf( + not(equalTo(expectedRowCountRaw)), + equalTo((long) AvaticaUtils.toSaturatedInt( + expectedRowCountRaw))))); // Because of long-to-integer overflow. + } + + @Test + public void testExecuteLargeUpdateShouldReturnNumColsAffected() throws SQLException { + collector.checkThat( + statement.executeLargeUpdate(LARGE_UPDATE_SAMPLE_QUERY), + is(LARGE_UPDATE_SAMPLE_QUERY_AFFECTED_COLS)); + } + + @Test(expected = SQLFeatureNotSupportedException.class) + // TODO Implement `Statement#executeUpdate(String, int)` + public void testExecuteUpdateUnsupportedWithDriverFlag() throws SQLException { + collector.checkThat( + statement.executeUpdate(UPDATE_SAMPLE_QUERY, Statement.RETURN_GENERATED_KEYS), + is(UPDATE_SAMPLE_QUERY_AFFECTED_COLS)); + } + + @Test(expected = SQLFeatureNotSupportedException.class) + // TODO Implement `Statement#executeUpdate(String, int[])` + public void testExecuteUpdateUnsupportedWithArrayOfInts() throws SQLException { + collector.checkThat( + statement.executeUpdate(UPDATE_SAMPLE_QUERY, new int[0]), + is(UPDATE_SAMPLE_QUERY_AFFECTED_COLS)); + } + + @Test(expected = SQLFeatureNotSupportedException.class) + // TODO Implement `Statement#executeUpdate(String, String[])` + public void testExecuteUpdateUnsupportedWithArraysOfStrings() throws SQLException { + collector.checkThat( + statement.executeUpdate(UPDATE_SAMPLE_QUERY, new String[0]), + is(UPDATE_SAMPLE_QUERY_AFFECTED_COLS)); + } + + @Test + public void testExecuteShouldExecuteUpdateQueryAutomatically() throws SQLException { + collector.checkThat(statement.execute(UPDATE_SAMPLE_QUERY), + is(false)); // Meaning there was an update query. + collector.checkThat(statement.execute(REGULAR_QUERY_SAMPLE), + is(true)); // Meaning there was a select query. + } + + @Test + public void testShouldFailToPrepareStatementForNullQuery() { + int count = 0; + try { + collector.checkThat(statement.execute(null), is(false)); + } catch (final SQLException e) { + count++; + collector.checkThat(e.getCause(), is(instanceOf(NullPointerException.class))); + } + collector.checkThat(count, is(1)); + } + + @Test + public void testShouldFailToPrepareStatementForClosedStatement() throws SQLException { + statement.close(); + collector.checkThat(statement.isClosed(), is(true)); + int count = 0; + try { + statement.execute(UPDATE_SAMPLE_QUERY); + } catch (final SQLException e) { + count++; + collector.checkThat(e.getMessage(), is("Statement closed")); + } + collector.checkThat(count, is(1)); + } + + @Test + public void testShouldFailToPrepareStatementForBadStatement() { + final String badQuery = "BAD INVALID STATEMENT"; + int count = 0; + try { + statement.execute(badQuery); + } catch (final SQLException e) { + count++; + /* + * The error message is up to whatever implementation of `FlightSqlProducer` + * the driver is communicating with. However, for the purpose of this test, + * we simply throw an `IllegalArgumentException` for queries not registered + * in our `MockFlightSqlProducer`. + */ + collector.checkThat( + e.getMessage(), + is(format("Error while executing SQL \"%s\": Query not found", badQuery))); + } + collector.checkThat(count, is(1)); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ConnectionTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ConnectionTest.java new file mode 100644 index 0000000000000..6fe7ba7129829 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ConnectionTest.java @@ -0,0 +1,552 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import static org.junit.Assert.assertNotNull; + +import java.net.URISyntaxException; +import java.sql.Connection; +import java.sql.Driver; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.util.Properties; + +import org.apache.arrow.driver.jdbc.authentication.UserPasswordAuthentication; +import org.apache.arrow.driver.jdbc.client.ArrowFlightSqlClientHandler; +import org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty; +import org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Test; + +/** + * Tests for {@link Connection}. + */ +public class ConnectionTest { + + @ClassRule + public static final FlightServerTestRule FLIGHT_SERVER_TEST_RULE; + private static final MockFlightSqlProducer PRODUCER = new MockFlightSqlProducer(); + private static final String userTest = "user1"; + private static final String passTest = "pass1"; + + static { + UserPasswordAuthentication authentication = + new UserPasswordAuthentication.Builder() + .user(userTest, passTest) + .build(); + + FLIGHT_SERVER_TEST_RULE = new FlightServerTestRule.Builder().host("localhost").randomPort() + .authentication(authentication).producer(PRODUCER).build(); + } + + private BufferAllocator allocator; + + @Before + public void setUp() throws Exception { + allocator = new RootAllocator(Long.MAX_VALUE); + } + + @After + public void tearDown() throws Exception { + allocator.getChildAllocators().forEach(BufferAllocator::close); + AutoCloseables.close(allocator); + } + + /** + * Checks if an unencrypted connection can be established successfully when + * the provided valid credentials. + * + * @throws SQLException on error. + */ + @Test + public void testUnencryptedConnectionShouldOpenSuccessfullyWhenProvidedValidCredentials() + throws Exception { + final Properties properties = new Properties(); + + properties.put(ArrowFlightConnectionProperty.HOST.camelName(), "localhost"); + properties.put(ArrowFlightConnectionProperty.PORT.camelName(), + FLIGHT_SERVER_TEST_RULE.getPort()); + properties.put(ArrowFlightConnectionProperty.USER.camelName(), + userTest); + properties.put(ArrowFlightConnectionProperty.PASSWORD.camelName(), + passTest); + properties.put("useEncryption", false); + + try (Connection connection = DriverManager.getConnection( + "jdbc:arrow-flight-sql://" + FLIGHT_SERVER_TEST_RULE.getHost() + ":" + + FLIGHT_SERVER_TEST_RULE.getPort(), properties)) { + assert connection.isValid(300); + } + } + + /** + * Checks if the exception SQLException is thrown when trying to establish a connection without a host. + * + * @throws SQLException on error. + */ + @Test(expected = SQLException.class) + public void testUnencryptedConnectionWithEmptyHost() + throws Exception { + final Properties properties = new Properties(); + + properties.put("user", userTest); + properties.put("password", passTest); + final String invalidUrl = "jdbc:arrow-flight-sql://"; + + DriverManager.getConnection(invalidUrl, properties); + } + + /** + * Try to instantiate a basic FlightClient. + * + * @throws URISyntaxException on error. + */ + @Test + public void testGetBasicClientAuthenticatedShouldOpenConnection() + throws Exception { + + try (ArrowFlightSqlClientHandler client = + new ArrowFlightSqlClientHandler.Builder() + .withHost(FLIGHT_SERVER_TEST_RULE.getHost()) + .withPort(FLIGHT_SERVER_TEST_RULE.getPort()) + .withUsername(userTest) + .withPassword(passTest) + .withBufferAllocator(allocator) + .build()) { + assertNotNull(client); + } + } + + /** + * Checks if the exception IllegalArgumentException is thrown when trying to establish an unencrypted + * connection providing with an invalid port. + * + * @throws SQLException on error. + */ + @Test(expected = SQLException.class) + public void testUnencryptedConnectionProvidingInvalidPort() + throws Exception { + final Properties properties = new Properties(); + + properties.put(ArrowFlightConnectionProperty.HOST.camelName(), "localhost"); + properties.put(ArrowFlightConnectionProperty.USER.camelName(), + userTest); + properties.put(ArrowFlightConnectionProperty.PASSWORD.camelName(), + passTest); + properties.put(ArrowFlightConnectionProperty.USE_ENCRYPTION.camelName(), + false); + final String invalidUrl = "jdbc:arrow-flight-sql://" + FLIGHT_SERVER_TEST_RULE.getHost() + + ":" + 65537; + + DriverManager.getConnection(invalidUrl, properties); + } + + /** + * Try to instantiate a basic FlightClient. + * + * @throws URISyntaxException on error. + */ + @Test + public void testGetBasicClientNoAuthShouldOpenConnection() throws Exception { + + try (ArrowFlightSqlClientHandler client = + new ArrowFlightSqlClientHandler.Builder() + .withHost(FLIGHT_SERVER_TEST_RULE.getHost()) + .withBufferAllocator(allocator) + .build()) { + assertNotNull(client); + } + } + + /** + * Checks if an unencrypted connection can be established successfully when + * not providing credentials. + * + * @throws SQLException on error. + */ + @Test + public void testUnencryptedConnectionShouldOpenSuccessfullyWithoutAuthentication() + throws Exception { + final Properties properties = new Properties(); + properties.put(ArrowFlightConnectionProperty.HOST.camelName(), "localhost"); + properties.put(ArrowFlightConnectionProperty.PORT.camelName(), + FLIGHT_SERVER_TEST_RULE.getPort()); + properties.put(ArrowFlightConnectionProperty.USE_ENCRYPTION.camelName(), + false); + try (Connection connection = DriverManager + .getConnection("jdbc:arrow-flight-sql://localhost:32010", properties)) { + assert connection.isValid(300); + } + } + + /** + * Check if an unencrypted connection throws an exception when provided with + * invalid credentials. + * + * @throws SQLException The exception expected to be thrown. + */ + @Test(expected = SQLException.class) + public void testUnencryptedConnectionShouldThrowExceptionWhenProvidedWithInvalidCredentials() + throws Exception { + + final Properties properties = new Properties(); + + properties.put(ArrowFlightConnectionProperty.HOST.camelName(), "localhost"); + properties.put(ArrowFlightConnectionProperty.USER.camelName(), + "invalidUser"); + properties.put(ArrowFlightConnectionProperty.PORT.camelName(), + FLIGHT_SERVER_TEST_RULE.getPort()); + properties.put(ArrowFlightConnectionProperty.USE_ENCRYPTION.camelName(), + false); + properties.put(ArrowFlightConnectionProperty.PASSWORD.camelName(), + "invalidPassword"); + + try (Connection ignored = DriverManager.getConnection("jdbc:arrow-flight-sql://localhost:32010", + properties)) { + Assert.fail(); + } + } + + /** + * Check if an non-encrypted connection can be established successfully when connecting through + * the DriverManager using just a connection url. + * + * @throws Exception on error. + */ + @Test + public void testTLSConnectionPropertyFalseCorrectCastUrlWithDriverManager() throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + DriverManager.registerDriver(driver); + + Connection connection = DriverManager.getConnection( + String.format( + "jdbc:arrow-flight-sql://localhost:%s?user=%s&password=%s&useEncryption=false", + FLIGHT_SERVER_TEST_RULE.getPort(), + userTest, + passTest)); + Assert.assertTrue(connection.isValid(0)); + connection.close(); + } + + /** + * Check if an non-encrypted connection can be established successfully when connecting through + * the DriverManager using a connection url and properties with String K-V pairs. + * + * @throws Exception on error. + */ + @Test + public void testTLSConnectionPropertyFalseCorrectCastUrlAndPropertiesUsingSetPropertyWithDriverManager() + throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + DriverManager.registerDriver(driver); + + Properties properties = new Properties(); + + properties.setProperty(ArrowFlightConnectionProperty.USER.camelName(), + userTest); + properties.setProperty(ArrowFlightConnectionProperty.PASSWORD.camelName(), + passTest); + properties.setProperty(ArrowFlightConnectionProperty.USE_ENCRYPTION.camelName(), "false"); + + Connection connection = DriverManager.getConnection( + String.format( + "jdbc:arrow-flight-sql://localhost:%s", + FLIGHT_SERVER_TEST_RULE.getPort()), + properties); + Assert.assertTrue(connection.isValid(0)); + connection.close(); + } + + /** + * Check if an non-encrypted connection can be established successfully when connecting through + * the DriverManager using a connection url and properties with Object K-V pairs. + * + * @throws Exception on error. + */ + @Test + public void testTLSConnectionPropertyFalseCorrectCastUrlAndPropertiesUsingPutWithDriverManager() + throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + DriverManager.registerDriver(driver); + + Properties properties = new Properties(); + properties.put(ArrowFlightConnectionProperty.USER.camelName(), + userTest); + properties.put(ArrowFlightConnectionProperty.PASSWORD.camelName(), + passTest); + properties.put(ArrowFlightConnectionProperty.USE_ENCRYPTION.camelName(), false); + + Connection connection = DriverManager.getConnection( + String.format( + "jdbc:arrow-flight-sql://localhost:%s", + FLIGHT_SERVER_TEST_RULE.getPort()), + properties); + Assert.assertTrue(connection.isValid(0)); + connection.close(); + } + + /** + * Check if an non-encrypted connection can be established successfully when connecting through + * the DriverManager using just a connection url and using 0 and 1 as ssl values. + * + * @throws Exception on error. + */ + @Test + public void testTLSConnectionPropertyFalseIntegerCorrectCastUrlWithDriverManager() + throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + DriverManager.registerDriver(driver); + + Connection connection = DriverManager.getConnection( + String.format( + "jdbc:arrow-flight-sql://localhost:%s?user=%s&password=%s&useEncryption=0", + FLIGHT_SERVER_TEST_RULE.getPort(), + userTest, + passTest)); + Assert.assertTrue(connection.isValid(0)); + connection.close(); + } + + /** + * Check if an non-encrypted connection can be established successfully when connecting through + * the DriverManager using a connection url and properties with String K-V pairs and using + * 0 and 1 as ssl values. + * + * @throws Exception on error. + */ + @Test + public void testTLSConnectionPropertyFalseIntegerCorrectCastUrlAndPropertiesUsingSetPropertyWithDriverManager() + throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + DriverManager.registerDriver(driver); + Properties properties = new Properties(); + + properties.setProperty(ArrowFlightConnectionProperty.USER.camelName(), + userTest); + properties.setProperty(ArrowFlightConnectionProperty.PASSWORD.camelName(), + passTest); + properties.setProperty(ArrowFlightConnectionProperty.USE_ENCRYPTION.camelName(), "0"); + + Connection connection = DriverManager.getConnection( + String.format( + "jdbc:arrow-flight-sql://localhost:%s", + FLIGHT_SERVER_TEST_RULE.getPort()), + properties); + Assert.assertTrue(connection.isValid(0)); + connection.close(); + } + + /** + * Check if an non-encrypted connection can be established successfully when connecting through + * the DriverManager using a connection url and properties with Object K-V pairs and using + * 0 and 1 as ssl values. + * + * @throws Exception on error. + */ + @Test + public void testTLSConnectionPropertyFalseIntegerCorrectCastUrlAndPropertiesUsingPutWithDriverManager() + throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + DriverManager.registerDriver(driver); + + Properties properties = new Properties(); + properties.put(ArrowFlightConnectionProperty.USER.camelName(), + userTest); + properties.put(ArrowFlightConnectionProperty.PASSWORD.camelName(), + passTest); + properties.put(ArrowFlightConnectionProperty.USE_ENCRYPTION.camelName(), 0); + + Connection connection = DriverManager.getConnection( + String.format( + "jdbc:arrow-flight-sql://localhost:%s", + FLIGHT_SERVER_TEST_RULE.getPort()), + properties); + Assert.assertTrue(connection.isValid(0)); + connection.close(); + } + + /** + * Check if an non-encrypted connection can be established successfully when connecting through + * the DriverManager using just a connection url. + * + * @throws Exception on error. + */ + @Test + public void testThreadPoolSizeConnectionPropertyCorrectCastUrlWithDriverManager() + throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + DriverManager.registerDriver(driver); + + Connection connection = DriverManager.getConnection( + String.format( + "jdbc:arrow-flight-sql://localhost:%s?user=%s&password=%s&threadPoolSize=1&useEncryption=%s", + FLIGHT_SERVER_TEST_RULE.getPort(), + userTest, + passTest, + false)); + Assert.assertTrue(connection.isValid(0)); + connection.close(); + } + + /** + * Check if an non-encrypted connection can be established successfully when connecting through + * the DriverManager using a connection url and properties with String K-V pairs and using + * 0 and 1 as ssl values. + * + * @throws Exception on error. + */ + @Test + public void testThreadPoolSizeConnectionPropertyCorrectCastUrlAndPropertiesUsingSetPropertyWithDriverManager() + throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + DriverManager.registerDriver(driver); + Properties properties = new Properties(); + + properties.setProperty(ArrowFlightConnectionProperty.USER.camelName(), + userTest); + properties.setProperty(ArrowFlightConnectionProperty.PASSWORD.camelName(), + passTest); + properties.setProperty(ArrowFlightConnectionProperty.THREAD_POOL_SIZE.camelName(), "1"); + properties.put("useEncryption", false); + + Connection connection = DriverManager.getConnection( + String.format( + "jdbc:arrow-flight-sql://localhost:%s", + FLIGHT_SERVER_TEST_RULE.getPort()), + properties); + Assert.assertTrue(connection.isValid(0)); + connection.close(); + } + + /** + * Check if an non-encrypted connection can be established successfully when connecting through + * the DriverManager using a connection url and properties with Object K-V pairs and using + * 0 and 1 as ssl values. + * + * @throws Exception on error. + */ + @Test + public void testThreadPoolSizeConnectionPropertyCorrectCastUrlAndPropertiesUsingPutWithDriverManager() + throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + DriverManager.registerDriver(driver); + + Properties properties = new Properties(); + properties.put(ArrowFlightConnectionProperty.USER.camelName(), + userTest); + properties.put(ArrowFlightConnectionProperty.PASSWORD.camelName(), + passTest); + properties.put(ArrowFlightConnectionProperty.THREAD_POOL_SIZE.camelName(), 1); + properties.put("useEncryption", false); + + Connection connection = DriverManager.getConnection( + String.format( + "jdbc:arrow-flight-sql://localhost:%s", + FLIGHT_SERVER_TEST_RULE.getPort()), + properties); + Assert.assertTrue(connection.isValid(0)); + connection.close(); + } + + /** + * Check if an non-encrypted connection can be established successfully when connecting through + * the DriverManager using just a connection url. + * + * @throws Exception on error. + */ + @Test + public void testPasswordConnectionPropertyIntegerCorrectCastUrlWithDriverManager() + throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + DriverManager.registerDriver(driver); + + Connection connection = DriverManager.getConnection( + String.format( + "jdbc:arrow-flight-sql://localhost:%s?user=%s&password=%s&useEncryption=%s", + FLIGHT_SERVER_TEST_RULE.getPort(), + userTest, + passTest, + false)); + Assert.assertTrue(connection.isValid(0)); + connection.close(); + } + + /** + * Check if an non-encrypted connection can be established successfully when connecting through + * the DriverManager using a connection url and properties with String K-V pairs and using + * 0 and 1 as ssl values. + * + * @throws Exception on error. + */ + @Test + public void testPasswordConnectionPropertyIntegerCorrectCastUrlAndPropertiesUsingSetPropertyWithDriverManager() + throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + DriverManager.registerDriver(driver); + Properties properties = new Properties(); + + properties.setProperty(ArrowFlightConnectionProperty.USER.camelName(), + userTest); + properties.setProperty(ArrowFlightConnectionProperty.PASSWORD.camelName(), + passTest); + properties.put("useEncryption", false); + + Connection connection = DriverManager.getConnection( + String.format( + "jdbc:arrow-flight-sql://localhost:%s", + FLIGHT_SERVER_TEST_RULE.getPort()), + properties); + Assert.assertTrue(connection.isValid(0)); + connection.close(); + } + + /** + * Check if an non-encrypted connection can be established successfully when connecting through + * the DriverManager using a connection url and properties with Object K-V pairs and using + * 0 and 1 as ssl values. + * + * @throws Exception on error. + */ + @Test + public void testPasswordConnectionPropertyIntegerCorrectCastUrlAndPropertiesUsingPutWithDriverManager() + throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + DriverManager.registerDriver(driver); + + Properties properties = new Properties(); + properties.put(ArrowFlightConnectionProperty.USER.camelName(), + userTest); + properties.put(ArrowFlightConnectionProperty.PASSWORD.camelName(), + passTest); + properties.put("useEncryption", false); + + Connection connection = DriverManager.getConnection( + String.format( + "jdbc:arrow-flight-sql://localhost:%s", + FLIGHT_SERVER_TEST_RULE.getPort()), + properties); + Assert.assertTrue(connection.isValid(0)); + connection.close(); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ConnectionTlsTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ConnectionTlsTest.java new file mode 100644 index 0000000000000..a5f9938f04bcb --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ConnectionTlsTest.java @@ -0,0 +1,454 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import static org.junit.Assert.assertNotNull; + +import java.net.URLEncoder; +import java.nio.file.Paths; +import java.sql.Connection; +import java.sql.Driver; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.util.Properties; + +import org.apache.arrow.driver.jdbc.authentication.UserPasswordAuthentication; +import org.apache.arrow.driver.jdbc.client.ArrowFlightSqlClientHandler; +import org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty; +import org.apache.arrow.driver.jdbc.utils.FlightSqlTestCertificates; +import org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.util.Preconditions; +import org.apache.calcite.avatica.org.apache.http.auth.UsernamePasswordCredentials; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Test; + +/** + * Tests encrypted connections. + */ +public class ConnectionTlsTest { + + @ClassRule + public static final FlightServerTestRule FLIGHT_SERVER_TEST_RULE; + private static final MockFlightSqlProducer PRODUCER = new MockFlightSqlProducer(); + private static final String userTest = "user1"; + private static final String passTest = "pass1"; + + static { + final FlightSqlTestCertificates.CertKeyPair + certKey = FlightSqlTestCertificates.exampleTlsCerts().get(0); + + UserPasswordAuthentication authentication = new UserPasswordAuthentication.Builder() + .user(userTest, passTest) + .build(); + + FLIGHT_SERVER_TEST_RULE = new FlightServerTestRule.Builder() + .host("localhost") + .randomPort() + .authentication(authentication) + .useEncryption(certKey.cert, certKey.key) + .producer(PRODUCER) + .build(); + } + + private String trustStorePath; + private String noCertificateKeyStorePath; + private final String trustStorePass = "flight"; + private BufferAllocator allocator; + + @Before + public void setUp() throws Exception { + trustStorePath = Paths.get( + Preconditions.checkNotNull(getClass().getResource("/keys/keyStore.jks")).toURI()).toString(); + noCertificateKeyStorePath = Paths.get( + Preconditions.checkNotNull(getClass().getResource("/keys/noCertificate.jks")).toURI()).toString(); + allocator = new RootAllocator(Long.MAX_VALUE); + } + + @After + public void tearDown() throws Exception { + allocator.getChildAllocators().forEach(BufferAllocator::close); + AutoCloseables.close(allocator); + } + + /** + * Try to instantiate an encrypted FlightClient. + * + * @throws Exception on error. + */ + @Test + public void testGetEncryptedClientAuthenticatedWithDisableCertVerification() throws Exception { + final UsernamePasswordCredentials credentials = new UsernamePasswordCredentials( + userTest, passTest); + + try (ArrowFlightSqlClientHandler client = + new ArrowFlightSqlClientHandler.Builder() + .withHost(FLIGHT_SERVER_TEST_RULE.getHost()) + .withPort(FLIGHT_SERVER_TEST_RULE.getPort()) + .withUsername(credentials.getUserName()) + .withPassword(credentials.getPassword()) + .withDisableCertificateVerification(true) + .withBufferAllocator(allocator) + .withEncryption(true) + .build()) { + assertNotNull(client); + } + } + + /** + * Try to instantiate an encrypted FlightClient. + * + * @throws Exception on error. + */ + @Test + public void testGetEncryptedClientAuthenticated() throws Exception { + final UsernamePasswordCredentials credentials = new UsernamePasswordCredentials( + userTest, passTest); + + try (ArrowFlightSqlClientHandler client = + new ArrowFlightSqlClientHandler.Builder() + .withHost(FLIGHT_SERVER_TEST_RULE.getHost()) + .withPort(FLIGHT_SERVER_TEST_RULE.getPort()) + .withUsername(credentials.getUserName()) + .withPassword(credentials.getPassword()) + .withTrustStorePath(trustStorePath) + .withTrustStorePassword(trustStorePass) + .withBufferAllocator(allocator) + .withEncryption(true) + .build()) { + assertNotNull(client); + } + } + + /** + * Try to instantiate an encrypted FlightClient providing a keystore without certificate. It's expected to + * receive the SQLException. + * + * @throws Exception on error. + */ + @Test(expected = SQLException.class) + public void testGetEncryptedClientWithNoCertificateOnKeyStore() throws Exception { + final String noCertificateKeyStorePassword = "flight1"; + + try (ArrowFlightSqlClientHandler ignored = + new ArrowFlightSqlClientHandler.Builder() + .withHost(FLIGHT_SERVER_TEST_RULE.getHost()) + .withTrustStorePath(noCertificateKeyStorePath) + .withTrustStorePassword(noCertificateKeyStorePassword) + .withBufferAllocator(allocator) + .withEncryption(true) + .build()) { + Assert.fail(); + } + } + + /** + * Try to instantiate an encrypted FlightClient without credentials. + * + * @throws Exception on error. + */ + @Test + public void testGetNonAuthenticatedEncryptedClientNoAuth() throws Exception { + try (ArrowFlightSqlClientHandler client = + new ArrowFlightSqlClientHandler.Builder() + .withHost(FLIGHT_SERVER_TEST_RULE.getHost()) + .withTrustStorePath(trustStorePath) + .withTrustStorePassword(trustStorePass) + .withBufferAllocator(allocator) + .withEncryption(true) + .build()) { + assertNotNull(client); + } + } + + /** + * Try to instantiate an encrypted FlightClient with an invalid password to the keystore file. + * It's expected to receive the SQLException. + * + * @throws Exception on error. + */ + @Test(expected = SQLException.class) + public void testGetEncryptedClientWithKeyStoreBadPasswordAndNoAuth() throws Exception { + String keyStoreBadPassword = "badPassword"; + + try (ArrowFlightSqlClientHandler ignored = + new ArrowFlightSqlClientHandler.Builder() + .withHost(FLIGHT_SERVER_TEST_RULE.getHost()) + .withTrustStorePath(trustStorePath) + .withTrustStorePassword(keyStoreBadPassword) + .withBufferAllocator(allocator) + .withEncryption(true) + .build()) { + Assert.fail(); + } + } + + /** + * Check if an encrypted connection can be established successfully when the + * provided valid credentials and a valid Keystore. + * + * @throws Exception on error. + */ + @Test + public void testGetEncryptedConnectionWithValidCredentialsAndKeyStore() throws Exception { + final Properties properties = new Properties(); + + properties.put(ArrowFlightConnectionProperty.HOST.camelName(), "localhost"); + properties.put(ArrowFlightConnectionProperty.PORT.camelName(), + FLIGHT_SERVER_TEST_RULE.getPort()); + properties.put(ArrowFlightConnectionProperty.USER.camelName(), + userTest); + properties.put(ArrowFlightConnectionProperty.PASSWORD.camelName(), + passTest); + properties.put(ArrowFlightConnectionProperty.TRUST_STORE.camelName(), trustStorePath); + properties.put(ArrowFlightConnectionProperty.USE_SYSTEM_TRUST_STORE.camelName(), false); + properties.put(ArrowFlightConnectionProperty.TRUST_STORE_PASSWORD.camelName(), trustStorePass); + + final ArrowFlightJdbcDataSource dataSource = + ArrowFlightJdbcDataSource.createNewDataSource(properties); + try (final Connection connection = dataSource.getConnection()) { + assert connection.isValid(300); + } + } + + /** + * Check if the SQLException is thrown when trying to establish an encrypted connection + * providing valid credentials but invalid password to the Keystore. + * + * @throws SQLException on error. + */ + @Test(expected = SQLException.class) + public void testGetAuthenticatedEncryptedConnectionWithKeyStoreBadPassword() throws Exception { + final Properties properties = new Properties(); + + properties.put(ArrowFlightConnectionProperty.HOST.camelName(), + FLIGHT_SERVER_TEST_RULE.getHost()); + properties.put(ArrowFlightConnectionProperty.PORT.camelName(), + FLIGHT_SERVER_TEST_RULE.getPort()); + properties.put(ArrowFlightConnectionProperty.USER.camelName(), + userTest); + properties.put(ArrowFlightConnectionProperty.PASSWORD.camelName(), + passTest); + properties.put(ArrowFlightConnectionProperty.USE_ENCRYPTION.camelName(), true); + properties.put(ArrowFlightConnectionProperty.TRUST_STORE.camelName(), trustStorePath); + properties.put(ArrowFlightConnectionProperty.TRUST_STORE_PASSWORD.camelName(), "badpassword"); + + final ArrowFlightJdbcDataSource dataSource = + ArrowFlightJdbcDataSource.createNewDataSource(properties); + try (final Connection ignored = dataSource.getConnection()) { + Assert.fail(); + } + } + + /** + * Check if an encrypted connection can be established successfully when not providing authentication. + * + * @throws Exception on error. + */ + @Test + public void testGetNonAuthenticatedEncryptedConnection() throws Exception { + final Properties properties = new Properties(); + + properties.put(ArrowFlightConnectionProperty.HOST.camelName(), FLIGHT_SERVER_TEST_RULE.getHost()); + properties.put(ArrowFlightConnectionProperty.PORT.camelName(), FLIGHT_SERVER_TEST_RULE.getPort()); + properties.put(ArrowFlightConnectionProperty.USE_ENCRYPTION.camelName(), true); + properties.put(ArrowFlightConnectionProperty.USE_SYSTEM_TRUST_STORE.camelName(), false); + properties.put(ArrowFlightConnectionProperty.TRUST_STORE.camelName(), trustStorePath); + properties.put(ArrowFlightConnectionProperty.TRUST_STORE_PASSWORD.camelName(), trustStorePass); + + final ArrowFlightJdbcDataSource dataSource = ArrowFlightJdbcDataSource.createNewDataSource(properties); + try (final Connection connection = dataSource.getConnection()) { + assert connection.isValid(300); + } + } + + /** + * Check if an encrypted connection can be established successfully when connecting through the DriverManager using + * just a connection url. + * + * @throws Exception on error. + */ + @Test + public void testTLSConnectionPropertyTrueCorrectCastUrlWithDriverManager() throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + DriverManager.registerDriver(driver); + + final Connection connection = DriverManager.getConnection( + String.format( + "jdbc:arrow-flight-sql://localhost:%s?user=%s&password=%s" + + "&useEncryption=true&useSystemTrustStore=false&%s=%s&%s=%s", + FLIGHT_SERVER_TEST_RULE.getPort(), + userTest, + passTest, + ArrowFlightConnectionProperty.TRUST_STORE.camelName(), + URLEncoder.encode(trustStorePath, "UTF-8"), + ArrowFlightConnectionProperty.TRUST_STORE_PASSWORD.camelName(), + URLEncoder.encode(trustStorePass, "UTF-8"))); + Assert.assertTrue(connection.isValid(0)); + connection.close(); + } + + /** + * Check if an encrypted connection can be established successfully when connecting through the DriverManager using + * a connection url and properties with String K-V pairs. + * + * @throws Exception on error. + */ + @Test + public void testTLSConnectionPropertyTrueCorrectCastUrlAndPropertiesUsingSetPropertyWithDriverManager() + throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + DriverManager.registerDriver(driver); + + Properties properties = new Properties(); + + properties.setProperty(ArrowFlightConnectionProperty.USER.camelName(), userTest); + properties.setProperty(ArrowFlightConnectionProperty.PASSWORD.camelName(), passTest); + properties.setProperty(ArrowFlightConnectionProperty.TRUST_STORE.camelName(), trustStorePath); + properties.setProperty(ArrowFlightConnectionProperty.TRUST_STORE_PASSWORD.camelName(), trustStorePass); + properties.setProperty(ArrowFlightConnectionProperty.USE_ENCRYPTION.camelName(), "true"); + properties.setProperty(ArrowFlightConnectionProperty.USE_SYSTEM_TRUST_STORE.camelName(), "false"); + + final Connection connection = DriverManager.getConnection( + String.format( + "jdbc:arrow-flight-sql://localhost:%s", + FLIGHT_SERVER_TEST_RULE.getPort()), + properties); + Assert.assertTrue(connection.isValid(0)); + connection.close(); + } + + /** + * Check if an encrypted connection can be established successfully when connecting through the DriverManager using + * a connection url and properties with Object K-V pairs. + * + * @throws Exception on error. + */ + @Test + public void testTLSConnectionPropertyTrueCorrectCastUrlAndPropertiesUsingPutWithDriverManager() + throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + DriverManager.registerDriver(driver); + + Properties properties = new Properties(); + + properties.put(ArrowFlightConnectionProperty.USER.camelName(), userTest); + properties.put(ArrowFlightConnectionProperty.PASSWORD.camelName(), passTest); + properties.put(ArrowFlightConnectionProperty.USE_ENCRYPTION.camelName(), true); + properties.put(ArrowFlightConnectionProperty.USE_SYSTEM_TRUST_STORE.camelName(), false); + properties.put(ArrowFlightConnectionProperty.TRUST_STORE.camelName(), trustStorePath); + properties.put(ArrowFlightConnectionProperty.TRUST_STORE_PASSWORD.camelName(), trustStorePass); + + final Connection connection = DriverManager.getConnection( + String.format( + "jdbc:arrow-flight-sql://localhost:%s", + FLIGHT_SERVER_TEST_RULE.getPort()), + properties); + Assert.assertTrue(connection.isValid(0)); + connection.close(); + } + + /** + * Check if an encrypted connection can be established successfully when connecting through the DriverManager using + * just a connection url and using 0 and 1 as ssl values. + * + * @throws Exception on error. + */ + @Test + public void testTLSConnectionPropertyTrueIntegerCorrectCastUrlWithDriverManager() + throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + DriverManager.registerDriver(driver); + + final Connection connection = DriverManager.getConnection( + String.format( + "jdbc:arrow-flight-sql://localhost:%s?user=%s&password=%s" + + "&useEncryption=1&useSystemTrustStore=0&%s=%s&%s=%s", + FLIGHT_SERVER_TEST_RULE.getPort(), + userTest, + passTest, + ArrowFlightConnectionProperty.TRUST_STORE.camelName(), + URLEncoder.encode(trustStorePath, "UTF-8"), + ArrowFlightConnectionProperty.TRUST_STORE_PASSWORD.camelName(), + URLEncoder.encode(trustStorePass, "UTF-8"))); + Assert.assertTrue(connection.isValid(0)); + connection.close(); + } + + /** + * Check if an encrypted connection can be established successfully when connecting through the DriverManager using + * a connection url and properties with String K-V pairs and using 0 and 1 as ssl values. + * + * @throws Exception on error. + */ + @Test + public void testTLSConnectionPropertyTrueIntegerCorrectCastUrlAndPropertiesUsingSetPropertyWithDriverManager() + throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + DriverManager.registerDriver(driver); + + Properties properties = new Properties(); + + properties.setProperty(ArrowFlightConnectionProperty.USER.camelName(), userTest); + properties.setProperty(ArrowFlightConnectionProperty.PASSWORD.camelName(), passTest); + properties.setProperty(ArrowFlightConnectionProperty.TRUST_STORE.camelName(), trustStorePath); + properties.setProperty(ArrowFlightConnectionProperty.TRUST_STORE_PASSWORD.camelName(), trustStorePass); + properties.setProperty(ArrowFlightConnectionProperty.USE_ENCRYPTION.camelName(), "1"); + properties.setProperty(ArrowFlightConnectionProperty.USE_SYSTEM_TRUST_STORE.camelName(), "0"); + + final Connection connection = DriverManager.getConnection( + String.format("jdbc:arrow-flight-sql://localhost:%s", FLIGHT_SERVER_TEST_RULE.getPort()), + properties); + Assert.assertTrue(connection.isValid(0)); + connection.close(); + } + + /** + * Check if an encrypted connection can be established successfully when connecting through the DriverManager using + * a connection url and properties with Object K-V pairs and using 0 and 1 as ssl values. + * + * @throws Exception on error. + */ + @Test + public void testTLSConnectionPropertyTrueIntegerCorrectCastUrlAndPropertiesUsingPutWithDriverManager() + throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + DriverManager.registerDriver(driver); + + Properties properties = new Properties(); + + properties.put(ArrowFlightConnectionProperty.USER.camelName(), userTest); + properties.put(ArrowFlightConnectionProperty.PASSWORD.camelName(), passTest); + properties.put(ArrowFlightConnectionProperty.USE_ENCRYPTION.camelName(), 1); + properties.put(ArrowFlightConnectionProperty.USE_SYSTEM_TRUST_STORE.camelName(), 0); + properties.put(ArrowFlightConnectionProperty.TRUST_STORE.camelName(), trustStorePath); + properties.put(ArrowFlightConnectionProperty.TRUST_STORE_PASSWORD.camelName(), trustStorePass); + + final Connection connection = DriverManager.getConnection( + String.format("jdbc:arrow-flight-sql://localhost:%s", + FLIGHT_SERVER_TEST_RULE.getPort()), + properties); + Assert.assertTrue(connection.isValid(0)); + connection.close(); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/FlightServerTestRule.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/FlightServerTestRule.java new file mode 100644 index 0000000000000..b251b7df1645b --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/FlightServerTestRule.java @@ -0,0 +1,365 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import static org.apache.arrow.driver.jdbc.utils.FlightSqlTestCertificates.CertKeyPair; + +import java.io.File; +import java.io.IOException; +import java.lang.reflect.Method; +import java.sql.Connection; +import java.sql.SQLException; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.Properties; + +import org.apache.arrow.driver.jdbc.authentication.Authentication; +import org.apache.arrow.driver.jdbc.authentication.TokenAuthentication; +import org.apache.arrow.driver.jdbc.authentication.UserPasswordAuthentication; +import org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl; +import org.apache.arrow.flight.CallHeaders; +import org.apache.arrow.flight.CallInfo; +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.FlightServerMiddleware; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.RequestContext; +import org.apache.arrow.flight.sql.FlightSqlProducer; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.util.Preconditions; +import org.junit.rules.TestRule; +import org.junit.runner.Description; +import org.junit.runners.model.Statement; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import me.alexpanov.net.FreePortFinder; + +/** + * Utility class for unit tests that need to instantiate a {@link FlightServer} + * and interact with it. + */ +public class FlightServerTestRule implements TestRule, AutoCloseable { + private static final Logger LOGGER = LoggerFactory.getLogger(FlightServerTestRule.class); + + private final Properties properties; + private final ArrowFlightConnectionConfigImpl config; + private final BufferAllocator allocator; + private final FlightSqlProducer producer; + private final Authentication authentication; + private final CertKeyPair certKeyPair; + + private final MiddlewareCookie.Factory middlewareCookieFactory = new MiddlewareCookie.Factory(); + + private FlightServerTestRule(final Properties properties, + final ArrowFlightConnectionConfigImpl config, + final BufferAllocator allocator, + final FlightSqlProducer producer, + final Authentication authentication, + final CertKeyPair certKeyPair) { + this.properties = Preconditions.checkNotNull(properties); + this.config = Preconditions.checkNotNull(config); + this.allocator = Preconditions.checkNotNull(allocator); + this.producer = Preconditions.checkNotNull(producer); + this.authentication = authentication; + this.certKeyPair = certKeyPair; + } + + /** + * Create a {@link FlightServerTestRule} with standard values such as: user, password, localhost. + * + * @param producer the producer used to create the FlightServerTestRule. + * @return the FlightServerTestRule. + */ + public static FlightServerTestRule createStandardTestRule(final FlightSqlProducer producer) { + UserPasswordAuthentication authentication = + new UserPasswordAuthentication.Builder() + .user("flight-test-user", "flight-test-password") + .build(); + + return new Builder() + .host("localhost") + .randomPort() + .authentication(authentication) + .producer(producer) + .build(); + } + + ArrowFlightJdbcDataSource createDataSource() { + return ArrowFlightJdbcDataSource.createNewDataSource(properties); + } + + ArrowFlightJdbcDataSource createDataSource(String token) { + properties.put("token", token); + return ArrowFlightJdbcDataSource.createNewDataSource(properties); + } + + public ArrowFlightJdbcConnectionPoolDataSource createConnectionPoolDataSource() { + return ArrowFlightJdbcConnectionPoolDataSource.createNewDataSource(properties); + } + + public ArrowFlightJdbcConnectionPoolDataSource createConnectionPoolDataSource(boolean useEncryption) { + setUseEncryption(useEncryption); + return ArrowFlightJdbcConnectionPoolDataSource.createNewDataSource(properties); + } + + public Connection getConnection(boolean useEncryption, String token) throws SQLException { + properties.put("token", token); + + return getConnection(useEncryption); + } + + public Connection getConnection(boolean useEncryption) throws SQLException { + setUseEncryption(useEncryption); + return this.createDataSource().getConnection(); + } + + private void setUseEncryption(boolean useEncryption) { + properties.put("useEncryption", useEncryption); + } + + public MiddlewareCookie.Factory getMiddlewareCookieFactory() { + return middlewareCookieFactory; + } + + @FunctionalInterface + public interface CheckedFunction { + R apply(T t) throws IOException; + } + + private FlightServer initiateServer(Location location) throws IOException { + FlightServer.Builder builder = FlightServer.builder(allocator, location, producer) + .headerAuthenticator(authentication.authenticate()) + .middleware(FlightServerMiddleware.Key.of("KEY"), middlewareCookieFactory); + if (certKeyPair != null) { + builder.useTls(certKeyPair.cert, certKeyPair.key); + } + return builder.build(); + } + + @Override + public Statement apply(Statement base, Description description) { + return new Statement() { + @Override + public void evaluate() throws Throwable { + try (FlightServer flightServer = + getStartServer(location -> + initiateServer(location), 3)) { + LOGGER.info("Started " + FlightServer.class.getName() + " as " + flightServer); + base.evaluate(); + } finally { + close(); + } + } + }; + } + + private FlightServer getStartServer(CheckedFunction newServerFromLocation, + int retries) + throws IOException { + + final Deque exceptions = new ArrayDeque<>(); + + for (; retries > 0; retries--) { + final Location location = Location.forGrpcInsecure(config.getHost(), config.getPort()); + final FlightServer server = newServerFromLocation.apply(location); + try { + Method start = server.getClass().getMethod("start"); + start.setAccessible(true); + start.invoke(server); + return server; + } catch (ReflectiveOperationException e) { + exceptions.add(e); + } + } + + exceptions.forEach( + e -> LOGGER.error("Failed to start a new " + FlightServer.class.getName() + ".", e)); + throw new IOException(exceptions.pop().getCause()); + } + + /** + * Sets a port to be used. + * + * @return the port value. + */ + public int getPort() { + return config.getPort(); + } + + /** + * Sets a host to be used. + * + * @return the host value. + */ + public String getHost() { + return config.getHost(); + } + + @Override + public void close() throws Exception { + allocator.getChildAllocators().forEach(BufferAllocator::close); + AutoCloseables.close(allocator); + } + + /** + * Builder for {@link FlightServerTestRule}. + */ + public static final class Builder { + private final Properties properties = new Properties(); + private FlightSqlProducer producer; + private Authentication authentication; + private CertKeyPair certKeyPair; + + /** + * Sets the host for the server rule. + * + * @param host the host value. + * @return the Builder. + */ + public Builder host(final String host) { + properties.put(ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.HOST.camelName(), + host); + return this; + } + + /** + * Sets a random port to be used by the server rule. + * + * @return the Builder. + */ + public Builder randomPort() { + properties.put(ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.PORT.camelName(), + FreePortFinder.findFreeLocalPort()); + return this; + } + + /** + * Sets a specific port to be used by the server rule. + * + * @param port the port value. + * @return the Builder. + */ + public Builder port(final int port) { + properties.put(ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.PORT.camelName(), + port); + return this; + } + + /** + * Sets the producer that will be used in the server rule. + * + * @param producer the flight sql producer. + * @return the Builder. + */ + public Builder producer(final FlightSqlProducer producer) { + this.producer = producer; + return this; + } + + /** + * Sets the type of the authentication that will be used in the server rules. + * There are two types of authentication: {@link UserPasswordAuthentication} and + * {@link TokenAuthentication}. + * + * @param authentication the type of authentication. + * @return the Builder. + */ + public Builder authentication(final Authentication authentication) { + this.authentication = authentication; + return this; + } + + /** + * Enable TLS on the server. + * + * @param certChain The certificate chain to use. + * @param key The private key to use. + * @return the Builder. + */ + public Builder useEncryption(final File certChain, final File key) { + certKeyPair = new CertKeyPair(certChain, key); + return this; + } + + /** + * Builds the {@link FlightServerTestRule} using the provided values. + * + * @return a {@link FlightServerTestRule}. + */ + public FlightServerTestRule build() { + authentication.populateProperties(properties); + return new FlightServerTestRule(properties, new ArrowFlightConnectionConfigImpl(properties), + new RootAllocator(Long.MAX_VALUE), producer, authentication, certKeyPair); + } + } + + /** + * A middleware to handle with the cookies in the server. It is used to test if cookies are + * being sent properly. + */ + static class MiddlewareCookie implements FlightServerMiddleware { + + private final Factory factory; + + public MiddlewareCookie(Factory factory) { + this.factory = factory; + } + + @Override + public void onBeforeSendingHeaders(CallHeaders callHeaders) { + if (!factory.receivedCookieHeader) { + callHeaders.insert("Set-Cookie", "k=v"); + } + } + + @Override + public void onCallCompleted(CallStatus callStatus) { + + } + + @Override + public void onCallErrored(Throwable throwable) { + + } + + /** + * A factory for the MiddlewareCookie. + */ + static class Factory implements FlightServerMiddleware.Factory { + + private boolean receivedCookieHeader = false; + private String cookie; + + @Override + public MiddlewareCookie onCallStarted(CallInfo callInfo, CallHeaders callHeaders, + RequestContext requestContext) { + cookie = callHeaders.get("Cookie"); + receivedCookieHeader = null != cookie; + return new MiddlewareCookie(this); + } + + public String getCookie() { + return cookie; + } + } + } + +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ResultSetMetadataTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ResultSetMetadataTest.java new file mode 100644 index 0000000000000..64ec7f7d9e1a5 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ResultSetMetadataTest.java @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.notNullValue; + +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.SQLException; +import java.sql.Statement; +import java.sql.Types; + +import org.apache.arrow.driver.jdbc.utils.CoreMockedSqlProducers; +import org.hamcrest.CoreMatchers; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; + +public class ResultSetMetadataTest { + private static ResultSetMetaData metadata; + + private static Connection connection; + + @Rule + public ErrorCollector collector = new ErrorCollector(); + + @ClassRule + public static final FlightServerTestRule SERVER_TEST_RULE = FlightServerTestRule + .createStandardTestRule(CoreMockedSqlProducers.getLegacyProducer()); + + @BeforeClass + public static void setup() throws SQLException { + connection = SERVER_TEST_RULE.getConnection(false); + + try (Statement statement = connection.createStatement(); + ResultSet resultSet = statement.executeQuery( + CoreMockedSqlProducers.LEGACY_METADATA_SQL_CMD)) { + metadata = resultSet.getMetaData(); + } + } + + @AfterClass + public static void teardown() throws SQLException { + connection.close(); + } + + /** + * Test if {@link ResultSetMetaData} object is not null. + */ + @Test + public void testShouldGetResultSetMetadata() { + collector.checkThat(metadata, CoreMatchers.is(notNullValue())); + } + + /** + * Test if {@link ResultSetMetaData#getColumnCount()} returns the correct values. + * + * @throws SQLException in case of error. + */ + @Test + public void testShouldGetColumnCount() throws SQLException { + final int columnCount = metadata.getColumnCount(); + + assert columnCount == 3; + } + + /** + * Test if {@link ResultSetMetaData#getColumnTypeName(int)} returns the correct type name for each + * column. + * + * @throws SQLException in case of error. + */ + @Test + public void testShouldGetColumnTypesName() throws SQLException { + final String firstColumn = metadata.getColumnTypeName(1); + final String secondColumn = metadata.getColumnTypeName(2); + final String thirdColumn = metadata.getColumnTypeName(3); + + collector.checkThat(firstColumn, equalTo("BIGINT")); + collector.checkThat(secondColumn, equalTo("VARCHAR")); + collector.checkThat(thirdColumn, equalTo("FLOAT")); + } + + /** + * Test if {@link ResultSetMetaData#getColumnTypeName(int)} passing an column index that does not exist. + * + * @throws SQLException in case of error. + */ + @Test(expected = IndexOutOfBoundsException.class) + public void testShouldGetColumnTypesNameFromOutOfBoundIndex() throws SQLException { + metadata.getColumnTypeName(4); + + Assert.fail(); + } + + /** + * Test if {@link ResultSetMetaData#getColumnName(int)} returns the correct name for each column. + * + * @throws SQLException in case of error. + */ + @Test + public void testShouldGetColumnNames() throws SQLException { + final String firstColumn = metadata.getColumnName(1); + final String secondColumn = metadata.getColumnName(2); + final String thirdColumn = metadata.getColumnName(3); + + collector.checkThat(firstColumn, equalTo("integer0")); + collector.checkThat(secondColumn, equalTo("string1")); + collector.checkThat(thirdColumn, equalTo("float2")); + } + + + /** + * Test {@link ResultSetMetaData#getColumnTypeName(int)} passing an column index that does not exist. + * + * @throws SQLException in case of error. + */ + @Test(expected = IndexOutOfBoundsException.class) + public void testShouldGetColumnNameFromOutOfBoundIndex() throws SQLException { + metadata.getColumnName(4); + + Assert.fail(); + } + + /** + * Test if {@link ResultSetMetaData#getColumnType(int)}returns the correct values. + * + * @throws SQLException in case of error. + */ + @Test + public void testShouldGetColumnType() throws SQLException { + final int firstColumn = metadata.getColumnType(1); + final int secondColumn = metadata.getColumnType(2); + final int thirdColumn = metadata.getColumnType(3); + + collector.checkThat(firstColumn, equalTo(Types.BIGINT)); + collector.checkThat(secondColumn, equalTo(Types.VARCHAR)); + collector.checkThat(thirdColumn, equalTo(Types.FLOAT)); + } + + @Test + public void testShouldGetPrecision() throws SQLException { + collector.checkThat(metadata.getPrecision(1), equalTo(10)); + collector.checkThat(metadata.getPrecision(2), equalTo(65535)); + collector.checkThat(metadata.getPrecision(3), equalTo(15)); + } + + @Test + public void testShouldGetScale() throws SQLException { + collector.checkThat(metadata.getScale(1), equalTo(0)); + collector.checkThat(metadata.getScale(2), equalTo(0)); + collector.checkThat(metadata.getScale(3), equalTo(20)); + } + + @Test + public void testShouldGetCatalogName() throws SQLException { + collector.checkThat(metadata.getCatalogName(1), equalTo("CATALOG_NAME_1")); + collector.checkThat(metadata.getCatalogName(2), equalTo("CATALOG_NAME_2")); + collector.checkThat(metadata.getCatalogName(3), equalTo("CATALOG_NAME_3")); + } + + @Test + public void testShouldGetSchemaName() throws SQLException { + collector.checkThat(metadata.getSchemaName(1), equalTo("SCHEMA_NAME_1")); + collector.checkThat(metadata.getSchemaName(2), equalTo("SCHEMA_NAME_2")); + collector.checkThat(metadata.getSchemaName(3), equalTo("SCHEMA_NAME_3")); + } + + @Test + public void testShouldGetTableName() throws SQLException { + collector.checkThat(metadata.getTableName(1), equalTo("TABLE_NAME_1")); + collector.checkThat(metadata.getTableName(2), equalTo("TABLE_NAME_2")); + collector.checkThat(metadata.getTableName(3), equalTo("TABLE_NAME_3")); + } + + @Test + public void testShouldIsAutoIncrement() throws SQLException { + collector.checkThat(metadata.isAutoIncrement(1), equalTo(true)); + collector.checkThat(metadata.isAutoIncrement(2), equalTo(false)); + collector.checkThat(metadata.isAutoIncrement(3), equalTo(false)); + } + + @Test + public void testShouldIsCaseSensitive() throws SQLException { + collector.checkThat(metadata.isCaseSensitive(1), equalTo(false)); + collector.checkThat(metadata.isCaseSensitive(2), equalTo(true)); + collector.checkThat(metadata.isCaseSensitive(3), equalTo(false)); + } + + @Test + public void testShouldIsReadonly() throws SQLException { + collector.checkThat(metadata.isReadOnly(1), equalTo(true)); + collector.checkThat(metadata.isReadOnly(2), equalTo(false)); + collector.checkThat(metadata.isReadOnly(3), equalTo(false)); + } + + @Test + public void testShouldIsSearchable() throws SQLException { + collector.checkThat(metadata.isSearchable(1), equalTo(true)); + collector.checkThat(metadata.isSearchable(2), equalTo(true)); + collector.checkThat(metadata.isSearchable(3), equalTo(true)); + } + + /** + * Test if {@link ResultSetMetaData#getColumnTypeName(int)} passing an column index that does not exist. + * + * @throws SQLException in case of error. + */ + @Test(expected = IndexOutOfBoundsException.class) + public void testShouldGetColumnTypesFromOutOfBoundIndex() throws SQLException { + metadata.getColumnType(4); + + Assert.fail(); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ResultSetTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ResultSetTest.java new file mode 100644 index 0000000000000..33473b6fe2baa --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ResultSetTest.java @@ -0,0 +1,377 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import static java.lang.String.format; +import static java.util.Collections.synchronizedSet; +import static org.hamcrest.CoreMatchers.allOf; +import static org.hamcrest.CoreMatchers.anyOf; +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.CoreMatchers.is; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.SQLTimeoutException; +import java.sql.Statement; +import java.util.HashSet; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.CountDownLatch; + +import org.apache.arrow.driver.jdbc.utils.CoreMockedSqlProducers; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; + +import com.google.common.collect.ImmutableSet; + +public class ResultSetTest { + private static final Random RANDOM = new Random(10); + @ClassRule + public static final FlightServerTestRule SERVER_TEST_RULE = FlightServerTestRule + .createStandardTestRule(CoreMockedSqlProducers.getLegacyProducer()); + private static Connection connection; + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + + @BeforeClass + public static void setup() throws SQLException { + connection = SERVER_TEST_RULE.getConnection(false); + } + + @AfterClass + public static void tearDown() throws SQLException { + connection.close(); + } + + private static void resultSetNextUntilDone(ResultSet resultSet) throws SQLException { + while (resultSet.next()) { + // TODO: implement resultSet.last() + // Pass to the next until resultSet is done + } + } + + private static void setMaxRowsLimit(int maxRowsLimit, Statement statement) throws SQLException { + statement.setLargeMaxRows(maxRowsLimit); + } + + /** + * Tests whether the {@link ArrowFlightJdbcDriver} can run a query successfully. + * + * @throws Exception If the connection fails to be established. + */ + @Test + public void testShouldRunSelectQuery() throws Exception { + try (Statement statement = connection.createStatement(); + ResultSet resultSet = statement.executeQuery( + CoreMockedSqlProducers.LEGACY_REGULAR_SQL_CMD)) { + CoreMockedSqlProducers.assertLegacyRegularSqlResultSet(resultSet, collector); + } + } + + @Test + public void testShouldExecuteQueryNotBlockIfClosedBeforeEnd() throws Exception { + try (Statement statement = connection.createStatement(); + ResultSet resultSet = statement.executeQuery( + CoreMockedSqlProducers.LEGACY_REGULAR_SQL_CMD)) { + + for (int i = 0; i < 7500; i++) { + assertTrue(resultSet.next()); + } + } + } + + /** + * Tests whether the {@link ArrowFlightJdbcDriver} query only returns only the + * amount of value set by {@link org.apache.calcite.avatica.AvaticaStatement#setMaxRows(int)}. + * + * @throws Exception If the connection fails to be established. + */ + @Test + public void testShouldRunSelectQuerySettingMaxRowLimit() throws Exception { + try (Statement statement = connection.createStatement(); + ResultSet resultSet = statement.executeQuery( + CoreMockedSqlProducers.LEGACY_REGULAR_SQL_CMD)) { + + final int maxRowsLimit = 3; + statement.setMaxRows(maxRowsLimit); + + collector.checkThat(statement.getMaxRows(), is(maxRowsLimit)); + + int count = 0; + int columns = 6; + for (; resultSet.next(); count++) { + for (int column = 1; column <= columns; column++) { + resultSet.getObject(column); + } + collector.checkThat("Test Name #" + count, is(resultSet.getString(2))); + } + + collector.checkThat(maxRowsLimit, is(count)); + } + } + + /** + * Tests whether the {@link ArrowFlightJdbcDriver} fails upon attempting + * to run an invalid query. + * + * @throws Exception If the connection fails to be established. + */ + @Test(expected = SQLException.class) + public void testShouldThrowExceptionUponAttemptingToExecuteAnInvalidSelectQuery() + throws Exception { + Statement statement = connection.createStatement(); + statement.executeQuery("SELECT * FROM SHOULD-FAIL"); + fail(); + } + + /** + * Tests whether the {@link ArrowFlightJdbcDriver} query only returns only the + * amount of value set by {@link org.apache.calcite.avatica.AvaticaStatement#setLargeMaxRows(long)} (int)}. + * + * @throws Exception If the connection fails to be established. + */ + @Test + public void testShouldRunSelectQuerySettingLargeMaxRowLimit() throws Exception { + try (Statement statement = connection.createStatement(); + ResultSet resultSet = statement.executeQuery( + CoreMockedSqlProducers.LEGACY_REGULAR_SQL_CMD)) { + + final long maxRowsLimit = 3; + statement.setLargeMaxRows(maxRowsLimit); + + collector.checkThat(statement.getLargeMaxRows(), is(maxRowsLimit)); + + int count = 0; + int columns = resultSet.getMetaData().getColumnCount(); + for (; resultSet.next(); count++) { + for (int column = 1; column <= columns; column++) { + resultSet.getObject(column); + } + assertEquals("Test Name #" + count, resultSet.getString(2)); + } + + assertEquals(maxRowsLimit, count); + } + } + + @Test + public void testColumnCountShouldRemainConsistentForResultSetThroughoutEntireDuration() + throws SQLException { + final Set counts = new HashSet<>(); + try (final Statement statement = connection.createStatement(); + final ResultSet resultSet = statement.executeQuery( + CoreMockedSqlProducers.LEGACY_REGULAR_SQL_CMD)) { + while (resultSet.next()) { + counts.add(resultSet.getMetaData().getColumnCount()); + } + } + collector.checkThat(counts, is(ImmutableSet.of(6))); + } + + /** + * Tests whether the {@link ArrowFlightJdbcDriver} close the statement after complete ResultSet + * when call {@link org.apache.calcite.avatica.AvaticaStatement#closeOnCompletion()}. + * + * @throws Exception If the connection fails to be established. + */ + @Test + public void testShouldCloseStatementWhenIsCloseOnCompletion() throws Exception { + Statement statement = connection.createStatement(); + ResultSet resultSet = statement.executeQuery(CoreMockedSqlProducers.LEGACY_REGULAR_SQL_CMD); + + statement.closeOnCompletion(); + + resultSetNextUntilDone(resultSet); + + collector.checkThat(statement.isClosed(), is(true)); + } + + /** + * Tests whether the {@link ArrowFlightJdbcDriver} close the statement after complete ResultSet with max rows limit + * when call {@link org.apache.calcite.avatica.AvaticaStatement#closeOnCompletion()}. + * + * @throws Exception If the connection fails to be established. + */ + @Test + public void testShouldCloseStatementWhenIsCloseOnCompletionWithMaxRowsLimit() throws Exception { + Statement statement = connection.createStatement(); + ResultSet resultSet = statement.executeQuery(CoreMockedSqlProducers.LEGACY_REGULAR_SQL_CMD); + + final long maxRowsLimit = 3; + statement.setLargeMaxRows(maxRowsLimit); + statement.closeOnCompletion(); + + resultSetNextUntilDone(resultSet); + + collector.checkThat(statement.isClosed(), is(true)); + } + + /** + * Tests whether the {@link ArrowFlightJdbcDriver} not close the statement after complete ResultSet with max rows + * limit when call {@link org.apache.calcite.avatica.AvaticaStatement#closeOnCompletion()}. + * + * @throws Exception If the connection fails to be established. + */ + @Test + public void testShouldNotCloseStatementWhenIsNotCloseOnCompletionWithMaxRowsLimit() + throws Exception { + try (Statement statement = connection.createStatement(); + ResultSet resultSet = statement.executeQuery( + CoreMockedSqlProducers.LEGACY_REGULAR_SQL_CMD)) { + + final long maxRowsLimit = 3; + statement.setLargeMaxRows(maxRowsLimit); + + collector.checkThat(statement.isClosed(), is(false)); + resultSetNextUntilDone(resultSet); + collector.checkThat(resultSet.isClosed(), is(false)); + collector.checkThat(resultSet, is(instanceOf(ArrowFlightJdbcFlightStreamResultSet.class))); + } + } + + @Test + public void testShouldCancelQueryUponCancelAfterQueryingResultSet() throws SQLException { + try (final Statement statement = connection.createStatement(); + final ResultSet resultSet = statement.executeQuery( + CoreMockedSqlProducers.LEGACY_REGULAR_SQL_CMD)) { + final int column = RANDOM.nextInt(resultSet.getMetaData().getColumnCount()) + 1; + collector.checkThat(resultSet.isClosed(), is(false)); + collector.checkThat(resultSet.next(), is(true)); + collector.checkSucceeds(() -> resultSet.getObject(column)); + statement.cancel(); + // Should reset `ResultSet`; keep both `ResultSet` and `Connection` open. + collector.checkThat(statement.isClosed(), is(false)); + collector.checkThat(resultSet.isClosed(), is(false)); + collector.checkThat(resultSet.getMetaData().getColumnCount(), is(0)); + } + } + + @Test + public void testShouldInterruptFlightStreamsIfQueryIsCancelledMidQuerying() + throws SQLException, InterruptedException { + try (final Statement statement = connection.createStatement()) { + final CountDownLatch latch = new CountDownLatch(1); + final Set exceptions = synchronizedSet(new HashSet<>(1)); + final Thread thread = new Thread(() -> { + try (final ResultSet resultSet = statement.executeQuery( + CoreMockedSqlProducers.LEGACY_REGULAR_SQL_CMD)) { + final int cachedColumnCount = resultSet.getMetaData().getColumnCount(); + Thread.sleep(300); + while (resultSet.next()) { + resultSet.getObject(RANDOM.nextInt(cachedColumnCount) + 1); + } + } catch (final SQLException | InterruptedException e) { + exceptions.add(e); + } finally { + latch.countDown(); + } + }); + thread.setName("Test Case: interrupt query execution before first retrieval"); + thread.start(); + statement.cancel(); + thread.join(); + collector.checkThat( + exceptions.stream() + .map(Exception::getMessage) + .map(StringBuilder::new) + .reduce(StringBuilder::append) + .orElseThrow(IllegalArgumentException::new) + .toString(), + is("Statement canceled")); + } + } + + @Test + public void testShouldInterruptFlightStreamsIfQueryIsCancelledMidProcessingForTimeConsumingQueries() + throws SQLException, InterruptedException { + final String query = CoreMockedSqlProducers.LEGACY_CANCELLATION_SQL_CMD; + try (final Statement statement = connection.createStatement()) { + final Set exceptions = synchronizedSet(new HashSet<>(1)); + final Thread thread = new Thread(() -> { + try (final ResultSet ignored = statement.executeQuery(query)) { + fail(); + } catch (final SQLException e) { + exceptions.add(e); + } + }); + thread.setName("Test Case: interrupt query execution mid-process"); + thread.setPriority(Thread.MAX_PRIORITY); + thread.start(); + Thread.sleep(5000); // Let the other thread attempt to retrieve results. + statement.cancel(); + thread.join(); + collector.checkThat( + exceptions.stream() + .map(Exception::getMessage) + .map(StringBuilder::new) + .reduce(StringBuilder::append) + .orElseThrow(IllegalStateException::new) + .toString(), + anyOf(is(format("Error while executing SQL \"%s\": Query canceled", query)), + allOf(containsString(format("Error while executing SQL \"%s\"", query)), + containsString("CANCELLED")))); + } + } + + @Test + public void testShouldInterruptFlightStreamsIfQueryTimeoutIsOver() throws SQLException { + final String query = CoreMockedSqlProducers.LEGACY_CANCELLATION_SQL_CMD; + final int timeoutValue = 2; + final String timeoutUnit = "SECONDS"; + try (final Statement statement = connection.createStatement()) { + statement.setQueryTimeout(timeoutValue); + final Set exceptions = new HashSet<>(1); + try { + statement.executeQuery(query); + } catch (final Exception e) { + exceptions.add(e); + } + final Throwable comparisonCause = exceptions.stream() + .findFirst() + .orElseThrow(RuntimeException::new) + .getCause() + .getCause(); + collector.checkThat(comparisonCause, + is(instanceOf(SQLTimeoutException.class))); + collector.checkThat(comparisonCause.getMessage(), + is(format("Query timed out after %d %s", timeoutValue, timeoutUnit))); + } + } + + @Test + public void testFlightStreamsQueryShouldNotTimeout() throws SQLException { + final String query = CoreMockedSqlProducers.LEGACY_REGULAR_SQL_CMD; + final int timeoutValue = 5; + try (Statement statement = connection.createStatement()) { + statement.setQueryTimeout(timeoutValue); + ResultSet resultSet = statement.executeQuery(query); + CoreMockedSqlProducers.assertLegacyRegularSqlResultSet(resultSet, collector); + resultSet.close(); + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/TokenAuthenticationTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/TokenAuthenticationTest.java new file mode 100644 index 0000000000000..56c8c178f2133 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/TokenAuthenticationTest.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import java.sql.Connection; +import java.sql.SQLException; + +import org.apache.arrow.driver.jdbc.authentication.TokenAuthentication; +import org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer; +import org.apache.arrow.util.AutoCloseables; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.ClassRule; +import org.junit.Test; + +public class TokenAuthenticationTest { + private static final MockFlightSqlProducer FLIGHT_SQL_PRODUCER = new MockFlightSqlProducer(); + + @ClassRule + public static FlightServerTestRule FLIGHT_SERVER_TEST_RULE; + + static { + FLIGHT_SERVER_TEST_RULE = new FlightServerTestRule.Builder() + .host("localhost") + .randomPort() + .authentication(new TokenAuthentication.Builder() + .token("1234") + .build()) + .producer(FLIGHT_SQL_PRODUCER) + .build(); + } + + @AfterClass + public static void tearDownAfterClass() { + AutoCloseables.closeNoChecked(FLIGHT_SQL_PRODUCER); + } + + @Test(expected = SQLException.class) + public void connectUsingTokenAuthenticationShouldFail() throws SQLException { + try (Connection ignored = FLIGHT_SERVER_TEST_RULE.getConnection(false, "invalid")) { + Assert.fail(); + } + } + + @Test + public void connectUsingTokenAuthenticationShouldSuccess() throws SQLException { + try (Connection connection = FLIGHT_SERVER_TEST_RULE.getConnection(false, "1234")) { + Assert.assertFalse(connection.isClosed()); + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/ArrowFlightJdbcAccessorFactoryTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/ArrowFlightJdbcAccessorFactoryTest.java new file mode 100644 index 0000000000000..4b3744372c0e8 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/ArrowFlightJdbcAccessorFactoryTest.java @@ -0,0 +1,496 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor; + +import java.util.function.IntSupplier; + +import org.apache.arrow.driver.jdbc.accessor.impl.binary.ArrowFlightJdbcBinaryVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcDateVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcDurationVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcIntervalVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcTimeStampVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcTimeVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.complex.ArrowFlightJdbcDenseUnionVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.complex.ArrowFlightJdbcFixedSizeListVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.complex.ArrowFlightJdbcLargeListVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.complex.ArrowFlightJdbcListVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.complex.ArrowFlightJdbcMapVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.complex.ArrowFlightJdbcStructVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.complex.ArrowFlightJdbcUnionVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.numeric.ArrowFlightJdbcBaseIntVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.numeric.ArrowFlightJdbcBitVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.numeric.ArrowFlightJdbcDecimalVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.numeric.ArrowFlightJdbcFloat4VectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.numeric.ArrowFlightJdbcFloat8VectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.text.ArrowFlightJdbcVarCharVectorAccessor; +import org.apache.arrow.driver.jdbc.utils.RootAllocatorTestRule; +import org.apache.arrow.vector.DurationVector; +import org.apache.arrow.vector.IntervalDayVector; +import org.apache.arrow.vector.IntervalYearVector; +import org.apache.arrow.vector.LargeVarCharVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.complex.DenseUnionVector; +import org.apache.arrow.vector.complex.MapVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.complex.UnionVector; +import org.apache.arrow.vector.types.TimeUnit; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.junit.Assert; +import org.junit.ClassRule; +import org.junit.Test; + +public class ArrowFlightJdbcAccessorFactoryTest { + public static final IntSupplier GET_CURRENT_ROW = () -> 0; + + @ClassRule + public static RootAllocatorTestRule rootAllocatorTestRule = new RootAllocatorTestRule(); + + @Test + public void createAccessorForUInt1Vector() { + try (ValueVector valueVector = rootAllocatorTestRule.createUInt1Vector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcBaseIntVectorAccessor); + } + } + + @Test + public void createAccessorForUInt2Vector() { + try (ValueVector valueVector = rootAllocatorTestRule.createUInt2Vector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcBaseIntVectorAccessor); + } + } + + @Test + public void createAccessorForUInt4Vector() { + try (ValueVector valueVector = rootAllocatorTestRule.createUInt4Vector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcBaseIntVectorAccessor); + } + } + + @Test + public void createAccessorForUInt8Vector() { + try (ValueVector valueVector = rootAllocatorTestRule.createUInt8Vector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcBaseIntVectorAccessor); + } + } + + @Test + public void createAccessorForTinyIntVector() { + try (ValueVector valueVector = rootAllocatorTestRule.createTinyIntVector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcBaseIntVectorAccessor); + } + } + + @Test + public void createAccessorForSmallIntVector() { + try (ValueVector valueVector = rootAllocatorTestRule.createSmallIntVector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcBaseIntVectorAccessor); + } + } + + @Test + public void createAccessorForIntVector() { + try (ValueVector valueVector = rootAllocatorTestRule.createIntVector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcBaseIntVectorAccessor); + } + } + + @Test + public void createAccessorForBigIntVector() { + try (ValueVector valueVector = rootAllocatorTestRule.createBigIntVector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcBaseIntVectorAccessor); + } + } + + @Test + public void createAccessorForFloat4Vector() { + try (ValueVector valueVector = rootAllocatorTestRule.createFloat4Vector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcFloat4VectorAccessor); + } + } + + @Test + public void createAccessorForFloat8Vector() { + try (ValueVector valueVector = rootAllocatorTestRule.createFloat8Vector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcFloat8VectorAccessor); + } + } + + @Test + public void createAccessorForBitVector() { + try (ValueVector valueVector = rootAllocatorTestRule.createBitVector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcBitVectorAccessor); + } + } + + @Test + public void createAccessorForDecimalVector() { + try (ValueVector valueVector = rootAllocatorTestRule.createDecimalVector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcDecimalVectorAccessor); + } + } + + @Test + public void createAccessorForDecimal256Vector() { + try (ValueVector valueVector = rootAllocatorTestRule.createDecimal256Vector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcDecimalVectorAccessor); + } + } + + @Test + public void createAccessorForVarBinaryVector() { + try (ValueVector valueVector = rootAllocatorTestRule.createVarBinaryVector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcBinaryVectorAccessor); + } + } + + @Test + public void createAccessorForLargeVarBinaryVector() { + try (ValueVector valueVector = rootAllocatorTestRule.createLargeVarBinaryVector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcBinaryVectorAccessor); + } + } + + @Test + public void createAccessorForFixedSizeBinaryVector() { + try (ValueVector valueVector = rootAllocatorTestRule.createFixedSizeBinaryVector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcBinaryVectorAccessor); + } + } + + @Test + public void createAccessorForTimeStampVector() { + try (ValueVector valueVector = rootAllocatorTestRule.createTimeStampMilliVector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcTimeStampVectorAccessor); + } + } + + @Test + public void createAccessorForTimeNanoVector() { + try (ValueVector valueVector = rootAllocatorTestRule.createTimeNanoVector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcTimeVectorAccessor); + } + } + + @Test + public void createAccessorForTimeMicroVector() { + try (ValueVector valueVector = rootAllocatorTestRule.createTimeMicroVector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcTimeVectorAccessor); + } + } + + @Test + public void createAccessorForTimeMilliVector() { + try (ValueVector valueVector = rootAllocatorTestRule.createTimeMilliVector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcTimeVectorAccessor); + } + } + + @Test + public void createAccessorForTimeSecVector() { + try (ValueVector valueVector = rootAllocatorTestRule.createTimeSecVector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcTimeVectorAccessor); + } + } + + @Test + public void createAccessorForDateDayVector() { + try (ValueVector valueVector = rootAllocatorTestRule.createDateDayVector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcDateVectorAccessor); + } + } + + @Test + public void createAccessorForDateMilliVector() { + try (ValueVector valueVector = rootAllocatorTestRule.createDateMilliVector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcDateVectorAccessor); + } + } + + @Test + public void createAccessorForVarCharVector() { + try ( + ValueVector valueVector = new VarCharVector("", rootAllocatorTestRule.getRootAllocator())) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcVarCharVectorAccessor); + } + } + + @Test + public void createAccessorForLargeVarCharVector() { + try (ValueVector valueVector = new LargeVarCharVector("", + rootAllocatorTestRule.getRootAllocator())) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcVarCharVectorAccessor); + } + } + + @Test + public void createAccessorForDurationVector() { + try (ValueVector valueVector = + new DurationVector("", + new FieldType(true, new ArrowType.Duration(TimeUnit.MILLISECOND), null), + rootAllocatorTestRule.getRootAllocator())) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcDurationVectorAccessor); + } + } + + @Test + public void createAccessorForIntervalDayVector() { + try (ValueVector valueVector = new IntervalDayVector("", + rootAllocatorTestRule.getRootAllocator())) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcIntervalVectorAccessor); + } + } + + @Test + public void createAccessorForIntervalYearVector() { + try (ValueVector valueVector = new IntervalYearVector("", + rootAllocatorTestRule.getRootAllocator())) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcIntervalVectorAccessor); + } + } + + @Test + public void createAccessorForUnionVector() { + try (ValueVector valueVector = new UnionVector("", rootAllocatorTestRule.getRootAllocator(), + null, null)) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcUnionVectorAccessor); + } + } + + @Test + public void createAccessorForDenseUnionVector() { + try ( + ValueVector valueVector = new DenseUnionVector("", rootAllocatorTestRule.getRootAllocator(), + null, null)) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcDenseUnionVectorAccessor); + } + } + + @Test + public void createAccessorForStructVector() { + try (ValueVector valueVector = StructVector.empty("", + rootAllocatorTestRule.getRootAllocator())) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcStructVectorAccessor); + } + } + + @Test + public void createAccessorForListVector() { + try (ValueVector valueVector = rootAllocatorTestRule.createListVector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcListVectorAccessor); + } + } + + @Test + public void createAccessorForLargeListVector() { + try (ValueVector valueVector = rootAllocatorTestRule.createLargeListVector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcLargeListVectorAccessor); + } + } + + @Test + public void createAccessorForFixedSizeListVector() { + try (ValueVector valueVector = rootAllocatorTestRule.createFixedSizeListVector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcFixedSizeListVectorAccessor); + } + } + + @Test + public void createAccessorForMapVector() { + try (ValueVector valueVector = MapVector.empty("", rootAllocatorTestRule.getRootAllocator(), + true)) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcMapVectorAccessor); + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/ArrowFlightJdbcAccessorTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/ArrowFlightJdbcAccessorTest.java new file mode 100644 index 0000000000000..099b0122179f1 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/ArrowFlightJdbcAccessorTest.java @@ -0,0 +1,358 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor; + +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.math.BigDecimal; +import java.nio.charset.StandardCharsets; +import java.sql.SQLException; +import java.util.HashMap; +import java.util.Map; + +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; + +@RunWith(MockitoJUnitRunner.class) +public class ArrowFlightJdbcAccessorTest { + + static class MockedArrowFlightJdbcAccessor extends ArrowFlightJdbcAccessor { + + protected MockedArrowFlightJdbcAccessor() { + super(() -> 0, (boolean wasNull) -> { + }); + } + + @Override + public Class getObjectClass() { + return Long.class; + } + } + + @Mock + MockedArrowFlightJdbcAccessor accessor; + + @Test + public void testShouldGetObjectWithByteClassReturnGetByte() throws SQLException { + byte expected = Byte.MAX_VALUE; + when(accessor.getByte()).thenReturn(expected); + + when(accessor.getObject(Byte.class)).thenCallRealMethod(); + + Assert.assertEquals(accessor.getObject(Byte.class), (Object) expected); + verify(accessor).getByte(); + } + + @Test + public void testShouldGetObjectWithShortClassReturnGetShort() throws SQLException { + short expected = Short.MAX_VALUE; + when(accessor.getShort()).thenReturn(expected); + + when(accessor.getObject(Short.class)).thenCallRealMethod(); + + Assert.assertEquals(accessor.getObject(Short.class), (Object) expected); + verify(accessor).getShort(); + } + + @Test + public void testShouldGetObjectWithIntegerClassReturnGetInt() throws SQLException { + int expected = Integer.MAX_VALUE; + when(accessor.getInt()).thenReturn(expected); + + when(accessor.getObject(Integer.class)).thenCallRealMethod(); + + Assert.assertEquals(accessor.getObject(Integer.class), (Object) expected); + verify(accessor).getInt(); + } + + @Test + public void testShouldGetObjectWithLongClassReturnGetLong() throws SQLException { + long expected = Long.MAX_VALUE; + when(accessor.getLong()).thenReturn(expected); + + when(accessor.getObject(Long.class)).thenCallRealMethod(); + + Assert.assertEquals(accessor.getObject(Long.class), (Object) expected); + verify(accessor).getLong(); + } + + @Test + public void testShouldGetObjectWithFloatClassReturnGetFloat() throws SQLException { + float expected = Float.MAX_VALUE; + when(accessor.getFloat()).thenReturn(expected); + + when(accessor.getObject(Float.class)).thenCallRealMethod(); + + Assert.assertEquals(accessor.getObject(Float.class), (Object) expected); + verify(accessor).getFloat(); + } + + @Test + public void testShouldGetObjectWithDoubleClassReturnGetDouble() throws SQLException { + double expected = Double.MAX_VALUE; + when(accessor.getDouble()).thenReturn(expected); + + when(accessor.getObject(Double.class)).thenCallRealMethod(); + + Assert.assertEquals(accessor.getObject(Double.class), (Object) expected); + verify(accessor).getDouble(); + } + + @Test + public void testShouldGetObjectWithBooleanClassReturnGetBoolean() throws SQLException { + when(accessor.getBoolean()).thenReturn(true); + + when(accessor.getObject(Boolean.class)).thenCallRealMethod(); + + Assert.assertEquals(accessor.getObject(Boolean.class), true); + verify(accessor).getBoolean(); + } + + @Test + public void testShouldGetObjectWithBigDecimalClassReturnGetBigDecimal() throws SQLException { + BigDecimal expected = BigDecimal.TEN; + when(accessor.getBigDecimal()).thenReturn(expected); + + when(accessor.getObject(BigDecimal.class)).thenCallRealMethod(); + + Assert.assertEquals(accessor.getObject(BigDecimal.class), expected); + verify(accessor).getBigDecimal(); + } + + @Test + public void testShouldGetObjectWithStringClassReturnGetString() throws SQLException { + String expected = "STRING_VALUE"; + when(accessor.getString()).thenReturn(expected); + + when(accessor.getObject(String.class)).thenCallRealMethod(); + + Assert.assertEquals(accessor.getObject(String.class), expected); + verify(accessor).getString(); + } + + @Test + public void testShouldGetObjectWithByteArrayClassReturnGetBytes() throws SQLException { + byte[] expected = "STRING_VALUE".getBytes(StandardCharsets.UTF_8); + when(accessor.getBytes()).thenReturn(expected); + + when(accessor.getObject(byte[].class)).thenCallRealMethod(); + + Assert.assertArrayEquals(accessor.getObject(byte[].class), expected); + verify(accessor).getBytes(); + } + + @Test + public void testShouldGetObjectWithObjectClassReturnGetObject() throws SQLException { + Object expected = new Object(); + when(accessor.getObject()).thenReturn(expected); + + when(accessor.getObject(Object.class)).thenCallRealMethod(); + + Assert.assertEquals(accessor.getObject(Object.class), expected); + verify(accessor).getObject(); + } + + @Test + public void testShouldGetObjectWithAccessorsObjectClassReturnGetObject() throws SQLException { + Class objectClass = Long.class; + + when(accessor.getObject(objectClass)).thenCallRealMethod(); + + accessor.getObject(objectClass); + verify(accessor).getObject(objectClass); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetBoolean() throws SQLException { + when(accessor.getBoolean()).thenCallRealMethod(); + accessor.getBoolean(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetByte() throws SQLException { + when(accessor.getByte()).thenCallRealMethod(); + accessor.getByte(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetShort() throws SQLException { + when(accessor.getShort()).thenCallRealMethod(); + accessor.getShort(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetInt() throws SQLException { + when(accessor.getInt()).thenCallRealMethod(); + accessor.getInt(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetLong() throws SQLException { + when(accessor.getLong()).thenCallRealMethod(); + accessor.getLong(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetFloat() throws SQLException { + when(accessor.getFloat()).thenCallRealMethod(); + accessor.getFloat(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetDouble() throws SQLException { + when(accessor.getDouble()).thenCallRealMethod(); + accessor.getDouble(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetBigDecimal() throws SQLException { + when(accessor.getBigDecimal()).thenCallRealMethod(); + accessor.getBigDecimal(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetBytes() throws SQLException { + when(accessor.getBytes()).thenCallRealMethod(); + accessor.getBytes(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetAsciiStream() throws SQLException { + when(accessor.getAsciiStream()).thenCallRealMethod(); + accessor.getAsciiStream(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetUnicodeStream() throws SQLException { + when(accessor.getUnicodeStream()).thenCallRealMethod(); + accessor.getUnicodeStream(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetBinaryStream() throws SQLException { + when(accessor.getBinaryStream()).thenCallRealMethod(); + accessor.getBinaryStream(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetObject() throws SQLException { + when(accessor.getObject()).thenCallRealMethod(); + accessor.getObject(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetObjectMap() throws SQLException { + Map> map = new HashMap<>(); + when(accessor.getObject(map)).thenCallRealMethod(); + accessor.getObject(map); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetCharacterStream() throws SQLException { + when(accessor.getCharacterStream()).thenCallRealMethod(); + accessor.getCharacterStream(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetRef() throws SQLException { + when(accessor.getRef()).thenCallRealMethod(); + accessor.getRef(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetBlob() throws SQLException { + when(accessor.getBlob()).thenCallRealMethod(); + accessor.getBlob(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetClob() throws SQLException { + when(accessor.getClob()).thenCallRealMethod(); + accessor.getClob(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetArray() throws SQLException { + when(accessor.getArray()).thenCallRealMethod(); + accessor.getArray(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetStruct() throws SQLException { + when(accessor.getStruct()).thenCallRealMethod(); + accessor.getStruct(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetURL() throws SQLException { + when(accessor.getURL()).thenCallRealMethod(); + accessor.getURL(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetNClob() throws SQLException { + when(accessor.getNClob()).thenCallRealMethod(); + accessor.getNClob(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetSQLXML() throws SQLException { + when(accessor.getSQLXML()).thenCallRealMethod(); + accessor.getSQLXML(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetNString() throws SQLException { + when(accessor.getNString()).thenCallRealMethod(); + accessor.getNString(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetNCharacterStream() throws SQLException { + when(accessor.getNCharacterStream()).thenCallRealMethod(); + accessor.getNCharacterStream(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetDate() throws SQLException { + when(accessor.getDate(null)).thenCallRealMethod(); + accessor.getDate(null); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetTime() throws SQLException { + when(accessor.getTime(null)).thenCallRealMethod(); + accessor.getTime(null); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetTimestamp() throws SQLException { + when(accessor.getTimestamp(null)).thenCallRealMethod(); + accessor.getTimestamp(null); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetBigDecimalWithValue() throws SQLException { + when(accessor.getBigDecimal(0)).thenCallRealMethod(); + accessor.getBigDecimal(0); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/ArrowFlightJdbcNullVectorAccessorTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/ArrowFlightJdbcNullVectorAccessorTest.java new file mode 100644 index 0000000000000..57e7ecfe02580 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/ArrowFlightJdbcNullVectorAccessorTest.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl; + +import org.junit.Assert; +import org.junit.Test; + +public class ArrowFlightJdbcNullVectorAccessorTest { + + ArrowFlightJdbcNullVectorAccessor accessor = + new ArrowFlightJdbcNullVectorAccessor((boolean wasNull) -> { + }); + + @Test + public void testShouldWasNullReturnTrue() { + Assert.assertTrue(accessor.wasNull()); + } + + @Test + public void testShouldGetObjectReturnNull() { + Assert.assertNull(accessor.getObject()); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/binary/ArrowFlightJdbcBinaryVectorAccessorTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/binary/ArrowFlightJdbcBinaryVectorAccessorTest.java new file mode 100644 index 0000000000000..f4d256c4cf8ac --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/binary/ArrowFlightJdbcBinaryVectorAccessorTest.java @@ -0,0 +1,244 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.binary; + +import static java.nio.charset.StandardCharsets.US_ASCII; +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.hamcrest.CoreMatchers.is; + +import java.io.InputStream; +import java.io.Reader; +import java.util.Arrays; +import java.util.Collection; +import java.util.function.Supplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.driver.jdbc.utils.AccessorTestUtils; +import org.apache.arrow.driver.jdbc.utils.RootAllocatorTestRule; +import org.apache.arrow.vector.FixedSizeBinaryVector; +import org.apache.arrow.vector.LargeVarBinaryVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.VarBinaryVector; +import org.apache.commons.io.IOUtils; +import org.hamcrest.CoreMatchers; +import org.junit.After; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(Parameterized.class) +public class ArrowFlightJdbcBinaryVectorAccessorTest { + + @ClassRule + public static RootAllocatorTestRule rootAllocatorTestRule = new RootAllocatorTestRule(); + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + + private ValueVector vector; + private final Supplier vectorSupplier; + + private final AccessorTestUtils.AccessorSupplier + accessorSupplier = (vector, getCurrentRow) -> { + ArrowFlightJdbcAccessorFactory.WasNullConsumer noOpWasNullConsumer = (boolean wasNull) -> { + }; + if (vector instanceof VarBinaryVector) { + return new ArrowFlightJdbcBinaryVectorAccessor(((VarBinaryVector) vector), getCurrentRow, + noOpWasNullConsumer); + } else if (vector instanceof LargeVarBinaryVector) { + return new ArrowFlightJdbcBinaryVectorAccessor(((LargeVarBinaryVector) vector), + getCurrentRow, noOpWasNullConsumer); + } else if (vector instanceof FixedSizeBinaryVector) { + return new ArrowFlightJdbcBinaryVectorAccessor(((FixedSizeBinaryVector) vector), + getCurrentRow, noOpWasNullConsumer); + } + return null; + }; + + private final AccessorTestUtils.AccessorIterator + accessorIterator = + new AccessorTestUtils.AccessorIterator<>(collector, accessorSupplier); + + @Parameterized.Parameters(name = "{1}") + public static Collection data() { + return Arrays.asList(new Object[][] { + {(Supplier) () -> rootAllocatorTestRule.createVarBinaryVector(), + "VarBinaryVector"}, + {(Supplier) () -> rootAllocatorTestRule.createLargeVarBinaryVector(), + "LargeVarBinaryVector"}, + {(Supplier) () -> rootAllocatorTestRule.createFixedSizeBinaryVector(), + "FixedSizeBinaryVector"}, + }); + } + + public ArrowFlightJdbcBinaryVectorAccessorTest(Supplier vectorSupplier, + String vectorType) { + this.vectorSupplier = vectorSupplier; + } + + @Before + public void setup() { + this.vector = vectorSupplier.get(); + } + + @After + public void tearDown() { + this.vector.close(); + } + + @Test + public void testShouldGetStringReturnExpectedString() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcBinaryVectorAccessor::getString, + (accessor) -> is(new String(accessor.getBytes(), UTF_8))); + } + + @Test + public void testShouldGetStringReturnNull() throws Exception { + vector.reset(); + vector.setValueCount(5); + + accessorIterator + .assertAccessorGetter(vector, ArrowFlightJdbcBinaryVectorAccessor::getString, + CoreMatchers.nullValue()); + } + + @Test + public void testShouldGetBytesReturnExpectedByteArray() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcBinaryVectorAccessor::getBytes, + (accessor, currentRow) -> { + if (vector instanceof VarBinaryVector) { + return is(((VarBinaryVector) vector).get(currentRow)); + } else if (vector instanceof LargeVarBinaryVector) { + return is(((LargeVarBinaryVector) vector).get(currentRow)); + } else if (vector instanceof FixedSizeBinaryVector) { + return is(((FixedSizeBinaryVector) vector).get(currentRow)); + } + return null; + }); + } + + @Test + public void testShouldGetBytesReturnNull() throws Exception { + vector.reset(); + vector.setValueCount(5); + + ArrowFlightJdbcBinaryVectorAccessor accessor = accessorSupplier.supply(vector, () -> 0); + collector.checkThat(accessor.getBytes(), CoreMatchers.equalTo(null)); + collector.checkThat(accessor.wasNull(), is(true)); + } + + @Test + public void testShouldGetObjectReturnAsGetBytes() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcBinaryVectorAccessor::getObject, + (accessor) -> is(accessor.getBytes())); + } + + @Test + public void testShouldGetObjectReturnNull() { + vector.reset(); + vector.setValueCount(5); + + ArrowFlightJdbcBinaryVectorAccessor accessor = accessorSupplier.supply(vector, () -> 0); + collector.checkThat(accessor.getObject(), CoreMatchers.equalTo(null)); + collector.checkThat(accessor.wasNull(), is(true)); + } + + @Test + public void testShouldGetUnicodeStreamReturnCorrectInputStream() throws Exception { + accessorIterator.iterate(vector, (accessor, currentRow) -> { + InputStream inputStream = accessor.getUnicodeStream(); + String actualString = IOUtils.toString(inputStream, UTF_8); + collector.checkThat(accessor.wasNull(), is(false)); + collector.checkThat(actualString, is(accessor.getString())); + }); + } + + @Test + public void testShouldGetUnicodeStreamReturnNull() throws Exception { + vector.reset(); + vector.setValueCount(5); + + ArrowFlightJdbcBinaryVectorAccessor accessor = accessorSupplier.supply(vector, () -> 0); + collector.checkThat(accessor.getUnicodeStream(), CoreMatchers.equalTo(null)); + collector.checkThat(accessor.wasNull(), is(true)); + } + + @Test + public void testShouldGetAsciiStreamReturnCorrectInputStream() throws Exception { + accessorIterator.iterate(vector, (accessor, currentRow) -> { + InputStream inputStream = accessor.getAsciiStream(); + String actualString = IOUtils.toString(inputStream, US_ASCII); + collector.checkThat(accessor.wasNull(), is(false)); + collector.checkThat(actualString, is(accessor.getString())); + }); + } + + @Test + public void testShouldGetAsciiStreamReturnNull() throws Exception { + vector.reset(); + vector.setValueCount(5); + + ArrowFlightJdbcBinaryVectorAccessor accessor = accessorSupplier.supply(vector, () -> 0); + collector.checkThat(accessor.getAsciiStream(), CoreMatchers.equalTo(null)); + collector.checkThat(accessor.wasNull(), is(true)); + } + + @Test + public void testShouldGetBinaryStreamReturnCurrentInputStream() throws Exception { + accessorIterator.iterate(vector, (accessor, currentRow) -> { + InputStream inputStream = accessor.getBinaryStream(); + String actualString = IOUtils.toString(inputStream, UTF_8); + collector.checkThat(accessor.wasNull(), is(false)); + collector.checkThat(actualString, is(accessor.getString())); + }); + } + + @Test + public void testShouldGetBinaryStreamReturnNull() throws Exception { + vector.reset(); + vector.setValueCount(5); + + ArrowFlightJdbcBinaryVectorAccessor accessor = accessorSupplier.supply(vector, () -> 0); + collector.checkThat(accessor.getBinaryStream(), CoreMatchers.equalTo(null)); + collector.checkThat(accessor.wasNull(), is(true)); + } + + @Test + public void testShouldGetCharacterStreamReturnCorrectReader() throws Exception { + accessorIterator.iterate(vector, (accessor, currentRow) -> { + Reader characterStream = accessor.getCharacterStream(); + String actualString = IOUtils.toString(characterStream); + collector.checkThat(accessor.wasNull(), is(false)); + collector.checkThat(actualString, is(accessor.getString())); + }); + } + + @Test + public void testShouldGetCharacterStreamReturnNull() throws Exception { + vector.reset(); + vector.setValueCount(5); + + ArrowFlightJdbcBinaryVectorAccessor accessor = accessorSupplier.supply(vector, () -> 0); + collector.checkThat(accessor.getCharacterStream(), CoreMatchers.equalTo(null)); + collector.checkThat(accessor.wasNull(), is(true)); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcDateVectorAccessorTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcDateVectorAccessorTest.java new file mode 100644 index 0000000000000..36af5134626a5 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcDateVectorAccessorTest.java @@ -0,0 +1,254 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.calendar; + +import static org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcDateVectorAccessor.getTimeUnitForVector; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.not; + +import java.sql.Date; +import java.sql.Timestamp; +import java.time.LocalDateTime; +import java.util.Arrays; +import java.util.Calendar; +import java.util.Collection; +import java.util.TimeZone; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; + +import org.apache.arrow.driver.jdbc.accessor.impl.text.ArrowFlightJdbcVarCharVectorAccessor; +import org.apache.arrow.driver.jdbc.utils.AccessorTestUtils; +import org.apache.arrow.driver.jdbc.utils.RootAllocatorTestRule; +import org.apache.arrow.vector.BaseFixedWidthVector; +import org.apache.arrow.vector.DateDayVector; +import org.apache.arrow.vector.DateMilliVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.util.Text; +import org.hamcrest.CoreMatchers; +import org.junit.After; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(Parameterized.class) +public class ArrowFlightJdbcDateVectorAccessorTest { + + public static final String AMERICA_VANCOUVER = "America/Vancouver"; + + @ClassRule + public static RootAllocatorTestRule rootAllocatorTestRule = new RootAllocatorTestRule(); + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + + private BaseFixedWidthVector vector; + private final Supplier vectorSupplier; + + private final AccessorTestUtils.AccessorSupplier + accessorSupplier = (vector, getCurrentRow) -> { + if (vector instanceof DateDayVector) { + return new ArrowFlightJdbcDateVectorAccessor((DateDayVector) vector, getCurrentRow, + (boolean wasNull) -> { + }); + } else if (vector instanceof DateMilliVector) { + return new ArrowFlightJdbcDateVectorAccessor((DateMilliVector) vector, getCurrentRow, + (boolean wasNull) -> { + }); + } + return null; + }; + + private final AccessorTestUtils.AccessorIterator + accessorIterator = + new AccessorTestUtils.AccessorIterator<>(collector, accessorSupplier); + + @Parameterized.Parameters(name = "{1}") + public static Collection data() { + return Arrays.asList(new Object[][] { + {(Supplier) () -> rootAllocatorTestRule.createDateDayVector(), + "DateDayVector"}, + {(Supplier) () -> rootAllocatorTestRule.createDateMilliVector(), + "DateMilliVector"}, + }); + } + + public ArrowFlightJdbcDateVectorAccessorTest(Supplier vectorSupplier, + String vectorType) { + this.vectorSupplier = vectorSupplier; + } + + @Before + public void setup() { + this.vector = vectorSupplier.get(); + } + + @After + public void tearDown() { + this.vector.close(); + } + + @Test + public void testShouldGetTimestampReturnValidTimestampWithoutCalendar() throws Exception { + accessorIterator.assertAccessorGetter(vector, accessor -> accessor.getTimestamp(null), + (accessor, currentRow) -> is(getTimestampForVector(currentRow))); + } + + @Test + public void testShouldGetObjectWithDateClassReturnValidDateWithoutCalendar() throws Exception { + accessorIterator.assertAccessorGetter(vector, accessor -> accessor.getObject(Date.class), + (accessor, currentRow) -> is(new Date(getTimestampForVector(currentRow).getTime()))); + } + + @Test + public void testShouldGetTimestampReturnValidTimestampWithCalendar() throws Exception { + TimeZone timeZone = TimeZone.getTimeZone(AMERICA_VANCOUVER); + Calendar calendar = Calendar.getInstance(timeZone); + + accessorIterator.iterate(vector, (accessor, currentRow) -> { + final Timestamp resultWithoutCalendar = accessor.getTimestamp(null); + final Timestamp result = accessor.getTimestamp(calendar); + + long offset = timeZone.getOffset(resultWithoutCalendar.getTime()); + + collector.checkThat(resultWithoutCalendar.getTime() - result.getTime(), is(offset)); + collector.checkThat(accessor.wasNull(), is(false)); + }); + } + + @Test + public void testShouldGetTimestampReturnNull() { + vector.setNull(0); + ArrowFlightJdbcDateVectorAccessor accessor = accessorSupplier.supply(vector, () -> 0); + collector.checkThat(accessor.getTimestamp(null), CoreMatchers.equalTo(null)); + collector.checkThat(accessor.wasNull(), is(true)); + } + + @Test + public void testShouldGetDateReturnValidDateWithoutCalendar() throws Exception { + accessorIterator.assertAccessorGetter(vector, accessor -> accessor.getDate(null), + (accessor, currentRow) -> is(new Date(getTimestampForVector(currentRow).getTime()))); + } + + @Test + public void testShouldGetDateReturnValidDateWithCalendar() throws Exception { + TimeZone timeZone = TimeZone.getTimeZone(AMERICA_VANCOUVER); + Calendar calendar = Calendar.getInstance(timeZone); + + accessorIterator.iterate(vector, (accessor, currentRow) -> { + final Date resultWithoutCalendar = accessor.getDate(null); + final Date result = accessor.getDate(calendar); + + long offset = timeZone.getOffset(resultWithoutCalendar.getTime()); + + collector.checkThat(resultWithoutCalendar.getTime() - result.getTime(), is(offset)); + collector.checkThat(accessor.wasNull(), is(false)); + }); + } + + @Test + public void testShouldGetDateReturnNull() { + vector.setNull(0); + ArrowFlightJdbcDateVectorAccessor accessor = accessorSupplier.supply(vector, () -> 0); + collector.checkThat(accessor.getDate(null), CoreMatchers.equalTo(null)); + collector.checkThat(accessor.wasNull(), is(true)); + } + + private Timestamp getTimestampForVector(int currentRow) { + Object object = vector.getObject(currentRow); + + Timestamp expectedTimestamp = null; + if (object instanceof LocalDateTime) { + expectedTimestamp = Timestamp.valueOf((LocalDateTime) object); + } else if (object instanceof Number) { + long value = ((Number) object).longValue(); + TimeUnit timeUnit = getTimeUnitForVector(vector); + long millis = timeUnit.toMillis(value); + expectedTimestamp = new Timestamp(millis); + } + return expectedTimestamp; + } + + @Test + public void testShouldGetObjectClass() throws Exception { + accessorIterator + .assertAccessorGetter(vector, ArrowFlightJdbcDateVectorAccessor::getObjectClass, + equalTo(Date.class)); + } + + @Test + public void testShouldGetStringBeConsistentWithVarCharAccessorWithoutCalendar() throws Exception { + assertGetStringIsConsistentWithVarCharAccessor(null); + } + + @Test + public void testShouldGetStringBeConsistentWithVarCharAccessorWithCalendar() throws Exception { + Calendar calendar = Calendar.getInstance(TimeZone.getTimeZone(AMERICA_VANCOUVER)); + assertGetStringIsConsistentWithVarCharAccessor(calendar); + } + + @Test + public void testValidateGetStringTimeZoneConsistency() throws Exception { + accessorIterator.iterate(vector, (accessor, currentRow) -> { + final TimeZone defaultTz = TimeZone.getDefault(); + try { + final String string = accessor.getString(); // Should always be UTC as no calendar is provided + + // Validate with UTC + Date date = accessor.getDate(null); + TimeZone.setDefault(TimeZone.getTimeZone("UTC")); + collector.checkThat(date.toString(), is(string)); + + // Validate with different TZ + TimeZone.setDefault(TimeZone.getTimeZone(AMERICA_VANCOUVER)); + collector.checkThat(date.toString(), not(string)); + + collector.checkThat(accessor.wasNull(), is(false)); + } finally { + // Set default Tz back + TimeZone.setDefault(defaultTz); + } + }); + } + + private void assertGetStringIsConsistentWithVarCharAccessor(Calendar calendar) throws Exception { + try (VarCharVector varCharVector = new VarCharVector("", + rootAllocatorTestRule.getRootAllocator())) { + varCharVector.allocateNew(1); + ArrowFlightJdbcVarCharVectorAccessor varCharVectorAccessor = + new ArrowFlightJdbcVarCharVectorAccessor(varCharVector, () -> 0, (boolean wasNull) -> { + }); + + accessorIterator.iterate(vector, (accessor, currentRow) -> { + final String string = accessor.getString(); + varCharVector.set(0, new Text(string)); + varCharVector.setValueCount(1); + + Date dateFromVarChar = varCharVectorAccessor.getDate(calendar); + Date date = accessor.getDate(calendar); + + collector.checkThat(date, is(dateFromVarChar)); + collector.checkThat(accessor.wasNull(), is(false)); + }); + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcDurationVectorAccessorTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcDurationVectorAccessorTest.java new file mode 100644 index 0000000000000..64ddb573f1bfb --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcDurationVectorAccessorTest.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.calendar; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; + +import java.time.Duration; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessor; +import org.apache.arrow.driver.jdbc.utils.AccessorTestUtils; +import org.apache.arrow.driver.jdbc.utils.RootAllocatorTestRule; +import org.apache.arrow.vector.DurationVector; +import org.apache.arrow.vector.types.TimeUnit; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.junit.After; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; + +public class ArrowFlightJdbcDurationVectorAccessorTest { + + @ClassRule + public static RootAllocatorTestRule rootAllocatorTestRule = new RootAllocatorTestRule(); + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + + private DurationVector vector; + + private final AccessorTestUtils.AccessorSupplier + accessorSupplier = + (vector, getCurrentRow) -> new ArrowFlightJdbcDurationVectorAccessor((DurationVector) vector, + getCurrentRow, (boolean wasNull) -> { + }); + + private final AccessorTestUtils.AccessorIterator + accessorIterator = + new AccessorTestUtils.AccessorIterator<>(collector, accessorSupplier); + + @Before + public void setup() { + FieldType fieldType = new FieldType(true, new ArrowType.Duration(TimeUnit.MILLISECOND), null); + this.vector = new DurationVector("", fieldType, rootAllocatorTestRule.getRootAllocator()); + + int valueCount = 10; + this.vector.setValueCount(valueCount); + for (int i = 0; i < valueCount; i++) { + this.vector.set(i, java.util.concurrent.TimeUnit.DAYS.toMillis(i + 1)); + } + } + + @After + public void tearDown() { + this.vector.close(); + } + + @Test + public void getObject() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcDurationVectorAccessor::getObject, + (accessor, currentRow) -> is(Duration.ofDays(currentRow + 1))); + } + + @Test + public void getObjectForNull() throws Exception { + int valueCount = vector.getValueCount(); + for (int i = 0; i < valueCount; i++) { + vector.setNull(i); + } + + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcDurationVectorAccessor::getObject, + (accessor, currentRow) -> equalTo(null)); + } + + @Test + public void getString() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcAccessor::getString, + (accessor, currentRow) -> is(Duration.ofDays(currentRow + 1).toString())); + } + + @Test + public void getStringForNull() throws Exception { + int valueCount = vector.getValueCount(); + for (int i = 0; i < valueCount; i++) { + vector.setNull(i); + } + + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcAccessor::getString, + (accessor, currentRow) -> equalTo(null)); + } + + @Test + public void testShouldGetObjectClass() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcAccessor::getObjectClass, + (accessor, currentRow) -> equalTo(Duration.class)); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcIntervalVectorAccessorTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcIntervalVectorAccessorTest.java new file mode 100644 index 0000000000000..ea228692202a7 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcIntervalVectorAccessorTest.java @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.calendar; + +import static org.apache.arrow.driver.jdbc.utils.IntervalStringUtils.formatIntervalDay; +import static org.apache.arrow.driver.jdbc.utils.IntervalStringUtils.formatIntervalYear; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; +import static org.joda.time.Period.parse; + +import java.time.Duration; +import java.time.Period; +import java.util.Arrays; +import java.util.Collection; +import java.util.function.Supplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.driver.jdbc.utils.AccessorTestUtils; +import org.apache.arrow.driver.jdbc.utils.RootAllocatorTestRule; +import org.apache.arrow.vector.IntervalDayVector; +import org.apache.arrow.vector.IntervalYearVector; +import org.apache.arrow.vector.ValueVector; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(Parameterized.class) +public class ArrowFlightJdbcIntervalVectorAccessorTest { + + @ClassRule + public static RootAllocatorTestRule rootAllocatorTestRule = new RootAllocatorTestRule(); + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + + private final Supplier vectorSupplier; + private ValueVector vector; + + private final AccessorTestUtils.AccessorSupplier + accessorSupplier = (vector, getCurrentRow) -> { + ArrowFlightJdbcAccessorFactory.WasNullConsumer noOpWasNullConsumer = (boolean wasNull) -> { + }; + if (vector instanceof IntervalDayVector) { + return new ArrowFlightJdbcIntervalVectorAccessor((IntervalDayVector) vector, + getCurrentRow, noOpWasNullConsumer); + } else if (vector instanceof IntervalYearVector) { + return new ArrowFlightJdbcIntervalVectorAccessor((IntervalYearVector) vector, + getCurrentRow, noOpWasNullConsumer); + } + return null; + }; + + final AccessorTestUtils.AccessorIterator accessorIterator = + new AccessorTestUtils.AccessorIterator<>(collector, accessorSupplier); + + @Parameterized.Parameters(name = "{1}") + public static Collection data() { + return Arrays.asList(new Object[][] { + {(Supplier) () -> { + IntervalDayVector vector = + new IntervalDayVector("", rootAllocatorTestRule.getRootAllocator()); + + int valueCount = 10; + vector.setValueCount(valueCount); + for (int i = 0; i < valueCount; i++) { + vector.set(i, i + 1, (i + 1) * 1000); + } + return vector; + }, "IntervalDayVector"}, + {(Supplier) () -> { + IntervalYearVector vector = + new IntervalYearVector("", rootAllocatorTestRule.getRootAllocator()); + + int valueCount = 10; + vector.setValueCount(valueCount); + for (int i = 0; i < valueCount; i++) { + vector.set(i, i + 1); + } + return vector; + }, "IntervalYearVector"}, + }); + } + + public ArrowFlightJdbcIntervalVectorAccessorTest(Supplier vectorSupplier, + String vectorType) { + this.vectorSupplier = vectorSupplier; + } + + @Before + public void setup() { + this.vector = vectorSupplier.get(); + } + + @After + public void tearDown() { + this.vector.close(); + } + + @Test + public void testShouldGetObjectReturnValidObject() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcIntervalVectorAccessor::getObject, + (accessor, currentRow) -> is(getExpectedObject(vector, currentRow))); + } + + @Test + public void testShouldGetObjectPassingObjectClassAsParameterReturnValidObject() throws Exception { + Class objectClass = getExpectedObjectClassForVector(vector); + accessorIterator.assertAccessorGetter(vector, accessor -> accessor.getObject(objectClass), + (accessor, currentRow) -> is(getExpectedObject(vector, currentRow))); + } + + @Test + public void testShouldGetObjectReturnNull() throws Exception { + setAllNullOnVector(vector); + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcIntervalVectorAccessor::getObject, + (accessor, currentRow) -> equalTo(null)); + } + + private String getStringOnVector(ValueVector vector, int index) { + String object = getExpectedObject(vector, index).toString(); + if (object == null) { + return null; + } else if (vector instanceof IntervalDayVector) { + return formatIntervalDay(parse(object)); + } else if (vector instanceof IntervalYearVector) { + return formatIntervalYear(parse(object)); + } + return null; + } + + @Test + public void testShouldGetIntervalYear( ) { + Assert.assertEquals("-002-00", formatIntervalYear(parse("P-2Y"))); + Assert.assertEquals("-001-01", formatIntervalYear(parse("P-1Y-1M"))); + Assert.assertEquals("-001-02", formatIntervalYear(parse("P-1Y-2M"))); + Assert.assertEquals("-002-03", formatIntervalYear(parse("P-2Y-3M"))); + Assert.assertEquals("-002-04", formatIntervalYear(parse("P-2Y-4M"))); + Assert.assertEquals("-011-01", formatIntervalYear(parse("P-11Y-1M"))); + Assert.assertEquals("+002-00", formatIntervalYear(parse("P+2Y"))); + Assert.assertEquals("+001-01", formatIntervalYear(parse("P+1Y1M"))); + Assert.assertEquals("+001-02", formatIntervalYear(parse("P+1Y2M"))); + Assert.assertEquals("+002-03", formatIntervalYear(parse("P+2Y3M"))); + Assert.assertEquals("+002-04", formatIntervalYear(parse("P+2Y4M"))); + Assert.assertEquals("+011-01", formatIntervalYear(parse("P+11Y1M"))); + } + + @Test + public void testShouldGetIntervalDay( ) { + Assert.assertEquals("-001 00:00:00.000", formatIntervalDay(parse("PT-24H"))); + Assert.assertEquals("+001 00:00:00.000", formatIntervalDay(parse("PT+24H"))); + Assert.assertEquals("-000 01:00:00.000", formatIntervalDay(parse("PT-1H"))); + Assert.assertEquals("-000 01:00:00.001", formatIntervalDay(parse("PT-1H-0M-00.001S"))); + Assert.assertEquals("-000 01:01:01.000", formatIntervalDay(parse("PT-1H-1M-1S"))); + Assert.assertEquals("-000 02:02:02.002", formatIntervalDay(parse("PT-2H-2M-02.002S"))); + Assert.assertEquals("-000 23:59:59.999", formatIntervalDay(parse("PT-23H-59M-59.999S"))); + Assert.assertEquals("-000 11:59:00.100", formatIntervalDay(parse("PT-11H-59M-00.100S"))); + Assert.assertEquals("-000 05:02:03.000", formatIntervalDay(parse("PT-5H-2M-3S"))); + Assert.assertEquals("-000 22:22:22.222", formatIntervalDay(parse("PT-22H-22M-22.222S"))); + Assert.assertEquals("+000 01:00:00.000", formatIntervalDay(parse("PT+1H"))); + Assert.assertEquals("+000 01:00:00.001", formatIntervalDay(parse("PT+1H0M00.001S"))); + Assert.assertEquals("+000 01:01:01.000", formatIntervalDay(parse("PT+1H1M1S"))); + Assert.assertEquals("+000 02:02:02.002", formatIntervalDay(parse("PT+2H2M02.002S"))); + Assert.assertEquals("+000 23:59:59.999", formatIntervalDay(parse("PT+23H59M59.999S"))); + Assert.assertEquals("+000 11:59:00.100", formatIntervalDay(parse("PT+11H59M00.100S"))); + Assert.assertEquals("+000 05:02:03.000", formatIntervalDay(parse("PT+5H2M3S"))); + Assert.assertEquals("+000 22:22:22.222", formatIntervalDay(parse("PT+22H22M22.222S"))); + } + + @Test + public void testIntervalDayWithJodaPeriodObject() { + Assert.assertEquals("+1567 00:00:00.000", + formatIntervalDay(new org.joda.time.Period().plusDays(1567))); + Assert.assertEquals("-1567 00:00:00.000", + formatIntervalDay(new org.joda.time.Period().minusDays(1567))); + } + + @Test + public void testShouldGetStringReturnCorrectString() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcIntervalVectorAccessor::getString, + (accessor, currentRow) -> is(getStringOnVector(vector, currentRow))); + } + + @Test + public void testShouldGetStringReturnNull() throws Exception { + setAllNullOnVector(vector); + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcIntervalVectorAccessor::getString, + (accessor, currentRow) -> equalTo(null)); + } + + @Test + public void testShouldGetObjectClassReturnCorrectClass() throws Exception { + Class expectedObjectClass = getExpectedObjectClassForVector(vector); + accessorIterator.assertAccessorGetter(vector, + ArrowFlightJdbcIntervalVectorAccessor::getObjectClass, + (accessor, currentRow) -> equalTo(expectedObjectClass)); + } + + private Class getExpectedObjectClassForVector(ValueVector vector) { + if (vector instanceof IntervalDayVector) { + return Duration.class; + } else if (vector instanceof IntervalYearVector) { + return Period.class; + } + return null; + } + + private void setAllNullOnVector(ValueVector vector) { + int valueCount = vector.getValueCount(); + if (vector instanceof IntervalDayVector) { + for (int i = 0; i < valueCount; i++) { + ((IntervalDayVector) vector).setNull(i); + } + } else if (vector instanceof IntervalYearVector) { + for (int i = 0; i < valueCount; i++) { + ((IntervalYearVector) vector).setNull(i); + } + } + } + + private Object getExpectedObject(ValueVector vector, int currentRow) { + if (vector instanceof IntervalDayVector) { + return Duration.ofDays(currentRow + 1).plusMillis((currentRow + 1) * 1000L); + } else if (vector instanceof IntervalYearVector) { + return Period.ofMonths(currentRow + 1); + } + return null; + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcTimeStampVectorAccessorTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcTimeStampVectorAccessorTest.java new file mode 100644 index 0000000000000..38d842724b9c1 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcTimeStampVectorAccessorTest.java @@ -0,0 +1,322 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.calendar; + +import static org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcTimeStampVectorAccessor.getTimeUnitForVector; +import static org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcTimeStampVectorAccessor.getTimeZoneForVector; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; + +import java.sql.Date; +import java.sql.Time; +import java.sql.Timestamp; +import java.time.LocalDateTime; +import java.util.Arrays; +import java.util.Calendar; +import java.util.Collection; +import java.util.TimeZone; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; + +import org.apache.arrow.driver.jdbc.accessor.impl.text.ArrowFlightJdbcVarCharVectorAccessor; +import org.apache.arrow.driver.jdbc.utils.AccessorTestUtils; +import org.apache.arrow.driver.jdbc.utils.RootAllocatorTestRule; +import org.apache.arrow.vector.TimeStampMicroVector; +import org.apache.arrow.vector.TimeStampMilliVector; +import org.apache.arrow.vector.TimeStampNanoVector; +import org.apache.arrow.vector.TimeStampSecVector; +import org.apache.arrow.vector.TimeStampVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.util.Text; +import org.hamcrest.CoreMatchers; +import org.junit.After; +import org.junit.Assume; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(Parameterized.class) +public class ArrowFlightJdbcTimeStampVectorAccessorTest { + + public static final String AMERICA_VANCOUVER = "America/Vancouver"; + public static final String ASIA_BANGKOK = "Asia/Bangkok"; + public static final String AMERICA_SAO_PAULO = "America/Sao_Paulo"; + + @ClassRule + public static RootAllocatorTestRule rootAllocatorTestRule = new RootAllocatorTestRule(); + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + private final String timeZone; + + private TimeStampVector vector; + private final Supplier vectorSupplier; + + private final AccessorTestUtils.AccessorSupplier + accessorSupplier = + (vector, getCurrentRow) -> new ArrowFlightJdbcTimeStampVectorAccessor( + (TimeStampVector) vector, getCurrentRow, (boolean wasNull) -> { + }); + + private final AccessorTestUtils.AccessorIterator + accessorIterator = + new AccessorTestUtils.AccessorIterator<>(collector, accessorSupplier); + + @Parameterized.Parameters(name = "{1} - TimeZone: {2}") + public static Collection data() { + return Arrays.asList(new Object[][] { + {(Supplier) () -> rootAllocatorTestRule.createTimeStampNanoVector(), + "TimeStampNanoVector", + null}, + {(Supplier) () -> rootAllocatorTestRule.createTimeStampNanoTZVector("UTC"), + "TimeStampNanoTZVector", + "UTC"}, + {(Supplier) () -> rootAllocatorTestRule.createTimeStampNanoTZVector( + AMERICA_VANCOUVER), + "TimeStampNanoTZVector", + AMERICA_VANCOUVER}, + {(Supplier) () -> rootAllocatorTestRule.createTimeStampNanoTZVector( + ASIA_BANGKOK), + "TimeStampNanoTZVector", + ASIA_BANGKOK}, + {(Supplier) () -> rootAllocatorTestRule.createTimeStampMicroVector(), + "TimeStampMicroVector", + null}, + {(Supplier) () -> rootAllocatorTestRule.createTimeStampMicroTZVector( + "UTC"), + "TimeStampMicroTZVector", + "UTC"}, + {(Supplier) () -> rootAllocatorTestRule.createTimeStampMicroTZVector( + AMERICA_VANCOUVER), + "TimeStampMicroTZVector", + AMERICA_VANCOUVER}, + {(Supplier) () -> rootAllocatorTestRule.createTimeStampMicroTZVector( + ASIA_BANGKOK), + "TimeStampMicroTZVector", + ASIA_BANGKOK}, + {(Supplier) () -> rootAllocatorTestRule.createTimeStampMilliVector(), + "TimeStampMilliVector", + null}, + {(Supplier) () -> rootAllocatorTestRule.createTimeStampMilliTZVector( + "UTC"), + "TimeStampMilliTZVector", + "UTC"}, + {(Supplier) () -> rootAllocatorTestRule.createTimeStampMilliTZVector( + AMERICA_VANCOUVER), + "TimeStampMilliTZVector", + AMERICA_VANCOUVER}, + {(Supplier) () -> rootAllocatorTestRule.createTimeStampMilliTZVector( + ASIA_BANGKOK), + "TimeStampMilliTZVector", + ASIA_BANGKOK}, + {(Supplier) () -> rootAllocatorTestRule.createTimeStampSecVector(), + "TimeStampSecVector", + null}, + {(Supplier) () -> rootAllocatorTestRule.createTimeStampSecTZVector("UTC"), + "TimeStampSecTZVector", + "UTC"}, + {(Supplier) () -> rootAllocatorTestRule.createTimeStampSecTZVector( + AMERICA_VANCOUVER), + "TimeStampSecTZVector", + AMERICA_VANCOUVER}, + {(Supplier) () -> rootAllocatorTestRule.createTimeStampSecTZVector( + ASIA_BANGKOK), + "TimeStampSecTZVector", + ASIA_BANGKOK} + }); + } + + public ArrowFlightJdbcTimeStampVectorAccessorTest(Supplier vectorSupplier, + String vectorType, + String timeZone) { + this.vectorSupplier = vectorSupplier; + this.timeZone = timeZone; + } + + @Before + public void setup() { + this.vector = vectorSupplier.get(); + } + + @After + public void tearDown() { + this.vector.close(); + } + + @Test + public void testShouldGetTimestampReturnValidTimestampWithoutCalendar() throws Exception { + accessorIterator.assertAccessorGetter(vector, accessor -> accessor.getTimestamp(null), + (accessor, currentRow) -> is(getTimestampForVector(currentRow))); + } + + @Test + public void testShouldGetTimestampReturnValidTimestampWithCalendar() throws Exception { + TimeZone timeZone = TimeZone.getTimeZone(AMERICA_SAO_PAULO); + Calendar calendar = Calendar.getInstance(timeZone); + + TimeZone timeZoneForVector = getTimeZoneForVector(vector); + + accessorIterator.iterate(vector, (accessor, currentRow) -> { + final Timestamp resultWithoutCalendar = accessor.getTimestamp(null); + final Timestamp result = accessor.getTimestamp(calendar); + + long offset = timeZone.getOffset(resultWithoutCalendar.getTime()) - + timeZoneForVector.getOffset(resultWithoutCalendar.getTime()); + + collector.checkThat(resultWithoutCalendar.getTime() - result.getTime(), is(offset)); + collector.checkThat(accessor.wasNull(), is(false)); + }); + } + + @Test + public void testShouldGetTimestampReturnNull() { + vector.setNull(0); + ArrowFlightJdbcTimeStampVectorAccessor accessor = accessorSupplier.supply(vector, () -> 0); + collector.checkThat(accessor.getTimestamp(null), CoreMatchers.equalTo(null)); + collector.checkThat(accessor.wasNull(), is(true)); + } + + @Test + public void testShouldGetDateReturnValidDateWithoutCalendar() throws Exception { + accessorIterator.assertAccessorGetter(vector, accessor -> accessor.getDate(null), + (accessor, currentRow) -> is(new Date(getTimestampForVector(currentRow).getTime()))); + } + + @Test + public void testShouldGetDateReturnValidDateWithCalendar() throws Exception { + TimeZone timeZone = TimeZone.getTimeZone(AMERICA_SAO_PAULO); + Calendar calendar = Calendar.getInstance(timeZone); + + TimeZone timeZoneForVector = getTimeZoneForVector(vector); + + accessorIterator.iterate(vector, (accessor, currentRow) -> { + final Date resultWithoutCalendar = accessor.getDate(null); + final Date result = accessor.getDate(calendar); + + long offset = timeZone.getOffset(resultWithoutCalendar.getTime()) - + timeZoneForVector.getOffset(resultWithoutCalendar.getTime()); + + collector.checkThat(resultWithoutCalendar.getTime() - result.getTime(), is(offset)); + collector.checkThat(accessor.wasNull(), is(false)); + }); + } + + @Test + public void testShouldGetDateReturnNull() { + vector.setNull(0); + ArrowFlightJdbcTimeStampVectorAccessor accessor = accessorSupplier.supply(vector, () -> 0); + collector.checkThat(accessor.getDate(null), CoreMatchers.equalTo(null)); + collector.checkThat(accessor.wasNull(), is(true)); + } + + @Test + public void testShouldGetTimeReturnValidTimeWithoutCalendar() throws Exception { + accessorIterator.assertAccessorGetter(vector, accessor -> accessor.getTime(null), + (accessor, currentRow) -> is(new Time(getTimestampForVector(currentRow).getTime()))); + } + + @Test + public void testShouldGetTimeReturnValidTimeWithCalendar() throws Exception { + TimeZone timeZone = TimeZone.getTimeZone(AMERICA_SAO_PAULO); + Calendar calendar = Calendar.getInstance(timeZone); + + TimeZone timeZoneForVector = getTimeZoneForVector(vector); + + accessorIterator.iterate(vector, (accessor, currentRow) -> { + final Time resultWithoutCalendar = accessor.getTime(null); + final Time result = accessor.getTime(calendar); + + long offset = timeZone.getOffset(resultWithoutCalendar.getTime()) - + timeZoneForVector.getOffset(resultWithoutCalendar.getTime()); + + collector.checkThat(resultWithoutCalendar.getTime() - result.getTime(), is(offset)); + collector.checkThat(accessor.wasNull(), is(false)); + }); + } + + @Test + public void testShouldGetTimeReturnNull() { + vector.setNull(0); + ArrowFlightJdbcTimeStampVectorAccessor accessor = accessorSupplier.supply(vector, () -> 0); + collector.checkThat(accessor.getTime(null), CoreMatchers.equalTo(null)); + collector.checkThat(accessor.wasNull(), is(true)); + } + + private Timestamp getTimestampForVector(int currentRow) { + Object object = vector.getObject(currentRow); + + Timestamp expectedTimestamp = null; + if (object instanceof LocalDateTime) { + expectedTimestamp = Timestamp.valueOf((LocalDateTime) object); + } else if (object instanceof Long) { + TimeUnit timeUnit = getTimeUnitForVector(vector); + long millis = timeUnit.toMillis((Long) object); + long offset = TimeZone.getTimeZone(timeZone).getOffset(millis); + expectedTimestamp = new Timestamp(millis + offset); + } + return expectedTimestamp; + } + + @Test + public void testShouldGetObjectClass() throws Exception { + accessorIterator.assertAccessorGetter(vector, + ArrowFlightJdbcTimeStampVectorAccessor::getObjectClass, + equalTo(Timestamp.class)); + } + + @Test + public void testShouldGetStringBeConsistentWithVarCharAccessorWithoutCalendar() throws Exception { + assertGetStringIsConsistentWithVarCharAccessor(null); + } + + @Test + public void testShouldGetStringBeConsistentWithVarCharAccessorWithCalendar() throws Exception { + // Ignore for TimeStamp vectors with TZ, as VarChar accessor won't consider their TZ + Assume.assumeTrue( + vector instanceof TimeStampNanoVector || vector instanceof TimeStampMicroVector || + vector instanceof TimeStampMilliVector || vector instanceof TimeStampSecVector); + Calendar calendar = Calendar.getInstance(TimeZone.getTimeZone(AMERICA_VANCOUVER)); + assertGetStringIsConsistentWithVarCharAccessor(calendar); + } + + private void assertGetStringIsConsistentWithVarCharAccessor(Calendar calendar) throws Exception { + try (VarCharVector varCharVector = new VarCharVector("", + rootAllocatorTestRule.getRootAllocator())) { + varCharVector.allocateNew(1); + ArrowFlightJdbcVarCharVectorAccessor varCharVectorAccessor = + new ArrowFlightJdbcVarCharVectorAccessor(varCharVector, () -> 0, (boolean wasNull) -> { + }); + + accessorIterator.iterate(vector, (accessor, currentRow) -> { + final String string = accessor.getString(); + varCharVector.set(0, new Text(string)); + varCharVector.setValueCount(1); + + Timestamp timestampFromVarChar = varCharVectorAccessor.getTimestamp(calendar); + Timestamp timestamp = accessor.getTimestamp(calendar); + + collector.checkThat(timestamp, is(timestampFromVarChar)); + collector.checkThat(accessor.wasNull(), is(false)); + }); + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcTimeVectorAccessorTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcTimeVectorAccessorTest.java new file mode 100644 index 0000000000000..d2f7eb336af59 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcTimeVectorAccessorTest.java @@ -0,0 +1,263 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.calendar; + +import static org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcTimeVectorAccessor.getTimeUnitForVector; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.not; + +import java.sql.Time; +import java.sql.Timestamp; +import java.time.LocalDateTime; +import java.util.Arrays; +import java.util.Calendar; +import java.util.Collection; +import java.util.TimeZone; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.driver.jdbc.accessor.impl.text.ArrowFlightJdbcVarCharVectorAccessor; +import org.apache.arrow.driver.jdbc.utils.AccessorTestUtils; +import org.apache.arrow.driver.jdbc.utils.RootAllocatorTestRule; +import org.apache.arrow.vector.BaseFixedWidthVector; +import org.apache.arrow.vector.TimeMicroVector; +import org.apache.arrow.vector.TimeMilliVector; +import org.apache.arrow.vector.TimeNanoVector; +import org.apache.arrow.vector.TimeSecVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.util.Text; +import org.hamcrest.CoreMatchers; +import org.junit.After; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(Parameterized.class) +public class ArrowFlightJdbcTimeVectorAccessorTest { + + public static final String AMERICA_VANCOUVER = "America/Vancouver"; + + @ClassRule + public static RootAllocatorTestRule rootAllocatorTestRule = new RootAllocatorTestRule(); + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + + private BaseFixedWidthVector vector; + private final Supplier vectorSupplier; + + private final AccessorTestUtils.AccessorSupplier + accessorSupplier = (vector, getCurrentRow) -> { + ArrowFlightJdbcAccessorFactory.WasNullConsumer noOpWasNullConsumer = (boolean wasNull) -> { + }; + if (vector instanceof TimeNanoVector) { + return new ArrowFlightJdbcTimeVectorAccessor((TimeNanoVector) vector, getCurrentRow, + noOpWasNullConsumer); + } else if (vector instanceof TimeMicroVector) { + return new ArrowFlightJdbcTimeVectorAccessor((TimeMicroVector) vector, getCurrentRow, + noOpWasNullConsumer); + } else if (vector instanceof TimeMilliVector) { + return new ArrowFlightJdbcTimeVectorAccessor((TimeMilliVector) vector, getCurrentRow, + noOpWasNullConsumer); + } else if (vector instanceof TimeSecVector) { + return new ArrowFlightJdbcTimeVectorAccessor((TimeSecVector) vector, getCurrentRow, + noOpWasNullConsumer); + } + return null; + }; + + private final AccessorTestUtils.AccessorIterator + accessorIterator = + new AccessorTestUtils.AccessorIterator<>(collector, accessorSupplier); + + @Parameterized.Parameters(name = "{1}") + public static Collection data() { + return Arrays.asList(new Object[][] { + {(Supplier) () -> rootAllocatorTestRule.createTimeNanoVector(), + "TimeNanoVector"}, + {(Supplier) () -> rootAllocatorTestRule.createTimeMicroVector(), + "TimeMicroVector"}, + {(Supplier) () -> rootAllocatorTestRule.createTimeMilliVector(), + "TimeMilliVector"}, + {(Supplier) () -> rootAllocatorTestRule.createTimeSecVector(), + "TimeSecVector"} + }); + } + + public ArrowFlightJdbcTimeVectorAccessorTest(Supplier vectorSupplier, + String vectorType) { + this.vectorSupplier = vectorSupplier; + } + + @Before + public void setup() { + this.vector = vectorSupplier.get(); + } + + @After + public void tearDown() { + this.vector.close(); + } + + @Test + public void testShouldGetTimestampReturnValidTimestampWithoutCalendar() throws Exception { + accessorIterator.assertAccessorGetter(vector, accessor -> accessor.getTimestamp(null), + (accessor, currentRow) -> is(getTimestampForVector(currentRow))); + } + + @Test + public void testShouldGetTimestampReturnValidTimestampWithCalendar() throws Exception { + TimeZone timeZone = TimeZone.getTimeZone(AMERICA_VANCOUVER); + Calendar calendar = Calendar.getInstance(timeZone); + + accessorIterator.iterate(vector, (accessor, currentRow) -> { + final Timestamp resultWithoutCalendar = accessor.getTimestamp(null); + final Timestamp result = accessor.getTimestamp(calendar); + + long offset = timeZone.getOffset(resultWithoutCalendar.getTime()); + + collector.checkThat(resultWithoutCalendar.getTime() - result.getTime(), is(offset)); + collector.checkThat(accessor.wasNull(), is(false)); + }); + } + + @Test + public void testShouldGetTimestampReturnNull() { + vector.setNull(0); + ArrowFlightJdbcTimeVectorAccessor accessor = accessorSupplier.supply(vector, () -> 0); + collector.checkThat(accessor.getTimestamp(null), CoreMatchers.equalTo(null)); + collector.checkThat(accessor.wasNull(), is(true)); + } + + @Test + public void testShouldGetTimeReturnValidTimeWithoutCalendar() throws Exception { + accessorIterator.assertAccessorGetter(vector, accessor -> accessor.getTime(null), + (accessor, currentRow) -> { + Timestamp expectedTimestamp = getTimestampForVector(currentRow); + return is(new Time(expectedTimestamp.getTime())); + }); + } + + @Test + public void testShouldGetTimeReturnValidTimeWithCalendar() throws Exception { + TimeZone timeZone = TimeZone.getTimeZone(AMERICA_VANCOUVER); + Calendar calendar = Calendar.getInstance(timeZone); + + accessorIterator.iterate(vector, (accessor, currentRow) -> { + final Time resultWithoutCalendar = accessor.getTime(null); + final Time result = accessor.getTime(calendar); + + long offset = timeZone.getOffset(resultWithoutCalendar.getTime()); + + collector.checkThat(resultWithoutCalendar.getTime() - result.getTime(), is(offset)); + collector.checkThat(accessor.wasNull(), is(false)); + }); + } + + @Test + public void testShouldGetTimeReturnNull() { + vector.setNull(0); + ArrowFlightJdbcTimeVectorAccessor accessor = accessorSupplier.supply(vector, () -> 0); + collector.checkThat(accessor.getTime(null), CoreMatchers.equalTo(null)); + collector.checkThat(accessor.wasNull(), is(true)); + } + + private Timestamp getTimestampForVector(int currentRow) { + Object object = vector.getObject(currentRow); + + Timestamp expectedTimestamp = null; + if (object instanceof LocalDateTime) { + expectedTimestamp = Timestamp.valueOf((LocalDateTime) object); + } else if (object instanceof Number) { + long value = ((Number) object).longValue(); + TimeUnit timeUnit = getTimeUnitForVector(vector); + long millis = timeUnit.toMillis(value); + expectedTimestamp = new Timestamp(millis); + } + return expectedTimestamp; + } + + @Test + public void testShouldGetObjectClass() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcTimeVectorAccessor::getObjectClass, + equalTo(Time.class)); + } + + @Test + public void testShouldGetStringBeConsistentWithVarCharAccessorWithoutCalendar() throws Exception { + assertGetStringIsConsistentWithVarCharAccessor(null); + } + + @Test + public void testShouldGetStringBeConsistentWithVarCharAccessorWithCalendar() throws Exception { + Calendar calendar = Calendar.getInstance(TimeZone.getTimeZone(AMERICA_VANCOUVER)); + assertGetStringIsConsistentWithVarCharAccessor(calendar); + } + + @Test + public void testValidateGetStringTimeZoneConsistency() throws Exception { + accessorIterator.iterate(vector, (accessor, currentRow) -> { + final TimeZone defaultTz = TimeZone.getDefault(); + try { + final String string = accessor.getString(); // Should always be UTC as no calendar is provided + + // Validate with UTC + Time time = accessor.getTime(null); + TimeZone.setDefault(TimeZone.getTimeZone("UTC")); + collector.checkThat(time.toString(), is(string)); + + // Validate with different TZ + TimeZone.setDefault(TimeZone.getTimeZone(AMERICA_VANCOUVER)); + collector.checkThat(time.toString(), not(string)); + + collector.checkThat(accessor.wasNull(), is(false)); + } finally { + // Set default Tz back + TimeZone.setDefault(defaultTz); + } + }); + } + + private void assertGetStringIsConsistentWithVarCharAccessor(Calendar calendar) throws Exception { + try (VarCharVector varCharVector = new VarCharVector("", + rootAllocatorTestRule.getRootAllocator())) { + varCharVector.allocateNew(1); + ArrowFlightJdbcVarCharVectorAccessor varCharVectorAccessor = + new ArrowFlightJdbcVarCharVectorAccessor(varCharVector, () -> 0, (boolean wasNull) -> { + }); + + accessorIterator.iterate(vector, (accessor, currentRow) -> { + final String string = accessor.getString(); + varCharVector.set(0, new Text(string)); + varCharVector.setValueCount(1); + + Time timeFromVarChar = varCharVectorAccessor.getTime(calendar); + Time time = accessor.getTime(calendar); + + collector.checkThat(time, is(timeFromVarChar)); + collector.checkThat(accessor.wasNull(), is(false)); + }); + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/AbstractArrowFlightJdbcListAccessorTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/AbstractArrowFlightJdbcListAccessorTest.java new file mode 100644 index 0000000000000..b2eb8f1dbee8f --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/AbstractArrowFlightJdbcListAccessorTest.java @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.complex; + +import static org.hamcrest.CoreMatchers.equalTo; + +import java.sql.Array; +import java.sql.ResultSet; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.function.Supplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.driver.jdbc.utils.AccessorTestUtils; +import org.apache.arrow.driver.jdbc.utils.RootAllocatorTestRule; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.complex.FixedSizeListVector; +import org.apache.arrow.vector.complex.LargeListVector; +import org.apache.arrow.vector.complex.ListVector; +import org.hamcrest.CoreMatchers; +import org.junit.After; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(Parameterized.class) +public class AbstractArrowFlightJdbcListAccessorTest { + + @ClassRule + public static RootAllocatorTestRule rootAllocatorTestRule = new RootAllocatorTestRule(); + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + + private final Supplier vectorSupplier; + private ValueVector vector; + + private final AccessorTestUtils.AccessorSupplier + accessorSupplier = (vector, getCurrentRow) -> { + ArrowFlightJdbcAccessorFactory.WasNullConsumer noOpWasNullConsumer = (boolean wasNull) -> { + }; + if (vector instanceof ListVector) { + return new ArrowFlightJdbcListVectorAccessor((ListVector) vector, getCurrentRow, + noOpWasNullConsumer); + } else if (vector instanceof LargeListVector) { + return new ArrowFlightJdbcLargeListVectorAccessor((LargeListVector) vector, getCurrentRow, + noOpWasNullConsumer); + } else if (vector instanceof FixedSizeListVector) { + return new ArrowFlightJdbcFixedSizeListVectorAccessor((FixedSizeListVector) vector, + getCurrentRow, noOpWasNullConsumer); + } + return null; + }; + + final AccessorTestUtils.AccessorIterator + accessorIterator = + new AccessorTestUtils.AccessorIterator<>(collector, accessorSupplier); + + @Parameterized.Parameters(name = "{1}") + public static Collection data() { + return Arrays.asList(new Object[][] { + {(Supplier) () -> rootAllocatorTestRule.createListVector(), "ListVector"}, + {(Supplier) () -> rootAllocatorTestRule.createLargeListVector(), + "LargeListVector"}, + {(Supplier) () -> rootAllocatorTestRule.createFixedSizeListVector(), + "FixedSizeListVector"}, + }); + } + + public AbstractArrowFlightJdbcListAccessorTest(Supplier vectorSupplier, + String vectorType) { + this.vectorSupplier = vectorSupplier; + } + + @Before + public void setup() { + this.vector = this.vectorSupplier.get(); + } + + @After + public void tearDown() { + this.vector.close(); + } + + @Test + public void testShouldGetObjectClassReturnCorrectClass() throws Exception { + accessorIterator.assertAccessorGetter(vector, + AbstractArrowFlightJdbcListVectorAccessor::getObjectClass, + (accessor, currentRow) -> equalTo(List.class)); + } + + @Test + public void testShouldGetObjectReturnValidList() throws Exception { + accessorIterator.assertAccessorGetter(vector, + AbstractArrowFlightJdbcListVectorAccessor::getObject, + (accessor, currentRow) -> equalTo( + Arrays.asList(0, (currentRow), (currentRow) * 2, (currentRow) * 3, (currentRow) * 4))); + } + + @Test + public void testShouldGetObjectReturnNull() throws Exception { + vector.clear(); + vector.allocateNewSafe(); + vector.setValueCount(5); + + accessorIterator.assertAccessorGetter(vector, + AbstractArrowFlightJdbcListVectorAccessor::getObject, + (accessor, currentRow) -> CoreMatchers.nullValue()); + } + + @Test + public void testShouldGetArrayReturnValidArray() throws Exception { + accessorIterator.iterate(vector, (accessor, currentRow) -> { + Array array = accessor.getArray(); + assert array != null; + + Object[] arrayObject = (Object[]) array.getArray(); + + collector.checkThat(arrayObject, equalTo( + new Object[] {0, currentRow, (currentRow) * 2, (currentRow) * 3, (currentRow) * 4})); + }); + } + + @Test + public void testShouldGetArrayReturnNull() throws Exception { + vector.clear(); + vector.allocateNewSafe(); + vector.setValueCount(5); + + accessorIterator.assertAccessorGetter(vector, + AbstractArrowFlightJdbcListVectorAccessor::getArray, + CoreMatchers.nullValue()); + } + + @Test + public void testShouldGetArrayReturnValidArrayPassingOffsets() throws Exception { + accessorIterator.iterate(vector, (accessor, currentRow) -> { + Array array = accessor.getArray(); + assert array != null; + + Object[] arrayObject = (Object[]) array.getArray(1, 3); + + collector.checkThat(arrayObject, equalTo( + new Object[] {currentRow, (currentRow) * 2, (currentRow) * 3})); + }); + } + + @Test + public void testShouldGetArrayGetResultSetReturnValidResultSet() throws Exception { + accessorIterator.iterate(vector, (accessor, currentRow) -> { + Array array = accessor.getArray(); + assert array != null; + + try (ResultSet rs = array.getResultSet()) { + int count = 0; + while (rs.next()) { + final int value = rs.getInt(1); + collector.checkThat(value, equalTo(currentRow * count)); + count++; + } + collector.checkThat(count, equalTo(5)); + } + }); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/AbstractArrowFlightJdbcUnionVectorAccessorTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/AbstractArrowFlightJdbcUnionVectorAccessorTest.java new file mode 100644 index 0000000000000..2b53b27dc9e13 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/AbstractArrowFlightJdbcUnionVectorAccessorTest.java @@ -0,0 +1,265 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.complex; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.sql.SQLException; +import java.util.Calendar; +import java.util.Map; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.ArrowFlightJdbcNullVectorAccessor; +import org.apache.arrow.vector.NullVector; +import org.apache.arrow.vector.ValueVector; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.Spy; +import org.mockito.junit.MockitoJUnitRunner; + +@RunWith(MockitoJUnitRunner.class) +public class AbstractArrowFlightJdbcUnionVectorAccessorTest { + + @Mock + ArrowFlightJdbcAccessor innerAccessor; + @Spy + AbstractArrowFlightJdbcUnionVectorAccessorMock accessor; + + @Before + public void setup() { + when(accessor.getAccessor()).thenReturn(innerAccessor); + } + + @Test + public void testGetNCharacterStreamUsesSpecificAccessor() throws SQLException { + accessor.getNCharacterStream(); + verify(innerAccessor).getNCharacterStream(); + } + + @Test + public void testGetNStringUsesSpecificAccessor() throws SQLException { + accessor.getNString(); + verify(innerAccessor).getNString(); + } + + @Test + public void testGetSQLXMLUsesSpecificAccessor() throws SQLException { + accessor.getSQLXML(); + verify(innerAccessor).getSQLXML(); + } + + @Test + public void testGetNClobUsesSpecificAccessor() throws SQLException { + accessor.getNClob(); + verify(innerAccessor).getNClob(); + } + + @Test + public void testGetURLUsesSpecificAccessor() throws SQLException { + accessor.getURL(); + verify(innerAccessor).getURL(); + } + + @Test + public void testGetStructUsesSpecificAccessor() throws SQLException { + accessor.getStruct(); + verify(innerAccessor).getStruct(); + } + + @Test + public void testGetArrayUsesSpecificAccessor() throws SQLException { + accessor.getArray(); + verify(innerAccessor).getArray(); + } + + @Test + public void testGetClobUsesSpecificAccessor() throws SQLException { + accessor.getClob(); + verify(innerAccessor).getClob(); + } + + @Test + public void testGetBlobUsesSpecificAccessor() throws SQLException { + accessor.getBlob(); + verify(innerAccessor).getBlob(); + } + + @Test + public void testGetRefUsesSpecificAccessor() throws SQLException { + accessor.getRef(); + verify(innerAccessor).getRef(); + } + + @Test + public void testGetCharacterStreamUsesSpecificAccessor() throws SQLException { + accessor.getCharacterStream(); + verify(innerAccessor).getCharacterStream(); + } + + @Test + public void testGetBinaryStreamUsesSpecificAccessor() throws SQLException { + accessor.getBinaryStream(); + verify(innerAccessor).getBinaryStream(); + } + + @Test + public void testGetUnicodeStreamUsesSpecificAccessor() throws SQLException { + accessor.getUnicodeStream(); + verify(innerAccessor).getUnicodeStream(); + } + + @Test + public void testGetAsciiStreamUsesSpecificAccessor() throws SQLException { + accessor.getAsciiStream(); + verify(innerAccessor).getAsciiStream(); + } + + @Test + public void testGetBytesUsesSpecificAccessor() throws SQLException { + accessor.getBytes(); + verify(innerAccessor).getBytes(); + } + + @Test + public void testGetBigDecimalUsesSpecificAccessor() throws SQLException { + accessor.getBigDecimal(); + verify(innerAccessor).getBigDecimal(); + } + + @Test + public void testGetDoubleUsesSpecificAccessor() throws SQLException { + accessor.getDouble(); + verify(innerAccessor).getDouble(); + } + + @Test + public void testGetFloatUsesSpecificAccessor() throws SQLException { + accessor.getFloat(); + verify(innerAccessor).getFloat(); + } + + @Test + public void testGetLongUsesSpecificAccessor() throws SQLException { + accessor.getLong(); + verify(innerAccessor).getLong(); + } + + @Test + public void testGetIntUsesSpecificAccessor() throws SQLException { + accessor.getInt(); + verify(innerAccessor).getInt(); + } + + @Test + public void testGetShortUsesSpecificAccessor() throws SQLException { + accessor.getShort(); + verify(innerAccessor).getShort(); + } + + @Test + public void testGetByteUsesSpecificAccessor() throws SQLException { + accessor.getByte(); + verify(innerAccessor).getByte(); + } + + @Test + public void testGetBooleanUsesSpecificAccessor() throws SQLException { + accessor.getBoolean(); + verify(innerAccessor).getBoolean(); + } + + @Test + public void testGetStringUsesSpecificAccessor() throws SQLException { + accessor.getString(); + verify(innerAccessor).getString(); + } + + @Test + public void testGetObjectClassUsesSpecificAccessor() { + accessor.getObjectClass(); + verify(innerAccessor).getObjectClass(); + } + + @Test + public void testGetObjectWithClassUsesSpecificAccessor() throws SQLException { + accessor.getObject(Object.class); + verify(innerAccessor).getObject(Object.class); + } + + @Test + public void testGetTimestampUsesSpecificAccessor() throws SQLException { + Calendar calendar = Calendar.getInstance(); + accessor.getTimestamp(calendar); + verify(innerAccessor).getTimestamp(calendar); + } + + @Test + public void testGetTimeUsesSpecificAccessor() throws SQLException { + Calendar calendar = Calendar.getInstance(); + accessor.getTime(calendar); + verify(innerAccessor).getTime(calendar); + } + + @Test + public void testGetDateUsesSpecificAccessor() throws SQLException { + Calendar calendar = Calendar.getInstance(); + accessor.getDate(calendar); + verify(innerAccessor).getDate(calendar); + } + + @Test + public void testGetObjectUsesSpecificAccessor() throws SQLException { + Map> map = mock(Map.class); + accessor.getObject(map); + verify(innerAccessor).getObject(map); + } + + @Test + public void testGetBigDecimalWithScaleUsesSpecificAccessor() throws SQLException { + accessor.getBigDecimal(2); + verify(innerAccessor).getBigDecimal(2); + } + + private static class AbstractArrowFlightJdbcUnionVectorAccessorMock + extends AbstractArrowFlightJdbcUnionVectorAccessor { + protected AbstractArrowFlightJdbcUnionVectorAccessorMock() { + super(() -> 0, (boolean wasNull) -> { + }); + } + + @Override + protected ArrowFlightJdbcAccessor createAccessorForVector(ValueVector vector) { + return new ArrowFlightJdbcNullVectorAccessor((boolean wasNull) -> { + }); + } + + @Override + protected byte getCurrentTypeId() { + return 0; + } + + @Override + protected ValueVector getVectorByTypeId(byte typeId) { + return new NullVector(); + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcDenseUnionVectorAccessorTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcDenseUnionVectorAccessorTest.java new file mode 100644 index 0000000000000..41d5eb97e85f6 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcDenseUnionVectorAccessorTest.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.complex; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; + +import java.sql.Timestamp; +import java.util.Arrays; +import java.util.List; + +import org.apache.arrow.driver.jdbc.utils.AccessorTestUtils; +import org.apache.arrow.driver.jdbc.utils.RootAllocatorTestRule; +import org.apache.arrow.vector.complex.DenseUnionVector; +import org.apache.arrow.vector.holders.NullableBigIntHolder; +import org.apache.arrow.vector.holders.NullableFloat8Holder; +import org.apache.arrow.vector.holders.NullableTimeStampMilliHolder; +import org.apache.arrow.vector.types.Types; +import org.apache.arrow.vector.types.pojo.Field; +import org.junit.After; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; + +public class ArrowFlightJdbcDenseUnionVectorAccessorTest { + + @ClassRule + public static RootAllocatorTestRule rootAllocatorTestRule = new RootAllocatorTestRule(); + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + + private DenseUnionVector vector; + + private final AccessorTestUtils.AccessorSupplier + accessorSupplier = + (vector, getCurrentRow) -> new ArrowFlightJdbcDenseUnionVectorAccessor( + (DenseUnionVector) vector, getCurrentRow, (boolean wasNull) -> { + //No Operation + }); + + private final AccessorTestUtils.AccessorIterator + accessorIterator = + new AccessorTestUtils.AccessorIterator<>(collector, accessorSupplier); + + @Before + public void setup() throws Exception { + this.vector = DenseUnionVector.empty("", rootAllocatorTestRule.getRootAllocator()); + this.vector.allocateNew(); + + // write some data + byte bigIntTypeId = + this.vector.registerNewTypeId(Field.nullable("", Types.MinorType.BIGINT.getType())); + byte float8TypeId = + this.vector.registerNewTypeId(Field.nullable("", Types.MinorType.FLOAT8.getType())); + byte timestampMilliTypeId = + this.vector.registerNewTypeId(Field.nullable("", Types.MinorType.TIMESTAMPMILLI.getType())); + + NullableBigIntHolder nullableBigIntHolder = new NullableBigIntHolder(); + nullableBigIntHolder.isSet = 1; + nullableBigIntHolder.value = Long.MAX_VALUE; + this.vector.setTypeId(0, bigIntTypeId); + this.vector.setSafe(0, nullableBigIntHolder); + + NullableFloat8Holder nullableFloat4Holder = new NullableFloat8Holder(); + nullableFloat4Holder.isSet = 1; + nullableFloat4Holder.value = Math.PI; + this.vector.setTypeId(1, float8TypeId); + this.vector.setSafe(1, nullableFloat4Holder); + + NullableTimeStampMilliHolder nullableTimeStampMilliHolder = new NullableTimeStampMilliHolder(); + nullableTimeStampMilliHolder.isSet = 1; + nullableTimeStampMilliHolder.value = 1625702400000L; + this.vector.setTypeId(2, timestampMilliTypeId); + this.vector.setSafe(2, nullableTimeStampMilliHolder); + + nullableBigIntHolder.isSet = 0; + this.vector.setTypeId(3, bigIntTypeId); + this.vector.setSafe(3, nullableBigIntHolder); + + this.vector.setValueCount(5); + } + + @After + public void tearDown() { + this.vector.close(); + } + + @Test + public void getObject() throws Exception { + List result = accessorIterator.toList(vector); + List expected = Arrays.asList( + Long.MAX_VALUE, + Math.PI, + new Timestamp(1625702400000L), + null, + null); + + collector.checkThat(result, is(expected)); + } + + @Test + public void getObjectForNull() throws Exception { + vector.reset(); + vector.setValueCount(5); + accessorIterator.assertAccessorGetter(vector, + AbstractArrowFlightJdbcUnionVectorAccessor::getObject, equalTo(null)); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcMapVectorAccessorTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcMapVectorAccessorTest.java new file mode 100644 index 0000000000000..7a81da4240b1a --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcMapVectorAccessorTest.java @@ -0,0 +1,221 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.complex; + +import java.sql.Array; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.Map; + +import org.apache.arrow.driver.jdbc.utils.AccessorTestUtils; +import org.apache.arrow.driver.jdbc.utils.RootAllocatorTestRule; +import org.apache.arrow.vector.complex.MapVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.complex.impl.UnionMapWriter; +import org.apache.arrow.vector.util.JsonStringHashMap; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; + +public class ArrowFlightJdbcMapVectorAccessorTest { + @ClassRule + public static RootAllocatorTestRule rootAllocatorTestRule = new RootAllocatorTestRule(); + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + + private MapVector vector; + + @Before + public void setup() { + vector = MapVector.empty("", rootAllocatorTestRule.getRootAllocator(), false); + UnionMapWriter writer = vector.getWriter(); + writer.allocate(); + writer.setPosition(0); // optional + writer.startMap(); + writer.startEntry(); + writer.key().integer().writeInt(1); + writer.value().integer().writeInt(11); + writer.endEntry(); + writer.startEntry(); + writer.key().integer().writeInt(2); + writer.value().integer().writeInt(22); + writer.endEntry(); + writer.startEntry(); + writer.key().integer().writeInt(3); + writer.value().integer().writeInt(33); + writer.endEntry(); + writer.endMap(); + + writer.setPosition(1); + writer.startMap(); + writer.startEntry(); + writer.key().integer().writeInt(2); + writer.endEntry(); + writer.endMap(); + + writer.setPosition(2); + writer.startMap(); + writer.startEntry(); + writer.key().integer().writeInt(0); + writer.value().integer().writeInt(2000); + writer.endEntry(); + writer.startEntry(); + writer.key().integer().writeInt(1); + writer.value().integer().writeInt(2001); + writer.endEntry(); + writer.startEntry(); + writer.key().integer().writeInt(2); + writer.value().integer().writeInt(2002); + writer.endEntry(); + writer.startEntry(); + writer.key().integer().writeInt(3); + writer.value().integer().writeInt(2003); + writer.endEntry(); + writer.endMap(); + + writer.setValueCount(3); + } + + @After + public void tearDown() { + vector.close(); + } + + @Test + public void testShouldGetObjectReturnValidMap() { + AccessorTestUtils.Cursor cursor = new AccessorTestUtils.Cursor(vector.getValueCount()); + ArrowFlightJdbcMapVectorAccessor accessor = + new ArrowFlightJdbcMapVectorAccessor(vector, cursor::getCurrentRow, (boolean wasNull) -> { + }); + + Map expected = new JsonStringHashMap<>(); + expected.put(1, 11); + expected.put(2, 22); + expected.put(3, 33); + Assert.assertEquals(expected, accessor.getObject()); + Assert.assertFalse(accessor.wasNull()); + + cursor.next(); + expected = new JsonStringHashMap<>(); + expected.put(2, null); + Assert.assertEquals(expected, accessor.getObject()); + Assert.assertFalse(accessor.wasNull()); + + cursor.next(); + expected = new JsonStringHashMap<>(); + expected.put(0, 2000); + expected.put(1, 2001); + expected.put(2, 2002); + expected.put(3, 2003); + Assert.assertEquals(expected, accessor.getObject()); + Assert.assertFalse(accessor.wasNull()); + } + + @Test + public void testShouldGetObjectReturnNull() { + vector.setNull(0); + ArrowFlightJdbcMapVectorAccessor accessor = + new ArrowFlightJdbcMapVectorAccessor(vector, () -> 0, (boolean wasNull) -> { + }); + + Assert.assertNull(accessor.getObject()); + Assert.assertTrue(accessor.wasNull()); + } + + @Test + public void testShouldGetArrayReturnValidArray() throws SQLException { + AccessorTestUtils.Cursor cursor = new AccessorTestUtils.Cursor(vector.getValueCount()); + ArrowFlightJdbcMapVectorAccessor accessor = + new ArrowFlightJdbcMapVectorAccessor(vector, cursor::getCurrentRow, (boolean wasNull) -> { + }); + + Array array = accessor.getArray(); + Assert.assertNotNull(array); + Assert.assertFalse(accessor.wasNull()); + + try (ResultSet resultSet = array.getResultSet()) { + Assert.assertTrue(resultSet.next()); + Map entry = resultSet.getObject(1, Map.class); + Assert.assertEquals(1, entry.get("key")); + Assert.assertEquals(11, entry.get("value")); + Assert.assertTrue(resultSet.next()); + entry = resultSet.getObject(1, Map.class); + Assert.assertEquals(2, entry.get("key")); + Assert.assertEquals(22, entry.get("value")); + Assert.assertTrue(resultSet.next()); + entry = resultSet.getObject(1, Map.class); + Assert.assertEquals(3, entry.get("key")); + Assert.assertEquals(33, entry.get("value")); + Assert.assertFalse(resultSet.next()); + } + + cursor.next(); + array = accessor.getArray(); + Assert.assertNotNull(array); + Assert.assertFalse(accessor.wasNull()); + try (ResultSet resultSet = array.getResultSet()) { + Assert.assertTrue(resultSet.next()); + Map entry = resultSet.getObject(1, Map.class); + Assert.assertEquals(2, entry.get("key")); + Assert.assertNull(entry.get("value")); + Assert.assertFalse(resultSet.next()); + } + + cursor.next(); + array = accessor.getArray(); + Assert.assertNotNull(array); + Assert.assertFalse(accessor.wasNull()); + try (ResultSet resultSet = array.getResultSet()) { + Assert.assertTrue(resultSet.next()); + Map entry = resultSet.getObject(1, Map.class); + Assert.assertEquals(0, entry.get("key")); + Assert.assertEquals(2000, entry.get("value")); + Assert.assertTrue(resultSet.next()); + entry = resultSet.getObject(1, Map.class); + Assert.assertEquals(1, entry.get("key")); + Assert.assertEquals(2001, entry.get("value")); + Assert.assertTrue(resultSet.next()); + entry = resultSet.getObject(1, Map.class); + Assert.assertEquals(2, entry.get("key")); + Assert.assertEquals(2002, entry.get("value")); + Assert.assertTrue(resultSet.next()); + entry = resultSet.getObject(1, Map.class); + Assert.assertEquals(3, entry.get("key")); + Assert.assertEquals(2003, entry.get("value")); + Assert.assertFalse(resultSet.next()); + } + } + + @Test + public void testShouldGetArrayReturnNull() { + vector.setNull(0); + ((StructVector) vector.getDataVector()).setNull(0); + + ArrowFlightJdbcMapVectorAccessor accessor = + new ArrowFlightJdbcMapVectorAccessor(vector, () -> 0, (boolean wasNull) -> { + }); + + Assert.assertNull(accessor.getArray()); + Assert.assertTrue(accessor.wasNull()); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcStructVectorAccessorTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcStructVectorAccessorTest.java new file mode 100644 index 0000000000000..b3c85fc0ab1f3 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcStructVectorAccessorTest.java @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.complex; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.nullValue; + +import java.sql.SQLException; +import java.sql.Struct; +import java.util.HashMap; +import java.util.Map; + +import org.apache.arrow.driver.jdbc.utils.AccessorTestUtils; +import org.apache.arrow.driver.jdbc.utils.RootAllocatorTestRule; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.complex.UnionVector; +import org.apache.arrow.vector.complex.impl.UnionListWriter; +import org.apache.arrow.vector.holders.NullableBitHolder; +import org.apache.arrow.vector.types.Types; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.util.JsonStringArrayList; +import org.apache.arrow.vector.util.JsonStringHashMap; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; + +public class ArrowFlightJdbcStructVectorAccessorTest { + + @ClassRule + public static RootAllocatorTestRule rootAllocatorTestRule = new RootAllocatorTestRule(); + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + + private StructVector vector; + + private final AccessorTestUtils.AccessorSupplier + accessorSupplier = + (vector, getCurrentRow) -> new ArrowFlightJdbcStructVectorAccessor((StructVector) vector, + getCurrentRow, (boolean wasNull) -> { + }); + + private final AccessorTestUtils.AccessorIterator + accessorIterator = + new AccessorTestUtils.AccessorIterator<>(collector, accessorSupplier); + + @Before + public void setUp() throws Exception { + Map metadata = new HashMap<>(); + metadata.put("k1", "v1"); + FieldType type = new FieldType(true, ArrowType.Struct.INSTANCE, null, metadata); + vector = new StructVector("", rootAllocatorTestRule.getRootAllocator(), type, null); + vector.allocateNew(); + + IntVector intVector = + vector.addOrGet("int", FieldType.nullable(Types.MinorType.INT.getType()), IntVector.class); + Float8Vector float8Vector = + vector.addOrGet("float8", FieldType.nullable(Types.MinorType.FLOAT8.getType()), + Float8Vector.class); + + intVector.setSafe(0, 100); + float8Vector.setSafe(0, 100.05); + vector.setIndexDefined(0); + intVector.setSafe(1, 200); + float8Vector.setSafe(1, 200.1); + vector.setIndexDefined(1); + + vector.setValueCount(2); + } + + @After + public void tearDown() throws Exception { + vector.close(); + } + + @Test + public void testShouldGetObjectClassReturnMapClass() throws Exception { + accessorIterator.assertAccessorGetter(vector, + ArrowFlightJdbcStructVectorAccessor::getObjectClass, + (accessor, currentRow) -> equalTo(Map.class)); + } + + @Test + public void testShouldGetObjectReturnValidMap() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcStructVectorAccessor::getObject, + (accessor, currentRow) -> { + Map expected = new HashMap<>(); + expected.put("int", 100 * (currentRow + 1)); + expected.put("float8", 100.05 * (currentRow + 1)); + + return equalTo(expected); + }); + } + + @Test + public void testShouldGetObjectReturnNull() throws Exception { + vector.setNull(0); + vector.setNull(1); + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcStructVectorAccessor::getObject, + (accessor, currentRow) -> nullValue()); + } + + @Test + public void testShouldGetStructReturnValidStruct() throws Exception { + accessorIterator.iterate(vector, (accessor, currentRow) -> { + Struct struct = accessor.getStruct(); + assert struct != null; + + Object[] expected = new Object[] { + 100 * (currentRow + 1), + 100.05 * (currentRow + 1) + }; + + collector.checkThat(struct.getAttributes(), equalTo(expected)); + }); + } + + @Test + public void testShouldGetStructReturnNull() throws Exception { + vector.setNull(0); + vector.setNull(1); + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcStructVectorAccessor::getStruct, + (accessor, currentRow) -> nullValue()); + } + + @Test + public void testShouldGetObjectWorkWithNestedComplexData() throws SQLException { + try (StructVector rootVector = StructVector.empty("", + rootAllocatorTestRule.getRootAllocator())) { + StructVector structVector = rootVector.addOrGetStruct("struct"); + + FieldType intFieldType = FieldType.nullable(Types.MinorType.INT.getType()); + IntVector intVector = structVector.addOrGet("int", intFieldType, IntVector.class); + FieldType float8FieldType = FieldType.nullable(Types.MinorType.FLOAT8.getType()); + Float8Vector float8Vector = + structVector.addOrGet("float8", float8FieldType, Float8Vector.class); + + ListVector listVector = rootVector.addOrGetList("list"); + UnionListWriter listWriter = listVector.getWriter(); + listWriter.allocate(); + + UnionVector unionVector = rootVector.addOrGetUnion("union"); + + intVector.setSafe(0, 100); + intVector.setValueCount(1); + float8Vector.setSafe(0, 100.05); + float8Vector.setValueCount(1); + structVector.setIndexDefined(0); + + listWriter.setPosition(0); + listWriter.startList(); + listWriter.bigInt().writeBigInt(Long.MAX_VALUE); + listWriter.bigInt().writeBigInt(Long.MIN_VALUE); + listWriter.endList(); + listVector.setValueCount(1); + + unionVector.setType(0, Types.MinorType.BIT); + NullableBitHolder holder = new NullableBitHolder(); + holder.isSet = 1; + holder.value = 1; + unionVector.setSafe(0, holder); + unionVector.setValueCount(1); + + rootVector.setIndexDefined(0); + rootVector.setValueCount(1); + + Map expected = new JsonStringHashMap<>(); + Map nestedStruct = new JsonStringHashMap<>(); + nestedStruct.put("int", 100); + nestedStruct.put("float8", 100.05); + expected.put("struct", nestedStruct); + JsonStringArrayList nestedList = new JsonStringArrayList<>(); + nestedList.add(Long.MAX_VALUE); + nestedList.add(Long.MIN_VALUE); + expected.put("list", nestedList); + expected.put("union", true); + + ArrowFlightJdbcStructVectorAccessor accessor = + new ArrowFlightJdbcStructVectorAccessor(rootVector, () -> 0, (boolean wasNull) -> { + }); + + Assert.assertEquals(accessor.getObject(), expected); + Assert.assertEquals(accessor.getString(), expected.toString()); + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcUnionVectorAccessorTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcUnionVectorAccessorTest.java new file mode 100644 index 0000000000000..9ec9388ff87c9 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcUnionVectorAccessorTest.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.complex; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; + +import java.sql.Timestamp; +import java.util.Arrays; +import java.util.List; + +import org.apache.arrow.driver.jdbc.utils.AccessorTestUtils; +import org.apache.arrow.driver.jdbc.utils.RootAllocatorTestRule; +import org.apache.arrow.vector.complex.UnionVector; +import org.apache.arrow.vector.holders.NullableBigIntHolder; +import org.apache.arrow.vector.holders.NullableFloat8Holder; +import org.apache.arrow.vector.holders.NullableTimeStampMilliHolder; +import org.apache.arrow.vector.types.Types; +import org.junit.After; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; + +public class ArrowFlightJdbcUnionVectorAccessorTest { + + @ClassRule + public static RootAllocatorTestRule rootAllocatorTestRule = new RootAllocatorTestRule(); + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + + private UnionVector vector; + + private final AccessorTestUtils.AccessorSupplier + accessorSupplier = + (vector, getCurrentRow) -> new ArrowFlightJdbcUnionVectorAccessor((UnionVector) vector, + getCurrentRow, (boolean wasNull) -> { + }); + + private final AccessorTestUtils.AccessorIterator + accessorIterator = + new AccessorTestUtils.AccessorIterator<>(collector, accessorSupplier); + + @Before + public void setup() { + this.vector = UnionVector.empty("", rootAllocatorTestRule.getRootAllocator()); + this.vector.allocateNew(); + + NullableBigIntHolder nullableBigIntHolder = new NullableBigIntHolder(); + nullableBigIntHolder.isSet = 1; + nullableBigIntHolder.value = Long.MAX_VALUE; + this.vector.setType(0, Types.MinorType.BIGINT); + this.vector.setSafe(0, nullableBigIntHolder); + + NullableFloat8Holder nullableFloat4Holder = new NullableFloat8Holder(); + nullableFloat4Holder.isSet = 1; + nullableFloat4Holder.value = Math.PI; + this.vector.setType(1, Types.MinorType.FLOAT8); + this.vector.setSafe(1, nullableFloat4Holder); + + NullableTimeStampMilliHolder nullableTimeStampMilliHolder = new NullableTimeStampMilliHolder(); + nullableTimeStampMilliHolder.isSet = 1; + nullableTimeStampMilliHolder.value = 1625702400000L; + this.vector.setType(2, Types.MinorType.TIMESTAMPMILLI); + this.vector.setSafe(2, nullableTimeStampMilliHolder); + + nullableBigIntHolder.isSet = 0; + this.vector.setType(3, Types.MinorType.BIGINT); + this.vector.setSafe(3, nullableBigIntHolder); + + this.vector.setValueCount(5); + } + + @After + public void tearDown() { + this.vector.close(); + } + + @Test + public void getObject() throws Exception { + List result = accessorIterator.toList(vector); + List expected = Arrays.asList( + Long.MAX_VALUE, + Math.PI, + new Timestamp(1625702400000L), + null, + null); + + collector.checkThat(result, is(expected)); + } + + @Test + public void getObjectForNull() throws Exception { + vector.reset(); + vector.setValueCount(5); + + accessorIterator.assertAccessorGetter(vector, + AbstractArrowFlightJdbcUnionVectorAccessor::getObject, + equalTo(null)); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcBaseIntVectorAccessorTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcBaseIntVectorAccessorTest.java new file mode 100644 index 0000000000000..5e54b545a85ac --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcBaseIntVectorAccessorTest.java @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.numeric; + +import static org.hamcrest.CoreMatchers.equalTo; + +import java.util.Arrays; +import java.util.Collection; +import java.util.function.Supplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.driver.jdbc.utils.AccessorTestUtils; +import org.apache.arrow.driver.jdbc.utils.RootAllocatorTestRule; +import org.apache.arrow.vector.BaseIntVector; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.UInt1Vector; +import org.apache.arrow.vector.UInt2Vector; +import org.apache.arrow.vector.UInt4Vector; +import org.apache.arrow.vector.UInt8Vector; +import org.junit.After; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + + +@RunWith(Parameterized.class) +public class ArrowFlightJdbcBaseIntVectorAccessorTest { + + @ClassRule + public static RootAllocatorTestRule rootAllocatorTestRule = new RootAllocatorTestRule(); + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + + private BaseIntVector vector; + private final Supplier vectorSupplier; + + private final AccessorTestUtils.AccessorSupplier + accessorSupplier = (vector, getCurrentRow) -> { + ArrowFlightJdbcAccessorFactory.WasNullConsumer noOpWasNullConsumer = (boolean wasNull) -> { + }; + if (vector instanceof UInt1Vector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((UInt1Vector) vector, getCurrentRow, + noOpWasNullConsumer); + } else if (vector instanceof UInt2Vector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((UInt2Vector) vector, getCurrentRow, + noOpWasNullConsumer); + } else { + if (vector instanceof UInt4Vector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((UInt4Vector) vector, getCurrentRow, + noOpWasNullConsumer); + } else if (vector instanceof UInt8Vector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((UInt8Vector) vector, getCurrentRow, + noOpWasNullConsumer); + } else if (vector instanceof TinyIntVector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((TinyIntVector) vector, getCurrentRow, + noOpWasNullConsumer); + } else if (vector instanceof SmallIntVector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((SmallIntVector) vector, getCurrentRow, + noOpWasNullConsumer); + } else if (vector instanceof IntVector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((IntVector) vector, getCurrentRow, + noOpWasNullConsumer); + } else if (vector instanceof BigIntVector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((BigIntVector) vector, getCurrentRow, + noOpWasNullConsumer); + } + } + throw new UnsupportedOperationException(); + }; + + private final AccessorTestUtils.AccessorIterator + accessorIterator = new AccessorTestUtils.AccessorIterator<>(collector, accessorSupplier); + + @Parameterized.Parameters(name = "{1}") + public static Collection data() { + return Arrays.asList(new Object[][] { + {(Supplier) () -> rootAllocatorTestRule.createIntVector(), "IntVector"}, + {(Supplier) () -> rootAllocatorTestRule.createSmallIntVector(), + "SmallIntVector"}, + {(Supplier) () -> rootAllocatorTestRule.createTinyIntVector(), + "TinyIntVector"}, + {(Supplier) () -> rootAllocatorTestRule.createBigIntVector(), + "BigIntVector"}, + {(Supplier) () -> rootAllocatorTestRule.createUInt1Vector(), "UInt1Vector"}, + {(Supplier) () -> rootAllocatorTestRule.createUInt2Vector(), "UInt2Vector"}, + {(Supplier) () -> rootAllocatorTestRule.createUInt4Vector(), "UInt4Vector"}, + {(Supplier) () -> rootAllocatorTestRule.createUInt8Vector(), "UInt8Vector"} + }); + } + + public ArrowFlightJdbcBaseIntVectorAccessorTest(Supplier vectorSupplier, + String vectorType) { + this.vectorSupplier = vectorSupplier; + } + + @Before + public void setup() { + this.vector = vectorSupplier.get(); + } + + @After + public void tearDown() { + this.vector.close(); + } + + @Test + public void testShouldConvertToByteMethodFromBaseIntVector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcBaseIntVectorAccessor::getByte, + (accessor, currentRow) -> equalTo((byte) accessor.getLong())); + } + + @Test + public void testShouldConvertToShortMethodFromBaseIntVector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcBaseIntVectorAccessor::getShort, + (accessor, currentRow) -> equalTo((short) accessor.getLong())); + } + + @Test + public void testShouldConvertToIntegerMethodFromBaseIntVector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcBaseIntVectorAccessor::getInt, + (accessor, currentRow) -> equalTo((int) accessor.getLong())); + } + + @Test + public void testShouldConvertToFloatMethodFromBaseIntVector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcBaseIntVectorAccessor::getFloat, + (accessor, currentRow) -> equalTo((float) accessor.getLong())); + } + + @Test + public void testShouldConvertToDoubleMethodFromBaseIntVector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcBaseIntVectorAccessor::getDouble, + (accessor, currentRow) -> equalTo((double) accessor.getLong())); + } + + @Test + public void testShouldConvertToBooleanMethodFromBaseIntVector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcBaseIntVectorAccessor::getBoolean, + (accessor, currentRow) -> equalTo(accessor.getLong() != 0L)); + } + + @Test + public void testShouldGetObjectClass() throws Exception { + accessorIterator.assertAccessorGetter(vector, + ArrowFlightJdbcBaseIntVectorAccessor::getObjectClass, + equalTo(Long.class)); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcBaseIntVectorAccessorUnitTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcBaseIntVectorAccessorUnitTest.java new file mode 100644 index 0000000000000..2e64b6fb40276 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcBaseIntVectorAccessorUnitTest.java @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.numeric; + +import static org.hamcrest.CoreMatchers.equalTo; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.driver.jdbc.utils.AccessorTestUtils; +import org.apache.arrow.driver.jdbc.utils.RootAllocatorTestRule; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.UInt1Vector; +import org.apache.arrow.vector.UInt2Vector; +import org.apache.arrow.vector.UInt4Vector; +import org.apache.arrow.vector.UInt8Vector; +import org.hamcrest.CoreMatchers; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; +import org.junit.runner.RunWith; +import org.mockito.junit.MockitoJUnitRunner; + +@RunWith(MockitoJUnitRunner.class) +public class ArrowFlightJdbcBaseIntVectorAccessorUnitTest { + + @ClassRule + public static RootAllocatorTestRule rule = new RootAllocatorTestRule(); + private static UInt4Vector int4Vector; + private static UInt8Vector int8Vector; + private static IntVector intVectorWithNull; + private static TinyIntVector tinyIntVector; + private static SmallIntVector smallIntVector; + private static IntVector intVector; + private static BigIntVector bigIntVector; + @Rule + public final ErrorCollector collector = new ErrorCollector(); + + private final AccessorTestUtils.AccessorSupplier + accessorSupplier = (vector, getCurrentRow) -> { + ArrowFlightJdbcAccessorFactory.WasNullConsumer noOpWasNullConsumer = (boolean wasNull) -> { + }; + if (vector instanceof UInt1Vector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((UInt1Vector) vector, getCurrentRow, + noOpWasNullConsumer); + } else if (vector instanceof UInt2Vector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((UInt2Vector) vector, getCurrentRow, + noOpWasNullConsumer); + } else if (vector instanceof UInt4Vector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((UInt4Vector) vector, getCurrentRow, + noOpWasNullConsumer); + } else if (vector instanceof UInt8Vector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((UInt8Vector) vector, getCurrentRow, + noOpWasNullConsumer); + } else if (vector instanceof TinyIntVector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((TinyIntVector) vector, getCurrentRow, + noOpWasNullConsumer); + } else if (vector instanceof SmallIntVector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((SmallIntVector) vector, getCurrentRow, + noOpWasNullConsumer); + } else if (vector instanceof IntVector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((IntVector) vector, getCurrentRow, + noOpWasNullConsumer); + } else if (vector instanceof BigIntVector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((BigIntVector) vector, getCurrentRow, + noOpWasNullConsumer); + } + return null; + }; + + private final AccessorTestUtils.AccessorIterator + accessorIterator = new AccessorTestUtils.AccessorIterator<>(collector, accessorSupplier); + + @BeforeClass + public static void setup() { + int4Vector = new UInt4Vector("ID", rule.getRootAllocator()); + int4Vector.setSafe(0, 0x80000001); + int4Vector.setValueCount(1); + + int8Vector = new UInt8Vector("ID", rule.getRootAllocator()); + int8Vector.setSafe(0, 0xFFFFFFFFFFFFFFFFL); + int8Vector.setValueCount(1); + + intVectorWithNull = new IntVector("ID", rule.getRootAllocator()); + intVectorWithNull.setNull(0); + intVectorWithNull.setValueCount(1); + + tinyIntVector = new TinyIntVector("ID", rule.getRootAllocator()); + tinyIntVector.setSafe(0, 0xAA); + tinyIntVector.setValueCount(1); + + smallIntVector = new SmallIntVector("ID", rule.getRootAllocator()); + smallIntVector.setSafe(0, 0xAABB); + smallIntVector.setValueCount(1); + + intVector = new IntVector("ID", rule.getRootAllocator()); + intVector.setSafe(0, 0xAABBCCDD); + intVector.setValueCount(1); + + bigIntVector = new BigIntVector("ID", rule.getRootAllocator()); + bigIntVector.setSafe(0, 0xAABBCCDDEEFFAABBL); + bigIntVector.setValueCount(1); + } + + @AfterClass + public static void tearDown() throws Exception { + AutoCloseables.close(bigIntVector, intVector, smallIntVector, tinyIntVector, int4Vector, + int8Vector, intVectorWithNull, rule); + } + + @Test + public void testShouldGetStringFromUnsignedValue() throws Exception { + accessorIterator.assertAccessorGetter(int8Vector, + ArrowFlightJdbcBaseIntVectorAccessor::getString, equalTo("18446744073709551615")); + } + + @Test + public void testShouldGetBytesFromIntVectorThrowsSqlException() throws Exception { + accessorIterator.assertAccessorGetterThrowingException(intVector, ArrowFlightJdbcBaseIntVectorAccessor::getBytes); + } + + @Test + public void testShouldGetStringFromIntVectorWithNull() throws Exception { + accessorIterator.assertAccessorGetter(intVectorWithNull, + ArrowFlightJdbcBaseIntVectorAccessor::getString, CoreMatchers.nullValue()); + } + + @Test + public void testShouldGetObjectFromInt() throws Exception { + accessorIterator.assertAccessorGetter(intVector, + ArrowFlightJdbcBaseIntVectorAccessor::getObject, equalTo(0xAABBCCDD)); + } + + @Test + public void testShouldGetObjectFromTinyInt() throws Exception { + accessorIterator.assertAccessorGetter(tinyIntVector, + ArrowFlightJdbcBaseIntVectorAccessor::getObject, equalTo((byte) 0xAA)); + } + + @Test + public void testShouldGetObjectFromSmallInt() throws Exception { + accessorIterator.assertAccessorGetter(smallIntVector, + ArrowFlightJdbcBaseIntVectorAccessor::getObject, equalTo((short) 0xAABB)); + } + + @Test + public void testShouldGetObjectFromBigInt() throws Exception { + accessorIterator.assertAccessorGetter(bigIntVector, + ArrowFlightJdbcBaseIntVectorAccessor::getObject, equalTo(0xAABBCCDDEEFFAABBL)); + } + + @Test + public void testShouldGetObjectFromUnsignedInt() throws Exception { + accessorIterator.assertAccessorGetter(int4Vector, + ArrowFlightJdbcBaseIntVectorAccessor::getObject, equalTo(0x80000001)); + } + + @Test + public void testShouldGetObjectFromIntVectorWithNull() throws Exception { + accessorIterator.assertAccessorGetter(intVectorWithNull, + ArrowFlightJdbcBaseIntVectorAccessor::getObject, CoreMatchers.nullValue()); + } + + @Test + public void testShouldGetBigDecimalFromIntVectorWithNull() throws Exception { + accessorIterator.assertAccessorGetter(intVectorWithNull, + ArrowFlightJdbcBaseIntVectorAccessor::getBigDecimal, CoreMatchers.nullValue()); + } + + @Test + public void testShouldGetBigDecimalWithScaleFromIntVectorWithNull() throws Exception { + accessorIterator.assertAccessorGetter(intVectorWithNull, accessor -> accessor.getBigDecimal(2), + CoreMatchers.nullValue()); + } + + @Test + public void testShouldGetBytesFromSmallVectorThrowsSqlException() throws Exception { + accessorIterator.assertAccessorGetterThrowingException(smallIntVector, + ArrowFlightJdbcBaseIntVectorAccessor::getBytes); + } + + @Test + public void testShouldGetBytesFromTinyIntVectorThrowsSqlException() throws Exception { + accessorIterator.assertAccessorGetterThrowingException(tinyIntVector, + ArrowFlightJdbcBaseIntVectorAccessor::getBytes); + } + + @Test + public void testShouldGetBytesFromBigIntVectorThrowsSqlException() throws Exception { + accessorIterator.assertAccessorGetterThrowingException(bigIntVector, + ArrowFlightJdbcBaseIntVectorAccessor::getBytes); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcBitVectorAccessorTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcBitVectorAccessorTest.java new file mode 100644 index 0000000000000..809d6e8d35386 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcBitVectorAccessorTest.java @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.numeric; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; + +import java.math.BigDecimal; + +import org.apache.arrow.driver.jdbc.utils.AccessorTestUtils; +import org.apache.arrow.driver.jdbc.utils.AccessorTestUtils.AccessorIterator; +import org.apache.arrow.driver.jdbc.utils.AccessorTestUtils.CheckedFunction; +import org.apache.arrow.driver.jdbc.utils.RootAllocatorTestRule; +import org.apache.arrow.vector.BitVector; +import org.junit.After; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; + +public class ArrowFlightJdbcBitVectorAccessorTest { + @ClassRule + public static RootAllocatorTestRule rootAllocatorTestRule = new RootAllocatorTestRule(); + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + private final AccessorTestUtils.AccessorSupplier + accessorSupplier = (vector, getCurrentRow) -> new ArrowFlightJdbcBitVectorAccessor((BitVector) vector, + getCurrentRow, (boolean wasNull) -> { + }); + private final AccessorIterator + accessorIterator = + new AccessorIterator<>(collector, accessorSupplier); + private BitVector vector; + private BitVector vectorWithNull; + private boolean[] arrayToAssert; + + @Before + public void setup() { + this.arrayToAssert = new boolean[] {false, true}; + this.vector = rootAllocatorTestRule.createBitVector(); + this.vectorWithNull = rootAllocatorTestRule.createBitVectorForNullTests(); + } + + @After + public void tearDown() { + this.vector.close(); + this.vectorWithNull.close(); + } + + private void iterate(final CheckedFunction function, + final T result, + final T resultIfFalse, final BitVector vector) throws Exception { + accessorIterator.assertAccessorGetter(vector, function, + ((accessor, currentRow) -> is(arrayToAssert[currentRow] ? result : resultIfFalse)) + ); + } + + @Test + public void testShouldGetBooleanMethodFromBitVector() throws Exception { + iterate(ArrowFlightJdbcBitVectorAccessor::getBoolean, true, false, vector); + } + + @Test + public void testShouldGetByteMethodFromBitVector() throws Exception { + iterate(ArrowFlightJdbcBitVectorAccessor::getByte, (byte) 1, (byte) 0, vector); + } + + @Test + public void testShouldGetShortMethodFromBitVector() throws Exception { + iterate(ArrowFlightJdbcBitVectorAccessor::getShort, (short) 1, (short) 0, vector); + } + + @Test + public void testShouldGetIntMethodFromBitVector() throws Exception { + iterate(ArrowFlightJdbcBitVectorAccessor::getInt, 1, 0, vector); + + } + + @Test + public void testShouldGetLongMethodFromBitVector() throws Exception { + iterate(ArrowFlightJdbcBitVectorAccessor::getLong, (long) 1, (long) 0, vector); + + } + + @Test + public void testShouldGetFloatMethodFromBitVector() throws Exception { + iterate(ArrowFlightJdbcBitVectorAccessor::getFloat, (float) 1, (float) 0, vector); + + } + + @Test + public void testShouldGetDoubleMethodFromBitVector() throws Exception { + iterate(ArrowFlightJdbcBitVectorAccessor::getDouble, (double) 1, (double) 0, vector); + + } + + @Test + public void testShouldGetBigDecimalMethodFromBitVector() throws Exception { + iterate(ArrowFlightJdbcBitVectorAccessor::getBigDecimal, BigDecimal.ONE, BigDecimal.ZERO, + vector); + } + + @Test + public void testShouldGetBigDecimalMethodFromBitVectorFromNull() throws Exception { + iterate(ArrowFlightJdbcBitVectorAccessor::getBigDecimal, null, null, vectorWithNull); + + } + + @Test + public void testShouldGetObjectMethodFromBitVector() throws Exception { + iterate(ArrowFlightJdbcBitVectorAccessor::getObject, true, false, vector); + + } + + @Test + public void testShouldGetObjectMethodFromBitVectorFromNull() throws Exception { + iterate(ArrowFlightJdbcBitVectorAccessor::getObject, null, null, vectorWithNull); + + } + + @Test + public void testShouldGetStringMethodFromBitVector() throws Exception { + iterate(ArrowFlightJdbcBitVectorAccessor::getString, "true", "false", vector); + + } + + @Test + public void testShouldGetStringMethodFromBitVectorFromNull() throws Exception { + iterate(ArrowFlightJdbcBitVectorAccessor::getString, null, null, vectorWithNull); + + } + + @Test + public void testShouldGetObjectClass() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcBitVectorAccessor::getObjectClass, + equalTo(Boolean.class)); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcDecimalVectorAccessorTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcDecimalVectorAccessorTest.java new file mode 100644 index 0000000000000..b7bd7c40fef20 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcDecimalVectorAccessorTest.java @@ -0,0 +1,248 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.numeric; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; + +import java.math.BigDecimal; +import java.util.Arrays; +import java.util.Collection; +import java.util.function.Supplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.driver.jdbc.utils.AccessorTestUtils; +import org.apache.arrow.driver.jdbc.utils.RootAllocatorTestRule; +import org.apache.arrow.vector.Decimal256Vector; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.ValueVector; +import org.hamcrest.CoreMatchers; +import org.junit.After; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(Parameterized.class) +public class ArrowFlightJdbcDecimalVectorAccessorTest { + + @ClassRule + public static RootAllocatorTestRule rootAllocatorTestRule = new RootAllocatorTestRule(); + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + + private final Supplier vectorSupplier; + private ValueVector vector; + private ValueVector vectorWithNull; + + private final AccessorTestUtils.AccessorSupplier + accessorSupplier = (vector, getCurrentRow) -> { + ArrowFlightJdbcAccessorFactory.WasNullConsumer noOpWasNullConsumer = (boolean wasNull) -> { + }; + if (vector instanceof DecimalVector) { + return new ArrowFlightJdbcDecimalVectorAccessor((DecimalVector) vector, getCurrentRow, + noOpWasNullConsumer); + } else if (vector instanceof Decimal256Vector) { + return new ArrowFlightJdbcDecimalVectorAccessor((Decimal256Vector) vector, getCurrentRow, + noOpWasNullConsumer); + } + return null; + }; + + private final AccessorTestUtils.AccessorIterator + accessorIterator = + new AccessorTestUtils.AccessorIterator<>(collector, accessorSupplier); + + @Parameterized.Parameters(name = "{1}") + public static Collection data() { + return Arrays.asList(new Object[][] { + {(Supplier) () -> rootAllocatorTestRule.createDecimalVector(), + "DecimalVector"}, + {(Supplier) () -> rootAllocatorTestRule.createDecimal256Vector(), + "Decimal256Vector"}, + }); + } + + public ArrowFlightJdbcDecimalVectorAccessorTest(Supplier vectorSupplier, + String vectorType) { + this.vectorSupplier = vectorSupplier; + } + + @Before + public void setup() { + this.vector = vectorSupplier.get(); + + this.vectorWithNull = vectorSupplier.get(); + this.vectorWithNull.clear(); + this.vectorWithNull.setValueCount(5); + } + + @After + public void tearDown() { + this.vector.close(); + this.vectorWithNull.close(); + } + + @Test + public void testShouldGetBigDecimalFromDecimalVector() throws Exception { + accessorIterator.assertAccessorGetter(vector, + ArrowFlightJdbcDecimalVectorAccessor::getBigDecimal, + (accessor, currentRow) -> CoreMatchers.notNullValue()); + } + + @Test + public void testShouldGetDoubleMethodFromDecimalVector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcDecimalVectorAccessor::getDouble, + (accessor, currentRow) -> equalTo(accessor.getBigDecimal().doubleValue())); + } + + @Test + public void testShouldGetFloatMethodFromDecimalVector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcDecimalVectorAccessor::getFloat, + (accessor, currentRow) -> equalTo(accessor.getBigDecimal().floatValue())); + } + + @Test + public void testShouldGetLongMethodFromDecimalVector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcDecimalVectorAccessor::getLong, + (accessor, currentRow) -> equalTo(accessor.getBigDecimal().longValue())); + } + + @Test + public void testShouldGetIntMethodFromDecimalVector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcDecimalVectorAccessor::getInt, + (accessor, currentRow) -> equalTo(accessor.getBigDecimal().intValue())); + } + + @Test + public void testShouldGetShortMethodFromDecimalVector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcDecimalVectorAccessor::getShort, + (accessor, currentRow) -> equalTo(accessor.getBigDecimal().shortValue())); + } + + @Test + public void testShouldGetByteMethodFromDecimalVector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcDecimalVectorAccessor::getByte, + (accessor, currentRow) -> equalTo(accessor.getBigDecimal().byteValue())); + } + + @Test + public void testShouldGetStringMethodFromDecimalVector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcDecimalVectorAccessor::getString, + (accessor, currentRow) -> equalTo(accessor.getBigDecimal().toString())); + } + + @Test + public void testShouldGetBooleanMethodFromDecimalVector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcDecimalVectorAccessor::getBoolean, + (accessor, currentRow) -> equalTo(!accessor.getBigDecimal().equals(BigDecimal.ZERO))); + } + + @Test + public void testShouldGetObjectMethodFromDecimalVector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcDecimalVectorAccessor::getObject, + (accessor, currentRow) -> equalTo(accessor.getBigDecimal())); + } + + @Test + public void testShouldGetObjectClass() throws Exception { + accessorIterator.assertAccessorGetter(vector, + ArrowFlightJdbcDecimalVectorAccessor::getObjectClass, + (accessor, currentRow) -> equalTo(BigDecimal.class)); + } + + @Test + public void testShouldGetBigDecimalMethodFromDecimalVectorWithNull() throws Exception { + accessorIterator.assertAccessorGetter(vectorWithNull, + ArrowFlightJdbcDecimalVectorAccessor::getBigDecimal, + (accessor, currentRow) -> CoreMatchers.nullValue()); + } + + @Test + public void testShouldGetObjectMethodFromDecimalVectorWithNull() throws Exception { + accessorIterator.assertAccessorGetter(vectorWithNull, + ArrowFlightJdbcDecimalVectorAccessor::getObject, + (accessor, currentRow) -> CoreMatchers.nullValue()); + } + + @Test + public void testShouldGetStringMethodFromDecimalVectorWithNull() throws Exception { + accessorIterator.assertAccessorGetter(vectorWithNull, + ArrowFlightJdbcDecimalVectorAccessor::getString, + (accessor, currentRow) -> CoreMatchers.nullValue()); + } + + @Test + public void testShouldGetByteMethodFromDecimalVectorWithNull() throws Exception { + accessorIterator.assertAccessorGetter(vectorWithNull, + ArrowFlightJdbcDecimalVectorAccessor::getByte, + (accessor, currentRow) -> is((byte) 0)); + } + + @Test + public void testShouldGetShortMethodFromDecimalVectorWithNull() throws Exception { + accessorIterator.assertAccessorGetter(vectorWithNull, + ArrowFlightJdbcDecimalVectorAccessor::getShort, + (accessor, currentRow) -> is((short) 0)); + } + + @Test + public void testShouldGetIntMethodFromDecimalVectorWithNull() throws Exception { + accessorIterator.assertAccessorGetter(vectorWithNull, + ArrowFlightJdbcDecimalVectorAccessor::getInt, + (accessor, currentRow) -> is(0)); + } + + @Test + public void testShouldGetLongMethodFromDecimalVectorWithNull() throws Exception { + accessorIterator.assertAccessorGetter(vectorWithNull, + ArrowFlightJdbcDecimalVectorAccessor::getLong, + (accessor, currentRow) -> is((long) 0)); + } + + @Test + public void testShouldGetFloatMethodFromDecimalVectorWithNull() throws Exception { + accessorIterator.assertAccessorGetter(vectorWithNull, + ArrowFlightJdbcDecimalVectorAccessor::getFloat, + (accessor, currentRow) -> is(0.0f)); + } + + @Test + public void testShouldGetDoubleMethodFromDecimalVectorWithNull() throws Exception { + accessorIterator.assertAccessorGetter(vectorWithNull, + ArrowFlightJdbcDecimalVectorAccessor::getDouble, + (accessor, currentRow) -> is(0.0D)); + } + + @Test + public void testShouldGetBooleanMethodFromDecimalVectorWithNull() throws Exception { + accessorIterator.assertAccessorGetter(vectorWithNull, + ArrowFlightJdbcDecimalVectorAccessor::getBoolean, + (accessor, currentRow) -> is(false)); + } + + @Test + public void testShouldGetBigDecimalWithScaleMethodFromDecimalVectorWithNull() throws Exception { + accessorIterator.assertAccessorGetter(vectorWithNull, accessor -> accessor.getBigDecimal(2), + (accessor, currentRow) -> CoreMatchers.nullValue()); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcFloat4VectorAccessorTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcFloat4VectorAccessorTest.java new file mode 100644 index 0000000000000..74a65715ec0fb --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcFloat4VectorAccessorTest.java @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.numeric; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; + +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.sql.SQLException; + +import org.apache.arrow.driver.jdbc.utils.AccessorTestUtils; +import org.apache.arrow.driver.jdbc.utils.RootAllocatorTestRule; +import org.apache.arrow.vector.Float4Vector; +import org.hamcrest.CoreMatchers; +import org.junit.After; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; +import org.junit.rules.ExpectedException; + +public class ArrowFlightJdbcFloat4VectorAccessorTest { + + @ClassRule + public static RootAllocatorTestRule rootAllocatorTestRule = new RootAllocatorTestRule(); + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + + @Rule + public ExpectedException exceptionCollector = ExpectedException.none(); + + private Float4Vector vector; + + private final AccessorTestUtils.AccessorSupplier + accessorSupplier = + (vector, getCurrentRow) -> new ArrowFlightJdbcFloat4VectorAccessor((Float4Vector) vector, + getCurrentRow, (boolean wasNull) -> { + }); + + private final AccessorTestUtils.AccessorIterator + accessorIterator = + new AccessorTestUtils.AccessorIterator<>(collector, accessorSupplier); + + @Before + public void setup() { + this.vector = rootAllocatorTestRule.createFloat4Vector(); + } + + @After + public void tearDown() { + this.vector.close(); + } + + @Test + public void testShouldGetFloatMethodFromFloat4Vector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcFloat4VectorAccessor::getFloat, + (accessor, currentRow) -> is(vector.get(currentRow))); + } + + @Test + public void testShouldGetObjectMethodFromFloat4Vector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcFloat4VectorAccessor::getObject, + (accessor) -> is(accessor.getFloat())); + } + + @Test + public void testShouldGetStringMethodFromFloat4Vector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcFloat4VectorAccessor::getString, + accessor -> is(Float.toString(accessor.getFloat()))); + } + + @Test + public void testShouldGetStringMethodFromFloat4VectorWithNull() throws Exception { + try (final Float4Vector float4Vector = new Float4Vector("ID", + rootAllocatorTestRule.getRootAllocator())) { + float4Vector.setNull(0); + float4Vector.setValueCount(1); + + accessorIterator.assertAccessorGetter(float4Vector, + ArrowFlightJdbcFloat4VectorAccessor::getString, + CoreMatchers.nullValue()); + } + } + + @Test + public void testShouldGetFloatMethodFromFloat4VectorWithNull() throws Exception { + try (final Float4Vector float4Vector = new Float4Vector("ID", + rootAllocatorTestRule.getRootAllocator())) { + float4Vector.setNull(0); + float4Vector.setValueCount(1); + + accessorIterator.assertAccessorGetter(float4Vector, + ArrowFlightJdbcFloat4VectorAccessor::getFloat, is(0.0f)); + } + } + + @Test + public void testShouldGetBigDecimalMethodFromFloat4VectorWithNull() throws Exception { + try (final Float4Vector float4Vector = new Float4Vector("ID", + rootAllocatorTestRule.getRootAllocator())) { + float4Vector.setNull(0); + float4Vector.setValueCount(1); + + accessorIterator.assertAccessorGetter(float4Vector, + ArrowFlightJdbcFloat4VectorAccessor::getBigDecimal, + CoreMatchers.nullValue()); + } + } + + @Test + public void testShouldGetObjectMethodFromFloat4VectorWithNull() throws Exception { + try (final Float4Vector float4Vector = new Float4Vector("ID", + rootAllocatorTestRule.getRootAllocator())) { + float4Vector.setNull(0); + float4Vector.setValueCount(1); + + accessorIterator.assertAccessorGetter(float4Vector, + ArrowFlightJdbcFloat4VectorAccessor::getObject, + CoreMatchers.nullValue()); + } + } + + @Test + public void testShouldGetBooleanMethodFromFloat4Vector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcFloat4VectorAccessor::getBoolean, + accessor -> is(accessor.getFloat() != 0.0f)); + } + + @Test + public void testShouldGetByteMethodFromFloat4Vector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcFloat4VectorAccessor::getByte, + accessor -> is((byte) accessor.getFloat())); + } + + @Test + public void testShouldGetShortMethodFromFloat4Vector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcFloat4VectorAccessor::getShort, + accessor -> is((short) accessor.getFloat())); + } + + @Test + public void testShouldGetIntMethodFromFloat4Vector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcFloat4VectorAccessor::getInt, + accessor -> is((int) accessor.getFloat())); + } + + @Test + public void testShouldGetLongMethodFromFloat4Vector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcFloat4VectorAccessor::getLong, + accessor -> is((long) accessor.getFloat())); + } + + @Test + public void testShouldGetDoubleMethodFromFloat4Vector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcFloat4VectorAccessor::getDouble, + accessor -> is((double) accessor.getFloat())); + } + + @Test + public void testShouldGetBigDecimalMethodFromFloat4Vector() throws Exception { + accessorIterator.iterate(vector, (accessor, currentRow) -> { + float value = accessor.getFloat(); + if (Float.isInfinite(value) || Float.isNaN(value)) { + exceptionCollector.expect(SQLException.class); + } + collector.checkThat(accessor.getBigDecimal(), is(BigDecimal.valueOf(value))); + }); + } + + @Test + public void testShouldGetBigDecimalWithScaleMethodFromFloat4Vector() throws Exception { + accessorIterator.iterate(vector, (accessor, currentRow) -> { + float value = accessor.getFloat(); + if (Float.isInfinite(value) || Float.isNaN(value)) { + exceptionCollector.expect(SQLException.class); + } + collector.checkThat(accessor.getBigDecimal(9), + is(BigDecimal.valueOf(value).setScale(9, RoundingMode.HALF_UP))); + }); + } + + @Test + public void testShouldGetObjectClass() throws Exception { + accessorIterator.assertAccessorGetter(vector, + ArrowFlightJdbcFloat4VectorAccessor::getObjectClass, + accessor -> equalTo(Float.class)); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcFloat8VectorAccessorTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcFloat8VectorAccessorTest.java new file mode 100644 index 0000000000000..26758287a96f3 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcFloat8VectorAccessorTest.java @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.numeric; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; + +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.sql.SQLException; + +import org.apache.arrow.driver.jdbc.utils.AccessorTestUtils; +import org.apache.arrow.driver.jdbc.utils.RootAllocatorTestRule; +import org.apache.arrow.vector.Float8Vector; +import org.hamcrest.CoreMatchers; +import org.junit.After; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; +import org.junit.rules.ExpectedException; + +public class ArrowFlightJdbcFloat8VectorAccessorTest { + + @ClassRule + public static RootAllocatorTestRule rootAllocatorTestRule = new RootAllocatorTestRule(); + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + + @Rule + public ExpectedException exceptionCollector = ExpectedException.none(); + + + private Float8Vector vector; + private Float8Vector vectorWithNull; + + private final AccessorTestUtils.AccessorSupplier + accessorSupplier = + (vector, getCurrentRow) -> new ArrowFlightJdbcFloat8VectorAccessor((Float8Vector) vector, + getCurrentRow, (boolean wasNull) -> { + }); + + private final AccessorTestUtils.AccessorIterator + accessorIterator = + new AccessorTestUtils.AccessorIterator<>(collector, accessorSupplier); + + @Before + public void setup() { + this.vector = rootAllocatorTestRule.createFloat8Vector(); + this.vectorWithNull = rootAllocatorTestRule.createFloat8VectorForNullTests(); + } + + @After + public void tearDown() { + this.vector.close(); + this.vectorWithNull.close(); + } + + @Test + public void testShouldGetDoubleMethodFromFloat8Vector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcFloat8VectorAccessor::getDouble, + (accessor, currentRow) -> is(vector.getValueAsDouble(currentRow))); + } + + @Test + public void testShouldGetObjectMethodFromFloat8Vector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcFloat8VectorAccessor::getObject, + (accessor) -> is(accessor.getDouble())); + } + + @Test + public void testShouldGetStringMethodFromFloat8Vector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcFloat8VectorAccessor::getString, + (accessor) -> is(Double.toString(accessor.getDouble()))); + } + + @Test + public void testShouldGetBooleanMethodFromFloat8Vector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcFloat8VectorAccessor::getBoolean, + (accessor) -> is(accessor.getDouble() != 0.0)); + } + + @Test + public void testShouldGetByteMethodFromFloat8Vector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcFloat8VectorAccessor::getByte, + (accessor) -> is((byte) accessor.getDouble())); + } + + @Test + public void testShouldGetShortMethodFromFloat8Vector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcFloat8VectorAccessor::getShort, + (accessor) -> is((short) accessor.getDouble())); + } + + @Test + public void testShouldGetIntMethodFromFloat8Vector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcFloat8VectorAccessor::getInt, + (accessor) -> is((int) accessor.getDouble())); + } + + @Test + public void testShouldGetLongMethodFromFloat8Vector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcFloat8VectorAccessor::getLong, + (accessor) -> is((long) accessor.getDouble())); + } + + @Test + public void testShouldGetFloatMethodFromFloat8Vector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcFloat8VectorAccessor::getFloat, + (accessor) -> is((float) accessor.getDouble())); + } + + @Test + public void testShouldGetBigDecimalMethodFromFloat8Vector() throws Exception { + accessorIterator.iterate(vector, (accessor, currentRow) -> { + double value = accessor.getDouble(); + if (Double.isInfinite(value) || Double.isNaN(value)) { + exceptionCollector.expect(SQLException.class); + } + collector.checkThat(accessor.getBigDecimal(), is(BigDecimal.valueOf(value))); + }); + } + + @Test + public void testShouldGetObjectClass() throws Exception { + accessorIterator + .assertAccessorGetter(vector, ArrowFlightJdbcFloat8VectorAccessor::getObjectClass, + equalTo(Double.class)); + } + + @Test + public void testShouldGetStringMethodFromFloat8VectorWithNull() throws Exception { + accessorIterator + .assertAccessorGetter(vectorWithNull, ArrowFlightJdbcFloat8VectorAccessor::getString, + CoreMatchers.nullValue()); + } + + @Test + public void testShouldGetFloatMethodFromFloat8VectorWithNull() throws Exception { + accessorIterator + .assertAccessorGetter(vectorWithNull, ArrowFlightJdbcFloat8VectorAccessor::getFloat, + is(0.0f)); + } + + @Test + public void testShouldGetBigDecimalMethodFromFloat8VectorWithNull() throws Exception { + accessorIterator.assertAccessorGetter(vectorWithNull, + ArrowFlightJdbcFloat8VectorAccessor::getBigDecimal, + CoreMatchers.nullValue()); + } + + @Test + public void testShouldGetBigDecimalWithScaleMethodFromFloat4Vector() throws Exception { + accessorIterator.iterate(vector, (accessor, currentRow) -> { + double value = accessor.getDouble(); + if (Double.isInfinite(value) || Double.isNaN(value)) { + exceptionCollector.expect(SQLException.class); + } + collector.checkThat(accessor.getBigDecimal(9), + is(BigDecimal.valueOf(value).setScale(9, RoundingMode.HALF_UP))); + }); + } + + @Test + public void testShouldGetObjectMethodFromFloat8VectorWithNull() throws Exception { + accessorIterator + .assertAccessorGetter(vectorWithNull, ArrowFlightJdbcFloat8VectorAccessor::getObject, + CoreMatchers.nullValue()); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/text/ArrowFlightJdbcVarCharVectorAccessorTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/text/ArrowFlightJdbcVarCharVectorAccessorTest.java new file mode 100644 index 0000000000000..799c517dd561b --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/text/ArrowFlightJdbcVarCharVectorAccessorTest.java @@ -0,0 +1,733 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.text; + +import static java.nio.charset.StandardCharsets.US_ASCII; +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.apache.commons.io.IOUtils.toByteArray; +import static org.apache.commons.io.IOUtils.toCharArray; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.mockito.Mockito.when; + +import java.io.InputStream; +import java.io.Reader; +import java.math.BigDecimal; +import java.sql.Date; +import java.sql.SQLException; +import java.sql.Time; +import java.sql.Timestamp; +import java.text.SimpleDateFormat; +import java.util.Calendar; +import java.util.TimeZone; +import java.util.function.IntSupplier; + +import org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcDateVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcTimeStampVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcTimeVectorAccessor; +import org.apache.arrow.driver.jdbc.utils.RootAllocatorTestRule; +import org.apache.arrow.driver.jdbc.utils.ThrowableAssertionUtils; +import org.apache.arrow.vector.DateMilliVector; +import org.apache.arrow.vector.TimeMilliVector; +import org.apache.arrow.vector.TimeStampVector; +import org.apache.arrow.vector.util.Text; +import org.junit.Assert; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; + + +@RunWith(MockitoJUnitRunner.class) +public class ArrowFlightJdbcVarCharVectorAccessorTest { + + private ArrowFlightJdbcVarCharVectorAccessor accessor; + private final SimpleDateFormat dateTimeFormat = + new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSXXX"); + private final SimpleDateFormat timeFormat = new SimpleDateFormat("HH:mm:ss.SSSXXX"); + + @ClassRule + public static RootAllocatorTestRule rootAllocatorTestRule = new RootAllocatorTestRule(); + + @Mock + private ArrowFlightJdbcVarCharVectorAccessor.Getter getter; + + @Rule + public ErrorCollector collector = new ErrorCollector(); + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Before + public void setUp() { + IntSupplier currentRowSupplier = () -> 0; + accessor = + new ArrowFlightJdbcVarCharVectorAccessor(getter, currentRowSupplier, (boolean wasNull) -> { + }); + } + + @Test + public void testShouldGetStringFromNullReturnNull() { + when(getter.get(0)).thenReturn(null); + final String result = accessor.getString(); + + collector.checkThat(result, equalTo(null)); + } + + @Test + public void testShouldGetStringReturnValidString() { + Text value = new Text("Value for Test."); + when(getter.get(0)).thenReturn(value.copyBytes()); + + final String result = accessor.getString(); + + collector.checkThat(result, instanceOf(String.class)); + collector.checkThat(result, equalTo(value.toString())); + } + + @Test + public void testShouldGetObjectReturnValidString() { + Text value = new Text("Value for Test."); + when(getter.get(0)).thenReturn(value.copyBytes()); + + final String result = accessor.getObject(); + + collector.checkThat(result, instanceOf(String.class)); + collector.checkThat(result, equalTo(value.toString())); + } + + @Test + public void testShouldGetByteThrowsExceptionForNonNumericValue() throws Exception { + Text value = new Text("Invalid value for byte."); + when(getter.get(0)).thenReturn(value.copyBytes()); + + thrown.expect(SQLException.class); + accessor.getByte(); + } + + @Test + public void testShouldGetByteThrowsExceptionForOutOfRangePositiveValue() throws Exception { + Text value = new Text("128"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + thrown.expect(SQLException.class); + accessor.getByte(); + } + + @Test + public void testShouldGetByteThrowsExceptionForOutOfRangeNegativeValue() throws Exception { + Text value = new Text("-129"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + thrown.expect(SQLException.class); + accessor.getByte(); + } + + @Test + public void testShouldGetByteReturnValidPositiveByte() throws Exception { + Text value = new Text("127"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + byte result = accessor.getByte(); + + collector.checkThat(result, instanceOf(Byte.class)); + collector.checkThat(result, equalTo((byte) 127)); + } + + @Test + public void testShouldGetByteReturnValidNegativeByte() throws Exception { + Text value = new Text("-128"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + byte result = accessor.getByte(); + + collector.checkThat(result, instanceOf(Byte.class)); + collector.checkThat(result, equalTo((byte) -128)); + } + + @Test + public void testShouldGetShortThrowsExceptionForNonNumericValue() throws Exception { + Text value = new Text("Invalid value for short."); + when(getter.get(0)).thenReturn(value.copyBytes()); + + thrown.expect(SQLException.class); + accessor.getShort(); + } + + @Test + public void testShouldGetShortThrowsExceptionForOutOfRangePositiveValue() throws Exception { + Text value = new Text("32768"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + thrown.expect(SQLException.class); + accessor.getShort(); + } + + @Test + public void testShouldGetShortThrowsExceptionForOutOfRangeNegativeValue() throws Exception { + Text value = new Text("-32769"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + thrown.expect(SQLException.class); + accessor.getShort(); + } + + @Test + public void testShouldGetShortReturnValidPositiveShort() throws Exception { + Text value = new Text("32767"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + short result = accessor.getShort(); + + collector.checkThat(result, instanceOf(Short.class)); + collector.checkThat(result, equalTo((short) 32767)); + } + + @Test + public void testShouldGetShortReturnValidNegativeShort() throws Exception { + Text value = new Text("-32768"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + short result = accessor.getShort(); + + collector.checkThat(result, instanceOf(Short.class)); + collector.checkThat(result, equalTo((short) -32768)); + } + + @Test + public void testShouldGetIntThrowsExceptionForNonNumericValue() throws Exception { + Text value = new Text("Invalid value for int."); + when(getter.get(0)).thenReturn(value.copyBytes()); + + thrown.expect(SQLException.class); + accessor.getInt(); + } + + @Test + public void testShouldGetIntThrowsExceptionForOutOfRangePositiveValue() throws Exception { + Text value = new Text("2147483648"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + thrown.expect(SQLException.class); + accessor.getInt(); + } + + @Test + public void testShouldGetIntThrowsExceptionForOutOfRangeNegativeValue() throws Exception { + Text value = new Text("-2147483649"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + thrown.expect(SQLException.class); + accessor.getInt(); + } + + @Test + public void testShouldGetIntReturnValidPositiveInteger() throws Exception { + Text value = new Text("2147483647"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + int result = accessor.getInt(); + + collector.checkThat(result, instanceOf(Integer.class)); + collector.checkThat(result, equalTo(2147483647)); + } + + @Test + public void testShouldGetIntReturnValidNegativeInteger() throws Exception { + Text value = new Text("-2147483648"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + int result = accessor.getInt(); + + collector.checkThat(result, instanceOf(Integer.class)); + collector.checkThat(result, equalTo(-2147483648)); + } + + @Test + public void testShouldGetLongThrowsExceptionForNonNumericValue() throws Exception { + Text value = new Text("Invalid value for long."); + when(getter.get(0)).thenReturn(value.copyBytes()); + + thrown.expect(SQLException.class); + accessor.getLong(); + } + + @Test + public void testShouldGetLongThrowsExceptionForOutOfRangePositiveValue() throws Exception { + Text value = new Text("9223372036854775808"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + thrown.expect(SQLException.class); + accessor.getLong(); + } + + @Test + public void testShouldGetLongThrowsExceptionForOutOfRangeNegativeValue() throws Exception { + Text value = new Text("-9223372036854775809"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + thrown.expect(SQLException.class); + accessor.getLong(); + } + + @Test + public void testShouldGetLongReturnValidPositiveLong() throws Exception { + Text value = new Text("9223372036854775807"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + long result = accessor.getLong(); + + collector.checkThat(result, instanceOf(Long.class)); + collector.checkThat(result, equalTo(9223372036854775807L)); + } + + @Test + public void testShouldGetLongReturnValidNegativeLong() throws Exception { + Text value = new Text("-9223372036854775808"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + long result = accessor.getLong(); + + collector.checkThat(result, instanceOf(Long.class)); + collector.checkThat(result, equalTo(-9223372036854775808L)); + } + + @Test + public void testShouldBigDecimalWithParametersThrowsExceptionForNonNumericValue() throws Exception { + Text value = new Text("Invalid value for BigDecimal."); + when(getter.get(0)).thenReturn(value.copyBytes()); + + thrown.expect(SQLException.class); + accessor.getBigDecimal(1); + } + + @Test + public void testShouldGetBigDecimalThrowsExceptionForNonNumericValue() throws Exception { + Text value = new Text("Invalid value for BigDecimal."); + when(getter.get(0)).thenReturn(value.copyBytes()); + + thrown.expect(SQLException.class); + accessor.getBigDecimal(); + } + + @Test + public void testShouldGetBigDecimalReturnValidPositiveBigDecimal() throws Exception { + Text value = new Text("9223372036854775807000.999"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + BigDecimal result = accessor.getBigDecimal(); + + collector.checkThat(result, instanceOf(BigDecimal.class)); + collector.checkThat(result, equalTo(new BigDecimal("9223372036854775807000.999"))); + } + + @Test + public void testShouldGetBigDecimalReturnValidNegativeBigDecimal() throws Exception { + Text value = new Text("-9223372036854775807000.999"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + BigDecimal result = accessor.getBigDecimal(); + + collector.checkThat(result, instanceOf(BigDecimal.class)); + collector.checkThat(result, equalTo(new BigDecimal("-9223372036854775807000.999"))); + } + + @Test + public void testShouldGetDoubleThrowsExceptionForNonNumericValue() throws Exception { + Text value = new Text("Invalid value for double."); + when(getter.get(0)).thenReturn(value.copyBytes()); + + thrown.expect(SQLException.class); + accessor.getDouble(); + } + + @Test + public void testShouldGetDoubleReturnValidPositiveDouble() throws Exception { + Text value = new Text("1.7976931348623157E308D"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + double result = accessor.getDouble(); + + collector.checkThat(result, instanceOf(Double.class)); + collector.checkThat(result, equalTo(1.7976931348623157E308D)); + } + + @Test + public void testShouldGetDoubleReturnValidNegativeDouble() throws Exception { + Text value = new Text("-1.7976931348623157E308D"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + double result = accessor.getDouble(); + + collector.checkThat(result, instanceOf(Double.class)); + collector.checkThat(result, equalTo(-1.7976931348623157E308D)); + } + + @Test + public void testShouldGetDoubleWorkWithPositiveInfinity() throws Exception { + Text value = new Text("Infinity"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + double result = accessor.getDouble(); + + collector.checkThat(result, instanceOf(Double.class)); + collector.checkThat(result, equalTo(Double.POSITIVE_INFINITY)); + } + + @Test + public void testShouldGetDoubleWorkWithNegativeInfinity() throws Exception { + Text value = new Text("-Infinity"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + double result = accessor.getDouble(); + + collector.checkThat(result, instanceOf(Double.class)); + collector.checkThat(result, equalTo(Double.NEGATIVE_INFINITY)); + } + + @Test + public void testShouldGetDoubleWorkWithNaN() throws Exception { + Text value = new Text("NaN"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + double result = accessor.getDouble(); + + collector.checkThat(result, instanceOf(Double.class)); + collector.checkThat(result, equalTo(Double.NaN)); + } + + @Test + public void testShouldGetFloatThrowsExceptionForNonNumericValue() throws Exception { + Text value = new Text("Invalid value for float."); + when(getter.get(0)).thenReturn(value.copyBytes()); + + thrown.expect(SQLException.class); + accessor.getFloat(); + } + + @Test + public void testShouldGetFloatReturnValidPositiveFloat() throws Exception { + Text value = new Text("3.4028235E38F"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + float result = accessor.getFloat(); + + collector.checkThat(result, instanceOf(Float.class)); + collector.checkThat(result, equalTo(3.4028235E38F)); + } + + @Test + public void testShouldGetFloatReturnValidNegativeFloat() throws Exception { + Text value = new Text("-3.4028235E38F"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + float result = accessor.getFloat(); + + collector.checkThat(result, instanceOf(Float.class)); + collector.checkThat(result, equalTo(-3.4028235E38F)); + } + + @Test + public void testShouldGetFloatWorkWithPositiveInfinity() throws Exception { + Text value = new Text("Infinity"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + float result = accessor.getFloat(); + + collector.checkThat(result, instanceOf(Float.class)); + collector.checkThat(result, equalTo(Float.POSITIVE_INFINITY)); + } + + @Test + public void testShouldGetFloatWorkWithNegativeInfinity() throws Exception { + Text value = new Text("-Infinity"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + float result = accessor.getFloat(); + + collector.checkThat(result, instanceOf(Float.class)); + collector.checkThat(result, equalTo(Float.NEGATIVE_INFINITY)); + } + + @Test + public void testShouldGetFloatWorkWithNaN() throws Exception { + Text value = new Text("NaN"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + float result = accessor.getFloat(); + + collector.checkThat(result, instanceOf(Float.class)); + collector.checkThat(result, equalTo(Float.NaN)); + } + + @Test + public void testShouldGetDateThrowsExceptionForNonDateValue() throws Exception { + Text value = new Text("Invalid value for date."); + when(getter.get(0)).thenReturn(value.copyBytes()); + + thrown.expect(SQLException.class); + accessor.getDate(null); + } + + @Test + public void testShouldGetDateReturnValidDateWithoutCalendar() throws Exception { + Text value = new Text("2021-07-02"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + Date result = accessor.getDate(null); + + collector.checkThat(result, instanceOf(Date.class)); + + Calendar calendar = Calendar.getInstance(); + calendar.setTime(result); + + collector.checkThat(dateTimeFormat.format(calendar.getTime()), + equalTo("2021-07-02T00:00:00.000Z")); + } + + @Test + public void testShouldGetDateReturnValidDateWithCalendar() throws Exception { + Text value = new Text("2021-07-02"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + Calendar calendar = Calendar.getInstance(TimeZone.getTimeZone("America/Sao_Paulo")); + Date result = accessor.getDate(calendar); + + calendar = Calendar.getInstance(TimeZone.getTimeZone("Etc/UTC")); + calendar.setTime(result); + + collector.checkThat(dateTimeFormat.format(calendar.getTime()), + equalTo("2021-07-02T03:00:00.000Z")); + } + + @Test + public void testShouldGetTimeThrowsExceptionForNonTimeValue() throws Exception { + Text value = new Text("Invalid value for time."); + when(getter.get(0)).thenReturn(value.copyBytes()); + + thrown.expect(SQLException.class); + accessor.getTime(null); + } + + @Test + public void testShouldGetTimeReturnValidDateWithoutCalendar() throws Exception { + Text value = new Text("02:30:00"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + Time result = accessor.getTime(null); + + Calendar calendar = Calendar.getInstance(); + calendar.setTime(result); + + collector.checkThat(timeFormat.format(calendar.getTime()), equalTo("02:30:00.000Z")); + } + + @Test + public void testShouldGetTimeReturnValidDateWithCalendar() throws Exception { + Text value = new Text("02:30:00"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + Calendar calendar = Calendar.getInstance(TimeZone.getTimeZone("America/Sao_Paulo")); + Time result = accessor.getTime(calendar); + + calendar = Calendar.getInstance(TimeZone.getTimeZone("Etc/UTC")); + calendar.setTime(result); + + collector.checkThat(timeFormat.format(calendar.getTime()), equalTo("05:30:00.000Z")); + } + + @Test + public void testShouldGetTimestampThrowsExceptionForNonTimeValue() throws Exception { + Text value = new Text("Invalid value for timestamp."); + when(getter.get(0)).thenReturn(value.copyBytes()); + + thrown.expect(SQLException.class); + accessor.getTimestamp(null); + } + + @Test + public void testShouldGetTimestampReturnValidDateWithoutCalendar() throws Exception { + Text value = new Text("2021-07-02 02:30:00.000"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + Timestamp result = accessor.getTimestamp(null); + + Calendar calendar = Calendar.getInstance(); + calendar.setTime(result); + + collector.checkThat(dateTimeFormat.format(calendar.getTime()), + equalTo("2021-07-02T02:30:00.000Z")); + } + + @Test + public void testShouldGetTimestampReturnValidDateWithCalendar() throws Exception { + Text value = new Text("2021-07-02 02:30:00.000"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + Calendar calendar = Calendar.getInstance(TimeZone.getTimeZone("America/Sao_Paulo")); + Timestamp result = accessor.getTimestamp(calendar); + + calendar = Calendar.getInstance(TimeZone.getTimeZone("Etc/UTC")); + calendar.setTime(result); + + collector.checkThat(dateTimeFormat.format(calendar.getTime()), + equalTo("2021-07-02T05:30:00.000Z")); + } + + private void assertGetBoolean(Text value, boolean expectedResult) throws SQLException { + when(getter.get(0)).thenReturn(value == null ? null : value.copyBytes()); + boolean result = accessor.getBoolean(); + collector.checkThat(result, equalTo(expectedResult)); + } + + private void assertGetBooleanForSQLException(Text value) { + when(getter.get(0)).thenReturn(value == null ? null : value.copyBytes()); + ThrowableAssertionUtils.simpleAssertThrowableClass(SQLException.class, () -> accessor.getBoolean()); + } + + @Test + public void testShouldGetBooleanThrowsSQLExceptionForInvalidValue() { + assertGetBooleanForSQLException(new Text("anything")); + } + + @Test + public void testShouldGetBooleanThrowsSQLExceptionForEmpty() { + assertGetBooleanForSQLException(new Text("")); + } + + @Test + public void testShouldGetBooleanReturnFalseFor0() throws Exception { + assertGetBoolean(new Text("0"), false); + } + + @Test + public void testShouldGetBooleanReturnFalseForFalseString() throws Exception { + assertGetBoolean(new Text("false"), false); + } + + @Test + public void testShouldGetBooleanReturnFalseForNull() throws Exception { + assertGetBoolean(null, false); + } + + @Test + public void testShouldGetBytesReturnValidByteArray() { + Text value = new Text("Value for Test."); + when(getter.get(0)).thenReturn(value.copyBytes()); + + final byte[] result = accessor.getBytes(); + + collector.checkThat(result, instanceOf(byte[].class)); + collector.checkThat(result, equalTo(value.toString().getBytes(UTF_8))); + } + + @Test + public void testShouldGetUnicodeStreamReturnValidInputStream() throws Exception { + Text value = new Text("Value for Test."); + when(getter.get(0)).thenReturn(value.copyBytes()); + + try (final InputStream result = accessor.getUnicodeStream()) { + byte[] resultBytes = toByteArray(result); + + collector.checkThat(new String(resultBytes, UTF_8), + equalTo(value.toString())); + } + } + + @Test + public void testShouldGetAsciiStreamReturnValidInputStream() throws Exception { + Text valueText = new Text("Value for Test."); + byte[] valueAscii = valueText.toString().getBytes(US_ASCII); + when(getter.get(0)).thenReturn(valueText.copyBytes()); + + try (final InputStream result = accessor.getAsciiStream()) { + byte[] resultBytes = toByteArray(result); + + Assert.assertArrayEquals(valueAscii, resultBytes); + } + } + + @Test + public void testShouldGetCharacterStreamReturnValidReader() throws Exception { + Text value = new Text("Value for Test."); + when(getter.get(0)).thenReturn(value.copyBytes()); + + try (Reader result = accessor.getCharacterStream()) { + char[] resultChars = toCharArray(result); + + collector.checkThat(new String(resultChars), equalTo(value.toString())); + } + } + + @Test + public void testShouldGetTimeStampBeConsistentWithTimeStampAccessor() throws Exception { + try (TimeStampVector timeStampVector = rootAllocatorTestRule.createTimeStampMilliVector()) { + ArrowFlightJdbcTimeStampVectorAccessor timeStampVectorAccessor = + new ArrowFlightJdbcTimeStampVectorAccessor(timeStampVector, () -> 0, + (boolean wasNull) -> { + }); + + Text value = new Text(timeStampVectorAccessor.getString()); + when(getter.get(0)).thenReturn(value.copyBytes()); + + Timestamp timestamp = accessor.getTimestamp(null); + collector.checkThat(timestamp, equalTo(timeStampVectorAccessor.getTimestamp(null))); + } + } + + @Test + public void testShouldGetTimeBeConsistentWithTimeAccessor() throws Exception { + try (TimeMilliVector timeVector = rootAllocatorTestRule.createTimeMilliVector()) { + ArrowFlightJdbcTimeVectorAccessor timeVectorAccessor = + new ArrowFlightJdbcTimeVectorAccessor(timeVector, () -> 0, (boolean wasNull) -> { + }); + + Text value = new Text(timeVectorAccessor.getString()); + when(getter.get(0)).thenReturn(value.copyBytes()); + + Time time = accessor.getTime(null); + collector.checkThat(time, equalTo(timeVectorAccessor.getTime(null))); + } + } + + @Test + public void testShouldGetDateBeConsistentWithDateAccessor() throws Exception { + try (DateMilliVector dateVector = rootAllocatorTestRule.createDateMilliVector()) { + ArrowFlightJdbcDateVectorAccessor dateVectorAccessor = + new ArrowFlightJdbcDateVectorAccessor(dateVector, () -> 0, (boolean wasNull) -> { + }); + + Text value = new Text(dateVectorAccessor.getString()); + when(getter.get(0)).thenReturn(value.copyBytes()); + + Date date = accessor.getDate(null); + collector.checkThat(date, equalTo(dateVectorAccessor.getDate(null))); + } + } + + @Test + public void testShouldGetObjectClassReturnString() { + final Class clazz = accessor.getObjectClass(); + collector.checkThat(clazz, equalTo(String.class)); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/authentication/Authentication.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/authentication/Authentication.java new file mode 100644 index 0000000000000..5fe2b0dc057fd --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/authentication/Authentication.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.authentication; + +import java.util.Properties; + +import org.apache.arrow.flight.auth2.CallHeaderAuthenticator; + +public interface Authentication { + /** + * Create a {@link CallHeaderAuthenticator} which is used to authenticate the connection. + * + * @return a CallHeaderAuthenticator. + */ + CallHeaderAuthenticator authenticate(); + + /** + * Uses the validCredentials variable and populate the Properties object. + * @param properties the Properties object that will be populated. + */ + void populateProperties(Properties properties); +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/authentication/TokenAuthentication.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/authentication/TokenAuthentication.java new file mode 100644 index 0000000000000..605705d1ca95a --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/authentication/TokenAuthentication.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.authentication; + +import java.util.ArrayList; +import java.util.List; +import java.util.Properties; + +import org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl; +import org.apache.arrow.flight.CallHeaders; +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.auth2.CallHeaderAuthenticator; + +public class TokenAuthentication implements Authentication { + private final List validCredentials; + + public TokenAuthentication(List validCredentials) { + this.validCredentials = validCredentials; + } + + @Override + public CallHeaderAuthenticator authenticate() { + return new CallHeaderAuthenticator() { + @Override + public AuthResult authenticate(CallHeaders incomingHeaders) { + String authorization = incomingHeaders.get("authorization"); + if (!validCredentials.contains(authorization)) { + throw CallStatus.UNAUTHENTICATED.withDescription("Invalid credentials.").toRuntimeException(); + } + return new AuthResult() { + @Override + public String getPeerIdentity() { + return authorization; + } + }; + } + }; + } + + @Override + public void populateProperties(Properties properties) { + this.validCredentials.forEach(value -> properties.put( + ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.TOKEN.camelName(), value)); + } + + public static final class Builder { + private final List tokenList = new ArrayList<>(); + + public TokenAuthentication.Builder token(String token) { + tokenList.add("Bearer " + token); + return this; + } + + public TokenAuthentication build() { + return new TokenAuthentication(tokenList); + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/authentication/UserPasswordAuthentication.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/authentication/UserPasswordAuthentication.java new file mode 100644 index 0000000000000..5dc97c858f352 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/authentication/UserPasswordAuthentication.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.authentication; + +import java.util.HashMap; +import java.util.Map; +import java.util.Properties; + +import org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl; +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.auth2.BasicCallHeaderAuthenticator; +import org.apache.arrow.flight.auth2.CallHeaderAuthenticator; +import org.apache.arrow.flight.auth2.GeneratedBearerTokenAuthenticator; + +public class UserPasswordAuthentication implements Authentication { + + private final Map validCredentials; + + public UserPasswordAuthentication(Map validCredentials) { + this.validCredentials = validCredentials; + } + + private String getCredentials(String key) { + return validCredentials.getOrDefault(key, null); + } + + @Override + public CallHeaderAuthenticator authenticate() { + return new GeneratedBearerTokenAuthenticator( + new BasicCallHeaderAuthenticator((username, password) -> { + if (validCredentials.containsKey(username) && getCredentials(username).equals(password)) { + return () -> username; + } + throw CallStatus.UNAUTHENTICATED.withDescription("Invalid credentials.").toRuntimeException(); + })); + } + + @Override + public void populateProperties(Properties properties) { + validCredentials.forEach((key, value) -> { + properties.put(ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.USER.camelName(), key); + properties.put(ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.PASSWORD.camelName(), value); + }); + } + + public static class Builder { + Map credentials = new HashMap<>(); + + public Builder user(String username, String password) { + credentials.put(username, password); + return this; + } + + public UserPasswordAuthentication build() { + return new UserPasswordAuthentication(credentials); + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/client/utils/ClientAuthenticationUtilsTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/client/utils/ClientAuthenticationUtilsTest.java new file mode 100644 index 0000000000000..d61436fd6e2a7 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/client/utils/ClientAuthenticationUtilsTest.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.client.utils; + +import static org.mockito.Mockito.mock; + +import java.io.IOException; +import java.io.InputStream; +import java.lang.reflect.Method; +import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.cert.Certificate; +import java.security.cert.CertificateException; +import java.util.Arrays; +import java.util.Collections; +import java.util.Enumeration; + +import org.bouncycastle.openssl.jcajce.JcaPEMWriter; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.mockito.junit.MockitoJUnitRunner; + +@RunWith(MockitoJUnitRunner.class) +public class ClientAuthenticationUtilsTest { + @Mock + KeyStore keyStoreMock; + + @Test + public void testGetCertificatesInputStream() throws IOException, KeyStoreException { + JcaPEMWriter pemWriterMock = mock(JcaPEMWriter.class); + Certificate certificateMock = mock(Certificate.class); + Enumeration alias = Collections.enumeration(Arrays.asList("test1", "test2")); + + Mockito.when(keyStoreMock.aliases()).thenReturn(alias); + Mockito.when(keyStoreMock.isCertificateEntry("test1")).thenReturn(true); + Mockito.when(keyStoreMock.getCertificate("test1")).thenReturn(certificateMock); + + ClientAuthenticationUtils.getCertificatesInputStream(keyStoreMock, pemWriterMock); + Mockito.verify(pemWriterMock).writeObject(certificateMock); + Mockito.verify(pemWriterMock).flush(); + } + + @Test + public void testGetKeyStoreInstance() throws IOException, + KeyStoreException, CertificateException, NoSuchAlgorithmException { + try (MockedStatic keyStoreMockedStatic = Mockito.mockStatic(KeyStore.class)) { + keyStoreMockedStatic + .when(() -> ClientAuthenticationUtils.getKeyStoreInstance(Mockito.any())) + .thenReturn(keyStoreMock); + + KeyStore receiveKeyStore = ClientAuthenticationUtils.getKeyStoreInstance("test1"); + Mockito + .verify(keyStoreMock) + .load(null, null); + + Assert.assertEquals(receiveKeyStore, keyStoreMock); + } + } + + @Test + public void testGetCertificateInputStreamFromMacSystem() throws IOException, + KeyStoreException, CertificateException, NoSuchAlgorithmException { + InputStream mock = mock(InputStream.class); + + try (MockedStatic keyStoreMockedStatic = createKeyStoreStaticMock(); + MockedStatic + clientAuthenticationUtilsMockedStatic = createClientAuthenticationUtilsStaticMock()) { + + setOperatingSystemMock(clientAuthenticationUtilsMockedStatic, false, true); + keyStoreMockedStatic.when(() -> ClientAuthenticationUtils + .getKeyStoreInstance("KeychainStore")) + .thenReturn(keyStoreMock); + keyStoreMockedStatic.when(() -> ClientAuthenticationUtils + .getCertificatesInputStream(Mockito.any())) + .thenReturn(mock); + + InputStream inputStream = ClientAuthenticationUtils.getCertificateInputStreamFromSystem("test"); + Assert.assertEquals(inputStream, mock); + } + } + + @Test + public void testGetCertificateInputStreamFromWindowsSystem() throws IOException, + KeyStoreException, CertificateException, NoSuchAlgorithmException { + InputStream mock = mock(InputStream.class); + + try (MockedStatic keyStoreMockedStatic = createKeyStoreStaticMock(); + MockedStatic + clientAuthenticationUtilsMockedStatic = createClientAuthenticationUtilsStaticMock()) { + + setOperatingSystemMock(clientAuthenticationUtilsMockedStatic, true, false); + keyStoreMockedStatic + .when(() -> ClientAuthenticationUtils.getKeyStoreInstance("Windows-ROOT")) + .thenReturn(keyStoreMock); + keyStoreMockedStatic + .when(() -> ClientAuthenticationUtils.getKeyStoreInstance("Windows-MY")) + .thenReturn(keyStoreMock); + keyStoreMockedStatic + .when(() -> ClientAuthenticationUtils.getCertificatesInputStream(Mockito.any())) + .thenReturn(mock); + + InputStream inputStream = ClientAuthenticationUtils.getCertificateInputStreamFromSystem("test"); + Assert.assertEquals(inputStream, mock); + } + } + + private MockedStatic createKeyStoreStaticMock() { + return Mockito.mockStatic(KeyStore.class); + } + + private MockedStatic createClientAuthenticationUtilsStaticMock() { + return Mockito.mockStatic(ClientAuthenticationUtils.class , invocationOnMock -> { + Method method = invocationOnMock.getMethod(); + if (method.getName().equals("getCertificateInputStreamFromSystem")) { + return invocationOnMock.callRealMethod(); + } + return invocationOnMock.getMock(); + }); + } + + private void setOperatingSystemMock(MockedStatic clientAuthenticationUtilsMockedStatic, + boolean isWindows, boolean isMac) { + clientAuthenticationUtilsMockedStatic.when(ClientAuthenticationUtils::isMac).thenReturn(isMac); + clientAuthenticationUtilsMockedStatic.when(ClientAuthenticationUtils::isWindows).thenReturn(isWindows); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/AccessorTestUtils.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/AccessorTestUtils.java new file mode 100644 index 0000000000000..bc1e8a042035b --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/AccessorTestUtils.java @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import static org.hamcrest.CoreMatchers.is; + +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.IntSupplier; +import java.util.function.Supplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessor; +import org.apache.arrow.vector.ValueVector; +import org.hamcrest.Matcher; +import org.junit.rules.ErrorCollector; + +public class AccessorTestUtils { + @FunctionalInterface + public interface CheckedFunction { + R apply(T t) throws SQLException; + } + + public interface AccessorSupplier { + T supply(ValueVector vector, IntSupplier getCurrentRow); + } + + public interface AccessorConsumer { + void accept(T accessor, int currentRow) throws Exception; + } + + public interface MatcherGetter { + Matcher get(T accessor, int currentRow); + } + + public static class Cursor { + int currentRow = 0; + int limit; + + public Cursor(int limit) { + this.limit = limit; + } + + public void next() { + currentRow++; + } + + boolean hasNext() { + return currentRow < limit; + } + + public int getCurrentRow() { + return currentRow; + } + } + + public static class AccessorIterator { + private final ErrorCollector collector; + private final AccessorSupplier accessorSupplier; + + public AccessorIterator(ErrorCollector collector, AccessorSupplier accessorSupplier) { + this.collector = collector; + this.accessorSupplier = accessorSupplier; + } + + public void iterate(ValueVector vector, AccessorConsumer accessorConsumer) throws Exception { + int valueCount = vector.getValueCount(); + if (valueCount == 0) { + throw new IllegalArgumentException("Vector is empty"); + } + + Cursor cursor = new Cursor(valueCount); + T accessor = accessorSupplier.supply(vector, cursor::getCurrentRow); + + while (cursor.hasNext()) { + accessorConsumer.accept(accessor, cursor.getCurrentRow()); + cursor.next(); + } + } + + public void iterate(ValueVector vector, Consumer accessorConsumer) throws Exception { + iterate(vector, (accessor, currentRow) -> accessorConsumer.accept(accessor)); + } + + public List toList(ValueVector vector) throws Exception { + List result = new ArrayList<>(); + iterate(vector, (accessor, currentRow) -> result.add(accessor.getObject())); + + return result; + } + + public void assertAccessorGetter(ValueVector vector, CheckedFunction getter, + MatcherGetter matcherGetter) throws Exception { + iterate(vector, (accessor, currentRow) -> { + R object = getter.apply(accessor); + boolean wasNull = accessor.wasNull(); + + collector.checkThat(object, matcherGetter.get(accessor, currentRow)); + collector.checkThat(wasNull, is(accessor.getObject() == null)); + }); + } + + public void assertAccessorGetterThrowingException(ValueVector vector, CheckedFunction getter) + throws Exception { + iterate(vector, (accessor, currentRow) -> + ThrowableAssertionUtils.simpleAssertThrowableClass(SQLException.class, () -> getter.apply(accessor))); + } + + public void assertAccessorGetter(ValueVector vector, CheckedFunction getter, + Function> matcherGetter) throws Exception { + assertAccessorGetter(vector, getter, (accessor, currentRow) -> matcherGetter.apply(accessor)); + } + + public void assertAccessorGetter(ValueVector vector, CheckedFunction getter, + Supplier> matcherGetter) throws Exception { + assertAccessorGetter(vector, getter, (accessor, currentRow) -> matcherGetter.get()); + } + + public void assertAccessorGetter(ValueVector vector, CheckedFunction getter, + Matcher matcher) throws Exception { + assertAccessorGetter(vector, getter, (accessor, currentRow) -> matcher); + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImplTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImplTest.java new file mode 100644 index 0000000000000..4fb07428af4ef --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImplTest.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import static java.lang.Runtime.getRuntime; +import static java.util.Arrays.asList; +import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.HOST; +import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.PASSWORD; +import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.PORT; +import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.THREAD_POOL_SIZE; +import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.USER; +import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.USE_ENCRYPTION; +import static org.hamcrest.CoreMatchers.is; + +import java.util.List; +import java.util.Properties; +import java.util.Random; +import java.util.function.Function; + +import org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; + +@RunWith(Parameterized.class) +public final class ArrowFlightConnectionConfigImplTest { + + private static final Random RANDOM = new Random(12L); + + private final Properties properties = new Properties(); + private ArrowFlightConnectionConfigImpl arrowFlightConnectionConfig; + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + + @Parameter + public ArrowFlightConnectionProperty property; + + @Parameter(value = 1) + public Object value; + + @Parameter(value = 2) + public Function arrowFlightConnectionConfigFunction; + + @Before + public void setUp() { + arrowFlightConnectionConfig = new ArrowFlightConnectionConfigImpl(properties); + properties.put(property.camelName(), value); + } + + @Test + public void testGetProperty() { + collector.checkThat(arrowFlightConnectionConfigFunction.apply(arrowFlightConnectionConfig), + is(value)); + } + + @Parameters(name = "<{0}> as <{1}>") + public static List provideParameters() { + return asList(new Object[][] { + {HOST, "host", + (Function) ArrowFlightConnectionConfigImpl::getHost}, + {PORT, + RANDOM.nextInt(Short.toUnsignedInt(Short.MAX_VALUE)), + (Function) ArrowFlightConnectionConfigImpl::getPort}, + {USER, "user", + (Function) ArrowFlightConnectionConfigImpl::getUser}, + {PASSWORD, "password", + (Function) ArrowFlightConnectionConfigImpl::getPassword}, + {USE_ENCRYPTION, RANDOM.nextBoolean(), + (Function) ArrowFlightConnectionConfigImpl::useEncryption}, + {THREAD_POOL_SIZE, + RANDOM.nextInt(getRuntime().availableProcessors()), + (Function) ArrowFlightConnectionConfigImpl::threadPoolSize}, + }); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionPropertyTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionPropertyTest.java new file mode 100644 index 0000000000000..25a48612cbd0b --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionPropertyTest.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import static org.apache.arrow.util.AutoCloseables.close; +import static org.mockito.MockitoAnnotations.openMocks; + +import java.util.ArrayList; +import java.util.List; +import java.util.Properties; + +import org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty; +import org.junit.After; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; +import org.mockito.Mock; + +@RunWith(Parameterized.class) +public final class ArrowFlightConnectionPropertyTest { + + @Mock + public Properties properties; + + private AutoCloseable mockitoResource; + + @Parameter + public ArrowFlightConnectionProperty arrowFlightConnectionProperty; + + @Before + public void setUp() { + mockitoResource = openMocks(this); + } + + @After + public void tearDown() throws Exception { + close(mockitoResource); + } + + @Test + public void testWrapIsUnsupported() { + ThrowableAssertionUtils.simpleAssertThrowableClass(UnsupportedOperationException.class, + () -> arrowFlightConnectionProperty.wrap(properties)); + } + + @Test + public void testRequiredPropertyThrows() { + Assume.assumeTrue(arrowFlightConnectionProperty.required()); + ThrowableAssertionUtils.simpleAssertThrowableClass(IllegalStateException.class, + () -> arrowFlightConnectionProperty.get(new Properties())); + } + + @Test + public void testOptionalPropertyReturnsDefault() { + Assume.assumeTrue(!arrowFlightConnectionProperty.required()); + Assert.assertEquals(arrowFlightConnectionProperty.defaultValue(), + arrowFlightConnectionProperty.get(new Properties())); + } + + @Parameters + public static List provideParameters() { + final ArrowFlightConnectionProperty[] arrowFlightConnectionProperties = + ArrowFlightConnectionProperty.values(); + final List parameters = new ArrayList<>(arrowFlightConnectionProperties.length); + for (final ArrowFlightConnectionProperty arrowFlightConnectionProperty : arrowFlightConnectionProperties) { + parameters.add(new Object[] {arrowFlightConnectionProperty}); + } + return parameters; + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/ConnectionWrapperTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/ConnectionWrapperTest.java new file mode 100644 index 0000000000000..6044f3a363c7e --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/ConnectionWrapperTest.java @@ -0,0 +1,443 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import static java.lang.String.format; +import static java.util.stream.IntStream.range; +import static org.hamcrest.CoreMatchers.allOf; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.hamcrest.CoreMatchers.sameInstance; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLClientInfoException; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.Arrays; +import java.util.Random; + +import org.apache.arrow.driver.jdbc.ArrowFlightConnection; +import org.apache.arrow.util.AutoCloseables; +import org.apache.calcite.avatica.AvaticaConnection; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; + +@RunWith(MockitoJUnitRunner.class) +public final class ConnectionWrapperTest { + + private static final String SCHEMA_NAME = "SCHEMA"; + private static final String PLACEHOLDER_QUERY = "SELECT * FROM DOES_NOT_MATTER"; + private static final int[] COLUMN_INDICES = range(0, 10).toArray(); + private static final String[] COLUMN_NAMES = + Arrays.stream(COLUMN_INDICES).mapToObj(i -> format("col%d", i)).toArray(String[]::new); + private static final String TYPE_NAME = "TYPE_NAME"; + private static final String SAVEPOINT_NAME = "SAVEPOINT"; + private static final String CLIENT_INFO = "CLIENT_INFO"; + private static final int RESULT_SET_TYPE = ResultSet.TYPE_FORWARD_ONLY; + private static final int RESULT_SET_CONCURRENCY = ResultSet.CONCUR_READ_ONLY; + private static final int RESULT_SET_HOLDABILITY = ResultSet.HOLD_CURSORS_OVER_COMMIT; + private static final int GENERATED_KEYS = Statement.NO_GENERATED_KEYS; + private static final Random RANDOM = new Random(Long.MAX_VALUE); + private static final int TIMEOUT = RANDOM.nextInt(Integer.MAX_VALUE); + + @Mock + public AvaticaConnection underlyingConnection; + private ConnectionWrapper connectionWrapper; + @Rule + public final ErrorCollector collector = new ErrorCollector(); + + @Before + public void setUp() { + connectionWrapper = new ConnectionWrapper(underlyingConnection); + } + + @After + public void tearDown() throws Exception { + AutoCloseables.close(connectionWrapper, underlyingConnection); + } + + @Test + public void testUnwrappingUnderlyingConnectionShouldReturnUnderlyingConnection() { + collector.checkThat( + collector.checkSucceeds(() -> connectionWrapper.unwrap(Object.class)), + is(sameInstance(underlyingConnection))); + collector.checkThat( + collector.checkSucceeds(() -> connectionWrapper.unwrap(Connection.class)), + is(sameInstance(underlyingConnection))); + collector.checkThat( + collector.checkSucceeds(() -> connectionWrapper.unwrap(AvaticaConnection.class)), + is(sameInstance(underlyingConnection))); + ThrowableAssertionUtils.simpleAssertThrowableClass(ClassCastException.class, + () -> connectionWrapper.unwrap(ArrowFlightConnection.class)); + ThrowableAssertionUtils.simpleAssertThrowableClass(ClassCastException.class, + () -> connectionWrapper.unwrap(ConnectionWrapper.class)); + } + + @Test + public void testCreateStatementShouldCreateStatementFromUnderlyingConnection() + throws SQLException { + collector.checkThat( + connectionWrapper.createStatement(), + is(sameInstance(verify(underlyingConnection, times(1)).createStatement()))); + collector.checkThat( + connectionWrapper.createStatement(RESULT_SET_TYPE, RESULT_SET_CONCURRENCY, + RESULT_SET_HOLDABILITY), + is(verify(underlyingConnection, times(1)) + .createStatement(RESULT_SET_TYPE, RESULT_SET_CONCURRENCY, RESULT_SET_HOLDABILITY))); + collector.checkThat( + connectionWrapper.createStatement(RESULT_SET_TYPE, RESULT_SET_CONCURRENCY), + is(verify(underlyingConnection, times(1)) + .createStatement(RESULT_SET_TYPE, RESULT_SET_CONCURRENCY))); + } + + @Test + public void testPrepareStatementShouldPrepareStatementFromUnderlyingConnection() + throws SQLException { + collector.checkThat( + connectionWrapper.prepareStatement(PLACEHOLDER_QUERY), + is(sameInstance( + verify(underlyingConnection, times(1)).prepareStatement(PLACEHOLDER_QUERY)))); + collector.checkThat( + connectionWrapper.prepareStatement(PLACEHOLDER_QUERY, COLUMN_INDICES), + is(allOf(sameInstance(verify(underlyingConnection, times(1)) + .prepareStatement(PLACEHOLDER_QUERY, COLUMN_INDICES)), + nullValue()))); + collector.checkThat( + connectionWrapper.prepareStatement(PLACEHOLDER_QUERY, COLUMN_NAMES), + is(allOf(sameInstance(verify(underlyingConnection, times(1)) + .prepareStatement(PLACEHOLDER_QUERY, COLUMN_NAMES)), + nullValue()))); + collector.checkThat( + connectionWrapper.prepareStatement(PLACEHOLDER_QUERY, RESULT_SET_TYPE, + RESULT_SET_CONCURRENCY), + is(allOf(sameInstance(verify(underlyingConnection, times(1)) + .prepareStatement(PLACEHOLDER_QUERY, RESULT_SET_TYPE, RESULT_SET_CONCURRENCY)), + nullValue()))); + collector.checkThat( + connectionWrapper.prepareStatement(PLACEHOLDER_QUERY, GENERATED_KEYS), + is(allOf(sameInstance(verify(underlyingConnection, times(1)) + .prepareStatement(PLACEHOLDER_QUERY, GENERATED_KEYS)), + nullValue()))); + } + + @Test + public void testPrepareCallShouldPrepareCallFromUnderlyingConnection() throws SQLException { + collector.checkThat( + connectionWrapper.prepareCall(PLACEHOLDER_QUERY), + is(sameInstance( + verify(underlyingConnection, times(1)).prepareCall(PLACEHOLDER_QUERY)))); + collector.checkThat( + connectionWrapper.prepareCall(PLACEHOLDER_QUERY, RESULT_SET_TYPE, RESULT_SET_CONCURRENCY), + is(verify(underlyingConnection, times(1)) + .prepareCall(PLACEHOLDER_QUERY, RESULT_SET_TYPE, RESULT_SET_CONCURRENCY))); + } + + @Test + public void testNativeSqlShouldGetNativeSqlFromUnderlyingConnection() throws SQLException { + collector.checkThat( + connectionWrapper.nativeSQL(PLACEHOLDER_QUERY), + is(sameInstance( + verify(underlyingConnection, times(1)).nativeSQL(PLACEHOLDER_QUERY)))); + } + + @Test + public void testSetAutoCommitShouldSetAutoCommitInUnderlyingConnection() throws SQLException { + connectionWrapper.setAutoCommit(true); + verify(underlyingConnection, times(1)).setAutoCommit(true); + connectionWrapper.setAutoCommit(false); + verify(underlyingConnection, times(1)).setAutoCommit(false); + } + + @Test + public void testGetAutoCommitShouldGetAutoCommitFromUnderlyingConnection() throws SQLException { + collector.checkThat( + connectionWrapper.getAutoCommit(), + is(verify(underlyingConnection, times(1)).getAutoCommit())); + } + + @Test + public void testCommitShouldCommitToUnderlyingConnection() throws SQLException { + connectionWrapper.commit(); + verify(underlyingConnection, times(1)).commit(); + } + + @Test + public void testRollbackShouldRollbackFromUnderlyingConnection() throws SQLException { + connectionWrapper.rollback(); + verify(underlyingConnection, times(1)).rollback(); + } + + @Test + public void testCloseShouldCloseUnderlyingConnection() throws SQLException { + connectionWrapper.close(); + verify(underlyingConnection, times(1)).close(); + } + + @Test + public void testIsClosedShouldGetStatusFromUnderlyingConnection() throws SQLException { + collector.checkThat( + connectionWrapper.isClosed(), is(verify(underlyingConnection, times(1)).isClosed())); + } + + @Test + public void testGetMetadataShouldGetMetadataFromUnderlyingConnection() throws SQLException { + collector.checkThat( + connectionWrapper.getMetaData(), is(verify(underlyingConnection, times(1)).getMetaData())); + } + + @Test + public void testSetReadOnlyShouldSetUnderlyingConnectionAsReadOnly() throws SQLException { + connectionWrapper.setReadOnly(false); + verify(underlyingConnection, times(1)).setReadOnly(false); + connectionWrapper.setReadOnly(true); + verify(underlyingConnection, times(1)).setReadOnly(true); + } + + @Test + public void testSetIsReadOnlyShouldGetStatusFromUnderlyingConnection() throws SQLException { + collector.checkThat(connectionWrapper.isReadOnly(), + is(verify(underlyingConnection).isReadOnly())); + } + + @Test + public void testSetCatalogShouldSetCatalogInUnderlyingConnection() throws SQLException { + final String catalog = "CATALOG"; + connectionWrapper.setCatalog(catalog); + verify(underlyingConnection, times(1)).setCatalog(catalog); + } + + @Test + public void testGetCatalogShouldGetCatalogFromUnderlyingConnection() throws SQLException { + collector.checkThat( + connectionWrapper.getCatalog(), + is(allOf(sameInstance(verify(underlyingConnection, times(1)).getCatalog()), nullValue()))); + } + + @Test + public void setTransactionIsolationShouldSetUnderlyingTransactionIsolation() throws SQLException { + final int transactionIsolation = Connection.TRANSACTION_NONE; + connectionWrapper.setTransactionIsolation(Connection.TRANSACTION_NONE); + verify(underlyingConnection, times(1)).setTransactionIsolation(transactionIsolation); + } + + @Test + public void getTransactionIsolationShouldGetUnderlyingConnectionIsolation() throws SQLException { + collector.checkThat( + connectionWrapper.getTransactionIsolation(), + is(equalTo(verify(underlyingConnection, times(1)).getTransactionIsolation()))); + } + + @Test + public void getWarningShouldGetWarningsFromUnderlyingConnection() throws SQLException { + collector.checkThat( + connectionWrapper.getWarnings(), + is(allOf( + sameInstance(verify(underlyingConnection, times(1)).getWarnings()), + nullValue()))); + } + + @Test + public void testClearWarningShouldClearWarningsFromUnderlyingConnection() throws SQLException { + connectionWrapper.clearWarnings(); + verify(underlyingConnection, times(1)).clearWarnings(); + } + + @Test + public void getTypeMapShouldGetTypeMapFromUnderlyingConnection() throws SQLException { + when(underlyingConnection.getTypeMap()).thenReturn(null); + collector.checkThat( + connectionWrapper.getTypeMap(), + is(verify(underlyingConnection, times(1)).getTypeMap())); + } + + @Test + public void testSetTypeMapShouldSetTypeMapFromUnderlyingConnection() throws SQLException { + connectionWrapper.setTypeMap(null); + verify(underlyingConnection, times(1)).setTypeMap(null); + } + + @Test + public void testSetHoldabilityShouldSetUnderlyingConnection() throws SQLException { + connectionWrapper.setHoldability(RESULT_SET_HOLDABILITY); + verify(underlyingConnection, times(1)).setHoldability(RESULT_SET_HOLDABILITY); + } + + @Test + public void testGetHoldabilityShouldGetHoldabilityFromUnderlyingConnection() throws SQLException { + collector.checkThat( + connectionWrapper.getHoldability(), + is(equalTo(verify(underlyingConnection, times(1)).getHoldability()))); + } + + @Test + public void testSetSavepointShouldSetSavepointInUnderlyingConnection() throws SQLException { + collector.checkThat( + connectionWrapper.setSavepoint(), + is(allOf( + sameInstance(verify(underlyingConnection, times(1)).setSavepoint()), + nullValue()))); + collector.checkThat( + connectionWrapper.setSavepoint(SAVEPOINT_NAME), + is(sameInstance( + verify(underlyingConnection, times(1)).setSavepoint(SAVEPOINT_NAME)))); + } + + @Test + public void testRollbackShouldRollbackInUnderlyingConnection() throws SQLException { + connectionWrapper.rollback(null); + verify(underlyingConnection, times(1)).rollback(null); + } + + @Test + public void testReleaseSavepointShouldReleaseSavepointFromUnderlyingConnection() + throws SQLException { + connectionWrapper.releaseSavepoint(null); + verify(underlyingConnection, times(1)).releaseSavepoint(null); + } + + @Test + public void testCreateClobShouldCreateClobFromUnderlyingConnection() throws SQLException { + collector.checkThat( + connectionWrapper.createClob(), + is(allOf(sameInstance( + verify(underlyingConnection, times(1)).createClob()), nullValue()))); + } + + @Test + public void testCreateBlobShouldCreateBlobFromUnderlyingConnection() throws SQLException { + collector.checkThat( + connectionWrapper.createBlob(), + is(allOf(sameInstance( + verify(underlyingConnection, times(1)).createBlob()), nullValue()))); + } + + @Test + public void testCreateNClobShouldCreateNClobFromUnderlyingConnection() throws SQLException { + collector.checkThat( + connectionWrapper.createNClob(), + is(allOf(sameInstance( + verify(underlyingConnection, times(1)).createNClob()), nullValue()))); + } + + @Test + public void testCreateSQLXMLShouldCreateSQLXMLFromUnderlyingConnection() throws SQLException { + collector.checkThat( + connectionWrapper.createSQLXML(), + is(allOf(sameInstance( + verify(underlyingConnection, times(1)).createSQLXML()), nullValue()))); + } + + @Test + public void testIsValidShouldReturnWhetherUnderlyingConnectionIsValid() throws SQLException { + collector.checkThat( + connectionWrapper.isValid(TIMEOUT), + is(verify(underlyingConnection, times(1)).isValid(TIMEOUT))); + } + + @Test + public void testSetClientInfoShouldSetClientInfoInUnderlyingConnection() + throws SQLClientInfoException { + connectionWrapper.setClientInfo(null); + verify(underlyingConnection, times(1)).setClientInfo(null); + } + + @Test + public void testGetClientInfoShouldGetClientInfoFromUnderlyingConnection() throws SQLException { + collector.checkThat( + connectionWrapper.getClientInfo(CLIENT_INFO), + is(allOf( + sameInstance( + verify(underlyingConnection, times(1)).getClientInfo(CLIENT_INFO)), + nullValue()))); + collector.checkThat( + connectionWrapper.getClientInfo(), + is(allOf( + sameInstance( + verify(underlyingConnection, times(1)).getClientInfo()), + nullValue()))); + } + + @Test + public void testCreateArrayOfShouldCreateArrayFromUnderlyingConnection() throws SQLException { + final Object[] elements = range(0, 100).boxed().toArray(); + collector.checkThat( + connectionWrapper.createArrayOf(TYPE_NAME, elements), + is(allOf( + sameInstance( + verify(underlyingConnection, times(1)).createArrayOf(TYPE_NAME, elements)), + nullValue()))); + } + + @Test + public void testCreateStructShouldCreateStructFromUnderlyingConnection() throws SQLException { + final Object[] attributes = range(0, 120).boxed().toArray(); + collector.checkThat( + connectionWrapper.createStruct(TYPE_NAME, attributes), + is(allOf( + sameInstance( + verify(underlyingConnection, times(1)).createStruct(TYPE_NAME, attributes)), + nullValue()))); + } + + @Test + public void testSetSchemaShouldSetSchemaInUnderlyingConnection() throws SQLException { + connectionWrapper.setSchema(SCHEMA_NAME); + verify(underlyingConnection, times(1)).setSchema(SCHEMA_NAME); + } + + @Test + public void testGetSchemaShouldGetSchemaFromUnderlyingConnection() throws SQLException { + collector.checkThat( + connectionWrapper.getSchema(), + is(allOf( + sameInstance(verify(underlyingConnection, times(1)).getSchema()), + nullValue()))); + } + + @Test + public void testAbortShouldAbortUnderlyingConnection() throws SQLException { + connectionWrapper.abort(null); + verify(underlyingConnection, times(1)).abort(null); + } + + @Test + public void testSetNetworkTimeoutShouldSetNetworkTimeoutInUnderlyingConnection() + throws SQLException { + connectionWrapper.setNetworkTimeout(null, TIMEOUT); + verify(underlyingConnection, times(1)).setNetworkTimeout(null, TIMEOUT); + } + + @Test + public void testGetNetworkTimeoutShouldGetNetworkTimeoutFromUnderlyingConnection() + throws SQLException { + collector.checkThat( + connectionWrapper.getNetworkTimeout(), + is(equalTo(verify(underlyingConnection, times(1)).getNetworkTimeout()))); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/ConvertUtilsTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/ConvertUtilsTest.java new file mode 100644 index 0000000000000..5cea3749283d7 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/ConvertUtilsTest.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import static org.hamcrest.CoreMatchers.equalTo; + +import java.util.List; + +import org.apache.arrow.flight.sql.FlightSqlColumnMetadata; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.calcite.avatica.ColumnMetaData; +import org.apache.calcite.avatica.proto.Common; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; + +import com.google.common.collect.ImmutableList; + +public class ConvertUtilsTest { + + @Rule + public ErrorCollector collector = new ErrorCollector(); + + @Test + public void testShouldSetOnColumnMetaDataBuilder() { + + final Common.ColumnMetaData.Builder builder = Common.ColumnMetaData.newBuilder(); + final FlightSqlColumnMetadata expectedColumnMetaData = new FlightSqlColumnMetadata.Builder() + .catalogName("catalog1") + .schemaName("schema1") + .tableName("table1") + .isAutoIncrement(true) + .isCaseSensitive(true) + .isReadOnly(true) + .isSearchable(true) + .precision(20) + .scale(10) + .build(); + ConvertUtils.setOnColumnMetaDataBuilder(builder, expectedColumnMetaData.getMetadataMap()); + assertBuilder(builder, expectedColumnMetaData); + } + + @Test + public void testShouldConvertArrowFieldsToColumnMetaDataList() { + + final List listField = ImmutableList.of( + new Field("col1", + new FieldType(true, ArrowType.Utf8.INSTANCE, null, + new FlightSqlColumnMetadata.Builder() + .catalogName("catalog1") + .schemaName("schema1") + .tableName("table1") + .build().getMetadataMap() + ), null)); + + final List expectedColumnMetaData = ImmutableList.of( + ColumnMetaData.fromProto( + Common.ColumnMetaData.newBuilder() + .setCatalogName("catalog1") + .setSchemaName("schema1") + .setTableName("table1") + .build())); + + final List actualColumnMetaData = ConvertUtils.convertArrowFieldsToColumnMetaDataList(listField); + assertColumnMetaData(expectedColumnMetaData, actualColumnMetaData); + } + + private void assertColumnMetaData(final List expected, final List actual) { + collector.checkThat(expected.size(), equalTo(actual.size())); + int size = expected.size(); + for (int i = 0; i < size; i++) { + final ColumnMetaData expectedColumnMetaData = expected.get(i); + final ColumnMetaData actualColumnMetaData = actual.get(i); + collector.checkThat(expectedColumnMetaData.catalogName, equalTo(actualColumnMetaData.catalogName)); + collector.checkThat(expectedColumnMetaData.schemaName, equalTo(actualColumnMetaData.schemaName)); + collector.checkThat(expectedColumnMetaData.tableName, equalTo(actualColumnMetaData.tableName)); + collector.checkThat(expectedColumnMetaData.readOnly, equalTo(actualColumnMetaData.readOnly)); + collector.checkThat(expectedColumnMetaData.autoIncrement, equalTo(actualColumnMetaData.autoIncrement)); + collector.checkThat(expectedColumnMetaData.precision, equalTo(actualColumnMetaData.precision)); + collector.checkThat(expectedColumnMetaData.scale, equalTo(actualColumnMetaData.scale)); + collector.checkThat(expectedColumnMetaData.caseSensitive, equalTo(actualColumnMetaData.caseSensitive)); + collector.checkThat(expectedColumnMetaData.searchable, equalTo(actualColumnMetaData.searchable)); + } + } + + private void assertBuilder(final Common.ColumnMetaData.Builder builder, + final FlightSqlColumnMetadata flightSqlColumnMetaData) { + + final Integer precision = flightSqlColumnMetaData.getPrecision(); + final Integer scale = flightSqlColumnMetaData.getScale(); + + collector.checkThat(flightSqlColumnMetaData.getCatalogName(), equalTo(builder.getCatalogName())); + collector.checkThat(flightSqlColumnMetaData.getSchemaName(), equalTo(builder.getSchemaName())); + collector.checkThat(flightSqlColumnMetaData.getTableName(), equalTo(builder.getTableName())); + collector.checkThat(flightSqlColumnMetaData.isAutoIncrement(), equalTo(builder.getAutoIncrement())); + collector.checkThat(flightSqlColumnMetaData.isCaseSensitive(), equalTo(builder.getCaseSensitive())); + collector.checkThat(flightSqlColumnMetaData.isSearchable(), equalTo(builder.getSearchable())); + collector.checkThat(flightSqlColumnMetaData.isReadOnly(), equalTo(builder.getReadOnly())); + collector.checkThat(precision == null ? 0 : precision, equalTo(builder.getPrecision())); + collector.checkThat(scale == null ? 0 : scale, equalTo(builder.getScale())); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/CoreMockedSqlProducers.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/CoreMockedSqlProducers.java new file mode 100644 index 0000000000000..cf359849a7105 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/CoreMockedSqlProducers.java @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import static java.lang.String.format; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; + +import java.sql.Date; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Timestamp; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.function.Consumer; +import java.util.stream.IntStream; + +import org.apache.arrow.flight.FlightProducer.ServerStreamListener; +import org.apache.arrow.flight.sql.FlightSqlColumnMetadata; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.DateDayVector; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.TimeStampMilliVector; +import org.apache.arrow.vector.UInt4Vector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.DateUnit; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.TimeUnit; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.Text; +import org.junit.rules.ErrorCollector; + +import com.google.common.collect.ImmutableList; + +/** + * Standard {@link MockFlightSqlProducer} instances for tests. + */ +// TODO Remove this once all tests are refactor to use only the queries they need. +public final class CoreMockedSqlProducers { + + public static final String LEGACY_REGULAR_SQL_CMD = "SELECT * FROM TEST"; + public static final String LEGACY_METADATA_SQL_CMD = "SELECT * FROM METADATA"; + public static final String LEGACY_CANCELLATION_SQL_CMD = "SELECT * FROM TAKES_FOREVER"; + + private CoreMockedSqlProducers() { + // Prevent instantiation. + } + + /** + * Gets the {@link MockFlightSqlProducer} for legacy tests and backward compatibility. + * + * @return a new producer. + */ + public static MockFlightSqlProducer getLegacyProducer() { + + final MockFlightSqlProducer producer = new MockFlightSqlProducer(); + addLegacyRegularSqlCmdSupport(producer); + addLegacyMetadataSqlCmdSupport(producer); + addLegacyCancellationSqlCmdSupport(producer); + return producer; + } + + private static void addLegacyRegularSqlCmdSupport(final MockFlightSqlProducer producer) { + final Schema querySchema = new Schema(ImmutableList.of( + new Field( + "ID", + new FieldType(true, new ArrowType.Int(64, true), + null), + null), + new Field( + "Name", + new FieldType(true, new ArrowType.Utf8(), null), + null), + new Field( + "Age", + new FieldType(true, new ArrowType.Int(32, false), + null), + null), + new Field( + "Salary", + new FieldType(true, new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE), + null), + null), + new Field( + "Hire Date", + new FieldType(true, new ArrowType.Date(DateUnit.DAY), null), + null), + new Field( + "Last Sale", + new FieldType(true, new ArrowType.Timestamp(TimeUnit.MILLISECOND, null), + null), + null) + )); + final List> resultProducers = new ArrayList<>(); + IntStream.range(0, 10).forEach(page -> { + resultProducers.add(listener -> { + final int rowsPerPage = 5000; + try (final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + final VectorSchemaRoot root = VectorSchemaRoot.create(querySchema, allocator)) { + root.allocateNew(); + listener.start(root); + int batchSize = 500; + int indexOnBatch = 0; + int resultsOffset = page * rowsPerPage; + for (int i = 0; i < rowsPerPage; i++) { + ((BigIntVector) root.getVector("ID")) + .setSafe(indexOnBatch, (long) Integer.MAX_VALUE + 1 + i + resultsOffset); + ((VarCharVector) root.getVector("Name")) + .setSafe(indexOnBatch, new Text("Test Name #" + (resultsOffset + i))); + ((UInt4Vector) root.getVector("Age")) + .setSafe(indexOnBatch, (int) Short.MAX_VALUE + 1 + i + resultsOffset); + ((Float8Vector) root.getVector("Salary")) + .setSafe(indexOnBatch, + Math.scalb((double) (i + resultsOffset) / 2, i + resultsOffset)); + ((DateDayVector) root.getVector("Hire Date")) + .setSafe(indexOnBatch, i + resultsOffset); + ((TimeStampMilliVector) root.getVector("Last Sale")) + .setSafe(indexOnBatch, Long.MAX_VALUE - i - resultsOffset); + indexOnBatch++; + if (indexOnBatch == batchSize) { + root.setRowCount(indexOnBatch); + if (listener.isCancelled()) { + return; + } + listener.putNext(); + root.allocateNew(); + indexOnBatch = 0; + } + } + if (listener.isCancelled()) { + return; + } + root.setRowCount(indexOnBatch); + listener.putNext(); + } finally { + listener.completed(); + } + }); + }); + producer.addSelectQuery(LEGACY_REGULAR_SQL_CMD, querySchema, resultProducers); + } + + private static void addLegacyMetadataSqlCmdSupport(final MockFlightSqlProducer producer) { + final Schema metadataSchema = new Schema(ImmutableList.of( + new Field( + "integer0", + new FieldType(true, new ArrowType.Int(64, true), + null, new FlightSqlColumnMetadata.Builder() + .catalogName("CATALOG_NAME_1") + .schemaName("SCHEMA_NAME_1") + .tableName("TABLE_NAME_1") + .typeName("TYPE_NAME_1") + .precision(10) + .scale(0) + .isAutoIncrement(true) + .isCaseSensitive(false) + .isReadOnly(true) + .isSearchable(true) + .build().getMetadataMap()), + null), + new Field( + "string1", + new FieldType(true, new ArrowType.Utf8(), + null, new FlightSqlColumnMetadata.Builder() + .catalogName("CATALOG_NAME_2") + .schemaName("SCHEMA_NAME_2") + .tableName("TABLE_NAME_2") + .typeName("TYPE_NAME_2") + .precision(65535) + .scale(0) + .isAutoIncrement(false) + .isCaseSensitive(true) + .isReadOnly(false) + .isSearchable(true) + .build().getMetadataMap()), + null), + new Field( + "float2", + new FieldType(true, new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE), + null, new FlightSqlColumnMetadata.Builder() + .catalogName("CATALOG_NAME_3") + .schemaName("SCHEMA_NAME_3") + .tableName("TABLE_NAME_3") + .typeName("TYPE_NAME_3") + .precision(15) + .scale(20) + .isAutoIncrement(false) + .isCaseSensitive(false) + .isReadOnly(false) + .isSearchable(true) + .build().getMetadataMap()), + null))); + final Consumer formula = listener -> { + try (final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + final VectorSchemaRoot root = VectorSchemaRoot.create(metadataSchema, allocator)) { + root.allocateNew(); + ((BigIntVector) root.getVector("integer0")).setSafe(0, 1); + ((VarCharVector) root.getVector("string1")).setSafe(0, new Text("teste")); + ((Float4Vector) root.getVector("float2")).setSafe(0, (float) 4.1); + root.setRowCount(1); + listener.start(root); + listener.putNext(); + } finally { + listener.completed(); + } + }; + producer.addSelectQuery(LEGACY_METADATA_SQL_CMD, metadataSchema, + Collections.singletonList(formula)); + } + + private static void addLegacyCancellationSqlCmdSupport(final MockFlightSqlProducer producer) { + producer.addSelectQuery( + LEGACY_CANCELLATION_SQL_CMD, + new Schema(Collections.singletonList(new Field( + "integer0", + new FieldType(true, new ArrowType.Int(64, true), null), + null))), + Collections.singletonList(listener -> { + // Should keep hanging until canceled. + })); + } + + /** + * Asserts that the values in the provided {@link ResultSet} are expected for the + * legacy {@link MockFlightSqlProducer}. + * + * @param resultSet the result set. + * @param collector the {@link ErrorCollector} to use. + * @throws SQLException on error. + */ + public static void assertLegacyRegularSqlResultSet(final ResultSet resultSet, + final ErrorCollector collector) + throws SQLException { + final int expectedRowCount = 50_000; + + final long[] expectedIds = new long[expectedRowCount]; + final List expectedNames = new ArrayList<>(expectedRowCount); + final int[] expectedAges = new int[expectedRowCount]; + final double[] expectedSalaries = new double[expectedRowCount]; + final List expectedHireDates = new ArrayList<>(expectedRowCount); + final List expectedLastSales = new ArrayList<>(expectedRowCount); + + final long[] actualIds = new long[expectedRowCount]; + final List actualNames = new ArrayList<>(expectedRowCount); + final int[] actualAges = new int[expectedRowCount]; + final double[] actualSalaries = new double[expectedRowCount]; + final List actualHireDates = new ArrayList<>(expectedRowCount); + final List actualLastSales = new ArrayList<>(expectedRowCount); + + int actualRowCount = 0; + + for (; resultSet.next(); actualRowCount++) { + expectedIds[actualRowCount] = (long) Integer.MAX_VALUE + 1 + actualRowCount; + expectedNames.add(format("Test Name #%d", actualRowCount)); + expectedAges[actualRowCount] = (int) Short.MAX_VALUE + 1 + actualRowCount; + expectedSalaries[actualRowCount] = Math.scalb((double) actualRowCount / 2, actualRowCount); + expectedHireDates.add(new Date(86_400_000L * actualRowCount)); + expectedLastSales.add(new Timestamp(Long.MAX_VALUE - actualRowCount)); + + actualIds[actualRowCount] = (long) resultSet.getObject(1); + actualNames.add((String) resultSet.getObject(2)); + actualAges[actualRowCount] = (int) resultSet.getObject(3); + actualSalaries[actualRowCount] = (double) resultSet.getObject(4); + actualHireDates.add((Date) resultSet.getObject(5)); + actualLastSales.add((Timestamp) resultSet.getObject(6)); + } + collector.checkThat(actualRowCount, is(equalTo(expectedRowCount))); + collector.checkThat(actualIds, is(expectedIds)); + collector.checkThat(actualNames, is(expectedNames)); + collector.checkThat(actualAges, is(expectedAges)); + collector.checkThat(actualSalaries, is(expectedSalaries)); + collector.checkThat(actualHireDates, is(expectedHireDates)); + collector.checkThat(actualLastSales, is(expectedLastSales)); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/DateTimeUtilsTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/DateTimeUtilsTest.java new file mode 100644 index 0000000000000..adb892fcdc76b --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/DateTimeUtilsTest.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import static org.hamcrest.CoreMatchers.is; + +import java.sql.Timestamp; +import java.time.Instant; +import java.util.Calendar; +import java.util.TimeZone; + +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; + +public class DateTimeUtilsTest { + + @ClassRule + public static final ErrorCollector collector = new ErrorCollector(); + private final TimeZone defaultTimezone = TimeZone.getTimeZone("UTC"); + private final TimeZone alternateTimezone = TimeZone.getTimeZone("America/Vancouver"); + private final long positiveEpochMilli = 959817600000L; // 2000-06-01 00:00:00 UTC + private final long negativeEpochMilli = -618105600000L; // 1950-06-01 00:00:00 UTC + + @Test + public void testShouldGetOffsetWithSameTimeZone() { + final TimeZone currentTimezone = TimeZone.getDefault(); + + final long epochMillis = positiveEpochMilli; + final long offset = defaultTimezone.getOffset(epochMillis); + + TimeZone.setDefault(defaultTimezone); + + try { // Trying to guarantee timezone returns to its original value + final long expected = epochMillis + offset; + final long actual = DateTimeUtils.applyCalendarOffset(epochMillis, Calendar.getInstance(defaultTimezone)); + + collector.checkThat(actual, is(expected)); + } finally { + // Reset Timezone + TimeZone.setDefault(currentTimezone); + } + } + + @Test + public void testShouldGetOffsetWithDifferentTimeZone() { + final TimeZone currentTimezone = TimeZone.getDefault(); + + final long epochMillis = negativeEpochMilli; + final long offset = alternateTimezone.getOffset(epochMillis); + + TimeZone.setDefault(alternateTimezone); + + try { // Trying to guarantee timezone returns to its original value + final long expectedEpochMillis = epochMillis + offset; + final long actualEpochMillis = DateTimeUtils.applyCalendarOffset(epochMillis, Calendar.getInstance( + defaultTimezone)); + + collector.checkThat(actualEpochMillis, is(expectedEpochMillis)); + } finally { + // Reset Timezone + TimeZone.setDefault(currentTimezone); + } + } + + @Test + public void testShouldGetTimestampPositive() { + long epochMilli = positiveEpochMilli; + final Instant instant = Instant.ofEpochMilli(epochMilli); + + final Timestamp expected = Timestamp.from(instant); + final Timestamp actual = DateTimeUtils.getTimestampValue(epochMilli); + + collector.checkThat(expected, is(actual)); + } + + @Test + public void testShouldGetTimestampNegative() { + final long epochMilli = negativeEpochMilli; + final Instant instant = Instant.ofEpochMilli(epochMilli); + + final Timestamp expected = Timestamp.from(instant); + final Timestamp actual = DateTimeUtils.getTimestampValue(epochMilli); + + collector.checkThat(expected, is(actual)); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/FlightSqlTestCertificates.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/FlightSqlTestCertificates.java new file mode 100644 index 0000000000000..a2b1864c02657 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/FlightSqlTestCertificates.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import java.io.File; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; + +/** + * Utility class for unit tests that need to reference the certificate params. + */ +public class FlightSqlTestCertificates { + + public static final String TEST_DATA_ENV_VAR = "ARROW_TEST_DATA"; + public static final String TEST_DATA_PROPERTY = "arrow.test.dataRoot"; + + static Path getTestDataRoot() { + String path = System.getenv(TEST_DATA_ENV_VAR); + if (path == null) { + path = System.getProperty(TEST_DATA_PROPERTY); + } + return Paths.get(Objects.requireNonNull(path, + String.format("Could not find test data path. Set the environment variable %s or the JVM property %s.", + TEST_DATA_ENV_VAR, TEST_DATA_PROPERTY))); + } + + /** + * Get the Path from the Files to be used in the encrypted test of Flight. + * + * @return the Path from the Files with certificates and keys. + */ + static Path getFlightTestDataRoot() { + return getTestDataRoot().resolve("flight"); + } + + /** + * Create CertKeyPair object with the certificates and keys. + * + * @return A list with CertKeyPair. + */ + public static List exampleTlsCerts() { + final Path root = getFlightTestDataRoot(); + return Arrays.asList(new CertKeyPair(root.resolve("cert0.pem") + .toFile(), root.resolve("cert0.pkcs1").toFile()), + new CertKeyPair(root.resolve("cert1.pem") + .toFile(), root.resolve("cert1.pkcs1").toFile())); + } + + public static class CertKeyPair { + + public final File cert; + public final File key; + + public CertKeyPair(File cert, File key) { + this.cert = cert; + this.key = key; + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/FlightStreamQueueTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/FlightStreamQueueTest.java new file mode 100644 index 0000000000000..b474da55a7f1f --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/FlightStreamQueueTest.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.mockito.Mockito.mock; + +import java.util.concurrent.CompletionService; + +import org.apache.arrow.flight.FlightStream; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; + +/** + * Tests for {@link FlightStreamQueue}. + */ +@RunWith(MockitoJUnitRunner.class) +public class FlightStreamQueueTest { + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + @Mock + private CompletionService mockedService; + private FlightStreamQueue queue; + + @Before + public void setUp() { + queue = new FlightStreamQueue(mockedService); + } + + @Test + public void testNextShouldRetrieveNullIfEmpty() throws Exception { + collector.checkThat(queue.next(), is(nullValue())); + } + + @Test + public void testNextShouldThrowExceptionUponClose() throws Exception { + queue.close(); + ThrowableAssertionUtils.simpleAssertThrowableClass(IllegalStateException.class, () -> queue.next()); + } + + @Test + public void testEnqueueShouldThrowExceptionUponClose() throws Exception { + queue.close(); + ThrowableAssertionUtils.simpleAssertThrowableClass(IllegalStateException.class, + () -> queue.enqueue(mock(FlightStream.class))); + } + + @Test + public void testCheckOpen() throws Exception { + collector.checkSucceeds(() -> { + queue.checkOpen(); + return true; + }); + queue.close(); + ThrowableAssertionUtils.simpleAssertThrowableClass(IllegalStateException.class, () -> queue.checkOpen()); + } + + @Test + public void testShouldCloseQueue() throws Exception { + collector.checkThat(queue.isClosed(), is(false)); + queue.close(); + collector.checkThat(queue.isClosed(), is(true)); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/MockFlightSqlProducer.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/MockFlightSqlProducer.java new file mode 100644 index 0000000000000..cc8fae9722f9a --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/MockFlightSqlProducer.java @@ -0,0 +1,539 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import static com.google.protobuf.Any.pack; +import static com.google.protobuf.ByteString.copyFrom; +import static java.lang.String.format; +import static java.util.UUID.randomUUID; +import static java.util.stream.Collectors.toList; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.Channels; +import java.nio.charset.StandardCharsets; +import java.util.AbstractMap.SimpleImmutableEntry; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.UUID; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import java.util.stream.IntStream; + +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.Criteria; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.PutResult; +import org.apache.arrow.flight.Result; +import org.apache.arrow.flight.SchemaResult; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.flight.sql.FlightSqlProducer; +import org.apache.arrow.flight.sql.SqlInfoBuilder; +import org.apache.arrow.flight.sql.impl.FlightSql; +import org.apache.arrow.flight.sql.impl.FlightSql.ActionClosePreparedStatementRequest; +import org.apache.arrow.flight.sql.impl.FlightSql.ActionCreatePreparedStatementRequest; +import org.apache.arrow.flight.sql.impl.FlightSql.ActionCreatePreparedStatementResult; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetCatalogs; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetCrossReference; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetDbSchemas; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetExportedKeys; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetImportedKeys; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetPrimaryKeys; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetSqlInfo; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetTableTypes; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetTables; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandPreparedStatementQuery; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandPreparedStatementUpdate; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementQuery; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementUpdate; +import org.apache.arrow.flight.sql.impl.FlightSql.DoPutUpdateResult; +import org.apache.arrow.flight.sql.impl.FlightSql.TicketStatementQuery; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.ipc.WriteChannel; +import org.apache.arrow.vector.ipc.message.MessageSerializer; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.calcite.avatica.Meta.StatementType; + +import com.google.protobuf.Any; +import com.google.protobuf.ByteString; +import com.google.protobuf.Message; + +/** + * An ad-hoc {@link FlightSqlProducer} for tests. + */ +public final class MockFlightSqlProducer implements FlightSqlProducer { + + private final Map>> queryResults = new HashMap<>(); + private final Map> selectResultProviders = new HashMap<>(); + private final Map preparedStatements = new HashMap<>(); + private final Map> catalogQueriesResults = + new HashMap<>(); + private final Map>> + updateResultProviders = + new HashMap<>(); + private SqlInfoBuilder sqlInfoBuilder = new SqlInfoBuilder(); + + private static FlightInfo getFightInfoExportedAndImportedKeys(final Message message, + final FlightDescriptor descriptor) { + return getFlightInfo(message, Schemas.GET_IMPORTED_KEYS_SCHEMA, descriptor); + } + + private static FlightInfo getFlightInfo(final Message message, final Schema schema, + final FlightDescriptor descriptor) { + return new FlightInfo( + schema, + descriptor, + Collections.singletonList(new FlightEndpoint(new Ticket(Any.pack(message).toByteArray()))), + -1, -1); + } + + public static ByteBuffer serializeSchema(final Schema schema) { + final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + try { + MessageSerializer.serialize(new WriteChannel(Channels.newChannel(outputStream)), schema); + + return ByteBuffer.wrap(outputStream.toByteArray()); + } catch (final IOException e) { + throw new RuntimeException("Failed to serialize schema", e); + } + } + + /** + * Registers a new {@link StatementType#SELECT} SQL query. + * + * @param sqlCommand the SQL command under which to register the new query. + * @param schema the schema to use for the query result. + * @param resultProviders the result provider for this query. + */ + public void addSelectQuery(final String sqlCommand, final Schema schema, + final List> resultProviders) { + final int providers = resultProviders.size(); + final List uuids = + IntStream.range(0, providers) + .mapToObj(index -> new UUID(sqlCommand.hashCode(), Integer.hashCode(index))) + .collect(toList()); + queryResults.put(sqlCommand, new SimpleImmutableEntry<>(schema, uuids)); + IntStream.range(0, providers) + .forEach( + index -> this.selectResultProviders.put(uuids.get(index), resultProviders.get(index))); + } + + /** + * Registers a new {@link StatementType#UPDATE} SQL query. + * + * @param sqlCommand the SQL command. + * @param updatedRows the number of rows affected. + */ + public void addUpdateQuery(final String sqlCommand, final long updatedRows) { + addUpdateQuery(sqlCommand, ((flightStream, putResultStreamListener) -> { + final DoPutUpdateResult result = + DoPutUpdateResult.newBuilder().setRecordCount(updatedRows).build(); + try (final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + final ArrowBuf buffer = allocator.buffer(result.getSerializedSize())) { + buffer.writeBytes(result.toByteArray()); + putResultStreamListener.onNext(PutResult.metadata(buffer)); + } catch (final Throwable throwable) { + putResultStreamListener.onError(throwable); + } finally { + putResultStreamListener.onCompleted(); + } + })); + } + + /** + * Adds a catalog query to the results. + * + * @param message the {@link Message} corresponding to the catalog query request type to register. + * @param resultsProvider the results provider. + */ + public void addCatalogQuery(final Message message, + final Consumer resultsProvider) { + catalogQueriesResults.put(message, resultsProvider); + } + + /** + * Registers a new {@link StatementType#UPDATE} SQL query. + * + * @param sqlCommand the SQL command. + * @param resultsProvider consumer for producing update results. + */ + void addUpdateQuery(final String sqlCommand, + final BiConsumer> resultsProvider) { + Preconditions.checkState( + updateResultProviders.putIfAbsent(sqlCommand, resultsProvider) == null, + format("Attempted to overwrite pre-existing query: <%s>.", sqlCommand)); + } + + @Override + public void createPreparedStatement(final ActionCreatePreparedStatementRequest request, + final CallContext callContext, + final StreamListener listener) { + try { + final ByteString preparedStatementHandle = + copyFrom(randomUUID().toString().getBytes(StandardCharsets.UTF_8)); + final String query = request.getQuery(); + + final ActionCreatePreparedStatementResult.Builder resultBuilder = + ActionCreatePreparedStatementResult.newBuilder() + .setPreparedStatementHandle(preparedStatementHandle); + + final Entry> entry = queryResults.get(query); + if (entry != null) { + preparedStatements.put(preparedStatementHandle, query); + + final Schema datasetSchema = entry.getKey(); + final ByteString datasetSchemaBytes = + ByteString.copyFrom(serializeSchema(datasetSchema)); + + resultBuilder.setDatasetSchema(datasetSchemaBytes); + } else if (updateResultProviders.containsKey(query)) { + preparedStatements.put(preparedStatementHandle, query); + + } else { + listener.onError( + CallStatus.INVALID_ARGUMENT.withDescription("Query not found").toRuntimeException()); + return; + } + + listener.onNext(new Result(pack(resultBuilder.build()).toByteArray())); + } catch (final Throwable t) { + listener.onError(t); + } finally { + listener.onCompleted(); + } + } + + @Override + public void closePreparedStatement( + final ActionClosePreparedStatementRequest actionClosePreparedStatementRequest, + final CallContext callContext, final StreamListener streamListener) { + // TODO Implement this method. + streamListener.onCompleted(); + } + + @Override + public FlightInfo getFlightInfoStatement(final CommandStatementQuery commandStatementQuery, + final CallContext callContext, + final FlightDescriptor flightDescriptor) { + final String query = commandStatementQuery.getQuery(); + final Entry> queryInfo = + Preconditions.checkNotNull(queryResults.get(query), + format("Query not registered: <%s>.", query)); + final List endpoints = + queryInfo.getValue().stream() + .map(TicketConversionUtils::getTicketBytesFromUuid) + .map(TicketConversionUtils::getTicketStatementQueryFromHandle) + .map(TicketConversionUtils::getEndpointFromMessage) + .collect(toList()); + return new FlightInfo(queryInfo.getKey(), flightDescriptor, endpoints, -1, -1); + } + + @Override + public FlightInfo getFlightInfoPreparedStatement( + final CommandPreparedStatementQuery commandPreparedStatementQuery, + final CallContext callContext, + final FlightDescriptor flightDescriptor) { + final ByteString preparedStatementHandle = + commandPreparedStatementQuery.getPreparedStatementHandle(); + + final String query = Preconditions.checkNotNull( + preparedStatements.get(preparedStatementHandle), + format("No query registered under handle: <%s>.", preparedStatementHandle)); + final Entry> queryInfo = + Preconditions.checkNotNull(queryResults.get(query), + format("Query not registered: <%s>.", query)); + final List endpoints = + queryInfo.getValue().stream() + .map(TicketConversionUtils::getTicketBytesFromUuid) + .map(TicketConversionUtils::getCommandPreparedStatementQueryFromHandle) + .map(TicketConversionUtils::getEndpointFromMessage) + .collect(toList()); + return new FlightInfo(queryInfo.getKey(), flightDescriptor, endpoints, -1, -1); + } + + @Override + public SchemaResult getSchemaStatement(final CommandStatementQuery commandStatementQuery, + final CallContext callContext, + final FlightDescriptor flightDescriptor) { + final String query = commandStatementQuery.getQuery(); + final Entry> queryInfo = + Preconditions.checkNotNull(queryResults.get(query), + format("Query not registered: <%s>.", query)); + + return new SchemaResult(queryInfo.getKey()); + } + + @Override + public void getStreamStatement(final TicketStatementQuery ticketStatementQuery, + final CallContext callContext, + final ServerStreamListener serverStreamListener) { + final UUID uuid = UUID.fromString(ticketStatementQuery.getStatementHandle().toStringUtf8()); + Preconditions.checkNotNull( + selectResultProviders.get(uuid), + "No consumer was registered for the specified UUID: <%s>.", uuid) + .accept(serverStreamListener); + } + + @Override + public void getStreamPreparedStatement( + final CommandPreparedStatementQuery commandPreparedStatementQuery, + final CallContext callContext, + final ServerStreamListener serverStreamListener) { + final UUID uuid = + UUID.fromString(commandPreparedStatementQuery.getPreparedStatementHandle().toStringUtf8()); + Preconditions.checkNotNull( + selectResultProviders.get(uuid), + "No consumer was registered for the specified UUID: <%s>.", uuid) + .accept(serverStreamListener); + } + + @Override + public Runnable acceptPutStatement(final CommandStatementUpdate commandStatementUpdate, + final CallContext callContext, + final FlightStream flightStream, + final StreamListener streamListener) { + return () -> { + final String query = commandStatementUpdate.getQuery(); + final BiConsumer> resultProvider = + Preconditions.checkNotNull( + updateResultProviders.get(query), + format("No consumer found for query: <%s>.", query)); + resultProvider.accept(flightStream, streamListener); + }; + } + + @Override + public Runnable acceptPutPreparedStatementUpdate( + final CommandPreparedStatementUpdate commandPreparedStatementUpdate, + final CallContext callContext, final FlightStream flightStream, + final StreamListener streamListener) { + final ByteString handle = commandPreparedStatementUpdate.getPreparedStatementHandle(); + final String query = Preconditions.checkNotNull( + preparedStatements.get(handle), + format("No query registered under handle: <%s>.", handle)); + return acceptPutStatement( + CommandStatementUpdate.newBuilder().setQuery(query).build(), callContext, flightStream, + streamListener); + } + + @Override + public Runnable acceptPutPreparedStatementQuery( + final CommandPreparedStatementQuery commandPreparedStatementQuery, + final CallContext callContext, final FlightStream flightStream, + final StreamListener streamListener) { + // TODO Implement this method. + throw CallStatus.UNIMPLEMENTED.toRuntimeException(); + } + + @Override + public FlightInfo getFlightInfoSqlInfo(final CommandGetSqlInfo commandGetSqlInfo, + final CallContext callContext, + final FlightDescriptor flightDescriptor) { + return getFlightInfo(commandGetSqlInfo, Schemas.GET_SQL_INFO_SCHEMA, flightDescriptor); + } + + @Override + public void getStreamSqlInfo(final CommandGetSqlInfo commandGetSqlInfo, + final CallContext callContext, + final ServerStreamListener serverStreamListener) { + sqlInfoBuilder.send(commandGetSqlInfo.getInfoList(), serverStreamListener); + } + + @Override + public FlightInfo getFlightInfoTypeInfo(FlightSql.CommandGetXdbcTypeInfo request, CallContext context, + FlightDescriptor descriptor) { + // TODO Implement this + return null; + } + + @Override + public void getStreamTypeInfo(FlightSql.CommandGetXdbcTypeInfo request, CallContext context, + ServerStreamListener listener) { + // TODO Implement this + } + + @Override + public FlightInfo getFlightInfoCatalogs(final CommandGetCatalogs commandGetCatalogs, + final CallContext callContext, + final FlightDescriptor flightDescriptor) { + return getFlightInfo(commandGetCatalogs, Schemas.GET_CATALOGS_SCHEMA, flightDescriptor); + } + + @Override + public void getStreamCatalogs(final CallContext callContext, + final ServerStreamListener serverStreamListener) { + final CommandGetCatalogs command = CommandGetCatalogs.getDefaultInstance(); + getStreamCatalogFunctions(command, serverStreamListener); + } + + @Override + public FlightInfo getFlightInfoSchemas(final CommandGetDbSchemas commandGetSchemas, + final CallContext callContext, + final FlightDescriptor flightDescriptor) { + return getFlightInfo(commandGetSchemas, Schemas.GET_SCHEMAS_SCHEMA, flightDescriptor); + } + + @Override + public void getStreamSchemas(final CommandGetDbSchemas commandGetSchemas, + final CallContext callContext, + final ServerStreamListener serverStreamListener) { + getStreamCatalogFunctions(commandGetSchemas, serverStreamListener); + } + + @Override + public FlightInfo getFlightInfoTables(final CommandGetTables commandGetTables, + final CallContext callContext, + final FlightDescriptor flightDescriptor) { + return getFlightInfo(commandGetTables, Schemas.GET_TABLES_SCHEMA_NO_SCHEMA, flightDescriptor); + } + + @Override + public void getStreamTables(final CommandGetTables commandGetTables, + final CallContext callContext, + final ServerStreamListener serverStreamListener) { + getStreamCatalogFunctions(commandGetTables, serverStreamListener); + } + + @Override + public FlightInfo getFlightInfoTableTypes(final CommandGetTableTypes commandGetTableTypes, + final CallContext callContext, + final FlightDescriptor flightDescriptor) { + return getFlightInfo(commandGetTableTypes, Schemas.GET_TABLE_TYPES_SCHEMA, flightDescriptor); + } + + @Override + public void getStreamTableTypes(final CallContext callContext, + final ServerStreamListener serverStreamListener) { + final CommandGetTableTypes command = CommandGetTableTypes.getDefaultInstance(); + getStreamCatalogFunctions(command, serverStreamListener); + } + + @Override + public FlightInfo getFlightInfoPrimaryKeys(final CommandGetPrimaryKeys commandGetPrimaryKeys, + final CallContext callContext, + final FlightDescriptor flightDescriptor) { + return getFlightInfo(commandGetPrimaryKeys, Schemas.GET_PRIMARY_KEYS_SCHEMA, flightDescriptor); + } + + @Override + public void getStreamPrimaryKeys(final CommandGetPrimaryKeys commandGetPrimaryKeys, + final CallContext callContext, + final ServerStreamListener serverStreamListener) { + getStreamCatalogFunctions(commandGetPrimaryKeys, serverStreamListener); + } + + @Override + public FlightInfo getFlightInfoExportedKeys(final CommandGetExportedKeys commandGetExportedKeys, + final CallContext callContext, + final FlightDescriptor flightDescriptor) { + return getFightInfoExportedAndImportedKeys(commandGetExportedKeys, flightDescriptor); + } + + @Override + public FlightInfo getFlightInfoImportedKeys(final CommandGetImportedKeys commandGetImportedKeys, + final CallContext callContext, + final FlightDescriptor flightDescriptor) { + return getFightInfoExportedAndImportedKeys(commandGetImportedKeys, flightDescriptor); + } + + @Override + public FlightInfo getFlightInfoCrossReference( + final CommandGetCrossReference commandGetCrossReference, + final CallContext callContext, + final FlightDescriptor flightDescriptor) { + return getFightInfoExportedAndImportedKeys(commandGetCrossReference, flightDescriptor); + } + + @Override + public void getStreamExportedKeys(final CommandGetExportedKeys commandGetExportedKeys, + final CallContext callContext, + final ServerStreamListener serverStreamListener) { + getStreamCatalogFunctions(commandGetExportedKeys, serverStreamListener); + } + + @Override + public void getStreamImportedKeys(final CommandGetImportedKeys commandGetImportedKeys, + final CallContext callContext, + final ServerStreamListener serverStreamListener) { + getStreamCatalogFunctions(commandGetImportedKeys, serverStreamListener); + } + + @Override + public void getStreamCrossReference(final CommandGetCrossReference commandGetCrossReference, + final CallContext callContext, + final ServerStreamListener serverStreamListener) { + getStreamCatalogFunctions(commandGetCrossReference, serverStreamListener); + } + + @Override + public void close() { + // TODO No-op. + } + + @Override + public void listFlights(final CallContext callContext, final Criteria criteria, + final StreamListener streamListener) { + // TODO Implement this method. + throw CallStatus.UNIMPLEMENTED.toRuntimeException(); + } + + private void getStreamCatalogFunctions(final Message ticket, + final ServerStreamListener serverStreamListener) { + Preconditions.checkNotNull( + catalogQueriesResults.get(ticket), + format("Query not registered for ticket: <%s>", ticket)) + .accept(serverStreamListener); + } + + public SqlInfoBuilder getSqlInfoBuilder() { + return sqlInfoBuilder; + } + + private static final class TicketConversionUtils { + private TicketConversionUtils() { + // Prevent instantiation. + } + + private static ByteString getTicketBytesFromUuid(final UUID uuid) { + return ByteString.copyFromUtf8(uuid.toString()); + } + + private static TicketStatementQuery getTicketStatementQueryFromHandle(final ByteString handle) { + return TicketStatementQuery.newBuilder().setStatementHandle(handle).build(); + } + + private static CommandPreparedStatementQuery getCommandPreparedStatementQueryFromHandle( + final ByteString handle) { + return CommandPreparedStatementQuery.newBuilder().setPreparedStatementHandle(handle).build(); + } + + private static FlightEndpoint getEndpointFromMessage(final Message message) { + return new FlightEndpoint(new Ticket(Any.pack(message).toByteArray())); + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/ResultSetTestUtils.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/ResultSetTestUtils.java new file mode 100644 index 0000000000000..d5ce7fb8fb36f --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/ResultSetTestUtils.java @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import static java.util.stream.IntStream.range; +import static org.hamcrest.CoreMatchers.is; + +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Function; + +import org.apache.arrow.util.Preconditions; +import org.junit.rules.ErrorCollector; + +/** + * Utility class for testing that require asserting that the values in a {@link ResultSet} are expected. + */ +public final class ResultSetTestUtils { + private final ErrorCollector collector; + + public ResultSetTestUtils(final ErrorCollector collector) { + this.collector = + Preconditions.checkNotNull(collector, "Error collector cannot be null."); + } + + /** + * Checks that the values (rows and columns) in the provided {@link ResultSet} are expected. + * + * @param resultSet the {@code ResultSet} to assert. + * @param expectedResults the rows and columns representing the only values the {@code resultSet} + * is expected to have. + * @param collector the {@link ErrorCollector} to use for asserting that the {@code resultSet} + * has the expected values. + * @param the type to be found in the expected results for the {@code resultSet}. + * @throws SQLException if querying the {@code ResultSet} fails at some point unexpectedly. + */ + public static void testData(final ResultSet resultSet, final List> expectedResults, + final ErrorCollector collector) + throws SQLException { + testData( + resultSet, + range(1, resultSet.getMetaData().getColumnCount() + 1).toArray(), + expectedResults, + collector); + } + + /** + * Checks that the values (rows and columns) in the provided {@link ResultSet} are expected. + * + * @param resultSet the {@code ResultSet} to assert. + * @param columnNames the column names to fetch in the {@code ResultSet} for comparison. + * @param expectedResults the rows and columns representing the only values the {@code resultSet} + * is expected to have. + * @param collector the {@link ErrorCollector} to use for asserting that the {@code resultSet} + * has the expected values. + * @param the type to be found in the expected results for the {@code resultSet}. + * @throws SQLException if querying the {@code ResultSet} fails at some point unexpectedly. + */ + @SuppressWarnings("unchecked") + public static void testData(final ResultSet resultSet, final List columnNames, + final List> expectedResults, + final ErrorCollector collector) + throws SQLException { + testData( + resultSet, + data -> { + final List columns = new ArrayList<>(); + for (final String columnName : columnNames) { + try { + columns.add((T) resultSet.getObject(columnName)); + } catch (final SQLException e) { + collector.addError(e); + } + } + return columns; + }, + expectedResults, + collector); + } + + /** + * Checks that the values (rows and columns) in the provided {@link ResultSet} are expected. + * + * @param resultSet the {@code ResultSet} to assert. + * @param columnIndices the column indices to fetch in the {@code ResultSet} for comparison. + * @param expectedResults the rows and columns representing the only values the {@code resultSet} + * is expected to have. + * @param collector the {@link ErrorCollector} to use for asserting that the {@code resultSet} + * has the expected values. + * @param the type to be found in the expected results for the {@code resultSet}. + * @throws SQLException if querying the {@code ResultSet} fails at some point unexpectedly. + */ + @SuppressWarnings("unchecked") + public static void testData(final ResultSet resultSet, final int[] columnIndices, + final List> expectedResults, + final ErrorCollector collector) + throws SQLException { + testData( + resultSet, + data -> { + final List columns = new ArrayList<>(); + for (final int columnIndex : columnIndices) { + try { + columns.add((T) resultSet.getObject(columnIndex)); + } catch (final SQLException e) { + collector.addError(e); + } + } + return columns; + }, + expectedResults, + collector); + } + + /** + * Checks that the values (rows and columns) in the provided {@link ResultSet} are expected. + * + * @param resultSet the {@code ResultSet} to assert. + * @param dataConsumer the column indices to fetch in the {@code ResultSet} for comparison. + * @param expectedResults the rows and columns representing the only values the {@code resultSet} + * is expected to have. + * @param collector the {@link ErrorCollector} to use for asserting that the {@code resultSet} + * has the expected values. + * @param the type to be found in the expected results for the {@code resultSet}. + * @throws SQLException if querying the {@code ResultSet} fails at some point unexpectedly. + */ + public static void testData(final ResultSet resultSet, + final Function> dataConsumer, + final List> expectedResults, + final ErrorCollector collector) + throws SQLException { + final List> actualResults = new ArrayList<>(); + while (resultSet.next()) { + actualResults.add(dataConsumer.apply(resultSet)); + } + collector.checkThat(actualResults, is(expectedResults)); + } + + /** + * Checks that the values (rows and columns) in the provided {@link ResultSet} are expected. + * + * @param resultSet the {@code ResultSet} to assert. + * @param expectedResults the rows and columns representing the only values the {@code resultSet} is expected to have. + * @param the type to be found in the expected results for the {@code resultSet}. + * @throws SQLException if querying the {@code ResultSet} fails at some point unexpectedly. + */ + public void testData(final ResultSet resultSet, final List> expectedResults) + throws SQLException { + testData(resultSet, expectedResults, collector); + } + + /** + * Checks that the values (rows and columns) in the provided {@link ResultSet} are expected. + * + * @param resultSet the {@code ResultSet} to assert. + * @param columnNames the column names to fetch in the {@code ResultSet} for comparison. + * @param expectedResults the rows and columns representing the only values the {@code resultSet} is expected to have. + * @param the type to be found in the expected results for the {@code resultSet}. + * @throws SQLException if querying the {@code ResultSet} fails at some point unexpectedly. + */ + @SuppressWarnings("unchecked") + public void testData(final ResultSet resultSet, final List columnNames, + final List> expectedResults) throws SQLException { + testData(resultSet, columnNames, expectedResults, collector); + } + + /** + * Checks that the values (rows and columns) in the provided {@link ResultSet} are expected. + * + * @param resultSet the {@code ResultSet} to assert. + * @param columnIndices the column indices to fetch in the {@code ResultSet} for comparison. + * @param expectedResults the rows and columns representing the only values the {@code resultSet} is expected to have. + * @param the type to be found in the expected results for the {@code resultSet}. + * @throws SQLException if querying the {@code ResultSet} fails at some point unexpectedly. + */ + @SuppressWarnings("unchecked") + public void testData(final ResultSet resultSet, final int[] columnIndices, + final List> expectedResults) throws SQLException { + testData(resultSet, columnIndices, expectedResults, collector); + } + + /** + * Checks that the values (rows and columns) in the provided {@link ResultSet} are expected. + * + * @param resultSet the {@code ResultSet} to assert. + * @param dataConsumer the column indices to fetch in the {@code ResultSet} for comparison. + * @param expectedResults the rows and columns representing the only values the {@code resultSet} is expected to have. + * @param the type to be found in the expected results for the {@code resultSet}. + * @throws SQLException if querying the {@code ResultSet} fails at some point unexpectedly. + */ + public void testData(final ResultSet resultSet, + final Function> dataConsumer, + final List> expectedResults) throws SQLException { + testData(resultSet, dataConsumer, expectedResults, collector); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/RootAllocatorTestRule.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/RootAllocatorTestRule.java new file mode 100644 index 0000000000000..a200fc8d39c15 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/RootAllocatorTestRule.java @@ -0,0 +1,820 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import java.math.BigDecimal; +import java.util.Random; +import java.util.concurrent.TimeUnit; +import java.util.stream.IntStream; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.DateDayVector; +import org.apache.arrow.vector.DateMilliVector; +import org.apache.arrow.vector.Decimal256Vector; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.FixedSizeBinaryVector; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.LargeVarBinaryVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TimeMicroVector; +import org.apache.arrow.vector.TimeMilliVector; +import org.apache.arrow.vector.TimeNanoVector; +import org.apache.arrow.vector.TimeSecVector; +import org.apache.arrow.vector.TimeStampMicroTZVector; +import org.apache.arrow.vector.TimeStampMicroVector; +import org.apache.arrow.vector.TimeStampMilliTZVector; +import org.apache.arrow.vector.TimeStampMilliVector; +import org.apache.arrow.vector.TimeStampNanoTZVector; +import org.apache.arrow.vector.TimeStampNanoVector; +import org.apache.arrow.vector.TimeStampSecTZVector; +import org.apache.arrow.vector.TimeStampSecVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.UInt1Vector; +import org.apache.arrow.vector.UInt2Vector; +import org.apache.arrow.vector.UInt4Vector; +import org.apache.arrow.vector.UInt8Vector; +import org.apache.arrow.vector.VarBinaryVector; +import org.apache.arrow.vector.complex.FixedSizeListVector; +import org.apache.arrow.vector.complex.LargeListVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.impl.UnionFixedSizeListWriter; +import org.apache.arrow.vector.complex.impl.UnionLargeListWriter; +import org.apache.arrow.vector.complex.impl.UnionListWriter; +import org.junit.rules.TestRule; +import org.junit.runner.Description; +import org.junit.runners.model.Statement; + +public class RootAllocatorTestRule implements TestRule, AutoCloseable { + + public static final byte MAX_VALUE = Byte.MAX_VALUE; + private final BufferAllocator rootAllocator = new RootAllocator(); + + private final Random random = new Random(10); + + @Override + public Statement apply(Statement base, Description description) { + return new Statement() { + @Override + public void evaluate() throws Throwable { + try { + base.evaluate(); + } finally { + close(); + } + } + }; + } + + public BufferAllocator getRootAllocator() { + return rootAllocator; + } + + @Override + public void close() throws Exception { + this.rootAllocator.getChildAllocators().forEach(BufferAllocator::close); + AutoCloseables.close(this.rootAllocator); + } + + /** + * Create a Float8Vector to be used in the accessor tests. + * + * @return Float8Vector + */ + public Float8Vector createFloat8Vector() { + double[] doubleVectorValues = new double[] { + 0, + 1, + -1, + Byte.MIN_VALUE, + Byte.MAX_VALUE, + Short.MIN_VALUE, + Short.MAX_VALUE, + Integer.MIN_VALUE, + Integer.MAX_VALUE, + Long.MIN_VALUE, + Long.MAX_VALUE, + Float.MAX_VALUE, + -Float.MAX_VALUE, + Float.NEGATIVE_INFINITY, + Float.POSITIVE_INFINITY, + Float.MIN_VALUE, + -Float.MIN_VALUE, + Double.MAX_VALUE, + -Double.MAX_VALUE, + Double.NEGATIVE_INFINITY, + Double.POSITIVE_INFINITY, + Double.MIN_VALUE, + -Double.MIN_VALUE, + }; + + Float8Vector result = new Float8Vector("", this.getRootAllocator()); + result.setValueCount(MAX_VALUE); + for (int i = 0; i < MAX_VALUE; i++) { + if (i < doubleVectorValues.length) { + result.setSafe(i, doubleVectorValues[i]); + } else { + result.setSafe(i, random.nextDouble()); + } + } + + return result; + } + + public Float8Vector createFloat8VectorForNullTests() { + final Float8Vector float8Vector = new Float8Vector("ID", this.getRootAllocator()); + float8Vector.allocateNew(1); + float8Vector.setNull(0); + float8Vector.setValueCount(1); + + return float8Vector; + } + + /** + * Create a Float4Vector to be used in the accessor tests. + * + * @return Float4Vector + */ + public Float4Vector createFloat4Vector() { + + float[] floatVectorValues = new float[] { + 0, + 1, + -1, + Byte.MIN_VALUE, + Byte.MAX_VALUE, + Short.MIN_VALUE, + Short.MAX_VALUE, + Integer.MIN_VALUE, + Integer.MAX_VALUE, + Long.MIN_VALUE, + Long.MAX_VALUE, + Float.MAX_VALUE, + -Float.MAX_VALUE, + Float.NEGATIVE_INFINITY, + Float.POSITIVE_INFINITY, + Float.MIN_VALUE, + -Float.MIN_VALUE, + }; + + Float4Vector result = new Float4Vector("", this.getRootAllocator()); + result.setValueCount(MAX_VALUE); + for (int i = 0; i < MAX_VALUE; i++) { + if (i < floatVectorValues.length) { + result.setSafe(i, floatVectorValues[i]); + } else { + result.setSafe(i, random.nextFloat()); + } + } + + return result; + } + + /** + * Create a IntVector to be used in the accessor tests. + * + * @return IntVector + */ + public IntVector createIntVector() { + + int[] intVectorValues = new int[] { + 0, + 1, + -1, + Byte.MIN_VALUE, + Byte.MAX_VALUE, + Short.MIN_VALUE, + Short.MAX_VALUE, + Integer.MIN_VALUE, + Integer.MAX_VALUE, + }; + + IntVector result = new IntVector("", this.getRootAllocator()); + result.setValueCount(MAX_VALUE); + for (int i = 0; i < MAX_VALUE; i++) { + if (i < intVectorValues.length) { + result.setSafe(i, intVectorValues[i]); + } else { + result.setSafe(i, random.nextInt()); + } + } + + return result; + } + + /** + * Create a SmallIntVector to be used in the accessor tests. + * + * @return SmallIntVector + */ + public SmallIntVector createSmallIntVector() { + + short[] smallIntVectorValues = new short[] { + 0, + 1, + -1, + Byte.MIN_VALUE, + Byte.MAX_VALUE, + Short.MIN_VALUE, + Short.MAX_VALUE, + }; + + SmallIntVector result = new SmallIntVector("", this.getRootAllocator()); + result.setValueCount(MAX_VALUE); + for (int i = 0; i < MAX_VALUE; i++) { + if (i < smallIntVectorValues.length) { + result.setSafe(i, smallIntVectorValues[i]); + } else { + result.setSafe(i, random.nextInt(Short.MAX_VALUE)); + } + } + + return result; + } + + /** + * Create a TinyIntVector to be used in the accessor tests. + * + * @return TinyIntVector + */ + public TinyIntVector createTinyIntVector() { + + byte[] byteVectorValues = new byte[] { + 0, + 1, + -1, + Byte.MIN_VALUE, + Byte.MAX_VALUE, + }; + + TinyIntVector result = new TinyIntVector("", this.getRootAllocator()); + result.setValueCount(MAX_VALUE); + for (int i = 0; i < MAX_VALUE; i++) { + if (i < byteVectorValues.length) { + result.setSafe(i, byteVectorValues[i]); + } else { + result.setSafe(i, random.nextInt(Byte.MAX_VALUE)); + } + } + + return result; + } + + /** + * Create a BigIntVector to be used in the accessor tests. + * + * @return BigIntVector + */ + public BigIntVector createBigIntVector() { + + long[] longVectorValues = new long[] { + 0, + 1, + -1, + Byte.MIN_VALUE, + Byte.MAX_VALUE, + Short.MIN_VALUE, + Short.MAX_VALUE, + Integer.MIN_VALUE, + Integer.MAX_VALUE, + Long.MIN_VALUE, + Long.MAX_VALUE, + }; + + BigIntVector result = new BigIntVector("", this.getRootAllocator()); + result.setValueCount(MAX_VALUE); + for (int i = 0; i < MAX_VALUE; i++) { + if (i < longVectorValues.length) { + result.setSafe(i, longVectorValues[i]); + } else { + result.setSafe(i, random.nextLong()); + } + } + + return result; + } + + /** + * Create a UInt1Vector to be used in the accessor tests. + * + * @return UInt1Vector + */ + public UInt1Vector createUInt1Vector() { + + short[] uInt1VectorValues = new short[] { + 0, + 1, + -1, + Byte.MIN_VALUE, + Byte.MAX_VALUE, + }; + + UInt1Vector result = new UInt1Vector("", this.getRootAllocator()); + result.setValueCount(MAX_VALUE); + for (int i = 0; i < MAX_VALUE; i++) { + if (i < uInt1VectorValues.length) { + result.setSafe(i, uInt1VectorValues[i]); + } else { + result.setSafe(i, random.nextInt(0x100)); + } + } + + return result; + } + + /** + * Create a UInt2Vector to be used in the accessor tests. + * + * @return UInt2Vector + */ + public UInt2Vector createUInt2Vector() { + + int[] uInt2VectorValues = new int[] { + 0, + 1, + -1, + Byte.MIN_VALUE, + Byte.MAX_VALUE, + Short.MIN_VALUE, + Short.MAX_VALUE, + }; + + UInt2Vector result = new UInt2Vector("", this.getRootAllocator()); + result.setValueCount(MAX_VALUE); + for (int i = 0; i < MAX_VALUE; i++) { + if (i < uInt2VectorValues.length) { + result.setSafe(i, uInt2VectorValues[i]); + } else { + result.setSafe(i, random.nextInt(0x10000)); + } + } + + return result; + } + + /** + * Create a UInt4Vector to be used in the accessor tests. + * + * @return UInt4Vector + */ + public UInt4Vector createUInt4Vector() { + + + int[] uInt4VectorValues = new int[] { + 0, + 1, + -1, + Byte.MIN_VALUE, + Byte.MAX_VALUE, + Short.MIN_VALUE, + Short.MAX_VALUE, + Integer.MIN_VALUE, + Integer.MAX_VALUE + }; + + UInt4Vector result = new UInt4Vector("", this.getRootAllocator()); + result.setValueCount(MAX_VALUE); + for (int i = 0; i < MAX_VALUE; i++) { + if (i < uInt4VectorValues.length) { + result.setSafe(i, uInt4VectorValues[i]); + } else { + result.setSafe(i, random.nextInt(Integer.MAX_VALUE)); + } + } + + return result; + } + + /** + * Create a UInt8Vector to be used in the accessor tests. + * + * @return UInt8Vector + */ + public UInt8Vector createUInt8Vector() { + + long[] uInt8VectorValues = new long[] { + 0, + 1, + -1, + Byte.MIN_VALUE, + Byte.MAX_VALUE, + Short.MIN_VALUE, + Short.MAX_VALUE, + Integer.MIN_VALUE, + Integer.MAX_VALUE, + Long.MIN_VALUE, + Long.MAX_VALUE + }; + + UInt8Vector result = new UInt8Vector("", this.getRootAllocator()); + result.setValueCount(MAX_VALUE); + for (int i = 0; i < MAX_VALUE; i++) { + if (i < uInt8VectorValues.length) { + result.setSafe(i, uInt8VectorValues[i]); + } else { + result.setSafe(i, random.nextLong()); + } + } + + return result; + } + + /** + * Create a VarBinaryVector to be used in the accessor tests. + * + * @return VarBinaryVector + */ + public VarBinaryVector createVarBinaryVector() { + return createVarBinaryVector(""); + } + + /** + * Create a VarBinaryVector to be used in the accessor tests. + * + * @return VarBinaryVector + */ + public VarBinaryVector createVarBinaryVector(final String fieldName) { + VarBinaryVector valueVector = new VarBinaryVector(fieldName, this.getRootAllocator()); + valueVector.allocateNew(3); + valueVector.setSafe(0, (fieldName + "__BINARY_DATA_0001").getBytes()); + valueVector.setSafe(1, (fieldName + "__BINARY_DATA_0002").getBytes()); + valueVector.setSafe(2, (fieldName + "__BINARY_DATA_0003").getBytes()); + valueVector.setValueCount(3); + + return valueVector; + } + + /** + * Create a LargeVarBinaryVector to be used in the accessor tests. + * + * @return LargeVarBinaryVector + */ + public LargeVarBinaryVector createLargeVarBinaryVector() { + LargeVarBinaryVector valueVector = new LargeVarBinaryVector("", this.getRootAllocator()); + valueVector.allocateNew(3); + valueVector.setSafe(0, "BINARY_DATA_0001".getBytes()); + valueVector.setSafe(1, "BINARY_DATA_0002".getBytes()); + valueVector.setSafe(2, "BINARY_DATA_0003".getBytes()); + valueVector.setValueCount(3); + + return valueVector; + } + + /** + * Create a FixedSizeBinaryVector to be used in the accessor tests. + * + * @return FixedSizeBinaryVector + */ + public FixedSizeBinaryVector createFixedSizeBinaryVector() { + FixedSizeBinaryVector valueVector = new FixedSizeBinaryVector("", this.getRootAllocator(), 16); + valueVector.allocateNew(3); + valueVector.setSafe(0, "BINARY_DATA_0001".getBytes()); + valueVector.setSafe(1, "BINARY_DATA_0002".getBytes()); + valueVector.setSafe(2, "BINARY_DATA_0003".getBytes()); + valueVector.setValueCount(3); + + return valueVector; + } + + /** + * Create a UInt8Vector to be used in the accessor tests. + * + * @return UInt8Vector + */ + public DecimalVector createDecimalVector() { + + BigDecimal[] bigDecimalValues = new BigDecimal[] { + new BigDecimal(0), + new BigDecimal(1), + new BigDecimal(-1), + new BigDecimal(Byte.MIN_VALUE), + new BigDecimal(Byte.MAX_VALUE), + new BigDecimal(-Short.MAX_VALUE), + new BigDecimal(Short.MIN_VALUE), + new BigDecimal(Integer.MIN_VALUE), + new BigDecimal(Integer.MAX_VALUE), + new BigDecimal(Long.MIN_VALUE), + new BigDecimal(-Long.MAX_VALUE), + new BigDecimal("170141183460469231731687303715884105727") + }; + + DecimalVector result = new DecimalVector("ID", this.getRootAllocator(), 39, 0); + result.setValueCount(MAX_VALUE); + for (int i = 0; i < MAX_VALUE; i++) { + if (i < bigDecimalValues.length) { + result.setSafe(i, bigDecimalValues[i]); + } else { + result.setSafe(i, random.nextLong()); + } + } + + return result; + } + + /** + * Create a UInt8Vector to be used in the accessor tests. + * + * @return UInt8Vector + */ + public Decimal256Vector createDecimal256Vector() { + + BigDecimal[] bigDecimalValues = new BigDecimal[] { + new BigDecimal(0), + new BigDecimal(1), + new BigDecimal(-1), + new BigDecimal(Byte.MIN_VALUE), + new BigDecimal(Byte.MAX_VALUE), + new BigDecimal(-Short.MAX_VALUE), + new BigDecimal(Short.MIN_VALUE), + new BigDecimal(Integer.MIN_VALUE), + new BigDecimal(Integer.MAX_VALUE), + new BigDecimal(Long.MIN_VALUE), + new BigDecimal(-Long.MAX_VALUE), + new BigDecimal("170141183460469231731687303715884105727"), + new BigDecimal("17014118346046923173168234157303715884105727"), + new BigDecimal("1701411834604692317316823415265417303715884105727"), + new BigDecimal("-17014118346046923173168234152654115451237303715884105727"), + new BigDecimal("-17014118346046923173168234152654115451231545157303715884105727"), + new BigDecimal("1701411834604692315815656534152654115451231545157303715884105727"), + new BigDecimal("30560141183460469231581565634152654115451231545157303715884105727"), + new BigDecimal( + "57896044618658097711785492504343953926634992332820282019728792003956564819967"), + new BigDecimal( + "-56896044618658097711785492504343953926634992332820282019728792003956564819967") + }; + + Decimal256Vector result = new Decimal256Vector("ID", this.getRootAllocator(), 77, 0); + result.setValueCount(MAX_VALUE); + for (int i = 0; i < MAX_VALUE; i++) { + if (i < bigDecimalValues.length) { + result.setSafe(i, bigDecimalValues[i]); + } else { + result.setSafe(i, random.nextLong()); + } + } + + return result; + } + + public TimeStampNanoVector createTimeStampNanoVector() { + TimeStampNanoVector valueVector = new TimeStampNanoVector("", this.getRootAllocator()); + valueVector.allocateNew(2); + valueVector.setSafe(0, TimeUnit.MILLISECONDS.toNanos(1625702400000L)); + valueVector.setSafe(1, TimeUnit.MILLISECONDS.toNanos(1625788800000L)); + valueVector.setValueCount(2); + + return valueVector; + } + + public TimeStampNanoTZVector createTimeStampNanoTZVector(String timeZone) { + TimeStampNanoTZVector valueVector = + new TimeStampNanoTZVector("", this.getRootAllocator(), timeZone); + valueVector.allocateNew(2); + valueVector.setSafe(0, TimeUnit.MILLISECONDS.toNanos(1625702400000L)); + valueVector.setSafe(1, TimeUnit.MILLISECONDS.toNanos(1625788800000L)); + valueVector.setValueCount(2); + + return valueVector; + } + + public TimeStampMicroVector createTimeStampMicroVector() { + TimeStampMicroVector valueVector = new TimeStampMicroVector("", this.getRootAllocator()); + valueVector.allocateNew(2); + valueVector.setSafe(0, TimeUnit.MILLISECONDS.toMicros(1625702400000L)); + valueVector.setSafe(1, TimeUnit.MILLISECONDS.toMicros(1625788800000L)); + valueVector.setValueCount(2); + + return valueVector; + } + + public TimeStampMicroTZVector createTimeStampMicroTZVector(String timeZone) { + TimeStampMicroTZVector valueVector = + new TimeStampMicroTZVector("", this.getRootAllocator(), timeZone); + valueVector.allocateNew(2); + valueVector.setSafe(0, TimeUnit.MILLISECONDS.toMicros(1625702400000L)); + valueVector.setSafe(1, TimeUnit.MILLISECONDS.toMicros(1625788800000L)); + valueVector.setValueCount(2); + + return valueVector; + } + + public TimeStampMilliVector createTimeStampMilliVector() { + TimeStampMilliVector valueVector = new TimeStampMilliVector("", this.getRootAllocator()); + valueVector.allocateNew(2); + valueVector.setSafe(0, 1625702400000L); + valueVector.setSafe(1, 1625788800000L); + valueVector.setValueCount(2); + + return valueVector; + } + + public TimeStampMilliTZVector createTimeStampMilliTZVector(String timeZone) { + TimeStampMilliTZVector valueVector = + new TimeStampMilliTZVector("", this.getRootAllocator(), timeZone); + valueVector.allocateNew(2); + valueVector.setSafe(0, 1625702400000L); + valueVector.setSafe(1, 1625788800000L); + valueVector.setValueCount(2); + + return valueVector; + } + + public TimeStampSecVector createTimeStampSecVector() { + TimeStampSecVector valueVector = new TimeStampSecVector("", this.getRootAllocator()); + valueVector.allocateNew(2); + valueVector.setSafe(0, TimeUnit.MILLISECONDS.toSeconds(1625702400000L)); + valueVector.setSafe(1, TimeUnit.MILLISECONDS.toSeconds(1625788800000L)); + valueVector.setValueCount(2); + + return valueVector; + } + + public TimeStampSecTZVector createTimeStampSecTZVector(String timeZone) { + TimeStampSecTZVector valueVector = + new TimeStampSecTZVector("", this.getRootAllocator(), timeZone); + valueVector.allocateNew(2); + valueVector.setSafe(0, TimeUnit.MILLISECONDS.toSeconds(1625702400000L)); + valueVector.setSafe(1, TimeUnit.MILLISECONDS.toSeconds(1625788800000L)); + valueVector.setValueCount(2); + + return valueVector; + } + + public BitVector createBitVector() { + BitVector valueVector = new BitVector("Value", this.getRootAllocator()); + valueVector.allocateNew(2); + valueVector.setSafe(0, 0); + valueVector.setSafe(1, 1); + valueVector.setValueCount(2); + + return valueVector; + } + + public BitVector createBitVectorForNullTests() { + final BitVector bitVector = new BitVector("ID", this.getRootAllocator()); + bitVector.allocateNew(2); + bitVector.setNull(0); + bitVector.setValueCount(1); + + return bitVector; + } + + public TimeNanoVector createTimeNanoVector() { + TimeNanoVector valueVector = new TimeNanoVector("", this.getRootAllocator()); + valueVector.allocateNew(5); + valueVector.setSafe(0, 0); + valueVector.setSafe(1, 1_000_000_000L); // 1 second + valueVector.setSafe(2, 60 * 1_000_000_000L); // 1 minute + valueVector.setSafe(3, 60 * 60 * 1_000_000_000L); // 1 hour + valueVector.setSafe(4, (24 * 60 * 60 - 1) * 1_000_000_000L); // 23:59:59 + valueVector.setValueCount(5); + + return valueVector; + } + + public TimeMicroVector createTimeMicroVector() { + TimeMicroVector valueVector = new TimeMicroVector("", this.getRootAllocator()); + valueVector.allocateNew(5); + valueVector.setSafe(0, 0); + valueVector.setSafe(1, 1_000_000L); // 1 second + valueVector.setSafe(2, 60 * 1_000_000L); // 1 minute + valueVector.setSafe(3, 60 * 60 * 1_000_000L); // 1 hour + valueVector.setSafe(4, (24 * 60 * 60 - 1) * 1_000_000L); // 23:59:59 + valueVector.setValueCount(5); + + return valueVector; + } + + public TimeMilliVector createTimeMilliVector() { + TimeMilliVector valueVector = new TimeMilliVector("", this.getRootAllocator()); + valueVector.allocateNew(5); + valueVector.setSafe(0, 0); + valueVector.setSafe(1, 1_000); // 1 second + valueVector.setSafe(2, 60 * 1_000); // 1 minute + valueVector.setSafe(3, 60 * 60 * 1_000); // 1 hour + valueVector.setSafe(4, (24 * 60 * 60 - 1) * 1_000); // 23:59:59 + valueVector.setValueCount(5); + + return valueVector; + } + + public TimeSecVector createTimeSecVector() { + TimeSecVector valueVector = new TimeSecVector("", this.getRootAllocator()); + valueVector.allocateNew(5); + valueVector.setSafe(0, 0); + valueVector.setSafe(1, 1); // 1 second + valueVector.setSafe(2, 60); // 1 minute + valueVector.setSafe(3, 60 * 60); // 1 hour + valueVector.setSafe(4, (24 * 60 * 60 - 1)); // 23:59:59 + valueVector.setValueCount(5); + + return valueVector; + } + + public DateDayVector createDateDayVector() { + DateDayVector valueVector = new DateDayVector("", this.getRootAllocator()); + valueVector.allocateNew(2); + valueVector.setSafe(0, (int) TimeUnit.MILLISECONDS.toDays(1625702400000L)); + valueVector.setSafe(1, (int) TimeUnit.MILLISECONDS.toDays(1625788800000L)); + valueVector.setValueCount(2); + + return valueVector; + } + + public DateMilliVector createDateMilliVector() { + DateMilliVector valueVector = new DateMilliVector("", this.getRootAllocator()); + valueVector.allocateNew(2); + valueVector.setSafe(0, 1625702400000L); + valueVector.setSafe(1, 1625788800000L); + valueVector.setValueCount(2); + + return valueVector; + } + + public ListVector createListVector() { + return createListVector(""); + } + + public ListVector createListVector(String fieldName) { + ListVector valueVector = ListVector.empty(fieldName, this.getRootAllocator()); + valueVector.setInitialCapacity(MAX_VALUE); + + UnionListWriter writer = valueVector.getWriter(); + + IntStream range = IntStream.range(0, MAX_VALUE); + + range.forEach(row -> { + writer.startList(); + writer.setPosition(row); + IntStream.range(0, 5).map(j -> j * row).forEach(writer::writeInt); + writer.setValueCount(5); + writer.endList(); + }); + + valueVector.setValueCount(MAX_VALUE); + + return valueVector; + } + + public LargeListVector createLargeListVector() { + LargeListVector valueVector = LargeListVector.empty("", this.getRootAllocator()); + valueVector.setInitialCapacity(MAX_VALUE); + + UnionLargeListWriter writer = valueVector.getWriter(); + + IntStream range = IntStream.range(0, MAX_VALUE); + + range.forEach(row -> { + writer.startList(); + writer.setPosition(row); + IntStream.range(0, 5).map(j -> j * row).forEach(writer::writeInt); + writer.setValueCount(5); + writer.endList(); + }); + + valueVector.setValueCount(MAX_VALUE); + + return valueVector; + } + + public FixedSizeListVector createFixedSizeListVector() { + FixedSizeListVector valueVector = FixedSizeListVector.empty("", 5, this.getRootAllocator()); + valueVector.setInitialCapacity(MAX_VALUE); + + UnionFixedSizeListWriter writer = valueVector.getWriter(); + + IntStream range = IntStream.range(0, MAX_VALUE); + + range.forEach(row -> { + writer.startList(); + writer.setPosition(row); + IntStream.range(0, 5).map(j -> j * row).forEach(writer::writeInt); + writer.setValueCount(5); + writer.endList(); + }); + + valueVector.setValueCount(MAX_VALUE); + + return valueVector; + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/SqlTypesTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/SqlTypesTest.java new file mode 100644 index 0000000000000..5c7c873e55c41 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/SqlTypesTest.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import static org.apache.arrow.driver.jdbc.utils.SqlTypes.getSqlTypeIdFromArrowType; +import static org.apache.arrow.driver.jdbc.utils.SqlTypes.getSqlTypeNameFromArrowType; +import static org.junit.Assert.assertEquals; + +import java.sql.Types; + +import org.apache.arrow.vector.types.DateUnit; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.IntervalUnit; +import org.apache.arrow.vector.types.TimeUnit; +import org.apache.arrow.vector.types.UnionMode; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.junit.Test; + +public class SqlTypesTest { + + @Test + public void testGetSqlTypeIdFromArrowType() { + assertEquals(Types.TINYINT, getSqlTypeIdFromArrowType(new ArrowType.Int(8, true))); + assertEquals(Types.SMALLINT, getSqlTypeIdFromArrowType(new ArrowType.Int(16, true))); + assertEquals(Types.INTEGER, getSqlTypeIdFromArrowType(new ArrowType.Int(32, true))); + assertEquals(Types.BIGINT, getSqlTypeIdFromArrowType(new ArrowType.Int(64, true))); + + assertEquals(Types.BINARY, getSqlTypeIdFromArrowType(new ArrowType.FixedSizeBinary(1024))); + assertEquals(Types.VARBINARY, getSqlTypeIdFromArrowType(new ArrowType.Binary())); + assertEquals(Types.LONGVARBINARY, getSqlTypeIdFromArrowType(new ArrowType.LargeBinary())); + + assertEquals(Types.VARCHAR, getSqlTypeIdFromArrowType(new ArrowType.Utf8())); + assertEquals(Types.LONGVARCHAR, getSqlTypeIdFromArrowType(new ArrowType.LargeUtf8())); + + assertEquals(Types.DATE, getSqlTypeIdFromArrowType(new ArrowType.Date(DateUnit.MILLISECOND))); + assertEquals(Types.TIME, + getSqlTypeIdFromArrowType(new ArrowType.Time(TimeUnit.MILLISECOND, 32))); + assertEquals(Types.TIMESTAMP, + getSqlTypeIdFromArrowType(new ArrowType.Timestamp(TimeUnit.MILLISECOND, ""))); + + assertEquals(Types.BOOLEAN, getSqlTypeIdFromArrowType(new ArrowType.Bool())); + + assertEquals(Types.DECIMAL, getSqlTypeIdFromArrowType(new ArrowType.Decimal(0, 0, 64))); + assertEquals(Types.DOUBLE, + getSqlTypeIdFromArrowType(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE))); + assertEquals(Types.FLOAT, + getSqlTypeIdFromArrowType(new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE))); + + assertEquals(Types.ARRAY, getSqlTypeIdFromArrowType(new ArrowType.List())); + assertEquals(Types.ARRAY, getSqlTypeIdFromArrowType(new ArrowType.LargeList())); + assertEquals(Types.ARRAY, getSqlTypeIdFromArrowType(new ArrowType.FixedSizeList(10))); + + assertEquals(Types.JAVA_OBJECT, getSqlTypeIdFromArrowType(new ArrowType.Struct())); + assertEquals(Types.JAVA_OBJECT, + getSqlTypeIdFromArrowType(new ArrowType.Duration(TimeUnit.MILLISECOND))); + assertEquals(Types.JAVA_OBJECT, + getSqlTypeIdFromArrowType(new ArrowType.Interval(IntervalUnit.DAY_TIME))); + assertEquals(Types.JAVA_OBJECT, + getSqlTypeIdFromArrowType(new ArrowType.Union(UnionMode.Dense, null))); + assertEquals(Types.JAVA_OBJECT, getSqlTypeIdFromArrowType(new ArrowType.Map(true))); + + assertEquals(Types.NULL, getSqlTypeIdFromArrowType(new ArrowType.Null())); + } + + @Test + public void testGetSqlTypeNameFromArrowType() { + assertEquals("TINYINT", getSqlTypeNameFromArrowType(new ArrowType.Int(8, true))); + assertEquals("SMALLINT", getSqlTypeNameFromArrowType(new ArrowType.Int(16, true))); + assertEquals("INTEGER", getSqlTypeNameFromArrowType(new ArrowType.Int(32, true))); + assertEquals("BIGINT", getSqlTypeNameFromArrowType(new ArrowType.Int(64, true))); + + assertEquals("BINARY", getSqlTypeNameFromArrowType(new ArrowType.FixedSizeBinary(1024))); + assertEquals("VARBINARY", getSqlTypeNameFromArrowType(new ArrowType.Binary())); + assertEquals("LONGVARBINARY", getSqlTypeNameFromArrowType(new ArrowType.LargeBinary())); + + assertEquals("VARCHAR", getSqlTypeNameFromArrowType(new ArrowType.Utf8())); + assertEquals("LONGVARCHAR", getSqlTypeNameFromArrowType(new ArrowType.LargeUtf8())); + + assertEquals("DATE", getSqlTypeNameFromArrowType(new ArrowType.Date(DateUnit.MILLISECOND))); + assertEquals("TIME", getSqlTypeNameFromArrowType(new ArrowType.Time(TimeUnit.MILLISECOND, 32))); + assertEquals("TIMESTAMP", + getSqlTypeNameFromArrowType(new ArrowType.Timestamp(TimeUnit.MILLISECOND, ""))); + + assertEquals("BOOLEAN", getSqlTypeNameFromArrowType(new ArrowType.Bool())); + + assertEquals("DECIMAL", getSqlTypeNameFromArrowType(new ArrowType.Decimal(0, 0, 64))); + assertEquals("DOUBLE", + getSqlTypeNameFromArrowType(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE))); + assertEquals("FLOAT", + getSqlTypeNameFromArrowType(new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE))); + + assertEquals("ARRAY", getSqlTypeNameFromArrowType(new ArrowType.List())); + assertEquals("ARRAY", getSqlTypeNameFromArrowType(new ArrowType.LargeList())); + assertEquals("ARRAY", getSqlTypeNameFromArrowType(new ArrowType.FixedSizeList(10))); + + assertEquals("JAVA_OBJECT", getSqlTypeNameFromArrowType(new ArrowType.Struct())); + + assertEquals("JAVA_OBJECT", + getSqlTypeNameFromArrowType(new ArrowType.Duration(TimeUnit.MILLISECOND))); + assertEquals("JAVA_OBJECT", + getSqlTypeNameFromArrowType(new ArrowType.Interval(IntervalUnit.DAY_TIME))); + assertEquals("JAVA_OBJECT", + getSqlTypeNameFromArrowType(new ArrowType.Union(UnionMode.Dense, null))); + assertEquals("JAVA_OBJECT", getSqlTypeNameFromArrowType(new ArrowType.Map(true))); + + assertEquals("NULL", getSqlTypeNameFromArrowType(new ArrowType.Null())); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/ThrowableAssertionUtils.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/ThrowableAssertionUtils.java new file mode 100644 index 0000000000000..f1bd44539ac58 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/ThrowableAssertionUtils.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +/** + * Utility class to avoid upgrading JUnit to version >= 4.13 and keep using code to assert a {@link Throwable}. + * This should be removed as soon as we can use the proper assertThrows/checkThrows. + */ +public class ThrowableAssertionUtils { + private ThrowableAssertionUtils() { + } + + public static void simpleAssertThrowableClass( + final Class expectedThrowable, final ThrowingRunnable runnable) { + try { + runnable.run(); + } catch (Throwable actualThrown) { + if (expectedThrowable.isInstance(actualThrown)) { + return; + } else { + final String mismatchMessage = String.format("unexpected exception type thrown;\nexpected: %s\nactual: %s", + formatClass(expectedThrowable), + formatClass(actualThrown.getClass())); + + throw new AssertionError(mismatchMessage, actualThrown); + } + } + final String notThrownMessage = String.format("expected %s to be thrown, but nothing was thrown", + formatClass(expectedThrowable)); + throw new AssertionError(notThrownMessage); + } + + private static String formatClass(final Class value) { + // Fallback for anonymous inner classes + final String className = value.getCanonicalName(); + return className == null ? value.getName() : className; + } + + public interface ThrowingRunnable { + void run() throws Throwable; + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/UrlParserTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/UrlParserTest.java new file mode 100644 index 0000000000000..4e764ab322c69 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/UrlParserTest.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; + +import java.util.Map; + +import org.junit.jupiter.api.Test; + +class UrlParserTest { + @Test + void parse() { + final Map parsed = UrlParser.parse("foo=bar&123=456", "&"); + assertEquals(parsed.get("foo"), "bar"); + assertEquals(parsed.get("123"), "456"); + } + + @Test + void parseEscaped() { + final Map parsed = UrlParser.parse("foo=bar%26&%26123=456", "&"); + assertEquals(parsed.get("foo"), "bar&"); + assertEquals(parsed.get("&123"), "456"); + } + + @Test + void parseEmpty() { + final Map parsed = UrlParser.parse("a=&b&foo=bar&123=456", "&"); + assertEquals(parsed.get("a"), ""); + assertNull(parsed.get("b")); + assertEquals(parsed.get("foo"), "bar"); + assertEquals(parsed.get("123"), "456"); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/VectorSchemaRootTransformerTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/VectorSchemaRootTransformerTest.java new file mode 100644 index 0000000000000..1804b42cecb88 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/VectorSchemaRootTransformerTest.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import java.util.List; +import java.util.stream.Collectors; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.VarBinaryVector; +import org.apache.arrow.vector.VectorLoader; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; + +import com.google.common.collect.ImmutableList; + +public class VectorSchemaRootTransformerTest { + + @Rule + public RootAllocatorTestRule rootAllocatorTestRule = new RootAllocatorTestRule(); + private final BufferAllocator rootAllocator = rootAllocatorTestRule.getRootAllocator(); + + @Test + public void testTransformerBuilderWorksCorrectly() throws Exception { + final VarBinaryVector field1 = rootAllocatorTestRule.createVarBinaryVector("FIELD_1"); + final VarBinaryVector field2 = rootAllocatorTestRule.createVarBinaryVector("FIELD_2"); + final VarBinaryVector field3 = rootAllocatorTestRule.createVarBinaryVector("FIELD_3"); + + try (final VectorSchemaRoot originalRoot = VectorSchemaRoot.of(field1, field2, field3); + final VectorSchemaRoot clonedRoot = cloneVectorSchemaRoot(originalRoot)) { + + final VectorSchemaRootTransformer.Builder builder = + new VectorSchemaRootTransformer.Builder(originalRoot.getSchema(), + rootAllocator); + + builder.renameFieldVector("FIELD_3", "FIELD_3_RENAMED"); + builder.addEmptyField("EMPTY_FIELD", new ArrowType.Bool()); + builder.renameFieldVector("FIELD_2", "FIELD_2_RENAMED"); + builder.renameFieldVector("FIELD_1", "FIELD_1_RENAMED"); + + final VectorSchemaRootTransformer transformer = builder.build(); + + final Schema transformedSchema = new Schema(ImmutableList.of( + Field.nullable("FIELD_3_RENAMED", new ArrowType.Binary()), + Field.nullable("EMPTY_FIELD", new ArrowType.Bool()), + Field.nullable("FIELD_2_RENAMED", new ArrowType.Binary()), + Field.nullable("FIELD_1_RENAMED", new ArrowType.Binary()) + )); + try (final VectorSchemaRoot transformedRoot = createVectorSchemaRoot(transformedSchema)) { + Assert.assertSame(transformedRoot, transformer.transform(clonedRoot, transformedRoot)); + Assert.assertEquals(transformedSchema, transformedRoot.getSchema()); + + final int rowCount = originalRoot.getRowCount(); + Assert.assertEquals(rowCount, transformedRoot.getRowCount()); + + final VarBinaryVector originalField1 = + (VarBinaryVector) originalRoot.getVector("FIELD_1"); + final VarBinaryVector originalField2 = + (VarBinaryVector) originalRoot.getVector("FIELD_2"); + final VarBinaryVector originalField3 = + (VarBinaryVector) originalRoot.getVector("FIELD_3"); + + final VarBinaryVector transformedField1 = + (VarBinaryVector) transformedRoot.getVector("FIELD_1_RENAMED"); + final VarBinaryVector transformedField2 = + (VarBinaryVector) transformedRoot.getVector("FIELD_2_RENAMED"); + final VarBinaryVector transformedField3 = + (VarBinaryVector) transformedRoot.getVector("FIELD_3_RENAMED"); + final FieldVector emptyField = transformedRoot.getVector("EMPTY_FIELD"); + + for (int i = 0; i < rowCount; i++) { + Assert.assertArrayEquals(originalField1.getObject(i), transformedField1.getObject(i)); + Assert.assertArrayEquals(originalField2.getObject(i), transformedField2.getObject(i)); + Assert.assertArrayEquals(originalField3.getObject(i), transformedField3.getObject(i)); + Assert.assertNull(emptyField.getObject(i)); + } + } + } + } + + private VectorSchemaRoot cloneVectorSchemaRoot(final VectorSchemaRoot originalRoot) { + final VectorUnloader vectorUnloader = new VectorUnloader(originalRoot); + try (final ArrowRecordBatch recordBatch = vectorUnloader.getRecordBatch()) { + final VectorSchemaRoot clonedRoot = createVectorSchemaRoot(originalRoot.getSchema()); + final VectorLoader vectorLoader = new VectorLoader(clonedRoot); + vectorLoader.load(recordBatch); + return clonedRoot; + } + } + + private VectorSchemaRoot createVectorSchemaRoot(final Schema schema) { + final List fieldVectors = schema.getFields().stream() + .map(field -> field.createVector(rootAllocator)) + .collect(Collectors.toList()); + return new VectorSchemaRoot(fieldVectors); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/resources/keys/keyStore.jks b/java/flight/flight-sql-jdbc-driver/src/test/resources/keys/keyStore.jks new file mode 100644 index 0000000000000..32a9bedea500a Binary files /dev/null and b/java/flight/flight-sql-jdbc-driver/src/test/resources/keys/keyStore.jks differ diff --git a/java/flight/flight-sql-jdbc-driver/src/test/resources/keys/noCertificate.jks b/java/flight/flight-sql-jdbc-driver/src/test/resources/keys/noCertificate.jks new file mode 100644 index 0000000000000..071a1ebf97b3e Binary files /dev/null and b/java/flight/flight-sql-jdbc-driver/src/test/resources/keys/noCertificate.jks differ diff --git a/java/flight/flight-sql-jdbc-driver/src/test/resources/logback.xml b/java/flight/flight-sql-jdbc-driver/src/test/resources/logback.xml new file mode 100644 index 0000000000000..ce66f8d82acda --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/resources/logback.xml @@ -0,0 +1,27 @@ + + + + + + + + %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n + + + + + + + + + diff --git a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSql.java b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSql.java index 06b3c9dbe202b..1d7305fcf2f3c 100644 --- a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSql.java +++ b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSql.java @@ -260,6 +260,7 @@ public void testGetTablesResultFilteredWithSchema() throws Exception { new Field("ID", new FieldType(false, MinorType.INT.getType(), null, new FlightSqlColumnMetadata.Builder() .catalogName("") + .typeName("INTEGER") .schemaName("APP") .tableName("FOREIGNTABLE") .precision(10) @@ -269,6 +270,7 @@ public void testGetTablesResultFilteredWithSchema() throws Exception { new Field("FOREIGNNAME", new FieldType(true, MinorType.VARCHAR.getType(), null, new FlightSqlColumnMetadata.Builder() .catalogName("") + .typeName("VARCHAR") .schemaName("APP") .tableName("FOREIGNTABLE") .precision(100) @@ -278,6 +280,7 @@ public void testGetTablesResultFilteredWithSchema() throws Exception { new Field("VALUE", new FieldType(true, MinorType.INT.getType(), null, new FlightSqlColumnMetadata.Builder() .catalogName("") + .typeName("INTEGER") .schemaName("APP") .tableName("FOREIGNTABLE") .precision(10) @@ -293,6 +296,7 @@ public void testGetTablesResultFilteredWithSchema() throws Exception { new Field("ID", new FieldType(false, MinorType.INT.getType(), null, new FlightSqlColumnMetadata.Builder() .catalogName("") + .typeName("INTEGER") .schemaName("APP") .tableName("INTTABLE") .precision(10) @@ -302,6 +306,7 @@ public void testGetTablesResultFilteredWithSchema() throws Exception { new Field("KEYNAME", new FieldType(true, MinorType.VARCHAR.getType(), null, new FlightSqlColumnMetadata.Builder() .catalogName("") + .typeName("VARCHAR") .schemaName("APP") .tableName("INTTABLE") .precision(100) @@ -311,6 +316,7 @@ public void testGetTablesResultFilteredWithSchema() throws Exception { new Field("VALUE", new FieldType(true, MinorType.INT.getType(), null, new FlightSqlColumnMetadata.Builder() .catalogName("") + .typeName("INTEGER") .schemaName("APP") .tableName("INTTABLE") .precision(10) @@ -320,6 +326,7 @@ public void testGetTablesResultFilteredWithSchema() throws Exception { new Field("FOREIGNID", new FieldType(true, MinorType.INT.getType(), null, new FlightSqlColumnMetadata.Builder() .catalogName("") + .typeName("INTEGER") .schemaName("APP") .tableName("INTTABLE") .precision(10) diff --git a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java index baf162cb919fc..d66b8df9283bf 100644 --- a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java +++ b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java @@ -576,6 +576,7 @@ private static VectorSchemaRoot getTablesRoot(final DatabaseMetaData databaseMet final String catalogName = columnsData.getString("TABLE_CAT"); final String schemaName = columnsData.getString("TABLE_SCHEM"); final String tableName = columnsData.getString("TABLE_NAME"); + final String typeName = columnsData.getString("TYPE_NAME"); final String fieldName = columnsData.getString("COLUMN_NAME"); final int dataType = columnsData.getInt("DATA_TYPE"); final boolean isNullable = columnsData.getInt("NULLABLE") != DatabaseMetaData.columnNoNulls; @@ -590,6 +591,7 @@ private static VectorSchemaRoot getTablesRoot(final DatabaseMetaData databaseMet .catalogName(catalogName) .schemaName(schemaName) .tableName(tableName) + .typeName(typeName) .precision(precision) .scale(scale) .isAutoIncrement(isAutoIncrement) diff --git a/java/flight/pom.xml b/java/flight/pom.xml index dad0f05d7afd7..d8b02bee7ab5c 100644 --- a/java/flight/pom.xml +++ b/java/flight/pom.xml @@ -28,6 +28,7 @@ flight-core flight-grpc flight-sql + flight-sql-jdbc-driver flight-integration-tests diff --git a/java/pom.xml b/java/pom.xml index 0d123977fc634..c1710f57b424e 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -150,6 +150,7 @@ **/client/build/** **/*.tbl **/*.iml + **/flight.properties