Skip to content

Commit

Permalink
Add API and unit test for reshape (#37232)
Browse files Browse the repository at this point in the history
* reshape kernel refactor

* fix compile bugs when run ci

* support xpu for reshape

* fix bugs when run unittest in kunlun ci

* fix compile bugs when run kunlun

* perfect code according to suggestion

* add api and unit test for reshape
  • Loading branch information
YuanRisheng authored Nov 16, 2021
1 parent 6ebc318 commit 79b49c2
Show file tree
Hide file tree
Showing 19 changed files with 218 additions and 27 deletions.
11 changes: 8 additions & 3 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1884,9 +1884,14 @@ void OperatorWithKernel::BuildPtenKernelContext(
} else if (attr_defs[i].type_index == std::type_index(typeid(bool))) {
pt_kernel_context_->EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<int>))) {
pt_kernel_context_->EmplaceBackAttr(
BOOST_GET_CONST(std::vector<int>, attr));
std::type_index(typeid(std::vector<int64_t>)) &&
std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int>))) {
// Emplace Back Attr according to the type of Pten_Kernel args.
const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr);
const std::vector<int64_t> vector_int64_attr(vector_int_attr.begin(),
vector_int_attr.end());
pt_kernel_context_->EmplaceBackAttr(vector_int64_attr);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"unsupported cast op attribute `%s` when construct "
Expand Down
10 changes: 8 additions & 2 deletions paddle/fluid/imperative/prepared_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -373,8 +373,14 @@ static void BuildDygraphPtenKernelContext(
} else if (attr_defs[i].type_index == std::type_index(typeid(bool))) {
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<int>))) {
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(std::vector<int>, attr));
std::type_index(typeid(std::vector<int64_t>)) &&
std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int>))) {
// Emplace Back Attr according to the type of Pten_Kernel args.
const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr);
const std::vector<int64_t> vector_int64_attr(vector_int_attr.begin(),
vector_int_attr.end());
kernel_ctx->EmplaceBackAttr(vector_int64_attr);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"unsupported cast op attribute `%s` when construct "
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/operators/reshape_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,9 @@ class ReshapeKernel {
}
#endif
} else {
auto &shape_vec = ctx.Attr<std::vector<int>>("shape");
auto &shape_attr = ctx.Attr<std::vector<int>>("shape");
const std::vector<int64_t> shape_vec(shape_attr.begin(),
shape_attr.end());
if (platform::is_cpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CPUDeviceContext>();
pten::ReshapeFromVectorVal(dev_ctx, *pt_x.get(), shape_vec, pt_out);
Expand Down
1 change: 1 addition & 0 deletions paddle/pten/api/include/manipulation.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@ namespace experimental {

PD_DLL_DECL Tensor flatten(const Tensor& x, int start_axis, int stop_axis);

PD_DLL_DECL Tensor reshape(const Tensor& x, const std::vector<int64_t>& shape);
} // namespace experimental
} // namespace paddle
34 changes: 34 additions & 0 deletions paddle/pten/api/lib/manipulation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,40 @@ PD_DLL_DECL Tensor flatten(const Tensor& x, int start_axis, int stop_axis) {

return out;
}

PD_DLL_DECL Tensor reshape(const Tensor& x, const std::vector<int64_t>& shape) {
// 1. Get kernel signature and kernel
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
"reshape2", kernel_key);

// 2. Get Device Context
auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
auto kernel_context = pten::KernelContext(dev_ctx);

// 3. Auto data transform
auto dense_x = std::dynamic_pointer_cast<pten::DenseTensor>(x.impl());
kernel_context.EmplaceBackInput(dense_x);
kernel_context.EmplaceBackAttr(shape);

// 4. InferShape
auto out_meta = InferShapeFromVecValue(dense_x->meta(), shape);

// 5. Prepare outputs
Tensor out;
const auto allocator = std::make_shared<DefaultAllocator>(
pten::TransToFluidPlace(kernel_key.backend()));
auto dense_out = std::make_shared<pten::DenseTensor>(allocator, out_meta);
kernel_context.EmplaceBackOutput(dense_out);
out.set_impl(dense_out);

// 6. Call kernel
kernel(&kernel_context);

return out;
}

} // namespace experimental
} // namespace paddle

Expand Down
1 change: 0 additions & 1 deletion paddle/pten/core/kernel_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,6 @@ struct KernelImpl<Return (*)(Args...), kernel_fn> {
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(paddle::platform::float16);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const Scalar&);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<int64_t>&);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<int>&);

/* Output Helpers */

Expand Down
2 changes: 1 addition & 1 deletion paddle/pten/include/manipulation.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ DenseTensor Flatten(const ContextT& dev_ctx,
template <typename T, typename ContextT>
DenseTensor Reshape(const ContextT& dev_ctx,
const DenseTensor& x,
const std::vector<int>& shape) {
const std::vector<int64_t>& shape) {
auto out_meta = InferShapeFromVecValue(x.meta(), shape);
const auto allocator =
std::make_shared<paddle::experimental::DefaultAllocator>(
Expand Down
4 changes: 2 additions & 2 deletions paddle/pten/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ DenseTensorMeta FullLikeInferShape(const DenseTensorMeta& x_meta,
}

static paddle::framework::DDim ValidateShape(
const std::vector<int> shape, const paddle::framework::DDim& in_dims) {
const std::vector<int64_t> shape, const paddle::framework::DDim& in_dims) {
const int64_t in_size = paddle::framework::product(in_dims);
auto in_dims_vec = paddle::framework::vectorize(in_dims);
bool all_positive = std::all_of(in_dims_vec.cbegin(),
Expand Down Expand Up @@ -203,7 +203,7 @@ static paddle::framework::DDim ValidateShape(
}

DenseTensorMeta InferShapeFromVecValue(const DenseTensorMeta& x_meta,
const std::vector<int>& shape) {
const std::vector<int64_t>& shape) {
PADDLE_ENFORCE_EQ(!shape.empty(),
true,
paddle::platform::errors::InvalidArgument(
Expand Down
2 changes: 1 addition & 1 deletion paddle/pten/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,5 @@ DenseTensorMeta FullLikeInferShape(const DenseTensorMeta& x_meta,
DataLayout layout);

DenseTensorMeta InferShapeFromVecValue(const DenseTensorMeta& x_meta,
const std::vector<int>& shape);
const std::vector<int64_t>& shape);
} // namespace pten
10 changes: 6 additions & 4 deletions paddle/pten/kernels/cpu/manipulation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ void FlattenWithXShape(const CPUContext& dev_ctx,

void ReshapeFromVectorVal(const CPUContext& dev_ctx,
const DenseTensor& x,
const std::vector<int>& shape,
const std::vector<int64_t>& shape,
DenseTensor* out) {
auto out_meta = InferShapeFromVecValue(x.meta(), shape);
if (&x == out) {
Expand All @@ -59,7 +59,7 @@ void ReshapeFromVectorVal(const CPUContext& dev_ctx,

void ReshapeFromVectorValWithXShape(const CPUContext& dev_ctx,
const DenseTensor& x,
const std::vector<int>& shape,
const std::vector<int64_t>& shape,
DenseTensor* xshape,
DenseTensor* out) {
ReshapeFromVectorVal(dev_ctx, x, shape, out);
Expand All @@ -71,8 +71,10 @@ void ReshapeFromDT(const CPUContext& dev_ctx,
const DenseTensor& shape,
DenseTensor* out) {
auto* shape_data = shape.data<int>();
auto vector_shape = std::vector<int>(shape_data, shape_data + shape.numel());
auto vector_shape =
std::vector<int64_t>(shape_data, shape_data + shape.numel());
ReshapeFromVectorVal(dev_ctx, x, vector_shape, out);
out->set_lod(x.lod());
}

void ReshapeFromDTWithXShape(const CPUContext& dev_ctx,
Expand All @@ -88,7 +90,7 @@ void ReshapeFromVectorDT(const CPUContext& dev_ctx,
const DenseTensor& x,
const std::vector<DenseTensor>& shape,
DenseTensor* out) {
std::vector<int> vector_shape;
std::vector<int64_t> vector_shape;
for (auto& tensor : shape) {
PADDLE_ENFORCE_EQ(
tensor.dims(),
Expand Down
4 changes: 2 additions & 2 deletions paddle/pten/kernels/cpu/manipulation.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ void ReshapeFromDT(const CPUContext& dev_ctx,

void ReshapeFromVectorVal(const CPUContext& dev_ctx,
const DenseTensor& x,
const std::vector<int>& shape,
const std::vector<int64_t>& shape,
DenseTensor* out);

void ReshapeFromVectorDT(const CPUContext& dev_ctx,
Expand All @@ -52,7 +52,7 @@ void ReshapeFromDTWithXShape(const CPUContext& dev_ctx,

void ReshapeFromVectorValWithXShape(const CPUContext& dev_ctx,
const DenseTensor& x,
const std::vector<int>& shape,
const std::vector<int64_t>& shape,
DenseTensor* xshape,
DenseTensor* out);

Expand Down
10 changes: 6 additions & 4 deletions paddle/pten/kernels/cuda/manipulation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ void FlattenWithXShape(const CUDAContext& dev_ctx,

void ReshapeFromVectorVal(const CUDAContext& dev_ctx,
const DenseTensor& x,
const std::vector<int>& shape,
const std::vector<int64_t>& shape,
DenseTensor* out) {
auto out_meta = InferShapeFromVecValue(x.meta(), shape);
if (&x == out) {
Expand All @@ -60,7 +60,7 @@ void ReshapeFromVectorVal(const CUDAContext& dev_ctx,

void ReshapeFromVectorValWithXShape(const CUDAContext& dev_ctx,
const DenseTensor& x,
const std::vector<int>& shape,
const std::vector<int64_t>& shape,
DenseTensor* xshape,
DenseTensor* out) {
ReshapeFromVectorVal(dev_ctx, x, shape, out);
Expand All @@ -72,8 +72,10 @@ void ReshapeFromDT(const CUDAContext& dev_ctx,
const DenseTensor& shape,
DenseTensor* out) {
auto* shape_data = shape.data<int>();
auto vector_shape = std::vector<int>(shape_data, shape_data + shape.numel());
auto vector_shape =
std::vector<int64_t>(shape_data, shape_data + shape.numel());
ReshapeFromVectorVal(dev_ctx, x, vector_shape, out);
out->set_lod(x.lod());
}

void ReshapeFromDTWithXShape(const CUDAContext& dev_ctx,
Expand All @@ -89,7 +91,7 @@ void ReshapeFromVectorDT(const CUDAContext& dev_ctx,
const DenseTensor& x,
const std::vector<DenseTensor>& shape,
DenseTensor* out) {
std::vector<int> vector_shape;
std::vector<int64_t> vector_shape;
for (auto& tensor : shape) {
PADDLE_ENFORCE_EQ(
tensor.dims(),
Expand Down
4 changes: 2 additions & 2 deletions paddle/pten/kernels/cuda/manipulation.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ void ReshapeFromDT(const CUDAContext& dev_ctx,

void ReshapeFromVectorVal(const CUDAContext& dev_ctx,
const DenseTensor& x,
const std::vector<int>& shape,
const std::vector<int64_t>& shape,
DenseTensor* out);

void ReshapeFromVectorDT(const CUDAContext& dev_ctx,
Expand All @@ -56,7 +56,7 @@ void ReshapeFromDTWithXShape(const CUDAContext& dev_ctx,

void ReshapeFromVectorValWithXShape(const CUDAContext& dev_ctx,
const DenseTensor& x,
const std::vector<int>& shape,
const std::vector<int64_t>& shape,
DenseTensor* xshape,
DenseTensor* out);

Expand Down
7 changes: 4 additions & 3 deletions paddle/pten/kernels/xpu/manipulation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ void FlattenWithXShape(const XPUContext& dev_ctx,

void ReshapeFromVectorVal(const XPUContext& dev_ctx,
const DenseTensor& x,
const std::vector<int>& shape,
const std::vector<int64_t>& shape,
DenseTensor* out) {
auto out_meta = InferShapeFromVecValue(x.meta(), shape);
if (&x == out) {
Expand All @@ -69,15 +69,16 @@ void ReshapeFromDT(const XPUContext& dev_ctx,
const DenseTensor& shape,
DenseTensor* out) {
auto* shape_data = shape.data<int>();
auto vector_shape = std::vector<int>(shape_data, shape_data + shape.numel());
auto vector_shape =
std::vector<int64_t>(shape_data, shape_data + shape.numel());
ReshapeFromVectorVal(dev_ctx, x, vector_shape, out);
}

void ReshapeFromVectorDT(const XPUContext& dev_ctx,
const DenseTensor& x,
const std::vector<DenseTensor>& shape,
DenseTensor* out) {
std::vector<int> vector_shape;
std::vector<int64_t> vector_shape;
for (auto& tensor : shape) {
PADDLE_ENFORCE_EQ(
tensor.dims(),
Expand Down
2 changes: 1 addition & 1 deletion paddle/pten/kernels/xpu/manipulation.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ void ReshapeFromDT(const XPUContext& dev_ctx,

void ReshapeFromVectorVal(const XPUContext& dev_ctx,
const DenseTensor& x,
const std::vector<int>& shape,
const std::vector<int64_t>& shape,
DenseTensor* out);

void ReshapeFromVectorDT(const XPUContext& dev_ctx,
Expand Down
1 change: 1 addition & 0 deletions paddle/pten/tests/api/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ cc_test(test_matmul_api SRCS test_matmul_api.cc DEPS pten_tensor pten_api pten_a
cc_test(test_fill_api SRCS test_fill_api.cc DEPS pten_tensor pten_api pten_api_utils)
cc_test(test_flatten_api SRCS test_flatten_api.cc DEPS pten_tensor pten_api pten_api_utils)
cc_test(test_elementwise_api SRCS test_elementwise_api.cc DEPS pten_tensor pten_api pten_api_utils)
cc_test(test_reshape_api SRCS test_reshape_api.cc DEPS pten_tensor pten_api pten_api_utils)
70 changes: 70 additions & 0 deletions paddle/pten/tests/api/test_reshape_api.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <gtest/gtest.h>
#include <memory>

#include "paddle/pten/api/include/manipulation.h"

#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"

PT_DECLARE_MODULE(ManipulationCPU);

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_DECLARE_MODULE(ManipulationCUDA);
#endif

namespace framework = paddle::framework;
using DDim = paddle::framework::DDim;

// TODO(chenweihang): Remove this test after the API is used in the dygraph
TEST(API, reshape) {
// 1. create tensor
const auto alloc = std::make_shared<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace());
auto dense_x = std::make_shared<pten::DenseTensor>(
alloc,
pten::DenseTensorMeta(pten::DataType::FLOAT32,
framework::make_ddim({3, 2, 2, 3}),
pten::DataLayout::NCHW));
auto* dense_x_data = dense_x->mutable_data<float>();

for (int i = 0; i < dense_x->numel(); i++) {
dense_x_data[i] = i;
}

paddle::experimental::Tensor x(dense_x);
std::vector<int64_t> shape{12, 3};
// 2. test API
auto out = paddle::experimental::reshape(x, shape);
// 3. check result
std::vector<int64_t> expect_shape = {12, 3};
ASSERT_EQ(out.shape()[0], expect_shape[0]);
ASSERT_EQ(out.shape()[1], expect_shape[1]);
ASSERT_EQ(out.numel(), 36);
ASSERT_EQ(out.is_cpu(), true);
ASSERT_EQ(out.type(), pten::DataType::FLOAT32);
ASSERT_EQ(out.layout(), pten::DataLayout::NCHW);
ASSERT_EQ(out.initialized(), true);
bool value_equal = true;
auto dense_out = std::dynamic_pointer_cast<pten::DenseTensor>(out.impl());
auto* dense_out_data = dense_out->data<float>();
for (int i = 0; i < dense_x->numel(); i++) {
if (std::abs(dense_x_data[i] - dense_out_data[i]) > 1e-6f)
value_equal = false;
}
ASSERT_EQ(value_equal, true);
}
1 change: 1 addition & 0 deletions paddle/pten/tests/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ cc_test(test_flatten_dev_api SRCS test_flatten_dev_api.cc DEPS pten pten_api_uti
cc_test(test_mean_dev_api SRCS test_mean_dev_api.cc DEPS pten pten_api_utils)
cc_test(test_scale_dev_api SRCS test_scale_dev_api.cc DEPS pten pten_api_utils)
cc_test(test_elementwise_dev_api SRCS test_elementwise_dev_api.cc DEPS pten pten_api_utils)
cc_test(test_reshape_dev_api SRCS test_reshape_dev_api.cc DEPS pten pten_api_utils)
Loading

0 comments on commit 79b49c2

Please sign in to comment.