Skip to content

Commit

Permalink
src: graph: backend: dnnl: patterns: add input num check for avgpool bwd
Browse files Browse the repository at this point in the history
  • Loading branch information
wzt1997 authored and vpirogov committed Nov 9, 2023
1 parent 9e0602a commit d025ef6
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions src/graph/backend/dnnl/patterns/single_op_pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,23 @@ DNNL_BACKEND_REGISTER_PATTERN_DEF_BEGIN(single_op_pass)
// register passes with dnnl backend support
DNNL_BACKEND_SINGLE_OP_TRANSFORM(abs_pass, Abs, float_eltwise_fwd)
DNNL_BACKEND_SINGLE_OP_TRANSFORM(abs_bw_pass, AbsBackward, eltwise_bwd_t)
DNNL_BACKEND_SINGLE_OP_TRANSFORM(
avg_pool_bw_pass, AvgPoolBackward, pooling_bwd_t)
DNNL_BACKEND_SINGLE_OP_TRANSFORM(bias_add_pass, BiasAdd, binary_t)

DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, avg_pool_bw_pass)
.set_priority(8.f)
.set_kind(partition_kind_t::misc_post_ops)
.set_attr<FCreatePattern>("FCreatePattern",
[](const std::shared_ptr<pb_graph_t> &pgraph) -> void {
graph::utils::pm::pb_op_t *p_avg_pool_backward
= pgraph->append_op(
graph::op_kind::AvgPoolBackward);
p_avg_pool_backward->append_decision_function(
check_input_num<1>);
})
.set_attr<FCreateKernel>("FCreateKernel", []() -> kernel_ptr {
return std::make_shared<pooling_bwd_t>();
});

DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, bn_pass)
.set_priority(8.f)
.set_kind(partition_kind_t::misc_post_ops)
Expand Down

0 comments on commit d025ef6

Please sign in to comment.