diff --git a/im2deep/__main__.py b/im2deep/__main__.py index a01508b..fe91858 100644 --- a/im2deep/__main__.py +++ b/im2deep/__main__.py @@ -101,6 +101,13 @@ def setup_logging(passed_level): default=2, help="Charge state to use for calibration. Only used if calibrate_per_charge is set to False.", ) +@click.option( + "--use_single_model", + type=click.BOOL, + default=False, + help="Use a single model for prediction.", +) + def main( psm_file: str, calibration_file: Optional[str] = None, @@ -108,6 +115,7 @@ def main( model_name: Optional[str] = "tims", log_level: Optional[str] = "info", n_jobs: Optional[int] = None, + use_single_model: Optional[bool] = False, calibrate_per_charge: Optional[bool] = True, use_charge_state: Optional[int] = 2, ): @@ -185,6 +193,7 @@ def main( calibrate_per_charge=calibrate_per_charge, use_charge_state=use_charge_state, n_jobs=n_jobs, + use_single_model=use_single_model, ) except IM2DeepError as e: LOGGER.error(e) diff --git a/im2deep/calibrate.py b/im2deep/calibrate.py index 5b028b3..9fb2987 100644 --- a/im2deep/calibrate.py +++ b/im2deep/calibrate.py @@ -84,7 +84,7 @@ def get_ccs_shift( """Calculating CCS shift based on {} overlapping peptide-charge pairs between PSMs and reference dataset""".format(both.shape[0]) ) - LOGGER.debug(both.columns) + # How much CCS in calibration data is larger than reference CCS, so predictions # need to be increased by this amount return 0 if both.shape[0] == 0 else np.mean(both["ccs_observed"] - both["CCS"]) diff --git a/im2deep/im2deep.py b/im2deep/im2deep.py index b1bc3d1..00e9894 100644 --- a/im2deep/im2deep.py +++ b/im2deep/im2deep.py @@ -20,6 +20,7 @@ def predict_ccs( model_name="tims", calibrate_per_charge=True, use_charge_state=2, + use_single_model=False, n_jobs=None, write_output=True, ): @@ -31,6 +32,8 @@ def predict_ccs( path_model = Path(__file__).parent / "models" / "TIMS" path_model_list = list(path_model.glob("*.hdf5")) + if use_single_model: + path_model_list = [path_model_list[0]] dlc = DeepLC(path_model=path_model_list, n_jobs=n_jobs, predict_ccs=True) LOGGER.info("Predicting CCS values...")