From cbb6458dacb102597e38a0faae10c5507f198739 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Sat, 18 May 2024 19:34:11 -0400 Subject: [PATCH] Added support for sparsevec type --- CHANGELOG.md | 2 +- lib/pgvector.ex | 26 ++++++++++++- lib/pgvector/ecto/sparse_vector.ex | 19 ++++++++++ lib/pgvector/extensions/sparsevec.ex | 31 +++++++++++++++ lib/pgvector/sparse_vector.ex | 56 ++++++++++++++++++++++++++++ test/ecto_test.exs | 35 +++++++++++++++-- test/postgrex_test.exs | 13 ++++++- test/sparse_vector_test.exs | 28 ++++++++++++++ 8 files changed, 203 insertions(+), 7 deletions(-) create mode 100644 lib/pgvector/ecto/sparse_vector.ex create mode 100644 lib/pgvector/extensions/sparsevec.ex create mode 100644 lib/pgvector/sparse_vector.ex create mode 100644 test/sparse_vector_test.exs diff --git a/CHANGELOG.md b/CHANGELOG.md index 24c467d..6ef0331 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ ## 0.2.2 (unreleased) -- Added support for `halfvec` type +- Added support for `halfvec` and `sparsevec` types - Added `Pgvector.extensions/0` function - Added `l1_distance` function for Ecto diff --git a/lib/pgvector.ex b/lib/pgvector.ex index b05bb7a..d6ec8d0 100644 --- a/lib/pgvector.ex +++ b/lib/pgvector.ex @@ -56,6 +56,10 @@ defmodule Pgvector do vector.data end + def to_binary(vector) when is_struct(vector, Pgvector.SparseVector) do + vector.data + end + @doc """ Converts the vector to a list """ @@ -69,6 +73,16 @@ defmodule Pgvector do for <>, do: v end + def to_list(vector) when is_struct(vector, Pgvector.SparseVector) do + <> = vector.data + + indices = for <>, do: v + values = for <>, do: v + list = List.duplicate(0.0, dim) + Enum.zip_reduce(indices, values, list, fn x, y, acc -> List.replace_at(acc, x, y) end) + end + if Code.ensure_loaded?(Nx) do @doc """ Converts the vector to a tensor @@ -83,6 +97,11 @@ defmodule Pgvector do bin |> f16_big_to_native() |> Nx.from_binary(:f16) end + def to_tensor(vector) when is_struct(vector, Pgvector.SparseVector) do + # TODO improve + vector |> to_list() |> Nx.tensor(type: :f32) + end + defp f32_big_to_native(binary) do if System.endianness() == :big do binary @@ -106,7 +125,8 @@ defmodule Pgvector do def extensions do [ Pgvector.Extensions.Vector, - Pgvector.Extensions.Halfvec + Pgvector.Extensions.Halfvec, + Pgvector.Extensions.Sparsevec ] end @@ -116,6 +136,10 @@ defmodule Pgvector do vector end + def to_sql(vector) when is_struct(vector, Pgvector.SparseVector) do + vector + end + def to_sql(vector) do vector |> Pgvector.new() end diff --git a/lib/pgvector/ecto/sparse_vector.ex b/lib/pgvector/ecto/sparse_vector.ex new file mode 100644 index 0000000..9611329 --- /dev/null +++ b/lib/pgvector/ecto/sparse_vector.ex @@ -0,0 +1,19 @@ +if Code.ensure_loaded?(Ecto) do + defmodule Pgvector.Ecto.SparseVector do + use Ecto.Type + + def type, do: :sparsevec + + def cast(value) do + {:ok, value |> Pgvector.SparseVector.new()} + end + + def load(data) do + {:ok, data} + end + + def dump(value) do + {:ok, value} + end + end +end diff --git a/lib/pgvector/extensions/sparsevec.ex b/lib/pgvector/extensions/sparsevec.ex new file mode 100644 index 0000000..4bda21a --- /dev/null +++ b/lib/pgvector/extensions/sparsevec.ex @@ -0,0 +1,31 @@ +defmodule Pgvector.Extensions.Sparsevec do + import Postgrex.BinaryUtils, warn: false + + def init(opts), do: Keyword.get(opts, :decode_binary, :copy) + + def matching(_), do: [type: "sparsevec"] + + def format(_), do: :binary + + def encode(_) do + quote do + vec -> + data = vec |> Pgvector.SparseVector.new() |> Pgvector.to_binary() + [<> | data] + end + end + + def decode(:copy) do + quote do + <> -> + bin |> :binary.copy() |> Pgvector.SparseVector.from_binary() + end + end + + def decode(_) do + quote do + <> -> + bin |> Pgvector.SparseVector.from_binary() + end + end +end diff --git a/lib/pgvector/sparse_vector.ex b/lib/pgvector/sparse_vector.ex new file mode 100644 index 0000000..a7bea1e --- /dev/null +++ b/lib/pgvector/sparse_vector.ex @@ -0,0 +1,56 @@ +defmodule Pgvector.SparseVector do + @moduledoc """ + A sparse vector struct for pgvector + """ + + defstruct [:data] + + @doc """ + Creates a new sparse vector from a list, tensor or sparse vector + """ + def new(list) when is_list(list) do + indices = + list + |> Enum.with_index() + |> Enum.filter(fn {v, _} -> v != 0 end) + |> Enum.map(fn {_, i} -> i end) + + values = list |> Enum.filter(fn v -> v != 0 end) + dim = list |> length() + nnz = indices |> length() + indices = for v <- indices, into: "", do: <> + values = for v <- values, into: "", do: <> + from_binary(<>) + end + + def new(%Pgvector.SparseVector{} = vector) do + vector + end + + if Code.ensure_loaded?(Nx) do + def new(tensor) when is_struct(tensor, Nx.Tensor) do + if Nx.rank(tensor) != 1 do + raise ArgumentError, "expected rank to be 1" + end + + # TODO improve + new(tensor |> Nx.to_list()) + end + end + + @doc """ + Creates a new sparse vector from its binary representation + """ + def from_binary(binary) when is_binary(binary) do + %Pgvector.SparseVector{data: binary} + end +end + +defimpl Inspect, for: Pgvector.SparseVector do + import Inspect.Algebra + + def inspect(vec, opts) do + # TODO improve + concat(["Pgvector.SparseVector.new(", Inspect.List.inspect(Pgvector.to_list(vec), opts), ")"]) + end +end diff --git a/test/ecto_test.exs b/test/ecto_test.exs index 48f9cd5..81bafd4 100644 --- a/test/ecto_test.exs +++ b/test/ecto_test.exs @@ -4,6 +4,7 @@ defmodule Item do schema "ecto_items" do field :embedding, Pgvector.Ecto.Vector field :half_embedding, Pgvector.Ecto.HalfVector + field :sparse_embedding, Pgvector.Ecto.SparseVector end end @@ -16,15 +17,15 @@ defmodule EctoTest do setup_all do Ecto.Adapters.SQL.query!(Repo, "CREATE EXTENSION IF NOT EXISTS vector", []) Ecto.Adapters.SQL.query!(Repo, "DROP TABLE IF EXISTS ecto_items", []) - Ecto.Adapters.SQL.query!(Repo, "CREATE TABLE ecto_items (id bigserial primary key, embedding vector(3), half_embedding halfvec(3))", []) + Ecto.Adapters.SQL.query!(Repo, "CREATE TABLE ecto_items (id bigserial primary key, embedding vector(3), half_embedding halfvec(3), sparse_embedding sparsevec(3))", []) create_items() :ok end defp create_items do - Repo.insert(%Item{embedding: Pgvector.new([1, 1, 1]), half_embedding: Pgvector.HalfVector.new([1, 1, 1])}) - Repo.insert(%Item{embedding: [2, 2, 3], half_embedding: [2, 2, 3]}) - Repo.insert(%Item{embedding: Nx.tensor([1, 1, 2], type: :f32), half_embedding: Nx.tensor([1, 1, 2], type: :f16)}) + Repo.insert(%Item{embedding: Pgvector.new([1, 1, 1]), half_embedding: Pgvector.HalfVector.new([1, 1, 1]), sparse_embedding: Pgvector.SparseVector.new([1, 1, 1])}) + Repo.insert(%Item{embedding: [2, 2, 3], half_embedding: [2, 2, 3], sparse_embedding: [2, 2, 3]}) + Repo.insert(%Item{embedding: Nx.tensor([1, 1, 2], type: :f32), half_embedding: Nx.tensor([1, 1, 2], type: :f16), sparse_embedding: Nx.tensor([1, 1, 2], type: :f32)}) end test "vector l2 distance" do @@ -79,6 +80,32 @@ defmodule EctoTest do assert Enum.map(items, fn v -> v.id end) == [1, 3, 2] end + test "sparsevec l2 distance" do + items = Repo.all(from i in Item, order_by: l2_distance(i.sparse_embedding, Pgvector.SparseVector.new([1, 1, 1])), limit: 5) + assert Enum.map(items, fn v -> v.id end) == [1, 3, 2] + assert Enum.map(items, fn v -> v.sparse_embedding |> Pgvector.to_list() end) == [[1.0, 1.0, 1.0], [1.0, 1.0, 2.0], [2.0, 2.0, 3.0]] + end + + test "sparsevec max inner product" do + items = Repo.all(from i in Item, order_by: max_inner_product(i.sparse_embedding, Pgvector.SparseVector.new([1, 1, 1])), limit: 5) + assert Enum.map(items, fn v -> v.id end) == [2, 3, 1] + end + + test "sparsevec cosine distance" do + items = Repo.all(from i in Item, order_by: cosine_distance(i.sparse_embedding, Pgvector.SparseVector.new([1, 1, 1])), limit: 5) + assert Enum.map(items, fn v -> v.id end) == [1, 2, 3] + end + + test "sparsevec cosine similarity" do + items = Repo.all(from i in Item, order_by: (1 - cosine_distance(i.sparse_embedding, Pgvector.SparseVector.new([1, 1, 1]))), limit: 5) + assert Enum.map(items, fn v -> v.id end) == [3, 2, 1] + end + + test "sparsevec l1 distance" do + items = Repo.all(from i in Item, order_by: l1_distance(i.sparse_embedding, Pgvector.SparseVector.new([1, 1, 1])), limit: 5) + assert Enum.map(items, fn v -> v.id end) == [1, 3, 2] + end + test "cast" do embedding = [1, 1, 1] items = Repo.all(from i in Item, where: i.embedding == ^embedding) diff --git a/test/postgrex_test.exs b/test/postgrex_test.exs index 2739b04..e25770e 100644 --- a/test/postgrex_test.exs +++ b/test/postgrex_test.exs @@ -10,7 +10,7 @@ defmodule PostgrexTest do {:ok, pid} = Postgrex.start_link(database: "pgvector_elixir_test", types: PostgrexApp.PostgrexTypes) Postgrex.query!(pid, "CREATE EXTENSION IF NOT EXISTS vector", []) Postgrex.query!(pid, "DROP TABLE IF EXISTS postgrex_items", []) - Postgrex.query!(pid, "CREATE TABLE postgrex_items (id bigserial primary key, embedding vector(3), half_embedding halfvec(3), binary_embedding bit(3))", []) + Postgrex.query!(pid, "CREATE TABLE postgrex_items (id bigserial primary key, embedding vector(3), half_embedding halfvec(3), binary_embedding bit(3), sparse_embedding sparsevec(3))", []) {:ok, pid: pid} end @@ -41,6 +41,17 @@ defmodule PostgrexTest do assert Enum.map(result.rows, fn v -> Enum.at(v, 1) |> Pgvector.to_list() end) == [[1.0, 1.0, 1.0], [1.0, 1.0, 2.0], [2.0, 2.0, 2.0]] end + test "sparsevec l2 distance", %{pid: pid} = _context do + embeddings = [Pgvector.SparseVector.new([1, 1, 1]), [2, 2, 2], Nx.tensor([1, 1, 2], type: :f32)] + Postgrex.query!(pid, "INSERT INTO postgrex_items (sparse_embedding) VALUES ($1), ($2), ($3)", embeddings) + + result = Postgrex.query!(pid, "SELECT id, sparse_embedding FROM postgrex_items ORDER BY sparse_embedding <-> $1 LIMIT 5", [[1, 1, 1]]) + + assert ["id", "sparse_embedding"] == result.columns + assert Enum.map(result.rows, fn v -> Enum.at(v, 0) end) == [1, 3, 2] + assert Enum.map(result.rows, fn v -> Enum.at(v, 1) |> Pgvector.to_list() end) == [[1.0, 1.0, 1.0], [1.0, 1.0, 2.0], [2.0, 2.0, 2.0]] + end + test "create index", %{pid: pid} = _context do Postgrex.query!(pid, "CREATE INDEX my_index ON postgrex_items USING ivfflat (embedding vector_l2_ops) WITH (lists = 1)", []) end diff --git a/test/sparse_vector_test.exs b/test/sparse_vector_test.exs new file mode 100644 index 0000000..c253f68 --- /dev/null +++ b/test/sparse_vector_test.exs @@ -0,0 +1,28 @@ +defmodule SparseVectorTest do + use ExUnit.Case + + test "sparse vector" do + vector = Pgvector.SparseVector.new([1, 2, 3]) + assert vector == vector |> Pgvector.SparseVector.new() + end + + test "list" do + list = [1.0, 2.0, 3.0] + assert list == list |> Pgvector.SparseVector.new() |> Pgvector.to_list() + end + + test "tensor" do + tensor = Nx.tensor([1.0, 2.0, 3.0], type: :f32) + assert tensor == tensor |> Pgvector.SparseVector.new() |> Pgvector.to_tensor() + end + + test "inspect" do + vector = Pgvector.SparseVector.new([1, 2, 3]) + assert "Pgvector.SparseVector.new([1.0, 2.0, 3.0])" == inspect(vector) + end + + test "equals" do + assert Pgvector.SparseVector.new([1, 2, 3]) == Pgvector.SparseVector.new([1, 2, 3]) + refute Pgvector.SparseVector.new([1, 2, 3]) == Pgvector.SparseVector.new([1, 2, 4]) + end +end