diff --git a/cpp/include/cudf/strings/convert/is_valid_fixed_point.hpp b/cpp/include/cudf/strings/convert/is_valid_fixed_point.hpp new file mode 100644 index 00000000000..2da65c38a69 --- /dev/null +++ b/cpp/include/cudf/strings/convert/is_valid_fixed_point.hpp @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2019, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +namespace cudf { +namespace strings { +/** + * @addtogroup strings_convert + * @{ + * @file + */ + +/** + * @brief Returns a boolean column identifying strings in which all + * characters are valid for conversion to integers. + * + * @param strings Strings instance for this operation. + * @param allow_decimal identification whether the format is decimal or not + * @param input_type input data type for check overflow + * @param mr Device memory resource used to allocate the returned column's device memory. + * @return New column of boolean results for each string. + */ +std::unique_ptr is_valid_fixed_point( + strings_column_view const& strings, + bool allow_decimal, + data_type input_type, + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); + +/** @} */ // end of doxygen group +} // namespace strings +} // namespace cudf + diff --git a/cpp/src/strings/convert/is_valid_fixed_point.cu b/cpp/src/strings/convert/is_valid_fixed_point.cu new file mode 100644 index 00000000000..268e6718006 --- /dev/null +++ b/cpp/src/strings/convert/is_valid_fixed_point.cu @@ -0,0 +1,173 @@ +/* + * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace cudf { +namespace strings { + +/** + * Parses this UTF8String(trimmed if needed) to INT8/16/32/64... + * + * Note that, in this method we accumulate the result in negative format, and convert it to + * positive format at the end, if this string is not started with '-'. This is because min value + * is bigger than max value in digits, e.g. Long.MAX_VALUE is '9223372036854775807' and + * Long.MIN_VALUE is '-9223372036854775808'. + * + * This code is mostly copied from LazyLong.parseLong in Hive. + * + * @param d_str String to check. + * @param allow_decimal Decimal format or not + * @param min_value min_value that corresponds to the type that is checking + * @return true if string has valid integer characters + */ +__device__ bool is_valid_fixed_point(string_view const& d_str, bool allow_decimal, long min_value) +{ + int offset = 0; + size_type bytes = d_str.size_bytes(); + const char* data = d_str.data(); + while (offset < bytes && data[offset] == ' ') ++offset; + if (offset == bytes) return false; + + int end = bytes - 1; + while (end > offset && data[end] == ' ') --end; + + char c_sign = data[offset]; + const bool negative = c_sign == '-'; + if (negative || c_sign == '+'){ + if (end - offset == 0) return false; + ++offset; + } + + const char separator = '.'; + const int radix = 10; + const long stop_value = min_value / radix; + long result = 0; + + while (offset <= end) { + const char c = data[offset]; + ++offset; + // We allow decimals and will return a truncated integral in that case. + // Therefore we won't throw an exception here (checking the fractional + // part happens below. + if (c == separator && allow_decimal) break; + + int digit; + if (c >= '0' && c <= '9'){ + digit = c - '0'; + } else { + return false; + } + + // We are going to process the new digit and accumulate the result. However, before doing + // this, if the result is already smaller than the stopValue(Long.MIN_VALUE / radix), then + // result * 10 will definitely be smaller than minValue, and we can stop. + if (result < stop_value) return false; + + result = result * radix - digit; + + // Since the previous result is less than or equal to stopValue(Long.MIN_VALUE / radix), we + // can just use `result > 0` to check overflow. If result overflows, we should stop. + if (result > 0) return false; + } + // This is the case when we've encountered a decimal separator. The fractional + // part will not change the number, but we will verify that the fractional part + // is well formed. + while (offset <= end) { + char currentByte = data[offset]; + if (currentByte < '0' || currentByte > '9') return false; + ++offset; + } + + if (!negative) { + result = -result; + if (result < 0) return false; + } + + return true; +} + +namespace detail { + +std::unique_ptr is_valid_fixed_point( + strings_column_view const& strings, + bool allow_decimal, + data_type input_type, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) +{ + auto strings_column = column_device_view::create(strings.parent(), stream); + auto d_column = *strings_column; + auto d_allow_decimal = allow_decimal; + + // ready a min_value corresponds to the input type in order to check overflow + long d_min_value = 0; + switch (input_type.id()) { + case type_id::INT8: d_min_value = -128; + case type_id::INT16: d_min_value = -32768; + case type_id::INT32: d_min_value = -2147483648; + case type_id::INT64: d_min_value = -9223372036854775808; + default: CUDF_FAIL("Unsupported current data type check when convert string type"); + } + + // create output column + auto results = make_numeric_column(data_type{type_id::BOOL8}, + strings.size(), + cudf::detail::copy_bitmask(strings.parent(), stream, mr), + strings.null_count(), + stream, + mr); + auto d_results = results->mutable_view().data(); + thrust::transform(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(strings.size()), + d_results, + [d_column,d_allow_decimal,d_min_value] __device__(size_type idx) { + if (d_column.is_null(idx)) return false; + return strings::is_valid_fixed_point(d_column.element(idx), d_allow_decimal, d_min_value); + }); + results->set_null_count(strings.null_count()); + return results; +} + +} // namespace detail + +// external API + +std::unique_ptr is_valid_fixed_point(strings_column_view const& strings, + bool allow_decimal, + data_type input_type, + rmm::mr::device_memory_resource* mr) +{ + CUDF_FUNC_RANGE(); + return detail::is_valid_fixed_point(strings, allow_decimal, input_type, rmm::cuda_stream_default, mr); +} + +} // namespace strings +} // namespace cudf + diff --git a/cpp/tests/strings/chars_types_tests.cpp b/cpp/tests/strings/chars_types_tests.cpp index 7a7d1e3e106..8917482ac46 100644 --- a/cpp/tests/strings/chars_types_tests.cpp +++ b/cpp/tests/strings/chars_types_tests.cpp @@ -243,6 +243,21 @@ TEST_F(StringsCharsTest, Integers) EXPECT_TRUE(cudf::strings::all_integer(cudf::strings_column_view(strings2))); } +TEST_F(StringsCharsTest, ValidIntegers) +{ + cudf::test::strings_column_wrapper strings1( + {"+175", "-34", "9.8", "17+2", "+-14", "1234567890", "67de", "", "1e10", "-", "++", ""}); + auto results = cudf::strings::is_valid_fixed_point(cudf::strings_column_view(strings1), true, data_type::INT64); + cudf::test::fixed_width_column_wrapper expected1({1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0}); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected1); + + cudf::test::strings_column_wrapper strings2( + {"0", "+0", "-0", "1234567890", "-27341132", "+012", "023", "-045", "-1.1", "+1000.1"}); + results = cudf::strings::is_valid_fixed_point(cudf::strings_column_view(strings2), true, data_type::INT64); + cudf::test::fixed_width_column_wrapper expected2({1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected2); +} + TEST_F(StringsCharsTest, Floats) { cudf::test::strings_column_wrapper strings1({"+175", @@ -390,3 +405,4 @@ TEST_F(StringsCharsTest, EmptyStringsColumn) EXPECT_EQ(cudf::type_id::STRING, results->view().type().id()); EXPECT_EQ(0, results->view().size()); } +