Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use inplace=True mode for WOQ #1557

Merged
merged 10 commits into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 22 additions & 21 deletions neural_compressor/adaptor/torch_utils/autoround/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,23 +127,20 @@ def __init__(
dtype=self.float_type,
).to(device),
)
self.scales = self.scales.T
self.register_buffer(
"qweight",
torch.zeros(
(math.ceil(in_features / self.n_pack), out_features),
dtype=self.compression_dtype,
).to(device),
)
self.qweight = self.qweight.T
self.register_buffer(
"qzeros",
torch.zeros(
(math.ceil(self.in_features / self.groupsize), math.ceil(self.out_features / self.n_pack)),
dtype=self.compression_dtype,
).to(device),
)
self.qzeros = self.qzeros.T
self.register_buffer("bias", torch.zeros(self.out_features, dtype=self.float_type).to(device))
else:
self.compression_dtype = compression_dtype
Expand Down Expand Up @@ -193,6 +190,10 @@ def __init__(
self.bias = None

def pack(self, int_weight, scale, zp, bias):
if self.use_optimum_format:
self.scales = self.scales.t_().contiguous()
self.qweight = self.qweight.t_().contiguous()
self.qzeros = self.qzeros.t_().contiguous()
int_weight = int_weight.to(self.device)
if self.use_optimum_format and zp is None:
# to avoid overflow
Expand All @@ -206,8 +207,8 @@ def pack(self, int_weight, scale, zp, bias):
assert scale.shape == self.scales.shape, "Scale shape is mismatched."
self.scales = scale.type(self.float_type).to(self.device)
if not self.use_optimum_format and self.compression_dim == 0:
int_weight = int_weight.T
self.qweight = self.qweight.T
int_weight = int_weight.t_().contiguous()
self.qweight = self.qweight.t_().contiguous()
origin_shape = int_weight.shape
target_shape = self.qweight.shape
assert origin_shape[0] == target_shape[0], "output channels mismatch, please check."
Expand All @@ -223,15 +224,15 @@ def pack(self, int_weight, scale, zp, bias):
tmp[:, e] = tmp[:, e] << (self.bits * e)
self.qweight[:, j] |= tmp[:, e]
if not self.use_optimum_format and self.compression_dim == 0:
self.qweight = self.qweight.T
self.qweight = self.qweight.t_().contiguous()

if zp is not None:
zp = zp.to(self.device)
if self.use_optimum_format:
zp -= 1
if self.use_optimum_format or self.compression_dim == 0:
zp = zp.T
self.qzeros = self.qzeros.T
zp = zp.t_().contiguous()
self.qzeros = self.qzeros.t_().contiguous()
assert hasattr(self, "qzeros"), "zp is not set when initializing."
target_shape = self.qzeros.shape
for j in range(target_shape[1]):
Expand All @@ -243,16 +244,16 @@ def pack(self, int_weight, scale, zp, bias):
tmp[:, e] = tmp[:, e] << (self.bits * e)
self.qzeros[:, j] |= tmp[:, e]
if self.use_optimum_format or self.compression_dim == 0:
self.qzeros = self.qzeros.T
self.qzeros = self.qzeros.t_().contiguous()
if self.use_optimum_format:
self.scales = self.scales.T
self.qweight = self.qweight.T
self.qzeros = self.qzeros.T
self.scales = self.scales.t_().contiguous()
self.qweight = self.qweight.t_().contiguous()
self.qzeros = self.qzeros.t_().contiguous()

def recover(self):
logger.debug(f"Recovering {self} weight")
scales = self.scales.T if self.use_optimum_format else self.scales
qweight = self.qweight.T if self.use_optimum_format else self.qweight
scales = self.scales.t_().contiguous() if self.use_optimum_format else self.scales
qweight = self.qweight.t_().contiguous() if self.use_optimum_format else self.qweight

device = scales.device
fp32_weight = torch.zeros(self.out_features, self.in_features, dtype=self.float_type).to(device)
Expand All @@ -264,8 +265,8 @@ def recover(self):
# unpack weight
weight = torch.zeros(self.out_features, self.in_features, dtype=weight_dtype).to(device)
if not self.use_optimum_format and self.compression_dim == 0:
weight = weight.T
qweight = qweight.T
weight = weight.t_().contiguous()
qweight = qweight.t_().contiguous()
origin_shape = weight.shape
target_shape = qweight.shape
for j in range(target_shape[1]):
Expand All @@ -280,7 +281,7 @@ def recover(self):
tmp &= mask # remove sign bit
weight[:, index] = tmp.type(weight_dtype)
if not self.use_optimum_format and self.compression_dim == 0:
weight = weight.T
weight = weight.t_().contiguous()
if "int" not in self.dtype:
new_weight = torch.zeros(self.out_features, self.in_features).to(device)
for k, v in self.int2float_mapping.items():
Expand All @@ -290,10 +291,10 @@ def recover(self):
if hasattr(self, "qzeros"):
zp_dtype = self.compression_dtype # to avoid overflow when weight-zp
zp = torch.zeros(scales.shape, dtype=zp_dtype).to(device)
qzeros = self.qzeros.T if self.use_optimum_format else self.qzeros
qzeros = self.qzeros.t_().contiguous() if self.use_optimum_format else self.qzeros
if self.use_optimum_format or self.compression_dim == 0:
zp = zp.T
qzeros = qzeros.T
zp = zp.t_().contiguous()
qzeros = qzeros.t_().contiguous()
origin_shape = zp.shape
target_shape = qzeros.shape
for j in range(target_shape[1]):
Expand All @@ -307,7 +308,7 @@ def recover(self):
tmp &= mask
zp[:, index] = tmp.type(zp_dtype)
if self.use_optimum_format or self.compression_dim == 0:
zp = zp.T
zp = zp.t_().contiguous()
if self.use_optimum_format:
# zp -= 1 may cause zp == -1, after recover it becomes 2**self.bits - 1
zp += 1
Expand Down
42 changes: 21 additions & 21 deletions neural_compressor/adaptor/torch_utils/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,9 +327,9 @@ def __init__(

def pack(self, int_weight, scale, zp, bias, g_idx=None):
if self.use_optimum_format:
self.scales = self.scales.T
self.qweight = self.qweight.T
self.qzeros = self.qzeros.T
self.scales = self.scales.t_().contiguous()
self.qweight = self.qweight.t_().contiguous()
self.qzeros = self.qzeros.t_().contiguous()
int_weight = int_weight.to(self.device)
if self.use_optimum_format and zp is None:
# to avoid overflow
Expand All @@ -350,8 +350,8 @@ def pack(self, int_weight, scale, zp, bias, g_idx=None):
assert scale.shape == self.scales.shape, "Scale shape is mismatched."
self.scales = scale.type(self.float_type).to(self.device)
if not self.use_optimum_format and self.compression_dim == 0:
int_weight = int_weight.T
self.qweight = self.qweight.T
int_weight = int_weight.t_().contiguous()
self.qweight = self.qweight.t_().contiguous()
origin_shape = int_weight.shape
target_shape = self.qweight.shape
assert origin_shape[0] == target_shape[0], "output channels mismatch, please check."
Expand All @@ -367,15 +367,15 @@ def pack(self, int_weight, scale, zp, bias, g_idx=None):
tmp[:, e] = tmp[:, e] << (self.bits * e)
self.qweight[:, j] |= tmp[:, e]
if not self.use_optimum_format and self.compression_dim == 0:
self.qweight = self.qweight.T
self.qweight = self.qweight.t_().contiguous()

if zp is not None:
zp = zp.to(self.device)
if self.use_optimum_format:
zp -= 1
if self.use_optimum_format or self.compression_dim == 0:
zp = zp.T
self.qzeros = self.qzeros.T
zp = zp.t_().contiguous()
self.qzeros = self.qzeros.t_().contiguous()
assert hasattr(self, "qzeros"), "zp is not set when initializing."
target_shape = self.qzeros.shape
for j in range(target_shape[1]):
Expand All @@ -387,16 +387,16 @@ def pack(self, int_weight, scale, zp, bias, g_idx=None):
tmp[:, e] = tmp[:, e] << (self.bits * e)
self.qzeros[:, j] |= tmp[:, e]
if self.use_optimum_format or self.compression_dim == 0:
self.qzeros = self.qzeros.T
self.qzeros = self.qzeros.t_().contiguous()
if self.use_optimum_format:
self.scales = self.scales.T
self.qweight = self.qweight.T
self.qzeros = self.qzeros.T
self.scales = self.scales.t_().contiguous()
self.qweight = self.qweight.t_().contiguous()
self.qzeros = self.qzeros.t_().contiguous()

def recover(self):
logger.debug(f"Recovering {self} weight")
scales = self.scales.T if self.use_optimum_format else self.scales
qweight = self.qweight.T if self.use_optimum_format else self.qweight
scales = self.scales.t_().contiguous() if self.use_optimum_format else self.scales
qweight = self.qweight.t_().contiguous() if self.use_optimum_format else self.qweight

device = scales.device
fp32_weight = torch.zeros(self.out_features, self.in_features, dtype=self.float_type).to(device)
Expand All @@ -411,8 +411,8 @@ def recover(self):
# unpack weight
weight = torch.zeros(self.out_features, self.in_features, dtype=weight_dtype).to(device)
if not self.use_optimum_format and self.compression_dim == 0:
weight = weight.T
qweight = qweight.T
weight = weight.t_().contiguous()
qweight = qweight.t_().contiguous()
origin_shape = weight.shape
target_shape = qweight.shape
for j in range(target_shape[1]):
Expand All @@ -427,7 +427,7 @@ def recover(self):
tmp &= mask # remove sign bit
weight[:, index] = tmp.type(weight_dtype)
if not self.use_optimum_format and self.compression_dim == 0:
weight = weight.T
weight = weight.t_().contiguous()
if "int" not in self.dtype:
new_weight = torch.zeros(self.out_features, self.in_features).to(device)
for k, v in self.int2float_mapping.items():
Expand All @@ -437,10 +437,10 @@ def recover(self):
if hasattr(self, "qzeros"):
zp_dtype = self.compression_dtype # to avoid overflow when weight-zp
zp = torch.zeros(scales.shape, dtype=zp_dtype).to(device)
qzeros = self.qzeros.T if self.use_optimum_format else self.qzeros
qzeros = self.qzeros.t_().contiguous() if self.use_optimum_format else self.qzeros
if self.use_optimum_format or self.compression_dim == 0:
zp = zp.T
qzeros = qzeros.T
zp = zp.t_().contiguous()
qzeros = qzeros.t_().contiguous()
origin_shape = zp.shape
target_shape = qzeros.shape
for j in range(target_shape[1]):
Expand All @@ -454,7 +454,7 @@ def recover(self):
tmp &= mask
zp[:, index] = tmp.type(zp_dtype)
if self.use_optimum_format or self.compression_dim == 0:
zp = zp.T
zp = zp.t_().contiguous()
if self.use_optimum_format:
# zp -= 1 may cause zp == -1, after recover it becomes 2**self.bits - 1
zp += 1
Expand Down
40 changes: 22 additions & 18 deletions neural_compressor/adaptor/torch_utils/weight_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,16 +330,17 @@ def search_clip(m, num_bits=4, group_size=32, scheme="asym", data_type="int", en
history = []
for i_s in range(int(max_shrink * n_grid)):
ratio = 1 - i_s / n_grid # 1, 0.805-1.0
cur_weight = quant_weight(
m.weight.data,
quant_weight(
m.weight.data, # in-place mode
num_bits=num_bits,
group_size=group_size,
scheme=scheme,
data_type=data_type,
full_range=enable_full_range,
quantile=ratio,
)
loss = (org_weight - cur_weight).float().pow(2).mean().item()
loss = (org_weight - m.weight.data).float().pow(2).mean().item()
m.weight.data.copy_(org_weight)
history.append(loss)
is_best = loss < best_error
if is_best:
Expand Down Expand Up @@ -429,14 +430,17 @@ def rtn_quantize(
if num_bits <= 0:
logger.info(f"Skip {name}")
continue
weight = m.weight.T if group_dim == 0 else m.weight
# contiguous is not an in-place op and returns Tensor instead of Parameter, so set it back to m.weight.data.
# transpose should be executed on Parameter level because Param.data.t_() is not an in-place op.
# Parameter.T is an in-place op while Tensor.T is not.
m.weight.data = m.weight.t_().data.contiguous() if group_dim == 0 else m.weight.data
if enable_mse_search:
quantile = search_clip(m, num_bits, group_size, scheme, data_type, enable_full_range)
if return_int:
from .model_wrapper import WeightOnlyLinear

_, scale, zp = quant_weight(
weight,
m.weight.data,
num_bits,
group_size,
scheme,
Expand All @@ -446,9 +450,9 @@ def rtn_quantize(
full_range=enable_full_range,
)
if group_dim == 0:
weight.transpose_(0, 1)
scale = scale.T if group_dim == 0 else scale
zp = zp.T if group_dim == 0 and zp is not None else zp
m.weight.t_()
scale = scale.t_().contiguous() if group_dim == 0 else scale
zp = zp.t_().contiguous() if group_dim == 0 and zp is not None else zp
new_module = WeightOnlyLinear(
m.in_features,
m.out_features,
Expand All @@ -463,14 +467,14 @@ def rtn_quantize(
device=device,
use_optimum_format=use_optimum_format,
)
new_module.pack(weight, scale, zp, m.bias)
new_module.pack(m.weight.data, scale, zp, m.bias)
if name == "":
return new_module
else:
set_module(model, name, new_module)
else:
quant_weight(
weight,
m.weight.data,
num_bits,
group_size,
scheme,
Expand All @@ -479,7 +483,7 @@ def rtn_quantize(
full_range=enable_full_range,
)
if group_dim == 0:
weight.transpose_(0, 1)
m.weight.t_()
if orig_dtype != torch.float:
m = m.to(orig_dtype)
return model
Expand Down Expand Up @@ -651,18 +655,18 @@ def quant_weight_w_scale(weight, scale, zp, group_size=-1):
if zp is not None:
zp = zp.to(device)
if group_size == -1:
return torch.round(weight / scale) if zp is None else torch.round(weight / scale + zp)
return weight.div_(scale).round_() if zp is None else weight.div_(scale).add_(zp).round_()
int_weight = torch.zeros(weight.shape).to(device)
leng = weight.shape[1] // group_size
tail_flag = False if weight.shape[1] % group_size == 0 else True
for i in range(leng):
int_weight_tmp = weight[:, i * group_size : (i + 1) * group_size] / scale[:, i].unsqueeze(1)
int_weight_tmp = weight[:, i * group_size : (i + 1) * group_size].div_(scale[:, i].unsqueeze(1))
if zp is not None:
int_weight_tmp += zp[:, i].unsqueeze(1)
int_weight[:, i * group_size : (i + 1) * group_size] = torch.round(int_weight_tmp)
int_weight_tmp.add_(zp[:, i].unsqueeze(1))
int_weight[:, i * group_size : (i + 1) * group_size].copy_(int_weight_tmp.round_())
if tail_flag:
int_weight_tmp = weight[:, leng * group_size :] / scale[:, -1].unsqueeze(1)
int_weight_tmp = weight[:, leng * group_size :].div_(scale[:, -1].unsqueeze(1))
if zp is not None:
int_weight_tmp += zp[:, -1].unsqueeze(1)
int_weight[:, leng * group_size :] = torch.round(int_weight_tmp)
int_weight_tmp.add_(zp[:, -1].unsqueeze(1))
int_weight[:, leng * group_size :].copy_(int_weight_tmp.round_())
return int_weight
4 changes: 4 additions & 0 deletions neural_compressor/torch/algorithms/weight_only/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .utility import *
from .rtn import rtn_quantize
from .gptq import gptq_quantize
Loading