Skip to content

Commit

Permalink
Add unit tests for mlengine_operator_utils (#9702)
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 3cc5756d04c7d3fd14b71e9ec6f5f906d5a24212
  • Loading branch information
chipmyersjr authored and Cloud Composer Team committed Sep 12, 2024
1 parent 1de8729 commit 24368cd
Show file tree
Hide file tree
Showing 2 changed files with 277 additions and 1 deletion.
277 changes: 277 additions & 0 deletions tests/providers/google/cloud/utils/test_mlengine_operator_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,277 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import base64
import json
import unittest
from datetime import datetime

import dill
import mock

from airflow.exceptions import AirflowException
from airflow.models import DAG
from airflow.operators.python import PythonOperator
from airflow.providers.google.cloud.hooks.gcs import GCSHook
from airflow.providers.google.cloud.operators.dataflow import DataflowCreatePythonJobOperator
from airflow.providers.google.cloud.utils.mlengine_operator_utils import create_evaluate_ops

TASK_PREFIX = "test-task-prefix"
TASK_PREFIX_PREDICTION = TASK_PREFIX + "-prediction"
TASK_PREFIX_SUMMARY = TASK_PREFIX + "-summary"
TASK_PREFIX_VALIDATION = TASK_PREFIX + "-validation"
DATA_FORMAT = "TEXT"
INPUT_PATHS = ["gs://path/to/input/file.json",
"gs://path/to/input/file2.json",
"gs://path/to/input/file3.json"]
PREDICTION_PATH = "gs://path/to/output/predictions.json"
BATCH_PREDICTION_JOB_ID = "test-batch-prediction-job-id"
PROJECT_ID = "test-project-id"
REGION = "test-region"
DATAFLOW_OPTIONS = {
"project": "my-gcp-project",
"zone": "us-central1-f",
"stagingLocation": "gs://bucket/tmp/dataflow/staging/"
}
MODEL_URI = "gs://path/to/model/model"
MODEL_NAME = "test-model-name"
VERSION_NAME = "test-version-name"
DAG_DEFAULT_ARGS = {
"project_id": PROJECT_ID,
"region": REGION,
"model_name": MODEL_NAME,
"version_name": VERSION_NAME,
"dataflow_default_options": DATAFLOW_OPTIONS
}
TEST_DAG = DAG(dag_id="test-dag-id", start_date=datetime(2000, 1, 1), default_args=DAG_DEFAULT_ARGS)


def get_metric_fn_and_keys():
import math

def error_and_squared_error(inst):
label = float(inst['input_label'])
classes = float(inst['classes'])
err = abs(classes - label)
squared_err = math.pow(classes - label, 2)
return err, squared_err

return error_and_squared_error, ['err', 'mse']


METRIC_FN, METRIC_KEYS = get_metric_fn_and_keys()
METRIC_FN_ENCODED = base64.b64encode(dill.dumps(METRIC_FN, recurse=True)).decode()
METRIC_KEYS_EXPECTED = ','.join(METRIC_KEYS)


def validate_err_and_count(summary):
if summary['err'] > 0.2:
raise ValueError('Too high err>0.2; summary=%s' % summary)
if summary['mse'] > 0.05:
raise ValueError('Too high mse>0.05; summary=%s' % summary)
if summary['count'] < 1000:
raise ValueError('Too few instances<1000; summary=%s' % summary)
return summary


class TestMlengineOperatorUtils(unittest.TestCase):
@mock.patch.object(PythonOperator, "set_upstream")
@mock.patch.object(DataflowCreatePythonJobOperator, "set_upstream")
def test_create_evaluate_ops(self, mock_dataflow, mock_python):
result = create_evaluate_ops(task_prefix=TASK_PREFIX,
data_format=DATA_FORMAT,
input_paths=INPUT_PATHS,
prediction_path=PREDICTION_PATH,
metric_fn_and_keys=get_metric_fn_and_keys(),
validate_fn=validate_err_and_count,
batch_prediction_job_id=BATCH_PREDICTION_JOB_ID,
project_id=PROJECT_ID,
region=REGION,
dataflow_options=DATAFLOW_OPTIONS,
model_uri=MODEL_URI
)

evaluate_prediction, evaluate_summary, evaluate_validation = result

mock_dataflow.assert_called_once_with(evaluate_prediction)
mock_python.assert_called_once_with(evaluate_summary)

self.assertEqual(TASK_PREFIX_PREDICTION, evaluate_prediction.task_id)
self.assertEqual(PROJECT_ID, evaluate_prediction._project_id)
self.assertEqual(BATCH_PREDICTION_JOB_ID, evaluate_prediction._job_id)
self.assertEqual(REGION, evaluate_prediction._region)
self.assertEqual(DATA_FORMAT, evaluate_prediction._data_format)
self.assertEqual(INPUT_PATHS, evaluate_prediction._input_paths)
self.assertEqual(PREDICTION_PATH, evaluate_prediction._output_path)
self.assertEqual(MODEL_URI, evaluate_prediction._uri)

self.assertEqual(TASK_PREFIX_SUMMARY, evaluate_summary.task_id)
self.assertEqual(DATAFLOW_OPTIONS, evaluate_summary.dataflow_default_options)
self.assertEqual(PREDICTION_PATH, evaluate_summary.options["prediction_path"])
self.assertEqual(METRIC_FN_ENCODED, evaluate_summary.options["metric_fn_encoded"])
self.assertEqual(METRIC_KEYS_EXPECTED, evaluate_summary.options["metric_keys"])

self.assertEqual(TASK_PREFIX_VALIDATION, evaluate_validation.task_id)
self.assertEqual(PREDICTION_PATH, evaluate_validation.templates_dict["prediction_path"])

@mock.patch.object(PythonOperator, "set_upstream")
@mock.patch.object(DataflowCreatePythonJobOperator, "set_upstream")
def test_create_evaluate_ops_model_and_version_name(self, mock_dataflow, mock_python):
result = create_evaluate_ops(task_prefix=TASK_PREFIX,
data_format=DATA_FORMAT,
input_paths=INPUT_PATHS,
prediction_path=PREDICTION_PATH,
metric_fn_and_keys=get_metric_fn_and_keys(),
validate_fn=validate_err_and_count,
batch_prediction_job_id=BATCH_PREDICTION_JOB_ID,
project_id=PROJECT_ID,
region=REGION,
dataflow_options=DATAFLOW_OPTIONS,
model_name=MODEL_NAME,
version_name=VERSION_NAME
)

evaluate_prediction, evaluate_summary, evaluate_validation = result

mock_dataflow.assert_called_once_with(evaluate_prediction)
mock_python.assert_called_once_with(evaluate_summary)

self.assertEqual(TASK_PREFIX_PREDICTION, evaluate_prediction.task_id)
self.assertEqual(PROJECT_ID, evaluate_prediction._project_id)
self.assertEqual(BATCH_PREDICTION_JOB_ID, evaluate_prediction._job_id)
self.assertEqual(REGION, evaluate_prediction._region)
self.assertEqual(DATA_FORMAT, evaluate_prediction._data_format)
self.assertEqual(INPUT_PATHS, evaluate_prediction._input_paths)
self.assertEqual(PREDICTION_PATH, evaluate_prediction._output_path)
self.assertEqual(MODEL_NAME, evaluate_prediction._model_name)
self.assertEqual(VERSION_NAME, evaluate_prediction._version_name)

self.assertEqual(TASK_PREFIX_SUMMARY, evaluate_summary.task_id)
self.assertEqual(DATAFLOW_OPTIONS, evaluate_summary.dataflow_default_options)
self.assertEqual(PREDICTION_PATH, evaluate_summary.options["prediction_path"])
self.assertEqual(METRIC_FN_ENCODED, evaluate_summary.options["metric_fn_encoded"])
self.assertEqual(METRIC_KEYS_EXPECTED, evaluate_summary.options["metric_keys"])

self.assertEqual(TASK_PREFIX_VALIDATION, evaluate_validation.task_id)
self.assertEqual(PREDICTION_PATH, evaluate_validation.templates_dict["prediction_path"])

@mock.patch.object(PythonOperator, "set_upstream")
@mock.patch.object(DataflowCreatePythonJobOperator, "set_upstream")
def test_create_evaluate_ops_dag(self, mock_dataflow, mock_python):
result = create_evaluate_ops(task_prefix=TASK_PREFIX,
data_format=DATA_FORMAT,
input_paths=INPUT_PATHS,
prediction_path=PREDICTION_PATH,
metric_fn_and_keys=get_metric_fn_and_keys(),
validate_fn=validate_err_and_count,
batch_prediction_job_id=BATCH_PREDICTION_JOB_ID,
dag=TEST_DAG
)

evaluate_prediction, evaluate_summary, evaluate_validation = result

mock_dataflow.assert_called_once_with(evaluate_prediction)
mock_python.assert_called_once_with(evaluate_summary)

self.assertEqual(TASK_PREFIX_PREDICTION, evaluate_prediction.task_id)
self.assertEqual(PROJECT_ID, evaluate_prediction._project_id)
self.assertEqual(BATCH_PREDICTION_JOB_ID, evaluate_prediction._job_id)
self.assertEqual(REGION, evaluate_prediction._region)
self.assertEqual(DATA_FORMAT, evaluate_prediction._data_format)
self.assertEqual(INPUT_PATHS, evaluate_prediction._input_paths)
self.assertEqual(PREDICTION_PATH, evaluate_prediction._output_path)
self.assertEqual(MODEL_NAME, evaluate_prediction._model_name)
self.assertEqual(VERSION_NAME, evaluate_prediction._version_name)

self.assertEqual(TASK_PREFIX_SUMMARY, evaluate_summary.task_id)
self.assertEqual(DATAFLOW_OPTIONS, evaluate_summary.dataflow_default_options)
self.assertEqual(PREDICTION_PATH, evaluate_summary.options["prediction_path"])
self.assertEqual(METRIC_FN_ENCODED, evaluate_summary.options["metric_fn_encoded"])
self.assertEqual(METRIC_KEYS_EXPECTED, evaluate_summary.options["metric_keys"])

self.assertEqual(TASK_PREFIX_VALIDATION, evaluate_validation.task_id)
self.assertEqual(PREDICTION_PATH, evaluate_validation.templates_dict["prediction_path"])

@mock.patch.object(GCSHook, "download")
@mock.patch.object(PythonOperator, "set_upstream")
@mock.patch.object(DataflowCreatePythonJobOperator, "set_upstream")
def test_apply_validate_fn(self, mock_dataflow, mock_python, mock_download):
result = create_evaluate_ops(task_prefix=TASK_PREFIX,
data_format=DATA_FORMAT,
input_paths=INPUT_PATHS,
prediction_path=PREDICTION_PATH,
metric_fn_and_keys=get_metric_fn_and_keys(),
validate_fn=validate_err_and_count,
batch_prediction_job_id=BATCH_PREDICTION_JOB_ID,
project_id=PROJECT_ID,
region=REGION,
dataflow_options=DATAFLOW_OPTIONS,
model_uri=MODEL_URI
)

_, _, evaluate_validation = result

mock_download.return_value = json.dumps({
"err": 0.3,
"mse": 0.04,
"count": 1100
})
templates_dict = {"prediction_path": PREDICTION_PATH}
with self.assertRaises(ValueError) as context:
evaluate_validation.python_callable(templates_dict=templates_dict)

self.assertEqual("Too high err>0.2; summary={'err': 0.3, 'mse': 0.04, 'count': 1100}",
str(context.exception))
mock_download.assert_called_once_with("path", "to/output/predictions.json/prediction.summary.json")

invalid_prediction_paths = ["://path/to/output/predictions.json", "gs://", ""]

for path in invalid_prediction_paths:
templates_dict = {"prediction_path": path}
with self.assertRaises(ValueError) as context:
evaluate_validation.python_callable(templates_dict=templates_dict)
self.assertEqual("Wrong format prediction_path:", str(context.exception)[:29])

def test_invalid_task_prefix(self):
invalid_task_prefix_values = ["test-task-prefix&", "~test-task-prefix", "test-task(-prefix"]

for invalid_task_prefix_value in invalid_task_prefix_values:
with self.assertRaises(AirflowException):
create_evaluate_ops(task_prefix=invalid_task_prefix_value,
data_format=DATA_FORMAT,
input_paths=INPUT_PATHS,
prediction_path=PREDICTION_PATH,
metric_fn_and_keys=get_metric_fn_and_keys(),
validate_fn=validate_err_and_count)

def test_non_callable_metric_fn(self):
with self.assertRaises(AirflowException):
create_evaluate_ops(task_prefix=TASK_PREFIX,
data_format=DATA_FORMAT,
input_paths=INPUT_PATHS,
prediction_path=PREDICTION_PATH,
metric_fn_and_keys=("error_and_squared_error", ['err', 'mse']),
validate_fn=validate_err_and_count)

def test_non_callable_validate_fn(self):
with self.assertRaises(AirflowException):
create_evaluate_ops(task_prefix=TASK_PREFIX,
data_format=DATA_FORMAT,
input_paths=INPUT_PATHS,
prediction_path=PREDICTION_PATH,
metric_fn_and_keys=get_metric_fn_and_keys(),
validate_fn="validate_err_and_count")
1 change: 0 additions & 1 deletion tests/test_project_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
'tests/providers/google/cloud/transfers/test_sql_to_gcs.py',
'tests/providers/google/cloud/utils/test_field_sanitizer.py',
'tests/providers/google/cloud/utils/test_field_validator.py',
'tests/providers/google/cloud/utils/test_mlengine_operator_utils.py',
'tests/providers/google/cloud/utils/test_mlengine_prediction_summary.py',
'tests/providers/jenkins/hooks/test_jenkins.py',
'tests/providers/microsoft/azure/sensors/test_azure_cosmos.py',
Expand Down

0 comments on commit 24368cd

Please sign in to comment.