Skip to content

Commit

Permalink
Merge pull request QMCPACK#4850 from ye-luo/hybridrep-builder
Browse files Browse the repository at this point in the history
Rename BsplineReaderBase to BsplineReader
  • Loading branch information
prckent authored Nov 30, 2023
2 parents 4189f84 + bc2d410 commit 30e06c9
Show file tree
Hide file tree
Showing 22 changed files with 76 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
//////////////////////////////////////////////////////////////////////////////////////


/** @file BsplineReaderBase.cpp
/** @file BsplineReader.cpp
*
* Implement super function
*/
#include "EinsplineSetBuilder.h"
#include "BsplineReaderBase.h"
#include "BsplineReader.h"
#include "OhmmsData/AttributeSet.h"
#include "Message/CommOperators.h"

Expand All @@ -27,13 +27,13 @@

namespace qmcplusplus
{
BsplineReaderBase::BsplineReaderBase(EinsplineSetBuilder* e)
BsplineReader::BsplineReader(EinsplineSetBuilder* e)
: mybuilder(e), checkNorm(true), saveSplineCoefs(false), rotate(true)
{
myComm = mybuilder->getCommunicator();
}

BsplineReaderBase::~BsplineReaderBase() = default;
BsplineReader::~BsplineReader() = default;

inline std::string make_bandinfo_filename(const std::string& root,
int spin,
Expand Down Expand Up @@ -65,7 +65,7 @@ inline std::string make_bandgroup_name(const std::string& root,
return oo.str();
}

void BsplineReaderBase::setCommon(xmlNodePtr cur)
void BsplineReader::setCommon(xmlNodePtr cur)
{
// check orbital normalization by default
std::string checkOrbNorm("yes");
Expand All @@ -84,7 +84,7 @@ void BsplineReaderBase::setCommon(xmlNodePtr cur)
saveSplineCoefs = saveCoefs == "yes";
}

std::unique_ptr<SPOSet> BsplineReaderBase::create_spline_set(int spin, xmlNodePtr cur)
std::unique_ptr<SPOSet> BsplineReader::create_spline_set(int spin, xmlNodePtr cur)
{
int ns(0);
std::string spo_object_name;
Expand Down Expand Up @@ -120,7 +120,7 @@ std::unique_ptr<SPOSet> BsplineReaderBase::create_spline_set(int spin, xmlNodePt
return create_spline_set(spo_object_name, spin, vals);
}

std::unique_ptr<SPOSet> BsplineReaderBase::create_spline_set(int spin, xmlNodePtr cur, SPOSetInputInfo& input_info)
std::unique_ptr<SPOSet> BsplineReader::create_spline_set(int spin, xmlNodePtr cur, SPOSetInputInfo& input_info)
{
std::string spo_object_name;
OhmmsAttributeSet a;
Expand Down Expand Up @@ -160,7 +160,7 @@ std::unique_ptr<SPOSet> BsplineReaderBase::create_spline_set(int spin, xmlNodePt
*
* At gamma or arbitrary kpoints with complex wavefunctions, spo2band[i]==i
*/
void BsplineReaderBase::initialize_spo2band(int spin,
void BsplineReader::initialize_spo2band(int spin,
const std::vector<BandInfo>& bigspace,
SPOSetInfo& sposet,
std::vector<int>& spo2band)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
//////////////////////////////////////////////////////////////////////////////////////


/** @file BsplineReaderBase.h
/** @file BsplineReader.h
*
* base class to read data and manage spline tables
*/
#ifndef QMCPLUSPLUS_BSPLINE_READER_BASE_H
#define QMCPLUSPLUS_BSPLINE_READER_BASE_H
#ifndef QMCPLUSPLUS_BSPLINE_READER_H
#define QMCPLUSPLUS_BSPLINE_READER_H

#include "mpi/collectives.h"
#include "mpi/point2point.h"
Expand All @@ -29,13 +29,13 @@ namespace qmcplusplus
struct SPOSetInputInfo;

/**
* Each SplineC2X needs a reader derived from BsplineReaderBase.
* Each SplineC2X needs a reader derived from BsplineReader.
* This base class handles common chores
* - check_twists : read gvectors, set twists for folded bands if needed, and set the phase for the special K
* - set_grid : create the basic grid and boundary conditions for einspline
* Note that template is abused but it works.
*/
struct BsplineReaderBase
struct BsplineReader
{
///pointer to the EinsplineSetBuilder
EinsplineSetBuilder* mybuilder;
Expand All @@ -51,9 +51,9 @@ struct BsplineReaderBase
///map from spo index to band index
std::vector<std::vector<int>> spo2band;

BsplineReaderBase(EinsplineSetBuilder* e);
BsplineReader(EinsplineSetBuilder* e);

virtual ~BsplineReaderBase();
virtual ~BsplineReader();

std::string getSplineDumpFileName(const BandInfoGroup& bandgroup) const
{
Expand Down
2 changes: 1 addition & 1 deletion src/QMCWaveFunctions/BsplineFactory/BsplineSet.h
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ class BsplineSet : public SPOSet

template<class BSPLINESPO>
friend class SplineSetReader;
friend struct BsplineReaderBase;
friend struct BsplineReader;
};

} // namespace qmcplusplus
Expand Down
8 changes: 4 additions & 4 deletions src/QMCWaveFunctions/BsplineFactory/EinsplineSetBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ class Communicate;

namespace qmcplusplus
{
///forward declaration of BsplineReaderBase
struct BsplineReaderBase;
///forward declaration of BsplineReader
struct BsplineReader;

// Helper needed for TwistMap
struct Int3less
Expand Down Expand Up @@ -128,8 +128,8 @@ class EinsplineSetBuilder : public SPOSetBuilder
*/
std::vector<std::unique_ptr<std::vector<BandInfo>>> FullBands;

/// reader to use BsplineReaderBase
std::unique_ptr<BsplineReaderBase> MixedSplineReader;
/// reader to use BsplineReader
std::unique_ptr<BsplineReader> MixedSplineReader;

///This is true if we have the orbital derivatives w.r.t. the ion positions
bool HaveOrbDerivs;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
#include "OhmmsData/AttributeSet.h"
#include "Message/CommOperators.h"
#include <Message/UniformCommunicateError.h>
#include "QMCWaveFunctions/BsplineFactory/BsplineReaderBase.h"
#include "QMCWaveFunctions/BsplineFactory/BsplineReader.h"
#include "Particle/DistanceTable.h"

#include <array>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
#include <fftw3.h>
#include "Utilities/ProgressReportEngine.h"
#include "einspline_helper.hpp"
#include "BsplineReaderBase.h"
#include "BsplineReader.h"
#include "BsplineSet.h"
#include "createBsplineReader.h"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
#include "Message/CommOperators.h"
#include "Utilities/Timer.h"
#include "einspline_helper.hpp"
#include "BsplineReaderBase.h"
#include "BsplineReader.h"
#include "createBsplineReader.h"

namespace qmcplusplus
Expand Down
2 changes: 1 addition & 1 deletion src/QMCWaveFunctions/BsplineFactory/HybridRepCplx.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ class HybridRepCplx : public SPLINEBASE, private HybridRepCenterOrbitals<typenam
friend class HybridRepSetReader;
template<class BSPLINESPO>
friend class SplineSetReader;
friend struct BsplineReaderBase;
friend struct BsplineReader;
};

} // namespace qmcplusplus
Expand Down
2 changes: 1 addition & 1 deletion src/QMCWaveFunctions/BsplineFactory/HybridRepReal.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ class HybridRepReal : public SPLINEBASE, private HybridRepCenterOrbitals<typenam
friend class HybridRepSetReader;
template<class BSPLINESPO>
friend class SplineSetReader;
friend struct BsplineReaderBase;
friend struct BsplineReader;
};

} // namespace qmcplusplus
Expand Down
11 changes: 8 additions & 3 deletions src/QMCWaveFunctions/BsplineFactory/HybridRepSetReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,14 +136,14 @@ struct Gvectors
/** General HybridRepSetReader to handle any unitcell
*/
template<typename SA>
class HybridRepSetReader : public BsplineReaderBase
class HybridRepSetReader : public BsplineReader
{
using SplineReader = SplineSetReader<typename SA::SplineBase>;
using DataType = typename SplineReader::DataType;
SplineReader spline_reader_;

public:
HybridRepSetReader(EinsplineSetBuilder* e) : BsplineReaderBase(e), spline_reader_(e) {}
HybridRepSetReader(EinsplineSetBuilder* e) : BsplineReader(e), spline_reader_(e) {}

std::unique_ptr<SPOSet> create_spline_set(const std::string& my_name,
int spin,
Expand All @@ -153,7 +153,7 @@ class HybridRepSetReader : public BsplineReaderBase
app_log() << " ClassName = " << bspline->getClassName() << std::endl;
// set info for Hybrid
initialize_hybridrep_atomic_centers(*bspline);
bool foundspline = spline_reader_.fill_spline_set(*bspline, spin, bandgroup);
bool foundspline = spline_reader_.createSplineDataSpaceLookforDumpFile(bandgroup, *bspline);
typename SA::HYBRIDBASE& hybrid_center_orbs = *bspline;
hybrid_center_orbs.resizeStorage(bspline->myV.size());
if (foundspline)
Expand Down Expand Up @@ -619,6 +619,11 @@ class HybridRepSetReader : public BsplineReaderBase
}
}

/** transforming planewave orbitals to 3D B-spline orbitals and 1D B-spline radial orbitals in real space.
* @param spin orbital dataset spin index
* @param bandgroup band info
* @param bspline the spline object being worked on
*/
void initialize_hybrid_pio_gather(const int spin, const BandInfoGroup& bandgroup, SA& bspline) const
{
auto& mybuilder = spline_reader_.mybuilder;
Expand Down
2 changes: 1 addition & 1 deletion src/QMCWaveFunctions/BsplineFactory/SplineC2C.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ class SplineC2C : public BsplineSet

template<class BSPLINESPO>
friend class SplineSetReader;
friend struct BsplineReaderBase;
friend struct BsplineReader;
};

extern template class SplineC2C<float>;
Expand Down
2 changes: 1 addition & 1 deletion src/QMCWaveFunctions/BsplineFactory/SplineC2COMPTarget.h
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ class SplineC2COMPTarget : public BsplineSet

template<class BSPLINESPO>
friend class SplineSetReader;
friend struct BsplineReaderBase;
friend struct BsplineReader;
};

extern template class SplineC2COMPTarget<float>;
Expand Down
2 changes: 1 addition & 1 deletion src/QMCWaveFunctions/BsplineFactory/SplineC2R.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ class SplineC2R : public BsplineSet

template<class BSPLINESPO>
friend class SplineSetReader;
friend struct BsplineReaderBase;
friend struct BsplineReader;
};

extern template class SplineC2R<float>;
Expand Down
2 changes: 1 addition & 1 deletion src/QMCWaveFunctions/BsplineFactory/SplineC2ROMPTarget.h
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ class SplineC2ROMPTarget : public BsplineSet

template<class BSPLINESPO>
friend class SplineSetReader;
friend struct BsplineReaderBase;
friend struct BsplineReader;
};

extern template class SplineC2ROMPTarget<float>;
Expand Down
2 changes: 1 addition & 1 deletion src/QMCWaveFunctions/BsplineFactory/SplineR2R.h
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ class SplineR2R : public BsplineSet

template<class BSPLINESPO>
friend class SplineSetReader;
friend struct BsplineReaderBase;
friend struct BsplineReader;
};

extern template class SplineR2R<float>;
Expand Down
29 changes: 20 additions & 9 deletions src/QMCWaveFunctions/BsplineFactory/SplineSetReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,21 +123,21 @@ class OneSplineOrbData
/** General SplineSetReader to handle any unitcell
*/
template<typename SA>
class SplineSetReader : public BsplineReaderBase
class SplineSetReader : public BsplineReader
{
public:
using DataType = typename SA::DataType;
using SplineType = typename SA::SplineType;

SplineSetReader(EinsplineSetBuilder* e) : BsplineReaderBase(e) {}
SplineSetReader(EinsplineSetBuilder* e) : BsplineReader(e) {}

std::unique_ptr<SPOSet> create_spline_set(const std::string& my_name,
int spin,
const BandInfoGroup& bandgroup) override
{
auto bspline = std::make_unique<SA>(my_name);
app_log() << " ClassName = " << bspline->getClassName() << std::endl;
bool foundspline = fill_spline_set(*bspline, spin, bandgroup);
bool foundspline = createSplineDataSpaceLookforDumpFile(bandgroup, *bspline);
if (foundspline)
{
Timer now;
Expand Down Expand Up @@ -184,11 +184,13 @@ class SplineSetReader : public BsplineReaderBase
return bspline;
}

bool fill_spline_set(SA& bspline, int spin, const BandInfoGroup& bandgroup) const
/** create data space in the spline object and try open spline dump files.
* @param bandgroup band info
* @param bspline the spline object being worked on
* @return true if dumpfile pass class name and data type size check
*/
bool createSplineDataSpaceLookforDumpFile(const BandInfoGroup& bandgroup, SA& bspline) const
{
ReportEngine PRE("SplineSetReader", "create_spline_set(spin,SPE*)");
//Timer c_prep, c_unpack,c_fft, c_phase, c_spline, c_newphase, c_h5, c_init;
//double t_prep=0.0, t_unpack=0.0, t_fft=0.0, t_phase=0.0, t_spline=0.0, t_newphase=0.0, t_h5=0.0, t_init=0.0;
if (bspline.isComplex())
app_log() << " Using complex einspline table" << std::endl;
else
Expand Down Expand Up @@ -230,6 +232,11 @@ class SplineSetReader : public BsplineReaderBase
return foundspline;
}

/** read planewave coefficients from h5 file
* @param s data set full path in h5
* @param h5f hdf5 file handle
* @param cG vector to store coefficients
*/
void readOneOrbitalCoefs(const std::string& s, hdf_archive& h5f, Vector<std::complex<double>>& cG) const
{
if (!h5f.readEntry(cG, s))
Expand All @@ -251,7 +258,10 @@ class SplineSetReader : public BsplineReaderBase
}
}

/** initialize the splines
/** transforming planewave orbitals to 3D B-spline orbitals in real space.
* @param spin orbital dataset spin index
* @param bandgroup band info
* @param bspline the spline object being worked on
*/
void initialize_spline_pio_gather(const int spin, const BandInfoGroup& bandgroup, SA& bspline) const
{
Expand Down Expand Up @@ -281,8 +291,9 @@ class SplineSetReader : public BsplineReaderBase
oneband.fft_spline(cG, mybuilder->Gvecs[0], mybuilder->primcell_kpoints[ti], rotate);
bspline.set_spline(&oneband.get_spline_r(), &oneband.get_spline_i(), cur_band.TwistIndex, iorb, 0);
}
band_group_comm.getGroupLeaderComm()->barrier();

{
band_group_comm.getGroupLeaderComm()->barrier();
Timer now;
bspline.gather_tables(band_group_comm.getGroupLeaderComm());
app_log() << " Time to gather the table = " << now.elapsed() << std::endl;
Expand Down
10 changes: 5 additions & 5 deletions src/QMCWaveFunctions/BsplineFactory/createBsplineReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,34 +19,34 @@
namespace qmcplusplus
{
///forward declaration
struct BsplineReaderBase;
struct BsplineReader;
class EinsplineSetBuilder;

/** create a reader which handles complex (double size real) splines, C2R or C2C case
* spline storage and computation precision is double
*/
std::unique_ptr<BsplineReaderBase> createBsplineComplexDouble(EinsplineSetBuilder* e,
std::unique_ptr<BsplineReader> createBsplineComplexDouble(EinsplineSetBuilder* e,
bool hybrid_rep,
const std::string& useGPU);

/** create a reader which handles complex (double size real) splines, C2R or C2C case
* spline storage and computation precision is float
*/
std::unique_ptr<BsplineReaderBase> createBsplineComplexSingle(EinsplineSetBuilder* e,
std::unique_ptr<BsplineReader> createBsplineComplexSingle(EinsplineSetBuilder* e,
bool hybrid_rep,
const std::string& useGPU);

/** create a reader which handles real splines, R2R case
* spline storage and computation precision is double
*/
std::unique_ptr<BsplineReaderBase> createBsplineRealDouble(EinsplineSetBuilder* e,
std::unique_ptr<BsplineReader> createBsplineRealDouble(EinsplineSetBuilder* e,
bool hybrid_rep,
const std::string& useGPU);

/** create a reader which handles real splines, R2R case
* spline storage and computation precision is float
*/
std::unique_ptr<BsplineReaderBase> createBsplineRealSingle(EinsplineSetBuilder* e,
std::unique_ptr<BsplineReader> createBsplineRealSingle(EinsplineSetBuilder* e,
bool hybrid_rep,
const std::string& useGPU);

Expand Down
Loading

0 comments on commit 30e06c9

Please sign in to comment.