Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: Move the print of H(k)&S(k) and wavefunctions to after_scf. #5682

Merged
merged 6 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions source/module_esolver/esolver_gets.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,4 +143,9 @@ void ESolver_GetS::runner(UnitCell& ucell, const int istep)
ModuleBase::timer::tick("ESolver_GetS", "runner");
}

void ESolver_GetS::after_all_runners(UnitCell& ucell) {};
double ESolver_GetS::cal_energy() {};
void ESolver_GetS::cal_force(UnitCell& ucell, ModuleBase::matrix& force) {};
void ESolver_GetS::cal_stress(UnitCell& ucell, ModuleBase::matrix& stress) {};

} // namespace ModuleESolver
8 changes: 4 additions & 4 deletions source/module_esolver/esolver_gets.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,18 @@ class ESolver_GetS : public ESolver_KS<std::complex<double>>

void before_all_runners(UnitCell& ucell, const Input_para& inp) override;

void after_all_runners(UnitCell& ucell){};
void after_all_runners(UnitCell& ucell) override;

void runner(UnitCell& ucell, const int istep) override;

//! calculate total energy of a given system
double cal_energy() {};
double cal_energy() override;

//! calcualte forces for the atoms in the given cell
void cal_force(UnitCell& ucell, ModuleBase::matrix& force) {};
void cal_force(UnitCell& ucell, ModuleBase::matrix& force) override;

//! calcualte stress of given cell
void cal_stress(UnitCell& ucell, ModuleBase::matrix& stress) {};
void cal_stress(UnitCell& ucell, ModuleBase::matrix& stress) override;

protected:
// 2d block - cyclic distribution info
Expand Down
9 changes: 9 additions & 0 deletions source/module_esolver/esolver_ks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,15 @@ void ESolver_KS<T, Device>::runner(UnitCell& ucell, const int istep)
return;
};

template <typename T, typename Device>
void ESolver_KS<T, Device>::before_scf(UnitCell& ucell, const int istep)
{
ModuleBase::TITLE("ESolver_KS", "before_scf");

//! 1) call before_scf() of ESolver_FP
ESolver_FP::before_scf(ucell, istep);
}

template <typename T, typename Device>
void ESolver_KS<T, Device>::iter_init(UnitCell& ucell, const int istep, const int iter)
{
Expand Down
2 changes: 1 addition & 1 deletion source/module_esolver/esolver_ks.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class ESolver_KS : public ESolver_FP

protected:
//! Something to do before SCF iterations.
virtual void before_scf(UnitCell& ucell, const int istep) {};
virtual void before_scf(UnitCell& ucell, const int istep) override;

//! Something to do before hamilt2density function in each iter loop.
virtual void iter_init(UnitCell& ucell, const int istep, const int iter);
Expand Down
178 changes: 87 additions & 91 deletions source/module_esolver/esolver_ks_lcao.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -757,86 +757,13 @@ void ESolver_KS_LCAO<TK, TR>::hamilt2density_single(UnitCell& ucell, int istep,
//------------------------------------------------------------------------------
//! the 12th function of ESolver_KS_LCAO: update_pot
//! mohan add 2024-05-11
//! 1) print Hamiltonian and Overlap matrix (why related to update_pot()?)
//! 2) print wavefunctions (why related to update_pot()?)
//! 3) print potential
//! 1) print potential
//------------------------------------------------------------------------------
template <typename TK, typename TR>
void ESolver_KS_LCAO<TK, TR>::update_pot(UnitCell& ucell, const int istep, const int iter)
{
ModuleBase::TITLE("ESolver_KS_LCAO", "update_pot");

// 1) print Hamiltonian and Overlap matrix
if (this->conv_esolver || iter == PARAM.inp.scf_nmax)
{
if (!PARAM.globalv.gamma_only_local && (PARAM.inp.out_mat_hs[0] || PARAM.inp.deepks_v_delta))
{
this->GK.renew(true);
}
for (int ik = 0; ik < this->kv.get_nks(); ++ik)
{
if (PARAM.inp.out_mat_hs[0] || PARAM.inp.deepks_v_delta)
{
this->p_hamilt->updateHk(ik);
}
bool bit = false; // LiuXh, 2017-03-21
// if set bit = true, there would be error in soc-multi-core
// calculation, noted by zhengdy-soc
if (this->psi != nullptr && (istep % PARAM.inp.out_interval == 0))
{
hamilt::MatrixBlock<TK> h_mat;
hamilt::MatrixBlock<TK> s_mat;

this->p_hamilt->matrix(h_mat, s_mat);

if (PARAM.inp.out_mat_hs[0])
{
ModuleIO::save_mat(istep,
h_mat.p,
PARAM.globalv.nlocal,
bit,
PARAM.inp.out_mat_hs[1],
1,
PARAM.inp.out_app_flag,
"H",
"data-" + std::to_string(ik),
this->pv,
GlobalV::DRANK);
ModuleIO::save_mat(istep,
s_mat.p,
PARAM.globalv.nlocal,
bit,
PARAM.inp.out_mat_hs[1],
1,
PARAM.inp.out_app_flag,
"S",
"data-" + std::to_string(ik),
this->pv,
GlobalV::DRANK);
}
#ifdef __DEEPKS
if (PARAM.inp.deepks_out_labels && PARAM.inp.deepks_v_delta)
{
DeePKS_domain::save_h_mat(h_mat.p, this->pv.nloc);
}
#endif
}
}
}

// 2) print wavefunctions
if (elecstate::ElecStateLCAO<TK>::out_wfc_lcao && (this->conv_esolver || iter == PARAM.inp.scf_nmax)
&& (istep % PARAM.inp.out_interval == 0))
{
ModuleIO::write_wfc_nao(elecstate::ElecStateLCAO<TK>::out_wfc_lcao,
this->psi[0],
this->pelec->ekb,
this->pelec->wg,
this->pelec->klist->kvec_c,
this->pv,
istep);
}

if (!this->conv_esolver)
{
elecstate::cal_ux(ucell);
Expand Down Expand Up @@ -962,7 +889,9 @@ void ESolver_KS_LCAO<TK, TR>::iter_finish(UnitCell& ucell, const int istep, int&
//! 1) call after_scf() of ESolver_KS
//! 2) write density matrix for sparse matrix
//! 4) write density matrix
//! 6) write Exx matrix
//! 5) write Exx matrix
//! 6) write Hamiltonian and Overlap matrix
//! 7) write wavefunctions
//! 11) write deepks information
//! 12) write rpa information
//! 13) write HR in npz format
Expand All @@ -982,10 +911,10 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep)
this->pelec->cal_tau(*(this->psi));
}

//! 5) call after_scf() of ESolver_KS
//! 2) call after_scf() of ESolver_KS
ESolver_KS<TK>::after_scf(ucell, istep);

//! 6) write density matrix for sparse matrix
//! 3) write density matrix for sparse matrix
ModuleIO::write_dmr(dynamic_cast<const elecstate::ElecStateLCAO<TK>*>(this->pelec)->get_DM()->get_DMR_vector(),
this->pv,
PARAM.inp.out_dm1,
Expand All @@ -995,7 +924,7 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep)
&ucell.nat,
istep);

//! 7) write density matrix
//! 4) write density matrix
if (PARAM.inp.out_dm)
{
std::vector<double> efermis(PARAM.inp.nspin == 2 ? 2 : 1);
Expand All @@ -1012,7 +941,7 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep)
}

#ifdef __EXX
//! 8) write Hexx matrix for NSCF (see `out_chg` in docs/advanced/input_files/input-main.md)
//! 5) write Hexx matrix for NSCF (see `out_chg` in docs/advanced/input_files/input-main.md)
if (PARAM.inp.calculation != "nscf")
{
if (GlobalC::exx_info.info_global.cal_exx && PARAM.inp.out_chg[0]
Expand All @@ -1031,7 +960,74 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep)
}
#endif

//! 9) Write DeePKS information
// 6) write Hamiltonian and Overlap matrix
if (!PARAM.globalv.gamma_only_local && (PARAM.inp.out_mat_hs[0] || PARAM.inp.deepks_v_delta))
{
this->GK.renew(true);
}
for (int ik = 0; ik < this->kv.get_nks(); ++ik)
{
if (PARAM.inp.out_mat_hs[0] || PARAM.inp.deepks_v_delta)
{
this->p_hamilt->updateHk(ik);
}
bool bit = false; // LiuXh, 2017-03-21
// if set bit = true, there would be error in soc-multi-core
// calculation, noted by zhengdy-soc
if (this->psi != nullptr && (istep % PARAM.inp.out_interval == 0))
{
hamilt::MatrixBlock<TK> h_mat;
hamilt::MatrixBlock<TK> s_mat;

this->p_hamilt->matrix(h_mat, s_mat);

if (PARAM.inp.out_mat_hs[0])
{
ModuleIO::save_mat(istep,
h_mat.p,
PARAM.globalv.nlocal,
bit,
PARAM.inp.out_mat_hs[1],
1,
PARAM.inp.out_app_flag,
"H",
"data-" + std::to_string(ik),
this->pv,
GlobalV::DRANK);
ModuleIO::save_mat(istep,
s_mat.p,
PARAM.globalv.nlocal,
bit,
PARAM.inp.out_mat_hs[1],
1,
PARAM.inp.out_app_flag,
"S",
"data-" + std::to_string(ik),
this->pv,
GlobalV::DRANK);
}
#ifdef __DEEPKS
if (PARAM.inp.deepks_out_labels && PARAM.inp.deepks_v_delta)
{
DeePKS_domain::save_h_mat(h_mat.p, this->pv.nloc);
}
#endif
}
}

// 7) write wavefunctions
if (elecstate::ElecStateLCAO<TK>::out_wfc_lcao && (istep % PARAM.inp.out_interval == 0))
{
ModuleIO::write_wfc_nao(elecstate::ElecStateLCAO<TK>::out_wfc_lcao,
this->psi[0],
this->pelec->ekb,
this->pelec->wg,
this->pelec->klist->kvec_c,
this->pv,
istep);
}

//! 8) Write DeePKS information
#ifdef __DEEPKS
std::shared_ptr<LCAO_Deepks> ld_shared_ptr(&GlobalC::ld, [](LCAO_Deepks*) {});
LCAO_Deepks_Interface LDI = LCAO_Deepks_Interface(ld_shared_ptr);
Expand All @@ -1053,7 +1049,7 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep)
ModuleBase::timer::tick("ESolver_KS_LCAO", "out_deepks_labels");
#endif

//! 10) Perform RDMFT calculations
//! 9) Perform RDMFT calculations
/******** test RDMFT *********/
if ( PARAM.inp.rdmft == true ) // rdmft, added by jghan, 2024-10-17
{
Expand All @@ -1079,7 +1075,7 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep)


#ifdef __EXX
// 11) Write RPA information.
// 10) Write RPA information.
if (PARAM.inp.rpa)
{
// ModuleRPA::DFT_RPA_interface
Expand All @@ -1094,7 +1090,7 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep)
}
#endif

// 12) write HR in npz format.
// 11) write HR in npz format.
if (PARAM.inp.out_hr_npz)
{
this->p_hamilt->updateHk(0); // first k point, up spin
Expand All @@ -1113,7 +1109,7 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep)
}
}

// 13) write density matrix in the 'npz' format.
// 12) write density matrix in the 'npz' format.
if (PARAM.inp.out_dm_npz)
{
const elecstate::DensityMatrix<TK, double>* dm
Expand All @@ -1128,7 +1124,7 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep)
}
}

//! 14) Print out information every 'out_interval' steps.
//! 13) Print out information every 'out_interval' steps.
if (PARAM.inp.calculation != "md" || istep % PARAM.inp.out_interval == 0)
{
//! Print out sparse matrix
Expand All @@ -1154,7 +1150,7 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep)
}
}

//! 15) Print out atomic magnetization only when 'spin_constraint' is on.
//! 14) Print out atomic magnetization only when 'spin_constraint' is on.
if (PARAM.inp.sc_mag_switch)
{
spinconstrain::SpinConstrain<TK>& sc = spinconstrain::SpinConstrain<TK>::getScInstance();
Expand All @@ -1163,14 +1159,14 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep)
sc.print_Mag_Force(GlobalV::ofs_running);
}

//! 16) Clean up RA.
//! 15) Clean up RA.
//! this should be last function and put it in the end, mohan request 2024-11-28
if (!PARAM.inp.cal_force && !PARAM.inp.cal_stress)
{
RA.delete_grid();
}

//! 17) Print out quasi-orbitals.
//! 16) Print out quasi-orbitals.
if (PARAM.inp.qo_switch)
{
toQO tqo(PARAM.inp.qo_basis, PARAM.inp.qo_strategy, PARAM.inp.qo_thr, PARAM.inp.qo_screening_coeff);
Expand All @@ -1185,7 +1181,7 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep)
tqo.calculate();
}

//! 18) Print out kinetic matrix.
//! 17) Print out kinetic matrix.
if (PARAM.inp.out_mat_tk[0])
{
hamilt::HS_Matrix_K<TK> hsk(&pv, true);
Expand Down Expand Up @@ -1220,7 +1216,7 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep)
delete ekinetic;
}

//! 19) Wannier 90 function, added by jingan in 2018.11.7
//! 18) Wannier 90 function, added by jingan in 2018.11.7
if (PARAM.inp.calculation == "nscf" && PARAM.inp.towannier90)
{
std::cout << FmtCore::format("\n * * * * * *\n << Start %s.\n", "Wave function to Wannier90");
Expand Down Expand Up @@ -1259,7 +1255,7 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep)
std::cout << FmtCore::format(" >> Finish %s.\n * * * * * *\n", "Wave function to Wannier90");
}

//! 20) berry phase calculations, added by jingan
//! 19) berry phase calculations, added by jingan
if (PARAM.inp.calculation == "nscf" &&
berryphase::berry_phase_flag &&
ModuleSymmetry::Symmetry::symm_flag != 1)
Expand Down
Loading
Loading