Skip to content

Commit

Permalink
Masks: restrict rdivide field to powers of 2
Browse files Browse the repository at this point in the history
  • Loading branch information
petercad authored and karturov committed Nov 18, 2024
1 parent 05d68df commit 280bd28
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 18 deletions.
21 changes: 11 additions & 10 deletions src/gpu/intel/jit/gemm/generator/pieces/layout_setup.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ bool BLASKernelGenerator<hw>::getBlockInfo(Type T, const MatrixAddressing &atype
vymask.bitRep = consecutive;
vymask.maskRep = 1;
vymask.rsize = *yblock;
vymask.rdivide = 1;
vymask.rshift = 0;
} else if (logicalSlots < slots) {
auto &fymask = block.colMajor ? block.rowMask.fixed : block.colMask.fixed;
fymask.isFixed = true;
Expand All @@ -279,7 +279,7 @@ bool BLASKernelGenerator<hw>::getBlockInfo(Type T, const MatrixAddressing &atype
vxmask.bitRep = (block.simdSize > 16) ? 32 : 16;
vxmask.maskRep = 1;
vxmask.rsize = 1;
vxmask.rdivide = 1;
vxmask.rshift = 0;
} else if (allowDesc && (channelScattered || astrategy.newDP) && *xblock > 1 && !byte) {
fragment = std::min(*xblock, 4 * width / T);
if (block.colMajor) // Clang can't handle the ternary operator equivalent of this.
Expand Down Expand Up @@ -482,7 +482,7 @@ bool BLASKernelGenerator<hw>::getBlockInfo(Type T, const MatrixAddressing &atype
vrmask.rsize = rblock;
vrmask.bitRep = std::max<int>(T.paddedSize() / maskGranularity, 1);
vrmask.maskRep = cblock;
vrmask.rdivide = std::max<int>(maskGranularity / T, 1);
vrmask.rshift = ilog2(std::max<int>(maskGranularity / T, 1));
}
} else {
if (avoidFragment) {
Expand All @@ -491,8 +491,8 @@ bool BLASKernelGenerator<hw>::getBlockInfo(Type T, const MatrixAddressing &atype
vrmask.isFixed = false;
vrmask.bitRep = 0; /* will be filled in later */
vrmask.maskRep = 1;
vrmask.rdivide = 1;
vrmask.rsize = 1;
vrmask.rshift = 0;
} else {
// Fragment it. Could actually handle rowFragment = 2 by changing descriptor.
block.rowFragment = 1;
Expand Down Expand Up @@ -520,7 +520,7 @@ bool BLASKernelGenerator<hw>::getBlockInfo(Type T, const MatrixAddressing &atype
vcmask.rsize = cblock;
vcmask.bitRep = std::max<int>(T.paddedSize() / maskGranularity, 1);
vcmask.maskRep = rblock;
vcmask.rdivide = std::max<int>(maskGranularity / T, 1);
vcmask.rshift = ilog2(std::max<int>(maskGranularity / T, 1));
}
} else {
if (avoidFragment) {
Expand All @@ -529,8 +529,8 @@ bool BLASKernelGenerator<hw>::getBlockInfo(Type T, const MatrixAddressing &atype
vcmask.isFixed = false;
vcmask.bitRep = 0;
vcmask.maskRep = 1;
vcmask.rdivide = 1;
vcmask.rsize = 1;
vcmask.rshift = 0;
} else {
// Fragment it. Could actually handle colFragment = 2 by changing descriptor.
block.colFragment = 1;
Expand Down Expand Up @@ -719,7 +719,8 @@ bool BLASKernelGenerator<hw>::getBlockInfo(Type T, const MatrixAddressing &atype
auto &vxmask = block.colMajor ? block.rowMask.variable : block.colMask.variable;
vxmask.isFixed = false;
vxmask.bitRep = block.simdSize;
vxmask.maskRep = vxmask.rdivide = vxmask.rsize = 1;
vxmask.maskRep = vxmask.rsize = 1;
vxmask.rshift = 0;
}

if (remainderY) {
Expand All @@ -728,7 +729,7 @@ bool BLASKernelGenerator<hw>::getBlockInfo(Type T, const MatrixAddressing &atype
vymask.bitRep = xCacheLines;
vymask.maskRep = 1;
vymask.rsize = yblock;
vymask.rdivide = 1;
vymask.rshift = 0;
}
break;
}
Expand All @@ -739,13 +740,13 @@ bool BLASKernelGenerator<hw>::getBlockInfo(Type T, const MatrixAddressing &atype
if (block.rowMask && !block.rowMask.fixed.isFixed) {
if (vrmask.rsize == 0)
vrmask.rsize = rblock;
vrmask.maskRep = std::min<int>(vrmask.maskRep, std::max<int>(1, vrmask.rdivide * block.simdSize / (vrmask.bitRep * vrmask.rsize)));
vrmask.maskRep = std::min<int>(vrmask.maskRep, std::max<int>(1, (block.simdSize << vrmask.rshift) / (vrmask.bitRep * vrmask.rsize)));
block.noRowsOK = true; // All-zero masks are always OK.
}
if (block.colMask && !block.colMask.fixed.isFixed) {
if (vcmask.rsize == 0)
vcmask.rsize = cblock;
vcmask.maskRep = std::min<int>(vcmask.maskRep, std::max<int>(1, vcmask.rdivide * block.simdSize / (vcmask.bitRep * vcmask.rsize)));
vcmask.maskRep = std::min<int>(vcmask.maskRep, std::max<int>(1, (block.simdSize << vcmask.rshift) / (vcmask.bitRep * vcmask.rsize)));
block.noColsOK = true;
}

Expand Down
14 changes: 7 additions & 7 deletions src/gpu/intel/jit/gemm/generator/pieces/masks.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ void BLASKernelGenerator<hw>::loadMask(MaskAssignment assignment, Subregister in
// Load a variable mask, which requires some minor bit-twiddling.
auto &vmask = assignment.mask.variable;

uint32_t rsizeScaled = vmask.rsize / vmask.rdivide;
uint32_t rsizeScaled = vmask.rsize >> vmask.rshift;
uint32_t maskLen = vmask.bitRep * vmask.maskRep * rsizeScaled;
uint32_t fullMask = (uint64_t(1) << maskLen) - 1;
uint32_t rep1Mask = (uint64_t(1) << (vmask.bitRep * rsizeScaled)) - 1;
Expand All @@ -136,7 +136,7 @@ void BLASKernelGenerator<hw>::loadMask(MaskAssignment assignment, Subregister in
auto flagType = flag.getType();
auto mask0Type = getBytes(flagType) >= 4 ? DataType::uq : flagType;

if (vmask.rsize == 1 && vmask.rdivide == 1) {
if (vmask.rsize == 1 && vmask.rshift == 0) {
// Simple threshold comparison.
offset += assignment.offset;
if (flag.isARF())
Expand All @@ -152,11 +152,11 @@ void BLASKernelGenerator<hw>::loadMask(MaskAssignment assignment, Subregister in
auto mask0 = state.ra.alloc_sub(mask0Type, getHint(HintType::Bank1));
auto mask = mask0.reinterpret(0, flagType);
auto mindex = index;
auto rdivide = 1 << vmask.rshift;

if (vmask.rdivide > 1) {
if (!is_zero_or_pow2(vmask.rdivide)) stub();
add(1 | sat, temp, mindex, -offset + vmask.rdivide - 1);
shr(1, temp, temp, uint16_t(ilog2(vmask.rdivide)));
if (vmask.rshift) {
add(1 | sat, temp, mindex, -offset + rdivide - 1);
shr(1, temp, temp, uint16_t(vmask.rshift));
mindex = temp;
offset = 0;
}
Expand All @@ -169,7 +169,7 @@ void BLASKernelGenerator<hw>::loadMask(MaskAssignment assignment, Subregister in
mulConstant(1, temp, mindex, vmask.bitRep);
mindex = temp;
}
uint16_t tshift = vmask.bitRep * (rsizeScaled + div_up(assignment.offset + offset, vmask.rdivide));
uint16_t tshift = vmask.bitRep * (rsizeScaled + div_up(assignment.offset + offset, rdivide));
add(1 | sat, temp, -mindex, tshift);
if (tshift >= 32)
min_(1, temp, temp, vmask.bitRep * rsizeScaled); // Ensure shift count doesn't overflow.
Expand Down
2 changes: 1 addition & 1 deletion src/gpu/intel/jit/gemm/generator/pieces/register_block.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ struct MaskInfo {
struct {
uint8_t isFixed : 1; // = false (variable mask)
uint8_t reverse : 1; // True to reverse mask.
uint8_t rdivide : 6; // Amount by which to divide index before forming mask. Fractions are rounded up.
uint8_t rshift : 6; // Power of 2 by which to divide index before forming mask. Fractions are rounded up.
// Note maskRep * bitRep * (rsize >> rshift) = # mask bits.
uint8_t rsize; // Maximum remainder value. (e.g. 16 if we need the last 4 bits of the index).
uint8_t maskRep; // # of repetitions of mask pattern.
Expand Down

0 comments on commit 280bd28

Please sign in to comment.