Skip to content

Commit

Permalink
[ascend]Zgc/diopi ascend add masked fill (#1005)
Browse files Browse the repository at this point in the history
* enable use aclnnCopy in tensor.copy_

* add masked_fill relate
  • Loading branch information
zhaoguochun1995 authored Mar 1, 2024
1 parent 4fb272f commit 57a3b7a
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 5 deletions.
12 changes: 12 additions & 0 deletions impl/ascend/convert_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -319,3 +319,15 @@
- diopiIsNan:
dtype: (uint8, int8, int32, int16, int64, bool)->float32

- diopiMaskedFill:
dtype: (int16, uint8)->int32, (float64)->float32

- diopiMaskedFillInp:
dtype: (int16, uint8)->int32, (float64)->float32

- diopiMaskedFillInpScalar:
dtype: (int16, uint8)->int32, (float64)->float32

- diopiMaskedFillScalar:
dtype: (int16, uint8)->int32, (float64)->float32

8 changes: 4 additions & 4 deletions impl/ascend_npu/ascend_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,6 @@ ascend:
- diopiLtScalar
- diopiMSELoss
- diopiMSELossBackward
- diopiMaskedFill
- diopiMaskedFillInp
- diopiMaskedFillInpScalar
- diopiMaskedFillScalar
- diopiMax
- diopiMaxAll
- diopiMaximum
Expand Down Expand Up @@ -194,6 +190,10 @@ ascend_npu:
- diopiRemainderScalar
- diopiRemainder
- diopiRepeat
- diopiMaskedFill
- diopiMaskedFillInp
- diopiMaskedFillInpScalar
- diopiMaskedFillScalar
- diopiSilu
- diopiSiluInp
- diopiSiluBackward
Expand Down
56 changes: 56 additions & 0 deletions impl/ascend_npu/diopi_impl/masked_fill.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/**
* @file
* @author DeepLink
* @copyright (c) 2024, DeepLink.
*/

#include "helper.hpp"
#include "op_plugin/OpApiInterface.h"

namespace OP_IMPL_NS {

diopiError_t diopiMaskedFill(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t mask,
diopiConstTensorHandle_t value) {
BEGIN_CALL_ACL_OP(out, input, mask, value);
if (input == nullptr || inputAt.numel() <= 0) {
return diopiSuccess;
}
if (outAt.data_ptr() != inputAt.data_ptr()) {
outAt.copy_(inputAt);
}
op_api::masked_fill_(outAt, maskAt, valueAt);
END_CALL_ACL_OP();
}

diopiError_t diopiMaskedFillInp(diopiContextHandle_t ctx, diopiTensorHandle_t input, diopiConstTensorHandle_t mask, diopiConstTensorHandle_t value) {
BEGIN_CALL_ACL_OP(input, mask, value);
if (input == nullptr || inputAt.numel() <= 0) {
return diopiSuccess;
}
op_api::masked_fill_(inputAt, maskAt, valueAt);
END_CALL_ACL_OP();
}

diopiError_t diopiMaskedFillScalar(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t mask,
const diopiScalar_t* value) {
BEGIN_CALL_ACL_OP(out, input, mask, value);
if (input == nullptr || inputAt.numel() <= 0) {
return diopiSuccess;
}
if (outAt.data_ptr() != inputAt.data_ptr()) {
outAt.copy_(inputAt);
}
op_api::masked_fill_(outAt, maskAt, valueAt);
END_CALL_ACL_OP();
}

diopiError_t diopiMaskedFillInpScalar(diopiContextHandle_t ctx, diopiTensorHandle_t input, diopiConstTensorHandle_t mask, const diopiScalar_t* value) {
BEGIN_CALL_ACL_OP(input, mask, value);
if (input == nullptr || inputAt.numel() <= 0) {
return diopiSuccess;
}
op_api::masked_fill_(inputAt, maskAt, valueAt);
END_CALL_ACL_OP();
}

} // namespace OP_IMPL_NS
2 changes: 1 addition & 1 deletion impl/ascend_npu/torch_npu/csrc/DIOPIAdapter.cpp
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -3116,7 +3116,7 @@ namespace {
at::Tensor& wrapper_Tensor_fill_(at::Tensor& self, const at::Tensor& value) { return acl_op::fill_(self, value); }

at::Tensor& wrapper__copy_(at::Tensor& self, const at::Tensor& src, bool non_blocking) {
return at_npu::native::NPUNativeFunctions::copy_(self, src, non_blocking);
return at_npu::native::NPUNativeOpApiFunctions::copy_(self, src, non_blocking);
}

at::Tensor wrapper__view(const at::Tensor& self, at::IntArrayRef size) { return impl::aten::viewStorage(self, size); }
Expand Down

0 comments on commit 57a3b7a

Please sign in to comment.