diff --git a/source/module_elecstate/cal_dm.h b/source/module_elecstate/cal_dm.h index 56aad08f3c..99815b296b 100644 --- a/source/module_elecstate/cal_dm.h +++ b/source/module_elecstate/cal_dm.h @@ -82,7 +82,7 @@ inline void cal_dm(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& wg, //dm.fix_k(ik); dm[ik].create(ParaV->ncol, ParaV->nrow); // wg_wfc(ib,iw) = wg[ib] * wfc(ib,iw); - psi::Psi> wg_wfc(1, wfc.get_nbands(), wfc.get_nbasis(), nullptr); + psi::Psi> wg_wfc(1, wfc.get_nbands(), wfc.get_nbasis(), wfc.get_nbasis(), true); const std::complex* pwfc = wfc.get_pointer(); std::complex* pwg_wfc = wg_wfc.get_pointer(); #ifdef _OPENMP diff --git a/source/module_elecstate/elecstate_pw.cpp b/source/module_elecstate/elecstate_pw.cpp index f241c59db8..944bcddef4 100644 --- a/source/module_elecstate/elecstate_pw.cpp +++ b/source/module_elecstate/elecstate_pw.cpp @@ -271,7 +271,7 @@ void ElecStatePW::cal_becsum(const psi::Psi& psi) { const T one{1, 0}; const T zero{0, 0}; - const int npol = psi.npol; + const int npol = psi.get_npol(); const int npwx = psi.get_nbasis() / npol; const int nbands = psi.get_nbands() * npol; const int nkb = this->ppcell->nkb; diff --git a/source/module_esolver/esolver_ks_lcao_tddft.cpp b/source/module_esolver/esolver_ks_lcao_tddft.cpp index 84a4ed0c68..acff4f7bf1 100644 --- a/source/module_esolver/esolver_ks_lcao_tddft.cpp +++ b/source/module_esolver/esolver_ks_lcao_tddft.cpp @@ -196,9 +196,9 @@ void ESolver_KS_LCAO_TDDFT::update_pot(UnitCell& ucell, const int istep, const i if (this->psi_laststep == nullptr) { #ifdef __MPI - this->psi_laststep = new psi::Psi>(kv.get_nks(), ncol_nbands, nrow, nullptr); + this->psi_laststep = new psi::Psi>(kv.get_nks(), ncol_nbands, nrow, kv.ngk, true); #else - this->psi_laststep = new psi::Psi>(kv.get_nks(), nbands, nlocal, nullptr); + this->psi_laststep = new psi::Psi>(kv.get_nks(), nbands, nlocal, kv.ngk, true); #endif } diff --git a/source/module_esolver/esolver_ks_lcaopw.cpp b/source/module_esolver/esolver_ks_lcaopw.cpp index 08d1043a4a..4649fb07ca 100644 --- a/source/module_esolver/esolver_ks_lcaopw.cpp +++ b/source/module_esolver/esolver_ks_lcaopw.cpp @@ -93,9 +93,10 @@ namespace ModuleESolver ESolver_KS_PW::before_all_runners(ucell, inp); delete this->psi_local; this->psi_local = new psi::Psi(this->psi->get_nk(), - this->p_psi_init->psi_initer->nbands_start(), - this->psi->get_nbasis(), - this->psi->get_ngk_pointer()); + this->p_psi_init->psi_initer->nbands_start(), + this->psi->get_nbasis(), + this->kv.ngk, + true); #ifdef __EXX if (PARAM.inp.calculation == "scf" || PARAM.inp.calculation == "relax" || PARAM.inp.calculation == "cell-relax" diff --git a/source/module_esolver/esolver_ks_pw.cpp b/source/module_esolver/esolver_ks_pw.cpp index a96d487a5c..cf5e300537 100644 --- a/source/module_esolver/esolver_ks_pw.cpp +++ b/source/module_esolver/esolver_ks_pw.cpp @@ -212,7 +212,7 @@ void ESolver_KS_PW::before_all_runners(UnitCell& ucell, const Input_p this->kv, this->ppcell, *this->pw_wfc); - allocate_psi(this->psi, this->kv.get_nks(), this->kv.ngk.data(), PARAM.inp.nbands, this->pw_wfc->npwk_max); + allocate_psi(this->psi, this->kv.get_nks(), this->kv.ngk, PARAM.inp.nbands, this->pw_wfc->npwk_max); this->p_psi_init->prepare_init(PARAM.inp.pw_seed); this->kspw_psi = PARAM.inp.device == "gpu" || PARAM.inp.precision == "single" diff --git a/source/module_esolver/esolver_sdft_pw.cpp b/source/module_esolver/esolver_sdft_pw.cpp index 667b440916..f5f9292522 100644 --- a/source/module_esolver/esolver_sdft_pw.cpp +++ b/source/module_esolver/esolver_sdft_pw.cpp @@ -78,13 +78,20 @@ void ESolver_SDFT_PW::before_all_runners(UnitCell& ucell, const Input // 4) allocate spaces for \sqrt(f(H))|chi> and |\tilde{chi}> size_t size = stowf.chi0->size(); this->stowf.shchi - = new psi::Psi(this->kv.get_nks(), this->stowf.nchip_max, this->pw_wfc->npwk_max, this->kv.ngk.data()); + = new psi::Psi(this->kv.get_nks(), + this->stowf.nchip_max, + this->pw_wfc->npwk_max, + this->kv.ngk, + true); ModuleBase::Memory::record("SDFT::shchi", size * sizeof(T)); if (PARAM.inp.nbands > 0) { this->stowf.chiortho - = new psi::Psi(this->kv.get_nks(), this->stowf.nchip_max, this->pw_wfc->npwk_max, this->kv.ngk.data()); + = new psi::Psi(this->kv.get_nks(), + this->stowf.nchip_max, + this->pw_wfc->npwk_max, + this->kv.ngk, true); ModuleBase::Memory::record("SDFT::chiortho", size * sizeof(T)); } diff --git a/source/module_esolver/lcao_before_scf.cpp b/source/module_esolver/lcao_before_scf.cpp index 2637fe41d8..0bf61d6947 100644 --- a/source/module_esolver/lcao_before_scf.cpp +++ b/source/module_esolver/lcao_before_scf.cpp @@ -159,7 +159,7 @@ void ESolver_KS_LCAO::before_scf(UnitCell& ucell, const int istep) ncol = PARAM.inp.nbands; #endif } - this->psi = new psi::Psi(nsk, ncol, this->pv.nrow, nullptr); + this->psi = new psi::Psi(nsk, ncol, this->pv.nrow, this->kv.ngk, true); } // init wfc from file diff --git a/source/module_esolver/lcao_others.cpp b/source/module_esolver/lcao_others.cpp index fc0ab246d3..faca3563f0 100644 --- a/source/module_esolver/lcao_others.cpp +++ b/source/module_esolver/lcao_others.cpp @@ -165,7 +165,7 @@ void ESolver_KS_LCAO::others(UnitCell& ucell, const int istep) ncol = PARAM.inp.nbands; #endif } - this->psi = new psi::Psi(nsk, ncol, this->pv.nrow, nullptr); + this->psi = new psi::Psi(nsk, ncol, this->pv.nrow, this->kv.ngk, true); } // init wfc from file diff --git a/source/module_hamilt_general/operator.cpp b/source/module_hamilt_general/operator.cpp index e9020866e6..08a5ba97cc 100644 --- a/source/module_hamilt_general/operator.cpp +++ b/source/module_hamilt_general/operator.cpp @@ -63,7 +63,7 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp delete this->hpsi; this->hpsi = new psi::Psi(hpsi_pointer, 1, - nbands / psi_input->npol, + nbands / psi_input->get_npol(), psi_input->get_nbasis(), psi_input->get_nbasis(), true); @@ -86,7 +86,7 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp default: op->act(nbands, psi_input->get_nbasis(), - psi_input->npol, + psi_input->get_npol(), tmpsi_in, this->hpsi->get_pointer(), psi_input->get_current_nbas(), @@ -105,7 +105,7 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp } ModuleBase::timer::tick("Operator", "hPsi"); - return hpsi_info(this->hpsi, psi::Range(1, 0, 0, nbands / psi_input->npol), hpsi_pointer); + return hpsi_info(this->hpsi, psi::Range(1, 0, 0, nbands / psi_input->get_npol()), hpsi_pointer); } template diff --git a/source/module_hamilt_lcao/module_deltaspin/cal_mw.cpp b/source/module_hamilt_lcao/module_deltaspin/cal_mw.cpp index 94c5c74db7..25b2e4e879 100644 --- a/source/module_hamilt_lcao/module_deltaspin/cal_mw.cpp +++ b/source/module_hamilt_lcao/module_deltaspin/cal_mw.cpp @@ -66,7 +66,7 @@ void spinconstrain::SpinConstrain>::cal_mi_pw() psi::Psi, base_device::DEVICE_CPU>* psi_t = static_cast, base_device::DEVICE_CPU>*>(this->psi); const int nbands = psi_t->get_nbands(); const int nks = psi_t->get_nk(); - const int npol = psi_t->npol; + const int npol = psi_t->get_npol(); for(int ik = 0; ik < nks; ik++) { psi_t->fix_k(ik); @@ -112,7 +112,7 @@ void spinconstrain::SpinConstrain>::cal_mi_pw() psi::Psi, base_device::DEVICE_GPU>* psi_t = static_cast, base_device::DEVICE_GPU>*>(this->psi); const int nbands = psi_t->get_nbands(); const int nks = psi_t->get_nk(); - const int npol = psi_t->npol; + const int npol = psi_t->get_npol(); for(int ik = 0; ik < nks; ik++) { psi_t->fix_k(ik); diff --git a/source/module_hamilt_lcao/module_deltaspin/cal_mw_from_lambda.cpp b/source/module_hamilt_lcao/module_deltaspin/cal_mw_from_lambda.cpp index 36baed7bab..d6602e6b11 100644 --- a/source/module_hamilt_lcao/module_deltaspin/cal_mw_from_lambda.cpp +++ b/source/module_hamilt_lcao/module_deltaspin/cal_mw_from_lambda.cpp @@ -199,7 +199,7 @@ void spinconstrain::SpinConstrain>::cal_mw_from_lambda(int hamilt::Hamilt, base_device::DEVICE_CPU>* hamilt_t = static_cast, base_device::DEVICE_CPU>*>(this->p_hamilt); auto* onsite_p = projectors::OnsiteProjector::get_instance(); nbands = psi_t->get_nbands(); - npol = psi_t->npol; + npol = psi_t->get_npol(); nkb = onsite_p->get_tot_nproj(); nk = psi_t->get_nk(); nh_iat = &onsite_p->get_nh(0); @@ -252,7 +252,7 @@ void spinconstrain::SpinConstrain>::cal_mw_from_lambda(int hamilt::Hamilt, base_device::DEVICE_GPU>* hamilt_t = static_cast, base_device::DEVICE_GPU>*>(this->p_hamilt); auto* onsite_p = projectors::OnsiteProjector::get_instance(); nbands = psi_t->get_nbands(); - npol = psi_t->npol; + npol = psi_t->get_npol(); nkb = onsite_p->get_tot_nproj(); nk = psi_t->get_nk(); nh_iat = &onsite_p->get_nh(0); @@ -382,7 +382,7 @@ void spinconstrain::SpinConstrain>::update_psi_charge(const hamilt::Hamilt, base_device::DEVICE_CPU>* hamilt_t = static_cast, base_device::DEVICE_CPU>*>(this->p_hamilt); auto* onsite_p = projectors::OnsiteProjector::get_instance(); nbands = psi_t->get_nbands(); - npol = psi_t->npol; + npol = psi_t->get_npol(); nkb = onsite_p->get_tot_nproj(); nk = psi_t->get_nk(); nh_iat = &onsite_p->get_nh(0); @@ -454,7 +454,7 @@ void spinconstrain::SpinConstrain>::update_psi_charge(const hamilt::Hamilt, base_device::DEVICE_GPU>* hamilt_t = static_cast, base_device::DEVICE_GPU>*>(this->p_hamilt); auto* onsite_p = projectors::OnsiteProjector::get_instance(); nbands = psi_t->get_nbands(); - npol = psi_t->npol; + npol = psi_t->get_npol(); nkb = onsite_p->get_tot_nproj(); nk = psi_t->get_nk(); nh_iat = &onsite_p->get_nh(0); diff --git a/source/module_hamilt_lcao/module_dftu/dftu_pw.cpp b/source/module_hamilt_lcao/module_dftu/dftu_pw.cpp index cc0c3a6c30..0ae2588625 100644 --- a/source/module_hamilt_lcao/module_dftu/dftu_pw.cpp +++ b/source/module_hamilt_lcao/module_dftu/dftu_pw.cpp @@ -29,11 +29,11 @@ void DFTU::cal_occ_pw(const int iter, const void* psi_in, const ModuleBase::matr psi_p->fix_k(ik); onsite_p->tabulate_atomic(ik); - onsite_p->overlap_proj_psi(nbands*psi_p->npol, psi_p->get_pointer()); + onsite_p->overlap_proj_psi(nbands*psi_p->get_npol(), psi_p->get_pointer()); const std::complex* becp = onsite_p->get_h_becp(); // becp(nbands*npol , nkb) // mag = wg * \sum_{nh}becp * becp - int nkb = onsite_p->get_size_becp() / nbands / psi_p->npol; + int nkb = onsite_p->get_size_becp() / nbands / psi_p->get_npol(); int begin_ih = 0; for(int iat = 0; iat < cell.nat; iat++) { @@ -88,11 +88,11 @@ void DFTU::cal_occ_pw(const int iter, const void* psi_in, const ModuleBase::matr psi_p->fix_k(ik); onsite_p->tabulate_atomic(ik); - onsite_p->overlap_proj_psi(nbands*psi_p->npol, psi_p->get_pointer()); + onsite_p->overlap_proj_psi(nbands*psi_p->get_npol(), psi_p->get_pointer()); const std::complex* becp = onsite_p->get_h_becp(); // becp(nbands*npol , nkb) // mag = wg * \sum_{nh}becp * becp - int nkb = onsite_p->get_size_becp() / nbands / psi_p->npol; + int nkb = onsite_p->get_size_becp() / nbands / psi_p->get_npol(); int begin_ih = 0; for(int iat = 0; iat < cell.nat; iat++) { diff --git a/source/module_hamilt_pw/hamilt_pwdft/onsite_projector.cpp b/source/module_hamilt_pw/hamilt_pwdft/onsite_projector.cpp index f235df15e5..47faf38797 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/onsite_projector.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/onsite_projector.cpp @@ -165,7 +165,7 @@ void projectors::OnsiteProjector::init(const std::string& orbital_dir RadialProjection::RadialProjector::_build_backward_map(it2iproj, lproj, irow2it_, irow2iproj_, irow2m_); RadialProjection::RadialProjector::_build_forward_map(it2ia, it2iproj, lproj, itiaiprojm2irow_); //rp_._build_sbt_tab(rgrid, projs, lproj, nq, dq); - rp_._build_sbt_tab(nproj, rgrid, projs, lproj, nq, dq, ucell_in->omega, psi.npol, tab, nhtol); + rp_._build_sbt_tab(nproj, rgrid, projs, lproj, nq, dq, ucell_in->omega, psi.get_npol(), tab, nhtol); // For being compatible with present cal_force and cal_stress framework // uncomment the following code block if you want to use the Onsite_Proj_tools if(this->tab_atomic_ == nullptr) @@ -541,7 +541,7 @@ void projectors::OnsiteProjector::cal_occupations(const psi::Psi::allocate_chi0() Device* ctx = {}; if (base_device::get_device_type(ctx) == base_device::GpuDevice) { - this->chi0 = new psi::Psi(nks, this->nchip_max, npwx, this->ngk); + this->chi0 = new psi::Psi(nks, this->nchip_max, npwx, this->ngk, true); } else { @@ -207,7 +207,7 @@ void Stochastic_WF::init_com_orbitals() delete[] npwip; } size_t size = this->nchip_max * npwx * nks; - this->chi0_cpu = new psi::Psi>(nks, this->nchip_max, npwx, this->ngk); + this->chi0_cpu = new psi::Psi>(nks, this->nchip_max, npwx, this->ngk, true); this->chi0_cpu->zero_out(); ModuleBase::Memory::record("SDFT::chi0_cpu", size * sizeof(std::complex)); for (int ik = 0; ik < nks; ++ik) @@ -252,7 +252,7 @@ void Stochastic_WF::init_com_orbitals() Device* ctx = {}; if (base_device::get_device_type(ctx) == base_device::GpuDevice) { - this->chi0 = new psi::Psi(nks, this->nchip_max, npwx, this->ngk); + this->chi0 = new psi::Psi(nks, this->nchip_max, npwx, this->ngk, true); } else { @@ -266,7 +266,7 @@ void Stochastic_WF::init_com_orbitals() const int npwx = this->npwx; const int nks = this->nks; size_t size = this->nchip_max * npwx * nks; - this->chi0_cpu = new psi::Psi>(nks, npwx, npwx, this->ngk); + this->chi0_cpu = new psi::Psi>(nks, npwx, npwx, this->ngk, true); this->chi0_cpu->zero_out(); ModuleBase::Memory::record("SDFT::chi0_cpu", size * sizeof(std::complex)); for (int ik = 0; ik < nks; ++ik) @@ -284,7 +284,7 @@ void Stochastic_WF::init_com_orbitals() Device* ctx = {}; if (base_device::get_device_type(ctx) == base_device::GpuDevice) { - this->chi0 = new psi::Psi(nks, this->nchip_max, npwx, this->ngk); + this->chi0 = new psi::Psi(nks, this->nchip_max, npwx, this->ngk, true); } else { diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_wf.h b/source/module_hamilt_pw/hamilt_stodft/sto_wf.h index a423810544..4afdeb4247 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_wf.h +++ b/source/module_hamilt_pw/hamilt_stodft/sto_wf.h @@ -30,10 +30,10 @@ class Stochastic_WF int* nchip = nullptr; ///< The number of stochatic orbitals in current process of each k point. int nchip_max = 0; ///< Max number of stochastic orbitals among all k points. int nks = 0; ///< number of k-points - int* ngk = nullptr; ///< ngk in klist int npwx = 0; ///< max ngk[ik] in all processors int nbands_diag = 0; ///< number of bands obtained from diagonalization int nbands_total = 0; ///< number of bands in total, nbands_total=nchi+nbands_diag; + std::vector ngk; ///< ngk in klist public: // Tn(H)|chi> psi::Psi* chiallorder = nullptr; diff --git a/source/module_hsolver/hsolver_lcao.cpp b/source/module_hsolver/hsolver_lcao.cpp index 2f9a1b4313..44deac1bbd 100644 --- a/source/module_hsolver/hsolver_lcao.cpp +++ b/source/module_hsolver/hsolver_lcao.cpp @@ -219,7 +219,7 @@ void HSolverLCAO::parakSolve(hamilt::Hamilt* pHamilt, k2d.distribute_hsk(pHamilt, ik_kpar, nrow); /// global index of k point int ik_global = ik + k2d.get_pKpoints()->startk_pool[k2d.get_my_pool()]; - auto psi_pool = psi::Psi(1, ncol_bands_pool, k2d.get_p2D_pool()->nrow, nullptr); + auto psi_pool = psi::Psi(1, ncol_bands_pool, k2d.get_p2D_pool()->nrow, k2d.get_p2D_pool()->nrow, true); ModuleBase::Memory::record("HSolverLCAO::psi_pool", nrow * ncol_bands_pool * sizeof(T)); if (ik_global < psi.get_nk() && ik < k2d.get_pKpoints()->nks_pool[k2d.get_my_pool()]) { diff --git a/source/module_hsolver/test/diago_mock.h b/source/module_hsolver/test/diago_mock.h index e63022f43d..85a7750fc5 100644 --- a/source/module_hsolver/test/diago_mock.h +++ b/source/module_hsolver/test/diago_mock.h @@ -214,7 +214,7 @@ class HPsi { Structure_Factor* sf; int* ngk = nullptr; - psi::Psi psitmp(1, nband, npw, ngk); + psi::Psi psitmp(1, nband, npw, npw, true); for(int i=0;i>(nks, nbands, wfcpw->npwk_max, wfcpw->npwk); + psi = new psi::Psi>(nks, nbands, wfcpw->npwk_max, kv->ngk, true); std::complex* ptr = psi->get_pointer(); for (int i = 0; i < nks * nbands * wfcpw->npwk_max; i++) { diff --git a/source/module_io/to_wannier90_lcao_in_pw.cpp b/source/module_io/to_wannier90_lcao_in_pw.cpp index 78f04a8a97..e067671465 100644 --- a/source/module_io/to_wannier90_lcao_in_pw.cpp +++ b/source/module_io/to_wannier90_lcao_in_pw.cpp @@ -52,7 +52,11 @@ void toWannier90_LCAO_IN_PW::calculate( const int nks_psi = (PARAM.inp.calculation == "nscf" && PARAM.inp.mem_saver == 1)? 1 : wfcpw->nks; const int nks_psig = (PARAM.inp.basis_type == "pw")? 1 : nks_psi; const int nbands_actual = this->psi_initer_->nbands_start(); - this->psi = new psi::Psi, base_device::DEVICE_CPU>(nks_psig, nbands_actual, wfcpw->npwk_max*PARAM.globalv.npol, wfcpw->npwk); + this->psi = new psi::Psi, base_device::DEVICE_CPU>(nks_psig, + nbands_actual, + wfcpw->npwk_max*PARAM.globalv.npol, + kv.ngk, + true); read_nnkp(ucell,kv); if (PARAM.inp.nspin == 2) @@ -117,7 +121,11 @@ psi::Psi>* toWannier90_LCAO_IN_PW::get_unk_from_lcao( { // init int npwx = wfcpw->npwk_max; - psi::Psi> *unk_inLcao = new psi::Psi>(num_kpts, num_bands, npwx*PARAM.globalv.npol, kv.ngk.data()); + psi::Psi> *unk_inLcao = new psi::Psi>(num_kpts, + num_bands, + npwx*PARAM.globalv.npol, + kv.ngk, + true); unk_inLcao->zero_out(); // Orbital projection to plane wave diff --git a/source/module_io/write_vxc_lip.hpp b/source/module_io/write_vxc_lip.hpp index d57c8f2ccd..1a50c8d00e 100644 --- a/source/module_io/write_vxc_lip.hpp +++ b/source/module_io/write_vxc_lip.hpp @@ -161,7 +161,7 @@ namespace ModuleIO // psi::Psi hpsi_single_band(&hpsi_localxc(ik, ib, 0), 1, 1, hpsi_localxc.get_current_nbas()); // vxcs_op_pw->act(1, psi_pw.get_current_nbas(), psi_pw.npol, psi_single_band.get_pointer(), hpsi_single_band.get_pointer(), psi_pw.get_ngk(ik)); // } - vxcs_op_pw->act(psi_pw.get_nbands(), psi_pw.get_nbasis(), psi_pw.npol, &psi_pw(ik, 0, 0), &hpsi_localxc(ik, 0, 0), psi_pw.get_ngk(ik)); + vxcs_op_pw->act(psi_pw.get_nbands(), psi_pw.get_nbasis(), psi_pw.get_npol(), &psi_pw(ik, 0, 0), &hpsi_localxc(ik, 0, 0), psi_pw.get_ngk(ik)); delete vxcs_op_pw; std::vector vxc_local_k_mo = psi_Hpsi(&psi_pw(ik, 0, 0), &hpsi_localxc(ik, 0, 0), psi_pw.get_nbasis(), psi_pw.get_nbands()); Parallel_Reduce::reduce_pool(vxc_local_k_mo.data(), nbands * nbands); diff --git a/source/module_lr/ao_to_mo_transformer/test/ao_to_mo_test.cpp b/source/module_lr/ao_to_mo_transformer/test/ao_to_mo_test.cpp index 5601ad451d..8bcb88b525 100644 --- a/source/module_lr/ao_to_mo_transformer/test/ao_to_mo_test.cpp +++ b/source/module_lr/ao_to_mo_transformer/test/ao_to_mo_test.cpp @@ -64,18 +64,18 @@ TEST_F(AO2MOTest, DoubleSerial) { for (auto s : this->sizes) { - psi::Psi vo_for(s.nks, nstate, s.nocc * s.nvirt, nullptr, false); - psi::Psi vo_blas(s.nks, nstate, s.nocc * s.nvirt, nullptr, false); - psi::Psi oo_for(s.nks, nstate, s.nocc * s.nocc, nullptr, false); - psi::Psi oo_blas(s.nks, nstate, s.nocc * s.nocc, nullptr, false); - psi::Psi vv_for(s.nks, nstate, s.nvirt * s.nvirt, nullptr, false); - psi::Psi vv_blas(s.nks, nstate, s.nvirt * s.nvirt, nullptr, false); + psi::Psi vo_for(s.nks, nstate, s.nocc * s.nvirt, s.nocc * s.nvirt, false); + psi::Psi vo_blas(s.nks, nstate, s.nocc * s.nvirt, s.nocc * s.nvirt, false); + psi::Psi oo_for(s.nks, nstate, s.nocc * s.nocc, s.nocc * s.nocc, false); + psi::Psi oo_blas(s.nks, nstate, s.nocc * s.nocc, s.nocc * s.nocc, false); + psi::Psi vv_for(s.nks, nstate, s.nvirt * s.nvirt, s.nvirt * s.nvirt, false); + psi::Psi vv_blas(s.nks, nstate, s.nvirt * s.nvirt, s.nvirt * s.nvirt, false); int size_c = s.nks * (s.nocc + s.nvirt) * s.naos; int size_v = s.naos * s.naos; for (int istate = 0;istate < nstate;++istate) { std::vector temp(s.nks, s.naos); - psi::Psi c(s.nks, s.nocc + s.nvirt, s.naos, temp.data(), true); + psi::Psi c(s.nks, s.nocc + s.nvirt, s.naos, temp, true); std::vector V(s.nks, container::Tensor(DAT::DT_DOUBLE, DEV::CpuDevice, { s.naos, s.naos })); set_rand(&c(0, 0, 0), size_c); for (auto& v : V) { set_rand(v.data(), size_v); } @@ -96,18 +96,18 @@ TEST_F(AO2MOTest, ComplexSerial) { for (auto s : this->sizes) { - psi::Psi> vo_for(s.nks, nstate, s.nocc * s.nvirt, nullptr, false); - psi::Psi> vo_blas(s.nks, nstate, s.nocc * s.nvirt, nullptr, false); - psi::Psi> oo_for(s.nks, nstate, s.nocc * s.nocc, nullptr, false); - psi::Psi> oo_blas(s.nks, nstate, s.nocc * s.nocc, nullptr, false); - psi::Psi> vv_for(s.nks, nstate, s.nvirt * s.nvirt, nullptr, false); - psi::Psi> vv_blas(s.nks, nstate, s.nvirt * s.nvirt, nullptr, false); + psi::Psi> vo_for(s.nks, nstate, s.nocc * s.nvirt, s.nocc * s.nvirt, false); + psi::Psi> vo_blas(s.nks, nstate, s.nocc * s.nvirt, s.nocc * s.nvirt, false); + psi::Psi> oo_for(s.nks, nstate, s.nocc * s.nocc, s.nocc * s.nocc, false); + psi::Psi> oo_blas(s.nks, nstate, s.nocc * s.nocc, s.nocc * s.nocc, false); + psi::Psi> vv_for(s.nks, nstate, s.nvirt * s.nvirt, s.nvirt * s.nvirt, false); + psi::Psi> vv_blas(s.nks, nstate, s.nvirt * s.nvirt, s.nvirt * s.nvirt, false); int size_c = s.nks * (s.nocc + s.nvirt) * s.naos; int size_v = s.naos * s.naos; for (int istate = 0;istate < nstate;++istate) { std::vector temp(s.nks, s.naos); - psi::Psi> c(s.nks, s.nocc + s.nvirt, s.naos, temp.data(), true); + psi::Psi> c(s.nks, s.nocc + s.nvirt, s.naos, temp, true); std::vector V(s.nks, container::Tensor(DAT::DT_COMPLEX_DOUBLE, DEV::CpuDevice, { s.naos, s.naos })); set_rand(&c(0, 0, 0), size_c); for (auto& v : V) { set_rand(v.data>(), size_v); } @@ -137,7 +137,7 @@ TEST_F(AO2MOTest, DoubleParallel) LR_Util::setup_2d_division(pc, s.nb, s.naos, s.nocc + s.nvirt, pV.blacs_ctxt); std::vector ngk_temp(s.nks, pc.get_row_size()); - psi::Psi c(s.nks, pc.get_col_size(), pc.get_row_size(), ngk_temp.data(), true); + psi::Psi c(s.nks, pc.get_col_size(), pc.get_row_size(), ngk_temp, true); Parallel_2D pvo, poo, pvv; LR_Util::setup_2d_division(pvo, s.nb, s.nvirt, s.nocc, pV.blacs_ctxt); LR_Util::setup_2d_division(poo, s.nb, s.nocc, s.nocc, pV.blacs_ctxt); @@ -148,12 +148,12 @@ TEST_F(AO2MOTest, DoubleParallel) EXPECT_GE(s.nvirt, pvo.dim0); EXPECT_GE(s.nocc, pvo.dim1); EXPECT_GE(s.naos, pc.dim0); - psi::Psi vo_pblas_loc(s.nks, nstate, pvo.get_local_size(), nullptr, false); - psi::Psi vo_gather(s.nks, nstate, s.nocc * s.nvirt, nullptr, false); - psi::Psi oo_pblas_loc(s.nks, nstate, poo.get_local_size(), nullptr, false); - psi::Psi oo_gather(s.nks, nstate, s.nocc * s.nocc, nullptr, false); - psi::Psi vv_pblas_loc(s.nks, nstate, pvv.get_local_size(), nullptr, false); - psi::Psi vv_gather(s.nks, nstate, s.nvirt * s.nvirt, nullptr, false); + psi::Psi vo_pblas_loc(s.nks, nstate, pvo.get_local_size(), pvo.get_local_size(), false); + psi::Psi vo_gather(s.nks, nstate, s.nocc * s.nvirt, s.nocc * s.nvirt, false); + psi::Psi oo_pblas_loc(s.nks, nstate, poo.get_local_size(), poo.get_local_size(), false); + psi::Psi oo_gather(s.nks, nstate, s.nocc * s.nocc, s.nocc * s.nocc, false); + psi::Psi vv_pblas_loc(s.nks, nstate, pvv.get_local_size(), pvv.get_local_size(), false); + psi::Psi vv_gather(s.nks, nstate, s.nvirt * s.nvirt, s.nvirt * s.nvirt, false); for (int istate = 0;istate < nstate;++istate) { for (int isk = 0;isk < s.nks;++isk) @@ -174,7 +174,7 @@ TEST_F(AO2MOTest, DoubleParallel) // compare to global AX std::vector V_full(s.nks, container::Tensor(DAT::DT_DOUBLE, DEV::CpuDevice, { s.naos, s.naos })); std::vector ngk_temp_1(s.nks, s.naos); - psi::Psi c_full(s.nks, s.nocc + s.nvirt, s.naos, ngk_temp_1.data(), true); + psi::Psi c_full(s.nks, s.nocc + s.nvirt, s.naos, ngk_temp_1, true); for (int isk = 0;isk < s.nks;++isk) { LR_Util::gather_2d_to_full(pV, V.at(isk).data(), V_full.at(isk).data(), false, s.naos, s.naos); @@ -182,13 +182,13 @@ TEST_F(AO2MOTest, DoubleParallel) } if (my_rank == 0) { - psi::Psi vo_full_istate(s.nks, 1, s.nocc * s.nvirt, nullptr, false); + psi::Psi vo_full_istate(s.nks, 1, s.nocc * s.nvirt, s.nocc * s.nvirt, false); LR::ao_to_mo_blas(V_full, c_full, s.nocc, s.nvirt, &vo_full_istate(0, 0, 0), false); check_eq(&vo_full_istate(0, 0, 0), &vo_gather(istate, 0, 0), s.nks * s.nocc * s.nvirt); - psi::Psi oo_full_istate(s.nks, 1, s.nocc * s.nocc, nullptr, false); + psi::Psi oo_full_istate(s.nks, 1, s.nocc * s.nocc, s.nocc * s.nocc, false); LR::ao_to_mo_blas(V_full, c_full, s.nocc, s.nvirt, &oo_full_istate(0, 0, 0), false, LR::MO_TYPE::OO); check_eq(&oo_full_istate(0, 0, 0), &oo_gather(istate, 0, 0), s.nks * s.nocc * s.nocc); - psi::Psi vv_full_istate(s.nks, 1, s.nvirt * s.nvirt, nullptr, false); + psi::Psi vv_full_istate(s.nks, 1, s.nvirt * s.nvirt, s.nvirt * s.nvirt, false); LR::ao_to_mo_blas(V_full, c_full, s.nocc, s.nvirt, &vv_full_istate(0, 0, 0), false, LR::MO_TYPE::VV); check_eq(&vv_full_istate(0, 0, 0), &vv_gather(istate, 0, 0), s.nks * s.nvirt * s.nvirt); } @@ -208,18 +208,18 @@ TEST_F(AO2MOTest, ComplexParallel) LR_Util::setup_2d_division(pc, s.nb, s.naos, s.nocc + s.nvirt, pV.blacs_ctxt); std::vector ngk_temp_1(s.nks, pc.get_row_size()); - psi::Psi> c(s.nks, pc.get_col_size(), pc.get_row_size(), ngk_temp_1.data(), true); + psi::Psi> c(s.nks, pc.get_col_size(), pc.get_row_size(), ngk_temp_1, true); Parallel_2D pvo, poo, pvv; LR_Util::setup_2d_division(pvo, s.nb, s.nvirt, s.nocc, pV.blacs_ctxt); LR_Util::setup_2d_division(poo, s.nb, s.nocc, s.nocc, pV.blacs_ctxt); LR_Util::setup_2d_division(pvv, s.nb, s.nvirt, s.nvirt, pV.blacs_ctxt); - psi::Psi> vo_pblas_loc(s.nks, nstate, pvo.get_local_size(), nullptr, false); - psi::Psi> vo_gather(s.nks, nstate, s.nocc * s.nvirt, nullptr, false); - psi::Psi> oo_pblas_loc(s.nks, nstate, poo.get_local_size(), nullptr, false); - psi::Psi> oo_gather(s.nks, nstate, s.nocc * s.nocc, nullptr, false); - psi::Psi> vv_pblas_loc(s.nks, nstate, pvv.get_local_size(), nullptr, false); - psi::Psi> vv_gather(s.nks, nstate, s.nvirt * s.nvirt, nullptr, false); + psi::Psi> vo_pblas_loc(s.nks, nstate, pvo.get_local_size(), pvo.get_local_size(), false); + psi::Psi> vo_gather(s.nks, nstate, s.nocc * s.nvirt, s.nocc * s.nvirt, false); + psi::Psi> oo_pblas_loc(s.nks, nstate, poo.get_local_size(), poo.get_local_size(), false); + psi::Psi> oo_gather(s.nks, nstate, s.nocc * s.nocc, s.nocc * s.nocc, false); + psi::Psi> vv_pblas_loc(s.nks, nstate, pvv.get_local_size(), pvv.get_local_size(), false); + psi::Psi> vv_gather(s.nks, nstate, s.nvirt * s.nvirt, s.nvirt * s.nvirt, false); for (int istate = 0;istate < nstate;++istate) { for (int isk = 0;isk < s.nks;++isk) @@ -241,7 +241,7 @@ TEST_F(AO2MOTest, ComplexParallel) // compare to global AX std::vector V_full(s.nks, container::Tensor(DAT::DT_COMPLEX_DOUBLE, DEV::CpuDevice, { s.naos, s.naos })); std::vector ngk_temp_2(s.nks, s.naos); - psi::Psi> c_full(s.nks, s.nocc + s.nvirt, s.naos, ngk_temp_2.data(), true); + psi::Psi> c_full(s.nks, s.nocc + s.nvirt, s.naos, ngk_temp_2, true); for (int isk = 0;isk < s.nks;++isk) { LR_Util::gather_2d_to_full(pV, V.at(isk).data>(), V_full.at(isk).data>(), false, s.naos, s.naos); @@ -249,13 +249,13 @@ TEST_F(AO2MOTest, ComplexParallel) } if (my_rank == 0) { - psi::Psi> vo_full_istate(s.nks, 1, s.nocc * s.nvirt, nullptr, false); + psi::Psi> vo_full_istate(s.nks, 1, s.nocc * s.nvirt, s.nocc * s.nvirt, false); LR::ao_to_mo_blas(V_full, c_full, s.nocc, s.nvirt, &vo_full_istate(0, 0, 0), false); check_eq(&vo_full_istate(0, 0, 0), &vo_gather(istate, 0, 0), s.nks * s.nocc * s.nvirt); - psi::Psi> oo_full_istate(s.nks, 1, s.nocc * s.nocc, nullptr, false); + psi::Psi> oo_full_istate(s.nks, 1, s.nocc * s.nocc, s.nocc * s.nvirt, false); LR::ao_to_mo_blas(V_full, c_full, s.nocc, s.nocc, &oo_full_istate(0, 0, 0), false, LR::MO_TYPE::OO); check_eq(&oo_full_istate(0, 0, 0), &oo_gather(istate, 0, 0), s.nks * s.nocc * s.nocc); - psi::Psi> vv_full_istate(s.nks, 1, s.nvirt * s.nvirt, nullptr, false); + psi::Psi> vv_full_istate(s.nks, 1, s.nvirt * s.nvirt, s.nocc * s.nvirt, false); LR::ao_to_mo_blas(V_full, c_full, s.nocc, s.nvirt, &vv_full_istate(0, 0, 0), false, LR::MO_TYPE::VV); check_eq(&vv_full_istate(0, 0, 0), &vv_gather(istate, 0, 0), s.nks * s.nvirt * s.nvirt); } diff --git a/source/module_lr/dm_trans/test/dm_trans_test.cpp b/source/module_lr/dm_trans/test/dm_trans_test.cpp index 8a40f08c61..acef1e8a40 100644 --- a/source/module_lr/dm_trans/test/dm_trans_test.cpp +++ b/source/module_lr/dm_trans/test/dm_trans_test.cpp @@ -61,18 +61,18 @@ TEST_F(DMTransTest, DoubleSerial) { for (auto s : this->sizes) { - psi::Psi X_vo(s.nks, nstate, s.nocc * s.nvirt, nullptr, false); + psi::Psi X_vo(s.nks, nstate, s.nocc * s.nvirt, s.nocc * s.nvirt, false); set_rand(X_vo.get_pointer(), nstate * s.nks * s.nocc * s.nvirt); - psi::Psi X_oo(s.nks, nstate, s.nocc * s.nocc, nullptr, false); + psi::Psi X_oo(s.nks, nstate, s.nocc * s.nocc, s.nocc * s.nocc, false); set_rand(X_oo.get_pointer(), nstate * s.nks * s.nocc * s.nocc); - psi::Psi X_vv(s.nks, nstate, s.nvirt * s.nvirt, nullptr, false); + psi::Psi X_vv(s.nks, nstate, s.nvirt * s.nvirt, s.nvirt * s.nvirt, false); set_rand(X_vv.get_pointer(), nstate * s.nks * s.nvirt * s.nvirt); for (int istate = 0;istate < nstate;++istate) { const int size_c = s.nks * (s.nocc + s.nvirt) * s.naos; std::vector temp(s.nks, s.naos); - psi::Psi c(s.nks, s.nocc + s.nvirt, s.naos, temp.data(), true); + psi::Psi c(s.nks, s.nocc + s.nvirt, s.naos, temp, true); set_rand(c.get_pointer(), size_c); auto test = [&](psi::Psi& X, const LR::MO_TYPE type) { @@ -92,18 +92,18 @@ TEST_F(DMTransTest, ComplexSerial) { for (auto s : this->sizes) { - psi::Psi> X_vo(s.nks, nstate, s.nocc * s.nvirt, nullptr, false); + psi::Psi> X_vo(s.nks, nstate, s.nocc * s.nvirt, s.nocc * s.nvirt, false); set_rand(X_vo.get_pointer(), nstate * s.nks * s.nocc * s.nvirt); - psi::Psi> X_oo(s.nks, nstate, s.nocc * s.nocc, nullptr, false); + psi::Psi> X_oo(s.nks, nstate, s.nocc * s.nocc, s.nocc * s.nocc, false); set_rand(X_oo.get_pointer(), nstate * s.nks * s.nocc * s.nocc); - psi::Psi> X_vv(s.nks, nstate, s.nvirt * s.nvirt, nullptr, false); + psi::Psi> X_vv(s.nks, nstate, s.nvirt * s.nvirt, s.nvirt * s.nvirt, false); set_rand(X_vv.get_pointer(), nstate * s.nks * s.nvirt * s.nvirt); for (int istate = 0;istate < nstate;++istate) { const int size_c = s.nks * (s.nocc + s.nvirt) * s.naos; std::vector temp(s.nks, s.naos); - psi::Psi> c(s.nks, s.nocc + s.nvirt, s.naos, temp.data(), true); + psi::Psi> c(s.nks, s.nocc + s.nvirt, s.naos, temp, true); set_rand(c.get_pointer(), size_c); auto test = [&](psi::Psi>& X, const LR::MO_TYPE type) { @@ -132,18 +132,18 @@ TEST_F(DMTransTest, DoubleParallel) LR_Util::setup_2d_division(px_oo, s.nb, s.nocc, s.nocc, px_vo.blacs_ctxt); LR_Util::setup_2d_division(px_vv, s.nb, s.nvirt, s.nvirt, px_vo.blacs_ctxt); - psi::Psi X_vo(s.nks, nstate, px_vo.get_local_size(), nullptr, false); + psi::Psi X_vo(s.nks, nstate, px_vo.get_local_size(), px_vo.get_local_size(), false); set_rand(X_vo.get_pointer(), nstate * s.nks * px_vo.get_local_size()); - psi::Psi X_oo(s.nks, nstate, px_oo.get_local_size(), nullptr, false); + psi::Psi X_oo(s.nks, nstate, px_oo.get_local_size(), px_oo.get_local_size(), false); set_rand(X_oo.get_pointer(), nstate * s.nks * px_oo.get_local_size()); - psi::Psi X_vv(s.nks, nstate, px_vv.get_local_size(), nullptr, false); + psi::Psi X_vv(s.nks, nstate, px_vv.get_local_size(), px_vv.get_local_size(), false); set_rand(X_vv.get_pointer(), nstate * s.nks * px_vv.get_local_size()); Parallel_2D pc; LR_Util::setup_2d_division(pc, s.nb, s.naos, s.nocc + s.nvirt, px_vo.blacs_ctxt); std::vector temp_2(s.nks, pc.get_row_size()); - psi::Psi c(s.nks, pc.get_col_size(), pc.get_row_size(), temp_2.data(), true); + psi::Psi c(s.nks, pc.get_col_size(), pc.get_row_size(), temp_2, true); Parallel_2D pmat; LR_Util::setup_2d_division(pmat, s.nb, s.naos, s.naos, px_vo.blacs_ctxt); @@ -153,9 +153,9 @@ TEST_F(DMTransTest, DoubleParallel) EXPECT_GE(s.nocc, px_vo.dim1); EXPECT_GE(s.naos, pc.dim0); - psi::Psi X_full_vo(s.nks, nstate, s.nocc * s.nvirt, nullptr, false); // allocate X_full - psi::Psi X_full_oo(s.nks, nstate, s.nocc * s.nocc, nullptr, false); // allocate X_full - psi::Psi X_full_vv(s.nks, nstate, s.nvirt * s.nvirt, nullptr, false); // allocate X_full + psi::Psi X_full_vo(s.nks, nstate, s.nocc * s.nvirt, s.nocc * s.nvirt, false); // allocate X_full + psi::Psi X_full_oo(s.nks, nstate, s.nocc * s.nocc, s.nocc * s.nocc, false); // allocate X_full + psi::Psi X_full_vv(s.nks, nstate, s.nvirt * s.nvirt, s.nvirt * s.nvirt, false); // allocate X_full auto gather = [&](const psi::Psi& X, psi::Psi& X_full, const Parallel_2D& px, const int dim1, const int dim2) { @@ -182,7 +182,7 @@ TEST_F(DMTransTest, DoubleParallel) // gather C std::vector temp(s.nks, s.naos); - psi::Psi c_full(s.nks, s.nocc + s.nvirt, s.naos, temp.data(), true); + psi::Psi c_full(s.nks, s.nocc + s.nvirt, s.naos, temp, true); for (int isk = 0;isk < s.nks;++isk) { c.fix_k(isk); @@ -223,24 +223,24 @@ TEST_F(DMTransTest, ComplexParallel) LR_Util::setup_2d_division(px_oo, s.nb, s.nocc, s.nocc, px_vo.blacs_ctxt); LR_Util::setup_2d_division(px_vv, s.nb, s.nvirt, s.nvirt, px_vo.blacs_ctxt); - psi::Psi> X_vo(s.nks, nstate, px_vo.get_local_size(), nullptr, false); + psi::Psi> X_vo(s.nks, nstate, px_vo.get_local_size(), px_vo.get_local_size(), false); set_rand(X_vo.get_pointer(), nstate * s.nks * px_vo.get_local_size()); - psi::Psi> X_oo(s.nks, nstate, px_oo.get_local_size(), nullptr, false); + psi::Psi> X_oo(s.nks, nstate, px_oo.get_local_size(), px_oo.get_local_size(), false); set_rand(X_oo.get_pointer(), nstate * s.nks * px_oo.get_local_size()); - psi::Psi> X_vv(s.nks, nstate, px_vv.get_local_size(), nullptr, false); + psi::Psi> X_vv(s.nks, nstate, px_vv.get_local_size(), px_vv.get_local_size(), false); set_rand(X_vv.get_pointer(), nstate * s.nks * px_vv.get_local_size()); Parallel_2D pc; LR_Util::setup_2d_division(pc, s.nb, s.naos, s.nocc + s.nvirt, px_vo.blacs_ctxt); std::vector temp(s.nks, pc.get_row_size()); - psi::Psi> c(s.nks, pc.get_col_size(), pc.get_row_size(), temp.data(), true); + psi::Psi> c(s.nks, pc.get_col_size(), pc.get_row_size(), temp, true); Parallel_2D pmat; LR_Util::setup_2d_division(pmat, s.nb, s.naos, s.naos, px_vo.blacs_ctxt); - psi::Psi> X_full_vo(s.nks, nstate, s.nocc * s.nvirt, nullptr, false); // allocate X_full - psi::Psi> X_full_oo(s.nks, nstate, s.nocc * s.nocc, nullptr, false); // allocate X_full - psi::Psi> X_full_vv(s.nks, nstate, s.nvirt * s.nvirt, nullptr, false); // allocate X_full + psi::Psi> X_full_vo(s.nks, nstate, s.nocc * s.nvirt, s.nocc * s.nvirt, false); // allocate X_full + psi::Psi> X_full_oo(s.nks, nstate, s.nocc * s.nocc, s.nocc * s.nvirt, false); // allocate X_full + psi::Psi> X_full_vv(s.nks, nstate, s.nvirt * s.nvirt, s.nocc * s.nvirt, false); // allocate X_full auto gather = [&](const psi::Psi>& X, psi::Psi>& X_full, const Parallel_2D& px, const int dim1, const int dim2) { @@ -266,7 +266,7 @@ TEST_F(DMTransTest, ComplexParallel) set_rand(c.get_pointer(), s.nks * pc.get_local_size()); // set c // compare to global matrix std::vector ngk_temp_2(s.nks, s.naos); - psi::Psi> c_full(s.nks, s.nocc + s.nvirt, s.naos, ngk_temp_2.data(), true); + psi::Psi> c_full(s.nks, s.nocc + s.nvirt, s.naos, ngk_temp_2, true); for (int isk = 0;isk < s.nks;++isk) { c.fix_k(isk); diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index a69635dffb..78eb202766 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -32,7 +32,6 @@ Range::Range(const bool k_first_in, const size_t index_1_in, const size_t range_ template Psi::Psi() { - this->npol = PARAM.globalv.npol; } template @@ -44,41 +43,7 @@ Psi::~Psi() } } -// Constructor 1-1: -template -Psi::Psi(const int nk_in, const int nbd_in, const int nbs_in, const int* ngk_in, const bool k_first_in) -{ - assert(nk_in > 0); - assert(nbd_in >= 0); // 187_PW_SDFT_ALL_GPU && 187_PW_MD_SDFT_ALL_GPU - assert(nbs_in > 0); - - this->k_first = k_first_in; - this->npol = PARAM.globalv.npol; - this->allocate_inside = true; - - this->ngk = ngk_in; // modify later - // This function will delete the psi array first(if psi exist), then malloc a new memory for it. - resize_memory_op()(this->psi, nk_in * static_cast(nbd_in) * nbs_in, "no_record"); - - this->nk = nk_in; - this->nbands = nbd_in; - this->nbasis = nbs_in; - - this->current_b = 0; - this->current_k = 0; - this->current_nbasis = nbs_in; - this->psi_current = this->psi; - this->psi_bias = 0; - - // Currently only GPU's implementation is supported for device recording! - base_device::information::print_device_info(this->ctx, GlobalV::ofs_device); - base_device::information::record_device_memory(this->ctx, - GlobalV::ofs_device, - "Psi->resize()", - sizeof(T) * nk_in * nbd_in * nbs_in); -} - -// Constructor 1-2: +// Constructor 1: template Psi::Psi(const int nk_in, const int nbd_in, @@ -87,11 +52,10 @@ Psi::Psi(const int nk_in, const bool k_first_in) { assert(nk_in > 0); - assert(nbd_in > 0); + assert(nbd_in >= 0); assert(nbs_in > 0); this->k_first = k_first_in; - this->npol = PARAM.globalv.npol; this->allocate_inside = true; this->ngk = ngk_in.data(); // modify later @@ -129,7 +93,6 @@ Psi::Psi(T* psi_pointer, // assert(nk_in == 1); // NOTE because lr/utils/lr_uril.hpp func & get_psi_spin func this->k_first = k_first_in; - this->npol = PARAM.globalv.npol; this->allocate_inside = false; this->ngk = nullptr; @@ -158,10 +121,9 @@ Psi::Psi(const int nk_in, const bool k_first_in) { // Currently this function only supports nk_in == 1 when called within diagH_subspace_init. - assert(nk_in == 1); + // assert(nk_in == 1); this->k_first = k_first_in; - this->npol = PARAM.globalv.npol; this->allocate_inside = true; this->ngk = nullptr; @@ -190,8 +152,8 @@ Psi::Psi(const int nk_in, template Psi::Psi(const Psi& psi_in) { + this->ngk = psi_in.ngk; - this->npol = psi_in.npol; this->nk = psi_in.get_nk(); this->nbands = psi_in.get_nbands(); this->nbasis = psi_in.get_nbasis(); @@ -215,8 +177,8 @@ template template Psi::Psi(const Psi& psi_in) { + this->ngk = psi_in.get_ngk_pointer(); - this->npol = psi_in.npol; this->nk = psi_in.get_nk(); this->nbands = psi_in.get_nbands(); this->nbasis = psi_in.get_nbasis(); @@ -323,7 +285,7 @@ const int& Psi::get_psi_bias() const template const int& Psi::get_current_ngk() const { - if (this->npol == 1) + if (this->get_npol() == 1) { return this->current_nbasis; } @@ -333,6 +295,19 @@ const int& Psi::get_current_ngk() const } } +template +const int Psi::get_npol() const +{ + if (PARAM.inp.nspin == 4) + { + return 2; + } + else + { + return 1; + } +} + template const int& Psi::get_nk() const { @@ -511,13 +486,13 @@ std::tuple Psi::to_range(const Range& range) const else if (i1 < 0) // [r1, r2] is the range of index1 with length m { const T* p = &this->psi[r1 * (k_first ? this->nbands : this->nk) * this->nbasis]; - int m = (r2 - r1 + 1) * this->npol; + int m = (r2 - r1 + 1) * this->get_npol(); return std::tuple(p, m); } else // [r1, r2] is the range of index2 with length m { const T* p = &this->psi[(i1 * (k_first ? this->nbands : this->nk) + r1) * this->nbasis]; - int m = (r2 - r1 + 1) * this->npol; + int m = (r2 - r1 + 1) * this->get_npol(); return std::tuple(p, m); } } diff --git a/source/module_psi/psi.h b/source/module_psi/psi.h index d8a994377a..75e13433ea 100644 --- a/source/module_psi/psi.h +++ b/source/module_psi/psi.h @@ -39,10 +39,7 @@ class Psi // Constructor 0: basic Psi(); - // Constructor 1-1: specify nk, nbands, nbasis, ngk, and do not need to call resize() later - Psi(const int nk_in, const int nbd_in, const int nbs_in, const int* ngk_in, const bool k_first_in = true); - - // Constructor 1-2: + // Constructor 1: Psi(const int nk_in, const int nbd_in, const int nbs_in, const std::vector& ngk_in, const bool k_first_in); // Constructor 2-1: initialize a new psi from the given psi_in @@ -137,7 +134,7 @@ class Psi // solve Range: return(pointer of begin, number of bands or k-points) std::tuple to_range(const Range& range) const; - int npol = 1; + const int get_npol() const; private: T* psi = nullptr; // avoid using C++ STL diff --git a/source/module_psi/psi_init.cpp b/source/module_psi/psi_init.cpp index 102e2d4b1a..8ef89dcfdc 100644 --- a/source/module_psi/psi_init.cpp +++ b/source/module_psi/psi_init.cpp @@ -106,7 +106,7 @@ void PSIInit::initialize_psi(Psi>* psi, if (not_equal) { - psi_cpu = new Psi(1, nbands_start, nbasis, nullptr); + psi_cpu = new Psi(1, nbands_start, nbasis, nbasis, true); psi_device = PARAM.inp.device == "gpu" ? new psi::Psi(psi_cpu[0]) : reinterpret_cast*>(psi_cpu); } @@ -119,7 +119,7 @@ void PSIInit::initialize_psi(Psi>* psi, } else { - psi_cpu = new Psi(1, nbands_start, nbasis, nullptr); + psi_cpu = new Psi(1, nbands_start, nbasis, nbasis, true); psi_device = kspw_psi; } } @@ -203,7 +203,7 @@ void PSIInit::initialize_lcao_in_pw(Psi* psi_local, std::ofstream& } } -void allocate_psi(Psi>*& psi, const int& nks, const int* ngk, const int& nbands, const int& npwx) +void allocate_psi(Psi>*& psi, const int& nks, const std::vector& ngk, const int& nbands, const int& npwx) { assert(npwx > 0); assert(nks > 0); @@ -215,7 +215,7 @@ void allocate_psi(Psi>*& psi, const int& nks, const int* ng { nks2 = 1; } - psi = new psi::Psi>(nks2, nbands, npwx * PARAM.globalv.npol, ngk); + psi = new psi::Psi>(nks2, nbands, npwx * PARAM.globalv.npol, ngk, true); const size_t memory_cost = sizeof(std::complex) * nks2 * nbands * (PARAM.globalv.npol * npwx); std::cout << " MEMORY FOR PSI (MB) : " << static_cast(memory_cost) / 1024.0 / 1024.0 << std::endl; ModuleBase::Memory::record("Psi_PW", memory_cost); diff --git a/source/module_psi/psi_init.h b/source/module_psi/psi_init.h index e112a71a6e..bf93e534d0 100644 --- a/source/module_psi/psi_init.h +++ b/source/module_psi/psi_init.h @@ -86,7 +86,7 @@ class PSIInit }; ///@brief allocate the wavefunction -void allocate_psi(Psi>*& psi, const int& nks, const int* ngk, const int& nbands, const int& npwx); +void allocate_psi(Psi>*& psi, const int& nks, const std::vector& ngk, const int& nbands, const int& npwx); } // namespace psi #endif \ No newline at end of file diff --git a/source/module_psi/psi_initializer_atomic_random.cpp b/source/module_psi/psi_initializer_atomic_random.cpp index f7b735f5ed..7e0652c25c 100644 --- a/source/module_psi/psi_initializer_atomic_random.cpp +++ b/source/module_psi/psi_initializer_atomic_random.cpp @@ -21,7 +21,7 @@ void psi_initializer_atomic_random::init_psig(T* psig, const int& ik) psi_initializer_atomic::init_psig(psig, ik); const int npol = PARAM.globalv.npol; const int nbasis = this->pw_wfc_->npwk_max * npol; - psi::Psi psi_random(1, this->nbands_start_, nbasis, nullptr); + psi::Psi psi_random(1, this->nbands_start_, nbasis, nbasis, true); psi_random.fix_k(0); this->random_t(psi_random.get_pointer(), 0, this->nbands_start_, ik, 0); for (int iband = 0; iband < this->nbands_start_; iband++) diff --git a/source/module_psi/psi_initializer_nao_random.cpp b/source/module_psi/psi_initializer_nao_random.cpp index 4f8b8d940f..ab23c4a163 100644 --- a/source/module_psi/psi_initializer_nao_random.cpp +++ b/source/module_psi/psi_initializer_nao_random.cpp @@ -21,7 +21,7 @@ void psi_initializer_nao_random::init_psig(T* psig, const int& ik) psi_initializer_nao::init_psig(psig, ik); const int npol = PARAM.globalv.npol; const int nbasis = this->pw_wfc_->npwk_max * npol; - psi::Psi psi_random(1, this->nbands_start_, nbasis, nullptr); + psi::Psi psi_random(1, this->nbands_start_, nbasis, nbasis, true); psi_random.fix_k(0); this->random_t(psi_random.get_pointer(), 0, this->nbands_start_, ik, 0); for (int iband = 0; iband < this->nbands_start_; iband++) diff --git a/source/module_psi/test/psi_initializer_unit_test.cpp b/source/module_psi/test/psi_initializer_unit_test.cpp index fd9dcd497c..b5b4180b2d 100644 --- a/source/module_psi/test/psi_initializer_unit_test.cpp +++ b/source/module_psi/test/psi_initializer_unit_test.cpp @@ -321,7 +321,7 @@ TEST_F(PsiIntializerUnitTest, CalPsigRandom) { this->psi_init->tabulate(); // always: new, initialize, tabulate, allocate, proj_ao_onkG const int nbands_start = this->psi_init->nbands_start(); const int nbasis = this->p_pw_wfc->npwk_max * PARAM.globalv.npol; - psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nullptr); + psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nbasis, true); this->psi_init->init_psig(psi->get_pointer(), 0); EXPECT_NEAR(-0.66187696761064307, psi->operator()(0,0,0).real(), 1e-4); delete psi; @@ -340,7 +340,7 @@ TEST_F(PsiIntializerUnitTest, CalPsigAtomic) { this->psi_init->tabulate(); // always: new, initialize, tabulate, allocate, proj_ao_onkG const int nbands_start = this->psi_init->nbands_start(); const int nbasis = this->p_pw_wfc->npwk_max * PARAM.globalv.npol; - psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nullptr); + psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nbasis, true); this->psi_init->init_psig(psi->get_pointer(), 0); EXPECT_NEAR(0, psi->operator()(0,0,0).real(), 1e-12); delete psi; @@ -363,7 +363,7 @@ TEST_F(PsiIntializerUnitTest, CalPsigAtomicSoc) { this->psi_init->tabulate(); // always: new, initialize, tabulate, allocate, proj_ao_onkG const int nbands_start = this->psi_init->nbands_start(); const int nbasis = this->p_pw_wfc->npwk_max * PARAM.globalv.npol; - psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nullptr); + psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nbasis, true); this->psi_init->init_psig(psi->get_pointer(), 0); EXPECT_NEAR(0, psi->operator()(0,0,0).real(), 1e-12); PARAM.input.nspin = 1; @@ -390,7 +390,7 @@ TEST_F(PsiIntializerUnitTest, CalPsigAtomicSocHasSo) { this->psi_init->tabulate(); // always: new, initialize, tabulate, allocate, proj_ao_onkG const int nbands_start = this->psi_init->nbands_start(); const int nbasis = this->p_pw_wfc->npwk_max * PARAM.globalv.npol; - psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nullptr); + psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nbasis, true); this->psi_init->init_psig(psi->get_pointer(), 0); EXPECT_NEAR(0, psi->operator()(0,0,0).real(), 1e-12); PARAM.input.nspin = 1; @@ -413,7 +413,7 @@ TEST_F(PsiIntializerUnitTest, CalPsigAtomicRandom) { this->psi_init->tabulate(); // always: new, initialize, tabulate, allocate, proj_ao_onkG const int nbands_start = this->psi_init->nbands_start(); const int nbasis = this->p_pw_wfc->npwk_max * PARAM.globalv.npol; - psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nullptr); + psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nbasis, true); this->psi_init->init_psig(psi->get_pointer(), 0); EXPECT_NEAR(0, psi->operator()(0,0,0).real(), 1e-12); delete psi; @@ -432,7 +432,7 @@ TEST_F(PsiIntializerUnitTest, CalPsigNao) { this->psi_init->tabulate(); // always: new, initialize, tabulate, allocate, proj_ao_onkG const int nbands_start = this->psi_init->nbands_start(); const int nbasis = this->p_pw_wfc->npwk_max * PARAM.globalv.npol; - psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nullptr); + psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nbasis, true); this->psi_init->init_psig(psi->get_pointer(), 0); EXPECT_NEAR(0, psi->operator()(0,0,0).real(), 1e-12); delete psi; @@ -451,7 +451,7 @@ TEST_F(PsiIntializerUnitTest, CalPsigNaoRandom) { this->psi_init->tabulate(); // always: new, initialize, tabulate, allocate, proj_ao_onkG const int nbands_start = this->psi_init->nbands_start(); const int nbasis = this->p_pw_wfc->npwk_max * PARAM.globalv.npol; - psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nullptr); + psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nbasis, true); this->psi_init->init_psig(psi->get_pointer(), 0); EXPECT_NEAR(0, psi->operator()(0,0,0).real(), 1e-12); delete psi; @@ -475,7 +475,7 @@ TEST_F(PsiIntializerUnitTest, CalPsigNaoSoc) { this->psi_init->tabulate(); // always: new, initialize, tabulate, allocate, proj_ao_onkG const int nbands_start = this->psi_init->nbands_start(); const int nbasis = this->p_pw_wfc->npwk_max * PARAM.globalv.npol; - psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nullptr); + psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nbasis, true); this->psi_init->init_psig(psi->get_pointer(), 0); EXPECT_NEAR(0, psi->operator()(0,0,0).real(), 1e-12); delete psi; @@ -499,7 +499,7 @@ TEST_F(PsiIntializerUnitTest, CalPsigNaoSocHasSo) { this->psi_init->tabulate(); // always: new, initialize, tabulate, allocate, proj_ao_onkG const int nbands_start = this->psi_init->nbands_start(); const int nbasis = this->p_pw_wfc->npwk_max * PARAM.globalv.npol; - psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nullptr); + psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nbasis, true); this->psi_init->init_psig(psi->get_pointer(), 0); EXPECT_NEAR(0, psi->operator()(0,0,0).real(), 1e-12); delete psi; @@ -523,7 +523,7 @@ TEST_F(PsiIntializerUnitTest, CalPsigNaoSocHasSoDOMAG) { this->psi_init->tabulate(); // always: new, initialize, tabulate, allocate, proj_ao_onkG const int nbands_start = this->psi_init->nbands_start(); const int nbasis = this->p_pw_wfc->npwk_max * PARAM.globalv.npol; - psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nullptr); + psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nbasis, true); this->psi_init->init_psig(psi->get_pointer(), 0); EXPECT_NEAR(0, psi->operator()(0,0,0).real(), 1e-12); delete psi; diff --git a/source/module_psi/test/psi_test.cpp b/source/module_psi/test/psi_test.cpp index 0b42df63c7..598cbe21bd 100644 --- a/source/module_psi/test/psi_test.cpp +++ b/source/module_psi/test/psi_test.cpp @@ -8,12 +8,12 @@ class TestPsi : public ::testing::Test const int ink = 2; const int inbands = 4; const int inbasis = 10; - int ngk[4] = {10, 10, 10, 10}; + std::vector ngk = {10, 10, 10, 10}; - const psi::Psi>* psi_object31 = new psi::Psi>(ink, inbands, inbasis, &ngk[0]); - const psi::Psi* psi_object32 = new psi::Psi(ink, inbands, inbasis, &ngk[0]); - const psi::Psi>* psi_object33 = new psi::Psi>(ink, inbands, inbasis, &ngk[0]); - const psi::Psi* psi_object34 = new psi::Psi(ink, inbands, inbasis, &ngk[0]); + const psi::Psi>* psi_object31 = new psi::Psi>(ink, inbands, inbasis, ngk, true); + const psi::Psi* psi_object32 = new psi::Psi(ink, inbands, inbasis, ngk, true); + const psi::Psi>* psi_object33 = new psi::Psi>(ink, inbands, inbasis, ngk, true); + const psi::Psi* psi_object34 = new psi::Psi(ink, inbands, inbasis, ngk, true); }; TEST_F(TestPsi, get_val) @@ -98,7 +98,7 @@ TEST_F(TestPsi, get_pointer_op_zero_complex_double) EXPECT_EQ(psi_object31->get_psi_bias(), 0); std::vector temp(ink, inbasis); - psi::Psi>* psi_temp = new psi::Psi>(ink, inbands, inbasis, temp.data(), true); + psi::Psi>* psi_temp = new psi::Psi>(ink, inbands, inbasis, temp, true); psi_temp->fix_k(0); EXPECT_EQ(psi_object31->get_current_nbas(), inbasis); delete psi_temp; @@ -241,10 +241,10 @@ TEST_F(TestPsi, range) TEST_F(TestPsi, band_first) { - const psi::Psi>* psi_band_c64 = new psi::Psi>(ink, inbands, inbasis, &ngk[0], false); - const psi::Psi* psi_band_64 = new psi::Psi(ink, inbands, inbasis, &ngk[0], false); - const psi::Psi>* psi_band_c32 = new psi::Psi>(ink, inbands, inbasis, &ngk[0], false); - const psi::Psi* psi_band_32 = new psi::Psi(ink, inbands, inbasis, &ngk[0], false); + const psi::Psi>* psi_band_c64 = new psi::Psi>(ink, inbands, inbasis, ngk, false); + const psi::Psi* psi_band_64 = new psi::Psi(ink, inbands, inbasis, ngk, false); + const psi::Psi>* psi_band_c32 = new psi::Psi>(ink, inbands, inbasis, ngk, false); + const psi::Psi* psi_band_32 = new psi::Psi(ink, inbands, inbasis, ngk, false); // set values: cover 4 different cases for (int ib = 0;ib < inbands;++ib) diff --git a/source/module_ri/exx_lip.hpp b/source/module_ri/exx_lip.hpp index 6be31a26b4..5e26446df4 100644 --- a/source/module_ri/exx_lip.hpp +++ b/source/module_ri/exx_lip.hpp @@ -112,7 +112,7 @@ Exx_Lip::Exx_Lip(const Exx_Info::Exx_Info_Lip& info_in, #endif this->k_pack->wf_wg.create(this->k_pack->kv_ptr->get_nks(),PARAM.inp.nbands); - this->k_pack->hvec_array = new psi::Psi(this->k_pack->kv_ptr->get_nks(), PARAM.inp.nbands, PARAM.globalv.nlocal, kv_ptr_in->ngk.data(), true); + this->k_pack->hvec_array = new psi::Psi(this->k_pack->kv_ptr->get_nks(), PARAM.inp.nbands, PARAM.globalv.nlocal, kv_ptr_in->ngk, true); // this->k_pack->hvec_array = new ModuleBase::ComplexMatrix[this->k_pack->kv_ptr->get_nks()]; // for( int ik=0; ikk_pack->kv_ptr->get_nks(); ++ik) // {