Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature & Refactor: read and write Hexx(R) in CSR format #3727

Merged
merged 9 commits into from
Apr 2, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
recover multiple process
  • Loading branch information
maki49 committed Mar 28, 2024
commit f5aa987a5c4c1739a026fc33de6a0af1693b0ba1
2 changes: 1 addition & 1 deletion source/module_esolver/esolver_ks_lcao.cpp
Original file line number Diff line number Diff line change
@@ -976,7 +976,7 @@ void ESolver_KS_LCAO<TK, TR>::afterscf(const int istep)
#ifdef __EXX
if (GlobalC::exx_info.info_global.cal_exx) // Peize Lin add if 2022.11.14
{
const std::string file_name_exx = GlobalV::global_out_dir + "HexxR";
const std::string file_name_exx = GlobalV::global_out_dir + "HexxR" + std::to_string(GlobalV::MY_RANK);
if (GlobalC::exx_info.info_ri.real_number)
{
this->exd->write_Hexxs(file_name_exx, GlobalC::ucell);
2 changes: 1 addition & 1 deletion source/module_esolver/esolver_ks_lcao_elec.cpp
Original file line number Diff line number Diff line change
@@ -524,7 +524,7 @@ void ESolver_KS_LCAO<TK, TR>::nscf()
if (GlobalC::exx_info.info_global.cal_exx)
{
// GlobalC::exx_lcao.cal_exx_elec_nscf(this->LOWF.ParaV[0]);
const std::string file_name_exx = GlobalV::global_out_dir + "HexxR";
const std::string file_name_exx = GlobalV::global_out_dir + "HexxR" + std::to_string(GlobalV::MY_RANK);
if (GlobalC::exx_info.info_ri.real_number)
this->exd->read_Hexxs(file_name_exx, GlobalC::ucell);
else
28 changes: 13 additions & 15 deletions source/module_io/single_R_io.cpp
Original file line number Diff line number Diff line change
@@ -8,19 +8,19 @@ void ModuleIO::output_single_R(std::ofstream& ofs,
const double& sparse_threshold,
const bool& binary,
const Parallel_Orbitals& pv,
const bool& already_global)
const bool& reduce)
{
double *line = nullptr;
std::vector<int> indptr;
indptr.reserve(GlobalV::NLOCAL + 1);
indptr.push_back(0);

std::stringstream tem1;
tem1 << GlobalV::global_out_dir << "temp_sparse_indices.dat";
tem1 << GlobalV::global_out_dir << std::to_string(GlobalV::DRANK) + "temp_sparse_indices.dat";
std::ofstream ofs_tem1;
std::ifstream ifs_tem1;

if (GlobalV::DRANK == 0)
if (!reduce || GlobalV::DRANK == 0)
{
if (binary)
{
@@ -38,7 +38,7 @@ void ModuleIO::output_single_R(std::ofstream& ofs,
// line = new double[GlobalV::NLOCAL];
ModuleBase::GlobalFunc::ZEROS(line, GlobalV::NLOCAL);

if (already_global || pv.global2local_row(row) >= 0)
if (!reduce || pv.global2local_row(row) >= 0)
{
auto iter = XR.find(row);
if (iter != XR.end())
@@ -50,7 +50,7 @@ void ModuleIO::output_single_R(std::ofstream& ofs,
}
}

if (!already_global)Parallel_Reduce::reduce_all(line, GlobalV::NLOCAL);
if (reduce)Parallel_Reduce::reduce_all(line, GlobalV::NLOCAL);

if(GlobalV::DRANK == 0)
{
@@ -87,7 +87,7 @@ void ModuleIO::output_single_R(std::ofstream& ofs,
delete[] line;
line = nullptr;

if (GlobalV::DRANK == 0)
if (!reduce || GlobalV::DRANK == 0)
{
if (binary)
{
@@ -126,21 +126,19 @@ void ModuleIO::output_single_R(std::ofstream& ofs,
const double& sparse_threshold,
const bool& binary,
const Parallel_Orbitals& pv,
const bool& already_global)
const bool& reduce)
{
if (already_global && GlobalV::DRANK != 0) return;

std::complex<double>* line = nullptr;
std::vector<int> indptr;
indptr.reserve(GlobalV::NLOCAL + 1);
indptr.push_back(0);

std::stringstream tem1;
tem1 << GlobalV::global_out_dir << "temp_sparse_indices.dat";
tem1 << GlobalV::global_out_dir << std::to_string(GlobalV::DRANK) + "temp_sparse_indices.dat";
std::ofstream ofs_tem1;
std::ifstream ifs_tem1;

if (GlobalV::DRANK == 0)
if (!reduce || GlobalV::DRANK == 0)
{
if (binary)
{
@@ -158,7 +156,7 @@ void ModuleIO::output_single_R(std::ofstream& ofs,
// line = new std::complex<double>[GlobalV::NLOCAL];
ModuleBase::GlobalFunc::ZEROS(line, GlobalV::NLOCAL);

if (already_global || pv.global2local_row(row) >= 0)
if (!reduce || pv.global2local_row(row) >= 0)
{
auto iter = XR.find(row);
if (iter != XR.end())
@@ -170,9 +168,9 @@ void ModuleIO::output_single_R(std::ofstream& ofs,
}
}

if (!already_global)Parallel_Reduce::reduce_all(line, GlobalV::NLOCAL);
if (reduce)Parallel_Reduce::reduce_all(line, GlobalV::NLOCAL);

if (GlobalV::DRANK == 0)
if (!reduce || GlobalV::DRANK == 0)
{
int nonzeros_count = 0;
for (int col = 0; col < GlobalV::NLOCAL; ++col)
@@ -208,7 +206,7 @@ void ModuleIO::output_single_R(std::ofstream& ofs,
delete[] line;
line = nullptr;

if (GlobalV::DRANK == 0)
if (!reduce || GlobalV::DRANK == 0)
{
if (binary)
{
4 changes: 2 additions & 2 deletions source/module_io/single_R_io.h
Original file line number Diff line number Diff line change
@@ -10,13 +10,13 @@ namespace ModuleIO
const double& sparse_threshold,
const bool& binary,
const Parallel_Orbitals& pv,
const bool& already_global = false);
const bool& reduce = true);
void output_single_R(std::ofstream& ofs,
const std::map<size_t, std::map<size_t, std::complex<double>>>& XR,
const double& sparse_threshold,
const bool& binary,
const Parallel_Orbitals& pv,
const bool& already_global = false);
const bool& reduce = true);
}

#endif
14 changes: 6 additions & 8 deletions source/module_io/write_HS_sparse.cpp
Original file line number Diff line number Diff line change
@@ -684,13 +684,11 @@ void ModuleIO::save_sparse(
const Parallel_Orbitals& pv,
const std::string& label,
const int& istep,
const bool& already_global)
const bool& reduce)
{
ModuleBase::TITLE("ModuleIO", "save_sparse");
ModuleBase::timer::tick("ModuleIO", "save_sparse");

if (already_global && GlobalV::DRANK != 0) return;

int total_R_num = all_R_coor.size();
std::vector<int> nonzero_num(total_R_num, 0);
int count = 0;
@@ -702,7 +700,7 @@ void ModuleIO::save_sparse(
nonzero_num[count] += row_loop.second.size();
++count;
}
if (!already_global)Parallel_Reduce::reduce_all(nonzero_num.data(), total_R_num);
if (reduce)Parallel_Reduce::reduce_all(nonzero_num.data(), total_R_num);

int output_R_number = 0;
for (int index = 0; index < total_R_num; ++index)
@@ -711,7 +709,7 @@ void ModuleIO::save_sparse(
std::stringstream sss;
sss << filename;
std::ofstream ofs;
if (GlobalV::DRANK == 0)
if (!reduce || GlobalV::DRANK == 0)
{
if (binary)
{
@@ -748,7 +746,7 @@ void ModuleIO::save_sparse(
continue;
}

if (GlobalV::DRANK == 0)
if (!reduce || GlobalV::DRANK == 0)
{
if (binary)
{
@@ -763,10 +761,10 @@ void ModuleIO::save_sparse(
}
}

output_single_R(ofs, smat.at(R_coor), sparse_threshold, binary, pv, already_global);
output_single_R(ofs, smat.at(R_coor), sparse_threshold, binary, pv, reduce);
++count;
}
if (GlobalV::DRANK == 0) ofs.close();
if (!reduce || GlobalV::DRANK == 0) ofs.close();

ModuleBase::timer::tick("ModuleIO", "save_sparse");
}
2 changes: 1 addition & 1 deletion source/module_io/write_HS_sparse.h
Original file line number Diff line number Diff line change
@@ -37,7 +37,7 @@ namespace ModuleIO
const Parallel_Orbitals& pv,
const std::string& label,
const int& istep = -1,
const bool& already_global = false);
const bool& reduce = true);
}

#endif
21 changes: 10 additions & 11 deletions source/module_ri/Exx_LRI_interface.hpp
Original file line number Diff line number Diff line change
@@ -52,17 +52,16 @@ void Exx_LRI_Interface<T, Tdata>::write_Hexxs(const std::string& file_name, cons
all_R_coor.insert(R);
}
}
if (GlobalV::DRANK == 0)
ModuleIO::save_sparse(
this->calculate_RI_Tensor_sparse(sparse_threshold, this->exx_ptr->Hexxs[is], ucell),
all_R_coor,
sparse_threshold,
false, //binary
file_name + "_" + std::to_string(is) + ".csr",
Parallel_Orbitals(),
"Hexxs_" + std::to_string(is),
-1,
true); //already global
ModuleIO::save_sparse(
this->calculate_RI_Tensor_sparse(sparse_threshold, this->exx_ptr->Hexxs[is], ucell),
all_R_coor,
sparse_threshold,
false, //binary
file_name + "_" + std::to_string(is) + ".csr",
Parallel_Orbitals(),
"Hexxs_" + std::to_string(is),
-1,
false); //no reduce, one file for each process
}
ModuleBase::timer::tick("Exx_LRI", "write_Hexxs");
}