Skip to content

Commit

Permalink
Added more functions to SparseVector
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Jun 20, 2024
1 parent a619f7d commit 76a0a70
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 3 deletions.
42 changes: 40 additions & 2 deletions lib/pgvector/sparse_vector.ex
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,51 @@ defmodule Pgvector.SparseVector do
def from_binary(binary) when is_binary(binary) do
%Pgvector.SparseVector{data: binary}
end

@doc """
Returns the dimensions
"""
def dimensions(vector) when is_struct(vector, Pgvector.SparseVector) do
<<dim::signed-32, _::binary>> = vector.data
dim
end

@doc """
Returns the indices
"""
def indices(vector) when is_struct(vector, Pgvector.SparseVector) do
<<_::signed-32, nnz::signed-32, 0::signed-32, indices::binary-size(nnz)-unit(32),
_::binary-size(nnz)-unit(32)>> = vector.data

for <<v::signed-32 <- indices>>, do: v
end

@doc """
Returns the values
"""
def values(vector) when is_struct(vector, Pgvector.SparseVector) do
<<_::signed-32, nnz::signed-32, 0::signed-32, _::binary-size(nnz)-unit(32),
values::binary-size(nnz)-unit(32)>> = vector.data

for <<v::float-32 <- values>>, do: v
end
end

defimpl Inspect, for: Pgvector.SparseVector do
import Inspect.Algebra

def inspect(vector, opts) do
# TODO improve
concat(["Pgvector.SparseVector.new(", Inspect.List.inspect(Pgvector.to_list(vector), opts), ")"])
dimensions = vector |> Pgvector.SparseVector.dimensions()
indices = vector |> Pgvector.SparseVector.indices()
values = vector |> Pgvector.SparseVector.values()
elements = Enum.zip(indices, values) |> Enum.into(%{})

concat([
"Pgvector.SparseVector.new(",
Inspect.Map.inspect(elements, opts),
", ",
Inspect.Integer.inspect(dimensions, opts),
")"
])
end
end
17 changes: 16 additions & 1 deletion test/sparse_vector_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,24 @@ defmodule SparseVectorTest do
assert [1.0, 0.0, 2.0, 0.0, 3.0, 0.0] == map |> Pgvector.SparseVector.new(6) |> Pgvector.to_list()
end

test "dimensions" do
vector = Pgvector.SparseVector.new([1, 2, 3])
assert 3 == vector |> Pgvector.SparseVector.dimensions()
end

test "indices" do
vector = Pgvector.SparseVector.new([1, 2, 3])
assert [0, 1, 2] == vector |> Pgvector.SparseVector.indices()
end

test "values" do
vector = Pgvector.SparseVector.new([1, 2, 3])
assert [1, 2, 3] == vector |> Pgvector.SparseVector.values()
end

test "inspect" do
vector = Pgvector.SparseVector.new([1, 2, 3])
assert "Pgvector.SparseVector.new([1.0, 2.0, 3.0])" == inspect(vector)
assert "Pgvector.SparseVector.new(%{0 => 1.0, 1 => 2.0, 2 => 3.0}, 3)" == inspect(vector)
end

test "equals" do
Expand Down

0 comments on commit 76a0a70

Please sign in to comment.