Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Python] Allow users to pass service name for profiler #26220

Merged
merged 15 commits into from
May 3, 2023
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@

## New Features / Improvements

* Allow passing service name for google-cloud-profiler (Python) ([#26280](https://github.com/apache/beam/issues/26280)).
* Dead letter queue support added to RunInference in Python ([#24209](https://github.com/apache/beam/issues/24209)).

## Breaking Changes
Expand Down
16 changes: 16 additions & 0 deletions sdks/python/apache_beam/options/pipeline_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import argparse
import json
import logging
import os
from typing import Any
from typing import Callable
from typing import Dict
Expand Down Expand Up @@ -876,6 +877,21 @@ def validate(self, validator):

return errors

def get_cloud_profiler_service_name(self):
_ENABLE_GOOGLE_CLOUD_PROFILER = 'enable_google_cloud_profiler'
if self.dataflow_service_options:
if _ENABLE_GOOGLE_CLOUD_PROFILER in self.dataflow_service_options:
return os.environ["JOB_NAME"]
for option_name in self.dataflow_service_options:
if option_name.startswith(_ENABLE_GOOGLE_CLOUD_PROFILER + '='):
return option_name.split('=', 1)[1]

experiments = self.view_as(DebugOptions).experiments or []
if _ENABLE_GOOGLE_CLOUD_PROFILER in experiments:
return os.environ["JOB_NAME"]

return None


class AzureOptions(PipelineOptions):
"""Azure Blob Storage options."""
Expand Down
1 change: 0 additions & 1 deletion sdks/python/apache_beam/runners/worker/data_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from typing import Iterable
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union

from apache_beam.coders.coder_impl import CoderImpl
Expand Down
58 changes: 34 additions & 24 deletions sdks/python/apache_beam/runners/worker/sdk_worker_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,32 +173,42 @@ def create_harness(environment, dry_run=False):
return fn_log_handler, sdk_harness, sdk_pipeline_options


def _start_profiler(gcp_profiler_service_name, gcp_profiler_service_version):
try:
import googlecloudprofiler
if gcp_profiler_service_name and gcp_profiler_service_version:
googlecloudprofiler.start(
service=gcp_profiler_service_name,
service_version=gcp_profiler_service_version,
verbose=1)
_LOGGER.info('Turning on Google Cloud Profiler.')
else:
raise RuntimeError('Unable to find the job id or job name from envvar.')
except Exception as e: # pylint: disable=broad-except
_LOGGER.warning(
'Unable to start google cloud profiler due to error: %s. For how to '
'enable Cloud Profiler with Dataflow see '
'https://cloud.google.com/dataflow/docs/guides/profiling-a-pipeline.'
'For troubleshooting tips with Cloud Profiler see '
'https://cloud.google.com/profiler/docs/troubleshooting.' % e)


def _get_gcp_profiler_name_if_enabled(sdk_pipeline_options):
riteshghorse marked this conversation as resolved.
Show resolved Hide resolved
gcp_profiler_service_name = sdk_pipeline_options.view_as(
GoogleCloudOptions).get_cloud_profiler_service_name()

return gcp_profiler_service_name


def main(unused_argv):
"""Main entry point for SDK Fn Harness."""
fn_log_handler, sdk_harness, sdk_pipeline_options = create_harness(os.environ)
experiments = sdk_pipeline_options.view_as(DebugOptions).experiments or []
dataflow_service_options = (
sdk_pipeline_options.view_as(GoogleCloudOptions).dataflow_service_options
or [])
if (_ENABLE_GOOGLE_CLOUD_PROFILER in experiments) or (
_ENABLE_GOOGLE_CLOUD_PROFILER in dataflow_service_options):
try:
import googlecloudprofiler
job_id = os.environ["JOB_ID"]
job_name = os.environ["JOB_NAME"]
if job_id and job_name:
googlecloudprofiler.start(
service=job_name, service_version=job_id, verbose=1)
_LOGGER.info('Turning on Google Cloud Profiler.')
else:
raise RuntimeError('Unable to find the job id or job name from envvar.')
except Exception as e: # pylint: disable=broad-except
_LOGGER.warning(
'Unable to start google cloud profiler due to error: %s. For how to '
'enable Cloud Profiler with Dataflow see '
'https://cloud.google.com/dataflow/docs/guides/profiling-a-pipeline.'
'For troubleshooting tips with Cloud Profiler see '
'https://cloud.google.com/profiler/docs/troubleshooting.' % e)
(fn_log_handler, sdk_harness,
sdk_pipeline_options) = create_harness(os.environ)

gcp_profiler_name = _get_gcp_profiler_name_if_enabled(sdk_pipeline_options)
if gcp_profiler_name:
_start_profiler(gcp_profiler_name, os.environ["JOB_ID"])

try:
_LOGGER.info('Python sdk harness starting.')
sdk_harness.run()
Expand Down
29 changes: 29 additions & 0 deletions sdks/python/apache_beam/runners/worker/sdk_worker_main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import io
import logging
import os
import unittest

from hamcrest import all_of
Expand Down Expand Up @@ -205,6 +206,34 @@ def test__set_log_level_overrides_error(self):
sdk_worker_main._set_log_level_overrides(overrides)
self.assertIn(expected, cm.output[0])

def test_gcp_profiler_uses_provided_service_name_when_specified(self):
options = PipelineOptions(
['--dataflow_service_options=enable_google_cloud_profiler=sample'])
gcp_profiler_name = sdk_worker_main._get_gcp_profiler_name_if_enabled(
options)
sdk_worker_main._start_profiler = unittest.mock.MagicMock()
sdk_worker_main._start_profiler(gcp_profiler_name, "version")
sdk_worker_main._start_profiler.assert_called_with("sample", "version")

@unittest.mock.patch.dict(os.environ, {"JOB_NAME": "sample_job"}, clear=True)
def test_gcp_profiler_uses_job_name_when_service_name_not_specified(self):
options = PipelineOptions(
['--dataflow_service_options=enable_google_cloud_profiler'])
gcp_profiler_name = sdk_worker_main._get_gcp_profiler_name_if_enabled(
options)
sdk_worker_main._start_profiler = unittest.mock.MagicMock()
sdk_worker_main._start_profiler(gcp_profiler_name, "version")
sdk_worker_main._start_profiler.assert_called_with("sample_job", "version")

@unittest.mock.patch.dict(os.environ, {"JOB_NAME": "sample_job"}, clear=True)
def test_gcp_profiler_uses_job_name_when_enabled_as_experiment(self):
options = PipelineOptions(['--experiment=enable_google_cloud_profiler'])
gcp_profiler_name = sdk_worker_main._get_gcp_profiler_name_if_enabled(
options)
sdk_worker_main._start_profiler = unittest.mock.MagicMock()
sdk_worker_main._start_profiler(gcp_profiler_name, "version")
sdk_worker_main._start_profiler.assert_called_with("sample_job", "version")


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
Expand Down