Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add output defs for sgd kernel #51332

Merged
merged 3 commits into from
Mar 9, 2023
Merged

Conversation

jinyouzhi
Copy link
Contributor

PR types

Others

PR changes

Others

Describe

Task 59 in #51292

@paddle-bot
Copy link

paddle-bot bot commented Mar 7, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label Mar 7, 2023
@jinyouzhi
Copy link
Contributor Author

Could you give some comments? @From00

@@ -187,7 +187,9 @@ PD_REGISTER_KERNEL(sgd,
phi::SGDDenseKernel,
phi::dtype::float16,
float,
double) {}
double) {
kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

master_param_out注册UNDEFINED后,你需要在SgdInferMeta中对其数据类型进行推导。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

补充了,请帮看下正确性

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据

using MPDType = typename phi::dtype::MPTypeTrait<T>::Type;

你需要在InferMeta中实现和phi::dtype::MPTypeTrait<T>::Type等价的类型推导逻辑。请查看phi::dtype::MPTypeTrait<T>::Type,根据模板类型推断实现InferMeta中的等价判断逻辑。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据

using MPDType = typename phi::dtype::MPTypeTrait<T>::Type;

你需要在InferMeta中实现和phi::dtype::MPTypeTrait<T>::Type等价的类型推导逻辑。请查看phi::dtype::MPTypeTrait<T>::Type,根据模板类型推断实现InferMeta中的等价判断逻辑。

谢谢回复,这里有点不明白,是可以抄这个模板推断还是用一串if else给出不同的类型时的结果。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

用if-else实现,InferMeta中无法获取模板类型进行推断,你需要根据输入的dtype,判断出不同的输出结果。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改完了

@jinyouzhi jinyouzhi requested a review from From00 March 9, 2023 05:08
Copy link
Contributor

@From00 From00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@From00 From00 merged commit c0f84b8 into PaddlePaddle:develop Mar 9, 2023
@jinyouzhi jinyouzhi deleted the phi/sgd branch March 9, 2023 06:28
@From00 From00 mentioned this pull request Mar 13, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants