Skip to content

Commit

Permalink
Merge pull request #468 from lattice/feature/milc-rhmc-cleanup
Browse files Browse the repository at this point in the history
Feature/milc rhmc cleanup
  • Loading branch information
mathiaswagner committed May 10, 2016
2 parents 2be9efe + 8e1daf1 commit 97672d3
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 45 deletions.
53 changes: 26 additions & 27 deletions lib/blas_core.h
Original file line number Diff line number Diff line change
Expand Up @@ -275,40 +275,39 @@ template <template <typename Float, typename FloatN> class Functor,
errorQuda("blas has not been built for Nspin=%d fields", x.Nspin());
#endif
} else if (x.Precision() == QUDA_SINGLE_PRECISION) {
if (x.Nspin() == 4) {
#if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC)
const int M = 1;
#endif
if (x.Nspin() == 4) {
#if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC)
Spinor<float4,float4,float4,M,writeX,0> X(x);
Spinor<float4,float4,float4,M,writeY,1> Y(y);
Spinor<float4,float4,float4,M,writeZ,2> Z(z);
Spinor<float4,float4,float4,M,writeW,3> W(w);
Functor<float2, float4> f(make_float2(a.x, a.y), make_float2(b.x, b.y), make_float2(c.x, c.y));
BlasCuda<float4,M,
Spinor<float4,float4,float4,M,writeX,0>, Spinor<float4,float4,float4,M,writeY,1>,
Spinor<float4,float4,float4,M,writeZ,2>, Spinor<float4,float4,float4,M,writeW,3>,
Functor<float2, float4> > blas(X, Y, Z, W, f, x.Length()/(4*M), bytes, norm_bytes);
blas.apply(*blasStream);
Spinor<float4,float4,float4,M,writeX,0> X(x);
Spinor<float4,float4,float4,M,writeY,1> Y(y);
Spinor<float4,float4,float4,M,writeZ,2> Z(z);
Spinor<float4,float4,float4,M,writeW,3> W(w);
Functor<float2, float4> f(make_float2(a.x, a.y), make_float2(b.x, b.y), make_float2(c.x, c.y));
BlasCuda<float4,M,
Spinor<float4,float4,float4,M,writeX,0>, Spinor<float4,float4,float4,M,writeY,1>,
Spinor<float4,float4,float4,M,writeZ,2>, Spinor<float4,float4,float4,M,writeW,3>,
Functor<float2, float4> > blas(X, Y, Z, W, f, x.Length()/(4*M), bytes, norm_bytes);
blas.apply(*blasStream);
#else
errorQuda("blas has not been built for Nspin=%d fields", x.Nspin());
errorQuda("blas has not been built for Nspin=%d fields", x.Nspin());
#endif
} else if (x.Nspin()==2 || x.Nspin()==1) {
} else if (x.Nspin()==2 || x.Nspin()==1) {
#if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC) || defined(GPU_STAGGERED_DIRAC)
Spinor<float2,float2,float2,M,writeX,0> X(x);
Spinor<float2,float2,float2,M,writeY,1> Y(y);
Spinor<float2,float2,float2,M,writeZ,2> Z(z);
Spinor<float2,float2,float2,M,writeW,3> W(w);
Functor<float2, float2> f(make_float2(a.x, a.y), make_float2(b.x, b.y), make_float2(c.x, c.y));
BlasCuda<float2,M,
Spinor<float2,float2,float2,M,writeX,0>, Spinor<float2,float2,float2,M,writeY,1>,
Spinor<float2,float2,float2,M,writeZ,2>, Spinor<float2,float2,float2,M,writeW,3>,
Functor<float2, float2> > blas(X, Y, Z, W, f, x.Length()/(2*M), bytes, norm_bytes);
blas.apply(*blasStream);
const int M = 1;
Spinor<float2,float2,float2,M,writeX,0> X(x);
Spinor<float2,float2,float2,M,writeY,1> Y(y);
Spinor<float2,float2,float2,M,writeZ,2> Z(z);
Spinor<float2,float2,float2,M,writeW,3> W(w);
Functor<float2, float2> f(make_float2(a.x, a.y), make_float2(b.x, b.y), make_float2(c.x, c.y));
BlasCuda<float2,M,
Spinor<float2,float2,float2,M,writeX,0>, Spinor<float2,float2,float2,M,writeY,1>,
Spinor<float2,float2,float2,M,writeZ,2>, Spinor<float2,float2,float2,M,writeW,3>,
Functor<float2, float2> > blas(X, Y, Z, W, f, x.Length()/(2*M), bytes, norm_bytes);
blas.apply(*blasStream);
#else
errorQuda("blas has not been built for Nspin=%d fields", x.Nspin());
errorQuda("blas has not been built for Nspin=%d fields", x.Nspin());
#endif
} else { errorQuda("nSpin=%d is not supported\n", x.Nspin()); }
} else { errorQuda("nSpin=%d is not supported\n", x.Nspin()); }
} else {
if (x.Ncolor() != 3) { errorQuda("nColor = %d is not supported", x.Ncolor()); }
if (x.Nspin() == 4){ //wilson
Expand Down
43 changes: 26 additions & 17 deletions lib/unitarize_force_quda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ namespace { // anonymous
const Float& denominator = g[2]*(g[0]*g[1]-g[2]);
c[0] = (g[0]*g[1]*g[1] - g[2]*(g[0]*g[0]+g[1]))/denominator;
c[1] = (-g[0]*g[0]*g[0] - g[2] + 2.*g[0]*g[1])/denominator;
c[2] = g[0]/denominator;
c[2] = g[0]/denominator;

tempq = c[1]*q + c[2]*qsq;
// Add a real scalar
Expand Down Expand Up @@ -451,18 +451,13 @@ namespace { // anonymous
Link vqv_dagger = v*qv_dagger;
Link temp = f[1]*vv_dagger + f[2]*vqv_dagger;


temp = f[1]*v_dagger + f[2]*qv_dagger;
Link conj_outer_prod = conj(outer_prod);


temp = f[1]*v + f[2]*v*q;
result = result + outer_prod*temp*v_dagger + f[2]*q*outer_prod*vv_dagger;

result = result + v_dagger*conj_outer_prod*conj(temp) + f[2]*qv_dagger*conj_outer_prod*v_dagger;


// now done with vv_dagger, I think
Link qsqv_dagger = q*qv_dagger;
Link pv_dagger = b[0]*v_dagger + b[1]*qv_dagger + b[2]*qsqv_dagger;
accumBothDerivatives(&result, v, pv_dagger, outer_prod); // 41 flops
Expand All @@ -474,6 +469,7 @@ namespace { // anonymous
Link sv_dagger = b[2]*v_dagger + b[4]*qv_dagger + b[5]*qsqv_dagger;
Link vqsq = vq*q;
accumBothDerivatives(&result, vqsq, sv_dagger, outer_prod); // 41 flops

return;
// 4528 flops - 17 matrix multiplies (198 flops each) + reciprocal root (approx 529 flops) + accumBothDerivatives (41 each) + miscellaneous
} // get unit force term
Expand All @@ -492,14 +488,18 @@ namespace { // anonymous

// This part of the calculation is always done in double precision
Matrix<complex<double>,3> v, result, oprod;

Matrix<complex<Float>,3> v_tmp, result_tmp, oprod_tmp;

for(int dir=0; dir<4; ++dir){
arg.force_old.load((Float*)(oprod.data), idx, dir, parity);
arg.gauge.load((Float*)(v.data), idx, dir, parity);
arg.force_old.load((Float*)(oprod_tmp.data), idx, dir, parity);
arg.gauge.load((Float*)(v_tmp.data), idx, dir, parity);
v = v_tmp;
oprod = oprod_tmp;

getUnitarizeForceSite<double>(result, v, oprod, arg);
result_tmp = result;

arg.force.save((Float*)(oprod.data), idx, dir, parity);
arg.force.save((Float*)(result_tmp.data), idx, dir, parity);
} // 4*4528 flops per site
return;
} // getUnitarizeForceField
Expand All @@ -508,16 +508,20 @@ namespace { // anonymous
template <typename Float, typename Arg>
void unitarizeForceCPU(Arg &arg) {
Matrix<complex<double>,3> v, result, oprod;
Matrix<complex<Float>,3> v_tmp, result_tmp, oprod_tmp;

for (int parity=0; parity<2; parity++) {
for (int i=0; i<arg.threads/2; i++) {
for (int dir=0; dir<4; dir++) {
arg.force_old.load((Float*)(oprod.data), i, dir, parity);
arg.gauge.load((Float*)(v.data), i, dir, parity);
arg.force_old.load((Float*)(oprod_tmp.data), i, dir, parity);
arg.gauge.load((Float*)(v_tmp.data), i, dir, parity);
v = v_tmp;
oprod = oprod_tmp;

getUnitarizeForceSite<double>(result, v, oprod, arg);

arg.force.save((Float*)(oprod.data), i, dir, parity);
result_tmp = result;
arg.force.save((Float*)(result_tmp.data), i, dir, parity);
}
}
}
Expand All @@ -534,23 +538,27 @@ namespace { // anonymous
UnitarizeForceArg<G,G> arg(G(newForce), G(oldForce), G(gauge), gauge, &num_failures, unitarize_eps, force_filter,
max_det_error, allow_svd, svd_only, svd_rel_error, svd_abs_error);
unitarizeForceCPU<double>(arg);
} else {
} else if (gauge.Precision() == QUDA_SINGLE_PRECISION) {
typedef gauge::MILCOrder<float,18> G;
UnitarizeForceArg<G,G> arg(G(newForce), G(oldForce), G(gauge), gauge, &num_failures, unitarize_eps, force_filter,
max_det_error, allow_svd, svd_only, svd_rel_error, svd_abs_error);
unitarizeForceCPU<float>(arg);
} else {
errorQuda("Precision = %d not supported", gauge.Precision());
}
} else if (gauge.Order() == QUDA_QDP_GAUGE_ORDER) {
if (gauge.Precision() == QUDA_DOUBLE_PRECISION) {
typedef gauge::QDPOrder<double,18> G;
UnitarizeForceArg<G,G> arg(G(newForce), G(oldForce), G(gauge), gauge, &num_failures, unitarize_eps, force_filter,
max_det_error, allow_svd, svd_only, svd_rel_error, svd_abs_error);
unitarizeForceCPU<double>(arg);
} else {
} else if (gauge.Precision() == QUDA_SINGLE_PRECISION) {
typedef gauge::QDPOrder<float,18> G;
UnitarizeForceArg<G,G> arg(G(newForce), G(oldForce), G(gauge), gauge, &num_failures, unitarize_eps, force_filter,
max_det_error, allow_svd, svd_only, svd_rel_error, svd_abs_error);
unitarizeForceCPU<float>(arg);
} else {
errorQuda("Precision = %d not supported", gauge.Precision());
}
} else {
errorQuda("Only MILC and QDP gauge orders supported\n");
Expand Down Expand Up @@ -588,7 +596,8 @@ namespace { // anonymous
void postTune() { cudaMemset(arg.fails, 0, sizeof(int)); } // reset fails counter

long long flops() const { return 4ll*4528*meta.Volume(); }

long long bytes() const { return 4ll * arg.threads * (arg.force.Bytes() + arg.force_old.Bytes() + arg.gauge.Bytes()); }

TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), aux); }
}; // UnitarizeForce

Expand Down Expand Up @@ -621,7 +630,7 @@ namespace { // anonymous
errorQuda("Mixed precision not supported");

if (gauge.Order() != oldForce.Order() || gauge.Order() != newForce.Order())
errorQuda("Mixed data ordering not supported not supported");
errorQuda("Mixed data ordering not supported");

if (gauge.Order() == QUDA_FLOAT2_GAUGE_ORDER) {
if (gauge.Precision() == QUDA_DOUBLE_PRECISION) {
Expand Down
2 changes: 1 addition & 1 deletion tests/hisq_unitarize_force_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ hisq_force_init()

setDims(gaugeParam.X);

gaugeParam.cpu_prec = link_prec;
gaugeParam.cpu_prec = QUDA_DOUBLE_PRECISION;
gaugeParam.cuda_prec = link_prec;
gaugeParam.reconstruct = link_recon;
gaugeParam.gauge_order = QUDA_QDP_GAUGE_ORDER;
Expand Down

0 comments on commit 97672d3

Please sign in to comment.