-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add output defs for sgd kernel #51332
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
Could you give some comments? @From00 |
@@ -187,7 +187,9 @@ PD_REGISTER_KERNEL(sgd, | |||
phi::SGDDenseKernel, | |||
phi::dtype::float16, | |||
float, | |||
double) {} | |||
double) { | |||
kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
master_param_out
注册UNDEFINED
后,你需要在SgdInferMeta
中对其数据类型进行推导。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
补充了,请帮看下正确性
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
根据
using MPDType = typename phi::dtype::MPTypeTrait<T>::Type;
你需要在InferMeta中实现和phi::dtype::MPTypeTrait<T>::Type
等价的类型推导逻辑。请查看phi::dtype::MPTypeTrait<T>::Type
,根据模板类型推断实现InferMeta中的等价判断逻辑。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
根据
using MPDType = typename phi::dtype::MPTypeTrait<T>::Type;你需要在InferMeta中实现和
phi::dtype::MPTypeTrait<T>::Type
等价的类型推导逻辑。请查看phi::dtype::MPTypeTrait<T>::Type
,根据模板类型推断实现InferMeta中的等价判断逻辑。
谢谢回复,这里有点不明白,是可以抄这个模板推断还是用一串if else给出不同的类型时的结果。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
用if-else实现,InferMeta中无法获取模板类型进行推断,你需要根据输入的dtype,判断出不同的输出结果。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
改完了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Others
PR changes
Others
Describe
Task 59 in #51292