From ca6972c874ee0c51873a4cd583f0c590245d7f5a Mon Sep 17 00:00:00 2001 From: Flax Team Date: Wed, 11 Dec 2024 05:55:47 -0800 Subject: [PATCH] Fix de-serialisation of numpy fixed-width dtypes. PiperOrigin-RevId: 705077026 --- flax/serialization.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/flax/serialization.py b/flax/serialization.py index bd6d0853cc..8894fe66cd 100644 --- a/flax/serialization.py +++ b/flax/serialization.py @@ -259,10 +259,30 @@ def _ndarray_to_bytes(arr) -> bytes: return msgpack.packb(tpl, use_bin_type=True) -def _dtype_from_name(name: str): - """Handle JAX bfloat16 dtype correctly.""" +def _dtype_from_name(name: bytes): + """Handle JAX bfloat16 and other numpy fixed-width dtypes correctly.""" + + def _parse_bit_len(name: bytes, dtype_name: bytes): + return int(name.replace(dtype_name, b'')) + if name == b'bfloat16': return jax.numpy.bfloat16 + elif name.startswith(b'str'): + string_dtype = np.asarray('x').dtype + # Typically '