diff --git a/sdks/python/apache_beam/testing/benchmarks/cloudml/criteo_tft/criteo.py b/sdks/python/apache_beam/testing/benchmarks/cloudml/criteo_tft/criteo.py index d2a0b652ca69..b4fdda72fe90 100644 --- a/sdks/python/apache_beam/testing/benchmarks/cloudml/criteo_tft/criteo.py +++ b/sdks/python/apache_beam/testing/benchmarks/cloudml/criteo_tft/criteo.py @@ -110,6 +110,18 @@ def make_input_feature_spec(include_label=True): return result +def fill_in_missing(feature, default_value=-1): + feature = tf.sparse.SparseTensor( + indices=feature.indices, + values=feature.values, + dense_shape=[feature.dense_shape[0], 1]) + feature = tf.sparse.to_dense(feature, default_value=default_value) + # Reshaping from a batch of vectors of size 1 to a batch of + # scalar and adding a bucketized version. + feature = tf.squeeze(feature, axis=1) + return feature + + def make_preprocessing_fn(frequency_threshold): """Creates a preprocessing function for criteo. @@ -132,15 +144,7 @@ def preprocessing_fn(inputs): result = {'clicked': inputs['clicked']} for name in _INTEGER_COLUMN_NAMES: feature = inputs[name] - # TODO(https://github.com/apache/beam/issues/24902): - # Replace this boilerplate with a helper function. - # This is a SparseTensor because it is optional. Here we fill in a - # default value when it is missing. - feature = tft.sparse_tensor_to_dense_with_shape( - feature, [None, 1], default_value=-1) - # Reshaping from a batch of vectors of size 1 to a batch of scalars and - # adding a bucketized version. - feature = tf.squeeze(feature, axis=1) + feature = fill_in_missing(feature) result[name] = feature result[name + '_bucketized'] = tft.bucketize(feature, _NUM_BUCKETS) for name in _CATEGORICAL_COLUMN_NAMES: diff --git a/sdks/python/apache_beam/testing/benchmarks/cloudml/criteo_tft/criteo_test.py b/sdks/python/apache_beam/testing/benchmarks/cloudml/criteo_tft/criteo_test.py new file mode 100644 index 000000000000..00743c3fa7cb --- /dev/null +++ b/sdks/python/apache_beam/testing/benchmarks/cloudml/criteo_tft/criteo_test.py @@ -0,0 +1,56 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +import numpy as np +import pytest + +try: + import tensorflow_transform as tft + import tensorflow as tf + from apache_beam.testing.benchmarks.cloudml.criteo_tft.criteo import fill_in_missing +except ImportError: + tft = None + +if not tft: + raise unittest.SkipTest('tensorflow_transform is not installed.') + + +@pytest.mark.uses_tft +@unittest.skipIf(tft is None or tf is None, 'Missing dependencies. ') +class FillInMissingTest(unittest.TestCase): + def test_fill_in_missing(self): + # Create a rank 2 sparse tensor with missing values + indices = np.array([[0, 0], [0, 2], [1, 1], [2, 0]]) + values = np.array([1, 2, 3, 4]) + dense_shape = np.array([3, 3]) + sparse_tensor = tf.sparse.SparseTensor(indices, values, dense_shape) + + # Fill in missing values with -1 + filled_tensor = tf.Tensor() + if fill_in_missing is not None: + filled_tensor = fill_in_missing(sparse_tensor, -1) + + # Convert to a dense tensor and check the values + expected_output = np.array([1, -1, 2, -1, -1, -1, 4, -1, -1]) + actual_output = filled_tensor.numpy() + self.assertEqual(expected_output, actual_output) + + +if __name__ == '__main__': + unittest.main()