Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added options for embeddings dimension reduction #123

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions biotrainer/config/configurator.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
import os
from pathlib import Path
from typing import Union, List, Dict, Any, Tuple
from datasets import load_dataset, concatenate_datasets
from sklearn.model_selection import train_test_split

from ruamel import yaml
from ruamel.yaml import YAMLError
from webencodings import labels

from . import config_rules
from .config_option import ConfigurationException, ConfigOption, FileOption, logger
from .config_option import ConfigurationException, ConfigOption, FileOption
from .config_rules import (
ConfigRule,
MutualExclusive,
Expand Down
59 changes: 57 additions & 2 deletions biotrainer/config/embedding_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from .config_option import FileOption, classproperty, ConfigOption
from ..embedders import get_predefined_embedder_names
from ..protocols import Protocol


class EmbeddingOption(ConfigOption, ABC):
Expand Down Expand Up @@ -144,5 +145,59 @@ def allow_download(self) -> bool:
return True


# List of all embedding-related configuration options
embedding_options: List[Type[EmbeddingOption]] = [EmbedderName, UseHalfPrecision, EmbeddingsFile]
class DimensionReductionMethod(EmbeddingOption, ConfigOption):

@classproperty
def name(self) -> str:
return "dimension_reduction_method"

@property
def default_value(self) -> Union[str, int, float, bool, Any]:
return ""

@classproperty
def allow_multiple_values(self) -> bool:
return False

@property
def possible_values(self) -> List[Any]:
return ["umap", "tsne"]

@classproperty
def allowed_protocols(self) -> List[Protocol]:
return Protocol.using_per_sequence_embeddings()

@classproperty
def required(self) -> bool:
return False


class NReducedComponents(EmbeddingOption, ConfigOption):

@classproperty
def name(self) -> str:
return "n_reduced_components"

@property
def default_value(self) -> Union[str, int, float, bool, Any]:
return ""

@classproperty
def allow_multiple_values(self) -> bool:
return False

@staticmethod
def _is_value_valid(config_option: ConfigOption, value) -> bool:
return type(value)==int and value > 0

@classproperty
def allowed_protocols(self) -> List[Protocol]:
return Protocol.using_per_sequence_embeddings()

@classproperty
def required(self) -> bool:
return False


embedding_options: List[Type[EmbeddingOption]] = [EmbedderName, UseHalfPrecision, EmbeddingsFile,
DimensionReductionMethod, NReducedComponents]
2 changes: 1 addition & 1 deletion biotrainer/config/general_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def possible_values(self) -> List[Any]:

@classproperty
def allowed_protocols(self) -> List[Protocol]:
return [Protocol.sequence_to_class, Protocol.sequence_to_value]
return Protocol.using_per_sequence_embeddings()

@classproperty
def required(self) -> bool:
Expand Down
46 changes: 44 additions & 2 deletions biotrainer/embedders/embedding_service.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import os
import gc
import time
import psutil
import h5py
import torch
import psutil
import logging
import numpy as np

from umap import UMAP
from tqdm import tqdm
from pathlib import Path
from numpy import ndarray
from typing import Dict, Tuple, Any, Optional, List, Union
from sklearn.manifold import TSNE
from typing import Dict, Tuple, Any, List, Union, Optional

from .embedder_interfaces import EmbedderInterface

Expand Down Expand Up @@ -230,6 +232,46 @@ def _save_and_reset_embeddings(self, embeddings: Dict[str, ndarray], last_save_i
del embeddings
return last_save_id, {}

@staticmethod
def embeddings_dimensionality_reduction(
embeddings: Dict[str, Any],
dimension_reduction_method: str,
n_reduced_components: int) -> Dict[str, Any]:
"""Reduces the dimension of per-protein embeddings using one of the
dimensionality reduction methods

Args:
embeddings (Dict[str, Any]): Dictionary of embeddings.
dimension_reduction_method (str): The method used to reduce
the dimensionality of embeddings. Options are 'umap' or 'tsne'.
n_reduced_components (int): The target number of dimensions for
the reduced embeddings.

Returns:
Dict[str, Any]: Dictionary of embeddings with reduced dimensions.
"""
sorted_keys = sorted(list(embeddings.keys()))
nadyadevani3112 marked this conversation as resolved.
Show resolved Hide resolved
all_embeddings = torch.stack([embeddings[k] for k in sorted_keys], dim=0)
max_dim_dict = {
"umap": all_embeddings.shape[0] - 2,
"tsne": all_embeddings.shape[0] - 1
}
n_reduced_components = min([
n_reduced_components,
max_dim_dict[dimension_reduction_method],
all_embeddings.shape[1]])
dimension_reduction_method_dict = {
"umap": UMAP(n_components=n_reduced_components),
"tsne": TSNE(
n_components=n_reduced_components,
perplexity=min(30, n_reduced_components))
}
logger.info(f"Starting embeddings dimensionality reduction via method {dimension_reduction_method}")
embeddings_reduced_dimensions = dimension_reduction_method_dict[
dimension_reduction_method].fit_transform(all_embeddings)
logger.info(f"Finished embeddings dimensionality reduction!")
return {sorted_keys[i]: torch.tensor(embeddings_reduced_dimensions[i]) for i in range(len(sorted_keys))}

@staticmethod
def _reduce_embeddings(embeddings: Dict[str, ndarray], embedder) -> Dict[str, ndarray]:
"""
Expand Down
33 changes: 32 additions & 1 deletion biotrainer/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def __init__(self,
cross_validation_config: Dict[str, Any] = None,
interaction: Optional[str] = None,
sanity_check: bool = True,
dimension_reduction_method: Optional[str] = None,
n_reduced_components: Optional[int] = None,
# Ignore rest
**kwargs
):
Expand All @@ -76,6 +78,8 @@ def __init__(self,
self._cross_validation_splitter = CrossValidationSplitter(self._protocol, self._cross_validation_config)
self._hp_manager = hp_manager
self._sanity_check = sanity_check
self._dimension_reduction_method = dimension_reduction_method
self._n_reduced_components = n_reduced_components

def training_and_evaluation_routine(self):
# SETUP
Expand Down Expand Up @@ -173,9 +177,36 @@ def _create_and_load_embeddings(self) -> Dict[str, Any]:

# Mapping from id to embeddings
id2emb = embedding_service.load_embeddings(embeddings_file_path=embeddings_file)

if self._is_dimension_reduction_possible(id2emb):
id2emb = embedding_service.embeddings_dimensionality_reduction(
embeddings=id2emb,
dimension_reduction_method=self._dimension_reduction_method,
n_reduced_components=self._n_reduced_components)
return id2emb

def _is_dimension_reduction_possible(self, embeddings: Dict[str, Any]) -> bool:
if (self._protocol.using_per_sequence_embeddings() and
self._dimension_reduction_method and
self._n_reduced_components and
len(embeddings)>=3 and
list(embeddings.values())[0].shape[0]>=3):
return True
else:
if (self._dimension_reduction_method and
self._n_reduced_components):
if len(embeddings)<3:
raise Exception(f"Dimensionality reduction cannot be performed as \
the number of samples is less than 3")
if list(embeddings.values())[0].shape[0]<3:
raise Exception(f"Dimensionality reduction cannot be performed as \
the original embedding dimension is less than 3")
if not self._protocol.using_per_sequence_embeddings():
raise Exception(f"Dimensionality reduction cannot be performed as \
the embeddings are not per-protein embeddings")
return False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be better to raise an exception here, because if the user provides both the dimension_reduction_method and n_reduced_components than he or she will probably expect this to work and might not be aware, that it was not done after training. Admittedly, this is an edge case because if the number of samples is less than 3, then the train/validation/test splitting will also not work. But better be safe than sorry :)




def _get_class_weights(self, target_manager: TargetManager) -> Union[None, torch.FloatTensor]:
# Get x_to_class specific logs and weights
class_weights = None
Expand Down
9 changes: 9 additions & 0 deletions docs/config_file_options.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,15 @@ The file will be downloaded and stored in the path of your config file with pref
**Note that *embedder_name* and *embeddings_file* are mutually exclusive. In case you provide your own embeddings,
the experiment directory will be called *custom_embeddings*.**

To perform dimensionality reduction on the embeddings, specify the dimension reduction method to be used:
```yaml
dimension_reduction_method: umap | tsne # Default: None
```
and the number of dimensions to reduce the embeddings to (any positive integer):
```yaml
n_reduced_components: 5 # Default: None
```

## Model parameters

There are multiple options available to specify the model you want to train.
Expand Down
2 changes: 2 additions & 0 deletions examples/sequence_to_class_dim_reduction/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
output/
out.yml
13 changes: 13 additions & 0 deletions examples/sequence_to_class_dim_reduction/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# sequence_to_class_dim_reduction example

This example shows how to use the sequence_to_class protocol with embeddings dimensionality reduction. It predicts a class for every sequence with reduced dimension in the
sequences.fasta file. Class labels and dataset annotations are also stored in the sequences.fasta file
for this protocol (see [data standardization](../../docs/data_standardization.md#sequence_to_class)).

Additionally, in this example the `use_class_weights: True` flag is set,
which can also be used to get a quick overview about the class distribution in your dataset from the console logs.

Execute the example (from the base directory):
```bash
poetry run python3 run-biotrainer.py examples/sequence_to_class_dim_reduction/config.yml
```
16 changes: 16 additions & 0 deletions examples/sequence_to_class_dim_reduction/config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
sequence_file: sequences.fasta
protocol: sequence_to_class
model_choice: FNN
optimizer_choice: adam
loss_choice: cross_entropy_loss
num_epochs: 200
use_class_weights: True
learning_rate: 1e-3
batch_size: 128
save_split_ids: False
use_half_precision: True
device: cuda
disable_pytorch_compile: True
embedder_name: Rostlab/prot_t5_xl_uniref50
dimension_reduction_method: umap
n_reduced_components: 3
8 changes: 8 additions & 0 deletions examples/sequence_to_class_dim_reduction/sequences.fasta
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
>Seq1 TARGET=Glob SET=train
SEQWENCE
>Seq2 TARGET=GlobSP SET=val
PRTEIN
>Seq3 TARGET=TM SET=test
SEQVENCEPROTEI
>Seq4 TARGET=TMSP SET=test
PRTEINSEQWENCE
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ onnxscript = "^0.1.0.dev20240806"
onnxruntime = "^1.19.0"
pandas = "^2.2.3"
datasets = "^3.1.0"
umap-learn = "^0.5.7"

[tool.poetry.dev-dependencies]
pytest = "^8.3.3"
Expand Down
71 changes: 70 additions & 1 deletion tests/test_configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,4 +527,73 @@ def test_hf_mutual_exclusive_mask_file_name(self):
"mutual exclusive",
str(context.exception),
"Exception does not raise an exception for mutual exclusive mask file name."
)
)

def test_dimension_reduction_methods(self):
with tempfile.TemporaryDirectory() as tmpdir:
config_dict = deepcopy(configurations["minimal"])
# Existing method works
config_dict["dimension_reduction_method"] = "umap"

configurator = Configurator.from_config_dict(config_dict)
configurator._config_file_path = Path(tmpdir)

self.assertTrue(
configurator.get_verified_config(),
"Valid dimension_reduction_method: umap failed!"
)

# Non-existing method does not work
config_dict["dimension_reduction_method"] = "nonexistingmethod"
configurator = Configurator.from_config_dict(config_dict)
configurator._config_file_path = Path(tmpdir)

with self.assertRaises(ConfigurationException) as context:
configurator.get_verified_config()

def test_dimension_reduction_components(self):
with tempfile.TemporaryDirectory() as tmpdir:
config_dict = deepcopy(configurations["minimal"])
# Positive integer works
config_dict["dimension_reduction_method"] = "umap"
config_dict["n_reduced_components"] = 23

configurator = Configurator.from_config_dict(config_dict)
configurator._config_file_path = Path(tmpdir)

self.assertTrue(
configurator.get_verified_config(),
"Valid n_reduced_components: 5 failed"
)

# Zero does not work
config_dict["n_reduced_components"] = 0
configurator = Configurator.from_config_dict(config_dict)
configurator._config_file_path = Path(tmpdir)

with self.assertRaises(ConfigurationException) as context:
configurator.get_verified_config()

# Negative integer does not work
config_dict["n_reduced_components"] = -50
configurator = Configurator.from_config_dict(config_dict)
configurator._config_file_path = Path(tmpdir)

with self.assertRaises(ConfigurationException) as context:
configurator.get_verified_config()

# Double value does not work
config_dict["n_reduced_components"] = 5.5
configurator = Configurator.from_config_dict(config_dict)
configurator._config_file_path = Path(tmpdir)

with self.assertRaises(ConfigurationException) as context:
configurator.get_verified_config()

# Negative double value does not work
config_dict["n_reduced_components"] = -10.25
configurator = Configurator.from_config_dict(config_dict)
configurator._config_file_path = Path(tmpdir)

with self.assertRaises(ConfigurationException) as context:
configurator.get_verified_config()
Loading