Skip to content

Commit

Permalink
feat: improve Streamer and Dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
js2264 committed Oct 19, 2024
1 parent 609484c commit 0ce0e07
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 178 deletions.
160 changes: 4 additions & 156 deletions src/momics/dataset.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
from typing import Callable, Generator, Optional, Tuple
from typing import Callable, Generator, Optional

import numpy as np
import pyranges as pr
import logging
import tensorflow as tf

from .momics import Momics
from .streamer import MomicsStreamer
from .multirangequery import MultiRangeQuery


class MomicsDataset(tf.data.Dataset):
Expand Down Expand Up @@ -37,7 +34,7 @@ def __new__(
preprocess_func: Optional[Callable] = None,
shuffle_buffer_size: int = 10000,
prefetch_buffer_size: Optional[int] = tf.data.experimental.AUTOTUNE,
silent: bool = False,
silent: bool = True,
) -> tf.data.Dataset:
"""Create the MomicsDataset object.
Expand Down Expand Up @@ -69,7 +66,7 @@ def __new__(
def generator() -> Generator:
streamer = MomicsStreamer(
momics, ranges, batch_size, features=[features, target], preprocess_func=preprocess_func, silent=silent
).generator()
)
for features_data, out in streamer:

# Adjust the output if target_size is provided
Expand All @@ -83,7 +80,7 @@ def generator() -> Generator:

# Example output signature (modify based on your actual data shapes)
feature_shape = (None, w, 4 if features == "nucleotide" else 1)
label_shape = (None, target_size if target_size else w, 4 if features == "nucleotide" else 1)
label_shape = (None, target_size if target_size else w, 4 if target == "nucleotide" else 1)

dataset = tf.data.Dataset.from_generator(
generator,
Expand All @@ -101,152 +98,3 @@ def generator() -> Generator:
dataset = dataset.prefetch(buffer_size=prefetch_buffer_size)

return dataset


class RangeDataLoader(tf.keras.utils.Sequence):
"""
This class is implemented to train deep learning models, where the
input data (features) is a track or a sequence and the labeled data (target)
is another track. The data loader will iterate over the ranges in batches
and extract the features and target for each range. It is a subclass of
`tf.data.DataSet` and can be used as a generator for a `tf.keras.Model`.
For a more basic generator to stream a `momics` by batches of ranges,
see `momics.streamer.MomicsStreamer`.
See Also
--------
`momics.streamer.MomicsStreamer`
Attributes
----------
momics (Momics): a local `.momics` repository.
ranges (dict): pr.PyRanges object.
features (str): the name of the track to use for input data
target (str): the name of the track to use for output data
target_size (int): To which width should the target be centered
"""

def __init__(
self,
momics: Momics,
ranges: pr.PyRanges,
features: str,
target: str,
target_size: Optional[int] = None,
batch_size: Optional[int] = None,
silent: bool = False,
) -> None:
"""Initialize the RangeDataLoader object.
Args:
momics (Momics): a Momics object
ranges (pr.PyRanges): pr.PyRanges object
features (str): the name of the track to use for input data
target_size (int): To which width should the target be centered
target (str): the name of the track to use for output data
batch_size (int): the batch size
silent (bool): whether to suppress info messages
"""

# Check that all ranges have the same width
df = ranges.df
widths = df.End - df.Start
if len(set(widths)) != 1:
raise ValueError("All ranges must have the same width")

self.momics = momics
self.ranges = ranges
if batch_size is None:
batch_size = len(ranges)
self.start = 0
self.stop = len(ranges)
self.batch_size = batch_size
self.current = self.start
self.silent = silent

tr = momics.tracks()
if features == "nucleotide":
_ = momics.seq()

if features not in list(tr["label"]) and features != "nucleotide":
raise ValueError(f"Features {features} not found in momics repository.")
if target not in list(tr["label"]):
raise ValueError(f"Target {target} not found in momics repository.")

self.features = features
self.target = target

if target_size is not None and target_size > int(widths[0]):
raise ValueError("Target size must be smaller than the features width.")
self.target_size = target_size

def __len__(self) -> int:
return int(np.ceil(len(self.ranges) / self.batch_size))

def __getitem__(self, idx) -> Tuple[np.ndarray, np.ndarray]:
subrg = pr.PyRanges(self.ranges.df[idx * self.batch_size : (idx + 1) * self.batch_size])

# Fetch only required tracks
attrs = [self.target]
q = MultiRangeQuery(self.momics, subrg)
if self.features != "nucleotide":
attrs.append(self.features)

if self.silent:
logging.disable(logging.WARNING)
q.query_tracks(tracks=attrs)
if self.silent:
logging.disable(logging.WARNING)
logging.disable(logging.NOTSET)

# If input is a track, reshape and filter out NaN values
if self.features in q.coverage.keys(): # type: ignore
X = np.array(list(q.coverage[self.features].values())) # type: ignore
# filter = ~np.isnan(X).any(axis=1)
# X = X[filter]
X = np.nan_to_num(X, nan=0)
sh = X.shape
X = X.reshape(-1, sh[1], 1)

# If input is the sequences, one-hot-encode the sequences and resize
elif self.features == "nucleotide":
q.query_sequence()
seqs = list(q.seq["nucleotide"].values()) # type: ignore

# One-hot-encode the sequences lists in seqs
def one_hot_encode(seq) -> np.ndarray:
seq = seq.upper()
encoding_map = {"A": [1, 0, 0, 0], "T": [0, 1, 0, 0], "C": [0, 0, 1, 0], "G": [0, 0, 0, 1]}
oha = np.zeros((len(seq), 4), dtype=int)
for i, nucleotide in enumerate(seq):
oha[i] = encoding_map[nucleotide]

return oha

X = np.array([one_hot_encode(seq) for seq in seqs])
sh = X.shape
X = X.reshape(-1, sh[1], 4)
else:
raise ValueError("features must be a track label or 'nucleotide'")

# Extract label and filter out NaN values
out = np.array(list(q.coverage[self.target].values())) # type: ignore
# out = out[filter]
out = np.nan_to_num(out, nan=0)

# Recenter label if needed
if self.target_size is not None:
midpos = out.shape[1] // 2
out = out[:, int(midpos - self.target_size / 2) : int(midpos + self.target_size / 2)]
dim = self.target_size
else:
sh = out.shape
dim = sh[1]

Y = out.reshape(-1, dim, 1)

return X, Y

def __str__(self):
return f"RangeDataLoader(start={self.start}, stop={self.stop}, batch_size={self.batch_size})"
9 changes: 6 additions & 3 deletions src/momics/multirangequery.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,12 @@ def query_tracks(self, threads: Optional[int] = None, tracks: Optional[list] = N
).schema
attrs = [_sch.attr(i).name for i in range(_sch.nattr)]
if tracks is not None:
if not all([track in attrs for track in tracks]):
raise ValueError(f"Tracks {tracks} not found in the repository.")
attrs = tracks
for track in tracks:
if track == "nucleotide":
logger.debug("'nucleotide' track is not a coverage track.")
elif track not in attrs:
raise ValueError(f"Track {track} not found in the repository.")
attrs = [tr for tr in tracks if tr != "nucleotide"]

# Check memory available and warn if it's not enough
self._check_memory_available(len(attrs))
Expand Down
58 changes: 42 additions & 16 deletions src/momics/streamer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Optional
from typing import Callable, Optional, Generator, Tuple

import numpy as np
import pyranges as pr
Expand Down Expand Up @@ -36,7 +36,7 @@ def __init__(
batch_size: Optional[int] = None,
features: Optional[int] = None,
preprocess_func: Optional[Callable] = None,
silent: bool = False,
silent: bool = True,
) -> None:
"""Initialize the MomicsStreamer object.
Expand All @@ -56,26 +56,26 @@ def __init__(
self.batch_size = batch_size
self.num_batches = (len(ranges) + batch_size - 1) // batch_size
if features is not None:
i = 0
if not isinstance(features, list):
features = [features]
i = len(features)
if "nucleotide" in features:
features.remove("nucleotide")
i += 1
_ = momics.seq()
if len(features) > i:
tr = momics.tracks()
i -= 1
_ = momics.seq() # Check that the momics object has a sequence
if i > 0: # Other features besides "nucleotide"
tr = momics.tracks() # Check that the momics object has the tracks
for f in features:
if f == "nucleotide":
continue
if f not in list(tr["label"]):
raise ValueError(f"Features {f} not found in momics repository.")

if i > 0:
features.insert(0, "nucleotide")
self.features = features
self.silent = silent
self.preprocess_func = preprocess_func if preprocess_func else self._default_preprocess
self.batch_index = 0

def query(self, batch_ranges):
def query(self, batch_ranges) -> Tuple:
"""
Query function to fetch data from a `momics` repo based on batch_ranges.
Expand All @@ -87,6 +87,7 @@ def query(self, batch_ranges):
"""

attrs = self.features
i = len(attrs)
res = {attr: None for attr in attrs}
q = MultiRangeQuery(self.momics, batch_ranges)

Expand All @@ -95,7 +96,7 @@ def query(self, batch_ranges):

# Fetch seq if needed
if "nucleotide" in attrs:
attrs.remove("nucleotide")
i -= 1
q.query_sequence()
seqs = list(q.seq["nucleotide"].values())

Expand All @@ -114,9 +115,10 @@ def one_hot_encode(seq) -> np.ndarray:
res["nucleotide"] = X.reshape(-1, sh[1], 4)

# Fetch coverage tracks if needed
if len(attrs) > 0:
q.query_tracks(tracks=attrs)
for attr in attrs:
if i > 0:
attrs2 = [attr for attr in attrs if attr != "nucleotide"]
q.query_tracks(tracks=attrs2)
for attr in attrs2:
out = np.array(list(q.coverage[attr].values()))
sh = out.shape
res[attr] = out.reshape(-1, sh[1], 1)
Expand All @@ -132,15 +134,39 @@ def _default_preprocess(self, data):
"""
return (data - np.mean(data, axis=0)) / np.std(data, axis=0)

def generator(self):
def generator(self) -> Generator:
"""
Generator to yield batches of ranges and queried/preprocessed data.
Yields:
Tuple[pr.PyRanges, np.ndarray]: batch_ranges and preprocessed_data
"""
self.batch_index = 0
for i in range(0, len(self.ranges), self.batch_size):
batch_ranges = pr.PyRanges(self.ranges.df.iloc[i : i + self.batch_size])
queried_data = self.query(batch_ranges)
# preprocessed_data = self.preprocess(queried_data)
self.batch_index += 1
yield queried_data

def __iter__(self):
return self.generator()

def __next__(self):
"""Return the next batch or raise StopIteration."""
if self.batch_index < self.num_batches:
start = self.batch_index * self.batch_size
end = min((self.batch_index + 1) * self.batch_size, len(self.ranges))
batch_ranges = pr.PyRanges(self.ranges.df.iloc[start:end])
queried_data = self.query(batch_ranges)
self.batch_index += 1
return queried_data
else:
raise StopIteration

def __len__(self):
return self.num_batches

def reset(self):
"""Reset the iterator to allow re-iteration."""
self.batch_index = 0
39 changes: 36 additions & 3 deletions tests/test_data.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Generator
import pytest

import momics
from momics.dataset import MomicsDataset
from momics.streamer import MomicsStreamer
from momics.dataset import MomicsDataset


@pytest.mark.order(99)
Expand All @@ -13,14 +14,46 @@ def test_streamer(momics_path: str):
with pytest.raises(ValueError, match=r".*not found in momics repository."):
MomicsStreamer(mom, b, features=["CH1", "bw2"])

rg = MomicsStreamer(mom, b, features="bw2", batch_size=1000, silent=False)
assert isinstance(rg.generator(), Generator)
assert isinstance(iter(rg), Generator)

n = next(iter(rg))
assert len(n) == 1
assert n[0].shape == (1000, 10, 1)
n = next(rg)
assert len(n) == 1
assert n[0].shape == (1000, 10, 1)

for n in iter(rg):
print(rg.batch_index, " / ", rg.num_batches)
assert isinstance(n, tuple)
with pytest.raises(StopIteration):
next(rg)
assert rg.batch_index == rg.num_batches
rg.reset()
assert rg.batch_index == 0
n = next(rg)
assert len(n) == 1
assert n[0].shape == (1000, 10, 1)

rg = MomicsStreamer(mom, b, features=["bw3", "bw2"], batch_size=1000)
n = next(iter(rg.generator()))
n = next(rg)
assert len(n) == 2
assert n[0].shape == (1000, 10, 1)
assert n[1].shape == (1000, 10, 1)
n = next(rg)
assert len(n) == 2
assert n[0].shape == (1000, 10, 1)
assert n[1].shape == (1000, 10, 1)

rg = MomicsStreamer(mom, b, features=["nucleotide", "bw2"], batch_size=10)
n = next(iter(rg.generator()))
n = next(rg)
assert len(n) == 2
assert n[0].shape == (10, 10, 4)
assert n[1].shape == (10, 10, 1)
n = next(rg)
assert len(n) == 2
assert n[0].shape == (10, 10, 4)
assert n[1].shape == (10, 10, 1)

Expand Down

0 comments on commit 0ce0e07

Please sign in to comment.