Skip to content

Commit

Permalink
Get rid of cudf device atomics
Browse files Browse the repository at this point in the history
  • Loading branch information
PointKernel committed Jun 14, 2023
1 parent 02be87b commit e9e11f2
Show file tree
Hide file tree
Showing 19 changed files with 149 additions and 1,102 deletions.
3 changes: 1 addition & 2 deletions cpp/benchmarks/join/generate_input_tables.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,7 +16,6 @@

#pragma once

#include <cudf/detail/utilities/device_atomics.cuh>
#include <cudf/utilities/default_stream.hpp>
#include <cudf/utilities/error.hpp>

Expand Down
58 changes: 38 additions & 20 deletions cpp/include/cudf/detail/aggregation/aggregation.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#include <cudf/column/column_device_view.cuh>
#include <cudf/detail/aggregation/aggregation.hpp>
#include <cudf/detail/utilities/assert.cuh>
#include <cudf/detail/utilities/device_atomics.cuh>
#include <cudf/detail/utilities/device_operators.cuh>
#include <cudf/dictionary/dictionary_column_view.hpp>
#include <cudf/table/table_device_view.cuh>
#include <cudf/utilities/traits.cuh>
Expand All @@ -30,6 +30,8 @@

#include <thrust/fill.h>

#include <cuda/atomic>

namespace cudf {
namespace detail {
/**
Expand Down Expand Up @@ -144,8 +146,9 @@ struct update_target_element<
if (source_has_nulls and source.is_null(source_index)) { return; }

using Target = target_type_t<Source, aggregation::MIN>;
atomicMin(&target.element<Target>(target_index),
static_cast<Target>(source.element<Source>(source_index)));
cuda::atomic_ref<Target, cuda::thread_scope_device> ref(target.element<Target>(target_index));
ref.fetch_min(static_cast<Target>(source.element<Source>(source_index)),
cuda::std::memory_order_relaxed);

if (target_has_nulls and target.is_null(target_index)) { target.set_valid(target_index); }
}
Expand All @@ -170,8 +173,10 @@ struct update_target_element<
using DeviceTarget = device_storage_type_t<Target>;
using DeviceSource = device_storage_type_t<Source>;

atomicMin(&target.element<DeviceTarget>(target_index),
static_cast<DeviceTarget>(source.element<DeviceSource>(source_index)));
cuda::atomic_ref<DeviceTarget, cuda::thread_scope_device> ref(
target.element<DeviceTarget>(target_index));
ref.fetch_min(static_cast<DeviceTarget>(source.element<DeviceSource>(source_index)),
cuda::std::memory_order_relaxed);

if (target_has_nulls and target.is_null(target_index)) { target.set_valid(target_index); }
}
Expand All @@ -193,8 +198,9 @@ struct update_target_element<
if (source_has_nulls and source.is_null(source_index)) { return; }

using Target = target_type_t<Source, aggregation::MAX>;
atomicMax(&target.element<Target>(target_index),
static_cast<Target>(source.element<Source>(source_index)));
cuda::atomic_ref<Target, cuda::thread_scope_device> ref(target.element<Target>(target_index));
ref.fetch_max(static_cast<Target>(source.element<Source>(source_index)),
cuda::std::memory_order_relaxed);

if (target_has_nulls and target.is_null(target_index)) { target.set_valid(target_index); }
}
Expand All @@ -219,8 +225,10 @@ struct update_target_element<
using DeviceTarget = device_storage_type_t<Target>;
using DeviceSource = device_storage_type_t<Source>;

atomicMax(&target.element<DeviceTarget>(target_index),
static_cast<DeviceTarget>(source.element<DeviceSource>(source_index)));
cuda::atomic_ref<DeviceTarget, cuda::thread_scope_device> ref(
target.element<DeviceTarget>(target_index));
ref.fetch_max(static_cast<DeviceTarget>(source.element<DeviceSource>(source_index)),
cuda::std::memory_order_relaxed);

if (target_has_nulls and target.is_null(target_index)) { target.set_valid(target_index); }
}
Expand All @@ -242,8 +250,9 @@ struct update_target_element<
if (source_has_nulls and source.is_null(source_index)) { return; }

using Target = target_type_t<Source, aggregation::SUM>;
atomicAdd(&target.element<Target>(target_index),
static_cast<Target>(source.element<Source>(source_index)));
cuda::atomic_ref<Target, cuda::thread_scope_device> ref(target.element<Target>(target_index));
ref.fetch_add(static_cast<Target>(source.element<Source>(source_index)),
cuda::std::memory_order_relaxed);

if (target_has_nulls and target.is_null(target_index)) { target.set_valid(target_index); }
}
Expand All @@ -268,8 +277,10 @@ struct update_target_element<
using DeviceTarget = device_storage_type_t<Target>;
using DeviceSource = device_storage_type_t<Source>;

atomicAdd(&target.element<DeviceTarget>(target_index),
static_cast<DeviceTarget>(source.element<DeviceSource>(source_index)));
cuda::atomic_ref<DeviceTarget, cuda::thread_scope_device> ref(
target.element<DeviceTarget>(target_index));
ref.fetch_add(static_cast<DeviceTarget>(source.element<DeviceSource>(source_index)),
cuda::std::memory_order_relaxed);

if (target_has_nulls and target.is_null(target_index)) { target.set_valid(target_index); }
}
Expand Down Expand Up @@ -368,7 +379,8 @@ struct update_target_element<Source,

using Target = target_type_t<Source, aggregation::SUM_OF_SQUARES>;
auto value = static_cast<Target>(source.element<Source>(source_index));
atomicAdd(&target.element<Target>(target_index), value * value);
cuda::atomic_ref<Target, cuda::thread_scope_device> ref(target.element<Target>(target_index));
ref.fetch_add(value * value, cuda::std::memory_order_relaxed);
if (target_has_nulls and target.is_null(target_index)) { target.set_valid(target_index); }
}
};
Expand Down Expand Up @@ -408,7 +420,8 @@ struct update_target_element<
if (source_has_nulls and source.is_null(source_index)) { return; }

using Target = target_type_t<Source, aggregation::COUNT_VALID>;
atomicAdd(&target.element<Target>(target_index), Target{1});
cuda::atomic_ref<Target, cuda::thread_scope_device> ref(target.element<Target>(target_index));
ref.fetch_add(1, cuda::std::memory_order_relaxed);

// It is assumed the output for COUNT_VALID is initialized to be all valid
}
Expand All @@ -427,7 +440,8 @@ struct update_target_element<
size_type source_index) const noexcept
{
using Target = target_type_t<Source, aggregation::COUNT_ALL>;
atomicAdd(&target.element<Target>(target_index), Target{1});
cuda::atomic_ref<Target, cuda::thread_scope_device> ref(target.element<Target>(target_index));
ref.fetch_add(1, cuda::std::memory_order_relaxed);

// It is assumed the output for COUNT_ALL is initialized to be all valid
}
Expand All @@ -449,10 +463,12 @@ struct update_target_element<
if (source_has_nulls and source.is_null(source_index)) { return; }

using Target = target_type_t<Source, aggregation::ARGMAX>;
auto old = atomicCAS(&target.element<Target>(target_index), ARGMAX_SENTINEL, source_index);
cuda::atomic_ref<Target, cuda::thread_scope_device> ref(target.element<Target>(target_index));
auto old =
ref.compare_exchange_strong(ARGMAX_SENTINEL, source_index, cuda::std::memory_order_relaxed);
if (old != ARGMAX_SENTINEL) {
while (source.element<Source>(source_index) > source.element<Source>(old)) {
old = atomicCAS(&target.element<Target>(target_index), old, source_index);
old = ref.compare_exchange_strong(old, source_index, cuda::std::memory_order_relaxed);
}
}

Expand All @@ -476,10 +492,12 @@ struct update_target_element<
if (source_has_nulls and source.is_null(source_index)) { return; }

using Target = target_type_t<Source, aggregation::ARGMIN>;
auto old = atomicCAS(&target.element<Target>(target_index), ARGMIN_SENTINEL, source_index);
cuda::atomic_ref<Target, cuda::thread_scope_device> ref(target.element<Target>(target_index));
auto old =
ref.compare_exchange_strong(ARGMIN_SENTINEL, source_index, cuda::std::memory_order_relaxed);
if (old != ARGMIN_SENTINEL) {
while (source.element<Source>(source_index) < source.element<Source>(old)) {
old = atomicCAS(&target.element<Target>(target_index), old, source_index);
old = ref.compare_exchange_strong(old, source_index, cuda::std::memory_order_relaxed);
}
}

Expand Down
13 changes: 9 additions & 4 deletions cpp/include/cudf/detail/copy_if.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
#include <cudf/detail/gather.hpp>
#include <cudf/detail/nvtx/ranges.hpp>
#include <cudf/detail/utilities/cuda.cuh>
#include <cudf/detail/utilities/device_atomics.cuh>
#include <cudf/null_mask.hpp>
#include <cudf/strings/string_view.cuh>
#include <cudf/table/table.hpp>
Expand All @@ -44,6 +43,8 @@

#include <cub/cub.cuh>

#include <cuda/atomic>

#include <algorithm>

namespace cudf {
Expand Down Expand Up @@ -181,7 +182,8 @@ __launch_bounds__(block_size) __global__
if (wid > 0 && wid < last_warp)
output_valid[valid_index] = valid_warp;
else {
atomicOr(&output_valid[valid_index], valid_warp);
cuda::atomic_ref<uint32_t, cuda::thread_scope_device> ref(output_valid[valid_index]);
ref.fetch_or(valid_warp, cuda::std::memory_order_relaxed);
}
}

Expand All @@ -190,7 +192,9 @@ __launch_bounds__(block_size) __global__
uint32_t valid_warp = __ballot_sync(0xffff'ffffu, temp_valids[block_size + threadIdx.x]);
if (lane == 0 && valid_warp != 0) {
tmp_warp_valid_counts += __popc(valid_warp);
atomicOr(&output_valid[valid_index + num_warps], valid_warp);
cuda::atomic_ref<uint32_t, cuda::thread_scope_device> ref(
output_valid[valid_index + num_warps]);
ref.fetch_or(valid_warp, cuda::std::memory_order_relaxed);
}
}
}
Expand All @@ -206,7 +210,8 @@ __launch_bounds__(block_size) __global__
cudf::detail::single_lane_block_sum_reduce<block_size, leader_lane>(warp_valid_counts);

if (threadIdx.x == 0) { // one thread computes and adds to null count
atomicAdd(output_null_count, block_sum - block_valid_count);
cuda::atomic_ref<cudf::size_type, cuda::thread_scope_device> ref(*output_null_count);
ref.fetch_add(block_sum - block_valid_count, cuda::std::memory_order_relaxed);
}
}

Expand Down
Loading

0 comments on commit e9e11f2

Please sign in to comment.