Skip to content

Commit

Permalink
Add a query with WHERE clause to DataTypeTest
Browse files Browse the repository at this point in the history
This is useful for testing predicate pushdown.

Please note that this doesn't fail if predicate evaluation happens
on Presto side, you still need an assertion that checks for that.

Debug one column at a time if DataTypeTest fails on WHERE query

Add WHERE clause to queries in Postgres type tests
  • Loading branch information
MiguelWeezardo authored and findepi committed Jun 16, 2020
1 parent 2356edb commit d26aab9
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ public void setUp()
@Test
public void testBasicTypes()
{
DataTypeTest.create()
DataTypeTest.create(true)
.addRoundTrip(booleanDataType(), true)
.addRoundTrip(booleanDataType(), false)
.addRoundTrip(bigintDataType(), 123_456_789_012L)
Expand Down Expand Up @@ -610,7 +610,7 @@ public void testDecimalUnspecifiedPrecisionWithExceedingValue()
public void testArray()
{
// basic types
DataTypeTest.create()
DataTypeTest.create(true)
.addRoundTrip(arrayDataType(booleanDataType()), asList(true, false))
.addRoundTrip(arrayDataType(bigintDataType()), asList(123_456_789_012L))
.addRoundTrip(arrayDataType(integerDataType()), asList(1, 2, 1_234_567_890))
Expand Down Expand Up @@ -853,6 +853,7 @@ private static <E> DataType<List<E>> arrayDataType(DataType<E> elementType, Stri
insertType,
new ArrayType(elementType.getPrestoResultType()),
valuesList -> "ARRAY" + valuesList.stream().map(elementType::toLiteral).collect(toList()),
valuesList -> "ARRAY" + valuesList.stream().map(elementType::toPrestoLiteral).collect(toList()),
valuesList -> valuesList == null ? null : valuesList.stream().map(elementType::toPrestoQueryResult).collect(toList()));
}

Expand Down Expand Up @@ -896,7 +897,7 @@ public void testDate()
LocalDate dateOfLocalTimeChangeBackwardAtMidnightInSomeZone = LocalDate.of(1983, 10, 1);
checkIsDoubled(someZone, dateOfLocalTimeChangeBackwardAtMidnightInSomeZone.atStartOfDay().minusMinutes(1));

DataTypeTest testCases = DataTypeTest.create()
DataTypeTest testCases = DataTypeTest.create(true)
.addRoundTrip(dateDataType(), LocalDate.of(1952, 4, 3)) // before epoch
.addRoundTrip(dateDataType(), LocalDate.of(1970, 1, 1))
.addRoundTrip(dateDataType(), LocalDate.of(1970, 2, 3))
Expand Down Expand Up @@ -984,7 +985,7 @@ private void addTimeTestIfSupported(DataTypeTest tests, boolean legacyTimestamp,
public void testTimestamp(boolean legacyTimestamp, boolean insertWithPresto, ZoneId sessionZone)
{
// using two non-JVM zones so that we don't need to worry what Postgres system zone is
DataTypeTest tests = DataTypeTest.create()
DataTypeTest tests = DataTypeTest.create(true)
.addRoundTrip(timestampDataType(), beforeEpoch)
.addRoundTrip(timestampDataType(), afterEpoch)
.addRoundTrip(timestampDataType(), timeDoubledInJvmZone)
Expand Down Expand Up @@ -1032,7 +1033,7 @@ public void testArrayTimestamp(boolean legacyTimestamp, boolean insertWithPresto
dataType = arrayDataType(timestampDataType(), "timestamp[]");
dataSetup = postgresCreateAndInsert("tpch.test_array_timestamp");
}
DataTypeTest tests = DataTypeTest.create()
DataTypeTest tests = DataTypeTest.create(true)
.addRoundTrip(dataType, asList(beforeEpoch))
.addRoundTrip(dataType, asList(afterEpoch))
.addRoundTrip(dataType, asList(timeDoubledInJvmZone))
Expand Down Expand Up @@ -1108,7 +1109,7 @@ public void testTimestampWithTimeZone(boolean insertWithPresto)
dataSetup = postgresCreateAndInsert("tpch.test_timestamp_with_time_zone");
}

DataTypeTest tests = DataTypeTest.create()
DataTypeTest tests = DataTypeTest.create(true)
.addRoundTrip(dataType, epoch.atZone(UTC))
.addRoundTrip(dataType, epoch.atZone(kathmandu))
.addRoundTrip(dataType, epoch.atZone(fixedOffsetEast))
Expand Down Expand Up @@ -1202,7 +1203,7 @@ public void testJsonb()

private DataTypeTest jsonTestCases(DataType<String> jsonDataType)
{
return DataTypeTest.create()
return DataTypeTest.create(true)
.addRoundTrip(jsonDataType, "{}")
.addRoundTrip(jsonDataType, null)
.addRoundTrip(jsonDataType, "null")
Expand Down Expand Up @@ -1245,15 +1246,15 @@ public void testUuid()

private DataTypeTest uuidTestCases(DataType<java.util.UUID> uuidDataType)
{
return DataTypeTest.create()
return DataTypeTest.create(true)
.addRoundTrip(uuidDataType, java.util.UUID.fromString("00000000-0000-0000-0000-000000000000"))
.addRoundTrip(uuidDataType, java.util.UUID.fromString("123e4567-e89b-12d3-a456-426655440000"));
}

@Test
public void testMoney()
{
DataTypeTest.create()
DataTypeTest.create(true)
.addRoundTrip(moneyDataType(), null)
.addRoundTrip(moneyDataType(), 10.)
.addRoundTrip(moneyDataType(), 10.54)
Expand Down Expand Up @@ -1283,7 +1284,7 @@ public void testDouble()

private static DataTypeTest singlePrecisionFloatingPointTests(DataType<Float> floatType)
{
return DataTypeTest.create()
return DataTypeTest.create(true)
.addRoundTrip(floatType, 3.14f)
.addRoundTrip(floatType, 3.1415927f)
.addRoundTrip(floatType, Float.NaN)
Expand All @@ -1294,7 +1295,7 @@ private static DataTypeTest singlePrecisionFloatingPointTests(DataType<Float> fl

private static DataTypeTest doublePrecisionFloatinPointTests(DataType<Double> doubleType)
{
return DataTypeTest.create()
return DataTypeTest.create(true)
.addRoundTrip(doubleType, 1.0e100d)
.addRoundTrip(doubleType, Double.NaN)
.addRoundTrip(doubleType, Double.POSITIVE_INFINITY)
Expand Down Expand Up @@ -1391,6 +1392,7 @@ public static DataType<ZonedDateTime> postgreSqlTimestampWithTimeZoneDataType()
// PostgreSQL never examines the content of a literal string before determining its type, so `TIMESTAMP '.... {zone}'` won't work.
// PostgreSQL does not store zone, only the point in time
zonedDateTime -> DateTimeFormatter.ofPattern("'TIMESTAMP WITH TIME ZONE '''yyyy-MM-dd HH:mm:ss.SSS VV''").format(zonedDateTime.withZoneSameInstant(UTC)),
DateTimeFormatter.ofPattern("'TIMESTAMP '''yyyy-MM-dd HH:mm:ss.SSS VV''")::format,
zonedDateTime -> zonedDateTime.withZoneSameInstant(ZoneId.of("UTC")));
}

Expand Down Expand Up @@ -1461,6 +1463,7 @@ private static DataType<byte[]> byteaDataType()
"bytea",
VARBINARY,
bytes -> format("bytea E'\\\\x%s'", base16().encode(bytes)),
DataType::binaryLiteral,
identity());
}

Expand All @@ -1470,6 +1473,10 @@ private static DataType<Double> moneyDataType()
"money",
VARCHAR,
String::valueOf,
amount -> {
NumberFormat numberFormat = NumberFormat.getCurrencyInstance(Locale.US);
return "'" + numberFormat.format(amount) + "'";
},
amount -> {
NumberFormat numberFormat = NumberFormat.getCurrencyInstance(Locale.US);
return numberFormat.format(amount);
Expand Down Expand Up @@ -1546,7 +1553,9 @@ private static DataType<Float> postgreSqlRealDataType()
return "'NaN'::real";
}
return format("'%sInfinity'::real", value > 0 ? "+" : "-");
});
},
realDataType()::toPrestoLiteral,
Function.identity());
}

private static DataType<Double> postgreSqlDoubleDataType()
Expand All @@ -1560,6 +1569,8 @@ private static DataType<Double> postgreSqlDoubleDataType()
return "'NaN'::double precision";
}
return format("'%sInfinity'::double precision", value > 0 ? "+" : "-");
});
},
doubleDataType()::toPrestoLiteral,
Function.identity());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ public class DataType<T>
private final String insertType;
private final Type prestoResultType;
private final Function<T, String> toLiteral;
private final Function<T, String> toPrestoLiteral;
private final Function<T, ?> toPrestoQueryResult;

public static DataType<Boolean> booleanDataType()
Expand Down Expand Up @@ -217,24 +218,30 @@ public static String binaryLiteral(byte[] value)

private static <T> DataType<T> dataType(String insertType, Type prestoResultType)
{
return new DataType<>(insertType, prestoResultType, Object::toString, Function.identity());
return new DataType<>(insertType, prestoResultType, Object::toString, Object::toString, Function.identity());
}

public static <T> DataType<T> dataType(String insertType, Type prestoResultType, Function<T, String> toLiteral)
{
return new DataType<>(insertType, prestoResultType, toLiteral, Function.identity());
return new DataType<>(insertType, prestoResultType, toLiteral, toLiteral, Function.identity());
}

public static <T> DataType<T> dataType(String insertType, Type prestoResultType, Function<T, String> toLiteral, Function<T, ?> toPrestoQueryResult)
{
return new DataType<>(insertType, prestoResultType, toLiteral, toPrestoQueryResult);
return new DataType<>(insertType, prestoResultType, toLiteral, toLiteral, toPrestoQueryResult);
}

private DataType(String insertType, Type prestoResultType, Function<T, String> toLiteral, Function<T, ?> toPrestoQueryResult)
public static <T> DataType<T> dataType(String insertType, Type prestoResultType, Function<T, String> toLiteral, Function<T, String> toPrestoLiteral, Function<T, ?> toPrestoQueryResult)
{
return new DataType<>(insertType, prestoResultType, toLiteral, toPrestoLiteral, toPrestoQueryResult);
}

private DataType(String insertType, Type prestoResultType, Function<T, String> toLiteral, Function<T, String> toPrestoLiteral, Function<T, ?> toPrestoQueryResult)
{
this.insertType = insertType;
this.prestoResultType = prestoResultType;
this.toLiteral = toLiteral;
this.toPrestoLiteral = toPrestoLiteral;
this.toPrestoQueryResult = toPrestoQueryResult;
}

Expand All @@ -246,6 +253,14 @@ public String toLiteral(T inputValue)
return toLiteral.apply(inputValue);
}

public String toPrestoLiteral(T inputValue)
{
if (inputValue == null) {
return "NULL";
}
return toPrestoLiteral.apply(inputValue);
}

public Object toPrestoQueryResult(T inputValue)
{
if (inputValue == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.prestosql.testing.datatype;

import io.airlift.log.Logger;
import io.prestosql.Session;
import io.prestosql.spi.type.Type;
import io.prestosql.testing.MaterializedResult;
Expand All @@ -23,24 +24,43 @@
import java.util.List;

import static com.google.common.collect.Iterables.getOnlyElement;
import static java.lang.String.format;
import static java.lang.String.join;
import static java.util.Collections.unmodifiableList;
import static java.util.stream.Collectors.toList;
import static org.testng.Assert.assertEquals;

public class DataTypeTest
{
private static final Logger log = Logger.get(DataTypeTest.class);

private final List<Input<?>> inputs = new ArrayList<>();

private DataTypeTest() {}
private boolean runSelectWithWhere;

private DataTypeTest(boolean runSelectWithWhere)
{
this.runSelectWithWhere = runSelectWithWhere;
}

public static DataTypeTest create()
{
return new DataTypeTest();
return new DataTypeTest(false);
}

public static DataTypeTest create(boolean runSelectWithWhere)
{
return new DataTypeTest(runSelectWithWhere);
}

public <T> DataTypeTest addRoundTrip(DataType<T> dataType, T value)
{
inputs.add(new Input<>(dataType, value));
return addRoundTrip(dataType, value, true);
}

public <T> DataTypeTest addRoundTrip(DataType<T> dataType, T value, boolean useInWhereClause)
{
inputs.add(new Input<>(dataType, value, useInWhereClause));
return this;
}

Expand All @@ -55,27 +75,80 @@ public void execute(QueryRunner prestoExecutor, Session session, DataSetup dataS
List<Object> expectedResults = inputs.stream().map(Input::toPrestoQueryResult).collect(toList());
try (TestTable testTable = dataSetup.setupTestTable(unmodifiableList(inputs))) {
MaterializedResult materializedRows = prestoExecutor.execute(session, "SELECT * from " + testTable.getName());
assertEquals(materializedRows.getTypes(), expectedTypes);
List<Object> actualResults = getOnlyElement(materializedRows).getFields();
assertEquals(actualResults.size(), expectedResults.size(), "lists don't have the same size");
for (int i = 0; i < expectedResults.size(); i++) {
assertEquals(actualResults.get(i), expectedResults.get(i), "Element " + i);
checkResults(expectedTypes, expectedResults, materializedRows);
if (runSelectWithWhere) {
queryWithWhere(prestoExecutor, session, expectedTypes, expectedResults, testTable);
}
}
}

private void queryWithWhere(QueryRunner prestoExecutor, Session session, List<Type> expectedTypes, List<Object> expectedResults, TestTable testTable)
{
String prestoQuery = buildPrestoQueryWithWhereClauses(testTable);
try {
MaterializedResult filteredRows = prestoExecutor.execute(session, prestoQuery);
checkResults(expectedTypes, expectedResults, filteredRows);
}
catch (RuntimeException e) {
log.error("Exception caught during query with merged WHERE clause, querying one column at a time", e);
debugTypes(prestoExecutor, session, expectedTypes, expectedResults, testTable);
}
}

private void debugTypes(QueryRunner prestoExecutor, Session session, List<Type> expectedTypes, List<Object> expectedResults, TestTable testTable)
{
for (int i = 0; i < inputs.size(); i++) {
Input<?> input = inputs.get(i);
if (input.isUseInWhereClause()) {
String debugQuery = String.format("SELECT col_%d FROM %s WHERE col_%d IS NOT DISTINCT FROM %s", i, testTable.getName(), i, input.toPrestoLiteral());
log.info("Querying input: %d (expected type: %s, expectedResult: %s) using: %s", i, expectedTypes.get(i), expectedResults.get(i), debugQuery);
MaterializedResult debugRows = prestoExecutor.execute(session, debugQuery);
checkResults(expectedTypes.subList(i, i + 1), expectedResults.subList(i, i + 1), debugRows);
}
}
}

private String buildPrestoQueryWithWhereClauses(TestTable testTable)
{
List<String> predicates = new ArrayList<>();
for (int i = 0; i < inputs.size(); i++) {
Input<?> input = inputs.get(i);
if (input.isUseInWhereClause()) {
predicates.add(format("col_%d IS NOT DISTINCT FROM %s", i, input.toPrestoLiteral()));
}
}
return "SELECT * FROM " + testTable.getName() + " WHERE " + join(" AND ", predicates);
}

private void checkResults(List<Type> expectedTypes, List<Object> expectedResults, MaterializedResult materializedRows)
{
assertEquals(materializedRows.getTypes(), expectedTypes);
List<Object> actualResults = getOnlyElement(materializedRows).getFields();
assertEquals(actualResults.size(), expectedResults.size(), "lists don't have the same size");
for (int i = 0; i < expectedResults.size(); i++) {
assertEquals(actualResults.get(i), expectedResults.get(i), "Element " + i);
}
}

public static class Input<T>
{
private final DataType<T> dataType;
private final T value;
private final boolean useInWhereClause;

public Input(DataType<T> dataType, T value)
public Input(DataType<T> dataType, T value, boolean useInWhereClause)
{
this.dataType = dataType;
this.value = value;
this.useInWhereClause = useInWhereClause;
}

public boolean isUseInWhereClause()
{
return useInWhereClause;
}

String getInsertType()
public String getInsertType()
{
return dataType.getInsertType();
}
Expand All @@ -90,9 +163,14 @@ Object toPrestoQueryResult()
return dataType.toPrestoQueryResult(value);
}

String toLiteral()
public String toLiteral()
{
return dataType.toLiteral(value);
}

public String toPrestoLiteral()
{
return dataType.toPrestoLiteral(value);
}
}
}

0 comments on commit d26aab9

Please sign in to comment.