Skip to content

Commit

Permalink
[SYCL] Add prototype of atomic_ref<T*> (#2177)
Browse files Browse the repository at this point in the history
Enables partial specialization of atomic_ref for pointer types.

Implementation assumes that both host and device pointers can be stored
in a uintptr_t, but uses compare_exchange to implement pointer arithmetic
rather than make assumptions about how pointers will be represented on
different devices.

Signed-off-by: John Pennycook <[email protected]>
  • Loading branch information
Pennycook authored Jul 30, 2020
1 parent 8ac87a3 commit a3c3425
Show file tree
Hide file tree
Showing 10 changed files with 204 additions and 97 deletions.
2 changes: 1 addition & 1 deletion sycl/doc/extensions/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ DPC++ extensions status:
| [SYCL_INTEL_deduction_guides](deduction_guides/SYCL_INTEL_deduction_guides.asciidoc) | Supported | |
| [SYCL_INTEL_device_specific_kernel_queries](DeviceSpecificKernelQueries/SYCL_INTEL_device_specific_kernel_queries.asciidoc) | Proposal | |
| [SYCL_INTEL_enqueue_barrier](EnqueueBarrier/enqueue_barrier.asciidoc) | Supported(OpenCL, Level Zero) | |
| [SYCL_INTEL_extended_atomics](ExtendedAtomics/SYCL_INTEL_extended_atomics.asciidoc) | Partially supported(OpenCL: CPU, GPU) | Not supported: pointer types |
| [SYCL_INTEL_extended_atomics](ExtendedAtomics/SYCL_INTEL_extended_atomics.asciidoc) | Supported(OpenCL: CPU, GPU) | |
| [SYCL_INTEL_group_algorithms](GroupAlgorithms/SYCL_INTEL_group_algorithms.asciidoc) | Supported(OpenCL) | |
| [SYCL_INTEL_group_mask](./GroupMask/SYCL_INTEL_group_mask.asciidoc) | Proposal | |
| [FPGA selector](IntelFPGA/FPGASelector.md) | Supported | |
Expand Down
136 changes: 130 additions & 6 deletions sycl/include/CL/sycl/intel/atomic_ref.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,6 @@ class atomic_ref_base {
static_assert(!(std::is_same<T, short>::value ||
std::is_same<T, unsigned short>::value),
"intel::atomic_ref does not support short type");
static_assert(!std::is_pointer<T>::value,
"intel::atomic_ref does not yet support pointer types");
static_assert(detail::IsValidAtomicAddressSpace<AddressSpace>::value,
"Invalid atomic address_space. Valid address spaces are: "
"global_space, local_space, global_device_space");
Expand Down Expand Up @@ -508,12 +506,138 @@ class atomic_ref_impl<
};

// Partial specialization for pointer types
// Arithmetic is emulated because target's representation of T* is unknown
// TODO: Find a way to use intptr_t or uintptr_t atomics instead
template <typename T, memory_order DefaultOrder, memory_scope DefaultScope,
access::address_space AddressSpace>
class atomic_ref_impl<T *, DefaultOrder, DefaultScope, AddressSpace,
typename detail::enable_if_t<std::is_pointer<T>::value>>
: public atomic_ref_base<T *, DefaultOrder, DefaultScope, AddressSpace> {
// TODO: Implement partial specialization for pointer types
class atomic_ref_impl<T *, DefaultOrder, DefaultScope, AddressSpace>
: public atomic_ref_base<uintptr_t, DefaultOrder, DefaultScope,
AddressSpace> {

private:
using base_type =
atomic_ref_base<uintptr_t, DefaultOrder, DefaultScope, AddressSpace>;

public:
using value_type = T *;
using difference_type = ptrdiff_t;
static constexpr size_t required_alignment = sizeof(T *);
static constexpr bool is_always_lock_free =
detail::IsValidAtomicType<T>::value;
static constexpr memory_order default_read_order =
detail::memory_order_traits<DefaultOrder>::read_order;
static constexpr memory_order default_write_order =
detail::memory_order_traits<DefaultOrder>::write_order;
static constexpr memory_order default_read_modify_write_order = DefaultOrder;
static constexpr memory_scope default_scope = DefaultScope;

using base_type::is_lock_free;

atomic_ref_impl(T *&ref) : base_type(reinterpret_cast<uintptr_t &>(ref)) {}

void store(T *operand, memory_order order = default_write_order,
memory_scope scope = default_scope) const noexcept {
base_type::store(reinterpret_cast<uintptr_t>(operand), order, scope);
}

T *operator=(T *desired) const noexcept {
store(desired);
return desired;
}

T *load(memory_order order = default_read_order,
memory_scope scope = default_scope) const noexcept {
return reinterpret_cast<T *>(base_type::load(order, scope));
}

operator T *() const noexcept { return load(); }

T *exchange(T *operand, memory_order order = default_read_modify_write_order,
memory_scope scope = default_scope) const noexcept {
return reinterpret_cast<T *>(base_type::exchange(
reinterpret_cast<uintptr_t>(operand), order, scope));
}

T *fetch_add(difference_type operand,
memory_order order = default_read_modify_write_order,
memory_scope scope = default_scope) const noexcept {
// TODO: Find a way to avoid compare_exchange here
auto load_order = detail::getLoadOrder(order);
T *expected = load(load_order, scope);
T *desired;
do {
desired = expected + operand;
} while (!compare_exchange_weak(expected, desired, order, scope));
return expected;
}

T *operator+=(difference_type operand) const noexcept {
return fetch_add(operand) + operand;
}

T *operator++(int) const noexcept { return fetch_add(difference_type(1)); }

T *operator++() const noexcept {
return fetch_add(difference_type(1)) + difference_type(1);
}

T *fetch_sub(difference_type operand,
memory_order order = default_read_modify_write_order,
memory_scope scope = default_scope) const noexcept {
// TODO: Find a way to avoid compare_exchange here
auto load_order = detail::getLoadOrder(order);
T *expected = load(load_order, scope);
T *desired;
do {
desired = expected - operand;
} while (!compare_exchange_weak(expected, desired, order, scope));
return expected;
}

T *operator-=(difference_type operand) const noexcept {
return fetch_sub(operand) - operand;
}

T *operator--(int) const noexcept { return fetch_sub(difference_type(1)); }

T *operator--() const noexcept {
return fetch_sub(difference_type(1)) - difference_type(1);
}

bool
compare_exchange_strong(T *&expected, T *desired, memory_order success,
memory_order failure,
memory_scope scope = default_scope) const noexcept {
return base_type::compare_exchange_strong(
reinterpret_cast<uintptr_t &>(expected),
reinterpret_cast<uintptr_t>(desired), success, failure, scope);
}

bool
compare_exchange_strong(T *&expected, T *desired,
memory_order order = default_read_modify_write_order,
memory_scope scope = default_scope) const noexcept {
return compare_exchange_strong(expected, desired, order, order, scope);
}

bool
compare_exchange_weak(T *&expected, T *desired, memory_order success,
memory_order failure,
memory_scope scope = default_scope) const noexcept {
return base_type::compare_exchange_weak(
reinterpret_cast<uintptr_t &>(expected),
reinterpret_cast<uintptr_t>(desired), success, failure, scope);
}

bool
compare_exchange_weak(T *&expected, T *desired,
memory_order order = default_read_modify_write_order,
memory_scope scope = default_scope) const noexcept {
return compare_exchange_weak(expected, desired, order, order, scope);
}

private:
using base_type::ptr;
};

} // namespace detail
Expand Down
50 changes: 24 additions & 26 deletions sycl/test/atomic_ref/add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
using namespace sycl;
using namespace sycl::intel;

template <typename T>
template <typename T, typename Difference = T>
void add_fetch_test(queue q, size_t N) {
T sum = 0;
std::vector<T> output(N);
std::fill(output.begin(), output.end(), 0);
std::fill(output.begin(), output.end(), T(0));
{
buffer<T> sum_buf(&sum, 1);
buffer<T> output_buf(output.data(), output.size());
Expand All @@ -27,29 +27,29 @@ void add_fetch_test(queue q, size_t N) {
cgh.parallel_for(range<1>(N), [=](item<1> it) {
int gid = it.get_id(0);
auto atm = atomic_ref<T, intel::memory_order::relaxed, intel::memory_scope::device, access::address_space::global_space>(sum[0]);
out[gid] = atm.fetch_add(T(1));
out[gid] = atm.fetch_add(Difference(1));
});
});
}

// All work-items increment by 1, so final value should be equal to N
assert(sum == N);
assert(sum == T(N));

// Fetch returns original value: will be in [0, N-1]
auto min_e = std::min_element(output.begin(), output.end());
auto max_e = std::max_element(output.begin(), output.end());
assert(*min_e == 0 && *max_e == N - 1);
assert(*min_e == T(0) && *max_e == T(N - 1));

// Intermediate values should be unique
std::sort(output.begin(), output.end());
assert(std::unique(output.begin(), output.end()) == output.end());
}

template <typename T>
template <typename T, typename Difference = T>
void add_plus_equal_test(queue q, size_t N) {
T sum = 0;
std::vector<T> output(N);
std::fill(output.begin(), output.end(), 0);
std::fill(output.begin(), output.end(), T(0));
{
buffer<T> sum_buf(&sum, 1);
buffer<T> output_buf(output.data(), output.size());
Expand All @@ -60,29 +60,29 @@ void add_plus_equal_test(queue q, size_t N) {
cgh.parallel_for(range<1>(N), [=](item<1> it) {
int gid = it.get_id(0);
auto atm = atomic_ref<T, intel::memory_order::relaxed, intel::memory_scope::device, access::address_space::global_space>(sum[0]);
out[gid] = atm += T(1);
out[gid] = atm += Difference(1);
});
});
}

// All work-items increment by 1, so final value should be equal to N
assert(sum == N);
assert(sum == T(N));

// += returns updated value: will be in [1, N]
auto min_e = std::min_element(output.begin(), output.end());
auto max_e = std::max_element(output.begin(), output.end());
assert(*min_e == 1 && *max_e == N);
assert(*min_e == T(1) && *max_e == T(N));

// Intermediate values should be unique
std::sort(output.begin(), output.end());
assert(std::unique(output.begin(), output.end()) == output.end());
}

template <typename T>
template <typename T, typename Difference = T>
void add_pre_inc_test(queue q, size_t N) {
T sum = 0;
std::vector<T> output(N);
std::fill(output.begin(), output.end(), 0);
std::fill(output.begin(), output.end(), T(0));
{
buffer<T> sum_buf(&sum, 1);
buffer<T> output_buf(output.data(), output.size());
Expand All @@ -99,23 +99,23 @@ void add_pre_inc_test(queue q, size_t N) {
}

// All work-items increment by 1, so final value should be equal to N
assert(sum == N);
assert(sum == T(N));

// Pre-increment returns updated value: will be in [1, N]
auto min_e = std::min_element(output.begin(), output.end());
auto max_e = std::max_element(output.begin(), output.end());
assert(*min_e == 1 && *max_e == N);
assert(*min_e == T(1) && *max_e == T(N));

// Intermediate values should be unique
std::sort(output.begin(), output.end());
assert(std::unique(output.begin(), output.end()) == output.end());
}

template <typename T>
template <typename T, typename Difference = T>
void add_post_inc_test(queue q, size_t N) {
T sum = 0;
std::vector<T> output(N);
std::fill(output.begin(), output.end(), 0);
std::fill(output.begin(), output.end(), T(0));
{
buffer<T> sum_buf(&sum, 1);
buffer<T> output_buf(output.data(), output.size());
Expand All @@ -132,24 +132,24 @@ void add_post_inc_test(queue q, size_t N) {
}

// All work-items increment by 1, so final value should be equal to N
assert(sum == N);
assert(sum == T(N));

// Post-increment returns original value: will be in [0, N-1]
auto min_e = std::min_element(output.begin(), output.end());
auto max_e = std::max_element(output.begin(), output.end());
assert(*min_e == 0 && *max_e == N - 1);
assert(*min_e == T(0) && *max_e == T(N - 1));

// Intermediate values should be unique
std::sort(output.begin(), output.end());
assert(std::unique(output.begin(), output.end()) == output.end());
}

template <typename T>
template <typename T, typename Difference = T>
void add_test(queue q, size_t N) {
add_fetch_test<T>(q, N);
add_plus_equal_test<T>(q, N);
add_pre_inc_test<T>(q, N);
add_post_inc_test<T>(q, N);
add_fetch_test<T, Difference>(q, N);
add_plus_equal_test<T, Difference>(q, N);
add_pre_inc_test<T, Difference>(q, N);
add_post_inc_test<T, Difference>(q, N);
}

// Floating-point types do not support pre- or post-increment
Expand All @@ -173,8 +173,6 @@ int main() {
}

constexpr int N = 32;

// TODO: Enable missing tests when supported
add_test<int>(q, N);
add_test<unsigned int>(q, N);
add_test<long>(q, N);
Expand All @@ -183,7 +181,7 @@ int main() {
add_test<unsigned long long>(q, N);
add_test<float>(q, N);
add_test<double>(q, N);
//add_test<char*>(q, N);
add_test<char *, ptrdiff_t>(q, N);

std::cout << "Test passed." << std::endl;
}
19 changes: 9 additions & 10 deletions sycl/test/atomic_ref/compare_exchange.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,27 @@ class compare_exchange_kernel;

template <typename T>
void compare_exchange_test(queue q, size_t N) {
const T initial = std::numeric_limits<T>::max();
const T initial = T(N);
T compare_exchange = initial;
std::vector<T> output(N);
std::fill(output.begin(), output.end(), 0);
std::fill(output.begin(), output.end(), T(0));
{
buffer<T> compare_exchange_buf(&compare_exchange, 1);
buffer<T> output_buf(output.data(), output.size());

q.submit([&](handler &cgh) {
auto exc = compare_exchange_buf.template get_access<access::mode::read_write>(cgh);
auto out = output_buf.template get_access<access::mode::discard_write>(cgh);
cgh.parallel_for<compare_exchange_kernel<T>>(range<1>(N), [=](item<1> it) {
int gid = it.get_id(0);
cgh.parallel_for<compare_exchange_kernel<T>>(range<1>(N), [=](item<1>
it) {
size_t gid = it.get_id(0);
auto atm = atomic_ref<T, intel::memory_order::relaxed, intel::memory_scope::device, access::address_space::global_space>(exc[0]);
T result = initial;
T result = T(N); // Avoid copying pointer
bool success = atm.compare_exchange_strong(result, (T)gid);
if (success) {
out[gid] = result;
} else {
out[gid] = gid;
out[gid] = T(gid);
}
});
});
Expand All @@ -45,7 +46,7 @@ void compare_exchange_test(queue q, size_t N) {
assert(std::count(output.begin(), output.end(), initial) == 1);

// All other values should be the index itself or the sentinel value
for (int i = 0; i < N; ++i) {
for (size_t i = 0; i < N; ++i) {
assert(output[i] == T(i) || output[i] == initial);
}
}
Expand All @@ -59,8 +60,6 @@ int main() {
}

constexpr int N = 32;

// TODO: Enable missing tests when supported
compare_exchange_test<int>(q, N);
compare_exchange_test<unsigned int>(q, N);
compare_exchange_test<long>(q, N);
Expand All @@ -69,7 +68,7 @@ int main() {
compare_exchange_test<unsigned long long>(q, N);
compare_exchange_test<float>(q, N);
compare_exchange_test<double>(q, N);
//compare_exchange_test<char*>(q, N);
compare_exchange_test<char *>(q, N);

std::cout << "Test passed." << std::endl;
}
Loading

0 comments on commit a3c3425

Please sign in to comment.