Skip to content

Commit

Permalink
Allow to pass extra env vars to the nodes (#46)
Browse files Browse the repository at this point in the history
  • Loading branch information
szczeles authored and j0rd1smit committed Feb 17, 2023
1 parent f7b2e47 commit 9c58ded
Show file tree
Hide file tree
Showing 11 changed files with 107 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests_and_publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ jobs:
AZURE_CLIENT_SECRET: ${{ secrets.AZURE_CLIENT_SECRET }}
AZURE_CLIENT_ID: ${{ secrets.AZURE_CLIENT_ID }}
run: |
kedro azureml run --wait-for-completion
kedro azureml run --wait-for-completion --env-var 'GETINDATA=ROCKS!'
publish:
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pycqa/isort
rev: 5.10.1
rev: 5.12.0
hooks:
- id: isort
args: ["--profile", "black", "--line-length=79"]
Expand Down
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

## [Unreleased]

- Ability to pass extra environment variables to the Kedro nodes using `--env-var` option
- Default configuration for docker-flow adjusted for the latest kedro-docker plugin

## [0.3.4] - 2022-12-30

- Add lazy initialization and cache to Kedro's context in the `KedroContextManager` class to prevent re-loading
Expand Down
14 changes: 13 additions & 1 deletion docs/source/03_quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -382,4 +382,16 @@ The ``distributed_job`` decorator also supports "hard-coded" values for number o
# rest of the code
pass
We have tested the implementation heavily with PyTorch (+PyTorch Lightning) and GPUs. If you encounter any problems, drop us an issue on GitHub!
We have tested the implementation heavily with PyTorch (+PyTorch Lightning) and GPUs. If you encounter any problems, drop us an issue on GitHub!

Run customization
-----------------

In case you need to customize pipeline run context, modifying configuration files is not always the most convinient option. Therefore, ``kedro azureml run`` command provides a few additional options you may find useful:

- ``--subscription_id`` overrides Azure Subscription ID,
- ``--azureml_environment`` overrides the configured Azure ML Environment,
- ``--image`` modifies the Docker image used during the execution,
- ``--pipeline`` allows to select a pipeline to run (by default, the ``__default__`` pipeline is started),
- ``--params`` takes a JSON string with parameters override (JSONed version of ``conf/*/parameters.yml``, not the Kedro's ``params:`` syntax),
- ``--env-var KEY=VALUE`` sets the OS environment variable injected to the steps during runtime (can be used multiple times).
16 changes: 12 additions & 4 deletions kedro_azureml/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
import logging
import os
from pathlib import Path
from typing import List, Optional
from typing import Optional, Tuple

import click
from kedro.framework.startup import ProjectMetadata

from kedro_azureml.cli_functions import (
get_context_and_pipeline,
parse_extra_env_params,
parse_extra_params,
verify_configuration_directory_for_azure,
warn_about_ignore_files,
Expand Down Expand Up @@ -46,7 +47,6 @@ def commands():
@click.pass_obj
@click.pass_context
def azureml_group(ctx, metadata: ProjectMetadata, env):
click.echo(metadata)
ctx.obj = CliContext(env, metadata)


Expand Down Expand Up @@ -150,6 +150,12 @@ def init(
help="Parameters override in form of JSON string",
)
@click.option("--wait-for-completion", type=bool, is_flag=True, default=False)
@click.option(
"--env-var",
type=str,
multiple=True,
help="Environment variables to be injected in the steps, format: KEY=VALUE",
)
@click.pass_obj
@click.pass_context
def run(
Expand All @@ -161,6 +167,7 @@ def run(
pipeline: str,
params: str,
wait_for_completion: bool,
env_var: Tuple[str],
):
"""Runs the specified pipeline in Azure ML Pipelines; Additional parameters can be passed from command line.
Can be used with --wait-for-completion param to block the caller until the pipeline finishes in Azure ML.
Expand All @@ -178,7 +185,8 @@ def run(
verify_configuration_directory_for_azure(click_context, ctx)

mgr: KedroContextManager
with get_context_and_pipeline(ctx, image, pipeline, params, aml_env) as (
extra_env = parse_extra_env_params(env_var)
with get_context_and_pipeline(ctx, image, pipeline, params, aml_env, extra_env) as (
mgr,
az_pipeline,
):
Expand Down Expand Up @@ -290,7 +298,7 @@ def compile(
)
@click.pass_obj
def execute(
ctx: CliContext, pipeline: str, node: str, params: str, azure_outputs: List[str]
ctx: CliContext, pipeline: str, node: str, params: str, azure_outputs: Tuple[str]
):
# 1. Run kedro
parameters = parse_extra_params(params)
Expand Down
13 changes: 12 additions & 1 deletion kedro_azureml/cli_functions.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import json
import logging
import os
import re
from contextlib import contextmanager
from pathlib import Path
from typing import Optional
from typing import Dict, Optional

import click

Expand All @@ -20,6 +21,7 @@ def get_context_and_pipeline(
pipeline: str,
params: str,
aml_env: Optional[str] = None,
extra_env: Dict[str, str] = {},
):
with KedroContextManager(
ctx.metadata.package_name, ctx.env, parse_extra_params(params, True)
Expand Down Expand Up @@ -47,6 +49,7 @@ def get_context_and_pipeline(
docker_image,
params,
storage_account_key,
extra_env,
)
az_pipeline = generator.generate()
yield mgr, az_pipeline
Expand Down Expand Up @@ -129,3 +132,11 @@ def verify_configuration_directory_for_azure(click_context, ctx: CliContext):
)
if not click.confirm(click.style(msg, fg="yellow")):
click_context.exit(2)


def parse_extra_env_params(extra_env):
for entry in extra_env:
if not re.match("[A-Za-z0-9_]+=.*", entry):
raise Exception(f"Invalid env-var: {entry}, expected format: KEY=VALUE")

return {(e := entry.split("="))[0]: e[1] for entry in extra_env}
2 changes: 1 addition & 1 deletion kedro_azureml/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class KedroAzureRunnerConfig(BaseModel):
code_directory: "."
# Path to the directory in the Docker image to run the code from
# Ignored when code_directory is set
working_directory: /home/kedro
working_directory: /home/kedro_docker
# Temporary storage settings - this is used to pass some data between steps
# if the data is not specified in the catalog directly
Expand Down
3 changes: 3 additions & 0 deletions kedro_azureml/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
docker_image: Optional[str] = None,
params: Optional[str] = None,
storage_account_key: Optional[str] = "",
extra_env: Dict[str, str] = {},
):
self.storage_account_key = storage_account_key
self.kedro_environment = kedro_environment
Expand All @@ -59,6 +60,7 @@ def __init__(
self.docker_image = docker_image
self.config = config
self.pipeline_name = pipeline_name
self.extra_env = extra_env

def generate(self) -> Job:
pipeline = self.get_kedro_pipeline()
Expand Down Expand Up @@ -183,6 +185,7 @@ def _construct_azure_command(
run_id=kedro_azure_run_id,
storage_account_key=self.storage_account_key,
).json(),
**self.extra_env,
},
environment=self._resolve_azure_environment(), # TODO: check whether Environment exists
inputs={
Expand Down
4 changes: 2 additions & 2 deletions tests/conf/e2e/azureml.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ azure:
code_directory: null
# Path to the directory in the Docker image to run the code from
# Ignored when code_directory is set
working_directory: /home/kedro
working_directory: /home/kedro_docker

# Temporary storage settings - this is used to pass some data between steps
# if the data is not specified in the catalog directly
Expand All @@ -36,4 +36,4 @@ azure:
# cluster_name: "<your_cluster_name>"
docker:
# Docker image to use during pipeline execution
image: "{container_registry}/kedro-azureml-e2e:latest"
image: "{container_registry}/kedro-azureml-e2e:latest"
42 changes: 41 additions & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def test_can_invoke_execute_cli(
)
@pytest.mark.parametrize("amlignore", ("empty", "missing", "filled"))
@pytest.mark.parametrize("gitignore", ("empty", "missing", "filled"))
@pytest.mark.parametrize("extra_env", (([], {}), (["A=B", "C="], {"A": "B", "C": ""})))
def test_can_invoke_run(
patched_kedro_package,
cli_context,
Expand All @@ -184,6 +185,7 @@ def test_can_invoke_run(
use_default_credentials: bool,
amlignore: str,
gitignore: str,
extra_env: list,
):
create_kedro_conf_dirs(tmp_path)
with patch.dict(
Expand Down Expand Up @@ -215,7 +217,8 @@ def test_can_invoke_run(
cli.run,
["-s", "subscription_id"]
+ (["--wait-for-completion"] if wait_for_completion else [])
+ (["--aml_env", aml_env] if aml_env else []),
+ (["--aml_env", aml_env] if aml_env else [])
+ (sum([["--env-var", k] for k in extra_env[0]], [])),
obj=cli_context,
)
assert result.exit_code == 0
Expand All @@ -234,6 +237,13 @@ def test_can_invoke_run(
else:
interactive_credentials.assert_not_called()

created_pipeline = ml_client.jobs.create_or_update.call_args[0][0]
populated_env_vars = list(created_pipeline.jobs.values())[
0
].environment_variables
del populated_env_vars["KEDRO_AZURE_RUNNER_CONFIG"]
assert populated_env_vars == extra_env[1]


@pytest.mark.parametrize(
"kedro_environment_name",
Expand Down Expand Up @@ -320,3 +330,33 @@ def test_can_invoke_run_with_failed_pipeline(
ml_client.jobs.create_or_update.assert_called_once()
ml_client.compute.get.assert_called_once()
ml_client.jobs.stream.assert_called_once()


@pytest.mark.parametrize("env_var", ("INVALID", "2+2=4"))
def test_fail_if_invalid_env_provided_in_run(
patched_kedro_package,
cli_context,
dummy_pipeline,
tmp_path: Path,
env_var: str,
):
create_kedro_conf_dirs(tmp_path)
with patch.dict(
"kedro.framework.project.pipelines", {"__default__": dummy_pipeline}
), patch.object(Path, "cwd", return_value=tmp_path), patch(
"kedro_azureml.client.MLClient"
) as ml_client_patched, patch(
"kedro_azureml.client.DefaultAzureCredential"
), patch.dict(
os.environ, {"AZURE_STORAGE_ACCOUNT_KEY": "dummy_key"}
):
ml_client = ml_client_patched.from_config()
ml_client.jobs.stream.side_effect = ValueError()

runner = CliRunner()
result = runner.invoke(cli.run, ["--env-var", env_var], obj=cli_context)
assert result.exit_code == 1
assert (
str(result.exception)
== f"Invalid env-var: {env_var}, expected format: KEY=VALUE"
)
18 changes: 18 additions & 0 deletions tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,21 @@ def test_get_target_resource_from_node_tags_raises_exception(
)
with pytest.raises(ConfigException):
generator.get_target_resource_from_node_tags(node)


def test_azure_pipeline_with_custom_env_vars(dummy_plugin_config, dummy_pipeline):
pipeline_name = "unit_test_pipeline"
node = MagicMock()
node.tags = ["compute-2", "compute-3"]
for t in node.tags:
dummy_plugin_config.azure.compute[t] = ComputeConfig(**{"cluster_name": t})
with patch.dict(
"kedro.framework.project.pipelines", {pipeline_name: dummy_pipeline}
):
generator = AzureMLPipelineGenerator(
pipeline_name, "local", dummy_plugin_config, {}, extra_env={"ABC": "def"}
)

for node in generator.generate().jobs.values():
assert "ABC" in node.environment_variables
assert node.environment_variables["ABC"] == "def"

0 comments on commit 9c58ded

Please sign in to comment.