Skip to content

Commit

Permalink
Update test_library_files_creator.py
Browse files Browse the repository at this point in the history
  • Loading branch information
niekdejonge committed Nov 23, 2023
1 parent 8ba91b3 commit deb5665
Showing 1 changed file with 15 additions and 23 deletions.
38 changes: 15 additions & 23 deletions tests/test_library_files_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from ms2query.clean_and_filter_spectra import normalize_and_filter_peaks
from ms2query.create_new_library.library_files_creator import \
LibraryFilesCreator
from ms2query.utils import (load_matchms_spectrum_objects_from_file,
load_pickled_file)


def test_give_already_used_file_name(tmp_path, path_to_general_test_files, hundred_test_spectra):
Expand All @@ -17,49 +15,43 @@ def test_give_already_used_file_name(tmp_path, path_to_general_test_files, hundr
LibraryFilesCreator(hundred_test_spectra, tmp_path)


def test_store_ms2ds_embeddings(tmp_path, path_to_general_test_files,
hundred_test_spectra,
expected_ms2ds_embeddings):
def test_create_ms2ds_embeddings(tmp_path, path_to_general_test_files,
hundred_test_spectra,
expected_ms2ds_embeddings):
"""Tests store_ms2ds_embeddings"""
base_file_name = os.path.join(tmp_path, '100_test_spectra')
library_spectra = [normalize_and_filter_peaks(s) for s in hundred_test_spectra if s is not None]
test_create_files = LibraryFilesCreator(library_spectra, base_file_name,
ms2ds_model_file_name=os.path.join(path_to_general_test_files,
'ms2ds_siamese_210301_5000_500_400.hdf5'))
test_create_files.store_ms2ds_embeddings()

new_embeddings_file_name = os.path.join(base_file_name, "ms2ds_embeddings.pickle")
assert os.path.isfile(new_embeddings_file_name), \
"Expected file to be created"
# Test if correct embeddings are stored
embeddings = load_pickled_file(new_embeddings_file_name)
embeddings = test_create_files.create_ms2ds_embeddings()
pd.testing.assert_frame_equal(embeddings, expected_ms2ds_embeddings,
check_exact=False,
atol=1e-5)


def test_store_s2v_embeddings(tmp_path, path_to_general_test_files, hundred_test_spectra,
def test_create_s2v_embeddings(tmp_path, path_to_general_test_files, hundred_test_spectra,
expected_s2v_embeddings):
"""Tests store_ms2ds_embeddings"""
base_file_name = os.path.join(tmp_path, '100_test_spectra')
library_spectra = [normalize_and_filter_peaks(s) for s in hundred_test_spectra if s is not None]
test_create_files = LibraryFilesCreator(library_spectra, base_file_name,
s2v_model_file_name=os.path.join(path_to_general_test_files,
"100_test_spectra_s2v_model.model"))
test_create_files.store_s2v_embeddings()

new_embeddings_file_name = os.path.join(base_file_name, "s2v_embeddings.pickle")
assert os.path.isfile(new_embeddings_file_name), \
"Expected file to be created"
embeddings = load_pickled_file(new_embeddings_file_name)
embeddings = test_create_files.create_s2v_embeddings()
pd.testing.assert_frame_equal(embeddings, expected_s2v_embeddings,
check_exact=False,
atol=1e-5)


def test_create_sqlite_file(tmp_path, path_to_general_test_files, hundred_test_spectra):
def test_create_sqlite_file_with_embeddings(hundred_test_spectra, path_to_general_test_files):
test_create_files = LibraryFilesCreator(
hundred_test_spectra[:20], output_directory=os.path.join(tmp_path, '100_test_spectra'),
add_compound_classes=False)
hundred_test_spectra[:20],
output_directory=os.path.join(path_to_general_test_files),
add_compound_classes=False,
ms2ds_model_file_name=os.path.join(path_to_general_test_files,
'ms2ds_siamese_210301_5000_500_400.hdf5'),
s2v_model_file_name=os.path.join(path_to_general_test_files,
"100_test_spectra_s2v_model.model")
)
test_create_files.create_sqlite_file()

0 comments on commit deb5665

Please sign in to comment.