Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hotfix/tifr #260

Merged
merged 7 commits into from
May 28, 2015
30 changes: 28 additions & 2 deletions include/gauge_field_order.h
Original file line number Diff line number Diff line change
Expand Up @@ -444,13 +444,15 @@ namespace quda {
#if __COMPUTE_CAPABILITY__ >= 200
const int hasPhase;
const size_t phaseOffset;
void *backup_h; //! host memory for backing up the field when tuning
size_t bytes;
#endif

FloatNOrder(const GaugeField &u, Float *gauge_=0, Float **ghost_=0) :
reconstruct(u), volumeCB(u.VolumeCB()), stride(u.Stride()), geometry(u.Geometry())
#if __COMPUTE_CAPABILITY__ >= 200
, hasPhase((u.Reconstruct() == QUDA_RECONSTRUCT_9 || u.Reconstruct() == QUDA_RECONSTRUCT_13) ? 1 : 0),
phaseOffset(u.PhaseOffset())
phaseOffset(u.PhaseOffset()), backup_h(0), bytes(u.Bytes())
#endif
{
if (gauge_) { gauge[0] = gauge_; gauge[1] = (Float*)((char*)gauge_ + u.Bytes()/2);
Expand All @@ -467,7 +469,7 @@ namespace quda {
: reconstruct(order.reconstruct), volumeCB(order.volumeCB), stride(order.stride),
geometry(order.geometry)
#if __COMPUTE_CAPABILITY__ >= 200
, hasPhase(order.hasPhase), phaseOffset(order.phaseOffset)
, hasPhase(order.hasPhase), phaseOffset(order.phaseOffset), backup_h(0), bytes(order.bytes)
#endif
{
gauge[0] = order.gauge[0];
Expand Down Expand Up @@ -618,6 +620,30 @@ namespace quda {
}
}

/**
used to backup the field to the host when tuning
*/
void save() {
#if __COMPUTE_CAPABILITY__ >= 200
if (backup_h) errorQuda("Already allocated host backup");
backup_h = safe_malloc(bytes);
cudaMemcpy(backup_h, gauge[0], bytes, cudaMemcpyDeviceToHost);
checkCudaError();
#endif
}

/**
restore the field from the host after tuning
*/
void load() {
#if __COMPUTE_CAPABILITY__ >= 200
cudaMemcpy(gauge[0], backup_h, bytes, cudaMemcpyHostToDevice);
host_free(backup_h);
backup_h = 0;
checkCudaError();
#endif
}

size_t Bytes() const { return reconLen * sizeof(Float); }
};

Expand Down
67 changes: 55 additions & 12 deletions include/quda.h
Original file line number Diff line number Diff line change
Expand Up @@ -635,31 +635,74 @@ extern "C" {


/**
* Take a gauge field on the host, extend it and load it onto the device.
* Take a gauge field on the host, load it onto the device and extend it.
* Return a pointer to the extended gauge field.
*
* @param gauge The CPU gauge field (optional - if set to 0 then the gauge field zeroed)
* @param geometry The geometry of the matrix field to create (1 - scaler, 4 - vector, 6 - tensor)
* @param param The parameters of the external field and the field to be created
* @return Pointer to the gauge field (cast as a void*)
*/
void* createExtendedGaugeField(void* gauge, int geometry, QudaGaugeParam* param);

void* createGaugeField(void* gauge, int geometry, QudaGaugeParam* param);
void* createExtendedGaugeFieldQuda(void* gauge, int geometry, QudaGaugeParam* param);

void saveGaugeField(void* outGauge, void* inGauge, QudaGaugeParam* param);
/**
* Allocate a gauge (matrix) field on the device and optionally download a host gauge field.
*
* @param gauge The host gauge field (optional - if set to 0 then the gauge field zeroed)
* @param geometry The geometry of the matrix field to create (1 - scaler, 4 - vector, 6 - tensor)
* @param param The parameters of the external field and the field to be created
* @return Pointer to the gauge field (cast as a void*)
*/
void* createGaugeFieldQuda(void* gauge, int geometry, QudaGaugeParam* param);

void extendGaugeField(void* outGauge, void* inGauge);
/**
* Store a gauge (matrix) field on the device and optionally download a CPU gauge field.
*
* @param outGauge Pointer to the host gauge field
* @param inGauge Pointer to the device gauge field
* @param param The parameters of the host and device fields
*/
void saveGaugeFieldQuda(void* outGauge, void* inGauge, QudaGaugeParam* param);

/**
* Take a gauge field on the device and extend it
*
* @param outGauge Pointer to the output extended device gauge field
* @param inGauge Pointer to the input device gauge field
*/
void extendGaugeFieldQuda(void* outGauge, void* inGauge);

/**
* Reinterpret gauge as a pointer to cudaGaugeField and call destructor.
* Reinterpret gauge as a pointer to cudaGaugeField and call destructor.
*
* @param gauge Gauge field to be freed
*/
void destroyQudaGaugeField(void* gauge);
void destroyGaugeFieldQuda(void* gauge);

/**
* Compute the clover field and its inverse from the resident gauge field
*
* @param param The parameters of the clover field to create
*/
void createCloverQuda(QudaInvertParam* param);

/**
* Compute the sigma trace field (part of HMC)
*
* @param out Sigma trace field
* @param dummy (not used)
* @param mu mu direction
* @param nu nu direction
* @param dim array of local field dimensions
*/
void computeCloverTraceQuda(void* out, void* dummy, int mu, int nu, int dim[4]);

void computeCloverTraceQuda(void* out, void* clover, int mu, int nu, int dim[4]);

/**
* Compute the derivative of the clover term
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing @param, might also use '@sa' ?

*/
void computeCloverDerivativeQuda(void* out, void* gauge, void* oprod, int mu, int nu,
double coeff,
QudaParity parity, QudaGaugeParam* param, int conjugate);
double coeff,
QudaParity parity, QudaGaugeParam* param, int conjugate);

/**
* Compute the quark-field outer product needed for gauge generation
Expand Down
2 changes: 2 additions & 0 deletions include/quda_milc_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,8 @@ extern "C" {

void* qudaCreateExtendedGaugeField(void* gauge, int geometry, int precision);

void* qudaResidentExtendedGaugeField(void* gauge, int geometry, int precision);

void* qudaCreateGaugeField(void* gauge, int geometry, int precision);

void qudaSaveGaugeField(void* gauge, void* inGauge);
Expand Down
Loading