diff --git a/source/module_esolver/esolver_ks_pw.cpp b/source/module_esolver/esolver_ks_pw.cpp index d16289d4e7..23639d45a3 100644 --- a/source/module_esolver/esolver_ks_pw.cpp +++ b/source/module_esolver/esolver_ks_pw.cpp @@ -106,15 +106,18 @@ ESolver_KS_PW::~ESolver_KS_PW() container::kernels::destroyGpuBlasHandle(); container::kernels::destroyGpuSolverHandle(); #endif - delete reinterpret_cast*>(this->kspw_psi); } #ifdef __DSP std::cout << " ** Closing DSP Hardware..." << std::endl; dspDestoryHandle(GlobalV::MY_RANK); #endif + if(PARAM.inp.device == "gpu" || PARAM.inp.precision == "single") + { + delete this->kspw_psi; + } if (PARAM.inp.precision == "single") { - delete reinterpret_cast, Device>*>(this->__kspw_psi); + delete this->__kspw_psi; } delete this->psi; diff --git a/source/module_hamilt_pw/hamilt_pwdft/VNL_in_pw.cpp b/source/module_hamilt_pw/hamilt_pwdft/VNL_in_pw.cpp index c33d68470b..b41c8f476e 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/VNL_in_pw.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/VNL_in_pw.cpp @@ -532,9 +532,12 @@ void pseudopot_cell_vnl::getvnl(Device* ctx, delmem_var_op()(ctx, ylm); delmem_var_op()(ctx, vkb1); delmem_complex_op()(ctx, sk); - if (base_device::get_device_type(ctx) == base_device::GpuDevice) + if (PARAM.inp.device == "gpu" || PARAM.inp.precision == "single") { delmem_var_op()(ctx, gk); + } + if (PARAM.inp.device == "gpu") + { delmem_int_op()(ctx, atom_nh); delmem_int_op()(ctx, atom_nb); delmem_int_op()(ctx, atom_na); diff --git a/source/module_io/read_input_item_system.cpp b/source/module_io/read_input_item_system.cpp index 97b5875f96..c95e785885 100644 --- a/source/module_io/read_input_item_system.cpp +++ b/source/module_io/read_input_item_system.cpp @@ -775,12 +775,28 @@ void ReadInput::item_system() para.input.device=base_device::information::get_device_flag( para.inp.device, para.inp.basis_type); }; + item.check_value = [](const Input_Item& item, const Parameter& para) { + std::vector avail_list = {"cpu", "gpu"}; + if (std::find(avail_list.begin(), avail_list.end(), para.input.device) == avail_list.end()) + { + const std::string warningstr = nofound_str(avail_list, "device"); + ModuleBase::WARNING_QUIT("ReadInput", warningstr); + } + }; this->add_item(item); } { Input_Item item("precision"); item.annotation = "the computing precision for ABACUS"; read_sync_string(input.precision); + item.check_value = [](const Input_Item& item, const Parameter& para) { + std::vector avail_list = {"single", "double"}; + if (std::find(avail_list.begin(), avail_list.end(), para.input.precision) == avail_list.end()) + { + const std::string warningstr = nofound_str(avail_list, "precision"); + ModuleBase::WARNING_QUIT("ReadInput", warningstr); + } + }; this->add_item(item); } } diff --git a/source/module_psi/psi_init.cpp b/source/module_psi/psi_init.cpp index f5a0fa6595..24708c9665 100644 --- a/source/module_psi/psi_init.cpp +++ b/source/module_psi/psi_init.cpp @@ -53,7 +53,7 @@ void PSIInit::prepare_init(const int& random_seed) this->psi_initer = std::unique_ptr>(new psi_initializer_random()); } else if (this->init_wfc == "atomic" - || (this->init_wfc == "atomic+random" && this->ucell.natomwfc != PARAM.inp.nbands)) + || (this->init_wfc == "atomic+random" && this->ucell.natomwfc < PARAM.inp.nbands)) { this->psi_initer = std::unique_ptr>(new psi_initializer_atomic()); } @@ -99,17 +99,30 @@ void PSIInit::initialize_psi(Psi>* psi, const int nbands_start = this->psi_initer->nbands_start(); const int nbands = psi->get_nbands(); const int nbasis = psi->get_nbasis(); - const bool another_psi_space = (nbands_start != nbands || PARAM.inp.precision == "single"); + const bool not_equal = (nbands_start != nbands); Psi* psi_cpu = reinterpret_cast*>(psi); Psi* psi_device = kspw_psi; - if (another_psi_space) + if (not_equal) { psi_cpu = new Psi(1, nbands_start, nbasis, nullptr); psi_device = PARAM.inp.device == "gpu" ? new psi::Psi(psi_cpu[0]) : reinterpret_cast*>(psi_cpu); } + else if (PARAM.inp.precision == "single") + { + if (PARAM.inp.device == "cpu") + { + psi_cpu = reinterpret_cast*>(kspw_psi); + psi_device = kspw_psi; + } + else + { + psi_cpu = new Psi(1, nbands_start, nbasis, nullptr); + psi_device = kspw_psi; + } + } // loop over kpoints, make it possible to only allocate memory for psig at the only one kpt // like (1, nbands, npwx), in which npwx is the maximal npw of all kpoints @@ -126,16 +139,16 @@ void PSIInit::initialize_psi(Psi>* psi, this->psi_initer->init_psig(psi_cpu->get_pointer(), ik); if (psi_device->get_pointer() != psi_cpu->get_pointer()) { - castmem_h2d_op()(ctx, cpu_ctx, psi_device->get_pointer(), psi_cpu->get_pointer(), nbands_start * nbasis); + syncmem_h2d_op()(ctx, cpu_ctx, psi_device->get_pointer(), psi_cpu->get_pointer(), nbands_start * nbasis); } std::vector::type> etatom(nbands_start, 0.0); if (this->ks_solver == "cg") { - if (another_psi_space) + if (not_equal) { - // for diagH_subspace_init, psi_cpu->get_pointer() and kspw_psi->get_pointer() should be different + // for diagH_subspace_init, psi_device->get_pointer() and kspw_psi->get_pointer() should be different hsolver::DiagoIterAssist::diagH_subspace_init(p_hamilt, psi_device->get_pointer(), nbands_start, @@ -145,7 +158,7 @@ void PSIInit::initialize_psi(Psi>* psi, } else { - // for diagH_subspace_init, psi_cpu->get_pointer() and kspw_psi->get_pointer() can be the same + // for diagH_subspace, psi_device->get_pointer() and kspw_psi->get_pointer() can be the same hsolver::DiagoIterAssist::diagH_subspace(p_hamilt, *psi_device, *kspw_psi, @@ -155,14 +168,14 @@ void PSIInit::initialize_psi(Psi>* psi, } else // dav, bpcg { - if (another_psi_space) + if (psi_device->get_pointer() != kspw_psi->get_pointer()) { syncmem_complex_op()(ctx, ctx, kspw_psi->get_pointer(), psi_device->get_pointer(), nbands * nbasis); } } } // end k-point loop - if (another_psi_space) + if (not_equal) { delete psi_cpu; if(PARAM.inp.device == "gpu") @@ -170,6 +183,10 @@ void PSIInit::initialize_psi(Psi>* psi, delete psi_device; } } + else if (PARAM.inp.precision == "single" && PARAM.inp.device == "gpu") + { + delete psi_cpu; + } ModuleBase::timer::tick("PSIInit", "initialize_psi"); } diff --git a/source/module_psi/psi_init.h b/source/module_psi/psi_init.h index 453d27b072..712952e634 100644 --- a/source/module_psi/psi_init.h +++ b/source/module_psi/psi_init.h @@ -82,8 +82,7 @@ class PSIInit //-------------------------OP-------------------------------------------- using syncmem_complex_op = base_device::memory::synchronize_memory_op; - using castmem_h2d_op - = base_device::memory::cast_memory_op; + using syncmem_h2d_op = base_device::memory::synchronize_memory_op; }; ///@brief allocate the wavefunction