Skip to content

Commit

Permalink
[Ascend] fuj/replace-baddbmm-and-fix-max-and-min-config (#1208)
Browse files Browse the repository at this point in the history
* replace baddbmm and fix max and min config

* replace bmm with aclnn

* support bfloat16
  • Loading branch information
jingguo-st authored May 16, 2024
1 parent a45c104 commit 794b873
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 73 deletions.
2 changes: 2 additions & 0 deletions impl/ascend/ascend_tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ constexpr aclDataType diopiDtypeToAclDataType(diopiDtype_t dtype) noexcept {
return acl_dtype;

switch (dtype) {
DIOPI_DTYPE_TO_ACL_DTYPE_CASE(diopi_dtype_bfloat16, ACL_BF16)
DIOPI_DTYPE_TO_ACL_DTYPE_CASE(diopi_dtype_float16, ACL_FLOAT16)
DIOPI_DTYPE_TO_ACL_DTYPE_CASE(diopi_dtype_float32, ACL_FLOAT)
DIOPI_DTYPE_TO_ACL_DTYPE_CASE(diopi_dtype_float64, ACL_DOUBLE)
Expand All @@ -107,6 +108,7 @@ constexpr aclDataType diopiDtypeToAclDataType(diopiDtype_t dtype) noexcept {
DIOPI_DTYPE_TO_ACL_DTYPE_CASE(diopi_dtype_int64, ACL_INT64)
DIOPI_DTYPE_TO_ACL_DTYPE_CASE(diopi_dtype_uint64, ACL_UINT64)
DIOPI_DTYPE_TO_ACL_DTYPE_CASE(diopi_dtype_bool, ACL_BOOL)
DIOPI_DTYPE_TO_ACL_DTYPE_CASE(diopi_dtype_complex32, ACL_COMPLEX32)
DIOPI_DTYPE_TO_ACL_DTYPE_CASE(diopi_dtype_complex64, ACL_COMPLEX64)
DIOPI_DTYPE_TO_ACL_DTYPE_CASE(diopi_dtype_complex128, ACL_COMPLEX128)
default:
Expand Down
14 changes: 14 additions & 0 deletions impl/ascend/device_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,20 @@
]
)
),
# Bad in-place call: input tensor size [2] and output tensor size [2, 0, 2] should match
# pytorch 2.1.0 does not support this case
# input: (2,), batch1: (2, 0, 4), batch2: (2, 4, 2)
'baddbmm_without_inplace': dict(
name=["baddbmm"],
tensor_para=dict(
args=[
{
"ins": ["input"],
"shape": [Skip((2,))],
},
],
),
),

# temp for 910B
'uniform': dict(
Expand Down
68 changes: 14 additions & 54 deletions impl/ascend/functions/baddbmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,72 +4,32 @@
* @copyright (c) 2023, DeepLink.
*/

#include "../common/acloprunner.hpp"
#include "../aclnn/acl_scalar.hpp"
#include "../aclnn/adaptor.hpp"

namespace impl {
namespace ascend {

diopiError_t diopiBaddbmm(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t batch1,
diopiConstTensorHandle_t batch2, double beta, double alpha) {
diopiDtype_t outDtype;
diopiGetTensorDtype(out, &outDtype);
AscendTensor inAt(input);
auto betas = constructDiopiScalarT(inAt.dtype(), beta);
auto alphas = constructDiopiScalarT(inAt.dtype(), alpha);

AscendTensor inputAt(input);
AscendTensor outputAt(out);
AscendTensor batch1At(batch1);
AscendTensor batch2At(batch2);

// get the size of batch1 * batch2
std::vector<int64_t> batch1Shape = batch1At.shape();
std::vector<int64_t> batch2Shape = batch2At.shape();
std::vector<int64_t> vectorSizeBatchMatMulTensor = {batch1Shape[0], batch1Shape[1], batch2Shape[2]};

// init a tensor according to the size of batch1 * batch2 ;
diopiSize_t diopiSizeBatchMatMulTensor = vectorToDiopiSize(vectorSizeBatchMatMulTensor);
AscendTensor batchMatMulTensorAt;
makeTensor(ctx, batchMatMulTensorAt, &diopiSizeBatchMatMulTensor, outDtype, diopiDevice_t::diopi_device);

// does batch1/batch2 need to transpose?
bool isSelfT = false;
bool isMat2T = false;

// do batch1 times batch2 -> BatchMatMulTensor
AclOpRunner<2, 1>("BatchMatMul", ctx)
.addInput(batch1At)
.addInput(batch2At)
.addOutput(batchMatMulTensorAt)
.setAttr("adj_x1", isSelfT)
.setAttr("adj_x2", isMat2T)
.run();

// init memory based on the size of alphaMulTensor and betaMulTensor
AscendTensor alphaMulTensor;
AscendTensor betaMulTensor;
makeTensorLike(ctx, alphaMulTensor, batchMatMulTensorAt, outDtype);
makeTensorLike(ctx, betaMulTensor, inputAt, outDtype);

diopiScalar_t alphaScalar = constructDiopiScalarT(outDtype, alpha);
diopiScalar_t betaScalar = constructDiopiScalarT(outDtype, beta);

// transform ascendTensor to diopiTensorHandle_t
diopiTensorHandle_t diopiAlphaMulTensor = const_cast<diopiTensorHandle_t>(alphaMulTensor.tensorHandle());
diopiTensorHandle_t diopiBateMulTensor = const_cast<diopiTensorHandle_t>(betaMulTensor.tensorHandle());
diopiTensorHandle_t diopiAsBatchMatMulTensor = const_cast<diopiTensorHandle_t>(batchMatMulTensorAt.tensorHandle());
diopiTensorHandle_t diopiInput = const_cast<diopiTensorHandle_t>(inputAt.tensorHandle());

// alpha times BatchMatMulTensor -> alphaMulTensor and beta times input -> betaMulTensor
diopiMulScalar(ctx, diopiAlphaMulTensor, diopiAsBatchMatMulTensor, &alphaScalar);
diopiMulScalar(ctx, diopiBateMulTensor, diopiInput, &betaScalar);

diopiScalar_t otherScalar = constructDiopiScalarT(outDtype, 1);
diopiTensorHandle_t diopiOutput = const_cast<diopiTensorHandle_t>(outputAt.tensorHandle());
diopiAdd(ctx, diopiOutput, diopiAlphaMulTensor, diopiBateMulTensor, &otherScalar);
int cubeMathType = 0;
DIOPI_ASCEND_CALL_ACLNN(aclnnBaddbmm, ctx, input, batch1, batch2, &betas, &alphas, out, cubeMathType);
return diopiSuccess;
}

diopiError_t diopiBaddbmmInp(diopiContextHandle_t ctx, diopiTensorHandle_t input, diopiConstTensorHandle_t batch1, diopiConstTensorHandle_t batch2, double beta,
double alpha) {
return diopiBaddbmm(ctx, input, input, batch1, batch2, beta, alpha);
AscendTensor inAt(input);
auto betas = constructDiopiScalarT(inAt.dtype(), beta);
auto alphas = constructDiopiScalarT(inAt.dtype(), alpha);

int cubeMathType = 0;
DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceBaddbmm, ctx, input, batch1, batch2, &betas, &alphas, cubeMathType);
return diopiSuccess;
}

} // namespace ascend
Expand Down
18 changes: 8 additions & 10 deletions impl/ascend/functions/bmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,20 @@
* @copyright (c) 2023, DeepLink.
*/

#include "../common/acloprunner.hpp"
#include "../aclnn/acl_scalar.hpp"
#include "../aclnn/adaptor.hpp"

namespace impl {
namespace ascend {

diopiError_t diopiBmm(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t mat2) {
AscendTensor inputAt(input);
AscendTensor mat2At(mat2);
AscendTensor outputAt(out);
if (inputAt.numel() == 0 || mat2At.numel() == 0) {
diopiScalar_t zero = constructDiopiScalarT(outputAt.dtype(), 0.0);
diopiFill(ctx, out, &zero);
return diopiSuccess;
}
AscendTensor inAt(input);
AscendTensor matAt(mat2);
ASCEND_CHECK_ABORT(inAt.dtype() == matAt.dtype(), "[diopiBmm] tensors dtype does not matched.");

int cubeMathType = 0;
DIOPI_ASCEND_CALL_ACLNN(aclnnBatchMatMul, ctx, input, mat2, out, cubeMathType);

AclOpRunner<2, 1>("BatchMatMulV2", ctx).addInput(input).addInput(mat2).setAttr("adj_x1", false).setAttr("adj_x1", false).addOutput(out).run();
return diopiSuccess;
}

Expand Down
12 changes: 3 additions & 9 deletions impl/ascend_npu/ascend_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ ascend:
- diopiArgmax
- diopiAtan
- diopiAtanInp
- diopiBaddbmm
- diopiBaddbmmInp
- diopiBitwiseNot
- diopiBitwiseNotInp
- diopiBitwiseAnd
Expand All @@ -23,6 +25,7 @@ ascend:
- diopiBitwiseOrInp
- diopiBitwiseOrScalar
- diopiBitwiseOrInpScalar
- diopiBmm
- diopiCastDtype
- diopiClamp
- diopiClampInp
Expand Down Expand Up @@ -201,26 +204,17 @@ ascend_npu:
- diopiAddcmul
- diopiAddcmulInp
- diopiAddmm
- diopiBaddbmm
- diopiBaddbmmInp
- diopiBatchNorm
- diopiBatchNormBackward
- diopiNonzero
- diopiBmm
- diopiMatmul
- diopiMaxAll
- diopiMin
- diopiMinAll
- diopiMinimum
- diopiCat
- diopiDropout
- diopiDropoutInp
- diopiCopyInp
- diopiExpand
- diopiGroupNorm
- diopiGroupNormBackward
- diopiMax
- diopiMaximum
- diopiMaskedFill
- diopiMaskedFillInp
- diopiMaskedFillInpScalar
Expand Down

0 comments on commit 794b873

Please sign in to comment.