Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPUHeuristic] Modify schedule generator to consider distribution of tranfer_read layout anchor #17636

Merged
merged 7 commits into from
Jun 12, 2024

Conversation

raikonenfnu
Copy link
Collaborator

Currently we are generating invalid schedules who's transfer read cannot be distributed because the sizes do not match up.

For example in our case our [wgTileSize, elemPerThread, threadSize] = [192, 8, 128]. There is no good layout for this because, the numbers of threads needed would be 192/8 == 24. And Since the threadSize pre-determined by schedule is 128, 128 % 24 != 0. Hence we cannot distribute it.

This patch teaches the schedule generator about these constraints.

…r_read.

Currently we are generating invalid schedules who's transfer read cannot
be distributed because the sizes do not match up.

For example in our case our [wgTileSize, elemPerThread, threadSize] = [192,
8, 128]. There is no good layout for this because, the numbers of
threads needed would be 192/8 == 24. And Since the threadSize pre-determined by schedule is
128, 128 % 24 != 0. Hence we cannot distribute it.

This patch teaches the schedule generator about these constraints.

Signed-off-by: stanley-nod <[email protected]>
Signed-off-by: stanley-nod <[email protected]>
int64_t nTileSize =
schedule.nSize * schedule.nTileCount * schedule.nWarpCount;
bool isDistributableN = (nTileSize / elemsPerThread) % wgThreads == 0 ||
wgThreads % (nTileSize / elemsPerThread) == 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you commented that this will work for matmul_transpose_b but not for matmul because it depends on which dimension is inner most. Can we add identifiers for which dimension is inner most to GPUMatmulShapeType to inform this heuristic about when to check for this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we can do that let me think of a nice way to shuttle this data around. :)

Still no guarantees(though unlikely) that somewhere down the line this information may change, but defo better than no heuristics haha.

@raikonenfnu raikonenfnu requested a review from qedawkins June 11, 2024 23:40
Copy link
Contributor

@qedawkins qedawkins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's worth noting that the heuristic logic is starting to get very involved/opinionated due to quirks of the lowering pipelines. It will be hard to maintain this state moving forward unless we can find time to start cleaning up tech debt of unhandled cases. Approving for now because I don't have a better suggestion and don't want to block, but we should fix codegen to not fail on certain valid lowering configs.

return op.emitError("kDim or nDim not found in RHS indexing map.");
}
bool transposedLhs = lhsMDim.value() > lhsKDim.value();
bool transposedRhs = rhsKDim.value() > rhsNDim.value();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slightly simpler could be to just compare cast<AffineDimExpr>(maps[0].getResults().back()).getPosition() with mIndex and kIndex. We can bail out if neither of them are inner most for now (I do not think we've hit that case, and the pipeline more or less assumes that we don't). Then we don't need any calls to getResultPosition which scans the whole map.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice idea, thanks :)

Comment on lines 472 to 480
auto maps = op.getIndexingMapsArray();
OpBuilder b(op);
auto lhsMDim = maps[0].getResultPosition(b.getAffineDimExpr(mDim));
auto lhsKDim = maps[0].getResultPosition(b.getAffineDimExpr(kDim));
if (!lhsMDim.has_value() || !lhsKDim.has_value()) {
return op.emitError("mDim or kDim not found in LHS indexing map.");
}
auto rhsKDim = maps[1].getResultPosition(b.getAffineDimExpr(kDim));
auto rhsNDim = maps[1].getResultPosition(b.getAffineDimExpr(nDim));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Don't use auto when the type is not obvious based on the RHS only

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, after simplification :)

auto rhsKDim = maps[1].getResultPosition(b.getAffineDimExpr(kDim));
auto rhsNDim = maps[1].getResultPosition(b.getAffineDimExpr(nDim));
if (!rhsKDim.has_value() || !rhsNDim.has_value()) {
return op.emitError("kDim or nDim not found in RHS indexing map.");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this an op error (IE bad input IR) or something we should use an assertion for (logic error)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see I see, it is a logic error, I'll change it to assert :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, dissappears after simplification :)

Comment on lines 925 to 936
auto maps = op.getIndexingMapsArray();
OpBuilder b(op);
auto lhsMDim = maps[0].getResultPosition(b.getAffineDimExpr(mIndex));
auto lhsKDim = maps[0].getResultPosition(b.getAffineDimExpr(kIndex));
if (!lhsMDim.has_value() || !lhsKDim.has_value()) {
return op.emitError("mDim or kDim not found in LHS indexing map.");
}
auto rhsKDim = maps[1].getResultPosition(b.getAffineDimExpr(kIndex));
auto rhsNDim = maps[1].getResultPosition(b.getAffineDimExpr(nIndex));
if (!rhsKDim.has_value() || !rhsNDim.has_value()) {
return op.emitError("kDim or nDim not found in RHS indexing map.");
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, dissappears after simplification :)

Signed-off-by: stanley-nod <[email protected]>
@raikonenfnu raikonenfnu requested review from qedawkins and kuhar June 12, 2024 00:03
Signed-off-by: stanley-nod <[email protected]>
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, just 2 remaining nits

deduceMMASchedule(problem, intrinsics, seeds, maxSharedMemoryBytes);

// Infer if lhs or rhs is transposed to help generate better schedule.
auto maps = op.getIndexingMapsArray();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks2 done :)

@@ -928,6 +921,23 @@ LogicalResult setCooperativeMatrixConfig(
subgroupSize = *minSize;
}

// Infer if lhs or rhs is transposed to help generate better schedule.
auto maps = op.getIndexingMapsArray();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also here

Signed-off-by: stanley-nod <[email protected]>
@raikonenfnu raikonenfnu merged commit 52b21f8 into iree-org:main Jun 12, 2024
51 checks passed
LLITCHEV pushed a commit to LLITCHEV/iree that referenced this pull request Jul 30, 2024
…tranfer_read layout anchor (iree-org#17636)

Modify heuristic to take into account layout of transfer reads, S.T we will not generate invalid schedules who's transfer read cannot be distributed because the sizes do not match up.

For example in one matmul with N-dim with these sizes
[wgTileSize, elemPerThread, threadSize] = [192, 8, 128].
There is no good layout for this because, the numbers of threads
needed would be 192/8 == 24, and Since the threadSize pre-determined by
schedule is 128, we will have 128 % 24 != 0. Hence we cannot distribute it.

This patch introduce constraints in our heuristic to solve these cases.

---------

Signed-off-by: stanley-nod <[email protected]>
Signed-off-by: Lubo Litchev <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants