diff --git a/CHANGELOG.md b/CHANGELOG.md index b5709856e..7a9ebe912 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,10 @@ Full documentation for rocSOLVER is available at the [rocSOLVER documentation](h * Improved the performance of SYEVJ * Improved the performance of GEQRF +### Resolved issues +### Known issues +### Upcoming changes + ## rocSOLVER 3.27.0 for ROCm 6.3.0 diff --git a/library/src/lapack/roclapack_syevj_heevj.hpp b/library/src/lapack/roclapack_syevj_heevj.hpp index 86913f201..007bdf17c 100644 --- a/library/src/lapack/roclapack_syevj_heevj.hpp +++ b/library/src/lapack/roclapack_syevj_heevj.hpp @@ -46,6 +46,14 @@ ROCSOLVER_BEGIN_NAMESPACE #define SYEVJ_BDIM 1024 // Max number of threads per thread-block used in syevj_small kernel +static int get_num_cu(int deviceId = 0) +{ + int ival = 0; + auto const attr = hipDeviceAttributeMultiprocessorCount; + HIP_CHECK(hipDeviceGetAttribute(&ival, attr, deviceId)); + return (ival); +} + /** SYEVJ_SMALL_KERNEL/RUN_SYEVJ applies the Jacobi eigenvalue algorithm to matrices of size n <= SYEVJ_BLOCKED_SWITCH. For each off-diagonal element A[i,j], a Jacobi rotation J is calculated so that (J'AJ)[i,j] = 0. J only affects rows i and j, and J' only affects @@ -592,14 +600,14 @@ ROCSOLVER_KERNEL void syevj_init(const rocblas_evect evect, will work on a separate diagonal block; for a matrix consisting of b * b blocks, use b thread blocks in x. **/ template -ROCSOLVER_KERNEL void syevj_diag_kernel(const rocblas_int n, - U AA, - const rocblas_int shiftA, - const rocblas_int lda, - const rocblas_stride strideA, - const S eps, - T* JA, - rocblas_int* completed) +ROCSOLVER_KERNEL void syevj_diag_kernel_org(const rocblas_int n, + U AA, + const rocblas_int shiftA, + const rocblas_int lda, + const rocblas_stride strideA, + const S eps, + T* JA, + rocblas_int* completed) { rocblas_int tix = hipThreadIdx_x; rocblas_int tiy = hipThreadIdx_y; @@ -675,101 +683,1173 @@ ROCSOLVER_KERNEL void syevj_diag_kernel(const rocblas_int n, } else { - g = 2 * mag; - f = std::real(A[j + j * lda] - A[i + i * lda]); - f += (f < 0) ? -std::hypot(f, g) : std::hypot(f, g); - lartg(f, g, c, s, r); - s1 = s * aij / mag; - } + g = 2 * mag; + f = std::real(A[j + j * lda] - A[i + i * lda]); + f += (f < 0) ? -std::hypot(f, g) : std::hypot(f, g); + lartg(f, g, c, s, r); + s1 = s * aij / mag; + } + + sh_cosines[tix] = c; + sh_sines[tix] = s1; + } + __syncthreads(); + + if(i < n && j < n) + { + c = sh_cosines[tix]; + s1 = sh_sines[tix]; + s2 = conj(s1); + + // store J row-wise + if(J) + { + xx1 = i - offset; + xx2 = j - offset; + temp1 = J[xx1 + yy1 * nb_max]; + temp2 = J[xx2 + yy1 * nb_max]; + J[xx1 + yy1 * nb_max] = c * temp1 + s2 * temp2; + J[xx2 + yy1 * nb_max] = -s1 * temp1 + c * temp2; + + if(y2 < n) + { + temp1 = J[xx1 + yy2 * nb_max]; + temp2 = J[xx2 + yy2 * nb_max]; + J[xx1 + yy2 * nb_max] = c * temp1 + s2 * temp2; + J[xx2 + yy2 * nb_max] = -s1 * temp1 + c * temp2; + } + } + + // apply J from the right + temp1 = A[y1 + i * lda]; + temp2 = A[y1 + j * lda]; + A[y1 + i * lda] = c * temp1 + s2 * temp2; + A[y1 + j * lda] = -s1 * temp1 + c * temp2; + + if(y2 < n) + { + temp1 = A[y2 + i * lda]; + temp2 = A[y2 + j * lda]; + A[y2 + i * lda] = c * temp1 + s2 * temp2; + A[y2 + j * lda] = -s1 * temp1 + c * temp2; + } + } + __syncthreads(); + + if(i < n && j < n) + { + // apply J' from the left + temp1 = A[i + y1 * lda]; + temp2 = A[j + y1 * lda]; + A[i + y1 * lda] = c * temp1 + s1 * temp2; + A[j + y1 * lda] = -s2 * temp1 + c * temp2; + + if(y2 < n) + { + temp1 = A[i + y2 * lda]; + temp2 = A[j + y2 * lda]; + A[i + y2 * lda] = c * temp1 + s1 * temp2; + A[j + y2 * lda] = -s2 * temp1 + c * temp2; + } + } + __syncthreads(); + + if(tiy == 0 && i < n && j < n) + { + // round aij and aji to zero + A[i + j * lda] = 0; + A[j + i * lda] = 0; + } + + // cycle top/bottom pairs + if(tix == 1) + i = sh_bottom[0]; + else if(tix > 1) + i = sh_top[tix - 1]; + if(tix == half_nb - 1) + j = sh_top[half_nb - 1]; + else + j = sh_bottom[tix + 1]; + __syncthreads(); + + if(tiy == 0) + { + sh_top[tix] = i; + sh_bottom[tix] = j; + } + } +} + +template +ROCSOLVER_KERNEL void syevj_diag_kernel(const rocblas_int n, + const rocblas_int nb_max, + U AA, + const rocblas_int shiftA, + const rocblas_int lda, + const rocblas_stride strideA, + const S eps, + T* JA, + rocblas_int* completed, + const rocblas_int batch_count, + size_t lmem_size = 64 * 1024) +{ + typedef rocblas_int I; + + I const bid_start = hipBlockIdx_z; + I const bid_inc = hipGridDim_z; + + I const tix_start = hipThreadIdx_x; + I const tix_inc = hipBlockDim_x; + I const tiy_start = hipThreadIdx_y; + I const tiy_inc = hipBlockDim_y; + + // ----------------------------------- + // tixy is 1D composite [tix,tiy] index + // for all threads in the thread block + // ----------------------------------- + I const tixy_start = tix_start + tiy_start * tix_inc; + I const tixy_inc = tix_inc * tiy_inc; + + I const ibx_start = hipBlockIdx_x; + I const ibx_inc = hipGridDim_x; + + auto ceil = [](auto n, auto nb) { return (((n - 1) / nb) + 1); }; + auto const blocks = ceil(n, nb_max); + + auto const half_n = ceil(n, 2); + + // -------------------------------------- + // return size of the i-th diagonal block + // -------------------------------------- + auto bsize = [=](auto iblock) { + auto const nb_last = n - (blocks - 1) * nb_max; + bool const is_last_block = (iblock == (blocks - 1)); + return ((is_last_block) ? nb_last : nb_max); + }; + + // ----------------------- + // arrays in shared memory + // ----------------------- + + extern __shared__ double lmem[]; + std::byte* pfree = reinterpret_cast(&(lmem[0])); + + auto const max_lds = lmem_size; + auto const max_npairs = ceil(nb_max, 2); + + size_t total_bytes = 0; + S* const sh_cosines = reinterpret_cast(pfree); + pfree += sizeof(S) * max_npairs; + total_bytes += sizeof(S) * max_npairs; + + T* const sh_sines = reinterpret_cast(pfree); + pfree += sizeof(T) * max_npairs; + total_bytes += sizeof(T) * max_npairs; + + I* const sh_top = reinterpret_cast(pfree); + pfree += sizeof(I) * max_npairs; + total_bytes += sizeof(I) * max_npairs; + + I* const sh_bottom = reinterpret_cast(pfree); + pfree += sizeof(I) * max_npairs; + total_bytes += sizeof(I) * max_npairs; + + assert(total_bytes <= max_lds); + + // ------------ + // alocate Ash[] + // ------------ + auto const ldAsh = nb_max; + size_t const size_Ash = sizeof(T) * ldAsh * nb_max; + T* const Ash_ = reinterpret_cast(pfree); + pfree += size_Ash; + total_bytes += size_Ash; + bool const use_Ash = (total_bytes <= max_lds); + auto Ash = [=](auto i, auto j) -> T& { return (Ash_[i + j * ldAsh]); }; + + // -------------- + // allocate Jsh[] + // -------------- + auto const ldJsh = nb_max; + size_t const size_Jsh = sizeof(T) * ldJsh * nb_max; + T* const Jsh_ = reinterpret_cast(pfree); + pfree += size_Jsh; + total_bytes += size_Jsh; + bool const use_Jsh = (total_bytes <= max_lds); + auto Jsh = [=](auto i, auto j) -> T& { return (Jsh_[i + j * ldJsh]); }; + + S const small_num = get_safemin() / eps; + S const sqrt_small_num = std::sqrt(small_num); + + for(auto bid = bid_start; bid < batch_count; bid += bid_inc) + { + if(completed[bid + 1]) + continue; + + T* const A_ = load_ptr_batch(AA, bid, shiftA, strideA); + auto A = [=](auto ia, auto ja) -> T& { return (A_[ia + ja * static_cast(lda)]); }; + + for(auto iblock = ibx_start; iblock < blocks; iblock += ibx_inc) + { + auto const jid = iblock + bid * blocks; + T* const J_ = (JA ? JA + (jid * nb_max * nb_max) : nullptr); + bool const use_J = (J_ != nullptr); + + auto const ldj = nb_max; + auto J = [=](auto i, auto j) -> T& { return (J_[i + j * ldj]); }; + + T* const Jmat_ = (use_Jsh) ? Jsh_ : J_; + auto const ldJmat = (use_Jsh) ? ldJsh : ldj; + auto Jmat = [=](auto i, auto j) -> T& { return (Jmat_[i + j * ldJmat]); }; + + auto const offset = iblock * nb_max; + // auto const nb = std::min(2 * half_n - offset, nb_max); + auto const half_nb = ceil(bsize(iblock), 2); + auto const npairs = half_nb; + + // ---------------------------------- + // Note: (i,j) are local index values + // ---------------------------------- + auto const Amat_ = (use_Ash) ? Ash_ : A_ + (offset + offset * static_cast(lda)); + auto const ldAmat = (use_Ash) ? ldAsh : lda; + auto Amat = [=](auto i, auto j) -> T& { return (Amat_[i + j * ldAmat]); }; + + // --------------------------- + // set J to be identity matrix + // --------------------------- + auto const nrowsJ = bsize(iblock); + auto const ncolsJ = nrowsJ; + auto const nrowsAmat = nrowsJ; + auto const ncolsAmat = ncolsJ; + + if(use_J) + { + for(auto tiy = tiy_start; tiy < ncolsJ; tiy += tiy_inc) + { + for(auto tix = tix_start; tix < nrowsJ; tix += tix_inc) + { + bool const is_diag = (tix == tiy); + Jmat(tix, tiy) = is_diag ? 1 : 0; + } + } + __syncthreads(); + } + + if(use_Ash) + { + __syncthreads(); + for(auto tiy = tiy_start; tiy < ncolsAmat; tiy += tiy_inc) + { + for(auto tix = tix_start; tix < nrowsAmat; tix += tix_inc) + { + auto const ia = tix + offset; + auto const ja = tiy + offset; + + Ash(tix, tiy) = A(ia, ja); + } + } + __syncthreads(); + } + + __syncthreads(); + + // --------------------- + // initialize top/bottom + // + // Note: sh_top[], sh_bottom[] contain local index + // --------------------- + for(auto ipair = tixy_start; ipair < npairs; ipair += tixy_inc) + { + sh_top[ipair] = 2 * ipair; + sh_bottom[ipair] = 2 * ipair + 1; + } + __syncthreads(); + + auto const num_rounds = (2 * npairs - 1); + for(I iround = 0; iround < num_rounds; iround++) + { + // -------------------------------------------------------------- + // for each off-diagonal element (indexed using top/bottom pairs), + // calculate the Jacobi rotation and apply it to A + // -------------------------------------------------------------- + + // ------------------------------ + // compute the sine, cosine values + // ------------------------------ + for(auto ipair = tixy_start; ipair < npairs; ipair += tixy_inc) + { + auto const i = std::min(sh_top[ipair], sh_bottom[ipair]); + auto const j = std::max(sh_top[ipair], sh_bottom[ipair]); + auto const ia = i + offset; + auto const ja = j + offset; + + S c = 1; + T s1 = 0; + + sh_cosines[ipair] = c; + sh_sines[ipair] = s1; + + bool const is_valid = (ia < n) && (ja < n); + if(!is_valid) + continue; + + auto const aij = Amat(i, j); + + auto const mag = std::abs(aij); + bool const is_small = (mag < sqrt_small_num); + if(!is_small) + { + auto const real_aii = std::real(Amat(i, i)); + auto const real_ajj = std::real(Amat(j, j)); + S g = 2 * mag; + S f = real_ajj - real_aii; + f += (f < 0) ? -std::hypot(f, g) : std::hypot(f, g); + + S r = 0; + S s = 0; + lartg(f, g, c, s, r); + + s1 = aij * (s / mag); + + sh_cosines[ipair] = c; + sh_sines[ipair] = s1; + } + } // end for ipair + + __syncthreads(); + + for(auto ipair = tiy_start; ipair < npairs; ipair += tiy_inc) + { + auto const i = std::min(sh_top[ipair], sh_bottom[ipair]); + auto const j = std::max(sh_top[ipair], sh_bottom[ipair]); + auto const ia = i + offset; + auto const ja = j + offset; + + bool const is_valid = (ia < n) && (ja < n); + if(!is_valid) + continue; + + auto const c = sh_cosines[ipair]; + auto const s1 = sh_sines[ipair]; + auto const s2 = conj(s1); + + for(auto tix = tix_start; tix < nrowsJ; tix += tix_inc) + { + if(use_J) + { + // ---------------- + // store J row-wise + // ---------------- + auto const temp1 = Jmat(i, tix); + auto const temp2 = Jmat(j, tix); + Jmat(i, tix) = c * temp1 + s2 * temp2; + Jmat(j, tix) = -s1 * temp1 + c * temp2; + } + + // -------- + // update A + // -------- + { + auto const temp1 = Amat(tix, i); + auto const temp2 = Amat(tix, j); + + Amat(tix, i) = c * temp1 + s2 * temp2; + Amat(tix, j) = -s1 * temp1 + c * temp2; + } + + } // end for tix + } // end for ipair + + __syncthreads(); + + for(auto ipair = tiy_start; ipair < npairs; ipair += tiy_inc) + { + auto const i = std::min(sh_top[ipair], sh_bottom[ipair]); + auto const j = std::max(sh_top[ipair], sh_bottom[ipair]); + auto const ia = i + offset; + auto const ja = j + offset; + + bool const is_valid = (ia < n) && (ja < n); + if(!is_valid) + continue; + + auto const c = sh_cosines[ipair]; + auto const s1 = sh_sines[ipair]; + auto const s2 = conj(s1); + + // ------------------ + // apply J' from left + // ------------------ + for(auto tix = tix_start; tix < nrowsJ; tix += tix_inc) + { + auto const temp1 = Amat(i, tix); + auto const temp2 = Amat(j, tix); + + Amat(i, tix) = c * temp1 + s1 * temp2; + Amat(j, tix) = -s2 * temp1 + c * temp2; + } + } // end for ipair + __syncthreads(); + + for(auto ipair = tixy_start; ipair < npairs; ipair += tixy_inc) + { + auto const i = std::min(sh_top[ipair], sh_bottom[ipair]); + auto const j = std::max(sh_top[ipair], sh_bottom[ipair]); + auto const ia = i + offset; + auto const ja = j + offset; + + bool const is_valid = (ia < n) && (ja < n); + if(!is_valid) + continue; + + // --------------- + // set aij to zero + // --------------- + + Amat(i, j) = 0; + Amat(j, i) = 0; + } + + __syncthreads(); + + { + assert(tixy_inc >= half_nb); + + I const ipair = tixy_start; + + I i = 0; + I j = 0; + + // cycle top/bottom pairs + if(ipair == 1) + i = sh_bottom[0]; + else if(ipair > 1) + i = ((ipair - 1) < npairs) ? sh_top[ipair - 1] : 0; + + if(ipair == (npairs - 1)) + j = sh_top[npairs - 1]; + else + j = ((ipair + 1) < npairs) ? sh_bottom[ipair + 1] : 0; + + __syncthreads(); + + if(ipair < npairs) + { + sh_top[ipair] = i; + sh_bottom[ipair] = j; + } + } + + __syncthreads(); + } // end for iround + + __syncthreads(); + // ---------------------------------------- + // write out data from LDS to device memory + // ---------------------------------------- + + if(use_Ash) + { + for(auto tiy = tiy_start; tiy < ncolsAmat; tiy += tiy_inc) + { + for(auto tix = tix_start; tix < nrowsAmat; tix += tix_inc) + { + auto const ia = tix + offset; + auto const ja = tiy + offset; + A(ia, ja) = Ash(tix, tiy); + } + } + __syncthreads(); + } + + if(use_J && use_Jsh) + { + for(auto tiy = tiy_start; tiy < ncolsJ; tiy += tiy_inc) + { + for(auto tix = tix_start; tix < nrowsJ; tix += tix_inc) + { + J(tix, tiy) = Jsh(tix, tiy); + } + } + __syncthreads(); + } + + } // end for iblock + } // end for bid +} + +template +ROCSOLVER_KERNEL void syevj_diag_kernel_dev(const rocblas_int n, + const rocblas_int nb_max, + U AA, + const rocblas_int shiftA, + const rocblas_int lda, + const rocblas_stride strideA, + const S eps, + T* JA, + rocblas_int* completed, + const rocblas_int batch_count) +{ + typedef rocblas_int I; + + auto ceil = [](auto n, auto nb) { return ((n - 1) / nb + 1); }; + + auto const blocks = ceil(n, nb_max); + // ------------------------------------------------- + // function to calculation the size of the i-th block + // ------------------------------------------------- + auto bsize = [=](auto iblock) { + auto const nb_last = n - (blocks - 1) * nb_max; + bool const is_last_block = (iblock == (blocks - 1)); + return ((is_last_block) ? nb_last : nb_max); + }; + + auto const bid_start = hipBlockIdx_z; + auto const bid_inc = hipGridDim_z; + + auto const iblock_start = hipBlockIdx_x; + auto const iblock_inc = hipGridDim_x; + + auto const i_start = hipThreadIdx_x; + auto const i_inc = hipBlockDim_x; + auto const j_start = hipThreadIdx_y; + auto const j_inc = hipBlockDim_y; + + // -------------------------------------- + // combined (i,j) into single index "ij" in 1D + // -------------------------------------- + auto const ij_start = i_start + j_start * i_inc; + auto const ij_inc = i_inc * j_inc; + + auto const tix_start = i_start; + auto const tix_inc = i_inc; + auto const tiy_start = j_start; + auto const tiy_inc = j_inc; + + auto const n_even = n + (n % 2); + auto const half_n = n_even / 2; + + auto const nb_max_even = nb_max + (nb_max % 2); + auto const half_nb_max = nb_max_even / 2; + auto const max_npairs = half_nb_max; + + // shared memory + size_t const size_lds = 64 * 1024; + extern __shared__ double lmem[]; + std::byte* pfree = reinterpret_cast(&(lmem[0])); + size_t total_bytes = 0; + + // --------------------------------- + // allocate sh_sines[], sh_cosines[] + // --------------------------------- + size_t const size_sh_cosines = sizeof(S) * max_npairs; + S* sh_cosines = reinterpret_cast(pfree); + pfree += size_sh_cosines; + total_bytes += size_sh_cosines; + + size_t const size_sh_sines = sizeof(T) * max_npairs; + T* sh_sines = reinterpret_cast(pfree); + pfree += size_sh_sines; + total_bytes += size_sh_sines; + + // ------------------------------------- + // allocate arrays for independent pairs + // ------------------------------------- + + size_t const len_vec = 2 * max_npairs; + size_t const size_vec = sizeof(I) * len_vec; + I* vec = reinterpret_cast(pfree); + pfree += size_vec; + total_bytes += size_vec; + + assert(total_bytes <= size_lds); + + // ------------------------ + // allocate Ash_[] + // ------------------------ + + auto const ldAsh = nb_max; + size_t const size_Ash = sizeof(T) * (ldAsh * nb_max); + T* const Ash_ = reinterpret_cast(pfree); + pfree += size_Ash; + total_bytes += size_Ash; + // bool const use_Ash = (total_bytes <= size_lds); + bool const use_Ash = false; + + auto Ash = [=](auto i, auto j) -> T& { return (Ash_[i + j * ldAsh]); }; + + // ------------------------ + // allocate Jsh_[] + // ------------------------ + auto const ldj = nb_max; + auto const len_Jsh = (ldj * nb_max); + size_t const size_Jsh = sizeof(T) * len_Jsh; + T* const Jsh_ = reinterpret_cast(pfree); + pfree += size_Jsh; + total_bytes += size_Jsh; + // bool const use_Jsh = (total_bytes <= size_lds); + bool const use_Jsh = false; + + S const small_num = get_safemin() / eps; + S const sqrt_small_num = std::sqrt(small_num); + + for(auto bid = bid_start; bid < batch_count; bid += bid_inc) + { + if(completed[bid + 1]) + continue; + + T* const A_ = load_ptr_batch(AA, bid, shiftA, strideA); + auto A = [=](auto ia, auto ja) -> T& { return (A_[ia + ja * static_cast(lda)]); }; + + // ------------------------------------------- + // work on the diagonal block A[iblock,iblock] + // ------------------------------------------- + for(auto iblock = iblock_start; iblock < blocks; iblock += iblock_inc) + { + auto const offset = iblock * nb_max; + auto const jid = iblock + bid * blocks; + T* const J_ = (JA ? JA + (jid * (nb_max * nb_max)) : nullptr); + bool const use_J = (J_ != nullptr); + + auto const nb = bsize(iblock); + auto const nb_even = nb + (nb % 2); + auto const half_nb = nb_even / 2; + auto const npairs = half_nb; + + auto const nrowsJ = nb_max; + auto const ncolsJ = nrowsJ; + auto const nrowsAmat = nb; + auto const ncolsAmat = nb; + + auto J = [=](auto i, auto j) -> T& { + assert(J_ != nullptr); + + return (J_[i + j * ldj]); + }; + + auto Jsh = [=](auto i, auto j) -> T& { + assert(use_Jsh); + + return (Jsh_[i + j * ldj]); + }; + + T* const Jmat_ = (use_Jsh) ? Jsh_ : J_; + auto Jmat = [=](auto i, auto j) -> T& { + assert(use_J); + assert(Jmat_ != nullptr); + + return (Jmat_[i + j * ldj]); + }; + + auto tb_pair = [=](auto i, auto ipair) { + assert((0 <= i) && (i <= 1)); + assert((0 <= ipair) && (ipair < npairs)); + + auto const m = 2 * npairs; + + auto map = [=](auto ip) { + auto const ival0 = (m - 1) - 1; + bool const is_even = ((ip % 2) == 0); + bool const is_last = (ip == (m - 1)); + auto const j = (ip - 1) / 2; + + return (is_last ? (m - 1) : is_even ? (ip / 2) : ival0 - j); + }; + + auto const ip = map(i + 2 * ipair); + assert((0 <= ip) && (ip < nb_even)); + + return (vec[ip]); + }; + + auto rotate = [=](auto const npairs) { + // ------------------------------------------ + // parallel algorithms need to have + // sufficient number of threads and registers + // + // note the last position vec[m-1] stays fixed + // ------------------------------------------ + // bool use_serial = ( ij_inc < (m-1) ); + + auto const m = 2 * npairs; + bool const use_parallel = (ij_inc >= (m - 1)); + // bool const use_serial = (!use_parallel); + bool const use_serial = true; + + if(use_serial) + { + bool const is_root = (ij_start == 0); + + if(is_root) + { + auto const v0 = vec[0]; + for(auto i = 1; i <= (m - 2); i++) + { + vec[i - 1] = vec[i]; + }; + vec[m - 2] = v0; + } + } + else + { + assert(ij_inc >= (m - 1)); + + auto const ij = ij_start; + auto const v_ij = (ij <= (m - 2)) ? vec[ij] : 0; + + // for(auto ij=ij_start; ij < (m-1); ij += ij_inc) { v_ij = vec[ ij ]; } + + __syncthreads(); + + for(auto ij = ij_start; ij <= (m - 2); ij += ij_inc) + { + if(ij >= 1) + { + vec[ij - 1] = v_ij; + } + else + { + vec[(m - 2)] = v_ij; + } + } + + } // end if use_serial + __syncthreads(); + }; + + auto init_tb_pair = [=](auto npairs) { + auto const m = 2 * npairs; + __syncthreads(); + for(auto ij = ij_start; ij < m; ij += ij_inc) + { + vec[ij] = ij; + }; + __syncthreads(); + }; + + // ----------------------------------- + // note Amat looks like nb by nb matrix + // ----------------------------------- + T* const Amat_ = (use_Ash) ? Ash_ : A_ + idx2D(offset, offset, lda); + auto const ldAmat = (use_Ash) ? ldAsh : lda; + auto Amat0 = [=](auto i, auto j) -> T& { + assert((0 <= i) && (i < nrowsAmat)); + assert((0 <= j) && (j < ncolsAmat)); + + return (Amat_[i + j * ldAmat]); + }; + + auto Amat = [=](auto i, auto j) -> T& { + assert((0 <= i) && (i < nrowsAmat)); + assert((0 <= j) && (j < ncolsAmat)); + if(use_Ash) + { + return (Ash_[i + j * ldAsh]); + } + else + { + auto const ia = i + offset; + auto const ja = j + offset; + return (A(ia, ja)); + } + }; + + if(use_J) + { + // ----------------------------- + // initialize to identity matrix + // ----------------------------- + + T const one = 1; + T const zero = 0; + + for(auto j = j_start; j < ncolsJ; j += j_inc) + { + for(auto i = i_start; i < nrowsJ; i += i_inc) + { + bool const is_diagonal = (i == j); + Jmat(i, j) = (is_diagonal) ? one : zero; + } + } + __syncthreads(); + } + + if(use_Ash) + { + // ----------------------------- + // load A into LDS shared memory + // ----------------------------- + + for(auto j = j_start; j < ncolsAmat; j += j_inc) + { + for(auto i = i_start; i < nrowsAmat; i += i_inc) + { + auto const ia = offset + i; + auto const ja = offset + j; + Ash(i, j) = A(ia, ja); + } + } + __syncthreads(); + } + +#ifdef NDEBUG +#else + auto cal_offd_norm = [=](auto& dnorm) { + dnorm = 0; + for(auto j = 0; j < ncolsAmat; j++) + { + for(auto i = 0; i < nrowsAmat; i++) + { + bool const is_diagonal = (i == j); + auto const ia = offset + i; + auto const ja = offset + j; + T const aij = (is_diagonal) ? 0 : A(ia, ja); + dnorm += std::norm(aij); + } + } + dnorm = std::sqrt(dnorm); + }; + + double offdiag_norm_init = 0; + + if(ij_start == 0) + { + cal_offd_norm(offdiag_norm_init); + } + + for(auto j = j_start; j < ncolsJ; j += j_inc) + { + for(auto i = i_start; i < nrowsJ; i += i_inc) + { + bool const is_diag = (i == j); + T const id_ij = (is_diag) ? 1 : 0; + bool const isok = (Jmat(i, j) == id_ij); + assert(isok); + } + } + __syncthreads(); +#endif + + init_tb_pair(npairs); + + auto const num_pass = (2 * npairs - 1); + for(auto ipass = 0; ipass < num_pass; ipass++) + { + // --------------------------------- + // generate sh_cosines[], sh_sines[] + // --------------------------------- +#ifdef NDEBUG +#else + if(iblock == (blocks - 1)) + { + if(ij_start == 0) + { + printf("ipass=%d,n=%d,blocks=%d,bsize=%d,npairs=%d\n", ipass, n, blocks, + bsize(iblock), npairs); + for(auto ipair = 0; ipair < npairs; ipair++) + { + printf("(%d,%d) ", tb_pair(0, ipair), tb_pair(1, ipair)); + }; + printf("\n"); + } + } +#endif + + for(auto ipair = ij_start; ipair < npairs; ipair += ij_inc) + { + auto const i = std::min(tb_pair(0, ipair), tb_pair(1, ipair)); + auto const j = std::max(tb_pair(0, ipair), tb_pair(1, ipair)); + + // ---------------------------------------------- + // default initialized value as identity rotation + // ---------------------------------------------- + double c = 1; + T s1 = 0; + + sh_cosines[ipair] = c; + sh_sines[ipair] = s1; + + // ---------------------------------------------------------- + // for each off-diagonal element (indexed using top/bottom pairs), + // calculate the Jacobi rotation and apply it to A + // ---------------------------------------------------------- + + { + auto const ia = i + offset; + auto const ja = j + offset; + bool is_valid = (ia < n) && (ja < n); + if(!is_valid) + continue; + } + + auto const aij = Amat(i, j); + double const mag = std::abs(aij); + + // bool const is_small = (mag < sqrt_small_num); + bool const is_small = (mag * mag < small_num); + // calculate rotation J + if(!is_small) + { + double const real_aii = std::real(Amat(i, i)); + double const real_ajj = std::real(Amat(j, j)); + + double g = 2 * mag; + // S f = std::real(Amat(ja, ja) - Amat(ia, ia)); + double f = real_ajj - real_aii; + f += (f < 0) ? -std::hypot(f, g) : std::hypot(f, g); + + double r = 1; + double s = 0; + lartg(f, g, c, s, r); + + s1 = aij * (s / mag); + } + + sh_cosines[ipair] = c; + sh_sines[ipair] = s1; + +#ifdef NDEBUG +#else + { + double const tol = 1e-6; + assert(std::abs((c * c + s1 * conj(s1)) - 1) <= tol); + } +#endif + + } // end for ij + + __syncthreads(); + + if(use_J) + { + for(auto ipair = tiy_start; ipair < npairs; ipair += tiy_inc) + { + auto const i = std::min(tb_pair(0, ipair), tb_pair(1, ipair)); + auto const j = std::max(tb_pair(0, ipair), tb_pair(1, ipair)); + + { + auto const ia = i + offset; + auto const ja = j + offset; + bool const is_valid = (ia < n) && (ja < n); + if(!is_valid) + continue; + } + + S const c = sh_cosines[ipair]; + T const s1 = sh_sines[ipair]; + T const s2 = conj(s1); + + for(auto tix = tix_start; tix < ncolsJ; tix += tix_inc) + { + // ---------------- + // store J row-wise + // ---------------- + { + T const temp1 = Jmat(i, tix); + T const temp2 = Jmat(j, tix); + Jmat(i, tix) = c * temp1 + s2 * temp2; + Jmat(j, tix) = -s1 * temp1 + c * temp2; + } + } + } // end for ipair + } + + __syncthreads(); + + for(auto ipair = tiy_start; ipair < npairs; ipair += tiy_inc) + { + auto const i = std::min(tb_pair(0, ipair), tb_pair(1, ipair)); + auto const j = std::max(tb_pair(0, ipair), tb_pair(1, ipair)); + + { + auto const ia = i + offset; + auto const ja = j + offset; + bool const is_valid = (ia < n) && (ja < n); + if(!is_valid) + continue; + } + + S const c = sh_cosines[ipair]; + T const s1 = sh_sines[ipair]; + T const s2 = conj(s1); + + for(auto tix = tix_start; tix < nrowsAmat; tix += tix_inc) + { + // ----------------------------- + // apply rotation from the right + // ----------------------------- + { + auto const temp1 = A(tix, i); + auto const temp2 = A(tix, j); + Amat(tix, i) = c * temp1 + s2 * temp2; + Amat(tix, j) = -s1 * temp1 + c * temp2; + } + } + } // end for ipair + + __syncthreads(); + + for(auto ipair = tiy_start; ipair < npairs; ipair += tiy_inc) + { + auto const i = std::min(tb_pair(0, ipair), tb_pair(1, ipair)); + auto const j = std::max(tb_pair(0, ipair), tb_pair(1, ipair)); + + { + auto const ia = i + offset; + auto const ja = j + offset; + bool const is_valid = (ia < n) && (ja < n); + if(!is_valid) + continue; + } + + S const c = sh_cosines[ipair]; + T const s1 = sh_sines[ipair]; + T const s2 = conj(s1); + + for(auto tix = tix_start; tix < ncolsAmat; tix += tix_inc) + { + // ---------------------- + // apply J' from the left + // ---------------------- + auto const temp1 = Amat(i, tix); + auto const temp2 = Amat(j, tix); + Amat(i, tix) = c * temp1 + s1 * temp2; + Amat(j, tix) = -s2 * temp1 + c * temp2; + } // end for tix + } // end for tix + + __syncthreads(); + + bool const round_aij_aji_to_zero = false; + if(round_aij_aji_to_zero) + { + for(auto ipair = ij_start; ipair < npairs; ipair += ij_inc) + { + auto const i = tb_pair(0, ipair); + auto const j = tb_pair(1, ipair); + + { + auto const ia = i + offset; + auto const ja = j + offset; + bool const is_valid = (ia < n) && (ja < n); + if(!is_valid) + continue; + } + + // round aij and aji to zero + Amat(i, j) = 0; + Amat(j, i) = 0; + } // end for ij + + __syncthreads(); + } + + // --------------------- + // rotate cycle pairs + // --------------------- + + rotate(npairs); + + __syncthreads(); + + } // end for ipass + + __syncthreads(); + + if(use_J && use_Jsh) + { + // --------------------------------------------------- + // write out rotation matrix from LDS to device memory + // --------------------------------------------------- - sh_cosines[tix] = c; - sh_sines[tix] = s1; - } - __syncthreads(); + for(auto ij = ij_start; ij < len_Jsh; ij += ij_inc) + { + J_[ij] = Jsh_[ij]; + }; - if(i < n && j < n) - { - c = sh_cosines[tix]; - s1 = sh_sines[tix]; - s2 = conj(s1); + __syncthreads(); + } - // store J row-wise - if(J) + if(use_Ash) { - xx1 = i - offset; - xx2 = j - offset; - temp1 = J[xx1 + yy1 * nb_max]; - temp2 = J[xx2 + yy1 * nb_max]; - J[xx1 + yy1 * nb_max] = c * temp1 + s2 * temp2; - J[xx2 + yy1 * nb_max] = -s1 * temp1 + c * temp2; + // ------------------------------------------ + // write out modified diagonal submatrix of A + // from LDS to device memory + // ------------------------------------------ - if(y2 < n) + for(auto j = j_start; j < ncolsAmat; j += j_inc) { - temp1 = J[xx1 + yy2 * nb_max]; - temp2 = J[xx2 + yy2 * nb_max]; - J[xx1 + yy2 * nb_max] = c * temp1 + s2 * temp2; - J[xx2 + yy2 * nb_max] = -s1 * temp1 + c * temp2; + for(auto i = i_start; i < nrowsAmat; i += i_inc) + { + auto const ia = i + offset; + auto const ja = j + offset; + A(ia, ja) = Ash(i, j); + } } + __syncthreads(); } + __syncthreads(); - // apply J from the right - temp1 = A[y1 + i * lda]; - temp2 = A[y1 + j * lda]; - A[y1 + i * lda] = c * temp1 + s2 * temp2; - A[y1 + j * lda] = -s1 * temp1 + c * temp2; +#ifdef NDEBUG +#else - if(y2 < n) + bool const check_J = true; + if(check_J) { - temp1 = A[y2 + i * lda]; - temp2 = A[y2 + j * lda]; - A[y2 + i * lda] = c * temp1 + s2 * temp2; - A[y2 + j * lda] = -s1 * temp1 + c * temp2; - } - } - __syncthreads(); + // ---------------------------- + // double check J is orthogonal + // check J' * J == identity + // ---------------------------- + __syncthreads(); - if(i < n && j < n) - { - // apply J' from the left - temp1 = A[i + y1 * lda]; - temp2 = A[j + y1 * lda]; - A[i + y1 * lda] = c * temp1 + s1 * temp2; - A[j + y1 * lda] = -s2 * temp1 + c * temp2; + if(use_J) + { + double const tol = nrowsJ * eps; + auto nerrors = 0; + for(auto j = j_start; j < ncolsJ; j += j_inc) + { + for(auto i = i_start; i < nrowsJ; i += i_inc) + { + T eij = 0; + for(auto k = 0; k < nrowsJ; k++) + { + auto const Jt_ik = conj(Jmat(i, k)); + auto const J_kj = Jmat(j, k); + eij += Jt_ik * J_kj; + } + T const id_ij = (i == j) ? 1 : 0; + double const diff = std::abs(id_ij - eij); - if(y2 < n) - { - temp1 = A[i + y2 * lda]; - temp2 = A[j + y2 * lda]; - A[i + y2 * lda] = c * temp1 + s1 * temp2; - A[j + y2 * lda] = -s2 * temp1 + c * temp2; - } - } - __syncthreads(); + bool const isok = (diff <= tol); + if(!isok) + { + printf("iblock=%d,i=%d,j=%d,diff=%le,eij=%le\n", iblock, i, j, diff, + (double)std::real(eij)); + nerrors += 1; + } + } + } + assert(nerrors == 0); + } - if(tiy == 0 && i < n && j < n) - { - // round aij and aji to zero - A[i + j * lda] = 0; - A[j + i * lda] = 0; - } + __syncthreads(); - // cycle top/bottom pairs - if(tix == 1) - i = sh_bottom[0]; - else if(tix > 1) - i = sh_top[tix - 1]; - if(tix == half_nb - 1) - j = sh_top[half_nb - 1]; - else - j = sh_bottom[tix + 1]; - __syncthreads(); + // compute norm of off diagonal - if(tiy == 0) - { - sh_top[tix] = i; - sh_bottom[tix] = j; - } - } + double offdiag_norm_final = 0; + if(ij_start == 0) + { + cal_offd_norm(offdiag_norm_final); + + printf("iblock=%d,offdiag_norm_init=%le,offdiag_norm_final=%le\n", iblock, + offdiag_norm_init, offdiag_norm_final); + } + __syncthreads(); + } +#endif + + } // end for iblock + } // end for bid } /** SYEVJ_DIAG_ROTATE rotates off-diagonal blocks of size nb <= BS2 using the rotations calculated @@ -1259,27 +2339,33 @@ ROCSOLVER_KERNEL void syevj_offd_kernel(const rocblas_int nb_max, rocblas_int* top, rocblas_int* bottom, rocblas_int* completed, - rocblas_int batch_count) + rocblas_int batch_count, + size_t lmem_size = 64 * 1024) { auto const blocks = ceil(n, nb_max); auto const even_blocks = blocks + (blocks % 2); auto const half_blocks = even_blocks / 2; - auto const ibx = hipBlockIdx_x; - auto const iby = hipBlockIdx_y; - auto const ibz = hipBlockIdx_z; + auto bsize = [=](auto iblock) { + auto const nb_last = n - (blocks - 1) * nb_max; + bool const is_last_block = (iblock == (blocks - 1)); + return ((is_last_block) ? nb_last : nb_max); + }; + + auto const ibpair_start = hipBlockIdx_x; + auto const ibpair_inc = hipGridDim_x; - auto const nbx = hipGridDim_x; - auto const nby = hipGridDim_y; - auto const nbz = hipGridDim_z; + auto const bid_start = hipBlockIdx_z; + auto const bid_inc = hipGridDim_z; auto const tix_start = hipThreadIdx_x; auto const tiy_start = hipThreadIdx_y; + auto const tix_inc = hipBlockDim_x; auto const tiy_inc = hipBlockDim_y; - auto const bid_start = ibz; - auto const bid_inc = nbz; + auto const tixy_start = tix_start + tiy_start * tix_inc; + auto const tixy_inc = tix_inc * tiy_inc; // shared memory extern __shared__ double lmem[]; @@ -1287,6 +2373,10 @@ ROCSOLVER_KERNEL void syevj_offd_kernel(const rocblas_int nb_max, size_t total_bytes = 0; std::byte* pfree = reinterpret_cast(&(lmem[0])); + // ----------------------------- + // allocate sh_cosines, sh_sines + // ----------------------------- + size_t const size_sh_cosines = sizeof(S) * nb_max; S* const sh_cosines = reinterpret_cast(pfree); pfree += size_sh_cosines; @@ -1297,23 +2387,33 @@ ROCSOLVER_KERNEL void syevj_offd_kernel(const rocblas_int nb_max, pfree += size_sh_sines; total_bytes += size_sh_sines; - size_t const size_Jsh = sizeof(T) * (2 * nb_max) * (2 * nb_max); + // ------------ + // allocate Ash + // ------------ + auto const ldAsh = (2 * nb_max); + auto const len_Ash = ldAsh * (2 * nb_max); + size_t const size_Ash = sizeof(T) * len_Ash; + + T* const Ash_ = reinterpret_cast(pfree); + pfree += size_Ash; + total_bytes += size_Ash; + bool const use_Ash = (total_bytes <= lmem_size); + auto Ash = [=](auto i, auto j) -> T& { return (Ash_[i + j * ldAsh]); }; + + // ------------ + // allocate Jsh + // ------------ + auto const ldJsh = (2 * nb_max); + auto const len_Jsh = ldJsh * (2 * nb_max); + size_t const size_Jsh = sizeof(T) * len_Jsh; T* const Jsh_ = reinterpret_cast(pfree); pfree += size_Jsh; total_bytes += size_Jsh; - size_t const max_lds = 64 * 1024; - bool const use_Jsh = (total_bytes <= max_lds); - - auto bsize = [=](auto iblock) { - auto const nb_last = n - (blocks - 1) * nb_max; - bool const is_last_block = (iblock == (blocks - 1)); - return ((is_last_block) ? nb_last : nb_max); - }; + bool const use_Jsh = (total_bytes <= lmem_size); + auto Jsh = [=](auto i, auto j) -> T& { return (Jsh_[i + j * ldJsh]); }; - auto const ipair_start = ibx; - auto const ipair_inc = nbx; - auto const npairs = half_blocks; + auto const nbpairs = half_blocks; for(auto bid = bid_start; bid < batch_count; bid += bid_inc) { @@ -1323,10 +2423,10 @@ ROCSOLVER_KERNEL void syevj_offd_kernel(const rocblas_int nb_max, T* const A_ = load_ptr_batch(AA, bid, shiftA, strideA); auto A = [=](auto i, auto j) -> T& { return (A_[i + j * static_cast(lda)]); }; - for(auto ipair = ipair_start; ipair < npairs; ipair += ipair_inc) + for(auto ibpair = ibpair_start; ibpair < nbpairs; ibpair += ibpair_inc) { - auto const iblock = std::min(top[ipair], bottom[ipair]); - auto const jblock = std::max(top[ipair], bottom[ipair]); + auto const iblock = std::min(top[ibpair], bottom[ibpair]); + auto const jblock = std::max(top[ibpair], bottom[ibpair]); bool const is_valid_block = (iblock < blocks) && (jblock < blocks); if(!is_valid_block) @@ -1339,31 +2439,59 @@ ROCSOLVER_KERNEL void syevj_offd_kernel(const rocblas_int nb_max, auto const nj = bsize(jblock); auto const nrowsJ = ni + nj; auto const ncolsJ = nrowsJ; + auto const nrowsAmat = nrowsJ; + auto const ncolsAmat = ncolsJ; + + auto const jid = ibpair + bid * nbpairs; - auto const jid = ipair + bid * npairs; T* const J_ = ((JA != nullptr) ? JA + (jid * (4 * nb_max * nb_max)) : nullptr); + bool const use_J = (J_ != nullptr); auto const ldj = (2 * nb_max); auto J = [=](auto i, auto j) -> T& { return (J_[i + j * ldj]); }; T* const Jmat_ = (use_Jsh) ? Jsh_ : J_; - auto Jmat = [=](auto i, auto j) -> T& { return (Jmat_[i + j * ldj]); }; + auto const ldJmat = (use_Jsh) ? ldJsh : ldj; + auto Jmat = [=](auto i, auto j) -> T& { return (Jmat_[i + j * ldJmat]); }; + + auto l2g_index = [=](auto i) { return ((i < ni) ? i + offseti : (i - ni) + offsetj); }; + + auto const Amat = [=](auto i, auto j) -> T& { + if(use_Ash) + { + return (Ash(i, j)); + } + auto const ia = l2g_index(i); + auto const ja = l2g_index(j); + + return (A(ia, ja)); + }; + + // --------------- + // initialize Amat + // --------------- + if(use_Ash) + { + for(auto j = tiy_start; j < ncolsAmat; j += tiy_inc) + { + for(auto i = tix_start; i < nrowsAmat; i += tix_inc) + { + auto const ia = l2g_index(i); + auto const ja = l2g_index(j); + + Ash(i, j) = A(ia, ja); + } + } + } + // ------------------------------------------ // initialize Jmat to be the identity matrix // ------------------------------------------ - if(J_ != nullptr) + if(use_J) { - __syncthreads(); - - auto const i_start = tix_start; - auto const i_inc = tix_inc; - - auto const j_start = tiy_start; - auto const j_inc = tiy_inc; - - for(auto j = j_start; j < ncolsJ; j += j_inc) + for(auto j = tiy_start; j < ncolsJ; j += tiy_inc) { - for(auto i = i_start; i < nrowsJ; i += i_inc) + for(auto i = tix_start; i < nrowsJ; i += tix_inc) { bool const is_diagonal = (i == j); Jmat(i, j) = (is_diagonal) ? 1 : 0; @@ -1375,204 +2503,226 @@ ROCSOLVER_KERNEL void syevj_offd_kernel(const rocblas_int nb_max, S const small_num = get_safemin() / eps; S const sqrt_small_num = std::sqrt(small_num); + __syncthreads(); + // for each element, calculate the Jacobi rotation and apply it to A for(rocblas_int k = 0; k < nb_max; k++) { - for(auto tiy = tiy_start; tiy < nb_max; tiy += tiy_inc) + // ------------------------------- + // generate the sine/cosine values + // ------------------------------- + for(auto tixy = tixy_start; tixy < nb_max; tixy += tixy_inc) { - for(auto tix = tix_start; tix < nb_max; tix += tix_inc) - { - auto const x1 = tix + offseti; - auto const x2 = tix + offsetj; - auto const y1 = tiy + offseti; - auto const y2 = tiy + offsetj; + auto const i = tixy; + auto const j = (tixy + k) % nb_max + nb_max; + S c = 1; + T s1 = 0; + sh_cosines[tixy] = c; + sh_sines[tixy] = s1; - // get element indices - auto const i = x1; - auto const j = (tix + k) % nb_max + offsetj; + bool const is_valid = (i < nrowsAmat) && (j < nrowsAmat); + if(!is_valid) + continue; - if((tiy == 0) && (i < n) && (j < n)) - { - auto const aij = A(i, j); - auto const mag = std::abs(aij); + auto const aij = Amat(i, j); + auto const mag = std::abs(aij); - // identity rotation - S c = 1; - T s1 = 0; + // calculate rotation J + bool const is_small_aij = (mag < sqrt_small_num); - // calculate rotation J - // bool const is_small_aij = (mag < sqrt_small_num); - bool const is_small_aij = (mag * mag < small_num); + if(!is_small_aij) + { + S g = 2 * mag; - if(!is_small_aij) - { - S g = 2 * mag; - S f = std::real(A(j, j) - A(i, i)); - f += (f < 0) ? -std::hypot(f, g) : std::hypot(f, g); + auto const real_ajj = std::real(Amat(j, j)); + auto const real_aii = std::real(Amat(i, i)); + S f = real_ajj - real_aii; + S const hypot_f_g = std::hypot(f, g); - S s = 0; - S r = 1; + // f += (f < 0) ? -std::hypot(f, g) : std::hypot(f, g); + f += (f < 0) ? -hypot_f_g : hypot_f_g; - lartg(f, g, c, s, r); - s1 = s * aij / mag; - } + S s = 0; + S r = 1; - sh_cosines[tix] = c; - sh_sines[tix] = s1; - } + lartg(f, g, c, s, r); + // s1 = s * aij / mag; + s1 = aij * (s / mag); } + + sh_cosines[tixy] = c; + sh_sines[tixy] = s1; } __syncthreads(); - for(auto tiy = tiy_start; tiy < nb_max; tiy += tiy_inc) + // ---------------------------------- + // apply rotation J on block columns + // ---------------------------------- + if(use_J) { for(auto tix = tix_start; tix < nb_max; tix += tix_inc) { - auto const x1 = tix + offseti; - auto const x2 = tix + offsetj; - auto const y1 = tiy + offseti; - auto const y2 = tiy + offsetj; + auto const i = tix; + auto const j = (tix + k) % nb_max + nb_max; + bool const is_valid = (i < nrowsAmat) && (j < nrowsAmat); + if(!is_valid) + continue; - // get element indices - auto const i = x1; - auto const j = (tix + k) % nb_max + offsetj; + auto const c = sh_cosines[tix]; + auto const s1 = sh_sines[tix]; + auto const s2 = conj(s1); - if((i < n) && (j < n)) + for(auto tiy = tiy_start; tiy < ncolsJ; tiy += tiy_inc) { - auto const c = sh_cosines[tix]; - auto const s1 = sh_sines[tix]; - auto const s2 = conj(s1); - - // store J row-wise - if(J_ != nullptr) - { - auto const xx1 = i - offseti; - auto const xx2 = j - offsetj + nb_max; - auto const yy1 = tiy; - auto const yy2 = tiy + nb_max; - - { - auto const temp1 = Jmat(xx1, yy1); - auto const temp2 = Jmat(xx2, yy1); - - Jmat(xx1, yy1) = c * temp1 + s2 * temp2; - Jmat(xx2, yy1) = -s1 * temp1 + c * temp2; - } - - if(y2 < n) - { - auto const temp1 = Jmat(xx1, yy2); - auto const temp2 = Jmat(xx2, yy2); - - Jmat(xx1, yy2) = c * temp1 + s2 * temp2; - Jmat(xx2, yy2) = -s1 * temp1 + c * temp2; - } - } + auto const temp1 = Jmat(i, tiy); + auto const temp2 = Jmat(j, tiy); - // apply J from the right - { - auto const temp1 = A(y1, i); - auto const temp2 = A(y1, j); - A(y1, i) = c * temp1 + s2 * temp2; - A(y1, j) = -s1 * temp1 + c * temp2; - } - - if(y2 < n) - { - auto const temp1 = A(y2, i); - auto const temp2 = A(y2, j); - A(y2, i) = c * temp1 + s2 * temp2; - A(y2, j) = -s1 * temp1 + c * temp2; - } + Jmat(i, tiy) = c * temp1 + s2 * temp2; + Jmat(j, tiy) = -s1 * temp1 + c * temp2; } } } - __syncthreads(); - for(auto tiy = tiy_start; tiy < nb_max; tiy += tiy_inc) { - for(auto tix = tix_start; tix < nb_max; tix += tix_inc) + auto const i = tiy; + auto const j = (tiy + k) % nb_max + nb_max; + bool const is_valid = (i < nrowsAmat) && (j < nrowsAmat); + if(!is_valid) + continue; + + auto const c = sh_cosines[tiy]; + auto const s1 = sh_sines[tiy]; + auto const s2 = conj(s1); + // -------------------------------------------------- + // apply J from the right on columns A(:,i), A(:,j) + // -------------------------------------------------- + if(use_Ash) { - auto const x1 = tix + offseti; - auto const x2 = tix + offsetj; - auto const y1 = tiy + offseti; - auto const y2 = tiy + offsetj; - - // get element indices - auto const i = x1; - auto const j = (tix + k) % nb_max + offsetj; - if(i < n && j < n) + for(auto tix = tix_start; tix < nrowsAmat; tix += tix_inc) { - auto const c = sh_cosines[tix]; - auto const s1 = sh_sines[tix]; - auto const s2 = conj(s1); - // apply J' from the left - { - auto const temp1 = A(i, y1); - auto const temp2 = A(j, y1); - A(i, y1) = c * temp1 + s1 * temp2; - A(j, y1) = -s2 * temp1 + c * temp2; - } - - if(y2 < n) - { - auto const temp1 = A(i, y2); - auto const temp2 = A(j, y2); - A(i, y2) = c * temp1 + s1 * temp2; - A(j, y2) = -s2 * temp1 + c * temp2; - } + auto const temp1 = Ash(tix, i); + auto const temp2 = Ash(tix, j); + Ash(tix, i) = c * temp1 + s2 * temp2; + Ash(tix, j) = -s1 * temp1 + c * temp2; + } + } + else + { + auto const ia = l2g_index(i); + auto const ja = l2g_index(j); + for(auto tix = tix_start; tix < nrowsAmat; tix += tix_inc) + { + auto const gtix = l2g_index(tix); + auto const temp1 = A(gtix, ia); + auto const temp2 = A(gtix, ja); + A(gtix, ia) = c * temp1 + s2 * temp2; + A(gtix, ja) = -s1 * temp1 + c * temp2; } } } __syncthreads(); - for(auto tiy = tiy_start; tiy < nb_max; tiy += tiy_inc) + // ------------------------------------------- + // apply J' from the left to rows A(i,:) and A(j,:) + // ------------------------------------------- + for(auto ipair = tix_start; ipair < nb_max; ipair += tix_inc) { - for(auto tix = tix_start; tix < nb_max; tix += tix_inc) + auto const i = ipair; + auto const j = (ipair + k) % nb_max + nb_max; + bool const is_valid = (i < nrowsAmat) && (j < nrowsAmat); + if(!is_valid) + continue; + + auto const c = sh_cosines[ipair]; + auto const s1 = sh_sines[ipair]; + auto const s2 = conj(s1); + + if(use_Ash) { - auto const x1 = tix + offseti; - auto const x2 = tix + offsetj; - auto const y1 = tiy + offseti; - auto const y2 = tiy + offsetj; - - // get element indices - auto const i = x1; - auto const j = (tix + k) % nb_max + offsetj; - if((tiy == 0) && (j < n)) + for(auto tiy = tiy_start; tiy < ncolsAmat; tiy += tiy_inc) { - // round aij and aji to zero - A(i, j) = 0; - A(j, i) = 0; + auto const temp1 = Ash(i, tiy); + auto const temp2 = Ash(j, tiy); + Ash(i, tiy) = c * temp1 + s1 * temp2; + Ash(j, tiy) = -s2 * temp1 + c * temp2; } } + else + { + auto const ia = l2g_index(i); + auto const ja = l2g_index(j); + for(auto tiy = tiy_start; tiy < ncolsAmat; tiy += tiy_inc) + { + auto const gtiy = l2g_index(tiy); + auto const temp1 = A(ia, gtiy); + auto const temp2 = A(ja, gtiy); + A(ia, gtiy) = c * temp1 + s1 * temp2; + A(ja, gtiy) = -s2 * temp1 + c * temp2; + } + } + } // end for ipair + + __syncthreads(); + + for(auto tixy = tixy_start; tixy < nb_max; tixy += tixy_inc) + { + auto const i = tixy; + auto const j = (tixy + k) % nb_max + nb_max; + bool const is_valid = (i < nrowsAmat) && (j < nrowsAmat); + if(!is_valid) + continue; + + if(use_Ash) + { + Ash(i, j) = 0; + Ash(j, i) = 0; + } + else + { + auto const ia = l2g_index(i); + auto const ja = l2g_index(j); + A(ia, ja) = 0; + A(ja, ia) = 0; + } } + __syncthreads(); } // end for k // ----------------------------------- - // write out Jsh to J in device memory + // write out Ash to A in device memory // ----------------------------------- - if((J_ != nullptr) && use_Jsh) + if(use_Ash) { + for(auto j = tiy_start; j < ncolsAmat; j += tiy_inc) + { + for(auto i = tix_start; i < nrowsAmat; i += tix_inc) + { + auto const ia = l2g_index(i); + auto const ja = l2g_index(j); + A(ia, ja) = Ash(i, j); + } + } __syncthreads(); - - auto const i_start = tix_start; - auto const i_inc = tix_inc; - - auto const j_start = tiy_start; - auto const j_inc = tiy_inc; - - for(auto j = j_start; j < ncolsJ; j += j_inc) + } + // ----------------------------------- + // write out Jsh to J in device memory + // ----------------------------------- + if(use_J && use_Jsh) + { + for(auto j = tiy_start; j < ncolsJ; j += tiy_inc) { - for(auto i = i_start; i < nrowsJ; i += i_inc) + for(auto i = tix_start; i < nrowsJ; i += tix_inc) { - J(i, j) = Jmat(i, j); + J(i, j) = Jsh(i, j); } } + __syncthreads(); } - } // end for ipair + } // end for ibpair + } // end for bid } @@ -1669,7 +2819,8 @@ ROCSOLVER_KERNEL void syevj_offd_rotate(const bool skip_block, rocblas_int* top, rocblas_int* bottom, rocblas_int* completed, - rocblas_int const batch_count) + rocblas_int const batch_count, + size_t const lmem_size = 64 * 1024) { bool constexpr APPLY_RIGHT = (!APPLY_LEFT); @@ -1709,26 +2860,36 @@ ROCSOLVER_KERNEL void syevj_offd_rotate(const bool skip_block, auto const kb_start = jb_start; auto const kb_inc = jb_inc; - auto constexpr max_lds = 64 * 1024; - auto constexpr len_shmem = max_lds / sizeof(T); - auto const ldj = 2 * nb_max; auto const len_Jsh = ldj * (2 * nb_max); auto const len_Ash = (2 * nb_max) * nb_max; + auto const len_shmem = lmem_size / sizeof(T); extern __shared__ double lmem[]; T* pfree = reinterpret_cast(&(lmem[0])); + size_t total_len = 0; - T* const Ash_ = pfree; + T* const __restrict__ Ash_ = pfree; pfree += len_Ash; - T* const Jsh_ = pfree; + total_len += len_Ash; + + // ---------------------------- + // in-place update requires storing + // a copy of submatrix (2*nb_max) by nb_max + // in LDS shared memory for correctness + // ---------------------------- + assert(total_len <= len_shmem); + + T* const __restrict__ Jsh_ = pfree; pfree += len_Jsh; + total_len += len_Jsh; // --------------------------------- // store J into shared memory only if - // there is sufficient space + // there is sufficient space in LDS + // shared memory // --------------------------------- - bool const use_Jsh = ((len_Jsh + len_Ash) <= len_shmem); + bool const use_Jsh = (total_len <= len_shmem); for(auto bid = bid_start; bid < batch_count; bid += bid_inc) { @@ -1762,7 +2923,7 @@ ROCSOLVER_KERNEL void syevj_offd_rotate(const bool skip_block, auto const ncolsJ = nrowsJ; auto const jid = ipair + bid * half_blocks; - T const* const __restrict__ J_ = JA + (jid * 4 * nb_max * nb_max); + T const* const __restrict__ J_ = JA + (jid * (2 * nb_max) * (2 * nb_max)); // --------------------------------- // store J into shared memory only if @@ -1778,11 +2939,9 @@ ROCSOLVER_KERNEL void syevj_offd_rotate(const bool skip_block, // ------------------------- auto const ij_start = i_start + j_start * nx; - auto const ij_inc = nx * ny; + auto const ij_inc = i_inc * j_inc; auto const len_Jsh = 4 * nb_max * nb_max; - __syncthreads(); - for(auto ij = ij_start; ij < len_Jsh; ij += ij_inc) { Jsh_[ij] = J_[ij]; @@ -1841,7 +3000,6 @@ ROCSOLVER_KERNEL void syevj_offd_rotate(const bool skip_block, // ------------------------ // load A into shared memory // ------------------------ - __syncthreads(); for(auto j = j_start; j < ncols_Ash; j += j_inc) { @@ -2315,14 +3473,19 @@ rocblas_status rocsolver_syevj_heevj_template(rocblas_handle handle, else { // use original algorithm for small problems - auto const n_threshold = 1024; + auto const n_threshold = 256; bool const use_offd_kernel_org = (n <= n_threshold); bool const use_diag_rotate_org = (n <= n_threshold); bool const use_offd_rotate_org = (n <= n_threshold); + bool const use_diag_kernel_org = (n <= n_threshold); + bool const use_any_org = use_offd_kernel_org || use_diag_rotate_org || use_offd_rotate_org + || use_diag_kernel_org; // *** USE BLOCKED KERNELS *** - rocblas_int const nb_max = BS2; + rocblas_int const nb_max_org = BS2; + rocblas_int const nb_max_new = (sizeof(T) == 16) ? 22 : 32; + rocblas_int const nb_max = (use_any_org) ? nb_max_org : nb_max_new; // kernel dimensions rocblas_int const blocksReset = batch_count / BS1 + 1; @@ -2339,14 +3502,19 @@ rocblas_status rocsolver_syevj_heevj_template(rocblas_handle handle, dim3 gridPairs(1, 1, 1); dim3 threadsReset(BS1, 1, 1); dim3 threads(BS1, 1, 1); - dim3 threadsDK(BS2 / 2, BS2 / 2, 1); - dim3 threadsDR(BS2, BS2, 1); - dim3 threadsOK(BS2, BS2, 1); + dim3 threadsDK(nb_max / 2, nb_max / 2, 1); + dim3 threadsDR(nb_max, nb_max, 1); + dim3 threadsOK(nb_max, nb_max, 1); dim3 gridOR_org(half_blocks, 2 * blocks, batch_count); - dim3 threadsOR_org(2 * BS2, BS2 / 2, 1); - - dim3 gridOR_new(half_blocks, std::max(1, blocks / 4), batch_count); + dim3 threadsOR_org(2 * nb_max, nb_max / 2, 1); + + // --------------------------------------------------------------- + // number of thread blocks related to number of compute units (CU) + // --------------------------------------------------------------- + auto const num_cu = get_num_cu(); + auto const nbx = std::max(1, std::min(blocks / 4, ceil(num_cu, blocks * batch_count))); + dim3 gridOR_new(half_blocks, nbx, batch_count); dim3 threadsOR_new(BS2, BS2, 1); dim3 gridOR = (use_offd_rotate_org) ? gridOR_org : gridOR_new; @@ -2355,7 +3523,8 @@ rocblas_status rocsolver_syevj_heevj_template(rocblas_handle handle, size_t const size_lds = 64 * 1024; // shared memory sizes size_t const lmemsizeInit = 2 * sizeof(S) * BS1; - size_t const lmemsizeDK = (sizeof(S) + sizeof(T) + 2 * sizeof(rocblas_int)) * (BS2 / 2); + // size_t const lmemsizeDK = (sizeof(S) + sizeof(T) + 2 * sizeof(rocblas_int)) * (BS2 / 2); + size_t const lmemsizeDK = size_lds; size_t lmemsizeDR = std::min(size_lds, 2 * sizeof(T) * nb_max * nb_max); { @@ -2423,8 +3592,17 @@ rocblas_status rocsolver_syevj_heevj_template(rocblas_handle handle, break; // decompose diagonal blocks - ROCSOLVER_LAUNCH_KERNEL(syevj_diag_kernel, gridDK, threadsDK, lmemsizeDK, stream, n, - Acpy, 0, n, n * n, eps, J, completed); + if(use_diag_kernel_org) + { + ROCSOLVER_LAUNCH_KERNEL(syevj_diag_kernel_org, gridDK, threadsDK, lmemsizeDK, + stream, n, Acpy, 0, n, n * n, eps, J, completed); + } + else + { + ROCSOLVER_LAUNCH_KERNEL(syevj_diag_kernel, gridDK, threadsDK, lmemsizeDK, stream, + n, nb_max, Acpy, 0, n, n * n, eps, J, completed, + batch_count, lmemsizeDK); + } // apply rotations calculated by diag_kernel if(use_diag_rotate_org) @@ -2474,8 +3652,8 @@ rocblas_status rocsolver_syevj_heevj_template(rocblas_handle handle, } else { - ROCSOLVER_LAUNCH_KERNEL((syevj_offd_kernel), gridOK, threadsOK, - lmemsizeOK, stream, nb_max, n, Acpy, 0, n, n * n, eps, + ROCSOLVER_LAUNCH_KERNEL((syevj_offd_kernel), gridOK, threadsOK, size_lds, + stream, nb_max, n, Acpy, 0, n, n * n, eps, (ev ? J : nullptr), top, bottom, completed, batch_count); } @@ -2510,8 +3688,8 @@ rocblas_status rocsolver_syevj_heevj_template(rocblas_handle handle, else { ROCSOLVER_LAUNCH_KERNEL((syevj_offd_kernel), gridOK, threadsOK, - lmemsizeOK, stream, nb_max, n, Acpy, 0, n, n * n, - eps, J, top, bottom, completed, batch_count); + size_lds, stream, nb_max, n, Acpy, 0, n, n * n, eps, + J, top, bottom, completed, batch_count); } // apply rotations calculated by offd_kernel