Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(ingestion/kafka): OAuth callback execution #11900

Merged
merged 13 commits into from
Nov 22, 2024
22 changes: 22 additions & 0 deletions metadata-ingestion/docs/sources/kafka/kafka.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,29 @@ source:
connection:
bootstrap: "broker:9092"
schema_registry_url: http://localhost:8081
```

### OAuth Callback
The OAuth callback function can be set up using `config.connection.consumer_config.oauth_cb`.

You need to specify a Python function reference in the format <python-module>:<function-name>.

For example, in the configuration `oauth:create_token`, `create_token` is a function defined in `oauth.py`, and `oauth.py` must be accessible in the PYTHONPATH.

```YAML
source:
type: "kafka"
config:
# Set the custom schema registry implementation class
schema_registry_class: "datahub.ingestion.source.confluent_schema_registry.ConfluentSchemaRegistry"
# Coordinates
connection:
bootstrap: "broker:9092"
schema_registry_url: http://localhost:8081
consumer_config:
security.protocol: "SASL_PLAINTEXT"
sasl.mechanism: "OAUTHBEARER"
oauth_cb: "oauth:create_token"
# sink configs
```

Expand Down
4 changes: 2 additions & 2 deletions metadata-ingestion/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,8 +741,8 @@
"hive = datahub.ingestion.source.sql.hive:HiveSource",
"hive-metastore = datahub.ingestion.source.sql.hive_metastore:HiveMetastoreSource",
"json-schema = datahub.ingestion.source.schema.json_schema:JsonSchemaSource",
"kafka = datahub.ingestion.source.kafka:KafkaSource",
"kafka-connect = datahub.ingestion.source.kafka_connect:KafkaConnectSource",
"kafka = datahub.ingestion.source.kafka.kafka:KafkaSource",
"kafka-connect = datahub.ingestion.source.kafka.kafka_connect:KafkaConnectSource",
"ldap = datahub.ingestion.source.ldap:LDAPSource",
"looker = datahub.ingestion.source.looker.looker_source:LookerDashboardSource",
"lookml = datahub.ingestion.source.looker.lookml_source:LookMLSource",
Expand Down
13 changes: 12 additions & 1 deletion metadata-ingestion/src/datahub/configuration/kafka.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pydantic import Field, validator

from datahub.configuration.common import ConfigModel
from datahub.configuration.common import ConfigModel, ConfigurationError
from datahub.configuration.kafka_consumer_config import CallableConsumerConfig
from datahub.configuration.validate_host_port import validate_host_port


Expand Down Expand Up @@ -36,6 +37,16 @@ class KafkaConsumerConnectionConfig(_KafkaConnectionConfig):
description="Extra consumer config serialized as JSON. These options will be passed into Kafka's DeserializingConsumer. See https://docs.confluent.io/platform/current/clients/confluent-kafka-python/html/index.html#deserializingconsumer and https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md .",
)

@validator("consumer_config")
@classmethod
def resolve_callback(cls, value: dict) -> dict:
if CallableConsumerConfig.is_callable_config(value):
try:
value = CallableConsumerConfig(value).callable_config()
except Exception as e:
raise ConfigurationError(e)
return value


class KafkaProducerConnectionConfig(_KafkaConnectionConfig):
"""Configuration class for holding connectivity information for Kafka producers"""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import logging
from typing import Any, Dict, Optional

from datahub.ingestion.api.registry import import_path

logger = logging.getLogger(__name__)


class CallableConsumerConfig:
CALLBACK_ATTRIBUTE: str = "oauth_cb"

def __init__(self, config: Dict[str, Any]):
self._config = config

self._resolve_oauth_callback()

def callable_config(self) -> Dict[str, Any]:
return self._config

@staticmethod
def is_callable_config(config: Dict[str, Any]) -> bool:
return CallableConsumerConfig.CALLBACK_ATTRIBUTE in config

def get_call_back_attribute(self) -> Optional[str]:
return self._config.get(CallableConsumerConfig.CALLBACK_ATTRIBUTE)

def _resolve_oauth_callback(self) -> None:
if not self.get_call_back_attribute():
return

call_back = self.get_call_back_attribute()

assert call_back # to silent lint
# Set the callback
self._config[CallableConsumerConfig.CALLBACK_ATTRIBUTE] = import_path(call_back)
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
from datahub.ingestion.extractor import protobuf_util, schema_util
from datahub.ingestion.extractor.json_schema_util import JsonSchemaTranslator
from datahub.ingestion.extractor.protobuf_util import ProtobufSchema
from datahub.ingestion.source.kafka import KafkaSourceConfig, KafkaSourceReport
from datahub.ingestion.source.kafka_schema_registry_base import KafkaSchemaRegistryBase
from datahub.ingestion.source.kafka.kafka import KafkaSourceConfig, KafkaSourceReport
from datahub.ingestion.source.kafka.kafka_schema_registry_base import (
KafkaSchemaRegistryBase,
)
from datahub.metadata.com.linkedin.pegasus2avro.schema import (
KafkaSchema,
SchemaField,
Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from datahub.configuration.common import AllowDenyPattern
from datahub.configuration.kafka import KafkaConsumerConnectionConfig
from datahub.configuration.kafka_consumer_config import CallableConsumerConfig
from datahub.configuration.source_common import (
DatasetSourceConfigMixin,
LowerCaseDatasetUrnConfigMixin,
Expand Down Expand Up @@ -49,7 +50,9 @@
)
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.common.subtypes import DatasetSubTypes
from datahub.ingestion.source.kafka_schema_registry_base import KafkaSchemaRegistryBase
from datahub.ingestion.source.kafka.kafka_schema_registry_base import (
KafkaSchemaRegistryBase,
)
from datahub.ingestion.source.state.stale_entity_removal_handler import (
StaleEntityRemovalHandler,
StaleEntityRemovalSourceReport,
Expand Down Expand Up @@ -143,14 +146,21 @@ class KafkaSourceConfig(
def get_kafka_consumer(
connection: KafkaConsumerConnectionConfig,
) -> confluent_kafka.Consumer:
return confluent_kafka.Consumer(
consumer = confluent_kafka.Consumer(
{
"group.id": "test",
"bootstrap.servers": connection.bootstrap,
**connection.consumer_config,
}
)

if CallableConsumerConfig.is_callable_config(connection.consumer_config):
# As per documentation, we need to explicitly call the poll method to make sure OAuth callback gets executed
# https://docs.confluent.io/platform/current/clients/confluent-kafka-python/html/index.html#kafka-client-configuration
consumer.poll(timeout=30)

return consumer


@dataclass
class KafkaSourceReport(StaleEntityRemovalSourceReport):
Expand Down
20 changes: 20 additions & 0 deletions metadata-ingestion/tests/integration/kafka/kafka_to_file_oauth.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
run_id: kafka-test

source:
type: kafka
config:
connection:
bootstrap: "localhost:29092"
schema_registry_url: "http://localhost:28081"
consumer_config:
security.protocol: "SASL_PLAINTEXT"
sasl.mechanism: "OAUTHBEARER"
oauth_cb: "oauth:create_token"
domain:
"urn:li:domain:sales":
allow:
- "key_value_topic"
sink:
type: file
config:
filename: "./kafka_mces.json"
14 changes: 14 additions & 0 deletions metadata-ingestion/tests/integration/kafka/oauth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import logging
from typing import Any, Tuple

logger = logging.getLogger(__name__)

MESSAGE: str = "OAuth token `create_token` callback"


def create_token(*args: Any, **kwargs: Any) -> Tuple[str, int]:
logger.warning(MESSAGE)
return (
"eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJjbGllbnRfaWQiOiJrYWZrYV9jbGllbnQiLCJleHAiOjE2OTg3NjYwMDB9.dummy_sig_abcdef123456",
3600,
)
39 changes: 38 additions & 1 deletion metadata-ingestion/tests/integration/kafka/test_kafka.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import logging
import subprocess

import pytest
import yaml
from freezegun import freeze_time

from datahub.ingestion.api.source import SourceCapability
from datahub.ingestion.source.kafka import KafkaSource
from datahub.ingestion.run.pipeline import Pipeline
from datahub.ingestion.source.kafka.kafka import KafkaSource
from tests.integration.kafka import oauth # type: ignore
from tests.test_helpers import mce_helpers, test_connection_helpers
from tests.test_helpers.click_helpers import run_datahub_cmd
from tests.test_helpers.docker_helpers import wait_for_port
Expand Down Expand Up @@ -99,3 +103,36 @@ def test_kafka_test_connection(mock_kafka_service, config_dict, is_success):
SourceCapability.SCHEMA_METADATA: "Failed to establish a new connection"
},
)


@freeze_time(FROZEN_TIME)
@pytest.mark.integration
def test_kafka_oauth_callback(
mock_kafka_service, test_resources_dir, pytestconfig, tmp_path, mock_time
):
# Run the metadata ingestion pipeline.
config_file = (test_resources_dir / "kafka_to_file_oauth.yml").resolve()

log_file = tmp_path / "kafka_oauth_message.log"

file_handler = logging.FileHandler(
str(log_file)
) # Add a file handler to later validate a test-case
logging.getLogger().addHandler(file_handler)

recipe: dict = {}
with open(config_file) as fp:
recipe = yaml.safe_load(fp)

pipeline = Pipeline.create(recipe)

pipeline.run()

is_found: bool = False
with open(log_file, "r") as file:
for line_number, line in enumerate(file, 1):
if oauth.MESSAGE in line:
is_found = True
break

assert is_found
8 changes: 6 additions & 2 deletions metadata-ingestion/tests/unit/api/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@

class TestPipeline:
@patch("confluent_kafka.Consumer", autospec=True)
@patch("datahub.ingestion.source.kafka.KafkaSource.get_workunits", autospec=True)
@patch(
"datahub.ingestion.source.kafka.kafka.KafkaSource.get_workunits", autospec=True
)
@patch("datahub.ingestion.sink.console.ConsoleSink.close", autospec=True)
@freeze_time(FROZEN_TIME)
def test_configure(self, mock_sink, mock_source, mock_consumer):
Expand Down Expand Up @@ -198,7 +200,9 @@ def test_configure_with_rest_sink_with_additional_props_initializes_graph(
assert pipeline.ctx.graph.config.token == pipeline.config.sink.config["token"]

@freeze_time(FROZEN_TIME)
@patch("datahub.ingestion.source.kafka.KafkaSource.get_workunits", autospec=True)
@patch(
"datahub.ingestion.source.kafka.kafka.KafkaSource.get_workunits", autospec=True
)
def test_configure_with_file_sink_does_not_init_graph(self, mock_source, tmp_path):
pipeline = Pipeline.create(
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
)

from datahub.ingestion.source.confluent_schema_registry import ConfluentSchemaRegistry
from datahub.ingestion.source.kafka import KafkaSourceConfig, KafkaSourceReport
from datahub.ingestion.source.kafka.kafka import KafkaSourceConfig, KafkaSourceReport


class ConfluentSchemaRegistryTest(unittest.TestCase):
Expand Down
32 changes: 17 additions & 15 deletions metadata-ingestion/tests/unit/test_kafka_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.kafka import KafkaSource, KafkaSourceConfig
from datahub.ingestion.source.kafka.kafka import KafkaSource, KafkaSourceConfig
from datahub.metadata.com.linkedin.pegasus2avro.mxe import MetadataChangeEvent
from datahub.metadata.schema_classes import (
BrowsePathsClass,
Expand All @@ -38,11 +38,13 @@

@pytest.fixture
def mock_admin_client():
with patch("datahub.ingestion.source.kafka.AdminClient", autospec=True) as mock:
with patch(
"datahub.ingestion.source.kafka.kafka.AdminClient", autospec=True
) as mock:
yield mock


@patch("datahub.ingestion.source.kafka.confluent_kafka.Consumer", autospec=True)
@patch("datahub.ingestion.source.kafka.kafka.confluent_kafka.Consumer", autospec=True)
def test_kafka_source_configuration(mock_kafka):
ctx = PipelineContext(run_id="test")
kafka_source = KafkaSource(
Expand All @@ -53,7 +55,7 @@ def test_kafka_source_configuration(mock_kafka):
assert mock_kafka.call_count == 1


@patch("datahub.ingestion.source.kafka.confluent_kafka.Consumer", autospec=True)
@patch("datahub.ingestion.source.kafka.kafka.confluent_kafka.Consumer", autospec=True)
def test_kafka_source_workunits_wildcard_topic(mock_kafka, mock_admin_client):
mock_kafka_instance = mock_kafka.return_value
mock_cluster_metadata = MagicMock()
Expand All @@ -74,7 +76,7 @@ def test_kafka_source_workunits_wildcard_topic(mock_kafka, mock_admin_client):
assert len(workunits) == 4


@patch("datahub.ingestion.source.kafka.confluent_kafka.Consumer", autospec=True)
@patch("datahub.ingestion.source.kafka.kafka.confluent_kafka.Consumer", autospec=True)
def test_kafka_source_workunits_topic_pattern(mock_kafka, mock_admin_client):
mock_kafka_instance = mock_kafka.return_value
mock_cluster_metadata = MagicMock()
Expand Down Expand Up @@ -108,7 +110,7 @@ def test_kafka_source_workunits_topic_pattern(mock_kafka, mock_admin_client):
assert len(workunits) == 4


@patch("datahub.ingestion.source.kafka.confluent_kafka.Consumer", autospec=True)
@patch("datahub.ingestion.source.kafka.kafka.confluent_kafka.Consumer", autospec=True)
def test_kafka_source_workunits_with_platform_instance(mock_kafka, mock_admin_client):
PLATFORM_INSTANCE = "kafka_cluster"
PLATFORM = "kafka"
Expand Down Expand Up @@ -160,7 +162,7 @@ def test_kafka_source_workunits_with_platform_instance(mock_kafka, mock_admin_cl
assert f"/prod/{PLATFORM}/{PLATFORM_INSTANCE}" in browse_path_aspects[0].paths


@patch("datahub.ingestion.source.kafka.confluent_kafka.Consumer", autospec=True)
@patch("datahub.ingestion.source.kafka.kafka.confluent_kafka.Consumer", autospec=True)
def test_kafka_source_workunits_no_platform_instance(mock_kafka, mock_admin_client):
PLATFORM = "kafka"
TOPIC_NAME = "test"
Expand Down Expand Up @@ -204,7 +206,7 @@ def test_kafka_source_workunits_no_platform_instance(mock_kafka, mock_admin_clie
assert f"/prod/{PLATFORM}" in browse_path_aspects[0].paths


@patch("datahub.ingestion.source.kafka.confluent_kafka.Consumer", autospec=True)
@patch("datahub.ingestion.source.kafka.kafka.confluent_kafka.Consumer", autospec=True)
def test_close(mock_kafka, mock_admin_client):
mock_kafka_instance = mock_kafka.return_value
ctx = PipelineContext(run_id="test")
Expand All @@ -223,7 +225,7 @@ def test_close(mock_kafka, mock_admin_client):
"datahub.ingestion.source.confluent_schema_registry.SchemaRegistryClient",
autospec=True,
)
@patch("datahub.ingestion.source.kafka.confluent_kafka.Consumer", autospec=True)
@patch("datahub.ingestion.source.kafka.kafka.confluent_kafka.Consumer", autospec=True)
def test_kafka_source_workunits_schema_registry_subject_name_strategies(
mock_kafka_consumer, mock_schema_registry_client, mock_admin_client
):
Expand Down Expand Up @@ -415,7 +417,7 @@ def mock_get_latest_version(subject_name: str) -> Optional[RegisteredSchema]:
"datahub.ingestion.source.confluent_schema_registry.SchemaRegistryClient",
autospec=True,
)
@patch("datahub.ingestion.source.kafka.confluent_kafka.Consumer", autospec=True)
@patch("datahub.ingestion.source.kafka.kafka.confluent_kafka.Consumer", autospec=True)
def test_kafka_ignore_warnings_on_schema_type(
mock_kafka_consumer,
mock_schema_registry_client,
Expand Down Expand Up @@ -483,8 +485,8 @@ def mock_get_latest_version(subject_name: str) -> Optional[RegisteredSchema]:
assert kafka_source.report.warnings


@patch("datahub.ingestion.source.kafka.AdminClient", autospec=True)
@patch("datahub.ingestion.source.kafka.confluent_kafka.Consumer", autospec=True)
@patch("datahub.ingestion.source.kafka.kafka.AdminClient", autospec=True)
@patch("datahub.ingestion.source.kafka.kafka.confluent_kafka.Consumer", autospec=True)
def test_kafka_source_succeeds_with_admin_client_init_error(
mock_kafka, mock_kafka_admin_client
):
Expand Down Expand Up @@ -513,8 +515,8 @@ def test_kafka_source_succeeds_with_admin_client_init_error(
assert len(workunits) == 2


@patch("datahub.ingestion.source.kafka.AdminClient", autospec=True)
@patch("datahub.ingestion.source.kafka.confluent_kafka.Consumer", autospec=True)
@patch("datahub.ingestion.source.kafka.kafka.AdminClient", autospec=True)
@patch("datahub.ingestion.source.kafka.kafka.confluent_kafka.Consumer", autospec=True)
def test_kafka_source_succeeds_with_describe_configs_error(
mock_kafka, mock_kafka_admin_client
):
Expand Down Expand Up @@ -550,7 +552,7 @@ def test_kafka_source_succeeds_with_describe_configs_error(
"datahub.ingestion.source.confluent_schema_registry.SchemaRegistryClient",
autospec=True,
)
@patch("datahub.ingestion.source.kafka.confluent_kafka.Consumer", autospec=True)
@patch("datahub.ingestion.source.kafka.kafka.confluent_kafka.Consumer", autospec=True)
def test_kafka_source_topic_meta_mappings(
mock_kafka_consumer, mock_schema_registry_client, mock_admin_client
):
Expand Down
Loading