Skip to content

Commit

Permalink
Merge pull request #369 from MannLabs/calibration-metrics
Browse files Browse the repository at this point in the history
Improved calibration and optimization stats
  • Loading branch information
GeorgWa authored Nov 13, 2024
2 parents 2f882c5 + 73e3ca7 commit 2a31f7b
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 17 deletions.
12 changes: 11 additions & 1 deletion alphadia/calibration/property.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(
float(transform_deviation) if transform_deviation is not None else None
)
self.is_fitted = False
self.metrics = None

def __repr__(self) -> str:
return f"<Calibration {self.name}, is_fitted: {self.is_fitted}>"
Expand Down Expand Up @@ -175,7 +176,9 @@ def fit(self, dataframe: pd.DataFrame, plot: bool = False, **kwargs):
logging.exception(f"Could not fit estimator {self.name}: {e}")
return

if plot is True:
self._save_metrics(dataframe)

if plot:
self.plot(dataframe, **kwargs)

def predict(self, dataframe, inplace=True):
Expand Down Expand Up @@ -297,6 +300,13 @@ def deviation(self, dataframe: pd.DataFrame):
axis=1,
)

def _save_metrics(self, dataframe):
deviation = self.deviation(dataframe)
self.metrics = {
"median_accuracy": np.median(np.abs(deviation[:, 1])),
"median_precision": np.median(np.abs(deviation[:, 2])),
}

def ci(self, dataframe, ci: float = 0.95):
"""Calculate the residual deviation at the given confidence interval.
Expand Down
66 changes: 54 additions & 12 deletions alphadia/outputtransform.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,6 +942,10 @@ def _build_run_stat_df(
folder, peptidecentric.PeptideCentricWorkflow.OPTIMIZATION_MANAGER_PATH
)

calibration_manager_path = os.path.join(
folder, peptidecentric.PeptideCentricWorkflow.CALIBRATION_MANAGER_PATH
)

if channels is None:
channels = [0]
out_df = []
Expand All @@ -956,31 +960,69 @@ def _build_run_stat_df(
"proteins": channel_df["pg"].nunique(),
}

if "weighted_mass_error" in channel_df.columns:
base_dict["ms1_accuracy"] = np.mean(channel_df["weighted_mass_error"])

if "cycle_fwhm" in channel_df.columns:
base_dict["fwhm_rt"] = np.mean(channel_df["cycle_fwhm"])

if "mobility_fwhm" in channel_df.columns:
base_dict["fwhm_mobility"] = np.mean(channel_df["mobility_fwhm"])

# collect optimization stats
base_dict["optimization.ms2_error"] = np.nan
base_dict["optimization.ms1_error"] = np.nan
base_dict["optimization.rt_error"] = np.nan
base_dict["optimization.mobility_error"] = np.nan

if os.path.exists(optimization_manager_path):
optimization_manager = manager.OptimizationManager(
path=optimization_manager_path
)

base_dict["ms2_error"] = optimization_manager.ms2_error
base_dict["ms1_error"] = optimization_manager.ms1_error
base_dict["rt_error"] = optimization_manager.rt_error
base_dict["mobility_error"] = optimization_manager.mobility_error
base_dict["optimization.ms2_error"] = optimization_manager.ms2_error
base_dict["optimization.ms1_error"] = optimization_manager.ms1_error
base_dict["optimization.rt_error"] = optimization_manager.rt_error
base_dict["optimization.mobility_error"] = (
optimization_manager.mobility_error
)

else:
logger.warning(f"Error reading optimization manager for {raw_name}")
base_dict["ms2_error"] = np.nan
base_dict["ms1_error"] = np.nan
base_dict["rt_error"] = np.nan
base_dict["mobility_error"] = np.nan

# collect calibration stats
base_dict["calibration.ms2_median_accuracy"] = np.nan
base_dict["calibration.ms2_median_precision"] = np.nan
base_dict["calibration.ms1_median_accuracy"] = np.nan
base_dict["calibration.ms1_median_precision"] = np.nan

if os.path.exists(calibration_manager_path):
calibration_manager = manager.CalibrationManager(
path=calibration_manager_path
)

if (
fragment_mz_estimator := calibration_manager.get_estimator(
"fragment", "mz"
)
) and (fragment_mz_metrics := fragment_mz_estimator.metrics):
base_dict["calibration.ms2_median_accuracy"] = fragment_mz_metrics[
"median_accuracy"
]
base_dict["calibration.ms2_median_precision"] = fragment_mz_metrics[
"median_precision"
]

if (
precursor_mz_estimator := calibration_manager.get_estimator(
"precursor", "mz"
)
) and (precursor_mz_metrics := precursor_mz_estimator.metrics):
base_dict["calibration.ms1_median_accuracy"] = precursor_mz_metrics[
"median_accuracy"
]
base_dict["calibration.ms1_median_precision"] = precursor_mz_metrics[
"median_precision"
]

else:
logger.warning(f"Error reading calibration manager for {raw_name}")

out_df.append(base_dict)

Expand Down
10 changes: 8 additions & 2 deletions alphadia/workflow/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import os
import pickle
import traceback
import typing
from collections import defaultdict
from copy import deepcopy
Expand Down Expand Up @@ -82,11 +83,16 @@ def save(self):
try:
with open(self.path, "wb") as f:
pickle.dump(self, f)
except Exception:
except Exception as e:
self.reporter.log_string(
f"Failed to save {self.__class__.__name__} to {self.path}",
f"Failed to save {self.__class__.__name__} to {self.path}: {str(e)}",
verbosity="error",
)
# Log the full traceback

self.reporter.log_string(
f"Traceback: {traceback.format_exc()}", verbosity="error"
)

def load(self):
"""Load the state from pickle file."""
Expand Down
8 changes: 8 additions & 0 deletions tests/unit_tests/test_calibration_property.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ def test_uninitialized_calibration():
with pytest.raises(ValueError):
mz_calibration.fit(mz_df)

assert mz_calibration.metrics is None


def test_fit_predict_linear():
library_mz = np.linspace(100, 1000, 100)
Expand All @@ -38,6 +40,8 @@ def test_fit_predict_linear():
mz_calibration.predict(mz_df)

assert "calibrated_mz" in mz_df.columns
assert "median_accuracy" in mz_calibration.metrics
assert "median_precision" in mz_calibration.metrics


def test_fit_predict_loess():
Expand All @@ -57,6 +61,8 @@ def test_fit_predict_loess():
mz_calibration.predict(mz_df)

assert "calibrated_mz" in mz_df.columns
assert "median_accuracy" in mz_calibration.metrics
assert "median_precision" in mz_calibration.metrics


def test_save_load():
Expand Down Expand Up @@ -86,3 +92,5 @@ def test_save_load():
mz_calibration_loaded.predict(df_loaded)

assert np.allclose(df_original["calibrated_mz"], df_loaded["calibrated_mz"])
assert "median_accuracy" in mz_calibration.metrics
assert "median_precision" in mz_calibration.metrics
5 changes: 3 additions & 2 deletions tests/unit_tests/test_outputtransform.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,9 @@ def test_output_transform():
os.path.join(temp_folder, f"{output.STAT_OUTPUT}.tsv"), sep="\t"
)
assert len(stat_df) == 3
assert stat_df["ms2_error"][0] == 6
assert stat_df["rt_error"][0] == 200

assert stat_df["optimization.ms2_error"][0] == 6
assert stat_df["optimization.rt_error"][0] == 200

assert all(
[
Expand Down

0 comments on commit 2a31f7b

Please sign in to comment.