Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: reorganized HSolverPW<T, Device>::solve function in HSolverPW #4675

Merged
merged 13 commits into from
Jul 16, 2024
3 changes: 2 additions & 1 deletion source/module_hsolver/diago_dav_subspace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
this->notconv = 0;
for (int m = 0; m < this->n_band; m++)
{
if (is_occupied[m])
if (is_occupied[m]) // always true
{
convflag[m] = (std::abs(eigenvalue_iter[m] - eigenvalue_in_hsolver[m]) < this->diag_thr);
}
Expand Down Expand Up @@ -740,6 +740,7 @@ int Diago_DavSubspace<T, Device>::diag(const HPsiFunc& hpsi_func,

int sum_iter = 0;
int ntry = 0;

do
{
if (this->is_subspace || ntry > 0)
Expand Down
226 changes: 103 additions & 123 deletions source/module_hsolver/hsolver_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "module_hsolver/diago_iter_assist.h"

#include <algorithm>
#include <vector>

#ifdef USE_PAW
#include "module_cell/module_paw/paw_cell.h"
Expand All @@ -30,7 +31,6 @@ HSolverPW<T, Device>::HSolverPW(ModulePW::PW_Basis_K* wfc_basis_in, wavefunc* pw
this->wfc_basis = wfc_basis_in;
this->pwf = pwf_in;
this->diag_ethr = GlobalV::PW_DIAG_THR;
/*this->init(pbas_in);*/
}

#ifdef USE_PAW
Expand Down Expand Up @@ -213,6 +213,32 @@ void HSolverPW<T, Device>::paw_func_after_kloop(psi::Psi<T, Device>& psi, elecst

#endif

template <typename T, typename Device>
void HSolverPW<T, Device>::set_isOccupied(std::vector<bool>& is_occupied,
elecstate::ElecState* pes,
const int i_scf,
const int nk,
const int nband,
const bool diago_full_acc_)
{
if (i_scf != 0 && diago_full_acc_ == false)
{
for (int i = 0; i < nk; i++)
{
if (pes->klist->wk[i] > 0.0)
{
for (int j = 0; j < nband; j++)
{
if (pes->wg(i, j) / pes->klist->wk[i] < 0.01)
{
is_occupied[i * nband + j] = false;
}
}
}
}
}
}

template <typename T, typename Device>
void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
psi::Psi<T, Device>& psi,
Expand All @@ -222,46 +248,29 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
{
ModuleBase::TITLE("HSolverPW", "solve");
ModuleBase::timer::tick("HSolverPW", "solve");
// prepare for the precondition of diagonalization
this->precondition.resize(psi.get_nbasis());
this->hamilt_ = pHamilt;

// select the method of diagonalization
this->method = method_in;

// report if the specified diagonalization method is not supported
const std::initializer_list<std::string> _methods = {"cg", "dav", "dav_subspace", "bpcg"};
if (std::find(std::begin(_methods), std::end(_methods), this->method) == std::end(_methods))
{
ModuleBase::WARNING_QUIT("HSolverPW::solve", "This method of DiagH is not supported!");
}

std::vector<Real> eigenvalues(pes->ekb.nr * pes->ekb.nc, 0);

if (this->is_first_scf)
{
is_occupied.resize(psi.get_nk() * psi.get_nbands(), true);
}
else
// prepare for the precondition of diagonalization
std::vector<Real> precondition(psi.get_nbasis(), 0.0);
std::vector<Real> eigenvalues(pes->ekb.nr * pes->ekb.nc, 0.0);
std::vector<bool> is_occupied(psi.get_nk() * psi.get_nbands(), true);
if (this->method == "dav_subspace")
{
if (this->diago_full_acc)
{
is_occupied.assign(is_occupied.size(), true);
}
else
{
for (int i = 0; i < psi.get_nk(); i++)
{
if (pes->klist->wk[i] > 0.0)
{
for (int j = 0; j < psi.get_nbands(); j++)
{
if (pes->wg(i, j) / pes->klist->wk[i] < 0.01)
{
is_occupied[i * psi.get_nbands() + j] = false;
}
}
}
}
}
this->set_isOccupied(is_occupied,
pes,
DiagoIterAssist<T, Device>::SCF_ITER,
psi.get_nk(),
psi.get_nbands(),
this->diago_full_acc);
}

/// Loop over k points for solve Hamiltonian to charge density
Expand All @@ -284,7 +293,7 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
#endif

/// solve eigenvector and eigenvalue for H(k)
this->hamiltSolvePsiK(pHamilt, psi, eigenvalues.data() + ik * pes->ekb.nc);
this->hamiltSolvePsiK(pHamilt, psi, precondition, eigenvalues.data() + ik * pes->ekb.nc);

if (skip_charge)
{
Expand All @@ -298,54 +307,35 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
}
// END Loop over k points

// copy eigenvalues to pes->ekb in ElecState
base_device::memory::cast_memory_op<double, Real, base_device::DEVICE_CPU, base_device::DEVICE_CPU>()(
cpu_ctx,
cpu_ctx,
pes->ekb.c,
eigenvalues.data(),
pes->ekb.nr * pes->ekb.nc);

this->is_first_scf = false;

this->endDiagh();
// psi only should be initialed once for PW
if (!this->initialed_psi)
{
this->initialed_psi = true;
}

if (skip_charge)
{
ModuleBase::timer::tick("HSolverPW", "solve");
return;
}
reinterpret_cast<elecstate::ElecStatePW<T, Device>*>(pes)->psiToRho(psi);
else
{
reinterpret_cast<elecstate::ElecStatePW<T, Device>*>(pes)->psiToRho(psi);

#ifdef USE_PAW
this->paw_func_after_kloop(psi, pes);
this->paw_func_after_kloop(psi, pes);
#endif

ModuleBase::timer::tick("HSolverPW", "solve");
return;
}

template <typename T, typename Device>
void HSolverPW<T, Device>::endDiagh()
{
// in PW base, average iteration steps for each band and k-point should be
// printing
if (DiagoIterAssist<T, Device>::avg_iter > 0.0)
{
GlobalV::ofs_running << "Average iterative diagonalization steps: "
<< DiagoIterAssist<T, Device>::avg_iter / this->wfc_basis->nks
<< " ; where current threshold is: " << DiagoIterAssist<T, Device>::PW_DIAG_THR << " . "
<< std::endl;

// std::cout << "avg_iter == " << DiagoIterAssist<T, Device>::avg_iter
// << std::endl;

// reset avg_iter
DiagoIterAssist<T, Device>::avg_iter = 0.0;
}
// psi only should be initialed once for PW
if (!this->initialed_psi)
{
this->initialed_psi = true;
ModuleBase::timer::tick("HSolverPW", "solve");
return;
}
}

Expand All @@ -361,13 +351,22 @@ void HSolverPW<T, Device>::updatePsiK(hamilt::Hamilt<T, Device>* pHamilt, psi::P
}

template <typename T, typename Device>
void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm, psi::Psi<T, Device>& psi, Real* eigenvalue)
void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
psi::Psi<T, Device>& psi,
std::vector<Real>& pre_condition,
Real* eigenvalue)
{
#ifdef __MPI
const diag_comm_info comm_info = {POOL_WORLD, GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL};
#else
const diag_comm_info comm_info = {GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL};
#endif

if (this->method == "cg")
{
// warp the subspace_func into a lambda function
auto ngk_pointer = psi.get_ngk_pointer();
auto subspace_func = [this, ngk_pointer](const ct::Tensor& psi_in, ct::Tensor& psi_out) {
auto subspace_func = [hm, ngk_pointer](const ct::Tensor& psi_in, ct::Tensor& psi_out) {
// psi_in should be a 2D tensor:
// psi_in.shape() = [nbands, nbasis]
const auto ndim = psi_in.shape().ndim();
Expand All @@ -387,7 +386,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm, psi::P
ct::DeviceType::CpuDevice,
ct::TensorShape({psi_in.shape().dim_size(0)}));

DiagoIterAssist<T, Device>::diagH_subspace(hamilt_, psi_in_wrapper, psi_out_wrapper, eigen.data<Real>());
DiagoIterAssist<T, Device>::diagH_subspace(hm, psi_in_wrapper, psi_out_wrapper, eigen.data<Real>());
};
DiagoCG<T, Device> cg(GlobalV::BASIS_TYPE,
GlobalV::CALCULATION,
Expand Down Expand Up @@ -456,45 +455,26 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm, psi::P
ct::DataTypeToEnum<Real>::value,
ct::DeviceTypeToEnum<ct::DEVICE_CPU>::value,
ct::TensorShape({psi.get_nbands()}));
auto prec_tensor = ct::TensorMap(precondition.data(),
auto prec_tensor = ct::TensorMap(pre_condition.data(),
ct::DataTypeToEnum<Real>::value,
ct::DeviceTypeToEnum<ct::DEVICE_CPU>::value,
ct::TensorShape({static_cast<int>(precondition.size())}))
ct::TensorShape({static_cast<int>(pre_condition.size())}))
.to_device<ct_Device>()
.slice({0}, {psi.get_current_nbas()});

cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, prec_tensor);
// TODO: Double check tensormap's potential problem
ct::TensorMap(psi.get_pointer(), psi_tensor, {psi.get_nbands(), psi.get_nbasis()}).sync(psi_tensor);
}
else if (this->method == "bpcg")
{
DiagoBPCG<T, Device> bpcg(pre_condition.data());
bpcg.init_iter(psi);
bpcg.diag(hm, psi, eigenvalue);
}
else if (this->method == "dav_subspace")
{
#ifdef __MPI
const diag_comm_info comm_info = {POOL_WORLD, GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL};
#else
const diag_comm_info comm_info = {GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL};
#endif
Diago_DavSubspace<T, Device> dav_subspace(this->precondition,
psi.get_nbands(),
psi.get_k_first() ? psi.get_current_nbas()
: psi.get_nk() * psi.get_nbasis(),
GlobalV::PW_DIAG_NDIM,
DiagoIterAssist<T, Device>::PW_DIAG_THR,
DiagoIterAssist<T, Device>::PW_DIAG_NMAX,
DiagoIterAssist<T, Device>::need_subspace,
comm_info);
bool scf;
if (GlobalV::CALCULATION == "nscf")
{
scf = false;
}
else
{
scf = true;
}

auto ngk_pointer = psi.get_ngk_pointer();

auto hpsi_func = [hm, ngk_pointer](T* hpsi_out,
T* psi_in,
const int nband_in,
Expand All @@ -514,40 +494,26 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm, psi::P

ModuleBase::timer::tick("DavSubspace", "hpsi_func");
};
bool scf = GlobalV::CALCULATION == "nscf" ? false : true;
const std::vector<bool> is_occupied(psi.get_nbands(), true);

auto subspace_func = [hm, ngk_pointer](T* psi_out,
T* psi_in,
Real* eigenvalue_in_hsolver,
const int nband_in,
const int nbasis_max_in) {
// Convert "pointer data stucture" to a psi::Psi object
auto psi_in_wrapper = psi::Psi<T, Device>(psi_in, 1, nband_in, nbasis_max_in, ngk_pointer);
auto psi_out_wrapper = psi::Psi<T, Device>(psi_out, 1, nband_in, nbasis_max_in, ngk_pointer);

DiagoIterAssist<T, Device>::diagH_subspace(hm,
psi_in_wrapper,
psi_out_wrapper,
eigenvalue_in_hsolver,
nband_in);
};
Diago_DavSubspace<T, Device> dav_subspace(pre_condition,
psi.get_nbands(),
psi.get_k_first() ? psi.get_current_nbas()
: psi.get_nk() * psi.get_nbasis(),
GlobalV::PW_DIAG_NDIM,
DiagoIterAssist<T, Device>::PW_DIAG_THR,
DiagoIterAssist<T, Device>::PW_DIAG_NMAX,
DiagoIterAssist<T, Device>::need_subspace,
comm_info);

DiagoIterAssist<T, Device>::avg_iter += static_cast<double>(
dav_subspace.diag(hpsi_func, psi.get_pointer(), psi.get_nbasis(), eigenvalue, is_occupied, scf));
}
else if (this->method == "bpcg")
{
DiagoBPCG<T, Device> bpcg(precondition.data());
bpcg.init_iter(psi);
bpcg.diag(hm, psi, eigenvalue);
DiagoIterAssist<T, Device>::avg_iter
+= static_cast<double>(dav_subspace.diag(hpsi_func, psi.get_pointer(), psi.get_nbasis(), eigenvalue, is_occupied, scf));
}
else if (this->method == "dav")
{
#ifdef __MPI
const diag_comm_info comm_info = {POOL_WORLD, GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL};
#else
const diag_comm_info comm_info = {GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL};
#endif
// Davidson iter parameters

// Allow 5 tries at most. If ntry > ntry_max = 5, exit diag loop.
const int ntry_max = 5;
// In non-self consistent calculation, do until totally converged. Else
Expand All @@ -561,7 +527,6 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm, psi::P
const int dim = psi.get_current_nbas();
const int nband = psi.get_nbands();
const int ldPsi = psi.get_nbasis();


auto ngk_pointer = psi.get_ngk_pointer();
/// wrap for hpsi function, Matrix \times blockvector
Expand Down Expand Up @@ -604,7 +569,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm, psi::P
1);
*/

DiagoDavid<T, Device> david(precondition.data(), GlobalV::PW_DIAG_NDIM, GlobalV::use_paw, comm_info);
DiagoDavid<T, Device> david(pre_condition.data(), GlobalV::PW_DIAG_NDIM, GlobalV::use_paw, comm_info);
DiagoIterAssist<T, Device>::avg_iter += static_cast<double>(
david.diag(hpsi_func, spsi_func, dim, nband, ldPsi, psi, eigenvalue, david_diag_thr, david_maxiter, ntry_max, notconv_max));
}
Expand Down Expand Up @@ -657,6 +622,21 @@ void HSolverPW<T, Device>::update_precondition(std::vector<Real>& h_diag, const
}
}

template <typename T, typename Device>
void HSolverPW<T, Device>::output_iterInfo()
{
// in PW base, average iteration steps for each band and k-point should be printing
if (DiagoIterAssist<T, Device>::avg_iter > 0.0)
{
GlobalV::ofs_running << "Average iterative diagonalization steps: "
<< DiagoIterAssist<T, Device>::avg_iter / this->wfc_basis->nks
<< " ; where current threshold is: " << DiagoIterAssist<T, Device>::PW_DIAG_THR << " . "
<< std::endl;
// reset avg_iter
DiagoIterAssist<T, Device>::avg_iter = 0.0;
}
}

template <typename T, typename Device>
typename HSolverPW<T, Device>::Real HSolverPW<T, Device>::cal_hsolerror()
{
Expand Down
Loading
Loading