diff --git a/cpp/include/cudf/ast/detail/expression_evaluator.cuh b/cpp/include/cudf/ast/detail/expression_evaluator.cuh index a179fe10774..105d87ff96f 100644 --- a/cpp/include/cudf/ast/detail/expression_evaluator.cuh +++ b/cpp/include/cudf/ast/detail/expression_evaluator.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,6 @@ #include #include #include -#include #include #include #include diff --git a/cpp/include/cudf/ast/detail/expression_parser.hpp b/cpp/include/cudf/ast/detail/expression_parser.hpp index ace60b70bf9..84fb7cfbd5a 100644 --- a/cpp/include/cudf/ast/detail/expression_parser.hpp +++ b/cpp/include/cudf/ast/detail/expression_parser.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -104,7 +104,7 @@ using IntermediateDataType = possibly_null_value_t; */ struct expression_device_view { device_span data_references; - device_span literals; + device_span literals; device_span operators; device_span operator_source_indices; cudf::size_type num_intermediates; @@ -281,11 +281,10 @@ class expression_parser { reinterpret_cast(device_data_buffer_ptr + buffer_offsets[0]), _data_references.size()); - device_expression_data.literals = - device_span( - reinterpret_cast( - device_data_buffer_ptr + buffer_offsets[1]), - _literals.size()); + device_expression_data.literals = device_span( + reinterpret_cast(device_data_buffer_ptr + + buffer_offsets[1]), + _literals.size()); device_expression_data.operators = device_span( reinterpret_cast(device_data_buffer_ptr + buffer_offsets[2]), _operators.size()); @@ -335,7 +334,7 @@ class expression_parser { std::vector _data_references; std::vector _operators; std::vector _operator_source_indices; - std::vector _literals; + std::vector _literals; }; } // namespace detail diff --git a/cpp/include/cudf/ast/expressions.hpp b/cpp/include/cudf/ast/expressions.hpp index d29b0787e8e..6df6ba71b4c 100644 --- a/cpp/include/cudf/ast/expressions.hpp +++ b/cpp/include/cudf/ast/expressions.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -150,6 +150,96 @@ enum class table_reference { OUTPUT ///< Column index in the output table }; +/** + * @brief A type-erased scalar_device_view where the value is a fixed width type or a string + */ +class generic_scalar_device_view : public cudf::detail::scalar_device_view_base { + public: + /** + * @brief Returns the stored value. + * + * @tparam T The desired type + * @returns The stored value + */ + template + __device__ T const value() const noexcept + { + if constexpr (std::is_same_v) { + return string_view(static_cast(_data), _size); + } + return *static_cast(_data); + } + + /** @brief Construct a new generic scalar device view object from a numeric scalar + * + * @param s The numeric scalar to construct from + */ + template + generic_scalar_device_view(numeric_scalar& s) + : generic_scalar_device_view(s.type(), s.data(), s.validity_data()) + { + } + + /** @brief Construct a new generic scalar device view object from a timestamp scalar + * + * @param s The timestamp scalar to construct from + */ + template + generic_scalar_device_view(timestamp_scalar& s) + : generic_scalar_device_view(s.type(), s.data(), s.validity_data()) + { + } + + /** @brief Construct a new generic scalar device view object from a duration scalar + * + * @param s The duration scalar to construct from + */ + template + generic_scalar_device_view(duration_scalar& s) + : generic_scalar_device_view(s.type(), s.data(), s.validity_data()) + { + } + + /** @brief Construct a new generic scalar device view object from a string scalar + * + * @param s The string scalar to construct from + */ + generic_scalar_device_view(string_scalar& s) + : generic_scalar_device_view(s.type(), s.data(), s.validity_data(), s.size()) + { + } + + protected: + void const* _data{}; ///< Pointer to device memory containing the value + size_type const _size{}; ///< Size of the string in bytes for string scalar + + /** + * @brief Construct a new fixed width scalar device view object + * + * @param type The data type of the value + * @param data The pointer to the data in device memory + * @param is_valid The pointer to the bool in device memory that indicates the + * validity of the stored value + */ + generic_scalar_device_view(data_type type, void const* data, bool* is_valid) + : cudf::detail::scalar_device_view_base(type, is_valid), _data(data) + { + } + + /** @brief Construct a new string scalar device view object + * + * @param type The data type of the value + * @param data The pointer to the data in device memory + * @param is_valid The pointer to the bool in device memory that indicates the + * validity of the stored value + * @param size The size of the string in bytes + */ + generic_scalar_device_view(data_type type, void const* data, bool* is_valid, size_type size) + : cudf::detail::scalar_device_view_base(type, is_valid), _data(data), _size(size) + { + } +}; + /** * @brief A literal value used in an abstract syntax tree. */ @@ -162,8 +252,7 @@ class literal : public expression { * @param value A numeric scalar value */ template - literal(cudf::numeric_scalar& value) - : scalar(value), value(cudf::get_scalar_device_view(value)) + literal(cudf::numeric_scalar& value) : scalar(value), value(value) { } @@ -174,8 +263,7 @@ class literal : public expression { * @param value A timestamp scalar value */ template - literal(cudf::timestamp_scalar& value) - : scalar(value), value(cudf::get_scalar_device_view(value)) + literal(cudf::timestamp_scalar& value) : scalar(value), value(value) { } @@ -186,11 +274,17 @@ class literal : public expression { * @param value A duration scalar value */ template - literal(cudf::duration_scalar& value) - : scalar(value), value(cudf::get_scalar_device_view(value)) + literal(cudf::duration_scalar& value) : scalar(value), value(value) { } + /** + * @brief Construct a new literal object. + * + * @param value A string scalar value + */ + literal(cudf::string_scalar& value) : scalar(value), value(value) {} + /** * @brief Get the data type. * @@ -203,10 +297,7 @@ class literal : public expression { * * @return The device scalar object */ - [[nodiscard]] cudf::detail::fixed_width_scalar_device_view_base get_value() const - { - return value; - } + [[nodiscard]] generic_scalar_device_view get_value() const { return value; } /** * @brief Accepts a visitor class. @@ -236,7 +327,7 @@ class literal : public expression { private: cudf::scalar const& scalar; - cudf::detail::fixed_width_scalar_device_view_base const value; + generic_scalar_device_view const value; }; /** diff --git a/cpp/tests/ast/transform_tests.cpp b/cpp/tests/ast/transform_tests.cpp index 6845d7990b8..745fa44d45e 100644 --- a/cpp/tests/ast/transform_tests.cpp +++ b/cpp/tests/ast/transform_tests.cpp @@ -396,6 +396,48 @@ TEST_F(TransformTest, StringComparison) CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, result->view(), verbosity); } +TEST_F(TransformTest, StringScalarComparison) +{ + auto c_0 = + cudf::test::strings_column_wrapper({"1", "12", "123", "23"}, {true, true, false, true}); + auto table = cudf::table_view{{c_0}}; + + auto literal_value = cudf::string_scalar("2"); + auto literal = cudf::ast::literal(literal_value); + + auto col_ref_0 = cudf::ast::column_reference(0); + auto expression = cudf::ast::operation(cudf::ast::ast_operator::LESS, col_ref_0, literal); + + auto expected = column_wrapper{{true, true, true, false}, {true, true, false, true}}; + auto result = cudf::compute_column(table, expression); + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, result->view(), verbosity); + + // compare with null literal + literal_value.set_valid_async(false); + auto expected2 = column_wrapper{{false, false, false, false}, {false, false, false, false}}; + auto result2 = cudf::compute_column(table, expression); + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, result->view(), verbosity); +} + +TEST_F(TransformTest, NumericScalarComparison) +{ + auto c_0 = column_wrapper{1, 12, 123, 23}; + auto table = cudf::table_view{{c_0}}; + + auto literal_value = cudf::numeric_scalar(2); + auto literal = cudf::ast::literal(literal_value); + + auto col_ref_0 = cudf::ast::column_reference(0); + auto expression = cudf::ast::operation(cudf::ast::ast_operator::LESS, col_ref_0, literal); + + auto expected = column_wrapper{true, false, false, false}; + auto result = cudf::compute_column(table, expression); + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, result->view(), verbosity); +} + TEST_F(TransformTest, CopyColumn) { auto c_0 = column_wrapper{3, 0, 1, 50};