Skip to content

Commit

Permalink
fix: tf.int32 in tf dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
js2264 committed Oct 28, 2024
1 parent d24095d commit fbbb1a9
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 16 deletions.
12 changes: 11 additions & 1 deletion CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
## Unreleased

## [v0.3.0](https://github.com/js2264/momics/releases/tag/0.4.0)
### New features

### Enhancements

* `silent` argument to `query_*` methods

### Bug fixes

* Fix MomicsDataset and generator.

## [v0.4.0](https://github.com/js2264/momics/releases/tag/0.4.0)

*Date: 2024-10-25*

Expand Down
6 changes: 3 additions & 3 deletions src/momics/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __new__(
x_streamer = MomicsStreamer(repo, ranges, batch_size, features=[features], preprocess_func=preprocess_func, silent=silent)
x_gen = x_streamer.generator
if features == "nucleotide":
out = tf.TensorSpec(shape=(None, features_size, 4), dtype=tf.string)
out = tf.TensorSpec(shape=(None, features_size, 4), dtype=tf.int32)
else:
out = tf.TensorSpec(shape=(None, features_size, 1), dtype=tf.float32)

Expand All @@ -80,8 +80,8 @@ def __new__(
repo, ranges_target, batch_size, features=[target], preprocess_func=preprocess_func, silent=silent
)
y_gen = y_streamer.generator
if features == "nucleotide":
out = tf.TensorSpec(shape=(None, target_size, 4), dtype=tf.string)
if target == "nucleotide":
out = tf.TensorSpec(shape=(None, target_size, 4), dtype=tf.int32)
else:
out = tf.TensorSpec(shape=(None, target_size, 1), dtype=tf.float32)

Expand Down
19 changes: 7 additions & 12 deletions tests/test_data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Generator
from typing import Generator, Iterator
import pytest

from momics import momics
Expand All @@ -16,7 +16,7 @@ def test_streamer(momics_path: str):

rg = MomicsStreamer(mom, b, features="bw2", batch_size=1000, silent=False)
assert isinstance(rg.generator(), Generator)
assert isinstance(iter(rg), Generator)
assert isinstance(iter(rg), Iterator)

n = next(iter(rg))
assert len(n) == 1
Expand All @@ -31,11 +31,6 @@ def test_streamer(momics_path: str):
with pytest.raises(StopIteration):
next(rg)
assert rg.batch_index == rg.num_batches
rg.reset()
assert rg.batch_index == 0
n = next(rg)
assert len(n) == 1
assert n[0].shape == (1000, 10, 1)

rg = MomicsStreamer(mom, b, features=["bw3", "bw2"], batch_size=1000)
n = next(rg)
Expand Down Expand Up @@ -67,15 +62,15 @@ def test_dataset(momics_path: str):
MomicsDataset(mom, b, "CH0", "CH1")

b = mom.bins(10, 21, cut_last_bin_out=True)
with pytest.raises(ValueError, match=r"Target size must be smaller than the features width"):
with pytest.raises(ValueError, match=r"Target size must be smaller.*"):
MomicsDataset(mom, b, "bw3", "bw2", target_size=1000000)

rg = MomicsDataset(mom, b, "bw3", "bw2", target_size=2, batch_size=10)
n = next(iter(rg))
assert n[0].shape == (10, 10, 1)
assert n[1].shape == (10, 2, 1)
assert n[0][0].shape == (10, 10, 1)
assert n[1][0].shape == (10, 2, 1)

rg = MomicsDataset(mom, b, "nucleotide", "bw2", target_size=2, batch_size=10)
n = next(iter(rg))
assert n[0].shape == (10, 10, 4)
assert n[1].shape == (10, 2, 1)
assert n[0][0].shape == (10, 10, 4)
assert n[1][0].shape == (10, 2, 1)

0 comments on commit fbbb1a9

Please sign in to comment.