-
Notifications
You must be signed in to change notification settings - Fork 28k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for weights_only
flag when loading state_dict
#32481
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this support!
Overall, I think this looks good. Let's get a second 👍 from @ArthurZucker too as it's touching core code
I learned from @mikaylagawarecki that |
Hi @ArthurZucker can you take a look at this PR? |
please let me know if we want to add a test, but load test is a bit harder to write because it needs to upload some models like: transformers/tests/test_modeling_utils.py Line 930 in abbffc4
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM sorry for being late on the review here!
Can you rebase on main and resolve the conflicts? |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
ce3775a
to
e309bf1
Compare
thanks for the review @ArthurZucker and @amyeroberts I just updated the PR, please feel free to merge when the CI is green |
@jerryzh168 Could you run |
thanks, updated, please take a look again @amyeroberts |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's revert unrelated changes and let's go
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this unrelated change should not be included here ! 🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ArthurZucker this is fix from running make fix-copies
😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok no worries then let's merge!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delay would kindly ask to rebase to make sure we are up to date as other things were merged !
Summary: This is to enable loading a state_dict with wrapper tensor subclasses (used in torchao to for quantized weights) Test Plan: tested locally with torchao weights, also need huggingface#32306: ``` import torch from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import TorchAoConfig from torchao.utils import benchmark_model import torchao DEVICE_TYPE = "cuda" def init_model_and_benchmark(model_id, torch_dtype=torch.bfloat16, quantization_config=None): tokenizer = AutoTokenizer.from_pretrained(model_id) if quantization_config is not None: model = AutoModelForCausalLM.from_pretrained(model_id, device_map=DEVICE_TYPE, torch_dtype=torch.\bfloat16, quantization_config=quantization_config) else: model = AutoModelForCausalLM.from_pretrained(model_id, device_map=DEVICE_TYPE, torch_dtype=torch.\bfloat16, weights_only=False) # sanity check: run the model input_text = "What are we having for dinner?" input_ids = tokenizer(input_text, return_tensors="pt").to(DEVICE_TYPE) output = model.generate(**input_ids, max_new_tokens=1000) print(tokenizer.decode(output[0], skip_special_tokens=True)) NUM_WARMUP = 1 NUM_RUNS = 5 if quantization_config is not None: torchao.quantization.utils.recommended_inductor_config_setter() model = torch.compile(model, mode="max-autotune") benchmark_model(model.generate, NUM_WARMUP, kwargs=input_ids, device_type=DEVICE_TYPE) print("running benchmark") results = benchmark_model(model.generate, NUM_RUNS, kwargs=input_ids, device_type=DEVICE_TYPE) return model, results model_id = "jerryzh168/test-model" torchao.quantization.utils.recommended_inductor_config_setter() bf16_model, bf16_time = init_model_and_benchmark(model_id) print(f"bf16: {bf16_time}") ``` Reviewers: Subscribers: Tasks: Tags:
Thanks once more @jerryzh168 |
…ace#32481) * Add support for `weights_only` flag when loading state_dict Summary: This is to enable loading a state_dict with wrapper tensor subclasses (used in torchao to for quantized weights) Test Plan: tested locally with torchao weights, also need huggingface#32306: ``` import torch from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import TorchAoConfig from torchao.utils import benchmark_model import torchao DEVICE_TYPE = "cuda" def init_model_and_benchmark(model_id, torch_dtype=torch.bfloat16, quantization_config=None): tokenizer = AutoTokenizer.from_pretrained(model_id) if quantization_config is not None: model = AutoModelForCausalLM.from_pretrained(model_id, device_map=DEVICE_TYPE, torch_dtype=torch.\bfloat16, quantization_config=quantization_config) else: model = AutoModelForCausalLM.from_pretrained(model_id, device_map=DEVICE_TYPE, torch_dtype=torch.\bfloat16, weights_only=False) # sanity check: run the model input_text = "What are we having for dinner?" input_ids = tokenizer(input_text, return_tensors="pt").to(DEVICE_TYPE) output = model.generate(**input_ids, max_new_tokens=1000) print(tokenizer.decode(output[0], skip_special_tokens=True)) NUM_WARMUP = 1 NUM_RUNS = 5 if quantization_config is not None: torchao.quantization.utils.recommended_inductor_config_setter() model = torch.compile(model, mode="max-autotune") benchmark_model(model.generate, NUM_WARMUP, kwargs=input_ids, device_type=DEVICE_TYPE) print("running benchmark") results = benchmark_model(model.generate, NUM_RUNS, kwargs=input_ids, device_type=DEVICE_TYPE) return model, results model_id = "jerryzh168/test-model" torchao.quantization.utils.recommended_inductor_config_setter() bf16_model, bf16_time = init_model_and_benchmark(model_id) print(f"bf16: {bf16_time}") ``` Reviewers: Subscribers: Tasks: Tags: * format
Summary:
This is to enable loading a state_dict with wrapper tensor subclasses (used in torchao to for quantized weights)
Test Plan:
tested locally with torchao weights, also need #32306:
Reviewers:
Subscribers:
Tasks:
Tags: