Skip to content

Commit

Permalink
fix (#60570)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangbo9674 authored Jan 5, 2024
1 parent 09544f6 commit c3106c4
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions paddle/phi/kernels/gpu/arange_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,12 @@ void ArangeNullaryKernel(const Context& dev_ctx,
const T end_value,
const T step_value,
DenseTensor* out) {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType start_value_mpt = static_cast<MPType>(start_value);
MPType end_value_mpt = static_cast<MPType>(end_value);
MPType step_value_mpt = static_cast<MPType>(step_value);
int64_t size = 0;
phi::funcs::GetSize(start_value, end_value, step_value, &size);
phi::funcs::GetSize(start_value_mpt, end_value_mpt, step_value_mpt, &size);
out->Resize(common::make_ddim({size}));
T* out_data = dev_ctx.template Alloc<T>(out);

Expand All @@ -77,7 +81,8 @@ void ArangeNullaryKernel(const Context& dev_ctx,
return;
}
int64_t grid = (size + block - 1) / block;
Range<T><<<grid, block, 0, stream>>>(start_value, step_value, size, out_data);
Range<MPType, T><<<grid, block, 0, stream>>>(
start_value_mpt, step_value_mpt, size, out_data);
}

template <typename T, typename Context>
Expand All @@ -86,11 +91,10 @@ void ArangeKernel(const Context& dev_ctx,
const Scalar& end,
const Scalar& step,
DenseTensor* out) {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType start_value = start.to<MPType>();
MPType end_value = end.to<MPType>();
MPType step_value = step.to<MPType>();
ArangeNullaryKernel<MPType, Context>(
T start_value = start.to<T>();
T end_value = end.to<T>();
T step_value = step.to<T>();
ArangeNullaryKernel<T, Context>(
dev_ctx, start_value, end_value, step_value, out);
}

Expand Down

0 comments on commit c3106c4

Please sign in to comment.