Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Vsalva/deepmil panda #619

Merged
merged 18 commits into from
Dec 14, 2021
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,3 @@ repos:
rev: v1.5.7
hooks:
- id: autopep8

- repo: https://github.com/ambv/black
rev: 21.9b0
hooks:
- id: black
language_version: python3.7
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ jobs that run in AzureML.
- ([#554](https://github.com/microsoft/InnerEye-DeepLearning/pull/554)) Added a parameter `pretraining_dataset_id` to
`NIH_COVID_BYOL` to specify the name of the SSL training dataset.
- ([#560](https://github.com/microsoft/InnerEye-DeepLearning/pull/560)) Added pre-commit hooks.
-([#619](https://github.com/microsoft/InnerEye-DeepLearning/pull/619)) Add DeepMIL PANDA
- ([#559](https://github.com/microsoft/InnerEye-DeepLearning/pull/559)) Adding the accompanying code for the ["Active label cleaning: Improving dataset quality under resource constraints"](https://arxiv.org/abs/2109.00574) paper. The code can be found in the [InnerEye-DataQuality](InnerEye-DataQuality/README.md) subfolder. It provides tools for training noise robust models, running label cleaning simulation and loading our label cleaning benchmark datasets.
- ([#589](https://github.com/microsoft/InnerEye-DeepLearning/pull/589)) Add `LightningContainer.update_azure_config()`
hook to enable overriding `AzureConfig` parameters from a container (e.g. `experiment_name`, `cluster`, `num_nodes`).
Expand Down
136 changes: 95 additions & 41 deletions InnerEye/Azure/azure_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,20 @@ def split_recovery_id(id: str) -> Tuple[str, str]:
"""
components = id.strip().split(EXPERIMENT_RUN_SEPARATOR)
if len(components) > 2:
raise ValueError("recovery_id must be in the format: 'experiment_name:run_id', but got: {}".format(id))
raise ValueError(
"recovery_id must be in the format: 'experiment_name:run_id', but got: {}".format(
id
)
)
elif len(components) == 2:
return components[0], components[1]
else:
recovery_id_regex = r"^(\w+)_\d+_[0-9a-f]+$|^(\w+)_\d+$"
match = re.match(recovery_id_regex, id)
if not match:
raise ValueError("The recovery ID was not in the expected format: {}".format(id))
raise ValueError(
"The recovery ID was not in the expected format: {}".format(id)
)
return (match.group(1) or match.group(2)), id


Expand All @@ -77,9 +83,15 @@ def fetch_run(workspace: Workspace, run_recovery_id: str) -> Run:
try:
experiment_to_recover = Experiment(workspace, experiment)
except Exception as ex:
raise Exception(f"Unable to retrieve run {run} in experiment {experiment}: {str(ex)}")
raise Exception(
f"Unable to retrieve run {run} in experiment {experiment}: {str(ex)}"
)
run_to_recover = fetch_run_for_experiment(experiment_to_recover, run)
logging.info("Fetched run #{} {} from experiment {}.".format(run, run_to_recover.number, experiment))
logging.info(
"Fetched run #{} {} from experiment {}.".format(
run, run_to_recover.number, experiment
)
)
return run_to_recover


Expand All @@ -94,9 +106,13 @@ def fetch_run_for_experiment(experiment_to_recover: Experiment, run_id: str) ->
except Exception:
available_runs = experiment_to_recover.get_runs()
available_ids = ", ".join([run.id for run in available_runs])
raise (Exception(
"Run {} not found for experiment: {}. Available runs are: {}".format(
run_id, experiment_to_recover.name, available_ids)))
raise (
Exception(
"Run {} not found for experiment: {}. Available runs are: {}".format(
run_id, experiment_to_recover.name, available_ids
)
)
)


def fetch_runs(experiment: Experiment, filters: List[str]) -> List[Run]:
Expand All @@ -116,8 +132,11 @@ def fetch_runs(experiment: Experiment, filters: List[str]) -> List[Run]:
return exp_runs


def fetch_child_runs(run: Run, status: Optional[str] = None,
expected_number_cross_validation_splits: int = 0) -> List[Run]:
def fetch_child_runs(
run: Run,
status: Optional[str] = None,
expected_number_cross_validation_splits: int = 0,
) -> List[Run]:
"""
Fetch child runs for the provided runs that have the provided AML status (or fetch all by default)
and have a run_recovery_id tag value set (this is to ignore superfluous AML infrastructure platform runs).
Expand All @@ -138,18 +157,25 @@ def fetch_child_runs(run: Run, status: Optional[str] = None,
if 0 < expected_number_cross_validation_splits != len(children_runs):
logging.warning(
f"The expected number of child runs was {expected_number_cross_validation_splits}."
f"Fetched only: {len(children_runs)} runs. Now trying to fetch them manually.")
run_ids_to_evaluate = [f"{create_run_recovery_id(run)}_{i}"
for i in range(expected_number_cross_validation_splits)]
children_runs = [fetch_run(run.experiment.workspace, id) for id in run_ids_to_evaluate]
f"Fetched only: {len(children_runs)} runs. Now trying to fetch them manually."
)
run_ids_to_evaluate = [
f"{create_run_recovery_id(run)}_{i}"
for i in range(expected_number_cross_validation_splits)
]
children_runs = [
fetch_run(run.experiment.workspace, id) for id in run_ids_to_evaluate
]
if status is not None:
children_runs = [child_run for child_run in children_runs if child_run.get_status() == status]
children_runs = [
child_run for child_run in children_runs if child_run.get_status() == status
]
return children_runs


def is_ensemble_run(run: Run) -> bool:
"""Checks if the run was an ensemble of multiple models"""
return run.get_tags().get(IS_ENSEMBLE_KEY_NAME) == 'True'
return run.get_tags().get(IS_ENSEMBLE_KEY_NAME) == "True"


def to_azure_friendly_string(x: Optional[str]) -> Optional[str]:
Expand All @@ -160,7 +186,7 @@ def to_azure_friendly_string(x: Optional[str]) -> Optional[str]:
if x is None:
return x
else:
return re.sub('_+', '_', re.sub(r'\W+', '_', x))
return re.sub("_+", "_", re.sub(r"\W+", "_", x))


def to_azure_friendly_container_path(path: Path) -> str:
Expand All @@ -178,7 +204,7 @@ def is_offline_run_context(run_context: Run) -> bool:
:param run_context: Context of the run to check
:return:
"""
return not hasattr(run_context, 'experiment')
return not hasattr(run_context, "experiment")


def get_run_context_or_default(run: Optional[Run] = None) -> Run:
Expand All @@ -199,7 +225,12 @@ def get_cross_validation_split_index(run: Run) -> int:
if is_offline_run_context(run):
return DEFAULT_CROSS_VALIDATION_SPLIT_INDEX
else:
return int(run.get_tags().get(CROSS_VALIDATION_SPLIT_INDEX_TAG_KEY, DEFAULT_CROSS_VALIDATION_SPLIT_INDEX))
return int(
run.get_tags().get(
CROSS_VALIDATION_SPLIT_INDEX_TAG_KEY,
DEFAULT_CROSS_VALIDATION_SPLIT_INDEX,
)
)


def is_cross_validation_child_run(run: Run) -> bool:
Expand All @@ -220,7 +251,7 @@ def strip_prefix(string: str, prefix: str) -> str:
:return: Input string with prefix removed.
"""
if string.startswith(prefix):
return string[len(prefix):]
return string[len(prefix) :]
return string


Expand Down Expand Up @@ -256,9 +287,7 @@ def is_parent_run(run: Run) -> bool:
return PARENT_RUN_CONTEXT and run.id == PARENT_RUN_CONTEXT.id


def download_run_output_file(blob_path: Path,
destination: Path,
run: Run) -> Path:
def download_run_output_file(blob_path: Path, destination: Path, run: Run) -> Path:
"""
Downloads a single file from the run's default output directory: DEFAULT_AML_UPLOAD_DIR ("outputs").
For example, if blobs_path = "foo/bar.csv", then the run result file "outputs/foo/bar.csv" will be downloaded
Expand All @@ -270,17 +299,21 @@ def download_run_output_file(blob_path: Path,
"""
blobs_prefix = str((fixed_paths.DEFAULT_AML_UPLOAD_DIR / blob_path).as_posix())
destination = destination / blob_path.name
logging.info(f"Downloading single file from run {run.id}: {blobs_prefix} -> {str(destination)}")
logging.info(
f"Downloading single file from run {run.id}: {blobs_prefix} -> {str(destination)}"
)
try:
run.download_file(blobs_prefix, str(destination), _validate_checksum=True)
except Exception as ex:
raise ValueError(f"Unable to download file '{blobs_prefix}' from run {run.id}") from ex
raise ValueError(
f"Unable to download file '{blobs_prefix}' from run {run.id}"
) from ex
return destination


def download_run_outputs_by_prefix(blobs_prefix: Path,
destination: Path,
run: Run) -> None:
def download_run_outputs_by_prefix(
blobs_prefix: Path, destination: Path, run: Run
) -> None:
"""
Download all the blobs from the run's default output directory: DEFAULT_AML_UPLOAD_DIR ("outputs") that
have a given prefix (folder structure). When saving, the prefix string will be stripped off. For example,
Expand All @@ -291,19 +324,25 @@ def download_run_outputs_by_prefix(blobs_prefix: Path,
:param destination: Local path to save the downloaded blobs to.
"""
prefix_str = str((fixed_paths.DEFAULT_AML_UPLOAD_DIR / blobs_prefix).as_posix())
logging.info(f"Downloading multiple files from run {run.id}: {prefix_str} -> {str(destination)}")
logging.info(
f"Downloading multiple files from run {run.id}: {prefix_str} -> {str(destination)}"
)
# There is a download_files function, but that can time out when downloading several large checkpoints file
# (120sec timeout for all files).
for file in run.get_file_names():
if file.startswith(prefix_str):
target_path = file[len(prefix_str):]
target_path = file[len(prefix_str) :]
if target_path.startswith("/"):
target_path = target_path[1:]
logging.info(f"Downloading {file}")
run.download_file(file, str(destination / target_path), _validate_checksum=True)
run.download_file(
file, str(destination / target_path), _validate_checksum=True
)
else:
logging.warning(f"Skipping file {file}, because the desired prefix {prefix_str} is not aligned with "
f"the folder structure")
logging.warning(
f"Skipping file {file}, because the desired prefix {prefix_str} is not aligned with "
f"the folder structure"
)


def is_running_on_azure_agent() -> bool:
Expand All @@ -314,10 +353,9 @@ def is_running_on_azure_agent() -> bool:
return bool(os.environ.get("AGENT_OS", None))


def get_comparison_baseline_paths(outputs_folder: Path,
blob_path: Path, run: Run,
dataset_csv_file_name: str) -> \
Tuple[Optional[Path], Optional[Path]]:
def get_comparison_baseline_paths(
outputs_folder: Path, blob_path: Path, run: Run, dataset_csv_file_name: str
) -> Tuple[Optional[Path], Optional[Path]]:
run_rec_id = run.id
# We usually find dataset.csv in the same directory as metrics.csv, but we sometimes
# have to look higher up.
Expand All @@ -328,21 +366,29 @@ def get_comparison_baseline_paths(outputs_folder: Path,
for blob_path_parent in step_up_directories(blob_path):
try:
comparison_dataset_path = download_run_output_file(
blob_path_parent / dataset_csv_file_name, destination_folder, run)
blob_path_parent / dataset_csv_file_name, destination_folder, run
)
break
except (ValueError, UserErrorException):
logging.warning(f"cannot find {dataset_csv_file_name} at {blob_path_parent} in {run_rec_id}")
logging.warning(
f"cannot find {dataset_csv_file_name} at {blob_path_parent} in {run_rec_id}"
)
except NotADirectoryError:
logging.warning(f"{blob_path_parent} is not a directory")
break
if comparison_dataset_path is None:
logging.warning(f"cannot find {dataset_csv_file_name} at or above {blob_path} in {run_rec_id}")
logging.warning(
f"cannot find {dataset_csv_file_name} at or above {blob_path} in {run_rec_id}"
)
# Look for epoch_NNN/Test/metrics.csv
try:
comparison_metrics_path = download_run_output_file(
blob_path / SUBJECT_METRICS_FILE_NAME, destination_folder, run)
blob_path / SUBJECT_METRICS_FILE_NAME, destination_folder, run
)
except (ValueError, UserErrorException):
logging.warning(f"cannot find {SUBJECT_METRICS_FILE_NAME} at {blob_path} in {run_rec_id}")
logging.warning(
f"cannot find {SUBJECT_METRICS_FILE_NAME} at {blob_path} in {run_rec_id}"
)
return (comparison_dataset_path, comparison_metrics_path)


Expand All @@ -357,3 +403,11 @@ def step_up_directories(path: Path) -> Generator[Path, None, None]:
if parent == path:
break
path = parent


def get_default_azure_config_json_path() -> Path:
"""
Gets the path to the project's default Azure config JSON file.
"""
azure_config_json_path = fixed_paths.repository_root_directory() / "config.json"
return azure_config_json_path
5 changes: 4 additions & 1 deletion InnerEye/ML/Histopathology/datamodules/panda_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------

from typing import Tuple
from typing import Tuple, Any

from InnerEye.ML.Histopathology.datamodules.base_module import TilesDataModule
from InnerEye.ML.Histopathology.datasets.panda_tiles_dataset import PandaTilesDataset
Expand All @@ -15,6 +15,9 @@ class PandaTilesDataModule(TilesDataModule):
Method get_splits() returns the train, val, test splits from the PANDA dataset
"""

def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

def get_splits(self) -> Tuple[PandaTilesDataset, PandaTilesDataset, PandaTilesDataset]:
dataset = PandaTilesDataset(self.root_path)
splits = DatasetSplits.from_proportions(dataset.dataset_df.reset_index(),
Expand Down
6 changes: 3 additions & 3 deletions InnerEye/ML/Histopathology/datasets/panda_tiles_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ class PandaTilesDataset(TilesDataset):
SPLIT_COLUMN = None # PANDA does not have an official train/test split
N_CLASSES = 6

_RELATIVE_ROOT_FOLDER = "PANDA_tiles_20210926-135446/panda_tiles_level1_224"
_RELATIVE_ROOT_FOLDER = Path("PANDA_tiles_20210926-135446/panda_tiles_level1_224")

def __init__(self,
root: Union[str, Path],
root: Path,
dataset_csv: Optional[Union[str, Path]] = None,
dataset_df: Optional[pd.DataFrame] = None) -> None:
super().__init__(root=Path(root) / self._RELATIVE_ROOT_FOLDER,
Expand All @@ -48,7 +48,7 @@ class PandaTilesDatasetReturnImageLabel(VisionDataset):
class label.
"""
def __init__(self,
root: Union[str, Path],
root: Path,
dataset_csv: Optional[Union[str, Path]] = None,
dataset_df: Optional[pd.DataFrame] = None,
transform: Optional[Callable] = None,
Expand Down
8 changes: 6 additions & 2 deletions InnerEye/ML/Histopathology/models/deepmil.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def _shared_step(self, batch: Dict, batch_idx: int, stage: str) -> Dict[ResultsK
bag_labels_list = []
bag_logits_list = []
bag_attn_list = []
for bag_idx in range(len(batch[TilesDataset.LABEL_COLUMN])):
for bag_idx in range(len(batch[self.label_column])):
images = batch[TilesDataset.IMAGE_COLUMN][bag_idx]
labels = batch[self.label_column][bag_idx]
bag_labels_list.append(self.get_bag_label(labels))
Expand All @@ -177,7 +177,7 @@ def _shared_step(self, batch: Dict, batch_idx: int, stage: str) -> Dict[ResultsK
bag_labels = torch.stack(bag_labels_list).view(-1)

if self.n_classes > 1:
loss = self.loss_fn(bag_logits, bag_labels)
loss = self.loss_fn(bag_logits, bag_labels.long())
else:
loss = self.loss_fn(bag_logits.squeeze(1), bag_labels.float())

Expand All @@ -201,6 +201,10 @@ def _shared_step(self, batch: Dict, batch_idx: int, stage: str) -> Dict[ResultsK
ResultsKey.PROB: probs, ResultsKey.PRED_LABEL: preds,
ResultsKey.TRUE_LABEL: bag_labels, ResultsKey.BAG_ATTN: bag_attn_list,
ResultsKey.IMAGE: batch[TilesDataset.IMAGE_COLUMN]})
if (TilesDataset.TILE_X_COLUMN in batch.keys()) and (TilesDataset.TILE_Y_COLUMN in batch.keys()):
results.update({ResultsKey.TILE_X: batch[TilesDataset.TILE_X_COLUMN],
ResultsKey.TILE_Y: batch[TilesDataset.TILE_Y_COLUMN]}
)
return results
vale-salvatelli marked this conversation as resolved.
Show resolved Hide resolved

def training_step(self, batch: Dict, batch_idx: int) -> Tensor: # type: ignore
Expand Down
3 changes: 3 additions & 0 deletions InnerEye/ML/Histopathology/utils/naming.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,6 @@ class ResultsKey(str, Enum):
PRED_LABEL = 'pred_label'
TRUE_LABEL = 'true_label'
BAG_ATTN = 'bag_attn'
TILE_X = "x"
TILE_Y = "y"

Loading