Skip to content

Commit

Permalink
Merge pull request ryanharvey1#44 from ryanharvey1/make-cut_artifacts…
Browse files Browse the repository at this point in the history
…_intan

Cut artifacts
  • Loading branch information
ryanharvey1 authored Dec 3, 2024
2 parents 2a81c9b + 8377b4c commit 44b196d
Show file tree
Hide file tree
Showing 3 changed files with 236 additions and 5 deletions.
14 changes: 12 additions & 2 deletions neuro_py/raw/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
__all__ = ["remove_artifacts", "fill_missing_channels"]
__all__ = [
"remove_artifacts",
"fill_missing_channels",
"cut_artifacts",
"cut_artifacts_intan",
]

from .preprocessing import remove_artifacts, fill_missing_channels
from .preprocessing import (
cut_artifacts,
fill_missing_channels,
remove_artifacts,
cut_artifacts_intan,
)
162 changes: 159 additions & 3 deletions neuro_py/raw/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,7 @@ def remove_artifacts(
data[start, ch],
data[end, ch],
end - start,
).astype(
data.dtype
) # Ensure consistent dtype
).astype(data.dtype) # Ensure consistent dtype
data[start:end, ch] = interpolated
else:
warnings.warn(
Expand Down Expand Up @@ -174,6 +172,164 @@ def remove_artifacts(
warnings.warn(f"Failed to create log file: {e}")


def cut_artifacts(
filepath: str,
n_channels: int,
cut_intervals: List[Tuple[int, int]],
precision: str = "int16",
output_filepath: Optional[str] = None,
) -> None:
"""
Remove user-defined periods from recordings in a binary file, resulting in a shorter file.
Parameters
----------
filepath : str
Path to the original binary file.
n_channels : int
Number of channels in the file.
cut_intervals : List[Tuple[int, int]]
List of intervals (start, end) in sample indices to remove.
precision : str, optional
Data precision, by default "int16".
output_filepath : str, optional
Path to save the modified binary file. If None, appends "_cut" to the original filename.
Returns
-------
None
"""
# Check if file exists
if not os.path.exists(filepath):
raise FileNotFoundError(f"File '{filepath}' does not exist.")

# Set default output filepath
if output_filepath is None:
output_filepath = os.path.splitext(filepath)[0] + "_cut.dat"

# Check for valid intervals
cut_intervals = sorted(cut_intervals)
for start, end in cut_intervals:
if start >= end:
raise ValueError(
f"Invalid interval: ({start}, {end}). Start must be less than end."
)

# Map the original file and calculate parameters
bytes_size = np.dtype(precision).itemsize
with open(filepath, "rb") as f:
startoffile = f.seek(0, 0)
endoffile = f.seek(0, 2)
n_samples = int((endoffile - startoffile) / n_channels / bytes_size)

data = np.memmap(filepath, dtype=precision, mode="r", shape=(n_samples, n_channels))

# Identify the indices to keep
keep_mask = np.ones(n_samples, dtype=bool)
for start, end in cut_intervals:
if 0 <= start < n_samples and 0 < end <= n_samples:
keep_mask[start:end] = False
else:
warnings.warn(
f"Interval ({start}, {end}) is out of bounds and was skipped."
)

keep_indices = np.flatnonzero(keep_mask)

# Create a new binary file with only the retained data
with open(output_filepath, "wb") as output_file:
for start_idx in range(0, len(keep_indices), 10_000): # Process in chunks
chunk_indices = keep_indices[start_idx : start_idx + 10_000]
output_file.write(data[chunk_indices].tobytes())

del data # Release memory-mapped file


def cut_artifacts_intan(
folder_name: str, n_channels_amplifier: int, cut_intervals: List[Tuple[int, int]]
) -> None:
"""
Cut specified artifact intervals from Intan data files.
This function iterates through a set of Intan data files (amplifier, auxiliary,
digitalin, digitalout, analogin, time, and supply), and for each file, it removes
artifacts within the specified intervals by invoking the `cut_artifacts` function.
Parameters
----------
folder_name : str
The folder where the Intan data files are located.
n_channels_amplifier : int
The number of amplifier channels used in the amplifier data file.
cut_intervals : List[Tuple[int, int]]
A list of intervals (start, end) in sample indices to remove artifacts.
Each tuple represents the start and end sample index for an artifact to be cut.
Returns
-------
None
This function modifies the files in place, so there is no return value.
Notes
-----
Does not correct the time.dat file, timestamps will be discontinuous after cutting artifacts.
Raises
------
FileNotFoundError
If the amplifier data file does not exist in the provided folder.
ValueError
If video files are found in the folder, as this function does not support video files.
Examples
--------
>>> fs = 20_000
>>> cut_artifacts_intan(
... folder_name=r"path/to/data",
... 128, [(np.array([394.4, 394.836]) * fs).astype(int)]
... )
"""

# refuse to cut artifacts if any video file exist in folder
video_files = [f for f in os.listdir(folder_name) if f.endswith(".avi")]
if video_files:
raise ValueError(f"Video files found in folder, refusing to cut: {video_files}")

# Define data types for each file (from Intan documentation)
files_table = {
"amplifier": "int16",
"auxiliary": "uint16",
"digitalin": "uint16",
"digitalout": "uint16",
"analogin": "uint16",
"time": "int32",
"supply": "uint16",
}

# determine number of samples from amplifier file
amplifier_file_path = os.path.join(folder_name, "amplifier.dat")
if not os.path.exists(amplifier_file_path):
raise FileNotFoundError(f"File '{amplifier_file_path}' does not exist.")

# get number of bytes per sample
bytes_size = np.dtype(files_table["amplifier"]).itemsize

# each file should have the same number of samples
n_samples = os.path.getsize(amplifier_file_path) // (
n_channels_amplifier * bytes_size
)

for file_name, precision in files_table.items():
file_path = os.path.join(folder_name, f"{file_name}.dat")

if os.path.exists(file_path):
# determine number of channels from n_samples
n_channels = int(os.path.getsize(file_path) / n_samples / bytes_size)

# cut artifacts
cut_artifacts(file_path, n_channels, cut_intervals, precision)


def fill_missing_channels(
basepath: str,
n_channels: int,
Expand Down
65 changes: 65 additions & 0 deletions tests/test_cut_artifacts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import pytest
import numpy as np
import os
import tempfile
from typing import List, Tuple

from neuro_py.raw.preprocessing import cut_artifacts


@pytest.fixture
def create_test_file():
"""Fixture to create a temporary binary file with test data."""
with tempfile.TemporaryDirectory() as temp_dir:
filepath = os.path.join(temp_dir, "test_data.dat")

# Test parameters
n_channels = 4
precision = "int16"
original_data = np.arange(100).reshape(-1, n_channels).astype(precision)

# Write test data to the file
with open(filepath, "wb") as f:
f.write(original_data.tobytes())

yield filepath, n_channels, precision, original_data


def test_cut_artifacts(create_test_file):
# Get the temporary file, parameters, and data
filepath, n_channels, precision, original_data = create_test_file

# Define intervals to cut
cut_intervals: List[Tuple[int, int]] = [(5, 10), (15, 20)] # In sample indices

# Expected output after cutting
keep_mask = np.ones(len(original_data), dtype=bool)
for start, end in cut_intervals:
keep_mask[start:end] = False
expected_data = original_data[keep_mask]

# Run the function
output_filepath = os.path.splitext(filepath)[0] + "_cut.dat"
cut_artifacts(
filepath=filepath,
n_channels=n_channels,
cut_intervals=cut_intervals,
precision=precision,
output_filepath=output_filepath,
)

# Verify the output file
with open(output_filepath, "rb") as f:
cut_data = np.frombuffer(f.read(), dtype=precision).reshape(-1, n_channels)

# Assertions
assert len(cut_data) == len(expected_data), "The output file length does not match the expected length."
np.testing.assert_array_equal(
cut_data, expected_data, "The output data does not match the expected data."
)

# Check if the file exists and is smaller than the original
assert os.path.exists(output_filepath), "The output file does not exist."
assert os.path.getsize(output_filepath) < os.path.getsize(filepath), (
"The output file size is not smaller than the original."
)

0 comments on commit 44b196d

Please sign in to comment.