Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: Sorting out the calculation logic of pexsi in hsolver-lcao #5299

Merged
merged 2 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 43 additions & 41 deletions source/module_hsolver/hsolver_lcao.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,50 @@ void HSolverLCAO<T, Device>::solve(hamilt::Hamilt<T>* pHamilt,
ModuleBase::TITLE("HSolverLCAO", "solve");
ModuleBase::timer::tick("HSolverLCAO", "solve");

#ifdef __PEXSI // other purification methods should follow this routine
// Zhang Xiaoyang : Please modify Pesxi usage later
if (this->method == "pexsi")
if (this->method != "pexsi")
{
if (GlobalV::KPAR_LCAO > 1
&& (this->method == "genelpa" || this->method == "elpa" || this->method == "scalapack_gvx"))
{
#ifdef __MPI
this->parakSolve(pHamilt, psi, pes, GlobalV::KPAR_LCAO);
#endif
}
else if (GlobalV::KPAR_LCAO == 1)
{
/// Loop over k points for solve Hamiltonian to eigenpairs(eigenvalues and eigenvectors).
for (int ik = 0; ik < psi.get_nk(); ++ik)
{
/// update H(k) for each k point
pHamilt->updateHk(ik);

/// find psi pointer for each k point
psi.fix_k(ik);

/// solve eigenvector and eigenvalue for H(k)
this->hamiltSolvePsiK(pHamilt, psi, &(pes->ekb(ik, 0)));
}
}
else
{
ModuleBase::WARNING_QUIT("HSolverLCAO::solve",
"This method and KPAR setting is not supported for lcao basis in ABACUS!");
}

if (!skip_charge)
{
// used in scf calculation
// calculate charge by eigenpairs(eigenvalues and eigenvectors)
pes->psiToRho(psi);
}
else
{
// used in nscf calculation
}
}
else if (this->method == "pexsi")
{
#ifdef __PEXSI // other purification methods should follow this routine
DiagoPexsi<T> pe(ParaV);
for (int ik = 0; ik < psi.get_nk(); ++ik)
{
Expand All @@ -60,41 +100,7 @@ void HSolverLCAO<T, Device>::solve(hamilt::Hamilt<T>* pHamilt,
pes->f_en.eband = pe.totalFreeEnergy;
// maybe eferm could be dealt with in the future
_pes->dmToRho(pe.DM, pe.EDM);
ModuleBase::timer::tick("HSolverLCAO", "solve");
return;
}
#endif

if (GlobalV::KPAR_LCAO > 1
&& (this->method == "genelpa" || this->method == "elpa" || this->method == "scalapack_gvx"))
{
#ifdef __MPI
this->parakSolve(pHamilt, psi, pes, GlobalV::KPAR_LCAO);
#endif
}
else if (GlobalV::KPAR_LCAO == 1)
{
/// Loop over k points for solve Hamiltonian to eigenpairs(eigenvalues and eigenvectors).
for (int ik = 0; ik < psi.get_nk(); ++ik)
{
/// update H(k) for each k point
pHamilt->updateHk(ik);

/// find psi pointer for each k point
psi.fix_k(ik);

/// solve eigenvector and eigenvalue for H(k)
this->hamiltSolvePsiK(pHamilt, psi, &(pes->ekb(ik, 0)));
}
}

if (!skip_charge) // used in scf calculation
{
// calculate charge by eigenpairs(eigenvalues and eigenvectors)
pes->psiToRho(psi);
}
else // used in nscf calculation
{
}

ModuleBase::timer::tick("HSolverLCAO", "solve");
Expand All @@ -114,7 +120,6 @@ void HSolverLCAO<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T>* hm, psi::Psi<T>&
sa.diag(hm, psi, eigenvalue);
#endif
}

#ifdef __ELPA
else if (this->method == "genelpa")
{
Expand All @@ -127,7 +132,6 @@ void HSolverLCAO<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T>* hm, psi::Psi<T>&
el.diag(hm, psi, eigenvalue);
}
#endif

#ifdef __CUDA
else if (this->method == "cusolver")
{
Expand All @@ -142,15 +146,13 @@ void HSolverLCAO<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T>* hm, psi::Psi<T>&
}
#endif
#endif

#ifndef __MPI
else if (this->method == "lapack") // only for single core
{
DiagoLapack<T> la;
la.diag(hm, psi, eigenvalue);
}
#endif

else
{
ModuleBase::WARNING_QUIT("HSolverLCAO::solve", "This method is not supported for lcao basis in ABACUS!");
Expand Down
8 changes: 2 additions & 6 deletions source/module_hsolver/hsolver_lcao.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,13 @@ class HSolverLCAO
const bool skip_charge);

private:
void hamiltSolvePsiK(hamilt::Hamilt<T>* hm, psi::Psi<T>& psi, double* eigenvalue);
void hamiltSolvePsiK(hamilt::Hamilt<T>* hm, psi::Psi<T>& psi, double* eigenvalue); // for kpar_lcao == 1

void parakSolve(hamilt::Hamilt<T>* pHamilt, psi::Psi<T>& psi, elecstate::ElecState* pes, int kpar);
void parakSolve(hamilt::Hamilt<T>* pHamilt, psi::Psi<T>& psi, elecstate::ElecState* pes, int kpar); // for kpar_lcao > 1

const Parallel_Orbitals* ParaV;

const std::string method;

// for cg_in_lcao
using Real = typename GetTypeReal<T>::type;
std::vector<Real> precondition_lcao;
};

} // namespace hsolver
Expand Down
Loading