forked from openvinotoolkit/openvino_notebooks
-
Notifications
You must be signed in to change notification settings - Fork 0
/
blip_model.py
156 lines (139 loc) · 6.58 KB
/
blip_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import torch
import numpy as np
import openvino as ov
from typing import List, Dict
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
def init_past_inputs(model_inputs:List):
"""
Helper function for initialization of past inputs on first inference step
Parameters:
model_inputs (List): list of model inputs
Returns:
pkv (List[ov.Tensor]): list of filled past key values
"""
pkv = []
for input_tensor in model_inputs[4:]:
partial_shape = input_tensor.partial_shape
partial_shape[0] = 1
partial_shape[2] = 0
pkv.append(ov.Tensor(ov.Type.f32, partial_shape.get_shape()))
return pkv
def postprocess_text_decoder_outputs(output:Dict):
"""
Helper function for rearranging model outputs and wrapping to CausalLMOutputWithCrossAttentions
Parameters:
output (Dict): dictionary with model output
Returns
wrapped_outputs (CausalLMOutputWithCrossAttentions): outputs wrapped to CausalLMOutputWithCrossAttentions format
"""
logits = torch.from_numpy(output[0])
past_kv = list(output.values())[1:]
return CausalLMOutputWithCrossAttentions(
loss=None,
logits=logits,
past_key_values=past_kv,
hidden_states=None,
attentions=None,
cross_attentions=None
)
def text_decoder_forward(
ov_text_decoder_with_past:ov.CompiledModel,
input_ids:torch.Tensor,
attention_mask:torch.Tensor,
past_key_values:List[ov.Tensor],
encoder_hidden_states:torch.Tensor,
encoder_attention_mask:torch.Tensor,
**kwargs
):
"""
Inference function for text_decoder in one generation step
Parameters:
input_ids (torch.Tensor): input token ids
attention_mask (torch.Tensor): attention mask for input token ids
past_key_values (List[ov.Tensor] list of cached decoder hidden states from previous step
encoder_hidden_states (torch.Tensor): encoder (vision or text) hidden states
encoder_attention_mask (torch.Tensor): attnetion mask for encoder hidden states
Returns
model outputs (CausalLMOutputWithCrossAttentions): model prediction wrapped to CausalLMOutputWithCrossAttentions class including predicted logits and hidden states for caching
"""
inputs = [input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask]
if past_key_values is None:
inputs.extend(init_past_inputs(ov_text_decoder_with_past.inputs))
else:
inputs.extend(past_key_values)
outputs = ov_text_decoder_with_past(inputs)
return postprocess_text_decoder_outputs(outputs)
class OVBlipModel:
"""
Model class for inference BLIP model with OpenVINO
"""
def __init__(self, config, decoder_start_token_id:int, vision_model, text_encoder, text_decoder):
"""
Initialization class parameters
"""
self.vision_model = vision_model
self.vision_model_out = vision_model.output(0)
self.text_encoder = text_encoder
self.text_encoder_out = text_encoder.output(0)
self.text_decoder = text_decoder
self.config = config
self.decoder_start_token_id = decoder_start_token_id
self.decoder_input_ids = config.text_config.bos_token_id
def generate_answer(self, pixel_values:torch.Tensor, input_ids:torch.Tensor, attention_mask:torch.Tensor, **generate_kwargs):
"""
Visual Question Answering prediction
Parameters:
pixel_values (torch.Tensor): preprocessed image pixel values
input_ids (torch.Tensor): question token ids after tokenization
attention_mask (torch.Tensor): attention mask for question tokens
Retruns:
generation output (torch.Tensor): tensor which represents sequence of generated answer token ids
"""
image_embed = self.vision_model(pixel_values.detach().numpy())[self.vision_model_out]
image_attention_mask = np.ones(image_embed.shape[:-1], dtype=int)
if isinstance(input_ids, list):
input_ids = torch.LongTensor(input_ids)
question_embeds = self.text_encoder([input_ids.detach().numpy(), attention_mask.detach().numpy(), image_embed, image_attention_mask])[self.text_encoder_out]
question_attention_mask = np.ones(question_embeds.shape[:-1], dtype=int)
bos_ids = np.full((question_embeds.shape[0], 1), fill_value=self.decoder_start_token_id)
outputs = self.text_decoder.generate(
input_ids=torch.from_numpy(bos_ids),
eos_token_id=self.config.text_config.sep_token_id,
pad_token_id=self.config.text_config.pad_token_id,
encoder_hidden_states=torch.from_numpy(question_embeds),
encoder_attention_mask=torch.from_numpy(question_attention_mask),
**generate_kwargs,
)
return outputs
def generate_caption(self, pixel_values:torch.Tensor, input_ids:torch.Tensor = None, attention_mask:torch.Tensor = None, **generate_kwargs):
"""
Image Captioning prediction
Parameters:
pixel_values (torch.Tensor): preprocessed image pixel values
input_ids (torch.Tensor, *optional*, None): pregenerated caption token ids after tokenization, if provided caption generation continue provided text
attention_mask (torch.Tensor): attention mask for caption tokens, used only if input_ids provided
Retruns:
generation output (torch.Tensor): tensor which represents sequence of generated caption token ids
"""
batch_size = pixel_values.shape[0]
image_embeds = self.vision_model(pixel_values.detach().numpy())[self.vision_model_out]
image_attention_mask = torch.ones(image_embeds.shape[:-1], dtype=torch.long)
if isinstance(input_ids, list):
input_ids = torch.LongTensor(input_ids)
elif input_ids is None:
input_ids = (
torch.LongTensor([[self.config.text_config.bos_token_id, self.config.text_config.eos_token_id]])
.repeat(batch_size, 1)
)
input_ids[:, 0] = self.config.text_config.bos_token_id
attention_mask = attention_mask[:, :-1] if attention_mask is not None else None
outputs = self.text_decoder.generate(
input_ids=input_ids[:, :-1],
eos_token_id=self.config.text_config.sep_token_id,
pad_token_id=self.config.text_config.pad_token_id,
attention_mask=attention_mask,
encoder_hidden_states=torch.from_numpy(image_embeds),
encoder_attention_mask=image_attention_mask,
**generate_kwargs,
)
return outputs