Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FLAC output and assorted wrsamp improvements #420

Merged
merged 10 commits into from
Sep 20, 2022
32 changes: 31 additions & 1 deletion tests/test_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def test_1f(self):
"Mismatch in %s" % name,
)

def test_read_flac(self):
def test_read_write_flac(self):
"""
All FLAC formats, multiple signal files in one record.

Expand Down Expand Up @@ -250,6 +250,28 @@ def test_read_flac(self):
f"Mismatch in {name}",
)

# Test file writing
record.wrsamp()
record_write = wfdb.rdrecord("flacformats", physical=False)
assert record == record_write

def test_read_write_flac_multifrequency(self):
"""
Format 516 with multiple signal files and variable samples per frame.
"""
# Check that we can read a record and write it out again
record = wfdb.rdrecord(
"sample-data/mixedsignals",
physical=False,
smooth_frames=False,
)
record.wrsamp(expanded=True)

# Check that result matches the original
record = wfdb.rdrecord("sample-data/mixedsignals", smooth_frames=False)
record_write = wfdb.rdrecord("mixedsignals", smooth_frames=False)
assert record == record_write

def test_read_flac_longduration(self):
"""
Three signals multiplexed in a FLAC file, over 2**24 samples.
Expand Down Expand Up @@ -628,6 +650,14 @@ def tearDownClass(cls):
"100_3chan.hea",
"a103l.hea",
"a103l.mat",
"flacformats.d0",
"flacformats.d1",
"flacformats.d2",
"flacformats.hea",
"mixedsignals.hea",
"mixedsignals_e.dat",
"mixedsignals_p.dat",
"mixedsignals_r.dat",
"s0010_re.dat",
"s0010_re.hea",
"s0010_re.xyz",
Expand Down
31 changes: 30 additions & 1 deletion wfdb/io/_header.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,35 @@ def get_write_fields(self):

return rec_write_fields, sig_write_fields

def _auto_signal_file_names(self):
fmt = self.fmt or [None] * self.n_sig
spf = self.samps_per_frame or [None] * self.n_sig
num_groups = 0
group_number = []
prev_fmt = prev_spf = None
channels_in_group = 0

for ch_fmt, ch_spf in zip(fmt, spf):
if ch_fmt != prev_fmt:
num_groups += 1
channels_in_group = 0
elif ch_fmt in ("508", "516", "524"):
if channels_in_group >= 8 or ch_spf != prev_spf:
num_groups += 1
channels_in_group = 0
group_number.append(num_groups)
prev_fmt = ch_fmt
prev_spf = ch_spf

if num_groups < 2:
return [self.record_name + ".dat"] * self.n_sig
else:
digits = len(str(group_number[-1]))
return [
self.record_name + "_" + str(g).rjust(digits, "0") + ".dat"
for g in group_number
]

def set_default(self, field):
"""
Set the object's attribute to its default value if it is missing
Expand Down Expand Up @@ -394,7 +423,7 @@ def set_default(self, field):

# Specific dynamic case
if field == "file_name" and self.file_name is None:
self.file_name = self.n_sig * [self.record_name + ".dat"]
self.file_name = self._auto_signal_file_names()
return

item = getattr(self, field)
Expand Down
105 changes: 82 additions & 23 deletions wfdb/io/_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -950,12 +950,11 @@ def wr_dat_files(self, expanded=False, write_dir=""):
dat_offsets[fn],
True,
[self.e_d_signal[ch] for ch in dat_channels[fn]],
self.samps_per_frame,
[self.samps_per_frame[ch] for ch in dat_channels[fn]],
write_dir=write_dir,
)
else:
# Create a copy to prevent overwrite
dsig = self.d_signal.copy()
dsig = self.d_signal
for fn in file_names:
wr_dat_file(
fn,
Expand Down Expand Up @@ -2267,16 +2266,15 @@ def wr_dat_file(
fmt : str
WFDB fmt of the dat file.
d_signal : ndarray
The digital conversion of the signal. Either a 2d numpy
array or a list of 1d numpy arrays.
The digital conversion of the signal, as a 2d numpy array.
byte_offset : int
The byte offset of the dat file.
expanded : bool, optional
Whether to transform the `e_d_signal` attribute (True) or
the `d_signal` attribute (False).
d_signal : ndarray, optional
The expanded digital conversion of the signal. Either a 2d numpy
array or a list of 1d numpy arrays.
e_d_signal : ndarray, optional
The expanded digital conversion of the signal, as a list of 1d
numpy arrays.
samps_per_frame : list, optional
The samples/frame for each signal of the dat file.
write_dir : str, optional
Expand All @@ -2287,10 +2285,19 @@ def wr_dat_file(
N/A

"""
file_path = os.path.join(write_dir, file_name)

# Combine list of arrays into single array
if expanded:
n_sig = len(e_d_signal)
sig_len = int(len(e_d_signal[0]) / samps_per_frame[0])
if len(samps_per_frame) != n_sig:
raise ValueError("mismatch between samps_per_frame and e_d_signal")

sig_len = len(e_d_signal[0]) // samps_per_frame[0]
for sig, spf in zip(e_d_signal, samps_per_frame):
if len(sig) != sig_len * spf:
raise ValueError("mismatch in lengths of expanded signals")

# Effectively create MxN signal, with extra frame samples acting
# like extra channels
d_signal = np.zeros((sig_len, sum(samps_per_frame)), dtype="int64")
Expand All @@ -2301,10 +2308,17 @@ def wr_dat_file(
for framenum in range(spf):
d_signal[:, expand_ch] = e_d_signal[ch][framenum::spf]
expand_ch = expand_ch + 1
else:
# Create a copy to prevent overwrite
d_signal = d_signal.copy()

# This n_sig is used for making list items.
# Does not necessarily represent number of signals (ie. for expanded=True)
n_sig = d_signal.shape[1]
# Non-expanded format always has 1 sample per frame
n_sig = d_signal.shape[1]
samps_per_frame = [1] * n_sig

# Total number of samples per frame (equal to number of signals if
# expanded=False, but may be greater for expanded=True)
tsamps_per_frame = d_signal.shape[1]

if fmt == "80":
# convert to 8 bit offset binary form
Expand Down Expand Up @@ -2362,8 +2376,8 @@ def wr_dat_file(
# convert to 16 bit two's complement
d_signal[d_signal < 0] = d_signal[d_signal < 0] + 65536
# Split samples into separate bytes using binary masks
b1 = d_signal & [255] * n_sig
b2 = (d_signal & [65280] * n_sig) >> 8
b1 = d_signal & [255] * tsamps_per_frame
b2 = (d_signal & [65280] * tsamps_per_frame) >> 8
# Interweave the bytes so that the same samples' bytes are consecutive
b1 = b1.reshape((-1, 1))
b2 = b2.reshape((-1, 1))
Expand All @@ -2375,9 +2389,9 @@ def wr_dat_file(
# convert to 24 bit two's complement
d_signal[d_signal < 0] = d_signal[d_signal < 0] + 16777216
# Split samples into separate bytes using binary masks
b1 = d_signal & [255] * n_sig
b2 = (d_signal & [65280] * n_sig) >> 8
b3 = (d_signal & [16711680] * n_sig) >> 16
b1 = d_signal & [255] * tsamps_per_frame
b2 = (d_signal & [65280] * tsamps_per_frame) >> 8
b3 = (d_signal & [16711680] * tsamps_per_frame) >> 16
# Interweave the bytes so that the same samples' bytes are consecutive
b1 = b1.reshape((-1, 1))
b2 = b2.reshape((-1, 1))
Expand All @@ -2391,10 +2405,10 @@ def wr_dat_file(
# convert to 32 bit two's complement
d_signal[d_signal < 0] = d_signal[d_signal < 0] + 4294967296
# Split samples into separate bytes using binary masks
b1 = d_signal & [255] * n_sig
b2 = (d_signal & [65280] * n_sig) >> 8
b3 = (d_signal & [16711680] * n_sig) >> 16
b4 = (d_signal & [4278190080] * n_sig) >> 24
b1 = d_signal & [255] * tsamps_per_frame
b2 = (d_signal & [65280] * tsamps_per_frame) >> 8
b3 = (d_signal & [16711680] * tsamps_per_frame) >> 16
b4 = (d_signal & [4278190080] * tsamps_per_frame) >> 24
# Interweave the bytes so that the same samples' bytes are consecutive
b1 = b1.reshape((-1, 1))
b2 = b2.reshape((-1, 1))
Expand All @@ -2404,9 +2418,54 @@ def wr_dat_file(
b_write = b_write.reshape((1, -1))[0]
# Convert to un_signed 8 bit dtype to write
b_write = b_write.astype("uint8")

elif fmt in ("508", "516", "524"):
import soundfile

if any(spf != samps_per_frame[0] for spf in samps_per_frame):
raise ValueError(
"All channels in a FLAC signal file must have the same "
"sampling rate and samples per frame"
)
if n_sig > 8:
raise ValueError(
"A single FLAC signal file cannot contain more than 8 channels"
)

d_signal = d_signal.reshape(-1, n_sig, samps_per_frame[0])
d_signal = d_signal.transpose(0, 2, 1)
d_signal = d_signal.reshape(-1, n_sig)

if fmt == "508":
d_signal = d_signal.astype("int16")
np.left_shift(d_signal, 8, out=d_signal)
subtype = "PCM_S8"
elif fmt == "516":
d_signal = d_signal.astype("int16")
subtype = "PCM_16"
elif fmt == "524":
d_signal = d_signal.astype("int32")
np.left_shift(d_signal, 8, out=d_signal)
subtype = "PCM_24"
else:
raise ValueError(f"unknown format ({fmt})")

sf = soundfile.SoundFile(
file_path,
mode="w",
samplerate=96000,
channels=n_sig,
subtype=subtype,
format="FLAC",
)
with sf:
sf.write(d_signal)
return

else:
raise ValueError(
"This library currently only supports writing the following formats: 80, 16, 24, 32"
"This library currently only supports writing the "
"following formats: 80, 16, 24, 32, 508, 516, 524"
)

# Byte offset in the file
Expand All @@ -2421,7 +2480,7 @@ def wr_dat_file(
b_write = np.append(np.zeros(byte_offset, dtype="uint8"), b_write)

# Write the bytes to the file
with open(os.path.join(write_dir, file_name), "wb") as f:
with open(file_path, "wb") as f:
b_write.tofile(f)


Expand Down
10 changes: 8 additions & 2 deletions wfdb/io/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,9 +520,15 @@ def check_field(self, field, required_channels="all"):
"block_size values must be non-negative integers"
)
elif field == "sig_name":
if re.search(r"\s", item[ch]):
if item[ch][:1].isspace() or item[ch][-1:].isspace():
raise ValueError(
"sig_name strings may not begin or end with "
"whitespace."
)
if re.search(r"[\x00-\x1f\x7f-\x9f]", item[ch]):
raise ValueError(
"sig_name strings may not contain whitespaces."
"sig_name strings may not contain "
"control characters."
)
if len(set(item)) != len(item):
raise ValueError("sig_name strings must be unique.")
Expand Down