Skip to content

Commit

Permalink
allow selecting an integer with by_id and by_index
Browse files Browse the repository at this point in the history
  • Loading branch information
tmichela committed Sep 30, 2024
1 parent d75fb6d commit b14aa4b
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 2 deletions.
4 changes: 3 additions & 1 deletion extra_data/read_machinery.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,11 @@ def _tid_to_slice_ix(tid, train_ids, stop=False):


def select_train_ids(train_ids, sel):
if isinstance(sel, by_id) and isinstance(sel.value, int):
sel.value = slice(sel.value, sel.value+1, None)
if isinstance(sel, by_index):
sel = sel.value
elif isinstance(sel, int):
if isinstance(sel, int):
sel = slice(sel, sel+1, None)

if isinstance(sel, by_id) and isinstance(sel.value, slice):
Expand Down
60 changes: 59 additions & 1 deletion extra_data/tests/test_read_machinery.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
from unittest import mock

import numpy as np
from extra_data import RunDirectory, read_machinery
import pytest

from extra_data import RunDirectory, by_id, by_index, read_machinery
from extra_data.read_machinery import select_train_ids


def test_find_proposal(tmpdir):
Expand Down Expand Up @@ -51,3 +54,58 @@ def test_same_run(mock_spb_raw_run, mock_jungfrau_run, mock_scs_run):
assert read_machinery.same_run(sd1, sd2)
else:
assert not read_machinery.same_run(sd1, sd2)


def test_select_train_ids():
train_ids = list(range(1000000, 1000010))

# Test by_id with a single integer
assert select_train_ids(train_ids, by_id[1000002]) == [1000002]

# Test by_id with a slice
assert select_train_ids(train_ids, by_id[1000002:1000005]) == [1000002, 1000003, 1000004]

# Test by_id with a list
assert select_train_ids(train_ids, by_id[[1000002, 1000005]]) == [1000002, 1000005]

# Test by_id with a numpy array
assert select_train_ids(train_ids, by_id[np.array([1000002, 1000005])]) == [1000002, 1000005]

# Test by_id with a slice and step
assert select_train_ids(train_ids, by_id[1000000:1000008:2]) == [1000000, 1000002, 1000004, 1000006]

# Test by_id with an open-ended slice (end)
assert select_train_ids(train_ids, by_id[1000005:]) == [1000005, 1000006, 1000007, 1000008, 1000009]

# Test by_id with an open-ended slice (start)
assert select_train_ids(train_ids, by_id[:1000003]) == [1000000, 1000001, 1000002]

# Test by_index with a single integer
assert select_train_ids(train_ids, by_index[2]) == [1000002]

# Test by_index with a slice
assert select_train_ids(train_ids, by_index[1:4]) == [1000001, 1000002, 1000003]

# Test by_index with a list
assert select_train_ids(train_ids, by_index[[1, 3]]) == [1000001, 1000003]

# Test by_index with a slice and step
assert select_train_ids(train_ids, by_index[::2]) == [1000000, 1000002, 1000004, 1000006, 1000008]

# Test with a plain slice
assert select_train_ids(train_ids, slice(1, 4)) == [1000001, 1000002, 1000003]

# Test with a plain list
assert select_train_ids(train_ids, [1, 3]) == [1000001, 1000003]

# Test with a numpy array
assert select_train_ids(train_ids, np.array([1, 3])) == [1000001, 1000003]

# Test with an invalid type (should raise TypeError)
with pytest.raises(TypeError):
select_train_ids(train_ids, "invalid")

# Test by_id with train IDs not found in the list (should raise a warning)
with pytest.warns(UserWarning):
result = select_train_ids(train_ids, by_id[[999999, 1000010]])
assert result == []

0 comments on commit b14aa4b

Please sign in to comment.