From d025ef6620b131f3487bb748866ddd9d7225c09f Mon Sep 17 00:00:00 2001 From: "Wang, Zhitao" Date: Thu, 26 Oct 2023 07:20:49 +0000 Subject: [PATCH] src: graph: backend: dnnl: patterns: add input num check for avgpool bwd --- .../backend/dnnl/patterns/single_op_pattern.cpp | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/src/graph/backend/dnnl/patterns/single_op_pattern.cpp b/src/graph/backend/dnnl/patterns/single_op_pattern.cpp index 10013331f34..093f32e4abc 100644 --- a/src/graph/backend/dnnl/patterns/single_op_pattern.cpp +++ b/src/graph/backend/dnnl/patterns/single_op_pattern.cpp @@ -49,10 +49,23 @@ DNNL_BACKEND_REGISTER_PATTERN_DEF_BEGIN(single_op_pass) // register passes with dnnl backend support DNNL_BACKEND_SINGLE_OP_TRANSFORM(abs_pass, Abs, float_eltwise_fwd) DNNL_BACKEND_SINGLE_OP_TRANSFORM(abs_bw_pass, AbsBackward, eltwise_bwd_t) -DNNL_BACKEND_SINGLE_OP_TRANSFORM( - avg_pool_bw_pass, AvgPoolBackward, pooling_bwd_t) DNNL_BACKEND_SINGLE_OP_TRANSFORM(bias_add_pass, BiasAdd, binary_t) +DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, avg_pool_bw_pass) + .set_priority(8.f) + .set_kind(partition_kind_t::misc_post_ops) + .set_attr("FCreatePattern", + [](const std::shared_ptr &pgraph) -> void { + graph::utils::pm::pb_op_t *p_avg_pool_backward + = pgraph->append_op( + graph::op_kind::AvgPoolBackward); + p_avg_pool_backward->append_decision_function( + check_input_num<1>); + }) + .set_attr("FCreateKernel", []() -> kernel_ptr { + return std::make_shared(); + }); + DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, bn_pass) .set_priority(8.f) .set_kind(partition_kind_t::misc_post_ops)