Skip to content

Commit

Permalink
[Inference] inference add cinn interface (#48741)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiweibo authored Dec 8, 2022
1 parent 379216a commit 3a387df
Show file tree
Hide file tree
Showing 8 changed files with 104 additions and 16 deletions.
28 changes: 24 additions & 4 deletions paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,8 @@ void AnalyseClusterVariables(
const std::unordered_set<std::string>& deny_var_set,
GraphNodeSet* cluster_inputs,
GraphNodeSet* cluster_outputs,
GraphNodeSet* cluster_internals) {
GraphNodeSet* cluster_internals,
bool is_inference_stage) {
// collecting all input and output of op
for (auto* op_node : cluster) {
const auto& op_name = op_node->Name();
Expand Down Expand Up @@ -523,6 +524,18 @@ void AnalyseClusterVariables(
for (auto* var_node : *cluster_internals) {
cluster_outputs->erase(var_node);
}

if (is_inference_stage) {
// If part of the output of the Op is not used by other operators, change it
// to internal. such as transpose2 op's XShape out.
auto outs = *cluster_outputs;
for (auto* node : outs) {
if (node->outputs.empty()) {
cluster_outputs->erase(node);
cluster_internals->insert(node);
}
}
}
}

void AddLinkToCinnOp(const GraphNodeSet& cluster_inputs,
Expand Down Expand Up @@ -611,7 +624,7 @@ void ReplaceSubGraphWithCinnOpNode(
// Here we using SubgraphDetector to detecte the subgraph that
// all of op node supported by CINN. We using OpMapperRegistry
// to check whether the op node supported by CINN.
void SearchAllSubgraphs(Graph* graph) {
void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) {
auto allow_ops = StringSplit(FLAGS_allow_cinn_ops, kDelim);
auto deny_ops = StringSplit(FLAGS_deny_cinn_ops, kDelim);
OpTransInfo trans_info;
Expand Down Expand Up @@ -671,7 +684,8 @@ void SearchAllSubgraphs(Graph* graph) {
deny_var_set,
&cluster_inputs,
&cluster_outputs,
&cluster_internals);
&cluster_internals,
is_inference_stage);

VLOG(4) << "Cluster Ops: " << cluster_debug_info(cluster_set);
VLOG(4) << "Cluster input vars: " << cluster_debug_info(cluster_inputs);
Expand All @@ -698,7 +712,13 @@ void SearchAllSubgraphs(Graph* graph) {
}
} // namespace

void BuildCinnPass::ApplyImpl(Graph* graph) const { SearchAllSubgraphs(graph); }
void BuildCinnPass::ApplyImpl(Graph* graph) const {
bool is_inference_stage{false};
if (Has("is_inference_stage")) {
is_inference_stage = Get<bool>("is_inference_stage");
}
SearchAllSubgraphs(graph, is_inference_stage);
}

} // namespace paddle2cinn
} // namespace framework
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/inference/analysis/argument.h
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,9 @@ struct Argument {
DECL_ARGUMENT_FIELD(enable_gpu_half, EnableGPUHalf, bool);
DECL_ARGUMENT_FIELD(mixed_precision_mode, MixedPrecisionMode, int);

// cinn compiler related
DECL_ARGUMENT_FIELD(use_cinn_compiler, UseCinnCompiler, bool);

private:
std::unordered_set<std::string> valid_fields_;
};
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/inference/analysis/ir_pass_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,8 @@ void IRPassManager::CreatePasses(Argument *argument,
new framework::ProgramDesc *(&argument->main_program()));
} else if (pass_name == "memory_optimize_pass") {
pass->Set("root_predictor_id", new int(argument->root_predictor_id()));
} else if (pass_name == "build_cinn_pass") {
pass->Set("is_inference_stage", new bool(argument->use_cinn_compiler()));
}
if (pass_name == "lite_subgraph_pass") {
bool lite_enable_int8 =
Expand Down
31 changes: 30 additions & 1 deletion paddle/fluid/inference/api/analysis_config.cc
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,9 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
// profile related.
CP_MEMBER(with_profile_);

// cinn compiler related.
CP_MEMBER(use_cinn_compiler_);

// glog related.
CP_MEMBER(with_glog_info_);

Expand Down Expand Up @@ -542,7 +545,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
#undef CP_MEMBER

Update();
if (use_tensorrt_) {
if (use_tensorrt_ || use_cinn_compiler_) {
// Update() will reset all the passes, when some tensorRT pass is deleted in
// other.pass_builder(), it will set again, so we just remove the
// deleted_pass.
Expand Down Expand Up @@ -872,6 +875,14 @@ void AnalysisConfig::Update() {
}
}

// TODO(wilber): An ugly method to update pass, need to be fixed.
if (use_cinn_compiler_) {
pass_builder()->ClearPasses();
for (const auto &pass : kCINNCompilerPasses) {
pass_builder()->AppendPass(pass);
}
}

if (use_dlnne_) {
pass_builder()->ClearPasses();
for (const auto &pass : kDlnneSubgraphPasses) {
Expand Down Expand Up @@ -1316,6 +1327,9 @@ std::string AnalysisConfig::Summary() {
os.InsertRow({"use_lite", use_lite_ ? "true" : "false"});
}

// cinn compiler
os.InsertRow({"use_cinn_compiler", use_cinn_compiler_ ? "true" : "false"});

// ir info
os.InsertRow({"ir_optim", enable_ir_optim_ ? "true" : "false"});
os.InsertRow({"ir_debug", ir_debug_ ? "true" : "false"});
Expand Down Expand Up @@ -1429,4 +1443,19 @@ void AnalysisConfig::Exp_DisableMixedInferOps(
mixed_black_list_ = black_list;
}

void AnalysisConfig::Exp_EnableCINNCompiler() {
#ifdef PADDLE_WITH_CINN
use_cinn_compiler_ = true;
Update();
#else
PADDLE_THROW(platform::errors::Unavailable(
"You tried to use CINN compiler, but Paddle was not compiled "
"with CINN."));
#endif
}

bool AnalysisConfig::cinn_compiler_enabled() const {
return use_cinn_compiler_;
}

} // namespace paddle
30 changes: 19 additions & 11 deletions paddle/fluid/inference/api/analysis_predictor.cc
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -1217,6 +1217,10 @@ void AnalysisPredictor::PrepareArgument() {
argument_.SetMKLDNNEnabledOpTypes(config_.mkldnn_enabled_op_types_);
}

if (config_.use_cinn_compiler_) {
argument_.SetUseCinnCompiler(config_.use_cinn_compiler_);
}

#ifdef PADDLE_WITH_MKLDNN
if (config_.mkldnn_quantizer_enabled()) {
LOG(INFO) << "Quantization is enabled";
Expand All @@ -1239,21 +1243,25 @@ void AnalysisPredictor::PrepareArgument() {
#endif

auto *pass_builder = config_.pass_builder();
// TODO(inference): Need to reconstruct the pass_builder, pass should be
// processed in a single
if (model_precision_ != phi::DataType::FLOAT32) {
LOG(INFO) << "Model is mixed precision type with " << model_precision_
<< ", we will use a new PassStrategy. Note that only the GPU "
"backend is supported for now.";
pass_builder->ClearPasses();
const auto &deleted_passes = pass_builder->GetAllDeletedPasses();
if (config_.tensorrt_engine_enabled()) {
for (const auto &pass : kTrtLowerPrecisionPasses) {
if (deleted_passes.count(pass)) continue;
pass_builder->AppendPass(pass);
}
} else if (config_.use_gpu()) {
for (const auto &pass : kGpuLowerPrecisionPasses) {
if (deleted_passes.count(pass)) continue;
pass_builder->AppendPass(pass);
if (!config_.use_cinn_compiler_) {
pass_builder->ClearPasses();
const auto &deleted_passes = pass_builder->GetAllDeletedPasses();
if (config_.tensorrt_engine_enabled()) {
for (const auto &pass : kTrtLowerPrecisionPasses) {
if (deleted_passes.count(pass)) continue;
pass_builder->AppendPass(pass);
}
} else if (config_.use_gpu()) {
for (const auto &pass : kGpuLowerPrecisionPasses) {
if (deleted_passes.count(pass)) continue;
pass_builder->AppendPass(pass);
}
}
}
}
Expand Down
16 changes: 16 additions & 0 deletions paddle/fluid/inference/api/paddle_analysis_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -1016,6 +1016,19 @@ struct PD_INFER_DECL AnalysisConfig {

void SetSkipLoadParams(bool value) { skip_load_params_ = value; }

///
/// \brief Enable use cinn compiler optimization.
///
void Exp_EnableCINNCompiler();

///
/// \brief A boolean state telling whether the CINN compiler optimization is
/// turned on.
///
/// \return bool Whether the CINN compiler optimization is turned on.
///
bool cinn_compiler_enabled() const;

protected:
// Update the config.
void Update();
Expand Down Expand Up @@ -1143,6 +1156,9 @@ struct PD_INFER_DECL AnalysisConfig {
Precision lite_precision_mode_;
bool lite_zero_copy_;

// CINN compiler related.
bool use_cinn_compiler_{false};

// XPU related.
bool use_xpu_{false};
int xpu_device_id_{0};
Expand Down
7 changes: 7 additions & 0 deletions paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,13 @@ const std::vector<std::string> kTrtLowerPrecisionPasses{
"tensorrt_subgraph_pass",
};

const std::vector<std::string> kCINNCompilerPasses{
"gpu_cpu_map_matmul_v2_to_mul_pass",
"gpu_cpu_map_matmul_v2_to_matmul_pass",
"gpu_cpu_map_matmul_to_mul_pass",
"build_cinn_pass",
};

GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
passes_.assign({
// "identity_scale_op_clean_pass", //
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/inference/api/paddle_pass_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,9 @@ PD_INFER_DECL extern const std::vector<std::string> kDlnneSubgraphPasses;
/// \brief List of lite subgraph passes.
PD_INFER_DECL extern const std::vector<std::string> kLiteSubgraphPasses;

/// \brief List of cinn compiler passes.
PD_INFER_DECL extern const std::vector<std::string> kCINNCompilerPasses;

/// \brief TODO(inference): Most of the existing pass fusion operators do not
/// support fp16/bf16 precision, temporarily use low precision pass to prevent
/// running errors. After fusion operator supports low precision, delete this.
Expand Down

0 comments on commit 3a387df

Please sign in to comment.