Skip to content

Commit

Permalink
[MLIR][TORCH] Add TorchToTosa lowering for torch.aten.softmax.int op
Browse files Browse the repository at this point in the history
  • Loading branch information
AmosLewis authored and AmosLewis committed Oct 28, 2022
1 parent b723186 commit e37895a
Showing 1 changed file with 51 additions and 0 deletions.
51 changes: 51 additions & 0 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3028,6 +3028,56 @@ LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
}


template <>
LogicalResult ConvertAtenOp<AtenSoftmaxIntOp>::matchAndRewrite(
AtenSoftmaxIntOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// Math: exp(%x) / sum(exp(%x), %dim)
// Torch format:
// "aten.softmax.int"(%x,%dim): (tensor<2x3xf32>, int) -> tensor<2x3xf32>
// Decompose tosa format: with -torch-decompose-complex-ops flag
// https://gist.github.com/AmosLewis/e668c3bfd2472e9f9f045e012362d831
// %2 = "tosa.exp"(%x) : (tensor<2x3xf32>) -> tensor<2x3xf32>
// %3 = "tosa.reduce_sum"(%2) {axis = %dim : i64} : (tensor<2x3xf32>) -> tensor<2x1xf32>
// %4 = "tosa.reciprocal"(%3) : (tensor<2x1xf32>) -> tensor<2x1xf32>
// %5 = "tosa.mul"(%2, %4) {shift = 0 : i32} : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
// No-Decompose TOSA format: without -torch-decompose-complex-ops flag
// "tosa.custom(%x){dim = 1 : i64, identifier = "softmax"}" : (tensor<2x3xf32>) -> tensor<2x3xf32>

// Check AtenSoftmaxIntOp first input is a tensor type.
auto selfType = adaptor.self().getType().dyn_cast<TensorType>();
if (!selfType)
return rewriter.notifyMatchFailure(
op, "Only tensor types input are currently supported");

// Get the dim int64_t type value from AtenSoftmaxIntOp second input,
// type need to convert from mlir::TypedValue<::mlir::torch::Torch::IntType>
int64_t dim;
if (!matchPattern(op.dim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(
op, "unimplemented: value `dim` should be a torch constant int");

// Create output type for tosa::CustomOp input
auto outType = getTypeConverter()->convertType(op.getType());

// Create attributes for tosa::CustomOp input
// example: {dim = 1 : i64, identifier = "softmax"}
StringAttr nameIdAttr= rewriter.getStringAttr("identifier");
StringAttr nameValueAttr= rewriter.getStringAttr("softmax");
StringAttr dimIdAttr= rewriter.getStringAttr("dim");
IntegerAttr dimValueAttr = rewriter.getI64IntegerAttr(dim);
mlir::NamedAttribute nameAttr = mlir::NamedAttribute(nameIdAttr, nameValueAttr);
mlir::NamedAttribute dimAttr = mlir::NamedAttribute(dimIdAttr, dimValueAttr);
llvm::ArrayRef<mlir::NamedAttribute> custom_attributes{nameAttr, dimAttr};

// TODO unportable target hardware implementation of exp(%x) / sum(exp(%x), %dim)
rewriter.replaceOpWithNewOp<tosa::CustomOp>(op, outType, adaptor.self(),
custom_attributes);

return success();
}


template <>
LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
AtenArangeStartStepOp op, OpAdaptor adaptor,
Expand Down Expand Up @@ -3854,6 +3904,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
INSERT_ATENOP_PATTERN(AtenSliceTensorOp);
INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
INSERT_ATENOP_PATTERN(AtenWhereSelfOp);
INSERT_ATENOP_PATTERN(AtenSoftmaxIntOp);
INSERT_ATENOP_PATTERN(AtenArangeStartStepOp);
INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp);
INSERT_ATENOP_PATTERN(AtenCopyOp);
Expand Down

0 comments on commit e37895a

Please sign in to comment.