-
Notifications
You must be signed in to change notification settings - Fork 258
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add some new features for layer-wise quant #1899
Conversation
Signed-off-by: n1ck-guo <[email protected]>
⛈️ Required checks status: Has failure 🔴
Groups summary🟢 Code Scan Tests workflow
These checks are required after the changes to 🔴 Model Tests 3x workflow
These checks are required after the changes to 🔴 Unit Tests 3x-PyTorch workflow
These checks are required after the changes to Thank you for your contribution! 💜
|
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add some UT verification functionality?
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(filename)s L%(lineno)d: %(message)s") | ||
logger = logging.getLogger("layer_wise_tools") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(filename)s L%(lineno)d: %(message)s") | |
logger = logging.getLogger("layer_wise_tools") | |
from neural_compressor.torch.utils import logger |
@@ -18,19 +18,24 @@ | |||
|
|||
import gc | |||
import json | |||
import logging |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import logging |
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(filename)s L%(lineno)d: %(message)s") | ||
logger = logging.getLogger("layer_wise_tools") | ||
|
||
LWQ_WORKSPACE = os.path.join("layer_wise_tmp") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LWQ_WORKSPACE = os.path.join("layer_wise_tmp") | |
from neural_compressor.common import options | |
LWQ_WORKSPACE = os.path.join(options.workspace, "lwq_tmpdir") |
@@ -121,7 +126,7 @@ def dowload_hf_model(repo_id, cache_dir=None, repo_type=None, revision=None): | |||
return file_path | |||
|
|||
|
|||
def load_empty_model(pretrained_model_name_or_path, cls=AutoModelForCausalLM, **kwargs): | |||
def load_empty_model(pretrained_model_name_or_path, cls=AutoModelForCausalLM, save_path=None, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
save_path
defaults to LWQ_WORKSPACE
if save_path is None: | ||
save_path = LWQ_WORKSPACE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove if use default value
pretrained_model_name_or_path, cls=AutoModelForCausalLM, device="cpu", clean_weight=True, saved_path=None, **kwargs | ||
): | ||
if saved_path is None: | ||
saved_path = LWQ_WORKSPACE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pretrained_model_name_or_path, cls=AutoModelForCausalLM, device="cpu", clean_weight=True, saved_path=None, **kwargs | |
): | |
if saved_path is None: | |
saved_path = LWQ_WORKSPACE | |
pretrained_model_name_or_path, cls=AutoModelForCausalLM, device="cpu", clean_weight=True, saved_path=LWQ_WORKSPACE, **kwargs | |
): |
m.forward = partial(_forward, m, n) | ||
|
||
try: | ||
model.forward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it applicable to all or most models?
marked draft and will migrate to #1883 |
Type of Change
feature
Description
add some new features for layer-wise quant, include get_weight, get_bias, update, and save/load. Make it more easy to use, like a normal model.
Expected Behavior & Potential Risk
None
How has this PR been tested?
how to reproduce the test (including hardware information)
Dependency Change?
None