Skip to content

Commit

Permalink
sycl does not support 16 bit atomic. throw error or fallback to worki…
Browse files Browse the repository at this point in the history
…ng version
  • Loading branch information
yhmtsai committed Nov 25, 2024
1 parent 5c9999f commit cef03ac
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 128 deletions.
4 changes: 0 additions & 4 deletions dpcpp/components/atomic.dp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,6 @@ __dpct_inline__ ResultType reinterpret(ValueType val)
GKO_BIND_ATOMIC_HELPER_STRUCTURE(unsigned long long int);
// Support 32-bit ATOMIC_ADD
GKO_BIND_ATOMIC_HELPER_STRUCTURE(unsigned int);
// Support 16-bit ATOMIC_ADD
// GKO_BIND_ATOMIC_HELPER_STRUCTURE(unsigned short);


#undef GKO_BIND_ATOMIC_HELPER_STRUCTURE
Expand Down Expand Up @@ -239,8 +237,6 @@ struct atomic_helper<
GKO_BIND_ATOMIC_MAX_STRUCTURE(unsigned long long int);
// Support 32-bit ATOMIC_MAX
GKO_BIND_ATOMIC_MAX_STRUCTURE(unsigned int);
// Support 16-bit ATOMIC_MAX
// GKO_BIND_ATOMIC_MAX_STRUCTURE(unsigned short);


#undef GKO_BIND_ATOMIC_MAX_STRUCTURE
Expand Down
98 changes: 54 additions & 44 deletions dpcpp/matrix/coo_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,27 +291,32 @@ void spmv2(std::shared_ptr<const DpcppExecutor> exec,
const dim3 coo_block(config::warp_size, warps_in_block, 1);
const auto nwarps = host_kernel::calculate_nwarps(exec, nnz);

if (nwarps > 0) {
if (b_ncols < 4) {
const dim3 coo_grid(ceildiv(nwarps, warps_in_block), b_ncols);
int num_lines = ceildiv(nnz, nwarps * config::warp_size);
abstract_spmv(coo_grid, coo_block, 0, exec->get_queue(), nnz,
num_lines, as_device_type(a->get_const_values()),
a->get_const_col_idxs(), a->get_const_row_idxs(),
as_device_type(b->get_const_values()),
b->get_stride(), as_device_type(c->get_values()),
c->get_stride());
} else {
int num_elems =
ceildiv(nnz, nwarps * config::warp_size) * config::warp_size;
const dim3 coo_grid(ceildiv(nwarps, warps_in_block),
ceildiv(b_ncols, config::warp_size));
abstract_spmm(coo_grid, coo_block, 0, exec->get_queue(), nnz,
num_elems, as_device_type(a->get_const_values()),
a->get_const_col_idxs(), a->get_const_row_idxs(),
b_ncols, as_device_type(b->get_const_values()),
b->get_stride(), as_device_type(c->get_values()),
c->get_stride());
// not support 16 bit atomic
if constexpr (std::is_same_v<remove_complex<ValueType>, half>) {
GKO_NOT_SUPPORTED(c);
} else {
if (nwarps > 0) {
if (b_ncols < 4) {
const dim3 coo_grid(ceildiv(nwarps, warps_in_block), b_ncols);
int num_lines = ceildiv(nnz, nwarps * config::warp_size);
abstract_spmv(coo_grid, coo_block, 0, exec->get_queue(), nnz,
num_lines, as_device_type(a->get_const_values()),
a->get_const_col_idxs(), a->get_const_row_idxs(),
as_device_type(b->get_const_values()),
b->get_stride(), as_device_type(c->get_values()),
c->get_stride());
} else {
int num_elems = ceildiv(nnz, nwarps * config::warp_size) *
config::warp_size;
const dim3 coo_grid(ceildiv(nwarps, warps_in_block),
ceildiv(b_ncols, config::warp_size));
abstract_spmm(coo_grid, coo_block, 0, exec->get_queue(), nnz,
num_elems, as_device_type(a->get_const_values()),
a->get_const_col_idxs(), a->get_const_row_idxs(),
b_ncols, as_device_type(b->get_const_values()),
b->get_stride(), as_device_type(c->get_values()),
c->get_stride());
}
}
}
}
Expand All @@ -332,29 +337,34 @@ void advanced_spmv2(std::shared_ptr<const DpcppExecutor> exec,
const dim3 coo_block(config::warp_size, warps_in_block, 1);
const auto b_ncols = b->get_size()[1];

if (nwarps > 0) {
if (b_ncols < 4) {
int num_lines = ceildiv(nnz, nwarps * config::warp_size);
const dim3 coo_grid(ceildiv(nwarps, warps_in_block), b_ncols);
abstract_spmv(coo_grid, coo_block, 0, exec->get_queue(), nnz,
num_lines, as_device_type(alpha->get_const_values()),
as_device_type(a->get_const_values()),
a->get_const_col_idxs(), a->get_const_row_idxs(),
as_device_type(b->get_const_values()),
b->get_stride(), as_device_type(c->get_values()),
c->get_stride());
} else {
int num_elems =
ceildiv(nnz, nwarps * config::warp_size) * config::warp_size;
const dim3 coo_grid(ceildiv(nwarps, warps_in_block),
ceildiv(b_ncols, config::warp_size));
abstract_spmm(coo_grid, coo_block, 0, exec->get_queue(), nnz,
num_elems, as_device_type(alpha->get_const_values()),
as_device_type(a->get_const_values()),
a->get_const_col_idxs(), a->get_const_row_idxs(),
b_ncols, as_device_type(b->get_const_values()),
b->get_stride(), as_device_type(c->get_values()),
c->get_stride());
// not support 16 bit atomic
if constexpr (std::is_same_v<remove_complex<ValueType>, gko::half>) {
GKO_NOT_SUPPORTED(c);
} else {
if (nwarps > 0) {
if (b_ncols < 4) {
int num_lines = ceildiv(nnz, nwarps * config::warp_size);
const dim3 coo_grid(ceildiv(nwarps, warps_in_block), b_ncols);
abstract_spmv(
coo_grid, coo_block, 0, exec->get_queue(), nnz, num_lines,
as_device_type(alpha->get_const_values()),
as_device_type(a->get_const_values()),
a->get_const_col_idxs(), a->get_const_row_idxs(),
as_device_type(b->get_const_values()), b->get_stride(),
as_device_type(c->get_values()), c->get_stride());
} else {
int num_elems = ceildiv(nnz, nwarps * config::warp_size) *
config::warp_size;
const dim3 coo_grid(ceildiv(nwarps, warps_in_block),
ceildiv(b_ncols, config::warp_size));
abstract_spmm(
coo_grid, coo_block, 0, exec->get_queue(), nnz, num_elems,
as_device_type(alpha->get_const_values()),
as_device_type(a->get_const_values()),
a->get_const_col_idxs(), a->get_const_row_idxs(), b_ncols,
as_device_type(b->get_const_values()), b->get_stride(),
as_device_type(c->get_values()), c->get_stride());
}
}
}
}
Expand Down
92 changes: 51 additions & 41 deletions dpcpp/matrix/csr_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1353,7 +1353,7 @@ GKO_ENABLE_IMPLEMENTATION_SELECTION(select_classical_spmv, classical_spmv);

template <typename MatrixValueType, typename InputValueType,
typename OutputValueType, typename IndexType>
void load_balance_spmv(std::shared_ptr<const DpcppExecutor> exec,
bool load_balance_spmv(std::shared_ptr<const DpcppExecutor> exec,
const matrix::Csr<MatrixValueType, IndexType>* a,
const matrix::Dense<InputValueType>* b,
matrix::Dense<OutputValueType>* c,
Expand All @@ -1363,40 +1363,49 @@ void load_balance_spmv(std::shared_ptr<const DpcppExecutor> exec,
using arithmetic_type =
highest_precision<InputValueType, OutputValueType, MatrixValueType>;

if (beta) {
dense::scale(exec, beta, c);
// not support 16 bit atomic
if constexpr (std::is_same_v<remove_complex<OutputValueType>, half>) {
return false;
} else {
dense::fill(exec, c, zero<OutputValueType>());
}
const IndexType nwarps = a->get_num_srow_elements();
if (nwarps > 0) {
const dim3 csr_block(config::warp_size, warps_in_block, 1);
const dim3 csr_grid(ceildiv(nwarps, warps_in_block), b->get_size()[1]);
const auto a_vals =
acc::helper::build_const_rrm_accessor<arithmetic_type>(a);
const auto b_vals =
acc::helper::build_const_rrm_accessor<arithmetic_type>(b);
auto c_vals = acc::helper::build_rrm_accessor<arithmetic_type>(c);
if (alpha) {
if (csr_grid.x > 0 && csr_grid.y > 0) {
csr::kernel::abstract_spmv(
csr_grid, csr_block, 0, exec->get_queue(), nwarps,
static_cast<IndexType>(a->get_size()[0]),
as_device_type(alpha->get_const_values()),
acc::as_device_range(a_vals), a->get_const_col_idxs(),
a->get_const_row_ptrs(), a->get_const_srow(),
acc::as_device_range(b_vals), acc::as_device_range(c_vals));
}
if (beta) {
dense::scale(exec, beta, c);
} else {
if (csr_grid.x > 0 && csr_grid.y > 0) {
csr::kernel::abstract_spmv(
csr_grid, csr_block, 0, exec->get_queue(), nwarps,
static_cast<IndexType>(a->get_size()[0]),
acc::as_device_range(a_vals), a->get_const_col_idxs(),
a->get_const_row_ptrs(), a->get_const_srow(),
acc::as_device_range(b_vals), acc::as_device_range(c_vals));
dense::fill(exec, c, zero<OutputValueType>());
}
const IndexType nwarps = a->get_num_srow_elements();
if (nwarps > 0) {
const dim3 csr_block(config::warp_size, warps_in_block, 1);
const dim3 csr_grid(ceildiv(nwarps, warps_in_block),
b->get_size()[1]);
const auto a_vals =
acc::helper::build_const_rrm_accessor<arithmetic_type>(a);
const auto b_vals =
acc::helper::build_const_rrm_accessor<arithmetic_type>(b);
auto c_vals = acc::helper::build_rrm_accessor<arithmetic_type>(c);
if (alpha) {
if (csr_grid.x > 0 && csr_grid.y > 0) {
csr::kernel::abstract_spmv(
csr_grid, csr_block, 0, exec->get_queue(), nwarps,
static_cast<IndexType>(a->get_size()[0]),
as_device_type(alpha->get_const_values()),
acc::as_device_range(a_vals), a->get_const_col_idxs(),
a->get_const_row_ptrs(), a->get_const_srow(),
acc::as_device_range(b_vals),
acc::as_device_range(c_vals));
}
} else {
if (csr_grid.x > 0 && csr_grid.y > 0) {
csr::kernel::abstract_spmv(
csr_grid, csr_block, 0, exec->get_queue(), nwarps,
static_cast<IndexType>(a->get_size()[0]),
acc::as_device_range(a_vals), a->get_const_col_idxs(),
a->get_const_row_ptrs(), a->get_const_srow(),
acc::as_device_range(b_vals),
acc::as_device_range(c_vals));
}
}
}
return true;
}
}

Expand Down Expand Up @@ -1502,9 +1511,7 @@ void spmv(std::shared_ptr<const DpcppExecutor> exec,
dense::fill(exec, c, zero<OutputValueType>());
return;
}
if (a->get_strategy()->get_name() == "load_balance") {
host_kernel::load_balance_spmv(exec, a, b, c);
} else if (a->get_strategy()->get_name() == "merge_path") {
if (a->get_strategy()->get_name() == "merge_path") {
using arithmetic_type =
highest_precision<InputValueType, OutputValueType, MatrixValueType>;
int items_per_thread =
Expand All @@ -1518,8 +1525,10 @@ void spmv(std::shared_ptr<const DpcppExecutor> exec,
syn::value_list<int>(), syn::type_list<>(), exec, a, b, c);
} else {
bool use_classical = true;
if (a->get_strategy()->get_name() == "sparselib" ||
a->get_strategy()->get_name() == "cusparse") {
if (a->get_strategy()->get_name() == "load_balance") {
use_classical = !host_kernel::load_balance_spmv(exec, a, b, c);
} else if (a->get_strategy()->get_name() == "sparselib" ||
a->get_strategy()->get_name() == "cusparse") {
use_classical = !host_kernel::try_sparselib_spmv(exec, a, b, c);
}
if (use_classical) {
Expand Down Expand Up @@ -1571,9 +1580,7 @@ void advanced_spmv(std::shared_ptr<const DpcppExecutor> exec,
dense::scale(exec, beta, c);
return;
}
if (a->get_strategy()->get_name() == "load_balance") {
host_kernel::load_balance_spmv(exec, a, b, c, alpha, beta);
} else if (a->get_strategy()->get_name() == "merge_path") {
if (a->get_strategy()->get_name() == "merge_path") {
using arithmetic_type =
highest_precision<InputValueType, OutputValueType, MatrixValueType>;
int items_per_thread =
Expand All @@ -1588,8 +1595,11 @@ void advanced_spmv(std::shared_ptr<const DpcppExecutor> exec,
beta);
} else {
bool use_classical = true;
if (a->get_strategy()->get_name() == "sparselib" ||
a->get_strategy()->get_name() == "cusparse") {
if (a->get_strategy()->get_name() == "load_balance") {
use_classical =
!host_kernel::load_balance_spmv(exec, a, b, c, alpha, beta);
} else if (a->get_strategy()->get_name() == "sparselib" ||
a->get_strategy()->get_name() == "cusparse") {
use_classical =
!host_kernel::try_sparselib_spmv(exec, a, b, c, alpha, beta);
}
Expand Down
74 changes: 43 additions & 31 deletions dpcpp/matrix/ell_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ void spmv_kernel(
using arithmetic_type = typename a_accessor::arithmetic_type;
const auto tidx = thread::get_thread_id_flat(item_ct1);
const decltype(tidx) column_id = item_ct1.get_group(1);
if (num_thread_per_worker == 1) {
if constexpr (num_thread_per_worker == 1) {
// Specialize the num_thread_per_worker = 1. It doesn't need the shared
// memory, __syncthreads, and atomic_add
if (tidx < num_rows) {
Expand Down Expand Up @@ -311,37 +311,49 @@ void abstract_spmv(syn::value_list<int, info>,
const dim3 grid_size(ceildiv(nrows * num_worker_per_row, block_size.x),
b->get_size()[1], 1);

const auto a_vals = gko::acc::range<a_accessor>(
std::array<acc::size_type, 1>{{static_cast<acc::size_type>(
num_stored_elements_per_row * stride)}},
a->get_const_values());
const auto b_vals = gko::acc::range<b_accessor>(
std::array<acc::size_type, 2>{
{static_cast<acc::size_type>(b->get_size()[0]),
static_cast<acc::size_type>(b->get_size()[1])}},
b->get_const_values(),
std::array<acc::size_type, 1>{
{static_cast<acc::size_type>(b->get_stride())}});

if (alpha == nullptr && beta == nullptr) {
kernel::spmv<num_thread_per_worker, atomic>(
grid_size, block_size, 0, exec->get_queue(), nrows,
num_worker_per_row, acc::as_device_range(a_vals),
a->get_const_col_idxs(), stride, num_stored_elements_per_row,
acc::as_device_range(b_vals), as_device_type(c->get_values()),
c->get_stride());
} else if (alpha != nullptr && beta != nullptr) {
const auto alpha_val = gko::acc::range<a_accessor>(
std::array<acc::size_type, 1>{1}, alpha->get_const_values());
kernel::spmv<num_thread_per_worker, atomic>(
grid_size, block_size, 0, exec->get_queue(), nrows,
num_worker_per_row, acc::as_device_range(alpha_val),
acc::as_device_range(a_vals), a->get_const_col_idxs(), stride,
num_stored_elements_per_row, acc::as_device_range(b_vals),
as_device_type(beta->get_const_values()),
as_device_type(c->get_values()), c->get_stride());
} else {
// not support 16 bit atomic
// We do atomic on shared memory when num_thread_per_worker is not 1.
// If atomic is also true, we also do atomic on out_vector.
constexpr bool shared_half =
std::is_same_v<remove_complex<arithmetic_type>, half>;
constexpr bool atomic_half_out =
atomic && std::is_same_v<remove_complex<OutputValueType>, half>;
if constexpr (num_thread_per_worker != 1 &&
(shared_half || atomic_half_out)) {
GKO_KERNEL_NOT_FOUND;
} else {
const auto a_vals = gko::acc::range<a_accessor>(
std::array<acc::size_type, 1>{{static_cast<acc::size_type>(
num_stored_elements_per_row * stride)}},
a->get_const_values());
const auto b_vals = gko::acc::range<b_accessor>(
std::array<acc::size_type, 2>{
{static_cast<acc::size_type>(b->get_size()[0]),
static_cast<acc::size_type>(b->get_size()[1])}},
b->get_const_values(),
std::array<acc::size_type, 1>{
{static_cast<acc::size_type>(b->get_stride())}});

if (alpha == nullptr && beta == nullptr) {
kernel::spmv<num_thread_per_worker, atomic>(
grid_size, block_size, 0, exec->get_queue(), nrows,
num_worker_per_row, acc::as_device_range(a_vals),
a->get_const_col_idxs(), stride, num_stored_elements_per_row,
acc::as_device_range(b_vals), as_device_type(c->get_values()),
c->get_stride());
} else if (alpha != nullptr && beta != nullptr) {
const auto alpha_val = gko::acc::range<a_accessor>(
std::array<acc::size_type, 1>{1}, alpha->get_const_values());
kernel::spmv<num_thread_per_worker, atomic>(
grid_size, block_size, 0, exec->get_queue(), nrows,
num_worker_per_row, acc::as_device_range(alpha_val),
acc::as_device_range(a_vals), a->get_const_col_idxs(), stride,
num_stored_elements_per_row, acc::as_device_range(b_vals),
as_device_type(beta->get_const_values()),
as_device_type(c->get_values()), c->get_stride());
} else {
GKO_KERNEL_NOT_FOUND;
}
}
}

Expand Down
15 changes: 7 additions & 8 deletions dpcpp/solver/idr_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -682,11 +682,11 @@ void update_g_and_u(std::shared_ptr<const DpcppExecutor> exec,
as_device_type(alpha->get_values()),
stop_status->get_const_data());
};
if constexpr (std::is_same_v<ValueType, half> ||
is_complex<ValueType>()) {
gko_impl();
// not support 16 bit atomic
if constexpr (std::is_same_v<remove_complex<ValueType>, half>) {
GKO_NOT_SUPPORTED(alpha);
} else {
if (nrhs > 1) {
if (nrhs > 1 || is_complex<ValueType>()) {
gko_impl();
} else {
onemkl::dot(*exec->get_queue(), size, p_i, 1, g_k->get_values(),
Expand Down Expand Up @@ -739,11 +739,10 @@ void update_m(std::shared_ptr<const DpcppExecutor> exec, const size_type nrhs,
g_k->get_stride(), as_device_type(m_i),
stop_status->get_const_data());
};
if constexpr (std::is_same_v<ValueType, half> ||
is_complex<ValueType>()) {
gko_impl();
if constexpr (std::is_same_v<remove_complex<ValueType>, half>) {
GKO_NOT_SUPPORTED(m_i);
} else {
if (nrhs > 1) {
if (nrhs > 1 || is_complex<ValueType>()) {
gko_impl();
} else {
onemkl::dot(*exec->get_queue(), size, as_device_type(p_i), 1,
Expand Down

0 comments on commit cef03ac

Please sign in to comment.