Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR][DynamicShape] Remove redundant code for shapeAnalysis and shapedTypeInterface #60744

Merged
merged 6 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,8 @@ bool ProcessOp(paddle::dialect::ExpandOp op, pir::PatternRewriter* rewriter) {
pir::ShapeConstraintIRAnalysis& shape_analysis =
pir::ShapeAnalysisManager::Instance().Get(
op.x().defining_op()->GetParentProgram());
CHECK(shape_analysis.value_id_to_shapeordata_.find(GetValueId(&value)) !=
shape_analysis.value_id_to_shapeordata_.end());
return shape_analysis.value_id_to_shapeordata_.at(GetValueId(&value));

return shape_analysis.GetShapeOrDataForValue(value);
};
std::optional<pir::Value> opt_generated_shape =
GetOutOfRewritedGenerateShapeOp(
Expand Down
43 changes: 22 additions & 21 deletions paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
#include "paddle/pir/core/builtin_attribute.h"
#include "paddle/pir/core/builtin_type.h"
#include "paddle/pir/core/builtin_type_interfaces.h"
#include "paddle/pir/dialect/shape/ir/shape_attribute.h"

namespace paddle::dialect {
Expand All @@ -33,27 +34,25 @@ bool SameOperandsAndResultShape(
pir::Value operand_source = op->operand_source(0);

symbol::ShapeOrDataDimExprs operand_shape_or_data =
shape_analysis->value_to_shape_or_data_[operand_source];
shape_analysis->GetShapeOrDataForValue(operand_source);

op->set_attribute("symbolic_shape",
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(),
operand_shape_or_data));
pir::OpResult res = op->result(0);
shape_analysis->value_to_shape_or_data_[res] = operand_shape_or_data;
shape_analysis->SetShapeOrDataForValue(res, operand_shape_or_data);
return true;
}

bool InferSymbolicShapeElementWiseBinary(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
pir::Value operand_source_0 = op->operand_source(0);
std::string operand_source_0_id = pir::GetValueId(&operand_source_0);
std::vector<symbol::DimExpr> shape_0{
shape_analysis->value_id_to_shapeordata_[operand_source_0_id].shape()};
shape_analysis->GetShapeOrDataForValue(operand_source_0).shape()};

pir::Value operand_source_1 = op->operand_source(1);
std::string operand_source_1_id = pir::GetValueId(&operand_source_1);
std::vector<symbol::DimExpr> shape_1{
shape_analysis->value_id_to_shapeordata_[operand_source_1_id].shape()};
shape_analysis->GetShapeOrDataForValue(operand_source_1).shape()};

if (shape_0.size() > shape_1.size()) {
for (size_t i = 0; i < shape_0.size() - shape_1.size(); i++) {
Expand All @@ -75,9 +74,11 @@ bool InferSymbolicShapeElementWiseBinary(
std::vector<symbol::DimExpr> data;

pir::OpResult res = op->result(0);
std::string res_id = pir::GetValueId(&res);
symbol::ShapeOrDataDimExprs shape_data{shapes, data};
shape_analysis->value_id_to_shapeordata_[res_id] = shape_data;
shape_analysis->SetShapeOrDataForValue(res, shape_data);
op->set_attribute(
"symbolic_shape",
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data));
return true;
}

Expand All @@ -104,7 +105,7 @@ bool DataOpInferSymbolicShape(pir::Operation *op,
std::vector<symbol::DimExpr> sym_dims;
for (auto dim : dims) {
symbol::DimExpr dim_expr;
if (dim == -1) {
if (dim == pir::ShapedTypeInterface::kDynamic) {
symbol::DimExpr symbolic_dim_expr(shape_analysis->GetNextSymName());
dim_expr = symbolic_dim_expr;
} else {
Expand All @@ -120,7 +121,7 @@ bool DataOpInferSymbolicShape(pir::Operation *op,
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data));

pir::OpResult res = op->result(0);
shape_analysis->value_to_shape_or_data_[res] = shape_data;
shape_analysis->SetShapeOrDataForValue(res, shape_data);

return true;
}
Expand Down Expand Up @@ -171,13 +172,13 @@ bool ShapeOpInferSymbolicShape(pir::Operation *op,
pir::OpResult res = op->result(0);

symbol::ShapeOrDataDimExprs operand_shape_or_data =
shape_analysis->value_to_shape_or_data_[operand_source];
shape_analysis->GetShapeOrDataForValue(operand_source);

symbol::ShapeOrDataDimExprs extend_shape_or_data =
symbol::ShapeOrDataDimExprs::MakeConsistentShapeOrData(
operand_shape_or_data);

shape_analysis->value_to_shape_or_data_[res] = extend_shape_or_data;
shape_analysis->SetShapeOrDataForValue(res, extend_shape_or_data);
op->set_attribute("symbolic_shape",
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(),
extend_shape_or_data));
Expand All @@ -193,7 +194,7 @@ bool StackOpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis) {
pir::Value operand_source = op->operand_source(0);
symbol::ShapeOrDataDimExprs operand_shape_or_data =
shape_analysis->value_to_shape_or_data_[operand_source];
shape_analysis->GetShapeOrDataForValue(operand_source);

std::vector<symbol::DimExpr> out_dims;
if (operand_shape_or_data.data().has_value()) {
Expand All @@ -213,7 +214,7 @@ bool StackOpInferSymbolicShape(pir::Operation *op,
"symbolic_shape",
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data));
pir::OpResult res = op->result(0);
shape_analysis->value_to_shape_or_data_[res] = shape_data;
shape_analysis->SetShapeOrDataForValue(res, shape_data);
return true;
}

Expand All @@ -222,7 +223,7 @@ bool ReshapeOpInferSymbolicShape(
pir::Value operand_source_shape = op->operand_source(1);

symbol::ShapeOrDataDimExprs operand_shape_or_data =
shape_analysis->value_to_shape_or_data_[operand_source_shape];
shape_analysis->GetShapeOrDataForValue(operand_source_shape);

std::vector<symbol::DimExpr> out_dims;
if (operand_shape_or_data.data().has_value()) {
Expand All @@ -236,9 +237,9 @@ bool ReshapeOpInferSymbolicShape(

pir::OpResult res0 = op->result(0);
pir::OpResult res1 = op->result(1);
shape_analysis->value_to_shape_or_data_[res0] = shape_data;
shape_analysis->value_to_shape_or_data_[res1] =
shape_analysis->value_to_shape_or_data_[operand_source_shape];
shape_analysis->SetShapeOrDataForValue(res0, shape_data);
shape_analysis->SetShapeOrDataForValue(
res1, shape_analysis->GetShapeOrDataForValue(operand_source_shape));
return true;
}

Expand Down Expand Up @@ -267,7 +268,7 @@ bool FullIntArrayOpInferSymbolicShape(
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data));

pir::OpResult res = op->result(0);
shape_analysis->value_to_shape_or_data_[res] = shape_data;
shape_analysis->SetShapeOrDataForValue(res, shape_data);
return true;
}

Expand All @@ -286,7 +287,7 @@ bool SliceOpInferSymbolicShape(pir::Operation *op,
// dialect.
pir::Value operand_source = op->operand_source(0);
symbol::ShapeOrDataDimExprs operand_shape_or_data =
shape_analysis->value_to_shape_or_data_[operand_source];
shape_analysis->GetShapeOrDataForValue(operand_source);
pir::AttributeMap attributes = op->attributes();

std::vector<pir::Attribute> attr_starts =
Expand All @@ -309,7 +310,7 @@ bool SliceOpInferSymbolicShape(pir::Operation *op,
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data));

pir::OpResult res = op->result(0);
shape_analysis->value_to_shape_or_data_[res] = shape_data;
shape_analysis->SetShapeOrDataForValue(res, shape_data);
return true;
}

Expand Down
15 changes: 4 additions & 11 deletions paddle/fluid/pir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3005,15 +3005,9 @@ bool ShapeBroadcastOp::InferSymbolicShape(
pir::ShapeConstraintIRAnalysis *shape_analysis) {
pir::Value x = operand_source(0);
pir::Value y = operand_source(1);
std::string x_id = pir::GetValueId(&x);
std::string y_id = pir::GetValueId(&y);

IR_ENFORCE(shape_analysis->value_id_to_shapeordata_.count(x_id) > 0,
"x_id does not exist.");
IR_ENFORCE(shape_analysis->value_id_to_shapeordata_.count(y_id) > 0,
"y_id does not exist.");
const auto &x_data_shape = shape_analysis->value_id_to_shapeordata_.at(x_id);
const auto &y_data_shape = shape_analysis->value_id_to_shapeordata_.at(y_id);

const auto &x_data_shape = shape_analysis->GetShapeOrDataForValue(x);
const auto &y_data_shape = shape_analysis->GetShapeOrDataForValue(y);
IR_ENFORCE(x_data_shape.data().has_value(),
"Value x comes from ShapeOp, it must have data");
IR_ENFORCE(y_data_shape.data().has_value(),
Expand All @@ -3028,10 +3022,9 @@ bool ShapeBroadcastOp::InferSymbolicShape(
}

pir::OpResult res = result(0);
std::string res_id = pir::GetValueId(&res);
symbol::ShapeOrDataDimExprs output_data_shape =
symbol::ShapeOrDataDimExprs::MakeConsistentShapeOrData(output_data);
shape_analysis->value_id_to_shapeordata_[res_id] = output_data_shape;
shape_analysis->SetShapeOrDataForValue(res, output_data_shape);
return true;
}

Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/pir/dialect/operator/ir/op_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ struct CombineOpInferSymbolicShapeInterfaceModel
}

auto operand_source_1st_data =
shape_analysis->value_to_shape_or_data_[op->operand_source(0)].data();
shape_analysis->GetShapeOrDataForValue(op->operand_source(0)).data();
if (operand_source_1st_data.has_value()) {
for (auto operand_source : op->operands_source()) {
auto source_data =
shape_analysis->value_to_shape_or_data_[operand_source]
shape_analysis->GetShapeOrDataForValue(operand_source)
.data()
.value();
out_dims.push_back(source_data[0]);
Expand All @@ -83,7 +83,7 @@ struct CombineOpInferSymbolicShapeInterfaceModel
pir::shape::SymbolAttribute::get(
pir::IrContext::Instance(), shape_data));
auto res = op->result(0);
shape_analysis->value_to_shape_or_data_[res] = shape_data;
shape_analysis->SetShapeOrDataForValue(res, shape_data);
return true;
}

Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/pir/transforms/shape_optimization_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ void DebugPrintOpInfo(
<< "ShapeOrData: ";

if (shape_analysis != nullptr) {
auto shape_data = shape_analysis->value_to_shape_or_data_[res];
auto shape_data = shape_analysis->GetShapeOrDataForValue(res);
print_stream << "shape: [";

for (size_t i = 0; i < shape_data.shape().size(); ++i) {
Expand Down Expand Up @@ -94,7 +94,9 @@ void InferSymExprForAllValues(ModuleOp module_op) {
if (infer_symbolic_shape_interface) {
VLOG(3) << op.name() << " has InferSymbolicShapeInterface.";
PADDLE_ENFORCE(infer_symbolic_shape_interface.InferSymbolicShape(
&shape_analysis));
&shape_analysis),
"InferSymbolicShape for %s failed.",
op.name());
}
DebugPrintOpInfo(&op, &shape_analysis);
}
Expand Down
12 changes: 2 additions & 10 deletions paddle/pir/core/builtin_type_interfaces.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,8 @@ Type ShapedTypeInterface::GetElementType() const {
return impl_->get_element_type(*this);
}

std::vector<int64_t> ShapedTypeInterface::GetDyShape() const {
if (dy_shape_.size() == 0) {
auto ddim_vec = common::vectorize(impl_->get_shape(*this));
dy_shape_ = ddim_vec;
std::replace(dy_shape_.begin(),
dy_shape_.end(),
(int64_t)-1,
ShapedTypeInterface::kDynamic);
}
return dy_shape_;
pir::DDim ShapedTypeInterface::GetShape() const {
return impl_->get_shape(*this);
}

} // namespace pir
Expand Down
22 changes: 10 additions & 12 deletions paddle/pir/core/builtin_type_interfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class IR_API ShapedTypeInterface
///
/// \brief kDynamic
///
static constexpr int64_t kDynamic = std::numeric_limits<int64_t>::min();
static constexpr int64_t kDynamic = std::int64_t(-1);

ShapedTypeInterface(Type type, Concept *impl)
: TypeInterfaceBase<ShapedTypeInterface>(type), impl_(impl) {}
Expand All @@ -69,7 +69,7 @@ class IR_API ShapedTypeInterface
///
/// \brief Get the shape of this type.
///
std::vector<int64_t> GetDyShape() const;
pir::DDim GetShape() const;

///
/// \brief Check whether this type is ranked, currently return true.
Expand All @@ -81,7 +81,7 @@ class IR_API ShapedTypeInterface
///
int64_t GetRank() const {
IR_ENFORCE((*this).HasRank(), "Cannot query rank of unranked shaped type.");
return (*this).GetDyShape().size();
return (*this).GetShape().size();
}

///
Expand All @@ -94,11 +94,10 @@ class IR_API ShapedTypeInterface
/// dimension.
///
bool IsDynamicShape() const {
auto size_vec = (*this).GetDyShape();
return std::any_of(
size_vec.begin(), size_vec.end(), [](int64_t size_value) {
return IsDynamic(size_value);
});
auto size_vec = common::vectorize(impl_->get_shape(*this));
return std::any_of(size_vec.begin(), size_vec.end(), [](int64_t size_val) {
return IsDynamic(size_val);
});
}

///
Expand All @@ -112,15 +111,15 @@ class IR_API ShapedTypeInterface
///
bool IsDynamicDim(unsigned idx) const {
IR_ENFORCE(idx < GetRank(), "Invalid index for shaped type.");
return ShapedTypeInterface::IsDynamic((*this).GetDyShape()[idx]);
return ShapedTypeInterface::IsDynamic((*this).GetShape()[idx]);
}

///
/// \brief Get the number of dimensions with dynamic size for a ranked type.
/// Aborts for unranked types.
///
int64_t GetNumDynamicDims() const {
auto shape_vec = (*this).GetDyShape();
auto shape_vec = vectorize((*this).GetShape());
return std::count_if(
shape_vec.begin(), shape_vec.end(), ShapedTypeInterface::IsDynamic);
}
Expand All @@ -131,12 +130,11 @@ class IR_API ShapedTypeInterface
///
int64_t GetDimSize(unsigned idx) const {
IR_ENFORCE(idx < GetRank(), "Invalid index for shaped type.");
return (*this).GetDyShape()[idx];
return (*this).GetShape()[idx];
}

private:
Concept *impl_;
mutable std::vector<int64_t> dy_shape_;
};

} // namespace pir
Expand Down
12 changes: 6 additions & 6 deletions paddle/pir/core/type_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ Type GetElementTypeOrSelf(Type type) {
return type;
}

bool VerifyCompatibleShape(const std::vector<int64_t> &lhs_shape,
const std::vector<int64_t> &rhs_shape) {
bool VerifyCompatibleShape(const pir::DDim &lhs_shape,
const pir::DDim &rhs_shape) {
if (lhs_shape.size() != rhs_shape.size()) return false;

for (auto dim1 : lhs_shape) {
for (auto dim2 : rhs_shape) {
for (auto dim1 : common::vectorize(lhs_shape)) {
for (auto dim2 : common::vectorize(rhs_shape)) {
if (!ShapedTypeInterface::IsDynamic(dim1) &&
!ShapedTypeInterface::IsDynamic(dim2) && dim1 != dim2)
return false;
Expand All @@ -47,8 +47,8 @@ bool VerifyCompatibleShape(Type lhs_type, Type rhs_type) {

if (!lhs_shaped_type.HasRank() || !rhs_shaped_type.HasRank()) return true;

return VerifyCompatibleShape(lhs_shaped_type.GetDyShape(),
rhs_shaped_type.GetDyShape());
return VerifyCompatibleShape(lhs_shaped_type.GetShape(),
rhs_shaped_type.GetShape());
}

bool VerifyCompatibleDims(const std::vector<int64_t> &dims) {
Expand Down
7 changes: 3 additions & 4 deletions paddle/pir/dialect/shape/utils/shape_optimization_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,9 @@ std::vector<SymbolicDimOp> SymbolicDimMgr::CreateSymbolicDimsForRankedValue(
std::vector<SymbolicDimOp> symbols;
auto dims = value.type().dyn_cast<pir::DenseTensorType>().dims();
for (int idx = 0; idx < dims.size(); ++idx) {
symbols.push_back(
(dims[idx] == ShapedTypeInterface::kDynamic || dims[idx] == -1)
? NewSymbolicDim()
: NewConstantSymbolicDim(dims[idx]));
symbols.push_back(dims[idx] == ShapedTypeInterface::kDynamic
? NewSymbolicDim()
: NewConstantSymbolicDim(dims[idx]));
}
return symbols;
}
Expand Down
Loading