diff --git a/source/module_hsolver/diago_iter_assist.cpp b/source/module_hsolver/diago_iter_assist.cpp index 50b4646f03..54b90e4f33 100644 --- a/source/module_hsolver/diago_iter_assist.cpp +++ b/source/module_hsolver/diago_iter_assist.cpp @@ -240,23 +240,49 @@ void DiagoIterAssist::diagH_subspace_init( // std::vector hpsi(psi_temp.get_nbands() * psi_temp.get_nbasis()); // do hPsi for all bands - psi::Range all_bands_range(1, psi_temp.get_current_k(), 0, psi_temp.get_nbands()-1); - hpsi_info hpsi_in(&psi_temp, all_bands_range, hpsi); - if(pHamilt->ops == nullptr) + if (base_device::get_device_type(ctx) == base_device::GpuDevice) { - ModuleBase::WARNING("DiagoIterAssist::diagH_subspace_init", - "Severe warning: Operators in Hamilt are not allocated yet, will return value of psi to evc directly\n"); - for(int iband = 0; iband < evc.get_nbands(); iband++) + for (int i = 0; i < psi_temp.get_nbands(); i++) { - for(int ig = 0; ig < evc.get_nbasis(); ig++) + psi::Range band_by_band_range(1, psi_temp.get_current_k(), i, i); + hpsi_info hpsi_in(&psi_temp, band_by_band_range, hpsi + i * psi_temp.get_nbasis()); + if(pHamilt->ops == nullptr) { - evc(iband, ig) = psi[iband * evc.get_nbasis() + ig]; + ModuleBase::WARNING("DiagoIterAssist::diagH_subspace_init", + "Severe warning: Operators in Hamilt are not allocated yet, will return value of psi to evc directly\n"); + for(int iband = 0; iband < evc.get_nbands(); iband++) + { + for(int ig = 0; ig < evc.get_nbasis(); ig++) + { + evc(iband, ig) = psi[iband * evc.get_nbasis() + ig]; + } + en[iband] = 0.0; + } + return; } - en[iband] = 0.0; + pHamilt->ops->hPsi(hpsi_in); } - return; } - pHamilt->ops->hPsi(hpsi_in); + else if (base_device::get_device_type(ctx) == base_device::CpuDevice) + { + psi::Range all_bands_range(1, psi_temp.get_current_k(), 0, psi_temp.get_nbands()-1); + hpsi_info hpsi_in(&psi_temp, all_bands_range, hpsi); + if(pHamilt->ops == nullptr) + { + ModuleBase::WARNING("DiagoIterAssist::diagH_subspace_init", + "Severe warning: Operators in Hamilt are not allocated yet, will return value of psi to evc directly\n"); + for(int iband = 0; iband < evc.get_nbands(); iband++) + { + for(int ig = 0; ig < evc.get_nbasis(); ig++) + { + evc(iband, ig) = psi[iband * evc.get_nbasis() + ig]; + } + en[iband] = 0.0; + } + return; + } + pHamilt->ops->hPsi(hpsi_in); + } gemm_op()( ctx,