Skip to content

Commit

Permalink
Added support for bit type to Ecto
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed May 18, 2024
1 parent cbb6458 commit 7acad90
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 5 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
## 0.2.2 (unreleased)

- Added support for `halfvec` and `sparsevec` types
- Added support for `bit` type to Ecto
- Added `Pgvector.extensions/0` function
- Added `l1_distance` function for Ecto
- Added `l1_distance`, `hamming_distance`, and `jaccard_distance` functions for Ecto

## 0.2.1 (2023-09-25)

Expand Down
6 changes: 6 additions & 0 deletions lib/pgvector.ex
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,12 @@ defmodule Pgvector do
def to_sql(vector) do
vector |> Pgvector.new()
end

# TODO move / improve pattern
@doc false
def to_bit_sql(vector) when is_bitstring(vector) do
vector
end
end

defimpl Inspect, for: Pgvector do
Expand Down
19 changes: 19 additions & 0 deletions lib/pgvector/ecto/bit.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
if Code.ensure_loaded?(Ecto) do
defmodule Pgvector.Ecto.Bit do
use Ecto.Type

def type, do: :bit

def cast(value) do
{:ok, value}
end

def load(data) do
{:ok, data}
end

def dump(value) do
{:ok, value}
end
end
end
18 changes: 18 additions & 0 deletions lib/pgvector/ecto/query.ex
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,23 @@ if Code.ensure_loaded?(Ecto) do
fragment("(? <+> ?)", unquote(column), ^Pgvector.to_sql(unquote(value)))
end
end

@doc """
Returns the Hamming distance
"""
defmacro hamming_distance(column, value) do
quote do
fragment("(? <~> ?)", unquote(column), ^Pgvector.to_bit_sql(unquote(value)))
end
end

@doc """
Returns the Jaccard distance
"""
defmacro jaccard_distance(column, value) do
quote do
fragment("(? <%> ?)", unquote(column), ^Pgvector.to_bit_sql(unquote(value)))
end
end
end
end
19 changes: 15 additions & 4 deletions test/ecto_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ defmodule Item do
schema "ecto_items" do
field :embedding, Pgvector.Ecto.Vector
field :half_embedding, Pgvector.Ecto.HalfVector
field :binary_embedding, Pgvector.Ecto.Bit
field :sparse_embedding, Pgvector.Ecto.SparseVector
end
end
Expand All @@ -17,15 +18,15 @@ defmodule EctoTest do
setup_all do
Ecto.Adapters.SQL.query!(Repo, "CREATE EXTENSION IF NOT EXISTS vector", [])
Ecto.Adapters.SQL.query!(Repo, "DROP TABLE IF EXISTS ecto_items", [])
Ecto.Adapters.SQL.query!(Repo, "CREATE TABLE ecto_items (id bigserial primary key, embedding vector(3), half_embedding halfvec(3), sparse_embedding sparsevec(3))", [])
Ecto.Adapters.SQL.query!(Repo, "CREATE TABLE ecto_items (id bigserial primary key, embedding vector(3), half_embedding halfvec(3), binary_embedding bit(3), sparse_embedding sparsevec(3))", [])
create_items()
:ok
end

defp create_items do
Repo.insert(%Item{embedding: Pgvector.new([1, 1, 1]), half_embedding: Pgvector.HalfVector.new([1, 1, 1]), sparse_embedding: Pgvector.SparseVector.new([1, 1, 1])})
Repo.insert(%Item{embedding: [2, 2, 3], half_embedding: [2, 2, 3], sparse_embedding: [2, 2, 3]})
Repo.insert(%Item{embedding: Nx.tensor([1, 1, 2], type: :f32), half_embedding: Nx.tensor([1, 1, 2], type: :f16), sparse_embedding: Nx.tensor([1, 1, 2], type: :f32)})
Repo.insert(%Item{embedding: Pgvector.new([1, 1, 1]), half_embedding: Pgvector.HalfVector.new([1, 1, 1]), binary_embedding: <<0::1, 0::1, 0::1>>, sparse_embedding: Pgvector.SparseVector.new([1, 1, 1])})
Repo.insert(%Item{embedding: [2, 2, 3], half_embedding: [2, 2, 3], binary_embedding: <<1::1, 0::1, 1::1>>, sparse_embedding: [2, 2, 3]})
Repo.insert(%Item{embedding: Nx.tensor([1, 1, 2], type: :f32), half_embedding: Nx.tensor([1, 1, 2], type: :f16), binary_embedding: <<1::1, 1::1, 1::1>>, sparse_embedding: Nx.tensor([1, 1, 2], type: :f32)})
end

test "vector l2 distance" do
Expand Down Expand Up @@ -80,6 +81,16 @@ defmodule EctoTest do
assert Enum.map(items, fn v -> v.id end) == [1, 3, 2]
end

test "bit hamming distance" do
items = Repo.all(from i in Item, order_by: hamming_distance(i.binary_embedding, <<1::1, 0::1, 1::1>>), limit: 5)
assert Enum.map(items, fn v -> v.id end) == [2, 3, 1]
end

test "bit jaccard distance" do
items = Repo.all(from i in Item, order_by: jaccard_distance(i.binary_embedding, <<1::1, 0::1, 1::1>>), limit: 5)
assert Enum.map(items, fn v -> v.id end) == [2, 3, 1]
end

test "sparsevec l2 distance" do
items = Repo.all(from i in Item, order_by: l2_distance(i.sparse_embedding, Pgvector.SparseVector.new([1, 1, 1])), limit: 5)
assert Enum.map(items, fn v -> v.id end) == [1, 3, 2]
Expand Down

0 comments on commit 7acad90

Please sign in to comment.