Skip to content

Commit

Permalink
Merge pull request #4733 from anbenali/VirtualParticle_MO_batched
Browse files Browse the repository at this point in the history
Specialization of functions using VirtualParticle for the MO
  • Loading branch information
prckent authored Sep 19, 2023
2 parents db96869 + 99074f3 commit 053f38a
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 23 deletions.
5 changes: 5 additions & 0 deletions src/QMCWaveFunctions/BasisSetBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#define QMCPLUSPLUS_BASISSETBASE_H

#include "Particle/ParticleSet.h"
#include "Particle/VirtualParticleSet.h"
#include "QMCWaveFunctions/OrbitalSetTraits.h"
#include "OMPTarget/OffloadAlignedAllocators.hpp"

Expand Down Expand Up @@ -153,6 +154,9 @@ struct SoaBasisSetBase
virtual void mw_evaluateVGL(const RefVectorWithLeader<ParticleSet>& P_list, int iat, OffloadMWVGLArray& vgl) = 0;
//Evaluates value for electron "iat". places it in a offload array for batched code.
virtual void mw_evaluateValue(const RefVectorWithLeader<ParticleSet>& P_list, int iat, OffloadMWVArray& v) = 0;
//Evaluates value for all the electrons of the virtual particles. places it in a offload array for batched code.
virtual void mw_evaluateValueVPs(const RefVectorWithLeader<const VirtualParticleSet>& vp_list,
OffloadMWVArray& v) = 0;
//Evaluates value, gradient, and Hessian for electron "iat". Parks them into a temporary data structure "vgh".
virtual void evaluateVGH(const ParticleSet& P, int iat, vgh_type& vgh) = 0;
//Evaluates value, gradient, and Hessian, and Gradient Hessian for electron "iat". Parks them into a temporary data structure "vghgh".
Expand All @@ -167,6 +171,7 @@ struct SoaBasisSetBase
int jion,
vghgh_type& vghgh) = 0;
virtual void evaluateV(const ParticleSet& P, int iat, value_type* restrict vals) = 0;

virtual bool is_S_orbital(int mo_idx, int ao_idx) { return false; }

/// Determine which orbitals are S-type. Used for cusp correction.
Expand Down
75 changes: 57 additions & 18 deletions src/QMCWaveFunctions/LCAO/LCAOrbitalSet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@ struct LCAOrbitalSet::LCAOMultiWalkerMem : public Resource

std::unique_ptr<Resource> makeClone() const override { return std::make_unique<LCAOMultiWalkerMem>(*this); }

OffloadMWVGLArray phi_vgl_v; // [5][NW][NumMO]
OffloadMWVGLArray basis_mw; // [5][NW][NumAO]
OffloadMWVArray phi_v; // [NW][NumMO]
OffloadMWVArray basis_v_mw; // [NW][NumMO]
OffloadMWVGLArray phi_vgl_v; // [5][NW][NumMO]
OffloadMWVGLArray basis_vgl_mw; // [5][NW][NumAO]
OffloadMWVArray phi_v; // [NW][NumMO]
OffloadMWVArray basis_v_mw; // [NW][NumAO]
OffloadMWVArray vp_phi_v; // [NVPs][NumMO]
OffloadMWVArray vp_basis_v_mw; // [NVPs][NumAO]
};

LCAOrbitalSet::LCAOrbitalSet(const std::string& my_name, std::unique_ptr<basis_type>&& bs)
Expand Down Expand Up @@ -428,13 +430,13 @@ void LCAOrbitalSet::mw_evaluateVGLImplGEMM(const RefVectorWithLeader<SPOSet>& sp
OffloadMWVGLArray& phi_vgl_v) const
{
assert(this == &spo_list.getLeader());
auto& spo_leader = spo_list.getCastedLeader<LCAOrbitalSet>();
auto& basis_mw = spo_leader.mw_mem_handle_.getResource().basis_mw;
basis_mw.resize(DIM_VGL, spo_list.size(), BasisSetSize);
auto& spo_leader = spo_list.getCastedLeader<LCAOrbitalSet>();
auto& basis_vgl_mw = spo_leader.mw_mem_handle_.getResource().basis_vgl_mw;
basis_vgl_mw.resize(DIM_VGL, spo_list.size(), BasisSetSize);

{
ScopedTimer local(basis_timer_);
myBasisSet->mw_evaluateVGL(P_list, iat, basis_mw);
myBasisSet->mw_evaluateVGL(P_list, iat, basis_vgl_mw);
}

if (Identity)
Expand All @@ -445,7 +447,7 @@ void LCAOrbitalSet::mw_evaluateVGLImplGEMM(const RefVectorWithLeader<SPOSet>& sp

for (size_t idim = 0; idim < DIM_VGL; idim++)
for (int iw = 0; iw < nw; iw++)
std::copy_n(basis_mw.data_at(idim, iw, 0), output_size, phi_vgl_v.data_at(idim, iw, 0));
std::copy_n(basis_vgl_mw.data_at(idim, iw, 0), output_size, phi_vgl_v.data_at(idim, iw, 0));
}
else
{
Expand All @@ -460,12 +462,43 @@ void LCAOrbitalSet::mw_evaluateVGLImplGEMM(const RefVectorWithLeader<SPOSet>& sp
requested_orb_size, // MOs
spo_list.size() * DIM_VGL, // walkers * DIM_VGL
BasisSetSize, // AOs
1, C_partial_view.data(), BasisSetSize, basis_mw.data(), BasisSetSize, 0, phi_vgl_v.data(),
1, C_partial_view.data(), BasisSetSize, basis_vgl_mw.data(), BasisSetSize, 0, phi_vgl_v.data(),
requested_orb_size);
}
}
}

void LCAOrbitalSet::mw_evaluateValueVPsImplGEMM(const RefVectorWithLeader<SPOSet>& spo_list,
const RefVectorWithLeader<const VirtualParticleSet>& vp_list,
OffloadMWVArray& vp_phi_v) const
{
assert(this == &spo_list.getLeader());
auto& spo_leader = spo_list.getCastedLeader<LCAOrbitalSet>();
//const size_t nw = spo_list.size();
auto& vp_basis_v_mw = spo_leader.mw_mem_handle_.getResource().vp_basis_v_mw;
//Splatter basis_v
const size_t nVPs = vp_phi_v.size(0);
vp_basis_v_mw.resize(nVPs, BasisSetSize);

myBasisSet->mw_evaluateValueVPs(vp_list, vp_basis_v_mw);

if (Identity)
{
std::copy_n(vp_basis_v_mw.data_at(0, 0), OrbitalSetSize * nVPs, vp_phi_v.data_at(0, 0));
}
else
{
const size_t requested_orb_size = vp_phi_v.size(1);
assert(requested_orb_size <= OrbitalSetSize);
ValueMatrix C_partial_view(C->data(), requested_orb_size, BasisSetSize);
BLAS::gemm('T', 'N',
requested_orb_size, // MOs
nVPs, // walkers * Virtual Particles
BasisSetSize, // AOs
1, C_partial_view.data(), BasisSetSize, vp_basis_v_mw.data(), BasisSetSize, 0, vp_phi_v.data(),
requested_orb_size);
}
}
void LCAOrbitalSet::mw_evaluateValue(const RefVectorWithLeader<SPOSet>& spo_list,
const RefVectorWithLeader<ParticleSet>& P_list,
int iat,
Expand Down Expand Up @@ -521,15 +554,21 @@ void LCAOrbitalSet::mw_evaluateDetRatios(const RefVectorWithLeader<SPOSet>& spo_
const std::vector<const ValueType*>& invRow_ptr_list,
std::vector<std::vector<ValueType>>& ratios_list) const
{
const size_t nw = spo_list.size();
for (size_t iw = 0; iw < nw; iw++)
{
assert(this == &spo_list.getLeader());
auto& spo_leader = spo_list.getCastedLeader<LCAOrbitalSet>();
auto& vp_phi_v = spo_leader.mw_mem_handle_.getResource().vp_phi_v;

const size_t nVPs = VirtualParticleSet::countVPs(vp_list);
const size_t requested_orb_size = psi_list[0].get().size();
vp_phi_v.resize(nVPs, requested_orb_size);

mw_evaluateValueVPsImplGEMM(spo_list, vp_list, vp_phi_v);

///To be computed on Device through new varuable mw_ratios_list, then copied to ratios_list on host.
size_t index = 0;
for (size_t iw = 0; iw < vp_list.size(); iw++)
for (size_t iat = 0; iat < vp_list[iw].getTotalNum(); iat++)
{
spo_list[iw].evaluateValue(vp_list[iw], iat, psi_list[iw]);
ratios_list[iw][iat] = simd::dot(psi_list[iw].get().data(), invRow_ptr_list[iw], psi_list[iw].get().size());
}
}
ratios_list[iw][iat] = simd::dot(vp_phi_v.data_at(index++, 0), invRow_ptr_list[iw], requested_orb_size);
}

void LCAOrbitalSet::evaluateDetRatios(const VirtualParticleSet& VP,
Expand Down
10 changes: 5 additions & 5 deletions src/QMCWaveFunctions/LCAO/LCAOrbitalSet.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,7 @@ struct LCAOrbitalSet : public SPOSet
std::vector<ValueType>& ratios,
std::vector<GradType>& grads) const final;

void evaluateVGH(const ParticleSet& P,
int iat,
ValueVector& psi,
GradVector& dpsi,
HessVector& grad_grad_psi) final;
void evaluateVGH(const ParticleSet& P, int iat, ValueVector& psi, GradVector& dpsi, HessVector& grad_grad_psi) final;

void evaluateVGHGH(const ParticleSet& P,
int iat,
Expand Down Expand Up @@ -308,6 +304,10 @@ struct LCAOrbitalSet : public SPOSet
int iat,
OffloadMWVArray& phi_v) const;

/// packed walker GEMM implementation with multi virtual particle sets
void mw_evaluateValueVPsImplGEMM(const RefVectorWithLeader<SPOSet>& spo_list,
const RefVectorWithLeader<const VirtualParticleSet>& vp_list,
OffloadMWVArray& phi_v) const;
struct LCAOMultiWalkerMem;
ResourceHandle<LCAOMultiWalkerMem> mw_mem_handle_;
/// timer for basis set
Expand Down
12 changes: 12 additions & 0 deletions src/QMCWaveFunctions/LCAO/SoaLocalizedBasisSet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,18 @@ void SoaLocalizedBasisSet<COT, ORBT>::mw_evaluateValue(const RefVectorWithLeader
evaluateV(P_list[iw], iat, v.data_at(iw, 0));
}

template<class COT, typename ORBT>
void SoaLocalizedBasisSet<COT, ORBT>::mw_evaluateValueVPs(const RefVectorWithLeader<const VirtualParticleSet>& vp_list,
OffloadMWVArray& v)
{
assert(BasisSetSize == v.size(1));
size_t index = 0;
for (size_t iw = 0; iw < vp_list.size(); iw++)
for (int iat = 0; iat < vp_list[iw].getTotalNum(); iat++)
evaluateV(vp_list[iw], iat, v.data_at(index++, 0));
}


template<class COT, typename ORBT>
void SoaLocalizedBasisSet<COT, ORBT>::evaluateGradSourceV(const ParticleSet& P,
int iat,
Expand Down
7 changes: 7 additions & 0 deletions src/QMCWaveFunctions/LCAO/SoaLocalizedBasisSet.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,13 @@ class SoaLocalizedBasisSet : public SoaBasisSetBase<ORBT>
*/
void mw_evaluateValue(const RefVectorWithLeader<ParticleSet>& P_list, int iat, OffloadMWVArray& v) override;

/** compute V using packed array with all walkers
* @param vp_list list of quantum virtual particleset (one for each walker)
* @param v Array(n_walkers, BasisSetSize)
*/
void mw_evaluateValueVPs(const RefVectorWithLeader<const VirtualParticleSet>& vp_list, OffloadMWVArray& v) override;


/** compute VGL using packed array with all walkers
* @param P_list list of quantum particleset (one for each walker)
* @param iat active particle
Expand Down

0 comments on commit 053f38a

Please sign in to comment.