diff --git a/component_sdk/python/kfp_component/google/dataflow/_client.py b/component_sdk/python/kfp_component/google/dataflow/_client.py index 1d5e38c3274..8d5783afb56 100644 --- a/component_sdk/python/kfp_component/google/dataflow/_client.py +++ b/component_sdk/python/kfp_component/google/dataflow/_client.py @@ -30,18 +30,18 @@ def launch_template(self, project_id, gcs_path, location, ).execute() def get_job(self, project_id, job_id, location=None, view=None): - return self._df.projects().jobs().get( + return self._df.projects().locations().jobs().get( projectId = project_id, jobId = job_id, - location = location, + location = self._get_location(location), view = view ).execute() def cancel_job(self, project_id, job_id, location): - return self._df.projects().jobs().update( + return self._df.projects().locations().jobs().update( projectId = project_id, jobId = job_id, - location = location, + location = self._get_location(location), body = { 'requestedState': 'JOB_STATE_CANCELLED' } @@ -56,3 +56,8 @@ def list_aggregated_jobs(self, project_id, filter=None, pageSize = page_size, pageToken = page_token, location = location).execute() + + def _get_location(self, location): + if not location: + location = 'us-central1' + return location diff --git a/component_sdk/python/kfp_component/google/dataflow/_common_ops.py b/component_sdk/python/kfp_component/google/dataflow/_common_ops.py index 9fbae8974bb..a2a11891d95 100644 --- a/component_sdk/python/kfp_component/google/dataflow/_common_ops.py +++ b/component_sdk/python/kfp_component/google/dataflow/_common_ops.py @@ -26,41 +26,6 @@ _JOB_FAILED_STATES = ['JOB_STATE_STOPPED', 'JOB_STATE_FAILED', 'JOB_STATE_CANCELLED'] _JOB_TERMINATED_STATES = _JOB_SUCCESSFUL_STATES + _JOB_FAILED_STATES -def generate_job_name(job_name, context_id): - """Generates a stable job name in the job context. - - If user provided ``job_name`` has value, the function will use it - as a prefix and appends first 8 characters of ``context_id`` to - make the name unique across contexts. If the ``job_name`` is empty, - it will use ``job-{context_id}`` as the job name. - """ - if job_name: - return '{}-{}'.format( - gcp_common.normalize_name(job_name), - context_id[:8]) - - return 'job-{}'.format(context_id) - -def get_job_by_name(df_client, project_id, job_name, location=None): - """Gets a job by its name. - - The function lists all jobs under a project or a region location. - Compares their names with the ``job_name`` and return the job - once it finds a match. If none of the jobs matches, it returns - ``None``. - """ - page_token = None - while True: - response = df_client.list_aggregated_jobs(project_id, - page_size=50, page_token=page_token, location=location) - for job in response.get('jobs', []): - name = job.get('name', None) - if job_name == name: - return job - page_token = response.get('nextPageToken', None) - if not page_token: - return None - def wait_for_job_done(df_client, project_id, job_id, location=None, wait_interval=30): while True: job = df_client.get_job(project_id, job_id, location=location) @@ -120,3 +85,37 @@ def stage_file(local_or_gcs_path): download_blob(local_or_gcs_path, local_file_path) return local_file_path +def get_staging_location(staging_dir, context_id): + if not staging_dir: + return None + + staging_location = os.path.join(staging_dir, context_id) + logging.info('staging_location: {}'.format(staging_location)) + return staging_location + +def read_job_id_and_location(storage_client, staging_location): + if staging_location: + job_blob = _get_job_blob(storage_client, staging_location) + if job_blob.exists(): + job_data = job_blob.download_as_string().decode().split(',') + # Returns (job_id, location) + logging.info('Found existing job {}.'.format(job_data)) + return (job_data[0], job_data[1]) + + return (None, None) + +def upload_job_id_and_location(storage_client, staging_location, job_id, location): + if not staging_location: + return + if not location: + location = '' + data = '{},{}'.format(job_id, location) + job_blob = _get_job_blob(storage_client, staging_location) + logging.info('Uploading {} to {}.'.format(data, job_blob)) + job_blob.upload_from_string(data) + +def _get_job_blob(storage_client, staging_location): + bucket_name, staging_blob_name = parse_blob_path(staging_location) + job_blob_name = os.path.join(staging_blob_name, 'kfp/dataflow/launch_python/job.txt') + bucket = storage_client.bucket(bucket_name) + return bucket.blob(job_blob_name) \ No newline at end of file diff --git a/component_sdk/python/kfp_component/google/dataflow/_launch_python.py b/component_sdk/python/kfp_component/google/dataflow/_launch_python.py index 456b0c92009..191902251e2 100644 --- a/component_sdk/python/kfp_component/google/dataflow/_launch_python.py +++ b/component_sdk/python/kfp_component/google/dataflow/_launch_python.py @@ -14,34 +14,38 @@ import subprocess import re import logging +import os +from google.cloud import storage from kfp_component.core import KfpExecutionContext from ._client import DataflowClient -from .. import common as gcp_common -from ._common_ops import (generate_job_name, get_job_by_name, - wait_and_dump_job, stage_file) +from ._common_ops import (wait_and_dump_job, stage_file, get_staging_location, + read_job_id_and_location, upload_job_id_and_location) from ._process import Process +from ..storage import parse_blob_path -def launch_python(python_file_path, project_id, requirements_file_path=None, - location=None, job_name_prefix=None, args=[], wait_interval=30): +def launch_python(python_file_path, project_id, staging_dir=None, requirements_file_path=None, + args=[], wait_interval=30): """Launch a self-executing beam python file. Args: python_file_path (str): The gcs or local path to the python file to run. project_id (str): The ID of the parent project. + staging_dir (str): Optional. The GCS directory for keeping staging files. + A random subdirectory will be created under the directory to keep job info + for resuming the job in case of failure and it will be passed as + `staging_location` and `temp_location` command line args of the beam code. requirements_file_path (str): Optional, the gcs or local path to the pip requirements file. - location (str): The regional endpoint to which to direct the - request. - job_name_prefix (str): Optional. The prefix of the genrated job - name. If not provided, the method will generated a random name. args (list): The list of args to pass to the python file. wait_interval (int): The wait seconds between polling. Returns: The completed job. """ + storage_client = storage.Client() df_client = DataflowClient() job_id = None + location = None def cancel(): if job_id: df_client.cancel_job( @@ -50,27 +54,23 @@ def cancel(): location ) with KfpExecutionContext(on_cancel=cancel) as ctx: - job_name = generate_job_name( - job_name_prefix, - ctx.context_id()) - # We will always generate unique name for the job. We expect - # job with same name was created in previous tries from the same - # pipeline run. - job = get_job_by_name(df_client, project_id, job_name, - location) - if job: + staging_location = get_staging_location(staging_dir, ctx.context_id()) + job_id, location = read_job_id_and_location(storage_client, staging_location) + # Continue waiting for the job if it's has been uploaded to staging location. + if job_id: + job = df_client.get_job(project_id, job_id, location) return wait_and_dump_job(df_client, project_id, location, job, wait_interval) _install_requirements(requirements_file_path) python_file_path = stage_file(python_file_path) - cmd = _prepare_cmd(project_id, location, job_name, python_file_path, - args) + cmd = _prepare_cmd(project_id, python_file_path, args, staging_location) sub_process = Process(cmd) for line in sub_process.read_lines(): - job_id = _extract_job_id(line) + job_id, location = _extract_job_id_and_location(line) if job_id: - logging.info('Found job id {}'.format(job_id)) + logging.info('Found job id {} and location {}.'.format(job_id, location)) + upload_job_id_and_location(storage_client, staging_location, job_id, location) break sub_process.wait_and_check() if not job_id: @@ -82,23 +82,24 @@ def cancel(): return wait_and_dump_job(df_client, project_id, location, job, wait_interval) -def _prepare_cmd(project_id, location, job_name, python_file_path, args): +def _prepare_cmd(project_id, python_file_path, args, staging_location): dataflow_args = [ '--runner', 'dataflow', - '--project', project_id, - '--job-name', job_name] - if location: - dataflow_args += ['--location', location] + '--project', project_id] + if staging_location: + dataflow_args += ['--staging_location', staging_location, '--temp_location', staging_location] return (['python2', '-u', python_file_path] + dataflow_args + args) -def _extract_job_id(line): +def _extract_job_id_and_location(line): + """Returns (job_id, location) from matched log. + """ job_id_pattern = re.compile( - br'.*console.cloud.google.com/dataflow.*/jobs/([a-z|0-9|A-Z|\-|\_]+).*') + br'.*console.cloud.google.com/dataflow.*/locations/([a-z|0-9|A-Z|\-|\_]+)/jobs/([a-z|0-9|A-Z|\-|\_]+).*') matched_job_id = job_id_pattern.search(line or '') if matched_job_id: - return matched_job_id.group(1).decode() - return None + return (matched_job_id.group(2).decode(), matched_job_id.group(1).decode()) + return (None, None) def _install_requirements(requirements_file_path): if not requirements_file_path: diff --git a/component_sdk/python/kfp_component/google/dataflow/_launch_template.py b/component_sdk/python/kfp_component/google/dataflow/_launch_template.py index 4e950c355ee..cb23f3b949e 100644 --- a/component_sdk/python/kfp_component/google/dataflow/_launch_template.py +++ b/component_sdk/python/kfp_component/google/dataflow/_launch_template.py @@ -17,14 +17,14 @@ import re import time +from google.cloud import storage from kfp_component.core import KfpExecutionContext from ._client import DataflowClient -from .. import common as gcp_common -from ._common_ops import (generate_job_name, get_job_by_name, - wait_and_dump_job) +from ._common_ops import (wait_and_dump_job, get_staging_location, + read_job_id_and_location, upload_job_id_and_location) def launch_template(project_id, gcs_path, launch_parameters, - location=None, job_name_prefix=None, validate_only=None, + location=None, validate_only=None, staging_dir=None, wait_interval=30): """Launchs a dataflow job from template. @@ -40,15 +40,17 @@ def launch_template(project_id, gcs_path, launch_parameters, `jobName` will be replaced by generated name. location (str): The regional endpoint to which to direct the request. - job_name_prefix (str): Optional. The prefix of the genrated job - name. If not provided, the method will generated a random name. validate_only (boolean): If true, the request is validated but not actually executed. Defaults to false. + staging_dir (str): Optional. The GCS directory for keeping staging files. + A random subdirectory will be created under the directory to keep job info + for resuming the job in case of failure. wait_interval (int): The wait seconds between polling. Returns: The completed job. """ + storage_client = storage.Client() df_client = DataflowClient() job_id = None def cancel(): @@ -59,19 +61,24 @@ def cancel(): location ) with KfpExecutionContext(on_cancel=cancel) as ctx: - job_name = generate_job_name( - job_name_prefix, - ctx.context_id()) - print(job_name) - job = get_job_by_name(df_client, project_id, job_name, - location) - if not job: - launch_parameters['jobName'] = job_name - response = df_client.launch_template(project_id, gcs_path, - location, validate_only, launch_parameters) - job = response.get('job', None) + staging_location = get_staging_location(staging_dir, ctx.context_id()) + job_id, _ = read_job_id_and_location(storage_client, staging_location) + # Continue waiting for the job if it's has been uploaded to staging location. + if job_id: + job = df_client.get_job(project_id, job_id, location) + return wait_and_dump_job(df_client, project_id, location, job, + wait_interval) + + if not launch_parameters: + launch_parameters = {} + launch_parameters['jobName'] = 'job-' + ctx.context_id() + response = df_client.launch_template(project_id, gcs_path, + location, validate_only, launch_parameters) + job = response.get('job', None) if not job: # Validate only mode return job + job_id = job.get('id') + upload_job_id_and_location(storage_client, staging_location, job_id, location) return wait_and_dump_job(df_client, project_id, location, job, wait_interval) \ No newline at end of file diff --git a/component_sdk/python/tests/google/dataflow/test__launch_python.py b/component_sdk/python/tests/google/dataflow/test__launch_python.py index e8f653862f6..3151d4fbcd7 100644 --- a/component_sdk/python/tests/google/dataflow/test__launch_python.py +++ b/component_sdk/python/tests/google/dataflow/test__launch_python.py @@ -20,6 +20,7 @@ MODULE = 'kfp_component.google.dataflow._launch_python' +@mock.patch(MODULE + '.storage') @mock.patch('kfp_component.google.dataflow._common_ops.display') @mock.patch(MODULE + '.stage_file') @mock.patch(MODULE + '.KfpExecutionContext') @@ -29,11 +30,9 @@ class LaunchPythonTest(unittest.TestCase): def test_launch_python_succeed(self, mock_subprocess, mock_process, - mock_client, mock_context, mock_stage_file, mock_display): + mock_client, mock_context, mock_stage_file, mock_display, mock_storage): mock_context().__enter__().context_id.return_value = 'ctx-1' - mock_client().list_aggregated_jobs.return_value = { - 'jobs': [] - } + mock_storage.Client().bucket().blob().exists.return_value = False mock_process().read_lines.return_value = [ b'https://console.cloud.google.com/dataflow/locations/us-central1/jobs/job-1?project=project-1' ] @@ -43,36 +42,32 @@ def test_launch_python_succeed(self, mock_subprocess, mock_process, } mock_client().get_job.return_value = expected_job - result = launch_python('/tmp/test.py', 'project-1') + result = launch_python('/tmp/test.py', 'project-1', staging_dir='gs://staging/dir') self.assertEqual(expected_job, result) + mock_storage.Client().bucket().blob().upload_from_string.assert_called_with( + 'job-1,us-central1' + ) def test_launch_python_retry_succeed(self, mock_subprocess, mock_process, - mock_client, mock_context, mock_stage_file, mock_display): + mock_client, mock_context, mock_stage_file, mock_display, mock_storage): mock_context().__enter__().context_id.return_value = 'ctx-1' - mock_client().list_aggregated_jobs.return_value = { - 'jobs': [{ - 'id': 'job-1', - 'name': 'test_job-ctx-1' - }] - } + mock_storage.Client().bucket().blob().exists.return_value = True + mock_storage.Client().bucket().blob().download_as_string.return_value = b'job-1,us-central1' expected_job = { 'id': 'job-1', 'currentState': 'JOB_STATE_DONE' } mock_client().get_job.return_value = expected_job - result = launch_python('/tmp/test.py', 'project-1', job_name_prefix='test-job') + result = launch_python('/tmp/test.py', 'project-1', staging_dir='gs://staging/dir') self.assertEqual(expected_job, result) mock_process.assert_not_called() def test_launch_python_no_job_created(self, mock_subprocess, mock_process, - mock_client, mock_context, mock_stage_file, mock_display): + mock_client, mock_context, mock_stage_file, mock_display, mock_storage): mock_context().__enter__().context_id.return_value = 'ctx-1' - mock_client().list_aggregated_jobs.return_value = { - 'jobs': [] - } mock_process().read_lines.return_value = [ b'no job id', b'no job id' diff --git a/component_sdk/python/tests/google/dataflow/test__launch_template.py b/component_sdk/python/tests/google/dataflow/test__launch_template.py index 138238ab3f9..e78a3700a33 100644 --- a/component_sdk/python/tests/google/dataflow/test__launch_template.py +++ b/component_sdk/python/tests/google/dataflow/test__launch_template.py @@ -20,16 +20,16 @@ MODULE = 'kfp_component.google.dataflow._launch_template' +@mock.patch(MODULE + '.storage') @mock.patch('kfp_component.google.dataflow._common_ops.display') @mock.patch(MODULE + '.KfpExecutionContext') @mock.patch(MODULE + '.DataflowClient') class LaunchTemplateTest(unittest.TestCase): - def test_launch_template_succeed(self, mock_client, mock_context, mock_display): + def test_launch_template_succeed(self, mock_client, mock_context, mock_display, + mock_storage): mock_context().__enter__().context_id.return_value = 'context-1' - mock_client().list_aggregated_jobs.return_value = { - 'jobs': [] - } + mock_storage.Client().bucket().blob().exists.return_value = False mock_client().launch_template.return_value = { 'job': { 'id': 'job-1' } } @@ -46,21 +46,19 @@ def test_launch_template_succeed(self, mock_client, mock_context, mock_display): "environment": { "zone": "us-central1" } - }) + }, staging_dir='gs://staging/dir') self.assertEqual(expected_job, result) mock_client().launch_template.assert_called_once() + mock_storage.Client().bucket().blob().upload_from_string.assert_called_with( + 'job-1,' + ) def test_launch_template_retry_succeed(self, - mock_client, mock_context, mock_display): + mock_client, mock_context, mock_display, mock_storage): mock_context().__enter__().context_id.return_value = 'ctx-1' - # The job with same name already exists. - mock_client().list_aggregated_jobs.return_value = { - 'jobs': [{ - 'id': 'job-1', - 'name': 'test_job-ctx-1' - }] - } + mock_storage.Client().bucket().blob().exists.return_value = True + mock_storage.Client().bucket().blob().download_as_string.return_value = b'job-1,' pending_job = { 'currentState': 'JOB_STATE_PENDING' } @@ -77,16 +75,15 @@ def test_launch_template_retry_succeed(self, "environment": { "zone": "us-central1" } - }, job_name_prefix='test-job', wait_interval=0) + }, staging_dir='gs://staging/dir', wait_interval=0) self.assertEqual(expected_job, result) mock_client().launch_template.assert_not_called() - def test_launch_template_fail(self, mock_client, mock_context, mock_display): + def test_launch_template_fail(self, mock_client, mock_context, mock_display, + mock_storage): mock_context().__enter__().context_id.return_value = 'context-1' - mock_client().list_aggregated_jobs.return_value = { - 'jobs': [] - } + mock_storage.Client().bucket().blob().exists.return_value = False mock_client().launch_template.return_value = { 'job': { 'id': 'job-1' } }