Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Container Image should be an input into computing the task version #2194

Merged
merged 3 commits into from
Feb 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from flyteidl.admin.signal_pb2 import Signal, SignalListRequest, SignalSetRequest
from flyteidl.core import literals_pb2 as literals_pb2

from flytekit import ImageSpec
from flytekit.clients.friendly import SynchronousFlyteClient
from flytekit.clients.helpers import iterate_node_executions, iterate_task_executions
from flytekit.configuration import Config, FastSerializationSettings, ImageConfig, SerializationSettings
Expand Down Expand Up @@ -939,10 +940,21 @@ def register_script(
)

if version is None:

def _get_image_names(entity: typing.Union[PythonAutoContainerTask, WorkflowBase]) -> typing.List[str]:
if isinstance(entity, PythonAutoContainerTask) and isinstance(entity.container_image, ImageSpec):
return [entity.container_image.image_name()]
if isinstance(entity, WorkflowBase):
image_names = []
for n in entity.nodes:
image_names.extend(_get_image_names(n.flyte_entity))
return image_names
return []

# The md5 version that we send to S3/GCS has to match the file contents exactly,
# but we don't have to use it when registering with the Flyte backend.
# For that add the hash of the compilation settings to hash of file
version = self._version_from_hash(md5_bytes, serialization_settings)
version = self._version_from_hash(md5_bytes, serialization_settings, *_get_image_names(entity))

if isinstance(entity, PythonTask):
return self.register_task(entity, serialization_settings, version)
Expand Down
36 changes: 35 additions & 1 deletion tests/flytekit/unit/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from mock import ANY, MagicMock, patch

import flytekit.configuration
from flytekit import CronSchedule, LaunchPlan, WorkflowFailurePolicy, task, workflow
from flytekit import CronSchedule, ImageSpec, LaunchPlan, WorkflowFailurePolicy, task, workflow
from flytekit.configuration import Config, DefaultImages, Image, ImageConfig, SerializationSettings
from flytekit.core.base_task import PythonTask
from flytekit.core.context_manager import FlyteContextManager
Expand Down Expand Up @@ -387,6 +387,40 @@ def test_launch_backfill(remote):
assert wf.workflow_metadata.on_failure == WorkflowFailurePolicy.FAIL_IMMEDIATELY


@mock.patch("pathlib.Path.read_bytes")
@mock.patch("flytekit.remote.remote.FlyteRemote._version_from_hash")
@mock.patch("flytekit.remote.remote.FlyteRemote.register_workflow")
@mock.patch("flytekit.remote.remote.FlyteRemote.upload_file")
@mock.patch("flytekit.remote.remote.compress_scripts")
def test_get_image_names(
compress_scripts_mock, upload_file_mock, register_workflow_mock, version_from_hash_mock, read_bytes_mock
):
md5_bytes = bytes([1, 2, 3])
read_bytes_mock.return_value = bytes([4, 5, 6])
compress_scripts_mock.return_value = "compressed"
upload_file_mock.return_value = md5_bytes, "localhost:30084"

image_spec = ImageSpec(requirements="requirements.txt", registry="flyteorg")

@task(container_image=image_spec)
def say_hello(name: str) -> str:
return f"hello {name}!"

@workflow
def sub_wf(name: str = "union"):
say_hello(name=name)

@workflow
def wf(name: str = "union"):
sub_wf(name=name)

flyte_remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1")
flyte_remote.register_script(wf)

version_from_hash_mock.assert_called_once_with(md5_bytes, mock.ANY, image_spec.image_name())
register_workflow_mock.assert_called_once()


@mock.patch("flytekit.remote.remote.FlyteRemote.client")
def test_local_server(mock_client):
ctx = FlyteContextManager.current_context()
Expand Down
Loading