From d54706a99c1371cff626c8f2bc7b4e55961fafdd Mon Sep 17 00:00:00 2001 From: haozhihan Date: Mon, 17 Jun 2024 23:03:52 +0800 Subject: [PATCH] remove psi::Psi from Diago_DavSubspace --- source/module_hsolver/diago_dav_subspace.cpp | 149 ++++++++----------- source/module_hsolver/diago_dav_subspace.h | 5 +- 2 files changed, 61 insertions(+), 93 deletions(-) diff --git a/source/module_hsolver/diago_dav_subspace.cpp b/source/module_hsolver/diago_dav_subspace.cpp index 42e4601ae9..83ad2daf9b 100644 --- a/source/module_hsolver/diago_dav_subspace.cpp +++ b/source/module_hsolver/diago_dav_subspace.cpp @@ -83,15 +83,11 @@ Diago_DavSubspace::~Diago_DavSubspace() } template -int Diago_DavSubspace::diag_once( - - const Func& hpsi_func, - T* psi_in, - - psi::Psi& psi, - - Real* eigenvalue_in_hsolver, - const std::vector& is_occupied) +int Diago_DavSubspace::diag_once(const Func& hpsi_func, + T* psi_in, + const int psi_in_dmax, + Real* eigenvalue_in_hsolver, + const std::vector& is_occupied) { ModuleBase::timer::tick("Diago_DavSubspace", "diag_once"); @@ -119,15 +115,10 @@ int Diago_DavSubspace::diag_once( syncmem_complex_op()(this->ctx, this->ctx, this->psi_in_iter + m * this->dim, - psi.get_k_first() ? &psi(m, 0) : &psi(m, 0, 0), + psi_in + m * psi_in_dmax, this->dim); } - // auto psi_iter_wrapper = psi::Psi(this->psi_in_iter, 1, this->nbase_x, this->dim); - // // calculate H|psi> - // hpsi_info dav_hpsi_in(&psi_iter_wrapper, psi::Range(1, 0, 0, psi_iter_wrapper.get_nbands() - 1), this->hphi); - // phm_in->ops->hPsi(dav_hpsi_in); - hpsi_func(this->hphi, this->psi_in_iter, this->nbase_x, this->dim, 0, this->nbase_x - 1); this->cal_elem(this->dim, nbase, this->notconv, this->psi_in_iter, this->hphi, this->hcc, this->scc); @@ -155,18 +146,15 @@ int Diago_DavSubspace::diag_once( { dav_iter++; - this->cal_grad( - - hpsi_func, - - this->dim, - nbase, - this->notconv, - this->psi_in_iter, - this->hphi, - this->vcc, - unconv.data(), - &eigenvalue_iter); + this->cal_grad(hpsi_func, + this->dim, + nbase, + this->notconv, + this->psi_in_iter, + this->hphi, + this->vcc, + unconv.data(), + &eigenvalue_iter); this->cal_elem(this->dim, nbase, this->notconv, this->psi_in_iter, this->hphi, this->hcc, this->scc); @@ -212,23 +200,22 @@ int Diago_DavSubspace::diag_once( ModuleBase::timer::tick("Diago_DavSubspace", "last"); // updata eigenvectors of Hamiltonian - setmem_complex_op()(this->ctx, psi.get_pointer(), 0, n_band * psi.get_nbasis()); - //<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< - // haozhihan repalce 2022-10-18 + setmem_complex_op()(this->ctx, psi_in, 0, n_band * psi_in_dmax); + gemm_op()(this->ctx, 'N', 'N', - this->dim, // m: row of A,C - this->n_band, // n: col of B,C - nbase, // k: col of A, row of B + this->dim, + this->n_band, + nbase, this->one, - this->psi_in_iter, // A dim * nbase + this->psi_in_iter, this->dim, - this->vcc, // B nbase * n_band + this->vcc, this->nbase_x, this->zero, - psi.get_pointer(), // C dim * n_band - psi.get_nbasis()); + psi_in, + psi_in_dmax); if (!this->notconv || (dav_iter == this->iter_nmax)) { @@ -243,16 +230,26 @@ int Diago_DavSubspace::diag_once( // then replace the first N (=nband) basis vectors with the current // estimate of the eigenvectors and set the basis dimension to N; + // update this->psi_in_iter according to psi_in + for (size_t i = 0; i < this->n_band; i++) + { + syncmem_complex_op()(this->ctx, + this->ctx, + this->psi_in_iter + i * this->dim, + psi_in + i * psi_in_dmax, + this->dim); + } + this->refresh(this->dim, this->n_band, nbase, eigenvalue_in_hsolver, - psi, this->psi_in_iter, this->hphi, this->hcc, this->scc, this->vcc); + ModuleBase::timer::tick("Diago_DavSubspace", "last"); } } @@ -289,18 +286,17 @@ void Diago_DavSubspace::cal_grad(const Func& hpsi_func, gemm_op()(this->ctx, 'N', 'N', - this->dim, // m: row of A,C - notconv, // n: col of B,C - nbase, // k: col of A, row of B - this->one, // alpha - psi_iter, // A - this->dim, // LDA - vcc, // B - this->nbase_x, // LDB - this->zero, // belta - psi_iter + nbase * this->dim, // C dim * notconv - this->dim // LDC - ); + this->dim, + notconv, + nbase, + this->one, + psi_iter, + this->dim, + vcc, + this->nbase_x, + this->zero, + psi_iter + nbase * this->dim, + this->dim); for (int m = 0; m < notconv; m++) { @@ -317,18 +313,17 @@ void Diago_DavSubspace::cal_grad(const Func& hpsi_func, gemm_op()(this->ctx, 'N', 'N', - this->dim, // m: row of A,C - notconv, // n: col of B,C - nbase, // k: col of A, row of B - this->one, // alpha - hphi, // A dim * nbase - this->dim, // LDA - vcc, // B nbase * notconv - this->nbase_x, // LDB - this->one, // belta + this->dim, + notconv, + nbase, + this->one, + hphi, + this->dim, + vcc, + this->nbase_x, + this->one, psi_iter + (nbase) * this->dim, - this->dim // LDC - ); + this->dim); // "precondition!!!" std::vector pre(this->dim, 0.0); @@ -365,10 +360,6 @@ void Diago_DavSubspace::cal_grad(const Func& hpsi_func, psi_norm[i]); } - // auto psi_iter_wrapper = psi::Psi(psi_iter, 1, this->nbase_x, this->dim); - // // "calculate H|psi>" for not convergence bands - // hpsi_info dav_hpsi_in(&psi_iter_wrapper, psi::Range(1, 0, nbase, nbase + notconv - 1), &hphi[nbase * this->dim]); - // phm_in->ops->hPsi(dav_hpsi_in); hpsi_func(&hphi[nbase * this->dim], psi_iter, this->nbase_x, this->dim, nbase, nbase + notconv - 1); ModuleBase::timer::tick("Diago_DavSubspace", "cal_grad"); @@ -516,7 +507,6 @@ void Diago_DavSubspace::diag_zhegvx(const int& nbase, T* scc, const int& nbase_x, std::vector* eigenvalue_iter, - // Real* eigenvalue_iter, T* vcc, bool init, bool is_subspace) @@ -647,7 +637,7 @@ void Diago_DavSubspace::refresh(const int& dim, const int& nband, int& nbase, const Real* eigenvalue_in_hsolver, - const psi::Psi& psi, + // const psi::Psi& psi, T* psi_iter, T* hp, T* sp, @@ -656,15 +646,6 @@ void Diago_DavSubspace::refresh(const int& dim, { ModuleBase::timer::tick("Diago_DavSubspace", "refresh"); - // update psi - for (size_t i = 0; i < nband; i++) - { - syncmem_complex_op()(this->ctx, - this->ctx, - psi_iter + i * this->dim, - &psi(i, 0), - this->dim); - } gemm_op()(this->ctx, 'N', 'N', @@ -681,11 +662,7 @@ void Diago_DavSubspace::refresh(const int& dim, this->dim); // update hphi - syncmem_complex_op()(this->ctx, - this->ctx, - hphi, - psi_iter + nband * this->dim, - this->dim * nband); + syncmem_complex_op()(this->ctx, this->ctx, hphi, psi_iter + nband * this->dim, this->dim * nband); nbase = nband; @@ -816,15 +793,7 @@ int Diago_DavSubspace::diag(const Func& hpsi_func, DiagoIterAssist::diagH_subspace(phm_in, psi, psi, eigenvalue_in_hsolver, psi.get_nbands()); } - sum_iter += this->diag_once( - - hpsi_func, - psi_in, - - psi, - - eigenvalue_in_hsolver, - is_occupied); + sum_iter += this->diag_once(hpsi_func, psi_in, psi.get_nbasis(), eigenvalue_in_hsolver, is_occupied); ++ntry; diff --git a/source/module_hsolver/diago_dav_subspace.h b/source/module_hsolver/diago_dav_subspace.h index de702a1e27..880962827c 100644 --- a/source/module_hsolver/diago_dav_subspace.h +++ b/source/module_hsolver/diago_dav_subspace.h @@ -36,7 +36,7 @@ class Diago_DavSubspace : public DiagH hamilt::Hamilt* phm_in, psi::Psi& phi, - + Real* eigenvalue_in, const std::vector& is_occupied, const bool& scf_type); @@ -105,7 +105,6 @@ class Diago_DavSubspace : public DiagH const int& nband, int& nbase, const Real* eigenvalue, - const psi::Psi& psi, T* psi_iter, T* hphi, T* hcc, @@ -124,7 +123,7 @@ class Diago_DavSubspace : public DiagH int diag_once(const Func& hpsi_func, T* psi_in, - psi::Psi& psi, + const int psi_in_dmax, Real* eigenvalue_in, const std::vector& is_occupied);