Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR save/load]Add cf saveload and flag #68628

Merged
merged 1 commit into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions paddle/common/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1953,3 +1953,7 @@ PHI_DEFINE_EXPORTED_bool(fused_multi_transformer_op_use_mbfmha,
PHI_DEFINE_EXPORTED_int64(multi_block_attention_min_partition_size,
1024,
"The minimum partition size for flash decoding");

PHI_DEFINE_EXPORTED_bool(save_cf_stack_op,
false,
"Save cf stack op for higher-order derivatives.");
29 changes: 29 additions & 0 deletions paddle/fluid/pir/serialize_deserialize/include/deserialize_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "paddle/phi/common/data_type.h"
#include "paddle/pir/include/core/builtin_attribute.h"
#include "paddle/pir/include/core/builtin_type.h"
#include "paddle/pir/include/dialect/control_flow/ir/cf_type.h"
#include "paddle/utils/flat_hash_map.h"

namespace pir {
Expand Down Expand Up @@ -66,6 +67,10 @@ class AttrTypeReader {
static pir::Attribute ReadPaddleDistAttr(const std::string attr_name,
Json* attr_json,
pir::IrContext* ctx);

static pir::Type ReadControlFlowType(const std::string type_name,
Json* type_json,
pir::IrContext* ctx);
};

template <typename T>
Expand Down Expand Up @@ -237,6 +242,9 @@ pir::Type parseType(Json* type_json) {
} else if (DECOMPRESS_DIALECT_ID(name.first) ==
paddle::dialect::DistDialect::name()) {
return AttrTypeReader::ReadPaddleDistType(name.second, type_json, ctx);
} else if (DECOMPRESS_DIALECT_ID(name.first) ==
pir::ControlFlowDialect::name()) {
return AttrTypeReader::ReadControlFlowType(name.second, type_json, ctx);
} else {
PADDLE_ENFORCE(
false,
Expand Down Expand Up @@ -695,4 +703,25 @@ pir::Type AttrTypeReader::ReadPaddleDistType(const std::string type_name,
}
}

pir::Type AttrTypeReader::ReadControlFlowType(const std::string type_name,
Json* type_json,
pir::IrContext* ctx) {
if (type_name == pir::StackType::name()) {
VLOG(8) << "Parse StackType ... ";
return pir::deserializeTypeFromJson<pir::StackType>(type_json, ctx);
} else if (type_name == pir::InletType::name()) {
VLOG(8) << "Parse InletType ... ";
return pir::deserializeTypeFromJson<pir::InletType>(type_json, ctx);
} else if (type_name == pir::OutletType::name()) {
VLOG(8) << "Parse OutletType ... ";
return pir::deserializeTypeFromJson<pir::OutletType>(type_json, ctx);
} else {
PADDLE_ENFORCE(
false,
common::errors::InvalidArgument(
"Unknown Type %s for parse controlflow dialect type", type_name));
return pir::Type();
}
}

} // namespace pir
28 changes: 28 additions & 0 deletions paddle/fluid/pir/serialize_deserialize/include/serialize_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "paddle/phi/common/data_type.h"
#include "paddle/pir/include/core/builtin_attribute.h"
#include "paddle/pir/include/core/builtin_type.h"
#include "paddle/pir/include/dialect/control_flow/ir/cf_type.h"

namespace pir {
#define COMPRESS_DIALECT_NAME(attr_template) \
Expand All @@ -53,6 +54,8 @@ class AttrTypeWriter {
static Json WritePaddleDistType(const pir::Type& type);

static Json WritePaddleDistAttr(const pir::Attribute& attr);

static Json WriteControlFlowType(const pir::Type& type);
};
/** serializeTypeToJson is a template function to serialize
* a pir type to a json object. a pir type may have value or no value
Expand Down Expand Up @@ -245,6 +248,9 @@ Json writeType(const pir::Type& type) {
} else if (type.dialect().name() == paddle::dialect::DistDialect::name()) {
VLOG(6) << "write PaddleDistType ... ";
return AttrTypeWriter::WritePaddleDistType(type);
} else if (type.dialect().name() == pir::ControlFlowDialect::name()) {
VLOG(6) << "write ControlFlowDialect ... ";
return AttrTypeWriter::WriteControlFlowType(type);
} else {
PADDLE_ENFORCE(
false,
Expand Down Expand Up @@ -723,4 +729,26 @@ Json AttrTypeWriter::WritePaddleDistAttr(const pir::Attribute& attr) {
return Json::object();
}

Json AttrTypeWriter::WriteControlFlowType(const pir::Type& type) {
Json type_json = Json::object();
if (type.isa<pir::StackType>()) {
VLOG(8) << "Write StackType ... ";
return pir::serializeTypeToJson<pir::StackType>(
type.dyn_cast<pir::StackType>());
} else if (type.isa<pir::InletType>()) {
VLOG(8) << "Write InletType ... ";
return pir::serializeTypeToJson<pir::InletType>(
type.dyn_cast<pir::InletType>());
} else if (type.isa<pir::OutletType>()) {
VLOG(8) << "Write OutletType ... ";
return pir::serializeTypeToJson<pir::OutletType>(
type.dyn_cast<pir::OutletType>());
} else {
PADDLE_ENFORCE(false,
common::errors::InvalidArgument(
"Unknown Type when write controlflow dialect type"));
}
return type_json;
}

} // namespace pir
44 changes: 24 additions & 20 deletions paddle/fluid/pir/serialize_deserialize/src/ir_serialize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
// limitations under the License.

#include "paddle/fluid/pir/serialize_deserialize/include/ir_serialize.h"
#include "paddle/common/flags.h"
#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h"
#include "paddle/fluid/pir/serialize_deserialize/include/serialize_utils.h"
#include "paddle/pir/include/core/dialect.h"
#include "paddle/pir/include/core/operation.h"
#include "paddle/pir/include/dialect/control_flow/ir/cf_dialect.h"
#include "paddle/pir/include/dialect/control_flow/ir/cf_op.h"

COMMON_DECLARE_bool(save_cf_stack_op);
namespace pir {

Json ProgramWriter::GetProgramJson(const pir::Program* program) {
Expand Down Expand Up @@ -99,29 +101,31 @@ Json ProgramWriter::WriteBlock(pir::Block* block,
Json ops_json = Json::array();

/* delete cf.stack_create / cf.tuple_push */
std::vector<pir::Operation*> delete_ops;
for (auto op : block->ops()) {
if (op->isa<pir::StackCreateOp>()) {
delete_ops.push_back(op);
}
}
VLOG(6) << "program before delete stack op :" << *(block->parent_program());
for (auto op : delete_ops) {
VLOG(0) << "Delete cf.stack_create / cf.tuple_push.";
auto stack_op = op->dyn_cast<pir::StackCreateOp>();
if (stack_op.inlet().HasOneUse()) {
auto tuple_push_op = stack_op.tuple_push_op();
auto block_in = tuple_push_op->GetParent();
block_in->erase(*tuple_push_op);
if (!FLAGS_save_cf_stack_op) {
std::vector<pir::Operation*> delete_ops;
for (auto op : block->ops()) {
if (op->isa<pir::StackCreateOp>()) {
delete_ops.push_back(op);
}
}
if (stack_op.outlet().HasOneUse()) {
auto tuple_pop_op = stack_op.tuple_pop_op();
auto block_in = tuple_pop_op->GetParent();
block_in->erase(*tuple_pop_op);
VLOG(6) << "program before delete stack op :" << *(block->parent_program());
for (auto op : delete_ops) {
VLOG(0) << "Delete cf.stack_create / cf.tuple_push.";
auto stack_op = op->dyn_cast<pir::StackCreateOp>();
if (stack_op.inlet().HasOneUse()) {
auto tuple_push_op = stack_op.tuple_push_op();
auto block_in = tuple_push_op->GetParent();
block_in->erase(*tuple_push_op);
}
if (stack_op.outlet().HasOneUse()) {
auto tuple_pop_op = stack_op.tuple_pop_op();
auto block_in = tuple_pop_op->GetParent();
block_in->erase(*tuple_pop_op);
}
block->erase(*op);
}
block->erase(*op);
VLOG(6) << "program after delete stack op :" << *(block->parent_program());
}
VLOG(6) << "program after delete stack op :" << *(block->parent_program());
for (auto op : block->ops()) {
auto op_json = WriteOp(*op);
ops_json.emplace_back(op_json);
Expand Down
52 changes: 34 additions & 18 deletions paddle/fluid/pir/serialize_deserialize/src/patch_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "paddle/phi/common/data_type.h"
#include "paddle/pir/include/core/builtin_attribute.h"
#include "paddle/pir/include/core/builtin_type.h"
#include "paddle/pir/include/dialect/control_flow/ir/cf_type.h"

namespace pir {

Expand Down Expand Up @@ -168,9 +169,15 @@ std::string GetTypeName(const YAML::Node &action) {

Json GetTypeJson(const YAML::Node &action) {
Json json;
std::string dialect = DialectIdMap::Instance()->GetCompressDialectId(
pir::BuiltinDialect::name()) +
".";
std::string builtin_dialect = DialectIdMap::Instance()->GetCompressDialectId(
pir::BuiltinDialect::name()) +
".";
std::string op_dialect = DialectIdMap::Instance()->GetCompressDialectId(
paddle::dialect::OperatorDialect::name()) +
".";
std::string cf_dialect = DialectIdMap::Instance()->GetCompressDialectId(
pir::ControlFlowDialect::name()) +
".";
std::string type_name = "";
if (action.IsScalar()) {
type_name = action.as<std::string>();
Expand All @@ -181,54 +188,54 @@ Json GetTypeJson(const YAML::Node &action) {
}
if (type_name == "pir::BoolType") {
VLOG(8) << "Get BoolType name.";
json[ID] = dialect + pir::BoolType::name();
json[ID] = builtin_dialect + pir::BoolType::name();
} else if (type_name == "pir::BFloat16Type") {
VLOG(8) << "Get BFloat16Type name.";
json[ID] = dialect + pir::BFloat16Type::name();
json[ID] = builtin_dialect + pir::BFloat16Type::name();
} else if (type_name == "pir::Float16Type") {
VLOG(8) << "Get Float16Type name.";
json[ID] = dialect + pir::Float16Type::name();
json[ID] = builtin_dialect + pir::Float16Type::name();
} else if (type_name == "pir::Float32Type") {
VLOG(8) << "Get Float32Type name.";
json[ID] = dialect + pir::Float32Type::name();
json[ID] = builtin_dialect + pir::Float32Type::name();
} else if (type_name == "pir::Float64Type") {
VLOG(8) << "Get Float64Type name.";
json[ID] = dialect + pir::Float64Type::name();
json[ID] = builtin_dialect + pir::Float64Type::name();
} else if (type_name == "pir::Int8Type") {
VLOG(8) << "Get Int8Type name.";
json[ID] = dialect + pir::Int8Type::name();
json[ID] = builtin_dialect + pir::Int8Type::name();
} else if (type_name == "pir::UInt8Type") {
VLOG(8) << "Get UInt8Type name.";
json[ID] = dialect + pir::UInt8Type::name();
json[ID] = builtin_dialect + pir::UInt8Type::name();
} else if (type_name == "pir::Int16Type") {
VLOG(8) << "Get Int16Type name.";
json[ID] = dialect + pir::Int16Type::name();
json[ID] = builtin_dialect + pir::Int16Type::name();
} else if (type_name == "pir::Int32Type") {
VLOG(8) << "Get Int32Type name.";
json[ID] = dialect + pir::Int32Type::name();
json[ID] = builtin_dialect + pir::Int32Type::name();
} else if (type_name == "pir::Int64Type") {
VLOG(8) << "Get Int64Type name.";
json[ID] = dialect + pir::Int64Type::name();
json[ID] = builtin_dialect + pir::Int64Type::name();
} else if (type_name == "pir::IndexType") {
VLOG(8) << "Get IndexType name.";
json[ID] = dialect + pir::IndexType::name();
json[ID] = builtin_dialect + pir::IndexType::name();
} else if (type_name == "pir::Complex64Type") {
VLOG(8) << "Get Complex64Type name.";
json[ID] = dialect + pir::Complex64Type::name();
json[ID] = builtin_dialect + pir::Complex64Type::name();
} else if (type_name == "pir::Complex128Type") {
VLOG(8) << "Get Complex128Type name.";
json[ID] = dialect + pir::Complex128Type::name();
json[ID] = builtin_dialect + pir::Complex128Type::name();
} else if (type_name == "pir::VectorType") {
VLOG(8) << "Get VectorType name.";
json[ID] = dialect + pir::VectorType::name();
json[ID] = builtin_dialect + pir::VectorType::name();
json[DATA] = Json::array();
for (size_t i = 0; i < action["default"].size(); i++) {
YAML::Node array_value = action["default"][i];
json[DATA].push_back(BuildTypeJsonPatch(array_value));
}
} else if (type_name == "pir::DenseTensorType") {
VLOG(8) << "Get DenseTensorType name.";
json[ID] = dialect + pir::DenseTensorType::name();
json[ID] = builtin_dialect + pir::DenseTensorType::name();
Json content = Json::array();
YAML::Node tensor_value = action["default"];
content.push_back(BuildTypeJsonPatch(tensor_value[0]));
Expand All @@ -242,6 +249,15 @@ Json GetTypeJson(const YAML::Node &action) {

content.push_back(tensor_value[4].as<int>()); // offset
json[DATA] = content;
} else if (type_name == "pir::StackType") {
VLOG(8) << "Get StackType name.";
json[ID] = cf_dialect + pir::StackType::name();
} else if (type_name == "pir::InletType") {
VLOG(8) << "Get InletType name.";
json[ID] = cf_dialect + pir::InletType::name();
} else if (type_name == "pir::OutletType") {
VLOG(8) << "Get OutletType name.";
json[ID] = cf_dialect + pir::OutletType::name();
}
return json;
}
Expand Down
3 changes: 2 additions & 1 deletion paddle/pir/include/core/storage_manager_support.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ class StorageHelperBase : public BaseT {
using InterfaceList =
typename Filter<TypeInterfaceBase, std::tuple<TraitOrInterface...>>::Type;

static ConcreteT dyn_cast_impl(BaseT type) {
template <typename T>
static ConcreteT dyn_cast_impl(T type) {
if (type && type.type_id() == TypeId::get<ConcreteT>()) {
return ConcreteT(type.storage());
}
Expand Down
3 changes: 3 additions & 0 deletions paddle/pir/include/dialect/control_flow/ir/cf_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,19 @@ class IR_API StackType
: public Type::TypeBase<StackType, ContainerType, TypeStorage> {
public:
using Base::Base;
static std::string name() { return "t_stack"; }
};

class IR_API InletType : public Type::TypeBase<InletType, Type, TypeStorage> {
public:
using Base::Base;
static std::string name() { return "t_inlet"; }
};

class IR_API OutletType : public Type::TypeBase<OutletType, Type, TypeStorage> {
public:
using Base::Base;
static std::string name() { return "t_outlet"; }
};

} // namespace pir
Expand Down