Skip to content

Commit

Permalink
use new prepare/convert
Browse files Browse the repository at this point in the history
Signed-off-by: Kaihui-intel <[email protected]>
  • Loading branch information
Kaihui-intel committed May 24, 2024
1 parent 3f3b2fd commit 8b54922
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 210 deletions.
214 changes: 14 additions & 200 deletions neural_compressor/torch/algorithms/weight_only/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,18 +144,13 @@ def __init__(self, model) -> None:
self.device = "cpu"
self.orig_model = model


def forward(self, *args, **kwargs):
if kwargs and len(args) == 0:
# Handle cases where input data is a dict
if args:
for arg in args:
self.data_pairs.append(arg)
if kwargs:
self.data_pairs.append(kwargs)
elif args and len(args) == 1:
# Handle cases where input data is a Tensor
self.data_pairs.append(args[0])
else:
logger.error("Handle cases where input data is neither a Tensor nor a dict")
return self.orig_model.forward(*args, **kwargs)


class AutoRoundQuantizer(Quantizer):
def __init__(
self,
Expand Down Expand Up @@ -256,9 +251,16 @@ def prepare(self, model: torch.nn.Module, *args, **kwargs):
Returns:
A prepared model.
"""
self.rounder = AutoRoundProcessor(
prepare_model = InputCaptureModule(model)
return prepare_model

def convert(self, model: torch.nn.Module, *args, **kwargs):
dataloader = torch.utils.data.DataLoader(model.data_pairs)
model = model.orig_model
rounder = AutoRound(
model=model,
tokenizer=None,
dataset=dataloader,
weight_config=self.quant_config or {},
enable_full_range=self.enable_full_range,
batch_size=self.batch_size,
Expand All @@ -282,15 +284,7 @@ def prepare(self, model: torch.nn.Module, *args, **kwargs):
data_type=self.data_type,
scale_dtype=self.scale_dtype,
)

self.rounder.prepare()
prepare_model = InputCaptureModule(model)
return prepare_model

def convert(self, model: torch.nn.Module, *args, **kwargs):
self.rounder.model_input = model.data_pairs
model = model.orig_model
model, weight_config = self.rounder.convert()
model, weight_config = rounder.quantize()
model.autoround_config = weight_config
if self.device == "xpu":
model = pack_model(model, weight_config, device=self.device, inplace=True,
Expand Down Expand Up @@ -384,183 +378,3 @@ def get_autoround_default_run_fn(
f"Valid samples size:{total_cnt}, Target sample size:{n_samples}"
)


class AutoRoundProcessor(AutoRound):

@torch.no_grad()
def cache_inter_data(self, block_names, n_samples, layer_names=[], last_cache_name=None):
"""Save the inputs of block_name for calibration. For layers, we cache both of inputs and output.
This method temporarily replaces the forward method of the model to capture
the inputs passing through the specified block. It then calibrates the model
using a specified number of samples. Finally, it restores the original forward
method and returns the inputs for the specified block.
Args:
block_names (list): The names of the blocks for which inputs are to be saved.
layer_names (list):The names of the layers for which inputs are to be saved.
n_samples (int): The number of samples to use for calibration.
last_cache_name (str, optional): The name of the last layer to be cached,
we could break the forward in this layer to save time
Returns:
dict: A dictionary containing the inputs for the specified block.
"""
self.inputs = {}
self.to_cached_layers = block_names + layer_names
tmp_dtype = None
## have bug if block name is not the first block
if (len(block_names) > 1 or len(layer_names) > 0) and self.low_gpu_mem_usage:
tmp_dtype = self.model.dtype
self.model = self.model.to(torch.bfloat16) if self.amp else self.model.to(torch.float32)

self.last_cache_name = last_cache_name
if last_cache_name is None and len(block_names) + len(layer_names) == 1:
self.last_cache_name = block_names[0] if len(block_names) == 1 else layer_names[0]
calib_bs = self.train_bs
self.hook_handles = []
self._replace_forward()
for data in self.model_input:
if isinstance(data, torch.Tensor):
self.model(data)
else:
self.model(**data)
self._recover_forward()
res = self.inputs
del self.model_input
del self.last_cache_name
del self.to_cached_layers
if tmp_dtype is not None:
self.model = self.model.to(tmp_dtype)

return res

@torch.no_grad()
def prepare(self):
"""Prepares a given model for quantization."""
self.block_names = get_block_names(self.model)
if len(self.block_names) == 0:
logger.warning("could not find blocks, exit with original model")
return self.model, self.weight_config

if self.amp:
self.model = self.model.to(self.amp_dtype)

self.layer_names = self.get_quantized_layer_names_outside_blocks()
self.start_time = time.time()
# all_inputs = self.try_cache_inter_data_gpucpu([block_names[0]], self.n_samples, layer_names=layer_names)

# try_cache_inter_data_gpucpu
# ([block_names[0]], self.n_samples, layer_names=layer_names)
# self, block_names, n_samples, layer_names=[], last_cache_name=None
last_cache_name = None
cache_block_names = [self.block_names[0]]
try:
self.model = self.model.to(self.device)
# all_inputs = self.cache_inter_data(
# block_names[0], self.n_samples, layer_names=layer_names, last_cache_name=last_cache_name
# )
# cache_inter_data cache_inter_data(self, block_names, n_samples, layer_names=[], last_cache_name=None):
self.inputs = {}
self.to_cached_layers = cache_block_names + self.layer_names
self.tmp_dtype = None
## have bug if block name is not the first block
if (len(cache_block_names) > 1 or len(self.layer_names) > 0) and self.low_gpu_mem_usage:
self.tmp_dtype = self.model.dtype
self.model = self.model.to(torch.bfloat16) if self.amp else self.model.to(torch.float32)

self.last_cache_name = last_cache_name
if last_cache_name is None and len(cache_block_names) + len(self.layer_names) == 1:
self.last_cache_name = cache_block_names[0] if len(cache_block_names) == 1 else self.layer_names[0]
# calib_bs = self.train_bs
self.hook_handles = []
self._replace_forward()
self.prepared_gpu = True
# self.calib(self.n_samples, calib_bs)

except:
logger.info("switch to cpu to cache inputs")
self.model = self.model.to("cpu")
torch.cuda.empty_cache()
# all_inputs = self.cache_inter_data(
# self.block_names[0], self.n_samples, layer_names=self.layer_names, last_cache_name=last_cache_name
# )
self.inputs = {}
self.to_cached_layers = cache_block_names + self.layer_names
self.tmp_dtype = None
## have bug if block name is not the first block
if (len(cache_block_names) > 1 or len(self.layer_names) > 0) and self.low_gpu_mem_usage:
self.tmp_dtype = self.model.dtype
self.model = self.model.to(torch.bfloat16) if self.amp else self.model.to(torch.float32)

self.last_cache_name = last_cache_name
if last_cache_name is None and len(cache_block_names) + len(self.layer_names) == 1:
self.last_cache_name = cache_block_names[0] if len(cache_block_names) == 1 else self.layer_names[0]
# calib_bs = self.train_bs
self.hook_handles = []
self._replace_forward()
cache_block_names
# self.calib(n_samples, calib_bs)

def convert(self):
"""Converts a prepared model to a quantized model."""
self._recover_forward()
res = self.inputs
del self.last_cache_name
del self.to_cached_layers
if self.tmp_dtype is not None:
self.model = self.model.to(self.tmp_dtype)
if self.prepared_gpu is True:
self.model = self.model.to("cpu")

all_inputs = res

del self.inputs
inputs = all_inputs[self.block_names[0]]

all_inputs.pop(self.block_names[0])
self.inputs = None
del self.inputs
if "input_ids" in inputs.keys():
total_samples = len(inputs["input_ids"])
self.n_samples = total_samples
if total_samples < self.train_bs:
self.train_bs = total_samples
logger.warning(f"force the train batch size to {total_samples} ")

self.model = self.model.to("cpu")
torch.cuda.empty_cache()
self.quant_blocks(
self.model,
inputs,
self.block_names,
n_blocks=self.n_blocks,
device=self.device,
)

self.quant_layers(self.layer_names, all_inputs)

self.dump_data_to_weight_config()

end_time = time.time()
cost_time = end_time - self.start_time
logger.info(f"quantization tuning time {cost_time}")

## dump a summary
quantized_layers = []
unquantized_layers = []
for n, m in self.model.named_modules():
if isinstance(m, tuple(self.supported_types)):
if self.weight_config[n]["bits"] == 16:
unquantized_layers.append(n)
else:
quantized_layers.append(n)
summary_info = (
f"Summary: quantized {len(quantized_layers)}/{len(quantized_layers) + len(unquantized_layers)} in the model"
)
if len(unquantized_layers) > 0:
summary_info += f", {unquantized_layers} have not been quantized"
logger.info(summary_info)

self.quantized = True
##self.model = self.model.to(self.model_orig_dtype)##keep it as amp dtype
return self.model, self.weight_config
19 changes: 9 additions & 10 deletions test/3x/torch/quantization/weight_only/test_autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def setup_method(self, method):

@pytest.mark.parametrize("quant_lm_head", [True, False])
def test_autoround(self, quant_lm_head):
gpt_j_model = copy.deepcopy(self.gptj)
quant_config = AutoRoundConfig(n_samples=20, seqlen=10, iters=10, scale_dtype="fp32")
fp32_model = copy.deepcopy(self.gptj)
quant_config = AutoRoundConfig(n_samples=32, seqlen=10, iters=10, scale_dtype="fp32")
if quant_lm_head is False:
quant_config.set_local("lm_head", AutoRoundConfig(dtype="fp32"))
logger.info(f"Test AutoRound with config {quant_config}")
Expand All @@ -56,10 +56,9 @@ def test_autoround(self, quant_lm_head):
run_args = (
self.tokenizer,
"NeelNanda/pile-10k",
20,
32,
10,
)
fp32_model = gpt_j_model

# prepare + convert API
model = prepare(model=fp32_model, quant_config=quant_config)
Expand All @@ -78,7 +77,7 @@ def test_autoround(self, quant_lm_head):
def test_autoround_with_quantize_API(self):
gpt_j_model = copy.deepcopy(self.gptj)

quant_config = get_default_AutoRound_config()
quant_config = AutoRoundConfig(n_samples=32, seqlen=10, iters=10, scale_dtype="fp32")
quant_config.set_local("lm_head", AutoRoundConfig(dtype="fp32"))

logger.info(f"Test AutoRound with config {quant_config}")
Expand All @@ -91,7 +90,7 @@ def test_autoround_with_quantize_API(self):
run_args=(
self.tokenizer,
"NeelNanda/pile-10k",
20,
32,
10,
),
)
Expand All @@ -101,15 +100,15 @@ def test_autoround_with_quantize_API(self):

def test_save_and_load(self):
fp32_model = copy.deepcopy(self.gptj)
quant_config = get_default_AutoRound_config()
quant_config = AutoRoundConfig(n_samples=32, seqlen=10, iters=10, scale_dtype="fp32")
# quant_config.set_local("lm_head", AutoRoundConfig(dtype="fp32"))
logger.info(f"Test AutoRound with config {quant_config}")

run_fn = get_autoround_default_run_fn
run_args = (
self.tokenizer,
"NeelNanda/pile-10k",
20,
32,
10,
)
# quantizer execute
Expand Down Expand Up @@ -144,10 +143,10 @@ def test_conv1d(self):
run_args = (
tokenizer,
"NeelNanda/pile-10k",
20,
32,
10,
)
quant_config = get_default_AutoRound_config()
quant_config = AutoRoundConfig(n_samples=32, seqlen=10, iters=10, scale_dtype="fp32")
model = prepare(model=model, quant_config=quant_config)
run_fn(model, *run_args)
q_model = convert(model)
Expand Down

0 comments on commit 8b54922

Please sign in to comment.