diff --git a/tests/test_record.py b/tests/test_record.py index 3d9d32a7..a6c57e79 100644 --- a/tests/test_record.py +++ b/tests/test_record.py @@ -219,7 +219,7 @@ def test_1f(self): "Mismatch in %s" % name, ) - def test_read_flac(self): + def test_read_write_flac(self): """ All FLAC formats, multiple signal files in one record. @@ -250,6 +250,28 @@ def test_read_flac(self): f"Mismatch in {name}", ) + # Test file writing + record.wrsamp() + record_write = wfdb.rdrecord("flacformats", physical=False) + assert record == record_write + + def test_read_write_flac_multifrequency(self): + """ + Format 516 with multiple signal files and variable samples per frame. + """ + # Check that we can read a record and write it out again + record = wfdb.rdrecord( + "sample-data/mixedsignals", + physical=False, + smooth_frames=False, + ) + record.wrsamp(expanded=True) + + # Check that result matches the original + record = wfdb.rdrecord("sample-data/mixedsignals", smooth_frames=False) + record_write = wfdb.rdrecord("mixedsignals", smooth_frames=False) + assert record == record_write + def test_read_flac_longduration(self): """ Three signals multiplexed in a FLAC file, over 2**24 samples. @@ -628,6 +650,14 @@ def tearDownClass(cls): "100_3chan.hea", "a103l.hea", "a103l.mat", + "flacformats.d0", + "flacformats.d1", + "flacformats.d2", + "flacformats.hea", + "mixedsignals.hea", + "mixedsignals_e.dat", + "mixedsignals_p.dat", + "mixedsignals_r.dat", "s0010_re.dat", "s0010_re.hea", "s0010_re.xyz", diff --git a/wfdb/io/_header.py b/wfdb/io/_header.py index 7322ee87..2bddd01a 100644 --- a/wfdb/io/_header.py +++ b/wfdb/io/_header.py @@ -361,6 +361,35 @@ def get_write_fields(self): return rec_write_fields, sig_write_fields + def _auto_signal_file_names(self): + fmt = self.fmt or [None] * self.n_sig + spf = self.samps_per_frame or [None] * self.n_sig + num_groups = 0 + group_number = [] + prev_fmt = prev_spf = None + channels_in_group = 0 + + for ch_fmt, ch_spf in zip(fmt, spf): + if ch_fmt != prev_fmt: + num_groups += 1 + channels_in_group = 0 + elif ch_fmt in ("508", "516", "524"): + if channels_in_group >= 8 or ch_spf != prev_spf: + num_groups += 1 + channels_in_group = 0 + group_number.append(num_groups) + prev_fmt = ch_fmt + prev_spf = ch_spf + + if num_groups < 2: + return [self.record_name + ".dat"] * self.n_sig + else: + digits = len(str(group_number[-1])) + return [ + self.record_name + "_" + str(g).rjust(digits, "0") + ".dat" + for g in group_number + ] + def set_default(self, field): """ Set the object's attribute to its default value if it is missing @@ -394,7 +423,7 @@ def set_default(self, field): # Specific dynamic case if field == "file_name" and self.file_name is None: - self.file_name = self.n_sig * [self.record_name + ".dat"] + self.file_name = self._auto_signal_file_names() return item = getattr(self, field) diff --git a/wfdb/io/_signal.py b/wfdb/io/_signal.py index d3d26439..61abb568 100644 --- a/wfdb/io/_signal.py +++ b/wfdb/io/_signal.py @@ -950,12 +950,11 @@ def wr_dat_files(self, expanded=False, write_dir=""): dat_offsets[fn], True, [self.e_d_signal[ch] for ch in dat_channels[fn]], - self.samps_per_frame, + [self.samps_per_frame[ch] for ch in dat_channels[fn]], write_dir=write_dir, ) else: - # Create a copy to prevent overwrite - dsig = self.d_signal.copy() + dsig = self.d_signal for fn in file_names: wr_dat_file( fn, @@ -2267,16 +2266,15 @@ def wr_dat_file( fmt : str WFDB fmt of the dat file. d_signal : ndarray - The digital conversion of the signal. Either a 2d numpy - array or a list of 1d numpy arrays. + The digital conversion of the signal, as a 2d numpy array. byte_offset : int The byte offset of the dat file. expanded : bool, optional Whether to transform the `e_d_signal` attribute (True) or the `d_signal` attribute (False). - d_signal : ndarray, optional - The expanded digital conversion of the signal. Either a 2d numpy - array or a list of 1d numpy arrays. + e_d_signal : ndarray, optional + The expanded digital conversion of the signal, as a list of 1d + numpy arrays. samps_per_frame : list, optional The samples/frame for each signal of the dat file. write_dir : str, optional @@ -2287,10 +2285,19 @@ def wr_dat_file( N/A """ + file_path = os.path.join(write_dir, file_name) + # Combine list of arrays into single array if expanded: n_sig = len(e_d_signal) - sig_len = int(len(e_d_signal[0]) / samps_per_frame[0]) + if len(samps_per_frame) != n_sig: + raise ValueError("mismatch between samps_per_frame and e_d_signal") + + sig_len = len(e_d_signal[0]) // samps_per_frame[0] + for sig, spf in zip(e_d_signal, samps_per_frame): + if len(sig) != sig_len * spf: + raise ValueError("mismatch in lengths of expanded signals") + # Effectively create MxN signal, with extra frame samples acting # like extra channels d_signal = np.zeros((sig_len, sum(samps_per_frame)), dtype="int64") @@ -2301,10 +2308,17 @@ def wr_dat_file( for framenum in range(spf): d_signal[:, expand_ch] = e_d_signal[ch][framenum::spf] expand_ch = expand_ch + 1 + else: + # Create a copy to prevent overwrite + d_signal = d_signal.copy() - # This n_sig is used for making list items. - # Does not necessarily represent number of signals (ie. for expanded=True) - n_sig = d_signal.shape[1] + # Non-expanded format always has 1 sample per frame + n_sig = d_signal.shape[1] + samps_per_frame = [1] * n_sig + + # Total number of samples per frame (equal to number of signals if + # expanded=False, but may be greater for expanded=True) + tsamps_per_frame = d_signal.shape[1] if fmt == "80": # convert to 8 bit offset binary form @@ -2362,8 +2376,8 @@ def wr_dat_file( # convert to 16 bit two's complement d_signal[d_signal < 0] = d_signal[d_signal < 0] + 65536 # Split samples into separate bytes using binary masks - b1 = d_signal & [255] * n_sig - b2 = (d_signal & [65280] * n_sig) >> 8 + b1 = d_signal & [255] * tsamps_per_frame + b2 = (d_signal & [65280] * tsamps_per_frame) >> 8 # Interweave the bytes so that the same samples' bytes are consecutive b1 = b1.reshape((-1, 1)) b2 = b2.reshape((-1, 1)) @@ -2375,9 +2389,9 @@ def wr_dat_file( # convert to 24 bit two's complement d_signal[d_signal < 0] = d_signal[d_signal < 0] + 16777216 # Split samples into separate bytes using binary masks - b1 = d_signal & [255] * n_sig - b2 = (d_signal & [65280] * n_sig) >> 8 - b3 = (d_signal & [16711680] * n_sig) >> 16 + b1 = d_signal & [255] * tsamps_per_frame + b2 = (d_signal & [65280] * tsamps_per_frame) >> 8 + b3 = (d_signal & [16711680] * tsamps_per_frame) >> 16 # Interweave the bytes so that the same samples' bytes are consecutive b1 = b1.reshape((-1, 1)) b2 = b2.reshape((-1, 1)) @@ -2391,10 +2405,10 @@ def wr_dat_file( # convert to 32 bit two's complement d_signal[d_signal < 0] = d_signal[d_signal < 0] + 4294967296 # Split samples into separate bytes using binary masks - b1 = d_signal & [255] * n_sig - b2 = (d_signal & [65280] * n_sig) >> 8 - b3 = (d_signal & [16711680] * n_sig) >> 16 - b4 = (d_signal & [4278190080] * n_sig) >> 24 + b1 = d_signal & [255] * tsamps_per_frame + b2 = (d_signal & [65280] * tsamps_per_frame) >> 8 + b3 = (d_signal & [16711680] * tsamps_per_frame) >> 16 + b4 = (d_signal & [4278190080] * tsamps_per_frame) >> 24 # Interweave the bytes so that the same samples' bytes are consecutive b1 = b1.reshape((-1, 1)) b2 = b2.reshape((-1, 1)) @@ -2404,9 +2418,54 @@ def wr_dat_file( b_write = b_write.reshape((1, -1))[0] # Convert to un_signed 8 bit dtype to write b_write = b_write.astype("uint8") + + elif fmt in ("508", "516", "524"): + import soundfile + + if any(spf != samps_per_frame[0] for spf in samps_per_frame): + raise ValueError( + "All channels in a FLAC signal file must have the same " + "sampling rate and samples per frame" + ) + if n_sig > 8: + raise ValueError( + "A single FLAC signal file cannot contain more than 8 channels" + ) + + d_signal = d_signal.reshape(-1, n_sig, samps_per_frame[0]) + d_signal = d_signal.transpose(0, 2, 1) + d_signal = d_signal.reshape(-1, n_sig) + + if fmt == "508": + d_signal = d_signal.astype("int16") + np.left_shift(d_signal, 8, out=d_signal) + subtype = "PCM_S8" + elif fmt == "516": + d_signal = d_signal.astype("int16") + subtype = "PCM_16" + elif fmt == "524": + d_signal = d_signal.astype("int32") + np.left_shift(d_signal, 8, out=d_signal) + subtype = "PCM_24" + else: + raise ValueError(f"unknown format ({fmt})") + + sf = soundfile.SoundFile( + file_path, + mode="w", + samplerate=96000, + channels=n_sig, + subtype=subtype, + format="FLAC", + ) + with sf: + sf.write(d_signal) + return + else: raise ValueError( - "This library currently only supports writing the following formats: 80, 16, 24, 32" + "This library currently only supports writing the " + "following formats: 80, 16, 24, 32, 508, 516, 524" ) # Byte offset in the file @@ -2421,7 +2480,7 @@ def wr_dat_file( b_write = np.append(np.zeros(byte_offset, dtype="uint8"), b_write) # Write the bytes to the file - with open(os.path.join(write_dir, file_name), "wb") as f: + with open(file_path, "wb") as f: b_write.tofile(f) diff --git a/wfdb/io/record.py b/wfdb/io/record.py index ba145407..399b69ff 100644 --- a/wfdb/io/record.py +++ b/wfdb/io/record.py @@ -520,9 +520,15 @@ def check_field(self, field, required_channels="all"): "block_size values must be non-negative integers" ) elif field == "sig_name": - if re.search(r"\s", item[ch]): + if item[ch][:1].isspace() or item[ch][-1:].isspace(): + raise ValueError( + "sig_name strings may not begin or end with " + "whitespace." + ) + if re.search(r"[\x00-\x1f\x7f-\x9f]", item[ch]): raise ValueError( - "sig_name strings may not contain whitespaces." + "sig_name strings may not contain " + "control characters." ) if len(set(item)) != len(item): raise ValueError("sig_name strings must be unique.")