Skip to content

Commit

Permalink
Fix value comparisons in TestHiveFileFormats
Browse files Browse the repository at this point in the history
Maps need to be compared without regard to order.
  • Loading branch information
electrum committed Jan 26, 2019
1 parent a63829d commit 486a8f1
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,13 @@
import com.google.common.base.Joiner;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.airlift.slice.DynamicSliceOutput;
import io.airlift.slice.Slice;
import io.airlift.slice.SliceOutput;
import io.airlift.slice.Slices;
import io.prestosql.block.BlockEncodingManager;
import io.prestosql.block.BlockSerdeUtil;
import io.prestosql.plugin.hive.metastore.StorageFormat;
import io.prestosql.spi.Page;
import io.prestosql.spi.PageBuilder;
import io.prestosql.spi.block.Block;
import io.prestosql.spi.block.BlockBuilder;
import io.prestosql.spi.block.BlockEncodingSerde;
import io.prestosql.spi.connector.ConnectorPageSource;
import io.prestosql.spi.connector.ConnectorSession;
import io.prestosql.spi.connector.RecordCursor;
Expand All @@ -46,7 +41,6 @@
import io.prestosql.testing.MaterializedResult;
import io.prestosql.testing.MaterializedRow;
import io.prestosql.tests.StructuralTestUtil;
import io.prestosql.type.TypeRegistry;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.hive.common.type.HiveChar;
Expand All @@ -73,6 +67,7 @@

import java.io.File;
import java.io.IOException;
import java.lang.invoke.MethodHandle;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.sql.Date;
Expand All @@ -89,6 +84,8 @@
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Predicates.not;
import static com.google.common.base.Strings.padEnd;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.collect.Iterables.filter;
import static com.google.common.collect.Iterables.transform;
import static io.prestosql.plugin.hive.HdfsConfigurationUpdater.configureCompression;
Expand All @@ -97,6 +94,7 @@
import static io.prestosql.plugin.hive.HivePartitionKey.HIVE_DEFAULT_DYNAMIC_PARTITION;
import static io.prestosql.plugin.hive.HiveTestUtils.SESSION;
import static io.prestosql.plugin.hive.HiveTestUtils.TYPE_MANAGER;
import static io.prestosql.plugin.hive.HiveTestUtils.isDistinctFrom;
import static io.prestosql.plugin.hive.HiveTestUtils.mapType;
import static io.prestosql.plugin.hive.HiveUtil.isStructuralType;
import static io.prestosql.plugin.hive.util.SerDeUtils.serializeObject;
Expand Down Expand Up @@ -124,6 +122,7 @@
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Arrays.fill;
import static java.util.Objects.requireNonNull;
import static java.util.function.Function.identity;
import static java.util.stream.Collectors.toList;
import static org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.getStandardListObjectInspector;
import static org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.getStandardMapObjectInspector;
Expand Down Expand Up @@ -464,8 +463,6 @@ public abstract class AbstractTestHiveFileFormats
.add(new TestColumn("t_array_string_all_nulls", getStandardListObjectInspector(javaStringObjectInspector), Arrays.asList(null, null, null), arrayBlockOf(createUnboundedVarcharType(), null, null, null)))
.build();

private final BlockEncodingSerde blockEncodingSerde = new BlockEncodingManager(new TypeRegistry());

private static <K, V> Map<K, V> asMap(K[] keys, V[] values)
{
checkArgument(keys.length == values.length, "array lengths don't match");
Expand Down Expand Up @@ -671,12 +668,20 @@ else if (type instanceof DecimalType) {

protected void checkCursor(RecordCursor cursor, List<TestColumn> testColumns, int rowCount)
{
List<Type> types = testColumns.stream()
.map(column -> column.getObjectInspector().getTypeName())
.map(type -> HiveType.valueOf(type).getType(TYPE_MANAGER))
.collect(toImmutableList());

Map<Type, MethodHandle> distinctFromOperators = types.stream().distinct()
.collect(toImmutableMap(identity(), HiveTestUtils::distinctFromOperator));

for (int row = 0; row < rowCount; row++) {
assertTrue(cursor.advanceNextPosition());
for (int i = 0, testColumnsSize = testColumns.size(); i < testColumnsSize; i++) {
TestColumn testColumn = testColumns.get(i);

Type type = HiveType.valueOf(testColumn.getObjectInspector().getTypeName()).getType(TYPE_MANAGER);
Type type = types.get(i);
Object fieldFromCursor = getFieldFromCursor(cursor, type, i);
if (fieldFromCursor == null) {
assertEquals(null, testColumn.getExpectedValue(), String.format("Expected null for column %s", testColumn.getName()));
Expand Down Expand Up @@ -707,7 +712,8 @@ else if (testColumn.getObjectInspector().getCategory() == Category.PRIMITIVE) {
else {
Block expected = (Block) testColumn.getExpectedValue();
Block actual = (Block) fieldFromCursor;
assertBlockEquals(actual, expected, String.format("Wrong value for column %s", testColumn.getName()));
boolean distinct = isDistinctFrom(distinctFromOperators.get(type), expected, actual);
assertFalse(distinct, "Wrong value for column: " + testColumn.getName());
}
}
}
Expand Down Expand Up @@ -788,19 +794,6 @@ else if (testColumn.getObjectInspector().getCategory() == Category.PRIMITIVE) {
}
}

private void assertBlockEquals(Block actual, Block expected, String message)
{
assertEquals(blockToSlice(actual), blockToSlice(expected), message);
}

private Slice blockToSlice(Block block)
{
// This function is strictly for testing use only
SliceOutput sliceOutput = new DynamicSliceOutput(1000);
BlockSerdeUtil.writeBlock(blockEncodingSerde, sliceOutput, block);
return sliceOutput.slice();
}

public static final class TestColumn
{
private final String name;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
import com.google.common.collect.ImmutableSet;
import io.airlift.slice.Slice;
import io.prestosql.PagesIndexPageSorter;
import io.prestosql.block.BlockEncodingManager;
import io.prestosql.metadata.FunctionRegistry;
import io.prestosql.metadata.Metadata;
import io.prestosql.metadata.Signature;
import io.prestosql.operator.PagesIndex;
import io.prestosql.plugin.hive.authentication.NoHdfsAuthentication;
import io.prestosql.plugin.hive.orc.DwrfPageSourceFactory;
Expand All @@ -28,6 +29,7 @@
import io.prestosql.plugin.hive.s3.HiveS3Config;
import io.prestosql.plugin.hive.s3.PrestoS3ConfigurationUpdater;
import io.prestosql.spi.PageSorter;
import io.prestosql.spi.block.Block;
import io.prestosql.spi.connector.ColumnHandle;
import io.prestosql.spi.connector.ConnectorSession;
import io.prestosql.spi.type.ArrayType;
Expand All @@ -36,15 +38,17 @@
import io.prestosql.spi.type.RowType;
import io.prestosql.spi.type.StandardTypes;
import io.prestosql.spi.type.Type;
import io.prestosql.spi.type.TypeManager;
import io.prestosql.spi.type.TypeSignatureParameter;
import io.prestosql.sql.analyzer.FeaturesConfig;
import io.prestosql.testing.TestingConnectorSession;
import io.prestosql.type.TypeRegistry;

import java.lang.invoke.MethodHandle;
import java.math.BigDecimal;
import java.util.List;
import java.util.Set;

import static io.prestosql.metadata.MetadataManager.createTestMetadataManager;
import static io.prestosql.spi.function.OperatorType.IS_DISTINCT_FROM;
import static io.prestosql.spi.type.Decimals.encodeScaledValue;
import static java.util.stream.Collectors.toList;

Expand All @@ -57,12 +61,9 @@ private HiveTestUtils()
public static final ConnectorSession SESSION = new TestingConnectorSession(
new HiveSessionProperties(new HiveClientConfig(), new OrcFileWriterConfig(), new ParquetFileWriterConfig()).getSessionProperties());

public static final TypeRegistry TYPE_MANAGER = new TypeRegistry();

static {
// associate TYPE_MANAGER with a function registry
new FunctionRegistry(TYPE_MANAGER, new BlockEncodingManager(TYPE_MANAGER), new FeaturesConfig());
}
private static final Metadata METADATA = createTestMetadataManager();
private static final FunctionRegistry FUNCTION_REGISTRY = METADATA.getFunctionRegistry();
public static final TypeManager TYPE_MANAGER = METADATA.getTypeManager();

public static final HdfsEnvironment HDFS_ENVIRONMENT = createTestHdfsEnvironment(new HiveClientConfig());

Expand Down Expand Up @@ -156,4 +157,20 @@ public static Slice longDecimal(String value)
{
return encodeScaledValue(new BigDecimal(value));
}

public static MethodHandle distinctFromOperator(Type type)
{
Signature signature = FUNCTION_REGISTRY.resolveOperator(IS_DISTINCT_FROM, ImmutableList.of(type, type));
return FUNCTION_REGISTRY.getScalarFunctionImplementation(signature).getMethodHandle();
}

public static boolean isDistinctFrom(MethodHandle handle, Block left, Block right)
{
try {
return (boolean) handle.invokeExact(left, left == null, right, right == null);
}
catch (Throwable t) {
throw new AssertionError(t);
}
}
}

0 comments on commit 486a8f1

Please sign in to comment.