-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
191 lines (150 loc) · 7.4 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import json
import os
import random
from typing import Any, List
from torch.utils.data import Dataset
from pathlib import Path
import re
from nltk import edit_distance
import numpy as np
import wandb
from PIL import Image
from PIL import ImageFile
import torch
ImageFile.LOAD_TRUNCATED_IMAGES = True
from transformers.optimization import Adafactor, get_cosine_schedule_with_warmup
import pytorch_lightning as pl
class ImageCaptioningDataset(Dataset):
def __init__(
self,
data_json: str,
processor,
model,
max_patches: int,
max_length:int,
ignore_id: int = -100,
task_start_token: str = "",
prompt_end_token: str = None,
):
super().__init__()
self.dataset = data_json
self.processor = processor
self.added_tokens = []
self.model = model
self.max_patches = max_patches
self.max_length = max_length
self.ignore_id = ignore_id
self.task_start_token = task_start_token
self.prompt_end_token = prompt_end_token if prompt_end_token else task_start_token
self.prompt_end_token_id = self.processor.tokenizer.convert_tokens_to_ids(self.prompt_end_token)
def add_tokens(self, list_of_tokens: List[str]):
"""
Add special tokens to tokenizer and resize the token embeddings of the decoder
"""
newly_added_num = self.processor.tokenizer.add_tokens(list_of_tokens)
if newly_added_num > 0:
self.model.decoder.resize_token_embeddings(len(self.processor.tokenizer))
self.added_tokens.extend(list_of_tokens)
def __len__(self) -> int:
return len(self.dataset)
def __getitem__(self, idx):
item = self.dataset[idx]
# prepare inputs
# print(f"vqa mode on::{self.processor.image_processor.is_vqa}")
image_folder_path = os.path.join(os.getenv("DATA_DIR"), "images")
if not os.path.exists(image_folder_path):
raise FileNotFoundError(f"Image folder not found: {image_folder_path}")
img_name = f'{item["image_name"]}'
img_path = os.path.join(image_folder_path, img_name)
item["question"]=item["question"].replace("\u200f", "").strip()
item["answer"]=item["answer"].replace("\u200f", "").strip()
# img_path = os.path.join("/home/ubuntu/akshat/ara_intern_docvqa/master_compiled_images", doc_id)
img = Image.open(img_path)
encoding = self.processor(images=img, text = item["question"], max_patches=self.max_patches, return_tensors="pt", font_path = "/usr/src/app/Arial.TTF")
encoding = {k:v.squeeze() for k,v in encoding.items()}
# prepare targets
# target_sequence = random.choice(self.gt_token_sequences[idx]) # can be more than one, e.g., DocVQA Task 1
target_sequence=item["answer"]
input_ids = self.processor.tokenizer(
target_sequence,
max_length=self.max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
).input_ids
labels = input_ids.squeeze().clone()
labels[labels == self.processor.tokenizer.pad_token_id] = self.ignore_id # model doesn't need to predict pad token
encoding["labels"] = labels
# labels[: torch.nonzero(labels == self.prompt_end_token_id).sum() + 1] = self.ignore_id # model doesn't need to predict prompt (for VQA)
return encoding, target_sequence
class Pix2Struct(pl.LightningModule):
def __init__(self, config, processor, model, train_data, val_data):
super().__init__()
self.config = config
self.processor = processor
self.model = model
self.train_data = train_data
self.val_data = val_data
def training_step(self, batch, batch_idx):
encoding, _ = batch
outputs = self.model(**encoding)
loss = outputs.loss
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx, dataset_idx=0):
encoding, answers = batch
flattened_patches, attention_mask = encoding["flattened_patches"], encoding["attention_mask"]
# batch_size = flattened_patches.shape[0]
# we feed the prompt to the model
# decoder_input_ids = torch.full((batch_size, 1), self.model.config.text_config.decoder_start_token_id, device=self.device)
outputs = self.model.generate(flattened_patches=flattened_patches,
attention_mask=attention_mask,
# decoder_input_ids=decoder_input_ids,
max_new_tokens=256,
min_length = 1,
return_dict_in_generate=True,)
predictions = []
for seq in self.processor.tokenizer.batch_decode(outputs.sequences):
seq = seq.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token, "")
# seq.replace("")
# seq = re.sub(r"<.*?>", "", seq, count=1).strip() # remove first task start token
predictions.append(seq)
scores = []
for pred, answer in zip(predictions, answers):
# pred = re.sub(r"(?:(?<=>) | (?=", "", answer, count=1)
answer = answer.replace(self.processor.tokenizer.eos_token, "").strip()
scores.append(edit_distance(pred.strip(), answer.strip()) / max(len(pred), len(answer), 1))
if self.config.get("verbose", False) and len(scores) == 1:
print(f"Prediction: {pred}")
print(f" Answer: {answer}")
print(f" Normed ED: {scores[0]}")
self.log("val_edit_distance", np.mean(scores))
# score_file_path = os.path.join(os.getenv("DATA_DIR"), "training", "scores.txt")
# os.makedirs(os.path.dirname(score_file_path), exist_ok=True)
# if not os.path.exists(score_file_path):
# with open(score_file_path, "w") as f:
# f.write("edit_distance\n")
# with open(score_file_path, "a") as f:
# f.write(f"{np.mean(scores)}\n")
return scores
def configure_optimizers(self):
# optimizer = Adafactor(self.parameters(), scale_parameter=False, relative_step=False, lr=self.config.get("lr"), weight_decay=1e-05)
optimizer = torch.optim.Adam(self.parameters(), lr=self.config.get("lr"))
# scheduler = get_cosine_schedule_with_warmup(optimizer,
# num_warmup_steps=self.config.get("warmup_epochs"),
# num_training_steps=self.config.get("max_epochs"))
return [optimizer]
def train_dataloader(self):
return self.train_data
def val_dataloader(self):
return self.val_data
from pytorch_lightning.callbacks import Callback
import os
class LogValidationDistanceCallback(Callback):
def __init__(self, file_path):
self.file_path = file_path
def on_validation_epoch_end(self, trainer, pl_module):
avg_edit_distance = trainer.callback_metrics.get('val_edit_distance')
if avg_edit_distance is not None:
with open(self.file_path, 'a') as file:
file.write(f'Epoch {trainer.current_epoch}: {avg_edit_distance:.4f}\n')