Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature & Refactor: read and write Hexx(R) in CSR format #3727

Merged
merged 9 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions source/module_esolver/esolver_ks_lcao.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1024,14 +1024,14 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(const int istep)
#ifdef __EXX
if (GlobalC::exx_info.info_global.cal_exx) // Peize Lin add if 2022.11.14
{
const std::string file_name_exx = GlobalV::global_out_dir + "HexxR_" + std::to_string(GlobalV::MY_RANK);
const std::string file_name_exx = GlobalV::global_out_dir + "HexxR" + std::to_string(GlobalV::MY_RANK);
maki49 marked this conversation as resolved.
Show resolved Hide resolved
if (GlobalC::exx_info.info_ri.real_number)
{
this->exd->write_Hexxs(file_name_exx);
this->exd->write_Hexxs_csr(file_name_exx, GlobalC::ucell);
}
else
{
this->exc->write_Hexxs(file_name_exx);
this->exc->write_Hexxs_csr(file_name_exx, GlobalC::ucell);
}
}
#endif
Expand Down
6 changes: 3 additions & 3 deletions source/module_esolver/esolver_ks_lcao_elec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -524,11 +524,11 @@ void ESolver_KS_LCAO<TK, TR>::nscf()
if (GlobalC::exx_info.info_global.cal_exx)
{
// GlobalC::exx_lcao.cal_exx_elec_nscf(this->LOWF.ParaV[0]);
const std::string file_name_exx = GlobalV::global_out_dir + "HexxR_" + std::to_string(GlobalV::MY_RANK);
const std::string file_name_exx = GlobalV::global_out_dir + "HexxR" + std::to_string(GlobalV::MY_RANK);
if (GlobalC::exx_info.info_ri.real_number)
this->exd->read_Hexxs(file_name_exx);
this->exd->read_Hexxs_csr(file_name_exx, GlobalC::ucell);
else
this->exc->read_Hexxs(file_name_exx);
this->exc->read_Hexxs_csr(file_name_exx, GlobalC::ucell);

hamilt::HamiltLCAO<TK, TR>* hamilt_lcao = dynamic_cast<hamilt::HamiltLCAO<TK, TR>*>(this->p_hamilt);
auto exx = new hamilt::OperatorEXX<hamilt::OperatorLCAO<TK, TR>>(&this->LM,
Expand Down
2 changes: 1 addition & 1 deletion source/module_io/csr_reader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,6 @@ int csrFileReader<T>::getStep() const
// T of AtomPair can be double
template class csrFileReader<double>;
// ToDo: T of AtomPair can be std::complex<double>
// template class csrFileReader<std::complex<double>>;
template class csrFileReader<std::complex<double>>;

} // namespace ModuleIO
161 changes: 37 additions & 124 deletions source/module_io/single_R_io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,132 +3,35 @@
#include "module_base/global_function.h"
#include "module_base/global_variable.h"

void ModuleIO::output_single_R(std::ofstream &ofs, const std::map<size_t, std::map<size_t, double>> &XR, const double &sparse_threshold, const bool &binary, const Parallel_Orbitals &pv)
inline void write_data(std::ofstream& ofs, const double& data)
{
double *line = nullptr;
std::vector<int> indptr;
indptr.reserve(GlobalV::NLOCAL + 1);
indptr.push_back(0);

std::stringstream tem1;
tem1 << GlobalV::global_out_dir << "temp_sparse_indices.dat";
std::ofstream ofs_tem1;
std::ifstream ifs_tem1;

if (GlobalV::DRANK == 0)
{
if (binary)
{
ofs_tem1.open(tem1.str().c_str(), std::ios::binary);
}
else
{
ofs_tem1.open(tem1.str().c_str());
}
}

line = new double[GlobalV::NLOCAL];
for(int row = 0; row < GlobalV::NLOCAL; ++row)
{
// line = new double[GlobalV::NLOCAL];
ModuleBase::GlobalFunc::ZEROS(line, GlobalV::NLOCAL);

if(pv.global2local_row(row) >= 0)
{
auto iter = XR.find(row);
if (iter != XR.end())
{
for (auto &value : iter->second)
{
line[value.first] = value.second;
}
}
}

Parallel_Reduce::reduce_all(line, GlobalV::NLOCAL);

if(GlobalV::DRANK == 0)
{
int nonzeros_count = 0;
for (int col = 0; col < GlobalV::NLOCAL; ++col)
{
if (std::abs(line[col]) > sparse_threshold)
{
if (binary)
{
ofs.write(reinterpret_cast<char *>(&line[col]), sizeof(double));
ofs_tem1.write(reinterpret_cast<char *>(&col), sizeof(int));
}
else
{
ofs << " " << std::fixed << std::scientific << std::setprecision(8) << line[col];
ofs_tem1 << " " << col;
}

nonzeros_count++;

}

}
nonzeros_count += indptr.back();
indptr.push_back(nonzeros_count);
}

// delete[] line;
// line = nullptr;

}

delete[] line;
line = nullptr;

if (GlobalV::DRANK == 0)
{
if (binary)
{
ofs_tem1.close();
ifs_tem1.open(tem1.str().c_str(), std::ios::binary);
ofs << ifs_tem1.rdbuf();
ifs_tem1.close();
for (auto &i : indptr)
{
ofs.write(reinterpret_cast<char *>(&i), sizeof(int));
}
}
else
{
ofs << std::endl;
ofs_tem1 << std::endl;
ofs_tem1.close();
ifs_tem1.open(tem1.str().c_str());
ofs << ifs_tem1.rdbuf();
ifs_tem1.close();
for (auto &i : indptr)
{
ofs << " " << i;
}
ofs << std::endl;
}

std::remove(tem1.str().c_str());

}

ofs << " " << std::fixed << std::scientific << std::setprecision(8) << data;
}
inline void write_data(std::ofstream& ofs, const std::complex<double>& data)
{
ofs << " (" << std::fixed << std::scientific << std::setprecision(8) << data.real() << ","
<< std::fixed << std::scientific << std::setprecision(8) << data.imag() << ")";
}

void ModuleIO::output_soc_single_R(std::ofstream &ofs, const std::map<size_t, std::map<size_t, std::complex<double>>> &XR, const double &sparse_threshold, const bool &binary, const Parallel_Orbitals &pv)
template<typename T>
void ModuleIO::output_single_R(std::ofstream& ofs,
const std::map<size_t, std::map<size_t, T>>& XR,
const double& sparse_threshold,
const bool& binary,
const Parallel_Orbitals& pv,
const bool& reduce)
{
std::complex<double> *line = nullptr;
T* line = nullptr;
std::vector<int> indptr;
indptr.reserve(GlobalV::NLOCAL + 1);
indptr.push_back(0);

std::stringstream tem1;
tem1 << GlobalV::global_out_dir << "temp_sparse_indices.dat";
tem1 << GlobalV::global_out_dir << std::to_string(GlobalV::DRANK) + "temp_sparse_indices.dat";
std::ofstream ofs_tem1;
std::ifstream ifs_tem1;

if (GlobalV::DRANK == 0)
if (!reduce || GlobalV::DRANK == 0)
{
if (binary)
{
Expand All @@ -140,13 +43,12 @@ void ModuleIO::output_soc_single_R(std::ofstream &ofs, const std::map<size_t, st
}
}

line = new std::complex<double>[GlobalV::NLOCAL];
line = new T[GlobalV::NLOCAL];
for(int row = 0; row < GlobalV::NLOCAL; ++row)
{
// line = new std::complex<double>[GlobalV::NLOCAL];
ModuleBase::GlobalFunc::ZEROS(line, GlobalV::NLOCAL);

if(pv.global2local_row(row) >= 0)
if (!reduce || pv.global2local_row(row) >= 0)
{
auto iter = XR.find(row);
if (iter != XR.end())
Expand All @@ -158,9 +60,9 @@ void ModuleIO::output_soc_single_R(std::ofstream &ofs, const std::map<size_t, st
}
}

Parallel_Reduce::reduce_all(line, GlobalV::NLOCAL);
if (reduce)Parallel_Reduce::reduce_all(line, GlobalV::NLOCAL);

if (GlobalV::DRANK == 0)
if (!reduce || GlobalV::DRANK == 0)
{
int nonzeros_count = 0;
for (int col = 0; col < GlobalV::NLOCAL; ++col)
Expand All @@ -169,13 +71,12 @@ void ModuleIO::output_soc_single_R(std::ofstream &ofs, const std::map<size_t, st
{
if (binary)
{
ofs.write(reinterpret_cast<char *>(&line[col]), sizeof(std::complex<double>));
ofs.write(reinterpret_cast<char*>(&line[col]), sizeof(T));
ofs_tem1.write(reinterpret_cast<char *>(&col), sizeof(int));
}
else
{
ofs << " (" << std::fixed << std::scientific << std::setprecision(8) << line[col].real() << ","
<< std::fixed << std::scientific << std::setprecision(8) << line[col].imag() << ")";
write_data(ofs, line[col]);
ofs_tem1 << " " << col;
}

Expand All @@ -196,7 +97,7 @@ void ModuleIO::output_soc_single_R(std::ofstream &ofs, const std::map<size_t, st
delete[] line;
line = nullptr;

if (GlobalV::DRANK == 0)
if (!reduce || GlobalV::DRANK == 0)
{
if (binary)
{
Expand Down Expand Up @@ -226,5 +127,17 @@ void ModuleIO::output_soc_single_R(std::ofstream &ofs, const std::map<size_t, st

std::remove(tem1.str().c_str());
}

}

template void ModuleIO::output_single_R<double>(std::ofstream& ofs,
const std::map<size_t, std::map<size_t, double>>& XR,
const double& sparse_threshold,
const bool& binary,
const Parallel_Orbitals& pv,
const bool& reduce);
template void ModuleIO::output_single_R<std::complex<double>>(std::ofstream& ofs,
const std::map<size_t, std::map<size_t, std::complex<double>>>& XR,
const double& sparse_threshold,
const bool& binary,
const Parallel_Orbitals& pv,
const bool& reduce);
9 changes: 7 additions & 2 deletions source/module_io/single_R_io.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,13 @@

namespace ModuleIO
{
void output_single_R(std::ofstream &ofs, const std::map<size_t, std::map<size_t, double>> &XR, const double &sparse_threshold, const bool &binary, const Parallel_Orbitals &pv);
void output_soc_single_R(std::ofstream &ofs, const std::map<size_t, std::map<size_t, std::complex<double>>> &XR, const double &sparse_threshold, const bool &binary, const Parallel_Orbitals &pv);
template <typename T>
void output_single_R(std::ofstream& ofs,
const std::map<size_t, std::map<size_t, T>>& XR,
const double& sparse_threshold,
const bool& binary,
const Parallel_Orbitals& pv,
const bool& reduce = true);
}

#endif
2 changes: 1 addition & 1 deletion source/module_io/sparse_matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ void SparseMatrix<T>::readCSR(const std::vector<T>& values,

// define the operator to index a matrix element
template <typename T>
T SparseMatrix<T>::operator()(int row, int col)
T SparseMatrix<T>::operator()(int row, int col) const
{
if (row < 0 || row >= _rows || col < 0 || col >= _cols)
{
Expand Down
2 changes: 1 addition & 1 deletion source/module_io/sparse_matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class SparseMatrix
}

// define the operator to index a matrix element
T operator()(int row, int col);
T operator()(int row, int col)const;

// set the threshold
void setSparseThreshold(double sparse_threshold)
Expand Down
4 changes: 2 additions & 2 deletions source/module_io/write_HS_R.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ void ModuleIO::output_S_R(
ModuleBase::timer::tick("ModuleIO","output_S_R");

UHM.cal_SR_sparse(sparse_threshold, p_ham);
ModuleIO::save_SR_sparse(*UHM.LM, sparse_threshold, binary, SR_filename);
ModuleIO::save_sparse(UHM.LM->SR_sparse, UHM.LM->all_R_coor, sparse_threshold, binary, SR_filename, *UHM.LM->ParaV, "S", 0);
UHM.destroy_all_HSR_sparse();

ModuleBase::timer::tick("ModuleIO","output_S_R");
Expand Down Expand Up @@ -149,7 +149,7 @@ void ModuleIO::output_T_R(
}

UHM.cal_TR_sparse(sparse_threshold);
ModuleIO::save_TR_sparse(istep, *UHM.LM, sparse_threshold, binary, sst.str().c_str());
ModuleIO::save_sparse(UHM.LM->TR_sparse, UHM.LM->all_R_coor, sparse_threshold, binary, sst.str().c_str(), *UHM.LM->ParaV, "T", istep);
UHM.destroy_TR_sparse();

ModuleBase::timer::tick("ModuleIO","output_T_R");
Expand Down
Loading