Skip to content

Commit

Permalink
[Gluten-core] Add struct literal support (facebookincubator#1048)
Browse files Browse the repository at this point in the history
  • Loading branch information
jinchengchenghh authored Mar 3, 2023
1 parent 1a4f23b commit 6a68e8c
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 7 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.glutenproject.substrait.expression;

import com.google.protobuf.ByteString;
import io.substrait.proto.Expression;
import org.apache.spark.sql.types.StructType;

import java.io.Serializable;

public class BinaryStructNode implements ExpressionNode, Serializable {
// first is index, second is value
private final byte[][] values;
private final StructType type;

public BinaryStructNode(byte[][] values, StructType type) {
this.values = values;
this.type = type;
}

public ExpressionNode getFieldLiteral(int index) {
return ExpressionBuilder.makeLiteral(values[index], type.fields()[index].dataType(),
type.fields()[index].nullable());
}

@Override
public Expression toProtobuf() {
Expression.Literal.Struct.Builder structBuilder = Expression.Literal.Struct.newBuilder();
Expression.Literal.Builder literalBuilder = Expression.Literal.newBuilder();
for (byte[] value : values) {
// TODO, here we copy the binary literal, if it is long such as BloomFilter binary,
// it will cost much time
literalBuilder.setBinary(ByteString.copyFrom(value));
structBuilder.addFields(literalBuilder);
}
literalBuilder.setStruct(structBuilder.build());

Expression.Builder builder = Expression.newBuilder();
builder.setLiteral(literalBuilder.build());

return builder.build();
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@
import io.glutenproject.expression.ConverterUtils;
import io.glutenproject.substrait.type.TypeBuilder;
import io.glutenproject.substrait.type.TypeNode;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.util.GenericArrayData;
import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.types.UTF8String;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

/** Contains helper functions for constructing substrait relations. */
public class ExpressionBuilder {
Expand Down Expand Up @@ -106,6 +109,10 @@ public static StringListNode makeStringList(ArrayList<String> strConstants) {
return new StringListNode(strConstants);
}

public static BinaryStructNode makeBinaryStruct(byte[][] binary, StructType type) {
return new BinaryStructNode(binary, type);
}

public static BinaryLiteralNode makeBinaryLiteral(byte[] bytesConstant) {
return new BinaryLiteralNode(bytesConstant);
}
Expand Down Expand Up @@ -223,6 +230,26 @@ public static ExpressionNode makeLiteral(Object obj, DataType dataType, Boolean
}
} else if (dataType instanceof NullType) {
return makeNullLiteral(TypeBuilder.makeNothing());
} else if (dataType instanceof StructType) {
StructType type = (StructType) dataType;
if (obj == null) {
List<TypeNode> typeNodes = Arrays.stream(type.fields())
.map(f -> ConverterUtils.getTypeNode(f.dataType(), f.nullable()))
.collect(Collectors.toList());
return makeNullLiteral(TypeBuilder.makeStruct(nullable, new ArrayList<>(typeNodes)));
} else {
if (Arrays.stream(type.fields()).anyMatch(f -> !(f.dataType() instanceof BinaryType))) {
throw new UnsupportedOperationException(
String.format("Type not supported in struct: %s, obj: %s, class: %s",
dataType, obj, obj.getClass().toString()));
}
InternalRow row = (InternalRow) obj;
byte[][] binarys = new byte[row.numFields()][];
for (int i = 0; i < row.numFields(); i++) {
binarys[i] = row.getBinary(i);
}
return makeBinaryStruct(binarys, type);
}
} else {
/// TODO(taiyang-li) implement Literal Node for Struct/Map/Array
throw new UnsupportedOperationException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ package io.glutenproject.expression

import com.google.common.collect.Lists
import io.glutenproject.expression.ConverterUtils.FunctionConfig
import io.glutenproject.substrait.expression.{ExpressionBuilder, ExpressionNode}
import io.glutenproject.substrait.expression.{BinaryStructNode, ExpressionBuilder, ExpressionNode}

import org.apache.spark.sql.catalyst.expressions.GetStructField
import org.apache.spark.sql.types.{IntegerType}
import org.apache.spark.sql.types.IntegerType

class GetStructFieldTransformer(
substraitExprName: String,
Expand All @@ -31,6 +32,9 @@ class GetStructFieldTransformer(

override def doTransform(args: Object): ExpressionNode = {
val childNode = childTransformer.doTransform(args)
if (childNode.isInstanceOf[BinaryStructNode]) {
return childNode.asInstanceOf[BinaryStructNode].getFieldLiteral(ordinal)
}
val ordinalNode = ExpressionBuilder.makeLiteral(ordinal, IntegerType, false)
val exprNodes = Lists.newArrayList(childNode, ordinalNode)
val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,6 @@ class TpcRunner(val queryResourceFolder: String, val dataPath: String) {
println(s"Executing SQL query from resource path $path...")
val sql = TpcRunner.resourceToString(path)
val prev = System.nanoTime()
if (caseId.equals("q8")) {
// because q8 fallback might_contain and offload bloomfilter
print("set bloomFilter false\n")
spark.conf.set("spark.gluten.sql.native.bloomFilter", "false")
}
val df = spark.sql(sql)
if (explain) {
df.explain(extended = true)
Expand Down

0 comments on commit 6a68e8c

Please sign in to comment.