Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can we just embedding the image with visual part? #287

Closed
whyiug opened this issue Sep 27, 2024 · 11 comments
Closed

Can we just embedding the image with visual part? #287

whyiug opened this issue Sep 27, 2024 · 11 comments
Assignees

Comments

@whyiug
Copy link

whyiug commented Sep 27, 2024

As the title says. we need to embed the image with visual part separately.

For now , i use the code bellow.
Trouble is, I still loaded all the parameters.
Can you guys give me some tips to simplify the code?

from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor, Qwen2VLConfig
from qwen_vl_utils import process_vision_info
import time
import torch

model_name = "Qwen/Qwen2-VL-7B-Instruct"
# model = Qwen2VLForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto", device_map="auto")
model = Qwen2VLForConditionalGeneration.from_pretrained(model_name)
visual_model = model.visual
visual_model.to("cuda")

min_pixels = 100000
max_pixels = 500000
processor = AutoProcessor.from_pretrained(model_name, min_pixels=min_pixels, max_pixels=max_pixels)

IMAGE_PATH = '/home/xxxx.jpeg'
question = "描述这张图片。"
messages = [
    {
        "role": "user",
        "content": [
            {
                "type": "image",
                "image": IMAGE_PATH,
                'max_pixels': max_pixels,
                'min_pixels': min_pixels,
            },
            {"type": "text", "text": question},
        ],
    }
]

text = processor.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)

inputs = processor(
    text=[text],
    images=image_inputs,
    videos=video_inputs,
    padding=True,
    return_tensors="pt",
)
# inputs = inputs.to("cuda")
inputs = inputs.to("cuda")


pixel_values = inputs["pixel_values"].type(torch.bfloat16)

image_embeds = visual_model(pixel_values, grid_thw=inputs["image_grid_thw"])
print(image_embeds.shape)
@whyiug
Copy link
Author

whyiug commented Sep 29, 2024

@simonJJJ @ShuaiBai623 Please have a look and consider it, this is really important for our production environment.

@Andcircle
Copy link

@whyiug I also need this to use VLLM, interestingly, they only support input with image embeds instead of pixel values

@whyiug
Copy link
Author

whyiug commented Oct 8, 2024

I solved the problem with a smaller gpu memory.
I'll post the code in case anyone else runs into a similar problem.
Just look at theload_model section.

import torch
from transformers import AutoConfig
from transformers import AutoProcessor
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
    Qwen2VisionTransformerPretrainedModel,
)
from safetensors.torch import load_file
from qwen_vl_utils import process_vision_info
import time
import os

visual_model = None
processor = None
device = "cuda"


def load_model():
    global visual_model, processor
    model_path = "/home/work/.cache/modelscope/hub/qwen/Qwen2-VL-7B-Instruct"
    config = AutoConfig.from_pretrained(model_path)
    visual_model = Qwen2VisionTransformerPretrainedModel._from_config(
        config=config.vision_config
    )
    checkpoint_path = os.path.join(model_path, "model-00001-of-00005.safetensors")
    checkpoint = load_file(checkpoint_path)
    visual_weights = {
        key.replace("visual.", ""): value
        for key, value in checkpoint.items()
        if key.startswith("visual.")
    }
    visual_model.load_state_dict(visual_weights, strict=False)
    visual_model.to(device)
    visual_model.eval()
    processor = AutoProcessor.from_pretrained(model_path)


load_model()


def tmp_process_image(image_path, question="", min_pixels=350000, max_pixels=500000):
    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": image_path,
                    "max_pixels": max_pixels,
                    "min_pixels": min_pixels,
                },
                {"type": "text", "text": question},
            ],
        }
    ]

    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    image_inputs, video_inputs = process_vision_info(messages)

    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to(device)
    pixel_values = inputs["pixel_values"].type(torch.bfloat16)
    with torch.no_grad():
        image_embeds = visual_model(pixel_values, grid_thw=inputs["image_grid_thw"])
    return inputs["image_grid_thw"], image_embeds


if __name__ == "__main__":
    image_path = "/home/work/data/xxx/up.jpeg"
    print("Processing image...")
    image_grid_thw1, image_embeds1 = tmp_process_image(
        image_path, question="", min_pixels=350000, max_pixels=500000
    )
    print(image_embeds1)
    print("Image processed.")
    time.sleep(1000)
image

@whyiug whyiug closed this as completed Oct 8, 2024
@LianghuiGuo
Copy link

LianghuiGuo commented Oct 30, 2024

@whyiug amazing work!
I'm wondering whether we can get text embedding through similar way? Do you have a try on it?

@whyiug
Copy link
Author

whyiug commented Oct 30, 2024

@whyiug amazing work! I'm wondering whether we can get text embedding through similar way? Do you have a try on it?

没有。感觉只获取 text embedding 是没有意义的,而且这个耗时非常少。如果非要获取,可以查看这里代码
我把 visual model 拆开是因为希望图片模型和文本模型分开部署,而且业务需要一图多问,这对延时和吞吐都有提升。
vllm 支持输入图片 embedding 的 PR: vllm-project/vllm#8856

@LianghuiGuo
Copy link

@whyiug amazing work! I'm wondering whether we can get text embedding through similar way? Do you have a try on it?

没有。感觉只获取 text embedding 是没有意义的,而且这个耗时非常少。如果非要获取,可以查看这里代码 我把 visual model 拆开是因为希望图片模型和文本模型分开部署,而且业务需要一图多问,这对延时和吞吐都有提升。 vllm 支持输入图片 embedding 的 PR: vllm-project/vllm#8856

好的,感谢!我是在想能不能用图文的embedding来做图文检索任务,不知道是否有可行性。

@goen-kkk
Copy link

I solved the problem with a smaller gpu memory. I'll post the code in case anyone else runs into a similar problem. Just look at theload_model section.

import torch
from transformers import AutoConfig
from transformers import AutoProcessor
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
    Qwen2VisionTransformerPretrainedModel,
)
from safetensors.torch import load_file
from qwen_vl_utils import process_vision_info
import time
import os

visual_model = None
processor = None
device = "cuda"


def load_model():
    global visual_model, processor
    model_path = "/home/work/.cache/modelscope/hub/qwen/Qwen2-VL-7B-Instruct"
    config = AutoConfig.from_pretrained(model_path)
    visual_model = Qwen2VisionTransformerPretrainedModel._from_config(
        config=config.vision_config
    )
    checkpoint_path = os.path.join(model_path, "model-00001-of-00005.safetensors")
    checkpoint = load_file(checkpoint_path)
    visual_weights = {
        key.replace("visual.", ""): value
        for key, value in checkpoint.items()
        if key.startswith("visual.")
    }
    visual_model.load_state_dict(visual_weights, strict=False)
    visual_model.to(device)
    visual_model.eval()
    processor = AutoProcessor.from_pretrained(model_path)


load_model()


def tmp_process_image(image_path, question="", min_pixels=350000, max_pixels=500000):
    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": image_path,
                    "max_pixels": max_pixels,
                    "min_pixels": min_pixels,
                },
                {"type": "text", "text": question},
            ],
        }
    ]

    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    image_inputs, video_inputs = process_vision_info(messages)

    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to(device)
    pixel_values = inputs["pixel_values"].type(torch.bfloat16)
    with torch.no_grad():
        image_embeds = visual_model(pixel_values, grid_thw=inputs["image_grid_thw"])
    return inputs["image_grid_thw"], image_embeds


if __name__ == "__main__":
    image_path = "/home/work/data/xxx/up.jpeg"
    print("Processing image...")
    image_grid_thw1, image_embeds1 = tmp_process_image(
        image_path, question="", min_pixels=350000, max_pixels=500000
    )
    print(image_embeds1)
    print("Image processed.")
    time.sleep(1000)
image

可以问您一下,怎么只load Qwen2-vl语言模型部分的权重呀?

@whyiug
Copy link
Author

whyiug commented Nov 20, 2024

I solved the problem with a smaller gpu memory. I'll post the code in case anyone else runs into a similar problem. Just look at theload_model section.

import torch
from transformers import AutoConfig
from transformers import AutoProcessor
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
    Qwen2VisionTransformerPretrainedModel,
)
from safetensors.torch import load_file
from qwen_vl_utils import process_vision_info
import time
import os

visual_model = None
processor = None
device = "cuda"


def load_model():
    global visual_model, processor
    model_path = "/home/work/.cache/modelscope/hub/qwen/Qwen2-VL-7B-Instruct"
    config = AutoConfig.from_pretrained(model_path)
    visual_model = Qwen2VisionTransformerPretrainedModel._from_config(
        config=config.vision_config
    )
    checkpoint_path = os.path.join(model_path, "model-00001-of-00005.safetensors")
    checkpoint = load_file(checkpoint_path)
    visual_weights = {
        key.replace("visual.", ""): value
        for key, value in checkpoint.items()
        if key.startswith("visual.")
    }
    visual_model.load_state_dict(visual_weights, strict=False)
    visual_model.to(device)
    visual_model.eval()
    processor = AutoProcessor.from_pretrained(model_path)


load_model()


def tmp_process_image(image_path, question="", min_pixels=350000, max_pixels=500000):
    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": image_path,
                    "max_pixels": max_pixels,
                    "min_pixels": min_pixels,
                },
                {"type": "text", "text": question},
            ],
        }
    ]

    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    image_inputs, video_inputs = process_vision_info(messages)

    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to(device)
    pixel_values = inputs["pixel_values"].type(torch.bfloat16)
    with torch.no_grad():
        image_embeds = visual_model(pixel_values, grid_thw=inputs["image_grid_thw"])
    return inputs["image_grid_thw"], image_embeds


if __name__ == "__main__":
    image_path = "/home/work/data/xxx/up.jpeg"
    print("Processing image...")
    image_grid_thw1, image_embeds1 = tmp_process_image(
        image_path, question="", min_pixels=350000, max_pixels=500000
    )
    print(image_embeds1)
    print("Image processed.")
    time.sleep(1000)
image

可以问您一下,怎么只load Qwen2-vl语言模型部分的权重呀?

参照 visual_weights 的写法和
https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/blob/main/model.safetensors.index.json

@ql632
Copy link

ql632 commented Dec 3, 2024

@whyiug amazing work! I'm wondering whether we can get text embedding through similar way? Do you have a try on it?

没有。感觉只获取 text embedding 是没有意义的,而且这个耗时非常少。如果非要获取,可以查看这里代码 我把 visual model 拆开是因为希望图片模型和文本模型分开部署,而且业务需要一图多问,这对延时和吞吐都有提升。 vllm 支持输入图片 embedding 的 PR: vllm-project/vllm#8856

好的,感谢!我是在想能不能用图文的embedding来做图文检索任务,不知道是否有可行性。

怎么样 图文检索可行吗

@suexin27
Copy link

suexin27 commented Jan 2, 2025

@whyiug amazing work! I'm wondering whether we can get text embedding through similar way? Do you have a try on it?

没有。感觉只获取 text embedding 是没有意义的,而且这个耗时非常少。如果非要获取,可以查看这里代码 我把 visual model 拆开是因为希望图片模型和文本模型分开部署,而且业务需要一图多问,这对延时和吞吐都有提升。 vllm 支持输入图片 embedding 的 PR: vllm-project/vllm#8856

好的,感谢!我是在想能不能用图文的embedding来做图文检索任务,不知道是否有可行性。

怎么样 图文检索可行吗

您好 请问图文融合后的embedding 可以获取到吗? 您解决了吗?

@whyiug
Copy link
Author

whyiug commented Jan 7, 2025

@whyiug amazing work! I'm wondering whether we can get text embedding through similar way? Do you have a try on it?

没有。感觉只获取 text embedding 是没有意义的,而且这个耗时非常少。如果非要获取,可以查看这里代码 我把 visual model 拆开是因为希望图片模型和文本模型分开部署,而且业务需要一图多问,这对延时和吞吐都有提升。 vllm 支持输入图片 embedding 的 PR: vllm-project/vllm#8856

好的,感谢!我是在想能不能用图文的embedding来做图文检索任务,不知道是否有可行性。

怎么样 图文检索可行吗

您好 请问图文融合后的embedding 可以获取到吗? 您解决了吗?

你也是做检索吗?还是推理架构分离?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants