diff --git a/lib/blas_core.h b/lib/blas_core.h index 2215c6432c..b2a84dc187 100644 --- a/lib/blas_core.h +++ b/lib/blas_core.h @@ -51,6 +51,8 @@ class BlasCuda : public Tunable { // these can't be curried into the Spinors because of Tesla argument length restriction char *X_h, *Y_h, *Z_h, *W_h; char *Xnorm_h, *Ynorm_h, *Znorm_h, *Wnorm_h; + const size_t *bytes_; + const size_t *norm_bytes_; unsigned int sharedBytesPerThread() const { return 0; } unsigned int sharedBytesPerBlock(const TuneParam ¶m) const { return 0; } @@ -67,9 +69,9 @@ class BlasCuda : public Tunable { public: BlasCuda(SpinorX &X, SpinorY &Y, SpinorZ &Z, SpinorW &W, Functor &f, - int length) : + int length, const size_t *bytes, const size_t *norm_bytes) : arg(X, Y, Z, W, f, length), X_h(0), Y_h(0), Z_h(0), W_h(0), - Xnorm_h(0), Ynorm_h(0), Znorm_h(0), Wnorm_h(0) { } + Xnorm_h(0), Ynorm_h(0), Znorm_h(0), Wnorm_h(0), bytes_(bytes), norm_bytes_(norm_bytes) { } virtual ~BlasCuda() { } @@ -82,21 +84,18 @@ class BlasCuda : public Tunable { blasKernel <<>>(arg); } -#define BYTES(X) ( arg.X.Precision()*(sizeof(FloatN)/sizeof(((FloatN*)0)->x))*M*arg.X.Stride() ) -#define NORM_BYTES(X) ( (arg.X.Precision() == QUDA_HALF_PRECISION) ? sizeof(float)*arg.length : 0 ) - void preTune() { - arg.X.save(&X_h, &Xnorm_h, BYTES(X), NORM_BYTES(X)); - arg.Y.save(&Y_h, &Ynorm_h, BYTES(Y), NORM_BYTES(Y)); - arg.Z.save(&Z_h, &Znorm_h, BYTES(Z), NORM_BYTES(Z)); - arg.W.save(&W_h, &Wnorm_h, BYTES(W), NORM_BYTES(W)); + arg.X.save(&X_h, &Xnorm_h, bytes_[0], norm_bytes_[0]); + arg.Y.save(&Y_h, &Ynorm_h, bytes_[1], norm_bytes_[1]); + arg.Z.save(&Z_h, &Znorm_h, bytes_[2], norm_bytes_[2]); + arg.W.save(&W_h, &Wnorm_h, bytes_[3], norm_bytes_[3]); } void postTune() { - arg.X.load(&X_h, &Xnorm_h, BYTES(X), NORM_BYTES(X)); - arg.Y.load(&Y_h, &Ynorm_h, BYTES(Y), NORM_BYTES(Y)); - arg.Z.load(&Z_h, &Znorm_h, BYTES(Z), NORM_BYTES(Z)); - arg.W.load(&W_h, &Wnorm_h, BYTES(W), NORM_BYTES(W)); + arg.X.load(&X_h, &Xnorm_h, bytes_[0], norm_bytes_[0]); + arg.Y.load(&Y_h, &Ynorm_h, bytes_[1], norm_bytes_[1]); + arg.Z.load(&Z_h, &Znorm_h, bytes_[2], norm_bytes_[2]); + arg.W.load(&W_h, &Wnorm_h, bytes_[3], norm_bytes_[3]); } long long flops() const { return arg.f.flops()*(sizeof(FloatN)/sizeof(((FloatN*)0)->x))*arg.length*M; } @@ -141,6 +140,9 @@ inline void blasCuda(const double2 &a, const double2 &b, const double2 &c, // FIXME: use traits to encapsulate register type for shorts - // will reduce template type parameters from 3 to 2 + size_t bytes[] = {x.Bytes(), y.Bytes(), z.Bytes(), w.Bytes()}; + size_t norm_bytes[] = {x.NormBytes(), y.NormBytes(), z.NormBytes(), w.NormBytes()}; + if (x.Precision() == QUDA_DOUBLE_PRECISION) { const int M = 1; Spinor X(x); @@ -151,7 +153,7 @@ inline void blasCuda(const double2 &a, const double2 &b, const double2 &c, BlasCuda, Spinor, Spinor, Spinor, - Functor > blas(X, Y, Z, W, f, x.Length()/(2*M)); + Functor > blas(X, Y, Z, W, f, x.Length()/(2*M), bytes, norm_bytes); blas.apply(*blasStream); } else if (x.Precision() == QUDA_SINGLE_PRECISION) { const int M = 1; @@ -165,7 +167,7 @@ inline void blasCuda(const double2 &a, const double2 &b, const double2 &c, BlasCuda, Spinor, Spinor, Spinor, - Functor > blas(X, Y, Z, W, f, x.Length()/(4*M)); + Functor > blas(X, Y, Z, W, f, x.Length()/(4*M), bytes, norm_bytes); blas.apply(*blasStream); #else errorQuda("blas has not been built for Nspin=%d fields", x.Nspin()); @@ -180,7 +182,7 @@ inline void blasCuda(const double2 &a, const double2 &b, const double2 &c, BlasCuda, Spinor, Spinor, Spinor, - Functor > blas(X, Y, Z, W, f, x.Length()/(2*M)); + Functor > blas(X, Y, Z, W, f, x.Length()/(2*M), bytes, norm_bytes); blas.apply(*blasStream); #else errorQuda("blas has not been built for Nspin=%d fields", x.Nspin()); @@ -197,7 +199,7 @@ inline void blasCuda(const double2 &a, const double2 &b, const double2 &c, BlasCuda, Spinor, Spinor, Spinor, - Functor > blas(X, Y, Z, W, f, y.Volume()); + Functor > blas(X, Y, Z, W, f, y.Volume(), bytes, norm_bytes); blas.apply(*blasStream); #else errorQuda("blas has not been built for Nspin=%d fields", x.Nspin()); @@ -212,7 +214,7 @@ inline void blasCuda(const double2 &a, const double2 &b, const double2 &c, BlasCuda, Spinor, Spinor, Spinor, - Functor > blas(X, Y, Z, W, f, y.Volume()); + Functor > blas(X, Y, Z, W, f, y.Volume(), bytes, norm_bytes); blas.apply(*blasStream); #else errorQuda("blas has not been built for Nspin=%d fields", x.Nspin()); diff --git a/lib/blas_mixed_core.h b/lib/blas_mixed_core.h index ca5bcad333..588bbcd2a4 100644 --- a/lib/blas_mixed_core.h +++ b/lib/blas_mixed_core.h @@ -53,6 +53,8 @@ class BlasCuda : public Tunable { // these can't be curried into the Spinors because of Tesla argument length restriction char *X_h, *Y_h, *Z_h, *W_h; char *Xnorm_h, *Ynorm_h, *Znorm_h, *Wnorm_h; + const size_t *bytes_; + const size_t *norm_bytes_; unsigned int sharedBytesPerThread() const { return 0; } unsigned int sharedBytesPerBlock(const TuneParam ¶m) const { return 0; } @@ -69,9 +71,10 @@ class BlasCuda : public Tunable { public: BlasCuda(SpinorX &X, SpinorY &Y, SpinorZ &Z, SpinorW &W, Functor &f, - int length) : + int length, const size_t *bytes, const size_t *norm_bytes) : arg(X, Y, Z, W, f, length), X_h(0), Y_h(0), Z_h(0), W_h(0), - Xnorm_h(0), Ynorm_h(0), Znorm_h(0), Wnorm_h(0) + Xnorm_h(0), Ynorm_h(0), Znorm_h(0), Wnorm_h(0), + bytes_(bytes), norm_bytes_(norm_bytes) { ; } virtual ~BlasCuda() { } @@ -84,21 +87,18 @@ class BlasCuda : public Tunable { blasKernel <<>>(arg); } -#define BYTES(X) ( arg.X.Precision()*(sizeof(FloatN)/sizeof(((FloatN*)0)->x))*M*arg.X.Stride() ) -#define NORM_BYTES(X) ( (arg.X.Precision() == QUDA_HALF_PRECISION) ? sizeof(float)*arg.length : 0 ) - void preTune() { - arg.X.save(&X_h, &Xnorm_h, BYTES(X), NORM_BYTES(X)); - arg.Y.save(&Y_h, &Ynorm_h, BYTES(Y), NORM_BYTES(Y)); - arg.Z.save(&Z_h, &Znorm_h, BYTES(Z), NORM_BYTES(Z)); - arg.W.save(&W_h, &Wnorm_h, BYTES(W), NORM_BYTES(W)); + arg.X.save(&X_h, &Xnorm_h, bytes_[0], norm_bytes_[0]); + arg.Y.save(&Y_h, &Ynorm_h, bytes_[1], norm_bytes_[1]); + arg.Z.save(&Z_h, &Znorm_h, bytes_[2], norm_bytes_[2]); + arg.W.save(&W_h, &Wnorm_h, bytes_[3], norm_bytes_[3]); } void postTune() { - arg.X.load(&X_h, &Xnorm_h, BYTES(X), NORM_BYTES(X)); - arg.Y.load(&Y_h, &Ynorm_h, BYTES(Y), NORM_BYTES(Y)); - arg.Z.load(&Z_h, &Znorm_h, BYTES(Z), NORM_BYTES(Z)); - arg.W.load(&W_h, &Wnorm_h, BYTES(W), NORM_BYTES(W)); + arg.X.load(&X_h, &Xnorm_h, bytes_[0], norm_bytes_[0]); + arg.Y.load(&Y_h, &Ynorm_h, bytes_[1], norm_bytes_[1]); + arg.Z.load(&Z_h, &Znorm_h, bytes_[2], norm_bytes_[2]); + arg.W.load(&W_h, &Wnorm_h, bytes_[3], norm_bytes_[3]); } long long flops() const { return arg.f.flops()*(sizeof(FloatN)/sizeof(((FloatN*)0)->x))*arg.length*M; } @@ -142,6 +142,9 @@ void blasCuda(const double2 &a, const double2 &b, const double2 &c, // FIXME: use traits to encapsulate register type for shorts - // will reduce template type parameters from 3 to 2 + size_t bytes[] = {x.Bytes(), y.Bytes(), z.Bytes(), w.Bytes()}; + size_t norm_bytes[] = {x.NormBytes(), y.NormBytes(), z.NormBytes(), w.NormBytes()}; + if (x.Precision() == QUDA_SINGLE_PRECISION && y.Precision() == QUDA_DOUBLE_PRECISION) { if (x.Nspin() == 4) { const int M = 12; @@ -153,7 +156,7 @@ void blasCuda(const double2 &a, const double2 &b, const double2 &c, BlasCuda, Spinor, Spinor, Spinor, Functor > - blas(X, Y, Z, W, f, y.Volume()); + blas(X, Y, Z, W, f, y.Volume(), bytes, norm_bytes); blas.apply(*blasStream); } else if (x.Nspin() == 1) { const int M = 3; @@ -165,7 +168,7 @@ void blasCuda(const double2 &a, const double2 &b, const double2 &c, BlasCuda, Spinor, Spinor, Spinor, - Functor > blas(X, Y, Z, W, f, y.Volume()); + Functor > blas(X, Y, Z, W, f, y.Volume(), bytes, norm_bytes); blas.apply(*blasStream); } } else if (x.Precision() == QUDA_HALF_PRECISION && y.Precision() == QUDA_DOUBLE_PRECISION) { @@ -179,7 +182,7 @@ void blasCuda(const double2 &a, const double2 &b, const double2 &c, BlasCuda, Spinor, Spinor, Spinor, - Functor > blas(X, Y, Z, W, f, y.Volume()); + Functor > blas(X, Y, Z, W, f, y.Volume(), bytes, norm_bytes); blas.apply(*blasStream); } else if (x.Nspin() == 1) { const int M = 3; @@ -191,7 +194,7 @@ void blasCuda(const double2 &a, const double2 &b, const double2 &c, BlasCuda, Spinor, Spinor, Spinor, - Functor > blas(X, Y, Z, W, f, y.Volume()); + Functor > blas(X, Y, Z, W, f, y.Volume(), bytes, norm_bytes); blas.apply(*blasStream); } } else if (y.Precision() == QUDA_SINGLE_PRECISION) { @@ -205,7 +208,7 @@ void blasCuda(const double2 &a, const double2 &b, const double2 &c, BlasCuda, Spinor, Spinor, Spinor, - Functor > blas(X, Y, Z, W, f, y.Volume()); + Functor > blas(X, Y, Z, W, f, y.Volume(), bytes, norm_bytes); blas.apply(*blasStream); } else if (x.Nspin() == 1) { const int M = 3; @@ -217,7 +220,7 @@ void blasCuda(const double2 &a, const double2 &b, const double2 &c, BlasCuda, Spinor, Spinor, Spinor, - Functor > blas(X, Y, Z, W, f, y.Volume()); + Functor > blas(X, Y, Z, W, f, y.Volume(), bytes, norm_bytes); blas.apply(*blasStream); } } else { diff --git a/lib/reduce_core.h b/lib/reduce_core.h index e32db2b3a0..641e366b21 100644 --- a/lib/reduce_core.h +++ b/lib/reduce_core.h @@ -281,6 +281,8 @@ class ReduceCuda : public Tunable { // these can't be curried into the Spinors because of Tesla argument length restriction char *X_h, *Y_h, *Z_h, *W_h, *V_h; char *Xnorm_h, *Ynorm_h, *Znorm_h, *Wnorm_h, *Vnorm_h; + const size_t *bytes_; + const size_t *norm_bytes_; unsigned int sharedBytesPerThread() const { return sizeof(ReduceType); } @@ -303,10 +305,12 @@ class ReduceCuda : public Tunable { public: ReduceCuda(doubleN &result, SpinorX &X, SpinorY &Y, SpinorZ &Z, - SpinorW &W, SpinorV &V, Reducer &r, int length) : + SpinorW &W, SpinorV &V, Reducer &r, int length, + const size_t *bytes, const size_t *norm_bytes) : arg(X, Y, Z, W, V, r, (ReduceType*)d_reduce, (ReduceType*)hd_reduce, length), result(result), X_h(0), Y_h(0), Z_h(0), W_h(0), V_h(0), - Xnorm_h(0), Ynorm_h(0), Znorm_h(0), Wnorm_h(0), Vnorm_h(0) { } + Xnorm_h(0), Ynorm_h(0), Znorm_h(0), Wnorm_h(0), Vnorm_h(0), + bytes_(bytes), norm_bytes_(norm_bytes) { } virtual ~ReduceCuda() { } inline TuneKey tuneKey() const { @@ -318,23 +322,20 @@ class ReduceCuda : public Tunable { result = reduceLaunch(arg, tp, stream); } -#define BYTES(X) ( arg.X.Precision()*(sizeof(FloatN)/sizeof(((FloatN*)0)->x))*M*arg.X.Stride() ) -#define NORM_BYTES(X) ( (arg.X.Precision() == QUDA_HALF_PRECISION) ? sizeof(float)*arg.length : 0 ) - void preTune() { - arg.X.save(&X_h, &Xnorm_h, BYTES(X), NORM_BYTES(X)); - arg.Y.save(&Y_h, &Ynorm_h, BYTES(Y), NORM_BYTES(Y)); - arg.Z.save(&Z_h, &Znorm_h, BYTES(Z), NORM_BYTES(Z)); - arg.W.save(&W_h, &Wnorm_h, BYTES(W), NORM_BYTES(W)); - arg.V.save(&V_h, &Vnorm_h, BYTES(V), NORM_BYTES(V)); + arg.X.save(&X_h, &Xnorm_h, bytes_[0], norm_bytes_[0]); + arg.Y.save(&Y_h, &Ynorm_h, bytes_[1], norm_bytes_[1]); + arg.Z.save(&Z_h, &Znorm_h, bytes_[2], norm_bytes_[2]); + arg.W.save(&W_h, &Wnorm_h, bytes_[3], norm_bytes_[3]); + arg.V.save(&V_h, &Vnorm_h, bytes_[4], norm_bytes_[4]); } void postTune() { - arg.X.load(&X_h, &Xnorm_h, BYTES(X), NORM_BYTES(X)); - arg.Y.load(&Y_h, &Ynorm_h, BYTES(Y), NORM_BYTES(Y)); - arg.Z.load(&Z_h, &Znorm_h, BYTES(Z), NORM_BYTES(Z)); - arg.W.load(&W_h, &Wnorm_h, BYTES(W), NORM_BYTES(W)); - arg.V.load(&V_h, &Vnorm_h, BYTES(V), NORM_BYTES(V)); + arg.X.load(&X_h, &Xnorm_h, bytes_[0], norm_bytes_[0]); + arg.Y.load(&Y_h, &Ynorm_h, bytes_[1], norm_bytes_[1]); + arg.Z.load(&Z_h, &Znorm_h, bytes_[2], norm_bytes_[2]); + arg.W.load(&W_h, &Wnorm_h, bytes_[3], norm_bytes_[3]); + arg.V.load(&V_h, &Vnorm_h, bytes_[4], norm_bytes_[4]); } long long flops() const { return arg.r.flops()*(sizeof(FloatN)/sizeof(((FloatN*)0)->x))*arg.length*M; } @@ -402,6 +403,9 @@ doubleN reduceCuda(const double2 &a, const double2 &b, cudaColorSpinorField &x, // FIXME: use traits to encapsulate register type for shorts - // will reduce template type parameters from 3 to 2 + size_t bytes[] = {x.Bytes(), y.Bytes(), z.Bytes(), w.Bytes(), v.Bytes()}; + size_t norm_bytes[] = {x.NormBytes(), y.NormBytes(), z.NormBytes(), w.NormBytes(), v.NormBytes()}; + if (x.Precision() == QUDA_DOUBLE_PRECISION) { if (x.Nspin() == 4){ //wilson const int M = siteUnroll ? 12 : 1; // determines how much work per thread to do @@ -415,7 +419,7 @@ doubleN reduceCuda(const double2 &a, const double2 &b, cudaColorSpinorField &x, Spinor, Spinor, Spinor, Spinor, Spinor, Reducer > - reduce(value, X, Y, Z, W, V, r, reduce_length/(2*M)); + reduce(value, X, Y, Z, W, V, r, reduce_length/(2*M), bytes, norm_bytes); reduce.apply(*getBlasStream()); } else if (x.Nspin() == 1){ //staggered const int M = siteUnroll ? 3 : 1; // determines how much work per thread to do @@ -429,7 +433,7 @@ doubleN reduceCuda(const double2 &a, const double2 &b, cudaColorSpinorField &x, Spinor, Spinor, Spinor, Spinor, Spinor, Reducer > - reduce(value, X, Y, Z, W, V, r, reduce_length/(2*M)); + reduce(value, X, Y, Z, W, V, r, reduce_length/(2*M), bytes, norm_bytes); reduce.apply(*getBlasStream()); } else { errorQuda("ERROR: nSpin=%d is not supported\n", x.Nspin()); } } else if (x.Precision() == QUDA_SINGLE_PRECISION) { @@ -446,7 +450,7 @@ doubleN reduceCuda(const double2 &a, const double2 &b, cudaColorSpinorField &x, Spinor, Spinor, Spinor, Spinor, Spinor, Reducer > - reduce(value, X, Y, Z, W, V, r, reduce_length/(4*M)); + reduce(value, X, Y, Z, W, V, r, reduce_length/(4*M), bytes, norm_bytes); reduce.apply(*getBlasStream()); #else errorQuda("blas has not been built for Nspin=%d fields", x.Nspin()); @@ -464,7 +468,7 @@ doubleN reduceCuda(const double2 &a, const double2 &b, cudaColorSpinorField &x, Spinor, Spinor, Spinor, Spinor, Spinor, Reducer > - reduce(value, X, Y, Z, W, V, r, reduce_length/(2*M)); + reduce(value, X, Y, Z, W, V, r, reduce_length/(2*M), bytes, norm_bytes); reduce.apply(*getBlasStream()); #else errorQuda("blas has not been built for Nspin=%d fields", x.Nspin()); @@ -483,7 +487,7 @@ doubleN reduceCuda(const double2 &a, const double2 &b, cudaColorSpinorField &x, Spinor, Spinor, Spinor, Spinor, Spinor, Reducer > - reduce(value, X, Y, Z, W, V, r, y.Volume()); + reduce(value, X, Y, Z, W, V, r, y.Volume(), bytes, norm_bytes); reduce.apply(*getBlasStream()); #else errorQuda("blas has not been built for Nspin=%d fields", x.Nspin()); @@ -500,7 +504,7 @@ doubleN reduceCuda(const double2 &a, const double2 &b, cudaColorSpinorField &x, Spinor, Spinor, Spinor, Spinor, Spinor, Reducer > - reduce(value, X, Y, Z, W, V, r, y.Volume()); + reduce(value, X, Y, Z, W, V, r, y.Volume(), bytes, norm_bytes); reduce.apply(*getBlasStream()); #else errorQuda("blas has not been built for Nspin=%d fields", x.Nspin()); diff --git a/lib/reduce_mixed_core.h b/lib/reduce_mixed_core.h index dd4bac6390..b2035a5235 100644 --- a/lib/reduce_mixed_core.h +++ b/lib/reduce_mixed_core.h @@ -289,6 +289,8 @@ class ReduceCuda : public Tunable { // these can't be curried into the Spinors because of Tesla argument length restriction char *X_h, *Y_h, *Z_h, *W_h, *V_h; char *Xnorm_h, *Ynorm_h, *Znorm_h, *Wnorm_h, *Vnorm_h; + const size_t *bytes_; + const size_t *norm_bytes_; unsigned int sharedBytesPerThread() const { return sizeof(ReduceType); } @@ -311,10 +313,12 @@ class ReduceCuda : public Tunable { public: ReduceCuda(doubleN &result, SpinorX &X, SpinorY &Y, SpinorZ &Z, - SpinorW &W, SpinorV &V, Reducer &r, int length) : + SpinorW &W, SpinorV &V, Reducer &r, int length, + const size_t *bytes, const size_t *norm_bytes) : arg(X, Y, Z, W, V, r, (ReduceType*)d_reduce, (ReduceType*)hd_reduce, length), result(result), X_h(0), Y_h(0), Z_h(0), W_h(0), V_h(0), - Xnorm_h(0), Ynorm_h(0), Znorm_h(0), Wnorm_h(0), Vnorm_h(0) { } + Xnorm_h(0), Ynorm_h(0), Znorm_h(0), Wnorm_h(0), Vnorm_h(0), + bytes_(bytes), norm_bytes_(norm_bytes) { } virtual ~ReduceCuda() { } inline TuneKey tuneKey() const { @@ -326,24 +330,20 @@ class ReduceCuda : public Tunable { result = reduceLaunch(arg, tp, stream); } - -#define BYTES(X) ( arg.X.Precision()*(sizeof(FloatN)/sizeof(((FloatN*)0)->x))*M*arg.X.Stride() ) -#define NORM_BYTES(X) ( (arg.X.Precision() == QUDA_HALF_PRECISION) ? sizeof(float)*arg.length : 0 ) - void preTune() { - arg.X.save(&X_h, &Xnorm_h, BYTES(X), NORM_BYTES(X)); - arg.Y.save(&Y_h, &Ynorm_h, BYTES(Y), NORM_BYTES(Y)); - arg.Z.save(&Z_h, &Znorm_h, BYTES(Z), NORM_BYTES(Z)); - arg.W.save(&W_h, &Wnorm_h, BYTES(W), NORM_BYTES(W)); - arg.V.save(&V_h, &Vnorm_h, BYTES(V), NORM_BYTES(V)); + arg.X.save(&X_h, &Xnorm_h, bytes_[0], norm_bytes_[0]); + arg.Y.save(&Y_h, &Ynorm_h, bytes_[1], norm_bytes_[1]); + arg.Z.save(&Z_h, &Znorm_h, bytes_[2], norm_bytes_[2]); + arg.W.save(&W_h, &Wnorm_h, bytes_[3], norm_bytes_[3]); + arg.V.save(&V_h, &Vnorm_h, bytes_[4], norm_bytes_[4]); } void postTune() { - arg.X.load(&X_h, &Xnorm_h, BYTES(X), NORM_BYTES(X)); - arg.Y.load(&Y_h, &Ynorm_h, BYTES(Y), NORM_BYTES(Y)); - arg.Z.load(&Z_h, &Znorm_h, BYTES(Z), NORM_BYTES(Z)); - arg.W.load(&W_h, &Wnorm_h, BYTES(W), NORM_BYTES(W)); - arg.V.load(&V_h, &Vnorm_h, BYTES(V), NORM_BYTES(V)); + arg.X.load(&X_h, &Xnorm_h, bytes_[0], norm_bytes_[0]); + arg.Y.load(&Y_h, &Ynorm_h, bytes_[1], norm_bytes_[1]); + arg.Z.load(&Z_h, &Znorm_h, bytes_[2], norm_bytes_[2]); + arg.W.load(&W_h, &Wnorm_h, bytes_[3], norm_bytes_[3]); + arg.V.load(&V_h, &Vnorm_h, bytes_[4], norm_bytes_[4]); } long long flops() const { return arg.r.flops()*(sizeof(FloatN)/sizeof(((FloatN*)0)->x))*arg.length*M; } @@ -405,13 +405,16 @@ doubleN reduceCuda(const double2 &a, const double2 &b, cudaColorSpinorField &x, blasStrings.vol_str = x.VolString(); strcpy(blasStrings.aux_tmp, x.AuxString()); strcat(blasStrings.aux_tmp, ","); - strcat(blasStrings.aux_tmp, y.AuxString()); + strcat(blasStrings.aux_tmp, z.AuxString()); doubleN value; // FIXME: use traits to encapsulate register type for shorts - // will reduce template type parameters from 3 to 2 + size_t bytes[] = {x.Bytes(), y.Bytes(), z.Bytes(), w.Bytes(), v.Bytes()}; + size_t norm_bytes[] = {x.NormBytes(), y.NormBytes(), z.NormBytes(), w.NormBytes(), v.NormBytes()}; + if (x.Precision() == QUDA_SINGLE_PRECISION && z.Precision() == QUDA_DOUBLE_PRECISION) { if (x.Nspin() == 4){ //wilson #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC) @@ -426,7 +429,7 @@ doubleN reduceCuda(const double2 &a, const double2 &b, cudaColorSpinorField &x, Spinor, Spinor, Spinor, Spinor, Spinor, Reducer > - reduce(value, X, Y, Z, W, V, r, y.Volume()); + reduce(value, X, Y, Z, W, V, r, y.Volume(), bytes, norm_bytes); reduce.apply(*getBlasStream()); #else errorQuda("blas has not been built for Nspin=%d fields", x.Nspin()); @@ -445,7 +448,7 @@ doubleN reduceCuda(const double2 &a, const double2 &b, cudaColorSpinorField &x, Spinor, Spinor, Spinor, Spinor, Spinor, Reducer > - reduce(value, X, Y, Z, W, V, r, reduce_length/(2*M)); + reduce(value, X, Y, Z, W, V, r, reduce_length/(2*M), bytes, norm_bytes); reduce.apply(*getBlasStream()); #else errorQuda("blas has not been built for Nspin=%d fields", x.Nspin()); @@ -465,7 +468,7 @@ doubleN reduceCuda(const double2 &a, const double2 &b, cudaColorSpinorField &x, Spinor, Spinor, Spinor, Spinor, Spinor, Reducer > - reduce(value, X, Y, Z, W, V, r, y.Volume()); + reduce(value, X, Y, Z, W, V, r, y.Volume(), bytes, norm_bytes); reduce.apply(*getBlasStream()); #else errorQuda("blas has not been built for Nspin=%d fields", x.Nspin()); @@ -484,7 +487,7 @@ doubleN reduceCuda(const double2 &a, const double2 &b, cudaColorSpinorField &x, Spinor, Spinor, Spinor, Spinor, Spinor, Reducer > - reduce(value, X, Y, Z, W, V, r, reduce_length/(2*M)); + reduce(value, X, Y, Z, W, V, r, reduce_length/(2*M), bytes, norm_bytes); reduce.apply(*getBlasStream()); #else errorQuda("blas has not been built for Nspin=%d fields", x.Nspin()); @@ -503,7 +506,7 @@ doubleN reduceCuda(const double2 &a, const double2 &b, cudaColorSpinorField &x, Spinor, Spinor, Spinor, Spinor, Spinor, Reducer > - reduce(value, X, Y, Z, W, V, r, y.Volume()); + reduce(value, X, Y, Z, W, V, r, y.Volume(), bytes, norm_bytes); reduce.apply(*getBlasStream()); #else errorQuda("blas has not been built for Nspin=%d fields", x.Nspin()); @@ -520,7 +523,7 @@ doubleN reduceCuda(const double2 &a, const double2 &b, cudaColorSpinorField &x, Spinor, Spinor, Spinor, Spinor, Spinor, Reducer > - reduce(value, X, Y, Z, W, V, r, y.Volume()); + reduce(value, X, Y, Z, W, V, r, y.Volume(), bytes, norm_bytes); reduce.apply(*getBlasStream()); #else errorQuda("blas has not been built for Nspin=%d fields", x.Nspin());