Skip to content

Commit

Permalink
add cf saveload and flag (#68628) (#68709)
Browse files Browse the repository at this point in the history
  • Loading branch information
changeyoung98 authored Oct 16, 2024
1 parent 5a5c6fc commit a1050d8
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 39 deletions.
4 changes: 4 additions & 0 deletions paddle/common/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1965,3 +1965,7 @@ PHI_DEFINE_EXPORTED_bool(fused_multi_transformer_op_use_mbfmha,
PHI_DEFINE_EXPORTED_int64(multi_block_attention_min_partition_size,
1024,
"The minimum partition size for flash decoding");

PHI_DEFINE_EXPORTED_bool(save_cf_stack_op,
false,
"Save cf stack op for higher-order derivatives.");
29 changes: 29 additions & 0 deletions paddle/fluid/pir/serialize_deserialize/include/deserialize_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "paddle/phi/common/data_type.h"
#include "paddle/pir/include/core/builtin_attribute.h"
#include "paddle/pir/include/core/builtin_type.h"
#include "paddle/pir/include/dialect/control_flow/ir/cf_type.h"
#include "paddle/utils/flat_hash_map.h"

namespace pir {
Expand Down Expand Up @@ -66,6 +67,10 @@ class AttrTypeReader {
static pir::Attribute ReadPaddleDistAttr(const std::string attr_name,
Json* attr_json,
pir::IrContext* ctx);

static pir::Type ReadControlFlowType(const std::string type_name,
Json* type_json,
pir::IrContext* ctx);
};

template <typename T>
Expand Down Expand Up @@ -237,6 +242,9 @@ pir::Type parseType(Json* type_json) {
} else if (DECOMPRESS_DIALECT_ID(name.first) ==
paddle::dialect::DistDialect::name()) {
return AttrTypeReader::ReadPaddleDistType(name.second, type_json, ctx);
} else if (DECOMPRESS_DIALECT_ID(name.first) ==
pir::ControlFlowDialect::name()) {
return AttrTypeReader::ReadControlFlowType(name.second, type_json, ctx);
} else {
PADDLE_ENFORCE(
false,
Expand Down Expand Up @@ -695,4 +703,25 @@ pir::Type AttrTypeReader::ReadPaddleDistType(const std::string type_name,
}
}

pir::Type AttrTypeReader::ReadControlFlowType(const std::string type_name,
Json* type_json,
pir::IrContext* ctx) {
if (type_name == pir::StackType::name()) {
VLOG(8) << "Parse StackType ... ";
return pir::deserializeTypeFromJson<pir::StackType>(type_json, ctx);
} else if (type_name == pir::InletType::name()) {
VLOG(8) << "Parse InletType ... ";
return pir::deserializeTypeFromJson<pir::InletType>(type_json, ctx);
} else if (type_name == pir::OutletType::name()) {
VLOG(8) << "Parse OutletType ... ";
return pir::deserializeTypeFromJson<pir::OutletType>(type_json, ctx);
} else {
PADDLE_ENFORCE(
false,
common::errors::InvalidArgument(
"Unknown Type %s for parse controlflow dialect type", type_name));
return pir::Type();
}
}

} // namespace pir
28 changes: 28 additions & 0 deletions paddle/fluid/pir/serialize_deserialize/include/serialize_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "paddle/phi/common/data_type.h"
#include "paddle/pir/include/core/builtin_attribute.h"
#include "paddle/pir/include/core/builtin_type.h"
#include "paddle/pir/include/dialect/control_flow/ir/cf_type.h"

namespace pir {
#define COMPRESS_DIALECT_NAME(attr_template) \
Expand All @@ -53,6 +54,8 @@ class AttrTypeWriter {
static Json WritePaddleDistType(const pir::Type& type);

static Json WritePaddleDistAttr(const pir::Attribute& attr);

static Json WriteControlFlowType(const pir::Type& type);
};
/** serializeTypeToJson is a template function to serialize
* a pir type to a json object. a pir type may have value or no value
Expand Down Expand Up @@ -245,6 +248,9 @@ Json writeType(const pir::Type& type) {
} else if (type.dialect().name() == paddle::dialect::DistDialect::name()) {
VLOG(6) << "write PaddleDistType ... ";
return AttrTypeWriter::WritePaddleDistType(type);
} else if (type.dialect().name() == pir::ControlFlowDialect::name()) {
VLOG(6) << "write ControlFlowDialect ... ";
return AttrTypeWriter::WriteControlFlowType(type);
} else {
PADDLE_ENFORCE(
false,
Expand Down Expand Up @@ -723,4 +729,26 @@ Json AttrTypeWriter::WritePaddleDistAttr(const pir::Attribute& attr) {
return Json::object();
}

Json AttrTypeWriter::WriteControlFlowType(const pir::Type& type) {
Json type_json = Json::object();
if (type.isa<pir::StackType>()) {
VLOG(8) << "Write StackType ... ";
return pir::serializeTypeToJson<pir::StackType>(
type.dyn_cast<pir::StackType>());
} else if (type.isa<pir::InletType>()) {
VLOG(8) << "Write InletType ... ";
return pir::serializeTypeToJson<pir::InletType>(
type.dyn_cast<pir::InletType>());
} else if (type.isa<pir::OutletType>()) {
VLOG(8) << "Write OutletType ... ";
return pir::serializeTypeToJson<pir::OutletType>(
type.dyn_cast<pir::OutletType>());
} else {
PADDLE_ENFORCE(false,
common::errors::InvalidArgument(
"Unknown Type when write controlflow dialect type"));
}
return type_json;
}

} // namespace pir
44 changes: 24 additions & 20 deletions paddle/fluid/pir/serialize_deserialize/src/ir_serialize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
// limitations under the License.

#include "paddle/fluid/pir/serialize_deserialize/include/ir_serialize.h"
#include "paddle/common/flags.h"
#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h"
#include "paddle/fluid/pir/serialize_deserialize/include/serialize_utils.h"
#include "paddle/pir/include/core/dialect.h"
#include "paddle/pir/include/core/operation.h"
#include "paddle/pir/include/dialect/control_flow/ir/cf_dialect.h"
#include "paddle/pir/include/dialect/control_flow/ir/cf_op.h"

COMMON_DECLARE_bool(save_cf_stack_op);
namespace pir {

Json ProgramWriter::GetProgramJson(const pir::Program* program) {
Expand Down Expand Up @@ -99,29 +101,31 @@ Json ProgramWriter::WriteBlock(pir::Block* block,
Json ops_json = Json::array();

/* delete cf.stack_create / cf.tuple_push */
std::vector<pir::Operation*> delete_ops;
for (auto op : block->ops()) {
if (op->isa<pir::StackCreateOp>()) {
delete_ops.push_back(op);
}
}
VLOG(6) << "program before delete stack op :" << *(block->parent_program());
for (auto op : delete_ops) {
VLOG(0) << "Delete cf.stack_create / cf.tuple_push.";
auto stack_op = op->dyn_cast<pir::StackCreateOp>();
if (stack_op.inlet().HasOneUse()) {
auto tuple_push_op = stack_op.tuple_push_op();
auto block_in = tuple_push_op->GetParent();
block_in->erase(*tuple_push_op);
if (!FLAGS_save_cf_stack_op) {
std::vector<pir::Operation*> delete_ops;
for (auto op : block->ops()) {
if (op->isa<pir::StackCreateOp>()) {
delete_ops.push_back(op);
}
}
if (stack_op.outlet().HasOneUse()) {
auto tuple_pop_op = stack_op.tuple_pop_op();
auto block_in = tuple_pop_op->GetParent();
block_in->erase(*tuple_pop_op);
VLOG(6) << "program before delete stack op :" << *(block->parent_program());
for (auto op : delete_ops) {
VLOG(0) << "Delete cf.stack_create / cf.tuple_push.";
auto stack_op = op->dyn_cast<pir::StackCreateOp>();
if (stack_op.inlet().HasOneUse()) {
auto tuple_push_op = stack_op.tuple_push_op();
auto block_in = tuple_push_op->GetParent();
block_in->erase(*tuple_push_op);
}
if (stack_op.outlet().HasOneUse()) {
auto tuple_pop_op = stack_op.tuple_pop_op();
auto block_in = tuple_pop_op->GetParent();
block_in->erase(*tuple_pop_op);
}
block->erase(*op);
}
block->erase(*op);
VLOG(6) << "program after delete stack op :" << *(block->parent_program());
}
VLOG(6) << "program after delete stack op :" << *(block->parent_program());
for (auto op : block->ops()) {
auto op_json = WriteOp(*op);
ops_json.emplace_back(op_json);
Expand Down
52 changes: 34 additions & 18 deletions paddle/fluid/pir/serialize_deserialize/src/patch_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "paddle/phi/common/data_type.h"
#include "paddle/pir/include/core/builtin_attribute.h"
#include "paddle/pir/include/core/builtin_type.h"
#include "paddle/pir/include/dialect/control_flow/ir/cf_type.h"

namespace pir {

Expand Down Expand Up @@ -168,9 +169,15 @@ std::string GetTypeName(const YAML::Node &action) {

Json GetTypeJson(const YAML::Node &action) {
Json json;
std::string dialect = DialectIdMap::Instance()->GetCompressDialectId(
pir::BuiltinDialect::name()) +
".";
std::string builtin_dialect = DialectIdMap::Instance()->GetCompressDialectId(
pir::BuiltinDialect::name()) +
".";
std::string op_dialect = DialectIdMap::Instance()->GetCompressDialectId(
paddle::dialect::OperatorDialect::name()) +
".";
std::string cf_dialect = DialectIdMap::Instance()->GetCompressDialectId(
pir::ControlFlowDialect::name()) +
".";
std::string type_name = "";
if (action.IsScalar()) {
type_name = action.as<std::string>();
Expand All @@ -181,54 +188,54 @@ Json GetTypeJson(const YAML::Node &action) {
}
if (type_name == "pir::BoolType") {
VLOG(8) << "Get BoolType name.";
json[ID] = dialect + pir::BoolType::name();
json[ID] = builtin_dialect + pir::BoolType::name();
} else if (type_name == "pir::BFloat16Type") {
VLOG(8) << "Get BFloat16Type name.";
json[ID] = dialect + pir::BFloat16Type::name();
json[ID] = builtin_dialect + pir::BFloat16Type::name();
} else if (type_name == "pir::Float16Type") {
VLOG(8) << "Get Float16Type name.";
json[ID] = dialect + pir::Float16Type::name();
json[ID] = builtin_dialect + pir::Float16Type::name();
} else if (type_name == "pir::Float32Type") {
VLOG(8) << "Get Float32Type name.";
json[ID] = dialect + pir::Float32Type::name();
json[ID] = builtin_dialect + pir::Float32Type::name();
} else if (type_name == "pir::Float64Type") {
VLOG(8) << "Get Float64Type name.";
json[ID] = dialect + pir::Float64Type::name();
json[ID] = builtin_dialect + pir::Float64Type::name();
} else if (type_name == "pir::Int8Type") {
VLOG(8) << "Get Int8Type name.";
json[ID] = dialect + pir::Int8Type::name();
json[ID] = builtin_dialect + pir::Int8Type::name();
} else if (type_name == "pir::UInt8Type") {
VLOG(8) << "Get UInt8Type name.";
json[ID] = dialect + pir::UInt8Type::name();
json[ID] = builtin_dialect + pir::UInt8Type::name();
} else if (type_name == "pir::Int16Type") {
VLOG(8) << "Get Int16Type name.";
json[ID] = dialect + pir::Int16Type::name();
json[ID] = builtin_dialect + pir::Int16Type::name();
} else if (type_name == "pir::Int32Type") {
VLOG(8) << "Get Int32Type name.";
json[ID] = dialect + pir::Int32Type::name();
json[ID] = builtin_dialect + pir::Int32Type::name();
} else if (type_name == "pir::Int64Type") {
VLOG(8) << "Get Int64Type name.";
json[ID] = dialect + pir::Int64Type::name();
json[ID] = builtin_dialect + pir::Int64Type::name();
} else if (type_name == "pir::IndexType") {
VLOG(8) << "Get IndexType name.";
json[ID] = dialect + pir::IndexType::name();
json[ID] = builtin_dialect + pir::IndexType::name();
} else if (type_name == "pir::Complex64Type") {
VLOG(8) << "Get Complex64Type name.";
json[ID] = dialect + pir::Complex64Type::name();
json[ID] = builtin_dialect + pir::Complex64Type::name();
} else if (type_name == "pir::Complex128Type") {
VLOG(8) << "Get Complex128Type name.";
json[ID] = dialect + pir::Complex128Type::name();
json[ID] = builtin_dialect + pir::Complex128Type::name();
} else if (type_name == "pir::VectorType") {
VLOG(8) << "Get VectorType name.";
json[ID] = dialect + pir::VectorType::name();
json[ID] = builtin_dialect + pir::VectorType::name();
json[DATA] = Json::array();
for (size_t i = 0; i < action["default"].size(); i++) {
YAML::Node array_value = action["default"][i];
json[DATA].push_back(BuildTypeJsonPatch(array_value));
}
} else if (type_name == "pir::DenseTensorType") {
VLOG(8) << "Get DenseTensorType name.";
json[ID] = dialect + pir::DenseTensorType::name();
json[ID] = builtin_dialect + pir::DenseTensorType::name();
Json content = Json::array();
YAML::Node tensor_value = action["default"];
content.push_back(BuildTypeJsonPatch(tensor_value[0]));
Expand All @@ -242,6 +249,15 @@ Json GetTypeJson(const YAML::Node &action) {

content.push_back(tensor_value[4].as<int>()); // offset
json[DATA] = content;
} else if (type_name == "pir::StackType") {
VLOG(8) << "Get StackType name.";
json[ID] = cf_dialect + pir::StackType::name();
} else if (type_name == "pir::InletType") {
VLOG(8) << "Get InletType name.";
json[ID] = cf_dialect + pir::InletType::name();
} else if (type_name == "pir::OutletType") {
VLOG(8) << "Get OutletType name.";
json[ID] = cf_dialect + pir::OutletType::name();
}
return json;
}
Expand Down
3 changes: 2 additions & 1 deletion paddle/pir/include/core/storage_manager_support.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ class StorageHelperBase : public BaseT {
using InterfaceList =
typename Filter<TypeInterfaceBase, std::tuple<TraitOrInterface...>>::Type;

static ConcreteT dyn_cast_impl(BaseT type) {
template <typename T>
static ConcreteT dyn_cast_impl(T type) {
if (type && type.type_id() == TypeId::get<ConcreteT>()) {
return ConcreteT(type.storage());
}
Expand Down
3 changes: 3 additions & 0 deletions paddle/pir/include/dialect/control_flow/ir/cf_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,19 @@ class IR_API StackType
: public Type::TypeBase<StackType, ContainerType, TypeStorage> {
public:
using Base::Base;
static std::string name() { return "t_stack"; }
};

class IR_API InletType : public Type::TypeBase<InletType, Type, TypeStorage> {
public:
using Base::Base;
static std::string name() { return "t_inlet"; }
};

class IR_API OutletType : public Type::TypeBase<OutletType, Type, TypeStorage> {
public:
using Base::Base;
static std::string name() { return "t_outlet"; }
};

} // namespace pir
Expand Down

0 comments on commit a1050d8

Please sign in to comment.