Skip to content

Commit

Permalink
Container Image should be an input into computing the task version (#…
Browse files Browse the repository at this point in the history
…2194)

Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored Feb 17, 2024
1 parent d7f312b commit 05378f6
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 2 deletions.
14 changes: 13 additions & 1 deletion flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from flyteidl.admin.signal_pb2 import Signal, SignalListRequest, SignalSetRequest
from flyteidl.core import literals_pb2

from flytekit import ImageSpec
from flytekit.clients.friendly import SynchronousFlyteClient
from flytekit.clients.helpers import iterate_node_executions, iterate_task_executions
from flytekit.configuration import Config, FastSerializationSettings, ImageConfig, SerializationSettings
Expand Down Expand Up @@ -943,10 +944,21 @@ def register_script(
)

if version is None:

def _get_image_names(entity: typing.Union[PythonAutoContainerTask, WorkflowBase]) -> typing.List[str]:
if isinstance(entity, PythonAutoContainerTask) and isinstance(entity.container_image, ImageSpec):
return [entity.container_image.image_name()]
if isinstance(entity, WorkflowBase):
image_names = []
for n in entity.nodes:
image_names.extend(_get_image_names(n.flyte_entity))
return image_names
return []

# The md5 version that we send to S3/GCS has to match the file contents exactly,
# but we don't have to use it when registering with the Flyte backend.
# For that add the hash of the compilation settings to hash of file
version = self._version_from_hash(md5_bytes, serialization_settings)
version = self._version_from_hash(md5_bytes, serialization_settings, *_get_image_names(entity))

if isinstance(entity, PythonTask):
return self.register_task(entity, serialization_settings, version)
Expand Down
36 changes: 35 additions & 1 deletion tests/flytekit/unit/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from mock import ANY, MagicMock, patch

import flytekit.configuration
from flytekit import CronSchedule, LaunchPlan, WorkflowFailurePolicy, task, workflow
from flytekit import CronSchedule, ImageSpec, LaunchPlan, WorkflowFailurePolicy, task, workflow
from flytekit.configuration import Config, DefaultImages, Image, ImageConfig, SerializationSettings
from flytekit.core.base_task import PythonTask
from flytekit.core.context_manager import FlyteContextManager
Expand Down Expand Up @@ -387,6 +387,40 @@ def test_launch_backfill(remote):
assert wf.workflow_metadata.on_failure == WorkflowFailurePolicy.FAIL_IMMEDIATELY


@mock.patch("pathlib.Path.read_bytes")
@mock.patch("flytekit.remote.remote.FlyteRemote._version_from_hash")
@mock.patch("flytekit.remote.remote.FlyteRemote.register_workflow")
@mock.patch("flytekit.remote.remote.FlyteRemote.upload_file")
@mock.patch("flytekit.remote.remote.compress_scripts")
def test_get_image_names(
compress_scripts_mock, upload_file_mock, register_workflow_mock, version_from_hash_mock, read_bytes_mock
):
md5_bytes = bytes([1, 2, 3])
read_bytes_mock.return_value = bytes([4, 5, 6])
compress_scripts_mock.return_value = "compressed"
upload_file_mock.return_value = md5_bytes, "localhost:30084"

image_spec = ImageSpec(requirements="requirements.txt", registry="flyteorg")

@task(container_image=image_spec)
def say_hello(name: str) -> str:
return f"hello {name}!"

@workflow
def sub_wf(name: str = "union"):
say_hello(name=name)

@workflow
def wf(name: str = "union"):
sub_wf(name=name)

flyte_remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1")
flyte_remote.register_script(wf)

version_from_hash_mock.assert_called_once_with(md5_bytes, mock.ANY, image_spec.image_name())
register_workflow_mock.assert_called_once()


@mock.patch("flytekit.remote.remote.FlyteRemote.client")
def test_local_server(mock_client):
ctx = FlyteContextManager.current_context()
Expand Down

0 comments on commit 05378f6

Please sign in to comment.