Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: Initial removal of psi::Psi<T, Device> from Diago_DavSubspace #4416

Merged
merged 1 commit into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 59 additions & 90 deletions source/module_hsolver/diago_dav_subspace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,11 @@ Diago_DavSubspace<T, Device>::~Diago_DavSubspace()
}

template <typename T, typename Device>
int Diago_DavSubspace<T, Device>::diag_once(

const Func& hpsi_func,
T* psi_in,

psi::Psi<T, Device>& psi,

Real* eigenvalue_in_hsolver,
const std::vector<bool>& is_occupied)
int Diago_DavSubspace<T, Device>::diag_once(const Func& hpsi_func,
T* psi_in,
const int psi_in_dmax,
Real* eigenvalue_in_hsolver,
const std::vector<bool>& is_occupied)
{
ModuleBase::timer::tick("Diago_DavSubspace", "diag_once");

Expand Down Expand Up @@ -119,15 +115,10 @@ int Diago_DavSubspace<T, Device>::diag_once(
syncmem_complex_op()(this->ctx,
this->ctx,
this->psi_in_iter + m * this->dim,
psi.get_k_first() ? &psi(m, 0) : &psi(m, 0, 0),
psi_in + m * psi_in_dmax,
this->dim);
}

// auto psi_iter_wrapper = psi::Psi<T, Device>(this->psi_in_iter, 1, this->nbase_x, this->dim);
// // calculate H|psi>
// hpsi_info dav_hpsi_in(&psi_iter_wrapper, psi::Range(1, 0, 0, psi_iter_wrapper.get_nbands() - 1), this->hphi);
// phm_in->ops->hPsi(dav_hpsi_in);

hpsi_func(this->hphi, this->psi_in_iter, this->nbase_x, this->dim, 0, this->nbase_x - 1);

this->cal_elem(this->dim, nbase, this->notconv, this->psi_in_iter, this->hphi, this->hcc, this->scc);
Expand Down Expand Up @@ -155,18 +146,15 @@ int Diago_DavSubspace<T, Device>::diag_once(
{
dav_iter++;

this->cal_grad(

hpsi_func,

this->dim,
nbase,
this->notconv,
this->psi_in_iter,
this->hphi,
this->vcc,
unconv.data(),
&eigenvalue_iter);
this->cal_grad(hpsi_func,
this->dim,
nbase,
this->notconv,
this->psi_in_iter,
this->hphi,
this->vcc,
unconv.data(),
&eigenvalue_iter);

this->cal_elem(this->dim, nbase, this->notconv, this->psi_in_iter, this->hphi, this->hcc, this->scc);

Expand Down Expand Up @@ -212,23 +200,22 @@ int Diago_DavSubspace<T, Device>::diag_once(
ModuleBase::timer::tick("Diago_DavSubspace", "last");

// updata eigenvectors of Hamiltonian
setmem_complex_op()(this->ctx, psi.get_pointer(), 0, n_band * psi.get_nbasis());
//<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
// haozhihan repalce 2022-10-18
setmem_complex_op()(this->ctx, psi_in, 0, n_band * psi_in_dmax);

gemm_op<T, Device>()(this->ctx,
'N',
'N',
this->dim, // m: row of A,C
this->n_band, // n: col of B,C
nbase, // k: col of A, row of B
this->dim,
this->n_band,
nbase,
this->one,
this->psi_in_iter, // A dim * nbase
this->psi_in_iter,
this->dim,
this->vcc, // B nbase * n_band
this->vcc,
this->nbase_x,
this->zero,
psi.get_pointer(), // C dim * n_band
psi.get_nbasis());
psi_in,
psi_in_dmax);

if (!this->notconv || (dav_iter == this->iter_nmax))
{
Expand All @@ -243,16 +230,26 @@ int Diago_DavSubspace<T, Device>::diag_once(
// then replace the first N (=nband) basis vectors with the current
// estimate of the eigenvectors and set the basis dimension to N;

// update this->psi_in_iter according to psi_in
for (size_t i = 0; i < this->n_band; i++)
{
syncmem_complex_op()(this->ctx,
this->ctx,
this->psi_in_iter + i * this->dim,
psi_in + i * psi_in_dmax,
this->dim);
}

this->refresh(this->dim,
this->n_band,
nbase,
eigenvalue_in_hsolver,
psi,
this->psi_in_iter,
this->hphi,
this->hcc,
this->scc,
this->vcc);

ModuleBase::timer::tick("Diago_DavSubspace", "last");
}
}
Expand Down Expand Up @@ -289,18 +286,17 @@ void Diago_DavSubspace<T, Device>::cal_grad(const Func& hpsi_func,
gemm_op<T, Device>()(this->ctx,
'N',
'N',
this->dim, // m: row of A,C
notconv, // n: col of B,C
nbase, // k: col of A, row of B
this->one, // alpha
psi_iter, // A
this->dim, // LDA
vcc, // B
this->nbase_x, // LDB
this->zero, // belta
psi_iter + nbase * this->dim, // C dim * notconv
this->dim // LDC
);
this->dim,
notconv,
nbase,
this->one,
psi_iter,
this->dim,
vcc,
this->nbase_x,
this->zero,
psi_iter + nbase * this->dim,
this->dim);

for (int m = 0; m < notconv; m++)
{
Expand All @@ -317,18 +313,17 @@ void Diago_DavSubspace<T, Device>::cal_grad(const Func& hpsi_func,
gemm_op<T, Device>()(this->ctx,
'N',
'N',
this->dim, // m: row of A,C
notconv, // n: col of B,C
nbase, // k: col of A, row of B
this->one, // alpha
hphi, // A dim * nbase
this->dim, // LDA
vcc, // B nbase * notconv
this->nbase_x, // LDB
this->one, // belta
this->dim,
notconv,
nbase,
this->one,
hphi,
this->dim,
vcc,
this->nbase_x,
this->one,
psi_iter + (nbase) * this->dim,
this->dim // LDC
);
this->dim);

// "precondition!!!"
std::vector<Real> pre(this->dim, 0.0);
Expand Down Expand Up @@ -365,10 +360,6 @@ void Diago_DavSubspace<T, Device>::cal_grad(const Func& hpsi_func,
psi_norm[i]);
}

// auto psi_iter_wrapper = psi::Psi<T, Device>(psi_iter, 1, this->nbase_x, this->dim);
// // "calculate H|psi>" for not convergence bands
// hpsi_info dav_hpsi_in(&psi_iter_wrapper, psi::Range(1, 0, nbase, nbase + notconv - 1), &hphi[nbase * this->dim]);
// phm_in->ops->hPsi(dav_hpsi_in);
hpsi_func(&hphi[nbase * this->dim], psi_iter, this->nbase_x, this->dim, nbase, nbase + notconv - 1);

ModuleBase::timer::tick("Diago_DavSubspace", "cal_grad");
Expand Down Expand Up @@ -516,7 +507,6 @@ void Diago_DavSubspace<T, Device>::diag_zhegvx(const int& nbase,
T* scc,
const int& nbase_x,
std::vector<Real>* eigenvalue_iter,
// Real* eigenvalue_iter,
T* vcc,
bool init,
bool is_subspace)
Expand Down Expand Up @@ -647,7 +637,7 @@ void Diago_DavSubspace<T, Device>::refresh(const int& dim,
const int& nband,
int& nbase,
const Real* eigenvalue_in_hsolver,
const psi::Psi<T, Device>& psi,
// const psi::Psi<T, Device>& psi,
T* psi_iter,
T* hp,
T* sp,
Expand All @@ -656,15 +646,6 @@ void Diago_DavSubspace<T, Device>::refresh(const int& dim,
{
ModuleBase::timer::tick("Diago_DavSubspace", "refresh");

// update psi
for (size_t i = 0; i < nband; i++)
{
syncmem_complex_op()(this->ctx,
this->ctx,
psi_iter + i * this->dim,
&psi(i, 0),
this->dim);
}
gemm_op<T, Device>()(this->ctx,
'N',
'N',
Expand All @@ -681,11 +662,7 @@ void Diago_DavSubspace<T, Device>::refresh(const int& dim,
this->dim);

// update hphi
syncmem_complex_op()(this->ctx,
this->ctx,
hphi,
psi_iter + nband * this->dim,
this->dim * nband);
syncmem_complex_op()(this->ctx, this->ctx, hphi, psi_iter + nband * this->dim, this->dim * nband);

nbase = nband;

Expand Down Expand Up @@ -816,15 +793,7 @@ int Diago_DavSubspace<T, Device>::diag(const Func& hpsi_func,
DiagoIterAssist<T, Device>::diagH_subspace(phm_in, psi, psi, eigenvalue_in_hsolver, psi.get_nbands());
}

sum_iter += this->diag_once(

hpsi_func,
psi_in,

psi,

eigenvalue_in_hsolver,
is_occupied);
sum_iter += this->diag_once(hpsi_func, psi_in, psi.get_nbasis(), eigenvalue_in_hsolver, is_occupied);

++ntry;

Expand Down
5 changes: 2 additions & 3 deletions source/module_hsolver/diago_dav_subspace.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class Diago_DavSubspace : public DiagH<T, Device>

hamilt::Hamilt<T, Device>* phm_in,
psi::Psi<T, Device>& phi,

Real* eigenvalue_in,
const std::vector<bool>& is_occupied,
const bool& scf_type);
Expand Down Expand Up @@ -105,7 +105,6 @@ class Diago_DavSubspace : public DiagH<T, Device>
const int& nband,
int& nbase,
const Real* eigenvalue,
const psi::Psi<T, Device>& psi,
T* psi_iter,
T* hphi,
T* hcc,
Expand All @@ -124,7 +123,7 @@ class Diago_DavSubspace : public DiagH<T, Device>

int diag_once(const Func& hpsi_func,
T* psi_in,
psi::Psi<T, Device>& psi,
const int psi_in_dmax,
Real* eigenvalue_in,
const std::vector<bool>& is_occupied);

Expand Down
Loading