Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL] Add prototype of atomic_ref<T*> #2177

Merged
merged 4 commits into from
Jul 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sycl/doc/extensions/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ DPC++ extensions status:
| [SYCL_INTEL_deduction_guides](deduction_guides/SYCL_INTEL_deduction_guides.asciidoc) | Supported | |
| [SYCL_INTEL_device_specific_kernel_queries](DeviceSpecificKernelQueries/SYCL_INTEL_device_specific_kernel_queries.asciidoc) | Proposal | |
| [SYCL_INTEL_enqueue_barrier](EnqueueBarrier/enqueue_barrier.asciidoc) | Supported(OpenCL, Level Zero) | |
| [SYCL_INTEL_extended_atomics](ExtendedAtomics/SYCL_INTEL_extended_atomics.asciidoc) | Partially supported(OpenCL: CPU, GPU) | Not supported: pointer types |
| [SYCL_INTEL_extended_atomics](ExtendedAtomics/SYCL_INTEL_extended_atomics.asciidoc) | Supported(OpenCL: CPU, GPU) | |
| [SYCL_INTEL_group_algorithms](GroupAlgorithms/SYCL_INTEL_group_algorithms.asciidoc) | Supported(OpenCL) | |
| [SYCL_INTEL_group_mask](./GroupMask/SYCL_INTEL_group_mask.asciidoc) | Proposal | |
| [FPGA selector](IntelFPGA/FPGASelector.md) | Supported | |
Expand Down
136 changes: 130 additions & 6 deletions sycl/include/CL/sycl/intel/atomic_ref.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,6 @@ class atomic_ref_base {
static_assert(!(std::is_same<T, short>::value ||
std::is_same<T, unsigned short>::value),
"intel::atomic_ref does not support short type");
static_assert(!std::is_pointer<T>::value,
"intel::atomic_ref does not yet support pointer types");
static_assert(detail::IsValidAtomicAddressSpace<AddressSpace>::value,
"Invalid atomic address_space. Valid address spaces are: "
"global_space, local_space, global_device_space");
Expand Down Expand Up @@ -508,12 +506,138 @@ class atomic_ref_impl<
};

// Partial specialization for pointer types
// Arithmetic is emulated because target's representation of T* is unknown
// TODO: Find a way to use intptr_t or uintptr_t atomics instead
template <typename T, memory_order DefaultOrder, memory_scope DefaultScope,
access::address_space AddressSpace>
class atomic_ref_impl<T *, DefaultOrder, DefaultScope, AddressSpace,
typename detail::enable_if_t<std::is_pointer<T>::value>>
: public atomic_ref_base<T *, DefaultOrder, DefaultScope, AddressSpace> {
// TODO: Implement partial specialization for pointer types
class atomic_ref_impl<T *, DefaultOrder, DefaultScope, AddressSpace>
: public atomic_ref_base<uintptr_t, DefaultOrder, DefaultScope,
AddressSpace> {

private:
using base_type =
atomic_ref_base<uintptr_t, DefaultOrder, DefaultScope, AddressSpace>;

public:
using value_type = T *;
using difference_type = ptrdiff_t;
static constexpr size_t required_alignment = sizeof(T *);
static constexpr bool is_always_lock_free =
detail::IsValidAtomicType<T>::value;
static constexpr memory_order default_read_order =
detail::memory_order_traits<DefaultOrder>::read_order;
static constexpr memory_order default_write_order =
detail::memory_order_traits<DefaultOrder>::write_order;
static constexpr memory_order default_read_modify_write_order = DefaultOrder;
static constexpr memory_scope default_scope = DefaultScope;

using base_type::is_lock_free;

atomic_ref_impl(T *&ref) : base_type(reinterpret_cast<uintptr_t &>(ref)) {}

void store(T *operand, memory_order order = default_write_order,
memory_scope scope = default_scope) const noexcept {
base_type::store(reinterpret_cast<uintptr_t>(operand), order, scope);
}

T *operator=(T *desired) const noexcept {
store(desired);
return desired;
}

T *load(memory_order order = default_read_order,
memory_scope scope = default_scope) const noexcept {
return reinterpret_cast<T *>(base_type::load(order, scope));
}

operator T *() const noexcept { return load(); }

T *exchange(T *operand, memory_order order = default_read_modify_write_order,
memory_scope scope = default_scope) const noexcept {
return reinterpret_cast<T *>(base_type::exchange(
reinterpret_cast<uintptr_t>(operand), order, scope));
}

T *fetch_add(difference_type operand,
memory_order order = default_read_modify_write_order,
memory_scope scope = default_scope) const noexcept {
// TODO: Find a way to avoid compare_exchange here
auto load_order = detail::getLoadOrder(order);
T *expected = load(load_order, scope);
T *desired;
do {
desired = expected + operand;
} while (!compare_exchange_weak(expected, desired, order, scope));
return expected;
}

T *operator+=(difference_type operand) const noexcept {
return fetch_add(operand) + operand;
}

T *operator++(int) const noexcept { return fetch_add(difference_type(1)); }

T *operator++() const noexcept {
return fetch_add(difference_type(1)) + difference_type(1);
}

T *fetch_sub(difference_type operand,
memory_order order = default_read_modify_write_order,
memory_scope scope = default_scope) const noexcept {
// TODO: Find a way to avoid compare_exchange here
auto load_order = detail::getLoadOrder(order);
T *expected = load(load_order, scope);
T *desired;
do {
desired = expected - operand;
} while (!compare_exchange_weak(expected, desired, order, scope));
return expected;
}

T *operator-=(difference_type operand) const noexcept {
return fetch_sub(operand) - operand;
}

T *operator--(int) const noexcept { return fetch_sub(difference_type(1)); }

T *operator--() const noexcept {
return fetch_sub(difference_type(1)) - difference_type(1);
}

bool
compare_exchange_strong(T *&expected, T *desired, memory_order success,
memory_order failure,
memory_scope scope = default_scope) const noexcept {
return base_type::compare_exchange_strong(
reinterpret_cast<uintptr_t &>(expected),
reinterpret_cast<uintptr_t>(desired), success, failure, scope);
}

bool
compare_exchange_strong(T *&expected, T *desired,
memory_order order = default_read_modify_write_order,
memory_scope scope = default_scope) const noexcept {
return compare_exchange_strong(expected, desired, order, order, scope);
}

bool
compare_exchange_weak(T *&expected, T *desired, memory_order success,
memory_order failure,
memory_scope scope = default_scope) const noexcept {
return base_type::compare_exchange_weak(
reinterpret_cast<uintptr_t &>(expected),
reinterpret_cast<uintptr_t>(desired), success, failure, scope);
}

bool
compare_exchange_weak(T *&expected, T *desired,
memory_order order = default_read_modify_write_order,
memory_scope scope = default_scope) const noexcept {
return compare_exchange_weak(expected, desired, order, order, scope);
}

private:
using base_type::ptr;
};

} // namespace detail
Expand Down
50 changes: 24 additions & 26 deletions sycl/test/atomic_ref/add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
using namespace sycl;
using namespace sycl::intel;

template <typename T>
template <typename T, typename Difference = T>
void add_fetch_test(queue q, size_t N) {
T sum = 0;
std::vector<T> output(N);
std::fill(output.begin(), output.end(), 0);
std::fill(output.begin(), output.end(), T(0));
{
buffer<T> sum_buf(&sum, 1);
buffer<T> output_buf(output.data(), output.size());
Expand All @@ -27,29 +27,29 @@ void add_fetch_test(queue q, size_t N) {
cgh.parallel_for(range<1>(N), [=](item<1> it) {
int gid = it.get_id(0);
auto atm = atomic_ref<T, intel::memory_order::relaxed, intel::memory_scope::device, access::address_space::global_space>(sum[0]);
out[gid] = atm.fetch_add(T(1));
out[gid] = atm.fetch_add(Difference(1));
});
});
}

// All work-items increment by 1, so final value should be equal to N
assert(sum == N);
assert(sum == T(N));

// Fetch returns original value: will be in [0, N-1]
auto min_e = std::min_element(output.begin(), output.end());
auto max_e = std::max_element(output.begin(), output.end());
assert(*min_e == 0 && *max_e == N - 1);
assert(*min_e == T(0) && *max_e == T(N - 1));

// Intermediate values should be unique
std::sort(output.begin(), output.end());
assert(std::unique(output.begin(), output.end()) == output.end());
}

template <typename T>
template <typename T, typename Difference = T>
void add_plus_equal_test(queue q, size_t N) {
T sum = 0;
std::vector<T> output(N);
std::fill(output.begin(), output.end(), 0);
std::fill(output.begin(), output.end(), T(0));
{
buffer<T> sum_buf(&sum, 1);
buffer<T> output_buf(output.data(), output.size());
Expand All @@ -60,29 +60,29 @@ void add_plus_equal_test(queue q, size_t N) {
cgh.parallel_for(range<1>(N), [=](item<1> it) {
int gid = it.get_id(0);
auto atm = atomic_ref<T, intel::memory_order::relaxed, intel::memory_scope::device, access::address_space::global_space>(sum[0]);
out[gid] = atm += T(1);
out[gid] = atm += Difference(1);
});
});
}

// All work-items increment by 1, so final value should be equal to N
assert(sum == N);
assert(sum == T(N));

// += returns updated value: will be in [1, N]
auto min_e = std::min_element(output.begin(), output.end());
auto max_e = std::max_element(output.begin(), output.end());
assert(*min_e == 1 && *max_e == N);
assert(*min_e == T(1) && *max_e == T(N));

// Intermediate values should be unique
std::sort(output.begin(), output.end());
assert(std::unique(output.begin(), output.end()) == output.end());
}

template <typename T>
template <typename T, typename Difference = T>
void add_pre_inc_test(queue q, size_t N) {
T sum = 0;
std::vector<T> output(N);
std::fill(output.begin(), output.end(), 0);
std::fill(output.begin(), output.end(), T(0));
{
buffer<T> sum_buf(&sum, 1);
buffer<T> output_buf(output.data(), output.size());
Expand All @@ -99,23 +99,23 @@ void add_pre_inc_test(queue q, size_t N) {
}

// All work-items increment by 1, so final value should be equal to N
assert(sum == N);
assert(sum == T(N));

// Pre-increment returns updated value: will be in [1, N]
auto min_e = std::min_element(output.begin(), output.end());
auto max_e = std::max_element(output.begin(), output.end());
assert(*min_e == 1 && *max_e == N);
assert(*min_e == T(1) && *max_e == T(N));

// Intermediate values should be unique
std::sort(output.begin(), output.end());
assert(std::unique(output.begin(), output.end()) == output.end());
}

template <typename T>
template <typename T, typename Difference = T>
void add_post_inc_test(queue q, size_t N) {
T sum = 0;
std::vector<T> output(N);
std::fill(output.begin(), output.end(), 0);
std::fill(output.begin(), output.end(), T(0));
{
buffer<T> sum_buf(&sum, 1);
buffer<T> output_buf(output.data(), output.size());
Expand All @@ -132,24 +132,24 @@ void add_post_inc_test(queue q, size_t N) {
}

// All work-items increment by 1, so final value should be equal to N
assert(sum == N);
assert(sum == T(N));

// Post-increment returns original value: will be in [0, N-1]
auto min_e = std::min_element(output.begin(), output.end());
auto max_e = std::max_element(output.begin(), output.end());
assert(*min_e == 0 && *max_e == N - 1);
assert(*min_e == T(0) && *max_e == T(N - 1));

// Intermediate values should be unique
std::sort(output.begin(), output.end());
assert(std::unique(output.begin(), output.end()) == output.end());
}

template <typename T>
template <typename T, typename Difference = T>
void add_test(queue q, size_t N) {
add_fetch_test<T>(q, N);
add_plus_equal_test<T>(q, N);
add_pre_inc_test<T>(q, N);
add_post_inc_test<T>(q, N);
add_fetch_test<T, Difference>(q, N);
add_plus_equal_test<T, Difference>(q, N);
add_pre_inc_test<T, Difference>(q, N);
add_post_inc_test<T, Difference>(q, N);
}

// Floating-point types do not support pre- or post-increment
Expand All @@ -173,8 +173,6 @@ int main() {
}

constexpr int N = 32;

// TODO: Enable missing tests when supported
add_test<int>(q, N);
add_test<unsigned int>(q, N);
add_test<long>(q, N);
Expand All @@ -183,7 +181,7 @@ int main() {
add_test<unsigned long long>(q, N);
add_test<float>(q, N);
add_test<double>(q, N);
//add_test<char*>(q, N);
add_test<char *, ptrdiff_t>(q, N);

std::cout << "Test passed." << std::endl;
}
19 changes: 9 additions & 10 deletions sycl/test/atomic_ref/compare_exchange.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,27 @@ class compare_exchange_kernel;

template <typename T>
void compare_exchange_test(queue q, size_t N) {
const T initial = std::numeric_limits<T>::max();
const T initial = T(N);
T compare_exchange = initial;
std::vector<T> output(N);
std::fill(output.begin(), output.end(), 0);
std::fill(output.begin(), output.end(), T(0));
{
buffer<T> compare_exchange_buf(&compare_exchange, 1);
buffer<T> output_buf(output.data(), output.size());

q.submit([&](handler &cgh) {
auto exc = compare_exchange_buf.template get_access<access::mode::read_write>(cgh);
auto out = output_buf.template get_access<access::mode::discard_write>(cgh);
cgh.parallel_for<compare_exchange_kernel<T>>(range<1>(N), [=](item<1> it) {
int gid = it.get_id(0);
cgh.parallel_for<compare_exchange_kernel<T>>(range<1>(N), [=](item<1>
it) {
size_t gid = it.get_id(0);
auto atm = atomic_ref<T, intel::memory_order::relaxed, intel::memory_scope::device, access::address_space::global_space>(exc[0]);
T result = initial;
T result = T(N); // Avoid copying pointer
bool success = atm.compare_exchange_strong(result, (T)gid);
if (success) {
out[gid] = result;
} else {
out[gid] = gid;
out[gid] = T(gid);
}
});
});
Expand All @@ -45,7 +46,7 @@ void compare_exchange_test(queue q, size_t N) {
assert(std::count(output.begin(), output.end(), initial) == 1);

// All other values should be the index itself or the sentinel value
for (int i = 0; i < N; ++i) {
for (size_t i = 0; i < N; ++i) {
assert(output[i] == T(i) || output[i] == initial);
}
}
Expand All @@ -59,8 +60,6 @@ int main() {
}

constexpr int N = 32;

// TODO: Enable missing tests when supported
compare_exchange_test<int>(q, N);
compare_exchange_test<unsigned int>(q, N);
compare_exchange_test<long>(q, N);
Expand All @@ -69,7 +68,7 @@ int main() {
compare_exchange_test<unsigned long long>(q, N);
compare_exchange_test<float>(q, N);
compare_exchange_test<double>(q, N);
//compare_exchange_test<char*>(q, N);
compare_exchange_test<char *>(q, N);

std::cout << "Test passed." << std::endl;
}
Loading