Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: Remove global dependence of some functions in DeePKS. #5778

Merged
merged 3 commits into from
Dec 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion source/Makefile.Objects
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ OBJS_CELL=atom_pseudo.o\

OBJS_DEEPKS=LCAO_deepks.o\
deepks_force.o\
LCAO_deepks_odelta.o\
deepks_orbital.o\
LCAO_deepks_io.o\
LCAO_deepks_mpi.o\
LCAO_deepks_pdm.o\
Expand Down
1 change: 1 addition & 0 deletions source/module_hamilt_lcao/hamilt_lcaodft/FORCE.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class Force_LCAO
ModuleBase::matrix& svnl_dbeta,
ModuleBase::matrix& svl_dphi,
#ifdef __DEEPKS
ModuleBase::matrix& fvnl_dalpha,
ModuleBase::matrix& svnl_dalpha,
#endif
typename TGint<T>::type& gint,
Expand Down
18 changes: 14 additions & 4 deletions source/module_hamilt_lcao/hamilt_lcaodft/FORCE_STRESS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
ModuleBase::matrix fewalds;
ModuleBase::matrix fcc;
ModuleBase::matrix fscc;
#ifdef __DEEPKS
ModuleBase::matrix fvnl_dalpha; // deepks
#endif

fvl_dphi.create(nat, 3); // must do it now, update it later, noted by zhengdy

Expand All @@ -93,6 +96,9 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
fewalds.create(nat, 3);
fcc.create(nat, 3);
fscc.create(nat, 3);
#ifdef __DEEPKS
fvnl_dalpha.create(nat, 3); // deepks
#endif

// calculate basic terms in Force, same method with PW base
this->calForcePwPart(ucell,
Expand Down Expand Up @@ -172,6 +178,7 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
svnl_dbeta,
svl_dphi,
#ifdef __DEEPKS
fvnl_dalpha,
svnl_dalpha,
#endif
gint_gamma,
Expand Down Expand Up @@ -454,7 +461,7 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
// mohan add 2021-08-04
if (PARAM.inp.deepks_scf)
{
fcs(iat, i) += GlobalC::ld.F_delta(iat, i);
fcs(iat, i) += fvnl_dalpha(iat, i);
}
#endif
// sum total force for correction
Expand Down Expand Up @@ -499,7 +506,7 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
if (PARAM.inp.deepks_scf)
{
const std::string file_fbase = PARAM.globalv.global_out_dir + "deepks_fbase.npy";
LCAO_deepks_io::save_npy_f(fcs - GlobalC::ld.F_delta,
LCAO_deepks_io::save_npy_f(fcs - fvnl_dalpha,
file_fbase,
ucell.nat,
GlobalV::MY_RANK); // Ry/Bohr, F_base
Expand Down Expand Up @@ -636,8 +643,7 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
// caoyu add 2021-06-03
if (PARAM.inp.deepks_scf)
{
ModuleIO::print_force(GlobalV::ofs_running, ucell, "DeePKS FORCE", GlobalC::ld.F_delta, true);
// this->print_force("DeePKS FORCE", GlobalC::ld.F_delta, 1, ry);
ModuleIO::print_force(GlobalV::ofs_running, ucell, "DeePKS FORCE", fvnl_dalpha, true);
}
#endif
}
Expand Down Expand Up @@ -891,6 +897,7 @@ void Force_Stress_LCAO<double>::integral_part(const bool isGammaOnly,
ModuleBase::matrix& svnl_dbeta,
ModuleBase::matrix& svl_dphi,
#if __DEEPKS
ModuleBase::matrix& fvnl_dalpha,
ModuleBase::matrix& svnl_dalpha,
#endif
Gint_Gamma& gint_gamma, // mohan add 2024-04-01
Expand All @@ -917,6 +924,7 @@ void Force_Stress_LCAO<double>::integral_part(const bool isGammaOnly,
svnl_dbeta,
svl_dphi,
#if __DEEPKS
fvnl_dalpha,
svnl_dalpha,
#endif
gint_gamma,
Expand Down Expand Up @@ -944,6 +952,7 @@ void Force_Stress_LCAO<std::complex<double>>::integral_part(const bool isGammaOn
ModuleBase::matrix& svnl_dbeta,
ModuleBase::matrix& svl_dphi,
#if __DEEPKS
ModuleBase::matrix& fvnl_dalpha,
ModuleBase::matrix& svnl_dalpha,
#endif
Gint_Gamma& gint_gamma,
Expand All @@ -969,6 +978,7 @@ void Force_Stress_LCAO<std::complex<double>>::integral_part(const bool isGammaOn
svnl_dbeta,
svl_dphi,
#if __DEEPKS
fvnl_dalpha,
svnl_dalpha,
#endif
gint_k,
Expand Down
1 change: 1 addition & 0 deletions source/module_hamilt_lcao/hamilt_lcaodft/FORCE_STRESS.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class Force_Stress_LCAO
ModuleBase::matrix& svnl_dbeta,
ModuleBase::matrix& svl_dphi,
#if __DEEPKS
ModuleBase::matrix& fvnl_dalpha,
ModuleBase::matrix& svnl_dalpha,
#endif
Gint_Gamma& gint_gamma,
Expand Down
70 changes: 35 additions & 35 deletions source/module_hamilt_lcao/hamilt_lcaodft/FORCE_gamma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ void Force_LCAO<double>::ftable(const bool isforce,
ModuleBase::matrix& svnl_dbeta,
ModuleBase::matrix& svl_dphi,
#ifdef __DEEPKS
ModuleBase::matrix& fvnl_dalpha,
ModuleBase::matrix& svnl_dalpha,
#endif
TGint<double>::type& gint,
Expand Down Expand Up @@ -246,15 +247,13 @@ void Force_LCAO<double>::ftable(const bool isforce,
false /*reset dm to gint*/);

#ifdef __DEEPKS
const std::vector<std::vector<double>>& dm_gamma = dm->get_DMK_vector();
if (PARAM.inp.deepks_scf)
{
const std::vector<std::vector<double>>& dm_gamma = dm->get_DMK_vector();

// when deepks_scf is on, the init pdm should be same as the out pdm, so we should not recalculate the pdm
// GlobalC::ld.cal_projected_DM(dm, ucell, orb, gd);

GlobalC::ld.cal_descriptor(ucell.nat);

GlobalC::ld.cal_gedm(ucell.nat);

const int nks = 1;
Expand All @@ -269,40 +268,9 @@ void Force_LCAO<double>::ftable(const bool isforce,
GlobalC::ld.phialpha,
GlobalC::ld.gedm,
GlobalC::ld.inl_index,
GlobalC::ld.F_delta,
fvnl_dalpha,
isstress,
svnl_dalpha);

#ifdef __MPI
Parallel_Reduce::reduce_all(GlobalC::ld.F_delta.c, GlobalC::ld.F_delta.nr * GlobalC::ld.F_delta.nc);

if (isstress)
{
Parallel_Reduce::reduce_pool(svnl_dalpha.c, svnl_dalpha.nr * svnl_dalpha.nc);
}
#endif

if (PARAM.inp.deepks_out_unittest)
{
const int nks = 1; // 1 for gamma-only
LCAO_deepks_io::print_dm(nks, PARAM.globalv.nlocal, this->ParaV->nrow, dm_gamma);

GlobalC::ld.check_projected_dm();

GlobalC::ld.check_descriptor(ucell, PARAM.globalv.global_out_dir);

GlobalC::ld.check_gedm();

GlobalC::ld.cal_e_delta_band(dm_gamma, nks);

std::ofstream ofs("E_delta_bands.dat");
ofs << std::setprecision(10) << GlobalC::ld.e_delta_band;

std::ofstream ofs1("E_delta.dat");
ofs1 << std::setprecision(10) << GlobalC::ld.E_delta;

DeePKS_domain::check_f_delta(ucell.nat, GlobalC::ld.F_delta, svnl_dalpha);
}
}
#endif

Expand All @@ -312,14 +280,46 @@ void Force_LCAO<double>::ftable(const bool isforce,
Parallel_Reduce::reduce_pool(ftvnl_dphi.c, ftvnl_dphi.nr * ftvnl_dphi.nc);
Parallel_Reduce::reduce_pool(fvnl_dbeta.c, fvnl_dbeta.nr * fvnl_dbeta.nc);
Parallel_Reduce::reduce_pool(fvl_dphi.c, fvl_dphi.nr * fvl_dphi.nc);
#ifdef __DEEPKS
Parallel_Reduce::reduce_pool(fvnl_dalpha.c, fvnl_dalpha.nr * fvnl_dalpha.nc);
#endif
}
if (isstress)
{
Parallel_Reduce::reduce_pool(soverlap.c, soverlap.nr * soverlap.nc);
Parallel_Reduce::reduce_pool(stvnl_dphi.c, stvnl_dphi.nr * stvnl_dphi.nc);
Parallel_Reduce::reduce_pool(svnl_dbeta.c, svnl_dbeta.nr * svnl_dbeta.nc);
Parallel_Reduce::reduce_pool(svl_dphi.c, svl_dphi.nr * svl_dphi.nc);
#ifdef __DEEPKS
Parallel_Reduce::reduce_pool(svnl_dalpha.c, svnl_dalpha.nr * svnl_dalpha.nc);
#endif
}

#ifdef __DEEPKS
// It seems these test should not all be here, should be moved in the future
// Also, these test are not in multi-k case now
if (PARAM.inp.deepks_scf && PARAM.inp.deepks_out_unittest)
{
const int nks = 1; // 1 for gamma-only
LCAO_deepks_io::print_dm(nks, PARAM.globalv.nlocal, this->ParaV->nrow, dm_gamma);

GlobalC::ld.check_projected_dm();

GlobalC::ld.check_descriptor(ucell, PARAM.globalv.global_out_dir);

GlobalC::ld.check_gedm();

GlobalC::ld.cal_e_delta_band(dm_gamma, nks);

std::ofstream ofs("E_delta_bands.dat");
ofs << std::setprecision(10) << GlobalC::ld.e_delta_band;

std::ofstream ofs1("E_delta.dat");
ofs1 << std::setprecision(10) << GlobalC::ld.E_delta;

DeePKS_domain::check_f_delta(ucell.nat, fvnl_dalpha, svnl_dalpha);
}
#endif

// delete DSloc_x, DSloc_y, DSloc_z
// delete DHloc_fixed_x, DHloc_fixed_y, DHloc_fixed_z
Expand Down
17 changes: 8 additions & 9 deletions source/module_hamilt_lcao/hamilt_lcaodft/FORCE_k.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ void Force_LCAO<std::complex<double>>::ftable(const bool isforce,
ModuleBase::matrix& svnl_dbeta,
ModuleBase::matrix& svl_dphi,
#ifdef __DEEPKS
ModuleBase::matrix& fvnl_dalpha,
ModuleBase::matrix& svnl_dalpha,
#endif
TGint<std::complex<double>>::type& gint,
Expand Down Expand Up @@ -363,17 +364,9 @@ void Force_LCAO<std::complex<double>>::ftable(const bool isforce,
GlobalC::ld.phialpha,
GlobalC::ld.gedm,
GlobalC::ld.inl_index,
GlobalC::ld.F_delta,
fvnl_dalpha,
isstress,
svnl_dalpha);

#ifdef __MPI
Parallel_Reduce::reduce_all(GlobalC::ld.F_delta.c, GlobalC::ld.F_delta.nr * GlobalC::ld.F_delta.nc);
if (isstress)
{
Parallel_Reduce::reduce_pool(svnl_dalpha.c, svnl_dalpha.nr * svnl_dalpha.nc);
}
#endif
}
#endif

Expand All @@ -386,13 +379,19 @@ void Force_LCAO<std::complex<double>>::ftable(const bool isforce,
Parallel_Reduce::reduce_pool(ftvnl_dphi.c, ftvnl_dphi.nr * ftvnl_dphi.nc);
Parallel_Reduce::reduce_pool(fvnl_dbeta.c, fvnl_dbeta.nr * fvnl_dbeta.nc);
Parallel_Reduce::reduce_pool(fvl_dphi.c, fvl_dphi.nr * fvl_dphi.nc);
#ifdef __DEEPKS
Parallel_Reduce::reduce_pool(fvnl_dalpha.c, fvnl_dalpha.nr * fvnl_dalpha.nc);
#endif
}
if (isstress)
{
Parallel_Reduce::reduce_pool(soverlap.c, soverlap.nr * soverlap.nc);
Parallel_Reduce::reduce_pool(stvnl_dphi.c, stvnl_dphi.nr * stvnl_dphi.nc);
Parallel_Reduce::reduce_pool(svnl_dbeta.c, svnl_dbeta.nr * svnl_dbeta.nc);
Parallel_Reduce::reduce_pool(svl_dphi.c, svl_dphi.nr * svl_dphi.nc);
#ifdef __DEEPKS
Parallel_Reduce::reduce_pool(svnl_dalpha.c, svnl_dalpha.nr * svnl_dalpha.nc);
#endif
}

ModuleBase::timer::tick("Force_LCAO", "ftable");
Expand Down
2 changes: 1 addition & 1 deletion source/module_hamilt_lcao/module_deepks/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ if(ENABLE_DEEPKS)
list(APPEND objects
LCAO_deepks.cpp
deepks_force.cpp
LCAO_deepks_odelta.cpp
deepks_orbital.cpp
LCAO_deepks_io.cpp
LCAO_deepks_mpi.cpp
LCAO_deepks_pdm.cpp
Expand Down
34 changes: 9 additions & 25 deletions source/module_hamilt_lcao/module_deepks/LCAO_deepks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,6 @@ void LCAO_Deepks::allocate_V_delta(const int nat, const int nks)
}
if (PARAM.inp.cal_force)
{
// init F_delta
F_delta.create(nat, 3);
if (PARAM.inp.deepks_out_labels)
{
this->init_gdmx(nat);
Expand All @@ -342,34 +340,24 @@ void LCAO_Deepks::allocate_V_delta(const int nat, const int nks)
// gdmx is used only in calculating gvx
}

if (PARAM.inp.deepks_bandgap)
{
// init o_delta
o_delta.create(nks, 1);
}

return;
}

void LCAO_Deepks::init_orbital_pdm_shell(const int nks)
{

this->orbital_pdm_shell = new double***[nks];
this->orbital_pdm_shell = new double**[nks];

for (int iks = 0; iks < nks; iks++)
{
this->orbital_pdm_shell[iks] = new double**[1];
for (int hl = 0; hl < 1; hl++)
this->orbital_pdm_shell[iks] = new double*[this->inlmax];
for (int inl = 0; inl < this->inlmax; inl++)
{
this->orbital_pdm_shell[iks][hl] = new double*[this->inlmax];

for (int inl = 0; inl < this->inlmax; inl++)
{
this->orbital_pdm_shell[iks][hl][inl] = new double[(2 * this->lmaxd + 1) * (2 * this->lmaxd + 1)];
ModuleBase::GlobalFunc::ZEROS(orbital_pdm_shell[iks][hl][inl],
(2 * this->lmaxd + 1) * (2 * this->lmaxd + 1));
}
this->orbital_pdm_shell[iks][inl] = new double[(2 * this->lmaxd + 1) * (2 * this->lmaxd + 1)];
ModuleBase::GlobalFunc::ZEROS(orbital_pdm_shell[iks][inl],
(2 * this->lmaxd + 1) * (2 * this->lmaxd + 1));
}

}

return;
Expand All @@ -379,13 +367,9 @@ void LCAO_Deepks::del_orbital_pdm_shell(const int nks)
{
for (int iks = 0; iks < nks; iks++)
{
for (int hl = 0; hl < 1; hl++)
for (int inl = 0; inl < this->inlmax; inl++)
{
for (int inl = 0; inl < this->inlmax; inl++)
{
delete[] this->orbital_pdm_shell[iks][hl][inl];
}
delete[] this->orbital_pdm_shell[iks][hl];
delete[] this->orbital_pdm_shell[iks][inl];
}
delete[] this->orbital_pdm_shell[iks];
}
Expand Down
Loading
Loading