Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Java bindings for string literal support in AST #13072

Merged
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
982af8a
string scalar support in AST - proof of concept
karthikeyann Mar 30, 2023
0a9eb86
Add cudf::ast::generic_scalar_device_view
karthikeyann Apr 4, 2023
50ee55d
remove filter by range example from test code
karthikeyann Apr 4, 2023
9735d51
cleanup docs
karthikeyann Apr 4, 2023
8653e61
Merge branch 'branch-23.06' of github.com:rapidsai/cudf into fea-stri…
karthikeyann Apr 4, 2023
1037b40
add jni bindings for string literal in AST
karthikeyann Apr 5, 2023
920adad
add jni string literal and column comparison tests
karthikeyann Apr 5, 2023
35fb4bd
reduce benchmark runtime by skipping unrequired combinations
karthikeyann Apr 16, 2023
a2c2004
Optimize List and Struct joining methods
karthikeyann Apr 16, 2023
cb24134
update default rows_per_chunk in cython
karthikeyann Apr 16, 2023
d43cad3
cleanup comments
karthikeyann Apr 18, 2023
3d9acc9
Merge branch 'branch-23.06' into enh-json_writer_opt
karthikeyann Apr 18, 2023
d1d6ccb
default value for append_colon
karthikeyann Apr 18, 2023
9ec93b9
Merge branch 'enh-json_writer_opt' of github.com:karthikeyann/cudf in…
karthikeyann Apr 18, 2023
f8bea99
add null string literal test cases
karthikeyann Apr 18, 2023
1babddc
Merge branch 'branch-23.06' into fea-jni-string_scalar_ast_compare
karthikeyann Apr 20, 2023
c2b7c63
fix merge mistake
karthikeyann Apr 20, 2023
ed4c73d
Merge branch 'branch-23.06' into fea-jni-string_scalar_ast_compare
karthikeyann Apr 21, 2023
a39f085
Merge branch 'branch-23.06' into fea-jni-string_scalar_ast_compare
karthikeyann Apr 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion java/src/main/java/ai/rapids/cudf/ast/Literal.java
karthikeyann marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -200,6 +200,18 @@ public static Literal ofDurationFromLong(DType type, Long value) {
return ofDurationFromLong(type, value.longValue());
}

/** Construct a string literal with the specified value or null. */
public static Literal ofString(String value) {
if (value == null) {
return ofNull(DType.STRING);
}
byte[] stringBytes = value.getBytes();
karthikeyann marked this conversation as resolved.
Show resolved Hide resolved
byte[] serializedValue = new byte[stringBytes.length + Integer.BYTES];
ByteBuffer.wrap(serializedValue).order(ByteOrder.nativeOrder()).putInt(stringBytes.length);
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
System.arraycopy(stringBytes, 0, serializedValue, Integer.BYTES, stringBytes.length);
return new Literal(DType.STRING, serializedValue);
}

Literal(DType type, byte[] serializedValue) {
this.type = type;
this.serializedValue = serializedValue;
Expand Down
45 changes: 37 additions & 8 deletions java/src/main/native/src/CompiledExpression.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,6 +15,7 @@
*/

#include <cstdint>
#include <memory>
#include <stdexcept>
#include <vector>

Expand Down Expand Up @@ -56,12 +57,20 @@ class jni_serialized_ast {

/** Read a multi-byte value from the serialized AST data buffer */
template <typename T> T read() {
check_for_eof(sizeof(T));
// use memcpy since data may be misaligned
T result;
memcpy(reinterpret_cast<jbyte *>(&result), data_ptr, sizeof(T));
data_ptr += sizeof(T);
return result;
if constexpr (std::is_same_v<T, std::string>) {
auto const size = read<cudf::size_type>();
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
check_for_eof(size);
auto const result = std::string(reinterpret_cast<char const *>(data_ptr), size);
data_ptr += size;
return result;
} else {
check_for_eof(sizeof(T));
// use memcpy since data may be misaligned
T result;
memcpy(reinterpret_cast<jbyte *>(&result), data_ptr, sizeof(T));
data_ptr += sizeof(T);
return result;
}
}

/** Decode a libcudf data type from the serialized AST data buffer */
Expand Down Expand Up @@ -254,9 +263,29 @@ struct make_literal {
std::move(scalar_ptr));
}

/** Construct an AST literal from a string value */
template <typename T, std::enable_if_t<std::is_same_v<T, cudf::string_view>> * = nullptr>
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
cudf::ast::literal &operator()(cudf::data_type dtype, bool is_valid,
cudf::jni::ast::compiled_expr &compiled_expr,
jni_serialized_ast &jni_ast) {
std::unique_ptr<cudf::scalar> scalar_ptr = [&]() {
if (is_valid) {
std::string val = jni_ast.read<std::string>();
return std::make_unique<cudf::string_scalar>(val, is_valid);
} else {
return std::make_unique<cudf::string_scalar>(rmm::device_buffer{}, is_valid);
}
}();

auto &str_scalar = static_cast<cudf::string_scalar &>(*scalar_ptr);
return compiled_expr.add_literal(std::make_unique<cudf::ast::literal>(str_scalar),
std::move(scalar_ptr));
}

/** Default functor implementation to catch type dispatch errors */
template <typename T, std::enable_if_t<!cudf::is_numeric<T>() && !cudf::is_timestamp<T>() &&
!cudf::is_duration<T>()> * = nullptr>
!cudf::is_duration<T>() &&
!std::is_same_v<T, cudf::string_view>> * = nullptr>
cudf::ast::literal &operator()(cudf::data_type dtype, bool is_valid,
cudf::jni::ast::compiled_expr &compiled_expr,
jni_serialized_ast &jni_ast) {
Expand Down
karthikeyann marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -476,6 +476,69 @@ void testBinaryComparisonOperationTransform(BinaryOperator op, Integer[] in1, In
}
}

private static Stream<Arguments> createStringLiteralComparisonParams() {
String[] in1 = new String[] {"a", "bb", null, "ccc", "dddd"};
String in2 = "ccc";
return Stream.of(
// nulls compare as equal by default
Arguments.of(BinaryOperator.NULL_EQUAL, in1, in2, Arrays.asList(false, false, false, true, false)),
Arguments.of(BinaryOperator.NOT_EQUAL, in1, in2, mapArray(in1, (a) -> !a.equals(in2))),
Arguments.of(BinaryOperator.LESS, in1, in2, mapArray(in1, (a) -> a.compareTo(in2) < 0)),
Arguments.of(BinaryOperator.GREATER, in1, in2, mapArray(in1, (a) -> a.compareTo(in2) > 0)),
Arguments.of(BinaryOperator.LESS_EQUAL, in1, in2, mapArray(in1, (a) -> a.compareTo(in2) <= 0)),
Arguments.of(BinaryOperator.GREATER_EQUAL, in1, in2, mapArray(in1, (a) -> a.compareTo(in2) >= 0)),
// null literal
Arguments.of(BinaryOperator.NULL_EQUAL, in1, null, Arrays.asList(false, false, true, false, false)),
Arguments.of(BinaryOperator.NOT_EQUAL, in1, null, Arrays.asList(null, null, null, null, null)),
Arguments.of(BinaryOperator.LESS, in1, null, Arrays.asList(null, null, null, null, null)));
}

@ParameterizedTest
@MethodSource("createStringLiteralComparisonParams")
void testStringLiteralComparison(BinaryOperator op, String[] in1, String in2,
List<Boolean> expectedValues) {
Literal lit = Literal.ofString(in2);
jlowe marked this conversation as resolved.
Show resolved Hide resolved
BinaryOperation expr = new BinaryOperation(op,
new ColumnReference(0),
lit);
try (Table t = new Table.TestBuilder().column(in1).build();
CompiledExpression compiledExpr = expr.compile();
ColumnVector actual = compiledExpr.computeColumn(t);
ColumnVector expected = ColumnVector.fromBoxedBooleans(
expectedValues.toArray(new Boolean[0]))) {
assertColumnsAreEqual(expected, actual);
}
}

private static Stream<Arguments> createBinaryComparisonOperationStringParams() {
String[] in1 = new String[] {"a", "bb", null, "ccc", "dddd"};
String[] in2 = new String[] {"aa", "b", null, "ccc", "ddd"};
return Stream.of(
// nulls compare as equal by default
Arguments.of(BinaryOperator.NULL_EQUAL, in1, in2, Arrays.asList(false, false, true, true, false)),
Arguments.of(BinaryOperator.NOT_EQUAL, in1, in2, mapArray(in1, in2, (a, b) -> !a.equals(b))),
Arguments.of(BinaryOperator.LESS, in1, in2, mapArray(in1, in2, (a, b) -> a.compareTo(b) < 0)),
Arguments.of(BinaryOperator.GREATER, in1, in2, mapArray(in1, in2, (a, b) -> a.compareTo(b) > 0)),
Arguments.of(BinaryOperator.LESS_EQUAL, in1, in2, mapArray(in1, in2, (a, b) -> a.compareTo(b) <= 0)),
Arguments.of(BinaryOperator.GREATER_EQUAL, in1, in2, mapArray(in1, in2, (a, b) -> a.compareTo(b) >= 0)));
}

@ParameterizedTest
@MethodSource("createBinaryComparisonOperationStringParams")
void testBinaryComparisonOperationStringTransform(BinaryOperator op, String[] in1, String[] in2,
List<Boolean> expectedValues) {
BinaryOperation expr = new BinaryOperation(op,
new ColumnReference(0),
new ColumnReference(1));
try (Table t = new Table.TestBuilder().column(in1).column(in2).build();
CompiledExpression compiledExpr = expr.compile();
ColumnVector actual = compiledExpr.computeColumn(t);
ColumnVector expected = ColumnVector.fromBoxedBooleans(
expectedValues.toArray(new Boolean[0]))) {
assertColumnsAreEqual(expected, actual);
}
}

private static Stream<Arguments> createBinaryBitwiseOperationParams() {
Integer[] in1 = new Integer[] { -5, 4, null, 2, -3 };
Integer[] in2 = new Integer[] { 123, -456, null, 0, -3 };
Expand Down