diff --git a/tests/test_library_files_creator.py b/tests/test_library_files_creator.py index cf0a5e9a..005ece74 100644 --- a/tests/test_library_files_creator.py +++ b/tests/test_library_files_creator.py @@ -4,8 +4,6 @@ from ms2query.clean_and_filter_spectra import normalize_and_filter_peaks from ms2query.create_new_library.library_files_creator import \ LibraryFilesCreator -from ms2query.utils import (load_matchms_spectrum_objects_from_file, - load_pickled_file) def test_give_already_used_file_name(tmp_path, path_to_general_test_files, hundred_test_spectra): @@ -17,28 +15,22 @@ def test_give_already_used_file_name(tmp_path, path_to_general_test_files, hundr LibraryFilesCreator(hundred_test_spectra, tmp_path) -def test_store_ms2ds_embeddings(tmp_path, path_to_general_test_files, - hundred_test_spectra, - expected_ms2ds_embeddings): +def test_create_ms2ds_embeddings(tmp_path, path_to_general_test_files, + hundred_test_spectra, + expected_ms2ds_embeddings): """Tests store_ms2ds_embeddings""" base_file_name = os.path.join(tmp_path, '100_test_spectra') library_spectra = [normalize_and_filter_peaks(s) for s in hundred_test_spectra if s is not None] test_create_files = LibraryFilesCreator(library_spectra, base_file_name, ms2ds_model_file_name=os.path.join(path_to_general_test_files, 'ms2ds_siamese_210301_5000_500_400.hdf5')) - test_create_files.store_ms2ds_embeddings() - - new_embeddings_file_name = os.path.join(base_file_name, "ms2ds_embeddings.pickle") - assert os.path.isfile(new_embeddings_file_name), \ - "Expected file to be created" - # Test if correct embeddings are stored - embeddings = load_pickled_file(new_embeddings_file_name) + embeddings = test_create_files.create_ms2ds_embeddings() pd.testing.assert_frame_equal(embeddings, expected_ms2ds_embeddings, check_exact=False, atol=1e-5) -def test_store_s2v_embeddings(tmp_path, path_to_general_test_files, hundred_test_spectra, +def test_create_s2v_embeddings(tmp_path, path_to_general_test_files, hundred_test_spectra, expected_s2v_embeddings): """Tests store_ms2ds_embeddings""" base_file_name = os.path.join(tmp_path, '100_test_spectra') @@ -46,20 +38,20 @@ def test_store_s2v_embeddings(tmp_path, path_to_general_test_files, hundred_test test_create_files = LibraryFilesCreator(library_spectra, base_file_name, s2v_model_file_name=os.path.join(path_to_general_test_files, "100_test_spectra_s2v_model.model")) - test_create_files.store_s2v_embeddings() - - new_embeddings_file_name = os.path.join(base_file_name, "s2v_embeddings.pickle") - assert os.path.isfile(new_embeddings_file_name), \ - "Expected file to be created" - embeddings = load_pickled_file(new_embeddings_file_name) + embeddings = test_create_files.create_s2v_embeddings() pd.testing.assert_frame_equal(embeddings, expected_s2v_embeddings, check_exact=False, atol=1e-5) -def test_create_sqlite_file(tmp_path, path_to_general_test_files, hundred_test_spectra): +def test_create_sqlite_file_with_embeddings(hundred_test_spectra, path_to_general_test_files): test_create_files = LibraryFilesCreator( - hundred_test_spectra[:20], output_directory=os.path.join(tmp_path, '100_test_spectra'), - add_compound_classes=False) + hundred_test_spectra[:20], + output_directory=os.path.join(path_to_general_test_files), + add_compound_classes=False, + ms2ds_model_file_name=os.path.join(path_to_general_test_files, + 'ms2ds_siamese_210301_5000_500_400.hdf5'), + s2v_model_file_name=os.path.join(path_to_general_test_files, + "100_test_spectra_s2v_model.model") + ) test_create_files.create_sqlite_file() -