Skip to content

Commit

Permalink
[AIRFLOW-2658] Add GCP specific k8s pod operator (apache#3532)
Browse files Browse the repository at this point in the history
Executes a task in a Kubernetes pod in the specified Google Kubernetes
Engine cluster. This makes it easier to interact with GCP kubernetes
engine service because it encapsulates acquiring credentials.
  • Loading branch information
Noremac201 authored and jeffkpayne committed Dec 20, 2018
1 parent 739fbff commit d3243d1
Show file tree
Hide file tree
Showing 4 changed files with 327 additions and 1 deletion.
149 changes: 149 additions & 0 deletions airflow/contrib/operators/gcp_container_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,13 @@
# specific language governing permissions and limitations
# under the License.
#
import os
import subprocess
import tempfile

from airflow import AirflowException
from airflow.contrib.hooks.gcp_container_hook import GKEClusterHook
from airflow.contrib.operators.kubernetes_pod_operator import KubernetesPodOperator
from airflow.models import BaseOperator
from airflow.utils.decorators import apply_defaults

Expand Down Expand Up @@ -170,3 +175,147 @@ def execute(self, context):
hook = GKEClusterHook(self.project_id, self.location)
create_op = hook.create_cluster(cluster=self.body)
return create_op


KUBE_CONFIG_ENV_VAR = "KUBECONFIG"
G_APP_CRED = "GOOGLE_APPLICATION_CREDENTIALS"


class GKEPodOperator(KubernetesPodOperator):
template_fields = ('project_id', 'location',
'cluster_name') + KubernetesPodOperator.template_fields

@apply_defaults
def __init__(self,
project_id,
location,
cluster_name,
gcp_conn_id='google_cloud_default',
*args,
**kwargs):
"""
Executes a task in a Kubernetes pod in the specified Google Kubernetes
Engine cluster
This Operator assumes that the system has gcloud installed and either
has working default application credentials or has configured a
connection id with a service account.
The **minimum** required to define a cluster to create are the variables
``task_id``, ``project_id``, ``location``, ``cluster_name``, ``name``,
``namespace``, and ``image``
**Operator Creation**: ::
operator = GKEPodOperator(task_id='pod_op',
project_id='my-project',
location='us-central1-a',
cluster_name='my-cluster-name',
name='task-name',
namespace='default',
image='perl')
.. seealso::
For more detail about application authentication have a look at the reference:
https://cloud.google.com/docs/authentication/production#providing_credentials_to_your_application
:param project_id: The Google Developers Console project id
:type project_id: str
:param location: The name of the Google Kubernetes Engine zone in which the
cluster resides, e.g. 'us-central1-a'
:type location: str
:param cluster_name: The name of the Google Kubernetes Engine cluster the pod
should be spawned in
:type cluster_name: str
:param gcp_conn_id: The google cloud connection id to use. This allows for
users to specify a service account.
:type gcp_conn_id: str
"""
super(GKEPodOperator, self).__init__(*args, **kwargs)
self.project_id = project_id
self.location = location
self.cluster_name = cluster_name
self.gcp_conn_id = gcp_conn_id

def execute(self, context):
# Specifying a service account file allows the user to using non default
# authentication for creating a Kubernetes Pod. This is done by setting the
# environment variable `GOOGLE_APPLICATION_CREDENTIALS` that gcloud looks at.
key_file = None

# If gcp_conn_id is not specified gcloud will use the default
# service account credentials.
if self.gcp_conn_id:
from airflow.hooks.base_hook import BaseHook
# extras is a deserialized json object
extras = BaseHook.get_connection(self.gcp_conn_id).extra_dejson
# key_file only gets set if a json file is created from a JSON string in
# the web ui, else none
key_file = self._set_env_from_extras(extras=extras)

# Write config to a temp file and set the environment variable to point to it.
# This is to avoid race conditions of reading/writing a single file
with tempfile.NamedTemporaryFile() as conf_file:
os.environ[KUBE_CONFIG_ENV_VAR] = conf_file.name
# Attempt to get/update credentials
# We call gcloud directly instead of using google-cloud-python api
# because there is no way to write kubernetes config to a file, which is
# required by KubernetesPodOperator.
# The gcloud command looks at the env variable `KUBECONFIG` for where to save
# the kubernetes config file.
subprocess.check_call(
["gcloud", "container", "clusters", "get-credentials",
self.cluster_name,
"--zone", self.location,
"--project", self.project_id])

# Since the key file is of type mkstemp() closing the file will delete it from
# the file system so it cannot be accessed after we don't need it anymore
if key_file:
key_file.close()

# Tell `KubernetesPodOperator` where the config file is located
self.config_file = os.environ[KUBE_CONFIG_ENV_VAR]
super(GKEPodOperator, self).execute(context)

def _set_env_from_extras(self, extras):
"""
Sets the environment variable `GOOGLE_APPLICATION_CREDENTIALS` with either:
- The path to the keyfile from the specified connection id
- A generated file's path if the user specified JSON in the connection id. The
file is assumed to be deleted after the process dies due to how mkstemp()
works.
The environment variable is used inside the gcloud command to determine correct
service account to use.
"""
key_path = self._get_field(extras, 'key_path', False)
keyfile_json_str = self._get_field(extras, 'keyfile_dict', False)

if not key_path and not keyfile_json_str:
self.log.info('Using gcloud with application default credentials.')
elif key_path:
os.environ[G_APP_CRED] = key_path
else:
# Write service account JSON to secure file for gcloud to reference
service_key = tempfile.NamedTemporaryFile(delete=False)
service_key.write(keyfile_json_str)
os.environ[G_APP_CRED] = service_key.name
# Return file object to have a pointer to close after use,
# thus deleting from file system.
return service_key

def _get_field(self, extras, field, default=None):
"""
Fetches a field from extras, and returns it. This is some Airflow
magic. The google_cloud_platform hook type adds custom UI elements
to the hook page, which allow admins to specify service_account,
key_path, etc. They get formatted as shown below.
"""
long_f = 'extra__google_cloud_platform__{}'.format(field)
if long_f in extras:
return extras[long_f]
else:
self.log.info('Field {} not found in extras.'.format(field))
return default
1 change: 1 addition & 0 deletions docs/code.rst
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ Operators
.. autoclass:: airflow.contrib.operators.file_to_wasb.FileToWasbOperator
.. autoclass:: airflow.contrib.operators.gcp_container_operator.GKEClusterCreateOperator
.. autoclass:: airflow.contrib.operators.gcp_container_operator.GKEClusterDeleteOperator
.. autoclass:: airflow.contrib.operators.gcp_container_operator.GKEPodOperator
.. autoclass:: airflow.contrib.operators.gcs_download_operator.GoogleCloudStorageDownloadOperator
.. autoclass:: airflow.contrib.operators.gcs_list_operator.GoogleCloudStorageListOperator
.. autoclass:: airflow.contrib.operators.gcs_operator.GoogleCloudStorageCreateBucketOperator
Expand Down
6 changes: 6 additions & 0 deletions docs/integration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -931,6 +931,12 @@ GKEClusterDeleteOperator
.. autoclass:: airflow.contrib.operators.gcp_container_operator.GKEClusterDeleteOperator
.. _GKEClusterDeleteOperator:

GKEPodOperator
^^^^^^^^^^^^^^

.. autoclass:: airflow.contrib.operators.gcp_container_operator.GKEPodOperator
.. _GKEPodOperator:

Google Kubernetes Engine Hook
"""""""""""""""""""""""""""""

Expand Down
172 changes: 171 additions & 1 deletion tests/contrib/operators/test_gcp_container_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
# specific language governing permissions and limitations
# under the License.

import os
import unittest

from airflow import AirflowException
from airflow.contrib.operators.gcp_container_operator import GKEClusterCreateOperator, \
GKEClusterDeleteOperator
GKEClusterDeleteOperator, GKEPodOperator
from airflow.contrib.operators.kubernetes_pod_operator import KubernetesPodOperator

try:
from unittest import mock
Expand All @@ -39,6 +41,15 @@
PROJECT_BODY = {'name': 'test-name'}
PROJECT_BODY_CREATE = {'name': 'test-name', 'initial_node_count': 1}

TASK_NAME = 'test-task-name'
NAMESPACE = 'default',
IMAGE = 'bash'

GCLOUD_COMMAND = "gcloud container clusters get-credentials {} --zone {} --project {}"
KUBE_ENV_VAR = 'KUBECONFIG'
GAC_ENV_VAR = 'GOOGLE_APPLICATION_CREDENTIALS'
FILE_NAME = '/tmp/mock_name'


class GoogleCloudPlatformContainerOperatorTest(unittest.TestCase):

Expand Down Expand Up @@ -123,3 +134,162 @@ def test_delete_execute_error_location(self, mock_hook):

operator.execute(None)
mock_hook.return_value.delete_cluster.assert_not_called()


class GKEPodOperatorTest(unittest.TestCase):
def setUp(self):
self.gke_op = GKEPodOperator(project_id=PROJECT_ID,
location=PROJECT_LOCATION,
cluster_name=CLUSTER_NAME,
task_id=PROJECT_TASK_ID,
name=TASK_NAME,
namespace=NAMESPACE,
image=IMAGE)

def test_template_fields(self):
self.assertTrue(set(KubernetesPodOperator.template_fields).issubset(
GKEPodOperator.template_fields))

@mock.patch(
'airflow.contrib.operators.kubernetes_pod_operator.KubernetesPodOperator.execute')
@mock.patch('tempfile.NamedTemporaryFile')
@mock.patch("subprocess.check_call")
def test_execute_conn_id_none(self, proc_mock, file_mock, exec_mock):
self.gke_op.gcp_conn_id = None

file_mock.return_value.__enter__.return_value.name = FILE_NAME

self.gke_op.execute(None)

# Assert Environment Variable is being set correctly
self.assertIn(KUBE_ENV_VAR, os.environ)
self.assertEqual(os.environ[KUBE_ENV_VAR], FILE_NAME)

# Assert the gcloud command being called correctly
proc_mock.assert_called_with(
GCLOUD_COMMAND.format(CLUSTER_NAME, PROJECT_LOCATION, PROJECT_ID).split())

self.assertEqual(self.gke_op.config_file, FILE_NAME)

@mock.patch('airflow.hooks.base_hook.BaseHook.get_connection')
@mock.patch(
'airflow.contrib.operators.kubernetes_pod_operator.KubernetesPodOperator.execute')
@mock.patch('tempfile.NamedTemporaryFile')
@mock.patch("subprocess.check_call")
@mock.patch.dict(os.environ, {})
def test_execute_conn_id_path(self, proc_mock, file_mock, exec_mock, get_con_mock):
# gcp_conn_id is defaulted to `google_cloud_default`

FILE_PATH = '/path/to/file'
KEYFILE_DICT = {"extra__google_cloud_platform__key_path": FILE_PATH}
get_con_mock.return_value.extra_dejson = KEYFILE_DICT
file_mock.return_value.__enter__.return_value.name = FILE_NAME

self.gke_op.execute(None)

# Assert Environment Variable is being set correctly
self.assertIn(KUBE_ENV_VAR, os.environ)
self.assertEqual(os.environ[KUBE_ENV_VAR], FILE_NAME)

self.assertIn(GAC_ENV_VAR, os.environ)
# since we passed in keyfile_path we should get a file
self.assertEqual(os.environ[GAC_ENV_VAR], FILE_PATH)

# Assert the gcloud command being called correctly
proc_mock.assert_called_with(
GCLOUD_COMMAND.format(CLUSTER_NAME, PROJECT_LOCATION, PROJECT_ID).split())

self.assertEqual(self.gke_op.config_file, FILE_NAME)

@mock.patch.dict(os.environ, {})
@mock.patch('airflow.hooks.base_hook.BaseHook.get_connection')
@mock.patch(
'airflow.contrib.operators.kubernetes_pod_operator.KubernetesPodOperator.execute')
@mock.patch('tempfile.NamedTemporaryFile')
@mock.patch("subprocess.check_call")
def test_execute_conn_id_dict(self, proc_mock, file_mock, exec_mock, get_con_mock):
# gcp_conn_id is defaulted to `google_cloud_default`
FILE_PATH = '/path/to/file'

# This is used in the _set_env_from_extras method
file_mock.return_value.name = FILE_PATH
# This is used in the execute method
file_mock.return_value.__enter__.return_value.name = FILE_NAME

KEYFILE_DICT = {"extra__google_cloud_platform__keyfile_dict":
'{"private_key": "r4nd0m_k3y"}'}
get_con_mock.return_value.extra_dejson = KEYFILE_DICT

self.gke_op.execute(None)

# Assert Environment Variable is being set correctly
self.assertIn(KUBE_ENV_VAR, os.environ)
self.assertEqual(os.environ[KUBE_ENV_VAR], FILE_NAME)

self.assertIn(GAC_ENV_VAR, os.environ)
# since we passed in keyfile_path we should get a file
self.assertEqual(os.environ[GAC_ENV_VAR], FILE_PATH)

# Assert the gcloud command being called correctly
proc_mock.assert_called_with(
GCLOUD_COMMAND.format(CLUSTER_NAME, PROJECT_LOCATION, PROJECT_ID).split())

self.assertEqual(self.gke_op.config_file, FILE_NAME)

@mock.patch.dict(os.environ, {})
def test_set_env_from_extras_none(self):
extras = {}
self.gke_op._set_env_from_extras(extras)
# _set_env_from_extras should not edit os.environ if extras does not specify
self.assertNotIn(GAC_ENV_VAR, os.environ)

@mock.patch.dict(os.environ, {})
@mock.patch('tempfile.NamedTemporaryFile')
def test_set_env_from_extras_dict(self, file_mock):
file_mock.return_value.name = FILE_NAME

KEYFILE_DICT_STR = '{ \"test\": \"cluster\" }'
extras = {
'extra__google_cloud_platform__keyfile_dict': KEYFILE_DICT_STR,
}

self.gke_op._set_env_from_extras(extras)
self.assertEquals(os.environ[GAC_ENV_VAR], FILE_NAME)

file_mock.return_value.write.assert_called_once_with(KEYFILE_DICT_STR)

@mock.patch.dict(os.environ, {})
def test_set_env_from_extras_path(self):
TEST_PATH = '/test/path'

extras = {
'extra__google_cloud_platform__key_path': TEST_PATH,
}

self.gke_op._set_env_from_extras(extras)
self.assertEquals(os.environ[GAC_ENV_VAR], TEST_PATH)

def test_get_field(self):
FIELD_NAME = 'test_field'
FIELD_VALUE = 'test_field_value'
extras = {
'extra__google_cloud_platform__{}'.format(FIELD_NAME):
FIELD_VALUE
}

ret_val = self.gke_op._get_field(extras, FIELD_NAME)
self.assertEqual(FIELD_VALUE, ret_val)

@mock.patch('airflow.contrib.operators.gcp_container_operator.GKEPodOperator.log')
def test_get_field_fail(self, log_mock):
log_mock.info = mock.Mock()
LOG_STR = 'Field {} not found in extras.'
FIELD_NAME = 'test_field'
FIELD_VALUE = 'test_field_value'

extras = {}

ret_val = self.gke_op._get_field(extras, FIELD_NAME, default=FIELD_VALUE)
# Assert default is returned upon failure
self.assertEqual(FIELD_VALUE, ret_val)
log_mock.info.assert_called_with(LOG_STR.format(FIELD_NAME))

0 comments on commit d3243d1

Please sign in to comment.