Skip to content

Commit

Permalink
Add output defs for sgd kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
jinyouzhi committed Mar 7, 2023
1 parent b76c2dc commit 26fdcde
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = {
"select",
"send_recv",
"send_ue_recv",
"sgd",
"svd",
"sync_batch_norm_grad",
"unique",
Expand Down
4 changes: 3 additions & 1 deletion paddle/phi/kernels/gpu/sgd_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,9 @@ PD_REGISTER_KERNEL(sgd,
phi::SGDDenseKernel,
phi::dtype::float16,
float,
double) {}
double) {
kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED);
}

PD_REGISTER_KERNEL(sgd_dense_param_sparse_grad,
GPU,
Expand Down

0 comments on commit 26fdcde

Please sign in to comment.