Skip to content

Commit

Permalink
Fixed bugs in clover application: dslash_test now supports full clove…
Browse files Browse the repository at this point in the history
…r testing (part 3 of #19).
  • Loading branch information
maddyscientist committed Jun 10, 2016
1 parent af6d23a commit b6de874
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 25 deletions.
2 changes: 1 addition & 1 deletion include/clover_field_order.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ namespace quda {
complex<Float> tmp(a[parity][idx], a[parity][idx+1]);
return tmp;
} else {
// requesting upper triangular so return conjuate transpose
// requesting upper triangular so return conjugate transpose
return conj(operator()(parity,x,s_col,s_row,c_col,c_row) );
}
}
Expand Down
2 changes: 1 addition & 1 deletion tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ endif()
#define tests

if(${QUDA_DIRAC_WILSON} OR ${QUDA_DIRAC_CLOVER} OR ${QUDA_DIRAC_TWISTED_MASS} OR ${QUDA_DIRAC_TWISTED_CLOVER} OR ${QUDA_DIRAC_DOMAIN_WALL})
cuda_add_executable(dslash_test dslash_test.cpp wilson_dslash_reference.cpp domain_wall_dslash_reference.cpp clover_reference.cpp)
cuda_add_executable(dslash_test dslash_test.cpp wilson_dslash_reference.cpp domain_wall_dslash_reference.cpp clover_reference.cpp blas_reference.cpp)
target_link_libraries(dslash_test ${TEST_LIBS} )

cuda_add_executable(invert_test invert_test.cpp wilson_dslash_reference.cpp domain_wall_dslash_reference.cpp clover_reference.cpp blas_reference.cpp)
Expand Down
71 changes: 51 additions & 20 deletions tests/clover_reference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,36 +20,36 @@ template <typename sFloat, typename cFloat>
void cloverReference(sFloat *out, cFloat *clover, sFloat *in, int parity) {
int nSpin = 4;
int nColor = 3;
int N = nSpin * nColor / 2;
int chiralBlock = (nSpin/2) * (nSpin/2) * nColor * nColor;
int N = nColor * nSpin / 2;
int chiralBlock = N + 2*(N-1)*N/2;

for (int i = 0; i < Vh; i++) {
for (int i=0; i<Vh; i++) {
std::complex<sFloat> *In = reinterpret_cast<std::complex<sFloat>*>(&in[i*nSpin*nColor*2]);
std::complex<sFloat> *Out = reinterpret_cast<std::complex<sFloat>*>(&out[i*nSpin*nColor*2]);
for (int i=0; i<nSpin*nColor; i++) Out[i] = 0.0;

for (int chi=0; chi<nSpin/2; chi++) {
cFloat *D = &clover[((parity*Vh + i)*2 + chi)*chiralBlock];
std::complex<cFloat> *L = reinterpret_cast<std::complex<cFloat>*>(&clover[((parity*Vh + i)*2 + chi)*chiralBlock+N]);
std::complex<cFloat> *L = reinterpret_cast<std::complex<cFloat>*>(&D[N]);

for (int s_col=0; s_col<nSpin/2; s_col++) { // 2 spins per chiral block
for (int c_col=0; c_col<nColor; c_col++) {
const int col = s_col * nColor + c_col;
const int Col = chi*chiralBlock + col;
const int Col = chi*N + col;
Out[Col] = 0.0;

for (int s_row=0; s_row<nSpin/2; s_row++) { // 2 spins per chiral block
for (int c_row=0; c_row<nColor; c_row++) {
const int row = s_row * nColor + c_row;
const int Row = chi*chiralBlock + row;
const int Row = chi*N + row;

if (row == col) {
Out[Col] += D[row] * In[Row];
} else if (col < row) {
int k = N*(N-1)/2 - (N-col)*(N-col-1)/2 + row - col - 1;
Out[Col] += L[k] * In[Row];
Out[Col] += conj(L[k]) * In[Row];
} else if (row < col) {
int k = N*(N-1)/2 - (N-row)*(N-row-1)/2 + col - row - 1;
Out[Col] += conj(L[k]) * In[Row];
Out[Col] += L[k] * In[Row];
}
}
}
Expand All @@ -68,8 +68,10 @@ void apply_clover(void *out, void *clover, void *in, int parity, QudaPrecision p
switch (precision) {
case QUDA_DOUBLE_PRECISION:
cloverReference(static_cast<double*>(out), static_cast<double*>(clover), static_cast<double*>(in), parity);
break;
case QUDA_SINGLE_PRECISION:
cloverReference(static_cast<float*>(out), static_cast<float*>(clover), static_cast<float*>(in), parity);
break;
default:
errorQuda("Unsupported precision %d", precision);
}
Expand All @@ -95,29 +97,47 @@ void clover_matpc(void *out, void **gauge, void *clover, void *clover_inv, void

switch(matpc_type) {
case QUDA_MATPC_EVEN_EVEN:
wil_dslash(tmp, gauge, in, 1, dagger, precision, gauge_param);
apply_clover(out, clover_inv, tmp, 1, precision);
wil_dslash(tmp, gauge, out, 0, dagger, precision, gauge_param);
apply_clover(out, clover_inv, tmp, 0, precision);
if (!dagger) {
wil_dslash(tmp, gauge, in, 1, dagger, precision, gauge_param);
apply_clover(out, clover_inv, tmp, 1, precision);
wil_dslash(tmp, gauge, out, 0, dagger, precision, gauge_param);
apply_clover(out, clover_inv, tmp, 0, precision);
} else {
apply_clover(tmp, clover_inv, in, 0, precision);
wil_dslash(out, gauge, tmp, 1, dagger, precision, gauge_param);
apply_clover(tmp, clover_inv, out, 1, precision);
wil_dslash(out, gauge, tmp, 0, dagger, precision, gauge_param);
}
xpay(in, kappa2, out, Vh*spinorSiteSize, precision);
break;
case QUDA_MATPC_EVEN_EVEN_ASYMMETRIC:
wil_dslash(out, gauge, in, 1, dagger, precision, gauge_param);
apply_clover(tmp, clover_inv, out, 1, precision);
wil_dslash(out, gauge, tmp, 0, dagger, precision, gauge_param);
apply_clover(tmp, clover, in, 0, precision);
xpay(tmp, kappa2, out, Vh*spinorSiteSize, precision);
break;
case QUDA_MATPC_ODD_ODD:
wil_dslash(tmp, gauge, in, 0, dagger, precision, gauge_param);
apply_clover(out, clover_inv, tmp, 0, precision);
wil_dslash(tmp, gauge, out, 1, dagger, precision, gauge_param);
apply_clover(out, clover_inv, tmp, 1, precision);
if (!dagger) {
wil_dslash(tmp, gauge, in, 0, dagger, precision, gauge_param);
apply_clover(out, clover_inv, tmp, 0, precision);
wil_dslash(tmp, gauge, out, 1, dagger, precision, gauge_param);
apply_clover(out, clover_inv, tmp, 1, precision);
} else {
apply_clover(tmp, clover_inv, in, 1, precision);
wil_dslash(out, gauge, tmp, 0, dagger, precision, gauge_param);
apply_clover(tmp, clover_inv, out, 0, precision);
wil_dslash(out, gauge, tmp, 1, dagger, precision, gauge_param);
}
xpay(in, kappa2, out, Vh*spinorSiteSize, precision);
break;
case QUDA_MATPC_ODD_ODD_ASYMMETRIC:
wil_dslash(out, gauge, in, 0, dagger, precision, gauge_param);
apply_clover(tmp, clover_inv, out, 0, precision);
wil_dslash(out, gauge, tmp, 1, dagger, precision, gauge_param);
apply_clover(tmp, clover, in, 1, precision);
xpay(tmp, kappa2, out, Vh*spinorSiteSize, precision);
break;
default:
errorQuda("Unsupoorted matpc=%d", matpc_type);
}
Expand All @@ -129,14 +149,25 @@ void clover_matpc(void *out, void **gauge, void *clover, void *clover_inv, void
void clover_mat(void *out, void **gauge, void *clover, void *in, double kappa,
int dagger, QudaPrecision precision, QudaGaugeParam &gauge_param) {

void *tmp = malloc(V*spinorSiteSize*precision);

void *inEven = in;
void *inOdd = (char*)in + Vh*spinorSiteSize*precision;
void *outEven = out;
void *outOdd = (char*)out + Vh*spinorSiteSize*precision;
void *tmpEven = tmp;
void *tmpOdd = (char*)tmp + Vh*spinorSiteSize*precision;

// Odd part
wil_dslash(outOdd, gauge, inEven, 1, dagger, precision, gauge_param);
apply_clover(tmpOdd, clover, inOdd, 1, precision);

clover_dslash(outOdd, gauge, clover, inEven, 1, dagger, precision, gauge_param);
clover_dslash(outEven, gauge, clover, inOdd, 0, dagger, precision, gauge_param);
// Even part
wil_dslash(outEven, gauge, inOdd, 0, dagger, precision, gauge_param);
apply_clover(tmpEven, clover, inEven, 0, precision);

// lastly apply the kappa term
xpay(in, -kappa, out, V*spinorSiteSize, precision);
xpay(tmp, -kappa, out, V*spinorSiteSize, precision);

free(tmp);
}
6 changes: 3 additions & 3 deletions tests/dslash_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,10 +321,10 @@ void init(int argc, char **argv) {
construct_gauge_field(hostGauge, 1, gauge_param.cpu_prec, &gauge_param);
}

spinor->Source(QUDA_RANDOM_SOURCE);
spinor->Source(QUDA_RANDOM_SOURCE, 0);

if (dslash_type == QUDA_CLOVER_WILSON_DSLASH) {
double norm = 0.0; // clover components are random numbers in the range (-norm, norm)
double norm = 1.0; // clover components are random numbers in the range (-norm, norm)
double diag = 1.0; // constant added to the diagonal

if (test_type == 2 || test_type == 4) {
Expand Down Expand Up @@ -647,7 +647,7 @@ void dslashRef() {
} else if (dslash_type == QUDA_CLOVER_WILSON_DSLASH) {
switch (test_type) {
case 0:
clover_dslash(spinorTmp->V(), hostGauge, hostClover, spinor->V(), parity, dagger, inv_param.cpu_prec, gauge_param);
clover_dslash(spinorRef->V(), hostGauge, hostCloverInv, spinor->V(), parity, dagger, inv_param.cpu_prec, gauge_param);
break;
case 1:
clover_matpc(spinorRef->V(), hostGauge, hostClover, hostCloverInv, spinor->V(), inv_param.kappa, inv_param.matpc_type,
Expand Down

0 comments on commit b6de874

Please sign in to comment.