-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
203 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
if Code.ensure_loaded?(Ecto) do | ||
defmodule Pgvector.Ecto.SparseVector do | ||
use Ecto.Type | ||
|
||
def type, do: :sparsevec | ||
|
||
def cast(value) do | ||
{:ok, value |> Pgvector.SparseVector.new()} | ||
end | ||
|
||
def load(data) do | ||
{:ok, data} | ||
end | ||
|
||
def dump(value) do | ||
{:ok, value} | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
defmodule Pgvector.Extensions.Sparsevec do | ||
import Postgrex.BinaryUtils, warn: false | ||
|
||
def init(opts), do: Keyword.get(opts, :decode_binary, :copy) | ||
|
||
def matching(_), do: [type: "sparsevec"] | ||
|
||
def format(_), do: :binary | ||
|
||
def encode(_) do | ||
quote do | ||
vec -> | ||
data = vec |> Pgvector.SparseVector.new() |> Pgvector.to_binary() | ||
[<<IO.iodata_length(data)::int32()>> | data] | ||
end | ||
end | ||
|
||
def decode(:copy) do | ||
quote do | ||
<<len::int32(), bin::binary-size(len)>> -> | ||
bin |> :binary.copy() |> Pgvector.SparseVector.from_binary() | ||
end | ||
end | ||
|
||
def decode(_) do | ||
quote do | ||
<<len::int32(), bin::binary-size(len)>> -> | ||
bin |> Pgvector.SparseVector.from_binary() | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
defmodule Pgvector.SparseVector do | ||
@moduledoc """ | ||
A sparse vector struct for pgvector | ||
""" | ||
|
||
defstruct [:data] | ||
|
||
@doc """ | ||
Creates a new sparse vector from a list, tensor or sparse vector | ||
""" | ||
def new(list) when is_list(list) do | ||
indices = | ||
list | ||
|> Enum.with_index() | ||
|> Enum.filter(fn {v, _} -> v != 0 end) | ||
|> Enum.map(fn {_, i} -> i end) | ||
|
||
values = list |> Enum.filter(fn v -> v != 0 end) | ||
dim = list |> length() | ||
nnz = indices |> length() | ||
indices = for v <- indices, into: "", do: <<v::signed-32>> | ||
values = for v <- values, into: "", do: <<v::float-32>> | ||
from_binary(<<dim::signed-32, nnz::signed-32, 0::signed-32, indices::binary, values::binary>>) | ||
end | ||
|
||
def new(%Pgvector.SparseVector{} = vector) do | ||
vector | ||
end | ||
|
||
if Code.ensure_loaded?(Nx) do | ||
def new(tensor) when is_struct(tensor, Nx.Tensor) do | ||
if Nx.rank(tensor) != 1 do | ||
raise ArgumentError, "expected rank to be 1" | ||
end | ||
|
||
# TODO improve | ||
new(tensor |> Nx.to_list()) | ||
end | ||
end | ||
|
||
@doc """ | ||
Creates a new sparse vector from its binary representation | ||
""" | ||
def from_binary(binary) when is_binary(binary) do | ||
%Pgvector.SparseVector{data: binary} | ||
end | ||
end | ||
|
||
defimpl Inspect, for: Pgvector.SparseVector do | ||
import Inspect.Algebra | ||
|
||
def inspect(vec, opts) do | ||
# TODO improve | ||
concat(["Pgvector.SparseVector.new(", Inspect.List.inspect(Pgvector.to_list(vec), opts), ")"]) | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
defmodule SparseVectorTest do | ||
use ExUnit.Case | ||
|
||
test "sparse vector" do | ||
vector = Pgvector.SparseVector.new([1, 2, 3]) | ||
assert vector == vector |> Pgvector.SparseVector.new() | ||
end | ||
|
||
test "list" do | ||
list = [1.0, 2.0, 3.0] | ||
assert list == list |> Pgvector.SparseVector.new() |> Pgvector.to_list() | ||
end | ||
|
||
test "tensor" do | ||
tensor = Nx.tensor([1.0, 2.0, 3.0], type: :f32) | ||
assert tensor == tensor |> Pgvector.SparseVector.new() |> Pgvector.to_tensor() | ||
end | ||
|
||
test "inspect" do | ||
vector = Pgvector.SparseVector.new([1, 2, 3]) | ||
assert "Pgvector.SparseVector.new([1.0, 2.0, 3.0])" == inspect(vector) | ||
end | ||
|
||
test "equals" do | ||
assert Pgvector.SparseVector.new([1, 2, 3]) == Pgvector.SparseVector.new([1, 2, 3]) | ||
refute Pgvector.SparseVector.new([1, 2, 3]) == Pgvector.SparseVector.new([1, 2, 4]) | ||
end | ||
end |