diff --git a/CHANGES.md b/CHANGES.md index e5c097439240..d8f322755fbd 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -60,6 +60,7 @@ container was based upon Debian 11. * RunInference PTransform will accept model paths as SideInputs in Python SDK. ([#24042](https://github.com/apache/beam/issues/24042)) * RunInference supports ONNX runtime in Python SDK ([#22972](https://github.com/apache/beam/issues/22972)) +* Tensorflow Model Handler for RunInference in Python SDK ([#25366](https://github.com/apache/beam/issues/25366)) ## I/Os diff --git a/sdks/python/apache_beam/examples/inference/README.md b/sdks/python/apache_beam/examples/inference/README.md index 69cd773593bd..c56f5eb8242b 100644 --- a/sdks/python/apache_beam/examples/inference/README.md +++ b/sdks/python/apache_beam/examples/inference/README.md @@ -523,3 +523,56 @@ background ... ``` Each line has a list of predicted label. + +--- +## MNIST digit classification with Tensorflow using Saved Model Weights +[`tensorflow_mnist_with_weights.py`](./tensorflow_mnist_with_weights.py) contains an implementation for a RunInference pipeline that performs image classification on handwritten digits from the [MNIST](https://en.wikipedia.org/wiki/MNIST_database) database. + +The pipeline reads rows of pixels corresponding to a digit, performs basic preprocessing(converts the input shape to 28x28), passes the pixels to the trained Tensorflow model with RunInference, and then writes the predictions to a text file. + +The model is loaded from the saved model weights. This can be done by passing a function which creates the model and setting the model type as +`ModelType.SAVED_WEIGHTS` to the `TFModelHandler`. The path to saved weights saved using `model.save_weights(path)` should be passed to the `model_path` argument. + +### Dataset and model for language modeling + +To use this transform, you need a dataset and model for language modeling. + +1. Create a file named [`INPUT.csv`](gs://apache-beam-ml/testing/inputs/it_mnist_data.csv) that contains labels and pixels to feed into the model. Each row should have comma-separated elements. The first element is the label. All other elements are pixel values. The csv should not have column headers. The content of the file should be similar to the following example: +``` +1,0,0,0... +0,0,0,0... +1,0,0,0... +4,0,0,0... +... +``` +2. Save the weights of trained tensorflow model to a directory `SAVED_WEIGHTS_DIR` . + + +### Running `tensorflow_mnist_with_weights.py` + +To run the MNIST classification pipeline locally, use the following command: +```sh +python -m apache_beam.examples.inference.tensorflow_mnist_with_weights.py \ + --input INPUT \ + --output OUTPUT \ + --model_path SAVED_WEIGHTS_DIR +``` +For example: +```sh +python -m apache_beam.examples.inference.tensorflow_mnist_with_weights.py \ + --input INPUT.csv \ + --output predictions.txt \ + --model_path SAVED_WEIGHTS_DIR +``` + +This writes the output to the `predictions.txt` with contents like: +``` +1,1 +4,4 +0,0 +7,7 +3,3 +5,5 +... +``` +Each line has data separated by a comma ",". The first item is the actual label of the digit. The second item is the predicted label of the digit. diff --git a/sdks/python/apache_beam/examples/inference/tensorflow_mnist_with_weights.py b/sdks/python/apache_beam/examples/inference/tensorflow_mnist_with_weights.py new file mode 100644 index 000000000000..ae51f8d9cdea --- /dev/null +++ b/sdks/python/apache_beam/examples/inference/tensorflow_mnist_with_weights.py @@ -0,0 +1,93 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging + +import apache_beam as beam +import tensorflow as tf +from apache_beam.examples.inference.tensorflow_mnist_classification import PostProcessor +from apache_beam.examples.inference.tensorflow_mnist_classification import parse_known_args +from apache_beam.examples.inference.tensorflow_mnist_classification import process_input +from apache_beam.ml.inference.base import KeyedModelHandler +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.inference.tensorflow_inference import ModelType +from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerNumpy +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions +from apache_beam.runners.runner import PipelineResult + + +def get_model(): + inputs = tf.keras.layers.Input(shape=(28, 28, 1)) + x = tf.keras.layers.Conv2D(32, 3, activation="relu")(inputs) + x = tf.keras.layers.Conv2D(32, 3, activation="relu")(x) + x = tf.keras.layers.MaxPooling2D(2)(x) + x = tf.keras.layers.Conv2D(64, 3, activation="relu")(x) + x = tf.keras.layers.Conv2D(64, 3, activation="relu")(x) + x = tf.keras.layers.MaxPooling2D(2)(x) + x = tf.keras.layers.Flatten()(x) + x = tf.keras.layers.Dropout(0.2)(x) + outputs = tf.keras.layers.Dense(10, activation='softmax')(x) + model = tf.keras.Model(inputs, outputs) + return model + + +def run( + argv=None, save_main_session=True, test_pipeline=None) -> PipelineResult: + """ + Args: + argv: Command line arguments defined for this example. + save_main_session: Used for internal testing. + test_pipeline: Used for internal testing. + """ + known_args, pipeline_args = parse_known_args(argv) + pipeline_options = PipelineOptions(pipeline_args) + pipeline_options.view_as(SetupOptions).save_main_session = save_main_session + + # In this example we pass keyed inputs to RunInference transform. + # Therefore, we use KeyedModelHandler wrapper over TFModelHandlerNumpy. + model_loader = KeyedModelHandler( + TFModelHandlerNumpy( + model_uri=known_args.model_path, + model_type=ModelType.SAVED_WEIGHTS, + create_model_fn=get_model)) + + pipeline = test_pipeline + if not test_pipeline: + pipeline = beam.Pipeline(options=pipeline_options) + + label_pixel_tuple = ( + pipeline + | "ReadFromInput" >> beam.io.ReadFromText(known_args.input) + | "PreProcessInputs" >> beam.Map(process_input)) + + predictions = ( + label_pixel_tuple + | "RunInference" >> RunInference(model_loader) + | "PostProcessOutputs" >> beam.ParDo(PostProcessor())) + + _ = predictions | "WriteOutput" >> beam.io.WriteToText( + known_args.output, shard_name_template='', append_trailing_newlines=True) + + result = pipeline.run() + result.wait_until_finish() + return result + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + run() diff --git a/sdks/python/apache_beam/ml/inference/tensorflow_inference.py b/sdks/python/apache_beam/ml/inference/tensorflow_inference.py index ee33c53cadb0..dcebb9347ed9 100644 --- a/sdks/python/apache_beam/ml/inference/tensorflow_inference.py +++ b/sdks/python/apache_beam/ml/inference/tensorflow_inference.py @@ -52,6 +52,7 @@ class ModelType(enum.Enum): """Defines how a model file should be loaded.""" SAVED_MODEL = 1 + SAVED_WEIGHTS = 2 def _load_model(model_uri, model_type): @@ -61,6 +62,12 @@ def _load_model(model_uri, model_type): raise AssertionError('Unsupported model type for loading.') +def _load_model_from_weights(create_model_fn, weights_path): + model = create_model_fn() + model.load_weights(weights_path) + return model + + def default_numpy_inference_fn( model: tf.Module, batch: Sequence[numpy.ndarray], @@ -88,6 +95,7 @@ def __init__( self, model_uri: str, model_type: ModelType = ModelType.SAVED_MODEL, + create_model_fn: Optional[Callable] = None, *, inference_fn: TensorInferenceFn = default_numpy_inference_fn): """Implementation of the ModelHandler interface for Tensorflow. @@ -101,6 +109,9 @@ def __init__( Args: model_uri (str): path to the trained model. model_type: type of model to be loaded. Defaults to SAVED_MODEL. + create_model_fn: a function that creates and returns a new + tensorflow model to load the saved weights. + It should be used with ModelType.SAVED_WEIGHTS. inference_fn: inference function to use during RunInference. Defaults to default_numpy_inference_fn. @@ -110,9 +121,16 @@ def __init__( self._model_uri = model_uri self._model_type = model_type self._inference_fn = inference_fn + self._create_model_fn = create_model_fn def load_model(self) -> tf.Module: """Loads and initializes a Tensorflow model for processing.""" + if self._model_type == ModelType.SAVED_WEIGHTS: + if not self._create_model_fn: + raise ValueError( + "Callable create_model_fn must be passed" + "with ModelType.SAVED_WEIGHTS") + return _load_model_from_weights(self._create_model_fn, self._model_uri) return _load_model(self._model_uri, self._model_type) def update_model_path(self, model_path: Optional[str] = None): @@ -169,6 +187,7 @@ def __init__( self, model_uri: str, model_type: ModelType = ModelType.SAVED_MODEL, + create_model_fn: Optional[Callable] = None, *, inference_fn: TensorInferenceFn = default_tensor_inference_fn): """Implementation of the ModelHandler interface for Tensorflow. @@ -183,6 +202,9 @@ def __init__( model_uri (str): path to the trained model. model_type: type of model to be loaded. Defaults to SAVED_MODEL. + create_model_fn: a function that creates and returns a new + tensorflow model to load the saved weights. + It should be used with ModelType.SAVED_WEIGHTS. inference_fn: inference function to use during RunInference. Defaults to default_numpy_inference_fn. @@ -192,9 +214,16 @@ def __init__( self._model_uri = model_uri self._model_type = model_type self._inference_fn = inference_fn + self._create_model_fn = create_model_fn def load_model(self) -> tf.Module: """Loads and initializes a tensorflow model for processing.""" + if self._model_type == ModelType.SAVED_WEIGHTS: + if not self._create_model_fn: + raise ValueError( + "Callable create_model_fn must be passed" + "with ModelType.SAVED_WEIGHTS") + return _load_model_from_weights(self._create_model_fn, self._model_uri) return _load_model(self._model_uri, self._model_type) def update_model_path(self, model_path: Optional[str] = None): diff --git a/sdks/python/apache_beam/ml/inference/tensorflow_inference_it_test.py b/sdks/python/apache_beam/ml/inference/tensorflow_inference_it_test.py index 3c92461c15af..fb1a2964841b 100644 --- a/sdks/python/apache_beam/ml/inference/tensorflow_inference_it_test.py +++ b/sdks/python/apache_beam/ml/inference/tensorflow_inference_it_test.py @@ -31,6 +31,7 @@ import tensorflow as tf from apache_beam.examples.inference import tensorflow_imagenet_segmentation from apache_beam.examples.inference import tensorflow_mnist_classification + from apache_beam.examples.inference import tensorflow_mnist_with_weights except ImportError as e: tf = None @@ -108,6 +109,36 @@ def test_tf_imagenet_image_segmentation(self): for true_label, predicted_label in zip(expected_outputs, predicted_outputs): self.assertEqual(true_label, predicted_label) + def test_tf_mnist_with_weights_classification(self): + test_pipeline = TestPipeline(is_integration_test=True) + input_file = 'gs://apache-beam-ml/testing/inputs/it_mnist_data.csv' + output_file_dir = 'gs://apache-beam-ml/testing/outputs' + output_file = '/'.join([output_file_dir, str(uuid.uuid4()), 'result.txt']) + model_path = 'gs://apache-beam-ml/models/tensorflow/mnist' + extra_opts = { + 'input': input_file, + 'output': output_file, + 'model_path': model_path, + } + tensorflow_mnist_with_weights.run( + test_pipeline.get_full_options_as_args(**extra_opts), + save_main_session=False) + self.assertEqual(FileSystems().exists(output_file), True) + + expected_output_filepath = 'gs://apache-beam-ml/testing/expected_outputs/test_sklearn_mnist_classification_actuals.txt' # pylint: disable=line-too-long + expected_outputs = process_outputs(expected_output_filepath) + predicted_outputs = process_outputs(output_file) + self.assertEqual(len(expected_outputs), len(predicted_outputs)) + + predictions_dict = {} + for i in range(len(predicted_outputs)): + true_label, prediction = predicted_outputs[i].split(',') + predictions_dict[true_label] = prediction + + for i in range(len(expected_outputs)): + true_label, expected_prediction = expected_outputs[i].split(',') + self.assertEqual(predictions_dict[true_label], expected_prediction) + if __name__ == '__main__': logging.getLogger().setLevel(logging.DEBUG) diff --git a/sdks/python/tox.ini b/sdks/python/tox.ini index c21e384ca866..c38a886dcce8 100644 --- a/sdks/python/tox.ini +++ b/sdks/python/tox.ini @@ -343,7 +343,7 @@ commands = /bin/sh -c "pip freeze | grep -E onnx" # Run all ONNX unit tests pytest -o junit_suite_name={envname} --junitxml=pytest_{envname}.xml -n 6 -m uses_onnx {posargs} - + [testenv:py{37,38,39,310}-tensorflow-{29,210,211}] deps = -r build-requirements.txt