-
Notifications
You must be signed in to change notification settings - Fork 3
/
caption.py
152 lines (130 loc) · 4.6 KB
/
caption.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import math
import os
import time
import json
import argparse
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import open_clip
from utils import (
GenerationDataset,
count_all_parameters,
count_trainable_parameters,
set_random_seed,
)
from tqdm import tqdm
from loguru import logger
def create_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--dataset",
type=str,
default="Beijing",
choices=["Beijing", "Shanghai", "Guangzhou", "Shenzhen"],
help="which dataset",
)
parser.add_argument("--batch_size", type=int, default=2, help="batch size")
parser.add_argument(
"--pretrained_model",
type=str,
# default="checkpoints/best_model.bin",
default="/root/" + "laion-mscoco_finetuned_CoCa-ViT-L-14-laion2B-s13B-b90k/open_clip_pytorch_model.bin",
help="pretrained model, mscoco_finetuned_laion2B-s13B-b90k",
)
parser.add_argument(
"--logging_dir", type=str, default="logs/downtask2", help="logging directory"
)
parser.add_argument(
"--checkpoint_dir",
type=str,
default="checkpoints/downtask2",
help="checkpoint path",
)
parser.add_argument(
"--seed", type=int, default=133, help="random seed for reproducibility"
)
args = parser.parse_args()
return args
def create_datasets(args, transform):
"""To create inference datasets."""
if args.dataset == "Beijing":
path = Path("data/images/Beijing")
jpg_files = list(path.glob("*.jpg"))
elif args.dataset == "Shanghai":
path = Path("data/images/Shanghai")
jpg_files = list(path.glob("*.jpg"))
elif args.dataset == "Guangzhou":
path = Path("data/images/Guangzhou")
jpg_files = list(path.glob("*.jpg"))
elif args.dataset == "Shenzhen":
path = Path("data/images/Shenzhen")
jpg_files = list(path.glob("*.jpg"))
else:
raise ValueError("dataset not found")
# create datasets
dataset = GenerationDataset(jpg_files, transform)
return dataset
def inference(model, image_paths, dataloader, args, logger):
"""test on test dataset."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.eval()
all_text_strings = []
with torch.no_grad():
for i, images in tqdm(enumerate(dataloader)):
images = images.to(device=device, non_blocking=True)
generated = model.generate(images, generation_type="top_p")
# print(type(generated)) # Tensor
# print("generates.shape: {}".format(generated.shape)) # torch.Size([batch_size, 12])
for j in range(generated.shape[0]):
text = (
open_clip.decode(generated[j])
.split("<end_of_text>")[0]
.replace("<start_of_text>", "")
)
all_text_strings.append(text)
df = pd.DataFrame(
{
"image": [str(item).split("/")[-1] for item in image_paths],
"caption": all_text_strings,
}
)
df.to_csv(os.path.join(args.checkpoint_dir, "captions.csv"), index=False)
def main():
args = create_args()
set_random_seed(args.seed)
# create logger
if not os.path.exists(args.logging_dir):
os.makedirs(args.logging_dir)
if not os.path.exists(args.checkpoint_dir):
os.makedirs(args.checkpoint_dir)
logger.remove(handler_id=None) # remove default logger
logger.add(os.path.join(args.logging_dir, str(args.seed) + ".log"), level="INFO")
logger.info(args)
# create model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, _, transform = open_clip.create_model_and_transforms(
model_name="coca_ViT-L-14", pretrained=args.pretrained_model
)
model.to(device)
for param in model.parameters():
param.requires_grad = False
logger.info("model parameters: {}".format(count_all_parameters(model)))
logger.info(
"model trainable parameters: {}".format(count_trainable_parameters(model))
)
# tokenizer = open_clip.get_tokenizer("coca_ViT-L-14")
# create datasets
dataset = create_datasets(args, transform)
logger.info("inference dataset size: {}".format(len(dataset)))
# create dataloaders
dataloader = DataLoader(
dataset, batch_size=args.batch_size, shuffle=False, drop_last=False
)
inference(model, dataset.jpg_list, dataloader, args, logger)
if __name__ == "__main__":
main()