-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathprocess.py
115 lines (90 loc) · 4.46 KB
/
process.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import glob
import cv2
import argparse
import numpy as np
import torch
from PIL import Image
import rembg
os.environ["OMP_NUM_THREADS"] = "10"
class BLIP2():
def __init__(self, device='cuda'):
self.device = device
from transformers import AutoProcessor, Blip2ForConditionalGeneration
self.processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
self.model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16).to(device)
@torch.no_grad()
def __call__(self, image):
image = Image.fromarray(image)
inputs = self.processor(image, return_tensors="pt").to(self.device, torch.float16)
generated_ids = self.model.generate(**inputs, max_new_tokens=20)
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
return generated_text
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('path', type=str, help="path to image (png, jpeg, etc.)")
parser.add_argument('--model', default='u2net', type=str, help="rembg model, see https://github.com/danielgatis/rembg#models")
parser.add_argument('--size', default=256, type=int, help="output resolution")
parser.add_argument('--border_ratio', default=0.2, type=float, help="output border ratio")
parser.add_argument('--recenter', type=bool, default=True, help="recenter, potentially not helpful for multiview zero123")
opt = parser.parse_args()
session = rembg.new_session(model_name=opt.model)
if os.path.isdir(opt.path):
print(f'[INFO] processing directory {opt.path}...')
files = glob.glob(f'{opt.path}/*')
out_dir = opt.path
else: # isfile
files = [opt.path]
out_dir = os.path.dirname(opt.path)
os.makedirs(os.path.join(out_dir, 'processed'), exist_ok=True)
os.makedirs(os.path.join(out_dir, 'source'), exist_ok=True)
for file in files:
out_base = os.path.basename(file).split('.')[0]
out_rgba = os.path.join(out_dir, 'processed', out_base + '_rgba.png')
out_rgb = os.path.join(out_dir, 'source', out_base + '.png')
# load image
print(f'[INFO] loading image {file}...')
image = cv2.imread(file, cv2.IMREAD_UNCHANGED)
# carve background
print(f'[INFO] background removal...')
carved_image = rembg.remove(image, session=session) # [H, W, 4]
mask = carved_image[..., -1] > 0
# recenter
if opt.recenter:
print(f'[INFO] recenter...')
final_rgb = np.zeros((opt.size, opt.size, 3), dtype=np.uint8)
final_rgba = np.zeros((opt.size, opt.size, 4), dtype=np.uint8)
coords = np.nonzero(mask)
x_min, x_max = coords[0].min(), coords[0].max()
y_min, y_max = coords[1].min(), coords[1].max()
h = x_max - x_min
w = y_max - y_min
desired_size = int(opt.size * (1 - opt.border_ratio))
scale = desired_size / max(h, w)
h2 = int(h * scale)
w2 = int(w * scale)
x2_min = (opt.size - h2) // 2
x2_max = x2_min + h2
y2_min = (opt.size - w2) // 2
y2_max = y2_min + w2
final_rgba[x2_min:x2_max, y2_min:y2_max] = cv2.resize(carved_image[x_min:x_max, y_min:y_max], (w2, h2), interpolation=cv2.INTER_AREA)
xc = (x_min + x_max) // 2
yc = (y_min + y_max) // 2
l = int(max(h, w) / (1 - opt.border_ratio)) // 2
x_min, x_max = xc - l, xc + l
y_min, y_max = yc - l, yc + l
H, W = image.shape[:2]
# pad the image in case the bbox is outside of boundary
canvas = np.zeros((max(H, x_max) - min(0, x_min), max(W, y_max) - min(0, y_min), 3), dtype=image.dtype)
# calculate where to place the original image on the canvas
y_offset = -min(0, y_min)
x_offset = -min(0, x_min)
canvas[x_offset:x_offset + H, y_offset:y_offset + W] = image
# extract the region from the padded canvas
roi = canvas[x_offset + x_min:x_offset + x_max, y_offset + y_min:y_offset + y_max]
final_rgb = cv2.resize(roi, (opt.size, opt.size), interpolation=cv2.INTER_AREA)
else:
final_rgba = carved_image
# write image
cv2.imwrite(out_rgba, final_rgba)
cv2.imwrite(out_rgb, final_rgb)