Skip to content

Commit

Permalink
apacheGH-40055: [Java][Docs] Simplify use of Filter and Expression in…
Browse files Browse the repository at this point in the history
…to Dataset Substrait (apache#40056)

### Rationale for this change

Simplify creation of SQL Expression Filter and Projections into Arrow Java Dataset module using new [Substrait Feature for SQL Expressions](https://github.com/substrait-io/substrait-java/releases/tag/v0.26.0).

### What changes are included in this PR?

Update Apache Arrow Java Dataset Substrait documentation

### Are these changes tested?

Yes

### Are there any user-facing changes?

No
* Closes: apache#40055

Authored-by: david dali susanibar arce <[email protected]>
Signed-off-by: David Li <[email protected]>
  • Loading branch information
davisusanibar authored Feb 15, 2024
1 parent 621f707 commit a03d957
Showing 1 changed file with 42 additions and 291 deletions.
333 changes: 42 additions & 291 deletions docs/source/java/substrait.rst
Original file line number Diff line number Diff line change
Expand Up @@ -113,31 +113,19 @@ This requires the substrait-java library.
This Java program:

- Loads a Parquet file containing the "nation" table from the TPC-H benchmark.
- Applies a filter:
- ``N_NATIONKEY > 18``
- Projects two new columns:
- ``N_NAME || ' - ' || N_COMMENT``
- ``N_REGIONKEY + 10``
- Applies a filter: ``N_NATIONKEY > 18``
- ``N_NAME || ' - ' || N_COMMENT``



.. code-block:: Java
import io.substrait.extension.ExtensionCollector;
import io.substrait.proto.Expression;
import io.substrait.proto.ExpressionReference;
import com.google.common.collect.ImmutableList;
import io.substrait.isthmus.SqlExpressionToSubstrait;
import io.substrait.proto.ExtendedExpression;
import io.substrait.proto.FunctionArgument;
import io.substrait.proto.SimpleExtensionDeclaration;
import io.substrait.proto.SimpleExtensionURI;
import io.substrait.type.NamedStruct;
import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
import io.substrait.type.proto.TypeProtoConverter;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
import java.util.Optional;
import org.apache.arrow.dataset.file.FileFormat;
import org.apache.arrow.dataset.file.FileSystemDatasetFactory;
import org.apache.arrow.dataset.jni.NativeMemoryPool;
Expand All @@ -148,297 +136,60 @@ This Java program:
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.ipc.ArrowReader;
import org.apache.calcite.sql.parser.SqlParseException;
import java.nio.ByteBuffer;
import java.util.Base64;
import java.util.Optional;
public class ClientSubstraitExtendedExpressionsCookbook {
public static void main(String[] args) throws Exception {
// project and filter dataset using extended expression definition - 03 Expressions:
// Expression 01 - CONCAT: N_NAME || ' - ' || N_COMMENT = col 1 || ' - ' || col 3
// Expression 02 - ADD: N_REGIONKEY + 10 = col 1 + 10
// Expression 03 - FILTER: N_NATIONKEY > 18 = col 3 > 18
public static void main(String[] args) throws SqlParseException {
projectAndFilterDataset();
}
public static void projectAndFilterDataset() {
private static void projectAndFilterDataset() throws SqlParseException {
String uri = "file:///Users/data/tpch_parquet/nation.parquet";
ScanOptions options = new ScanOptions.Builder(/*batchSize*/ 32768)
.columns(Optional.empty())
.substraitFilter(getSubstraitExpressionFilter())
.substraitProjection(getSubstraitExpressionProjection())
.build();
try (
BufferAllocator allocator = new RootAllocator();
DatasetFactory datasetFactory = new FileSystemDatasetFactory(
allocator, NativeMemoryPool.getDefault(),
FileFormat.PARQUET, uri);
Dataset dataset = datasetFactory.finish();
Scanner scanner = dataset.newScan(options);
ArrowReader reader = scanner.scanBatches()
) {
ScanOptions options =
new ScanOptions.Builder(/*batchSize*/ 32768)
.columns(Optional.empty())
.substraitFilter(getByteBuffer(new String[]{"N_NATIONKEY > 18"}))
.substraitProjection(getByteBuffer(new String[]{"N_REGIONKEY + 10",
"N_NAME || CAST(' - ' as VARCHAR) || N_COMMENT"}))
.build();
try (BufferAllocator allocator = new RootAllocator();
DatasetFactory datasetFactory =
new FileSystemDatasetFactory(
allocator, NativeMemoryPool.getDefault(), FileFormat.PARQUET, uri);
Dataset dataset = datasetFactory.finish();
Scanner scanner = dataset.newScan(options);
ArrowReader reader = scanner.scanBatches()) {
while (reader.loadNextBatch()) {
System.out.println(
reader.getVectorSchemaRoot().contentToTSVString());
System.out.println(reader.getVectorSchemaRoot().contentToTSVString());
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
private static ByteBuffer getSubstraitExpressionProjection() {
// Expression: N_REGIONKEY + 10 = col 3 + 10
Expression.Builder selectionBuilderProjectOne = Expression.newBuilder().
setSelection(
Expression.FieldReference.newBuilder().
setDirectReference(
Expression.ReferenceSegment.newBuilder().
setStructField(
Expression.ReferenceSegment.StructField.newBuilder().setField(
2)
)
)
);
Expression.Builder literalBuilderProjectOne = Expression.newBuilder()
.setLiteral(
Expression.Literal.newBuilder().setI32(10)
);
io.substrait.proto.Type outputProjectOne = TypeCreator.NULLABLE.I32.accept(
new TypeProtoConverter(new ExtensionCollector()));
Expression.Builder expressionBuilderProjectOne = Expression.
newBuilder().
setScalarFunction(
Expression.
ScalarFunction.
newBuilder().
setFunctionReference(0).
setOutputType(outputProjectOne).
addArguments(
0,
FunctionArgument.newBuilder().setValue(
selectionBuilderProjectOne)
).
addArguments(
1,
FunctionArgument.newBuilder().setValue(
literalBuilderProjectOne)
)
);
ExpressionReference.Builder expressionReferenceBuilderProjectOne = ExpressionReference.newBuilder().
setExpression(expressionBuilderProjectOne)
.addOutputNames("ADD_TEN_TO_COLUMN_N_REGIONKEY");
// Expression: name || name = N_NAME || "-" || N_COMMENT = col 1 || col 3
Expression.Builder selectionBuilderProjectTwo = Expression.newBuilder().
setSelection(
Expression.FieldReference.newBuilder().
setDirectReference(
Expression.ReferenceSegment.newBuilder().
setStructField(
Expression.ReferenceSegment.StructField.newBuilder().setField(
1)
)
)
);
Expression.Builder selectionBuilderProjectTwoConcatLiteral = Expression.newBuilder()
.setLiteral(
Expression.Literal.newBuilder().setString(" - ")
);
Expression.Builder selectionBuilderProjectOneToConcat = Expression.newBuilder().
setSelection(
Expression.FieldReference.newBuilder().
setDirectReference(
Expression.ReferenceSegment.newBuilder().
setStructField(
Expression.ReferenceSegment.StructField.newBuilder().setField(
3)
)
)
);
io.substrait.proto.Type outputProjectTwo = TypeCreator.NULLABLE.STRING.accept(
new TypeProtoConverter(new ExtensionCollector()));
Expression.Builder expressionBuilderProjectTwo = Expression.
newBuilder().
setScalarFunction(
Expression.
ScalarFunction.
newBuilder().
setFunctionReference(1).
setOutputType(outputProjectTwo).
addArguments(
0,
FunctionArgument.newBuilder().setValue(
selectionBuilderProjectTwo)
).
addArguments(
1,
FunctionArgument.newBuilder().setValue(
selectionBuilderProjectTwoConcatLiteral)
).
addArguments(
2,
FunctionArgument.newBuilder().setValue(
selectionBuilderProjectOneToConcat)
)
);
ExpressionReference.Builder expressionReferenceBuilderProjectTwo = ExpressionReference.newBuilder().
setExpression(expressionBuilderProjectTwo)
.addOutputNames("CONCAT_COLUMNS_N_NAME_AND_N_COMMENT");
List<String> columnNames = Arrays.asList("N_NATIONKEY", "N_NAME",
"N_REGIONKEY", "N_COMMENT");
List<Type> dataTypes = Arrays.asList(
TypeCreator.NULLABLE.I32,
TypeCreator.NULLABLE.STRING,
TypeCreator.NULLABLE.I32,
TypeCreator.NULLABLE.STRING
);
NamedStruct of = NamedStruct.of(
columnNames,
Type.Struct.builder().fields(dataTypes).nullable(false).build()
);
// Extensions URI
HashMap<String, SimpleExtensionURI> extensionUris = new HashMap<>();
extensionUris.put(
"key-001",
SimpleExtensionURI.newBuilder()
.setExtensionUriAnchor(1)
.setUri("/functions_arithmetic.yaml")
.build()
);
// Extensions
ArrayList<SimpleExtensionDeclaration> extensions = new ArrayList<>();
SimpleExtensionDeclaration extensionFunctionAdd = SimpleExtensionDeclaration.newBuilder()
.setExtensionFunction(
SimpleExtensionDeclaration.ExtensionFunction.newBuilder()
.setFunctionAnchor(0)
.setName("add:i32_i32")
.setExtensionUriReference(1))
.build();
SimpleExtensionDeclaration extensionFunctionGreaterThan = SimpleExtensionDeclaration.newBuilder()
.setExtensionFunction(
SimpleExtensionDeclaration.ExtensionFunction.newBuilder()
.setFunctionAnchor(1)
.setName("concat:vchar")
.setExtensionUriReference(2))
.build();
extensions.add(extensionFunctionAdd);
extensions.add(extensionFunctionGreaterThan);
// Extended Expression
ExtendedExpression.Builder extendedExpressionBuilder =
ExtendedExpression.newBuilder().
addReferredExpr(0,
expressionReferenceBuilderProjectOne).
addReferredExpr(1,
expressionReferenceBuilderProjectTwo).
setBaseSchema(of.toProto(new TypeProtoConverter(
new ExtensionCollector())));
extendedExpressionBuilder.addAllExtensionUris(extensionUris.values());
extendedExpressionBuilder.addAllExtensions(extensions);
ExtendedExpression extendedExpression = extendedExpressionBuilder.build();
byte[] extendedExpressions = Base64.getDecoder().decode(
Base64.getEncoder().encodeToString(
extendedExpression.toByteArray()));
ByteBuffer substraitExpressionProjection = ByteBuffer.allocateDirect(
extendedExpressions.length);
substraitExpressionProjection.put(extendedExpressions);
return substraitExpressionProjection;
}
private static ByteBuffer getSubstraitExpressionFilter() {
// Expression: Filter: N_NATIONKEY > 18 = col 1 > 18
Expression.Builder selectionBuilderFilterOne = Expression.newBuilder().
setSelection(
Expression.FieldReference.newBuilder().
setDirectReference(
Expression.ReferenceSegment.newBuilder().
setStructField(
Expression.ReferenceSegment.StructField.newBuilder().setField(
0)
)
)
);
Expression.Builder literalBuilderFilterOne = Expression.newBuilder()
.setLiteral(
Expression.Literal.newBuilder().setI32(18)
);
io.substrait.proto.Type outputFilterOne = TypeCreator.NULLABLE.BOOLEAN.accept(
new TypeProtoConverter(new ExtensionCollector()));
Expression.Builder expressionBuilderFilterOne = Expression.
newBuilder().
setScalarFunction(
Expression.
ScalarFunction.
newBuilder().
setFunctionReference(1).
setOutputType(outputFilterOne).
addArguments(
0,
FunctionArgument.newBuilder().setValue(
selectionBuilderFilterOne)
).
addArguments(
1,
FunctionArgument.newBuilder().setValue(
literalBuilderFilterOne)
)
);
ExpressionReference.Builder expressionReferenceBuilderFilterOne = ExpressionReference.newBuilder().
setExpression(expressionBuilderFilterOne)
.addOutputNames("COLUMN_N_NATIONKEY_GREATER_THAN_18");
List<String> columnNames = Arrays.asList("N_NATIONKEY", "N_NAME",
"N_REGIONKEY", "N_COMMENT");
List<Type> dataTypes = Arrays.asList(
TypeCreator.NULLABLE.I32,
TypeCreator.NULLABLE.STRING,
TypeCreator.NULLABLE.I32,
TypeCreator.NULLABLE.STRING
);
NamedStruct of = NamedStruct.of(
columnNames,
Type.Struct.builder().fields(dataTypes).nullable(false).build()
);
// Extensions URI
HashMap<String, SimpleExtensionURI> extensionUris = new HashMap<>();
extensionUris.put(
"key-001",
SimpleExtensionURI.newBuilder()
.setExtensionUriAnchor(1)
.setUri("/functions_comparison.yaml")
.build()
);
// Extensions
ArrayList<SimpleExtensionDeclaration> extensions = new ArrayList<>();
SimpleExtensionDeclaration extensionFunctionLowerThan = SimpleExtensionDeclaration.newBuilder()
.setExtensionFunction(
SimpleExtensionDeclaration.ExtensionFunction.newBuilder()
.setFunctionAnchor(1)
.setName("gt:any_any")
.setExtensionUriReference(1))
.build();
extensions.add(extensionFunctionLowerThan);
// Extended Expression
ExtendedExpression.Builder extendedExpressionBuilder =
ExtendedExpression.newBuilder().
addReferredExpr(0,
expressionReferenceBuilderFilterOne).
setBaseSchema(of.toProto(new TypeProtoConverter(
new ExtensionCollector())));
extendedExpressionBuilder.addAllExtensionUris(extensionUris.values());
extendedExpressionBuilder.addAllExtensions(extensions);
ExtendedExpression extendedExpression = extendedExpressionBuilder.build();
byte[] extendedExpressions = Base64.getDecoder().decode(
Base64.getEncoder().encodeToString(
extendedExpression.toByteArray()));
ByteBuffer substraitExpressionFilter = ByteBuffer.allocateDirect(
extendedExpressions.length);
substraitExpressionFilter.put(extendedExpressions);
return substraitExpressionFilter;
private static ByteBuffer getByteBuffer(String[] sqlExpression) throws SqlParseException {
String schema =
"CREATE TABLE NATION (N_NATIONKEY INT NOT NULL, N_NAME VARCHAR, "
+ "N_REGIONKEY INT NOT NULL, N_COMMENT VARCHAR)";
SqlExpressionToSubstrait expressionToSubstrait = new SqlExpressionToSubstrait();
ExtendedExpression expression =
expressionToSubstrait.convert(sqlExpression, ImmutableList.of(schema));
byte[] expressionToByte =
Base64.getDecoder().decode(Base64.getEncoder().encodeToString(expression.toByteArray()));
ByteBuffer byteBuffer = ByteBuffer.allocateDirect(expressionToByte.length);
byteBuffer.put(expressionToByte);
return byteBuffer;
}
}
.. code-block:: text
ADD_TEN_TO_COLUMN_N_REGIONKEY CONCAT_COLUMNS_N_NAME_AND_N_COMMENT
column-1 column-2
13 ROMANIA - ular asymptotes are about the furious multipliers. express dependencies nag above the ironically ironic account
14 SAUDI ARABIA - ts. silent requests haggle. closely express packages sleep across the blithely
12 VIETNAM - hely enticingly express accounts. even, final
Expand Down

0 comments on commit a03d957

Please sign in to comment.