-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
54 lines (43 loc) · 1.66 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
import datasets
import pandas as pd
from datasets import Dataset
from transformers import GitProcessor
from literal import ANSWER, IMG, IMG_PATH, QUESTION
@dataclass
class DataCollatorForGit:
processor: GitProcessor
padding = True
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
return_tensors: str = "pt"
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
images = [feature[IMG].convert("RGB") for feature in features]
if ANSWER in features[0]:
questions = [
feature[QUESTION] + self.processor.tokenizer.sep_token + feature[ANSWER] for feature in features
]
else:
questions = [feature[QUESTION] for feature in features]
batch = self.processor(images=images, return_tensors=self.return_tensors)
tokenized_question = self.processor.tokenizer(
questions, padding=self.padding, return_tensors=self.return_tensors
)
batch["input_ids"] = tokenized_question.input_ids
batch["attention_mask"] = tokenized_question.attention_mask
if ANSWER in features[0]:
batch["labels"] = batch["input_ids"]
return batch
def get_dataset(csv_path: os.PathLike) -> Dataset:
df = pd.read_csv(csv_path)
data_dict = {
IMG: df[IMG_PATH].tolist(),
QUESTION: df[QUESTION].tolist(),
}
if ANSWER in df.columns:
data_dict[ANSWER] = df[ANSWER].tolist()
dataset = Dataset.from_dict(data_dict)
dataset = dataset.cast_column(IMG, datasets.Image())
return dataset