Skip to content

Commit

Permalink
Add int16 quantization for embedding (PaddlePaddle#857)
Browse files Browse the repository at this point in the history
  • Loading branch information
juncaipeng authored Jul 26, 2021
1 parent 4d75cb9 commit c9c0e83
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 16 deletions.
21 changes: 18 additions & 3 deletions docs/zh_cn/api_cn/static/quant/quantization_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -461,12 +461,27 @@ fluid.Program
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
# 量化为8比特,Embedding参数的体积减小4倍,精度有轻微损失
config = {
'quantize_op_types': ['lookup_table'],
'quantize_op_types': ['lookup_table'],
'lookup_table': {
'quantize_type': 'abs_max'
}
'quantize_type': 'abs_max',
'quantize_bits': 8,
'dtype': 'int8'
}
}
'''
# 量化为16比特,Embedding参数的体积减小2倍,精度损失很小
config = {
'quantize_op_types': ['lookup_table'],
'lookup_table': {
'quantize_type': 'abs_max',
'quantize_bits': 16,
'dtype': 'int16'
}
}
'''
quant_program = quant.quant_embedding(infer_program, place, config)
更详细的用法请参考 `Embedding量化demo <https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/quant/quant_embedding>`_
Expand Down
23 changes: 20 additions & 3 deletions docs/zh_cn/tutorials/quant/static/embedding_quant_tutorial.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Embedding量化

Embedding量化将网络中的Embedding参数从`float32`类型量化到 `8-bit`整数类型,在几乎不损失模型精度的情况下减少模型的存储空间和显存占用。
Embedding量化将网络中的Embedding参数从`float32`类型量化到 `8-bit`或者 `16-bit` 整数类型,在几乎不损失模型精度的情况下减少模型的存储空间和显存占用。

Embedding量化仅能减少模型参数的体积,加快加载Embedding参数的速度,并不能显著提升模型预测速度。

Embedding量化仅能减少模型参数的体积,并不能显著提升模型预测速度。
## 使用方法

在预测时调用paddleslim `quant_embedding`接口,主要实现代码如下:
Expand All @@ -29,12 +30,28 @@ place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())

# 量化为8比特,Embedding参数的体积减小4倍,精度有轻微损失
config = {
'quantize_op_types': ['lookup_table'],
'lookup_table': {
'quantize_type': 'abs_max',
'quantize_bits': 8,
'dtype': 'int8'
}
}

'''
# 量化为16比特,Embedding参数的体积减小2倍,精度损失很小
config = {
'quantize_op_types': ['lookup_table'],
'lookup_table': {
'quantize_type': 'abs_max'
'quantize_type': 'abs_max',
'quantize_bits': 16,
'dtype': 'int16'
}
}
'''

quant_program = quant.quant_embedding(infer_program, place, config)
```

Expand Down
18 changes: 13 additions & 5 deletions paddleslim/quant/quant_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,13 @@
"quantize_bits": 8,
"dtype": "int8"
}
SUPPORT_OP_TYPES = ['lookup_table', 'fused_embedding_seq_pool', 'pyramid_hash']
SUPPORT_OP_TYPES = [
'lookup_table', 'lookup_table_v2', 'fused_embedding_seq_pool',
'pyramid_hash'
]
SUPPORT_QUANTIZE_TYPES = ['abs_max', 'log']
SUPPORT_QUANTIZE_BITS = [8]
SUPPORT_DTYPE = ['int8']
SUPPORT_QUANTIZE_BITS = [8, 16]
SUPPORT_DTYPE = ['int8', 'int16']

_default_config = {"quantize_op_types": SUPPORT_OP_TYPES, }

Expand Down Expand Up @@ -125,7 +128,7 @@ def _get_quant_var_name(var_name):
"""
get quantized var name
"""
return var_name + '.int8'
return var_name + '.int'


def _get_dequant_var_name(var_name):
Expand All @@ -151,6 +154,11 @@ def _clear_var(var_name, scope):
tensor._clear()


def _get_var_dtype(config):
return core.VarDesc.VarType.INT8 if config['dtype'] == 'int8' \
else core.VarDesc.VarType.INT16


def _quant_embedding_abs_max(graph, scope, place, config, var_name,
embedding_node):
"""
Expand Down Expand Up @@ -230,7 +238,7 @@ def _clip_array(array, config):
_get_quant_var_name(var_name),
var_type=embedding_node.type(),
shape=embedding_node.shape(),
var_dtype=core.VarDesc.VarType.INT8)
var_dtype=_get_var_dtype(config))
# create var in scope
scope.var(_get_quant_var_name(var_name))
scope.var(_get_scale_var_name(var_name))
Expand Down
34 changes: 29 additions & 5 deletions tests/test_quant_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,29 +21,53 @@


class TestQuantEmbedding(StaticCase):
def set_config(self):
self.config = {
'quantize_op_types': ['lookup_table_v2'],
'lookup_table': {
'quantize_type': 'abs_max',
'quantize_bits': 8,
'dtype': 'int8'
}
}

def test_quant_embedding(self):
self.set_config()

train_program = paddle.static.Program()
with paddle.static.program_guard(train_program):
startup_program = paddle.static.Program()
with paddle.static.program_guard(train_program, startup_program):
input_word = paddle.static.data(
name="input_word", shape=[None, 1], dtype='int64')
param_attr = paddle.ParamAttr(
name='emb',
initializer=paddle.nn.initializer.Uniform(-0.005, 0.005))
weight = train_program.global_block().create_parameter(
weight = paddle.static.create_parameter(
(100, 128), attr=param_attr, dtype="float32")

input_emb = paddle.nn.functional.embedding(
x=input_word, weight=weight, sparse=True)

infer_program = train_program.clone(for_test=True)

use_gpu = True
place = paddle.CUDAPlace(0) if use_gpu else paddle.CPUPlace()
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
exe.run(startup_program)

quant_program = quant.quant_embedding(infer_program, place)


class TestQuantEmbeddingInt16(TestQuantEmbedding):
def set_config(self):
self.config = {
'quantize_op_types': ['lookup_table'],
'lookup_table': {
'quantize_type': 'abs_max',
'quantize_bits': 16,
'dtype': 'int16'
}
}


if __name__ == '__main__':
unittest.main()

0 comments on commit c9c0e83

Please sign in to comment.