forked from abacusmodeling/abacus-develop
-
Notifications
You must be signed in to change notification settings - Fork 139
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add PyDiagoDavid class * add pybind11 wrapper for `hsolver::DiagoDavid` * add `PyDiagoDavid` class as the pybind wrapper for `hsolver::DiagoDavid` class * add python function signature and python tests * fix some bugs in pytest:test_hsolver.py * Add example for diagonalizing sparse matrices using Davidson method and update README.md with relevant documentation * update example for hsolver and README.md * Revert unintended changes to file
- Loading branch information
Showing
8 changed files
with
404 additions
and
52 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
#ifndef PYTHON_PYABACUS_SRC_PY_DIAGO_DAVID_HPP | ||
#define PYTHON_PYABACUS_SRC_PY_DIAGO_DAVID_HPP | ||
|
||
#include <complex> | ||
#include <functional> | ||
|
||
#include <pybind11/pybind11.h> | ||
#include <pybind11/complex.h> | ||
#include <pybind11/functional.h> | ||
#include <pybind11/numpy.h> | ||
#include <pybind11/stl.h> | ||
|
||
#include "module_hsolver/diago_david.h" | ||
|
||
namespace py = pybind11; | ||
|
||
namespace py_hsolver | ||
{ | ||
|
||
class PyDiagoDavid | ||
{ | ||
public: | ||
PyDiagoDavid(int nbasis, int nband) : nbasis(nbasis), nband(nband) | ||
{ | ||
psi = new std::complex<double>[nbasis * nband]; | ||
eigenvalue = new double[nband]; | ||
} | ||
|
||
PyDiagoDavid(const PyDiagoDavid&) = delete; | ||
PyDiagoDavid& operator=(const PyDiagoDavid&) = delete; | ||
PyDiagoDavid(PyDiagoDavid&& other) : nbasis(other.nbasis), nband(other.nband) | ||
{ | ||
psi = other.psi; | ||
eigenvalue = other.eigenvalue; | ||
|
||
other.psi = nullptr; | ||
other.eigenvalue = nullptr; | ||
} | ||
|
||
~PyDiagoDavid() | ||
{ | ||
if (psi != nullptr) | ||
{ | ||
delete[] psi; | ||
psi = nullptr; | ||
} | ||
if (eigenvalue != nullptr) | ||
{ | ||
delete[] eigenvalue; | ||
eigenvalue = nullptr; | ||
} | ||
} | ||
|
||
void set_psi(py::array_t<std::complex<double>> psi_in) | ||
{ | ||
assert(psi_in.size() == nbasis * nband); | ||
|
||
for (size_t i = 0; i < nbasis * nband; ++i) | ||
{ | ||
psi[i] = psi_in.at(i); | ||
} | ||
} | ||
|
||
py::array_t<std::complex<double>> get_psi() | ||
{ | ||
py::array_t<std::complex<double>> psi_out(nband * nbasis); | ||
py::buffer_info psi_out_buf = psi_out.request(); | ||
|
||
std::complex<double>* psi_out_ptr = static_cast<std::complex<double>*>(psi_out_buf.ptr); | ||
|
||
for (size_t i = 0; i < nband * nbasis; ++i) | ||
{ | ||
psi_out_ptr[i] = psi[i]; | ||
} | ||
|
||
return psi_out; | ||
} | ||
|
||
void init_eigenvalue() | ||
{ | ||
for (size_t i = 0; i < nband; ++i) | ||
{ | ||
eigenvalue[i] = 0.0; | ||
} | ||
} | ||
|
||
py::array_t<double> get_eigenvalue() | ||
{ | ||
py::array_t<double> eigenvalue_out(nband); | ||
py::buffer_info eigenvalue_out_buf = eigenvalue_out.request(); | ||
|
||
double* eigenvalue_out_ptr = static_cast<double*>(eigenvalue_out_buf.ptr); | ||
|
||
for (size_t i = 0; i < nband; ++i) | ||
{ | ||
eigenvalue_out_ptr[i] = eigenvalue[i]; | ||
} | ||
|
||
return eigenvalue_out; | ||
} | ||
|
||
int diag( | ||
std::function<py::array_t<std::complex<double>>(py::array_t<std::complex<double>>)> mm_op, | ||
std::vector<double> precond_vec, | ||
int dav_ndim, | ||
double tol, | ||
int max_iter, | ||
bool use_paw, | ||
hsolver::diag_comm_info comm_info | ||
) { | ||
auto hpsi_func = [mm_op] ( | ||
std::complex<double> *hpsi_out, | ||
std::complex<double> *psi_in, | ||
const int nband_in, | ||
const int nbasis_in, | ||
const int band_index1, | ||
const int band_index2 | ||
) { | ||
// Note: numpy's py::array_t is row-major, but | ||
// our raw pointer-array is column-major | ||
py::array_t<std::complex<double>, py::array::f_style> psi({nbasis_in, band_index2 - band_index1 + 1}); | ||
py::buffer_info psi_buf = psi.request(); | ||
std::complex<double>* psi_ptr = static_cast<std::complex<double>*>(psi_buf.ptr); | ||
std::copy(psi_in + band_index1 * nbasis_in, psi_in + (band_index2 + 1) * nbasis_in, psi_ptr); | ||
|
||
py::array_t<std::complex<double>, py::array::f_style> hpsi = mm_op(psi); | ||
|
||
py::buffer_info hpsi_buf = hpsi.request(); | ||
std::complex<double>* hpsi_ptr = static_cast<std::complex<double>*>(hpsi_buf.ptr); | ||
std::copy(hpsi_ptr, hpsi_ptr + (band_index2 - band_index1 + 1) * nbasis_in, hpsi_out); | ||
}; | ||
|
||
auto spsi_func = [this] ( | ||
const std::complex<double> *psi_in, | ||
std::complex<double> *spsi_out, | ||
const int nrow, | ||
const int npw, | ||
const int nbands | ||
) { | ||
syncmem_op()(this->ctx, this->ctx, spsi_out, psi_in, static_cast<size_t>(nbands * nrow)); | ||
}; | ||
|
||
obj = std::make_unique<hsolver::DiagoDavid<std::complex<double>, base_device::DEVICE_CPU>>( | ||
precond_vec.data(), | ||
nband, | ||
nbasis, | ||
dav_ndim, | ||
use_paw, | ||
comm_info | ||
); | ||
|
||
return obj->diag(hpsi_func, spsi_func, nbasis, psi, eigenvalue, tol, max_iter); | ||
} | ||
|
||
private: | ||
std::complex<double>* psi = nullptr; | ||
double* eigenvalue = nullptr; | ||
|
||
int nbasis; | ||
int nband; | ||
|
||
std::unique_ptr<hsolver::DiagoDavid<std::complex<double>, base_device::DEVICE_CPU>> obj; | ||
|
||
base_device::DEVICE_CPU* ctx = {}; | ||
using syncmem_op = base_device::memory::synchronize_memory_op<std::complex<double>, base_device::DEVICE_CPU, base_device::DEVICE_CPU>; | ||
}; | ||
|
||
} // namespace py_hsolver | ||
|
||
#endif |
Oops, something went wrong.