diff --git a/biotrainer/config/configurator.py b/biotrainer/config/configurator.py index aa6b1a97..d887caa2 100644 --- a/biotrainer/config/configurator.py +++ b/biotrainer/config/configurator.py @@ -1,12 +1,17 @@ import os from pathlib import Path from typing import Union, List, Dict, Any, Tuple +from datasets import load_dataset, concatenate_datasets +from sklearn.model_selection import train_test_split from ruamel import yaml from ruamel.yaml import YAMLError +from webencodings import labels -from .config_option import ConfigurationException, ConfigOption, FileOption +from . import config_rules +from .config_option import ConfigurationException, ConfigOption, FileOption, logger from .config_rules import ( + ConfigRule, MutualExclusive, MutualExclusiveValues, ProtocolRequires, @@ -27,17 +32,44 @@ ) from .embedding_options import EmbedderName, EmbeddingsFile, embedding_options, UseHalfPrecision from .general_options import general_options, Device -from .input_options import SequenceFile, LabelsFile, input_options +from .input_options import SequenceFile, LabelsFile, MaskFile, input_options from .model_options import model_options from .training_options import AutoResume, PretrainedModel, training_options +from .hf_dataset_options import ( + hf_dataset_options, + HF_DATASET_CONFIG_KEY, + HFPath, + HFSequenceColumn, + HFTargetColumn, +) from ..protocols import Protocol +from ..utilities import process_hf_dataset_to_fasta # Define protocol-specific rules -protocol_rules = [ +local_file_protocol_rules = [ ProtocolRequires(protocol=Protocol.per_residue_protocols(), requires=[SequenceFile, LabelsFile]), ProtocolRequires(protocol=Protocol.per_sequence_protocols(), requires=[SequenceFile]), ] +hf_dataset_rules = [ + ProtocolRequires(protocol=Protocol.all(), requires=[HFPath, HFSequenceColumn, HFTargetColumn]), + MutualExclusive( + exclusive=[HFPath, SequenceFile], + error_message="If you want to download from HuggingFace, don't provide a sequence_file.\n" + "Providing sequence_column is enough to download the dataset from HuggingFace." + ), + MutualExclusive( + exclusive=[HFPath, LabelsFile], + error_message="If you want to download from HuggingFace, don't provide a labels_file.\n" + "Providing targets_column is enough to download the dataset from HuggingFace." + ), + MutualExclusive( + exclusive=[HFPath, MaskFile], + error_message="If you want to download from HuggingFace, don't provide a mask_file.\n" + "Providing mask_column is enough to download the dataset from HuggingFace." + ) +] + # Define configuration option rules config_option_rules = [ MutualExclusive( @@ -80,13 +112,20 @@ # Combine all configuration options into dictionaries for easy access all_options_dict: Dict[str, ConfigOption] = { option.name: option - for option in general_options + input_options + model_options + training_options + embedding_options + for option in ( + general_options + input_options + hf_dataset_options + + model_options + training_options + embedding_options + ) } cross_validation_dict: Dict[str, ConfigOption] = { option.name: option for option in cross_validation_options } +hf_dataset_dict: Dict[str, ConfigOption] = { + option.name: option for option in hf_dataset_options +} + class Configurator: """ @@ -98,9 +137,9 @@ class Configurator: """ def __init__( - self, - config_dict: Dict, - config_file_path: Path = None + self, + config_dict: Dict, + config_file_path: Path = None ): """ Initialize a Configurator instance. @@ -216,13 +255,17 @@ def _get_protocol_from_config_dict(config_dict: Dict[str, Any]): Raises: ConfigurationException: If the protocol is not specified or invalid. """ + protocol = config_dict.get("protocol") + if protocol is None: + raise ConfigurationException( + "No protocol specified in config file!" + ) try: - protocol = config_dict["protocol"] return Protocol[protocol] except KeyError: - raise ConfigurationException("No protocol specified in config file!") - except KeyError as e: - raise ConfigurationException(f"Invalid protocol specified: {config_dict.get('protocol')})") from e + raise ConfigurationException( + f"Invalid protocol specified: {protocol}" + ) @staticmethod def _get_cross_validation_map( @@ -253,10 +296,14 @@ def _get_cross_validation_map( cv_object.transform_value_if_necessary() cv_map[cv_name] = cv_object except KeyError: - raise ConfigurationException(f"Unknown cross-validation option: {cv_name}!") + raise ConfigurationException( + f"Unknown cross-validation option: {cv_name}!" + ) if method == "": - raise ConfigurationException("Required option method is missing from cross_validation_config!") + raise ConfigurationException( + "Required option method is missing from cross_validation_config!" + ) else: # Add default value for choose_by if not present if ChooseBy.name not in cv_dict.keys(): @@ -264,14 +311,87 @@ def _get_cross_validation_map( return cv_map + @staticmethod + def _get_hf_dataset_map( + protocol: Protocol, + hf_dict: Dict[str, Any] + ) -> Dict[str, ConfigOption]: + """ + Create a mapping of HuggingFace dataset options based on the provided configuration. + + Args: + protocol (Protocol): The selected protocol. + hf_dict (Dict[str, Any]): The hf_dataset configuration dictionary. + + Returns: + Dict[str, ConfigOption]: A dictionary mapping hf_dataset option names to their instances. + + Raises: + ConfigurationException: If an unknown hf_dataset option is encountered or required options are missing. + """ + hf_map = {} + + for hf_name, hf_value in hf_dict.items(): + try: + hf_option_class = hf_dataset_dict[hf_name] + hf_option: ConfigOption = hf_option_class(protocol=protocol, value=hf_value) + hf_map[hf_name] = hf_option + except KeyError: + raise ConfigurationException( + f"Unknown hf_dataset option: {hf_name}!" + ) + except ConfigurationException as e: + raise ConfigurationException( + f"Invalid value for hf_dataset option '{hf_name}': {e}" + ) + + return hf_map + + def _hf_config_updates(self, config_map: Dict[str, ConfigOption]) -> None: + """ + Apply updates to the configuration map based on the HuggingFace dataset configurations. + + This method ensures that the necessary files (e.g., sequence_file, labels_file, mask_file) are + created or updated based on the HuggingFace dataset configuration and protocol requirements. + + Args: + config_map (Dict[str, ConfigOption]): A dictionary mapping configuration option names + to their instances. + + Raises: + ConfigurationException: If there are issues with file creation or HuggingFace dataset processing. + """ + # Ensure the 'hf_dataset' directory exists + hf_dataset_dir = self._config_file_path / "hf_db" + hf_dataset_dir.mkdir(exist_ok=True) + + if self._config_dict.get("hf_dataset", None): + if self.protocol in Protocol.per_sequence_protocols(): + # Update 'sequence_file' in config_map + sequence_file_path = str(hf_dataset_dir / "sequences.fasta") + self._update_config_map(config_map, "sequence_file", sequence_file_path) + + elif self.protocol in Protocol.per_residue_protocols(): + # Update 'sequence_file' and 'labels_file' in config_map + sequence_file_path = str(hf_dataset_dir / "sequences.fasta") + labels_file_path = str(hf_dataset_dir / "labels.fasta") + self._update_config_map(config_map, "sequence_file", sequence_file_path) + self._update_config_map(config_map, "labels_file", labels_file_path) + + # Update 'mask_file' if 'mask_column' is specified + if self._config_dict["hf_dataset"].get("mask_column", None): + mask_file_path = str(hf_dataset_dir / "mask.fasta") + self._update_config_map(config_map, "mask_file", mask_file_path) + + @staticmethod def _get_config_maps( protocol: Protocol, config_dict: Dict[str, Any], config_file_path: Path = None, - ) -> Tuple[Dict[str, ConfigOption], Dict[str, ConfigOption]]: + ) -> Tuple[Dict[str, ConfigOption], Dict[str, ConfigOption], Dict[str, ConfigOption]]: """ - Generate configuration and cross-validation maps based on the protocol and configuration dictionary. + Generate configuration, cross-validation, and hf_dataset maps based on the protocol and configuration dictionary. Args: protocol (Protocol): The selected protocol. @@ -279,19 +399,25 @@ def _get_config_maps( config_file_path (Path, optional): Path to the configuration file directory. Defaults to None. Returns: - Tuple[Dict[str, ConfigOption], Dict[str, ConfigOption]]: + Tuple[Dict[str, ConfigOption], Dict[str, ConfigOption], Dict[str, ConfigOption]]: - config_map: Mapping of configuration option names to their instances. - cv_map: Mapping of cross-validation option names to their instances. + - hf_map: Mapping of hf_dataset option names to their instances. """ config_map = {} cv_map = {} + hf_map = {} contains_cross_validation_config = False + for config_name in config_dict.keys(): try: if config_name == CROSS_VALIDATION_CONFIG_KEY: cv_map = Configurator._get_cross_validation_map(protocol=protocol, cv_dict=config_dict[config_name]) contains_cross_validation_config = True + elif config_name == HF_DATASET_CONFIG_KEY: + hf_map = Configurator._get_hf_dataset_map(protocol=protocol, + hf_dict=config_dict[config_name]) else: value = config_dict[config_name] if value == "": # Ignore empty values @@ -300,7 +426,10 @@ def _get_config_maps( config_object.transform_value_if_necessary(config_file_path) config_map[config_name] = config_object except KeyError: - raise ConfigurationException(f"Unknown configuration option: {config_name}!") + raise ConfigurationException( + f"Unknown configuration option: {config_name}!" + ) + # Add default values for missing configuration options all_options_for_protocol: List[ConfigOption] = [ option for option in all_options_dict.values() if protocol in option.allowed_protocols @@ -317,30 +446,104 @@ def _get_config_maps( cv_map[Method.name] = Method(protocol=protocol) cv_map[ChooseBy.name] = ChooseBy(protocol=protocol) - return config_map, cv_map + return config_map, cv_map, hf_map + + def _update_config_map( + self, + config_map: Dict[str, ConfigOption], + option_name: str, + value: Any + ) -> None: + """ + Updates an existing ConfigOption in the configuration map or creates a new one if it does not exist. + + This method ensures that the specified configuration option is present in the `config_map`. + If the `option_name` already exists in `config_map`, its value is updated and any necessary + transformations are applied. If the `option_name` does not exist, the method attempts to + create a new ConfigOption instance using the `all_options_dict`. If the `option_name` is + unrecognized, a `ConfigurationException` is raised. + + Args: + config_map (Dict[str, ConfigOption]): + The configuration map to update, where keys are option names and values are ConfigOption instances. + option_name (str): + The name of the configuration option to update or add. + value (Any): + The new value to assign to the configuration option. + + Raises: + ConfigurationException: + If the `option_name` is not recognized or does not correspond to any known configuration option. + + """ + if option_name in config_map: + config_option = config_map[option_name] + config_option.value = value + config_option.transform_value_if_necessary(self._config_file_path) + else: + option_class = all_options_dict.get(option_name) + if option_class: + config_option = option_class(protocol=self.protocol, value=value) + config_option.transform_value_if_necessary(self._config_file_path) + config_map[option_name] = config_option + else: + raise ConfigurationException(f"Unknown configuration option: {option_name}") @staticmethod - def _verify_config( + def _create_hf_files( protocol: Protocol, config_map: Dict[str, ConfigOption], + hf_map: Dict[str, ConfigOption] + ) -> None: + """ + Creates sequences and, if needed, labels and masks FASTA files based on the HuggingFace + dataset configuration and protocol requirements. This method downloads and processes the + HuggingFace dataset according to the selected protocol. + + Args: + protocol (Protocol): The selected protocol determining how the dataset should be processed. + config_map (Dict[str, ConfigOption]): A mapping of configuration option names to their respective ConfigOption instances. + hf_map (Dict[str, ConfigOption]): A mapping of HuggingFace dataset option names to their respective ConfigOption instances. + + Raises: + ConfigurationException: If there is an issue during the creation of the required files or processing the dataset. + """ + try: + process_hf_dataset_to_fasta(protocol, config_map, hf_map) + except Exception as e: + raise ConfigurationException(f"Error in _create_hf_files: {e}") + + @staticmethod + def _check_rules( + protocol: Protocol, + config: Dict[str, ConfigOption], + rules: List[ConfigRule], ignore_file_checks: bool ): """ - Verify the configuration map against all defined rules. + Applies a set of validation rules to the provided configuration. + + This method iterates through each rule in the provided list of rules and applies it to the configuration. + If any rule fails, a `ConfigurationException` is raised with the corresponding failure reason. Args: - protocol (Protocol): The selected protocol. - config_map (Dict[str, ConfigOption]): Mapping of configuration option names to their instances. - ignore_file_checks (bool): If True, file-related checks are ignored. + protocol (Protocol): + The selected protocol that dictates which rules are applicable to the configuration. + config (Dict[str, ConfigOption]): + A dictionary where keys are the names of configuration options, and values are their corresponding `ConfigOption` instances. + rules (List[ConfigRule]): + A list of validation rule objects that will be applied to the configuration. + ignore_file_checks (bool): + If set to `True`, file-related checks (such as existence or correctness of file paths) will be ignored. Raises: - ConfigurationException: If any validation rule is violated. + ConfigurationException: + If any validation rule fails, an exception is raised with the reason why the rule was not met. + """ - config_objects = list(config_map.values()) + config_objects = list(config.values()) - # Combine all applicable rules - all_rules = protocol_rules + config_option_rules - for rule in all_rules: + for rule in rules: success, reason = rule.apply( protocol=protocol, config=config_objects, @@ -349,59 +552,88 @@ def _verify_config( if not success: raise ConfigurationException(reason) + def _verify_config( + self, + protocol: Protocol, + config_map: Dict[str, ConfigOption], + ignore_file_checks: bool + ): + """ + Verify the provided configuration map against a set of validation rules. + + This method ensures that the configuration options are valid for the specified protocol, + adhere to the defined rules, and have correct values. It also skips file-related checks if + `ignore_file_checks` is set to `True`. + + Args: + protocol (Protocol): + The protocol to validate the configuration against. + config_map (Dict[str, ConfigOption]): + A mapping of configuration option names to their respective `ConfigOption` instances. + ignore_file_checks (bool): + If `True`, skips validation of file-related options. Defaults to `False`. + + Raises: + ConfigurationException: + If any configuration rule is violated or an option is invalid for the given protocol. + """ + + self._check_rules( + protocol=protocol, + config=config_map, + rules=config_option_rules, + ignore_file_checks=ignore_file_checks + ) + # Check protocol compatibility and value validity for each configuration option - for config_object in config_objects: + for config_object in list(config_map.values()): if ignore_file_checks and isinstance(config_object, FileOption): continue if protocol not in config_object.allowed_protocols: raise ConfigurationException(f"{config_object.name} not allowed for protocol {protocol}!") - if not config_object.check_value(): raise ConfigurationException(f"{config_object.value} not valid for option {config_object.name}!") - @staticmethod def _verify_cv_config( + self, protocol: Protocol, config_map: Dict[str, ConfigOption], cv_config: Dict[str, ConfigOption], ignore_file_checks: bool, - ): + ) -> None: """ - Verify the cross-validation configuration map against all defined rules. + Validates the cross-validation configuration against defined rules and ensures compatibility with the selected protocol. Args: - protocol (Protocol): The selected protocol. - config_map (Dict[str, ConfigOption]): Mapping of configuration option names to their instances. + protocol (Protocol): The selected protocol that determines allowed configurations. + config_map (Dict[str, ConfigOption]): Mapping of general configuration option names to their instances. cv_config (Dict[str, ConfigOption]): Mapping of cross-validation option names to their instances. - ignore_file_checks (bool): If True, file-related checks are ignored. + ignore_file_checks (bool): If True, file-related checks are ignored during validation. Raises: - ConfigurationException: If any validation rule is violated. + ConfigurationException: If any validation rule is violated, required options are missing, + or incompatible options are used in the cross-validation configuration. """ - cv_objects = list(cv_config.values()) - config_objects = list(config_map.values()) if Method.name not in cv_config.keys(): raise ConfigurationException("Required option method is missing from cross_validation_config!") method = cv_config[Method.name] - # Apply cross-validation specific rules - for rule in cross_validation_rules: - success, reason = rule.apply( - protocol=protocol, - config=cv_objects, - ignore_file_checks=ignore_file_checks - ) - if not success: - raise ConfigurationException(reason) - for rule in optimization_rules: - success, reason = rule.apply( - protocol, - config=cv_objects + config_objects, - ignore_file_checks=ignore_file_checks - ) - if not success: - raise ConfigurationException(reason) + self._check_rules( + protocol=protocol, + config=cv_config, + rules=cross_validation_rules, + ignore_file_checks=ignore_file_checks + ) + + self._check_rules( + protocol=protocol, + config={**config_map, **cv_config}, + rules=optimization_rules, + ignore_file_checks=ignore_file_checks + ) + + cv_objects = list(cv_config.values()) # Ensure that the cross-validation method is compatible with other options if method == "hold_out" and len(cv_objects) > 1: @@ -432,27 +664,50 @@ def get_verified_config(self, ignore_file_checks: bool = False) -> Dict[str, Any Raises: ConfigurationException: If any validation rule is violated or if required options are missing. """ - config_map, cv_map = self._get_config_maps( + config_map, cv_map, hf_map = self._get_config_maps( protocol=self.protocol, config_dict=self._config_dict, config_file_path=self._config_file_path, ) + + if hf_map: + self._check_rules( + protocol=self.protocol, + config={**config_map, **hf_map}, + rules=hf_dataset_rules, + ignore_file_checks=ignore_file_checks + ) + self._hf_config_updates(config_map) + self._create_hf_files( + protocol=self.protocol, + config_map=config_map, + hf_map=hf_map, + ) + + else: + self._check_rules( + protocol=self.protocol, + config=config_map, + rules=local_file_protocol_rules, + ignore_file_checks=ignore_file_checks + ) + self._verify_config( protocol=self.protocol, config_map=config_map, ignore_file_checks=ignore_file_checks ) + self._verify_cv_config( protocol=self.protocol, config_map=config_map, cv_config=cv_map, ignore_file_checks=ignore_file_checks, ) - result = {} - for config_object in config_map.values(): - result[config_object.name] = config_object.value - result[CROSS_VALIDATION_CONFIG_KEY] = {} - for cv_object in cv_map.values(): - result[CROSS_VALIDATION_CONFIG_KEY][cv_object.name] = cv_object.value + + # Prepare the final result dictionary + result = {config_object.name: config_object.value for config_object in config_map.values()} + result[CROSS_VALIDATION_CONFIG_KEY] = {cv_object.name: cv_object.value for cv_object in cv_map.values()} + result[HF_DATASET_CONFIG_KEY] = {hf_object.name: hf_object.value for hf_object in hf_map.values()} return result diff --git a/biotrainer/config/hf_dataset_options.py b/biotrainer/config/hf_dataset_options.py new file mode 100644 index 00000000..24eb17b6 --- /dev/null +++ b/biotrainer/config/hf_dataset_options.py @@ -0,0 +1,142 @@ +from typing import List, Any, Union, Type +from abc import ABC + +from .config_option import ConfigOption, ConfigurationException, classproperty +from ..protocols import Protocol + + +class HFDatasetOption(ConfigOption, ABC): + """ + Abstract base class for HuggingFace dataset configuration options. + + Extends `ConfigOption` to provide a framework for defining + specific HuggingFace dataset-related options. + """ + + @classproperty + def category(self) -> str: + return "hf_dataset" + + +class HFPath(HFDatasetOption): + """ + Configuration option for specifying the HuggingFace dataset path. + """ + + @classproperty + def name(self) -> str: + return "path" + + @classproperty + def allow_multiple_values(self) -> bool: + return False + + @classproperty + def required(self) -> bool: + return True + + def check_value(self) -> bool: + if not isinstance(self.value, str) or "/" not in self.value: + raise ConfigurationException( + f"Invalid HuggingFace dataset path: {self.value}. It should be in the format 'username/dataset_name'." + ) + return True + + +class HFSubsetName(HFDatasetOption): + """ + Configuration option for specifying the dataset subset name. + """ + + @classproperty + def name(self) -> str: + return "subset" + + @classproperty + def allow_multiple_values(self) -> bool: + return False + + @classproperty + def required(self) -> bool: + return False + + +class HFSequenceColumn(HFDatasetOption): + """ + Configuration option for specifying the sequence column in the dataset. + """ + + @classproperty + def name(self) -> str: + return "sequence_column" + + @classproperty + def allow_multiple_values(self) -> bool: + return False + + @classproperty + def required(self) -> bool: + return True + + def check_value(self) -> bool: + if not isinstance(self.value, str) or not self.value.strip(): + raise ConfigurationException("sequence_column must be a non-empty string.") + return True + + +class HFTargetColumn(HFDatasetOption): + """ + Configuration option for specifying the target column in the dataset. + """ + + @classproperty + def name(self) -> str: + return "target_column" + + @classproperty + def allow_multiple_values(self) -> bool: + return False + + @classproperty + def required(self) -> bool: + return True + + def check_value(self) -> bool: + if not isinstance(self.value, str) or not self.value.strip(): + raise ConfigurationException("target_column must be a non-empty string.") + return True + +class HFMaskColumn(HFDatasetOption): + """ + Configuration option for specifying the mask column in the dataset. + """ + + @classproperty + def name(self) -> str: + return "mask_column" + + @classproperty + def allow_multiple_values(self) -> bool: + return False + + @classproperty + def required(self) -> bool: + return False + + def check_value(self) -> bool: + if not isinstance(self.value, str) or not self.value.strip(): + raise ConfigurationException("mask_column must be a non-empty string.") + return True + +# Constant key for hf_dataset configuration +HF_DATASET_CONFIG_KEY: str = "hf_dataset" + +# Add hf_dataset options to a separate dictionary +hf_dataset_options: List[Type[HFDatasetOption]] = [ + HFDatasetOption, + HFPath, + HFSubsetName, + HFSequenceColumn, + HFTargetColumn, + HFMaskColumn +] diff --git a/biotrainer/utilities/__init__.py b/biotrainer/utilities/__init__.py index f4333829..5658b202 100644 --- a/biotrainer/utilities/__init__.py +++ b/biotrainer/utilities/__init__.py @@ -2,16 +2,27 @@ from .version import __version__ from .cuda_device import get_device, is_device_cpu from .data_classes import Split, SplitResult, DatasetSample -from .constants import SEQUENCE_PAD_VALUE, MASK_AND_LABELS_PAD_VALUE, INTERACTION_INDICATOR, \ +from .constants import ( + SEQUENCE_PAD_VALUE, + MASK_AND_LABELS_PAD_VALUE, + INTERACTION_INDICATOR, METRICS_WITHOUT_REVERSED_SORTING -from .fasta import read_FASTA, get_attributes_from_seqrecords, \ - get_attributes_from_seqrecords_for_protein_interactions, get_split_lists +) +from .fasta import ( + read_FASTA, + get_attributes_from_seqrecords, + get_attributes_from_seqrecords_for_protein_interactions, + get_split_lists +) + +from .hf_dataset_to_fasta import process_hf_dataset_to_fasta __all__ = [ 'seed_all', 'get_device', 'is_device_cpu', 'read_FASTA', + 'process_hf_dataset_to_fasta', 'get_attributes_from_seqrecords', 'get_attributes_from_seqrecords_for_protein_interactions', 'get_split_lists', diff --git a/biotrainer/utilities/fasta.py b/biotrainer/utilities/fasta.py index 54119035..2e1d1ded 100644 --- a/biotrainer/utilities/fasta.py +++ b/biotrainer/utilities/fasta.py @@ -2,7 +2,7 @@ import logging from Bio import SeqIO -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, Union, Any, Optional from Bio.SeqRecord import SeqRecord from ..utilities import INTERACTION_INDICATOR diff --git a/biotrainer/utilities/hf_dataset_to_fasta.py b/biotrainer/utilities/hf_dataset_to_fasta.py new file mode 100644 index 00000000..04cc4736 --- /dev/null +++ b/biotrainer/utilities/hf_dataset_to_fasta.py @@ -0,0 +1,311 @@ +import logging + +from typing import Any, Dict, List, Optional, Tuple +from datasets import load_dataset + +from ..protocols import Protocol + +logger = logging.getLogger(__name__) + + +def hf_to_fasta( + sequences: List[str], + targets: List[Any], + set_values: List[str], + sequences_file_name: str, + labels_file_name: Optional[str] = None, + masks: Optional[List[Any]] = None, + masks_file_name: Optional[str] = None +) -> None: + """ + Converts sequences, targets, and optional masks from a HuggingFace dataset into FASTA file(s). + + Args: + sequences (List[str]): A list of protein sequences. + targets (List[Any]): A list of target values. + set_values (List[str]): A list of SET values corresponding to each sequence. + sequences_file_name (str): Path and filename for the output sequence FASTA file. + labels_file_name (Optional[str], optional): Path and filename for the output target FASTA file. + Defaults to None. + masks (Optional[List[Any]]): A list of mask values, or None if masks are not provided. + masks_file_name (Optional[str], optional): Path and filename for the output mask FASTA file. + Defaults to None. + + Raises: + ValueError: If the lengths of sequences, targets, masks (if masks are provided), and set_values do not match. + IOError: If there is an issue writing to the output files. + """ + if not masks_file_name: + if not (len(sequences) == len(targets) == len(set_values)): + raise ValueError("The number of sequences, targets, and set_values must be the same.") + else: + if not (len(sequences) == len(targets) == len(set_values) == len(masks)): + raise ValueError("The number of sequences, targets, set_values, and masks must be the same.") + + files = {} + try: + # Open the necessary files + seq_file = open(sequences_file_name, 'w') + files['seq_file'] = seq_file + + if labels_file_name: + tgt_file = open(labels_file_name, 'w') + files['tgt_file'] = tgt_file + else: + tgt_file = None + + if masks_file_name: + mask_file = open(masks_file_name, 'w') + files['mask_file'] = mask_file + else: + mask_file = None + + # Write the data + for idx in range(len(sequences)): + seq_id = f"Seq{idx+1}" + seq = sequences[idx] + target = targets[idx] + set_val = set_values[idx] + mask = masks[idx] if masks is not None else None + + # Write to sequence file + if tgt_file is None: + # Include target in the header + seq_header = f">{seq_id} SET={set_val} TARGET={target}" + else: + seq_header = f">{seq_id}" + seq_file.write(f"{seq_header}\n{seq}\n") + + if tgt_file is not None: + # Write to target file + tgt_header = f">{seq_id} SET={set_val}" + tgt_file.write(f"{tgt_header}\n{target}\n") + + if mask_file is not None: + # Write to mask file + mask_header = f">{seq_id}" + mask_file.write(f"{mask_header}\n{mask}\n") + + # Close all files + for f in files.values(): + f.close() + + except IOError as e: + # Ensure all files are closed in case of exception + for f in files.values(): + f.close() + raise IOError(f"Error writing to FASTA file(s): {e}") + +def process_subset( + current_subset: Any, + sequence_column: str, + target_column: str, + mask_column: Optional[str] = None +) -> Tuple[List[str], List[Any], Optional[List[Any]]]: + """ + Processes a single subset, verifying the presence of required columns + and extracting sequences, targets, and optionally masks. + + Args: + current_subset (Any): The subset to process. This is typically a HuggingFace dataset or similar structure. + sequence_column (str): The name of the column containing the sequences. + target_column (str): The name of the column containing the target values. + mask_column (Optional[str]): The name of the column containing mask values, if applicable. Defaults to None. + + Returns: + Tuple[List[str], List[Any], Optional[List[Any]]]: A tuple containing: + - List[str]: A list of sequences. + - List[Any]: A list of target values. + - Optional[List[Any]]: A list of mask values, or None if no mask column is provided. + + Raises: + Exception: If any of the specified columns are missing from the dataset. + """ + # Verify columns + verify_column(current_subset, sequence_column) + verify_column(current_subset, target_column) + if mask_column: + verify_column(current_subset, mask_column) + + # Extract data + sequences = current_subset[sequence_column] + targets = current_subset[target_column] + masks = current_subset[mask_column] if mask_column else None + + return sequences, targets, masks + +def determine_set_name(subset_name: str) -> str: + """ + Determines the corresponding set name ("TRAIN", "VAL", "TEST") based on the provided subset name. + + This function normalizes and categorizes the input subset name into one of the standard set names. If the input does not + match any of the expected patterns ("train", "val", "test"), it logs a warning and assigns the subset to "TEST" by default. + + Args: + subset_name (str): The name of the subset (e.g., "train1", "validation", "testing"). + + Returns: + str: The normalized set name ("TRAIN", "VAL", "TEST"). + + Logs: + Warning: Logs a warning if the subset name is unrecognized and defaults to "TEST". + + """ + lower_subset_name = subset_name.lower() + if lower_subset_name.startswith("train"): + return "TRAIN" + elif lower_subset_name.startswith("val"): + return "VAL" + elif lower_subset_name.startswith("test"): + return "TEST" + else: + logger.warning(f"Unrecognized subset name '{subset_name}'. Assigning to 'TEST'.") + return "TEST" + +def verify_column(dataset: Any, column: str) -> None: + """ + Verifies that the specified column exists in the given dataset. + + Args: + dataset (Any): The dataset to verify. This is expected to have a `column_names` attribute, + such as a HuggingFace dataset or a similar structure. + column (str): The name of the column to check for existence. + + Raises: + Exception: If the specified column is not found in the dataset. + """ + if column not in dataset.column_names: + raise Exception( + f"Column '{column}' not found in the dataset." + ) + +def load_and_split_hf_dataset(hf_map: Dict) -> Tuple[List[str], List[Any], List[Any], List[str]]: + """ + Loads a HuggingFace dataset and splits it into sequences, targets, masks (if available), and set values. + + Args: + hf_map (Dict[str, Any]): A mapping of configuration options for loading the HuggingFace dataset. + Expected keys include: + - "path" (str): The dataset path. + - "sequence_column" (str): Name of the sequence column. + - "target_column" (str): Name of the target column. + - "mask_column" (Optional[str]): Name of the mask column (if applicable). + - "subset" (Optional[str]): Subset name, if required. + + Returns: + Tuple[List[str], List[Any], Optional[List[Any]], List[str]]: + - List[str]: Sequences extracted from the dataset. + - List[Any]: targets corresponding to the sequences. + - Optional[List[Any]]: Masks, if available; otherwise, None. + - List[str]: Set values (e.g., "TRAIN", "VAL", "TEST") for each sequence. + + Raises: + ValueError: If loading the dataset fails due to missing or invalid configuration. + Exception: If the required splits or columns are missing, or if other errors occur during processing. + + """ + # Extract parameters from hf_map + path = hf_map["path"].value + sequence_column = hf_map["sequence_column"].value + target_column = hf_map["target_column"].value + mask_column = hf_map["mask_column"].value if "mask_column" in hf_map else None + subset_name = hf_map["subset"].value if "subset" in hf_map else None + + logger.info(f"Loading HuggingFace dataset from path: {path}") + + # Load dataset + try: + dataset = load_dataset(path, subset_name if subset_name is not None else 'default') + except ValueError as e: + error_msg = ("Loading the dataset from Hugging Face failed. " + "If the dataset requires a 'subset', you can specify it using the 'subset' option in the config file.\n" + f"Error: {e}") + raise ValueError(error_msg) + except Exception as e: + raise Exception(f"Loading the dataset from Hugging Face failed. Error: {e}") + + # Collect all available subsets + subset_names = list(dataset.keys()) + if not subset_names: + raise Exception(f"No subsets found in the dataset at path '{path}'.") + + sequences = [] + targets = [] + masks = [] + set_values = [] + + if len(subset_names) >= 3: + logger.info(f"{len(subset_names)} subsets found: {subset_names}. Processing each separately.") + for subset_name in subset_names: + current_subset = dataset[subset_name] + if current_subset is None: + logger.warning(f"Subset '{subset_name}' is None. Skipping.") + continue + + # Process subset + current_sequences, current_targets, current_masks = process_subset(current_subset, sequence_column, + target_column, mask_column) + + # Determine SET value based on split name + current_set_value = determine_set_name(subset_name) + + sequences.extend(current_sequences) + targets.extend(current_targets) + set_values.extend([current_set_value] * len(current_sequences)) + if current_masks is not None: + masks.extend(current_masks) + else: + masks.extend([None] * len(current_sequences)) + + else: + raise Exception( + f"Expected 3 subsets (TRAIN, VAL, TEST) in the dataset at path '{path}'. Found: {subset_names}." + ) + + return sequences, targets, masks, set_values + +def process_hf_dataset_to_fasta( + protocol: Protocol, + config_map: Dict, + hf_map: Dict +) -> None: + """ + Loads a HuggingFace dataset, splits it according to the protocol, and writes the data to FASTA files. + + Args: + protocol (Protocol): The selected protocol determining how the dataset should be processed. + config_map (Dict): A mapping of configuration option names to their respective ConfigOption instances. + hf_map (Dict): A mapping of HuggingFace dataset option names to their respective ConfigOption instances. + + Raises: + Exception: If there is an issue during the creation of the required files or processing the dataset. + """ + try: + sequences, targets, masks, set_values = load_and_split_hf_dataset(hf_map) + except Exception as e: + raise Exception(f"Failed to load and split HuggingFace dataset: {e}") + + try: + if protocol in Protocol.per_sequence_protocols(): + hf_to_fasta( + sequences=sequences, + targets=targets, + set_values=set_values, + sequences_file_name=config_map["sequence_file"].value + ) + elif protocol in Protocol.per_residue_protocols(): + hf_to_fasta( + sequences=sequences, + targets=targets, + masks=masks, + set_values=set_values, + sequences_file_name=config_map["sequence_file"].value, + labels_file_name=config_map["labels_file"].value, + masks_file_name=config_map["mask_file"].value if "mask_file" in config_map else None + ) + else: + raise Exception(f"Unsupported protocol: {protocol}") + except Exception as e: + raise Exception(f"Failed to write FASTA files: {e}") + + logger.info("HuggingFace dataset downloaded and processed successfully.") diff --git a/docs/config_file_options.md b/docs/config_file_options.md index ccc49716..3d92c350 100644 --- a/docs/config_file_options.md +++ b/docs/config_file_options.md @@ -381,4 +381,46 @@ total dataset. This can also be automatically done for you in *biotrainer*: limited_sample_size: 100 # Default: -1, must be > 0 to be applied ``` Note that this value is applied only to the train dataset and embedding calculation is currently -done for all sequences! \ No newline at end of file +done for all sequences! + +## HF Dataset Integration + +This configuration enables the use of datasets hosted on the HuggingFace repository. By specifying the `hf_dataset` option, **there is no need to have `sequence_file`, `labels_file`, and `mask_file` on your local machine**. Instead: + +- A new folder will be created as `hf_db` where your config file exists, and new `sequence_file`, `labels_file`, and `mask_file` will be created based on your config needs. + +### General options + +For HuggingFace integration, the `hf_dataset` option is used: +```yaml +hf_dataset: + path: huggingface_user_name/repository_name # Required + subset: subset_name # Optional + sequence_column: sequences_column_name # Required + target_column: targets_column_name # Required + mask_column: mask_column_name # Optional +``` + +### HF Dataset Configuration Options + +The `hf_dataset` section of the configuration includes the following options: + +- **`path` (required):** The repository path to the desired dataset in the HuggingFace hub (e.g., `huggingface_user_name/repository_name`). +- **`subset` (optional):** Specifies the subset of the dataset to download. + - If no subsets exist, you should remove this option or set it to `default`. + - If the subset name is incorrect, an error will display the available subsets. +- **`sequence_column` (required):** The column in the dataset that contains the sequences. +- **`target_column` (required):** The column in the dataset that contains the targets. +- **`mask_column` (optional):** The column in the dataset that contains the masks. + +### Handling Dataset Splits + +Datasets in the HuggingFace repository may include predefined splits (e.g., `train`, `validation`, `test`). The tool handles splits as follows: + +1. **If three predefined splits exist** (e.g., `train`, `validation`, `test`): + - The splits are directly used as **TRAIN/VAL/TEST**. + - Note that their names should start with `train`, `val`, `test`. +2. **Otherwise**: + - A `ConfigurationException` will be raised. + +You can find an example [here](../examples/hf_dataset). diff --git a/examples/hf_dataset/.gitignore b/examples/hf_dataset/.gitignore new file mode 100644 index 00000000..ee42f705 --- /dev/null +++ b/examples/hf_dataset/.gitignore @@ -0,0 +1,3 @@ +output/ +hf_db/ +*.fasta \ No newline at end of file diff --git a/examples/hf_dataset/README.md b/examples/hf_dataset/README.md new file mode 100644 index 00000000..633e3ef4 --- /dev/null +++ b/examples/hf_dataset/README.md @@ -0,0 +1,34 @@ +# HF Dataset Integration + +This integration allows you to use datasets hosted on HuggingFace without needing local `sequence_file`, `labels_file`, or `mask_file`. The tool creates these files automatically in the `hf_db` folder based on your config. + +### Configuration + +In your `config.yml`, set the following options under `hf_dataset`: + +```yaml +hf_dataset: + path: "huggingface_user_name/repository_name" + subset: "subset_name_if_there_are" + sequence_column: "column_for_sequences" + target_column: "column_for_targets" + mask_column: "optional_column_for_masks" +``` + +### Dataset Splits + +If the dataset includes predefined splits like `train`, `validation`, or `test`, they will be used directly. Otherwise, a `ConfigurationException` will occur. + +### Example + +Run the example with: + +```bash +poetry run biotrainer examples/hf_dataset/config.yml +``` + +**Notes** +- When using the `hf_dataset` option, remove the `sequence_file`, `labels_file`, and `mask_file` entries from the config. +- Ensure that the `sequence_column`, `target_column`, and `mask_column` names match the structure of the dataset in the HuggingFace repository. + +By following this configuration, you can seamlessly integrate HuggingFace datasets into your tool without requiring local sequence and label files. This setup also ensures proper handling of dataset splits for robust training, validation, and testing workflows. diff --git a/examples/hf_dataset/config.yml b/examples/hf_dataset/config.yml new file mode 100644 index 00000000..2508de6d --- /dev/null +++ b/examples/hf_dataset/config.yml @@ -0,0 +1,15 @@ +protocol: residue_to_class +hf_dataset: + path: heispv/protein_data_test + subset: split_3 + sequence_column: protein_sequence + target_column: secondary_structure +model_choice: FNN +optimizer_choice: adam +loss_choice: cross_entropy_loss +num_epochs: 200 +use_class_weights: False +learning_rate: 1e-3 +batch_size: 128 +device: cpu +embedder_name: one_hot_encoding \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 5268f4fe..09e65ae8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,7 @@ onnxscript = "^0.1.0.dev20240806" onnxruntime = "^1.19.0" pandas = "^2.2.3" umap-learn = "^0.5.7" +datasets = "^3.1.0" [tool.poetry.dev-dependencies] pytest = "^8.3.3" diff --git a/tests/test_configurations.py b/tests/test_configurations.py index e9143587..eb8bec89 100644 --- a/tests/test_configurations.py +++ b/tests/test_configurations.py @@ -1,7 +1,9 @@ import unittest +import tempfile from pathlib import Path from copy import deepcopy +from datasets import load_dataset, concatenate_datasets from biotrainer.config import Configurator, ConfigurationException @@ -103,6 +105,43 @@ "method": "leave_p_out", "p": 5, } + }, + "hf_valid_for_sequences": { + "protocol": "sequence_to_class", + "hf_dataset": { + "path": "heispv/protein_data_test", + "subset": "split_1", + "sequence_column": "protein_sequence", + "target_column": "protein_class" + } + }, + "hf_valid_for_residues": { + "protocol": "residue_to_class", + "hf_dataset": { + "path": "heispv/protein_data_test", + "subset": "split_3", + "sequence_column": "protein_sequence", + "target_column": "secondary_structure" + } + }, + "hf_no_subset_required": { + "protocol": "residue_to_class", + "hf_dataset": { + "path": "heispv/protein_data_test_2", + "subset": "random_subset_name", + "sequence_column": "protein_sequence", + "target_column": "secondary_structure" + } + }, + "hf_mutual_exclusive_sequence_file_name": { + "sequence_file": "sequence_file", + "protocol": "residue_to_class", + "hf_dataset": { + "path": "heispv/protein_data_test_2", + "subset": "random_subset_name", + "sequence_column": "protein_sequence", + "target_column": "secondary_structure" + } } } @@ -273,4 +312,219 @@ def test_multiple_values(self): config_dict = {**configurations["multiple_values"], **configurations["leave_p_out"]} with self.assertRaises(ConfigurationException, msg="Config with multiple values for leave_p_out cv does not throw an exception!"): - configurator.from_config_dict(config_dict).get_verified_config() \ No newline at end of file + configurator.from_config_dict(config_dict).get_verified_config() + + def test_hf_valid_3_split(self): + with tempfile.TemporaryDirectory() as tmpdir: + config_dict = deepcopy(configurations["hf_valid_for_sequences"]) + + configurator = Configurator.from_config_dict(config_dict) + configurator._config_file_path = Path(tmpdir) + self.assertTrue( + configurator.get_verified_config(), + "Valid hf_dataset configuration for sequences failed." + ) + + def test_hf_invalid_1_split(self): + with tempfile.TemporaryDirectory() as tmpdir: + config_dict = deepcopy(configurations["hf_valid_for_sequences"]) + config_dict["hf_dataset"]["subset"] = "split_2" + + configurator = Configurator.from_config_dict(config_dict) + configurator._config_file_path = Path(tmpdir) + with self.assertRaises(ConfigurationException) as context: + configurator.get_verified_config() + + self.assertIn( + "Expected 3 subsets", + str(context.exception), + "Valid hf_dataset configuration for one split failed." + ) + + def test_hf_valid_residues_protocol(self): + with tempfile.TemporaryDirectory() as tmpdir: + config_dict = deepcopy(configurations["hf_valid_for_residues"]) + + configurator = Configurator.from_config_dict(config_dict) + configurator._config_file_path = Path(tmpdir) + self.assertTrue( + configurator.get_verified_config(), + "Valid hf_dataset configuration for residues failed." + ) + + def test_hf_missing_sequence_column_name(self): + with tempfile.TemporaryDirectory() as tmpdir: + config_dict = deepcopy(configurations["hf_valid_for_sequences"]) + del config_dict["hf_dataset"]["sequence_column"] + + configurator = Configurator.from_config_dict(config_dict) + configurator._config_file_path = Path(tmpdir) + with self.assertRaises(ConfigurationException) as context: + configurator.get_verified_config() + + self.assertIn( + "sequence_column", + str(context.exception), + "Exception does not mention the missing sequence_column." + ) + + def test_hf_missing_labels_column_name(self): + with tempfile.TemporaryDirectory() as tmpdir: + config_dict = deepcopy(configurations["hf_valid_for_sequences"]) + del config_dict["hf_dataset"]["target_column"] + + configurator = Configurator.from_config_dict(config_dict) + configurator._config_file_path = Path(tmpdir) + with self.assertRaises(ConfigurationException) as context: + configurator.get_verified_config() + + self.assertIn( + "target_column", + str(context.exception), + "Exception does not mention the missing target_column." + ) + + def test_hf_invalid_sequence_column_name(self): + with tempfile.TemporaryDirectory() as tmpdir: + config_dict = deepcopy(configurations["hf_valid_for_sequences"]) + config_dict["hf_dataset"]["sequence_column"] = "random_invalid_name" + + configurator = Configurator.from_config_dict(config_dict) + configurator._config_file_path = Path(tmpdir) + with self.assertRaises(ConfigurationException) as context: + configurator.get_verified_config() + + self.assertIn( + "not found in the dataset", + str(context.exception), + "Exception does not mention the wrong sequence_column." + ) + + + def test_hf_invalid_target_column_name(self): + with tempfile.TemporaryDirectory() as tmpdir: + config_dict = deepcopy(configurations["hf_valid_for_sequences"]) + config_dict["hf_dataset"]["target_column"] = "random_invalid_name" + + configurator = Configurator.from_config_dict(config_dict) + configurator._config_file_path = Path(tmpdir) + with self.assertRaises(ConfigurationException) as context: + configurator.get_verified_config() + + self.assertIn( + "not found in the dataset", + str(context.exception), + "Exception does not mention the wrong target_column." + ) + + def test_hf_invalid_path(self): + with tempfile.TemporaryDirectory() as tmpdir: + config_dict = deepcopy(configurations["hf_valid_for_sequences"]) + config_dict["hf_dataset"]["path"] = "random_invalid_name" + + configurator = Configurator.from_config_dict(config_dict) + configurator._config_file_path = Path(tmpdir) + with self.assertRaises(ConfigurationException) as context: + configurator.get_verified_config() + + self.assertIn( + "doesn't exist on the Hub or cannot be accessed", + str(context.exception), + "Exception does not raise an exception for invalid path." + ) + + def test_hf_invalid_subset(self): + with tempfile.TemporaryDirectory() as tmpdir: + config_dict = deepcopy(configurations["hf_valid_for_sequences"]) + config_dict["hf_dataset"]["subset"] = "random_invalid_name" + + configurator = Configurator.from_config_dict(config_dict) + configurator._config_file_path = Path(tmpdir) + with self.assertRaises(ConfigurationException) as context: + configurator.get_verified_config() + + self.assertIn( + "not found", + str(context.exception), + "Exception does not raise an exception for invalid subset." + ) + + def test_hf_requires_subset(self): + with tempfile.TemporaryDirectory() as tmpdir: + config_dict = deepcopy(configurations["hf_valid_for_sequences"]) + del config_dict["hf_dataset"]["subset"] + + configurator = Configurator.from_config_dict(config_dict) + configurator._config_file_path = Path(tmpdir) + with self.assertRaises(ConfigurationException) as context: + configurator.get_verified_config() + + self.assertIn( + "dataset requires", + str(context.exception), + "Exception does not raise an exception for missing subset." + ) + + def test_hf_not_requires_subset(self): + with tempfile.TemporaryDirectory() as tmpdir: + config_dict = deepcopy(configurations["hf_no_subset_required"]) + + configurator = Configurator.from_config_dict(config_dict) + configurator._config_file_path = Path(tmpdir) + print(configurator._config_dict) + with self.assertRaises(ConfigurationException) as context: + configurator.get_verified_config() + + self.assertIn( + "Available: ['default']", + str(context.exception), + "Exception does not raise an exception for missing subset." + ) + + def test_hf_mutual_exclusive_sequence_file_name(self): + with tempfile.TemporaryDirectory() as tmpdir: + config_dict = deepcopy(configurations["hf_valid_for_sequences"]) + config_dict["sequence_file"] = "random_invalid_name" + + configurator = Configurator.from_config_dict(config_dict) + configurator._config_file_path = Path(tmpdir) + with self.assertRaises(ConfigurationException) as context: + configurator.get_verified_config() + + self.assertIn( + "mutual exclusive", + str(context.exception), + "Exception does not raise an exception for mutual exclusive sequence file name." + ) + + def test_hf_mutual_exclusive_labels_file_name(self): + with tempfile.TemporaryDirectory() as tmpdir: + config_dict = deepcopy(configurations["hf_valid_for_sequences"]) + config_dict["labels_file"] = "random_invalid_name" + + configurator = Configurator.from_config_dict(config_dict) + configurator._config_file_path = Path(tmpdir) + with self.assertRaises(ConfigurationException) as context: + configurator.get_verified_config() + + self.assertIn( + "mutual exclusive", + str(context.exception), + "Exception does not raise an exception for mutual exclusive labels file name." + ) + + def test_hf_mutual_exclusive_mask_file_name(self): + with tempfile.TemporaryDirectory() as tmpdir: + config_dict = deepcopy(configurations["hf_valid_for_sequences"]) + config_dict["mask_file"] = "random_invalid_name" + + configurator = Configurator.from_config_dict(config_dict) + configurator._config_file_path = Path(tmpdir) + with self.assertRaises(ConfigurationException) as context: + configurator.get_verified_config() + + self.assertIn( + "mutual exclusive", + str(context.exception), + "Exception does not raise an exception for mutual exclusive mask file name." + ) \ No newline at end of file