Skip to content

Commit

Permalink
Add support for filters on subfields to OrcTester
Browse files Browse the repository at this point in the history
  • Loading branch information
mbasmanova committed Sep 5, 2019
1 parent 6cffd16 commit 7b9b0ce
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 110 deletions.
159 changes: 93 additions & 66 deletions presto-orc/src/test/java/com/facebook/presto/orc/OrcTester.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.facebook.presto.spi.Subfield;
import com.facebook.presto.spi.block.Block;
import com.facebook.presto.spi.block.BlockBuilder;
import com.facebook.presto.spi.type.ArrayType;
import com.facebook.presto.spi.type.CharType;
import com.facebook.presto.spi.type.DecimalType;
import com.facebook.presto.spi.type.Decimals;
Expand All @@ -46,7 +47,6 @@
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.airlift.units.DataSize;
Expand Down Expand Up @@ -142,6 +142,7 @@
import static com.facebook.presto.spi.type.Varchars.truncateToLength;
import static com.facebook.presto.testing.DateTimeTestingUtils.sqlTimestampOf;
import static com.facebook.presto.testing.TestingConnectorSession.SESSION;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.collect.Lists.newArrayList;
import static io.airlift.slice.Slices.utf8Slice;
Expand Down Expand Up @@ -308,11 +309,19 @@ public void testRoundTrip(Type type, List<?> readValues)
testRoundTrip(type, readValues, ImmutableList.of());
}

public void testRoundTrip(Type type, List<?> readValues, List<Map<Integer, TupleDomainFilter>> filters)
public void testRoundTrip(Type type, List<?> readValues, TupleDomainFilter...filters)
throws Exception
{
testRoundTrip(type, readValues, Arrays.stream(filters).map(filter -> ImmutableMap.of(new Subfield("c"), filter)).collect(toImmutableList()));
}

public void testRoundTrip(Type type, List<?> readValues, List<Map<Subfield, TupleDomainFilter>> filters)
throws Exception
{
List<Map<Integer, Map<Subfield, TupleDomainFilter>>> columnFilters = filters.stream().map(filter -> ImmutableMap.of(0, filter)).collect(toImmutableList());

// just the values
testRoundTripTypes(ImmutableList.of(type), ImmutableList.of(readValues), filters);
testRoundTripTypes(ImmutableList.of(type), ImmutableList.of(readValues), columnFilters);

// all nulls
if (nullTestsEnabled) {
Expand All @@ -321,7 +330,7 @@ public void testRoundTrip(Type type, List<?> readValues, List<Map<Integer, Tuple
readValues.stream()
.map(value -> null)
.collect(toList()),
filters);
columnFilters);
}

// values wrapped in struct
Expand Down Expand Up @@ -467,7 +476,7 @@ private void testRoundTripType(Type type, List<?> readValues)
testRoundTripTypes(ImmutableList.of(type), ImmutableList.of(readValues), ImmutableList.of());
}

public void testRoundTripTypes(List<Type> types, List<List<?>> readValues, List<Map<Integer, TupleDomainFilter>> filters)
public void testRoundTripTypes(List<Type> types, List<List<?>> readValues, List<Map<Integer, Map<Subfield, TupleDomainFilter>>> filters)
throws Exception
{
assertEquals(types.size(), readValues.size());
Expand Down Expand Up @@ -506,7 +515,7 @@ public void assertRoundTrip(Type type, List<?> readValues)
assertRoundTrip(type, type, readValues, readValues, true, ImmutableList.of());
}

public void assertRoundTrip(Type type, List<?> readValues, List<Map<Integer, TupleDomainFilter>> filters)
public void assertRoundTrip(Type type, List<?> readValues, List<Map<Integer, Map<Subfield, TupleDomainFilter>>> filters)
throws Exception
{
assertRoundTrip(type, type, readValues, readValues, true, filters);
Expand All @@ -518,19 +527,19 @@ public void assertRoundTrip(Type type, List<?> readValues, boolean verifyWithHiv
assertRoundTrip(type, type, readValues, readValues, verifyWithHiveReader, ImmutableList.of());
}

public void assertRoundTrip(List<Type> types, List<List<?>> readValues, List<Map<Integer, TupleDomainFilter>> filters)
public void assertRoundTrip(List<Type> types, List<List<?>> readValues, List<Map<Integer, Map<Subfield, TupleDomainFilter>>> filters)
throws Exception
{
assertRoundTrip(types, types, readValues, readValues, true, filters);
}

private void assertRoundTrip(Type writeType, Type readType, List<?> writeValues, List<?> readValues, boolean verifyWithHiveReader, List<Map<Integer, TupleDomainFilter>> filters)
private void assertRoundTrip(Type writeType, Type readType, List<?> writeValues, List<?> readValues, boolean verifyWithHiveReader, List<Map<Integer, Map<Subfield, TupleDomainFilter>>> filters)
throws Exception
{
assertRoundTrip(ImmutableList.of(writeType), ImmutableList.of(readType), ImmutableList.of(writeValues), ImmutableList.of(readValues), verifyWithHiveReader, filters);
}

private void assertRoundTrip(List<Type> writeTypes, List<Type> readTypes, List<List<?>> writeValues, List<List<?>> readValues, boolean verifyWithHiveReader, List<Map<Integer, TupleDomainFilter>> filters)
private void assertRoundTrip(List<Type> writeTypes, List<Type> readTypes, List<List<?>> writeValues, List<List<?>> readValues, boolean verifyWithHiveReader, List<Map<Integer, Map<Subfield, TupleDomainFilter>>> filters)
throws Exception
{
assertEquals(writeTypes.size(), readTypes.size());
Expand Down Expand Up @@ -585,7 +594,7 @@ private static void assertFileContentsPresto(
List<List<?>> expectedValues,
OrcEncoding orcEncoding,
OrcPredicate orcPredicate,
Optional<Map<Integer, TupleDomainFilter>> filters)
Optional<Map<Integer, Map<Subfield, TupleDomainFilter>>> filters)
throws IOException
{
try (OrcSelectiveRecordReader recordReader = createCustomOrcSelectiveRecordReader(tempFile, orcEncoding, orcPredicate, types, MAX_BATCH_SIZE, filters.orElse(ImmutableMap.of()))) {
Expand Down Expand Up @@ -617,8 +626,8 @@ private static void assertFileContentsPresto(
data.add(type.getObjectValue(SESSION, block, position));
}

for (int j = 0; j < positionCount; j++) {
assertColumnValueEquals(type, data.get(j), expectedValues.get(i).get(rowsProcessed + j));
for (int position = 0; position < positionCount; position++) {
assertColumnValueEquals(type, data.get(position), expectedValues.get(i).get(rowsProcessed + position));
}
}

Expand All @@ -639,14 +648,14 @@ private static void assertFileContentsPresto(
Format format,
boolean isHiveWriter,
boolean useSelectiveOrcReader,
List<Map<Integer, TupleDomainFilter>> filters)
List<Map<Integer, Map<Subfield, TupleDomainFilter>>> filters)
throws IOException
{
OrcPredicate orcPredicate = createOrcPredicate(types, expectedValues, format, isHiveWriter);
if (useSelectiveOrcReader) {
assertFileContentsPresto(types, tempFile, expectedValues, orcEncoding, orcPredicate, Optional.empty());

for (Map<Integer, TupleDomainFilter> columnFilters : filters) {
for (Map<Integer, Map<Subfield, TupleDomainFilter>> columnFilters : filters) {
assertFileContentsPresto(types, tempFile, filterRows(types, expectedValues, columnFilters), orcEncoding, orcPredicate, Optional.of(columnFilters));
}

Expand Down Expand Up @@ -694,7 +703,7 @@ else if (skipFirstBatch && isFirst) {
}
}

private static List<List<?>> filterRows(List<Type> types, List<List<?>> values, Map<Integer, TupleDomainFilter> columnFilters)
private static List<List<?>> filterRows(List<Type> types, List<List<?>> values, Map<Integer, Map<Subfield, TupleDomainFilter>> columnFilters)
{
List<Integer> passingRows = IntStream.range(0, values.get(0).size())
.filter(row -> testRow(types, values, row, columnFilters))
Expand All @@ -705,70 +714,88 @@ private static List<List<?>> filterRows(List<Type> types, List<List<?>> values,
.collect(toList());
}

private static boolean testRow(List<Type> types, List<List<?>> values, int row, Map<Integer, TupleDomainFilter> columnFilters)
private static boolean testRow(List<Type> types, List<List<?>> values, int row, Map<Integer, Map<Subfield, TupleDomainFilter>> columnFilters)
{
for (int column = 0; column < types.size(); column++) {
TupleDomainFilter filter = columnFilters.get(column);
if (filter == null) {
Map<Subfield, TupleDomainFilter> filters = columnFilters.get(column);

if (filters == null) {
continue;
}

Type type = types.get(column);
Object value = values.get(column).get(row);
if (filter == IS_NULL) {
if (value != null) {
for (Map.Entry<Subfield, TupleDomainFilter> entry : filters.entrySet()) {
if (!testSubfieldValue(type, value, entry.getKey(), entry.getValue())) {
return false;
}
}
else if (filter == IS_NOT_NULL) {
if (value == null) {
return false;
}
}
else if (value == null) {
if (!filter.testNull()) {
return false;
}

return true;
}

private static boolean testSubfieldValue(Type type, Object value, Subfield subfield, TupleDomainFilter filter)
{
Type nestedType = type;
Object nestedValue = value;
for (Subfield.PathElement pathElement : subfield.getPath()) {
if (nestedType instanceof ArrayType) {
assertTrue(pathElement instanceof Subfield.LongSubscript);
if (nestedValue == null) {
return filter == IS_NULL;
}
int index = toIntExact(((Subfield.LongSubscript) pathElement).getIndex());
nestedType = ((ArrayType) nestedType).getElementType();
nestedValue = ((List) nestedValue).get(index - 1);
}
else {
Type type = types.get(column);
if (type == BOOLEAN) {
if (!filter.testBoolean((Boolean) value)) {
return false;
}
}
else if (type == TINYINT || type == BIGINT || type == INTEGER || type == SMALLINT) {
if (!filter.testLong(((Number) value).longValue())) {
return false;
}
}
else if (type == DOUBLE) {
if (!filter.testDouble((double) value)) {
return false;
}
}
else if (type == DATE) {
if (!filter.testLong(((SqlDate) value).getDays())) {
return false;
}
}
else if (type == REAL) {
if (!filter.testFloat(((Number) value).floatValue())) {
return false;
}
}
else if (type == TIMESTAMP) {
if (!filter.testLong(((SqlTimestamp) value).getMillisUtc())) {
return false;
}
}

else {
fail("Unsupported type: " + type);
}
fail("Unsupported type: " + type);
}
}
return testValue(nestedType, nestedValue, filter);
}

return true;
private static boolean testValue(Type type, Object value, TupleDomainFilter filter)
{
if (value == null) {
return filter.testNull();
}

if (filter == IS_NULL) {
return false;
}

if (filter == IS_NOT_NULL) {
return true;
}

if (type == BOOLEAN) {
return filter.testBoolean((Boolean) value);
}

if (type == TINYINT || type == BIGINT || type == INTEGER || type == SMALLINT) {
return filter.testLong(((Number) value).longValue());
}

if (type == REAL) {
return filter.testFloat(((Number) value).floatValue());
}

if (type == DOUBLE) {
return filter.testDouble((double) value);
}

if (type == DATE) {
return filter.testLong(((SqlDate) value).getDays());
}

if (type == TIMESTAMP) {
return filter.testLong(((SqlTimestamp) value).getMillisUtc());
}

fail("Unsupported type: " + type);
return false;
}

private static void assertColumnValueEquals(Type type, Object actual, Object expected)
Expand Down Expand Up @@ -908,7 +935,7 @@ static OrcSelectiveRecordReader createCustomOrcSelectiveRecordReader(
OrcPredicate predicate,
List<Type> types,
int initialBatchSize,
Map<Integer, TupleDomainFilter> filters)
Map<Integer, Map<Subfield, TupleDomainFilter>> filters)
throws IOException
{
OrcDataSource orcDataSource = new FileOrcDataSource(tempFile.getFile(), new DataSize(1, MEGABYTE), new DataSize(1, MEGABYTE), new DataSize(1, MEGABYTE), true);
Expand All @@ -924,7 +951,7 @@ static OrcSelectiveRecordReader createCustomOrcSelectiveRecordReader(
return orcReader.createSelectiveRecordReader(
columnTypes,
IntStream.range(0, types.size()).boxed().collect(toList()),
Maps.transformValues(filters, v -> ImmutableMap.of(new Subfield("c"), v)),
filters,
ImmutableList.of(),
ImmutableMap.of(),
ImmutableMap.of(),
Expand Down
Loading

0 comments on commit 7b9b0ce

Please sign in to comment.