Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Enable overriding AzureConfig parameters from a LightningContainer (#589
Browse files Browse the repository at this point in the history
)
  • Loading branch information
dccastro authored Nov 17, 2021
1 parent a9f3dd8 commit 0b1d68f
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 0 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ jobs that run in AzureML.
`NIH_COVID_BYOL` to specify the name of the SSL training dataset.
- ([#560](https://github.com/microsoft/InnerEye-DeepLearning/pull/560)) Added pre-commit hooks.
- ([#559](https://github.com/microsoft/InnerEye-DeepLearning/pull/559)) Adding the accompanying code for the ["Active label cleaning: Improving dataset quality under resource constraints"](https://arxiv.org/abs/2109.00574) paper. The code can be found in the [InnerEye-DataQuality](InnerEye-DataQuality/README.md) subfolder. It provides tools for training noise robust models, running label cleaning simulation and loading our label cleaning benchmark datasets.
- ([#589](https://github.com/microsoft/InnerEye-DeepLearning/pull/589)) Add `LightningContainer.update_azure_config()`
hook to enable overriding `AzureConfig` parameters from a container (e.g. `experiment_name`, `cluster`, `num_nodes`).

### Changed
- ([#576](https://github.com/microsoft/InnerEye-DeepLearning/pull/576)) The console output is no longer written to stdout.txt because AzureML handles that better now
Expand Down
14 changes: 14 additions & 0 deletions InnerEye/ML/lightning_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from torch.optim.lr_scheduler import _LRScheduler
from azureml.core import ScriptRunConfig
from azureml.train.hyperdrive import GridParameterSampling, HyperDriveConfig, PrimaryMetricGoal, choice
from InnerEye.Azure.azure_config import AzureConfig

from InnerEye.Azure.azure_util import CROSS_VALIDATION_SPLIT_INDEX_TAG_KEY
from InnerEye.Common.generic_parsing import GenericConfig, create_from_matching_params
Expand Down Expand Up @@ -213,6 +214,19 @@ def get_parameter_search_hyperdrive_config(self, _: ScriptRunConfig) -> HyperDri
"""
raise NotImplementedError("Parameter search is not implemented. It should be implemented in a sub class if needed.")

def update_azure_config(self, azure_config: AzureConfig) -> None:
"""
This method allows overriding AzureConfig parameters from within a LightningContainer.
It is called right after the AzureConfig and container are initialised.
Be careful when using class parameters to override these values. If the parameter names clash,
CLI values will be consumed by the AzureConfig, but container parameters will keep their defaults.
This can be avoided by always using unique parameter names.
Also note that saving a reference to `azure_config` and updating its attributes at any other
point may lead to unexpected behaviour.
:param azure_config: The initialised AzureConfig whose parameters to override in-place.
"""
pass

def create_report(self) -> None:
"""
This method is called after training and testing has been completed. It can aggregate all files that were
Expand Down
4 changes: 4 additions & 0 deletions InnerEye/ML/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,10 @@ def parse_overrides_and_apply(c: object, previous_parser_result: ParserResult) -
self.lightning_container = InnerEyeContainer(config_or_container)
else:
raise ValueError(f"Don't know how to handle a loaded configuration of type {type(config_or_container)}")

# Allow overriding AzureConfig params from within the container.
self.lightning_container.update_azure_config(self.azure_config)

if azure_config.extra_code_directory:
exist = "exists" if Path(azure_config.extra_code_directory).exists() else "does not exist"
logging.info(f"extra_code_directory is {azure_config.extra_code_directory}, which {exist}")
Expand Down
64 changes: 64 additions & 0 deletions Tests/ML/test_lightning_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,21 @@
from unittest import mock

import pandas as pd
import param
import pytest
from azureml.core import ScriptRunConfig
from azureml.train.hyperdrive.runconfig import HyperDriveConfig
from pytorch_lightning import LightningModule

from InnerEye.Azure.azure_config import AzureConfig
from InnerEye.Common.output_directories import OutputFolderForTests
from InnerEye.ML.common import ModelExecutionMode
from InnerEye.ML.deep_learning_config import ARGS_TXT, DatasetParams, WorkflowParams
from InnerEye.ML.lightning_base import InnerEyeContainer
from InnerEye.ML.lightning_container import LightningContainer
from InnerEye.ML.model_config_base import ModelConfigBase
from InnerEye.ML.run_ml import MLRunner
from InnerEye.ML.runner import Runner
from Tests.ML.configs.DummyModel import DummyModel
from Tests.ML.configs.lightning_test_containers import (DummyContainerWithAzureDataset, DummyContainerWithHooks,
DummyContainerWithModel, DummyContainerWithPlainLightning)
Expand Down Expand Up @@ -345,3 +348,64 @@ def mocked_convert_channels_to_file_paths(
convert_channels_to_file_paths_mock.side_effect = mocked_convert_channels_to_file_paths
container.setup()
convert_channels_to_file_paths_mock.assert_called()


class DummyContainerWithAzureConfigOverrides(LightningContainer):
container_subscription_id: str = param.String("default-container-subscription-id")
tenant_id: str = param.String("default-container-tenant-id")
application_id: str = param.String("default-container-application-id")

def update_azure_config(self, azure_config: AzureConfig) -> None:
# Override parameter with different name
azure_config.subscription_id = self.container_subscription_id
# Override parameter with clashing name
azure_config.tenant_id = self.tenant_id
# Override with hard-coded value
azure_config.experiment_name = "hardcoded-experiment-name"


def test_override_azure_config_from_container() -> None:
# Arguments partly to be set in AzureConfig, and partly in container.
args = ["",
"--model", DummyContainerWithAzureConfigOverrides.__name__,
"--model_configs_namespace", "Tests.ML.test_lightning_containers",
"--container_subscription_id", "cli-container-subscription-id",
"--subscription_id", "cli-subscription-id",
"--tenant_id", "cli-tenant-id",
"--application_id", "cli-application-id",
"--experiment_name", "cli-experiment-name",
"--workspace_name", "cli-workspace-name"]
with mock.patch("sys.argv", args):
runner: Runner = default_runner()
runner.parse_and_load_model()
assert runner.azure_config is not None
assert runner.lightning_container is not None

# Current AzureConfig parameter priority is as follows:
# 1. Container
# 2. CLI
# 3. YAML
# 4. AzureConfig defaults

# ==== Parameters declared in the container ====
# Unique container parameters can be set from CLI, then override AzureConfig
assert runner.azure_config.subscription_id \
== runner.lightning_container.container_subscription_id \
== "cli-container-subscription-id"

# If the container declares a clashing parameter, the CLI value will be
# consumed by the original AzureConfig
assert runner.azure_config.application_id == "cli-application-id"
assert runner.lightning_container.application_id == "default-container-application-id"
# However, it may then be overriden by the container default; this should be
# avoided to prevent unexpected behaviour
assert runner.azure_config.tenant_id \
== runner.lightning_container.tenant_id \
== "default-container-tenant-id"

# ==== Parameters declared only in AzureConfig ====
# Hard-coded overrides ignore CLI value
assert runner.azure_config.experiment_name == "hardcoded-experiment-name"

# AzureConfig parameters not overriden in container can still be set from CLI
assert runner.azure_config.workspace_name == "cli-workspace-name"

0 comments on commit 0b1d68f

Please sign in to comment.