From 17e87355b2f8a7dce899461fe51e575477c48922 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Sat, 7 May 2016 17:48:10 -0700 Subject: [PATCH] Fixed single-precision hisq force unitarization and added bytes measure --- lib/unitarize_force_quda.cu | 43 ++++++++++++++++++++++--------------- 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/lib/unitarize_force_quda.cu b/lib/unitarize_force_quda.cu index 844e183402..4bf5b07d42 100644 --- a/lib/unitarize_force_quda.cu +++ b/lib/unitarize_force_quda.cu @@ -399,7 +399,7 @@ namespace { // anonymous const Float& denominator = g[2]*(g[0]*g[1]-g[2]); c[0] = (g[0]*g[1]*g[1] - g[2]*(g[0]*g[0]+g[1]))/denominator; c[1] = (-g[0]*g[0]*g[0] - g[2] + 2.*g[0]*g[1])/denominator; - c[2] = g[0]/denominator; + c[2] = g[0]/denominator; tempq = c[1]*q + c[2]*qsq; // Add a real scalar @@ -451,18 +451,13 @@ namespace { // anonymous Link vqv_dagger = v*qv_dagger; Link temp = f[1]*vv_dagger + f[2]*vqv_dagger; - temp = f[1]*v_dagger + f[2]*qv_dagger; Link conj_outer_prod = conj(outer_prod); - temp = f[1]*v + f[2]*v*q; result = result + outer_prod*temp*v_dagger + f[2]*q*outer_prod*vv_dagger; - result = result + v_dagger*conj_outer_prod*conj(temp) + f[2]*qv_dagger*conj_outer_prod*v_dagger; - - // now done with vv_dagger, I think Link qsqv_dagger = q*qv_dagger; Link pv_dagger = b[0]*v_dagger + b[1]*qv_dagger + b[2]*qsqv_dagger; accumBothDerivatives(&result, v, pv_dagger, outer_prod); // 41 flops @@ -474,6 +469,7 @@ namespace { // anonymous Link sv_dagger = b[2]*v_dagger + b[4]*qv_dagger + b[5]*qsqv_dagger; Link vqsq = vq*q; accumBothDerivatives(&result, vqsq, sv_dagger, outer_prod); // 41 flops + return; // 4528 flops - 17 matrix multiplies (198 flops each) + reciprocal root (approx 529 flops) + accumBothDerivatives (41 each) + miscellaneous } // get unit force term @@ -492,14 +488,18 @@ namespace { // anonymous // This part of the calculation is always done in double precision Matrix,3> v, result, oprod; - + Matrix,3> v_tmp, result_tmp, oprod_tmp; + for(int dir=0; dir<4; ++dir){ - arg.force_old.load((Float*)(oprod.data), idx, dir, parity); - arg.gauge.load((Float*)(v.data), idx, dir, parity); + arg.force_old.load((Float*)(oprod_tmp.data), idx, dir, parity); + arg.gauge.load((Float*)(v_tmp.data), idx, dir, parity); + v = v_tmp; + oprod = oprod_tmp; getUnitarizeForceSite(result, v, oprod, arg); + result_tmp = result; - arg.force.save((Float*)(result.data), idx, dir, parity); + arg.force.save((Float*)(result_tmp.data), idx, dir, parity); } // 4*4528 flops per site return; } // getUnitarizeForceField @@ -508,16 +508,20 @@ namespace { // anonymous template void unitarizeForceCPU(Arg &arg) { Matrix,3> v, result, oprod; + Matrix,3> v_tmp, result_tmp, oprod_tmp; for (int parity=0; parity<2; parity++) { for (int i=0; i(result, v, oprod, arg); - arg.force.save((Float*)(result.data), i, dir, parity); + result_tmp = result; + arg.force.save((Float*)(result_tmp.data), i, dir, parity); } } } @@ -534,11 +538,13 @@ namespace { // anonymous UnitarizeForceArg arg(G(newForce), G(oldForce), G(gauge), gauge, &num_failures, unitarize_eps, force_filter, max_det_error, allow_svd, svd_only, svd_rel_error, svd_abs_error); unitarizeForceCPU(arg); - } else { + } else if (gauge.Precision() == QUDA_SINGLE_PRECISION) { typedef gauge::MILCOrder G; UnitarizeForceArg arg(G(newForce), G(oldForce), G(gauge), gauge, &num_failures, unitarize_eps, force_filter, max_det_error, allow_svd, svd_only, svd_rel_error, svd_abs_error); unitarizeForceCPU(arg); + } else { + errorQuda("Precision = %d not supported", gauge.Precision()); } } else if (gauge.Order() == QUDA_QDP_GAUGE_ORDER) { if (gauge.Precision() == QUDA_DOUBLE_PRECISION) { @@ -546,11 +552,13 @@ namespace { // anonymous UnitarizeForceArg arg(G(newForce), G(oldForce), G(gauge), gauge, &num_failures, unitarize_eps, force_filter, max_det_error, allow_svd, svd_only, svd_rel_error, svd_abs_error); unitarizeForceCPU(arg); - } else { + } else if (gauge.Precision() == QUDA_SINGLE_PRECISION) { typedef gauge::QDPOrder G; UnitarizeForceArg arg(G(newForce), G(oldForce), G(gauge), gauge, &num_failures, unitarize_eps, force_filter, max_det_error, allow_svd, svd_only, svd_rel_error, svd_abs_error); unitarizeForceCPU(arg); + } else { + errorQuda("Precision = %d not supported", gauge.Precision()); } } else { errorQuda("Only MILC and QDP gauge orders supported\n"); @@ -588,7 +596,8 @@ namespace { // anonymous void postTune() { cudaMemset(arg.fails, 0, sizeof(int)); } // reset fails counter long long flops() const { return 4ll*4528*meta.Volume(); } - + long long bytes() const { return 4ll * arg.threads * (arg.force.Bytes() + arg.force_old.Bytes() + arg.gauge.Bytes()); } + TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), aux); } }; // UnitarizeForce @@ -621,7 +630,7 @@ namespace { // anonymous errorQuda("Mixed precision not supported"); if (gauge.Order() != oldForce.Order() || gauge.Order() != newForce.Order()) - errorQuda("Mixed data ordering not supported not supported"); + errorQuda("Mixed data ordering not supported"); if (gauge.Order() == QUDA_FLOAT2_GAUGE_ORDER) { if (gauge.Precision() == QUDA_DOUBLE_PRECISION) {