-
Notifications
You must be signed in to change notification settings - Fork 631
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[GPUHeuristic] Modify schedule generator to consider distribution of tranfer_read layout anchor #17636
[GPUHeuristic] Modify schedule generator to consider distribution of tranfer_read layout anchor #17636
Conversation
…r_read. Currently we are generating invalid schedules who's transfer read cannot be distributed because the sizes do not match up. For example in our case our [wgTileSize, elemPerThread, threadSize] = [192, 8, 128]. There is no good layout for this because, the numbers of threads needed would be 192/8 == 24. And Since the threadSize pre-determined by schedule is 128, 128 % 24 != 0. Hence we cannot distribute it. This patch teaches the schedule generator about these constraints. Signed-off-by: stanley-nod <[email protected]>
Signed-off-by: stanley-nod <[email protected]>
int64_t nTileSize = | ||
schedule.nSize * schedule.nTileCount * schedule.nWarpCount; | ||
bool isDistributableN = (nTileSize / elemsPerThread) % wgThreads == 0 || | ||
wgThreads % (nTileSize / elemsPerThread) == 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you commented that this will work for matmul_transpose_b
but not for matmul
because it depends on which dimension is inner most. Can we add identifiers for which dimension is inner most to GPUMatmulShapeType
to inform this heuristic about when to check for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, we can do that let me think of a nice way to shuttle this data around. :)
Still no guarantees(though unlikely) that somewhere down the line this information may change, but defo better than no heuristics haha.
Signed-off-by: stanley-nod <[email protected]>
Signed-off-by: stanley-nod <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's worth noting that the heuristic logic is starting to get very involved/opinionated due to quirks of the lowering pipelines. It will be hard to maintain this state moving forward unless we can find time to start cleaning up tech debt of unhandled cases. Approving for now because I don't have a better suggestion and don't want to block, but we should fix codegen to not fail on certain valid lowering configs.
return op.emitError("kDim or nDim not found in RHS indexing map."); | ||
} | ||
bool transposedLhs = lhsMDim.value() > lhsKDim.value(); | ||
bool transposedRhs = rhsKDim.value() > rhsNDim.value(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Slightly simpler could be to just compare cast<AffineDimExpr>(maps[0].getResults().back()).getPosition()
with mIndex
and kIndex
. We can bail out if neither of them are inner most for now (I do not think we've hit that case, and the pipeline more or less assumes that we don't). Then we don't need any calls to getResultPosition
which scans the whole map.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice idea, thanks :)
auto maps = op.getIndexingMapsArray(); | ||
OpBuilder b(op); | ||
auto lhsMDim = maps[0].getResultPosition(b.getAffineDimExpr(mDim)); | ||
auto lhsKDim = maps[0].getResultPosition(b.getAffineDimExpr(kDim)); | ||
if (!lhsMDim.has_value() || !lhsKDim.has_value()) { | ||
return op.emitError("mDim or kDim not found in LHS indexing map."); | ||
} | ||
auto rhsKDim = maps[1].getResultPosition(b.getAffineDimExpr(kDim)); | ||
auto rhsNDim = maps[1].getResultPosition(b.getAffineDimExpr(nDim)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Don't use auto when the type is not obvious based on the RHS only
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, after simplification :)
auto rhsKDim = maps[1].getResultPosition(b.getAffineDimExpr(kDim)); | ||
auto rhsNDim = maps[1].getResultPosition(b.getAffineDimExpr(nDim)); | ||
if (!rhsKDim.has_value() || !rhsNDim.has_value()) { | ||
return op.emitError("kDim or nDim not found in RHS indexing map."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this an op error (IE bad input IR) or something we should use an assertion for (logic error)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see I see, it is a logic error, I'll change it to assert :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, dissappears after simplification :)
auto maps = op.getIndexingMapsArray(); | ||
OpBuilder b(op); | ||
auto lhsMDim = maps[0].getResultPosition(b.getAffineDimExpr(mIndex)); | ||
auto lhsKDim = maps[0].getResultPosition(b.getAffineDimExpr(kIndex)); | ||
if (!lhsMDim.has_value() || !lhsKDim.has_value()) { | ||
return op.emitError("mDim or kDim not found in LHS indexing map."); | ||
} | ||
auto rhsKDim = maps[1].getResultPosition(b.getAffineDimExpr(kIndex)); | ||
auto rhsNDim = maps[1].getResultPosition(b.getAffineDimExpr(nIndex)); | ||
if (!rhsKDim.has_value() || !rhsNDim.has_value()) { | ||
return op.emitError("kDim or nDim not found in RHS indexing map."); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, dissappears after simplification :)
Signed-off-by: stanley-nod <[email protected]>
Signed-off-by: stanley-nod <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, just 2 remaining nits
deduceMMASchedule(problem, intrinsics, seeds, maxSharedMemoryBytes); | ||
|
||
// Infer if lhs or rhs is transposed to help generate better schedule. | ||
auto maps = op.getIndexingMapsArray(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks2 done :)
@@ -928,6 +921,23 @@ LogicalResult setCooperativeMatrixConfig( | |||
subgroupSize = *minSize; | |||
} | |||
|
|||
// Infer if lhs or rhs is transposed to help generate better schedule. | |||
auto maps = op.getIndexingMapsArray(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also here
Signed-off-by: stanley-nod <[email protected]>
…tranfer_read layout anchor (iree-org#17636) Modify heuristic to take into account layout of transfer reads, S.T we will not generate invalid schedules who's transfer read cannot be distributed because the sizes do not match up. For example in one matmul with N-dim with these sizes [wgTileSize, elemPerThread, threadSize] = [192, 8, 128]. There is no good layout for this because, the numbers of threads needed would be 192/8 == 24, and Since the threadSize pre-determined by schedule is 128, we will have 128 % 24 != 0. Hence we cannot distribute it. This patch introduce constraints in our heuristic to solve these cases. --------- Signed-off-by: stanley-nod <[email protected]> Signed-off-by: Lubo Litchev <[email protected]>
Currently we are generating invalid schedules who's transfer read cannot be distributed because the sizes do not match up.
For example in our case our [wgTileSize, elemPerThread, threadSize] = [192, 8, 128]. There is no good layout for this because, the numbers of threads needed would be 192/8 == 24. And Since the threadSize pre-determined by schedule is 128, 128 % 24 != 0. Hence we cannot distribute it.
This patch teaches the schedule generator about these constraints.