From 0dfed4b184992063553eba302acba87bbf5a87d1 Mon Sep 17 00:00:00 2001 From: haozhihan Date: Tue, 14 May 2024 11:03:29 +0800 Subject: [PATCH 1/2] modify hpsi interface in diagH_subspace_init --- source/module_hsolver/diago_iter_assist.cpp | 48 ++++++++++++++++----- 1 file changed, 37 insertions(+), 11 deletions(-) diff --git a/source/module_hsolver/diago_iter_assist.cpp b/source/module_hsolver/diago_iter_assist.cpp index fa3edaf8c2..7221c4c9c5 100644 --- a/source/module_hsolver/diago_iter_assist.cpp +++ b/source/module_hsolver/diago_iter_assist.cpp @@ -240,23 +240,49 @@ void DiagoIterAssist::diagH_subspace_init( // std::vector hpsi(psi_temp.get_nbands() * psi_temp.get_nbasis()); // do hPsi for all bands - psi::Range all_bands_range(1, psi_temp.get_current_k(), 0, psi_temp.get_nbands()-1); - hpsi_info hpsi_in(&psi_temp, all_bands_range, hpsi); - if(pHamilt->ops == nullptr) + if (psi::device::get_device_type(ctx) == psi::GpuDevice) { - ModuleBase::WARNING("DiagoIterAssist::diagH_subspace_init", - "Severe warning: Operators in Hamilt are not allocated yet, will return value of psi to evc directly\n"); - for(int iband = 0; iband < evc.get_nbands(); iband++) + for (int i = 0; i < psi_temp.get_nbands(); i++) { - for(int ig = 0; ig < evc.get_nbasis(); ig++) + psi::Range band_by_band_range(1, psi_temp.get_current_k(), i, i); + hpsi_info hpsi_in(&psi_temp, band_by_band_range, hpsi + i * psi_temp.get_nbasis()); + if(pHamilt->ops == nullptr) { - evc(iband, ig) = psi[iband * evc.get_nbasis() + ig]; + ModuleBase::WARNING("DiagoIterAssist::diagH_subspace_init", + "Severe warning: Operators in Hamilt are not allocated yet, will return value of psi to evc directly\n"); + for(int iband = 0; iband < evc.get_nbands(); iband++) + { + for(int ig = 0; ig < evc.get_nbasis(); ig++) + { + evc(iband, ig) = psi[iband * evc.get_nbasis() + ig]; + } + en[iband] = 0.0; + } + return; } - en[iband] = 0.0; + pHamilt->ops->hPsi(hpsi_in); } - return; } - pHamilt->ops->hPsi(hpsi_in); + else if (psi::device::get_device_type(ctx) == psi::CpuDevice) + { + psi::Range all_bands_range(1, psi_temp.get_current_k(), 0, psi_temp.get_nbands()-1); + hpsi_info hpsi_in(&psi_temp, all_bands_range, hpsi); + if(pHamilt->ops == nullptr) + { + ModuleBase::WARNING("DiagoIterAssist::diagH_subspace_init", + "Severe warning: Operators in Hamilt are not allocated yet, will return value of psi to evc directly\n"); + for(int iband = 0; iband < evc.get_nbands(); iband++) + { + for(int ig = 0; ig < evc.get_nbasis(); ig++) + { + evc(iband, ig) = psi[iband * evc.get_nbasis() + ig]; + } + en[iband] = 0.0; + } + return; + } + pHamilt->ops->hPsi(hpsi_in); + } gemm_op()( ctx, From c8037784b0948c4d1576a830fc54150cc2eab720 Mon Sep 17 00:00:00 2001 From: haozhihan Date: Wed, 22 May 2024 14:18:07 +0800 Subject: [PATCH 2/2] fix build bug --- source/module_hsolver/diago_iter_assist.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/source/module_hsolver/diago_iter_assist.cpp b/source/module_hsolver/diago_iter_assist.cpp index c8015da0c7..54b90e4f33 100644 --- a/source/module_hsolver/diago_iter_assist.cpp +++ b/source/module_hsolver/diago_iter_assist.cpp @@ -240,7 +240,7 @@ void DiagoIterAssist::diagH_subspace_init( // std::vector hpsi(psi_temp.get_nbands() * psi_temp.get_nbasis()); // do hPsi for all bands - if (psi::device::get_device_type(ctx) == psi::GpuDevice) + if (base_device::get_device_type(ctx) == base_device::GpuDevice) { for (int i = 0; i < psi_temp.get_nbands(); i++) { @@ -263,7 +263,7 @@ void DiagoIterAssist::diagH_subspace_init( pHamilt->ops->hPsi(hpsi_in); } } - else if (psi::device::get_device_type(ctx) == psi::CpuDevice) + else if (base_device::get_device_type(ctx) == base_device::CpuDevice) { psi::Range all_bands_range(1, psi_temp.get_current_k(), 0, psi_temp.get_nbands()-1); hpsi_info hpsi_in(&psi_temp, all_bands_range, hpsi);