Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: replace most of input varibles by parameter #4693

Merged
merged 14 commits into from
Jul 14, 2024
Merged
2 changes: 1 addition & 1 deletion docs/advanced/json/json_add.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ namespace Json
{

#ifdef __RAPIDJSON
void gen_general_info(Input *input)
void gen_general_info(const Parameter& param)
{

#ifdef VERSION
Expand Down
1 change: 1 addition & 0 deletions source/Makefile.Objects
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,7 @@ OBJS_IO=input.o\
read_input_item_postprocess.o\
read_input_item_md.o\
read_input_item_other.o\
bcast_globalv.o

OBJS_IO_LCAO=cal_r_overlap_R.o\
write_orb_info.o\
Expand Down
9 changes: 4 additions & 5 deletions source/driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ void Driver::init()

// (5) output the json file
// Json::create_Json(&GlobalC::ucell.symm,GlobalC::ucell.atoms,&INPUT);
Json::create_Json(&GlobalC::ucell, &INPUT);
Json::create_Json(&GlobalC::ucell, PARAM);
}

void Driver::print_start_info()
Expand All @@ -68,7 +68,6 @@ void Driver::print_start_info()
#endif
time_t time_now = time(nullptr);

INPUT.start_time = time_now;
PARAM.set_start_time(time_now);
GlobalV::ofs_running << " "
" "
Expand Down Expand Up @@ -116,11 +115,11 @@ void Driver::reading()
{
ModuleBase::timer::tick("Driver", "reading");
// temperarily
GlobalV::MY_RANK = PARAM.sys.myrank;
GlobalV::NPROC = PARAM.sys.nproc;
GlobalV::MY_RANK = PARAM.globalv.myrank;
GlobalV::NPROC = PARAM.globalv.nproc;

// (1) read the input file
ModuleIO::ReadInput read_input(PARAM.sys.myrank);
ModuleIO::ReadInput read_input(PARAM.globalv.myrank);
read_input.read_parameters(PARAM, GlobalV::global_in_card);

// (2) create the output directory, running_*.log and print info
Expand Down
5 changes: 3 additions & 2 deletions source/driver_run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "module_cell/module_neighbor/sltk_atom_arrange.h"
#include "module_hamilt_pw/hamilt_pwdft/global.h"
#include "module_io/input.h"
#include "module_parameter/parameter.h"
#include "module_io/para_json.h"
#include "module_io/print_info.h"
#include "module_io/winput.h"
Expand Down Expand Up @@ -45,10 +46,10 @@ void Driver::driver_run() {
GlobalV::MIN_DIST_COEF);

//! 2: initialize the ESolver (depends on a set-up ucell after `setup_cell`)
ModuleESolver::ESolver* p_esolver = ModuleESolver::init_esolver(INPUT, PARAM.inp, GlobalC::ucell);
ModuleESolver::ESolver* p_esolver = ModuleESolver::init_esolver(PARAM.inp, GlobalC::ucell);

//! 3: initialize Esolver and fill json-structure
p_esolver->before_all_runners(INPUT, GlobalC::ucell);
p_esolver->before_all_runners(PARAM.inp, GlobalC::ucell);

// this Json part should be moved to before_all_runners, mohan 2024-05-12
#ifdef __RAPIDJSON
Expand Down
2 changes: 1 addition & 1 deletion source/module_cell/read_atoms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ int UnitCell::read_atom_species(std::ifstream &ifa, std::ofstream &ofs_running)
// Peize Lin add 2016-09-23
#ifdef __MPI
#ifdef __EXX
if( GlobalC::exx_info.info_global.cal_exx || INPUT.rpa )
if( GlobalC::exx_info.info_global.cal_exx || PARAM.inp.rpa )
{
if( ModuleBase::GlobalFunc::SCAN_BEGIN(ifa, "ABFS_ORBITAL") )
{
Expand Down
4 changes: 2 additions & 2 deletions source/module_elecstate/elecstate_getters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

#include "module_cell/unitcell.h"
#include "module_hamilt_pw/hamilt_pwdft/global.h"
#include "module_io/input.h"
#include "module_parameter/parameter.h"
#include "module_hamilt_general/module_xc/xc_functional.h"

namespace elecstate
Expand All @@ -25,7 +25,7 @@ int get_xc_func_type()

std::string get_input_vdw_method()
{
return INPUT.vdw_method;
return PARAM.inp.vdw_method;
}

double get_ucell_tot_magnetization()
Expand Down
14 changes: 7 additions & 7 deletions source/module_esolver/esolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ std::string determine_type()


//Some API to operate E_Solver
ESolver* init_esolver(Input& input, const Input_para& input_para, UnitCell& ucell)
ESolver* init_esolver(const Input_para& inp, UnitCell& ucell)
{
//determine type of esolver based on INPUT information
const std::string esolver_type = determine_type();
Expand Down Expand Up @@ -187,9 +187,9 @@ ESolver* init_esolver(Input& input, const Input_para& input_para, UnitCell& ucel
{
// use constructor rather than Init function to initialize reference (instead of pointers) to ucell
if (GlobalV::GAMMA_ONLY_LOCAL)
return new LR::ESolver_LR<double, double>(input_para, input, ucell);
return new LR::ESolver_LR<double, double>(inp, ucell);
else if (GlobalV::NSPIN < 2)
return new LR::ESolver_LR<std::complex<double>, double>(input_para, input, ucell);
return new LR::ESolver_LR<std::complex<double>, double>(inp, ucell);
else
throw std::runtime_error("LR-TDDFT is not implemented for spin polarized case");
}
Expand All @@ -209,7 +209,7 @@ ESolver* init_esolver(Input& input, const Input_para& input_para, UnitCell& ucel
{
p_esolver = new ESolver_KS_LCAO<std::complex<double>, std::complex<double>>();
}
p_esolver->before_all_runners(input, ucell);
p_esolver->before_all_runners(inp, ucell);
p_esolver->runner(0, ucell); // scf-only
// force and stress is not needed currently,
// they will be supported after the analytical gradient
Expand All @@ -221,12 +221,12 @@ ESolver* init_esolver(Input& input, const Input_para& input_para, UnitCell& ucel
if (GlobalV::GAMMA_ONLY_LOCAL)
p_esolver_lr = new LR::ESolver_LR<double, double>(
std::move(*dynamic_cast<ModuleESolver::ESolver_KS_LCAO<double, double>*>(p_esolver)),
input_para,
inp,
ucell);
else
p_esolver_lr = new LR::ESolver_LR<std::complex<double>, double>(
std::move(*dynamic_cast<ModuleESolver::ESolver_KS_LCAO<std::complex<double>, double>*>(p_esolver)),
input_para,
inp,
ucell);
// clean the 1st ESolver_KS and swap the pointer
ModuleESolver::clean_esolver(p_esolver, false); // do not call Cblacs_exit, remain it for the 2nd ESolver
Expand All @@ -247,7 +247,7 @@ ESolver* init_esolver(Input& input, const Input_para& input_para, UnitCell& ucel
}
else if (esolver_type == "dp_pot")
{
return new ESolver_DP(INPUT.mdp.pot_file);
return new ESolver_DP(PARAM.mdp.pot_file);
}
throw std::invalid_argument("esolver_type = "+std::string(esolver_type)+". Wrong in "+std::string(__FILE__)+" line "+std::to_string(__LINE__));
}
Expand Down
6 changes: 3 additions & 3 deletions source/module_esolver/esolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include "module_base/matrix.h"
#include "module_cell/unitcell.h"
#include "module_io/input.h"
#include "module_parameter/input_parameter.h"
#include "module_parameter/parameter.h"

namespace ModuleESolver
{
Expand All @@ -21,7 +21,7 @@ class ESolver
}

//! initialize the energy solver by using input parameters and cell modules
virtual void before_all_runners(Input& inp, UnitCell& cell) = 0;
virtual void before_all_runners(const Input_para& inp, UnitCell& cell) = 0;

//! run energy solver
virtual void runner(const int istep, UnitCell& cell) = 0;
Expand Down Expand Up @@ -85,7 +85,7 @@ std::string determine_type();
*
* @return [out] A pointer to an ESolver object that will be initialized.
*/
ESolver* init_esolver(Input& input, const Input_para& input_para, UnitCell& ucell);
ESolver* init_esolver(const Input_para& inp, UnitCell& ucell);

void clean_esolver(ESolver*& pesolver, const bool lcao_cblacs_exit = false);

Expand Down
2 changes: 1 addition & 1 deletion source/module_esolver/esolver_dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
namespace ModuleESolver
{

void ESolver_DP::before_all_runners(Input& inp, UnitCell& ucell)
void ESolver_DP::before_all_runners(const Input_para& inp, UnitCell& ucell)
{
ucell_ = &ucell;
dp_potential = 0;
Expand Down
2 changes: 1 addition & 1 deletion source/module_esolver/esolver_dp.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class ESolver_DP : public ESolver
* @param inp input parameters
* @param cell unitcell information
*/
void before_all_runners(Input& inp, UnitCell& cell) override;
void before_all_runners(const Input_para& inp, UnitCell& cell) override;

/**
* @brief Run the DP solver for a given ion/md step and unit cell
Expand Down
14 changes: 7 additions & 7 deletions source/module_esolver/esolver_fp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ ESolver_FP::ESolver_FP()

// temporary, it will be removed
pw_big = static_cast<ModulePW::PW_Basis_Big*>(pw_rhod);
pw_big->setbxyz(INPUT.bx, INPUT.by, INPUT.bz);
sf.set(pw_rhod, INPUT.nbspline);
pw_big->setbxyz(PARAM.inp.bx, PARAM.inp.by, PARAM.inp.bz);
sf.set(pw_rhod, PARAM.inp.nbspline);

GlobalC::ucell.symm.epsilon = GlobalC::ucell.symm.epsilon_input = PARAM.inp.symmetry_prec;
}
Expand All @@ -39,7 +39,7 @@ ESolver_FP::~ESolver_FP()
delete this->pelec;
}

void ESolver_FP::before_all_runners(Input& inp, UnitCell& cell)
void ESolver_FP::before_all_runners(const Input_para& inp, UnitCell& cell)
{
ModuleBase::TITLE("ESolver_FP", "before_all_runners");

Expand Down Expand Up @@ -105,7 +105,7 @@ void ESolver_FP::before_all_runners(Input& inp, UnitCell& cell)
return;
}

void ESolver_FP::init_after_vc(Input& inp, UnitCell& cell)
void ESolver_FP::init_after_vc(const Input_para& inp, UnitCell& cell)
{
ModuleBase::TITLE("ESolver_FP", "init_after_vc");

Expand Down Expand Up @@ -174,7 +174,7 @@ void ESolver_FP::init_after_vc(Input& inp, UnitCell& cell)
return;
}

void ESolver_FP::print_rhofft(Input& inp, std::ofstream& ofs)
void ESolver_FP::print_rhofft(const Input_para& inp, std::ofstream& ofs)
{
std::cout << " UNIFORM GRID DIM : " << pw_rho->nx << " * " << pw_rho->ny << " * " << pw_rho->nz << std::endl;
std::cout << " UNIFORM GRID DIM(BIG) : " << pw_big->nbx << " * " << pw_big->nby << " * " << pw_big->nbz
Expand Down Expand Up @@ -219,7 +219,7 @@ void ESolver_FP::print_rhofft(Input& inp, std::ofstream& ofs)
ofs << "\n\n\n\n";
ofs << "\n SETUP THE PLANE WAVE BASIS" << std::endl;

double ecut = 4 * INPUT.ecutwfc;
double ecut = 4 * PARAM.inp.ecutwfc;
if (inp.nx * inp.ny * inp.nz > 0)
{
ecut = this->pw_rho->gridecut_lat * this->pw_rho->tpiba2;
Expand Down Expand Up @@ -264,7 +264,7 @@ void ESolver_FP::print_rhofft(Input& inp, std::ofstream& ofs)
ofs << std::endl;
ofs << std::endl;
ofs << std::endl;
double ecut = INPUT.ecutrho;
double ecut = PARAM.inp.ecutrho;
if (inp.ndx * inp.ndy * inp.ndz > 0)
{
ecut = this->pw_rhod->gridecut_lat * this->pw_rhod->tpiba2;
Expand Down
6 changes: 3 additions & 3 deletions source/module_esolver/esolver_fp.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ namespace ModuleESolver
virtual ~ESolver_FP();

//! Initialize of the first-principels energy solver
virtual void before_all_runners(Input& inp, UnitCell& cell) override;
virtual void before_all_runners(const Input_para& inp, UnitCell& cell) override;

virtual void init_after_vc(Input& inp, UnitCell& cell); // liuyu add 2023-03-09
virtual void init_after_vc(const Input_para& inp, UnitCell& cell); // liuyu add 2023-03-09

//! Electronic states
elecstate::ElecState* pelec = nullptr;
Expand All @@ -62,7 +62,7 @@ namespace ModuleESolver
private:

//! Print charge density using FFT
void print_rhofft(Input& inp, std::ofstream &ofs);
void print_rhofft(const Input_para& inp, std::ofstream &ofs);
};
}

Expand Down
18 changes: 9 additions & 9 deletions source/module_esolver/esolver_ks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ ESolver_KS<T, Device>::ESolver_KS()
ModulePW::PW_Basis_K_Big* tmp = static_cast<ModulePW::PW_Basis_K_Big*>(pw_wfc);

// should not use INPUT here, mohan 2024-05-12
tmp->setbxyz(INPUT.bx, INPUT.by, INPUT.bz);
tmp->setbxyz(PARAM.inp.bx, PARAM.inp.by, PARAM.inp.bz);

///----------------------------------------------------------
/// charge mixing
Expand All @@ -67,10 +67,10 @@ ESolver_KS<T, Device>::ESolver_KS()
///----------------------------------------------------------
/// wavefunc
///----------------------------------------------------------
this->wf.init_wfc = INPUT.init_wfc;
this->wf.mem_saver = INPUT.mem_saver;
this->wf.out_wfc_pw = INPUT.out_wfc_pw;
this->wf.out_wfc_r = INPUT.out_wfc_r;
this->wf.init_wfc = PARAM.inp.init_wfc;
this->wf.mem_saver = PARAM.inp.mem_saver;
this->wf.out_wfc_pw = PARAM.inp.out_wfc_pw;
this->wf.out_wfc_r = PARAM.inp.out_wfc_r;
}

//------------------------------------------------------------------------------
Expand All @@ -92,7 +92,7 @@ ESolver_KS<T, Device>::~ESolver_KS()
//! mohan add 2024-05-11
//------------------------------------------------------------------------------
template <typename T, typename Device>
void ESolver_KS<T, Device>::before_all_runners(Input& inp, UnitCell& ucell)
void ESolver_KS<T, Device>::before_all_runners(const Input_para& inp, UnitCell& ucell)
{
ModuleBase::TITLE("ESolver_KS", "before_all_runners");

Expand Down Expand Up @@ -344,7 +344,7 @@ void ESolver_KS<T, Device>::before_all_runners(Input& inp, UnitCell& ucell)
//! mohan add 2024-05-11
//------------------------------------------------------------------------------
template <typename T, typename Device>
void ESolver_KS<T, Device>::init_after_vc(Input& inp, UnitCell& ucell)
void ESolver_KS<T, Device>::init_after_vc(const Input_para& inp, UnitCell& ucell)
{
ModuleBase::TITLE("ESolver_KS", "init_after_vc");

Expand Down Expand Up @@ -388,7 +388,7 @@ void ESolver_KS<T, Device>::hamilt2density(const int istep, const int iter, cons
//! mohan add 2024-05-11
//------------------------------------------------------------------------------
template <typename T, typename Device>
void ESolver_KS<T, Device>::print_wfcfft(Input& inp, std::ofstream& ofs)
void ESolver_KS<T, Device>::print_wfcfft(const Input_para& inp, std::ofstream& ofs)
{
ofs << "\n\n\n\n";
ofs << " >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>"
Expand Down Expand Up @@ -718,7 +718,7 @@ void ESolver_KS<T, Device>::print_iter(const int iter,
const double duration,
const double ethr)
{
this->pelec->print_etot(this->conv_elec, iter, drho, dkin, duration, INPUT.printe, ethr);
this->pelec->print_etot(this->conv_elec, iter, drho, dkin, duration, PARAM.inp.printe, ethr);
}

//------------------------------------------------------------------------------
Expand Down
8 changes: 4 additions & 4 deletions source/module_esolver/esolver_ks.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ class ESolver_KS : public ESolver_FP

int out_freq_elec;// frequency for output

virtual void before_all_runners(Input& inp, UnitCell& cell) override;
virtual void before_all_runners(const Input_para& inp, UnitCell& cell) override;

virtual void init_after_vc(Input& inp, UnitCell& cell) override; // liuyu add 2023-03-09
virtual void init_after_vc(const Input_para& inp, UnitCell& cell) override; // liuyu add 2023-03-09

virtual void runner(const int istep, UnitCell& cell) override;

Expand Down Expand Up @@ -137,7 +137,7 @@ class ESolver_KS : public ESolver_FP

std::string basisname; //PW or LCAO

void print_wfcfft(Input& inp, std::ofstream& ofs);
};
void print_wfcfft(const Input_para& inp, std::ofstream &ofs);
};
} // end of namespace
#endif
Loading
Loading