Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace torch.tensor() with torch.from_numpy() when processing numpy arrays #33185

Closed
2 of 4 tasks
shinyano opened this issue Aug 29, 2024 · 5 comments · Fixed by #33201
Closed
2 of 4 tasks

Replace torch.tensor() with torch.from_numpy() when processing numpy arrays #33185

shinyano opened this issue Aug 29, 2024 · 5 comments · Fixed by #33201
Labels

Comments

@shinyano
Copy link
Contributor

shinyano commented Aug 29, 2024

System Info

  • transformers version: 4.43.4
  • Platform: Windows-10-10.0.22631-SP0
  • Python version: 3.11.4
  • Huggingface_hub version: 0.24.5
  • Safetensors version: 0.4.4
  • Accelerate version: not installed
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.2.2+cpu (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:

Who can help?

@ArthurZucker @gante @Rocketknight1

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I'm using Pemja, a C extension using multiple python threads, to execute python scripts from java. I use Blip models in my python script and it hangs when I create a new python thread to execute this script.

This is my minimal script:

import numpy as np
import torch
from PIL import Image
from transformers import BlipProcessor, BlipForConditionalGeneration

class TensorTest:
    def __init__(self):
        self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
        self.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")

    def generate_caption(self, image_data):
        import io
        image = Image.open(io.BytesIO(image_data))

        # original way:
        inputs = self.processor(images=image, return_tensors="pt")   #  return_tensors="pt" is the pain point, my program would stuck here forever
        out = self.model.generate(**inputs)

        # workaround:
        # inputs = self.processor(images=image)               # This is a workaround to avoid hanging, it basically skips the torch.tensor() step and use torch.from_numpy() instead
        # pixel_values = np.array(inputs["pixel_values"][0])
        # pixel_values = torch.from_numpy(pixel_values).unsqueeze(0)
        # pixel_values = pixel_values.to(self.model.device)
        # out = self.model.generate(pixel_values)

        caption = self.processor.decode(out[0], skip_special_tokens=True)
        return caption

    def transform(self):
        with open("E:/pics/horse.jpg", "rb") as image_file:
            image_data = image_file.read()
            return self.generate_caption(image_data)

if __name__ == '__main__':
    t = TensorTest()
    t.transform()

The minimal Java code to call the script:

PythonInterpreterConfig config =
    PythonInterpreterConfig.newBuilder()
        .setPythonExec(pythonCMD)
        .addPythonPaths(PYTHON_PATH)
        .build();
try(PythonInterpreter interpreter = new PythonInterpreter(config)) {
  interpreter.exec(String.format("import %s", moduleName));
  interpreter.exec(String.format("t = %s.%s()", moduleName, className));
  interpreter.exec("t.transform()");                       // <- works fine
}
try(PythonInterpreter interpreter2 = new PythonInterpreter(config)) {
  interpreter2.exec(String.format("import %s", moduleName));
  interpreter2.exec(String.format("t = %s.%s()", moduleName, className));
  interpreter2.exec("t.transform()");                       // <- hangs here
}

As written in the comment, inputs = self.processor(images=image, return_tensors="pt") would stuck. By importing source code, I found that the processor was stuck here at line 149 and value is a numpy array:

def as_tensor(value):
if isinstance(value, (list, tuple)) and len(value) > 0:
if isinstance(value[0], np.ndarray):
value = np.array(value)
elif (
isinstance(value[0], (list, tuple))
and len(value[0]) > 0
and isinstance(value[0][0], np.ndarray)
):
value = np.array(value)
return torch.tensor(value)

In research, I noticed similiar question had happened here: stackoverflow. And torch.from_numpy() works fine both in the stackoverflow link and in my case, as shown in the script above.

Although this bug won't appear if I call the script directly from python, still It seems that using torch.tensor() on numpy arrays is not a reliable practice when multiple processing or threading is involved.

So I would suggest that replacing torch.tensor() with torch.from_numpy() to process numpy arrays is a good idea. Please correct me if it's actually bad.

Expected behavior

The script should not stuck and finish its work.

@shinyano shinyano added the bug label Aug 29, 2024
@LysandreJik
Copy link
Member

Thanks for your issue! cc @qubvel in case you have bandwidth :)

@qubvel
Copy link
Member

qubvel commented Aug 29, 2024

Hi @shinyano, thanks for opening the issue!

I'm not sure about replacing torch.tensor() with torch.from_numpy() because I suspect there might be not only np.arrays but for example scalar values. However, we can resolve it by checking if the value is an array.

 def as_tensor(value): 
     if isinstance(value, (list, tuple)) and len(value) > 0: 
         if isinstance(value[0], np.ndarray): 
             value = np.array(value) 
         elif ( 
             isinstance(value[0], (list, tuple)) 
             and len(value[0]) > 0 
             and isinstance(value[0][0], np.ndarray) 
         ): 
             value = np.array(value)
             
     # Modified code 
     if isinstance(value, np.ndarray):
         tensor = torch.from_numpy(value)
     else:
         tensor = torch.tensor(value)
     return tensor

What do you think? Would you like to submit a PR to fix it?

@shinyano
Copy link
Contributor Author

What do you think? Would you like to submit a PR to fix it?

The modification is exactly what I wanted :). And I would be glad to submit the PR.

@qubvel
Copy link
Member

qubvel commented Aug 29, 2024

@shinyano great! ping me for review when it's ready 🤗

@shinyano
Copy link
Contributor Author

shinyano commented Sep 2, 2024

@shinyano great! ping me for review when it's ready 🤗

@qubvel I have created a PR, Would you like to review it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants