Skip to content

Commit

Permalink
Added support for sparsevec type
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed May 18, 2024
1 parent 7ef2332 commit cbb6458
Show file tree
Hide file tree
Showing 8 changed files with 203 additions and 7 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
## 0.2.2 (unreleased)

- Added support for `halfvec` type
- Added support for `halfvec` and `sparsevec` types
- Added `Pgvector.extensions/0` function
- Added `l1_distance` function for Ecto

Expand Down
26 changes: 25 additions & 1 deletion lib/pgvector.ex
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ defmodule Pgvector do
vector.data
end

def to_binary(vector) when is_struct(vector, Pgvector.SparseVector) do
vector.data
end

@doc """
Converts the vector to a list
"""
Expand All @@ -69,6 +73,16 @@ defmodule Pgvector do
for <<v::float-16 <- bin>>, do: v
end

def to_list(vector) when is_struct(vector, Pgvector.SparseVector) do
<<dim::signed-32, nnz::signed-32, 0::signed-32, indices::binary-size(nnz)-unit(32),
values::binary-size(nnz)-unit(32)>> = vector.data

indices = for <<v::signed-32 <- indices>>, do: v
values = for <<v::float-32 <- values>>, do: v
list = List.duplicate(0.0, dim)
Enum.zip_reduce(indices, values, list, fn x, y, acc -> List.replace_at(acc, x, y) end)
end

if Code.ensure_loaded?(Nx) do
@doc """
Converts the vector to a tensor
Expand All @@ -83,6 +97,11 @@ defmodule Pgvector do
bin |> f16_big_to_native() |> Nx.from_binary(:f16)
end

def to_tensor(vector) when is_struct(vector, Pgvector.SparseVector) do
# TODO improve
vector |> to_list() |> Nx.tensor(type: :f32)
end

defp f32_big_to_native(binary) do
if System.endianness() == :big do
binary
Expand All @@ -106,7 +125,8 @@ defmodule Pgvector do
def extensions do
[
Pgvector.Extensions.Vector,
Pgvector.Extensions.Halfvec
Pgvector.Extensions.Halfvec,
Pgvector.Extensions.Sparsevec
]
end

Expand All @@ -116,6 +136,10 @@ defmodule Pgvector do
vector
end

def to_sql(vector) when is_struct(vector, Pgvector.SparseVector) do
vector
end

def to_sql(vector) do
vector |> Pgvector.new()
end
Expand Down
19 changes: 19 additions & 0 deletions lib/pgvector/ecto/sparse_vector.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
if Code.ensure_loaded?(Ecto) do
defmodule Pgvector.Ecto.SparseVector do
use Ecto.Type

def type, do: :sparsevec

def cast(value) do
{:ok, value |> Pgvector.SparseVector.new()}
end

def load(data) do
{:ok, data}
end

def dump(value) do
{:ok, value}
end
end
end
31 changes: 31 additions & 0 deletions lib/pgvector/extensions/sparsevec.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
defmodule Pgvector.Extensions.Sparsevec do
import Postgrex.BinaryUtils, warn: false

def init(opts), do: Keyword.get(opts, :decode_binary, :copy)

def matching(_), do: [type: "sparsevec"]

def format(_), do: :binary

def encode(_) do
quote do
vec ->
data = vec |> Pgvector.SparseVector.new() |> Pgvector.to_binary()
[<<IO.iodata_length(data)::int32()>> | data]
end
end

def decode(:copy) do
quote do
<<len::int32(), bin::binary-size(len)>> ->
bin |> :binary.copy() |> Pgvector.SparseVector.from_binary()
end
end

def decode(_) do
quote do
<<len::int32(), bin::binary-size(len)>> ->
bin |> Pgvector.SparseVector.from_binary()
end
end
end
56 changes: 56 additions & 0 deletions lib/pgvector/sparse_vector.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
defmodule Pgvector.SparseVector do
@moduledoc """
A sparse vector struct for pgvector
"""

defstruct [:data]

@doc """
Creates a new sparse vector from a list, tensor or sparse vector
"""
def new(list) when is_list(list) do
indices =
list
|> Enum.with_index()
|> Enum.filter(fn {v, _} -> v != 0 end)
|> Enum.map(fn {_, i} -> i end)

values = list |> Enum.filter(fn v -> v != 0 end)
dim = list |> length()
nnz = indices |> length()
indices = for v <- indices, into: "", do: <<v::signed-32>>
values = for v <- values, into: "", do: <<v::float-32>>
from_binary(<<dim::signed-32, nnz::signed-32, 0::signed-32, indices::binary, values::binary>>)
end

def new(%Pgvector.SparseVector{} = vector) do
vector
end

if Code.ensure_loaded?(Nx) do
def new(tensor) when is_struct(tensor, Nx.Tensor) do
if Nx.rank(tensor) != 1 do
raise ArgumentError, "expected rank to be 1"
end

# TODO improve
new(tensor |> Nx.to_list())
end
end

@doc """
Creates a new sparse vector from its binary representation
"""
def from_binary(binary) when is_binary(binary) do
%Pgvector.SparseVector{data: binary}
end
end

defimpl Inspect, for: Pgvector.SparseVector do
import Inspect.Algebra

def inspect(vec, opts) do
# TODO improve
concat(["Pgvector.SparseVector.new(", Inspect.List.inspect(Pgvector.to_list(vec), opts), ")"])
end
end
35 changes: 31 additions & 4 deletions test/ecto_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ defmodule Item do
schema "ecto_items" do
field :embedding, Pgvector.Ecto.Vector
field :half_embedding, Pgvector.Ecto.HalfVector
field :sparse_embedding, Pgvector.Ecto.SparseVector
end
end

Expand All @@ -16,15 +17,15 @@ defmodule EctoTest do
setup_all do
Ecto.Adapters.SQL.query!(Repo, "CREATE EXTENSION IF NOT EXISTS vector", [])
Ecto.Adapters.SQL.query!(Repo, "DROP TABLE IF EXISTS ecto_items", [])
Ecto.Adapters.SQL.query!(Repo, "CREATE TABLE ecto_items (id bigserial primary key, embedding vector(3), half_embedding halfvec(3))", [])
Ecto.Adapters.SQL.query!(Repo, "CREATE TABLE ecto_items (id bigserial primary key, embedding vector(3), half_embedding halfvec(3), sparse_embedding sparsevec(3))", [])
create_items()
:ok
end

defp create_items do
Repo.insert(%Item{embedding: Pgvector.new([1, 1, 1]), half_embedding: Pgvector.HalfVector.new([1, 1, 1])})
Repo.insert(%Item{embedding: [2, 2, 3], half_embedding: [2, 2, 3]})
Repo.insert(%Item{embedding: Nx.tensor([1, 1, 2], type: :f32), half_embedding: Nx.tensor([1, 1, 2], type: :f16)})
Repo.insert(%Item{embedding: Pgvector.new([1, 1, 1]), half_embedding: Pgvector.HalfVector.new([1, 1, 1]), sparse_embedding: Pgvector.SparseVector.new([1, 1, 1])})
Repo.insert(%Item{embedding: [2, 2, 3], half_embedding: [2, 2, 3], sparse_embedding: [2, 2, 3]})
Repo.insert(%Item{embedding: Nx.tensor([1, 1, 2], type: :f32), half_embedding: Nx.tensor([1, 1, 2], type: :f16), sparse_embedding: Nx.tensor([1, 1, 2], type: :f32)})
end

test "vector l2 distance" do
Expand Down Expand Up @@ -79,6 +80,32 @@ defmodule EctoTest do
assert Enum.map(items, fn v -> v.id end) == [1, 3, 2]
end

test "sparsevec l2 distance" do
items = Repo.all(from i in Item, order_by: l2_distance(i.sparse_embedding, Pgvector.SparseVector.new([1, 1, 1])), limit: 5)
assert Enum.map(items, fn v -> v.id end) == [1, 3, 2]
assert Enum.map(items, fn v -> v.sparse_embedding |> Pgvector.to_list() end) == [[1.0, 1.0, 1.0], [1.0, 1.0, 2.0], [2.0, 2.0, 3.0]]
end

test "sparsevec max inner product" do
items = Repo.all(from i in Item, order_by: max_inner_product(i.sparse_embedding, Pgvector.SparseVector.new([1, 1, 1])), limit: 5)
assert Enum.map(items, fn v -> v.id end) == [2, 3, 1]
end

test "sparsevec cosine distance" do
items = Repo.all(from i in Item, order_by: cosine_distance(i.sparse_embedding, Pgvector.SparseVector.new([1, 1, 1])), limit: 5)
assert Enum.map(items, fn v -> v.id end) == [1, 2, 3]
end

test "sparsevec cosine similarity" do
items = Repo.all(from i in Item, order_by: (1 - cosine_distance(i.sparse_embedding, Pgvector.SparseVector.new([1, 1, 1]))), limit: 5)
assert Enum.map(items, fn v -> v.id end) == [3, 2, 1]
end

test "sparsevec l1 distance" do
items = Repo.all(from i in Item, order_by: l1_distance(i.sparse_embedding, Pgvector.SparseVector.new([1, 1, 1])), limit: 5)
assert Enum.map(items, fn v -> v.id end) == [1, 3, 2]
end

test "cast" do
embedding = [1, 1, 1]
items = Repo.all(from i in Item, where: i.embedding == ^embedding)
Expand Down
13 changes: 12 additions & 1 deletion test/postgrex_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ defmodule PostgrexTest do
{:ok, pid} = Postgrex.start_link(database: "pgvector_elixir_test", types: PostgrexApp.PostgrexTypes)
Postgrex.query!(pid, "CREATE EXTENSION IF NOT EXISTS vector", [])
Postgrex.query!(pid, "DROP TABLE IF EXISTS postgrex_items", [])
Postgrex.query!(pid, "CREATE TABLE postgrex_items (id bigserial primary key, embedding vector(3), half_embedding halfvec(3), binary_embedding bit(3))", [])
Postgrex.query!(pid, "CREATE TABLE postgrex_items (id bigserial primary key, embedding vector(3), half_embedding halfvec(3), binary_embedding bit(3), sparse_embedding sparsevec(3))", [])
{:ok, pid: pid}
end

Expand Down Expand Up @@ -41,6 +41,17 @@ defmodule PostgrexTest do
assert Enum.map(result.rows, fn v -> Enum.at(v, 1) |> Pgvector.to_list() end) == [[1.0, 1.0, 1.0], [1.0, 1.0, 2.0], [2.0, 2.0, 2.0]]
end

test "sparsevec l2 distance", %{pid: pid} = _context do
embeddings = [Pgvector.SparseVector.new([1, 1, 1]), [2, 2, 2], Nx.tensor([1, 1, 2], type: :f32)]
Postgrex.query!(pid, "INSERT INTO postgrex_items (sparse_embedding) VALUES ($1), ($2), ($3)", embeddings)

result = Postgrex.query!(pid, "SELECT id, sparse_embedding FROM postgrex_items ORDER BY sparse_embedding <-> $1 LIMIT 5", [[1, 1, 1]])

assert ["id", "sparse_embedding"] == result.columns
assert Enum.map(result.rows, fn v -> Enum.at(v, 0) end) == [1, 3, 2]
assert Enum.map(result.rows, fn v -> Enum.at(v, 1) |> Pgvector.to_list() end) == [[1.0, 1.0, 1.0], [1.0, 1.0, 2.0], [2.0, 2.0, 2.0]]
end

test "create index", %{pid: pid} = _context do
Postgrex.query!(pid, "CREATE INDEX my_index ON postgrex_items USING ivfflat (embedding vector_l2_ops) WITH (lists = 1)", [])
end
Expand Down
28 changes: 28 additions & 0 deletions test/sparse_vector_test.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
defmodule SparseVectorTest do
use ExUnit.Case

test "sparse vector" do
vector = Pgvector.SparseVector.new([1, 2, 3])
assert vector == vector |> Pgvector.SparseVector.new()
end

test "list" do
list = [1.0, 2.0, 3.0]
assert list == list |> Pgvector.SparseVector.new() |> Pgvector.to_list()
end

test "tensor" do
tensor = Nx.tensor([1.0, 2.0, 3.0], type: :f32)
assert tensor == tensor |> Pgvector.SparseVector.new() |> Pgvector.to_tensor()
end

test "inspect" do
vector = Pgvector.SparseVector.new([1, 2, 3])
assert "Pgvector.SparseVector.new([1.0, 2.0, 3.0])" == inspect(vector)
end

test "equals" do
assert Pgvector.SparseVector.new([1, 2, 3]) == Pgvector.SparseVector.new([1, 2, 3])
refute Pgvector.SparseVector.new([1, 2, 3]) == Pgvector.SparseVector.new([1, 2, 4])
end
end

0 comments on commit cbb6458

Please sign in to comment.