diff --git a/ms2query/benchmarking/k_fold_cross_validation.py b/ms2query/benchmarking/k_fold_cross_validation.py index c0138626..95a03b07 100644 --- a/ms2query/benchmarking/k_fold_cross_validation.py +++ b/ms2query/benchmarking/k_fold_cross_validation.py @@ -12,7 +12,8 @@ clean_normalize_and_split_annotated_spectra from ms2query.create_new_library.split_data_for_training import ( select_spectra_per_unique_inchikey, split_spectra_in_random_inchikey_sets) -from ms2query.create_new_library.train_models import train_all_models +from ms2query.create_new_library.train_models import (SettingsTrainingModels, + train_all_models) from ms2query.ms2library import create_library_object_from_one_dir from ms2query.utils import (load_matchms_spectrum_objects_from_file, save_pickled_file) @@ -102,7 +103,8 @@ def train_models_and_test_result_from_k_fold_folder(k_fold_split_folder:str, # Train all models train_all_models(annotated_training_spectra, unannotated_training_spectra, - models_folder) + models_folder, + SettingsTrainingModels({"add_compound_classes": False})) # Generate test results ms2library = create_library_object_from_one_dir(models_folder) diff --git a/ms2query/create_new_library/train_models.py b/ms2query/create_new_library/train_models.py index aaf99426..f0f53d17 100644 --- a/ms2query/create_new_library/train_models.py +++ b/ms2query/create_new_library/train_models.py @@ -18,11 +18,12 @@ class SettingsTrainingModels: def __init__(self, - settings): + settings: dict = None): default_settings = {"ms2ds_fraction_validation_spectra": 30, "ms2ds_epochs": 150, "spec2vec_iterations": 30, - "ms2query_fraction_for_making_pairs": 40} + "ms2query_fraction_for_making_pairs": 40, + "add_compound_classes": True} if settings: for setting in settings: assert setting in default_settings, \ @@ -32,15 +33,15 @@ def __init__(self, self.ms2ds_epochs: int = default_settings["ms2ds_epochs"] self.ms2query_fraction_for_making_pairs: int = default_settings["ms2query_fraction_for_making_pairs"] self.spec2vec_iterations = default_settings["spec2vec_iterations"] + self.add_compound_classes = default_settings["add_compound_classes"] def train_all_models(annotated_training_spectra, unannotated_training_spectra, output_folder, - other_settings: dict = None): + settings: SettingsTrainingModels): if not os.path.isdir(output_folder): os.mkdir(output_folder) - settings = SettingsTrainingModels(other_settings) # set file names of new generated files ms2deepscore_model_file_name = os.path.join(output_folder, "ms2deepscore_model.hdf5") spec2vec_model_file_name = os.path.join(output_folder, "spec2vec_model.model") @@ -76,7 +77,8 @@ def train_all_models(annotated_training_spectra, library_files_creator = LibraryFilesCreator(annotated_training_spectra, output_folder, spec2vec_model_file_name, - ms2deepscore_model_file_name) + ms2deepscore_model_file_name, + add_compound_classes=settings.add_compound_classes) library_files_creator.create_all_library_files() @@ -92,12 +94,19 @@ def clean_and_train_models(spectrum_file: str, The ion mode of the spectra you want to use for training the models, choose from "positive" or "negative" :param output_folder: The folder in which the models and library files are stored. + :param model_train_settings: + The settings used for training the models, options can be found in SettingsTrainingModels. If None is given + all the default settings are used. The options and default settings are: + {"ms2ds_fraction_validation_spectra": 30, "ms2ds_epochs": 150, "spec2vec_iterations": 30, + "ms2query_fraction_for_making_pairs": 40, "add_compound_classes": False} """ if not os.path.exists(output_folder): os.mkdir(output_folder) assert os.path.isdir(output_folder), "The specified folder is not a folder" assert ion_mode in {"positive", "negative"}, "ion_mode should be set to 'positive' or 'negative'" + settings = SettingsTrainingModels(model_train_settings) + spectra = load_matchms_spectrum_objects_from_file(spectrum_file) annotated_spectra, unnnotated_spectra = clean_normalize_and_split_annotated_spectra(spectra, ion_mode, @@ -105,4 +114,4 @@ def clean_and_train_models(spectrum_file: str, train_all_models(annotated_spectra, unnnotated_spectra, output_folder, - model_train_settings) + settings) diff --git a/tests/conftest.py b/tests/conftest.py index 358b7aa8..6e492d7e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,7 @@ import os import numpy as np -import pytest import pandas as pd +import pytest from matchms import Spectrum from matchms.importing.load_from_mgf import load_from_mgf from ms2query.ms2library import MS2Library diff --git a/tests/test_ms2library.py b/tests/test_ms2library.py index ee4e022d..e37674e0 100644 --- a/tests/test_ms2library.py +++ b/tests/test_ms2library.py @@ -3,7 +3,7 @@ import numpy as np import pandas as pd from ms2query.ms2library import MS2Library, create_library_object_from_one_dir -from ms2query.utils import (SettingsRunMS2Query, column_names_for_output) +from ms2query.utils import SettingsRunMS2Query, column_names_for_output from tests.test_utils import check_correct_results_csv_file diff --git a/tests/test_train_models.py b/tests/test_train_models.py index 1d5efe72..61f6cc01 100644 --- a/tests/test_train_models.py +++ b/tests/test_train_models.py @@ -14,7 +14,8 @@ def test_train_all_models(path_to_general_test_files, tmp_path): {"ms2ds_fraction_validation_spectra": 2, "ms2ds_epochs": 2, "spec2vec_iterations": 2, - "ms2query_fraction_for_making_pairs": 400} + "ms2query_fraction_for_making_pairs": 400, + "add_compound_classes": False} ) ms2library = create_library_object_from_one_dir(models_folder) assert isinstance(ms2library, MS2Library) diff --git a/tests/test_train_ms2query_model.py b/tests/test_train_ms2query_model.py index 379a9599..296ef615 100644 --- a/tests/test_train_ms2query_model.py +++ b/tests/test_train_ms2query_model.py @@ -3,13 +3,14 @@ import numpy as np import pandas as pd import pytest +from matchms import Spectrum from onnxruntime import InferenceSession from ms2query.create_new_library.train_ms2query_model import ( DataCollectorForTraining, calculate_tanimoto_scores_with_library, convert_to_onnx_model, train_ms2query_model, train_random_forest) from ms2query.ms2library import MS2Library from ms2query.utils import predict_onnx_model -from matchms import Spectrum + if sys.version_info < (3, 8): pass