Skip to content

Commit

Permalink
Fixed single-precision hisq force unitarization and added bytes measure
Browse files Browse the repository at this point in the history
  • Loading branch information
maddyscientist committed May 8, 2016
1 parent aa9b0be commit 17e8735
Showing 1 changed file with 26 additions and 17 deletions.
43 changes: 26 additions & 17 deletions lib/unitarize_force_quda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ namespace { // anonymous
const Float& denominator = g[2]*(g[0]*g[1]-g[2]);
c[0] = (g[0]*g[1]*g[1] - g[2]*(g[0]*g[0]+g[1]))/denominator;
c[1] = (-g[0]*g[0]*g[0] - g[2] + 2.*g[0]*g[1])/denominator;
c[2] = g[0]/denominator;
c[2] = g[0]/denominator;

tempq = c[1]*q + c[2]*qsq;
// Add a real scalar
Expand Down Expand Up @@ -451,18 +451,13 @@ namespace { // anonymous
Link vqv_dagger = v*qv_dagger;
Link temp = f[1]*vv_dagger + f[2]*vqv_dagger;


temp = f[1]*v_dagger + f[2]*qv_dagger;
Link conj_outer_prod = conj(outer_prod);


temp = f[1]*v + f[2]*v*q;
result = result + outer_prod*temp*v_dagger + f[2]*q*outer_prod*vv_dagger;

result = result + v_dagger*conj_outer_prod*conj(temp) + f[2]*qv_dagger*conj_outer_prod*v_dagger;


// now done with vv_dagger, I think
Link qsqv_dagger = q*qv_dagger;
Link pv_dagger = b[0]*v_dagger + b[1]*qv_dagger + b[2]*qsqv_dagger;
accumBothDerivatives(&result, v, pv_dagger, outer_prod); // 41 flops
Expand All @@ -474,6 +469,7 @@ namespace { // anonymous
Link sv_dagger = b[2]*v_dagger + b[4]*qv_dagger + b[5]*qsqv_dagger;
Link vqsq = vq*q;
accumBothDerivatives(&result, vqsq, sv_dagger, outer_prod); // 41 flops

return;
// 4528 flops - 17 matrix multiplies (198 flops each) + reciprocal root (approx 529 flops) + accumBothDerivatives (41 each) + miscellaneous
} // get unit force term
Expand All @@ -492,14 +488,18 @@ namespace { // anonymous

// This part of the calculation is always done in double precision
Matrix<complex<double>,3> v, result, oprod;

Matrix<complex<Float>,3> v_tmp, result_tmp, oprod_tmp;

for(int dir=0; dir<4; ++dir){
arg.force_old.load((Float*)(oprod.data), idx, dir, parity);
arg.gauge.load((Float*)(v.data), idx, dir, parity);
arg.force_old.load((Float*)(oprod_tmp.data), idx, dir, parity);
arg.gauge.load((Float*)(v_tmp.data), idx, dir, parity);
v = v_tmp;
oprod = oprod_tmp;

getUnitarizeForceSite<double>(result, v, oprod, arg);
result_tmp = result;

arg.force.save((Float*)(result.data), idx, dir, parity);
arg.force.save((Float*)(result_tmp.data), idx, dir, parity);
} // 4*4528 flops per site
return;
} // getUnitarizeForceField
Expand All @@ -508,16 +508,20 @@ namespace { // anonymous
template <typename Float, typename Arg>
void unitarizeForceCPU(Arg &arg) {
Matrix<complex<double>,3> v, result, oprod;
Matrix<complex<Float>,3> v_tmp, result_tmp, oprod_tmp;

for (int parity=0; parity<2; parity++) {
for (int i=0; i<arg.threads/2; i++) {
for (int dir=0; dir<4; dir++) {
arg.force_old.load((Float*)(oprod.data), i, dir, parity);
arg.gauge.load((Float*)(v.data), i, dir, parity);
arg.force_old.load((Float*)(oprod_tmp.data), i, dir, parity);
arg.gauge.load((Float*)(v_tmp.data), i, dir, parity);
v = v_tmp;
oprod = oprod_tmp;

getUnitarizeForceSite<double>(result, v, oprod, arg);

arg.force.save((Float*)(result.data), i, dir, parity);
result_tmp = result;
arg.force.save((Float*)(result_tmp.data), i, dir, parity);
}
}
}
Expand All @@ -534,23 +538,27 @@ namespace { // anonymous
UnitarizeForceArg<G,G> arg(G(newForce), G(oldForce), G(gauge), gauge, &num_failures, unitarize_eps, force_filter,
max_det_error, allow_svd, svd_only, svd_rel_error, svd_abs_error);
unitarizeForceCPU<double>(arg);
} else {
} else if (gauge.Precision() == QUDA_SINGLE_PRECISION) {
typedef gauge::MILCOrder<float,18> G;
UnitarizeForceArg<G,G> arg(G(newForce), G(oldForce), G(gauge), gauge, &num_failures, unitarize_eps, force_filter,
max_det_error, allow_svd, svd_only, svd_rel_error, svd_abs_error);
unitarizeForceCPU<float>(arg);
} else {
errorQuda("Precision = %d not supported", gauge.Precision());
}
} else if (gauge.Order() == QUDA_QDP_GAUGE_ORDER) {
if (gauge.Precision() == QUDA_DOUBLE_PRECISION) {
typedef gauge::QDPOrder<double,18> G;
UnitarizeForceArg<G,G> arg(G(newForce), G(oldForce), G(gauge), gauge, &num_failures, unitarize_eps, force_filter,
max_det_error, allow_svd, svd_only, svd_rel_error, svd_abs_error);
unitarizeForceCPU<double>(arg);
} else {
} else if (gauge.Precision() == QUDA_SINGLE_PRECISION) {
typedef gauge::QDPOrder<float,18> G;
UnitarizeForceArg<G,G> arg(G(newForce), G(oldForce), G(gauge), gauge, &num_failures, unitarize_eps, force_filter,
max_det_error, allow_svd, svd_only, svd_rel_error, svd_abs_error);
unitarizeForceCPU<float>(arg);
} else {
errorQuda("Precision = %d not supported", gauge.Precision());
}
} else {
errorQuda("Only MILC and QDP gauge orders supported\n");
Expand Down Expand Up @@ -588,7 +596,8 @@ namespace { // anonymous
void postTune() { cudaMemset(arg.fails, 0, sizeof(int)); } // reset fails counter

long long flops() const { return 4ll*4528*meta.Volume(); }

long long bytes() const { return 4ll * arg.threads * (arg.force.Bytes() + arg.force_old.Bytes() + arg.gauge.Bytes()); }

TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), aux); }
}; // UnitarizeForce

Expand Down Expand Up @@ -621,7 +630,7 @@ namespace { // anonymous
errorQuda("Mixed precision not supported");

if (gauge.Order() != oldForce.Order() || gauge.Order() != newForce.Order())
errorQuda("Mixed data ordering not supported not supported");
errorQuda("Mixed data ordering not supported");

if (gauge.Order() == QUDA_FLOAT2_GAUGE_ORDER) {
if (gauge.Precision() == QUDA_DOUBLE_PRECISION) {
Expand Down

0 comments on commit 17e8735

Please sign in to comment.