Skip to content

Commit

Permalink
Merge pull request #1338 from lattice/feature/tm_force
Browse files Browse the repository at this point in the history
minimal trial implementation for a twisted clover determinant derivative (pre-draft)
  • Loading branch information
weinbe2 authored Dec 21, 2023
2 parents 1914dc3 + e545e36 commit fd50676
Show file tree
Hide file tree
Showing 38 changed files with 2,524 additions and 1,029 deletions.
53 changes: 43 additions & 10 deletions include/clover_field.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,19 @@

namespace quda {

/**
@brief Helper function that returns whether we have enabled
clover fermions.
*/
constexpr bool is_enabled_clover()
{
#ifdef GPU_CLOVER_DIRAC
return true;
#else
return false;
#endif
}

namespace clover
{

Expand Down Expand Up @@ -463,6 +476,29 @@ namespace quda {
*/
void cloverInvert(CloverField &clover, bool computeTraceLog);

/**
@brief Driver for the clover force computation. Eventually the
construction of the x and p fields will be delegated to this
function, but for now, we pre-compute these and pass them in.
@param mom[in,out] Momentum field to be updates
@param gaugeEx[in] Extended gauge field
@param gauge[in] Gauge field
@param clover[in] Clover field
@param x[in] Vector of quark solution fields
@param x0[in] Vector of auxilary quark fields for determinant ratio
@param coeff[in] Vector of coefficients for the quark field outer
products
@param epsilon[in] Vector of scalar coefficient pairs (one per
parity) for the clover sigma outer product
@param sigma_coeff[in] Coefficient for the tr log clover force
@param detratio[in] Whether to compute determinant ratio
@param parity[in] Which parity do we need compute the tr log clover force
*/
void computeCloverForce(GaugeField &mom, const GaugeField &gaugeEx, const GaugeField &gauge,
const CloverField &clover, cvector_ref<ColorSpinorField> &x, cvector_ref<ColorSpinorField> &x0,
const std::vector<double> &coeff, const std::vector<array<double, 2>> &epsilon,
double sigma_coeff, bool detratio, QudaInvertParam &param);

/**
@brief Compute the force contribution from the solver solution fields
Expand All @@ -480,9 +516,8 @@ namespace quda {
@param p Intermediate vectors (both parities)
@param coeff Multiplicative coefficient (e.g., dt * residue)
*/
void computeCloverForce(GaugeField& force, const GaugeField& U,
std::vector<ColorSpinorField*> &x, std::vector<ColorSpinorField*> &p,
std::vector<double> &coeff);
void computeCloverForce(GaugeField &force, const GaugeField &U, cvector_ref<const ColorSpinorField> &x,
cvector_ref<const ColorSpinorField> &p, const std::vector<double> &coeff);
/**
@brief Compute the outer product from the solver solution fields
arising from the diagonal term of the fermion bilinear in
Expand All @@ -493,19 +528,18 @@ namespace quda {
@param p[in] Intermediate vectors (both parities)
@coeff coeff[in] Multiplicative coefficient (e.g., dt * residiue), one for each parity
*/
void computeCloverSigmaOprod(GaugeField& oprod,
std::vector<ColorSpinorField*> &x,
std::vector<ColorSpinorField*> &p,
std::vector< std::vector<double> > &coeff);
void computeCloverSigmaOprod(GaugeField &oprod, cvector_ref<const ColorSpinorField> &x,
cvector_ref<const ColorSpinorField> &p, const std::vector<array<double, 2>> &coeff);
/**
@brief Compute the matrix tensor field necessary for the force calculation from
the clover trace action. This computes a tensor field [mu,nu].
@param output The computed matrix field (tensor matrix field)
@param clover The input clover field
@param coeff Scalar coefficient multiplying the result (e.g., stepsize)
@param parity The field parity we are working on
*/
void computeCloverSigmaTrace(GaugeField &output, const CloverField &clover, double coeff);
void computeCloverSigmaTrace(GaugeField &output, const CloverField &clover, double coeff, int parity);

/**
@brief Compute the derivative of the clover matrix in the direction
Expand All @@ -516,9 +550,8 @@ namespace quda {
@param gauge The input gauge field
@param oprod The input outer-product field (tensor matrix field)
@param coeff Multiplicative coefficient (e.g., clover coefficient)
@param parity The field parity we are working on
*/
void cloverDerivative(GaugeField &force, GaugeField &gauge, GaugeField &oprod, double coeff, QudaParity parity);
void cloverDerivative(GaugeField &force, const GaugeField &gauge, const GaugeField &oprod, double coeff);

/**
@brief This function is used for copying from a source clover field to a destination clover field
Expand Down
29 changes: 16 additions & 13 deletions include/color_spinor_field.h
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ namespace quda
size_t norm_offset = 0; /** offset to the norm (if applicable) */

// multi-GPU parameters
array_2d<void *, 2, QUDA_MAX_DIM> ghost = {}; // pointers to the ghost regions - NULL by default
mutable array_2d<void *, 2, QUDA_MAX_DIM> ghost = {}; // pointers to the ghost regions - NULL by default
mutable lat_dim_t ghostFace = {}; // the size of each face
mutable lat_dim_t ghostFaceCB = {}; // the size of each checkboarded face
mutable array<void *, 2 *QUDA_MAX_DIM> ghost_buf = {}; // wrapper that points to current ghost zone
Expand Down Expand Up @@ -510,7 +510,7 @@ namespace quda
@param[in] nFace Depth of each halo
@param[in] spin_project Whether the halos are spin projected (Wilson-type fermions only)
*/
void createComms(int nFace, bool spin_project = true);
void createComms(int nFace, bool spin_project = true) const;

/**
@brief Packs the ColorSpinorField's ghost zone
Expand All @@ -530,7 +530,7 @@ namespace quda
*/
void packGhost(const int nFace, const QudaParity parity, const int dagger, const qudaStream_t &stream,
MemoryLocation location[2 * QUDA_MAX_DIM], MemoryLocation location_label, bool spin_project,
double a = 0, double b = 0, double c = 0, int shmem = 0);
double a = 0, double b = 0, double c = 0, int shmem = 0) const;

/**
Pack the field halos in preparation for halo exchange, e.g., for Dslash
Expand All @@ -550,7 +550,7 @@ namespace quda
*/
void pack(int nFace, int parity, int dagger, const qudaStream_t &stream, MemoryLocation location[2 * QUDA_MAX_DIM],
MemoryLocation location_label, bool spin_project = true, double a = 0, double b = 0, double c = 0,
int shmem = 0);
int shmem = 0) const;

/**
@brief Initiate the gpu to cpu send of the ghost zone (halo)
Expand All @@ -559,7 +559,7 @@ namespace quda
@param dir The direction (QUDA_BACKWARDS or QUDA_FORWARDS)
@param stream The array of streams to use
*/
void sendGhost(void *ghost_spinor, const int dim, const QudaDirection dir, const qudaStream_t &stream);
void sendGhost(void *ghost_spinor, const int dim, const QudaDirection dir, const qudaStream_t &stream) const;

/**
Initiate the cpu to gpu send of the ghost zone (halo)
Expand All @@ -568,7 +568,7 @@ namespace quda
@param dir The direction (QUDA_BACKWARDS or QUDA_FORWARDS)
@param stream The array of streams to use
*/
void unpackGhost(const void *ghost_spinor, const int dim, const QudaDirection dir, const qudaStream_t &stream);
void unpackGhost(const void *ghost_spinor, const int dim, const QudaDirection dir, const qudaStream_t &stream) const;

/**
@brief Copies the ghost to the host from the device, prior to
Expand All @@ -577,15 +577,15 @@ namespace quda
the scatter-centric direction (0=backwards,1=forwards)
@param[in] stream The stream in which to do the copy
*/
void gather(int dir, const qudaStream_t &stream);
void gather(int dir, const qudaStream_t &stream) const;

/**
@brief Initiate halo communication receive
@param[in] d d=[2*dim+dir], where dim is dimension and dir is
the scatter-centric direction (0=backwards,1=forwards)
@param[in] gdr Whether we are using GDR on the receive side
*/
void recvStart(int dir, const qudaStream_t &stream, bool gdr = false);
void recvStart(int dir, const qudaStream_t &stream, bool gdr = false) const;

/**
@brief Initiate halo communication sending
Expand All @@ -596,7 +596,7 @@ namespace quda
@param[in] gdr Whether we are using GDR on the send side
@param[in] remote_write Whether we are writing direct to remote memory (or using copy engines)
*/
void sendStart(int d, const qudaStream_t &stream, bool gdr = false, bool remote_write = false);
void sendStart(int d, const qudaStream_t &stream, bool gdr = false, bool remote_write = false) const;

/**
@brief Initiate halo communication
Expand All @@ -606,7 +606,7 @@ namespace quda
@param[in] gdr_send Whether we are using GDR on the send side
@param[in] gdr_recv Whether we are using GDR on the receive side
*/
void commsStart(int d, const qudaStream_t &stream, bool gdr_send = false, bool gdr_recv = false);
void commsStart(int d, const qudaStream_t &stream, bool gdr_send = false, bool gdr_recv = false) const;

/**
@brief Non-blocking query if the halo communication has completed
Expand All @@ -616,7 +616,7 @@ namespace quda
@param[in] gdr_send Whether we are using GDR on the send side
@param[in] gdr_recv Whether we are using GDR on the receive side
*/
int commsQuery(int d, const qudaStream_t &stream, bool gdr_send = false, bool gdr_recv = false);
int commsQuery(int d, const qudaStream_t &stream, bool gdr_send = false, bool gdr_recv = false) const;

/**
@brief Wait on halo communication to complete
Expand All @@ -626,7 +626,7 @@ namespace quda
@param[in] gdr_send Whether we are using GDR on the send side
@param[in] gdr_recv Whether we are using GDR on the receive side
*/
void commsWait(int d, const qudaStream_t &stream, bool gdr_send = false, bool gdr_recv = false);
void commsWait(int d, const qudaStream_t &stream, bool gdr_send = false, bool gdr_recv = false) const;

/**
@brief Unpacks the ghost from host to device after
Expand All @@ -636,7 +636,7 @@ namespace quda
@param[in] stream The stream in which to do the copy. If
-1 is passed then the copy will be issied to the d^th stream
*/
void scatter(int d, const qudaStream_t &stream);
void scatter(int d, const qudaStream_t &stream) const;

/**
Do the exchange between neighbouring nodes of the data in
Expand Down Expand Up @@ -725,6 +725,9 @@ namespace quda
ColorSpinorField &Even();
ColorSpinorField &Odd();

const ColorSpinorField &operator[](QudaParity parity) const { return parity == QUDA_EVEN_PARITY ? Even() : Odd(); }
ColorSpinorField &operator[](QudaParity parity) { return parity == QUDA_EVEN_PARITY ? Even() : Odd(); }

CompositeColorSpinorField &Components() { return components; };

/**
Expand Down
2 changes: 1 addition & 1 deletion include/gauge_field.h
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@ namespace quda {
@param recon The reconsturction type
@return the pointer to the extended gauge field
*/
GaugeField *createExtendedGauge(GaugeField &in, const lat_dim_t &R, TimeProfile &profile,
GaugeField *createExtendedGauge(const GaugeField &in, const lat_dim_t &R, TimeProfile &profile = getProfile(),
bool redundant_comms = false, QudaReconstructType recon = QUDA_RECONSTRUCT_INVALID);

/**
Expand Down
Loading

0 comments on commit fd50676

Please sign in to comment.