Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dynamic Shape] Fuse shape ops into generate shape op pass #60490

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion paddle/cinn/hlir/dialect/operator/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,12 @@ if(NOT CINN_ONLY)
op_dialect.cc
${cinn_op_source_file}
${cinn_op_info_file}
generate_shape_util.cc
manual_op.cc
op_attribute.cc
DEPS
op_dialect_vjp)
op_dialect_vjp
pir)

target_include_directories(cinn_op_dialect PRIVATE ${CINN_DIALECT_SOURCE_DIR})
endif()
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/pir/dialect/shape/utils/dim_expr_util.h"
#include "paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h"
#include "paddle/pir/core/builder.h"
#include "paddle/pir/core/builtin_attribute.h"

namespace symbol {
namespace cinn::dialect {
using namespace symbol; // NOLINT

namespace {

Expand Down Expand Up @@ -58,71 +59,71 @@ std::string GetSerializedTag<Broadcast<DimExpr>>() {
return "Broadcast";
}

::pir::Attribute ConvertDimExprToAttributeImpl(::pir::Builder* builder,
::pir::Attribute ConvertDimExprToAttributeImpl(::pir::IrContext* ctx,
const std::int64_t& dim_expr) {
return builder->int64_attr(dim_expr);
return pir::Int64Attribute::get(ctx, dim_expr);
}

::pir::Attribute ConvertDimExprToAttributeImpl(::pir::Builder* builder,
::pir::Attribute ConvertDimExprToAttributeImpl(::pir::IrContext* ctx,
const std::string& dim_expr) {
return builder->str_attr(dim_expr);
return pir::StrAttribute::get(ctx, dim_expr);
}

template <typename T>
::pir::Attribute ConvertUnaryDimExprToAttributeImpl(::pir::Builder* builder,
::pir::Attribute ConvertUnaryDimExprToAttributeImpl(::pir::IrContext* ctx,
const T& dim_expr) {
std::vector<::pir::Attribute> attr_vecs{};
attr_vecs.push_back(builder->str_attr(GetSerializedTag<T>()));
attr_vecs.push_back(pir::StrAttribute::get(ctx, GetSerializedTag<T>()));
const auto& operand = dim_expr->data;
attr_vecs.push_back(ConvertDimExprToAttribute(builder, operand));
return builder->array_attr(attr_vecs);
attr_vecs.push_back(ConvertDimExprToAttribute(ctx, operand));
return pir::ArrayAttribute::get(ctx, attr_vecs);
}

::pir::Attribute ConvertDimExprToAttributeImpl(
::pir::Builder* builder, const Negative<DimExpr>& dim_expr) {
return ConvertUnaryDimExprToAttributeImpl(builder, dim_expr);
::pir::IrContext* ctx, const Negative<DimExpr>& dim_expr) {
return ConvertUnaryDimExprToAttributeImpl(ctx, dim_expr);
}

::pir::Attribute ConvertDimExprToAttributeImpl(
::pir::Builder* builder, const Reciprocal<DimExpr>& dim_expr) {
return ConvertUnaryDimExprToAttributeImpl(builder, dim_expr);
::pir::IrContext* ctx, const Reciprocal<DimExpr>& dim_expr) {
return ConvertUnaryDimExprToAttributeImpl(ctx, dim_expr);
}

template <typename T>
::pir::Attribute ConvertVariadicDimExprToAttribute(::pir::Builder* builder,
::pir::Attribute ConvertVariadicDimExprToAttribute(::pir::IrContext* ctx,
const T& dim_expr) {
std::vector<::pir::Attribute> attr_vecs{};
attr_vecs.push_back(builder->str_attr(GetSerializedTag<T>()));
attr_vecs.push_back(pir::StrAttribute::get(ctx, GetSerializedTag<T>()));
const auto& operands = *(dim_expr.operands);
for (const auto& operand : operands) {
attr_vecs.push_back(ConvertDimExprToAttribute(builder, operand));
attr_vecs.push_back(ConvertDimExprToAttribute(ctx, operand));
}
return builder->array_attr(attr_vecs);
return pir::ArrayAttribute::get(ctx, attr_vecs);
}

::pir::Attribute ConvertDimExprToAttributeImpl(::pir::Builder* builder,
::pir::Attribute ConvertDimExprToAttributeImpl(::pir::IrContext* ctx,
const Add<DimExpr>& dim_expr) {
return ConvertVariadicDimExprToAttribute(builder, dim_expr);
return ConvertVariadicDimExprToAttribute(ctx, dim_expr);
}

::pir::Attribute ConvertDimExprToAttributeImpl(::pir::Builder* builder,
::pir::Attribute ConvertDimExprToAttributeImpl(::pir::IrContext* ctx,
const Mul<DimExpr>& dim_expr) {
return ConvertVariadicDimExprToAttribute(builder, dim_expr);
return ConvertVariadicDimExprToAttribute(ctx, dim_expr);
}

::pir::Attribute ConvertDimExprToAttributeImpl(::pir::Builder* builder,
::pir::Attribute ConvertDimExprToAttributeImpl(::pir::IrContext* ctx,
const Max<DimExpr>& dim_expr) {
return ConvertVariadicDimExprToAttribute(builder, dim_expr);
return ConvertVariadicDimExprToAttribute(ctx, dim_expr);
}

::pir::Attribute ConvertDimExprToAttributeImpl(::pir::Builder* builder,
::pir::Attribute ConvertDimExprToAttributeImpl(::pir::IrContext* ctx,
const Min<DimExpr>& dim_expr) {
return ConvertVariadicDimExprToAttribute(builder, dim_expr);
return ConvertVariadicDimExprToAttribute(ctx, dim_expr);
}

::pir::Attribute ConvertDimExprToAttributeImpl(
::pir::Builder* builder, const Broadcast<DimExpr>& dim_expr) {
return ConvertVariadicDimExprToAttribute(builder, dim_expr);
::pir::IrContext* ctx, const Broadcast<DimExpr>& dim_expr) {
return ConvertVariadicDimExprToAttribute(ctx, dim_expr);
}

std::optional<DimExpr> ConvertInt64AttributeToDimExpr(
Expand Down Expand Up @@ -211,11 +212,11 @@ std::optional<DimExpr> ConvertArrayAttributeToDimExpr(

} // namespace

::pir::Attribute ConvertDimExprToAttribute(::pir::Builder* builder,
::pir::Attribute ConvertDimExprToAttribute(pir::IrContext* ctx,
const DimExpr& dim_expr) {
return std::visit(
[&](const auto& impl) {
return ConvertDimExprToAttributeImpl(builder, impl);
return ConvertDimExprToAttributeImpl(ctx, impl);
},
dim_expr.variant());
}
Expand Down Expand Up @@ -359,4 +360,66 @@ MakeGetterDimExpr4SymbolName(
};
}

} // namespace symbol
namespace {

std::optional<DimExpr> GetDimExprBySymbolBindingImpl(
const GenerateShapeOp::DataSymbolBinding& symbol_binding,
const std::function<const symbol::ShapeOrDataDimExprs&(int in_tensor_idx)>&
DimExpr4InputDim) {
const symbol::ShapeOrDataDimExprs& shape_or_data_dim_expr =
DimExpr4InputDim(symbol_binding.input_tensor_idx);
if (!shape_or_data_dim_expr.data().has_value()) return std::nullopt;
int dim_idx = symbol_binding.input_tensor_dim_idx;
if (dim_idx >= shape_or_data_dim_expr.data().value().size())
return std::nullopt;
return shape_or_data_dim_expr.data().value().at(dim_idx);
}

std::optional<DimExpr> GetDimExprBySymbolBindingImpl(
const GenerateShapeOp::ShapeSymbolBinding& symbol_binding,
const std::function<const symbol::ShapeOrDataDimExprs&(int in_tensor_idx)>&
DimExpr4InputDim) {
const symbol::ShapeOrDataDimExprs& shape_or_data_dim_expr =
DimExpr4InputDim(symbol_binding.input_tensor_idx);
int dim_idx = symbol_binding.input_tensor_dim_idx;
if (dim_idx >= shape_or_data_dim_expr.shape().size()) return std::nullopt;
return shape_or_data_dim_expr.shape().at(dim_idx);
}

} // namespace

std::function<std::optional<DimExpr>(const std::string& symbol_name)>
MakeGetterDimExpr4SymbolName(
const GenerateShapeOp::SymbolBindings& symbol_bindings,
const std::function<const symbol::ShapeOrDataDimExprs&(int in_tensor_idx)>&
DimExpr4InputDim) {
std::unordered_map<std::string, std::vector<GenerateShapeOp::SymbolBinding>>
symbol_name2symbol_bindins{};
const auto& GetDimExpr =
[&](const GenerateShapeOp::SymbolBinding& symbol_binding) {
return std::visit(
[&](const auto& impl) {
return GetDimExprBySymbolBindingImpl(impl, DimExpr4InputDim);
},
symbol_binding);
};
return [map = std::move(symbol_name2symbol_bindins), GetDimExpr](
const std::string& symbol_name) -> std::optional<DimExpr> {
const auto& iter = map.find(symbol_name);
if (iter == map.end()) return std::nullopt;
std::optional<DimExpr> ret = std::nullopt;
for (const auto& symbol_binding : iter->second) {
const auto& current = GetDimExpr(symbol_binding);
if (!current.has_value()) return std::nullopt;
if (ret.has_value()) {
// Same names, same DimExprs.
if (ret.value() != current.value()) return std::nullopt;
} else {
ret = current;
}
}
return ret;
};
}

} // namespace cinn::dialect
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,35 @@
#pragma once

#include <optional>
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
#include "paddle/pir/core/builder.h"
#include "paddle/pir/core/dll_decl.h"
#include "paddle/pir/dialect/shape/utils/dim_expr.h"

namespace symbol {
namespace cinn::dialect {

IR_API ::pir::Attribute ConvertDimExprToAttribute(::pir::Builder* builder,
const DimExpr& dim_expr);
IR_API std::optional<DimExpr> ConvertAttributeToDimExpr(
::pir::Attribute ConvertDimExprToAttribute(pir::IrContext* ctx,
const symbol::DimExpr& dim_expr);

std::optional<symbol::DimExpr> ConvertAttributeToDimExpr(
::pir::Attribute attribute);

IR_API std::optional<DimExpr> SubstituteDimExpr(
const DimExpr& dim_expr,
const std::function<std::optional<DimExpr>(const std::string& symbol_name)>&
DimExpr4SymbolName);
std::optional<symbol::DimExpr> SubstituteDimExpr(
const symbol::DimExpr& dim_expr,
const std::function<std::optional<symbol::DimExpr>(
const std::string& symbol_name)>& DimExpr4SymbolName);

IR_API std::function<std::optional<DimExpr>(const std::string& symbol_name)>
std::function<std::optional<symbol::DimExpr>(const std::string& symbol_name)>
MakeGetterDimExpr4SymbolName(
const std::vector<std::tuple<std::string /*symbol_name*/,
int /*in_tensor_idx*/,
int /*in_tensor_dim_idx*/>>& symbol_bindings,
const std::function<std::optional<DimExpr>(
const std::function<std::optional<symbol::DimExpr>(
int in_tensor_idx, int in_tensor_dim_idx)>& DimExpr4InputDim);

} // namespace symbol
std::function<std::optional<symbol::DimExpr>(const std::string& symbol_name)>
MakeGetterDimExpr4SymbolName(
const GenerateShapeOp::SymbolBindings& symbol_bindings,
const std::function<const symbol::ShapeOrDataDimExprs&(int in_tensor_idx)>&
DimExpr4InputDim);

} // namespace cinn::dialect
Loading