-
Notifications
You must be signed in to change notification settings - Fork 229
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add device
argument for safetensors.flax.load_file
#399
Comments
This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days. |
Thanks for the note, the docstring is outdated or bad copy pasted. The reason for the argument not being here, is that Flax doesn't provide a way to create tensors directly on device (afaik), Also I thought lazy tensors placements for flax was more idiomatic. How exactly do you move the tensors ? |
@Narsil I later found out I can load my Flax msgpack models directly to CPU with this: cpu_device = jax.devices('cpu')[0]
with jax.default_device(cpu_device):
with open(msgpack_file, "rb") as state_f:
state = from_bytes(cls, state_f.read()) In any case, I've now moved to saving my flax model in >>> x = jnp.zeros((5,5))
>>> type(x)
<class 'jaxlib.xla_extension.ArrayImpl'> # array is on device
>>> type(jax.device_get(x))
<class 'numpy.ndarray'> Then to load the model state = safetensors.numpy.load_file(st_file) # np arrays on CPU
state = jax_utils.replicate(state) # returns jax.numpy arrays replicated on default device |
Thanks for sharing your fix. |
Feature request
Hey there - love this library! 👍
Any reason why the
device
argument is not valid (anymore?) forload_file
for flax?Also a bit confused, as it is listed as an argument in the docs 🤔 https://huggingface.co/docs/safetensors/main/en/api/flax#safetensors.flax.load_file
I'm using
safetensors==0.4.1
Motivation
It's useful to have control over device placement during model load
Your contribution
Probably not...
The text was updated successfully, but these errors were encountered: