Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: add new init_chg method with wavefunctions #5082

Merged
merged 16 commits into from
Sep 14, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/advanced/input_files/input-main.md
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,7 @@ These variables are used to control general system parameters.

- atomic: the density is starting from the summation of the atomic density of single atoms.
- file: the density will be read in from a binary file `charge-density.dat` first. If it does not exist, the charge density will be read in from cube files. Besides, when you do `nspin=1` calculation, you only need the density file SPIN1_CHG.cube. However, if you do `nspin=2` calculation, you also need the density file SPIN2_CHG.cube. The density file should be output with these names if you set out_chg = 1 in INPUT file.
- wfc: the density will be calculated by wavefunctions and occupations. Wavefunctions are read in from binary files `WAVEFUNC*.dat` while occupations are read in from file `istate.info`.
- auto: Abacus first attempts to read the density from a file; if not found, it defaults to using atomic density.
- **Default**: atomic

Expand Down
1 change: 1 addition & 0 deletions source/Makefile.Objects
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,7 @@ OBJS_IO=input_conv.o\
rhog_io.o\
read_exit_file.o\
read_wfc_pw.o\
read_wfc_to_rho.o\
restart.o\
binstream.o\
to_wannier90.o\
Expand Down
9 changes: 5 additions & 4 deletions source/module_cell/klist.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ class K_Vectors
* This function gets the global index of a k-point based on its local index and the process pool ID.
* The global index is used when the k-points are distributed among multiple process pools.
*
* @param nkstot The total number of k-points.
* @param ik The local index of the k-point.
*
* @return int Returns the global index of the k-point.
Expand All @@ -116,7 +117,7 @@ class K_Vectors
* process pools (KPAR), and adding the remainder if the process pool ID (MY_POOL) is less than the remainder.
* @note The function is declared as inline for efficiency.
*/
inline int getik_global(const int& ik) const;
static inline int getik_global(const int& ik, const int& nkstot);
Qianruipku marked this conversation as resolved.
Show resolved Hide resolved

int get_nks() const
{
Expand Down Expand Up @@ -390,10 +391,10 @@ class K_Vectors
void print_klists(std::ofstream& fn);
};

inline int K_Vectors::getik_global(const int& ik) const
inline int K_Vectors::getik_global(const int& ik, const int& nkstot)
{
int nkp = this->nkstot / GlobalV::KPAR;
int rem = this->nkstot % GlobalV::KPAR;
int nkp = nkstot / GlobalV::KPAR;
Qianruipku marked this conversation as resolved.
Show resolved Hide resolved
int rem = nkstot % GlobalV::KPAR;
if (GlobalV::MY_POOL < rem)
{
return GlobalV::MY_POOL * nkp + GlobalV::MY_POOL + ik;
Expand Down
4 changes: 2 additions & 2 deletions source/module_elecstate/elecstate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ void ElecState::calEBand()
return;
}

void ElecState::init_scf(const int istep, const ModuleBase::ComplexMatrix& strucfac)
void ElecState::init_scf(const int istep, const ModuleBase::ComplexMatrix& strucfac, const void* wfcpw)
{
//---------Charge part-----------------
// core correction potential.
Expand All @@ -225,7 +225,7 @@ void ElecState::init_scf(const int istep, const ModuleBase::ComplexMatrix& struc
//--------------------------------------------------------------------
if (istep == 0)
{
this->charge->init_rho(this->eferm, strucfac, this->bigpw->nbz, this->bigpw->bz);
this->charge->init_rho(this->eferm, strucfac, (const void*)this->klist, wfcpw);
this->charge->check_rho(); // check the rho
}

Expand Down
9 changes: 8 additions & 1 deletion source/module_elecstate/elecstate.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,14 @@ class ElecState
return;
}

void init_scf(const int istep, const ModuleBase::ComplexMatrix& strucfac);
/**
* @brief Init rho_core, init rho, renormalize rho, init pot
*
* @param istep i-th step
* @param strucfac structure factor
* @param wfcpw PW basis for wave function if needed
*/
void init_scf(const int istep, const ModuleBase::ComplexMatrix& strucfac, const void* wfcpw = nullptr);
std::string classname = "elecstate";

int iter = 0; ///< scf iteration
Expand Down
21 changes: 11 additions & 10 deletions source/module_elecstate/module_charge/charge.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,17 @@ class Charge

/**
* @brief Init charge density from file or atomic pseudo-wave-functions
*
* @param eferm_iout fermi energy to be initialized
* @param strucFac [in] structure factor
* @param nbz [in] number of big grids in z direction
* @param bz [in] number of small grids in big grids for z dirction
*
* @param eferm_iout [out] fermi energy to be initialized
* @param strucFac [in] structure factor
* @param klist [in] k points list if needed
* @param wfcpw [in] PW basis for wave function if needed
*/
void init_rho(elecstate::efermi& eferm_iout, const ModuleBase::ComplexMatrix& strucFac, const int& nbz, const int& bz);

void init_rho(elecstate::efermi& eferm_iout,
const ModuleBase::ComplexMatrix& strucFac,
const void* klist = nullptr,
const void* wfcpw = nullptr);

void allocate(const int &nspin_in);

void atomic_rho(const int spin_number_need,
Expand Down Expand Up @@ -108,10 +111,8 @@ class Charge
/**
* @brief init some arrays for mpi_inter_pools, rho_mpi
*
* @param nbz number of bigz in big grids
* @param bz number of z for each bigz
*/
void init_chgmpi(const int& nbz, const int& bz);
void init_chgmpi();

/**
* @brief Sum rho at different pools (k-point parallelism).
Expand Down
20 changes: 18 additions & 2 deletions source/module_elecstate/module_charge/charge_init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,15 @@
#include "module_hamilt_pw/hamilt_pwdft/parallel_grid.h"
#include "module_io/rho_io.h"
#include "module_io/rhog_io.h"
#include "module_io/read_wfc_to_rho.h"
#ifdef USE_PAW
#include "module_cell/module_paw/paw_cell.h"
#endif

void Charge::init_rho(elecstate::efermi& eferm_iout, const ModuleBase::ComplexMatrix& strucFac, const int& nbz, const int& bz)
void Charge::init_rho(elecstate::efermi& eferm_iout,
const ModuleBase::ComplexMatrix& strucFac,
const void* klist,
const void* wfcpw)
{
ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "init_chg", PARAM.inp.init_chg);

Expand Down Expand Up @@ -195,8 +199,20 @@ void Charge::init_rho(elecstate::efermi& eferm_iout, const ModuleBase::ComplexMa
GlobalC::restart.info_load.load_charge_finish = true;
}
#ifdef __MPI
this->init_chgmpi(nbz, bz);
this->init_chgmpi();
#endif
if (PARAM.inp.init_chg == "wfc")
{
if (wfcpw == nullptr)
{
ModuleBase::WARNING_QUIT("Charge::init_rho", "wfc is only supported for PW-KSDFT.");
}
const ModulePW::PW_Basis_K* pw_wfc = reinterpret_cast<ModulePW::PW_Basis_K*>(const_cast<void*>(wfcpw));
const K_Vectors* kv = reinterpret_cast<const K_Vectors*>(klist);
const int nkstot = kv->get_nkstot();
const std::vector<int>& isk = kv->isk;
ModuleIO::read_wfc_to_rho(pw_wfc, nkstot, isk, *this);
}
}

//==========================================================
Expand Down
2 changes: 1 addition & 1 deletion source/module_elecstate/module_charge/charge_mpi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include "module_elecstate/elecstate_getters.h"
#include "module_parameter/parameter.h"
#ifdef __MPI
void Charge::init_chgmpi(const int& nbz, const int& bz)
void Charge::init_chgmpi()
{
if (GlobalV::NPROC_IN_STOGROUP % GlobalV::KPAR == 0)
{
Expand Down
5 changes: 4 additions & 1 deletion source/module_elecstate/test/elecstate_base_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@ void Charge::set_rho_core(ModuleBase::ComplexMatrix const&)
void Charge::set_rho_core_paw()
{
}
void Charge::init_rho(elecstate::efermi&, ModuleBase::ComplexMatrix const&, const int&, const int&)
void Charge::init_rho(elecstate::efermi&,
ModuleBase::ComplexMatrix const&,
const void*,
const void*)
{
}
void Charge::set_rhopw(ModulePW::PW_Basis*)
Expand Down
5 changes: 4 additions & 1 deletion source/module_elecstate/test/elecstate_pw_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,10 @@ void Charge::set_rho_core(ModuleBase::ComplexMatrix const&)
void Charge::set_rho_core_paw()
{
}
void Charge::init_rho(elecstate::efermi&, ModuleBase::ComplexMatrix const&, const int&, const int&)
void Charge::init_rho(elecstate::efermi&,
ModuleBase::ComplexMatrix const&,
const void*,
const void*)
{
}
void Charge::set_rhopw(ModulePW::PW_Basis*)
Expand Down
6 changes: 3 additions & 3 deletions source/module_elecstate/test_mpi/charge_mpi_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ TEST_F(ChargeMpiTest, reduce_diff_pools1)
}
double refsum = sum_array(array_rho, nrxx);

charge->init_chgmpi(nz, 1);
charge->init_chgmpi();
charge->reduce_diff_pools(array_rho);
double sum = sum_array(array_rho, nrxx);
EXPECT_EQ(sum, refsum * GlobalV::KPAR);
Expand Down Expand Up @@ -150,7 +150,7 @@ TEST_F(ChargeMpiTest, reduce_diff_pools2)
}
}

charge->init_chgmpi(nz, 1);
charge->init_chgmpi();
charge->reduce_diff_pools(array_rho);
double sum = sum_array(array_rho, nrxx);
MPI_Allreduce(MPI_IN_PLACE, &sum, 1, MPI_DOUBLE, MPI_SUM, POOL_WORLD);
Expand Down Expand Up @@ -194,7 +194,7 @@ TEST_F(ChargeMpiTest, rho_mpi)
charge->nrxx = nrxx;
charge->rho[0] = new double[nrxx];
charge->kin_r[0] = new double[nrxx];
charge->init_chgmpi(nz, 1);
charge->init_chgmpi();
charge->rho_mpi();

delete[] charge->rho[0];
Expand Down
2 changes: 1 addition & 1 deletion source/module_esolver/esolver_ks_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ void ESolver_KS_PW<T, Device>::before_scf(const int istep)
}

//! calculate the total local pseudopotential in real space
this->pelec->init_scf(istep, this->sf.strucFac);
this->pelec->init_scf(istep, this->sf.strucFac, (void*)this->pw_wfc);

//! output the initial charge density
if (PARAM.inp.out_chg[0] == 2)
Expand Down
11 changes: 7 additions & 4 deletions source/module_hamilt_pw/hamilt_pwdft/elecond.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,9 @@ void EleCond::jjresponse_ks(const int ik, const int nt, const double dt, const d
hamilt::Velocity& velop, double* ct11, double* ct12, double* ct22)
{
const int nbands = GlobalV::NBANDS;
if (wg(ik, 0) - wg(ik, nbands - 1) < 1e-8 || nbands == 0)
if (wg(ik, 0) - wg(ik, nbands - 1) < 1e-8 || nbands == 0) {
return;
}
const char transn = 'N';
const char transc = 'C';
const int ndim = 3;
Expand All @@ -121,20 +122,21 @@ void EleCond::jjresponse_ks(const int ik, const int nt, const double dt, const d
#ifdef __MPI
MPI_Allreduce(MPI_IN_PLACE, pij.data(), nbands * nbands, MPI_DOUBLE_COMPLEX, MPI_SUM, POOL_WORLD);
#endif
if (!gamma_only)
if (!gamma_only) {
for (int ib = 0, ijb = 0; ib < nbands; ++ib)
{
for (int jb = ib + 1; jb < nbands; ++jb, ++ijb)
{
pij2[ijb] += norm(pij[ib * nbands + jb]);
}
}
}
}

if (GlobalV::RANK_IN_POOL == 0)
{
int nkstot = this->p_kv->get_nkstot();
int ikglobal = this->p_kv->getik_global(ik);
int ikglobal = K_Vectors::getik_global(ik, nkstot);
std::stringstream ss;
ss << GlobalV::global_out_dir << "vmatrix" << ikglobal + 1 << ".dat";
Binstream binpij(ss.str(), "w");
Expand Down Expand Up @@ -168,8 +170,9 @@ void EleCond::jjresponse_ks(const int ik, const int nt, const double dt, const d
for (int jb = ib + 1; jb < nbands; ++jb, ++ijb)
{
double ej = enb[jb];
if (ej - ei > decut)
if (ej - ei > decut) {
continue;
}
double fj = wg(ik, jb);
double tmct = sin((ej - ei) * (it)*dt) * (fi - fj) * pij2[ijb];
tmct11 += tmct;
Expand Down
Loading
Loading