Skip to content

Commit

Permalink
fix reshard dist_attr (#60535)
Browse files Browse the repository at this point in the history
  • Loading branch information
LiYuRio authored Jan 3, 2024
1 parent 5d01382 commit 2ad9e24
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 0 deletions.
3 changes: 3 additions & 0 deletions paddle/phi/core/distributed/auto_parallel/dist_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ DistTensor::DistTensor() : value_(std::make_shared<DenseTensor>()) {}
DistTensor::DistTensor(const std::shared_ptr<phi::DenseTensor>& global_value,
const TensorDistAttr& dist_attr)
: global_dims_(global_value->dims()), dist_attr_(dist_attr) {
process_mesh_ = dist_attr_.process_mesh();
placements_ = ToPlacements(dist_attr);

// If the current rank doesn't in process_mesh, we should create an
// uninitialized tensor only with tensor_meta.
if (IsCurRankInMesh(dist_attr.process_mesh())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ void ReshardFunction::SetDistProps(DistTensor* tensor,

tensor->global_dims_ = dims;
tensor->dist_attr_ = dist_attr;
tensor->process_mesh_ = dist_attr.process_mesh();
tensor->placements_ = ToPlacements(dist_attr);
}

void ReshardFunction::SetDistProps(DistTensor* tensor,
Expand All @@ -64,6 +66,8 @@ void ReshardFunction::SetDistProps(DistTensor* tensor,
str_join(vectorize(tensor->dims()))));

tensor->dist_attr_ = dist_attr;
tensor->process_mesh_ = dist_attr.process_mesh();
tensor->placements_ = ToPlacements(dist_attr);
}

DenseTensor* ReshardFunction::GetMutableTensor(DistTensor* tensor) {
Expand Down

0 comments on commit 2ad9e24

Please sign in to comment.