From 794b873e2dcf4e476742c64de3869e82401d17ad Mon Sep 17 00:00:00 2001 From: Fu Jingguo Date: Thu, 16 May 2024 20:22:35 +0800 Subject: [PATCH] [Ascend] fuj/replace-baddbmm-and-fix-max-and-min-config (#1208) * replace baddbmm and fix max and min config * replace bmm with aclnn * support bfloat16 --- impl/ascend/ascend_tensor.hpp | 2 + impl/ascend/device_configs.py | 14 ++++++ impl/ascend/functions/baddbmm.cpp | 68 ++++++------------------------ impl/ascend/functions/bmm.cpp | 18 ++++---- impl/ascend_npu/ascend_config.yaml | 12 ++---- 5 files changed, 41 insertions(+), 73 deletions(-) diff --git a/impl/ascend/ascend_tensor.hpp b/impl/ascend/ascend_tensor.hpp index 715ec4be2f..7c8c9ec628 100644 --- a/impl/ascend/ascend_tensor.hpp +++ b/impl/ascend/ascend_tensor.hpp @@ -95,6 +95,7 @@ constexpr aclDataType diopiDtypeToAclDataType(diopiDtype_t dtype) noexcept { return acl_dtype; switch (dtype) { + DIOPI_DTYPE_TO_ACL_DTYPE_CASE(diopi_dtype_bfloat16, ACL_BF16) DIOPI_DTYPE_TO_ACL_DTYPE_CASE(diopi_dtype_float16, ACL_FLOAT16) DIOPI_DTYPE_TO_ACL_DTYPE_CASE(diopi_dtype_float32, ACL_FLOAT) DIOPI_DTYPE_TO_ACL_DTYPE_CASE(diopi_dtype_float64, ACL_DOUBLE) @@ -107,6 +108,7 @@ constexpr aclDataType diopiDtypeToAclDataType(diopiDtype_t dtype) noexcept { DIOPI_DTYPE_TO_ACL_DTYPE_CASE(diopi_dtype_int64, ACL_INT64) DIOPI_DTYPE_TO_ACL_DTYPE_CASE(diopi_dtype_uint64, ACL_UINT64) DIOPI_DTYPE_TO_ACL_DTYPE_CASE(diopi_dtype_bool, ACL_BOOL) + DIOPI_DTYPE_TO_ACL_DTYPE_CASE(diopi_dtype_complex32, ACL_COMPLEX32) DIOPI_DTYPE_TO_ACL_DTYPE_CASE(diopi_dtype_complex64, ACL_COMPLEX64) DIOPI_DTYPE_TO_ACL_DTYPE_CASE(diopi_dtype_complex128, ACL_COMPLEX128) default: diff --git a/impl/ascend/device_configs.py b/impl/ascend/device_configs.py index bcaef224d0..6cabba3e61 100755 --- a/impl/ascend/device_configs.py +++ b/impl/ascend/device_configs.py @@ -35,6 +35,20 @@ ] ) ), + # Bad in-place call: input tensor size [2] and output tensor size [2, 0, 2] should match + # pytorch 2.1.0 does not support this case + # input: (2,), batch1: (2, 0, 4), batch2: (2, 4, 2) + 'baddbmm_without_inplace': dict( + name=["baddbmm"], + tensor_para=dict( + args=[ + { + "ins": ["input"], + "shape": [Skip((2,))], + }, + ], + ), + ), # temp for 910B 'uniform': dict( diff --git a/impl/ascend/functions/baddbmm.cpp b/impl/ascend/functions/baddbmm.cpp index eec036acb7..008970abac 100644 --- a/impl/ascend/functions/baddbmm.cpp +++ b/impl/ascend/functions/baddbmm.cpp @@ -4,72 +4,32 @@ * @copyright (c) 2023, DeepLink. */ -#include "../common/acloprunner.hpp" +#include "../aclnn/acl_scalar.hpp" +#include "../aclnn/adaptor.hpp" namespace impl { namespace ascend { diopiError_t diopiBaddbmm(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t batch1, diopiConstTensorHandle_t batch2, double beta, double alpha) { - diopiDtype_t outDtype; - diopiGetTensorDtype(out, &outDtype); + AscendTensor inAt(input); + auto betas = constructDiopiScalarT(inAt.dtype(), beta); + auto alphas = constructDiopiScalarT(inAt.dtype(), alpha); - AscendTensor inputAt(input); - AscendTensor outputAt(out); - AscendTensor batch1At(batch1); - AscendTensor batch2At(batch2); - - // get the size of batch1 * batch2 - std::vector batch1Shape = batch1At.shape(); - std::vector batch2Shape = batch2At.shape(); - std::vector vectorSizeBatchMatMulTensor = {batch1Shape[0], batch1Shape[1], batch2Shape[2]}; - - // init a tensor according to the size of batch1 * batch2 ; - diopiSize_t diopiSizeBatchMatMulTensor = vectorToDiopiSize(vectorSizeBatchMatMulTensor); - AscendTensor batchMatMulTensorAt; - makeTensor(ctx, batchMatMulTensorAt, &diopiSizeBatchMatMulTensor, outDtype, diopiDevice_t::diopi_device); - - // does batch1/batch2 need to transpose? - bool isSelfT = false; - bool isMat2T = false; - - // do batch1 times batch2 -> BatchMatMulTensor - AclOpRunner<2, 1>("BatchMatMul", ctx) - .addInput(batch1At) - .addInput(batch2At) - .addOutput(batchMatMulTensorAt) - .setAttr("adj_x1", isSelfT) - .setAttr("adj_x2", isMat2T) - .run(); - - // init memory based on the size of alphaMulTensor and betaMulTensor - AscendTensor alphaMulTensor; - AscendTensor betaMulTensor; - makeTensorLike(ctx, alphaMulTensor, batchMatMulTensorAt, outDtype); - makeTensorLike(ctx, betaMulTensor, inputAt, outDtype); - - diopiScalar_t alphaScalar = constructDiopiScalarT(outDtype, alpha); - diopiScalar_t betaScalar = constructDiopiScalarT(outDtype, beta); - - // transform ascendTensor to diopiTensorHandle_t - diopiTensorHandle_t diopiAlphaMulTensor = const_cast(alphaMulTensor.tensorHandle()); - diopiTensorHandle_t diopiBateMulTensor = const_cast(betaMulTensor.tensorHandle()); - diopiTensorHandle_t diopiAsBatchMatMulTensor = const_cast(batchMatMulTensorAt.tensorHandle()); - diopiTensorHandle_t diopiInput = const_cast(inputAt.tensorHandle()); - - // alpha times BatchMatMulTensor -> alphaMulTensor and beta times input -> betaMulTensor - diopiMulScalar(ctx, diopiAlphaMulTensor, diopiAsBatchMatMulTensor, &alphaScalar); - diopiMulScalar(ctx, diopiBateMulTensor, diopiInput, &betaScalar); - - diopiScalar_t otherScalar = constructDiopiScalarT(outDtype, 1); - diopiTensorHandle_t diopiOutput = const_cast(outputAt.tensorHandle()); - diopiAdd(ctx, diopiOutput, diopiAlphaMulTensor, diopiBateMulTensor, &otherScalar); + int cubeMathType = 0; + DIOPI_ASCEND_CALL_ACLNN(aclnnBaddbmm, ctx, input, batch1, batch2, &betas, &alphas, out, cubeMathType); return diopiSuccess; } diopiError_t diopiBaddbmmInp(diopiContextHandle_t ctx, diopiTensorHandle_t input, diopiConstTensorHandle_t batch1, diopiConstTensorHandle_t batch2, double beta, double alpha) { - return diopiBaddbmm(ctx, input, input, batch1, batch2, beta, alpha); + AscendTensor inAt(input); + auto betas = constructDiopiScalarT(inAt.dtype(), beta); + auto alphas = constructDiopiScalarT(inAt.dtype(), alpha); + + int cubeMathType = 0; + DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceBaddbmm, ctx, input, batch1, batch2, &betas, &alphas, cubeMathType); + return diopiSuccess; } } // namespace ascend diff --git a/impl/ascend/functions/bmm.cpp b/impl/ascend/functions/bmm.cpp index fbc010f73f..aa6c149d18 100644 --- a/impl/ascend/functions/bmm.cpp +++ b/impl/ascend/functions/bmm.cpp @@ -4,22 +4,20 @@ * @copyright (c) 2023, DeepLink. */ -#include "../common/acloprunner.hpp" +#include "../aclnn/acl_scalar.hpp" +#include "../aclnn/adaptor.hpp" namespace impl { namespace ascend { diopiError_t diopiBmm(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t mat2) { - AscendTensor inputAt(input); - AscendTensor mat2At(mat2); - AscendTensor outputAt(out); - if (inputAt.numel() == 0 || mat2At.numel() == 0) { - diopiScalar_t zero = constructDiopiScalarT(outputAt.dtype(), 0.0); - diopiFill(ctx, out, &zero); - return diopiSuccess; - } + AscendTensor inAt(input); + AscendTensor matAt(mat2); + ASCEND_CHECK_ABORT(inAt.dtype() == matAt.dtype(), "[diopiBmm] tensors dtype does not matched."); + + int cubeMathType = 0; + DIOPI_ASCEND_CALL_ACLNN(aclnnBatchMatMul, ctx, input, mat2, out, cubeMathType); - AclOpRunner<2, 1>("BatchMatMulV2", ctx).addInput(input).addInput(mat2).setAttr("adj_x1", false).setAttr("adj_x1", false).addOutput(out).run(); return diopiSuccess; } diff --git a/impl/ascend_npu/ascend_config.yaml b/impl/ascend_npu/ascend_config.yaml index 8398977c58..f97ff245b1 100755 --- a/impl/ascend_npu/ascend_config.yaml +++ b/impl/ascend_npu/ascend_config.yaml @@ -13,6 +13,8 @@ ascend: - diopiArgmax - diopiAtan - diopiAtanInp +- diopiBaddbmm +- diopiBaddbmmInp - diopiBitwiseNot - diopiBitwiseNotInp - diopiBitwiseAnd @@ -23,6 +25,7 @@ ascend: - diopiBitwiseOrInp - diopiBitwiseOrScalar - diopiBitwiseOrInpScalar +- diopiBmm - diopiCastDtype - diopiClamp - diopiClampInp @@ -201,17 +204,10 @@ ascend_npu: - diopiAddcmul - diopiAddcmulInp - diopiAddmm -- diopiBaddbmm -- diopiBaddbmmInp - diopiBatchNorm - diopiBatchNormBackward - diopiNonzero -- diopiBmm - diopiMatmul -- diopiMaxAll -- diopiMin -- diopiMinAll -- diopiMinimum - diopiCat - diopiDropout - diopiDropoutInp @@ -219,8 +215,6 @@ ascend_npu: - diopiExpand - diopiGroupNorm - diopiGroupNormBackward -- diopiMax -- diopiMaximum - diopiMaskedFill - diopiMaskedFillInp - diopiMaskedFillInpScalar