Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: Use memory_op to set diag_const_nums #5246

Merged
merged 12 commits into from
Oct 18, 2024
45 changes: 39 additions & 6 deletions source/module_hsolver/diag_const_nums.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,60 @@ template class const_nums<std::complex<float>>;

// Specialize templates to support double types
template <>
const_nums<double>::const_nums() : zero(0.0), one(1.0), neg_one(-1.0)
const_nums<double>::const_nums()
{
base_device::memory::resize_memory_op<double, base_device::DEVICE_CPU>()(
this->cpu_ctx, this->zero, 1);
this->zero[0] = 0.0;
base_device::memory::resize_memory_op<double, base_device::DEVICE_CPU>()(
this->cpu_ctx, this->one, 1);
this->one[0] = 1.0;
base_device::memory::resize_memory_op<double, base_device::DEVICE_CPU>()(
this->cpu_ctx, this->neg_one, 1);
this->neg_one[0] = -1.0;
}

// Specialize templates to support double types
template <>
const_nums<float>::const_nums() : zero(0.0), one(1.0), neg_one(-1.0)
const_nums<float>::const_nums()
{
base_device::memory::resize_memory_op<float, base_device::DEVICE_CPU>()(
this->cpu_ctx, this->zero, 1);
this->zero[0] = 0.0;
base_device::memory::resize_memory_op<float, base_device::DEVICE_CPU>()(
this->cpu_ctx, this->one, 1);
this->one[0] = 1.0;
base_device::memory::resize_memory_op<float, base_device::DEVICE_CPU>()(
this->cpu_ctx, this->neg_one, 1);
this->neg_one[0] = -1.0;
}

// Specialized templates to support std:: complex<double>types
template <>
const_nums<std::complex<double>>::const_nums()
: zero(std::complex<double>(0.0, 0.0)), one(std::complex<double>(1.0, 0.0)),
neg_one(std::complex<double>(-1.0, 0.0))
{
base_device::memory::resize_memory_op<std::complex<double>, base_device::DEVICE_CPU>()(
this->cpu_ctx, this->zero, 1);
this->zero[0] = std::complex<double>(0.0, 0.0);
base_device::memory::resize_memory_op<std::complex<double>, base_device::DEVICE_CPU>()(
this->cpu_ctx, this->one, 1);
this->one[0] = std::complex<double>(1.0, 0.0);
base_device::memory::resize_memory_op<std::complex<double>, base_device::DEVICE_CPU>()(
this->cpu_ctx, this->neg_one, 1);
this->neg_one[0] = std::complex<double>(-1.0, 0.0);
}

// Specialized templates to support std:: complex<float>types
template <>
const_nums<std::complex<float>>::const_nums()
: zero(std::complex<float>(0.0, 0.0)), one(std::complex<float>(1.0, 0.0)), neg_one(std::complex<float>(-1.0, 0.0))
{
}
base_device::memory::resize_memory_op<std::complex<float>, base_device::DEVICE_CPU>()(
this->cpu_ctx, this->zero, 1);
this->zero[0] = std::complex<float>(0.0, 0.0);
base_device::memory::resize_memory_op<std::complex<float>, base_device::DEVICE_CPU>()(
this->cpu_ctx, this->one, 1);
this->one[0] = std::complex<float>(1.0, 0.0);
base_device::memory::resize_memory_op<std::complex<float>, base_device::DEVICE_CPU>()(
this->cpu_ctx, this->neg_one, 1);
this->neg_one[0] = std::complex<float>(-1.0, 0.0);
}
8 changes: 5 additions & 3 deletions source/module_hsolver/diag_const_nums.h
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
#ifndef DIAG_CONST_NUMS
#define DIAG_CONST_NUMS
#include "module_base/module_device/memory_op.h"

template <typename T>
struct const_nums
{
const_nums();
T zero;
T one;
T neg_one;
base_device::DEVICE_CPU* cpu_ctx = {};
T* zero = nullptr;
T* one = nullptr;
T* neg_one = nullptr;
};

#endif
18 changes: 9 additions & 9 deletions source/module_hsolver/diago_dav_subspace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ Diago_DavSubspace<T, Device>::Diago_DavSubspace(const std::vector<Real>& precond
{
this->device = base_device::get_device_type<Device>(this->ctx);

this->one = &this->cs.one;
this->zero = &this->cs.zero;
this->neg_one = &this->cs.neg_one;
this->one = this->cs.one;
this->zero = this->cs.zero;
this->neg_one = this->cs.neg_one;

assert(david_ndim_in > 1);
assert(david_ndim_in * nband_in < nbasis_in * this->diag_comm.nproc);
Expand Down Expand Up @@ -534,8 +534,8 @@ void Diago_DavSubspace<T, Device>::diag_zhegvx(const int& nbase,
}
else
{
std::vector<std::vector<T>> h_diag(nbase, std::vector<T>(nbase, cs.zero));
std::vector<std::vector<T>> s_diag(nbase, std::vector<T>(nbase, cs.zero));
std::vector<std::vector<T>> h_diag(nbase, std::vector<T>(nbase, cs.zero[0]));
std::vector<std::vector<T>> s_diag(nbase, std::vector<T>(nbase, cs.zero[0]));

for (size_t i = 0; i < nbase; i++)
{
Expand Down Expand Up @@ -564,10 +564,10 @@ void Diago_DavSubspace<T, Device>::diag_zhegvx(const int& nbase,

for (size_t j = nbase; j < this->nbase_x; j++)
{
hcc[i * this->nbase_x + j] = cs.zero;
hcc[j * this->nbase_x + i] = cs.zero;
scc[i * this->nbase_x + j] = cs.zero;
scc[j * this->nbase_x + i] = cs.zero;
hcc[i * this->nbase_x + j] = cs.zero[0];
hcc[j * this->nbase_x + i] = cs.zero[0];
scc[i * this->nbase_x + j] = cs.zero[0];
scc[j * this->nbase_x + i] = cs.zero[0];
}
}
}
Expand Down
Loading