Skip to content

Commit

Permalink
feat: split generator in two: a streamer and a Dataset class
Browse files Browse the repository at this point in the history
  • Loading branch information
js2264 committed Oct 19, 2024
1 parent 1fc8bf0 commit 27dc618
Show file tree
Hide file tree
Showing 6 changed files with 302 additions and 35 deletions.
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
templates_path = ["_templates"]
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]

pygments_style = "one-dark"
# pygments_style = "one-dark"
todo_include_todos = False
master_doc = "index"

Expand Down
106 changes: 104 additions & 2 deletions src/momics/dataloader.py → src/momics/dataset.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,122 @@
from typing import Optional, Tuple
from typing import Callable, Generator, Optional, Tuple

import numpy as np
import pyranges as pr
import logging
import tensorflow as tf

from .momics import Momics
from .streamer import MomicsStreamer
from .multirangequery import MultiRangeQuery


class MomicsDataset(tf.data.Dataset):
"""
This class is implemented to train deep learning models, where the
input data (features) is a track or a sequence and the labeled data (target)
is another track. The data loader will iterate over the ranges in batches
and extract the features and target for each range. It is a subclass of
`tf.data.DataSet` and can be used as a generator for a `tf.keras.Model`.
For a more basic generator to stream a `momics` by batches of ranges,
see `momics.streamer.MomicsStreamer`.
See Also
--------
`momics.streamer.MomicsStreamer`
"""

def __new__(
cls,
momics: Momics,
ranges: pr.PyRanges,
features: str,
target: str,
target_size: Optional[int] = None,
batch_size: Optional[int] = None,
preprocess_func: Optional[Callable] = None,
shuffle_buffer_size: int = 10000,
prefetch_buffer_size: Optional[int] = tf.data.experimental.AUTOTUNE,
silent: bool = False,
) -> tf.data.Dataset:
"""Create the MomicsDataset object.
Args:
momics (Momics): a Momics object
ranges (pr.PyRanges): pr.PyRanges object
features (str): the name of the track to use for input data
target_size (int): To which width should the target be centered
target (str): the name of the track to use for output data
batch_size (int): the batch size
preprocess_func (Callable): a function to preprocess the queried data
shuffle_buffer_size (int): the size of the shuffle buffer. Pass 0 to disable shuffling
prefetch_buffer_size (int): the size of the prefetch buffer
silent (bool): whether to suppress info messages
"""

# Check that all ranges have the same width
df = ranges.df
widths = df.End - df.Start
if len(set(widths)) != 1:
raise ValueError("All ranges must have the same width")
w = int(widths[0])

# Check that the target size is smaller than the features width
if target_size is not None and target_size > w:
raise ValueError("Target size must be smaller than the features width.")

# Encapsulate MomicsStreamer logic
def generator() -> Generator:
streamer = MomicsStreamer(
momics, ranges, batch_size, features=[features, target], preprocess_func=preprocess_func, silent=silent
).generator()
for features_data, out in streamer:

# Adjust the output if target_size is provided
if target_size:
center = out.shape[1] // 2
label_data = out[:, int(center - target_size // 2) : int(center + target_size // 2)]
else:
label_data = out

yield features_data, label_data

# Example output signature (modify based on your actual data shapes)
feature_shape = (None, w, 4 if features == "nucleotide" else 1)
label_shape = (None, target_size if target_size else w, 4 if features == "nucleotide" else 1)

dataset = tf.data.Dataset.from_generator(
generator,
output_signature=(
tf.TensorSpec(shape=feature_shape, dtype=tf.float32),
tf.TensorSpec(shape=label_shape, dtype=tf.float32),
),
)

# Add shuffling and prefetching
if shuffle_buffer_size > 0:
shuffle_buffer_size = min(shuffle_buffer_size, batch_size)
dataset = dataset.shuffle(shuffle_buffer_size)

dataset = dataset.prefetch(buffer_size=prefetch_buffer_size)

return dataset


class RangeDataLoader(tf.keras.utils.Sequence):
"""
This class is implemented to train deep learning models, where the
input data (features) is a track or a sequence and the labeled data (target)
is another track. The data loader will iterate over the ranges in batches
and extract the features and target for each range.
and extract the features and target for each range. It is a subclass of
`tf.data.DataSet` and can be used as a generator for a `tf.keras.Model`.
For a more basic generator to stream a `momics` by batches of ranges,
see `momics.streamer.MomicsStreamer`.
See Also
--------
`momics.streamer.MomicsStreamer`
Attributes
----------
Expand Down
4 changes: 3 additions & 1 deletion src/momics/multirangequery.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ def query_tracks(self, threads: Optional[int] = None, tracks: Optional[list] = N
).schema
attrs = [_sch.attr(i).name for i in range(_sch.nattr)]
if tracks is not None:
attrs = [attr for attr in attrs if attr in tracks]
if not all([track in attrs for track in tracks]):
raise ValueError(f"Tracks {tracks} not found in the repository.")
attrs = tracks

# Check memory available and warn if it's not enough
self._check_memory_available(len(attrs))
Expand Down
146 changes: 146 additions & 0 deletions src/momics/streamer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
from typing import Callable, Optional

import numpy as np
import pyranges as pr
import logging

from .momics import Momics
from .multirangequery import MultiRangeQuery


class MomicsStreamer:
"""
This class is implemented to efficiently query a `momics` repository by batches
and extract any coverage data from it. The data streamer will iterate over ranges in batches
and iteratively query a `momics`.
For a tensorflow DataSet constructor, see `momics.dataset.MomicsDataset`.
See Also
--------
`momics.dataset.MomicsDataset`
Attributes
----------
momics (Momics): a local `.momics` repository.
ranges (dict): pr.PyRanges object.
batch_size (int): the batch size
features (list): list of track labels to query
silent (bool): whether to suppress info messages
"""

def __init__(
self,
momics: Momics,
ranges: pr.PyRanges,
batch_size: Optional[int] = None,
features: Optional[int] = None,
preprocess_func: Optional[Callable] = None,
silent: bool = False,
) -> None:
"""Initialize the MomicsStreamer object.
Args:
momics (Momics): a Momics object
ranges (dict): pr.PyRanges object.
batch_size (int): the batch size
features (list): list of track labels to query
preprocess_func (Callable): a function to preprocess the queried data
silent (bool): whether to suppress info messages
"""

self.momics = momics
self.ranges = ranges
if batch_size is None:
batch_size = len(ranges)
self.batch_size = batch_size
self.num_batches = (len(ranges) + batch_size - 1) // batch_size
if features is not None:
i = 0
if not isinstance(features, list):
features = [features]
if "nucleotide" in features:
features.remove("nucleotide")
i += 1
_ = momics.seq()
if len(features) > i:
tr = momics.tracks()
for f in features:
if f not in list(tr["label"]):
raise ValueError(f"Features {f} not found in momics repository.")

if i > 0:
features.insert(0, "nucleotide")
self.features = features
self.silent = silent
self.preprocess_func = preprocess_func if preprocess_func else self._default_preprocess

def query(self, batch_ranges):
"""
Query function to fetch data from a `momics` repo based on batch_ranges.
Args:
batch_ranges (pr.PyRanges): PyRanges object for a batch
Returns:
Queried coverage/sequence data
"""

attrs = self.features
res = {attr: None for attr in attrs}
q = MultiRangeQuery(self.momics, batch_ranges)

if self.silent:
logging.disable(logging.WARNING)

# Fetch seq if needed
if "nucleotide" in attrs:
attrs.remove("nucleotide")
q.query_sequence()
seqs = list(q.seq["nucleotide"].values())

# One-hot-encode the sequences lists in seqs
def one_hot_encode(seq) -> np.ndarray:
seq = seq.upper()
encoding_map = {"A": [1, 0, 0, 0], "T": [0, 1, 0, 0], "C": [0, 0, 1, 0], "G": [0, 0, 0, 1]}
oha = np.zeros((len(seq), 4), dtype=int)
for i, nucleotide in enumerate(seq):
oha[i] = encoding_map[nucleotide]

return oha

X = np.array([one_hot_encode(seq) for seq in seqs])
sh = X.shape
res["nucleotide"] = X.reshape(-1, sh[1], 4)

# Fetch coverage tracks if needed
if len(attrs) > 0:
q.query_tracks(tracks=attrs)
for attr in attrs:
out = np.array(list(q.coverage[attr].values()))
sh = out.shape
res[attr] = out.reshape(-1, sh[1], 1)

if self.silent:
logging.disable(logging.NOTSET)

return tuple(res.values())

def _default_preprocess(self, data):
"""
Default preprocessing function that normalizes data.
"""
return (data - np.mean(data, axis=0)) / np.std(data, axis=0)

def generator(self):
"""
Generator to yield batches of ranges and queried/preprocessed data.
Yields:
Tuple[pr.PyRanges, np.ndarray]: batch_ranges and preprocessed_data
"""
for i in range(0, len(self.ranges), self.batch_size):
batch_ranges = pr.PyRanges(self.ranges.df.iloc[i : i + self.batch_size])
queried_data = self.query(batch_ranges)
# preprocessed_data = self.preprocess(queried_data)
yield queried_data
48 changes: 48 additions & 0 deletions tests/test_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import pytest

import momics
from momics.dataset import MomicsDataset
from momics.streamer import MomicsStreamer


@pytest.mark.order(99)
def test_streamer(momics_path: str):
mom = momics.Momics(momics_path)

b = mom.bins(10, 21, cut_last_bin_out=True)
with pytest.raises(ValueError, match=r".*not found in momics repository."):
MomicsStreamer(mom, b, features=["CH1", "bw2"])

rg = MomicsStreamer(mom, b, features=["bw3", "bw2"], batch_size=1000)
n = next(iter(rg.generator()))
assert len(n) == 2
assert n[0].shape == (1000, 10, 1)
assert n[1].shape == (1000, 10, 1)

rg = MomicsStreamer(mom, b, features=["nucleotide", "bw2"], batch_size=10)
n = next(iter(rg.generator()))
assert n[0].shape == (10, 10, 4)
assert n[1].shape == (10, 10, 1)


@pytest.mark.order(99)
def test_dataset(momics_path: str):
mom = momics.Momics(momics_path)
b = mom.bins(10, 21)

with pytest.raises(ValueError, match="All ranges must have the same width"):
MomicsDataset(mom, b, "CH0", "CH1")

b = mom.bins(10, 21, cut_last_bin_out=True)
with pytest.raises(ValueError, match=r"Target size must be smaller than the features width"):
MomicsDataset(mom, b, "bw3", "bw2", target_size=1000000)

rg = MomicsDataset(mom, b, "bw3", "bw2", target_size=2, batch_size=10)
n = next(iter(rg))
assert n[0].shape == (10, 10, 1)
assert n[1].shape == (10, 2, 1)

rg = MomicsDataset(mom, b, "nucleotide", "CH1", target_size=2, batch_size=10)
n = next(iter(rg))
assert n[0].shape == (10, 10, 4)
assert n[1].shape == (10, 2, 1)
31 changes: 0 additions & 31 deletions tests/test_dataloader.py

This file was deleted.

0 comments on commit 27dc618

Please sign in to comment.