Skip to content

Commit

Permalink
Refactor: add smooth threshold support for bpcg method (#5709)
Browse files Browse the repository at this point in the history
* add smooth threshold support for bpcg method

* fix build test bug

* fix build bug
  • Loading branch information
haozhihan authored Dec 10, 2024
1 parent c5d5eab commit ab2a789
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 18 deletions.
16 changes: 8 additions & 8 deletions source/module_hsolver/diago_bpcg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,15 @@ void DiagoBPCG<T, Device>::init_iter(const int nband, const int nbasis) {
}

template<typename T, typename Device>
bool DiagoBPCG<T, Device>::test_error(const ct::Tensor& err_in, Real thr_in)
bool DiagoBPCG<T, Device>::test_error(const ct::Tensor& err_in, const std::vector<double>& ethr_band)
{
const Real * _err_st = err_in.data<Real>();
if (err_in.device_type() == ct::DeviceType::GpuDevice) {
ct::Tensor h_err_in = err_in.to_device<ct::DEVICE_CPU>();
_err_st = h_err_in.data<Real>();
}
for (int ii = 0; ii < this->n_band; ii++) {
if (_err_st[ii] > thr_in) {
if (_err_st[ii] > ethr_band[ii]) {
return true;
}
}
Expand Down Expand Up @@ -242,11 +242,11 @@ void DiagoBPCG<T, Device>::calc_hsub_with_block_exit(
return;
}

template<typename T, typename Device>
void DiagoBPCG<T, Device>::diag(
const HPsiFunc& hpsi_func,
T *psi_in,
Real* eigenvalue_in)
template <typename T, typename Device>
void DiagoBPCG<T, Device>::diag(const HPsiFunc& hpsi_func,
T* psi_in,
Real* eigenvalue_in,
const std::vector<double>& ethr_band)
{
const int current_scf_iter = hsolver::DiagoIterAssist<T, Device>::SCF_ITER;
// Get the pointer of the input psi
Expand Down Expand Up @@ -301,7 +301,7 @@ void DiagoBPCG<T, Device>::diag(
if (current_scf_iter == 1 && ntry % this->nline == 0) {
this->calc_hsub_with_block(hpsi_func, psi_in, this->psi, this->hpsi, this->hsub, this->work, this->eigen);
}
} while (ntry < max_iter && this->test_error(this->err_st, this->all_band_cg_thr));
} while (ntry < max_iter && this->test_error(this->err_st, ethr_band));

this->calc_hsub_with_block_exit(this->psi, this->hpsi, this->hsub, this->work, this->eigen);

Expand Down
10 changes: 5 additions & 5 deletions source/module_hsolver/diago_bpcg.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,10 @@ class DiagoBPCG
* @param psi_in Pointer to input wavefunction psi matrix with [dim: n_basis x n_band, column major].
* @param eigenvalue_in Pointer to the eigen array with [dim: n_band, column major].
*/
void diag(const HPsiFunc& hpsi_func, T *psi_in, Real *eigenvalue_in);

void diag(const HPsiFunc& hpsi_func,
T* psi_in,
Real* eigenvalue_in,
const std::vector<double>& ethr_band);

private:
/// the number of rows of the input psi
Expand All @@ -77,8 +79,6 @@ class DiagoBPCG
int n_basis = 0;
/// max iter steps for all-band cg loop
int nline = 4;
/// cg convergence thr
Real all_band_cg_thr = 1E-5;

ct::DataType r_type = ct::DataType::DT_INVALID;
ct::DataType t_type = ct::DataType::DT_INVALID;
Expand Down Expand Up @@ -316,7 +316,7 @@ class DiagoBPCG
* @param thr_in The threshold.
* @return Returns true if all error values are less than or equal to the threshold, false otherwise.
*/
bool test_error(const ct::Tensor& err_in, Real thr_in);
bool test_error(const ct::Tensor& err_in, const std::vector<double>& ethr_band);

using ct_Device = typename ct::PsiToContainer<Device>::type;
using setmem_var_op = ct::kernels::set_memory<Real, ct_Device>;
Expand Down
2 changes: 1 addition & 1 deletion source/module_hsolver/hsolver_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
};
DiagoBPCG<T, Device> bpcg(pre_condition.data());
bpcg.init_iter(nband, nbasis);
bpcg.diag(hpsi_func, psi.get_pointer(), eigenvalue);
bpcg.diag(hpsi_func, psi.get_pointer(), eigenvalue, this->ethr_band);
}
else if (this->method == "dav_subspace")
{
Expand Down
9 changes: 5 additions & 4 deletions source/module_hsolver/test/diago_bpcg_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,11 @@ class DiagoBPCGPrepare
hpsi_out, ld_psi);
};
bpcg.init_iter(nband, npw);
bpcg.diag(hpsi_func, psi_local.get_pointer(), en);
bpcg.diag(hpsi_func, psi_local.get_pointer(), en);
bpcg.diag(hpsi_func, psi_local.get_pointer(), en);
bpcg.diag(hpsi_func, psi_local.get_pointer(), en);
std::vector<double> ethr_band(nband, 1e-5);
bpcg.diag(hpsi_func, psi_local.get_pointer(), en, ethr_band);
bpcg.diag(hpsi_func, psi_local.get_pointer(), en, ethr_band);
bpcg.diag(hpsi_func, psi_local.get_pointer(), en, ethr_band);
bpcg.diag(hpsi_func, psi_local.get_pointer(), en, ethr_band);
end = MPI_Wtime();
//if(mypnum == 0) printf("diago time:%7.3f\n",end-start);
delete [] DIAGOTEST::npw_local;
Expand Down

0 comments on commit ab2a789

Please sign in to comment.