Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ImageConfig to the serialization context for dynamic task #2456

Merged
merged 25 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flytekit/core/container_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def _get_image(self, settings: SerializationSettings) -> str:
if isinstance(self._image, ImageSpec):
# Set the source root for the image spec if it's non-fast registration
self._image.source_root = settings.source_root
return get_registerable_container_image(self._image, settings.image_config)
return get_registerable_container_image(self._image, settings.image_config, self.name)

def _get_container(self, settings: SerializationSettings) -> _task_model.Container:
env = settings.env or {}
Expand Down
15 changes: 11 additions & 4 deletions flytekit/core/python_auto_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def get_image(self, settings: SerializationSettings) -> str:
if isinstance(self.container_image, ImageSpec):
# Set the source root for the image spec if it's non-fast registration
self.container_image.source_root = settings.source_root
return get_registerable_container_image(self.container_image, settings.image_config)
return get_registerable_container_image(self.container_image, settings.image_config, self.name)

def get_container(self, settings: SerializationSettings) -> _task_model.Container:
# if pod_template is not None, return None here but in get_k8s_pod, return pod_template merged with container
Expand Down Expand Up @@ -264,7 +264,9 @@ def get_all_tasks(self) -> List[PythonAutoContainerTask]: # type: ignore
default_task_resolver = DefaultTaskResolver()


def get_registerable_container_image(img: Optional[Union[str, ImageSpec]], cfg: ImageConfig) -> str:
def get_registerable_container_image(
img: Optional[Union[str, ImageSpec]], cfg: ImageConfig, task_name: Optional[str] = None
) -> str:
"""
Resolve the image to the real image name that should be used for registration.
1. If img is a ImageSpec, it will be built and the image name will be returned
Expand All @@ -273,11 +275,16 @@ def get_registerable_container_image(img: Optional[Union[str, ImageSpec]], cfg:

:param img: Configured image or image spec
:param cfg: Registration configuration
:param task_name: The name of the container task.
:return:
"""
if isinstance(img, ImageSpec):
ImageBuildEngine.build(img)
return img.image_name()
image = cfg.find_image(f"ft_{task_name}") if task_name else None
image_name = image.full if image else None
if not image_name:
ImageBuildEngine.build(img)
image_name = img.image_name()
return image_name

if img is not None and img != "":
matches = _IMAGE_REPLACE_REGEX.findall(img)
Expand Down
2 changes: 1 addition & 1 deletion flytekit/core/python_customized_container_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def get_image(self, settings: SerializationSettings) -> str:
if isinstance(self.container_image, ImageSpec):
# Set the source root for the image spec if it's non-fast registration
self.container_image.source_root = settings.source_root
return get_registerable_container_image(self.container_image, settings.image_config)
return get_registerable_container_image(self.container_image, settings.image_config, self.name)

def get_container(self, settings: SerializationSettings) -> _task_model.Container:
env = {**settings.env, **self.environment} if self.environment else settings.env
Expand Down
14 changes: 12 additions & 2 deletions flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@

from flyteidl.admin import schedule_pb2

from flytekit import PythonFunctionTask, SourceCode
from flytekit.configuration import SerializationSettings
from flytekit import ImageSpec, PythonFunctionTask, SourceCode
from flytekit.configuration import Image, ImageConfig, SerializationSettings
from flytekit.core import constants as _common_constants
from flytekit.core import context_manager
from flytekit.core.array_node_map_task import ArrayNodeMapTask
from flytekit.core.base_task import PythonTask
from flytekit.core.condition import BranchNode
Expand Down Expand Up @@ -176,6 +177,15 @@ def get_serializable_task(
)

if isinstance(entity, PythonFunctionTask) and entity.execution_mode == PythonFunctionTask.ExecutionBehavior.DYNAMIC:
for e in context_manager.FlyteEntities.entities:
if isinstance(e, PythonAutoContainerTask):
# 1. Build the ImageSpec for the entities that are inside the dynamic task,
pingsutw marked this conversation as resolved.
Show resolved Hide resolved
# 2. Add images to the serialization context, so the dynamic task can look it up at runtime.
if isinstance(e.container_image, ImageSpec):
if settings.image_config.images is None:
settings.image_config = ImageConfig.create_from(settings.image_config.default_image)
settings.image_config.images.append(Image.look_up_image_info(f"ft_{e.name}", e.get_image(settings)))

# In case of Dynamic tasks, we want to pass the serialization context, so that they can reconstruct the state
# from the serialization context. This is passed through an environment variable, that is read from
# during dynamic serialization
Expand Down
2 changes: 1 addition & 1 deletion plugins/flytekit-spark/flytekitplugins/spark/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def get_image(self, settings: SerializationSettings) -> str:
# Ensure that the code is always copied into the image, even during fast-registration.
self.container_image.source_root = settings.source_root

return get_registerable_container_image(self.container_image, settings.image_config)
return get_registerable_container_image(self.container_image, settings.image_config, self.name)

def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
job = SparkJob(
Expand Down
5 changes: 4 additions & 1 deletion tests/flytekit/unit/core/test_node_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@
from flytekit.core.workflow import workflow
from flytekit.exceptions.user import FlyteAssertion
from flytekit.extras.accelerators import A100, T4
from flytekit.image_spec.image_spec import ImageBuildEngine
from flytekit.models import literals as _literal_models
from flytekit.models.task import Resources as _resources_models
from flytekit.tools.translator import get_serializable


def test_normal_task():
def test_normal_task(mock_image_spec_builder):
ImageBuildEngine.register("test", mock_image_spec_builder)

@task
def t1(a: str) -> str:
return a + " world"
Expand Down
8 changes: 7 additions & 1 deletion tests/flytekit/unit/core/test_python_auto_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,15 @@ def serialization_settings(request):


def test_image_name_interpolation(default_image_config):
new_img_cfg = ImageConfig.create_from(
default_image_config.default_image, other_images=[Image.look_up_image_info("ft_d1", "flyte/test:d1")]
)
img_to_interpolate = "{{.image.default.fqn}}:{{.image.default.version}}-special"
img = get_registerable_container_image(img=img_to_interpolate, cfg=default_image_config)
img = get_registerable_container_image(img=img_to_interpolate, cfg=new_img_cfg)
assert img == "docker.io/xyz:some-git-hash-special"
image = ImageSpec(name="image-1", registry="localhost:30000", builder="test")
img = get_registerable_container_image(img=image, cfg=new_img_cfg, task_name="d1")
assert img == "flyte/test:d1"


class DummyAutoContainerTask(PythonAutoContainerTask):
Expand Down
26 changes: 22 additions & 4 deletions tests/flytekit/unit/core/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
import pytest

import flytekit.configuration
from flytekit import ContainerTask, kwtypes
from flytekit import ContainerTask, ImageSpec, kwtypes
from flytekit.configuration import Image, ImageConfig, SerializationSettings
from flytekit.core.condition import conditional
from flytekit.core.python_auto_container import get_registerable_container_image
from flytekit.core.task import task
from flytekit.core.workflow import workflow
from flytekit.image_spec.image_spec import ImageBuildEngine
from flytekit.models.admin.workflow import WorkflowSpec
from flytekit.models.types import SimpleType
from flytekit.tools.translator import get_serializable
Expand Down Expand Up @@ -250,7 +251,9 @@ def test_bad_configuration():
get_registerable_container_image(container_image, image_config)


def test_serialization_images():
def test_serialization_images(mock_image_spec_builder):
ImageBuildEngine.register("test", mock_image_spec_builder)

@task(container_image="{{.image.xyz.fqn}}:{{.image.xyz.version}}")
def t1(a: int) -> int:
return a
Expand All @@ -271,10 +274,22 @@ def t5(a: int) -> int:
def t6(a: int) -> int:
return a

image_spec = ImageSpec(
packages=["mypy"],
apt_packages=["git"],
registry="ghcr.io/flyteorg",
builder="test",
)

@task(container_image=image_spec)
def t7(a: int) -> int:
return a

with mock.patch.dict(os.environ, {"FLYTE_INTERNAL_IMAGE": "docker.io/default:version"}):
imgs = ImageConfig.auto(
config_file=os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/images.config")
)
imgs.images.append(Image(name=f"ft_{t7.name}", fqn="docker.io/t7", tag="latest"))
rs = flytekit.configuration.SerializationSettings(
project="project",
domain="domain",
Expand All @@ -295,8 +310,11 @@ def t6(a: int) -> int:
t5_spec = get_serializable(OrderedDict(), rs, t5)
assert t5_spec.template.container.image == "docker.io/org/myimage:latest"

t5_spec = get_serializable(OrderedDict(), rs, t6)
assert t5_spec.template.container.image == "docker.io/xyz_123:v1"
t6_spec = get_serializable(OrderedDict(), rs, t6)
assert t6_spec.template.container.image == "docker.io/xyz_123:v1"

t7_spec = get_serializable(OrderedDict(), rs, t7)
assert t7_spec.template.container.image == "docker.io/t7:latest"


def test_serialization_command1():
Expand Down
Loading