Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DID loop split on reshaped IDs #3875

Merged
merged 14 commits into from
Mar 6, 2025
53 changes: 43 additions & 10 deletions csrc/scheduler/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2227,25 +2227,58 @@ void propagateReshapeTransforms(Fusion* fusion, const ComputeAtMap& ca_map) {

std::unordered_map<int64_t, int64_t> old2new;
// Make sure rfactor dims we need are in domain, and reorder them in domain
// so they're consecutive starting from the left of domain. TODO: We could
// improve this so that if there's transformations replayed after the
// rfactor dims we could try and pull those through the fusion instead of
// enforcing rfactor dims are in domain.
// so they're consecutive starting from the left of domain.
// The reordering is to limit the propagation to only the view
// transformations.
for (auto logical_id : tv->getLogicalDomain()) {
if (terminating_reshape_dims.find(logical_id) !=
terminating_reshape_dims.end()) {
auto find_it = std::find(
tv->getLoopDomain().begin(), tv->getLoopDomain().end(), logical_id);
// The rfactor dims are not in the loop domain directly if they are
// sharded. For example, Consider the split reshape: `[h]->[a, h/a]` `h`
// and `a` are both sharded by `d`. The loop domain of the consumer is
// `[DIDx(d), a/d, h/a]`. Hence, we cannot directly find logical ID `a`
// in the loop domain. Similarly, for merge reshape: `[a, h/a]->[h]`, we
// cannot directly find `h` in the loop domain when `h` is sharded by
// `d`.

// Find all reachable ids between the logical id and the loop domain.
// If the ids are in the loop domain, reorder them to the front.
auto transforms = DependencyCheck::getAllExprsBetween(
{logical_id},
{tv->getLoopDomain().begin(), tv->getLoopDomain().end()});
std::vector<IterDomain*> reachable_ids;
// Add the logical id for the case where it is directly in the loop
// domain.
reachable_ids.push_back(logical_id);

for (auto expr : transforms) {
auto outputs = ir_utils::filterByType<IterDomain>(expr->outputs());
std::copy(
outputs.begin(),
outputs.end(),
std::back_inserter(reachable_ids));
}

bool has_reachable_loop_id = false;
for (auto id : reachable_ids) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know reachable_ids will be ordered -- it really depends on the implementation of getAllExprsBetween. Therefore, instead, I'd loop over the loop domain and try to find a match in reachable_ids (which probably should be a set instead of a vector). It's roughly the same logic but more deterministic and more aligned with the existing order in the loop domain.

auto find_it = std::find(
tv->getLoopDomain().begin(), tv->getLoopDomain().end(), id);
if (find_it == tv->getLoopDomain().end()) {
continue;
}
has_reachable_loop_id = true;
// Reorder the reshape dimensions to the front of the domain
int64_t old_pos = std::distance(tv->getLoopDomain().begin(), find_it);
old2new[old_pos] = (int64_t)old2new.size();
}

NVF_ERROR(
find_it != tv->getLoopDomain().end(),
has_reachable_loop_id,
"Require ",
logical_id,
" is in the active domain of ",
tv->toString(),
" for view propagation.");
int64_t old_pos = std::distance(tv->getLoopDomain().begin(), find_it);

old2new[old_pos] = (int64_t)old2new.size();
}
}

Expand Down
62 changes: 62 additions & 0 deletions tests/cpp/test_multidevice_sharding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -745,4 +745,66 @@ TEST_F(MultiDeviceTest, ReorderDIDToFront) {
__FILE__);
}

TEST_F(MultiDeviceTest, ShardedSplitReshapeIds) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

const int d = communicator_->size();
const int64_t b = 2, s = 2, h = 4, e = 3;

TensorView* tv0 = makeContigConcreteTensor(
{b, s, d * h * e}); // in: loop domain: {b, s, d*h*e}
TensorView* tv1 = reshape(
tv0,
{b, s, d * h * e},
{b, s, d * h, e}); // out: loop domain: {b, s, d*h, e}

fusion->addInput(tv0);
fusion->addOutput(tv1);

auto mesh = DeviceMesh::createForNumDevices(d);

// Propagate transform from reshaped output to input.
// Without this propagation, the two DID axes on `in` and `out` will not be
// mapped in together in ID model. This causes scheduling to fail due to
// resharding.
TransformPropagator propagator_c2p(tv1);
MaxLogicalDomainInfoSpanningTree(tv1).traverse(&propagator_c2p);
// in: loop domain: {b, s, d*h, e} after transform propagation

// Loop split and parallelize input
tv0->setDeviceMesh(mesh);
tv0->split(-2, d, /*inner_split=*/false);
tv0->axis(-3)->parallelize(ParallelType::DIDx);
// in: loop domain: {b, s, DIDx{d}, h, e}

// Propagate DID loop split to output
TransformPropagator propagator_p2c(tv0);
MaxLogicalDomainInfoSpanningTree(tv0).traverse(&propagator_p2c);
// out: loop domain: {b, s, d, h, e} after transform propagation

// Parallelize output
scheduler_utils::parallelizeAllLike(
tv0,
/*pos=*/-1,
/*selected_tv=*/{tv1});
// out: loop domain: {b, s, DIDx{d}, h, e} after parallelization

tv0->setAllocationDomain(tv0->getLoopDomain(), true);
tv1->setAllocationDomain(tv1->getLoopDomain(), true);

FusionExecutorCache executor_cache(std::move(fusion));
at::Tensor inp = at::randn({b, s, d * h * e}, tensor_options);
at::Tensor sharded_inp = shardTensor(inp, tv0);
at::Tensor nvf_out =
executor_cache.runFusionWithInputs({sharded_inp})[0].as<at::Tensor>();
testValidate(
executor_cache.fusion(),
{nvf_out},
{sharded_inp},
{sharded_inp.view({b, s, h, e})},
__LINE__,
__FILE__);
}

} // namespace nvfuser