Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Optimizer] DQ + MatMul to MatMulNBits support #21180

Closed
wants to merge 36 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
d7d08f7
adding python interface
fajin-corp May 29, 2024
1c687d1
restarting
fajin-corp Jun 17, 2024
7612401
finished default quantizer
fajin-corp Jun 18, 2024
c5b6175
only enable quant_format in default quantizer
fajin-corp Jun 18, 2024
60c3cf8
added UT and fixed qtype in DQ
fajin-corp Jun 19, 2024
77b7dd9
added dq matmul selectors
fajin-corp Jun 21, 2024
1685070
added selector checks, and init action
fajin-corp Jun 21, 2024
82419a3
finished attribute insertion
fajin-corp Jun 21, 2024
d5d9e61
finished initializer transpose and append to replacement node
fajin-corp Jun 21, 2024
b2548da
change target name generation
fajin-corp Jun 21, 2024
3aa704c
added selector and action to qdq selector transformer
fajin-corp Jun 22, 2024
2a10834
fixed linting
fajin-corp Jun 22, 2024
b516487
fixed building UT
fajin-corp Jun 24, 2024
3fe19cb
fixed non-convert ut
fajin-corp Jun 25, 2024
67542a5
fixed action calling transpose
fajin-corp Jun 26, 2024
95de135
finished changing quantize
fajin-corp Jun 26, 2024
0ad7fe4
finished modifying transpose
fajin-corp Jun 26, 2024
f81d7bb
updated mlas kernel calling
fajin-corp Jun 26, 2024
8e8b314
fixed mlas scale calc bug
fajin-corp Jun 26, 2024
5729762
passed UT
fajin-corp Jun 26, 2024
c4805c8
fixed python build
fajin-corp Jun 26, 2024
53635fa
fixing ci
fajin-corp Jun 26, 2024
425f61b
fixing ci
fajin-corp Jun 26, 2024
b188f55
fixing minimal build
fajin-corp Jun 26, 2024
1d34c27
fixing ci
fajin-corp Jun 26, 2024
75bc4c5
change dq matmul tool chain interface for genAI
fajin-corp Jun 28, 2024
f8773d9
pass accuracy from session.config
fajin-corp Jun 28, 2024
d5b032e
added UT for accuracy level in session config options
fajin-corp Jun 28, 2024
2d01a38
fixed missing transformer path
fajin-corp Jun 28, 2024
c17a9cf
resolved comments
fajin-corp Jul 6, 2024
8c2a121
fix build
fajin-corp Jul 8, 2024
baa9389
fix ut
fajin-corp Jul 8, 2024
f33dbf9
fixing arm ut
fajin-corp Jul 8, 2024
18e0000
fix linting
fajin-corp Jul 8, 2024
1ecf5c5
corrected UT semantics
fajin-corp Jul 9, 2024
f95c3d6
try to fix web ci failure
fajin-corp Jul 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions include/onnxruntime/core/optimizer/graph_transformer_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "core/common/inlined_containers.h"
#include "core/framework/session_options.h"
#include "core/optimizer/graph_transformer.h"
#include "core/platform/threadpool.h"

#if !defined(ORT_MINIMAL_BUILD)
#include "core/optimizer/rule_based_graph_transformer.h"
Expand Down Expand Up @@ -49,7 +50,8 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
TransformerLevel level,
const SessionOptions& session_options,
const IExecutionProvider& execution_provider /*required by constant folding*/,
const InlinedHashSet<std::string>& rules_and_transformers_to_disable = {});
const InlinedHashSet<std::string>& rules_and_transformers_to_disable = {},
concurrency::ThreadPool* intra_op_thread_pool = nullptr);

#endif // !defined(ORT_MINIMAL_BUILD)

Expand Down Expand Up @@ -78,7 +80,8 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformersForMinimalB
const SessionOptions& session_options,
const SatApplyContextVariant& apply_context,
const IExecutionProvider& cpu_execution_provider,
const InlinedHashSet<std::string>& rules_and_transformers_to_disable = {});
const InlinedHashSet<std::string>& rules_and_transformers_to_disable = {},
concurrency::ThreadPool* intra_op_thread_pool = nullptr);

#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,3 +270,8 @@ static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed
// - "0": Gemm FastMath mode is not enabled. [DEFAULT]
// - "1": Gemm FastMath mode is enabled.
static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas.enable_gemm_fastmath_arm64_bfloat16";

// When converting DQ + MatMul -> MatMulNBits, the accuracy level of the MatMulNBits is controlled by this option.
// Refer to MatMulNBits op schema for more details.
// If not provided, default is 4.
static const char* const kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel = "session.qdq_matmulnbits_accuracy_level";
26 changes: 18 additions & 8 deletions onnxruntime/core/mlas/inc/mlas_q4.h
Original file line number Diff line number Diff line change
Expand Up @@ -360,12 +360,12 @@ MlasDequantizeBlockwise(
);

/**
* @brief Blockwise 2 bits or 4 bits quantization. After quantization, the weights and zero points
* are packed row-wise. In terms of the qbits type, dst and src have the same shape, and
* scales and zero_points have the same shape.
* columns must be multiple of 8 / qbits.
* @brief Blockwise 4 bits quantization. After quantization, the weights and zero points
* are packed row-wise. If zero_points is null, quantized type is int4 with default
* zero point 0, to align with DQ schema. Otherwise, quantized type is uint4.
* In int4/uint4, dst have the same shape as src, and zero_points have the same shape as scales.
* @tparam Tin
* @tparam qbits number of bits used for quantization, 2 or 4
* @tparam qbits number of bits used for quantization, only 4 is supported
* @param src points to the floating point matrix, to be quantized, row major shape [rows, columns]
* @param scales points to the scales matrix, row major
* @param zero_points points to the zero_points matrix, row major
Expand All @@ -376,9 +376,10 @@ MlasDequantizeBlockwise(
* @param columns
* @param quant_block_size number of elements in a quantize block
* @param thread_pool
* @return the quantized type is signed.
*/
template <typename Tin, int qbits>
void
bool
MlasQDQQuantizeBlockwise(
const Tin* src,
Tin* scales,
Expand All @@ -395,8 +396,17 @@ MlasQDQQuantizeBlockwise(
* @brief Transpose blockwise quantized tensors. The src tensors are row major. src weights and zero
* points are packed row-wise. The dst tensors are column major. dst weights and zero points
* are packed column-wise.
* dst_weights and dst_zero_points are in uint4.
* If src_weights is int4 and has src_zero_points, src_weights and src_zero_points are
* converted to uint4 by adding 8.
* If src_weights is int4 and no src_zero_points, src_weights is converted to uint4 by adding 8.
* src_zero_points is 0 and dst_zero_points is 8.
* If src_weights is uint4 and has src_zero_points, just transpose.
* If src_weights is uint4 and no src_zero_points, caller must allocate dst_zero_points with
* 0 values. Otherwise exception is thrown.
* @tparam Tin
* @tparam qbits number of bits used for quantization, 2 or 4
* @tparam qbits number of bits used for quantization, only 4 is supported
* @tparam signed_quant true when quantized type is signed, false when quantized type is unsigned
* @param src_weights points to the quantized matrix, row major, shape [rows, columns] in qbits type.
* In uint8_t type, shape is [rows, columns * qbits / 8].
* @param src_scales points to the scales matrix, row major
Expand All @@ -410,7 +420,7 @@ MlasQDQQuantizeBlockwise(
* @param quant_block_size number of elements in a quantize block
* @param thread_pool
*/
template <typename Tin, int qbits>
template <typename Tin, int qbits, bool signed_quant>
void
MlasQDQTransposeBlockwiseQuantized(
const uint8_t* src_weights,
Expand Down
Loading
Loading