Skip to content

Commit

Permalink
fix: segment fault in GPU-Davidson (#5763)
Browse files Browse the repository at this point in the history
It comes from the compatible problem between c++ compiler and nvcc
  • Loading branch information
Qianruipku authored Dec 25, 2024
1 parent 91b0281 commit 28cb769
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions source/module_hsolver/kernels/cuda/math_kernel_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -760,35 +760,41 @@ void gemv_op<std::complex<float>, base_device::DEVICE_GPU>::operator()(const bas
const char& trans,
const int& m,
const int& n,
const std::complex<float>* alpha,
const std::complex<float>* alpha_in,
const std::complex<float>* A,
const int& lda,
const std::complex<float>* X,
const int& incx,
const std::complex<float>* beta,
const std::complex<float>* beta_in,
std::complex<float>* Y,
const int& incy)
{
cublasOperation_t cutrans = judge_trans_op(true, trans, "gemv_op");
cublasErrcheck(cublasCgemv(cublas_handle, cutrans, m, n, (float2*)alpha, (float2*)A, lda, (float2*)X, incx, (float2*)beta, (float2*)Y, incx));
cuFloatComplex alpha = make_cuFloatComplex(alpha_in->real(), alpha_in->imag());
cuFloatComplex beta = make_cuFloatComplex(beta_in->real(), beta_in->imag());
cublasErrcheck(cublasCgemv(cublas_handle, cutrans, m, n, &alpha, (cuFloatComplex*)A, lda, (cuFloatComplex*)X, incx, &beta, (cuFloatComplex*)Y, incx));
}

template <>
void gemv_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* d,
const char& trans,
const int& m,
const int& n,
const std::complex<double>* alpha,
const std::complex<double>* alpha_in,
const std::complex<double>* A,
const int& lda,
const std::complex<double>* X,
const int& incx,
const std::complex<double>* beta,
const std::complex<double>* beta_in,
std::complex<double>* Y,
const int& incy)
{
cublasOperation_t cutrans = judge_trans_op(true, trans, "gemv_op");
cublasErrcheck(cublasZgemv(cublas_handle, cutrans, m, n, (double2*)alpha, (double2*)A, lda, (double2*)X, incx, (double2*)beta, (double2*)Y, incx));
cuDoubleComplex alpha = make_cuDoubleComplex(alpha_in->real(), alpha_in->imag());
cuDoubleComplex beta = make_cuDoubleComplex(beta_in->real(), beta_in->imag());
// icpc and nvcc have some compatible problems
// We must use cuDoubleComplex instead of converting std::complex<double>* to cuDoubleComplex*
cublasErrcheck(cublasZgemv(cublas_handle, cutrans, m, n, &alpha, (cuDoubleComplex*)A, lda, (cuDoubleComplex*)X, incx, &beta, (cuDoubleComplex*)Y, incx));
}

template <>
Expand Down

0 comments on commit 28cb769

Please sign in to comment.