Skip to content

Commit

Permalink
fix some shared bytes amounts
Browse files Browse the repository at this point in the history
  • Loading branch information
jcosborn committed Nov 30, 2023
1 parent 51d75ae commit 92ca04d
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 10 deletions.
1 change: 1 addition & 0 deletions lib/clover_deriv_quda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ namespace quda {
double coeff;
int parity;
unsigned int minThreads() const { return gauge.LocalVolumeCB(); }
unsigned int sharedBytesPerThread() const { return 4 * sizeof(int); } // for thread_array

public:
DerivativeClover(GaugeField &force, GaugeField &gauge, GaugeField &oprod, double coeff, int parity) :
Expand Down
4 changes: 1 addition & 3 deletions lib/coarse_op.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ namespace quda {

unsigned int sharedBytesPerBlock(const TuneParam &param) const override
{
if (type == COMPUTE_VUV || type == COMPUTE_VLV)
if (arg.shared_atomic && (type == COMPUTE_VUV || type == COMPUTE_VLV))
return 4*sizeof(storeType)*arg.max_color_height_per_block*arg.max_color_width_per_block*4*coarseSpin*coarseSpin;
return TunableKernel3D::sharedBytesPerBlock(param);
}
Expand Down Expand Up @@ -577,9 +577,7 @@ namespace quda {
if (type == COMPUTE_VUV || type == COMPUTE_VLV || type == COMPUTE_CONVERT || type == COMPUTE_RESCALE) arg.dim_index = 4*(dir==QUDA_BACKWARDS ? 0 : 1) + dim;
arg.kd_dagger = kd_dagger;

if (type == COMPUTE_VUV || type == COMPUTE_VLV) tp.shared_bytes -= sharedBytesPerBlock(tp); // shared memory is static so don't include it in launch
Launch<location_template>(arg, tp, type, stream);
if (type == COMPUTE_VUV || type == COMPUTE_VLV) tp.shared_bytes += sharedBytesPerBlock(tp); // restore shared memory
};

/**
Expand Down
12 changes: 9 additions & 3 deletions lib/dslash5_domain_wall.cu
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ namespace quda
int blockMin() const { return 4; }
unsigned int sharedBytesPerThread() const
{
if (mobius_m5::shared()) {
if (mobius_m5::shared()
&& (type == Dslash5Type::M5_INV_DWF || type == Dslash5Type::M5_INV_MOBIUS
|| type == Dslash5Type::M5_INV_ZMOBIUS)) {
// spin components in shared depend on inversion algorithm
bool isInv = type == Dslash5Type::M5_INV_DWF || type == Dslash5Type::M5_INV_MOBIUS || type == Dslash5Type::M5_INV_ZMOBIUS;
int nSpin = (!isInv || mobius_m5::var_inverse()) ? mobius_m5::use_half_vector() ? in.Nspin() / 2 : in.Nspin() : in.Nspin();
Expand All @@ -81,7 +83,9 @@ namespace quda
// overloaded to return max dynamic shared memory if doing shared-memory inverse
unsigned int maxSharedBytesPerBlock() const
{
if (mobius_m5::shared()) {
if (mobius_m5::shared()
&& (type == Dslash5Type::M5_INV_DWF || type == Dslash5Type::M5_INV_MOBIUS
|| type == Dslash5Type::M5_INV_ZMOBIUS)) {
return maxDynamicSharedBytesPerBlock();
} else {
return TunableKernel3D::maxSharedBytesPerBlock();
Expand All @@ -104,7 +108,9 @@ namespace quda
xpay(a == 0.0 ? false : true),
type(type)
{
if (mobius_m5::shared()) {
if (mobius_m5::shared()
&& (type == Dslash5Type::M5_INV_DWF || type == Dslash5Type::M5_INV_MOBIUS
|| type == Dslash5Type::M5_INV_ZMOBIUS)) {
TunableKernel2D_base<false>::resizeStep(in.X(4)); // Ls must be contained in the block
}

Expand Down
1 change: 1 addition & 0 deletions lib/gauge_ape.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ namespace quda {
const GaugeField &in;
const Float alpha;
unsigned int minThreads() const { return in.LocalVolumeCB(); }
unsigned int sharedBytesPerThread() const { return 4 * sizeof(int); } // for thread_array

public:
// (2,3): 2 for parity in the y thread dim, 3 corresponds to mapping direction to the z thread dim
Expand Down
5 changes: 3 additions & 2 deletions lib/gauge_stout.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ namespace quda {
unsigned int maxSharedBytesPerBlock() const { return maxDynamicSharedBytesPerBlock(); }
unsigned int sharedBytesPerThread() const
{
// use SharedMemoryCache if using over improvement for two link fields
return improved ? 2 * in.Ncolor() * in.Ncolor() * 2 * sizeof(typename mapper<Float>::type) : 0;
// use ThreadLocalCache if using over improvement for two link fields
return (improved ? 2 * in.Ncolor() * in.Ncolor() * 2 * sizeof(typename mapper<Float>::type) : 0)
+ 4 * sizeof(int); // for thread_array
}

public:
Expand Down
5 changes: 3 additions & 2 deletions lib/gauge_wilson_flow.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ namespace quda {

unsigned int sharedBytesPerThread() const
{
// use SharedMemoryCache if using Symanzik improvement for two Link fields
return 4*sizeof(int) + (wflow_type == QUDA_GAUGE_SMEAR_SYMANZIK_FLOW ? 2 * in.Ncolor() * in.Ncolor() * 2 * sizeof(typename mapper<Float>::type) : 0);
// use ThreadLocalCache if using Symanzik improvement for two Link fields
return (wflow_type == QUDA_GAUGE_SMEAR_SYMANZIK_FLOW ? 2 * in.Ncolor() * in.Ncolor() * 2 * sizeof(typename mapper<Float>::type) : 0)
+ 4 * sizeof(int); // for thread_array
}

public:
Expand Down

0 comments on commit 92ca04d

Please sign in to comment.