Skip to content

Commit

Permalink
Refactor:Remove GloblaC::ucell in module_hsolver (#5657)
Browse files Browse the repository at this point in the history
* change ucell in the hsolver_pw

* change test hsolver_pw

* Revert "change ucell in the hsolver_pw"

This reverts commit 51ba921.

* Revert "change test hsolver_pw"

This reverts commit a815fdc.

* use parameter trans instead of ucell
  • Loading branch information
A-006 authored Dec 2, 2024
1 parent a43cbfb commit 48065a3
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 29 deletions.
2 changes: 1 addition & 1 deletion source/module_esolver/esolver_ks_lcaopw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ namespace ModuleESolver
}

hsolver::HSolverLIP<T> hsolver_lip_obj(this->pw_wfc);
hsolver_lip_obj.solve(this->p_hamilt, this->kspw_psi[0], this->pelec, psig.lock().get()[0], skip_charge);
hsolver_lip_obj.solve(this->p_hamilt, this->kspw_psi[0], this->pelec, psig.lock().get()[0], skip_charge,ucell.tpiba,ucell.nat);

// add exx
#ifdef __EXX
Expand Down
4 changes: 3 additions & 1 deletion source/module_esolver/esolver_ks_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,9 @@ void ESolver_KS_PW<T, Device>::hamilt2density_single(UnitCell& ucell,
this->pelec->ekb.c,
GlobalV::RANK_IN_POOL,
GlobalV::NPROC_IN_POOL,
skip_charge);
skip_charge,
ucell.tpiba,
ucell.nat);

Symmetry_rho srho;
for (int is = 0; is < PARAM.inp.nspin; is++)
Expand Down
23 changes: 14 additions & 9 deletions source/module_hsolver/hsolver_lcaopw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace hsolver

#ifdef USE_PAW
template <typename T>
void HSolverLIP<T>::paw_func_in_kloop(const int ik)
void HSolverLIP<T>::paw_func_in_kloop(const int ik,const double tpiba)
{
if (PARAM.inp.use_paw)
{
Expand Down Expand Up @@ -64,7 +64,7 @@ void HSolverLIP<T>::paw_func_in_kloop(const int ik)
this->wfc_basis->get_ig2iy(ik).data(),
this->wfc_basis->get_ig2iz(ik).data(),
(const double**)kpg,
GlobalC::ucell.tpiba,
tpiba,
(const double**)gcar);

std::vector<double>().swap(kpt);
Expand All @@ -83,7 +83,10 @@ void HSolverLIP<T>::paw_func_in_kloop(const int ik)
}

template <typename T>
void HSolverLIP<T>::paw_func_after_kloop(psi::Psi<T>& psi, elecstate::ElecState* pes)
void HSolverLIP<T>::paw_func_after_kloop(psi::Psi<T>& psi,
elecstate::ElecState* pes,
const double tpiba,
const int nat)
{
if (PARAM.inp.use_paw)
{
Expand Down Expand Up @@ -131,7 +134,7 @@ void HSolverLIP<T>::paw_func_after_kloop(psi::Psi<T>& psi, elecstate::ElecState*
this->wfc_basis->get_ig2iy(ik).data(),
this->wfc_basis->get_ig2iz(ik).data(),
(const double**)kpg,
GlobalC::ucell.tpiba,
tpiba,
(const double**)gcar);

std::vector<double>().swap(kpt);
Expand Down Expand Up @@ -164,7 +167,7 @@ void HSolverLIP<T>::paw_func_after_kloop(psi::Psi<T>& psi, elecstate::ElecState*
{
GlobalC::paw_cell.get_rhoijp(rhoijp, rhoijselect, nrhoijsel);

for (int iat = 0; iat < GlobalC::ucell.nat; iat++)
for (int iat = 0; iat < nat; iat++)
{
GlobalC::paw_cell.set_rhoij(iat,
nrhoijsel[iat],
Expand All @@ -176,7 +179,7 @@ void HSolverLIP<T>::paw_func_after_kloop(psi::Psi<T>& psi, elecstate::ElecState*
#else
GlobalC::paw_cell.get_rhoijp(rhoijp, rhoijselect, nrhoijsel);

for (int iat = 0; iat < GlobalC::ucell.nat; iat++)
for (int iat = 0; iat < nat; iat++)
{
GlobalC::paw_cell.set_rhoij(iat,
nrhoijsel[iat],
Expand All @@ -201,7 +204,9 @@ void HSolverLIP<T>::solve(hamilt::Hamilt<T>* pHamilt, // ESolver_KS_PW::p_hamilt
psi::Psi<T>& psi, // ESolver_KS_PW::kspw_psi
elecstate::ElecState* pes, // ESolver_KS_PW::pes
psi::Psi<T>& transform,
const bool skip_charge)
const bool skip_charge,
const double tpiba,
const int nat)
{
ModuleBase::TITLE("HSolverLIP", "solve");
ModuleBase::timer::tick("HSolverLIP", "solve");
Expand All @@ -212,7 +217,7 @@ void HSolverLIP<T>::solve(hamilt::Hamilt<T>* pHamilt, // ESolver_KS_PW::p_hamilt
pHamilt->updateHk(ik);

#ifdef USE_PAW
this->paw_func_in_kloop(ik);
this->paw_func_in_kloop(ik,tpiba);
#endif

psi.fix_k(ik);
Expand Down Expand Up @@ -282,7 +287,7 @@ void HSolverLIP<T>::solve(hamilt::Hamilt<T>* pHamilt, // ESolver_KS_PW::p_hamilt
reinterpret_cast<elecstate::ElecStatePW<T>*>(pes)->psiToRho(psi);

#ifdef USE_PAW
this->paw_func_after_kloop(psi, pes);
this->paw_func_after_kloop(psi, pes,tpiba,nat);
#endif

ModuleBase::timer::tick("HSolverLIP", "solve");
Expand Down
12 changes: 9 additions & 3 deletions source/module_hsolver/hsolver_lcaopw.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,21 @@ class HSolverLIP
psi::Psi<T>& psi,
elecstate::ElecState* pes,
psi::Psi<T>& transform,
const bool skip_charge);
const bool skip_charge,
const double tpiba,
const int nat);

private:
ModulePW::PW_Basis_K* wfc_basis;

#ifdef USE_PAW
void paw_func_in_kloop(const int ik);
void paw_func_in_kloop(const int ik,
const double tpiba);

void paw_func_after_kloop(psi::Psi<T>& psi, elecstate::ElecState* pes);
void paw_func_after_kloop(psi::Psi<T>& psi,
elecstate::ElecState* pes,
const double tpiba,
const int nat);
#endif
};

Expand Down
24 changes: 15 additions & 9 deletions source/module_hsolver/hsolver_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ namespace hsolver

#ifdef USE_PAW
template <typename T, typename Device>
void HSolverPW<T, Device>::paw_func_in_kloop(const int ik)
void HSolverPW<T, Device>::paw_func_in_kloop(const int ik,
const double tpiba)
{
if (this->use_paw)
{
Expand Down Expand Up @@ -68,7 +69,7 @@ void HSolverPW<T, Device>::paw_func_in_kloop(const int ik)
this->wfc_basis->get_ig2iy(ik).data(),
this->wfc_basis->get_ig2iz(ik).data(),
(const double**)kpg,
GlobalC::ucell.tpiba,
tpiba,
(const double**)gcar);

std::vector<double>().swap(kpt);
Expand Down Expand Up @@ -96,7 +97,10 @@ void HSolverPW<T, Device>::call_paw_cell_set_currentk(const int ik)
}

template <typename T, typename Device>
void HSolverPW<T, Device>::paw_func_after_kloop(psi::Psi<T, Device>& psi, elecstate::ElecState* pes)
void HSolverPW<T, Device>::paw_func_after_kloop(psi::Psi<T, Device>& psi,
elecstate::ElecState* pes,
const double tpiba,
const int nat)
{
if (this->use_paw)
{
Expand Down Expand Up @@ -144,7 +148,7 @@ void HSolverPW<T, Device>::paw_func_after_kloop(psi::Psi<T, Device>& psi, elecst
this->wfc_basis->get_ig2iy(ik).data(),
this->wfc_basis->get_ig2iz(ik).data(),
(const double**)kpg,
GlobalC::ucell.tpiba,
tpiba,
(const double**)gcar);

std::vector<double>().swap(kpt);
Expand Down Expand Up @@ -177,7 +181,7 @@ void HSolverPW<T, Device>::paw_func_after_kloop(psi::Psi<T, Device>& psi, elecst
{
GlobalC::paw_cell.get_rhoijp(rhoijp, rhoijselect, nrhoijsel);

for (int iat = 0; iat < GlobalC::ucell.nat; iat++)
for (int iat = 0; iat < nat; iat++)
{
GlobalC::paw_cell.set_rhoij(iat,
nrhoijsel[iat],
Expand All @@ -189,7 +193,7 @@ void HSolverPW<T, Device>::paw_func_after_kloop(psi::Psi<T, Device>& psi, elecst
#else
GlobalC::paw_cell.get_rhoijp(rhoijp, rhoijselect, nrhoijsel);

for (int iat = 0; iat < GlobalC::ucell.nat; iat++)
for (int iat = 0; iat < nat; iat++)
{
GlobalC::paw_cell.set_rhoij(iat,
nrhoijsel[iat],
Expand Down Expand Up @@ -255,7 +259,9 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
double* out_eigenvalues,
const int rank_in_pool_in,
const int nproc_in_pool_in,
const bool skip_charge)
const bool skip_charge,
const double tpiba,
const int nat)
{
ModuleBase::TITLE("HSolverPW", "solve");
ModuleBase::timer::tick("HSolverPW", "solve");
Expand All @@ -282,7 +288,7 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
pHamilt->updateHk(ik);

#ifdef USE_PAW
this->paw_func_in_kloop(ik);
this->paw_func_in_kloop(ik,tpiba);
#endif

/// update psi pointer for each k point
Expand Down Expand Up @@ -341,7 +347,7 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
reinterpret_cast<elecstate::ElecStatePW<T, Device>*>(pes)->psiToRho(psi);

#ifdef USE_PAW
this->paw_func_after_kloop(psi, pes);
this->paw_func_after_kloop(psi, pes,tpiba,nat);
#endif

ModuleBase::timer::tick("HSolverPW", "solve");
Expand Down
9 changes: 6 additions & 3 deletions source/module_hsolver/hsolver_pw.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ class HSolverPW
double* out_eigenvalues,
const int rank_in_pool_in,
const int nproc_in_pool_in,
const bool skip_charge);
const bool skip_charge,
const double tpiba,
const int nat);

protected:
// diago caller
Expand Down Expand Up @@ -89,11 +91,12 @@ class HSolverPW
std::vector<double> ethr_band;

#ifdef USE_PAW
void paw_func_in_kloop(const int ik);
void paw_func_in_kloop(const int ik,
const double tpiba);

void call_paw_cell_set_currentk(const int ik);

void paw_func_after_kloop(psi::Psi<T, Device>& psi, elecstate::ElecState* pes);
void paw_func_after_kloop(psi::Psi<T, Device>& psi, elecstate::ElecState* pes,const double tpiba,const int nat);
#endif
};

Expand Down
5 changes: 2 additions & 3 deletions source/module_hsolver/test/test_hsolver_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,7 @@ TEST_F(TestHSolverPW, SolveLcaoInPW) {
= hsolver::HSolverLIP<std::complex<float>>(&pwbk);
hsolver::HSolverLIP<std::complex<double>> hs_d_lip
= hsolver::HSolverLIP<std::complex<double>>(&pwbk);
hs_f_lip.solve(&hamilt_test_f, psi_test_cf, &elecstate_test,
transform_test_cf, true);
hs_f_lip.solve(&hamilt_test_f, psi_test_cf, &elecstate_test,transform_test_cf, true,0.0,0);
EXPECT_DOUBLE_EQ(hsolver::DiagoIterAssist<std::complex<float>>::avg_iter, 0.0);
for (int i = 0; i < psi_test_cf.size(); i++)
{
Expand All @@ -261,7 +260,7 @@ TEST_F(TestHSolverPW, SolveLcaoInPW) {

elecstate_test.ekb.c[0] = 1.0;
elecstate_test.ekb.c[1] = 2.0;
hs_d_lip.solve(&hamilt_test_d, psi_test_cd, &elecstate_test, transform_test_cd, true);
hs_d_lip.solve(&hamilt_test_d, psi_test_cd, &elecstate_test, transform_test_cd, true,0.0,0);
EXPECT_DOUBLE_EQ(hsolver::DiagoIterAssist<std::complex<double>>::avg_iter, 0.0);
for (int i = 0; i < psi_test_cd.size(); i++)
{
Expand Down

0 comments on commit 48065a3

Please sign in to comment.