Skip to content

Commit

Permalink
Enabled support for initial guess in BiCGstab and GCR. Closes #7.
Browse files Browse the repository at this point in the history
  • Loading branch information
maddyscientist committed Aug 8, 2012
1 parent 17c650b commit 0eaa39a
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 14 deletions.
7 changes: 7 additions & 0 deletions lib/interface_quda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -886,6 +886,12 @@ void invertQuda(void *hp_x, void *hp_b, QudaInvertParam *param)


if (param->use_init_guess == QUDA_USE_INIT_GUESS_YES) { // download initial guess
// initial guess only supported for single-pass solvers
if ((param->solution_type == QUDA_MATDAG_MAT_SOLUTION || param->solution_type == QUDA_MATPCDAG_MATPC_SOLUTION) &&
(param->inv_type == QUDA_BICGSTAB_INVERTER || param->inv_type == QUDA_GCR_INVERTER)) {
errorQuda("Initial guess not supported for two-pass solver");
}

x = new cudaColorSpinorField(*h_x, cudaParam); // solution
} else { // zero initial guess
cudaParam.create = QUDA_ZERO_FIELD_CREATE;
Expand Down Expand Up @@ -917,6 +923,7 @@ void invertQuda(void *hp_x, void *hp_b, QudaInvertParam *param)

switch (param->inv_type) {
case QUDA_CG_INVERTER:
// prepare source if we are doing CGNR
if (param->solution_type != QUDA_MATDAG_MAT_SOLUTION && param->solution_type != QUDA_MATPCDAG_MATPC_SOLUTION) {
copyCuda(*out, *in);
dirac.Mdag(*in, *out);
Expand Down
23 changes: 17 additions & 6 deletions lib/inv_bicgstab_quda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,25 +73,39 @@ void BiCGstab::operator()(cudaColorSpinorField &x, cudaColorSpinorField &b)
cudaColorSpinorField &v = *vp;
cudaColorSpinorField &tmp = *tmpp;
cudaColorSpinorField &t = *tp;

cudaColorSpinorField &w = *wp;
cudaColorSpinorField &z = *zp;

cudaColorSpinorField *x_sloppy, *r_sloppy, *r_0;

double b2; // norm sq of source
double r2; // norm sq of residual

// compute initial residual depending on whether we have an initial guess or not
if (invParam.use_init_guess == QUDA_USE_INIT_GUESS_YES) {
mat(r, x, y);
r2 = xmyNormCuda(b, r);
b2 = normCuda(b);
copyCuda(y, x);
} else {
copyCuda(r, b);
r2 = normCuda(b);
b2 = r2;
}

// set field aliasing according to whether we are doing mixed precision or not
if (invParam.cuda_prec_sloppy == x.Precision()) {
x_sloppy = &x;
r_sloppy = &r;
r_0 = &b;
zeroCuda(*x_sloppy);
copyCuda(*r_sloppy, b);
} else {
ColorSpinorParam csParam(x);
csParam.create = QUDA_ZERO_FIELD_CREATE;
csParam.precision = invParam.cuda_prec_sloppy;
x_sloppy = new cudaColorSpinorField(x, csParam);
csParam.create = QUDA_COPY_FIELD_CREATE;
r_sloppy = new cudaColorSpinorField(b, csParam);
r_sloppy = new cudaColorSpinorField(r, csParam);
r_0 = new cudaColorSpinorField(b, csParam);
}

Expand All @@ -103,9 +117,6 @@ void BiCGstab::operator()(cudaColorSpinorField &x, cudaColorSpinorField &b)
QudaInvertParam invert_param_inner = newQudaInvertParam();
fillInnerInvertParam(invert_param_inner, invParam);

double b2 = normCuda(b);

double r2 = b2;
double stop = b2*invParam.tol*invParam.tol; // stopping condition of solver
double delta = invParam.reliable_delta;

Expand Down
16 changes: 8 additions & 8 deletions lib/inv_gcr_quda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,18 +220,18 @@ void GCR::operator()(cudaColorSpinorField &x, cudaColorSpinorField &b)
for (int i=0; i<4; i++) parity += commCoords(i);
parity = parity % 2;

// calculate initial residual
mat(r, x);
double r2 = xmyNormCuda(b, r);
copyCuda(rSloppy, r);

quda::blas_flops = 0;

cudaColorSpinorField rM(rSloppy);
cudaColorSpinorField xM(rSloppy);

quda::blas_flops = 0;

stopwatchStart();

// calculate initial residual
mat(r, x);
double r2 = xmyNormCuda(b, r);
copyCuda(rSloppy, r);

int total_iter = 0;
int restart = 0;
double r2_old = r2;
Expand Down Expand Up @@ -341,7 +341,7 @@ void GCR::operator()(cudaColorSpinorField &x, cudaColorSpinorField &b)

}

copyCuda(x, y);
if (total_iter > 0) copyCuda(x, y);

if (k>=invParam.maxiter && invParam.verbosity >= QUDA_SUMMARIZE)
warningQuda("Exceeded maximum iterations %d", invParam.maxiter);
Expand Down

0 comments on commit 0eaa39a

Please sign in to comment.