Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix deferrable mode for DataflowTemplatedJobStartOperator and DataflowStartFlexTemplateOperator #39018

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 145 additions & 32 deletions airflow/providers/google/cloud/hooks/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,13 @@
MessagesV1Beta3AsyncClient,
MetricsV1Beta3AsyncClient,
)
from google.cloud.dataflow_v1beta3.types import GetJobMetricsRequest, JobMessageImportance, JobMetrics
from google.cloud.dataflow_v1beta3.types import (
GetJobMetricsRequest,
JobMessageImportance,
JobMetrics,
)
from google.cloud.dataflow_v1beta3.types.jobs import ListJobsRequest
from googleapiclient.discovery import build
from googleapiclient.discovery import Resource, build

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.apache.beam.hooks.beam import BeamHook, BeamRunnerType, beam_options_to_args
Expand Down Expand Up @@ -573,7 +577,7 @@ def __init__(
impersonation_chain=impersonation_chain,
)

def get_conn(self) -> build:
def get_conn(self) -> Resource:
"""Return a Google Cloud Dataflow service object."""
http_authorized = self._authorize()
return build("dataflow", "v1b3", http=http_authorized, cache_discovery=False)
Expand Down Expand Up @@ -653,9 +657,9 @@ def start_template_dataflow(
on_new_job_callback: Callable[[dict], None] | None = None,
location: str = DEFAULT_DATAFLOW_LOCATION,
environment: dict | None = None,
) -> dict:
) -> dict[str, str]:
"""
Start Dataflow template job.
Launch a Dataflow job with a Classic Template and wait for its completion.

:param job_name: The name of the job.
:param variables: Map of job runtime environment options.
Expand Down Expand Up @@ -688,34 +692,22 @@ def start_template_dataflow(
environment=environment,
)

service = self.get_conn()

request = (
service.projects()
.locations()
.templates()
.launch(
projectId=project_id,
location=location,
gcsPath=dataflow_template,
body={
"jobName": name,
"parameters": parameters,
"environment": environment,
},
)
job: dict[str, str] = self.send_launch_template_request(
project_id=project_id,
location=location,
gcs_path=dataflow_template,
job_name=name,
parameters=parameters,
environment=environment,
)
response = request.execute(num_retries=self.num_retries)

job = response["job"]

if on_new_job_id_callback:
warnings.warn(
"on_new_job_id_callback is Deprecated. Please start using on_new_job_callback",
AirflowProviderDeprecationWarning,
stacklevel=3,
)
on_new_job_id_callback(job.get("id"))
on_new_job_id_callback(job["id"])

if on_new_job_callback:
on_new_job_callback(job)
Expand All @@ -734,7 +726,62 @@ def start_template_dataflow(
expected_terminal_state=self.expected_terminal_state,
)
jobs_controller.wait_for_done()
return response["job"]
return job

@_fallback_to_location_from_variables
@_fallback_to_project_id_from_variables
@GoogleBaseHook.fallback_to_default_project_id
def launch_job_with_template(
Lee-W marked this conversation as resolved.
Show resolved Hide resolved
self,
*,
job_name: str,
variables: dict,
parameters: dict,
dataflow_template: str,
project_id: str,
append_job_name: bool = True,
location: str = DEFAULT_DATAFLOW_LOCATION,
environment: dict | None = None,
) -> dict[str, str]:
"""
Launch a Dataflow job with a Classic Template and exit without waiting for its completion.

:param job_name: The name of the job.
:param variables: Map of job runtime environment options.
It will update environment argument if passed.

.. seealso::
For more information on possible configurations, look at the API documentation
`https://cloud.google.com/dataflow/pipelines/specifying-exec-params
<https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment>`__

:param parameters: Parameters for the template
:param dataflow_template: GCS path to the template.
:param project_id: Optional, the Google Cloud project ID in which to start a job.
If set to None or missing, the default project_id from the Google Cloud connection is used.
:param append_job_name: True if unique suffix has to be appended to job name.
:param location: Job location.

.. seealso::
For more information on possible configurations, look at the API documentation
`https://cloud.google.com/dataflow/pipelines/specifying-exec-params
<https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment>`__
:return: the Dataflow job response
"""
name = self.build_dataflow_job_name(job_name, append_job_name)
environment = self._update_environment(
variables=variables,
environment=environment,
)
job: dict[str, str] = self.send_launch_template_request(
project_id=project_id,
location=location,
gcs_path=dataflow_template,
job_name=name,
parameters=parameters,
environment=environment,
)
return job

def _update_environment(self, variables: dict, environment: dict | None = None) -> dict:
environment = environment or {}
Expand Down Expand Up @@ -770,6 +817,35 @@ def _check_one(key, val):

return environment

def send_launch_template_request(
self,
*,
project_id: str,
location: str,
gcs_path: str,
job_name: str,
parameters: dict,
environment: dict,
) -> dict[str, str]:
service: Resource = self.get_conn()
request = (
service.projects()
.locations()
.templates()
.launch(
projectId=project_id,
location=location,
gcsPath=gcs_path,
body={
"jobName": job_name,
"parameters": parameters,
"environment": environment,
},
)
)
response: dict = request.execute(num_retries=self.num_retries)
return response["job"]

@GoogleBaseHook.fallback_to_default_project_id
def start_flex_template(
self,
Expand All @@ -778,9 +854,9 @@ def start_flex_template(
project_id: str,
on_new_job_id_callback: Callable[[str], None] | None = None,
on_new_job_callback: Callable[[dict], None] | None = None,
) -> dict:
) -> dict[str, str]:
"""
Start flex templates with the Dataflow pipeline.
Launch a Dataflow job with a Flex Template and wait for its completion.

:param body: The request body. See:
https://cloud.google.com/dataflow/docs/reference/rest/v1b3/projects.locations.flexTemplates/launch#request-body
Expand All @@ -791,31 +867,32 @@ def start_flex_template(
:param on_new_job_callback: A callback that is called when a Job is detected.
:return: the Job
"""
service = self.get_conn()
service: Resource = self.get_conn()
request = (
service.projects()
.locations()
.flexTemplates()
.launch(projectId=project_id, body=body, location=location)
)
response = request.execute(num_retries=self.num_retries)
response: dict = request.execute(num_retries=self.num_retries)
job = response["job"]
job_id: str = job["id"]

if on_new_job_id_callback:
warnings.warn(
"on_new_job_id_callback is Deprecated. Please start using on_new_job_callback",
AirflowProviderDeprecationWarning,
stacklevel=3,
)
on_new_job_id_callback(job.get("id"))
on_new_job_id_callback(job_id)

if on_new_job_callback:
on_new_job_callback(job)

jobs_controller = _DataflowJobsController(
dataflow=self.get_conn(),
project_number=project_id,
job_id=job.get("id"),
job_id=job_id,
location=location,
poll_sleep=self.poll_sleep,
num_retries=self.num_retries,
Expand All @@ -826,6 +903,42 @@ def start_flex_template(

return jobs_controller.get_jobs(refresh=True)[0]

@GoogleBaseHook.fallback_to_default_project_id
def launch_job_with_flex_template(
self,
body: dict,
location: str,
project_id: str,
) -> dict[str, str]:
"""
Launch a Dataflow Job with a Flex Template and exit without waiting for the job completion.

:param body: The request body. See:
https://cloud.google.com/dataflow/docs/reference/rest/v1b3/projects.locations.flexTemplates/launch#request-body
:param location: The location of the Dataflow job (for example europe-west1)
:param project_id: The ID of the GCP project that owns the job.
If set to ``None`` or missing, the default project_id from the GCP connection is used.
:return: a Dataflow job response
"""
service: Resource = self.get_conn()
request = (
service.projects()
.locations()
.flexTemplates()
.launch(projectId=project_id, body=body, location=location)
)
response: dict = request.execute(num_retries=self.num_retries)
return response["job"]

@staticmethod
def extract_job_id(job: dict) -> str:
try:
return job["id"]
except KeyError:
raise AirflowException(
"While reading job object after template execution error occurred. Job object has no id."
)

@_fallback_to_location_from_variables
@_fallback_to_project_id_from_variables
@GoogleBaseHook.fallback_to_default_project_id
Expand Down
Loading