-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathinference.py
76 lines (63 loc) · 3.53 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
import os
from utils import save_audio, get_loss
from tqdm import tqdm
import shutil
import numpy as np
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
return noise_cfg
@torch.no_grad()
def eval_ddim(unet, autoencoder, scheduler, eval_loader, args, device, epoch=0,
uncond_path=None,
guidance_scale=False, guidance_rescale=0.0,
ddim_steps=50, eta=0,
random_seed=2024,):
if random_seed is not None:
generator = torch.Generator(device=device).manual_seed(random_seed)
else:
generator = torch.Generator(device=device)
generator.seed()
scheduler.set_timesteps(ddim_steps)
unet.eval()
for step, (mixture, target, timbre, mix_id, mixture_path, source_path, enroll_path) in enumerate(tqdm(eval_loader)):
mixture = mixture.to(device)
target = target.to(device)
timbre = timbre.to(device)
# init noise
noise = torch.randn(mixture.shape, generator=generator, device=device)
pred = noise
for t in scheduler.timesteps:
pred = scheduler.scale_model_input(pred, t)
if guidance_scale:
uncond = torch.tensor(np.load(uncond_path)['arr_0']).unsqueeze(0).to(device)
pred_combined = torch.cat([pred, pred], dim=0)
mixture_combined = torch.cat([mixture, mixture], dim=0)
timbre_combined = torch.cat([timbre, uncond], dim=0)
output_combined = unet(x=pred_combined, timesteps=t, mixture=mixture_combined, timbre=timbre_combined)
output_pos, output_neg = torch.chunk(output_combined, 2, dim=0)
model_output = output_neg + guidance_scale * (output_pos - output_neg)
if guidance_rescale > 0.0:
# avoid overexposed
model_output = rescale_noise_cfg(model_output, output_pos,
guidance_rescale=guidance_rescale)
else:
model_output = unet(x=pred, timesteps=t, mixture=mixture, timbre=timbre)
pred = scheduler.step(model_output=model_output, timestep=t, sample=pred,
eta=eta, generator=generator).prev_sample
pred_wav = autoencoder(embedding=pred)
os.makedirs(f'{args.log_dir}/audio/{epoch}/', exist_ok=True)
for j in range(pred_wav.shape[0]):
shutil.copyfile(mixture_path[j], f'{args.log_dir}/audio/{epoch}/pred_{mix_id[j]}_mixture.wav')
shutil.copyfile(source_path[j], f'{args.log_dir}/audio/{epoch}/pred_{mix_id[j]}_source.wav')
shutil.copyfile(enroll_path[j], f'{args.log_dir}/audio/{epoch}/pred_{mix_id[j]}_enroll.wav')
save_audio(f'{args.log_dir}/audio/{epoch}/pred_{mix_id[j]}.wav', 24000, pred_wav[j].unsqueeze(0))