Skip to content

Commit

Permalink
Merge pull request #5234 from ye-luo/reduce-overload
Browse files Browse the repository at this point in the history
Remove custom real/imag() overload functions
  • Loading branch information
prckent authored Nov 20, 2024
2 parents 2f8ef74 + 36c3a95 commit 73672a7
Show file tree
Hide file tree
Showing 16 changed files with 84 additions and 122 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,10 @@ class KP3IndexFactorization_batched
for (int n = 0; n < nwalk; n++)
fill_n(E[n].origin(), 1, ComplexType(E0));
// must use Gc since GKK is is SP
int na = 0, nk = 0, nb = 0;
#if defined(MIXED_PRECISION)
int na = 0, nk = 0;
#endif
int nb = 0;
for (int K = 0; K < nkpts; ++K)
{
#if defined(MIXED_PRECISION)
Expand All @@ -591,16 +594,13 @@ class KP3IndexFactorization_batched
}
nk += ni;
#else
nk = nopk[K];
{
na = nelpk[nd][K];
CVector_ref haj_K(make_device_ptr(haj[nd * nkpts + K].origin()), {nocc_max * npol * nmo_max});
SpMatrix_ref Gaj(GKK[0][K][K].origin(), {nwalk, nocc_max * npol * nmo_max});
ma::product(ComplexType(1.), Gaj, haj_K, ComplexType(1.), E({0, nwalk}, 0));
}
if (walker_type == COLLINEAR)
{
na = nelpk[nd][nkpts + K];
CVector_ref haj_K(make_device_ptr(haj[nd * nkpts + K].origin()) + nocc_max * nmo_max, {nocc_max * nmo_max});
SpMatrix_ref Gaj(GKK[1][K][K].origin(), {nwalk, nocc_max * nmo_max});
ma::product(ComplexType(1.), Gaj, haj_K, ComplexType(1.), E({0, nwalk}, 0));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,17 +200,15 @@ void ham_ops_basic_serial(boost::mpi3::communicator& world)
Eloc[0][2] = (TG.Node() += ComplexType(Eloc[0][2]));
if (std::abs(file_data.E0 + file_data.E1) > 1e-8)
{
CHECK(real(Eloc[0][0]) == Approx(real(file_data.E0 + file_data.E1)));
CHECK(imag(Eloc[0][0]) == Approx(imag(file_data.E0 + file_data.E1)));
CHECK(ComplexType(Eloc[0][0]) == ComplexApprox(file_data.E0 + file_data.E1));
}
else
{
app_log() << " E1: " << setprecision(12) << Eloc[0][0] << std::endl;
}
if (std::abs(file_data.E2) > 1e-8)
{
CHECK(real(Eloc[0][1] + Eloc[0][2]) == Approx(real(file_data.E2)));
CHECK(imag(Eloc[0][1] + Eloc[0][2]) == Approx(imag(file_data.E2)));
CHECK(Eloc[0][1] + Eloc[0][2] == ComplexApprox(file_data.E2));
}
else
{
Expand Down Expand Up @@ -242,8 +240,7 @@ void ham_ops_basic_serial(boost::mpi3::communicator& world)
}
if (std::abs(file_data.Xsum) > 1e-8)
{
CHECK(real(Xsum) == Approx(real(file_data.Xsum)));
CHECK(imag(Xsum) == Approx(imag(file_data.Xsum)));
CHECK(Xsum == ComplexApprox(file_data.Xsum));
}
else
{
Expand Down Expand Up @@ -271,8 +268,7 @@ void ham_ops_basic_serial(boost::mpi3::communicator& world)
}
if (std::abs(file_data.Vsum) > 1e-8)
{
CHECK(real(Vsum) == Approx(real(file_data.Vsum)));
CHECK(imag(Vsum) == Approx(imag(file_data.Vsum)));
CHECK(Vsum == ComplexApprox(file_data.Vsum));
}
else
{
Expand Down Expand Up @@ -352,10 +348,9 @@ void ham_ops_basic_serial(boost::mpi3::communicator& world)
{
for (int j = 0; j < NMO; j++)
{
if (std::abs(Mat[i][j] - real(GFock[1][0][i * NMO + j])) > 1e-5)
if (auto gfock = ComplexType(GFock[1][0][i * NMO + j]); std::abs(Mat[i][j] - std::real(gfock)) > 1e-5)
{
std::cout << "DELTAA: " << i << " " << j << " " << Mat[i][j] << " " << real(GFock[1][0][i * NMO + j])
<< std::endl;
std::cout << "DELTAA: " << i << " " << j << " " << Mat[i][j] << " " << std::real(gfock) << std::endl;
}
//if(std::abs(real(GFock[1][0][i*NMO+j]))>1e-6)
//std::cout << i << " " << j << " " << real(GFock[1][0][i*NMO+j]) << " " << std::endl;
Expand All @@ -370,10 +365,9 @@ void ham_ops_basic_serial(boost::mpi3::communicator& world)
{
//std::cout << Mat[i][j] << std::endl;
//std::cout << Mat[i][j]-real(GFock[0][0][i*NMO+j]) << std::endl;
if (std::abs(Mat[i][j] - real(GFock[0][0][i * NMO + j])) > 1e-5)
if (auto gfock = ComplexType(GFock[0][0][i * NMO + j]); std::abs(Mat[i][j] - std::real(gfock)) > 1e-5)
{
std::cout << "DELTAB: " << i << " " << j << " " << Mat[i][j] << " " << real(GFock[0][0][i * NMO + j])
<< std::endl;
std::cout << "DELTAB: " << i << " " << j << " " << Mat[i][j] << " " << std::real(gfock) << std::endl;
}
//if(std::abs(real(GFock[0][0][i*NMO+j]))>1e-6)
//std::cout << i << " " << j << " " << real(GFock[0][0][i*NMO+j]) << " " << real(GFock[0][1][i*NMO+j]) << " " << real(GFock[0][2][i*NMO+j]) << std::endl;
Expand Down
22 changes: 11 additions & 11 deletions src/AFQMC/Numerics/detail/CPU/lapack_cpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ inline void gesvd_bufferSize(const int m, const int n, std::complex<T>* a, int&
int status;
lwork = -1;
gesvd('A', 'A', m, n, a, m, nullptr, nullptr, m, nullptr, m, &work, lwork, &rwork, status);
lwork = int(real(work));
lwork = int(std::real(work));
}

inline void geev(char* jobvl,
Expand Down Expand Up @@ -332,7 +332,7 @@ inline void hevr(char& JOBZ,
LIWORK, INFO);
if (query)
{
LWORK = int(real(WORK[0]));
LWORK = int(std::real(WORK[0]));
LRWORK = int(RWORK[0]);
LIWORK = int(IWORK[0]);
}
Expand Down Expand Up @@ -373,7 +373,7 @@ inline void hevr(char& JOBZ,
LIWORK, INFO);
if (query)
{
LWORK = int(real(WORK[0]));
LWORK = int(std::real(WORK[0]));
LRWORK = int(RWORK[0]);
LIWORK = int(IWORK[0]);
}
Expand Down Expand Up @@ -489,7 +489,7 @@ inline void gvx(int ITYPE,
IFAIL, INFO);
if (query)
{
LWORK = int(real(WORK[0]));
LWORK = int(std::real(WORK[0]));
}
}

Expand Down Expand Up @@ -527,7 +527,7 @@ inline void gvx(int ITYPE,
IFAIL, INFO);
if (query)
{
LWORK = int(real(WORK[0]));
LWORK = int(std::real(WORK[0]));
}
}

Expand Down Expand Up @@ -622,7 +622,7 @@ inline void getri_bufferSize(int n, std::complex<float> const* a, int lda, int&
int status;
lwork = -1;
cgetri(n, nullptr, lda, nullptr, &work, lwork, status);
lwork = int(real(work));
lwork = int(std::real(work));
}

inline void getri_bufferSize(int n, std::complex<double> const* a, int lda, int& lwork)
Expand All @@ -631,7 +631,7 @@ inline void getri_bufferSize(int n, std::complex<double> const* a, int lda, int&
int status;
lwork = -1;
zgetri(n, nullptr, lda, nullptr, &work, lwork, status);
lwork = int(real(work));
lwork = int(std::real(work));
}


Expand Down Expand Up @@ -745,7 +745,7 @@ inline void geqrf_bufferSize(int m, int n, T* a, int lda, int& lwork)
int status;
lwork = -1;
geqrf(m, n, nullptr, lda, nullptr, &work, lwork, status);
lwork = int(real(work));
lwork = int(std::real(work));
}

inline void gelqf(int M,
Expand Down Expand Up @@ -809,7 +809,7 @@ inline void gelqf_bufferSize(int m, int n, T* a, int lda, int& lwork)
int status;
lwork = -1;
gelqf(m, n, nullptr, lda, nullptr, &work, lwork, status);
lwork = int(real(work));
lwork = int(std::real(work));
}

inline void gqr(int M,
Expand Down Expand Up @@ -875,7 +875,7 @@ inline void gqr_bufferSize(int m, int n, int k, T* a, int lda, int& lwork)
int status;
lwork = -1;
gqr(m, n, k, nullptr, lda, nullptr, &work, lwork, status);
lwork = int(real(work));
lwork = int(std::real(work));
}

inline void glq(int M,
Expand Down Expand Up @@ -941,7 +941,7 @@ inline void glq_bufferSize(int m, int n, int k, T* a, int lda, int& lwork)
int status;
lwork = -1;
glq(m, n, k, nullptr, lda, nullptr, &work, lwork, status);
lwork = int(real(work));
lwork = int(std::real(work));
}

template<typename T1, typename T2>
Expand Down
14 changes: 2 additions & 12 deletions src/AFQMC/Numerics/detail/utilities.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include <complex>
#include "AFQMC/config.0.h"
#include "type_traits/complex_help.hpp"
#include "AFQMC/Memory/raw_pointers.hpp"
#include "AFQMC/Memory/SharedMemory/shm_ptr_with_raw_ptr_dispatch.hpp"

Expand Down Expand Up @@ -53,18 +54,7 @@ using CBLAS_TRANSPOSE = enum {CblasNoTrans=111, CblasTrans=112, CblasConjTrans=1
#endif
*/

inline double const& real(double const& d) { return d; }
inline float const& real(float const& f) { return f; }

inline double conj(double const& d) { return d; }
inline float conj(float const& f) { return f; }

inline std::complex<double> conj(std::complex<double> const& d) { return std::conj(d); }
inline std::complex<float> conj(std::complex<float> const& f) { return std::conj(f); }
//template<typename T>
//T conj(T const& v) { return v; }
//template<typename T>
//std::complex<T> conj(std::complex<T> const& v) { return std::conj(v); }
using qmcplusplus::conj;

template<class Ptr>
auto pointer_dispatch(Ptr p)
Expand Down
8 changes: 4 additions & 4 deletions src/AFQMC/Numerics/determinant.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ inline T determinant_from_getrf(int n, T* M, int lda, int* pivot, T LogOverlapFa
T sg(1.0);
for (int i = 0, ip = 1; i != n; i++, ip++)
{
if (real(M[i * lda + i]) < 0.0)
if (std::real(M[i * lda + i]) < 0.0)
{
res += std::log(-static_cast<T>(M[i * lda + i]));
sg *= -1.0;
Expand All @@ -50,7 +50,7 @@ inline void determinant_from_getrf(int n, T* M, int lda, int* pivot, T LogOverla
T sg(1.0);
for (int i = 0, ip = 1; i != n; i++, ip++)
{
if (real(M[i * lda + i]) < 0.0)
if (std::real(M[i * lda + i]) < 0.0)
{
*res += std::log(-static_cast<T>(M[i * lda + i]));
sg *= -1.0;
Expand Down Expand Up @@ -98,7 +98,7 @@ T determinant_from_geqrf(int n, T* M, int lda, T* buff, T LogOverlapFactor)
T res(0.0);
for (int i = 0; i < n; i++)
{
if (real(M[i * lda + i]) < 0.0)
if (std::real(M[i * lda + i]) < 0.0)
buff[i] = T(-1.0);
else
buff[i] = T(1.0);
Expand Down Expand Up @@ -158,7 +158,7 @@ inline void determinant_from_geqrf(int n, T* M, int lda, T* buff)
{
for (int i = 0; i < n; i++)
{
if (real(M[i * lda + i]) < 0)
if (std::real(M[i * lda + i]) < 0)
buff[i] = T(-1.0);
else
buff[i] = T(1.0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -413,9 +413,9 @@ Tp MixedDensityMatrix_noHerm_wSVD(const MatA& A,
assert( get<1>(C.sizes()) == get<0>(A.sizes()) );
}

using std::real;
using ma::determinant_from_geqrf;
using ma::H;
using ma::real;
using ma::T;
using ma::term_by_term_matrix_vector;

Expand Down
2 changes: 1 addition & 1 deletion src/AFQMC/Wavefunctions/NOMSD.icc
Original file line number Diff line number Diff line change
Expand Up @@ -2443,7 +2443,7 @@ void NOMSD<devPsiT>::vMF(Vec&& v)
// which should be exactly zero, suffers from truncation errors.
// Set it to zero.
for (int i = 0; i < v.num_elements(); i++)
v[i] = ComplexType(real(v[i]), 0.0);
v[i] = ComplexType(real(ComplexType(v[i])), 0.0);
}

template<class devPsiT>
Expand Down
4 changes: 2 additions & 2 deletions src/AFQMC/Wavefunctions/tests/test_phmsd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,8 @@ void test_phmsd(boost::mpi3::communicator& world)
// coefficients as using factory setup.
for (auto it = wset.begin(); it != wset.end(); ++it)
{
CHECK(std::abs(real(*it->overlap())) == Approx(std::abs(real(ovlp_sum))));
CHECK(std::abs(imag(*it->overlap())) == Approx(std::abs(imag(ovlp_sum))));
CHECK(std::abs(std::real(ComplexType(*it->overlap()))) == Approx(std::abs(std::real(ovlp_sum))));
CHECK(std::abs(std::imag(ComplexType(*it->overlap()))) == Approx(std::abs(std::imag(ovlp_sum))));
}
// It's not straightforward to calculate energy directly in unit test due to half
// rotation.
Expand Down
Loading

0 comments on commit 73672a7

Please sign in to comment.