diff --git a/paddle/common/flags.cc b/paddle/common/flags.cc index fd0cc0024d510..d755e31870e46 100644 --- a/paddle/common/flags.cc +++ b/paddle/common/flags.cc @@ -1965,3 +1965,7 @@ PHI_DEFINE_EXPORTED_bool(fused_multi_transformer_op_use_mbfmha, PHI_DEFINE_EXPORTED_int64(multi_block_attention_min_partition_size, 1024, "The minimum partition size for flash decoding"); + +PHI_DEFINE_EXPORTED_bool(save_cf_stack_op, + false, + "Save cf stack op for higher-order derivatives."); diff --git a/paddle/fluid/pir/serialize_deserialize/include/deserialize_utils.h b/paddle/fluid/pir/serialize_deserialize/include/deserialize_utils.h index 3305167842810..51d5a0199d318 100644 --- a/paddle/fluid/pir/serialize_deserialize/include/deserialize_utils.h +++ b/paddle/fluid/pir/serialize_deserialize/include/deserialize_utils.h @@ -29,6 +29,7 @@ #include "paddle/phi/common/data_type.h" #include "paddle/pir/include/core/builtin_attribute.h" #include "paddle/pir/include/core/builtin_type.h" +#include "paddle/pir/include/dialect/control_flow/ir/cf_type.h" #include "paddle/utils/flat_hash_map.h" namespace pir { @@ -66,6 +67,10 @@ class AttrTypeReader { static pir::Attribute ReadPaddleDistAttr(const std::string attr_name, Json* attr_json, pir::IrContext* ctx); + + static pir::Type ReadControlFlowType(const std::string type_name, + Json* type_json, + pir::IrContext* ctx); }; template @@ -237,6 +242,9 @@ pir::Type parseType(Json* type_json) { } else if (DECOMPRESS_DIALECT_ID(name.first) == paddle::dialect::DistDialect::name()) { return AttrTypeReader::ReadPaddleDistType(name.second, type_json, ctx); + } else if (DECOMPRESS_DIALECT_ID(name.first) == + pir::ControlFlowDialect::name()) { + return AttrTypeReader::ReadControlFlowType(name.second, type_json, ctx); } else { PADDLE_ENFORCE( false, @@ -695,4 +703,25 @@ pir::Type AttrTypeReader::ReadPaddleDistType(const std::string type_name, } } +pir::Type AttrTypeReader::ReadControlFlowType(const std::string type_name, + Json* type_json, + pir::IrContext* ctx) { + if (type_name == pir::StackType::name()) { + VLOG(8) << "Parse StackType ... "; + return pir::deserializeTypeFromJson(type_json, ctx); + } else if (type_name == pir::InletType::name()) { + VLOG(8) << "Parse InletType ... "; + return pir::deserializeTypeFromJson(type_json, ctx); + } else if (type_name == pir::OutletType::name()) { + VLOG(8) << "Parse OutletType ... "; + return pir::deserializeTypeFromJson(type_json, ctx); + } else { + PADDLE_ENFORCE( + false, + common::errors::InvalidArgument( + "Unknown Type %s for parse controlflow dialect type", type_name)); + return pir::Type(); + } +} + } // namespace pir diff --git a/paddle/fluid/pir/serialize_deserialize/include/serialize_utils.h b/paddle/fluid/pir/serialize_deserialize/include/serialize_utils.h index 782c808bb7607..5b39f72138671 100644 --- a/paddle/fluid/pir/serialize_deserialize/include/serialize_utils.h +++ b/paddle/fluid/pir/serialize_deserialize/include/serialize_utils.h @@ -28,6 +28,7 @@ #include "paddle/phi/common/data_type.h" #include "paddle/pir/include/core/builtin_attribute.h" #include "paddle/pir/include/core/builtin_type.h" +#include "paddle/pir/include/dialect/control_flow/ir/cf_type.h" namespace pir { #define COMPRESS_DIALECT_NAME(attr_template) \ @@ -53,6 +54,8 @@ class AttrTypeWriter { static Json WritePaddleDistType(const pir::Type& type); static Json WritePaddleDistAttr(const pir::Attribute& attr); + + static Json WriteControlFlowType(const pir::Type& type); }; /** serializeTypeToJson is a template function to serialize * a pir type to a json object. a pir type may have value or no value @@ -245,6 +248,9 @@ Json writeType(const pir::Type& type) { } else if (type.dialect().name() == paddle::dialect::DistDialect::name()) { VLOG(6) << "write PaddleDistType ... "; return AttrTypeWriter::WritePaddleDistType(type); + } else if (type.dialect().name() == pir::ControlFlowDialect::name()) { + VLOG(6) << "write ControlFlowDialect ... "; + return AttrTypeWriter::WriteControlFlowType(type); } else { PADDLE_ENFORCE( false, @@ -723,4 +729,26 @@ Json AttrTypeWriter::WritePaddleDistAttr(const pir::Attribute& attr) { return Json::object(); } +Json AttrTypeWriter::WriteControlFlowType(const pir::Type& type) { + Json type_json = Json::object(); + if (type.isa()) { + VLOG(8) << "Write StackType ... "; + return pir::serializeTypeToJson( + type.dyn_cast()); + } else if (type.isa()) { + VLOG(8) << "Write InletType ... "; + return pir::serializeTypeToJson( + type.dyn_cast()); + } else if (type.isa()) { + VLOG(8) << "Write OutletType ... "; + return pir::serializeTypeToJson( + type.dyn_cast()); + } else { + PADDLE_ENFORCE(false, + common::errors::InvalidArgument( + "Unknown Type when write controlflow dialect type")); + } + return type_json; +} + } // namespace pir diff --git a/paddle/fluid/pir/serialize_deserialize/src/ir_serialize.cc b/paddle/fluid/pir/serialize_deserialize/src/ir_serialize.cc index 7bcf38559b5f3..dc41e7a7d58b4 100644 --- a/paddle/fluid/pir/serialize_deserialize/src/ir_serialize.cc +++ b/paddle/fluid/pir/serialize_deserialize/src/ir_serialize.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/pir/serialize_deserialize/include/ir_serialize.h" +#include "paddle/common/flags.h" #include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" #include "paddle/fluid/pir/serialize_deserialize/include/serialize_utils.h" #include "paddle/pir/include/core/dialect.h" @@ -20,6 +21,7 @@ #include "paddle/pir/include/dialect/control_flow/ir/cf_dialect.h" #include "paddle/pir/include/dialect/control_flow/ir/cf_op.h" +COMMON_DECLARE_bool(save_cf_stack_op); namespace pir { Json ProgramWriter::GetProgramJson(const pir::Program* program) { @@ -99,29 +101,31 @@ Json ProgramWriter::WriteBlock(pir::Block* block, Json ops_json = Json::array(); /* delete cf.stack_create / cf.tuple_push */ - std::vector delete_ops; - for (auto op : block->ops()) { - if (op->isa()) { - delete_ops.push_back(op); - } - } - VLOG(6) << "program before delete stack op :" << *(block->parent_program()); - for (auto op : delete_ops) { - VLOG(0) << "Delete cf.stack_create / cf.tuple_push."; - auto stack_op = op->dyn_cast(); - if (stack_op.inlet().HasOneUse()) { - auto tuple_push_op = stack_op.tuple_push_op(); - auto block_in = tuple_push_op->GetParent(); - block_in->erase(*tuple_push_op); + if (!FLAGS_save_cf_stack_op) { + std::vector delete_ops; + for (auto op : block->ops()) { + if (op->isa()) { + delete_ops.push_back(op); + } } - if (stack_op.outlet().HasOneUse()) { - auto tuple_pop_op = stack_op.tuple_pop_op(); - auto block_in = tuple_pop_op->GetParent(); - block_in->erase(*tuple_pop_op); + VLOG(6) << "program before delete stack op :" << *(block->parent_program()); + for (auto op : delete_ops) { + VLOG(0) << "Delete cf.stack_create / cf.tuple_push."; + auto stack_op = op->dyn_cast(); + if (stack_op.inlet().HasOneUse()) { + auto tuple_push_op = stack_op.tuple_push_op(); + auto block_in = tuple_push_op->GetParent(); + block_in->erase(*tuple_push_op); + } + if (stack_op.outlet().HasOneUse()) { + auto tuple_pop_op = stack_op.tuple_pop_op(); + auto block_in = tuple_pop_op->GetParent(); + block_in->erase(*tuple_pop_op); + } + block->erase(*op); } - block->erase(*op); + VLOG(6) << "program after delete stack op :" << *(block->parent_program()); } - VLOG(6) << "program after delete stack op :" << *(block->parent_program()); for (auto op : block->ops()) { auto op_json = WriteOp(*op); ops_json.emplace_back(op_json); diff --git a/paddle/fluid/pir/serialize_deserialize/src/patch_util.cc b/paddle/fluid/pir/serialize_deserialize/src/patch_util.cc index 91d5f37db2426..c3b68f16acad5 100644 --- a/paddle/fluid/pir/serialize_deserialize/src/patch_util.cc +++ b/paddle/fluid/pir/serialize_deserialize/src/patch_util.cc @@ -24,6 +24,7 @@ #include "paddle/phi/common/data_type.h" #include "paddle/pir/include/core/builtin_attribute.h" #include "paddle/pir/include/core/builtin_type.h" +#include "paddle/pir/include/dialect/control_flow/ir/cf_type.h" namespace pir { @@ -168,9 +169,15 @@ std::string GetTypeName(const YAML::Node &action) { Json GetTypeJson(const YAML::Node &action) { Json json; - std::string dialect = DialectIdMap::Instance()->GetCompressDialectId( - pir::BuiltinDialect::name()) + - "."; + std::string builtin_dialect = DialectIdMap::Instance()->GetCompressDialectId( + pir::BuiltinDialect::name()) + + "."; + std::string op_dialect = DialectIdMap::Instance()->GetCompressDialectId( + paddle::dialect::OperatorDialect::name()) + + "."; + std::string cf_dialect = DialectIdMap::Instance()->GetCompressDialectId( + pir::ControlFlowDialect::name()) + + "."; std::string type_name = ""; if (action.IsScalar()) { type_name = action.as(); @@ -181,46 +188,46 @@ Json GetTypeJson(const YAML::Node &action) { } if (type_name == "pir::BoolType") { VLOG(8) << "Get BoolType name."; - json[ID] = dialect + pir::BoolType::name(); + json[ID] = builtin_dialect + pir::BoolType::name(); } else if (type_name == "pir::BFloat16Type") { VLOG(8) << "Get BFloat16Type name."; - json[ID] = dialect + pir::BFloat16Type::name(); + json[ID] = builtin_dialect + pir::BFloat16Type::name(); } else if (type_name == "pir::Float16Type") { VLOG(8) << "Get Float16Type name."; - json[ID] = dialect + pir::Float16Type::name(); + json[ID] = builtin_dialect + pir::Float16Type::name(); } else if (type_name == "pir::Float32Type") { VLOG(8) << "Get Float32Type name."; - json[ID] = dialect + pir::Float32Type::name(); + json[ID] = builtin_dialect + pir::Float32Type::name(); } else if (type_name == "pir::Float64Type") { VLOG(8) << "Get Float64Type name."; - json[ID] = dialect + pir::Float64Type::name(); + json[ID] = builtin_dialect + pir::Float64Type::name(); } else if (type_name == "pir::Int8Type") { VLOG(8) << "Get Int8Type name."; - json[ID] = dialect + pir::Int8Type::name(); + json[ID] = builtin_dialect + pir::Int8Type::name(); } else if (type_name == "pir::UInt8Type") { VLOG(8) << "Get UInt8Type name."; - json[ID] = dialect + pir::UInt8Type::name(); + json[ID] = builtin_dialect + pir::UInt8Type::name(); } else if (type_name == "pir::Int16Type") { VLOG(8) << "Get Int16Type name."; - json[ID] = dialect + pir::Int16Type::name(); + json[ID] = builtin_dialect + pir::Int16Type::name(); } else if (type_name == "pir::Int32Type") { VLOG(8) << "Get Int32Type name."; - json[ID] = dialect + pir::Int32Type::name(); + json[ID] = builtin_dialect + pir::Int32Type::name(); } else if (type_name == "pir::Int64Type") { VLOG(8) << "Get Int64Type name."; - json[ID] = dialect + pir::Int64Type::name(); + json[ID] = builtin_dialect + pir::Int64Type::name(); } else if (type_name == "pir::IndexType") { VLOG(8) << "Get IndexType name."; - json[ID] = dialect + pir::IndexType::name(); + json[ID] = builtin_dialect + pir::IndexType::name(); } else if (type_name == "pir::Complex64Type") { VLOG(8) << "Get Complex64Type name."; - json[ID] = dialect + pir::Complex64Type::name(); + json[ID] = builtin_dialect + pir::Complex64Type::name(); } else if (type_name == "pir::Complex128Type") { VLOG(8) << "Get Complex128Type name."; - json[ID] = dialect + pir::Complex128Type::name(); + json[ID] = builtin_dialect + pir::Complex128Type::name(); } else if (type_name == "pir::VectorType") { VLOG(8) << "Get VectorType name."; - json[ID] = dialect + pir::VectorType::name(); + json[ID] = builtin_dialect + pir::VectorType::name(); json[DATA] = Json::array(); for (size_t i = 0; i < action["default"].size(); i++) { YAML::Node array_value = action["default"][i]; @@ -228,7 +235,7 @@ Json GetTypeJson(const YAML::Node &action) { } } else if (type_name == "pir::DenseTensorType") { VLOG(8) << "Get DenseTensorType name."; - json[ID] = dialect + pir::DenseTensorType::name(); + json[ID] = builtin_dialect + pir::DenseTensorType::name(); Json content = Json::array(); YAML::Node tensor_value = action["default"]; content.push_back(BuildTypeJsonPatch(tensor_value[0])); @@ -242,6 +249,15 @@ Json GetTypeJson(const YAML::Node &action) { content.push_back(tensor_value[4].as()); // offset json[DATA] = content; + } else if (type_name == "pir::StackType") { + VLOG(8) << "Get StackType name."; + json[ID] = cf_dialect + pir::StackType::name(); + } else if (type_name == "pir::InletType") { + VLOG(8) << "Get InletType name."; + json[ID] = cf_dialect + pir::InletType::name(); + } else if (type_name == "pir::OutletType") { + VLOG(8) << "Get OutletType name."; + json[ID] = cf_dialect + pir::OutletType::name(); } return json; } diff --git a/paddle/pir/include/core/storage_manager_support.h b/paddle/pir/include/core/storage_manager_support.h index 614f3938c54e2..8964ab0d1a5f0 100644 --- a/paddle/pir/include/core/storage_manager_support.h +++ b/paddle/pir/include/core/storage_manager_support.h @@ -65,7 +65,8 @@ class StorageHelperBase : public BaseT { using InterfaceList = typename Filter>::Type; - static ConcreteT dyn_cast_impl(BaseT type) { + template + static ConcreteT dyn_cast_impl(T type) { if (type && type.type_id() == TypeId::get()) { return ConcreteT(type.storage()); } diff --git a/paddle/pir/include/dialect/control_flow/ir/cf_type.h b/paddle/pir/include/dialect/control_flow/ir/cf_type.h index 1e76de313d861..cf3e50c4790f2 100644 --- a/paddle/pir/include/dialect/control_flow/ir/cf_type.h +++ b/paddle/pir/include/dialect/control_flow/ir/cf_type.h @@ -30,16 +30,19 @@ class IR_API StackType : public Type::TypeBase { public: using Base::Base; + static std::string name() { return "t_stack"; } }; class IR_API InletType : public Type::TypeBase { public: using Base::Base; + static std::string name() { return "t_inlet"; } }; class IR_API OutletType : public Type::TypeBase { public: using Base::Base; + static std::string name() { return "t_outlet"; } }; } // namespace pir