Skip to content

Commit

Permalink
refactor: remove template on relax_driver (#4172)
Browse files Browse the repository at this point in the history
hongriTianqi authored May 21, 2024
1 parent c6f8609 commit 142adc2
Showing 7 changed files with 51 additions and 32 deletions.
14 changes: 2 additions & 12 deletions source/driver_run.cpp
Original file line number Diff line number Diff line change
@@ -57,18 +57,8 @@ void Driver::driver_run(void)
}
else //! scf; cell relaxation; nscf; etc
{
// mixed-precision should not be like this, mohan 2024-05-12,
// DEVICE should not depend on psi
if (GlobalV::precision_flag == "single")
{
Relax_Driver<float, base_device::DEVICE_CPU> rl_driver;
rl_driver.relax_driver(p_esolver);
}
else
{
Relax_Driver<double, base_device::DEVICE_CPU> rl_driver;
rl_driver.relax_driver(p_esolver);
}
Relax_Driver rl_driver;
rl_driver.relax_driver(p_esolver);
}
// "others" in ESolver should be here.

14 changes: 13 additions & 1 deletion source/module_esolver/esolver.h
Original file line number Diff line number Diff line change
@@ -47,7 +47,19 @@ class ESolver

// temporarily
// get iterstep used in current scf
virtual int getniter()
virtual int get_niter()
{
return 0;
}

// get maxniter used in current scf
virtual int get_maxniter()
{
return 0;
}

// get conv_elec used in current scf
virtual bool get_conv_elec()
{
return 0;
}
29 changes: 24 additions & 5 deletions source/module_esolver/esolver_ks.cpp
Original file line number Diff line number Diff line change
@@ -716,14 +716,33 @@ void ESolver_KS<T, Device>::write_head(std::ofstream& ofs_running, const int ist
//! mohan add 2024-05-12
//------------------------------------------------------------------------------
template<typename T, typename Device>
int ESolver_KS<T, Device>::getniter()
int ESolver_KS<T, Device>::get_niter()
{
return this->niter;
}

//------------------------------------------------------------------------------
//! the 11th function of ESolver_KS: get_maxniter
//! tqzhao add 2024-05-15
//------------------------------------------------------------------------------
template<typename T, typename Device>
int ESolver_KS<T, Device>::get_maxniter()
{
return this->maxniter;
}

//------------------------------------------------------------------------------
//! the 12th function of ESolver_KS: get_conv_elec
//! tqzhao add 2024-05-15
//------------------------------------------------------------------------------
template<typename T, typename Device>
bool ESolver_KS<T, Device>::get_conv_elec()
{
return this->conv_elec;
}

//------------------------------------------------------------------------------
//! the 11th function of ESolver_KS: create_Output_Rho
//! the 13th function of ESolver_KS: create_Output_Rho
//! mohan add 2024-05-12
//------------------------------------------------------------------------------
template<typename T, typename Device>
@@ -765,7 +784,7 @@ ModuleIO::Output_Rho ESolver_KS<T, Device>::create_Output_Rho(


//------------------------------------------------------------------------------
//! the 12th function of ESolver_KS: create_Output_Kin
//! the 14th function of ESolver_KS: create_Output_Kin
//! mohan add 2024-05-12
//------------------------------------------------------------------------------
template<typename T, typename Device>
@@ -789,7 +808,7 @@ ModuleIO::Output_Rho ESolver_KS<T, Device>::create_Output_Kin(int is, int iter,


//------------------------------------------------------------------------------
//! the 13th function of ESolver_KS: create_Output_Potential
//! the 15th function of ESolver_KS: create_Output_Potential
//! mohan add 2024-05-12
//------------------------------------------------------------------------------
template<typename T, typename Device>
@@ -814,7 +833,7 @@ ModuleIO::Output_Potential ESolver_KS<T, Device>::create_Output_Potential(int it


//------------------------------------------------------------------------------
//! the 14th-18th functions of ESolver_KS
//! the 16th-20th functions of ESolver_KS
//! mohan add 2024-05-12
//------------------------------------------------------------------------------
//! This is for mixed-precision pw/LCAO basis sets.
8 changes: 7 additions & 1 deletion source/module_esolver/esolver_ks.h
Original file line number Diff line number Diff line change
@@ -56,7 +56,13 @@ class ESolver_KS : public ESolver_FP
virtual void hamilt2estates(const double ethr){};

// get current step of Ionic simulation
virtual int getniter() override;
virtual int get_niter() override;

// get maxniter used in current scf
virtual int get_maxniter() override;

// get conv_elec used in current scf
virtual bool get_conv_elec() override;

protected:
//! Something to do before SCF iterations.
2 changes: 1 addition & 1 deletion source/module_esolver/esolver_of.h
Original file line number Diff line number Diff line change
@@ -34,7 +34,7 @@ class ESolver_OF : public ESolver_FP

virtual void cal_stress(ModuleBase::matrix& stress) override;

virtual int getniter() override
virtual int get_niter() override
{
return this->iter_;
}
15 changes: 4 additions & 11 deletions source/module_relax/relax_driver.cpp
Original file line number Diff line number Diff line change
@@ -8,9 +8,7 @@
#include "module_io/json_output/output_info.h"



template<typename FPTYPE, typename Device>
void Relax_Driver<FPTYPE, Device>::relax_driver(ModuleESolver::ESolver *p_esolver)
void Relax_Driver::relax_driver(ModuleESolver::ESolver *p_esolver)
{
ModuleBase::TITLE("Ions", "opt_ions");
ModuleBase::timer::tick("Ions", "opt_ions");
@@ -115,12 +113,10 @@ void Relax_Driver<FPTYPE, Device>::relax_driver(ModuleESolver::ESolver *p_esolve
GlobalC::ucell.print_cell_cif("STRU_NOW.cif");
}

ModuleESolver::ESolver_KS<FPTYPE, Device>* p_esolver_ks
= dynamic_cast<ModuleESolver::ESolver_KS<FPTYPE, Device>*>(p_esolver);
if (p_esolver_ks
if (p_esolver
&& stop
&& p_esolver_ks->maxniter == p_esolver_ks->niter
&& !(p_esolver_ks->conv_elec))
&& p_esolver->get_maxniter() == p_esolver->get_niter()
&& !(p_esolver->get_conv_elec()))
{
std::cout << "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%" << std::endl;
std::cout << "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%" << std::endl;
@@ -165,6 +161,3 @@ void Relax_Driver<FPTYPE, Device>::relax_driver(ModuleESolver::ESolver *p_esolve
ModuleBase::timer::tick("Ions", "opt_ions");
return;
}

template class Relax_Driver<float, base_device::DEVICE_CPU>;
template class Relax_Driver<double, base_device::DEVICE_CPU>;
1 change: 0 additions & 1 deletion source/module_relax/relax_driver.h
Original file line number Diff line number Diff line change
@@ -6,7 +6,6 @@
#include "relax_new/relax.h"
#include "relax_old/relax_old.h"

template <typename FPTYPE, typename Device = base_device::DEVICE_CPU>
class Relax_Driver
{

0 comments on commit 142adc2

Please sign in to comment.