Skip to content

Commit

Permalink
Replaced deprecated assertDictContainsSubset with assertLessEqual(ite…
Browse files Browse the repository at this point in the history
…msA, itemsB)

PiperOrigin-RevId: 425717602
  • Loading branch information
zhitaoli authored and tfx-copybara committed Feb 1, 2022
1 parent bdab20c commit 010e94d
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 49 deletions.
5 changes: 4 additions & 1 deletion RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@

## Bug Fixes and Other Changes

* Fixed the cluster spec error in CAIP Tuner on Vertex when `num_parallel_trials = 1`
* Fixed the cluster spec error in CAIP Tuner on Vertex when
`num_parallel_trials = 1`
* Replaced deprecated assertDictContainsSubset with
assertLessEqual(itemsA, itemsB).

## Dependency Updates

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,16 +84,16 @@ def testStart(self):
# Check calls.
self._docker_client.containers.run.assert_called()
_, run_kwargs = self._docker_client.containers.run.call_args
self.assertDictContainsSubset(dict(
image='tensorflow/serving:1.15.0',
environment={
'MODEL_NAME': 'chicago-taxi',
'MODEL_BASE_PATH': '/model'
},
publish_all_ports=True,
auto_remove=True,
detach=True
), run_kwargs)
self.assertLessEqual(
dict(
image='tensorflow/serving:1.15.0',
environment={
'MODEL_NAME': 'chicago-taxi',
'MODEL_BASE_PATH': '/model'
},
publish_all_ports=True,
auto_remove=True,
detach=True).items(), run_kwargs.items())

def testStartMultipleTimesFail(self):
# Prepare mocks and variables.
Expand Down
74 changes: 36 additions & 38 deletions tfx/extensions/google_cloud_ai_platform/runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,21 +114,20 @@ def testStartCloudTraining(self, mock_discovery):

default_image = 'gcr.io/tfx-oss-public/tfx:{}'.format(
version_utils.get_image_version())
self.assertDictContainsSubset(
{
'masterConfig': {
'imageUri':
default_image,
'containerCommand':
runner._CONTAINER_COMMAND + [
'--executor_class_path', class_path, '--inputs', '{}',
'--outputs', '{}', '--exec-properties',
('{"custom_config": '
'"{\\"ai_platform_training_args\\": {\\"project\\": \\"12345\\"'
'}}"}')
],
},
}, body['training_input'])
self.assertLessEqual({
'masterConfig': {
'imageUri':
default_image,
'containerCommand':
runner._CONTAINER_COMMAND + [
'--executor_class_path', class_path, '--inputs', '{}',
'--outputs', '{}', '--exec-properties',
('{"custom_config": '
'"{\\"ai_platform_training_args\\": {\\"project\\": \\"12345\\"'
'}}"}')
],
},
}.items(), body['training_input'].items())
self.assertNotIn('project', body['training_input'])
self.assertStartsWith(body['job_id'], 'tfx_')
self._mock_get.execute.assert_called_with()
Expand Down Expand Up @@ -239,28 +238,27 @@ def testStartCloudTrainingWithUserContainer_Vertex(self, mock_gapic):
custom_job=mock.ANY)
kwargs = self._mock_create.call_args[1]
body = kwargs['custom_job']
self.assertDictContainsSubset(
{
'worker_pool_specs': [{
'container_spec': {
'image_uri':
'my-custom-image',
'command':
runner._CONTAINER_COMMAND + [
'--executor_class_path', class_path, '--inputs',
'{}', '--outputs', '{}', '--exec-properties',
('{"custom_config": '
'"{\\"ai_platform_training_args\\": '
'{\\"project\\": \\"12345\\", '
'\\"worker_pool_specs\\": '
'[{\\"container_spec\\": '
'{\\"image_uri\\": \\"my-custom-image\\"}}]}, '
'\\"ai_platform_training_job_id\\": '
'\\"my_jobid\\"}"}')
],
},
},],
}, body['job_spec'])
self.assertLessEqual({
'worker_pool_specs': [{
'container_spec': {
'image_uri':
'my-custom-image',
'command':
runner._CONTAINER_COMMAND + [
'--executor_class_path', class_path, '--inputs', '{}',
'--outputs', '{}', '--exec-properties',
('{"custom_config": '
'"{\\"ai_platform_training_args\\": '
'{\\"project\\": \\"12345\\", '
'\\"worker_pool_specs\\": '
'[{\\"container_spec\\": '
'{\\"image_uri\\": \\"my-custom-image\\"}}]}, '
'\\"ai_platform_training_job_id\\": '
'\\"my_jobid\\"}"}')
],
},
},],
}.items(), body['job_spec'].items())
self.assertEqual(body['display_name'], 'my_jobid')
self._mock_get.assert_called_with(name='vertex_job_study_id')

Expand Down Expand Up @@ -329,7 +327,7 @@ def testStartCloudTrainingWithVertexCustomJob(self, mock_gapic):
}, body['job_spec'])
self.assertEqual(body['display_name'], 'valid_name')
self.assertDictEqual(body['encryption_spec'], expected_encryption_spec)
self.assertDictContainsSubset(user_provided_labels, body['labels'])
self.assertLessEqual(user_provided_labels.items(), body['labels'].items())
self._mock_get.assert_called_with(name='vertex_job_study_id')

def _setUpPredictionMocks(self):
Expand Down

0 comments on commit 010e94d

Please sign in to comment.