From 76a0a7052da9f77ba4372a039eff5cb7c66738fb Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Thu, 20 Jun 2024 00:08:23 -0700 Subject: [PATCH] Added more functions to SparseVector --- lib/pgvector/sparse_vector.ex | 42 +++++++++++++++++++++++++++++++++-- test/sparse_vector_test.exs | 17 +++++++++++++- 2 files changed, 56 insertions(+), 3 deletions(-) diff --git a/lib/pgvector/sparse_vector.ex b/lib/pgvector/sparse_vector.ex index 7619a80..53295aa 100644 --- a/lib/pgvector/sparse_vector.ex +++ b/lib/pgvector/sparse_vector.ex @@ -60,13 +60,51 @@ defmodule Pgvector.SparseVector do def from_binary(binary) when is_binary(binary) do %Pgvector.SparseVector{data: binary} end + + @doc """ + Returns the dimensions + """ + def dimensions(vector) when is_struct(vector, Pgvector.SparseVector) do + <> = vector.data + dim + end + + @doc """ + Returns the indices + """ + def indices(vector) when is_struct(vector, Pgvector.SparseVector) do + <<_::signed-32, nnz::signed-32, 0::signed-32, indices::binary-size(nnz)-unit(32), + _::binary-size(nnz)-unit(32)>> = vector.data + + for <>, do: v + end + + @doc """ + Returns the values + """ + def values(vector) when is_struct(vector, Pgvector.SparseVector) do + <<_::signed-32, nnz::signed-32, 0::signed-32, _::binary-size(nnz)-unit(32), + values::binary-size(nnz)-unit(32)>> = vector.data + + for <>, do: v + end end defimpl Inspect, for: Pgvector.SparseVector do import Inspect.Algebra def inspect(vector, opts) do - # TODO improve - concat(["Pgvector.SparseVector.new(", Inspect.List.inspect(Pgvector.to_list(vector), opts), ")"]) + dimensions = vector |> Pgvector.SparseVector.dimensions() + indices = vector |> Pgvector.SparseVector.indices() + values = vector |> Pgvector.SparseVector.values() + elements = Enum.zip(indices, values) |> Enum.into(%{}) + + concat([ + "Pgvector.SparseVector.new(", + Inspect.Map.inspect(elements, opts), + ", ", + Inspect.Integer.inspect(dimensions, opts), + ")" + ]) end end diff --git a/test/sparse_vector_test.exs b/test/sparse_vector_test.exs index 44dd20c..5bf7ed0 100644 --- a/test/sparse_vector_test.exs +++ b/test/sparse_vector_test.exs @@ -21,9 +21,24 @@ defmodule SparseVectorTest do assert [1.0, 0.0, 2.0, 0.0, 3.0, 0.0] == map |> Pgvector.SparseVector.new(6) |> Pgvector.to_list() end + test "dimensions" do + vector = Pgvector.SparseVector.new([1, 2, 3]) + assert 3 == vector |> Pgvector.SparseVector.dimensions() + end + + test "indices" do + vector = Pgvector.SparseVector.new([1, 2, 3]) + assert [0, 1, 2] == vector |> Pgvector.SparseVector.indices() + end + + test "values" do + vector = Pgvector.SparseVector.new([1, 2, 3]) + assert [1, 2, 3] == vector |> Pgvector.SparseVector.values() + end + test "inspect" do vector = Pgvector.SparseVector.new([1, 2, 3]) - assert "Pgvector.SparseVector.new([1.0, 2.0, 3.0])" == inspect(vector) + assert "Pgvector.SparseVector.new(%{0 => 1.0, 1 => 2.0, 2 => 3.0}, 3)" == inspect(vector) end test "equals" do