diff --git a/src/ilabs_streamsync/example_script.py b/src/ilabs_streamsync/example_script.py index 710b7ed..e006f72 100644 --- a/src/ilabs_streamsync/example_script.py +++ b/src/ilabs_streamsync/example_script.py @@ -1,34 +1,44 @@ +from __future__ import annotations + import mne +from streamsync import StreamSync, extract_audio_from_video + +if __name__ == "__main__": + # load an MNE raw file + raw = "/Users/user/VideoSync_NonSubject/sinclair_alexis_audiosync_240110_raw.fif" + channel = "STI011" + cams = ["/Users/user/VideoSync_NonSubject/sinclair_alexis_audiosync_240110_CAM3.mp4"] + output_dir = "/Users/user/VideoSync_NonSubject/output" + flux1 = None + my_events = [] + + for cam in cams: + extract_audio_from_video(cam, output_dir, overwrite=False) #This could potentially return filenames to avoid the hardcoding seen below. + ss = StreamSync(raw, channel) + + ss.add_stream("/Users/user/VideoSync_NonSubject/output/sinclair_alexis_audiosync_240110_CAM3_16bit.wav", channel=1) + ss.plot_sync_pulses(tmin=0.5,tmax=50) + + # subjects = ["146a", "222b"] + + # for subj in subjects: + # construct the filename/path + # load the Raw + # figure out where video files are & load them + # extract_audio_from_video(cam1) + + # ss = StreamSync(raw, "STIM001") + # ss.add_stream(audio1) + # ss.add_camera_events(my_events) + # ss.add_stream(flux1) + # result = ss.do_syncing() + # fig = ss.plot_sync() + # annot = ss.add_camera_events(my_events) + # raw.set_annotations(annot) + # fig.savefig(...) + # if result < 0.7: + # write_log_msg(f"subj {subj} had bad pulse syncing, aborting") + # continue -from ilabs_streamsync import StreamSync, extract_audio_from_video - -# load an MNE raw file -raw = None -cam1 = None -flux1 = None -my_events = [] - - -subjects = ["146a", "222b"] - -for subj in subjects: - # construct the filename/path - # load the Raw - # figure out where video files are & load them - audio1 = extract_audio_from_video(cam1) - - ss = StreamSync(raw, "STIM001") - ss.add_stream(audio1) - ss.add_camera_events(my_events) - ss.add_stream(flux1) - result = ss.do_syncing() - fig = ss.plot_sync() - annot = ss.add_camera_events(my_events) - raw.set_annotations(annot) - fig.savefig(...) - if result < 0.7: - write_log_msg(f"subj {subj} had bad pulse syncing, aborting") - continue - - # apply maxfilter - # do ICA + # apply maxfilter + # do ICA diff --git a/src/ilabs_streamsync/streamdata.py b/src/ilabs_streamsync/streamdata.py new file mode 100644 index 0000000..9a30fcf --- /dev/null +++ b/src/ilabs_streamsync/streamdata.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +class StreamData: + """ + Store information about stream of data. + """ + def __init__(self, filename, sample_rate, pulses, data): + """ + Initialize object with associated properties. + + filename: str + Path to the file with stream data + sample_rate: int + Sampling rate of the data + pulses: np.array + Numpy array representing the pulses. + data: np.array + NumPy array representing all streams of data. + """ + self.filename = filename + self.sample_rate = sample_rate + self.pulses = pulses + self.data = data \ No newline at end of file diff --git a/src/ilabs_streamsync/streamsync.py b/src/ilabs_streamsync/streamsync.py index c8a8552..f967202 100644 --- a/src/ilabs_streamsync/streamsync.py +++ b/src/ilabs_streamsync/streamsync.py @@ -1,3 +1,19 @@ + +from __future__ import annotations + +import logging +import os +import pathlib +import subprocess + +import matplotlib.pyplot as plt +import mne +import numpy as np +from scipy.io.wavfile import read as wavread +from streamdata import StreamData + +FFMPEG_TIMEOUT_SEC = 50 + class StreamSync: """Synchronize two data streams. @@ -9,39 +25,145 @@ class StreamSync: """ def __init__(self, reference_object, pulse_channel): - self.ref_stream = reference_object.get_chan(pulse_channel) - self.sfreq = reference_object.info["sfreq"] # Hz - self.streams = [] + """Initialize StreamSync object with 'Raw' MEG associated with it. + + reference_object: str TODO: is str the best method for this, or should this be pathlib obj? + File path to an MEG raw file with fif formatting. TODO: Verify fif only? + pulse_channel: str + A string associated with the stim channel name. + """ + # Check provided reference_object for type and existence. + if not reference_object: + raise TypeError("reference_object is None. Please provide a path.") + if type(reference_object) is not str: + raise TypeError("reference_object must be a file path of type str.") + ref_path_obj = pathlib.Path(reference_object) + if not ref_path_obj.exists(): + raise OSError("reference_object file path does not exist.") + if not ref_path_obj.suffix == ".fif": + raise ValueError("Provided reference object does not point to a .fif file.") + + # Load in raw file if valid + raw = mne.io.read_raw_fif(reference_object, preload=False, allow_maxshield=True) + + #Check type and value of pulse_channel, and ensure reference object has such a channel. + if not pulse_channel: + raise TypeError("pulse_channel is None. Please provide a channel name of type str.") + if type(pulse_channel) is not str: + raise TypeError("pulse_chanel parameter must be of type str.") + if raw[pulse_channel] is None: + raise ValueError('pulse_channel does not exist in refrence_object.') + + + self.raw = mne.io.read_raw_fif(reference_object, preload=False, allow_maxshield=True) + self.ref_stream = raw[pulse_channel] + + self.sfreq = self.raw.info["sfreq"] # Hz + + self.streams = [] # list of StreamData objects def add_stream(self, stream, channel=None, events=None): """Add a new ``Raw`` or video stream, optionally with events. - stream : Raw | wav - An audio or FIF stream. + stream : str + File path to an audio or FIF stream. channel : str | int | None Which channel of `stream` contains the sync pulse sequence. events : array-like | None Events associated with the stream. TODO: should they be integer sample numbers? Timestamps? Do we support both? """ - pulses = self._extract_pulse_sequence_from_stream(stream, channel=channel) - self.streams.append(pulses) + self.streams.append(self._extract_data_from_stream(stream, channel=channel)) - def _extract_pulse_sequence_from_stream(self, stream, channel): - # TODO triage based on input type (e.g., if it's a Raw, pull out a stim chan, - # if it's audio, just add it as-is) + def _extract_data_from_stream(self, stream, channel): + """Extract pulses and raw data from stream provided. TODO: Implement adding a annotation stream.""" + ext = pathlib.Path(stream).suffix + if ext == ".wav": + return self._extract_data_from_wav(stream, channel) + raise TypeError("Stream provided was of unsupported format. Please provide a wav file.") + + + def _extract_data_from_wav(self, stream, channel): + """Return tuple of (pulse channel, audio channel) from stereo file.""" + srate, wav_signal = wavread(stream) + return StreamData(filename = stream, sample_rate=srate, pulses=wav_signal[:,channel], data=wav_signal[:,1-channel]) + + def remove_stream(self, stream): pass def do_syncing(self): """Synchronize all streams with the reference stream.""" # TODO (waves hands) do the hard part. # TODO spit out a report of correlation/association between all pairs of streams - pass - def plot_sync(self): - pass + def plot_sync_pulses(self, tmin=0, tmax=None): + """Plot each stream in the class. + + tmin: int + Minimum timestamp to be graphed. + tmax: int + Maximum timestamp to be graphed. + """ + fig, axset = plt.subplots(len(self.streams)+1, 1, figsize = [8,6]) #show individual channels seperately, and the 0th plot is the combination of these. + # Plot reference_object + trig, tt_trig = self.ref_stream + trig = trig.reshape(tt_trig.shape) + idx = np.where((tt_trig>=tmin) & (tt_trig=tmin) & (tt