Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add output defs for sgd kernel #51332

Merged
merged 3 commits into from
Mar 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = {
"select",
"send_recv",
"send_ue_recv",
"sgd",
"svd",
"sync_batch_norm_grad",
"unique",
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2481,6 +2481,15 @@ void SgdInferMeta(const MetaTensor& param,

param_out->set_dims(param.dims());
param_out->set_dtype(param.dtype());
if (multi_precision) {
master_param_out->set_dims(master_param.dims());
if (DataType::FLOAT16 == master_param.dtype() ||
DataType::BFLOAT16 == master_param.dtype()) {
master_param_out->set_dtype(DataType::FLOAT32);
} else {
master_param_out->set_dtype(master_param.dtype());
}
}
}

void SendUERecvInferMeta(const MetaTensor& x,
Expand Down
4 changes: 3 additions & 1 deletion paddle/phi/kernels/gpu/sgd_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,9 @@ PD_REGISTER_KERNEL(sgd,
phi::SGDDenseKernel,
phi::dtype::float16,
float,
double) {}
double) {
kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

master_param_out注册UNDEFINED后,你需要在SgdInferMeta中对其数据类型进行推导。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

补充了,请帮看下正确性

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据

using MPDType = typename phi::dtype::MPTypeTrait<T>::Type;

你需要在InferMeta中实现和phi::dtype::MPTypeTrait<T>::Type等价的类型推导逻辑。请查看phi::dtype::MPTypeTrait<T>::Type,根据模板类型推断实现InferMeta中的等价判断逻辑。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据

using MPDType = typename phi::dtype::MPTypeTrait<T>::Type;

你需要在InferMeta中实现和phi::dtype::MPTypeTrait<T>::Type等价的类型推导逻辑。请查看phi::dtype::MPTypeTrait<T>::Type,根据模板类型推断实现InferMeta中的等价判断逻辑。

谢谢回复,这里有点不明白,是可以抄这个模板推断还是用一串if else给出不同的类型时的结果。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

用if-else实现,InferMeta中无法获取模板类型进行推断,你需要根据输入的dtype,判断出不同的输出结果。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改完了

}

PD_REGISTER_KERNEL(sgd_dense_param_sparse_grad,
GPU,
Expand Down