Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize Memory Handling in Embedding Computations and Refactor EmbeddingService #103

Merged
merged 8 commits into from
Aug 26, 2024
203 changes: 171 additions & 32 deletions biotrainer/embedders/embedding_service.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import gc
import time
import psutil
import h5py
import torch
import logging
Expand All @@ -9,17 +10,20 @@
from tqdm import tqdm
from pathlib import Path
from numpy import ndarray
from typing import Dict, Any, List, Optional
from typing import Dict, Tuple, Any, List, Optional

from .embedder_interfaces import EmbedderInterface

from ..protocols import Protocol
from ..utilities import read_FASTA, SAVE_AFTER_N_EMBEDDINGS
from ..utilities import read_FASTA

logger = logging.getLogger(__name__)


class EmbeddingService:
"""
A service class for computing embeddings using a provided embedder.
"""

def __init__(self, embedder: EmbedderInterface = None, use_half_precision: bool = False):
self._embedder = embedder
Expand All @@ -29,15 +33,18 @@ def compute_embeddings(self, sequence_file: str, output_dir: Path, protocol: Pro
force_output_dir: Optional[bool] = False,
force_recomputing: Optional[bool] = False) -> str:
"""
Compute embeddings with the provided embedder from file.
Compute embeddings with the provided embedder from a sequence file.

:param sequence_file: Path to the sequence file
:param output_dir: Output directory to store the computed embeddings
:param protocol: Protocol for the embeddings. Determines if the embeddings should be reduced to per-protein
:param force_output_dir: If True, the given output directory is directly used to store the embeddings file,
without any path enhancing
:param force_recomputing: If True, the embedding file is re-recomputed, even if it already exists
:return: Path to the generated output embeddings file
Parameters:
sequence_file (str): Path to the sequence file.
output_dir (Path): Output directory to store the computed embeddings.
protocol (Protocol): Protocol for the embeddings. Determines if the embeddings should be reduced to per-protein.
force_output_dir (bool, optional): If True, the given output directory is directly used to store the embeddings file,
without any path enhancement. Defaults to False.
force_recomputing (bool, optional): If True, the embedding file is re-computed, even if it already exists. Defaults to False.

Returns:
str: Path to the generated output embeddings file.
"""
use_reduced_embeddings = protocol in Protocol.using_per_sequence_embeddings()
embedder_name = self._embedder.name.split("/")[-1]
Expand Down Expand Up @@ -68,45 +75,165 @@ def compute_embeddings(self, sequence_file: str, output_dir: Path, protocol: Pro
protein_sequences = {seq.id: str(seq.seq) for seq in sorted(read_FASTA(sequence_file),
key=lambda seq: len(seq.seq),
reverse=True)}

embeddings_file_path = self._do_embeddings_computation(protein_sequences, embeddings_file_path, use_reduced_embeddings)

return str(embeddings_file_path)

def _do_embeddings_computation(self, protein_sequences: Dict[str, str], embeddings_file_path: Path,
use_reduced_embeddings: bool) -> str:
"""
Performs the embedding service for the given protein sequences.

Parameters:
protein_sequences (Dict[str, str]): A dictionary mapping sequence IDs to protein sequences.
embeddings_file_path (Path): The path where embeddings will be saved.
use_reduced_embeddings (bool): Indicates if reduced embeddings should be used.

Returns:
str: The path to the embeddings file.
"""
sequence_ids = list(protein_sequences.keys())
embeddings = {}
idx: int = 0
last_save_id: int = 0
for embedding in tqdm(self._embedder.embed_many(protein_sequences.values()),
total=len(protein_sequences.values())):
embeddings[sequence_ids[idx]] = embedding
idx += 1
if len(embeddings) > SAVE_AFTER_N_EMBEDDINGS:
if use_reduced_embeddings:
embeddings = self._reduce_embeddings(embeddings, self._embedder)
last_save_id = self._save_embeddings(save_id=last_save_id, embeddings=embeddings,
embeddings_file_path=embeddings_file_path)

# Manually clear to ensure that embeddings are deleted from RAM
del embeddings
embeddings = {}

start_time = time.time()

embedding_iter = self._embedder.embed_many(protein_sequences.values())
total_sequences = len(protein_sequences.values())

logger.info("If your dataset contains long reads, it may take more time to process the first few sequences.")

with tqdm(total=total_sequences, desc="Computing Embeddings") as pbar:

# Load the first sequence and calculate the initial max_embedding_fit
embeddings[sequence_ids[0]] = next(embedding_iter, None)
pbar.update(1)

max_embedding_fit = self._max_embedding_fit(embeddings[sequence_ids[0]])

if embeddings[sequence_ids[0]] is None:
logger.info(f"No embeddings found.")
return str(embeddings_file_path)

embedding_dimension = embeddings[sequence_ids[0]].shape[-1]

# Load other sequences
for idx in range(1, total_sequences):
if max_embedding_fit <= 3 or len(embeddings) % max_embedding_fit == 0 or idx == total_sequences - 1:
last_save_id, embeddings = self._save_and_reset_embeddings(embeddings, last_save_id,
embeddings_file_path, use_reduced_embeddings)
logger.debug(f"New {max_embedding_fit=}")

embeddings[sequence_ids[idx]] = next(embedding_iter, None)
pbar.update(1)

# Calculate the new max_embedding_fit for the next batch
max_embedding_fit = self._max_embedding_fit(embeddings[sequence_ids[idx]])

else:
embeddings[sequence_ids[idx]] = next(embedding_iter, None)
pbar.update(1)

if embeddings[sequence_ids[idx]] is None:
logger.debug(f"len(sequence_ids) > len(embedding_iter) or found a None value in the embedding_iter")
del embeddings[sequence_ids[idx]]
return str(embeddings_file_path)

logger.info(f"Embedding dimension: {embedding_dimension}")

# Save remaining embeddings
if len(embeddings) > 0:
if use_reduced_embeddings:
embeddings = self._reduce_embeddings(embeddings, self._embedder)
_ = self._save_embeddings(save_id=last_save_id, embeddings=embeddings,
embeddings_file_path=embeddings_file_path)
last_save_id, embeddings = self._save_and_reset_embeddings(embeddings, last_save_id,
embeddings_file_path, use_reduced_embeddings)

# Delete embeddings and embedding model from memory now, because they will no longer be needed
end_time = time.time()
logger.info(f"Time elapsed for saving embeddings: {end_time - start_time:.2f}[s]")

del embeddings
del self._embedder
gc.collect()

return str(embeddings_file_path)

@staticmethod
def _max_embedding_fit(embedding: ndarray) -> int:
"""
Calculates the maximum number of embeddings that can fit in available memory.
heispv marked this conversation as resolved.
Show resolved Hide resolved

This function estimates the maximum number of embeddings that can be stored in
the available system memory without exceeding it. The calculation includes a
safety factor to prevent exhausting memory.

Parameters:
embedding (ndarray): An embedding array, representing the data structure
whose memory footprint is being considered.

Returns:
int: The maximum number of embeddings that can fit in memory.

Notes:
- The number 18 was determined experimentally as a factor correlating the
embedding size to the memory usage, indicating that each unit of
embedding size corresponds to approximately 18 bytes of memory.
- The multiplier 0.75 is a safety margin to ensure that the memory usage
stays within 75% of the available system memory, reducing the risk of
running out of RAM during operations.
"""
max_embedding_fit = int(0.75 * (psutil.virtual_memory().available / (embedding.size*18)))
max_embedding_fit = 1 if max_embedding_fit == 0 else max_embedding_fit
return max_embedding_fit

def _save_and_reset_embeddings(self, embeddings: Dict[str, ndarray], last_save_id: int,
embeddings_file_path: Path, use_reduced_embeddings: bool) -> Tuple[int, Dict[str, ndarray]]:
"""
Save the embeddings and reset the dictionary.

Parameters:
embeddings (Dict[str, ndarray]): Dictionary of embeddings to be saved.
last_save_id (int): The last save ID used for tracking saved embeddings.
embeddings_file_path (Path): The path where embeddings are saved.
use_reduced_embeddings (bool): Flag to determine if embeddings should be reduced.

Returns:
out (Tuple[int, Dict[str, ndarray]]): Updated last_save_id and an empty embeddings dictionary.
"""
if use_reduced_embeddings:
embeddings = self._reduce_embeddings(embeddings, self._embedder)
last_save_id = self._save_embeddings(save_id=last_save_id, embeddings=embeddings,
embeddings_file_path=embeddings_file_path)
del embeddings
return last_save_id, {}


@staticmethod
def _reduce_embeddings(embeddings: Dict[str, ndarray], embedder) -> Dict[str, ndarray]:
"""
Reduces the per-residue embeddings to per-protein embeddings.

Parameters:
embeddings (Dict[str, ndarray]): Dictionary of embeddings.
embedder: The embedder used for reducing embeddings.

Returns:
out (Dict[str, ndarray]): Dictionary of reduced embeddings.
"""
return {seq_id: embedder.reduce_per_protein(embedding) for seq_id, embedding in
embeddings.items()}

@staticmethod
def _save_embeddings(save_id: int, embeddings: Dict[str, ndarray], embeddings_file_path: Path) -> int:
"""
Saves the embeddings to a file.

Args:
save_id (int): The save ID used for tracking saved embeddings.
embeddings (Dict[str, ndarray]): Dictionary of embeddings to be saved.
embeddings_file_path (Path): The path where embeddings are saved.

Returns:
out (int): The updated save ID.
"""
with h5py.File(embeddings_file_path, "a") as embeddings_file:
idx = save_id
for seq_id, embedding in embeddings.items():
Expand All @@ -119,9 +246,12 @@ def compute_embeddings_from_list(self, protein_sequences: List[str], protocol: P
"""
Compute embeddings with the provided embedder directly from a list of sequences.

:param protein_sequences: List of protein sequences as string
:param protocol: Protocol for the embeddings. Determines if the embeddings should be reduced to per-protein
:return: List of computed embeddings
Parameters:
protein_sequences (List[str]): List of protein sequences as strings.
protocol (Protocol): Protocol for the embeddings. Determines if the embeddings should be reduced to per-protein.

Returns:
out (List): List of computed embeddings.
"""
use_reduced_embeddings = protocol in Protocol.using_per_sequence_embeddings()

Expand All @@ -134,6 +264,15 @@ def compute_embeddings_from_list(self, protein_sequences: List[str], protocol: P

@staticmethod
def load_embeddings(embeddings_file_path: str) -> Dict[str, Any]:
"""
Loads precomputed embeddings from a file.

Parameters:
embeddings_file_path (str): Path to the embeddings file.

Returns:
out (Dict[str, Any]): Dictionary mapping sequence IDs to embeddings.
"""
# Load computed embeddings in .h5 file format
logger.info(f"Loading embeddings from: {embeddings_file_path}")
start = time.perf_counter()
Expand Down
3 changes: 1 addition & 2 deletions biotrainer/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .cuda_device import get_device, is_device_cpu
from .data_classes import Split, SplitResult, DatasetSample
from .constants import SEQUENCE_PAD_VALUE, MASK_AND_LABELS_PAD_VALUE, INTERACTION_INDICATOR, \
METRICS_WITHOUT_REVERSED_SORTING, SAVE_AFTER_N_EMBEDDINGS
METRICS_WITHOUT_REVERSED_SORTING
from .fasta import read_FASTA, get_attributes_from_seqrecords, \
get_attributes_from_seqrecords_for_protein_interactions, get_split_lists

Expand All @@ -19,7 +19,6 @@
'MASK_AND_LABELS_PAD_VALUE',
'INTERACTION_INDICATOR',
'METRICS_WITHOUT_REVERSED_SORTING',
'SAVE_AFTER_N_EMBEDDINGS',
'Split',
'SplitResult',
'DatasetSample',
Expand Down
3 changes: 0 additions & 3 deletions biotrainer/utilities/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,3 @@
# Usually, a higher metric means a better model (accuracy: 0.9 > 0.8, precision: 0.5 > 0.1 ..)
# For some metrics, however, the opposite is true (loss: 0.1 > 0.2, rmse: 20.05 > 40.05)
METRICS_WITHOUT_REVERSED_SORTING: Final[List[str]] = ["loss", "mse", "rmse"]

# Embeddings
SAVE_AFTER_N_EMBEDDINGS: Final[int] = 100
Loading
Loading