Skip to content

Commit

Permalink
Check Parquet schema mismatch
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenxiao authored and arhimondr committed Aug 30, 2019
1 parent d960aef commit 5d18a1c
Show file tree
Hide file tree
Showing 3 changed files with 324 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ public ParquetPageSource(
typesBuilder.add(type);
hiveColumnIndexes[columnIndex] = column.getHiveColumnIndex();

if (getParquetType(column, fileSchema, useParquetColumnNames) == null) {
if (getParquetType(type, fileSchema, useParquetColumnNames, column.getName(), column.getHiveColumnIndex(), column.getHiveType()) == null) {
constantBlocks[columnIndex] = RunLengthEncodedBlock.create(type, null, MAX_VECTOR_LENGTH);
fieldsBuilder.add(Optional.empty());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.facebook.presto.hive.HdfsEnvironment;
import com.facebook.presto.hive.HiveBatchPageSourceFactory;
import com.facebook.presto.hive.HiveColumnHandle;
import com.facebook.presto.hive.HiveType;
import com.facebook.presto.memory.context.AggregatedMemoryContext;
import com.facebook.presto.parquet.ParquetCorruptionException;
import com.facebook.presto.parquet.ParquetDataSource;
Expand All @@ -29,6 +30,8 @@
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.predicate.Domain;
import com.facebook.presto.spi.predicate.TupleDomain;
import com.facebook.presto.spi.type.StandardTypes;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.spi.type.TypeManager;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
Expand All @@ -42,7 +45,9 @@
import org.apache.parquet.hadoop.metadata.FileMetaData;
import org.apache.parquet.hadoop.metadata.ParquetMetadata;
import org.apache.parquet.io.MessageColumnIO;
import org.apache.parquet.schema.GroupType;
import org.apache.parquet.schema.MessageType;
import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName;
import org.joda.time.DateTimeZone;

import javax.inject.Inject;
Expand All @@ -61,6 +66,7 @@
import static com.facebook.presto.hive.HiveErrorCode.HIVE_BAD_DATA;
import static com.facebook.presto.hive.HiveErrorCode.HIVE_CANNOT_OPEN_SPLIT;
import static com.facebook.presto.hive.HiveErrorCode.HIVE_MISSING_DATA;
import static com.facebook.presto.hive.HiveErrorCode.HIVE_PARTITION_SCHEMA_MISMATCH;
import static com.facebook.presto.hive.HiveSessionProperties.isFailOnCorruptedParquetStatistics;
import static com.facebook.presto.hive.HiveSessionProperties.isUseParquetColumnNames;
import static com.facebook.presto.hive.HiveUtil.getDeserializerClassName;
Expand All @@ -71,11 +77,34 @@
import static com.facebook.presto.parquet.ParquetTypeUtils.getParquetTypeByName;
import static com.facebook.presto.parquet.predicate.PredicateUtils.buildPredicate;
import static com.facebook.presto.parquet.predicate.PredicateUtils.predicateMatches;
import static com.facebook.presto.spi.type.StandardTypes.ARRAY;
import static com.facebook.presto.spi.type.StandardTypes.BIGINT;
import static com.facebook.presto.spi.type.StandardTypes.CHAR;
import static com.facebook.presto.spi.type.StandardTypes.DATE;
import static com.facebook.presto.spi.type.StandardTypes.DECIMAL;
import static com.facebook.presto.spi.type.StandardTypes.INTEGER;
import static com.facebook.presto.spi.type.StandardTypes.MAP;
import static com.facebook.presto.spi.type.StandardTypes.REAL;
import static com.facebook.presto.spi.type.StandardTypes.ROW;
import static com.facebook.presto.spi.type.StandardTypes.SMALLINT;
import static com.facebook.presto.spi.type.StandardTypes.TIMESTAMP;
import static com.facebook.presto.spi.type.StandardTypes.TINYINT;
import static com.facebook.presto.spi.type.StandardTypes.VARBINARY;
import static com.facebook.presto.spi.type.StandardTypes.VARCHAR;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Strings.nullToEmpty;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.toList;
import static org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category.PRIMITIVE;
import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.BINARY;
import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.BOOLEAN;
import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.DOUBLE;
import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY;
import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.FLOAT;
import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT32;
import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT64;
import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT96;

public class ParquetPageSourceFactory
implements HiveBatchPageSourceFactory
Expand Down Expand Up @@ -160,7 +189,7 @@ public static ParquetPageSource createParquetPageSource(

List<org.apache.parquet.schema.Type> fields = columns.stream()
.filter(column -> column.getColumnType() == REGULAR)
.map(column -> getParquetType(column, fileSchema, useParquetColumnNames))
.map(column -> getParquetType(typeManager.getType(column.getTypeSignature()), fileSchema, useParquetColumnNames, column.getName(), column.getHiveColumnIndex(), column.getHiveType()))
.filter(Objects::nonNull)
.collect(toList());

Expand Down Expand Up @@ -249,15 +278,116 @@ public static TupleDomain<ColumnDescriptor> getParquetTupleDomain(Map<List<Strin
return TupleDomain.withColumnDomains(predicate.build());
}

public static org.apache.parquet.schema.Type getParquetType(HiveColumnHandle column, MessageType messageType, boolean useParquetColumnNames)
public static org.apache.parquet.schema.Type getParquetType(Type prestoType, MessageType messageType, boolean useParquetColumnNames, String columnName, int columnHiveIndex, HiveType hiveType)
{
org.apache.parquet.schema.Type type = null;
if (useParquetColumnNames) {
return getParquetTypeByName(column.getName(), messageType);
type = getParquetTypeByName(columnName, messageType);
}
else if (columnHiveIndex < messageType.getFieldCount()) {
type = messageType.getType(columnHiveIndex);
}

if (type == null) {
return null;
}

if (!checkSchemaMatch(type, prestoType)) {
String parquetTypeName;
if (type.isPrimitive()) {
parquetTypeName = type.asPrimitiveType().getPrimitiveTypeName().toString();
}
else {
GroupType group = type.asGroupType();
StringBuilder builder = new StringBuilder();
group.writeToStringBuilder(builder, "");
parquetTypeName = builder.toString();
}
throw new PrestoException(HIVE_PARTITION_SCHEMA_MISMATCH, format("The column %s is declared as type %s, but the Parquet file declares the column as type %s",
columnName,
hiveType,
parquetTypeName));
}
return type;
}

private static boolean checkSchemaMatch(org.apache.parquet.schema.Type parquetType, Type type)
{
String prestoType = type.getTypeSignature().getBase();
if (parquetType instanceof GroupType) {
GroupType groupType = parquetType.asGroupType();
switch (prestoType) {
case ROW:
if (groupType.getFields().size() == type.getTypeParameters().size()) {
for (int i = 0; i < groupType.getFields().size(); i++) {
if (!checkSchemaMatch(groupType.getFields().get(i), type.getTypeParameters().get(i))) {
return false;
}
}
return true;
}
return false;
case MAP:
if (groupType.getFields().size() != 1) {
return false;
}
org.apache.parquet.schema.Type mapKeyType = groupType.getFields().get(0);
if (mapKeyType instanceof GroupType) {
GroupType mapGroupType = mapKeyType.asGroupType();
return mapGroupType.getFields().size() == 2 &&
checkSchemaMatch(mapGroupType.getFields().get(0), type.getTypeParameters().get(0)) &&
checkSchemaMatch(mapGroupType.getFields().get(1), type.getTypeParameters().get(1));
}
return false;
case ARRAY:
/* array has a standard 3-level structure with middle level repeated group with a single field:
* optional group my_list (LIST) {
* repeated group element {
* required type field;
* };
* }
* Backward-compatibility support for 2-level arrays:
* optional group my_list (LIST) {
* repeated type field;
* }
* field itself could be primitive or group
*/
if (groupType.getFields().size() != 1) {
return false;
}
org.apache.parquet.schema.Type bagType = groupType.getFields().get(0);
if (bagType.isPrimitive()) {
return checkSchemaMatch(bagType.asPrimitiveType(), type.getTypeParameters().get(0));
}
GroupType bagGroupType = bagType.asGroupType();
return checkSchemaMatch(bagGroupType, type.getTypeParameters().get(0)) ||
(bagGroupType.getFields().size() == 1 && checkSchemaMatch(bagGroupType.getFields().get(0), type.getTypeParameters().get(0)));
default:
return false;
}
}

if (column.getHiveColumnIndex() < messageType.getFieldCount()) {
return messageType.getType(column.getHiveColumnIndex());
checkArgument(parquetType.isPrimitive(), "Unexpected parquet type for column: %s " + parquetType.getName());
PrimitiveTypeName parquetTypeName = parquetType.asPrimitiveType().getPrimitiveTypeName();
switch (parquetTypeName) {
case INT64:
return prestoType.equals(BIGINT) || prestoType.equals(DECIMAL);
case INT32:
return prestoType.equals(INTEGER) || prestoType.equals(SMALLINT) || prestoType.equals(DATE) || prestoType.equals(DECIMAL) || prestoType.equals(TINYINT);
case BOOLEAN:
return prestoType.equals(StandardTypes.BOOLEAN);
case FLOAT:
return prestoType.equals(REAL);
case DOUBLE:
return prestoType.equals(StandardTypes.DOUBLE);
case BINARY:
return prestoType.equals(VARBINARY) || prestoType.equals(VARCHAR) || prestoType.startsWith(CHAR) || prestoType.equals(DECIMAL);
case INT96:
return prestoType.equals(TIMESTAMP);
case FIXED_LEN_BYTE_ARRAY:
return prestoType.equals(DECIMAL);
default:
throw new IllegalArgumentException("Unexpected parquet type name: " + parquetTypeName);
}
return null;
}
}
Loading

0 comments on commit 5d18a1c

Please sign in to comment.