Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ljf/fix ascend unique op #1343

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions impl/ascend/functions/unique.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,8 @@ diopiError_t diopiUnique(diopiContextHandle_t ctx, diopiTensorHandle_t* out, dio
makeTensor(ctx, outTmpAt, {inputAt.numel()}, inputAt.dtype());
}

// allocate temp inverse tensor
diopiTensorHandle_t inverseTmp = nullptr;
AscendTensor inverseTmpAt(inverseTmp);
bool returnInverse = (indices != nullptr) ? true : false;
std::vector<int64_t> zeroShape = {0};
if (returnInverse || returnCounts) {
makeTensor(ctx, inverseTmpAt, inputAt.shape(), diopi_dtype_int64);
} else {
makeTensor(ctx, inverseTmpAt, zeroShape, diopi_dtype_int64);
}

// allocate temp counts tensor
diopiTensorHandle_t countsTmp = nullptr;
Expand All @@ -48,8 +40,23 @@ diopiError_t diopiUnique(diopiContextHandle_t ctx, diopiTensorHandle_t* out, dio
}

// call aclnnUnique2
auto params = ::impl::ascend::aclnn_adaptor::convertParams(input, sorted, returnInverse, returnCounts, outTmpAt, inverseTmpAt, countsTmpAt).params();
DIOPI_ASECND_CALL_ACLNN_TYPE_SYNC(aclnnUnique2, ctx, params);
std::tuple<aclTensor*, bool, bool, bool, aclTensor*, aclTensor*, aclTensor*> params;
if (returnInverse) {
params = ::impl::ascend::aclnn_adaptor::convertParams(input, sorted, returnInverse, returnCounts, outTmpAt, indices, countsTmpAt).params();
DIOPI_ASECND_CALL_ACLNN_TYPE_SYNC(aclnnUnique2, ctx, params);
} else {
// allocate temp inverse tensor
diopiTensorHandle_t inverseTmp = nullptr;
AscendTensor inverseTmpAt(inverseTmp);
makeTensor(ctx, inverseTmpAt, zeroShape, diopi_dtype_int64);
if (returnCounts) {
makeTensor(ctx, inverseTmpAt, inputAt.shape(), diopi_dtype_int64);
} else {
makeTensor(ctx, inverseTmpAt, zeroShape, diopi_dtype_int64);
}
params = ::impl::ascend::aclnn_adaptor::convertParams(input, sorted, returnInverse, returnCounts, outTmpAt, inverseTmpAt, countsTmpAt).params();
DIOPI_ASECND_CALL_ACLNN_TYPE_SYNC(aclnnUnique2, ctx, params);
}

// get true outShape by aclGetViewShape
int64_t* viewDims = nullptr;
Expand All @@ -65,11 +72,6 @@ diopiError_t diopiUnique(diopiContextHandle_t ctx, diopiTensorHandle_t* out, dio
AscendTensor outReshapeAt = reshape(ctx, outTmpAt, {viewDims, viewDims + viewDimNum});
*out = const_cast<diopiTensorHandle_t>(outReshapeAt.tensorHandle());

// fill indices tensor
if (returnInverse) {
indices = const_cast<diopiTensorHandle_t>(inverseTmpAt.tensorHandle());
}

// fill counts tensor
if (returnCounts) {
// get counts tensor shape, counts tensor is the 7th tensor in aclnnUnique2, index = 6
Expand Down
Loading