Skip to content

Commit

Permalink
[PHI decoupling] remove framework/data_type.h from phi (#47776)
Browse files Browse the repository at this point in the history
* remove framework/data_type.h from phi

* fix CI fail: map proto::VarType to phi::DataType

* refactor code to add more detailed comments
  • Loading branch information
GreatV authored Nov 9, 2022
1 parent 7e91438 commit 1631836
Show file tree
Hide file tree
Showing 12 changed files with 45 additions and 32 deletions.
18 changes: 18 additions & 0 deletions paddle/phi/core/utils/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License. */

#pragma once
#include <iostream>
#include <map>
#include <string>
#include <typeindex>

Expand All @@ -23,6 +24,23 @@ limitations under the License. */

namespace phi {

// Here we can't depend on the fluid proto::VarType, so we use the dtype enum
// value directly. See also `assign_value_sig.cc`.
// proto::VarType::INT16 -> 1 -> phi::DataType::INT16
// proto::VarType::INT32 -> 2 -> phi::DataType::INT32
// proto::VarType::INT64 -> 3 -> phi::DataType::INT64
// proto::VarType::FP16 -> 4 -> phi::DataType::FLOAT16
// proto::VarType::FP32 -> 5 -> phi::DataType::FLOAT32
// proto::VarType::FP64 -> 6 -> phi::DataType::FLOAT64
// proto::VarType::UINT8 -> 20 -> phi::DataType::UINT8
static std::map<int, phi::DataType> var_type_map{{1, phi::DataType::INT16},
{2, phi::DataType::INT32},
{3, phi::DataType::INT64},
{4, phi::DataType::FLOAT16},
{5, phi::DataType::FLOAT32},
{6, phi::DataType::FLOAT64},
{20, phi::DataType::UINT8}};

#define _PhiForEachDataTypeHelper_(callback, cpp_type, data_type) \
callback(cpp_type, data_type);

Expand Down
10 changes: 5 additions & 5 deletions paddle/phi/kernels/cpu/arg_min_max_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function.h"

Expand Down Expand Up @@ -141,15 +142,14 @@ void ArgMinMaxKernel(const Context& dev_ctx,
int dtype,
DenseTensor* out) {
if (dtype < 0) {
paddle::framework::VisitDataTypeTiny(
static_cast<paddle::framework::proto::VarType::Type>(
paddle::framework::proto::VarType::INT64),
phi::VisitDataTypeTiny(
phi::DataType::INT64,
VisitDataArgMinMaxFunctor<Context, T, EnumArgMinMaxValue>(
dev_ctx, x, axis.to<int64_t>(), keepdims, flatten, out));
return;
}
paddle::framework::VisitDataTypeTiny(
static_cast<paddle::framework::proto::VarType::Type>(dtype),
phi::VisitDataTypeTiny(
var_type_map[dtype],
VisitDataArgMinMaxFunctor<Context, T, EnumArgMinMaxValue>(
dev_ctx, x, axis.to<int64_t>(), keepdims, flatten, out));
}
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/kernels/cpu/cumprod_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
#include "paddle/phi/kernels/funcs/for_range.h"

// NOTE(@xiongkun): use of IsComplex<>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/core/utils/data_type.h"

namespace phi {
template <typename T, typename Context>
Expand Down Expand Up @@ -51,7 +51,7 @@ void CumprodGradKernel(const Context& dev_ctx,
const T* out_data_deal;
Allocator::AllocationPtr x_conj;
Allocator::AllocationPtr out_conj;
if (paddle::framework::IsComplex<T>::value) {
if (phi::IsComplexType(x.dtype())) {
x_conj = const_cast<Allocator&>(dev_ctx.GetAllocator())
.Allocate(numel * sizeof(T));
auto* x_data_conj = reinterpret_cast<T*>(x_conj->ptr());
Expand Down
11 changes: 5 additions & 6 deletions paddle/phi/kernels/cpu/unique_consecutive_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/errors.h"
#include "paddle/phi/core/kernel_registry.h"

#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/core/utils/data_type.h"

namespace phi {

Expand All @@ -33,8 +32,8 @@ void UniqueConsecutiveKernel(const Context& dev_ctx,
DenseTensor* out,
DenseTensor* index,
DenseTensor* counts) {
auto data_type = static_cast<paddle::framework::proto::VarType::Type>(dtype);
if (data_type == paddle::framework::proto::VarType::INT32) {
auto data_type = var_type_map[dtype];
if (data_type == phi::DataType::INT32) {
PADDLE_ENFORCE_LE(
x.numel(),
INT_MAX,
Expand All @@ -46,13 +45,13 @@ void UniqueConsecutiveKernel(const Context& dev_ctx,
}

if (axis.empty()) {
paddle::framework::VisitDataTypeTiny(
phi::VisitDataTypeTiny(
data_type,
UniqueConsecutiveFlattenedTensorFunctor<Context, T>(
dev_ctx, x, out, return_inverse, return_counts, index, counts));
} else {
int valid_axis = axis[0];
paddle::framework::VisitDataTypeTiny(
phi::VisitDataTypeTiny(
data_type,
UniqueConsecutiveDimFunctor<Context, T>(dev_ctx,
x,
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/funcs/math_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ limitations under the License. */
#include <utility>
#include <vector>

#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function_impl.h"
#include "unsupported/Eigen/CXX11/Tensor"
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/funcs/math_function.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ limitations under the License. */
#include <algorithm>
#include <vector>

#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/math_function_impl.h"
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/funcs/math_function_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ limitations under the License. */
#include <memory>
#include <vector>

#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/kernels/funcs/math_function.h"

namespace phi {
Expand Down
12 changes: 5 additions & 7 deletions paddle/phi/kernels/gpu/arg_min_max_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,8 @@ namespace cub = hipcub;
#endif
#include <limits>

#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/core/ddim.h"

#include "paddle/phi/core/utils/data_type.h"
namespace phi {

namespace { // NOLINT
Expand Down Expand Up @@ -209,15 +208,14 @@ void ArgMinMaxOpCUDAKernel(const Context& dev_ctx,
int dtype,
DenseTensor* out) {
if (dtype < 0) {
paddle::framework::VisitDataTypeTiny(
static_cast<paddle::framework::proto::VarType::Type>(
paddle::framework::proto::VarType::INT64),
phi::VisitDataTypeTiny(
phi::DataType::INT64,
VisitDataCudaArgMinMaxFunctor<Context, T, Reducer>(
dev_ctx, x, axis.to<int64_t>(), keepdims, flatten, out));
return;
}
paddle::framework::VisitDataTypeTiny(
static_cast<paddle::framework::proto::VarType::Type>(dtype),
phi::VisitDataTypeTiny(
var_type_map[dtype],
VisitDataCudaArgMinMaxFunctor<Context, T, Reducer>(
dev_ctx, x, axis.to<int64_t>(), keepdims, flatten, out));
}
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/kernels/gpu/cumprod_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/for_range.h"
// NOTE(@xiongkun): use of IsComplex<>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/core/utils/data_type.h"

namespace phi {

Expand Down Expand Up @@ -152,7 +152,7 @@ void CumprodGradKernel(const Context &dev_ctx,
const T *y_data_deal;
Allocator::AllocationPtr x_conj;
Allocator::AllocationPtr y_conj;
if (paddle::framework::IsComplex<T>::value) {
if (phi::IsComplexType(x.dtype())) {
x_conj = const_cast<Allocator &>(dev_ctx.GetAllocator())
.Allocate(numel * sizeof(T));
auto *x_data_conj = reinterpret_cast<T *>(x_conj->ptr());
Expand Down
10 changes: 4 additions & 6 deletions paddle/phi/kernels/gpu/unique_consecutive_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
#include "paddle/phi/core/errors.h"
#include "paddle/phi/core/kernel_registry.h"

#include "paddle/fluid/framework/data_type.h"

namespace phi {

template <typename T, typename Context>
Expand All @@ -35,8 +33,8 @@ void UniqueConsecutiveKernel(const Context& dev_ctx,
DenseTensor* out,
DenseTensor* index,
DenseTensor* counts) {
auto data_type = static_cast<paddle::framework::proto::VarType::Type>(dtype);
if (data_type == paddle::framework::proto::VarType::INT32) {
auto data_type = var_type_map[dtype];
if (data_type == phi::DataType::INT32) {
PADDLE_ENFORCE_LE(
x.numel() + 1,
INT_MAX,
Expand All @@ -49,14 +47,14 @@ void UniqueConsecutiveKernel(const Context& dev_ctx,

// if 'axis' is not required, flatten the Tensor.
if (axis.empty()) {
paddle::framework::VisitDataTypeTiny(
phi::VisitDataTypeTiny(
data_type,
UniqueConsecutiveFlattenedCUDAFunctor<Context, T>(
dev_ctx, x, out, return_inverse, return_counts, index, counts));
} else {
// 'axis' is required.
int valid_axis = axis[0];
paddle::framework::VisitDataTypeTiny(
phi::VisitDataTypeTiny(
data_type,
UniqueConsecutiveDimsCUDAFunctor<Context, T>(dev_ctx,
x,
Expand Down
1 change: 0 additions & 1 deletion paddle/phi/kernels/impl/isclose_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#include <cmath>
#include <string>

#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h"
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/xpu/arg_min_max_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"

namespace phi {

Expand Down

0 comments on commit 1631836

Please sign in to comment.