Skip to content

Commit

Permalink
add env vars option in pyflyte package (#2171)
Browse files Browse the repository at this point in the history
  • Loading branch information
fiedlerNr9 authored Feb 8, 2024
1 parent d2c6353 commit 8b77a18
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 5 deletions.
29 changes: 27 additions & 2 deletions flytekit/clis/sdk_in_container/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
ImageConfig,
SerializationSettings,
)
from flytekit.interaction.click_types import key_value_callback
from flytekit.tools.repo import NoSerializableEntitiesError, serialize_and_package


Expand Down Expand Up @@ -83,9 +84,27 @@
is_flag=True,
help="Enables symlink dereferencing when packaging files in fast registration",
)
@click.option(
"--env",
"--envvars",
required=False,
multiple=True,
type=str,
callback=key_value_callback,
help="Environment variables to set in the container, of the format `ENV_NAME=ENV_VALUE`",
)
@click.pass_context
def package(
ctx, image_config, source, output, force, fast, in_container_source_path, python_interpreter, deref_symlinks
ctx,
image_config,
source,
output,
force,
fast,
in_container_source_path,
python_interpreter,
deref_symlinks,
env,
):
"""
This command produces a Flyte backend registrable package of all entities in Flyte.
Expand All @@ -95,7 +114,12 @@ def package(
This serialization step will set the name of the tasks to the fully qualified name of the task function.
"""
if os.path.exists(output) and not force:
raise click.BadParameter(click.style(f"Output file {output} already exists, specify -f to override.", fg="red"))
raise click.BadParameter(
click.style(
f"Output file {output} already exists, specify -f to override.",
fg="red",
)
)

serialization_settings = SerializationSettings(
image_config=image_config,
Expand All @@ -104,6 +128,7 @@ def package(
destination_dir=in_container_source_path,
),
python_interpreter=python_interpreter,
env=env,
)

pkgs = ctx.obj[constants.CTX_PACKAGES]
Expand Down
82 changes: 79 additions & 3 deletions tests/flytekit/unit/cli/pyflyte/test_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import shutil

from click.testing import CliRunner
from flyteidl.admin import task_pb2

import flytekit
import flytekit.clis.sdk_in_container.utils
Expand Down Expand Up @@ -80,26 +81,82 @@ def test_get_registrable_entities():
assert False, f"found unknown entity {type(e)}"


def test_package_with_fast_registration():
def test_package_with_fast_registration_and_envvars():
runner = CliRunner()
with runner.isolated_filesystem():
os.makedirs("core", exist_ok=True)
with open(os.path.join("core", "sample.py"), "w") as f:
f.write(sample_file_contents)
f.close()
result = runner.invoke(pyflyte.main, ["--pkgs", "core", "package", "--image", "core:v1", "--fast"])
result = runner.invoke(
pyflyte.main,
[
"--pkgs",
"core",
"package",
"--image",
"core:v1",
"--fast",
"--env",
"abc=42",
"--env",
"euler=2.71828",
],
)
assert result.exit_code == 0
assert "Successfully serialized" in result.output
assert "Successfully packaged" in result.output

# verify existence of flyte-package.tgz file
assert os.path.exists("flyte-package.tgz")

# verify the contents of the flyte-package.tgz file
import tarfile

# Uncompress flyte-package.tgz
tarfile.open("flyte-package.tgz", "r:gz").extractall()

# Load the proto message from file 3_core.sample.sum_1.pb
task_spec = task_pb2.TaskSpec()
task_spec.ParseFromString(open("3_core.sample.sum_1.pb", "rb").read())

assert task_spec.template.container.env[0].key == "abc"
assert task_spec.template.container.env[0].value == "42"
assert task_spec.template.container.env[1].key == "euler"
assert task_spec.template.container.env[1].value == "2.71828"

result = runner.invoke(pyflyte.main, ["--pkgs", "core", "package", "--image", "core:v1", "--fast"])
assert result.exit_code == 2
assert "flyte-package.tgz already exists, specify -f to override" in result.output
result = runner.invoke(
pyflyte.main,
["--pkgs", "core", "package", "--image", "core:v1", "--fast", "--force"],
[
"--pkgs",
"core",
"package",
"--image",
"core:v1",
"--fast",
"--force",
"--env",
"k1=v1",
"--env",
"pi=3.14159265",
],
)
assert result.exit_code == 0
assert "deleting and re-creating it" in result.output

tarfile.open("flyte-package.tgz", "r:gz").extractall()

# Load the proto message from file 3_core.sample.sum_1.pb
task_spec = task_pb2.TaskSpec()
task_spec.ParseFromString(open("3_core.sample.sum_1.pb", "rb").read())

assert task_spec.template.container.env[0].key == "k1"
assert task_spec.template.container.env[0].value == "v1"
assert task_spec.template.container.env[1].key == "pi"
assert task_spec.template.container.env[1].value == "3.14159265"
shutil.rmtree("core")


Expand Down Expand Up @@ -131,3 +188,22 @@ def test_package_with_no_pkgs():
result = runner.invoke(pyflyte.main, ["package"])
assert result.exit_code == 1
assert "No packages to scan for flyte entities. Aborting!" in result.output


def test_package_with_envs_wrong_format():
runner = CliRunner()
with runner.isolated_filesystem():
result = runner.invoke(
pyflyte.main,
[
"--pkgs",
"flytekit.unit.cli.pyflyte.test_package",
"package",
"--image",
"myapp:03eccc1cf101adbd8c4734dba865d3fdeb720aa7",
"--env",
"Key0:Value0",
],
)
assert result.exit_code == 2
assert "Expected key-value pair of the form key=value, got" in result.output

0 comments on commit 8b77a18

Please sign in to comment.