diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/AbstractTestHiveFileFormats.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/AbstractTestHiveFileFormats.java index 8e1f134d807a..07ccfa090fc7 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/AbstractTestHiveFileFormats.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/AbstractTestHiveFileFormats.java @@ -16,18 +16,13 @@ import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.airlift.slice.DynamicSliceOutput; import io.airlift.slice.Slice; -import io.airlift.slice.SliceOutput; import io.airlift.slice.Slices; -import io.prestosql.block.BlockEncodingManager; -import io.prestosql.block.BlockSerdeUtil; import io.prestosql.plugin.hive.metastore.StorageFormat; import io.prestosql.spi.Page; import io.prestosql.spi.PageBuilder; import io.prestosql.spi.block.Block; import io.prestosql.spi.block.BlockBuilder; -import io.prestosql.spi.block.BlockEncodingSerde; import io.prestosql.spi.connector.ConnectorPageSource; import io.prestosql.spi.connector.ConnectorSession; import io.prestosql.spi.connector.RecordCursor; @@ -46,7 +41,6 @@ import io.prestosql.testing.MaterializedResult; import io.prestosql.testing.MaterializedRow; import io.prestosql.tests.StructuralTestUtil; -import io.prestosql.type.TypeRegistry; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.hadoop.hive.common.type.HiveChar; @@ -73,6 +67,7 @@ import java.io.File; import java.io.IOException; +import java.lang.invoke.MethodHandle; import java.math.BigDecimal; import java.math.BigInteger; import java.sql.Date; @@ -89,6 +84,8 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Predicates.not; import static com.google.common.base.Strings.padEnd; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.Iterables.filter; import static com.google.common.collect.Iterables.transform; import static io.prestosql.plugin.hive.HdfsConfigurationUpdater.configureCompression; @@ -97,6 +94,7 @@ import static io.prestosql.plugin.hive.HivePartitionKey.HIVE_DEFAULT_DYNAMIC_PARTITION; import static io.prestosql.plugin.hive.HiveTestUtils.SESSION; import static io.prestosql.plugin.hive.HiveTestUtils.TYPE_MANAGER; +import static io.prestosql.plugin.hive.HiveTestUtils.isDistinctFrom; import static io.prestosql.plugin.hive.HiveTestUtils.mapType; import static io.prestosql.plugin.hive.HiveUtil.isStructuralType; import static io.prestosql.plugin.hive.util.SerDeUtils.serializeObject; @@ -124,6 +122,7 @@ import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.Arrays.fill; import static java.util.Objects.requireNonNull; +import static java.util.function.Function.identity; import static java.util.stream.Collectors.toList; import static org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.getStandardListObjectInspector; import static org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.getStandardMapObjectInspector; @@ -464,8 +463,6 @@ public abstract class AbstractTestHiveFileFormats .add(new TestColumn("t_array_string_all_nulls", getStandardListObjectInspector(javaStringObjectInspector), Arrays.asList(null, null, null), arrayBlockOf(createUnboundedVarcharType(), null, null, null))) .build(); - private final BlockEncodingSerde blockEncodingSerde = new BlockEncodingManager(new TypeRegistry()); - private static Map asMap(K[] keys, V[] values) { checkArgument(keys.length == values.length, "array lengths don't match"); @@ -671,12 +668,20 @@ else if (type instanceof DecimalType) { protected void checkCursor(RecordCursor cursor, List testColumns, int rowCount) { + List types = testColumns.stream() + .map(column -> column.getObjectInspector().getTypeName()) + .map(type -> HiveType.valueOf(type).getType(TYPE_MANAGER)) + .collect(toImmutableList()); + + Map distinctFromOperators = types.stream().distinct() + .collect(toImmutableMap(identity(), HiveTestUtils::distinctFromOperator)); + for (int row = 0; row < rowCount; row++) { assertTrue(cursor.advanceNextPosition()); for (int i = 0, testColumnsSize = testColumns.size(); i < testColumnsSize; i++) { TestColumn testColumn = testColumns.get(i); - Type type = HiveType.valueOf(testColumn.getObjectInspector().getTypeName()).getType(TYPE_MANAGER); + Type type = types.get(i); Object fieldFromCursor = getFieldFromCursor(cursor, type, i); if (fieldFromCursor == null) { assertEquals(null, testColumn.getExpectedValue(), String.format("Expected null for column %s", testColumn.getName())); @@ -707,7 +712,8 @@ else if (testColumn.getObjectInspector().getCategory() == Category.PRIMITIVE) { else { Block expected = (Block) testColumn.getExpectedValue(); Block actual = (Block) fieldFromCursor; - assertBlockEquals(actual, expected, String.format("Wrong value for column %s", testColumn.getName())); + boolean distinct = isDistinctFrom(distinctFromOperators.get(type), expected, actual); + assertFalse(distinct, "Wrong value for column: " + testColumn.getName()); } } } @@ -788,19 +794,6 @@ else if (testColumn.getObjectInspector().getCategory() == Category.PRIMITIVE) { } } - private void assertBlockEquals(Block actual, Block expected, String message) - { - assertEquals(blockToSlice(actual), blockToSlice(expected), message); - } - - private Slice blockToSlice(Block block) - { - // This function is strictly for testing use only - SliceOutput sliceOutput = new DynamicSliceOutput(1000); - BlockSerdeUtil.writeBlock(blockEncodingSerde, sliceOutput, block); - return sliceOutput.slice(); - } - public static final class TestColumn { private final String name; diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/HiveTestUtils.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/HiveTestUtils.java index 97980e36274b..41df9d94968c 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/HiveTestUtils.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/HiveTestUtils.java @@ -17,8 +17,9 @@ import com.google.common.collect.ImmutableSet; import io.airlift.slice.Slice; import io.prestosql.PagesIndexPageSorter; -import io.prestosql.block.BlockEncodingManager; import io.prestosql.metadata.FunctionRegistry; +import io.prestosql.metadata.Metadata; +import io.prestosql.metadata.Signature; import io.prestosql.operator.PagesIndex; import io.prestosql.plugin.hive.authentication.NoHdfsAuthentication; import io.prestosql.plugin.hive.orc.DwrfPageSourceFactory; @@ -28,6 +29,7 @@ import io.prestosql.plugin.hive.s3.HiveS3Config; import io.prestosql.plugin.hive.s3.PrestoS3ConfigurationUpdater; import io.prestosql.spi.PageSorter; +import io.prestosql.spi.block.Block; import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.connector.ConnectorSession; import io.prestosql.spi.type.ArrayType; @@ -36,15 +38,17 @@ import io.prestosql.spi.type.RowType; import io.prestosql.spi.type.StandardTypes; import io.prestosql.spi.type.Type; +import io.prestosql.spi.type.TypeManager; import io.prestosql.spi.type.TypeSignatureParameter; -import io.prestosql.sql.analyzer.FeaturesConfig; import io.prestosql.testing.TestingConnectorSession; -import io.prestosql.type.TypeRegistry; +import java.lang.invoke.MethodHandle; import java.math.BigDecimal; import java.util.List; import java.util.Set; +import static io.prestosql.metadata.MetadataManager.createTestMetadataManager; +import static io.prestosql.spi.function.OperatorType.IS_DISTINCT_FROM; import static io.prestosql.spi.type.Decimals.encodeScaledValue; import static java.util.stream.Collectors.toList; @@ -57,12 +61,9 @@ private HiveTestUtils() public static final ConnectorSession SESSION = new TestingConnectorSession( new HiveSessionProperties(new HiveClientConfig(), new OrcFileWriterConfig(), new ParquetFileWriterConfig()).getSessionProperties()); - public static final TypeRegistry TYPE_MANAGER = new TypeRegistry(); - - static { - // associate TYPE_MANAGER with a function registry - new FunctionRegistry(TYPE_MANAGER, new BlockEncodingManager(TYPE_MANAGER), new FeaturesConfig()); - } + private static final Metadata METADATA = createTestMetadataManager(); + private static final FunctionRegistry FUNCTION_REGISTRY = METADATA.getFunctionRegistry(); + public static final TypeManager TYPE_MANAGER = METADATA.getTypeManager(); public static final HdfsEnvironment HDFS_ENVIRONMENT = createTestHdfsEnvironment(new HiveClientConfig()); @@ -156,4 +157,20 @@ public static Slice longDecimal(String value) { return encodeScaledValue(new BigDecimal(value)); } + + public static MethodHandle distinctFromOperator(Type type) + { + Signature signature = FUNCTION_REGISTRY.resolveOperator(IS_DISTINCT_FROM, ImmutableList.of(type, type)); + return FUNCTION_REGISTRY.getScalarFunctionImplementation(signature).getMethodHandle(); + } + + public static boolean isDistinctFrom(MethodHandle handle, Block left, Block right) + { + try { + return (boolean) handle.invokeExact(left, left == null, right, right == null); + } + catch (Throwable t) { + throw new AssertionError(t); + } + } }