-
Notifications
You must be signed in to change notification settings - Fork 38
/
predict.py
305 lines (265 loc) · 12.2 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
# Prediction interface for Cog ⚙️
# https://github.com/replicate/cog/blob/main/docs/python.md
import os
import sys
import time
import subprocess
import shutil
from PIL import Image
from omegaconf import OmegaConf
from moviepy.editor import VideoFileClip
import numpy as np
from cog import BasePredictor, Input, Path
import torch
from diffusers import DDIMInverseScheduler, DDIMScheduler
from diffusers.utils import load_image
import imageio
from black_box_image_edit import InstructPix2Pix
sys.path.insert(0, "i2vgen-xl")
from utils import load_ddim_latents_at_t
from pipelines.pipeline_i2vgen_xl import I2VGenXLPipeline
from run_group_ddim_inversion import ddim_inversion
from run_group_pnp_edit import init_pnp
# Weights are saved and loaded from replicate.delivery for faster booting
INSTRUCTPIX2PIX_URL = "https://weights.replicate.delivery/default/timbrooks/instruct-pix2pix.tar" # original pipeline weights from timbrooks/instruct-pix2pix
INSTRUCTPIX2PIX_CACHE = "weights/timbrooks/instruct-pix2pix"
ALI_I2VGENXL_URL = "https://weights.replicate.delivery/default/ali-vilab/i2vgen-xl.tar" # original pipeline weights from ali-vilab/i2vgen-xl
ALI_I2VGENXL_CACHE = "weights/ali-vilab/i2vgen-xl"
def download_weights(url, dest):
start = time.time()
print("downloading url: ", url)
print("downloading to: ", dest)
subprocess.check_call(["pget", "-x", url, dest], close_fds=False)
print("downloading took: ", time.time() - start)
class Predictor(BasePredictor):
def setup(self) -> None:
"""Load the model into memory to make running multiple predictions efficient"""
if not os.path.exists(INSTRUCTPIX2PIX_CACHE):
download_weights(INSTRUCTPIX2PIX_URL, INSTRUCTPIX2PIX_CACHE)
self.black_box_image_model = InstructPix2Pix(weight=INSTRUCTPIX2PIX_CACHE)
if not os.path.exists(ALI_I2VGENXL_CACHE):
download_weights(ALI_I2VGENXL_URL, ALI_I2VGENXL_CACHE)
# Initialize the DDIM inverse scheduler
self.inverse_scheduler = DDIMInverseScheduler.from_pretrained(
ALI_I2VGENXL_CACHE,
subfolder="scheduler",
)
# Initialize the DDIM scheduler
self.ddim_scheduler = DDIMScheduler.from_pretrained(
ALI_I2VGENXL_CACHE,
subfolder="scheduler",
)
# Set up default inversion config file
config = {
# DDIM inversion
"inverse_config": {
"image_size": [512, 512],
"n_frames": 16,
"cfg": 1.0,
"target_fps": 8,
"prompt": "",
"negative_prompt": "",
},
"pnp_config": {
"ddim_inv_prompt": "",
"random_ratio": 0.0,
"target_fps": 8,
},
}
self.config = OmegaConf.create(config)
def predict(
self,
video: Path = Input(description="Input video"),
edited_first_frame: Path = Input(
description="Provide the edited first frame of the input video. This is optional, leave it blank and provide the prompt below to use the default pipeline that edits the frist frame with instructpix2pix",
default=None,
),
instruct_pix2pix_prompt: str = Input(
description="The first step invovles using timbrooks/instruct-pix2pix to edit the first frame. Specify the prompt for editing the first frame. This will be ignored if edited_first_frame above is provided.",
default="turn man into robot",
),
editing_prompt: str = Input(
description="Describe the input video",
default="a man doing exercises for the body and mind",
),
editing_negative_prompt: str = Input(
description="Things not to see int the edited video",
default="Distorted, discontinuous, Ugly, blurry, low resolution, motionless, static, disfigured, disconnected limbs, Ugly faces, incomplete arms",
),
num_inference_steps: int = Input(
description="Number of denoising steps", ge=1, le=500, default=50
),
guidance_scale: float = Input(
description="Scale for classifier-free guidance", ge=1, le=20, default=9.0
),
pnp_f_t: float = Input(
description="Specifies the proportion of time steps in the DDIM sampling process where the convolutional injection is applied. A higher value improves motion consistency. 1.0 indicates injection at every time step",
ge=0.0,
le=1.0,
default=1.0,
),
pnp_spatial_attn_t: float = Input(
description="Specifies the proportion of time steps in the DDIM sampling process where the spatial attention injection is applied. A higher value improves motion consistency. 1.0 indicates injection at every time step",
ge=0.0,
le=1.0,
default=1.0,
),
pnp_temp_attn_t: float = Input(
description="Specifies the proportion of time steps in the DDIM sampling process where the temporal attention injection is applied. A higher value improves motion consistency. 1.0 indicates injection at every time step",
ge=0.0,
le=1.0,
default=1.0,
),
ddim_init_latents_t_idx: int = Input(
description="This parameter determines the time step index at which to begin sampling from the initial DDIM inversed latents, with a range of [0, num_inference_steps-1]. In the context of a DDIM sampling process where the sampling step is 50, the scheduler progresses through the time steps in the sequence [981, 961, 941, ..., 1]. Therefore, setting ddim_init_latents_t_idx to 0 initiates the sampling from t=981, whereas setting it to 1 starts the process at t=961. A higher index enhances motion consistency with the source video but may lead to flickering and cause the edited video to diverge from the edited first frame.",
ge=0,
default=0,
),
ddim_inversion_steps: int = Input(
description="Number of ddim inversion steps", default=100
),
seed: int = Input(
description="Random seed. Leave blank to randomize the seed", default=None
),
) -> Path:
"""Run a single prediction on the model"""
if seed is None:
seed = int.from_bytes(os.urandom(2), "big")
print(f"Using seed: {seed}")
tmp_dir = "exp_dir"
if os.path.exists(tmp_dir):
shutil.rmtree(tmp_dir)
os.makedirs(tmp_dir)
ddim_latents_path = os.path.join(tmp_dir, "ddim_latents")
frame_list = read_frames(str(video))
self.config.inverse_config.image_size = list(frame_list[0].size)
self.config.inverse_config.n_steps = ddim_inversion_steps
self.config.inverse_config.n_frames = len(frame_list)
self.config.inverse_config.output_dir = ddim_latents_path
ddim_init_latents_t_idx = min(ddim_init_latents_t_idx, num_inference_steps - 1)
if edited_first_frame is not None:
edited_first_frame_path = str(edited_first_frame)
else:
# Step 0. Black-box image editing for the first frame
edited_first_frame_path = os.path.join(tmp_dir, "edited_first_frame.png")
infer_video(
self.black_box_image_model,
str(video),
edited_first_frame_path,
instruct_pix2pix_prompt,
seed=seed,
)
# Step 1. DDIM Inversion
first_frame = frame_list[0]
pipe = I2VGenXLPipeline.from_pretrained(
ALI_I2VGENXL_CACHE,
torch_dtype=torch.float16,
variant="fp16",
).to("cuda:0")
generator = torch.Generator(device="cuda:0")
generator = generator.manual_seed(seed)
_ddim_latents = ddim_inversion(
self.config.inverse_config,
first_frame,
frame_list,
pipe,
self.inverse_scheduler,
generator,
)
# Step 2. DDIM Sampling + PnP feature and attention injection
# Load the edited first frame
edited_1st_frame = load_image(edited_first_frame_path).resize(
self.config.inverse_config.image_size, resample=Image.Resampling.LANCZOS
)
# Load the initial latents at t
self.ddim_scheduler.set_timesteps(num_inference_steps)
print(f"ddim_scheduler.timesteps: {self.ddim_scheduler.timesteps}")
ddim_latents_at_t = load_ddim_latents_at_t(
self.ddim_scheduler.timesteps[ddim_init_latents_t_idx],
ddim_latents_path=ddim_latents_path,
)
print(
f"ddim_scheduler.timesteps[t_idx]: {self.ddim_scheduler.timesteps[ddim_init_latents_t_idx]}"
)
print(f"ddim_latents_at_t.shape: {ddim_latents_at_t.shape}")
# Blend the latents
random_latents = torch.randn_like(ddim_latents_at_t)
print(
f"Blending random_ratio (1 means random latent): {self.config.pnp_config.random_ratio}"
)
mixed_latents = (
random_latents * self.config.pnp_config.random_ratio
+ ddim_latents_at_t * (1 - self.config.pnp_config.random_ratio)
)
# Init Pnp
self.config.pnp_config.n_steps = num_inference_steps
self.config.pnp_config.pnp_f_t = pnp_f_t
self.config.pnp_config.pnp_spatial_attn_t = pnp_spatial_attn_t
self.config.pnp_config.pnp_temp_attn_t = pnp_temp_attn_t
self.config.pnp_config.ddim_init_latents_t_idx = ddim_init_latents_t_idx
init_pnp(pipe, self.ddim_scheduler, self.config.pnp_config)
# Edit video
pipe.register_modules(scheduler=self.ddim_scheduler)
edited_video = pipe.sample_with_pnp(
prompt=editing_prompt,
image=edited_1st_frame,
height=self.config.inverse_config.image_size[1],
width=self.config.inverse_config.image_size[0],
num_frames=self.config.inverse_config.n_frames,
num_inference_steps=self.config.pnp_config.n_steps,
guidance_scale=guidance_scale,
negative_prompt=editing_negative_prompt,
target_fps=self.config.pnp_config.target_fps,
latents=mixed_latents,
generator=generator,
return_dict=True,
ddim_init_latents_t_idx=ddim_init_latents_t_idx,
ddim_inv_latents_path=ddim_latents_path,
ddim_inv_prompt="",
ddim_inv_1st_frame=first_frame,
).frames[0]
edited_video = [
frame.resize(self.config.inverse_config.image_size, resample=Image.LANCZOS)
for frame in edited_video
]
output_path = "/tmp/out.mp4"
images_to_video(
edited_video, output_path, fps=self.config.pnp_config.target_fps
)
return Path(output_path)
def infer_video(
model, video_path, result_path, prompt, force_512=False, seed=42, negative_prompt=""
):
# Create the output directory if it does not exist
video_clip = VideoFileClip(video_path)
def process_frame(image):
pil_image = Image.fromarray(image)
if force_512:
pil_image = pil_image.resize((512, 512), Image.LANCZOS)
result = model.infer_one_image(
pil_image,
instruct_prompt=prompt,
seed=seed,
negative_prompt=negative_prompt,
)
if force_512:
result = result.resize(video_clip.size, Image.LANCZOS)
return np.array(result)
# Process only the first frame
first_frame = video_clip.get_frame(0) # Get the first frame
processed_frame = process_frame(first_frame) # Process the first frame
Image.fromarray(processed_frame).save(result_path)
print(f"Processed and saved the first frame: {result_path}")
def images_to_video(images, output_path, fps=24):
writer = imageio.get_writer(output_path, fps=fps)
for img in images:
img_np = np.array(img)
writer.append_data(img_np)
writer.close()
def read_frames(video_path):
frames = []
with imageio.get_reader(video_path) as reader:
for i, frame in enumerate(reader):
pil_image = Image.fromarray(frame)
frames.append(pil_image)
return frames