diff --git a/.gitignore b/.gitignore index c4046a8d6b6e3..008e4b06e5834 100644 --- a/.gitignore +++ b/.gitignore @@ -108,7 +108,7 @@ paddle/fluid/pir/dialect/operator/ir/pd_api.* paddle/fluid/pir/dialect/operator/ir/op_decomp.cc paddle/fluid/pir/dialect/operator/ir/pd_op_vjp.cc paddle/fluid/pir/dialect/operator/ir/pd_op.* -paddle/fluid/pir/dialect/operator/ir/pd_onednn_op.* +paddle/fluid/pir/dialect/operator/ir/onednn_op.* paddle/fluid/pir/dialect/operator/ir/pd_onednn_op_info.* paddle/fluid/pir/dialect/operator/ir/pd_op_bwd.* paddle/fluid/pir/dialect/operator/ir/pd_op_fused.* diff --git a/paddle/fluid/framework/new_executor/pir_interpreter.cc b/paddle/fluid/framework/new_executor/pir_interpreter.cc index 823aaca9eed3a..6d3657ed4f65c 100644 --- a/paddle/fluid/framework/new_executor/pir_interpreter.cc +++ b/paddle/fluid/framework/new_executor/pir_interpreter.cc @@ -731,7 +731,7 @@ void PirInterpreter::BuildInstruction() { CREATE_INSTR(PhiKernelInstruction); } #ifdef PADDLE_WITH_DNNL - } else if (op.dialect()->name() == "pd_onednn_kernel") { + } else if (op.dialect()->name() == "onednn_kernel") { auto op_name = op.attributes() .at("op_name") .dyn_cast<::pir::StrAttribute>() diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index 0227091e0aa53..09921306ffa67 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -45,7 +45,7 @@ #include "paddle/pir/core/value.h" #ifdef PADDLE_WITH_DNNL -#include "paddle/fluid/pir/dialect/operator/ir/pd_onednn_op.h" +#include "paddle/fluid/pir/dialect/operator/ir/onednn_op.h" #endif // NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in // paddle/fluid/pir/dialect/CMakeLists.txt. @@ -81,7 +81,7 @@ using AttributeHandlerFn = std::function { class OneDNNPhiKernelOp : public pir::Op { public: using Op::Op; - static const char *name() { return "pd_onednn_kernel.phi_kernel"; } + static const char *name() { return "onednn_kernel.phi_kernel"; } static constexpr uint32_t attributes_num = 3; static const char *attributes_name[attributes_num]; std::string op_name(); @@ -72,7 +72,7 @@ class OneDNNPhiKernelOp : public pir::Op { class OneDNNMixedPhiKernelOp : public pir::Op { public: using Op::Op; - static const char *name() { return "pd_onednn_kernel.phi_mixed_kernel"; } + static const char *name() { return "onednn_kernel.phi_mixed_kernel"; } static constexpr uint32_t attributes_num = 3; static const char *attributes_name[attributes_num]; std::string op_name(); @@ -84,7 +84,7 @@ class OneDNNMixedPhiKernelOp : public pir::Op { class OneDNNLegacyKernelOp : public pir::Op { public: using Op::Op; - static const char *name() { return "pd_onednn_kernel.legacy_kernel"; } + static const char *name() { return "onednn_kernel.legacy_kernel"; } static constexpr uint32_t attributes_num = 3; static const char *attributes_name[attributes_num]; std::string op_name(); diff --git a/paddle/fluid/pir/dialect/op_generator/op_gen.py b/paddle/fluid/pir/dialect/op_generator/op_gen.py index 01ba59f79c4e2..8011054b20ac7 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_gen.py @@ -1149,20 +1149,20 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): if ( op_info.backward_name and op_info.op_phi_name[0] not in vjp_interface_black_list - and dialect_name != "pd_onednn_op" + and dialect_name != "onednn_op" ): op_interfaces += ["paddle::dialect::VjpInterface"] exclusive_interface_str = gen_exclusive_interface_str( op_info, op_info_items ) - if dialect_name == "pd_op" or dialect_name == "pd_onednn_op": + if dialect_name == "pd_op" or dialect_name == "onednn_op": op_interfaces += ["paddle::dialect::GetKernelTypeForVarInterface"] # if op has custom vjp rule, then append a CustomVjpTrait to it if ( op_info.op_phi_name[0] in custom_vjp_op_name_list - and dialect_name != "pd_onednn_op" + and dialect_name != "onednn_op" ): op_traits += ["paddle::dialect::CustomVjpTrait"] @@ -1184,7 +1184,7 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): if op_name[-1] == "_": op_traits += ["paddle::dialect::InplaceTrait"] - if dialect_name == "pd_onednn_op": + if dialect_name == "onednn_op": op_traits += ["paddle::dialect::OneDNNTrait"] if op_info.is_onednn_only: @@ -1208,7 +1208,7 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): if ( op_name in decomp_interface_declare_gen_op_list and kernel_func_name in decomp_interface_declare_gen_op_list - and dialect_name != "pd_onednn_op" + and dialect_name != "onednn_op" ): op_interfaces = op_interfaces + [ "paddle::dialect::DecompInterface" @@ -1272,7 +1272,7 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): build_func_with_muta_attr_is_input = "" get_kernel_type_for_var_declare_str = "" - if dialect_name == "pd_op" or dialect_name == "pd_onednn_op": + if dialect_name == "pd_op" or dialect_name == "onednn_op": get_kernel_type_for_var_declare_str = ( get_kernel_type_for_var_declare_template ) @@ -1607,7 +1607,7 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): origin_op_name=op_info.op_yaml_item['name'], ) - if dialect_name == "pd_onednn_op": + if dialect_name == "onednn_op": if len(op_info.onednn_extra_args) > 0: args_name = [] for arg in op_info.onednn_extra_args: @@ -1698,7 +1698,7 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): # generate op GetKernelKeyForVar function str op_get_kernel_type_for_var_str = '' - if dialect_name == "pd_op" or dialect_name == "pd_onednn_op": + if dialect_name == "pd_op" or dialect_name == "onednn_op": op_get_kernel_type_for_var_str = ( gen_kernel_type_for_var_str( op_class_name, @@ -1727,7 +1727,7 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): op_info.backward_name and op_info.op_phi_name[0] not in vjp_interface_black_list - and dialect_name != "pd_onednn_op" + and dialect_name != "onednn_op" ): op_vjp_str = gen_op_vjp_str( op_class_name, @@ -1758,7 +1758,7 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): ops_defined_list.append(infer_symbolic_shape_define_str) # NOTE(chenxi67)skip if dialect_name==cinn - if dialect_name == "cinn" or dialect_name == "pd_onednn_op": + if dialect_name == "cinn" or dialect_name == "onednn_op": pass else: ops_vjp_defined_list.append(op_vjp_str) @@ -1855,7 +1855,7 @@ def OpGenerator( # (2) parse yaml files op_compat_parser = OpCompatParser(op_compat_yaml_file) - if dialect_name == "pd_onednn_op": + if dialect_name == "onednn_op": with open(ops_onednn_extra_yaml_file, "r") as f: ops_onednn_extra = yaml.safe_load(f) ops_onednn_extra_map = {} @@ -1890,7 +1890,7 @@ def OpGenerator( op_info_items = {} for op in op_yaml_items: op_compat_item = None - if dialect_name == "pd_op" or dialect_name == "pd_onednn_op": + if dialect_name == "pd_op" or dialect_name == "onednn_op": op_compat_item = op_compat_parser.get_compat(op['name']) if ( @@ -1916,7 +1916,7 @@ def OpGenerator( ) = op_compat_parser.parse_support_tensor(op) op_compat_item['scalar'] = scalar_item op_compat_item['int_array'] = int_array_item - if dialect_name == "pd_onednn_op": + if dialect_name == "onednn_op": if first_file: first_file = False op["is_onednn_only"] = True @@ -1934,7 +1934,7 @@ def OpGenerator( all_op_info_items[op['name']] = item op_infos.append(op_info_items) - if dialect_name == "pd_onednn_op": + if dialect_name == "onednn_op": op_infos = [all_op_info_items] # (3) auto code gen @@ -2047,7 +2047,7 @@ def OpGenerator( namespace=name, input=source_file_str ) # Add namespaces - if dialect_name == "pd_onednn_op": + if dialect_name == "onednn_op": op_def_h_file_tmp = ( "paddle/fluid/pir/dialect/operator/ir/pd_op.h\"\n#include \"" + op_def_h_file @@ -2070,7 +2070,7 @@ def OpGenerator( vjp_source_file_str = VJP_CC_FILE_TEMPLATE.format(input=vjp_source_file_str) if ( dialect_name != 'cinn' - and dialect_name != 'pd_onednn_op' + and dialect_name != 'onednn_op' and op_vjp_cc_file ): with open(op_vjp_cc_file, 'w') as f: diff --git a/paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.cc b/paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.cc index b2d817b506199..28c8533da2efe 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.cc +++ b/paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.cc @@ -27,7 +27,7 @@ #include "paddle/pir/dialect/control_flow/ir/cf_op.h" #ifdef PADDLE_WITH_DNNL -#include "paddle/fluid/pir/dialect/operator/ir/pd_onednn_op.h" +#include "paddle/fluid/pir/dialect/operator/ir/onednn_op.h" #endif namespace paddle { diff --git a/paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.h b/paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.h index ac6483d4d53ec..608f4733fe104 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.h +++ b/paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.h @@ -23,7 +23,7 @@ class OneDNNOperatorDialect : public pir::Dialect { public: explicit OneDNNOperatorDialect(pir::IrContext* context); - static const char* name() { return "pd_onednn_op"; } + static const char* name() { return "onednn_op"; } pir::Type ParseType(pir::IrParser& parser) override; // NOLINT pir::Attribute ParseAttribute(pir::IrParser& parser) override; // NOLINT diff --git a/paddle/fluid/pir/dialect/operator/utils/utils.cc b/paddle/fluid/pir/dialect/operator/utils/utils.cc index eff7fecb844fb..06084b36b8c16 100644 --- a/paddle/fluid/pir/dialect/operator/utils/utils.cc +++ b/paddle/fluid/pir/dialect/operator/utils/utils.cc @@ -28,7 +28,7 @@ #include "paddle/utils/string/string_helper.h" #ifdef PADDLE_WITH_DNNL -#include "paddle/fluid/pir/dialect/operator/ir/pd_onednn_op.h" +#include "paddle/fluid/pir/dialect/operator/ir/onednn_op.h" #endif namespace paddle { diff --git a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc index 73e26ca1b1b09..fd154a6d20362 100644 --- a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc +++ b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc @@ -46,8 +46,8 @@ #include "paddle/utils/flags.h" #ifdef PADDLE_WITH_DNNL +#include "paddle/fluid/pir/dialect/operator/ir/onednn_op.h" #include "paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.h" -#include "paddle/fluid/pir/dialect/operator/ir/pd_onednn_op.h" #include "paddle/fluid/pir/dialect/operator/trait/onednn.h" #endif @@ -2219,7 +2219,7 @@ void ProcessBlock( } } std::string target_op_name = op_item->name(); - target_op_name.replace(0, 12, "pd_op"); + target_op_name.replace(0, 9, "pd_op"); auto op_info = ctx->GetRegisteredOpInfo(target_op_name); if (!op_info) { IR_THROW("Ctx should have corresponding OpInfo %s", target_op_name);