diff --git a/CHANGELOG.md b/CHANGELOG.md index bb55b6bc3..b4b17f63a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -51,6 +51,10 @@ with only minimum code changes required. See [the MD documentation](docs/bring_y - ([#439](https://github.com/microsoft/InnerEye-DeepLearning/pull/439)) Enable automatic job recovery from last recovery checkpoint in case of job pre-emption on AML. Give the possibility to the user to keep more than one recovery checkpoint. +- ([#442](https://github.com/microsoft/InnerEye-DeepLearning/pull/442)) Enable defining custom scalar losses + (`ScalarLoss.CustomClassification` and `CustomRegression`), prediction targets (`ScalarModelBase.target_names`), + and reporting (`ModelConfigBase.generate_custom_report()`) in scalar configs, providing more flexibility for defining + model configs with custom behaviour while leveraging the existing InnerEye workflows. ### Changed diff --git a/InnerEye/ML/lightning_models.py b/InnerEye/ML/lightning_models.py index 6087a09d1..490e693ee 100644 --- a/InnerEye/ML/lightning_models.py +++ b/InnerEye/ML/lightning_models.py @@ -178,6 +178,7 @@ def __init__(self, config: ScalarModelBase, *args: Any, **kwargs: Any) -> None: super().__init__(config, *args, **kwargs) self.model = config.create_model() raw_loss = model_util.create_scalar_loss_function(config) + self.posthoc_label_transform = config.get_posthoc_label_transform() if isinstance(config, SequenceModelBase): self.loss_fn = lambda model_output, loss: apply_sequence_model_loss(raw_loss, model_output, loss) self.target_indices = config.get_target_indices() @@ -186,7 +187,7 @@ def __init__(self, config: ScalarModelBase, *args: Any, **kwargs: Any) -> None: else: self.loss_fn = raw_loss self.target_indices = [] - self.target_names = config.class_names + self.target_names = config.target_names self.is_classification_model = config.is_classification_model self.use_mean_teacher_model = config.compute_mean_teacher_model self.is_binary_classification_or_regression = True if len(config.class_names) == 1 else False @@ -269,6 +270,7 @@ def training_or_validation_step(self, """ model_inputs_and_labels = get_scalar_model_inputs_and_labels(self.model, self.target_indices, sample) labels = model_inputs_and_labels.labels + labels = self.posthoc_label_transform(labels) if is_training: logits = self.model(*model_inputs_and_labels.model_inputs) else: diff --git a/InnerEye/ML/model_config_base.py b/InnerEye/ML/model_config_base.py index 7973eca1e..745d248ba 100644 --- a/InnerEye/ML/model_config_base.py +++ b/InnerEye/ML/model_config_base.py @@ -248,6 +248,20 @@ def set_derived_model_properties(self, model: Any) -> None: """ pass + def generate_custom_report(self, report_dir: Path, train_metrics: Path, val_metrics: Path, + test_metrics: Path) -> Path: + """ + Enables creating a custom results report, given the metrics files written during model training and inference. + By default, this method is a no-op. + + :param report_dir: The output directory where the generated report should be saved. + :param train_metrics: The CSV file with training metrics. + :param val_metrics: The CSV file with validation metrics. + :param test_metrics: The CSV file with test metrics. + :return: The path to the generated report file. + """ + pass + class ModelTransformsPerExecutionMode: """ diff --git a/InnerEye/ML/model_testing.py b/InnerEye/ML/model_testing.py index 39364a7db..7b7421f4b 100644 --- a/InnerEye/ML/model_testing.py +++ b/InnerEye/ML/model_testing.py @@ -388,7 +388,7 @@ def create_metrics_dict_for_scalar_models(config: ScalarModelBase) -> \ return SequenceMetricsDict.create(is_classification_model=config.is_classification_model, sequence_target_positions=config.sequence_target_positions) else: - return ScalarMetricsDict(hues=config.class_names, + return ScalarMetricsDict(hues=config.target_names, is_classification_metrics=config.is_classification_model) @@ -407,6 +407,7 @@ def classification_model_test(config: ScalarModelBase, :param model_proc: whether we are testing an ensemble or single model :return: InferenceMetricsForClassification object that contains metrics related for all of the checkpoint epochs. """ + posthoc_label_transform = config.get_posthoc_label_transform() def test_epoch(checkpoint_paths: List[Path]) -> Optional[MetricsDict]: pipeline = create_inference_pipeline(config=config, @@ -431,6 +432,7 @@ def test_epoch(checkpoint_paths: List[Path]) -> Optional[MetricsDict]: result = pipeline.predict(sample) model_output = result.posteriors label = result.labels.to(device=model_output.device) + label = posthoc_label_transform(label) sample_id = result.subject_ids[0] compute_scalar_metrics(metrics_dict, subject_ids=[sample_id], diff --git a/InnerEye/ML/reports/classification_crossval_report.ipynb b/InnerEye/ML/reports/classification_crossval_report.ipynb index 7b3a647d0..547a58714 100644 --- a/InnerEye/ML/reports/classification_crossval_report.ipynb +++ b/InnerEye/ML/reports/classification_crossval_report.ipynb @@ -220,7 +220,7 @@ "outputs": [], "source": [ "if not is_crossval_report and val_metrics_csv.is_file() and test_metrics_csv.is_file():\n", - " for prediction_target in config.class_names:\n", + " for prediction_target in config.target_names:\n", " print_header(f\"Class: {prediction_target}\", level=3)\n", " print_k_best_and_worst_performing(val_metrics_csv=val_metrics_csv, test_metrics_csv=test_metrics_csv,\n", " k=number_best_and_worst_performing,\n", @@ -242,7 +242,7 @@ "outputs": [], "source": [ "if not is_crossval_report and val_metrics_csv.is_file() and test_metrics_csv.is_file():\n", - " for prediction_target in config.class_names:\n", + " for prediction_target in config.target_names:\n", " print_header(f\"Class: {prediction_target}\", level=3)\n", " plot_k_best_and_worst_performing(val_metrics_csv=val_metrics_csv, test_metrics_csv=test_metrics_csv,\n", " k=number_best_and_worst_performing, prediction_target=prediction_target,\n", diff --git a/InnerEye/ML/reports/classification_report.ipynb b/InnerEye/ML/reports/classification_report.ipynb index ee862b62a..49f612fe2 100644 --- a/InnerEye/ML/reports/classification_report.ipynb +++ b/InnerEye/ML/reports/classification_report.ipynb @@ -232,7 +232,7 @@ "outputs": [], "source": [ "if val_metrics_csv.is_file() and test_metrics_csv.is_file():\n", - " for prediction_target in config.class_names:\n", + " for prediction_target in config.target_names:\n", " print_header(f\"Class {prediction_target}\", level=3)\n", " print_k_best_and_worst_performing(val_metrics_csv=val_metrics_csv, test_metrics_csv=test_metrics_csv,\n", " k=number_best_and_worst_performing,\n", @@ -255,7 +255,7 @@ "outputs": [], "source": [ "if val_metrics_csv.is_file() and test_metrics_csv.is_file():\n", - " for prediction_target in config.class_names:\n", + " for prediction_target in config.target_names:\n", " print_header(f\"Class {prediction_target}\", level=3)\n", " plot_k_best_and_worst_performing(val_metrics_csv=val_metrics_csv, test_metrics_csv=test_metrics_csv,\n", " k=number_best_and_worst_performing, prediction_target=prediction_target, config=config)" diff --git a/InnerEye/ML/reports/classification_report.py b/InnerEye/ML/reports/classification_report.py index 1e376714a..86000843d 100644 --- a/InnerEye/ML/reports/classification_report.py +++ b/InnerEye/ML/reports/classification_report.py @@ -81,6 +81,7 @@ def check_column_present(dataframe: pd.DataFrame, column: LoggingColumns) -> Non df = pd.read_csv(csv) df = df[df[LoggingColumns.Hue.value] == prediction_target] # Filter by prediction target + df = df[~df[LoggingColumns.Label.value].isna()] # Filter missing labels # Filter by crossval split index if crossval_split_index is not None: @@ -279,7 +280,7 @@ def plot_pr_and_roc_curves_from_csv(metrics_csv: Path, config: ScalarModelBase, :param is_crossval_report: If True, assumes CSV contains results for multiple cross-validation runs and plots the curves with median and confidence intervals. Otherwise, plots curves for a single run. """ - for prediction_target in config.class_names: + for prediction_target in config.target_names: print_header(f"Class: {prediction_target}", level=3) if is_crossval_report: all_metrics = [get_labels_and_predictions(metrics_csv, prediction_target, @@ -469,7 +470,7 @@ def print_metrics_for_all_prediction_targets(csv_to_set_optimal_threshold: Path, :param is_crossval_report: If True, assumes CSVs contain results for multiple cross-validation runs and prints the metrics along with means and standard deviations. Otherwise, prints metrics for a single run. """ - for prediction_target in config.class_names: + for prediction_target in config.target_names: print_header(f"Class: {prediction_target}", level=3) rows, header = get_metrics_table_for_prediction_target( csv_to_set_optimal_threshold=csv_to_set_optimal_threshold, @@ -484,7 +485,7 @@ def print_metrics_for_all_prediction_targets(csv_to_set_optimal_threshold: Path, def get_correct_and_misclassified_examples(val_metrics_csv: Path, test_metrics_csv: Path, - prediction_target: str = MetricsDict.DEFAULT_HUE_KEY) -> Results: + prediction_target: str = MetricsDict.DEFAULT_HUE_KEY) -> Optional[Results]: """ Given the paths to the metrics files for the validation and test sets, get a list of true positives, false positives, false negatives and true negatives. @@ -495,12 +496,18 @@ def get_correct_and_misclassified_examples(val_metrics_csv: Path, test_metrics_c """ df_val = read_csv_and_filter_prediction_target(val_metrics_csv, prediction_target) + if len(df_val) == 0: + return None + fpr, tpr, thresholds = roc_curve(df_val[LoggingColumns.Label.value], df_val[LoggingColumns.ModelOutput.value]) optimal_idx = MetricsDict.get_optimal_idx(fpr=fpr, tpr=tpr) optimal_threshold = thresholds[optimal_idx] df_test = read_csv_and_filter_prediction_target(test_metrics_csv, prediction_target) + if len(df_test) == 0: + return None + df_test["predicted"] = df_test.apply(lambda x: int(x[LoggingColumns.ModelOutput.value] >= optimal_threshold), axis=1) @@ -516,7 +523,7 @@ def get_correct_and_misclassified_examples(val_metrics_csv: Path, test_metrics_c def get_k_best_and_worst_performing(val_metrics_csv: Path, test_metrics_csv: Path, k: int, - prediction_target: str = MetricsDict.DEFAULT_HUE_KEY) -> Results: + prediction_target: str = MetricsDict.DEFAULT_HUE_KEY) -> Optional[Results]: """ Get the top "k" best predictions (i.e. correct classifications where the model was the most certain) and the top "k" worst predictions (i.e. misclassifications where the model was the most confident). @@ -524,6 +531,8 @@ def get_k_best_and_worst_performing(val_metrics_csv: Path, test_metrics_csv: Pat results = get_correct_and_misclassified_examples(val_metrics_csv=val_metrics_csv, test_metrics_csv=test_metrics_csv, prediction_target=prediction_target) + if results is None: + return None # sort by model_output sorted = Results(true_positives=results.true_positives.sort_values(by=LoggingColumns.ModelOutput.value, @@ -553,6 +562,9 @@ def print_k_best_and_worst_performing(val_metrics_csv: Path, test_metrics_csv: P test_metrics_csv=test_metrics_csv, k=k, prediction_target=prediction_target) + if results is None: + print_header("Empty validation or test set", level=2) + return print_header(f"Top {k} false positives", level=2) for index, (subject, model_output) in enumerate(zip(results.false_positives[LoggingColumns.Patient.value], @@ -729,6 +741,9 @@ def plot_k_best_and_worst_performing(val_metrics_csv: Path, test_metrics_csv: Pa test_metrics_csv=test_metrics_csv, k=k, prediction_target=prediction_target) + if results is None: + print_header("Empty validation or test set", level=4) + return test_metrics = pd.read_csv(test_metrics_csv, dtype=str) diff --git a/InnerEye/ML/run_ml.py b/InnerEye/ML/run_ml.py index 67ef3e443..181ee84dd 100644 --- a/InnerEye/ML/run_ml.py +++ b/InnerEye/ML/run_ml.py @@ -873,7 +873,7 @@ def get_epoch_path(mode: ModelExecutionMode) -> Path: val_metrics=path_to_best_epoch_val, test_metrics=path_to_best_epoch_test) - if len(config.class_names) > 1: + if config.should_generate_multilabel_report(): generate_classification_multilabel_notebook( result_notebook=reports_dir / get_ipynb_report_name( f"{config.model_category.value}_multilabel"), @@ -883,6 +883,11 @@ def get_epoch_path(mode: ModelExecutionMode) -> Path: test_metrics=path_to_best_epoch_test) else: logging.info(f"Cannot create report for config of type {type(config)}.") + + config.generate_custom_report(report_dir=reports_dir, + train_metrics=path_to_best_epoch_train, + val_metrics=path_to_best_epoch_val, + test_metrics=path_to_best_epoch_test) except Exception as ex: print_exception(ex, "Failed to generated reporting notebook.") raise diff --git a/InnerEye/ML/scalar_config.py b/InnerEye/ML/scalar_config.py index 952db3493..a903806df 100644 --- a/InnerEye/ML/scalar_config.py +++ b/InnerEye/ML/scalar_config.py @@ -42,12 +42,15 @@ class ScalarLoss(Enum): BinaryCrossEntropyWithLogits = "BinaryCrossEntropyWithLogits" WeightedCrossEntropyWithLogits = "WeightedCrossEntropyWithLogits" MeanSquaredError = "MeanSquaredError" + CustomClassification = "CustomClassification" + CustomRegression = "CustomRegression" def is_classification_loss(self) -> bool: - return self == self.BinaryCrossEntropyWithLogits or self == self.WeightedCrossEntropyWithLogits + return self in {self.BinaryCrossEntropyWithLogits, self.WeightedCrossEntropyWithLogits, + self.CustomClassification} def is_regression_loss(self) -> bool: - return self == self.MeanSquaredError + return self in {self.MeanSquaredError, self.CustomRegression} @unique @@ -112,6 +115,14 @@ class ScalarModelBase(ModelConfigBase): "For binary classification, this field must be a list of size 1, and " "is by default ['Default'], but can optionally be set to a more descriptive " "name for the positive class.") + target_names: List[str] = param.List(class_=str, + default=None, + bounds=(1, None), + doc="The label names for each output target, used for logging metrics and " + "reporting results. If provided, the length of this list must match the " + "number of model outputs (and of transformed labels, if defined; see " + "get_posthoc_label_transform()). By default, this inherits the value of " + "class_names at initialisation.") aggregation_type: AggregationType = param.ClassSelector(default=AggregationType.Average, class_=AggregationType, doc="The type of global pooling aggregation to use between" " the encoder and the classifier.") @@ -214,6 +225,8 @@ def __init__(self, num_dataset_reader_workers: int = 0, **params: Any) -> None: "num_dataset_reader_workers to 0 as this is an AML run.") else: self.num_dataset_reader_workers = num_dataset_reader_workers + if self.target_names is None: + self.target_names = self.class_names def validate(self) -> None: if len(self.class_names) > 1 and not self.is_classification_model: @@ -240,6 +253,10 @@ def is_non_imaging_model(self) -> bool: """ return len(self.image_channels) == 0 + def should_generate_multilabel_report(self) -> bool: + """Determines whether to produce a multilabel report. Override this to implement custom behaviour.""" + return len(self.class_names) > 1 + def get_total_number_of_non_imaging_features(self) -> int: """Returns the total number of non imaging features expected in the input""" return self.get_total_number_of_numerical_non_imaging_features() + \ @@ -338,6 +355,14 @@ def get_label_transform(self) -> Union[Callable, List[Callable]]: """ return LabelTransformation.identity + def get_posthoc_label_transform(self) -> Callable: + """ + Return a transformation to apply to the labels after they are loaded, for computing losses, metrics, and + reports. The transformed labels refer to the config's target_names, if defined (class_names, otherwise). + If not overriden, this method does not change the loaded labels. + """ + return lambda x: x # no-op by default + def read_dataset_into_dataframe_and_pre_process(self) -> None: assert self.local_dataset is not None file_path = self.local_dataset / self.dataset_csv @@ -408,6 +433,12 @@ def get_total_number_of_training_samples(self) -> int: def create_model(self) -> Any: pass + def get_loss_function(self) -> Callable: + """Returns a custom loss function to be used with ScalarLoss.CustomClassification or CustomRegression.""" + assert self.loss_type in {ScalarLoss.CustomClassification, ScalarLoss.CustomRegression}, \ + f"get_loss_function() should be called only for custom loss types (received {self.loss_type})" + raise NotImplementedError(f"get_loss_function() must be implemented for loss type {self.loss_type}") + def get_post_loss_logits_normalization_function(self) -> Callable: """ Post loss normalization function to apply to the logits produced by the model. diff --git a/InnerEye/ML/utils/model_util.py b/InnerEye/ML/utils/model_util.py index 1e06bfc20..d182f3432 100644 --- a/InnerEye/ML/utils/model_util.py +++ b/InnerEye/ML/utils/model_util.py @@ -115,6 +115,8 @@ def create_scalar_loss_function(config: ScalarModelBase) -> torch.nn.Module: num_train_samples=config.get_total_number_of_training_samples()) elif config.loss_type == ScalarLoss.MeanSquaredError: return MSELoss() + elif config.loss_type == ScalarLoss.CustomClassification or config.loss_type == ScalarLoss.CustomRegression: + return config.get_loss_function() # type: ignore else: raise NotImplementedError(f"Loss type {config.loss_type} is not implemented") diff --git a/Tests/ML/models/test_scalar_model.py b/Tests/ML/models/test_scalar_model.py index 4bd7ef657..2f709cd84 100644 --- a/Tests/ML/models/test_scalar_model.py +++ b/Tests/ML/models/test_scalar_model.py @@ -56,7 +56,7 @@ def test_train_classification_model(class_name: str, test_output_dirs: OutputFol """ logging_to_stdout(logging.DEBUG) config = ClassificationModelForTesting() - config.class_names = [class_name] + config.class_names = config.target_names = [class_name] config.set_output_to(test_output_dirs.root_dir) # Train for 4 epochs, checkpoints at epochs 2 and 4 config.num_epochs = 4 diff --git a/Tests/ML/reports/test_classification_report.py b/Tests/ML/reports/test_classification_report.py index 8b5d30a14..2fef24568 100644 --- a/Tests/ML/reports/test_classification_report.py +++ b/Tests/ML/reports/test_classification_report.py @@ -301,6 +301,8 @@ def test_get_correct_and_misclassified_examples() -> None: results = get_correct_and_misclassified_examples(val_metrics_csv=val_metrics_file, test_metrics_csv=test_metrics_file) + assert results is not None # for mypy + true_positives = [item[LoggingColumns.Patient.value] for _, item in results.true_positives.iterrows()] assert all([i in true_positives for i in [3, 4, 5]]) @@ -323,6 +325,8 @@ def test_get_k_best_and_worst_performing() -> None: test_metrics_csv=test_metrics_file, k=2) + assert results is not None # for mypy + best_true_positives = [item[LoggingColumns.Patient.value] for _, item in results.true_positives.iterrows()] assert best_true_positives == [5, 4]