diff --git a/cpp/include/cudf/strings/contains.hpp b/cpp/include/cudf/strings/contains.hpp index 9f408a40314..5b8b2f56bae 100644 --- a/cpp/include/cudf/strings/contains.hpp +++ b/cpp/include/cudf/strings/contains.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -51,7 +51,7 @@ namespace strings { */ std::unique_ptr contains_re( strings_column_view const& strings, - std::string const& pattern, + std::string_view pattern, regex_flags const flags = regex_flags::DEFAULT, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); @@ -78,7 +78,7 @@ std::unique_ptr contains_re( */ std::unique_ptr matches_re( strings_column_view const& strings, - std::string const& pattern, + std::string_view pattern, regex_flags const flags = regex_flags::DEFAULT, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); @@ -105,7 +105,7 @@ std::unique_ptr matches_re( */ std::unique_ptr count_re( strings_column_view const& strings, - std::string const& pattern, + std::string_view pattern, regex_flags const flags = regex_flags::DEFAULT, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); diff --git a/cpp/include/cudf/strings/extract.hpp b/cpp/include/cudf/strings/extract.hpp index 94e9f36d7d3..680d0f5b7bc 100644 --- a/cpp/include/cudf/strings/extract.hpp +++ b/cpp/include/cudf/strings/extract.hpp @@ -55,7 +55,7 @@ namespace strings { */ std::unique_ptr extract( strings_column_view const& strings, - std::string const& pattern, + std::string_view pattern, regex_flags const flags = regex_flags::DEFAULT, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); @@ -90,7 +90,7 @@ std::unique_ptr
extract( */ std::unique_ptr extract_all_record( strings_column_view const& strings, - std::string const& pattern, + std::string_view pattern, regex_flags const flags = regex_flags::DEFAULT, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); diff --git a/cpp/include/cudf/strings/findall.hpp b/cpp/include/cudf/strings/findall.hpp index 25ebdc61673..25c6d523250 100644 --- a/cpp/include/cudf/strings/findall.hpp +++ b/cpp/include/cudf/strings/findall.hpp @@ -56,7 +56,7 @@ namespace strings { */ std::unique_ptr
findall( strings_column_view const& input, - std::string const& pattern, + std::string_view pattern, regex_flags const flags = regex_flags::DEFAULT, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); @@ -90,7 +90,7 @@ std::unique_ptr
findall( */ std::unique_ptr findall_record( strings_column_view const& input, - std::string const& pattern, + std::string_view pattern, regex_flags const flags = regex_flags::DEFAULT, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); diff --git a/cpp/include/cudf/strings/replace_re.hpp b/cpp/include/cudf/strings/replace_re.hpp index 0ab3953470d..36c287009d0 100644 --- a/cpp/include/cudf/strings/replace_re.hpp +++ b/cpp/include/cudf/strings/replace_re.hpp @@ -50,7 +50,7 @@ namespace strings { */ std::unique_ptr replace_re( strings_column_view const& strings, - std::string const& pattern, + std::string_view pattern, string_scalar const& replacement = string_scalar(""), std::optional max_replace_count = std::nullopt, regex_flags const flags = regex_flags::DEFAULT, @@ -98,8 +98,8 @@ std::unique_ptr replace_re( */ std::unique_ptr replace_with_backrefs( strings_column_view const& strings, - std::string const& pattern, - std::string const& replacement, + std::string_view pattern, + std::string_view replacement, regex_flags const flags = regex_flags::DEFAULT, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); diff --git a/cpp/include/cudf/strings/split/split_re.hpp b/cpp/include/cudf/strings/split/split_re.hpp index 9f40956722d..57246bd91d2 100644 --- a/cpp/include/cudf/strings/split/split_re.hpp +++ b/cpp/include/cudf/strings/split/split_re.hpp @@ -71,7 +71,7 @@ namespace strings { */ std::unique_ptr
split_re( strings_column_view const& input, - std::string const& pattern, + std::string_view pattern, size_type maxsplit = -1, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); @@ -121,7 +121,7 @@ std::unique_ptr
split_re( */ std::unique_ptr
rsplit_re( strings_column_view const& input, - std::string const& pattern, + std::string_view pattern, size_type maxsplit = -1, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); @@ -173,7 +173,7 @@ std::unique_ptr
rsplit_re( */ std::unique_ptr split_record_re( strings_column_view const& input, - std::string const& pattern, + std::string_view pattern, size_type maxsplit = -1, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); @@ -227,7 +227,7 @@ std::unique_ptr split_record_re( */ std::unique_ptr rsplit_record_re( strings_column_view const& input, - std::string const& pattern, + std::string_view pattern, size_type maxsplit = -1, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); diff --git a/cpp/src/strings/contains.cu b/cpp/src/strings/contains.cu index 987cd076fd0..d75d914bb8e 100644 --- a/cpp/src/strings/contains.cu +++ b/cpp/src/strings/contains.cu @@ -56,7 +56,7 @@ struct contains_fn { }; std::unique_ptr contains_impl(strings_column_view const& input, - std::string const& pattern, + std::string_view pattern, regex_flags const flags, bool const beginning_only, rmm::cuda_stream_view stream, @@ -85,7 +85,7 @@ std::unique_ptr contains_impl(strings_column_view const& input, std::unique_ptr contains_re( strings_column_view const& input, - std::string const& pattern, + std::string_view pattern, regex_flags const flags, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) @@ -95,7 +95,7 @@ std::unique_ptr contains_re( std::unique_ptr matches_re( strings_column_view const& input, - std::string const& pattern, + std::string_view pattern, regex_flags const flags, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) @@ -105,7 +105,7 @@ std::unique_ptr matches_re( std::unique_ptr count_re( strings_column_view const& input, - std::string const& pattern, + std::string_view pattern, regex_flags const flags, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) @@ -128,7 +128,7 @@ std::unique_ptr count_re( // external APIs std::unique_ptr contains_re(strings_column_view const& strings, - std::string const& pattern, + std::string_view pattern, regex_flags const flags, rmm::mr::device_memory_resource* mr) { @@ -137,7 +137,7 @@ std::unique_ptr contains_re(strings_column_view const& strings, } std::unique_ptr matches_re(strings_column_view const& strings, - std::string const& pattern, + std::string_view pattern, regex_flags const flags, rmm::mr::device_memory_resource* mr) { @@ -146,7 +146,7 @@ std::unique_ptr matches_re(strings_column_view const& strings, } std::unique_ptr count_re(strings_column_view const& strings, - std::string const& pattern, + std::string_view pattern, regex_flags const flags, rmm::mr::device_memory_resource* mr) { diff --git a/cpp/src/strings/extract/extract.cu b/cpp/src/strings/extract/extract.cu index 59b90952d97..018fb7ba2fb 100644 --- a/cpp/src/strings/extract/extract.cu +++ b/cpp/src/strings/extract/extract.cu @@ -85,7 +85,7 @@ struct extract_fn { // std::unique_ptr
extract(strings_column_view const& input, - std::string const& pattern, + std::string_view pattern, regex_flags const flags, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) @@ -130,7 +130,7 @@ std::unique_ptr
extract(strings_column_view const& input, // external API std::unique_ptr
extract(strings_column_view const& strings, - std::string const& pattern, + std::string_view pattern, regex_flags const flags, rmm::mr::device_memory_resource* mr) { diff --git a/cpp/src/strings/extract/extract_all.cu b/cpp/src/strings/extract/extract_all.cu index 95b8a43a9d4..60c28027833 100644 --- a/cpp/src/strings/extract/extract_all.cu +++ b/cpp/src/strings/extract/extract_all.cu @@ -96,7 +96,7 @@ struct extract_fn { */ std::unique_ptr extract_all_record( strings_column_view const& input, - std::string const& pattern, + std::string_view pattern, regex_flags const flags, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) @@ -165,7 +165,7 @@ std::unique_ptr extract_all_record( // external API std::unique_ptr extract_all_record(strings_column_view const& strings, - std::string const& pattern, + std::string_view pattern, regex_flags const flags, rmm::mr::device_memory_resource* mr) { diff --git a/cpp/src/strings/regex/regcomp.cpp b/cpp/src/strings/regex/regcomp.cpp index f99acc3448a..dd4b4116994 100644 --- a/cpp/src/strings/regex/regcomp.cpp +++ b/cpp/src/strings/regex/regcomp.cpp @@ -21,7 +21,9 @@ #include #include -#include +#include +#include +#include namespace cudf { namespace strings { @@ -862,8 +864,7 @@ class regex_compiler { ; // "unmatched left paren"; /* points to first and only operand */ m_prog.set_start_inst(andstack[andstack.size() - 1].id_first); - m_prog.optimize1(); - m_prog.optimize2(); + m_prog.finalize(); m_prog.check_for_errors(); m_prog.set_groups_count(cursubid); } @@ -880,81 +881,78 @@ reprog reprog::create_from(std::string_view pattern, regex_flags const flags) return rtn; } -// -void reprog::optimize1() +void reprog::finalize() { - // Treat non-capturing LBRAs/RBRAs as NOOP - for (int i = 0; i < static_cast(_insts.size()); i++) { - if (_insts[i].type == LBRA || _insts[i].type == RBRA) { - if (_insts[i].u1.subid < 1) { _insts[i].type = NOP; } + collapse_nops(); + build_start_ids(); +} + +void reprog::collapse_nops() +{ + // treat non-capturing LBRAs/RBRAs as NOP + std::transform(_insts.begin(), _insts.end(), _insts.begin(), [](auto inst) { + if ((inst.type == LBRA || inst.type == RBRA) && (inst.u1.subid < 1)) { inst.type = NOP; } + return inst; + }); + + // functor for finding the next valid op + auto find_next_op = [insts = _insts](int id) { + while (insts[id].type == NOP) { + id = insts[id].u2.next_id; } - } + return id; + }; - // get rid of NOP chains - for (int i = 0; i < insts_count(); i++) { - if (_insts[i].type != NOP) { - { - int target_id = _insts[i].u2.next_id; - while (_insts[target_id].type == NOP) - target_id = _insts[target_id].u2.next_id; - _insts[i].u2.next_id = target_id; - } - if (_insts[i].type == OR) { - int target_id = _insts[i].u1.right_id; - while (_insts[target_id].type == NOP) - target_id = _insts[target_id].u2.next_id; - _insts[i].u1.right_id = target_id; - } + // create new routes around NOP chains + std::transform(_insts.begin(), _insts.end(), _insts.begin(), [find_next_op](auto inst) { + if (inst.type != NOP) { + inst.u2.next_id = find_next_op(inst.u2.next_id); + if (inst.type == OR) { inst.u1.right_id = find_next_op(inst.u1.right_id); } } - } - // skip NOPs from the beginning - { - int target_id = _startinst_id; - while (_insts[target_id].type == NOP) - target_id = _insts[target_id].u2.next_id; - _startinst_id = target_id; - } - // actually remove the no-ops + return inst; + }); + + // find starting op + _startinst_id = find_next_op(_startinst_id); + + // build a map of op ids + // these are used to fix up the ids after the NOPs are removed std::vector id_map(insts_count()); - int j = 0; // compact the ops (non no-ops) - for (int i = 0; i < insts_count(); i++) { - id_map[i] = j; - if (_insts[i].type != NOP) { - _insts[j] = _insts[i]; - j++; - } - } - _insts.resize(j); - // fix up the ORs - for (int i = 0; i < insts_count(); i++) { - { - int target_id = _insts[i].u2.next_id; - _insts[i].u2.next_id = id_map[target_id]; - } - if (_insts[i].type == OR) { - int target_id = _insts[i].u1.right_id; - _insts[i].u1.right_id = id_map[target_id]; - } - } - // set the new start id + std::transform_exclusive_scan( + _insts.begin(), _insts.end(), id_map.begin(), 0, std::plus{}, [](auto inst) { + return static_cast(inst.type != NOP); + }); + + // remove the NOP instructions + auto end = std::remove_if(_insts.begin(), _insts.end(), [](auto i) { return i.type == NOP; }); + _insts.resize(std::distance(_insts.begin(), end)); + + // fix up the ids on the remaining instructions using the id_map + std::transform(_insts.begin(), _insts.end(), _insts.begin(), [id_map](auto inst) { + inst.u2.next_id = id_map[inst.u2.next_id]; + if (inst.type == OR) { inst.u1.right_id = id_map[inst.u1.right_id]; } + return inst; + }); + + // fix up the start instruction id too _startinst_id = id_map[_startinst_id]; } // expand leading ORs to multiple startinst_ids -void reprog::optimize2() +void reprog::build_start_ids() { _startinst_ids.clear(); - std::vector stack; - stack.push_back(_startinst_id); - while (!stack.empty()) { - int id = stack.back(); - stack.pop_back(); + std::stack ids; + ids.push(_startinst_id); + while (!ids.empty()) { + int id = ids.top(); + ids.pop(); const reinst& inst = _insts[id]; if (inst.type == OR) { if (inst.u2.left_id != id) // prevents infinite while-loop here - stack.push_back(inst.u2.left_id); + ids.push(inst.u2.left_id); if (inst.u1.right_id != id) // prevents infinite while-loop here - stack.push_back(inst.u1.right_id); + ids.push(inst.u1.right_id); } else { _startinst_ids.push_back(id); } diff --git a/cpp/src/strings/regex/regcomp.h b/cpp/src/strings/regex/regcomp.h index 162a2090268..ed87660f106 100644 --- a/cpp/src/strings/regex/regcomp.h +++ b/cpp/src/strings/regex/regcomp.h @@ -124,8 +124,7 @@ class reprog { void set_start_inst(int32_t id); [[nodiscard]] int32_t get_start_inst() const; - void optimize1(); - void optimize2(); + void finalize(); void check_for_errors(); #ifndef NDEBUG void print(regex_flags const flags); @@ -139,6 +138,8 @@ class reprog { int32_t _num_capturing_groups{}; reprog() = default; + void collapse_nops(); + void build_start_ids(); void check_for_errors(int32_t id, int32_t next_id); }; diff --git a/cpp/src/strings/regex/regex.cuh b/cpp/src/strings/regex/regex.cuh index 5ccc70222d5..2ee195a2c5e 100644 --- a/cpp/src/strings/regex/regex.cuh +++ b/cpp/src/strings/regex/regex.cuh @@ -88,7 +88,7 @@ class reprog_device { * @return The program device object. */ static std::unique_ptr> create( - std::string const& pattern, rmm::cuda_stream_view stream); + std::string_view pattern, rmm::cuda_stream_view stream); /** * @brief Create the device program instance from a regex pattern. @@ -99,7 +99,7 @@ class reprog_device { * @return The program device object. */ static std::unique_ptr> create( - std::string const& pattern, regex_flags const re_flags, rmm::cuda_stream_view stream); + std::string_view pattern, regex_flags const re_flags, rmm::cuda_stream_view stream); /** * @brief Called automatically by the unique_ptr returned from create(). diff --git a/cpp/src/strings/regex/regexec.cu b/cpp/src/strings/regex/regexec.cu index 4b58d9d8a88..16f5b6fa03d 100644 --- a/cpp/src/strings/regex/regexec.cu +++ b/cpp/src/strings/regex/regexec.cu @@ -43,14 +43,14 @@ reprog_device::reprog_device(reprog& prog) } std::unique_ptr> reprog_device::create( - std::string const& pattern, rmm::cuda_stream_view stream) + std::string_view pattern, rmm::cuda_stream_view stream) { return reprog_device::create(pattern, regex_flags::MULTILINE, stream); } // Create instance of the reprog that can be passed into a device kernel std::unique_ptr> reprog_device::create( - std::string const& pattern, regex_flags const flags, rmm::cuda_stream_view stream) + std::string_view pattern, regex_flags const flags, rmm::cuda_stream_view stream) { // compile pattern into host object reprog h_prog = reprog::create_from(pattern, flags); diff --git a/cpp/src/strings/replace/backref_re.cu b/cpp/src/strings/replace/backref_re.cu index 107adf07263..55498e760ff 100644 --- a/cpp/src/strings/replace/backref_re.cu +++ b/cpp/src/strings/replace/backref_re.cu @@ -46,13 +46,14 @@ namespace { * * Reference: https://www.regular-expressions.info/refreplacebackref.html */ -std::string get_backref_pattern(std::string const& repl) +std::string get_backref_pattern(std::string_view repl) { std::string const backslash_pattern = "\\\\(\\d+)"; std::string const bracket_pattern = "\\$\\{(\\d+)\\}"; + std::string const r{repl}; std::smatch m; - return std::regex_search(repl, m, std::regex(backslash_pattern)) ? backslash_pattern - : bracket_pattern; + return std::regex_search(r, m, std::regex(backslash_pattern)) ? backslash_pattern + : bracket_pattern; } /** * @brief Parse the back-ref index and position values from a given replace format. @@ -66,11 +67,11 @@ std::string get_backref_pattern(std::string const& repl) * For example, for input string 'hello \2 and \1' the returned `backref_type` vector * contains `[(2,6),(1,11)]` and the returned string is 'hello and '. */ -std::pair> parse_backrefs(std::string const& repl, +std::pair> parse_backrefs(std::string_view repl, int const group_count) { std::vector backrefs; - std::string str = repl; // make a modifiable copy + std::string str{repl}; // make a modifiable copy std::smatch m; std::regex ex(get_backref_pattern(repl)); std::string rtn; @@ -100,8 +101,8 @@ std::pair> parse_backrefs(std::string con // std::unique_ptr replace_with_backrefs(strings_column_view const& input, - std::string const& pattern, - std::string const& replacement, + std::string_view pattern, + std::string_view replacement, regex_flags const flags, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) @@ -144,8 +145,8 @@ std::unique_ptr replace_with_backrefs(strings_column_view const& input, // external API std::unique_ptr replace_with_backrefs(strings_column_view const& strings, - std::string const& pattern, - std::string const& replacement, + std::string_view pattern, + std::string_view replacement, regex_flags const flags, rmm::mr::device_memory_resource* mr) { diff --git a/cpp/src/strings/replace/replace_re.cu b/cpp/src/strings/replace/replace_re.cu index 159f83453bd..1ed29587ac7 100644 --- a/cpp/src/strings/replace/replace_re.cu +++ b/cpp/src/strings/replace/replace_re.cu @@ -101,7 +101,7 @@ struct replace_regex_fn { // std::unique_ptr replace_re( strings_column_view const& input, - std::string const& pattern, + std::string_view pattern, string_scalar const& replacement, std::optional max_replace_count, regex_flags const flags, @@ -135,7 +135,7 @@ std::unique_ptr replace_re( // external API std::unique_ptr replace_re(strings_column_view const& strings, - std::string const& pattern, + std::string_view pattern, string_scalar const& replacement, std::optional max_replace_count, regex_flags const flags, diff --git a/cpp/src/strings/search/findall.cu b/cpp/src/strings/search/findall.cu index 64e46d07e25..c92e1e7bbd9 100644 --- a/cpp/src/strings/search/findall.cu +++ b/cpp/src/strings/search/findall.cu @@ -86,7 +86,7 @@ struct findall_fn { } // namespace std::unique_ptr
findall(strings_column_view const& input, - std::string const& pattern, + std::string_view pattern, regex_flags const flags, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) @@ -147,7 +147,7 @@ std::unique_ptr
findall(strings_column_view const& input, // external API std::unique_ptr
findall(strings_column_view const& input, - std::string const& pattern, + std::string_view pattern, regex_flags const flags, rmm::mr::device_memory_resource* mr) { diff --git a/cpp/src/strings/search/findall_record.cu b/cpp/src/strings/search/findall_record.cu index 2f4b9ce5b24..e4cf4dad618 100644 --- a/cpp/src/strings/search/findall_record.cu +++ b/cpp/src/strings/search/findall_record.cu @@ -93,7 +93,7 @@ std::unique_ptr findall_util(column_device_view const& d_strings, // std::unique_ptr findall_record( strings_column_view const& input, - std::string const& pattern, + std::string_view pattern, regex_flags const flags, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) @@ -133,7 +133,7 @@ std::unique_ptr findall_record( // external API std::unique_ptr findall_record(strings_column_view const& input, - std::string const& pattern, + std::string_view pattern, regex_flags const flags, rmm::mr::device_memory_resource* mr) { diff --git a/cpp/src/strings/split/split_re.cu b/cpp/src/strings/split/split_re.cu index 16edd0606e9..750f5fbe942 100644 --- a/cpp/src/strings/split/split_re.cu +++ b/cpp/src/strings/split/split_re.cu @@ -184,7 +184,7 @@ struct tokens_transform_fn { }; std::unique_ptr
split_re(strings_column_view const& input, - std::string const& pattern, + std::string_view pattern, split_direction direction, size_type maxsplit, rmm::cuda_stream_view stream, @@ -252,7 +252,7 @@ std::unique_ptr
split_re(strings_column_view const& input, } std::unique_ptr split_record_re(strings_column_view const& input, - std::string const& pattern, + std::string_view pattern, split_direction direction, size_type maxsplit, rmm::cuda_stream_view stream, @@ -289,7 +289,7 @@ std::unique_ptr split_record_re(strings_column_view const& input, } // namespace std::unique_ptr
split_re(strings_column_view const& input, - std::string const& pattern, + std::string_view pattern, size_type maxsplit, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) @@ -298,7 +298,7 @@ std::unique_ptr
split_re(strings_column_view const& input, } std::unique_ptr split_record_re(strings_column_view const& input, - std::string const& pattern, + std::string_view pattern, size_type maxsplit, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) @@ -307,7 +307,7 @@ std::unique_ptr split_record_re(strings_column_view const& input, } std::unique_ptr
rsplit_re(strings_column_view const& input, - std::string const& pattern, + std::string_view pattern, size_type maxsplit, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) @@ -316,7 +316,7 @@ std::unique_ptr
rsplit_re(strings_column_view const& input, } std::unique_ptr rsplit_record_re(strings_column_view const& input, - std::string const& pattern, + std::string_view pattern, size_type maxsplit, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) @@ -329,7 +329,7 @@ std::unique_ptr rsplit_record_re(strings_column_view const& input, // external APIs std::unique_ptr
split_re(strings_column_view const& input, - std::string const& pattern, + std::string_view pattern, size_type maxsplit, rmm::mr::device_memory_resource* mr) { @@ -338,7 +338,7 @@ std::unique_ptr
split_re(strings_column_view const& input, } std::unique_ptr split_record_re(strings_column_view const& input, - std::string const& pattern, + std::string_view pattern, size_type maxsplit, rmm::mr::device_memory_resource* mr) { @@ -347,7 +347,7 @@ std::unique_ptr split_record_re(strings_column_view const& input, } std::unique_ptr
rsplit_re(strings_column_view const& input, - std::string const& pattern, + std::string_view pattern, size_type maxsplit, rmm::mr::device_memory_resource* mr) { @@ -356,7 +356,7 @@ std::unique_ptr
rsplit_re(strings_column_view const& input, } std::unique_ptr rsplit_record_re(strings_column_view const& input, - std::string const& pattern, + std::string_view pattern, size_type maxsplit, rmm::mr::device_memory_resource* mr) { diff --git a/java/src/main/java/ai/rapids/cudf/Cuda.java b/java/src/main/java/ai/rapids/cudf/Cuda.java index 21843527fc2..56a754279fc 100755 --- a/java/src/main/java/ai/rapids/cudf/Cuda.java +++ b/java/src/main/java/ai/rapids/cudf/Cuda.java @@ -596,4 +596,10 @@ public static void multiBufferCopyAsync(long [] destAddrs, * no effect. */ public static native void profilerStop(); + + /** + * Synchronizes the whole device using cudaDeviceSynchronize. + * @note this is very expensive and should almost never be used + */ + public static native void deviceSynchronize(); } diff --git a/java/src/main/native/src/CudaJni.cpp b/java/src/main/native/src/CudaJni.cpp index 926521c55f9..ce1ad1b1671 100644 --- a/java/src/main/native/src/CudaJni.cpp +++ b/java/src/main/native/src/CudaJni.cpp @@ -390,4 +390,12 @@ JNIEXPORT void JNICALL Java_ai_rapids_cudf_Cuda_profilerStop(JNIEnv *env, jclass CATCH_STD(env, ); } +JNIEXPORT void JNICALL Java_ai_rapids_cudf_Cuda_deviceSynchronize(JNIEnv *env, jclass clazz) { + try { + cudf::jni::auto_set_device(env); + CUDF_CUDA_TRY(cudaDeviceSynchronize()); + } + CATCH_STD(env, ); +} + } // extern "C" diff --git a/python/cudf/cudf/core/resample.py b/python/cudf/cudf/core/resample.py index 2bed71ea751..57630e7d4a9 100644 --- a/python/cudf/cudf/core/resample.py +++ b/python/cudf/cudf/core/resample.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & +# SPDX-FileCopyrightText: Copyright (c) 2021-2022, NVIDIA CORPORATION & # AFFILIATES. All rights reserved. SPDX-License-Identifier: # Apache-2.0 # diff --git a/python/dask_cudf/dask_cudf/groupby.py b/python/dask_cudf/dask_cudf/groupby.py index a64aabe1a6b..22705c2b83b 100644 --- a/python/dask_cudf/dask_cudf/groupby.py +++ b/python/dask_cudf/dask_cudf/groupby.py @@ -245,6 +245,7 @@ def last(self, split_every=None, split_out=1): def aggregate(self, arg, split_every=None, split_out=1): if arg == "size": return self.size() + arg = _redirect_aggs(arg) if _groupby_supported(self) and _aggs_supported(arg, SUPPORTED_AGGS): @@ -431,6 +432,7 @@ def last(self, split_every=None, split_out=1): def aggregate(self, arg, split_every=None, split_out=1): if arg == "size": return self.size() + arg = _redirect_aggs(arg) if not isinstance(arg, dict): @@ -503,7 +505,7 @@ def groupby_agg( if isinstance(gb_cols, str): gb_cols = [gb_cols] columns = [c for c in ddf.columns if c not in gb_cols] - if isinstance(aggs, list): + if not isinstance(aggs, dict): aggs = {col: aggs for col in columns} # Assert if our output will have a MultiIndex; this will be the case if @@ -665,6 +667,8 @@ def _aggs_supported(arg, supported: set): _global_set = set(arg) return bool(_global_set.issubset(supported)) + elif isinstance(arg, str): + return arg in supported return False diff --git a/python/dask_cudf/dask_cudf/tests/test_groupby.py b/python/dask_cudf/dask_cudf/tests/test_groupby.py index 5aa9cffb789..2b7f2bdae36 100644 --- a/python/dask_cudf/dask_cudf/tests/test_groupby.py +++ b/python/dask_cudf/dask_cudf/tests/test_groupby.py @@ -58,7 +58,10 @@ def test_groupby_basic(series, aggregation): "func", [ lambda df: df.groupby("x").agg({"y": "max"}), + lambda df: df.groupby("x").agg(["sum", "max"]), lambda df: df.groupby("x").y.agg(["sum", "max"]), + lambda df: df.groupby("x").agg("sum"), + lambda df: df.groupby("x").y.agg("sum"), ], ) def test_groupby_agg(func): @@ -663,11 +666,20 @@ def test_groupby_agg_redirect(aggregations): @pytest.mark.parametrize( - "arg", - [["not_supported"], {"a": "not_supported"}, {"a": ["not_supported"]}], + "arg,supported", + [ + ("sum", True), + (["sum"], True), + ({"a": "sum"}, True), + ({"a": ["sum"]}, True), + ("not_supported", False), + (["not_supported"], False), + ({"a": "not_supported"}, False), + ({"a": ["not_supported"]}, False), + ], ) -def test_is_supported(arg): - assert _aggs_supported(arg, {"supported"}) is False +def test_is_supported(arg, supported): + assert _aggs_supported(arg, SUPPORTED_AGGS) is supported def test_groupby_unique_lists():