Skip to content

Commit

Permalink
Pass stream and user resource to make_default_constructed_scalar (#7469)
Browse files Browse the repository at this point in the history
The `make_default_constructed_scalar` factory currently [doesn't take streams or user resources](https://github.com/rapidsai/cudf/blob/c929ba1fe85c152d6e8b4c868cd36f0802dafa51/cpp/src/scalar/scalar_factories.cpp#L100-L133), which could lead to inconsistent results if the constructed scalar is used with a non-default non-blocking stream later on. This PR is to fix the issue.

Authors:
  - Wonchan Lee (@magnatelee)

Approvers:
  - Mark Harris (@harrism)
  - Vukasin Milovanovic (@vuule)

URL: #7469
  • Loading branch information
magnatelee authored Mar 3, 2021
1 parent 8341db4 commit 2a0be16
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 16 deletions.
7 changes: 6 additions & 1 deletion cpp/include/cudf/scalar/scalar_factories.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,13 @@ std::unique_ptr<scalar> make_string_scalar(
* @throws std::bad_alloc if device memory allocation fails
*
* @param type The desired element type
* @param stream CUDA stream used for device memory operations.
* @param mr Device memory resource used to allocate the scalar's `data` and `is_valid` bool.
*/
std::unique_ptr<scalar> make_default_constructed_scalar(data_type type);
std::unique_ptr<scalar> make_default_constructed_scalar(
data_type type,
rmm::cuda_stream_view stream = rmm::cuda_stream_default,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
* @brief Construct scalar using the given value of fixed width type
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/copying/get_element.cu
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ struct get_element_functor {
stream);

if (!key_index_scalar.is_valid(stream)) {
auto null_result = make_default_constructed_scalar(dict_view.keys().type());
auto null_result = make_default_constructed_scalar(dict_view.keys().type(), stream, mr);
null_result->set_valid(false, stream);
return null_result;
}
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/reductions/minmax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,8 @@ std::pair<std::unique_ptr<scalar>, std::unique_ptr<scalar>> minmax(
if (col.null_count() == col.size()) {
// this handles empty and all-null columns
// return scalars with valid==false
return {make_default_constructed_scalar(col.type()),
make_default_constructed_scalar(col.type())};
return {make_default_constructed_scalar(col.type(), stream, mr),
make_default_constructed_scalar(col.type(), stream, mr)};
}

return type_dispatcher(col.type(), minmax_functor{}, col, stream, mr);
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/reductions/reductions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ std::unique_ptr<scalar> reduce(
rmm::cuda_stream_view stream = rmm::cuda_stream_default,
rmm::mr::device_memory_resource *mr = rmm::mr::get_current_device_resource())
{
std::unique_ptr<scalar> result = make_default_constructed_scalar(output_dtype);
std::unique_ptr<scalar> result = make_default_constructed_scalar(output_dtype, stream, mr);
result->set_valid(false, stream);

// check if input column is empty
Expand Down
28 changes: 20 additions & 8 deletions cpp/src/scalar/scalar_factories.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,36 +100,48 @@ std::unique_ptr<scalar> make_fixed_width_scalar(data_type type,
namespace {
struct default_scalar_functor {
template <typename T>
std::unique_ptr<cudf::scalar> operator()()
std::unique_ptr<cudf::scalar> operator()(rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
using ScalarType = scalar_type_t<T>;
return std::unique_ptr<scalar>(new ScalarType);
return make_fixed_width_scalar(data_type(type_to_id<T>()), stream, mr);
}
};

template <>
std::unique_ptr<cudf::scalar> default_scalar_functor::operator()<dictionary32>()
std::unique_ptr<cudf::scalar> default_scalar_functor::operator()<string_view>(
rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr)
{
return std::unique_ptr<scalar>(new string_scalar("", false, stream, mr));
}

template <>
std::unique_ptr<cudf::scalar> default_scalar_functor::operator()<dictionary32>(
rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr)
{
CUDF_FAIL("dictionary type not supported");
}

template <>
std::unique_ptr<cudf::scalar> default_scalar_functor::operator()<list_view>()
std::unique_ptr<cudf::scalar> default_scalar_functor::operator()<list_view>(
rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr)
{
CUDF_FAIL("list_view type not supported");
}

template <>
std::unique_ptr<cudf::scalar> default_scalar_functor::operator()<struct_view>()
std::unique_ptr<cudf::scalar> default_scalar_functor::operator()<struct_view>(
rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr)
{
CUDF_FAIL("struct_view type not supported");
}

} // namespace

std::unique_ptr<scalar> make_default_constructed_scalar(data_type type)
std::unique_ptr<scalar> make_default_constructed_scalar(data_type type,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
return type_dispatcher(type, default_scalar_functor{});
return type_dispatcher(type, default_scalar_functor{}, stream, mr);
}

} // namespace cudf
8 changes: 5 additions & 3 deletions cpp/tests/scalar/factories_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ TYPED_TEST(TimestampScalarFactory, TypeCast)
}

template <typename T>
struct DefaultScalarFactory : public cudf::test::BaseFixture {
struct DefaultScalarFactory : public ScalarFactoryTest {
static constexpr auto factory = cudf::make_default_constructed_scalar;
};

Expand All @@ -98,15 +98,17 @@ TYPED_TEST_CASE(DefaultScalarFactory, MixedTypes);

TYPED_TEST(DefaultScalarFactory, FactoryDefault)
{
std::unique_ptr<cudf::scalar> s = this->factory(cudf::data_type{cudf::type_to_id<TypeParam>()});
std::unique_ptr<cudf::scalar> s =
this->factory(cudf::data_type{cudf::type_to_id<TypeParam>()}, this->stream(), this->mr());

EXPECT_EQ(s->type(), cudf::data_type{cudf::type_to_id<TypeParam>()});
EXPECT_FALSE(s->is_valid());
}

TYPED_TEST(DefaultScalarFactory, TypeCast)
{
std::unique_ptr<cudf::scalar> s = this->factory(cudf::data_type{cudf::type_to_id<TypeParam>()});
std::unique_ptr<cudf::scalar> s =
this->factory(cudf::data_type{cudf::type_to_id<TypeParam>()}, this->stream(), this->mr());

auto numeric_s = static_cast<cudf::scalar_type_t<TypeParam>*>(s.get());

Expand Down

0 comments on commit 2a0be16

Please sign in to comment.