Skip to content

Commit

Permalink
Added support for halfvec type
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed May 18, 2024
1 parent 8d2ffa1 commit 963ba57
Show file tree
Hide file tree
Showing 8 changed files with 208 additions and 14 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
## 0.2.2 (unreleased)

- Added support for `halfvec` type
- Added `Pgvector.extensions/0` function
- Added `l1_distance` function for Ecto

Expand Down
35 changes: 34 additions & 1 deletion lib/pgvector.ex
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ defmodule Pgvector do
vector.data
end

def to_binary(vector) when is_struct(vector, Pgvector.HalfVector) do
vector.data
end

@doc """
Converts the vector to a list
"""
Expand All @@ -60,6 +64,11 @@ defmodule Pgvector do
for <<v::float-32 <- bin>>, do: v
end

def to_list(vector) when is_struct(vector, Pgvector.HalfVector) do
<<dim::unsigned-16, 0::unsigned-16, bin::binary-size(dim)-unit(16)>> = vector.data
for <<v::float-16 <- bin>>, do: v
end

if Code.ensure_loaded?(Nx) do
@doc """
Converts the vector to a tensor
Expand All @@ -69,23 +78,47 @@ defmodule Pgvector do
bin |> f32_big_to_native() |> Nx.from_binary(:f32)
end

def to_tensor(vector) when is_struct(vector, Pgvector.HalfVector) do
<<dim::unsigned-16, 0::unsigned-16, bin::binary-size(dim)-unit(16)>> = vector.data
bin |> f16_big_to_native() |> Nx.from_binary(:f16)
end

defp f32_big_to_native(binary) do
if System.endianness() == :big do
binary
else
for <<n::float-32-big <- binary>>, into: "", do: <<n::float-32-little>>
end
end

defp f16_big_to_native(binary) do
if System.endianness() == :big do
binary
else
for <<n::float-16-big <- binary>>, into: "", do: <<n::float-16-little>>
end
end
end

@doc """
Extensions for Postgrex
"""
def extensions do
[
Pgvector.Extensions.Vector
Pgvector.Extensions.Vector,
Pgvector.Extensions.Halfvec
]
end

# TODO move / improve pattern
@doc false
def to_sql(vector) when is_struct(vector, Pgvector.HalfVector) do
vector
end

def to_sql(vector) do
vector |> Pgvector.new()
end
end

defimpl Inspect, for: Pgvector do
Expand Down
19 changes: 19 additions & 0 deletions lib/pgvector/ecto/halfvec.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
if Code.ensure_loaded?(Ecto) do
defmodule Pgvector.Ecto.HalfVector do
use Ecto.Type

def type, do: :halfvec

def cast(value) do
{:ok, value |> Pgvector.HalfVector.new()}
end

def load(data) do
{:ok, data}
end

def dump(value) do
{:ok, value}
end
end
end
8 changes: 4 additions & 4 deletions lib/pgvector/ecto/query.ex
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ if Code.ensure_loaded?(Ecto) do
"""
defmacro l2_distance(column, value) do
quote do
fragment("(? <-> ?::vector)", unquote(column), unquote(value))
fragment("(? <-> ?)", unquote(column), ^Pgvector.to_sql(unquote(value)))
end
end

Expand All @@ -18,7 +18,7 @@ if Code.ensure_loaded?(Ecto) do
"""
defmacro max_inner_product(column, value) do
quote do
fragment("(? <#> ?::vector)", unquote(column), unquote(value))
fragment("(? <#> ?)", unquote(column), ^Pgvector.to_sql(unquote(value)))
end
end

Expand All @@ -27,7 +27,7 @@ if Code.ensure_loaded?(Ecto) do
"""
defmacro cosine_distance(column, value) do
quote do
fragment("(? <=> ?::vector)", unquote(column), unquote(value))
fragment("(? <=> ?)", unquote(column), ^Pgvector.to_sql(unquote(value)))
end
end

Expand All @@ -36,7 +36,7 @@ if Code.ensure_loaded?(Ecto) do
"""
defmacro l1_distance(column, value) do
quote do
fragment("(? <+> ?::vector)", unquote(column), unquote(value))
fragment("(? <+> ?)", unquote(column), ^Pgvector.to_sql(unquote(value)))
end
end
end
Expand Down
31 changes: 31 additions & 0 deletions lib/pgvector/extensions/halfvec.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
defmodule Pgvector.Extensions.Halfvec do
import Postgrex.BinaryUtils, warn: false

def init(opts), do: Keyword.get(opts, :decode_binary, :copy)

def matching(_), do: [type: "halfvec"]

def format(_), do: :binary

def encode(_) do
quote do
vec ->
data = vec |> Pgvector.HalfVector.new() |> Pgvector.to_binary()
[<<IO.iodata_length(data)::int32()>> | data]
end
end

def decode(:copy) do
quote do
<<len::int32(), bin::binary-size(len)>> ->
bin |> :binary.copy() |> Pgvector.HalfVector.from_binary()
end
end

def decode(_) do
quote do
<<len::int32(), bin::binary-size(len)>> ->
bin |> Pgvector.HalfVector.from_binary()
end
end
end
55 changes: 55 additions & 0 deletions lib/pgvector/half_vector.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
defmodule Pgvector.HalfVector do
@moduledoc """
A half vector struct for pgvector
"""

defstruct [:data]

@doc """
Creates a new half vector from a list, tensor, or half vector
"""
def new(list) when is_list(list) do
dim = list |> length()
bin = for v <- list, into: "", do: <<v::float-16>>
from_binary(<<dim::unsigned-16, 0::unsigned-16, bin::binary>>)
end

def new(%Pgvector.HalfVector{} = vector) do
vector
end

if Code.ensure_loaded?(Nx) do
def new(tensor) when is_struct(tensor, Nx.Tensor) do
if Nx.rank(tensor) != 1 do
raise ArgumentError, "expected rank to be 1"
end

dim = tensor |> Nx.size()
bin = tensor |> Nx.as_type(:f16) |> Nx.to_binary() |> f16_native_to_big()
from_binary(<<dim::unsigned-16, 0::unsigned-16, bin::binary>>)
end

defp f16_native_to_big(binary) do
if System.endianness() == :big do
binary
else
for <<n::float-16-little <- binary>>, into: "", do: <<n::float-16-big>>
end
end
end

@doc """
Creates a new half vector from its binary representation
"""
def from_binary(binary) when is_binary(binary) do
%Pgvector.HalfVector{data: binary}
end
end

defimpl Inspect, for: Pgvector.HalfVector do
import Inspect.Algebra

def inspect(vec, opts) do
concat(["Pgvector.HalfVector.new(", Inspect.List.inspect(Pgvector.to_list(vec), opts), ")"])
end
end
45 changes: 36 additions & 9 deletions test/ecto_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ defmodule Item do

schema "ecto_items" do
field :embedding, Pgvector.Ecto.Vector
field :half_embedding, Pgvector.Ecto.HalfVector
end
end

Expand All @@ -15,43 +16,69 @@ defmodule EctoTest do
setup_all do
Ecto.Adapters.SQL.query!(Repo, "CREATE EXTENSION IF NOT EXISTS vector", [])
Ecto.Adapters.SQL.query!(Repo, "DROP TABLE IF EXISTS ecto_items", [])
Ecto.Adapters.SQL.query!(Repo, "CREATE TABLE ecto_items (id bigserial primary key, embedding vector(3))", [])
Ecto.Adapters.SQL.query!(Repo, "CREATE TABLE ecto_items (id bigserial primary key, embedding vector(3), half_embedding halfvec(3))", [])
create_items()
:ok
end

defp create_items do
Repo.insert(%Item{embedding: Pgvector.new([1, 1, 1])})
Repo.insert(%Item{embedding: [2, 2, 3]})
Repo.insert(%Item{embedding: Nx.tensor([1, 1, 2], type: :f32)})
Repo.insert(%Item{embedding: Pgvector.new([1, 1, 1]), half_embedding: Pgvector.HalfVector.new([1, 1, 1])})
Repo.insert(%Item{embedding: [2, 2, 3], half_embedding: [2, 2, 3]})
Repo.insert(%Item{embedding: Nx.tensor([1, 1, 2], type: :f32), half_embedding: Nx.tensor([1, 1, 2], type: :f16)})
end

test "l2 distance" do
test "vector l2 distance" do
items = Repo.all(from i in Item, order_by: l2_distance(i.embedding, [1, 1, 1]), limit: 5)
assert Enum.map(items, fn v -> v.id end) == [1, 3, 2]
assert Enum.map(items, fn v -> v.embedding |> Pgvector.to_list() end) == [[1.0, 1.0, 1.0], [1.0, 1.0, 2.0], [2.0, 2.0, 3.0]]
end

test "max inner product" do
test "vector max inner product" do
items = Repo.all(from i in Item, order_by: max_inner_product(i.embedding, [1, 1, 1]), limit: 5)
assert Enum.map(items, fn v -> v.id end) == [2, 3, 1]
end

test "cosine distance" do
test "vector cosine distance" do
items = Repo.all(from i in Item, order_by: cosine_distance(i.embedding, [1, 1, 1]), limit: 5)
assert Enum.map(items, fn v -> v.id end) == [1, 2, 3]
end

test "cosine similarity" do
test "vector cosine similarity" do
items = Repo.all(from i in Item, order_by: (1 - cosine_distance(i.embedding, [1, 1, 1])), limit: 5)
assert Enum.map(items, fn v -> v.id end) == [3, 2, 1]
end

test "l1 distance" do
test "vector l1 distance" do
items = Repo.all(from i in Item, order_by: l1_distance(i.embedding, [1, 1, 1]), limit: 5)
assert Enum.map(items, fn v -> v.id end) == [1, 3, 2]
end

test "halfvec l2 distance" do
items = Repo.all(from i in Item, order_by: l2_distance(i.half_embedding, Pgvector.HalfVector.new([1, 1, 1])), limit: 5)
assert Enum.map(items, fn v -> v.id end) == [1, 3, 2]
assert Enum.map(items, fn v -> v.half_embedding |> Pgvector.to_list() end) == [[1.0, 1.0, 1.0], [1.0, 1.0, 2.0], [2.0, 2.0, 3.0]]
end

test "halfvec max inner product" do
items = Repo.all(from i in Item, order_by: max_inner_product(i.half_embedding, Pgvector.HalfVector.new([1, 1, 1])), limit: 5)
assert Enum.map(items, fn v -> v.id end) == [2, 3, 1]
end

test "halfvec cosine distance" do
items = Repo.all(from i in Item, order_by: cosine_distance(i.half_embedding, Pgvector.HalfVector.new([1, 1, 1])), limit: 5)
assert Enum.map(items, fn v -> v.id end) == [1, 2, 3]
end

test "halfvec cosine similarity" do
items = Repo.all(from i in Item, order_by: (1 - cosine_distance(i.half_embedding, Pgvector.HalfVector.new([1, 1, 1]))), limit: 5)
assert Enum.map(items, fn v -> v.id end) == [3, 2, 1]
end

test "halfvec l1 distance" do
items = Repo.all(from i in Item, order_by: l1_distance(i.half_embedding, Pgvector.HalfVector.new([1, 1, 1])), limit: 5)
assert Enum.map(items, fn v -> v.id end) == [1, 3, 2]
end

test "cast" do
embedding = [1, 1, 1]
items = Repo.all(from i in Item, where: i.embedding == ^embedding)
Expand Down
28 changes: 28 additions & 0 deletions test/half_vector_test.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
defmodule HalfVectorTest do
use ExUnit.Case

test "half vector" do
vector = Pgvector.HalfVector.new([1, 2, 3])
assert vector == vector |> Pgvector.HalfVector.new()
end

test "list" do
list = [1.0, 2.0, 3.0]
assert list == list |> Pgvector.HalfVector.new() |> Pgvector.to_list()
end

test "tensor" do
tensor = Nx.tensor([1.0, 2.0, 3.0], type: :f16)
assert tensor == tensor |> Pgvector.HalfVector.new() |> Pgvector.to_tensor()
end

test "inspect" do
vector = Pgvector.HalfVector.new([1, 2, 3])
assert "Pgvector.HalfVector.new([1.0, 2.0, 3.0])" == inspect(vector)
end

test "equals" do
assert Pgvector.HalfVector.new([1, 2, 3]) == Pgvector.HalfVector.new([1, 2, 3])
refute Pgvector.HalfVector.new([1, 2, 3]) == Pgvector.HalfVector.new([1, 2, 4])
end
end

0 comments on commit 963ba57

Please sign in to comment.