From dc37616d6773acc55c7452c242c7f13e838362f4 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 12 Jan 2024 19:11:14 -0800 Subject: [PATCH] [torch][quant] Support quantize and dequantize for torch (#2731) Handle both `torch.dequantize` and `torch.quantize_per_tensor` including the op based quantization parameter tracking. This includes adding `qint32` to torch types as it was missing during the initial type inclusion. For testing we only have `torch.int8` and `torch.float` types on function boundaries as the `qint8` types require passing the scale and zero point quantization information which is not supported yet. --- .../Conversion/TorchToLinalg/Utils.h | 2 + .../Dialect/Torch/IR/GeneratedTorchOps.td | 120 ++++++++++++++ .../torch-mlir/Dialect/Torch/IR/TorchTypes.td | 11 ++ .../TorchToLinalg/Uncategorized.cpp | 153 +++++++++++++++++- lib/Conversion/TorchToLinalg/Utils.cpp | 17 ++ lib/Dialect/Torch/IR/TorchTypes.cpp | 12 +- .../Transforms/AbstractInterpLibrary.cpp | 73 +++++++++ lib/Dialect/Torch/Utils/Utils.cpp | 12 ++ .../base_lazy_backend/shape_inference.cpp | 2 +- projects/pt1/e2e_testing/xfail_sets.py | 6 + .../build_tools/abstract_interp_lib_gen.py | 43 +++++ .../build_tools/torch_ods_gen.py | 7 + .../test_suite/elementwise.py | 46 ++++++ 13 files changed, 496 insertions(+), 8 deletions(-) diff --git a/include/torch-mlir/Conversion/TorchToLinalg/Utils.h b/include/torch-mlir/Conversion/TorchToLinalg/Utils.h index 134fbeca46dc..7c9257075824 100644 --- a/include/torch-mlir/Conversion/TorchToLinalg/Utils.h +++ b/include/torch-mlir/Conversion/TorchToLinalg/Utils.h @@ -95,6 +95,8 @@ FailureOr getBackendTypeForScalarType(MLIRContext *context, torch_upstream::ScalarType dtypeInt); +bool isUnsignedTorchType(Type type); + } // namespace torch_to_linalg } // namespace torch } // namespace mlir diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 12a2bf4a86e2..9525f9f9ffa6 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -14206,6 +14206,126 @@ def Torch_AtenLeakyReluBackwardOp : Torch_Op<"aten.leaky_relu_backward", [ }]; } +def Torch_AtenQuantizePerTensorOp : Torch_Op<"aten.quantize_per_tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::quantize_per_tensor : (Tensor, float, int, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_FloatType:$scale, + Torch_IntType:$zero_point, + Torch_IntType:$dtype + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenQuantizePerTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenQuantizePerTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + +def Torch_AtenDequantizeSelfOp : Torch_Op<"aten.dequantize.self", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::dequantize.self : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenDequantizeSelfOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenDequantizeSelfOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenDequantizeTensorOp : Torch_Op<"aten.dequantize.tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::dequantize.tensor : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$qtensor + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenDequantizeTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenDequantizeTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenIntReprOp : Torch_Op<"aten.int_repr", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::int_repr : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenIntReprOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenIntReprOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_Aten_MakePerTensorQuantizedTensorOp : Torch_Op<"aten._make_per_tensor_quantized_tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::_make_per_tensor_quantized_tensor : (Tensor, float, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_FloatType:$scale, + Torch_IntType:$zero_point + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten_MakePerTensorQuantizedTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void Aten_MakePerTensorQuantizedTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_PrimLayoutOp : Torch_Op<"prim.layout", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td index 1f7231b3500a..c3b5c1582c02 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td @@ -306,6 +306,17 @@ def Torch_QUInt8Type : Torch_Type<"QUInt8", "quint8"> { }]; } +def Torch_QInt32Type : Torch_Type<"QInt32", "qint32"> { + let summary = "Type modeling `ScalarType::QInt32`"; + let description = [{ + This is intended to be a 1:1 match for the Torch `ScalarType` types. + + Looking at the variety / ad-hocness (e.g. `QUInt4x2`) of that set of + types, it is deemed preferable to import them as one-off ad-hoc types + instead of a single parameterized type. + }]; +} + def Torch_LinearParamsType : Torch_Type<"LinearParams", "LinearParams"> { let summary = "Torch packed linear params type"; let description = [{ diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 749945dee6e2..e35136e333f0 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -1316,6 +1316,106 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return b.create(loc, payloadArgs[0], allOnesVal); } + if (isa(op)) { + auto value = payloadArgs[0]; + auto valueTy = value.getType(); + auto qtensor = op->getOperand(0); + auto qtensorTy = qtensor.getType().cast().getDtype(); + auto makeQTensor = + qtensor.getDefiningOp(); + if (!makeQTensor) { + op->emitError( + "unimplemented: dequantizing tensor of unknown scale / zero-point"); + return nullptr; + } + + auto outFpTy = payloadArgs[1].getType(); + auto outBw = outFpTy.getIntOrFloatBitWidth(); + auto outIntTy = b.getIntegerType(outBw); + + if (valueTy != outIntTy) { + if (torch_to_linalg::isUnsignedTorchType(qtensorTy)) { + value = b.create(loc, outIntTy, value); + } else { + value = b.create(loc, outIntTy, value); + } + } + + Value zp = makeQTensor.getZeroPoint(); + zp = converter->materializeTargetConversion( + b, loc, converter->convertType(zp.getType()), + makeQTensor.getZeroPoint()); + auto zpTy = zp.getType(); + + if (zpTy != outIntTy) { + zp = b.create(loc, outIntTy, zp); + } + + value = b.create(loc, value, zp); + + if (torch_to_linalg::isUnsignedTorchType(qtensorTy)) { + value = b.create(loc, outFpTy, value); + } else { + value = b.create(loc, outFpTy, value); + } + + Value scale = makeQTensor.getScale(); + scale = converter->materializeTargetConversion( + b, loc, converter->convertType(scale.getType()), + makeQTensor.getScale()); + if (scale.getType() != value.getType()) { + scale = b.create(loc, value.getType(), scale); + } + value = b.create(loc, value, scale); + return value; + } + + if (auto quant = dyn_cast(op)) { + Value value = payloadArgs[0]; + Value scale = quant.getScale(); + Value zp = quant.getZeroPoint(); + auto valueTy = value.getType(); + + zp = converter->materializeTargetConversion( + b, loc, converter->convertType(zp.getType()), zp); + zp = b.create(loc, valueTy, zp); + + scale = converter->materializeTargetConversion( + b, loc, converter->convertType(scale.getType()), scale); + scale = b.create(loc, valueTy, scale); + + value = b.create(loc, value, scale); + value = b.create(loc, value); + value = b.create(loc, value, zp); + + auto destTy = payloadArgs[1].getType(); + auto bitwidth = destTy.getIntOrFloatBitWidth(); + bool isUnsigned = torch_to_linalg::isUnsignedTorchType(quant.getType()); + APInt min = isUnsigned ? APInt::getMinValue(bitwidth) + : APInt::getSignedMinValue(bitwidth); + APInt max = isUnsigned ? APInt::getMaxValue(bitwidth) + : APInt::getSignedMaxValue(bitwidth); + + Value minVal = b.create( + loc, b.getFloatAttr(valueTy, min.getSExtValue())); + Value maxVal = b.create( + loc, b.getFloatAttr(valueTy, max.getSExtValue())); + Value minCmp = + b.create(loc, arith::CmpFPredicate::ULT, value, minVal); + Value maxCmp = + b.create(loc, arith::CmpFPredicate::UGT, value, maxVal); + value = b.create(loc, minCmp, minVal, value); + value = b.create(loc, maxCmp, maxVal, value); + + if (isUnsigned) { + value = b.create(loc, destTy, value); + } else { + value = b.create(loc, destTy, value); + } + + return value; + } + op->emitError("unimplemented lowering in " "createLinalgPayloadCalculationForElementwiseOp"); return nullptr; @@ -1368,9 +1468,10 @@ class ConvertElementwiseOp : public ConversionPattern { AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenNegOp, AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, - AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp, AtenTrilOp, - AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp, - AtenAtanOp, AtenAcosOp, AtenRealOp, AtenImagOp>(op)) + AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp, + AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, + AtenFillTensorOp, AtenAtanOp, AtenAcosOp, AtenRealOp, AtenImagOp, + AtenDequantizeSelfOp, AtenDequantizeTensorOp, AtenQuantizePerTensorOp>(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) @@ -2080,6 +2181,42 @@ class ConvertLogitOp : public OpConversionPattern { } }; } // namespace + +namespace { +class ConvertAtenIntReprOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenIntReprOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + RankedTensorType resultType = getTypeConverter() + ->convertType(op->getResult(0).getType()) + .cast(); + rewriter.replaceOpWithNewOp(op, resultType, + adaptor.getSelf()); + return success(); + } +}; +} // namespace + +namespace { +class ConvertMakePerTensorQuantizedTensorOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(Aten_MakePerTensorQuantizedTensorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + RankedTensorType resultType = getTypeConverter() + ->convertType(op->getResult(0).getType()) + .cast(); + rewriter.replaceOpWithNewOp(op, resultType, + adaptor.getSelf()); + return success(); + } +}; +} // namespace + void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { @@ -2102,9 +2239,9 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp, AtenAcosOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp, - AtenTrilOp, - AtenRemainderScalarOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, - AtenFillTensorOp, AtenRealOp, AtenImagOp>(); + AtenTrilOp, AtenRemainderScalarOp, AtenBitwiseNotOp, AtenRoundOp, + AtenFillScalarOp, AtenFillTensorOp, AtenRealOp, AtenImagOp, + AtenDequantizeSelfOp, AtenDequantizeTensorOp, AtenQuantizePerTensorOp>(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); @@ -2122,4 +2259,8 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( patterns.add(typeConverter, context); patterns.add(typeConverter, context); target.addIllegalOp(); + patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); + target.addIllegalOp(); } diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index ccc78985dc6c..77459aca3a60 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -559,3 +559,20 @@ FailureOr torch_to_linalg::getBackendTypeForScalarType( } return type; } + +bool torch_to_linalg::isUnsignedTorchType(Type type) { + if (auto tty = dyn_cast(type)) + return isUnsignedTorchType(tty.getDtype()); + if (isa(type)) + return false; + if (isa(type)) + return false; + if (isa(type)) + return true; + if (isa(type)) + return false; + if (auto intTy = dyn_cast(type)) + return intTy.isUnsigned(); + llvm_unreachable("Unknown type checked for signedness"); + return false; +} diff --git a/lib/Dialect/Torch/IR/TorchTypes.cpp b/lib/Dialect/Torch/IR/TorchTypes.cpp index cf832b1b755e..33ef459081c4 100644 --- a/lib/Dialect/Torch/IR/TorchTypes.cpp +++ b/lib/Dialect/Torch/IR/TorchTypes.cpp @@ -184,7 +184,7 @@ static bool isValidTorchDtype(Type dtype) { dtype = dtype.cast().getElementType(); } // Torch quantized types. - if (dtype.isa()) + if (dtype.isa()) return true; // Builtin floating point types. if (dtype.isa()) @@ -410,6 +410,16 @@ static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) { } else if (dtype.isa()){ return dtype; } + + if (isa(dtype)) + return IntegerType::get(context, 8, IntegerType::Signless); + + if (isa(dtype)) + return IntegerType::get(context, 8, IntegerType::Signless); + + if (isa(dtype)) + return IntegerType::get(context, 32, IntegerType::Signless); + emitError(UnknownLoc::get(context)) << "unimplemented: conversion of dtype " << dtype << " to builtin tensor element type"; diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 3901cd34a4aa..c286168080e4 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6481,6 +6481,26 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.quantize_per_tensor\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.dequantize.self\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.dequantize.tensor\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.int_repr\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten._make_per_tensor_quantized_tensor\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.int) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.prims.convert_element_type\"(%arg0: !torch.list, %arg1: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -11783,6 +11803,59 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.quantize_per_tensor\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" return %arg3 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.dequantize.self\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" return %int6 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.dequantize.tensor\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" return %int6 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.int_repr\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int3 = torch.constant.int 3\n" +" %int1 = torch.constant.int 1\n" +" %int12 = torch.constant.int 12\n" +" %int0 = torch.constant.int 0\n" +" %int13 = torch.constant.int 13\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int13 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int0 : !torch.int\n" +" } else {\n" +" %3 = torch.aten.eq.int %0#1, %int12 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int3 : !torch.int\n" +" }\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._make_per_tensor_quantized_tensor\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.int) -> !torch.int {\n" +" %int14 = torch.constant.int 14\n" +" %int12 = torch.constant.int 12\n" +" %int1 = torch.constant.int 1\n" +" %int13 = torch.constant.int 13\n" +" %int0 = torch.constant.int 0\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int13 : !torch.int\n" +" } else {\n" +" %3 = torch.aten.eq.int %0#1, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int12 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int14 : !torch.int\n" +" }\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" "}\n" ""; // clang-format on diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 12ac1d58ee59..06330f16a57e 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -64,6 +64,12 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) { return torch_upstream::ScalarType::Byte; if (type.isSignedInteger(8)) return torch_upstream::ScalarType::Char; + if (type.isa()) + return torch_upstream::ScalarType::QUInt8; + if (type.isa()) + return torch_upstream::ScalarType::QInt8; + if (type.isa()) + return torch_upstream::ScalarType::QInt32; if (type.isa()) { mlir::Type complexElemType = type.cast().getElementType(); if (complexElemType.isF16()) @@ -109,6 +115,12 @@ Torch::getTypeForScalarType(MLIRContext *context, return mlir::IntegerType::get(context, 8, mlir::IntegerType::Unsigned); case torch_upstream::ScalarType::Char: return mlir::IntegerType::get(context, 8, mlir::IntegerType::Signed); + case torch_upstream::ScalarType::QUInt8: + return QUInt8Type::get(context); + case torch_upstream::ScalarType::QInt8: + return QInt8Type::get(context); + case torch_upstream::ScalarType::QInt32: + return QInt32Type::get(context); case torch_upstream::ScalarType::ComplexHalf: return mlir::ComplexType::get(Float16Type::get(context)); case torch_upstream::ScalarType::ComplexFloat: diff --git a/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp b/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp index 3971fdd3258a..15080f9764cc 100644 --- a/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp +++ b/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp @@ -431,4 +431,4 @@ std::vector compute_shape_linspace(const at::Scalar & start, } // namespace lazy -} // namespace torch \ No newline at end of file +} // namespace torch diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index de68680a82f2..e04657df4d2c 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -311,6 +311,10 @@ # ERROR: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float' "GroupNormModule_basic", "GroupNormNoWeightAndBiasModule_basic", + + # Dynamo does not support tracing quantized tensors + "ElementwiseDequantizePerTensorModule_basic", + "ElementwiseQuantizePerTensorModule_basic", } TORCHDYNAMO_CRASHING_SET = { @@ -1507,4 +1511,6 @@ "ElementwiseBitwiseAndScalarInt64Module_basic", "ElementwiseBitwiseAndScalarInt32Module_basic", "ElementwiseBitwiseAndScalarInt8Module_basic", + "ElementwiseQuantizePerTensorModule_basic", + "ElementwiseDequantizePerTensorModule_basic" } diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 211023a9deec..6a8fbf34e911 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -221,6 +221,21 @@ def aten〇clamp_max〡shape(self: List[int], max: float) -> List[int]: def aten〇rsub〇Scalar〡shape(self: List[int], other: float, alpha: float = 1) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇quantize_per_tensor〡shape(self: List[int], scale: float, zero_point: int, dtype: int) -> List[int]: + return upstream_shape_functions.unary(self) + +def aten〇dequantize〇self〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + +def aten〇dequantize〇tensor〡shape(qtensor: List[int]) -> List[int]: + return upstream_shape_functions.unary(qtensor) + +def aten〇int_repr〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + +def aten〇_make_per_tensor_quantized_tensor〡shape(self: List[int], scale: float, zero_point: int) -> List[int]: + return upstream_shape_functions.unary(self) + def prims〇convert_element_type〡shape(a: List[int], dtype: int) -> List[int]: return upstream_shape_functions.unary(a) @@ -3958,6 +3973,34 @@ def prims〇collapse〡dtype(a_rank_dtype: Tuple[int, int], start: int, end: int return a_dtype +def aten〇quantize_per_tensor〡dtype(self_rank_dtype: Tuple[int, int], scale: float, zero_point: int, dtype: int) -> int: + return dtype + +def aten〇dequantize〇self〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + return torch.float32 + +def aten〇dequantize〇tensor〡dtype(qtensor_rank_dtype: Tuple[int, int]) -> int: + return torch.float32 + +def aten〇int_repr〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + if (self_dtype == torch.quint8): + return torch.uint8 + if (self_dtype == torch.qint8): + return torch.int8 + return torch.int32 + +def aten〇_make_per_tensor_quantized_tensor〡dtype(self_rank_dtype: Tuple[int, int], scale: float, zero_point: int) -> int: + self_rank, self_dtype = self_rank_dtype + if (self_dtype == torch.uint8): + return torch.quint8 + if (self_dtype == torch.int8): + return torch.qint8 + return torch.qint32 + + + + # ============================================================================== # Main diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index a9f9ed96dce2..249c25628a82 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -805,6 +805,13 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::elu_backward : (Tensor, Scalar, Scalar, Scalar, bool, Tensor) -> (Tensor)") emit("aten::leaky_relu_backward : (Tensor, Tensor, Scalar, bool) -> (Tensor)") + # quantized ops + emit("aten::quantize_per_tensor : (Tensor, float, int, int) -> (Tensor)") + emit("aten::dequantize.self : (Tensor) -> (Tensor)") + emit("aten::dequantize.tensor : (Tensor) -> (Tensor)") + emit("aten::int_repr : (Tensor) -> (Tensor)") + emit("aten::_make_per_tensor_quantized_tensor : (Tensor, float, int) -> (Tensor)") + # ========================================================================== # `prim::` namespace. # ========================================================================== diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 5d6217b59072..23a22142c4d5 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -4136,6 +4136,52 @@ def ElementwiseBitwiseAndScalarInt8Module_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseQuantizePerTensorModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float, True), + ]) + def forward(self, x): + scale = 0.04 + zp = -110 + dtype = torch.qint8 + # We return the int representation as we can not map to quint8 type yet on boundaries. + q = torch.quantize_per_tensor(x, scale, zp, dtype).int_repr() + return q + +@register_test_case(module_factory=lambda: ElementwiseQuantizePerTensorModule()) +def ElementwiseQuantizePerTensorModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + + +# ============================================================================== + +class ElementwiseDequantizePerTensorModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int8, True), + ]) + def forward(self, x): + qx = torch._make_per_tensor_quantized_tensor(x, 0.1, 8) + qx = torch.dequantize(qx) + return qx + +@register_test_case(module_factory=lambda: ElementwiseDequantizePerTensorModule()) +def ElementwiseDequantizePerTensorModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-128, high=127).to(torch.int8)) + +# ============================================================================== + class GluStaticModule(torch.nn.Module): def __init__(self): super().__init__()