Skip to content

Commit

Permalink
Add ImageConfig to the serialization context for dynamic task (#2456)
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <[email protected]>
Signed-off-by: Jan Fiedler <[email protected]>
  • Loading branch information
pingsutw authored and fiedlerNr9 committed Jul 25, 2024
1 parent d4b5799 commit cc37c50
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 15 deletions.
10 changes: 7 additions & 3 deletions flytekit/core/python_auto_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from flytekit.core.tracker import TrackedInstance, extract_task_module
from flytekit.core.utils import _get_container_definition, _serialize_pod_spec, timeit
from flytekit.extras.accelerators import BaseAccelerator
from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpec
from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpec, _calculate_deduped_hash_from_image_spec
from flytekit.loggers import logger
from flytekit.models import task as _task_model
from flytekit.models.security import Secret, SecurityContext
Expand Down Expand Up @@ -276,8 +276,12 @@ def get_registerable_container_image(img: Optional[Union[str, ImageSpec]], cfg:
:return:
"""
if isinstance(img, ImageSpec):
ImageBuildEngine.build(img)
return img.image_name()
image = cfg.find_image(_calculate_deduped_hash_from_image_spec(img))
image_name = image.full if image else None
if not image_name:
ImageBuildEngine.build(img)
image_name = img.image_name()
return image_name

if img is not None and img != "":
matches = _IMAGE_REPLACE_REGEX.findall(img)
Expand Down
15 changes: 15 additions & 0 deletions flytekit/image_spec/image_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,21 @@ def _build_image(cls, builder, image_spec, img_name):
cls._IMAGE_NAME_TO_REAL_NAME[img_name] = fully_qualified_image_name


@lru_cache
def _calculate_deduped_hash_from_image_spec(image_spec: ImageSpec):
"""
Calculate this special hash from the image spec,
and it used to identify the imageSpec in the ImageConfig in the serialization context.
ImageConfig:
- deduced hash 1: flyteorg/flytekit: 123
- deduced hash 2: flyteorg/flytekit: 456
"""
image_spec_bytes = asdict(image_spec).__str__().encode("utf-8")
# copy the image spec to avoid modifying the original image spec. otherwise, the hash will be different.
return base64.urlsafe_b64encode(hashlib.md5(image_spec_bytes).digest()).decode("ascii").rstrip("=")


@lru_cache
def calculate_hash_from_image_spec(image_spec: ImageSpec):
"""
Expand Down
19 changes: 17 additions & 2 deletions flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@

from flyteidl.admin import schedule_pb2

from flytekit import PythonFunctionTask, SourceCode
from flytekit.configuration import SerializationSettings
from flytekit import ImageSpec, PythonFunctionTask, SourceCode
from flytekit.configuration import Image, ImageConfig, SerializationSettings
from flytekit.core import constants as _common_constants
from flytekit.core import context_manager
from flytekit.core.array_node_map_task import ArrayNodeMapTask
from flytekit.core.base_task import PythonTask
from flytekit.core.condition import BranchNode
Expand All @@ -22,6 +23,7 @@
from flytekit.core.task import ReferenceTask
from flytekit.core.utils import ClassDecorator, _dnsify
from flytekit.core.workflow import ReferenceWorkflow, WorkflowBase
from flytekit.image_spec.image_spec import _calculate_deduped_hash_from_image_spec
from flytekit.models import common as _common_models
from flytekit.models import common as common_models
from flytekit.models import interface as interface_models
Expand Down Expand Up @@ -176,6 +178,19 @@ def get_serializable_task(
)

if isinstance(entity, PythonFunctionTask) and entity.execution_mode == PythonFunctionTask.ExecutionBehavior.DYNAMIC:
for e in context_manager.FlyteEntities.entities:
if isinstance(e, PythonAutoContainerTask):
# 1. Build the ImageSpec for all the entities that are inside the current context,
# 2. Add images to the serialization context, so the dynamic task can look it up at runtime.
if isinstance(e.container_image, ImageSpec):
if settings.image_config.images is None:
settings.image_config = ImageConfig.create_from(settings.image_config.default_image)
settings.image_config.images.append(
Image.look_up_image_info(
_calculate_deduped_hash_from_image_spec(e.container_image), e.get_image(settings)
)
)

# In case of Dynamic tasks, we want to pass the serialization context, so that they can reconstruct the state
# from the serialization context. This is passed through an environment variable, that is read from
# during dynamic serialization
Expand Down
6 changes: 3 additions & 3 deletions plugins/flytekit-envd/tests/test_image_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def build():
run(commands=["echo hello"])
install.python_packages(name=["pandas"])
install.apt_packages(name=["git"])
runtime.environ(env={{'PYTHONPATH': '/root', '_F_IMG_ID': '{image_name}'}}, extra_path=['/root'])
runtime.environ(env={{'PYTHONPATH': '/root:', '_F_IMG_ID': '{image_name}'}}, extra_path=['/root'])
config.pip_index(url="https://private-pip-index/simple")
install.python(version="3.8")
io.copy(source="./", target="/root")
Expand Down Expand Up @@ -88,7 +88,7 @@ def build():
run(commands=[])
install.python_packages(name=["flytekit"])
install.apt_packages(name=[])
runtime.environ(env={{'PYTHONPATH': '/root', '_F_IMG_ID': '{image_name}'}}, extra_path=['/root'])
runtime.environ(env={{'PYTHONPATH': '/root:', '_F_IMG_ID': '{image_name}'}}, extra_path=['/root'])
config.pip_index(url="https://pypi.org/simple")
install.conda(use_mamba=True)
install.conda_packages(name=["pytorch", "cpuonly"], channel=["pytorch"])
Expand Down Expand Up @@ -122,7 +122,7 @@ def build():
run(commands=[])
install.python_packages(name=["-U --pre pandas", "torch", "torchvision"])
install.apt_packages(name=[])
runtime.environ(env={{'PYTHONPATH': '/root', '_F_IMG_ID': '{image_name}'}}, extra_path=['/root'])
runtime.environ(env={{'PYTHONPATH': '/root:', '_F_IMG_ID': '{image_name}'}}, extra_path=['/root'])
config.pip_index(url="https://pypi.org/simple", extra_url="https://download.pytorch.org/whl/cpu https://pypi.anaconda.org/scientific-python-nightly-wheels/simple")
"""
)
Expand Down
5 changes: 4 additions & 1 deletion tests/flytekit/unit/core/test_node_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@
from flytekit.core.workflow import workflow
from flytekit.exceptions.user import FlyteAssertion
from flytekit.extras.accelerators import A100, T4
from flytekit.image_spec.image_spec import ImageBuildEngine
from flytekit.models import literals as _literal_models
from flytekit.models.task import Resources as _resources_models
from flytekit.tools.translator import get_serializable


def test_normal_task():
def test_normal_task(mock_image_spec_builder):
ImageBuildEngine.register("test", mock_image_spec_builder)

@task
def t1(a: str) -> str:
return a + " world"
Expand Down
12 changes: 10 additions & 2 deletions tests/flytekit/unit/core/test_python_auto_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from flytekit.core.pod_template import PodTemplate
from flytekit.core.python_auto_container import PythonAutoContainerTask, get_registerable_container_image
from flytekit.core.resources import Resources
from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpec
from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpec, _calculate_deduped_hash_from_image_spec
from flytekit.tools.translator import get_serializable_task


Expand Down Expand Up @@ -55,9 +55,17 @@ def serialization_settings(request):


def test_image_name_interpolation(default_image_config):
image_spec = ImageSpec(name="image-1", registry="localhost:30000", builder="test")

new_img_cfg = ImageConfig.create_from(
default_image_config.default_image,
other_images=[Image.look_up_image_info(_calculate_deduped_hash_from_image_spec(image_spec), "flyte/test:d1")],
)
img_to_interpolate = "{{.image.default.fqn}}:{{.image.default.version}}-special"
img = get_registerable_container_image(img=img_to_interpolate, cfg=default_image_config)
img = get_registerable_container_image(img=img_to_interpolate, cfg=new_img_cfg)
assert img == "docker.io/xyz:some-git-hash-special"
img = get_registerable_container_image(img=image_spec, cfg=new_img_cfg)
assert img == "flyte/test:d1"


class DummyAutoContainerTask(PythonAutoContainerTask):
Expand Down
28 changes: 24 additions & 4 deletions tests/flytekit/unit/core/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
import pytest

import flytekit.configuration
from flytekit import ContainerTask, kwtypes
from flytekit import ContainerTask, ImageSpec, kwtypes
from flytekit.configuration import Image, ImageConfig, SerializationSettings
from flytekit.core.condition import conditional
from flytekit.core.python_auto_container import get_registerable_container_image
from flytekit.core.task import task
from flytekit.core.workflow import workflow
from flytekit.image_spec.image_spec import ImageBuildEngine, _calculate_deduped_hash_from_image_spec
from flytekit.models.admin.workflow import WorkflowSpec
from flytekit.models.types import SimpleType
from flytekit.tools.translator import get_serializable
Expand Down Expand Up @@ -250,7 +251,9 @@ def test_bad_configuration():
get_registerable_container_image(container_image, image_config)


def test_serialization_images():
def test_serialization_images(mock_image_spec_builder):
ImageBuildEngine.register("test", mock_image_spec_builder)

@task(container_image="{{.image.xyz.fqn}}:{{.image.xyz.version}}")
def t1(a: int) -> int:
return a
Expand All @@ -271,10 +274,24 @@ def t5(a: int) -> int:
def t6(a: int) -> int:
return a

image_spec = ImageSpec(
packages=["mypy"],
apt_packages=["git"],
registry="ghcr.io/flyteorg",
builder="test",
)

@task(container_image=image_spec)
def t7(a: int) -> int:
return a

with mock.patch.dict(os.environ, {"FLYTE_INTERNAL_IMAGE": "docker.io/default:version"}):
imgs = ImageConfig.auto(
config_file=os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/images.config")
)
imgs.images.append(
Image(name=_calculate_deduped_hash_from_image_spec(image_spec), fqn="docker.io/t7", tag="latest")
)
rs = flytekit.configuration.SerializationSettings(
project="project",
domain="domain",
Expand All @@ -295,8 +312,11 @@ def t6(a: int) -> int:
t5_spec = get_serializable(OrderedDict(), rs, t5)
assert t5_spec.template.container.image == "docker.io/org/myimage:latest"

t5_spec = get_serializable(OrderedDict(), rs, t6)
assert t5_spec.template.container.image == "docker.io/xyz_123:v1"
t6_spec = get_serializable(OrderedDict(), rs, t6)
assert t6_spec.template.container.image == "docker.io/xyz_123:v1"

t7_spec = get_serializable(OrderedDict(), rs, t7)
assert t7_spec.template.container.image == "docker.io/t7:latest"


def test_serialization_command1():
Expand Down

0 comments on commit cc37c50

Please sign in to comment.