Skip to content

Commit

Permalink
Enable Kerberos Authentication in Async Hive Sensors (#357)
Browse files Browse the repository at this point in the history
  • Loading branch information
rajaths010494 authored May 23, 2022
1 parent fc99f74 commit dc0256c
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 19 deletions.
24 changes: 18 additions & 6 deletions astronomer/providers/apache/hive/hooks/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import asyncio
from typing import Tuple

from airflow.configuration import conf
from airflow.hooks.base import BaseHook
from impala.dbapi import connect
from impala.hiveserver2 import HiveServer2Connection
Expand All @@ -17,17 +18,28 @@ class HiveCliHookAsync(BaseHook):

def __init__(self, metastore_conn_id: str) -> None:
"""Get the connection parameters separated from connection string"""
self.metastore_conn_id = self.get_connection(conn_id=metastore_conn_id)
self.auth_mechanism = self.metastore_conn_id.extra_dejson.get("authMechanism", "PLAIN")
super().__init__()
self.conn = self.get_connection(conn_id=metastore_conn_id)
self.auth_mechanism = self.conn.extra_dejson.get("authMechanism", "PLAIN")

def get_hive_client(self) -> HiveServer2Connection:
"""Makes a connection to the hive client using impyla library"""
if conf.get("core", "security") == "kerberos":
auth_mechanism = self.conn.extra_dejson.get("authMechanism", "GSSAPI")
kerberos_service_name = self.conn.extra_dejson.get("kerberos_service_name", "hive")
return connect(
host=self.conn.host,
port=self.conn.port,
auth_mechanism=auth_mechanism,
kerberos_service_name=kerberos_service_name,
)

return connect(
host=self.metastore_conn_id.host,
port=self.metastore_conn_id.port,
host=self.conn.host,
port=self.conn.port,
auth_mechanism=self.auth_mechanism,
user=self.metastore_conn_id.login,
password=self.metastore_conn_id.password,
user=self.conn.login,
password=self.conn.password,
)

async def partition_exists(self, table: str, schema: str, partition: str, polling_interval: float) -> str:
Expand Down
20 changes: 13 additions & 7 deletions astronomer/providers/apache/hive/sensors/hive_partition.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Dict, Optional

from airflow import AirflowException
from airflow.exceptions import AirflowException
from airflow.providers.apache.hive.sensors.hive_partition import HivePartitionSensor
from airflow.utils.context import Context

Expand All @@ -14,21 +14,27 @@ class HivePartitionSensorAsync(HivePartitionSensor):
Waits for a given partition to show up in Hive table asynchronously.
.. note::
HivePartitionSensorAsync uses implya library instead of pyhive.
The sync version of this sensor uses `pyhive <https://github.com/dropbox/PyHive>`_,
but `pyhive <https://github.com/dropbox/PyHive>`_ is currently unsupported.
HivePartitionSensorAsync uses impyla library instead of PyHive.
The sync version of this sensor uses `PyHive <https://github.com/dropbox/PyHive>`.
Since we use `implya <https://github.com/cloudera/impyla>`_ library,
Since we use `impyla <https://github.com/cloudera/impyla>`_ library,
please set the connection to use the port ``10000`` instead of ``9083``.
This sensor currently supports ``auth_mechansim='PLAIN'`` only.
For ``auth_mechanism='GSSAPI'`` the ticket renewal happens through command
``airflow kerberos`` in
`worker/trigger <https://airflow.apache.org/docs/apache-airflow/stable/security/kerberos.html>`_.
You may also need to allow traffic from Airflow worker/Triggerer to the Hive instance, depending on where
they are running. For example, you might consider adding an entry in the ``etc/hosts`` file present in the
Airflow worker/Triggerer, which maps the EMR Master node Public IP Address to its Private DNS Name to
allow the network traffic.
The library version of hive and hadoop in ``Dockerfile`` should match the remote
cluster where they are running.
:param table: the table where the partition is present.
:param partition: The partition clause to wait for. This is passed as
notation as in "ds='2015-01-01'"
:param schema: database which needs to be connected in hive. By default it is 'default'
:param schema: database which needs to be connected in hive. By default, it is 'default'
:param metastore_conn_id: connection string to connect to hive.
:param polling_interval: The interval in seconds to wait between checks for partition.
"""
Expand Down
18 changes: 18 additions & 0 deletions astronomer/providers/apache/hive/sensors/named_hive_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,24 @@ class NamedHivePartitionSensorAsync(NamedHivePartitionSensor):
"""
Waits asynchronously for a set of partitions to show up in Hive.
.. note::
HivePartitionSensorAsync uses impyla library instead of PyHive.
The sync version of this sensor uses `PyHive <https://github.com/dropbox/PyHive>`.
Since we use `impyla <https://github.com/cloudera/impyla>`_ library,
please set the connection to use the port ``10000`` instead of ``9083``.
For ``auth_mechanism='GSSAPI'`` the ticket renewal happens through command
``airflow kerberos`` in
`worker/trigger <https://airflow.apache.org/docs/apache-airflow/stable/security/kerberos.html>`_.
You may also need to allow traffic from Airflow worker/Triggerer to the Hive instance, depending on where
they are running. For example, you might consider adding an entry in the ``etc/hosts`` file present in the
Airflow worker/Triggerer, which maps the EMR Master node Public IP Address to its Private DNS Name to
allow the network traffic.
The library version of hive and hadoop in ``Dockerfile`` should match the remote
cluster where they are running.
:param partition_names: List of fully qualified names of the
partitions to wait for. A fully qualified name is of the
form ``schema.table/pk1=pv1/pk2=pv2``, for example,
Expand Down
28 changes: 22 additions & 6 deletions tests/apache/hive/hooks/test_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,37 @@


@mock.patch("astronomer.providers.apache.hive.hooks.hive.HiveCliHookAsync.get_connection")
@mock.patch("astronomer.providers.apache.hive.hooks.hive.HiveCliHookAsync.get_hive_client")
def test_get_hive_client(mock_connect, mock_get_connection):
@mock.patch("airflow.configuration.AirflowConfigParser.get")
@mock.patch("impala.hiveserver2.connect")
def test_get_hive_client_with_conf(mock_get_connect, mock_get_conf, mock_get_connection):
"""Checks the connection to hive client"""
mock_get_connect.return_value = mock.AsyncMock(HiveServer2Connection)
mock_get_conf.return_value = "kerberos"
mock_get_connection.return_value = models.Connection(
conn_id="metastore_default",
conn_type="metastore",
port=10000,
host="localhost",
)
hook = HiveCliHookAsync(TEST_METASTORE_CONN_ID)
result = hook.get_hive_client()
assert isinstance(result, HiveServer2Connection)


@mock.patch("astronomer.providers.apache.hive.hooks.hive.HiveCliHookAsync.get_connection")
@mock.patch("impala.hiveserver2.connect")
def test_get_hive_client(mock_get_connect, mock_get_connection):
"""Checks the connection to hive client"""
mock_connect.return_value = HiveServer2Connection
mock_get_connect.return_value = mock.AsyncMock(HiveServer2Connection)
mock_get_connection.return_value = models.Connection(
conn_id="metastore_default",
conn_type="metastore",
port=10000,
host="localhost",
extra='{"auth": ""}',
schema="default",
)
hook = HiveCliHookAsync(TEST_METASTORE_CONN_ID)
result = hook.get_hive_client()
assert isinstance(result, type(HiveServer2Connection))
assert isinstance(result, HiveServer2Connection)


@pytest.mark.asyncio
Expand Down

0 comments on commit dc0256c

Please sign in to comment.