Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Smoothquant refactor for 3.x API #1792

Merged
merged 36 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
e2b5ced
Smoothquant refactor for 3.x API
violetch24 May 13, 2024
aa15c3f
Merge branch 'master' into zixuan/sq_refactor
violetch24 May 13, 2024
5ffa853
modify smoothquant ut
violetch24 May 15, 2024
3b09ba1
Update utility.py
violetch24 May 15, 2024
fe9a810
Merge branch 'master' into zixuan/sq_refactor
violetch24 May 15, 2024
4824301
modify sq example
violetch24 May 15, 2024
19fbf86
minor fix
violetch24 May 16, 2024
968b12e
Merge branch 'master' into zixuan/sq_refactor
violetch24 May 16, 2024
5c5ccd1
minor fix
violetch24 May 16, 2024
987ba49
modify ut
violetch24 May 16, 2024
5e8b7ab
update requirements
violetch24 May 16, 2024
b04a9a0
update requirements
violetch24 May 17, 2024
1ef4415
Merge branch 'master' into zixuan/sq_refactor
violetch24 May 17, 2024
55dfaed
modify ut for coverage
violetch24 May 17, 2024
9b8628b
minor fix
violetch24 May 17, 2024
276b029
Update smooth_quant.py
violetch24 May 17, 2024
605bcb0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 17, 2024
f08aaae
minor fix
violetch24 May 17, 2024
5a40e16
minor fix
violetch24 May 17, 2024
f0f5b58
Merge branch 'master' into zixuan/sq_refactor
violetch24 May 17, 2024
18b2990
Merge branch 'master' into zixuan/sq_refactor
violetch24 May 17, 2024
a98e202
Merge branch 'master' into zixuan/sq_refactor
violetch24 May 17, 2024
cd805d0
code fix
violetch24 May 17, 2024
ca311ca
Merge branch 'master' into zixuan/sq_refactor
violetch24 May 17, 2024
cf9096e
Update requirements.txt
violetch24 May 17, 2024
e365ce7
Update requirements.txt
violetch24 May 17, 2024
ce214bd
Update run_clm_no_trainer.py
violetch24 May 17, 2024
b79e74c
modify ut
violetch24 May 17, 2024
a67daff
Merge branch 'master' into zixuan/sq_refactor
violetch24 May 19, 2024
757ce29
ut coverage
violetch24 May 19, 2024
e058667
minor fix
violetch24 May 19, 2024
718f163
minor fix
violetch24 May 19, 2024
28fbb09
remove overrides
violetch24 May 20, 2024
2e93927
ut for 2.x and 3.x API
violetch24 May 20, 2024
1c0f4aa
Update test_smooth_quant.py
violetch24 May 20, 2024
531e8d9
Update test_smooth_quant.py
violetch24 May 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -366,20 +366,11 @@ def run_fn(model):

from utils import get_example_inputs
example_inputs = get_example_inputs(user_model, calib_dataloader)
if args.sq:
# currently, smooth quant only support quantize API
# TODO: support prepare/convert API for smooth quant
from neural_compressor.torch.quantization import quantize

user_model = quantize(
model=user_model, quant_config=quant_config, example_inputs=example_inputs, run_fn=run_fn
)
else:
from neural_compressor.torch.quantization import prepare, convert

user_model = prepare(model=user_model, quant_config=quant_config, example_inputs=example_inputs)
run_fn(user_model)
user_model = convert(user_model)
from neural_compressor.torch.quantization import prepare, convert
user_model = prepare(model=user_model, quant_config=quant_config, example_inputs=example_inputs)
run_fn(user_model)
user_model = convert(user_model)
user_model.save(args.output_dir)


Expand Down
33 changes: 30 additions & 3 deletions neural_compressor/torch/algorithms/base_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Any, Optional
from typing import Any, Callable, Optional

import torch

Expand Down Expand Up @@ -76,7 +76,28 @@ def convert(self, model: torch.nn.Module, *args: Any, **kwargs: Any):
"""
raise NotImplementedError("{} doesn't implement `convert` function. ".format(self.__class__.__name__))

def quantize(self, model: torch.nn.Module, *args: Any, **kwargs: Any):
@abstractmethod
violetch24 marked this conversation as resolved.
Show resolved Hide resolved
def quantize(
self,
model: torch.nn.Module,
tune_cfg: OrderedDict,
run_fn: Callable,
example_inputs: Any,
inplace=True,
*args,
**kwargs
):
"""Quantizes a given float model.

Args:
model (torch.nn.Module): The float model to be quantized.

Returns:
A quantized model.
"""
raise NotImplementedError("{} doesn't implement `quantize` function. ".format(self.__class__.__name__))

def quantize(self, model: torch.nn.Module, *args: Any, **kwargs: Any): # noqa: F811
"""Quantizes a given float model.

Args:
Expand Down Expand Up @@ -111,5 +132,11 @@ def execute(self, model: torch.nn.Module, mode, *args: Any, **kwargs: Any):
elif mode == Mode.CONVERT:
model = self.convert(model, *args, **kwargs)
elif mode == Mode.QUANTIZE:
model = self.quantize(model, *args, **kwargs)
if "recipe_cfgs" in self.quant_config: # keep quantize API for smoothquant
run_fn = kwargs.get("run_fn", None)
example_inputs = kwargs.get("example_inputs", None)
inplace = kwargs.get("inplace", True)
model = self.quantize(model, self.quant_config, run_fn, example_inputs, inplace)
else:
model = self.quantize(model, *args, **kwargs)
return model
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@
# limitations under the License.

from .utility import *
from .smooth_quant import smooth_quantize
from .smooth_quant import SmoothQuantQuantizer
from .save_load import save, load, recover_model_from_json
Loading
Loading