diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 4bb5cffb128..b318e1b12b4 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -845,6 +845,8 @@ def __init__( device = -1 if is_torch_available() and self.framework == "pt": + if device == -1 and self.model.device is not None: + device = self.model.device if isinstance(device, torch.device): if device.type == "xpu" and not is_torch_xpu_available(check_device=True): raise ValueError(f'{device} is not available, you should use device="cpu" instead') @@ -871,11 +873,10 @@ def __init__( self.device = device if device is not None else -1 self.binary_output = binary_output - - # We shouldn't call `model.to()` for models loaded with accelerate + # We shouldn't call `model.to()` for models loaded with accelerate as well as the case that model is already on device if ( self.framework == "pt" - and self.device is not None + and self.model.device != self.device and not (isinstance(self.device, int) and self.device < 0) and hf_device_map is None ): diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index c680b4c634d..763c7d1a883 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -48,6 +48,7 @@ require_tf, require_torch, require_torch_accelerator, + require_torch_multi_accelerator, require_torch_or_tf, slow, torch_device, @@ -519,6 +520,52 @@ def test_pipeline_negative_device(self): actual_output = classifier("Test input.") self.assertEqual(expected_output, actual_output) + @require_torch_accelerator + def test_pipeline_no_device(self): + # Test when no device is passed to pipeline + import torch + + from transformers import AutoModelForCausalLM + + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert") + # Case 1: Model is manually moved to device + model = AutoModelForCausalLM.from_pretrained( + "hf-internal-testing/tiny-random-bert", torch_dtype=torch.float16 + ).to(torch_device) + model_device = model.device + pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) + self.assertEqual(pipe.model.device, model_device) + # Case 2: Model is loaded by accelerate + model = AutoModelForCausalLM.from_pretrained( + "hf-internal-testing/tiny-random-bert", device_map=torch_device, torch_dtype=torch.float16 + ) + model_device = model.device + pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) + self.assertEqual(pipe.model.device, model_device) + # Case 3: device_map is passed to model and device is passed to pipeline + model = AutoModelForCausalLM.from_pretrained( + "hf-internal-testing/tiny-random-bert", device_map=torch_device, torch_dtype=torch.float16 + ) + with self.assertRaises(ValueError): + pipe = pipeline("text-generation", model=model, device="cpu", tokenizer=tokenizer) + + @require_torch_multi_accelerator + def test_pipeline_device_not_equal_model_device(self): + # Test when device ids are different, pipeline should move the model to the passed device id + import torch + + from transformers import AutoModelForCausalLM + + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert") + model_device = f"{torch_device}:1" + model = AutoModelForCausalLM.from_pretrained( + "hf-internal-testing/tiny-random-bert", torch_dtype=torch.float16 + ).to(model_device) + target_device = f"{torch_device}:0" + self.assertNotEqual(model_device, target_device) + pipe = pipeline("text-generation", model=model, device=target_device, tokenizer=tokenizer) + self.assertEqual(pipe.model.device, torch.device(target_device)) + @slow @require_torch def test_load_default_pipelines_pt(self):