Skip to content

Commit

Permalink
manually rebase
Browse files Browse the repository at this point in the history
Signed-off-by: xin3he <[email protected]>
  • Loading branch information
xin3he committed Jan 18, 2024
1 parent dd8511f commit 59be57e
Show file tree
Hide file tree
Showing 29 changed files with 1,044 additions and 1,235 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,8 @@ def get_user_model():

# 3.x api
if args.approach == 'weight_only':
from neural_compressor.torch import RTNWeightQuantConfig, GPTQConfig, quantize
from neural_compressor.torch.utils.utility import get_double_quant_config
from neural_compressor.torch.quantization import RTNConfig, GPTQConfig, quantize
from neural_compressor.torch.utils import get_double_quant_config
weight_sym = True if args.woq_scheme == "sym" else False
double_quant_config_dict = get_double_quant_config(args.double_quant_type, weight_sym=weight_sym)

Expand All @@ -243,9 +243,9 @@ def get_user_model():
"enable_mse_search": args.woq_enable_mse_search,
}
)
quant_config = RTNWeightQuantConfig.from_dict(double_quant_config_dict)
quant_config = RTNConfig.from_dict(double_quant_config_dict)
else:
quant_config = RTNWeightQuantConfig(
quant_config = RTNConfig(
weight_dtype=args.woq_dtype,
weight_bits=args.woq_bits,
weight_group_size=args.woq_group_size,
Expand All @@ -257,7 +257,7 @@ def get_user_model():
double_quant_sym=args.double_quant_sym,
double_quant_group_size=args.double_quant_group_size,
)
quant_config.set_local("lm_head", RTNWeightQuantConfig(weight_dtype="fp32"))
quant_config.set_local("lm_head", RTNConfig(weight_dtype="fp32"))
user_model = quantize(
model=user_model, quant_config=quant_config
)
Expand Down
43 changes: 22 additions & 21 deletions neural_compressor/adaptor/torch_utils/autoround/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,23 +127,20 @@ def __init__(
dtype=self.float_type,
).to(device),
)
self.scales = self.scales.T
self.register_buffer(
"qweight",
torch.zeros(
(math.ceil(in_features / self.n_pack), out_features),
dtype=self.compression_dtype,
).to(device),
)
self.qweight = self.qweight.T
self.register_buffer(
"qzeros",
torch.zeros(
(math.ceil(self.in_features / self.groupsize), math.ceil(self.out_features / self.n_pack)),
dtype=self.compression_dtype,
).to(device),
)
self.qzeros = self.qzeros.T
self.register_buffer("bias", torch.zeros(self.out_features, dtype=self.float_type).to(device))
else:
self.compression_dtype = compression_dtype
Expand Down Expand Up @@ -193,6 +190,10 @@ def __init__(
self.bias = None

def pack(self, int_weight, scale, zp, bias):
if self.use_optimum_format:
self.scales = self.scales.t_().contiguous()
self.qweight = self.qweight.t_().contiguous()
self.qzeros = self.qzeros.t_().contiguous()
int_weight = int_weight.to(self.device)
if self.use_optimum_format and zp is None:
# to avoid overflow
Expand All @@ -206,8 +207,8 @@ def pack(self, int_weight, scale, zp, bias):
assert scale.shape == self.scales.shape, "Scale shape is mismatched."
self.scales = scale.type(self.float_type).to(self.device)
if not self.use_optimum_format and self.compression_dim == 0:
int_weight = int_weight.T
self.qweight = self.qweight.T
int_weight = int_weight.t_().contiguous()
self.qweight = self.qweight.t_().contiguous()
origin_shape = int_weight.shape
target_shape = self.qweight.shape
assert origin_shape[0] == target_shape[0], "output channels mismatch, please check."
Expand All @@ -223,15 +224,15 @@ def pack(self, int_weight, scale, zp, bias):
tmp[:, e] = tmp[:, e] << (self.bits * e)
self.qweight[:, j] |= tmp[:, e]
if not self.use_optimum_format and self.compression_dim == 0:
self.qweight = self.qweight.T
self.qweight = self.qweight.t_().contiguous()

if zp is not None:
zp = zp.to(self.device)
if self.use_optimum_format:
zp -= 1
if self.use_optimum_format or self.compression_dim == 0:
zp = zp.T
self.qzeros = self.qzeros.T
zp = zp.t_().contiguous()
self.qzeros = self.qzeros.t_().contiguous()
assert hasattr(self, "qzeros"), "zp is not set when initializing."
target_shape = self.qzeros.shape
for j in range(target_shape[1]):
Expand All @@ -243,16 +244,16 @@ def pack(self, int_weight, scale, zp, bias):
tmp[:, e] = tmp[:, e] << (self.bits * e)
self.qzeros[:, j] |= tmp[:, e]
if self.use_optimum_format or self.compression_dim == 0:
self.qzeros = self.qzeros.T
self.qzeros = self.qzeros.t_().contiguous()
if self.use_optimum_format:
self.scales = self.scales.T
self.qweight = self.qweight.T
self.qzeros = self.qzeros.T
self.scales = self.scales.t_().contiguous()
self.qweight = self.qweight.t_().contiguous()
self.qzeros = self.qzeros.t_().contiguous()

def recover(self):
logger.debug(f"Recovering {self} weight")
scales = self.scales.T if self.use_optimum_format else self.scales
qweight = self.qweight.T if self.use_optimum_format else self.qweight
scales = self.scales.t_().contiguous() if self.use_optimum_format else self.scales
qweight = self.qweight.t_().contiguous() if self.use_optimum_format else self.qweight

device = scales.device
fp32_weight = torch.zeros(self.out_features, self.in_features, dtype=self.float_type).to(device)
Expand All @@ -264,8 +265,8 @@ def recover(self):
# unpack weight
weight = torch.zeros(self.out_features, self.in_features, dtype=weight_dtype).to(device)
if not self.use_optimum_format and self.compression_dim == 0:
weight = weight.T
qweight = qweight.T
weight = weight.t_().contiguous()
qweight = qweight.t_().contiguous()
origin_shape = weight.shape
target_shape = qweight.shape
for j in range(target_shape[1]):
Expand All @@ -280,7 +281,7 @@ def recover(self):
tmp &= mask # remove sign bit
weight[:, index] = tmp.type(weight_dtype)
if not self.use_optimum_format and self.compression_dim == 0:
weight = weight.T
weight = weight.t_().contiguous()
if "int" not in self.dtype:
new_weight = torch.zeros(self.out_features, self.in_features).to(device)
for k, v in self.int2float_mapping.items():
Expand All @@ -290,10 +291,10 @@ def recover(self):
if hasattr(self, "qzeros"):
zp_dtype = self.compression_dtype # to avoid overflow when weight-zp
zp = torch.zeros(scales.shape, dtype=zp_dtype).to(device)
qzeros = self.qzeros.T if self.use_optimum_format else self.qzeros
qzeros = self.qzeros.t_().contiguous() if self.use_optimum_format else self.qzeros
if self.use_optimum_format or self.compression_dim == 0:
zp = zp.T
qzeros = qzeros.T
zp = zp.t_().contiguous()
qzeros = qzeros.t_().contiguous()
origin_shape = zp.shape
target_shape = qzeros.shape
for j in range(target_shape[1]):
Expand All @@ -307,7 +308,7 @@ def recover(self):
tmp &= mask
zp[:, index] = tmp.type(zp_dtype)
if self.use_optimum_format or self.compression_dim == 0:
zp = zp.T
zp = zp.t_().contiguous()
if self.use_optimum_format:
# zp -= 1 may cause zp == -1, after recover it becomes 2**self.bits - 1
zp += 1
Expand Down
42 changes: 21 additions & 21 deletions neural_compressor/adaptor/torch_utils/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,9 +327,9 @@ def __init__(

def pack(self, int_weight, scale, zp, bias, g_idx=None):
if self.use_optimum_format:
self.scales = self.scales.T
self.qweight = self.qweight.T
self.qzeros = self.qzeros.T
self.scales = self.scales.t_().contiguous()
self.qweight = self.qweight.t_().contiguous()
self.qzeros = self.qzeros.t_().contiguous()
int_weight = int_weight.to(self.device)
if self.use_optimum_format and zp is None:
# to avoid overflow
Expand All @@ -350,8 +350,8 @@ def pack(self, int_weight, scale, zp, bias, g_idx=None):
assert scale.shape == self.scales.shape, "Scale shape is mismatched."
self.scales = scale.type(self.float_type).to(self.device)
if not self.use_optimum_format and self.compression_dim == 0:
int_weight = int_weight.T
self.qweight = self.qweight.T
int_weight = int_weight.t_().contiguous()
self.qweight = self.qweight.t_().contiguous()
origin_shape = int_weight.shape
target_shape = self.qweight.shape
assert origin_shape[0] == target_shape[0], "output channels mismatch, please check."
Expand All @@ -367,15 +367,15 @@ def pack(self, int_weight, scale, zp, bias, g_idx=None):
tmp[:, e] = tmp[:, e] << (self.bits * e)
self.qweight[:, j] |= tmp[:, e]
if not self.use_optimum_format and self.compression_dim == 0:
self.qweight = self.qweight.T
self.qweight = self.qweight.t_().contiguous()

if zp is not None:
zp = zp.to(self.device)
if self.use_optimum_format:
zp -= 1
if self.use_optimum_format or self.compression_dim == 0:
zp = zp.T
self.qzeros = self.qzeros.T
zp = zp.t_().contiguous()
self.qzeros = self.qzeros.t_().contiguous()
assert hasattr(self, "qzeros"), "zp is not set when initializing."
target_shape = self.qzeros.shape
for j in range(target_shape[1]):
Expand All @@ -387,16 +387,16 @@ def pack(self, int_weight, scale, zp, bias, g_idx=None):
tmp[:, e] = tmp[:, e] << (self.bits * e)
self.qzeros[:, j] |= tmp[:, e]
if self.use_optimum_format or self.compression_dim == 0:
self.qzeros = self.qzeros.T
self.qzeros = self.qzeros.t_().contiguous()
if self.use_optimum_format:
self.scales = self.scales.T
self.qweight = self.qweight.T
self.qzeros = self.qzeros.T
self.scales = self.scales.t_().contiguous()
self.qweight = self.qweight.t_().contiguous()
self.qzeros = self.qzeros.t_().contiguous()

def recover(self):
logger.debug(f"Recovering {self} weight")
scales = self.scales.T if self.use_optimum_format else self.scales
qweight = self.qweight.T if self.use_optimum_format else self.qweight
scales = self.scales.t_().contiguous() if self.use_optimum_format else self.scales
qweight = self.qweight.t_().contiguous() if self.use_optimum_format else self.qweight

device = scales.device
fp32_weight = torch.zeros(self.out_features, self.in_features, dtype=self.float_type).to(device)
Expand All @@ -411,8 +411,8 @@ def recover(self):
# unpack weight
weight = torch.zeros(self.out_features, self.in_features, dtype=weight_dtype).to(device)
if not self.use_optimum_format and self.compression_dim == 0:
weight = weight.T
qweight = qweight.T
weight = weight.t_().contiguous()
qweight = qweight.t_().contiguous()
origin_shape = weight.shape
target_shape = qweight.shape
for j in range(target_shape[1]):
Expand All @@ -427,7 +427,7 @@ def recover(self):
tmp &= mask # remove sign bit
weight[:, index] = tmp.type(weight_dtype)
if not self.use_optimum_format and self.compression_dim == 0:
weight = weight.T
weight = weight.t_().contiguous()
if "int" not in self.dtype:
new_weight = torch.zeros(self.out_features, self.in_features).to(device)
for k, v in self.int2float_mapping.items():
Expand All @@ -437,10 +437,10 @@ def recover(self):
if hasattr(self, "qzeros"):
zp_dtype = self.compression_dtype # to avoid overflow when weight-zp
zp = torch.zeros(scales.shape, dtype=zp_dtype).to(device)
qzeros = self.qzeros.T if self.use_optimum_format else self.qzeros
qzeros = self.qzeros.t_().contiguous() if self.use_optimum_format else self.qzeros
if self.use_optimum_format or self.compression_dim == 0:
zp = zp.T
qzeros = qzeros.T
zp = zp.t_().contiguous()
qzeros = qzeros.t_().contiguous()
origin_shape = zp.shape
target_shape = qzeros.shape
for j in range(target_shape[1]):
Expand All @@ -454,7 +454,7 @@ def recover(self):
tmp &= mask
zp[:, index] = tmp.type(zp_dtype)
if self.use_optimum_format or self.compression_dim == 0:
zp = zp.T
zp = zp.t_().contiguous()
if self.use_optimum_format:
# zp -= 1 may cause zp == -1, after recover it becomes 2**self.bits - 1
zp += 1
Expand Down
20 changes: 10 additions & 10 deletions neural_compressor/adaptor/torch_utils/weight_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def rtn_quantize(
if num_bits <= 0:
logger.info(f"Skip {name}")
continue
weight = m.weight.T if group_dim == 0 else m.weight
weight = m.weight.t_().contiguous() if group_dim == 0 else m.weight
if enable_mse_search:
quantile = search_clip(m, num_bits, group_size, scheme, data_type, enable_full_range)
if return_int:
Expand All @@ -447,8 +447,8 @@ def rtn_quantize(
)
if group_dim == 0:
weight.transpose_(0, 1)
scale = scale.T if group_dim == 0 else scale
zp = zp.T if group_dim == 0 and zp is not None else zp
scale = scale.t_().contiguous() if group_dim == 0 else scale
zp = zp.t_().contiguous() if group_dim == 0 and zp is not None else zp
new_module = WeightOnlyLinear(
m.in_features,
m.out_features,
Expand Down Expand Up @@ -651,18 +651,18 @@ def quant_weight_w_scale(weight, scale, zp, group_size=-1):
if zp is not None:
zp = zp.to(device)
if group_size == -1:
return torch.round(weight / scale) if zp is None else torch.round(weight / scale + zp)
return weight.div_(scale).round_() if zp is None else weight.div_(scale).add_(zp).round_()
int_weight = torch.zeros(weight.shape).to(device)
leng = weight.shape[1] // group_size
tail_flag = False if weight.shape[1] % group_size == 0 else True
for i in range(leng):
int_weight_tmp = weight[:, i * group_size : (i + 1) * group_size] / scale[:, i].unsqueeze(1)
int_weight_tmp = weight[:, i * group_size : (i + 1) * group_size].div_(scale[:, i].unsqueeze(1))
if zp is not None:
int_weight_tmp += zp[:, i].unsqueeze(1)
int_weight[:, i * group_size : (i + 1) * group_size] = torch.round(int_weight_tmp)
int_weight_tmp.add_(zp[:, i].unsqueeze(1))
int_weight[:, i * group_size : (i + 1) * group_size].copy_(int_weight_tmp.round_())
if tail_flag:
int_weight_tmp = weight[:, leng * group_size :] / scale[:, -1].unsqueeze(1)
int_weight_tmp = weight[:, leng * group_size :].div_(scale[:, -1].unsqueeze(1))
if zp is not None:
int_weight_tmp += zp[:, -1].unsqueeze(1)
int_weight[:, leng * group_size :] = torch.round(int_weight_tmp)
int_weight_tmp.add_(zp[:, -1].unsqueeze(1))
int_weight[:, leng * group_size :].copy_(int_weight_tmp.round_())
return int_weight
4 changes: 2 additions & 2 deletions neural_compressor/common/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def set_local(self, operator_name: str, config: BaseConfig) -> BaseConfig:
self.local_config[operator_name] = config
return self

def to_dict(self, params_list=[], operator2str=None):
def to_dict(self):
result = {}
global_config = self.get_params_dict()
if bool(self.local_config):
Expand All @@ -200,7 +200,7 @@ def get_params_dict(self):
return result

@classmethod
def from_dict(cls, config_dict, str2operator=None):
def from_dict(cls, config_dict):
"""Construct config from a dict.
Args:
Expand Down
2 changes: 1 addition & 1 deletion neural_compressor/common/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
# config name
BASE_CONFIG = "base_config"
COMPOSABLE_CONFIG = "composable_config"
RTN_WEIGHT_ONLY_QUANT = "rtn_weight_only_quant"
RTN = "rtn"
STATIC_QUANT = "static_quant"
GPTQ = "gptq"
FP8_QUANT = "fp8_quant"
Expand Down
2 changes: 1 addition & 1 deletion neural_compressor/tensorflow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def register_algo(name):
Usage example:
@register_algo(name=example_algo)
def example_algo(model: torch.nn.Module, quant_config: RTNWeightQuantConfig) -> torch.nn.Module:
def example_algo(model: torch.nn.Module, quant_config: RTNConfig) -> torch.nn.Module:
...
Args:
name (str): The name under which the algorithm function will be registered.
Expand Down
13 changes: 0 additions & 13 deletions neural_compressor/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from neural_compressor.torch.utils.utility import register_algo
from neural_compressor.torch.algorithms import rtn_quantize_entry, gptq_quantize_entry

from neural_compressor.torch.quantization import (
quantize,
RTNWeightQuantConfig,
get_default_rtn_config,
GPTQConfig,
get_default_gptq_config,
)

from neural_compressor.torch.tune import autotune, TuningConfig, get_default_tune_config
6 changes: 4 additions & 2 deletions neural_compressor/torch/algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,7 @@
# limitations under the License.


from neural_compressor.torch.algorithms.weight_only_algos import rtn_quantize_entry
from neural_compressor.torch.algorithms.weight_only_algos import gptq_quantize_entry
from .weight_only import (
rtn_quantize,
gptq_quantize,
)
1 change: 1 addition & 0 deletions neural_compressor/torch/algorithms/weight_only/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Demo of algorithm usage w/o INC
4 changes: 4 additions & 0 deletions neural_compressor/torch/algorithms/weight_only/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .utility import *
from .rtn import rtn_quantize
from .gptq import gptq_quantize
Loading

0 comments on commit 59be57e

Please sign in to comment.