Skip to content

Commit

Permalink
Add Safetensors.write!/2 for streamed write (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko authored Feb 22, 2024
1 parent dd365da commit ed9267d
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 31 deletions.
103 changes: 72 additions & 31 deletions lib/safetensors.ex
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,70 @@ defmodule Safetensors do

@dtype_to_type for {k, v} <- @type_to_dtype, into: %{}, do: {v, k}

@doc """
Writes a map of tensors to a file.
Tensors are written into the file one by one, without the need to
dump all of them into the memory at once.
"""
@spec write!(path :: Path.t(), %{String.t() => Nx.Tensor.t()}) :: :ok
def write!(path, tensors) when is_map(tensors) do
File.open!(path, [:write, :raw], fn file ->
tensors = Enum.sort(tensors)

{header_entries, _offset} =
Enum.map_reduce(tensors, 0, fn {tensor_name, tensor}, offset ->
tensor_header_entry(tensor_name, tensor, offset)
end)

:ok = :file.write(file, header_binary(header_entries))

for {_tensor_name, tensor} <- tensors do
:ok = :file.write(file, tensor_to_binary(tensor))
end
end)

:ok
end

defp tensor_header_entry(tensor_name, tensor, offset) do
end_offset = offset + tensor_byte_size(tensor)

header_entry = {
tensor_name,
Jason.OrderedObject.new(
dtype: tensor |> Nx.type() |> type_to_dtype(),
shape: tensor |> Nx.shape() |> Tuple.to_list(),
data_offsets: [offset, end_offset]
)
}

{header_entry, end_offset}
end

defp header_binary(header_entries) do
header_json =
header_entries
|> Jason.OrderedObject.new()
|> Jason.encode!()

[<<byte_size(header_json)::unsigned-64-integer-little>>, header_json]
end

defp tensor_byte_size(tensor) do
{_, elem_size} = Nx.type(tensor)
elem_byte_size = div(elem_size, 8)
Nx.size(tensor) * elem_byte_size
end

defp tensor_to_binary(tensor) do
{_, elem_size} = Nx.type(tensor)

tensor
|> Nx.to_binary()
|> new_byte_order(elem_size, :little)
end

@doc """
Serializes the given map of tensors to iodata.
Expand All @@ -50,46 +114,23 @@ defmodule Safetensors do
"""
@spec dump(%{String.t() => Nx.Tensor.t()}) :: iodata()
def dump(tensors) when is_map(tensors) do
tensors = Enum.sort(tensors)

{header_entries, {buffer, _offset}} =
Enum.map_reduce(tensors, {[], 0}, fn {tensor_name, tensor}, {buffer, offset} ->
{_, elem_size} = Nx.type(tensor)

binary =
tensor
|> Nx.to_binary()
|> new_byte_order(elem_size, :little)

end_offset = offset + byte_size(binary)

header_entry = {
tensor_name,
Jason.OrderedObject.new(
dtype: tensor |> Nx.type() |> type_to_dtype(),
shape: tensor |> Nx.shape() |> Tuple.to_list(),
data_offsets: [offset, end_offset]
)
}

{header_entry, end_offset} = tensor_header_entry(tensor_name, tensor, offset)
binary = tensor_to_binary(tensor)
{header_entry, {[buffer, binary], end_offset}}
end)

header_json =
header_entries
|> Jason.OrderedObject.new()
|> Jason.encode!()

[
<<byte_size(header_json)::unsigned-64-integer-little>>,
header_json,
buffer
]
[header_binary(header_entries), buffer]
end

@doc """
Reads a safe tensor from file.
Reads a serialized map of tensors from a file.
Tensors are loaded into Nx one by one,
without the need to load the entire file from disk into memory.
Tensors are loaded into Nx one by one, without the need to load the
entire file from disk into memory.
"""
@spec read!(path :: Path.t()) :: %{String.t() => Nx.Tensor.t()}
def read!(path) do
Expand Down
14 changes: 14 additions & 0 deletions test/safetensors_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,20 @@ defmodule SafetensorsTest do

doctest Safetensors

@tag :tmp_dir
test "write", %{tmp_dir: tmp_dir} do
path = Path.join(tmp_dir, "safetensor")

data = %{test: Nx.tensor([[1, 2], [3, 4]], type: :s32)}
Safetensors.write!(path, data)

# source:
# https://github.com/huggingface/safetensors/blob/1a65a3fdebcf280ef0ca32934901d3e2ad3b2c65/bindings/python/tests/test_simple.py#L22-L25
# with the header padding removed and changed numbers
assert File.read!(path) ==
~s(<\x00\x00\x00\x00\x00\x00\x00{"test":{"dtype":"I32","shape":[2,2],"data_offsets":[0,16]}}\x01\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x04\x00\x00\x00)
end

test "dump" do
binary =
%{test: Nx.tensor([[1, 2], [3, 4]], type: :s32)}
Expand Down

0 comments on commit ed9267d

Please sign in to comment.