Skip to content

Commit

Permalink
Fixed staggered reference dslash compiliation
Browse files Browse the repository at this point in the history
  • Loading branch information
maddyscientist committed Jun 10, 2016
1 parent b6de874 commit 6fd83bb
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 24 deletions.
4 changes: 2 additions & 2 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ cuda_add_executable(deflation_test deflation_test.cpp wilson_dslash_reference.cp
target_link_libraries(deflation_test ${TEST_LIBS})

if(${QUDA_DIRAC_STAGGERED})
cuda_add_executable(staggered_dslash_test staggered_dslash_test.cpp staggered_dslash_reference.cpp)
cuda_add_executable(staggered_dslash_test staggered_dslash_test.cpp staggered_dslash_reference.cpp blas_reference.cpp)
target_link_libraries(staggered_dslash_test ${TEST_LIBS})

cuda_add_executable(staggered_invert_test staggered_invert_test.cpp staggered_dslash_reference.cpp blas_reference.cpp)
cuda_add_executable(staggered_invert_test staggered_invert_test.cpp staggered_dslash_reference.cpp blas_reference.cpp)
target_link_libraries(staggered_invert_test ${TEST_LIBS})
endif()

Expand Down
2 changes: 1 addition & 1 deletion tests/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ multigrid_benchmark_test: multigrid_benchmark_test.o test_util.o misc.o $(QUDA)
deflation_test: deflation_test.o test_util.o wilson_dslash_reference.o domain_wall_dslash_reference.o blas_reference.o misc.o $(QUDA)
$(CXX) $(LDFLAGS) $^ -o $@ $(LDFLAGS)

staggered_dslash_test: staggered_dslash_test.o gtest-all.o test_util.o staggered_dslash_reference.o misc.o $(QUDA)
staggered_dslash_test: staggered_dslash_test.o gtest-all.o test_util.o staggered_dslash_reference.o misc.o blas_reference.o $(QUDA)
$(CXX) $(LDFLAGS) $^ -o $@ $(LDFLAGS)

staggered_invert_test: staggered_invert_test.o test_util.o staggered_dslash_reference.o misc.o blas_reference.o $(QUDA)
Expand Down
39 changes: 18 additions & 21 deletions tests/staggered_dslash_reference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <blas_quda.h>

#include <face_quda.h>
#include <blas_reference.h>

extern void *memset(void *s, int c, size_t n);

Expand Down Expand Up @@ -126,10 +127,7 @@ void Mat(sFloat *out, gFloat **fatlink, gFloat** longlink, sFloat *in, sFloat ka
// full dslash operator
dslashReference(outOdd, fatlink, longlink, inEven, 1, daggerBit);
dslashReference(outEven, fatlink, longlink, inOdd, 0, daggerBit);

// lastly apply the kappa term
xpay(in, -kappa, out, V*mySpinorSiteSize);
}
}


void
Expand All @@ -150,6 +148,9 @@ mat(void *out, void **fatlink, void** longlink, void *in, double kappa, int dagg
Mat((float*)out, (float**)fatlink, (float**)longlink, (float*)in, (float)kappa, dagger_bit);
}
}

// lastly apply the kappa term
xpay(in, -kappa, out, V*mySpinorSiteSize, sPrecision);
}


Expand Down Expand Up @@ -220,27 +221,19 @@ matdagmat(void *out, void **fatlink, void** longlink, void *in, double mass, int

// Apply the even-odd preconditioned Dirac operator
template <typename sFloat, typename gFloat>
static void MatPC(sFloat *outEven, gFloat **fatlink, gFloat** longlink, sFloat *inEven, sFloat kappa,
int daggerBit, QudaMatPCType matpc_type) {
static void MatPC(sFloat *outEven, gFloat **fatlink, gFloat** longlink, sFloat *inEven, int dagger, QudaMatPCType matpc_type) {

sFloat *tmp = (sFloat*)malloc(Vh*mySpinorSiteSize*sizeof(sFloat));

// full dslash operator
if (matpc_type == QUDA_MATPC_EVEN_EVEN) {
dslashReference(tmp, fatlink, longlink, inEven, 1, daggerBit);
dslashReference(outEven, fatlink, longlink, tmp, 0, daggerBit);

//dslashReference(outEven, fatlink, longlink, inEven, 1, daggerBit);
dslashReference(tmp, fatlink, longlink, inEven, 1, dagger);
dslashReference(outEven, fatlink, longlink, tmp, 0, dagger);
} else {
dslashReference(tmp, fatlink, longlink, inEven, 0, daggerBit);
dslashReference(outEven, fatlink, longlink, tmp, 1, daggerBit);
dslashReference(tmp, fatlink, longlink, inEven, 0, dagger);
dslashReference(outEven, fatlink, longlink, tmp, 1, dagger);
}

// lastly apply the kappa term

sFloat kappa2 = -kappa*kappa;
xpay(inEven, kappa2, outEven, Vh*mySpinorSiteSize);

free(tmp);
}

Expand All @@ -252,18 +245,22 @@ staggered_matpc(void *outEven, void **fatlink, void**longlink, void *inEven, dou

if (sPrecision == QUDA_DOUBLE_PRECISION)
if (gPrecision == QUDA_DOUBLE_PRECISION) {
MatPC((double*)outEven, (double**)fatlink, (double**)longlink, (double*)inEven, (double)kappa, dagger_bit, matpc_type);
MatPC((double*)outEven, (double**)fatlink, (double**)longlink, (double*)inEven, dagger_bit, matpc_type);
}
else{
MatPC((double*)outEven, (double**)fatlink, (double**)longlink, (double*)inEven, (double)kappa, dagger_bit, matpc_type);
MatPC((double*)outEven, (double**)fatlink, (double**)longlink, (double*)inEven, dagger_bit, matpc_type);
}
else {
if (gPrecision == QUDA_DOUBLE_PRECISION){
MatPC((float*)outEven, (double**)fatlink, (double**)longlink, (float*)inEven, (float)kappa, dagger_bit, matpc_type);
MatPC((float*)outEven, (double**)fatlink, (double**)longlink, (float*)inEven, dagger_bit, matpc_type);
}else{
MatPC((float*)outEven, (float**)fatlink, (float**)longlink, (float*)inEven, (float)kappa, dagger_bit, matpc_type);
MatPC((float*)outEven, (float**)fatlink, (float**)longlink, (float*)inEven, dagger_bit, matpc_type);
}
}

// lastly apply the kappa term
double kappa2 = -kappa*kappa;
xpay(inEven, kappa2, outEven, Vh*mySpinorSiteSize, sPrecision);
}

#ifdef MULTI_GPU
Expand Down

0 comments on commit 6fd83bb

Please sign in to comment.