Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deal with unsigned int to int conversion #1707

Merged
merged 10 commits into from
Jun 27, 2023
10 changes: 10 additions & 0 deletions src/spikeinterface/preprocessing/astype.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import numpy as np

from ..core.core_tools import define_function_from_class
from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment
from .filter import fix_dtype
Expand Down Expand Up @@ -45,6 +47,14 @@ def get_traces(self, start_frame, end_frame, channel_indices):
if channel_indices is None:
channel_indices = slice(None)
traces = self.parent_recording_segment.get_traces(start_frame, end_frame, channel_indices)
# if uint --> take care of offset
traces_dtype = traces.dtype
if traces_dtype.kind == "u" and np.dtype(self.dtype).kind == "i":
itemsize = traces_dtype.itemsize
assert itemsize < 8, "Cannot upcast uint64!"
nbits = traces_dtype.itemsize * 8
# upcast to int with double itemsize
traces = traces.astype(f"int{2 * (traces_dtype.itemsize) * 8}") - 2 ** (nbits - 1)
return traces.astype(self.dtype, copy=False)


Expand Down
42 changes: 42 additions & 0 deletions src/spikeinterface/preprocessing/tests/test_astype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import pytest
from pathlib import Path
import shutil

from spikeinterface import set_global_tmp_folder, NumpyRecording
from spikeinterface.core import generate_recording

from spikeinterface.preprocessing import astype

import numpy as np


if hasattr(pytest, "global_test_folder"):
cache_folder = pytest.global_test_folder / "preprocessing"
else:
cache_folder = Path("cache_folder") / "preprocessing"

set_global_tmp_folder(cache_folder)


def test_astype():
traces = (np.random.rand(10000, 4) * 100).astype("float32")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these tests will work better with a user defined traces

Just write something like [1.0 ,3.0, 5.0, ...] or at least let's make them non-deterministic.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's ok to leave randomness here. It should work in any case since it's only casting the dtypes

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is unlikely to create a problem here but I also don't see why the randomness is required. If it is meant to work as a test then it should be reproducible so you can debug it if it fails. But here, the failure is unlikely to come from randonmess so not a big concern ...

Anyway, not a big deal:
"A foolish consistency is the hobgoblin of little minds"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll set a seed ;)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

rec_float32 = NumpyRecording(traces, sampling_frequency=30000)
traces_int16 = traces.astype("int16")
np.testing.assert_array_equal(traces_int16, astype(rec_float32, "int16").get_traces())
traces_float64 = traces.astype("float64")
np.testing.assert_array_equal(traces_float64, astype(rec_float32, "float64").get_traces())


def test_astype_unsigned():
traces = np.random.rand(10000, 4) * 100 + 500
traces_uint16 = traces.astype("uint16")
rec_uint16 = NumpyRecording(traces_uint16, sampling_frequency=30000)
traces_int16 = (traces.astype("int32") - 2**15).astype("int16")
np.testing.assert_array_equal(traces_int16, astype(rec_uint16, "int16").get_traces())
traces_int32 = (traces.astype("int32") - 2**15).astype("int32")
np.testing.assert_array_equal(traces_int32, astype(rec_uint16, "int32").get_traces())


if __name__ == "__main__":
test_astype()
test_astype_unsigned()