Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ImageConfig to the serialization context for dynamic task #2456

Merged
merged 25 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions flytekit/core/python_auto_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from flytekit.core.tracker import TrackedInstance, extract_task_module
from flytekit.core.utils import _get_container_definition, _serialize_pod_spec, timeit
from flytekit.extras.accelerators import BaseAccelerator
from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpec
from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpec, _calculate_deduced_hash_from_image_spec
from flytekit.loggers import logger
from flytekit.models import task as _task_model
from flytekit.models.security import Secret, SecurityContext
Expand Down Expand Up @@ -276,8 +276,12 @@ def get_registerable_container_image(img: Optional[Union[str, ImageSpec]], cfg:
:return:
"""
if isinstance(img, ImageSpec):
ImageBuildEngine.build(img)
return img.image_name()
image = cfg.find_image(_calculate_deduced_hash_from_image_spec(img))
image_name = image.full if image else None
if not image_name:
ImageBuildEngine.build(img)
image_name = img.image_name()
return image_name

if img is not None and img != "":
matches = _IMAGE_REPLACE_REGEX.findall(img)
Expand Down
15 changes: 15 additions & 0 deletions flytekit/image_spec/image_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,21 @@ def _build_image(cls, builder, image_spec, img_name):
cls._IMAGE_NAME_TO_REAL_NAME[img_name] = fully_qualified_image_name


@lru_cache
def _calculate_deduced_hash_from_image_spec(image_spec: ImageSpec):
"""
Calculate this special hash from the image spec,
and it used to identify the imageSpec in the ImageConfig in the serialization context.

ImageConfig:
- deduced hash 1: flyteorg/flytekit: 123
- deduced hash 2: flyteorg/flytekit: 456
"""
image_spec_bytes = asdict(image_spec).__str__().encode("utf-8")
# copy the image spec to avoid modifying the original image spec. otherwise, the hash will be different.
return base64.urlsafe_b64encode(hashlib.md5(image_spec_bytes).digest()).decode("ascii").rstrip("=")


@lru_cache
def calculate_hash_from_image_spec(image_spec: ImageSpec):
"""
Expand Down
19 changes: 17 additions & 2 deletions flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@

from flyteidl.admin import schedule_pb2

from flytekit import PythonFunctionTask, SourceCode
from flytekit.configuration import SerializationSettings
from flytekit import ImageSpec, PythonFunctionTask, SourceCode
from flytekit.configuration import Image, ImageConfig, SerializationSettings
from flytekit.core import constants as _common_constants
from flytekit.core import context_manager
from flytekit.core.array_node_map_task import ArrayNodeMapTask
from flytekit.core.base_task import PythonTask
from flytekit.core.condition import BranchNode
Expand All @@ -22,6 +23,7 @@
from flytekit.core.task import ReferenceTask
from flytekit.core.utils import ClassDecorator, _dnsify
from flytekit.core.workflow import ReferenceWorkflow, WorkflowBase
from flytekit.image_spec.image_spec import _calculate_deduced_hash_from_image_spec
from flytekit.models import common as _common_models
from flytekit.models import common as common_models
from flytekit.models import interface as interface_models
Expand Down Expand Up @@ -176,6 +178,19 @@ def get_serializable_task(
)

if isinstance(entity, PythonFunctionTask) and entity.execution_mode == PythonFunctionTask.ExecutionBehavior.DYNAMIC:
for e in context_manager.FlyteEntities.entities:
if isinstance(e, PythonAutoContainerTask):
# 1. Build the ImageSpec for all the entities that are inside the current context,
# 2. Add images to the serialization context, so the dynamic task can look it up at runtime.
if isinstance(e.container_image, ImageSpec):
if settings.image_config.images is None:
settings.image_config = ImageConfig.create_from(settings.image_config.default_image)
settings.image_config.images.append(
Image.look_up_image_info(
_calculate_deduced_hash_from_image_spec(e.container_image), e.get_image(settings)
)
)

# In case of Dynamic tasks, we want to pass the serialization context, so that they can reconstruct the state
# from the serialization context. This is passed through an environment variable, that is read from
# during dynamic serialization
Expand Down
5 changes: 4 additions & 1 deletion tests/flytekit/unit/core/test_node_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@
from flytekit.core.workflow import workflow
from flytekit.exceptions.user import FlyteAssertion
from flytekit.extras.accelerators import A100, T4
from flytekit.image_spec.image_spec import ImageBuildEngine
from flytekit.models import literals as _literal_models
from flytekit.models.task import Resources as _resources_models
from flytekit.tools.translator import get_serializable


def test_normal_task():
def test_normal_task(mock_image_spec_builder):
ImageBuildEngine.register("test", mock_image_spec_builder)

@task
def t1(a: str) -> str:
return a + " world"
Expand Down
12 changes: 10 additions & 2 deletions tests/flytekit/unit/core/test_python_auto_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from flytekit.core.pod_template import PodTemplate
from flytekit.core.python_auto_container import PythonAutoContainerTask, get_registerable_container_image
from flytekit.core.resources import Resources
from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpec
from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpec, _calculate_deduced_hash_from_image_spec
from flytekit.tools.translator import get_serializable_task


Expand Down Expand Up @@ -55,9 +55,17 @@ def serialization_settings(request):


def test_image_name_interpolation(default_image_config):
image_spec = ImageSpec(name="image-1", registry="localhost:30000", builder="test")

new_img_cfg = ImageConfig.create_from(
default_image_config.default_image,
other_images=[Image.look_up_image_info(_calculate_deduced_hash_from_image_spec(image_spec), "flyte/test:d1")],
)
img_to_interpolate = "{{.image.default.fqn}}:{{.image.default.version}}-special"
img = get_registerable_container_image(img=img_to_interpolate, cfg=default_image_config)
img = get_registerable_container_image(img=img_to_interpolate, cfg=new_img_cfg)
assert img == "docker.io/xyz:some-git-hash-special"
img = get_registerable_container_image(img=image_spec, cfg=new_img_cfg)
assert img == "flyte/test:d1"


class DummyAutoContainerTask(PythonAutoContainerTask):
Expand Down
28 changes: 24 additions & 4 deletions tests/flytekit/unit/core/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
import pytest

import flytekit.configuration
from flytekit import ContainerTask, kwtypes
from flytekit import ContainerTask, ImageSpec, kwtypes
from flytekit.configuration import Image, ImageConfig, SerializationSettings
from flytekit.core.condition import conditional
from flytekit.core.python_auto_container import get_registerable_container_image
from flytekit.core.task import task
from flytekit.core.workflow import workflow
from flytekit.image_spec.image_spec import ImageBuildEngine, _calculate_deduced_hash_from_image_spec
from flytekit.models.admin.workflow import WorkflowSpec
from flytekit.models.types import SimpleType
from flytekit.tools.translator import get_serializable
Expand Down Expand Up @@ -250,7 +251,9 @@ def test_bad_configuration():
get_registerable_container_image(container_image, image_config)


def test_serialization_images():
def test_serialization_images(mock_image_spec_builder):
ImageBuildEngine.register("test", mock_image_spec_builder)

@task(container_image="{{.image.xyz.fqn}}:{{.image.xyz.version}}")
def t1(a: int) -> int:
return a
Expand All @@ -271,10 +274,24 @@ def t5(a: int) -> int:
def t6(a: int) -> int:
return a

image_spec = ImageSpec(
packages=["mypy"],
apt_packages=["git"],
registry="ghcr.io/flyteorg",
builder="test",
)

@task(container_image=image_spec)
def t7(a: int) -> int:
return a

with mock.patch.dict(os.environ, {"FLYTE_INTERNAL_IMAGE": "docker.io/default:version"}):
imgs = ImageConfig.auto(
config_file=os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/images.config")
)
imgs.images.append(
Image(name=_calculate_deduced_hash_from_image_spec(image_spec), fqn="docker.io/t7", tag="latest")
)
rs = flytekit.configuration.SerializationSettings(
project="project",
domain="domain",
Expand All @@ -295,8 +312,11 @@ def t6(a: int) -> int:
t5_spec = get_serializable(OrderedDict(), rs, t5)
assert t5_spec.template.container.image == "docker.io/org/myimage:latest"

t5_spec = get_serializable(OrderedDict(), rs, t6)
assert t5_spec.template.container.image == "docker.io/xyz_123:v1"
t6_spec = get_serializable(OrderedDict(), rs, t6)
assert t6_spec.template.container.image == "docker.io/xyz_123:v1"

t7_spec = get_serializable(OrderedDict(), rs, t7)
assert t7_spec.template.container.image == "docker.io/t7:latest"


def test_serialization_command1():
Expand Down
Loading