From a4536af8991d7551befe582f7938ccb5ac63ffb4 Mon Sep 17 00:00:00 2001 From: Mariusz Strzelecki Date: Fri, 17 Feb 2023 10:12:47 +0100 Subject: [PATCH] Allow to pass extra env vars to the nodes (#46) --- .github/workflows/tests_and_publish.yml | 2 +- .pre-commit-config.yaml | 2 +- CHANGELOG.md | 3 ++ docs/source/03_quickstart.rst | 14 ++++++++- kedro_azureml/cli.py | 16 +++++++--- kedro_azureml/cli_functions.py | 13 +++++++- kedro_azureml/config.py | 2 +- kedro_azureml/generator.py | 3 ++ tests/conf/e2e/azureml.yml | 4 +-- tests/test_cli.py | 42 ++++++++++++++++++++++++- tests/test_generator.py | 18 +++++++++++ 11 files changed, 107 insertions(+), 12 deletions(-) diff --git a/.github/workflows/tests_and_publish.yml b/.github/workflows/tests_and_publish.yml index fc39764..8b50942 100644 --- a/.github/workflows/tests_and_publish.yml +++ b/.github/workflows/tests_and_publish.yml @@ -175,7 +175,7 @@ jobs: AZURE_CLIENT_SECRET: ${{ secrets.AZURE_CLIENT_SECRET }} AZURE_CLIENT_ID: ${{ secrets.AZURE_CLIENT_ID }} run: | - kedro azureml run --wait-for-completion + kedro azureml run --wait-for-completion --env-var 'GETINDATA=ROCKS!' publish: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 98cac3a..5bb94b6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pycqa/isort - rev: 5.10.1 + rev: 5.12.0 hooks: - id: isort args: ["--profile", "black", "--line-length=79"] diff --git a/CHANGELOG.md b/CHANGELOG.md index 131048b..15ffa2c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,9 @@ ## [Unreleased] +- Ability to pass extra environment variables to the Kedro nodes using `--env-var` option +- Default configuration for docker-flow adjusted for the latest kedro-docker plugin + ## [0.3.4] - 2022-12-30 - Add lazy initialization and cache to Kedro's context in the `KedroContextManager` class to prevent re-loading diff --git a/docs/source/03_quickstart.rst b/docs/source/03_quickstart.rst index e4b84b4..e37cb9d 100644 --- a/docs/source/03_quickstart.rst +++ b/docs/source/03_quickstart.rst @@ -382,4 +382,16 @@ The ``distributed_job`` decorator also supports "hard-coded" values for number o # rest of the code pass -We have tested the implementation heavily with PyTorch (+PyTorch Lightning) and GPUs. If you encounter any problems, drop us an issue on GitHub! \ No newline at end of file +We have tested the implementation heavily with PyTorch (+PyTorch Lightning) and GPUs. If you encounter any problems, drop us an issue on GitHub! + +Run customization +----------------- + +In case you need to customize pipeline run context, modifying configuration files is not always the most convinient option. Therefore, ``kedro azureml run`` command provides a few additional options you may find useful: + +- ``--subscription_id`` overrides Azure Subscription ID, +- ``--azureml_environment`` overrides the configured Azure ML Environment, +- ``--image`` modifies the Docker image used during the execution, +- ``--pipeline`` allows to select a pipeline to run (by default, the ``__default__`` pipeline is started), +- ``--params`` takes a JSON string with parameters override (JSONed version of ``conf/*/parameters.yml``, not the Kedro's ``params:`` syntax), +- ``--env-var KEY=VALUE`` sets the OS environment variable injected to the steps during runtime (can be used multiple times). diff --git a/kedro_azureml/cli.py b/kedro_azureml/cli.py index 1afd2a1..edaf3c1 100644 --- a/kedro_azureml/cli.py +++ b/kedro_azureml/cli.py @@ -2,13 +2,14 @@ import logging import os from pathlib import Path -from typing import List, Optional +from typing import Optional, Tuple import click from kedro.framework.startup import ProjectMetadata from kedro_azureml.cli_functions import ( get_context_and_pipeline, + parse_extra_env_params, parse_extra_params, verify_configuration_directory_for_azure, warn_about_ignore_files, @@ -46,7 +47,6 @@ def commands(): @click.pass_obj @click.pass_context def azureml_group(ctx, metadata: ProjectMetadata, env): - click.echo(metadata) ctx.obj = CliContext(env, metadata) @@ -150,6 +150,12 @@ def init( help="Parameters override in form of JSON string", ) @click.option("--wait-for-completion", type=bool, is_flag=True, default=False) +@click.option( + "--env-var", + type=str, + multiple=True, + help="Environment variables to be injected in the steps, format: KEY=VALUE", +) @click.pass_obj @click.pass_context def run( @@ -161,6 +167,7 @@ def run( pipeline: str, params: str, wait_for_completion: bool, + env_var: Tuple[str], ): """Runs the specified pipeline in Azure ML Pipelines; Additional parameters can be passed from command line. Can be used with --wait-for-completion param to block the caller until the pipeline finishes in Azure ML. @@ -178,7 +185,8 @@ def run( verify_configuration_directory_for_azure(click_context, ctx) mgr: KedroContextManager - with get_context_and_pipeline(ctx, image, pipeline, params, aml_env) as ( + extra_env = parse_extra_env_params(env_var) + with get_context_and_pipeline(ctx, image, pipeline, params, aml_env, extra_env) as ( mgr, az_pipeline, ): @@ -290,7 +298,7 @@ def compile( ) @click.pass_obj def execute( - ctx: CliContext, pipeline: str, node: str, params: str, azure_outputs: List[str] + ctx: CliContext, pipeline: str, node: str, params: str, azure_outputs: Tuple[str] ): # 1. Run kedro parameters = parse_extra_params(params) diff --git a/kedro_azureml/cli_functions.py b/kedro_azureml/cli_functions.py index 7b05268..ec1810f 100644 --- a/kedro_azureml/cli_functions.py +++ b/kedro_azureml/cli_functions.py @@ -1,9 +1,10 @@ import json import logging import os +import re from contextlib import contextmanager from pathlib import Path -from typing import Optional +from typing import Dict, Optional import click @@ -20,6 +21,7 @@ def get_context_and_pipeline( pipeline: str, params: str, aml_env: Optional[str] = None, + extra_env: Dict[str, str] = {}, ): with KedroContextManager( ctx.metadata.package_name, ctx.env, parse_extra_params(params, True) @@ -47,6 +49,7 @@ def get_context_and_pipeline( docker_image, params, storage_account_key, + extra_env, ) az_pipeline = generator.generate() yield mgr, az_pipeline @@ -129,3 +132,11 @@ def verify_configuration_directory_for_azure(click_context, ctx: CliContext): ) if not click.confirm(click.style(msg, fg="yellow")): click_context.exit(2) + + +def parse_extra_env_params(extra_env): + for entry in extra_env: + if not re.match("[A-Za-z0-9_]+=.*", entry): + raise Exception(f"Invalid env-var: {entry}, expected format: KEY=VALUE") + + return {(e := entry.split("="))[0]: e[1] for entry in extra_env} diff --git a/kedro_azureml/config.py b/kedro_azureml/config.py index 185b75e..7252d9f 100644 --- a/kedro_azureml/config.py +++ b/kedro_azureml/config.py @@ -78,7 +78,7 @@ class KedroAzureRunnerConfig(BaseModel): code_directory: "." # Path to the directory in the Docker image to run the code from # Ignored when code_directory is set - working_directory: /home/kedro + working_directory: /home/kedro_docker # Temporary storage settings - this is used to pass some data between steps # if the data is not specified in the catalog directly diff --git a/kedro_azureml/generator.py b/kedro_azureml/generator.py index ac9a3a6..f69f3e0 100644 --- a/kedro_azureml/generator.py +++ b/kedro_azureml/generator.py @@ -49,6 +49,7 @@ def __init__( docker_image: Optional[str] = None, params: Optional[str] = None, storage_account_key: Optional[str] = "", + extra_env: Dict[str, str] = {}, ): self.storage_account_key = storage_account_key self.kedro_environment = kedro_environment @@ -59,6 +60,7 @@ def __init__( self.docker_image = docker_image self.config = config self.pipeline_name = pipeline_name + self.extra_env = extra_env def generate(self) -> Job: pipeline = self.get_kedro_pipeline() @@ -183,6 +185,7 @@ def _construct_azure_command( run_id=kedro_azure_run_id, storage_account_key=self.storage_account_key, ).json(), + **self.extra_env, }, environment=self._resolve_azure_environment(), # TODO: check whether Environment exists inputs={ diff --git a/tests/conf/e2e/azureml.yml b/tests/conf/e2e/azureml.yml index fc9f333..b3ba158 100644 --- a/tests/conf/e2e/azureml.yml +++ b/tests/conf/e2e/azureml.yml @@ -13,7 +13,7 @@ azure: code_directory: null # Path to the directory in the Docker image to run the code from # Ignored when code_directory is set - working_directory: /home/kedro + working_directory: /home/kedro_docker # Temporary storage settings - this is used to pass some data between steps # if the data is not specified in the catalog directly @@ -36,4 +36,4 @@ azure: # cluster_name: "" docker: # Docker image to use during pipeline execution - image: "{container_registry}/kedro-azureml-e2e:latest" \ No newline at end of file + image: "{container_registry}/kedro-azureml-e2e:latest" diff --git a/tests/test_cli.py b/tests/test_cli.py index 1a73552..42f0b08 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -174,6 +174,7 @@ def test_can_invoke_execute_cli( ) @pytest.mark.parametrize("amlignore", ("empty", "missing", "filled")) @pytest.mark.parametrize("gitignore", ("empty", "missing", "filled")) +@pytest.mark.parametrize("extra_env", (([], {}), (["A=B", "C="], {"A": "B", "C": ""}))) def test_can_invoke_run( patched_kedro_package, cli_context, @@ -184,6 +185,7 @@ def test_can_invoke_run( use_default_credentials: bool, amlignore: str, gitignore: str, + extra_env: list, ): create_kedro_conf_dirs(tmp_path) with patch.dict( @@ -215,7 +217,8 @@ def test_can_invoke_run( cli.run, ["-s", "subscription_id"] + (["--wait-for-completion"] if wait_for_completion else []) - + (["--aml_env", aml_env] if aml_env else []), + + (["--aml_env", aml_env] if aml_env else []) + + (sum([["--env-var", k] for k in extra_env[0]], [])), obj=cli_context, ) assert result.exit_code == 0 @@ -234,6 +237,13 @@ def test_can_invoke_run( else: interactive_credentials.assert_not_called() + created_pipeline = ml_client.jobs.create_or_update.call_args[0][0] + populated_env_vars = list(created_pipeline.jobs.values())[ + 0 + ].environment_variables + del populated_env_vars["KEDRO_AZURE_RUNNER_CONFIG"] + assert populated_env_vars == extra_env[1] + @pytest.mark.parametrize( "kedro_environment_name", @@ -320,3 +330,33 @@ def test_can_invoke_run_with_failed_pipeline( ml_client.jobs.create_or_update.assert_called_once() ml_client.compute.get.assert_called_once() ml_client.jobs.stream.assert_called_once() + + +@pytest.mark.parametrize("env_var", ("INVALID", "2+2=4")) +def test_fail_if_invalid_env_provided_in_run( + patched_kedro_package, + cli_context, + dummy_pipeline, + tmp_path: Path, + env_var: str, +): + create_kedro_conf_dirs(tmp_path) + with patch.dict( + "kedro.framework.project.pipelines", {"__default__": dummy_pipeline} + ), patch.object(Path, "cwd", return_value=tmp_path), patch( + "kedro_azureml.client.MLClient" + ) as ml_client_patched, patch( + "kedro_azureml.client.DefaultAzureCredential" + ), patch.dict( + os.environ, {"AZURE_STORAGE_ACCOUNT_KEY": "dummy_key"} + ): + ml_client = ml_client_patched.from_config() + ml_client.jobs.stream.side_effect = ValueError() + + runner = CliRunner() + result = runner.invoke(cli.run, ["--env-var", env_var], obj=cli_context) + assert result.exit_code == 1 + assert ( + str(result.exception) + == f"Invalid env-var: {env_var}, expected format: KEY=VALUE" + ) diff --git a/tests/test_generator.py b/tests/test_generator.py index 90c72b4..cd409df 100644 --- a/tests/test_generator.py +++ b/tests/test_generator.py @@ -120,3 +120,21 @@ def test_get_target_resource_from_node_tags_raises_exception( ) with pytest.raises(ConfigException): generator.get_target_resource_from_node_tags(node) + + +def test_azure_pipeline_with_custom_env_vars(dummy_plugin_config, dummy_pipeline): + pipeline_name = "unit_test_pipeline" + node = MagicMock() + node.tags = ["compute-2", "compute-3"] + for t in node.tags: + dummy_plugin_config.azure.compute[t] = ComputeConfig(**{"cluster_name": t}) + with patch.dict( + "kedro.framework.project.pipelines", {pipeline_name: dummy_pipeline} + ): + generator = AzureMLPipelineGenerator( + pipeline_name, "local", dummy_plugin_config, {}, extra_env={"ABC": "def"} + ) + + for node in generator.generate().jobs.values(): + assert "ABC" in node.environment_variables + assert node.environment_variables["ABC"] == "def"