-
Notifications
You must be signed in to change notification settings - Fork 215
/
Copy pathmappers.py
2289 lines (1970 loc) · 96.5 KB
/
mappers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The core public API of TFTransform. Provide functions to transform tensors.
The core tf.Transform API requires a user to construct a
"preprocessing function" that accepts and returns `Tensor`s. This function is
built by composing regular functions built from TensorFlow ops, as well as
special functions we refer to as `Analyzer`s. `Analyzer`s behave similarly to
TensorFlow ops but require a full pass over the whole dataset to compute their
output value. The analyzers are defined in analyzers.py, while this module
provides helper functions that call analyzers and then use the results of the
anaylzers to transform the original data.
The user-defined preprocessing function should accept and return `Tensor`s that
are batches from the dataset, whose batch size may vary. For example the
following preprocessing function centers the input 'x' while returning 'y'
unchanged.
import tensorflow_transform as tft
def preprocessing_fn(inputs):
x = inputs['x']
y = inputs['y']
# Apply the `mean` analyzer to obtain the mean x.
x_mean = tft.mean(x)
# Subtract the mean.
x_centered = x - mean
# Return a new dictionary containing x_centered, and y unchanged
return {
'x_centered': x_centered,
'y': y
}
This user-defined function then must be run using an implementation based on
some distributed computation framework. The canonical implementation uses
Apache Beam as the underlying framework. See beam/impl.py for how to use the
Beam implementation.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from typing import Any, Callable, Iterable, Optional, Tuple, Union
# GOOGLE-INITIALIZATION
import six
import tensorflow as tf
from tensorflow_transform import analyzers
from tensorflow_transform import common
from tensorflow_transform import common_types
from tensorflow_transform import gaussianization
from tensorflow_transform import schema_inference
from tensorflow_transform import tf_utils
# TODO(https://issues.apache.org/jira/browse/SPARK-22674): Switch to
# `collections.namedtuple` or `typing.NamedTuple` once the Spark issue is
# resolved.
from tfx_bsl.types import tfx_namedtuple
@common.log_api_use(common.MAPPER_COLLECTION)
def scale_to_gaussian(
x: common_types.ConsistentTensorType,
elementwise: bool = False,
name: Optional[str] = None,
output_dtype: Optional[tf.DType] = None
) -> common_types.ConsistentTensorType:
"""Returns an (approximately) normal column with mean to 0 and variance 1.
We transform the column to values that are approximately distributed
according to a standard normal distribution.
The transformation is obtained by applying the moments method to estimate
the parameters of a Tukey HH distribution and applying the inverse of the
estimated function to the column values.
The method is partially described in
Georg M. Georgm "The Lambert Way to Gaussianize Heavy-Tailed Data with the
Inverse of Tukey's h Transformation as a Special Case," The Scientific World
Journal, Vol. 2015, Hindawi Publishing Corporation.
We use the L-moments instead of conventional moments to be able to deal with
long-tailed distributions. The expressions of the L-moments for the Tukey HH
distribution is in
Todd C. Headrick, and Mohan D. Pant. "Characterizing Tukey H and
HH-Distributions through L-Moments and the L-Correlation," ISRN Applied
Mathematics, vol. 2012, 2012. doi:10.5402/2012/980153
Note that the transformation to Gaussian is applied only if the column has
long-tails. If this is not the case, for instance if values are uniformly
distributed, the values are only normalized using the z score. This applies
also to the cases where only one of the tails is long; the other tail is only
rescaled but not non linearly transformed.
Also, if the analysis set is empty, the transformation is set to to leave the
input vaules unchanged.
Args:
x: A numeric `Tensor` or `SparseTensor`.
elementwise: If true, scales each element of the tensor independently;
otherwise uses the parameters of the whole tensor.
name: (Optional) A name for this operation.
output_dtype: (Optional) If not None, casts the output tensor to this type.
Returns:
A `Tensor` or `SparseTensor` containing the input column transformed to be
approximately standard distributed (i.e. a Gaussian with mean 0 and variance
1). If `x` is floating point, the mean will have the same type as `x`. If
`x` is integral, the output is cast to tf.float32.
Note that TFLearn generally permits only tf.int64 and tf.float32, so casting
this scaler's output may be necessary.
"""
with tf.compat.v1.name_scope(name, 'scale_to_gaussian'):
return _scale_to_gaussian_internal(
x=x,
elementwise=elementwise,
output_dtype=output_dtype)
def _scale_to_gaussian_internal(
x: common_types.ConsistentTensorType,
elementwise: bool = False,
output_dtype: Optional[tf.DType] = None
) -> common_types.ConsistentTensorType:
"""Implementation for scale_to_gaussian."""
# x_mean will be float16, float32, or float64, depending on type of x.
x_loc, x_scale, hl, hr = analyzers._tukey_parameters( # pylint: disable=protected-access
x, reduce_instance_dims=not elementwise, output_dtype=output_dtype)
compose_result_fn = _make_sparse_tensor_wrapper_if_sparse(x)
x_values = x
x_var = analyzers.var(x, reduce_instance_dims=not elementwise,
output_dtype=output_dtype)
if isinstance(x, tf.SparseTensor):
x_values = x.values
if elementwise:
x_loc = tf.gather_nd(x_loc, x.indices[:, 1:])
x_scale = tf.gather_nd(x_scale, x.indices[:, 1:])
hl = tf.gather_nd(hl, x.indices[:, 1:])
hr = tf.gather_nd(hr, x.indices[:, 1:])
x_var = tf.gather_nd(x_var, x.indices[:, 1:])
numerator = tf.cast(x_values, x_loc.dtype) - x_loc
is_long_tailed = tf.math.logical_or(hl > 0.0, hr > 0.0)
# If the distribution is long-tailed, we apply the robust scale computed
# with L-moments; otherwise, we scale using the standard deviation so that
# we obtain the same result of scale_to_z_score.
denominator = tf.where(is_long_tailed, x_scale, tf.sqrt(x_var))
cond = tf.not_equal(denominator, 0)
if cond.shape.as_list() != x_values.shape.as_list():
# Repeats cond when necessary across the batch dimension for it to be
# compatible with the shape of numerator.
cond = tf.cast(
tf.zeros_like(numerator) + tf.cast(cond, numerator.dtype),
dtype=tf.bool)
scaled_values = tf.where(cond, tf.divide(numerator, denominator),
numerator)
gaussianized_values = gaussianization.inverse_tukey_hh(scaled_values, hl, hr)
return compose_result_fn(gaussianized_values)
@common.log_api_use(common.MAPPER_COLLECTION)
def sparse_tensor_to_dense_with_shape(
x: tf.SparseTensor,
shape: Union[tf.TensorShape, Iterable[int]],
default_value: Optional[Union[tf.Tensor, int, float,
str]] = 0) -> tf.Tensor:
"""Converts a `SparseTensor` into a dense tensor and sets its shape.
Args:
x: A `SparseTensor`.
shape: The desired shape of the densified `Tensor`.
default_value: (Optional) Value to set for indices not specified. Defaults
to zero.
Returns:
A `Tensor` with the desired shape.
Raises:
ValueError: If input is not a `SparseTensor`.
"""
if not isinstance(x, tf.SparseTensor):
raise ValueError('input must be a SparseTensor')
new_dense_shape = [
x.dense_shape[i] if size is None else size
for i, size in enumerate(shape)
]
dense = tf.raw_ops.SparseToDense(
sparse_indices=x.indices,
output_shape=new_dense_shape,
sparse_values=x.values,
default_value=default_value)
dense.set_shape(shape)
return dense
@common.log_api_use(common.MAPPER_COLLECTION)
def sparse_tensor_left_align(sparse_tensor: tf.SparseTensor) -> tf.SparseTensor:
"""Re-arranges a `tf.SparseTensor` and returns a left-aligned version of it.
This mapper can be useful when returning a sparse tensor that may not be
left-aligned from a preprocessing_fn.
Args:
sparse_tensor: A 2D `tf.SparseTensor`.
Raises:
ValueError if `sparse_tensor` is not 2D.
Returns:
A left-aligned version of sparse_tensor as a `tf.SparseTensor`.
"""
if sparse_tensor.get_shape().ndims != 2:
raise ValueError('sparse_tensor_left_align requires a 2D input')
reordered_tensor = tf.sparse.reorder(sparse_tensor)
transposed_indices = tf.transpose(reordered_tensor.indices)
row_indices = transposed_indices[0]
row_counts = tf.unique_with_counts(row_indices, out_idx=tf.int64).count
column_indices = tf.ragged.range(row_counts).flat_values
return tf.SparseTensor(
indices=tf.transpose(tf.stack([row_indices, column_indices])),
values=reordered_tensor.values,
dense_shape=reordered_tensor.dense_shape)
@common.log_api_use(common.MAPPER_COLLECTION)
def scale_by_min_max(
x: common_types.ConsistentTensorType,
output_min: float = 0.0,
output_max: float = 1.0,
elementwise: bool = False,
name: Optional[str] = None) -> common_types.ConsistentTensorType:
"""Scale a numerical column into the range [output_min, output_max].
Args:
x: A numeric `Tensor` or `SparseTensor`.
output_min: The minimum of the range of output values.
output_max: The maximum of the range of output values.
elementwise: If true, scale each element of the tensor independently.
name: (Optional) A name for this operation.
Returns:
A `Tensor` containing the input column scaled to [output_min, output_max].
If the analysis dataset is empty or contains a singe distinct value, then
`x` is scaled using a sigmoid function.
Raises:
ValueError: If output_min, output_max have the wrong order.
"""
with tf.compat.v1.name_scope(name, 'scale_by_min_max'):
return _scale_by_min_max_internal(
x,
key=None,
output_min=output_min,
output_max=output_max,
elementwise=elementwise,
key_vocabulary_filename=None)
@common.log_api_use(common.MAPPER_COLLECTION)
def scale_by_min_max_per_key(
x: common_types.ConsistentTensorType,
key: common_types.TensorType,
output_min: float = 0.0,
output_max: float = 1.0,
elementwise: bool = False,
key_vocabulary_filename: Optional[str] = None,
name: Optional[str] = None) -> common_types.ConsistentTensorType:
# pyformat: disable
"""Scale a numerical column into a predefined range on a per-key basis.
Args:
x: A numeric `Tensor` or `SparseTensor`.
key: A `Tensor` or `SparseTensor` of dtype tf.string.
Must meet one of the following conditions:
0. key is None
1. Both x and key are dense,
2. Both x and key are sparse and `key` must exactly match `x` in
everything except values,
3. The axis=1 index of each x matches its index of dense key.
output_min: The minimum of the range of output values.
output_max: The maximum of the range of output values.
elementwise: If true, scale each element of the tensor independently.
key_vocabulary_filename: (Optional) The file name for the per-key file.
If None, this combiner will assume the keys fit in memory and will not
store the analyzer result in a file. If '', a file name will be chosen
based on the current TensorFlow scope. If not '', it should be unique
within a given preprocessing function.
name: (Optional) A name for this operation.
Example:
>>> def preprocessing_fn(inputs):
... return {
... 'scaled': tft.scale_by_min_max_per_key(inputs['x'], inputs['s'])
... }
>>> raw_data = [dict(x=1, s='a'), dict(x=0, s='b'), dict(x=3, s='a')]
>>> feature_spec = dict(
... x=tf.io.FixedLenFeature([], tf.float32),
... s=tf.io.FixedLenFeature([], tf.string))
>>> raw_data_metadata = tft.tf_metadata.dataset_metadata.DatasetMetadata(
... tft.tf_metadata.schema_utils.schema_from_feature_spec(feature_spec))
>>> with tft_beam.Context(temp_dir=tempfile.mkdtemp()):
... transformed_dataset, transform_fn = (
... (raw_data, raw_data_metadata)
... | tft_beam.AnalyzeAndTransformDataset(preprocessing_fn))
>>> transformed_data, transformed_metadata = transformed_dataset
>>> transformed_data
[{'scaled': 0.0}, {'scaled': 0.5}, {'scaled': 1.0}]
Returns:
A `Tensor` or `SparseTensor` containing the input column scaled to
[output_min, output_max] on a per-key basis if a key is provided. If the
analysis dataset is empty, a certain key contains a single distinct value or
the computed key vocabulary doesn't have an entry for `key`, then `x` is
scaled using a sigmoid function.
Raises:
ValueError: If output_min, output_max have the wrong order.
NotImplementedError: If elementwise is True and key is not None.
InvalidArgumentError: If indices of sparse x and key do not match.
"""
# pyformat: enable
with tf.compat.v1.name_scope(name, 'scale_by_min_max_per_key'):
if key is None:
raise ValueError('key is None, call `tft.scale_by_min_max` instead')
return _scale_by_min_max_internal(
x,
key=key,
output_min=output_min,
output_max=output_max,
elementwise=elementwise,
key_vocabulary_filename=key_vocabulary_filename)
def _scale_by_min_max_internal(
x: common_types.ConsistentTensorType,
key: Optional[common_types.TensorType],
output_min: float,
output_max: float,
elementwise: bool,
key_vocabulary_filename: Optional[str] = None
) -> common_types.ConsistentTensorType:
"""Implementation for scale_by_min_max."""
if output_min >= output_max:
raise ValueError('output_min must be less than output_max')
x = tf.cast(x, tf.float32)
if key is None:
min_x_value, max_x_value = analyzers._min_and_max( # pylint: disable=protected-access
x,
reduce_instance_dims=not elementwise)
else:
if elementwise:
raise NotImplementedError('Per-key elementwise reduction not supported')
key_values = analyzers._min_and_max_per_key( # pylint: disable=protected-access
x,
key,
reduce_instance_dims=True,
key_vocabulary_filename=key_vocabulary_filename)
if key_vocabulary_filename is None:
key_vocab, min_x_value, max_x_value = key_values
# Missing keys will translate to 0 for both min and max which will be
# ignored below in the tf.where.
min_x_value, max_x_value = tf_utils.map_per_key_reductions(
(min_x_value, max_x_value), key, key_vocab, x)
else:
minus_min_max_for_key = tf_utils.apply_per_key_vocabulary(
key_values, key, target_ndims=x.get_shape().ndims)
min_x_value, max_x_value = (
-minus_min_max_for_key[:, 0], minus_min_max_for_key[:, 1])
compose_result_fn = _make_sparse_tensor_wrapper_if_sparse(x)
x_values = x
if isinstance(x, tf.SparseTensor):
if elementwise:
min_x_value = tf.gather_nd(
tf.broadcast_to(min_x_value, x.dense_shape), x.indices)
max_x_value = tf.gather_nd(
tf.broadcast_to(max_x_value, x.dense_shape), x.indices)
x_values = x.values
# If min>=max, then the corresponding input to the min_and_max analyzer either
# was empty and the analyzer returned default values, or contained only one
# distinct value. In this case we scale x by applying a sigmoid function which
# is continuous, increasing and maps (-inf, inf) -> (0, 1). Its output is
# then projected on the requested range. Note that both the options of
# tf.where are computed, which means that this will compute unused NaNs.
numerator = tf.cast(x_values, min_x_value.dtype) - min_x_value
where_cond = min_x_value < max_x_value
where_cond = tf.cast(
tf.zeros_like(numerator) + tf.cast(where_cond, numerator.dtype),
dtype=tf.bool)
scaled_result = tf.where(where_cond, numerator / (max_x_value - min_x_value),
tf.math.sigmoid(x_values))
return compose_result_fn((scaled_result * (output_max - output_min)) +
output_min)
@common.log_api_use(common.MAPPER_COLLECTION)
def scale_to_0_1(
x: common_types.ConsistentTensorType,
elementwise: bool = False,
name: Optional[str] = None) -> common_types.ConsistentTensorType:
"""Returns a column which is the input column scaled to have range [0,1].
Args:
x: A numeric `Tensor` or `SparseTensor`.
elementwise: If true, scale each element of the tensor independently.
name: (Optional) A name for this operation.
Returns:
A `Tensor` or `SparseTensor` containing the input column scaled to [0, 1].
If the analysis dataset is empty or contains a single distinct value, then
`x` is scaled using a sigmoid function.
"""
with tf.compat.v1.name_scope(name, 'scale_to_0_1'):
return _scale_by_min_max_internal(
x,
key=None,
output_min=0,
output_max=1,
elementwise=elementwise,
key_vocabulary_filename=None)
@common.log_api_use(common.MAPPER_COLLECTION)
def scale_to_0_1_per_key(
x: common_types.ConsistentTensorType,
key: common_types.TensorType,
elementwise: bool = False,
key_vocabulary_filename: Optional[str] = None,
name: Optional[str] = None) -> common_types.ConsistentTensorType:
# pyformat: disable
"""Returns a column which is the input column scaled to have range [0,1].
Args:
x: A numeric `Tensor` or `SparseTensor`.
key: A `Tensor` or `SparseTensor` of type string.
elementwise: If true, scale each element of the tensor independently.
key_vocabulary_filename: (Optional) The file name for the per-key file. If
None, this combiner will assume the keys fit in memory and will not store
the analyzer result in a file. If '', a file name will be chosen based on
the current TensorFlow scope. If not '', it should be unique within a
given preprocessing function.
name: (Optional) A name for this operation.
Example:
>>> def preprocessing_fn(inputs):
... return {
... 'scaled': tft.scale_to_0_1_per_key(inputs['x'], inputs['s'])
... }
>>> raw_data = [dict(x=1, s='a'), dict(x=0, s='b'), dict(x=3, s='a')]
>>> feature_spec = dict(
... x=tf.io.FixedLenFeature([], tf.float32),
... s=tf.io.FixedLenFeature([], tf.string))
>>> raw_data_metadata = tft.tf_metadata.dataset_metadata.DatasetMetadata(
... tft.tf_metadata.schema_utils.schema_from_feature_spec(feature_spec))
>>> with tft_beam.Context(temp_dir=tempfile.mkdtemp()):
... transformed_dataset, transform_fn = (
... (raw_data, raw_data_metadata)
... | tft_beam.AnalyzeAndTransformDataset(preprocessing_fn))
>>> transformed_data, transformed_metadata = transformed_dataset
>>> transformed_data
[{'scaled': 0.0}, {'scaled': 0.5}, {'scaled': 1.0}]
Returns:
A `Tensor` or `SparseTensor` containing the input column scaled to [0, 1],
per key. If the analysis dataset is empty, contains a single distinct value
or the computed key vocabulary doesn't have an entry for `key`, then `x` is
scaled using a sigmoid function.
"""
# pyformat: enable
with tf.compat.v1.name_scope(name, 'scale_to_0_1_per_key'):
if key is None:
raise ValueError('key is None, call `tft.scale_to_0_1` instead')
return _scale_by_min_max_internal(
x,
key=key,
output_min=0,
output_max=1,
elementwise=elementwise,
key_vocabulary_filename=key_vocabulary_filename)
@common.log_api_use(common.MAPPER_COLLECTION)
def scale_to_z_score(
x: common_types.ConsistentTensorType,
elementwise: bool = False,
name: Optional[str] = None,
output_dtype: Optional[tf.DType] = None
) -> common_types.ConsistentTensorType:
"""Returns a standardized column with mean 0 and variance 1.
Scaling to z-score subtracts out the mean and divides by standard deviation.
Note that the standard deviation computed here is based on the biased variance
(0 delta degrees of freedom), as computed by analyzers.var.
Args:
x: A numeric `Tensor` or `SparseTensor`.
elementwise: If true, scales each element of the tensor independently;
otherwise uses the mean and variance of the whole tensor.
name: (Optional) A name for this operation.
output_dtype: (Optional) If not None, casts the output tensor to this type.
Returns:
A `Tensor` or `SparseTensor` containing the input column scaled to mean 0
and variance 1 (standard deviation 1), given by: (x - mean(x)) / std_dev(x).
If `x` is floating point, the mean will have the same type as `x`. If `x` is
integral, the output is cast to tf.float32. If the analysis dataset is empty
or contains a single distinct value, then the input is returned without
scaling.
Note that TFLearn generally permits only tf.int64 and tf.float32, so casting
this scaler's output may be necessary.
"""
with tf.compat.v1.name_scope(name, 'scale_to_z_score'):
return _scale_to_z_score_internal(
x=x,
key=None,
elementwise=elementwise,
key_vocabulary_filename=None,
output_dtype=output_dtype)
@common.log_api_use(common.MAPPER_COLLECTION)
def scale_to_z_score_per_key(
x: common_types.ConsistentTensorType,
key: common_types.TensorType,
elementwise: bool = False,
key_vocabulary_filename: Optional[str] = None,
name: Optional[str] = None,
output_dtype: Optional[tf.DType] = None
) -> common_types.ConsistentTensorType:
"""Returns a standardized column with mean 0 and variance 1, grouped per key.
Scaling to z-score subtracts out the mean and divides by standard deviation.
Note that the standard deviation computed here is based on the biased variance
(0 delta degrees of freedom), as computed by analyzers.var.
Args:
x: A numeric `Tensor` or `SparseTensor`.
key: A Tensor or `SparseTensor` of dtype tf.string.
Must meet one of the following conditions:
0. key is None
1. Both x and key are dense,
2. Both x and key are sparse and `key` must exactly match `x` in
everything except values,
3. The axis=1 index of each x matches its index of dense key.
elementwise: If true, scales each element of the tensor independently;
otherwise uses the mean and variance of the whole tensor.
Currently, not supported for per-key operations.
key_vocabulary_filename: (Optional) The file name for the per-key file.
If None, this combiner will assume the keys fit in memory and will not
store the analyzer result in a file. If '', a file name will be chosen
based on the current TensorFlow scope. If not '', it should be unique
within a given preprocessing function.
name: (Optional) A name for this operation.
output_dtype: (Optional) If not None, casts the output tensor to this type.
Returns:
A `Tensor` or `SparseTensor` containing the input column scaled to mean 0
and variance 1 (standard deviation 1), grouped per key if a key is provided.
That is, for all keys k: (x - mean(x)) / std_dev(x) for all x with key k.
If `x` is floating point, the mean will have the same type as `x`. If `x` is
integral, the output is cast to tf.float32. If the analysis dataset is
empty, contains a single distinct value or the computed key vocabulary
doesn't have an entry for `key`, then the input is returned without scaling.
Note that TFLearn generally permits only tf.int64 and tf.float32, so casting
this scaler's output may be necessary.
"""
with tf.compat.v1.name_scope(name, 'scale_to_z_score_per_key'):
if key is None:
raise ValueError('key is None, call `tft.scale_to_z_score` instead')
return _scale_to_z_score_internal(
x=x,
key=key,
elementwise=elementwise,
key_vocabulary_filename=key_vocabulary_filename,
output_dtype=output_dtype)
def _scale_to_z_score_internal(
x: common_types.ConsistentTensorType,
key: Optional[common_types.TensorType], elementwise: bool,
key_vocabulary_filename: Optional[str],
output_dtype: Optional[tf.DType]) -> common_types.ConsistentTensorType:
"""Implementation for scale_to_z_score."""
# x_mean will be float16, float32, or float64, depending on type of x
if key is None:
x_mean, x_var = analyzers._mean_and_var( # pylint: disable=protected-access
x,
reduce_instance_dims=not elementwise,
output_dtype=output_dtype)
else:
if elementwise:
raise NotImplementedError('Per-key elementwise reduction not supported')
mean_and_var_per_key_result = analyzers._mean_and_var_per_key( # pylint: disable=protected-access
x, key, key_vocabulary_filename=key_vocabulary_filename,
output_dtype=output_dtype)
if key_vocabulary_filename is None:
# Missing keys will translate to 0 for both mean and var which will be
# ignored below in the tf.where.
key_vocab, key_means, key_vars = mean_and_var_per_key_result
x_mean, x_var = tf_utils.map_per_key_reductions((key_means, key_vars),
key, key_vocab, x)
else:
mean_var_for_key = tf_utils.apply_per_key_vocabulary(
mean_and_var_per_key_result, key, target_ndims=x.get_shape().ndims)
x_mean, x_var = (mean_var_for_key[:, 0], mean_var_for_key[:, 1])
compose_result_fn = _make_sparse_tensor_wrapper_if_sparse(x)
x_values = x
if isinstance(x, tf.SparseTensor):
x_values = x.values
if elementwise:
x_mean = tf.gather_nd(tf.broadcast_to(x_mean, x.dense_shape), x.indices)
x_var = tf.gather_nd(tf.broadcast_to(x_var, x.dense_shape), x.indices)
numerator = tf.cast(x_values, x_mean.dtype) - x_mean
denominator = tf.sqrt(x_var)
cond = tf.not_equal(denominator, 0)
if cond.shape.as_list() != x_values.shape.as_list():
# Repeats cond when necessary across the batch dimension for it to be
# compatible with the shape of numerator.
cond = tf.cast(
tf.zeros_like(numerator) + tf.cast(cond, numerator.dtype),
dtype=tf.bool)
deviation_values = tf.where(cond, tf.divide(numerator, denominator),
numerator)
return compose_result_fn(deviation_values)
@common.log_api_use(common.MAPPER_COLLECTION)
def tfidf(
x: tf.SparseTensor,
vocab_size: int,
smooth: bool = True,
name: Optional[str] = None) -> Tuple[tf.SparseTensor, tf.SparseTensor]:
# pyformat: disable
"""Maps the terms in x to their term frequency * inverse document frequency.
The term frequency of a term in a document is calculated as
(count of term in document) / (document size)
The inverse document frequency of a term is, by default, calculated as
1 + log((corpus size + 1) / (count of documents containing term + 1)).
Example usage:
>>> def preprocessing_fn(inputs):
... integerized = tft.compute_and_apply_vocabulary(inputs['x'])
... vocab_size = tft.get_num_buckets_for_transformed_feature(integerized)
... vocab_index, tfidf_weight = tft.tfidf(integerized, vocab_size)
... return {
... 'index': vocab_index,
... 'tf_idf': tfidf_weight,
... 'integerized': integerized,
... }
>>> raw_data = [dict(x=["I", "like", "pie", "pie", "pie"]),
... dict(x=["yum", "yum", "pie"])]
>>> feature_spec = dict(x=tf.io.VarLenFeature(tf.string))
>>> raw_data_metadata = tft.tf_metadata.dataset_metadata.DatasetMetadata(
... tft.tf_metadata.schema_utils.schema_from_feature_spec(feature_spec))
>>> with tft_beam.Context(temp_dir=tempfile.mkdtemp()):
... transformed_dataset, transform_fn = (
... (raw_data, raw_data_metadata)
... | tft_beam.AnalyzeAndTransformDataset(preprocessing_fn))
>>> transformed_data, transformed_metadata = transformed_dataset
>>> transformed_data
[{'index': array([0, 2, 3]), 'integerized': array([3, 2, 0, 0, 0]),
'tf_idf': array([0.6, 0.28109303, 0.28109303], dtype=float32)},
{'index': array([0, 1]), 'integerized': array([1, 1, 0]),
'tf_idf': array([0.33333334, 0.9369768 ], dtype=float32)}]
```
example strings: [["I", "like", "pie", "pie", "pie"], ["yum", "yum", "pie]]
in: SparseTensor(indices=[[0, 0], [0, 1], [0, 2], [0, 3], [0, 4],
[1, 0], [1, 1], [1, 2]],
values=[1, 2, 0, 0, 0, 3, 3, 0])
out: SparseTensor(indices=[[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]],
values=[1, 2, 0, 3, 0])
SparseTensor(indices=[[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]],
values=[(1/5)*(log(3/2)+1), (1/5)*(log(3/2)+1), (3/5),
(2/3)*(log(3/2)+1), (1/3)]
```
NOTE: the first doc's duplicate "pie" strings have been combined to
one output, as have the second doc's duplicate "yum" strings.
Args:
x: A 2D `SparseTensor` representing int64 values (most likely that are the
result of calling `compute_and_apply_vocabulary` on a tokenized string).
vocab_size: An int - the count of vocab used to turn the string into int64s
including any OOV buckets.
smooth: A bool indicating if the inverse document frequency should be
smoothed. If True, which is the default, then the idf is calculated as
1 + log((corpus size + 1) / (document frequency of term + 1)).
Otherwise, the idf is
1 +log((corpus size) / (document frequency of term)), which could
result in a division by zero error.
name: (Optional) A name for this operation.
Returns:
Two `SparseTensor`s with indices [index_in_batch, index_in_bag_of_words].
The first has values vocab_index, which is taken from input `x`.
The second has values tfidf_weight.
Raises:
ValueError if `x` does not have 2 dimensions.
"""
# pyformat: enable
if x.get_shape().ndims != 2:
raise ValueError('tft.tfidf requires a 2D SparseTensor input. '
'Input had {} dimensions.'.format(x.get_shape().ndims))
def _to_vocab_range(x):
"""Enforces that the vocab_ids in x are positive."""
return tf.SparseTensor(
indices=x.indices,
values=tf.math.mod(x.values, vocab_size),
dense_shape=x.dense_shape)
with tf.compat.v1.name_scope(name, 'tfidf'):
cleaned_input = _to_vocab_range(x)
term_frequencies = _to_term_frequency(cleaned_input, vocab_size)
count_docs_with_term_column = _count_docs_with_term(term_frequencies)
# Expand dims to get around the min_tensor_rank checks
sizes = tf.expand_dims(tf.shape(input=cleaned_input)[0], 0)
# [batch, vocab] - tfidf
tfidfs = _to_tfidf(term_frequencies,
analyzers.sum(count_docs_with_term_column,
reduce_instance_dims=False),
analyzers.sum(sizes),
smooth)
return _split_tfidfs_to_outputs(tfidfs)
def _split_tfidfs_to_outputs(
tfidfs: tf.SparseTensor) -> Tuple[tf.SparseTensor, tf.SparseTensor]:
"""Splits [batch, vocab]-weight into [batch, bow]-vocab & [batch, bow]-tfidf.
Args:
tfidfs: the `SparseTensor` output of _to_tfidf
Returns:
Two `SparseTensor`s with indices [index_in_batch, index_in_bag_of_words].
The first has values vocab_index, which is taken from input `x`.
The second has values tfidf_weight.
"""
# Split tfidfs tensor into [batch, dummy] -> vocab & [batch, dummy] -> tfidf
# The "dummy" index counts from 0 to the number of unique tokens in the doc.
# So example doc ["I", "like", "pie", "pie", "pie"], with 3 unique tokens,
# will have "dummy" indices [0, 1, 2]. The particular dummy index that any
# token receives is not important, only that the tfidf value and vocab index
# have the *same* dummy index, so that feature_column can apply the weight to
# the correct vocab item.
dummy_index = segment_indices(tfidfs.indices[:, 0])
out_index = tf.concat(
[tf.expand_dims(tfidfs.indices[:, 0], 1),
tf.expand_dims(dummy_index, 1)], 1)
out_shape_second_dim = tf.maximum(
tf.reduce_max(input_tensor=dummy_index), -1) + 1
out_shape = tf.stack([tfidfs.dense_shape[0], out_shape_second_dim])
out_shape.set_shape([2])
de_duped_indicies_out = tf.SparseTensor( # NOTYPO ('indices')
indices=out_index,
values=tfidfs.indices[:, 1],
dense_shape=out_shape)
de_duped_tfidf_out = tf.SparseTensor(
indices=out_index,
values=tfidfs.values,
dense_shape=out_shape)
return de_duped_indicies_out, de_duped_tfidf_out # NOTYPO ('indices')
def _to_term_frequency(x: tf.SparseTensor,
vocab_size: Union[int, tf.Tensor]) -> tf.SparseTensor:
"""Creates a SparseTensor of term frequency for every doc/term pair.
Args:
x : a SparseTensor of int64 representing string indices in vocab.
vocab_size: A scalar int64 Tensor - the count of vocab used to turn the
string into int64s including any OOV buckets.
Returns:
a SparseTensor with the count of times a term appears in a document at
indices <doc_index_in_batch>, <term_index_in_vocab>,
with size (num_docs_in_batch, vocab_size).
"""
# Construct intermediary sparse tensor with indices
# [<doc>, <term_index_in_doc>, <vocab_id>] and tf.ones values.
vocab_size = tf.convert_to_tensor(value=vocab_size, dtype=tf.int64)
split_indices = tf.cast(
tf.split(x.indices, axis=1, num_or_size_splits=2), dtype=tf.int64)
expanded_values = tf.cast(tf.expand_dims(x.values, 1), dtype=tf.int64)
next_index = tf.concat(
[split_indices[0], split_indices[1], expanded_values], axis=1)
next_values = tf.ones_like(x.values)
expanded_vocab_size = tf.expand_dims(vocab_size, 0)
next_shape = tf.concat(
[x.dense_shape, expanded_vocab_size], 0)
next_tensor = tf.SparseTensor(
indices=tf.cast(next_index, dtype=tf.int64),
values=next_values,
dense_shape=next_shape)
# Take the intermediary tensor and reduce over the term_index_in_doc
# dimension. This produces a tensor with indices [<doc_id>, <term_id>]
# and values [count_of_term_in_doc] and shape batch x vocab_size
term_count_per_doc = tf.compat.v1.sparse_reduce_sum_sparse(next_tensor, 1)
dense_doc_sizes = tf.cast(
tf.sparse.reduce_sum(
tf.SparseTensor(
indices=x.indices,
values=tf.ones_like(x.values),
dense_shape=x.dense_shape), 1),
dtype=tf.float64)
gather_indices = term_count_per_doc.indices[:, 0]
gathered_doc_sizes = tf.gather(dense_doc_sizes, gather_indices)
term_frequency = (
tf.cast(term_count_per_doc.values, dtype=tf.float64) /
tf.cast(gathered_doc_sizes, dtype=tf.float64))
return tf.SparseTensor(
indices=term_count_per_doc.indices,
values=term_frequency,
dense_shape=term_count_per_doc.dense_shape)
def _to_tfidf(term_frequency: tf.SparseTensor, reduced_term_freq: tf.Tensor,
corpus_size: tf.Tensor, smooth: bool) -> tf.SparseTensor:
"""Calculates the inverse document frequency of terms in the corpus.
Args:
term_frequency: The `SparseTensor` output of _to_term_frequency.
reduced_term_freq: A `Tensor` of shape (vocabSize,) that represents the
count of the number of documents with each term.
corpus_size: A scalar count of the number of documents in the corpus.
smooth: A bool indicating if the idf value should be smoothed. See
tfidf_weights documentation for details.
Returns:
A `SparseTensor` with indices=<doc_index_in_batch>, <term_index_in_vocab>,
values=term frequency * inverse document frequency,
and shape=(batch, vocab_size)
"""
# The idf tensor has shape (vocab_size,)
if smooth:
idf = tf.math.log((tf.cast(corpus_size, dtype=tf.float64) + 1.0) /
(1.0 + tf.cast(reduced_term_freq, dtype=tf.float64))) + 1
else:
idf = tf.math.log(
tf.cast(corpus_size, dtype=tf.float64) /
(tf.cast(reduced_term_freq, dtype=tf.float64))) + 1
gathered_idfs = tf.gather(tf.squeeze(idf), term_frequency.indices[:, 1])
tfidf_values = (tf.cast(term_frequency.values, tf.float32)
* tf.cast(gathered_idfs, tf.float32))
return tf.SparseTensor(
indices=term_frequency.indices,
values=tfidf_values,
dense_shape=term_frequency.dense_shape)
def _count_docs_with_term(term_frequency: tf.SparseTensor) -> tf.Tensor:
"""Computes the number of documents in a batch that contain each term.
Args:
term_frequency: The `SparseTensor` output of _to_term_frequency.
Returns:
A `Tensor` of shape (vocab_size,) that contains the number of documents in
the batch that contain each term.
"""
count_of_doc_inter = tf.SparseTensor(
indices=term_frequency.indices,
values=tf.ones_like(term_frequency.values),
dense_shape=term_frequency.dense_shape)
out = tf.sparse.reduce_sum(count_of_doc_inter, axis=0)
return tf.expand_dims(out, 0)
@common.log_api_use(common.MAPPER_COLLECTION)
def compute_and_apply_vocabulary(
x: common_types.ConsistentTensorType,
default_value: Optional[Any] = -1,
top_k: Optional[int] = None,
frequency_threshold: Optional[int] = None,
num_oov_buckets: Optional[int] = 0,
vocab_filename: Optional[str] = None,
weights: Optional[tf.Tensor] = None,
labels: Optional[tf.Tensor] = None,
use_adjusted_mutual_info: bool = False,
min_diff_from_avg: Optional[float] = 0.0,
coverage_top_k: Optional[int] = None,
coverage_frequency_threshold: Optional[int] = None,
key_fn: Optional[Callable[[Any], Any]] = None,
fingerprint_shuffle: bool = False,
file_format: Optional[common_types.VocabularyFileFormatType] = analyzers
.DEFAULT_VOCABULARY_FILE_FORMAT,
name: Optional[str] = None) -> common_types.ConsistentTensorType: # TODO(b/64987151): Remove # pytype: disable=annotation-type-mismatch
r"""Generates a vocabulary for `x` and maps it to an integer with this vocab.
In case one of the tokens contains the '\n' or '\r' characters or is empty it
will be discarded since we are currently writing the vocabularies as text
files. This behavior will likely be fixed/improved in the future.
Note that this function will cause a vocabulary to be computed. For large
datasets it is highly recommended to either set frequency_threshold or top_k
to control the size of the vocabulary, and also the run time of this
operation.
Args:
x: A `Tensor` or `SparseTensor` of type tf.string or tf.int[8|16|32|64].
default_value: The value to use for out-of-vocabulary values, unless
'num_oov_buckets' is greater than zero.
top_k: Limit the generated vocabulary to the first `top_k` elements. If set
to None, the full vocabulary is generated.
frequency_threshold: Limit the generated vocabulary only to elements whose
absolute frequency is >= to the supplied threshold. If set to None, the
full vocabulary is generated. Absolute frequency means the number of
occurences of the element in the dataset, as opposed to the proportion of
instances that contain that element. If labels are provided and the vocab
is computed using mutual information, tokens are filtered if their mutual
information with the label is < the supplied threshold.
num_oov_buckets: Any lookup of an out-of-vocabulary token will return a
bucket ID based on its hash if `num_oov_buckets` is greater than zero.
Otherwise it is assigned the `default_value`.
vocab_filename: The file name for the vocabulary file. If None, a name based
on the scope name in the context of this graph will be used as the
file name. If not None, should be unique within a given preprocessing
function.
NOTE in order to make your pipelines resilient to implementation details
please set `vocab_filename` when you are using the vocab_filename on a
downstream component.
weights: (Optional) Weights `Tensor` for the vocabulary. It must have the
same shape as x.
labels: (Optional) A `Tensor` of labels for the vocabulary. If provided,
the vocabulary is calculated based on mutual information with the label,
rather than frequency. The labels must have the same batch dimension as x.
If x is sparse, labels should be a 1D tensor reflecting row-wise labels.
If x is dense, labels can either be a 1D tensor of row-wise labels, or
a dense tensor of the identical shape as x (i.e. element-wise labels).
Labels should be a discrete integerized tensor (If the label is numeric,
it should first be bucketized; If the label is a string, an integer
vocabulary should first be applied). Note: `SparseTensor` labels are not
yet supported (b/134931826). WARNING: when labels are provided, the
frequency_threshold argument functions as a mutual information threshold,
which is a float. TODO(b/116308354): Fix confusing naming.
use_adjusted_mutual_info: If true, use adjusted mutual information.
min_diff_from_avg: Mutual information of a feature will be adjusted to zero
whenever the difference between count of the feature with any label and
its expected count is lower than min_diff_from_average.
coverage_top_k: (Optional), (Experimental) The minimum number of elements
per key to be included in the vocabulary.
coverage_frequency_threshold: (Optional), (Experimental) Limit the coverage
arm of the vocabulary only to elements whose absolute frequency is >= this
threshold for a given key.
key_fn: (Optional), (Experimental) A fn that takes in a single entry of `x`
and returns the corresponding key for coverage calculation. If this is
`None`, no coverage arm is added to the vocabulary.