Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check tied parameters #1529

Merged
merged 10 commits into from
Jun 5, 2023
Merged

Conversation

SunMarc
Copy link
Member

@SunMarc SunMarc commented Jun 5, 2023

What does this PR do ?

This PR fixes two issues user can have when using big inference model:

  • Use their own device map but forget that parameters that are tied together should be on the same device. We return an error showing which parameters should be on the same device
  • Forget to tie the parameters before using infer_auto_device_map() which can create a bad device_map. We also return an error asking to tie the weights before using this function.

How to test it

Issue 1

import os
import torch
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
from accelerate.utils import find_tied_parameters

checkpoint = "facebook/opt-350m"

device_map_work = {'model.decoder.embed_tokens': 'cpu',
 'model.decoder.embed_positions': 'cpu',
 'model.decoder.project_out': 'cpu',
 'model.decoder.project_in': 'cpu',
 'model.decoder.layers': 'cpu',
 'lm_head': 'cpu'}

device_map_do_no_work = {'model.decoder.embed_tokens': 'cpu',
 'model.decoder.embed_positions': 'cpu',
 'model.decoder.project_out': 'cpu',
 'model.decoder.project_in': 'cpu',
 'model.decoder.layers': 'cpu',
 'lm_head': 'disk'}

model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map = device_map_do_no_work, offload_folder="offload",offload_state_dict = True)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

Issue 2

import torch
from transformers import AutoConfig,AutoModelForCausalLM
from accelerate import init_empty_weights, infer_auto_device_map, load_checkpoint_and_dispatch

checkpoint = "facebook/opt-350m"

config = AutoConfig.from_pretrained(checkpoint)
with init_empty_weights():
    model = AutoModelForCausalLM.from_config(config)
device_map = infer_auto_device_map(model, no_split_module_classes=["OPTDecoderLayer"])

@SunMarc SunMarc requested a review from sgugger June 5, 2023 15:16
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 5, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for those QOL improvements!

src/accelerate/utils/modeling.py Outdated Show resolved Hide resolved
src/accelerate/utils/modeling.py Outdated Show resolved Hide resolved
src/accelerate/utils/modeling.py Outdated Show resolved Hide resolved
src/accelerate/utils/modeling.py Outdated Show resolved Hide resolved
src/accelerate/utils/modeling.py Outdated Show resolved Hide resolved
src/accelerate/utils/modeling.py Outdated Show resolved Hide resolved
src/accelerate/utils/modeling.py Outdated Show resolved Hide resolved
src/accelerate/utils/modeling.py Outdated Show resolved Hide resolved
SunMarc and others added 2 commits June 5, 2023 12:02
Fix log

Co-authored-by: Sylvain Gugger <[email protected]>
Fix comments and tests

Fix description
has_tied_encoder_decoder = False
has_tied_module = False

if transformers.modeling_utils.PreTrainedModel in inspect.getmro(model.__class__):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking on testing the class __name__ to avoid the extra dep on Transformers.

@SunMarc SunMarc merged commit b9628f1 into huggingface:main Jun 5, 2023
@SunMarc SunMarc deleted the check_tied_parameters branch June 5, 2023 19:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants