Skip to content

Commit

Permalink
Support null literals in expressions (#9117)
Browse files Browse the repository at this point in the history
This PR resolves #8831 by propagating null values appropriately when they are provided in the form of literals. In addition to adding support at the level of expression evaluation, this PR also adds the appropriate dispatch to the correct code path at the level of calls to APIs using expressions containing nulls so that the null-supporting code paths can be triggered even when operating on tables that do not contain any nulls.

Authors:
  - Vyas Ramasubramani (https://github.com/vyasr)

Approvers:
  - Jason Lowe (https://github.com/jlowe)
  - Bradley Dice (https://github.com/bdice)
  - Mike Wilson (https://github.com/hyperbolic2346)

URL: #9117
  • Loading branch information
vyasr authored Sep 7, 2021
1 parent 91f8533 commit a02f888
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 62 deletions.
9 changes: 8 additions & 1 deletion cpp/include/cudf/ast/detail/expression_evaluator.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,14 @@ struct expression_evaluator {
return ReturnType(table.column(data_index).element<Element>(row_index));
}
} else if (ref_type == detail::device_data_reference_type::LITERAL) {
return ReturnType(plan.literals[data_index].value<Element>());
if constexpr (has_nulls) {
return plan.literals[data_index].is_valid()
? ReturnType(plan.literals[data_index].value<Element>())
: ReturnType();

} else {
return ReturnType(plan.literals[data_index].value<Element>());
}
} else { // Assumes ref_type == detail::device_data_reference_type::INTERMEDIATE
// Using memcpy instead of reinterpret_cast<Element*> for safe type aliasing
// Using a temporary variable ensures that the compiler knows the result is aligned
Expand Down
55 changes: 50 additions & 5 deletions cpp/include/cudf/ast/expressions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,21 @@ class expression_parser;
/**
* @brief A generic expression that can be evaluated to return a value.
*
* This class is a part of a "visitor" pattern with the `linearizer` class.
* Nodes inheriting from this class can accept visitors.
* This class is a part of a "visitor" pattern with the `expression_parser` class.
* Expressions inheriting from this class can accept parsers as visitors.
*/
struct expression {
virtual cudf::size_type accept(detail::expression_parser& visitor) const = 0;

bool may_evaluate_null(table_view const& left, rmm::cuda_stream_view stream) const
{
return may_evaluate_null(left, left, stream);
}

virtual bool may_evaluate_null(table_view const& left,
table_view const& right,
rmm::cuda_stream_view stream) const = 0;

virtual ~expression() {}
};

Expand Down Expand Up @@ -116,7 +125,8 @@ class literal : public expression {
* @param value A numeric scalar value.
*/
template <typename T>
literal(cudf::numeric_scalar<T>& value) : value(cudf::get_scalar_device_view(value))
literal(cudf::numeric_scalar<T>& value)
: scalar(value), value(cudf::get_scalar_device_view(value))
{
}

Expand All @@ -127,7 +137,8 @@ class literal : public expression {
* @param value A timestamp scalar value.
*/
template <typename T>
literal(cudf::timestamp_scalar<T>& value) : value(cudf::get_scalar_device_view(value))
literal(cudf::timestamp_scalar<T>& value)
: scalar(value), value(cudf::get_scalar_device_view(value))
{
}

Expand All @@ -138,7 +149,8 @@ class literal : public expression {
* @param value A duration scalar value.
*/
template <typename T>
literal(cudf::duration_scalar<T>& value) : value(cudf::get_scalar_device_view(value))
literal(cudf::duration_scalar<T>& value)
: scalar(value), value(cudf::get_scalar_device_view(value))
{
}

Expand All @@ -164,7 +176,22 @@ class literal : public expression {
*/
cudf::size_type accept(detail::expression_parser& visitor) const override;

bool may_evaluate_null(table_view const& left,
table_view const& right,
rmm::cuda_stream_view stream) const override
{
return !is_valid(stream);
}

/**
* @brief Check if the underlying scalar is valid.
*
* @return bool
*/
bool is_valid(rmm::cuda_stream_view stream) const { return scalar.is_valid(stream); }

private:
cudf::scalar const& scalar;
cudf::detail::fixed_width_scalar_device_view_base const value;
};

Expand Down Expand Up @@ -240,6 +267,13 @@ class column_reference : public expression {
*/
cudf::size_type accept(detail::expression_parser& visitor) const override;

bool may_evaluate_null(table_view const& left,
table_view const& right,
rmm::cuda_stream_view stream) const override
{
return (table_source == table_reference::LEFT ? left : right).column(column_index).has_nulls();
}

private:
cudf::size_type column_index;
table_reference table_source;
Expand Down Expand Up @@ -296,6 +330,17 @@ class operation : public expression {
*/
cudf::size_type accept(detail::expression_parser& visitor) const override;

bool may_evaluate_null(table_view const& left,
table_view const& right,
rmm::cuda_stream_view stream) const override
{
return std::any_of(operands.cbegin(),
operands.cend(),
[&left, &right, &stream](std::reference_wrapper<expression const> subexpr) {
return subexpr.get().may_evaluate_null(left, right, stream);
});
};

private:
ast_operator const op;
std::vector<std::reference_wrapper<expression const>> const operands;
Expand Down
11 changes: 4 additions & 7 deletions cpp/src/join/conditional_join.cu
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,10 @@ conditional_join(table_view const& left,
}
}

// Prepare output column. Whether or not the output column is nullable is
// determined by whether any of the columns in the input table are nullable.
// If none of the input columns actually contain nulls, we can still use the
// non-nullable version of the expression evaluation code path for
// performance, so we capture that information as well.
auto const nullable = cudf::nullable(left) || cudf::nullable(right);
auto const has_nulls = nullable && (cudf::has_nulls(left) || cudf::has_nulls(right));
// If evaluating the expression may produce null outputs we create a nullable
// output column and follow the null-supporting expression evaluation code
// path.
auto const has_nulls = binary_predicate.may_evaluate_null(left, right, stream);

auto const parser =
ast::detail::expression_parser{binary_predicate, left, right, has_nulls, stream, mr};
Expand Down
14 changes: 5 additions & 9 deletions cpp/src/transform/compute_column.cu
Original file line number Diff line number Diff line change
Expand Up @@ -83,19 +83,15 @@ std::unique_ptr<column> compute_column(table_view const& table,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
// Prepare output column. Whether or not the output column is nullable is
// determined by whether any of the columns in the input table are nullable.
// If none of the input columns actually contain nulls, we can still use the
// non-nullable version of the expression evaluation code path for
// performance, so we capture that information as well.
auto const nullable = cudf::nullable(table);
auto const has_nulls = nullable && cudf::has_nulls(table);
// If evaluating the expression may produce null outputs we create a nullable
// output column and follow the null-supporting expression evaluation code
// path.
auto const has_nulls = expr.may_evaluate_null(table, stream);

auto const parser = ast::detail::expression_parser{expr, table, has_nulls, stream, mr};

auto const output_column_mask_state =
nullable ? (has_nulls ? mask_state::UNINITIALIZED : mask_state::ALL_VALID)
: mask_state::UNALLOCATED;
has_nulls ? mask_state::UNINITIALIZED : mask_state::UNALLOCATED;

auto output_column = cudf::make_fixed_width_column(
parser.output_type(), table.num_rows(), output_column_mask_state, stream, mr);
Expand Down
17 changes: 17 additions & 0 deletions cpp/tests/ast/transform_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,23 @@ TEST_F(TransformTest, Literal)
cudf::test::expect_columns_equal(expected, result->view(), verbosity);
}

TEST_F(TransformTest, NullLiteral)
{
auto c_0 = column_wrapper<int32_t>{0, 0, 0, 0};
auto table = cudf::table_view{{c_0}};

auto literal_value = cudf::numeric_scalar<int32_t>(-123);
literal_value.set_valid_async(false);
auto literal = cudf::ast::literal(literal_value);

auto expression = cudf::ast::operation(cudf::ast::ast_operator::IDENTITY, literal);

auto result = cudf::compute_column(table, expression);
auto expected = column_wrapper<int32_t>({-123, -123, -123, -123}, {0, 0, 0, 0});

cudf::test::expect_columns_equal(expected, result->view(), verbosity);
}

TEST_F(TransformTest, BasicAddition)
{
auto c_0 = column_wrapper<int32_t>{3, 20, 1, 50};
Expand Down
64 changes: 24 additions & 40 deletions java/src/test/java/ai/rapids/cudf/ast/CompiledExpressionTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.junit.jupiter.params.provider.NullSource;

import java.util.ArrayList;
import java.util.Arrays;
Expand Down Expand Up @@ -77,20 +78,18 @@ public void testBooleanLiteralTransform() {
assertColumnsAreEqual(trueExprExpected, trueExprActual);
}

// Uncomment the following after https://github.com/rapidsai/cudf/issues/8831 is fixed
// Literal nullLiteral = Literal.ofBoolean(null);
// UnaryOperation nullExpr = new UnaryOperation(AstOperator.IDENTITY, nullLiteral);
// try (CompiledExpression nullCompiledExpr = nullExpr.compile();
// ColumnVector nullExprActual = nullCompiledExpr.computeColumn(t);
// ColumnVector nullExprExpected = ColumnVector.fromBoxedBooleans(null, null, null)) {
// assertColumnsAreEqual(nullExprExpected, nullExprActual);
// }
Literal nullLiteral = Literal.ofBoolean(null);
UnaryOperation nullExpr = new UnaryOperation(UnaryOperator.IDENTITY, nullLiteral);
try (CompiledExpression nullCompiledExpr = nullExpr.compile();
ColumnVector nullExprActual = nullCompiledExpr.computeColumn(t);
ColumnVector nullExprExpected = ColumnVector.fromBoxedBooleans(null, null, null)) {
assertColumnsAreEqual(nullExprExpected, nullExprActual);
}
}
}

@ParameterizedTest
// Uncomment the following line after https://github.com/rapidsai/cudf/issues/8831 is fixed
// @NullSource
@NullSource
@ValueSource(bytes = 0x12)
public void testByteLiteralTransform(Byte value) {
Literal expr = Literal.ofByte(value);
Expand All @@ -103,8 +102,7 @@ public void testByteLiteralTransform(Byte value) {
}

@ParameterizedTest
// Uncomment the following line after https://github.com/rapidsai/cudf/issues/8831 is fixed
// @NullSource
@NullSource
@ValueSource(shorts = 0x1234)
public void testShortLiteralTransform(Short value) {
Literal expr = Literal.ofShort(value);
Expand All @@ -117,8 +115,7 @@ public void testShortLiteralTransform(Short value) {
}

@ParameterizedTest
// Uncomment the following line after https://github.com/rapidsai/cudf/issues/8831 is fixed
// @NullSource
@NullSource
@ValueSource(ints = 0x12345678)
public void testIntLiteralTransform(Integer value) {
Literal expr = Literal.ofInt(value);
Expand All @@ -131,8 +128,7 @@ public void testIntLiteralTransform(Integer value) {
}

@ParameterizedTest
// Uncomment the following line after https://github.com/rapidsai/cudf/issues/8831 is fixed
// @NullSource
@NullSource
@ValueSource(longs = 0x1234567890abcdefL)
public void testLongLiteralTransform(Long value) {
Literal expr = Literal.ofLong(value);
Expand All @@ -145,8 +141,7 @@ public void testLongLiteralTransform(Long value) {
}

@ParameterizedTest
// Uncomment the following line after https://github.com/rapidsai/cudf/issues/8831 is fixed
// @NullSource
@NullSource
@ValueSource(floats = { 123456.789f, Float.NaN, Float.POSITIVE_INFINITY, Float.NEGATIVE_INFINITY} )
public void testFloatLiteralTransform(Float value) {
Literal expr = Literal.ofFloat(value);
Expand All @@ -159,8 +154,7 @@ public void testFloatLiteralTransform(Float value) {
}

@ParameterizedTest
// Uncomment the following line after https://github.com/rapidsai/cudf/issues/8831 is fixed
// @NullSource
@NullSource
@ValueSource(doubles = { 123456.789f, Double.NaN, Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY} )
public void testDoubleLiteralTransform(Double value) {
Literal expr = Literal.ofDouble(value);
Expand All @@ -173,8 +167,7 @@ public void testDoubleLiteralTransform(Double value) {
}

@ParameterizedTest
// Uncomment the following line after https://github.com/rapidsai/cudf/issues/8831 is fixed
// @NullSource
@NullSource
@ValueSource(ints = 0x12345678)
public void testTimestampDaysLiteralTransform(Integer value) {
Literal expr = Literal.ofTimestampDaysFromInt(value);
Expand All @@ -188,8 +181,7 @@ public void testTimestampDaysLiteralTransform(Integer value) {
}

@ParameterizedTest
// Uncomment the following line after https://github.com/rapidsai/cudf/issues/8831 is fixed
// @NullSource
@NullSource
@ValueSource(longs = 0x1234567890abcdefL)
public void testTimestampSecondsLiteralTransform(Long value) {
Literal expr = Literal.ofTimestampFromLong(DType.TIMESTAMP_SECONDS, value);
Expand All @@ -203,8 +195,7 @@ public void testTimestampSecondsLiteralTransform(Long value) {
}

@ParameterizedTest
// Uncomment the following line after https://github.com/rapidsai/cudf/issues/8831 is fixed
// @NullSource
@NullSource
@ValueSource(longs = 0x1234567890abcdefL)
public void testTimestampMilliSecondsLiteralTransform(Long value) {
Literal expr = Literal.ofTimestampFromLong(DType.TIMESTAMP_MILLISECONDS, value);
Expand All @@ -218,8 +209,7 @@ public void testTimestampMilliSecondsLiteralTransform(Long value) {
}

@ParameterizedTest
// Uncomment the following line after https://github.com/rapidsai/cudf/issues/8831 is fixed
// @NullSource
@NullSource
@ValueSource(longs = 0x1234567890abcdefL)
public void testTimestampMicroSecondsLiteralTransform(Long value) {
Literal expr = Literal.ofTimestampFromLong(DType.TIMESTAMP_MICROSECONDS, value);
Expand All @@ -233,8 +223,7 @@ public void testTimestampMicroSecondsLiteralTransform(Long value) {
}

@ParameterizedTest
// Uncomment the following line after https://github.com/rapidsai/cudf/issues/8831 is fixed
// @NullSource
@NullSource
@ValueSource(longs = 0x1234567890abcdefL)
public void testTimestampNanoSecondsLiteralTransform(Long value) {
Literal expr = Literal.ofTimestampFromLong(DType.TIMESTAMP_NANOSECONDS, value);
Expand All @@ -248,8 +237,7 @@ public void testTimestampNanoSecondsLiteralTransform(Long value) {
}

@ParameterizedTest
// Uncomment the following line after https://github.com/rapidsai/cudf/issues/8831 is fixed
// @NullSource
@NullSource
@ValueSource(ints = 0x12345678)
public void testDurationDaysLiteralTransform(Integer value) {
Literal expr = Literal.ofDurationDaysFromInt(value);
Expand All @@ -263,8 +251,7 @@ public void testDurationDaysLiteralTransform(Integer value) {
}

@ParameterizedTest
// Uncomment the following line after https://github.com/rapidsai/cudf/issues/8831 is fixed
// @NullSource
@NullSource
@ValueSource(longs = 0x1234567890abcdefL)
public void testDurationSecondsLiteralTransform(Long value) {
Literal expr = Literal.ofDurationFromLong(DType.DURATION_SECONDS, value);
Expand All @@ -278,8 +265,7 @@ public void testDurationSecondsLiteralTransform(Long value) {
}

@ParameterizedTest
// Uncomment the following line after https://github.com/rapidsai/cudf/issues/8831 is fixed
// @NullSource
@NullSource
@ValueSource(longs = 0x1234567890abcdefL)
public void testDurationMilliSecondsLiteralTransform(Long value) {
Literal expr = Literal.ofDurationFromLong(DType.DURATION_MILLISECONDS, value);
Expand All @@ -293,8 +279,7 @@ public void testDurationMilliSecondsLiteralTransform(Long value) {
}

@ParameterizedTest
// Uncomment the following line after https://github.com/rapidsai/cudf/issues/8831 is fixed
// @NullSource
@NullSource
@ValueSource(longs = 0x1234567890abcdefL)
public void testDurationMicroSecondsLiteralTransform(Long value) {
Literal expr = Literal.ofDurationFromLong(DType.DURATION_MICROSECONDS, value);
Expand All @@ -308,8 +293,7 @@ public void testDurationMicroSecondsLiteralTransform(Long value) {
}

@ParameterizedTest
// Uncomment the following line after https://github.com/rapidsai/cudf/issues/8831 is fixed
// @NullSource
@NullSource
@ValueSource(longs = 0x1234567890abcdefL)
public void testDurationNanoSecondsLiteralTransform(Long value) {
Literal expr = Literal.ofDurationFromLong(DType.DURATION_NANOSECONDS, value);
Expand Down

0 comments on commit a02f888

Please sign in to comment.