-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BFS traversal could not visit some nodes when applying grid swizzle with matmul scheduler. #3962
Comments
This happens with non-persistent too currently (i.e. We start with something like this:
(The actual extents are more complicated because at this point we have already done the cta tile split). When we apply the grid swizzle with RowMajor order, we split
The indexing traversal graph is the Fuser/csrc/id_model/id_model.h Lines 98 to 103 in 70e1d5d
I think the problem is the mixing of broadcast and iteration domains here. When we don't do grid swizzling, there is a path from However, we do not find a path from C's loop domain to the alloc dom of B above, in particular we can't get to For what it's worth, grid swizzling is done the same way on Ampere but we do not encounter this bug (I set |
What does the loop domain look like? Is it |
These are the two actual tensors.
The ALMOSTEXACT graph looks like this:
Looking at the two logical domains, 193 and 187 are mapped but 12 and 194 are not mapped. |
In this case, which ID of T4 is not reachable? |
Only |
The subgraph for graph TD;
iS187 --> iS23 & iS24
iS23 --> iblockIdx.x27
bS21 --> bblockIdx.y25 & bS26
bS26 --> iblockIdx.x27
So the paths go through the logical IDs For graph TD;
iS193 --> iS61 & iS62
iS61 --> iblockIdx.x65
iS194 --> iS59 & iS60
iS59 --> iblockIdx.y63 & iS64
iS64 --> iblockIdx.x65
where 193 and 194 are the M and N logical domains. 193 and 187 are mapped but not 194 and 21. I think in the ALMOSTEXACT map, we do not map these Iteration and Broadcast domains because they are not directly used in merges with matching IDs, but only in a split. |
Actually now that I draw it, the problem could be the extra split on |
This explanation makes sense. When we started respecting the warp tile split, I believe we only modified the mma result: #3636 . That was fine without swizzling because the scheduling of the TMA load's broadcast dims didn't matter but with grid swizzling there are transforms mixing the two so it's important to schedule the broadcast dims too. I will do this! |
Something seems off. This ID is parallelized with BIDx, so it should not be part of T4's allocation domain as it's on shared memory. I haven't looked at the details of the graph yet, but indexing shouldn't try to reach this node. |
That is a good point. This occurs in |
Didn't know you can embed graph diagrams 👍 |
I confirmed that changing the definition of std::vector<IterDomain*> indexed_alloc_ids;
for (IterDomain* id : tv->getMaybeAllocationDomain()) {
const ParallelType ptype = id->getParallelType();
if (!id->isBroadcast() &&
(ptype == ParallelType::Serial || ptype == ParallelType::Bulk ||
ir_utils::isMemorySharedAcross(tv->getMemoryType(), ptype))) {
indexed_alloc_ids.push_back(id);
}
}
auto alloc_domain = id_graph.toGroups(indexed_alloc_ids); EDIT: This fixes the non-persistent case but persistent still fails. I'm going to look into just scheduling those broadcast IDs again.. |
Another way to go would be to enforce that the indexing traversal graph maps concretized broadcasts i.e. adding a mapping between |
Broadcast IDs shouldn't be mapped for indexing. For example, consider:
Then, if broadcast IDs were mapped, that is, |
I wonder if it is safe to just use |
In the indexing lowering pass, getInnerMmaLoopGroup is used to determine the inner and outer MMA strides, which is used to create the wgmma descriptor. The purpose of that function is to obtain the ValGroup of the consumer loop ID corresponding to the innermost allocation domain of the producer. So we do need to be able to traverse from consumer loop to the innermost allocation group in the ValGraph, but we do not care about visiting any other groups. See #3962 (comment) for an example where we cannot currently reach some of the outer allocation dimensions if grid swizzling is used. This PR just sets `require_all_to_visited` to false when performing the BFS. A more involved fix might update `ValGraphBFS::getExprGroupsBetween` to accept a vector of required groups to be visited. Even better would be to understand fully why we are unable to visit the grid swizzled allocation domains and address that instead. Fixes #3962
To Reproduce:
NOTE:
grid_swizzle_factor
is set to 8 in Matmul Parameters.Error
Backtrace
The text was updated successfully, but these errors were encountered: