Skip to content

Commit

Permalink
Merge pull request #223 from iomega/add_compound_classes_part_of_sett…
Browse files Browse the repository at this point in the history
…ings

Add add compound classes to settings
  • Loading branch information
niekdejonge authored Nov 23, 2023
2 parents 1487ab1 + 48c04ac commit 0733c60
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 12 deletions.
6 changes: 4 additions & 2 deletions ms2query/benchmarking/k_fold_cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
clean_normalize_and_split_annotated_spectra
from ms2query.create_new_library.split_data_for_training import (
select_spectra_per_unique_inchikey, split_spectra_in_random_inchikey_sets)
from ms2query.create_new_library.train_models import train_all_models
from ms2query.create_new_library.train_models import (SettingsTrainingModels,
train_all_models)
from ms2query.ms2library import create_library_object_from_one_dir
from ms2query.utils import (load_matchms_spectrum_objects_from_file,
save_pickled_file)
Expand Down Expand Up @@ -102,7 +103,8 @@ def train_models_and_test_result_from_k_fold_folder(k_fold_split_folder:str,
# Train all models
train_all_models(annotated_training_spectra,
unannotated_training_spectra,
models_folder)
models_folder,
SettingsTrainingModels({"add_compound_classes": False}))

# Generate test results
ms2library = create_library_object_from_one_dir(models_folder)
Expand Down
21 changes: 15 additions & 6 deletions ms2query/create_new_library/train_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@

class SettingsTrainingModels:
def __init__(self,
settings):
settings: dict = None):
default_settings = {"ms2ds_fraction_validation_spectra": 30,
"ms2ds_epochs": 150,
"spec2vec_iterations": 30,
"ms2query_fraction_for_making_pairs": 40}
"ms2query_fraction_for_making_pairs": 40,
"add_compound_classes": True}
if settings:
for setting in settings:
assert setting in default_settings, \
Expand All @@ -32,15 +33,15 @@ def __init__(self,
self.ms2ds_epochs: int = default_settings["ms2ds_epochs"]
self.ms2query_fraction_for_making_pairs: int = default_settings["ms2query_fraction_for_making_pairs"]
self.spec2vec_iterations = default_settings["spec2vec_iterations"]
self.add_compound_classes = default_settings["add_compound_classes"]


def train_all_models(annotated_training_spectra,
unannotated_training_spectra,
output_folder,
other_settings: dict = None):
settings: SettingsTrainingModels):
if not os.path.isdir(output_folder):
os.mkdir(output_folder)
settings = SettingsTrainingModels(other_settings)
# set file names of new generated files
ms2deepscore_model_file_name = os.path.join(output_folder, "ms2deepscore_model.hdf5")
spec2vec_model_file_name = os.path.join(output_folder, "spec2vec_model.model")
Expand Down Expand Up @@ -76,7 +77,8 @@ def train_all_models(annotated_training_spectra,
library_files_creator = LibraryFilesCreator(annotated_training_spectra,
output_folder,
spec2vec_model_file_name,
ms2deepscore_model_file_name)
ms2deepscore_model_file_name,
add_compound_classes=settings.add_compound_classes)
library_files_creator.create_all_library_files()


Expand All @@ -92,17 +94,24 @@ def clean_and_train_models(spectrum_file: str,
The ion mode of the spectra you want to use for training the models, choose from "positive" or "negative"
:param output_folder:
The folder in which the models and library files are stored.
:param model_train_settings:
The settings used for training the models, options can be found in SettingsTrainingModels. If None is given
all the default settings are used. The options and default settings are:
{"ms2ds_fraction_validation_spectra": 30, "ms2ds_epochs": 150, "spec2vec_iterations": 30,
"ms2query_fraction_for_making_pairs": 40, "add_compound_classes": False}
"""
if not os.path.exists(output_folder):
os.mkdir(output_folder)
assert os.path.isdir(output_folder), "The specified folder is not a folder"
assert ion_mode in {"positive", "negative"}, "ion_mode should be set to 'positive' or 'negative'"

settings = SettingsTrainingModels(model_train_settings)

spectra = load_matchms_spectrum_objects_from_file(spectrum_file)
annotated_spectra, unnnotated_spectra = clean_normalize_and_split_annotated_spectra(spectra,
ion_mode,
do_pubchem_lookup=True)
train_all_models(annotated_spectra,
unnnotated_spectra,
output_folder,
model_train_settings)
settings)
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import numpy as np
import pytest
import pandas as pd
import pytest
from matchms import Spectrum
from matchms.importing.load_from_mgf import load_from_mgf
from ms2query.ms2library import MS2Library
Expand Down
2 changes: 1 addition & 1 deletion tests/test_ms2library.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import pandas as pd
from ms2query.ms2library import MS2Library, create_library_object_from_one_dir
from ms2query.utils import (SettingsRunMS2Query, column_names_for_output)
from ms2query.utils import SettingsRunMS2Query, column_names_for_output
from tests.test_utils import check_correct_results_csv_file


Expand Down
3 changes: 2 additions & 1 deletion tests/test_train_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ def test_train_all_models(path_to_general_test_files, tmp_path):
{"ms2ds_fraction_validation_spectra": 2,
"ms2ds_epochs": 2,
"spec2vec_iterations": 2,
"ms2query_fraction_for_making_pairs": 400}
"ms2query_fraction_for_making_pairs": 400,
"add_compound_classes": False}
)
ms2library = create_library_object_from_one_dir(models_folder)
assert isinstance(ms2library, MS2Library)
3 changes: 2 additions & 1 deletion tests/test_train_ms2query_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
import numpy as np
import pandas as pd
import pytest
from matchms import Spectrum
from onnxruntime import InferenceSession
from ms2query.create_new_library.train_ms2query_model import (
DataCollectorForTraining, calculate_tanimoto_scores_with_library,
convert_to_onnx_model, train_ms2query_model, train_random_forest)
from ms2query.ms2library import MS2Library
from ms2query.utils import predict_onnx_model
from matchms import Spectrum


if sys.version_info < (3, 8):
pass
Expand Down

0 comments on commit 0733c60

Please sign in to comment.