Skip to content

Commit

Permalink
Merge pull request CompVis#36 from enzymezoo-code/inpainting_1.0
Browse files Browse the repository at this point in the history
Updating ipynb with colab-convert
  • Loading branch information
enzymezoo-code authored Aug 29, 2022
2 parents 5776e29 + 9a62dce commit 3ba2f92
Showing 1 changed file with 113 additions and 84 deletions.
197 changes: 113 additions & 84 deletions Deforum_Stable_Diffusion.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -164,14 +164,10 @@
"def add_noise(sample: torch.Tensor, noise_amt: float):\n",
" return sample + torch.randn(sample.shape, device=sample.device) * noise_amt\n",
"\n",
"def get_output_folder(output_path,batch_folder=None):\n",
" yearMonth = time.strftime('%Y-%m/')\n",
" out_path = os.path.join(output_path,yearMonth)\n",
"def get_output_folder(output_path, batch_folder):\n",
" out_path = os.path.join(output_path,time.strftime('%Y-%m/'))\n",
" if batch_folder != \"\":\n",
" out_path = os.path.join(out_path,batch_folder)\n",
" # we will also make sure the path suffix is a slash if linux and a backslash if windows\n",
" if out_path[-1] != os.path.sep:\n",
" out_path += os.path.sep\n",
" out_path = os.path.join(out_path, batch_folder)\n",
" os.makedirs(out_path, exist_ok=True)\n",
" return out_path\n",
"\n",
Expand Down Expand Up @@ -203,14 +199,19 @@
" mask = torch.from_numpy(mask)\n",
" return mask\n",
"\n",
"def maintain_colors(prev_img, color_match_sample, hsv=False):\n",
" if hsv:\n",
"def maintain_colors(prev_img, color_match_sample, mode):\n",
" if mode == 'Match Frame 0 RGB':\n",
" return match_histograms(prev_img, color_match_sample, multichannel=True)\n",
" elif mode == 'Match Frame 0 HSV':\n",
" prev_img_hsv = cv2.cvtColor(prev_img, cv2.COLOR_RGB2HSV)\n",
" color_match_hsv = cv2.cvtColor(color_match_sample, cv2.COLOR_RGB2HSV)\n",
" matched_hsv = match_histograms(prev_img_hsv, color_match_hsv, multichannel=True)\n",
" return cv2.cvtColor(matched_hsv, cv2.COLOR_HSV2RGB)\n",
" else:\n",
" return match_histograms(prev_img, color_match_sample, multichannel=True)\n",
" else: # Match Frame 0 LAB\n",
" prev_img_lab = cv2.cvtColor(prev_img, cv2.COLOR_RGB2LAB)\n",
" color_match_lab = cv2.cvtColor(color_match_sample, cv2.COLOR_RGB2LAB)\n",
" matched_lab = match_histograms(prev_img_lab, color_match_lab, multichannel=True)\n",
" return cv2.cvtColor(matched_lab, cv2.COLOR_LAB2RGB)\n",
"\n",
"\n",
"def make_callback(sampler_name, dynamic_threshold=None, static_threshold=None, mask=None, init_latent=None, sigmas=None, sampler=None, masked_noise_modifier=1.0): \n",
Expand Down Expand Up @@ -330,57 +331,56 @@
" with torch.no_grad():\n",
" with precision_scope(\"cuda\"):\n",
" with model.ema_scope():\n",
" for n in range(args.n_samples):\n",
" for prompts in data:\n",
" uc = None\n",
" if args.scale != 1.0:\n",
" uc = model.get_learned_conditioning(batch_size * [\"\"])\n",
" if isinstance(prompts, tuple):\n",
" prompts = list(prompts)\n",
" c = model.get_learned_conditioning(prompts)\n",
"\n",
" if args.init_c != None:\n",
" c = args.init_c\n",
"\n",
" if args.sampler in [\"klms\",\"dpm2\",\"dpm2_ancestral\",\"heun\",\"euler\",\"euler_ancestral\"]:\n",
" samples = sampler_fn(\n",
" c=c, \n",
" uc=uc, \n",
" args=args, \n",
" model_wrap=model_wrap, \n",
" init_latent=init_latent, \n",
" t_enc=t_enc, \n",
" device=device, \n",
" cb=callback)\n",
" for prompts in data:\n",
" uc = None\n",
" if args.scale != 1.0:\n",
" uc = model.get_learned_conditioning(batch_size * [\"\"])\n",
" if isinstance(prompts, tuple):\n",
" prompts = list(prompts)\n",
" c = model.get_learned_conditioning(prompts)\n",
"\n",
" if args.init_c != None:\n",
" c = args.init_c\n",
"\n",
" if args.sampler in [\"klms\",\"dpm2\",\"dpm2_ancestral\",\"heun\",\"euler\",\"euler_ancestral\"]:\n",
" samples = sampler_fn(\n",
" c=c, \n",
" uc=uc, \n",
" args=args, \n",
" model_wrap=model_wrap, \n",
" init_latent=init_latent, \n",
" t_enc=t_enc, \n",
" device=device, \n",
" cb=callback)\n",
" else:\n",
" # args.sampler == 'plms' or args.sampler == 'ddim':\n",
" if init_latent is not None and args.strength > 0:\n",
" z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))\n",
" else:\n",
" # args.sampler == 'plms' or args.sampler == 'ddim':\n",
" if init_latent is not None and args.strength > 0:\n",
" z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))\n",
" else:\n",
" z_enc = torch.randn([args.n_samples, args.C, args.H // args.f, args.W // args.f], device=device)\n",
" samples = sampler.decode(z_enc, \n",
" c, \n",
" t_enc, \n",
" unconditional_guidance_scale=args.scale,\n",
" unconditional_conditioning=uc,\n",
" img_callback=callback)\n",
"\n",
" if return_latent:\n",
" results.append(samples.clone())\n",
"\n",
" x_samples = model.decode_first_stage(samples)\n",
" if return_sample:\n",
" results.append(x_samples.clone())\n",
"\n",
" x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)\n",
"\n",
" if return_c:\n",
" results.append(c.clone())\n",
"\n",
" for x_sample in x_samples:\n",
" x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')\n",
" image = Image.fromarray(x_sample.astype(np.uint8))\n",
" results.append(image)\n",
" z_enc = torch.randn([args.n_samples, args.C, args.H // args.f, args.W // args.f], device=device)\n",
" samples = sampler.decode(z_enc, \n",
" c, \n",
" t_enc, \n",
" unconditional_guidance_scale=args.scale,\n",
" unconditional_conditioning=uc,\n",
" img_callback=callback)\n",
"\n",
" if return_latent:\n",
" results.append(samples.clone())\n",
"\n",
" x_samples = model.decode_first_stage(samples)\n",
" if return_sample:\n",
" results.append(x_samples.clone())\n",
"\n",
" x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)\n",
"\n",
" if return_c:\n",
" results.append(c.clone())\n",
"\n",
" for x_sample in x_samples:\n",
" x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')\n",
" image = Image.fromarray(x_sample.astype(np.uint8))\n",
" results.append(image)\n",
" return results\n",
"\n",
"def sample_from_cv2(sample: np.ndarray) -> torch.Tensor:\n",
Expand Down Expand Up @@ -569,13 +569,14 @@
" scale_schedule = \"0: (1.0)\"#@param {type:\"string\"}\n",
"\n",
" #@markdown ####**Coherence:**\n",
" color_coherence = 'MatchFrame0' #@param ['None', 'MatchFrame0'] {type:'string'}\n",
" color_coherence = 'Match Frame 0 HSV' #@param ['None', 'Match Frame 0 HSV', 'Match Frame 0 LAB', 'Match Frame 0 RGB'] {type:'string'}\n",
"\n",
" #@markdown ####**Video Input:**\n",
" video_init_path ='/content/video_in.mp4'#@param {type:\"string\"}\n",
" extract_nth_frame = 1#@param {type:\"number\"}\n",
"\n",
" #@markdown ####**Interpolation:**\n",
" interpolate_key_frames = False #@param {type:\"boolean\"}\n",
" interpolate_x_frames = 4 #@param {type:\"number\"}\n",
"\n",
" return locals()\n",
Expand Down Expand Up @@ -657,6 +658,7 @@
"id": "2ujwkGZTcGev"
},
"source": [
"\n",
"prompts = [\n",
" \"a beautiful forest by Asher Brown Durand, trending on Artstation\", #the first prompt I want\n",
" \"a beautiful portrait of a woman by Artgerm, trending on Artstation\", #the second prompt I want\n",
Expand All @@ -665,9 +667,9 @@
"\n",
"animation_prompts = {\n",
" 0: \"a beautiful apple, trending on Artstation\",\n",
" 10: \"a beautiful banana, trending on Artstation\",\n",
" 100: \"a beautiful coconut, trending on Artstation\",\n",
" 101: \"a beautiful durian, trending on Artstation\",\n",
" 20: \"a beautiful banana, trending on Artstation\",\n",
" 30: \"a beautiful coconut, trending on Artstation\",\n",
" 40: \"a beautiful durian, trending on Artstation\",\n",
"}"
],
"outputs": [],
Expand Down Expand Up @@ -726,7 +728,7 @@
" seed_behavior = \"iter\" #@param [\"iter\",\"fixed\",\"random\"]\n",
"\n",
" #@markdown **Grid Settings**\n",
" make_grid = True #@param {type:\"boolean\"}\n",
" make_grid = False #@param {type:\"boolean\"}\n",
" grid_rows = 2 #@param \n",
"\n",
" precision = 'autocast' \n",
Expand Down Expand Up @@ -893,11 +895,11 @@
" )\n",
"\n",
" # apply color matching\n",
" if anim_args.color_coherence == 'MatchFrame0':\n",
" if anim_args.color_coherence != 'None':\n",
" if color_match_sample is None:\n",
" color_match_sample = prev_img.copy()\n",
" else:\n",
" prev_img = maintain_colors(prev_img, color_match_sample, (frame_idx%2) == 0)\n",
" prev_img = maintain_colors(prev_img, color_match_sample, anim_args.color_coherence)\n",
"\n",
" # apply scaling\n",
" scaled_sample = prev_img * scale\n",
Expand Down Expand Up @@ -999,25 +1001,52 @@
"\n",
" frame_idx = 0\n",
"\n",
" for i in range(len(prompts_c_s)-1):\n",
" for j in range(anim_args.interpolate_x_frames+1):\n",
" # interpolate the text embedding\n",
" prompt1_c = prompts_c_s[i]\n",
" prompt2_c = prompts_c_s[i+1] \n",
" args.init_c = prompt1_c.add(prompt2_c.sub(prompt1_c).mul(j * 1/(anim_args.interpolate_x_frames+1)))\n",
" if anim_args.interpolate_key_frames:\n",
" for i in range(len(prompts_c_s)-1):\n",
" dist_frames = list(animation_prompts.items())[i+1][0] - list(animation_prompts.items())[i][0]\n",
" if dist_frames <= 0:\n",
" print(\"key frames duplicated or reversed. interpolation skipped.\")\n",
" return\n",
" else:\n",
" for j in range(dist_frames):\n",
" # interpolate the text embedding\n",
" prompt1_c = prompts_c_s[i]\n",
" prompt2_c = prompts_c_s[i+1] \n",
" args.init_c = prompt1_c.add(prompt2_c.sub(prompt1_c).mul(j * 1/dist_frames))\n",
"\n",
" # sample the diffusion model\n",
" results = generate(args)\n",
" image = results[0]\n",
" # sample the diffusion model\n",
" results = generate(args)\n",
" image = results[0]\n",
"\n",
" filename = f\"{args.timestring}_{frame_idx:05}.png\"\n",
" image.save(os.path.join(args.outdir, filename))\n",
" frame_idx += 1\n",
" filename = f\"{args.timestring}_{frame_idx:05}.png\"\n",
" image.save(os.path.join(args.outdir, filename))\n",
" frame_idx += 1\n",
"\n",
" display.clear_output(wait=True)\n",
" display.display(image)\n",
" display.clear_output(wait=True)\n",
" display.display(image)\n",
"\n",
" args.seed = next_seed(args)\n",
" args.seed = next_seed(args)\n",
"\n",
" else:\n",
" for i in range(len(prompts_c_s)-1):\n",
" for j in range(anim_args.interpolate_x_frames+1):\n",
" # interpolate the text embedding\n",
" prompt1_c = prompts_c_s[i]\n",
" prompt2_c = prompts_c_s[i+1] \n",
" args.init_c = prompt1_c.add(prompt2_c.sub(prompt1_c).mul(j * 1/(anim_args.interpolate_x_frames+1)))\n",
"\n",
" # sample the diffusion model\n",
" results = generate(args)\n",
" image = results[0]\n",
"\n",
" filename = f\"{args.timestring}_{frame_idx:05}.png\"\n",
" image.save(os.path.join(args.outdir, filename))\n",
" frame_idx += 1\n",
"\n",
" display.clear_output(wait=True)\n",
" display.display(image)\n",
"\n",
" args.seed = next_seed(args)\n",
"\n",
" # generate the last prompt\n",
" args.init_c = prompts_c_s[-1]\n",
Expand Down Expand Up @@ -1110,7 +1139,7 @@
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "Deforum_Stable_Diffusion_+_Interpolation.ipynb",
"name": "Deforum_Stable_Diffusion.ipynb",
"provenance": [],
"private_outputs": true
},
Expand Down

0 comments on commit 3ba2f92

Please sign in to comment.