Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for weights_only flag when loading state_dict #32481

Merged
merged 2 commits into from
Oct 3, 2024

Conversation

jerryzh168
Copy link
Contributor

@jerryzh168 jerryzh168 commented Aug 7, 2024

Summary:
This is to enable loading a state_dict with wrapper tensor subclasses (used in torchao to for quantized weights)

Test Plan:
tested locally with torchao weights, also need #32306:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import TorchAoConfig
from torchao.utils import benchmark_model
import torchao

DEVICE_TYPE = "cuda"

def init_model_and_benchmark(model_id, torch_dtype=torch.bfloat16, quantization_config=None):
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    if quantization_config is not None:
        model = AutoModelForCausalLM.from_pretrained(model_id, device_map=DEVICE_TYPE, torch_dtype=torch.bfloat16, quantization_config=quantization_config)
    else:
        model = AutoModelForCausalLM.from_pretrained(model_id, device_map=DEVICE_TYPE, torch_dtype=torch.bfloat16, weights_only=False)

    # sanity check: run the model
    input_text = "What are we having for dinner?"
    input_ids = tokenizer(input_text, return_tensors="pt").to(DEVICE_TYPE)
    output = model.generate(**input_ids, max_new_tokens=1000)
    print(tokenizer.decode(output[0], skip_special_tokens=True))

    NUM_WARMUP = 1
    NUM_RUNS = 5

    if quantization_config is not None:
        torchao.quantization.utils.recommended_inductor_config_setter()

    model = torch.compile(model, mode="max-autotune")

    benchmark_model(model.generate, NUM_WARMUP, kwargs=input_ids, device_type=DEVICE_TYPE)
    print("running benchmark")
    results = benchmark_model(model.generate, NUM_RUNS, kwargs=input_ids, device_type=DEVICE_TYPE)
    return model, results

model_id = "jerryzh168/test-model"
torchao.quantization.utils.recommended_inductor_config_setter()
bf16_model, bf16_time = init_model_and_benchmark(model_id)
print(f"bf16: {bf16_time}")

Reviewers:

Subscribers:

Tasks:

Tags:

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this support!

Overall, I think this looks good. Let's get a second 👍 from @ArthurZucker too as it's touching core code

src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
@jerryzh168
Copy link
Contributor Author

I learned from @mikaylagawarecki that weights_only is going to be set to True by default in the future. but this flag could still be helpful if people are using older versions of pytorch and want to use torchao I think

@jerryzh168
Copy link
Contributor Author

Hi @ArthurZucker can you take a look at this PR?

@jerryzh168
Copy link
Contributor Author

please let me know if we want to add a test, but load test is a bit harder to write because it needs to upload some models like:

model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM sorry for being late on the review here!

@ArthurZucker
Copy link
Collaborator

Can you rebase on main and resolve the conflicts?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@jerryzh168
Copy link
Contributor Author

jerryzh168 commented Sep 3, 2024

thanks for the review @ArthurZucker and @amyeroberts I just updated the PR, please feel free to merge when the CI is green

@amyeroberts
Copy link
Collaborator

@jerryzh168 Could you run make fix-copies and push the changes? This should resolve the failing quality checks

@jerryzh168
Copy link
Contributor Author

thanks, updated, please take a look again @amyeroberts

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's revert unrelated changes and let's go

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this unrelated change should not be included here ! 🤗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker this is fix from running make fix-copies 😅

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok no worries then let's merge!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay would kindly ask to rebase to make sure we are up to date as other things were merged !

Summary:
This is to enable loading a state_dict with wrapper tensor subclasses (used in torchao to
for quantized weights)

Test Plan:
tested locally with torchao weights, also need huggingface#32306:
```
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import TorchAoConfig
from torchao.utils import benchmark_model
import torchao

DEVICE_TYPE = "cuda"

def init_model_and_benchmark(model_id, torch_dtype=torch.bfloat16, quantization_config=None):
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    if quantization_config is not None:
        model = AutoModelForCausalLM.from_pretrained(model_id, device_map=DEVICE_TYPE, torch_dtype=torch.\bfloat16, quantization_config=quantization_config)
    else:
        model = AutoModelForCausalLM.from_pretrained(model_id, device_map=DEVICE_TYPE, torch_dtype=torch.\bfloat16, weights_only=False)

    # sanity check: run the model
    input_text = "What are we having for dinner?"
    input_ids = tokenizer(input_text, return_tensors="pt").to(DEVICE_TYPE)
    output = model.generate(**input_ids, max_new_tokens=1000)
    print(tokenizer.decode(output[0], skip_special_tokens=True))

    NUM_WARMUP = 1
    NUM_RUNS = 5

    if quantization_config is not None:
        torchao.quantization.utils.recommended_inductor_config_setter()

    model = torch.compile(model, mode="max-autotune")

    benchmark_model(model.generate, NUM_WARMUP, kwargs=input_ids, device_type=DEVICE_TYPE)
    print("running benchmark")
    results = benchmark_model(model.generate, NUM_RUNS, kwargs=input_ids, device_type=DEVICE_TYPE)
    return model, results

model_id = "jerryzh168/test-model"
torchao.quantization.utils.recommended_inductor_config_setter()
bf16_model, bf16_time = init_model_and_benchmark(model_id)
print(f"bf16: {bf16_time}")
```

Reviewers:

Subscribers:

Tasks:

Tags:
@ArthurZucker ArthurZucker merged commit 15a4d24 into huggingface:main Oct 3, 2024
21 checks passed
@ArthurZucker
Copy link
Collaborator

Thanks once more @jerryzh168

BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
…ace#32481)

* Add support for `weights_only` flag when loading state_dict

Summary:
This is to enable loading a state_dict with wrapper tensor subclasses (used in torchao to
for quantized weights)

Test Plan:
tested locally with torchao weights, also need huggingface#32306:
```
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import TorchAoConfig
from torchao.utils import benchmark_model
import torchao

DEVICE_TYPE = "cuda"

def init_model_and_benchmark(model_id, torch_dtype=torch.bfloat16, quantization_config=None):
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    if quantization_config is not None:
        model = AutoModelForCausalLM.from_pretrained(model_id, device_map=DEVICE_TYPE, torch_dtype=torch.\bfloat16, quantization_config=quantization_config)
    else:
        model = AutoModelForCausalLM.from_pretrained(model_id, device_map=DEVICE_TYPE, torch_dtype=torch.\bfloat16, weights_only=False)

    # sanity check: run the model
    input_text = "What are we having for dinner?"
    input_ids = tokenizer(input_text, return_tensors="pt").to(DEVICE_TYPE)
    output = model.generate(**input_ids, max_new_tokens=1000)
    print(tokenizer.decode(output[0], skip_special_tokens=True))

    NUM_WARMUP = 1
    NUM_RUNS = 5

    if quantization_config is not None:
        torchao.quantization.utils.recommended_inductor_config_setter()

    model = torch.compile(model, mode="max-autotune")

    benchmark_model(model.generate, NUM_WARMUP, kwargs=input_ids, device_type=DEVICE_TYPE)
    print("running benchmark")
    results = benchmark_model(model.generate, NUM_RUNS, kwargs=input_ids, device_type=DEVICE_TYPE)
    return model, results

model_id = "jerryzh168/test-model"
torchao.quantization.utils.recommended_inductor_config_setter()
bf16_model, bf16_time = init_model_and_benchmark(model_id)
print(f"bf16: {bf16_time}")
```

Reviewers:

Subscribers:

Tasks:

Tags:

* format
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants