From 74f25822ccd5c395ddf7b88d5b089e9689c97e07 Mon Sep 17 00:00:00 2001 From: dzzz2001 Date: Fri, 12 Jul 2024 23:37:33 +0800 Subject: [PATCH 1/5] optimize function pow in module_gint --- .../module_gint/gint_tools.h | 26 +++++++++++++++++++ .../module_gint/grid_meshball.cpp | 3 ++- .../module_gint/kernels/cuda/interp.cuh | 26 +++++++++++++++++++ 3 files changed, 54 insertions(+), 1 deletion(-) diff --git a/source/module_hamilt_lcao/module_gint/gint_tools.h b/source/module_hamilt_lcao/module_gint/gint_tools.h index 178ccf6fa5..632274b604 100644 --- a/source/module_hamilt_lcao/module_gint/gint_tools.h +++ b/source/module_hamilt_lcao/module_gint/gint_tools.h @@ -136,6 +136,32 @@ class Gint_inout namespace Gint_Tools { + +inline double pow(const double base, const int exp) +{ + double result = 1.0; + switch (exp) + { + case 0: + return 1.0; + case 1: + return base; + case 2: + return base * base; + case 3: + return base * base * base; + case 4: + return base * base * base * base; + case 5: + return base * base * base * base * base; + default: + for (int i = 0; i < exp; i++) + { + result *= base; + } + return result; + } +} // vindex[pw.bxyz] int* get_vindex(const int bxyz, const int bx, diff --git a/source/module_hamilt_lcao/module_gint/grid_meshball.cpp b/source/module_hamilt_lcao/module_gint/grid_meshball.cpp index e01518c668..d8e86864fd 100644 --- a/source/module_hamilt_lcao/module_gint/grid_meshball.cpp +++ b/source/module_hamilt_lcao/module_gint/grid_meshball.cpp @@ -1,4 +1,5 @@ #include "grid_meshball.h" +#include "gint_tools.h" // for pow(double, int) #include "module_base/memory.h" Grid_MeshBall::Grid_MeshBall() @@ -127,7 +128,7 @@ double Grid_MeshBall::deal_with_atom_spillage(const double *pos) cell[ip] = i*this->bigcell_vec1[ip] + j*this->bigcell_vec2[ip] + k*this->bigcell_vec3[ip]; - dx += std::pow(cell[ip] - pos[ip], 2); + dx += pow(cell[ip] - pos[ip], 2); } r2 = std::min(dx, r2); } diff --git a/source/module_hamilt_lcao/module_gint/kernels/cuda/interp.cuh b/source/module_hamilt_lcao/module_gint/kernels/cuda/interp.cuh index 5c5882be8f..c02043120f 100644 --- a/source/module_hamilt_lcao/module_gint/kernels/cuda/interp.cuh +++ b/source/module_hamilt_lcao/module_gint/kernels/cuda/interp.cuh @@ -5,6 +5,32 @@ namespace GintKernel { +static __device__ double pow(double base, int exp) +{ + double result = 1.0; + switch (exp) + { + case 0: + return 1.0; + case 1: + return base; + case 2: + return base * base; + case 3: + return base * base * base; + case 4: + return base * base * base * base; + case 5: + return base * base * base * base * base; + default: + for (int i = 0; i < exp; i++) + { + result *= base; + } + return result; + } +} + static __device__ void interp_rho(const double dist, const double delta_r, const int atype, From bb4f453df718ff021c210930b69bff741d168cd1 Mon Sep 17 00:00:00 2001 From: dzzz2001 Date: Sat, 13 Jul 2024 12:03:03 +0800 Subject: [PATCH 2/5] fix a bug --- source/module_hamilt_lcao/module_gint/gint_tools.h | 6 +----- .../module_hamilt_lcao/module_gint/kernels/cuda/interp.cuh | 6 +----- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/source/module_hamilt_lcao/module_gint/gint_tools.h b/source/module_hamilt_lcao/module_gint/gint_tools.h index 632274b604..0e33efa1e2 100644 --- a/source/module_hamilt_lcao/module_gint/gint_tools.h +++ b/source/module_hamilt_lcao/module_gint/gint_tools.h @@ -139,7 +139,6 @@ namespace Gint_Tools inline double pow(const double base, const int exp) { - double result = 1.0; switch (exp) { case 0: @@ -155,10 +154,7 @@ inline double pow(const double base, const int exp) case 5: return base * base * base * base * base; default: - for (int i = 0; i < exp; i++) - { - result *= base; - } + double result = std::pow(base, exp); return result; } } diff --git a/source/module_hamilt_lcao/module_gint/kernels/cuda/interp.cuh b/source/module_hamilt_lcao/module_gint/kernels/cuda/interp.cuh index c02043120f..fcd9d8fd87 100644 --- a/source/module_hamilt_lcao/module_gint/kernels/cuda/interp.cuh +++ b/source/module_hamilt_lcao/module_gint/kernels/cuda/interp.cuh @@ -7,7 +7,6 @@ namespace GintKernel { static __device__ double pow(double base, int exp) { - double result = 1.0; switch (exp) { case 0: @@ -23,10 +22,7 @@ static __device__ double pow(double base, int exp) case 5: return base * base * base * base * base; default: - for (int i = 0; i < exp; i++) - { - result *= base; - } + double result = std::pow(base, exp); return result; } } From f453a58bc58268d25a66c73f407700599354f2e1 Mon Sep 17 00:00:00 2001 From: dzzz2001 Date: Sun, 14 Jul 2024 09:11:00 +0800 Subject: [PATCH 3/5] rename function pow in module_gint --- source/module_hamilt_lcao/module_gint/cal_ddpsir_ylm.cpp | 6 +++--- source/module_hamilt_lcao/module_gint/cal_dpsir_ylm.cpp | 2 +- source/module_hamilt_lcao/module_gint/gint_tools.h | 5 +++-- source/module_hamilt_lcao/module_gint/grid_meshball.cpp | 3 +-- .../module_gint/kernels/cuda/interp.cuh | 8 +++++--- 5 files changed, 13 insertions(+), 11 deletions(-) diff --git a/source/module_hamilt_lcao/module_gint/cal_ddpsir_ylm.cpp b/source/module_hamilt_lcao/module_gint/cal_ddpsir_ylm.cpp index ac5e92616b..ecb7883842 100644 --- a/source/module_hamilt_lcao/module_gint/cal_ddpsir_ylm.cpp +++ b/source/module_hamilt_lcao/module_gint/cal_ddpsir_ylm.cpp @@ -160,7 +160,7 @@ void cal_ddpsir_ylm( const int ll = atom->iw2l[iw]; const int idx_lm = atom->iw2_ylm[iw]; - const double rl = pow(distance1, ll); + const double rl = pow_int(distance1, ll); // derivative of wave functions with respect to atom positions. const double tmpdphi_rly = (dtmp - tmp * ll / distance1) / rl * rly[idx_lm] / distance1; @@ -268,8 +268,8 @@ void cal_ddpsir_ylm( const int ll = atom->iw2l[iw]; const int idx_lm = atom->iw2_ylm[iw]; - const double rl = pow(distance, ll); - const double r_lp2 = pow(distance, ll + 2); + const double rl = pow_int(distance, ll); + const double r_lp2 = pow_int(distance, ll + 2); // d/dr (R_l / r^l) const double tmpdphi = (dtmp - tmp * ll / distance) / rl; diff --git a/source/module_hamilt_lcao/module_gint/cal_dpsir_ylm.cpp b/source/module_hamilt_lcao/module_gint/cal_dpsir_ylm.cpp index 87644dbcfa..163a944c41 100644 --- a/source/module_hamilt_lcao/module_gint/cal_dpsir_ylm.cpp +++ b/source/module_hamilt_lcao/module_gint/cal_dpsir_ylm.cpp @@ -115,7 +115,7 @@ void cal_dpsir_ylm( const int ll = atom->iw2l[iw]; const int idx_lm = atom->iw2_ylm[iw]; - const double rl = pow(distance, ll); + const double rl = pow_int(distance, ll); // 3D wave functions p_psi[iw] = tmp * rly[idx_lm] / rl; diff --git a/source/module_hamilt_lcao/module_gint/gint_tools.h b/source/module_hamilt_lcao/module_gint/gint_tools.h index 0e33efa1e2..c8771dfbe6 100644 --- a/source/module_hamilt_lcao/module_gint/gint_tools.h +++ b/source/module_hamilt_lcao/module_gint/gint_tools.h @@ -136,8 +136,9 @@ class Gint_inout namespace Gint_Tools { - -inline double pow(const double base, const int exp) +// if exponent is an integer between 0 and 5 (the most common cases in gint), +// pow_int is much faster than std::pow +inline double pow_int(const double base, const int exp) { switch (exp) { diff --git a/source/module_hamilt_lcao/module_gint/grid_meshball.cpp b/source/module_hamilt_lcao/module_gint/grid_meshball.cpp index d8e86864fd..1868559d16 100644 --- a/source/module_hamilt_lcao/module_gint/grid_meshball.cpp +++ b/source/module_hamilt_lcao/module_gint/grid_meshball.cpp @@ -1,5 +1,4 @@ #include "grid_meshball.h" -#include "gint_tools.h" // for pow(double, int) #include "module_base/memory.h" Grid_MeshBall::Grid_MeshBall() @@ -128,7 +127,7 @@ double Grid_MeshBall::deal_with_atom_spillage(const double *pos) cell[ip] = i*this->bigcell_vec1[ip] + j*this->bigcell_vec2[ip] + k*this->bigcell_vec3[ip]; - dx += pow(cell[ip] - pos[ip], 2); + dx += (cell[ip] - pos[ip]) * (cell[ip] - pos[ip]); } r2 = std::min(dx, r2); } diff --git a/source/module_hamilt_lcao/module_gint/kernels/cuda/interp.cuh b/source/module_hamilt_lcao/module_gint/kernels/cuda/interp.cuh index fcd9d8fd87..31ccf3ca2c 100644 --- a/source/module_hamilt_lcao/module_gint/kernels/cuda/interp.cuh +++ b/source/module_hamilt_lcao/module_gint/kernels/cuda/interp.cuh @@ -5,7 +5,9 @@ namespace GintKernel { -static __device__ double pow(double base, int exp) +// if exponent is an integer between 0 and 5 (the most common cases in gint), +// pow_int is much faster than std::pow +static __device__ double pow_int(double base, int exp) { switch (exp) { @@ -22,7 +24,7 @@ static __device__ double pow(double base, int exp) case 5: return base * base * base * base * base; default: - double result = std::pow(base, exp); + double result = pow(base, exp); return result; } } @@ -167,7 +169,7 @@ static __device__ void interp_f(const double dist, // Extract information from atom_iw2_* arrays const int ll = atom_iw2_l[it_nw_iw]; const int idx_lm = atom_iw2_ylm[it_nw_iw]; - const double rl = pow(dist, ll); + const double rl = pow_int(dist, ll); const double rl_r = 1.0 / rl; const double dist_r = 1 / dist; const int dpsi_idx = psi_idx * 3; From 246051c84abcd2b58a62f90459ed60f4d8717978 Mon Sep 17 00:00:00 2001 From: dzzz2001 Date: Sun, 14 Jul 2024 09:18:15 +0800 Subject: [PATCH 4/5] optimize the calculation of r_lp2 --- source/module_hamilt_lcao/module_gint/cal_ddpsir_ylm.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/module_hamilt_lcao/module_gint/cal_ddpsir_ylm.cpp b/source/module_hamilt_lcao/module_gint/cal_ddpsir_ylm.cpp index ecb7883842..ea785361f1 100644 --- a/source/module_hamilt_lcao/module_gint/cal_ddpsir_ylm.cpp +++ b/source/module_hamilt_lcao/module_gint/cal_ddpsir_ylm.cpp @@ -269,7 +269,7 @@ void cal_ddpsir_ylm( const int idx_lm = atom->iw2_ylm[iw]; const double rl = pow_int(distance, ll); - const double r_lp2 = pow_int(distance, ll + 2); + const double r_lp2 =rl * distance * distance; // d/dr (R_l / r^l) const double tmpdphi = (dtmp - tmp * ll / distance) / rl; From 464d5204fd01b34071917636f5aabfc339d5c832 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci-lite[bot]" <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com> Date: Sun, 14 Jul 2024 01:50:38 +0000 Subject: [PATCH 5/5] [pre-commit.ci lite] apply automatic fixes --- source/module_hamilt_lcao/module_gint/cal_dpsir_ylm.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/source/module_hamilt_lcao/module_gint/cal_dpsir_ylm.cpp b/source/module_hamilt_lcao/module_gint/cal_dpsir_ylm.cpp index 163a944c41..8c6a4ce637 100644 --- a/source/module_hamilt_lcao/module_gint/cal_dpsir_ylm.cpp +++ b/source/module_hamilt_lcao/module_gint/cal_dpsir_ylm.cpp @@ -68,8 +68,9 @@ void cal_dpsir_ylm( double distance = std::sqrt(dr[0] * dr[0] + dr[1] * dr[1] + dr[2] * dr[2]); ModuleBase::Ylm::grad_rl_sph_harm(ucell.atoms[it].nwl, dr[0], dr[1], dr[2], rly, grly.get_ptr_2D()); - if (distance < 1e-9) + if (distance < 1e-9) { distance = 1e-9; +} const double position = distance / delta_r;