diff --git a/biotrainer/embedders/embedding_service.py b/biotrainer/embedders/embedding_service.py index 663c34ae..e688e952 100644 --- a/biotrainer/embedders/embedding_service.py +++ b/biotrainer/embedders/embedding_service.py @@ -1,6 +1,7 @@ import os import gc import time +import psutil import h5py import torch import logging @@ -9,17 +10,20 @@ from tqdm import tqdm from pathlib import Path from numpy import ndarray -from typing import Dict, Any, List, Optional +from typing import Dict, Tuple, Any, List, Optional from .embedder_interfaces import EmbedderInterface from ..protocols import Protocol -from ..utilities import read_FASTA, SAVE_AFTER_N_EMBEDDINGS +from ..utilities import read_FASTA logger = logging.getLogger(__name__) class EmbeddingService: + """ + A service class for computing embeddings using a provided embedder. + """ def __init__(self, embedder: EmbedderInterface = None, use_half_precision: bool = False): self._embedder = embedder @@ -29,15 +33,18 @@ def compute_embeddings(self, sequence_file: str, output_dir: Path, protocol: Pro force_output_dir: Optional[bool] = False, force_recomputing: Optional[bool] = False) -> str: """ - Compute embeddings with the provided embedder from file. + Compute embeddings with the provided embedder from a sequence file. - :param sequence_file: Path to the sequence file - :param output_dir: Output directory to store the computed embeddings - :param protocol: Protocol for the embeddings. Determines if the embeddings should be reduced to per-protein - :param force_output_dir: If True, the given output directory is directly used to store the embeddings file, - without any path enhancing - :param force_recomputing: If True, the embedding file is re-recomputed, even if it already exists - :return: Path to the generated output embeddings file + Parameters: + sequence_file (str): Path to the sequence file. + output_dir (Path): Output directory to store the computed embeddings. + protocol (Protocol): Protocol for the embeddings. Determines if the embeddings should be reduced to per-protein. + force_output_dir (bool, optional): If True, the given output directory is directly used to store the embeddings file, + without any path enhancement. Defaults to False. + force_recomputing (bool, optional): If True, the embedding file is re-computed, even if it already exists. Defaults to False. + + Returns: + str: Path to the generated output embeddings file. """ use_reduced_embeddings = protocol in Protocol.using_per_sequence_embeddings() embedder_name = self._embedder.name.split("/")[-1] @@ -68,45 +75,165 @@ def compute_embeddings(self, sequence_file: str, output_dir: Path, protocol: Pro protein_sequences = {seq.id: str(seq.seq) for seq in sorted(read_FASTA(sequence_file), key=lambda seq: len(seq.seq), reverse=True)} + + embeddings_file_path = self._do_embeddings_computation(protein_sequences, embeddings_file_path, use_reduced_embeddings) + + return str(embeddings_file_path) + + def _do_embeddings_computation(self, protein_sequences: Dict[str, str], embeddings_file_path: Path, + use_reduced_embeddings: bool) -> str: + """ + Performs the embedding service for the given protein sequences. + + Parameters: + protein_sequences (Dict[str, str]): A dictionary mapping sequence IDs to protein sequences. + embeddings_file_path (Path): The path where embeddings will be saved. + use_reduced_embeddings (bool): Indicates if reduced embeddings should be used. + + Returns: + str: The path to the embeddings file. + """ sequence_ids = list(protein_sequences.keys()) embeddings = {} idx: int = 0 last_save_id: int = 0 - for embedding in tqdm(self._embedder.embed_many(protein_sequences.values()), - total=len(protein_sequences.values())): - embeddings[sequence_ids[idx]] = embedding - idx += 1 - if len(embeddings) > SAVE_AFTER_N_EMBEDDINGS: - if use_reduced_embeddings: - embeddings = self._reduce_embeddings(embeddings, self._embedder) - last_save_id = self._save_embeddings(save_id=last_save_id, embeddings=embeddings, - embeddings_file_path=embeddings_file_path) - - # Manually clear to ensure that embeddings are deleted from RAM - del embeddings - embeddings = {} - + start_time = time.time() + + embedding_iter = self._embedder.embed_many(protein_sequences.values()) + total_sequences = len(protein_sequences.values()) + + logger.info("If your dataset contains long reads, it may take more time to process the first few sequences.") + + with tqdm(total=total_sequences, desc="Computing Embeddings") as pbar: + + # Load the first sequence and calculate the initial max_embedding_fit + embeddings[sequence_ids[0]] = next(embedding_iter, None) + pbar.update(1) + + max_embedding_fit = self._max_embedding_fit(embeddings[sequence_ids[0]]) + + if embeddings[sequence_ids[0]] is None: + logger.info(f"No embeddings found.") + return str(embeddings_file_path) + + embedding_dimension = embeddings[sequence_ids[0]].shape[-1] + + # Load other sequences + for idx in range(1, total_sequences): + if max_embedding_fit <= 3 or len(embeddings) % max_embedding_fit == 0 or idx == total_sequences - 1: + last_save_id, embeddings = self._save_and_reset_embeddings(embeddings, last_save_id, + embeddings_file_path, use_reduced_embeddings) + logger.debug(f"New {max_embedding_fit=}") + + embeddings[sequence_ids[idx]] = next(embedding_iter, None) + pbar.update(1) + + # Calculate the new max_embedding_fit for the next batch + max_embedding_fit = self._max_embedding_fit(embeddings[sequence_ids[idx]]) + + else: + embeddings[sequence_ids[idx]] = next(embedding_iter, None) + pbar.update(1) + + if embeddings[sequence_ids[idx]] is None: + logger.debug(f"len(sequence_ids) > len(embedding_iter) or found a None value in the embedding_iter") + del embeddings[sequence_ids[idx]] + return str(embeddings_file_path) + + logger.info(f"Embedding dimension: {embedding_dimension}") + # Save remaining embeddings if len(embeddings) > 0: - if use_reduced_embeddings: - embeddings = self._reduce_embeddings(embeddings, self._embedder) - _ = self._save_embeddings(save_id=last_save_id, embeddings=embeddings, - embeddings_file_path=embeddings_file_path) + last_save_id, embeddings = self._save_and_reset_embeddings(embeddings, last_save_id, + embeddings_file_path, use_reduced_embeddings) - # Delete embeddings and embedding model from memory now, because they will no longer be needed + end_time = time.time() + logger.info(f"Time elapsed for saving embeddings: {end_time - start_time:.2f}[s]") + del embeddings del self._embedder gc.collect() return str(embeddings_file_path) + @staticmethod + def _max_embedding_fit(embedding: ndarray) -> int: + """ + Calculates the maximum number of embeddings that can fit in available memory. + + This function estimates the maximum number of embeddings that can be stored in + the available system memory without exceeding it. The calculation includes a + safety factor to prevent exhausting memory. + + Parameters: + embedding (ndarray): An embedding array, representing the data structure + whose memory footprint is being considered. + + Returns: + int: The maximum number of embeddings that can fit in memory. + + Notes: + - The number 18 was determined experimentally as a factor correlating the + embedding size to the memory usage, indicating that each unit of + embedding size corresponds to approximately 18 bytes of memory. + - The multiplier 0.75 is a safety margin to ensure that the memory usage + stays within 75% of the available system memory, reducing the risk of + running out of RAM during operations. + """ + max_embedding_fit = int(0.75 * (psutil.virtual_memory().available / (embedding.size*18))) + max_embedding_fit = 1 if max_embedding_fit == 0 else max_embedding_fit + return max_embedding_fit + + def _save_and_reset_embeddings(self, embeddings: Dict[str, ndarray], last_save_id: int, + embeddings_file_path: Path, use_reduced_embeddings: bool) -> Tuple[int, Dict[str, ndarray]]: + """ + Save the embeddings and reset the dictionary. + + Parameters: + embeddings (Dict[str, ndarray]): Dictionary of embeddings to be saved. + last_save_id (int): The last save ID used for tracking saved embeddings. + embeddings_file_path (Path): The path where embeddings are saved. + use_reduced_embeddings (bool): Flag to determine if embeddings should be reduced. + + Returns: + out (Tuple[int, Dict[str, ndarray]]): Updated last_save_id and an empty embeddings dictionary. + """ + if use_reduced_embeddings: + embeddings = self._reduce_embeddings(embeddings, self._embedder) + last_save_id = self._save_embeddings(save_id=last_save_id, embeddings=embeddings, + embeddings_file_path=embeddings_file_path) + del embeddings + return last_save_id, {} + + @staticmethod def _reduce_embeddings(embeddings: Dict[str, ndarray], embedder) -> Dict[str, ndarray]: + """ + Reduces the per-residue embeddings to per-protein embeddings. + + Parameters: + embeddings (Dict[str, ndarray]): Dictionary of embeddings. + embedder: The embedder used for reducing embeddings. + + Returns: + out (Dict[str, ndarray]): Dictionary of reduced embeddings. + """ return {seq_id: embedder.reduce_per_protein(embedding) for seq_id, embedding in embeddings.items()} @staticmethod def _save_embeddings(save_id: int, embeddings: Dict[str, ndarray], embeddings_file_path: Path) -> int: + """ + Saves the embeddings to a file. + + Args: + save_id (int): The save ID used for tracking saved embeddings. + embeddings (Dict[str, ndarray]): Dictionary of embeddings to be saved. + embeddings_file_path (Path): The path where embeddings are saved. + + Returns: + out (int): The updated save ID. + """ with h5py.File(embeddings_file_path, "a") as embeddings_file: idx = save_id for seq_id, embedding in embeddings.items(): @@ -119,9 +246,12 @@ def compute_embeddings_from_list(self, protein_sequences: List[str], protocol: P """ Compute embeddings with the provided embedder directly from a list of sequences. - :param protein_sequences: List of protein sequences as string - :param protocol: Protocol for the embeddings. Determines if the embeddings should be reduced to per-protein - :return: List of computed embeddings + Parameters: + protein_sequences (List[str]): List of protein sequences as strings. + protocol (Protocol): Protocol for the embeddings. Determines if the embeddings should be reduced to per-protein. + + Returns: + out (List): List of computed embeddings. """ use_reduced_embeddings = protocol in Protocol.using_per_sequence_embeddings() @@ -134,6 +264,15 @@ def compute_embeddings_from_list(self, protein_sequences: List[str], protocol: P @staticmethod def load_embeddings(embeddings_file_path: str) -> Dict[str, Any]: + """ + Loads precomputed embeddings from a file. + + Parameters: + embeddings_file_path (str): Path to the embeddings file. + + Returns: + out (Dict[str, Any]): Dictionary mapping sequence IDs to embeddings. + """ # Load computed embeddings in .h5 file format logger.info(f"Loading embeddings from: {embeddings_file_path}") start = time.perf_counter() diff --git a/biotrainer/utilities/__init__.py b/biotrainer/utilities/__init__.py index 25eca24a..f4333829 100644 --- a/biotrainer/utilities/__init__.py +++ b/biotrainer/utilities/__init__.py @@ -3,7 +3,7 @@ from .cuda_device import get_device, is_device_cpu from .data_classes import Split, SplitResult, DatasetSample from .constants import SEQUENCE_PAD_VALUE, MASK_AND_LABELS_PAD_VALUE, INTERACTION_INDICATOR, \ - METRICS_WITHOUT_REVERSED_SORTING, SAVE_AFTER_N_EMBEDDINGS + METRICS_WITHOUT_REVERSED_SORTING from .fasta import read_FASTA, get_attributes_from_seqrecords, \ get_attributes_from_seqrecords_for_protein_interactions, get_split_lists @@ -19,7 +19,6 @@ 'MASK_AND_LABELS_PAD_VALUE', 'INTERACTION_INDICATOR', 'METRICS_WITHOUT_REVERSED_SORTING', - 'SAVE_AFTER_N_EMBEDDINGS', 'Split', 'SplitResult', 'DatasetSample', diff --git a/biotrainer/utilities/constants.py b/biotrainer/utilities/constants.py index 18093fd2..3a25942a 100644 --- a/biotrainer/utilities/constants.py +++ b/biotrainer/utilities/constants.py @@ -13,6 +13,3 @@ # Usually, a higher metric means a better model (accuracy: 0.9 > 0.8, precision: 0.5 > 0.1 ..) # For some metrics, however, the opposite is true (loss: 0.1 > 0.2, rmse: 20.05 > 40.05) METRICS_WITHOUT_REVERSED_SORTING: Final[List[str]] = ["loss", "mse", "rmse"] - -# Embeddings -SAVE_AFTER_N_EMBEDDINGS: Final[int] = 100 diff --git a/tests/test_embedding_service.py b/tests/test_embedding_service.py new file mode 100644 index 00000000..2bcb4f79 --- /dev/null +++ b/tests/test_embedding_service.py @@ -0,0 +1,121 @@ +import unittest +import logging +import tempfile +import os +import random +import psutil +from pathlib import Path + +from biotrainer.embedders import OneHotEncodingEmbedder, EmbeddingService +from biotrainer.protocols import Protocol + +logger = logging.getLogger(__name__) + +class TestEmbeddingService(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls._configure_logging() + + @classmethod + def _configure_logging(cls): + logging.basicConfig( + level=logging.DEBUG, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[logging.StreamHandler()] + ) + logger.info('Setting up the test case class') + + def setUp(self): + self._setup_test_environment() + self._setup_fasta_parameters() + + def _setup_test_environment(self): + logger.info('Setting up test environment') + self.embedder = OneHotEncodingEmbedder() + self.embedding_service = EmbeddingService(embedder=self.embedder) + + def _setup_fasta_parameters(self): + self.num_reads = 100 + if os.environ.get('CI') == 'true': + logger.debug("CI environment detected. Generating ultra-long sequences to test memory limits and extreme edge cases.") + logger.debug( + "Sequence generation may take considerable time due to large memory allocation: " + f"{psutil.virtual_memory().available / (1024 ** 3):.2f} GB available." + ) + logger.debug( + "The calculations for embeddings of these ultra-long sequences are also computationally intensive, " + "which contributes to the overall long duration of the test." + ) + # Use the original calculation in CI environment + self.long_length = int((0.75 * psutil.virtual_memory().available) / (18 * 21)) + else: + logger.debug("Local environment detected. Using shorter sequence length for faster testing.") + # Use a fixed value for local development + self.long_length = 50000 + self.other_length = 250 + + @staticmethod + def _generate_sequence(length): + """Generate a random sequence of a given length.""" + return ''.join(random.choice("ACDE") for _ in range(length)) + + def _generate_fasta(self, num_reads, long_length, other_length, filename, include_long=True, include_short=True): + """Generate a FASTA file with a specified number of reads.""" + logger.info(f"Generating FASTA file: {filename}") + with open(filename, 'w') as file: + if include_long: + logger.info(f"Generating long sequence, it may take a bit of time.") + for i in range(1, 3): + file.write(f">read_{i}\n{self._generate_sequence(long_length)}\n") + + if include_short: + for i in range(3, num_reads + 1): + file.write(f">read_{i}\n{self._generate_sequence(other_length)}\n") + logger.info(f"FASTA file generated: {filename}") + + def _run_embedding_test(self, test_name, num_reads, protocol, include_long=True, include_short=True): + with tempfile.TemporaryDirectory() as tmp_dir: + logger.info(f'Starting test: compute_embeddings_{test_name}') + sequence_path = os.path.join(tmp_dir, f"{test_name}.fasta") + self._generate_fasta(num_reads, self.long_length, self.other_length, sequence_path, include_long, include_short) + result = self._compute_embeddings(sequence_path, tmp_dir, protocol) + self._verify_result(protocol, result, tmp_dir) + logger.info(f'Test compute_embeddings_{test_name} completed successfully') + + def _compute_embeddings(self, sequence_path, output_dir, protocol): + logger.info('Computing embeddings') + return self.embedding_service.compute_embeddings( + sequence_path, + Path(output_dir), + protocol + ) + + def _verify_result(self, protocol, result, tmp_dir): + self.assertTrue(os.path.exists(result), f"Result file does not exist: {result}") + if protocol == Protocol.sequence_to_class: + expected_path = os.path.join(tmp_dir, "sequence_to_class", "one_hot_encoding", "reduced_embeddings_file_one_hot_encoding.h5") + elif protocol == Protocol.residue_to_class: + expected_path = os.path.join(tmp_dir, "residue_to_class", "one_hot_encoding", "embeddings_file_one_hot_encoding.h5") + self.assertEqual(result, expected_path, f"Unexpected result path. Expected {expected_path}, got {result}") + + # Test methods + def test_long_sequence_to_class(self): + self._run_embedding_test("long_sequences", 2, Protocol.sequence_to_class, include_short=False) + + def test_short_sequence_to_class(self): + self._run_embedding_test("short_sequences", self.num_reads, Protocol.sequence_to_class,include_long=False) + + def test_mixed_sequence_to_class(self): + self._run_embedding_test("mixed_sequences", self.num_reads, Protocol.sequence_to_class) + + def test_long_residue_to_class(self): + self._run_embedding_test("long_sequences", 2, Protocol.residue_to_class, include_short=False) + + def test_short_residue_to_class(self): + self._run_embedding_test("short_sequences", self.num_reads, Protocol.residue_to_class,include_long=False) + + def test_mixed_residue_to_class(self): + self._run_embedding_test("mixed_sequences", self.num_reads, Protocol.residue_to_class) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file