Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/milc rhmc cleanup #468

Merged
merged 4 commits into from
May 10, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 26 additions & 27 deletions lib/blas_core.h
Original file line number Diff line number Diff line change
Expand Up @@ -275,40 +275,39 @@ template <template <typename Float, typename FloatN> class Functor,
errorQuda("blas has not been built for Nspin=%d fields", x.Nspin());
#endif
} else if (x.Precision() == QUDA_SINGLE_PRECISION) {
if (x.Nspin() == 4) {
#if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC)
const int M = 1;
#endif
if (x.Nspin() == 4) {
#if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC)
Spinor<float4,float4,float4,M,writeX,0> X(x);
Spinor<float4,float4,float4,M,writeY,1> Y(y);
Spinor<float4,float4,float4,M,writeZ,2> Z(z);
Spinor<float4,float4,float4,M,writeW,3> W(w);
Functor<float2, float4> f(make_float2(a.x, a.y), make_float2(b.x, b.y), make_float2(c.x, c.y));
BlasCuda<float4,M,
Spinor<float4,float4,float4,M,writeX,0>, Spinor<float4,float4,float4,M,writeY,1>,
Spinor<float4,float4,float4,M,writeZ,2>, Spinor<float4,float4,float4,M,writeW,3>,
Functor<float2, float4> > blas(X, Y, Z, W, f, x.Length()/(4*M), bytes, norm_bytes);
blas.apply(*blasStream);
Spinor<float4,float4,float4,M,writeX,0> X(x);
Spinor<float4,float4,float4,M,writeY,1> Y(y);
Spinor<float4,float4,float4,M,writeZ,2> Z(z);
Spinor<float4,float4,float4,M,writeW,3> W(w);
Functor<float2, float4> f(make_float2(a.x, a.y), make_float2(b.x, b.y), make_float2(c.x, c.y));
BlasCuda<float4,M,
Spinor<float4,float4,float4,M,writeX,0>, Spinor<float4,float4,float4,M,writeY,1>,
Spinor<float4,float4,float4,M,writeZ,2>, Spinor<float4,float4,float4,M,writeW,3>,
Functor<float2, float4> > blas(X, Y, Z, W, f, x.Length()/(4*M), bytes, norm_bytes);
blas.apply(*blasStream);
#else
errorQuda("blas has not been built for Nspin=%d fields", x.Nspin());
errorQuda("blas has not been built for Nspin=%d fields", x.Nspin());
#endif
} else if (x.Nspin()==2 || x.Nspin()==1) {
} else if (x.Nspin()==2 || x.Nspin()==1) {
#if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC) || defined(GPU_STAGGERED_DIRAC)
Spinor<float2,float2,float2,M,writeX,0> X(x);
Spinor<float2,float2,float2,M,writeY,1> Y(y);
Spinor<float2,float2,float2,M,writeZ,2> Z(z);
Spinor<float2,float2,float2,M,writeW,3> W(w);
Functor<float2, float2> f(make_float2(a.x, a.y), make_float2(b.x, b.y), make_float2(c.x, c.y));
BlasCuda<float2,M,
Spinor<float2,float2,float2,M,writeX,0>, Spinor<float2,float2,float2,M,writeY,1>,
Spinor<float2,float2,float2,M,writeZ,2>, Spinor<float2,float2,float2,M,writeW,3>,
Functor<float2, float2> > blas(X, Y, Z, W, f, x.Length()/(2*M), bytes, norm_bytes);
blas.apply(*blasStream);
const int M = 1;
Spinor<float2,float2,float2,M,writeX,0> X(x);
Spinor<float2,float2,float2,M,writeY,1> Y(y);
Spinor<float2,float2,float2,M,writeZ,2> Z(z);
Spinor<float2,float2,float2,M,writeW,3> W(w);
Functor<float2, float2> f(make_float2(a.x, a.y), make_float2(b.x, b.y), make_float2(c.x, c.y));
BlasCuda<float2,M,
Spinor<float2,float2,float2,M,writeX,0>, Spinor<float2,float2,float2,M,writeY,1>,
Spinor<float2,float2,float2,M,writeZ,2>, Spinor<float2,float2,float2,M,writeW,3>,
Functor<float2, float2> > blas(X, Y, Z, W, f, x.Length()/(2*M), bytes, norm_bytes);
blas.apply(*blasStream);
#else
errorQuda("blas has not been built for Nspin=%d fields", x.Nspin());
errorQuda("blas has not been built for Nspin=%d fields", x.Nspin());
#endif
} else { errorQuda("nSpin=%d is not supported\n", x.Nspin()); }
} else { errorQuda("nSpin=%d is not supported\n", x.Nspin()); }
} else {
if (x.Ncolor() != 3) { errorQuda("nColor = %d is not supported", x.Ncolor()); }
if (x.Nspin() == 4){ //wilson
Expand Down
43 changes: 26 additions & 17 deletions lib/unitarize_force_quda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ namespace { // anonymous
const Float& denominator = g[2]*(g[0]*g[1]-g[2]);
c[0] = (g[0]*g[1]*g[1] - g[2]*(g[0]*g[0]+g[1]))/denominator;
c[1] = (-g[0]*g[0]*g[0] - g[2] + 2.*g[0]*g[1])/denominator;
c[2] = g[0]/denominator;
c[2] = g[0]/denominator;

tempq = c[1]*q + c[2]*qsq;
// Add a real scalar
Expand Down Expand Up @@ -451,18 +451,13 @@ namespace { // anonymous
Link vqv_dagger = v*qv_dagger;
Link temp = f[1]*vv_dagger + f[2]*vqv_dagger;


temp = f[1]*v_dagger + f[2]*qv_dagger;
Link conj_outer_prod = conj(outer_prod);


temp = f[1]*v + f[2]*v*q;
result = result + outer_prod*temp*v_dagger + f[2]*q*outer_prod*vv_dagger;

result = result + v_dagger*conj_outer_prod*conj(temp) + f[2]*qv_dagger*conj_outer_prod*v_dagger;


// now done with vv_dagger, I think
Link qsqv_dagger = q*qv_dagger;
Link pv_dagger = b[0]*v_dagger + b[1]*qv_dagger + b[2]*qsqv_dagger;
accumBothDerivatives(&result, v, pv_dagger, outer_prod); // 41 flops
Expand All @@ -474,6 +469,7 @@ namespace { // anonymous
Link sv_dagger = b[2]*v_dagger + b[4]*qv_dagger + b[5]*qsqv_dagger;
Link vqsq = vq*q;
accumBothDerivatives(&result, vqsq, sv_dagger, outer_prod); // 41 flops

return;
// 4528 flops - 17 matrix multiplies (198 flops each) + reciprocal root (approx 529 flops) + accumBothDerivatives (41 each) + miscellaneous
} // get unit force term
Expand All @@ -492,14 +488,18 @@ namespace { // anonymous

// This part of the calculation is always done in double precision
Matrix<complex<double>,3> v, result, oprod;

Matrix<complex<Float>,3> v_tmp, result_tmp, oprod_tmp;

for(int dir=0; dir<4; ++dir){
arg.force_old.load((Float*)(oprod.data), idx, dir, parity);
arg.gauge.load((Float*)(v.data), idx, dir, parity);
arg.force_old.load((Float*)(oprod_tmp.data), idx, dir, parity);
arg.gauge.load((Float*)(v_tmp.data), idx, dir, parity);
v = v_tmp;
oprod = oprod_tmp;

getUnitarizeForceSite<double>(result, v, oprod, arg);
result_tmp = result;

arg.force.save((Float*)(oprod.data), idx, dir, parity);
arg.force.save((Float*)(result_tmp.data), idx, dir, parity);
} // 4*4528 flops per site
return;
} // getUnitarizeForceField
Expand All @@ -508,16 +508,20 @@ namespace { // anonymous
template <typename Float, typename Arg>
void unitarizeForceCPU(Arg &arg) {
Matrix<complex<double>,3> v, result, oprod;
Matrix<complex<Float>,3> v_tmp, result_tmp, oprod_tmp;

for (int parity=0; parity<2; parity++) {
for (int i=0; i<arg.threads/2; i++) {
for (int dir=0; dir<4; dir++) {
arg.force_old.load((Float*)(oprod.data), i, dir, parity);
arg.gauge.load((Float*)(v.data), i, dir, parity);
arg.force_old.load((Float*)(oprod_tmp.data), i, dir, parity);
arg.gauge.load((Float*)(v_tmp.data), i, dir, parity);
v = v_tmp;
oprod = oprod_tmp;

getUnitarizeForceSite<double>(result, v, oprod, arg);

arg.force.save((Float*)(oprod.data), i, dir, parity);
result_tmp = result;
arg.force.save((Float*)(result_tmp.data), i, dir, parity);
}
}
}
Expand All @@ -534,23 +538,27 @@ namespace { // anonymous
UnitarizeForceArg<G,G> arg(G(newForce), G(oldForce), G(gauge), gauge, &num_failures, unitarize_eps, force_filter,
max_det_error, allow_svd, svd_only, svd_rel_error, svd_abs_error);
unitarizeForceCPU<double>(arg);
} else {
} else if (gauge.Precision() == QUDA_SINGLE_PRECISION) {
typedef gauge::MILCOrder<float,18> G;
UnitarizeForceArg<G,G> arg(G(newForce), G(oldForce), G(gauge), gauge, &num_failures, unitarize_eps, force_filter,
max_det_error, allow_svd, svd_only, svd_rel_error, svd_abs_error);
unitarizeForceCPU<float>(arg);
} else {
errorQuda("Precision = %d not supported", gauge.Precision());
}
} else if (gauge.Order() == QUDA_QDP_GAUGE_ORDER) {
if (gauge.Precision() == QUDA_DOUBLE_PRECISION) {
typedef gauge::QDPOrder<double,18> G;
UnitarizeForceArg<G,G> arg(G(newForce), G(oldForce), G(gauge), gauge, &num_failures, unitarize_eps, force_filter,
max_det_error, allow_svd, svd_only, svd_rel_error, svd_abs_error);
unitarizeForceCPU<double>(arg);
} else {
} else if (gauge.Precision() == QUDA_SINGLE_PRECISION) {
typedef gauge::QDPOrder<float,18> G;
UnitarizeForceArg<G,G> arg(G(newForce), G(oldForce), G(gauge), gauge, &num_failures, unitarize_eps, force_filter,
max_det_error, allow_svd, svd_only, svd_rel_error, svd_abs_error);
unitarizeForceCPU<float>(arg);
} else {
errorQuda("Precision = %d not supported", gauge.Precision());
}
} else {
errorQuda("Only MILC and QDP gauge orders supported\n");
Expand Down Expand Up @@ -588,7 +596,8 @@ namespace { // anonymous
void postTune() { cudaMemset(arg.fails, 0, sizeof(int)); } // reset fails counter

long long flops() const { return 4ll*4528*meta.Volume(); }

long long bytes() const { return 4ll * arg.threads * (arg.force.Bytes() + arg.force_old.Bytes() + arg.gauge.Bytes()); }

TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), aux); }
}; // UnitarizeForce

Expand Down Expand Up @@ -621,7 +630,7 @@ namespace { // anonymous
errorQuda("Mixed precision not supported");

if (gauge.Order() != oldForce.Order() || gauge.Order() != newForce.Order())
errorQuda("Mixed data ordering not supported not supported");
errorQuda("Mixed data ordering not supported");

if (gauge.Order() == QUDA_FLOAT2_GAUGE_ORDER) {
if (gauge.Precision() == QUDA_DOUBLE_PRECISION) {
Expand Down
2 changes: 1 addition & 1 deletion tests/hisq_unitarize_force_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ hisq_force_init()

setDims(gaugeParam.X);

gaugeParam.cpu_prec = link_prec;
gaugeParam.cpu_prec = QUDA_DOUBLE_PRECISION;
gaugeParam.cuda_prec = link_prec;
gaugeParam.reconstruct = link_recon;
gaugeParam.gauge_order = QUDA_QDP_GAUGE_ORDER;
Expand Down