-
Notifications
You must be signed in to change notification settings - Fork 64
/
Copy pathmeshgpt_pytorch.py
1758 lines (1298 loc) · 55.7 KB
/
meshgpt_pytorch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from __future__ import annotations
from pathlib import Path
from functools import partial
from math import ceil, pi, sqrt
import torch
from torch import nn, Tensor, einsum
from torch.nn import Module, ModuleList
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from torch.cuda.amp import autocast
from pytorch_custom_utils import save_load
from beartype.typing import Tuple, Callable, List, Dict, Any
from meshgpt_pytorch.typing import Float, Int, Bool, typecheck
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange
from einx import get_at
from x_transformers import Decoder
from x_transformers.x_transformers import RMSNorm, FeedForward, LayerIntermediates
from x_transformers.autoregressive_wrapper import (
eval_decorator,
top_k,
top_p,
)
from local_attention import LocalMHA
from vector_quantize_pytorch import (
ResidualVQ,
ResidualLFQ
)
from meshgpt_pytorch.data import derive_face_edges_from_faces
from meshgpt_pytorch.version import __version__
from taylor_series_linear_attention import TaylorSeriesLinearAttn
from classifier_free_guidance_pytorch import (
classifier_free_guidance,
TextEmbeddingReturner
)
from torch_geometric.nn.conv import SAGEConv
from gateloop_transformer import SimpleGateLoopLayer
from tqdm import tqdm
# helper functions
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
def first(it):
return it[0]
def identity(t, *args, **kwargs):
return t
def divisible_by(num, den):
return (num % den) == 0
def is_odd(n):
return not divisible_by(n, 2)
def is_empty(x):
return len(x) == 0
def is_tensor_empty(t: Tensor):
return t.numel() == 0
def set_module_requires_grad_(
module: Module,
requires_grad: bool
):
for param in module.parameters():
param.requires_grad = requires_grad
def l1norm(t):
return F.normalize(t, dim = -1, p = 1)
def l2norm(t):
return F.normalize(t, dim = -1, p = 2)
def safe_cat(tensors, dim):
tensors = [*filter(exists, tensors)]
if len(tensors) == 0:
return None
elif len(tensors) == 1:
return first(tensors)
return torch.cat(tensors, dim = dim)
def pad_at_dim(t, padding, dim = -1, value = 0):
ndim = t.ndim
right_dims = (ndim - dim - 1) if dim >= 0 else (-dim - 1)
zeros = (0, 0) * right_dims
return F.pad(t, (*zeros, *padding), value = value)
def pad_to_length(t, length, dim = -1, value = 0, right = True):
curr_length = t.shape[dim]
remainder = length - curr_length
if remainder <= 0:
return t
padding = (0, remainder) if right else (remainder, 0)
return pad_at_dim(t, padding, dim = dim, value = value)
def masked_mean(tensor, mask, dim = -1, eps = 1e-5):
if not exists(mask):
return tensor.mean(dim = dim)
mask = rearrange(mask, '... -> ... 1')
tensor = tensor.masked_fill(~mask, 0.)
total_el = mask.sum(dim = dim)
num = tensor.sum(dim = dim)
den = total_el.float().clamp(min = eps)
mean = num / den
mean = mean.masked_fill(total_el == 0, 0.)
return mean
# continuous embed
def ContinuousEmbed(dim_cont):
return nn.Sequential(
Rearrange('... -> ... 1'),
nn.Linear(1, dim_cont),
nn.SiLU(),
nn.Linear(dim_cont, dim_cont),
nn.LayerNorm(dim_cont)
)
# additional encoder features
# 1. angle (3), 2. area (1), 3. normals (3)
def derive_angle(x, y, eps = 1e-5):
z = einsum('... d, ... d -> ...', l2norm(x), l2norm(y))
return z.clip(-1 + eps, 1 - eps).arccos()
@torch.no_grad()
@typecheck
def get_derived_face_features(
face_coords: Float['b nf nvf 3'] # 3 or 4 vertices with 3 coordinates
):
is_quad = face_coords.shape[-2] == 4
# shift face coordinates depending on triangles or quads
shifted_face_coords = torch.roll(face_coords, 1, dims = (2,))
angles = derive_angle(face_coords, shifted_face_coords)
if is_quad:
# @sbriseid says quads need to be shifted by 2
shifted_face_coords = torch.roll(shifted_face_coords, 1, dims = (2,))
edge1, edge2, *_ = (face_coords - shifted_face_coords).unbind(dim = 2)
cross_product = torch.cross(edge1, edge2, dim = -1)
normals = l2norm(cross_product)
area = cross_product.norm(dim = -1, keepdim = True) * 0.5
return dict(
angles = angles,
area = area,
normals = normals
)
# tensor helper functions
@typecheck
def discretize(
t: Tensor,
*,
continuous_range: Tuple[float, float],
num_discrete: int = 128
) -> Tensor:
lo, hi = continuous_range
assert hi > lo
t = (t - lo) / (hi - lo)
t *= num_discrete
t -= 0.5
return t.round().long().clamp(min = 0, max = num_discrete - 1)
@typecheck
def undiscretize(
t: Tensor,
*,
continuous_range = Tuple[float, float],
num_discrete: int = 128
) -> Tensor:
lo, hi = continuous_range
assert hi > lo
t = t.float()
t += 0.5
t /= num_discrete
return t * (hi - lo) + lo
@typecheck
def gaussian_blur_1d(
t: Tensor,
*,
sigma: float = 1.
) -> Tensor:
_, _, channels, device, dtype = *t.shape, t.device, t.dtype
width = int(ceil(sigma * 5))
width += (width + 1) % 2
half_width = width // 2
distance = torch.arange(-half_width, half_width + 1, dtype = dtype, device = device)
gaussian = torch.exp(-(distance ** 2) / (2 * sigma ** 2))
gaussian = l1norm(gaussian)
kernel = repeat(gaussian, 'n -> c 1 n', c = channels)
t = rearrange(t, 'b n c -> b c n')
out = F.conv1d(t, kernel, padding = half_width, groups = channels)
return rearrange(out, 'b c n -> b n c')
@typecheck
def scatter_mean(
tgt: Tensor,
indices: Tensor,
src = Tensor,
*,
dim: int = -1,
eps: float = 1e-5
):
"""
todo: update to pytorch 2.1 and try https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_reduce_.html#torch.Tensor.scatter_reduce_
"""
num = tgt.scatter_add(dim, indices, src)
den = torch.zeros_like(tgt).scatter_add(dim, indices, torch.ones_like(src))
return num / den.clamp(min = eps)
# resnet block
class FiLM(Module):
def __init__(self, dim, dim_out = None):
super().__init__()
dim_out = default(dim_out, dim)
self.to_gamma = nn.Linear(dim, dim_out, bias = False)
self.to_beta = nn.Linear(dim, dim_out)
self.gamma_mult = nn.Parameter(torch.zeros(1,))
self.beta_mult = nn.Parameter(torch.zeros(1,))
def forward(self, x, cond):
gamma, beta = self.to_gamma(cond), self.to_beta(cond)
gamma, beta = tuple(rearrange(t, 'b d -> b 1 d') for t in (gamma, beta))
# for initializing to identity
gamma = (1 + self.gamma_mult * gamma.tanh())
beta = beta.tanh() * self.beta_mult
# classic film
return x * gamma + beta
class PixelNorm(Module):
def __init__(self, dim, eps = 1e-4):
super().__init__()
self.dim = dim
self.eps = eps
def forward(self, x):
dim = self.dim
return F.normalize(x, dim = dim, eps = self.eps) * sqrt(x.shape[dim])
class SqueezeExcite(Module):
def __init__(
self,
dim,
reduction_factor = 4,
min_dim = 16
):
super().__init__()
dim_inner = max(dim // reduction_factor, min_dim)
self.net = nn.Sequential(
nn.Linear(dim, dim_inner),
nn.SiLU(),
nn.Linear(dim_inner, dim),
nn.Sigmoid(),
Rearrange('b c -> b c 1')
)
def forward(self, x, mask = None):
if exists(mask):
x = x.masked_fill(~mask, 0.)
num = reduce(x, 'b c n -> b c', 'sum')
den = reduce(mask.float(), 'b 1 n -> b 1', 'sum')
avg = num / den.clamp(min = 1e-5)
else:
avg = reduce(x, 'b c n -> b c', 'mean')
return x * self.net(avg)
class Block(Module):
def __init__(
self,
dim,
dim_out = None,
dropout = 0.
):
super().__init__()
dim_out = default(dim_out, dim)
self.proj = nn.Conv1d(dim, dim_out, 3, padding = 1)
self.norm = PixelNorm(dim = 1)
self.dropout = nn.Dropout(dropout)
self.act = nn.SiLU()
def forward(self, x, mask = None):
if exists(mask):
x = x.masked_fill(~mask, 0.)
x = self.proj(x)
if exists(mask):
x = x.masked_fill(~mask, 0.)
x = self.norm(x)
x = self.act(x)
x = self.dropout(x)
return x
class ResnetBlock(Module):
def __init__(
self,
dim,
dim_out = None,
*,
dropout = 0.
):
super().__init__()
dim_out = default(dim_out, dim)
self.block1 = Block(dim, dim_out, dropout = dropout)
self.block2 = Block(dim_out, dim_out, dropout = dropout)
self.excite = SqueezeExcite(dim_out)
self.residual_conv = nn.Conv1d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(
self,
x,
mask = None
):
res = self.residual_conv(x)
h = self.block1(x, mask = mask)
h = self.block2(h, mask = mask)
h = self.excite(h, mask = mask)
return h + res
# gateloop layers
class GateLoopBlock(Module):
def __init__(
self,
dim,
*,
depth,
use_heinsen = True
):
super().__init__()
self.gateloops = ModuleList([])
for _ in range(depth):
gateloop = SimpleGateLoopLayer(dim = dim, use_heinsen = use_heinsen)
self.gateloops.append(gateloop)
def forward(
self,
x,
cache = None
):
received_cache = exists(cache)
if is_tensor_empty(x):
return x, None
if received_cache:
prev, x = x[:, :-1], x[:, -1:]
cache = default(cache, [])
cache = iter(cache)
new_caches = []
for gateloop in self.gateloops:
layer_cache = next(cache, None)
out, new_cache = gateloop(x, cache = layer_cache, return_cache = True)
new_caches.append(new_cache)
x = x + out
if received_cache:
x = torch.cat((prev, x), dim = -2)
return x, new_caches
# main classes
@save_load(version = __version__)
class MeshAutoencoder(Module):
@typecheck
def __init__(
self,
num_discrete_coors = 128,
coor_continuous_range: Tuple[float, float] = (-1., 1.),
dim_coor_embed = 64,
num_discrete_area = 128,
dim_area_embed = 16,
num_discrete_normals = 128,
dim_normal_embed = 64,
num_discrete_angle = 128,
dim_angle_embed = 16,
encoder_dims_through_depth: Tuple[int, ...] = (
64, 128, 256, 256, 576
),
init_decoder_conv_kernel = 7,
decoder_dims_through_depth: Tuple[int, ...] = (
128, 128, 128, 128,
192, 192, 192, 192,
256, 256, 256, 256, 256, 256,
384, 384, 384
),
dim_codebook = 192,
num_quantizers = 2, # or 'D' in the paper
codebook_size = 16384, # they use 16k, shared codebook between layers
use_residual_lfq = True, # whether to use the latest lookup-free quantization
rq_kwargs: dict = dict(
quantize_dropout = True,
quantize_dropout_cutoff_index = 1,
quantize_dropout_multiple_of = 1,
),
rvq_kwargs: dict = dict(
kmeans_init = True,
threshold_ema_dead_code = 2,
),
rlfq_kwargs: dict = dict(
frac_per_sample_entropy = 1.,
soft_clamp_input_value = 10.,
experimental_softplus_entropy_loss = True,
),
rvq_stochastic_sample_codes = True,
sageconv_kwargs: dict = dict(
normalize = True,
project = True
),
commit_loss_weight = 0.1,
bin_smooth_blur_sigma = 0.4, # they blur the one hot discretized coordinate positions
attn_encoder_depth = 0,
attn_decoder_depth = 0,
local_attn_kwargs: dict = dict(
dim_head = 32,
heads = 8
),
local_attn_window_size = 64,
linear_attn_kwargs: dict = dict(
dim_head = 8,
heads = 16
),
use_linear_attn = True,
pad_id = -1,
flash_attn = True,
attn_dropout = 0.,
ff_dropout = 0.,
resnet_dropout = 0,
checkpoint_quantizer = False,
quads = False
):
super().__init__()
self.num_vertices_per_face = 3 if not quads else 4
total_coordinates_per_face = self.num_vertices_per_face * 3
# main face coordinate embedding
self.num_discrete_coors = num_discrete_coors
self.coor_continuous_range = coor_continuous_range
self.discretize_face_coords = partial(discretize, num_discrete = num_discrete_coors, continuous_range = coor_continuous_range)
self.coor_embed = nn.Embedding(num_discrete_coors, dim_coor_embed)
# derived feature embedding
self.discretize_angle = partial(discretize, num_discrete = num_discrete_angle, continuous_range = (0., pi))
self.angle_embed = nn.Embedding(num_discrete_angle, dim_angle_embed)
lo, hi = coor_continuous_range
self.discretize_area = partial(discretize, num_discrete = num_discrete_area, continuous_range = (0., (hi - lo) ** 2))
self.area_embed = nn.Embedding(num_discrete_area, dim_area_embed)
self.discretize_normals = partial(discretize, num_discrete = num_discrete_normals, continuous_range = coor_continuous_range)
self.normal_embed = nn.Embedding(num_discrete_normals, dim_normal_embed)
# attention related
attn_kwargs = dict(
causal = False,
prenorm = True,
dropout = attn_dropout,
window_size = local_attn_window_size,
)
# initial dimension
init_dim = dim_coor_embed * (3 * self.num_vertices_per_face) + dim_angle_embed * self.num_vertices_per_face + dim_normal_embed * 3 + dim_area_embed
# project into model dimension
self.project_in = nn.Linear(init_dim, dim_codebook)
# initial sage conv
init_encoder_dim, *encoder_dims_through_depth = encoder_dims_through_depth
curr_dim = init_encoder_dim
self.init_sage_conv = SAGEConv(dim_codebook, init_encoder_dim, **sageconv_kwargs)
self.init_encoder_act_and_norm = nn.Sequential(
nn.SiLU(),
nn.LayerNorm(init_encoder_dim)
)
self.encoders = ModuleList([])
for dim_layer in encoder_dims_through_depth:
sage_conv = SAGEConv(
curr_dim,
dim_layer,
**sageconv_kwargs
)
self.encoders.append(sage_conv)
curr_dim = dim_layer
self.encoder_attn_blocks = ModuleList([])
for _ in range(attn_encoder_depth):
self.encoder_attn_blocks.append(nn.ModuleList([
TaylorSeriesLinearAttn(curr_dim, prenorm = True, **linear_attn_kwargs) if use_linear_attn else None,
LocalMHA(dim = curr_dim, **attn_kwargs, **local_attn_kwargs),
nn.Sequential(RMSNorm(curr_dim), FeedForward(curr_dim, glu = True, dropout = ff_dropout))
]))
# residual quantization
self.codebook_size = codebook_size
self.num_quantizers = num_quantizers
self.project_dim_codebook = nn.Linear(curr_dim, dim_codebook * self.num_vertices_per_face)
if use_residual_lfq:
self.quantizer = ResidualLFQ(
dim = dim_codebook,
num_quantizers = num_quantizers,
codebook_size = codebook_size,
commitment_loss_weight = 1.,
**rlfq_kwargs,
**rq_kwargs
)
else:
self.quantizer = ResidualVQ(
dim = dim_codebook,
num_quantizers = num_quantizers,
codebook_size = codebook_size,
shared_codebook = True,
commitment_weight = 1.,
rotation_trick = True,
stochastic_sample_codes = rvq_stochastic_sample_codes,
**rvq_kwargs,
**rq_kwargs
)
self.checkpoint_quantizer = checkpoint_quantizer # whether to memory checkpoint the quantizer
self.pad_id = pad_id # for variable lengthed faces, padding quantized ids will be set to this value
# decoder
decoder_input_dim = dim_codebook * 3
self.decoder_attn_blocks = ModuleList([])
for _ in range(attn_decoder_depth):
self.decoder_attn_blocks.append(nn.ModuleList([
TaylorSeriesLinearAttn(decoder_input_dim, prenorm = True, **linear_attn_kwargs) if use_linear_attn else None,
LocalMHA(dim = decoder_input_dim, **attn_kwargs, **local_attn_kwargs),
nn.Sequential(RMSNorm(decoder_input_dim), FeedForward(decoder_input_dim, glu = True, dropout = ff_dropout))
]))
init_decoder_dim, *decoder_dims_through_depth = decoder_dims_through_depth
curr_dim = init_decoder_dim
assert is_odd(init_decoder_conv_kernel)
self.init_decoder_conv = nn.Sequential(
nn.Conv1d(dim_codebook * self.num_vertices_per_face, init_decoder_dim, kernel_size = init_decoder_conv_kernel, padding = init_decoder_conv_kernel // 2),
nn.SiLU(),
Rearrange('b c n -> b n c'),
nn.LayerNorm(init_decoder_dim),
Rearrange('b n c -> b c n')
)
self.decoders = ModuleList([])
for dim_layer in decoder_dims_through_depth:
resnet_block = ResnetBlock(curr_dim, dim_layer, dropout = resnet_dropout)
self.decoders.append(resnet_block)
curr_dim = dim_layer
self.to_coor_logits = nn.Sequential(
nn.Linear(curr_dim, num_discrete_coors * total_coordinates_per_face),
Rearrange('... (v c) -> ... v c', v = total_coordinates_per_face)
)
# loss related
self.commit_loss_weight = commit_loss_weight
self.bin_smooth_blur_sigma = bin_smooth_blur_sigma
@property
def device(self):
return next(self.parameters()).device
@classmethod
def _from_pretrained(
cls,
*,
model_id: str,
revision: str | None,
cache_dir: str | Path | None,
force_download: bool,
proxies: Dict | None,
resume_download: bool,
local_files_only: bool,
token: str | bool | None,
map_location: str = "cpu",
strict: bool = False,
**model_kwargs,
):
model_filename = "mesh-autoencoder.bin"
model_file = Path(model_id) / model_filename
if not model_file.exists():
model_file = hf_hub_download(
repo_id=model_id,
filename=model_filename,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
model = cls.init_and_load(model_file,strict=strict)
model.to(map_location)
return model
@typecheck
def encode(
self,
*,
vertices: Float['b nv 3'],
faces: Int['b nf nvf'],
face_edges: Int['b e 2'],
face_mask: Bool['b nf'],
face_edges_mask: Bool['b e'],
return_face_coordinates = False
):
"""
einops:
b - batch
nf - number of faces
nv - number of vertices (3)
nvf - number of vertices per face (3 or 4) - triangles vs quads
c - coordinates (3)
d - embed dim
"""
_, num_faces, num_vertices_per_face = faces.shape
assert self.num_vertices_per_face == num_vertices_per_face
face_without_pad = faces.masked_fill(~rearrange(face_mask, 'b nf -> b nf 1'), 0)
# continuous face coords
face_coords = get_at('b [nv] c, b nf mv -> b nf mv c', vertices, face_without_pad)
# compute derived features and embed
derived_features = get_derived_face_features(face_coords)
discrete_angle = self.discretize_angle(derived_features['angles'])
angle_embed = self.angle_embed(discrete_angle)
discrete_area = self.discretize_area(derived_features['area'])
area_embed = self.area_embed(discrete_area)
discrete_normal = self.discretize_normals(derived_features['normals'])
normal_embed = self.normal_embed(discrete_normal)
# discretize vertices for face coordinate embedding
discrete_face_coords = self.discretize_face_coords(face_coords)
discrete_face_coords = rearrange(discrete_face_coords, 'b nf nv c -> b nf (nv c)') # 9 or 12 coordinates per face
face_coor_embed = self.coor_embed(discrete_face_coords)
face_coor_embed = rearrange(face_coor_embed, 'b nf c d -> b nf (c d)')
# combine all features and project into model dimension
face_embed, _ = pack([face_coor_embed, angle_embed, area_embed, normal_embed], 'b nf *')
face_embed = self.project_in(face_embed)
# handle variable lengths by using masked_select and masked_scatter
# first handle edges
# needs to be offset by number of faces for each batch
face_index_offsets = reduce(face_mask.long(), 'b nf -> b', 'sum')
face_index_offsets = F.pad(face_index_offsets.cumsum(dim = 0), (1, -1), value = 0)
face_index_offsets = rearrange(face_index_offsets, 'b -> b 1 1')
face_edges = face_edges + face_index_offsets
face_edges = face_edges[face_edges_mask]
face_edges = rearrange(face_edges, 'be ij -> ij be')
# next prepare the face_mask for using masked_select and masked_scatter
orig_face_embed_shape = face_embed.shape[:2]
face_embed = face_embed[face_mask]
# initial sage conv followed by activation and norm
face_embed = self.init_sage_conv(face_embed, face_edges)
face_embed = self.init_encoder_act_and_norm(face_embed)
for conv in self.encoders:
face_embed = conv(face_embed, face_edges)
shape = (*orig_face_embed_shape, face_embed.shape[-1])
face_embed = face_embed.new_zeros(shape).masked_scatter(rearrange(face_mask, '... -> ... 1'), face_embed)
for linear_attn, attn, ff in self.encoder_attn_blocks:
if exists(linear_attn):
face_embed = linear_attn(face_embed, mask = face_mask) + face_embed
face_embed = attn(face_embed, mask = face_mask) + face_embed
face_embed = ff(face_embed) + face_embed
if not return_face_coordinates:
return face_embed
return face_embed, discrete_face_coords
@typecheck
def quantize(
self,
*,
faces: Int['b nf nvf'],
face_mask: Bool['b n'],
face_embed: Float['b nf d'],
pad_id = None,
rvq_sample_codebook_temp = 1.
):
pad_id = default(pad_id, self.pad_id)
batch, device = faces.shape[0], faces.device
max_vertex_index = faces.amax()
num_vertices = int(max_vertex_index.item() + 1)
face_embed = self.project_dim_codebook(face_embed)
face_embed = rearrange(face_embed, 'b nf (nvf d) -> b nf nvf d', nvf = self.num_vertices_per_face)
vertex_dim = face_embed.shape[-1]
vertices = torch.zeros((batch, num_vertices, vertex_dim), device = device)
# create pad vertex, due to variable lengthed faces
pad_vertex_id = num_vertices
vertices = pad_at_dim(vertices, (0, 1), dim = -2, value = 0.)
faces = faces.masked_fill(~rearrange(face_mask, 'b n -> b n 1'), pad_vertex_id)
# prepare for scatter mean
faces_with_dim = repeat(faces, 'b nf nvf -> b (nf nvf) d', d = vertex_dim)
face_embed = rearrange(face_embed, 'b ... d -> b (...) d')
# scatter mean
averaged_vertices = scatter_mean(vertices, faces_with_dim, face_embed, dim = -2)
# mask out null vertex token
mask = torch.ones((batch, num_vertices + 1), device = device, dtype = torch.bool)
mask[:, -1] = False
# rvq specific kwargs
quantize_kwargs = dict(mask = mask)
if isinstance(self.quantizer, ResidualVQ):
quantize_kwargs.update(sample_codebook_temp = rvq_sample_codebook_temp)
# a quantize function that makes it memory checkpointable
def quantize_wrapper_fn(inp):
unquantized, quantize_kwargs = inp
return self.quantizer(unquantized, **quantize_kwargs)
# maybe checkpoint the quantize fn
if self.checkpoint_quantizer:
quantize_wrapper_fn = partial(checkpoint, quantize_wrapper_fn, use_reentrant = False)
# residual VQ
quantized, codes, commit_loss = quantize_wrapper_fn((averaged_vertices, quantize_kwargs))
# gather quantized vertexes back to faces for decoding
# now the faces have quantized vertices
face_embed_output = get_at('b [n] d, b nf nvf -> b nf (nvf d)', quantized, faces)
# vertex codes also need to be gathered to be organized by face sequence
# for autoregressive learning
codes_output = get_at('b [n] q, b nf nvf -> b (nf nvf) q', codes, faces)
# make sure codes being outputted have this padding
face_mask = repeat(face_mask, 'b nf -> b (nf nvf) 1', nvf = self.num_vertices_per_face)
codes_output = codes_output.masked_fill(~face_mask, self.pad_id)
# output quantized, codes, as well as commitment loss
return face_embed_output, codes_output, commit_loss
@typecheck
def decode(
self,
quantized: Float['b n d'],
face_mask: Bool['b n']
):
conv_face_mask = rearrange(face_mask, 'b n -> b 1 n')
x = quantized
for linear_attn, attn, ff in self.decoder_attn_blocks:
if exists(linear_attn):
x = linear_attn(x, mask = face_mask) + x
x = attn(x, mask = face_mask) + x
x = ff(x) + x
x = rearrange(x, 'b n d -> b d n')
x = x.masked_fill(~conv_face_mask, 0.)
x = self.init_decoder_conv(x)
for resnet_block in self.decoders:
x = resnet_block(x, mask = conv_face_mask)
return rearrange(x, 'b d n -> b n d')
@typecheck
@torch.no_grad()
def decode_from_codes_to_faces(
self,
codes: Tensor,
face_mask: Bool['b n'] | None = None,
return_discrete_codes = False
):
codes = rearrange(codes, 'b ... -> b (...)')
if not exists(face_mask):
face_mask = reduce(codes != self.pad_id, 'b (nf nvf q) -> b nf', 'all', nvf = self.num_vertices_per_face, q = self.num_quantizers)
# handle different code shapes
codes = rearrange(codes, 'b (n q) -> b n q', q = self.num_quantizers)
# decode
quantized = self.quantizer.get_output_from_indices(codes)
quantized = rearrange(quantized, 'b (nf nvf) d -> b nf (nvf d)', nvf = self.num_vertices_per_face)
decoded = self.decode(
quantized,
face_mask = face_mask
)
decoded = decoded.masked_fill(~face_mask[..., None], 0.)
pred_face_coords = self.to_coor_logits(decoded)
pred_face_coords = pred_face_coords.argmax(dim = -1)
pred_face_coords = rearrange(pred_face_coords, '... (v c) -> ... v c', v = self.num_vertices_per_face)
# back to continuous space
continuous_coors = undiscretize(
pred_face_coords,
num_discrete = self.num_discrete_coors,
continuous_range = self.coor_continuous_range
)
# mask out with nan
continuous_coors = continuous_coors.masked_fill(~rearrange(face_mask, 'b nf -> b nf 1 1'), float('nan'))
if not return_discrete_codes:
return continuous_coors, face_mask
return continuous_coors, pred_face_coords, face_mask
@torch.no_grad()
def tokenize(self, vertices, faces, face_edges = None, **kwargs):
assert 'return_codes' not in kwargs
inputs = [vertices, faces, face_edges]
inputs = [*filter(exists, inputs)]
ndims = {i.ndim for i in inputs}
assert len(ndims) == 1
batch_less = first(list(ndims)) == 2
if batch_less:
inputs = [rearrange(i, '... -> 1 ...') for i in inputs]
input_kwargs = dict(zip(['vertices', 'faces', 'face_edges'], inputs))
self.eval()
codes = self.forward(
**input_kwargs,
return_codes = True,
**kwargs
)
if batch_less:
codes = rearrange(codes, '1 ... -> ...')
return codes
@typecheck
def forward(
self,
*,
vertices: Float['b nv 3'],
faces: Int['b nf nvf'],
face_edges: Int['b e 2'] | None = None,
return_codes = False,
return_loss_breakdown = False,
return_recon_faces = False,
only_return_recon_faces = False,
rvq_sample_codebook_temp = 1.
):
if not exists(face_edges):
face_edges = derive_face_edges_from_faces(faces, pad_id = self.pad_id)
device = faces.device
face_mask = reduce(faces != self.pad_id, 'b nf c -> b nf', 'all')
face_edges_mask = reduce(face_edges != self.pad_id, 'b e ij -> b e', 'all')
encoded, face_coordinates = self.encode(
vertices = vertices,