-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
208 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
if Code.ensure_loaded?(Ecto) do | ||
defmodule Pgvector.Ecto.HalfVector do | ||
use Ecto.Type | ||
|
||
def type, do: :halfvec | ||
|
||
def cast(value) do | ||
{:ok, value |> Pgvector.HalfVector.new()} | ||
end | ||
|
||
def load(data) do | ||
{:ok, data} | ||
end | ||
|
||
def dump(value) do | ||
{:ok, value} | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
defmodule Pgvector.Extensions.Halfvec do | ||
import Postgrex.BinaryUtils, warn: false | ||
|
||
def init(opts), do: Keyword.get(opts, :decode_binary, :copy) | ||
|
||
def matching(_), do: [type: "halfvec"] | ||
|
||
def format(_), do: :binary | ||
|
||
def encode(_) do | ||
quote do | ||
vec -> | ||
data = vec |> Pgvector.HalfVector.new() |> Pgvector.to_binary() | ||
[<<IO.iodata_length(data)::int32()>> | data] | ||
end | ||
end | ||
|
||
def decode(:copy) do | ||
quote do | ||
<<len::int32(), bin::binary-size(len)>> -> | ||
bin |> :binary.copy() |> Pgvector.HalfVector.from_binary() | ||
end | ||
end | ||
|
||
def decode(_) do | ||
quote do | ||
<<len::int32(), bin::binary-size(len)>> -> | ||
bin |> Pgvector.HalfVector.from_binary() | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
defmodule Pgvector.HalfVector do | ||
@moduledoc """ | ||
A half vector struct for pgvector | ||
""" | ||
|
||
defstruct [:data] | ||
|
||
@doc """ | ||
Creates a new half vector from a list, tensor, or half vector | ||
""" | ||
def new(list) when is_list(list) do | ||
dim = list |> length() | ||
bin = for v <- list, into: "", do: <<v::float-16>> | ||
from_binary(<<dim::unsigned-16, 0::unsigned-16, bin::binary>>) | ||
end | ||
|
||
def new(%Pgvector.HalfVector{} = vector) do | ||
vector | ||
end | ||
|
||
if Code.ensure_loaded?(Nx) do | ||
def new(tensor) when is_struct(tensor, Nx.Tensor) do | ||
if Nx.rank(tensor) != 1 do | ||
raise ArgumentError, "expected rank to be 1" | ||
end | ||
|
||
dim = tensor |> Nx.size() | ||
bin = tensor |> Nx.as_type(:f16) |> Nx.to_binary() |> f16_native_to_big() | ||
from_binary(<<dim::unsigned-16, 0::unsigned-16, bin::binary>>) | ||
end | ||
|
||
defp f16_native_to_big(binary) do | ||
if System.endianness() == :big do | ||
binary | ||
else | ||
for <<n::float-16-little <- binary>>, into: "", do: <<n::float-16-big>> | ||
end | ||
end | ||
end | ||
|
||
@doc """ | ||
Creates a new half vector from its binary representation | ||
""" | ||
def from_binary(binary) when is_binary(binary) do | ||
%Pgvector.HalfVector{data: binary} | ||
end | ||
end | ||
|
||
defimpl Inspect, for: Pgvector.HalfVector do | ||
import Inspect.Algebra | ||
|
||
def inspect(vec, opts) do | ||
concat(["Pgvector.HalfVector.new(", Inspect.List.inspect(Pgvector.to_list(vec), opts), ")"]) | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
defmodule HalfVectorTest do | ||
use ExUnit.Case | ||
|
||
test "half vector" do | ||
vector = Pgvector.HalfVector.new([1, 2, 3]) | ||
assert vector == vector |> Pgvector.HalfVector.new() | ||
end | ||
|
||
test "list" do | ||
list = [1.0, 2.0, 3.0] | ||
assert list == list |> Pgvector.HalfVector.new() |> Pgvector.to_list() | ||
end | ||
|
||
test "tensor" do | ||
tensor = Nx.tensor([1.0, 2.0, 3.0], type: :f16) | ||
assert tensor == tensor |> Pgvector.HalfVector.new() |> Pgvector.to_tensor() | ||
end | ||
|
||
test "inspect" do | ||
vector = Pgvector.HalfVector.new([1, 2, 3]) | ||
assert "Pgvector.HalfVector.new([1.0, 2.0, 3.0])" == inspect(vector) | ||
end | ||
|
||
test "equals" do | ||
assert Pgvector.HalfVector.new([1, 2, 3]) == Pgvector.HalfVector.new([1, 2, 3]) | ||
refute Pgvector.HalfVector.new([1, 2, 3]) == Pgvector.HalfVector.new([1, 2, 4]) | ||
end | ||
end |