Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Support outputting partial charge densities for different k-points and spins separately when using PW basis set #4829

Merged
merged 17 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions source/Makefile.Objects
Original file line number Diff line number Diff line change
Expand Up @@ -520,16 +520,16 @@ OBJS_IO=input_conv.o\
read_input_item_exx_dftu.o\
read_input_item_other.o\
read_input_item_output.o\
bcast_globalv.o
bcast_globalv.o\

OBJS_IO_LCAO=cal_r_overlap_R.o\
write_orb_info.o\
write_dos_lcao.o\
write_proj_band_lcao.o\
write_istate_info.o\
nscf_fermi_surf.o\
get_pchg.o\
get_wf.o\
get_pchg_lcao.o\
get_wf_lcao.o\
io_dmk.o\
unk_overlap_lcao.o\
read_wfc_nao.o\
Expand Down
145 changes: 25 additions & 120 deletions source/module_esolver/esolver_ks_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "module_hsolver/kernels/math_kernel_op.h"
#include "module_io/berryphase.h"
#include "module_io/cube_io.h"
#include "module_io/get_pchg_pw.h"
#include "module_io/numerical_basis.h"
#include "module_io/numerical_descriptor.h"
#include "module_io/to_wannier90_pw.h"
Expand Down Expand Up @@ -99,7 +100,8 @@ ESolver_KS_PW<T, Device>::~ESolver_KS_PW()
}

template <typename T, typename Device>
void ESolver_KS_PW<T, Device>::before_all_runners(const Input_para& inp, UnitCell& ucell) {
void ESolver_KS_PW<T, Device>::before_all_runners(const Input_para& inp, UnitCell& ucell)
{
// 1) call before_all_runners() of ESolver_KS
ESolver_KS<T, Device>::before_all_runners(inp, ucell);

Expand Down Expand Up @@ -166,7 +168,6 @@ void ESolver_KS_PW<T, Device>::before_all_runners(const Input_para& inp, UnitCel
}
}


template <typename T, typename Device>
void ESolver_KS_PW<T, Device>::before_scf(const int istep)
{
Expand Down Expand Up @@ -658,127 +659,31 @@ void ESolver_KS_PW<T, Device>::after_scf(const int istep)
this->psi[0].size());
}

// Get bands_to_print through public function of INPUT (returns a const
// pointer to string)
// Calculate band-decomposed (partial) charge density
const std::vector<int> bands_to_print = PARAM.inp.bands_to_print;
if (bands_to_print.size() > 0)
{
// bands_picked is a vector of 0s and 1s, where 1 means the band is
// picked to output
std::vector<int> bands_picked;
bands_picked.resize(this->kspw_psi->get_nbands());
ModuleBase::GlobalFunc::ZEROS(bands_picked.data(), this->kspw_psi->get_nbands());

// Check if length of bands_to_print is valid
if (static_cast<int>(bands_to_print.size()) > this->kspw_psi->get_nbands())
{
ModuleBase::WARNING_QUIT("ESolver_KS_PW::after_scf",
"The number of bands specified by `bands_to_print` in the "
"INPUT file exceeds `nbands`!");
}

// Check if all elements in bands_picked are 0 or 1
for (int value: bands_to_print)
{
if (value != 0 && value != 1)
{
ModuleBase::WARNING_QUIT("ESolver_KS_PW::after_scf",
"The elements of `bands_to_print` must be either 0 or 1. "
"Invalid values found!");
}
}

// Fill bands_picked with values from bands_to_print
// Remaining bands are already set to 0
int length = std::min(static_cast<int>(bands_to_print.size()), this->kspw_psi->get_nbands());
for (int i = 0; i < length; ++i)
{
// bands_to_print rely on function parse_expression
// Initially designed for ocp_set, which can be double
bands_picked[i] = static_cast<int>(bands_to_print[i]);
}

std::complex<double>* wfcr = new std::complex<double>[this->pw_rho->nxyz];
double* rho_band = new double[this->pw_rho->nxyz];

for (int ib = 0; ib < this->kspw_psi->get_nbands(); ++ib)
{
// Skip the loop iteration if bands_picked[ib] is 0
if (!bands_picked[ib])
{
continue;
}

for (int i = 0; i < this->pw_rho->nxyz; i++)
{
// Initialize rho_band to zero for each band
rho_band[i] = 0.0;
}

for (int ik = 0; ik < this->kv.get_nks(); ik++)
{
this->psi->fix_k(ik);
this->pw_wfc->recip_to_real(this->ctx, &psi[0](ib, 0), wfcr, ik);

double w1 = static_cast<double>(this->kv.wk[ik] / GlobalC::ucell.omega);

for (int i = 0; i < this->pw_rho->nxyz; i++)
{
rho_band[i] += std::norm(wfcr[i]) * w1;
}
}

// Symmetrize the charge density, otherwise the results are incorrect if the symmetry is on
std::cout << " Symmetrizing band-decomposed charge density..." << std::endl;
Symmetry_rho srho;
for (int is = 0; is < GlobalV::NSPIN; is++)
{
// Use vector instead of raw pointers
std::vector<double*> rho_save_pointers(GlobalV::NSPIN, rho_band);
std::vector<std::vector<std::complex<double>>> rhog(
GlobalV::NSPIN,
std::vector<std::complex<double>>(this->pelec->charge->ngmc));

// Convert vector of vectors to vector of pointers
std::vector<std::complex<double>*> rhog_pointers(GlobalV::NSPIN);
for (int s = 0; s < GlobalV::NSPIN; s++)
{
rhog_pointers[s] = rhog[s].data();
}

srho.begin(is,
rho_save_pointers.data(),
rhog_pointers.data(),
this->pelec->charge->ngmc,
nullptr,
this->pw_rhod,
GlobalC::Pgrid,
GlobalC::ucell.symm);
}

std::stringstream ssc;
ssc << GlobalV::global_out_dir << "BAND" << ib + 1 << "_CHG.cube"; // band index starts from 1

ModuleIO::write_cube(
#ifdef __MPI
this->pw_big->bz,
this->pw_big->nbz,
this->pw_rhod->nplane,
this->pw_rhod->startz_current,
#endif
rho_band,
0,
GlobalV::NSPIN,
0,
ssc.str(),
this->pw_rhod->nx,
this->pw_rhod->ny,
this->pw_rhod->nz,
0.0,
&(GlobalC::ucell));
}
delete[] wfcr;
delete[] rho_band;
ModuleIO::get_pchg_pw(bands_to_print,
this->kspw_psi->get_nbands(),
GlobalV::NSPIN,
this->pw_rhod->nx,
this->pw_rhod->ny,
this->pw_rhod->nz,
this->pw_rhod->nxyz,
this->kv.get_nks(),
this->kv.isk,
this->kv.wk,
this->pw_big->bz,
this->pw_big->nbz,
this->pelec->charge->ngmc,
&GlobalC::ucell,
this->psi,
this->pw_rhod,
this->pw_wfc,
this->ctx,
GlobalC::Pgrid,
GlobalV::global_out_dir,
PARAM.inp.if_separate_k);
}
}

Expand Down
4 changes: 2 additions & 2 deletions source/module_esolver/lcao_before_scf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
#include "module_cell/module_neighbor/sltk_atom_arrange.h"
#include "module_cell/module_neighbor/sltk_grid_driver.h"
#include "module_io/berryphase.h"
#include "module_io/get_pchg.h"
#include "module_io/get_wf.h"
#include "module_io/get_pchg_lcao.h"
#include "module_io/get_wf_lcao.h"
#include "module_io/to_wannier90_lcao.h"
#include "module_io/to_wannier90_lcao_in_pw.h"
#include "module_io/write_HS_R.h"
Expand Down
4 changes: 2 additions & 2 deletions source/module_esolver/lcao_nscf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
#include "module_cell/module_neighbor/sltk_grid_driver.h"
#include "module_io/berryphase.h"
#include "module_io/cube_io.h"
#include "module_io/get_pchg.h"
#include "module_io/get_wf.h"
#include "module_io/get_pchg_lcao.h"
#include "module_io/get_wf_lcao.h"
#include "module_io/to_wannier90_lcao.h"
#include "module_io/to_wannier90_lcao_in_pw.h"
#include "module_io/write_HS_R.h"
Expand Down
20 changes: 11 additions & 9 deletions source/module_esolver/lcao_others.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
#include "module_cell/module_neighbor/sltk_atom_arrange.h"
#include "module_cell/module_neighbor/sltk_grid_driver.h"
#include "module_io/berryphase.h"
#include "module_io/get_pchg.h"
#include "module_io/get_wf.h"
#include "module_io/get_pchg_lcao.h"
#include "module_io/get_wf_lcao.h"
#include "module_io/to_wannier90_lcao.h"
#include "module_io/to_wannier90_lcao_in_pw.h"
#include "module_io/write_HS_R.h"
Expand Down Expand Up @@ -45,7 +45,8 @@ void ESolver_KS_LCAO<TK, TR>::others(const int istep)

const std::string cal_type = GlobalV::CALCULATION;

if (cal_type == "get_S") {
if (cal_type == "get_S")
{
std::cout << FmtCore::format("\n * * * * * *\n << Start %s.\n", "writing the overlap matrix");
this->get_S();
std::cout << FmtCore::format(" >> Finish %s.\n * * * * * *\n", "writing the overlap matrix");
Expand All @@ -54,7 +55,9 @@ void ESolver_KS_LCAO<TK, TR>::others(const int istep)

// return; // use 'return' will cause segmentation fault. by mohan
// 2024-06-09
} else if (cal_type == "test_memory") {
}
else if (cal_type == "test_memory")
{
std::cout << FmtCore::format("\n * * * * * *\n << Start %s.\n", "testing memory");
Cal_Test::test_memory(this->pw_rho,
this->pw_wfc,
Expand All @@ -67,9 +70,9 @@ void ESolver_KS_LCAO<TK, TR>::others(const int istep)
{
// test_search_neighbor();
std::cout << FmtCore::format("\n * * * * * *\n << Start %s.\n", "testing neighbour");
if (GlobalV::SEARCH_RADIUS < 0) {
std::cout << " SEARCH_RADIUS : " << GlobalV::SEARCH_RADIUS
<< std::endl;
if (GlobalV::SEARCH_RADIUS < 0)
{
std::cout << " SEARCH_RADIUS : " << GlobalV::SEARCH_RADIUS << std::endl;
std::cout << " please make sure search_radius > 0" << std::endl;
}

Expand Down Expand Up @@ -205,8 +208,7 @@ void ESolver_KS_LCAO<TK, TR>::others(const int istep)
}
else
{
ModuleBase::WARNING_QUIT("ESolver_KS_LCAO<TK, TR>::others",
"CALCULATION type not supported");
ModuleBase::WARNING_QUIT("ESolver_KS_LCAO<TK, TR>::others", "CALCULATION type not supported");
}

ModuleBase::timer::tick("ESolver_KS_LCAO", "others");
Expand Down
4 changes: 2 additions & 2 deletions source/module_io/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ if(ENABLE_LCAO)
write_orb_info.cpp
write_proj_band_lcao.cpp
nscf_fermi_surf.cpp
get_pchg.cpp
get_wf.cpp
get_pchg_lcao.cpp
get_wf_lcao.cpp
read_wfc_nao.cpp
read_wfc_lcao.cpp
write_wfc_nao.cpp
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include "get_pchg.h"
#include "get_pchg_lcao.h"

#include "module_base/blas_connector.h"
#include "module_base/global_function.h"
Expand Down Expand Up @@ -120,7 +120,7 @@ void IState_Charge::begin(Gint_Gamma& gg,
ModuleBase::GlobalFunc::DCOPY(rho[is], rho_save[is].data(), rhopw_nrxx); // Copy data
}

std::cout << " Writting cube files...";
std::cout << " Writing cube files...";

for (int is = 0; is < nspin; ++is)
{
Expand Down Expand Up @@ -257,7 +257,7 @@ void IState_Charge::begin(Gint_k& gk,
ModuleBase::GlobalFunc::DCOPY(rho[is], rho_save[is].data(), rhopw_nrxx); // Copy data
}

std::cout << " Writting cube files...";
std::cout << " Writing cube files...";

for (int is = 0; is < nspin; ++is)
{
Expand Down Expand Up @@ -332,7 +332,7 @@ void IState_Charge::begin(Gint_k& gk,
ucell_in->symm);
}

std::cout << " Writting cube files...";
std::cout << " Writing cube files...";

for (int is = 0; is < nspin; ++is)
{
Expand Down
File renamed without changes.
Loading
Loading