Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pipeline] A simple fix for half-precision & 8bit models #21479

Merged
merged 15 commits into from
Feb 10, 2023

Conversation

younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Feb 6, 2023

What does this PR do?

Currently on the main branch of transformers if a user wants to run a pipeline using large models (thus, ideally loaded with device_map=...) and in half precision (or in int8), they may encounter some issues when calling pipeline with top_p & top_k sampling:

RuntimeError: "topk_cpu" not implemented for 'Half'

Snippet to reproduce & explanations:

import torch
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM

model_id = "EleutherAI/gpt-neo-125M"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.float16)

pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_length=20, temperature=1, do_sample=True, top_p=0.95, top_k=60, num_return_sequences=3)
text = "What can you tell me about the LHC?"
response = pipe(text)
print(response[0]["generated_text"])

This is because the input_ids are automatically set on cpu since the argument device is not passed when initializing the pipeline. A model that is loaded with device_map=... (i.e. with accelerate) always sets the output tensor of the model to the device of the input tensor thanks to the forward hooks. Therefore when calling the top_k method, the output tensor is in fp16 (because the model has been loaded in fp16) & on cpu hence the torch error above.

Currently a hack to fix this is to add device=0 when initializing the pipeline but this leads to inconsistent and undesirable behaviours for some cases, for example when loading large models in several GPUs, since the call model.to(self.device) will break some internals (the hooks will be still there but the weights will be set on the wrong devices). A snippet to reproduce below:

import torch
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM

model_id = "EleutherAI/gpt-neo-125M"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="balanced", torch_dtype=torch.float16)

pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_length=20, temperature=1, do_sample=True, top_p=0.95, top_k=60, num_return_sequences=3, device=0)
text = "What can you tell me about the LHC?"
response = pipe(text)
print(response[0]["generated_text"])

adding this hack also breaks the usage of pipeline with int8 models, since the to method is blocked for these models:

ValueError: `.to` is not supported for `8-bit` models. Please use the model as it is, since the model has already been set to the correct devices and casted to the correct `dtype`.

Thus, I propose to fix this by simply checking whether a model has been loaded with accelerate by looking at the attribute hf_device_map , and set the model on the correct device only if it has not been loaded with accelerate as backend. This fixes 3 bugs: using pipeline with a fp16 model that has been loaded with accelerate without having any error in case of multi-gpu usage, using pipeline with a fp16 model w accelerate & sampling strategies, and using pipeline with int8 models & sampling strategies.

cc @sgugger @Narsil

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Feb 6, 2023

The documentation is not available anymore as the PR was closed or merged.

@Narsil
Copy link
Contributor

Narsil commented Feb 6, 2023

Thanks for the well thought out issue and proposed fix.

I don't particularly like the fix because it depends on some weird internal and still forces users to use device_map and device iiuc.

Couldn't we just use device_map and use accelerate api to figure out where to put the inputs? (most likely cuda:0 but still cpu if no gpus are available I think.)
That or just do something special for device_map without asking where the model is (if the API doesn't exist or is tricky).

Imo using device_map and device should be an error (ambiguous intent)

@younesbelkada
Copy link
Contributor Author

younesbelkada commented Feb 7, 2023

Thanks for the feedback @Narsil !
I think

Imo using device_map and device should be an error (ambiguous intent)

Makes sense !
Another fix would be to force-upcast the logits in fp32 when doing top_k & top_p sampling on the generation side only if the logits are on cpu, is this solution a reasonable fix @gante ? Happy to open a PR to fix it!

@Narsil
Copy link
Contributor

Narsil commented Feb 7, 2023

force-upcast

I would highly advise against it too. There's a limit to magic. Doing half precision on cpu should crash in a lot of places. We shouldn't upcast on behalf of a user that explicitely asked for half precision imo. That's breaking user intent.
But the user also asked for GPU, that's where we're breaking his intent and that's what should be fixed IMO.

Does accelerate allow to know on which device is the start of the model ?

@younesbelkada
Copy link
Contributor Author

I see, makes sense!

Does accelerate allow to know on which device is the start of the model ?

I am not sure here, maybe @sgugger & @muellerzr knows better

@Narsil
Copy link
Contributor

Narsil commented Feb 7, 2023

I am not sure here, maybe @sgugger & @muellerzr knows better

if not the pipeline could have the simplest heuristic 'cuda:0' if torch.cuda.is_avalaible() else 'cpu' which should work 99% of the time.
But it wouldn't if a user specified an odd map (which is why having direct access would be better).

@Narsil
Copy link
Contributor

Narsil commented Feb 7, 2023

What do you think:

diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py
index 30402b36e..e698f1aa3 100644
--- a/src/transformers/pipelines/base.py
+++ b/src/transformers/pipelines/base.py
@@ -749,7 +749,7 @@ class Pipeline(_ScikitCompat):
         framework: Optional[str] = None,
         task: str = "",
         args_parser: ArgumentHandler = None,
-        device: Union[int, str, "torch.device"] = -1,
+        device: Union[int, str, "torch.device"] = None,
         torch_dtype: Optional[Union[str, "torch.dtype"]] = None,
         binary_output: bool = False,
         **kwargs,
@@ -764,6 +764,20 @@ class Pipeline(_ScikitCompat):
         self.image_processor = image_processor
         self.modelcard = modelcard
         self.framework = framework
+
+        # Special handling
+        if self.framework == "pt" and device is not None:
+            self.model = self.model.to(device=device)
+
+        if device is None:
+            # `accelerate` device map
+            hf_device_map = getattr(self.model, "hf_device_map", None)
+            if hf_device_map is not None:
+                # Take the first device used by `accelerate`.
+                device = next(iter(hf_device_map.values()))
+            else:
+                device = -1
+
         if is_torch_available() and self.framework == "pt":
             if isinstance(device, torch.device):
                 self.device = device
@@ -775,13 +789,10 @@ class Pipeline(_ScikitCompat):
                 self.device = torch.device(f"cuda:{device}")
         else:
             self.device = device
+
         self.torch_dtype = torch_dtype
         self.binary_output = binary_output

-        # Special handling
-        if self.framework == "pt" and self.device.type != "cpu":
-            self.model = self.model.to(self.device)
-
         # Update config with task specific parameters
         task_specific_params = self.model.config.task_specific_params
         if task_specific_params is not None and task in task_specific_params:

Here we just modify the default device when the model uses accelerate 's device_map.
We still depend on something magic, but it only modifies the default device, and doesn't modify model unless device was specified by user (which is correct in terms of intent IMO)

@younesbelkada
Copy link
Contributor Author

younesbelkada commented Feb 7, 2023

I think that would totally work @Narsil ! Happy to change the PR with your proposed changes, let me know!

@Narsil
Copy link
Contributor

Narsil commented Feb 7, 2023

Sure. Let's update the doc too.

@sgugger
Copy link
Collaborator

sgugger commented Feb 7, 2023

Side notes:

  • you should probably do something if the user passes a device and the model has a hf_device_map (at least a warning) as the line model.to(device) will probably screw things up (it will at least error if there are some weights offloaded)
  • the device on which the model is executed is determined by this rule in Accelerate, maybe you should use the same code? I can also store it on the Accelerate side in a special attribute but then you'd have to wait for a release.
    if set(device_map.values()) == {"cpu"} or set(device_map.values()) == {"cpu", "disk"}:
        main_device = "cpu"
    else:
        main_device = [d for d in device_map.values() if d not in ["cpu", "disk"]][0]

@younesbelkada
Copy link
Contributor Author

Thanks a lot for the valuable feedback @Narsil @sgugger !
I updated the PR and added more clarification (and also a new section) on the docs

Copy link
Contributor

@Narsil Narsil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update.

I propose to simplify the examples a lot.

Then, regarding the actual code, I think your proposed change is entangling and confusing device/self.device more than it should. At least more than my original modification.

We could go even further in "cleanliness" and handle all that logic within pipeline function and not in Pipeline.__init__ since at that point the model should be a black box (that includes moving the model on device).

Modifying the model on device things will break calls like TranslationPipeline(model=MyModel(), device=0). I don't think we can realistically break that, but we can probably figure out a non breaking change that does handle device_map better.

docs/source/en/pipeline_tutorial.mdx Outdated Show resolved Hide resolved
Comment on lines 782 to 783
if self.framework == "pt" and device is not None:
self.model = self.model.to(device=self.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't mix self.device and device . This is super error prone.

The proposed change I made was at least explicit about it's default value.
I really think this needs to be changed. Too many opportunities to introduce bugs later on.

  • Set the default value (if no value provided)
  • Handle device to create self.device.
  • Use self.device everywhere.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I somehow didn't fully considered your proposition in #21479 (comment) - I think it's wiser to revert my changes with yours!

Comment on lines 785 to 786
hf_device_map = getattr(self.model, "hf_device_map", None)
if hf_device_map is not None:
Copy link
Contributor

@Narsil Narsil Feb 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's probably a way to structure code where this is written only once.

I think directly in pipeline function we could warn when both device and device_map are set.
Prevents having to guess here. If you're splitting model loading and pipeline loading, then you should be aware of what you do, but we shouldn't actively depend on internals to seek what's going on.

Essentially, when users use pipeline(model=MyModel()) the model is a black box to us, we shouldn't look at it. We're looking at it in my proposed change only when there's no device being sent.

And to be even purer, we could modify the pipeline itself, to check hf_device_map only when we do from_pretrained. That seems even cleaner since we know that this internal map could exist here (where we can't here if user passes in a real object).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💯 on this @Narsil

tests/pipelines/test_pipelines_text_generation.py Outdated Show resolved Hide resolved
docs/source/en/pipeline_tutorial.mdx Outdated Show resolved Hide resolved
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed with all of @Narsil 's comments!

Comment on lines 271 to 277
# pip install accelerate
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

model = AutoModelForCausalLM.from_pretrained("bigscience/bloom", torch_dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe use a smaller example and say in a note the user can replace it by BLOOM?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure!

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating!

Copy link
Contributor

@Narsil Narsil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@younesbelkada
Copy link
Contributor Author

the failing test is indenpendent to our PR! Merging!
Thanks for all your comments!

@younesbelkada younesbelkada merged commit f839426 into huggingface:main Feb 10, 2023
@younesbelkada younesbelkada deleted the fix-int8-pipeline branch February 28, 2023 16:37
ArthurZucker pushed a commit to ArthurZucker/transformers that referenced this pull request Mar 2, 2023
…ce#21479)

* v1 fix

* adapt from suggestions

* make style

* fix tests

* add gpu tests

* update docs

* fix other tests

* Apply suggestions from code review

Co-authored-by: Nicolas Patry <[email protected]>

* better fix

* make fixup

* better example

* revert changes

* proposal

* more elegant solution

* Update src/transformers/pipelines/automatic_speech_recognition.py

Co-authored-by: Sylvain Gugger <[email protected]>

---------

Co-authored-by: Nicolas Patry <[email protected]>
Co-authored-by: Sylvain Gugger <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants