From 27dc618cdc49bd1056b03b986cf284e81023c9b5 Mon Sep 17 00:00:00 2001 From: js2264 Date: Sat, 19 Oct 2024 03:16:45 +0200 Subject: [PATCH] feat: split generator in two: a streamer and a Dataset class --- docs/source/conf.py | 2 +- src/momics/{dataloader.py => dataset.py} | 106 +++++++++++++++- src/momics/multirangequery.py | 4 +- src/momics/streamer.py | 146 +++++++++++++++++++++++ tests/test_data.py | 48 ++++++++ tests/test_dataloader.py | 31 ----- 6 files changed, 302 insertions(+), 35 deletions(-) rename src/momics/{dataloader.py => dataset.py} (55%) create mode 100644 src/momics/streamer.py create mode 100644 tests/test_data.py delete mode 100644 tests/test_dataloader.py diff --git a/docs/source/conf.py b/docs/source/conf.py index bf4ab95..37a6f03 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -53,7 +53,7 @@ templates_path = ["_templates"] exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] -pygments_style = "one-dark" +# pygments_style = "one-dark" todo_include_todos = False master_doc = "index" diff --git a/src/momics/dataloader.py b/src/momics/dataset.py similarity index 55% rename from src/momics/dataloader.py rename to src/momics/dataset.py index 69c5ca9..4057fba 100644 --- a/src/momics/dataloader.py +++ b/src/momics/dataset.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple +from typing import Callable, Generator, Optional, Tuple import numpy as np import pyranges as pr @@ -6,15 +6,117 @@ import tensorflow as tf from .momics import Momics +from .streamer import MomicsStreamer from .multirangequery import MultiRangeQuery +class MomicsDataset(tf.data.Dataset): + """ + This class is implemented to train deep learning models, where the + input data (features) is a track or a sequence and the labeled data (target) + is another track. The data loader will iterate over the ranges in batches + and extract the features and target for each range. It is a subclass of + `tf.data.DataSet` and can be used as a generator for a `tf.keras.Model`. + + For a more basic generator to stream a `momics` by batches of ranges, + see `momics.streamer.MomicsStreamer`. + + See Also + -------- + `momics.streamer.MomicsStreamer` + """ + + def __new__( + cls, + momics: Momics, + ranges: pr.PyRanges, + features: str, + target: str, + target_size: Optional[int] = None, + batch_size: Optional[int] = None, + preprocess_func: Optional[Callable] = None, + shuffle_buffer_size: int = 10000, + prefetch_buffer_size: Optional[int] = tf.data.experimental.AUTOTUNE, + silent: bool = False, + ) -> tf.data.Dataset: + """Create the MomicsDataset object. + + Args: + momics (Momics): a Momics object + ranges (pr.PyRanges): pr.PyRanges object + features (str): the name of the track to use for input data + target_size (int): To which width should the target be centered + target (str): the name of the track to use for output data + batch_size (int): the batch size + preprocess_func (Callable): a function to preprocess the queried data + shuffle_buffer_size (int): the size of the shuffle buffer. Pass 0 to disable shuffling + prefetch_buffer_size (int): the size of the prefetch buffer + silent (bool): whether to suppress info messages + """ + + # Check that all ranges have the same width + df = ranges.df + widths = df.End - df.Start + if len(set(widths)) != 1: + raise ValueError("All ranges must have the same width") + w = int(widths[0]) + + # Check that the target size is smaller than the features width + if target_size is not None and target_size > w: + raise ValueError("Target size must be smaller than the features width.") + + # Encapsulate MomicsStreamer logic + def generator() -> Generator: + streamer = MomicsStreamer( + momics, ranges, batch_size, features=[features, target], preprocess_func=preprocess_func, silent=silent + ).generator() + for features_data, out in streamer: + + # Adjust the output if target_size is provided + if target_size: + center = out.shape[1] // 2 + label_data = out[:, int(center - target_size // 2) : int(center + target_size // 2)] + else: + label_data = out + + yield features_data, label_data + + # Example output signature (modify based on your actual data shapes) + feature_shape = (None, w, 4 if features == "nucleotide" else 1) + label_shape = (None, target_size if target_size else w, 4 if features == "nucleotide" else 1) + + dataset = tf.data.Dataset.from_generator( + generator, + output_signature=( + tf.TensorSpec(shape=feature_shape, dtype=tf.float32), + tf.TensorSpec(shape=label_shape, dtype=tf.float32), + ), + ) + + # Add shuffling and prefetching + if shuffle_buffer_size > 0: + shuffle_buffer_size = min(shuffle_buffer_size, batch_size) + dataset = dataset.shuffle(shuffle_buffer_size) + + dataset = dataset.prefetch(buffer_size=prefetch_buffer_size) + + return dataset + + class RangeDataLoader(tf.keras.utils.Sequence): """ This class is implemented to train deep learning models, where the input data (features) is a track or a sequence and the labeled data (target) is another track. The data loader will iterate over the ranges in batches - and extract the features and target for each range. + and extract the features and target for each range. It is a subclass of + `tf.data.DataSet` and can be used as a generator for a `tf.keras.Model`. + + For a more basic generator to stream a `momics` by batches of ranges, + see `momics.streamer.MomicsStreamer`. + + See Also + -------- + `momics.streamer.MomicsStreamer` Attributes ---------- diff --git a/src/momics/multirangequery.py b/src/momics/multirangequery.py index c6c0e4d..6ef28e3 100644 --- a/src/momics/multirangequery.py +++ b/src/momics/multirangequery.py @@ -140,7 +140,9 @@ def query_tracks(self, threads: Optional[int] = None, tracks: Optional[list] = N ).schema attrs = [_sch.attr(i).name for i in range(_sch.nattr)] if tracks is not None: - attrs = [attr for attr in attrs if attr in tracks] + if not all([track in attrs for track in tracks]): + raise ValueError(f"Tracks {tracks} not found in the repository.") + attrs = tracks # Check memory available and warn if it's not enough self._check_memory_available(len(attrs)) diff --git a/src/momics/streamer.py b/src/momics/streamer.py new file mode 100644 index 0000000..a663631 --- /dev/null +++ b/src/momics/streamer.py @@ -0,0 +1,146 @@ +from typing import Callable, Optional + +import numpy as np +import pyranges as pr +import logging + +from .momics import Momics +from .multirangequery import MultiRangeQuery + + +class MomicsStreamer: + """ + This class is implemented to efficiently query a `momics` repository by batches + and extract any coverage data from it. The data streamer will iterate over ranges in batches + and iteratively query a `momics`. + + For a tensorflow DataSet constructor, see `momics.dataset.MomicsDataset`. + + See Also + -------- + `momics.dataset.MomicsDataset` + + Attributes + ---------- + momics (Momics): a local `.momics` repository. + ranges (dict): pr.PyRanges object. + batch_size (int): the batch size + features (list): list of track labels to query + silent (bool): whether to suppress info messages + """ + + def __init__( + self, + momics: Momics, + ranges: pr.PyRanges, + batch_size: Optional[int] = None, + features: Optional[int] = None, + preprocess_func: Optional[Callable] = None, + silent: bool = False, + ) -> None: + """Initialize the MomicsStreamer object. + + Args: + momics (Momics): a Momics object + ranges (dict): pr.PyRanges object. + batch_size (int): the batch size + features (list): list of track labels to query + preprocess_func (Callable): a function to preprocess the queried data + silent (bool): whether to suppress info messages + """ + + self.momics = momics + self.ranges = ranges + if batch_size is None: + batch_size = len(ranges) + self.batch_size = batch_size + self.num_batches = (len(ranges) + batch_size - 1) // batch_size + if features is not None: + i = 0 + if not isinstance(features, list): + features = [features] + if "nucleotide" in features: + features.remove("nucleotide") + i += 1 + _ = momics.seq() + if len(features) > i: + tr = momics.tracks() + for f in features: + if f not in list(tr["label"]): + raise ValueError(f"Features {f} not found in momics repository.") + + if i > 0: + features.insert(0, "nucleotide") + self.features = features + self.silent = silent + self.preprocess_func = preprocess_func if preprocess_func else self._default_preprocess + + def query(self, batch_ranges): + """ + Query function to fetch data from a `momics` repo based on batch_ranges. + + Args: + batch_ranges (pr.PyRanges): PyRanges object for a batch + + Returns: + Queried coverage/sequence data + """ + + attrs = self.features + res = {attr: None for attr in attrs} + q = MultiRangeQuery(self.momics, batch_ranges) + + if self.silent: + logging.disable(logging.WARNING) + + # Fetch seq if needed + if "nucleotide" in attrs: + attrs.remove("nucleotide") + q.query_sequence() + seqs = list(q.seq["nucleotide"].values()) + + # One-hot-encode the sequences lists in seqs + def one_hot_encode(seq) -> np.ndarray: + seq = seq.upper() + encoding_map = {"A": [1, 0, 0, 0], "T": [0, 1, 0, 0], "C": [0, 0, 1, 0], "G": [0, 0, 0, 1]} + oha = np.zeros((len(seq), 4), dtype=int) + for i, nucleotide in enumerate(seq): + oha[i] = encoding_map[nucleotide] + + return oha + + X = np.array([one_hot_encode(seq) for seq in seqs]) + sh = X.shape + res["nucleotide"] = X.reshape(-1, sh[1], 4) + + # Fetch coverage tracks if needed + if len(attrs) > 0: + q.query_tracks(tracks=attrs) + for attr in attrs: + out = np.array(list(q.coverage[attr].values())) + sh = out.shape + res[attr] = out.reshape(-1, sh[1], 1) + + if self.silent: + logging.disable(logging.NOTSET) + + return tuple(res.values()) + + def _default_preprocess(self, data): + """ + Default preprocessing function that normalizes data. + """ + return (data - np.mean(data, axis=0)) / np.std(data, axis=0) + + def generator(self): + """ + Generator to yield batches of ranges and queried/preprocessed data. + + Yields: + Tuple[pr.PyRanges, np.ndarray]: batch_ranges and preprocessed_data + """ + for i in range(0, len(self.ranges), self.batch_size): + batch_ranges = pr.PyRanges(self.ranges.df.iloc[i : i + self.batch_size]) + queried_data = self.query(batch_ranges) + # preprocessed_data = self.preprocess(queried_data) + yield queried_data diff --git a/tests/test_data.py b/tests/test_data.py new file mode 100644 index 0000000..be5f704 --- /dev/null +++ b/tests/test_data.py @@ -0,0 +1,48 @@ +import pytest + +import momics +from momics.dataset import MomicsDataset +from momics.streamer import MomicsStreamer + + +@pytest.mark.order(99) +def test_streamer(momics_path: str): + mom = momics.Momics(momics_path) + + b = mom.bins(10, 21, cut_last_bin_out=True) + with pytest.raises(ValueError, match=r".*not found in momics repository."): + MomicsStreamer(mom, b, features=["CH1", "bw2"]) + + rg = MomicsStreamer(mom, b, features=["bw3", "bw2"], batch_size=1000) + n = next(iter(rg.generator())) + assert len(n) == 2 + assert n[0].shape == (1000, 10, 1) + assert n[1].shape == (1000, 10, 1) + + rg = MomicsStreamer(mom, b, features=["nucleotide", "bw2"], batch_size=10) + n = next(iter(rg.generator())) + assert n[0].shape == (10, 10, 4) + assert n[1].shape == (10, 10, 1) + + +@pytest.mark.order(99) +def test_dataset(momics_path: str): + mom = momics.Momics(momics_path) + b = mom.bins(10, 21) + + with pytest.raises(ValueError, match="All ranges must have the same width"): + MomicsDataset(mom, b, "CH0", "CH1") + + b = mom.bins(10, 21, cut_last_bin_out=True) + with pytest.raises(ValueError, match=r"Target size must be smaller than the features width"): + MomicsDataset(mom, b, "bw3", "bw2", target_size=1000000) + + rg = MomicsDataset(mom, b, "bw3", "bw2", target_size=2, batch_size=10) + n = next(iter(rg)) + assert n[0].shape == (10, 10, 1) + assert n[1].shape == (10, 2, 1) + + rg = MomicsDataset(mom, b, "nucleotide", "CH1", target_size=2, batch_size=10) + n = next(iter(rg)) + assert n[0].shape == (10, 10, 4) + assert n[1].shape == (10, 2, 1) diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py deleted file mode 100644 index b90f6bb..0000000 --- a/tests/test_dataloader.py +++ /dev/null @@ -1,31 +0,0 @@ -import pytest - -import momics -from momics.dataloader import RangeDataLoader - - -@pytest.mark.order(99) -def test_generator_init(momics_path: str): - mom = momics.Momics(momics_path) - b = mom.bins(10, 21) - - with pytest.raises(ValueError, match="All ranges must have the same width"): - RangeDataLoader(mom, b, "CH0", "CH1") - - b = mom.bins(10, 20, cut_last_bin_out=True) - with pytest.raises(ValueError, match=r".*not found in momics repository"): - RangeDataLoader(mom, b, "CH0", "CH1") - with pytest.raises(ValueError, match=r".*not found in momics repository"): - RangeDataLoader(mom, b, "bw3", "CH1") - with pytest.raises(ValueError, match=r"Target size must be smaller than the features width"): - RangeDataLoader(mom, b, "bw3", "bw2", target_size=1000000) - - rg = RangeDataLoader(mom, b, "bw3", "bw2", target_size=2, batch_size=1000) - n = rg[0] - assert n[0].shape == (1000, 10, 1) - assert n[1].shape == (1000, 2, 1) - - rg = RangeDataLoader(mom, b, "nucleotide", "bw2", target_size=2, batch_size=1000) - n = rg[0] - assert n[0].shape == (1000, 10, 4) - assert n[1].shape == (1000, 2, 1)