Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fully support bfloat16 please #1840

Open
BBC-Esq opened this issue Jan 6, 2025 · 0 comments
Open

fully support bfloat16 please #1840

BBC-Esq opened this issue Jan 6, 2025 · 0 comments

Comments

@BBC-Esq
Copy link

BBC-Esq commented Jan 6, 2025

I am trying to run an embedding model in bfloat16, which works, but the embeddings that are output actually are in float32. It's my understanding that this is because of the following...

Currently, ctranslate2's core C++ implementation (storage_view.h) supports BFLOAT16 computation, but this data type cannot be exposed through the Python bindings due to two limitations:

In module.cc, the array interface implementation (dtype_to_typestr and typestr_to_dtype functions) does not include BFLOAT16 in its supported types, falling into the default case which throws a runtime error instructing users to convert to float32.
More fundamentally, NumPy does not natively support BFLOAT16. The ml_dtypes library (https://github.com/jax-ml/ml_dtypes) exists specifically to add BFLOAT16 and other ML-focused dtypes to NumPy.

Therefore, my proposed solution is...

Add ml_dtypes as an optional dependency for ctranslate2
Extend the array interface implementation in module.cc to support BFLOAT16 when ml_dtypes is available
Update dtype_to_typestr to return the correct type string for BFLOAT16 that ml_dtypes expects
Update typestr_to_dtype to properly handle BFLOAT16 type strings from ml_dtypes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant