-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Paddle Inference] Add add eye trt converter #48937
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
// Declare inputs attr | ||
const int num_rows = PADDLE_GET_CONST(int, op_desc.GetAttr("num_rows")); | ||
int num_columns = PADDLE_GET_CONST(int, op_desc.GetAttr("num_columns")); | ||
const int dtype = PADDLE_GET_CONST(int, op_desc.GetAttr("dtype")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以改成
auto dtype = static_cast<framework::proto::VarType::Type>(
PADDLE_GET_CONST(int, op_desc.GetAttr("dtype")));
if (-1 == num_columns) { | ||
input_shape.d[1] = num_rows; | ||
num_columns = num_rows; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
num_columns为默认值-1时,num_columns=num_rows, 这段逻辑可以放在input_shape.d赋值前
} | ||
|
||
std::vector<T> constant_arr(num_rows * num_columns, 0); | ||
for (int i = 0; i < std::min(num_rows, num_columns); i++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
std::min(num_rows, num_columns) 放循环外避免多次调用std::min
nvinfer1::DataType nv_type = nvinfer1::DataType::kFLOAT; | ||
switch (dtype) { | ||
case paddle::framework::proto::VarType::FP32: | ||
nv_type = nvinfer1::DataType::kFLOAT; | ||
typedef float T; | ||
break; | ||
case paddle::framework::proto::VarType::FP16: | ||
nv_type = nvinfer1::DataType::kHALF; | ||
typedef uint16_t T; | ||
break; | ||
case paddle::framework::proto::VarType::INT32: | ||
nv_type = nvinfer1::DataType::kINT32; | ||
typedef int32_t T; | ||
break; | ||
default: | ||
paddle::platform::errors::InvalidArgument( | ||
"Paddle-TRT loads weighths failed, found not supported data type " | ||
"%s.", | ||
dtype); | ||
break; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里T类型声明编译报错,可以参考convert/fill_constant_op.cc 中写法,进行不同类型数组赋值
for _ in range(6): | ||
if np.random.random() > 0.5: | ||
num_rows = generate_input_attr1() | ||
attr_dic = {"num_rows": num_rows} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以加一个num_columns为-1。两处attr_dic中添加一个dtype字段,37-38行之间添加一层for循环用于dtype赋值,
for dtype in [2, 4, 5]
@zhangjun 打扰了,问一下这个 |
很抱歉,经过我们的反复讨论,你的PR暂未达到合入标准,请阅读飞桨原生算子开发规范,你可以重新提交新的PR,我们先将此PR关闭,感谢你的贡献。 |
PR types
Others
PR changes
Others
Describe
#48292