diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 04bea130e1..bfd8d08ef8 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -38,10 +38,10 @@ cuda_add_executable(deflation_test deflation_test.cpp wilson_dslash_reference.cp target_link_libraries(deflation_test ${TEST_LIBS}) if(${QUDA_DIRAC_STAGGERED}) - cuda_add_executable(staggered_dslash_test staggered_dslash_test.cpp staggered_dslash_reference.cpp) + cuda_add_executable(staggered_dslash_test staggered_dslash_test.cpp staggered_dslash_reference.cpp blas_reference.cpp) target_link_libraries(staggered_dslash_test ${TEST_LIBS}) - cuda_add_executable(staggered_invert_test staggered_invert_test.cpp staggered_dslash_reference.cpp blas_reference.cpp) + cuda_add_executable(staggered_invert_test staggered_invert_test.cpp staggered_dslash_reference.cpp blas_reference.cpp) target_link_libraries(staggered_invert_test ${TEST_LIBS}) endif() diff --git a/tests/Makefile b/tests/Makefile index 13479c38e1..f948fa13ba 100644 --- a/tests/Makefile +++ b/tests/Makefile @@ -71,7 +71,7 @@ multigrid_benchmark_test: multigrid_benchmark_test.o test_util.o misc.o $(QUDA) deflation_test: deflation_test.o test_util.o wilson_dslash_reference.o domain_wall_dslash_reference.o blas_reference.o misc.o $(QUDA) $(CXX) $(LDFLAGS) $^ -o $@ $(LDFLAGS) -staggered_dslash_test: staggered_dslash_test.o gtest-all.o test_util.o staggered_dslash_reference.o misc.o $(QUDA) +staggered_dslash_test: staggered_dslash_test.o gtest-all.o test_util.o staggered_dslash_reference.o misc.o blas_reference.o $(QUDA) $(CXX) $(LDFLAGS) $^ -o $@ $(LDFLAGS) staggered_invert_test: staggered_invert_test.o test_util.o staggered_dslash_reference.o misc.o blas_reference.o $(QUDA) diff --git a/tests/staggered_dslash_reference.cpp b/tests/staggered_dslash_reference.cpp index 6932e07c7b..ea845cb339 100644 --- a/tests/staggered_dslash_reference.cpp +++ b/tests/staggered_dslash_reference.cpp @@ -12,6 +12,7 @@ #include #include +#include extern void *memset(void *s, int c, size_t n); @@ -126,10 +127,7 @@ void Mat(sFloat *out, gFloat **fatlink, gFloat** longlink, sFloat *in, sFloat ka // full dslash operator dslashReference(outOdd, fatlink, longlink, inEven, 1, daggerBit); dslashReference(outEven, fatlink, longlink, inOdd, 0, daggerBit); - - // lastly apply the kappa term - xpay(in, -kappa, out, V*mySpinorSiteSize); -} + } void @@ -150,6 +148,9 @@ mat(void *out, void **fatlink, void** longlink, void *in, double kappa, int dagg Mat((float*)out, (float**)fatlink, (float**)longlink, (float*)in, (float)kappa, dagger_bit); } } + + // lastly apply the kappa term + xpay(in, -kappa, out, V*mySpinorSiteSize, sPrecision); } @@ -220,27 +221,19 @@ matdagmat(void *out, void **fatlink, void** longlink, void *in, double mass, int // Apply the even-odd preconditioned Dirac operator template -static void MatPC(sFloat *outEven, gFloat **fatlink, gFloat** longlink, sFloat *inEven, sFloat kappa, - int daggerBit, QudaMatPCType matpc_type) { +static void MatPC(sFloat *outEven, gFloat **fatlink, gFloat** longlink, sFloat *inEven, int dagger, QudaMatPCType matpc_type) { sFloat *tmp = (sFloat*)malloc(Vh*mySpinorSiteSize*sizeof(sFloat)); // full dslash operator if (matpc_type == QUDA_MATPC_EVEN_EVEN) { - dslashReference(tmp, fatlink, longlink, inEven, 1, daggerBit); - dslashReference(outEven, fatlink, longlink, tmp, 0, daggerBit); - - //dslashReference(outEven, fatlink, longlink, inEven, 1, daggerBit); + dslashReference(tmp, fatlink, longlink, inEven, 1, dagger); + dslashReference(outEven, fatlink, longlink, tmp, 0, dagger); } else { - dslashReference(tmp, fatlink, longlink, inEven, 0, daggerBit); - dslashReference(outEven, fatlink, longlink, tmp, 1, daggerBit); + dslashReference(tmp, fatlink, longlink, inEven, 0, dagger); + dslashReference(outEven, fatlink, longlink, tmp, 1, dagger); } - // lastly apply the kappa term - - sFloat kappa2 = -kappa*kappa; - xpay(inEven, kappa2, outEven, Vh*mySpinorSiteSize); - free(tmp); } @@ -252,18 +245,22 @@ staggered_matpc(void *outEven, void **fatlink, void**longlink, void *inEven, dou if (sPrecision == QUDA_DOUBLE_PRECISION) if (gPrecision == QUDA_DOUBLE_PRECISION) { - MatPC((double*)outEven, (double**)fatlink, (double**)longlink, (double*)inEven, (double)kappa, dagger_bit, matpc_type); + MatPC((double*)outEven, (double**)fatlink, (double**)longlink, (double*)inEven, dagger_bit, matpc_type); } else{ - MatPC((double*)outEven, (double**)fatlink, (double**)longlink, (double*)inEven, (double)kappa, dagger_bit, matpc_type); + MatPC((double*)outEven, (double**)fatlink, (double**)longlink, (double*)inEven, dagger_bit, matpc_type); } else { if (gPrecision == QUDA_DOUBLE_PRECISION){ - MatPC((float*)outEven, (double**)fatlink, (double**)longlink, (float*)inEven, (float)kappa, dagger_bit, matpc_type); + MatPC((float*)outEven, (double**)fatlink, (double**)longlink, (float*)inEven, dagger_bit, matpc_type); }else{ - MatPC((float*)outEven, (float**)fatlink, (float**)longlink, (float*)inEven, (float)kappa, dagger_bit, matpc_type); + MatPC((float*)outEven, (float**)fatlink, (float**)longlink, (float*)inEven, dagger_bit, matpc_type); } } + + // lastly apply the kappa term + double kappa2 = -kappa*kappa; + xpay(inEven, kappa2, outEven, Vh*mySpinorSiteSize, sPrecision); } #ifdef MULTI_GPU