Skip to content

Commit

Permalink
Final changes
Browse files Browse the repository at this point in the history
  • Loading branch information
codereport committed Feb 3, 2021
1 parent cd04a8e commit d22fb16
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 14 deletions.
91 changes: 81 additions & 10 deletions cpp/include/cudf/detail/aggregation/aggregation.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,34 @@ struct update_target_element<
}
};

template <typename Source, bool target_has_nulls, bool source_has_nulls>
struct update_target_element<Source,
aggregation::MIN,
target_has_nulls,
source_has_nulls,
std::enable_if_t<is_fixed_point<Source>()>> {
__device__ void operator()(mutable_column_device_view target,
size_type target_index,
column_device_view source,
size_type source_index) const noexcept
{
#if (__CUDACC_VER_MAJOR__ != 10) or (__CUDACC_VER_MINOR__ != 2)

if (source_has_nulls and source.is_null(source_index)) { return; }

using Target = target_type_t<Source, aggregation::MIN>;
using DeviceTarget = device_storage_type_t<Target>;
using DeviceSource = device_storage_type_t<Source>;

atomicMin(&target.element<DeviceTarget>(target_index),
static_cast<DeviceTarget>(source.element<DeviceSource>(source_index)));

if (target_has_nulls and target.is_null(target_index)) { target.set_valid(target_index); }

#endif
}
};

template <typename Source, bool target_has_nulls, bool source_has_nulls>
struct update_target_element<
Source,
Expand All @@ -154,35 +182,78 @@ struct update_target_element<

template <typename Source, bool target_has_nulls, bool source_has_nulls>
struct update_target_element<Source,
aggregation::SUM,
aggregation::MAX,
target_has_nulls,
source_has_nulls,
std::enable_if_t<is_fixed_width<Source>()>> {
std::enable_if_t<is_fixed_point<Source>()>> {
__device__ void operator()(mutable_column_device_view target,
size_type target_index,
column_device_view source,
size_type source_index) const noexcept
{
if (source_has_nulls and source.is_null(source_index)) { return; }

using Target = target_type_t<Source, aggregation::SUM>;
#if (__CUDACC_VER_MAJOR__ != 10) or (__CUDACC_VER_MINOR__ != 2)

// #if (__CUDACC_VER_MAJOR__ != 10) or (__CUDACC_VER_MINOR__ != 2)
if (source_has_nulls and source.is_null(source_index)) { return; }

using Target = target_type_t<Source, aggregation::MAX>;
using DeviceTarget = device_storage_type_t<Target>;
using DeviceSource = device_storage_type_t<Source>;

// #else
atomicMax(&target.element<DeviceTarget>(target_index),
static_cast<DeviceTarget>(source.element<DeviceSource>(source_index)));

// using DeviceTarget = Target;
// using DeviceSource = Source;
if (target_has_nulls and target.is_null(target_index)) { target.set_valid(target_index); }

// #endif
#endif
}
};

template <typename Source, bool target_has_nulls, bool source_has_nulls>
struct update_target_element<
Source,
aggregation::SUM,
target_has_nulls,
source_has_nulls,
std::enable_if_t<is_fixed_width<Source>() && !is_fixed_point<Source>()>> {
__device__ void operator()(mutable_column_device_view target,
size_type target_index,
column_device_view source,
size_type source_index) const noexcept
{
if (source_has_nulls and source.is_null(source_index)) { return; }

using Target = target_type_t<Source, aggregation::SUM>;
atomicAdd(&target.element<Target>(target_index),
static_cast<Target>(source.element<Source>(source_index)));

if (target_has_nulls and target.is_null(target_index)) { target.set_valid(target_index); }
}
};

template <typename Source, bool target_has_nulls, bool source_has_nulls>
struct update_target_element<Source,
aggregation::SUM,
target_has_nulls,
source_has_nulls,
std::enable_if_t<is_fixed_point<Source>()>> {
__device__ void operator()(mutable_column_device_view target,
size_type target_index,
column_device_view source,
size_type source_index) const noexcept
{
#if (__CUDACC_VER_MAJOR__ != 10) or (__CUDACC_VER_MINOR__ != 2)

if (source_has_nulls and source.is_null(source_index)) { return; }

using Target = target_type_t<Source, aggregation::SUM>;
using DeviceTarget = device_storage_type_t<Target>;
using DeviceSource = device_storage_type_t<Source>;

atomicAdd(&target.element<DeviceTarget>(target_index),
static_cast<DeviceTarget>(source.element<DeviceSource>(source_index)));

if (target_has_nulls and target.is_null(target_index)) { target.set_valid(target_index); }
#endif
}
};

Expand Down
6 changes: 2 additions & 4 deletions cpp/tests/groupby/group_count_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,12 +240,11 @@ TYPED_TEST(FixedPointTestBothReps, GroupBySumProductMinMaxDecimalAsValue)
test_single_agg(keys, vals, expect_keys, {}, std::move(agg4), force_use_sort_impl::YES),
cudf::logic_error);

#if !((__CUDACC_VER_MAJOR__ == 10) and (__CUDACC_VER_MINOR__ == 2))

// group_by hash tests

auto agg5 = cudf::make_sum_aggregation();
test_single_agg(keys, vals, expect_keys, expect_vals_sum, std::move(agg5));
// auto agg5 = cudf::make_sum_aggregation();
// test_single_agg(keys, vals, expect_keys, expect_vals_sum, std::move(agg5));

// auto agg6 = cudf::make_min_aggregation();
// test_single_agg(keys, vals, expect_keys, expect_vals_min, std::move(agg6));
Expand All @@ -257,7 +256,6 @@ TYPED_TEST(FixedPointTestBothReps, GroupBySumProductMinMaxDecimalAsValue)
// EXPECT_THROW(test_single_agg(keys, vals, expect_keys, {}, std::move(agg8)),
// cudf::logic_error);

#endif
}
}

Expand Down

0 comments on commit d22fb16

Please sign in to comment.