-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvery_attentive_tacotron.py
1771 lines (1541 loc) · 59.1 KB
/
very_attentive_tacotron.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Reference implementations for Very Attentive Tacotron and T5-TTS.
From the paper:
"Very Attentive Tacotron: Robust and Unbounded Length Generalization in
Transformer-based Text-to-Speech."
https://arxiv.org/abs/2410.22179
Includes text encoder and autoregressive decoder configurations for the
large and small (3/8 width) variants of the T5-TTS baseline model and the
proposed Very Attentive Tacotron (VAT) model.
"""
import abc
import dataclasses
import functools
import math
import numpy as np
from sequence_layers import tensorflow as sl
from sequence_layers.examples import t5
from sequence_layers.tensorflow import utils
import tensorflow.compat.v2 as tf
# Value used to bias the attention logits for excluded positions.
ATTENTION_MASK_BIAS = -1e9
# Epsilon value to prevent log underflow.
LOG_EPS = 1e-16
# Constants key for alignment position.
ALIGNMENT_POSITION = 'alignment_position'
# ------------------------------------------------------------------------------
# Text encoder configurations.
# ------------------------------------------------------------------------------
def SmallT5TTSTextEncoder() -> sl.SequenceLayer:
return TextEncoder(
dimension=192,
use_irpbs=False,
name='small_t5_tts_text_encoder',
)
def SmallVATTextEncoder() -> sl.SequenceLayer:
return TextEncoder(
dimension=192,
use_irpbs=True,
name='small_vat_text_encoder',
)
def LargeT5TTSTextEncoder() -> sl.SequenceLayer:
return TextEncoder(
dimension=512,
use_irpbs=False,
name='large_t5_tts_text_encoder',
)
def LargeVATTextEncoder() -> sl.SequenceLayer:
return TextEncoder(
dimension=512,
use_irpbs=True,
name='large_vat_text_encoder',
)
# ------------------------------------------------------------------------------
# Decoder configurations.
# ------------------------------------------------------------------------------
def SmallT5TTSDecoder() -> sl.SequenceLayer:
decoder_block_num_heads = 8
decoder_block_units_per_head = 48
decoder_block_hidden_dim = (
decoder_block_num_heads * decoder_block_units_per_head
)
return T5TTSDecoder(
config=T5TTSDecoderConfig(
name='t5_tts_decoder',
source_name='text_encoder_top',
num_codebooks=8,
codebook_size=256,
num_decoder_blocks=6,
num_heads=decoder_block_num_heads,
units_per_head=decoder_block_units_per_head,
ffn_dimension=decoder_block_hidden_dim * 4,
dropout_rate=0.1,
max_past_horizon=128,
),
)
def SmallVATDecoder() -> sl.SequenceLayer:
decoder_block_num_heads = 8
decoder_block_units_per_head = 48
decoder_block_hidden_dim = (
decoder_block_num_heads * decoder_block_units_per_head
)
return VATDecoder(
config=VATDecoderConfig(
name='very_attentive_decoder',
source_name='text_encoder_top',
alignment_layer=AlignmentLayerConfig(alignment_rnn_units=96),
num_codebooks=8,
codebook_size=256,
num_decoder_blocks=6,
dropout_rate=0.1,
decoder_block=DecoderBlockConfig(
alignment_rnn_units=96,
num_heads=decoder_block_num_heads,
units_per_head=decoder_block_units_per_head,
hidden_dim=decoder_block_hidden_dim,
feedforward_hidden_dim=decoder_block_hidden_dim * 4,
),
),
)
def LargeT5TTSDecoder() -> sl.SequenceLayer:
decoder_block_num_heads = 16
decoder_block_units_per_head = 64
decoder_block_hidden_dim = (
decoder_block_num_heads * decoder_block_units_per_head
)
return T5TTSDecoder(
config=T5TTSDecoderConfig(
name='t5_tts_decoder',
source_name='text_encoder_top',
num_codebooks=8,
codebook_size=256,
num_decoder_blocks=6,
num_heads=decoder_block_num_heads,
units_per_head=decoder_block_units_per_head,
ffn_dimension=decoder_block_hidden_dim * 4,
dropout_rate=0.1,
max_past_horizon=128,
),
)
def LargeVATDecoder() -> sl.SequenceLayer:
decoder_block_num_heads = 16
decoder_block_units_per_head = 64
decoder_block_hidden_dim = (
decoder_block_num_heads * decoder_block_units_per_head
)
return VATDecoder(
config=VATDecoderConfig(
name='very_attentive_decoder',
source_name='text_encoder_top',
alignment_layer=AlignmentLayerConfig(alignment_rnn_units=256),
num_codebooks=8,
codebook_size=256,
num_decoder_blocks=6,
dropout_rate=0.1,
decoder_block=DecoderBlockConfig(
alignment_rnn_units=96,
num_heads=decoder_block_num_heads,
units_per_head=decoder_block_units_per_head,
hidden_dim=decoder_block_hidden_dim,
feedforward_hidden_dim=decoder_block_hidden_dim * 4,
),
),
)
class PreprocessConstants(abc.ABC):
"""sl.SequenceLayer mix-in for preprocess_constants support."""
@abc.abstractmethod
def preprocess_constants(self, constants: sl.Constants) -> None: # pylint: disable=invalid-name
"""Preprocess constants and stash resulting Tensors internally.
This method allows a SequenceLayer to precompute and stash constant-derived
values so they don't need to be recomputed on each call to step.
Args:
constants: Constants to preprocess.
"""
@abc.abstractmethod
def clear_preprocessed_constants(self) -> None: # pylint: disable=invalid-name
"""Clear preprocessed constants.
This method should remove any internally stashed preprocessed constants when
they are no longer needed in the current graph (or tf.function). This is to
prevent Tensors from crossing from one graph to another (which results in
an error.)
"""
def FeedforwardBlock(
hidden_dim: int,
output_dim: int,
activation: ...,
dropout_rate: float,
) -> sl.SequenceLayer:
"""T5-style feed-forward block."""
return sl.Residual([
sl.RMSNormalization(epsilon=1e-6, name='rms_normalization'),
sl.Dense(hidden_dim, use_bias=False, activation=activation),
sl.Dropout(dropout_rate, noise_shape=[None, 1, None]),
sl.Dense(output_dim, use_bias=False, name='dense'),
sl.Dropout(dropout_rate, noise_shape=[None, 1, None]),
])
# ------------------------------------------------------------------------------
# Relative position embedding helper classes.
# ------------------------------------------------------------------------------
@dataclasses.dataclass
class InterpolatedRelativePositionBiasesConfig:
"""Configuration for InterpolatedRelativePositionBiases."""
# Number of buckets to use for the relative position bias matrix.
num_buckets: int
# Maximum relative distance to support. All distances above this value are
# mapped to the same bucket.
max_distance: int
# Scale of value to subtract from the output position biases for relative
# positions that exceed max_distance. The subtracted value is equal to:
# `max_distance_penalty * (abs(relative_position) - max_distance)`.
max_distance_penalty: float
# pyformat: disable
# Bias matrix initialization scheme. One of:
# * 'constant': Initialize all biases to `init_scheme_value`.
# * 'gaussian_window_stddev': Initialize biases with values drawn from a
# Gaussian distribution centered at relative position 0, with standard
# deviation `init_scheme_value`.
# * 'truncated_normal_stddev': Initialize biases with values drawn from a
# truncated normal distribution centered at relative position 0, with
# standard deviation `init_scheme_value`.
# pyformat: enable
init_scheme: str
# Value to use for the bias matrix initialization scheme.
init_scheme_value: float
@dataclasses.dataclass
class T5RelativePositionEmbeddingConfig:
"""Configuration for sl.T5RelativePositionEmbedding."""
# Number of buckets to use for the relative position bias matrix.
num_buckets: int
# Maximum relative distance to support. All distances above this value are
# mapped to the same bucket.
max_distance: int
class GaussianWindowBiasInitializer(tf.keras.initializers.Initializer):
"""Gaussian window initializer for InterpolatedRelativePositionBiases.
This returns the normalized logits of a Gaussian window.
"""
def __init__(self, stddev: float, bidirectional: bool):
"""Construct GaussianWindowBiasInitializer.
Args:
stddev: The standard deviation of the Gaussian window.
bidirectional: Whether the position biases are bidirectional. This
determines where the window is centered.
"""
self.bidirectional = bidirectional
self.stddev = stddev
def __call__(self, shape, dtype=tf.float32, **kwargs):
num_buckets, num_heads = shape
# Work around missing bfloat16 support in tf.range.
if dtype == tf.bfloat16:
bucket_inds = tf.cast(
tf.range(num_buckets, dtype=tf.float32), tf.bfloat16
)
else:
bucket_inds = tf.range(num_buckets, dtype=dtype)
if self.bidirectional:
offset = num_buckets // 2
else:
offset = 0
logits = -0.5 * tf.square((bucket_inds - offset) / self.stddev)
logits -= tf.math.reduce_logsumexp(logits) # Normalize logits.
biases = tf.tile(logits[:, tf.newaxis], [1, num_heads])
return biases
class InterpolatedRelativePositionBiases(sl.RelativePositionEmbedding):
"""InterpolatedRelativePositionBiases.
This is an extension of T5's relative position biases that supports
non-integer relative positions. This is done by interpolating between
bias values for adjacent integer bins. Additionally, interpolation is used
to compute the bias values for the "non-exact" logarithmically spaced bins
(as an alternative to flooring to the nearest bin).
The motivation behind this scheme is to support location-relative
cross-attention where the "query" position is differentiable; and therefore,
must be continuous.
This also supports the SequenceLayer RelativePositionEmbedding interface.
"""
def __init__(
self,
num_heads: int,
num_buckets: int,
max_distance: int,
max_distance_penalty: float,
bidirectional: bool,
initializer: str,
name: str | None = None,
):
"""Construct InterpolatedRelativePositionBiases.
Args:
num_heads: Number of biases to produce for each relative position.
num_buckets: Number of relative position buckets to use.
max_distance: Maximum relative distance to support. All distances above
this value are mapped to the same bucket.
max_distance_penalty: Scale of value to subtract from the output position
biases for relative positions that exceed max_distance. The subtracted
value is equal to: `max_distance_penalty * (abs(relative_position) -
max_distance)`.
bidirectional: Whether biases should be produced for both positive and
negative relative positions. If False, positive relative positions
(where key_position > query_position) are not supported.
initializer: Initializer for the biases (embedding) matrix.
name: Module name.
If bidirectional is True, the buckets are evenly split between positive and
negative relative positions; if False, all of the buckets are used for
negative relative positions (and zero). In either case, half of the buckets
are used for "exact" relative positions and the other half are used for
logarithmically-spaced buckets of increasing size.
Positive relative positions greater than or equal to max_distance are
mapped to the same bias bucket. Negative relative positions less than or
equal to -max_distance are mapped to the same bias bucket.
"""
super().__init__(name=name)
self.num_heads = num_heads
self.num_buckets = num_buckets
self.bidirectional = bidirectional
self.max_distance = max_distance
self.max_distance_penalty = max_distance_penalty
with self.name_scope:
self.bias_matrix = tf.keras.layers.Embedding(
input_dim=self.num_buckets,
output_dim=self.num_heads,
embeddings_initializer=initializer,
use_one_hot_matmul=True, # Avoid gathers on TPU.
name='bias_matrix',
)
@classmethod
def from_config(
cls,
num_heads: int,
bidirectional: bool,
config: InterpolatedRelativePositionBiasesConfig,
) -> 'InterpolatedRelativePositionBiases':
if config.init_scheme is None or config.init_scheme == 'constant':
initializer = tf.keras.initializers.Constant(config.init_scheme_value)
elif config.init_scheme == 'gaussian_window_stddev':
initializer = GaussianWindowBiasInitializer(
config.init_scheme_value, bidirectional=bidirectional
)
elif config.init_scheme == 'truncated_normal_stddev':
initializer = tf.keras.initializers.TruncatedNormal(
stddev=config.init_scheme_value
)
else:
raise NotImplementedError(f'Unknown {config.init_scheme=}')
return cls(
num_heads=num_heads,
num_buckets=config.num_buckets,
max_distance=config.max_distance,
max_distance_penalty=config.max_distance_penalty,
bidirectional=bidirectional,
initializer=initializer,
)
def _relative_position_bucket(self, relative_position):
"""Translate relative position to a continuous bucket value.
The relative position is defined as memory_position - query_position, i.e.
the distance in tokens from the attending position to the attended-to
position.
If bidirectional=False, then positive relative positions are invalid.
We use smaller buckets for small absolute relative_position and larger
buckets for larger absolute relative_positions.
All relative positions >=max_distance map to the same bucket.
All relative positions <=-max_distance map to the same bucket.
Args:
relative_position: a float32 Tensor
Returns:
a Tensor with the same shape as relative_position, containing float32
values in the range [0, num_buckets)
"""
is_positive = None
n = -relative_position
if self.bidirectional:
buckets_per_side = self.num_buckets // 2
is_positive = tf.math.less(n, 0) # n = -relative_position
n = tf.math.abs(n)
else:
buckets_per_side = self.num_buckets
n = tf.math.maximum(n, 0)
# now n is in the range [0, inf)
max_exact = buckets_per_side // 2
is_small = tf.math.less(n, max_exact)
# Note that `(buckets_per_side - 1 - max_exact)` below differs from the
# reference implementation of T5 relative position biases which uses
# `(buckets_per_side - max_exact)`, and therefore doesn't produce the
# max_distance behavior described in the docstring above.
val_if_large = max_exact + (
tf.math.log(n / max_exact + LOG_EPS)
/ math.log(self.max_distance / max_exact)
* (buckets_per_side - 1 - max_exact)
)
val_if_large = tf.math.minimum(val_if_large, buckets_per_side - 1)
val = tf.where(is_small, n, val_if_large)
if self.bidirectional:
# Relative position `0` is centered in the bias matrix.
val = tf.where(
is_positive, buckets_per_side + val, buckets_per_side - val
)
return val
@tf.Module.with_name_scope
def __call__(
self, query_positions: tf.Tensor, key_positions: tf.Tensor
) -> tf.Tensor:
"""Get biases from position indices.
Args:
query_positions: [batch, query_length] int/float Tensor containing query
position indices.
key_positions: [batch, key_length] int/float Tensor containing position
indices.
Returns:
position_biases: [batch, num_heads, query_length, key_length] Tensor
containing head-wise biases for each query-key pair.
"""
compute_dtype = utils.compute_dtype()
# [batch, query_length, 1]
query_positions = tf.cast(query_positions[:, :, tf.newaxis], compute_dtype)
# [batch, 1, key_length]
key_positions = tf.cast(key_positions[:, tf.newaxis, :], compute_dtype)
# [batch, query_length, key_length]
relative_positions = key_positions - query_positions
bucket_val = self._relative_position_bucket(relative_positions)
# Interpolate between adjacent buckets.
bucket_ind_low = tf.math.floor(bucket_val)
bucket_ind_high = tf.math.ceil(bucket_val)
# [batch, query_length, key_length, 1]
high_weight = tf.expand_dims(bucket_val - bucket_ind_low, axis=-1)
# [batch, query_length, key_length, num_heads]
biases_low = self.bias_matrix(tf.cast(bucket_ind_low, tf.int32))
biases_high = self.bias_matrix(tf.cast(bucket_ind_high, tf.int32))
biases = (1.0 - high_weight) * biases_low + high_weight * biases_high
# Apply optional max_distance_penalty.
if self.max_distance_penalty != 0.0:
# [batch, query_length, key_length]
excess_distance = tf.maximum(
0.0, tf.abs(relative_positions) - self.max_distance
)
penalty_amount = excess_distance * self.max_distance_penalty
# [batch, query_length, key_length, num_heads]
biases -= penalty_amount[:, :, :, tf.newaxis]
# -> [batch, num_heads, query_length, key_length]
position_biases = tf.transpose(biases, [0, 3, 1, 2])
return position_biases
@tf.Module.with_name_scope
def get_position_bias_raw(
self,
queries_position: tf.Tensor,
queries_length: tf.Tensor | int,
keys_position: tf.Tensor,
keys_length: tf.Tensor | int,
) -> tf.Tensor:
"""Computes relative self-attention position biases for absolute positions.
This method computes relative position biases for absolute query / key
positions and lengths.
Args:
queries_position: Scalar integer query absolute position.
queries_length: Number of query timesteps to produce relative position
embeddings for.
keys_position: Scalar integer key absolute position.
keys_length: Number of key timesteps to produce relative position
embeddings for.
Returns:
A tensor of relative position biases broadcastable to
[batch, queries_length, num_heads, keys_length].
"""
queries_positions = (
queries_position + tf.range(queries_length)[tf.newaxis, :]
)
keys_positions = keys_position + tf.range(keys_length)[tf.newaxis, :]
position_biases = self(queries_positions, keys_positions)
position_biases = tf.transpose(position_biases, [0, 2, 1, 3])
return position_biases
@tf.Module.with_name_scope
def get_position_bias_streaming(
self, queries: tf.Tensor, keys: tf.Tensor, queries_position: tf.Tensor
) -> tf.Tensor:
"""Computes relative self-attention position biases for streaming queries.
This method computes relative position biases for a block of queries_time
timesteps of the overall queries tensor, and the keys/values available are
always the queries within the current queries_time block, and a fixed
"max_previous" trailing window of timesteps.
Args:
queries: [batch, queries_time, num_heads, units_per_head] queries.
queries_time is the number of streaming steps we are taking at once.
keys: [batch, queries_time + max_previous, num_heads, units_per_head]
keys.
queries_position: scalar integer indicating the current decode position of
queries[:, 0, :]. For example, this value is n * queries_time for the
n'th queries_time block we process.
Returns:
A tensor of relative position biases broadcastable to
[batch, num_heads, queries_time, keys_time].
"""
queries_time_static = queries.shape.dims[1].value
keys_time_static = keys.shape.dims[1].value
queries_time = sl.utils.smart_dimension_size(queries, 1)
keys_time = sl.utils.smart_dimension_size(keys, 1)
# Add singleton batch dimension to context/memory_position.
context_position = queries_position + tf.range(queries_time)[tf.newaxis, :]
# keys[:, 0, :]'s absolute position is queries_position - max_previous.
max_previous = keys_time - queries_time
memory_start_position = queries_position - max_previous
memory_position = memory_start_position + tf.range(keys_time)[tf.newaxis, :]
values = self.__call__(context_position, memory_position)
values.shape.assert_is_compatible_with(
[1, self.num_heads, queries_time_static, keys_time_static]
)
return values
@tf.Module.with_name_scope
def get_position_bias(self, queries: tf.Tensor) -> tf.Tensor:
"""Computes relative self-attention position biases for queries.
Args:
queries: [batch, queries_time, num_heads, units_per_head] queries.
Returns:
A tensor of relative position biases broadcastable to
[batch, num_heads, queries_time, keys_time].
"""
queries_time_static = queries.shape.dims[1].value
queries_time = sl.utils.smart_dimension_size(queries, 1)
# Add singleton batch dimension.
context_position = tf.range(queries_time)[tf.newaxis, :]
values = self.__call__(context_position, context_position)
values.shape.assert_is_compatible_with(
[1, self.num_heads, queries_time_static, queries_time_static]
)
return values
# ------------------------------------------------------------------------------
# Text encoder helper classes.
# ------------------------------------------------------------------------------
def ConvStage(
num_blocks: int,
strides: int,
num_filters: int = 256,
kernel_size: int = 3,
output_dim: int = 256,
dropout_rate: float = 0.1,
name: str = 'conv_stage',
) -> sl.SequenceLayer:
"""A sequence of N Conv1D residual blocks plus a final rms_norm + dropout."""
def ConvBlock(block_ind):
return sl.Residual(
[
sl.RMSNormalization(epsilon=1e-6, name='rms_norm'),
sl.Conv1D(
filters=num_filters,
kernel_size=kernel_size,
dilation_rate=1,
padding='same',
activation=tf.nn.gelu,
name='conv_layer',
),
sl.Dropout(dropout_rate, noise_shape=[-1, 1, -1]),
sl.Dense(units=output_dim, use_bias=False, name='output_layer'),
sl.Dropout(dropout_rate, noise_shape=[-1, 1, -1]),
],
name=f'residual_block_{block_ind:02d}',
)
return sl.Serial(
[
sl.Conv1D(
filters=num_filters,
kernel_size=3,
strides=strides,
padding='same',
name='resample_conv1d',
)
]
+ [ConvBlock(block_ind) for block_ind in range(num_blocks)]
+ [
sl.RMSNormalization(epsilon=1e-6, name='rms_norm'),
sl.Dropout(dropout_rate, noise_shape=[-1, 1, -1]),
],
name=name,
)
def SelfAttentionBlock(
output_dim: int,
num_heads: int,
units_per_head: int,
max_horizon: int,
max_future_horizon: int,
relative_position_embedding: (
InterpolatedRelativePositionBiases | sl.T5RelativePositionEmbedding
),
dropout_rate: float,
) -> sl.SequenceLayer:
"""Self-attention with relative position biases (interpolated or not)."""
return sl.Residual([
sl.RMSNormalization(epsilon=1e-6, name='rms_normalization'),
sl.DotProductSelfAttention(
num_heads=num_heads,
units_per_head=units_per_head,
max_horizon=max_horizon,
max_future_horizon=max_future_horizon,
use_relative_position_embedding=True,
relative_position_embedding=relative_position_embedding,
attention_probabilities_dropout_rate=dropout_rate,
broadcast_dropout_across_queries=True,
use_bias=False,
name='dot_product_self_attention',
),
sl.Flatten(),
sl.Dense(output_dim, use_bias=False, name='dense'),
sl.Dropout(dropout_rate, noise_shape=[None, 1, None]),
])
def TransformerEncoderBlock(
dimension: int,
num_heads: int,
units_per_head: int,
position_embeddings_config: (
InterpolatedRelativePositionBiasesConfig
| T5RelativePositionEmbeddingConfig
),
ffn_activation: ... = tf.nn.gelu,
dropout_rate: float = 0.1,
) -> sl.SequenceLayer:
"""Self-attention + feed-forward encoder block."""
if isinstance(
position_embeddings_config, InterpolatedRelativePositionBiasesConfig
):
position_embeddings = InterpolatedRelativePositionBiases.from_config(
bidirectional=True,
num_heads=num_heads,
config=position_embeddings_config,
)
else:
position_embeddings = sl.T5RelativePositionEmbedding(
num_buckets=position_embeddings_config.num_buckets,
num_heads=num_heads,
bidirectional=True,
max_distance=position_embeddings_config.max_distance,
)
return sl.Serial([
SelfAttentionBlock(
output_dim=dimension,
num_heads=num_heads,
units_per_head=units_per_head,
relative_position_embedding=position_embeddings,
max_horizon=-1,
max_future_horizon=-1,
dropout_rate=dropout_rate,
),
FeedforwardBlock(
output_dim=dimension,
hidden_dim=dimension * 4,
activation=ffn_activation,
dropout_rate=dropout_rate,
),
])
# ------------------------------------------------------------------------------
# Text encoder base class.
# ------------------------------------------------------------------------------
def TextEncoder(
dimension: int,
use_irpbs: bool,
num_layers: int = 3,
max_distance: int = 64,
num_heads: int = 8,
num_buckets: int = 32,
ffn_activation: ... = tf.nn.gelu,
dropout_rate: float = 0.1,
name: str | None = None,
) -> sl.SequenceLayer:
"""Builds a hybrid text encoder, with or without relative position biases."""
assert dimension % num_heads == 0
units_per_head = dimension // num_heads
if use_irpbs:
position_embeddings_config = InterpolatedRelativePositionBiasesConfig(
num_buckets=num_buckets,
max_distance=max_distance,
max_distance_penalty=1.0,
init_scheme='truncated_normal_stddev',
init_scheme_value=1.0,
)
else:
position_embeddings_config = T5RelativePositionEmbeddingConfig(
num_buckets=num_buckets,
max_distance=max_distance,
)
with tf.name_scope(name or 't5_encoder'):
return sl.Serial([
ConvStage(
num_blocks=3,
num_filters=int(dimension / 2),
strides=1,
output_dim=int(dimension / 2),
),
ConvStage(
num_blocks=3,
num_filters=dimension,
strides=2,
output_dim=dimension,
),
sl.Serial([
TransformerEncoderBlock(
dimension=dimension,
num_heads=num_heads,
units_per_head=units_per_head,
position_embeddings_config=position_embeddings_config,
ffn_activation=ffn_activation,
dropout_rate=dropout_rate,
)
for _ in range(num_layers)
]),
sl.RMSNormalization(epsilon=1e-6),
sl.Dropout(dropout_rate),
])
# ------------------------------------------------------------------------------
# VAT decoder helper classes.
# ------------------------------------------------------------------------------
@dataclasses.dataclass
class AlignmentLayerConfig:
"""Configuration for AlignmentLayer."""
# Number of RNN units for the alignment LSTM sublayer.
alignment_rnn_units: int
# Number of heads used in location-based cross-attention.
num_heads: int = 4
# Number of units per head used in location-based cross-attention.
units_per_head: int = 32
# Initial delta for the alignment position.
initial_delta: float = 0.25
# Configuration for the cross-attention bias.
cross_attention_bias: InterpolatedRelativePositionBiasesConfig = (
dataclasses.field(
default_factory=lambda: InterpolatedRelativePositionBiasesConfig(
num_buckets=32,
max_distance=64,
max_distance_penalty=1.0,
init_scheme='gaussian_window_stddev',
init_scheme_value=15.0,
)
)
)
class AlignmentLayer(sl.Emitting, PreprocessConstants):
"""AlignmentLayer.
Emits a differentiable monotonic alignment position. This can be used to
provide the "query" position in a cross-attention mechanism that uses relative
position biases.
The order of operations is as follows:
1. The alignment position is initialized to zero.
2. For the current time step:
a. A multi-head context vector is computed for the current alignment
position using a purely location-relative mechanism (using no query-key
comparisons) based on T5 relative position biases.
b. The input to this layer and the context vector are concatenated and fed
into an RNN.
c. To produce the updated alignment position:
- The output of the MLP is converted to a positive scalar using a
softplus layer.
- The positive float is added to the current alignment position to produce
the updated alignment position.
3. Repeat #2 for all time steps in the input sequence.
4. RNN outputs are returned as layer outputs and the alignment position
Sequence is returned as layer emits.
"""
def __init__(
self,
source_name: str,
config: AlignmentLayerConfig,
dropout_rate: float,
name: str | None = None,
):
super().__init__(name=name)
self.config = config
self.source_name = source_name
self.num_heads = config.num_heads
self.units_per_head = config.units_per_head
# This is set by preprocess_constants().
self._source_head_values = None
def _inverse_softplus(x: tf.Tensor) -> tf.Tensor: # pylint: disable=invalid-name
return np.log(np.exp(x) - 1.0)
with self.name_scope:
# Source value-projection for cross-attention.
self.source_value_projection = tf.keras.layers.Dense(
units=config.num_heads * config.units_per_head,
use_bias=False,
name='source_value_projection',
)
# Position biases for cross-attention.
self.position_embeddings = InterpolatedRelativePositionBiases.from_config(
num_heads=config.num_heads,
bidirectional=True,
config=config.cross_attention_bias,
)
self.sublayer = sl.RNN(
tf.keras.layers.LSTMCell(config.alignment_rnn_units)
)
# Compute initial bias from initial_delta using inverse softplus.
if config.initial_delta <= 0.0:
raise ValueError('initial_delta must be positive')
initial_output_bias = _inverse_softplus(config.initial_delta)
self.delta_output_layer = tf.keras.layers.Dense(
units=1,
activation='softplus', # Keeps the output positive.
kernel_initializer='zeros',
bias_initializer=tf.constant_initializer(initial_output_bias),
name='delta_output_layer',
)
self.dropout = tf.keras.layers.Dropout(
rate=dropout_rate, name='attention_dropout'
)
@tf.Module.with_name_scope
def preprocess_constants(self, constants: sl.Constants) -> None:
source = self._get_source(constants)
batch_size, source_len, _ = sl.utils.smart_dimension_size(source.values)
# Precompute self._source_head_values.
source_head_values = tf.reshape(
self.source_value_projection(source.values),
[batch_size, source_len, self.num_heads, self.units_per_head],
)
self._source_head_values = sl.Sequence(
source_head_values, source.mask
).mask_invalid()
@tf.Module.with_name_scope
def clear_preprocessed_constants(self) -> None:
self._source_head_values = None
def _get_source_head_values(self) -> sl.Sequence:
"""Get [bs, source_len, num_heads, units_per_head] Sequence."""
if self._source_head_values is None:
raise ValueError(
'source_head_values must be precomputed using preprocess_constants().'
)
return self._source_head_values
def _get_source(self, constants: sl.Constants) -> sl.Sequence:
if constants is None:
constants = {}
source = constants.get(self.source_name)
if not isinstance(source, sl.Sequence):
raise ValueError('constants[source_name] must contain an sl.Sequence')
return source
def _compute_context(
self, x: sl.Sequence, alignment_position: tf.Tensor, training: bool
) -> sl.Sequence:
"""Compute bias-only attention context (no query-key comparisons)."""
source_head_values = self._get_source_head_values()
# Compute position biases.
batch_size, source_len, _, _ = sl.utils.smart_dimension_size(
source_head_values.values
)
source_position = tf.tile(
tf.range(source_len)[tf.newaxis, :], [batch_size, 1]
)
# [bs, num_heads, query_len=1, source_len]
position_biases = self.position_embeddings(
alignment_position, source_position
)
# [bs, num_heads, source_len]
position_biases = tf.squeeze(position_biases, axis=2)
# Compute attention bias mask.
# [bs, num_heads=1, source_len] (broadcast across heads).
attention_mask = source_head_values.mask[:, tf.newaxis, :]
attention_bias_mask = (1.0 - attention_mask) * ATTENTION_MASK_BIAS
assert attention_bias_mask.dtype == tf.float32
compute_dtype = utils.compute_dtype()
# Compute attention weights in float32, then downcast to compute_dtype.
# [bs, num_heads, source_len]
scores = tf.cast(position_biases, tf.float32) + attention_bias_mask
weights = tf.nn.softmax(scores, axis=-1)
weights = tf.cast(weights, compute_dtype)
weights = self.dropout(weights, training=training)
# Compute context vectors.
# b=batch_size, s=source_len, n=num_heads, d=units_per_head.
context = tf.einsum('bns,bsnd->bnd', weights, source_head_values.values)
# Add time dim, flatten heads.
context = tf.reshape(
context, [batch_size, 1, self.num_heads * self.units_per_head]
)
context = sl.Sequence(context, x.mask).mask_invalid()
return context
def _compute_alignment_deltas(
self, sublayer_output: sl.Sequence, training: bool
) -> tf.Tensor:
# Here, T = 1, because we only compute deltas for single steps.
net = tf.ensure_shape(sublayer_output.values, [None, 1, None])
# [B, T=1, D]
net = self.delta_net(net, training=training)
# [B, T=1, 1]
net = self.delta_output_layer(net, training=training)
# [B, T=1]
net = tf.squeeze(net, axis=2)
# Set invalid deltas to zero.
deltas = sl.Sequence(net, sublayer_output.mask).mask_invalid()
# [B, T=1]
return deltas.values
def _single_step_with_emits(
self,