diff --git a/.azure-pipelines/scripts/codeScan/pylint/pylint.sh b/.azure-pipelines/scripts/codeScan/pylint/pylint.sh index c5df80c9bce..f85ea02818e 100644 --- a/.azure-pipelines/scripts/codeScan/pylint/pylint.sh +++ b/.azure-pipelines/scripts/codeScan/pylint/pylint.sh @@ -31,6 +31,7 @@ pip install torch==1.12.0 \ onnxruntime_extensions \ tf_slim \ transformers \ + accelerate \ flask==2.1.3 \ xgboost diff --git a/neural_compressor/adaptor/pytorch.py b/neural_compressor/adaptor/pytorch.py index 0f4fef519b5..2ca2ef6b9bf 100644 --- a/neural_compressor/adaptor/pytorch.py +++ b/neural_compressor/adaptor/pytorch.py @@ -3346,6 +3346,21 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None): else: self.prepare_custom_config_dict, self.convert_custom_config_dict = None, None self.fx_op_cfgs = _cfgs_to_fx_cfgs(op_cfgs, self.approach) + + # for layer-wise quant + # recipe_cfgs = tune_cfg.get('recipe_cfgs', None) + if recipe_cfgs and recipe_cfgs.get('layer_wise_quant', False) \ + and self.approach != 'post_training_dynamic_quant': + from .torch_utils.layer_wise_quant import LayerWiseQuant + + model_path = recipe_cfgs['layer_wise_quant_args'].get('model_path', None) + assert model_path is not None,\ + "the layer_wise_quant_args should have args model_path to load the weight of model." + device = recipe_cfgs['layer_wise_quant_args'].get('decvice', 'cpu') + lw_quant = LayerWiseQuant(q_model._model, model_path, self.fx_op_cfgs, device=device) + q_model._model = lw_quant.quantize(dataloader, clean_weight=False) + return q_model + self.tune_cfg['fx_sub_module_list'] = self.sub_module_list if self.approach == 'quant_aware_training': q_model._model.train() diff --git a/neural_compressor/adaptor/torch_utils/layer_wise_quant/__init__.py b/neural_compressor/adaptor/torch_utils/layer_wise_quant/__init__.py new file mode 100644 index 00000000000..6f01b1288bf --- /dev/null +++ b/neural_compressor/adaptor/torch_utils/layer_wise_quant/__init__.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Torch layer-wise quantization module.""" +from .utils import load_shell +from .quantize import LayerWiseQuant diff --git a/neural_compressor/adaptor/torch_utils/layer_wise_quant/modified_pickle.py b/neural_compressor/adaptor/torch_utils/layer_wise_quant/modified_pickle.py new file mode 100644 index 00000000000..8b260af5ed0 --- /dev/null +++ b/neural_compressor/adaptor/torch_utils/layer_wise_quant/modified_pickle.py @@ -0,0 +1,1819 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Create portable serialized representations of Python objects. + +See module copyreg for a mechanism for registering custom picklers. +See module pickletools source for extensive comments. + +Classes: + + Pickler + Unpickler + +Functions: + + dump(object, file) + dumps(object) -> string + load(file) -> object + loads(string) -> object + +Misc variables: + + __version__ + format_version + compatible_formats + +""" + +from types import FunctionType +from copyreg import dispatch_table +from copyreg import _extension_registry, _inverted_registry, _extension_cache +from itertools import islice +from functools import partial +import sys +from sys import maxsize +from struct import pack, unpack +import re +import io +import codecs +import _compat_pickle + +__all__ = ["PickleError", "PicklingError", "UnpicklingError", "Pickler", + "Unpickler", "dump", "dumps", "load", "loads"] + +try: + from _pickle import PickleBuffer + __all__.append("PickleBuffer") + _HAVE_PICKLE_BUFFER = True +except ImportError: + _HAVE_PICKLE_BUFFER = False + + +# Shortcut for use in isinstance testing +bytes_types = (bytes, bytearray) + +# These are purely informational; no code uses these. +format_version = "4.0" # File format version we write +compatible_formats = ["1.0", # Original protocol 0 + "1.1", # Protocol 0 with INST added + "1.2", # Original protocol 1 + "1.3", # Protocol 1 with BINFLOAT added + "2.0", # Protocol 2 + "3.0", # Protocol 3 + "4.0", # Protocol 4 + "5.0", # Protocol 5 + ] # Old format versions we can read + +# This is the highest protocol number we know how to read. +HIGHEST_PROTOCOL = 5 + +# The protocol we write by default. May be less than HIGHEST_PROTOCOL. +# Only bump this if the oldest still supported version of Python already +# includes it. +DEFAULT_PROTOCOL = 4 + +class PickleError(Exception): + """A common base class for the other pickling exceptions.""" + pass + +class PicklingError(PickleError): + """This exception is raised when an unpicklable object is passed to the + dump() method. + + """ + pass + +class UnpicklingError(PickleError): + """This exception is raised when there is a problem unpickling an object, + such as a security violation. + + Note that other exceptions may also be raised during unpickling, including + (but not necessarily limited to) AttributeError, EOFError, ImportError, + and IndexError. + + """ + pass + +# An instance of _Stop is raised by Unpickler.load_stop() in response to +# the STOP opcode, passing the object that is the result of unpickling. +class _Stop(Exception): + def __init__(self, value): + self.value = value + +# Jython has PyStringMap; it's a dict subclass with string keys +try: + from org.python.core import PyStringMap +except ImportError: + PyStringMap = None + +# Pickle opcodes. See pickletools.py for extensive docs. The listing +# here is in kind-of alphabetical order of 1-character pickle code. +# pickletools groups them by purpose. + +MARK = b'(' # push special markobject on stack +STOP = b'.' # every pickle ends with STOP +POP = b'0' # discard topmost stack item +POP_MARK = b'1' # discard stack top through topmost markobject +DUP = b'2' # duplicate top stack item +FLOAT = b'F' # push float object; decimal string argument +INT = b'I' # push integer or bool; decimal string argument +BININT = b'J' # push four-byte signed int +BININT1 = b'K' # push 1-byte unsigned int +LONG = b'L' # push long; decimal string argument +BININT2 = b'M' # push 2-byte unsigned int +NONE = b'N' # push None +PERSID = b'P' # push persistent object; id is taken from string arg +BINPERSID = b'Q' # " " " ; " " " " stack +REDUCE = b'R' # apply callable to argtuple, both on stack +STRING = b'S' # push string; NL-terminated string argument +BINSTRING = b'T' # push string; counted binary string argument +SHORT_BINSTRING= b'U' # " " ; " " " " < 256 bytes +UNICODE = b'V' # push Unicode string; raw-unicode-escaped'd argument +BINUNICODE = b'X' # " " " ; counted UTF-8 string argument +APPEND = b'a' # append stack top to list below it +BUILD = b'b' # call __setstate__ or __dict__.update() +GLOBAL = b'c' # push self.find_class(modname, name); 2 string args +DICT = b'd' # build a dict from stack items +EMPTY_DICT = b'}' # push empty dict +APPENDS = b'e' # extend list on stack by topmost stack slice +GET = b'g' # push item from memo on stack; index is string arg +BINGET = b'h' # " " " " " " ; " " 1-byte arg +INST = b'i' # build & push class instance +LONG_BINGET = b'j' # push item from memo on stack; index is 4-byte arg +LIST = b'l' # build list from topmost stack items +EMPTY_LIST = b']' # push empty list +OBJ = b'o' # build & push class instance +PUT = b'p' # store stack top in memo; index is string arg +BINPUT = b'q' # " " " " " ; " " 1-byte arg +LONG_BINPUT = b'r' # " " " " " ; " " 4-byte arg +SETITEM = b's' # add key+value pair to dict +TUPLE = b't' # build tuple from topmost stack items +EMPTY_TUPLE = b')' # push empty tuple +SETITEMS = b'u' # modify dict by adding topmost key+value pairs +BINFLOAT = b'G' # push float; arg is 8-byte float encoding + +TRUE = b'I01\n' # not an opcode; see INT docs in pickletools.py +FALSE = b'I00\n' # not an opcode; see INT docs in pickletools.py + +# Protocol 2 + +PROTO = b'\x80' # identify pickle protocol +NEWOBJ = b'\x81' # build object by applying cls.__new__ to argtuple +EXT1 = b'\x82' # push object from extension registry; 1-byte index +EXT2 = b'\x83' # ditto, but 2-byte index +EXT4 = b'\x84' # ditto, but 4-byte index +TUPLE1 = b'\x85' # build 1-tuple from stack top +TUPLE2 = b'\x86' # build 2-tuple from two topmost stack items +TUPLE3 = b'\x87' # build 3-tuple from three topmost stack items +NEWTRUE = b'\x88' # push True +NEWFALSE = b'\x89' # push False +LONG1 = b'\x8a' # push long from < 256 bytes +LONG4 = b'\x8b' # push really big long + +_tuplesize2code = [EMPTY_TUPLE, TUPLE1, TUPLE2, TUPLE3] + +# Protocol 3 (Python 3.x) + +BINBYTES = b'B' # push bytes; counted binary string argument +SHORT_BINBYTES = b'C' # " " ; " " " " < 256 bytes + +# Protocol 4 + +SHORT_BINUNICODE = b'\x8c' # push short string; UTF-8 length < 256 bytes +BINUNICODE8 = b'\x8d' # push very long string +BINBYTES8 = b'\x8e' # push very long bytes string +EMPTY_SET = b'\x8f' # push empty set on the stack +ADDITEMS = b'\x90' # modify set by adding topmost stack items +FROZENSET = b'\x91' # build frozenset from topmost stack items +NEWOBJ_EX = b'\x92' # like NEWOBJ but work with keyword only arguments +STACK_GLOBAL = b'\x93' # same as GLOBAL but using names on the stacks +MEMOIZE = b'\x94' # store top of the stack in memo +FRAME = b'\x95' # indicate the beginning of a new frame + +# Protocol 5 + +BYTEARRAY8 = b'\x96' # push bytearray +NEXT_BUFFER = b'\x97' # push next out-of-band buffer +READONLY_BUFFER = b'\x98' # make top of stack readonly + +__all__.extend([x for x in dir() if re.match("[A-Z][A-Z0-9_]+$", x)]) + + +class _Framer: # pragma: no cover + + _FRAME_SIZE_MIN = 4 + _FRAME_SIZE_TARGET = 64 * 1024 + + def __init__(self, file_write): + self.file_write = file_write + self.current_frame = None + + def start_framing(self): + self.current_frame = io.BytesIO() + + def end_framing(self): + if self.current_frame and self.current_frame.tell() > 0: + self.commit_frame(force=True) + self.current_frame = None + + def commit_frame(self, force=False): + if self.current_frame: + f = self.current_frame + if f.tell() >= self._FRAME_SIZE_TARGET or force: + data = f.getbuffer() + write = self.file_write + if len(data) >= self._FRAME_SIZE_MIN: + # Issue a single call to the write method of the underlying + # file object for the frame opcode with the size of the + # frame. The concatenation is expected to be less expensive + # than issuing an additional call to write. + write(FRAME + pack("': + raise AttributeError("Can't get local attribute {!r} on {!r}" + .format(name, obj)) + try: + parent = obj + obj = getattr(obj, subpath) + except AttributeError: + raise AttributeError("Can't get attribute {!r} on {!r}" + .format(name, obj)) from None + return obj, parent + +def whichmodule(obj, name): # pragma: no cover + """Find the module an object belong to.""" + module_name = getattr(obj, '__module__', None) + if module_name is not None: + return module_name + # Protect the iteration by using a list copy of sys.modules against dynamic + # modules that trigger imports of other modules upon calls to getattr. + for module_name, module in sys.modules.copy().items(): + if (module_name == '__main__' + or module_name == '__mp_main__' # bpo-42406 + or module is None): + continue + try: + if _getattribute(module, name)[0] is obj: + return module_name + except AttributeError: + pass + return '__main__' + +def encode_long(x): # pragma: no cover + r"""Encode a long to a two's complement little-endian binary string. + Note that 0 is a special case, returning an empty string, to save a + byte in the LONG1 pickling context. + + >>> encode_long(0) + b'' + >>> encode_long(255) + b'\xff\x00' + >>> encode_long(32767) + b'\xff\x7f' + >>> encode_long(-256) + b'\x00\xff' + >>> encode_long(-32768) + b'\x00\x80' + >>> encode_long(-128) + b'\x80' + >>> encode_long(127) + b'\x7f' + >>> + """ + if x == 0: + return b'' + nbytes = (x.bit_length() >> 3) + 1 + result = x.to_bytes(nbytes, byteorder='little', signed=True) + if x < 0 and nbytes > 1: + if result[-1] == 0xff and (result[-2] & 0x80) != 0: + result = result[:-1] + return result + +def decode_long(data): # pragma: no cover + r"""Decode a long from a two's complement little-endian binary string. + + >>> decode_long(b'') + 0 + >>> decode_long(b"\xff\x00") + 255 + >>> decode_long(b"\xff\x7f") + 32767 + >>> decode_long(b"\x00\xff") + -256 + >>> decode_long(b"\x00\x80") + -32768 + >>> decode_long(b"\x80") + -128 + >>> decode_long(b"\x7f") + 127 + """ + return int.from_bytes(data, byteorder='little', signed=True) + + +# Pickling machinery + +class _Pickler: # pragma: no cover + + def __init__(self, file, protocol=None, *, fix_imports=True, + buffer_callback=None): + """This takes a binary file for writing a pickle data stream. + + The optional *protocol* argument tells the pickler to use the + given protocol; supported protocols are 0, 1, 2, 3, 4 and 5. + The default protocol is 4. It was introduced in Python 3.4, and + is incompatible with previous versions. + + Specifying a negative protocol version selects the highest + protocol version supported. The higher the protocol used, the + more recent the version of Python needed to read the pickle + produced. + + The *file* argument must have a write() method that accepts a + single bytes argument. It can thus be a file object opened for + binary writing, an io.BytesIO instance, or any other custom + object that meets this interface. + + If *fix_imports* is True and *protocol* is less than 3, pickle + will try to map the new Python 3 names to the old module names + used in Python 2, so that the pickle data stream is readable + with Python 2. + + If *buffer_callback* is None (the default), buffer views are + serialized into *file* as part of the pickle stream. + + If *buffer_callback* is not None, then it can be called any number + of times with a buffer view. If the callback returns a false value + (such as None), the given buffer is out-of-band; otherwise the + buffer is serialized in-band, i.e. inside the pickle stream. + + It is an error if *buffer_callback* is not None and *protocol* + is None or smaller than 5. + """ + if protocol is None: + protocol = DEFAULT_PROTOCOL + if protocol < 0: + protocol = HIGHEST_PROTOCOL + elif not 0 <= protocol <= HIGHEST_PROTOCOL: + raise ValueError("pickle protocol must be <= %d" % HIGHEST_PROTOCOL) + if buffer_callback is not None and protocol < 5: + raise ValueError("buffer_callback needs protocol >= 5") + self._buffer_callback = buffer_callback + try: + self._file_write = file.write + except AttributeError: + raise TypeError("file must have a 'write' attribute") + self.framer = _Framer(self._file_write) + self.write = self.framer.write + self._write_large_bytes = self.framer.write_large_bytes + self.memo = {} + self.proto = int(protocol) + self.bin = protocol >= 1 + self.fast = 0 + self.fix_imports = fix_imports and protocol < 3 + + def clear_memo(self): + """Clears the pickler's "memo". + + The memo is the data structure that remembers which objects the + pickler has already seen, so that shared or recursive objects + are pickled by reference and not by value. This method is + useful when re-using picklers. + """ + self.memo.clear() + + def dump(self, obj): + """Write a pickled representation of obj to the open file.""" + # Check whether Pickler was initialized correctly. This is + # only needed to mimic the behavior of _pickle.Pickler.dump(). + if not hasattr(self, "_file_write"): + raise PicklingError("Pickler.__init__() was not called by " + "%s.__init__()" % (self.__class__.__name__,)) + if self.proto >= 2: + self.write(PROTO + pack("= 4: + self.framer.start_framing() + self.save(obj) + self.write(STOP) + self.framer.end_framing() + + def memoize(self, obj): + """Store an object in the memo.""" + + # The Pickler memo is a dictionary mapping object ids to 2-tuples + # that contain the Unpickler memo key and the object being memoized. + # The memo key is written to the pickle and will become + # the key in the Unpickler's memo. The object is stored in the + # Pickler memo so that transient objects are kept alive during + # pickling. + + # The use of the Unpickler memo length as the memo key is just a + # convention. The only requirement is that the memo values be unique. + # But there appears no advantage to any other scheme, and this + # scheme allows the Unpickler memo to be implemented as a plain (but + # growable) array, indexed by memo key. + if self.fast: + return + assert id(obj) not in self.memo + idx = len(self.memo) + self.write(self.put(idx)) + self.memo[id(obj)] = idx, obj + + # Return a PUT (BINPUT, LONG_BINPUT) opcode string, with argument i. + def put(self, idx): + if self.proto >= 4: + return MEMOIZE + elif self.bin: + if idx < 256: + return BINPUT + pack("= 2 and func_name == "__newobj_ex__": + cls, args, kwargs = args + if not hasattr(cls, "__new__"): + raise PicklingError("args[0] from {} args has no __new__" + .format(func_name)) + if obj is not None and cls is not obj.__class__: + raise PicklingError("args[0] from {} args has the wrong class" + .format(func_name)) + if self.proto >= 4: + save(cls) + save(args) + save(kwargs) + write(NEWOBJ_EX) + else: + func = partial(cls.__new__, cls, *args, **kwargs) + save(func) + save(()) + write(REDUCE) + elif self.proto >= 2 and func_name == "__newobj__": + # A __reduce__ implementation can direct protocol 2 or newer to + # use the more efficient NEWOBJ opcode, while still + # allowing protocol 0 and 1 to work normally. For this to + # work, the function returned by __reduce__ should be + # called __newobj__, and its first argument should be a + # class. The implementation for __newobj__ + # should be as follows, although pickle has no way to + # verify this: + # + # def __newobj__(cls, *args): + # return cls.__new__(cls, *args) + # + # Protocols 0 and 1 will pickle a reference to __newobj__, + # while protocol 2 (and above) will pickle a reference to + # cls, the remaining args tuple, and the NEWOBJ code, + # which calls cls.__new__(cls, *args) at unpickling time + # (see load_newobj below). If __reduce__ returns a + # three-tuple, the state from the third tuple item will be + # pickled regardless of the protocol, calling __setstate__ + # at unpickling time (see load_build below). + # + # Note that no standard __newobj__ implementation exists; + # you have to provide your own. This is to enforce + # compatibility with Python 2.2 (pickles written using + # protocol 0 or 1 in Python 2.3 should be unpicklable by + # Python 2.2). + cls = args[0] + if not hasattr(cls, "__new__"): + raise PicklingError( + "args[0] from __newobj__ args has no __new__") + if obj is not None and cls is not obj.__class__: + raise PicklingError( + "args[0] from __newobj__ args has the wrong class") + args = args[1:] + save(cls) + save(args) + write(NEWOBJ) + else: + save(func) + save(args) + write(REDUCE) + + if obj is not None: + # If the object is already in the memo, this means it is + # recursive. In this case, throw away everything we put on the + # stack, and fetch the object back from the memo. + if id(obj) in self.memo: + write(POP + self.get(self.memo[id(obj)][0])) + else: + self.memoize(obj) + + # More new special cases (that work with older protocols as + # well): when __reduce__ returns a tuple with 4 or 5 items, + # the 4th and 5th item should be iterators that provide list + # items and dict items (as (key, value) tuples), or None. + + if listitems is not None: + self._batch_appends(listitems) + + if dictitems is not None: + self._batch_setitems(dictitems) + + if state is not None: + if state_setter is None: + save(state) + write(BUILD) + else: + # If a state_setter is specified, call it instead of load_build + # to update obj's with its previous state. + # First, push state_setter and its tuple of expected arguments + # (obj, state) onto the stack. + save(state_setter) + save(obj) # simple BINGET opcode as obj is already memoized. + save(state) + write(TUPLE2) + # Trigger a state_setter(obj, state) function call. + write(REDUCE) + # The purpose of state_setter is to carry-out an + # inplace modification of obj. We do not care about what the + # method might return, so its output is eventually removed from + # the stack. + write(POP) + + # Methods below this point are dispatched through the dispatch table + + dispatch = {} + + def save_none(self, obj): + self.write(NONE) + dispatch[type(None)] = save_none + + def save_bool(self, obj): + if self.proto >= 2: + self.write(NEWTRUE if obj else NEWFALSE) + else: + self.write(TRUE if obj else FALSE) + dispatch[bool] = save_bool + + def save_long(self, obj): + if self.bin: + # If the int is small enough to fit in a signed 4-byte 2's-comp + # format, we can store it more efficiently than the general + # case. + # First one- and two-byte unsigned ints: + if obj >= 0: + if obj <= 0xff: + self.write(BININT1 + pack("= 2: + encoded = encode_long(obj) + n = len(encoded) + if n < 256: + self.write(LONG1 + pack("d', obj)) + else: + self.write(FLOAT + repr(obj).encode("ascii") + b'\n') + dispatch[float] = save_float + + def save_bytes(self, obj): + if self.proto < 3: + if not obj: # bytes object is empty + self.save_reduce(bytes, (), obj=obj) + else: + self.save_reduce(codecs.encode, + (str(obj, 'latin1'), 'latin1'), obj=obj) + return + n = len(obj) + if n <= 0xff: + self.write(SHORT_BINBYTES + pack(" 0xffffffff and self.proto >= 4: + self._write_large_bytes(BINBYTES8 + pack("= self.framer._FRAME_SIZE_TARGET: + self._write_large_bytes(BINBYTES + pack("= self.framer._FRAME_SIZE_TARGET: + self._write_large_bytes(BYTEARRAY8 + pack("= 5") + with obj.raw() as m: + if not m.contiguous: + raise PicklingError("PickleBuffer can not be pickled when " + "pointing to a non-contiguous buffer") + in_band = True + if self._buffer_callback is not None: + in_band = bool(self._buffer_callback(obj)) + if in_band: + # Write data in-band + # XXX The C implementation avoids a copy here + if m.readonly: + self.save_bytes(m.tobytes()) + else: + self.save_bytearray(m.tobytes()) + else: + # Write data out-of-band + self.write(NEXT_BUFFER) + if m.readonly: + self.write(READONLY_BUFFER) + + dispatch[PickleBuffer] = save_picklebuffer + + def save_str(self, obj): + if self.bin: + encoded = obj.encode('utf-8', 'surrogatepass') + n = len(encoded) + if n <= 0xff and self.proto >= 4: + self.write(SHORT_BINUNICODE + pack(" 0xffffffff and self.proto >= 4: + self._write_large_bytes(BINUNICODE8 + pack("= self.framer._FRAME_SIZE_TARGET: + self._write_large_bytes(BINUNICODE + pack("= 2: + for element in obj: + save(element) + # Subtle. Same as in the big comment below. + if id(obj) in memo: + get = self.get(memo[id(obj)][0]) + self.write(POP * n + get) + else: + self.write(_tuplesize2code[n]) + self.memoize(obj) + return + + # proto 0 or proto 1 and tuple isn't empty, or proto > 1 and tuple + # has more than 3 elements. + write = self.write + write(MARK) + for element in obj: + save(element) + + if id(obj) in memo: + # Subtle. d was not in memo when we entered save_tuple(), so + # the process of saving the tuple's elements must have saved + # the tuple itself: the tuple is recursive. The proper action + # now is to throw away everything we put on the stack, and + # simply GET the tuple (it's already constructed). This check + # could have been done in the "for element" loop instead, but + # recursive tuples are a rare thing. + get = self.get(memo[id(obj)][0]) + if self.bin: + write(POP_MARK + get) + else: # proto 0 -- POP_MARK not available + write(POP * (n+1) + get) + return + + # No recursion. + write(TUPLE) + self.memoize(obj) + + dispatch[tuple] = save_tuple + + def save_list(self, obj): + if self.bin: + self.write(EMPTY_LIST) + else: # proto 0 -- can't use EMPTY_LIST + self.write(MARK + LIST) + + self.memoize(obj) + self._batch_appends(obj) + + dispatch[list] = save_list + + _BATCHSIZE = 1000 + + def _batch_appends(self, items): + # Helper to batch up APPENDS sequences + save = self.save + write = self.write + + if not self.bin: + for x in items: + save(x) + write(APPEND) + return + + it = iter(items) + while True: + tmp = list(islice(it, self._BATCHSIZE)) + n = len(tmp) + if n > 1: + write(MARK) + for x in tmp: + save(x) + write(APPENDS) + elif n: + save(tmp[0]) + write(APPEND) + # else tmp is empty, and we're done + if n < self._BATCHSIZE: + return + + def save_dict(self, obj): + if self.bin: + self.write(EMPTY_DICT) + else: # proto 0 -- can't use EMPTY_DICT + self.write(MARK + DICT) + + self.memoize(obj) + self._batch_setitems(obj.items()) + + dispatch[dict] = save_dict + if PyStringMap is not None: + dispatch[PyStringMap] = save_dict + + def _batch_setitems(self, items): + # Helper to batch up SETITEMS sequences; proto >= 1 only + save = self.save + write = self.write + + if not self.bin: + for k, v in items: + save(k) + save(v) + write(SETITEM) + return + + it = iter(items) + while True: + tmp = list(islice(it, self._BATCHSIZE)) + n = len(tmp) + if n > 1: + write(MARK) + for k, v in tmp: + save(k) + save(v) + write(SETITEMS) + elif n: + k, v = tmp[0] + save(k) + save(v) + write(SETITEM) + # else tmp is empty, and we're done + if n < self._BATCHSIZE: + return + + def save_set(self, obj): + save = self.save + write = self.write + + if self.proto < 4: + self.save_reduce(set, (list(obj),), obj=obj) + return + + write(EMPTY_SET) + self.memoize(obj) + + it = iter(obj) + while True: + batch = list(islice(it, self._BATCHSIZE)) + n = len(batch) + if n > 0: + write(MARK) + for item in batch: + save(item) + write(ADDITEMS) + if n < self._BATCHSIZE: + return + dispatch[set] = save_set + + def save_frozenset(self, obj): + save = self.save + write = self.write + + if self.proto < 4: + self.save_reduce(frozenset, (list(obj),), obj=obj) + return + + write(MARK) + for item in obj: + save(item) + + if id(obj) in self.memo: + # If the object is already in the memo, this means it is + # recursive. In this case, throw away everything we put on the + # stack, and fetch the object back from the memo. + write(POP_MARK + self.get(self.memo[id(obj)][0])) + return + + write(FROZENSET) + self.memoize(obj) + dispatch[frozenset] = save_frozenset + + def save_global(self, obj, name=None): + write = self.write + memo = self.memo + + if name is None: + name = getattr(obj, '__qualname__', None) + if name is None: + name = obj.__name__ + + module_name = whichmodule(obj, name) + try: + __import__(module_name, level=0) + module = sys.modules[module_name] + obj2, parent = _getattribute(module, name) + except (ImportError, KeyError, AttributeError): + raise PicklingError( + "Can't pickle %r: it's not found as %s.%s" % + (obj, module_name, name)) from None + else: + if obj2 is not obj: + raise PicklingError( + "Can't pickle %r: it's not the same object as %s.%s" % + (obj, module_name, name)) + + if self.proto >= 2: + code = _extension_registry.get((module_name, name)) + if code: + assert code > 0 + if code <= 0xff: + write(EXT1 + pack("= 3. + if self.proto >= 4: + self.save(module_name) + self.save(name) + write(STACK_GLOBAL) + elif parent is not module: + self.save_reduce(getattr, (parent, lastname)) + elif self.proto >= 3: + write(GLOBAL + bytes(module_name, "utf-8") + b'\n' + + bytes(name, "utf-8") + b'\n') + else: + if self.fix_imports: + r_name_mapping = _compat_pickle.REVERSE_NAME_MAPPING + r_import_mapping = _compat_pickle.REVERSE_IMPORT_MAPPING + if (module_name, name) in r_name_mapping: + module_name, name = r_name_mapping[(module_name, name)] + elif module_name in r_import_mapping: + module_name = r_import_mapping[module_name] + try: + write(GLOBAL + bytes(module_name, "ascii") + b'\n' + + bytes(name, "ascii") + b'\n') + except UnicodeEncodeError: + raise PicklingError( + "can't pickle global identifier '%s.%s' using " + "pickle protocol %i" % (module, name, self.proto)) from None + + self.memoize(obj) + + def save_type(self, obj): + if obj is type(None): + return self.save_reduce(type, (None,), obj=obj) + elif obj is type(NotImplemented): + return self.save_reduce(type, (NotImplemented,), obj=obj) + elif obj is type(...): + return self.save_reduce(type, (...,), obj=obj) + return self.save_global(obj) + + dispatch[FunctionType] = save_global + dispatch[type] = save_type + + +# Unpickling machinery + + +class _Unpickler: # pragma: no cover + + def __init__(self, file, *, fix_imports=True, + encoding="ASCII", errors="strict", buffers=None): + """This takes a binary file for reading a pickle data stream. + + The protocol version of the pickle is detected automatically, so + no proto argument is needed. + + The argument *file* must have two methods, a read() method that + takes an integer argument, and a readline() method that requires + no arguments. Both methods should return bytes. Thus *file* + can be a binary file object opened for reading, an io.BytesIO + object, or any other custom object that meets this interface. + + The file-like object must have two methods, a read() method + that takes an integer argument, and a readline() method that + requires no arguments. Both methods should return bytes. + Thus file-like object can be a binary file object opened for + reading, a BytesIO object, or any other custom object that + meets this interface. + + If *buffers* is not None, it should be an iterable of buffer-enabled + objects that is consumed each time the pickle stream references + an out-of-band buffer view. Such buffers have been given in order + to the *buffer_callback* of a Pickler object. + + If *buffers* is None (the default), then the buffers are taken + from the pickle stream, assuming they are serialized there. + It is an error for *buffers* to be None if the pickle stream + was produced with a non-None *buffer_callback*. + + Other optional arguments are *fix_imports*, *encoding* and + *errors*, which are used to control compatibility support for + pickle stream generated by Python 2. If *fix_imports* is True, + pickle will try to map the old Python 2 names to the new names + used in Python 3. The *encoding* and *errors* tell pickle how + to decode 8-bit string instances pickled by Python 2; these + default to 'ASCII' and 'strict', respectively. *encoding* can be + 'bytes' to read theses 8-bit string instances as bytes objects. + """ + self._buffers = iter(buffers) if buffers is not None else None + self._file_readline = file.readline + self._file_read = file.read + self.memo = {} + self.encoding = encoding + self.errors = errors + self.proto = 0 + self.fix_imports = fix_imports + + def load(self, tensor_name=None): + """Read a pickled object representation from the open file. + + Return the reconstituted object hierarchy specified in the file. + """ + # Check whether Unpickler was initialized correctly. This is + # only needed to mimic the behavior of _pickle.Unpickler.dump(). + + if not hasattr(self, "_file_read"): + raise UnpicklingError("Unpickler.__init__() was not called by " + "%s.__init__()" % (self.__class__.__name__,)) + self.tensor_name = tensor_name + self._unframer = _Unframer(self._file_read, self._file_readline) + self.read = self._unframer.read + self.readinto = self._unframer.readinto + self.readline = self._unframer.readline + self.metastack = [] + self.stack = [] + self.append = self.stack.append + self.proto = 0 + read = self.read + dispatch = self.dispatch + try: + while True: + key = read(1) + if not key: + raise EOFError + assert isinstance(key, bytes_types) + dispatch[key[0]](self) + except _Stop as stopinst: + return stopinst.value + + # Return a list of items pushed in the stack after last MARK instruction. + def pop_mark(self): + items = self.stack + self.stack = self.metastack.pop() + self.append = self.stack.append + return items + + def persistent_load(self, pid): + raise UnpicklingError("unsupported persistent id encountered") + + dispatch = {} + + def load_proto(self): + proto = self.read(1)[0] + if not 0 <= proto <= HIGHEST_PROTOCOL: + raise ValueError("unsupported pickle protocol: %d" % proto) + self.proto = proto + dispatch[PROTO[0]] = load_proto + + def load_frame(self): + frame_size, = unpack(' sys.maxsize: + raise ValueError("frame size > sys.maxsize: %d" % frame_size) + self._unframer.load_frame(frame_size) + dispatch[FRAME[0]] = load_frame + + def load_persid(self): + try: + pid = self.readline()[:-1].decode("ascii") + except UnicodeDecodeError: + raise UnpicklingError( + "persistent IDs in protocol 0 must be ASCII strings") + self.append(self.persistent_load(pid)) + dispatch[PERSID[0]] = load_persid + + def load_binpersid(self): + pid = self.stack.pop() + self.append(self.persistent_load(pid)) + dispatch[BINPERSID[0]] = load_binpersid + + def load_none(self): + self.append(None) + dispatch[NONE[0]] = load_none + + def load_false(self): + self.append(False) + dispatch[NEWFALSE[0]] = load_false + + def load_true(self): + self.append(True) + dispatch[NEWTRUE[0]] = load_true + + def load_int(self): + data = self.readline() + if data == FALSE[1:]: + val = False + elif data == TRUE[1:]: + val = True + else: + val = int(data, 0) + self.append(val) + dispatch[INT[0]] = load_int + + def load_binint(self): + self.append(unpack('d', self.read(8))[0]) + dispatch[BINFLOAT[0]] = load_binfloat + + def _decode_string(self, value): + # Used to allow strings from Python 2 to be decoded either as + # bytes or Unicode strings. This should be used only with the + # STRING, BINSTRING and SHORT_BINSTRING opcodes. + if self.encoding == "bytes": + return value + else: + return value.decode(self.encoding, self.errors) + + def load_string(self): + data = self.readline()[:-1] + # Strip outermost quotes + if len(data) >= 2 and data[0] == data[-1] and data[0] in b'"\'': + data = data[1:-1] + else: + raise UnpicklingError("the STRING opcode argument must be quoted") + self.append(self._decode_string(codecs.escape_decode(data)[0])) + dispatch[STRING[0]] = load_string + + def load_binstring(self): + # Deprecated BINSTRING uses signed 32-bit length + len, = unpack(' maxsize: + raise UnpicklingError("BINBYTES exceeds system's maximum size " + "of %d bytes" % maxsize) + self.append(self.read(len)) + dispatch[BINBYTES[0]] = load_binbytes + + def load_unicode(self): + self.append(str(self.readline()[:-1], 'raw-unicode-escape')) + dispatch[UNICODE[0]] = load_unicode + + def load_binunicode(self): + len, = unpack(' maxsize: + raise UnpicklingError("BINUNICODE exceeds system's maximum size " + "of %d bytes" % maxsize) + self.append(str(self.read(len), 'utf-8', 'surrogatepass')) + dispatch[BINUNICODE[0]] = load_binunicode + + def load_binunicode8(self): + len, = unpack(' maxsize: + raise UnpicklingError("BINUNICODE8 exceeds system's maximum size " + "of %d bytes" % maxsize) + self.append(str(self.read(len), 'utf-8', 'surrogatepass')) + dispatch[BINUNICODE8[0]] = load_binunicode8 + + def load_binbytes8(self): + len, = unpack(' maxsize: + raise UnpicklingError("BINBYTES8 exceeds system's maximum size " + "of %d bytes" % maxsize) + self.append(self.read(len)) + dispatch[BINBYTES8[0]] = load_binbytes8 + + def load_bytearray8(self): + len, = unpack(' maxsize: + raise UnpicklingError("BYTEARRAY8 exceeds system's maximum size " + "of %d bytes" % maxsize) + b = bytearray(len) + self.readinto(b) + self.append(b) + dispatch[BYTEARRAY8[0]] = load_bytearray8 + + def load_next_buffer(self): + if self._buffers is None: + raise UnpicklingError("pickle stream refers to out-of-band data " + "but no *buffers* argument was given") + try: + buf = next(self._buffers) + except StopIteration: + raise UnpicklingError("not enough out-of-band buffers") + self.append(buf) + dispatch[NEXT_BUFFER[0]] = load_next_buffer + + def load_readonly_buffer(self): + buf = self.stack[-1] + with memoryview(buf) as m: + if not m.readonly: + self.stack[-1] = m.toreadonly() + dispatch[READONLY_BUFFER[0]] = load_readonly_buffer + + def load_short_binstring(self): + len = self.read(1)[0] + data = self.read(len) + self.append(self._decode_string(data)) + dispatch[SHORT_BINSTRING[0]] = load_short_binstring + + def load_short_binbytes(self): + len = self.read(1)[0] + self.append(self.read(len)) + dispatch[SHORT_BINBYTES[0]] = load_short_binbytes + + def load_short_binunicode(self): + len = self.read(1)[0] + self.append(str(self.read(len), 'utf-8', 'surrogatepass')) + dispatch[SHORT_BINUNICODE[0]] = load_short_binunicode + + def load_tuple(self): + items = self.pop_mark() + self.append(tuple(items)) + dispatch[TUPLE[0]] = load_tuple + + def load_empty_tuple(self): + self.append(()) + dispatch[EMPTY_TUPLE[0]] = load_empty_tuple + + def load_tuple1(self): + self.stack[-1] = (self.stack[-1],) + dispatch[TUPLE1[0]] = load_tuple1 + + def load_tuple2(self): + self.stack[-2:] = [(self.stack[-2], self.stack[-1])] + dispatch[TUPLE2[0]] = load_tuple2 + + def load_tuple3(self): + self.stack[-3:] = [(self.stack[-3], self.stack[-2], self.stack[-1])] + dispatch[TUPLE3[0]] = load_tuple3 + + def load_empty_list(self): + self.append([]) + dispatch[EMPTY_LIST[0]] = load_empty_list + + def load_empty_dictionary(self): + self.append({}) + dispatch[EMPTY_DICT[0]] = load_empty_dictionary + + def load_empty_set(self): + self.append(set()) + dispatch[EMPTY_SET[0]] = load_empty_set + + def load_frozenset(self): + items = self.pop_mark() + self.append(frozenset(items)) + dispatch[FROZENSET[0]] = load_frozenset + + def load_list(self): + items = self.pop_mark() + self.append(items) + dispatch[LIST[0]] = load_list + + def load_dict(self): + items = self.pop_mark() + d = {items[i]: items[i+1] + for i in range(0, len(items), 2)} + self.append(d) + dispatch[DICT[0]] = load_dict + + # INST and OBJ differ only in how they get a class object. It's not + # only sensible to do the rest in a common routine, the two routines + # previously diverged and grew different bugs. + # klass is the class to instantiate, and k points to the topmost mark + # object, following which are the arguments for klass.__init__. + def _instantiate(self, klass, args): + if (args or not isinstance(klass, type) or + hasattr(klass, "__getinitargs__")): + try: + value = klass(*args) + except TypeError as err: + raise TypeError("in constructor for %s: %s" % + (klass.__name__, str(err)), sys.exc_info()[2]) + else: + value = klass.__new__(klass) + self.append(value) + + def load_inst(self): + module = self.readline()[:-1].decode("ascii") + name = self.readline()[:-1].decode("ascii") + klass = self.find_class(module, name) + self._instantiate(klass, self.pop_mark()) + dispatch[INST[0]] = load_inst + + def load_obj(self): + # Stack is ... markobject classobject arg1 arg2 ... + args = self.pop_mark() + cls = args.pop(0) + self._instantiate(cls, args) + dispatch[OBJ[0]] = load_obj + + def load_newobj(self): + args = self.stack.pop() + cls = self.stack.pop() + obj = cls.__new__(cls, *args) + self.append(obj) + dispatch[NEWOBJ[0]] = load_newobj + + def load_newobj_ex(self): + kwargs = self.stack.pop() + args = self.stack.pop() + cls = self.stack.pop() + obj = cls.__new__(cls, *args, **kwargs) + self.append(obj) + dispatch[NEWOBJ_EX[0]] = load_newobj_ex + + def load_global(self): + module = self.readline()[:-1].decode("utf-8") + name = self.readline()[:-1].decode("utf-8") + klass = self.find_class(module, name) + self.append(klass) + dispatch[GLOBAL[0]] = load_global + + def load_stack_global(self): + name = self.stack.pop() + module = self.stack.pop() + if type(name) is not str or type(module) is not str: + raise UnpicklingError("STACK_GLOBAL requires str") + self.append(self.find_class(module, name)) + dispatch[STACK_GLOBAL[0]] = load_stack_global + + def load_ext1(self): + code = self.read(1)[0] + self.get_extension(code) + dispatch[EXT1[0]] = load_ext1 + + def load_ext2(self): + code, = unpack('= 4: + return _getattribute(sys.modules[module], name)[0] + else: + return getattr(sys.modules[module], name) + + def load_reduce(self): + stack = self.stack + args = stack.pop() + func = stack[-1] + if len(args) > 0 and args[0] is None: + stack[-1] = None + else: + stack[-1] = func(*args) + # stack[-1] = func(*args) + dispatch[REDUCE[0]] = load_reduce + + def load_pop(self): + if self.stack: + del self.stack[-1] + else: + self.pop_mark() + dispatch[POP[0]] = load_pop + + def load_pop_mark(self): + self.pop_mark() + dispatch[POP_MARK[0]] = load_pop_mark + + def load_dup(self): + self.append(self.stack[-1]) + dispatch[DUP[0]] = load_dup + + def load_get(self): + i = int(self.readline()[:-1]) + self.append(self.memo[i]) + dispatch[GET[0]] = load_get + + def load_binget(self): + i = self.read(1)[0] + self.append(self.memo[i]) + dispatch[BINGET[0]] = load_binget + + def load_long_binget(self): + i, = unpack(' maxsize: + raise ValueError("negative LONG_BINPUT argument") + self.memo[i] = self.stack[-1] + dispatch[LONG_BINPUT[0]] = load_long_binput + + def load_memoize(self): + memo = self.memo + memo[len(memo)] = self.stack[-1] + dispatch[MEMOIZE[0]] = load_memoize + + def load_append(self): + stack = self.stack + value = stack.pop() + list = stack[-1] + list.append(value) + dispatch[APPEND[0]] = load_append + + def load_appends(self): + items = self.pop_mark() + list_obj = self.stack[-1] + try: + extend = list_obj.extend + except AttributeError: + pass + else: + extend(items) + return + # Even if the PEP 307 requires extend() and append() methods, + # fall back on append() if the object has no extend() method + # for backward compatibility. + append = list_obj.append + for item in items: + append(item) + dispatch[APPENDS[0]] = load_appends + + def load_setitem(self): + stack = self.stack + value = stack.pop() + key = stack.pop() + dict = stack[-1] + dict[key] = value + dispatch[SETITEM[0]] = load_setitem + + def load_setitems(self): + items = self.pop_mark() + dict = self.stack[-1] + for i in range(0, len(items), 2): + dict[items[i]] = items[i + 1] + dispatch[SETITEMS[0]] = load_setitems + + def load_additems(self): + items = self.pop_mark() + set_obj = self.stack[-1] + if isinstance(set_obj, set): + set_obj.update(items) + else: + add = set_obj.add + for item in items: + add(item) + dispatch[ADDITEMS[0]] = load_additems + + def load_build(self): + stack = self.stack + state = stack.pop() + inst = stack[-1] + setstate = getattr(inst, "__setstate__", None) + if setstate is not None: + setstate(state) + return + slotstate = None + if isinstance(state, tuple) and len(state) == 2: + state, slotstate = state + if state: + inst_dict = inst.__dict__ + intern = sys.intern + for k, v in state.items(): + if type(k) is str: + inst_dict[intern(k)] = v + else: + inst_dict[k] = v + if slotstate: + for k, v in slotstate.items(): + setattr(inst, k, v) + dispatch[BUILD[0]] = load_build + + def load_mark(self): + self.metastack.append(self.stack) + self.stack = [] + self.append = self.stack.append + dispatch[MARK[0]] = load_mark + + def load_stop(self): + value = self.stack.pop() + raise _Stop(value) + dispatch[STOP[0]] = load_stop + + +# Shorthands + +def _dump(obj, file, protocol=None, *, fix_imports=True, buffer_callback=None): # pragma: no cover + _Pickler(file, protocol, fix_imports=fix_imports, + buffer_callback=buffer_callback).dump(obj) + +def _dumps(obj, protocol=None, *, fix_imports=True, buffer_callback=None): # pragma: no cover + f = io.BytesIO() + _Pickler(f, protocol, fix_imports=fix_imports, + buffer_callback=buffer_callback).dump(obj) + res = f.getvalue() + assert isinstance(res, bytes_types) + return res + +def _load(file, *, fix_imports=True, encoding="ASCII", errors="strict", + buffers=None): # pragma: no cover + return _Unpickler(file, fix_imports=fix_imports, buffers=buffers, + encoding=encoding, errors=errors).load() + +def _loads(s, *, fix_imports=True, encoding="ASCII", errors="strict", + buffers=None): # pragma: no cover + if isinstance(s, str): + raise TypeError("Can't load pickle from unicode string") + file = io.BytesIO(s) + return _Unpickler(file, fix_imports=fix_imports, buffers=buffers, + encoding=encoding, errors=errors).load() + +# Use the faster _pickle if possible +Pickler, Unpickler = _Pickler, _Unpickler +dump, dumps, load, loads = _dump, _dumps, _load, _loads + +# Doctest +def _test(): # pragma: no cover + import doctest + return doctest.testmod() + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser( + description='display contents of the pickle files') + parser.add_argument( + 'pickle_file', type=argparse.FileType('br'), + nargs='*', help='the pickle file') + parser.add_argument( + '-t', '--test', action='store_true', + help='run self-test suite') + parser.add_argument( + '-v', action='store_true', + help='run verbosely; only affects self-test run') + args = parser.parse_args() + if args.test: + _test() + else: + if not args.pickle_file: + parser.print_help() + else: + import pprint + for f in args.pickle_file: + obj = load(f) + pprint.pprint(obj) diff --git a/neural_compressor/adaptor/torch_utils/layer_wise_quant/quantize.py b/neural_compressor/adaptor/torch_utils/layer_wise_quant/quantize.py new file mode 100644 index 00000000000..4519a7a091e --- /dev/null +++ b/neural_compressor/adaptor/torch_utils/layer_wise_quant/quantize.py @@ -0,0 +1,196 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Layer wise quantization.""" +import os +import gc +import time +import shutil +from copy import deepcopy +from tqdm import tqdm + +from .utils import torch +from ..model_wrapper import QDQLayer +from torch.quantization import prepare, convert +from accelerate.utils import set_module_tensor_to_device +from .utils import _get_path, get_named_children, update_module, load_tensor_from_shard, load_tensor + +from neural_compressor.config import default_workspace + +TMP_DIR = f'{default_workspace}/layer_wise_quant_tmp_dir_{time.time()}' + + +def mk_tmp_dir(): + os.makedirs(TMP_DIR, exist_ok=True) + + +def del_tmp_dir(): + shutil.rmtree(TMP_DIR) + + +class LayerWiseQuant: + """Layer wise quantization. + Layer-by-layer quantize the model, in order to save memomery. + """ + def __init__(self, q_model, pretrained_model_name_or_path, op_cfgs, + output_dir=None, device='cpu'): + """Init LayerWiseQuant.""" + # self.q_model = load_shell(pretrained_model_name_or_path, cls) + self.q_model = q_model + self.fp32_model = deepcopy(self.q_model) + self.path = _get_path(pretrained_model_name_or_path) + self.op_cfgs = op_cfgs + self.output_dir = output_dir + if self.output_dir: + os.makedirs(self.output_dir, exist_ok=True) + self.modules = get_named_children(self.q_model) + self.device = device + self._handle = {} + + def quantize(self, calib_data, clean_weight=True): + """The main entry of layer wise quantization.""" + mk_tmp_dir() + self._layer_wise_quantize(calib_data) + if self.output_dir: + self._save(self.output_dir, clean_weight=clean_weight) + else: + self._convert(clean_weight=clean_weight) + del_tmp_dir() + return self.q_model + + def _layer_wise_quantize(self, calib_data): + for idx, (name, module) in enumerate(self.modules): + qconfig = self.op_cfgs.module_name_qconfigs.get(name + '.module') + # if qconfig: + if module.__class__.__name__ in ['Linear']: + module = QDQLayer(module) + self.modules[idx] = (name, module) + update_module(self.q_model, name, module) + # module.qconfig = self.qconfig + module.qconfig = qconfig + self._regist_hooks() + + self.q_model.eval() + with torch.no_grad(): + if isinstance(calib_data, torch.Tensor): + self.q_model(calib_data) + elif isinstance(calib_data, torch.utils.data.dataloader.DataLoader): + pbar = tqdm(enumerate(calib_data), total=len(calib_data)) + try: + for idx, input in pbar: + pbar.set_description(f'iter {idx}') + self.q_model(**input) + except Exception: # pragma: no cover + for idx, (input, label) in pbar: + self.q_model(**input) + else: + self.q_model(**calib_data) + self._remove_hooks() + + def _save(self, path=None, clean_weight=True): # pragma: no cover + if path is None: + path = TMP_DIR + for name, module in self.modules: + self._load_state_dict(name, TMP_DIR) + new_module = convert(module, inplace=False) + torch.save(new_module, os.path.join(path, f'{name}.pt')) + del new_module + if clean_weight: + self._clean_weight(module, name) + torch.save(self.fp32_model, os.path.join(path, 'model_arch.pt')) + + def _convert(self, clean_weight=False): + for name, module in self.modules: + self._load_state_dict(name, TMP_DIR) + convert(module, inplace=True) + if clean_weight: + self._clean_weight(module, name) + + def _regist_hooks(self): + def forward_pre_hook(name): + def load_value(param_name): + if 'lm_head' in param_name and getattr(self.q_model.config, "tie_word_embeddings", True): + input_embeddings = self.q_model.get_input_embeddings() + for name, module in self.modules: + if module == input_embeddings: + param_name = name + '.' + param_name.split('.')[-1] + prefix = self.q_model.base_model_prefix + if 'pytorch_model.bin.index.json' in os.listdir(self.path): + value = load_tensor_from_shard(self.path, param_name, prefix) + else: + value = load_tensor(os.path.join(self.path, 'pytorch_model.bin'), param_name, prefix) + return value + + def hook(module, input): + file_path = os.path.join(TMP_DIR, f'{name}.pt') + if os.path.exists(file_path): + self._load_state_dict(name, TMP_DIR) + else: + if isinstance(module, QDQLayer): + for n, _ in module.module.named_parameters(): + value = load_value(name + '.' + n) + set_module_tensor_to_device(self.q_model, name + '.module.' + n, self.device, value) + prepare(module, inplace=True) + else: + for n, p in module.named_parameters(): + param_name = name + '.' + n + value = load_value(param_name) + # from hf transformers.modeling_utils._load_state_dict_into_meta_model + set_module_tensor_to_device(self.q_model, param_name, self.device, value) + return hook + + def forward_hook(name): + def hook(module, input, output): + file_path = os.path.join(TMP_DIR, f'{name}.pt') + if os.path.exists(TMP_DIR): + torch.save(module.state_dict(), file_path) + self._clean_weight(module, name) + return hook + + for name, module in self.modules: + self._handle[name] = [module.register_forward_pre_hook(forward_pre_hook(name))] + self._handle[name] += [module.register_forward_hook(forward_hook(name))] + + def _remove_hooks(self): + for handle in self._handle.values(): + [h.remove() for h in handle] + + def _clean_weight(self, module, name): + if isinstance(module, QDQLayer): + submodule = module.module + else: + submodule = module + + for n, m in submodule.named_parameters(): + is_buffer = n in submodule._buffers + old_value = getattr(submodule, n) + with torch.no_grad(): + if is_buffer: + submodule._buffers[n] = torch.zeros([0], device="meta") + else: + param_cls = type(submodule._parameters[n]) + kwargs = submodule._parameters[n].__dict__ + new_value = torch.zeros([0], device="meta") + new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to("meta") + submodule._parameters[n] = new_value + gc.collect() + + + def _load_state_dict(self, module_name, weight_path): + file_path = os.path.join(weight_path, f'{module_name}.pt') + state_dict = torch.load(file_path) + for n, p in state_dict.items(): + set_module_tensor_to_device(self.q_model, f'{module_name}.{n}', self.device, p) diff --git a/neural_compressor/adaptor/torch_utils/layer_wise_quant/torch_load.py b/neural_compressor/adaptor/torch_utils/layer_wise_quant/torch_load.py new file mode 100644 index 00000000000..a1f9c0b3548 --- /dev/null +++ b/neural_compressor/adaptor/torch_utils/layer_wise_quant/torch_load.py @@ -0,0 +1,286 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Load one specify tensor from a bin file.""" + +import io +import os +import warnings +from typing import Any, BinaryIO, Callable, Dict, Optional, Union, IO + +from .utils import torch +from neural_compressor.adaptor.torch_utils.layer_wise_quant import modified_pickle as pickle + +from torch.serialization import ( + _get_restore_location, + StorageType, + _maybe_decode_ascii, + _open_file_like, + _open_zipfile_reader, + _is_torchscript_zip, + _is_zipfile +) + +FILE_LIKE = Union[str, os.PathLike, BinaryIO, IO[bytes]] +MAP_LOCATION = Optional[Union[Callable[[torch.Tensor, str], torch.Tensor], + torch.device, str, Dict[str, str]]] + + +def _load(zip_file, tensor_name, prefix, map_location, pickle_module, + pickle_file='data.pkl', **pickle_load_args): + restore_location = _get_restore_location(map_location) + + loaded_storages = {} + + from packaging.version import Version + torch_version = torch.__version__.split('+')[0] + version = Version(torch_version) + + def load_tensor(dtype, numel, key, location): + name = f'data/{key}' + + if version.release < Version("2.0.0").release: # pragma: no cover + storage = zip_file.get_storage_from_record(name, numel, torch.UntypedStorage).storage().untyped() + typed_storage = torch.storage.TypedStorage( + wrap_storage=restore_location(storage, location), + dtype=dtype) + loaded_storages[key] = typed_storage + else: + storage = zip_file.get_storage_from_record(name, numel, torch.UntypedStorage)\ + ._typed_storage()._untyped_storage + typed_storage = torch.storage.TypedStorage( + wrap_storage=restore_location(storage, location), + dtype=dtype, + _internal=True) + + if typed_storage._data_ptr() != 0: + loaded_storages[key] = typed_storage + + return typed_storage + + # def persistent_load(saved_id): + # print(saved_id) + # assert isinstance(saved_id, tuple) + # typename = _maybe_decode_ascii(saved_id[0]) + # data = saved_id[1:] + + # assert typename == 'storage', \ + # f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'" + # storage_type, key, location, numel = data + # if storage_type is torch.UntypedStorage: + # dtype = torch.uint8 + # else: + # dtype = storage_type.dtype + + # if key in loaded_storages: + # typed_storage = loaded_storages[key] + # else: + # nbytes = numel * torch._utils._element_size(dtype) + # typed_storage = load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location)) + + # return typed_storage + + load_module_mapping: Dict[str, str] = { + 'torch.tensor': 'torch._tensor' + } + + # Need to subclass Unpickler instead of directly monkey-patching the find_class method + # because it's marked readonly in pickle. + # The type: ignore is because mypy can't statically determine the type of this class. + class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined] + def find_class(self, mod_name, name): + if type(name) is str and 'Storage' in name: + try: + return StorageType(name) + except KeyError: # pragma: no cover + pass + mod_name = load_module_mapping.get(mod_name, mod_name) + return super().find_class(mod_name, name) + + def persistent_load(self, saved_id): + assert isinstance(saved_id, tuple) + typename = _maybe_decode_ascii(saved_id[0]) + data = saved_id[1:] + + assert typename == 'storage', \ + f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'" + storage_type, key, location, numel = data + if storage_type is torch.UntypedStorage: # pragma: no cover + dtype = torch.uint8 + else: + dtype = storage_type.dtype + + if key in loaded_storages: + typed_storage = loaded_storages[key] + else: + name_list = [self.tensor_name] + if prefix: + no_prefix_name = self.tensor_name.split('.') + no_prefix_name.remove(prefix) + no_prefix_name = '.'.join(no_prefix_name) + name_list.append(no_prefix_name) + if self.tensor_name and self.metastack[-1][-2] not in name_list: + # typed_storage = None + # loaded_storages[key] = typed_storage + # nbytes = numel * torch._utils._element_size(dtype) + # typed_storage = load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location)) + typed_storage = None + else: + nbytes = numel * torch._utils._element_size(dtype) + typed_storage = load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location)) + + return typed_storage + # Load the data (which may in turn use `persistent_load` to load tensors) + data_file = io.BytesIO(zip_file.get_record(pickle_file)) + + unpickler = UnpicklerWrapper(data_file, **pickle_load_args) + # unpickler.persistent_load = persistent_load + result = unpickler.load(tensor_name) + + torch._utils._validate_loaded_sparse_tensors() + return result + + +def load( + f: FILE_LIKE, + tensor_name: str = None, + prefix: str = None, + map_location: MAP_LOCATION = None, + pickle_module: Any = None, + *, + weights_only: bool = False, + **pickle_load_args: Any +) -> Any: + # Reference: https://github.com/pytorch/pytorch/issues/54354 + # The first line of this docstring overrides the one Sphinx generates for the + # documentation. We need it so that Sphinx doesn't leak `pickle`s path from + # the build environment (e.g. `>> # xdoctest: +SKIP("undefined filepaths") + >>> torch.load('tensors.pt') + # Load all tensors onto the CPU + >>> torch.load('tensors.pt', map_location=torch.device('cpu')) + # Load all tensors onto the CPU, using a function + >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage) + # Load all tensors onto GPU 1 + >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1)) + # Map tensors from GPU 1 to GPU 0 + >>> torch.load('tensors.pt', map_location={'cuda:1': 'cuda:0'}) + # Load tensor from io.BytesIO object + >>> with open('tensor.pt', 'rb') as f: + ... buffer = io.BytesIO(f.read()) + >>> torch.load(buffer) + # Load a module with 'ascii' encoding for unpickling + >>> torch.load('module.pt', encoding='ascii') + """ + torch._C._log_api_usage_once("torch.load") + # Add ability to force safe only weight loads via environment variable + if os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0").lower() \ + in ['1', 'y', 'yes', 'true']: # pragma: no cover + weights_only = True + + if weights_only: # pragma: no cover + if pickle_module is not None: + raise RuntimeError("Can not safely load weights when explicit pickle_module is specified") + else: + if pickle_module is None: + pickle_module = pickle + + if 'encoding' not in pickle_load_args.keys(): + pickle_load_args['encoding'] = 'utf-8' + + with _open_file_like(f, 'rb') as opened_file: + if _is_zipfile(opened_file): + # The zipfile reader is going to advance the current file position. + # If we want to actually tail call to torch.jit.load, we need to + # reset back to the original position. + orig_position = opened_file.tell() + with _open_zipfile_reader(opened_file) as opened_zipfile: + if _is_torchscript_zip(opened_zipfile): # pragma: no cover + warnings.warn("'torch.load' received a zip file that looks like a TorchScript archive" + " dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to" + " silence this warning)", UserWarning) + opened_file.seek(orig_position) + return torch.jit.load(opened_file, map_location=map_location) + return _load(opened_zipfile, tensor_name, prefix, map_location, pickle_module, **pickle_load_args) diff --git a/neural_compressor/adaptor/torch_utils/layer_wise_quant/utils.py b/neural_compressor/adaptor/torch_utils/layer_wise_quant/utils.py new file mode 100644 index 00000000000..5b370a974c6 --- /dev/null +++ b/neural_compressor/adaptor/torch_utils/layer_wise_quant/utils.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utils for layer wise quantization.""" + +import os +import json + +from neural_compressor.utils.utility import LazyImport +torch = LazyImport("torch") +from accelerate import init_empty_weights +from transformers import AutoConfig +from transformers.models.auto.auto_factory import _BaseAutoModelClass + +from .torch_load import load + + +def get_children(model): + """Get all the children of given model.""" + module_list = [] + children = list(model.children()) + if len(children) == 0: + return [model] + for child in children: + module_list += get_children(child) + return module_list + + +def get_named_children(model, pre=[]): + """Get all the name and children of given model.""" + module_list = [] + if len(list(model.children())) == 0: + return [('.'.join(pre), model)] + for name, module in model.named_children(): + module_list += get_named_children(module, pre=pre + [name]) + return module_list + + +def dowload_hf_model(repo_id, cache_dir=None, repo_type=None, revision=None): + """Download hugging face model from hf hub.""" + from huggingface_hub.file_download import repo_folder_name, REGEX_COMMIT_HASH + from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE, DEFAULT_REVISION + from huggingface_hub.utils import EntryNotFoundError + if cache_dir is None: + cache_dir = HUGGINGFACE_HUB_CACHE + if revision is None: + revision = DEFAULT_REVISION + if repo_type is None: + repo_type = 'model' + storage_folder = os.path.join( + cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type) + ) + commit_hash = None + if REGEX_COMMIT_HASH.match(revision): + commit_hash = revision + else: + ref_path = os.path.join(storage_folder, "refs", revision) + if os.path.exists(ref_path): + with open(ref_path) as f: + commit_hash = f.read() + pointer_path = os.path.join( + storage_folder, "snapshots", commit_hash + ) + if os.path.isdir(pointer_path): + return pointer_path + else: # pragma: no cover + from huggingface_hub import snapshot_download + file_path = snapshot_download(repo_id) + return file_path + + +def load_shell(pretrained_model_name_or_path, cls): + """Load a empty model.""" + is_local = os.path.isdir(pretrained_model_name_or_path) + if is_local: # pragma: no cover + path = pretrained_model_name_or_path + else: + path = dowload_hf_model(pretrained_model_name_or_path) + if cls.__base__ == _BaseAutoModelClass: + config = AutoConfig.from_pretrained(path) + with init_empty_weights(): + model = cls.from_config(config) + else: # pragma: no cover + config = cls.config_class.from_pretrained(path) + with init_empty_weights(): + model = cls(config) + model.tie_weights() + model.eval() + return model + + +def get_super_module_by_name(model, module_name): + """Get the father module with given name of child module.""" + name_list = module_name.split(".") + for name in name_list[:-1]: + if hasattr(model, name): + model = getattr(model, name) + else: # pragma: no cover + return None + if hasattr(model, name_list[-1]): + return model + else: # pragma: no cover + return None + + +def update_module(model, module_name, new_module): + """Update module.""" + super_module = get_super_module_by_name(model, module_name) + if super_module: + setattr(super_module, module_name.split('.')[-1], new_module) + + +def load_layer_wise_quantized_model(path): # pragma: no cover + """Load layer wise quantized model.""" + model = torch.load(os.path.join(path, 'model_arch.pt')) + for name, _ in model.named_modules(): + if name + '.pt' in os.listdir(path): + update_module(model, name, torch.load(os.path.join(path, name + '.pt'))) + model.eval() + return model + + +def load_tensor_from_shard(pretrained_model_name_or_path, + tensor_name, + prefix=None): # pragma: no cover + """Load tensor from shard.""" + path = _get_path(pretrained_model_name_or_path) + idx_dict = json.load(open(os.path.join(path, 'pytorch_model.bin.index.json'), 'r'))['weight_map'] + if tensor_name not in idx_dict.keys(): + if tensor_name.replace(f'{prefix}.', '') in idx_dict.keys(): + tensor_name = tensor_name.replace(f'{prefix}.', '') + else: + assert False, '{} not in the index.json'.format(tensor_name) + return load_tensor(os.path.join(path, idx_dict[tensor_name]), tensor_name, None) + + +def load_tensor(path, tensor_name=None, prefix=None): + """Load a tensor from bin file with given tensor name.""" + # transformers.modeling_utils + if tensor_name: + if "gamma" in tensor_name: # pragma: no cover + tensor_name = tensor_name.replace("gamma", "weight") + if "beta" in tensor_name: # pragma: no cover + tensor_name = tensor_name.replace("beta", "bias") + + if os.path.isdir(path): + path = os.path.join(path, 'pytorch_model.bin') + state_dict = load(path, tensor_name, prefix) + if tensor_name: + if tensor_name in state_dict: + return state_dict[tensor_name] + else: # pragma: no cover + return state_dict[tensor_name.replace(f'{prefix}.', '')] + else: # pragma: no cover + return state_dict + +def _get_path(pretrained_model_name_or_path): + is_local = os.path.isdir(pretrained_model_name_or_path) + if is_local: # pragma: no cover + path = pretrained_model_name_or_path + else: + path = dowload_hf_model(pretrained_model_name_or_path) + return path diff --git a/neural_compressor/adaptor/torch_utils/model_wrapper.py b/neural_compressor/adaptor/torch_utils/model_wrapper.py index 153fddcba49..93480d3e6c2 100644 --- a/neural_compressor/adaptor/torch_utils/model_wrapper.py +++ b/neural_compressor/adaptor/torch_utils/model_wrapper.py @@ -68,6 +68,20 @@ def qdq_weight(self): self.module.weight = torch.nn.Parameter(weith_qdq) +class QDQLayer(torch.nn.Module): + def __init__(self, module, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.quant = torch.ao.quantization.QuantStub() + self.module = module + self.dequant = torch.ao.quantization.DeQuantStub() + + def forward(self, X): + X = self.quant(X) + X = self.module(X) + X = self.dequant(X) + return X + + class SQLinearWrapper(torch.nn.Module): def __init__(self, module, input_scale, input_minmax, alpha=0.5, dtype=torch.quint8): super().__init__() diff --git a/neural_compressor/compression/pruner/pruners/progressive.py b/neural_compressor/compression/pruner/pruners/progressive.py index 9efc6876637..8739643986c 100644 --- a/neural_compressor/compression/pruner/pruners/progressive.py +++ b/neural_compressor/compression/pruner/pruners/progressive.py @@ -133,7 +133,6 @@ def check_progressive_validity(self): f"Pruning pattern is {self.pattern.N} in {self.pattern.M}, " f"while progressive steps {self.progressive_steps} is overflowing.") else: - breakpoint() raise NotImplementedError def check_is_pruned_progressive_step(self, step): diff --git a/neural_compressor/config.py b/neural_compressor/config.py index 9205d507f1e..fde44331395 100644 --- a/neural_compressor/config.py +++ b/neural_compressor/config.py @@ -703,6 +703,8 @@ class _BaseQuantizationConfig: recipes: Recipes for quantiztaion, support list is as below. 'smooth_quant': whether do smooth quant 'smooth_quant_args': parameters for smooth_quant + 'layer_wise_quant': whether to use layer wise quant + 'layer_wise_quant_args': parameters for layer_wise_quant 'fast_bias_correction': whether do fast bias correction 'weight_correction': whether do weight correction 'gemm_to_matmul': whether convert gemm to matmul and add, only valid for onnx models @@ -859,6 +861,18 @@ def smooth_quant_args(val=None): else: return {} + def layer_wise_quant(val=None): + if val is not None: + return _check_value("layer_wise_quant", val, bool) + else: + return False + + def layer_wise_quant_args(val=None): + if val is not None: + return _check_value("layer_wise_quant_args", val, dict) + else: + return {} + def rtn_args(val=None): if val is not None: return _check_value("rtn_args", val, dict) @@ -946,6 +960,8 @@ def dedicated_qdq_pair(val=None): RECIPES = {"smooth_quant": smooth_quant, "smooth_quant_args": smooth_quant_args, + "layer_wise_quant": layer_wise_quant, + "layer_wise_quant_args": layer_wise_quant_args, "fast_bias_correction": fast_bias_correction, "weight_correction": weight_correction, "gemm_to_matmul": gemm_to_matmul, @@ -1140,6 +1156,7 @@ class PostTrainingQuantConfig(_BaseQuantizationConfig): recipes: Recipes for quantiztaion, support list is as below. 'smooth_quant': whether do smooth quant 'smooth_quant_args': parameters for smooth_quant + 'layer_wise_quant': whether to use layer wise quant 'fast_bias_correction': whether do fast bias correction 'weight_correction': whether do weight correction 'gemm_to_matmul': whether convert gemm to matmul and add, only valid for onnx models diff --git a/neural_compressor/strategy/strategy.py b/neural_compressor/strategy/strategy.py index 308a9a9ed39..48edd6329e3 100644 --- a/neural_compressor/strategy/strategy.py +++ b/neural_compressor/strategy/strategy.py @@ -1249,6 +1249,8 @@ def _tune_cfg_converter(self, op_tuning_cfg): if isinstance(v, list) and len(v) >= 1: v = v[0] tune_cfg['recipe_cfgs'].setdefault('smooth_quant_args', {})[k] = v + if 'layer_wise_quant_args' in self.config.recipes: + tune_cfg['recipe_cfgs']['layer_wise_quant_args'] = self.config.recipes['layer_wise_quant_args'] # For tuning recipe, use the default value if it not specified by recipe tuning sampler. for recipe_name, recipe_val in self._tuning_recipes_default_values.items(): if recipe_name not in tune_cfg['recipe_cfgs']: diff --git a/neural_compressor/utils/constant.py b/neural_compressor/utils/constant.py index d2ed191f15e..d255be1db16 100644 --- a/neural_compressor/utils/constant.py +++ b/neural_compressor/utils/constant.py @@ -80,6 +80,7 @@ }, "pytorch": { 'smooth_quant': [False, True], + 'layer_wise_quant': [False, True] }, } diff --git a/test/algorithm/test_layer_wise_quant.py b/test/algorithm/test_layer_wise_quant.py new file mode 100644 index 00000000000..4a16c7f362a --- /dev/null +++ b/test/algorithm/test_layer_wise_quant.py @@ -0,0 +1,77 @@ +import unittest +import sys +import torch +from torch.utils.data import DataLoader, Dataset + +sys.path.insert(0, './') + +from transformers import AutoModelForCausalLM, AutoTokenizer +from neural_compressor.adaptor.torch_utils.layer_wise_quant import load_shell +from neural_compressor import PostTrainingQuantConfig, quantization + + +class TestLayerWise(unittest.TestCase): + @classmethod + def test_layer_wise(self): + + model_name_or_path = 'facebook/opt-125m' + fp32_model = load_shell(model_name_or_path, AutoModelForCausalLM) + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + + texts = [ + "Replace me by any text you'd like.", "This is a test sentence." + ] + tokenizer.pad_token = tokenizer.eos_token + encoded_input = tokenizer(texts, return_tensors='pt', padding=True) + + class TestDataset(Dataset): + def __init__(self, encodings): + self.encodings = encodings + + def __getitem__(self, idx): + return { + key: torch.tensor(val[idx]) + for key, val in self.encodings.items() + } + + def __len__(self): + return len(self.encodings.input_ids) + + eval_dataset = TestDataset(encoded_input) + eval_dataloader = DataLoader(eval_dataset) + + conf = PostTrainingQuantConfig( + calibration_sampling_size=8, + recipes={ + "layer_wise_quant": True, + "layer_wise_quant_args": { + "model_path": "facebook/opt-125m", + "output_dir": "./saved_result" + } + }) + + q_model = quantization.fit( + fp32_model, + conf, + calib_dataloader=eval_dataloader, + eval_func=lambda x: 0.1, + ) + + def test_util(self): + from neural_compressor.adaptor.torch_utils.layer_wise_quant.utils import ( + get_children, + get_named_children, + get_super_module_by_name + ) + + model_name_or_path = 'facebook/opt-125m' + model = load_shell(model_name_or_path, AutoModelForCausalLM) + children = get_children(model) + named_children = get_named_children(model) + self.assertEqual(children, [v for k, v in named_children]) + module = get_super_module_by_name(model, named_children[0][0]) + self.assertTrue(hasattr(module, named_children[0][0].split('.')[-1])) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/requirements.txt b/test/requirements.txt index 41cf6845598..b9b11da2ca4 100644 --- a/test/requirements.txt +++ b/test/requirements.txt @@ -11,6 +11,7 @@ pandas tf_slim pillow>=8.2.0 transformers>=4.30.2 +accelerate tensorflow_model_optimization horovod tensorflow-addons