Skip to content

Commit

Permalink
added more arguments to the transformer quantizer
Browse files Browse the repository at this point in the history
fixed assert statement in attentions.py to allow training with quantizer

PiperOrigin-RevId: 652968446
  • Loading branch information
The praxis Authors committed Jul 16, 2024
1 parent dafadc1 commit eb5473d
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 1 deletion.
2 changes: 1 addition & 1 deletion praxis/layers/quantization/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def _get_weight_scale_shape(self, block_size, use_block_size):
elif axes is not None:
h_sharding += axes
wt = [h_sharding, wp.wt[2]]
assert len(self.wt) == 2
assert len(wt) == 2
else:
wt = wp.wt

Expand Down
40 changes: 40 additions & 0 deletions praxis/layers/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,11 @@ def for_transformer(
num_bits_act: int | None = None,
use_symmetric_act: bool | None = None,
skip_transformers: list[str] | None = None,
num_optimize_clipping: int | None = None,
min_clipping: float | None = None,
clipping_coeff: float = 1.0,
optimize_clipping_per_channel: bool = False,
vn_scale: float | None = None,
):
"""Find and quantize transformer.
Expand Down Expand Up @@ -459,6 +464,16 @@ def for_transformer(
weight_quant_only is false.
skip_transformers: If not None, will skip quantizing transformers with the
name in this list.
num_optimize_clipping: Number of optimization steps used for scale
estimation with search over clipping values in range [min_clipping ... 1].
min_clipping: Clipping value which will be used for clipping optimization in
range [min_clipping ... 1].
clipping_coeff: The coefficient to shrink the hard range for weight
quantization. 1.0 means using hard min/max.
optimize_clipping_per_channel: If True choose the best clipping value per
channel, else per-tensor. It only works when min_clipping and
num_optimize_clipping are set.
vn_scale: Scale coefficient for VN quantization.
Returns:
A modifier that quantizes transformers when applied to a config.
Expand Down Expand Up @@ -505,6 +520,11 @@ def task(self):
use_symmetric_act=use_symmetric_act,
num_bits_act=num_bits_act,
skip_transformers=skip_transformers,
num_optimize_clipping=num_optimize_clipping,
min_clipping=min_clipping,
clipping_coeff=clipping_coeff,
optimize_clipping_per_channel=optimize_clipping_per_channel,
vn_scale=vn_scale,
)
return task_p

Expand Down Expand Up @@ -715,6 +735,11 @@ def set_transformer_quantization(
num_bits_act: int | None = None,
use_symmetric_act: bool | None = None,
skip_transformers: list[str] | None = None,
num_optimize_clipping: int | None = None,
min_clipping: int | None = None,
clipping_coeff: float = 1.0,
optimize_clipping_per_channel: bool | None = None,
vn_scale: float | None = None,
):
"""Sets quantization params for TransformerLm or TransformerEncoderDecoder.
Expand Down Expand Up @@ -764,6 +789,16 @@ def set_transformer_quantization(
weight_quant_only is false.
skip_transformers: If not None, will skip quantizing transformers with the
name in this list.
num_optimize_clipping: Number of optimization steps used for scale
estimation with search over clipping values in range [min_clipping ... 1].
min_clipping: Clipping value which will be used for clipping optimization in
range [min_clipping ... 1].
clipping_coeff: The coefficient to shrink the hard range for weight
quantization. 1.0 means using hard min/max.
optimize_clipping_per_channel: If True choose the best clipping value per
channel, else per-tensor. It only works when min_clipping and
num_optimize_clipping are set.
vn_scale: Scale coefficient for VN quantization.
"""
weight_quantization_params = WeightQuantizationParams(
precision=num_bits,
Expand All @@ -773,6 +808,11 @@ def set_transformer_quantization(
use_int4_packed_weights=use_int4_packed_weights,
int4_packed_weights_container_dtype=int4_packed_weights_container_dtype,
# Pass internal quantization parameters.
num_optimize_clipping=num_optimize_clipping,
min_clipping=min_clipping,
clipping_coeff=clipping_coeff,
optimize_clipping_per_channel=optimize_clipping_per_channel,
vn_scale=vn_scale,
)
act_quantization_params = None
if (
Expand Down

0 comments on commit eb5473d

Please sign in to comment.