From ab6d75fdac2d44fcc2a82b2302c71d856b069b60 Mon Sep 17 00:00:00 2001 From: haohongxiang Date: Wed, 13 Jul 2022 18:33:42 +0000 Subject: [PATCH 01/17] migrate lstsq op --- paddle/fluid/operators/lstsq_op.cc | 92 +----- paddle/fluid/operators/lstsq_op.cu | 313 ------------------ paddle/fluid/operators/lstsq_op.h | 345 -------------------- paddle/phi/api/yaml/legacy_api.yaml | 8 + paddle/phi/infermeta/binary.cc | 84 +++++ paddle/phi/infermeta/binary.h | 9 + paddle/phi/kernels/cpu/lstsq_kernel.cc | 296 +++++++++++++++++ paddle/phi/kernels/gpu/lstsq_kernel.cu | 162 +++++++++ paddle/phi/kernels/impl/lstsq_kernel_impl.h | 239 ++++++++++++++ paddle/phi/kernels/impl/qr_kernel_impl.h | 277 ++++++++++++++++ paddle/phi/kernels/lstsq_kernel.h | 36 ++ paddle/phi/kernels/reduce_sum_kernel.h | 2 +- python/paddle/tensor/linalg.py | 54 +-- 13 files changed, 1125 insertions(+), 792 deletions(-) delete mode 100644 paddle/fluid/operators/lstsq_op.cu delete mode 100644 paddle/fluid/operators/lstsq_op.h create mode 100644 paddle/phi/kernels/cpu/lstsq_kernel.cc create mode 100644 paddle/phi/kernels/gpu/lstsq_kernel.cu create mode 100644 paddle/phi/kernels/impl/lstsq_kernel_impl.h create mode 100644 paddle/phi/kernels/impl/qr_kernel_impl.h create mode 100644 paddle/phi/kernels/lstsq_kernel.h diff --git a/paddle/fluid/operators/lstsq_op.cc b/paddle/fluid/operators/lstsq_op.cc index 70ce5082ced304..90381c4227c95f 100644 --- a/paddle/fluid/operators/lstsq_op.cc +++ b/paddle/fluid/operators/lstsq_op.cc @@ -12,12 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/lstsq_op.h" - -#include -#include - +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/infermeta/binary.h" namespace paddle { namespace operators { @@ -25,79 +22,6 @@ namespace operators { class LstsqOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "LstsqOp"); - OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "LstsqOp"); - - OP_INOUT_CHECK(ctx->HasOutput("Solution"), "Output", "Solution", "LstsqOp"); - OP_INOUT_CHECK(ctx->HasOutput("Rank"), "Output", "Rank", "LstsqOp"); - OP_INOUT_CHECK(ctx->HasOutput("SingularValues"), - "Output", - "SingularValues", - "LstsqOp"); - - auto x_dims = ctx->GetInputDim("X"); - auto y_dims = ctx->GetInputDim("Y"); - int x_rank = x_dims.size(); - int y_rank = y_dims.size(); - - PADDLE_ENFORCE_GE(x_rank, - 2, - platform::errors::InvalidArgument( - "Expects input tensor x to be not less than " - "2 dimentions, but got dimention %d", - x_rank)); - PADDLE_ENFORCE_GE(y_rank, - 2, - platform::errors::InvalidArgument( - "Expects input tensor y to be not less than " - "2 dimentions, but got dimention %d", - y_rank)); - - PADDLE_ENFORCE_EQ( - x_rank, - y_rank, - platform::errors::InvalidArgument( - "Expects input tensor x and y to have the same dimension " - "but got x's dimention [%d] and y's dimention [%d]", - x_rank, - y_rank)); - - std::vector batch_dims_vec{}; - for (int i = 0; i < x_rank - 2; ++i) { - PADDLE_ENFORCE_EQ( - x_dims[i], - y_dims[i], - platform::errors::InvalidArgument( - "Expects input tensor x and y to have the same batch " - "dimension, but got x's batch dimention [%d] and " - "y's batch dimention [%d] in %d-th dim", - x_dims[i], - y_dims[i], - i)); - batch_dims_vec.emplace_back(x_dims[i]); - } - - PADDLE_ENFORCE_EQ( - x_dims[x_rank - 2], - y_dims[y_rank - 2], - platform::errors::InvalidArgument( - "Expects input tensor x and y to have the same row dimension " - "of the inner-most 2-dims matrix, " - "but got x's row dimention [%d] and y's row dimention [%d]", - x_dims[x_rank - 2], - y_dims[y_rank - 2])); - - ctx->SetOutputDim("Rank", phi::make_ddim(batch_dims_vec)); - - batch_dims_vec.emplace_back( - std::min(x_dims[x_rank - 2], x_dims[x_rank - 1])); - ctx->SetOutputDim("SingularValues", phi::make_ddim(batch_dims_vec)); - - batch_dims_vec[x_rank - 2] = x_dims[x_rank - 1]; - batch_dims_vec.emplace_back(y_dims[x_rank - 1]); - ctx->SetOutputDim("Solution", phi::make_ddim(batch_dims_vec)); - } protected: // The output of lstsq is always complex-valued even for real-valued inputs @@ -148,8 +72,12 @@ This API processes Lstsq functor for general matrices. } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(lstsq, ops::LstsqOp, ops::LstsqOpMaker) -REGISTER_OP_CPU_KERNEL(lstsq, - ops::LstsqCPUKernel, - ops::LstsqCPUKernel); +DECLARE_INFER_SHAPE_FUNCTOR(lstsq, + LstsqInferShapeFunctor, + PD_INFER_META(phi::TriangularSolveInferMeta)); + +REGISTER_OPERATOR(lstsq, + ops::LstsqOp, + ops::LstsqOpMaker, + LstsqInferShapeFunctor); diff --git a/paddle/fluid/operators/lstsq_op.cu b/paddle/fluid/operators/lstsq_op.cu deleted file mode 100644 index 82a56af7eb4f14..00000000000000 --- a/paddle/fluid/operators/lstsq_op.cu +++ /dev/null @@ -1,313 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef PADDLE_WITH_HIP -// HIP not support cusolver - -#include -#include - -#include "paddle/fluid/framework/phi_utils.h" -#include "paddle/fluid/operators/lstsq_op.h" -#include "paddle/fluid/operators/qr_op.h" -#include "paddle/fluid/platform/dynload/cusolver.h" -#include "paddle/phi/kernels/triangular_solve_kernel.h" - -namespace paddle { -namespace operators { - -using paddle::framework::Tensor; - -template -class LstsqCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const Tensor& x = *context.Input("X"); - const Tensor& y = *context.Input("Y"); - auto* solution = context.Output("Solution"); - - auto dito = - math::DeviceIndependenceTensorOperations(context); - auto& dev_ctx = - context.template device_context(); - - auto x_dims = x.dims(); - auto y_dims = y.dims(); - int dim_size = x_dims.size(); - int m = x_dims[dim_size - 2]; - int n = x_dims[dim_size - 1]; - int nrhs = y_dims[dim_size - 1]; - int min_mn = std::min(m, n); - int max_mn = std::max(m, n); - int k = min_mn; - - int x_stride = MatrixStride(x); - int y_stride = MatrixStride(y); - int tau_stride = min_mn; - int batch_count = BatchCount(x); - - Tensor new_x, new_y; - new_x.mutable_data(context.GetPlace(), - size_t(batch_count * m * n * sizeof(T))); - new_y.mutable_data(context.GetPlace(), - size_t(batch_count * m * nrhs * sizeof(T))); - framework::TensorCopy(x, context.GetPlace(), &new_x); - framework::TensorCopy(y, context.GetPlace(), &new_y); - - // Prepare tau - auto tau_dims_vec = phi::vectorize(x_dims); - tau_dims_vec.pop_back(); - tau_dims_vec[tau_dims_vec.size() - 1] = min_mn; - Tensor tau = dito.Fill(tau_dims_vec, 0); - auto tau_data = tau.mutable_data(context.GetPlace()); - - using Context = - typename framework::ConvertToPhiContext::TYPE; - auto& phi_dev_ctx = static_cast(dev_ctx); - - if (m >= n) { - Tensor tmp_x = dito.Transpose(new_x); - Tensor tmp_y = dito.Transpose(new_y); - auto x_data = tmp_x.mutable_data(context.GetPlace()); - auto y_data = tmp_y.mutable_data(context.GetPlace()); - - // step 1, compute QR factorization using geqrf - BatchedGeqrf(dev_ctx, - batch_count, - m, - n, - x_data, - m, - tau_data, - x_stride, - tau_stride); - - // Step 2, Y <- Q^H Y - BatchedOrmqr(dev_ctx, - true, - true, - batch_count, - m, - nrhs, - k, - x_data, - x_stride, - tau_data, - tau_stride, - y_data, - y_stride); - - Tensor trans_r = dito.Transpose(tmp_x); - Tensor slice_r = dito.Slice(trans_r, {-2}, {0}, {min_mn}); - Tensor res_r = dito.TrilTriu(slice_r, 0, false); - - Tensor trans_y = dito.Transpose(tmp_y); - Tensor slice_y = dito.Slice(trans_y, {-2}, {0}, {min_mn}); - - // Step 3, solve R X = Y - phi::TriangularSolveKernel( - phi_dev_ctx, res_r, slice_y, true, false, false, solution); - - } else { - auto x_data = new_x.mutable_data(context.GetPlace()); - auto y_data = new_y.mutable_data(context.GetPlace()); - - // step 1, compute QR factorization using geqrf - BatchedGeqrf(dev_ctx, - batch_count, - n, - m, - x_data, - n, - tau_data, - x_stride, - tau_stride); - - // Step 2, solve R^H Z = Y - Tensor trans_r = dito.Transpose(new_x); - Tensor slice_r = dito.Slice(trans_r, {-2}, {0}, {min_mn}); - Tensor res_r = dito.TrilTriu(slice_r, 0, false); - - phi::TriangularSolveKernel( - phi_dev_ctx, res_r, new_y, true, true, false, solution); - - // Step 3, X <- Q Z - BatchedOrgqr(dev_ctx, - batch_count, - n, - m, - min_mn, - x_data, - n, - tau_data, - x_stride, - tau_stride); - Tensor trans_q = dito.Transpose(new_x); - Tensor slice_q = dito.Slice(trans_q, {-1}, {0}, {m}); - Tensor solu_tensor = dito.Matmul(slice_q, *solution, false, false); - framework::TensorCopy(solu_tensor, solution->place(), solution); - } - } -}; - -template <> -void BatchedOrmqr( - const platform::CUDADeviceContext& dev_ctx, - bool left, - bool transpose, - int batch_size, - int m, - int n, - int k, - float* a, - int a_stride, - float* tau, - int tau_stride, - float* other, - int other_stride) { - int lwork = 0; - auto side = left ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT; - auto trans = transpose ? CUBLAS_OP_T : CUBLAS_OP_N; - int lda = std::max(1, left ? m : n); - int ldc = std::max(1, m); - - auto handle = dev_ctx.cusolver_dn_handle(); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnSormqr_bufferSize( - handle, side, trans, m, n, k, a, lda, tau, other, ldc, &lwork)); - auto info = memory::Alloc(dev_ctx, sizeof(int)); - int* info_d = reinterpret_cast(info->ptr()); - - for (int i = 0; i < batch_size; ++i) { - float* a_working_ptr = &a[i * a_stride]; - float* tau_working_ptr = &tau[i * tau_stride]; - float* other_working_ptr = &other[i * other_stride]; - - handle = dev_ctx.cusolver_dn_handle(); - auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(float)); - float* workspace_ptr = reinterpret_cast(workspace->ptr()); - - // compute ormgr - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cusolverDnSormqr(handle, - side, - trans, - m, - n, - k, - a_working_ptr, - lda, - tau_working_ptr, - other_working_ptr, - ldc, - workspace_ptr, - lwork, - info_d)); - - // check the error info - int info_h; - memory::Copy(platform::CPUPlace(), - &info_h, - dev_ctx.GetPlace(), - info_d, - sizeof(int), - dev_ctx.stream()); - PADDLE_ENFORCE_EQ( - info_h, - 0, - platform::errors::PreconditionNotMet( - "For batch [%d]: CUSolver info is not zero but [%d]", i, info_h)); - } -} - -template <> -void BatchedOrmqr( - const platform::CUDADeviceContext& dev_ctx, - bool left, - bool transpose, - int batch_size, - int m, - int n, - int k, - double* a, - int a_stride, - double* tau, - int tau_stride, - double* other, - int other_stride) { - int lwork = 0; - auto side = left ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT; - auto trans = transpose ? CUBLAS_OP_T : CUBLAS_OP_N; - int lda = std::max(1, left ? m : n); - int ldc = std::max(1, m); - - auto handle = dev_ctx.cusolver_dn_handle(); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnDormqr_bufferSize( - handle, side, trans, m, n, k, a, lda, tau, other, ldc, &lwork)); - auto info = memory::Alloc(dev_ctx, sizeof(int)); - int* info_d = reinterpret_cast(info->ptr()); - - for (int i = 0; i < batch_size; ++i) { - double* a_working_ptr = &a[i * a_stride]; - double* tau_working_ptr = &tau[i * tau_stride]; - double* other_working_ptr = &other[i * other_stride]; - - handle = dev_ctx.cusolver_dn_handle(); - auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(double)); - double* workspace_ptr = reinterpret_cast(workspace->ptr()); - - // compute ormgr - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cusolverDnDormqr(handle, - side, - trans, - m, - n, - k, - a_working_ptr, - lda, - tau_working_ptr, - other_working_ptr, - ldc, - workspace_ptr, - lwork, - info_d)); - - // check the error info - int info_h; - memory::Copy(platform::CPUPlace(), - &info_h, - dev_ctx.GetPlace(), - info_d, - sizeof(int), - dev_ctx.stream()); - PADDLE_ENFORCE_EQ( - info_h, - 0, - platform::errors::PreconditionNotMet( - "For batch [%d]: CUSolver info is not zero but [%d]", i, info_h)); - } -} - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -REGISTER_OP_CUDA_KERNEL( - lstsq, - ops::LstsqCUDAKernel, - ops::LstsqCUDAKernel); - -#endif // not PADDLE_WITH_HIP diff --git a/paddle/fluid/operators/lstsq_op.h b/paddle/fluid/operators/lstsq_op.h deleted file mode 100644 index f99e027e9ced2f..00000000000000 --- a/paddle/fluid/operators/lstsq_op.h +++ /dev/null @@ -1,345 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include - -#include -#include - -#include "paddle/fluid/operators/eig_op.h" -#include "paddle/fluid/operators/math/eigen_values_vectors.h" -#include "paddle/fluid/operators/math/matrix_solve.h" -#include "paddle/fluid/operators/svd_helper.h" -#include "paddle/fluid/operators/transpose_op.h" -#include "paddle/fluid/platform/for_range.h" -#include "paddle/phi/kernels/funcs/complex_functors.h" -#include "paddle/phi/kernels/funcs/lapack/lapack_function.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -#define EPSILON 1e-6 - -namespace paddle { -namespace operators { - -using paddle::framework::Tensor; -enum class LapackDriverType : int { Gels, Gelsd, Gelsy, Gelss }; - -using DDim = framework::DDim; -static DDim UDDim(const DDim& x_dim) { - auto x_vec = vectorize(x_dim); - return phi::make_ddim(x_vec); -} - -template -class LstsqCPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - using ValueType = phi::dtype::Real; - - const Tensor& x = *context.Input("X"); - auto y = context.Input("Y"); - auto rcond = context.Attr("rcond"); - auto driver_string = context.Attr("driver"); - - static auto driver_type = std::unordered_map( - {{"gels", LapackDriverType::Gels}, - {"gelsy", LapackDriverType::Gelsy}, - {"gelsd", LapackDriverType::Gelsd}, - {"gelss", LapackDriverType::Gelss}}); - auto driver = driver_type[driver_string]; - - auto solution = context.Output("Solution"); - auto* rank = context.Output("Rank"); - auto* singular_values = context.Output("SingularValues"); - - auto dito = - math::DeviceIndependenceTensorOperations(context); - - auto x_dims = x.dims(); - auto y_dims = y->dims(); - int dim_size = x_dims.size(); - int x_stride = MatrixStride(x); - int y_stride = MatrixStride(*y); - int batch_count = BatchCount(x); - auto solution_dim = solution->dims(); - int ori_solu_stride = MatrixStride(*solution); - int max_solu_stride = std::max(y_stride, ori_solu_stride); - int min_solu_stride = std::min(y_stride, ori_solu_stride); - - // lapack is a column-major storge, transpose make the input to - // have a continuous memory layout - int info = 0; - int m = x_dims[dim_size - 2]; - int n = x_dims[dim_size - 1]; - int nrhs = y_dims[dim_size - 1]; - int lda = std::max(m, 1); - int ldb = std::max(1, std::max(m, n)); - - Tensor new_x; - new_x.mutable_data(context.GetPlace(), - size_t(batch_count * m * n * sizeof(T))); - framework::TensorCopy(x, context.GetPlace(), &new_x); - - solution->mutable_data( - context.GetPlace(), - size_t(batch_count * std::max(m, n) * nrhs * sizeof(T))); - - if (m >= n) { - const Tensor& new_y = *context.Input("Y"); - framework::TensorCopy(new_y, context.GetPlace(), solution); - } else { - auto* solu_data = solution->data(); - auto* y_data = y->data(); - for (auto i = 0; i < batch_count; i++) { - for (auto j = 0; j < min_solu_stride; j++) { - solu_data[i * max_solu_stride + j] = y_data[i * y_stride + j]; - } - } - } - - Tensor input_x_trans = dito.Transpose(new_x); - Tensor input_y_trans = dito.Transpose(*solution); - framework::TensorCopy(input_x_trans, new_x.place(), &new_x); - framework::TensorCopy(input_y_trans, solution->place(), solution); - - auto* x_vector = new_x.data(); - auto* y_vector = solution->data(); - - // "gels" divers does not need to compute rank - int rank_32 = 0; - int* rank_data = nullptr; - int* rank_working_ptr = nullptr; - if (driver != LapackDriverType::Gels) { - rank_data = rank->mutable_data(context.GetPlace()); - rank_working_ptr = rank_data; - } - - // "gelsd" and "gelss" divers need to compute singular values - ValueType* s_data = nullptr; - ValueType* s_working_ptr = nullptr; - int s_stride = 0; - if (driver == LapackDriverType::Gelsd || - driver == LapackDriverType::Gelss) { - s_data = singular_values->mutable_data(context.GetPlace()); - s_working_ptr = s_data; - auto s_dims = singular_values->dims(); - s_stride = s_dims[s_dims.size() - 1]; - } - - // "jpvt" is only used for "gelsy" driver - Tensor jpvt; - int* jpvt_data = nullptr; - if (driver == LapackDriverType::Gelsy) { - jpvt.Resize(phi::make_ddim({std::max(1, n)})); - jpvt_data = jpvt.mutable_data(context.GetPlace()); - } - - // run once the driver, first to get the optimal workspace size - int lwork = -1; - T wkopt; - ValueType rwkopt; - int iwkopt = 0; - - if (driver == LapackDriverType::Gels) { - phi::funcs::lapackGels( - 'N', m, n, nrhs, x_vector, lda, y_vector, ldb, &wkopt, lwork, &info); - } else if (driver == LapackDriverType::Gelsd) { - phi::funcs::lapackGelsd(m, - n, - nrhs, - x_vector, - lda, - y_vector, - ldb, - s_working_ptr, - static_cast(rcond), - &rank_32, - &wkopt, - lwork, - &rwkopt, - &iwkopt, - &info); - } else if (driver == LapackDriverType::Gelsy) { - phi::funcs::lapackGelsy(m, - n, - nrhs, - x_vector, - lda, - y_vector, - ldb, - jpvt_data, - static_cast(rcond), - &rank_32, - &wkopt, - lwork, - &rwkopt, - &info); - } else if (driver == LapackDriverType::Gelss) { - phi::funcs::lapackGelss(m, - n, - nrhs, - x_vector, - lda, - y_vector, - ldb, - s_working_ptr, - static_cast(rcond), - &rank_32, - &wkopt, - lwork, - &rwkopt, - &info); - } - - lwork = std::max(1, static_cast(phi::dtype::Real(wkopt))); - Tensor work; - work.Resize(phi::make_ddim({lwork})); - T* work_data = work.mutable_data(context.GetPlace()); - - // "rwork" only used for complex inputs and "gelsy/gelsd/gelss" drivers - Tensor rwork; - ValueType* rwork_data = nullptr; - if (framework::IsComplexType(framework::TransToProtoVarType(x.dtype())) && - driver != LapackDriverType::Gels) { - int rwork_len = 0; - if (driver == LapackDriverType::Gelsy) { - rwork_len = std::max(1, 2 * n); - } else if (driver == LapackDriverType::Gelss) { - rwork_len = std::max(1, 5 * std::min(m, n)); - } else if (driver == LapackDriverType::Gelsd) { - rwork_len = std::max(1, rwkopt); - } - rwork.Resize(phi::make_ddim({rwork_len})); - rwork_data = rwork.mutable_data(context.GetPlace()); - } - - // "iwork" workspace array is relavant only for "gelsd" driver - Tensor iwork; - int* iwork_data = nullptr; - if (driver == LapackDriverType::Gelsd) { - iwork.Resize(phi::make_ddim({std::max(1, iwkopt)})); - iwork_data = iwork.mutable_data(context.GetPlace()); - } - - for (auto i = 0; i < batch_count; ++i) { - auto* x_input = &x_vector[i * x_stride]; - auto* y_input = &y_vector[i * max_solu_stride]; - rank_working_ptr = rank_working_ptr ? &rank_data[i] : nullptr; - s_working_ptr = s_working_ptr ? &s_data[i * s_stride] : nullptr; - - if (driver == LapackDriverType::Gels) { - phi::funcs::lapackGels('N', - m, - n, - nrhs, - x_input, - lda, - y_input, - ldb, - work_data, - lwork, - &info); - } else if (driver == LapackDriverType::Gelsd) { - phi::funcs::lapackGelsd(m, - n, - nrhs, - x_input, - lda, - y_input, - ldb, - s_working_ptr, - static_cast(rcond), - &rank_32, - work_data, - lwork, - rwork_data, - iwork_data, - &info); - } else if (driver == LapackDriverType::Gelsy) { - phi::funcs::lapackGelsy(m, - n, - nrhs, - x_input, - lda, - y_input, - ldb, - jpvt_data, - static_cast(rcond), - &rank_32, - work_data, - lwork, - rwork_data, - &info); - } else if (driver == LapackDriverType::Gelss) { - phi::funcs::lapackGelss(m, - n, - nrhs, - x_input, - lda, - y_input, - ldb, - s_working_ptr, - static_cast(rcond), - &rank_32, - work_data, - lwork, - rwork_data, - &info); - } - - PADDLE_ENFORCE_EQ( - info, - 0, - platform::errors::PreconditionNotMet( - "For batch [%d]: Lapack info is not zero but [%d]", i, info)); - - if (rank_working_ptr) *rank_working_ptr = static_cast(rank_32); - } - - Tensor tmp_s = dito.Transpose(*solution); - framework::TensorCopy(tmp_s, solution->place(), solution); - - if (m > n) { - auto* solu_data = solution->data(); - for (auto i = 1; i < batch_count; i++) { - for (auto j = 0; j < min_solu_stride; j++) { - solu_data[i * min_solu_stride + j] = - solu_data[i * max_solu_stride + j]; - } - } - } - - solution->Resize(UDDim(solution_dim)); - } -}; - -template -void BatchedOrmqr(const DeviceContext& dev_ctx, - bool left, - bool transpose, - int batch_size, - int m, - int n, - int k, - T* a, - int a_stride, - T* tau, - int tau_stride, - T* other, - int other_stride); - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index aa86c0f34db55a..e886cbc047851a 100644 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -1254,6 +1254,14 @@ func : logsumexp backward : logsumexp_grad +- api : lstsq + args : (Tensor x, Tensor y, float rcond, str driver) + output : Tensor(solution), Tensor(residuals), Tensor(rank), Tensor(singular_values) + infer_meta : + func : LstsqInferMeta + kernel : + func : lstsq + # masked_select - api : masked_select args : (Tensor x, Tensor mask) diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 269286d76d9545..9795f6f5d54e74 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -1928,6 +1928,90 @@ void TriangularSolveInferMeta(const MetaTensor& x, out->share_lod(y); } +void LstsqInferMeta(const MetaTensor& x, + const MetaTensor& y, + float rcond, + std::string driver, + MetaTensor* solution, + MetaTensor* residuals, + MetaTensor* rank, + MetaTensor* singular_values) { + auto x_dims = x.dims(); + auto y_dims = y.dims(); + int x_rank = x_dims.size(); + int y_rank = y_dims.size(); + + int m = x_dims[x_rank - 2]; + int n = x_dims[x_rank - 1]; + int nrhs = y_dims[x_rank - 1]; + + PADDLE_ENFORCE_GE( + x_rank, + 2, + phi::errors::InvalidArgument("Expects input tensor x to be not less than " + "2 dimentions, but got dimention %d", + x_rank)); + PADDLE_ENFORCE_GE( + y_rank, + 2, + phi::errors::InvalidArgument("Expects input tensor y to be not less than " + "2 dimentions, but got dimention %d", + y_rank)); + + PADDLE_ENFORCE_EQ( + x_rank, + y_rank, + phi::errors::InvalidArgument( + "Expects input tensor x and y to have the same dimension " + "but got x's dimention [%d] and y's dimention [%d]", + x_rank, + y_rank)); + + std::vector batch_dims_vec{}; + for (int i = 0; i < x_rank - 2; ++i) { + PADDLE_ENFORCE_EQ(x_dims[i], + y_dims[i], + phi::errors::InvalidArgument( + "Expects input tensor x and y to have the same batch " + "dimension, but got x's batch dimention [%d] and " + "y's batch dimention [%d] in %d-th dim", + x_dims[i], + y_dims[i], + i)); + batch_dims_vec.emplace_back(x_dims[i]); + } + + PADDLE_ENFORCE_EQ( + m, + y_dims[y_rank - 2], + phi::errors::InvalidArgument( + "Expects input tensor x and y to have the same row dimension " + "of the inner-most 2-dims matrix, " + "but got x's row dimention [%d] and y's row dimention [%d]", + m, + y_dims[y_rank - 2])); + + rank->set_dims(phi::make_ddim(batch_dims_vec)); + + if (m > n) { + batch_dims_vec.emplace_back(nrhs); + residuals->set_dims(phi::make_ddim(batch_dims_vec)); + batch_dims_vec.pop_back(); + } else { + residuals->set_dims(phi::make_ddim({0})); + } + residuals->set_dtype(y.dtype()); + + batch_dims_vec.emplace_back(std::min(m, n)); + singular_values->set_dims(phi::make_ddim(batch_dims_vec)); + singular_values->set_dtype(y.dtype()); + + batch_dims_vec[x_rank - 2] = n; + batch_dims_vec.emplace_back(nrhs); + solution->set_dims(phi::make_ddim(batch_dims_vec)); + solution->set_dtype(y.dtype()); +} + void YoloBoxInferMeta(const MetaTensor& x, const MetaTensor& img_size, const std::vector& anchors, diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 9709edf63ccc07..9ac61a9c4216cb 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -282,6 +282,15 @@ void TriangularSolveInferMeta(const MetaTensor& x, bool unitriangular, MetaTensor* out); +void LstsqInferMeta(const MetaTensor& x, + const MetaTensor& y, + float rcond, + std::string driver, + MetaTensor* solution, + MetaTensor* residuals, + MetaTensor* rank, + MetaTensor* singular_values); + void YoloBoxInferMeta(const MetaTensor& x, const MetaTensor& img_size, const std::vector& anchors, diff --git a/paddle/phi/kernels/cpu/lstsq_kernel.cc b/paddle/phi/kernels/cpu/lstsq_kernel.cc new file mode 100644 index 00000000000000..de1f05394f2589 --- /dev/null +++ b/paddle/phi/kernels/cpu/lstsq_kernel.cc @@ -0,0 +1,296 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/complex_functors.h" +#include "paddle/phi/kernels/funcs/lapack/lapack_function.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/lstsq_kernel.h" + +namespace phi { + +enum class LapackDriverType : int { Gels, Gelsd, Gelsy, Gelss }; + +template +void LstsqKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + float rcond, + const std::string& driver_string, + DenseTensor* solution, + DenseTensor* residuals, + DenseTensor* rank, + DenseTensor* singular_values) { + using ValueType = phi::dtype::Real; + + static auto driver_type = std::unordered_map( + {{"gels", LapackDriverType::Gels}, + {"gelsy", LapackDriverType::Gelsy}, + {"gelsd", LapackDriverType::Gelsd}, + {"gelss", LapackDriverType::Gelss}}); + auto driver = driver_type[driver_string]; + + auto x_dims = x.dims(); + auto y_dims = y.dims(); + int dim_size = x_dims.size(); + int x_stride = phi::GetMatrixStride(x_dims); + int y_stride = phi::GetMatrixStride(y_dims); + int batch_count = phi::GetBatchCount(x_dims); + auto solution_dim = solution->dims(); + int ori_solu_stride = phi::GetMatrixStride(solution_dim); + int max_solu_stride = std::max(y_stride, ori_solu_stride); + int min_solu_stride = std::min(y_stride, ori_solu_stride); + + // lapack is a column-major storge, transpose make the input to + // have a continuous memory layout + int info = 0; + int m = x_dims[dim_size - 2]; + int n = x_dims[dim_size - 1]; + int nrhs = y_dims[dim_size - 1]; + int lda = std::max(m, 1); + int ldb = std::max(1, std::max(m, n)); + + DenseTensor* new_x = new DenseTensor(); + new_x->Resize(phi::make_ddim({batch_count, m, n})); + dev_ctx.template Alloc(new_x); + phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), true, new_x); + + solution->Resize(phi::make_ddim({batch_count, std::max(m, n), nrhs})); + dev_ctx.template Alloc(solution); + + if (m >= n) { + phi::Copy(dev_ctx, y, dev_ctx.GetPlace(), true, solution); + } else { + auto* solu_data = solution->data(); + auto* y_data = y.data(); + for (auto i = 0; i < batch_count; i++) { + for (auto j = 0; j < min_solu_stride; j++) { + solu_data[i * max_solu_stride + j] = y_data[i * y_stride + j]; + } + } + } + + DenseTensor input_x_trans = phi::TransposeLast2Dim(dev_ctx, *new_x); + DenseTensor input_y_trans = phi::TransposeLast2Dim(dev_ctx, *solution); + phi::Copy(dev_ctx, input_x_trans, dev_ctx.GetPlace(), true, new_x); + phi::Copy( + dev_ctx, input_y_trans, dev_ctx.GetPlace(), true, solution); + + auto* x_vector = new_x->data(); + auto* y_vector = solution->data(); + + // "gels" divers does not need to compute rank + int rank_32 = 0; + int* rank_data = nullptr; + int* rank_working_ptr = nullptr; + if (driver != LapackDriverType::Gels) { + rank_data = dev_ctx.template Alloc(rank); + rank_working_ptr = rank_data; + } + + // "gelsd" and "gelss" divers need to compute singular values + ValueType* s_data = nullptr; + ValueType* s_working_ptr = nullptr; + int s_stride = 0; + if (driver == LapackDriverType::Gelsd || driver == LapackDriverType::Gelss) { + s_data = dev_ctx.template Alloc(singular_values); + s_working_ptr = s_data; + auto s_dims = singular_values->dims(); + s_stride = s_dims[s_dims.size() - 1]; + } + + // "jpvt" is only used for "gelsy" driver + DenseTensor* jpvt = new DenseTensor(); + int* jpvt_data = nullptr; + if (driver == LapackDriverType::Gelsy) { + jpvt->Resize(phi::make_ddim({std::max(1, n)})); + jpvt_data = dev_ctx.template Alloc(jpvt); + } + + // run once the driver, first to get the optimal workspace size + int lwork = -1; + T wkopt; + ValueType rwkopt; + int iwkopt = 0; + + if (driver == LapackDriverType::Gels) { + phi::funcs::lapackGels( + 'N', m, n, nrhs, x_vector, lda, y_vector, ldb, &wkopt, lwork, &info); + } else if (driver == LapackDriverType::Gelsd) { + phi::funcs::lapackGelsd(m, + n, + nrhs, + x_vector, + lda, + y_vector, + ldb, + s_working_ptr, + static_cast(rcond), + &rank_32, + &wkopt, + lwork, + &rwkopt, + &iwkopt, + &info); + } else if (driver == LapackDriverType::Gelsy) { + phi::funcs::lapackGelsy(m, + n, + nrhs, + x_vector, + lda, + y_vector, + ldb, + jpvt_data, + static_cast(rcond), + &rank_32, + &wkopt, + lwork, + &rwkopt, + &info); + } else if (driver == LapackDriverType::Gelss) { + phi::funcs::lapackGelss(m, + n, + nrhs, + x_vector, + lda, + y_vector, + ldb, + s_working_ptr, + static_cast(rcond), + &rank_32, + &wkopt, + lwork, + &rwkopt, + &info); + } + + lwork = std::max(1, static_cast(phi::dtype::Real(wkopt))); + DenseTensor* work = new DenseTensor(); + work->Resize(phi::make_ddim({lwork})); + T* work_data = dev_ctx.template Alloc(work); + + // "rwork" only used for complex inputs and "gelsy/gelsd/gelss" drivers + DenseTensor* rwork = new DenseTensor(); + ValueType* rwork_data = nullptr; + if (IsComplexDtype(x.dtype()) && driver != LapackDriverType::Gels) { + int rwork_len = 0; + if (driver == LapackDriverType::Gelsy) { + rwork_len = std::max(1, 2 * n); + } else if (driver == LapackDriverType::Gelss) { + rwork_len = std::max(1, 5 * std::min(m, n)); + } else if (driver == LapackDriverType::Gelsd) { + rwork_len = std::max(1, rwkopt); + } + rwork->Resize(phi::make_ddim({rwork_len})); + rwork_data = dev_ctx.template Alloc(rwork); + } + + // "iwork" workspace array is relavant only for "gelsd" driver + DenseTensor* iwork = new DenseTensor(); + int* iwork_data = nullptr; + if (driver == LapackDriverType::Gelsd) { + iwork->Resize(phi::make_ddim({std::max(1, iwkopt)})); + iwork_data = dev_ctx.template Alloc(iwork); + } + + for (auto i = 0; i < batch_count; ++i) { + auto* x_input = &x_vector[i * x_stride]; + auto* y_input = &y_vector[i * max_solu_stride]; + rank_working_ptr = rank_working_ptr ? &rank_data[i] : nullptr; + s_working_ptr = s_working_ptr ? &s_data[i * s_stride] : nullptr; + + if (driver == LapackDriverType::Gels) { + phi::funcs::lapackGels( + 'N', m, n, nrhs, x_input, lda, y_input, ldb, work_data, lwork, &info); + } else if (driver == LapackDriverType::Gelsd) { + phi::funcs::lapackGelsd(m, + n, + nrhs, + x_input, + lda, + y_input, + ldb, + s_working_ptr, + static_cast(rcond), + &rank_32, + work_data, + lwork, + rwork_data, + iwork_data, + &info); + } else if (driver == LapackDriverType::Gelsy) { + phi::funcs::lapackGelsy(m, + n, + nrhs, + x_input, + lda, + y_input, + ldb, + jpvt_data, + static_cast(rcond), + &rank_32, + work_data, + lwork, + rwork_data, + &info); + } else if (driver == LapackDriverType::Gelss) { + phi::funcs::lapackGelss(m, + n, + nrhs, + x_input, + lda, + y_input, + ldb, + s_working_ptr, + static_cast(rcond), + &rank_32, + work_data, + lwork, + rwork_data, + &info); + } + + PADDLE_ENFORCE_EQ( + info, + 0, + phi::errors::PreconditionNotMet( + "For batch [%d]: Lapack info is not zero but [%d]", i, info)); + + if (rank_working_ptr) *rank_working_ptr = static_cast(rank_32); + } + + DenseTensor tmp_s = phi::TransposeLast2Dim(dev_ctx, *solution); + phi::Copy(dev_ctx, tmp_s, dev_ctx.GetPlace(), true, solution); + + if (m > n) { + auto* solu_data = solution->data(); + for (auto i = 1; i < batch_count; i++) { + for (auto j = 0; j < min_solu_stride; j++) { + solu_data[i * min_solu_stride + j] = solu_data[i * max_solu_stride + j]; + } + } + } + + solution->Resize(solution_dim); + GetResidualsTensor(dev_ctx, x, y, solution, residuals); +} + +} // namespace phi + +PD_REGISTER_KERNEL(lstsq, CPU, ALL_LAYOUT, phi::LstsqKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/lstsq_kernel.cu b/paddle/phi/kernels/gpu/lstsq_kernel.cu new file mode 100644 index 00000000000000..397735b8f11ebe --- /dev/null +++ b/paddle/phi/kernels/gpu/lstsq_kernel.cu @@ -0,0 +1,162 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/slice.h" +#include "paddle/phi/kernels/impl/qr_kernel_impl.h" +#include "paddle/phi/kernels/impl/tril_triu_kernel_impl.h" +#include "paddle/phi/kernels/lstsq_kernel.h" +#include "paddle/phi/kernels/matmul_kernel.h" +#include "paddle/phi/kernels/triangular_solve_kernel.h" + +namespace phi { + +enum class LapackDriverType : int { Gels, Gelsd, Gelsy, Gelss }; + +template +void LstsqKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + float rcond, + const std::string& driver_string, + DenseTensor* solution, + DenseTensor* residuals, + DenseTensor* rank, + DenseTensor* singular_values) { + auto x_dims = x.dims(); + auto y_dims = y.dims(); + int dim_size = x_dims.size(); + int m = x_dims[dim_size - 2]; + int n = x_dims[dim_size - 1]; + int nrhs = y_dims[dim_size - 1]; + int min_mn = std::min(m, n); + int max_mn = std::max(m, n); + int k = min_mn; + + int x_stride = phi::GetMatrixStride(x_dims); + int y_stride = phi::GetMatrixStride(y_dims); + int tau_stride = min_mn; + int batch_count = phi::GetBatchCount(x_dims); + + DenseTensor* new_x = new DenseTensor(); + new_x->Resize(phi::make_ddim({batch_count, m, n})); + dev_ctx.template Alloc(new_x); + phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), true, new_x); + + DenseTensor* new_y = new DenseTensor(); + new_y->Resize(phi::make_ddim({batch_count, m, nrhs})); + dev_ctx.template Alloc(new_y); + phi::Copy(dev_ctx, y, dev_ctx.GetPlace(), true, new_y); + + // Prepare tau + auto tau_dims_vec = phi::vectorize(x_dims); + tau_dims_vec.pop_back(); + tau_dims_vec[tau_dims_vec.size() - 1] = min_mn; + + DenseTensor* tau = new DenseTensor(); + tau->Resize(phi::make_ddim(tau_dims_vec)); + auto tau_data = dev_ctx.template Alloc(tau); + + if (m >= n) { + DenseTensor tmp_x = phi::TransposeLast2Dim(dev_ctx, *new_x); + DenseTensor tmp_y = phi::TransposeLast2Dim(dev_ctx, *new_y); + auto x_data = tmp_x.data(); + auto y_data = tmp_y.data(); + + // step 1, compute QR factorization using geqrf + BatchedGeqrf( + dev_ctx, batch_count, m, n, x_data, m, tau_data, x_stride, tau_stride); + + // Step 2, Y <- Q^H Y + BatchedOrmqr(dev_ctx, + true, + true, + batch_count, + m, + nrhs, + k, + x_data, + x_stride, + tau_data, + tau_stride, + y_data, + y_stride); + + DenseTensor trans_r = phi::TransposeLast2Dim(dev_ctx, tmp_x); + DenseTensor slice_r = + phi::funcs::Slice(dev_ctx, trans_r, {-2}, {0}, {min_mn}); + DenseTensor* res_r = new DenseTensor(); + res_r->Resize(phi::make_ddim({batch_count, min_mn, min_mn})); + dev_ctx.template Alloc(res_r); + phi::TrilTriuKernel(dev_ctx, slice_r, 0, false, res_r); + + DenseTensor trans_y = phi::TransposeLast2Dim(dev_ctx, tmp_y); + DenseTensor slice_y = + phi::funcs::Slice(dev_ctx, trans_y, {-2}, {0}, {min_mn}); + + // Step 3, solve R X = Y + phi::TriangularSolveKernel( + dev_ctx, *res_r, slice_y, true, false, false, solution); + + } else { + auto x_data = dev_ctx.template Alloc(new_x); + auto y_data = dev_ctx.template Alloc(new_y); + + // step 1, compute QR factorization using geqrf + BatchedGeqrf( + dev_ctx, batch_count, n, m, x_data, n, tau_data, x_stride, tau_stride); + + // Step 2, solve R^H Z = Y + DenseTensor trans_r = phi::TransposeLast2Dim(dev_ctx, *new_x); + DenseTensor slice_r = + phi::funcs::Slice(dev_ctx, trans_r, {-2}, {0}, {min_mn}); + DenseTensor* res_r = new DenseTensor(); + res_r->Resize(phi::make_ddim({batch_count, min_mn, min_mn})); + dev_ctx.template Alloc(res_r); + phi::TrilTriuKernel(dev_ctx, slice_r, 0, false, res_r); + + phi::TriangularSolveKernel( + dev_ctx, *res_r, *new_y, true, true, false, solution); + + // Step 3, X <- Q Z + BatchedOrgqr(dev_ctx, + batch_count, + n, + m, + min_mn, + x_data, + n, + tau_data, + x_stride, + tau_stride); + + DenseTensor trans_q = phi::TransposeLast2Dim(dev_ctx, *new_x); + DenseTensor slice_q = + phi::funcs::Slice(dev_ctx, trans_q, {-1}, {0}, {m}); + DenseTensor solu_tensor = + phi::Matmul(dev_ctx, slice_q, *solution, false, false); + phi::Copy( + dev_ctx, solu_tensor, dev_ctx.GetPlace(), true, solution); + } + GetResidualsTensor(dev_ctx, x, y, solution, residuals); +} + +} // namespace phi + +PD_REGISTER_KERNEL(lstsq, GPU, ALL_LAYOUT, phi::LstsqKernel, float, double) {} diff --git a/paddle/phi/kernels/impl/lstsq_kernel_impl.h b/paddle/phi/kernels/impl/lstsq_kernel_impl.h new file mode 100644 index 00000000000000..774f5f9d1582ab --- /dev/null +++ b/paddle/phi/kernels/impl/lstsq_kernel_impl.h @@ -0,0 +1,239 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/platform/dynload/cusolver.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/utils/optional.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/elementwise_subtract_kernel.h" +#include "paddle/phi/kernels/impl/activation_impl.h" +#include "paddle/phi/kernels/matmul_kernel.h" +#include "paddle/phi/kernels/reduce_sum_kernel.h" + +namespace phi { + +inline int GetBatchCount(const DDim& dims) { + int count = 1; + int num_dims = dims.size(); + for (int i = 0; i < num_dims - 2; ++i) { + count *= dims[i]; + } + return count; +} + +inline int GetMatrixStride(const DDim& dims) { + int num_dims = dims.size(); + return dims[num_dims - 1] * dims[num_dims - 2]; +} + +inline bool IsComplexDtype(const DataType& type) { + return (type == DataType::COMPLEX64 || type == DataType::COMPLEX128); +} + +template +inline void GetResidualsTensor(const DeviceContext& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* solution, + DenseTensor* residuals) { + auto x_dims = x.dims(); + int dim_size = x_dims.size(); + int m = x_dims[dim_size - 2]; + int n = x_dims[dim_size - 1]; + + if (m > n) { + DenseTensor matmul_tensor = + phi::Matmul(dev_ctx, x, *solution, false, false); + DenseTensor sub_tensor = phi::Subtract(dev_ctx, matmul_tensor, y); + DenseTensor* pow_tensor = new DenseTensor(); + pow_tensor->Resize(sub_tensor.dims()); + dev_ctx.template Alloc(pow_tensor); + phi::PowKernel(dev_ctx, sub_tensor, Scalar(2), pow_tensor); + + auto sum_tensor = + phi::Sum(dev_ctx, *pow_tensor, {-2}, pow_tensor->dtype(), false); + phi::Copy( + dev_ctx, sum_tensor, dev_ctx.GetPlace(), true, residuals); + } else { + IntArray empty_shape({0}); + DenseTensor empty_tensor = + phi::Empty(dev_ctx, empty_shape); + phi::Copy( + dev_ctx, empty_tensor, dev_ctx.GetPlace(), true, residuals); + } +} + +template +inline void BatchedOrmqr(const DeviceContext& dev_ctx, + bool left, + bool transpose, + int batch_size, + int m, + int n, + int k, + T* a, + int a_stride, + T* tau, + int tau_stride, + T* other, + int other_stride); + +template <> +inline void BatchedOrmqr(const GPUContext& dev_ctx, + bool left, + bool transpose, + int batch_size, + int m, + int n, + int k, + float* a, + int a_stride, + float* tau, + int tau_stride, + float* other, + int other_stride) { + int lwork = 0; + auto side = left ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT; + auto trans = transpose ? CUBLAS_OP_T : CUBLAS_OP_N; + int lda = std::max(1, left ? m : n); + int ldc = std::max(1, m); + + auto handle = dev_ctx.cusolver_dn_handle(); + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cusolverDnSormqr_bufferSize( + handle, side, trans, m, n, k, a, lda, tau, other, ldc, &lwork)); + DenseTensor* info = new DenseTensor(); + info->Resize(make_ddim({1})); + int* info_d = dev_ctx.template Alloc(info); + + for (int i = 0; i < batch_size; ++i) { + float* a_working_ptr = &a[i * a_stride]; + float* tau_working_ptr = &tau[i * tau_stride]; + float* other_working_ptr = &other[i * other_stride]; + + handle = dev_ctx.cusolver_dn_handle(); + DenseTensor* workspace = new DenseTensor(); + workspace->Resize(make_ddim({lwork})); + float* workspace_ptr = dev_ctx.template Alloc(workspace); + + // compute ormgr + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cusolverDnSormqr(handle, + side, + trans, + m, + n, + k, + a_working_ptr, + lda, + tau_working_ptr, + other_working_ptr, + ldc, + workspace_ptr, + lwork, + info_d)); + + // check the error info + int info_h; + paddle::memory::Copy(phi::CPUPlace(), + &info_h, + dev_ctx.GetPlace(), + info_d, + sizeof(int), + dev_ctx.stream()); + PADDLE_ENFORCE_EQ( + info_h, + 0, + phi::errors::PreconditionNotMet( + "For batch [%d]: CUSolver info is not zero but [%d]", i, info_h)); + } +} + +template <> +inline void BatchedOrmqr(const GPUContext& dev_ctx, + bool left, + bool transpose, + int batch_size, + int m, + int n, + int k, + double* a, + int a_stride, + double* tau, + int tau_stride, + double* other, + int other_stride) { + int lwork = 0; + auto side = left ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT; + auto trans = transpose ? CUBLAS_OP_T : CUBLAS_OP_N; + int lda = std::max(1, left ? m : n); + int ldc = std::max(1, m); + + auto handle = dev_ctx.cusolver_dn_handle(); + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cusolverDnDormqr_bufferSize( + handle, side, trans, m, n, k, a, lda, tau, other, ldc, &lwork)); + DenseTensor* info = new DenseTensor(); + info->Resize(make_ddim({1})); + int* info_d = dev_ctx.template Alloc(info); + + for (int i = 0; i < batch_size; ++i) { + double* a_working_ptr = &a[i * a_stride]; + double* tau_working_ptr = &tau[i * tau_stride]; + double* other_working_ptr = &other[i * other_stride]; + + handle = dev_ctx.cusolver_dn_handle(); + DenseTensor* workspace = new DenseTensor(); + workspace->Resize(make_ddim({lwork})); + double* workspace_ptr = dev_ctx.template Alloc(workspace); + + // compute ormgr + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cusolverDnDormqr(handle, + side, + trans, + m, + n, + k, + a_working_ptr, + lda, + tau_working_ptr, + other_working_ptr, + ldc, + workspace_ptr, + lwork, + info_d)); + + // check the error info + int info_h; + paddle::memory::Copy(phi::CPUPlace(), + &info_h, + dev_ctx.GetPlace(), + info_d, + sizeof(int), + dev_ctx.stream()); + PADDLE_ENFORCE_EQ( + info_h, + 0, + phi::errors::PreconditionNotMet( + "For batch [%d]: CUSolver info is not zero but [%d]", i, info_h)); + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/qr_kernel_impl.h b/paddle/phi/kernels/impl/qr_kernel_impl.h new file mode 100644 index 00000000000000..188ab45213ed5d --- /dev/null +++ b/paddle/phi/kernels/impl/qr_kernel_impl.h @@ -0,0 +1,277 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/platform/dynload/cusolver.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/utils/optional.h" + +namespace phi { + +template +void BatchedGeqrf(const DeviceContext& dev_ctx, + int batch_size, + int m, + int n, + T* a, + int lda, + T* tau, + int a_stride, + int tau_stride); + +template +void BatchedOrgqr(const DeviceContext& dev_ctx, + int batch_size, + int m, + int n, + int k, + T* a, + int lda, + T* tau, + int a_stride, + int tau_stride); + +template <> +void BatchedGeqrf(const GPUContext& dev_ctx, + int batch_size, + int m, + int n, + float* a, + int lda, + float* tau, + int a_stride, + int tau_stride) { + int lwork = 0; + + auto handle = dev_ctx.cusolver_dn_handle(); + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cusolverDnSgeqrf_bufferSize( + handle, m, n, a, lda, &lwork)); + + DenseTensor* workspace = new DenseTensor(); + workspace->Resize(make_ddim({lwork})); + float* workspace_ptr = dev_ctx.template Alloc(workspace); + + DenseTensor* info = new DenseTensor(); + info->Resize(make_ddim({1})); + int* info_d = dev_ctx.template Alloc(info); + + for (int i = 0; i < batch_size; ++i) { + float* a_working_ptr = &a[i * a_stride]; + float* tau_working_ptr = &tau[i * tau_stride]; + // compute geqrf + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cusolverDnSgeqrf(handle, + m, + n, + a_working_ptr, + lda, + tau_working_ptr, + workspace_ptr, + lwork, + info_d)); + // Do we need synchronized here? + // check the error info + int info_h; + paddle::memory::Copy(phi::CPUPlace(), + &info_h, + dev_ctx.GetPlace(), + info_d, + sizeof(int), + dev_ctx.stream()); + PADDLE_ENFORCE_EQ( + info_h, + 0, + phi::errors::PreconditionNotMet( + "For batch [%d]: CUSolver geqrf is not zero. [%d]", i, info_h)); + } +} + +template <> +void BatchedGeqrf(const GPUContext& dev_ctx, + int batch_size, + int m, + int n, + double* a, + int lda, + double* tau, + int a_stride, + int tau_stride) { + int lwork = 0; + + auto handle = dev_ctx.cusolver_dn_handle(); + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cusolverDnDgeqrf_bufferSize( + handle, m, n, a, lda, &lwork)); + + DenseTensor* workspace = new DenseTensor(); + workspace->Resize(make_ddim({lwork})); + double* workspace_ptr = dev_ctx.template Alloc(workspace); + + DenseTensor* info = new DenseTensor(); + info->Resize(make_ddim({1})); + int* info_d = dev_ctx.template Alloc(info); + + for (int i = 0; i < batch_size; ++i) { + double* a_working_ptr = &a[i * a_stride]; + double* tau_working_ptr = &tau[i * tau_stride]; + // compute geqrf + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cusolverDnDgeqrf(handle, + m, + n, + a_working_ptr, + lda, + tau_working_ptr, + workspace_ptr, + lwork, + info_d)); + // Do we need synchronized here? + // check the error info + int info_h; + paddle::memory::Copy(phi::CPUPlace(), + &info_h, + dev_ctx.GetPlace(), + info_d, + sizeof(int), + dev_ctx.stream()); + PADDLE_ENFORCE_EQ( + info_h, + 0, + phi::errors::PreconditionNotMet( + "For batch [%d]: CUSolver geqrf is not zero. [%d]", i, info_h)); + } +} + +template <> +void BatchedOrgqr(const GPUContext& dev_ctx, + int batch_size, + int m, + int n, + int k, + float* a, + int lda, + float* tau, + int a_stride, + int tau_stride) { + int lwork = 0; + + auto handle = dev_ctx.cusolver_dn_handle(); + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cusolverDnSorgqr_bufferSize( + handle, m, n, k, a, lda, tau, &lwork)); + + DenseTensor* workspace = new DenseTensor(); + workspace->Resize(make_ddim({lwork})); + float* workspace_ptr = dev_ctx.template Alloc(workspace); + + DenseTensor* info = new DenseTensor(); + info->Resize(make_ddim({1})); + int* info_d = dev_ctx.template Alloc(info); + + for (int i = 0; i < batch_size; ++i) { + float* a_working_ptr = &a[i * a_stride]; + float* tau_working_ptr = &tau[i * tau_stride]; + // compute orggr + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cusolverDnSorgqr(handle, + m, + n, + k, + a_working_ptr, + lda, + tau_working_ptr, + workspace_ptr, + lwork, + info_d)); + // Do we need synchronized here? + // check the error info + int info_h; + paddle::memory::Copy(phi::CPUPlace(), + &info_h, + dev_ctx.GetPlace(), + info_d, + sizeof(int), + dev_ctx.stream()); + PADDLE_ENFORCE_EQ( + info_h, + 0, + phi::errors::PreconditionNotMet( + "For batch [%d]: CUSolver QR is not zero. [%d]", i, info_h)); + } +} + +template <> +void BatchedOrgqr(const GPUContext& dev_ctx, + int batch_size, + int m, + int n, + int k, + double* a, + int lda, + double* tau, + int a_stride, + int tau_stride) { + int lwork = 0; + + auto handle = dev_ctx.cusolver_dn_handle(); + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cusolverDnDorgqr_bufferSize( + handle, m, n, k, a, lda, tau, &lwork)); + + DenseTensor* workspace = new DenseTensor(); + workspace->Resize(make_ddim({lwork})); + double* workspace_ptr = dev_ctx.template Alloc(workspace); + + DenseTensor* info = new DenseTensor(); + info->Resize(make_ddim({1})); + int* info_d = dev_ctx.template Alloc(info); + + for (int i = 0; i < batch_size; ++i) { + double* a_working_ptr = &a[i * a_stride]; + double* tau_working_ptr = &tau[i * tau_stride]; + // compute orggr + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cusolverDnDorgqr(handle, + m, + n, + k, + a_working_ptr, + lda, + tau_working_ptr, + workspace_ptr, + lwork, + info_d)); + // Do we need synchronized here? + // check the error info + int info_h; + paddle::memory::Copy(phi::CPUPlace(), + &info_h, + dev_ctx.GetPlace(), + info_d, + sizeof(int), + dev_ctx.stream()); + PADDLE_ENFORCE_EQ( + info_h, + 0, + phi::errors::PreconditionNotMet( + "For batch [%d]: CUSolver QR is not zero. [%d]", i, info_h)); + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/lstsq_kernel.h b/paddle/phi/kernels/lstsq_kernel.h new file mode 100644 index 00000000000000..709858d641bda0 --- /dev/null +++ b/paddle/phi/kernels/lstsq_kernel.h @@ -0,0 +1,36 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/elementwise_subtract_kernel.h" +#include "paddle/phi/kernels/impl/activation_impl.h" +#include "paddle/phi/kernels/impl/lstsq_kernel_impl.h" +#include "paddle/phi/kernels/matmul_kernel.h" +#include "paddle/phi/kernels/transpose_kernel.h" + +namespace phi { + +template +void LstsqKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + float rcond, + const std::string& driver, + DenseTensor* solution, + DenseTensor* residuals, + DenseTensor* rank, + DenseTensor* singular_values); +} // namespace phi diff --git a/paddle/phi/kernels/reduce_sum_kernel.h b/paddle/phi/kernels/reduce_sum_kernel.h index c969cea296db13..6dcd459d5016e0 100644 --- a/paddle/phi/kernels/reduce_sum_kernel.h +++ b/paddle/phi/kernels/reduce_sum_kernel.h @@ -44,7 +44,7 @@ DenseTensor Sum(const Context& dev_ctx, DenseTensor dense_out; MetaTensor meta_out(&dense_out); SumInferMeta(x, axis, dtype, keep_dim, &meta_out); - SumKernel(dev_ctx, x, axis, dtype, keep_dim, &dense_out); + SumKernel(dev_ctx, x, axis, dtype, false, &dense_out); return dense_out; } diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 95eaee2cc03564..06cd137e87f804 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -3160,22 +3160,8 @@ def lstsq(x, y, rcond=None, driver=None, name=None): rcond = 1e-15 * max(x.shape[-2], x.shape[-1]) if _non_static_mode(): - solution, rank, singular_values = _C_ops.lstsq(x, y, "rcond", rcond, - "driver", driver) - if x.shape[-2] > x.shape[-1]: - matmul_out = _varbase_creator(dtype=x.dtype) - _C_ops.matmul(x, solution, matmul_out, 'trans_x', False, 'trans_y', - False) - minus_out = _C_ops.elementwise_sub(matmul_out, y) - pow_out = _C_ops.pow(minus_out, 'factor', 2) - if in_dygraph_mode(): - residuals = _C_ops.final_state_sum(pow_out, [-2], None, False) - else: - residuals = _C_ops.reduce_sum(pow_out, 'dim', [-2], 'keepdim', - False, 'reduce_all', False) - else: - residuals = paddle.empty(shape=[0], dtype=x.dtype) - + solution, residuals, rank, singular_values = _C_ops.final_state_lstsq( + x, y, rcond, driver) if driver == "gels": rank = paddle.empty(shape=[0], dtype=paddle.int32) singular_values = paddle.empty(shape=[0], dtype=x.dtype) @@ -3204,6 +3190,7 @@ def lstsq(x, y, rcond=None, driver=None, name=None): }, outputs={ 'Solution': solution, + 'Residuals': residuals, 'Rank': rank, 'SingularValues': singular_values }, @@ -3212,41 +3199,6 @@ def lstsq(x, y, rcond=None, driver=None, name=None): 'driver': driver }) - matmul_out = helper.create_variable_for_type_inference(dtype=x.dtype) - minus_out = helper.create_variable_for_type_inference(dtype=x.dtype) - pow_out = helper.create_variable_for_type_inference(dtype=x.dtype) - helper.append_op(type='matmul_v2', - inputs={ - 'X': x, - 'Y': solution - }, - outputs={'Out': matmul_out}, - attrs={ - 'trans_x': False, - 'trans_y': False, - }) - - helper.append_op(type='elementwise_sub', - inputs={ - 'X': matmul_out, - 'Y': y - }, - outputs={'Out': minus_out}) - - helper.append_op(type='pow', - inputs={'X': minus_out}, - outputs={'Out': pow_out}, - attrs={'factor': 2}) - - helper.append_op(type='reduce_sum', - inputs={'X': pow_out}, - outputs={'Out': residuals}, - attrs={ - 'dim': [-2], - 'keep_dim': False, - 'reduce_all': False - }) - if driver == "gels": rank = paddle.static.data(name='rank', shape=[0]) singular_values = paddle.static.data(name='singular_values', shape=[0]) From 93530134fa2e82fb363ad306daca78a7175ca327 Mon Sep 17 00:00:00 2001 From: haohongxiang Date: Thu, 14 Jul 2022 07:09:48 +0000 Subject: [PATCH 02/17] update --- paddle/fluid/operators/lstsq_op.cc | 2 + paddle/fluid/operators/lstsq_op.h | 345 ------------------ paddle/phi/api/yaml/legacy_api.yaml | 3 +- paddle/phi/infermeta/binary.cc | 2 +- paddle/phi/infermeta/binary.h | 2 +- paddle/phi/kernels/cpu/lstsq_kernel.cc | 10 +- paddle/phi/kernels/gpu/lstsq_kernel.cu | 6 +- paddle/phi/kernels/lstsq_kernel.h | 2 +- paddle/phi/kernels/reduce_sum_kernel.h | 2 +- paddle/phi/ops/compat/lstsq_sig.cc | 28 ++ .../tests/unittests/test_linalg_lstsq_op.py | 57 +-- python/paddle/tensor/linalg.py | 21 +- 12 files changed, 92 insertions(+), 388 deletions(-) delete mode 100644 paddle/fluid/operators/lstsq_op.h create mode 100644 paddle/phi/ops/compat/lstsq_sig.cc diff --git a/paddle/fluid/operators/lstsq_op.cc b/paddle/fluid/operators/lstsq_op.cc index 90381c4227c95f..3b45cacdb73e4a 100644 --- a/paddle/fluid/operators/lstsq_op.cc +++ b/paddle/fluid/operators/lstsq_op.cc @@ -57,6 +57,8 @@ class LstsqOpMaker : public framework::OpProtoAndCheckerMaker { .SetDefault("gels"); AddOutput("Solution", "(Tensor), The output Solution tensor with shape (*, n, k)."); + AddOutput("Residuals", + "(Tensor), The output Residuals tensor with shape (*, k)."); AddOutput("Rank", "(Tensor), The output Rank tensor with shape (*)."); AddOutput( "SingularValues", diff --git a/paddle/fluid/operators/lstsq_op.h b/paddle/fluid/operators/lstsq_op.h deleted file mode 100644 index b3e5894a9451e7..00000000000000 --- a/paddle/fluid/operators/lstsq_op.h +++ /dev/null @@ -1,345 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include - -#include -#include - -#include "paddle/fluid/operators/eig_op.h" -#include "paddle/fluid/operators/math/eigen_values_vectors.h" -#include "paddle/fluid/operators/svd_helper.h" -#include "paddle/fluid/operators/transpose_op.h" -#include "paddle/fluid/platform/for_range.h" -#include "paddle/phi/kernels/funcs/complex_functors.h" -#include "paddle/phi/kernels/funcs/lapack/lapack_function.h" -#include "paddle/phi/kernels/funcs/math_function.h" -#include "paddle/phi/kernels/funcs/matrix_solve.h" - -#define EPSILON 1e-6 - -namespace paddle { -namespace operators { - -using paddle::framework::Tensor; -enum class LapackDriverType : int { Gels, Gelsd, Gelsy, Gelss }; - -using DDim = framework::DDim; -static DDim UDDim(const DDim& x_dim) { - auto x_vec = vectorize(x_dim); - return phi::make_ddim(x_vec); -} - -template -class LstsqCPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - using ValueType = phi::dtype::Real; - - const Tensor& x = *context.Input("X"); - auto y = context.Input("Y"); - auto rcond = context.Attr("rcond"); - auto driver_string = context.Attr("driver"); - - static auto driver_type = std::unordered_map( - {{"gels", LapackDriverType::Gels}, - {"gelsy", LapackDriverType::Gelsy}, - {"gelsd", LapackDriverType::Gelsd}, - {"gelss", LapackDriverType::Gelss}}); - auto driver = driver_type[driver_string]; - - auto solution = context.Output("Solution"); - auto* rank = context.Output("Rank"); - auto* singular_values = context.Output("SingularValues"); - - auto dito = - math::DeviceIndependenceTensorOperations(context); - - auto x_dims = x.dims(); - auto y_dims = y->dims(); - int dim_size = x_dims.size(); - int x_stride = MatrixStride(x); - int y_stride = MatrixStride(*y); - int batch_count = BatchCount(x); - auto solution_dim = solution->dims(); - int ori_solu_stride = MatrixStride(*solution); - int max_solu_stride = std::max(y_stride, ori_solu_stride); - int min_solu_stride = std::min(y_stride, ori_solu_stride); - - // lapack is a column-major storge, transpose make the input to - // have a continuous memory layout - int info = 0; - int m = x_dims[dim_size - 2]; - int n = x_dims[dim_size - 1]; - int nrhs = y_dims[dim_size - 1]; - int lda = std::max(m, 1); - int ldb = std::max(1, std::max(m, n)); - - Tensor new_x; - new_x.mutable_data(context.GetPlace(), - size_t(batch_count * m * n * sizeof(T))); - framework::TensorCopy(x, context.GetPlace(), &new_x); - - solution->mutable_data( - context.GetPlace(), - size_t(batch_count * std::max(m, n) * nrhs * sizeof(T))); - - if (m >= n) { - const Tensor& new_y = *context.Input("Y"); - framework::TensorCopy(new_y, context.GetPlace(), solution); - } else { - auto* solu_data = solution->data(); - auto* y_data = y->data(); - for (auto i = 0; i < batch_count; i++) { - for (auto j = 0; j < min_solu_stride; j++) { - solu_data[i * max_solu_stride + j] = y_data[i * y_stride + j]; - } - } - } - - Tensor input_x_trans = dito.Transpose(new_x); - Tensor input_y_trans = dito.Transpose(*solution); - framework::TensorCopy(input_x_trans, new_x.place(), &new_x); - framework::TensorCopy(input_y_trans, solution->place(), solution); - - auto* x_vector = new_x.data(); - auto* y_vector = solution->data(); - - // "gels" divers does not need to compute rank - int rank_32 = 0; - int* rank_data = nullptr; - int* rank_working_ptr = nullptr; - if (driver != LapackDriverType::Gels) { - rank_data = rank->mutable_data(context.GetPlace()); - rank_working_ptr = rank_data; - } - - // "gelsd" and "gelss" divers need to compute singular values - ValueType* s_data = nullptr; - ValueType* s_working_ptr = nullptr; - int s_stride = 0; - if (driver == LapackDriverType::Gelsd || - driver == LapackDriverType::Gelss) { - s_data = singular_values->mutable_data(context.GetPlace()); - s_working_ptr = s_data; - auto s_dims = singular_values->dims(); - s_stride = s_dims[s_dims.size() - 1]; - } - - // "jpvt" is only used for "gelsy" driver - Tensor jpvt; - int* jpvt_data = nullptr; - if (driver == LapackDriverType::Gelsy) { - jpvt.Resize(phi::make_ddim({std::max(1, n)})); - jpvt_data = jpvt.mutable_data(context.GetPlace()); - } - - // run once the driver, first to get the optimal workspace size - int lwork = -1; - T wkopt; - ValueType rwkopt; - int iwkopt = 0; - - if (driver == LapackDriverType::Gels) { - phi::funcs::lapackGels( - 'N', m, n, nrhs, x_vector, lda, y_vector, ldb, &wkopt, lwork, &info); - } else if (driver == LapackDriverType::Gelsd) { - phi::funcs::lapackGelsd(m, - n, - nrhs, - x_vector, - lda, - y_vector, - ldb, - s_working_ptr, - static_cast(rcond), - &rank_32, - &wkopt, - lwork, - &rwkopt, - &iwkopt, - &info); - } else if (driver == LapackDriverType::Gelsy) { - phi::funcs::lapackGelsy(m, - n, - nrhs, - x_vector, - lda, - y_vector, - ldb, - jpvt_data, - static_cast(rcond), - &rank_32, - &wkopt, - lwork, - &rwkopt, - &info); - } else if (driver == LapackDriverType::Gelss) { - phi::funcs::lapackGelss(m, - n, - nrhs, - x_vector, - lda, - y_vector, - ldb, - s_working_ptr, - static_cast(rcond), - &rank_32, - &wkopt, - lwork, - &rwkopt, - &info); - } - - lwork = std::max(1, static_cast(phi::dtype::Real(wkopt))); - Tensor work; - work.Resize(phi::make_ddim({lwork})); - T* work_data = work.mutable_data(context.GetPlace()); - - // "rwork" only used for complex inputs and "gelsy/gelsd/gelss" drivers - Tensor rwork; - ValueType* rwork_data = nullptr; - if (framework::IsComplexType(framework::TransToProtoVarType(x.dtype())) && - driver != LapackDriverType::Gels) { - int rwork_len = 0; - if (driver == LapackDriverType::Gelsy) { - rwork_len = std::max(1, 2 * n); - } else if (driver == LapackDriverType::Gelss) { - rwork_len = std::max(1, 5 * std::min(m, n)); - } else if (driver == LapackDriverType::Gelsd) { - rwork_len = std::max(1, rwkopt); - } - rwork.Resize(phi::make_ddim({rwork_len})); - rwork_data = rwork.mutable_data(context.GetPlace()); - } - - // "iwork" workspace array is relavant only for "gelsd" driver - Tensor iwork; - int* iwork_data = nullptr; - if (driver == LapackDriverType::Gelsd) { - iwork.Resize(phi::make_ddim({std::max(1, iwkopt)})); - iwork_data = iwork.mutable_data(context.GetPlace()); - } - - for (auto i = 0; i < batch_count; ++i) { - auto* x_input = &x_vector[i * x_stride]; - auto* y_input = &y_vector[i * max_solu_stride]; - rank_working_ptr = rank_working_ptr ? &rank_data[i] : nullptr; - s_working_ptr = s_working_ptr ? &s_data[i * s_stride] : nullptr; - - if (driver == LapackDriverType::Gels) { - phi::funcs::lapackGels('N', - m, - n, - nrhs, - x_input, - lda, - y_input, - ldb, - work_data, - lwork, - &info); - } else if (driver == LapackDriverType::Gelsd) { - phi::funcs::lapackGelsd(m, - n, - nrhs, - x_input, - lda, - y_input, - ldb, - s_working_ptr, - static_cast(rcond), - &rank_32, - work_data, - lwork, - rwork_data, - iwork_data, - &info); - } else if (driver == LapackDriverType::Gelsy) { - phi::funcs::lapackGelsy(m, - n, - nrhs, - x_input, - lda, - y_input, - ldb, - jpvt_data, - static_cast(rcond), - &rank_32, - work_data, - lwork, - rwork_data, - &info); - } else if (driver == LapackDriverType::Gelss) { - phi::funcs::lapackGelss(m, - n, - nrhs, - x_input, - lda, - y_input, - ldb, - s_working_ptr, - static_cast(rcond), - &rank_32, - work_data, - lwork, - rwork_data, - &info); - } - - PADDLE_ENFORCE_EQ( - info, - 0, - platform::errors::PreconditionNotMet( - "For batch [%d]: Lapack info is not zero but [%d]", i, info)); - - if (rank_working_ptr) *rank_working_ptr = static_cast(rank_32); - } - - Tensor tmp_s = dito.Transpose(*solution); - framework::TensorCopy(tmp_s, solution->place(), solution); - - if (m > n) { - auto* solu_data = solution->data(); - for (auto i = 1; i < batch_count; i++) { - for (auto j = 0; j < min_solu_stride; j++) { - solu_data[i * min_solu_stride + j] = - solu_data[i * max_solu_stride + j]; - } - } - } - - solution->Resize(UDDim(solution_dim)); - } -}; - -template -void BatchedOrmqr(const DeviceContext& dev_ctx, - bool left, - bool transpose, - int batch_size, - int m, - int n, - int k, - T* a, - int a_stride, - T* tau, - int tau_stride, - T* other, - int other_stride); - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index 1ab0a7c961f561..d95a163f215fb8 100644 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -1272,10 +1272,11 @@ backward : logsumexp_grad - api : lstsq - args : (Tensor x, Tensor y, float rcond, str driver) + args : (Tensor x, Tensor y, Scalar rcond, str driver) output : Tensor(solution), Tensor(residuals), Tensor(rank), Tensor(singular_values) infer_meta : func : LstsqInferMeta + dtype : x kernel : func : lstsq diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 068c7da48d0e1a..c1d3c905946144 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -1962,7 +1962,7 @@ void TriangularSolveInferMeta(const MetaTensor& x, void LstsqInferMeta(const MetaTensor& x, const MetaTensor& y, - float rcond, + const Scalar& rcond, std::string driver, MetaTensor* solution, MetaTensor* residuals, diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index bd688e230ce744..aa8d52a85d7330 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -288,7 +288,7 @@ void TriangularSolveInferMeta(const MetaTensor& x, void LstsqInferMeta(const MetaTensor& x, const MetaTensor& y, - float rcond, + const Scalar& rcond, std::string driver, MetaTensor* solution, MetaTensor* residuals, diff --git a/paddle/phi/kernels/cpu/lstsq_kernel.cc b/paddle/phi/kernels/cpu/lstsq_kernel.cc index de1f05394f2589..1350b0fb2f3aae 100644 --- a/paddle/phi/kernels/cpu/lstsq_kernel.cc +++ b/paddle/phi/kernels/cpu/lstsq_kernel.cc @@ -31,7 +31,7 @@ template void LstsqKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, - float rcond, + const Scalar& rcond_scaler, const std::string& driver_string, DenseTensor* solution, DenseTensor* residuals, @@ -45,6 +45,7 @@ void LstsqKernel(const Context& dev_ctx, {"gelsd", LapackDriverType::Gelsd}, {"gelss", LapackDriverType::Gelss}}); auto driver = driver_type[driver_string]; + T rcond = rcond_scaler.to(); auto x_dims = x.dims(); auto y_dims = y.dims(); @@ -287,7 +288,12 @@ void LstsqKernel(const Context& dev_ctx, } } - solution->Resize(solution_dim); + if (batch_count > 1) { + solution->Resize(solution_dim); + } else { + solution->Resize(phi::make_ddim({n, nrhs})); + } + GetResidualsTensor(dev_ctx, x, y, solution, residuals); } diff --git a/paddle/phi/kernels/gpu/lstsq_kernel.cu b/paddle/phi/kernels/gpu/lstsq_kernel.cu index 397735b8f11ebe..c3bf177df3f1d6 100644 --- a/paddle/phi/kernels/gpu/lstsq_kernel.cu +++ b/paddle/phi/kernels/gpu/lstsq_kernel.cu @@ -33,7 +33,7 @@ template void LstsqKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, - float rcond, + const Scalar& rcond_scalar, const std::string& driver_string, DenseTensor* solution, DenseTensor* residuals, @@ -54,6 +54,8 @@ void LstsqKernel(const Context& dev_ctx, int tau_stride = min_mn; int batch_count = phi::GetBatchCount(x_dims); + T rcond = rcond_scalar.to(); + DenseTensor* new_x = new DenseTensor(); new_x->Resize(phi::make_ddim({batch_count, m, n})); dev_ctx.template Alloc(new_x); @@ -154,6 +156,8 @@ void LstsqKernel(const Context& dev_ctx, phi::Copy( dev_ctx, solu_tensor, dev_ctx.GetPlace(), true, solution); } + + if (batch_count == 1) solution->Resize(phi::make_ddim({n, nrhs})); GetResidualsTensor(dev_ctx, x, y, solution, residuals); } diff --git a/paddle/phi/kernels/lstsq_kernel.h b/paddle/phi/kernels/lstsq_kernel.h index 709858d641bda0..4523ce06f7942a 100644 --- a/paddle/phi/kernels/lstsq_kernel.h +++ b/paddle/phi/kernels/lstsq_kernel.h @@ -27,7 +27,7 @@ template void LstsqKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, - float rcond, + const Scalar& rcond, const std::string& driver, DenseTensor* solution, DenseTensor* residuals, diff --git a/paddle/phi/kernels/reduce_sum_kernel.h b/paddle/phi/kernels/reduce_sum_kernel.h index 6dcd459d5016e0..c969cea296db13 100644 --- a/paddle/phi/kernels/reduce_sum_kernel.h +++ b/paddle/phi/kernels/reduce_sum_kernel.h @@ -44,7 +44,7 @@ DenseTensor Sum(const Context& dev_ctx, DenseTensor dense_out; MetaTensor meta_out(&dense_out); SumInferMeta(x, axis, dtype, keep_dim, &meta_out); - SumKernel(dev_ctx, x, axis, dtype, false, &dense_out); + SumKernel(dev_ctx, x, axis, dtype, keep_dim, &dense_out); return dense_out; } diff --git a/paddle/phi/ops/compat/lstsq_sig.cc b/paddle/phi/ops/compat/lstsq_sig.cc new file mode 100644 index 00000000000000..f36dfb1917cafd --- /dev/null +++ b/paddle/phi/ops/compat/lstsq_sig.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature LstsqOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("lstsq", + {"X", "Y"}, + {"rcond", "driver"}, + {"Solution", "Residuals", "Rank", "SingularValues"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(lstsq, phi::LstsqOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_linalg_lstsq_op.py b/python/paddle/fluid/tests/unittests/test_linalg_lstsq_op.py index 60414b8de97a58..58b3b68cc03bb3 100644 --- a/python/paddle/fluid/tests/unittests/test_linalg_lstsq_op.py +++ b/python/paddle/fluid/tests/unittests/test_linalg_lstsq_op.py @@ -30,6 +30,7 @@ def setUp(self): self.devices.append("gpu:0") self.generate_input() self.generate_output() + np.random.seed(2022) def init_config(self): self.dtype = 'float64' @@ -88,34 +89,34 @@ def test_dygraph(self): self._result_sg_values = results[3].numpy() self.assert_np_close() - def test_static(self): - paddle.enable_static() - for dev in self.devices: - paddle.set_device(dev) - place = fluid.CPUPlace() if dev == "cpu" else fluid.CUDAPlace(0) - with fluid.program_guard(fluid.Program(), fluid.Program()): - x = paddle.fluid.data(name="x", - shape=self._input_shape_1, - dtype=self._input_data_1.dtype) - y = paddle.fluid.data(name="y", - shape=self._input_shape_2, - dtype=self._input_data_2.dtype) - results = paddle.linalg.lstsq(x, - y, - rcond=self.rcond, - driver=self.driver) - exe = fluid.Executor(place) - fetches = exe.run(fluid.default_main_program(), - feed={ - "x": self._input_data_1, - "y": self._input_data_2 - }, - fetch_list=[results]) - self._result_solution = fetches[0] - self._result_residuals = fetches[1] - self._result_rank = fetches[2] - self._result_sg_values = fetches[3] - self.assert_np_close() + # def test_static(self): + # paddle.enable_static() + # for dev in self.devices: + # paddle.set_device(dev) + # place = fluid.CPUPlace() if dev == "cpu" else fluid.CUDAPlace(0) + # with fluid.program_guard(fluid.Program(), fluid.Program()): + # x = paddle.fluid.data(name="x", + # shape=self._input_shape_1, + # dtype=self._input_data_1.dtype) + # y = paddle.fluid.data(name="y", + # shape=self._input_shape_2, + # dtype=self._input_data_2.dtype) + # results = paddle.linalg.lstsq(x, + # y, + # rcond=self.rcond, + # driver=self.driver) + # exe = fluid.Executor(place) + # fetches = exe.run(fluid.default_main_program(), + # feed={ + # "x": self._input_data_1, + # "y": self._input_data_2 + # }, + # fetch_list=[results]) + # self._result_solution = fetches[0] + # self._result_residuals = fetches[1] + # self._result_rank = fetches[2] + # self._result_sg_values = fetches[3] + # self.assert_np_close() def assert_np_close(self): if len(self._input_shape_1) == 2: diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index ab03a07f194b6c..43d5474b8e805b 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -3161,16 +3161,23 @@ def lstsq(x, y, rcond=None, driver=None, name=None): elif x.dtype == paddle.float64: rcond = 1e-15 * max(x.shape[-2], x.shape[-1]) - if _non_static_mode(): + if not isinstance(rcond, float): + raise TypeError("Attr rcond of lstsq must be a float number") + + if in_dygraph_mode(): solution, residuals, rank, singular_values = _C_ops.final_state_lstsq( x, y, rcond, driver) - if driver == "gels": - rank = paddle.empty(shape=[0], dtype=paddle.int32) - singular_values = paddle.empty(shape=[0], dtype=x.dtype) - elif driver == "gelsy": - singular_values = paddle.empty(shape=[0], dtype=x.dtype) + elif paddle.in_dynamic_mode(): + solution, residuals, rank, singular_values = _C_ops.lstsq( + x, y, 'rcond', rcond, 'driver', driver) - return solution, residuals, rank, singular_values + if driver == "gels": + rank = paddle.empty(shape=[0], dtype=paddle.int32) + singular_values = paddle.empty(shape=[0], dtype=x.dtype) + elif driver == "gelsy": + singular_values = paddle.empty(shape=[0], dtype=x.dtype) + + return solution, residuals, rank, singular_values helper = LayerHelper('lstsq', **locals()) check_variable_and_dtype(x, 'dtype', From 042a18025b7308bcd526faab89baa94b75c927bb Mon Sep 17 00:00:00 2001 From: haohongxiang Date: Thu, 14 Jul 2022 09:41:39 +0000 Subject: [PATCH 03/17] fix bugs for CIs --- paddle/fluid/operators/lstsq_op.cc | 2 +- paddle/phi/infermeta/binary.cc | 2 +- paddle/phi/infermeta/binary.h | 2 +- paddle/phi/kernels/lstsq_kernel.h | 1 + .../tests/unittests/test_linalg_lstsq_op.py | 56 +++++++++---------- python/paddle/tensor/linalg.py | 25 +++++---- 6 files changed, 45 insertions(+), 43 deletions(-) diff --git a/paddle/fluid/operators/lstsq_op.cc b/paddle/fluid/operators/lstsq_op.cc index 3b45cacdb73e4a..67d97facf027e9 100644 --- a/paddle/fluid/operators/lstsq_op.cc +++ b/paddle/fluid/operators/lstsq_op.cc @@ -77,7 +77,7 @@ namespace ops = paddle::operators; DECLARE_INFER_SHAPE_FUNCTOR(lstsq, LstsqInferShapeFunctor, - PD_INFER_META(phi::TriangularSolveInferMeta)); + PD_INFER_META(phi::LstsqInferMeta)); REGISTER_OPERATOR(lstsq, ops::LstsqOp, diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index c1d3c905946144..019d1d995b8cb7 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -1963,7 +1963,7 @@ void TriangularSolveInferMeta(const MetaTensor& x, void LstsqInferMeta(const MetaTensor& x, const MetaTensor& y, const Scalar& rcond, - std::string driver, + const std::string& driver, MetaTensor* solution, MetaTensor* residuals, MetaTensor* rank, diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index aa8d52a85d7330..436ed059028f60 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -289,7 +289,7 @@ void TriangularSolveInferMeta(const MetaTensor& x, void LstsqInferMeta(const MetaTensor& x, const MetaTensor& y, const Scalar& rcond, - std::string driver, + const std::string& driver, MetaTensor* solution, MetaTensor* residuals, MetaTensor* rank, diff --git a/paddle/phi/kernels/lstsq_kernel.h b/paddle/phi/kernels/lstsq_kernel.h index 4523ce06f7942a..0eed826617a8b4 100644 --- a/paddle/phi/kernels/lstsq_kernel.h +++ b/paddle/phi/kernels/lstsq_kernel.h @@ -14,6 +14,7 @@ #pragma once +#include "paddle/phi/common/scalar.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/elementwise_subtract_kernel.h" #include "paddle/phi/kernels/impl/activation_impl.h" diff --git a/python/paddle/fluid/tests/unittests/test_linalg_lstsq_op.py b/python/paddle/fluid/tests/unittests/test_linalg_lstsq_op.py index 58b3b68cc03bb3..b283c80adfd9a5 100644 --- a/python/paddle/fluid/tests/unittests/test_linalg_lstsq_op.py +++ b/python/paddle/fluid/tests/unittests/test_linalg_lstsq_op.py @@ -89,34 +89,34 @@ def test_dygraph(self): self._result_sg_values = results[3].numpy() self.assert_np_close() - # def test_static(self): - # paddle.enable_static() - # for dev in self.devices: - # paddle.set_device(dev) - # place = fluid.CPUPlace() if dev == "cpu" else fluid.CUDAPlace(0) - # with fluid.program_guard(fluid.Program(), fluid.Program()): - # x = paddle.fluid.data(name="x", - # shape=self._input_shape_1, - # dtype=self._input_data_1.dtype) - # y = paddle.fluid.data(name="y", - # shape=self._input_shape_2, - # dtype=self._input_data_2.dtype) - # results = paddle.linalg.lstsq(x, - # y, - # rcond=self.rcond, - # driver=self.driver) - # exe = fluid.Executor(place) - # fetches = exe.run(fluid.default_main_program(), - # feed={ - # "x": self._input_data_1, - # "y": self._input_data_2 - # }, - # fetch_list=[results]) - # self._result_solution = fetches[0] - # self._result_residuals = fetches[1] - # self._result_rank = fetches[2] - # self._result_sg_values = fetches[3] - # self.assert_np_close() + def test_static(self): + paddle.enable_static() + for dev in self.devices: + paddle.set_device(dev) + place = fluid.CPUPlace() if dev == "cpu" else fluid.CUDAPlace(0) + with fluid.program_guard(fluid.Program(), fluid.Program()): + x = paddle.fluid.data(name="x", + shape=self._input_shape_1, + dtype=self._input_data_1.dtype) + y = paddle.fluid.data(name="y", + shape=self._input_shape_2, + dtype=self._input_data_2.dtype) + results = paddle.linalg.lstsq(x, + y, + rcond=self.rcond, + driver=self.driver) + exe = fluid.Executor(place) + fetches = exe.run(fluid.default_main_program(), + feed={ + "x": self._input_data_1, + "y": self._input_data_2 + }, + fetch_list=[results]) + self._result_solution = fetches[0] + self._result_residuals = fetches[1] + self._result_rank = fetches[2] + self._result_sg_values = fetches[3] + self.assert_np_close() def assert_np_close(self): if len(self._input_shape_1) == 2: diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 43d5474b8e805b..948dee8dbd8557 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -3164,20 +3164,21 @@ def lstsq(x, y, rcond=None, driver=None, name=None): if not isinstance(rcond, float): raise TypeError("Attr rcond of lstsq must be a float number") - if in_dygraph_mode(): - solution, residuals, rank, singular_values = _C_ops.final_state_lstsq( - x, y, rcond, driver) - elif paddle.in_dynamic_mode(): - solution, residuals, rank, singular_values = _C_ops.lstsq( - x, y, 'rcond', rcond, 'driver', driver) + if _non_static_mode(): + if in_dygraph_mode(): + solution, residuals, rank, singular_values = _C_ops.final_state_lstsq( + x, y, rcond, driver) + else: + solution, residuals, rank, singular_values = _C_ops.lstsq( + x, y, 'rcond', rcond, 'driver', driver) - if driver == "gels": - rank = paddle.empty(shape=[0], dtype=paddle.int32) - singular_values = paddle.empty(shape=[0], dtype=x.dtype) - elif driver == "gelsy": - singular_values = paddle.empty(shape=[0], dtype=x.dtype) + if driver == "gels": + rank = paddle.empty(shape=[0], dtype=paddle.int32) + singular_values = paddle.empty(shape=[0], dtype=x.dtype) + elif driver == "gelsy": + singular_values = paddle.empty(shape=[0], dtype=x.dtype) - return solution, residuals, rank, singular_values + return solution, residuals, rank, singular_values helper = LayerHelper('lstsq', **locals()) check_variable_and_dtype(x, 'dtype', From ca57eda5a187e30c35ea496b211e9115242feee6 Mon Sep 17 00:00:00 2001 From: haohongxiang Date: Thu, 14 Jul 2022 10:13:41 +0000 Subject: [PATCH 04/17] update --- paddle/phi/kernels/impl/lstsq_kernel_impl.h | 9 +++++++-- paddle/phi/kernels/impl/qr_kernel_impl.h | 9 +++++++-- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/paddle/phi/kernels/impl/lstsq_kernel_impl.h b/paddle/phi/kernels/impl/lstsq_kernel_impl.h index 774f5f9d1582ab..b356a5c06f347b 100644 --- a/paddle/phi/kernels/impl/lstsq_kernel_impl.h +++ b/paddle/phi/kernels/impl/lstsq_kernel_impl.h @@ -15,17 +15,20 @@ #pragma once #include "paddle/fluid/memory/memcpy.h" -#include "paddle/fluid/platform/dynload/cusolver.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/utils/optional.h" -#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/elementwise_subtract_kernel.h" #include "paddle/phi/kernels/impl/activation_impl.h" #include "paddle/phi/kernels/matmul_kernel.h" #include "paddle/phi/kernels/reduce_sum_kernel.h" +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +#include "paddle/fluid/platform/dynload/cusolver.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#endif + namespace phi { inline int GetBatchCount(const DDim& dims) { @@ -79,6 +82,7 @@ inline void GetResidualsTensor(const DeviceContext& dev_ctx, } } +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) template inline void BatchedOrmqr(const DeviceContext& dev_ctx, bool left, @@ -235,5 +239,6 @@ inline void BatchedOrmqr(const GPUContext& dev_ctx, "For batch [%d]: CUSolver info is not zero but [%d]", i, info_h)); } } +#endif } // namespace phi diff --git a/paddle/phi/kernels/impl/qr_kernel_impl.h b/paddle/phi/kernels/impl/qr_kernel_impl.h index 188ab45213ed5d..48ad8620447e7d 100644 --- a/paddle/phi/kernels/impl/qr_kernel_impl.h +++ b/paddle/phi/kernels/impl/qr_kernel_impl.h @@ -15,14 +15,18 @@ #pragma once #include "paddle/fluid/memory/memcpy.h" -#include "paddle/fluid/platform/dynload/cusolver.h" #include "paddle/fluid/platform/enforce.h" -#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/utils/optional.h" +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +#include "paddle/fluid/platform/dynload/cusolver.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#endif + namespace phi { +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) template void BatchedGeqrf(const DeviceContext& dev_ctx, int batch_size, @@ -273,5 +277,6 @@ void BatchedOrgqr(const GPUContext& dev_ctx, "For batch [%d]: CUSolver QR is not zero. [%d]", i, info_h)); } } +#endif } // namespace phi From e16f3125b78aa790f375b2179a63002f1e305c6b Mon Sep 17 00:00:00 2001 From: haohongxiang Date: Thu, 14 Jul 2022 12:38:59 +0000 Subject: [PATCH 05/17] fix bugs --- paddle/phi/kernels/impl/lstsq_kernel_impl.h | 8 ++++++-- paddle/phi/kernels/impl/qr_kernel_impl.h | 8 ++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/paddle/phi/kernels/impl/lstsq_kernel_impl.h b/paddle/phi/kernels/impl/lstsq_kernel_impl.h index b356a5c06f347b..04c9a3601809d9 100644 --- a/paddle/phi/kernels/impl/lstsq_kernel_impl.h +++ b/paddle/phi/kernels/impl/lstsq_kernel_impl.h @@ -14,6 +14,10 @@ #pragma once +#define GPU_ENABLE \ + defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \ + defined(PADLDE_WITH_ROCM) + #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/utils/optional.h" @@ -24,7 +28,7 @@ #include "paddle/phi/kernels/matmul_kernel.h" #include "paddle/phi/kernels/reduce_sum_kernel.h" -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +#if defined(GPU_ENABLE) #include "paddle/fluid/platform/dynload/cusolver.h" #include "paddle/phi/backends/gpu/gpu_context.h" #endif @@ -82,7 +86,7 @@ inline void GetResidualsTensor(const DeviceContext& dev_ctx, } } -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +#if defined(GPU_ENABLE) template inline void BatchedOrmqr(const DeviceContext& dev_ctx, bool left, diff --git a/paddle/phi/kernels/impl/qr_kernel_impl.h b/paddle/phi/kernels/impl/qr_kernel_impl.h index 48ad8620447e7d..d454ae7fecd45b 100644 --- a/paddle/phi/kernels/impl/qr_kernel_impl.h +++ b/paddle/phi/kernels/impl/qr_kernel_impl.h @@ -19,14 +19,18 @@ #include "paddle/phi/core/dense_tensor.h" #include "paddle/utils/optional.h" -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +#define GPU_ENABLE \ + defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \ + defined(PADLDE_WITH_ROCM) + +#if defined(GPU_ENABLE) #include "paddle/fluid/platform/dynload/cusolver.h" #include "paddle/phi/backends/gpu/gpu_context.h" #endif namespace phi { -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +#if defined(GPU_ENABLE) template void BatchedGeqrf(const DeviceContext& dev_ctx, int batch_size, From f368243c68512c21a879223b73b3cccd93e0bd99 Mon Sep 17 00:00:00 2001 From: haohongxiang Date: Thu, 14 Jul 2022 15:06:05 +0000 Subject: [PATCH 06/17] add uts --- .../tests/unittests/test_linalg_lstsq_op.py | 25 ++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_linalg_lstsq_op.py b/python/paddle/fluid/tests/unittests/test_linalg_lstsq_op.py index b283c80adfd9a5..60acfd414feec8 100644 --- a/python/paddle/fluid/tests/unittests/test_linalg_lstsq_op.py +++ b/python/paddle/fluid/tests/unittests/test_linalg_lstsq_op.py @@ -68,8 +68,31 @@ def generate_output(self): self._output_rank.append(out[2]) self._output_sg_values.append(out[3]) - def test_dygraph(self): + def test_eager_dygraph(self): paddle.disable_static() + paddle.fluid.framework._disable_legacy_dygraph() + for dev in self.devices: + paddle.set_device(dev) + place = paddle.CPUPlace() if dev == "cpu" else paddle.CUDAPlace(0) + x = paddle.to_tensor(self._input_data_1, + place=place, + dtype=self.dtype) + y = paddle.to_tensor(self._input_data_2, + place=place, + dtype=self.dtype) + results = paddle.linalg.lstsq(x, + y, + rcond=self.rcond, + driver=self.driver) + self._result_solution = results[0].numpy() + self._result_residuals = results[1].numpy() + self._result_rank = results[2].numpy() + self._result_sg_values = results[3].numpy() + self.assert_np_close() + + def test_legacy_dygraph(self): + paddle.disable_static() + paddle.fluid.framework._enable_legacy_dygraph() for dev in self.devices: paddle.set_device(dev) place = paddle.CPUPlace() if dev == "cpu" else paddle.CUDAPlace(0) From fcaca4330356615858bfefb8be90dca8f166f92b Mon Sep 17 00:00:00 2001 From: haohongxiang Date: Thu, 14 Jul 2022 15:30:16 +0000 Subject: [PATCH 07/17] update --- paddle/phi/kernels/impl/lstsq_kernel_impl.h | 4 +--- paddle/phi/kernels/impl/qr_kernel_impl.h | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/paddle/phi/kernels/impl/lstsq_kernel_impl.h b/paddle/phi/kernels/impl/lstsq_kernel_impl.h index 04c9a3601809d9..fb94b4c861b1fd 100644 --- a/paddle/phi/kernels/impl/lstsq_kernel_impl.h +++ b/paddle/phi/kernels/impl/lstsq_kernel_impl.h @@ -14,9 +14,7 @@ #pragma once -#define GPU_ENABLE \ - defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \ - defined(PADLDE_WITH_ROCM) +#define GPU_ENABLE defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/platform/enforce.h" diff --git a/paddle/phi/kernels/impl/qr_kernel_impl.h b/paddle/phi/kernels/impl/qr_kernel_impl.h index d454ae7fecd45b..92a0d55039b4ba 100644 --- a/paddle/phi/kernels/impl/qr_kernel_impl.h +++ b/paddle/phi/kernels/impl/qr_kernel_impl.h @@ -19,9 +19,7 @@ #include "paddle/phi/core/dense_tensor.h" #include "paddle/utils/optional.h" -#define GPU_ENABLE \ - defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \ - defined(PADLDE_WITH_ROCM) +#define GPU_ENABLE defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(GPU_ENABLE) #include "paddle/fluid/platform/dynload/cusolver.h" From ec26f45d5560d01beb383d2ee6a779f9725e2e60 Mon Sep 17 00:00:00 2001 From: haohongxiang Date: Thu, 14 Jul 2022 15:43:47 +0000 Subject: [PATCH 08/17] update --- paddle/phi/kernels/impl/lstsq_kernel_impl.h | 6 ++---- paddle/phi/kernels/impl/qr_kernel_impl.h | 6 ++---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/paddle/phi/kernels/impl/lstsq_kernel_impl.h b/paddle/phi/kernels/impl/lstsq_kernel_impl.h index fb94b4c861b1fd..b356a5c06f347b 100644 --- a/paddle/phi/kernels/impl/lstsq_kernel_impl.h +++ b/paddle/phi/kernels/impl/lstsq_kernel_impl.h @@ -14,8 +14,6 @@ #pragma once -#define GPU_ENABLE defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/utils/optional.h" @@ -26,7 +24,7 @@ #include "paddle/phi/kernels/matmul_kernel.h" #include "paddle/phi/kernels/reduce_sum_kernel.h" -#if defined(GPU_ENABLE) +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #include "paddle/fluid/platform/dynload/cusolver.h" #include "paddle/phi/backends/gpu/gpu_context.h" #endif @@ -84,7 +82,7 @@ inline void GetResidualsTensor(const DeviceContext& dev_ctx, } } -#if defined(GPU_ENABLE) +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) template inline void BatchedOrmqr(const DeviceContext& dev_ctx, bool left, diff --git a/paddle/phi/kernels/impl/qr_kernel_impl.h b/paddle/phi/kernels/impl/qr_kernel_impl.h index 92a0d55039b4ba..48ad8620447e7d 100644 --- a/paddle/phi/kernels/impl/qr_kernel_impl.h +++ b/paddle/phi/kernels/impl/qr_kernel_impl.h @@ -19,16 +19,14 @@ #include "paddle/phi/core/dense_tensor.h" #include "paddle/utils/optional.h" -#define GPU_ENABLE defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - -#if defined(GPU_ENABLE) +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #include "paddle/fluid/platform/dynload/cusolver.h" #include "paddle/phi/backends/gpu/gpu_context.h" #endif namespace phi { -#if defined(GPU_ENABLE) +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) template void BatchedGeqrf(const DeviceContext& dev_ctx, int batch_size, From 04cdd8fc800d053f278962d8381a8500684622e2 Mon Sep 17 00:00:00 2001 From: haohongxiang Date: Fri, 15 Jul 2022 03:00:26 +0000 Subject: [PATCH 09/17] update --- paddle/phi/kernels/impl/lstsq_kernel_impl.h | 4 ++-- paddle/phi/kernels/impl/qr_kernel_impl.h | 4 ++-- python/paddle/tensor/linalg.py | 4 +--- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/paddle/phi/kernels/impl/lstsq_kernel_impl.h b/paddle/phi/kernels/impl/lstsq_kernel_impl.h index b356a5c06f347b..440f6623ea4b6b 100644 --- a/paddle/phi/kernels/impl/lstsq_kernel_impl.h +++ b/paddle/phi/kernels/impl/lstsq_kernel_impl.h @@ -24,7 +24,7 @@ #include "paddle/phi/kernels/matmul_kernel.h" #include "paddle/phi/kernels/reduce_sum_kernel.h" -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +#if defined(PADDLE_WITH_CUDA) #include "paddle/fluid/platform/dynload/cusolver.h" #include "paddle/phi/backends/gpu/gpu_context.h" #endif @@ -82,7 +82,7 @@ inline void GetResidualsTensor(const DeviceContext& dev_ctx, } } -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +#if defined(PADDLE_WITH_CUDA) template inline void BatchedOrmqr(const DeviceContext& dev_ctx, bool left, diff --git a/paddle/phi/kernels/impl/qr_kernel_impl.h b/paddle/phi/kernels/impl/qr_kernel_impl.h index 48ad8620447e7d..d74cca2436f287 100644 --- a/paddle/phi/kernels/impl/qr_kernel_impl.h +++ b/paddle/phi/kernels/impl/qr_kernel_impl.h @@ -19,14 +19,14 @@ #include "paddle/phi/core/dense_tensor.h" #include "paddle/utils/optional.h" -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +#if defined(PADDLE_WITH_CUDA) #include "paddle/fluid/platform/dynload/cusolver.h" #include "paddle/phi/backends/gpu/gpu_context.h" #endif namespace phi { -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +#if defined(PADDLE_WITH_CUDA) template void BatchedGeqrf(const DeviceContext& dev_ctx, int batch_size, diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 948dee8dbd8557..13b47a2a4b6f00 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -3161,9 +3161,6 @@ def lstsq(x, y, rcond=None, driver=None, name=None): elif x.dtype == paddle.float64: rcond = 1e-15 * max(x.shape[-2], x.shape[-1]) - if not isinstance(rcond, float): - raise TypeError("Attr rcond of lstsq must be a float number") - if _non_static_mode(): if in_dygraph_mode(): solution, residuals, rank, singular_values = _C_ops.final_state_lstsq( @@ -3209,6 +3206,7 @@ def lstsq(x, y, rcond=None, driver=None, name=None): 'driver': driver }) + print("--- 3 ---") if driver == "gels": rank = paddle.static.data(name='rank', shape=[0]) singular_values = paddle.static.data(name='singular_values', shape=[0]) From 796bf89c85418032114f934a71f205d1e5989a39 Mon Sep 17 00:00:00 2001 From: haohongxiang Date: Fri, 15 Jul 2022 06:43:51 +0000 Subject: [PATCH 10/17] fix bugs of jip --- paddle/phi/kernels/impl/lstsq_kernel_impl.h | 6 ++++-- paddle/phi/kernels/impl/qr_kernel_impl.h | 6 ++++-- python/paddle/tensor/linalg.py | 1 - 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/paddle/phi/kernels/impl/lstsq_kernel_impl.h b/paddle/phi/kernels/impl/lstsq_kernel_impl.h index 440f6623ea4b6b..b5ad1c1b13070f 100644 --- a/paddle/phi/kernels/impl/lstsq_kernel_impl.h +++ b/paddle/phi/kernels/impl/lstsq_kernel_impl.h @@ -24,7 +24,8 @@ #include "paddle/phi/kernels/matmul_kernel.h" #include "paddle/phi/kernels/reduce_sum_kernel.h" -#if defined(PADDLE_WITH_CUDA) +// HIP not support cusolver +#if defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) #include "paddle/fluid/platform/dynload/cusolver.h" #include "paddle/phi/backends/gpu/gpu_context.h" #endif @@ -82,7 +83,8 @@ inline void GetResidualsTensor(const DeviceContext& dev_ctx, } } -#if defined(PADDLE_WITH_CUDA) +// HIP not support cusolver +#if defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) template inline void BatchedOrmqr(const DeviceContext& dev_ctx, bool left, diff --git a/paddle/phi/kernels/impl/qr_kernel_impl.h b/paddle/phi/kernels/impl/qr_kernel_impl.h index d74cca2436f287..02904c378a3461 100644 --- a/paddle/phi/kernels/impl/qr_kernel_impl.h +++ b/paddle/phi/kernels/impl/qr_kernel_impl.h @@ -19,14 +19,16 @@ #include "paddle/phi/core/dense_tensor.h" #include "paddle/utils/optional.h" -#if defined(PADDLE_WITH_CUDA) +// HIP not support cusolver +#if defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) #include "paddle/fluid/platform/dynload/cusolver.h" #include "paddle/phi/backends/gpu/gpu_context.h" #endif namespace phi { -#if defined(PADDLE_WITH_CUDA) +// HIP not support cusolver +#if defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) template void BatchedGeqrf(const DeviceContext& dev_ctx, int batch_size, diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 13b47a2a4b6f00..e551b93f537664 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -3206,7 +3206,6 @@ def lstsq(x, y, rcond=None, driver=None, name=None): 'driver': driver }) - print("--- 3 ---") if driver == "gels": rank = paddle.static.data(name='rank', shape=[0]) singular_values = paddle.static.data(name='singular_values', shape=[0]) From da4bd82433415ebf9933e1f1beaedbd77daf6fef Mon Sep 17 00:00:00 2001 From: haohongxiang Date: Fri, 15 Jul 2022 07:35:09 +0000 Subject: [PATCH 11/17] fix bugs of hip --- paddle/phi/kernels/gpu/lstsq_kernel.cu | 4 ++++ paddle/phi/kernels/impl/lstsq_kernel_impl.h | 6 ++---- paddle/phi/kernels/impl/qr_kernel_impl.h | 6 ++---- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/paddle/phi/kernels/gpu/lstsq_kernel.cu b/paddle/phi/kernels/gpu/lstsq_kernel.cu index c3bf177df3f1d6..75a46ff2713a13 100644 --- a/paddle/phi/kernels/gpu/lstsq_kernel.cu +++ b/paddle/phi/kernels/gpu/lstsq_kernel.cu @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#ifndef PADDLE_WITH_HIP // HIP not support cusolver + #include #include #include @@ -164,3 +166,5 @@ void LstsqKernel(const Context& dev_ctx, } // namespace phi PD_REGISTER_KERNEL(lstsq, GPU, ALL_LAYOUT, phi::LstsqKernel, float, double) {} + +#endif // not PADDLE_WITH_HIP diff --git a/paddle/phi/kernels/impl/lstsq_kernel_impl.h b/paddle/phi/kernels/impl/lstsq_kernel_impl.h index b5ad1c1b13070f..440f6623ea4b6b 100644 --- a/paddle/phi/kernels/impl/lstsq_kernel_impl.h +++ b/paddle/phi/kernels/impl/lstsq_kernel_impl.h @@ -24,8 +24,7 @@ #include "paddle/phi/kernels/matmul_kernel.h" #include "paddle/phi/kernels/reduce_sum_kernel.h" -// HIP not support cusolver -#if defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) +#if defined(PADDLE_WITH_CUDA) #include "paddle/fluid/platform/dynload/cusolver.h" #include "paddle/phi/backends/gpu/gpu_context.h" #endif @@ -83,8 +82,7 @@ inline void GetResidualsTensor(const DeviceContext& dev_ctx, } } -// HIP not support cusolver -#if defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) +#if defined(PADDLE_WITH_CUDA) template inline void BatchedOrmqr(const DeviceContext& dev_ctx, bool left, diff --git a/paddle/phi/kernels/impl/qr_kernel_impl.h b/paddle/phi/kernels/impl/qr_kernel_impl.h index 02904c378a3461..d74cca2436f287 100644 --- a/paddle/phi/kernels/impl/qr_kernel_impl.h +++ b/paddle/phi/kernels/impl/qr_kernel_impl.h @@ -19,16 +19,14 @@ #include "paddle/phi/core/dense_tensor.h" #include "paddle/utils/optional.h" -// HIP not support cusolver -#if defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) +#if defined(PADDLE_WITH_CUDA) #include "paddle/fluid/platform/dynload/cusolver.h" #include "paddle/phi/backends/gpu/gpu_context.h" #endif namespace phi { -// HIP not support cusolver -#if defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) +#if defined(PADDLE_WITH_CUDA) template void BatchedGeqrf(const DeviceContext& dev_ctx, int batch_size, From 74eec39030fe5ccdea331a7c7135ad5fab42719c Mon Sep 17 00:00:00 2001 From: haohongxiang Date: Fri, 15 Jul 2022 08:56:50 +0000 Subject: [PATCH 12/17] update --- paddle/phi/kernels/gpu/lstsq_kernel.cu | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/paddle/phi/kernels/gpu/lstsq_kernel.cu b/paddle/phi/kernels/gpu/lstsq_kernel.cu index 75a46ff2713a13..efc8cd66dce9a3 100644 --- a/paddle/phi/kernels/gpu/lstsq_kernel.cu +++ b/paddle/phi/kernels/gpu/lstsq_kernel.cu @@ -165,6 +165,11 @@ void LstsqKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL(lstsq, GPU, ALL_LAYOUT, phi::LstsqKernel, float, double) {} +PD_REGISTER_KERNEL(lstsq, // cuda_only + GPU, + ALL_LAYOUT, + phi::LstsqKernel, + float, + double) {} #endif // not PADDLE_WITH_HIP From 5d3b9e5b85a31829a090d932e691d4b23e8a84d8 Mon Sep 17 00:00:00 2001 From: haohongxiang Date: Mon, 18 Jul 2022 10:37:47 +0000 Subject: [PATCH 13/17] update according to review --- paddle/phi/kernels/impl/lstsq_kernel_impl.h | 72 +++++++------- paddle/phi/kernels/impl/qr_kernel_impl.h | 100 +++++++++----------- 2 files changed, 80 insertions(+), 92 deletions(-) diff --git a/paddle/phi/kernels/impl/lstsq_kernel_impl.h b/paddle/phi/kernels/impl/lstsq_kernel_impl.h index 440f6623ea4b6b..73ba954614a221 100644 --- a/paddle/phi/kernels/impl/lstsq_kernel_impl.h +++ b/paddle/phi/kernels/impl/lstsq_kernel_impl.h @@ -15,7 +15,7 @@ #pragma once #include "paddle/fluid/memory/memcpy.h" -#include "paddle/fluid/platform/enforce.h" +#include "paddle/phi/core/enforce.h" #include "paddle/utils/optional.h" #include "paddle/phi/core/dense_tensor.h" @@ -25,7 +25,7 @@ #include "paddle/phi/kernels/reduce_sum_kernel.h" #if defined(PADDLE_WITH_CUDA) -#include "paddle/fluid/platform/dynload/cusolver.h" +#include "paddle/phi/backends/dynload/cusolver.h" #include "paddle/phi/backends/gpu/gpu_context.h" #endif @@ -119,9 +119,8 @@ inline void BatchedOrmqr(const GPUContext& dev_ctx, int ldc = std::max(1, m); auto handle = dev_ctx.cusolver_dn_handle(); - PADDLE_ENFORCE_GPU_SUCCESS( - paddle::platform::dynload::cusolverDnSormqr_bufferSize( - handle, side, trans, m, n, k, a, lda, tau, other, ldc, &lwork)); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnSormqr_bufferSize( + handle, side, trans, m, n, k, a, lda, tau, other, ldc, &lwork)); DenseTensor* info = new DenseTensor(); info->Resize(make_ddim({1})); int* info_d = dev_ctx.template Alloc(info); @@ -137,21 +136,20 @@ inline void BatchedOrmqr(const GPUContext& dev_ctx, float* workspace_ptr = dev_ctx.template Alloc(workspace); // compute ormgr - PADDLE_ENFORCE_GPU_SUCCESS( - paddle::platform::dynload::cusolverDnSormqr(handle, - side, - trans, - m, - n, - k, - a_working_ptr, - lda, - tau_working_ptr, - other_working_ptr, - ldc, - workspace_ptr, - lwork, - info_d)); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnSormqr(handle, + side, + trans, + m, + n, + k, + a_working_ptr, + lda, + tau_working_ptr, + other_working_ptr, + ldc, + workspace_ptr, + lwork, + info_d)); // check the error info int info_h; @@ -190,9 +188,8 @@ inline void BatchedOrmqr(const GPUContext& dev_ctx, int ldc = std::max(1, m); auto handle = dev_ctx.cusolver_dn_handle(); - PADDLE_ENFORCE_GPU_SUCCESS( - paddle::platform::dynload::cusolverDnDormqr_bufferSize( - handle, side, trans, m, n, k, a, lda, tau, other, ldc, &lwork)); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnDormqr_bufferSize( + handle, side, trans, m, n, k, a, lda, tau, other, ldc, &lwork)); DenseTensor* info = new DenseTensor(); info->Resize(make_ddim({1})); int* info_d = dev_ctx.template Alloc(info); @@ -208,21 +205,20 @@ inline void BatchedOrmqr(const GPUContext& dev_ctx, double* workspace_ptr = dev_ctx.template Alloc(workspace); // compute ormgr - PADDLE_ENFORCE_GPU_SUCCESS( - paddle::platform::dynload::cusolverDnDormqr(handle, - side, - trans, - m, - n, - k, - a_working_ptr, - lda, - tau_working_ptr, - other_working_ptr, - ldc, - workspace_ptr, - lwork, - info_d)); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnDormqr(handle, + side, + trans, + m, + n, + k, + a_working_ptr, + lda, + tau_working_ptr, + other_working_ptr, + ldc, + workspace_ptr, + lwork, + info_d)); // check the error info int info_h; diff --git a/paddle/phi/kernels/impl/qr_kernel_impl.h b/paddle/phi/kernels/impl/qr_kernel_impl.h index d74cca2436f287..1d64117922d26b 100644 --- a/paddle/phi/kernels/impl/qr_kernel_impl.h +++ b/paddle/phi/kernels/impl/qr_kernel_impl.h @@ -15,12 +15,12 @@ #pragma once #include "paddle/fluid/memory/memcpy.h" -#include "paddle/fluid/platform/enforce.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/enforce.h" #include "paddle/utils/optional.h" #if defined(PADDLE_WITH_CUDA) -#include "paddle/fluid/platform/dynload/cusolver.h" +#include "paddle/phi/backends/dynload/cusolver.h" #include "paddle/phi/backends/gpu/gpu_context.h" #endif @@ -64,8 +64,7 @@ void BatchedGeqrf(const GPUContext& dev_ctx, auto handle = dev_ctx.cusolver_dn_handle(); PADDLE_ENFORCE_GPU_SUCCESS( - paddle::platform::dynload::cusolverDnSgeqrf_bufferSize( - handle, m, n, a, lda, &lwork)); + phi::dynload::cusolverDnSgeqrf_bufferSize(handle, m, n, a, lda, &lwork)); DenseTensor* workspace = new DenseTensor(); workspace->Resize(make_ddim({lwork})); @@ -79,16 +78,15 @@ void BatchedGeqrf(const GPUContext& dev_ctx, float* a_working_ptr = &a[i * a_stride]; float* tau_working_ptr = &tau[i * tau_stride]; // compute geqrf - PADDLE_ENFORCE_GPU_SUCCESS( - paddle::platform::dynload::cusolverDnSgeqrf(handle, - m, - n, - a_working_ptr, - lda, - tau_working_ptr, - workspace_ptr, - lwork, - info_d)); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnSgeqrf(handle, + m, + n, + a_working_ptr, + lda, + tau_working_ptr, + workspace_ptr, + lwork, + info_d)); // Do we need synchronized here? // check the error info int info_h; @@ -120,8 +118,7 @@ void BatchedGeqrf(const GPUContext& dev_ctx, auto handle = dev_ctx.cusolver_dn_handle(); PADDLE_ENFORCE_GPU_SUCCESS( - paddle::platform::dynload::cusolverDnDgeqrf_bufferSize( - handle, m, n, a, lda, &lwork)); + phi::dynload::cusolverDnDgeqrf_bufferSize(handle, m, n, a, lda, &lwork)); DenseTensor* workspace = new DenseTensor(); workspace->Resize(make_ddim({lwork})); @@ -135,16 +132,15 @@ void BatchedGeqrf(const GPUContext& dev_ctx, double* a_working_ptr = &a[i * a_stride]; double* tau_working_ptr = &tau[i * tau_stride]; // compute geqrf - PADDLE_ENFORCE_GPU_SUCCESS( - paddle::platform::dynload::cusolverDnDgeqrf(handle, - m, - n, - a_working_ptr, - lda, - tau_working_ptr, - workspace_ptr, - lwork, - info_d)); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnDgeqrf(handle, + m, + n, + a_working_ptr, + lda, + tau_working_ptr, + workspace_ptr, + lwork, + info_d)); // Do we need synchronized here? // check the error info int info_h; @@ -176,9 +172,8 @@ void BatchedOrgqr(const GPUContext& dev_ctx, int lwork = 0; auto handle = dev_ctx.cusolver_dn_handle(); - PADDLE_ENFORCE_GPU_SUCCESS( - paddle::platform::dynload::cusolverDnSorgqr_bufferSize( - handle, m, n, k, a, lda, tau, &lwork)); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnSorgqr_bufferSize( + handle, m, n, k, a, lda, tau, &lwork)); DenseTensor* workspace = new DenseTensor(); workspace->Resize(make_ddim({lwork})); @@ -192,17 +187,16 @@ void BatchedOrgqr(const GPUContext& dev_ctx, float* a_working_ptr = &a[i * a_stride]; float* tau_working_ptr = &tau[i * tau_stride]; // compute orggr - PADDLE_ENFORCE_GPU_SUCCESS( - paddle::platform::dynload::cusolverDnSorgqr(handle, - m, - n, - k, - a_working_ptr, - lda, - tau_working_ptr, - workspace_ptr, - lwork, - info_d)); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnSorgqr(handle, + m, + n, + k, + a_working_ptr, + lda, + tau_working_ptr, + workspace_ptr, + lwork, + info_d)); // Do we need synchronized here? // check the error info int info_h; @@ -234,9 +228,8 @@ void BatchedOrgqr(const GPUContext& dev_ctx, int lwork = 0; auto handle = dev_ctx.cusolver_dn_handle(); - PADDLE_ENFORCE_GPU_SUCCESS( - paddle::platform::dynload::cusolverDnDorgqr_bufferSize( - handle, m, n, k, a, lda, tau, &lwork)); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnDorgqr_bufferSize( + handle, m, n, k, a, lda, tau, &lwork)); DenseTensor* workspace = new DenseTensor(); workspace->Resize(make_ddim({lwork})); @@ -250,17 +243,16 @@ void BatchedOrgqr(const GPUContext& dev_ctx, double* a_working_ptr = &a[i * a_stride]; double* tau_working_ptr = &tau[i * tau_stride]; // compute orggr - PADDLE_ENFORCE_GPU_SUCCESS( - paddle::platform::dynload::cusolverDnDorgqr(handle, - m, - n, - k, - a_working_ptr, - lda, - tau_working_ptr, - workspace_ptr, - lwork, - info_d)); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnDorgqr(handle, + m, + n, + k, + a_working_ptr, + lda, + tau_working_ptr, + workspace_ptr, + lwork, + info_d)); // Do we need synchronized here? // check the error info int info_h; From eeac28d8e396cb3a6d3ec911fb6c88757eb1fc83 Mon Sep 17 00:00:00 2001 From: haohongxiang Date: Mon, 25 Jul 2022 03:08:55 +0000 Subject: [PATCH 14/17] update --- paddle/fluid/operators/lstsq_op.cc | 10 ++++++++++ paddle/phi/kernels/cpu/lstsq_kernel.cc | 2 ++ paddle/phi/kernels/gpu/lstsq_kernel.cu | 2 ++ paddle/phi/kernels/lstsq_kernel.h | 5 ----- 4 files changed, 14 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/lstsq_op.cc b/paddle/fluid/operators/lstsq_op.cc index 67d97facf027e9..82fc089b512447 100644 --- a/paddle/fluid/operators/lstsq_op.cc +++ b/paddle/fluid/operators/lstsq_op.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/phi/infermeta/binary.h" namespace paddle { @@ -83,3 +84,12 @@ REGISTER_OPERATOR(lstsq, ops::LstsqOp, ops::LstsqOpMaker, LstsqInferShapeFunctor); + +REGISTER_OP_VERSION(lstsq).AddCheckpoint( + R"ROC( + Upgrade lstsq, add 1 outputs [Residuals]. + )ROC", + paddle::framework::compatible::OpVersionDesc().NewOutput( + "Residuals", + "Output tensor of lstsq operator, " + "meaning the squared residuals of the calculated solutions.")); diff --git a/paddle/phi/kernels/cpu/lstsq_kernel.cc b/paddle/phi/kernels/cpu/lstsq_kernel.cc index 1350b0fb2f3aae..5542c2ba6e7c5f 100644 --- a/paddle/phi/kernels/cpu/lstsq_kernel.cc +++ b/paddle/phi/kernels/cpu/lstsq_kernel.cc @@ -21,7 +21,9 @@ #include "paddle/phi/kernels/funcs/complex_functors.h" #include "paddle/phi/kernels/funcs/lapack/lapack_function.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/impl/lstsq_kernel_impl.h" #include "paddle/phi/kernels/lstsq_kernel.h" +#include "paddle/phi/kernels/transpose_kernel.h" namespace phi { diff --git a/paddle/phi/kernels/gpu/lstsq_kernel.cu b/paddle/phi/kernels/gpu/lstsq_kernel.cu index efc8cd66dce9a3..adb0ca09d89386 100644 --- a/paddle/phi/kernels/gpu/lstsq_kernel.cu +++ b/paddle/phi/kernels/gpu/lstsq_kernel.cu @@ -21,10 +21,12 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/slice.h" +#include "paddle/phi/kernels/impl/lstsq_kernel_impl.h" #include "paddle/phi/kernels/impl/qr_kernel_impl.h" #include "paddle/phi/kernels/impl/tril_triu_kernel_impl.h" #include "paddle/phi/kernels/lstsq_kernel.h" #include "paddle/phi/kernels/matmul_kernel.h" +#include "paddle/phi/kernels/transpose_kernel.h" #include "paddle/phi/kernels/triangular_solve_kernel.h" namespace phi { diff --git a/paddle/phi/kernels/lstsq_kernel.h b/paddle/phi/kernels/lstsq_kernel.h index 0eed826617a8b4..1ad58615b4b3d0 100644 --- a/paddle/phi/kernels/lstsq_kernel.h +++ b/paddle/phi/kernels/lstsq_kernel.h @@ -16,11 +16,6 @@ #include "paddle/phi/common/scalar.h" #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/kernels/elementwise_subtract_kernel.h" -#include "paddle/phi/kernels/impl/activation_impl.h" -#include "paddle/phi/kernels/impl/lstsq_kernel_impl.h" -#include "paddle/phi/kernels/matmul_kernel.h" -#include "paddle/phi/kernels/transpose_kernel.h" namespace phi { From 8247911ae87cdd1e8e50e3fe6c671309abadcb38 Mon Sep 17 00:00:00 2001 From: haohongxiang Date: Tue, 26 Jul 2022 07:50:23 +0000 Subject: [PATCH 15/17] update --- paddle/fluid/operators/lstsq_op.cc | 3 ++- paddle/phi/api/yaml/legacy_api.yaml | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/lstsq_op.cc b/paddle/fluid/operators/lstsq_op.cc index 82fc089b512447..b02a2fe13a2b0a 100644 --- a/paddle/fluid/operators/lstsq_op.cc +++ b/paddle/fluid/operators/lstsq_op.cc @@ -59,7 +59,8 @@ class LstsqOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Solution", "(Tensor), The output Solution tensor with shape (*, n, k)."); AddOutput("Residuals", - "(Tensor), The output Residuals tensor with shape (*, k)."); + "(Tensor), The output Residuals tensor with shape (*, k).") + .AsDispensable(); AddOutput("Rank", "(Tensor), The output Rank tensor with shape (*)."); AddOutput( "SingularValues", diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index cd7ccee89460f1..d7a529f327c4fc 100644 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -1343,6 +1343,7 @@ dtype : x kernel : func : lstsq + optional : residuals # masked_select - api : masked_select From e7293d406139185f8faf6f30ed2e3241147ffb1a Mon Sep 17 00:00:00 2001 From: haohongxiang Date: Tue, 26 Jul 2022 12:44:08 +0000 Subject: [PATCH 16/17] update --- paddle/phi/api/yaml/legacy_api.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index 3da063aad55904..a2112519842177 100644 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -1355,7 +1355,6 @@ dtype : x kernel : func : lstsq - optional : residuals # masked_select - api : masked_select From de5664121df5a4d5d0783bb92ea6f645410408a5 Mon Sep 17 00:00:00 2001 From: haohongxiang Date: Wed, 27 Jul 2022 03:17:34 +0000 Subject: [PATCH 17/17] update --- paddle/fluid/pybind/op_function_generator.h | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/fluid/pybind/op_function_generator.h b/paddle/fluid/pybind/op_function_generator.h index 590d9d2f83e8b9..8f66d258edac42 100644 --- a/paddle/fluid/pybind/op_function_generator.h +++ b/paddle/fluid/pybind/op_function_generator.h @@ -245,6 +245,7 @@ std::map> op_outs_map = { "SavedMean", "SavedVariance", "ReserveSpace"}}, + {"lstsq", {"Solution", "Residuals", "Rank", "SingularValues"}}, {"inplace_abn", {"Y", "MeanOut",