Skip to content

Commit

Permalink
Fix de-serialisation of numpy fixed-width dtypes.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 705077026
  • Loading branch information
Flax Team committed Dec 16, 2024
1 parent fc38f21 commit ca6972c
Showing 1 changed file with 22 additions and 2 deletions.
24 changes: 22 additions & 2 deletions flax/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,10 +259,30 @@ def _ndarray_to_bytes(arr) -> bytes:
return msgpack.packb(tpl, use_bin_type=True)


def _dtype_from_name(name: str):
"""Handle JAX bfloat16 dtype correctly."""
def _dtype_from_name(name: bytes):
"""Handle JAX bfloat16 and other numpy fixed-width dtypes correctly."""

def _parse_bit_len(name: bytes, dtype_name: bytes):
return int(name.replace(dtype_name, b''))

if name == b'bfloat16':
return jax.numpy.bfloat16
elif name.startswith(b'str'):
string_dtype = np.asarray('x').dtype
# Typically '<U' indicating a little endian unicode. Get it programatically
# in case of big-endian architecture.
# See https://numpy.org/doc/2.1/reference/arrays.interface.html#python-side.
string_dtype_signature = str(string_dtype)[:2]
# numpy stores unicode strings in UCS4 encoding with 4 bytes per character.
ucs4_bits_per_character = _parse_bit_len(string_dtype.name.encode(), b'str')
len_in_bits = _parse_bit_len(name, b'str')
return np.dtype(
f'{string_dtype_signature}{len_in_bits//ucs4_bits_per_character}'
)
elif name.startswith(b'bytes'):
len_in_bits = _parse_bit_len(name, b'bytes')
# |S indicates no-order specific flag for zero-terminated bytes.
return np.dtype(f'|S{len_in_bits//8}')
else:
return np.dtype(name)

Expand Down

0 comments on commit ca6972c

Please sign in to comment.