-
Notifications
You must be signed in to change notification settings - Fork 90
/
gaussian_diffusion.py
executable file
·969 lines (860 loc) · 35.4 KB
/
gaussian_diffusion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
"""
This code started out as a PyTorch port of Ho et al's diffusion models:
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py
Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules.
"""
import enum
import math
import numpy as np
import torch as th
import sys
sys.path.append('.')
import torch.nn.functional as F
from .utils.nn import mean_flat
from .utils.losses import normal_kl, discretized_gaussian_log_likelihood
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
"""
Get a pre-defined beta schedule for the given name.
The beta schedule library consists of beta schedules which remain similar
in the limit of num_diffusion_timesteps.
Beta schedules may be added, but should not be removed or changed once
they are committed to maintain backwards compatibility.
"""
if schedule_name == "linear":
# Linear schedule from Ho et al, extended to work for any number of
# diffusion steps.
scale = 1000 / num_diffusion_timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return np.linspace(
beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
)
elif schedule_name == "cosine":
return betas_for_alpha_bar(
num_diffusion_timesteps,
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
)
elif schedule_name == 'sqrt':
return betas_for_alpha_bar(
num_diffusion_timesteps,
lambda t: 1-np.sqrt(t + 0.0001),
)
elif schedule_name == "trunc_cos":
return betas_for_alpha_bar_left(
num_diffusion_timesteps,
lambda t: np.cos((t + 0.1) / 1.1 * np.pi / 2) ** 2,
)
elif schedule_name == 'trunc_lin':
scale = 1000 / num_diffusion_timesteps
beta_start = scale * 0.0001 + 0.01
beta_end = scale * 0.02 + 0.01
return np.linspace(
beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
)
elif schedule_name == 'pw_lin':
scale = 1000 / num_diffusion_timesteps
beta_start = scale * 0.0001 + 0.01
beta_mid = scale * 0.0001 #scale * 0.02
beta_end = scale * 0.02
first_part = np.linspace(
beta_start, beta_mid, 10, dtype=np.float64
)
second_part = np.linspace(
beta_mid, beta_end, num_diffusion_timesteps - 10 , dtype=np.float64
)
return np.concatenate(
[first_part, second_part]
)
else:
raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
def betas_for_alpha_bar_left(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function, but shifts towards left interval starting from 0
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas = []
betas.append(min(1-alpha_bar(0), max_beta))
for i in range(num_diffusion_timesteps-1):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas)
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas)
class GaussianDiffusion:
"""
Utilities for training and sampling diffusion models.
Ported directly from here, and then adapted over time to further experimentation.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
:param betas: a 1-D numpy array of betas for each diffusion timestep,
starting at T and going to 1.
:param predict_xstart: the model outputs to predict x_0, else to predict eps.
:param learn_sigmas: the model outputs to predict sigma or not. Default: False
:param rescale_learned_sigmas, sigma_small: details setting of learned sigmas
:param rescale_timesteps: if True, pass floating point timesteps into the
model so that they are always scaled like in the
original paper (0 to 1000).
"""
def __init__(
self,
*,
betas,
predict_xstart,
rescale_learned_sigmas,
learn_sigmas,
sigma_small,
use_kl,
rescale_timesteps=False,
):
self.rescale_timesteps = rescale_timesteps
self.predict_xstart = predict_xstart
self.rescale_learned_sigmas = rescale_learned_sigmas
self.learn_sigmas = learn_sigmas
self.sigma_small = sigma_small
self.use_kl = use_kl
# Use float64 for accuracy.
betas = np.array(betas, dtype=np.float64)
self.betas = betas
assert len(betas.shape) == 1, "betas must be 1-D"
assert (betas > 0).all() and (betas <= 1).all()
self.num_timesteps = int(betas.shape[0])
alphas = 1.0 - betas
self.alphas_cumprod = np.cumprod(alphas, axis=0)
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
# calculations for posterior q(x_{t-1} | x_t, x_0)
self.posterior_variance = (
betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
)
# log calculation clipped because the posterior variance is 0 at the
# beginning of the diffusion chain.
self.posterior_log_variance_clipped = np.log(
np.append(self.posterior_variance[1], self.posterior_variance[1:])
)
self.posterior_mean_coef1 = (
betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
)
self.posterior_mean_coef2 = (
(1.0 - self.alphas_cumprod_prev)
* np.sqrt(alphas)
/ (1.0 - self.alphas_cumprod)
)
self.mapping_func = None # implement in train main()
self.add_mask_noise = False # TODO
def training_losses(self, model, *args, **kwargs):
self.model = model
return self.training_losses_seq2seq(model, *args, **kwargs)
def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
)
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
return (
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- pred_xstart
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
def _scale_timesteps(self, t):
if self.rescale_timesteps:
return t.float() * (1000.0 / self.num_timesteps)
return t
def q_mean_variance(self, x_start, t):
"""
Get the distribution q(x_t | x_0).
:param x_start: the [N x C x ...] tensor of noiseless inputs.
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
"""
mean = (
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
)
variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
log_variance = _extract_into_tensor(
self.log_one_minus_alphas_cumprod, t, x_start.shape
)
return mean, variance, log_variance
def q_sample(self, x_start, t, noise=None, mask=None):
"""
Diffuse the data for a given number of diffusion steps.
In other words, sample from q(x_t | x_0).
:param x_start: the initial data batch.
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:param noise: if specified, the split-out normal noise.
:param mask: anchoring masked position
:return: A noisy version of x_start.
"""
if noise is None:
noise = th.randn_like(x_start)
assert noise.shape == x_start.shape
x_t = (
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
* noise
)
if mask == None:
return x_t
else:
mask = th.broadcast_to(mask.unsqueeze(dim=-1), x_start.shape)
return th.where(mask==0, x_start, x_t)
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior:
q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
_extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = _extract_into_tensor(
self.posterior_log_variance_clipped, t, x_t.shape
)
assert (
posterior_mean.shape[0]
== posterior_variance.shape[0]
== posterior_log_variance_clipped.shape[0]
== x_start.shape[0]
)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(
self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None
):
"""
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
the initial x, x_0.
:param model: the model, which takes a signal and a batch of timesteps
as input.
:param x: the [N x C x ...] tensor at time t.
:param t: a 1-D Tensor of timesteps.
:param clip_denoised: if True, clip the denoised signal into [-1, 1].
:param denoised_fn: if not None, a function which applies to the
x_start prediction before it is used to sample. Applies before
clip_denoised.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:return: a dict with the following keys:
- 'mean': the model mean output.
- 'variance': the model variance output.
- 'log_variance': the log of 'variance'.
- 'pred_xstart': the prediction for x_0.
"""
if model_kwargs is None:
model_kwargs = {}
B, C = x.size(0), x.size(-1)
assert t.shape == (B,)
# print(x.shape)
model_output = model(x, self._scale_timesteps(t), **model_kwargs)
# for fixedlarge, we set the initial (log-)variance like so
# to get a better decoder log likelihood.
model_variance = np.append(self.posterior_variance[1], self.betas[1:])
model_log_variance = np.log(np.append(self.posterior_variance[1], self.betas[1:]))
model_variance = _extract_into_tensor(model_variance, t, x.shape)
model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
def process_xstart(x):
if denoised_fn is not None:
# print(denoised_fn)
x = denoised_fn(x, t)
if clip_denoised:
return x.clamp(-1, 1)
return x
if self.predict_xstart:
pred_xstart = process_xstart(model_output)
else:
### model is used to predict eps
pred_xstart = process_xstart(
self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
)
model_mean, _, _ = self.q_posterior_mean_variance(
x_start=pred_xstart, x_t=x, t=t
)
assert (
model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
)
return {
"mean": model_mean,
"variance": model_variance,
"log_variance": model_log_variance,
"pred_xstart": pred_xstart,
}
def p_sample(
self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None,
top_p=None, mask=None, x_start=None,
):
"""
Sample x_{t-1} from the model at the given timestep.
:param model: the model to sample from.
:param x: the current tensor at x_{t-1}.
:param t: the value of t, starting at 0 for the first diffusion step.
:param clip_denoised: if True, clip the x_start prediction to [-1, 1].
:param denoised_fn: if not None, a function which applies to the
x_start prediction before it is used to sample.
:param mask: anchoring masked position to x_start
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:return: a dict containing the following keys:
- 'sample': a random sample from the model.
- 'pred_xstart': a prediction of x_0.
"""
out = self.p_mean_variance(
model,
x,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
)
if top_p is not None and top_p > 0:
# print('top_p sampling')
noise = th.randn_like(x)
replace_mask = th.abs(noise) > top_p
while replace_mask.any():
noise[replace_mask] = th.randn_like(noise[replace_mask])
replace_mask = th.abs(noise) > top_p
assert (th.abs(noise) <= top_p).all()
else:
noise = th.randn_like(x)
nonzero_mask = (
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
) # no noise when t == 0
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
if mask == None:
pass
else:
sample = th.where(mask==0, x_start, sample)
return {
"sample": sample,
"pred_xstart": out["pred_xstart"],
"greedy_mean": out["mean"],
"out": out
}
def p_sample_loop(
self,
model,
shape,
noise=None,
clip_denoised=True,
denoised_fn=None,
model_kwargs=None,
device=None,
progress=False,
top_p=None,
clamp_step=None,
clamp_first=None,
mask=None,
x_start=None,
gap=1,
):
"""
Generate samples from the model.
:param model: the model module.
:param shape: the shape of the samples, (N, C, H, W).
:param noise: if specified, the noise from the encoder to sample.
Should be of the same shape as `shape`.
:param clip_denoised: if True, clip x_start predictions to [-1, 1].
:param denoised_fn: if not None, a function which applies to the
x_start prediction before it is used to sample.
:param mask: anchoring masked position to x_start
:param clamp_step: in clamp_first mode, choose end clamp step, otherwise starting clamp step
:param clamp_first: bool, clamp_first mode
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:param device: if specified, the device to create the samples on.
If not specified, use a model parameter's device.
:param progress: if True, show a tqdm progress bar.
:return: a non-differentiable batch of samples.
"""
final = []
for sample in self.p_sample_loop_progressive(
model,
shape,
noise=noise,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
device=device,
progress=progress,
top_p=top_p,
clamp_step=clamp_step,
clamp_first=clamp_first,
mask=mask,
x_start=x_start
):
final.append(sample['sample'])
return final
def p_sample_loop_progressive(
self,
model,
shape,
noise=None,
clip_denoised=True,
denoised_fn=None,
model_kwargs=None,
device=None,
progress=False,
top_p=None,
clamp_step=None,
clamp_first=None,
mask=None,
x_start=None,
):
"""
Generate samples from the model and yield intermediate samples from
each timestep of diffusion.
Arguments are the same as p_sample_loop().
Returns a generator over dicts, where each dict is the return value of
p_sample().
"""
if device is None:
device = next(model.parameters()).device
assert isinstance(shape, (tuple, list))
if noise is not None: # custom your the start point of x_0
sample_x = noise
else:
sample_x = th.randn(*shape, device=device)
indices = list(range(self.num_timesteps))[::-1]
if progress:
# Lazy import so that we don't depend on tqdm.
from tqdm.auto import tqdm
indices = tqdm(indices)
for i in indices: # from T to 0
t = th.tensor([i] * shape[0], device=device)
if not clamp_first:
if i > clamp_step:
denoised_fn_cur = None
else:
denoised_fn_cur = denoised_fn
else:
if i >= clamp_step:
denoised_fn_cur = denoised_fn
else:
denoised_fn_cur = None
with th.no_grad():
out = self.p_sample(
model,
sample_x,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn_cur,
model_kwargs=model_kwargs,
top_p=top_p,
mask=mask,
x_start=x_start
)
yield out
sample_x = out["sample"]
def _get_x_start(self, x_start_mean, std):
'''
Word embedding projection from {Emb(w)} to {x_0}
:param x_start_mean: word embedding
:return: x_0
'''
noise = th.randn_like(x_start_mean)
assert noise.shape == x_start_mean.shape
# print(x_start_mean.device, noise.device)
return (
x_start_mean + std * noise
)
def _token_discrete_loss(self, x_t, get_logits, input_ids, mask=None, truncate=False, t=None):
'''
the loss of -log p(w|z_0)
:param x_start_mean: word embedding
:return: x_0
'''
reshaped_x_t = x_t
logits = get_logits(reshaped_x_t) # bsz, seqlen, vocab
# print(logits.shape)
loss_fct = th.nn.CrossEntropyLoss(reduction='none')
decoder_nll = loss_fct(logits.view(-1, logits.size(-1)), input_ids.view(-1)).view(input_ids.shape)
if mask != None:
decoder_nll *= mask
# print(decoder_nll.shape)
if mask != None:
decoder_nll = decoder_nll.sum(dim=-1)/mask.sum(dim=-1)
else:
decoder_nll = decoder_nll.mean(dim=-1)
return decoder_nll
def _x0_helper(self, model_output, x, t):
if self.predict_xstart:
pred_xstart = model_output
pred_prev, _, _ = self.q_posterior_mean_variance(
x_start=pred_xstart, x_t=x, t=t
)
else: # predict eps
pred_xstart = self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
pred_prev, _, _ = self.q_posterior_mean_variance(
x_start=pred_xstart, x_t=x, t=t
)
return {'pred_xprev':pred_prev, 'pred_xstart':pred_xstart}
def training_losses_seq2seq(self, model, x_start, t, model_kwargs=None, noise=None):
"""
Compute training losses for a single timestep.
:param model: the model to evaluate loss on.
:param x_start: the [N x C x ...] tensor of inputs. # not used unless fixing the input embeddings
:param t: a batch of timestep indices.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:param noise: if specified, the specific Gaussian noise to try to remove.
:return: a dict with the key "loss" containing a tensor of shape [N].
Some mean or variance settings may also have other keys.
"""
x_start_fix = x_start # save the orignal x_0
assert 'input_ids' in model_kwargs
input_ids_x = model_kwargs.pop('input_ids').to(t.device)
input_ids_mask = model_kwargs.pop('input_mask').to(t.device)
x_start_mean = model.model.module.get_embeds(input_ids_x)
std = _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod,
th.tensor([0]).to(x_start_mean.device),
x_start_mean.shape)
# print(std.shape, )
# x_start_log_var = 2 * th.log(std)
x_start = self._get_x_start(x_start_mean, std)
# print(x_start_mean.shape, x_start.shape)
if noise is None:
noise = th.randn_like(x_start)
x_t = self.q_sample(x_start, t, noise=noise, mask=input_ids_mask) # reparametrization trick.
get_logits = model.model.module.get_logits
terms = {}
target = x_start
model_output = model(x_t, self._scale_timesteps(t), **model_kwargs)
assert model_output.shape == target.shape == x_start.shape
terms["mse"] = mean_flat((target - model_output) ** 2)
model_out_x_start = self._x0_helper(model_output, x_t, t)['pred_xstart'] # predicted_xstart = model_output
t0_mask = (t == 0)
t0_loss = mean_flat((x_start_mean - model_out_x_start) ** 2)
terms["mse"] = th.where(t0_mask, t0_loss, terms["mse"])
# tT_mask = (t == self.num_timesteps - 1)
out_mean, _, _ = self.q_mean_variance(x_start, th.LongTensor([self.num_timesteps - 1]).to(x_start.device))
tT_loss = mean_flat(out_mean ** 2)
decoder_nll = self._token_discrete_loss(x_start, get_logits, input_ids_x) # embedding regularization
terms["nll"] = self._token_discrete_loss(model_out_x_start, get_logits, input_ids_x, mask=input_ids_mask, truncate=True, t=t) # x_0->model_out_x_start
# assert (model.lm_head.weight == model.word_embedding.weight).all()
terms["loss"] = terms["mse"] + decoder_nll + tT_loss
return terms
def ddim_sample(
self,
model,
x,
t,
clip_denoised=True,
denoised_fn=None,
model_kwargs=None,
eta=0.0,
langevin_fn=None,
mask=None,
x_start=None
):
"""
Sample x_{t-1} from the model using DDIM.
Same usage as p_sample().
"""
out = self.p_mean_variance(
model,
x,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
)
# Usually our model outputs epsilon, but we re-derive it
# in case we used x_start or x_prev prediction.
eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
sigma = (
eta
* th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
* th.sqrt(1 - alpha_bar / alpha_bar_prev)
)
# Equation 12.
noise = th.randn_like(x)
mean_pred = (
out["pred_xstart"] * th.sqrt(alpha_bar_prev)
+ th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
)
nonzero_mask = (
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
) # no noise when t == 0
# print(sigma.mean())
sample = mean_pred + nonzero_mask * sigma * noise
if langevin_fn:
print(t.shape)
sample=langevin_fn(sample, mean_pred, sigma, self.alphas_cumprod_prev[t[0]], t, x)
if mask == None:
pass
else:
sample = th.where(mask==0, x_start, sample)
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
def ddim_reverse_sample(
self,
model,
x,
t,
clip_denoised=True,
denoised_fn=None,
model_kwargs=None,
eta=0.0,
):
"""
Sample x_{t+1} from the model using DDIM reverse ODE.
"""
assert eta == 0.0, "Reverse ODE only for deterministic path"
out = self.p_mean_variance(
model,
x,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
)
# Usually our model outputs epsilon, but we re-derive it
# in case we used x_start or x_prev prediction.
eps = (
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
- out["pred_xstart"]
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
# Equation 12. reversed
mean_pred = (
out["pred_xstart"] * th.sqrt(alpha_bar_next)
+ th.sqrt(1 - alpha_bar_next) * eps
)
return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
def ddim_sample_loop(
self,
model,
shape,
noise=None,
clip_denoised=True,
denoised_fn=None,
model_kwargs=None,
device=None,
progress=False,
top_p=None,
clamp_step=None,
clamp_first=None,
mask=None,
x_start=None,
gap=1,
):
"""
Generate samples from the model using DDIM.
:param gap: compute ddim sampling for each {gap} step
Same usage as p_sample_loop().
"""
final = []
for sample in self.ddim_sample_loop_progressive(
model,
shape,
noise=noise,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
device=device,
progress=progress,
mask=mask,
x_start=x_start,
gap = gap
):
final.append(sample['sample'])
return final
def ddim_sample_loop_progressive(
self,
model,
shape,
noise=None,
clip_denoised=True,
denoised_fn=None,
model_kwargs=None,
device=None,
progress=False,
eta=0.0,
langevin_fn=None,
mask=None,
x_start=None,
gap=1
):
"""
Use DDIM to sample from the model and yield intermediate samples from
each timestep of DDIM.
Same usage as p_sample_loop_progressive().
"""
if device is None:
device = next(model.parameters()).device
assert isinstance(shape, (tuple, list))
if noise is not None:
sample_x = noise
else:
sample_x = th.randn(*shape, device=device)
indices = list(range(self.num_timesteps))[::-1][::gap]
if progress:
# Lazy import so that we don't depend on tqdm.
from tqdm.auto import tqdm
indices = tqdm(indices)
for i in indices:
t = th.tensor([i] * shape[0], device=device)
with th.no_grad():
out = self.ddim_sample(
model,
sample_x,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
mask=mask,
x_start=x_start
)
yield out
sample_x = out["sample"]
def _extract_into_tensor(arr, timesteps, broadcast_shape):
"""
Extract values from a 1-D numpy array for a batch of indices.
:param arr: the 1-D numpy array.
:param timesteps: a tensor of indices into the array to extract.
:param broadcast_shape: a larger shape of K dimensions with the batch
dimension equal to the length of timesteps.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return res.expand(broadcast_shape)
def space_timesteps(num_timesteps, section_counts):
"""
Create a list of timesteps to use from an original diffusion process,
given the number of timesteps we want to take from equally-sized portions
of the original process.
For example, if there's 300 timesteps and the section counts are [10,15,20]
then the first 100 timesteps are strided to be 10 timesteps, the second 100
are strided to be 15 timesteps, and the final 100 are strided to be 20.
If the stride is a string starting with "ddim", then the fixed striding
from the DDIM paper is used, and only one section is allowed.
:param num_timesteps: the number of diffusion steps in the original
process to divide up.
:param section_counts: either a list of numbers, or a string containing
comma-separated numbers, indicating the step count
per section. As a special case, use "ddimN" where N
is a number of steps to use the striding from the
DDIM paper.
:return: a set of diffusion steps from the original process to use.
"""
if isinstance(section_counts, str):
if section_counts.startswith("ddim"):
desired_count = int(section_counts[len("ddim") :])
for i in range(1, num_timesteps):
if len(range(0, num_timesteps, i)) == desired_count:
return set(range(0, num_timesteps, i))
raise ValueError(
f"cannot create exactly {num_timesteps} steps with an integer stride"
)
section_counts = [int(x) for x in section_counts.split(",")]
size_per = num_timesteps // len(section_counts)
extra = num_timesteps % len(section_counts)
start_idx = 0
all_steps = []
for i, section_count in enumerate(section_counts):
size = size_per + (1 if i < extra else 0)
if size < section_count:
raise ValueError(
f"cannot divide section of {size} steps into {section_count}"
)
if section_count <= 1:
frac_stride = 1
else:
frac_stride = (size - 1) / (section_count - 1)
cur_idx = 0.0
taken_steps = []
for _ in range(section_count):
taken_steps.append(start_idx + round(cur_idx))
cur_idx += frac_stride
all_steps += taken_steps
start_idx += size
return set(all_steps)
class SpacedDiffusion(GaussianDiffusion):
"""
A diffusion process which can skip steps in a base diffusion process.
:param use_timesteps: a collection (sequence or set) of timesteps from the
original diffusion process to retain.
:param kwargs: the kwargs to create the base diffusion process.
"""
def __init__(self, use_timesteps, **kwargs):
self.use_timesteps = set(use_timesteps)
self.timestep_map = []
self.original_num_steps = len(kwargs["betas"])
# print(kwargs.keys())
base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
last_alpha_cumprod = 1.0
new_betas = []
for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
if i in self.use_timesteps:
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
last_alpha_cumprod = alpha_cumprod
self.timestep_map.append(i)
kwargs["betas"] = np.array(new_betas)
super().__init__(**kwargs)
def p_mean_variance(
self, model, *args, **kwargs
): # pylint: disable=signature-differs
# print('called p_mean_var')
return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
def training_losses(
self, model, *args, **kwargs
): # pylint: disable=signature-differs
# print('called training_losses')
return super().training_losses(self._wrap_model(model), *args, **kwargs)
def _wrap_model(self, model):
if isinstance(model, _WrappedModel):
return model
return _WrappedModel(
model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
)
def _scale_timesteps(self, t):
# Scaling is done by the wrapped model.
return t
class _WrappedModel:
def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
self.model = model
self.timestep_map = timestep_map
self.rescale_timesteps = rescale_timesteps
self.original_num_steps = original_num_steps
def __call__(self, x, ts, **kwargs):
# print(ts)
map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
new_ts = map_tensor[ts]
# print(new_ts)
if self.rescale_timesteps:
new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
# temp = self.model(x, new_ts, **kwargs)
# print(temp.shape)
# return temp
# print(new_ts)
return self.model(x, new_ts, **kwargs)