Skip to content

Commit

Permalink
Make sure that only valid elasticsearch keys are passed to handler
Browse files Browse the repository at this point in the history
The elasticsearch handler got all configuraiton parameters
from the "elasticsearch_config" section but it means that in
airflow versions pre 2.7 it could get old config keys which renders
the new provider unusable.

This PR filters out configuration parameter to only pass valid
parameters for the new handler.

Fixes: apache#34099
  • Loading branch information
potiuk committed Sep 6, 2023
1 parent 5f47e60 commit 23d1da9
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 50 deletions.
37 changes: 31 additions & 6 deletions airflow/providers/elasticsearch/log/es_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations

import inspect
import logging
import sys
import warnings
Expand All @@ -30,6 +31,7 @@
import elasticsearch
import pendulum
from elasticsearch.exceptions import NotFoundError
from typing_extensions import Literal

from airflow.configuration import conf
from airflow.exceptions import AirflowProviderDeprecationWarning
Expand All @@ -56,6 +58,32 @@
USE_PER_RUN_LOG_ID = hasattr(DagRun, "get_log_template")


VALID_ES_CONFIG_KEYS = set(inspect.signature(elasticsearch.Elasticsearch.__init__).parameters.keys())
# Remove `self` from the valid set of kwargs
VALID_ES_CONFIG_KEYS.remove("self")


def get_es_kwargs_from_config() -> dict[str, Any]:
elastic_search_config = conf.getsection("elasticsearch_configs")
kwargs_dict = (
{key: value for key, value in elastic_search_config.items() if key in VALID_ES_CONFIG_KEYS}
if elastic_search_config
else {}
)
# For elasticsearch>8 retry_timeout have changed for elasticsearch to retry_on_timeout
# in Elasticsearch() compared to previous versions.
# Read more at: https://elasticsearch-py.readthedocs.io/en/v8.8.2/api.html#module-elasticsearch
if (
elastic_search_config
and "retry_timeout" in elastic_search_config
and not kwargs_dict.get("retry_on_timeout")
):
retry_timeout = elastic_search_config.get("retry_timeout")
if retry_timeout is not None:
kwargs_dict["retry_on_timeout"] = retry_timeout
return kwargs_dict


class ElasticsearchTaskHandler(FileTaskHandler, ExternalLoggingMixin, LoggingMixin):
"""
ElasticsearchTaskHandler is a python log handler that reads logs from Elasticsearch.
Expand Down Expand Up @@ -95,17 +123,14 @@ def __init__(
host: str = "http://localhost:9200",
frontend: str = "localhost:5601",
index_patterns: str | None = conf.get("elasticsearch", "index_patterns", fallback="_all"),
es_kwargs: dict | None = conf.getsection("elasticsearch_configs"),
es_kwargs: dict | None | Literal["default_es_kwargs"] = "default_es_kwargs",
*,
filename_template: str | None = None,
log_id_template: str | None = None,
):
es_kwargs = es_kwargs or {}
# For elasticsearch>8,arguments like retry_timeout have changed for elasticsearch to retry_on_timeout
# in Elasticsearch() compared to previous versions.
# Read more at: https://elasticsearch-py.readthedocs.io/en/v8.8.2/api.html#module-elasticsearch
if es_kwargs.get("retry_timeout"):
es_kwargs["retry_on_timeout"] = es_kwargs.pop("retry_timeout")
if es_kwargs == "default_es_kwargs":
es_kwargs = get_es_kwargs_from_config()
host = self.format_url(host)
super().__init__(base_log_folder, filename_template)
self.closed = False
Expand Down
105 changes: 61 additions & 44 deletions tests/providers/elasticsearch/log/test_es_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import os
import re
import shutil
from pathlib import Path
from unittest import mock
from urllib.parse import quote

Expand All @@ -32,15 +33,24 @@

from airflow.configuration import conf
from airflow.providers.elasticsearch.log.es_response import ElasticSearchResponse
from airflow.providers.elasticsearch.log.es_task_handler import ElasticsearchTaskHandler, getattr_nested
from airflow.providers.elasticsearch.log.es_task_handler import (
VALID_ES_CONFIG_KEYS,
ElasticsearchTaskHandler,
get_es_kwargs_from_config,
getattr_nested,
)
from airflow.utils import timezone
from airflow.utils.state import DagRunState, TaskInstanceState
from airflow.utils.timezone import datetime
from tests.test_utils.config import conf_vars
from tests.test_utils.db import clear_db_dags, clear_db_runs

from .elasticmock import elasticmock
from .elasticmock.utilities import SearchFailedException

AIRFLOW_SOURCES_ROOT_DIR = Path(__file__).parents[4].resolve()
ES_PROVIDER_YAML_FILE = AIRFLOW_SOURCES_ROOT_DIR / "airflow" / "providers" / "elasticsearch" / "provider.yaml"


def get_ti(dag_id, task_id, execution_date, create_task_instance):
ti = create_task_instance(
Expand Down Expand Up @@ -145,49 +155,6 @@ def test_format_url(self, host, expected):
else:
assert ElasticsearchTaskHandler.format_url(host) == expected

def test_elasticsearch_constructor_retry_timeout_handling(self):
"""
Test if the ElasticsearchTaskHandler constructor properly handles the retry_timeout argument.
"""
# Mock the Elasticsearch client
with mock.patch(
"airflow.providers.elasticsearch.log.es_task_handler.elasticsearch.Elasticsearch"
) as mock_es:
# Test when 'retry_timeout' is present in es_kwargs
es_kwargs = {"retry_timeout": 10}
ElasticsearchTaskHandler(
base_log_folder="dummy_folder",
end_of_log_mark="end_of_log_mark",
write_stdout=False,
json_format=False,
json_fields="fields",
host_field="host",
offset_field="offset",
es_kwargs=es_kwargs,
)

# Check the arguments with which the Elasticsearch client is instantiated
mock_es.assert_called_once_with("http://localhost:9200", retry_on_timeout=10)

# Reset the mock for the next test
mock_es.reset_mock()

# Test when 'retry_timeout' is not present in es_kwargs
es_kwargs = {}
ElasticsearchTaskHandler(
base_log_folder="dummy_folder",
end_of_log_mark="end_of_log_mark",
write_stdout=False,
json_format=False,
json_fields="fields",
host_field="host",
offset_field="offset",
es_kwargs=es_kwargs,
)

# Check that the Elasticsearch client is instantiated without the 'retry_on_timeout' argument
mock_es.assert_called_once_with("http://localhost:9200")

def test_client(self):
assert isinstance(self.es_task_handler.client, elasticsearch.Elasticsearch)
assert self.es_task_handler.index_patterns == "_all"
Expand Down Expand Up @@ -691,3 +658,53 @@ class A:
assert getattr_nested(a, "aa", "heya") == "heya" # respects non-none default
assert getattr_nested(a, "c", "heya") is None # respects none value
assert getattr_nested(a, "aa", None) is None # respects none default


def test_retrieve_config_keys():
"""
Tests that the ElasticsearchTaskHandler retrieves the correct configuration keys from the config file.
* old_parameters are removed
* parameters from config are automatically added
* constructor parameters missing from config are also added
:return:
"""
with conf_vars(
{
("elasticsearch_configs", "use_ssl"): "True",
("elasticsearch_configs", "http_compress"): "False",
("elasticsearch_configs", "timeout"): "10",
}
):
args_from_config = get_es_kwargs_from_config().keys()
# use_ssl is removed from config
assert "use_ssl" not in args_from_config
# verify_certs comes from default config value
assert "verify_certs" in args_from_config
# timeout comes from config provided value
assert "timeout" in args_from_config
# http_compress comes from config value
assert "http_compress" in args_from_config
assert "self" not in args_from_config


def test_retrieve_retry_on_timeout():
"""
Test if retrieve timeout is converted to retry_on_timeout.
"""
with conf_vars(
{
("elasticsearch_configs", "retry_timeout"): "True",
}
):
args_from_config = get_es_kwargs_from_config().keys()
# use_ssl is removed from config
assert "retry_timeout" not in args_from_config
# verify_certs comes from default config value
assert "retry_on_timeout" in args_from_config


def test_self_not_valid_arg():
"""
Test if self is not a valid argument.
"""
assert "self" not in VALID_ES_CONFIG_KEYS

0 comments on commit 23d1da9

Please sign in to comment.