Skip to content

Commit

Permalink
Clean up of extended field handling used in clover force reference code
Browse files Browse the repository at this point in the history
  • Loading branch information
maddyscientist committed Dec 15, 2023
1 parent b752a97 commit 7bd79ed
Showing 1 changed file with 15 additions and 87 deletions.
102 changes: 15 additions & 87 deletions tests/host_reference/clover_force_reference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -444,10 +444,10 @@ void computeCloverSigmaTrace_reference(void *oprod, void *clover, double coeff,
// FIXME: here call the appropriate template function according to gauge_precision
cloverSigmaTraceCompute_host((double *)oprod, (double *)clover, coeff, parity, mu2, eps2, twist);
}

template <typename gFloat>
void get_su3FromOprod(gFloat *oprod_out, gFloat *oprod, int munu, size_t nbr_idx, const lattice_t &lat)
{

int x_cb = nbr_idx % (lat.volume_ex / 2);
int OddBit = nbr_idx / (lat.volume_ex / 2);

Expand Down Expand Up @@ -778,6 +778,7 @@ void cloverDerivative_reference(void *h_mom, void **gauge, void *oprod, int pari
// created extended field
quda::lat_dim_t R;
for (int d = 0; d < 4; d++) R[d] = 2 * quda::comm_dim_partitioned(d);

QudaGaugeParam param = newQudaGaugeParam();
setGaugeParam(param);
param.gauge_order = QUDA_QDP_GAUGE_ORDER;
Expand All @@ -786,43 +787,11 @@ void cloverDerivative_reference(void *h_mom, void **gauge, void *oprod, int pari
auto qdp_ex = quda::createExtendedGauge(gauge, param, R);
lattice_t lat(*qdp_ex);

/// the following does not work: segmentation fault
// param.geometry = QUDA_TENSOR_GEOMETRY;
// create a qdp gauge
// std::vector<char> oprod_qdp_;
// std::array<void *, 6> oprod_qdp;
// oprod_qdp_.resize(6 * V * gauge_site_size * host_gauge_data_type_size);
// for (int i = 0; i < 6; i++) oprod_qdp[i] = oprod_qdp_.data() + i * V * gauge_site_size * host_gauge_data_type_size;
// int T = param.X[3];
// int LX = param.X[0];
// int LY = param.X[1];
// int LZ = param.X[2];

// for (int x0 = 0; x0 < T; x0++) {
// for (int x1 = 0; x1 < LX; x1++) {
// for (int x2 = 0; x2 < LY; x2++) {
// for (int x3 = 0; x3 < LZ; x3++) {
// int j = (x1 + LX * x2 + LY * LX * x3 + LZ * LY * LX * x0) / 2;
// int oddBit = (x0 + x1 + x2 + x3) & 1;
// for (int munu = 0; munu < 6; munu++) {
// double *out = (double *)oprod_qdp[munu];
// double *in = (double *)oprod;
// for (int i = 0; i < 9; i++) {
// for (int reim = 0; reim < 2; reim++) {
// out[reim + 2 * (i + 9 * (j + Vh * (oddBit)))]
// = in[reim + 2 * (j / 2 + Vh * (i + 9 * (munu + 6 * (oddBit))))];
// }
// }
// }
// }
// }
// }
// }
// auto oprod_ex = quda::createExtendedTensorGauge(oprod_qdp.data(), param, R);
// printf("HERE before oprod_ex created\n");

void *u_array[QUDA_MAX_DIM];
for (int d = 0; d < 4; d++) u_array[d] = qdp_ex->data(d);
quda::GaugeFieldParam gparam(gauge_param, oprod, QUDA_GENERAL_LINKS);
gparam.create = QUDA_REFERENCE_FIELD_CREATE;
gparam.order = QUDA_FLOAT2_GAUGE_ORDER;
gparam.geometry = QUDA_TENSOR_GEOMETRY;
auto oprod_ex = quda::createExtendedGauge(quda::GaugeField(gparam), R);

#pragma omp parallel for
for (int i = 0; i < Vh; i++) {
Expand All @@ -832,16 +801,17 @@ void cloverDerivative_reference(void *h_mom, void **gauge, void *oprod, int pari
if (nu == mu)
continue;
else if (gauge_param.cpu_prec == QUDA_DOUBLE_PRECISION)
computeForce_reference<double>(h_mom, u_array, lat, oprod, i, yIndex, parity, mu, nu);
computeForce_reference<double>(h_mom, (void**)qdp_ex->raw_pointer(), lat, oprod_ex->data(), i, yIndex, parity, mu, nu);
else if (gauge_param.cpu_prec == QUDA_SINGLE_PRECISION)
computeForce_reference<float>(h_mom, u_array, lat, oprod, i, yIndex, parity, mu, nu);
computeForce_reference<float>(h_mom, (void**)qdp_ex->raw_pointer(), lat, oprod_ex->data(), i, yIndex, parity, mu, nu);
else
errorQuda("Unsupported precision %d", gauge_param.cpu_prec);
}
}
}
}

delete oprod_ex;
delete qdp_ex;
}

Expand Down Expand Up @@ -1056,17 +1026,15 @@ void TMCloverForce_reference(void *h_mom, void **h_x, void **h_x0, double *coeff
CloverForce_reference(refmom, gauge, x, p, force_coeff);

// create oprod and trace field
void *oprod;
std::vector<char> oprod_;
oprod_.resize(V * 6 * gauge_site_size * host_gauge_data_type_size);
oprod = oprod_.data();
std::vector<char> oprod_(V * 6 * gauge_site_size * host_gauge_data_type_size);
void *oprod = oprod_.data();

if (gauge_param->cpu_prec == QUDA_DOUBLE_PRECISION)
set_to_zero<double>(oprod);
else if (gauge_param->cpu_prec == QUDA_SINGLE_PRECISION)
set_to_zero<float>(oprod);
else
errorQuda("precision not valid\n");
errorQuda("precision not valid");

double k_csw_ov_8 = inv_param->kappa * inv_param->clover_csw / 8.0;
size_t twist_flavor = inv_param->dslash_type == QUDA_TWISTED_CLOVER_DSLASH ? inv_param->twist_flavor : QUDA_TWIST_NO;
Expand All @@ -1089,50 +1057,10 @@ void TMCloverForce_reference(void *h_mom, void **h_x, void **h_x0, double *coeff
// sw_spinor_eo(OO,..) in tmLQCD
computeCloverSigmaOprod_reference(oprod, p, x, ferm_epsilon, *gauge_param);

// create extended field
quda::GaugeFieldParam gParamMom(*gauge_param, h_mom, QUDA_ASQTAD_MOM_LINKS);
gParamMom.location = QUDA_CUDA_FIELD_LOCATION;
gParamMom.link_type = QUDA_GENERAL_LINKS;
gParamMom.create = QUDA_ZERO_FIELD_CREATE;
gParamMom.order = QUDA_FLOAT2_GAUGE_ORDER;
gParamMom.reconstruct = QUDA_RECONSTRUCT_NO;
gParamMom.geometry = QUDA_TENSOR_GEOMETRY;
quda::GaugeField cudaOprod(gParamMom);
cudaOprod.copy_from_buffer(oprod);

quda::lat_dim_t R;
for (int d = 0; d < 4; d++) R[d] = 2 * quda::comm_dim_partitioned(d);
quda::TimeProfile profile_host("profile_host");
quda::GaugeField *cudaOprodEx = createExtendedGauge(cudaOprod, R, profile_host);

int ghostFace[4];
int ghost_size = 0;
for (int i = 0; i < 4; i++) {
ghostFace[i] = 0;
if (quda::comm_dim_partitioned(i)) {
ghostFace[i] = 1;
for (int j = 0; j < 4; j++) {
if (i == j)
continue;
else if (j == 0)
ghostFace[i] *= qParam.x[j] * 2 + 2 * R[i];
else
ghostFace[i] *= qParam.x[j] + 2 * R[i];
}
}
ghost_size += 2 * R[i] * ghostFace[i];
}
std::vector<char> oprod_ex_;
oprod_ex_.resize((V + ghost_size) * 6 * gauge_site_size * host_gauge_data_type_size);
void *oprod_ex = oprod_ex_.data();
cudaOprodEx->copy_to_buffer(oprod_ex);

// oprod = (A12) of hep-lat/0112051
// compute the insertion of oprod in Fig.27 of hep-lat/0112051
cloverDerivative_reference(refmom, gauge.data(), oprod_ex, QUDA_ODD_PARITY, *gauge_param);
cloverDerivative_reference(refmom, gauge.data(), oprod_ex, QUDA_EVEN_PARITY, *gauge_param);
cloverDerivative_reference(refmom, gauge.data(), oprod, QUDA_ODD_PARITY, *gauge_param);
cloverDerivative_reference(refmom, gauge.data(), oprod, QUDA_EVEN_PARITY, *gauge_param);

add_mom((double *)h_mom, (double *)mom.data(), 4 * V * mom_site_size, -1.0);

delete cudaOprodEx;
}

0 comments on commit 7bd79ed

Please sign in to comment.