Skip to content

Commit

Permalink
fix ut errors
Browse files Browse the repository at this point in the history
  • Loading branch information
denghuilu committed Dec 4, 2023
1 parent 5ac9321 commit 14e2247
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions source/module_hsolver/hsolver_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ void HSolverPW::update()
return;
}*/
template<typename T, typename Device>
void HSolverPW<T, Device>::initDiagh(const psi::Psi<T, Device>& psi_in)
void HSolverPW<T, Device>::initDiagh(const psi::Psi<T, Device>& psi)
{
if (this->method == "cg")
{
Expand All @@ -53,7 +53,8 @@ void HSolverPW<T, Device>::initDiagh(const psi::Psi<T, Device>& psi_in)
}
// this->pdiagh = new DiagoCG<T, Device>(precondition.data());
// warp the subspace_func into a lambda function
auto subspace_func = [this](const ct::Tensor& psi_in, ct::Tensor& psi_out) {
auto ngk_pointer = psi.get_ngk_pointer();
auto subspace_func = [this, ngk_pointer](const ct::Tensor& psi_in, ct::Tensor& psi_out) {
// psi_in should be a 2D tensor:
// psi_in.shape() = [nbands, nbasis]
const auto ndim = psi_in.shape().ndim();
Expand All @@ -62,11 +63,13 @@ void HSolverPW<T, Device>::initDiagh(const psi::Psi<T, Device>& psi_in)
auto psi_in_wrapper = psi::Psi<T, Device>(
psi_in.data<T>(), 1,
psi_in.shape().dim_size(0),
psi_in.shape().dim_size(1));
psi_in.shape().dim_size(1),
ngk_pointer);
auto psi_out_wrapper = psi::Psi<T, Device>(
psi_out.data<T>(), 1,
psi_out.shape().dim_size(0),
psi_out.shape().dim_size(1));
psi_out.shape().dim_size(1),
ngk_pointer);
auto eigen = ct::Tensor(
ct::DataTypeToEnum<Real>::value,
ct::DeviceType::CpuDevice,
Expand Down Expand Up @@ -108,13 +111,13 @@ void HSolverPW<T, Device>::initDiagh(const psi::Psi<T, Device>& psi_in)
delete (DiagoBPCG<T, Device>*)this->pdiagh;
this->pdiagh = new DiagoBPCG<T, Device>(precondition.data());
this->pdiagh->method = this->method;
reinterpret_cast<DiagoBPCG<T, Device>*>(this->pdiagh)->init_iter(psi_in);
reinterpret_cast<DiagoBPCG<T, Device>*>(this->pdiagh)->init_iter(psi);
}
}
else {
this->pdiagh = new DiagoBPCG<T, Device>(precondition.data());
this->pdiagh->method = this->method;
reinterpret_cast<DiagoBPCG<T, Device>*>(this->pdiagh)->init_iter(psi_in);
reinterpret_cast<DiagoBPCG<T, Device>*>(this->pdiagh)->init_iter(psi);
}
}
else
Expand Down Expand Up @@ -377,7 +380,8 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm, psi::P
using ct_Device = typename ct::PsiToContainer<Device>::type;
auto cg = reinterpret_cast<DiagoCG<T, Device>*>(this->pdiagh);
// warp the hpsi_func and spsi_func into a lambda function
auto hpsi_func = [hm](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) {
auto ngk_pointer = psi.get_ngk_pointer();
auto hpsi_func = [hm, ngk_pointer](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) {
ModuleBase::timer::tick("DiagoCG_New", "hpsi_func");
// psi_in should be a 2D tensor:
// psi_in.shape() = [nbands, nbasis]
Expand All @@ -387,7 +391,8 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm, psi::P
auto psi_wrapper = psi::Psi<T, Device>(
psi_in.data<T>(), 1,
ndim == 1 ? 1 : psi_in.shape().dim_size(0),
ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1));
ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1),
ngk_pointer);
psi::Range all_bands_range(true, psi_wrapper.get_current_k(), 0, psi_wrapper.get_nbands() - 1);
using hpsi_info = typename hamilt::Operator<T, Device>::hpsi_info;
hpsi_info info(&psi_wrapper, all_bands_range, hpsi_out.data<T>());
Expand Down

0 comments on commit 14e2247

Please sign in to comment.