diff --git a/cpp/include/cudf/strings/convert/is_valid_element.hpp b/cpp/include/cudf/strings/convert/is_valid_element.hpp new file mode 100644 index 00000000000..3845b5f5675 --- /dev/null +++ b/cpp/include/cudf/strings/convert/is_valid_element.hpp @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021, Baidu CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +namespace cudf { +namespace strings { +/** + * @addtogroup strings_convert + * @{ + * @file + */ + +/** + * @brief Returns a boolean column identifying strings in which all characters are valid. + * + * Boolean variable `allow_decimal` indicates that whether we allow the input string data + * is decimal, if `allow_decimal` is false, this function will check that the format is + * [+-]?[0-9]+ like `is_integer`, or itll should check that it matches [+-]?[0-9]+(.[0-9]+) + * similar to `is_float` but without some of the special cases for float (E, Inf, -Inf, NaN). + * + * input_type is used to check whether the data overflows, for example, if input_type is + * `int8_t` and input string data is `128`, then it will return false ,because it out of ranges + * [-128, 127] and overflows. + * + * @param strings Strings instance for this operation. + * @param allow_decimal identification whether we allow the element is decimal or not. + * @param input_type input data type for check overflow. + * @param mr Device memory resource used to allocate the returned column's device memory. + * @return New column of boolean results for each string. + */ +std::unique_ptr is_valid_element( + strings_column_view const& strings, + bool allow_decimal, + data_type input_type, + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); + +/** @} */ // end of doxygen group +} // namespace strings +} // namespace cudf + diff --git a/cpp/src/strings/convert/is_valid_element.cu b/cpp/src/strings/convert/is_valid_element.cu new file mode 100644 index 00000000000..2f4f71a9564 --- /dev/null +++ b/cpp/src/strings/convert/is_valid_element.cu @@ -0,0 +1,198 @@ +/* + * Copyright (c) 2021, Baidu CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include + +#include + +#include + +namespace cudf { +namespace strings { +namespace detail { +namespace { +/** + * Check whether the string is valid when convert string to signed integers, + * like INT8/16/32/64. For example, if allow_decimal is true, then strings + * `['1.23', '123']` will return `[true, true]`. + * If `allow_decimal` is false, then this function will return `[false, true]`. + * + * Note that, in this method we accumulate the result in negative format, and convert it to + * positive format at the end, if this string is not started with '-'. This is because min value + * is bigger than max value in digits, e.g. Long.MAX_VALUE is '9223372036854775807' and + * Long.MIN_VALUE is '-9223372036854775808'. + * + * This code is heavily based off of LazyLong.parseLong from Hive, but updated for C++. + * + * @param d_str String to check. + * @param allow_decimal whether we allow the data is Decimal type or not. + * @param min_value min_value that corresponds to the type that is checking. + * @return true if string has valid integer characters or decimal characters. + */ +__device__ bool is_valid_element(string_view const& d_str, bool allow_decimal, long min_value) +{ + int offset = 0; + size_type bytes = d_str.size_bytes(); + const char* data = d_str.data(); + // strip leading white space + while (offset < bytes && data[offset] == ' ') ++offset; + if (offset == bytes) return false; + + int end = bytes - 1; + // strip trailing white space + while (end > offset && data[end] == ' ') --end; + + char c_sign = data[offset]; + const bool negative = c_sign == '-'; + if (negative || c_sign == '+'){ + if (end - offset == 0) return false; + ++offset; + } + + const char separator = '.'; + const int radix = 10; + const long stop_value = min_value / radix; + long result = 0; + + while (offset <= end) { + const char c = data[offset]; + ++offset; + // We allow decimals and will return a truncated integral in that case. + // Therefore we won't throw an exception here (checking the fractional + // part happens below). + if (c == separator && allow_decimal) break; + + int digit; + if (c >= '0' && c <= '9'){ + digit = c - '0'; + } else { + return false; + } + + // We are going to process the new digit and accumulate the result. However, + // before doing this, if the result is already smaller than the stop_value which is + // (std::numeric_limits::min() / radix), then result * 10 will definitely + // be smaller than the min_value, and we can stop. + if (result < stop_value) return false; + + result = result * radix - digit; + + // Since the previous result is less than or equal to stopValue which is + // (std::numeric_limits::min() / radix), we can just use `result > 0` + // to check overflow. If result overflows, we should stop. + if (result > 0) return false; + } + // This is the case when we've encountered a decimal separator. The fractional + // part will not change the number, but we will verify that the fractional part + // is well formed. + if (offset <= end && thrust::any_of(thrust::seq, + data+offset, + data+end, + [] (char ch) { + return (ch<'0' || ch>'9'); + })) + return false; + + if (!negative) { + result = -result; + if (result < 0) return false; + } + + return true; +} + +} //namespace + +/** + * @brief The dispatch functions return the min value of the input data type + * used for checking overflow. + * + * The output is the min value of specified type. + */ +struct min_value_of_type{ + template + long operator()() + { + CUDF_FAIL("Unsupported current data type check."); + } +}; + +template <> +long min_value_of_type::operator()() { return std::numeric_limits::min(); } + +template <> +long min_value_of_type::operator()() { return std::numeric_limits::min(); } + +template <> +long min_value_of_type::operator()() { return std::numeric_limits::min(); } + +template <> +long min_value_of_type::operator()() { return std::numeric_limits::min(); } + +std::unique_ptr is_valid_element( + strings_column_view const& strings, + bool allow_decimal, + data_type input_type, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) +{ + auto strings_column = column_device_view::create(strings.parent(), stream); + auto d_column = *strings_column; + auto d_allow_decimal = allow_decimal; + + // ready a min_value corresponds to the input type in order to check overflow + long d_min_value = cudf::type_dispatcher(input_type, min_value_of_type{}) ; + + // create output column + auto results = make_numeric_column(data_type{type_id::BOOL8}, + strings.size(), + cudf::detail::copy_bitmask(strings.parent(), stream, mr), + strings.null_count(), + stream, + mr); + auto d_results = results->mutable_view().data(); + thrust::transform(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(strings.size()), + d_results, + [d_column,d_allow_decimal,d_min_value] __device__(size_type idx) { + if (d_column.is_null(idx)) return false; + return is_valid_element(d_column.element(idx), d_allow_decimal, d_min_value); + }); + results->set_null_count(strings.null_count()); + return results; +} + +} // namespace detail + +// external API + +std::unique_ptr is_valid_element(strings_column_view const& strings, + bool allow_decimal, + data_type input_type, + rmm::mr::device_memory_resource* mr) +{ + CUDF_FUNC_RANGE(); + return detail::is_valid_element(strings, allow_decimal, input_type, rmm::cuda_stream_default, mr); +} + +} // namespace strings +} // namespace cudf + diff --git a/cpp/tests/strings/chars_types_tests.cpp b/cpp/tests/strings/chars_types_tests.cpp index 7a7d1e3e106..728b79626e4 100644 --- a/cpp/tests/strings/chars_types_tests.cpp +++ b/cpp/tests/strings/chars_types_tests.cpp @@ -390,3 +390,4 @@ TEST_F(StringsCharsTest, EmptyStringsColumn) EXPECT_EQ(cudf::type_id::STRING, results->view().type().id()); EXPECT_EQ(0, results->view().size()); } + diff --git a/cpp/tests/strings/valid_element.cpp b/cpp/tests/strings/valid_element.cpp new file mode 100644 index 00000000000..8da464f3b2a --- /dev/null +++ b/cpp/tests/strings/valid_element.cpp @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2021, Baidu CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include + +struct ValidStringCharsTest : public cudf::test::BaseFixture { +}; + +TEST_F(ValidStringCharsTest, ValidFixedPoint) +{ + // allow_decimal = true + cudf::test::strings_column_wrapper strings1( + {"+175", "-34", "9.8", "17+2", "+-14", "1234567890", "67de", "", "1e10", "-", "++", "", "21474836482222"}); + auto results = cudf::strings::is_valid_element(cudf::strings_column_view(strings1), true, cudf::data_type{cudf::type_id::INT8}); + cudf::test::fixed_width_column_wrapper expected1({0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected1); + + cudf::test::strings_column_wrapper strings2( + {"+175", "-34", "9.8", "17+2", "+-14", "1234567890", "67de", "", "1e10", "-", "++", "", "21474836482222"}); + results = cudf::strings::is_valid_element(cudf::strings_column_view(strings2), true, cudf::data_type{cudf::type_id::INT16}); + cudf::test::fixed_width_column_wrapper expected2({1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected2); + + cudf::test::strings_column_wrapper strings3( + {"+175", "-34", "9.8", "17+2", "+-14", "1234567890", "67de", "", "1e10", "-", "++", "", "21474836482222"}); + results = cudf::strings::is_valid_element(cudf::strings_column_view(strings3), true, cudf::data_type{cudf::type_id::INT32}); + cudf::test::fixed_width_column_wrapper expected3({1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0}); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected3); + + cudf::test::strings_column_wrapper strings4( + {"+175", "-34", "9.8", "17+2", "+-14", "1234567890", "67de", "", "1e10", "-", "++", "", "21474836482222"}); + results = cudf::strings::is_valid_element(cudf::strings_column_view(strings4), true, cudf::data_type{cudf::type_id::INT64}); + cudf::test::fixed_width_column_wrapper expected4({1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1}); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected4); + + // allow_decimal = false + cudf::test::strings_column_wrapper strings5( + {"+175", "-34", "9.8", "17+2", "+-14", "1234567890", "67de", "", "1e10", "-", "++", "", "21474836482222"}); + results = cudf::strings::is_valid_element(cudf::strings_column_view(strings5), false, cudf::data_type{cudf::type_id::INT8}); + cudf::test::fixed_width_column_wrapper expected5({0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected5); + + cudf::test::strings_column_wrapper strings6( + {"+175", "-34", "9.8", "17+2", "+-14", "1234567890", "67de", "", "1e10", "-", "++", "", "21474836482222"}); + results = cudf::strings::is_valid_element(cudf::strings_column_view(strings6), false, cudf::data_type{cudf::type_id::INT16}); + cudf::test::fixed_width_column_wrapper expected6({1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected6); + + cudf::test::strings_column_wrapper strings7( + {"+175", "-34", "9.8", "17+2", "+-14", "1234567890", "67de", "", "1e10", "-", "++", "", "21474836482222"}); + results = cudf::strings::is_valid_element(cudf::strings_column_view(strings7), false, cudf::data_type{cudf::type_id::INT32}); + cudf::test::fixed_width_column_wrapper expected7({1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0}); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected7); + + cudf::test::strings_column_wrapper strings8( + {"+175", "-34", "9.8", "17+2", "+-14", "1234567890", "67de", "", "1e10", "-", "++", "", "21474836482222"}); + results = cudf::strings::is_valid_element(cudf::strings_column_view(strings8), false, cudf::data_type{cudf::type_id::INT64}); + cudf::test::fixed_width_column_wrapper expected8({1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1}); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected8); + + // second test + cudf::test::strings_column_wrapper strings0( + {"0", "+0", "-0", "1234567890", "-27341132", "+012", "023", "-045", "-1.1", "+1000.1"}); + results = cudf::strings::is_valid_element(cudf::strings_column_view(strings0), true, cudf::data_type{cudf::type_id::INT64}); + cudf::test::fixed_width_column_wrapper expected0({1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected0); + +} +