Skip to content

Commit

Permalink
Add additional tests for TensorDataset (#187)
Browse files Browse the repository at this point in the history
* Add additional tests for TensorDataset

* Add explicit casting to avoid windows error
  • Loading branch information
stes authored Oct 27, 2024
1 parent 51f048d commit e652b9a
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 7 deletions.
13 changes: 11 additions & 2 deletions cebra/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(self,
super().__init__(device=device)
self.neural = self._to_tensor(neural, check_dtype="float").float()
self.continuous = self._to_tensor(continuous, check_dtype="float")
self.discrete = self._to_tensor(discrete, check_dtype="integer")
self.discrete = self._to_tensor(discrete, check_dtype="int")
if self.continuous is None and self.discrete is None:
raise ValueError(
"You have to pass at least one of the arguments 'continuous' or 'discrete'."
Expand All @@ -87,7 +87,7 @@ def _to_tensor(
Args:
array: Array to check.
check_dtype (list, optional): If not `None`, list of dtypes to which the values in `array`
check_dtype: If not `None`, list of dtypes to which the values in `array`
must belong to. Defaults to None.
Returns:
Expand All @@ -98,11 +98,20 @@ def _to_tensor(
if isinstance(array, np.ndarray):
array = torch.from_numpy(array)
if check_dtype is not None:
if check_dtype not in ["int", "float"]:
raise ValueError(
f"check_dtype must be 'int' or 'float', got {check_dtype}")
if (check_dtype == "int" and not cebra_helper._is_integer(array)
) or (check_dtype == "float" and
not cebra_helper._is_floating(array)):
raise TypeError(
f"Array has type {array.dtype} instead of {check_dtype}.")
if cebra_helper._is_floating(array):
array = array.float()
if cebra_helper._is_integer(array):
# NOTE(stes): Required for standardizing number format on
# windows machines.
array = array.long()
return array

@property
Expand Down
113 changes: 108 additions & 5 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def test_demo():

@pytest.mark.requires_dataset
def test_hippocampus():
from cebra.datasets import hippocampus

pytest.skip("Outdated")

from cebra.datasets import hippocampus # noqa: F401
dataset = cebra.datasets.init("rat-hippocampus-single")
loader = cebra.data.ContinuousDataLoader(
dataset=dataset,
Expand Down Expand Up @@ -99,7 +99,7 @@ def test_hippocampus():

@pytest.mark.requires_dataset
def test_monkey():
from cebra.datasets import monkey_reaching
from cebra.datasets import monkey_reaching # noqa: F401

dataset = cebra.datasets.init(
"area2-bump-pos-active-passive",
Expand All @@ -111,7 +111,7 @@ def test_monkey():

@pytest.mark.requires_dataset
def test_allen():
from cebra.datasets import allen
from cebra.datasets import allen # noqa: F401

pytest.skip("Test takes too long")

Expand Down Expand Up @@ -148,7 +148,7 @@ def test_allen():
multisubject_options.extend(
cebra.datasets.get_options(
"rat-hippocampus-multisubjects-3fold-trial-split*"))
except:
except: # noqa: E722
options = []


Expand Down Expand Up @@ -388,3 +388,106 @@ def test_download_file_wrong_content_disposition(filename, url,
expected_checksum=expected_checksum,
location=temp_dir,
file_name=filename)


@pytest.mark.parametrize("neural, continuous, discrete", [
(np.random.randn(100, 30), np.random.randn(
100, 2), np.random.randint(0, 5, (100,))),
(np.random.randn(50, 20), None, np.random.randint(0, 3, (50,))),
(np.random.randn(200, 40), np.random.randn(200, 5), None),
])
def test_tensor_dataset_initialization(neural, continuous, discrete):
dataset = cebra.data.datasets.TensorDataset(neural,
continuous=continuous,
discrete=discrete)
assert dataset.neural.shape == neural.shape
if continuous is not None:
assert dataset.continuous.shape == continuous.shape
if discrete is not None:
assert dataset.discrete.shape == discrete.shape


def test_tensor_dataset_invalid_initialization():
neural = np.random.randn(100, 30)
with pytest.raises(ValueError):
cebra.data.datasets.TensorDataset(neural)


@pytest.mark.parametrize("neural, continuous, discrete", [
(np.random.randn(100, 30), np.random.randn(
100, 2), np.random.randint(0, 5, (100,))),
(np.random.randn(50, 20), None, np.random.randint(0, 3, (50,))),
(np.random.randn(200, 40), np.random.randn(200, 5), None),
])
def test_tensor_dataset_length(neural, continuous, discrete):
dataset = cebra.data.datasets.TensorDataset(neural,
continuous=continuous,
discrete=discrete)
assert len(dataset) == len(neural)


@pytest.mark.parametrize("neural, continuous, discrete", [
(np.random.randn(100, 30), np.random.randn(
100, 2), np.random.randint(0, 5, (100,))),
(np.random.randn(50, 20), None, np.random.randint(0, 3, (50,))),
(np.random.randn(200, 40), np.random.randn(200, 5), None),
])
def test_tensor_dataset_getitem(neural, continuous, discrete):
dataset = cebra.data.datasets.TensorDataset(neural,
continuous=continuous,
discrete=discrete)
index = torch.randint(0, len(dataset), (10,))
batch = dataset[index]
assert batch.shape[0] == len(index)
assert batch.shape[1] == neural.shape[1]


def test_tensor_dataset_invalid_discrete_type():
neural = np.random.randn(100, 30)
continuous = np.random.randn(100, 2)
discrete = np.random.randn(100, 2) # Invalid type: float instead of int
with pytest.raises(TypeError):
cebra.data.datasets.TensorDataset(neural,
continuous=continuous,
discrete=discrete)


@pytest.mark.parametrize("array, check_dtype, expected_dtype", [
(np.random.randn(100, 30), "float", torch.float32),
(np.random.randint(0, 5, (100, 30)), "int", torch.int64),
(torch.randn(100, 30), "float", torch.float32),
(torch.randint(0, 5, (100, 30)), "int", torch.int64),
(None, None, None),
])
def test_to_tensor(array, check_dtype, expected_dtype):
dataset = cebra.data.datasets.TensorDataset(np.random.randn(10, 2),
continuous=np.random.randn(
10, 2))
result = dataset._to_tensor(array, check_dtype=check_dtype)
if array is None:
assert result is None
else:
assert isinstance(result, torch.Tensor)
assert result.dtype == expected_dtype


def test_to_tensor_invalid_dtype():
dataset = cebra.data.datasets.TensorDataset(np.random.randn(10, 2),
continuous=np.random.randn(
10, 2))
array = np.random.randn(100, 30)
with pytest.raises(TypeError):
dataset._to_tensor(array, check_dtype="int")
array = np.random.randint(0, 5, (100, 30))
with pytest.raises(TypeError):
dataset._to_tensor(array, check_dtype="float")


def test_to_tensor_invalid_check_dtype():
dataset = cebra.data.datasets.TensorDataset(np.random.randn(10, 2),
continuous=np.random.randn(
10, 2))
array = np.random.randn(100, 30)
with pytest.raises(ValueError,
match="check_dtype must be 'int' or 'float', got"):
dataset._to_tensor(array, check_dtype="invalid_dtype")

0 comments on commit e652b9a

Please sign in to comment.