Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BFS traversal could not visit some nodes when applying grid swizzle with matmul scheduler. #3962

Closed
rdspring1 opened this issue Feb 25, 2025 · 15 comments · Fixed by #4012
Closed
Assignees
Labels

Comments

@rdspring1
Copy link
Collaborator

To Reproduce:
NOTE: grid_swizzle_factor is set to 8 in Matmul Parameters.

TEST_F(HopperMatmulTest, MLPGemmPersistentBroadcastInputs) {
  EnableOptionsGuard eog;
  EnableOptionsGuard::getCurOptions().set(EnableOption::FuseMultipleMatmuls);

  Fusion fusion;
  FusionGuard fg(&fusion);

  constexpr int64_t M = 8192, N = 8192, K = 8192;
  const auto dtype = DataType::BFloat16;

  auto tv0 = makeContigConcreteTensor({-1, 1, -1}, dtype); // M, 1, K
  auto tv1 = makeContigConcreteTensor({1, -1, -1}, dtype); // 1, N, K
  auto tv2 = makeContigConcreteTensor({1, -1, -1}, dtype); // 1, N, K
  fusion.addInput(tv0);
  fusion.addInput(tv1);
  fusion.addInput(tv2);

  auto tv4 = fusedMultiplySum(tv0, tv1, {2});
  auto tv3 = castOp(dtype, tv4);
  fusion.addOutput(tv3);

  auto options = at::TensorOptions().dtype(at::kBFloat16).device(at::kCUDA);
  auto a_ref = at::randn({M, 1, K}, options);
  auto b_ref = at::randn({1, N, K}, options);
  auto c_ref = at::randn({1, N, K}, options);
  clearL2Cache();

  auto tv3_ref = at::linear(a_ref.squeeze(), b_ref.squeeze());
  clearL2Cache();

  MatMulTileOptions gemm_tile;
  gemm_tile.cta_tile = GemmTile(128, 256, 64);
  gemm_tile.warp_tile = GemmTile(64, 256, 64);

  MatmulParams mparams;
  mparams.supported_vec_size = {8, 8, 8};
  mparams.mma_macro = MmaMacro::Hopper_64_256_16;
  mparams.tile_sizes = gemm_tile;
  mparams.cta_order = MatmulParams::TileRasterizationOrder::RowMajor;
  mparams.async_gmem_load_operands = true;
  mparams.circular_buffering_strategy =
      MatmulParams::CircularBufferingStrategy::WarpSpecialized;
  mparams.tiling_strategy =
      MatmulParams::TilingStrategy::DistributeTilesAcrossSMs;
  mparams.circular_buffer_options.circular_buffer_smem_write = true;
  mparams.circular_buffer_options.circular_buffer_smem_read = false;
  mparams.grid_swizzle_factor = 8;
  // TODO reduced share memory aliasing because of persistent scheduling
  mparams.circular_buffer_options.smem_circular_buffer_stage = 3;
  mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1;
  mparams.splitk_factor = 1;
  mparams.use_smem_epilogue = true;
  mparams.cluster_dims = {2, 1, 1};
  mparams.promote_prologue_smem_reuse = true;

  SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
      ->schedule(&fusion, &mparams);
  std::vector<c10::IValue> inputs = {a_ref, b_ref, c_ref};

  KernelExecutor ke;
  ke.compile(&fusion, inputs);
  EXPECT_TRUE(getBankConflictInfo(ke.compiledKernel()->kernel()).empty());
  auto cg_outputs = ke.run(inputs);
  ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
      ke.compiledKernel()->kernel()));

  // Relax tolerance for larger sum due to large K
  EXPECT_TRUE(cg_outputs[0].allclose(tv3_ref, 1e-6 * K, 1e-6 * K));
}

Error

C++ exception with description " INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/bfs.h":241, 
please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. BFS traversal could not visit some nodes:  idg{32} idg{33} (from:  idg{80 117 127 137 239 241 261 263} idg{81 118 128 138 240 242 262 264} idg{22 46 70} idg{82 92 139 147 148 156 157 165} idg{84 141 150 159} idg{86 88 143 145 152 154 161 163} idg{90} idg{83 85 140 142 149 151 158 160} idg{49 59 73 87 89 110 120 130 144 146 153 155 162 164 230 244 252 254} idg{91}), visited: ( idg{56} idg{51} idg{41} idg{48 72 109 119 129 229 243 251 253} idg{55} idg{57} idg{212 214 220 223 225 227} idg{195 209} idg{194 208} idg{101 174 184} idg{69} idg{68} idg{53 54 77 114 124 134 234 256} idg{198} idg{39 40} idg{100 173 183} idg{65} idg{3 17 50} idg{197} idg{186 200} idg{36 60} idg{90} idg{216 218 221 224 226 228} idg{44} idg{79 116 126 136 237 238 259 260} idg{199} idg{45} idg{35} idg{96 169 179} idg{105} idg{187 201} idg{84 141 150 159} idg{192 206} idg{23 37 47 61 71} idg{43 67} idg{86 88 143 145 152 154 161 163} idg{213 215 217 219 222} idg{91} idg{81 118 128 138 240 242 262 264} idg{83 85 140 142 149 151 158 160} idg{191 205} idg{80 117 127 137 239 241 261 263} idg{27 75 112 122 132 232 246 248 250} idg{58} idg{103} idg{22 46 70} idg{93 166 176 188 202} idg{62 97 170 180} idg{63 64 98 171 181} idg{104} idg{106} idg{102 175 185} idg{95 168 178} idg{49 59 73 87 89 110 120 130 144 146 153 155 162 164 230 244 252 254} idg{26 74 111 121 131 231 245 247 249} idg{38} idg{94 167 177 189 190 203 204} idg{82 92 139 147 148 156 157 165} idg{196} idg{99 172 182} idg{78 115 125 135 235 236 257 258} idg{34} idg{52 76 113 123 133 233 255} idg{42 66} idg{193 207})

Backtrace

#0  0x0000ffffa17ea8c0 in __cxa_throw () from /usr/lib/aarch64-linux-gnu/libstdc++.so.6
#1  0x0000aaaaaab7b874 [PAC] in nvfuser::nvfCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) ()
#2  0x0000aaaaaadd0e50 in nvfuser::nvfErrorFail(char const*, char const*, unsigned int, char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) ()
#3  0x0000aaaaaacdd998 in std::pair<nvfuser::ValGraphBFS::ExprPath, bool> nvfuser::getExprsBetween<nvfuser::ValGraphBFS, nvfuser::ValGraph>(std::vector<nvfuser::ValGraphBFS::ValType, std::allocator<nvfuser::ValGraphBFS::ValType> > const&, std::vector<nvfuser::ValGraphBFS::ValType, std::allocator<nvfuser::ValGraphBFS::ValType> > const&, bool, nvfuser::Direction, nvfuser::ValGraph const&) ()
#4  0x0000aaaaaacddd4c in nvfuser::getInnerMmaLoopGroup(nvfuser::TensorView*, nvfuser::MmaOp const*) ()
#5  0x0000aaaaaacde6c4 in nvfuser::getInnerStrideBytes(nvfuser::TensorView*, nvfuser::MmaOp const*) ()
#6  0x0000aaaaaacdf848 in nvfuser::IndexLowering::handle(nvfuser::MmaOp const*) ()
#7  0x0000aaaaaacdbfc4 in nvfuser::IndexLowering::handle(nvfuser::kir::IfThenElse const*) ()
#8  0x0000aaaaaacc7c74 in nvfuser::IndexLowering::handle(nvfuser::ForLoop const*) ()
#9  0x0000aaaaaacc7c74 in nvfuser::IndexLowering::handle(nvfuser::ForLoop const*) ()
#10 0x0000aaaaaacc7c74 in nvfuser::IndexLowering::handle(nvfuser::ForLoop const*) ()
#11 0x0000aaaaaacc7c74 in nvfuser::IndexLowering::handle(nvfuser::ForLoop const*) ()
#12 0x0000aaaaaacc7c74 in nvfuser::IndexLowering::handle(nvfuser::ForLoop const*) ()
#13 0x0000aaaaaacc7c74 in nvfuser::IndexLowering::handle(nvfuser::ForLoop const*) ()
#14 0x0000aaaaaacc7c74 in nvfuser::IndexLowering::handle(nvfuser::ForLoop const*) ()
#15 0x0000aaaaaacc7c74 in nvfuser::IndexLowering::handle(nvfuser::ForLoop const*) ()
#16 0x0000aaaaaacdc004 in nvfuser::IndexLowering::handle(nvfuser::kir::IfThenElse const*) ()
#17 0x0000aaaaaacc7c74 in nvfuser::IndexLowering::handle(nvfuser::ForLoop const*) ()
#18 0x0000aaaaaacc7c74 in nvfuser::IndexLowering::handle(nvfuser::ForLoop const*) ()
#19 0x0000aaaaaacc6424 in nvfuser::IndexLowering::getIndexedExprs(std::vector<nvfuser::Expr*, std::allocator<nvfuser::Expr*> >) ()
#20 0x0000aaaaaac89af4 in std::_Function_handler<std::vector<nvfuser::Expr*, std::allocator<nvfuser::Expr*> > (std::vector<nvfuser::Expr*, std::allocator<nvfuser::Expr*> > const&), std::vector<nvfuser::Expr*, std::allocator<nvfuser::Expr*> > (*)(std::vector<nvfuser::Expr*, std::allocator<nvfuser::Expr*> >)>::_M_invoke(std::_Any_data const&, std::vector<nvfuser::Expr*, std::allocator<nvfuser::Expr*> > const&) ()
#21 0x0000aaaaaac8396c in nvfuser::GpuLower::run() ()
#22 0x0000aaaaab138078 in nvfuser::CompiledKernel::CompiledKernel(nvfuser::Fusion*, nvfuser::CompileParams, c10::Device, nvfuser::SchedulerType, long, long, long, long, std::vector<std::function<void (nvfuser::GpuLower*)>, std::allocator<std::function<void (nvfuser::GpuLower*)> > > const&, std::vector<std::function<void (nvfuser::kir::Kernel*)>, std::allocator<std::function<void (nvfuser::kir::Kernel*)> > > const&) ()
#23 0x0000aaaaab1493b8 in nvfuser::KernelExecutor::compile(nvfuser::Fusion*, nvfuser::KernelArgumentHolder const&, nvfuser::LaunchParams const&, nvfuser::CompileParams, nvfuser::SchedulerType) ()
#24 0x0000aaaaab405460 in nvfuser::HopperMatmulTest_MLPGemmPersistentBroadcastInputs_Test::TestBody() ()
#25 0x0000aaaaab50f5e0 in void testing::internal::HandleExceptionsInMethodIfSupported<testing::Test, void>(testing::Test*, void (testing::Test::*)(), char const*) ()
#26 0x0000aaaaab4f83a4 in testing::Test::Run() [clone .part.0] ()
#27 0x0000aaaaab4f8898 in testing::TestInfo::Run() ()
#28 0x0000aaaaab4f8e94 in testing::TestSuite::Run() [clone .part.0] ()
#29 0x0000aaaaab506140 in testing::internal::UnitTestImpl::RunAllTests() ()
#30 0x0000aaaaab4f9070 in testing::UnitTest::Run() ()
#31 0x0000aaaaaab87520 in main ()
@rdspring1 rdspring1 self-assigned this Feb 25, 2025
@jacobhinkle
Copy link
Collaborator

This happens with non-persistent too currently (i.e. OneTilePerCTA instead of DistributeTilesAcrossCTAs).

We start with something like this:

A[ iS0{i0} bS1{1} ... ]
B[ bS2{1} iS3{i1} ... ]
C[ iS4{i0} iS5{i1} ... ] = mma(A, B)

(The actual extents are more complicated because at this point we have already done the cta tile split). When we apply the grid swizzle with RowMajor order, we split bS1 and iS3 by the swizzle factor then merge the resulting inner dim with iS0 and bS2 respectively:

Split bS1{1} by 8 -> bS6{1}, bS7{8}
Merge iS0{i0}, bS7{8} -> iS8{i0}
Split iS3{i1} by 8 -> iS9{ceilDiv(i1, 8)}, iS10{8}
Merge bS2{1}, iS10{8} -> iS11{8}
Split iS5{i1} by 8 -> iS12{ceilDiv(i1, 8)}, iS13{8}
Merge iS4{i0}, iS13{8} -> iS14{i0*8}

A[ iS8{i0} bS6{1} ... ]
B[ iS11{8} iS9{ceilDiv(i1, 8)} ... ]
C[ iS14{i0*8} iS12{ceilDiv(i1, 8)} ... ] = mma(A, B)

The indexing traversal graph is the ALMOSTEXACT graph:

// IdMappingMode::ALMOSTEXACT
// Forward through broadcast axes, but not through to a non-broadcast axis
// i.e. id{b1*i0}, id{i0} are mapped
// id{i1*i0}, id{i0} are not mapped (this part is the difference from
// PERMISSIVE)
// Forward through split one axes, i.e. id{ceilDiv(i0, 1)}, id{i0} are mapped

I think the problem is the mixing of broadcast and iteration domains here. When we don't do grid swizzling, there is a path from iS4 to iS0 and from iS5 to iS3, and we do not need any path to bS1 or bS2 since those are broadcast dimensions.

However, we do not find a path from C's loop domain to the alloc dom of B above, in particular we can't get to iS9. I don't get how this could be possible though since the IDs are connected to their logical domains, and those are mapped together in the ALMOSTEXACT graph...

For what it's worth, grid swizzling is done the same way on Ampere but we do not encounter this bug (I set mparams.grid_swizzle_factor = 8 in MatmulTestWithLayout.AmpereMatmul to test it).

@naoyam
Copy link
Collaborator

naoyam commented Mar 3, 2025

However, we do not find a path from C's loop domain to the alloc dom of B above, in particular we can't get to iS9.

What does the loop domain look like? Is it C[ iS14{i0*8} iS12{ceilDiv(i1, 8)} ... ]? If so, isn't iS12 mapped with iS9?

@jacobhinkle
Copy link
Collaborator

However, we do not find a path from C's loop domain to the alloc dom of B above, in particular we can't get to iS9.

What does the loop domain look like? Is it C[ iS14{i0*8} iS12{ceilDiv(i1, 8)} ... ]? If so, isn't iS12 mapped with iS9?

These are the two actual tensors. T4_s is the "A" input and T2_l is the mma output.

T4_s___bfloat[iblockIdx.x27{( ( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 128) ) * 8 )}, bblockIdx.y25{1}, iS19{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[2] ), 64) )}, bS22{256}, iS28{1}, iB30{16}, iB36{8}, iB33{1}, iB37{8}, iB35{8}] ca_pos( 3 )
 logical domain : (iS187{( (( (( getMetaData(T0) )).logical_size ))[0] )}, bS12{1}, iS188{( (( (( getMetaData(T0) )).logical_size ))[2] )})
 allocation domain : (iblockIdx.x27{( ( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 128) ) * 8 )}, bblockIdx.y25{1}, iS19{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[2] ), 64) )}, bS22{256}, iS28{1}, iB30{16}, iB36{8}, iB33{1}, iB37{8}, iB35{8})
 contiguity: t n t n t t t t t t
  Split: iS187{( (( (( getMetaData(T0) )).logical_size ))[0] )} by factor 128 -> iS23{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 128) )}, iS24{128}
  Split: bS12{1} by factor 256 -> bS21{1}, bS22{256}
  Split: bS21{1} by factor 8 -> bblockIdx.y25{1}, bS26{8}
  Merge: iS23{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 128) )} and bS26{8} -> iblockIdx.x27{( ( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 128) ) * 8 )}
  Split: iS188{( (( (( getMetaData(T0) )).logical_size ))[2] )} by factor 64 -> iS19{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[2] ), 64) )}, iS20{64}
  Split: iS20{64} by factor 64 -> iS28{1}, iS29{64}
  Split: iS24{128} by factor 8 -> iB30{16}, iS31{8}
  Split: iS31{8} by factor 1 -> iS32{8}, iB33{1}
  Split: iS29{64} by factor 8 -> iS34{8}, iB35{8}
  Xor(2D): iS32{8} , iS34{8} -> iB36{8} , iB37{8}
 loop domain : (iblockIdx.x27{( ( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 128) ) * 8 )}, bblockIdx.y25{1}, iS19{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[2] ), 64) )}, bS22{256}, iS28{1}, iB30{16}, iB36{8}, iB33{1}, iB37{8}, iB35{8})

T2_l_float[iblockIdx.x65{( ( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 128) ) * 8 )}, iblockIdx.y63{( ceilDiv(( ceilDiv(( (( (( getMetaData(T1) )).logical_size ))[1] ), 256) ), 8) )}, rS57{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[2] ), 64) )}, ithreadIdx.y76{2}, iS68{1}, iS72{1}, rS74{4}, iMMA69{64}, iMMA73{256}, rMMA75{16}] ca_pos( 2 ) produce_pos( 3 )
 logical domain : (iS193{( (( (( getMetaData(T0) )).logical_size ))[0] )}, iS194{( (( (( getMetaData(T1) )).logical_size ))[1] )}, rS195{( (( (( getMetaData(T0) )).logical_size ))[2] )})
 allocation domain : (iblockIdx.x65{( ( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 128) ) * 8 )}, iblockIdx.y63{( ceilDiv(( ceilDiv(( (( (( getMetaData(T1) )).logical_size ))[1] ), 256) ), 8) )}, rS57{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[2] ), 64) )}, ithreadIdx.y76{2}, iS68{1}, iS72{1}, rS74{4}, ithreadIdx.x86{128}, iMMA81{32}, iMMA80{2}, iMMA84{2}, rMMA89{2}, rMMA90{4}, rMMA88{2})
 contiguity: t t n t t t n t t t t n n n
  Split: iS193{( (( (( getMetaData(T0) )).logical_size ))[0] )} by factor 128 -> iS61{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 128) )}, iS62{128}
  Split: iS194{( (( (( getMetaData(T1) )).logical_size ))[1] )} by factor 256 -> iS59{( ceilDiv(( (( (( getMetaData(T1) )).logical_size ))[1] ), 256) )}, iS60{256}
  Split: iS59{( ceilDiv(( (( (( getMetaData(T1) )).logical_size ))[1] ), 256) )} by factor 8 -> iblockIdx.y63{( ceilDiv(( ceilDiv(( (( (( getMetaData(T1) )).logical_size ))[1] ), 256) ), 8) )}, iS64{8}
  Merge: iS61{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 128) )} and iS64{8} -> iblockIdx.x65{( ( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 128) ) * 8 )}
  Split: rS195{( (( (( getMetaData(T0) )).logical_size ))[2] )} by factor 64 -> rS57{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[2] ), 64) )}, rS58{64}
  Split: iS62{128} by factor 64 -> iS66{2}, iS67{64}
  Split: iS60{256} by factor 256 -> iS70{1}, iS71{256}
  Merge: iS66{2} and iS70{1} -> ithreadIdx.y76{2}
  Split: iS67{64} by factor 64 -> iS68{1}, iMMA69{64}
  Split: iS71{256} by factor 256 -> iS72{1}, iMMA73{256}
  Split: rS58{64} by factor 16 -> rS74{4}, rMMA75{16}
 loop domain : (iblockIdx.x65{( ( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 128) ) * 8 )}, iblockIdx.y63{( ceilDiv(( ceilDiv(( (( (( getMetaData(T1) )).logical_size ))[1] ), 256) ), 8) )}, rS57{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[2] ), 64) )}, ithreadIdx.y76{2}, iS68{1}, iS72{1}, rS74{4}, iMMA69{64}, iMMA73{256}, rMMA75{16})

The ALMOSTEXACT graph looks like this:

IdGraph {
Disjoint Ids:
  (idgs){
    idg{1 12 21 25}
    idg{3 14 42}
    idg{19 38 57}
    idg{20 29 39 48 58}
    idg{22}
    idg{23 61 95 102 109 204 212 214 216}
    idg{24 62 96 103 110 205 213 215 217}
    idg{26}
    idg{27}
    idg{28 47}
    idg{30}
    idg{31 32}
    idg{33}
    idg{34 53}
    idg{35 54}
    idg{36}
    idg{37}
    idg{40 59 93 100 107 202 210 218 220}
    idg{41 60 71 73 94 101 108 119 121 128 130 137 139 203 211 219 221}
    idg{43}
    idg{44 63 97 104 111 206 222}
    idg{45 46 64 98 105 112 207 223}
    idg{49 81 145 155}
    idg{50 51 82 146 156}
    idg{52}
    idg{55}
    idg{56}
    idg{65 99 106 113 208 209 224 225}
    idg{66 76 114 122 123 131 132 140}
    idg{67 69 115 117 124 126 133 135}
    idg{68 116 125 134}
    idg{70 72 118 120 127 129 136 138}
    idg{74}
    idg{75}
    idg{77 141 151 163 177}
    idg{78 142 152 164 165 178 179}
    idg{79 143 153}
    idg{80 144 154}
    idg{83 147 157}
    idg{84 148 158}
    idg{85 149 159}
    idg{86 150 160}
    idg{87}
    idg{88}
    idg{89}
    idg{90}
    idg{161 175}
    idg{162 176}
    idg{166 180}
    idg{167 181}
    idg{168 182}
    idg{169 183}
    idg{170 184}
    idg{171}
    idg{172}
    idg{173}
    idg{174}
    idg{185 187 193 196 198 200}
    idg{186 188 190 192 195}
    idg{189 191 194 197 199 201}
}

Looking at the two logical domains, 193 and 187 are mapped but 12 and 194 are not mapped.

@naoyam
Copy link
Collaborator

naoyam commented Mar 3, 2025

In this case, which ID of T4 is not reachable?

@jacobhinkle
Copy link
Collaborator

In this case, which ID of T4 is not reachable?

Only iblockIdx.x27{( ( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 128) ) * 8 )} is not reachable.

@jacobhinkle
Copy link
Collaborator

jacobhinkle commented Mar 3, 2025

The subgraph for T4_s looks like this:

graph TD;

iS187 --> iS23 & iS24
iS23 --> iblockIdx.x27
bS21 --> bblockIdx.y25 & bS26
bS26 --> iblockIdx.x27
Loading

So the paths go through the logical IDs iS187 and bS21.

For T2_l we have

graph TD;

iS193 --> iS61 & iS62
iS61 -->  iblockIdx.x65
iS194 --> iS59 & iS60
iS59 --> iblockIdx.y63 & iS64
iS64 --> iblockIdx.x65

Loading

where 193 and 194 are the M and N logical domains.

193 and 187 are mapped but not 194 and 21.

I think in the ALMOSTEXACT map, we do not map these Iteration and Broadcast domains because they are not directly used in merges with matching IDs, but only in a split.

@jacobhinkle
Copy link
Collaborator

Actually now that I draw it, the problem could be the extra split on iS194 by 256. Maybe we are skipping the CTA tile split on broadcast dims?? I will check that and report back.

@jacobhinkle
Copy link
Collaborator

Actually now that I draw it, the problem could be the extra split on iS194 by 256. Maybe we are skipping the CTA tile split on broadcast dims?? I will check that and report back.

This explanation makes sense. When we started respecting the warp tile split, I believe we only modified the mma result: #3636 . That was fine without swizzling because the scheduling of the TMA load's broadcast dims didn't matter but with grid swizzling there are transforms mixing the two so it's important to schedule the broadcast dims too. I will do this!

@naoyam
Copy link
Collaborator

naoyam commented Mar 3, 2025

In this case, which ID of T4 is not reachable?

Only iblockIdx.x27{( ( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 128) ) * 8 )} is not reachable.

Something seems off. This ID is parallelized with BIDx, so it should not be part of T4's allocation domain as it's on shared memory. I haven't looked at the details of the graph yet, but indexing shouldn't try to reach this node.

@jacobhinkle
Copy link
Collaborator

In this case, which ID of T4 is not reachable?

Only iblockIdx.x27{( ( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 128) ) * 8 )} is not reachable.

Something seems off. This ID is parallelized with BIDx, so it should not be part of T4's allocation domain as it's on shared memory. I haven't looked at the details of the graph yet, but indexing shouldn't try to reach this node.

That is a good point. This occurs in getInnerMmaLoopGroup. I could pre-filter the allocation domain by ir_utils::isMemorySharedAcross(tv->getMemoryType(), id->getParallelType()).

@naoyam
Copy link
Collaborator

naoyam commented Mar 3, 2025

The subgraph for T4_s looks like this:

So the paths go through the logical IDs iS187 and bS21.

For T2_l we have

where 193 and 194 are the M and N logical domains.

193 and 187 are mapped but not 194 and 21.

I think in the ALMOSTEXACT map, we do not map these Iteration and Broadcast domains because they are not directly used in merges with matching IDs, but only in a split.

Didn't know you can embed graph diagrams 👍

@jacobhinkle
Copy link
Collaborator

jacobhinkle commented Mar 3, 2025

I could pre-filter the allocation domain by ir_utils::isMemorySharedAcross(tv->getMemoryType(), id->getParallelType()).

I confirmed that changing the definition of alloc_domain to the following in getInnerMmaLoopGroup avoids the error and the test passes:

  std::vector<IterDomain*> indexed_alloc_ids;
  for (IterDomain* id : tv->getMaybeAllocationDomain()) {
    const ParallelType ptype = id->getParallelType();
    if (!id->isBroadcast() &&
        (ptype == ParallelType::Serial || ptype == ParallelType::Bulk ||
         ir_utils::isMemorySharedAcross(tv->getMemoryType(), ptype))) {
      indexed_alloc_ids.push_back(id);
    }
  }
  auto alloc_domain = id_graph.toGroups(indexed_alloc_ids);

EDIT: This fixes the non-persistent case but persistent still fails. I'm going to look into just scheduling those broadcast IDs again..

@jacobhinkle
Copy link
Collaborator

Another way to go would be to enforce that the indexing traversal graph maps concretized broadcasts i.e. adding a mapping between bS21 and iS194.

@naoyam
Copy link
Collaborator

naoyam commented Mar 4, 2025

Another way to go would be to enforce that the indexing traversal graph maps concretized broadcasts i.e. adding a mapping between bS21 and iS194.

Broadcast IDs shouldn't be mapped for indexing. For example, consider:

// t0: [i0, b1]
// t1: [i2, i3]
t2 = t0 + t1; // [i4, i5]

for (auto tv: {t0, t1, t2}) {
  tv->merge(0, 1);
}

// t0: [i6]
// t1: [i7]
// t2: [i8]

inlineMost();

Then, if broadcast IDs were mapped, that is, {b1, i3, i5} were mapped together, then the merge outputs would be also mapped, i.e., {i6, i7, i8}. That would mean the same index would be assigned to all of the tensors, which shouldn't be the case.

@jacobhinkle
Copy link
Collaborator

jacobhinkle commented Mar 4, 2025

I wonder if it is safe to just use /*require_all_to_visited=*/false in this instance. These IDs we would miss are not going to be the innermost alloc IDs anyway, and we already error out if we cannot find the innermost allocation ID in the loop domain.

jacobhinkle added a commit that referenced this issue Mar 5, 2025
In the indexing lowering pass, getInnerMmaLoopGroup is used to determine
the inner and outer MMA strides, which is used to create the wgmma
descriptor. The purpose of that function is to obtain the ValGroup of
the consumer loop ID corresponding to the innermost allocation domain of
the producer. So we do need to be able to traverse from consumer loop to
the innermost allocation group in the ValGraph, but we do not care about
visiting any other groups.

See #3962 (comment)
for an example where we cannot currently reach some of the outer
allocation dimensions if grid swizzling is used.

This PR just sets `require_all_to_visited` to false when performing the
BFS. A more involved fix might update
`ValGraphBFS::getExprGroupsBetween` to accept a vector of required
groups to be visited. Even better would be to understand fully why we
are unable to visit the grid swizzled allocation domains and address
that instead.

Fixes #3962
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants