-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add reduce_all and reduce_any #49765
Changes from 5 commits
955b0ab
409f3c8
0a4650c
f79cc0d
4d5e89e
e7ef35b
5257b7b
5475262
8ee53a9
8662195
47e3822
8094c09
3e4745d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -65,15 +65,24 @@ def generate_input1(dtype, attrs: List[Dict[str, Any]]): | |
[3, 4, 5], | ||
]: | ||
for reduce_all in [True, False]: | ||
for out_dtype in [-1, 2, 5]: | ||
for op_type in [ | ||
"reduce_max", | ||
"reduce_min", | ||
"reduce_mean", | ||
"reduce_sum", | ||
"reduce_prod", | ||
]: | ||
dics1 = [ | ||
for out_dtype in [-1, 0, 2, 5]: | ||
if out_dtype != 0: | ||
reduce_type_list = [ | ||
"reduce_max", | ||
"reduce_min", | ||
"reduce_mean", | ||
"reduce_sum", | ||
"reduce_prod", | ||
] | ||
else: | ||
reduce_type_list = [ | ||
"reduce_all", | ||
"reduce_any", | ||
] | ||
|
||
for op_type in reduce_type_list: | ||
|
||
dics = [ | ||
{ | ||
"keep_dim": keep_dim, | ||
"dim": dim, | ||
|
@@ -83,46 +92,35 @@ def generate_input1(dtype, attrs: List[Dict[str, Any]]): | |
}, | ||
{}, | ||
] | ||
dics2 = [ | ||
ops_config = [ | ||
{ | ||
"keep_dim": keep_dim, | ||
"dim": dim, | ||
"reduce_all": reduce_all, | ||
Comment on lines
-88
to
-90
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里不能去掉,会报错 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
ci里面有一个验证(op_teller)是 getAttr 说要没有这些的时候,我想过这个ci,不知道那个是怎么回事。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 现在的覆盖率也是很差,不懂为什么 reduce_any 整个不会进去 |
||
"out_dtype": out_dtype, | ||
"in_dtype": out_dtype, | ||
}, | ||
{}, | ||
"op_type": op_type, | ||
"op_inputs": {"X": ["input_data"]}, | ||
"op_outputs": { | ||
"Out": ["reduce_output_data"] | ||
}, | ||
"op_attrs": dics[0], | ||
} | ||
] | ||
for dics in [dics1, dics2]: | ||
ops_config = [ | ||
{ | ||
"op_type": op_type, | ||
"op_inputs": {"X": ["input_data"]}, | ||
"op_outputs": { | ||
"Out": ["reduce_output_data"] | ||
}, | ||
"op_attrs": dics[0], | ||
} | ||
] | ||
ops = self.generate_op_config(ops_config) | ||
|
||
program_config = ProgramConfig( | ||
ops=ops, | ||
weights={}, | ||
inputs={ | ||
"input_data": TensorConfig( | ||
data_gen=partial( | ||
generate_input1, out_dtype, dics | ||
) | ||
ops = self.generate_op_config(ops_config) | ||
|
||
program_config = ProgramConfig( | ||
ops=ops, | ||
weights={}, | ||
inputs={ | ||
"input_data": TensorConfig( | ||
data_gen=partial( | ||
generate_input1, out_dtype, dics | ||
) | ||
}, | ||
outputs=["reduce_output_data"], | ||
) | ||
) | ||
}, | ||
outputs=["reduce_output_data"], | ||
) | ||
Comment on lines
+124
to
+125
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 输出为bool类型时,需添加 |
||
|
||
if not self.is_program_valid(program_config): | ||
continue | ||
if not self.is_program_valid(program_config): | ||
continue | ||
|
||
yield program_config | ||
yield program_config | ||
|
||
def sample_predictor_configs( | ||
self, program_config | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
generate_input1 也需要添加bool支持