You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am trying to run an embedding model in bfloat16, which works, but the embeddings that are output actually are in float32. It's my understanding that this is because of the following...
Currently, ctranslate2's core C++ implementation (storage_view.h) supports BFLOAT16 computation, but this data type cannot be exposed through the Python bindings due to two limitations:
In module.cc, the array interface implementation (dtype_to_typestr and typestr_to_dtype functions) does not include BFLOAT16 in its supported types, falling into the default case which throws a runtime error instructing users to convert to float32.
More fundamentally, NumPy does not natively support BFLOAT16. The ml_dtypes library (https://github.com/jax-ml/ml_dtypes) exists specifically to add BFLOAT16 and other ML-focused dtypes to NumPy.
Therefore, my proposed solution is...
Add ml_dtypes as an optional dependency for ctranslate2
Extend the array interface implementation in module.cc to support BFLOAT16 when ml_dtypes is available
Update dtype_to_typestr to return the correct type string for BFLOAT16 that ml_dtypes expects
Update typestr_to_dtype to properly handle BFLOAT16 type strings from ml_dtypes
The text was updated successfully, but these errors were encountered:
I am trying to run an embedding model in
bfloat16
, which works, but the embeddings that are output actually are infloat32
. It's my understanding that this is because of the following...Currently, ctranslate2's core C++ implementation (storage_view.h) supports BFLOAT16 computation, but this data type cannot be exposed through the Python bindings due to two limitations:
In module.cc, the array interface implementation (dtype_to_typestr and typestr_to_dtype functions) does not include BFLOAT16 in its supported types, falling into the default case which throws a runtime error instructing users to convert to float32.
More fundamentally, NumPy does not natively support BFLOAT16. The ml_dtypes library (https://github.com/jax-ml/ml_dtypes) exists specifically to add BFLOAT16 and other ML-focused dtypes to NumPy.
Therefore, my proposed solution is...
Add ml_dtypes as an optional dependency for ctranslate2
Extend the array interface implementation in module.cc to support BFLOAT16 when ml_dtypes is available
Update dtype_to_typestr to return the correct type string for BFLOAT16 that ml_dtypes expects
Update typestr_to_dtype to properly handle BFLOAT16 type strings from ml_dtypes
The text was updated successfully, but these errors were encountered: