Skip to content

Commit

Permalink
Refactor runners that use '_transport_request_params()' to set transp…
Browse files Browse the repository at this point in the history
…ort level options via 'options()' (#1705)

With this commit we fix a bug that prevented any runner using the low
level `perform_request()` ES client method, and `_transport_request_params()`
helper method from passing in the `request-timeout` parameter.

Relates #1673 
Relates elastic/rally-tracks#393

Note that there's still follow up work to be done to refactor the remaining runners, specifically around overriding the client's `options()` method to handle the `distribution_version` param to support REST compatibility headers as discussed in #1673, but that will only affect us when 9.x is released.
  • Loading branch information
b-deam authored Apr 19, 2023
1 parent b39af39 commit 79e05b2
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 21 deletions.
48 changes: 32 additions & 16 deletions esrally/driver/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,16 +204,29 @@ def _default_kw_params(self, params):
# filter Nones
return dict(filter(lambda kv: kv[1] is not None, full_result.items()))

def _transport_request_params(self, params):
@staticmethod
def _transport_request_params(params):
"""
Takes all of a runner's params and splits out request parameters, transport
level parameters, and headers into their own respective dicts.
:param params: A hash with all the respective runner's parameters.
:return: A tuple of the specific runner's params, request level parameters, transport level parameters, and headers, respectively.
"""
transport_params = {}
request_params = params.get("request-params", {})
request_timeout = params.get("request-timeout")
if request_timeout is not None:
request_params["request_timeout"] = request_timeout
headers = params.get("headers") or {}
opaque_id = params.get("opaque-id")
if opaque_id is not None:

if request_timeout := params.pop("request-timeout", None):
transport_params["request_timeout"] = request_timeout

if (ignore_status := request_params.pop("ignore", None)) or (ignore_status := params.pop("ignore", None)):
transport_params["ignore_status"] = ignore_status

headers = params.pop("headers", None) or {}
if opaque_id := params.pop("opaque-id", None):
headers.update({"x-opaque-id": opaque_id})
return request_params, headers

return params, request_params, transport_params, headers


class Delegator:
Expand Down Expand Up @@ -847,10 +860,9 @@ def __init__(self):
self._composite_agg_extractor = CompositeAggExtractor()

async def __call__(self, es, params):
request_params, headers = self._transport_request_params(params)
request_timeout = request_params.pop("request_timeout", None)
if request_timeout is not None:
es.options(request_timeout=request_timeout)
params, request_params, transport_params, headers = self._transport_request_params(params)
# we don't set headers at the options level because the Query runner sets them via the client's '_perform_request' method
es.options(**transport_params)
# Mandatory to ensure it is always provided. This is especially important when this runner is used in a
# composite context where there is no actual parameter source and the entire request structure must be provided
# by the composite's parameter source.
Expand Down Expand Up @@ -1933,16 +1945,19 @@ def __repr__(self, *args, **kwargs):

class RawRequest(Runner):
async def __call__(self, es, params):
request_params, headers = self._transport_request_params(params)
if "ignore" in params:
request_params["ignore"] = params["ignore"]
params, request_params, transport_params, headers = self._transport_request_params(params)
es.options(**transport_params)

path = mandatory(params, "path", self)

if not path.startswith("/"):
self.logger.error("RawRequest failed. Path parameter: [%s] must begin with a '/'.", path)
raise exceptions.RallyAssertionError(f"RawRequest [{path}] failed. Path parameter must begin with a '/'.")

if not bool(headers):
# counter-intuitive, but preserves prior behavior
headers = None

# disable eager response parsing - responses might be huge thus skewing results
es.return_raw_response()

Expand Down Expand Up @@ -2731,7 +2746,8 @@ class Downsample(Runner):
"""

async def __call__(self, es, params):
request_params, request_headers = self._transport_request_params(params)
params, request_params, transport_params, request_headers = self._transport_request_params(params)
es.options(**transport_params)

fixed_interval = mandatory(params, "fixed-interval", self)
if fixed_interval is None:
Expand Down
11 changes: 6 additions & 5 deletions tests/driver/runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1642,6 +1642,7 @@ async def test_query_with_timeout_and_headers(self, es):
"shards": {"total": 808, "successful": 808, "skipped": 0, "failed": 0},
}

es.options.assert_called_once_with(request_timeout=3.0)
es.perform_request.assert_awaited_once_with(
method="GET",
path="/_all/_search",
Expand Down Expand Up @@ -3641,10 +3642,8 @@ async def test_issue_delete_index(self, es):
},
}
await r(es, params)

es.perform_request.assert_called_once_with(
method="DELETE", path="/twitter", headers=None, body=None, params={"ignore": [400, 404], "pretty": "true"}
)
es.options.assert_called_once_with(ignore_status=[400, 404])
es.perform_request.assert_called_once_with(method="DELETE", path="/twitter", headers=None, body=None, params={"pretty": "true"})

@mock.patch("elasticsearch.Elasticsearch")
@pytest.mark.asyncio
Expand Down Expand Up @@ -3720,6 +3719,8 @@ async def test_raw_with_timeout_and_opaqueid(self, es):
}
await r(es, params)

es.options.assert_called_once_with(request_timeout=3.0)

es.perform_request.assert_called_once_with(
method="GET",
path="/_msearch",
Expand All @@ -3730,7 +3731,7 @@ async def test_raw_with_timeout_and_opaqueid(self, es):
{"index": "test", "search_type": "dfs_query_then_fetch"},
{"query": {"match_all": {}}},
],
params={"request_timeout": 3.0},
params={},
)


Expand Down

0 comments on commit 79e05b2

Please sign in to comment.