Skip to content

Commit

Permalink
revert the hook change
Browse files Browse the repository at this point in the history
  • Loading branch information
wenhuach21 committed May 31, 2024
1 parent 787fa01 commit 5c94cea
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 31 deletions.
18 changes: 8 additions & 10 deletions auto_round/auto_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,11 @@ def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):

def validate_environment(self, *args, **kwargs):
if not is_auto_round_available():
raise ImportError("Loading a AutoRound quantized model requires auto-round library (`pip install auto-round`)")
raise ImportError("Loading a AutoRound quantized model requires auto-round library (`pip install "
"auto-round`)")
elif version.parse(importlib.metadata.version("auto_round")) < version.parse("0.2.0"):
raise ImportError("You need a version of auto_round > 0.2.0 to use AutoRound: `pip install --upgrade auto-round`")
raise ImportError("You need a version of auto_round > 0.2.0 to use AutoRound: `pip install --upgrade "
"auto-round`")

def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
if torch_dtype is None:
Expand Down Expand Up @@ -381,14 +383,10 @@ def post_init_model(self, model):
The input model
"""
#
# if self.bits == 4:
# if get_device(model) == torch.device("cpu") or (
# hasattr(model, "hf_device_map") and any(d in model.hf_device_map for d in ["cpu", "disk"])
# ):
# raise ValueError(
# "Found modules on cpu/disk. Using Exllamav2 backend requires all the modules to be on GPU."
# "You can deactivate exllama backend by setting `disable_exllama=True` in the quantization config object"
# )
# if self.bits == 4: if get_device(model) == torch.device("cpu") or ( hasattr(model, "hf_device_map") and
# any(d in model.hf_device_map for d in ["cpu", "disk"]) ): raise ValueError( "Found modules on cpu/disk.
# Using Exllamav2 backend requires all the modules to be on GPU." "You can deactivate exllama backend by
# setting `disable_exllama=True` in the quantization config object" )

class StoreAttr(object):
pass
Expand Down
10 changes: 7 additions & 3 deletions auto_round/export/export_to_autogptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@
@register_format("auto_gptq")
def save_quantized_as_autogptq(output_dir, use_triton=True, inplace=True, **kwargs):
"""Export the model to autogptq format to easily leverage cuda kernel."""
try:
import auto_gptq
except ImportError:
raise ImportError("export to autogptq requires autogptq library. Please run 'pip install auto-gptq'")
model = kwargs["model"]
weight_config = kwargs["weight_config"]
sym = kwargs["sym"]
Expand Down Expand Up @@ -95,7 +99,7 @@ def save_quantized_as_autogptq(output_dir, use_triton=True, inplace=True, **kwar
else:
compressed_model = copy.deepcopy(model.to("cpu"))

from auto_gptq.modeling._utils import pack_model
from auto_gptq.modeling._utils import pack_model # pylint: disable=E0401

if bits == 3 or use_triton is False:
if bits == 3 and use_triton is True:
Expand Down Expand Up @@ -127,7 +131,7 @@ def save_quantized_as_autogptq(output_dir, use_triton=True, inplace=True, **kwar
info = weight_config[key]
if not check_to_quantized(info):
continue
quantizers[key] = (None, info["scale"].to(torch.float32), info["zp"].to(torch.float32), info["g_idx"])
quantizers[key] = (None, info["scale"], info["zp"].to(torch.float32), info["g_idx"])
pack_model(
compressed_model,
quantizers,
Expand Down Expand Up @@ -236,7 +240,7 @@ def _save_quantized_to_autogptq(
model_save_name = model_base_name + ".bin"
torch.save(model.state_dict(), join(save_dir, model_save_name))

from auto_gptq.modeling._base import BaseQuantizeConfig
from auto_gptq.modeling._base import BaseQuantizeConfig # pylint: disable=E0401

quantization_config = BaseQuantizeConfig(
bits=bits,
Expand Down
7 changes: 4 additions & 3 deletions auto_round/export/export_to_autoround/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import transformers

from auto_round.export.register import register_format
from auto_round.utils import get_layer_names_in_block, get_block_names, get_module, logger, set_module
from auto_round.utils import get_layer_names_in_block, get_module, logger, set_module


def check_neq_config(config, data_type, bits, group_size, sym):
Expand Down Expand Up @@ -87,15 +87,16 @@ def dynamic_QuantLienar_for_packing(backend, bits, group_size):
disable_marlin=disable_marlin,
)
return QuantLinear
elif "autoround" in backend or "auto-round" in backend or "auto_round" in backend: ##export all use trition,inferce use exllamav2
##export all use trition, inference use exllamav2
elif "autoround" in backend or "auto-round" in backend or "auto_round" in backend:
from .qliner_triton import QuantLinear
return QuantLinear

else:
assert False, f"only support gptq and autoround backend"


@register_format("autoround")
@register_format("auto_round")
def save_quantized_as_autoround(output_dir, inplace=True, backend="autoround:exllamav2", **kwargs):
model = kwargs["model"]
if not inplace:
Expand Down
38 changes: 30 additions & 8 deletions auto_round/export/export_to_autoround/post_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# MIT License
#
# Copyright (c) 2023 潘其威(William)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import torch
EXLLAMA_DEFAULT_MAX_INPUT_LENGTH=2048

EXLLAMA_DEFAULT_MAX_INPUT_LENGTH = 2048


def autoround_post_init(model):
"""
Expand All @@ -32,15 +55,15 @@ def autoround_post_init(model):
"max_inner_outer_dim": 1,
}


submodule._use_act_order = False


# Disable this heuristic for detecting act_order, but it could be used instead of the config.
"""
if submodule.g_idx is None:
submodule.act_order = False
elif submodule.g_idx is not None and ((submodule.g_idx == 0).all() or torch.equal(submodule.g_idx.cpu(), torch.tensor([i // submodule.group_size for i in range(submodule.g_idx.shape[0])], dtype=torch.int32))):
elif submodule.g_idx is not None and ((submodule.g_idx == 0).all() or
torch.equal(submodule.g_idx.cpu(),
torch.tensor([i // submodule.group_size for i in range(submodule.g_idx.shape[0])], dtype=torch.int32))):
submodule.g_idx = None
submodule.act_order = False
else:
Expand All @@ -52,19 +75,18 @@ def autoround_post_init(model):
submodule.qweight.numel() * 8,
)


if model_uses_exllama:
# To be honest this is quite ugly, not proud of this.
try:
from exllama_kernels import prepare_buffers, set_tuning_params
except ImportError as e:
raise ImportError(
f"Could not import exllama backend dependencies prepare_buffers, set_tuning_params with the following error: {e}"
f"Could not import exllama backend dependencies prepare_buffers, set_tuning_params with the following "
f"error: {e}"
)

device_to_buffers = {}


max_input_len = 1

for device, buffers_size in device_to_buffers_size.items():
Expand Down Expand Up @@ -129,4 +151,4 @@ def autoround_post_init(model):
submodule.post_init(temp_dq=model.device_tensors[device])
torch.cuda.empty_cache()

return model
return model
14 changes: 10 additions & 4 deletions auto_round/export/export_to_autoround/qliner_exllamav2.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@

def error_raiser_exllama(*args, **kwargs):
raise ValueError(
f"Trying to use the exllama v2 backend, but could not import the C++/CUDA dependencies with the following error: {exllama_v2_import_exception}"
f"Trying to use the exllama v2 backend, but could not import the C++/CUDA dependencies with the following "
f"error: {exllama_v2_import_exception}"
)

make_q_matrix = error_raiser_exllama
Expand Down Expand Up @@ -110,7 +111,8 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
device=w["qweight"].device,
)
w["q_invperm"] = torch.empty_like(w["q_perm"])
# make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx.
# make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs
# to be passed for g_idx.
return make_q_matrix(
w["qweight"],
w["q_perm"],
Expand Down Expand Up @@ -148,7 +150,8 @@ def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=Fa
super().__init__()
if bits != 4:
raise ValueError(
f"Exllamav2 kernel supports only bits=4, requested bits={bits}. Something is wrong in the model initialization."
f"Exllamav2 kernel supports only bits=4, requested bits={bits}. Something is wrong in the model "
f"initialization."
)
if trainable:
raise NotImplementedError("Exllamav2 kernel does not support training.")
Expand Down Expand Up @@ -217,7 +220,10 @@ def post_init(self, temp_dq):
def forward(self, x, force_cuda=False):
if x.dtype != torch.float16:
logger.warning_once(
f"The exllama v2 kernel for GPTQ requires a float16 input activation, while {x.dtype} was passed. Casting to float16.\nMake sure you loaded your model with torch_dtype=torch.float16, that the model definition does not inadvertently cast to float32, or disable AMP Autocast that may produce float32 intermediate activations in the model."
f"The exllama v2 kernel for GPTQ requires a float16 input activation, while {x.dtype} was passed. "
f"Casting to float16.\nMake sure you loaded your model with torch_dtype=torch.float16, that the model "
f"definition does not inadvertently cast to float32, or disable AMP Autocast that may produce float32 "
f"intermediate activations in the model."
)

x = x.half()
Expand Down
2 changes: 1 addition & 1 deletion auto_round_extension/cuda/exllamav2/cuda/util.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ __forceinline__ __device__ float clamp(float x, float a, float b)
return fmaxf(a, fminf(b, x));
}

#define cuda_check(and) { gpu_assert((and), __FILE__, __LINE__); }
#define cuda_check(res) { gpu_assert((res), __FILE__, __LINE__); }
inline void gpu_assert(cudaError_t code, const char *file, int line, bool abort=true)
{
if (code != cudaSuccess)
Expand Down
5 changes: 3 additions & 2 deletions examples/language-modeling/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def get_library_version(library_name):
except subprocess.CalledProcessError:
return "Library not found"


res = get_library_version("lm-eval")
if res == "0.3.0":
use_eval_legacy = True
Expand Down Expand Up @@ -289,7 +290,7 @@ def get_library_version(library_name):
f"supported currently")
break
if args.quant_lm_head:
weight_config[lm_head_layer_name] = {"data_type": "int"}
weight_config[lm_head_layer_name] = {"data_type": "int", "bits": 4, "group_size": 32}
transformers_version = [int(item) for item in transformers.__version__.split('.')[:2]]
if transformers_version[0] == 4 and transformers_version[1] < 38:
error_message = "Please upgrade transformers>=4.38.0 to support lm-head quantization."
Expand Down Expand Up @@ -342,4 +343,4 @@ def get_library_version(library_name):
print(excel_name, flush=True)
eval_model(model_path=output_dir, tasks=tasks, dtype=dtype, limit=None,
eval_bs=args.eval_bs, use_accelerate=not args.disable_low_gpu_mem_usage,
device=torch_device, excel_file=excel_name)
device=torch_device, excel_file=excel_name)

0 comments on commit 5c94cea

Please sign in to comment.