Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement fletcher32 #412

Merged
merged 10 commits into from
Jan 15, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions numcodecs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,6 @@
register_codec(VLenUTF8)
register_codec(VLenBytes)
register_codec(VLenArray)

from numcodecs.fletcher32 import Fletcher32
register_codec(Fletcher32)
79 changes: 79 additions & 0 deletions numcodecs/fletcher32.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# cython: language_level=3
# cython: overflowcheck=False
# cython: cdivision=True
import struct

from numcodecs.abc import Codec
from numcodecs.compat import ensure_contiguous_ndarray

from libc.stdint cimport uint8_t, uint16_t, uint32_t
martindurant marked this conversation as resolved.
Show resolved Hide resolved


cdef uint32_t _fletcher32(const uint8_t[::1] _data):
cdef:
const uint8_t *data = &_data[0]
size_t _len = _data.shape[0]
size_t len = _len / 2
size_t tlen
uint32_t sum1 = 0, sum2 = 0;


while len:
tlen = 360 if len > 360 else len
len -= tlen
while True:
sum1 += <uint32_t>((<uint16_t>data[0]) << 8) | (<uint16_t>data[1])
data += 2
sum2 += sum1
tlen -= 1
if tlen < 1:
break
sum1 = (sum1 & 0xffff) + (sum1 >> 16)
sum2 = (sum2 & 0xffff) + (sum2 >> 16)

if _len % 2:
sum1 += <uint32_t>((<uint16_t>(data[0])) << 8)
sum2 += sum1
sum1 = (sum1 & 0xffff) + (sum1 >> 16)
sum2 = (sum2 & 0xffff) + (sum2 >> 16)

sum1 = (sum1 & 0xffff) + (sum1 >> 16)
sum2 = (sum2 & 0xffff) + (sum2 >> 16)

return (sum2 << 16) | sum1


class Fletcher32(Codec):
"""The fletcher checksum with 16-bit words and 32-bit output

With this codec, the checksum is concatenated on the end of the data
bytes when encoded. At decode time, the checksum is performed on
the data portion and compared with the four-byte checksum, raising
RuntimeError if inconsistent.
"""

codec_id = "fletcher32"

def encode(self, buf):
"""Return buffer plus 4-byte fletcher checksum"""
buf = ensure_contiguous_ndarray(buf).ravel().view('uint8')
cdef const uint8_t[::1] b_ptr = buf
val = _fletcher32(b_ptr)
return buf.tobytes() + struct.pack("<I", val)

def decode(self, buf, out=None):
"""Check fletcher checksum, and return buffer without it"""
b = ensure_contiguous_ndarray(buf).view('uint8')
cdef const uint8_t[::1] b_ptr = b[:-4]
val = _fletcher32(b_ptr)
found = b[-4:].view("<u4")[0]
if val != found:
raise RuntimeError(
f"The fletcher32 checksum of the data ({val}) did not"
f" match the expected checksum ({found}).\n"
"This could be a sign that the data has been corrupted."
)
if out:
out.view("uint8")[:] = b[:-4]
return out
return memoryview(b[:-4])
42 changes: 42 additions & 0 deletions numcodecs/tests/test_fletcher32.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import numpy as np
import pytest

from numcodecs.fletcher32 import Fletcher32


@pytest.mark.parametrize(
"dtype",
["uint8", "int32", "float32"]
)
def test_with_data(dtype):
data = np.arange(100, dtype=dtype)
f = Fletcher32()
arr = np.frombuffer(f.decode(f.encode(data)), dtype=dtype)
assert (arr == data).all()
martindurant marked this conversation as resolved.
Show resolved Hide resolved


def test_error():
data = np.arange(100)
f = Fletcher32()
enc = f.encode(data)
enc2 = bytearray(enc)
enc2[0] += 1
with pytest.raises(RuntimeError) as e:
f.decode(enc2)
assert "fletcher32 checksum" in str(e.value)


def test_known():
data = (
b'w\x07\x00\x00\x00\x00\x00\x00\x85\xf6\xff\xff\xff\xff\xff\xff'
b'i\x07\x00\x00\x00\x00\x00\x00\x94\xf6\xff\xff\xff\xff\xff\xff'
b'\x88\t\x00\x00\x00\x00\x00\x00i\x03\x00\x00\x00\x00\x00\x00'
b'\x93\xfd\xff\xff\xff\xff\xff\xff\xc3\xfc\xff\xff\xff\xff\xff\xff'
b"'\x02\x00\x00\x00\x00\x00\x00\xba\xf7\xff\xff\xff\xff\xff\xff"
b'\xfd%\x86d')
data3 = Fletcher32().decode(data)
outarr = np.frombuffer(data3, dtype="<i8")
expected = [
1911, -2427, 1897, -2412, 2440, 873, -621, -829, 551, -2118,
]
assert outarr.tolist() == expected
28 changes: 27 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,31 @@ def vlen_extension():
return extensions


def fletcher_extension():
info('setting up fletcher32 extension')

extra_compile_args = base_compile_args.copy()
define_macros = []

# setup sources
include_dirs = ['numcodecs']
# define_macros += [('CYTHON_TRACE', '1')]

sources = ['numcodecs/fletcher32.pyx']

# define extension module
extensions = [
Extension('numcodecs.fletcher32',
sources=sources,
include_dirs=include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
),
]

return extensions


def compat_extension():
info('setting up compat extension')

Expand Down Expand Up @@ -265,7 +290,8 @@ def run_setup(with_extensions):

if with_extensions:
ext_modules = (blosc_extension() + zstd_extension() + lz4_extension() +
compat_extension() + shuffle_extension() + vlen_extension())
compat_extension() + shuffle_extension() + vlen_extension() +
fletcher_extension())

cmdclass = dict(build_ext=ve_build_ext)
else:
Expand Down