Skip to content

Commit

Permalink
Add retrieve output docker swarm operator (#41531)
Browse files Browse the repository at this point in the history
* Add retrieve_output functionnality to DockerSwarmOperator

* Result XCOM as list only if replicated containers

* Fix unit tests
  • Loading branch information
rgriffier authored Aug 22, 2024
1 parent 9608ebf commit 16f0073
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 21 deletions.
27 changes: 13 additions & 14 deletions airflow/providers/docker/operators/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,24 +472,23 @@ def _attempt_to_retrieve_result(self):
This uses Docker's ``get_archive`` function. If the file is not yet
ready, *None* is returned.
"""

def copy_from_docker(container_id, src):
archived_result, stat = self.cli.get_archive(container_id, src)
if stat["size"] == 0:
# 0 byte file, it can't be anything else than None
return None
# no need to port to a file since we intend to deserialize
with BytesIO(b"".join(archived_result)) as f:
tar = tarfile.open(fileobj=f)
file = tar.extractfile(stat["name"])
lib = getattr(self, "pickling_library", pickle)
return lib.load(file)

try:
return copy_from_docker(self.container["Id"], self.retrieve_output_path)
return self._copy_from_docker(self.container["Id"], self.retrieve_output_path)
except APIError:
return None

def _copy_from_docker(self, container_id, src):
archived_result, stat = self.cli.get_archive(container_id, src)
if stat["size"] == 0:
# 0 byte file, it can't be anything else than None
return None
# no need to port to a file since we intend to deserialize
with BytesIO(b"".join(archived_result)) as f:
tar = tarfile.open(fileobj=f)
file = tar.extractfile(stat["name"])
lib = getattr(self, "pickling_library", pickle)
return lib.load(file)

def execute(self, context: Context) -> list[str] | str | None:
# Pull the docker image if `force_pull` is set or image does not exist locally
if self.force_pull or not self.cli.images(name=self.image):
Expand Down
38 changes: 38 additions & 0 deletions airflow/providers/docker/operators/docker_swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from typing import TYPE_CHECKING

from docker import types
from docker.errors import APIError

from airflow.exceptions import AirflowException
from airflow.providers.docker.operators.docker import DockerOperator
Expand Down Expand Up @@ -87,6 +88,10 @@ class DockerSwarmOperator(DockerOperator):
:param enable_logging: Show the application's logs in operator's logs.
Supported only if the Docker engine is using json-file or journald logging drivers.
The `tty` parameter should be set to use this with Python applications.
:param retrieve_output: Should this docker image consistently attempt to pull from and output
file before manually shutting down the image. Useful for cases where users want a pickle serialized
output that is not posted to logs
:param retrieve_output_path: path for output file that will be retrieved and passed to xcom
:param configs: List of docker configs to be exposed to the containers of the swarm service.
The configs are ConfigReference objects as per the docker api
[https://docker-py.readthedocs.io/en/stable/services.html#docker.models.services.ServiceCollection.create]_
Expand Down Expand Up @@ -122,6 +127,8 @@ def __init__(
self.args = args
self.enable_logging = enable_logging
self.service = None
self.tasks: list[dict] = []
self.containers: list[dict] = []
self.configs = configs
self.secrets = secrets
self.mode = mode
Expand Down Expand Up @@ -173,6 +180,18 @@ def _run_service(self) -> None:
self.log.info("Service status before exiting: %s", self._service_status())
break

if self.service and self._service_status() == "complete":
self.tasks = self.cli.tasks(filters={"service": self.service["ID"]})
for task in self.tasks:
container_id = task["Status"]["ContainerStatus"]["ContainerID"]
container = self.cli.inspect_container(container_id)
self.containers.append(container)
else:
raise AirflowException(f"Service did not complete: {self.service!r}")

if self.retrieve_output:
return self._attempt_to_retrieve_results()

self.log.info("auto_removeauto_removeauto_removeauto_removeauto_remove : %s", str(self.auto_remove))
if self.service and self._service_status() != "complete":
if self.auto_remove == "success":
Expand Down Expand Up @@ -230,6 +249,25 @@ def stream_new_logs(last_line_logged, since=0):
sleep(2)
last_line_logged, last_timestamp = stream_new_logs(last_line_logged, since=last_timestamp)

def _attempt_to_retrieve_results(self):
"""
Attempt to pull the result from the expected file for each containers.
This uses Docker's ``get_archive`` function. If the file is not yet
ready, *None* is returned.
"""
try:
file_contents = []
for container in self.containers:
file_content = self._copy_from_docker(container["Id"], self.retrieve_output_path)
file_contents.append(file_content)
if len(file_contents) == 1:
return file_contents[0]
else:
return file_contents
except APIError:
return None

@staticmethod
def format_args(args: list[str] | str | None) -> list[str] | None:
"""
Expand Down
24 changes: 17 additions & 7 deletions tests/providers/docker/operators/test_docker_swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def _client_tasks_side_effect():
for _ in range(2):
yield [{"Status": {"State": "pending"}}]
while True:
yield [{"Status": {"State": "complete"}}]
yield [{"Status": {"State": "complete", "ContainerStatus": {"ContainerID": "some_id"}}}]

def _client_service_logs_effect():
service_logs = [
Expand Down Expand Up @@ -123,7 +123,7 @@ def _client_service_logs_effect():
assert cskwargs["labels"] == {"name": "airflow__adhoc_airflow__unittest"}
assert cskwargs["name"].startswith("airflow-")
assert cskwargs["mode"] == types.ServiceMode(mode="replicated", replicas=3)
assert client_mock.tasks.call_count == 6
assert client_mock.tasks.call_count == 8
client_mock.remove_service.assert_called_once_with("some_id")

@mock.patch("airflow.providers.docker.operators.docker_swarm.types")
Expand All @@ -134,7 +134,9 @@ def test_auto_remove(self, types_mock, docker_api_client_patcher):
client_mock.create_service.return_value = {"ID": "some_id"}
client_mock.images.return_value = []
client_mock.pull.return_value = [b'{"status":"pull log"}']
client_mock.tasks.return_value = [{"Status": {"State": "complete"}}]
client_mock.tasks.return_value = [
{"Status": {"State": "complete", "ContainerStatus": {"ContainerID": "some_id"}}}
]
types_mock.TaskTemplate.return_value = mock_obj
types_mock.ContainerSpec.return_value = mock_obj
types_mock.RestartPolicy.return_value = mock_obj
Expand All @@ -157,7 +159,9 @@ def test_no_auto_remove(self, types_mock, docker_api_client_patcher):
client_mock.create_service.return_value = {"ID": "some_id"}
client_mock.images.return_value = []
client_mock.pull.return_value = [b'{"status":"pull log"}']
client_mock.tasks.return_value = [{"Status": {"State": "complete"}}]
client_mock.tasks.return_value = [
{"Status": {"State": "complete", "ContainerStatus": {"ContainerID": "some_id"}}}
]
types_mock.TaskTemplate.return_value = mock_obj
types_mock.ContainerSpec.return_value = mock_obj
types_mock.RestartPolicy.return_value = mock_obj
Expand Down Expand Up @@ -233,7 +237,9 @@ def test_container_resources(self, types_mock, docker_api_client_patcher):
client_mock.create_service.return_value = {"ID": "some_id"}
client_mock.images.return_value = []
client_mock.pull.return_value = [b'{"status":"pull log"}']
client_mock.tasks.return_value = [{"Status": {"State": "complete"}}]
client_mock.tasks.return_value = [
{"Status": {"State": "complete", "ContainerStatus": {"ContainerID": "some_id"}}}
]
types_mock.TaskTemplate.return_value = mock_obj
types_mock.ContainerSpec.return_value = mock_obj
types_mock.RestartPolicy.return_value = mock_obj
Expand Down Expand Up @@ -278,7 +284,9 @@ def test_service_args_str(self, types_mock, docker_api_client_patcher):
client_mock.create_service.return_value = {"ID": "some_id"}
client_mock.images.return_value = []
client_mock.pull.return_value = [b'{"status":"pull log"}']
client_mock.tasks.return_value = [{"Status": {"State": "complete"}}]
client_mock.tasks.return_value = [
{"Status": {"State": "complete", "ContainerStatus": {"ContainerID": "some_id"}}}
]
types_mock.TaskTemplate.return_value = mock_obj
types_mock.ContainerSpec.return_value = mock_obj
types_mock.RestartPolicy.return_value = mock_obj
Expand Down Expand Up @@ -316,7 +324,9 @@ def test_service_args_list(self, types_mock, docker_api_client_patcher):
client_mock.create_service.return_value = {"ID": "some_id"}
client_mock.images.return_value = []
client_mock.pull.return_value = [b'{"status":"pull log"}']
client_mock.tasks.return_value = [{"Status": {"State": "complete"}}]
client_mock.tasks.return_value = [
{"Status": {"State": "complete", "ContainerStatus": {"ContainerID": "some_id"}}}
]
types_mock.TaskTemplate.return_value = mock_obj
types_mock.ContainerSpec.return_value = mock_obj
types_mock.RestartPolicy.return_value = mock_obj
Expand Down

0 comments on commit 16f0073

Please sign in to comment.